草庐IT

使用TensorFlow训练图像分类模型的指南

陈峻 2023-03-29 原文

译者 | 陈峻

审校 | 孙淑娟

众所周知,人类在很小的时候就学会了识别和标记自己所看到的事物。如今,随着机器学习和深度学习算法的不断迭代,计算机已经能够以非常高的精度,对捕获到的图像进行大规模的分类了。目前,此类先进算法的应用场景已经涵括到了包括:解读肺部扫描影像是否健康,通过移动设备进行面部识别,以及为零售商区分不同的消费对象类型等领域。

下面,我将和您共同探讨计算机视觉(Computer Vision)的一种应用——图像分类,并逐步展示如何使用TensorFlow,在小型图像数据集上进行模型的训练。

1、数据集和目标

在本示例中,我们将使用MNIST数据集的从0到9的数字图像。其形态如下图所示:

我们训练该模型的目的是为了将图像分类到其各自的标签下,即:它们在上图中各自对应的数字处。通常,深度神经网络架构会提供一个输入、一个输出、两个隐藏层(Hidden Layers)和一个用于训练模型的Dropout层。而CNN或卷积神经网络(Convolutional Neural Network)是识别较大图像的首选,它能够在减少输入量的同时,捕获到相关的信息。

2、准备工作

首先,让我们通过TensorFlow、to_categorical(用于将数字类的值转换为其他类别)、Sequential、Flatten、Dense、以及用于构建神经网络架构的 Dropout,来导入所有相关的代码库。您可能会对此处提及的部分代码库略感陌生。我会在下文中对它们进行详细的解释。

3、超参数

  • 我将通过如下方面,来选择正确的超参数集:

  • 首先,让我们定义一些超参数作为起点。后续,您可以针对不同的需求,对其进行调整。在此,我选择了128作为较小的批量尺寸(batch size)。其实,批量尺寸可以取任何值,但是2的幂次方大小往往能够提高内存的效率,因此应作为首选。值得注意的是,在决定合适的批量尺寸时,其背后的主要参考依据是:过小的批量尺寸会使收敛过于繁琐,而过大的批量尺寸则可能并不适合您的计算机内存。
  • 让我们将epoch(训练集中每一个样本都参与一次训练)的数量保持为50 ,以实现对模型的快速训练。epoch数值越低,越适合小而简单的数据集。
  • 接着,您需要添加隐藏层。在此,我为每个隐藏层都保留了128个神经元。当然,你也可以用64和32个神经元进行测试。就本例而言,像MINST这样的简单数据集,我并不建议使用较高的数值。
  • 您可以尝试不同的学习率(learning rate),例如0.01、0.05和0.1。在本例中,我将其保持为0.01。
  • 对于其他超参数,我将衰减步骤(decay steps)和衰减率(decay rate)分别选择为2000和0.9。而随着训练的进行,它们可以被用来降低学习率。
  • 在此,我选择Adamax作为优化器。当然,您也可以选择诸如Adam、RMSProp、SGD等其他优化器。
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout
params = {
'dropout': 0.25,
'batch-size': 128,
'epochs': 50,
'layer-1-size': 128,
'layer-2-size': 128,
'initial-lr': 0.01,
'decay-steps': 2000,
'decay-rate': 0.9,
'optimizer': 'adamax'
}
mnist = tf.keras.datasets.mnist
num_class = 10
# split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# reshape and normalize the data
x_train = x_train.reshape(60000, 784).astype("float32")/255
x_test = x_test.reshape(10000, 784).astype("float32")/255
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_class)
y_test = to_categorical(y_test, num_class)

4、创建训练和测试集

由于TensorFlow库也包括了MNIST数据集,因此您可以通过调用对象上的 datasets.mnist ,再调用load_data() 的方法,来分别获取训练(60,000个样本)和测试(10,000个样本)的数据集。

接着,您需要对训练和测试的图像进行整形和归一化。其中,归一化会将图像的像素强度限制在0和1之间。

最后,我们使用之前已导入的to_categorical 方法,将训练和测试标签转换为已分类标签。这对于向TensorFlow框架传达输出的标签(即:0到9)为类(class),而不是数字类型,是非常重要的。

5、设计神经网络架构

下面,让我们来了解如何在细节上设计神经网络架构。

我们通过添加Flatten ,将2D图像矩阵转换为向量,以定义DNN(深度神经网络)的结构。输入的神经元在此处对应向量中的数字。

接着,我使用Dense() 方法,添加两个隐藏的密集层,并从之前已定义的“params”字典中提取各项超参数。我们可以将“relu”(Rectified Linear Unit)作为这些层的激活函数。它是神经网络隐藏层中最常用的激活函数之一。

然后,我们使用Dropout方法添加Dropout层。它将被用于在训练神经网络时,避免出现过拟合(overfitting)。毕竟,过度拟合模型倾向于准确地记住训练集,并且无法泛化那些不可见(unseen)的数据集。

输出层是我们网络中的最后一层,它是使用Dense() 方法来定义的。需要注意的是,输出层有10个神经元,这对应于类(数字)的数量。

# Model Definition
# Get parameters from logged hyperparameters
model = Sequential([
Flatten(input_shape=(784, )),
Dense(params('layer-1-size'), activatinotallow='relu'),
Dense(params('layer-2-size'), activatinotallow='relu'),
Dropout(params('dropout')),
Dense(10)
])
lr_schedule =
tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=experiment.get_parameter('initial-lr'),
decay_steps=experiment.get_parameter('decay-steps'),
decay_rate=experiment.get_parameter('decay-rate')
)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adamax',
loss=loss_fn,
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=experiment.get_parameter('batch-size'),
epochs=experiment.get_parameter('epochs'),
validation_data=(x_test, y_test),)
score = model.evaluate(x_test, y_test)
# Log Model
model.save('tf-mnist-comet.h5')

6、训练

至此,我们已经定义好了架构。下面让我们用给定的训练数据,来编译和训练神经网络。

首先,我们以初始学习率、衰减步骤和衰减率作为参数,使用ExponentialDecay(指数衰减学习率)来定义学习率计划。

其次,将损失函数定义为CategoricalCrossentropy(用于多类式分类)。

接着,通过将优化器 (即:adamax)、损失函数、以及各项指标(由于所有类都同等重要、且均匀分布,因此我选择了准确性)作为参数,来编译模型。

然后,我们通过使用x_train、y_train、batch_size、epochs和validation_data去调用一个拟合方法,并拟合出模型。

同时,我们调用模型对象的评估方法,以获得模型在不可见数据集上的表现分数。

最后,您可以使用在模型对象上调用的save方法,保存要在生产环境中部署的模型对象。

7、小结

综上所述,我们讨论了为图像分类任务,训练深度神经网络的一些入门级的知识。您可以将其作为熟悉使用神经网络,进行图像分类的一个起点。据此,您可了解到该如何选择正确的参数集、以及架构背后的思考逻辑。

原文链接:https://www.kdnuggets.com/2022/12/guide-train-image-classification-model-tensorflow.html

有关使用TensorFlow训练图像分类模型的指南的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 使用 RubyZip 生成 ZIP 文件时设置压缩级别 - 2

    我有一个Ruby程序,它使用rubyzip压缩XML文件的目录树。gem。我的问题是文件开始变得很重,我想提高压缩级别,因为压缩时间不是问题。我在rubyzipdocumentation中找不到一种为创建的ZIP文件指定压缩级别的方法。有人知道如何更改此设置吗?是否有另一个允许指定压缩级别的Ruby库? 最佳答案 这是我通过查看ruby​​zip内部创建的代码。level=Zlib::BEST_COMPRESSIONZip::ZipOutputStream.open(zip_file)do|zip|Dir.glob("**/*")d

  3. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  4. ruby-on-rails - 使用 Ruby on Rails 进行自动化测试 - 最佳实践 - 2

    很好奇,就使用ruby​​onrails自动化单元测试而言,你们正在做什么?您是否创建了一个脚本来在cron中运行rake作业并将结果邮寄给您?git中的预提交Hook?只是手动调用?我完全理解测试,但想知道在错误发生之前捕获错误的最佳实践是什么。让我们理所当然地认为测试本身是完美无缺的,并且可以正常工作。下一步是什么以确保他们在正确的时间将可能有害的结果传达给您? 最佳答案 不确定您到底想听什么,但是有几个级别的自动代码库控制:在处理某项功能时,您可以使用类似autotest的内容获得关于哪些有效,哪些无效的即时反馈。要确保您的提

  5. ruby - 在 Ruby 中使用匿名模块 - 2

    假设我做了一个模块如下:m=Module.newdoclassCendend三个问题:除了对m的引用之外,还有什么方法可以访问C和m中的其他内容?我可以在创建匿名模块后为其命名吗(就像我输入“module...”一样)?如何在使用完匿名模块后将其删除,使其定义的常量不再存在? 最佳答案 三个答案:是的,使用ObjectSpace.此代码使c引用你的类(class)C不引用m:c=nilObjectSpace.each_object{|obj|c=objif(Class===objandobj.name=~/::C$/)}当然这取决于

  6. ruby - 使用 ruby​​ 和 savon 的 SOAP 服务 - 2

    我正在尝试使用ruby​​和Savon来使用网络服务。测试服务为http://www.webservicex.net/WS/WSDetails.aspx?WSID=9&CATID=2require'rubygems'require'savon'client=Savon::Client.new"http://www.webservicex.net/stockquote.asmx?WSDL"client.get_quotedo|soap|soap.body={:symbol=>"AAPL"}end返回SOAP异常。检查soap信封,在我看来soap请求没有正确的命名空间。任何人都可以建议我

  7. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  8. ruby-on-rails - Rails - 子类化模型的设计模式是什么? - 2

    我有一个模型:classItem项目有一个属性“商店”基于存储的值,我希望Item对象对特定方法具有不同的行为。Rails中是否有针对此的通用设计模式?如果方法中没有大的if-else语句,这是如何干净利落地完成的? 最佳答案 通常通过Single-TableInheritance. 关于ruby-on-rails-Rails-子类化模型的设计模式是什么?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.co

  9. ruby-on-rails - 'compass watch' 是如何工作的/它是如何与 rails 一起使用的 - 2

    我在我的项目目录中完成了compasscreate.和compassinitrails。几个问题:我已将我的.sass文件放在public/stylesheets中。这是放置它们的正确位置吗?当我运行compasswatch时,它不会自动编译这些.sass文件。我必须手动指定文件:compasswatchpublic/stylesheets/myfile.sass等。如何让它自动运行?文件ie.css、print.css和screen.css已放在stylesheets/compiled。如何在编译后不让它们重新出现的情况下删除它们?我自己编译的.sass文件编译成compiled/t

  10. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

随机推荐