草庐IT

GAN的训练技巧:炼丹师养成计划 ——生成式对抗网络训练、调参和改进

中杯可乐多加冰 2023-04-19 原文

目录

生成对抗网络(GAN:Generative adversarial networks)是深度学习领域的一个重要生成模型,即两个网络(生成器和鉴别器)在同一时间训练并且在极小化极大算法(minimax)中进行竞争。这种对抗方式避免了一些传统生成模型在实际应用中的一些困难,巧妙地通过对抗学习来近似一些不可解的损失函数。

之前我们介绍了GAN的原理:深入浅出 理解GAN中的数学原理,GAN最重要的就是找到D与G之间的纳什均衡,但是在实际中会发现GAN的训练不稳定,训练方法不佳很容易出现模式崩溃等问题,本篇将记录一些训练技巧,不一定适合你的模型,也可能有疏漏和错误,供学习参考,欢迎指正和补充。

一、模式崩溃: 生成器产生的结果模式较为单一

模式崩溃现象狭义上来说是生成器仅仅产生单个或有限的模式来欺骗鉴别器,仅仅只是为了得到最低的判别器损失D_loss,却忽视了数据集的分布,比如一个动物图像数据集,GAN在训练时候发现生成猫和狗的效果非常好,生成牛、羊、猴子等效果很差,整个G就只去生成猫狗,根本不去学习生成其他的动物图像,就会导致生成的图像单一。模式崩溃现象本质上还是GAN的训练优化问题,即使是最优秀的 GAN 研究人员也在与模式崩溃作斗争。

解决模式崩溃有很多方法,如下:

1.1、改进训练方法

  1. 小批量鉴别器(mini-batch discriminator):因为判别器每次只能独立处理一个样本,生成器在每个样本上获得的梯度信息缺乏“统一协调”,都指向了同一个方向。于是小批量让判别器不再独立考虑一个样本,而是同时考虑一个小批量的所有样本,具体实现可以看:小批量判别器如何解决模式崩溃问题
  2. 经验重播:每隔一段时间向鉴别器显示旧的假样本,可以使模式间的跳来跳去最小化。这可以防止鉴别器变得太容易被利用,但仅限于生成器过去已经探索过的模式。
  3. 调整GAN的学习速度(学习率):通过改变这个特定的超参数来克服这个阻碍,使用较小的学习率,并从头开始训练,学习速度是最重要的超参数之一,即使不是最重要的超参数,即使是它微小变化也可能导致训练过程中的根本性变化。
  4. 特征匹配:特征匹配改变了生成器的cost function,以最小化真实图像和所生成图像的特征之间的统计差异,测量它们的特征向量均值之间的 L2 距离。
  5. 把多个属于同一类的样本进行打包,然后传递给判别网络 D。
  6. 预计反攻:生成器在更新时,不仅仅考虑当前生成器的状态,还会额外考虑K次更新后判别器的状态,综合两个信息做出最优解,即参数更新方式为采用梯度下降方式连续更新K次,提高生成器的“先见之明”,从而避免了短视行为。首先将参数更新方式改为采用梯度下降方式连续更新K次,如下:
    θ D 0 = θ D … ⋯ θ D K = θ D K − 1 + η ∂ f ( θ G , θ D K − 1 ) ∂ θ D K − 1 \begin{aligned} \theta_{D}^{0} &=\theta_{D} \\ & \ldots \cdots \\ \theta_{D}^{K} &=\theta_{D}^{K-1}+\eta \frac{\partial f\left(\theta_{G}, \theta_{D}^{K-1}\right)}{\partial \theta_{D}^{K-1}} \end{aligned} θD0θDK=θD=θDK1+ηθDK1f(θG,θDK1)
    生成器的优化目标改为: θ G = arg ⁡ min ⁡ θ G f ( θ G , θ D K ( θ G , θ D ) ) \theta_{G}=\arg \min _{\theta_{G}} f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right) θG=argminθGf(θG,θDK(θG,θD)),梯度的变化改为: d f K ( θ G , θ D ) d θ G = ∂ f ( θ G , θ D K ( θ G , θ D ) ) ∂ θ G + ∂ f ( θ G , θ D K ( θ G , θ D ) ) ∂ θ D K ( θ G , θ D ) ∂ θ D K ( θ G , θ D ) ∂ θ G \frac{d f_{K}\left(\theta_{G}, \theta_{D}\right)}{d \theta_{G}}=\frac{\partial f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right)}{\partial \theta_{G}}+\frac{\partial f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right)}{\partial \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)} \frac{\partial \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)}{\partial \theta_{G}} dθGdfK(θG,θD)=θGf(θG,θDK(θG,θD))+θDK(θG,θD)f(θG,θDK(θG,θD))θGθDK(θG,θD)

1.2、改进目标函数

  1. 特征匹配:改变生成器的损失函数;
  2. 用Wassernstein距离代替JS散度;
  3. 在梯度上加入惩罚项:WGAN-GP、DRAGAN;
  4. 引入pixel级别loss,特别是在训练早期,如L1, L2等;
  5. 在损失函数上加上正则项,帮助GAN找到更多多样性;
  6. 使用均方损失( mean squared loss )替代对数损失( log loss )。

1.3、改进网络架构

  1. 使用多个生成器,简单地接受GAN只覆盖数据集中模式的一个子集,并为不同模式训练多个生成器,而不是对抗模式崩溃,一起去生成图像,这样就可以生成多样化的图像;
  2. 自注意力机制:全局信息(长距依赖)会用于生成更好的图像。

二、训练缓慢:发生了梯度消失

  1. 网络使用残差结构:自适应网络深度,同时避免梯度消失;
  2. softmax+CrossEntropy loss:通过损失函数来抵消激活函数求导后造成的梯度消失影响
  3. 使用Adam优化器;
  4. 不要把判别器训练得太好,以避免后期梯度消失导致无法训练生成器,判别器的任务是辅助学习数据集的本质概率分布和生成器定义的隐式概率分布之间的某种距离,生成器的任务是使该距离达到最小;
  5. 对于层数过深的模型,尽量避免使用全连接层。

三、不收敛:训练不稳定,收敛的慢

  1. 生成器或鉴别器损失突然增加或减少时,不要随意停止训练,损失函数往往是随机上升或下降的,这个现象并没有什么问题,遇到突然的不稳定时,多进行一些训练,关注生成图像的质量,视觉的理解通常比一些损失数字更有意义;
  2. 添加噪声:通过添加噪声有利于提高系统的整体多样性和稳定性,在真实数据和合成数据(例如由生成器生成的图像)中添加噪声;在数学领域中,这应该是有效的,因为它有助于为两个相互竞争的网络的数据分布提供一定的稳定性;
  3. 软标签或者带噪声的标签:如果真实图像的标签设置为1,我们将它更改为一个低一点的值,比如0.9。这个解决方案阻止判别器对其分类标签过于确信,或者换句话说,不依赖非常有限的一组特征来判断图像是真还是假。

四、过拟合

在GAN中,如果鉴别器依赖于一小组特征来检测真实图像,则生成器可以仅生成这些特征以仅利用鉴别器。优化可能变得过于贪婪并且不会产生长期效益;

  1. 使用正则化来避免过拟合,常用的有L1、L2两种算法,如果已经使用了,调整其参数大小;
  2. dropout:让某些神经元以一定的概率停止工作。从隐藏层神经元中随机选择一个子集临时删除掉,然后训练时没有被删除的那一部分参数更新,删除的神经元参数保持被删除前的结果,不断重复这一过程;
  3. 软标签或者带噪声的标签(同上三)。

五、尽早发现失败

  1. D的loss一直接近于0,直接宣告失败。鉴别器太强了,生成器已经无法再产生更好的假数据了,也可以认为梯度消失了,这种情况很常见因为识别真假样本通常比伪造真实样本容易;
  2. D的loss居高不下,生成的图像很模糊不清,极有可能已失败。判别网络能力不行,胡乱分辨真假,甚至把真的误认为假的,假的误认为真的,生成器无法从判别器D那里学习到东西;
  3. 观察图像发现生成出来的图像单一,发生了模式崩溃,生成网络凑巧在生成某类真样本上特别得心应手,或者,判别网络对某类样本的辨别能力相对较差,那么生成网络会扬长避短,尽量多生成这类样本;
  4. 在一定的epoch后观察图像发现生成出来的图像模糊,全是噪声,极有可能已失败,梯度更新已经开始无意义,再往下训练也不会有改善,所以不要把时间浪费在无谓,病态的梯度更新上;
  5. GAN中loss体现的是判别器的判别能力,整体变化应该是降升、降升,最后趋于稳定。降是因为判别器性能增强了,升是因为生成器生成能力变好了。

六、一些训练技巧

  1. 将图像像素值缩放在-1到1之间,tanh作为生成器的输出层;
  2. 使用Adam优化器通常比其他更好;
  3. 使用PixelShuffle和转置卷积进行上采样;
  4. 使用Batch Normalization,其能提高网络泛化能力,使用BN后还可以不用理会过拟合中的drop out和L2正则化参数选择;
  5. 在将图像输入鉴别器之前,将噪声添加到实际图像和生成的图像中;
  6. 噪声尽量使用正态分布而不是均匀分布;
  7. 梯度惩罚;
  8. 激活函数使用LeakyRelu
  9. Two Timescale Update Rule (TTUR):不同的学习率,低速更新规则用于生成网络 G ,判别网络 D使用 高速更新规则,将判别器的学习率选为0.0004,将生成器的学习率选为0.0001也许可以达到不错的效果
  10. 反转标签,故意在部分样本上颠倒黑白,这个被放过的小鬼也许能刺激GAN别一条道走到黑;
  11. 在一定情况下打乱数据集,不然会导致网络在学习过程中产生偏见;
  12. 优先级:调参>更换损失函数>调整网络结构
  13. 不要采用早停法,要相信奇迹,除非判别器损失迅速趋近于 0;
  14. 不要放弃,一些微小改动将决定你的GAN模型能否训练成功。

部分参考自:
https://arxiv.org/pdf/1606.03498.pdf
https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
https://www.zhihu.com/people/xiaomizhou94/posts

最后

💖 个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向

📝 个人主页:中杯可乐多加冰

🔥 限时免费订阅:文本生成图像T2I专栏

🎉 支持我:点赞👍+收藏⭐️+留言📝

如果这篇文章帮助到你很多,希望能点击下方打赏我一杯可乐!多加冰哦

有关GAN的训练技巧:炼丹师养成计划 ——生成式对抗网络训练、调参和改进的更多相关文章

  1. 动漫制作技巧如何制作动漫视频 - 2

    动漫制作技巧是很多新人想了解的问题,今天小编就来解答与大家分享一下动漫制作流程,为了帮助有兴趣的同学理解,大多数人会选择动漫培训机构,那么今天小编就带大家来看看动漫制作要掌握哪些技巧?一、动漫作品首先完成草图设计和原型制作。设计草图要有目的、有对象、有步骤、要形象、要简单、符合实际。设计图要一致性,以保证制作的顺利进行。二、原型制作是根据设计图纸和制作材料,可以是手绘也可以是3d软件创建。在此步骤中,要注意的问题是色彩和平面布局。三、动漫制作制作完成后,加工成型。完成不同的表现形式后,就要对设计稿进行加工处理,使加工的难易度降低,并得到一些基本准确的概念,以便于后续的大样、准确的尺寸制定。四、

  2. ruby-on-rails - 我可以用鸭子类型(duck typing)改进这种方法吗? - 2

    希望我没有误解“ducktyping”的含义,但从我读到的内容来看,这意味着我应该根据对象如何响应方法而不是它是什么类型/类来编写代码。代码如下:defconvert_hash(hash)ifhash.keys.all?{|k|k.is_a?(Integer)}returnhashelsifhash.keys.all?{|k|k.is_a?(Property)}new_hash={}hash.each_pair{|k,v|new_hash[k.id]=v}returnnew_hashelseraise"CustomattributekeysshouldbeID'sorPropertyo

  3. Unity Shader 学习笔记(5)Shader变体、Shader属性定义技巧、自定义材质面板 - 2

    写在之前Shader变体、Shader属性定义技巧、自定义材质面板,这三个知识点任何一个单拿出来都是一套知识体系,不能一概而论,本文章目的在于将学习和实际工作中遇见的问题进行总结,类似于网络笔记之用,方便后续回顾查看,如有以偏概全、不祥不尽之处,还望海涵。1、Shader变体先看一段代码......Properties{ [KeywordEnum(on,off)]USL_USE_COL("IsUseColorMixTex?",int)=0 [Toggle(IS_RED_ON)]_IsRed("IsRed?",int)=0}......//中间省略,后续会有完整代码 #pragmamulti_c

  4. ruby - 在 Ruby 中训练神经网络 - 2

    在神经网络方面,我完全是个初学者。我整天都在与ruby​​-fann和ai4r搏斗,不幸的是我没有任何东西可以展示,所以我想我会来到StackOverflow并询问这里的知识渊博的人。我有一组样本——每天都有一个数据点,但它们不符合我能够找出的任何明确模式(我尝试了几次回归)。不过,我认为看看是否有任何方法可以仅从日期预测future的数据会很好,而且我认为神经网络将是生成希望表达这种关系的函数的好方法.日期是DateTime对象,数据点是十进制数,例如7.68。我一直在将DateTime对象转换为float,然后除以10,000,000,000得到一个介于0和1之间的数字,我一直在将

  5. ruby - 在 Ruby 中为 XOR 训练神经网络 - 2

    我正在尝试训练一个前馈网络来使用Ruby库AI4R执行异或运算。然而,当我在训练后评估XOR时。我没有得到正确的输出。有没有人以前使用过这个库并得到它来学习异或运算。我使用了两个输入神经元,一个隐藏层中的三个神经元,一个输出层,正如我看到的预计算XOR前馈神经网络就像这样。require"rubygems"require"ai4r"#Createthenetworkwith:#2inputs#1hiddenlayerwith3neurons#1outputsnet=Ai4r::NeuralNetwork::Backpropagation.new([2,3,1])example=[[0,

  6. 关于yolov5训练时参数workers和batch-size的理解 - 2

    关于yolov5训练时参数workers和batch-size的理解yolov5训练命令workers和batch-size参数的理解两个参数的调优总结yolov5训练命令python.\train.py--datamy.yaml--workers8--batch-size32--epochs100yolov5的训练很简单,下载好仓库,装好依赖后,只需自定义一下data目录中的yaml文件就可以了。这里我使用自定义的my.yaml文件,里面就是定义数据集位置和训练种类数和名字。workers和batch-size参数的理解一般训练主要需要调整的参数是这两个:workers指数据装载时cpu所使

  7. ruby - 需要帮助改进 Ruby DSL 以控制 Arduino 控制的饮料分配器(bar monkey) - 2

    我正在用Ruby编写DSL来控制我正在处理的Arduino项目;巴尔迪诺。这是一只酒吧猴子,将由软件控制来提供饮料。Arduino通过串行端口接收命令,告诉Arduino要打开什么泵以及打开多长时间。它目前正在读取一个食谱(见下文)并将其打印出来。串行通信的代码以及我在下面提到的其他一些想法仍然需要改进。这是我的第一个DSL,我正在处理之前的示例,所以它的边缘非常粗糙。任何批评、代码改进(是否有任何关于RubyDSL最佳实践或习语的良好引用?)或任何一般性评论。我目前有DSL的粗略草稿,因此饮料配方如下所示(Githublink):desc"Simpleglassofwater"rec

  8. NEUQ-acm 预备队训练Week4—BFS/DFS - 2

    1.深度优先搜索(DFS)深度优先遍历主要思路是从图中一个未访问的顶点V开始,沿着一条路一直走到底,然后从这条路尽头的节点回退到上一个节点,再从另一条路开始走到底…,不断递归重复此过程,直到所有的顶点都遍历完成。例题P1605迷宫题目描述给定一个N×MN\timesMN×M方格的迷宫,迷宫里有TTT处障碍,障碍处不可通过。在迷宫中移动有上下左右四种方式,每次只能移动一个方格。数据保证起点上没有障碍。给定起点坐标和终点坐标,每个方格最多经过一次,问有多少种从起点坐标到终点坐标的方案。输入格式第一行为三个正整数N,M,TN,M,TN,M,T,分别表示迷宫的长宽和障碍总数。第二行为四个正整数SX,S

  9. ruby-on-rails - Ruby 改进和钩子(Hook) - 2

    我正在尝试使用ruby​​改进来应用Rails钩子(Hook)。我想避免猴子补丁。当猴子修补时它会这样工作ActiveRecord::Base.class_evaldoafter_finddo#dosomethingwithmy_methodenddefmy_method#somethingusefulendend我已经能够通过做这样的事情来拥有类方法:moduleActiveRecordRefinementsrefineActiveRecord::Base.singleton_classdodefmy_method#somethingcoolendendend但我无法运行钩子(Hoo

  10. ruby - 如何改进 Ruby 中的模块方法? - 2

    您可以使用优化您的类(class)moduleRefinedStringrefineStringdodefto_boolean(text)!!(text=~/^(true|t|yes|y|1)$/i)endendend但是如何细化模块方法呢?这:moduleRefinedMathrefineMathdodefPI22/7endendend引发:TypeError:错误的参数类型模块(预期类) 最佳答案 这段代码可以工作:moduleMathdefself.piputs'originalmethod'endendmoduleRefin

随机推荐