天天看點

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

在正式介紹MIRNetV2之前,我們先來看一下它與MIRNetV1的性能對比,見下表。真可謂,MIRNetV2把MIRNetV1放在地上使勁的“摩擦”!關于MIRNet的介紹,可參見:https://zhuanlan.zhihu.com/p/261580767 。

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

1Method

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

上圖為MIRNetV2的網絡架構示意圖,它是在MIRNet的基礎上演變而來(MIRNet的整體架構形态與RCAN非常相似,差別在于其核心Block)。MIRNetV2的核心子產品為MRB,它是一種多尺度特征提取、聚合子產品。在多尺度方面,它通過下采樣方式建構了三個尺度的特征;在特征聚合方面,它采用了SKNet一文的特征融合機制;在特征提取方面,它采用了一種全新的RCB子產品(詳見後文介紹)。

SKFF

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

上圖為MRB中用于特征聚合的SKFF子產品結構示意圖,關于該子產品的介紹已經非常多了,也在不同結構設計中得到了廣泛應用,該子產品對于多尺度特征融合有非常優秀的效果,同時具有資料自适應性。關于SKFF直接看如下code可以看了,注:SKFF的輸入為[feat1, feat2]。

class SKFF(nn.Module):
    def __init__(self, in_channels, height=3,reduction=8,bias=False):
        super(SKFF, self).__init__()
        
        self.height = height
        d = max(int(in_channels/reduction),4)
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.LeakyReLU(0.2))

        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1,bias=bias))
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats =  inp_feats[0].shape[1]
        

        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
        
        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)

        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        # stx()
        attention_vectors = self.softmax(attention_vectors)
        
        feats_V = torch.sum(inp_feats*attention_vectors, dim=1)
        
        return feats_V 
           

複制

RCB

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

上圖為MRB的核心子產品RCB結構示意圖,它是在ResBlock基礎上演變而來。正常ResBlock的殘差分支隻包含兩個卷積一個非線性激活;而RCB将ResBlock中卷積的groups參數設為,然後引入通道相關性模組化(即上圖中的Modeling)資訊,将該資訊進行變換後輸入特征相融合。關于RCB,看一下下面Modeling的實作就差不多了。

class ContextBlock(nn.Module):

    def __init__(self, n_feat, bias=False):
        super(ContextBlock, self).__init__()

        self.conv_mask = nn.Conv2d(n_feat, 1, kernel_size=1, bias=bias)
        self.softmax = nn.Softmax(dim=2)

        self.channel_add_conv = nn.Sequential(
            nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
        )

    def modeling(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        context_mask = self.conv_mask(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = self.softmax(context_mask)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(3)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.modeling(x)

        # [N, C, 1, 1]
        channel_add_term = self.channel_add_conv(context)
        x = x + channel_add_term

        return x
           

複制

Training Regime

在訓練方面,圖像複原算法基本都采用随機裁剪的圖像塊進行模型訓練。對于較大的圖像塊,CNN可以捕獲更細粒度的細節并可以取得更優的性能,但會導緻更長的訓練時長;對于較小的圖像塊,盡管訓練速度快,但會導緻模型性能下降。

為達成訓練效率與性能均衡,本文提出一種漸進式學習方案:網絡先在小圖像塊上進行訓練,在訓練過程中階段性的将圖像塊的尺寸調大。這種混合尺寸學習機制不僅可以加速訓練,同時可以提升模型性能,可參照下表。

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

2Experiments

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量
MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

3MIRNetV1 vs MIRNetV2

MIRNet重磅更新!MIRNetV2 更快、更強、更輕量

對于沒有看過MIRNetV1的同學來說,直接看MIRNetV2的話,可能不知道MIRNetV2的改進點在哪裡。在這裡,我們對兩者的差異進行簡單的彙總(可參考上圖),主要展現在兩個方面:

  • 子產品方面:子產品方面的差異可以參考上圖。這裡又有兩個差異:(1) MRB的核心方面方面,MIRNet在ResBlock基礎上引入了對偶注意力單元,MIRNetV2則引入了前面所提到的RCB單元;(2) 特征聚合方面,MIRNet對于每個尺度的特征都與其他兩個尺寸的特征進行一次融合,而MIRNetV2則隻進行了低分辨率特征向高分辨率特征的融合。事實上,在實作上,兩者所使用的上采樣和下采樣也存在一些差異,MIRNetV1的實作更“臃腫”(多尺度),MIRNetV2的實作則更“簡單”(插值+卷積)。
  • 訓練機制:MIRNetV1采用的是最基本的固定塊尺寸方式進行訓練,而MIRNetV2則采用了漸進式(伴随訓練周期提升圖像塊尺寸)機制進行訓練。