草庐IT

pytorch 计算混淆矩阵

:)�东东要拼命 2023-12-14 原文

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏

 预测对了 为对角线 

还可以通过矩阵的上下角发现哪些容易出错

从这个 矩阵出发 可以得到 acc != precision recall  特异度?

 

 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率是什么 什么是准确率 什么是召回率_:)�东东要拼命的博客-CSDN博客

 acc  是对所有类别来说的

其他三个都是 对于类别来说的

下面给出源码 

import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable
from torchvision import datasets
from torchvision.models import MobileNetV2
from torchvision.transforms import transforms


class ConfusionMatrix(object):
    """
    注意版本问题,使用numpy来进行数值计算的
    """

    def __init__(self, num_classes: int, labels: list):
            self.matrix = np.zeros((num_classes, num_classes))
            self.num_classes = num_classes
            self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[t, p] += 1

# 行代表预测标签 列表示真实标签




    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("acc is", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.fields_names = ["", "pre", "recall", "spec"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            pre = round(TP / (TP + FP), 3)    # round 保留三位小数
            recall = round(TP / (TP + FN), 3)
            spec = round(TN / (FP + FN), 3)
            table.add_row([self.labels[i], pre, recall, spec])
        print(table)


    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)  # 颜色变化从白色到蓝色

        # 设置 x  轴坐标 label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 将原来的 x 轴的数字替换成我们想要的信息 self.num_classes  x 轴旋转45度
        # 设置 y  轴坐标 label
        plt.yticks(range(self.num_classes), self.labels)

        # 显示 color bar  可以通过颜色的密度看出数值的分布
        plt.colorbar()
        plt.xlabel("true_label")
        plt.ylabel("Predicted_label")
        plt.title("ConfusionMatrix")

        # 在图中标注数量 概率信息
        thresh = matrix.max() / 2
        # 设定阈值来设定数值文本的颜色 开始遍历图像的时候一般是图像的左上角
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 这里矩阵的行列交换,因为遍历的方向 第y行 第x列
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        # 图形显示更加的紧凑
        plt.show()



if __name__ ==' __main__':
    device = torch.device("cuda:0" if torch.cuda.is_available()else "cpu")
    print(device)
    # 使用验证集的预处理方式
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor()
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    data_loot = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # get data root path
    image_path = data_loot + "/data_set/flower_data/"
    # flower data set path

    validate_dataset = datasets.ImageFolder(root=image_path +"val",
                                            transform=data_transform)

    batch_size = 16
    validate_loader = torch.utils.data.DataLoder(validate_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=2)

    net = MobileNetV2(num_classes=5)
    #加载预训练的权重
    model_weight_path = "./MobileNetV2.pth"
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)

    #read class_indict
    try:
        json_file = open('./class_indicts.json', 'r')
        class_indict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)


    labels = [label for _, label in class_indict.item()]
    # 通过json文件读出来的label
    confusion = ConfusionMatrix(num_classes=5, labels=labels)
    net.eval()
    # 启动验证模式
    # 通过上下文管理器  no_grad  来停止pytorch的变量对梯度的跟踪
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # 获取概率最大的元素
            confusion.update(outputs.numpy(), val_labels.numpy())
            # 预测值和标签值
    confusion.plot()
    # 绘制混淆矩阵
    confusion.summary()
    # 来打印各个指标信息
































是这样的 这篇算是一个学习笔记,其中的基础图都源于我的导师

 霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

欢迎无依无靠的CV同学加入 

讲的非常好 代码其实也是导师给的 

我能做的就是读懂每一行加点注释

给不想看视频的同学留点时间

有关pytorch 计算混淆矩阵的更多相关文章

  1. ruby-on-rails - 使用一系列等级计算字母等级 - 2

    这里是Ruby新手。完成一些练习后碰壁了。练习:计算一系列成绩的字母等级创建一个方法get_grade来接受测试分数数组。数组中的每个分数应介于0和100之间,其中100是最大分数。计算平均分并将字母等级作为字符串返回,即“A”、“B”、“C”、“D”、“E”或“F”。我一直返回错误:avg.rb:1:syntaxerror,unexpectedtLBRACK,expecting')'defget_grade([100,90,80])^avg.rb:1:syntaxerror,unexpected')',expecting$end这是我目前所拥有的。我想坚持使用下面的方法或.join,

  2. 旋转矩阵的几何意义 - 2

    点向量坐标矩阵的几何意义介绍旋转矩阵的几何含义之前,先介绍一下点向量坐标矩阵的几何含义点:在一维空间下就是一个标量,如同一条直线上,以任意某一个位置为0点,以一定的尺度间隔为1,2,3...,相反方向为-1,-2,-3...;如此就形成了一维坐标系,这时候任何一个点都可以用一个数值表示,如点p1=5,即即从原点出发沿着x轴正方向移动5个尺度;点p2=-3,负方向移动3个尺度;     在一维坐标系上过原点做垂直于一维坐标系的直线,则形成了二维坐标系,此时描述一个点需要两个数值来表示点p3=(3,2),即从原点出发沿着x轴正方向移动3个尺度,在此基础上沿着y轴正方向移动两个尺度的位置就是点p3。

  3. 计算机毕业设计ssm+vue基本微信小程序的小学生兴趣延时班预约小程序 - 2

    项目介绍随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱小学生兴趣延时班预约小程序的设计与开发被用户普遍使用,为方便用户能够可以随时进行小学生兴趣延时班预约小程序的设计与开发的数据信息管理,特开发了小程序的设计与开发的管理系统。小学生兴趣延时班预约小程序的设计与开发的开发利用现有的成熟技术参考,以源代码为模板,分析功能调整与小学生兴趣延时班预约小程序的设计与开发的实际需求相结合,讨论了小学生兴趣延时班预约小程序的设计与开发的使用。开发环境开发说明:前端使用微信微信小程序开发工具:后端使用ssm:VU

  4. ruby - 如何计算 Liquid 中的变量 +1 - 2

    我对如何计算通过{%assignvar=0%}赋值的变量加一完全感到困惑。这应该是最简单的任务。到目前为止,这是我尝试过的:{%assignamount=0%}{%forvariantinproduct.variants%}{%assignamount=amount+1%}{%endfor%}Amount:{{amount}}结果总是0。也许我忽略了一些明显的东西。也许有更好的方法。我想要存档的只是获取运行的迭代次数。 最佳答案 因为{{incrementamount}}将输出您的变量值并且不会影响{%assign%}定义的变量,我

  5. ruby - 使用 Ruby,计算 n x m 数组的每一列中有多少个 true 的简单方法是什么? - 2

    给定一个nxmbool数组:[[true,true,false],[false,true,true],[false,true,true]]有什么简单的方法可以返回“该列中有多少个true?”结果应该是[1,3,2] 最佳答案 使用转置得到一个数组,其中每个子数组代表一列,然后将每一列映射到其中的true数:arr.transpose.map{|subarr|subarr.count(true)}这是一个带有inject的版本,应该在1.8.6上运行,没有任何依赖:arr.transpose.map{|subarr|subarr.in

  6. arrays - 计算数组中的匹配元素 - 2

    给定两个大小相等的数组,如何找到不考虑位置的匹配元素的数量?例如:[0,0,5]和[0,5,5]将返回2的匹配项,因为有一个0和一个5共同;[1,0,0,3]和[0,0,1,4]将返回3的匹配项,因为0有两场,1有一场;[1,2,2,3]和[1,2,3,4]将返回3的匹配项。我尝试了很多想法,但它们都变得相当粗糙和令人费解。我猜想有一些不错的Ruby习惯用法,或者可能是一个正则表达式,可以很好地回答这个解决方案。 最佳答案 您可以使用count完成它:a.count{|e|index=b.index(e)andb.delete_at

  7. ruby - 关于 Ruby 中 Dir[] 和 File.join() 的混淆 - 2

    我在Ruby中遇到了一个关于Dir[]和File.join()的简单程序,blobs_dir='/path/to/dir'Dir[File.join(blobs_dir,"**","*")].eachdo|file|FileUtils.rm_rf(file)ifFile.symlink?(file)我有两个困惑:首先,File.join(@blobs_dir,"**","*")中的第二个和第三个参数是什么意思?其次,Dir[]在Ruby中有什么用?我只知道它等价于Dir.glob(),但是,我对Dir.glob()确实不是很清楚。 最佳答案

  8. ruby-on-rails - 如何计算 Ruby/Rails 中 JSON 对象的数量 - 2

    Ruby中如何“一般地”计算以下格式(有根、无根)的JSON对象的数量?一般来说,我的意思是元素可能不同(例如“标题”被称为其他东西)。没有根:{[{"title":"Post1","body":"Hello!"},{"title":"Post2","body":"Goodbye!"}]}根包裹:{"posts":[{"title":"Post1","body":"Hello!"},{"title":"Post2","body":"Goodbye!"}]} 最佳答案 首先,withoutroot代码不是有效的json格式。它将没有包

  9. ruby - 如何计算自 Ruby 中给定日期以来的周数? - 2

    目标我正在尝试计算自给定日期以来周的距离,而无需跳过任何步骤。我更喜欢用普通的Ruby来做,但ActiveSupport无疑是一个可以接受的选择。我的代码我写了以下内容,这似乎可行,但对我来说似乎还有很长的路要走。require'date'DAYS_IN_WEEK=7.0defweeks_sincedate_stringdate=Date.parsedate_stringdays=Date.today-dateweeks=days/DAYS_IN_WEEKweeks.round2endweeks_since'2015-06-15'#=>32.57ActiveSupport的#weeks

  10. 华为OD机试真题 C++ 实现【带传送阵的矩阵游离】【2023 Q2 | 200分】 - 2

            所有题目均有五种语言实现。C实现目录、C++实现目录、Python实现目录、Java实现目录、JavaScript实现目录题目n行m列的矩阵,每个位置上有一个元素你可以上下左右行走,代价是前后两个位置元素值差的绝对值.另外,你最多可以使用一次传送阵(只能从一个数跳到另外一个相同的数)求从走上角走到右下角最少需要多少时间。输入描述:第一行两个整数n,m,分别代表矩阵的行和列。后面n行,每行m个整数,分别代表矩阵中的元素。输出描述:一个整数,表示最少需要多少时间。

随机推荐