Seq2Seq 模型詳解

Seq2Seq 是一種循環(huán)神經(jīng)網(wǎng)絡(luò)的變種,包括編碼器 (Encoder) 和解碼器 (Decoder) 兩部分。Seq2Seq 是自然語(yǔ)言處理中的一種重要模型,可以用于機(jī)器翻譯、對(duì)話系統(tǒng)、自動(dòng)文摘。

1. RNN 結(jié)構(gòu)及使用

RNN 結(jié)構(gòu)

在之前的文章《循環(huán)神經(jīng)網(wǎng)絡(luò) RNN、LSTM、GRU》中介紹了 RNN 模型,RNN 基本的模型如上圖所示,每個(gè)神經(jīng)元接受的輸入包括:前一個(gè)神經(jīng)元的隱藏層狀態(tài) h (用于記憶) 和當(dāng)前的輸入 x (當(dāng)前信息)。神經(jīng)元得到輸入之后,會(huì)計(jì)算出新的隱藏狀態(tài) h 和輸出 y,然后再傳遞到下一個(gè)神經(jīng)元。因?yàn)殡[藏狀態(tài) h 的存在,使得 RNN 具有一定的記憶功能。

針對(duì)不同任務(wù),通常要對(duì) RNN 模型結(jié)構(gòu)進(jìn)行少量的調(diào)整,根據(jù)輸入和輸出的數(shù)量,分為三種比較常見(jiàn)的結(jié)構(gòu):N vs N、1 vs N、N vs 1。

1.1 N vs N

N vs N 結(jié)構(gòu)

上圖是RNN 模型的一種 N vs N 結(jié)構(gòu),包含 N 個(gè)輸入 x1, x2, ..., xN,和 N 個(gè)輸出 y1, y2, ..., yN。N vs N 的結(jié)構(gòu)中,輸入和輸出序列的長(zhǎng)度是相等的,通常適合用于以下任務(wù):

  • 詞性標(biāo)注
  • 訓(xùn)練語(yǔ)言模型,使用之前的詞預(yù)測(cè)下一個(gè)詞等

1.2 1 vs N

1 vs N 結(jié)構(gòu)(1)

1 vs N 結(jié)構(gòu)(2)

在 1 vs N 結(jié)構(gòu)中,我們只有一個(gè)輸入 x,和 N 個(gè)輸出 y1, y2, ..., yN??梢杂袃煞N方式使用 1 vs N,第一種只將輸入 x 傳入第一個(gè) RNN 神經(jīng)元,第二種是將輸入 x 傳入所有的 RNN 神經(jīng)元。1 vs N 結(jié)構(gòu)適合用于以下任務(wù):

  • 圖像生成文字,輸入 x 就是一張圖片,輸出就是一段圖片的描述文字。
  • 根據(jù)音樂(lè)類別,生成對(duì)應(yīng)的音樂(lè)。
  • 根據(jù)小說(shuō)類別,生成相應(yīng)的小說(shuō)。

1.3 N vs 1

N vs 1 結(jié)構(gòu)

在 N vs 1 結(jié)構(gòu)中,我們有 N 個(gè)輸入 x1, x2, ..., xN,和一個(gè)輸出 y。N vs 1 結(jié)構(gòu)適合用于以下任務(wù):

  • 序列分類任務(wù),一段語(yǔ)音、一段文字的類別,句子的情感分析。

2. Seq2Seq 模型

2.1 Seq2Seq 結(jié)構(gòu)

上面的三種結(jié)構(gòu)對(duì)于 RNN 的輸入和輸出個(gè)數(shù)都有一定的限制,但實(shí)際中很多任務(wù)的序列的長(zhǎng)度是不固定的,例如機(jī)器翻譯中,源語(yǔ)言、目標(biāo)語(yǔ)言的句子長(zhǎng)度不一樣;對(duì)話系統(tǒng)中,問(wèn)句和答案的句子長(zhǎng)度不一樣。

Seq2Seq 是一種重要的 RNN 模型,也稱為 Encoder-Decoder 模型,可以理解為一種 N×M 的模型。模型包含兩個(gè)部分:Encoder 用于編碼序列的信息,將任意長(zhǎng)度的序列信息編碼到一個(gè)向量 c 里。而 Decoder 是解碼器,解碼器得到上下文信息向量 c 之后可以將信息解碼,并輸出為序列。Seq2Seq 模型結(jié)構(gòu)有很多種,下面是幾種比較常見(jiàn)的:

第一種

第一種 Seq2Seq 結(jié)構(gòu)

第二種

第二種 Seq2Seq 結(jié)構(gòu)

第三種

第三種 Seq2Seq 結(jié)構(gòu)

2.2 編碼器 Encoder

這三種 Seq2Seq 模型的主要區(qū)別在于 Decoder,他們的 Encoder 都是一樣的。下圖是 Encoder 部分,Encoder 的 RNN 接受輸入 x,最終輸出一個(gè)編碼所有信息的上下文向量 c,中間的神經(jīng)元沒(méi)有輸出。Decoder 主要傳入的是上下文向量 c,然后解碼出需要的信息。

Seq2Seq Encoder

從上圖可以看到,Encoder 與一般的 RNN 區(qū)別不大,只是中間神經(jīng)元沒(méi)有輸出。其中的上下文向量 c 可以采用多種方式進(jìn)行計(jì)算。

Encoder 上下文向量 c

從公式可以看到,c 可以直接使用最后一個(gè)神經(jīng)元的隱藏狀態(tài) hN 表示;也可以在最后一個(gè)神經(jīng)元的隱藏狀態(tài)上進(jìn)行某種變換 hN 而得到,q 函數(shù)表示某種變換;也可以使用所有神經(jīng)元的隱藏狀態(tài) h1, h2, ..., hN 計(jì)算得到。得到上下文向量 c 之后,需要傳遞到 Decoder。

2.3 解碼器 Decoder

Decoder 有多種不同的結(jié)構(gòu),這里主要介紹三種。

第一種

第一種 Decoder 結(jié)構(gòu)

第一種 Decoder 結(jié)構(gòu)比較簡(jiǎn)單,將上下文向量 c 當(dāng)成是 RNN 的初始隱藏狀態(tài),輸入到 RNN 中,后續(xù)只接受上一個(gè)神經(jīng)元的隱藏層狀態(tài) h' 而不接收其他的輸入 x。第一種 Decoder 結(jié)構(gòu)的隱藏層及輸出的計(jì)算公式:

第一種 Decoder 結(jié)構(gòu)隱藏層及輸出層

第二種

第二種 Decoder 結(jié)構(gòu)

第二種 Decoder 結(jié)構(gòu)有了自己的初始隱藏層狀態(tài) h'0,不再把上下文向量 c 當(dāng)成是 RNN 的初始隱藏狀態(tài),而是當(dāng)成 RNN 每一個(gè)神經(jīng)元的輸入??梢钥吹皆?Decoder 的每一個(gè)神經(jīng)元都擁有相同的輸入 c,這種 Decoder 的隱藏層及輸出計(jì)算公式:

第二種 Decoder 結(jié)構(gòu)隱藏層及輸出層

第三種

第三種 Decoder 結(jié)構(gòu)

第三種 Decoder 結(jié)構(gòu)和第二種類似,但是在輸入的部分多了上一個(gè)神經(jīng)元的輸出 y'。即每一個(gè)神經(jīng)元的輸入包括:上一個(gè)神經(jīng)元的隱藏層向量 h',上一個(gè)神經(jīng)元的輸出 y',當(dāng)前的輸入 c (Encoder 編碼的上下文向量)。對(duì)于第一個(gè)神經(jīng)元的輸入 y'0,通常是句子其實(shí)標(biāo)志位的 embedding 向量。第三種 Decoder 的隱藏層及輸出計(jì)算公式:

第三種 Decoder 結(jié)構(gòu)隱藏層及輸出層

3. Seq2Seq模型使用技巧

3.1 Teacher Forcing

Teacher Forcing 用于訓(xùn)練階段,主要針對(duì)上面第三種 Decoder 模型來(lái)說(shuō)的,第三種 Decoder 模型神經(jīng)元的輸入包括了上一個(gè)神經(jīng)元的輸出 y'。如果上一個(gè)神經(jīng)元的輸出是錯(cuò)誤的,則下一個(gè)神經(jīng)元的輸出也很容易錯(cuò)誤,導(dǎo)致錯(cuò)誤會(huì)一直傳遞下去。

而 Teacher Forcing 可以在一定程度上緩解上面的問(wèn)題,在訓(xùn)練 Seq2Seq 模型時(shí),Decoder 的每一個(gè)神經(jīng)元并非一定使用上一個(gè)神經(jīng)元的輸出,而是有一定的比例采用正確的序列作為輸入。

舉例說(shuō)明,在翻譯任務(wù)中,給定英文句子翻譯為中文。"I have a cat" 翻譯成 "我有一只貓",下圖是不使用 Teacher Forcing 的 Seq2Seq

不使用 Teacher Forcing

如果使用 Teacher Forcing,則神經(jīng)元直接使用正確的輸出作為當(dāng)前神經(jīng)元的輸入。

使用 Teacher Forcing

3.2 Attention

在 Seq2Seq 模型,Encoder 總是將源句子的所有信息編碼到一個(gè)固定長(zhǎng)度的上下文向量 c 中,然后在 Decoder 解碼的過(guò)程中向量 c 都是不變的。這存在著不少缺陷:

  • 對(duì)于比較長(zhǎng)的句子,很難用一個(gè)定長(zhǎng)的向量 c 完全表示其意義。
  • RNN 存在長(zhǎng)序列梯度消失的問(wèn)題,只使用最后一個(gè)神經(jīng)元得到的向量 c 效果不理想。
  • 與人類的注意力方式不同,即人類在閱讀文章的時(shí)候,會(huì)把注意力放在當(dāng)前的句子上。

Attention 即注意力機(jī)制,是一種將模型的注意力放在當(dāng)前翻譯單詞上的一種機(jī)制。例如翻譯 "I have a cat",翻譯到 "我" 時(shí),要將注意力放在源句子的 "I" 上,翻譯到 "貓" 時(shí)要將注意力放在源句子的 "cat" 上。

使用了 Attention 后,Decoder 的輸入就不是固定的上下文向量 c 了,而是會(huì)根據(jù)當(dāng)前翻譯的信息,計(jì)算當(dāng)前的 c。

Attention

Attention 需要保留 Encoder 每一個(gè)神經(jīng)元的隱藏層向量 h,然后 Decoder 的第 t 個(gè)神經(jīng)元要根據(jù)上一個(gè)神經(jīng)元的隱藏層向量 h't-1 計(jì)算出當(dāng)前狀態(tài)與 Encoder 每一個(gè)神經(jīng)元的相關(guān)性 et。et 是一個(gè) N 維的向量 (Encoder 神經(jīng)元個(gè)數(shù)為 N),若 et 的第 i 維越大,則說(shuō)明當(dāng)前節(jié)點(diǎn)與 Encoder 第 i 個(gè)神經(jīng)元的相關(guān)性越大。et 的計(jì)算方法有很多種,即相關(guān)性系數(shù)的計(jì)算函數(shù) a 有很多種:

Attention 相關(guān)性

上面得到相關(guān)性向量 et 后,需要進(jìn)行歸一化,使用 softmax 歸一化。然后用歸一化后的系數(shù)融合 Encoder 的多個(gè)隱藏層向量得到 Decoder 當(dāng)前神經(jīng)元的上下文向量 ct:

使用 Attention 計(jì)算上下文向量 c

3.3 beam search

beam search 方法不用于訓(xùn)練的過(guò)程,而是用在測(cè)試的。在每一個(gè)神經(jīng)元中,我們都選取當(dāng)前輸出概率值最大的 top k 個(gè)輸出傳遞到下一個(gè)神經(jīng)元。下一個(gè)神經(jīng)元分別用這 k 個(gè)輸出,計(jì)算出 L 個(gè)單詞的概率 (L 為詞匯表大小),然后在 kL 個(gè)結(jié)果中得到 top k 個(gè)最大的輸出,重復(fù)這一步驟。

4. Seq2Seq 總結(jié)

Seq2Seq 模型允許我們使用長(zhǎng)度不同的輸入和輸出序列,適用范圍相當(dāng)廣,可用于機(jī)器翻譯,對(duì)話系統(tǒng),閱讀理解等場(chǎng)景。

Seq2Seq 模型使用時(shí)可以利用 Teacher Forceing,Attention,beam search 等方法優(yōu)化。

參考文獻(xiàn)

RNN神經(jīng)網(wǎng)絡(luò)模型的不同結(jié)構(gòu)
Tensorflow中的Seq2Seq全家桶
Attention機(jī)制詳解(一)——Seq2Seq中的Attention

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容