Transformer解讀(附pytorch代碼)

Transformer早在2017年就出現(xiàn)了,直到BERT問世,Transformer開始在NLP大放光彩,目前比較好的推進(jìn)就是Transformer-XL(后期附上)。這里主要針對(duì)論文和程序進(jìn)行解讀,如有不詳實(shí)之處,歡迎指出交流,如需了解更多細(xì)節(jié)之處,推薦知乎上川陀學(xué)者寫的。本文程序的git地址在這里。程序如果有不詳實(shí)之處,歡迎指出交流~

前言

2017年6月,Google發(fā)布了一篇論文《Attention is All You Need》,在這篇論文中,提出了 Transformer 的模型,其旨在全部利用Attention方式來替代掉RNN的循環(huán)機(jī)制,從而通過實(shí)現(xiàn)并行化計(jì)算提速。在Transformer出現(xiàn)之前,RNN系列網(wǎng)絡(luò)以及seq2seq+attention架構(gòu)基本上鑄就了所有NLP任務(wù)的鐵桶江山。由于Attention模型本身就可以看到全局的信息, Transformer實(shí)現(xiàn)了完全不依賴于RNN結(jié)構(gòu)僅利用Attention機(jī)制,在其并行性和對(duì)全局信息的有效處理上獲得了比之前更好的效果。

Transformer的整體結(jié)構(gòu)

圖1:Transformer整體結(jié)構(gòu)

Transformer的整體結(jié)構(gòu)就是分成編碼器和解碼器兩部分,并且兩部分之間是有聯(lián)系的,可以注意到編碼器的輸出是解碼器第二個(gè)Multi-head Attention中和的輸入,這里,我們把編碼器的輸出稱為state用來初始化解碼器的狀態(tài),而實(shí)際上對(duì)于解碼器而言,每一層的解碼器的state是一樣的(都是編碼器的輸出),并不會(huì)像RNN中的state一樣改變。對(duì)應(yīng)的pytorch程序如下:

class transformer(nn.Module):
    def __init__(self, enc_net, dec_net):
        super(transformer, self).__init__()
        self.enc_net = enc_net   # TransformerEncoder的對(duì)象   
        self.dec_net = dec_net   # TransformerDecoder的對(duì)象
    
    def forward(self, enc_X, dec_X, valid_length=None, max_seq_len=None):
        """
        enc_X: 編碼器的輸入
        dec_X: 解碼器的輸入
        valid_length: 編碼器的輸入對(duì)應(yīng)的valid_length,主要用于編碼器attention的masksoftmax中,
                      并且還用于解碼器的第二個(gè)attention的masksoftmax中
        max_seq_len:  位置編碼時(shí)調(diào)整sin和cos周期大小的,默認(rèn)大小為enc_X的第一個(gè)維度seq_len
        """
        
        # 1、通過編碼器得到編碼器最后一層的輸出enc_output
        enc_output = self.enc_net(enc_X, valid_length, max_seq_len)
        # 2、state為解碼器的初始狀態(tài),state包含兩個(gè)元素,分別為[enc_output, valid_length]
        state = self.dec_net.init_state(enc_output, valid_length)
        # 3、通過解碼器得到編碼器最后一層到線性層的輸出output,這里的output不是解碼器最后一層的輸出,而是
        #    最后一層再連接線性層的輸出
        output = self.dec_net(dec_X, state)
        return output

縱觀圖1整個(gè)Transformer的結(jié)構(gòu),其核心模塊其實(shí)就是三個(gè):Multi-Head attention、Feed Forward 以及 Add&Norm。這里關(guān)于Multi-Head attention部分只講程序的實(shí)現(xiàn),關(guān)于更多細(xì)節(jié)原理,請移至簡書開頭推薦的知乎鏈接。

Multi-Head Attention實(shí)現(xiàn)

Transformer中的attention采用的是多頭的self-attention結(jié)構(gòu),并且在編碼器中,由于不同的輸入mask的部分不一樣,因此在softmax之前采用了mask操作,并且解碼時(shí)由于不能看到t時(shí)刻之后的數(shù)據(jù),同樣在解碼器的第一個(gè)Multi-Head attention中采用了mask操作,但是二者是不同的。因?yàn)榫幋a器被mask的部分是需要在輸入到Transformer之前事先確定好,而解碼器第一個(gè)Multi-Head attention被mask的部分其實(shí)就是從t=1時(shí)刻開始一直到t=seq_len結(jié)束,對(duì)應(yīng)于圖2。在圖2中,橫坐標(biāo)表示解碼器一個(gè)batch上的輸入序列長度(也就是t),紫色部分為被mask的部分,黃色部分為未被mask的部分,可以看出,隨著t的增加,被mask的部分逐一減少。而解碼器第二個(gè)Multi-Head attention的mask操作和編碼器中是一樣的。


圖2:解碼器第一個(gè)Multi-Head attention中的mask操作

mask+softmax程序如下:

def masked_softmax(X, valid_length, value=-1e6):
    # 如果valid_length是一維的:valid_length的維度等于batch_size的大小
    # 對(duì)每一個(gè)batch去確定一個(gè)valid_length,因此valid_length的維度與batch_size大小相同
    # 再將valid_length內(nèi)的元素通過repeat操作將valid_length內(nèi)的元素repeat seq_len(X.size()[1])次
    # 結(jié)果就是對(duì)每一個(gè)batch上的X根據(jù)valid_length輸出相應(yīng)的attention weights,因此一個(gè)batch上的attention weights是一樣的

    # 如果valid_length是二維的:valid_length的維度等于[batch_size, seq_length]
    # 此時(shí)是針對(duì)每一個(gè)batch的每一句話都設(shè)置了seq_length
    if valid_length is None:
        return F.softmax(X, dim=-1)
    else:
        X_size = X.size()
        device = valid_length.device
        if valid_length.dim() == 1:
            valid_length = torch.tensor(valid_length.cpu().numpy().repeat(X_size[1], axis=0),
                                        dtype=torch.float, device=device) if valid_length.is_cuda \
                else torch.tensor(valid_length.numpy().repeat(X_size[1], axis=0),
                                  dtype=torch.float, device=device)
        else:
            valid_length = valid_length.view([-1])
        X = X.view([-1, X_size[-1]])
        max_seq_length = X_size[-1]
        valid_length = valid_length.to(torch.device('cpu'))
        mask = torch.arange(max_seq_length, dtype=torch.float)[None, :] >= valid_length[:, None]
        X[mask] = value
        X = X.view(X_size)
        return F.softmax(X, dim=-1)

mask操作其實(shí)就是對(duì)于無效的輸入,用一個(gè)負(fù)無窮的值代替這個(gè)輸入,這樣在softmax的時(shí)候其值就是0。而在attention中(attention操作見下式),softmax的操作出來的結(jié)果其實(shí)就是attention weights,當(dāng)attention weights為0時(shí),表示不需要attention該位置的信息。
softmax(\frac{QK^{T}}{\sqrtu0z1t8os})V
對(duì)于Multi-Head attention的實(shí)現(xiàn),其實(shí)并沒有像論文原文寫的那樣,逐一實(shí)現(xiàn)多個(gè)attention,再將最后的結(jié)果concat,并且通過一個(gè)輸出權(quán)重輸出。下面通過程序和公式講解一下實(shí)際的實(shí)現(xiàn)過程,這里假設(shè)Q,K,V的來源是一樣的,都是X,其維度為[batch_size, seq_len, input_size]。(需要注意的是在解碼器中第二個(gè)Multi-Head的輸入中QKV的來源不一樣)

圖3:論文原文中的attention操作

class DotProductAttention(nn.Module):
    # 經(jīng)過DotProductAttention之后,輸入輸出的維度是不變的,都是[batch_size*h, seq_len, d_model//h]
    def __init__(self, dropout,):
        super(DotProductAttention, self).__init__()
        self.drop = nn.Dropout(dropout)

    def forward(self, Q, K, V, valid_length):
        # Q, K, V shape:[batch_size*h, seq_len, d_model//h]
        d_model = Q.size()[-1]  # int
        # torch.bmm表示批次之間(>2維)的矩陣相乘
        attention_scores = torch.bmm(Q, K.transpose(1, 2))/math.sqrt(d_model)
        # attention_scores shape: [batch_size*h, seq_len, seq_len]
        attention_weights = self.drop(masked_softmax(attention_scores, valid_length))
        return torch.bmm(attention_weights, V)  # [batch_size*h, seq_len, d_model//h]
class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, dropout,):
        super(MultiHeadAttention, self).__init__()
        # 保證MultiHeadAttention的輸入輸出tensor的維度一樣
        assert hidden_size % num_heads == 0
        # hidden_size => d_model
        self.num_heads = num_heads
        # num_heads => h
        self.hidden_size = hidden_size
        # 這里的d_model為中間隱層單元的神經(jīng)元數(shù)目,d_model=h*d_v=h*d_k=h*d_q
        self.Wq = nn.Linear(input_size, hidden_size, bias=False)
        self.Wk = nn.Linear(input_size, hidden_size, bias=False)
        self.Wv = nn.Linear(input_size, hidden_size, bias=False)
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=False)
        self.attention = DotProductAttention(dropout)

    def _transpose_qkv(self, X):
        # X的輸入維度為[batch_size, seq_len, d_model]
        # 通過該函數(shù)將X的維度改變成[batch_size*num_heads, seq_len, d_model//num_heads]
        self._batch, self._seq_len = X.size()[0], X.size()[1]
        X = X.view([self._batch, self._seq_len, self.num_heads, self.hidden_size//self.num_heads])  # [batch_size, seq_len, num_heads, d_model//num_heads]
        X = X.permute([0, 2, 1, 3])  # [batch_size, num_heads, seq_len, d_model//num_heads]
        return X.contiguous().view([self._batch*self.num_heads, self._seq_len, self.hidden_size//self.num_heads])

    def _transpose_output(self, X):
        X = X.view([self._batch, self.num_heads, -1, self.hidden_size//self.num_heads])
        X = X.permute([0, 2, 1, 3])
        return X.contiguous().view([self._batch, -1, self.hidden_size])

    def forward(self, query, key, value, valid_length):
        Q = self._transpose_qkv(self.Wq(query))
        K = self._transpose_qkv(self.Wk(key))
        V = self._transpose_qkv(self.Wv(value))
        # 由于輸入的valid_length是相對(duì)batch輸入的,而經(jīng)過_transpose_qkv之后,
        # batch的大小發(fā)生了改變,Q的第一維度由原來的batch改為batch*num_heads
        # 因此,需要對(duì)valid_length進(jìn)行復(fù)制,也就是進(jìn)行np.title的操作
        if valid_length is not None:
            device = valid_length.device
            valid_length = valid_length.cpu().numpy() if valid_length.is_cuda else valid_length.numpy()
            if valid_length.ndim == 1:
                valid_length = np.tile(valid_length, self.num_heads)
            else:
                valid_length = np.tile(valid_length, [self.num_heads, 1])
            valid_length = torch.tensor(valid_length, dtype=torch.float, device=device)
        output = self.attention(Q, K, V, valid_length)
        output_concat = self._transpose_output(output)
        return self.Wo(output_concat)

首先,對(duì)于輸入X,通過三個(gè)權(quán)重變量得到Q,K,V,此時(shí)三者維度相同,都是[batch, seq_len, d_model],然后對(duì)其進(jìn)行維度變換:[batch, seq_len, h, d_model//h]==>[batch, h, seq_len, d]==>[batch×h, seq_len, d],其中d=d_model//h,因此直接將變換后的Q,K,V直接做DotProductAttention就可以實(shí)現(xiàn)Multi-Head attention,最后只需要將DotProductAttention輸出的維度依次變換回去,然后乘以輸出權(quán)重就可以了。關(guān)于程序中的參數(shù)valid_length已在程序中做了詳細(xì)的解讀,這里不再贅述,注意的是輸入的valid_length是針對(duì)batch這個(gè)維度的,而實(shí)際操作中由于X的batch維度發(fā)生了改變(由batch變成了batch×h),因此需要對(duì)valid_length進(jìn)行復(fù)制。

PositionWiseFFN的實(shí)現(xiàn)

FFN的實(shí)現(xiàn)是很容易的,其實(shí)就是對(duì)輸入進(jìn)行第一個(gè)線性變換,其輸出加上ReLU激活函數(shù),然后在進(jìn)行第二個(gè)線性變換就可以了。

class PositionWiseFFN(nn.Module):
    # y = w*[max(0, wx+b)]x+b
    def __init__(self, input_size, fft_hidden_size, output_size,):
        super(PositionWiseFFN, self).__init__()
        self.FFN1 = nn.Linear(input_size, fft_hidden_size)
        self.FFN2 = nn.Linear(fft_hidden_size, output_size)

    def forward(self, X):
        return self.FFN2(F.relu(self.FFN1(X)))

Add&Norm的實(shí)現(xiàn)

Add&norm的實(shí)現(xiàn)就是利用殘差網(wǎng)絡(luò)進(jìn)行連接,最后將連接的結(jié)果接上LN,值得注意的是,程序在Y的輸出中加入了dropout正則化。同樣的正則化技術(shù)還出現(xiàn)在masked softmax之后和positional encoding之后。

class AddNorm(nn.Module):
    def __init__(self, hidden_size, dropout,):
        super(AddNorm, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.LN = nn.LayerNorm(hidden_size)

    def forward(self, X, Y):
        assert X.size() == Y.size()
        return self.LN(self.drop(Y) + X)

positional encoding

positional encoding的實(shí)現(xiàn)很簡單,其實(shí)就是對(duì)輸入序列給定一個(gè)唯一的位置,采用sin和cos的方式給了一個(gè)位置編碼,其中sin處理的是偶數(shù)位置,cos處理的是奇數(shù)位置。但是,這一塊的工作確實(shí)非常重要的,因?yàn)閷?duì)于序列而言最主要的就是位置信息,顯然BERT是沒有去采用positional encoding(盡管在BERT的論文里有一個(gè)Position Embeddings的輸入,但是顯然描述的不是Transformer中要描述的位置信息),后續(xù)BERT在這一方面的改進(jìn)工作體現(xiàn)在了XLNet中(其采用了Transformer-XL的結(jié)構(gòu)),后續(xù)的簡書中再介紹該部分的內(nèi)容。

class PositionalEncoding(nn.Module):
    def __init__(self, dropout,):
        super(PositionalEncoding, self).__init__()

    def forward(self, X, max_seq_len=None):
        if max_seq_len is None:
            max_seq_len = X.size()[1]
        # X為wordEmbedding的輸入,PositionalEncoding與batch沒有關(guān)系
        # max_seq_len越大,sin()或者cos()的周期越小,同樣維度
        # 的X,針對(duì)不同的max_seq_len就可以得到不同的positionalEncoding
        assert X.size()[1] <= max_seq_len
        # X的維度為: [batch_size, seq_len, embed_size]
        # 其中: seq_len = l, embed_size = d
        l, d = X.size()[1], X.size()[-1]
        # P_{i,2j}   = sin(i/10000^{2j/d})
        # P_{i,2j+1} = cos(i/10000^{2j/d})
        # for i=0,1,...,l-1 and j=0,1,2,...,[(d-2)/2]
        max_seq_len = int((max_seq_len//l)*l)
        P = np.zeros([1, l, d])
        # T = i/10000^{2j/d}
        T = [i*1.0/10000**(2*j*1.0/d) for i in range(0, max_seq_len, max_seq_len//l) for j in range((d+1)//2)]
        T = np.array(T).reshape([l, (d+1)//2])
        if d % 2 != 0:
            P[0, :, 1::2] = np.cos(T[:, :-1])
        else:
            P[0, :, 1::2] = np.cos(T)
        P[0, :, 0::2] = np.sin(T)
        return torch.tensor(P, dtype=torch.float, device=X.device)

編碼器實(shí)現(xiàn)和解碼器的實(shí)現(xiàn)

無論是編碼器還是解碼器,其實(shí)都是用上面說的三個(gè)基本模塊堆疊而成,具體的實(shí)現(xiàn)細(xì)節(jié)大家可以看簡書開頭的git地址,這里需要強(qiáng)調(diào)的是以下幾點(diǎn):

  • 無論是編碼器還是解碼器,都在word embedding后面乘 上\sqrt{d_{model}},防止其值過小;
  • 論文里面提到了他們用的優(yōu)化器,是以\beta_1=0.9,\beta_2=0.98\epsilon=10^{-9}的Adam為基礎(chǔ),而后使用一種warmup的學(xué)習(xí)率調(diào)整方式來進(jìn)行調(diào)節(jié)。具體公式如下:基本上就是先用一個(gè)固定warmup_steps進(jìn)行學(xué)習(xí)率的線性增長,而后到達(dá)warmup_steps之后會(huì)隨著step_num的增長而逐漸減小。
    l_{rate}=d_{model}^{-0.5}*min(step\_num^{-0.5},step\_num*warmup\_steps^{-1.5})
class NoamOpt:
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer    # 優(yōu)化器
        self._step = 0                # 步長
        self.warmup = warmup          # warmup_steps
        self.factor = factor          # 學(xué)習(xí)率因子(就是學(xué)習(xí)率前面的系數(shù))
        self.model_size = model_size  # d_model
        self._rate = 0                # 學(xué)習(xí)率

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))

簡書中出現(xiàn)的程序都在簡書開頭的git中了,直接執(zhí)行main.ipynb就可以運(yùn)行程序,如有不詳實(shí)之處,還請指出~~~

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容