seq2seq通俗理解----編碼器和解碼器(TensorFlow實(shí)現(xiàn))

1. 什么是seq2seq

在?然語?處理的很多應(yīng)?中,輸?和輸出都可以是不定?序列。以機(jī)器翻譯為例,輸?可以是?段不定?的英語?本序列,輸出可以是?段不定?的法語?本序列,例如:

英語輸?:“They”、“are”、“watching”、“.”

法語輸出:“Ils”、“regardent”、“.”

當(dāng)輸?和輸出都是不定?序列時(shí),我們可以使?編碼器—解碼器(encoder-decoder)或者seq2seq模型。序列到序列模型,簡(jiǎn)稱seq2seq模型。這兩個(gè)模型本質(zhì)上都?到了兩個(gè)循環(huán)神經(jīng)?絡(luò),分別叫做編碼器和解碼器。編碼器?來分析輸?序列,解碼器?來?成輸出序列。兩 個(gè)循環(huán)神經(jīng)網(wǎng)絡(luò)是共同訓(xùn)練的。

下圖描述了使?編碼器—解碼器將上述英語句?翻譯成法語句?的?種?法。在訓(xùn)練數(shù)據(jù)集中,我們可以在每個(gè)句?后附上特殊符號(hào)“<eos>”(end of sequence)以表?序列的終?。編碼器每個(gè)時(shí)間步的輸?依次為英語句?中的單詞、標(biāo)點(diǎn)和特殊符號(hào)“<eos>”。下圖中使?了編碼器在 最終時(shí)間步的隱藏狀態(tài)作為輸?句?的表征或編碼信息。解碼器在各個(gè)時(shí)間步中使?輸?句?的 編碼信息和上個(gè)時(shí)間步的輸出以及隱藏狀態(tài)作為輸?。我們希望解碼器在各個(gè)時(shí)間步能正確依次 輸出翻譯后的法語單詞、標(biāo)點(diǎn)和特殊符號(hào)“<eos>”。需要注意的是,解碼器在最初時(shí)間步的輸? ?到了?個(gè)表?序列開始的特殊符號(hào)“<bos>”(beginning of sequence)。

image

2. 編碼器

編碼器的作?是把?個(gè)不定?的輸?序列變換成?個(gè)定?的背景變量 c,并在該背景變量中編碼輸?序列信息。常?的編碼器是循環(huán)神經(jīng)?絡(luò)。

讓我們考慮批量?小為1的時(shí)序數(shù)據(jù)樣本。假設(shè)輸?序列是 x1, . . . , xT,例如 xi 是輸?句?中的第 i 個(gè)詞。在時(shí)間步 t,循環(huán)神經(jīng)?絡(luò)將輸? xt 的特征向量 xt 和上個(gè)時(shí)間步的隱藏狀態(tài)h_{t-1}變換為當(dāng)前時(shí)間步的隱藏狀態(tài)ht。我們可以?函數(shù) f 表達(dá)循環(huán)神經(jīng)?絡(luò)隱藏層的變換:

h_t=f(x_t,h_{t-1})

接下來,編碼器通過?定義函數(shù) q 將各個(gè)時(shí)間步的隱藏狀態(tài)變換為背景變量:

c=q(h_1,...,h_T)

例如,當(dāng)選擇 q(h1, . . . , h****T ) = h****T 時(shí),背景變量是輸?序列最終時(shí)間步的隱藏狀態(tài)h**T

以上描述的編碼器是?個(gè)單向的循環(huán)神經(jīng)?絡(luò),每個(gè)時(shí)間步的隱藏狀態(tài)只取決于該時(shí)間步及之前的輸??序列。我們也可以使?雙向循環(huán)神經(jīng)?絡(luò)構(gòu)造編碼器。在這種情況下,編碼器每個(gè)時(shí)間步的隱藏狀態(tài)同時(shí)取決于該時(shí)間步之前和之后的?序列(包括當(dāng)前時(shí)間步的輸?),并編碼了整個(gè)序列的信息。

3. 解碼器

剛剛已經(jīng)介紹,編碼器輸出的背景變量 c 編碼了整個(gè)輸?序列 x1, . . . , xT 的信息。給定訓(xùn)練樣本中的輸出序列 y1, y2, . . . , yT′ ,對(duì)每個(gè)時(shí)間步 t′(符號(hào)與輸?序列或編碼器的時(shí)間步 t 有區(qū)別),解碼器輸出 yt′ 的條件概率將基于之前的輸出序列 y_1,...,y_{t^{′}-1} 和背景變量 c,即:

P(y_{t^{′}}|y_1,...,y_{t^{′}-1},c)

為此,我們可以使?另?個(gè)循環(huán)神經(jīng)?絡(luò)作為解碼器。在輸出序列的時(shí)間步 t′,解碼器將上?時(shí)間步的輸出 y_{t^{′}-1} 以及背景變量 c 作為輸?,并將它們與上?時(shí)間步的隱藏狀態(tài) s_{t^{′}-1} 變換為當(dāng)前時(shí)間步的隱藏狀態(tài)st′。因此,我們可以?函數(shù) g 表達(dá)解碼器隱藏層的變換:

s_{t^{′}}=g(y_{t^{′}-1},c,s_{t^{′}-1})

有了解碼器的隱藏狀態(tài)后,我們可以使??定義的輸出層和softmax運(yùn)算來計(jì)算P(y_{t^{′}}|y_1,...,y_{t^{′}-1},c),例如,基于當(dāng)前時(shí)間步的解碼器隱藏狀態(tài) st′、上?時(shí)間步的輸出s_{t^{′}-1}以及背景變量 c 來計(jì)算當(dāng)前時(shí)間步輸出 yt′ 的概率分布。

4. 訓(xùn)練模型

根據(jù)最?似然估計(jì),我們可以最?化輸出序列基于輸?序列的條件概率:

P(y_1,...,y_{t^{′}-1}|x_1,...,x_T)=\prod_{t^{′}=1}^{T^{′}}P(y_{t^{′}}|y_1,...,y_{t^{′}-1},x_1,...,x_T)

=\prod_{t^{′}=1}^{T^{′}}P(y_{t^{′}}|y_1,...,y_{t^{′}-1},c)

并得到該輸出序列的損失:

-logP(y_1,...,y_{t^{′}-1}|x_1,...,x_T)=-\sum_{t^{′}=1}^{T^{′}}logP(y_{t^{′}}|y_1,...,y_{t^{′}-1},c)

在模型訓(xùn)練中,所有輸出序列損失的均值通常作為需要最小化的損失函數(shù)。在上圖所描述的模型預(yù)測(cè)中,我們需要將解碼器在上?個(gè)時(shí)間步的輸出作為當(dāng)前時(shí)間步的輸?。與此不同,在訓(xùn)練中我們也可以將標(biāo)簽序列(訓(xùn)練集的真實(shí)輸出序列)在上?個(gè)時(shí)間步的標(biāo)簽作為解碼器在當(dāng)前時(shí)間步的輸?。這叫作強(qiáng)制教學(xué)(teacher forcing)。

5. seq2seq模型預(yù)測(cè)

以上介紹了如何訓(xùn)練輸?和輸出均為不定?序列的編碼器—解碼器。本節(jié)我們介紹如何使?編碼器—解碼器來預(yù)測(cè)不定?的序列。

在準(zhǔn)備訓(xùn)練數(shù)據(jù)集時(shí),我們通常會(huì)在樣本的輸?序列和輸出序列后面分別附上?個(gè)特殊符號(hào)“<eos>”表?序列的終?。我們?cè)诮酉聛淼挠懻撝幸矊⒀?上?節(jié)的全部數(shù)學(xué)符號(hào)。為了便于討論,假設(shè)解碼器的輸出是?段?本序列。設(shè)輸出?本詞典Y(包含特殊符號(hào)“<eos>”)的?小為|Y|,輸出序列的最??度為T′。所有可能的輸出序列?共有 O(|y|^{T^{′}}) 種。這些輸出序列中所有特殊符號(hào)“<eos>”后?的?序列將被舍棄。

5.1 貪婪搜索

貪婪搜索(greedy search)。對(duì)于輸出序列任?時(shí)間步t′,我們從|Y|個(gè)詞中搜索出條件概率最?的詞:

y_{t^{′}}=argmax_{y\in_{}Y}P(y|y_1,...,y_{t^{′}-1},c)

作為輸出。?旦搜索出“<eos>”符號(hào),或者輸出序列?度已經(jīng)達(dá)到了最??度T′,便完成輸出。我們?cè)诿枋鼋獯a器時(shí)提到,基于輸?序列?成輸出序列的條件概率是\prod_{t^{′}=1}^{T^{′}}P(y_{t^{′}}|y_1,...,y_{t^{′}-1},c)。我們將該條件概率最?的輸出序列稱為最優(yōu)輸出序列。而貪婪搜索的主要問題是不能保證得到最優(yōu)輸出序列。

下?來看?個(gè)例?。假設(shè)輸出詞典??有“A”“B”“C”和“<eos>”這4個(gè)詞。下圖中每個(gè)時(shí)間步
下的4個(gè)數(shù)字分別代表了該時(shí)間步?成“A”“B”“C”和“<eos>”這4個(gè)詞的條件概率。在每個(gè)時(shí)間步,貪婪搜索選取條件概率最?的詞。因此,圖10.9中將?成輸出序列“A”“B”“C”“<eos>”。該輸出序列的條件概率是0.5 × 0.4 × 0.4 × 0.6 = 0.048。

image

接下來,觀察下面演?的例?。與上圖中不同,在時(shí)間步2中選取了條件概率第??的詞“C”
。由于時(shí)間步3所基于的時(shí)間步1和2的輸出?序列由上圖中的“A”“B”變?yōu)榱讼聢D中的“A”“C”,下圖中時(shí)間步3?成各個(gè)詞的條件概率發(fā)?了變化。我們選取條件概率最?的詞“B”。此時(shí)時(shí)間步4所基于的前3個(gè)時(shí)間步的輸出?序列為“A”“C”“B”,與上圖中的“A”“B”“C”不同。因此,下圖中時(shí)間步4?成各個(gè)詞的條件概率也與上圖中的不同。我們發(fā)現(xiàn),此時(shí)的輸出序列“A”“C”“B”“<eos>”的條件概率是0.5 × 0.3 × 0.6 × 0.6 = 0.054,?于貪婪搜索得到的輸出序列的條件概率。因此,貪婪搜索得到的輸出序列“A”“B”“C”“<eos>”并?最優(yōu)輸出序列。

image

5.2 窮舉搜索

如果?標(biāo)是得到最優(yōu)輸出序列,我們可以考慮窮舉搜索(exhaustive search):窮舉所有可能的輸出序列,輸出條件概率最?的序列。

雖然窮舉搜索可以得到最優(yōu)輸出序列,但它的計(jì)算開銷 O(|y|^{T^{′}}) 很容易過?。例如,當(dāng)|Y| =
10000且T′ = 10時(shí),我們將評(píng)估 10000^{10}=10^{40} 個(gè)序列:這?乎不可能完成。而貪婪搜索的計(jì)
算開銷是 O(|y|^{T^{′}}),通常顯著小于窮舉搜索的計(jì)算開銷。例如,當(dāng)|Y| = 10000且T′ = 10時(shí),我
們只需評(píng)估 10000*10=10^5 個(gè)序列。

5.3 束搜索

束搜索(beam search)是對(duì)貪婪搜索的?個(gè)改進(jìn)算法。它有?個(gè)束寬(beam size)超參數(shù)。我們將它設(shè)為 k。在時(shí)間步 1 時(shí),選取當(dāng)前時(shí)間步條件概率最?的 k 個(gè)詞,分別組成 k 個(gè)候選輸出序列的?詞。在之后的每個(gè)時(shí)間步,基于上個(gè)時(shí)間步的 k 個(gè)候選輸出序列,從 k |Y| 個(gè)可能的輸出序列中選取條件概率最?的 k 個(gè),作為該時(shí)間步的候選輸出序列。最終,我們從各個(gè)時(shí)間步的候選輸出序列中篩選出包含特殊符號(hào)“<eos>”的序列,并將它們中所有特殊符號(hào)“<eos>”后?的?序列舍棄,得到最終候選輸出序列的集合。

image

束寬為2,輸出序列最??度為3。候選輸出序列有A、C、AB、CE、ABD和CED。我們將根據(jù)這6個(gè)序列得出最終候選輸出序列的集合。在最終候選輸出序列的集合中,我們?nèi)∫韵路謹(jǐn)?shù)最?的序列作為輸出序列:

\frac{1}{L^{\alpha}}logP(y_1,...,y_L)=\frac{1}{L^{\alpha}}\sum_{t^{′}=1}^{T^{′}}logP(y_{t^{′}}|y_1,...,y_{t^{′}-1},c)

其中 L 為最終候選序列?度,α ?般可選為0.75。分?上的 Lα 是為了懲罰較?序列在以上分?jǐn)?shù)中較多的對(duì)數(shù)相加項(xiàng)。分析可知,束搜索的計(jì)算開銷為 O(k|y|^{T^{′}})。這介于貪婪搜索和窮舉搜索的計(jì)算開銷之間。此外,貪婪搜索可看作是束寬為 1 的束搜索。束搜索通過靈活的束寬 k 來權(quán)衡計(jì)算開銷和搜索質(zhì)量。

6. Bleu得分

評(píng)價(jià)機(jī)器翻譯結(jié)果通常使?BLEU(Bilingual Evaluation Understudy)(雙語評(píng)估替補(bǔ))。對(duì)于模型預(yù)測(cè)序列中任意的?序列,BLEU考察這個(gè)?序列是否出現(xiàn)在標(biāo)簽序列中。

具體來說,設(shè)詞數(shù)為 n 的?序列的精度為 pn。它是預(yù)測(cè)序列與標(biāo)簽序列匹配詞數(shù)為 n 的?序列的數(shù)量與預(yù)測(cè)序列中詞數(shù)為 n 的?序列的數(shù)量之?。舉個(gè)例?,假設(shè)標(biāo)簽序列為A、B、C、D、E、F,預(yù)測(cè)序列為A、B、B、C、D,那么:

P1= \frac{預(yù)測(cè)序列中的 1 元詞組在標(biāo)簽序列是否存在的個(gè)數(shù)}{預(yù)測(cè)序列 1 元詞組的個(gè)數(shù)之和}

預(yù)測(cè)序列一元詞組:A/B/C/D,都在標(biāo)簽序列里存在,所以P1=4/5,以此類推,p2 = 3/4, p3 = 1/3, p4 = 0。設(shè) len_{label}和len_{pred} 分別為標(biāo)簽序列和預(yù)測(cè)序列的詞數(shù),那么,BLEU的定義為:

exp(min(0,1-\frac{len_{label}}{len_{pred}}))\prod_{n=1}^{k}p_n^{\frac{1}{2^n}}

其中 k 是我們希望匹配的?序列的最?詞數(shù)??梢钥吹疆?dāng)預(yù)測(cè)序列和標(biāo)簽序列完全?致時(shí),
BLEU為1。

因?yàn)槠ヅ漭^??序列?匹配較短?序列更難,BLEU對(duì)匹配較??序列的精度賦予了更?權(quán)重。例如,當(dāng) pn 固定在0.5時(shí),隨著n的增?,0.5^{\frac{1}{2}}\approx0.7,0.5^{\frac{1}{4}}\approx0.84,0.5^{\frac{1}{8}}\approx0.92,0.5^{\frac{1}{16}}\approx0.96。另外,模型預(yù)測(cè)較短序列往往會(huì)得到較?pn 值。因此,上式中連乘項(xiàng)前?的系數(shù)是為了懲罰較短的輸出而設(shè)的。舉個(gè)例?,當(dāng)k = 2時(shí),假設(shè)標(biāo)簽序列為A、B、C、D、E、F,而預(yù)測(cè)序列為A、 B。雖然p1 = p2 = 1,但懲罰系數(shù)exp(1-6/2) ≈ 0.14,因此BLEU也接近0.14。

7. 代碼實(shí)現(xiàn)

TensorFlow seq2seq的基本實(shí)現(xiàn)

機(jī)器學(xué)習(xí)通俗易懂系列文章

3.png

8. 參考文獻(xiàn)

動(dòng)手學(xué)深度學(xué)習(xí)


作者:@mantchs

GitHub:https://github.com/NLP-LOVE/ML-NLP

歡迎大家加入討論!共同完善此項(xiàng)目!群號(hào):【541954936】點(diǎn)擊加入

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容