【GiantPandaCV導(dǎo)語(yǔ)】基于Transformer的骨干網(wǎng)絡(luò),同時(shí)使用卷積與自注意力機(jī)制來(lái)保持全局性和局部性。模型在ResNet最后三個(gè)BottleNeck中使用了MHSA替換3x3卷積。屬于早期的結(jié)合CNN+Transformer的工作。簡(jiǎn)單來(lái)講Non-Local+Self Attention+BottleNeck = BoTNet
引言
本文的發(fā)展脈絡(luò)如下圖所示:
實(shí)際上沿著Transformer Block改進(jìn)的方向進(jìn)行的,與CNN架構(gòu)也是兼容的。具體結(jié)構(gòu)如下圖所示:
兩者都遵循了BottleNeck的設(shè)計(jì)原則,可以有效降低計(jì)算量。
設(shè)計(jì)Transformer中self attention存在幾個(gè)挑戰(zhàn):
- 圖片尺寸比較大,比如目標(biāo)檢測(cè)中分辨率在1024x1024
- 內(nèi)存和計(jì)算量的占用高,導(dǎo)致訓(xùn)練開(kāi)銷比較大。
本文設(shè)計(jì)如下:
- 使用卷積識(shí)別底層特征的抽象信息。
- 使用self attention處理通過(guò)卷積層得到的高層信息。
這樣可以有效處理大分辨率圖像。
方法
BoTNet中MHSA模塊如下圖所示:
上邊的這個(gè)MHSA Block是核心創(chuàng)新點(diǎn),其與Transformer中的MHSA有所不同:
- 由于處理對(duì)象不是一維的,而是類似CNN模型,所以有非常多特性與此相關(guān)。
- 歸一化這里并沒(méi)有使用Layer Norm而是采用的Batch Norm,與CNN一致。
- 非線性激活,BoTNet使用了三個(gè)非線性激活
- 左側(cè)content-position模塊引入了二維的位置編碼,這是與Transformer中最大區(qū)別。
由于該模塊是處理BxCHW的形式,所以難免讓人想起來(lái)Non Local 操作,這里列出筆者以前繪制的一幅圖:
可以看出主要區(qū)別就是在于Content-postion模塊引入的位置信息。
BoTNet細(xì)節(jié)設(shè)計(jì):
整體的設(shè)計(jì)和ResNet50幾乎一樣,唯一不同在于最后一個(gè)階段中三個(gè)BottleNeck使用了MHSA模塊。具體這樣做的原因是Self attention需要消耗巨大的計(jì)算量,在模型最后加入時(shí)候feature map的size比較小,相對(duì)而言計(jì)算量比較小。
實(shí)驗(yàn)
在目標(biāo)檢測(cè)和分割領(lǐng)域性能對(duì)比

分辨率改變對(duì)BoTNet幫助更大
消融實(shí)驗(yàn)-相對(duì)位置編碼
BoTNet對(duì)ResNet系列模型的提升
BoTNet與更大的圖片適配
BoTNet與Non-Local Net的比較
與ImageNet上結(jié)果比較

模型放縮的影響
顯卡香氣飄來(lái),又是谷歌的騷操作,將EfficientNet方法放在BoTNet上:
可以看出與期望相符合,Transformer架構(gòu)帶來(lái)的性能上限要高于CNN,雖然模型大小比較小的時(shí)候性能比較弱,但是模型量變大以后其性能就有了保證。
代碼
核心模塊:MHSA (由第三方進(jìn)行實(shí)現(xiàn))
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=4):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k)
content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
content_position = torch.matmul(content_position, q)
energy = content_content + content_position
attention = self.softmax(energy)
out = torch.matmul(v, attention.permute(0, 1, 3, 2))
out = out.view(n_batch, C, width, height)
return out
參考
https://arxiv.org/abs/2101.11605
https://zhuanlan.zhihu.com/p/347849929
https://github.com/leaderj1001/BottleneckTransformers/blob/main/model.py
跑不動(dòng)ImageNet,想試試Vision Transformer的同學(xué)可以看看這個(gè)倉(cāng)庫(kù),
https://github.com/pprp/pytorch-cifar-model-zoo
在CIFAR10上測(cè)試:
python train.py --model 'botnet' --name "fast_training" --sched 'cosine' --epochs 100 --cutout True --lr 0.1 --bs 128 --nw 4
目前可以在100個(gè)epoch內(nèi)達(dá)到驗(yàn)證集91.1%的準(zhǔn)確率。