天天看点

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

论文链接:SSD: Single Shot MultiBox Detector

文中所用代码(仅multibox_loss.py)链接:multibox_loss.py

Loss in SSD

数学公式部分

来自SSD论文第五页training objective段落。指示器

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

指: 第

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

- th 默认框(default box)对应

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

类物体的第

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

-th 目标框(Ground truth box)的匹配度,匹配度取值区间为0~1。根据这个匹配策略,可得所有的默认框对应

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

类物体的第

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

-th 目标框(Ground truth box)的匹配度的总和大于等于1,即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

。整体的目标损失函数(overall objective loss function)是定位损失(localization loss,简称loc)和置信损失(confidence loss,简称conf)的加权求和。

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

是能匹配的默认框数量,如果

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

,Loss为0。loc为预测框

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

(predicted box)和真实框

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

(ground truth box)之间的Smooth L1 loss。

对于loc, 根据公式一步步来

  1. 首先,公式(2)中第一行:先对每个可匹配的默认框求这个框的置信度器
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    的乘积,loc即为所有能匹配框的该值的总和。其中
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    分别如figure 1 中红绿箭头
  2. 拆开看
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    为该框的中心点的x和y值,
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    为框宽,
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    为框长。公式(2)中第二行说明:
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,公式(2)中第三行说明:
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,如果
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,则
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ;如果
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,则
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    。附:论文中未说明,但
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    应也为对默认框的比例值。
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

Figure 1. 预测框和真实框的图例,图中以center point举例说明,w和h信息未包含

对于conf,

先上公式(3),

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

拆分成两部分(Pos和Neg)看,即,在该次iteration计算里,一共有N(即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

时的i的总数为N个)个可匹配默认框:得出预测框

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

中为

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

类(category)的相对置信度

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

,再乘以框匹配度,在计算loss时,希望

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

,

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

,即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

,即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

。如果第

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

-th个默认框属于背景(Negative),则Neg部分:

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

(指背景类),

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

, 即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

, 即

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

。公式三要使

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

.

conf公式也拆分成两部分来看(conf loss 是指所有类别的loss):默认框为4个(如figure 1), 举例数据集有三类(0:背景,1:交通标志,2:信号灯)。

  1. 首先是
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    (即
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    时),2个默认框(Default BBox0,1,2)为可匹配默认框,因为
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ,根据公式
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    = [0.8log(0.9) + 0.8log(0.01)] + [0.6log(0.1)+0.6log(0.7)] = [(-0.037)+(-1.6)]+[(-0.6)+(-0.093)]= -2.33 这里(-1.6,-0.6)这两个值太大,即错误的框[(框0,类别2),(框1,类别1) ]loss大。正确的框/之后的预测框[(框0,类别1),(框1,类别2)]的loss小。
  2. (未完,代码hard negative mining部分说明)然后i 是
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    (即
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    ),从figure 2中可看到Default BBox 3未能和任何Groundtruth BBox重合,即IOU太小(IOU<0.5),归类为未能
focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

代码解读部分

一些相关参数(parameter):

  • batch_size = 32
def 
           

predictions是一个元组(tuple),其中包含三个值

  1. loc_data:预测框位置点矩阵 size: (batch_size,num_priors,4) -> (32,8732,4)
  2. conf_data:预测框置信度矩阵 size:(batch_size,num_priors,num_classes) -> (32,8732,3)
  3. priors:先验框矩阵 size:(num_priors,4) -> (8732,4)
# match priors (default boxes) and ground truth boxes
    
           

举例idx的循环:

idx: 0

targets: tensor([[ 0.2781, 0.8527, 0.3834, 1.0000, 8.0000],

[ 0.5586, 0.8045, 0.6864, 1.0000, 17.0000],

[ 0.3089, 0.7861, 0.4178, 0.9122, 17.0000],

[ 0.4544, 0.7564, 0.5278, 0.8286, 19.0000],

[ 0.2012, 0.7861, 0.2343, 0.8768, 19.0000]])

truths: torch.Size([5, 4])

labels: torch.Size([5])

idx: 1

targets: tensor([[ 0.8637, 0.7126, 0.9310, 0.9073, 14.0000]])

truths: torch.Size([1, 4])

labels: torch.Size([1])

...

match() function:

  • 返还新的loc_t (32, 8732, 4)和conf_t (32, 8732)值。loc_t: endcoded location targets。表示:32张图中,每张图里有8732个默认框:每个框对应该图中目标框的定位度loc_t( 4)和置信度conf_t(1)。每张图得出的维度分别为loc_t (8732, 4)和conf_t(8732, 1),一起处理32张图即维度分别为loc_t (32, 8732, 4)和loc_t (32, 8732)。
  • 其中,overlap_threshold: 0.5, 默认框和真实框的IOU小于0.5的,设定为背景,该位置返还的conf[i] = 0.

conf, pos, num_pos的输入和size:

con_f: tensor([[0, 0, 0, ..., 3, 3, 3],

[0, 0, 0, ..., 0, 0, 0],

[0, 0, 0, ..., 0, 0, 0],

...,

[0, 0, 0, ..., 0, 0, 0],

[0, 0, 0, ..., 8, 8, 0],

[0, 0, 0, ..., 0, 0, 0]]) torch.Size([32, 8732])

pos: tensor([[0, 0, 0, ..., 1, 1, 1],

[0, 0, 0, ..., 0, 0, 0],

[0, 0, 0, ..., 0, 0, 0],

...,

[0, 0, 0, ..., 0, 0, 0],

[0, 0, 0, ..., 1, 1, 0],

[0, 0, 0, ..., 0, 0, 0]], dtype=torch.uint8) torch.Size([32, 8732])

num_pos: tensor([[ 7], [64],[21],[53],[ 8],[ 2],[28],[ 8],[20],[66],[ 2], [14],[10],[11],[45],[12], [11],[ 8], [17],[ 9],[37],[24], [17], [43],[26], [11],[22], [10], [35],[14],[23],[33]]) torch.Size([32, 1])

# Localization Loss (Smooth L1)
    
           

Confidence Loss Including Positive and Negative Examples: 也就是equation(3).

pos_idx 
           

hard negative mining: 这里暂时先不讲,后期补充

# Compute max conf across batch for hard negative mining
    
           
  • 以下解释代码段“loss_c = log_sum_exp(batch_conf)- batch_conf.gather(1, conf_t.view(-1,1))”

其中, log_sum_exp(x):

focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

: 为何使用log_sum_exp(x)可参考Tricks of the Trade: LogSumExp. 这里简要说明。首先,Figure 3为代码“batch_conf.gather(1, conf_t.view(-1,1))”,输出shape为(279424,1)。下面是公式变形,其实 log_sum_exp(x)为计算Softmax的过程:

(以下公式来自链接Tricks of the Trade: LogSumExp)

  1. softmax:
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    (1),如何计算softmax如公式(1)
  2. log(softmax)即为
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
  3. focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
  4. 把(3)套入(2)即
    focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
  5. focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版
    即代码“loss_c = log_sum_exp(batch_conf)- batch_conf.gather(1, conf_t.view(-1,1))”
focal loss dice loss源码_SSD中Multibox_loss源码解读,Pytorch版

Figure 3