文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021
下面說一下我對(duì)這篇文章的淺陋之見, 如有錯(cuò)誤, 請(qǐng)多包涵指正.
文章的核心方法
如下圖所示為其處理流程:

圖中 X 為CNN骨干網(wǎng)絡(luò)提取得到的feature, 其大小為 d*h*w , 為1個(gè)batch數(shù)據(jù). 一般 d*h*w=2048*7*7 .
從圖中可以看到, 有2個(gè)分支, 一個(gè)是 average pooling, 一個(gè)是 spatial pooling, 最后二者加權(quán)融合得到 residual attention .
Spatial pooling
其過程為:

這里有個(gè) 1*1 的卷積操作FC , 其大小為 C*d*1*1 , C 為類別數(shù), 如果直接使用矩陣乘法計(jì)算, FC(X) 后的大小為 C*h*w .
但文章中的公式是將其展開為對(duì)每個(gè)空間點(diǎn)單獨(dú)計(jì)算, 其中 為
FC 第i 個(gè)類別的參數(shù), 其大小為 d*1*1, 計(jì)算得到的 為第
i 個(gè)類別在第 j 個(gè)位置的概率, 為第
i 個(gè)類別的特征, 其大小為 d*1 .
如果, 和
計(jì)算就可以得到第
i 個(gè)類別的概率. 這樣就可以用到每個(gè)空間點(diǎn)的特征, 有利于不同目標(biāo)不同類別物體的分類識(shí)別.
公式中有個(gè)溫度參數(shù) T 用來控制 的大小, 當(dāng)
T 趨于無窮時(shí), spatial pooling 就變成了 max pooling
Average pooling
其過程為:

上式其實(shí)就是一般分類模型的做法, 全局均值池化.
Residual Attention
如下所示, 將上述2個(gè)過程進(jìn)行加權(quán)融合:

其中, 大小為
d*1, 為第
i 個(gè)類別的概率.
至于為什么叫 Residual Attention , 文章中的說法是:
the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.
我的理解是, 公式5形式有點(diǎn)像 residual 形式.
文章實(shí)驗(yàn)結(jié)果
多標(biāo)簽
如下表所示為作者對(duì)多個(gè)數(shù)據(jù)集的測(cè)試, 除了ImageNet 為單標(biāo)簽外, 其它都為多標(biāo)簽. 可以看到多標(biāo)簽提升還是不錯(cuò)的.

熱力圖
由于利用到了不同位置空間點(diǎn)的信息, 獲得的 heatmap會(huì)更加準(zhǔn)確, 文章中給出了一張結(jié)果, 如下:

我覺得這里有個(gè)遺憾的是, 文中沒有進(jìn)行對(duì)比.
個(gè)人理解
關(guān)于原理
根據(jù)流程圖, 結(jié)合文中作者給出的核心代碼, 其基本原理就是 average pooling + max pooling.

上述代碼中: y_avg 大小為 C*1, 為 average pooling ; y_max 大小為 C*1, 為 max pooling .
下面是上述代碼的一個(gè)例子, y_raw 的大小為 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 剛好為 average pooling , y_max 剛好為 max pooling .
關(guān)于公式
公式中的溫度參數(shù) T 用于調(diào)整參數(shù)大小, 而給出的核心代碼中, 只有T趨于無窮的情況(等價(jià)于max pooling), 對(duì)于多個(gè) Head 的情況, T=2,3,4,5 等, 代碼中是如何體現(xiàn)出來的?
關(guān)于效果
對(duì)于 multi-label , 使用了 spatial pooling 和 multi-head 來提高效果, 從實(shí)驗(yàn)結(jié)果來看, 確實(shí)有效果, 但對(duì)于單標(biāo)簽情況, max pooling 應(yīng)該改善不大, 從實(shí)驗(yàn)結(jié)果上看也確實(shí)可以看到, 單標(biāo)簽數(shù)據(jù)集上, 最高提升了0.02個(gè)百分點(diǎn).
測(cè)試代碼
測(cè)試代碼如下, 可以參考這里.
import torch
from torch import nn
class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score
if __name__ == '__main__':
channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)