天天看點

機器學習實戰之線性回歸

機器學習實戰之線性回歸

之前我們學習的機器學習算法都是屬于分類算法,也就是預測值是離散值。當預測值為連續值時,就需要使用回歸算法。本文将介紹線性回歸的原理和代碼實作。

線性回歸原理與推導

如圖所示,這時一組二維的資料,我們先想想如何通過一條直線較好的拟合這些散點了?直白的說:盡量讓拟合的直線穿過這些散點(這些點離拟合直線很近)。

機器學習實戰之線性回歸
目标函數

要使這些點離拟合直線很近,我們需要用數學公式來表示。首先,我們要求的直線公式為:Y = XTw。我們這裡要求的就是這個w向量(類似于logistic回歸)。誤差最小,也就是預測值y和真實值的y的內插補點小,我們這裡采用平方誤差:

機器學習實戰之線性回歸
求解

我們所需要做的就是讓這個平方誤差最小即可,那就對w求導,最後w的計算公式為:

機器學習實戰之線性回歸

我們稱這個方法為OLS,也就是“普通最小二乘法”

線性回歸實踐

資料情況

我們首先讀入資料并用matplotlib庫來顯示這些資料。

def loadDataSet(filename):
    numFeat = len(open(filename).readline().split('\t')) - 1
    dataMat = [];labelMat = []
    fr = open(filename)
    for line in fr.readlines():
        lineArr = []
        curLine = line.strip().split('\t')
        for i in range(numFeat):
            lineArr.append(float(curLine[i]))
        dataMat.append(lineArr)
        labelMat.append(float(curLine[-1]))
    return dataMat, labelMat
           
機器學習實戰之線性回歸
回歸算法

這裡直接求w就行,然後對直線進行可視化。

def standRegres(Xarr,yarr):
    X = mat(Xarr);y = mat(yarr).T
    XTX = X.T * X
    if linalg.det(XTX) == 0:
        print('不能求逆')
        return
    w = XTX.I * (X.T*y)
    return w
           
機器學習實戰之線性回歸

算法優缺點

  • 優點:易于了解和計算
  • 缺點:精度不高