文章目錄
- 一、支援向量機的原理
-
- 解決的問題:
- 線性分類及其限制條件:
- 二、實戰
-
- 2.1、線性回歸
- 2.2、支援向量機SVM
- 2.3、多項式特征
一、支援向量機的原理
- Support Vector Machine。支援向量機,其含義是通過支援向量運算的分類器。其中“機”的意思是機器,可以了解為分類器。
- 那麼什麼是支援向量呢?在求解的過程中,會發現隻根據部分資料就可以确定分類器,這些資料稱為支援向量。
- 見下圖,在一個二維環境中,其中點R,S,G點和其它靠近中間黑線的點可以看作為支援向量,它們可以決定分類器,也就是黑線的具體參數。
解決的問題:
-
線性分類
在訓練資料中,每個資料都有n個的屬性和一個二類類别标志,我們可以認為這些資料在一個n維空間裡。我們的目标是找到一個n-1維的超平面(hyperplane),這個超平面可以将資料分成兩部分,每部分資料都屬于同一個類别。
其實這樣的超平面有很多,我們要找到一個最佳的。是以,增加一個限制條件:這個超平面到每邊最近資料點的距離是最大的。也成為最大間隔超平面(maximum-margin hyperplane)。這個分類器也成為最大間隔分類器(maximum-margin classifier)。
支援向量機是一個二類分類器。
-
非線性分類
SVM的一個優勢是支援非線性分類。它結合使用拉格朗日乘子法和KKT條件,以及核函數可以産生非線性分類器。
線性分類及其限制條件:
SVM的解決問題的思路是找到離超平面的最近點,通過其限制條件求出最優解。
二、實戰
2.1、線性回歸
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,5,30).reshape(-1,1) #轉化為二維的資料
y = (X - 3)**2 + 3*X - 10 #假設一個函數
#引入線性回歸函數
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X,y) #學習上面的函數
#加入測試集
X_test = np.linspace(-5,5,130).reshape(-1,1)
y_ = lr.predict(X_test) #預測的結果
#資料可視化
plt.scatter(X,y) #原函數散點圖圖像
plt.plot(X_test,y_,c = 'r') #回歸線
plt.show()
結果分析: 從上面的圖中可以看出,線性的拟合并不理想,因為原函數不屬于線性分布模型。
2.2、支援向量機SVM
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,10,30).reshape(-1,1) #轉化為二維的資料
y = X**3 + X**2 + X + 10 #假設一個函數
#引入支援向量機
from sklearn.svm import SVR
svr = SVR(kernel='poly',degree=3) #degree度
svr.fit(X,y) #學習上面的函數
#加入測試集預測
X_test = np.linspace(-10,20,300).reshape(-1,1)
y_ = svr.predict(X_test) #預測的結果
#資料可視化
plt.scatter(X,y) #原函數散點圖圖像
plt.plot(X_test,y_,c = 'r') #回歸線
plt.show()
結果分析: 從圖中可以看出,對于曲線的預測評估,SVM的準确度比線性回歸好很多。
2.3、多項式特征
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,10,30).reshape(-1,1) #轉化為二維的資料
y = X**3 + X**2 + X + 10 #假設一個函數
# 資料清洗
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=3)
X3 = poly.fit_transform(X)
#引入線性回歸函數
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X3,y) #學習模型
#加入測試集預測
X_test = np.linspace(-10,20,300).reshape(-1,1)
X_test3 = poly.fit_transform(X_test) #需要統一資料次元
y_ = lr.predict(X_test3) #預測
plt.scatter(X,y)
plt.plot(X_test,y_,color = 'r')
plt.show()
結果分析: 如上圖所示,它的的資料拟合度幾乎完美了!