草庐IT

你需要知道的11个Torchvision计算机视觉数据集

王瑞平 2023-04-11 原文

译者 | 王瑞平

51CTO读者成长计划社群招募,咨询小助手(微信号:TTalkxiaozhuli)

计算机视觉是一个显著增长的领域,有许多实际应用,从自动驾驶汽车到面部识别系统。该领域的主要挑战之一是获得高质量的数据集来训练机器学习模型。

Torchvision作为Pytorch的图形库,一直服务于PyTorch深度学习框架,主要用于构建计算机视觉模型。

为了解决这一挑战,Torchvision提供了访问预先构建的数据集、模型和专门为计算机视觉任务设计的转换。此外,Torchvision还支持CPU和GPU的加速,使其成为开发计算机视觉应用程序的灵活且强大的工具。

一、什么是“Torchvision数据集”?

Torchvision数据集是计算机视觉中常用的用于开发和测试机器学习模型的流行数据集集合。运用Torchvision数据集,开发人员可以在一系列任务上训练和测试他们的机器学习模型,例如,图像分类、对象检测和分割。数据集还经过预处理、标记并组织成易于加载和使用的格式。

据了解,Torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是“常用数据集+常见模型+常见图像增强”方法。

Torchvision中的数据集共有11种:MNIST、CIFAR-10等,下面具体说说。

二、Torchvision中的11种数据集

1.MNIST手写数字数据库

这个Torchvision数据集在机器学习和计算机视觉领域中非常流行和广泛应用。它由7万张手写数字0-9的灰度图像组成。其中,6万张用于训练,1万张用于测试。每张图像的大小为28×28像素,并有相应的标签表示它所代表的数字。

要访问此数据集,您可以直接从Kaggle下载或使用torchvision加载数据集:

import torchvision.datasets as datasets# Load the training dataset
train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)# Load the testing dataset
test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)

2.CIFAR-10(广泛使用的标准数据集)

CIFAR-10数据集由6万张32×32彩色图像组成,分为10个类别,每个类别有6000张图像,总共有5万张训练图像和1万张测试图像。这些图像又分为5个训练批次和一个测试批次,每个批次有1万张图像。数据集可以从Kaggle下载。

import torchimport torchvisionimport torchvision.transforms as transforms

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

在此提醒一句,您可以根据需要调整数据加载器的批处理大小和工作进程的数量。

3.CIFAR-100(广泛使用的标准数据集)

CIFAR-100数据集在100个类中有60,000张(50,000张训练图像和10,000张测试图像)32×32的彩色图像。每个类有600张图像。这100个类被分成20个超类,用一个细标签表示它的类,另一个粗标签表示它所属的超类。

import torchimport torchvisionimport torchvision.transforms as transforms

import torchvision.datasets as datasetsimport torchvision.transforms as transforms# Define transform to normalize data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
# Load CIFAR-100 train and test datasets
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
# Create data loaders for train and test datasets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

4.ImageNet数据集

Torchvision中的ImageNet数据集包含大约120万张训练图像,5万张验证图像和10万张测试图像。数据集中的每张图像都被标记为1000个类别中的一个,如“猫”、“狗”、“汽车”、“飞机”等。

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
# Set the path to the ImageNet dataset on your machine
data_path = "/path/to/imagenet"
# Create the ImageNet dataset object with custom options
imagenet_train = datasets.ImageNet(
root=data_path,
split='train',
transform=transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
download=False)

imagenet_val = datasets.ImageNet(
root=data_path,
split='val',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
download=False)
# Print the number of images in the training and validation setsprint("Number of images in the training set:", len(imagenet_train))print("Number of images in the validation set:", len(imagenet_val))

5.MSCoco数据集

Microsoft Common Objects in Context(MS Coco)数据集包含32.8万张日常物体和人类的高质量视觉图像,通常用作实时物体检测中比较算法性能的标准。

6.Fashion-MNIST数据集

时尚MNIST数据集是由Zalando Research创建的,作为原始MNIST数据集的替代品。Fashion MNIST数据集由70000张服装灰度图像(训练集60000张,测试集10000张)组成。

图片大小为28×28像素,代表10种不同类别的服装,包括:t恤/上衣、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和短靴。它类似于原始的MNIST数据集,但由于服装项目的复杂性和多样性,分类任务更具挑战性。这个Torchvision数据集可以从Kaggle下载。

import torchimport torchvisionimport torchvision.transforms as transforms
# Define transformations
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])# Load the dataset
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transform)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
download=True, transform=transform)
# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

7.SVHN数据集

SVHN(街景门牌号)数据集是一个来自谷歌街景图像的图像数据集,它由从街道级图像中截取的门牌号的裁剪图像组成。它包含所有门牌号及其包围框的完整格式和仅包含门牌号的裁剪格式。完整格式通常用于对象检测任务,而裁剪格式通常用于分类任务。

SVHN数据集也包含在Torchvision包中,它包含了73,257张用于训练的图像、26,032张用于测试的图像和531,131张用于额外训练数据的额外图像。

import torchvisionimport torch
# Load the train and test sets
train_set = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=torchvision.transforms.ToTensor())
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

8.STL-10数据集

STL-10数据集是一个图像识别数据集,由10个类组成,总共约6000 +张图像。STL-10代表“图像识别标准训练和测试集-10类”,数据集中的10个类是:飞机、鸟、汽车、猫、鹿、狗、马、猴子、船、卡车。您可以直接从Kaggle下载数据集。

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
# Define the transformation to apply to the data
transform = transforms.Compose([
transforms.ToTensor(),
# Convert PIL image to PyTorch tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the data])
# Load the STL-10 dataset
train_dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=transform)

9.CelebA数据集

这个Torchvision数据集是一个流行的大规模面部属性数据集,包含超过20万张名人图像。2015年,香港中文大学的研究人员首次发布了这一数据。CelebA中的图像包含40个面部属性,如,年龄、头发颜色、面部表情和性别。

此外,这些图片是从互联网上检索到的,涵盖了广泛的面部外观,包括不同的种族、年龄和性别。每个图像中面部位置的边界框注释,以及眼睛、鼻子和嘴巴的5个地标点。

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
transform = transforms.Compose([
transforms.CenterCrop(178),
transforms.Resize(128),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
celeba_dataset = datasets.CelebA(root='./data', split='train', transform=transform, download=True)

9.PASCAL VOC数据集

VOC数据集(视觉对象类)于2005年作为PASCAL VOC挑战的一部分首次引入。该挑战旨在推进视觉识别的最新水平。它由20种不同类别的物体组成,包括:动物、交通工具和常见的家用物品。这些图像中的每一个都标注了图像中物体的位置和分类。注释包括边界框和像素级分割掩码。

数据集分为两个主要集:训练集和验证集。

训练集包含大约5000张带有注释的图像,而验证集包含大约5000张没有注释的图像。此外,该数据集还包括一个包含大约10,000张图像的测试集,但该测试集的注释是不可公开的。

import torchimport torchvisionfrom torchvision import transforms
# Define transformations to apply to the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# Load the train and validation datasets
train_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='train', transform=transform)
val_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='val', transform=transform)# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

11.Places365数据集

Places365数据集是一个大型场景识别数据集,拥有超过180万张图像,涵盖365个场景类别。Places365标准数据集包含约180万张图像,而Places365挑战数据集包含5万张额外的验证图像,这些图像对识别模型更具挑战性。

import torchimport torchvisionfrom torchvision import transforms
# Define transformations to apply to the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# Load the train and validation datasets
train_dataset = torchvision.datasets.Places365(root='./data', split='train-standard', transform=transform)
val_dataset = torchvision.datasets.Places365(root='./data', split='val', transform=transform)# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

三、总结

总之,Torchvision数据集通常用于训练和评估机器学习模型,如卷积神经网络(CNNs)。这些模型通常用于计算机视觉应用,任何人都可以免费下载和使用。本文的主要图像是通过HackerNoon的AI稳定扩散模型生成的。

参考链接:​​https://hackernoon.com/11-torchvision-datasets-for-computer-vision-you-need-to-know​

有关你需要知道的11个Torchvision计算机视觉数据集的更多相关文章

  1. ruby - 我需要将 Bundler 本身添加到 Gemfile 中吗? - 2

    当我使用Bundler时,是否需要在我的Gemfile中将其列为依赖项?毕竟,我的代码中有些地方需要它。例如,当我进行Bundler设置时:require"bundler/setup" 最佳答案 没有。您可以尝试,但首先您必须用鞋带将自己抬离地面。 关于ruby-我需要将Bundler本身添加到Gemfile中吗?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/4758609/

  2. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  3. ruby - rspec 需要 .rspec 文件中的 spec_helper - 2

    我注意到像bundler这样的项目在每个specfile中执行requirespec_helper我还注意到rspec使用选项--require,它允许您在引导rspec时要求一个文件。您还可以将其添加到.rspec文件中,因此只要您运行不带参数的rspec就会添加它。使用上述方法有什么缺点可以解释为什么像bundler这样的项目选择在每个规范文件中都需要spec_helper吗? 最佳答案 我不在Bundler上工作,所以我不能直接谈论他们的做法。并非所有项目都checkin.rspec文件。原因是这个文件,通常按照当前的惯例,只

  4. ruby - 如何在 Lion 上安装 Xcode 4.6,需要用 RVM 升级 ruby - 2

    我实际上是在尝试使用RVM在我的OSX10.7.5上更新ruby,并在输入以下命令后:rvminstallruby我得到了以下回复:Searchingforbinaryrubies,thismighttakesometime.Checkingrequirementsforosx.Installingrequirementsforosx.Updatingsystem.......Errorrunning'requirements_osx_brew_update_systemruby-2.0.0-p247',pleaseread/Users/username/.rvm/log/138121

  5. ruby - Ruby 有 `Pair` 数据类型吗? - 2

    有时我需要处理键/值数据。我不喜欢使用数组,因为它们在大小上没有限制(很容易不小心添加超过2个项目,而且您最终需要稍后验证大小)。此外,0和1的索引变成了魔数(MagicNumber),并且在传达含义方面做得很差(“当我说0时,我的意思是head...”)。散列也不合适,因为可能会不小心添加额外的条目。我写了下面的类来解决这个问题:classPairattr_accessor:head,:taildefinitialize(h,t)@head,@tail=h,tendend它工作得很好并且解决了问题,但我很想知道:Ruby标准库是否已经带有这样一个类? 最佳

  6. ruby - 为什么在 ruby​​ 中创建 Rational 不需要新方法 - 2

    这个问题在这里已经有了答案:关闭10年前。PossibleDuplicate:Rubysyntaxquestion:Rational(a,b)andRational.new!(a,b)我正在阅读ruby镐书,我对创建有理数的语法感到困惑。Rational(3,4)*Rational(1,2)产生=>3/8为什么Rational不需要new方法(我还注意到例如我可以在没有new方法的情况下创建字符串)?

  7. ruby - 我如何添加二进制数据来遏制 POST - 2

    我正在尝试使用Curbgem执行以下POST以解析云curl-XPOST\-H"X-Parse-Application-Id:PARSE_APP_ID"\-H"X-Parse-REST-API-Key:PARSE_API_KEY"\-H"Content-Type:image/jpeg"\--data-binary'@myPicture.jpg'\https://api.parse.com/1/files/pic.jpg用这个:curl=Curl::Easy.new("https://api.parse.com/1/files/lion.jpg")curl.multipart_form_

  8. 世界前沿3D开发引擎HOOPS全面讲解——集3D数据读取、3D图形渲染、3D数据发布于一体的全新3D应用开发工具 - 2

    无论您是想搭建桌面端、WEB端或者移动端APP应用,HOOPSPlatform组件都可以为您提供弹性的3D集成架构,同时,由工业领域3D技术专家组成的HOOPS技术团队也能为您提供技术支持服务。如果您的客户期望有一种在多个平台(桌面/WEB/APP,而且某些客户端是“瘦”客户端)快速、方便地将数据接入到3D应用系统的解决方案,并且当访问数据时,在各个平台上的性能和用户体验保持一致,HOOPSPlatform将帮助您完成。利用HOOPSPlatform,您可以开发在任何环境下的3D基础应用架构。HOOPSPlatform可以帮您打造3D创新型产品,HOOPSSDK包含的技术有:快速且准确的CAD

  9. FOHEART H1数据手套驱动Optitrack光学动捕双手运动(Unity3D) - 2

    本教程将在Unity3D中混合Optitrack与数据手套的数据流,在人体运动的基础上,添加双手手指部分的运动。双手手背的角度仍由Optitrack提供,数据手套提供双手手指的角度。 01  客户端软件分别安装MotiveBody与MotionVenus并校准人体与数据手套。MotiveBodyMotionVenus数据手套使用、校准流程参照:https://gitee.com/foheart_1/foheart-h1-data-summary.git02  数据转发打开MotiveBody软件的Streaming,开始向Unity3D广播数据;MotionVenus中设置->选项选择Unit

  10. 使用canal同步MySQL数据到ES - 2

    文章目录一、概述简介原理模块二、配置Mysql使用版本环境要求1.操作系统2.mysql要求三、配置canal-server离线下载在线下载上传解压修改配置单机配置集群配置分库分表配置1.修改全局配置2.实例配置垂直分库水平分库3.修改group-instance.xml4.启动监听四、配置canal-adapter1修改启动配置2配置映射文件3启动ES数据同步查询所有订阅同步数据同步开关启动4.验证五、配置canal-admin一、概述简介canal是Alibaba旗下的一款开源项目,Java开发。基于数据库增量日志解析,提供增量数据订阅&消费。Git地址:https://github.co

随机推荐