![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIn5GcuY2Y0kTNkJGOkRDN2czM2IDZ2ImZyQmMyMDZwUDZjNGZfdWbp9CXt92Yu4GZjlGbh5SZslmZxl3Lc9CX6MHc0RHaiojIsJye.png)
“回歸”與“樹”
在講解樹回歸之前,我們看看回歸和樹巧妙結合的原因。
線性回歸的弊端
- 線性回歸需要拟合所有樣本點,在特征多且特征關系複雜時,建構全局模型的想法就顯得太難。
- 實際生活中,問題很大程度上不是線性的,而是非線性的,是以線性回歸的很容易欠拟合。
傳統決策樹弊端與改進
決策樹可以解決資料的非線性問題,而且直覺易懂,是否可以通過決策樹來實作回歸任務?
我們來回顧下之前講過的決策樹方法,其在劃分子集的時候使用的方法是資訊增益(我們也叫ID3方法),其方法隻針對标稱型(離散型)資料有效,很難用于回歸;而且ID3算法切分過于迅速,容易過拟合,例如:一個特征有4個值,資料就會被切為四份,切分過後的特征在後面的過程中不再起作用。
CART(分類回歸樹)算法可以解決掉ID3的問題,該算法可用于分類和回歸。我們來看看針對ID3算法的問題,CART算法是怎樣解決的。
- 資訊增益無法切分連續型資料,如何計算連續型資料的混亂程度?其實,連續型的資料計算混亂程度很簡單,根本不需要資訊熵的理論。我們隻需要計算平方誤差的總值即可(先計算資料的均值,然後計算每條資料到均值的內插補點,進行平方求和)。
- ID3方法切分太快,CART算法采用二進制切分。
回歸樹
基于CART算法,當葉節點是分類值,就會是分類算法;如果是常數值(也就是回歸需要預測的值),就可以實作回歸算法。這裡的常數值的求解很簡單,就是該劃分資料的均值。
資料情況
首先,利用代碼帶入資料,資料情況如圖所示。
from numpy import *
def loadDataSet(filename):
dataMat = []
fr = open(filename)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float,curLine))
dataMat.append(fltLine)
return dataMat
代碼
其實CART算法直覺(代碼卻比較多。。。),其實隻用做兩件事:切分資料和構造樹。我們以這個資料為例:首先切分資料,找到一個中心點(平方誤差的總值最小),這樣就完成了劃分(左下和右上),然後構造樹(求左下和右上的均值為葉子節點)。我們來看代碼:
def regLeaf(dataSet):
return mean(dataSet[:,-1])
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0];tolN = ops[1]
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0;bestValue = 0
for featIndex in range(n-1):
for splitVal in set((dataSet[:,featIndex].T.tolist())[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0,mat1
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
看下結果,和我想的是一緻的。
模型樹
回歸樹的葉節點是常數值,而模型樹的葉節點是一個回歸方程。
讀入資料進行可視化,你會發現,這種資料如果用回歸樹拟合效果不好,如果切分為兩段,每段是一個回歸方程,就可以很好的對資料進行拟合。
前面的代碼大部分是不變的,隻需要少量修改就可以完成模型樹。
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m, n)))
Y = mat(ones((m, 1)))
X[:, 1: n] = dataSet[:, 0: n-1]
Y = dataSet[:, -1]
xTx = X.T * X
if linalg.det(xTx) == 0.0:
raise NameError('錯誤')
ws = xTx.I * (X.T * Y)
return ws, X, Y
結果如圖所示:
算法優缺點
- 優點:可對複雜資料進行模組化
- 缺點:容易過拟合