草庐IT

Transformer网络-Self-attention is all your need

LeonYi 2023-06-22 原文

一、Transformer

Transformer最开始用于机器翻译任务,其架构是seq2seq的编码器解码器架构。其核心是自注意力机制: 每个输入都可以看到全局信息,从而缓解RNN的长期依赖问题。
输入: (待学习的)输入词嵌入 + 位置编码(相对位置)
编码器结构: 6层编码器: 一层编码器 = 多头注意力+残差(LN) + FFN+残差(LN)
输出:每一个位置上输出预测概率分布(K类类别分布)

1.1 自注意力

分解式

 

 

 

 

缩放内积注意力
1. 自注意力的优势
         a. 计算开销,计算可并行 (嵌入维度d,序列长度n,计算复杂度O(n^2d))
         b. 建模长期依赖 (稳定训练过程)
2. 自注意力缩放(内积过大,softmax饱和)
We suspect that for large values d_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.To counteract this effect, we scale the dot products by sqrt(d_k)
如上为原文。作者怀疑,如果Q和K的维度特别大,会使得内积后的值也大。从而使softmax进入梯度极小的区域(类似于sigmoid的饱和区域)。 这样容易导致梯度消失。
所以,他们将内积值除以sqrt(d_k),进行一个缩放,而又不破坏相对比例。
 
多头注意力机制(multi-head attention)
Transformer 提出多头注意力机制(不同头结果拼起来,再做线性变换),增强了 attention 层的能力(参数量不变)。解释:
  1. 它扩展了模型关注不同位置的能力。不同注意力头,关注不同的位置。长距离依赖
  2. 多头注意力机制赋予 attention 层多个“子表示空间(训练之后,每组注意力可以看作是把输入的向量映射到一个”子表示空间“)
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
参数说明如下:
  • embed_dim:最终输出的 K、Q、V 矩阵的维度,这个维度需要和词向量的维度一样
  • num_heads:设置多头注意力的数量。如果设置为 1,那么只使用一组注意力。如果设置为其他数值,那么 num_heads 的值需要能够被 embed_dim 整除
  • dropout:这个 dropout 加在 attention score 后面
定义 MultiheadAttention 的对象后,调用时传入的参数如下。
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
  • query:对应于 Query 矩阵,形状是 (L,N,E) 。其中 L 是输出序列长度,N 是 batch size,E 是词向量的维度
  • key:对应于 Key 矩阵,形状是 (S,N,E) 。其中 S 是输入序列长度,N 是 batch size,E 是词向量的维度
  • value:对应于 Value 矩阵,形状是 (S,N,E) 。其中 S 是输入序列长度,N 是 batch size,E 是词向量的维度
  • key_padding_mask:如果提供了这个参数,那么计算 attention score 时,忽略 Key 矩阵中某些 padding 元素,不参与计算 attention(序列长度不同)。形状是 (N,S)。其中 N 是 batch size,S 是输入序列长度。
    • 如果 key_padding_mask 是 ByteTensor,那么非 0 元素对应的位置会被忽略
    • 如果 key_padding_mask 是 BoolTensor,那么 True 对应的位置会被忽略
  • attn_mask:计算输出时,忽略某些位置。形状可以是 2D (L,S),或者 3D (N∗numheads,L,S)。其中 L 是输出序列长度,S 是输入序列长度,N 是 batch size。
    • 如果 attn_mask 是 ByteTensor,那么非 0 元素对应的位置会被忽略
    • 如果 attn_mask 是 BoolTensor,那么 True 对应的位置会被忽略 
在实际中,K、V 矩阵的长度一样,而 Q 矩阵的序列长度可不一样。这种情况发生在:在解码器部分的encoder-decoder attention层中,Q 矩阵是来自解码器下层,而 K、V 矩阵则是来自编码器的输出。
 

2. Encoder 和 Decoder

 
编码器就是编码器层(多头注意力+(残差+LN),FFN+(残差+LN))的堆叠。
 
解码器
Self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position.We need to prevent leftward information flow in the decoder to preserve the auto-regressive property.We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections.
      为了保持自回归的性质,要保持从左往右的顺序 (这种情况下,不能利用要预测的未来来推断过去)。  这里将当前token以后的进行mask (即将注意力得分加上-inf,将其变成无穷小,使其注意力系数极小接近于无) [exp(-inf) = 0]
      GAT也是这样做的,只不过mask的是非邻居结点 (避免信息泄露,从而让模型学不好)。
 
避免信息泄露,在解码器中使用mask:
# 把 mask 不为空,那么就把 mask 为 0 的位置的 attention 分数设置为 -1e10(系数无穷小)
attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
  attention = attention.masked_fill(mask == 0, -1e10)
  attention = self.do(torch.softmax(attention, dim=-1))
  x = torch.matmul(attention, V)
 
交叉注意力层(Decoder-Encoder Attention, Decoder到Encoder输出的潜表示)
使用前一层的输出来构造 Query 矩阵,而 Key 矩阵和 Value 矩阵来自于编码器最终的输出(seq2seq都是这样的,预测当前输出时,不仅看之前的输出,同时也对输入隐状态进行关注)

 

 

 
预测
每一个位置有一个分类损失;总的损失就是每个位置损失之和。
 
训练
让我们假设输出词汇表只包含 6 个单词(“a”, “am”, “i”, “thanks”, “student”, and “”(“”表示句子末尾))。
 
这种架构本就可以用来做语言模型,只不过这里做了seq2seq的翻译。
如果训练数据中本身就有很多句子对,就可以直接通过语言模型实现翻译,例如GPT架构。
 
学习笔记,配图参考知乎-张贤同学、李宏毅机器学习。

有关Transformer网络-Self-attention is all your need的更多相关文章

  1. ruby - 用 Ruby 编写一个简单的网络服务器 - 2

    我想在Ruby中创建一个用于开发目的的极其简单的Web服务器(不,不想使用现成的解决方案)。代码如下:#!/usr/bin/rubyrequire'socket'server=TCPServer.new('127.0.0.1',8080)whileconnection=server.acceptheaders=[]length=0whileline=connection.getsheaders想法是从命令行运行这个脚本,提供另一个脚本,它将在其标准输入上获取请求,并在其标准输出上返回完整的响应。到目前为止一切顺利,但事实证明这真的很脆弱,因为它在第二个请求上中断并出现错误:/usr/b

  2. 网络编程套接字 - 2

    网络编程套接字网络编程基础知识理解源`IP`地址和目的`IP`地址理解源MAC地址和目的MAC地址认识端口号理解端口号和进程ID理解源端口号和目的端口号认识`TCP`协议认识`UDP`协议网络字节序socket编程接口`sockaddr``UDP`网络程序服务器端代码逻辑:需要用到的接口服务器端代码`udp`客户端代码逻辑`udp`客户端代码`TCP`网络程序服务器代码逻辑多个版本服务器单进程版本多进程版本多线程版本线程池版本服务器端代码客户端代码逻辑客户端代码TCP协议通讯流程TCP协议的客户端/服务器程序流程三次握手(建立连接)数据传输四次挥手(断开连接)TCP和UDP对比网络编程基础知识

  3. TimeSformer:抛弃CNN的Transformer视频理解框架 - 2

    Transformers开始在视频识别领域的“猪突猛进”,各种改进和魔改层出不穷。由此作者将开启VideoTransformer系列的讲解,本篇主要介绍了FBAI团队的TimeSformer,这也是第一篇使用纯Transformer结构在视频识别上的文章。如果觉得有用,就请点赞、收藏、关注!paper:https://arxiv.org/abs/2102.05095code(offical):https://github.com/facebookresearch/TimeSformeraccept:ICML2021author:FacebookAI一、前言Transformers(VIT)在图

  4. ruby-on-rails - 如何使用 ruby​​ 从 self 方法调用另一个方法? - 2

    #app/models/product.rbclassProduct我从Controller调用方法1。当我运行程序时。我收到一个错误:method_missing(atlinemethod2(param2)).rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/activerecord-5.0.0/lib/active_record/relation/batches.rb:59:in`block(2levels)infind_each... 最佳答案 classProduct说明:第一个是类

  5. ruby - 在参数为 `yield self` 的方法中使用 `&block` 和在没有参数 `yield self` 的方法中使用 `&block` 有什么区别吗? - 2

    我明白了defa(&block)block.call(self)end和defa()yieldselfend导致相同的结果,如果我假设有这样一个blocka{}。我的问题是-因为我偶然发现了一些这样的代码,它是否有任何区别或者是否有任何优势(如果我不使用变量/引用block):defa(&block)yieldselfend这是一个我不理解&block用法的具体案例:defrule(code,name,&block)@rules=[]if@rules.nil?@rules 最佳答案 我能想到的唯一优点就是自省(introspecti

  6. ruby - 从另一个私有(private)方法中使用 self.xxx() 调用私有(private)方法 xxx,导致错误 "private method ` xxx' called” - 2

    我正在尝试获得良好的Ruby编码风格。为防止意外调用具有相同名称的局部变量,我总是在适当的地方使用self.。但是现在我偶然发现了这个:classMyClass上面的代码导致错误privatemethodsanitize_namecalled但是当删除self.并仅使用sanitize_name时,它会起作用。这是为什么? 最佳答案 发生这种情况是因为无法使用显式接收器调用私有(private)方法,并且说self.sanitize_name是显式指定应该接收sanitize_name的对象(self),而不是依赖于隐式接收器(也是

  7. ruby-on-rails - self 在 Rails 模型中的值(value)是什么?为什么没有明显的实例方法可用? - 2

    我的rails3.1.6应用程序中有一个自定义访问器方法,它为一个属性分配一个值,即使该值不存在。my_attr属性是一个序列化的哈希,除非为空白,否则应与给定值合并指定了值,在这种情况下,它将当前值设置为空值。(添加了检查以确保值是它们应该的值,但为简洁起见被删除,因为它们不是我的问题的一部分。)我的setter定义为:defmy_attr=(new_val)cur_val=read_attribute(:my_attr)#storecurrentvalue#makesureweareworkingwithahash,andresetvalueifablankvalueisgiven

  8. ruby - 检查网络文件是否存在,而不下载它? - 2

    是否可以在不实际下载文件的情况下检查文件是否存在?我有这么大的(~40mb)文件,例如:http://mirrors.sohu.com/mysql/MySQL-6.0/MySQL-6.0.11-0.glibc23.src.rpm这与ruby​​不严格相关,但如果发件人可以设置内容长度就好了。RestClient.get"http://mirrors.sohu.com/mysql/MySQL-6.0/MySQL-6.0.11-0.glibc23.src.rpm",headers:{"Content-Length"=>100} 最佳答案

  9. ruby - 404 未找到,但可以从网络浏览器正常访问 - 2

    我在这方面尝试了很多URL,在我遇到这个特定的之前,它们似乎都很好:require'rubygems'require'nokogiri'require'open-uri'doc=Nokogiri::HTML(open("http://www.moxyst.com/fashion/men-clothing/underwear.html"))putsdoc这是结果:/Users/macbookair/.rvm/rubies/ruby-2.0.0-p481/lib/ruby/2.0.0/open-uri.rb:353:in`open_http':404NotFound(OpenURI::HT

  10. 深度学习12. CNN经典网络 VGG16 - 2

    深度学习12.CNN经典网络VGG16一、简介1.VGG来源2.VGG分类3.不同模型的参数数量4.3x3卷积核的好处5.关于学习率调度6.批归一化二、VGG16层分析1.层划分2.参数展开过程图解3.参数传递示例4.VGG16各层参数数量三、代码分析1.VGG16模型定义2.训练3.测试一、简介1.VGG来源VGG(VisualGeometryGroup)是一个视觉几何组在2014年提出的深度卷积神经网络架构。VGG在2014年ImageNet图像分类竞赛亚军,定位竞赛冠军;VGG网络采用连续的小卷积核(3x3)和池化层构建深度神经网络,网络深度可以达到16层或19层,其中VGG16和VGG

随机推荐