草庐IT

有监督学习——梯度下降

zh-jp 2023-03-28 原文

1. 梯度下降

梯度下降(Gradient Descent)是计算机计算能力有限的条件下启用的逐步逼近、迭代求解方法,在理论上不保证下降求得最优解。

e.g. 假设有三维曲面表达函数空间,长(x)、宽(y)轴为子变量,高(z)是因变量,若使用梯度下降法求解因变量最低点的步骤如下:

  1. 任取一点作为起始点。
  2. 查看当前点向哪个方向移动得到最小的z值,并向该方向移动。
  3. 重复上述步骤,直到无法找到更小的z值,此时认为达到最低点。

受起始点和目标函数的约束,有时该法无法找到全局最优点,但有着比OLS更快的求解速度,因此被广泛应用。

根据原理介绍几个梯度下降求解算法概念:

  • 步长(learning rate):每一步梯度下降时向目标方向前行的长度。
  • 假设函数(hypothesis function):由特征产生目标变量的函数,常用\(h()\)表示。
  • 损失函数(loss function):评估任意参数组合的函数,常用\(J()\)表示。

损失函数判断向周围哪个方向移动的原理是计算损失函数的偏导数向量该向量就是损失函数增长最快的方向,而其反方向则是以最小化损失函数为目标时需要前进的方向。

2. 随机梯度下降

随机梯度下降(Stochastic Gradient Descent, SGD),在损失函数计算时不便利所有样本,只采用单一或小批量样本的方差和作为损失值。因此,每次迭代计算速度非常快,通过每次随机选用不同的样本进行迭代达到对整体数据的拟合。

对比普通梯度下降,随机梯度下降的主要区别在于:

  • 迭代次数明显增加,但由于每次计算样本少,总体时间缩短。
  • 由于样本数据存在噪声,每次迭代方向不一定是“正确的”,但由于迭代次数的增加,总体的移动期望任朝着正确方向前进。
  • 能因为“不一定正确”的方向越过高点,从而找到最优解。

3. Python中的SGDRegression和SGDClassifier

scikit-learn中提供了随机梯度下降的线性回归器SGDRegressor和线性分类器SGDClassifier,使用它们可学习超大规模样本(样本数>\(10^5\)且特征维度>\(10^5\))。

Python中使用两者

from sklearn.linear_model import SGDRegressor, SGDClassifier
X = [[0, 0], [2, 1], [5, 4]]    # 样本特征
y = [0, 2, 2]                   # 样本目标分类
reg = SGDRegressor(penalty='l2', max_iter=10000)
reg.fit(X, y)
reg.predict([[4,3]])
# array([1.85046249])

reg.coef_       # 查看回归参数
# array([0.30716325, 0.16212611])

reg.intercept_  # 查看截距
# array([0.13543114])

clf = SGDClassifier(penalty='l2', max_iter=100) # 初始化分类器
clf.fit(X, y)
clf.predict([[4, 3]])   # 预测
# array([2])

两者最大的不同在于predict()函数的预测结果,SGDClassifier预测的结果一定是训练数据的目标值之一,SGDRegressor预测值是假设函数直接的计算结果。

而两者的对象初始化参数类似:

Attribute Introduce
penalty 损失函数惩罚项,取值none l1 l2elasticnet,"elasticnet"是"l1"和"l2"的综合
loss 损失函数类型,影响训练速度,取值squared_loss huber epsilon_insensitivesquared_epsilon_insensitive
tol 损失函数变化小于tol时认为获得最优解
max_iter 最大迭代次数,当迭代陷入抖动,无法满足tol时只能利用max_iter作为停止迭代条件
shuffle 完成一轮所有样本迭代后是否洗牌
n_jobs 训练中可利用的CPU数量
learning_rate 步长类型,取值constant optimalinvscaling,前者为固定步长。后两者为动态步长有利于在训练初期跳出局部解,同时后期避免抖动。
eta0 learning_rate 为 constant 或 invscaling 时的初始步长
fit_intercept 是否有截距,取值TrueFalse

3. 增量学习

增量学习(Incremental Learning)是指一种可以边读数据边训练的拟合方法。

在scikit-Learn中提供了partial_fit()函数接口,所有支持增量学习的模型都实现了该函数。SGD的增量学习调用方法举例:

from random import randint
import numpy as np
reg2 = SGDRegressor(loss="squared_error", penalty="l1", tol=1e-15)
X = np.linspace(0, 1, 50)  # 50个x值
Y = X/2 + 0.3 + np.random.normal(0, 0.15, len(X))   # 用y=x/2+0.3加随机数生成样本
X = X.reshape(-1, 1)

for i in range(10000):
    idx = randint(0, len(Y)-1)  # 随机选择一个样本索引
    reg2.partial_fit(X[idx: idx+10], Y[idx: idx+10])  # 用partial_fit()训练

print(reg2.coef_)   # 查看回归参数
# [0.56874507]
print(reg2.intercept_)  # 查看截距
# [0.2769033]

查看模型参数,当前模型应为:

\[y=0.56874507x+0.2769033 \]

与生成样本时的公式相近。

参考文献

[1]刘长龙. 从机器学习到深度学习[M]. 1. 电子工业出版社, 2019.3.

有关有监督学习——梯度下降的更多相关文章

  1. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  2. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  3. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  4. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  5. ruby - 我如何学习 ruby​​ 的正则表达式? - 2

    如何学习ruby​​的正则表达式?(对于假人) 最佳答案 http://www.rubular.com/在Ruby中使用正则表达式时是一个很棒的工具,因为它可以立即将结果可视化。 关于ruby-我如何学习ruby​​的正则表达式?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/1881231/

  6. 深度学习12. CNN经典网络 VGG16 - 2

    深度学习12.CNN经典网络VGG16一、简介1.VGG来源2.VGG分类3.不同模型的参数数量4.3x3卷积核的好处5.关于学习率调度6.批归一化二、VGG16层分析1.层划分2.参数展开过程图解3.参数传递示例4.VGG16各层参数数量三、代码分析1.VGG16模型定义2.训练3.测试一、简介1.VGG来源VGG(VisualGeometryGroup)是一个视觉几何组在2014年提出的深度卷积神经网络架构。VGG在2014年ImageNet图像分类竞赛亚军,定位竞赛冠军;VGG网络采用连续的小卷积核(3x3)和池化层构建深度神经网络,网络深度可以达到16层或19层,其中VGG16和VGG

  7. 映宇宙2022年营收63亿元:同比下降三成,毛利率提升4.3个百分点 - 2

    3月26日,映宇宙(HK:03700,即“映客”)发布截至2022年12月31日的2022年度业绩财务报告。财报显示,映宇宙2022年的总营收为63.19亿元,较2021年同期的91.76亿元下降31.1%。2022年,映宇宙的经营亏损为4698.7万元,2021年同期则为净利润4.57亿元;期内亏损(净亏损)为1.68亿元,2021年同期的净利润为4.33亿元;非国际财务报告准则经调整净利润为3.88亿元,2021年同期为4.82亿元,同比下降19.6%。 映宇宙在财报中表示,收入减少主要是由于行业竞争加剧,该集团对旗下产品采取更为谨慎的运营策略以应对市场变化。不过,映宇宙的毛利率则有所提升

  8. 机器学习——时间序列ARIMA模型(四):自相关函数ACF和偏自相关函数PACF用于判断ARIMA模型中p、q参数取值 - 2

    文章目录1、自相关函数ACF2、偏自相关函数PACF3、ARIMA(p,d,q)的阶数判断4、代码实现1、引入所需依赖2、数据读取与处理3、一阶差分与绘图4、ACF5、PACF1、自相关函数ACF自相关函数反映了同一序列在不同时序的取值之间的相关性。公式:ACF(k)=ρk=Cov(yt,yt−k)Var(yt)ACF(k)=\rho_{k}=\frac{Cov(y_{t},y_{t-k})}{Var(y_{t})}ACF(k)=ρk​=Var(yt​)Cov(yt​,yt−k​)​其中分子用于求协方差矩阵,分母用于计算样本方差。求出的ACF值为[-1,1]。但对于一个平稳的AR模型,求出其滞

  9. Unity Shader 学习笔记(5)Shader变体、Shader属性定义技巧、自定义材质面板 - 2

    写在之前Shader变体、Shader属性定义技巧、自定义材质面板,这三个知识点任何一个单拿出来都是一套知识体系,不能一概而论,本文章目的在于将学习和实际工作中遇见的问题进行总结,类似于网络笔记之用,方便后续回顾查看,如有以偏概全、不祥不尽之处,还望海涵。1、Shader变体先看一段代码......Properties{ [KeywordEnum(on,off)]USL_USE_COL("IsUseColorMixTex?",int)=0 [Toggle(IS_RED_ON)]_IsRed("IsRed?",int)=0}......//中间省略,后续会有完整代码 #pragmamulti_c

  10. ruby-on-rails - 这个 C 和 PHP 程序员如何学习 Ruby 和 Rails? - 2

    按照目前的情况,这个问题不适合我们的问答形式。我们希望答案得到事实、引用或专业知识的支持,但这个问题可能会引发辩论、争论、投票或扩展讨论。如果您觉得这个问题可以改进并可能重新打开,visitthehelpcenter指导。关闭9年前。我来自C、php和bash背景,很容易学习,因为它们都有相同的C结构,我可以将其与我已经知道的联系起来。然后2年前我学了Python并且学得很好,Python对我来说比Ruby更容易学。然后从去年开始,我一直在尝试学习Ruby,然后是Rails,我承认,直到现在我还是学不会,讽刺的是那些打着简单易学的烙印,但是对于我这样一个老练的程序员来说,我只是无法将它

随机推荐