CNN、Transformer、MLP架構(gòu)的經(jīng)驗(yàn)性分析

A Battle of Network Structures: An Empirical Study of CNN, Transformer, and MLP

【GiantPandaCV導(dǎo)語】ViT的興起挑戰(zhàn)了CNN的地位,隨之而來的是MLP系列方法。三種架構(gòu)各有特點(diǎn),為了公平地比較幾種架構(gòu),本文提出了統(tǒng)一化的框架SPACH來對(duì)比,得到了具有一定insight的結(jié)論。

背景

近期Transformer MLP系列模型的出現(xiàn),增加了CV領(lǐng)域的多樣性,MLP-Mixer的出現(xiàn)表明卷積或者注意力都不是模型性能優(yōu)異的必要條件。不同架構(gòu)的模型進(jìn)行比較的過程中,會(huì)使用不同的正則化方法、訓(xùn)練技巧等,為了比較的公平性,本文提出了SPACH的統(tǒng)一框架,期望對(duì)幾種架構(gòu)進(jìn)行對(duì)比,同時(shí)探究他們各自的特點(diǎn)。

image

這個(gè)框架總體來說有兩種模式:多階段和單階段。每個(gè)階段內(nèi)部采用的是Mixing Block,而該Mixing Block可以是卷積層、Transformer層以及MLP層。

經(jīng)過實(shí)驗(yàn)發(fā)現(xiàn)了以下幾個(gè)結(jié)論:

  • 多階段框架效果優(yōu)于單節(jié)段框架(通過降采樣劃分階段)
  • 局部性建模具有高效性和重要性。
    • 通過使用輕量級(jí)深度卷積(depth wise conv),基于卷積的模型就可以取得與Transformer模型類似的性能。
    • 在MLP和Transformer的架構(gòu)的支路中使用一些局部的建模可以在有效提升性能同時(shí),只增加一點(diǎn)點(diǎn)參數(shù)量。
  • MLP在小型模型中具有非常強(qiáng)的性能表現(xiàn),但是模型容量擴(kuò)大的時(shí)候會(huì)出現(xiàn)過擬合問題,過擬合是MLP成功路上的攔路虎。
  • 卷積操作和Transformer操作是互補(bǔ)的,卷積的泛化性能更強(qiáng),Transformer結(jié)構(gòu)模型容量更大。通過靈活組合兩者可以掌控從小到大的所有模型。

統(tǒng)一框架

本文提出一統(tǒng)MLP、Transformer、Convolution的框架:SPACH

下表展示的是各個(gè)模塊中可選的參數(shù),并提出了三種變體空間。

image

其中各個(gè)模塊設(shè)計(jì)如下:

image
  • (a)展示的是卷積部分操作,使用的是3x3深度可分離卷積。
  • (b)展示的是Transformer模塊,使用了positional embedding(由于目前一些研究使用absolute positional embedding會(huì)導(dǎo)致模塊模型的平移不變性,因此采用Convolutional Position Encoding(CPE)。
  • (c)展示的是MLP模塊,參考了MLP-Mixer的設(shè)計(jì),雖然MLP-Mixer中并沒有使用Positional Embedding,但是作者發(fā)現(xiàn)通過增加輕量級(jí)的CPE能夠有效提升模型性能。

注:感覺這三種模塊的設(shè)計(jì)注入了很多經(jīng)驗(yàn)型設(shè)計(jì),比如卷積并沒有用普通卷積,用深度可分離卷積其實(shí)類似MLP中的操作,此外為MLP引入CPE的操作也非常具有技巧性。

三種模塊具有不同的屬性:

image

所謂dynamic weight是Transformer中可以根據(jù)圖片輸入的不同動(dòng)態(tài)控制權(quán)重,這樣的模型的容量相較CNN更高。CNN中也有這樣的趨勢(shì),dynamic network的出現(xiàn)也是為了實(shí)現(xiàn)動(dòng)態(tài)權(quán)重。(感謝zzk老師的講解)Transformer側(cè)重是關(guān)系的學(xué)習(xí)和建模,不完全依賴于數(shù)據(jù),CNN側(cè)重模板的匹配和建模,比較依賴于數(shù)據(jù)。

Transformer CNN
Dynamic Attention Multi-scale Features by multi-stage
Global Context Fusion Shift,scale and distortion invariance
Better Generalization(學(xué)習(xí)關(guān)系,不完全依賴數(shù)據(jù)) Local Spatial Modeling

實(shí)驗(yàn)

實(shí)驗(yàn)設(shè)置:

  • 數(shù)據(jù)集選擇ImageNet-1K
  • 輸入分辨率224x224
  • 訓(xùn)練設(shè)置參看DeiT
  • AdamW優(yōu)化器訓(xùn)練300個(gè)epoch
  • weight decay: 0.05 (T用的weight decay更?。?/li>
  • learning rate:0.005 對(duì)應(yīng) 512 batch size(T用的lr更?。?/li>

結(jié)論1:multi-stage 要比 single-stage性能更好

image

具體性能如下表所記錄,Multi-Stage能夠顯著超過Single Stage的模型。

image

可以發(fā)現(xiàn),有一個(gè)例外,在xxs尺度下,Transformer進(jìn)度損失了2.6個(gè)百分點(diǎn),因?yàn)槎嚯A段模型恰好只有單階段模型一半的參數(shù)量和Flops。

隨著參數(shù)量的增加,模型最高精度先后由MLP、Conv、Transformer所主導(dǎo)。

結(jié)論2:局部建模非常重要

image

上表展示了具有局部建模以及去除局部建模的效果,可以發(fā)現(xiàn)使用卷積旁路的時(shí)候吞吐量略微降低,但是精度有顯著提高。

結(jié)論3:MLP的細(xì)節(jié)分析

MLP性能不足主要源自于過擬合問題,可以使用兩種機(jī)制來緩解這個(gè)問題。

  • Multi-Stage的網(wǎng)絡(luò)機(jī)制,可以從以上實(shí)驗(yàn)發(fā)現(xiàn),multi-stage能夠有效降低過擬合,提高模型性能。
image
  • 權(quán)重共享機(jī)制,MLP在模型參數(shù)量比較大的情況下容易過擬合,但是如果使用權(quán)重共享可以有效緩解過擬合問題。具體共享的方法是對(duì)于某個(gè)stage的所有Mixing Block均使用相同的MLP進(jìn)行處理。
image

結(jié)論4:卷積與Transformer具有互補(bǔ)性

作者認(rèn)為卷積具有的泛化能力更強(qiáng),而Transformer具有更大的模型容量,如下圖所示,在Loss比較大的情況下,整體的準(zhǔn)確率是超過了Transformer空間的。

image

結(jié)論5: 混合架構(gòu)的模型

在multi-stage的卷積網(wǎng)絡(luò)基礎(chǔ)上將某些Mixing Block替換為Transformer的Block, 并且處于對(duì)他們建模能力的考量,選擇在淺層網(wǎng)絡(luò)使用CNN,深層網(wǎng)絡(luò)使用Transformer,得到兩種模型空間:

image

SOTA模型比較結(jié)果:

image

整體結(jié)論是:

  • Transformer能力要比MLP強(qiáng),因此不考慮使用MLP作為混合架構(gòu)
  • 混合Transformer+CNN的架構(gòu)性能上能夠超越單獨(dú)的CNN架構(gòu)或者Transformer架構(gòu)
  • FLOPS與ACC的權(quán)衡做的比較出色,能夠超越Swin Transformer以及NAS搜索得到的RegNet系列。

最后作者還向讀者進(jìn)行提問:

  • MLP性能欠佳是由于過擬合帶來的,能夠設(shè)計(jì)高性能MLP模型防止過擬合呢?
  • 目前的分析證明卷積或者Transformer并不是一家獨(dú)大,如何用更好的方式融合兩種架構(gòu)?
  • 是否存在MLP,CNN,Transformer之外的更有效地架構(gòu)呢?

代碼

對(duì)照下圖逐步給出各個(gè)Mixing Block:

image

(a)卷積模塊 ,kernel為3的深度可分離卷積

class DWConv(nn.Module):
    def __init__(self, dim, kernel_size=3):
        super(DWConv, self).__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        padding = (kernel_size - 1) // 2
        self.net = nn.Sequential(Reshape2HW(),
                                 nn.Conv2d(dim, dim, kernel_size, 1, padding, groups=dim),
                                 Reshape2N())

    def forward(self, x):
        x = self.net(x)
        return x

(b)Transformer


class SpatialAttention(nn.Module):
    """Spatial Attention"""
    def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., **kwargs):
        super(SpatialAttention, self).__init__()
        head_dim = dim // num_heads

        self.num_heads = num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = rearrange(qkv, "b n (three heads head_c) -> three b heads n head_c", three=3, heads=self.num_heads)
        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1))  # B, head, N, N
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v)  # B, head, N, C
        out = rearrange(out, "b heads n head_c -> b n (heads head_c)")

        out = self.proj(out)
        out = self.proj_drop(out)

        return out

(c)MLP模塊,分為channel mlp和spatial mlp,與MLP-Mixer保持一致

class ChannelMLP(nn.Module):
    """Channel MLP"""
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., **kwargs):
        super(ChannelMLP, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        self.hidden_features = hidden_features
        self.out_features = out_features

    def forward(self, x):
        B, N, C = x.shape
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class SpatialAttention(nn.Module):
    """Spatial Attention"""
    def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., **kwargs):
        super(SpatialAttention, self).__init__()
        head_dim = dim // num_heads

        self.num_heads = num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = rearrange(qkv, "b n (three heads head_c) -> three b heads n head_c", three=3, heads=self.num_heads)
        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1))  # B, head, N, N
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v)  # B, head, N, C
        out = rearrange(out, "b heads n head_c -> b n (heads head_c)")

        out = self.proj(out)
        out = self.proj_drop(out)

        return out

SPACH骨干網(wǎng)絡(luò)的構(gòu)建: MixingBlock

class MixingBlock(nn.Module):
    def __init__(self, dim,
                 spatial_func=None, scaled=True, init_values=1e-4, shared_spatial_func=False,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_path=0., cpe=True,
                 num_heads=None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,  # attn
                 in_features=None, hidden_features=None, drop=0.,  # mlp
                 channel_ratio=2.0
                 ):
        super(MixingBlock, self).__init__()

        spatial_kwargs = dict(act_layer=act_layer,
                              in_features=in_features, hidden_features=hidden_features, drop=drop,  # mlp
                              dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop  # attn
                              )

        self.valid_spatial_func = True

        if spatial_func is not None:
            if shared_spatial_func:
                self.spatial_func = spatial_func
            else:
                self.spatial_func = spatial_func(**spatial_kwargs)
            self.norm1 = norm_layer(dim)
            if scaled:
                self.gamma_1 = nn.Parameter(init_values * torch.ones(1, 1, dim), requires_grad=True)
            else:
                self.gamma_1 = 1.
        else:
            self.valid_spatial_func = False

        self.channel_func = ChannelMLP(in_features=dim, hidden_features=int(dim*channel_ratio), act_layer=act_layer,
                                       drop=drop)

        self.norm2 = norm_layer(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()


        self.cpe = cpe
        if cpe:
            self.cpe_net = DWConv(dim)


    def forward(self, x):
        in_x = x
        if self.valid_spatial_func:
            x = x + self.drop_path(self.gamma_1 * self.spatial_func(self.norm1(in_x)))
        if self.cpe:
            x = x + self.cpe_net(in_x)

        x = x + self.drop_path(self.channel_func(self.norm2(x)))

        return 

SPACH構(gòu)建:

class Spach(nn.Module):
    def __init__(self,
                 num_classes=1000,
                 img_size=224,
                 in_chans=3,
                 hidden_dim=384,
                 patch_size=16,
                 net_arch=None,
                 act_layer=nn.GELU,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 stem_type='conv1',
                 scaled=True, init_values=1e-4, drop_path_rate=0., cpe=True, shared_spatial_func=False,  # mixing block
                 num_heads=12, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,  # attn
                 token_ratio=0.5, channel_ratio=2.0, drop_rate=0.,  # mlp
                 downstream=False,
                 **kwargs
                 ):
        super(Spach, self).__init__()
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.downstream = downstream

        self.stem = STEM_LAYER[stem_type](
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, downstream=downstream)
        self.norm1 = norm_layer(hidden_dim)

        block_kwargs = dict(dim=hidden_dim, scaled=scaled, init_values=init_values, cpe=cpe,
                            shared_spatial_func=shared_spatial_func, norm_layer=norm_layer, act_layer=act_layer,
                            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop,  # attn
                            in_features=self.stem.num_patches, hidden_features=int(self.stem.num_patches * token_ratio), channel_ratio=channel_ratio, drop=drop_rate)  # mlp

        self.blocks = self.make_blocks(net_arch, block_kwargs, drop_path_rate, shared_spatial_func)
        self.norm2 = norm_layer(hidden_dim)

        if not downstream:
            self.pool = Reduce('b n c -> b c', reduction='mean')
            self.head = nn.Linear(hidden_dim, self.num_classes)

        self.init_weights()

    def make_blocks(self, net_arch, block_kwargs, drop_path, shared_spatial_func):
        if shared_spatial_func:
            assert len(net_arch) == 1, '`shared_spatial_func` only support unitary spatial function'
            assert net_arch[0][0] != 'pass', '`shared_spatial_func` do not support pass'
            spatial_func = SPATIAL_FUNC[net_arch[0][0]](**block_kwargs)
        else:
            spatial_func = None
        blocks = []
        for func_type, depth in net_arch:
            for i in range(depth):
                blocks.append(MixingBlock(spatial_func=spatial_func or SPATIAL_FUNC[func_type], drop_path=drop_path,
                                          **block_kwargs))
        return nn.Sequential(*blocks)

    def init_weights(self):
        for n, m in self.named_modules():
            _init_weights(m, n)

    def forward_features(self, x):
        x = self.stem(x)
        x = reshape2n(x)
        x = self.norm1(x)

        x = self.blocks(x)
        x = self.norm2(x)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.pool(x)
        x = self.head(x)
        return x

參考

https://github.com/microsoft/SPACH

https://zhuanlan.zhihu.com/p/411145994

https://arxiv.org/pdf/2108.13002v2.pdf

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 愛可可老師7月至9月的推薦(2021.10.12整理)——MLPs與Vision Transformer主干架構(gòu)專...
    Valar_Morghulis閱讀 1,035評(píng)論 0 0
  • 0. 前言 近兩年學(xué)術(shù)界對(duì)Transformer在CV上的應(yīng)用可謂異常青睞,這里重點(diǎn)強(qiáng)調(diào)學(xué)術(shù)界的原因是目前工業(yè)界還...
    mrhalyang閱讀 2,577評(píng)論 0 0
  • ?作者 |小欣 01 簡介 最近,為了解決卷積神經(jīng)網(wǎng)絡(luò)的一些普遍存在的缺點(diǎn),如對(duì)于上下文信息的建模差、全局信息理解...
    CV案例精選閱讀 1,049評(píng)論 0 0
  • transformer是一種主要基于自注意力機(jī)制的深度神經(jīng)網(wǎng)絡(luò),最初應(yīng)用于自然語言處理領(lǐng)域。受到transform...
    Carrie_Hou閱讀 2,456評(píng)論 0 0
  • 《MLP-Mixer: An all-MLP Architecture for Vision》[https://a...
    Valar_Morghulis閱讀 905評(píng)論 0 1

友情鏈接更多精彩內(nèi)容