输出结果
ML之DT:基于简单回归问题训练决策树(DIY数据集+七种{1~7}深度的决策树{依次进行10交叉验证})
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLicmbw5iM2gTN4ETY2QTZ2UTNhdDO4EjY0M2MxQTMklzNjZWZm9CX5d2bs92Yl1iclB3bsVmdlR2LcNWaw9CXt92Yu4GZjlGbh5yYjV3Lc9CX6MHc0RHaiojIsJye.png)
设计思路
ML之DT:基于简单回归问题训练决策树(DIY数据集+七种{1~7}深度的决策树{依次进行10交叉验证})
核心代码
for iDepth in depthList:
for ixval in range(nxval):
idxTest = [a for a in range(nrow) if a%nxval == ixval%nxval]
idxTrain = [a for a in range(nrow) if a%nxval != ixval%nxval]
xTrain = [x[r] for r in idxTrain]
xTest = [x[r] for r in idxTest]
yTrain = [y[r] for r in idxTrain]
yTest = [y[r] for r in idxTest]
treeModel = DecisionTreeRegressor(max_depth=iDepth)
treeModel.fit(xTrain, yTrain)
treePrediction = treeModel.predict(xTest)
error = [yTest[r] - treePrediction[r] for r in range(len(yTest))]
if ixval == 0:
oosErrors = sum([e * e for e in error])
else:
oosErrors += sum([e * e for e in error])
mse = oosErrors/nrow
xvalMSE.append(mse)