草庐IT

Pytorch中的grid_sample算子功能解析

追猫人 2023-10-04 原文

         pytorch中的grid_sample是一种特殊的采样算法。

调用接口为:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。

         input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。

         grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。

         mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。

         padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。

         align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。

         要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。

         基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。

import torch
import numpy as np
def grid_sample(input, grid):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    output = np.random.random((N,C,H,W))
    for i in range(N):
        for j in range(C):
            for k in range(H_out):
                for l in range(W_out):
                    param = [0.0, 0.0]
                    param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2
                    param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2
                    x0 = int(param[0])
                    x1 = x0 + 1
                    y0 = int(param[1])
                    y1 = y0 + 1
                    param[0] -= x0
                    param[1] -= y0
                    left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1])
                    left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1]
                    right_top = input[i][j][y0][x1] * param[0] * (1 - param[1])
                    right_bottom = input[i][j][y1][x1] * param[0] * param[1]
                    result = left_bottom + left_top + right_bottom + right_top
                    output[i][j][k][l] = result
    return output

N, C, H, W = 1, 1, 4, 4
input = np.random.random((N,C,H,W))
grid = np.random.random((N,H,W,2))
out = grid_sample(input, grid)
print(f'自定义实现输出结果:\n{out}')
input = torch.from_numpy(input)
grid = torch.from_numpy(grid)
output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)
print(f'grid_sample输出结果:\n{output}')

运行结果:

         从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。

        考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置'zero',也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torch
import numpy as np


def grid_sample(input, grid):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    output = np.random.random((N, C, H_out, W_out))
    for i in range(N):
        for j in range(C):
            for k in range(H_out):
                for l in range(W_out):
                    x, y = grid[i][k][l][0], grid[i][k][l][1]
                    param = [0.0, 0.0]
                    param[0] = (W_in - 1) * (x + 1) / 2
                    param[1] = (H_in - 1) * (y + 1) / 2
                    x1 = int(param[0] + 1)
                    x0 = x1 - 1
                    y1 = int(param[1] + 1)
                    y0 = y1 - 1
                    param[0] = abs(param[0] - x0)
                    param[1] = abs(param[1] - y0)
                    left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0
                    if 0 <= x0 < W_in and 0 <= y0 < H_in:
                        left_top_value = input[i][j][y0][x0]
                    if 0 <= x1 < W_in and 0 <= y0 < H_in:
                        right_top_value = input[i][j][y0][x1]
                    if 0 <= x0 < W_in and 0 <= y1 < H_in:
                        left_bottom_value = input[i][j][y1][x0]
                    if 0 <= x1 < W_in and 0 <= y1 < H_in:
                        right_bottom_value = input[i][j][y1][x1]
                    left_top = left_top_value * (1 - param[0]) * (1 - param[1])
                    left_bottom = left_bottom_value * (1 - param[0]) * param[1]
                    right_top = right_top_value * param[0] * (1 - param[1])
                    right_bottom = right_bottom_value * param[0] * param[1]
                    result = left_bottom + left_top + right_bottom + right_top
                    output[i][j][k][l] = result
    return output


N, C, H_in, W_in = 1, 1, 4, 4
H_out, W_out = 4, 4
input = np.random.random((N, C, H_in, W_in))
grid = np.random.random((N, H_out, W_out, 2))
grid[0][0][0] = [-1.2, 1.3]
out = grid_sample(input, grid)
print(f'自定义实现输出结果:\n{out}')
input = torch.from_numpy(input)
grid = torch.from_numpy(grid)
output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
print(f'grid_sample输出结果:\n{output}')

     测试结果:

   

有关Pytorch中的grid_sample算子功能解析的更多相关文章

  1. ruby - 如何从 ruby​​ 中的字符串运行任意对象方法? - 2

    总的来说,我对ruby​​还比较陌生,我正在为我正在创建的对象编写一些rspec测试用例。许多测试用例都非常基础,我只是想确保正确填充和返回值。我想知道是否有办法使用循环结构来执行此操作。不必为我要测试的每个方法都设置一个assertEquals。例如:describeitem,"TestingtheItem"doit"willhaveanullvaluetostart"doitem=Item.new#HereIcoulddotheitem.name.shouldbe_nil#thenIcoulddoitem.category.shouldbe_nilendend但我想要一些方法来使用

  2. Ruby 解析字符串 - 2

    我有一个字符串input="maybe(thisis|thatwas)some((nice|ugly)(day|night)|(strange(weather|time)))"Ruby中解析该字符串的最佳方法是什么?我的意思是脚本应该能够像这样构建句子:maybethisissomeuglynightmaybethatwassomenicenightmaybethiswassomestrangetime等等,你明白了......我应该一个字符一个字符地读取字符串并构建一个带有堆栈的状态机来存储括号值以供以后计算,还是有更好的方法?也许为此目的准备了一个开箱即用的库?

  3. ruby - 其他文件中的 Rake 任务 - 2

    我试图在一个项目中使用rake,如果我把所有东西都放到Rakefile中,它会很大并且很难读取/找到东西,所以我试着将每个命名空间放在lib/rake中它自己的文件中,我添加了这个到我的rake文件的顶部:Dir['#{File.dirname(__FILE__)}/lib/rake/*.rake'].map{|f|requiref}它加载文件没问题,但没有任务。我现在只有一个.rake文件作为测试,名为“servers.rake”,它看起来像这样:namespace:serverdotask:testdoputs"test"endend所以当我运行rakeserver:testid时

  4. ruby-on-rails - Ruby net/ldap 模块中的内存泄漏 - 2

    作为我的Rails应用程序的一部分,我编写了一个小导入程序,它从我们的LDAP系统中吸取数据并将其塞入一个用户表中。不幸的是,与LDAP相关的代码在遍历我们的32K用户时泄漏了大量内存,我一直无法弄清楚如何解决这个问题。这个问题似乎在某种程度上与LDAP库有关,因为当我删除对LDAP内容的调用时,内存使用情况会很好地稳定下来。此外,不断增加的对象是Net::BER::BerIdentifiedString和Net::BER::BerIdentifiedArray,它们都是LDAP库的一部分。当我运行导入时,内存使用量最终达到超过1GB的峰值。如果问题存在,我需要找到一些方法来更正我的代

  5. ruby-on-rails - Rails 3 中的多个路由文件 - 2

    Rails2.3可以选择随时使用RouteSet#add_configuration_file添加更多路由。是否可以在Rails3项目中做同样的事情? 最佳答案 在config/application.rb中:config.paths.config.routes在Rails3.2(也可能是Rails3.1)中,使用:config.paths["config/routes"] 关于ruby-on-rails-Rails3中的多个路由文件,我们在StackOverflow上找到一个类似的问题

  6. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  7. ruby-on-rails - Rails - 一个 View 中的多个模型 - 2

    我需要从一个View访问多个模型。以前,我的links_controller仅用于提供以不同方式排序的链接资源。现在我想包括一个部分(我假设)显示按分数排序的顶级用户(@users=User.all.sort_by(&:score))我知道我可以将此代码插入每个链接操作并从View访问它,但这似乎不是“ruby方式”,我将需要在不久的将来访问更多模型。这可能会变得很脏,是否有针对这种情况的任何技术?注意事项:我认为我的应用程序正朝着单一格式和动态页面内容的方向发展,本质上是一个典型的网络应用程序。我知道before_filter但考虑到我希望应用程序进入的方向,这似乎很麻烦。最终从任何

  8. ruby - 用逗号、双引号和编码解析 csv - 2

    我正在使用ruby​​1.9解析以下带有MacRoman字符的csv文件#encoding:ISO-8859-1#csv_parse.csvName,main-dialogue"Marceu","Giveittohimóhe,hiswife."我做了以下解析。require'csv'input_string=File.read("../csv_parse.rb").force_encoding("ISO-8859-1").encode("UTF-8")#=>"Name,main-dialogue\r\n\"Marceu\",\"Giveittohim\x97he,hiswife.\"\

  9. ruby-on-rails - Rails 3.2.1 中 ActionMailer 中的未定义方法 'default_content_type=' - 2

    我在我的项目中添加了一个系统来重置用户密码并通过电子邮件将密码发送给他,以防他忘记密码。昨天它运行良好(当我实现它时)。当我今天尝试启动服务器时,出现以下错误。=>BootingWEBrick=>Rails3.2.1applicationstartingindevelopmentonhttp://0.0.0.0:3000=>Callwith-dtodetach=>Ctrl-CtoshutdownserverExiting/Users/vinayshenoy/.rvm/gems/ruby-1.9.3-p0/gems/actionmailer-3.2.1/lib/action_mailer

  10. ruby-on-rails - Rails 应用程序中的 Rails : How are you using application_controller. rb 是新手吗? - 2

    刚入门rails,开始慢慢理解。有人可以解释或给我一些关于在application_controller中编码的好处或时间和原因的想法吗?有哪些用例。您如何为Rails应用程序使用应用程序Controller?我不想在那里放太多代码,因为据我了解,每个请求都会调用此Controller。这是真的? 最佳答案 ApplicationController实际上是您应用程序中的每个其他Controller都将从中继承的类(尽管这不是强制性的)。我同意不要用太多代码弄乱它并保持干净整洁的态度,尽管在某些情况下ApplicationContr

随机推荐