前言:分類是機器學習中的重要的一種功能,在機器學習的研究曆史中,誕生了大量的分類算法,而每種算法都有其優勢和不足。
本文彙總了常用的分類算法及其實作方式,友善快速查詢使用。(本文使用鸢尾花資料集,是三類别分類)
以下9種分類算法,使用相同的資料進行訓練和測試,在測試集上的準确率(accuracy)分别為:
1.随機森林:100%
2.決策樹:100%
3.K近鄰:100%
4.支援向量機:100%
5.邏輯回歸:96.67%
6.線性支援向量機:100%
1 import numpy as np
2 import pandas as pd
3 import matplotlib as mpl
4 import matplotlib.pyplot as plt
5 import sklearn
6 from sklearn import datasets
7 from sklearn.metrics import accuracy_score
8
9 from sklearn.ensemble import RandomForestClassifier
10 from sklearn.tree import DecisionTreeClassifier
11
12 from sklearn.neighbors import KNeighborsClassifier
13 from sklearn.svm import SVC, LinearSVC
14 from sklearn.linear_model import LogisticRegression
15
16 from sklearn.linear_model import SGDClassifier
17 from sklearn.linear_model import Perceptron
18 from sklearn.naive_bayes import GaussianNB
19
20 from sklearn.model_selection import train_test_split
21 from sklearn.model_selection import cross_val_score
22
23 from sklearn.model_selection import GridSearchCV
24
25 iris = datasets.load_iris()
26 x,y = iris.data,iris.target
27
28 x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
29
30 res = []
31
32 #1. 随機森林分類
33 print('随機森林分類')
34 clf = RandomForestClassifier(n_estimators=100)
35 clf.fit(x_train, y_train)
36 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
37 print(cross_score)
38 y_predict = clf.predict(x_test)
39 score = accuracy_score(y_test,y_predict)
40 res.append(score)
41 print()
42
43 #2. 決策樹分類
44 print('決策樹分類')
45 clf = DecisionTreeClassifier()
46 clf.fit(x_train, y_train)
47 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
48 print(cross_score)
49 y_predict = clf.predict(x_test)
50 score = accuracy_score(y_test,y_predict)
51 res.append(score)
52 print()
53
54 #3. KNN
55 print('KNN')
56 clf = KNeighborsClassifier(n_neighbors = 13)
57 clf.fit(x_train, y_train)
58 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
59 print(cross_score)
60 y_predict = clf.predict(x_test)
61 score = accuracy_score(y_test,y_predict)
62 res.append(score)
63 print()
64
65 #4. SVM分類
66 print('SVM')
67 clf = SVC(gamma='scale')
68 clf.fit(x_train, y_train)
69 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
70 print(cross_score)
71 y_predict = clf.predict(x_test)
72 score = accuracy_score(y_test,y_predict)
73 res.append(score)
74 print()
75
76 #5. 邏輯回歸分類
77 print('LogisticRegression')
78 clf = LogisticRegression(solver='lbfgs',multi_class='ovr')
79 clf.fit(x_train, y_train)
80 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
81 print(cross_score)
82 y_predict = clf.predict(x_test)
83 score = accuracy_score(y_test,y_predict)
84 res.append(score)
85 print()
86
87 #6. linear svm分類
88 print('linear SVM')
89 clf = LinearSVC(max_iter=10000)
90 clf.fit(x_train, y_train)
91 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
92 print(cross_score)
93 y_predict = clf.predict(x_test)
94 score = accuracy_score(y_test,y_predict)
95 res.append(score)
96 print()
97
98 #7. 随機梯度下降分類
99 print('SGD')
100 clf = SGDClassifier(max_iter=1000,tol=1e-3)
101 clf.fit(x_train, y_train)
102 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
103 print(cross_score)
104 y_predict = clf.predict(x_test)
105 score = accuracy_score(y_test,y_predict)
106 res.append(score)
107 print()
108
109 #8. 感覺機分類
110 print('Perceptron')
111 clf = Perceptron(max_iter=1000,tol=1e-3)
112 clf.fit(x_train, y_train)
113 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
114 print(cross_score)
115 y_predict = clf.predict(x_test)
116 score = accuracy_score(y_test,y_predict)
117 res.append(score)
118 print()
119
120 #9. 樸素貝葉斯分類
121 print('Naive Bayes')
122 clf = GaussianNB()
123 clf.fit(x_train, y_train)
124 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy")
125 print(cross_score)
126 y_predict = clf.predict(x_test)
127 score = accuracy_score(y_test,y_predict)
128 res.append(score)
129 print()
130
131 #10. 得分比較
132 print(res)