天天看點

gbdt 算法比随機森林容易_聊聊GBDT和随機森林

原标題:聊聊GBDT和随機森林

歡迎關注天善智能微信公衆号,我們是專注于商業智能BI,大資料,資料分析領域的垂直社群。 對商業智能BI、大資料分析挖掘、機器學習,python,R等資料領域感興趣的同學加微信:tstoutiao,邀請你進入頭條資料愛好者交流群,資料愛好者們都在這兒

gbdt 算法比随機森林容易_聊聊GBDT和随機森林

GBDT 和 随機森林都是基于決策樹而得到的。決策樹比較容易了解,它具有比較直覺的圖示。利用決策樹可以發現比較重要的變量,也可以挖掘變量之間的關系。決策樹也比較不易受到離群點和缺失值的影響。由于決策樹不考慮空間分布,也不考慮分類器的結構,它是一種無參算法。但是決策樹比較容易過拟合,另外,決策樹不易處理連續型變量。

Gradient Boosting 是一種提升的架構,可以用于決策樹算法,即GBDT。通常Boosting基于弱學習器,弱分類器具有高偏差,低方差。決策樹的深度較淺時,通常是弱學習器,比如一種比較極端的例子,隻有一個根節點和兩個葉子節點。Boosting這種政策主要通過減小偏差來降低總體誤差,一般來講,通過內建多個模型的結果也會減小方差。這裡的總體誤差可以看作由偏差和方差構成。由于GBDT是基于Boosting政策的,是以這種算法具有序貫性,不容易并行實作。

關于偏差和方差随模型複雜度變化,可以參見下圖。

gbdt 算法比随機森林容易_聊聊GBDT和随機森林

随機森林主要通過減小方差來降低總體誤差。随機森林是由多個決策樹構成的,是以需要基于原始資料集随機生成多個資料集,用于生成多個決策樹。這些決策樹之間的相關性越小,方差降低得越多。雖然随機森林可以減小方差,但是這種組合政策不能降低偏差,它會使得總體偏差大于森林中單個決策樹的偏差。

随機森林利用bagging來組合多個決策樹,容易過拟合。由于這種方法基于bagging思想,是以這種算法比較容易并行實作。随機森林能夠較好地應對缺失值和非平衡集的情形。

下面給出基于scikit-learn的随機森林示例:

fromsklearn.ensembleimportRandomForestClassifier

X=[[0,0],[1,1]]Y=[0,1]clf=RandomForestClassifier(n_estimators=10)clf=clf.fit(X,Y)

spark也内嵌了随機森林算法,示例如下:

frompyspark.mllib.treeimportRandomForest,RandomForestModelfrompyspark.mllib.utilimportMLUtils# Loadandparsethe datafileintoan RDD ofLabeledPoint.data=MLUtils.loadLibSVMFile(sc,'data/mllib/sample_libsvm_data.txt')# Splitthe dataintotraining andtestsets(30% held outfortesting)(trainingData,testData)=data.randomSplit([0.7,0.3])# Train a RandomForest model.# EmptycategoricalFeaturesInfo indicates all features arecontinuous.# Note: Uselarger numTrees inpractice.# Setting featureSubsetStrategy="auto"lets the algorithm choose.model=RandomForest.trainClassifier(trainingData,numClasses=2,categoricalFeaturesInfo={},

numTrees=3,featureSubsetStrategy="auto",

impurity='gini',maxDepth=4,maxBins=32)# Evaluatemodelontestinstances andcomputetesterrorpredictions=model.predict(testData.map(lambdax:x.features))labelsAndPredictions=testData.map(lambdalp:lp.label).zip(predictions)testErr=labelsAndPredictions.filter(lambda(v,p):v!=p).count()/float(testData.count())print('Test Error = '+str(testErr))print('Learned classification forest model:')print(model.toDebugString())# Saveandloadmodelmodel.save(sc,"target/tmp/myRandomForestClassificationModel")sameModel=RandomForestModel.load(sc,"target/tmp/myRandomForestClassificationModel")

這兩種算法也可以用于客戶管理或營銷領域,比如客戶流失預測(Bagging and boosting classification trees to predict churn

) 和點選率預估(

Feature Selection in Click-Through Rate Prediction Based on Gradient Boosting

)等。

轉載請保留以下内容:

本文來源自天善社群陳富強老師的部落格(公衆号)。

責任編輯: