草庐IT

一文详解对抗训练方法

晓柒NLP与药物设计 2023-09-21 原文

对抗训练方法

Adversarial learning主要是用于样本生成或者对抗攻击领域,主要方法是通过添加鉴别器或者根据梯度回传生成新样本,其主要是为了提升当前主干模型生成样本的能力或者鲁棒性

一. 对抗训练定义

==对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力==

1.1 对抗训练特点

  • 相对于原始输入,所添加的扰动是微小的
  • 添加的噪声可以使得模型预测错误

1.2 对抗训练的基本概念

就是在原始输入样本上加上一个扰动得到对抗样本,再用其进行训练,这个问题可以抽象成这样一个模型:

其中,ground truth,是模型参数。意思就是即使在扰动的情况下求使得预测出的概率最大的参数,扰动可以被定义为:

其中,为符号函数,为损失函数

最后,GoodFellow还总结了对抗训练的两个作用:

  1. 提高模型应对恶意对抗样本时的鲁棒性
  2. 作为一种regularization,减少overfitting,提高泛化能力

1.3 Min-Max公式

Madry在2018年的ICLR论文Towards Deep Learning Models Resistant to Adversarial Attacks中总结了之前的工作,对抗训练可以统一写成如下格式:

其中代表输入样本的分布,代表输入,代表标签,是模型参数,是单个样本的loss,是扰动,是扰动空间。这个式子可以分布理解如下:

  1. 内部max是指往中添加扰动的目的是让越大越好,也就是说尽可能让现有模型预测出错。但是,也是有约束的,要在范围内. 常规的约束是,其中是一个常数
  2. 外部min是指找到最鲁棒的参数是预测的分布符合原数据集的分布

这就解决了两个问题:如何构建足够强的对抗样本、和如何使得分布仍然尽可能接近原始分布

1.4 NLP领域的对抗训练

对于CV领域,图像被认为是连续的,因此可以直接在原始图像上添加扰动;而对于NLP,它的输入是文本的本质是one-hot,而one-hot之间的欧式距离恒为,理论上不存在微小的扰动,而且,在Embedding向量上加上微小扰动可能就找不到与之对应的词了,不是真正意义上的对抗样本,因为对抗样本依旧能对应一个合理的原始输入,既然不能对Embedding向量添加扰动,可以对Embedding层添加扰动,使其产生更鲁棒的Embedding向量

二. 对抗训练方法

2.1 FGM(Fast Gradient Method) ICLR2017

FGM是根据具体的梯度进行scale,得到更好的对抗样本:

整个对抗训练的过程如下,伪代码如下:

  1. 计算x的前向loss、反向传播得到梯度
  2. 根据embedding矩阵的梯度计算出,并加到当前embedding上,相当于
  3. 计算的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上
  4. 将embedding恢复为(1)时的值
  5. 根据(3)的梯度对参数进行更新
class FGM:
    def __init__(self, model: nn.Module, eps=1.):
        self.model = (model.module if hasattr(model, "module") else model)
        self.eps = eps
        self.backup = {}
    # only attack word embedding
    def attack(self, emb_name='word_embeddings'):
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm and not torch.isnan(norm):
                    r_at = self.eps * param.grad / norm
                    param.data.add_(r_at)
    def restore(self, emb_name='word_embeddings'):
        for name, para in self.model.named_parameters():
            if para.requires_grad and emb_name in name:
                assert name in self.backup
                para.data = self.backup[name]
        self.backup = {}

2.2 FGSM (Fast Gradient Sign Method) ICLR2015

FGSM的全称是Fast Gradient Sign Method. FGSM和FGM的核心区别在计算扰动的方式不一样,FGSM扰动的计算方式如下:

def FGSM(image, epsilon, data_grad):
    """
    :param image: 需要攻击的图像
    :param epsilon: 扰动值的范围
    :param data_grad: 图像的梯度
    :return: 扰动后的图像
    """
    # 收集数据梯度的元素符号
    sign_data_grad = data_grad.sign()
    # 通过调整输入图像的每个像素来创建扰动图像
    perturbed_image = image + epsilon*sign_data_grad
    # 添加剪切以维持[0,1]范围
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # 返回被扰动的图像
    return perturbed_image

2.3 PGD(Projected Gradient Descent)

FGM直接通过epsilon参数算出了对抗扰动,这样得到的可能不是最优的。因此PGD进行了改进,通过迭代慢慢找到最优的扰动

并且

PGD整个对抗训练的过程如下

  1. 计算的前向loss、反向传播得到梯度并备份

  2. 对于每步:

    1. 根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于(超出范围则投影回epsilon内)
    2. if t不是最后一步: 将梯度归0,根据(1)计算前后向并得到梯度
    3. if t是最后一步: 恢复(1)的梯度,计算最后的并将梯度累加到(1)
  3. 将embedding恢复为(1)时的值

  4. 根据(5)的梯度对参数进行更新

在循环中是逐渐累加的,要注意的是最后更新参数只使用最后一个算出来的梯度

class PGD():
    def __init__(self, model):
        self.model = model
        self.emb_backup = {}
        self.grad_backup = {}
    def attack(self, epsilon=1., alpha=0.3, emb_name='emb.', is_first_attack=False):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                if is_first_attack:
                    self.emb_backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = alpha * param.grad / norm
                    param.data.add_(r_at)
                    param.data = self.project(name, param.data, epsilon)
    def restore(self, emb_name='emb.'):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name: 
                assert name in self.emb_backup
                param.data = self.emb_backup[name]
        self.emb_backup = {}
    def project(self, param_name, param_data, epsilon):
        r = param_data - self.emb_backup[param_name]
        if torch.norm(r) > epsilon:
            r = epsilon * r / torch.norm(r)
        return self.emb_backup[param_name] + r
    def backup_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.grad_backup[name] = param.grad.clone()
    def restore_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad = self.grad_backup[name]

2.4 FreeAT(Free Adversarial Training)

从FGSM到PGD,主要是优化对抗扰动的计算,虽然取得了更好的效果,但计算量也一步步增加。对于每个样本,FGSM和FGM都只用计算两次,一次是计算的前后向,一次是计算的前后向。而PGD则计算了K+1次,消耗了更多的计算资源。因此FreeAT被提了出来,在PGD的基础上进行训练速度的优化

FreeAT的思想是在对每个样本连续重复次训练,计算时复用上一步的梯度,为了保证速度,整体epoch会除以的更新公式为:

FreeAT的训练过程如下:

  1. 初始化
  2. 对于epoch=:
    1. 对于每个:
      1. 对于每步:
        1. 利用上一步的,计算的前后向,得到梯度
        2. 根据梯度更新参数
        3. 根据梯度更新

FreeAT的问题在于每次的对于当前的参数都是次优的(无法最大化loss),因为当前是由计算出来的,是对于的最优

2.5 YOPO(You Only Propagate Once)

YOPO的出发点是利用神经网络的结构来降低梯度计算的计算量。从极大值原理PMP(Pontryagin’s maximum principle)出发,对抗扰动只和网络的第0层有关,即在embedding层上添加扰动。再加之层之间是解耦合的,那就不需要每次都计算完整的前后向传播

基于这个想法,复用后面几层的梯度,减少非必要的完整传播。可以将PGD次攻击拆成次:

则对r的更新就可以变为:

其算法流程为:

对于每个样本,初始化,对于:

  1. 根据,计算对于:
  2. 计算

2.6 FreeLB (Free Large-Batch)

YOPO的假设对于ReLU-based网络来说是不成立的,因为YOPO要求损失是两次可微的,于是,FreeLB在FreeAT的基础上将每次inner-max中更新模型参数这一操作换掉,利用步之后累积的参数梯度进行更新,于是总体任务的目标函数就记为:

可以看成两个球形邻域的交上局部最大的近似。同时,通过累积参数梯度的操作,可以看作是输入了这样一个虚拟的倍大小的batch。其中input subwords的one-hot representations记为,embedding matrix记为,subwords embedding记为

依据下面算法中的数学符号,PGD需要进行次梯度计算,FreeAT需要进行次,FreeLB需要次。虽然FreeLB在效率上并没有特别大的优势,但是其效果十分不错

另外,论文中指出对抗训练和dropout不能同时使用,加上dropout相当于改变了网络的结果,影响扰动的计算。如果一定要加入dropout操作,需要在K步中都使用同一个mask

2.7 SMART(SMoothness-inducing Adversarial Regularization)

SMART放弃了Min-Max公式,选择通过正则项Smoothness-inducing Adversarial Regularization完成对抗学习。为了解决这个新的目标函数作者又提出了优化算法Bregman Proximal Point Optimization,这就是SMART的两个主要内容

SMART的主要想法是强制模型在neighboring data points上作出相似的预测,加入正则项后的目标函数如下所示:

是具体任务的损失函数,是generated neighbors of training points,在分类任务中使用对称的KL散度,即;在回归任务中使用平方损失,此时可以看到对抗发生在正则化项上,对抗的目标是最大扰动前后的输出

Bregman Proximal Point Optimization也可以看作是一个正则项,防止更新的时候和前面的变化过大。在第次迭代时,采用vanilla Bregman proximal point (VBPP) method

其中表示Bregman divergence定义为:

是上面给出的对称KL散度

使用动量来加速VBPP,此时定义为动量,记表示指数移动平均,那么momentum Bregman proximal point (MBPP) method就可以表示为:

下面是SMART的完整算法流程:

  1. 对于轮迭代:
    1. 备份,作为Bregman divergence计算的
    2. 对于每一个
      1. 使用正态分布随机初始化扰动,结合得到
      2. 循环小步:计
        1. 算扰动下的梯度
        2. 基于和学习率更新
      3. 基于重新计算梯度,更新参数
    3. 更新

三. Reference

  1. Madry A, Makelov A, Schmidt L, et al. Towards deep learning models resistant to adversarial attacks[J]. arXiv preprint arXiv:1706.06083, 2017.

  2. Goodfellow I J, Shlens J, Szegedy C. Explaining and harnessing adversarial examples[J]. arXiv preprint arXiv:1412.6572, 2014.

  3. Miyato T, Dai A M, Goodfellow I. Adversarial training methods for semi-supervised text classification[J]. arXiv preprint arXiv:1605.07725, 2016.

  4. Shafahi A, Najibi M, Ghiasi A, et al. Adversarial training for free![J]. arXiv preprint arXiv:1904.12843, 2019.

  5. Zhang D, Zhang T, Lu Y, et al. You only propagate once: Accelerating adversarial training via maximal principle[J]. arXiv preprint arXiv:1905.00877, 2019.

  6. Zhu C, Cheng Y, Gan Z, et al. Freelb: Enhanced adversarial training for natural language understanding[J]. arXiv preprint arXiv:1909.11764, 2019.

  7. Jiang H, He P, Chen W, et al. Smart: Robust and efficient fine-tuning for pre-trained natural language models through principled regularized optimization[J]. arXiv preprint arXiv:1911.03437, 2019.

有关一文详解对抗训练方法的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 如何从 ruby​​ 中的字符串运行任意对象方法? - 2

    总的来说,我对ruby​​还比较陌生,我正在为我正在创建的对象编写一些rspec测试用例。许多测试用例都非常基础,我只是想确保正确填充和返回值。我想知道是否有办法使用循环结构来执行此操作。不必为我要测试的每个方法都设置一个assertEquals。例如:describeitem,"TestingtheItem"doit"willhaveanullvaluetostart"doitem=Item.new#HereIcoulddotheitem.name.shouldbe_nil#thenIcoulddoitem.category.shouldbe_nilendend但我想要一些方法来使用

  3. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  4. ruby - Facter::Util::Uptime:Module 的未定义方法 get_uptime (NoMethodError) - 2

    我正在尝试设置一个puppet节点,但ruby​​gems似乎不正常。如果我通过它自己的二进制文件(/usr/lib/ruby/gems/1.8/gems/facter-1.5.8/bin/facter)在cli上运行facter,它工作正常,但如果我通过由ruby​​gems(/usr/bin/facter)安装的二进制文件,它抛出:/usr/lib/ruby/1.8/facter/uptime.rb:11:undefinedmethod`get_uptime'forFacter::Util::Uptime:Module(NoMethodError)from/usr/lib/ruby

  5. Ruby 方法() 方法 - 2

    我想了解Ruby方法methods()是如何工作的。我尝试使用“ruby方法”在Google上搜索,但这不是我需要的。我也看过ruby​​-doc.org,但我没有找到这种方法。你能详细解释一下它是如何工作的或者给我一个链接吗?更新我用methods()方法做了实验,得到了这样的结果:'labrat'代码classFirstdeffirst_instance_mymethodenddefself.first_class_mymethodendendclassSecond使用类#returnsavailablemethodslistforclassandancestorsputsSeco

  6. ruby-on-rails - Rails 3.2.1 中 ActionMailer 中的未定义方法 'default_content_type=' - 2

    我在我的项目中添加了一个系统来重置用户密码并通过电子邮件将密码发送给他,以防他忘记密码。昨天它运行良好(当我实现它时)。当我今天尝试启动服务器时,出现以下错误。=>BootingWEBrick=>Rails3.2.1applicationstartingindevelopmentonhttp://0.0.0.0:3000=>Callwith-dtodetach=>Ctrl-CtoshutdownserverExiting/Users/vinayshenoy/.rvm/gems/ruby-1.9.3-p0/gems/actionmailer-3.2.1/lib/action_mailer

  7. ruby - Highline 询问方法不会使用同一行 - 2

    设置:狂欢ruby1.9.2高线(1.6.13)描述:我已经相当习惯在其他一些项目中使用highline,但已经有几个月没有使用它了。现在,在Ruby1.9.2上全新安装时,它似乎不允许在同一行回答提示。所以以前我会看到类似的东西:require"highline/import"ask"Whatisyourfavoritecolor?"并得到:Whatisyourfavoritecolor?|现在我看到类似的东西:Whatisyourfavoritecolor?|竖线(|)符号是我的终端光标。知道为什么会发生这种变化吗? 最佳答案

  8. ruby - 主要 :Object when running build from sublime 的未定义方法 `require_relative' - 2

    我已经从我的命令行中获得了一切,所以我可以运行rubymyfile并且它可以正常工作。但是当我尝试从sublime中运行它时,我得到了undefinedmethod`require_relative'formain:Object有人知道我的sublime设置中缺少什么吗?我正在使用OSX并安装了rvm。 最佳答案 或者,您可以只使用“require”,它应该可以正常工作。我认为“require_relative”仅适用于ruby​​1.9+ 关于ruby-主要:Objectwhenrun

  9. ruby - 多个属性的 update_column 方法 - 2

    我有一个具有一些属性的模型:attr1、attr2和attr3。我需要在不执行回调和验证的情况下更新此属性。我找到了update_column方法,但我想同时更新三个属性。我需要这样的东西:update_columns({attr1:val1,attr2:val2,attr3:val3})代替update_column(attr1,val1)update_column(attr2,val2)update_column(attr3,val3) 最佳答案 您可以使用update_columns(attr1:val1,attr2:val2

  10. ruby - 检查方法参数的类型 - 2

    我不确定传递给方法的对象的类型是否正确。我可能会将一个字符串传递给一个只能处理整数的函数。某种运行时保证怎么样?我看不到比以下更好的选择:defsomeFixNumMangler(input)raise"wrongtype:integerrequired"unlessinput.class==FixNumother_stuffend有更好的选择吗? 最佳答案 使用Kernel#Integer在使用之前转换输入的方法。当无法以任何合理的方式将输入转换为整数时,它将引发ArgumentError。defmy_method(number)

随机推荐