adaboost理論部分(公式)後期補充
1.建立資料集
"""
此處不建構太複雜的資料集,不然可能後面用單層決策樹時候,效果不好。
因為用任何一個單層決策樹都無法完全分開這五個資料
"""
import numpy as np
def loadSimpData():
datMat = np.matrix([[1,2.1],[2,1.1],[1.3,1],[1,1],[2,1]])
classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
return datMat,classLabels
loadSimpData()
(matrix([[1. , 2.1],
[2. , 1.1],
[1.3, 1. ],
[1. , 1. ],
[2. , 1. ]]), [1.0, 1.0, -1.0, -1.0, 1.0])
2.單層決策樹生成函數
"""
單層決策樹是一個很弱很弱的分類器,因為它隻選擇了一個特征就不繼續選擇了。
第一個函數stumpClassify()是通過門檻值比較對資料進彳了分類的。
所有在門檻值一邊的資料會分到類别-i, 而在另外一邊的資料分到類别+l。該函數可以通過數組過
濾來實作,首先将傳回數組的全部元素設定為1 ,然後将所有不滿足不等式要求的元素設定為-1。
可以基于資料集中的任一進制素進行比較,同時也可以将不等号在大于、小于之間切換。
Parameters:
dataMatrix - 資料矩陣
dimen - 第dimen列,也就是第幾個特征
threshVal - 門檻值
threshIneq - 标志
Returns:
retArray - 分類結果
"""
def stumpClassify(dataMatrix,dimen,threshVal,threshIneq):
#初始化retArray為1
retArray = np.ones((np.shape(dataMatrix)[0],1))
if threshIneq == 'lt':
retArray[dataMatrix[:,dimen] <= threshVal] = -1.0#如果小于門檻值,則指派為-1
else:
retArray[dataMatrix[:,dimen] > threshVal] = -1.0#如果大于門檻值,則指派為-1
return retArray
"""
buildStump()和三層for循環的詳細解釋《機器學習實戰》書籍中寫了。
此步驟利用單層決策樹尋找到分類錯誤率最低的門檻值即可。
Parameters:
dataArr - 資料矩陣
classLabels - 資料标簽
D - 樣本權重
Returns:
bestStump - 最佳單層決策樹資訊
minError - 最小誤差
bestClasEst - 最佳的分類結果
"""
def buildStump(dataArr,classLabels,D):
dataMatrix = np.mat(dataArr); labelMat = np.mat(classLabels).T
m,n = np.shape(dataMatrix)
numSteps = 10.0; bestStump = {}; bestClasEst = np.mat(np.zeros((m,1)))
minError = float('inf') #最小誤差初始化為正無窮大,後面再進行更新
for i in range(n): #周遊所有特征
rangeMin = dataMatrix[:,i].min(); rangeMax = dataMatrix[:,i].max() #找到特征中最小的值和最大值
stepSize = (rangeMax - rangeMin) / numSteps #計算步長
for j in range(-1, int(numSteps) + 1): #先對第一個特征周遊所有門檻值,再對第二....
for inequal in ['lt', 'gt']: #大于和小于的情況,均周遊。lt:less than,gt:greater than
threshVal = (rangeMin + float(j) * stepSize) #計算門檻值
predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)#計算分類結果
errArr = np.mat(np.ones((m,1))) #初始化誤差矩陣
errArr[predictedVals == labelMat] = 0 #分類正确的,指派為0
"""
下面計算誤差,這裡用的是權重乘以誤差
"""
weightedError = D.T * errArr
print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
if weightedError < minError: #找到誤差最小的分類方式
minError = weightedError
bestClasEst = predictedVals.copy()
bestStump['dim'] = i
bestStump['thresh'] = threshVal
bestStump['ineq'] = inequal
return bestStump, minError, bestClasEst
"""
這一步驟,第一個弱分類器就訓練好了
"""
3.基于單層決策樹的AdaBoost訓練過程
"""
整個實作的僞代碼如下:
對每次疊代:
利用buildStump()函數找到最佳的單層決策樹
将最佳單層決策樹 入到單層決策樹數組
計算alpha
計算新的權重向量D
更新累計類别估計值
如果錯誤率等于0,則退出循環
"""
def adaBoostTrainDS(dataArr, classLabels, numIt = 40):
weakClassArr = []
m = np.shape(dataArr)[0]
D = np.mat(np.ones((m, 1)) / m) #初始化權重
aggClassEst = np.mat(np.zeros((m,1)))
for i in range(numIt):
bestStump, error, classEst = buildStump(dataArr, classLabels, D) #建構單層決策樹。這一塊可以換成其他的弱分類器算法
print("D:",D.T)
#計算弱學習算法權重alpha,使error不等于0,因為分母不能為0,是以後面有1e-16
alpha = float(0.5 * np.log((1.0 - error) / max(error, 1e-16)))
bestStump['alpha'] = alpha #存儲弱學習算法權重
weakClassArr.append(bestStump) #存儲單層決策樹
print("classEst: ", classEst.T)
expon = np.multiply(-1 * alpha * np.mat(classLabels).T, classEst) #計算e的指數項
D = np.multiply(D, np.exp(expon))
D = D / D.sum() #根據樣本權重公式,更新樣本權重,歸一化
"""
計算AdaBoost誤差,當誤差為0的時候,退出循環,這一塊展現了內建的思想
"""
aggClassEst += alpha * classEst #這個地方是累加,計算類别估計累計值,這裡包括了目前已經訓練好的每一個弱分類器
print("aggClassEst: ", aggClassEst.T)
aggErrors = np.multiply(np.sign(aggClassEst) != np.mat(classLabels).T, np.ones((m,1))) #計算誤差,不一樣就設定為1,一樣就設定為0
errorRate = aggErrors.sum() / m
print("total error: ", errorRate)
if errorRate == 0.0: break #誤差為0,退出循環
return weakClassArr, aggClassEst
if __name__ == '__main__':
dataArr,classLabels = loadSimpData()
weakClassArr, aggClassEst = adaBoostTrainDS(dataArr, classLabels)
print(weakClassArr)
print(aggClassEst)
4.測試算法,基于Adaboost的分類
"""
Parameters:
datToClass - 待分類樣例
classifierArr - 訓練好的分類器
Returns:
分類結果
"""
def adaClassify(datToClass,classifierArr):
dataMatrix = np.mat(datToClass)
m = np.shape(dataMatrix)[0]
aggClassEst = np.mat(np.zeros((m,1)))
for i in range(len(classifierArr)): #周遊所有分類器,進行分類
classEst = stumpClassify(dataMatrix, classifierArr[i]['dim'], classifierArr[i]['thresh'], classifierArr[i]['ineq'])
aggClassEst += classifierArr[i]['alpha'] * classEst
print(aggClassEst)
return np.sign(aggClassEst)
if __name__ == '__main__':
dataArr,classLabels = loadSimpData()
weakClassArr, aggClassEst = adaBoostTrainDS(dataArr, classLabels)
print(adaClassify([[0,0],[5,5]], weakClassArr))
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.600
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.400
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.200
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.800
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.600
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.400
D: [[0.2 0.2 0.2 0.2 0.2]]
classEst: [[-1. 1. -1. -1. 1.]]
aggClassEst: [[-0.69314718 0.69314718 -0.69314718 -0.69314718 0.69314718]]
total error: 0.2
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.250
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.750
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.750
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.250
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.125
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.875
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.750
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.250
D: [[0.5 0.125 0.125 0.125 0.125]]
classEst: [[ 1. 1. -1. -1. -1.]]
aggClassEst: [[ 0.27980789 1.66610226 -1.66610226 -1.66610226 -0.27980789]]
total error: 0.2
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.143
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.857
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.857
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.143
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.143
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.857
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.500
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.500
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.857
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.143
D: [[0.28571429 0.07142857 0.07142857 0.07142857 0.5 ]]
classEst: [[1. 1. 1. 1. 1.]]
aggClassEst: [[ 1.17568763 2.56198199 -0.77022252 -0.77022252 0.61607184]]
total error: 0.0
[[-0.69314718]
[ 0.69314718]]
[[-1.66610226]
[ 1.66610226]]
[[-2.56198199]
[ 2.56198199]]
[[-1.]
[ 1.]]