天天看点

决策树(三):CART算法

CART(分类与回归树),也就是说CART算法既可以用于分类,也可以用于回归,它是在给定输入随机变量X条件下输出随机变量Y的条件概率分布的学习方法,其也和回归树一样是二叉树。

是CART算法,也是分为:特征选择,树的生成,树的剪枝。其实感觉前两步可以合并为一步,因为树的生成过程中就是不断的进行特征的选择。

李航《统计学习方法》中说,决策树生成阶段,生成的决策树要尽可能的大,也就是生成阶段,模型要复杂。为什么这么说,我觉得因为后面还是剪枝的过程,在复杂的模型上剪枝优化得到最优模型,这是最合理也是最有可能得到最优模型的方式。

1.回归树的生成

回归树的生成方式可以参考我的另一篇文章《决策树(二):回归树和模型树》中的回归树的生成方式,这里不再赘述。

2.分类树的生成

CART算法生成的是二叉树,而分类树适用的是标称型数据,每个特征有好多个特征值,所以怎么生成二叉树是个问题。之前计算分类树某个节点的划分标准的时候是计算某个特征的信息增益,而CART中计算每个特征值的基尼系数,找出基尼系数最小的,即选择最优特征,作为划分值,等于那个特征对应的特征值放到左子树,其余的放到右子树。

假设有K个分类,样本点数据第k类的概率为pk,则概率分布的基尼指数定义为:

决策树(三):CART算法

所以对于给定的样本集合D,其基尼指数为:

决策树(三):CART算法

其中Ck是D中属于第k类的样本子集,K是类的个数。

下面一段话及公式是李航书的原来内容,我感觉不太好理解:

如果样本集合D根据特征A是否取某一可能值a被分割成D1和D2两部分,即

决策树(三):CART算法

则在特征A的条件下,集合D的基尼指数定义为

决策树(三):CART算法

怎么理解上述公式?我们其它文章介绍过,条件信息熵是针对整个特征求的熵,而基尼指数是需要计算某个特征的某个特征值的基尼指数,所以我们这里就不要考虑特征值之间的关系了。假设我们根据特征A的某个特征值a,将整个数据集分为D1和D2两个部分,D1是特征A是a的数据集合,而D2是特征A不是a的集合;|D1|/|D|就是特征A的值是a的数据集个数占总集合个数的比例,|D2|/|D|亦然;再看基尼指数的计算公式Gini(D1)=1-

决策树(三):CART算法

,其中K是集合D1中分类结果的总个数,即有多少种分类结果,相当于Y的个数,|D1|是集合D1中元素的个数,|Ck|是第k个分类中元素的个数;对于D2集合的理解亦同理;

其实从上面我么就可以看出和信息增益算法的区别:信息增益如果选择了某个特征分类,是按照所有兄弟种类分的,而CART算法因为是二叉树,所以会选择某个特征的某个值去分,所以分的子集是特征值是a的分一个子集,特征值不是a的是一个子集。

这里有个容易忽略的点,看李航书中的CART生成算法第(3)步,有这么一句话:对两个子节点递归调用(1),(2),直至满足停止条件。所以说特征值等于a的分出来的子集还是需要进行再切分的;另外书上没有说停止条件,但是前面我们说过生成树尽可能要复杂,可以参考这个条件,怎么复杂我觉得可以自己把握,例如基尼指数都小于某个阈值什么的。

3.CART剪枝

CART算法的第一阶段我们得到的是一个比较复杂的决策树,所以第二步我们就需要对生成的决策树进行剪枝,从而得到泛化能力更强的决策树。

CART剪枝算法会生成好多个子树{T1,T2,T3,…,Tn},然后用交叉验证的方式选出损失最少的决策树,作为最终的决策树。

CART剪枝算法中用到的损失函数,和其它剪枝算法的损失函数的基本表达式是一样的:Cα(T) = C(T)+α|T|,其中C(T)是基本的损失函数,参考决策树生成过程中的误差计算概念,我们可以理解为数据的混乱程度(如基尼指数,信息熵),|T|是树的叶子节点的个数,是模型的复杂度(叶子节点越多,模型越复杂),α>=0是参数,权衡训练数据的拟合程度与模型的复杂程度。

当α固定时,则一定存在一棵树,损失是最小的,用Tα表示;当α较大时,|T|较小的话整体损失就会达到一个较大的值,所以偏向选择模型较简单(叶子节点较少)的模型;当α较小时,|T|取到较大的值,所以偏向选择模型相对复杂(叶子节点较多)的模型。

CART算法对于一个子树来说需要定义两个损失函数:

  1. 对于树内部的任意节点t,以t为单节点数的损失函数是:
    决策树(三):CART算法
    ,这个损失函数是假设对子树进行剪枝后,就只保留一个这个子树的根节点的损失函数。
  2. 以t为根节点的子树Tt的损失函数是:
    决策树(三):CART算法
    ,这个损失函数是假设不剪枝的损失函数。

当α=0及α充分小时,有不等式

决策树(三):CART算法

,当α增大时,在某一α有

决策树(三):CART算法

。李航《统计学习方法》一书上是这么写的,为什么说这么说,没太明白。以下是个人理解:

就拿以t为根节点的子树来说,所有分到t节点的样本,对于这个子树来说就是这个子树的所有样本,也就是还没有开始往下分类的样本。对于这个子树来说,较小的α偏向选择较复杂的模型,也就是有一定的叶子节点(非根节点t),那么就是对于根节点t来说就是有分支结构的,那么就是更加细分了,整体损失就更小了,就会有

决策树(三):CART算法

;当α变大的时候,因为根节点的|T|=1,而这颗子树的|T| > 1,所以子树的整体损失增加较快,α到达一定程度的时候,就会有

决策树(三):CART算法

,此时由于剪枝前后损失相同,并且只有根节点t,模型偏向简单,所有只保留根节点比保留以t为根节点的子树更加可取,就进行剪枝。

所以根据以上等式,我们可以得到

决策树(三):CART算法

我们要保证α从小到大增加,所以对CART第一阶段生成的树的每个节点计算上述值,然后选择最小的α,然后将其对应的子树剪枝,得到子树T1,然后在T1基础上再重复上述过程,直到剪枝到根节点。最后得到一系列子树T1,T2,T3,…,Tn并且对应一些列参数α1,α2,α3,…,αn。得到这些子树后,采用交叉验证的方式在一系列子树中选取最优子树Tα。

通过以上过程,我们可以看到CART剪枝比一般剪枝算法还是复杂一些的;最起码普通的剪枝算法,不需要生成那么多的树,只需要在一棵树上不断计算剪枝前后的误差,然后决定是否剪枝,思路也相对好理解,CART剪枝就不那么好理解了。

继续阅读