草庐IT

pytorch 笔记:torch.distributions 概率分布相关(更新中)

UQI-LIUWJ 2023-04-19 原文

1 包介绍

        torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

        不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

  • 评分函数估计量 score function estimato
  • 似然比估计量 likelihood ratio estimator
  • REINFORCE
  • 路径导数估计量 pathwise derivative estimator

REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

        虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。、

1.1 REINFORCE

        我们以reinforce 为例:

        当概率密度函数关于其参数可微时,我们只需要 sample() 和 log_prob() 来实现 REINFORCE:

        

        其中θ是参数,α是学习率,r是奖励,是在状态s的时候,根据策略使用动作a的概率

        (这个也就是policy gradient)

强化学习笔记:Policy-based Approach_UQI-LIUWJ的博客-CSDN博客

         在实践中,我们会从网络的输出中采样一个动作,在一个环境中应用这个动作,然后使用 log_prob 构造一个等效的损失函数。

         对于分类策略,实现 REINFORCE 的代码如下:(这只是一个示意代码,跑不起来的)

probs = policy_network(state)
#在状态state的时候,各个action的概率

m = Categorical(probs)
#分类概率

action = m.sample()
#采样一个action

next_state, reward = env.step(action)
#这里为了简化考虑,一个episode只有一个action

loss = -m.log_prob(action) * reward
#m.log_prob(action) 就是 logp
#reward就是前面的r
#这里用负号是因为强化学习是梯度上升

loss.backward()

  2 包所涉及的类

2.1 伯努利分布

torch.distributions.bernoulli.Bernoulli(
    probs=None, 
    logits=None, 
    validate_args=None)

        创建由 probs 或 logits(但不是两者同时)参数化的伯努利分布。

        样本是二进制的(0 或 1)。 它们取值 1 的概率为 p,取值 0 的概率为 1 - p。

2.1.1 参数

probs (Number,Tensor采样概率
logits (Number,Tensor采样的对数几率

2.1.2 函数 & 属性

sample()

采样,默认采样一个值

还可以按照shape 采样

entropy()

计算熵

enumerate_support()

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean

均值

probs, logits两个输入的参数
param_shape

参数的形状

variance

方差

2.2 贝塔分布

torch.distributions.beta.Beta(
    concentration1, 
    concentration0, 
    validate_args=None)

由concentration 1 (α)和concentration 0 (β)参数化的 Beta 分布。

 2.2.1 函数

采样

默认是采样一个值,也可以设置采样的维数

entropy

计算熵


rsample(sample_shape)

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

注:生成Beta分布的时候,两个参数必须至少有一个是Tensor,否则rsample效果失效

mean,variance

均值 & 方差

 2.3 Chi2 分布

torch.distributions.chi2.Chi2(
    df, 
    validate_args=None)

 它只有sample一个函数 

2.4 连续伯努利

参数和伯努利很类似

torch.distributions.continuous_bernoulli.ContinuousBernoulli(
    probs=None, 
    logits=None, 
    lims=(0.499, 0.501), 
    validate_args=None)

请注意,与伯努利不同,这里的“probs”不对应于伯努利的“probs”,这里的“logits”不对应于伯努利的“logits”,但由于与伯努利的相似性,使用了相同的名称。 

2.4.1 函数

sample还是采样
cdf

返回以 value 计算的累积概率密度函数。

icdf

返回以 value 计算的逆累积密度/质量函数。

entropy

还是计算熵

rsample

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

和前面Beta分布类似,只有创建时参数为Tensor,才会有rsample效果

mean,variance均值 方差

 2.5 二项分布

torch.distributions.binomial.Binomial(
    total_count=1, 
    probs=None, 
    logits=None, 
    validate_args=None)

 

         创建由 total_count 和 probs 或 logits(但不是两者)参数化的二项分布。 total_count 必须可以用 probs/logits 广播。

2.5.1 函数&参数

sample

采样

 

100被广播到0,0.2,0.8,1 所以每次相当于是四个二项分布

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean,variance

均值,方差

2.6  分类分布

torch.distributions.categorical.Categorical(
    probs=None, 
    logits=None, 
    validate_args=None)

 样本是来{0,...,K−1} 的整数,其中 K 是 probs.size(-1)。

2.6.1 函数

sample采样

entropy

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

2.6.2 注意:

创建分类分布时候的Tensor中元素的和可以不是1,最后归一化到1即可

import torch
import math
m=torch.distributions.Categorical(torch.Tensor([1,2,4]))
m.enumerate_support()
#tensor([0, 1, 2])

m.probs
#tensor([0.1429, 0.2857, 0.5714])

3 log_probs

很多分类都有这样一个函数log_probs,我们就统一说一下

假设m是一个torch的分类,那么m.log_prob(action)相当于

probs.log()[0][action.item()].unsqueeze(0)

(对这个action的概率添加log操作) 

import torch
import math
m=torch.distributions.Categorical(torch.Tensor([1,2,4]))
m.enumerate_support()
#tensor([0, 1, 2])

a=m.sample()
a
#tensor(2)

m.probs
#tensor([0.1429, 0.2857, 0.5714])

m.probs.log()
#tensor([-1.9459, -1.2528, -0.5596])

m.log_prob(a)
#tensor(-0.5596)

m.probs.log()[a.item()]
#tensor(-0.5596)

有关pytorch 笔记:torch.distributions 概率分布相关(更新中)的更多相关文章

  1. ruby-on-rails - 如何验证 update_all 是否实际在 Rails 中更新 - 2

    给定这段代码defcreate@upgrades=User.update_all(["role=?","upgraded"],:id=>params[:upgrade])redirect_toadmin_upgrades_path,:notice=>"Successfullyupgradeduser."end我如何在该操作中实际验证它们是否已保存或未重定向到适当的页面和消息? 最佳答案 在Rails3中,update_all不返回任何有意义的信息,除了已更新的记录数(这可能取决于您的DBMS是否返回该信息)。http://ar.ru

  2. ruby-on-rails - 使用 rails 4 设计而不更新用户 - 2

    我将应用程序升级到Rails4,一切正常。我可以登录并转到我的编辑页面。也更新了观点。使用标准View时,用户会更新。但是当我添加例如字段:name时,它​​不会在表单中更新。使用devise3.1.1和gem'protected_attributes'我需要在设备或数据库上运行某种更新命令吗?我也搜索过这个地方,找到了许多不同的解决方案,但没有一个会更新我的用户字段。我没有添加任何自定义字段。 最佳答案 如果您想允许额外的参数,您可以在ApplicationController中使用beforefilter,因为Rails4将参数

  3. ruby-on-rails - 相关表上的范围为 "WHERE ... LIKE" - 2

    我正在尝试从Postgresql表(table1)中获取数据,该表由另一个相关表(property)的字段(table2)过滤。在纯SQL中,我会这样编写查询:SELECT*FROMtable1JOINtable2USING(table2_id)WHEREtable2.propertyLIKE'query%'这工作正常:scope:my_scope,->(query){includes(:table2).where("table2.property":query)}但我真正需要的是使用LIKE运算符进行过滤,而不是严格相等。然而,这是行不通的:scope:my_scope,->(que

  4. ruby - 分布式事务和队列,ruby,erlang,scala - 2

    我有一个涉及多台机器、消息队列和事务的问题。因此,例如用户点击网页,点击将消息发送到另一台机器,该机器将付款添加到用户的帐户。每秒可能有数千次点击。事务的所有方面都应该是容错的。我以前从未遇到过这样的事情,但一些阅读表明这是一个众所周知的问题。所以我的问题。我假设安全的方法是使用两阶段提交,但协议(protocol)是阻塞的,所以我不会获得所需的性能,我是否正确?我通常写Ruby,但似乎Redis之类的数据库和Rescue、RabbitMQ等消息队列系统对我的帮助不大——即使我实现某种两阶段提交,如果Redis崩溃,数据也会丢失,因为它本质上只是内存。所有这些让我开始关注erlang和

  5. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  6. objective-c - 在设置 Cocoa Pods 和安装 Ruby 更新时出错 - 2

    我正在尝试为我的iOS应用程序设置cocoapods但是当我执行命令时:sudogemupdate--system我收到错误消息:当前已安装最新版本。中止。当我进入cocoapods的下一步时:sudogeminstallcocoapods我在MacOS10.8.5上遇到错误:ERROR:Errorinstallingcocoapods:cocoapods-trunkrequiresRubyversion>=2.0.0.我在MacOS10.9.4上尝试了同样的操作,但出现错误:ERROR:Couldnotfindavalidgem'cocoapods'(>=0),hereiswhy:U

  7. ruby-on-rails - Rails Associations 的更新方法是什么? - 2

    这太简单了,太荒谬了,我在任何地方都找不到关于它的任何信息,包括API文档和Rails源代码:我有一个:belongs_to关联,我开始理解当您没有关联时您在Controller中调用的正常模型方法与您有关联时调用的方法略有不同。例如,我的关联在创建Controller操作时运行良好:@user=current_user@building=Building.new(params[:building])respond_todo|format|if@user.buildings.create(params[:building])#etcetera但我找不到关于更新如何工作的文档:@user

  8. ruby-on-rails - 在具有 ActiveRecord 条件的相关模型中按字段排序 - 2

    我正在尝试按Rails相关模型中的字段进行排序。我研究的所有解决方案都没有解决如果相关模型被另一个参数过滤?元素模型classItem相关模型:classPriority我正在使用where子句检索项目:@items=Item.where('company_id=?andapproved=?',@company.id,true).all我需要按相关表格中的“位置”列进行排序。问题在于,在优先级模型中,一个项目可能会被多家公司列出。因此,这些职位取决于他们拥有的company_id。当我显示项目时,它是针对一个公司的,按公司内的职位排序。完成此任务的正确方法是什么?感谢您的帮助。PS-我

  9. ruby-on-rails - OSX Yosemite 更新破坏了 pow.cx - 2

    升级到OSXYosemite后,我现有的pow.cx安装不起作用。升级到最新的pow.cx无效。通过事件监视器重新启动它也没有成功。 最佳答案 卸载(!)并重新安装解决了这个问题。curlget.pow.cx/uninstall.sh|shcurlget.pow.cx|sh 关于ruby-on-rails-OSXYosemite更新破坏了pow.cx,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/q

  10. ruby - 将 Gitlab 从 9.3.7 更新到 9.3.8 安装 re2 时出错 - 2

    我们在Ubuntu14.04和Gitlab9.3.7上运行,运行良好。我们正在尝试更新到Gitlabv9.3.8的最新安全补丁,但它给我们这个错误:Gem::Ext::BuildError:ERROR:Failedtobuildgemnativeextension.currentdirectory:/home/git/gitlab/vendor/bundle/ruby/2.3.0/gems/re2-1.0.0/ext/re2/usr/local/bin/ruby-r./siteconf20170720-19622-15i0edf.rbextconf.rbcheckingformain(

随机推荐