在人工智能算法的實際應用場景中,不可避免地會出現訓練資料誤标現象,即訓練資料集上存在标簽噪聲。這會降低所訓模型的泛化能力。尤其是對于深度神經網絡這種描述能力極強的模型,标簽噪聲對推理精度的影響甚或是災難性的。論文證明了一個簡單的兩層網絡就能記住所有随機配置設定的标簽。
本篇博文将介紹一種對标簽噪聲魯棒的損失函數,即
General Cross Entropy
(
GCE
)。這種損失函數在2018年的
NIPS
會議論文中被提出,其內建了
Mean Absolute Error
(
MAE
)損失函數的噪聲魯棒性,以及傳統的
Cross Entropy
損失函數的訓練高效性。
博文主要内容如下:第一部分将論述
MAE
為何是噪聲魯棒的。第二部分将介紹
GCE
的具體表達;其被稱為廣義交叉熵的原因也将在這一部分揭示。第三部分給出了複現結果。最後将給出相關的參考連結。
一、 MAE
為何是噪聲魯棒的
MAE
一句話解釋,因為
MAE
是一種
Symmetric Loss
。那麼原來的一個問題就變成了兩個問題,(1) 什麼是
Symmetric Loss
;(2) 為什麼
Symmetric Loss
是噪聲魯棒的。
為了讨論第一個問題,我們考慮一個十分簡單的情形:隻包含一個訓練樣本的二分類問題。假設訓練樣本為 { x , y } \{x, y\} {x,y},其中 y ∈ { 0 , 1 } y\in\{0, 1\} y∈{0,1}; f θ ( x ) f_\theta(x) fθ(x)為模型輸出, θ \theta θ為待優化的參數,損失函數為 l l l。沒有标簽噪聲的情況下,待優化的目标為 l [ f θ ( x ) , y ] l[f_\theta(x),y] l[fθ(x),y]。考慮存在标簽噪聲的情形,那麼樣本 x x x有一定機率 ρ \rho ρ被誤标為 1 − y 1-y 1−y,那麼實際上的優化目标是 ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y] (1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]。
如果
arg min θ l [ f θ ( x ) , y ] = arg min θ { ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] } ( 1 ) \argmin_\theta l[f_\theta(x),y]=\argmin_\theta \{(1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]\}\qquad\qquad (1) θargminl[fθ(x),y]=θargmin{(1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]}(1)
那麼意味着無論有無噪聲,該優化問題都會得到同樣的解。這時候損失函數 l l l就是噪聲魯棒的。
( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] = ( 1 − 2 ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ { l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] } ( 2 ) (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]=(1-2\rho)\cdot l[f_\theta(x),y]+\rho\cdot\{l[f_\theta(x),y]+l[f_\theta(x),1-y]\}\qquad\qquad(2) (1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]=(1−2ρ)⋅l[fθ(x),y]+ρ⋅{l[fθ(x),y]+l[fθ(x),1−y]}(2)
整理後的第一項是無噪聲情況下的優化目标的一個固定倍數。而第二項是當樣本标簽等機率取遍所有可能值時,所産生的損失值。
敲黑闆,如果
l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] = C ( 3 ) l[f_\theta(x),y]+l[f_\theta(x),1-y]=C\qquad\qquad (3) l[fθ(x),y]+l[fθ(x),1−y]=C(3)
其中 C C C是常數時,那麼 l l l就是
Symmetric Loss
。這時候,(2)式右端的第二項對 θ \theta θ的優化不造成影響,是以(1)式就會成立,是以
Symmetric Loss
是對噪聲魯棒的。這裡的
Symmetric
是一種輪換對稱的含義,就是指當樣本等機率取遍所有可能标簽時,産生的損失值是常值。
對
MAE
,考慮到 f θ ( x ) ∈ [ 0 , 1 ] f_\theta(x)\in[0,1] fθ(x)∈[0,1],那麼
∣ y − f θ ( x ) ∣ + ∣ 1 − y − f θ ( x ) ∣ = ∣ 0 − f θ ( x ) ∣ + ∣ 1 − f θ ( x ) ∣ = 1 |y-f_\theta(x)|+|1-y-f_\theta(x)|=|0-f_\theta(x)|+|1-f_\theta(x)|=1 ∣y−fθ(x)∣+∣1−y−fθ(x)∣=∣0−fθ(x)∣+∣1−fθ(x)∣=1
是以
MAE
是
Symmetric Loss
,是以是噪聲魯棒的。
更嚴謹的論述,可以參考論文Making Risk Minimization Tolerant to Label Noise。
二、 GCE
GCE
在明了
MAE
為什麼對噪聲魯棒的機理後,我們再來看看
Cross Entropy Loss
為什麼是不魯棒的。如果還是用一句話來解釋,就是因為
CE
是無界的。考慮
(3)
式所表達的魯棒性條件,各個損失項非負,且其和為定值,那麼各個損失項必然是有界的。但在
CE
中,假設 f θ ( x ) f_\theta(x) fθ(x)表示樣本 x x x屬于類别
1
的機率,那麼損失值是 − log f θ ( x ) -\log f_\theta(x) −logfθ(x)或 − log ( 1 − f θ ( x ) ) -\log(1-f_\theta(x)) −log(1−fθ(x))。當 f θ ( x ) f_\theta(x) fθ(x)接近于
或者
1
時,損失值會非常大。這意味着模型會花費更多的功夫在這些樣本上。
這一特性是個雙刃劍。如果我們所有的訓練樣本的标注都是正确的,那麼這一特性使得模型關注于那些容易被錯判的樣本,是以訓練過程會十分高效。但如果訓練樣本存在标簽噪聲,這一特性将使得模型過度關注于誤标的樣本,以緻最終會得到一個過拟合到誤标樣本的模型。
好,那我們直接使用
MAE
不就行了麼。Hmmm…實際上
MAE
實戰效果并不好。因為
MAE
是平等地對待各個樣本的,是以其收斂速度比較慢 (具體的解釋可以參考
GCE
論文)。
那麼一個自然的想法就是能不能在
MAE
的噪聲魯棒性和
CE
的快速收斂性之間做一個折中。這就是
GCE
的基礎想法。具體做法是将
CE
中的 − log ( f θ ( x ) ) -\log(f_\theta(x)) −log(fθ(x))項,替換成一個指數項 1 − f θ q ( x ) q \frac{1-f_\theta^q(x)}{q} q1−fθq(x),其中幂次 q q q為一個超參數,取值範圍為 ( 0 , 1 ] (0,1] (0,1]。當 q = 1 q=1 q=1時,該指數項就蛻變為
MAE
的形式;當 q → 0 q\rightarrow 0 q→0時,由洛必達法則,該指數項将蛻變為
CE
的形式。是以 q q q控制着在
MAE
和
CE
之間的折中程度。原論文中僅給出了 q q q的經驗值,并未給出具體的選取方法。
為進一步抑制誤标資料的影響,原論文還給出了一種加強版的損失函數,該函數融合了樣本選擇的思想。其核心想法是,當訓練到一定程度後,模型已經擷取了關于資料的正确分類的模式資訊,此時,如果模型在樣本上輸出的機率值比較小,那麼這些樣本很可能是誤标樣本。是以,若将這些樣本從訓練資料中剔除,那麼接下來的訓練可能會更準确。具體做法是:假設在訓練過程中
validation
資料集上準确率最好的模型為 M b e s t \mathcal{M}_{best} Mbest,則在模型訓練一定次數後,每隔一定數目的
epoch
,将 M b e s t \mathcal{M}_{best} Mbest預測的機率值小于 k k k的樣本丢棄,而在其餘的資料上進行訓練。這裡又引入了一個新的超參數 k k k。同樣的,原論文隻給出了一些資料集上的經驗值,并未給明确的設定方法。
三、 Results
Results
以下為在
CIFAR-10
資料集上,傳統的交叉熵損失與廣義交叉熵損失的結果對比。所用參數與論文中一緻:
training
/
validation
/
testing
資料集大小分别為
45000
/
5000
/
10000
;模型選用
ResNet34
,注意其中第一個卷積層的參數為
kernel_size
=3,
stride
=1,而用于
ImageNet
資料集時
kernel_size
=7,
stride
=2,這是因為
CIFAR-10
的圖檔尺寸比較小,另外去除了第一個
pooling
層; q = 0.7 , k = 0.5 q=0.7, k=0.5 q=0.7,k=0.5;總共訓練120個
epochs
,初始學習速率為
1E-2
,第
40
、第
80
個
epoch
學習速率遞降
10
倍;從第
40
個
epoch
開始,每
10
個
epoch
篩選一遍資料。每組實驗運作了5次,結果表示為分類準确率的 μ ± σ \mu\pm\sigma μ±σ的形式。複現結果與論文所示大緻相當,甚至還要略優于論文中的結果。其中可以看到,存在
pair wise
類型的标簽噪聲時,準确率的提升不是特别明顯。
表一. 不同方法的對比
參考
- Generalized Cross Entropy:
的論文。GCE
- pytorch-Truncated-Loss.git: 基于
實作的pytorch
。GCE