注意力機制大合集:
https://github.com/xmu-xiaoma666/External-Attention-pytorch
1 Attention 和 Self-Attention
Attention的核心思想是:從關注全部到關注重點。Attention 機制很像人類看圖檔的邏輯,當看一張圖檔的時候,我們并沒有看清圖檔的全部内容,而是将注意力集中在了圖檔的焦點上。大家看下面這張圖自行體會:
對于CV中早期的Attention,例如:SENet,CBAM,通常是在通道或者空間計算注意力分布。
而Self-attention(NLP中往往稱為Scaled-Dot Attention)的結構有三個分支:query、key和value。計算時通常分為三步:
第一步:是将query和每個key進行相似度計算得到權重,常用的相似度函數有點積,cos相似度,拼接,感覺機等;
第二步:一般是使用一個softmax函數對這些權重進行歸一化;
第三步:将權重和相應的鍵值value進行權重求和得到最後的attention。
假設輸入的 feature maps 的大小 Batch_size×Channels×Width×Height,那麼通過三個1×1卷積(分别是query_conv , key_conv 和 value_conv)就可以得到query、key 和 value:
- query:在query_conv卷積中,輸入為B×C×W×H,輸出為B×C/8×W×H;
- key:在key_conv卷積中,輸入為B×C×W×H,輸出為B×C/8×W×H;
- value:在value_conv卷積中,輸入為B×C×W×H,輸出為B×C×W×H。
後續的操作可以檢視下面代碼及注釋:
class Self_Attn(nn.Module): """ Self attention Layer""" def __init__(self,in_dim,activation): super(Self_Attn,self).__init__() self.chanel_in = in_dim self.activation = activation self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self,x): """ inputs : x : input feature maps( B * C * W * H) returns : out : self attention value + input feature attention: B * N * N (N is Width*Height) """ m_batchsize,C,width ,height = x.size() proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B*N*C proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B*C*N energy = torch.bmm(proj_query,proj_key) # batch的matmul B*N*N attention = self.softmax(energy) # B * (N) * (N) proj_value = self.value_conv(x).view(m_batchsize,-1, width*height) # B * C * N out = torch.bmm(proj_value,attention.permute(0,2,1) ) # B*C*N out = out.view(m_batchsize,C,width,height) # B*C*H*W out = self.gamma*out + x return out,attention
2 【CVPR 2017】SENet
論文:https://arxiv.org/abs/1709.01507
代碼:https://github.com/hujie-frank/SENet
由Momenta研發的網路SENet,獲得ImageNet 2017 Image Classification 冠軍。将top-5 error從2.991%降到2.251%。
SENet是早期Attention,核心思想是學習 feature Channel 間的關系,以凸顯feature Channel不同的重要度(也就是注意力分布),進而提高模型表現。
上圖是SE Module 的示意圖。給定一個輸入 x,其特征通道數為 c_1,通過一系列卷積等一般變換後得到一個特徵通道數為 c_2 的特徵。與傳統的 CNN 不一樣的是,接下來通過三個操作來重标定前面得到的特征。
首先是 Squeeze 操作,從空間次元來進行特征壓縮,将h*w*c的特征變成一個1*1*c的特征,得到向量某種程度上具有全域性的感受野,并且輸出的通道數和輸入的特征通道數相比對,它表示在特征通道上響應的全域性分布。公式非常簡單,就是一個 global average pooling:
其次是 Excitation 操作,通過引入 w 參數來為每個特征通道生成權重,其中引數 w 是可學習的,并通過一個 Sigmoid 的門獲得 0~1 之間歸一化的權重,完成顯式地模組化特征通道間的相關性。公式如下:
最後是一個 Scale 的操作,将 Excitation 的輸出的權重看做是經過選擇後的每個特征通道的重要性,然後通過channel-wise multiplication 逐通道權重到先前的特征上,完成在通道次元上的對原始特征的重标定。公式如下:
介紹完具體的公式實作,下面介紹下SE block如何運用到具體的網絡中。
代碼:
class SELayer(nn.Module): def __init__(self, channel, reduction=16): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) # 壓縮空間 self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)
簡評:雖然沒有看到query、key和value的影子,但是其展現了不同channel應有不同權重,是早期的attention。
3【ECCV 2018】CBAM
論文:https://arxiv.org/abs/1807.06521
代碼:https://github.com/luuuyi/CBAM.PyTorch
這是2018年ECCV的一篇論文,引文超過1000篇。
CBAM可以無縫地內建到任何CNN架構中,開銷不會很大,而且可以與基本CNN網絡一起進行端到端的訓練。與SENet類似,CBAM 也是早期的Attention,沒有通過複雜的相似度計算得到注意力分布。
CBAM依次推導出尺寸為C×1×1的一維通道注意圖Mc和尺寸為1×H×W的二維空間注意圖Ms:
其中⨂表示element-wise的乘法,F''是最終的優化輸出。
實驗表明,sequential arrangement 比parallel arrangement效果,并且channel-first 順序略優于 spatial-first.。
ResBlock中的CBAM示例如下所示:
Channel Attention集中在輸入圖像的“channel”上。
為了有效地計算channel attention,對輸入特征映射的空間次元進行壓縮。
對于空間資訊的聚合,通常同時采用 average-pooling 和 max-pooling,以得到更精細的channel-wise attention。
Fcavg和Fcmax分别表示平均池特征和最大池特征,然後通過一個隐藏層的多層感覺器(MLP),σ表示sigmoid函數。
Spatial attention關注“空間”的資訊,是對Channel Attention的補充。
為了計算Spatial attention,在 Channel 軸上應用average-pooling 和 max-pooling,然後将它們連接配接起來生成一個有效的特征。
然後利用卷積層生成一個R×H×W的空間注意映射Ms(F),該映射對強調或抑制的位置進行編碼。
具體地說,通過兩次池生成兩個映射:1×H×W的Fsavg和1×H×W的Fsmax。σ表示sigmoid函數,f7×7表示濾波器尺寸為7×7的卷積運算。
代碼如下:
def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) # 壓縮空間 self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out # [b, C, 1, 1] return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) # 壓縮通道 max_out, _ = torch.max(x, dim=1, keepdim=True) # 壓縮通道 x = torch.cat([avg_out, max_out], dim=1) # [b, 1, h, w] x = self.conv1(x) return self.sigmoid(x) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.ca = ChannelAttention(planes) self.sa = SpatialAttention() self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.ca(out) * out out = self.sa(out) * out if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.ca = ChannelAttention(planes * 4) self.sa = SpatialAttention() self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out = self.ca(out) * out out = self.sa(out) * out if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
4【CVPR2018 Non-local】
論文位址:https://arxiv.org/abs/1711.07971
再次回顧下Self-attention
Self-attention結構自上而下分為三個分支,分别是query、key和value。計算時通常分為三步:
第一步是将query和每個key進行相似度計算得到權重,常用的相似度函數有點積,cos相似度,拼接,感覺機等;
第二步一般是使用一個softmax函數對這些權重進行歸一化;
第三步将權重和相應的鍵值value進行權重求和得到最後的attention。
Non-local就是CV中的self-attetion。其計算公式如下:
x是輸入信号,CV中使用的一般是 feature map;
i 代表的是輸出位置,如空間、時間或者時空的索引,j 代表全局響應;
f 函數式計算i和j的相似度;
g 函數計算feature map在j位置的表示;
最終的y是通過響應因子 C(x) 進行标準化處理以後得到的。
結構圖如下,可以看到non-local的原理與self-attention運作原理一樣,通過 3 個1*1的卷積建構了query,key 和 value。
5 【CVPR 2019】DANet
論文題目:Dual Attention Network for Scene Segmentation
論文位址:https://arxiv.org/abs/1809.02983
DANet結構如上圖,包含了Position Attention Module 和 Channel Attention Module,和CBAM相似,隻是在spatial和channel次元利用self-attention思想建立全局上下文關系。如下所示:
6 總結
Self-attention能夠捕捉全局的特征,是以,在計算機視覺領域大放異彩,如 Detr,Sparse R-CNN等等,不過需要指出的是:Self attention 也是有缺陷的,如:計算量大,并且這類Set Prediction檢測器檢測準确性還不能夠超越之前的檢測算法。
是以,如果是做研究,那麼這是一個不錯的主題;如果是要産品落地,那麼直接拿來用可能就會被速度拖累。
問答起飛
如果你平時遇到任何困擾你已久、或面試中的遇到目标檢測等相關問題,可以加群(掃碼下方二維碼,備注互助群,就會拉入群),告訴我們,統一記錄在《Question List》中。盡力幫助大家解決難題,真正解決問題的那種!我們解決不了的,會在公衆号内發起求助,盡力解決問題。