前言:分类是机器学习中的重要的一种功能,在机器学习的研究历史中,诞生了大量的分类算法,而每种算法都有其优势和不足。
本文汇总了常用的分类算法及其实现方式,方便快速查询使用。(本文使用鸢尾花数据集,是三类别分类)
以下9种分类算法,使用相同的数据进行训练和测试,在测试集上的准确率(accuracy)分别为:
1.随机森林:100%
2.决策树:100%
3.K近邻:100%
4.支持向量机:100%
5.逻辑回归:96.67%
6.线性支持向量机:100%
1 import numpy as np
2 import pandas as pd
3 import matplotlib as mpl
4 import matplotlib.pyplot as plt
5 import sklearn
6 from sklearn import datasets
7 from sklearn.metrics import accuracy_score
8
9 from sklearn.ensemble import RandomForestClassifier
10 from sklearn.tree import DecisionTreeClassifier
11
12 from sklearn.neighbors import KNeighborsClassifier
13 from sklearn.svm import SVC, LinearSVC
14 from sklearn.linear_model import LogisticRegression
15
16 from sklearn.linear_model import SGDClassifier
17 from sklearn.linear_model import Perceptron
18 from sklearn.naive_bayes import GaussianNB
19
20 from sklearn.model_selection import train_test_split
21 from sklearn.model_selection import cross_val_score
22
23 from sklearn.model_selection import GridSearchCV
24
25 iris = datasets.load_iris()
26 x,y = iris.data,iris.target
27
28 x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
29
30 res = []
31
32 #1. 随机森林分类
33 print('随机森林分类')
34 clf = RandomForestClassifier(n_estimators=100)
35 clf.fit(x_train, y_train)
36 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
37 print(cross_score)
38 y_predict = clf.predict(x_test)
39 score = accuracy_score(y_test,y_predict)
40 res.append(score)
41 print()
42
43 #2. 决策树分类
44 print('决策树分类')
45 clf = DecisionTreeClassifier()
46 clf.fit(x_train, y_train)
47 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
48 print(cross_score)
49 y_predict = clf.predict(x_test)
50 score = accuracy_score(y_test,y_predict)
51 res.append(score)
52 print()
53
54 #3. KNN
55 print('KNN')
56 clf = KNeighborsClassifier(n_neighbors = 13)
57 clf.fit(x_train, y_train)
58 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
59 print(cross_score)
60 y_predict = clf.predict(x_test)
61 score = accuracy_score(y_test,y_predict)
62 res.append(score)
63 print()
64
65 #4. SVM分类
66 print('SVM')
67 clf = SVC(gamma='scale')
68 clf.fit(x_train, y_train)
69 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
70 print(cross_score)
71 y_predict = clf.predict(x_test)
72 score = accuracy_score(y_test,y_predict)
73 res.append(score)
74 print()
75
76 #5. 逻辑回归分类
77 print('LogisticRegression')
78 clf = LogisticRegression(solver='lbfgs',multi_class='ovr')
79 clf.fit(x_train, y_train)
80 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
81 print(cross_score)
82 y_predict = clf.predict(x_test)
83 score = accuracy_score(y_test,y_predict)
84 res.append(score)
85 print()
86
87 #6. linear svm分类
88 print('linear SVM')
89 clf = LinearSVC(max_iter=10000)
90 clf.fit(x_train, y_train)
91 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
92 print(cross_score)
93 y_predict = clf.predict(x_test)
94 score = accuracy_score(y_test,y_predict)
95 res.append(score)
96 print()
97
98 #7. 随机梯度下降分类
99 print('SGD')
100 clf = SGDClassifier(max_iter=1000,tol=1e-3)
101 clf.fit(x_train, y_train)
102 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
103 print(cross_score)
104 y_predict = clf.predict(x_test)
105 score = accuracy_score(y_test,y_predict)
106 res.append(score)
107 print()
108
109 #8. 感知机分类
110 print('Perceptron')
111 clf = Perceptron(max_iter=1000,tol=1e-3)
112 clf.fit(x_train, y_train)
113 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
114 print(cross_score)
115 y_predict = clf.predict(x_test)
116 score = accuracy_score(y_test,y_predict)
117 res.append(score)
118 print()
119
120 #9. 朴素贝叶斯分类
121 print('Naive Bayes')
122 clf = GaussianNB()
123 clf.fit(x_train, y_train)
124 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
125 print(cross_score)
126 y_predict = clf.predict(x_test)
127 score = accuracy_score(y_test,y_predict)
128 res.append(score)
129 print()
130
131 #10. 得分比较
132 print(res)