草庐IT

度量学习——总结

不说话装高手H 2023-04-03 原文

传统方法

User guide: contents — metric-learn 0.6.2 documentation

深度学习

基于深度学习的度量学习方法大都由两个部分组成:特征提取模块和距离度量模块。距离度量模块的任务是使同一类样本间的距离更靠近,而不同类样本间的距离更远离。这一模块更多的实现方法是改进损失函数,对模型的学习更加“赏罚分明”。

基于正负样本对的方法

也可以称为基于对比学习的方法,抽出正负样本对学习。对比学习的方法现在正广泛的应用于学习更好的特征提取模块,即用自监督学习的方法来学习更好的特征表达,更强大的 backbone,如 MoCo、SimCLR 等。Contrastive Representation Learning | Lil'Log

而为了学习更好的距离度量模块,越来越多基于样本对的损失函数被提出:

Contrastive Loss

是最简单也最直观的损失函数:

直观上分析这个公式,当两个样本的标签相同时,模型的损失函数值为这两个样本在特征空间内的距离,这时梯度回传是为了使这两个样本更“靠近”。而当这两个样本的标签不同时,这里的 α 是 margin,当这两个样本对的在特征空间的距离大于 margin 时,就使损失为0,即不更新网络维持现状,当小于 margin 时,我们就惩罚模型,使这两个样本的距离不断逼近 margin。同时 margin 更重要的作用是避免模型欺骗损失函数,即将所有的样本都映射到特征空间的同一个点,学到一个“捷径”使损失不断接近0。

Triplet Loss

上面的方法只有一个样本对,Triplet Loss 则引入了正负样本对的概念。 

xa 称为 anchor 样本,xp 为正样本,xn 为负样本。这里的 margin 同样有避免模型将所有样本映射到特征空间中同一个点的作用。同时 Triplet Loss 的一个关键点是负例挖掘(Negative Samples Mining),将 anchor 样本与正样本间的距离尽可能限制到0附近,同时将 anchor 样本与负样本的距离推开至 margin 左右,

The only requirement is that given two positive examples of the same class and one negative example, the negative should be farther away than the positive by some margin.

Triplet Loss 实现:Triplet Loss and Online Triplet Mining in TensorFlow | Olivier Moindrot blog

Quadruplet Loss

是对 Triplet Loss 的更进一步,其中 xa、xp 和 xs 都属于同一类:

Structured Loss

Triplet Loss 中只考虑了一个负例,而忽略了其他负例。与 Triplet Loss 不同,Structured Loss 考虑的是 batch 中所有的正样本对,以及距离正样本对中两个点最近的负样本:

N-Pair Loss 也是考虑到 Triplet Loss 中的这个缺点而提出的,只不过处理方法与 Structured Loss 不同。

基于交叉熵的方法

上述基于对比方法的正负样本采样问题在多卡训练的情况下变得尤其复杂,而且不能保证具有相似标签的样本被很好地分开,为了解决这两个问题,越来越多基于交叉熵的方法被提出。这些方法都是基于最基本的交叉熵损失函数(也可以称为 softmax loss):

其中 wz+b 是输出分类结果前的那一个全连接层。将模型的输出,标签为 i 的特征向量 z,投影到类别 i 的权重 wi 上,从几何上来说这个结果就是 vector center ,即这个向量 z 在特征空间中的映射点。将在 MNIST 数据集上用 softmax loss 训练的网络中特征的分布可视化(左边为 train,右边为 val):

可以看到还有一些部分不能被很好地分开,即使是在这样简单的一个数据集上。

Center Loss

在 Softmax Loss 的基础上加上正则项,将不同类样本在特征空间内“推开”。

其中 z 是全连接层的输入,c 是可学习的向量,可以理解为每一个类中特征向量的移动平均值(moving mean vector)。

可以看到 Center Loss 将每一类聚类到类别中心,在特征空间内更好地将类别分开。

SphereFace

Center Loss 的问题是,我们不能预先知道数据集中各类别聚类中心在特征空间内是相互远离的,如果他们很靠近的话,这个正则项就起不了多大的作用。但是如果我们将每个类别的聚类中心放到距离特征空间“原点”相同的距离,即将聚类中心都映射到一个“圆”上,只要这个“圆”的半径够大,理论上我们就可以把各类别的聚类中心“推”的更加分散。这个“圆”其实就是一个超平面。这也就是 SphereFace 的做法,它是由 Softmax Loss 一步步进化而来:

由向量积的公式我们知道:

所以上述公式中的做法就是将各类别的权重 w 正则化,正则化后我们完成了将聚类中心“放置”在超平面上的操作,之后再将全连接层的偏置置0,为了可以更简单的分析。公式中的 θ 就是 z 和 wi 之间的夹角,它是大于0小于 pai 的。

在 Softmax Loss 推理时,我们将特征向量 z 通过全连接层,也就是将 z 映射到各个类别的权重上,哪一个结果大,那么它就属于哪一类。反应到下图中,将 z(图中的 x)在特征空间的映射向每一个类别的权重(W1 和 W2)做垂直平分线,这个交点到特征空间“原点”的距离即为分类依据,这个距离就是 Cosθ,也就是全连接层的输出

在 SphereFace 中,由于我们对 w 和 b 进行了一些操作,并且 z 进行了正则化是一个常数。所以当 z 和各类别的权重 w 之间的夹角 θ 更小,那么它就属于哪一类。这样的“决策边界”依然不能保证十分正确的分类,因为我们没有对 z 在特征空间的映射到各聚类中心点在超平面上的距离施加正则或者惩罚项。这也是 SphereFace 第二个创新点,margin μ。SphereFace 公式如下:

所以在推理时只有当 z 与一个类别 w 的夹角 μθ 大于与其他类别的夹角 θ 时,模型才会判定 z 属于这个类别。也就是说 θ 被限定到了如下区间:

通过损失函数影响模型,让模型将 z 映射到特征空间内更小的角度,这样在推理的时候可以更好地判别。下图可以看出 SphereFace 的效果,图中两个 w 之间的红线即为决策边界。

图1

SphereFace 开创了用角度距离来完成分类的先河,接下来的几种方法都是基于此提出。

CosFace

指出 SphereFace 用计算出的角度经过 Cos 函数输出特征向量的调整结果(或者说调整特征向量在特征空间内的映射),但是 Cos 函数不是单调的,所以给优化带来了困难。同时只通过角度的余弦值来判断属于哪一类的话,会导致类别间的距离有的大有的小,降低了区分能力。

CosFace 的做法是将特征向量 z 也进行正则化,同时加上两个超参数 s 和 m:

其中 s 是放缩系数,而 m 是 margin。直观理解是:将垂直平分线交点往特征空间“原点”拉近,类别中各样本在特征空间的映射就会更加靠近“原点”,如图1中每个类别画出的“圆弧”会更加小,及增大了决策边界的角度。

其中 s 和 m 的选择颇为讲究:

其中 K 是特征维度,C 是数据集中类别数,PW 是 expected minimum posterior probability of class center。随着类别数的增加,类别之间 cosine margin 的上限相应地减少。

ArcFace

论文:https://arxiv.org/pdf/1801.07698.pdf

对全连接层的分类输出进行调整,再计算交叉熵损失。

从 softmax loss 的看到这里,这张图就不难理解了。

对全连接层的输入和权重进行正则化后结果为 cosθ,再将其乘上一个超参数 s:

将 cos(θyi) 用 cos(θyi+m) 代替,这部分是 ArcFace 的核心,其背后的意义是是直接在角度空间(angular space)中最大化分类界限。而 CosFace 是将类别映射的更紧凑以期望来达到最大化分类界限的目的,与 ArcFace 在公式上的区别就是增加 margin 的位置。m 为超参数 margin:

下面代码算出的是 L3 中的指数函数 e 的输入,用 ArcFace 进行调整后输入到交叉熵损失中输出损失:

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0,
                 m=0.30, easy_margin=False, ls_eps=0.0, device="cpu"):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

        self.device = device

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # F.normalize(input)、F.normalize(self.weight) 是公式中对输入和权重的正则化
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m    # 三角公式
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        # one_hot = torch.zeros(cosine.size(), device=self.device)
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

AdaCos

论文:https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhang_AdaCos_Adaptively_Scaling_Cosine_Logits_for_Effectively_Learning_Deep_Face_CVPR_2019_paper.pdf

这篇文章中对 ArcFace 中的两个超参数 s 和 m 进行了消融实验,P 是 softmax 后归一化的后验概率:

对 s:

当 s 过小时,模型的分类概率达不到 1,这样模型无法做出“自信”的判断,就导致损失函数惩罚了正例;而当 S 过大时,模型过于自信,这时损失函数无法正确地惩罚负例。

对 m:

当 m 过大,当 θ 变得稍大时,模型就不会将其判断为这一类,可以证明加上了 margin 的损失函数比不加 margin 的损失函数使模型的预测的细粒度更小。

Sub-center ArcFace

论文:https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf

解决了 Sphere Face、CosFace 和 ArcFace 对噪声数据敏感的问题,免去了数据清洗的工作更贴近日常生活中的数据。为每一类设定 K 个子中心而不是像之前的做法每一类中只有一个。这样样本的大部分都会靠近 dominant centers,而那些 noisy / hard sample 则会被推向其他的 undominant centers。

但这样做破坏了类内的紧致性,对此文章中的做法是当网络具有足够的识别能力后,直接去掉那些 undominant centers。同时引入了一个恒定的角度阈值来降低高置信噪声数据,在此之后在自动清理的数据集上从头开始重新训练模型。

其中

对应下列代码只是公式中 cosine, _ = torch.max(cosine_all, dim=2) 这个操作:

class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine

下面大部分都是 ArcFace 中的操作,其中 phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) 理解为引入了一个恒定的角度阈值来降低高置信噪声数据:

class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, out_dim, s):
        super().__init__()
        self.crit = DenseCrossEntropy()
        self.s = s
        self.register_buffer('margins', torch.tensor(margins, device="cuda:0"))
        self.out_dim = out_dim

    def forward(self, logits, labels):
        # ms = []
        # ms = self.margins[labels.cpu().numpy()]
        ms = self.margins[labels]
        cos_m = torch.cos(ms)  # torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.sin(ms)  # torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.cos(math.pi - ms)  # torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.sin(math.pi - ms) * ms  # torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, self.out_dim)
        labels = labels.half() if CFG.MIXED_PRECISION else labels.float()
        cosine = logits
        sine = torch.sqrt(1.0 - cosine * cosine)
        phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
        phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
        output = (labels * phi) + ((1.0 - labels) * cosine)
        output *= self.s
        loss = self.crit(output, labels)
        return loss

ArcFace with Dynamic Margin

为了应对严重的类别不均衡而提出,样本更少的类应该具有更大的 margin,以期望更好地与其他类分开。每一类的 margin:

其中 a 和 b 控制着 margin 的上下界,n 为各类别的样本数,λ 控制着这个函数的形状。

参考

Deep Metric Learning: a (Long) Survey – Chan Kha Vu

有关度量学习——总结的更多相关文章

  1. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  2. SPI接收数据异常问题总结 - 2

    SPI接收数据左移一位问题目录SPI接收数据左移一位问题一、问题描述二、问题分析三、探究原理四、经验总结最近在工作在学习调试SPI的过程中遇到一个问题——接收数据整体向左移了一位(1bit)。SPI数据收发是数据交换,因此接收数据时从第二个字节开始才是有效数据,也就是数据整体向右移一个字节(1byte)。请教前辈之后也没有得到解决,通过在网上查阅前人经验终于解决问题,所以写一个避坑经验总结。实际背景:MCU与一款芯片使用spi通信,MCU作为主机,芯片作为从机。这款芯片采用的是它规定的六线SPI,多了两根线:RDY和INT,这样从机就可以主动请求主机给主机发送数据了。一、问题描述根据从机芯片手

  3. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  4. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  5. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  6. ruby - 我如何学习 ruby​​ 的正则表达式? - 2

    如何学习ruby​​的正则表达式?(对于假人) 最佳答案 http://www.rubular.com/在Ruby中使用正则表达式时是一个很棒的工具,因为它可以立即将结果可视化。 关于ruby-我如何学习ruby​​的正则表达式?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/1881231/

  7. 深度学习12. CNN经典网络 VGG16 - 2

    深度学习12.CNN经典网络VGG16一、简介1.VGG来源2.VGG分类3.不同模型的参数数量4.3x3卷积核的好处5.关于学习率调度6.批归一化二、VGG16层分析1.层划分2.参数展开过程图解3.参数传递示例4.VGG16各层参数数量三、代码分析1.VGG16模型定义2.训练3.测试一、简介1.VGG来源VGG(VisualGeometryGroup)是一个视觉几何组在2014年提出的深度卷积神经网络架构。VGG在2014年ImageNet图像分类竞赛亚军,定位竞赛冠军;VGG网络采用连续的小卷积核(3x3)和池化层构建深度神经网络,网络深度可以达到16层或19层,其中VGG16和VGG

  8. 机器学习——时间序列ARIMA模型(四):自相关函数ACF和偏自相关函数PACF用于判断ARIMA模型中p、q参数取值 - 2

    文章目录1、自相关函数ACF2、偏自相关函数PACF3、ARIMA(p,d,q)的阶数判断4、代码实现1、引入所需依赖2、数据读取与处理3、一阶差分与绘图4、ACF5、PACF1、自相关函数ACF自相关函数反映了同一序列在不同时序的取值之间的相关性。公式:ACF(k)=ρk=Cov(yt,yt−k)Var(yt)ACF(k)=\rho_{k}=\frac{Cov(y_{t},y_{t-k})}{Var(y_{t})}ACF(k)=ρk​=Var(yt​)Cov(yt​,yt−k​)​其中分子用于求协方差矩阵,分母用于计算样本方差。求出的ACF值为[-1,1]。但对于一个平稳的AR模型,求出其滞

  9. Unity Shader 学习笔记(5)Shader变体、Shader属性定义技巧、自定义材质面板 - 2

    写在之前Shader变体、Shader属性定义技巧、自定义材质面板,这三个知识点任何一个单拿出来都是一套知识体系,不能一概而论,本文章目的在于将学习和实际工作中遇见的问题进行总结,类似于网络笔记之用,方便后续回顾查看,如有以偏概全、不祥不尽之处,还望海涵。1、Shader变体先看一段代码......Properties{ [KeywordEnum(on,off)]USL_USE_COL("IsUseColorMixTex?",int)=0 [Toggle(IS_RED_ON)]_IsRed("IsRed?",int)=0}......//中间省略,后续会有完整代码 #pragmamulti_c

  10. Simulink方法总结和避坑指南(一)——Simulink入门与基本调试方法 - 2

    文章目录一、项目场景二、基本模块原理与调试方法分析——信源部分:三、信号处理部分和显示部分:四、基本的通信链路搭建:四、特殊模块:interpretedMATLABfunction:五、总结和坑点提醒一、项目场景  最近一个任务是使用simulink搭建一个MIMO串扰消除的链路,并用实际收到的数据进行测试,在搭建的过程中也遇到了不少的问题(当然这比vivado里面的debug好不知道多少倍)。准备趁着这个机会,先以一个很基本的通信链路对simulink基础和相关的debug方法进行总结。  在本篇中,主要记录simulink的基本原理和基本的SISO通信传输链路(QPSK方式),计划在下篇记

随机推荐