草庐IT

1. 梯度下降法

YL-Wang 2023-03-28 原文

1. 简介

梯度下降法是一种函数极值的优化算法。在机器学习中,主要用于寻找最小化损失函数的的最优解。是算法更新模型参数的常用的方法之一。

2. 相关概念

1. 导数

  • 定义

设一元函数\(f(x)\)\(x_0\)的临域内有定义,若极限

\[f^{`}(x_0)=\lim_{\Delta x\to0}\frac{f(x+\Delta x)-f(x)}{\Delta } \]

存在,则称\(f^{`}(x_0)\)\(f(x)\)\(x=x_0\)处的导数。

  • 意义
    1. 导数的绝对值大小代表了当前函数的在该处的变化速度
    2. 导数的正负代表了在一定临域内随着自变量\(x\)的增加,函数值是增大还是减小

2. 偏导数

  • 定义

对于多元函数\(f(x),x \in R^p\)\(f(x)\)在对\(x_i\)的偏导数定义为

\[\frac{\partial f(x)}{\partial x_i}=\lim_{\Delta x \to 0}\frac{f(x_1,x_2,\cdots,x_i+\Delta x,\cdots,x_p)-f(x_1,x_2,\cdots,x_i,\cdots,x_p)}{\Delta x} \]

  • 意义

偏导数定义了多元函数在某个数轴方向上的变化情况。

3. 方向导数

  • 定义

函数的偏导数定义了在各个数轴上的变化率,方向导数则为函数在任意方向上的变化率。以二元函数\(f(x,y)\)为例:

\[\nabla \frac{\partial f(x)}{\partial l}|_{(x_0,y_0)}=\frac{\partial f(x)}{\partial x}\cos(\alpha)+\frac{\partial f(x)}{\partial y}\cos(\beta) \]

  • 意义

多元函数在某点处的方向导数有无数个,每一个方向导数的值代表了在该方向上的变化程度,我们要寻找在某点处函数变化最快的方向就可以转化成寻找在该点处方向导数的绝对值最大时对应的那个方向

4. 梯度

  • 定义

梯度是一个矢量,表示函数沿着该方向的变化率最大,记为

\[f(x)=(\frac{\partial f(x)}{\partial x_1},\frac{\partial f(x)}{\partial x_2},\cdots,\frac{\partial f(x)}{\partial x_p})^T \]

  • 为什么该方向为变化最快的方向

根据方向导数定义,

\[\begin{align*} \frac{\partial f(x)}{\partial l}|_{(x_0,y_0)} &=\frac{\partial f(x)}{\partial x}\cos(\alpha)+\frac{\partial f(x)}{\partial y}\cos(\beta) \\ &= (\frac{\partial f(x)}{\partial x},\frac{\partial f(x)}{\partial y})(\cos(\alpha),\cos(\beta))^T \\ &= A\cdot I \quad\quad (A=(\frac{\partial f(x)}{\partial x},\frac{\partial f(x)}{\partial y}),I=(\cos(\alpha),\cos(\beta))^T ) \\ &= ||A||\times||I||\cos(\theta) \qquad (\theta为两个向量的夹角) \end{align*} \]

当且仅当 \(\theta=0\),即\(A\)\(I\)通向时,方向导数取得最大值,因此梯度表示变化率最大的方向,此时方向导数为正。因此梯度指向函数增大的方向。

3 原理详解

假设在一个类是凹函数的山中放一个小球,让它自然的滚动到山谷(最小值点)处,那么小球滚动每个地点滚动的方向都是梯度的负方向。

现在有一个凹函数,要找到它的最小值,在不考虑解析解的情况下,也可以利用类似的方法去求解。先随机找一个初始点\(x_0\),然后求出该点的梯度,利用公式\(x_1=x_0-lr*\nabla f(x)\)模拟小球的滚动,其中\(lr\)为滚动的步长,也称为学习率

通过迭代公式 \(x_n=x_{n-1}-lr* \nabla f(x)\)一步步去逼近函数的极小值点。通常迭代的结束条件有:

  • 指定迭代次数
  • 计算迭代前后函数值的差距,若在一个非常小的阈值以为就可以认为已经找到最小值

4. 代码实现

案例 :\(f(x)=(x_1-2)^2+(x_2-3)^2+(x_3-4)^4\)

import numpy as np
#定义函数
def func(x):
    return (x[0]-2)**2+(x[1]-3)**2+(x[2]-4)**2
#定义梯度
def gradFunc(x):
    return np.array([(x[0]-2)*2,(x[1]-3)*2,(x[2]-4)*2])
# 定义梯度下降法
def SGD(init_x,func,gradFunc,lr=0.01,maxIter=100000,error=1e-10):
    x=init_x
    for iter in range(0,maxIter):
        gd=gradFunc(x)
        x_new=x-lr*gd
        if(np.abs(func(x)-func(x_new))<error):
            return x_new
        x=x_new
    return x_new
SGD(np.array([1,1,1]),func,gradFunc) 

array([1.99998703, 2.99997406, 3.99996109])

SGD(np.array([10,10,10]),func,gradFunc)

array([2.00003215, 3.00002813, 4.00002411])

有关1. 梯度下降法的更多相关文章

  1. 映宇宙2022年营收63亿元:同比下降三成,毛利率提升4.3个百分点 - 2

    3月26日,映宇宙(HK:03700,即“映客”)发布截至2022年12月31日的2022年度业绩财务报告。财报显示,映宇宙2022年的总营收为63.19亿元,较2021年同期的91.76亿元下降31.1%。2022年,映宇宙的经营亏损为4698.7万元,2021年同期则为净利润4.57亿元;期内亏损(净亏损)为1.68亿元,2021年同期的净利润为4.33亿元;非国际财务报告准则经调整净利润为3.88亿元,2021年同期为4.82亿元,同比下降19.6%。 映宇宙在财报中表示,收入减少主要是由于行业竞争加剧,该集团对旗下产品采取更为谨慎的运营策略以应对市场变化。不过,映宇宙的毛利率则有所提升

  2. javascript - npm 库/框架下载量大幅下降。有人知道为什么吗? - 2

    在查看npmtrends.com时,我注意到几乎每个npm库/框架的下载量在2018年6月初都大幅下降。有人知道这是为什么吗?也许npm宕机了,或者每个人都在暑假休息了? 最佳答案 在更新npm,Inc.方面的计数时似乎出现了问题:We'reinvestigatingaknownissuewithdownloadcountsnotbeingupdatedproperlyinthepastfewdays.Posted[…]Jun04,2018-17:30UTC(引自https://status.npmjs.org/incidents/

  3. javascript - 简单的递归下降解析器? - 2

    我正在为一种编译成JS(如果相关的话)的模板语言编写解析器。我从几个简单的正则表达式开始,它们似乎可以工作,但正则表达式非常脆弱,所以我决定改写一个解析器。我首先编写了一个简单的解析器,它通过压入/弹出堆栈来记住状态,但事情一直在升级,直到我手上有了一个递归下降解析器。不久之后,我比较了我以前所有解析方法的性能。递归下降解析器到目前为止是最慢的。我被卡住了:是否值得为一些简单的事情使用递归下降解析器,或者我是否有理由走捷径?我很想走纯正则表达式路线,它非常快(几乎比RD解析器快3倍),但在某种程度上非常hacky和不可维护。我认为性能不是非常重要,因为编译后的模板被缓存了,但是递归下降

  4. JavaScript 引用下降 - 2

    我正在创建一个扩展现有应用程序的模块。我收到了一个变量device,我想创建myDevice来始终保存相同的数据。假设数据包含在一个数组中:https://jsfiddle.net/hmkg9q60/2/vardevice={name:"one",data:[1,2,3]};varmyDevice={name:"two",data:[]};myDevice.data=device.data;//Assignarrayreferencedevice.data.push(4);//Pushworksonarrayreferenceconsole.log(device.data);//[1,

  5. javascript - 从具有 O(n) 的数组中获取最大的时间顺​​序下降、最小值和最大值 - 2

    我编写了一个javascript函数来分析数组中最大的落差。但是还有一个小问题。作为最大值,我总是从我的孔阵列而不是我的下降中获得最大值。例子:数组:[100,90,80,120]最大下降值在100到80之间。因此最大值必须为100,最小值必须为80。我的函数总是返回整个数组中的最大值。在我的例子中是120functioncheckData(data){letmax=0letmin=0letdrop=0for(leti=0;i我想从左到右获得按时间顺序正确的最大增量 最佳答案 您的循环应该跟踪当前的下降并将其与之前最大的下降进行比较

  6. Javascript递归函数性能下降 - 2

    我在招聘流程技能测试中被问到以下问题:varx=function(z){console.log(z);if(z>0){x(z-1);}};whythisisprogressivelysloweraszgethigher?proposeabetterversion,keepingitrecursive.我想知道答案只是为了了解它。我回答说它变慢了,因为随着z的增加,递归调用的数量也增加了,但我无法提供更好的版本。另外,我不知道是否还有其他原因导致函数随着z变高而变慢。 最佳答案 正确的答案应该是,“随着z变高,它应该不逐渐变慢”。事实

  7. optimization - 函数调用导致性能下降 - 2

    对于以下函数:funcCycleClock(c*ballclock.Clock)int{fori:=0;i其中c.BallQueue定义为[]int,CalculateBallCycle定义为funcCalculateBallCycle(s[]int)整数。for循环和return语句之间的性能大幅下降。我写了以下基准测试。第一个基准测试整个函数,第二个基准测试for循环,而第三个基准测试CalculateBallCycle函数:funcBenchmarkCycleClock(b*testing.B){fori:=ballclock.MinBalls;i使用123个球,得到以下结果:B

  8. 随着请求数量的增加,Go 网络服务器性能急剧下降 - 2

    我正在使用wrk对一个用Go编写的简单网络服务器进行基准测试。服务器在具有4GBRAM的机器上运行。在测试开始时,代码服务高达2000个请求/秒,性能非常好。但随着时间的推移,进程使用的内存逐渐增加,一旦达到85%(我正在使用top进行检查),吞吐量就会下降到约100个请求/秒。一旦我重新启动服务器,吞吐量再次增加到最佳数量。性能下降是内存问题吗?为什么Go不释放这段内存?我的Go服务器看起来像这样:funcmain(){deferfunc(){//Waitforallmessagestodrainoutbeforeclosingtheproducerp.Flush(1000)p.Cl

  9. database - KyotoCabinet (TreeDB) 性能严重下降 - 2

    我选择TreeDB作为KyotoCabinet后端,希望它能扩展到巨大的值(value)。不幸的是,有一个问题:#./kyotobenchGeneratedstringlength:10241000records,typet74.008887msthroughput:13511/sec2000records,typet145.390096msthroughput:13756/sec4000records,typet290.13486msthroughput:13786/sec8000records,typet584.46691msthroughput:13687/sec16000rec

  10. android - 使用 HOGDescriptor(适用于 Android 的 OpenCV)对图像进行梯度和角度可视化 - 2

    我尝试可视化由OpenCVLibforAndroid的HOGDescriptor计算的图像的渐变和角度。一开始我有一个3channel图像Mat()和8位无符号整数(CV_8UC3)。计算的结果是梯度的MAT()(CV_32FC2)和角度的Mat()(CV_8UC2)。我如何可视化此结果?什么代表值(value)观?为什么角度Mat()有2个channel?渐变Mat()的2个channel是渐变的x和y分量吗?我找不到computeGradiant-Method的文档。 最佳答案 HOG描述符是定向梯度的直方图:它是一个直方图,其

随机推荐