草庐IT

使用Pytorch框架自己制作做数据集进行图像分类(一)

张_哈哈 2023-04-13 原文

第一章:Pytorch制作自己的数据集实现图像分类

第一章: Pytorch框架制作自己的数据集实现图像分类
第二章: Pytorch框架构建残差神经网络(ResNet)
第三章: Pytorch框架构建DenseNet神经网络


提示:本文代码,含有部分测试性输出语句,更改数据文件夹路径后可以直接跑通,文章末尾附全部代码

文章目录


前言

网上有很多直接利用已有数据集(如MNIST, CIFAR-10等),直接进行机器学习,图像分类的教程。但如何自己制作数据集,为图像制作相应标签等的教程较少。故写本文,分享一下自己利用Pytorch框架制作数据集的方法技巧。

开发环境:
Pycharm + Python 3.7.9
torch 1.10.2+cu102
torchvision 0.11.3+cu102


提示:以下是本篇文章正文内容

一、上网搜取相关照片作为数据


制作了三个文件夹,每个文件夹里面有十张图片,分别是关于云、雨、太阳,所有图片均来自百度图片。

这是cloud文件夹里面的内容,请注意图片命名格式

这是rain文件夹里面的内容,请注意图片命名格式

这是sun文件夹里面的内容,请注意图片命名格式

二、定义自己的数据类并读入图片数据

1.引入相关库

代码如下:

import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt

2.继承Dataset实现Mydataset子类

代码如下:

#通过创建data.Dataset子类Mydataset来创建输入
class Mydataset(data.Dataset):
# 类初始化
    def __init__(self, root):
        self.imgs_path = root
# 进行切片
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        return img_path
# 返回长度
    def __len__(self):
        return len(self.imgs_path)

init() 初始化方法,传入数据文件夹路径。
getitem() 切片方法,根据索引下标,获得相应的图片。
len() 计算长度方法,返回整个数据文件夹下所有文件的个数。


3.使用glob方法获取文件夹中所有图片路径

代码如下:

#使用glob方法来获取数据图片的所有路径
all_imgs_path = glob.glob(r'F:\weather\*\*.jpg')#数据文件夹路径,根据实际情况更改!
#循环遍历输出列表中的每个元素,显示出每个图片的路径
for var in all_imgs_path:
    print(var)


上图为运行结果部分显示


三、为图片制作标签,并划分训练集与测试集

1.利用自定义类Mydataset创建对象weather_dataset

代码如下:

#利用自定义类Mydataset创建对象weather_dataset
weather_dataset = Mydataset(all_imgs_path)
print(len(weather_dataset)) #返回文件夹中图片总个数
print(weather_dataset[12:15])#切片,显示第12至第十五张图片的路径
wheather_datalodaer = torch.utils.data.DataLoader(weather_dataset, batch_size=5) #每次迭代时返回五个数据
print(next(iter(wheather_datalodaer)))


上图为运行结果

2.为每张图片制作相应的标签

代码如下:

species = ['cloud','sun','rain']
species_to_id = dict((c, i) for i, c in enumerate(species))
print(species_to_id)
id_to_species = dict((v, k) for k, v in species_to_id.items())
print(id_to_species)
all_labels = []
#对所有图片路径进行迭代
for img in all_imgs_path:
    # 区分出每个img,应该属于什么类别
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)
print(all_labels) #得到所有标签


上图为运行结果

3.完善Mydataset类,将图片数据转换成Tensor,并展示部分图片与标签对应关系

代码如下:

# 对数据进行转换处理
transform = transforms.Compose([
                transforms.Resize((96,96)), #做的第一步转换
                transforms.ToTensor() #第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
])

class Mydatasetpro(data.Dataset):
# 类初始化
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
# 进行切片
    def __getitem__(self, index):                #根据给出的索引进行切片,并对其进行数据处理转换成Tensor,返回成Tensor
        img = self.imgs[index]
        label = self.labels[index]
        pil_img = Image.open(img)                 #pip install pillow
        data = self.transforms(pil_img)
        return data, label
# 返回长度
    def __len__(self):
        return len(self.imgs)

BATCH_SIZE = 10
weather_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
wheather_datalodaer = data.DataLoader(
                            weather_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True
)

imgs_batch, labels_batch = next(iter(wheather_datalodaer))
print(imgs_batch.shape)

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:6], labels_batch[:6])):
    img = img.permute(1, 2, 0).numpy()
    plt.subplot(2, 3, i+1)
    plt.title(id_to_species.get(label.item()))
    plt.imshow(img)
plt.show()#展示图片


上图为运行结果


4.划分数据集和测试集

代码如下:


#划分测试集和训练集
index = np.random.permutation(len(all_imgs_path))

all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

#80% as train
s = int(len(all_imgs_path)*0.8)
print(s)

train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_imgs_path[s:]

train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
test_ds = Mydatasetpro(test_imgs, test_labels, transform) #TestSet TensorData
train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)#TrainSet Labels
test_dl = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)#TestSet Labels

至此我们把原数据集的80%作为训练集得到了:
train_ds 训练集数据
test_ds 测试集数据
train_dl 训练集标签
test_dl 测试集标签


总结

整体思路
1.将自己所找图片按照一定的规则命名后,放到文件夹中。
2.使用glob方法获取所有数据文件路径
3.创建DataSet类的子类Mydataset,用于后续通过路径读入数据,并易于后续相应处理操作
4.通过Transforms.Compose()方法,对图片数据进行统一处理,并转换成Tensor格式
5.创建Mydatasetpro类,调用相关方法,获得{‘图片名’:标签}和{标签:‘图片名’}字典
6.统计索引数量,按照百分比,划分出训练集和测试集
最后得到:训练集数据、测试集数据、训练集标签、测试集标签

本文代码

import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt

#通过创建data.Dataset子类Mydataset来创建输入
class Mydataset(data.Dataset):
# 类初始化
    def __init__(self, root):
        self.imgs_path = root
# 进行切片
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        return img_path
# 返回长度
    def __len__(self):
        return len(self.imgs_path)

#使用glob方法来获取数据图片的所有路径
all_imgs_path = glob.glob(r'F:\weather\*\*.jpg')
#循环遍历输出列表中的每个元素,显示出每个图片的路径
for var in all_imgs_path:
    print(var)


#利用自定义类Mydataset创建对象weather_dataset
weather_dataset = Mydataset(all_imgs_path)
print(len(weather_dataset)) #返回文件夹中图片总个数
print(weather_dataset[12:14])#切片,显示第12张、第十三张图片,python左闭右开
wheather_datalodaer = torch.utils.data.DataLoader(weather_dataset, batch_size=3) #每次迭代时返回五个数据
print(next(iter(wheather_datalodaer)))

species = ['cloud','sun','rain']
species_to_id = dict((c, i) for i, c in enumerate(species))
print(species_to_id)
id_to_species = dict((v, k) for k, v in species_to_id.items())
print(id_to_species)
all_labels = []
#对所有图片路径进行迭代
for img in all_imgs_path:
    # 区分出每个img,应该属于什么类别
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)
print(all_labels) #得到所有标签

# 对数据进行转换处理
transform = transforms.Compose([
                transforms.Resize((96,96)), #做的第一步转换
                transforms.ToTensor() #第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
])

class Mydatasetpro(data.Dataset):
# 类初始化
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
# 进行切片
    def __getitem__(self, index):                #根据给出的索引进行切片,并对其进行数据处理转换成Tensor,返回成Tensor
        img = self.imgs[index]
        label = self.labels[index]
        pil_img = Image.open(img)                 #pip install pillow
        data = self.transforms(pil_img)
        return data, label
# 返回长度
    def __len__(self):
        return len(self.imgs)

BATCH_SIZE = 10
weather_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
wheather_datalodaer = data.DataLoader(
                            weather_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True
)

imgs_batch, labels_batch = next(iter(wheather_datalodaer))
print(imgs_batch.shape)

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:6], labels_batch[:6])):
    img = img.permute(1, 2, 0).numpy()
    plt.subplot(2, 3, i+1)
    plt.title(id_to_species.get(label.item()))
    plt.imshow(img)
plt.show()

#划分测试集和训练集
index = np.random.permutation(len(all_imgs_path))

all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

#80% as train
s = int(len(all_imgs_path)*0.8)
print(s)

train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
test_ds = Mydatasetpro(test_imgs, test_labels, transform) #TestSet TensorData
train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)#TrainSet Labels
test_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)#TestSet Labels

有关使用Pytorch框架自己制作做数据集进行图像分类(一)的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 使用 RubyZip 生成 ZIP 文件时设置压缩级别 - 2

    我有一个Ruby程序,它使用rubyzip压缩XML文件的目录树。gem。我的问题是文件开始变得很重,我想提高压缩级别,因为压缩时间不是问题。我在rubyzipdocumentation中找不到一种为创建的ZIP文件指定压缩级别的方法。有人知道如何更改此设置吗?是否有另一个允许指定压缩级别的Ruby库? 最佳答案 这是我通过查看ruby​​zip内部创建的代码。level=Zlib::BEST_COMPRESSIONZip::ZipOutputStream.open(zip_file)do|zip|Dir.glob("**/*")d

  3. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  4. ruby-on-rails - 使用 Ruby on Rails 进行自动化测试 - 最佳实践 - 2

    很好奇,就使用ruby​​onrails自动化单元测试而言,你们正在做什么?您是否创建了一个脚本来在cron中运行rake作业并将结果邮寄给您?git中的预提交Hook?只是手动调用?我完全理解测试,但想知道在错误发生之前捕获错误的最佳实践是什么。让我们理所当然地认为测试本身是完美无缺的,并且可以正常工作。下一步是什么以确保他们在正确的时间将可能有害的结果传达给您? 最佳答案 不确定您到底想听什么,但是有几个级别的自动代码库控制:在处理某项功能时,您可以使用类似autotest的内容获得关于哪些有效,哪些无效的即时反馈。要确保您的提

  5. ruby - 在 Ruby 中使用匿名模块 - 2

    假设我做了一个模块如下:m=Module.newdoclassCendend三个问题:除了对m的引用之外,还有什么方法可以访问C和m中的其他内容?我可以在创建匿名模块后为其命名吗(就像我输入“module...”一样)?如何在使用完匿名模块后将其删除,使其定义的常量不再存在? 最佳答案 三个答案:是的,使用ObjectSpace.此代码使c引用你的类(class)C不引用m:c=nilObjectSpace.each_object{|obj|c=objif(Class===objandobj.name=~/::C$/)}当然这取决于

  6. ruby - 使用 ruby​​ 和 savon 的 SOAP 服务 - 2

    我正在尝试使用ruby​​和Savon来使用网络服务。测试服务为http://www.webservicex.net/WS/WSDetails.aspx?WSID=9&CATID=2require'rubygems'require'savon'client=Savon::Client.new"http://www.webservicex.net/stockquote.asmx?WSDL"client.get_quotedo|soap|soap.body={:symbol=>"AAPL"}end返回SOAP异常。检查soap信封,在我看来soap请求没有正确的命名空间。任何人都可以建议我

  7. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  8. ruby-on-rails - 'compass watch' 是如何工作的/它是如何与 rails 一起使用的 - 2

    我在我的项目目录中完成了compasscreate.和compassinitrails。几个问题:我已将我的.sass文件放在public/stylesheets中。这是放置它们的正确位置吗?当我运行compasswatch时,它不会自动编译这些.sass文件。我必须手动指定文件:compasswatchpublic/stylesheets/myfile.sass等。如何让它自动运行?文件ie.css、print.css和screen.css已放在stylesheets/compiled。如何在编译后不让它们重新出现的情况下删除它们?我自己编译的.sass文件编译成compiled/t

  9. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

  10. ruby - 在 64 位 Snow Leopard 上使用 rvm、postgres 9.0、ruby 1.9.2-p136 安装 pg gem 时出现问题 - 2

    我想为Heroku构建一个Rails3应用程序。他们使用Postgres作为他们的数据库,所以我通过MacPorts安装了postgres9.0。现在我需要一个postgresgem并且共识是出于性能原因你想要pggem。但是我对我得到的错误感到非常困惑当我尝试在rvm下通过geminstall安装pg时。我已经非常明确地指定了所有postgres目录的位置可以找到但仍然无法完成安装:$envARCHFLAGS='-archx86_64'geminstallpg--\--with-pg-config=/opt/local/var/db/postgresql90/defaultdb/po

随机推荐