Attention Is All You Need-谷歌的"自注意力"

上一篇文章記錄了自然語(yǔ)言處理中的注意力機(jī)制,這篇文章分析一下google的一篇論文Attention Is All You Need。

為什么不使用循環(huán)神經(jīng)網(wǎng)絡(luò)

其實(shí)早在google之前,facebook就在[1]中拋棄了RNN等提出了基于卷積的sequence to sequence模型。由于RNN中時(shí)間步之間存在依賴(lài)關(guān)系,因此各時(shí)間步無(wú)法并行運(yùn)算,使得GPU并行計(jì)算的優(yōu)勢(shì)無(wú)法發(fā)揮。在同等神經(jīng)元量級(jí)的情況下,RNN訓(xùn)練速度較CNN相比更慢。因此很多研究中希望用其他的網(wǎng)絡(luò)類(lèi)型來(lái)替代RNN,這可能也是google這篇論文的出發(fā)點(diǎn)之一。

主要結(jié)構(gòu)

首先看一些論文中對(duì)attention的定義,對(duì)于Q\in R^{m*d_k}, K \in R^{n*d_k}, V \in R^{n*d_v},有:
(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V
這個(gè)定義表面上看和我們之前提到的注意力機(jī)制有很大的不同, Q, K, V 這幾個(gè)矩陣也不知道什么意思。這里可以類(lèi)比翻譯任務(wù),Q可以是看做是decoder的隱藏狀態(tài),而K可以看做encoder的隱藏狀態(tài),對(duì)于機(jī)器翻譯任務(wù)K=V。在其他應(yīng)用場(chǎng)景,例如在memory network中,K, V可能是分別用address的向量和用于修改memory的向量。這里QK^T實(shí)際上就是計(jì)算出一個(gè)權(quán)重,然后對(duì)V進(jìn)行加權(quán),也和之前提到的注意力結(jié)構(gòu)一致。這里除以\sqrt {d_k}是避免QK乘積過(guò)大,softmax計(jì)算出結(jié)果出現(xiàn)上溢出。論文中給出的注意力計(jì)算流程如下:

(left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.

其中左邊是普通注意力機(jī)制,右邊則是論文中提到的對(duì)普通注意力機(jī)制的一種改進(jìn)。也就是Multi-Head Attention。

Multi-Head Attention

Multi-Head Attention首先對(duì)Q, K,V分別乘不同的變換矩陣進(jìn)行變換,并重復(fù)多次這樣的操作。寫(xiě)成公式就是:
\begin{aligned} \text { MultiHead }(Q, K, V) &=\text { Concat }\left(\text { head }_{1}, \ldots, \text { head }_{\mathrm{h}}\right) W^{O} \\ \text { where head }_{\mathrm{i}} &=\text { Attention }\left(Q W_{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\right) \end{aligned}
論文中提到這種方式可以學(xué)習(xí)到不同子空間的特征。這里其實(shí)有一點(diǎn)CNN的意思,對(duì)相同的數(shù)據(jù)用不同的核進(jìn)行處理。

自注意力機(jī)制

論文中還提出了這種注意力機(jī)制的一個(gè)應(yīng)用,也就是自注意力機(jī)制。自注意力機(jī)制也就是Q=K=V的情況。使用這種方式,論文中提出了transformer結(jié)構(gòu),并采用了這種結(jié)構(gòu)實(shí)現(xiàn)了sequence to sequence模型:

transformer結(jié)構(gòu)

下面分別對(duì)其中的模塊進(jìn)行講解:

PE(Positional Encoding)

從圖中可以看出輸入先經(jīng)過(guò)了一個(gè)Positional Encoding(位置嵌入, PE)。其實(shí)PE在其他論文[2]也有體現(xiàn)。進(jìn)行PE的的主要原因是transformer中并沒(méi)有任何可以體現(xiàn)輸入順序的結(jié)構(gòu)。對(duì)于NLP來(lái)說(shuō),詞語(yǔ)的順序是非常重要的。因此論文中使用了PE。論文中直接指定了公式:
P E_{(p o s, 2 i)}=\sin \left(p o s / 10000^{2 i / d_{\mathrm{matel}}}\right)
P E_{(p o s, 2 i+1)}=\cos \left(p o s / 10000^{2 i / d_{\mathrm{model}}}\right)
其中pos是指詞語(yǔ)在句子中的位置,i是每一個(gè)詞向量中的第i個(gè)元素,d_model是詞向量的維度。論文中提到使用學(xué)得的詞向量結(jié)果與這種方式相似,于是論文就采用了這種方式進(jìn)行PE。

Position-wise Feed-Forward Networks

論文中對(duì)該層的描述是:

a fully connected feed-forward network, which is applied to each position separately and identically

其實(shí)也就是核大小為1的卷積網(wǎng)絡(luò)。公式描述如下:
\mathrm{FFN}(x)=\max \left(0, x W_{1}+b_{1}\right) W_{2}+b_{2}
max(0, .)也就是relu激活函數(shù)。

除了上面提到的幾個(gè)模塊之外,在Multi-Head Attention以及Feed Forward層都使用了殘差連接以及l(fā)ayer normalization?!?/p>

Self-Attention的優(yōu)勢(shì)

論文中給出了Self-Attention的幾種優(yōu)勢(shì):

  1. 由于Multi-Head Attention每一層都可以并行計(jì)算,因此計(jì)算速度相比RNN有優(yōu)勢(shì)。
  2. 在長(zhǎng)距離的依賴(lài)問(wèn)題有優(yōu)勢(shì)。在RNN中,反向傳遞梯度容易彌散。雖然在LSTM中引入了遺忘門(mén)等記憶單元,但是仍然在長(zhǎng)序列時(shí)出現(xiàn)輸出只依賴(lài)于最近幾個(gè)輸入的情況。也就是當(dāng)依賴(lài)路徑變長(zhǎng)時(shí),當(dāng)前輸出受其他輸入之間的影響逐漸變小。但是在Self-Attention中,每一個(gè)輸入都與其他輸入進(jìn)行了attention,因此在每一個(gè)輸入中都包含了來(lái)至于其他輸入的信息,這樣使得每一個(gè)輸入與其他輸入的依賴(lài)路徑變得更短,更容易學(xué)得更長(zhǎng)的依賴(lài)關(guān)系。

實(shí)現(xiàn)

下面是我使用keras對(duì)Multi-Head Attention的一個(gè)實(shí)現(xiàn),源代碼如下:

class MultiHeadAttention(Layer):
    """
        multi head attention 的實(shí)現(xiàn)
        參考論文: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
    """

    def __init__(self, num_heads, projection_shape, d_model, **kwargs):
        self._num_heads = num_heads
        self._d_model = d_model
        self._projection_shape = projection_shape
        super(MultiHeadAttention, self).__init__(**kwargs)

    def __add_weight(self, shape, name):
        return self.add_weight(name=name, shape=shape, initializer='normal', trainable=True)

    def build(self, input_shape):
        """
           以下不包括 batch_size
            input_shape = [m, n, d_model]
            這里的 m, n, d_model 代表的是為映射前的輸入大小:
            shape(Q) = [m, d_model]
            shape(K) = [n, d_model]
            shape(V) = [n, d_model]

            d_k, d_v 是指 multi-head-attention 映射之后:
                            Q|K|V      x     QW|KW|VW 
            shape(Q*WQ) = [m, d_model] x [d_model, d_k] = [m, d_k]
            shape(K*WK) = [n, d_model] x [d_model, d_k] = [n, d_k]
            shape(V*WV) = [n, d_model] x [d_model, d_v] = [n, d_v]
            一般來(lái)說(shuō), dk, dv < d_model 因?yàn)檎撐闹兄赋鲇成鋵?shí)際上有降維的左右,這樣可以加快計(jì)算速度

            另外:
            shape(W_O)  =  [h * d_v, d_model]  
            
            符號(hào)與 *Attention is All You Need* 一致
            
        """
        d_k, d_v = self._projection_shape
        head_weight = self.__add_weight
        self._QW = head_weight([self._d_model, d_k * self._num_heads], "Q")
        self._KW = head_weight([self._d_model, d_k * self._num_heads], "K")
        self._VW = head_weight([self._d_model, d_k * self._num_heads], "V")
        self._OW = head_weight([self._num_heads * d_v, self._d_model], "O")

        super(MultiHeadAttention, self).build(input_shape)

    def __attention(self, Q, K, V):
        batch_dot = backend.batch_dot
        d_k, _ = self._projection_shape
        raw_weights = batch_dot(Q, tf.transpose(K, [0, 2, 1])) / tf.sqrt(tf.constant(d_k, dtype=tf.float32))
        attention_weights = tf.nn.softmax(raw_weights, axis=2)  # 每一行進(jìn)行 soft max
        return batch_dot(attention_weights, V)

    def _mul_every_batch_with(self, inputs, y):
        """
        將 inputs 的每一個(gè) batch 與 y 相乘, 產(chǎn)生一個(gè)新的張量
        :param inputs: [batch_size, m, n]
        :param y: [n, x]
        :return:[batch_size, m, x]
        """
        return tool.mul_every_batch_with(inputs, y)

    def _multi_attention(self, Q, K, V):
        mul_every_batch_with = self._mul_every_batch_with
        # 多次線性映射,連接
        QP = mul_every_batch_with(Q, self._QW)
        KP = mul_every_batch_with(K, self._KW)
        VP = mul_every_batch_with(V, self._VW)
        attention = self.__attention(QP, KP, VP)
        return mul_every_batch_with(attention, self._OW)

    def call(self, inputs, **kwargs):
        Q, K, V = inputs
        return self._multi_attention(Q, K, V)

也可以直接點(diǎn)鏈接attention.py。上面的代碼只單純的實(shí)現(xiàn)了MultiHeadAttention,并沒(méi)有實(shí)現(xiàn)transformer結(jié)構(gòu)。

參考

[1] Convolutional Sequence to Sequence Learning
[2] End-To-End Memory Networks
[3] Attention Is All You Need

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容