天天看點

對标簽噪聲魯棒的廣義交叉熵損失 (Generalized Cross Entropy)

在人工智能算法的實際應用場景中,不可避免地會出現訓練資料誤标現象,即訓練資料集上存在标簽噪聲。這會降低所訓模型的泛化能力。尤其是對于深度神經網絡這種描述能力極強的模型,标簽噪聲對推理精度的影響甚或是災難性的。論文證明了一個簡單的兩層網絡就能記住所有随機配置設定的标簽。

本篇博文将介紹一種對标簽噪聲魯棒的損失函數,即

General Cross Entropy

(

GCE

)。這種損失函數在2018年的

NIPS

會議論文中被提出,其內建了

Mean Absolute Error

(

MAE

)損失函數的噪聲魯棒性,以及傳統的

Cross Entropy

損失函數的訓練高效性。

博文主要内容如下:第一部分将論述

MAE

為何是噪聲魯棒的。第二部分将介紹

GCE

的具體表達;其被稱為廣義交叉熵的原因也将在這一部分揭示。第三部分給出了複現結果。最後将給出相關的參考連結。

一、

MAE

為何是噪聲魯棒的

一句話解釋,因為

MAE

是一種

Symmetric Loss

。那麼原來的一個問題就變成了兩個問題,(1) 什麼是

Symmetric Loss

;(2) 為什麼

Symmetric Loss

是噪聲魯棒的。

為了讨論第一個問題,我們考慮一個十分簡單的情形:隻包含一個訓練樣本的二分類問題。假設訓練樣本為 { x , y } \{x, y\} {x,y},其中 y ∈ { 0 , 1 } y\in\{0, 1\} y∈{0,1}; f θ ( x ) f_\theta(x) fθ​(x)為模型輸出, θ \theta θ為待優化的參數,損失函數為 l l l。沒有标簽噪聲的情況下,待優化的目标為 l [ f θ ( x ) , y ] l[f_\theta(x),y] l[fθ​(x),y]。考慮存在标簽噪聲的情形,那麼樣本 x x x有一定機率 ρ \rho ρ被誤标為 1 − y 1-y 1−y,那麼實際上的優化目标是 ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y] (1−ρ)⋅l[fθ​(x),y]+ρ⋅l[fθ​(x),1−y]。

如果

arg min ⁡ θ l [ f θ ( x ) , y ] = arg min ⁡ θ { ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] } ( 1 ) \argmin_\theta l[f_\theta(x),y]=\argmin_\theta \{(1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]\}\qquad\qquad (1) θargmin​l[fθ​(x),y]=θargmin​{(1−ρ)⋅l[fθ​(x),y]+ρ⋅l[fθ​(x),1−y]}(1)

那麼意味着無論有無噪聲,該優化問題都會得到同樣的解。這時候損失函數 l l l就是噪聲魯棒的。

( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] = ( 1 − 2 ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ { l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] } ( 2 ) (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]=(1-2\rho)\cdot l[f_\theta(x),y]+\rho\cdot\{l[f_\theta(x),y]+l[f_\theta(x),1-y]\}\qquad\qquad(2) (1−ρ)⋅l[fθ​(x),y]+ρ⋅l[fθ​(x),1−y]=(1−2ρ)⋅l[fθ​(x),y]+ρ⋅{l[fθ​(x),y]+l[fθ​(x),1−y]}(2)

整理後的第一項是無噪聲情況下的優化目标的一個固定倍數。而第二項是當樣本标簽等機率取遍所有可能值時,所産生的損失值。

敲黑闆,如果

l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] = C ( 3 ) l[f_\theta(x),y]+l[f_\theta(x),1-y]=C\qquad\qquad (3) l[fθ​(x),y]+l[fθ​(x),1−y]=C(3)

其中 C C C是常數時,那麼 l l l就是

Symmetric Loss

。這時候,(2)式右端的第二項對 θ \theta θ的優化不造成影響,是以(1)式就會成立,是以

Symmetric Loss

是對噪聲魯棒的。這裡的

Symmetric

是一種輪換對稱的含義,就是指當樣本等機率取遍所有可能标簽時,産生的損失值是常值。

MAE

,考慮到 f θ ( x ) ∈ [ 0 , 1 ] f_\theta(x)\in[0,1] fθ​(x)∈[0,1],那麼

∣ y − f θ ( x ) ∣ + ∣ 1 − y − f θ ( x ) ∣ = ∣ 0 − f θ ( x ) ∣ + ∣ 1 − f θ ( x ) ∣ = 1 |y-f_\theta(x)|+|1-y-f_\theta(x)|=|0-f_\theta(x)|+|1-f_\theta(x)|=1 ∣y−fθ​(x)∣+∣1−y−fθ​(x)∣=∣0−fθ​(x)∣+∣1−fθ​(x)∣=1

是以

MAE

Symmetric Loss

,是以是噪聲魯棒的。

更嚴謹的論述,可以參考論文Making Risk Minimization Tolerant to Label Noise。

二、

GCE

在明了

MAE

為什麼對噪聲魯棒的機理後,我們再來看看

Cross Entropy Loss

為什麼是不魯棒的。如果還是用一句話來解釋,就是因為

CE

是無界的。考慮

(3)

式所表達的魯棒性條件,各個損失項非負,且其和為定值,那麼各個損失項必然是有界的。但在

CE

中,假設 f θ ( x ) f_\theta(x) fθ​(x)表示樣本 x x x屬于類别

1

的機率,那麼損失值是 − log ⁡ f θ ( x ) -\log f_\theta(x) −logfθ​(x)或 − log ⁡ ( 1 − f θ ( x ) ) -\log(1-f_\theta(x)) −log(1−fθ​(x))。當 f θ ( x ) f_\theta(x) fθ​(x)接近于

或者

1

時,損失值會非常大。這意味着模型會花費更多的功夫在這些樣本上。

這一特性是個雙刃劍。如果我們所有的訓練樣本的标注都是正确的,那麼這一特性使得模型關注于那些容易被錯判的樣本,是以訓練過程會十分高效。但如果訓練樣本存在标簽噪聲,這一特性将使得模型過度關注于誤标的樣本,以緻最終會得到一個過拟合到誤标樣本的模型。

好,那我們直接使用

MAE

不就行了麼。Hmmm…實際上

MAE

實戰效果并不好。因為

MAE

是平等地對待各個樣本的,是以其收斂速度比較慢 (具體的解釋可以參考

GCE

論文)。

那麼一個自然的想法就是能不能在

MAE

的噪聲魯棒性和

CE

的快速收斂性之間做一個折中。這就是

GCE

的基礎想法。具體做法是将

CE

中的 − log ⁡ ( f θ ( x ) ) -\log(f_\theta(x)) −log(fθ​(x))項,替換成一個指數項 1 − f θ q ( x ) q \frac{1-f_\theta^q(x)}{q} q1−fθq​(x)​,其中幂次 q q q為一個超參數,取值範圍為 ( 0 , 1 ] (0,1] (0,1]。當 q = 1 q=1 q=1時,該指數項就蛻變為

MAE

的形式;當 q → 0 q\rightarrow 0 q→0時,由洛必達法則,該指數項将蛻變為

CE

的形式。是以 q q q控制着在

MAE

CE

之間的折中程度。原論文中僅給出了 q q q的經驗值,并未給出具體的選取方法。

為進一步抑制誤标資料的影響,原論文還給出了一種加強版的損失函數,該函數融合了樣本選擇的思想。其核心想法是,當訓練到一定程度後,模型已經擷取了關于資料的正确分類的模式資訊,此時,如果模型在樣本上輸出的機率值比較小,那麼這些樣本很可能是誤标樣本。是以,若将這些樣本從訓練資料中剔除,那麼接下來的訓練可能會更準确。具體做法是:假設在訓練過程中

validation

資料集上準确率最好的模型為 M b e s t \mathcal{M}_{best} Mbest​,則在模型訓練一定次數後,每隔一定數目的

epoch

,将 M b e s t \mathcal{M}_{best} Mbest​預測的機率值小于 k k k的樣本丢棄,而在其餘的資料上進行訓練。這裡又引入了一個新的超參數 k k k。同樣的,原論文隻給出了一些資料集上的經驗值,并未給明确的設定方法。

三、

Results

以下為在

CIFAR-10

資料集上,傳統的交叉熵損失與廣義交叉熵損失的結果對比。所用參數與論文中一緻:

training

/

validation

/

testing

資料集大小分别為

45000

/

5000

/

10000

;模型選用

ResNet34

,注意其中第一個卷積層的參數為

kernel_size

=3,

stride

=1,而用于

ImageNet

資料集時

kernel_size

=7,

stride

=2,這是因為

CIFAR-10

的圖檔尺寸比較小,另外去除了第一個

pooling

層; q = 0.7 , k = 0.5 q=0.7, k=0.5 q=0.7,k=0.5;總共訓練120個

epochs

,初始學習速率為

1E-2

,第

40

、第

80

epoch

學習速率遞降

10

倍;從第

40

epoch

開始,每

10

epoch

篩選一遍資料。每組實驗運作了5次,結果表示為分類準确率的 μ ± σ \mu\pm\sigma μ±σ的形式。複現結果與論文所示大緻相當,甚至還要略優于論文中的結果。其中可以看到,存在

pair wise

類型的标簽噪聲時,準确率的提升不是特别明顯。

對标簽噪聲魯棒的廣義交叉熵損失 (Generalized Cross Entropy)

表一. 不同方法的對比

參考

  • Generalized Cross Entropy:

    GCE

    的論文。
  • pytorch-Truncated-Loss.git: 基于

    pytorch

    實作的

    GCE

繼續閱讀