Reinforced Self-Attention Network: a Hybrid of Hard and Soft Attention for Sequence Modeling

soft-attention hard-attention
優(yōu)點 1、參數(shù)少、訓(xùn)練快
2、可微分
能處理較長的輸入序列
缺點 softmax函數(shù)將較小但非零的概率分配給瑣碎的元素,這降低了少數(shù)真正重要元素的注意力,對于較長的輸入序列效果不好 1、序列采樣耗時較大
2、不可微分

文章的motivation是將soft attention和hard attention結(jié)合起來,使其保留二者的優(yōu)點,同時丟棄二者的缺點。具體地說,hard attention用于編碼關(guān)于上下文依賴的豐富的結(jié)構(gòu)信息,并將長序列修剪成短得多的序列,以便soft attention處理。相反,soft attention被用來提供一個穩(wěn)定的環(huán)境和強(qiáng)烈的award來幫助訓(xùn)練hard attention處理之后的序列。該方法既能提高soft attention的預(yù)測質(zhì)量,又能提高h(yuǎn)ard attention的可訓(xùn)練性,同時提高了對上下文依賴關(guān)系建模的能力。

背景知識

模型

Reinforced Sequence Sampling (RSS)

hard attention的目標(biāo)是從輸入序列中選擇關(guān)鍵的words,這些關(guān)鍵的words能夠提供足夠的信息來完成下游任務(wù),這樣就可以排除掉許多boring words,從而減少模型的訓(xùn)練時間。
給定一個輸入序列\boldsymbol{x} = \left[ x _ { 1 } , \dots , x _ { n } \right],RSS生成一個等長的向量\boldsymbol{z} = \left[ z _ { 1 } , \dots , z _ { n } \right],其中z_i=1意味著x_i會被選擇,而z_i=0則意味著x_i會被忽略掉。在RSS中,z_i是通過attention機(jī)制計算的結(jié)果作為其采樣的概率。RSS的目標(biāo)是學(xué)習(xí)到以下的分布:
\begin{aligned} p ( \boldsymbol { z } | \boldsymbol { x } ; \theta _ { r } ) & = \prod _ { i = 1 } ^ { n } p \left( z _ { i } | \boldsymbol { x } ; \theta _ { r } \right) \\ \text { where } p \left( z _ { i } | \boldsymbol { x } ; \theta _ { r } \right) & = g \left( f \left( \boldsymbol { x } ; \theta _ { f } \right) _ { i } ; \theta _ { g } \right) \end{aligned}
其中,f \left( \cdot ; \theta _ { f } \right)表示一個上下文融合層(context fusion layer),如Bi-LSTM,Bi-GRU等,為x_i生成一個上下文敏感的representation。g \left( \cdot ; \theta _ { g } \right)f \left( \cdot ; \theta _ { f } \right)映射到x_i被選中的概率。注意到z_i的計算方式不依賴于z_{i-1},因此這個步驟可以并行完成。為了進(jìn)一步提高了效率。文章通過下面這個式子來計算f \left( \boldsymbol{x} ; \theta _ { f } \right) _ { i }
f \left( \boldsymbol{x} ; \theta _ { f } \right) _ { i } = \left[ x _ { i } ; \text { mead_pooling } ( \boldsymbol{x} ) ; x _ { i } \odot \text { mead_pooling } ( \boldsymbol{x} ) \right]
g \left( f \left( x ; \theta _ { f } \right) _ { i } ; \theta _ { g } \right)的計算方式則與source2token self-attention相似,如下:

g \left( f \left( x ; \theta _ { f } \right)_ {i} ; \theta _ { g } \right) = \operatorname { sigmoid } \left( w ^ { T } \sigma \left( W ^ { ( R ) }f \left( x ; \theta _ { f } \right)_ {i} + b ^ { ( R ) } \right) + b \right)

Reinforced Self-Attention (ReSA)

ReSA

在ReSA中,兩個參數(shù)獨立的RSS分別對輸入序列的進(jìn)行采樣,采樣結(jié)果分別稱為head tokens和dependent tokens。
\begin{array} { l } { \hat { z } ^ { h } = \left[ \hat { z } _ { 1 } ^ { h } , \ldots , \hat { z } _ { n } ^ { h } \right] \sim \operatorname { RSS } \left( x ; \theta _ { r h } \right) } \\ { \hat { z } ^ { d } = \left[ \hat { z } _ { 1 } ^ { d } , \ldots , \hat { z } _ { n } ^ { d } \right] \sim \operatorname { RSS } \left( x ; \theta _ { r d } \right) } \end{array}
然后使用\hat { z } ^ { h }、\hat { z } ^ { d }生成一個mask M^{rss}
M _ { i j } ^ { r s s } = \left\{ \begin{array} { l l } { 0 , } & { \hat { z } _ { i } ^ { d } = \hat { z } _ { j } ^ { h } = 1 \& i \neq j } \\ { - \infty , } & { \text { otherwise } } \end{array} \right.
M^{rss}放到Masked Self-Attention中:
f ^ { r s s } \left( x _ { i } , x _ { j } \right) = c \cdot \tanh \left( \left[ W ^ { ( 1 ) } x _ { i } + W ^ { ( 2 ) } x _ { j } + b ^ { ( 1 ) } \right] / c \right) + M _ { i j } ^ { r s s }
f ^ { r s s } \left( x _ { i } , x _ { j } \right)即score function,然后使用softmax函數(shù)計算概率:
P ^ { j } = \operatorname { softmax } \left( \left[ f ^ { r s s } \left( x _ { i } , x _ { j } \right) \right] _ { i = 1 } ^ { n } \right) , \text { for } j = 1 , \ldots , n
x_j的上下文注意力特性通過以下方式計算:
s _ { j } = \sum _ { i = 1 } ^ { n } P _ { i } ^ { j } \odot x _ { i } , \text { for } j = 1 , \ldots , n
最后,使用與DiSAN相同的融合層給出最終的輸出:
\begin{aligned} F & = \operatorname { sigmoid } \left( W ^ { ( f ) } [ \boldsymbol { x } ; s ] + b ^ { ( f ) } \right) \\ \boldsymbol { u } & = F \odot \boldsymbol { x } + ( 1 - F ) \odot \boldsymbol { s } \end{aligned}

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容