天天看點

徒手寫代碼之《機器學習實戰》---adaboost算法(1)

adaboost理論部分(公式)後期補充

1.建立資料集

"""
此處不建構太複雜的資料集,不然可能後面用單層決策樹時候,效果不好。
因為用任何一個單層決策樹都無法完全分開這五個資料
"""
import numpy as np
def loadSimpData():
    datMat = np.matrix([[1,2.1],[2,1.1],[1.3,1],[1,1],[2,1]])
    classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
    return datMat,classLabels
loadSimpData()
           
(matrix([[1. , 2.1],
         [2. , 1.1],
         [1.3, 1. ],
         [1. , 1. ],
         [2. , 1. ]]), [1.0, 1.0, -1.0, -1.0, 1.0])
           

2.單層決策樹生成函數

"""
單層決策樹是一個很弱很弱的分類器,因為它隻選擇了一個特征就不繼續選擇了。
第一個函數stumpClassify()是通過門檻值比較對資料進彳了分類的。
所有在門檻值一邊的資料會分到類别-i, 而在另外一邊的資料分到類别+l。該函數可以通過數組過
濾來實作,首先将傳回數組的全部元素設定為1 ,然後将所有不滿足不等式要求的元素設定為-1。
可以基于資料集中的任一進制素進行比較,同時也可以将不等号在大于、小于之間切換。
Parameters:
        dataMatrix - 資料矩陣
        dimen - 第dimen列,也就是第幾個特征
        threshVal - 門檻值
        threshIneq - 标志
Returns:
        retArray - 分類結果
"""
def stumpClassify(dataMatrix,dimen,threshVal,threshIneq):
     #初始化retArray為1
    retArray = np.ones((np.shape(dataMatrix)[0],1))  
    if threshIneq == 'lt':
        retArray[dataMatrix[:,dimen] <= threshVal] = -1.0#如果小于門檻值,則指派為-1
    else:
        retArray[dataMatrix[:,dimen] > threshVal] = -1.0#如果大于門檻值,則指派為-1
    return retArray



"""
buildStump()和三層for循環的詳細解釋《機器學習實戰》書籍中寫了。
此步驟利用單層決策樹尋找到分類錯誤率最低的門檻值即可。
Parameters:
        dataArr - 資料矩陣
        classLabels - 資料标簽
        D - 樣本權重
Returns:
        bestStump - 最佳單層決策樹資訊
        minError - 最小誤差
        bestClasEst - 最佳的分類結果
"""
def buildStump(dataArr,classLabels,D):
    dataMatrix = np.mat(dataArr); labelMat = np.mat(classLabels).T
    m,n = np.shape(dataMatrix)
    numSteps = 10.0; bestStump = {}; bestClasEst = np.mat(np.zeros((m,1)))
    minError = float('inf')  #最小誤差初始化為正無窮大,後面再進行更新
    for i in range(n):  #周遊所有特征
        rangeMin = dataMatrix[:,i].min(); rangeMax = dataMatrix[:,i].max() #找到特征中最小的值和最大值
        stepSize = (rangeMax - rangeMin) / numSteps  #計算步長
        for j in range(-1, int(numSteps) + 1): #先對第一個特征周遊所有門檻值,再對第二....                                  
            for inequal in ['lt', 'gt']: #大于和小于的情況,均周遊。lt:less than,gt:greater than
                threshVal = (rangeMin + float(j) * stepSize) #計算門檻值
                predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)#計算分類結果
                errArr = np.mat(np.ones((m,1)))               #初始化誤差矩陣
                errArr[predictedVals == labelMat] = 0       #分類正确的,指派為0
                 """ 
                  下面計算誤差,這裡用的是權重乘以誤差
                 """
                weightedError = D.T * errArr   
                print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
                if weightedError < minError:         #找到誤差最小的分類方式
                    minError = weightedError
                    bestClasEst = predictedVals.copy()
                    bestStump['dim'] = i
                    bestStump['thresh'] = threshVal
                    bestStump['ineq'] = inequal
    return bestStump, minError, bestClasEst

"""
這一步驟,第一個弱分類器就訓練好了
"""
           

3.基于單層決策樹的AdaBoost訓練過程

"""
整個實作的僞代碼如下:
對每次疊代:
利用buildStump()函數找到最佳的單層決策樹
将最佳單層決策樹 入到單層決策樹數組
計算alpha
計算新的權重向量D
更新累計類别估計值
如果錯誤率等于0,則退出循環
"""
def adaBoostTrainDS(dataArr, classLabels, numIt = 40):
    weakClassArr = []
    m = np.shape(dataArr)[0]
    D = np.mat(np.ones((m, 1)) / m)         #初始化權重
    aggClassEst = np.mat(np.zeros((m,1)))
    for i in range(numIt):
        bestStump, error, classEst = buildStump(dataArr, classLabels, D)     #建構單層決策樹。這一塊可以換成其他的弱分類器算法
        print("D:",D.T)
        #計算弱學習算法權重alpha,使error不等于0,因為分母不能為0,是以後面有1e-16
        alpha = float(0.5 * np.log((1.0 - error) / max(error, 1e-16)))  
        bestStump['alpha'] = alpha                       #存儲弱學習算法權重
        weakClassArr.append(bestStump)                   #存儲單層決策樹
        print("classEst: ", classEst.T)
        expon = np.multiply(-1 * alpha * np.mat(classLabels).T, classEst)    #計算e的指數項
        D = np.multiply(D, np.exp(expon))                                      
        D = D / D.sum() #根據樣本權重公式,更新樣本權重,歸一化
        """
        計算AdaBoost誤差,當誤差為0的時候,退出循環,這一塊展現了內建的思想
        """
        aggClassEst += alpha * classEst #這個地方是累加,計算類别估計累計值,這裡包括了目前已經訓練好的每一個弱分類器                         
        print("aggClassEst: ", aggClassEst.T)
        aggErrors = np.multiply(np.sign(aggClassEst) != np.mat(classLabels).T, np.ones((m,1)))     #計算誤差,不一樣就設定為1,一樣就設定為0
        errorRate = aggErrors.sum() / m
        print("total error: ", errorRate)
        if errorRate == 0.0: break          #誤差為0,退出循環
    return weakClassArr, aggClassEst

if __name__ == '__main__':
    dataArr,classLabels = loadSimpData()
    weakClassArr, aggClassEst = adaBoostTrainDS(dataArr, classLabels)
    print(weakClassArr)
    print(aggClassEst)

           

4.測試算法,基于Adaboost的分類

"""
Parameters:
        datToClass - 待分類樣例
        classifierArr - 訓練好的分類器
Returns:
        分類結果
"""
def adaClassify(datToClass,classifierArr):
    dataMatrix = np.mat(datToClass)
    m = np.shape(dataMatrix)[0]
    aggClassEst = np.mat(np.zeros((m,1)))
    for i in range(len(classifierArr)): #周遊所有分類器,進行分類
        classEst = stumpClassify(dataMatrix, classifierArr[i]['dim'], classifierArr[i]['thresh'], classifierArr[i]['ineq'])            
        aggClassEst += classifierArr[i]['alpha'] * classEst
        print(aggClassEst)
    return np.sign(aggClassEst)
if __name__ == '__main__':
    dataArr,classLabels = loadSimpData()
    weakClassArr, aggClassEst = adaBoostTrainDS(dataArr, classLabels)
    print(adaClassify([[0,0],[5,5]], weakClassArr))

           
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.600
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.200
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.800
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.600
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.400
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.200
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.800
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.400
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.600
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.600
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.400
D: [[0.2 0.2 0.2 0.2 0.2]]
classEst:  [[-1.  1. -1. -1.  1.]]
aggClassEst:  [[-0.69314718  0.69314718 -0.69314718 -0.69314718  0.69314718]]
total error:  0.2
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.250
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.750
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.625
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.375
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.500
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.500
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.750
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.250
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.125
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.875
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.250
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.750
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.750
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.250
D: [[0.5   0.125 0.125 0.125 0.125]]
classEst:  [[ 1.  1. -1. -1. -1.]]
aggClassEst:  [[ 0.27980789  1.66610226 -1.66610226 -1.66610226 -0.27980789]]
total error:  0.2
split: dim 0, thresh 0.90, thresh ineqal: lt, the weighted error is 0.143
split: dim 0, thresh 0.90, thresh ineqal: gt, the weighted error is 0.857
split: dim 0, thresh 1.00, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.00, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.10, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.10, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.20, thresh ineqal: lt, the weighted error is 0.357
split: dim 0, thresh 1.20, thresh ineqal: gt, the weighted error is 0.643
split: dim 0, thresh 1.30, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.30, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.40, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.40, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.50, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.50, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.60, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.60, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.70, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.70, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.80, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.80, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 1.90, thresh ineqal: lt, the weighted error is 0.286
split: dim 0, thresh 1.90, thresh ineqal: gt, the weighted error is 0.714
split: dim 0, thresh 2.00, thresh ineqal: lt, the weighted error is 0.857
split: dim 0, thresh 2.00, thresh ineqal: gt, the weighted error is 0.143
split: dim 1, thresh 0.89, thresh ineqal: lt, the weighted error is 0.143
split: dim 1, thresh 0.89, thresh ineqal: gt, the weighted error is 0.857
split: dim 1, thresh 1.00, thresh ineqal: lt, the weighted error is 0.500
split: dim 1, thresh 1.00, thresh ineqal: gt, the weighted error is 0.500
split: dim 1, thresh 1.11, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.11, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.22, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.22, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.33, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.33, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.44, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.44, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.55, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.55, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.66, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.66, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.77, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.77, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.88, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.88, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 1.99, thresh ineqal: lt, the weighted error is 0.571
split: dim 1, thresh 1.99, thresh ineqal: gt, the weighted error is 0.429
split: dim 1, thresh 2.10, thresh ineqal: lt, the weighted error is 0.857
split: dim 1, thresh 2.10, thresh ineqal: gt, the weighted error is 0.143
D: [[0.28571429 0.07142857 0.07142857 0.07142857 0.5       ]]
classEst:  [[1. 1. 1. 1. 1.]]
aggClassEst:  [[ 1.17568763  2.56198199 -0.77022252 -0.77022252  0.61607184]]
total error:  0.0
[[-0.69314718]
 [ 0.69314718]]
[[-1.66610226]
 [ 1.66610226]]
[[-2.56198199]
 [ 2.56198199]]
[[-1.]
 [ 1.]]
           

繼續閱讀