《Attention Is All You Need》算法詳解

該篇文章右谷歌大腦團(tuán)隊在17年提出,目的是解決對于NLP中使用RNN不能并行計算(詳情參考《【譯】理解LSTM(通俗易懂版)》),從而導(dǎo)致算法效率低的問題。該篇文章中的模型就是近幾年大家到處可以聽到的Transformer模型。

一、算法介紹前的說明

由于該文章提出是解決NLP(Nature Language Processing)中的任務(wù),例如文章實驗是在翻譯任務(wù)上做的。為了CV同學(xué)更好的理解,先簡單介紹一下NLP任務(wù)的一個工作流程,來理解模型的輸入和輸出是什么。

1.1 CV模型的輸入和輸出

首先拿CV中的分類任務(wù)來說,訓(xùn)練前我們會有以下幾個常見步驟:

  1. 獲取圖片
  2. 定義待分類的類別,用數(shù)字標(biāo)簽或者one-hot向量標(biāo)簽表示
  3. 對圖片進(jìn)行類別的標(biāo)注
  4. 圖片預(yù)處理(翻轉(zhuǎn)、裁剪、縮放等)
  5. 將預(yù)處理后的圖片輸入到模型中

所以對于分類任務(wù)來說,模型的輸入為預(yù)處理過的圖片,輸出為圖片的類別(一般為預(yù)測的向量,然后求argmax獲得類別)。

1.2 NLP模型的輸入

在介紹NLP任務(wù)預(yù)處理流程前,先解釋兩個詞,一個是tokenize,一個是embedding。

tokenize是把文本切分成一個字符串序列,可以暫且簡單的理解為對輸入的文本進(jìn)行分詞操作。對英文來說分詞操作輸出一個一個的單詞,對中文來說分詞操作輸出一個一個的字。(實際的分詞操作多有種方式,會復(fù)雜一點(diǎn),這里說的只是一種分詞方式,姑且這么定,方便下面的理解。)

embedding是可以簡單理解為通過某種方式將詞向量化,即輸入一個詞輸出該詞對應(yīng)的一個向量。(embedding可以采用訓(xùn)練好的模型如GLOVE等進(jìn)行處理,也可以直接利用深度學(xué)習(xí)模型直接學(xué)習(xí)一個embedding層,Transformer模型的embedding方式是第二種,即自己去學(xué)習(xí)的一個embedding層。)

在NLP中,拿翻譯任務(wù)(英文翻譯為中文)來說,訓(xùn)練模型前存在下面步驟:

  1. 獲取英文中文對應(yīng)的句子
  2. 定義英文詞表(常用的英文單詞作為一個類別)和中文詞表(一個字為一個類別)
  3. 對中英文進(jìn)行分詞
  4. 將分好的詞根據(jù)步驟2定義好的詞表獲得句子中每個詞的one-hot向量
  5. 對每個詞進(jìn)行embedding(輸入one-hot輸出與該詞對應(yīng)的embedding向量)
  6. embedding向量輸入到模型中去

所以對于翻譯任務(wù)來說,翻譯模型的輸入為句子每個詞的one-hot向量或者embedding后的向量(取決于embedding是否是翻譯模型自己學(xué)習(xí)的,如果是則輸入one-hot就可以了,如果不是那么輸入就是通過別的模型獲得的embedding向量)組成的序列,輸出為當(dāng)前預(yù)測詞的類別(一般為詞表大小維度的向量)

二、Transformer的結(jié)構(gòu)

知道了Transformer模型的輸入和輸出后,下面來介紹一下Transformer模型的結(jié)構(gòu)。

先來看看Transformer的整體結(jié)構(gòu),如下圖所示:


1.png

可以看出它是一個典型的seq2seq結(jié)構(gòu)(encoder-decoder結(jié)構(gòu)),Encoder里面有N個重復(fù)的block結(jié)構(gòu),Decoder里面也有N個重復(fù)的block結(jié)構(gòu)。


2.jpg

2.1 Embedding

可以注意到這里的embedding操作是與翻譯模型一起學(xué)習(xí)的。所以Transformer模型的輸入為對句子分詞后,每個詞的one-hot向量組成的一個向量序列,輸出為預(yù)測的每個詞的預(yù)測向量。

2.2 Positional Encoding

為了更好的利用序列的位置信息,在對embedding后的向量加上位置相關(guān)的編碼。文章采用的是人工預(yù)設(shè)的方式計算出來的編碼。計算方式如下

PE_{(pos, 2i)}=sin(pos/10000^{2i/d_{model}})

PE_{(pos, 2i+1)}=cos(pos/10000^{2i/d_{model}})

上式中,pos表示當(dāng)前詞在句子中的位置,例如輸入的序列長L=5,那么pos取值分別為0-4,i表示維度的位置,偶數(shù)位置用PE(pos, 2i)公式計算, 奇數(shù)位置用PE(pos, 2i+1)公式計算。

文章也采用了加入模型訓(xùn)練來自動學(xué)習(xí)位置編碼的方式,發(fā)現(xiàn)效果與人工預(yù)設(shè)方式差不多。

2.3 Encoder結(jié)構(gòu)

Encoder包含了N個重復(fù)的block結(jié)構(gòu),文章N=6。下面來拆解一個每個塊的具體結(jié)構(gòu)。


6.png
2.3.1 Multi-Head Attention(encoder)

為了便于理解,介紹Multi-Head Attention結(jié)構(gòu)前,先介紹一下基礎(chǔ)的Scaled Dot-Product Attention結(jié)構(gòu),該結(jié)構(gòu)是Transformer的核心結(jié)構(gòu)。

Scaled Dot-Product Attention結(jié)構(gòu)如下圖所示

3.png

Scaled Dot-Product Attention模塊用公式表示如下

Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

上式中,可以假設(shè)Q\K的維度皆為{(L, d_k)},V的維度為(L, d_v),L為輸入的句子長度,d_k,d_v為特征維度。

softmax(QK^T)得到的維度為(L, L),該張量可以理解為計算Q與K中向量兩兩間的相似度或者說是模型應(yīng)該著重關(guān)注(attention)的地方。這里還除了\sqrt{d_k},文章解釋是防止維度d_k太大得到的值就會太大,導(dǎo)致后續(xù)的導(dǎo)數(shù)會太小。(這里為什么一定要除\sqrt{d_k}而不是{d_k}或者其它數(shù)值,文章沒有給出解釋。)

經(jīng)過softmax(\frac{QK^T}{\sqrt{d_k}})獲得attention權(quán)重后,與V相乘,既可以得到attention后的張量信息。最終的Attention(Q, K, V)輸出維度為(L, d_v)

這里還可以看到在Scaled Dot-Product Attention模塊中還存在一個可選的Mask模塊(Mask(opt.)),后續(xù)會介紹它的作用。

文章認(rèn)為采用多頭(Multi-Head)機(jī)制有利于模型的性能提高,所以文章引入了Multi-Head Attention結(jié)構(gòu)。

Multi-Head Attention結(jié)構(gòu)如下圖所示

4.png

Multi-Head Attention結(jié)構(gòu)用公式表示如下

MultiHead(Q, K, V) = Concat(head_1, ..., head_n)W^O\\ where head_i = Attention(QW^Q_i, KW^K_i, VW^V_i)

上述參數(shù)矩陣為W^Q_i\in R^{d_{model} \times d_k}, W^K_i\in R^{d_{model} \times d_k}, W^V_i\in R^{d_{model} \times d_v}, W^O_i\in R^{hd_v \times d_{model}}d_{model}為multi-head attention模塊輸入與輸出張量的通道維度,h為head個數(shù)。文中h=8,d_k=d_v=d_{model}/h=64d_{model}=512

關(guān)于multi-head機(jī)制為什么可以提高模型性能

文章末尾給出了多頭中其中兩個頭的attention可視化結(jié)果,如下所示

5.png

圖中,線條越粗表示attention的權(quán)重越大,可以看出,兩個頭關(guān)注的地方不一樣,綠色圖說明該頭更關(guān)注全局信息,紅色圖說明該頭更關(guān)注局部信息。

2.3.2 Add&Norm結(jié)構(gòu)

從結(jié)構(gòu)圖不難看出網(wǎng)絡(luò)加入了residual結(jié)構(gòu),所以add很好理解,就是輸入張量與輸出張量相加的操作。

Norm操作與CV常用的BN不太一樣,這里采用NLP領(lǐng)域較常用的LN(Layer Norm)。(關(guān)于BN、LN、IN、GN的計算方式可以參考《GN-Group Normalization》

還要多說一下的是,文章中共Add&Norm結(jié)構(gòu)是先相加再進(jìn)行Norm操作。

2.3.3 Feed Forward結(jié)構(gòu)

該結(jié)構(gòu)很簡單,由兩個全連接(或者kernel size為1的卷積)和一個ReLU激活單元組成。

Feed Forward結(jié)構(gòu)用公式表示如下

FFN(x)=max(0, xW_1 + b_1)W_2 + b_2

2.4 Decoder結(jié)構(gòu)

Decoder同樣也包含了N個重復(fù)的block結(jié)構(gòu),文章N=6。下面來拆解一個每個塊的具體結(jié)構(gòu)。


7.png
2.4.1 Masked Multi-Head Attention

從名字可以看出它比2.3.1部分介紹的Multi-Head Attention結(jié)構(gòu)多一個masked,其實它的基本結(jié)構(gòu)如下圖所示

3.png

可以看出這就是Scaled Dot-Product Attention,只是這里mask是啟用的狀態(tài)。

這里先從維度角度考慮mask是怎么工作的,然后再解釋為什么要加這個mask操作。

mask工作方式

為了方便解釋,先不考慮多batch和多head情況。

可以假設(shè)Q\K的維度皆為{(L, d_k)},V的維度為(L, d_v)。

那么在進(jìn)行mask操作前,經(jīng)過MatMul和Scale后得到的張量維度為 \frac{QK^T}{\sqrt{d_k}}\in R^{(L, L)}。

現(xiàn)在有一個提前計算好的mask為M\in R^{(L, L)},M是一個上三角為-inf,下三角為0的方陣。如下圖所示(圖中假設(shè)L=5)。

8.png

softmax(\frac{QK^T}{\sqrt{d_k}}+M)的結(jié)果如下圖所示(圖中假設(shè)L=5)

注意:下圖中的非0區(qū)域的值不一定是一樣的,這里為了方便顯示畫成了一樣的顏色

9.png

現(xiàn)在Scaled Dot-Product Attention的公式如下所示

Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}}+M)V

可以看出經(jīng)過M后,softmax在-inf處輸出結(jié)果為0,其它地方為非0,所以softmax的輸出為softmax(\frac{QK^T}{\sqrt{d_k}}+M)\in R^{(L, L)},該結(jié)果為上三角為0的方陣。與V\in R^{(L, d_v)}進(jìn)行相乘得到結(jié)果為Attention(Q, K, V) \in R^{(L, d_v)}。

從上述運(yùn)算可以看出mask的目的是為了讓V與attention權(quán)重計算attention操作時只考慮當(dāng)前元素以前的所有元素,而忽略之后元素的影響。即V的維度為(L, d_v),那么第i個元素只考慮0-i元素來得出attention的結(jié)果。

mask操作的作用

在解釋mask作用之前,我們先解釋一個概念叫teacher forcing。

teacher forcing這個操作方式經(jīng)常在訓(xùn)練序列任務(wù)時被用到,它的含義是在訓(xùn)練一個序列預(yù)測模型時,模型的輸入是ground truth。

舉例來說,對于"I Love China -> 我愛中國"這個翻譯任務(wù)來說,測試階段,Encoder會將輸入英文編譯為feature,Decoder解碼時首先會收到一個BOS(Begin Of Sentence)標(biāo)識,模型輸出"我",然后將"我"作為decoder的輸入,輸出"愛",重復(fù)這個步驟直到輸出EOS(End Of Sentence)標(biāo)志。

但是為了能快速的訓(xùn)練一個效果好的網(wǎng)絡(luò),在訓(xùn)練時,不管decoder輸出是什么,它的輸入都是ground truth。例如,網(wǎng)絡(luò)在收到BOS后,輸出的是"你",那么下一步的網(wǎng)絡(luò)輸入依然還是使用gt中的"我"。這種訓(xùn)練方式稱為teacher forcing。如下圖所示


12.jpg

我們看下面兩張圖,第一張是沒有mask操作時的示例圖,第二張是有mask操作時的示例圖??梢钥吹?,按照teacher forcing的訓(xùn)練方式來訓(xùn)練Transformer,如果沒有mask操作,模型在預(yù)測"我"這個詞時,就會利用到"我愛中國"所有文字的信息,這不合理。所以需要加入mask,使得網(wǎng)絡(luò)只能利用部分已知的信息來模擬推斷階段的流程。

13.jpg

14.jpg
2.4.2 Multi-Head Attention(decoder)

decoder中的Multi-Head Attention內(nèi)部結(jié)構(gòu)與encoder是一模一樣的,只是輸入中的Q為2.4.1部分提到的Masked Multi-Head Attention的輸出,輸入中的K與V則都是encoder模塊的輸出。

下面用一張圖來展示encoder和decoder之間的信息傳遞關(guān)系


15.jpg

decoder中Add&Norm和Feed Forward結(jié)構(gòu)都與encoder一模一樣了。

2.5 其它說明

1. 從圖中看出encoder和decoder中每個block的輸入都是一個張量,但是輸入給attention確實Q\K\V三個張量?

對于block來說,Q=K=V=輸入張量

2. 推斷階段,解碼可以并行嗎?

不可以,上面說的并行是采用了teacher forcing+mask的操作,是的訓(xùn)練可以并行計算。但是推斷時的解碼過程同RNN,都是通過auto-regression方式獲得結(jié)果的。(當(dāng)然也有non auto-regression方面的研究,就是一次估計出最終結(jié)果)

參考:

  1. https://arxiv.org/abs/1706.03762
  2. https://www.youtube.com/watch?v=ugWDIIOHtPA&t=1697s
  3. http://nlp.seas.harvard.edu/2018/04/03/attention.html
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容