草庐IT

强化学习之 PPO 算法

红龙96 2023-06-02 原文
简述 PPO
        PPO 算法是一种基于策略的、使用两个神经网络的强化学习算法。通过将“智体”当前
的“状态”输入神经网络,最终会得到相应的“动作”和“奖励”,再根据“动作”来更新
“智体”的状态,根据包含有“奖励”和“动作”的目标函数,运用梯度上升来更新神经网
络中的权重参数,从而能得到使得总体奖励值更大的“动作”判断。
月球飞船降落
        本文根据 gym 来跑强化学习,在该游戏中,“状态”与“奖励”的更新都使用 gym 内部
封装的函数来实行,所以我们只需要考虑“状态”→“神经网络”→“动作”就行了。
        下载 gym 的步骤如下:
                pip install gym
                pip install box2d box2d-kengz --user
如果在 cmd 下安装不成功,建议在 Anaconda 中安装
神经网络
         PPO 算法需要使用到两个神经网络,其中一个网络我们命名为“actor_net ”,“状态”就是通过“actor_net ”做出了采取什么动作的判断;另一个网络我们叫它“critic_net”,进入这个网络的也是“状态”,但通过这个网络得到的是一个值“value”,具体作用我们后面会详述。正如前面所说,“状态”通过“actor_net ”会得到“动作”,然后根据 gym 自带的函数会返回给我们新的“状态”以及“奖励”;“状态”通过“critic_net”会得到一个“value”。 这里得到的“动作”是经过 softmax 后得到的一个概率值,再经过采样后会得到相应动作的索引值。将“状态”输入神经网络是一个“智体”与环境不断交互产生数据的阶段,这也是强化学习里数据的来源,要注意的有交互并不等于有训练。
数学公式
        在深度强化学习中,我们的最终目的就是要尽可能地让得到的总的奖励最大。但是,在
PPO 算法中我们的目标函数求出的并不是“奖励”的期望,而是求“奖励”期望。
        公式如下:
        期望就是把概率和求出的值相乘,目的是求出一个较为平均的数。在强化学习中,“状态”时刻都在变化,即使是同样权重参数,也可能得到不同的判断,如果只是求单一状态下奖励的最大值,无法得到理想的效果。所以 PPO 算法将目标函数设为求奖励的期望值,这样可以一个较为总体的结果。
        再根据大数定律,我们可以把目标函数直接看成每一次的“动作”与“奖励”相乘结果的求和求平均值,公式如下:这里求出的“动作”其实是实行该动作的概率值。
我们再来说说“奖励”。
我们并不是使用直接得到的“奖励”进行参数的更新。我们前面有提到“
critic_net ”网
络会得到一个“
value ”,这时候我们要做的是拿“奖励”减去“
value ”得到一个新的值,我
们用来参与训练的便是这个新的到的值。
因为在一般的交互过程中,很难会有奖励值为负,也就是惩罚产生,这样显然是不利于
训练流程的。同时,加大惩罚,可以让“智体”不断探索发现能带来更大奖励值的“动作”。
actor_net ”负责产生“动作”,从而产生“奖励”和新的“状态”。而产生用于调整“奖
励”的“
value ”值,就是“
critic_net ”任务。
参数更新
我们前面已经得到了求期望的公式,但这个公式并不是我们要使用的目标函数,我们还
要进行一些加工,最终得到了如下函数:
该式子求出了目标函数的梯度,最终我们可能根据梯度上升的思想来更新参数。
这里要注意的是,我们是根据每一次的“状态”、“动作”和“奖励”来更新参数,因此
必须要把每一次交互的数据记录下来,等参数更新之后再删除。
Of policy Off policy
        PPO 算法里的 policy 就是使用的神经网络, Of policy 就是有经过神经网络, Off policy
是不经过神经网络。
假设我们设定,在飞船降落月球的过程中,一次降落最多产生 300 次动作,而最多可以
降落 1000 次,这样的话“智体”就会与环境交互 30 万次。
前面有提到,参数的更新需要用到全部的数据,然后删除。但是,如果一整个流程就更
新了这么一次参数,不仅慢而且浪费。因为 PPO 算法采用了两个“智体”,也就是两套参数,
其中一套我们成为“打工人”钻进神经网络,不断与环境交互产生数据。另一个我们成为“大
少爷”,平时就呆着不动,等到交互达到一定次数时,比如我们可以设为 400 次,这个看我
们自己,“打工人”将自己收集到的数据传递给“大少爷”,“大少爷”就根据前面的公式进 行参数更新。
因为是用了别的“智体”代替自己与环境交互,所以这个代替的“智体”与本来的“智
体”差异绝对不能太大,不然最终结果会有偏差。这里的差异,主要看的是判断出“动作”
的差异,因此,可以用以下代码来做一个限制:
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
Surr1 就是直接求两个“智体”所得值得比例, surr2 则是加了一个限制,如果将 eps_clip
设为 0.2 ,意思就是比值差异不能超过 0.2 。求出两个 surr ,然后取它们中得最小值,再加一
个负号,这就是我们用到的损失函数,也是之前目标函数加工后得来的。

有关强化学习之 PPO 算法的更多相关文章

  1. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  2. 100个python算法超详细讲解:画直线 - 2

    1.问题描述使用Python的turtle(海龟绘图)模块提供的函数绘制直线。2.问题分析一幅复杂的图形通常都可以由点、直线、三角形、矩形、平行四边形、圆、椭圆和圆弧等基本图形组成。其中的三角形、矩形、平行四边形又可以由直线组成,而直线又是由两个点确定的。我们使用Python的turtle模块所提供的函数来绘制直线。在使用之前我们先介绍一下turtle模块的相关知识点。turtle模块提供面向对象和面向过程两种形式的海龟绘图基本组件。面向对象的接口类如下:1)TurtleScreen类:定义图形窗口作为绘图海龟的运动场。它的构造器需要一个tkinter.Canvas或ScrolledCanva

  3. ruby - 在 Ruby 中实现 Luhn 算法 - 2

    我一直在尝试用Ruby实现Luhn算法。我一直在执行以下步骤:该公式根据其包含的校验位验证数字,该校验位通常附加到部分帐号以生成完整帐号。此帐号必须通过以下测试:从最右边的校验位开始向左移动,每第二个数字的值加倍。将乘积的数字(例如,10=1+0=1、14=1+4=5)与原始数字的未加倍数字相加。如果总模10等于0(如果总和以零结尾),则根据Luhn公式该数字有效;否则无效。http://en.wikipedia.org/wiki/Luhn_algorithm这是我想出的:defvalidCreditCard(cardNumber)sum=0nums=cardNumber.to_s.s

  4. Ruby 斐波那契算法 - 2

    下面是我写的一个计算斐波那契数列中的值的方法:deffib(n)ifn==0return0endifn==1return1endifn>=2returnfib(n-1)+(fib(n-2))endend它工作到n=14,但在那之后我收到一条消息说程序响应时间太长(我正在使用repl.it)。有人知道为什么会这样吗? 最佳答案 Naivefibonacci进行了大量的重复计算-在fib(14)fib(4)中计算了很多次。您可以将内存添加到您的算法中以使其更快:deffib(n,memo={})ifn==0||n==1returnnen

  5. ruby-on-rails - Rails add_index 算法 : :concurrently still causes database lock up during migration - 2

    为了防止在迁移到生产站点期间出现数据库事务错误,我们遵循了https://github.com/LendingHome/zero_downtime_migrations中列出的建议。(具体由https://robots.thoughtbot.com/how-to-create-postgres-indexes-concurrently-in概述),但在特别大的表上创建索引期间,即使是索引创建的“并发”方法也会锁定表并导致该表上的任何ActiveRecord创建或更新导致各自的事务失败有PG::InFailedSqlTransaction异常。下面是我们运行Rails4.2(使用Acti

  6. ruby - 趋势算法 - 2

    我正在开发一个类似微论坛的项目,其中一个特殊用户发布一条快速(接近推文大小)的主题消息,订阅者可以用他们自己的类似大小的消息来响应。直截了当,没有任何形式的“挖掘”或投票,只是每个主题消息的响应按时间顺序排列。但预计会有很高的流量。我们想根据它们引起的响应嗡嗡声来标记主题消息,使用0到10的等级。在谷歌上搜索了一段时间的趋势算法和开源社区应用示例,到目前为止已经收集到两个有趣的引用资料,但我还没有完全理解它们:Understandingalgorithmsformeasuringtrends,关于使用基线趋势算法比较维基百科页面浏览量的讨论,在SO上。TheBritneySpearsP

  7. Ruby - 不支持的密码算法 (AES-256-GCM) - 2

    我收到错误:unsupportedcipheralgorithm(AES-256-GCM)(RuntimeError)但我似乎具备所有要求:ruby版本:$ruby--versionruby2.1.2p95OpenSSL会列出gcm:$opensslenc-help2>&1|grepgcm-aes-128-ecb-aes-128-gcm-aes-128-ofb-aes-192-ecb-aes-192-gcm-aes-192-ofb-aes-256-ecb-aes-256-gcm-aes-256-ofbRuby解释器:$irb2.1.2:001>require'openssl';puts

  8. java实现Dijkstra算法 - 2

    文章目录一.Dijkstra算法想解决的问题二.Dijkstra算法理论三.java代码实现一.Dijkstra算法想解决的问题解决的问题:求解单源最短路径,即各个节点到达源点的最短路径或权值考察其他所有节点到源点的最短路径和长度局限性:无法解决权值为负数的情况二.Dijkstra算法理论参数:S记录当前已经处理过的源点到最短节点U记录还未处理的节点dist[]记录各个节点到起始节点的最短权值path[]记录各个节点的上一级节点(用来联系该节点到起始节点的路径)Dijkstra算法步骤:(1)初始化:顶点集S:节点A到自已的最短路径长度为0。只包含源点,即S={A}顶点集U:包含除A外的其他顶

  9. 对于体育新闻中文文本关键字提取有哪些关键字提取算法及其步骤 - 2

    对于体育新闻中文文本的关键字提取,常用的算法包括TF-IDF、TextRank和LDA等。它们的基本步骤如下:1.TF-IDF算法: -将文本进行分词和词性标注处理。-统计每个词在文本中的词频(TF)。-计算每个词在整个语料库中出现的文档频率(DF)和逆文档频率(IDF)。-计算每个词的TF-IDF值,并按照值的大小进行排序,选择排名前几的词作为关键字。2.TextRank算法:-将文本进行分词和词性标注处理。-将分词结果转化成图模型,每个词语为节点,根据词语之间的共现关系建立边。-对图模型进行迭代计算,计算每个节点的PageRank值,表示该节点的重要性。-选择排名前几的节点作为关键字。3.

  10. arrays - ruby 中的最佳排列计数算法 - 2

    我正在尝试计算由二进制形式的1和0的P数表示的数字的数量。如果P=2,则表示的数字为0011、1100、0110、0101、1001、1010,所以计数为6。我试过:[0,0,1,1].permutation.to_a.uniq但这不是大数的最佳解决方案(P可以什么可能是最好的排列技术,或者我们是否有任何直接的数学来做到这一点? 最佳答案 Numberofpermutationcanbecalculatedusingfactorial.a=[0,0,1,1](1..a.size).inject(:*)#=>4!=>24要计算重复项,

随机推荐