Pytorch 使用attention實現(xiàn)轉(zhuǎn)換日期并可視化attention
實現(xiàn)環(huán)境:python3.6
pytorch1.0
import json
from matplotlib import ticker
from numpy import *
from collections import Counter
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
對數(shù)據(jù)進行預處理:
首先從字符級層面統(tǒng)計字符的數(shù)量,然后將其轉(zhuǎn)為字符對應數(shù)字的字典,最后再形成一個數(shù)字對應字符的字典。
def build_vocab(texts,n=None):
counter = Counter(''.join(texts)) #字符級層的字典,Counter:是一個簡單的計數(shù)器,例如,統(tǒng)計字符出現(xiàn)的個數(shù):
char2index = {w:i for i,(w,c) in enumerate(counter.most_common(n),start=4)} #(w,c)對應著字符:出現(xiàn)的次數(shù),most_common() 方法返回最常見的元素及其計數(shù),順序為最常見到最少
char2index['~'] = 0 #pad,不足長度的文本在后面填充0,這里添加四種表示,也就是上面為什么從4開始的原因
char2index['^'] = 1 #sos 表示句子的開頭
char2index['$'] = 2 #eos 表示句子的結尾
char2index['#'] = 3 #unk 表示句子中出現(xiàn)字典中沒有的未知詞
index2char = {i:w for w,i in char2index.items()}
return char2index,index2char
數(shù)據(jù)下載鏈接:https://pan.baidu.com/s/132uS7mMzn7ISqEVg8i27eA
提取碼:36fu
pairs = json.load(open('./data/Time Dataset.json','rt',encoding='utf-8'))
print(pairs[:2]) #查看一下數(shù)據(jù)的格式。
[['six hours and fifty five am', '06:55'], ['48 min before 10 a.m', '09:12']]
將目標文本和原文本分開,建立各自的字典
data = array(pairs)
src_texts = data[:,0] #第一列的所有值
trg_texts = data[:,1] #第二列的所有值
src_c2ix,src_ix2c = build_vocab(src_texts)
trg_c2ix,trg_ix2c = build_vocab(trg_texts)
接下來按批量更新,定義一個隨機批量生成的函數(shù),它能將文本轉(zhuǎn)成字典中的數(shù)字表示,并同時返回batch_size個樣本和它們的長度,這些樣本按照長度降序排序。pad的長度以batch中最長的為準。這主要是為了適應pack_padded_sequence這個函數(shù),因為輸入RNN的序列不需要講pad標志也輸入RNN中計算,RNN只需要循環(huán)計算得到其真實長度即可。
def indexes_from_text(text,char2index):
return [1] + [char2index[c] for c in text] + [2]
def pad_seq(seq,max_length):
seq += [0 for _ in range(max_length - len(seq))]
return seq
#第一個參數(shù) function 以參數(shù)序列中的每一個元素調(diào)用 function 函數(shù),返回包含每次 function 函數(shù)返回值的新列表。
max_src_len = max(list(map(len,src_texts)))+2 #map(function, iterable, ...)
max_trg_len = max(list(map(len,trg_texts)))+2
max_src_len,max_trg_len
(43, 7)
def random_batch(batch_size,pairs,src_c2ix,trg_c2ix):
input_seqs,target_seqs = [],[]
for i in random.choice(len(pairs),batch_size):
input_seqs.append(indexes_from_text(pairs[i][0],src_c2ix)) #從隨機的索引開始生成相應的text的index
target_seqs.append(indexes_from_text(pairs[i][1],trg_c2ix))
seq_pairs = sorted(zip(input_seqs,target_seqs),key=lambda p:len(p[0]),reverse=True) #以key來排序,此處以input_seqs進行降序排序
input_seqs,target_seqs = zip(*seq_pairs) #與 zip 相反,*zipped 可理解為解壓,返回二維矩陣式
input_lengths = [len(s) for s in input_seqs]
input_padded = [pad_seq(s,max(input_lengths)) for s in input_seqs]
target_lengths = [len(s) for s in target_seqs]
target_padded = [pad_seq(s,max(target_lengths)) for s in target_seqs]
input_var = torch.LongTensor(input_padded).transpose(0,1) #torch.transpose(input, dim0, dim1, out=None) → Tensor,返回輸入矩陣input的轉(zhuǎn)置。交換維度dim0和dim1
#得到的矩陣形狀為seq_len*batch_size
target_var = torch.LongTensor(target_padded).transpose(0,1)
input_var = input_var.to(device)
target_var = target_var.to(device)
return input_var,input_lengths,target_var,target_lengths
"""
sort 與 sorted 區(qū)別:
sort 是應用在 list 上的方法,sorted 可以對所有可迭代的對象進行排序操作。
list 的 sort 方法返回的是對已經(jīng)存在的列表進行操作,無返回值,而內(nèi)建函數(shù) sorted 方法返回的是一個新的 list,而不是在原來的基礎上進行的操作
sorted(iterable[, cmp[, key[, reverse]]])
key -- 主要是用來進行比較的元素,只有一個參數(shù),具體的函數(shù)的參數(shù)就是取自于可迭代對象中,指定可迭代對象中的一個元素來進行排序。
reverse -- 排序規(guī)則,reverse = True 降序 , reverse = False 升序(默認)。
"""測試batch_size = 3時是否能夠正確輸出
random_batch(3,data,src_c2ix,trg_c2ix)
(tensor([[ 1, 1, 1],
[ 6, 23, 6],
[ 5, 9, 18],
[ 8, 23, 23],
[ 4, 37, 9],
[ 7, 4, 26],
[33, 13, 23],
[22, 9, 2],
[30, 11, 0],
[ 7, 9, 0],
[22, 2, 0],
[34, 0, 0],
[ 4, 0, 0],
[ 6, 0, 0],
[31, 0, 0],
[ 5, 0, 0],
[ 8, 0, 0],
[ 6, 0, 0],
[20, 0, 0],
[ 4, 0, 0],
[13, 0, 0],
[ 9, 0, 0],
[11, 0, 0],
[ 9, 0, 0],
[ 2, 0, 0]], device='cuda:0'), [25, 11, 8], tensor([[ 1, 1, 1],
[ 6, 5, 7],
[ 5, 8, 8],
[ 4, 4, 4],
[ 7, 8, 5],
[ 5, 12, 8],
[ 2, 2, 2]], device='cuda:0'), [7, 7, 7])
模型:
這里的模型框架分為encoder和decoder兩個部分,encoder部分比較簡單,就是一層enbedding層加上兩層GRU。
前面對于batch的格式處理,主要是為了處理pack_padded_sequence和pad_packer_sequence這兩個類對GRU輸入輸出批量處理https://blog.csdn.net/lssc4205/article/details/79474735
https://blog.csdn.net/u012436149/article/details/79749409
class Encoder(nn.Module):
def __init__(self,input_dim,embedding_dim,hidden_dim,num_layers=2,dropout=0.2):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
#input_dim = vocab_size + 1
self.embedding = nn.Embedding(input_dim,embedding_dim)
self.rnn = nn.GRU(embedding_dim,hidden_dim,num_layers=num_layers,dropout=dropout)
self.dropout = nn.Dropout(dropout)
def forward(self,input_seqs,input_lengths,hidden=None):
#src = [sent_len,batch_size]
embedded = self.dropout(self.embedding(input_seqs))
#embedded = [sent_len,batch_size,emb_dim]
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded,input_lengths)
outputs,hidden = self.rnn(packed,hidden)
outputs,output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs)
return outputs,hidden
outputs,hidden = self.rnn(embedded,hidden)
outputs = [sent_len,batch_size,hid_dim*n_directions]
hidden = [n_layers,batch_size,hid_dim]
outputs總是來自于最后一層
首先定義一下Attention層,這里主要是對encoder的輸出進行attention操作,也可以直接對embedding層的輸出進行attention。
論文Neural Machine Translation by Jointly Learning to Align and Translate中定義了attention的計算公式。
decoder的輸出取決于decoder先前的輸出和 , 這里
包括當前GRU輸出的hidden state(這部分已經(jīng)考慮了先前的輸出) 以及attention(上下文向量,由encoder的輸出求得)。 計算公式如下:函數(shù)
非線性激活的全連接層,輸入是
,
, and
三者的拼接。
所謂的上下文向量就是對encoder的所有輸出進行加權求和, 表示輸出的第 i 個詞對encoder第 j 個輸出
的權重。
每個 通過對所有
進行softmax,而每個
是decoder的上一個hidden state
和指定的encoder的輸出
經(jīng)過某些線性操作
計算得分。
此外,論文Effective Approaches to Attention-based Neural Machine Translation中提出了計算分值的不同方式。這里用到的是第三種。
class Attention(nn.Module):
def __init__(self,hidden_dim):
super(Attention,self).__init__()
self.hidden_dim = hidden_dim
self.attn = nn.Linear(self.hidden_dim*2,hidden_dim)
self.v = nn.Parameter(torch.rand(hidden_dim))
self.v.data.normal_(mean=0,std=1./np.sqrt(self.v.size(0)))
def forward(self,hidden,encoder_outputs):
#encoder_outputs:(seq_len,batch_size,hidden_size)
#hidden:(num_layers*num_directions,batch_size,hidden_size)
max_len = encoder_outputs.size(0)
h = hidden[-1].repeat(max_len,1,1) #np.repeat(x, 3, axis=1) #沿著縱軸方向重復3次,增加列數(shù)
#(seq_len,batch_size,hidden_size)
attn_energies = self.score(h,encoder_outputs) #計算attention score
return F.softmax(attn_energies,dim=1) #使用sofrmax歸一化
def score(self,hidden,encoder_outputs):
#(seq_len,batch_size,2*hidden_size->(seq_len,batch_size,hidden_size))
energy = F.tanh(self.attn(torch.cat([hidden,encoder_outputs],2)))
energy = energy.permute(1,2,0) #(batch_size,hidden_size,seq_len):permute實現(xiàn)了0維的到2維上,1->0,2->1
v = self.v.repeat(encoder_outputs.size(1),1).unsqueeze(1) #(batch_size,1,hidden_size):repeat()https://blog.csdn.net/xuxiatian/article/details/81167784
energy = torch.bmm(v,energy) #(batch_size,1,seq_len):torch.matmul和torch.bmm,都能實現(xiàn)對于batch的矩陣乘法:https://blog.csdn.net/laox1ao/article/details/79159303
return energy.squeeze(1) #(batch_size,seq_len):https://blog.csdn.net/u013444215/article/details/81941366
接下來是加入attention層的decoder,GRU的輸出進入全連接層后,又進行了log_softmax操作計算輸出詞的概率,主要是為了方便NLLLoss損失函數(shù),如果用CrossEntropyLoss損失函數(shù),可以不用加softmax:損失函數(shù)NLLLoss() 的 輸入 是一個對數(shù)概率向量和一個目標標簽. 它不會為我們計算對數(shù)概率,適合最后一層是log_softmax()的網(wǎng)絡. 損失函數(shù) CrossEntropyLoss() 與 NLLLoss() 類似, 唯一的不同是它為我們?nèi)プ?softmax.可以理解為:CrossEntropyLoss()=log_softmax() + NLLLoss()
class Decoder(nn.Module):
def __init__(self,output_dim,embedding_dim,hidden_dim,num_layers=2,dropout=0.2):
super().__init__()
self.output_dim = output_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
self.embedding = nn.Embedding(output_dim,embedding_dim)
self.attention = Attention(hidden_dim)
self.rnn = nn.GRU(embedding_dim+hidden_dim,hidden_dim,num_layers=num_layers,dropout=dropout)
self.out = nn.Linear(embedding_dim+hidden_dim*2,output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self,input,hidden,encoder_outputs):
#input = [bsz]
#hidden = [n_layer*n_direction,batch_size,hid_dim]
#encoder_outputs = [sent_len,batch_size,hid_dim*n_direction]
input = input.unsqueeze(0)
#input = [1,bsz]
embedded = self.dropout(self.embedding(input))
#emdedded = [1,bsz,emb_dim]
attn_weight = self.attention(hidden,encoder_outputs)
#(batch_size,seq_len)
context = attn_weight.unsqueeze(1).bmm(encoder_outputs.transpose(0,1)).transpose(0,1)
#(batch_size,1,hidden_dim*n_directions)
#(1,batch_size,hidden_dim*n_directions)
emb_con = torch.cat((embedded,context),dim=2)
#emb_con = [1,bsz,emb_dim+hid_dim]
_,hidden = self.rnn(emb_con,hidden)
#outputs = [sent_len,batch_size,hid_dim*n_directions]
#hidden = [n_layers*n_direction,batch_size,hid_dim]
output = torch.cat((embedded.squeeze(0),hidden[-1],context.squeeze(0)),dim=1)
output = F.log_softmax(self.out(output),1)
#outputs = [sent_len,batch_size,vocab_size]
return output,hidden,attn_weight
我們定義一個Seq2Seq類,將encoder和decoder結合起來,通過一個循環(huán),模型對每一個batch從前往后依次生成序列,訓練的時候可以使用teacher_forcing隨機使用真實詞或是模型輸出的詞作為target,測試的時候就不需要了。
class Seq2Seq(nn.Module):
def __init__(self,encoder,decoder,device,teacher_forcing_ratio=0.5):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
self.teacher_forcing_ratio = teacher_forcing_ratio
def forward(self,src_seqs,src_lengths,trg_seqs):
#src_seqs = [sent_len,batch_size]
#trg_seqs = [sent_len,batch_size]
batch_size = src_seqs.shape[1]
max_len = trg_seqs.shape[0]
trg_vocab_size = self.decoder.output_dim
#建立一個tensor來保存decoder的輸出
outputs = torch.zeros(max_len,batch_size,trg_vocab_size).to(self.device)
#hidden用于decoder的初始hidden的狀態(tài)
#encoder_outputs用于計算上下文向量
encoder_outputs,hidden = self.encoder(src_seqs,src_lengths)
#decoder的第一個輸入是<sos>
output = trg_seqs[0,:]
for t in range(1,max_len): #這里跳過sos
output,hidden,_ = self.decoder(output,hidden,encoder_outputs)
outputs[t] = output
teacher_force = random.random() < self.teacher_forcing_ratio
output = (trg_seqs[t] if teacher_force else output.max(1)[1])
return outputs
def predict(self,src_seqs,src_lengths,max_trg_len=20,start_ix=1):
max_src_len = src_seqs.shape[0]
batch_size = src_seqs.shape[1]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(max_trg_len,batch_size,trg_vocab_size).to(self.device)
encoder_outputs,hidden = self.encoder(src_seqs,src_lengths)
output = torch.LongTensor([start_ix]*batch_size).to(self.device) #這里初始化一個batch的一步輸出大小的tensor
attn_weights = torch.zeros((max_trg_len,batch_size,max_src_len))
for t in range(1,max_trg_len):
output,hidden,attn_weight = self.decoder(output,hidden,encoder_outputs)
outputs[t] = output
output = output.max(1)[1]
attn_weights[t] = attn_weight
return outputs,attn_weights
模型訓練:
直接使用1000個batch進行更新
import torch.optim as optim
embedding_dim = 100
hidden_dim = 100
batch_size = 256
clip = 5
encoder = Encoder(len(src_c2ix)+1,embedding_dim,hidden_dim)
decoder = Decoder(len(trg_c2ix)+1,embedding_dim,hidden_dim)
model = Seq2Seq(encoder,decoder,device).to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.NLLLoss(ignore_index=0).to(device)
model.train()
for batch_id in range(1,1001):
src_seqs,src_lengths,trg_seqs,_ = random_batch(batch_size,pairs,src_c2ix,trg_c2ix)
optimizer.zero_grad()
output = model(src_seqs,src_lengths,trg_seqs)
loss = criterion(output.view(-1,output.shape[2]),trg_seqs.view(-1))
loss.backward()
#http://www.cnblogs.com/lindaxin/p/7998196.html
torch.nn.utils.clip_grad_norm_(model.parameters(),clip) #既然在BP過程中會產(chǎn)生梯度消失/爆炸(就是偏導無限接近0,導致長時記憶無法更新),那么最簡單粗暴的方法,設定閾值,當梯度小于/大于閾值時,更新的梯度為閾值
optimizer.step() #https://blog.csdn.net/gdymind/article/details/82708920
if batch_id % 100 == 0:
print('current loss:{:.4f}'.format(loss))
torch.save(model, 'model.pth')
current loss:0.8211
current loss:0.3182
current loss:0.2070
current loss:0.1032
current loss:0.0706
current loss:0.0345
current loss:0.0343
current loss:0.0215
current loss:0.0108
current loss:0.0169
c:\users\administrator\appdata\local\programs\python\python36\lib\site-packages\torch\serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Seq2Seq. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
c:\users\administrator\appdata\local\programs\python\python36\lib\site-packages\torch\serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Encoder. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
c:\users\administrator\appdata\local\programs\python\python36\lib\site-packages\torch\serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Decoder. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
c:\users\administrator\appdata\local\programs\python\python36\lib\site-packages\torch\serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Attention. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
進行測試:
主要實驗可視化attention權重
def show_attention(input_words,output_words,attentions):
plt.rcParams['savefig.dpi'] = 300 #圖片像素
plt.rcParams['figure.dpi'] = 300 #分辨率
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions,cmap='bone') #可視化矩陣
fig.colorbar(cax)
#設置axes
ax.set_xticklabels(['']+input_words)
ax.set_yticklabels(['']+output_words)
#以每一個刻度顯示label
ax.xaxis.set_major_locator(ticker.MultipleLocator())
ax.yaxis.set_major_locator(ticker.MultipleLocator())
plt.show()
plt.close()
def evaluate(model,text,src_c2ix,trg_ix2c):
model.eval()
with torch.no_grad():
seq = torch.LongTensor(indexes_from_text(text,src_c2ix)).view(-1,1).to(device)
outputs,attn_weights = model.predict(seq,[seq.size(0)],max_trg_len)
outputs = outputs.squeeze(1).cpu().numpy()
attn_weights = attn_weights.squeeze(1).cpu().numpy()
output_words = [trg_ix2c[np.argmax(word_prob)] for word_prob in outputs]
show_attention(list('^'+text+'$'),output_words,attn_weights)
text = 'thirsty 1 before 3 clock affternoon'
evaluate(model,text,src_c2ix,trg_ix2c)
text = 'forty seven min before 10 p.m'
evaluate(model,text,src_c2ix,trg_ix2c)