retinanet是ICCV2017的Best Student Paper Award(最佳學生論文),何凱明是其作者之一.文章中最為精華的部分就是損失函數 Focal loss的提出.
論文中提出類别失衡是造成two-stage與one-stage模型精确度差異的原因.并提出了Focal loss損失函數,通過調整類間平衡因子與難易度平衡因子.最終使one-stage模型達到了two-stage的精确度.
本項目基于pytorch實作focal loss,力圖給你原生pytorch損失函數的使用體驗.
一. 項目簡介
實作過程簡易明了,全中文備注.
阿爾法α 參數用于調整類别權重
伽馬γ 參數用于調整不同檢測難易樣本的權重,讓模型快速關注于困難樣本
完整項目位址:Github,歡迎star, fork.github還有其他視覺相關項目
github連接配接較慢的,可以去Gitee(國内的代碼托管網站),也有完整項目.
項目配有 Jupyter-Notebook 作為focal loss使用例子.
二. 損失函數公式
focal loss 損失函數基于交叉熵損失函數,在交叉熵的基礎上,引入了α與γ兩個不同的調整因子.
2.1 交叉熵損失
2.2 帶平衡因子的交叉熵
2.3 Focal loss損失
加入 (1-pt)γ 平衡難易樣本的權重,通過γ縮放因子調整,retainnet預設γ=2
2.4 帶平衡因子的Focal損失
論文中最終為帶平衡因子的focal loss, 本項目實作的也是這個版本
三. Focal loss實作
# -*- coding: utf-8 -*-
# @Author : LG
from torch import nn
import torch
from torch.nn import functional as F
class focal_loss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
"""
focal_loss損失函數, -α(1-yi)**γ *ce_loss(xi,yi)
步驟詳細的實作了 focal_loss損失函數.
:param alpha: 阿爾法α,類别權重. 當α是清單時,為各類别權重,當α為常數時,類别權重為[α, 1-α, 1-α, ....],常用于 目标檢測算法中抑制背景類 , retainnet中設定為0.25
:param gamma: 伽馬γ,難易樣本調節參數. retainnet中設定為2
:param num_classes: 類别數量
:param size_average: 損失計算方式,預設取均值
"""
super(focal_loss,self).__init__()
self.size_average = size_average
if isinstance(alpha,list):
assert len(alpha)==num_classes # α可以以list方式輸入,size:[num_classes] 用于對不同類别精細地賦予權重
print("Focal_loss alpha = {}, 将對每一類權重進行精細化指派".format(alpha))
self.alpha = torch.Tensor(alpha)
else:
assert alpha<1 #如果α為一個常數,則降低第一類的影響,在目标檢測中為第一類
print(" --- Focal_loss alpha = {} ,将對背景類進行衰減,請在目标檢測任務中使用 --- ".format(alpha))
self.alpha = torch.zeros(num_classes)
self.alpha[0] += alpha
self.alpha[1:] += (1-alpha) # α 最終為 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
self.gamma = gamma
def forward(self, preds, labels):
"""
focal_loss損失計算
:param preds: 預測類别. size:[B,N,C] or [B,C] 分别對應與檢測與分類任務, B 批次, N檢測框數, C類别數
:param labels: 實際類别. size:[B,N] or [B]
:return:
"""
# assert preds.dim()==2 and labels.dim()==1
preds = preds.view(-1,preds.size(-1))
self.alpha = self.alpha.to(preds.device)
preds_softmax = F.softmax(preds, dim=1) # 這裡并沒有直接使用log_softmax, 因為後面會用到softmax的結果(當然你也可以使用log_softmax,然後進行exp操作)
preds_logsoft = torch.log(preds_softmax)
preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 這部分實作nll_loss ( crossempty = log_softmax + nll )
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
self.alpha = self.alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 為focal loss中 (1-pt)**γ
loss = torch.mul(self.alpha, loss.t())
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
詳細的使用例子請到Github檢視jupyter-notebook.
說明
完整項目位址:Github,歡迎star, fork.
僅限用于交流學習,如需引用,請聯系作者.
原創文章,轉載請注明 :基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網
原文出處: https://ptorch.com/news/253.html
問題交流群 :168117787