天天看點

【從零開始學CenterNet】3. CenterNet骨幹網絡之hourglass

CenterNet中主要提供了三個骨幹網絡ResNet-18(ResNet-101), DLA-34, Hourglass-104,本文從結構和代碼先對hourglass進行講解。

本文對應代碼位置在:https://github.com/pprp/SimpleCVReproduction/tree/master/Simple_CenterNet

1. Ground Truth Heatmap

在開始講解骨幹網絡之前,先提一下上一篇文章中有朋友問我的問題:CenterNet為什麼要沿用CornerNet的半徑計算方式?

查詢了CenterNet論文還有官方實作的issue,其實沒有明确指出為何要用CornerNet的半徑,issue中回複也說是這是沿用了CornerNet的祖傳代碼。經過和@tangsipeng的讨論,讨論結果如下:

以下代碼是涉及到半徑計算的部分:

# 根據一進制二次方程計算出最小的半徑
radius = max(0, int(gaussian_radius((math.ceil(h), math.ceil(w)), self.gaussian_iou)))
# 得到高斯分布
draw_umich_gaussian(hmap[label], obj_c_int, radius)
           

在centerNet中,半徑的存在主要是用于計算高斯分布的sigma值,而這個值也是一個經驗性判定結果。

def draw_umich_gaussian(heatmap, center, radius, k=1):
    # 得到直徑
    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
    # 一個圓對應内切正方形的高斯分布

    x, y = int(center[0]), int(center[1])

    height, width = heatmap.shape[0:2]

    # 對邊界進行限制,防止越界
    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # 選擇對應區域
    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    # 将高斯分布結果限制在邊界内
    masked_gaussian = gaussian[radius - top:radius + bottom, 
                               radius - left:radius + right]

    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
        # 将高斯分布覆寫到heatmap上,相當于不斷的在heatmap基礎上添加關鍵點的高斯,
        # 即同一種類型的框會在一個heatmap某一個類别通道上面上面不斷添加。
        # 最終通過函數總體的for循環,相當于不斷将目标畫到heatmap
    return heatmap
           

合理推測一下(不喜勿噴),之前很多人在知乎上issue裡讨論這個半徑計算的時候,有提到這樣的問題,就是如果将CenterNet對應的2a改正确了,反而效果會差。

我覺得這個問題可能和這裡的

sigma=diameter / 6

有一定的關系,作者當時用祖傳代碼(2a那部分有錯)進行調參,然後确定了sigma。這時這個sigma就和祖傳代碼是對應的,如果修改了祖傳代碼,同樣也需要改一下sigma或者調一下參數。

tangsipeng同學分享的文章《Training-Time-Friendly Network for Real-Time Object Detection》對應計算高斯核sigma部分就沒有用cornernet的祖傳代碼,對應代碼可以發現,這裡的sigma是一個和h,w相關的超參數,也是手工挑選的。

綜上,目前暫時認為CenterNet直接沿用CornerNet的祖傳代碼沒有官方的解釋,我們也暫時沒有想到解釋。如果對這個問題有研究的同學歡迎聯系筆者。

1. Hourglass

Hourglass網絡結構最初是在ECCV2016的Stacked hourglass networks for human pose estimation文章中提出的,用于人體姿态估計。Stacked Hourglass就是把多個漏鬥形狀的網絡級聯起來,可以擷取多尺度的資訊。

Hourglass的設計比較有層次,通過各個子產品的有規律組合成完整網絡。

1.1 最底層:Residual子產品

class residual(nn.Module):
    def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
        super(residual, self).__init__()

        self.conv1 = nn.Conv2d(inp_dim,
                               out_dim, (3, 3),
                               padding=(1, 1),
                               stride=(stride, stride),
                               bias=False)
        self.bn1 = nn.BatchNorm2d(out_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_dim,
                               out_dim, (3, 3),
                               padding=(1, 1),
                               bias=False)
        self.bn2 = nn.BatchNorm2d(out_dim)

        self.skip = nn.Sequential(nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
                                  nn.BatchNorm2d(out_dim)) \
            if stride != 1 or inp_dim != out_dim else nn.Sequential()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        conv1 = self.conv1(x)
        bn1 = self.bn1(conv1)
        relu1 = self.relu1(bn1)

        conv2 = self.conv2(relu1)
        bn2 = self.bn2(conv2)

        skip = self.skip(x)
        return self.relu(bn2 + skip)
           

就是簡單的殘差連結網絡中的最基礎的殘差子產品。

1.2 Hourglass子子產品

class kp_module(nn.Module):
    '''
    kp module指的是hourglass基本子產品
    '''
    def __init__(self, n, dims, modules):
        super(kp_module, self).__init__()

        self.n = n

        curr_modules = modules[0]
        next_modules = modules[1]

        curr_dim = dims[0]
        next_dim = dims[1]

        # curr_mod x residual,curr_dim -> curr_dim -> ... -> curr_dim
        self.top = make_layer(3, # 空間分辨率不變
                              curr_dim,
                              curr_dim,
                              curr_modules,
                              layer=residual)
        self.down = nn.Sequential() # 暫時沒用
        # curr_mod x residual,curr_dim -> next_dim -> ... -> next_dim
        self.low1 = make_layer(3,
                               curr_dim,
                               next_dim,
                               curr_modules,
                               layer=residual,
                               stride=2)# 降維
        # next_mod x residual,next_dim -> next_dim -> ... -> next_dim
        if self.n > 1:
            # 通過遞歸完成建構
            self.low2 = kp_module(n - 1, dims[1:], modules[1:])
        else:
            # 遞歸出口
            self.low2 = make_layer(3,
                                   next_dim,
                                   next_dim,
                                   next_modules,
                                   layer=residual)
        # curr_mod x residual,next_dim -> next_dim -> ... -> next_dim -> curr_dim
        self.low3 = make_layer_revr(3, # 升維
                                    next_dim,
                                    curr_dim,
                                    curr_modules,
                                    layer=residual)
        self.up = nn.Upsample(scale_factor=2) # 上采樣進行升維

    def forward(self, x):
        up1 = self.top(x)
        down = self.down(x)
        low1 = self.low1(down)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up(low3)
        return up1 + up2
           

其中有兩個主要的函數

make_layer

make_layer_revr

make_layer

将空間分辨率降維,

make_layer_revr

函數進行升維,是以将這個結構命名為hourglass(沙漏)。

核心建構是一個遞歸函數,遞歸層數是通過n來控制,稱之為n階hourglass子產品。

【從零開始學CenterNet】3. CenterNet骨幹網絡之hourglass

1.3 Hourglass

class exkp(nn.Module):
    '''
     整體模型調用
     large hourglass stack為2
     small hourglass stack為1
     n這裡控制的是hourglass的階數,以上兩個都用的是5階的hourglass
     exkp(n=5, nstack=2, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4]),
    '''
    def __init__(self, n, nstack, dims, modules, cnv_dim=256, num_classes=80):
        super(exkp, self).__init__()

        self.nstack = nstack # 堆疊多次hourglass
        self.num_classes = num_classes

        curr_dim = dims[0]

        # 快速降維為原來的1/4
        self.pre = nn.Sequential(convolution(7, 3, 128, stride=2),
                                 residual(3, 128, curr_dim, stride=2))

        # 堆疊nstack個hourglass
        self.kps = nn.ModuleList(
            [kp_module(n, dims, modules) for _ in range(nstack)])

        self.cnvs = nn.ModuleList(
            [convolution(3, curr_dim, cnv_dim) for _ in range(nstack)])

        self.inters = nn.ModuleList(
            [residual(3, curr_dim, curr_dim) for _ in range(nstack - 1)])

        self.inters_ = nn.ModuleList([
            nn.Sequential(nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False),
                          nn.BatchNorm2d(curr_dim)) for _ in range(nstack - 1)
        ])
        self.cnvs_ = nn.ModuleList([
            nn.Sequential(nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False),
                          nn.BatchNorm2d(curr_dim)) for _ in range(nstack - 1)
        ])
        # heatmap layers
        self.hmap = nn.ModuleList([
            make_kp_layer(cnv_dim, curr_dim, num_classes) # heatmap輸出通道為num_classes
            for _ in range(nstack)
        ])
        for hmap in self.hmap:
            # -2.19是focal loss中的預設參數,論文的4.1節有詳細說明,-ln((1-pi)/pi),這裡的pi取0.1
            hmap[-1].bias.data.fill_(-2.19)

        # regression layers
        self.regs = nn.ModuleList(
            [make_kp_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)]) # 回歸的輸出通道為2
        self.w_h_ = nn.ModuleList(
            [make_kp_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)]) # wh

        self.relu = nn.ReLU(inplace=True)

    def forward(self, image):
        inter = self.pre(image)

        outs = []
        for ind in range(self.nstack): # 堆疊兩次hourglass
            kp = self.kps[ind](inter)
            cnv = self.cnvs[ind](kp)

            if self.training or ind == self.nstack - 1:
                outs.append([
                    self.hmap[ind](cnv), self.regs[ind](cnv),
                    self.w_h_[ind](cnv)
                ])

            if ind < self.nstack - 1:
                inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
                inter = self.relu(inter)
                inter = self.inters[ind](inter)
        return outs
           

這裡需要注意的是inters變量,這個變量儲存的是中間監督過程,可以在這個位置添加loss,具體如下圖藍色部分所示,在這個部分可以添加loss,然後再用1x1卷積重新映射到對應的通道個數并相加。

【從零開始學CenterNet】3. CenterNet骨幹網絡之hourglass

然後再來談三個輸出,假設目前是COCO資料集,類别個數為80,那麼hmap相當于輸出了通道個數為80的heatmap,每個通道負責預測一個類别;wh代表對應中心點的寬和高;regs是偏置量。

CenterNet論文詳解可以點選【目标檢測Anchor-Free】CVPR 2019 Object as Points(CenterNet)

整個網絡就梳理完成了,筆者簡單畫了一下nstack為2時的hourglass網絡,如下圖所示:

【從零開始學CenterNet】3. CenterNet骨幹網絡之hourglass

3. Reference

https://blog.csdn.net/shenxiaolu1984/article/details/51428392

http://xxx.itp.ac.cn/pdf/1603.06937.pdf

http://xxx.itp.ac.cn/pdf/1904.07850v1

代碼改變世界

繼續閱讀