Tensorflow中的AttentionCellWrapper:一種更通用的Attention機(jī)制

AttentionCellWrapper的疑問

關(guān)注Attention機(jī)制的同學(xué)們都知道,Attention最初是在Encoder-Decoder結(jié)構(gòu)中由Bahdanau提出來的:

《 Neural Machine Translation by Jointly Learning to Align and Translate 》
https://arxiv.org/abs/1409.0473

這種Attention結(jié)構(gòu)大致如下圖所示:


圖1 Encoder-Decoder中的Attention結(jié)構(gòu)圖



但細(xì)心的同學(xué)就發(fā)現(xiàn)了,我平常用的都是簡單的RNN,單向展開后就得到結(jié)果,不是Encoder-Decoder的結(jié)構(gòu),這種Attention機(jī)制用不了??!

于是大家就在Tensorflow的API里翻來覆去,終于找到了一個(gè)可以用于單向傳播的Attention機(jī)制API:AttentionCellWrapper

在RNN中使用這個(gè)API也非常簡單,只需要把cell包裹起來就好了:

cell = tf.contrib.rnn.LSTMCell(num_units)
cell = tf.contrib.rnn.AttentionCellWrapper(cell, attn_length = attention_len) # attention_len是自己指定的Attention關(guān)注長度

但仔細(xì)看API里的說明:

Implementation based on 《 Neural Machine Translation by Jointly Learning to Align and Translate 》

是不是有點(diǎn)眼熟……我的RNN(LSTM)明明就一層,沒有第二層Decoder,怎么使用之前的信息呢?

AttentionCellWrapper靈感來源

經(jīng)過我的反復(fù)Google,Tensorflow的AttentionCellWrapper并非基于Encoder-Decoder的架構(gòu)設(shè)計(jì)的,其靈感來源于這篇文章的Attention:

https://magenta.tensorflow.org/2016/07/15/lookback-rnn-attention-rnn

這篇文章提出了一種單向RNN就能使用的Attention結(jié)構(gòu),在處理每一步的輸入時(shí),考慮前面N步的輸出,經(jīng)過映射加權(quán)后把這些歷史信息加到本次輸入的預(yù)測中。公式如下圖:


圖2 Attention公式

其中:


圖3 公式釋義

看公式有點(diǎn)懵,結(jié)合結(jié)構(gòu)圖看一下:


圖4 通用型Attention結(jié)構(gòu)圖

這里我沒畫LSTM的輸入及與其他step的連接,大家意會(huì)就好。其中:
綠色線:當(dāng)前step(第t步)的cell
藍(lán)色線:前面step的輸出(hi)
紅色線:每一個(gè)step經(jīng)Attention加權(quán)后的輸出,虛線代表權(quán)重小,實(shí)線代表權(quán)重大,對(duì)當(dāng)前step影響也大
黃色方塊:可學(xué)習(xí)的參數(shù)矩陣

如此一來,當(dāng)前step的處理就用到了此前所有step的輸出信息,至于每個(gè)step的輸出貢獻(xiàn)了多少,就要看Attention的這些矩陣學(xué)的怎么樣了。

AttentionCellWrapper結(jié)構(gòu)探秘

然而Tensorflow版的AttentionCellWrapper結(jié)構(gòu)比上述結(jié)構(gòu)還要再復(fù)雜一點(diǎn),它考慮了以下兩點(diǎn):

  1. 使用了前面所有step信息后的當(dāng)前step輸出,是否可以用于下一個(gè)step的輸入
  2. 是否可以同時(shí)使用當(dāng)前step的輸出和cell信息,加入到Attention的計(jì)算中

基于這兩點(diǎn),Tensorflow的AttentionCellWrapper實(shí)現(xiàn)了更為周全的結(jié)構(gòu),如下所示:


圖5 Tensorflow AttentionCellWrapper結(jié)構(gòu)圖

其中紫色的線是新添加的內(nèi)容。

總結(jié)

Tensorflow設(shè)計(jì)了一種通用的Attention結(jié)構(gòu),使得簡單的時(shí)序模型就能使用Attention機(jī)制,讓整個(gè)模型能更好地將注意力放在貢獻(xiàn)最大的step上,一定程度上解決了時(shí)序模型記憶力不足的問題。

本文為YoungLittleFat原創(chuàng)文章,未經(jīng)允許不得轉(zhuǎn)載。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容