Mask-Predict: Parallel Decoding of Conditional Masked Language Models
<center> 來源:EMNLP2019 </center>
<center> 鏈接:https://arxiv.org/pdf/1904.09324.pdf </center>
<center> 代碼:https://github.com/facebookresearch/Mask-Predict </center>
動機
大部分的神經(jīng)機器翻譯模型(seq2seq)都是以自回歸的方式(autoregressive)進行的,decode階段從左到右依次生成token,因此,decoder階段的時間會隨著輸入句子的長度增加而增加。雖然有一些工作探索了非自回歸的方式的seq2seq模型,decoder階段同時生成target端的所有token,但是與自回歸的seq2seq模型相比,效果相差較大。
解決方法
論文提出了一種新的decoder方法:mask-predict
在decoder過程中,首先用非自回歸的方式生成target端全部token(遮住全部token),然后再遮住模型不確定的一些token,依據(jù)target端其他未遮蓋token文以及源文本重新進行預測,迭代進行。
模型及訓練
Conditional Masked Language Models (CMLM)
給定源文和部分的目標文本
,來預測剩下的目標文本
即是目標端文本的長度
論文采用標準的transformer的模型,不同之處在于decode端的self-attention不再使用attention mask防止看到預測單詞之后的詞。
訓練的時候,首先確定需要遮蓋的token個數(shù)(均勻分布在0到target端文本長度),然后隨機選取target端的token進行遮蓋,并進行預測,訓練目標是交叉熵損失函數(shù),只計算遮蓋位置的loss。
預測target端文本的長度
自回歸的方法中,decode從左到右依次生成token,直到預測出EOS結束符,或者超過的給定的最大長度。論文提出的方法并行的生成生成target端全部token,因此必須事先知道target端文本的長度。
仿照Bert原論文,添加一個 [CLS] token來處理分類任務,論文在encoder端添加了一個 [LENGTH] token來預測target端文本的長度,計算的loss和CMLM的loss相加。
inference
使用Mask-Predict解碼
主要可以分為兩步:
Mask
`
第t次迭代,遮住的單詞數(shù)n,T為預先設定的迭代次數(shù)。第0次迭代,遮住target端所有token;隨著迭代次數(shù)的增加,逐步減少遮蓋的單詞數(shù)。
是當前迭代次數(shù)的預測概率
Predict
給定源文本和部分的目標文本
,來預測mask的目標文本
,并更新概率
沒有被mask的token ,概率不變
兩個步驟迭代進行,下面是一個例子
第0步,遮蓋整個target端所有tokens,獲得target端全部輸出,可以看到有很多重復的單詞,整個句子語法也不合理。
第1步,根據(jù)第0步得到的概率,從12個單詞中選出8個低概率的單詞(標黃的部分),使用CMLM重新預測。此時的結果比第一步更加準確合理。
第2步,根據(jù)第1步得到的概率(被重新預測的8個單詞的概率需要更新),從12個單詞中選出4個低概率的單詞,再使用CMLM重新預測。值得注意的是第1步?jīng)]有被重新預測的詞也可能被選中。
預測target端文本的長度
獲取encoder部分 [LENGTH] token處預測到的長度分布之后,選在top l個長度作為候選,并行decode,然后選擇平均概率 最好的結果
實驗結果
在機器翻譯上面的結果:
- 論文提出的方法(CMLM with Mask-Predict) 超過了其他的非自回歸的方法。
- 相比自回歸的方法,結果差了 0.5-1.2 BLEU, 但是速度加快了。
- 在中英(兩種語言相差比較大)翻譯的數(shù)據(jù)集上結果也比較好
速度與效果的trade-off
- 犧牲2個BLEU的情況下,可以提升3倍多的速度(T=4, l= 2)
- 達到相當?shù)男Ч?7.03 vs 27.74),速度提升30%
模型簡化測試
分析decoder是否必須迭代多次
decoder迭代多次主要是為了改善,非自回歸方法共有的 token repetitions 問題。 隨著token重復率的減少,bleu也有所上升
分析長句是否需要迭代更多的次數(shù)
長句迭代多次有更好的效果。
分析使用更多的長度候選是否能改善效果
不能
分析知識蒸餾是否必要
這里知識蒸餾的意思是,訓練非自回歸模型時,訓練集不使用原始訓練集中的target端數(shù)據(jù),而使用teacher模型預測的target結果。
必要
總結
- 在機器翻譯任務上,與其他非自回歸seq2seq模型相比,效果達到最好;
- 在WMT en-de任務上,比baseline自回歸模型低0.7 BLEU,速度提升了30%; bleu損失兩個點,inference的速度可以達到原來的三倍多。
- target端文本長度的預測也比較重要,如果完全預測對,bleu還有提升空間
- 知識蒸餾的teacher網(wǎng)絡使用了強模型,這也會使得論文的結果強于一般模型。同時,論文也和先前其他論文一樣沒有解釋清楚為什么需要知識蒸餾