| soft-attention | hard-attention | |
|---|---|---|
| 優(yōu)點 | 1、參數(shù)少、訓(xùn)練快 2、可微分 |
能處理較長的輸入序列 |
| 缺點 | softmax函數(shù)將較小但非零的概率分配給瑣碎的元素,這降低了少數(shù)真正重要元素的注意力,對于較長的輸入序列效果不好 | 1、序列采樣耗時較大 2、不可微分 |
文章的motivation是將soft attention和hard attention結(jié)合起來,使其保留二者的優(yōu)點,同時丟棄二者的缺點。具體地說,hard attention用于編碼關(guān)于上下文依賴的豐富的結(jié)構(gòu)信息,并將長序列修剪成短得多的序列,以便soft attention處理。相反,soft attention被用來提供一個穩(wěn)定的環(huán)境和強(qiáng)烈的award來幫助訓(xùn)練hard attention處理之后的序列。該方法既能提高soft attention的預(yù)測質(zhì)量,又能提高h(yuǎn)ard attention的可訓(xùn)練性,同時提高了對上下文依賴關(guān)系建模的能力。
背景知識
模型
Reinforced Sequence Sampling (RSS)
hard attention的目標(biāo)是從輸入序列中選擇關(guān)鍵的words,這些關(guān)鍵的words能夠提供足夠的信息來完成下游任務(wù),這樣就可以排除掉許多boring words,從而減少模型的訓(xùn)練時間。
給定一個輸入序列,RSS生成一個等長的向量
,其中
意味著
會被選擇,而
則意味著
會被忽略掉。在RSS中,
是通過attention機(jī)制計算的結(jié)果作為其采樣的概率。RSS的目標(biāo)是學(xué)習(xí)到以下的分布:
其中,表示一個上下文融合層(context fusion layer),如Bi-LSTM,Bi-GRU等,為
生成一個上下文敏感的representation。
將
映射到
被選中的概率。注意到
的計算方式不依賴于
,因此這個步驟可以并行完成。為了進(jìn)一步提高了效率。文章通過下面這個式子來計算
:
而的計算方式則與source2token self-attention相似,如下:
Reinforced Self-Attention (ReSA)

在ReSA中,兩個參數(shù)獨立的RSS分別對輸入序列的進(jìn)行采樣,采樣結(jié)果分別稱為head tokens和dependent tokens。
然后使用、
生成一個mask
:
把放到Masked Self-Attention中:
即score function,然后使用softmax函數(shù)計算概率:
的上下文注意力特性通過以下方式計算:
最后,使用與DiSAN相同的融合層給出最終的輸出: