25、SKAttention模塊
論文《Selective Kernel Networks》
1、作用
該論文介紹了選擇性核網(wǎng)絡(luò)(SKNets),這是一種在卷積神經(jīng)網(wǎng)絡(luò)(CNN)中的動(dòng)態(tài)選擇機(jī)制,允許每個(gè)神經(jīng)元根據(jù)輸入自適應(yīng)地調(diào)整其感受野大小。這種方法受到視覺皮層神經(jīng)元對(duì)不同刺激響應(yīng)時(shí)感受野大小變化的啟發(fā),在CNN設(shè)計(jì)中不常利用此特性。
2、機(jī)制
SKNets利用了一個(gè)稱為選擇性核(SK)單元的構(gòu)建模塊,該模塊包含具有不同核大小的多個(gè)分支。這些分支通過一個(gè)softmax注意力機(jī)制融合,由這些分支中的信息引導(dǎo)。這個(gè)融合過程使得神經(jīng)元能夠根據(jù)輸入自適應(yīng)地調(diào)整其有效感受野大小。
3、獨(dú)特優(yōu)勢(shì)
1、自適應(yīng)感受野:
SKNets中的神經(jīng)元可以基于輸入動(dòng)態(tài)調(diào)整其感受野大小,模仿生物神經(jīng)元的適應(yīng)能力。這允許在不同尺度上更有效地處理視覺信息。
2、計(jì)算效率:
盡管為了適應(yīng)性而納入了多種核大小,SKNets仍然保持了較低的模型復(fù)雜度,與現(xiàn)有最先進(jìn)的架構(gòu)相比。通過仔細(xì)的設(shè)計(jì)選擇,如使用高效的分組/深度卷積和注意力機(jī)制中的縮減比率來(lái)控制參數(shù)數(shù)量,實(shí)現(xiàn)了這種效率。
3、性能提升:
在ImageNet和CIFAR等基準(zhǔn)測(cè)試上的實(shí)驗(yàn)結(jié)果顯示,SKNets在具有相似或更低模型復(fù)雜度的情況下,超過了其他最先進(jìn)的架構(gòu)。適應(yīng)性調(diào)整感受野的能力可能有助于更有效地捕捉不同尺度的目標(biāo)對(duì)象,提高識(shí)別性能。
4、代碼
import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict
class SKAttention(nn.Module):
def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
super().__init__()
# 計(jì)算維度壓縮后的向量長(zhǎng)度
self.d = max(L, channel // reduction)
# 不同尺寸的卷積核組成的卷積層列表
self.convs = nn.ModuleList([])
for k in kernels:
self.convs.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
('bn', nn.BatchNorm2d(channel)),
('relu', nn.ReLU())
]))
)
# 通道數(shù)壓縮的全連接層
self.fc = nn.Linear(channel, self.d)
# 為每個(gè)卷積核尺寸對(duì)應(yīng)的特征圖計(jì)算注意力權(quán)重的全連接層列表
self.fcs = nn.ModuleList([])
for i in range(len(kernels)):
self.fcs.append(nn.Linear(self.d, channel))
# 注意力權(quán)重的Softmax層
self.softmax = nn.Softmax(dim=0)
def forward(self, x):
bs, c, _, _ = x.size()
conv_outs = []
# 通過不同尺寸的卷積核處理輸入
for conv in self.convs:
conv_outs.append(conv(x))
feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w
# 將所有卷積核的輸出求和得到融合特征圖U
U = sum(conv_outs) # bs,c,h,w
# 對(duì)融合特征圖U進(jìn)行全局平均池化,并通過全連接層降維得到Z
S = U.mean(-1).mean(-1) # bs,c
Z = self.fc(S) # bs,d
# 計(jì)算每個(gè)卷積核對(duì)應(yīng)的注意力權(quán)重
weights = []
for fc in self.fcs:
weight = fc(Z)
weights.append(weight.view(bs, c, 1, 1)) # bs,channel
attention_weights = torch.stack(weights, 0) # k,bs,channel,1,1
attention_weights = self.softmax(attention_weights) # k,bs,channel,1,1
# 將注意力權(quán)重應(yīng)用到對(duì)應(yīng)的特征圖上,并對(duì)所有特征圖進(jìn)行加權(quán)求和得到最終的輸出V
V = (attention_weights * feats).sum(0)
return V
# 示例用法
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
sk = SKAttention(channel=512, reduction=8)
output = sk(input)
print(output.shape) # 輸出經(jīng)過SK注意力處理后的特征圖形狀