天天看點

k折交叉驗證 k-fold cross-validation

文章目錄

    • k折交叉驗證
    • k值的确定
    • 執行個體
    • 使用scikit-learn進行交叉驗證

交叉驗證是用來評估機器學習方法的有效性的統計學方法,可以使用有限的樣本數量來評估模型對于驗證集或測試集資料的效果。

k折交叉驗證

參數 k k k表示,将給定的樣本資料分割成 k k k組。 k = 10 k=10 k=10時,稱為10折交叉驗證。

流程如下:

  1. 将資料集随機打亂。Shuffle the dataset randomly.
  2. 将資料集随機分割為 k k k組。
  3. 對于每一個組,進行如下操作:
    1. 将這一個組的資料當做測試集
    2. 剩餘的 k − 1 k-1 k−1個組的資料當做訓練集
    3. 使用測試集訓練模型,并在測試集上進行評測
    4. 保留評測的分數,抛棄模型
  4. 使用 k k k次評測分數的平均值來總結模型的性能,有時也會統計 k k k次評測分數的方差。

總結:每一個樣本都做了一次測試集,做了 k − 1 k-1 k−1次訓練集。

k值的确定

k k k值必須仔細确定。不合适的 k k k值會導緻不能準确評估模型的性能,會得出high variance或high bias的結果。There is a bias-variance trade-off associated with the choice of k in k-fold cross-validation

選擇 k k k值的幾點政策:

  1. 資料的代表性: k k k值必須使得每一組訓練集和測試集中的樣本數量都足夠大,使其在統計學意義上可以代表更廣泛的資料。
  2. k = 10 k=10 k=10:這是一個經過廣泛的實驗得到的一個經驗值。所得的結果會有較低的偏差和适量的方差 (low bias and modest variance)
  3. k = n k=n k=n:其中 n n n是樣本的數量。這樣一來,使得一個樣本作為測試集,其他樣本作為訓練集。是以也被稱作 leave-one-out 交叉驗證 (LOOCV)。

總結: k k k沒有固定值,不過通常取值5或10。随着 k k k值的增大,訓練集的大小和采樣子集之間的差異變小,對模型評估的偏差也會減小。 k = 10 k=10 k=10總是一個較優的選擇。

執行個體

給定一個樣本資料集如下:

標明 k = 3 k=3 k=3,則樣本分為3組,每組2個資料。

Fold1: [0.5, 0.2]
Fold2: [0.1, 0.3]
Fold3: [0.4, 0.6]
           

接下來訓練三個模型,并使用對應的測試集進行評測。

Model1: Trained on Fold1 + Fold2, Tested on Fold3

Model2: Trained on Fold2 + Fold3, Tested on Fold1

Model3: Trained on Fold1 + Fold3, Tested on Fold2

使用scikit-learn進行交叉驗證

使用KFold類:

其中:3表示 k = 3 k=3 k=3,True表示随機打亂資料集,1是随機數的種子seed。

接下來使用split函數将樣本進行分組:

# enumerate splits
for train, test in kfold.split(data):
	print('train: %s, test: %s' % (data[train], data[test]))
           

其中的train和test都是原始資料在data數組中的索引值。

完整的代碼如下:

# scikit-learn k-fold cross-validation
from numpy import array
from sklearn.model_selection import KFold
# data sample
data = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
# prepare cross validation
kfold = KFold(3, True, 1)
# enumerate splits
for train, test in kfold.split(data):
	print('train: %s, test: %s' % (data[train], data[test]))
           

輸出結果為:

# enumerate splits
train: [0.1 0.4 0.5 0.6], test: [0.2 0.3]
train: [0.2 0.3 0.4 0.6], test: [0.1 0.5]
train: [0.1 0.2 0.3 0.5], test: [0.4 0.6]
           

參考連結:

https://machinelearningmastery.com/k-fold-cross-validation/

繼續閱讀