聲明:轉(zhuǎn)載請(qǐng)?jiān)跇?biāo)題標(biāo)明轉(zhuǎn)載,并添加原文鏈接。
簡(jiǎn)介
這篇博客的主要內(nèi)容是對(duì)谷歌提出的transformer 進(jìn)行論文解讀,包含算法復(fù)雜度的分析。對(duì)應(yīng)的論文是 “Attention is all you need", 鏈接如下 https://arxiv.org/pdf/1706.03762.pdf 。
選擇這篇論文的原因有三點(diǎn)。
1. 這篇論文達(dá)到了 new the state-of-the-art result,? 應(yīng)該是現(xiàn)在做神經(jīng)翻譯里最好的BLUE結(jié)果。
2. 這篇文章提出的算法另辟蹊徑,沒(méi)有采取大熱的RNN/LSTM/GRU的結(jié)構(gòu),而是使用attention layer 和全連接層,達(dá)到了較好的效果,并且解決了 RNN/LSTM/GRU 里的long dependency problem 。
3. 這篇文章的算法解決了傳統(tǒng)RNN 訓(xùn)練并行度的問(wèn)題,并降低了計(jì)算復(fù)雜度。
接下來(lái)會(huì)按照 "Attention is all you need" 論文中的邏輯, 逐個(gè)模塊介紹, 希望能對(duì)大家有所幫助。原文寫(xiě)在我的筆記上。
https://shimo.im/docs/gmRW4WV2mjoXzKA1/
模型結(jié)構(gòu)
上面這個(gè)Fig.1 就是谷歌提出的transformer 的架構(gòu)。這其中左半部分是 encoder 右半部分是 decoder.
Encoder: 這里面有 N=6 個(gè) 一樣的layers, 每一層包含了兩個(gè)sub-layers. 第一個(gè)sub-layer 就是多頭注意力層(multi-head attention layer) 然后是一個(gè)簡(jiǎn)單的全連接層。 這里還有一個(gè)殘差連接 (residual connection), 在這個(gè)基礎(chǔ)上, 還有一個(gè)layer norm.? 這里的注意力層會(huì)在下文詳細(xì)解釋。
Decoder: 這里同樣是有六個(gè)一樣的Layer是,但是這里的layer 和encoder 不一樣, 這里的layer 包含了三個(gè)sub-layers,? 其中有 一個(gè)self-attention layer, encoder-decoder attention layer 最后是一個(gè)全連接層。 前兩個(gè)sub-layer 都是基于multi-head attention layer.? 這里有個(gè)特別點(diǎn)就是masking,? masking 的作用就是防止在訓(xùn)練的時(shí)候 使用未來(lái)的輸出的單詞。 比如訓(xùn)練時(shí), 第一個(gè)單詞是不能參考第二個(gè)單詞的生成結(jié)果的。 Masking就會(huì)把這個(gè)信息變成0, 用來(lái)保證預(yù)測(cè)位置 i 的信息只能基于比 i 小的輸出。
Attention
Scaled dot-product attention
這里就詳細(xì)討論scaled dot-product attention. 在原文里, 這個(gè)算法是通過(guò)queriies, keys and values 的形式描述的, 非常抽象。這里我用了一張CMU NLP 課里的圖來(lái)解釋?zhuān)?Q(queries), K (keys) and V(Values), 其中 Key and values 一般對(duì)應(yīng)同樣的 vector, K=V 而Query vecotor? 是對(duì)應(yīng)目標(biāo)句子的 word vector.
Fig. 2 里的quary vector? 來(lái)自 decoder state, key/value vector來(lái)自所有的encoder state.? 具體的操作有三個(gè)步驟。
1. 每個(gè)query-key 會(huì)做出一個(gè)點(diǎn)乘的運(yùn)算過(guò)程
2. 最后會(huì)使用soft max 把他們歸一。
3. 再到最后會(huì)乘以V (values) 用來(lái)當(dāng)做attention vector.
這里的數(shù)學(xué)表達(dá)式如下。

如果用Numpy 來(lái)寫(xiě), scaled dot-product attention, 內(nèi)容如下

這個(gè)在實(shí)際呢, 是一個(gè)tensor dot product. 對(duì)于新手來(lái)說(shuō)tensor dot product 可能還有些陌生。 對(duì)于一維向量的點(diǎn)乘(dot product), 結(jié)果是一個(gè) 標(biāo)量scalar, 對(duì)于一個(gè)高緯度的tensor? dot product, 結(jié)果就不那么好理解了。 在numpy 里的 numpy.dot? 解釋 如下,
If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]), source: https://docs.scipy.org/doc/numpy 1.14.0/reference/generated/numpy.dot.html
舉個(gè)例子, 比如 張量(tensor)a? 是一個(gè)四維矩陣,維度是[3,4,5,6], 張量 b 也是一個(gè)四維矩陣, 維度是[5,4,6,3], 那么 dot(a,b) 的維度就是 [3,4,5,5,4,3].

這里面如果考慮計(jì)算復(fù)雜度呢, c 里的每一個(gè)元素 都是通過(guò) sum(a[i,j,:] * b[k,:,m])計(jì)算得來(lái),對(duì)于上圖的例子,就需要6次乘法 5 次加法, 總共11 FLOPs, 那這樣的張量dot product 就是11*3*4*5*5*4*3 FLOPs的運(yùn)算量了。
Multi-head attention
上面的scaled dot-product attention, 看起來(lái)還有點(diǎn)簡(jiǎn)單, 網(wǎng)絡(luò)的表達(dá)能力還有一些簡(jiǎn)單,所以提出了多頭注意力機(jī)制(multi-head attention)。參考文獻(xiàn)是[3] 用圖說(shuō)明。

這里的self-attention, 就是上文中query (Q), keys (K) and Values (V) 都是相同的, 表達(dá)同一段句子 (Q=K=V). Fig. 3 右圖說(shuō)明了, 對(duì)于self-attention, target node (生成的那個(gè)點(diǎn)) 實(shí)際上和 輸入中的任意一點(diǎn)的距離是相同的。 這點(diǎn)是可以用來(lái)獲得 long path dependency的, 在翻譯句子里尤為重要。 而convolution, 因?yàn)橛芯矸e窗口的 問(wèn)題, 任意兩點(diǎn)的局距離是 Log_k(n), 其中k 是卷積窗口的大小,而n句子長(zhǎng)度。 所以說(shuō)這個(gè)self-attention 可以解決 翻譯句子 source -target 對(duì)齊的問(wèn)題。? ? 一般有三種注意力方式, 如下圖所示。
這里的masked decoder self attention, 就是為了防止 當(dāng)前的單詞生成 對(duì)未來(lái)的單詞產(chǎn)生依賴性。
多頭注意力機(jī)制很棒啊。 首先each head, 是可以并行計(jì)算的, 然后每個(gè)head 都有自己對(duì)應(yīng)的weight, 實(shí)現(xiàn)不同的線性轉(zhuǎn)換, 這樣每個(gè)head 也就有了自己特別的表達(dá)信息。 所以Fig. 5 里的每個(gè)連接 是用彩色表示的。
那multi-head attention 是怎么實(shí)現(xiàn)的呢。 在執(zhí)行self-attention 之前 Q, K, V 都會(huì)先乘以一個(gè)矩陣 做linear project, 把他投影到維度num_head*(d_q, d_k, d_v)。 這個(gè)地方舉個(gè)例子就比較容易說(shuō)明了, 比如以前的head 維度是 64, 我需要8 個(gè)head, 那我就用線性轉(zhuǎn)換, 把我的輸入變成的維度變成 512, 再reshape 一下。
具體就是, 我單個(gè)self head attention 的維度是 [1, 16, 512]? 對(duì)應(yīng)的是[batch, seq_length, head_dim]。 那我現(xiàn)在做multi-head attention, 那我就先把512 通過(guò)linear transform 變成64,(multiplied by 512*64 mixtrx) 那得到了 [1,16,64], 同樣的操作作進(jìn)行八次, reshape, 就是 [1, 8, 16,64].? 這樣multi-head attention 就實(shí)現(xiàn)了。 接下來(lái)繼續(xù)按照Fig.2 中介紹的三步驟來(lái)實(shí)現(xiàn) attention。
我們回到multi-head attention 論文中的討論,
有了前文的解釋?zhuān)現(xiàn)ig. 6 中的原文討論就容易理解的多了。 第一段就是我上文的討論, 說(shuō)出了 a single attention head 的問(wèn)題。? d_model 就是 512, d_k= d_v= 8.? Q 是 [1, 16, 512] 維度, 而且W_i^Q 的維度是 [512, 64].
這里有個(gè)潛在的問(wèn)題, 計(jì)算復(fù)雜度變高了。 比如Q, [1,16,512], 512? 前文說(shuō)是head dimension,? 其實(shí)是有物理含義的, 對(duì)于第一個(gè)attention 模塊, 512 表達(dá)的是 word vector length, 就是用多長(zhǎng)的vector 來(lái)代表一個(gè)單詞。? 而這個(gè)計(jì)算最終導(dǎo)致, 增加了O(n*d^2) 的復(fù)雜度, 其中 n 是sequence length, d 是word vector representation dimension.
Positional Embedding
討論到現(xiàn)在, 還沒(méi)有信息來(lái)表達(dá)單詞和單詞相關(guān)的位置關(guān)系。 這里, 谷歌用了sine 和cosine 函數(shù)來(lái)表達(dá)。 先上截圖。

這里使用了兩個(gè)構(gòu)造函數(shù), sin, cosine. pos 用來(lái)表示單詞的位置信息, 比如 第一個(gè)單詞啦, 第二個(gè)單詞什么的。 而 i? 用來(lái)表達(dá)dimension。 現(xiàn)在的例子里, d_{model} 是512, 那 i 應(yīng)該是 0 到255.? 這里呢, 為了好說(shuō)明, 如果2i= d_{model}, PE 的函數(shù)就是sin(pos/10000), 那它的波長(zhǎng)就是10000*2pi,? 如果i =0, 那么他的波長(zhǎng) 就是2pi.? 這樣的sin, cosin 的函數(shù) 是可以通過(guò)線性關(guān)系互相表達(dá)的。
就是這么回事啦。
Auto recursive decoding
在Transformer里有一個(gè)概念就是auto recursive decoding. 這個(gè)并不好理解, 論文也沒(méi)有詳細(xì)討論,讓我思考許久。好在谷歌博客做了一個(gè)動(dòng)圖, 一目了然。

Fig. 8? 這個(gè)圖的encoding? 過(guò)程, 主要是self attention,? 有三層。 接下來(lái)是decoding過(guò)程, 也是有三層, 第一個(gè)預(yù)測(cè)結(jié)果 <start> 符號(hào), 是完全通過(guò)encoding 里的attention vector 做出的決策。 而第二個(gè)預(yù)測(cè)結(jié)果Je, 是基于encoding attention vector &? <start> attention vector? 做出的決策。按照這個(gè)邏輯,新翻譯的單詞不僅僅依賴 encoding attention vector, 也依賴過(guò)去翻譯好的單詞的attention vector。 隨著翻譯出來(lái)的句子越來(lái)越多,翻譯下一個(gè)單詞的運(yùn)算量也就會(huì)相應(yīng)增加。 如果詳細(xì)分析,復(fù)雜度是 (n^2d), 其中n是翻譯句子的長(zhǎng)度,d是word vector 的維度。
計(jì)算復(fù)雜度

這里n, 是 sequence length, d 是word representation dimension. 通常 n<<d, 一個(gè)句子的長(zhǎng)度 n 一般是3 ---30,? 但是一個(gè)word representation dimension,大概300, 也有更多的, 比如512.
后續(xù)
實(shí)驗(yàn)結(jié)果我就先不分析了, 如果需要,請(qǐng)留言, 我再找時(shí)間。歡迎大家留言討論。
小更新,這個(gè)博客也特別棒http://jalammar.github.io/illustrated-transformer/
參考文獻(xiàn)
[1] Michal Chromiak Blog, https://mchromiak.github.io/articles/2017/Sep/01/Primer-NN/#attention-basis
[2] Google Blog, https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html
[3] Tensor2tensor transformer, https://nlp.stanford.edu/seminar/details/lkaiser.pdf