1、决策树简介
决策树多被使用处理分类问题,也是最经常使用的数据挖掘算法。决策树的主要任务是发掘数据中所蕴含的知识信息,并从中给出一系列的分类规则对数据进行分类,其预测结果往往可以匹敌具有几十年专家经验所得出的预测规则。比较常用的决策树有ID3,C4.5和CART(Classification And Regression Tree),CART的分类效果一般优于其他决策树,所以,scikit-learn使用的为CART算法的优化版本。下面简单介绍下决策树的优缺点和使用技巧。
- 优点:计算复杂度低,从而计算速度较快;不需要领域知识和参数假设;输出结果易于理解,分类规则准确性较高
- 缺点:容易产生过拟合;忽略各属性之间的相关性
- 使用技巧:适用的数据类型有数值型和标称型;可以采用剪枝来避免过拟合问题;需要考虑对输入数据进行预处理(比如降维等)
2、决策树实例
(1)数据来源:Iris数据集,具体介绍参见https://zhuanlan.zhihu.com/p/145542994中的数据来源部分。sklearn包中自带了Iri数据集和数据处理方法,可以直接将第五列类别信息转换为数字。
from sklearn.datasets import load_iris #数据集
from sklearn.tree import DecisionTreeClassifier #训练器
from sklearn import tree
import graphviz #结果展示
from sklearn.tree import export_text #结果展示
X, y = load_iris(return_X_y=True) #其中X为(150, 4),y为(150,),类型均为 'numpy.ndarray'
(2)网络结构的搭建
clf = tree.DecisionTreeClassifier() #创建分类器
clf = clf.fit(X, y) #模型的训练
print(clf.predict([[2., 2.,,1,1]]))#模型的预测,输出为预测的类型,1类
print(clf.predict_proba([[2., 2.,1,1]])) #模型的预测概率,输出为 [[0. 1. 0.]],代表每类的概率值
(3)结果的展示
展示的方式有二种,绘制树和文本输出。绘制树可以直接使用
plot_tree
函数或者graphviz导出,代码和结果分别如下所示。
#(a)直接使用plot_tree()函数
tree.plot_tree(clf)
plt.show()
#(b)以Graphviz格式导出
iris = load_iris()
dot_data = tree.export_graphviz(clf,
out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render(filename='MyPicture',view=True) #结果保存在MyPicture.pdf中
graph.render("iris") #打印输出图表的所有配置项
结果如下:
#以文本格式导出树export_text
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
结果如下:
(4)其他用途
使用DecisionTreeRegressor()函数,决策树也可以应用于回归问题,进行简单示例。
X = [[0, 0], [1, 1]] #定义输入
Y = [1, 2] #定义标签
clf = tree.DecisionTreeRegressor() #构建回归模型
clf = clf.fit(X, Y) #模型训练
print(clf.predict([[1, 1]])) #模型预测
参考文献:
《机器学习实战》、《scikit-learn官方文档》
机器学习实战-当当网search.dangdang.com scikit-learn: machine learning in Pythonscikit-learn.org