最近有一些關于資料是否是新石油的争論。 無論如何,為我們的機器學習工作擷取訓練資料可能是昂貴的(在工時、許可費、裝置運作時間等方面)。 是以,機器學習項目中的一個關鍵問題是确定需要多少訓練資料才能實作特定的性能目标(即分類器準确性)。
在這篇文章中,我們将在從回歸分析到深度學習等領域對有關訓練資料大小的實證和研究文獻結果進行快速但廣泛的審查。 訓練資料大小問題在文獻中也稱為樣本複雜度。 具體來說,我們将:
- 說明回歸任務和計算機視覺任務訓練資料的經驗範圍;
- 給定統計檢驗的檢驗效能,讨論如何确定樣本數量。這是一個統計學的話題,然而,由于它與确定機器學習訓練資料量密切相關,是以也将包含在本讨論中;
- 展示統計理論學習的結果,說明是什麼決定了訓練資料的多少;
- 給出下面問題的答案:随着訓練資料的增加,模型性能是否會繼續改善?在深度學習的情況下又會如何?
- 提出一種在分類任務中确定訓練資料量的方法;
- 最後,我們将回答這個問題:增加訓練資料是處理資料不平衡的最佳方式嗎?
1、訓練資料大小的經驗界限
讓我們首先根據我們使用的模型類型讨論一些廣泛使用的經驗方法來确定訓練資料的大小:
回歸分析:根據 1/10 的經驗規則,每個預測因子 [3] 需要 10 個樣例。在 [4] 中讨論了這種方法的其他版本,比如用 1/20 來處理回歸系數減小的問題,在 [5] 中提出了一個令人興奮的二進制邏輯回歸變量。具體地說,作者通過考慮預測變量的數量、總體樣本量以及正樣本量/總體樣本量的比例來估計訓練資料的多少。
計算機視覺:對于使用深度學習的圖像分類,經驗法則是每一個分類需要 1000 幅圖像,如果使用預訓練的模型 [6],這個需求可以顯著下降。
2、假設檢驗中樣本大小的确定
假設檢驗是資料科學家可以用來檢驗人群差異、确定新藥效果等的工具之一。考慮到檢驗的功效,在這裡通常需要确定樣本量。
讓我們考慮這個例子:一家科技巨頭搬到了 A 市,那裡的房價大幅上漲。 記者想知道,較高價的電梯大廈的新平均價格是多少。 鑒于較高價的電梯大廈價格的标準偏差為 60K,可接受的誤差幅度為 10K,他應該平均多少較高價的電梯大廈銷售價格,有 95% 的置信度? 相應的公式如下所示; N 是他需要的樣本量,1.96 是對應于 95% 置信度的标準正态分布的數字。
根據上述等式,記者将需要考慮大約 138 套較高價的電梯大廈的價格。上述公式根據具體測試而變化,但它始終包括置信區間、可接受的誤差幅度和标準偏差的度量。
3、訓練資料大小的統計學習理論
讓我們首先介紹一下著名的 Vapnik-Chevronenkis 次元 ( VC 維) [8]。VC 維是模型複雜度的度量,模型越複雜,VC 維越大。在下一段中,我們将介紹一個用 VC 表示訓練資料大小的公式。
首先,讓我們看一個經常用于展示 VC 維如何計算的例子:假設我們的分類器是二維平面上的一條直線,有 3 個點需要分類。
無論這 3 個點的正/負組合是什麼(都是正的、2個正的、1個正的,等等),一條直線都可以正确地分類/區分這些正樣本和負樣本。
我們說線性分類器可以區分所有的點,是以,它的 VC 維至少是 3,又因為我們可以找到4個不能被直線準确區分的點的例子,是以我們說線性分類器的 VC 維正好是3。結果表明,訓練資料大小 N 是 VC 的函數 [8]:
從 VC 維估計訓練資料的大小
其中 d 為失效機率,epsilon 為學習誤差。是以,正如 [9] 所指出的,學習所需的資料量取決于模型的複雜度。一個明顯的例子是衆所周知的神經網絡對訓練資料的貪婪,因為它們非常複雜。
4、随着訓練資料的增加,模型性能會繼續提高嗎?在深度學習的情況下又會怎樣?
學習曲線
上圖展示了在傳統機器學習 [10] 算法(回歸等)和深度學習 [11] 的情況下,機器學習算法的性能随着資料量的增加而如何變化。
具體來說,對于傳統的機器學習算法,性能是按照幂律增長的,一段時間後趨于平穩。 文獻 [12]-[16],[18] 的研究展示了對于深度學習,随着資料量的增加性能如何變化。
圖1顯示了目前大多數研究的共識:對于深度學習,根據幂次定律,性能會随着資料量的增加而增加。
例如,在文獻 [13] 中,作者使用深度學習技術對3億幅圖像進行分類,他們發現随着訓練資料的增加模型性能呈對數增長。
讓我們看看另一些在深度學習領域值得注意的,與上述沖突的結果。具體來說,在文獻 [15] 中,作者使用卷積網絡來處理 1 億張 Flickr 圖檔和标題的資料集。
對于訓練集的資料量,他們報告說,模型性能會随着資料量的增加而增加,然而,在 5000 萬張圖檔之後,它就停滞不前了。
在文獻[16]中,作者發現圖像分類準确度随着訓練集的增大而增加,然而,模型的魯棒性在超過與模型特定相關的某一點後便開始下降。
5、在分類任務中确定訓練資料量的方法
衆所周知的學習曲線,通常是誤差與訓練資料量的關系圖。[17] 和 [18] 是了解機器學習中學習曲線以及它們如何随着偏差或方差的增加而變化的參考資料。Python 在 scikit-learn [17] 也中提供了一個學習曲線的函數。
在分類任務中,我們通常使用一個稍微變化的學習曲線形式:分類準确度與訓練資料量的關系圖。
确定訓練資料量的方法很簡單:首先根據任務确定一個學習曲線形式,然後簡單地在圖上找到所需分類準确度對應的點。例如,在文獻 [19]、[20] 中,作者在醫學領域中使用了學習曲線法,并用幂律函數表示:
學習曲線方程
上式中 y 為分類準确度,x 為訓練資料,b1、b2 分别對應學習率和衰減率。參數的設定随問題的不同而變化,可以用非線性回歸或權重非線性回歸對它們進行估計。
6、增加訓練資料是處理資料不平衡的最好方法嗎?
這個問題在文獻 [9] 中得到了解決。作者提出了一個有趣的觀點:在資料不平衡的情況下,準确性并不是衡量分類器性能的最佳名額。
原因很直覺:讓我們假設負樣本是占絕大多數,然後如果我們在大部分時間裡都預測為負樣本,就可以達到很高的準确度。
相反,他們建議準确度和召回率(也稱為靈敏度)是衡量資料不平衡性能的最合适名額。除了上述明顯的準确度問題外,作者還認為,測量精度對不平衡區域的内在影響更大。
例如,在醫院的警報系統中,高精确度意味着當警報響起時,病人很可能确實有問題。
選擇适當的性能測量方法,作者比較了在 imbalanced-learn (Python scikit-learn 庫)中的不平衡校正方法和簡單的使用一個更大的訓練資料集。
具體地說,他們在一個 50,000 個樣本的藥物相關的資料集上,使用 imbalance-correction 中的K近鄰方法進行資料不平衡校正,這些不平衡校正技術包括欠采樣、過采樣和內建學習等,然後在與原資料集相近的 100 萬資料集上訓練了一個神經網絡。
作者重複實驗了 200 次,最終的結論簡單而深刻:在測量準确度和召回率方面,沒有任何一種不平衡校正技術可以與增加更多的訓練資料相媲美。
轉載自:https://towardsdatascience.com/how-do-you-know-you-have-enough-training-data-ad9b1fd679ee