AttentionCellWrapper的疑問
關(guān)注Attention機(jī)制的同學(xué)們都知道,Attention最初是在Encoder-Decoder結(jié)構(gòu)中由Bahdanau提出來的:
《 Neural Machine Translation by Jointly Learning to Align and Translate 》
https://arxiv.org/abs/1409.0473
這種Attention結(jié)構(gòu)大致如下圖所示:

但細(xì)心的同學(xué)就發(fā)現(xiàn)了,我平常用的都是簡單的RNN,單向展開后就得到結(jié)果,不是Encoder-Decoder的結(jié)構(gòu),這種Attention機(jī)制用不了??!
于是大家就在Tensorflow的API里翻來覆去,終于找到了一個(gè)可以用于單向傳播的Attention機(jī)制API:AttentionCellWrapper
在RNN中使用這個(gè)API也非常簡單,只需要把cell包裹起來就好了:
cell = tf.contrib.rnn.LSTMCell(num_units)
cell = tf.contrib.rnn.AttentionCellWrapper(cell, attn_length = attention_len) # attention_len是自己指定的Attention關(guān)注長度
但仔細(xì)看API里的說明:
Implementation based on 《 Neural Machine Translation by Jointly Learning to Align and Translate 》
是不是有點(diǎn)眼熟……我的RNN(LSTM)明明就一層,沒有第二層Decoder,怎么使用之前的信息呢?
AttentionCellWrapper靈感來源
經(jīng)過我的反復(fù)Google,Tensorflow的AttentionCellWrapper并非基于Encoder-Decoder的架構(gòu)設(shè)計(jì)的,其靈感來源于這篇文章的Attention:
https://magenta.tensorflow.org/2016/07/15/lookback-rnn-attention-rnn
這篇文章提出了一種單向RNN就能使用的Attention結(jié)構(gòu),在處理每一步的輸入時(shí),考慮前面N步的輸出,經(jīng)過映射加權(quán)后把這些歷史信息加到本次輸入的預(yù)測中。公式如下圖:

其中:

看公式有點(diǎn)懵,結(jié)合結(jié)構(gòu)圖看一下:

這里我沒畫LSTM的輸入及與其他step的連接,大家意會(huì)就好。其中:
綠色線:當(dāng)前step(第t步)的cell
藍(lán)色線:前面step的輸出(hi)
紅色線:每一個(gè)step經(jīng)Attention加權(quán)后的輸出,虛線代表權(quán)重小,實(shí)線代表權(quán)重大,對(duì)當(dāng)前step影響也大
黃色方塊:可學(xué)習(xí)的參數(shù)矩陣
如此一來,當(dāng)前step的處理就用到了此前所有step的輸出信息,至于每個(gè)step的輸出貢獻(xiàn)了多少,就要看Attention的這些矩陣學(xué)的怎么樣了。
AttentionCellWrapper結(jié)構(gòu)探秘
然而Tensorflow版的AttentionCellWrapper結(jié)構(gòu)比上述結(jié)構(gòu)還要再復(fù)雜一點(diǎn),它考慮了以下兩點(diǎn):
- 使用了前面所有step信息后的當(dāng)前step輸出,是否可以用于下一個(gè)step的輸入
- 是否可以同時(shí)使用當(dāng)前step的輸出和cell信息,加入到Attention的計(jì)算中
基于這兩點(diǎn),Tensorflow的AttentionCellWrapper實(shí)現(xiàn)了更為周全的結(jié)構(gòu),如下所示:

其中紫色的線是新添加的內(nèi)容。
總結(jié)
Tensorflow設(shè)計(jì)了一種通用的Attention結(jié)構(gòu),使得簡單的時(shí)序模型就能使用Attention機(jī)制,讓整個(gè)模型能更好地將注意力放在貢獻(xiàn)最大的step上,一定程度上解決了時(shí)序模型記憶力不足的問題。
本文為YoungLittleFat原創(chuàng)文章,未經(jīng)允許不得轉(zhuǎn)載。