【論文閱讀筆記】文本摘要任務(wù)中的copy機(jī)制(Summarization with Pointer-Generator Networks)

Pointer Network (Vinyals et al., 2015)

Pointer Network針對(duì)原seq2seq模型的輸出序列受限于固定大小的問題而提出,該框架期望decoder的輸出長度隨encoder模型的輸入長度變化而變化,本質(zhì)上是對(duì)基于attention機(jī)制的seq2seq模型的簡化,decoder的每一個(gè)時(shí)間步將輸出input sequence各token的概率分布,選擇概率最高的輸出,直至輸出<EOS>。

設(shè)輸入序列為\mathbf{X}=\{x_1,x_2,\dots,x_n\},輸出序列為\mathbf{Y}=\{y_1,y_2,\dots,y_{m(\mathbf{X})}\},此處的m(\mathbf{X})表示輸出序列的長度與輸入序列相關(guān)。將encoder部分的隱藏層狀態(tài)表示為(e_1,e_2,\dots,e_n),decoder部分的隱藏層狀態(tài)表示為(d_1,d_2,\dots,d_{m(\mathbf{X})})

Pointer Network在第i個(gè)位置的輸出P(y_i|y_1,\dots,y_{i-1},x_1,\dots,x_n)計(jì)算如下:
u_j^i=v^\top tanh(W_1e_j+W_2d_i) \qquad j \in (1, \dots , n) \\ P(y_i|y_1,\dots,y_{i-1},x_1,\dots,x_n)=softmax(u^i)
其中v、W_1、W_2均為模型需要學(xué)習(xí)的參數(shù),第一個(gè)式子則是attention機(jī)制中計(jì)算decoder第i個(gè)位置的隱藏狀態(tài)與encoder輸入序列各位置隱藏狀態(tài)的關(guān)聯(lián),對(duì)應(yīng)輸入序列中各token的分值,各分值經(jīng)過softmax歸一化操作得到的輸出視為輸入序列各token的概率分布,該步將選擇概率最大的token作為輸出。當(dāng)然此時(shí)的輸入序列與原seq2seq模型的不同在于需額外添加一個(gè)<EOS>的token。

Get To The Point: Summarization with Pointer-Generator Networks (See et al., 2017)

Pointer-Generator Networks可以視為一個(gè)基于attention機(jī)制的seq2seq模型和pointer network的混合體,既能從給定詞匯表中生成新token,又能從原輸入序列中拷貝舊token,其框架如下圖所示。

Pointer-Generator Networks

圖中Source Text中各token w_i經(jīng)過一個(gè)單層雙向LSTM將依次得到Encoder Hidden States序列,各隱藏層狀態(tài)表示為h_i。在每一個(gè)時(shí)間步t,decoder將根據(jù)上一個(gè)預(yù)測得到的單詞的embedding經(jīng)一個(gè)單層雙向LSTM得到Decoder Hidden State s_t,此時(shí)計(jì)算基于s_t的各h_i的Attention Distribution a^t計(jì)算如下:
e^t_i = v^\top tanh(W_h h_i + W_s s_t + b_{attn}) \\ a^t = softmax(e^t)
其中,vW_h、W_sb_{attn}均為模型要學(xué)習(xí)的參數(shù)。接下來Attention Distribution將被用于生成當(dāng)前時(shí)間步的上下文向量h_t^\star,繼而同Decoder Hidden State s_t拼接起來經(jīng)由兩個(gè)線性層產(chǎn)生基于輸出序列詞典的Vocabulary Distribution P_{vocab}
h_t^\star = \sum_i a_i^t h_i \\ P_{vocab} = softmax( V ^{'} ( V [s_t , h_t^\star ] + b ) + b^{'})
其中,VV^{'}、b、b^{'}均為模型需要學(xué)習(xí)的參數(shù)。

上述過程為傳統(tǒng)基于attention機(jī)制的seq2seq模型的計(jì)算過程。接下來為了在輸出中可以拷貝輸入序列中的token,將根據(jù)上下文向量h_t^\star、Decoder Hidden State s_t和Decoder input x_t計(jì)算生成概率p_{gen}:
p_{gen} = \sigma (w_{h^\star}^\top h_t^\star + w_s^\top s_t + w_x^\top x_t + b_{ptr})
其中w_{h^\star}^\top、w_s^\top、w_x^\topb_{ptr}均為模型要學(xué)習(xí)的參數(shù)。p_{gen}的作用在于調(diào)節(jié)生成的單詞是來自于根據(jù)P_{vocab}在輸出序列的詞典中的采樣還是來自于根據(jù)a^t在輸入序列的token中的采樣,最終的token分布表示如下:
P(w)=p_{gen} P_{vocab}(w)+\left(1-p_{gen}\right) \sum_{i: w_{i}=w} a_{i}^{t}
其中i: w_{i}=w表示輸入序列中的token w,這里會(huì)將在輸入序列中可能出現(xiàn)多次的w的注意力分布相加。當(dāng)w未在輸出序列的詞典中出現(xiàn)時(shí),P_{vocab}(w)=0;類似地,當(dāng)w未出現(xiàn)在輸入序列中時(shí),\sum_{i: w_{i}=w} a_{i}^{t} = 0.

總結(jié)

Pointer-Generator Networks在基于attention機(jī)制的seq2seq模型中融合copy機(jī)制,并應(yīng)用于文本摘要任務(wù)中,實(shí)則是基于上下文向量、decoder input以及decoder hidden state計(jì)算一個(gè)生成單詞的概率p,對(duì)應(yīng)拷貝單詞的概率則為1-p,根據(jù)概率綜合encoder的注意力分布和decoder的output分布得到一個(gè)綜合的基于input token和output vocabulary的token分布。此外本文關(guān)注的是多語句的摘要生成,因此額外考慮了生成摘要時(shí)的重復(fù)問題,在計(jì)算attention得分時(shí),除了考慮decoder hidden state和encoder hidden state外,還額外加入之前生成token的attention分布總和一項(xiàng),并在最終loss的計(jì)算上額外添加了一個(gè)名為 coverage loss的懲罰項(xiàng)(該部分上文尚未細(xì)述),以避免摘要生成時(shí)的重復(fù)問題。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容