BART 詳解

本文為轉(zhuǎn)載,原文鏈接 https://wmathor.com/index.php/archives/1505/

一切都得從 Transformer 說(shuō)起。Transformer 左半邊為 Encoder,右半邊為 Decoder。我們將 Encoder 輸入的句子稱為 source,Decoder 輸入的句子稱為 target

Encoder 負(fù)責(zé)將 source 進(jìn)行 self-attention 并獲得句子中每個(gè)詞的 representation,最經(jīng)典的 Encoder 架構(gòu)就是 BERT,通過(guò) Masked Language Model 來(lái)學(xué)習(xí)詞之間的關(guān)系,另外還有 XLNet, RoBERTa, ALBERT, DistilBERT 等等。但是單獨(dú) Encoder 結(jié)構(gòu)不適用于生成任務(wù)

Decoder 如下圖所示,輸入與輸出之間差一個(gè)位置,主要是模擬在 Inference 時(shí),不能讓模型看到未來(lái)的詞,這種方式稱為 AutoRegressive,常見(jiàn)的基于 Decoder 的模型通常是用來(lái)做序列生成的,例如 GPT, CTRL 等等。但是單獨(dú) Decoder 結(jié)構(gòu)僅基于左側(cè)上下文預(yù)測(cè)單詞,無(wú)法學(xué)習(xí)雙向交互


而兩者合在一起后,就能當(dāng)成一種 Seq2Seq 模型,進(jìn)行翻譯任務(wù)。下圖是 BART 的主要結(jié)構(gòu),看上去似乎和 Transformer 沒(méi)什么不同,主要區(qū)別在于 source 和 target

訓(xùn)練階段,Encoder 端使用雙向模型編碼被破壞的文本,然后 Decoder 采用自回歸的方式計(jì)算出原始輸入;測(cè)試階段或者是微調(diào)階段,Encoder 和 Decoder 的輸入都是未被破壞的文本

BART vs Transformer

BART 使用標(biāo)準(zhǔn)的 Transformer 模型,不過(guò)做了一些改變:

  1. 同 GPT 一樣,將 ReLU 激活函數(shù)改為 GeLU,并且參數(shù)初始化服從正態(tài)分布 N(0,0.02)
  2. BART base 模型的 Encoder 和 Decoder 各有 6 層,large 模型增加到了 12 層
  3. BART 解碼器的各層對(duì)編碼器最終隱藏層額外執(zhí)行 cross-attention
  4. BERT 在詞預(yù)測(cè)之前使用了額外的 Feed Forward Layer,而 BART 沒(méi)有

Pre-training BART

BART 作者嘗試了不同的方式來(lái)破壞輸入:

  • Token Masking:Following BERT (Devlin et al., 2019), random tokens are sampled and replaced with [MASK] elements.
  • Sentence Permutation:A document is divided into sentences based on full stops, and these sentences are shuffled in a random order.
  • Document Rotation:A token is chosen uniformly at random, and the document is rotated so that it begins with that token. This task trains the model to identify the start of the document.
  • Token Deletion:Random tokens are deleted from the input. In contrast to token masking, the model must decide which positions are missing inputs.
  • Text Infilling:A number of text spans are sampled, with span lengths drawn from a Poisson distribution (\lambda=3). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens. Text infilling teaches the model to predict how many tokens are missing from a span.

Fine-tuning BART

Sequence Classification Tasks

序列分類(lèi)任務(wù)中,編碼器和解碼器的輸入相同,解碼器 token 的最終隱藏狀態(tài)被輸入到多類(lèi)別線性分類(lèi)器中。BART 在解碼器最后額外添加了一個(gè) token,如下圖所示,該 token 位置的輸出可以被認(rèn)為是該句子的 representation


Sequence Generation Tasks

由于 BART 具備自回歸解碼器,因此它可以針對(duì)序列生成任務(wù)進(jìn)行直接微調(diào),如問(wèn)答或者文本摘要

Machine Translation

作者采用新的隨機(jī)初始化 Encoder 替換 BART 編碼器的 Embedding 層。該模型以端到端的方式進(jìn)行訓(xùn)練,即訓(xùn)練一個(gè)新的編碼器將外來(lái)詞映射到輸入。新的編碼器可以使用不同于原始 BART 模型的詞匯。其中隨機(jī)初始化 Encoder 的訓(xùn)練分兩步,均需要將來(lái)自 BART 模型輸出的交叉熵?fù)p失進(jìn)行反向傳播。第一步,作者凍結(jié) BART 的大部分參數(shù),僅更新隨機(jī)初始化的 Encoder、BART 位置嵌入和 BART 編碼器第一層的自注意力輸入投影矩陣。第二步,作者將所有模型參數(shù)進(jìn)行少量迭代訓(xùn)練


Results

從上表可以看出,貌似帶上 Document Rotation 或 Sentence Shuffling 效果都不是太好,可以這么理解,假如模型在訓(xùn)練的時(shí)候看到的句子順序都是亂的,它可能就認(rèn)為這個(gè)世界的句子順序都是亂的,當(dāng)你做測(cè)試的時(shí)候,輸入的句子是正序的,可能模型就不知所措了。實(shí)際上 Text Infilling 可以看作是 Token Masking+Token Deletion,所以 Text Infilling 效果這么好也可以理解


Reference

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容