草庐IT

XGBoost算法介绍

月落乌啼silence 2024-01-22 原文

XGBoost算法介绍

一、简介

  XGBoost(eXtreme Gradient Boosting)又叫极度梯度提升树,是boosting算法的一种实现方式。针对分类或回归问题,效果非常好。在各种数据竞赛中大放异彩,而且在工业界也是应用广泛,主要是因为其效果优异,使用简单,速度快等优点。本文主要从以下几个方面介绍该算法模型:

二、基本原理

  xgb是boosting算法的一种实现方式,主要是降低偏差,也就是降低模型的误差。因此它是采用多个基学习器,每个基学习器都比较简单,避免过拟合,下一个学习器是学习前面基学习器的结果 y i t y^{t}_{i} yit和实际值 y i y_{i} yi的差值,通过多个学习器的学习,不断降低模型值和实际值的差。
y i 0 = 0 y_{i}^{0} = 0 yi0=0
y i 1 = f 1 ( x i ) = y i 0 + f 1 ( x i ) y_{i}^{1} = f_{1}(x_{i}) = y_{i}^{0}+f_{1}(x_{i}) yi1=f1(xi)=yi0+f1(xi)
$ y i 2 = f 1 ( x i ) + f 2 ( x i ) = y i 1 + f 2 ( x i ) y_{i}^{2}=f_{1}(x_{i})+f_{2}(x_{i})=y_{i}^{1}+f_{2}(x_{i}) yi2=f1(xi)+f2(xi)=yi1+f2(xi)
y i t = ∑ k = 1 t f k ( x i ) = y i t − 1 + f t ( x i ) y_{i}^{t}=\sum_{k=1}^{t}f_{k}(x_{i})=y_{i}^{t-1}+f_{t}(x_{i}) yit=k=1tfk(xi)=yit1+ft(xi)
基本思路就是不断生成新的树,每棵树都是基于上一颗树和目标值的差值来进行学习,从而降低模型的偏差。最终模型结果的输出如下: y i = ∑ k = 1 t f k ( x i ) y_{i}=\sum_{k=1}^{t}f_{k}(x_{i}) yi=k=1tfk(xi),即所有树的结果累加起来才是模型对一个样本的预测值。那在每一步如何选择/生成一个较优的树呢?那就是由我们的目标函数来决定。

三、目标函数

  目标函数由两部分组成,一是模型误差,即样本真实值和预测值之间的差值,二是模型的结构误差,即正则项,用于限制模型的复杂度。
O b j ( θ ) = L ( θ ) + Ω ( θ ) = L ( y i , y i t ) + ∑ k = 1 t Ω ( f k ( x i ) ) Obj(\theta)=L(\theta)+\Omega(\theta)=L(y_{i},y_{i}^{t})+\sum_{k=1}^{t}\Omega(f_{k}(x_{i})) Obj(θ)=L(θ)+Ω(θ)=L(yi,yit)+k=1tΩ(fk(xi))
因为 y i t = y i t − 1 + f t ( x i ) y_{i}^{t}=y_{i}^{t-1}+f_{t}(x_{i}) yit=yit1+ft(xi),所以将其带入上面的公式中转换为:
O b j t = ∑ n = 1 n L ( y i , y i t − 1 + f t ( x i ) ) + Ω ( f t ) + ∑ t = 1 T − 1 Ω ( f t ) Obj^{t}=\sum_{n=1}^{n}L(y_{i},y_{i}^{t-1}+f_{t}(x_{i}))+\Omega(f_{t})+\sum_{t=1}^{T-1}\Omega(f_{t}) Objt=n=1nL(yi,yit1+ft(xi))+Ω(ft)+t=1T1Ω(ft),第t颗树的误差由三部分组成,n个样本在第t颗树的误差求和,以及第t颗树的结构误差和前t-1颗树的结构误差。其中前t-1颗树的结构误差是常数,因为我们已经知道前t-1颗树的结构了。
  假设我们的损失函数是平方损失函数(mse),则上述目标函数转换为:
O b j t = ∑ i = 1 n L ( y i , y i t − 1 + f t ( x i ) ) + Ω ( f t ) + ∑ t = 1 T − 1 Ω ( f t ) = ∑ i = 1 n ( y i − ( y i t − 1 + f t ( x i ) ) ) 2 + Ω ( f t ) + c o n s t a n t Obj^{t}=\sum_{i=1}^{n}L(y_{i},y_{i}^{t-1}+f_{t}(x_{i}))+\Omega(f_{t})+\sum_{t=1}^{T-1}\Omega(f_{t}) \\ =\sum_{i=1}^{n}(y_{i}-(y_{i}^{t-1}+f_{t}(x_{i})))^2+\Omega(f_{t})+constant Objt=i=1nL(yi,yit1+ft(xi))+Ω(ft)+t=1T1Ω(ft)=i=1n(yi(yit1+ft(xi)))2+Ω(ft)+constant
上述公式即为损失函数为mse时xgb第t步的目标函数。唯一的变量即为 f t f_{t} ft,此处的损失函数仍然是一个相对复杂的表达式,所以为了简化它,采用二阶泰勒展开来近似表达,即 f ( x + Δ x ) = f ( x ) + f ′ ( x ) Δ x + 1 / 2 f ′ ′ ( x ) Δ x 2 f(x+\Delta x)=f(x)+f^{'}(x)\Delta x+1/2f^{''}(x)\Delta x^2 f(x+Δx)=f(x)+f(x)Δx+1/2f(x)Δx2,所以另 g i = ∂ y i t − 1 l ( y i , y i t − 1 ) g_{i}=\partial _{y_{i}^{t-1}}l(y_{i},y_{i}^{t-1}) gi=yit1l(yi,yit1) h i = ∂ y i t − 1 2 l ( y i , y i t − 1 ) h_{i}=\partial _{y_{i}^{t-1}} ^ 2 l(y_{i},y_{i}^{t-1}) hi=yit12l(yi,yit1),即分别是 l ( y i , y i t − 1 ) l(y_{i},y_{i}^{t-1}) l(yi,yit1)的一阶导和二阶导。则上述损失函数转换为二阶导之后, O b j t = ∑ i = 1 n [ l ( y i , y i t − 1 ) + g i f t ( x ) + 1 / 2 h i f t 2 ( x ) ] + Ω ( f t ) + c o n s t a n t Obj^{t}=\sum_{i=1}^{n}[l(y_{i},y_{i}^{t-1})+g_{i} f_{t}(x_{})+1/2h_{i} f_{t}^2(x)]+\Omega(f_{t})+constant Objt=i=1n[l(yi,yit1)+gift(x)+1/2hift2(x)]+Ω(ft)+constant
  所以当损失函数是mse时, g i = 2 ( y i t − 1 − y i ) g_{i}=2(y_{i}^{t-1}-y_{i}) gi=2(yit1yi) h i = 2 h_{i}=2 hi=2
  经过转换之后,其中第一项是所有样本与第t-1颗树的误差之和,因为第t-1颗树是已知的,所以可以将其视为常数项,我们暂时在目标函数中将其舍去,我们的目标函数变为关于 f t ( x ) f_{t}(x) ft(x)的函数了。而 f t ( x ) f_{t}(x) ft(x)则是关于叶子节点输出 w w w的函数,所以我们的目标函数全部转换为关于 w w w的函数, O b j t = ∑ i = 1 n [ g i f t ( x ) + 1 / 2 h i f t 2 ( x ) ] + Ω ( f t ) + c o n s t a n t = ∑ i = 1 n [ g i w q ( x i ) + 1 / 2 h i w q 2 ( x i ) ] + γ T + 1 / 2 λ ∑ j = 1 T w j 2 = ∑ j = 1 T [ ∑ i ∈ I j ( g i ) ∗ w j + 1 / 2 ∗ ∑ i ∈ I j ( h i + λ ) w j 2 ] + γ T Obj^{t}=\sum_{i=1}^{n}[g_{i} f_{t}(x_{})+1/2h_{i} f_{t}^2(x)]+\Omega(f_{t})+constant \\ =\sum_{i=1}^{n}[g_{i}w_{q}(x_{i})+1/2h_{i}w_{q}^2(x_{i})]+\gamma T+1/2\lambda\sum_{j=1}^{T}w_{j}^{2} \\ =\sum_{j=1}^{T}[\sum_{i \in I_{j}}(g_{i})*w_{j}+1/2*\sum_{i \in I_{j}}(h_{i}+\lambda)w_{j}^2]+\gamma T Objt=i=1n[gift(x)+1/2hift2(x)]+Ω(ft)+constant=i=1n[giwq(xi)+1/2hiwq2(xi)]+γT+1/2λj=1Twj2=j=1T[iIj(gi)wj+1/2iIj(hi+λ)wj2]+γT。我们令 G j = ∑ i ∈ I j ( g i ) G_{j}=\sum_{i \in I_{j}}(g_{i}) Gj=iIj(gi),令 H j = ∑ i ∈ I j ( h i ) H_{j}=\sum i \in I_{j}(h_{i}) Hj=iIj(hi),则我们的目标函数转换为 O b j t = ∑ j = 1 T G j ∗ w j + 1 / 2 ( H j + λ ) ∗ w j 2 + λ T Obj^{t}=\sum_{j=1}^{T}G_{j}*w_{j}+1/2(H_{j}+\lambda)*w_{j}^{2}+\lambda T Objt=j=1TGjwj+1/2(Hj+λ)wj2+λT。在上述表达式中, j 表 示 第 j 个 节 点 j表示第j个节点 jj i 表 示 第 i 个 样 本 i表示第i个样本 ii。所以整个目标函数转换成了关于 w w w即叶节点分数的一元二次函数,想要优化目标函数,就是求解最优的w,因此我们对目标求导,得到 w ∗ = − G i / ( H i + λ ) w^{*}=-G_{i}/(H_{i}+\lambda) w=Gi/(Hi+λ),将 w ∗ w^{*} w代入目标函数中,则目标函数变为 O b j t = − 1 / 2 ∑ j = 1 T G j 2 / ( H j + λ ) + λ T Obj^{t}=-1/2\sum_{j=1}^{T}G_{j}^{2}/(H_{j}+\lambda)+\lambda T Objt=1/2j=1TGj2/(Hj+λ)+λT。如此简单,所以在求解二叉树的目标函数时,只要知道损失函数的一阶导、二阶导,以及样本落在哪个叶子节点上,我们只要求出在每个叶子节点上,该样本的一阶导和二阶导就能求出目标函数。也就能决定是否分裂该节点,依据哪个节点的特征值来进行分裂。

三、节点分裂

   xgb节点是否分裂取决于信息增益的变化,若分裂当前节点,信息增益>0,则进行分裂,若不大于0则不分裂,如何判断分列前后信息增益的变化呢。那就可以使用我们的目标函数来表示了。
G a i n = G L 2 / ( H L + λ ) + G R 2 / ( H R + λ ) − ( G L + G R ) 2 / ( H L + H R + λ ) + γ Gain=G_{L}^{2}/(H_{L}+\lambda)+G_{R}^{2}/(H_{R}+\lambda)-(G_{L}+G_{R})^2/(H_{L}+H_{R}+\lambda)+\gamma Gain=GL2/(HL+λ)+GR2/(HR+λ)(GL+GR)2/(HL+HR+λ)+γ
  节点分裂有两种方式:1、贪心算法,2、近似算法

3.1 贪心算法

  贪心算法分裂的方式就是一种暴力搜索的方式,遍历每一个特征,遍历该特征的每一个取值,计算分裂前后的增益,选择增益最大的特征取值作为分裂点。

分裂流程如上图所示。

3.2 近似算法

   近似算法,其实就是分桶,目的是为了提升计算速度,降低遍历的次数,所以对特征进行分桶。就是将每一个特征的取值按照分位数划分到不同的桶中,利用桶的边界值作为分裂节点的候选集,每次遍历时不再是遍历所有特征取值,而是仅遍历该特征的几个桶(每个桶可以理解为该特征取值的分位数)就可以,这样可以降低遍历特征取值的次数。
  分桶算法分为global模式和local模式,global模式就是在第一次划分桶之后,不再更新桶,一直使用划分完成的桶进行后续的分裂。这样做就是计算复杂度降低,但是经过多次划分之后,可能会存在一些桶是空的,即该桶中已经没有了数据。
  local模式就是在每次分列前都重新划分桶,优点是每次分桶都能保证各桶中的样本数量都是均匀的,不足的地方就是计算量大。

四、其它特点

4.1 缺失值处理

   对于存在某一维特征缺失的样本,xgb会尝试将其放到左子树计算一次增益,再放到右子树计算一次增益,对比放在左右子树增益的大小决定放在哪个子树。

4.2 防止过拟合

   xgb提出了两种防止过拟合的方法:第一种称为Shrinkage,即学习率,在每次迭代一棵树的时候对每个叶子结点的权重乘上一个缩减系数,使每棵树的影响不会过大,并且给后面的树留下更大的空间优化。另一个方法称为Column Subsampling,类似于随机森林选区部分特征值进行建树,其中又分为两个方式:方式一按层随机采样,在对同一层结点分裂前,随机选取部分特征值进行遍历,计算信息增益;方式二在建一棵树前随机采样部分特征值,然后这棵树的所有结点分裂都遍历这些特征值,计算信息增益。

五、总结

  以上是对xgb的一些理解,大多是观看了很多大神的博客,通过不断的看别人总结的部分以及公式的推导,才让我逐渐理解xgb的各种特征。本文还是有很多不足的地方,后续逐渐补充,完善。

有关XGBoost算法介绍的更多相关文章

  1. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  2. Unity 热更新技术 | (三) Lua语言基本介绍及下载安装 - 2

    ?博客主页:https://xiaoy.blog.csdn.net?本文由呆呆敲代码的小Y原创,首发于CSDN??学习专栏推荐:Unity系统学习专栏?游戏制作专栏推荐:游戏制作?Unity实战100例专栏推荐:Unity实战100例教程?欢迎点赞?收藏⭐留言?如有错误敬请指正!?未来很长,值得我们全力奔赴更美好的生活✨------------------❤️分割线❤️-------------------------

  3. 100个python算法超详细讲解:画直线 - 2

    1.问题描述使用Python的turtle(海龟绘图)模块提供的函数绘制直线。2.问题分析一幅复杂的图形通常都可以由点、直线、三角形、矩形、平行四边形、圆、椭圆和圆弧等基本图形组成。其中的三角形、矩形、平行四边形又可以由直线组成,而直线又是由两个点确定的。我们使用Python的turtle模块所提供的函数来绘制直线。在使用之前我们先介绍一下turtle模块的相关知识点。turtle模块提供面向对象和面向过程两种形式的海龟绘图基本组件。面向对象的接口类如下:1)TurtleScreen类:定义图形窗口作为绘图海龟的运动场。它的构造器需要一个tkinter.Canvas或ScrolledCanva

  4. H2数据库配置及相关使用方式一站式介绍(极为详细并整理官方文档) - 2

    目录H2数据库入门以及实际开发时的使用1.H2数据库的初识1.1H2数据库介绍1.2为什么要使用嵌入式数据库?1.3嵌入式数据库对比1.3.1性能对比1.4技术选型思考2.H2数据库实战2.1H2数据库下载搭建以及部署2.1.1H2数据库的下载2.1.2数据库启动2.1.2.1windows系统可以在bin目录下执行h2.bat2.1.2.2同理可以通过cmd直接使用命令进行启动:2.1.2.3启动后控制台页面:2.1.3spring整合H2数据库2.1.3.1引入依赖文件2.1.4数据库通过file模式实际保存数据的位置2.2H2数据库操作2.2.1Mysql兼容模式2.2.2Mysql模式

  5. ruby - 在 Ruby 中实现 Luhn 算法 - 2

    我一直在尝试用Ruby实现Luhn算法。我一直在执行以下步骤:该公式根据其包含的校验位验证数字,该校验位通常附加到部分帐号以生成完整帐号。此帐号必须通过以下测试:从最右边的校验位开始向左移动,每第二个数字的值加倍。将乘积的数字(例如,10=1+0=1、14=1+4=5)与原始数字的未加倍数字相加。如果总模10等于0(如果总和以零结尾),则根据Luhn公式该数字有效;否则无效。http://en.wikipedia.org/wiki/Luhn_algorithm这是我想出的:defvalidCreditCard(cardNumber)sum=0nums=cardNumber.to_s.s

  6. Ruby 斐波那契算法 - 2

    下面是我写的一个计算斐波那契数列中的值的方法:deffib(n)ifn==0return0endifn==1return1endifn>=2returnfib(n-1)+(fib(n-2))endend它工作到n=14,但在那之后我收到一条消息说程序响应时间太长(我正在使用repl.it)。有人知道为什么会这样吗? 最佳答案 Naivefibonacci进行了大量的重复计算-在fib(14)fib(4)中计算了很多次。您可以将内存添加到您的算法中以使其更快:deffib(n,memo={})ifn==0||n==1returnnen

  7. ruby-on-rails - Rails add_index 算法 : :concurrently still causes database lock up during migration - 2

    为了防止在迁移到生产站点期间出现数据库事务错误,我们遵循了https://github.com/LendingHome/zero_downtime_migrations中列出的建议。(具体由https://robots.thoughtbot.com/how-to-create-postgres-indexes-concurrently-in概述),但在特别大的表上创建索引期间,即使是索引创建的“并发”方法也会锁定表并导致该表上的任何ActiveRecord创建或更新导致各自的事务失败有PG::InFailedSqlTransaction异常。下面是我们运行Rails4.2(使用Acti

  8. ruby - 趋势算法 - 2

    我正在开发一个类似微论坛的项目,其中一个特殊用户发布一条快速(接近推文大小)的主题消息,订阅者可以用他们自己的类似大小的消息来响应。直截了当,没有任何形式的“挖掘”或投票,只是每个主题消息的响应按时间顺序排列。但预计会有很高的流量。我们想根据它们引起的响应嗡嗡声来标记主题消息,使用0到10的等级。在谷歌上搜索了一段时间的趋势算法和开源社区应用示例,到目前为止已经收集到两个有趣的引用资料,但我还没有完全理解它们:Understandingalgorithmsformeasuringtrends,关于使用基线趋势算法比较维基百科页面浏览量的讨论,在SO上。TheBritneySpearsP

  9. Ruby - 不支持的密码算法 (AES-256-GCM) - 2

    我收到错误:unsupportedcipheralgorithm(AES-256-GCM)(RuntimeError)但我似乎具备所有要求:ruby版本:$ruby--versionruby2.1.2p95OpenSSL会列出gcm:$opensslenc-help2>&1|grepgcm-aes-128-ecb-aes-128-gcm-aes-128-ofb-aes-192-ecb-aes-192-gcm-aes-192-ofb-aes-256-ecb-aes-256-gcm-aes-256-ofbRuby解释器:$irb2.1.2:001>require'openssl';puts

  10. java实现Dijkstra算法 - 2

    文章目录一.Dijkstra算法想解决的问题二.Dijkstra算法理论三.java代码实现一.Dijkstra算法想解决的问题解决的问题:求解单源最短路径,即各个节点到达源点的最短路径或权值考察其他所有节点到源点的最短路径和长度局限性:无法解决权值为负数的情况二.Dijkstra算法理论参数:S记录当前已经处理过的源点到最短节点U记录还未处理的节点dist[]记录各个节点到起始节点的最短权值path[]记录各个节点的上一级节点(用来联系该节点到起始节点的路径)Dijkstra算法步骤:(1)初始化:顶点集S:节点A到自已的最短路径长度为0。只包含源点,即S={A}顶点集U:包含除A外的其他顶

随机推荐