草庐IT

小熊飞桨练习册-01手写数字识别

小熊宝宝啊 2023-03-28 原文

小熊飞桨练习册-01手写数字识别

简介

小熊飞桨练习册-01手写数字识别,本项目开发和测试均在 Ubuntu 20.04 系统下进行。
项目最新代码查看主页:小熊飞桨练习册
百度飞桨 AI Studio 主页:小熊飞桨练习册-01手写数字识别
Ubuntu 系统安装 CUDA 参考:Ubuntu 百度飞桨和 CUDA 的安装

文件说明

文件 说明
train.py 训练程序
test.py 测试程序
report.py 报表程序
onekey.sh 一键获取数据到 dataset 目录下
get-data.sh 获取数据到 dataset 目录下
check-data.sh 检查 dataset 目录下的数据是否存在
mod/lenet.py LeNet 网络模型
mod/dataset.py MNIST 手写数据集解析
mod/utils.py 杂项
mod/config.py 配置
mod/report.py 结果报表
dataset 数据集目录
params 模型参数保存目录
log VisualDL 日志保存目录

数据集

数据集来源于百度飞桨公共数据集:经典MNIST数据集

一键获取数据

  • 运行脚本,包含以下步骤:获取数据,检查数据。

如果运行在本地计算机,下载完数据,文件放到 dataset 目录下,在项目目录下运行下面脚本。
如果运行在百度 AI Studio 环境,查看 data 目录是否有数据,在项目目录下运行下面脚本。

bash onekey.sh

获取数据

如果运行在本地计算机,下载完数据,文件放到 dataset 目录下,在项目目录下运行下面脚本。
如果运行在百度 AI Studio 环境,查看 data 目录是否有数据,在项目目录下运行下面脚本。

bash get-data.sh

检查数据

获取数据完毕后,在项目目录下运行下面脚本,检查 dataset 目录下的数据是否存在。

bash check-data.sh

网络模型

网络模型使用 LeNet 网络模型 来源百度飞桨教程和网络
LeNet 网络模型 参考: 百度飞桨教程

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


# LeNet 网络模型
class LeNet(nn.Layer):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        if num_classes < 1:
            raise Exception("分类数量 num_classes 必须大于 0: {}".format(num_classes))
        self.num_classes = num_classes
        self.conv1 = nn.Conv2D(
            in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.avg_pool1 = nn.AvgPool2D(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2D(
            in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.avg_pool2 = nn.AvgPool2D(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2D(
            in_channels=16, out_channels=120, kernel_size=4, stride=1)
        self.fc1 = nn.Linear(in_features=120, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.avg_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.avg_pool2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

数据集解析

数据集解析方法来源百度飞桨教程和网络,和百度飞桨 MNIST 数据集稍有不同

import paddle
import os
import struct
import numpy as np


class MNIST(paddle.io.Dataset):
    """
    MNIST 手写数据集解析, 继承 paddle.io.Dataset 类
    """

    def __init__(self,
                 images_path: str,
                 labels_path: str,
                 transform=None,
                 ):
        """
        构造函数,定义数据集大小

        Args:
            images_path (str): 图像集路径
            labels_path (str): 标签集路径
            transform (Compose, optional): 转换数据的操作组合, 默认 None
        """
        super(MNIST, self).__init__()
        self.images_path = images_path
        self.labels_path = labels_path
        self._check_path(images_path, "数据路径错误")
        self._check_path(labels_path, "标签路径错误")
        self.transform = transform
        self.images, self.labels = self.parse_dataset(images_path, labels_path)

    def __getitem__(self, idx):
        """
        获取单个数据和标签

        Args:
            idx (Any): 索引

        Returns:
            image (float32): 图像
            label (int64): 标签
        """
        image, label = self.images[idx], self.labels[idx]
        # 这里 reshape 是2维 [28 ,28]
        image = np.reshape(image, [28, 28])
        if self.transform is not None:
            image = self.transform(image)
        # label.astype 如果是整型,只能是 int64
        return image.astype('float32'), label.astype('int64')

    def __len__(self):
        """
        数据数量

        Returns:
            int: 数据数量
        """
        return len(self.labels)

    def _check_path(self, path: str, msg: str):
        """
        检查路径是否存在

        Args:
            path (str): 路径
            msg (str, optional): 异常消息

        Raises:
            Exception: 路径错误, 异常
        """
        if not os.path.exists(path):
            raise Exception("{}: {}".format(msg, path))

    @staticmethod
    def parse_dataset(images_path: str, labels_path: str):
        """
        数据集解析

        Args:
            images_path (str): 图像集路径
            labels_path (str): 标签集路径

        Returns:
            images: 图像集
            labels: 标签集
        """
        with open(images_path, 'rb') as imgpath:
            # 解析图像集
            magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
            # 这里 reshape 是1维 [786]
            images = np.fromfile(
                imgpath, dtype=np.uint8).reshape(num, rows * cols)
        with open(labels_path, 'rb') as lbpath:
            # 解析标签集
            magic, n = struct.unpack('>II', lbpath.read(8))
            labels = np.fromfile(lbpath, dtype=np.uint8)
        return images, labels

配置模块

可以查看修改 mod/config.py 文件,有详细的说明

开始训练

运行 train.py 文件,查看命令行参数加 -h

python3 train.py
  --cpu             是否使用 cpu 计算,默认使用 CUDA
  --learning-rate   学习率,默认 0.001
  --epochs          训练几轮,默认 2 轮
  --batch-size      一批次数量,默认 128
  --num-workers     线程数量,默认 2
  --no-save         是否保存模型参数,默认保存, 选择后不保存模型参数
  --load-dir        读取模型参数,读取 params 目录下的子文件夹, 默认不读取
  --log             是否输出 VisualDL 日志,默认不输出
  --summary         输出网络模型信息,默认不输出,选择后只输出信息,不会开启训练

测试模型

运行 test.py 文件,查看命令行参数加 -h

python3 test.py
  --cpu           是否使用 cpu 计算,默认使用 CUDA
  --batch-size    一批次数量,默认 128
  --num-workers   线程数量,默认 2
  --load-dir      读取模型参数,读取 params 目录下的子文件夹, 默认 best 目录

查看结果报表

运行 report.py 文件,可以显示 params 目录下所有子目录的 report.json
加参数 --best 根据 loss 最小的模型参数保存在 best 子目录下。

python3 report.py

report.json 说明

键名 说明
id 根据时间生成的字符串 ID
loss 本次训练的 loss 值
acc 本次训练的 acc 值
epochs 本次训练的 epochs 值
batch_size 本次训练的 batch_size 值
learning_rate 本次训练的 learning_rate 值

VisualDL 可视化分析工具

  • 安装和使用说明参考:VisualDL
  • 训练的时候加上参数 --log
  • 如果是 AI Studio 环境训练的把 log 目录下载下来,解压缩后放到本地项目目录下 log 目录
  • 在项目目录下运行下面命令
  • 然后根据提示的网址,打开浏览器访问提示的网址即可
visualdl --logdir ./log

有关小熊飞桨练习册-01手写数字识别的更多相关文章

  1. ruby - 查找字符串中的内容类型(数字、日期、时间、字符串等) - 2

    我正在尝试解析一个CSV文件并使用SQL命令自动为其创建一个表。CSV中的第一行给出了列标题。但我需要推断每个列的类型。Ruby中是否有任何函数可以找到每个字段中内容的类型。例如,CSV行:"12012","Test","1233.22","12:21:22","10/10/2009"应该产生像这样的类型['integer','string','float','time','date']谢谢! 最佳答案 require'time'defto_something(str)if(num=Integer(str)rescueFloat(s

  2. 报告回顾丨模型进化狂飙,DetectGPT能否识别最新模型生成结果? - 2

    导读语言模型给我们的生产生活带来了极大便利,但同时不少人也利用他们从事作弊工作。如何规避这些难辨真伪的文字所产生的负面影响也成为一大难题。在3月9日智源Live第33期活动「DetectGPT:判断文本是否为机器生成的工具」中,主讲人Eric为我们讲解了DetectGPT工作背后的思路——一种基于概率曲率检测的用于检测模型生成文本的工具,它可以帮助我们更好地分辨文章的来源和可信度,对保护信息真实、防止欺诈等方面具有重要意义。本次报告主要围绕其功能,实现和效果等展开。(文末点击“阅读原文”,查看活动回放。)Ericmitchell斯坦福大学计算机系四年级博士生,由ChelseaFinn和Chri

  3. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  4. [Vuforia]二.3D物体识别 - 2

    之前说过10之后的版本没有3dScan了,所以还是9.8的版本或者之前更早的版本。 3d物体扫描需要先下载扫描的APK进行扫面。首先要在手机上装一个扫描程序,扫描现实中的三维物体,然后上传高通官网,在下载成UnityPackage类型让Unity能够使用这个扫描程序可以从高通官网上进行下载,是一个安卓程序。点到Tools往下滑,找到VuforiaObjectScanner下载后解压数据线连接手机,将apk文件拷入手机安装然后刚才解压文件中的Media文件夹打开,两个PDF图打印第一张A4-ObjectScanningTarget.pdf,主要是用来辅助扫描的。好了,接下来就是扫描三维物体。将瓶

  5. ruby-on-rails - 在 heroku 的 .fonts 文件夹中包含自定义字体,似乎无法识别它们 - 2

    Heroku支持人员告诉我,为了在我的Web应用程序中使用自定义字体(未安装在系统中,您可以在bash控制台中使用fc-list查看已安装的字体)我必须部署一个包含所有字体的.fonts文件夹里面的字体。问题是我不知道该怎么做。我的意思是,我不知道文件名是否必须遵循heroku的任何特殊模式,或者我必须在我的代码中做一些事情来考虑这种字体,或者如果我将它包含在文件夹中它是自动的......事实是,我尝试以不同的方式更改字体的文件名,但根本没有使用该字体。为了提供更多详细信息,我们使用字体的过程是将PDF转换为图像,更具体地说,使用rghostgem。并且最终图像根本不使用自定义字体。在

  6. 牛客网专项练习30天Pytnon篇第02天 - 2

    1.在Python3中,下列关于数学运算结果正确的是:(B)a=10b=3print(a//b)print(a%b)print(a/b)A.3,3,3.3333...B.3,1,3.3333...C.3.3333...,3.3333...,3D.3.3333...,1,3.3333...解析:    在Python中,//表示地板除(向下取整),%表示取余,/表示除(Python2向下取整返回3)2.如下程序Python2会打印多少个数:(D)k=1000whilek>1:    print(k)k=k/2A.1000 B.10C.11D.9解析:    按照题意每次循环K/2,直到K值小于等

  7. ruby-on-rails - 没有这样的文件或目录 - 用 Mini Magick 识别 - 2

    在我让另一个人重做我的前端UI之前,我的Rails应用程序运行平稳。我已经尝试解决此错误3天了。这是错误:Nosuchfileordirectory-identifyExtractedsource(aroundline#59):575859606162@post=Post.find(params[:id])authorize@postif@post.update_attributes(post_params)flash[:notice]="Postwasupdated."redirect_to[@topic,@post]else{"utf8"=>"✓","_method"=>"patc

  8. ruby - 将n维数组的每个元素乘以Ruby中的数字 - 2

    在Ruby中,是否有一种简单的方法可以将n维数组中的每个元素乘以一个数字?这样:[1,2,3,4,5].multiplied_by2==[2,4,6,8,10]和[[1,2,3],[1,2,3]].multiplied_by2==[[2,4,6],[2,4,6]]?(很明显,我编写了multiplied_by函数以区别于*,它似乎连接了数组的多个副本,不幸的是这不是我需要的)。谢谢! 最佳答案 它的长格式等价物是:[1,2,3,4,5].collect{|n|n*2}其实并没有那么复杂。你总是可以使你的multiply_by方法:c

  9. Ruby 的数字方法性能 - 2

    我正在使用Ruby解决一些ProjectEuler问题,特别是这里我要讨论的问题25(Fibonacci数列中包含1000位数字的第一项的索引是多少?)。起初,我使用的是Ruby2.2.3,我将问题编码为:number=3a=1b=2whileb.to_s.length但后来我发现2.4.2版本有一个名为digits的方法,这正是我需要的。我转换为代码:whileb.digits.length当我比较这两种方法时,digits慢得多。时间./025/problem025.rb0.13s用户0.02s系统80%cpu0.190总计./025/problem025.rb2.19s用户0.0

  10. ruby - 按数字(从大到大)然后按字母(字母顺序)对对象集合进行排序 - 2

    我正在构建一个小部件来显示奥运会的奖牌数。我有一个“国家”对象的集合,其中每个对象都有一个“名称”属性,以及奖牌计数的“金”、“银”、“铜”。列表应该排序:1.首先是奖牌总数2.如果奖牌相同,按类型分割(金>银>铜,即2金>1金+1银)3.如果奖牌和类型相同,则按字母顺序子排序我正在用ruby​​做这件事,但我想语言并不重要。我确实找到了一个解决方案,但如果感觉必须有更优雅的方法来实现它。这是我做的:使用加权奖牌总数创建一个虚拟属性。因此,如果他们有2个金牌和1个银牌,加权总数将为“3.020100”。1金1银1铜为“3.010101”由于我们希望将奖牌数排序为最高的,因此列表按降序排

随机推荐