Transformer是近兩三年非常火的一種適用于NLP領(lǐng)域的一種模型,本質(zhì)上是Encoder-Decoder結(jié)構(gòu),所以多應(yīng)用在機(jī)器翻譯(輸入一個句子輸出一個句子)、語音識別(輸入語音輸出文字)、問答系統(tǒng)等領(lǐng)域。本文基于Attention is all you need這篇論文,講解Transformer的結(jié)構(gòu),涉及到的圖片均為論文中或經(jīng)典圖片,參數(shù)值均來自論文,具體問題可以具體調(diào)整超參數(shù)。
Transformer的組成模塊分為:Attention(包括multi-head self-Attention & context-Attention),Normalization(使用的是layer Norm,區(qū)別于Batch Norm),mask(padding mask & sequence mask),positional encoding,feed forword network(FFN)。
1、總體結(jié)構(gòu)
Transformer的總架構(gòu)如下圖所示:

這是典型的Transformer結(jié)構(gòu),簡單來說,Transformer = 預(yù)訓(xùn)練(input) + Encoder*N + 預(yù)訓(xùn)練(output) + Decoder*N+output。
模型的運(yùn)行步驟為:
① 對Input做Embedding,可以使用Word2Vec等工具,維度為512維,Embedding過后結(jié)合positional encoding,它記錄了輸入單詞的位置信息。
② 預(yù)處理后的輸入向量經(jīng)過多頭Attention層處理,加入殘差、規(guī)則化,數(shù)據(jù)給到FFN(全連接層),再加入殘差、規(guī)則化。如此反復(fù),經(jīng)過6個這樣的Encoder(即Nx=6x),編碼部分結(jié)束。
③ 編碼部分的第一個Decoder的第一個Attention接受的是來自O(shè)utputs的信息,其余的均接受來自Encoder和上一層Decoder的信息。最終的output的串行生成的,每生成一個,就放到Decoder最下面的outputs座位Decoder的輸入。
④ Decoder也是6個,最終的輸出要經(jīng)過線性層和Softmax得到最終輸出。
要注意的是,Encoder和Decoder的結(jié)構(gòu)是相同的,但不共享權(quán)重;在Encoder部分,每個單詞在Attention層的路徑具有依賴關(guān)系,串行執(zhí)行,在FFN層不具有依賴關(guān)系,并行執(zhí)行。
2、Attention
在這個結(jié)構(gòu)中,存在這樣幾個Attention,有:self-attention & context attention & scaled dot-product attention & multi-headed attention。要說明的是scaled dot-product attention和multi-headed attention是兩種attention的計算方法,后面會介紹,前兩個Attention均使用的是這兩種計算方法。
2.1?scaled dot-product attention
這種Attention的計算公式為:
以第一個Encoder為例對流程解釋如下:
① 為Encoder的每個單詞創(chuàng)建如下的三個向量:Query vector , Key vector , Value vector。這三個向量由輸入的Embedding乘以三個向量矩陣得到。要注意的是,Embedding向量維度為512,Q K V向量維度是64。

② 計算Score:對于每個詞,計算它自身的與所有的
的乘積。
③ 計算Attention:按上面Attention的公式,將Score除以一個定值(這個操作稱為“scaled”),進(jìn)行Softmax變換,使所有Score之和為1。最后乘以對應(yīng)位置的,得到該單詞的Attention。

這就是scaled dot-product attention這種機(jī)制的計算方法,Transformer架構(gòu)中的兩種Attention都使用的是這種計算方法,不同的是二者的Q K V的來源有些差異。
注:為什么Softmax中要除以一個根號?論文中給出的原因是本來
和
都是均值為0、方差為1的變量,假設(shè)二者分布相互獨立,他們乘積的分布就是均值為0、方差為
,除以根號使得Softmax內(nèi)的值保持均值為0、方差為1利于梯度計算。如果不加根號會使得計算收斂很慢,因為Softmax中的值處于梯度消失區(qū)。
進(jìn)一步思考:為什么很多Attention中沒有Scaled這一步?Attention分為兩種,前面那種是乘法,還有加法的一種:。實驗表明,加法雖然看起來簡單但計算起來并沒有快多少(tanh相當(dāng)于一個完整的隱層),在維度較高時確實更好,但如果加上Scaled也差不多。Transformer中選擇乘法是為了計算更快,維度大的話就加上Scaled。
2.2 multi-headed attention
多頭注意力機(jī)制也是一種處理的技巧,主要提高了Attention層的性能。因為上面介紹的self-attention雖然包含了其余位置的編碼,但主導(dǎo)的還是自身位置的單詞,而有時我們更需要關(guān)注其他位置,比如機(jī)器翻譯中的代詞具體指代哪個主語。
多頭注意力機(jī)制是把Q K V三個矩陣通過h個線性變換投影,然后進(jìn)行h次self-attention的計算,最后再把h個計算結(jié)果拼接起來。
2.3 self-attention & context attention
在Encoder的self-attention中,Q K V均是上一層Encoder的輸出,對于第一個Encoder來說,他們就是輸入的Embedding與positional encoding之和。
在Decoder的self-attention中,Q K V也是上一層Decoder的輸出,對于第一個Decoder來說,他們是輸入的Embedding與positional encoding之和。要注意的是,這部分我們不希望獲取到后面時刻的數(shù)據(jù),只想考慮已經(jīng)預(yù)測出來的信息,所以要進(jìn)行sequence masking(后面講到)。
在Encoder-Decoder attention(即context attention)中,Q是Decoder上一層的輸出,K V是Encoder的輸出。
3、Layer Normalization
Transformer中使用的是LN,并非BN(Batch Normalization)。什么是Norm規(guī)范化,一般地,可以用下面公式來表達(dá):
公式一為規(guī)范化處理前,公式二為處理后。規(guī)范化是對數(shù)據(jù)分布的調(diào)整,比如本身數(shù)據(jù)是正態(tài)分布,調(diào)整后的數(shù)據(jù)分布就是標(biāo)準(zhǔn)正態(tài)分布,相當(dāng)于調(diào)整了均值和方差。這樣做的意義一是讓激活值落入激活函數(shù)敏感區(qū)間,梯度更新變大,訓(xùn)練加快,二是消除極端值,提升訓(xùn)練穩(wěn)定性。
Transformer使用的是LN,而不是BN。首先看二者的區(qū)別如圖:

LN是對每個樣本自身進(jìn)行規(guī)范化,BN是對一個批次的數(shù)據(jù)在同一維度上規(guī)范化,是跨樣本的。在CNN任務(wù)中,BatchSize較大,并且訓(xùn)練時全局記錄了樣本均值和方差,適用于BN。而時序問題中,對每個神經(jīng)元進(jìn)行統(tǒng)計是不現(xiàn)實的。LN的限制相對來說就小很多,即時BatchSize=1也無妨。
4、mask
mask分為兩種,一是padding mask,二是sequence mask,這兩種在Transformer中出現(xiàn)的位置不同:padding mask在所有scaled dot-product attention中均出現(xiàn),sequence mask僅在decoder的self-attention中出現(xiàn)。
4.1 padding mask
由于每個batch的輸入序列的長度不同,padding mask被用來對齊序列長度,簡單來說就是短序列向長序列對齊,對齊的方法就是補(bǔ)0。補(bǔ)充上的地方是沒有意義的,那么Attention就不應(yīng)該給以關(guān)注。實際上,我們并不是直接在相應(yīng)位置上補(bǔ)充0,而是補(bǔ)充-inf(負(fù)無窮),這樣在Softmax之后,這些位置的概率就接近0了。
在處理過程中,padding mask是一個bool張量,false的地方就是補(bǔ)0的地方。
4.2 sequence mask
前面提到,sequence mask的作用是不讓decoder看到當(dāng)前時刻以后的信息,所以要把后面那部分信息完全遮蓋住。具體的做法是,產(chǎn)生一個上三角矩陣,上三角的值均為1,下三角和對角線均為0。
在decoder的self-attention部分,sequence mask 和 padding mask同時作用,二者相加作為mask。
5、positional encoding
RNN處理序列問題是天然有序的,而Transformer消除了這種時序上的依賴。以機(jī)器翻譯為例,輸出要是一個完整的合理的句子,就需要對輸入數(shù)據(jù)處理時加入位置信息,否則可能輸出結(jié)果的每個字是對的,但組成不了一句話。positional encoding是對輸入信息的位置進(jìn)行編碼,再和輸入的Embedding相加。
positional encoding使用的是正余弦編碼:
在偶數(shù)位置,使用公式一正弦編碼,奇數(shù)位置使用公式二余弦編碼。由于正余弦函數(shù)的特性,這種編碼既是絕對位置編碼,也包含了相對位置編碼的信息。
相對位置編碼信息主要依賴于三角函數(shù)和角公式:
6、FFN
FFN 是一個全連接網(wǎng)絡(luò),順序上先線性變換,再ReLU非線性變換,再線性變換,公式如下:
參考文獻(xiàn):