草庐IT

PyTorch 打印模型结构、输出维度和参数信息(torchsummary)

梁小憨憨 2023-05-30 原文

使用 PyTorch 深度学习搭建模型后,如果想查看模型结构,可以直接使用 print(model) 函数打印。但该输出结果不是特别直观,查阅发现有个能输出类似 keras 风格 model.summary() 的模型可视化工具。这里记录一下方便以后查阅。

PyTorch 打印模型结构、输出维度和参数信息(torchsummary)

安装 torchsummary

pip install torchsummary

输出网络信息

summary函数介绍

model:网络模型
input_size:网络输入图片的shape,这里不用加batch_size进去
batch_size:batch_size参数,默认是-1
device:在GPU还是CPU上运行,默认是cuda在GPU上运行,如果想在CPU上执行将参数改为CPU即可

import torch
import torch.nn as nn
from torchsummary import summary



class Shallow_ConvNet(nn.Module):
    def __init__(self, in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,
                              pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size) :
        super(Shallow_ConvNet, self).__init__()

        self.temp_conv = nn.Conv2d(in_channels=in_channel,
                                                                    out_channels=conv_channel_temp,
                                                                    kernel_size=(1, kernel_size_temp),
                                                                    stride=1,
                                                                    bias=False)

        self.spat_conv = nn.Conv2d(in_channels=conv_channel_temp,
                                                                  out_channels=conv_channel_spat,
                                                                  kernel_size=(kernel_size_spat, 1),
                                                                  stride=1,
                                                                  bias=False)

        self.bn = nn.BatchNorm2d(num_features=conv_channel_spat)

        # slef.act_conv = x*x

        self.pooling = nn.AvgPool2d(kernel_size=(1, pooling_size),
                                                                   stride=(1, pool_stride_size))

        # slef.act_pool = log(max(x, eps))

        self.dropout = nn.Dropout(p=dropoutRate)

        self.class_conv = nn.Conv2d(in_channels=conv_channel_spat,
                                                                    out_channels=n_classes,
                                                                    kernel_size=(1, class_kernel_size),
                                                                    bias=False)

        self.softmax = nn.Softmax(dim=1)

    def safe_log(self, x):
        """ Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
        return torch.log(torch.clamp(x, min=1e-6))
    
    def forward(self, x):
        # input shape (batch_size, C, T)
        if len(x.shape) is not 4:
            x = torch.unsqueeze(x, 1)
        # input shape (batch_size, 1, C, T)
        x = self.temp_conv(x)
        x = self.spat_conv(x)
        x = self.bn(x)
        x = x*x # conv_activate
        x = self.pooling(x)
        x = self.safe_log(x) # pool_activate
        x = self.dropout(x)
        x = self.class_conv(x)
        x= self.softmax(x)
        out = torch.squeeze(x)

        return out


###============================ Initialization parameters ============================###
channels = 44
samples = 534

in_channel = 1
conv_channel_temp = 40
kernel_size_temp = 25
conv_channel_spat = 40
kernel_size_spat = channels
pooling_size = 75
pool_stride_size = 15
dropoutRate = 0.3
n_classes = 4
class_kernel_size = 30

def main():
    input = torch.randn(32, 1, channels, samples)
    model = Shallow_ConvNet(in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,
                                                            pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size)
    out = model(input)
    print('===============================================================')
    print('out', out.shape)
    print('model', model)
    summary(model=model, input_size=(1,channels,samples), batch_size=32, device="cpu")

if __name__ == "__main__":
    main()

输出:

out torch.Size([32, 4])
model Shallow_ConvNet(
  (temp_conv): Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1), bias=False)
  (spat_conv): Conv2d(40, 40, kernel_size=(44, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pooling): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0)
  (dropout): Dropout(p=0.3, inplace=False)
  (class_conv): Conv2d(40, 4, kernel_size=(1, 30), stride=(1, 1), bias=False)
  (softmax): Softmax(dim=1)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [32, 40, 44, 510]           1,000
            Conv2d-2           [32, 40, 1, 510]          70,400
       BatchNorm2d-3           [32, 40, 1, 510]              80
         AvgPool2d-4            [32, 40, 1, 30]               0
           Dropout-5            [32, 40, 1, 30]               0
            Conv2d-6              [32, 4, 1, 1]           4,800
           Softmax-7              [32, 4, 1, 1]               0
================================================================
Total params: 76,280
Trainable params: 76,280
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 2.87
Forward/backward pass size (MB): 229.69
Params size (MB): 0.29
Estimated Total Size (MB): 232.85
----------------------------------------------------------------

AttributeError: ‘tuple’ object has no attribute ‘size’

旧的summary加入LSTM之类的会报错,需要用新的summarry

pip install torchinfo
from torchinfo import summary

def main():
    input = torch.randn(32, window_size, channels, samples)
    model = Cascade_Conv_LSTM(in_channel, out_channel_conv1, out_channel_conv2, out_channel_conv3, kernel_conv123, stride_conv123, padding_conv123,
                                                                    fc1_in, fc1_out, dropoutRate1, lstm1_in, lstm1_hidden, lstm1_layer, lstm2_in, lstm2_hidden, lstm2_layer, fc2_in, fc2_out, dropoutRate2,
                                                                    fc3_in, n_classes)
    # model = model.to('cuda:1')
    # input = torch.from_numpy(input).to('cuda:1').to(torch.float32).requires_grad_()
    out = model(input)
    print('===============================================================')
    print('out', out.shape)
    print('model', model)
    summary(model=model, input_size=(32,10,channels,samples), device="cpu")

if __name__ == "__main__":
    main()
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Cascade_Conv_LSTM                        [32, 4]                   --
├─Sequential: 1-1                        [320, 32, 10, 11]         --
│    └─Conv2d: 2-1                       [320, 32, 10, 11]         288
│    └─ELU: 2-2                          [320, 32, 10, 11]         --
├─Sequential: 1-2                        [320, 64, 10, 11]         --
│    └─Conv2d: 2-3                       [320, 64, 10, 11]         18,432
│    └─ELU: 2-4                          [320, 64, 10, 11]         --
├─Sequential: 1-3                        [320, 128, 10, 11]        --
│    └─Conv2d: 2-5                       [320, 128, 10, 11]        73,728
│    └─ELU: 2-6                          [320, 128, 10, 11]        --
├─Sequential: 1-4                        [320, 1024]               --
│    └─Linear: 2-7                       [320, 1024]               14,418,944
│    └─ELU: 2-8                          [320, 1024]               --
├─Dropout: 1-5                           [320, 1024]               --
├─LSTM: 1-6                              [32, 10, 1024]            8,396,800
├─LSTM: 1-7                              [32, 10, 1024]            8,396,800
├─Sequential: 1-8                        [32, 1024]                --
│    └─Linear: 2-9                       [32, 1024]                1,049,600
│    └─ELU: 2-10                         [32, 1024]                --
├─Dropout: 1-9                           [32, 1024]                --
├─Linear: 1-10                           [32, 4]                   4,100
├─Softmax: 1-11                          [32, 4]                   --
==========================================================================================
Total params: 32,358,692
Trainable params: 32,358,692
Non-trainable params: 0
Total mult-adds (G): 13.28
==========================================================================================
Input size (MB): 0.14
Forward/backward pass size (MB): 71.21
Params size (MB): 129.43
Estimated Total Size (MB): 200.78
==========================================================================================

有关PyTorch 打印模型结构、输出维度和参数信息(torchsummary)的更多相关文章

  1. ruby-on-rails - Rails - 子类化模型的设计模式是什么? - 2

    我有一个模型:classItem项目有一个属性“商店”基于存储的值,我希望Item对象对特定方法具有不同的行为。Rails中是否有针对此的通用设计模式?如果方法中没有大的if-else语句,这是如何干净利落地完成的? 最佳答案 通常通过Single-TableInheritance. 关于ruby-on-rails-Rails-子类化模型的设计模式是什么?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.co

  2. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

  3. ruby-on-rails - Rails 常用字符串(用于通知和错误信息等) - 2

    大约一年前,我决定确保每个包含非唯一文本的Flash通知都将从模块中的方法中获取文本。我这样做的最初原因是为了避免一遍又一遍地输入相同的字符串。如果我想更改措辞,我可以在一个地方轻松完成,而且一遍又一遍地重复同一件事而出现拼写错误的可能性也会降低。我最终得到的是这样的:moduleMessagesdefformat_error_messages(errors)errors.map{|attribute,message|"Error:#{attribute.to_s.titleize}#{message}."}enddeferror_message_could_not_find(obje

  4. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  5. ruby-on-rails - Rails - 一个 View 中的多个模型 - 2

    我需要从一个View访问多个模型。以前,我的links_controller仅用于提供以不同方式排序的链接资源。现在我想包括一个部分(我假设)显示按分数排序的顶级用户(@users=User.all.sort_by(&:score))我知道我可以将此代码插入每个链接操作并从View访问它,但这似乎不是“ruby方式”,我将需要在不久的将来访问更多模型。这可能会变得很脏,是否有针对这种情况的任何技术?注意事项:我认为我的应用程序正朝着单一格式和动态页面内容的方向发展,本质上是一个典型的网络应用程序。我知道before_filter但考虑到我希望应用程序进入的方向,这似乎很麻烦。最终从任何

  6. ruby - 检查 "command"的输出应该包含 NilClass 的意外崩溃 - 2

    为了将Cucumber用于命令行脚本,我按照提供的说明安装了arubagem。它在我的Gemfile中,我可以验证是否安装了正确的版本并且我已经包含了require'aruba/cucumber'在'features/env.rb'中为了确保它能正常工作,我写了以下场景:@announceScenario:Testingcucumber/arubaGivenablankslateThentheoutputfrom"ls-la"shouldcontain"drw"假设事情应该失败。它确实失败了,但失败的原因是错误的:@announceScenario:Testingcucumber/ar

  7. ruby-on-rails - 如何在 ruby​​ 中使用两个参数异步运行 exe? - 2

    exe应该在我打开页面时运行。异步进程需要运行。有什么方法可以在ruby​​中使用两个参数异步运行exe吗?我已经尝试过ruby​​命令-system()、exec()但它正在等待过程完成。我需要用参数启动exe,无需等待进程完成是否有任何ruby​​gems会支持我的问题? 最佳答案 您可以使用Process.spawn和Process.wait2:pid=Process.spawn'your.exe','--option'#Later...pid,status=Process.wait2pid您的程序将作为解释器的子进程执行。除

  8. ruby-on-rails - 在混合/模块中覆盖模型的属性访问器 - 2

    我有一个包含模块的模型。我想在模块中覆盖模型的访问器方法。例如:classBlah这显然行不通。有什么想法可以实现吗? 最佳答案 您的代码看起来是正确的。我们正在毫无困难地使用这个确切的模式。如果我没记错的话,Rails使用#method_missing作为属性setter,因此您的模块将优先,阻止ActiveRecord的setter。如果您正在使用ActiveSupport::Concern(参见thisblogpost),那么您的实例方法需要进入一个特殊的模块:classBlah

  9. ruby - 通过 erb 模板输出 ruby​​ 数组 - 2

    我正在使用puppet为ruby​​程序提供一组常量。我需要提供一组主机名,我的程序将对其进行迭代。在我之前使用的bash脚本中,我只是将它作为一个puppet变量hosts=>"host1,host2"我将其提供给bash脚本作为HOSTS=显然这对ruby​​不太适用——我需要它的格式hosts=["host1","host2"]自从phosts和putsmy_array.inspect提供输出["host1","host2"]我希望使用其中之一。不幸的是,我终其一生都无法弄清楚如何让它发挥作用。我尝试了以下各项:我发现某处他们指出我需要在函数调用前放置“function_”……这

  10. ruby - RSpec - 使用测试替身作为 block 参数 - 2

    我有一些Ruby代码,如下所示:Something.createdo|x|x.foo=barend我想编写一个测试,它使用double代替block参数x,这样我就可以调用:x_double.should_receive(:foo).with("whatever").这可能吗? 最佳答案 specify'something'dox=doublex.should_receive(:foo=).with("whatever")Something.should_receive(:create).and_yield(x)#callthere

随机推荐