草庐IT

Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

Mr.Winter` 2023-04-15 原文

目录

1 数据集Dataset

Dataset类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自Dataset父类。当我们自定义数据集处理时,必须实现Dataset类中的三个接口:

  • 初始化
    def __init__(self)
    
    构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等
  • 数据集大小
    def __len__(self)
    
    返回数据集大小,不同的数据集有不同的衡量数据量的方式
  • 数据集索引
    def __getitem__(self, index):
    
    支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如
    class data(Dataset)
    	...
        def __getitem__(self, idx: int):
            if self.transforms:
                img = self.transforms(img)
            return img, label			# 返回的值即为后续迭代的循环变量
    
    for images, labels in dataLoader:
    	...
    

2 数据加载DataLoader

为什么有了数据集Dataset还需要数据加载器DataLoader呢?原因在于神经网络需要进一步借助DataLoader对数据进行划分,也就是我们常说的batch,此外DataLoader还实现了打乱数据集、多线程等操作。

DataLoader本质是一个可迭代对象,可以使用形如

for inputs, labels in dataloaders

进行可迭代对象的访问。

我们一般不需要去实现DataLoader的接口,只需要在构造函数中指定相应的参数即可,比如常见的batch_sizeshuffle等参数。

下面这张图非常好地说明了DatasetDataLoader的关系

接下来总结数据构造的三步法

  1. 继承Dataset对象,并实现__len__()__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中;
  2. DataLoader对象封装Dataset,使其成为可迭代对象;
  3. 遍历DataLoader对象以将数据加载到模型中进行训练。

3 常用预处理方法

在数据集Dataset__getitem__()中利用torchvision.transforms进行数据预处理与变换

常见的数据预处理变换方法总结如下表

序号变换含义
1RandomCrop(size, ...)对输入图像依据给定size随机裁剪
2CenterCrop(size, ...)对输入图像依据给定size从中心裁剪
3RandomResizedCrop(size, ...)对输入图像随机长宽比裁剪,再放缩到给定size
4FiveCrop(size, ...)对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量
5TenCrop(size, vertical_flip=False)对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量
6RandomHorizontalFlip(p=0.5)对输入图像按概率p随机进行水平翻转
7RandomVerticalFlip(p=0.5)对输入图像按概率p随机进行垂直翻转
8RandomRotation(degree, ...)对输入图像在degree内随机旋转某角度
9Resize(size, ...)对输入图像重置分辨率
10Normalize(mean, std)对输入图像各通道进行标准化
11ToTensor()将输入图像或ndarray 转换为tensor并归一化
12Pad(padding, fill=0, padding_mode=‘constant’)对输入图像进行填充
13ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)对输入图像修改亮度、对比度、饱和度、色度等
14Grayscale(num_output_channels=1)对输入图像转灰度
15LinearTransformation(matrix)对输入图像进行线性变换
16RandomAffine(...)对输入图像进行仿射变换
17RandomGrayscale(p=0.1)对输入图像按概率p随机转灰度
18ToPILImage(mode=None)对输入图像转PIL格式图像
19RandomOrder()随机打乱transforms操作顺序

4 模型处理

考虑以下场景:

网络的部分层级结构已经收敛、无需调整;大型复杂网络需要微调(Fine-tune)某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。

这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入预训练(pretrained)。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以.pth为后缀

# 保存已训练模型
torch.save(model.state_dict(), path)
# 加载预训练模型
model.load_state_dict(torch.load(path), device)

通过设置模型某些层可学习参数的requires_grad属性为False即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过mode.to(device)方法将模型部署到指定设备上(CPU/GPU),范式如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

工程上也常使用torch.nn.DataParallel(model, devices)来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。

5 实例:MNIST数据集处理

下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解

from abc import abstractmethod
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import Dataset
from PIL import Image

class mnistData(Dataset):
    '''
    * @breif: MNIST数据集抽象接口
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__()
        self.dataPath = dataPath
        self.transforms = transforms
        self.data, self.label = [], []

    def __len__(self) -> int:
        return len(self.label)

    def __getitem__(self, idx: int):
        img = self.data[idx]
        if self.transforms:
            img = self.transforms(img)
        return img, self.label[idx]

    @abstractmethod
    def plot(self, index: int) -> None:
        pass

    @abstractmethod
    def load(self) -> list:
        pass

    def plotData(self, index: int, info: str=None) -> None:
        '''
        * @breif: 可视化训练数据
        * @param[in]: index -> 数据集索引
        * @param[in]: info -> 备注信息
        * @retval: None
        '''
        print(info, " --index:", index, "--label:", self.label[index])  if info else \
        print(" --index:", index, "--label:", self.label[index])          
        img = Image.fromarray(np.uint8(self.data[index]))
        img.show()

    def loadData(self, train: bool) -> list:
        '''
        * @breif: 下载与加载数据集
        * @param[in]: train -> 是否为训练集
        * @retval: 数据与标签列表
        '''    
        # 如果指定目录下不存在数据集则下载
        dataSet   = mnist.MNIST(self.dataPath, train=train, download=True)
        # 初始化数据与标签
        data  = [ i[0] for i in dataSet ]
        label = [ i[1] for i in dataSet ]
        return data, label

class mnistTrainData(mnistData):
    '''
    * @breif: MNIST训练集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()

    def plot(self, index: int) -> None:
        self.plotData(index, "trainSet data")

    def load(self) -> list:
        return self.loadData(train=True)


class mnistTestData(mnistData):
    '''
    * @breif: MNIST测试集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()

    def plot(self, index: int) -> None:
        self.plotData(index, "testSet data")

    def load(self) -> list:
        return self.loadData(train=False)

有关Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理的更多相关文章

  1. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  2. ruby - 如何指定 Rack 处理程序 - 2

    Rackup通过Rack的默认处理程序成功运行任何Rack应用程序。例如:classRackAppdefcall(environment)['200',{'Content-Type'=>'text/html'},["Helloworld"]]endendrunRackApp.new但是当最后一行更改为使用Rack的内置CGI处理程序时,rackup给出“NoMethodErrorat/undefinedmethod`call'fornil:NilClass”:Rack::Handler::CGI.runRackApp.newRack的其他内置处理程序也提出了同样的反对意见。例如Rack

  3. ruby - 如何在续集中重新加载表模式? - 2

    鉴于我有以下迁移:Sequel.migrationdoupdoalter_table:usersdoadd_column:is_admin,:default=>falseend#SequelrunsaDESCRIBEtablestatement,whenthemodelisloaded.#Atthispoint,itdoesnotknowthatusershaveais_adminflag.#Soitfails.@user=User.find(:email=>"admin@fancy-startup.example")@user.is_admin=true@user.save!ende

  4. ruby - RuntimeError(自动加载常量 Apps 多线程时检测到循环依赖 - 2

    我收到这个错误:RuntimeError(自动加载常量Apps时检测到循环依赖当我使用多线程时。下面是我的代码。为什么会这样?我尝试多线程的原因是因为我正在编写一个HTML抓取应用程序。对Nokogiri::HTML(open())的调用是一个同步阻塞调用,需要1秒才能返回,我有100,000多个页面要访问,所以我试图运行多个线程来解决这个问题。有更好的方法吗?classToolsController0)app.website=array.join(',')putsapp.websiteelseapp.website="NONE"endapp.saveapps=Apps.order("

  5. ruby - Ruby 有 `Pair` 数据类型吗? - 2

    有时我需要处理键/值数据。我不喜欢使用数组,因为它们在大小上没有限制(很容易不小心添加超过2个项目,而且您最终需要稍后验证大小)。此外,0和1的索引变成了魔数(MagicNumber),并且在传达含义方面做得很差(“当我说0时,我的意思是head...”)。散列也不合适,因为可能会不小心添加额外的条目。我写了下面的类来解决这个问题:classPairattr_accessor:head,:taildefinitialize(h,t)@head,@tail=h,tendend它工作得很好并且解决了问题,但我很想知道:Ruby标准库是否已经带有这样一个类? 最佳

  6. ruby-on-rails - 使用 config.threadsafe 时从 lib/加载模块/类的正确方法是什么!选项? - 2

    我一直致力于让我们的Rails2.3.8应用程序在JRuby下正确运行。一切正常,直到我启用config.threadsafe!以实现JRuby提供的并发性。这导致lib/中的模块和类不再自动加载。使用config.threadsafe!启用:$rubyscript/runner-eproduction'pSim::Sim200Provisioner'/Users/amchale/.rvm/gems/jruby-1.5.1@web-services/gems/activesupport-2.3.8/lib/active_support/dependencies.rb:105:in`co

  7. ruby - 我如何添加二进制数据来遏制 POST - 2

    我正在尝试使用Curbgem执行以下POST以解析云curl-XPOST\-H"X-Parse-Application-Id:PARSE_APP_ID"\-H"X-Parse-REST-API-Key:PARSE_API_KEY"\-H"Content-Type:image/jpeg"\--data-binary'@myPicture.jpg'\https://api.parse.com/1/files/pic.jpg用这个:curl=Curl::Easy.new("https://api.parse.com/1/files/lion.jpg")curl.multipart_form_

  8. 世界前沿3D开发引擎HOOPS全面讲解——集3D数据读取、3D图形渲染、3D数据发布于一体的全新3D应用开发工具 - 2

    无论您是想搭建桌面端、WEB端或者移动端APP应用,HOOPSPlatform组件都可以为您提供弹性的3D集成架构,同时,由工业领域3D技术专家组成的HOOPS技术团队也能为您提供技术支持服务。如果您的客户期望有一种在多个平台(桌面/WEB/APP,而且某些客户端是“瘦”客户端)快速、方便地将数据接入到3D应用系统的解决方案,并且当访问数据时,在各个平台上的性能和用户体验保持一致,HOOPSPlatform将帮助您完成。利用HOOPSPlatform,您可以开发在任何环境下的3D基础应用架构。HOOPSPlatform可以帮您打造3D创新型产品,HOOPSSDK包含的技术有:快速且准确的CAD

  9. FOHEART H1数据手套驱动Optitrack光学动捕双手运动(Unity3D) - 2

    本教程将在Unity3D中混合Optitrack与数据手套的数据流,在人体运动的基础上,添加双手手指部分的运动。双手手背的角度仍由Optitrack提供,数据手套提供双手手指的角度。 01  客户端软件分别安装MotiveBody与MotionVenus并校准人体与数据手套。MotiveBodyMotionVenus数据手套使用、校准流程参照:https://gitee.com/foheart_1/foheart-h1-data-summary.git02  数据转发打开MotiveBody软件的Streaming,开始向Unity3D广播数据;MotionVenus中设置->选项选择Unit

  10. 使用canal同步MySQL数据到ES - 2

    文章目录一、概述简介原理模块二、配置Mysql使用版本环境要求1.操作系统2.mysql要求三、配置canal-server离线下载在线下载上传解压修改配置单机配置集群配置分库分表配置1.修改全局配置2.实例配置垂直分库水平分库3.修改group-instance.xml4.启动监听四、配置canal-adapter1修改启动配置2配置映射文件3启动ES数据同步查询所有订阅同步数据同步开关启动4.验证五、配置canal-admin一、概述简介canal是Alibaba旗下的一款开源项目,Java开发。基于数据库增量日志解析,提供增量数据订阅&消费。Git地址:https://github.co

随机推荐