草庐IT

手把手实战PyTorch手写数据集MNIST识别项目全流程

果子尝尝 2023-05-03 原文

目录

摘要

一、认识MNIST手写数据集

二、实战流程

1、加载必要的库

2、定义超参数

3、构建transform,对图像做处理

4、下载、处理、加载数据集

下载、处理数据集

加载数据集

5、构建网络模型

6、定义优化器

7、定义训练方法

8、定义测试方法

9、调用方法7和8

10、运行

三、完整代码

 


摘要

MNIST手写数据集是跑深度学习模型中很基础的、几乎所有初学者都会用到的数据集,认真领悟手写数据集的识别过程对于深度学习框架有着弥足重要的意义。然而目前各类文章中关于项目完全实战的记录较少,无法满足广大初学者的要求,故本文受B站Tommy启发来手把手从引入库开始进行对整个手写数据集识别的流程,这对于笔者以后的深度学习有着很大的必要性。

一、认识MNIST手写数据集

MNIST 数据集是由 0〜9 手写数字图片和数字标签所组成的,由 60000 个训练样本和 10000 个测试样本组成,每个样本都是一张 28*28 像素的灰度手写数字图片。如下图所示。

 可以看到,每个阿拉伯数字都形态各异,而本文的任务就是把它们识别出来。

二、实战流程

1、加载必要的库

MNIST手写识别需要的库有基本库torch、包含了构筑神经网络结构基本元素的包torch.nn、torch.nn.functional、优化器optim、对数据库进行操作的torchvision。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

2、定义超参数

超参数:在机器学习中,超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据。通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果。

由于实操中数据往往会过多,一次加载不完,内存不够,所以我们将数据切割,选择超参数batch_size(每批处理的数据)为128(根据性能)。

第二个超参数定义一个DEVICE来判断用CPU还是GPU训练。

第三个超参数决定进行几轮训练,本文选择100轮训练.

BATCH_SIZE = 128      
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")      
EPOCHS = 20

3、构建transform,对图像做处理

PyTorch内置很多库,直接调用方法transforms即可:

将图片转换成PyTorch处理的tensor格式,然后进行正则化(对抗过拟合)。

其中0.1307,0.3081分别为官网查得的均值和方差值。

tranform = transforms.Compose([
    transforms.ToTensor(),      
    transforms.Normalize((0.1307,), (0.3081,))      #正则化
])

4、下载、处理、加载数据集

下载、处理数据集

由于笔者已经提前下载MNIST文件到项目目录里,故download = False,如果提前未下载改成True等待下载成功即可。

from torch.utils.data import DataLoader
train_data = datasets.MNIST(root="./MNIST",
                            train=True,
                            transform=tranform,
                            download=False)

test_data = datasets.MNIST(root="./MNIST",
                           train=False,
                           transform=tranform,
                           download=False)

加载数据集

其中shuffle决定的是是否打乱数据,为了提高模型精度选择True打乱。

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

5、构建网络模型

class Digit(nn.Module):                    #继承父类
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)   # 二维卷积、输入通道,输出通道,5*5 kernel
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.fc1 = nn.Linear(20*10*10, 500)    # 全连接层,输入通道, 输出通道
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):          # 前馈
        input_size = x.size(0)     # 得到batch_size
        x = self.conv1(x)          # 输入:batch*1*28*28, 输出:batch*10*24*24(28-5+1)
        x = F.relu(x)              # 使表达能力更强大的激活函数, 输出batch*10*24*24
        x = F.max_pool2d(x, 2, 2)  # 最大池化层,输入batch*10*24*24,输出batch*10*12*12

        x = self.conv2(x)          # 输入batch*10*12*12,输出batch*20*10*10
        x = F.relu(x)

        x = x.view(input_size, -1) # 拉平, 自动计算维度,20*10*10= 2000

        x = self.fc1(x)            # 输入batch*2000,输出batch*500
        x = F.relu(x)

        x = self.fc2(x)            # 输入batch*500 输出batch*10

        output = F.log_softmax(x, dim=1)  # 计算分类后每个数字的概率值

        return output

6、定义优化器

将模型部署到GPU

优化器:更新模型参数,使训练结果达到最优值

model = Digit().to(DEVICE)   

optimizer = optim.Adam(model.parameters())

7、定义训练方法

enumerate函数:来遍历一个集合对象,它在遍历的同时还可以得到当前元素的索引位置。

反向传播:不断迭代权重,降低误差。

loss.item():取出单元素张量的元素值(loss值)并返回该值,保持原元素类型不变。

def train_model(model, device, train_loader, optimizer, epoch):
    model.train()                     #PyTorch提供的训练方法
    for batch_index, (data, label) in enumerate(train_loader):
        #部署到DEVICE
        data, label = data.to(device), label.to(device)
        #梯度初始化为0
        optimizer.zero_grad()
        #训练后的结果
        output = model(data)
        #计算损失(针对多分类任务交叉熵,二分类用sigmoid)
        loss = F.cross_entropy(output, label)
        #找到最大概率的下标
        pred = output.argmax(dim=1)
        #反向传播Backpropagation
        loss.backward()
        #参数的优化
        optimizer.step()
        if batch_index % 3000 == 0:
            print("Train Epoch : {} \t Loss : {:.6f}".format(epoch, loss.item()))

8、定义测试方法

def test_model(model, device, test_loader):
    #模型验证
    model.eval()
    #统计正确率
    correct = 0.0
    #测试损失
    test_loss = 0.0
    with torch.no_grad():    # 不计算梯度,不反向传播
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            #测试数据
            output = model(data)
            #计算测试损失
            test_loss += F.cross_entropy(output, label).item()
            #找到概率值最大的下标
            pred = output.argmax(dim=1)
            #累计正确率
            correct += pred.eq(label.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        print("Test —— Average loss : {:.4f}, Accuracy : {:.3f}\n".format(test_loss, 100.0 * correct / len(test_loader.dataset)))

9、调用方法7和8

for epoch in range(1, EPOCHS + 1):
    train_model(model, DEVICE, train_loader, optimizer, epoch)
    test_model(model, DEVICE, test_loader)

10、运行

接下来运行即可,笔者运行结果如下图示:

三、完整代码

完整代码如下:

#1加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


#2定义超参数(参数:由模型学习来决定的)数据太多一次放不完,切割
BATCH_SIZE = 128      # 每批处理的数据
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")      # CPU还是GPU?
EPOCHS = 100


#3构建transform, 对图像进行各种处理(旋转拉伸,放大缩小等)
tranform = transforms.Compose([
    transforms.ToTensor(),       # 将图片转换成Tensor
    transforms.Normalize((0.1307,), (0.3081,))      # 均值和方差,正则化(对抗过拟合):降低模型复杂度
])


#4下载、加载数据集
from torch.utils.data import DataLoader
train_data = datasets.MNIST(root="./MNIST",
                            train=True,
                            transform=tranform,
                            download=False)

test_data = datasets.MNIST(root="./MNIST",
                           train=False,
                           transform=tranform,
                           download=False)
#加载数据集
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


#5构建网络模型
class Digit(nn.Module):
    def __init__(self):                    #继承父类
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)   # 输入通道,输出通道,5*5 kernel
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.fc1 = nn.Linear(20*10*10, 500)    # 全连接层,输入通道, 输出通道
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):          # 前馈
        input_size = x.size(0)     # 得到batch_size
        x = self.conv1(x)          # 输入:batch*1*28*28, 输出:batch*10*24*24(28-5+1)
        x = F.relu(x)              # 使表达能力更强大激活函数, 输出batch*10*24*24
        x = F.max_pool2d(x, 2, 2)  # 最大池化层,输入batch*10*24*24,输出batch*10*12*12

        x = self.conv2(x)          # 输入batch*10*12*12,输出batch*20*10*10
        x = F.relu(x)

        x = x.view(input_size, -1) # 拉平, 自动计算维度,20*10*10= 2000

        x = self.fc1(x)            # 输入batch*2000,输出batch*500
        x = F.relu(x)

        x = self.fc2(x)            # 输入batch*500 输出batch*10

        output = F.log_softmax(x, dim=1)  # 计算分类后每个数字的概率值

        return output


#6定义优化器
model = Digit().to(DEVICE)    # 创建模型部署到DEVICE

optimizer = optim.Adam(model.parameters())


#7定义训练方法
def train_model(model, device, train_loader, optimizer, epoch):
    model.train()                    #PyTorch提供的训练方法
    for batch_index, (data, label) in enumerate(train_loader):
        #部署到DEVICE
        data, label = data.to(device), label.to(device)
        #梯度初始化为0
        optimizer.zero_grad()
        #训练后的结果
        output = model(data)
        #计算损失(针对多分类任务交叉熵,二分类用sigmoid)
        loss = F.cross_entropy(output, label)
        #找到最大概率的下标
        pred = output.argmax(dim=1)
        #反向传播Backpropagation
        loss.backward()
        #参数的优化
        optimizer.step()
        if batch_index % 3000 == 0:
            print("Train Epoch : {} \t Loss : {:.6f}".format(epoch, loss.item()))


#8定义测试方法
def test_model(model, device, test_loader):
    #模型验证
    model.eval()
    #统计正确率
    correct = 0.0
    #测试损失
    test_loss = 0.0
    with torch.no_grad():    # 不计算梯度,不反向传播
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            #测试数据
            output = model(data)
            #计算测试损失
            test_loss += F.cross_entropy(output, label).item()
            #找到概率值最大的下标
            pred = output.argmax(dim=1)
            #累计正确率
            correct += pred.eq(label.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        print("Test —— Average loss : {:.4f}, Accuracy : {:.3f}\n".format(test_loss, 100.0 * correct / len(test_loader.dataset)))


#9 调用方法7和8
for epoch in range(1, EPOCHS + 1):
    train_model(model, DEVICE, train_loader, optimizer, epoch)
    test_model(model, DEVICE, test_loader)

 

有关手把手实战PyTorch手写数据集MNIST识别项目全流程的更多相关文章

  1. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  2. ruby - Ruby 有 `Pair` 数据类型吗? - 2

    有时我需要处理键/值数据。我不喜欢使用数组,因为它们在大小上没有限制(很容易不小心添加超过2个项目,而且您最终需要稍后验证大小)。此外,0和1的索引变成了魔数(MagicNumber),并且在传达含义方面做得很差(“当我说0时,我的意思是head...”)。散列也不合适,因为可能会不小心添加额外的条目。我写了下面的类来解决这个问题:classPairattr_accessor:head,:taildefinitialize(h,t)@head,@tail=h,tendend它工作得很好并且解决了问题,但我很想知道:Ruby标准库是否已经带有这样一个类? 最佳

  3. ruby - 我如何添加二进制数据来遏制 POST - 2

    我正在尝试使用Curbgem执行以下POST以解析云curl-XPOST\-H"X-Parse-Application-Id:PARSE_APP_ID"\-H"X-Parse-REST-API-Key:PARSE_API_KEY"\-H"Content-Type:image/jpeg"\--data-binary'@myPicture.jpg'\https://api.parse.com/1/files/pic.jpg用这个:curl=Curl::Easy.new("https://api.parse.com/1/files/lion.jpg")curl.multipart_form_

  4. 世界前沿3D开发引擎HOOPS全面讲解——集3D数据读取、3D图形渲染、3D数据发布于一体的全新3D应用开发工具 - 2

    无论您是想搭建桌面端、WEB端或者移动端APP应用,HOOPSPlatform组件都可以为您提供弹性的3D集成架构,同时,由工业领域3D技术专家组成的HOOPS技术团队也能为您提供技术支持服务。如果您的客户期望有一种在多个平台(桌面/WEB/APP,而且某些客户端是“瘦”客户端)快速、方便地将数据接入到3D应用系统的解决方案,并且当访问数据时,在各个平台上的性能和用户体验保持一致,HOOPSPlatform将帮助您完成。利用HOOPSPlatform,您可以开发在任何环境下的3D基础应用架构。HOOPSPlatform可以帮您打造3D创新型产品,HOOPSSDK包含的技术有:快速且准确的CAD

  5. FOHEART H1数据手套驱动Optitrack光学动捕双手运动(Unity3D) - 2

    本教程将在Unity3D中混合Optitrack与数据手套的数据流,在人体运动的基础上,添加双手手指部分的运动。双手手背的角度仍由Optitrack提供,数据手套提供双手手指的角度。 01  客户端软件分别安装MotiveBody与MotionVenus并校准人体与数据手套。MotiveBodyMotionVenus数据手套使用、校准流程参照:https://gitee.com/foheart_1/foheart-h1-data-summary.git02  数据转发打开MotiveBody软件的Streaming,开始向Unity3D广播数据;MotionVenus中设置->选项选择Unit

  6. 使用canal同步MySQL数据到ES - 2

    文章目录一、概述简介原理模块二、配置Mysql使用版本环境要求1.操作系统2.mysql要求三、配置canal-server离线下载在线下载上传解压修改配置单机配置集群配置分库分表配置1.修改全局配置2.实例配置垂直分库水平分库3.修改group-instance.xml4.启动监听四、配置canal-adapter1修改启动配置2配置映射文件3启动ES数据同步查询所有订阅同步数据同步开关启动4.验证五、配置canal-admin一、概述简介canal是Alibaba旗下的一款开源项目,Java开发。基于数据库增量日志解析,提供增量数据订阅&消费。Git地址:https://github.co

  7. Unity 3D 制作开关门动画,旋转门制作,推拉门制作,门把手动画制作 - 2

    Unity自动旋转动画1.开门需要门把手先动,门再动2.关门需要门先动,门把手再动3.中途播放过程中不可以再次进行操作觉得太复杂?查看我的文章开关门简易进阶版效果:如果这个门可以直接打开的话,就不需要放置"门把手"如果门把手还有钥匙需要旋转,那就可以把钥匙放在门把手的"门把手",理论上是可以无限套娃的可调整参数有:角度,反向,轴向,速度运行时点击Test进行测试自己写的代码比较垃圾,命名与结构比较拉,高手轻点喷,新手有类似的需求可以拿去做参考上代码usingSystem.Collections;usingSystem.Collections.Generic;usingUnityEngine;u

  8. ruby-on-rails - 创建 ruby​​ 数据库时惰性符号绑定(bind)失败 - 2

    我正在尝试在Rails上安装ruby​​,到目前为止一切都已安装,但是当我尝试使用rakedb:create创建数据库时,我收到一个奇怪的错误:dyld:lazysymbolbindingfailed:Symbolnotfound:_mysql_get_client_infoReferencedfrom:/Library/Ruby/Gems/1.8/gems/mysql2-0.3.11/lib/mysql2/mysql2.bundleExpectedin:flatnamespacedyld:Symbolnotfound:_mysql_get_client_infoReferencedf

  9. STM32读取串口传感器数据(颗粒物传感器,主动上传) - 2

    文章目录1.开发板选择*用到的资源2.串口通信(个人理解)3.代码分析(注释比较详细)1.主函数2.串口1配置3.串口2配置以及中断函数4.注意问题5.源码链接1.开发板选择我用的是STM32F103RCT6的板子,不过代码大概在F103系列的板子上都可以运行,我试过在野火103的霸道板上也可以,主要看一下串口对应的引脚一不一样就行了,不一样的就更改一下。*用到的资源keil5软件这里用到了两个串口资源,采集数据一个,串口通信一个,板子对应引脚如下:串口1,TX:PA9,RX:PA10串口2,TX:PA2,RX:PA32.串口通信(个人理解)我就从串口采集传感器数据这个过程说一下我自己的理解,

  10. SPI接收数据异常问题总结 - 2

    SPI接收数据左移一位问题目录SPI接收数据左移一位问题一、问题描述二、问题分析三、探究原理四、经验总结最近在工作在学习调试SPI的过程中遇到一个问题——接收数据整体向左移了一位(1bit)。SPI数据收发是数据交换,因此接收数据时从第二个字节开始才是有效数据,也就是数据整体向右移一个字节(1byte)。请教前辈之后也没有得到解决,通过在网上查阅前人经验终于解决问题,所以写一个避坑经验总结。实际背景:MCU与一款芯片使用spi通信,MCU作为主机,芯片作为从机。这款芯片采用的是它规定的六线SPI,多了两根线:RDY和INT,这样从机就可以主动请求主机给主机发送数据了。一、问题描述根据从机芯片手

随机推荐