草庐IT

【深度学习】7-矩阵乘法运算的反向传播求梯度

清风莫追 2023-07-11 原文

🚩 前言

本节以较简单的例子来理解矩阵乘法下的反向传播过程。为了稍微形象一些,这里同样会用到计算图来进行描述。

矩阵乘法下的反向传播,其实和标量计算下的反向传播区别不大,只是我们的研究对象从标量变成了矩阵。我们需要解决的就是矩阵乘法运算下求梯度的问题,而两个矩阵的乘法又可以分解为许多标量的运算。


文章目录


1. 求梯度的公式

在矩阵乘法的情况下,设有一个特征矩阵为 X X X,一个权值矩阵为 W W W,输出: Y = X W Y = XW Y=XW
如果我们要得到 Y Y Y关于 W W W的梯度,则可以使用公式: d W = X ⊤ d Y dW=X ^\top dY dW=XdY
同样的,如果求 Y Y Y关于 X X X的梯度,则可以使用公式: d X = d Y W ⊤ dX=dYW^\top dX=dYW

那么,为什么上面的公式确实可以求出我们所需要的梯度呢?

2. “举个栗子”:两个矩阵相乘

我们不妨看看两个简单矩阵相乘的过程,并将目光聚焦到求关于 W W W的梯度

求关于 W W W的梯度,则我们得到的 d W dW dW的形状应当是与 W W W相同的,即每个元素都有一个对应的梯度。我们看和 W 11 W_{11} W11有关的部分:

y 11 = X 11 W 11 + X 12 W 21 y_{11}=X_{11}W_{11}+X_{12}W_{21} y11=X11W11+X12W21
y 21 = X 21 W 11 + X 22 W 21 y_{21}=X_{21}W_{11}+X_{22}W_{21} y21=X21W11+X22W21
y 31 = X 31 W 11 + X 32 W 21 y_{31}=X_{31}W_{11}+X_{32}W_{21} y31=X31W11+X32W21

不难发现, W 11 W_{11} W11的系数有三个,那么 W 11 W_{11} W11的梯度就是这三个系数的和: X 11 + X 21 + X 31 X_{11}+X_{21}+X_{31} X11+X21+X31

  • 对应的系数作为梯度很好理解,可为什么是呢?而不是平均数?又或者其它的?
    我现在也没有很明白,求得的梯度为什么是它所有系数的和值,主要是对这个梯度值所代表的意义有些困惑。不过平均数其实没有什么意义,不过是给所有求得的梯度等比缩小了而已。

相应的, W W W第一行的元素,其梯度都是 X X X第一列的和;第二行的元素,其梯度都是 X X X第二列的和。
于是可以发现,通过公式 d W = X ⊤ d Y dW=X ^\top dY dW=XdY,如果 d Y dY dY的元素值都为1,我们就恰巧能得到上面的结果。

  • 在实际的模型中,矩阵乘法的运算只是作为很小的一个部分, d Y dY dY的值接受自下一层,而非简单的全为 1 1 1,因此不必担心出现每一行的权值只能同步更新的问题

3. 从计算图看:误差反向传播

前面我们是从表达式的系数得出的规律,接下来再从计算图来看一下反向传播求梯度的过程。

  • 在考虑神经网络中的误差的反向传播时,计算图确实是一个很棒的工具。对于复杂的矩阵乘法运算,我们可以把它分解成许多简单的加法和乘法运算来考虑。

求W11有关的部分计算图——正向推理

误差反向传播

这里我们得到: d W 11 = X 11 d y 11 + X 21 d y 21 + X 31 d y 31 dW_{11}=X_{11}dy_{11}+X_{21}dy_{21}+X_{31}dy_{31} dW11=X11dy11+X21dy21+X31dy31

这里只画出了举例子所需要的小部分计算图,将一个矩阵乘法运算完整地用计算图呈现出来,会显得比较错综复杂,也比较麻烦。但使用部分计算图来以点带面、帮助理解还是非常不错的。


感谢阅读

有关【深度学习】7-矩阵乘法运算的反向传播求梯度的更多相关文章

  1. ruby - 触发器 ruby​​ 中 3 点范围运算符和 2 点范围运算符的区别 - 2

    请帮助我理解范围运算符...和..之间的区别,作为Ruby中使用的“触发器”。这是PragmaticProgrammersguidetoRuby中的一个示例:a=(11..20).collect{|i|(i%4==0)..(i%3==0)?i:nil}返回:[nil,12,nil,nil,nil,16,17,18,nil,20]还有:a=(11..20).collect{|i|(i%4==0)...(i%3==0)?i:nil}返回:[nil,12,13,14,15,16,17,18,nil,20] 最佳答案 触发器(又名f/f)是

  2. 旋转矩阵的几何意义 - 2

    点向量坐标矩阵的几何意义介绍旋转矩阵的几何含义之前,先介绍一下点向量坐标矩阵的几何含义点:在一维空间下就是一个标量,如同一条直线上,以任意某一个位置为0点,以一定的尺度间隔为1,2,3...,相反方向为-1,-2,-3...;如此就形成了一维坐标系,这时候任何一个点都可以用一个数值表示,如点p1=5,即即从原点出发沿着x轴正方向移动5个尺度;点p2=-3,负方向移动3个尺度;     在一维坐标系上过原点做垂直于一维坐标系的直线,则形成了二维坐标系,此时描述一个点需要两个数值来表示点p3=(3,2),即从原点出发沿着x轴正方向移动3个尺度,在此基础上沿着y轴正方向移动两个尺度的位置就是点p3。

  3. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  4. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  5. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

  6. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  7. ruby - ruby 乘法语句中星号中断语法前的空格 - 2

    在添加一些空格以使代码更具可读性时(与上面的代码对齐),我遇到了这个:classCdefx42endendm=C.new现在这将给出“错误数量的参数”:m.x*m.x这将给出“语法错误,意外的tSTAR,期待$end”:2/m.x*m.x这里的解析器到底发生了什么?我使用Ruby1.9.2和2.1.5进行了测试。 最佳答案 *用于运算符(42*42)和参数解包(myfun*[42,42])。当你这样做时:m.x*m.x2/m.x*m.xRuby将此解释为参数解包,而不是*运算符(即乘法)。如果您不熟悉它,参数解包(有时也称为“spl

  8. ruby - 带括号和 splat 运算符的并行赋值 - 2

    我明白了:x,(y,z)=1,*[2,3]x#=>1y#=>2z#=>nil我想知道为什么z的值为nil。 最佳答案 x,(y,z)=1,*[2,3]右侧的splat*是内联扩展的,所以它等同于:x,(y,z)=1,2,3左边带括号的列表被视为嵌套赋值,所以它等价于:x=1y,z=23被丢弃,而z被分配给nil。 关于ruby-带括号和splat运算符的并行赋值,我们在StackOverflow上找到一个类似的问题: https://stackoverflow

  9. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  10. ruby-on-rails - 浮点乘法的 Ruby 奇怪问题 - 2

    有没有人用ruby​​解决这个问题:假设我们有:a=8.1999999我们想将它四舍五入为2位小数,即8.20,然后乘以1,000,000得到8,200,000我们是这样做的;(a.round(2)*1000000).to_i但是我们得到的是8199999,为什么?奇怪的是,如果我们乘以1000、100000或10000000而不是1000000,我们会得到正确的结果。有人知道为什么吗?我们正在使用ruby​​1.9.2并尝试使用1.9.3。谢谢! 最佳答案 每当你在计算中得到时髦的数字时使用bigdecimalrequire'bi

随机推荐