Mask-Predict: Parallel Decoding of Conditional Masked Language Models

Mask-Predict: Parallel Decoding of Conditional Masked Language Models

<center> 來源:EMNLP2019 </center>
<center> 鏈接:https://arxiv.org/pdf/1904.09324.pdf </center>
<center> 代碼:https://github.com/facebookresearch/Mask-Predict </center>

動機

大部分的神經(jīng)機器翻譯模型(seq2seq)都是以自回歸的方式(autoregressive)進行的,decode階段從左到右依次生成token,因此,decoder階段的時間會隨著輸入句子的長度增加而增加。雖然有一些工作探索了非自回歸的方式的seq2seq模型,decoder階段同時生成target端的所有token,但是與自回歸的seq2seq模型相比,效果相差較大。

解決方法

論文提出了一種新的decoder方法:mask-predict

在decoder過程中,首先用非自回歸的方式生成target端全部token(遮住全部token),然后再遮住模型不確定的一些token,依據(jù)target端其他未遮蓋token文以及源文本重新進行預測,迭代進行。

模型及訓練

Conditional Masked Language Models (CMLM)

P\left(y | X, Y_{o b s}\right)

給定源文X和部分的目標文本Y_{o b s},來預測剩下的目標文本Y_{\text {mask}}

N即是目標端文本的長度N=\left|Y_{m a s k}\right|+\left|Y_{o b s}\right|

論文采用標準的transformer的模型,不同之處在于decode端的self-attention不再使用attention mask防止看到預測單詞之后的詞。

訓練的時候,首先確定需要遮蓋的token個數(shù)(均勻分布在0到target端文本長度),然后隨機選取target端的token進行遮蓋,并進行預測,訓練目標是交叉熵損失函數(shù),只計算遮蓋位置的loss。

預測target端文本的長度

自回歸的方法中,decode從左到右依次生成token,直到預測出EOS結束符,或者超過的給定的最大長度。論文提出的方法并行的生成生成target端全部token,因此必須事先知道target端文本的長度。

仿照Bert原論文,添加一個 [CLS] token來處理分類任務,論文在encoder端添加了一個 [LENGTH] token來預測target端文本的長度,計算的loss和CMLM的loss相加。

inference

使用Mask-Predict解碼

主要可以分為兩步:

Mask

\begin{aligned} Y_{\text {mask}}^{(t)} &=\arg \min _{i}\left(p_{i}, n\right) \\ Y_{o b s}^{(t)} &=Y \backslash Y_{\text {mask}}^{(t)} \end{aligned}

n=N \cdot \frac{T-t}{T}`

第t次迭代,遮住的單詞數(shù)n,T為預先設定的迭代次數(shù)。第0次迭代,遮住target端所有token;隨著迭代次數(shù)的增加,逐步減少遮蓋的單詞數(shù)。
p_{i}是當前迭代次數(shù)的預測概率

Predict

\begin{aligned} y_{i}^{(t)} &=\arg \max _{w} P\left(y_{i}=w | X, Y_{o b s}^{(t)}\right) \\ p_{i}^{(t)} &=\max _{w} P\left(y_{i}=w | X, Y_{o b s}^{(t)}\right) \end{aligned}

給定源文本X和部分的目標文本Y_{o b s},來預測mask的目標文本Y_{\text {mask}},并更新概率

沒有被mask的token Y_{o b s},概率不變

\begin{aligned} y_{i}^{(t)} &=y_{i}^{(t-1)} \\ p_{i}^{(t)} &=p_{i}^{(t-1)} \end{aligned}

兩個步驟迭代進行,下面是一個例子


在這里插入圖片描述

第0步,遮蓋整個target端所有tokens,獲得target端全部輸出,可以看到有很多重復的單詞,整個句子語法也不合理。

第1步,根據(jù)第0步得到的概率,從12個單詞中選出8個低概率的單詞(標黃的部分),使用CMLM重新預測。此時的結果比第一步更加準確合理。

第2步,根據(jù)第1步得到的概率(被重新預測的8個單詞的概率需要更新),從12個單詞中選出4個低概率的單詞,再使用CMLM重新預測。值得注意的是第1步?jīng)]有被重新預測的詞也可能被選中。

預測target端文本的長度

獲取encoder部分 [LENGTH] token處預測到的長度分布之后,選在top l個長度作為候選,并行decode,然后選擇平均概率 \frac{1}{N} \sum\log p_{i}^{(T)} 最好的結果

實驗結果

在機器翻譯上面的結果:


在這里插入圖片描述

在這里插入圖片描述
  1. 論文提出的方法(CMLM with Mask-Predict) 超過了其他的非自回歸的方法。
  2. 相比自回歸的方法,結果差了 0.5-1.2 BLEU, 但是速度加快了。
  3. 在中英(兩種語言相差比較大)翻譯的數(shù)據(jù)集上結果也比較好

速度與效果的trade-off

在這里插入圖片描述
  1. 犧牲2個BLEU的情況下,可以提升3倍多的速度(T=4, l= 2)
  2. 達到相當?shù)男Ч?7.03 vs 27.74),速度提升30%

模型簡化測試

分析decoder是否必須迭代多次

decoder迭代多次主要是為了改善,非自回歸方法共有的 token repetitions 問題。 隨著token重復率的減少,bleu也有所上升

在這里插入圖片描述

分析長句是否需要迭代更多的次數(shù)

在這里插入圖片描述

長句迭代多次有更好的效果。

分析使用更多的長度候選是否能改善效果

在這里插入圖片描述

不能

分析知識蒸餾是否必要

這里知識蒸餾的意思是,訓練非自回歸模型時,訓練集不使用原始訓練集中的target端數(shù)據(jù),而使用teacher模型預測的target結果。

在這里插入圖片描述

必要

總結

  1. 在機器翻譯任務上,與其他非自回歸seq2seq模型相比,效果達到最好;
  2. 在WMT en-de任務上,比baseline自回歸模型低0.7 BLEU,速度提升了30%; bleu損失兩個點,inference的速度可以達到原來的三倍多。
  3. target端文本長度的預測也比較重要,如果完全預測對,bleu還有提升空間
  4. 知識蒸餾的teacher網(wǎng)絡使用了強模型,這也會使得論文的結果強于一般模型。同時,論文也和先前其他論文一樣沒有解釋清楚為什么需要知識蒸餾
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容