天天看點

機器學習實戰之線性回歸

線性回歸原理與推導

如圖所示,這時一組二維的資料,我們先想想如何通過一條直線較好的拟合這些散點了?直白的說:盡量讓拟合的直線穿過這些散點(這些點離拟合直線很近)。

機器學習實戰之線性回歸

目标函數

要使這些點離拟合直線很近,我們需要用數學公式來表示。首先,我們要求的直線公式為:Y = XTw。我們這裡要求的就是這個w向量(類似于logistic回歸)。誤差最小,也就是預測值y和真實值的y的內插補點小,我們這裡采用平方誤差:

機器學習實戰之線性回歸

求解

我們所需要做的就是讓這個平方誤差最小即可,那就對w求導,最後w的計算公式為:

機器學習實戰之線性回歸

我們稱這個方法為OLS,也就是“普通最小二乘法”

線性回歸實踐

資料情況

我們首先讀入資料并用matplotlib庫來顯示這些資料。

def loadDataSet(filename):
    numFeat = len(open(filename).readline().split('\t')) - 1
    dataMat = [];labelMat = []
    fr = open(filename)
    for line in fr.readlines():
        lineArr = []
        curLine = line.strip().split('\t')
        for i in range(numFeat):
            lineArr.append(float(curLine[i]))
        dataMat.append(lineArr)
        labelMat.append(float(curLine[-1]))
    return dataMat, labelMat
           
機器學習實戰之線性回歸

回歸算法

這裡直接求w就行,然後對直線進行可視化。

def standRegres(Xarr,yarr):
        X = mat(Xarr);y = mat(yarr).T
        XTX = X.T * X
        if linalg.det(XTX) == 0:
            print('不能求逆')
            return
        w = XTX.I * (X.T*y)
        return w
           

算法優缺點

優點:易于了解和計算

缺點:精度不高

原文釋出時間為:2018-07-01

本文作者:羅羅攀

本文來自雲栖社群合作夥伴“

Python愛好者社群

”,了解相關資訊可以關注“

”。