草庐IT

解决ValueError: Expected input batch_size (40) to match target batch_size (8).

翰墨大人 2023-04-21 原文

已解决!!!有bug不要放弃一定要细心追根溯源,花点时间很正常的。

1:bug出现的地方

根据报错的信息,我们可以定位在损失函数losses = loss_function_train(pred_scales, target_scales),还有在损失函数的原函数处class CrossEntropyLoss2d(nn.Module):

2:什么原因导致的bug:

这是由于维度不匹配导致的,那是什么维度不匹配?,以及那两个维度不匹配的呢?。

①:在网上冲浪了大半天,大部分都是因为view函数使用错误,导致nn.linear函数的输入和输出不匹配。因此需要回模型检查view函数前的维度,通过print函数检查view函数输入前的维度,经过我认真检查维度,对每一个层都进行print后发现模型维度没有任何的错误,所以这个方法不适用于我,但是还把链接放在这里大家检查一下自己的模型batch维度不匹配

②:然后我就在losses处前面加上print,即打印pred_scales,target_scales的shape。

        # print(pred_scales.shape) #torch.Size([8, 40, 448, 448])
        # print(target_scales.shape) #torch.Size([8, 448, 448])
        losses = loss_function_train(pred_scales, target_scales)

这里还有一个小插曲,刚开始target_scales的size还打印不出来,是因为target_scales是一个列表,里面是totch,经过分析把target_scales旁边的中括号去掉就可以打印了。

这里我们看一下pred_scales,target_scales到底是啥:

pred_scales = model(image, depth)
        if modality in ['rgbd', 'rgb']:
            image = sample['image'].to(device)
            # print(image.shape) #torch.Size([8, 3, 448, 448])
            batch_size = image.data.shape[0]
        if modality in ['rgbd', 'depth']:
            depth = sample['depth'].to(device)
            # print(depth.shape) #torch.Size([8, 1, 448, 448])
            batch_size = depth.data.shape[0]
            # print(batch_size) # 8
        target_scales = sample['label'].to(device)

model是我们实例化后的模型,这里将rgb和depth输入,pred_scales就是我们的模型输出,这里是(8,40,448,448),target_scales是标签。这里我们可以看出target_scales是sample列表中['label']索引对应的数据,同理image和depth也是rgb和depth索引对应的数据。

而sample是什么呢?

train_data = Dataset(
       data_dir=args.dataset_dir,
       split='train',
       depth_mode=depth_mode,
       with_input_orig=with_input_orig,
       **dataset_kwargs)

train_loader = DataLoader(train_data,
                          batch_size=args.batch_size,
                          num_workers=args.workers,
                          drop_last=True,
                          shuffle=True)

train_loader, valid_loader = data_loaders

for i, sample in enumerate(train_loader):

我们看一下数据传递的流程,首先获取data路径,经过dataset获得图片,然后经过dataloader取一个batch的数据得到trainloader,遍历trainloader的列表,得到索引i和数据sample。因为trainloader取的一个batch=8的数据,所以samle里面包含了image,depth,label他们的大小分别为torch.Size([8, 3, 448, 448]),torch.Size([8, 1, 448, 448]),torch.Size([8,  448, 448])。即

pred_scales大小为(8,40,448,448),我们有40个类别,target_scales大小为torch.Size([8,  448, 448])。

这里延伸一下pytorch如何进行损失函数计算    参考

标签没有通道,每一个像素代表一个类别,且大小和图片的输入相同,为什么不需要one-hot编码是因为pytorch自动进行编码了。这里有一个坑:预测值和标签进行损失计算,他们两个都必须有batch,否则是不能计算成功的。

 下面一个例子演示一下:

inputs_scales = torch.rand(8,40,448,448)
targets_scales = torch.rand(8,448,448)
for inputs, targets in zip(inputs_scales, targets_scales):
    # inputs = inputs.unsqueeze(0)
    # targets = targets.unsqueeze(0)
    print(inputs.shape)
    print(targets.shape)
    loss2 = nn.CrossEntropyLoss()
    result2 = loss2(inputs, targets.long())
    print(result2)
torch.Size([40, 448, 448])
torch.Size([448, 448])
Traceback (most recent call last):
  File "/tmp/pycharm_project_346/kong.py", line 816, in <module>
    result2 = loss2(inputs, targets.long())
  File "/home/software/anaconda3/envs/pycharm329/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/software/anaconda3/envs/pycharm329/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 1166, in forward
    label_smoothing=self.label_smoothing)
  File "/home/software/anaconda3/envs/pycharm329/lib/python3.7/site-packages/torch/nn/functional.py", line 3014, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (40) to match target batch_size (448).

类似于题目中的bug是吧!

我们增加batch维度后:batch为8,所以遍历八次,每次都做损失计算。

inputs_scales = torch.rand(8,40,448,448)
targets_scales = torch.rand(8,448,448)
for inputs, targets in zip(inputs_scales, targets_scales):
    inputs = inputs.unsqueeze(0)
    targets = targets.unsqueeze(0)
    print(inputs.shape)
    print(targets.shape)
    loss2 = nn.CrossEntropyLoss()
    result2 = loss2(inputs, targets.long())
    print(result2)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7298)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7283)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7302)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7282)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7296)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7296)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7289)
torch.Size([1, 40, 448, 448])
torch.Size([1, 448, 448])
tensor(3.7292)

在损失的定义中:

inputs_scales和 targets_scales的维度分别为torch.Size([8, 40, 448, 448]),torch.Size([8, 448, 448]),遍历inputs_scales和 targets_scales,他们的维度就是如下,他们是不能进行损失计算的。

            print(targets.shape) torch.Size([448, 448])

            print(inputs.shape) torch.Size([40, 448, 448])
class CrossEntropyLoss2d(nn.Module):
    def __init__(self, device, weight):
        super(CrossEntropyLoss2d, self).__init__()
        self.weight = torch.tensor(weight).to(device)
        self.num_classes = len(self.weight) + 1  # +1 for void
        if self.num_classes < 2**8:
            self.dtype = torch.uint8
        else:
            self.dtype = torch.int16
        self.ce_loss = nn.CrossEntropyLoss(
            torch.from_numpy(np.array(weight)).float(),
            reduction='none',
            ignore_index=-1
        )
        self.ce_loss.to(device)

    def forward(self, inputs_scales, targets_scales):
        losses = []
        for inputs, targets in zip(inputs_scales, targets_scales):
            # mask = targets > 0
            # 返回一个和源张量同shape、dtype和device的张量,与源张量不共享数据内存,但提供梯度的回溯
            targets_m = targets.clone()
            targets_m -= 1
            print(inputs.size())
            print(targets_m.size())
            loss_all = self.ce_loss(inputs, targets_m.long())

            number_of_pixels_per_class = \
                torch.bincount(targets.flatten().type(self.dtype),
                               minlength=self.num_classes)
            divisor_weighted_pixel_sum = \
                torch.sum(number_of_pixels_per_class[1:] * self.weight)   # without void
            losses.append(torch.sum(loss_all) / divisor_weighted_pixel_sum)
            # losses.append(torch.sum(loss_all) / torch.sum(mask.float()))

        return losses

3:如何解决?

所以我们要给遍历的两个数据增加维度,或者说遍历[8, 40, 448, 448],我们希望的输出是[1,40,448,448],直接增加维度也是同理。然后我们就可以运行了。

            inputs = inputs.unsqueeze(0)
            targets = targets.unsqueeze(0)
            targets_m = targets.clone()

总结:预测图和标签又要有batch这一维度,才能够匹配,才能够输入到损失函数中。正好就对应了bug,batch的不匹配。

有关解决ValueError: Expected input batch_size (40) to match target batch_size (8).的更多相关文章

  1. 屏幕录制为什么没声音?检查这2项,轻松解决 - 2

    相信很多人在录制视频的时候都会遇到各种各样的问题,比如录制的视频没有声音。屏幕录制为什么没声音?今天小编就和大家分享一下如何录制音画同步视频的具体操作方法。如果你有录制的视频没有声音,你可以试试这个方法。 一、检查是否打开电脑系统声音相信很多小伙伴在录制视频后会发现录制的视频没有声音,屏幕录制为什么没声音?如果当时没有打开音频录制,则录制好的视频是没有声音的。因此,建议在录制前进行检查。屏幕上没有声音,很可能是因为你的电脑系统的声音被禁止了。您只需打开电脑系统的声音,即可录制音频和图画同步视频。操作方法:步骤1:点击电脑屏幕右下侧的“小喇叭”图案,在上方的选项中,选择“声音”。 步骤2:在“声

  2. 【高数】用拉格朗日中值定理解决极限问题 - 2

    首先回顾一下拉格朗日定理的内容:函数f(x)是在闭区间[a,b]上连续、开区间(a,b)上可导的函数,那么至少存在一个,使得:通过这个表达式我们可以知道,f(x)是函数的主体,a和b可以看作是主体函数f(x)中所取的两个值。那么可以有,  也就意味着我们可以用来替换 这种替换可以用在求某些多项式差的极限中。方法: 外层函数f(x)是一致的,并且h(x)和g(x)是等价无穷小。此时,利用拉格朗日定理,将原式替换为 ,再进行求解,往往会省去复合函数求极限的很多麻烦。使用要注意:1.要先找到主体函数f(x),即外层函数必须相同。2.f(x)找到后,复合部分是等价无穷小。3.要满足作差的形式。如果是加

  3. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  4. ruby - 在 Ruby 中,为什么 Array.new(size, object) 创建一个由对同一对象的多个引用组成的数组? - 2

    如thisanswer中所述,Array.new(size,object)创建一个数组,其中size引用相同的object。hash=Hash.newa=Array.new(2,hash)a[0]['cat']='feline'a#=>[{"cat"=>"feline"},{"cat"=>"feline"}]a[1]['cat']='Felix'a#=>[{"cat"=>"Felix"},{"cat"=>"Felix"}]为什么Ruby会这样做,而不是对object进行dup或clone? 最佳答案 因为那是thedocumenta

  5. ruby - 如何更快地解决 project euler #21? - 2

    原始问题Letd(n)bedefinedasthesumofproperdivisorsofn(numberslessthannwhichdivideevenlyinton).Ifd(a)=bandd(b)=a,whereab,thenaandbareanamicablepairandeachofaandbarecalledamicablenumbers.Forexample,theproperdivisorsof220are1,2,4,5,10,11,20,22,44,55and110;therefored(220)=284.Theproperdivisorsof284are1,2,

  6. ruby - 为什么这些方法没有解决? - 2

    这个问题在这里已经有了答案:WhydoRubysettersneed"self."qualificationwithintheclass?(3个答案)关闭29天前。给定这段代码:classSomethingattr_accessor:my_variabledefinitialize@my_variable=0enddeffoomy_variable=my_variable+3endends=Something.news.foo我收到这个错误:test.rb:9:in`foo':undefinedmethod`+'fornil:NilClass(NoMethodError)fromtes

  7. 电脑启动后显示器黑屏怎么办?排查下面4个问题,快速解决 - 2

    电脑启动出现显示器黑屏是一个相当常见的问题。如果您遇到了这个问题,不要惊慌,因为它有很多可能的原因,可以采取一些简单的措施来解决它。在本文中,小编将介绍下面4种常见的电脑启动后显示器黑屏的原因,排查这些原因,快速解决! 演示机型:联想Ideapad700-15ISK-ISE系统版本:Windows10一、显示器问题如果出现电脑启动后显示器黑屏的情况。那么首先您需要检查一下显示器是否正常工作。您可以通过更换另一个显示器或将当前显示器连接到另一台计算机来检查显示器是否存在问题。如果问题仍然存在,那么您可以排除显示器故障的可能性。 二、显卡问题如果您的电脑配备了独立显卡,那么显卡故障也可能是导致电脑

  8. 关于Qt程序打包后运行库依赖的常见问题分析及解决方法 - 2

    目录一.大致如下常见问题:(1)找不到程序所依赖的Qt库version`Qt_5'notfound(requiredby(2)CouldnotLoadtheQtplatformplugin"xcb"in""eventhoughitwasfound(3)打包到在不同的linux系统下,或者打包到高版本的相同系统下,运行程序时,直接提示段错误即segmentationfault,或者Illegalinstruction(coredumped)非法指令(4)ldd应用程序或者库,查看运行所依赖的库时,直接报段错误二.问题逐个分析,得出解决方法:(1)找不到程序所依赖的Qt库version`Qt_5'

  9. 【RuntimeError: CUDA error: device-side assert triggered】问题与解决 - 2

    RuntimeError:CUDAerror:device-sideasserttriggered问题描述解决思路发现问题:总结问题描述当我在调试模型的时候,出现了如下的问题/opt/conda/conda-bld/pytorch_1656352465323/work/aten/src/ATen/native/cuda/IndexKernel.cu:91:operator():block:[5,0,0],thread:[63,0,0]Assertion`index>=-sizes[i]&&index通过提示信息可以知道是个数组越界的问题。但是如图一中第二行话所说这个问题可能并不出在提示的代码段

  10. ruby-on-rails - 如何解决#<Book::ActiveRecord_Relation:0x007fb709a6a8c0> 的未定义方法 `to_key'? - 2

    我遇到了未定义方法`to_key'的问题这是我的books_controller.rbclassBooksController和我的索引页如下。index.html.erb......现在当我要访问索引页面时出现如下错误。undefinedmethod`to_key'for# 最佳答案 index通常返回一个集合。事实上,您的Controller符合要求。但是,您的View试图为其定义一个表单。正如您所发现的,这不会成功。表单适用于实体,而不适用于集合。该错误在您看来以及您希望如何处理index。

随机推荐