草庐IT

使用MindSpore计算旋转矩阵

Dechin的博客 2023-03-28 原文

技术背景

坐标变换、旋转矩阵,是在线性空间常用的操作,在分子动力学模拟领域有非常广泛的应用。比如在一个体系中切换坐标,或者对整体分子进行旋转平移等。如果直接使用Numpy,是很容易可以实现的,只要把相关的旋转矩阵写成numpy.array的形式即可。但是在一些使用GPU计算的深度学习框架中,比如MindSpore框架,则是不能直接支持这样操作的。因此我们需要探索一下如何在MindSpore框架中实现一个简单的旋转矩阵,并使用旋转矩阵进行一些旋转操作。

Jax.numpy旋转矩阵

我们先介绍一下在常用的Numpy库中是如何实现一个旋转矩阵的,这里为了演示方便,简化编程工作量,我们选择用Jax中所集成的Numpy来进行试验和对比。这里我们计算的场景是,给定一个N原子的分子体系,其空间维度为D=3,我们通过给定三个欧拉角,来旋转整个分子系统。这就需要我们分别定义三个维度的旋转矩阵\(R_x(\phi),R_y(\psi),R_z(\theta)\),分别表示绕\(X\)轴旋转\(\phi\)的角度、绕\(Y\)轴旋转\(\psi\)的角度,以及绕\(Z\)轴旋转\(\theta\)的角度。如果使用Jax来进行编程,那我们得到的旋转矩阵应该是如下的形式:

def rotation(psi,phi,theta,v):
    """ Module of rotation in 3 Euler angles. """
    RY = np.array([[np.cos(psi),0,-np.sin(psi)],
                   [0, 1, 0],
                   [np.sin(psi),0,np.cos(psi)]])
    RX = np.array([[1,0,0],
                   [0,np.cos(phi),-np.sin(phi)],
                   [0,np.sin(phi),np.cos(phi)]])
    RZ = np.array([[np.cos(theta),-np.sin(theta),0],
                   [np.sin(theta),np.cos(theta),0],
                   [0,0,1]])
    return np.dot(RZ,np.dot(RX,np.dot(RY,v)))

multi_rotation = jit(vmap(rotation,(None,None,None,0)))

接下来使用这个旋转矩阵来展示一个具体的案例:

In [1]: from jax import numpy as np
In [2]: from jax import jit, vmap

In [3]: def rotation(psi,phi,theta,v):
   ...:     """ Module of rotation in 3 Euler angles. """
   ...:     RY = np.array([[np.cos(psi),0,-np.sin(psi)],
   ...:                    [0, 1, 0],
   ...:                    [np.sin(psi),0,np.cos(psi)]])
   ...:     RX = np.array([[1,0,0],
   ...:                    [0,np.cos(phi),-np.sin(phi)],
   ...:                    [0,np.sin(phi),np.cos(phi)]])
   ...:     RZ = np.array([[np.cos(theta),-np.sin(theta),0],
   ...:                    [np.sin(theta),np.cos(theta),0],
   ...:                    [0,0,1]])
   ...:     return np.dot(RZ,np.dot(RX,np.dot(RY,v)))
   ...: 

In [4]: multi_rotation = jit(vmap(rotation,(None,None,None,0)))

In [5]: import numpy as onp

In [6]: v=onp.random.random((3,3))

In [7]: v
Out[7]: 
array([[0.97911664, 0.48098486, 0.44966794],
       [0.25350689, 0.50949849, 0.77506796],
       [0.24502845, 0.23313826, 0.72014647]])

In [8]: multi_rotation(onp.pi, onp.pi, 0, v)
Out[8]: 
DeviceArray([[-0.97911656, -0.4809849 ,  0.449668  ],
             [-0.25350684, -0.50949854,  0.7750679 ],
             [-0.24502839, -0.23313832,  0.7201465 ]], dtype=float32)

在这个案例中,我们给定了绕X和Y轴分别旋转180度的操作,而对Z轴则保持相对静止。可想而知我们所得到的结果会使得X和Y的值分别取负号,而Z的值保持不变,上述的测试结果也表明这个计算过程是正确的。

MindSpore旋转矩阵

在MindSpore深度学习框架中,有一点不同于Numpy和Jax的是,MindSpore的Tensor中的元素不能包含有object。在上一个章节的案例中其实我们可以发现,旋转矩阵的元素中包含了一些正弦余弦函数的使用。假如我们使用MindSpore去计算正余弦函数值的话,得到的输出结果会是一个Tensor,而不是一个常数。比较尴尬的是,MindSpore的Tensor只能使用常数来初始化,这里矛盾点就出现了。那么我们只有两个途径可以解决这个问题:将输入的角度转化成普通numpy的格式,使用cpu上的numpy计算完成旋转矩阵之后,在输出的时候再转化为MindSpore的Tensor。而另一操作就是,先把所有的旋转矩阵的元素计算好之后,将这些元素concat起来变成一个一维的Tensor,再对该Tensor做一个reshape,就可以得到我们想要的旋转矩阵所对应的Tensor。在如下的示例中我们使用的是第二种方案:

In [1]: from mindspore import ops, Tensor

In [2]: import mindspore as ms

In [3]: import numpy as np

In [4]: psi = Tensor([np.pi], ms.float32)

In [5]: phi = Tensor([np.pi], ms.float32)

In [6]: theta = Tensor([0.], ms.float32)

In [7]: v = Tensor(np.random.random((3,3)), ms.float32)

In [8]: v
Out[8]: 
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 4.51581478e-01,  7.52180338e-01,  2.84639597e-01],
 [ 8.46439958e-01,  2.95659006e-01,  1.81022584e-01],
 [ 8.94563913e-01,  2.25287616e-01,  1.71754003e-01]])

In [9]: zero = Tensor([0.], ms.float32)

In [10]: one = Tensor([1.], ms.float32)

In [11]: def rotation(psi, phi, theta, v):
    ...:     RY = ops.Concat(-1)((ops.Cos()(psi), zero, -ops.Sin()(psi),
    ...:                          zero, one, zero,
    ...:                          ops.Sin()(psi), zero, ops.Cos()(psi)))
    ...:     RY = RY.reshape(3, 3)
    ...:     RX = ops.Concat(-1)((one, zero, zero,
    ...:                          zero, ops.Cos()(phi), -ops.Sin()(phi),
    ...:                          zero, ops.Sin()(phi), ops.Cos()(phi)))
    ...:     RX = RX.reshape(3, 3)
    ...:     RZ = ops.Concat(-1)((ops.Cos()(theta), -ops.Sin()(theta), zero,
    ...:                        ops.Sin()(theta), ops.Cos()(theta), zero,
    ...:                        zero, zero, one))
    ...:     RZ = RZ.reshape(3, 3)
    ...:     dot = ops.Einsum('ij,kj->ki')
    ...:     return dot((RZ, dot((RX, dot((RY, v))))))
    ...: 

In [12]: rotation(psi, phi, theta, v)
Out[12]: 
Tensor(shape=[3, 3], dtype=Float32, value=
[[-4.51581448e-01, -7.52180338e-01,  2.84639567e-01],
 [-8.46439958e-01, -2.95659035e-01,  1.81022629e-01],
 [-8.94563913e-01, -2.25287631e-01,  1.71754062e-01]])

从这个计算结果中,我们可以看到跟Jax的案例一样,也是得到了X和Y值分别取负数的结果,程序是正确运行的。但是这里关于案例代码,需要一些额外的解释:

  1. 在上述案例中,我们先定义了一系列的一维Tensor来作为旋转矩阵的元素,使用MindSpore的Concat算子将这些一维Tensor的最后一维取出组成一个新的Tensor,再对其做reshape操作,得到一个我们所需要的旋转矩阵。
  2. 在Jax中我们是使用了vmap将旋转矩阵对单个矢量旋转的操作扩展到对多个矢量的旋转操作,而在MindSpore中虽然也支持了Vmap的算子,但是这里我们使用的是MindSpore所支持的另外一个功能:爱因斯坦求和算子。使用这个算子,我们就允许了旋转矩阵直接对多个矢量输入的指定维度进行运算,一样也可以得到我们想要的计算结果。

总结概要

本文介绍了两个不同的深度学习框架:Jax和MindSpore下的旋转矩阵的实现,对于不同的框架来说同一个功能会涉及到不同的实现方式。在Jax中,由于其函数式编程的特性,就允许我们更加简单的去构造和扩展一个旋转矩阵。MindSpore是一个面向对象编程的框架,其优势在于构建大型的模型应用。但构造一个可用的简单模型,相对而言就会走一些弯路。就比如我们需要使用Concat+Reshape的算子来拼接一个旋转矩阵,看起来会相对麻烦一些。而构建好旋转矩阵之后,则可以使用跟Jax一样的Vmap操作,或者是直接使用爱因斯坦求和来计算旋转矩阵对多个矢量输入的计算,从文章中的案例中可以看到两者所得到的计算结果是一致的。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/mindrot.html

作者ID:DechinPhy

更多原著文章请参考:https://www.cnblogs.com/dechinphy/

打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

CSDN同步链接:https://blog.csdn.net/baidu_37157624?spm=1008.2028.3001.5343

51CTO同步链接:https://blog.51cto.com/u_15561675

有关使用MindSpore计算旋转矩阵的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 使用 RubyZip 生成 ZIP 文件时设置压缩级别 - 2

    我有一个Ruby程序,它使用rubyzip压缩XML文件的目录树。gem。我的问题是文件开始变得很重,我想提高压缩级别,因为压缩时间不是问题。我在rubyzipdocumentation中找不到一种为创建的ZIP文件指定压缩级别的方法。有人知道如何更改此设置吗?是否有另一个允许指定压缩级别的Ruby库? 最佳答案 这是我通过查看ruby​​zip内部创建的代码。level=Zlib::BEST_COMPRESSIONZip::ZipOutputStream.open(zip_file)do|zip|Dir.glob("**/*")d

  3. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  4. ruby-on-rails - 使用 Ruby on Rails 进行自动化测试 - 最佳实践 - 2

    很好奇,就使用ruby​​onrails自动化单元测试而言,你们正在做什么?您是否创建了一个脚本来在cron中运行rake作业并将结果邮寄给您?git中的预提交Hook?只是手动调用?我完全理解测试,但想知道在错误发生之前捕获错误的最佳实践是什么。让我们理所当然地认为测试本身是完美无缺的,并且可以正常工作。下一步是什么以确保他们在正确的时间将可能有害的结果传达给您? 最佳答案 不确定您到底想听什么,但是有几个级别的自动代码库控制:在处理某项功能时,您可以使用类似autotest的内容获得关于哪些有效,哪些无效的即时反馈。要确保您的提

  5. ruby - 在 Ruby 中使用匿名模块 - 2

    假设我做了一个模块如下:m=Module.newdoclassCendend三个问题:除了对m的引用之外,还有什么方法可以访问C和m中的其他内容?我可以在创建匿名模块后为其命名吗(就像我输入“module...”一样)?如何在使用完匿名模块后将其删除,使其定义的常量不再存在? 最佳答案 三个答案:是的,使用ObjectSpace.此代码使c引用你的类(class)C不引用m:c=nilObjectSpace.each_object{|obj|c=objif(Class===objandobj.name=~/::C$/)}当然这取决于

  6. ruby - 使用 ruby​​ 和 savon 的 SOAP 服务 - 2

    我正在尝试使用ruby​​和Savon来使用网络服务。测试服务为http://www.webservicex.net/WS/WSDetails.aspx?WSID=9&CATID=2require'rubygems'require'savon'client=Savon::Client.new"http://www.webservicex.net/stockquote.asmx?WSDL"client.get_quotedo|soap|soap.body={:symbol=>"AAPL"}end返回SOAP异常。检查soap信封,在我看来soap请求没有正确的命名空间。任何人都可以建议我

  7. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  8. ruby-on-rails - 'compass watch' 是如何工作的/它是如何与 rails 一起使用的 - 2

    我在我的项目目录中完成了compasscreate.和compassinitrails。几个问题:我已将我的.sass文件放在public/stylesheets中。这是放置它们的正确位置吗?当我运行compasswatch时,它不会自动编译这些.sass文件。我必须手动指定文件:compasswatchpublic/stylesheets/myfile.sass等。如何让它自动运行?文件ie.css、print.css和screen.css已放在stylesheets/compiled。如何在编译后不让它们重新出现的情况下删除它们?我自己编译的.sass文件编译成compiled/t

  9. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

  10. ruby - 在 64 位 Snow Leopard 上使用 rvm、postgres 9.0、ruby 1.9.2-p136 安装 pg gem 时出现问题 - 2

    我想为Heroku构建一个Rails3应用程序。他们使用Postgres作为他们的数据库,所以我通过MacPorts安装了postgres9.0。现在我需要一个postgresgem并且共识是出于性能原因你想要pggem。但是我对我得到的错误感到非常困惑当我尝试在rvm下通过geminstall安装pg时。我已经非常明确地指定了所有postgres目录的位置可以找到但仍然无法完成安装:$envARCHFLAGS='-archx86_64'geminstallpg--\--with-pg-config=/opt/local/var/db/postgresql90/defaultdb/po

随机推荐