天天看點

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

retinanet是ICCV2017的Best Student Paper Award(最佳學生論文),何凱明是其作者之一.文章中最為精華的部分就是損失函數 Focal loss的提出.

論文中提出類别失衡是造成two-stage與one-stage模型精确度差異的原因.并提出了Focal loss損失函數,通過調整類間平衡因子與難易度平衡因子.最終使one-stage模型達到了two-stage的精确度.

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

本項目基于pytorch實作focal loss,力圖給你原生pytorch損失函數的使用體驗.

一. 項目簡介

實作過程簡易明了,全中文備注.

阿爾法α 參數用于調整類别權重

伽馬γ 參數用于調整不同檢測難易樣本的權重,讓模型快速關注于困難樣本

完整項目位址:Github,歡迎star, fork.github還有其他視覺相關項目

github連接配接較慢的,可以去Gitee(國内的代碼托管網站),也有完整項目.

項目配有 Jupyter-Notebook 作為focal loss使用例子.

二. 損失函數公式

focal loss 損失函數基于交叉熵損失函數,在交叉熵的基礎上,引入了α與γ兩個不同的調整因子.

2.1 交叉熵損失

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

2.2 帶平衡因子的交叉熵

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

2.3 Focal loss損失

加入 (1-pt)γ 平衡難易樣本的權重,通過γ縮放因子調整,retainnet預設γ=2

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

2.4 帶平衡因子的Focal損失

論文中最終為帶平衡因子的focal loss, 本項目實作的也是這個版本

python torch exp_基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網...

三. Focal loss實作

# -*- coding: utf-8 -*-

# @Author : LG

from torch import nn

import torch

from torch.nn import functional as F

class focal_loss(nn.Module):

def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):

"""

focal_loss損失函數, -α(1-yi)**γ *ce_loss(xi,yi)

步驟詳細的實作了 focal_loss損失函數.

:param alpha: 阿爾法α,類别權重. 當α是清單時,為各類别權重,當α為常數時,類别權重為[α, 1-α, 1-α, ....],常用于 目标檢測算法中抑制背景類 , retainnet中設定為0.25

:param gamma: 伽馬γ,難易樣本調節參數. retainnet中設定為2

:param num_classes: 類别數量

:param size_average: 損失計算方式,預設取均值

"""

super(focal_loss,self).__init__()

self.size_average = size_average

if isinstance(alpha,list):

assert len(alpha)==num_classes # α可以以list方式輸入,size:[num_classes] 用于對不同類别精細地賦予權重

print("Focal_loss alpha = {}, 将對每一類權重進行精細化指派".format(alpha))

self.alpha = torch.Tensor(alpha)

else:

assert alpha<1 #如果α為一個常數,則降低第一類的影響,在目标檢測中為第一類

print(" --- Focal_loss alpha = {} ,将對背景類進行衰減,請在目标檢測任務中使用 --- ".format(alpha))

self.alpha = torch.zeros(num_classes)

self.alpha[0] += alpha

self.alpha[1:] += (1-alpha) # α 最終為 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

self.gamma = gamma

def forward(self, preds, labels):

"""

focal_loss損失計算

:param preds: 預測類别. size:[B,N,C] or [B,C] 分别對應與檢測與分類任務, B 批次, N檢測框數, C類别數

:param labels: 實際類别. size:[B,N] or [B]

:return:

"""

# assert preds.dim()==2 and labels.dim()==1

preds = preds.view(-1,preds.size(-1))

self.alpha = self.alpha.to(preds.device)

preds_softmax = F.softmax(preds, dim=1) # 這裡并沒有直接使用log_softmax, 因為後面會用到softmax的結果(當然你也可以使用log_softmax,然後進行exp操作)

preds_logsoft = torch.log(preds_softmax)

preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 這部分實作nll_loss ( crossempty = log_softmax + nll )

preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))

self.alpha = self.alpha.gather(0,labels.view(-1))

loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 為focal loss中 (1-pt)**γ

loss = torch.mul(self.alpha, loss.t())

if self.size_average:

loss = loss.mean()

else:

loss = loss.sum()

return loss

詳細的使用例子請到Github檢視jupyter-notebook.

說明

完整項目位址:Github,歡迎star, fork.

僅限用于交流學習,如需引用,請聯系作者.

原創文章,轉載請注明 :基于Pytorch實作Focal loss.(簡單、易用、全中文注釋、帶例子) - pytorch中文網

原文出處: https://ptorch.com/news/253.html

問題交流群 :168117787