草庐IT

ID3 决策树的原理、构造及可视化(附完整源代码)

白白净净吃了没病 2024-06-02 原文

目录

一、本文的问题定义和(决策树中)信息熵的回顾

① 本文的问题定义

②(决策树中)信息熵的回顾

二、ID3 决策树的原理及构造

三、ID3 决策树的可视化源码(含构造过程)

四、ID3 决策树可视化的效果及测试结果

① ID3 决策树可视化的效果

② ID3 决策树的文本化结果和用例的测试结果

五、ID3 算法的优缺点


说明:

1、第一节至第三节来源于《机器学习及应用》李克清 时允田主编一书,大约在 57 页的位置。

2、源代码部分是我根据书中原理并参考源码后,自己重写。其中,源代码中的变量的定义对应第二节介绍的原理部分的数学符号,以便于适合对应学习。源代码中的注释是根据自己的理解所写。

3、本文是自己的学习过程的记录,还望读者海涵。如果有幸对大家产生帮助,不胜感激。


一、本文的问题定义和(决策树中)信息熵的回顾

① 本文的问题定义


②(决策树中)信息熵的回顾

(决策树中的)信息熵和样本分类的信息熵计算源代码_白白净净吃了没病的博客-CSDN博客本文包含(决策树)中的信息熵和实现源码,以及极为详细的注释和说明https://angxiao.blog.csdn.net/article/details/127156554


二、ID3 决策树的原理及构造

本文不写任何复杂通用的公式,以书中的例子作为本文的例子,个人觉得能够更通俗易懂:

继续对其他属性逐个计算信息增益,直到不能划分为止。在这个过程,不断找到具有最大信息熵的特征,采用递归思想来构造子树,最终构造出 ID3 的分类决策树。


三、ID3 决策树的可视化源码(含构造过程)

① main.py

from math import log2

# import treePlotter # 导入失败
import help  # 新建模块,然后导入


class ID3Tree(object):
    def __init__(self, data, cols):
        self.all_data = data  # 初始化数据集
        self.all_cols = cols  # 初始化数据集的各个列名(类别名)
        self.tree = {}  # 初始化ID3决策树

    def train(self):
        self.tree = self.make_tree(self.all_data, self.all_cols)

    def make_tree(self, data, cols):
        """
        :param data: 数据集:可能是总集,也可能是子集
        :param cols: 列名(特征名):可能是全列名,也可能是当前数据集去掉最大信息熵特征后的列名集
        :return:树
        """
        all_label_datas = [item[-1] for item in data]  # 所有的标签对应的所有值
        # 1、如果这个数据集的全部标签都是一样的,那么没有属性划分的必要,决策树就一个叶子节点(即拿这个标签直接作为决策)
        if all_label_datas.count(all_label_datas[0]) == len(all_label_datas):
            return all_label_datas[0]
        # 2、如果这个数据集中每条数据(默认每条数据的长度和格式都一样)没有任何属性,只有标签
        # 我们看哪类标签出现的次数最多,直接拿它作为决策结果,这种情况决策树也就一个叶子结点
        elif len(data[0]) == 1:
            # 初始化出现的次数
            max_num = all_label_datas.count(all_label_datas[0])
            # 初始化出现次数最多的标签
            max_sort_data = all_label_datas[0]
            # set对原列表去重,但不改变原列表
            for i in list(set(all_label_datas)):
                if all_label_datas.count(all_label_datas[i]) > max_num:
                    max_num = all_label_datas.count(all_label_datas[i])
                    max_sort_data = i
            return max_sort_data
        # 3、正常情况,我们来构建决策树
        # *** 选取信息熵最大的属性(特征)***
        best_xns_feature_index = self.find_best_xns_feature(data)  # 找到香农熵最大的特征的下标
        best_feature_label = cols[best_xns_feature_index]  # 找到香农熵最大的特征的名称
        tree = {best_feature_label: {}}  # 构造一个(新的)树结点,一个根节点,大括号是子树
        del (cols[best_xns_feature_index])  # 删除数据集中香农熵最大的特征所在的列
        # 抽取最大增益的特征对应的列的数据
        best_xns_feature_values = [item[best_xns_feature_index] for item in data]
        for value in list(set(best_xns_feature_values)):
            # 此时的all_data是上次all_data去掉一列特征得到的
            sub_cols = cols
            sub_data = self.construct_new_dataset(data, best_xns_feature_index, value)
            # 递归构造子树
            tree[best_feature_label][value] = self.make_tree(sub_data, sub_cols)  # 向子树中放入值
        return tree

    def find_best_xns_feature(self, data):
        """
        计算各个特征的香农熵的大小,并返回香农熵最大的特征的下标
        :return: 香农熵最大的特征的下标
        """
        data_num = len(data)  # 数据集中样本的总数
        feature_nums = len(data[0]) - 1  # 数据集中所有特征的数量,-1是因为数据中不止有特征,还有标签
        I = self.calculate_xns(data)  # 数据集(样本标签)的香农熵
        best_xns_feature_value = 0  # 初始化香农熵最大的特征的值
        best_xns_feature_index = -1  # 初始化香农熵最大的特征的下标

        for i in range(feature_nums):
            feat_values = [number[i] for number in data]  # 得到某个特征列(随机变量)下的所有值
            feat_sorts = set(feat_values)  # 去重,得到特征的所有无重复的取值
            E = 0  # 初始化当前特征的信息熵
            # 对当前特征下具有相同特征值的子集,根据正负样本算出信息熵,并乘以prob。在不同特征值下计算完后,进行加和,得到E
            for value in feat_sorts:
                sub_dataset = self.construct_new_dataset(data, i, value)  # 得到i特征下,特征值为value的数据,去除特征i构成的集合
                prob = len(sub_dataset) / float(data_num)  # 特征i的值为value的数据所占的比例
                E += prob * self.calculate_xns(sub_dataset)
            # 用 I 减去 E,得到当前特征的信息增益gain
            gain = I - E  # 当前i特征的信息增益
            # 保留最大的信息熵及其对应的特征索引
            if gain > best_xns_feature_value:
                best_xns_feature_value = gain
                best_xns_feature_index = i

        return best_xns_feature_index  # 返回最大信息增益的特征的下标

    def construct_new_dataset(self, data, axis, value):
        """
        从数据集的某个特征中,选取值为某个特征值的数据,并去掉此特征,然后将这类数据构成新的数据集
        比如,在性别这个特征中,把特征值是男的数据抽出来,然后把这些数据的性别列去掉,构成数据集
        :param data:数据集
        :param axis:数据集中某个特征在数据中的索引
        :param value:此特征下的一个特征值
        :return:数据集中特征值是给定特征值的数据构成的子集
        """
        remain_dataset = []
        for item in data:  # 数据集中的每条数据
            if item[axis] == value:  # 如果这条数据的特征等于给定的某个特征值时
                # 把此条数据去掉这个特征列,重构此条数据
                remain_data = item[:axis]
                remain_data.extend(item[axis + 1:])
                remain_dataset.append(remain_data)  # 将重构后的数据加入列表中
        return remain_dataset

    def calculate_xns(self, data):
        """
        计算给定数据集的香农熵(信息熵)
        :return:数据集的香农熵
        """
        xns = 0.0  # 香农熵
        data_num = len(data)  # 样本集的总数,用于计算分类标签出现的概率

        # 将数据集样本标签的特征值(分类值)放入列表
        all_labels = [c[-1] for c in data]  # c[-1]:即取数据集中的每条数据的标签:Yes 或 No
        # print(all_labels)  # 得到 [Yes,No,No,...] 的结果
        # 按标签的种类进行统计,Yes这一类几个;No这一类几个
        every_label = {}  # 以词典形式存储每个类别(键)及个数(值)
        for item in list(set(all_labels)):  # 对每个类别计数,并放入词典, 其中set(all_labels) = [Yes,No]
            every_label[item] = all_labels.count(item)
        # 计算样本标签的香农熵,即数据集的香农熵
        for item2 in every_label:
            prob = every_label[item2] / float(data_num)  # 每个特征值出现的概率
            xns -= prob * log2(prob)  # xns是全局变量,这样就可以计算关于决策的要考虑的某个随机变量(如收入特征)的香农熵
        return xns


if __name__ == "__main__":
    dataset = [['sunny', 'hot', 'high', 'weak', 'NO'],
               ['sunny', 'hot', 'high', 'strong', 'NO'],
               ['overcast', 'hot', 'high', 'weak', 'YES'],
               ['rain', 'mild', 'high', 'weak', 'YES'],
               ['rain', 'cool', 'normal', 'weak', 'YES'],
               ['rain', 'cool', 'normal', 'strong', 'NO'],
               ['overcast', 'cool', 'normal', 'strong', 'YES'],
               ['sunny', 'mild', 'high', 'weak', 'NO'],
               ['sunny', 'cool', 'normal', 'weak', 'YES'],
               ['rain', 'mild', 'normal', 'weak', 'YES'],
               ['sunny', 'mild', 'normal', 'strong', 'YES'],
               ['overcast', 'mild', 'high', 'strong', 'YES'],
               ['overcast', 'hot', 'normal', 'weak', 'YES'],
               ['rain', 'mild', 'high', 'strong', 'NO']]
    # 前四列的名字(特征列)分别为天气、温度、湿度、风速
    labels = ['Outlook', 'Temp', 'Humidity', 'Windy']
    id3 = ID3Tree(dataset, labels)  # 实例化决策树对象
    id3.train()
    print(id3.tree)  # 输出决策树
    # treeplotter.createPlot(id3.tree) # 因treePlotter不能直接导入,这里会报错
    help.createPlot(id3.tree)  # 可视化决策树

    # 给定新一天的气象数据指标,根据决策树,来判断是否会去打球
    def predict_play(tree, new_dic):
        """
        根据构造的决策树,对未知数据进行预测
        :param tree: 决策树(根据已知数据构造的)
        :param new_dic: 一条待预测的数据
        :return:返回叶子节点,也就是最终的决策
        """
        while type(tree).__name__ == "dict":
            key = list(tree.keys())[0]
            tree = tree[key][new_dic[key]]
        return tree


    # 输出决策结果
    print(predict_play(id3.tree, {'Outlook': 'rain', 'Temp': 'mild', 'Humidity': 'high', 'Windy': 'weak'}))

② help.py 

由于 treePlotter这个模块一直导入失败,目前未知原因。因此使用并在 main.py 中导入以下这个模块,用于构建 ID3 决策树。

import matplotlib.pyplot as plt

"""绘决策树的函数"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 定义分支点的样式
leafNode = dict(boxstyle="round4", fc="0.8")  # 定义叶节点的样式
arrow_args = dict(arrowstyle="<-")  # 定义箭头标识样式


# 计算树的叶子节点数量
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 画出节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center",
                            bbox=nodeType, arrowprops=arrow_args)


# 标箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
    lens = len(txtString)
    xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
    yMid = (parentPt[1] + cntrPt[1]) / 2.0
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.x0ff + \
              (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.x0ff = -0.5 / plotTree.totalW
    plotTree.y0ff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

四、ID3 决策树可视化的效果及测试结果

① ID3 决策树可视化的效果


② ID3 决策树的文本化结果和用例的测试结果


五、ID3 算法的优缺点

ID3 算法是决策树的经典构建算法,它根据信息增益来评估和选择特征进行划分,每次选择信息增益最大的特征作为判断的模块(即特征节点),可用于划分标称型数据集(即数据中没有缺省特征值的数据集),虽然 ID3 比较灵活方便,但是有以下几个缺点:

(1)采用信息增益进行分裂,缺少剪枝的过程,很可能会出现过拟合的问题。我们可以合并相邻的无法产生大量信息增益的叶子结点(如设置信息增益阈值)。

(2)信息增益和属性的值域范围成正比,也就是有些特征(属性)取值很多,ID3算法很大可能将其作为分裂属性,导致分裂的精确度可能没有采用信息增益率进行分裂高。

(3)不能处理连续分布的数据特征,只能通过将连续性数据转化为离散型数据来解决,也不能处理数据集中的缺省值。

有关ID3 决策树的原理、构造及可视化(附完整源代码)的更多相关文章

  1. ruby - 如何在 buildr 项目中使用 Ruby 代码? - 2

    如何在buildr项目中使用Ruby?我在很多不同的项目中使用过Ruby、JRuby、Java和Clojure。我目前正在使用我的标准Ruby开发一个模拟应用程序,我想尝试使用Clojure后端(我确实喜欢功能代码)以及JRubygui和测试套件。我还可以看到在未来的不同项目中使用Scala作为后端。我想我要为我的项目尝试一下buildr(http://buildr.apache.org/),但我注意到buildr似乎没有设置为在项目中使用JRuby代码本身!这看起来有点傻,因为该工具旨在统一通用的JVM语言并且是在ruby中构建的。除了将输出的jar包含在一个独特的、仅限ruby​​

  2. ruby-on-rails - Rails 源代码 : initialize hash in a weird way? - 2

    在rails源中:https://github.com/rails/rails/blob/master/activesupport/lib/active_support/lazy_load_hooks.rb可以看到以下内容@load_hooks=Hash.new{|h,k|h[k]=[]}在IRB中,它只是初始化一个空哈希。和做有什么区别@load_hooks=Hash.new 最佳答案 查看rubydocumentationforHashnew→new_hashclicktotogglesourcenew(obj)→new_has

  3. ruby-on-rails - 浏览 Ruby 源代码 - 2

    我的主要目标是能够完全理解我正在使用的库/gem。我尝试在Github上从头到尾阅读源代码,但这真的很难。我认为更有趣、更温和的踏脚石就是在使用时阅读每个库/gem方法的源代码。例如,我想知道RubyonRails中的redirect_to方法是如何工作的:如何查找redirect_to方法的源代码?我知道在pry中我可以执行类似show-methodmethod的操作,但我如何才能对Rails框架中的方法执行此操作?您对我如何更好地理解Gem及其API有什么建议吗?仅仅阅读源代码似乎真的很难,尤其是对于框架。谢谢! 最佳答案 Ru

  4. ruby - 模块嵌套代码风格偏好 - 2

    我的假设是moduleAmoduleBendend和moduleA::Bend是一样的。我能够从thisblog找到解决方案,thisSOthread和andthisSOthread.为什么以及什么时候应该更喜欢紧凑语法A::B而不是另一个,因为它显然有一个缺点?我有一种直觉,它可能与性能有关,因为在更多命名空间中查找常量需要更多计算。但是我无法通过对普通类进行基准测试来验证这一点。 最佳答案 这两种写作方法经常被混淆。首先要说的是,据我所知,没有可衡量的性能差异。(在下面的书面示例中不断查找)最明显的区别,可能也是最著名的,是你的

  5. ruby - Ruby 中的波形可视化 - 2

    我即将开始一个将录制和编辑音频文件的项目,我正在寻找一个好的库(最好是Ruby,但会考虑Java或.NET以外的任何库)以进行实时可视化波形。有人知道我应该从哪里开始搜索吗? 最佳答案 要流入浏览器的数据量很大。Flash或Flex图表可能是唯一能提高内存效率的解决方案。Javascript图表往往会因大型数据集而崩溃。 关于ruby-Ruby中的波形可视化,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.c

  6. ruby - 寻找通过阅读代码确定编程语言的ruby gem? - 2

    几个月前,我读了一篇关于ruby​​gem的博客文章,它可以通过阅读代码本身来确定编程语言。对于我的生活,我不记得博客或gem的名称。谷歌搜索“ruby编程语言猜测”及其变体也无济于事。有人碰巧知道相关gem的名称吗? 最佳答案 是这个吗:http://github.com/chrislo/sourceclassifier/tree/master 关于ruby-寻找通过阅读代码确定编程语言的rubygem?,我们在StackOverflow上找到一个类似的问题:

  7. ruby - Net::HTTP 获取源代码和状态 - 2

    我目前正在使用以下方法获取页面的源代码:Net::HTTP.get(URI.parse(page.url))我还想获取HTTP状态,而无需发出第二个请求。有没有办法用另一种方法做到这一点?我一直在查看文档,但似乎找不到我要找的东西。 最佳答案 在我看来,除非您需要一些真正的低级访问或控制,否则最好使用Ruby的内置Open::URI模块:require'open-uri'io=open('http://www.example.org/')#=>#body=io.read[0,50]#=>"["200","OK"]io.base_ur

  8. 程序员如何提高代码能力? - 2

    前言作为一名程序员,自己的本质工作就是做程序开发,那么程序开发的时候最直接的体现就是代码,检验一个程序员技术水平的一个核心环节就是开发时候的代码能力。众所周知,程序开发的水平提升是一个循序渐进的过程,每一位程序员都是从“菜鸟”变成“大神”的,所以程序员在程序开发过程中的代码能力也是根据平时开发中的业务实践来积累和提升的。提高代码能力核心要素程序员要想提高自身代码能力,尤其是新晋程序员的代码能力有很大的提升空间的时候,需要针对性的去提高自己的代码能力。提高代码能力其实有几个比较关键的点,只要把握住这些方面,就能很好的、快速的提高自己的一部分代码能力。1、多去阅读开源项目,如有机会可以亲自参与开源

  9. 7个大一C语言必学的程序 / C语言经典代码大全 - 2

    嗨~大家好,这里是可莉!今天给大家带来的是7个C语言的经典基础代码~那一起往下看下去把【程序一】打印100到200之间的素数#includeintmain(){ inti; for(i=100;i 【程序二】输出乘法口诀表#includeintmain(){inti;for(i=1;i 【程序三】判断1000年---2000年之间的闰年#includeintmain(){intyear;for(year=1000;year 【程序四】给定两个整形变量的值,将两个值的内容进行交换。这里提供两种方法来进行交换,第一种为创建临时变量来进行交换,第二种是不创建临时变量而直接进行交换。1.创建临时变量来

  10. git使用常见问题(提交代码,合并冲突) - 2

    文章目录git常用命令(简介,详细参数往下看)Git提交代码步骤gitpullgitstatusgitaddgitcommitgitpushgit代码冲突合并问题方法一:放弃本地代码方法二:合并代码常用命令以及详细参数gitadd将文件添加到仓库:gitdiff比较文件异同gitlog查看历史记录gitreset代码回滚版本库相关操作远程仓库相关操作分支相关操作创建分支查看分支:gitbranch合并分支:gitmerge删除分支:gitbranch-ddev查看分支合并图:gitlog–graph–pretty=oneline–abbrev-commit撤消某次提交git用户名密码相关配置g

随机推荐