Transformer架構(gòu)詳解

Google 2017年論文Attention is all you need提出了Transformer模型,完全基于Attention mechanism,拋棄了傳統(tǒng)的CNNRNN。

1. Transformer架構(gòu)

Transformer

解釋下這個(gè)結(jié)構(gòu)圖。首先,Transformer模型也是使用經(jīng)典的encoder-decoder架構(gòu),由encoder和decoder兩部分組成。

上圖左側(cè)用Nx框出來(lái)的,就是我們encoder的一層。encoder一共有6層這樣的結(jié)構(gòu)。

上圖右側(cè)用Nx框出來(lái)的,就是我們decoder的一層。decoder一共有6層這樣的結(jié)構(gòu)。

輸入序列經(jīng)過(guò)word embeddingpositional embedding相加后,輸入到encoder中。

輸出序列經(jīng)過(guò)word embeddingpositional embedding相加后,輸入到decoder中。

最后,decoder輸出的結(jié)果,經(jīng)過(guò)一個(gè)線性層,然后計(jì)算softmax。

2. Encoder

encoder由6層相同的層組成,每一層分別由兩部分組成:

  • 第一部分是multi-head self-attention mechanism
  • 第二部分是position-wise feed-forward network,是一個(gè)全連接層。

兩部分,都有一個(gè)殘差連接(residual connection),然后接著一個(gè)Layer Normalization。

3. Decoder

與encoder類似,decoder也是由6個(gè)相同層組成,每一個(gè)層包括以下3個(gè)部分:

  • 第一部分是multi-head self-attention mechanism
  • 第二部分是multi-head context-attention mechanism
  • 第三部分是position-wise feed-forward network

同樣,上面三部分中每一部分,都有一個(gè)殘差連接(residual connection),后接著一個(gè)Layer Normalization。

4. Attention機(jī)制

Attention是指對(duì)于某個(gè)時(shí)刻的輸出y,它在輸入x上各個(gè)部分的注意力。這個(gè)注意力可以理解為權(quán)重。

attention機(jī)制有很多計(jì)算方式,下面是一張比較全面的表格:

image.png

seq2seq模型中,使用的是加性注意力(addtion attention)較多。

為什么這種attention叫做addtion attention呢?很簡(jiǎn)單,對(duì)于輸入序列隱狀態(tài)h_i和輸出序列的隱狀態(tài)s_t,它的處理方式很簡(jiǎn)單,直接合并為[s_t;h_i]

但是transformer模型使用的不是這種attention機(jī)制,使用的是另一種,叫做乘性注意力(multiplicative attention)。

那么這種乘性注意力機(jī)制是怎么樣的呢?從上表中的公式也可以看出來(lái):兩個(gè)隱狀態(tài)進(jìn)行點(diǎn)積!

4.1 Self-attention是什么?

上面我們說(shuō)的attention機(jī)制的時(shí)候,都會(huì)提到兩個(gè)隱狀態(tài),分別是h_is_t,前者是輸入序列第i個(gè)位置產(chǎn)生的隱狀態(tài),后者是輸出序列在第t個(gè)位置產(chǎn)生的隱狀態(tài)。

所謂self-attention實(shí)際上就是輸出序列就是輸入序列,因此計(jì)算自己的attention得分,就叫做self-attention!

4.2 Context-attention是什么?

context-attention是encoder和decoder之間的attention!,所以,也可以成為encoder-decoder attention!

不管是self-attention還是context-attention,它們計(jì)算attention分?jǐn)?shù)的時(shí)候,可以選擇很多方式,比如上面表中提到的:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那么Transformer模型,采用的是哪種呢?答案是:scaled dot-product attention。

4.3 Scaled dot-product attention是什么?

論文Attention is all you need里面對(duì)于attention機(jī)制的描述是這樣的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

這句話描述得很清楚了。翻譯過(guò)來(lái)就是:通過(guò)確定Q和K之間的相似程度來(lái)選擇V

用公式來(lái)描述更加清晰:
Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V\tag{4.3.1}

scaled dot-product attentiondot-product attention唯一區(qū)別是,scaled dot-product attention有一個(gè)縮放因子\frac{1}{\sqrt{d_k}}。

上面公式中d_k表示的是K的維度,在論文中,默認(rèn)是64。

那么為什么需要加上這個(gè)縮放因子呢?論文中給出了解釋:對(duì)于d_k很大時(shí),點(diǎn)積得到的結(jié)果維度很大,使得結(jié)果處理softmax函數(shù)梯度很小的區(qū)域。

我們知道,梯度很小時(shí),這對(duì)反向傳播不利。為了克服這個(gè)負(fù)面影響,除以一個(gè)縮放因子,在一定程度上減緩這種情況。

為什么是\frac{1}{\sqrt{d_k}}呢?論文沒(méi)有進(jìn)一步說(shuō)明。個(gè)人覺(jué)得你可以使用其他縮放因子,看看模型效果有沒(méi)有提升。

論文中也提供了一張很清晰的結(jié)果圖,供大家參考:

image.png

首先說(shuō)明一下我們的K、Q、V是什么:

  • 在encoder的self-attention中,Q、K、V都來(lái)自同一個(gè)地方(相等),他們是上一層encoder的輸出。對(duì)于第一層encoder,它們就是word embeddingpositional encoding相加得到的輸入。

  • 在decoder的self-attention中,Q、K、V都來(lái)自同一個(gè)地方(相等),他們是上一層decoder的輸出。對(duì)于第一層decoder,它們就是word embeddingpositional encoding相加得到的輸入。但是對(duì)于decoder,我們不希望它能獲得下一個(gè)time step,因此我們需要進(jìn)行sequence masking。

  • 在encoder-decoder attention中,Q來(lái)自于decoder的上一層的輸出,K和V來(lái)自于encoder的輸出,K和V是一樣的。

  • Q、K、V三者的維度一樣,即d_q=d_k=d_v。

4.4 Scaled dot-product attention代碼實(shí)現(xiàn)

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """
    Scaled dot-product attention mechanism.
    """

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向傳播

        args:
            q: Queries張量,形狀[B, L_q, D_q]
            k: keys張量, 形狀[B, L_k, D_k]
            v: Values張量,形狀[B, L_v, D_v]
            scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
            attn_mask: Masking張量,形狀[B, L_q, L_k]
        returns:
            上下文張量和attention張量
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            # 給需要mask的地方設(shè)置一個(gè)負(fù)無(wú)窮
            attention = attention.masked_fill_(attn_mask, -np.inf)
        # 計(jì)算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和V做點(diǎn)積
        context = torch.bmm(attention, v)

        return context, attention

5. Multi-head attention是什么呢?

理解了Scaled dot-product attention,Multi-head attention也很簡(jiǎn)單了。論文提到,他們發(fā)現(xiàn)將Q、K、V通過(guò)一個(gè)線性映射之后,分成h份,對(duì)每一份進(jìn)行scaled dot-product attention效果更好。然后,把各個(gè)部分的結(jié)果合并起來(lái),再次經(jīng)過(guò)線性映射,得到最終的輸出。這就是所謂的multi-head attention。上面的超參數(shù)h就是heads數(shù)量。論文默認(rèn)是8

multi-head attention的結(jié)構(gòu)圖如下:


image.png

值得注意的是,上面所說(shuō)的分成h份是在d_k、d_q、d_v維度上面進(jìn)行切分的。因此,進(jìn)入到scaled dot-product attention的d_k實(shí)際上等于未進(jìn)入之前的D_k/h。

Multi-head attention允許模型加入不同位置的表示子空間的信息。

Multi-head attention的公式如下:
MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O\tag{5.1}

其中,
head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)\tag{5.2}

論文中,d_{model}=512, h=8。所以scaled dot-product attention里面的
d_q=d_k=d_v=d_{model}/h=512/8=64

5.1 Multi-head attention代碼實(shí)現(xiàn)

class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim / num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        # multi-head attention之后需要做layer norm
        self.layer_num = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        # 殘差連接
        residual = query

        batch_size = key.size(0)

        # linear projection
        query = self.linear_q(query) # [B, L, D]
        key = self.linear_k(key) # [B, L, D]
        value = self.linear_v(value) # [B, L, D]

        # split by head
        query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
        key = key.view(batch_size * num_heads, -1, dim_per_head) # 
        value = value.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
            query, key, value, scale, attn_mask
        ) 

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)
        
        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_num(residual + output)

        return output, attention

上面代碼中出現(xiàn)了 Residual connectionLayer normalization。下面進(jìn)行解釋:

5.1.1 Residual connection是什么?

殘差連接其實(shí)比較簡(jiǎn)單!看圖就會(huì)比較清晰:

image.png

假設(shè)網(wǎng)絡(luò)中某個(gè)層對(duì)輸入x作用后的輸出為F(x),那么增加residual connection之后,變成:

F(x) + x \tag{5.2.1}

這個(gè)+x操作被稱為shotcut

殘差結(jié)構(gòu)因?yàn)樵黾恿艘豁?xiàng)x,該層網(wǎng)絡(luò)對(duì)x求偏導(dǎo)時(shí),為常數(shù)項(xiàng)1!所以可以在反向傳播過(guò)程中,梯度連乘,不會(huì)造成梯度消失!

5.1.2 Layer normalization是什么?

歸一化層,主要有這幾種方法,BatchNorm(2015年)、LayerNorm(2016年)、InstanceNorm(2016年)、GroupNorm(2018年);
將輸入的圖像shape記為[N,C,H,W],這幾個(gè)方法主要區(qū)別是:

  • BatchNorm:batch方向做歸一化,計(jì)算NHW的均值,對(duì)小batchsize效果不好;(BN主要缺點(diǎn)是對(duì)batchsize的大小比較敏感,由于每次計(jì)算均值和方差是在一個(gè)batch上,所以如果batchsize太小,則計(jì)算的均值、方差不足以代表整個(gè)數(shù)據(jù)分布)

  • LayerNorm:channel方向做歸一化,計(jì)算CHW的均值;(對(duì)RNN作用明顯)

  • InstanceNorm:一個(gè)batch,一個(gè)channel內(nèi)做歸一化。計(jì)算HW的均值,用在風(fēng)格化遷移;(因?yàn)樵趫D像風(fēng)格化中,生成結(jié)果主要依賴于某個(gè)圖像實(shí)例,所以對(duì)整個(gè)batch歸一化不適合圖像風(fēng)格化中,因而對(duì)HW做歸一化。可以加速模型收斂,并且保持每個(gè)圖像實(shí)例之間的獨(dú)立。)

  • GroupNorm:將channel方向分group,然后每個(gè)group內(nèi)做歸一化,算(C//G)HW的均值;這樣與batchsize無(wú)關(guān),不受其約束。

Normalization layers

6. Mask是什么?

mask顧名思義就是掩碼,大概意思是對(duì)某些值進(jìn)行掩蓋,使其不產(chǎn)生效果.

需要說(shuō)明的是,Transformer模型中有兩種mask。分別是padding masksequence mask。其中,padding mask在所有的scaled dot-product attention里都需要用到,而sequence mask只在decoder的self-attention中用到。

所以,我們之前的ScaledDotProductAttention的forward方法里的參數(shù)attn_mask在不同的地方有不同的含義。

6.1 Padding mask

什么是padding mask呢?回想一下,我們的每個(gè)批次輸入序列長(zhǎng)度是不一樣的!也就是說(shuō),我們要對(duì)輸入序列進(jìn)行對(duì)齊!具體來(lái)說(shuō),就是給較短序列后面填充0。因?yàn)檫@些填充位置,其實(shí)沒(méi)有意義,所以我們的attention機(jī)制不應(yīng)該把注意力放在這些位置上,所以我們需要進(jìn)行一些處理。

具體做法是:把這些位置的值加上一個(gè)非常大的負(fù)數(shù)(可以是負(fù)無(wú)窮),這樣的話,經(jīng)過(guò)softmax,這些位置的概率就會(huì)接近0。

而我們的padding mask實(shí)際上是一個(gè)張量,每個(gè)值都是一個(gè)Boolean,值為False的地方就是我們要進(jìn)行處理的地方。

下面是代碼實(shí)現(xiàn):

def padding_mask(seq_q, seq_k):
    # seq_k和seq_q的形狀都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]

[B,L]->[B,1,L]->[B,L,L]

F F T T
F F T T
F F T T
F F T T

6.2 Sequence mask

sequence mask是為了使得decoder不能看到未來(lái)的信息。也就是對(duì)于一個(gè)序列,在time step為t的時(shí)刻,我們的解碼輸出只能依賴于t時(shí)刻之前的輸出,而不能依賴t之后的輸出。因此我們需要想一個(gè)辦法,把t之后的信息給隱藏起來(lái)。

那具體如何做呢?也很簡(jiǎn)單:產(chǎn)生一個(gè)上三角矩陣,上三角矩陣的值全為1,下三角的值全為0,對(duì)角線值也為0。把這個(gè)矩陣作用在每一個(gè)序列上,就可以達(dá)到我們的目的。

具體代碼如下:

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

[B,L,L]

0 1 1 1
0 0 1 1
0 0 0 1
0 0 0 0

哈佛大學(xué)的文章The Annotated Transformer有一張效果圖:

image.png

值得注意的是,本來(lái)mask只需要二維矩陣即可,但是考慮到我們的輸入序列都是批量的,所以我們需要把原本二維矩陣擴(kuò)張成3維張量。上面代碼中,已經(jīng)做了處理。

回到本節(jié)開(kāi)始的問(wèn)題,attn_mask參數(shù)有幾種情況?分別是什么意思?

  • 對(duì)于decoder的self-attention,里面使用的scaled dot-product attention,同時(shí)需要padding masksequence mask作為attn_mask,具體實(shí)現(xiàn)就是兩個(gè)mask相加作為attn_mask。
  • 其它情況,attn_mask都等于padding mask。

7. Positional encoding是什么?

就目前而言,Transformer架構(gòu)似乎少了點(diǎn)東西。沒(méi)錯(cuò),那就是它對(duì)序列的順序沒(méi)有約束!我們知道序列的順序是一個(gè)很重要的信息,如果缺失了這個(gè)信息,可能我們的結(jié)果就是:所有詞語(yǔ)都對(duì)了,但是無(wú)法組成有意義的語(yǔ)句。

為了解決這個(gè)問(wèn)題,論文中提出了positional encoding。一句話概括就是:對(duì)序列中的詞語(yǔ)出現(xiàn)的位置進(jìn)行編碼!如果對(duì)位置進(jìn)行編碼,那么我們的模型就可以捕捉順序信息。

那么具體怎么做呢?論文的實(shí)現(xiàn)是使用正余弦函數(shù)。公式如下:
PF(pos,2i)=sin(pos/10000^{2i/d_{model}})\tag{7.1}

PF(pos,2i+1)=cos(pos/10000^{2i/d_{model}})\tag{7.2}

其中,pos是指詞語(yǔ)在序列中的位置??梢钥闯?,在偶數(shù)位置,使用正弦編碼,在奇數(shù)位置,使用余弦編碼。

上面公式中的d_{model}是模型的維度,論文默認(rèn)是512。

這個(gè)編碼公式的意思就是:給定詞語(yǔ)的位置pos,我們可以把它編碼成d_{model}維的向量!也就是說(shuō),位置編碼的每一個(gè)維度對(duì)應(yīng)正弦曲線,波長(zhǎng)構(gòu)成了從2\pi10000*2\pi的等比序列。

Postional encoding是對(duì)詞匯的位置編碼。

7.1 Positional encoding代碼實(shí)現(xiàn)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_seq_len):
        """
        初始化

        args:
            d_model: 一個(gè)標(biāo)量。模型的維度,論文默認(rèn)是512
            max_seq_len: 一個(gè)標(biāo)量。文本序列的最大長(zhǎng)度
        """
        super(PositionalEncoding, self).__init__()

        # 根據(jù)論文給出的公式,構(gòu)造出PE矩陣
        position_encoding = np.array([
            [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
            for pos in range(max_seq_len)
        ])
        # 偶數(shù)列使用sin,奇數(shù)列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩陣的一次行,加上一個(gè)全是0的向量,代表這`PAD`的positional_encoding
        # 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding,兩者十分類似
        # 那么為什么需要這個(gè)額外的PAD的編碼呢?很簡(jiǎn)單,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易,我們需要對(duì)齊,
        # 短的序列我們使用0在結(jié)尾不全,我們也需要這些補(bǔ)全位置的編碼,也就是`PAD`對(duì)應(yīng)的位置編碼
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))

        # 嵌入操作,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
        # word embedding中如果詞典增加`UNK`,我們也需要+1。
        self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)

    def forward(self, input_len):
        """
        神經(jīng)網(wǎng)絡(luò)前向傳播

        args:
            input_len: 一個(gè)張量,形狀為[BATCH_SIZE, 1]。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度。

        returns:
            返回這一批序列的位置編碼,進(jìn)行了對(duì)齊。
        """

        # 找出這一批序列的最大長(zhǎng)度
        max_len = torch.max(input_len)
        # 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
        # 這里range從1開(kāi)始也是因?yàn)橐荛_(kāi)PAD(0)的位置
        input_pos = torch.LongTensor(
            [list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
        )
        return self.position_encoding(input_pos)

8. Word embedding是什么?

Word embedding是對(duì)序列中的詞匯的編碼,把每一個(gè)詞匯編碼成d_{model}維的向量!它實(shí)際上就是一個(gè)二維浮點(diǎn)矩陣,里面的權(quán)重是可訓(xùn)練參數(shù),我們只需要把這個(gè)矩陣構(gòu)建出來(lái)就完成了word embedding的工作。

embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)

上面vocab_size是詞典大小,embedding_size是詞嵌入的維度大小,論文里面就是等于d_{model}=512。所以word embedding矩陣就是一個(gè)vocab_size*embedding_size的二維張量。

9. Position-wise Feed-Forward netword是什么?

這是一個(gè)全連接網(wǎng)絡(luò),包含連個(gè)線性變換和一個(gè)非線性函數(shù)(ReLU)。公式如下:
FFN(x)=max(0,xW_1+b_1)W2+b2\tag{9.1}

這個(gè)線性變換在不同的位置都是一樣的,并且在不同的層之間使用不同的參數(shù)。

論文提到,這個(gè)公式還可以用兩個(gè)核大小為1的一維卷積來(lái)解釋,卷積的輸入輸出都是d_{model}=512,中間層維度是d_{ff}=2048

代碼如下:

class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output

10. 完整代碼

至此,所有的細(xì)節(jié)都解釋完了。

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """
    Scaled dot-product attention mechanism.
    """

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向傳播

        args:
            q: Queries張量,形狀[B, L_q, D_q]
            k: keys張量, 形狀[B, L_k, D_k]
            v: Values張量,形狀[B, L_v, D_v]
            scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
            attn_mask: Masking張量,形狀[B, L_q, L_k]
        returns:
            上下文張量和attention張量
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            # 給需要mask的地方設(shè)置一個(gè)負(fù)無(wú)窮
            attention = attention.masked_fill_(attn_mask, -np.inf)
        # 計(jì)算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和V做點(diǎn)積
        context = torch.bmm(attention, v)

        return context, attention

class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim / num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        # multi-head attention之后需要做layer norm
        self.layer_num = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        # 殘差連接
        residual = query

        batch_size = key.size(0)

        # linear projection
        query = self.linear_q(query) # [B, L, D]
        key = self.linear_k(key) # [B, L, D]
        value = self.linear_v(value) # [B, L, D]

        # split by head
        query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
        key = key.view(batch_size * num_heads, -1, dim_per_head) # 
        value = value.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
            query, key, value, scale, attn_mask
        ) 

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)
        
        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_num(residual + output)

        return output, attention

def padding_mask(seq_q, seq_k):
    # seq_k和seq_q的形狀都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_seq_len):
        """
        初始化

        args:
            d_model: 一個(gè)標(biāo)量。模型的維度,論文默認(rèn)是512
            max_seq_len: 一個(gè)標(biāo)量。文本序列的最大長(zhǎng)度
        """
        super(PositionalEncoding, self).__init__()

        # 根據(jù)論文給出的公式,構(gòu)造出PE矩陣
        position_encoding = np.array([
            [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
            for pos in range(max_seq_len)
        ])
        # 偶數(shù)列使用sin,奇數(shù)列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩陣的一次行,加上一個(gè)全是0的向量,代表這`PAD`的positional_encoding
        # 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding,兩者十分類似
        # 那么為什么需要這個(gè)額外的PAD的編碼呢?很簡(jiǎn)單,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易,我們需要對(duì)齊,
        # 短的序列我們使用0在結(jié)尾不全,我們也需要這些補(bǔ)全位置的編碼,也就是`PAD`對(duì)應(yīng)的位置編碼
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))

        # 嵌入操作,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
        # word embedding中如果詞典增加`UNK`,我們也需要+1。
        self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)

    def forward(self, input_len):
        """
        神經(jīng)網(wǎng)絡(luò)前向傳播

        args:
            input_len: 一個(gè)張量,形狀為[BATCH_SIZE, 1]。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度。

        returns:
            返回這一批序列的位置編碼,進(jìn)行了對(duì)齊。
        """

        # 找出這一批序列的最大長(zhǎng)度
        max_len = torch.max(input_len)
        # 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
        # 這里range從1開(kāi)始也是因?yàn)橐荛_(kāi)PAD(0)的位置
        input_pos = torch.LongTensor(
            [list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
        )
        return self.position_encoding(input_pos)

# embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 獲得輸入的詞嵌入編碼
# seq_embedding = seq_embedding(inputs) * np.sqrt(d_model)

class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output

class EncoderLayer(nn.Module):
    """Encoder的一層。"""
    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, inputs, attn_mask=None):
        # self attention
        context, attention = self.attention(inputs, inputs, inputs, attn_mask)

        # feed forward network
        output = self.feed_forward(context)

        return output, attention


class Encoder(nn.Module):
    """多層EncoderLayer組成的Encoder"""
    def __init__(self,
                vocab_size,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Encoder, self).__init__()

        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
        )

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_mask = padding_mask(inputs, inputs)

        attentions = []
        for encoder in self.encoder_layers:
            output, attention = encoder(output, self_attention_mask)
            attentions.append(attention)

        return output, attentions

class DecoderLayer(nn.Module):
    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(DecoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self,
                dec_inputs,
                enc_outputs,
                self_attn_mask=None,
                context_attn_mask=None):
        # self attention, all inputs are decoder inputs
        dec_output, self_attention = self.attention(dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

        # context attention
        # query is decoder's outputs, key and value are encoder's inputs
        dec_output, context_attention = self.attention(dec_output, enc_outputs, enc_outputs, context_attn_mask)

        # decoder's output, or context
        dec_output = self.feed_forward(dec_output)

        return dec_output, self_attention, context_attention

class Decoder(nn.Module):
    def __init__(self,
                vocab_size,
                max_seq_len,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Decoder).__init__()

        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
        )
        
        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs)
        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

        self_attentions = []
        context_attentions = []
        for decoder in self.decoder_layers:
            output, self_attn, context_attn = decoder(
            output, enc_output, self_attn_mask, context_attn_mask)
            self_attentions.append(self_attn)
            context_attentions.append(context_attn)

        return output, self_attentions, context_attentions

    
class Transformer(nn.Module):
    def __init__(self,
                src_vocab_size,
                src_max_len,
                tgt_vocab_size,
                tgt_max_len,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Transformer).__init__()

        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)
        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)

        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
        self.softmax = nn.Softmax()

    def forward(self, src_seq, src_len, tgt_seq, tgt_len):
        context_attn_mask = padding_mask(tgt_seq, src_seq)

        output, enc_self_attn = self.encoder(src_seq, src_len)

        output, dec_self_attn, ctx_attn = self.decoder(tgt_seq, tgt_len, output, context_attn_mask)

        output = self.linear(output)
        output = self.softmax(output)

        return output, enc_self_attn, dec_self_attn, ctx_attn
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容