【CV中的Attention機制】CBAM模塊

前言: CBAM模塊由于其使用的廣泛性以及易于集成得到很多應(yīng)用。目前cv領(lǐng)域中的attention機制也是在2019年論文中非?;稹_@篇cbam雖然是在2018年提出的,但是其影響力比較深遠,在很多領(lǐng)域都用到了該模塊。

1. 什么是注意力機制?

注意力機制(Attention Mechanism)是機器學(xué)習(xí)中的一種數(shù)據(jù)處理方法,廣泛應(yīng)用在自然語言處理、圖像識別及語音識別等各種不同類型的機器學(xué)習(xí)任務(wù)中。

通俗來講:注意力機制就是希望網(wǎng)絡(luò)能夠自動學(xué)出來圖片或者文字序列中的需要注意的地方。比如人眼在看一幅畫的時候,不會將注意力平等地分配給畫中的所有像素,而是將更多注意力分配給人們關(guān)注的地方。

從實現(xiàn)的角度來講:注意力機制通過神經(jīng)網(wǎng)絡(luò)的操作生成一個掩碼mask, mask上的值一個打分,評價當前需要關(guān)注的點的評分。

注意力機制可以分為:

  • 通道注意力機制:對通道生成掩碼mask,進行打分,代表是senet, Channel Attention Module
  • 空間注意力機制:對空間進行掩碼的生成,進行打分,代表是Spatial Attention Module
  • 混合域注意力機制:同時對通道注意力和空間注意力進行評價打分,代表的有BAM, CBAM

2. CBAM模塊的實現(xiàn)

CBAM全稱是Convolutional Block Attention Module, 是在ECCV2018上發(fā)表的注意力機制代表作之一。本人在打比賽的時候遇見過有人使用過該模塊取得了第一名的好成績,證明了其有效性。

在該論文中,作者研究了網(wǎng)絡(luò)架構(gòu)中的注意力,注意力不僅要告訴我們重點關(guān)注哪里,還要提高關(guān)注點的表示。 目標是通過使用注意機制來增加表現(xiàn)力,關(guān)注重要特征并抑制不必要的特征。為了強調(diào)空間和通道這兩個維度上的有意義特征,作者依次應(yīng)用通道和空間注意模塊,來分別在通道和空間維度上學(xué)習(xí)關(guān)注什么、在哪里關(guān)注。此外,通過了解要強調(diào)或抑制的信息也有助于網(wǎng)絡(luò)內(nèi)的信息流動。

主要網(wǎng)絡(luò)架構(gòu)也很簡單,一個是通道注意力模塊,另一個是空間注意力模塊,CBAM就是先后集成了通道注意力模塊和空間注意力模塊。

2.1 通道注意力機制

image

通道注意力機制按照上圖進行實現(xiàn):

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, rotio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
            nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

核心的部分Shared MLP使用了1\times1卷積完成的,進行信息的提取。需要注意的是,其中的bias需要人工設(shè)置為False。

2.2 空間注意力機制

image

空間注意力機制按照上圖進行實現(xiàn):

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3,7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

這個部分實現(xiàn)也很簡單,分別從通道維度進行求平均和求最大,合并得到一個通道數(shù)為2的卷積層,然后通過一個卷積,得到了一個通道數(shù)為1的spatial attention。

2.3 Convolutional bottleneck attention module

image
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()
        self.downsample = downsample
        self.stride = stride
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.ca(out) * out  # 廣播機制
        out = self.sa(out) * out  # 廣播機制
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

最后的使用一個類進行兩個模塊的集成,得到的通道注意力和空間注意力以后,使用廣播機制對原有的feature map進行信息提煉,最終得到提煉后的feature map。以上代碼以ResNet中的模塊作為對象,實際運用可以單獨將以下模塊融合到網(wǎng)絡(luò)中:

 class cbam(nn.Module):
    def __init__(self, planes):
        self.ca = ChannelAttention(planes)# planes是feature map的通道個數(shù)
        self.sa = SpatialAttention()
     def forward(self, x):
        x = self.ca(out) * x  # 廣播機制
        x = self.sa(out) * x  # 廣播機制

3. 在什么情況下可以使用?

提出CBAM的作者主要對分類網(wǎng)絡(luò)和目標檢測網(wǎng)絡(luò)進行了實驗,證明了CBAM模塊確實是有效的。

以ResNet為例,論文中提供了改造的示意圖,如下圖所示:

image

也就是在ResNet中的每個block中添加了CBAM模塊,訓(xùn)練數(shù)據(jù)來自benchmark ImageNet-1K。檢測使用的是Faster R-CNN, Backbone選擇的ResNet34,ResNet50, WideResNet18, ResNeXt50等,還跟SE等進行了對比。

消融實驗:消融實驗一般是控制變量,最能看出模型變好起作用的部分在那里。分為三個部分:

  1. 如何更有效地計算channel attention?
image

可以看出來,使用avgpool和maxpool可以更好的降低錯誤率,大概有1-2%的提升,這個組合就是dual pooling,能提供更加精細的信息,有利于提升模型的表現(xiàn)。

  1. 如何更有效地計算spatial attention?
image

這里的空間注意力機制參數(shù)也是有avg, max組成,另外還有一個卷積的參數(shù)kernel_size(k), 通過以上實驗,可以看出,當前使用通道的平均和通道的最大化,并且設(shè)置kernel size=7是最好的。

  1. 如何組織這兩個部分?
image

可以看出,這里與SENet中的SE模塊也進行了比較,這里使用CBAM也是超出了SE的表現(xiàn)。除此以外,還進行了順序和并行的測試,發(fā)現(xiàn),先channel attention然后spatial attention效果最好,所以也是最終的CBAM模塊的組成。

在MSCOCO數(shù)據(jù)及使用了ResNet50,ResNet101為backbone, Faster RCNN為檢測器的模型進行目標檢測,如下圖所示:

image

在VOC2007數(shù)據(jù)集中采用了StairNet進行了測試,如下圖所示:

image

官方貌似沒有提供目標檢測部分的代碼,CBAM的作用在于對信息進行精細化分配和處理,所以猜測是在backbone的分類器之前添加的CBAM模塊,歡迎有研究的小伙伴留言。

4. 參考

CBAM arxiv link: https://arxiv.org/pdf/1807.06521.pdf

核心代碼:https://github.com/pprp/SimpleCVReproduction/blob/master/attention/CBAM/cbam.py

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容