天天看点

Spark MLlib - Decision Tree源码分析

以决策树作为开始,因为简单,而且也比较容易用到,当前的boosting或random forest也是常以其为基础的

决策树算法本身参考之前的blog,其实就是贪婪算法,每次切分使得数据变得最为有序

那么如何来定义有序或无序?

无序,node impurity 

Spark MLlib - Decision Tree源码分析

对于分类问题,我们可以用熵entropy或gini来表示信息的无序程度 

对于回归问题,我们用方差variance来表示无序程度,方差越大,说明数据间差异越大

information gain

用于表示,由父节点划分后得到子节点,所带来的impurity的下降,即有序性的增益

Spark MLlib - Decision Tree源码分析

mlib决策树的例子

下面直接看个regression的例子,分类的case,差不多,

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

还是比较简单的,

由于是回归,所以impurity的定义为variance 

maxdepth,最大树深,设为5 

maxbins,最大的划分数 

先理解什么是bin,决策树的算法就是对feature的取值不断的进行划分 

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

但如果是有序的,即按老,中,少的序,那么只有m-1个,即2种划分,老|中,少;老,中|少

对于连续的feature,其实就是进行范围划分,而划分的点就是split,划分出的区间就是bin 

对于连续feature,理论上划分点是无数的,但是出于效率我们总要选取合适的划分点 

有个比较常用的方法是取出训练集中该feature出现过的值作为划分点, 

但对于分布式数据,取出所有的值进行排序也比较费资源,所以可以采取sample的方式

源码分析

首先调用,decisiontree.trainregressor,类似调用静态函数(object decisiontree)

org.apache.spark.mllib.tree.decisiontree.scala

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

调用静态函数train

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

可以看到将所有参数封装到strategy类,然后初始化decisiontree类对象,继续调用成员函数train

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

可以看到,这里decisiontree的设计是基于randomforest的特例,即单颗树的randomforest 

所以调用randomforest.train(),最终因为只有一棵树,所以取trees(0)

org.apache.spark.mllib.tree.randomforest.scala

重点看下,randomforest里面的train做了什么?

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

1. decisiontreemetadata.buildmetadata

org.apache.spark.mllib.tree.impl.decisiontreemetadata.scala

这里生成一些后面需要用到的metadata 

最关键的是计算每个feature的bins和splits的数目,

计算bins的数目

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

其他case,bins数目等于feature的numcategories 

对于unordered情况,比较特殊,

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

根据bins数目,计算splits

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

2. decisiontree.findsplitsbins

首先找出每个feature上可能出现的splits和相应的bins,这是后续算法的基础 

这里的注释解释了上面如何计算splits和bins数目的算法

a,对于连续数据,对于一个feature,splits = numbins - 1;上面也说了对于连续值,其实splits可以无限的,如何找到numbins - 1个splits,很简单,这里用sample 

b,对于离散数据,两个case 

Spark MLlib - Decision Tree源码分析

    b.2, 有序的feature,用于regression,二元分类,或high-arity的多元分类,这种case下划分的可能比较少,m-1,所以用每个category作为划分

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

3. treepoint和baggedpoint

treepoint是labeledpoint的内部数据结构,这里需要做转换,

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

arr是findbin的结果, 

这里主要是针对连续特征做处理,将连续的值通过二分查找转换为相应bin的index 

对于离散数据,bin等同于featurevalue.toint

baggedpoint,由于random forest是比较典型的bagging算法,所以需要对训练集做bootstrap sample 

而对于decision tree是特殊的单根random forest,所以不需要做抽样 

baggedpoint.converttobaggedrddwithoutsampling(treeinput) 

其实只是做简单的封装

4. decisiontree.findbestsplits

这段代码写的有点复杂,尤其和randomforest混杂一起

总之,关键在

看看binstobestsplit的实现,为了清晰一点,我们只看continuous feature

四个参数,

binaggregates: dtstatsaggregator, 就是impurityaggregator,给出如果算出impurity的逻辑 

splits: array[array[split]], feature对应的splits 

featuresfornode: option[array[int]], tree node对应的feature  

node: node, 哪个tree node

返回值,

(split, informationgainstats, predict), 

split,最优的split对象(包含featureindex和splitindex) 

informationgainstats,该split产生的gain对象,表明产生多少增益,多大程度降低impurity 

predict,该节点的预测值,对于连续feature就是平均值,看后面的分析

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

predict,这个需要分析一下 

predictwithimpurity.get._1,predictwithimpurity元组的第一个元素 

calculatepredictimpurity的返回值中的predict

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

这里predict和impurity有什么不同,可以看出 

impurity = impuritycalculator.calculate() 

predict = impuritycalculator.predict

对于连续feature,我们就看variance的实现,

Spark MLlib - Decision Tree源码分析
Spark MLlib - Decision Tree源码分析

从calculate的实现可以看到,impurity求的就是方差, 不是标准差(均方差)

Spark MLlib - Decision Tree源码分析

继续阅读