近年來,Attention Model在自然語言處理領(lǐng)域大火,在多項(xiàng)NLP任務(wù)中取得了顯著的成績(jī),成為構(gòu)成Transformer,BERT等多個(gè)語言模型的基石。本文將主要介紹一下Attention Model的基本原理及計(jì)算過程。
seq2seq模型入門
在NLP領(lǐng)域,Bahdanau首先將Attention機(jī)制引入神經(jīng)網(wǎng)絡(luò)機(jī)器翻譯(NMT)中,而NMT任務(wù)是一個(gè)典型的sequence to sequence(簡(jiǎn)稱seq2seq)任務(wù)。所以在介紹Attention之前,我們先來簡(jiǎn)單回顧一下經(jīng)典seq2seq模型以及其所面臨的瓶頸。
seq2seq的任務(wù)目標(biāo)是:輸入一個(gè)序列,輸出另一個(gè)序列。這種任務(wù)形式普遍存在于翻譯、對(duì)話、文本摘要生成等多種NLP任務(wù)中。seq2seq模型基本可以歸納為經(jīng)典的Encoder-Decoder結(jié)構(gòu)。該結(jié)構(gòu)的經(jīng)典表示如下圖所示:
Encoder-Decoder網(wǎng)絡(luò)結(jié)構(gòu)
該結(jié)構(gòu)由兩個(gè)RNN組成。其中,Encoder部分RNN只在序列結(jié)束時(shí)輸出一個(gè)語義向量C,該向量可以看成擁有輸入序列的全部上下文語義信息。將C復(fù)制N份,與輸出序列上一個(gè)時(shí)刻的預(yù)測(cè)值yt-1一起,作為Decoder部分每個(gè)RNN序列的輸入。Decoder部分在t時(shí)刻的隱藏層狀態(tài)ht由ht-1, yt-1, c共同決定,即以下公式:
采取該網(wǎng)絡(luò)結(jié)構(gòu),相當(dāng)于:將在給定yt-1, yt-2,……,y1和輸入語義向量c的情況下,求yt時(shí)刻輸出概率最大值的問題,等價(jià)于給定Decoder當(dāng)前隱狀態(tài)ht,上一個(gè)時(shí)刻預(yù)測(cè)輸出yt-1,輸入語義向量c的情況下,求yt時(shí)刻輸出概率最大值。即以下公式:
取對(duì)數(shù)似然條件概率,即得到整個(gè)模型的優(yōu)化目標(biāo):
Attention機(jī)制介紹
然而,上述模型有一個(gè)很大的問題:對(duì)于輸出序列而言,每一個(gè)時(shí)刻傳入的輸入語義向量均為同樣的值。而這很顯然跟我們的生活常識(shí)不符,例如:當(dāng)我們翻譯一句話中某個(gè)單詞時(shí),跟它相鄰詞的參考價(jià)值,往往要大于遠(yuǎn)離它的詞。所以為了解決這個(gè)問題,NMT框架中引入了Attention機(jī)制。通過參數(shù),來控制每一個(gè)詞在語義向量中的權(quán)重,從而提升最終效果。其網(wǎng)絡(luò)結(jié)構(gòu)如下:
Attention Model
其中下半部分為Encoder結(jié)構(gòu),這里采用雙向RNN構(gòu)成,前向RNN順序輸入單詞,后向RNN反序輸入單詞。將同一時(shí)刻的兩個(gè)RNN單元的隱狀態(tài)做拼接形成最終的隱狀態(tài)輸出ht,這樣ht既包含當(dāng)前單詞前一個(gè)時(shí)刻的信息,也包含后一個(gè)時(shí)刻的信息。上半部分為Decoder結(jié)構(gòu),為一個(gè)單向的RNN。中間部分就是Attention,采用如下公式計(jì)算:
其中,si-1為Decoder上一個(gè)時(shí)刻的隱狀態(tài),hj為j時(shí)刻Encoder隱藏層輸出狀態(tài)。使用一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)a訓(xùn)練,得到的分值eij表示j時(shí)刻輸入與i時(shí)刻輸出之間的匹配程度。之后用一個(gè)softmax函數(shù)歸一化,得到的標(biāo)準(zhǔn)概率表示alpha ij即為hj在翻譯yi中的重要性表示。最后以對(duì)應(yīng)的alpha作為權(quán)值,加權(quán)計(jì)算每一個(gè)時(shí)刻的輸入語義向量ci,即體現(xiàn)了每一個(gè)輸入單詞,在翻譯不同輸出單詞中的重要性。
以上就是關(guān)于Attention Model的基本原理的介紹??戳酥笫遣皇歉杏X也沒有很復(fù)雜呢?