天天看点

sklearn 决策树_初识决策树及sklearn实现

1、决策树简介

决策树多被使用处理分类问题,也是最经常使用的数据挖掘算法。决策树的主要任务是发掘数据中所蕴含的知识信息,并从中给出一系列的分类规则对数据进行分类,其预测结果往往可以匹敌具有几十年专家经验所得出的预测规则。比较常用的决策树有ID3,C4.5和CART(Classification And Regression Tree),CART的分类效果一般优于其他决策树,所以,scikit-learn使用的为CART算法的优化版本。下面简单介绍下决策树的优缺点和使用技巧。

  • 优点:计算复杂度低,从而计算速度较快;不需要领域知识和参数假设;输出结果易于理解,分类规则准确性较高
  • 缺点:容易产生过拟合;忽略各属性之间的相关性
  • 使用技巧:适用的数据类型有数值型和标称型;可以采用剪枝来避免过拟合问题;需要考虑对输入数据进行预处理(比如降维等)

2、决策树实例

(1)数据来源:Iris数据集,具体介绍参见https://zhuanlan.zhihu.com/p/145542994中的数据来源部分。sklearn包中自带了Iri数据集和数据处理方法,可以直接将第五列类别信息转换为数字。

from sklearn.datasets import load_iris         #数据集
from sklearn.tree import DecisionTreeClassifier #训练器
from sklearn import tree
import graphviz                                #结果展示
from sklearn.tree import export_text           #结果展示                      
X, y = load_iris(return_X_y=True)          #其中X为(150, 4),y为(150,),类型均为 'numpy.ndarray'           

(2)网络结构的搭建

clf = tree.DecisionTreeClassifier()   #创建分类器
clf = clf.fit(X, y)                   #模型的训练
print(clf.predict([[2., 2.,,1,1]]))#模型的预测,输出为预测的类型,1类
print(clf.predict_proba([[2., 2.,1,1]]))   #模型的预测概率,输出为 [[0. 1. 0.]],代表每类的概率值           

(3)结果的展示

展示的方式有二种,绘制树和文本输出。绘制树可以直接使用

plot_tree

函数或者graphviz导出,代码和结果分别如下所示。

#(a)直接使用plot_tree()函数
tree.plot_tree(clf)
plt.show()
#(b)以Graphviz格式导出
iris = load_iris()
dot_data = tree.export_graphviz(clf,
                                out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = graphviz.Source(dot_data)
graph.render(filename='MyPicture',view=True)   #结果保存在MyPicture.pdf中
graph.render("iris")              #打印输出图表的所有配置项           

结果如下:

sklearn 决策树_初识决策树及sklearn实现
#以文本格式导出树export_text
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)           

结果如下:

sklearn 决策树_初识决策树及sklearn实现

(4)其他用途

使用DecisionTreeRegressor()函数,决策树也可以应用于回归问题,进行简单示例。

X = [[0, 0], [1, 1]]                   #定义输入
Y = [1, 2]                             #定义标签
clf = tree.DecisionTreeRegressor()    #构建回归模型
clf = clf.fit(X, Y)                   #模型训练
print(clf.predict([[1, 1]]))          #模型预测           

参考文献:

《机器学习实战》、《scikit-learn官方文档》

机器学习实战-当当网​search.dangdang.com scikit-learn: machine learning in Python​scikit-learn.org

sklearn 决策树_初识决策树及sklearn实现