草庐IT

强化学习,直接策略搜索,策略梯度,入门样例

Wei_Xiong 2023-03-28 原文

策略梯度,入门样例

原文链接:
https://www.cnblogs.com/Twobox/
参考链接:

https://datawhalechina.github.io/easy-rl/#/chapter4/chapter4

https://zhuanlan.zhihu.com/p/358700228

策略网路结构

算法流程与策略梯度

添加一个基线

调整更合适的分数

代码结构

需要的包

import numpy as np
import gym
import matplotlib.pyplot as plt
import torch  # torch.optim.SGD 内置优化器
import torch.nn as nn  # 模型库
import torch.nn.functional as F  # 内置loss函数
from torch.utils.data import TensorDataset  # 包装
from torch.utils.data import DataLoader  # 迭代器

model.py

def loss_fun(p, advantage, N):
    # p就是p(a|s)  advantage 就是权重优势
    # p Tensor格式  advantage为数字数组1

    advantage = torch.Tensor(advantage)
    # 目标函数 1/N sum(sum(a' * log p'))
    loss = -torch.sum(torch.log(p) * advantage) / N
    return loss


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(4, 128)
        self.linear2 = nn.Linear(128, 2)
        # self.linear3 = nn.Linear(20, 2)

    def forward(self, x):
        # xb = xb.view(xb.size(0), -1)
        out = self.linear1(x)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.softmax(out, dim=-1)
        return out

    def fit(self, p, advantage, N):
        opt = torch.optim.Adam(self.parameters(), 0.005)
        loss = loss_fun(p, advantage, N)
        opt.zero_grad()
        loss.backward()
        opt.step()
        opt.zero_grad()

agent.py

class Agent:
    def __init__(self, gamma):
        self.model = Model()

        # 目标函数 1/N sum(sum(a' * log p'))
        self.p = []
        self.advantage = []
        self.N = 0
        self.gamma = gamma

    def get_action_p(self, state):
        # 转化为Tensor , 此时为一维
        state = torch.FloatTensor(state)
        # 转化为二维,最外面加个[]
        state = torch.unsqueeze(state, 0)

        p = self.model(state)
        return p  # tensor

    def clear(self):
        self.advantage.clear()
        self.p.clear()
        self.N = 0

    def pay_n_times(self, N, env):
        # 玩N次,追加存储N次经验
        self.N += N
        r_sum = 0  # 所有奖励
        advantage = []

        for n in range(N):
            state = env.reset()
            r_list = []  # 一个回合 每个动作的奖励
            done = False
            while not done:
                p = self.get_action_p(state)
                # 按概率采样下表;在dim为1的位置进行采样;这里的结果为[[0 or 1]]
                action = torch.multinomial(p, 1).item()  # 这时候直接是数字
                s_, r, done, _ = env.step(action)
                state = s_
                r_list.append(r)
                # 后续需要对self.p使用torch.cat方法
                self.p.append(p[0][action].unsqueeze(0))  

            r_sum += sum(r_list)
            # sum(gamma^i * r)
            ad_list = []
            ad_temp = 0
            for i in reversed(range(len(r_list))):
                ad_temp = ad_temp * self.gamma + r_list[i]
                ad_list.append(ad_temp)

            ad_list.reverse()
            advantage += ad_list

        b = r_sum / N
        advantage = [a - b for a in advantage]
        self.advantage += advantage

        # 返回平均分数
        return b

    def learn(self):
        p = torch.cat(self.p)
        advantage = torch.FloatTensor(self.advantage)
        self.model.fit(p, advantage, self.N)

main.py

env = gym.make("CartPole-v1")
agent = Agent(0.95)

T = 1000 # 更新多少次梯度
N = 50 # 每次跟新需要采样多少回合的经验
x, y = [], []
for t in range(T):
    avg_r = agent.pay_n_times(N, env)
    x.append(t)
    y.append(avg_r)
    print("{} : {}".format(t, avg_r))
    agent.learn()
    agent.clear()

    plt.plot(x,y)
    plt.pause(0.1)

plt.plot(x,y)
plt.show()

结果

本文原创作者:魏雄
原文链接:
https://www.cnblogs.com/Twobox/

有关强化学习,直接策略搜索,策略梯度,入门样例的更多相关文章

  1. ruby-on-rails - Nokogiri:使用 XPath 搜索 <div> - 2

    我使用Nokogiri(Rubygem)css搜索寻找某些在我的html里面。看起来Nokogiri的css搜索不喜欢正则表达式。我想切换到Nokogiri的xpath搜索,因为这似乎支持搜索字符串中的正则表达式。如何在xpath搜索中实现下面提到的(伪)css搜索?require'rubygems'require'nokogiri'value=Nokogiri::HTML.parse(ABBlaCD3"HTML_END#my_blockisgivenmy_bl="1"#my_eqcorrespondstothisregexmy_eq="\/[0-9]+\/"#FIXMEThefoll

  2. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  3. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  4. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  5. 微信小程序开发入门与实战(Behaviors使用) - 2

    @作者:SYFStrive @博客首页:HomePage📜:微信小程序📌:个人社区(欢迎大佬们加入)👉:社区链接🔗📌:觉得文章不错可以点点关注👉:专栏连接🔗💃:感谢支持,学累了可以先看小段由小胖给大家带来的街舞👉微信小程序(🔥)目录自定义组件-behaviors    1、什么是behaviors    2、behaviors的工作方式    3、创建behavior    4、导入并使用behavior    5、behavior中所有可用的节点    6、同名字段的覆盖和组合规则总结最后自定义组件-behaviors    1、什么是behaviorsbehaviors是小程序中,用于实现

  6. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

  7. ES基础入门 - 2

    ES一、简介1、ElasticStackES技术栈:ElasticSearch:存数据+搜索;QL;Kibana:Web可视化平台,分析。LogStash:日志收集,Log4j:产生日志;log.info(xxx)。。。。使用场景:metrics:指标监控…2、基本概念Index(索引)动词:保存(插入)名词:类似MySQL数据库,给数据Type(类型)已废弃,以前类似MySQL的表现在用索引对数据分类Document(文档)真正要保存的一个JSON数据{name:"tcx"}二、入门实战{"name":"DESKTOP-1TSVGKG","cluster_name":"elasticsear

  8. ruby - 如何搜索有用的 ruby - 2

    寻找有用的ruby的好网站是什么? 最佳答案 AgileWebDevelopment列出插件(虽然不是ruby​​gems,我不确定为什么),并允许人们对它们进行评级。RubyToolbox按类别列出gem并比较它们的受欢迎程度。Rubygems有一个搜索框。StackOverflow对最有用的rails插件和ruby​​gems有疑问。 关于ruby-如何搜索有用的ruby,我们在StackOverflow上找到一个类似的问题: https://stacko

  9. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  10. ruby - 如何搜索、递增和替换 Ruby 字符串中的整数子字符串? - 2

    我有很多这样的文档:foo_1foo_2foo_3bar_1foo_4...我想通过获取foo_[X]的所有实例并将它们中的每一个替换为foo_[X+1]来转换它们。在这个例子中:foo_2foo_3foo_4bar_1foo_5...我可以用gsub和一个block来做到这一点吗?如果不是,最干净的方法是什么?我真的在寻找一个优雅的解决方案,因为我总是可以暴力破解它,但我觉得有一些正则表达式技巧值得学习。 最佳答案 我(完全)不懂Ruby,但类似这样的东西应该可以工作:"foo_1foo_2".gsub(/(foo_)(\d+)/

随机推荐