草庐IT

强化学习之路一 QLearning 算法

Please Call me 小强 2023-03-28 原文

Q-Learning算法

理论

Q-Learning是一种强化学习算法,用于学习在给定状态下采取不同行动的最佳策略。其公式如下:

\(Q(s,a) \leftarrow (1 - \alpha) \cdot Q(s,a) + \alpha \cdot (r + \gamma \cdot \max_{a'} Q(s',a'))\)

其中,\(Q(s,a)\)是在状态\(s\)下采取行动\(a\)的预期回报,\(\alpha\)是学习率,\(r\)是在状态\(s\)下采取行动\(a\)的即时回报,\(\gamma\)是折扣因子,\(s'\)是采取行动\(a\)后得到的新状态。\(\max_{a'} Q(s',a')\)是在新状态\(s'\)下采取不同行动所能获得的最大预期回报。

Q-Learning公式的意义是,在当前状态\(s\)下采取行动\(a\),更新当前状态下采取行动\(a\)的预期回报\(Q(s,a)\)。更新公式中的第一项表示当前状态下采取行动\(a\)的原始预期回报,第二项表示从当前状态采取行动\(a\)后得到的新状态\(s'\)的最大预期回报。通过不断更新\(Q(s,a)\),我们可以学习到在不同状态下采取不同行动的最佳策略。

将理论转换为简单易懂的python的代码:

alpha = 0.1
gamma = 0.5

# s 当前状态 就是一个位置信息
# a 执行动作 上下左右
# newS 当前状态执行动作后的新状态
# r 为执行动作a后,环境给的奖励
def updateQ(s, a, r):
    newS = None
    if a == 0: # 上
        newS = (s[0]-1, s[1])
    elif a == 1: # 下
        newS = (s[0]+1, s[1])
    elif a == 2: # 左
        newS = (s[0], s[1]-1)
    elif a == 3: # 右
        newS = (s[0], s[1]+1)

    Q[s][a] += alpha * (r + gamma * max(Q[newS]) - Q[s][a])

中间小插曲

刚开始看到理论后,就开始撸代码了,没有看其他人的写的代码, 结果翻车了。
根据我的理解,我刚开始代码的Q表是每个状态的价值表。动作的变化,引发环境改变,环境改变给出一个奖励, 然后在更新Q表。
大家一定要注意是对Q表存的是每个状态的每个动作的评价值
不过经过翻车,也算是加深了对Qlearning的理解

实际能运行demo

#coding:utf8
import random
import math
import gym
from gym import spaces
import numpy as np
S = "S" # 起始块
G = "G" # 目标块
F = "F" # 冻结块
H = "H" # 危险块
# 这个环境规则就是, 从S点走到G点,中间走到H点就GameOver
class MyEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        self.board = np.array([
            [S, F, F, F],
            [F, H, F, H],
            [F, F, F, H],
            [H, F, F, G],
        ])
        self.height, self.width = self.board.shape

        # 定义动作空间和观察空间
        self.action_space = spaces.Discrete(4) # 上下左右
        self.observation_space = spaces.Tuple((
            spaces.Discrete(self.height),
            spaces.Discrete(self.width)
        ))
        self.reset()

    def step(self, action):
        if action == 0: # 上
            next_pos = (self.current_pos[0]-1, self.current_pos[1])
        elif action == 1: # 下
            next_pos = (self.current_pos[0]+1, self.current_pos[1])
        elif action == 2: # 左
            next_pos = (self.current_pos[0], self.current_pos[1]-1)
        elif action == 3: # 右
            next_pos = (self.current_pos[0], self.current_pos[1]+1)

        assert self._is_valid_pos(next_pos)

        # 步骤越多模型越差
        self.steps += 0.1
        if self.board[next_pos] == H:
            reward = -self.steps -self.width*self.height
            self.done = True
        elif self.board[next_pos] == G:
            reward = -self.steps
            self.done = True
        else:
            reward = -self.steps - abs(next_pos[0]-3) - abs(next_pos[1]-3)
            self.done = False

        self.current_pos = next_pos
        return self.current_pos, reward, self.done, self.board[next_pos] == H

    def reset(self):
        self.current_pos = (0, 0)
        self.done = False
        self.steps = 0
        return self.current_pos

    def render(self, mode='human'):
        for i in range(self.height):
            for j in range(self.width):
                if (i, j) == self.current_pos:
                    print("*", end="")
                else:
                    print(self.board[i][j], end="")
            print()
        print()

    def _is_valid_pos(self, pos):
        if pos[0] < 0 or pos[0] >= self.height or pos[1] < 0 or pos[1] >= self.width:
            return False
        return True
    
def softmax(x):
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x)
    
# 定义获取当前位置动作的函数
def get_actions(row, col):
    actions = []
    if row < 3:  # 如果不在最后一行,则可以向下移动
        actions.append(1)
    if col < 3:  # 如果不在最后一列,则可以向右移动
        actions.append(3)
    if row > 0:  # 如果不在第一行,则可以向上移动
        actions.append(0)
    if col > 0:  # 如果不在第一列,则可以向左移动
        actions.append(2)
    return actions

env = MyEnv()
ACTIONS = np.arange(4)
ACTIONS_STR = '上|下|左|右'.split('|')

Q = np.random.rand(4, 4, 4)
for i in range(4):
    for j in range(4):
        actions = get_actions(i, j)
        for k in range(4):
            if k not in actions:
                # 经过soeftmax之后,执行这个动作的概率为0
                Q[(i, j, k)] = -float("inf")
            else:
                Q[(i, j, k)] = 0

def printQ():
    for i in range(4):
        for j in range(4):
            print("{}_{}: ".format(i,j), Q[(i,j)])

def getAction(s):
    action = np.argmax(softmax(Q[s]))
    return action

def train():
    alpha = 0.1
    gamma = 0.95
    # 90%概率
    useQ = 0.9
    for i in range(100):
        s = env.reset()
        while True:
            env.render()
            # 根据状态获取s, 选择一个动作
            can_actions = get_actions(s[0], s[1])
            action = getAction(s) if np.random.uniform() < useQ else np.random.choice(can_actions)
            assert action in can_actions
            nextPos, reward, done, isH =  env.step(action)
            if done: # game over 没有下一个状态
                Q[s][action] +=  alpha * (reward - Q[s][action])
                break
            else:
                Q[s][action] +=  alpha * (reward + gamma * max(Q[nextPos]) - Q[s][action])
            s = nextPos
def play():
    s = env.reset()
    env.render()
    while True:
        # 根据状态获取s, 选择一个动作
        action = getAction(s)
        print('执行了动作:', ACTIONS_STR[action])
        nextPos, reward, done, _ =  env.step(action)
        s = nextPos
        env.render()
        if done:
            print(reward)
            break

train()
printQ()
play()

展望

写这个花费了很久,第一个原因是Q表创建错误, 第二个是中间非常容易死循环。
写这个需要考虑到底需要迭代多少次合适,以及奖励应该怎么定合适,一定要有概率不按Q表选择动作, 因为容易出现死循环。训练步骤不能太少,Q表信息不够,也是容易出现死循环。
这个例子环境是固定的,环境变化,必须重新训练
奖励函数现在是 -step - 曼哈顿距离, 也就是说步骤越少以及距离越小,函数值越大
对无效动作给了-float("inf"), 充当动作的MASK, 使用softmax去映射,会得到0

有关强化学习之路一 QLearning 算法的更多相关文章

  1. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  2. 100个python算法超详细讲解:画直线 - 2

    1.问题描述使用Python的turtle(海龟绘图)模块提供的函数绘制直线。2.问题分析一幅复杂的图形通常都可以由点、直线、三角形、矩形、平行四边形、圆、椭圆和圆弧等基本图形组成。其中的三角形、矩形、平行四边形又可以由直线组成,而直线又是由两个点确定的。我们使用Python的turtle模块所提供的函数来绘制直线。在使用之前我们先介绍一下turtle模块的相关知识点。turtle模块提供面向对象和面向过程两种形式的海龟绘图基本组件。面向对象的接口类如下:1)TurtleScreen类:定义图形窗口作为绘图海龟的运动场。它的构造器需要一个tkinter.Canvas或ScrolledCanva

  3. ruby - 在 Ruby 中实现 Luhn 算法 - 2

    我一直在尝试用Ruby实现Luhn算法。我一直在执行以下步骤:该公式根据其包含的校验位验证数字,该校验位通常附加到部分帐号以生成完整帐号。此帐号必须通过以下测试:从最右边的校验位开始向左移动,每第二个数字的值加倍。将乘积的数字(例如,10=1+0=1、14=1+4=5)与原始数字的未加倍数字相加。如果总模10等于0(如果总和以零结尾),则根据Luhn公式该数字有效;否则无效。http://en.wikipedia.org/wiki/Luhn_algorithm这是我想出的:defvalidCreditCard(cardNumber)sum=0nums=cardNumber.to_s.s

  4. Ruby 斐波那契算法 - 2

    下面是我写的一个计算斐波那契数列中的值的方法:deffib(n)ifn==0return0endifn==1return1endifn>=2returnfib(n-1)+(fib(n-2))endend它工作到n=14,但在那之后我收到一条消息说程序响应时间太长(我正在使用repl.it)。有人知道为什么会这样吗? 最佳答案 Naivefibonacci进行了大量的重复计算-在fib(14)fib(4)中计算了很多次。您可以将内存添加到您的算法中以使其更快:deffib(n,memo={})ifn==0||n==1returnnen

  5. ruby-on-rails - Rails add_index 算法 : :concurrently still causes database lock up during migration - 2

    为了防止在迁移到生产站点期间出现数据库事务错误,我们遵循了https://github.com/LendingHome/zero_downtime_migrations中列出的建议。(具体由https://robots.thoughtbot.com/how-to-create-postgres-indexes-concurrently-in概述),但在特别大的表上创建索引期间,即使是索引创建的“并发”方法也会锁定表并导致该表上的任何ActiveRecord创建或更新导致各自的事务失败有PG::InFailedSqlTransaction异常。下面是我们运行Rails4.2(使用Acti

  6. ruby - 趋势算法 - 2

    我正在开发一个类似微论坛的项目,其中一个特殊用户发布一条快速(接近推文大小)的主题消息,订阅者可以用他们自己的类似大小的消息来响应。直截了当,没有任何形式的“挖掘”或投票,只是每个主题消息的响应按时间顺序排列。但预计会有很高的流量。我们想根据它们引起的响应嗡嗡声来标记主题消息,使用0到10的等级。在谷歌上搜索了一段时间的趋势算法和开源社区应用示例,到目前为止已经收集到两个有趣的引用资料,但我还没有完全理解它们:Understandingalgorithmsformeasuringtrends,关于使用基线趋势算法比较维基百科页面浏览量的讨论,在SO上。TheBritneySpearsP

  7. Ruby - 不支持的密码算法 (AES-256-GCM) - 2

    我收到错误:unsupportedcipheralgorithm(AES-256-GCM)(RuntimeError)但我似乎具备所有要求:ruby版本:$ruby--versionruby2.1.2p95OpenSSL会列出gcm:$opensslenc-help2>&1|grepgcm-aes-128-ecb-aes-128-gcm-aes-128-ofb-aes-192-ecb-aes-192-gcm-aes-192-ofb-aes-256-ecb-aes-256-gcm-aes-256-ofbRuby解释器:$irb2.1.2:001>require'openssl';puts

  8. java实现Dijkstra算法 - 2

    文章目录一.Dijkstra算法想解决的问题二.Dijkstra算法理论三.java代码实现一.Dijkstra算法想解决的问题解决的问题:求解单源最短路径,即各个节点到达源点的最短路径或权值考察其他所有节点到源点的最短路径和长度局限性:无法解决权值为负数的情况二.Dijkstra算法理论参数:S记录当前已经处理过的源点到最短节点U记录还未处理的节点dist[]记录各个节点到起始节点的最短权值path[]记录各个节点的上一级节点(用来联系该节点到起始节点的路径)Dijkstra算法步骤:(1)初始化:顶点集S:节点A到自已的最短路径长度为0。只包含源点,即S={A}顶点集U:包含除A外的其他顶

  9. 对于体育新闻中文文本关键字提取有哪些关键字提取算法及其步骤 - 2

    对于体育新闻中文文本的关键字提取,常用的算法包括TF-IDF、TextRank和LDA等。它们的基本步骤如下:1.TF-IDF算法: -将文本进行分词和词性标注处理。-统计每个词在文本中的词频(TF)。-计算每个词在整个语料库中出现的文档频率(DF)和逆文档频率(IDF)。-计算每个词的TF-IDF值,并按照值的大小进行排序,选择排名前几的词作为关键字。2.TextRank算法:-将文本进行分词和词性标注处理。-将分词结果转化成图模型,每个词语为节点,根据词语之间的共现关系建立边。-对图模型进行迭代计算,计算每个节点的PageRank值,表示该节点的重要性。-选择排名前几的节点作为关键字。3.

  10. arrays - ruby 中的最佳排列计数算法 - 2

    我正在尝试计算由二进制形式的1和0的P数表示的数字的数量。如果P=2,则表示的数字为0011、1100、0110、0101、1001、1010,所以计数为6。我试过:[0,0,1,1].permutation.to_a.uniq但这不是大数的最佳解决方案(P可以什么可能是最好的排列技术,或者我们是否有任何直接的数学来做到这一点? 最佳答案 Numberofpermutationcanbecalculatedusingfactorial.a=[0,0,1,1](1..a.size).inject(:*)#=>4!=>24要计算重复项,

随机推荐