草庐IT

softmax回归详解

副歌微巫师 2023-04-13 原文

在一些其他场景中,我们的模型输出可能是一个图像类别这样的离散值,对于这样的离散值预测问题,可以使用 softmax 回归的分类模型。

1.1 分类问题

在一个简单图像分类问题中,输入图像的高和宽均是 2 像素,色彩为灰度,可以将图像中的 4 像素分别记为   ,假设训练集中图像的真实标签为狗、猫和鸡,也就是说通过这4种像素可以表示出这三种动物,这些标签对应着  。

1.2 softmax回归模型 

softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等
于标签中的类别数,对每个输入计算出输出:

 softmax回归是单层神经网络,每个输出的计算依赖于所有的输入

 

那么如何将输出的结果转换成对应的类别呢? 首先可以将输出值  中最大的输出所对应的类作为预测输出,例如  分别为 0.1,10,0.1,由于  最大 ,那么预测类别为 2 ,代表猫。

但是输出层的输出值的范围是不确定的,并且真实标签是离散值,难以计算出与输出值的误差。

softmax运算符解决了这些问题,通过下式将输出值变换成值为正并且和为 1 的概率分布:

其中

 

可以看出   ,且都小于1,若 ,无论剩余两个值是多少,图像类别是猫的概率为 80%。

1.3 小批量样本分类矢量计算表达式

给定一个小批量样本,批量大小为  ,输入特征数为  ,输出类别数为  ,设批量特征是,softmax 回归的权重和偏置为 ,计算表达式为:

1.4 交叉熵损失函数

可以使用线性回归那样的平方损失函数    但是想要预测分类结果正确,我们其实并不需要预测概率完全等于标签概率,只需要其中一个预测值比其他的都大就行了,即使  不管其他两个预测值为多少,类别预测均正确,而平方损失则过于严格。

交叉熵刻画的是两个概率分布之间的距离, 代表正确答案, 代表的是预测值,交叉熵越小,两个概率的分布越接近。

其中  为标签值, 为预测值

2.1获取数据集

在 softmax 中使用Fashion-MNIST数据集

导入需要的包

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

通过torch框架内置函数下载Fashion-MNIST数据集并读取到内存中

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

该数据集中包含10个类别的图像,每个类别由训练集中的6000张图像和测试集中的1000张图像组成,训练集和测试集分别包含60000和10000张图像。

len(mnist_train), len(mnist_test)
(60000, 10000)

每个输入图像的高度和宽度均为28像素

torch.Size([1, 28, 28])

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

展示训练集中前几个样本的图像以及标签

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

 2.2读取小批量数据

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())
PyTorch 的 DataLoader 中一个很方便的功能是允许用多进程来加速数据读取 。这里我们 设置 4 个进程读取数据。

2.3 初始化模型参数

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

原始数据集中的每个样本都是28×28的图像,我们将展平每个图像,把它们看作长度为784的向量,在softmax回归中,我们的输出与类别一样多。因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2.4 实现softmax运算

def softmax(X):
 X_exp = X.exp()
 partition = X_exp.sum(dim=1, keepdim=True)
 return X_exp / partition

2.5 定义模型

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

2.6 定义损失函数

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

cross_entropy(y_hat, y)

2.7 计算分类精度

当预测值与标签分类一致时,那么就是正确的。分类精度即正确预测数量与总预测数量之比。所以我们使用argmax获得每行中最大元素的索引来获得预测类别,然后将预测类别与真实标签进行比较。

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

2.8训练模型

和线性回归类似,同样使用小批量随机梯度下降来优化模型的损失函数。
 
num_epochs, lr = 5, 0.1
# 本函数已保存在d2lzh包中⽅便以后使⽤
def train_ch3(net, train_iter, test_iter, loss, num_epochs,
batch_size,params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
 
 # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
 
            l.backward()
            if optimizer is None:
                 d2l.sgd(params, lr, batch_size)
            else:
                 optimizer.step() 
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) ==y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n,
        train_acc_sum / n,test_acc))
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs,
batch_size, [W, b], lr)
epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

2.9 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定⼀系列图像(第三行图像输出),我们比较⼀下它们的真实标签(第⼀行文本输出)和模型预测结果(第二行文本输出)。

X, y = iter(test_iter).next()
true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels =
d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels,
pred_labels)]
d2l.show_fashion_mnist(X[0:9], titles[0:9])

 

有关softmax回归详解的更多相关文章

  1. 物联网MQTT协议详解 - 2

    一、什么是MQTT协议MessageQueuingTelemetryTransport:消息队列遥测传输协议。是一种基于客户端-服务端的发布/订阅模式。与HTTP一样,基于TCP/IP协议之上的通讯协议,提供有序、无损、双向连接,由IBM(蓝色巨人)发布。原理:(1)MQTT协议身份和消息格式有三种身份:发布者(Publish)、代理(Broker)(服务器)、订阅者(Subscribe)。其中,消息的发布者和订阅者都是客户端,消息代理是服务器,消息发布者可以同时是订阅者。MQTT传输的消息分为:主题(Topic)和负载(payload)两部分Topic,可以理解为消息的类型,订阅者订阅(Su

  2. Tcl脚本入门笔记详解(一) - 2

    TCL脚本语言简介•TCL(ToolCommandLanguage)是一种解释执行的脚本语言(ScriptingLanguage),它提供了通用的编程能力:支持变量、过程和控制结构;同时TCL还拥有一个功能强大的固有的核心命令集。TCL经常被用于快速原型开发,脚本编程,GUI和测试等方面。•实际上包含了两个部分:一个语言和一个库。首先,Tcl是一种简单的脚本语言,主要使用于发布命令给一些互交程序如文本编辑器、调试器和shell。由于TCL的解释器是用C\C++语言的过程库实现的,因此在某种意义上我们又可以把TCL看作C库,这个库中有丰富的用于扩展TCL命令的C\C++过程和函数,所以,Tcl是

  3. 【详解】Docker安装Elasticsearch7.16.1集群 - 2

    开门见山|拉取镜像dockerpullelasticsearch:7.16.1|配置存放的目录#存放配置文件的文件夹mkdir-p/opt/docker/elasticsearch/node-1/config#存放数据的文件夹mkdir-p/opt/docker/elasticsearch/node-1/data#存放运行日志的文件夹mkdir-p/opt/docker/elasticsearch/node-1/log#存放IK分词插件的文件夹mkdir-p/opt/docker/elasticsearch/node-1/plugins若你使用了moba,直接右键新建即可如上图所示依次类推创建

  4. 【Elasticsearch基础】Elasticsearch索引、文档以及映射操作详解 - 2

    文章目录概念索引相关操作创建索引更新副本查看索引删除索引索引的打开与关闭收缩索引索引别名查询索引别名文档相关操作新建文档查询文档更新文档删除文档映射相关操作查询文档映射创建静态映射创建索引并添加映射概念es中有三个概念要清楚,分别为索引、映射和文档(不用死记硬背,大概有个印象就可以)索引可理解为MySQL数据库;映射可理解为MySQL的表结构;文档可理解为MySQL表中的每行数据静态映射和动态映射上面已经介绍了,映射可理解为MySQL的表结构,在MySQL中,向表中插入数据是需要先创建表结构的;但在es中不必这样,可以直接插入文档,es可以根据插入的文档(数据),动态的创建映射(表结构),这就

  5. 最强Http缓存策略之强缓存和协商缓存的详解与应用实例 - 2

    HTTP缓存是指浏览器或者代理服务器将已经请求过的资源保存到本地,以便下次请求时能够直接从缓存中获取资源,从而减少网络请求次数,提高网页的加载速度和用户体验。缓存分为强缓存和协商缓存两种模式。一.强缓存强缓存是指浏览器直接从本地缓存中获取资源,而不需要向web服务器发出网络请求。这是因为浏览器在第一次请求资源时,服务器会在响应头中添加相关缓存的响应头,以表明该资源的缓存策略。常见的强缓存响应头如下所述:Cache-ControlCache-Control响应头是用于控制强制缓存和协商缓存的缓存策略。该响应头中的指令如下:max-age:指定该资源在本地缓存的最长有效时间,以秒为单位。例如:Ca

  6. IDEA 2022 创建 Spring Boot 项目详解 - 2

    如何用IDEA2022创建并初始化一个SpringBoot项目?目录如何用IDEA2022创建并初始化一个SpringBoot项目?0. 环境说明1.  创建SpringBoot项目 2.编写初始化代码0. 环境说明IDEA2022.3.1JDK1.8SpringBoot1.  创建SpringBoot项目        打开IDEA,选择NewProject创建项目。        填写项目名称、项目构建方式、jdk版本,按需要修改项目文件路径等信息。        选择springboot版本以及需要的包,此处只选择了springweb。        此处需特别注意,若你使用的是jdk1

  7. 详解Unity中的粒子系统Particle System (二) - 2

    前言上一篇我们简要讲述了粒子系统是什么,如何添加,以及基本模块的介绍,以及对于曲线和颜色编辑器的讲解。从本篇开始,我们将按照模块结构讲解下去,本篇主要讲粒子系统的主模块,该模块主要是控制粒子的初始状态和全局属性的,以下是关于该模块的介绍,请大家指正。目录前言本系列提要一、粒子系统主模块1.阅读前注意事项2.参考图3.参数讲解DurationLoopingPrewarmStartDelayStartLifetimeStartSpeed3DStartSizeStartSize3DStartRotationStartRotationFlipRotationStartColorGravityModif

  8. VMware虚拟机与本地主机进行磁盘共享(详解) - 2

    VMware虚拟机与本地主机进行磁盘共享前提虚拟机版本为Windows10(专业版,不是可能有问题)本地主机为家庭版或学生版(此版本会有问题,但有替代方式)最好是专业版VMware操作1.关闭防火墙,全部关闭。2.打开电脑属性3.点击共享-》高级共享-》权限4.如果没有everyone,就添加权限选择完全控制,然后应用确定。5.打开cmd输入lusrmgr.msc(只有专业版可以打开)如果不是专业版,可以跳过这一步。点击用户-》administrator密码要复杂密码,否则不行。推荐admaiN@1234类型的密码。设置完密码,点击属性,将禁用解开。6.如果虚拟机的windows不是专业版,可

  9. ruby - 逻辑回归给出不正确的结果 - 2

    我在一个网站上工作,收集人们玩过的国际象棋比赛的结果。查看玩家的评分以及他们与对手的评分之间的差异,我绘制了一个图表,其中的点代表获胜(绿色)、平局(蓝色)和失败(红色)。根据这些信息,我还实现了逻辑回归算法来对获胜和获胜/平局的截止值进行分类。使用评级和差异作为我的两个特征,我得到了一个分类器,然后在图表上绘制了分类器改变其预测的边界。我的梯度下降、成本函数和sigmoid函数的代码如下。defgradient_descent()oldJ=0newJ=J()alpha=1.0#Learningraterun=0while(run0.001))thenrun-=20end#Do20mo

  10. Ruby 曲线拟合(对数回归)包 - 2

    我正在寻找进行对数回归(对数方程的曲线拟合)的Rubygem或库。我试过statsample(http://ruby-statsample.rubyforge.org/),但它似乎没有我要找的东西。有人有什么建议吗? 最佳答案 尝试使用“statsample”gem。您可以使用类似的方法执行指数、对数、幂、正弦或任何其他变换。我希望这有帮助。require'statsample'#IndependentVariablex_data=[Math.exp(1),Math.exp(2),Math.exp(3),Math.exp(4),Ma

随机推荐