草庐IT

Pytorch实战笔记(1)——BiLSTM 实现情感分析

野指针小李 2024-05-24 原文

本文展示的是使用 Pytorch 构建一个 BiLSTM 来实现情感分析。本文的架构是第一章详细介绍 BiLSTM,第二章粗略介绍 BiLSTM(就是说如果你想快速上手可以跳过第一章),第三章是核心代码部分。

目录

1. BiLSTM的详细介绍

坦白的说,其实我也不懂 LSTM,但是我这里还是尽我最大的可能解释这个模型。这里我就盗个图 [1](懒得自己画了,而且感觉好像他也是盗的李宏毅老师课件的图)。

简单来说,LSTM 在每个时刻的输入都是由该时刻输入的序列信息 X t X^t Xt 与上一时刻的隐藏状态 h t − 1 h^{t-1} ht1 通过四种不同的非线性变化映射而成,分别为:

  1. 遗忘门控信号:遗忘门控信号 z f z^f zf 的计算公式如下:
    z f = s i g m o i d ( W f [ X t ; h t − 1 ] ) , z^f = {\rm sigmoid}(W^f\left[ X^t; h^{t-1} \right]), zf=sigmoid(Wf[Xt;ht1]),
    其中, [ X t ; h t − 1 ] [X^t;h^{t-1}] [Xt;ht1] 是将 X t X^t Xt h t − 1 h^{t-1} ht1 拼接起来; W f W^f Wf 是权重; S i g m o i d ( ⋅ ) {\rm Sigmoid}(\cdot) Sigmoid() 是 Sigmoid 激活函数,用于将数据映射到 (0, 1) 的区间范围内。
  2. 记忆门控信号:记忆门控信号 z i z^i zi 的计算公式如下:
    z i = s i g m o i d ( W i [ X t ; h t − 1 ] ) . z^i={\rm sigmoid}(W^i\left[ X^t; h^{t-1} \right]). zi=sigmoid(Wi[Xt;ht1]).
  3. 输出门控信号:输出门控信号 z o z^o zo 的计算公式如下:
    z o = s i g m o i d ( W o [ X t ; h t − 1 ] ) . z^o = {\rm sigmoid}(W^o\left[ X^t; h^{t-1} \right]). zo=sigmoid(Wo[Xt;ht1]).
  4. 当前时刻的信息:当前时刻的信息 z z z 的计算公式如下:
    z = t a n h ( W [ X t ; h t − 1 ] ) , z = {\rm tanh}(W\left[ X^t; h^{t-1} \right]), z=tanh(W[Xt;ht1]),
    其中, t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh() 是将数据放缩到 (-1, 1) 的区间内。

通过以上的公式,我们可以发现, z f , z i , z o z^f, z^i, z^o zf,zi,zo 都是 (0, 1) 区间的值,而 z z z(-1, 1) 区间的值。

接着就是 LSTM 的内部计算公式,即图上所示的那几个,分别为:

  1. 当前时刻的细胞状态 c t c^t ct 的计算公式如下:
    c t = z f ⊙ c t − 1 + z i ⊙ z , c^t = z^f \odot c^{t-1} + z^i \odot z, ct=zfct1+ziz,
    其中, ⊙ \odot 是哈达玛积,即矩阵元素对位相乘,但是需要注意的是,哈达玛积数学上不可解释,但是跑出来效果好
  2. 当前时刻的隐藏状态 h t h^t ht 的计算公式如下:
    h t = z o ⊙ t a n h ( c t ) . h^t = z^o \odot {\rm tanh} (c^t). ht=zotanh(ct).
  3. 当前时刻的输出 y t y^t yt 的计算公式如下:
    y t = σ ( W ′ h t ) . y^t = \sigma (W'h^t). yt=σ(Wht).

公式列举完后,这里说一下我对这些公式的理解(不一定是对的哈)。

  • 首先是 c t c^t ct 的计算。我们看到 c t c^t ct 的计算分为了两部分。一部分是 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1,这一部分是 LSTM 的遗忘过程,由于刚刚提到, z f z^f zf 是 (0, 1) 区间范围内的值,同时,sigmoid 函数是一个无限趋近于 0 或者 1 的函数,也就是说, c t − 1 c^{t-1} ct1 无论怎样,都会有些数据被遗弃,始终不会完全保留下来,这也就模拟了一个遗忘的过程。同理,对于记忆部分 z i ⊙ z z^i \odot z ziz,这一步也是只会保留部分 z z z 的信息,也就模拟了人的记忆是由些许失真的过程。同时,两者相加后,那么就代表了当前细胞状态 c t c^t ct 中保留的是没有被遗忘掉的过去的信息和当前时刻被记忆下来的信息
  • 接着是 h t h^t ht 的计算。首先是为什么要先对 c t c^t ct 做一次 t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh(),这是因为由于 c t c^t ct 的区间范围不是 (-1, 1),因为 z i ⊙ z z^i \odot z ziz 的区间范围是 (-1, 1),再与 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1 相加,那么 c t c^t ct 的范围就有可能超出 (-1, 1),所以先用一个 tanh 将数值给放缩到 (-1, 1) 内。接着再与 z o z^o zo 做一次哈达玛积后,得到的隐藏状态就是 (-1, 1) 的数据,那么该数据放到后续模块中,就可以代表当前时刻的输入是正的还是负的,同时有多大。
  • 最后就是 y t y^t yt 的计算,实际上这就是个全连接层,将隐藏状态进行一次映射,再通过一个非线性变化的激活函数。

2. BiLSTM 的简单介绍

当然,其实你没看懂上面的部分也不重要,从使用的角度上来讲,会用就行了,就像你用手机,你不会去搞懂里面每个元器件是怎么做出来的,每个 APP 是怎么写出来的;就像你去打篮球,也不用梳个中分,穿个背带裤。

那么对于 BiLSTM,你需要了解的是什么?

  • 首先,这是一个序列模型,它接受一个序列的输入,并且输出这个序列的信息。对于序列中每个位置的输出,它会包含该位置的信息以及之前的信息。就是说 LSTM 能够捕获到位置 t t t 及其之前位置的信息。而对于 BiLSTM 的话,则能捕获到 t t t 的双向信息。
  • 如果是 BiLSTM,它的每个位置的输出,是前向 L S T M → \overrightarrow{LSTM} LSTM 的输出 y → \overrightarrow{y} y 与反向 L S T M ← \overleftarrow{LSTM} LSTM 的输出 y ← \overleftarrow{y} y 拼接在一起的, [ y → ; y ← ] [\overrightarrow{y}; \overleftarrow{y}] [y ;y ]。所以假设你设置 LSTM 的隐藏层维度为 128,那么单向 LSTM 的输出维度是 128,但是双向就是 256 (128*2).
  • 但是虽然说 LSTM 好像大概可能也许 maybe possibly 能够捕获长距离依赖信息哈,毕竟 LSTM 的全称都是 Long Short-Term Memory,但是实际上这是 LSTM 的骗局,LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力! 从数学上说,你经过这么多次 sigmoid,还能保留个啥?当然,在《An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling》这篇论文[2]中,作者用了大量的实验来说明了,LSTM 不仅并行计算能力差(因为要上一个时间步的信息才能计算下一个时间步,所以 LSTM 不是个并行系统),同时在它最吹嘘的长距离信息捕获能力上,都不如 CNN,所以以后在跑实验的时候,可以尝试使用 TextCNN 来试试,说不定效果比 BiLSTM 好(反正我做过的实验中 TextCNN 性能一般比 BiLSTM 高8-10个点)。

3. BiLSTM 实现情感分析

在本博客中仅介绍模型部分,详细代码见 github。

模型图如图所示:

具体而言,就是输入序列输入到一个双向 LSTM 中,并将双向 LSTM 的最后一个隐藏状态(即句向量)输入到一个全连接层(也可以说是分类器)中,输出最后的分类结果,具体模型的代码如下:

import torch.nn as nn

class BiLSTM_SA(nn.Module):

    def __init__(self, embed, config):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embed, freeze=False)
        self.LSTM = nn.LSTM(config.embed_size, config.lstm_hidden_size,
                            num_layers=config.num_layers, batch_first=True,
                            bidirectional=True)
        # 因为是双向 LSTM, 所以要乘2
        self.ffn = nn.Linear(config.lstm_hidden_size * 2,
                             config.dense_hidden_size)
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(config.dense_hidden_size,
                                    config.num_outputs)

    def forward(self, inputs):
        # shape: (batch_size, max_seq_length, embed_size)
        embed = self.embedding(inputs)
        # shape: (batch_size, max_seq_length, lstm_hidden_size * 2)
        lstm_hidden_states, _ = self.LSTM(embed)
        # LSTM 的最后一个时刻的隐藏状态, 即句向量
        # shape: (batch, lstm_hidden_size * 2)
        lstm_hidden_states = lstm_hidden_states[:, -1, :]
        # shape: (batch, dense_hidden_size)
        ffn_outputs = self.relu(self.ffn(lstm_hidden_states))
        # shape: (batch, num_outputs)
        logits = self.classifier(ffn_outputs)

        return logits

全连接层我采用了两个全连接层,一个将维度从 256 压缩到 128,另外一个是分类器。

这里有个小细节要注意一下,通常在论文的公式里面,我们都会看到别人写的分类器的公式如下: y ^ = S o f t m a x ( W h + b ) \hat{y} = {\rm Softmax}(Wh+b) y^=Softmax(Wh+b),有个 softmax 的激活函数,但是在 pytorch 中实际不需要,就比如我代码里面是写的:

logits = self.classifier(ffn_outputs)

而不是:

y_hat = self.softmax(self.classifier(ffn_outputs))

这是因为如果你后面选用交叉熵作为损失函数,而且调用的是torch中的 nn.CrossEntropyLoss(),那么就没必要在输出的时候用 softmax,这是因为 nn.CrossEntropyLoss() 中自带有 softmax 操作,虽然这样对你的分类结果不会产生任何影响,但是你得损失会变得很大。

最后的测试集的实验结果为:

test loss 0.419664 | test accuracy 0.813760 | test precision 0.804267 | test recall 0.829360 | test F1 0.816621

参考

[1] 陈诚. 人人都能看懂的LSTM[EB/OL]. https://zhuanlan.zhihu.com/p/32085405, 2018
[2] Shaojie Bai, J. Zico Kolter, Vladlen Koltun. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling [EB/OL]. https://arxiv.org/abs/1803.01271, 2018

有关Pytorch实战笔记(1)——BiLSTM 实现情感分析的更多相关文章

  1. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  2. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  3. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  4. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  5. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

  6. 微信小程序开发入门与实战(Behaviors使用) - 2

    @作者:SYFStrive @博客首页:HomePage📜:微信小程序📌:个人社区(欢迎大佬们加入)👉:社区链接🔗📌:觉得文章不错可以点点关注👉:专栏连接🔗💃:感谢支持,学累了可以先看小段由小胖给大家带来的街舞👉微信小程序(🔥)目录自定义组件-behaviors    1、什么是behaviors    2、behaviors的工作方式    3、创建behavior    4、导入并使用behavior    5、behavior中所有可用的节点    6、同名字段的覆盖和组合规则总结最后自定义组件-behaviors    1、什么是behaviorsbehaviors是小程序中,用于实现

  7. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

  8. ruby - Arrays Sets 和 SortedSets 在 Ruby 中是如何实现的 - 2

    通常,数组被实现为内存块,集合被实现为HashMap,有序集合被实现为跳跃列表。在Ruby中也是如此吗?我正在尝试从性能和内存占用方面评估Ruby中不同容器的使用情况 最佳答案 数组是Ruby核心库的一部分。每个Ruby实现都有自己的数组实现。Ruby语言规范只规定了Ruby数组的行为,并没有规定任何特定的实现策略。它甚至没有指定任何会强制或至少建议特定实现策略的性能约束。然而,大多数Rubyist对数组的性能特征有一些期望,这会迫使不符合它们的实现变得默默无闻,因为实际上没有人会使用它:插入、前置或追加以及删除元素的最坏情况步骤复

  9. ruby - "public/protected/private"方法是如何实现的,我该如何模拟它? - 2

    在ruby中,你可以这样做:classThingpublicdeff1puts"f1"endprivatedeff2puts"f2"endpublicdeff3puts"f3"endprivatedeff4puts"f4"endend现在f1和f3是公共(public)的,f2和f4是私有(private)的。内部发生了什么,允许您调用一个类方法,然后更改方法定义?我怎样才能实现相同的功能(表面上是创建我自己的java之类的注释)例如...classThingfundeff1puts"hey"endnotfundeff2puts"hey"endendfun和notfun将更改以下函数定

  10. ruby - 实现k最近邻需要哪些数据? - 2

    我目前有一个reddit克隆类型的网站。我正在尝试根据我的用户之前喜欢的帖子推荐帖子。看起来K最近邻或k均值是执行此操作的最佳方法。我似乎无法理解如何实际实现它。我看过一些数学公式(例如k表示维基百科页面),但它们对我来说并没有真正意义。有人可以推荐一些伪代码,或者可以查看的地方,以便我更好地了解如何执行此操作吗? 最佳答案 K最近邻(又名KNN)是一种分类算法。基本上,您采用包含N个项目的训练组并对它们进行分类。如何对它们进行分类完全取决于您的数据,以及您认为该数据的重要分类特征是什么。在您的示例中,这可能是帖子类别、谁发布了该项

随机推荐