草庐IT

动手实现深度学习(12): 卷积层的实现与优化(img2col)

修雨轩陈 2023-03-28 原文

9.1 卷积层的运算

传送门: https://www.cnblogs.com/greentomlee/p/12314064.html

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

卷积的forward

卷积的计算过程网上的资料已经做够好了,没必要自己再写一遍。只把资料搬运到这里:

http://deeplearning.net/software/theano_versions/dev/tutorial/conv_arithmetic.html#transposed-convolution-arithmetic

https://www.zhihu.com/question/43609045

https://blog.csdn.net/weixin_44106928/article/details/103079668

这里总结一下有padding\stride的卷积操作:

假设,输入大小为(H,W,C),fileter大小为(FH,FW,C)*N ; padding=P, stride=S,卷积后的形状为(OH,OW,OC)

  1 def forward(self, x):
  2     '''
  3     使用im2col 将输入的x 转换成2D矩阵
  4     然后 y= w*x+b 以矩阵的形式完成
  5     最后返回y
  6     :param x: x为4D tensor, 输入数据
  7     :return: out=w*x+b
  8     '''
  9     FN, C, FH, FW = self.W.shape
 10     N, C, H, W = x.shape
 11     out_h = 1 + int((H + 2 * self.pad - FH) / self.stride)
 12     out_w = 1 + int((W + 2 * self.pad - FW) / self.stride)
 13 
 14     col = im2col(x, FH, FW, self.stride, self.pad)
 15     col_W = self.W.reshape(FN, -1).T
 16     print("col.shape=%s"%str(col.shape))
 17     print("col_W.shape=%s"%str(col_W.shape))
 18 
 19     out = np.dot(col, col_W)
 20     print("out.shape=%s"%str(out.shape))
 21     out=out+ self.b
 22     out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
 23 
 24     self.x = x
 25     self.col = col
 26     self.col_W = col_W
 27 
 28     return out
 29 

 

卷积的backward

概念介绍: https://zhuanlan.zhihu.com/p/33802329

卷积的backward是对卷积的求导。

代码实现如下:

  1 def backward(self, dout):
  2     '''
  3     反馈过程中也需要将2D 矩阵转换为4D tensor
  4     :param dout: 梯度差
  5     :return:
  6     '''
  7     FN, C, FH, FW = self.W.shape
  8     dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN) # NCHW
  9 
 10     self.db = np.sum(dout, axis=0)# NHWC , 求和
 11     self.dW = np.dot(self.col.T, dout) # 点乘w
 12     self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
 13 
 14     dcol = np.dot(dout, self.col_W.T)
 15     dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
 16 
 17     return dx

9.2 引入im2col 概念

再讲卷积的实现之前,首先抛出一个问题:如果按照上述的卷积方式计算,是否会影响性能?

答案是肯定会受影响的。

因此,我们需要向优化一下conv的计算方式.

按照“以空间换时间”的思想,我们可以做一些优化,使得在conv和pool的时候运算速度加快。

首先,我们知道Numpy对大型矩阵的运算是有做优化的,这个特点我们应该好好利用;

其次,我们知道Numpy在做多个嵌套的for循环的时候,O(n)会很大;应该避免做多个for循环;

因此,要是将4D的卷积运算转换成2D的矩阵乘法就会好很多;filter也可以变成2D的数组;

Im2col便是将4D数据转换成2D矩阵的函数。

该函数大致的思路是:filter按照行列展开成一个2D矩阵即可,input_data按照计算的单元重新组合。因此需要写一个函数将图像转换成2D矩阵,该函数可以将图像展开成适合与滤波去做乘法的矩阵。

展开和计算的流程如下:

 

9.3 单元测试im2col

对filter计算有影响的因素有input_data,filter_h,filter_w,stride, padding;im2col会应该根据以上的因因素展开input_data,展开后的input_data一定是比之前要大的;

我们可以尝试计算一下input_data展开后的数据形状:

假设,输入数据为4*4*3大小的tensor; filter有两个为2*(2*2*3),filter_h=2,filter_w=2,stride=1, padding=0;这里可以计算出展开以后的大小:

Filter为有两个,分别为f1和f2; shape=(2*2*3), 按照行展开成2D的矩阵以后如下图所示:

 

 

Input_data为4*4*3的tensor,如下图所示:

 

Input_data首先会找出filter对应的计算单元,这些还是需要padding\stride\filter_w\filter_h相关,找出计算的单元以后,按照行展开。最后得到的数据便是im2col的结果:

 

Input_data和filter这样展开以后,卷积计算就可以按照矩阵乘法的方式计算,避免了重复的for循环。如下图所示,黑色和灰色区域是计算的结果。不必担心矩阵过大是否会影响计算速度,Numpy对大规模矩阵乘法内部有优化加速,这样展开以后恰恰也能充分的利用numpy的特性。

 

Im2col的实现:

  1 def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
  2     '''
  3 
  4     :param input_data: 输入数据由4维数组组成(N,C,H,W)
  5     :param filter_h:   filer的高
  6     :param filter_w:   filter的宽
  7     :param stride:     stride
  8     :param pad:        padding
  9     :return:           2D矩阵
 10     '''
 11     # 计算输出的大小
 12     N, C, H, W = input_data.shape
 13     out_h = (H + 2*pad - filter_h)//stride + 1
 14     out_w = (W + 2*pad - filter_w)//stride + 1
 15     # padding
 16     img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
 17     col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
 18     # 计算单元
 19     for y in range(filter_h):
 20         y_max = y + stride*out_h
 21         for x in range(filter_w):
 22             x_max = x + stride*out_w
 23             col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
 24     # 重新排列
 25     col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
 26     return col

 

测试代码:

  1 # -*- coding: utf-8 -*-
  2 # @File  : test_im2col.py
  3 # @Author: lizhen
  4 # @Date  : 2020/2/14
  5 # @Desc  : 测试im2col
  6 import numpy as np
  7 
  8 from src.common.util import im2col,col2im
  9 
 10 if __name__ == '__main__':
 11     raw_data = [3, 0, 4, 2,
 12                 6, 5, 4, 3,
 13                 3, 0, 2, 3,
 14                 1, 0, 3, 1,
 15 
 16                 1, 2, 0, 1,
 17                 3, 0, 2, 4,
 18                 1, 0, 3, 2,
 19                 4, 3, 0, 1,
 20 
 21                 4, 2, 0, 1,
 22                 1, 2, 0, 4,
 23                 3, 0, 4, 2,
 24                 6, 2, 4, 5
 25     ]
 26 
 27     input_data = np.array(raw_data)
 28     input_data = input_data.reshape(1,3,4,4)
 29     print(input_data.shape)
 30     col1 = im2col(input_data=input_data,filter_h=2,filter_w=2,stride=1,pad=0)#input_data, filter_h, filter_w, stride=1, pad=0
 31     print(col1)
 32 

 

 

========输出:可以发现和上面的绘图的结果是一致的 =====

(1, 3, 4, 4)

[[3. 0. 6. 5. 1. 2. 3. 0. 4. 2. 1. 2.]

[0. 4. 5. 4. 2. 0. 0. 2. 2. 0. 2. 0.]

[4. 2. 4. 3. 0. 1. 2. 4. 0. 1. 0. 4.]

[6. 5. 3. 0. 3. 0. 1. 0. 1. 2. 3. 0.]

[5. 4. 0. 2. 0. 2. 0. 3. 2. 0. 0. 4.]

[4. 3. 2. 3. 2. 4. 3. 2. 0. 4. 4. 2.]

[3. 0. 1. 0. 1. 0. 4. 3. 3. 0. 6. 2.]

[0. 2. 0. 3. 0. 3. 3. 0. 0. 4. 2. 4.]

[2. 3. 3. 1. 3. 2. 0. 1. 4. 2. 4. 5.]]

 

9.3 卷积操作的实现

卷积操作也需要实现forward和backward函数。

Forward函数中用到了9.1\9.2的im2col

 

  1 class Convolution:
  2     def __init__(self, W, b, stride=1, pad=0):
  3         '''
  4         conv的构造函数
  5         :param W: 2D矩阵
  6         :param b:
  7         :param stride:
  8         :param pad:
  9         '''
 10         self.W = W
 11         self.b = b
 12         self.stride = stride
 13         self.pad = pad
 14 
 15         # 中间结果(backward的时候使用)
 16         self.x = None
 17         self.col = None
 18         self.col_W = None
 19 
 20         # 权重的梯度/偏置的梯度
 21         self.dW = None
 22         self.db = None
 23 
 24     def forward(self, x):
 25         '''
 26         使用im2col 将输入的x 转换成2D矩阵
 27         然后 y= w*x+b 以矩阵的形式完成
 28         最后返回y
 29         :param x: x为4D tensor, 输入数据
 30         :return: out=w*x+b
 31         '''
 32         FN, C, FH, FW = self.W.shape
 33         N, C, H, W = x.shape
 34         out_h = 1 + int((H + 2 * self.pad - FH) / self.stride)
 35         out_w = 1 + int((W + 2 * self.pad - FW) / self.stride)
 36 
 37         col = im2col(x, FH, FW, self.stride, self.pad)
 38         col_W = self.W.reshape(FN, -1).T
 39 
 40         out = np.dot(col, col_W) + self.b
 41         out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
 42 
 43         self.x = x
 44         self.col = col
 45         self.col_W = col_W
 46 
 47         return out
 48 
 49     def backward(self, dout):
 50         '''
 51         反馈过程中也需要将2D 矩阵转换为4D tensor
 52         :param dout: 梯度差
 53         :return:
 54         '''
 55         FN, C, FH, FW = self.W.shape
 56         dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
 57 
 58         self.db = np.sum(dout, axis=0)
 59         self.dW = np.dot(self.col.T, dout)
 60         self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
 61 
 62         dcol = np.dot(dout, self.col_W.T)
 63         dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
 64 
 65         return dx
 66 

9.4单元测试卷积操作

输入:input_data\filters

输出:output

测试代码:

 

  1   2 # -*- coding: utf-8 -*-
  3 # @File  : test_im2col.py
  4 # @Author: lizhen
  5 # @Date  : 2020/2/14
  6 # @Desc  : 测试im2col
  7 import numpy as np
  8 
  9 from src.common.util import im2col,col2im
 10 from src.common.layers import Convolution
 11 
 12 
 13 if __name__ == '__main__':
 14     raw_data = [3, 0, 4, 2,
 15                 6, 5, 4, 3,
 16                 3, 0, 2, 3,
 17                 1, 0, 3, 1,
 18 
 19                 1, 2, 0, 1,
 20                 3, 0, 2, 4,
 21                 1, 0, 3, 2,
 22                 4, 3, 0, 1,
 23 
 24                 4, 2, 0, 1,
 25                 1, 2, 0, 4,
 26                 3, 0, 4, 2,
 27                 6, 2, 4, 5
 28     ]
 29 
 30     raw_filter=[
 31         1,    1,    1,    1,    1,    1,
 32         1,    1,    1,    1,    1,    1,
 33         2,    2,    2,    2,    2,   2,
 34         2,    2,    2,    2,    2,   2,
 35 
 36     ]
 37 
 38 
 39 
 40     input_data = np.array(raw_data)
 41     filter_data = np.array(raw_filter)
 42 
 43     x = input_data.reshape(1,3,4,4)# NCHW
 44     W = filter_data.reshape(2,3,2,2) # NHWC
 45     b = np.zeros(2)
 46     # b = b.reshape((2,1))
 47     # col1 = im2col(input_data=x,filter_h=2,filter_w=2,stride=1,pad=0)#input_data, filter_h, filter_w, stride=1, pad=0
 48     # print(col1)
 49 
 50     print("input_data.shape=%s"%str(input_data.shape))
 51     print("W.shape=%s"%str(W.shape))
 52     print("b.shape=%s"%str(b.shape))
 53     conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0)
 54     out = conv.forward(x)
 55     print("bout.shape=%s"%str(out.shape))
 56     print(out)

Conv的输出结果,与上图的结果一致。

 

有关动手实现深度学习(12): 卷积层的实现与优化(img2col)的更多相关文章

  1. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  2. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  3. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  4. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  5. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  6. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

  7. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  8. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

  9. ruby - Arrays Sets 和 SortedSets 在 Ruby 中是如何实现的 - 2

    通常,数组被实现为内存块,集合被实现为HashMap,有序集合被实现为跳跃列表。在Ruby中也是如此吗?我正在尝试从性能和内存占用方面评估Ruby中不同容器的使用情况 最佳答案 数组是Ruby核心库的一部分。每个Ruby实现都有自己的数组实现。Ruby语言规范只规定了Ruby数组的行为,并没有规定任何特定的实现策略。它甚至没有指定任何会强制或至少建议特定实现策略的性能约束。然而,大多数Rubyist对数组的性能特征有一些期望,这会迫使不符合它们的实现变得默默无闻,因为实际上没有人会使用它:插入、前置或追加以及删除元素的最坏情况步骤复

  10. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

随机推荐