草庐IT

基于pytorch的图像识别基础完整教程

@________ 2023-04-10 原文

一、数据集爬取

现在的深度学习对数据集量的需求越来越大了,也有了许多现成的数据集可供大家查找下载,但是如果你只是想要做一下深度学习的实例以此熟练一下或者找不到好的数据集,那么你也可以尝试自己制作数据集——自己从网上爬取图片,下面是通过百度图片爬取数据的示例。

import os
import time
import requests
import re
def imgdata_set(save_path,word,epoch):
    q=0     #停止爬取图片条件
    a=0     #图片名称
    while(True):
        time.sleep(1)
        url="https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word={}&pn={}&ct=&ic=0&lm=-1&width=0&height=0".format(word,q)
        #word=需要搜索的名字
        headers={
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Edg/88.0.705.56'
        }
        response=requests.get(url,headers=headers)
        # print(response.request.headers)
        html=response.text
        # print(html)
        urls=re.findall('"objURL":"(.*?)"',html)
        # print(urls)
        for url in urls:
            print(a)    #图片的名字
            response = requests.get(url, headers=headers)
            image=response.content
            with open(os.path.join(save_path,"{}.jpg".format(a)),'wb') as f:
                f.write(image)
            a=a+1
        q=q+20
        if (q/20)>=int(epoch):
            break
if __name__=="__main__":
    save_path = input('你想保存的路径:')
    word = input('你想要下载什么图片?请输入:')
    epoch = input('你想要下载几轮图片?请输入(一轮为60张左右图片):')  # 需要迭代几次图片
    imgdata_set(save_path, word, epoch)

通过上述的代码可以自行选择自己需要保存的图片路径、图片种类和图片数目。如我下面做的几种常见的盆栽植物的图片爬取,只需要执行六次代码,改变相应的盆栽植物的名称就可以了。下面是爬取盆栽芦荟的输入示例,输入完成后按Enter执行即可,会自动爬取图片保存到指定文件夹,

如图即为爬取后的图片。

可以看到图片中出现了一些无法打开的图片,同时因为是直接爬取的网络上的图片,可能会出现一些相同的图片,这些都需要进行删除,这就需要我们进行第二步处理了。

二、数据处理

由于上面直接爬取到的图片有一些瑕疵,这就需要对图片进行进一步的处理了,对图片进行去重处理,可以参考下面链接:
文件夹去除重复图片
通过重复图片去重处理,将自己需要的数据集按照种类分别保存在各自的文件夹里。同样,由于数据集可能存在无法打开的图片,这就需要对数据集进行下一步处理了。
首先将上面去重处理后的文件夹统一保存在同一个文件夹里面,如下图所示。

记住此文件夹路径,我这里是‘C:\Users\Lenovo\Desktop\data’,将此路径输入到下面代码中。

import os
from PIL import Image
root_path=r"C:\Users\Lenovo\Desktop\data"   #待处理文件夹绝对路径(可按‘Ctrl+Shift+c’复制)
root_names=os.listdir(root_path)

for root_name in root_names:
    path=os.path.join(root_path,root_name)
    print("正在删除文件夹:",path)
    names=os.listdir(path)
    names_path=[]
    for name in names:
        # print(name)
        img=Image.open(os.path.join(path,name))
        name_path=os.path.join(path,name)
        if img==None:           #筛选无法打开的图片
            names_path.append(name_path)
            print('成功保存错误图片路径:{}'.format(name))
        else:
            w,h=img.size
            if w<50 or h<50:    #筛选错误图片
                names_path.append(name_path)
                print('成功保存特小图片路径:{}'.format(name))
    print("开始删除需删除的图片")
    for r in names_path:
        os.remove(r)
        print("已删除:",r)

经过上述处理即完成了图片数据集的处理。最后,也可以对图片数据集进行图片名称的处理,使图片的名称重新从零开始依次排列,方便计数(注意下面代码中的rename将会删除掉原文件夹中的图片)。

import os
root_dir=r"C:\Users\Lenovo\Desktop\pzlh"    #原文件夹路径
save_path=r"C:\Users\Lenovo\Desktop\pzlh2"  #新建文件夹路径
img_path=os.listdir(root_dir)
a=0
for i in img_path:
    a+=1
    i= os.path.join(os.path.abspath(root_dir), i)
    new_name=os.path.join(os.path.abspath(save_path), str(a) + '_pzlh.jpg')    #此处可以修改图片名称
    os.rename(i,new_name)       #特别注意:rename会删除原图

最后,我们可以得到一个将完整的常见盆栽植物的数据集。如果此时数据集的图片数量不多,我们还可以采用数据增强的方法,如旋转,加噪等步骤,都可以在网上找到相应的教程。最后,我们可以得到数据集如下图所示。

三、开始识别

首先,先为上面的图片数据集生成对应的标签文件,运行下面代码可以自动生成对应的标签文件。

import os
root_path=r"C:\Users\Lenovo\Desktop\data"
save_path=r"C:\Users\Lenovo\Desktop\data_label" #对应的label文件夹下也要建好相应的空子文件夹
names=os.listdir(root_path) #得到images文件夹下的子文件夹的名称
for name in names:
    path=os.path.join(root_path,name)
    img_names=os.listdir(path)  #得到子文件夹下的图片的名称
    for img_name in img_names:
        save_name = img_name.split(".jpg")[0]+'.txt'    #得到相应的lable名称
        txt_path=os.path.join(save_path,name)           #得到label的子文件夹的路径
        with open(os.path.join(txt_path,save_name), "w") as f:  #结合子文件夹路径和相应子文件夹下图片的名称生成相应的子文件夹txt文件
            f.write(name)       #将label写入对应txt文件夹
            print(f.name)

然后,将上面已经准备好的数据集按照7:3(其他比例也可以)分为训练数据集和验证数据集(图片和标签一定要完全对应即对应图片和标签应该都处于训练集或者数据集),并如下图所示放置。

最后,数据集准备好后,即可导入到模型开始训练,运行下列代码

import time
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn as nn
import torch

print("是否使用GPU训练:{}".format(torch.cuda.is_available()))    #打印是否采用gpu训练
if torch.cuda.is_available:
    print("GPU名称为:{}".format(torch.cuda.get_device_name()))  #打印相应的gpu信息
#数据增强太多也可能造成训练出不好的结果,而且耗时长,宜增强两三倍即可。
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])  #规范化
transform=transforms.Compose([                                  #数据处理
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    normalize
])
dataset_train=ImageFolder('data/train',transform=transform)     #训练数据集
# print(dataset_tran[0])
dataset_valid=ImageFolder('data/valid',transform=transform)     #验证或测试数据集
# print(dataset_train.classer)#返回类别
print(dataset_train.class_to_idx)                               #返回类别及其索引
# print(dataset_train.imgs)#返回图片路径
print(dataset_valid.class_to_idx)
train_data_size=len(dataset_train)                              #放回数据集长度
test_data_size=len(dataset_valid)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
#torch自带的标准数据集加载函数
dataloader_train=DataLoader(dataset_train,batch_size=4,shuffle=True,num_workers=0,drop_last=True)
dataloader_test=DataLoader(dataset_valid,batch_size=4,shuffle=True,num_workers=0,drop_last=True)

#2.模型加载
model_ft=models.resnet18(pretrained=True)#使用迁移学习,加载预训练权重
# print(model_ft)

in_features=model_ft.fc.in_features
model_ft.fc=nn.Sequential(nn.Linear(in_features,36),
                          nn.Linear(36,6))#将最后的全连接改为(36,6),使输出为六个小数,对应六种植物的置信度
#冻结卷积层函数
# for i,para in enumerate(model_ft.parameters()):
#     if i<18:
#         para.requires_grad=False

# print(model_ft)


# model_ft.half()#可改为半精度,加快训练速度,在这里不适用

model_ft=model_ft.cuda()#将模型迁移到gpu
#3.优化器
loss_fn=nn.CrossEntropyLoss()

loss_fn=loss_fn.cuda()  #将loss迁移到gpu
learn_rate=0.01         #设置学习率
optimizer=torch.optim.SGD(model_ft.parameters(),lr=learn_rate,momentum=0.01)#可调超参数

total_train_step=0
total_test_step=0
epoch=50                #迭代次数
writer=SummaryWriter("logs_train_yaopian")
best_acc=-1
ss_time=time.time()

for i in range(epoch):
    start_time = time.time()
    print("--------第{}轮训练开始---------".format(i+1))
    model_ft.train()
    for data in dataloader_train:
        imgs,targets=data
        # if torch.cuda.is_available():
        # imgs.float()
        # imgs=imgs.float()#为上述改为半精度操作,在这里不适用
        imgs=imgs.cuda()
        targets=targets.cuda()
        # imgs=imgs.half()
        outputs=model_ft(imgs)
        loss=loss_fn(outputs,targets)

        optimizer.zero_grad()   #梯度归零
        loss.backward()         #反向传播计算梯度
        optimizer.step()        #梯度优化

        total_train_step=total_train_step+1
        if total_train_step%100==0:#一轮时间过长可以考虑加一个
            end_time=time.time()
            print("使用GPU训练100次的时间为:{}".format(end_time-start_time))
            print("训练次数:{},loss:{}".format(total_train_step,loss.item()))
            # writer.add_scalar("valid_loss",loss.item(),total_train_step)
    model_ft.eval()
    total_test_loss=0
    total_accuracy=0
    with torch.no_grad():       #验证数据集时禁止反向传播优化权重
        for data in dataloader_test:
            imgs,targets=data
            # if torch.cuda.is_available():
            # imgs.float()
            # imgs=imgs.float()
            imgs = imgs.cuda()
            targets = targets.cuda()
            # imgs=imgs.half()
            outputs=model_ft(imgs)
            loss=loss_fn(outputs,targets)
            total_test_loss=total_test_loss+loss.item()
            accuracy=(outputs.argmax(1)==targets).sum()
            total_accuracy=total_accuracy+accuracy
        print("整体测试集上的loss:{}(越小越好,与上面的loss无关此为测试集的总loss)".format(total_test_loss))
        print("整体测试集上的正确率:{}(越大越好)".format(total_accuracy / len(dataset_valid)))

        writer.add_scalar("valid_loss",(total_accuracy/len(dataset_valid)),(i+1))#选择性使用哪一个
        total_test_step = total_test_step + 1
        if total_accuracy > best_acc:   #保存迭代次数中最好的模型
            print("已修改模型")
            best_acc = total_accuracy
            torch.save(model_ft, "best_model_yaopian.pth")
ee_time=time.time()
zong_time=ee_time-ss_time
print("训练总共用时:{}h:{}m:{}s".format(int(zong_time//3600),int((zong_time%3600)//60),int(zong_time%60))) #打印训练总耗时
writer.close()

上述采用的迁移学习直接使用resnet18的模型进行训练,只对全连接的输出进行修改,是一种十分方便且实用的方法,同样,你也可以自己编写模型,然后使用自己的模型进行训练,但是这种方法显然需要训练更长的时间才能达到拟合。如图所示,只需要修改矩形框内部分,将‘model_ft=models.resnet18(pretrained=True)'改为自己的模型‘model_ft=model’即可。

四、模型测试

经过上述的步骤后,我们将会得到一个‘best_model_yaopian.pth’的模型权重文件,最后运行下列代码就可以对图片进行识别了

import os
import torch
import torchvision
from PIL import Image
from torch import nn
i=0 #识别图片计数
root_path="测试_data"         #待测试文件夹
names=os.listdir(root_path)
for name in names:
    print(name)
    i=i+1
    data_class=['滴水观音','发财树','非洲茉莉','君子兰','盆栽芦荟','文竹']   #按文件索引顺序排列
    image_path=os.path.join(root_path,name)             
    image=Image.open(image_path)
    print(image)
    transforms=torchvision.transforms.Compose([torchvision.transforms.Resize((64,64)),
                                              torchvision.transforms.ToTensor()])
    image=transforms(image)
    print(image.shape)

    model_ft=torchvision.models.resnet18()      #需要使用训练时的相同模型
    # print(model_ft)
    in_features=model_ft.fc.in_features
    model_ft.fc=nn.Sequential(nn.Linear(in_features,36),
                              nn.Linear(36,6))     #此处也要与训练模型一致

    model=torch.load("best_model_yaopian.pth",map_location=torch.device("cpu")) #选择训练后得到的模型文件
    # print(model)
    image=torch.reshape(image,(1,3,64,64))      #修改待预测图片尺寸,需要与训练时一致
    model.eval()
    with torch.no_grad():
        output=model(image)
    print(output)               #输出预测结果
    # print(int(output.argmax(1)))
    print("第{}张图片预测为:{}".format(i,data_class[int(output.argmax(1))]))   #对结果进行处理,使直接显示出预测的植物种类

最后,通过上述步骤我们可以得到一个简单的盆栽植物智能识别程序,对盆栽植物进行识别,如下图是识别结果说明。

到这里,我们就实现了一个简单的深度学习图像识别示例了。

有关基于pytorch的图像识别基础完整教程的更多相关文章

  1. ruby-on-rails - 添加回形针新样式不影响旧上传的图像 - 2

    我有带有Logo图像的公司模型has_attached_file:logo我用他们的Logo创建了许多公司。现在,我需要添加新样式has_attached_file:logo,:styles=>{:small=>"30x15>",:medium=>"155x85>"}我是否应该重新上传所有旧数据以重新生成新样式?我不这么认为……或者有什么rake任务可以重新生成样式吗? 最佳答案 参见Thumbnail-Generation.如果rake任务不适合你,你应该能够在控制台中使用一个片段来调用重新处理!关于相关公司

  2. 报告回顾丨模型进化狂飙,DetectGPT能否识别最新模型生成结果? - 2

    导读语言模型给我们的生产生活带来了极大便利,但同时不少人也利用他们从事作弊工作。如何规避这些难辨真伪的文字所产生的负面影响也成为一大难题。在3月9日智源Live第33期活动「DetectGPT:判断文本是否为机器生成的工具」中,主讲人Eric为我们讲解了DetectGPT工作背后的思路——一种基于概率曲率检测的用于检测模型生成文本的工具,它可以帮助我们更好地分辨文章的来源和可信度,对保护信息真实、防止欺诈等方面具有重要意义。本次报告主要围绕其功能,实现和效果等展开。(文末点击“阅读原文”,查看活动回放。)Ericmitchell斯坦福大学计算机系四年级博士生,由ChelseaFinn和Chri

  3. 叮咚买菜基于 Apache Doris 统一 OLAP 引擎的应用实践 - 2

    导读:随着叮咚买菜业务的发展,不同的业务场景对数据分析提出了不同的需求,他们希望引入一款实时OLAP数据库,构建一个灵活的多维实时查询和分析的平台,统一数据的接入和查询方案,解决各业务线对数据高效实时查询和精细化运营的需求。经过调研选型,最终引入ApacheDoris作为最终的OLAP分析引擎,Doris作为核心的OLAP引擎支持复杂地分析操作、提供多维的数据视图,在叮咚买菜数十个业务场景中广泛应用。作者|叮咚买菜资深数据工程师韩青叮咚买菜创立于2017年5月,是一家专注美好食物的创业公司。叮咚买菜专注吃的事业,为满足更多人“想吃什么”而努力,通过美好食材的供应、美好滋味的开发以及美食品牌的孵

  4. [Vuforia]二.3D物体识别 - 2

    之前说过10之后的版本没有3dScan了,所以还是9.8的版本或者之前更早的版本。 3d物体扫描需要先下载扫描的APK进行扫面。首先要在手机上装一个扫描程序,扫描现实中的三维物体,然后上传高通官网,在下载成UnityPackage类型让Unity能够使用这个扫描程序可以从高通官网上进行下载,是一个安卓程序。点到Tools往下滑,找到VuforiaObjectScanner下载后解压数据线连接手机,将apk文件拷入手机安装然后刚才解压文件中的Media文件夹打开,两个PDF图打印第一张A4-ObjectScanningTarget.pdf,主要是用来辅助扫描的。好了,接下来就是扫描三维物体。将瓶

  5. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  6. postman接口测试工具-基础使用教程 - 2

    1.postman介绍Postman一款非常流行的API调试工具。其实,开发人员用的更多。因为测试人员做接口测试会有更多选择,例如Jmeter、soapUI等。不过,对于开发过程中去调试接口,Postman确实足够的简单方便,而且功能强大。2.下载安装官网地址:https://www.postman.com/下载完成后双击安装吧,安装过程极其简单,无需任何操作3.使用教程这里以百度为例,工具使用简单,填写URL地址即可发送请求,在下方查看响应结果和响应状态码常用方法都有支持请求方法:getpostputdeleteGet、Post、Put与Delete的作用get:请求方法一般是用于数据查询,

  7. 软件测试基础 - 2

    Ⅰ软件测试基础一、软件测试基础理论1、软件测试的必要性所有的产品或者服务上线都需要测试2、测试的发展过程3、什么是软件测试找bug,发现缺陷4、测试的定义使用人工或自动的手段来运行或者测试某个系统的过程。目的在于检测它是否满足规定的需求。弄清预期结果和实际结果的差别。5、测试的目的以最小的人力、物力和时间找出软件中潜在的错误和缺陷6、测试的原则28原则:20%的主要功能要重点测(eg:支付宝的支付功能,其他功能都是次要的)80%的错误存在于20%的代码中7、测试标准8、测试的基本要求功能测试性能测试安全性测试兼容性测试易用性测试外观界面测试可靠性测试二、质量模型衡量一个优秀软件的维度①功能性功

  8. ruby-on-rails - 在 heroku 的 .fonts 文件夹中包含自定义字体,似乎无法识别它们 - 2

    Heroku支持人员告诉我,为了在我的Web应用程序中使用自定义字体(未安装在系统中,您可以在bash控制台中使用fc-list查看已安装的字体)我必须部署一个包含所有字体的.fonts文件夹里面的字体。问题是我不知道该怎么做。我的意思是,我不知道文件名是否必须遵循heroku的任何特殊模式,或者我必须在我的代码中做一些事情来考虑这种字体,或者如果我将它包含在文件夹中它是自动的......事实是,我尝试以不同的方式更改字体的文件名,但根本没有使用该字体。为了提供更多详细信息,我们使用字体的过程是将PDF转换为图像,更具体地说,使用rghostgem。并且最终图像根本不使用自定义字体。在

  9. 在VMware16虚拟机安装Ubuntu详细教程 - 2

    在VMware16.2.4安装Ubuntu一、安装VMware1.打开VMwareWorkstationPro官网,点击即可进入。2.进入后向下滑动找到Workstation16ProforWindows,点击立即下载。3.下载完成,文件大小615MB,如下图:4.鼠标右击,以管理员身份运行。5.点击下一步6.勾选条款,点击下一步7.先勾选,再点击下一步8.去掉勾选,点击下一步9.点击下一步10.点击安装11.点击许可证12.在百度上搜索VM16许可证,复制填入,然后点击输入即可,亲测有效。13.点击完成14.重启系统,点击是15.双击VMwareWorkstationPro图标,进入虚拟机主

  10. kvm虚拟机安装centos7基于ubuntu20.04系统 - 2

    需求:要创建虚拟机,就需要给他提供一个虚拟的磁盘,我们就在/opt目录下创建一个10G大小的raw格式的虚拟磁盘CentOS-7-x86_64.raw命令格式:qemu-imgcreate-f磁盘格式磁盘名称磁盘大小qemu-imgcreate-f磁盘格式-o?1.创建磁盘qemu-imgcreate-fraw/opt/CentOS-7-x86_64.raw10G执行效果#ls/opt/CentOS-7-x86_64.raw2.安装虚拟机使用virt-install命令,基于我们提供的系统镜像和虚拟磁盘来创建一个虚拟机,另外在创建虚拟机之前,提前打开vnc客户端,在创建虚拟机的时候,通过vnc

随机推荐