Pay Attention to MLPs

Pay Attention to MLPs

Ref: https://arxiv.org/pdf/2105.08050.pdf
code:https://github.com/lucidrains/g-mlp-pytorch/blob/54209f0fb2a52557a1c64409f26df9ebd8d5c257/g_mlp_pytorch/g_mlp_pytorch.py

背景

Transformers 自橫空出世以來(lái),在NLP領(lǐng)域,大規(guī)模了取代了LSTM-RNN模型,在CV上,ConvNets也不再是唯一選擇。它有2個(gè)重要特性:

  1. recurrent-free結(jié)構(gòu),可以并行化計(jì)算每個(gè)token的表達(dá);
  2. multi-head self-attention blocks, 可以聚合token之間的空間信息。

其中的attention mechanism一直被認(rèn)為transformers取得優(yōu)秀成績(jī)的重要因素。和MLP相比,attention可以根據(jù)模型輸入,調(diào)整參數(shù),而MLP的參數(shù)是固定的。那么問(wèn)題來(lái)了,transformers效果那么好,是self-attention起的決定性作用嗎,self-attention是必要的嗎

本文提出了gMLPs,一種attention-free, 以MLP為基礎(chǔ)的由channel projections, spatial projections 和gating組成的網(wǎng)絡(luò)結(jié)構(gòu)。

實(shí)驗(yàn)顯示:

  1. 在CV上,可以達(dá)到和vision transformers差不多的準(zhǔn)確率;和MLP-Mixer相比,參數(shù)減少66%,準(zhǔn)確率還提升了3%;
  2. 在NLP上,將gMLPs應(yīng)用到BERT的MLM,和transformers一樣,在預(yù)訓(xùn)練實(shí)時(shí)能最小化perplexity。同時(shí),實(shí)驗(yàn)也顯示,perplexity和模型規(guī)模有關(guān),而對(duì)attention不敏感;
    2.1 隨著模型的capacity上升,gMLPs的預(yù)訓(xùn)練和finetuning指標(biāo)會(huì)快速接近Transformers,這意味著,只要擴(kuò)大模型規(guī)模,那么無(wú)需self-attention,gMLPs和Transformers的差距會(huì)不斷縮?。?br> 2.2 batch-size為256,進(jìn)過(guò)1Mstep,gMLPs相比Bert,在MNLI達(dá)到了86.4%的準(zhǔn)確率,在SQuAD達(dá)到了89.5%的F1;
    2.3 在finetuning階段,模型規(guī)模和perplexity接近的情況下, Transformers在cross-sentence alignment任務(wù)上比gMLPs效果好[MNLI任務(wù)高1.8%]。但是,當(dāng)gMLPs的參數(shù)量是transformers的3倍時(shí),模型效果就很接近;
    2.4 同時(shí),文中提出一個(gè)trick,在gMLPs后接一個(gè)single-head 128d 的attention,在NLP的各項(xiàng)任務(wù)上,就超過(guò)了transformers。

因此,本文覺(jué)得,提高數(shù)據(jù)量和算力,無(wú)需self-attention,gMLPs,就可以和transformers媲美。

Model

輸入:序列長(zhǎng)度為n,embedding維度為d:
X\in R^{n\times d}

使用L個(gè)block,每個(gè)block進(jìn)行如下操作:

Z = \sigma (XU) = GeLU(XU)
\tilde Z = s(Z)
Y = \tilde ZV

其中:
U,V為沿著channel[可理解為hidden維度]的線性投影,同Transformers的FFN;
s(\cdot)為空間上的交互,用于獲取tokens之間的關(guān)系。本文認(rèn)為s(\cdot)可以學(xué)習(xí)到位置信息,因此,沒(méi)有使用positional embedding。

gMLPs

Spatial Gating Unit

為了實(shí)現(xiàn)token之間的交互,在s(\cdot)層,就要包含一個(gè)空間維度的交叉操作。

文中主要介紹了2種SGU:

  1. 比較直觀的,就是使用線性投影:
    f_{W,b} (Z) = WZ + b

其中:
W\in R^{n\times n}, n為序列長(zhǎng)度;b可以是一個(gè)矩陣,也可以是一個(gè)常量。
空間交互通過(guò)element-wise實(shí)現(xiàn):
s(Z) = Z \odot f_{W,b} (Z)

為確保訓(xùn)練的穩(wěn)定性,W初始化值接近于0, b為1。這相當(dāng)于初始化的FFN,開(kāi)始每個(gè)token相互獨(dú)立,隨著訓(xùn)練逐漸考慮token之間的交互信息。

  1. 除了線性投影的gatef_{W,b} (\cdot), 文中還將Z沿著channel分解成(Z_1,Z_2),借鑒GLUs的思路:
    s(Z) = Z_1 \odot f_{W,b} (Z_2)

源代碼分析

class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
        """dim: embedding size 
            dim_seq: sequence length """
        super().__init__()
        dim_out = dim // 2
        self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1) 
        # 常規(guī)卷積,卷積的是詞向量的維度。本文是空間上的信息交互,因此輸入/輸出通道是序列長(zhǎng)度,卷積核尺寸為1。

        self.act = act

        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)

    def forward(self, x, gate_res = None):
        device, n = x.device, x.shape[1]

        res, gate = x.chunk(2, dim = -1) #沿著詞向量維度,分成2個(gè)矩陣。
        gate = self.norm(gate)

        weight, bias = self.proj.weight, self.proj.bias
        if self.causal:
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)

        gate = F.conv1d(gate, weight, bias)

        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res

GLUs(Gated linear units)補(bǔ)充:

由Language model with gated convolutional network提出,使用CNN學(xué)習(xí)長(zhǎng)文本,為緩解梯度消散,并保留非線性能力,使用門控機(jī)制。即:
沒(méi)有經(jīng)過(guò)非線性轉(zhuǎn)換的卷積層輸出*經(jīng)過(guò)非線性轉(zhuǎn)換的卷積層輸出
h(x) = (X*W+b)\odot \sigma(X*V + b)

其中:
\odot:element-wise product
X\in R^{N \times m}
W,V \in R^{k \times m \times n}

注意,GLUs是沿著channel維度[per token]的處理,而SGU是沿著空間維度[cross-token]的處理。

Image Classification

在圖片分類ImageNet數(shù)據(jù)集上,無(wú)需添加外部數(shù)據(jù),訓(xùn)練gMLPs。
模型配置如下,輸入和輸出沿用的ViT(Vision Transformer)格式,模型的深度和寬度配置也和ViT/DeiT模型相似。
結(jié)果:和Transformer一樣,gMLPs在訓(xùn)練集上過(guò)擬合,因此采用了DeiT的正則化處理(mixup, cutmix);同時(shí),對(duì)模型的維度做了調(diào)整。


CV gMLPs
ImageNet模型結(jié)果
圖片分類準(zhǔn)確率和模型規(guī)模關(guān)系

Masked Language Modeling with BERT

DepthWise convolution補(bǔ)充

一個(gè)卷積核負(fù)責(zé)一個(gè)通道,卷積核數(shù)量要和圖片通道數(shù)相同。
f_{W,b}( \cdot)好比一個(gè)寬的depthwise convolution,接收整個(gè)句子的信息。但是depthwise convolution面向的是通道的filter,而gMLPs只使用一個(gè)W共享交叉通道。

在NLP上,gMLPs進(jìn)行了多個(gè)ablation實(shí)驗(yàn)。

1. Ablation:the importance of gating in gMLP for BERT's Pretraining

  1. 使用Bert的absolute position embeddings;
  2. Bert框架 + T5-stype的relative position biases;
  3. 同1,2,但只保留relative positional biases,去掉content-dependent terms inside the softmax。

困惑度:交叉熵的指數(shù)形式。語(yǔ)言模型越好,句子概率越大,熵越小,困惑度越低。

各種模型的perplexity比較

使用SGU可以讓gMLPs得到與Bert差不多的perplexity。

2. Case Study: The Behavior of gMLPs as Model Size Increases

模型規(guī)模和finetuing結(jié)果比較

Transformer中的6+6:self-attention使用6層,F(xiàn)FN使用6層。
finetuning任務(wù)用GLUE表示模型效果。
結(jié)果顯示:

  1. gMLPs越深,pretraining perplexity越小,和transformer的模型效果越逼近;
  2. pretraining的perplexity越小,不意味著finetuning結(jié)果越好,比如gMLPs的perplexity比transformer小的時(shí)候,在SST-2的模型結(jié)果更好,但是MNLI-m的模型結(jié)果更差;

3. Ablation: The Usefulness of Tiny Attention in BERT's Finetuning

文中還做了個(gè)測(cè)試,在一些下游任務(wù)上,主要是設(shè)計(jì)到句子對(duì)的任務(wù)上,gMLPs表現(xiàn)比Transformers差。 那就再加一個(gè)tiny attention,來(lái)加強(qiáng)模型對(duì)cross-sentence alignment的學(xué)習(xí)。

Hybrid

這種混個(gè)gMLPs和attention的模型,稱為aMLPs。結(jié)果顯示,aMLPs的效果比gMLPs和transformer都要好。


模型比較

4.Main Results for MLM in the BERT Setup

模型效果總結(jié)
  1. 以SQuADv2.0任務(wù)為例,base模型,Bert模型的f1達(dá)到了78.6,gMLPs只有70.1, 差距8.5%;到了large模型,差距只有81.0-78.3=2.7;
  2. aMLPs使用128d的attention size,在SQuADv2.0任務(wù),比Bert還要高4.4%的F1.

前面做的幾個(gè)實(shí)驗(yàn)的總結(jié):

  1. 在finetuning階段,gMLPs不如transformer,但是,隨著模型變大,和transformer的差距會(huì)不斷縮?。?/li>
  2. aMLPs 不同的attention size(64,128),足夠使得模型效果優(yōu)于其他2個(gè)。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • BERT-Google Code Pre-training of Deep Bidirectional Trans...
    呢嘻嘻嘻嘻嘻閱讀 1,186評(píng)論 0 0
  • Contextual Word Representations and Pretraining 一、Word Re...
    Evermemo閱讀 2,822評(píng)論 0 4
  • 目錄 一、前言 二、如何理解BERT模型 三、BERT模型解析 1、論文的主要貢獻(xiàn)2、模型架構(gòu)3、關(guān)鍵創(chuàng)新3、實(shí)驗(yàn)...
    奇點(diǎn)機(jī)智閱讀 95,210評(píng)論 1 35
  • 表情是什么,我認(rèn)為表情就是表現(xiàn)出來(lái)的情緒。表情可以傳達(dá)很多信息。高興了當(dāng)然就笑了,難過(guò)就哭了。兩者是相互影響密不可...
    Persistenc_6aea閱讀 129,412評(píng)論 2 7
  • 16宿命:用概率思維提高你的勝算 以前的我是風(fēng)險(xiǎn)厭惡者,不喜歡去冒險(xiǎn),但是人生放棄了冒險(xiǎn),也就放棄了無(wú)數(shù)的可能。 ...
    yichen大刀閱讀 7,544評(píng)論 0 4

友情鏈接更多精彩內(nèi)容