草庐IT

DenseNet详解

tt丫 2023-09-06 原文

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

✨完整代码在我的github上,有需要的朋友可以康康✨

https://github.com/tt-s-t/Deep-Learning.git

目录

一、DenseNet网络的背景

二、DenseNet网络结构

1、Dense Block——特征重用

2、Transition层

3、网络结构

四、DenseNet优缺点

1、优点

(1)相比ResNet拥有更少的参数数量

(2)传播与预测都保留了低层次的特征

(3)旁路加强了特征的重用,导致直接的监督

(4)网络更易于训练,并具有一定的正则化效果

(5)缓解了梯度消失/爆炸和网络退化的问题

2、不足

五、DenseNet代码实现

1、DenseLayer

2、DenseBlock

3、Transition

4、DenseNet整体构建


一、DenseNet网络的背景

       DenseNet模型的基本思路与ResNet一致,但它建立的是前面所有层与后面层的密集连接(即相加变连结),它的名称也是由此而来。

      DenseNet的另一大特色是通过特征在通道上的连接来实现特征重用。这些特点让DenseNet的参数量和计算成本都变得更少了(相对ResNet),效果也更好了。

      ResNet解决了深层网络梯度消失问题,它是从深度方向研究的。宽度方向是GoogleNet的Inception。而DenseNet是从feature入手,通过对feature的极致利用能达到更好的效果和减少参数。

      DenseNet斩获CVPR 2017的最佳论文奖。


二、DenseNet网络结构

1、Dense Block——特征重用

       DenseBlock包含很多层,每个层的特征图大小相同(才可以在通道上进行连结),层与层之间采用密集连接方式。

如下图所示:

 

        上图是一个包含5层layer的Dense Block。可以看出Dense Block互相连接所有的层,具体来说就是每一层的输入都来自于它前面所有层的特征图,每一层的输出均会直接连接到它后面所有层的输入。所以对于一个L层的DenseBlock,共包含 L*(L+1)/2 个连接(等差数列求和公式),如果是ResNet的话则为(L-1)*2+1。从这里可以看出:相比ResNet,Dense Block采用密集连接。而且Dense Block是直接concat来自不同层的特征图,这可以实现特征重用(即对不同“级别”的特征——不同表征进行总体性地再探索),提升效率,这一特点是DenseNet与ResNet最主要的区别。

Note:k —— DenseNet中的growth rate(增长率),这是一个超参数。一般情况下使用较小的k(比如12),就可以得到较佳的性能。

假定输入层的特征图的通道数为k0,那么L层输入的channel数为 k0+k*(L-1),因此随着层数增加,尽管k设定得较小,DenseBlock中每一层输入依旧会越来越多。

        另外一个特殊的点:DenseBlock中采用BN+ReLU+Conv的结构,平常我们常见的是Conv+BN+ReLU。这么做的原因是:卷积层的输入包含了它前面所有层的输出特征,它们来自不同层的输出,因此数值分布差异比较大,所以它们在输入到下一个卷积层时,必须先经过BN层将其数值进行标准化,然后再进行卷积操作。

       通常为了减少参数,一般还会先加一个1x1conv来减少参数量。所以DenseBlock中的每一层采用BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv的结构(或者BottleNeck)。

2、Transition层

       它主要用于连接两个相邻的DenseBlock,整合上一个DenseBlock获得的特征,缩小上一个DenseBlock的宽高,达到下采样效果,特征图的宽高减半。Transition层包括一个1x1卷积(用于调整通道数)和2x2AvgPooling(用于降低特征图大小),结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。因此,Transition层可以起到压缩模型的作用 。

      超参调节:θ是压缩系数,取值(0,1],当θ=1时,原feature维度不变,即无压缩;

而当压缩系数小于1时,这种结构称为DenseNet-C(文中使用θ=0.5);

对于使用bottleneck层的DenseBlock结构和压缩系数小于1的Transition组合结构称为DenseNet-BC。后面在使用的DenseNet默认都是DenseNet-BC,因为它的效果最好。

3、网络结构

      DenseNet的网络结构主要由DenseBlock和Transition组成,一个DenseNet中有3个或4个DenseBlock。而一个DenseBlock中也会有多个Bottleneck layers。最后的DenseBlock之后是一个global AvgPooling层,然后送入一个softmax分类器,得到每个类别所属分数。


四、DenseNet优缺点

1、优点

(1)相比ResNet拥有更少的参数数量

参数减少,计算效率更高,效果更好(相较于其他网络)

(2)传播与预测都保留了低层次的特征

在以前的卷积神经网络中,最终输出只会利用最高层次的特征。而DenseNet实现特征重用,同时利用低层次和高层次的特征。

(3)旁路加强了特征的重用,导致直接的监督

因为每一层都建立起了与前面层的连接,误差信号可以很容易地传播到较早的层,所以较早的层可以从最终分类层获得直接的监督。

(4)网络更易于训练,并具有一定的正则化效果

(网上资料都有说这一句,但是我不太清楚他是怎么体现正则化效果的)

(5)缓解了梯度消失/爆炸和网络退化的问题

特征重用实现了梯度的提前传播,也至少保留了前面网络的能力,不至于变弱(最少也是个恒等变换)

2、不足

由于需要进行多次Concatnate操作,数据需要被复制多次,显存容易增加得很快,需要一定的显存优化技术。因此在训练过程中,训练的时间要比Resnet作为backbone长很多。所以相对而言,ResNet更常用。

并且ResNet更加的简洁,变体也多,更加成熟,因此后来更多使用的是ResNet,但是DenseNet的思想贡献也是如今很常见的。


五、DenseNet代码实现

完整代码在我的github上,用自行搭建的DenseNet实现对cifar10数据的分类(里面超参数是我一拍脑袋设的,要有好的分类效果还要调一调,这里主要是想自己搭一下网络实现而已)

https://github.com/tt-s-t/Deep-Learning.git

以下展示网络搭建

1、DenseLayer

import torch
import torch.nn as nn
import torch.nn.functional as F

class DenseLayer(nn.Sequential):
    """Basic unit of DenseBlock (using bottleneck layer) """
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(DenseLayer, self).__init__()
        self.bn1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(num_input_features, bn_size*growth_rate,
                                           kernel_size=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(bn_size*growth_rate)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(bn_size*growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)
        self.drop_rate = drop_rate

    def forward(self, x):
        output = self.bn1(x)
        output = self.relu1(output)
        output = self.conv1(output)

        output = self.bn2(output)
        output = self.relu2(output)
        output = self.conv2(output)

        if self.drop_rate > 0:
            output = F.dropout(output, p=self.drop_rate)
        return torch.cat([x, output], 1)

2、DenseBlock

class DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(DenseBlock, self).__init__()
        for i in range(num_layers):
            if i == 0:
                self.layer = nn.Sequential(
                    DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,drop_rate)
                )
            else:
                layer = DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,drop_rate)
                self.layer.add_module("denselayer%d" % (i+1), layer)
    
    def forward(self,input):
        return self.layer(input)

3、Transition

class Transition(nn.Sequential):
    def __init__(self, num_input_feature, num_output_features):
        super(Transition, self).__init__()
        self.bn = nn.BatchNorm2d(num_input_feature)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(num_input_feature, num_output_features,
                                          kernel_size=1, stride=1, bias=False)
        self.pool = nn.AvgPool2d(2, stride=2)

    def forward(self,input):
        output = self.bn(input)
        output = self.relu(output)
        output = self.conv(output)
        output = self.pool(output)

        return output

4、DenseNet整体构建

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
        super(DenseNet, self).__init__()

        # 前部
        self.features = nn.Sequential(
            #第一层
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(),
            #第二层
            nn.MaxPool2d(3, stride=2, padding=1)
        )

        # DenseBlock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(num_layers, num_features, bn_size, growth_rate,drop_rate)
            if i == 0:
                self.block_tran = nn.Sequential(
                    block
                )
            else:
                self.block_tran.add_module("denseblock%d" % (i + 1), block)#添加一个block
            num_features += num_layers*growth_rate#更新通道数
            if i != len(block_config) - 1:#除去最后一层不需要加Transition来连接两个相邻的DenseBlock
                transition = Transition(num_features, int(num_features*compression_rate))
                self.block_tran.add_module("transition%d" % (i + 1), transition)#添加Transition
                num_features = int(num_features * compression_rate)#更新通道数

        # 后部 bn+ReLU
        self.tail = nn.Sequential(
            nn.BatchNorm2d(num_features),
            nn.ReLU()
        )

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # params initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):#如果是卷积层,参数kaiming分布处理
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):#如果是批量归一化则伸缩参数为1,偏移为0
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):#如果是线性层偏移为0
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        block_output = self.block_tran(features)
        tail_output = self.tail(block_output)
        out = F.avg_pool2d(tail_output, 7, stride=1).view(tail_output.size(0), -1)#平均池化
        out = self.classifier(out)
        return out

欢迎大家在评论区批评指正,谢谢大家~

有关DenseNet详解的更多相关文章

  1. 物联网MQTT协议详解 - 2

    一、什么是MQTT协议MessageQueuingTelemetryTransport:消息队列遥测传输协议。是一种基于客户端-服务端的发布/订阅模式。与HTTP一样,基于TCP/IP协议之上的通讯协议,提供有序、无损、双向连接,由IBM(蓝色巨人)发布。原理:(1)MQTT协议身份和消息格式有三种身份:发布者(Publish)、代理(Broker)(服务器)、订阅者(Subscribe)。其中,消息的发布者和订阅者都是客户端,消息代理是服务器,消息发布者可以同时是订阅者。MQTT传输的消息分为:主题(Topic)和负载(payload)两部分Topic,可以理解为消息的类型,订阅者订阅(Su

  2. Tcl脚本入门笔记详解(一) - 2

    TCL脚本语言简介•TCL(ToolCommandLanguage)是一种解释执行的脚本语言(ScriptingLanguage),它提供了通用的编程能力:支持变量、过程和控制结构;同时TCL还拥有一个功能强大的固有的核心命令集。TCL经常被用于快速原型开发,脚本编程,GUI和测试等方面。•实际上包含了两个部分:一个语言和一个库。首先,Tcl是一种简单的脚本语言,主要使用于发布命令给一些互交程序如文本编辑器、调试器和shell。由于TCL的解释器是用C\C++语言的过程库实现的,因此在某种意义上我们又可以把TCL看作C库,这个库中有丰富的用于扩展TCL命令的C\C++过程和函数,所以,Tcl是

  3. 【详解】Docker安装Elasticsearch7.16.1集群 - 2

    开门见山|拉取镜像dockerpullelasticsearch:7.16.1|配置存放的目录#存放配置文件的文件夹mkdir-p/opt/docker/elasticsearch/node-1/config#存放数据的文件夹mkdir-p/opt/docker/elasticsearch/node-1/data#存放运行日志的文件夹mkdir-p/opt/docker/elasticsearch/node-1/log#存放IK分词插件的文件夹mkdir-p/opt/docker/elasticsearch/node-1/plugins若你使用了moba,直接右键新建即可如上图所示依次类推创建

  4. 【Elasticsearch基础】Elasticsearch索引、文档以及映射操作详解 - 2

    文章目录概念索引相关操作创建索引更新副本查看索引删除索引索引的打开与关闭收缩索引索引别名查询索引别名文档相关操作新建文档查询文档更新文档删除文档映射相关操作查询文档映射创建静态映射创建索引并添加映射概念es中有三个概念要清楚,分别为索引、映射和文档(不用死记硬背,大概有个印象就可以)索引可理解为MySQL数据库;映射可理解为MySQL的表结构;文档可理解为MySQL表中的每行数据静态映射和动态映射上面已经介绍了,映射可理解为MySQL的表结构,在MySQL中,向表中插入数据是需要先创建表结构的;但在es中不必这样,可以直接插入文档,es可以根据插入的文档(数据),动态的创建映射(表结构),这就

  5. 最强Http缓存策略之强缓存和协商缓存的详解与应用实例 - 2

    HTTP缓存是指浏览器或者代理服务器将已经请求过的资源保存到本地,以便下次请求时能够直接从缓存中获取资源,从而减少网络请求次数,提高网页的加载速度和用户体验。缓存分为强缓存和协商缓存两种模式。一.强缓存强缓存是指浏览器直接从本地缓存中获取资源,而不需要向web服务器发出网络请求。这是因为浏览器在第一次请求资源时,服务器会在响应头中添加相关缓存的响应头,以表明该资源的缓存策略。常见的强缓存响应头如下所述:Cache-ControlCache-Control响应头是用于控制强制缓存和协商缓存的缓存策略。该响应头中的指令如下:max-age:指定该资源在本地缓存的最长有效时间,以秒为单位。例如:Ca

  6. IDEA 2022 创建 Spring Boot 项目详解 - 2

    如何用IDEA2022创建并初始化一个SpringBoot项目?目录如何用IDEA2022创建并初始化一个SpringBoot项目?0. 环境说明1.  创建SpringBoot项目 2.编写初始化代码0. 环境说明IDEA2022.3.1JDK1.8SpringBoot1.  创建SpringBoot项目        打开IDEA,选择NewProject创建项目。        填写项目名称、项目构建方式、jdk版本,按需要修改项目文件路径等信息。        选择springboot版本以及需要的包,此处只选择了springweb。        此处需特别注意,若你使用的是jdk1

  7. 详解Unity中的粒子系统Particle System (二) - 2

    前言上一篇我们简要讲述了粒子系统是什么,如何添加,以及基本模块的介绍,以及对于曲线和颜色编辑器的讲解。从本篇开始,我们将按照模块结构讲解下去,本篇主要讲粒子系统的主模块,该模块主要是控制粒子的初始状态和全局属性的,以下是关于该模块的介绍,请大家指正。目录前言本系列提要一、粒子系统主模块1.阅读前注意事项2.参考图3.参数讲解DurationLoopingPrewarmStartDelayStartLifetimeStartSpeed3DStartSizeStartSize3DStartRotationStartRotationFlipRotationStartColorGravityModif

  8. VMware虚拟机与本地主机进行磁盘共享(详解) - 2

    VMware虚拟机与本地主机进行磁盘共享前提虚拟机版本为Windows10(专业版,不是可能有问题)本地主机为家庭版或学生版(此版本会有问题,但有替代方式)最好是专业版VMware操作1.关闭防火墙,全部关闭。2.打开电脑属性3.点击共享-》高级共享-》权限4.如果没有everyone,就添加权限选择完全控制,然后应用确定。5.打开cmd输入lusrmgr.msc(只有专业版可以打开)如果不是专业版,可以跳过这一步。点击用户-》administrator密码要复杂密码,否则不行。推荐admaiN@1234类型的密码。设置完密码,点击属性,将禁用解开。6.如果虚拟机的windows不是专业版,可

  9. ElasticSearch之 ik分词器详解 - 2

    IK分词器本文分为简介、安装、使用三个角度进行讲解。简介倒排索引众所周知,ES是一个及其强大的搜索引擎,那么它为什么搜索效率极高呢,当然和他的存储方式脱离不了关系,ES采取的是倒排索引,就是反向索引;常见索引结构几乎都是通过key找value,例如Map;倒排索引的优势就是有效利用Value,将多个含有相同Value的值存储至同一位置。分词器为了配合倒排索引,分词器也就诞生了,只有合理的利用Value,才会让倒排索引更加高效,如果一整个Value不进行任何操作直接进行存储,那么Value和key毫无区别。分词器Analyzer通常会对Value进行操作:一、字符过滤,过滤掉html标签;二、分

  10. Educational Codeforces Round 146 (Rated for Div. 2)(B,E详解) - 2

    题外话:抑郁场,开局一小时只出A,死活想不来B,最后因为D题出锅ura才保住可怜的分。但咱本来就写不到DB-LongLegs(数论)本题题解法一学自同样抑郁的知乎作者幽血魅影的题解,有讲解原理。法二来着知乎巨佬cup-pyy(大佬说《不难发现》呜呜)题意三种操作:向上走mmm步向右走mmm步给自己一次走的步数加111,即使得m=m+1m=m+1m=m+1问从(0,0)(0,0)(0,0)走到(a,b)(a,b)(a,b)的最小操作次数,值得注意的是操作三不可逆。解析假设我们最终一步的大小增长到mmm,那么在这个过程中我能以[1,m][1,m][1,m](当步数增长到该数时)之间的任何数字向上或

随机推荐