天天看点

gbdt 算法比随机森林容易_聊聊GBDT和随机森林

原标题:聊聊GBDT和随机森林

欢迎关注天善智能微信公众号,我们是专注于商业智能BI,大数据,数据分析领域的垂直社区。 对商业智能BI、大数据分析挖掘、机器学习,python,R等数据领域感兴趣的同学加微信:tstoutiao,邀请你进入头条数据爱好者交流群,数据爱好者们都在这儿

gbdt 算法比随机森林容易_聊聊GBDT和随机森林

GBDT 和 随机森林都是基于决策树而得到的。决策树比较容易理解,它具有比较直观的图示。利用决策树可以发现比较重要的变量,也可以挖掘变量之间的关系。决策树也比较不易受到离群点和缺失值的影响。由于决策树不考虑空间分布,也不考虑分类器的结构,它是一种无参算法。但是决策树比较容易过拟合,另外,决策树不易处理连续型变量。

Gradient Boosting 是一种提升的框架,可以用于决策树算法,即GBDT。通常Boosting基于弱学习器,弱分类器具有高偏差,低方差。决策树的深度较浅时,通常是弱学习器,比如一种比较极端的例子,只有一个根节点和两个叶子节点。Boosting这种策略主要通过减小偏差来降低总体误差,一般来讲,通过集成多个模型的结果也会减小方差。这里的总体误差可以看作由偏差和方差构成。由于GBDT是基于Boosting策略的,所以这种算法具有序贯性,不容易并行实现。

关于偏差和方差随模型复杂度变化,可以参见下图。

gbdt 算法比随机森林容易_聊聊GBDT和随机森林

随机森林主要通过减小方差来降低总体误差。随机森林是由多个决策树构成的,因此需要基于原始数据集随机生成多个数据集,用于生成多个决策树。这些决策树之间的相关性越小,方差降低得越多。虽然随机森林可以减小方差,但是这种组合策略不能降低偏差,它会使得总体偏差大于森林中单个决策树的偏差。

随机森林利用bagging来组合多个决策树,容易过拟合。由于这种方法基于bagging思想,因此这种算法比较容易并行实现。随机森林能够较好地应对缺失值和非平衡集的情形。

下面给出基于scikit-learn的随机森林示例:

fromsklearn.ensembleimportRandomForestClassifier

X=[[0,0],[1,1]]Y=[0,1]clf=RandomForestClassifier(n_estimators=10)clf=clf.fit(X,Y)

spark也内嵌了随机森林算法,示例如下:

frompyspark.mllib.treeimportRandomForest,RandomForestModelfrompyspark.mllib.utilimportMLUtils# Loadandparsethe datafileintoan RDD ofLabeledPoint.data=MLUtils.loadLibSVMFile(sc,'data/mllib/sample_libsvm_data.txt')# Splitthe dataintotraining andtestsets(30% held outfortesting)(trainingData,testData)=data.randomSplit([0.7,0.3])# Train a RandomForest model.# EmptycategoricalFeaturesInfo indicates all features arecontinuous.# Note: Uselarger numTrees inpractice.# Setting featureSubsetStrategy="auto"lets the algorithm choose.model=RandomForest.trainClassifier(trainingData,numClasses=2,categoricalFeaturesInfo={},

numTrees=3,featureSubsetStrategy="auto",

impurity='gini',maxDepth=4,maxBins=32)# Evaluatemodelontestinstances andcomputetesterrorpredictions=model.predict(testData.map(lambdax:x.features))labelsAndPredictions=testData.map(lambdalp:lp.label).zip(predictions)testErr=labelsAndPredictions.filter(lambda(v,p):v!=p).count()/float(testData.count())print('Test Error = '+str(testErr))print('Learned classification forest model:')print(model.toDebugString())# Saveandloadmodelmodel.save(sc,"target/tmp/myRandomForestClassificationModel")sameModel=RandomForestModel.load(sc,"target/tmp/myRandomForestClassificationModel")

这两种算法也可以用于客户管理或营销领域,比如客户流失预测(Bagging and boosting classification trees to predict churn

) 和点击率预估(

Feature Selection in Click-Through Rate Prediction Based on Gradient Boosting

)等。

转载请保留以下内容:

本文来源自天善社区陈富强老师的博客(公众号)。

责任编辑: