草庐IT

Pytorch教程入门系列 10----优化器介绍

CV_Today 2023-04-09 原文

文章目录


前言

一、什么叫优化器

用于优化模型的参数。在选择优化器时,需要考虑模型的结构、模型的数据量、模型的目标函数等因素。
优化器是一种算法,用于训练模型并使模型的损失最小化。它通过不断更新模型的参数来实现这一目的。
优化器通常用于深度学习模型,因为这些模型通常具有大量可训练参数,并且需要大量数据和计算来优化。优化器通过不断更新模型的参数来拟合训练数据,从而使模型在新数据上表现良好。

二、优化器的种类介绍

1、SGD(Stochastic Gradient Descent)

  • 思想

SGD是一种经典的优化器,用于优化模型的参数。SGD的基本思想是,通过梯度下降的方法,不断调整模型的参数,使模型的损失函数最小化。SGD的优点是实现简单、效率高,缺点是收敛速度慢、容易陷入局部最小值。

  • 数学表达

通过如下的方式来更新模型的参数:

θ ( t + 1 ) = θ ( t ) − α ⋅ ∇ θ J ( θ ( t ) ) \theta^{(t+1)} = \theta^{(t)} - \alpha \cdot \nabla_{\theta} J(\theta^{(t)}) θ(t+1)=θ(t)αθJ(θ(t))
其中, θ ( t ) 表 示 模 型 在 第 \theta^{(t)}表示模型在第 θ(t)t$次迭代时的参数值, α \alpha α表示学习率, ∇ θ J ( θ ( t ) ) \nabla_{\theta} J(\theta^{(t)}) θJ(θ(t))表示损失函数 J ( θ ) J(\theta) J(θ)关于模型参数 θ \theta θ的梯度。

  • 实际使用

在PyTorch中,可以使用torch.optim.SGD类来实现SGD。

# 定义模型
model = ...

# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 训练模型
for inputs, labels in dataset:
    # 计算损失函数
    outputs = model(inputs)
    loss = ...

    # 计算梯度
    optimizer.zero_grad()
    loss.backward()

    # 更新参数
    optimizer.step()

首先定义了模型,然后定义了SGD优化器,并指定了学习率为0.1。接着,通过循环迭代数据集,计算损失函数和梯度,并更新模型的参数。通过这样的方式,就可以在PyTorch中使用SGD来训练模型了。

2、Adam

  • 思想

Adam是一种近似于随机梯度下降的优化器,用于优化模型的参数。Adam的基本思想是,通过维护模型的梯度和梯度平方的一阶动量和二阶动量,来调整模型的参数。Adam的优点是计算效率高,收敛速度快,缺点是需要调整超参数。

  • 数学表达

  • 通过如下的方式来更新模型的参数:

m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t mt=β1mt1+(1β1)gt

v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 vt=β2vt1+(1β2)gt2

其中, m t m_t mt v t v_t vt分别表示梯度的一阶动量和二阶动量, g t g_t gt表示模型在第 t t t次迭代时的梯度, β 1 \beta_1 β1 β 2 \beta_2 β2是超参数。

θ ( t + 1 ) = θ ( t ) − α v t + ϵ m t \theta^{(t+1)} = \theta^{(t)} - \frac{\alpha}{\sqrt{v_t} + \epsilon} m_t θ(t+1)=θ(t)vt +ϵαmt
其中, θ ( t ) \theta^{(t)} θ(t)表示模型在第 t t t次迭代时的参数值, α \alpha α表示学习率, m t m_t mt v t v_t vt分别表示梯度的一阶动量和二阶动量, ϵ \epsilon ϵ是一个小常数,用于防止分母为0。

  • 实际使用

在PyTorch中,可以使用torch.optim.Adam类来实现Adam。

# 定义模型
model = ...

# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))

# 训练模型
for inputs, labels in dataset:
    # 计算损失函数
    outputs = model(inputs)
    loss = ...

    # 计算梯度
    optimizer.zero_grad()
    loss.backward()

    # 更新参数
    optimizer.step()

上面的代码中,首先定义了模型,然后定义了Adam优化器,并指定了学习率为0.1, β 1 \beta_1 β1 β 2 \beta_2 β2的值分别为0.9和0.999。接着,通过循环迭代数据集,计算损失函数和梯度,并更新模型的参数。通过这样的方式,就可以在PyTorch中使用Adam来训练模型了。

3、RMSprop(Root Mean Square Propagation)

  • 思想

RMSprop是一种改进的随机梯度下降优化器,用于优化模型的参数。RMSprop的基本思想是,通过维护模型的梯度平方的指数加权平均,来调整模型的参数。RMSprop的优点是收敛速度快,缺点是计算复杂度高,需要调整超参数。

  • 数学表达

具体来说,RMSprop优化算法的公式如下:

g t + 1 = α g t + ( 1 − α ) g t 2 g_{t+1} = \alpha g_t + (1 - \alpha) g_t^2 gt+1=αgt+(1α)gt2

θ t + 1 = θ t − η g t + 1 + ϵ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{g_{t+1} + \epsilon}} θt+1=θtgt+1+ϵ η

其中, g t g_t gt表示模型在第 t t t次迭代中的梯度的平方和, θ t \theta_t θt表示模型在第 t t t次迭代中的参数值, α \alpha α表示梯度的指数衰减率, η \eta η表示学习率, ϵ \epsilon ϵ表示一个小常数,用于防止除数为0。

  • 实际使用

在PyTorch中,可以使用torch.optim.Adam类来实现Adam。

import torch

# 定义模型
model = MyModel()

# 如果可用则model移至GPU
if torch.cuda.is_available():
    model = model.cuda()

# 设定训练模式
model.train()
# 定义 RMSprop 优化器
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01)

# 循环训练
for input, target in dataset:
    # 如果可用则将input、target移至GPU
    if torch.cuda.is_available():
        input = input.cuda()
        target = target.cuda()

    # 前向传递:通过将输入传递给模型来计算预测输出
    output = model(input)

    # 计算损失
    loss = loss_fn(output, target)

    # 清除所有优化变量的梯度
    optimizer.zero_grad()

    # 反向传递:计算损失相对于模型参数的梯度
    loss.backward()

    # 执行单个优化步骤(参数更新)
    optimizer.step()

上面的代码中,首先定义了模型,并将其转换为训练模式。然后定义了RMSprop优化器,并指定了要优化的模型参数,学习率为0.1, α \alpha α的值为0.9。接着,通过循环迭代数据集,计算损失函数和梯度,并更新模型的参数。通过这样的方式,就可以在PyTorch中使用RMSprop来训练模型了。


总结

除了上面提到的三种优化器,PyTorch还提供了多种优化器,比如Adadelta、Adagrad、AdamW、SparseAdam等。要使用优化器,需要定义模型并转换为训练模式,然后定义优化器并指定要优化的模型参数和学习率。在训练循环中,每次迭代都要计算模型的损失,然后使用优化器来更新模型参数。选择优化器时,需要根据实际情况选择合适的优化器。另外,优化器的超参数也需要适当调整,以获得较好的优化效果。

有关Pytorch教程入门系列 10----优化器介绍的更多相关文章

  1. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  2. ruby-on-rails - 使用一系列等级计算字母等级 - 2

    这里是Ruby新手。完成一些练习后碰壁了。练习:计算一系列成绩的字母等级创建一个方法get_grade来接受测试分数数组。数组中的每个分数应介于0和100之间,其中100是最大分数。计算平均分并将字母等级作为字符串返回,即“A”、“B”、“C”、“D”、“E”或“F”。我一直返回错误:avg.rb:1:syntaxerror,unexpectedtLBRACK,expecting')'defget_grade([100,90,80])^avg.rb:1:syntaxerror,unexpected')',expecting$end这是我目前所拥有的。我想坚持使用下面的方法或.join,

  3. 【鸿蒙应用开发系列】- 获取系统设备信息以及版本API兼容调用方式 - 2

    在应用开发中,有时候我们需要获取系统的设备信息,用于数据上报和行为分析。那在鸿蒙系统中,我们应该怎么去获取设备的系统信息呢,比如说获取手机的系统版本号、手机的制造商、手机型号等数据。1、获取方式这里分为两种情况,一种是设备信息的获取,一种是系统信息的获取。1.1、获取设备信息获取设备信息,鸿蒙的SDK包为我们提供了DeviceInfo类,通过该类的一些静态方法,可以获取设备信息,DeviceInfo类的包路径为:ohos.system.DeviceInfo.具体的方法如下:ModifierandTypeMethodDescriptionstatic StringgetAbiList​()Obt

  4. Unity 热更新技术 | (三) Lua语言基本介绍及下载安装 - 2

    ?博客主页:https://xiaoy.blog.csdn.net?本文由呆呆敲代码的小Y原创,首发于CSDN??学习专栏推荐:Unity系统学习专栏?游戏制作专栏推荐:游戏制作?Unity实战100例专栏推荐:Unity实战100例教程?欢迎点赞?收藏⭐留言?如有错误敬请指正!?未来很长,值得我们全力奔赴更美好的生活✨------------------❤️分割线❤️-------------------------

  5. postman接口测试工具-基础使用教程 - 2

    1.postman介绍Postman一款非常流行的API调试工具。其实,开发人员用的更多。因为测试人员做接口测试会有更多选择,例如Jmeter、soapUI等。不过,对于开发过程中去调试接口,Postman确实足够的简单方便,而且功能强大。2.下载安装官网地址:https://www.postman.com/下载完成后双击安装吧,安装过程极其简单,无需任何操作3.使用教程这里以百度为例,工具使用简单,填写URL地址即可发送请求,在下方查看响应结果和响应状态码常用方法都有支持请求方法:getpostputdeleteGet、Post、Put与Delete的作用get:请求方法一般是用于数据查询,

  6. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  7. 在VMware16虚拟机安装Ubuntu详细教程 - 2

    在VMware16.2.4安装Ubuntu一、安装VMware1.打开VMwareWorkstationPro官网,点击即可进入。2.进入后向下滑动找到Workstation16ProforWindows,点击立即下载。3.下载完成,文件大小615MB,如下图:4.鼠标右击,以管理员身份运行。5.点击下一步6.勾选条款,点击下一步7.先勾选,再点击下一步8.去掉勾选,点击下一步9.点击下一步10.点击安装11.点击许可证12.在百度上搜索VM16许可证,复制填入,然后点击输入即可,亲测有效。13.点击完成14.重启系统,点击是15.双击VMwareWorkstationPro图标,进入虚拟机主

  8. 微信小程序开发入门与实战(Behaviors使用) - 2

    @作者:SYFStrive @博客首页:HomePage📜:微信小程序📌:个人社区(欢迎大佬们加入)👉:社区链接🔗📌:觉得文章不错可以点点关注👉:专栏连接🔗💃:感谢支持,学累了可以先看小段由小胖给大家带来的街舞👉微信小程序(🔥)目录自定义组件-behaviors    1、什么是behaviors    2、behaviors的工作方式    3、创建behavior    4、导入并使用behavior    5、behavior中所有可用的节点    6、同名字段的覆盖和组合规则总结最后自定义组件-behaviors    1、什么是behaviorsbehaviors是小程序中,用于实现

  9. 阿里云RDS——产品系列概述 - 2

    基础版云数据库RDS的产品系列包括基础版、高可用版、集群版、三节点企业版,本文介绍基础版实例的相关信息。RDS基础版实例也称为单机版实例,只有单个数据库节点,计算与存储分离,性价比超高。说明RDS基础版实例只有一个数据库节点,没有备节点作为热备份,因此当该节点意外宕机或者执行重启实例、变更配置、版本升级等任务时,会出现较长时间的不可用。如果业务对数据库的可用性要求较高,不建议使用基础版实例,可选择其他系列(如高可用版),部分基础版实例也支持升级为高可用版。基础版与高可用版的对比拓扑图如下所示。优势 性能由于不提供备节点,主节点不会因为实时的数据库复制而产生额外的性能开销,因此基础版的性能相对于

  10. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

随机推荐