目錄
- 前言
- 0、導入需要的包
- 1、smooth_BCE
- 2、BCEBlurWithLogitsLoss
- 3、FocalLoss
- 4、QFocalLoss
- 5、ComputeLoss類
-
- 5.1、__init__函數
- 5.2、build_targets
- 5.3、__call__函數
- 總結
- Reference
前言
源碼: YOLOv5源碼.
導航: 【YOLOV5-5.x 源碼講解】整體項目檔案導航.
這個檔案是yolov5的損失函數部分。代碼量不多,隻有300多行,但卻是整個項目最難,最精華的部分。在看這個檔案之前建議大家仔細看下下面兩篇關于BCE交叉熵損失函數的内容: 【PyTorch 理論】交叉熵損失函數的了解 和 【PyTorch】兩種常用的交叉熵損失函數BCELoss和BCEWithLogitsLoss 。另外,這個檔案涉及到了損失函數的計算、正負樣本取樣、平滑标簽增強、Focalloss、QFocalloss等操作,都是比較常用的trick,一樣都要弄懂!
0、導入需要的包
import torch
import torch.nn as nn
from utils.metrics import bbox_iou
from utils.torch_utils import is_parallel
1、smooth_BCE
這個函數是一個标簽平滑的政策(trick),是一種在 分類/檢測 問題中,防止過拟合的方法。如果要詳細了解這個政策的原理,可以看看我的另一篇博文: 【trick 1】Label Smoothing(标簽平滑)—— 分類問題中錯誤标注的一種解決方法.
smooth_BCE函數代碼:
def smooth_BCE(eps=0.1):
"""用在ComputeLoss類中
标簽平滑操作 [1, 0] => [0.95, 0.05]
https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
:params eps: 平滑參數
:return positive, negative label smoothing BCE targets 兩個值分别代表正樣本和負樣本的标簽取值
原先的正樣本=1 負樣本=0 改為 正樣本=1.0 - 0.5 * eps 負樣本=0.5 * eps
"""
return 1.0 - 0.5 * eps, 0.5 * eps
通常會用在分類損失當中,如下ComputeLoss類的__init__函數定義:
ComputeLoss類的__call__函數調用:
2、BCEBlurWithLogitsLoss
這個函數是BCE函數的一個替代,是yolov5作者的一個實驗性的函數,可以自己試試效果。
class BCEBlurWithLogitsLoss(nn.Module):
"""用在ComputeLoss類的__init__函數中
BCEwithLogitLoss() with reduced missing label effects.
https://github.com/ultralytics/yolov5/issues/1030
The idea was to reduce the effects of false positive (missing labels) 就是檢測成正樣本了 但是檢測錯了
"""
def __init__(self, alpha=0.05):
super(BCEBlurWithLogitsLoss, self).__init__()
self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
self.alpha = alpha
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred = torch.sigmoid(pred) # prob from logits
# dx = [-1, 1] 當pred=1 true=0時(網絡預測說這裡有個obj但是gt說這裡沒有), dx=1 => alpha_factor=0 => loss=0
# 這種就是檢測成正樣本了但是檢測錯了(false positive)或者missing label的情況 這種情況不應該過多的懲罰->loss=0
dx = pred - true # reduce only missing label effects
# 如果采樣絕對值的話 會減輕pred和gt差異過大而造成的影響
# dx = (pred - true).abs() # reduce missing label and false label effects
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
loss *= alpha_factor
return loss.mean()
使用起來直接在ComputeLoss類的__init__函數中替代傳統的BCE函數即可:
3、FocalLoss
FocalLoss損失函數來自 Kaiming He在2017年發表的一篇論文:Focal Loss for Dense Object Detection. 這篇論文設計的主要思路: 希望那些hard examples對損失的貢獻變大,使網絡更傾向于從這些樣本上學習。防止由于easy examples過多,主導整個損失函數。
優點:
- 解決了one-stage object detection中圖檔中正負樣本(前景和背景)不均衡的問題;
- 降低簡單樣本的權重,使損失函數更關注困難樣本;
函數公式:
更多細節請看我的另一篇部落格: 【trick 4】Focal Loss —— 解決one-stage目标檢測中正負樣本不均衡的問題.
FocalLoss函數代碼:
class FocalLoss(nn.Module):
"""用在代替原本的BCEcls(分類損失)和BCEobj(置信度損失)
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
論文: https://arxiv.org/abs/1708.02002
https://blog.csdn.net/qq_38253797/article/details/116292496
TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
"""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super(FocalLoss, self).__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()=Sigmoid+BCELoss 定義為多分類交叉熵損失函數
self.gamma = gamma # 參數gamma 用于削弱簡單樣本對loss的貢獻程度
self.alpha = alpha # 參數alpha 用于平衡正負樣本個數不均衡的問題
# self.reduction: 控制FocalLoss損失輸出模式 sum/mean/none 預設是Mean
self.reduction = loss_fcn.reduction
# focalloss中的BCE函數的reduction='None' BCE不使用Sum或者Mean
self.loss_fcn.reduction = 'none' # 需要将Focal loss應用于每一個樣本之中
def forward(self, pred, true):
loss = self.loss_fcn(pred, true) # 正常BCE的loss: loss = -log(p_t)
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
pred_prob = torch.sigmoid(pred) # prob from logits
# true=1 p_t=pred_prob true=0 p_t=1-pred_prob
p_t = true * pred_prob + (1 - true) * (1 - pred_prob) # p_t
# true=1 alpha_factor=self.alpha true=0 alpha_factor=1-self.alpha
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) # alpha_t
modulating_factor = (1.0 - p_t) ** self.gamma # 這裡代表Focal loss中的指數項
# 傳回最終的loss=BCE * 兩個參數 (看看公式就行了 和公式一模一樣)
loss *= alpha_factor * modulating_factor
# 最後選擇focalloss傳回的類型 預設是mean
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
這個函數用在代替原本的BCEcls和BCEobj:
4、QFocalLoss
QFocalLoss損失函數來自20年的一篇文章: Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection.
這篇文章我暫時還沒有看完,因為涉及太多的anchor free的内容,後面學完anchor free的一些經典論文再回來重寫。
如果對這篇論文感興趣可以看看大神部落格: 大白話 Generalized Focal Loss.
公式:
QFocalLoss函數代碼:
class QFocalLoss(nn.Module):
"""用來代替FocalLoss
QFocalLoss 來自General Focal Loss論文: https://arxiv.org/abs/2006.04388
Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
"""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super(QFocalLoss, self).__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred) # prob from logits
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
# 和FocalLoss相比隻變了這裡
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
用法就是直接在ComputeLoss時代替FocalLoss即可:
5、ComputeLoss類
5.1、__init__函數
這個函數就是定義一些後面要用到的變量,參數,函數等。
def __init__(self, model, autobalance=False):
super(ComputeLoss, self).__init__()
self.sort_obj_iou = False # 後面篩選置信度損失正樣本的時候是否先對iou排序
device = next(model.parameters()).device # get model device
h = model.hyp # hyperparameters
# Define criteria 定義分類損失和置信度損失
# BCEcls = BCEBlurWithLogitsLoss()
# BCEobj = BCEBlurWithLogitsLoss()
# h['cls_pw']=1 BCEWithLogitsLoss預設的正樣本權重也是1
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# 标簽平滑 eps=0代表不做标簽平滑-> cp=1 cn=0 eps!=0代表做标簽平滑 cp代表positive的标簽值 cn代表negative的标簽值
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
# Focal loss g=0 代表不用focal loss
g = h['fl_gamma'] # focal loss gamma
if g > 0:
# g>0 将分類損失和置信度損失(BCE)都換成focalloss損失函數
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
# BCEcls, BCEobj = QFocalLoss(BCEcls, g), QFocalLoss(BCEobj, g)
# det: 傳回的是模型的檢測頭 Detector 3個 分别對應産生三個輸出feature map
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
# balance用來設定三個feature map對應輸出的置信度損失系數(平衡三個feature map的置信度損失)
# 從左到右分别對應大feature map(檢測小目标)到小feature map(檢測大目标)
# 思路: It seems that larger output layers may overfit earlier, so those numbers may need a bit of adjustment
# 一般來說,檢測小物體的難度大一點,是以會增加大特征圖的損失系數,讓模型更加側重小物體的檢測
# 如果det.nl=3就傳回[4.0, 1.0, 0.4]否則傳回[4.0, 1.0, 0.25, 0.06, .02]
# self.balance = {3: [4.0, 1.0, 0.4], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl]
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
# 三個預測頭的下采樣率det.stride: [8, 16, 32] .index(16): 求出下采樣率stride=16的索引
# 這個參數會用來自動計算更新3個feature map的置信度損失系數self.balance
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
# self.BCEcls: 類别損失函數 self.BCEobj: 置信度損失函數 self.hyp: 超參數
# self.gr: 計算真實框的置信度标準的iou ratio self.autobalance: 是否自動更新各feature map的置信度損失平衡系數 預設False
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
# na: number of anchors 每個grid_cell的anchor數量 = 3
# nc: number of classes 資料集的總類别 = 80
# nl: number of detection layers Detect的個數 = 3
# anchors: [3, 3, 2] 3個feature map 每個feature map上有3個anchor(w,h) 這裡的anchor尺寸是相對feature map的
for k in 'na', 'nc', 'nl', 'anchors':
# setattr: 給對象self的屬性k指派為getattr(det, k)
# getattr: 傳回det對象的k屬性
# 是以這句話的意思: 講det的k屬性指派給self.k屬性 其中k in 'na', 'nc', 'nl', 'anchors'
setattr(self, k, getattr(det, k))
5.2、build_targets
這個函數是用來為每個feature map上的三個anchor篩選相應的正樣本的(ground true)。篩選條件是比較GT和anchor的寬比和高比,大于一定的門檻值就是負樣本,反之正樣本。篩選到的各個feature的各個anchor的正樣本資訊(image_index, anchor_index, gridy, gridx),傳入__call__函數,通過這個資訊去篩選pred每個grid預測得到的資訊,保留對應grid_cell上的正樣本。通過build_targets篩選的GT中的正樣本和pred篩選出的對應位置的預測樣本進行計算損失。
def build_targets(self, p, targets):
"""
Build targets for compute_loss()
:params p: 預測框 由模型建構中的三個檢測頭Detector傳回的三個yolo層的輸出
tensor格式 list清單 存放三個tensor 對應的是三個yolo層的輸出
如: [4, 3, 112, 112, 85]、[4, 3, 56, 56, 85]、[4, 3, 28, 28, 85]
[bs, anchor_num, grid_h, grid_w, xywh+class+classes]
可以看出來這裡的預測值p是三個yolo層每個grid_cell(每個grid_cell有三個預測值)的預測值,後面肯定要進行正樣本篩選
:params targets: 資料增強後的真實框 [63, 6] [num_target, image_index+class+xywh] xywh為歸一化後的框
:return tcls: 表示這個target所屬的class index
tbox: xywh 其中xy為這個target對目前grid_cell左上角的偏移量
indices: b: 表示這個target屬于的image index
a: 表示這個target使用的anchor index
gj: 經過篩選後确定某個target在某個網格中進行預測(計算損失) gj表示這個網格的左上角y坐标
gi: 表示這個網格的左上角x坐标
anch: 表示這個target所使用anchor的尺度(相對于這個feature map) 注意可能一個target會使用大小不同anchor進行計算
"""
na, nt = self.na, targets.shape[0] # number of anchors 3, targets 63
tcls, tbox, indices, anch = [], [], [], [] # 初始化tcls tbox indices anch
# gain是為了後面将targets=[na,nt,7]中的歸一化了的xywh映射到相對feature map尺度上
# 7: image_index+class+xywh+anchor_index
gain = torch.ones(7, device=targets.device)
# 需要在3個anchor上都進行訓練 是以将标簽指派na=3個 ai代表3個anchor上在所有的target對應的anchor索引 就是用來标記下目前這個target屬于哪個anchor
# [1, 3] -> [3, 1] -> [3, 63]=[na, nt] 三行 第一行63個0 第二行63個1 第三行63個2
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
# [63, 6] [3, 63] -> [3, 63, 6] [3, 63, 1] -> [3, 63, 7] 7: [image_index+class+xywh+anchor_index]
# 對每一個feature map: 這一步是将target複制三份 對應一個feature map的三個anchor
# 先假設所有的target對三個anchor都是正樣本(複制三份) 再進行篩選 并将ai加進去标記目前是哪個anchor的target
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
# 這兩個變量是用來擴充正樣本的 因為預測框預測到target有可能不止目前的格子預測到了
# 可能周圍的格子也預測到了高品質的樣本 我們也要把這部分的預測資訊加入正樣本中
g = 0.5 # bias 中心偏移 用來衡量target中心點離哪個格子更近
# 以自身 + 周圍左上右下4個網格 = 5個網格 用來計算offsets
off = torch.tensor([[0, 0],
[1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
], device=targets.device).float() * g # offsets
# 周遊三個feature 篩選每個feature map(包含batch張圖檔)的每個anchor的正樣本
for i in range(self.nl): # self.nl: number of detection layers Detect的個數 = 3
# anchors: 目前feature map對應的三個anchor尺寸(相對feature map) [3, 2]
anchors = self.anchors[i]
# gain: 儲存每個輸出feature map的寬高 -> gain[2:6]=gain[whwh]
# [1, 1, 1, 1, 1, 1, 1] -> [1, 1, 112, 112, 112,112, 1]=image_index+class+xywh+anchor_index
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
# t = [3, 63, 7] 将target中的xywh的歸一化尺度放縮到相對目前feature map的坐标尺度
# [3, 63, image_index+class+xywh+anchor_index]
t = targets * gain
if nt: # 開始比對 Matches
# t=[na, nt, 7] t[:, :, 4:6]=[na, nt, 2]=[3, 63, 2]
# anchors[:, None]=[na, 1, 2]
# r=[na, nt, 2]=[3, 63, 2]
# 目前feature map的3個anchor的所有正樣本(沒删除前是所有的targets)與三個anchor的寬高比(w/w h/h)
r = t[:, :, 4:6] / anchors[:, None] # wh ratio (w/w h/h)
# 篩選條件 GT與anchor的寬比或高比超過一定的門檻值 就當作負樣本
# torch.max(r, 1. / r)=[3, 63, 2] 篩選出寬比w1/w2 w2/w1 高比h1/h2 h2/h1中最大的那個
# .max(2)傳回寬比 高比兩者中較大的一個值和它的索引 [0]傳回較大的一個值
# j: [3, 63] False: 目前gt是目前anchor的負樣本 True: 目前gt是目前anchor的正樣本
j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
# yolov3 v4的篩選方法: wh_iou GT與anchor的wh_iou超過一定的門檻值就是正樣本
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
# 根據篩選條件j, 過濾負樣本, 得到目前feature map上三個anchor的所有正樣本t(batch_size張圖檔)
# t: [3, 63, 7] -> [126, 7] [num_Positive_sample, image_index+class+xywh+anchor_index]
t = t[j] # filter
# Offsets 篩選目前格子周圍格子 找到2個離target中心最近的兩個格子 可能周圍的格子也預測到了高品質的樣本 我們也要把這部分的預測資訊加入正樣本中
# 除了target所在的目前格子外, 還有2個格子對目标進行檢測(計算損失) 也就是說一個目标需要3個格子去預測(計算損失)
# 首先目前格子是其中1個 再從目前格子的上下左右四個格子中選擇2個 用這三個格子去預測這個目标(計算損失)
# feature map上的原點在左上角 向右為x軸正坐标 向下為y軸正坐标
gxy = t[:, 2:4] # grid xy 取target中心的坐标xy(相對feature map左上角的坐标)
gxi = gain[[2, 3]] - gxy # inverse 得到target中心點相對于右下角的坐标 gain[[2, 3]]為目前feature map的wh
# 篩選中心坐标 距離目前grid_cell的左、上方偏移小于g=0.5 且 中心坐标必須大于1(坐标不能在邊上 此時就沒有4個格子了)
# j: [126] bool 如果是True表示目前target中心點所在的格子的左邊格子也對該target進行回歸(後續進行計算損失)
# k: [126] bool 如果是True表示目前target中心點所在的格子的上邊格子也對該target進行回歸(後續進行計算損失)
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
# 篩選中心坐标 距離目前grid_cell的右、下方偏移小于g=0.5 且 中心坐标必須大于1(坐标不能在邊上 此時就沒有4個格子了)
# l: [126] bool 如果是True表示目前target中心點所在的格子的右邊格子也對該target進行回歸(後續進行計算損失)
# m: [126] bool 如果是True表示目前target中心點所在的格子的下邊格子也對該target進行回歸(後續進行計算損失)
l, m = ((gxi % 1. < g) & (gxi > 1.)).T
# j: [5, 126] torch.ones_like(j): 目前格子, 不需要篩選全是True j, k, l, m: 左上右下格子的篩選結果
j = torch.stack((torch.ones_like(j), j, k, l, m))
# 得到篩選後所有格子的正樣本 格子數<=3*126 都不在邊上等号成立
# t: [126, 7] -> 複制5份target[5, 126, 7] 分别對應目前格子和左上右下格子5個格子
# j: [5, 126] + t: [5, 126, 7] => t: [378, 7] 理論上是小于等于3倍的126 當且僅當沒有邊界的格子等号成立
t = t.repeat((5, 1, 1))[j]
# torch.zeros_like(gxy)[None]: [1, 126, 2] off[:, None]: [5, 1, 2] => [5, 126, 2]
# j篩選後: [378, 2] 得到所有篩選後的網格的中心相對于這個要預測的真實框所在網格邊界(左右上下邊框)的偏移量
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
# Define
b, c = t[:, :2].long().T # image_index, class
gxy = t[:, 2:4] # target的xy
gwh = t[:, 4:6] # target的wh
gij = (gxy - offsets).long() # 預測真實框的網格所在的左上角坐标(有左上右下的網格)
gi, gj = gij.T # grid xy indices
# Append
a = t[:, 6].long() # anchor index
# b: image index a: anchor index gj: 網格的左上角y坐标 gi: 網格的左上角x坐标
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))
# tbix: xywh 其中xy為這個target對目前grid_cell左上角的偏移量
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
anch.append(anchors[a]) # 對應的所有anchors
tcls.append(c) # class
return tcls, tbox, indices, anch
5.3、__call__函數
這個函數相當于forward函數,在這個函數中進行損失函數的前向傳播。
def __call__(self, p, targets): # predictions, targets, model
"""
:params p: 預測框 由模型建構中的三個檢測頭Detector傳回的三個yolo層的輸出
tensor格式 list清單 存放三個tensor 對應的是三個yolo層的輸出
如: [4, 3, 112, 112, 85]、[4, 3, 56, 56, 85]、[4, 3, 28, 28, 85]
[bs, anchor_num, grid_h, grid_w, xywh+class+classes]
可以看出來這裡的預測值p是三個yolo層每個grid_cell(每個grid_cell有三個預測值)的預測值,後面肯定要進行正樣本篩選
:params targets: 資料增強後的真實框 [63, 6] [num_object, batch_index+class+xywh]
:params loss * bs: 整個batch的總損失 進行反向傳播
:params torch.cat((lbox, lobj, lcls, loss)).detach(): 回歸損失、置信度損失、分類損失和總損失 這個參數隻用來可視化參數或儲存資訊
"""
device = targets.device # 确定運作的裝置
# 初始化lcls, lbox, lobj三種損失值 tensor([0.])
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
# 每一個都是append的 有feature map個 每個都是目前這個feature map中3個anchor篩選出的所有的target(3個grid_cell進行預測)
# tcls: 表示這個target所屬的class index
# tbox: xywh 其中xy為這個target對目前grid_cell左上角的偏移量
# indices: b: 表示這個target屬于的image index
# a: 表示這個target使用的anchor index
# gj: 經過篩選後确定某個target在某個網格中進行預測(計算損失) gj表示這個網格的左上角y坐标
# gi: 表示這個網格的左上角x坐标
# anch: 表示這個target所使用anchor的尺度(相對于這個feature map) 注意可能一個target會使用大小不同anchor進行計算
tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
# 依次周遊三個feature map的預測輸出pi
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image_index, anchor_index, gridy, gridx
tobj = torch.zeros_like(pi[..., 0], device=device) # 初始化target置信度(先全是負樣本 後面再篩選正樣本指派)
n = b.shape[0] # number of targets
if n:
# 精确得到第b張圖檔的第a個feature map的grid_cell(gi, gj)對應的預測值
# 用這個預測值與我們篩選的這個grid_cell的真實框進行預測(計算損失)
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
# Regression loss 隻計算所有正樣本的回歸損失
pxy = ps[:, :2].sigmoid() * 2. - 0.5 # 一個歸一化操作 和論文裡不同
# https://github.com/ultralytics/yolov3/issues/168
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] # 和論文裡不同 這裡是作者自己提出的公式
pbox = torch.cat((pxy, pwh), 1) # predicted box
# 這裡的tbox[i]中的xy是這個target對目前grid_cell左上角的偏移量[0,1] 而pbox.T是一個歸一化的值
# 就是要用這種方式訓練 傳回loss 修改梯度 讓pbox越來越接近tbox(偏移量)
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += (1.0 - iou).mean() # iou loss
# Objectness loss stpe1
# iou.detach() 不會更新iou梯度 iou并不是反向傳播的參數 是以不需要反向傳播梯度資訊
score_iou = iou.detach().clamp(0).type(tobj.dtype) # .clamp(0)必須大于等于0
if self.sort_obj_iou: # 可以看下官方的解釋 我也不是很清楚為什麼這裡要對iou排序???
# https://github.com/ultralytics/yolov5/issues/3605
# There maybe several GTs match the same anchor when calculate ComputeLoss in the scene with dense targets
sort_id = torch.argsort(score_iou)
b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
# 預測資訊有置信度 但是真實框資訊是沒有置信度的 是以需要我們認為的給一個标準置信度
# self.gr是iou ratio [0, 1] self.gr越大置信度越接近iou self.gr越小越接近1(人為加大訓練難度)
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio
# tobj[b, a, gj, gi] = 1 # 如果發現預測的score不高 資料集目标太小太擁擠 困難樣本過多 可以試試這個
# Classification loss 隻計算所有正樣本的分類損失
if self.nc > 1: # cls loss (only if multiple classes)
# targets 原本負樣本是0 這裡使用smooth label 就是cn
t = torch.full_like(ps[:, 5:], self.cn, device=device)
t[range(n), tcls[i]] = self.cp # 篩選到的正樣本對應位置值是cp
lcls += self.BCEcls(ps[:, 5:], t) # BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
# Objectness loss stpe2 置信度損失是用所有樣本(正樣本 + 負樣本)一起計算損失的
obji = self.BCEobj(pi[..., 4], tobj)
# 每個feature map的置信度損失權重不同 要乘以相應的權重系數self.balance[i]
# 一般來說,檢測小物體的難度大一點,是以會增加大特征圖的損失系數,讓模型更加側重小物體的檢測
lobj += obji * self.balance[i] # obj loss
if self.autobalance:
# 自動更新各個feature map的置信度損失系數
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
if self.autobalance:
self.balance = [x / self.balance[self.ssi] for x in self.balance]
# 根據超參中的損失權重參數 對各個損失進行平衡 防止總損失被某個損失所左右
lbox *= self.hyp['box']
lobj *= self.hyp['obj']
lcls *= self.hyp['cls']
bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls # 平均每張圖檔的總損失
# loss * bs: 整個batch的總損失
# .detach() 利用損失值進行反向傳播 利用梯度資訊更新的是損失函數的參數 而對于損失這個值是不需要梯度反向傳播的
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
train.py初始化損失函數類:
調用執行損失函數,計算損失:
這個代碼當中其實我也有點沒有讀懂:就是self.sort_obj_iou這個參數是幹嘛的,在篩選到的正樣本指派标準(target)置信度的時候,為什麼要先排序呢?作者是這樣說的:
There maybe several GTs match the same anchor when calculate ComputeLoss in the scene with dense targets
當在密集目标場景中,計算損失時可能有多個GT比對配相同的anchor?
what?這是什麼意思?如果有知道的朋友可以直接回複在下面。
總結
這個腳本最最最重要的就是ComputeLoss類了。看了很久,本來打算寫細一點的,但是看完代碼發現自己把想說的都已經寫在代碼的注釋當中了。代碼其實還是挺難的,尤其build_target各種花裡胡哨的矩陣操作較多,pytorch不熟的人會看的比較痛苦,但是如果你堅持看下來我的注釋再加上自己的debug的話,應該是能讀懂的。最後,一定要細讀ComputeLoss!!!!
另外這個腳本設計到的幾個其他的函數可以在這裡檢視到:【YOLOV5-5.x 源碼講解】整體項目檔案導航.
–2021.08.09 19:55
Reference
連結1: 部落格1.
連結2: 部落格2.
連結3: 部落格3.
連結4: 部落格4.
連結5: 部落格5.