天天看點

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

公衆号關注 “ ML_NLP ” 設為 “ 星标 ”,重磅幹貨,第一時間送達!

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

原文連結:https://amaarora.github.io/2020/06/29/FocalLoss.html

原文作者:Aman Arora

Focal loss 是一個在目标檢測領域常用的損失函數。最近看到一篇部落格,趁這個機會,學習和翻譯一下,與大家一起交流和分享。

在這篇部落格中,我們将會了解什麼是Focal loss,并且什麼時候應該使用它。同時我們會深入了解下其背後的數學原理與pytorch 實作.

  1. 什麼是Focal loss,它是用來幹嘛的?
  2. 為什麼Focal loss有效,其中的原理是什麼?
  3. Alpha and Gamma?
  4. 怎麼在代碼中實作它?
  5. Credits

什麼是Focal loss,它是用來幹嘛的?

在了解什麼是Focal Loss以及有關它的所有詳細資訊之前,我們首先快速直覺地了解Focal Loss的實際作用。Focal loss最早是 He et al 在論文 Focal Loss for Dense Object Detection 中實作的。

在這篇文章發表之前,對象檢測實際上一直被認為是一個很難解決的問題,尤其是很難檢測圖像中的小尺寸對象。請參見下面的示例,與其他圖檔相比,機車的尺寸相對較小, 是以該模型無法很好地預測機車的存在。

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

fig-1  bce  在上圖中,模型無法預測機車的原因是因為該模型是使用了Binary Cross Entropy loss,這種訓練目标要求模型 對自己的預測真的很有信心。而Focal Loss所做的是,它使模型可以更"放松"地預測事物,而無需80-100%确信此對象是“某物”。簡而言之,它給模型提供了更多的自由,可以在進行預測時承擔一些風險。這在處理高度不平衡的資料集時尤其重要,因為在某些情況下(例如癌症檢測),即使預測結果為假陽性也可接受,确實需要模型承擔風險并盡量進行預測。

是以,Focal loss在樣本不平衡的情況下特别有用。特别是在“對象檢測”的情況下,大多數像素通常都是背景,圖像中隻有很少數的像素具有我們感興趣的對象。

這是經過Focal loss訓練後同一模型對同樣圖檔的預測。

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

fig-2  focal loss prediction 分析這兩者并觀察其中的差異,可能是個很好的主意。這将有助于我們對于Focal loss進行直覺的了解。

那麼為什麼Focal loss有效,其中的原理是什麼?

既然我們已經看到了“Focal loss”可以做什麼的一個例子,接下來讓我們嘗試去了解為什麼它可以起作用。下面是了解Focal loss的最重要的一張圖:

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

fig-3 FL vs CE

在上圖中,“藍”線代表交叉熵損失。X軸即“預測為真實标簽的機率”(為簡單起見,将其稱為pt)。舉例來說,假設模型預測某物是自行車的機率為0.6,而它确實是自行車, 在這種情況下的pt為0.6。而如果同樣的情況下對象不是自行車。則pt為0.4,因為此處的真實标簽是0,而對象不是自行車的機率為0.4(1-0.6)。

Y軸是給定pt後Focal loss和CE的loss的值。

從圖像中可以看出,當模型預測為真實标簽的機率為0.6左右時,交叉熵損失仍在0.5左右。是以,為了在訓練過程中減少損失,我們的模型将必須以更高的機率來預測到真實标簽。換句話說,交叉熵損失要求模型對自己的預測非常有信心。但這也同樣會給模型表現帶來負面影響。

深度學習模型會變得過度自信, 是以模型的泛化能力會下降.

這個模型過度自信的問題同樣在另一篇出色的論文 Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration 被強調過。

另外,作為重新思考計算機視覺的初始架構的一部分而引入的标簽平滑是解決該問題的另一種方法。

Focal loss與上述解決方案不同。從比較Focal loss與CrossEntropy的圖表可以看出,當使用γ> 1的Focal Loss可以減少“分類得好的樣本”或者說“模型預測正确機率大”的樣本的訓練損失,而對于“難以分類的示例”,比如預測機率小于0.5的,則不會減小太多損失。是以,在資料類别不平衡的情況下,會讓模型的注意力放在稀少的類别上,因為這些類别的樣本見過的少,比較難分。

Focal loss的數學定義如下:

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

Alpha and Gamma?

那麼在Focal loss 中的

alpha

gamma

是什麼呢?我們會将

alpha

記為

α

gamma

記為

γ

我們可以這樣來了解fig3

γ

 控制曲線的形狀. 

γ

的值越大, 好分類樣本的loss就越小, 我們就可以把模型的注意力投向那些難分類的樣本. 一個大的 

γ

 讓獲得小loss的樣本範圍擴大了.

同時,當

γ=0

時,這個表達式就退化成了Cross Entropy Loss,衆所周知地

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

定義“ pt”如下,按照其真實意義:

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

将上述兩個式子合并,Cross Entropy Loss其實就變成了下式。

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

現在我們知道了γ的作用,那麼α是幹什麼的呢?

除了Focal loss以外,另一種處理類别不均衡的方法是引入權重。給稀有類别以高權重,給統治地位的類或普通類以小權重。這些權重我們也可以用α表示。

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

alpha-CE

加上了這些權重确實幫助處理了類别的 不均衡,focal loss的論文報道:

類間不均衡較大會導緻,交叉熵損失在訓練的時候收到影響。易分類的樣本的分類錯誤的損失占了整體損失的絕大部分,并主導梯度。盡管α平衡了正面/負面例子的重要性,但它并未區分簡單/困難例子。

作者想要解釋的是:

盡管我們加上了α, 它也确實對不同的類别加上了不同的權重, 進而平衡了正負樣本的重要性 ,但在大多數例子中,隻做這個是不夠的. 我們同樣要做的是減少容易分類的樣本分類錯誤的損失。因為不然的話,這些容易分類的樣本就主導了我們的訓練.

那麼Focal loss 怎麼處理的呢,它相對交叉熵加上了一個乘性的因子

(1 − pt)**γ

,進而像我們上面所講的,降低了易分類樣本區間内産生的loss。

再看下Focal loss的表達,是不是清晰了許多。

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

怎麼在代碼中實作呢?

這是Focal loss在Pytorch中的實作。

class WeightedFocalLoss(nn.Module):    "Non weighted version of Focal Loss"    def __init__(self, alpha=.25, gamma=2):        super(WeightedFocalLoss, self).__init__()        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()        self.gamma = gamma    def forward(self, inputs, targets):        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')        targets = targets.type(torch.long)        at = self.alpha.gather(0, targets.data.view(-1))        pt = torch.exp(-BCE_loss)        F_loss = at*(1-pt)**self.gamma * BCE_loss        return F_loss.mean()
           

如果你了解了

alpha

gamma

的意思,那麼這個實作應該都能了解。同時,像文章中提到的一樣,這裡是對BCE進行因子的相乘。

Credits

貼上作者的 twitter ,當然如果大家有什麼問題讨論,也可以在公衆号留言。

  • fig-1

     and 

    fig-2

     are from the Fastai 2018 course Lecture-09!

倉庫位址共享:

在機器學習算法與自然語言處理公衆号背景回複“代碼”,

即可擷取195篇NAACL+295篇ACL2019有代碼開源的論文。開源位址如下:https://github.com/yizhen20133868/NLP-Conferences-Code

重磅!憶臻自然語言處理-Pytorch交流群已正式成立!

群内有大量資源,歡迎大家進群學習!

注意:請大家添加時修改備注為 [學校/公司 + 姓名 + 方向]

例如 —— 哈工大+張三+對話系統。

号主,微商請自覺繞道。謝謝!

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)
kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

推薦閱讀:

Longformer:超越RoBERTa,為長文檔而生的預訓練模型

一文直覺了解KL散度

機器學習必讀TOP 100論文清單:高引用、分類全、覆寫面廣丨GitHub 21.4k星

kl散度的了解_10分鐘了解Focal loss數學原理與Pytorch代碼(翻譯)

繼續閱讀