RNN理論基礎
基本RNN結構

rnn_base.png
RNN的基本結構如上左圖所示,輸出除了與當前輸入有關,還與上一時刻狀態(tài)有關。RNN結構展開可視為上右圖,傳播過程如下所示:
- $I_{n}$為當前狀態(tài)的輸入
- $S_{n}$為當前狀態(tài),與當前輸入與上一時刻狀態(tài)有關,即$S_{n} = f(W * S_{n - 1} + U * I_{n})$,其中f(x)為激活函數(shù)
- $O_{n}$為當前輸出,與狀態(tài)有關,為$O_{n} = g(V * S_{n})$,其中f(x)為激活函數(shù)
整個結構共享參數(shù)U,W,V。
當輸入很長時,RNN的狀態(tài)中的包含最早輸入的信息會被“遺忘”,因此RNN無法處理非常長的輸入
基本LSTM結構

lstm_base.png
LSTM為特殊為保存長時記憶而設計的RNN單元,傳遞過程如下:
- 遺忘:決定上一時刻的狀態(tài)有多少被遺忘,由遺忘門層完成,有$f_{n} = sigmoid(W_{f} * [h_{n-1},x_{n}] + b_{f})$,該結果輸出的矩陣與$C_{n-1}$對應位置相乘,對狀態(tài)起衰減作用
- 輸入:決定哪些新信息被整合進狀態(tài),由輸入值層和輸入門層完成:
- 輸入值層決定新輸入數(shù)據(jù),有$CX_{n} = tanh(W_{c} * [h_{n - 1},x_{n}] + b_{c})$
- 輸入門層決定哪些新數(shù)據(jù)被整合入狀態(tài),有$I_{n} = sigmoid(W_{i} * [h_{n - 1},x_{n}] + b_{i})$
- 最終匯入狀態(tài)的輸入有$C_{n} = C_{n-1} * f_{n} + I_{n} * CX_{n}$
- 輸出:決定哪些狀態(tài)被輸出,由輸出門層完成:
- 輸出門層決定哪些狀態(tài)被輸出,有$O_{n} = sigmoid(W_{o} * [h_{n-1},x_{n}] + b_{o})$
- 最終輸入為$h_{n} = O_{n} * tanh(C_{n})$
參數(shù)一共有4對,如下表所示
| 參數(shù)功能 | 參數(shù)對 |
|---|---|
| 忘記門層,決定哪些狀態(tài)被遺忘 | $W_{f}$,$b_{f}$ |
| 輸入門層,決定哪些新輸入被累積入狀態(tài) | $W_{c}$,$b_{c}$ |
| 輸入值層,產(chǎn)生新輸入 | $W_{i}$,$b_{i}$ |
| 輸出門層,決定哪些狀態(tài)被輸出 | $W_{o}$,$b_{o}$ |
代碼實現(xiàn)
import mxnet as mx
導入數(shù)據(jù)
下載數(shù)據(jù)
import os
import requests
def download_data(url,name):
if not os.path.exists(name):
file_content = requests.get(url)
with open(name,"wb") as f:
f.write(file_content.content)
download_data("https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt","./ptb.train.txt")
download_data("https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt","./ptb.valid.txt")
download_data("https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt","./ptb.test.txt")
download_data("https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt","./input.txt")
數(shù)據(jù)處理函數(shù)
def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
lines = open(fname).readlines()
lines = [filter(None, i.split(' ')) for i in lines]
sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label,
start_label=start_label)
return sentences, vocab
可迭代數(shù)據(jù)生成
start_label = 1
invalid_label = 0
train_sent, vocab = tokenize_text("./ptb.train.txt", start_label=start_label,invalid_label=invalid_label)
val_sent, _ = tokenize_text("./ptb.test.txt", vocab=vocab, start_label=start_label,invalid_label=invalid_label)
print(type(vocab),len(vocab))
<class 'dict'> 10000
print(type(train_sent),train_sent[:5])
<class 'list'> [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 0], [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 27, 0], [39, 26, 40, 41, 42, 26, 43, 32, 44, 45, 46, 0], [47, 26, 27, 28, 29, 48, 49, 41, 42, 50, 51, 52, 53, 54, 55, 35, 36, 37, 42, 56, 57, 58, 59, 0], [35, 60, 42, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 35, 71, 72, 42, 73, 74, 75, 35, 46, 42, 76, 77, 64, 78, 79, 80, 27, 28, 81, 82, 83, 0]]
batch_size = 50
buckets = [10,20,40,60,80]
# buckets = None
data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets,invalid_label=invalid_label)
data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets,invalid_label=invalid_label)
WARNING: discarded 4 sentences longer than the largest bucket.
WARNING: discarded 0 sentences longer than the largest bucket.
for _,i in enumerate(data_train):
print(i.data[0][:2],i.label[0][:2])
break
[[ 1203. 373. 141. 119. 79. 64. 32. 891. 80. 4220.
3864. 119. 1407. 860. 467. 1930. 42. 668. 0. 0.]
[ 35. 114. 81. 5793. 119. 840. 432. 1516. 232. 926.
181. 923. 5845. 225. 98. 0. 0. 0. 0. 0.]]
<NDArray 2x20 @cpu(0)>
[[ 373. 141. 119. 79. 64. 32. 891. 80. 4220. 3864.
119. 1407. 860. 467. 1930. 42. 668. 0. 0. 0.]
[ 114. 81. 5793. 119. 840. 432. 1516. 232. 926. 181.
923. 5845. 225. 98. 0. 0. 0. 0. 0. 0.]]
<NDArray 2x20 @cpu(0)>
可以發(fā)現(xiàn),可迭代數(shù)據(jù)的label為下一時刻(data向左平移一個單詞)的數(shù)據(jù)
模型建立
num_layers = 2
num_hidden = 256
stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
num_embed = 256
def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),output_dim=num_embed, name='embed')
stack.reset()
outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')
label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
return pred, ('data',), ('softmax_label',)
a,_,_ = sym_gen(1)
mx.viz.plot_network(symbol=a)

model
訓練網(wǎng)絡
import logging
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
model = mx.mod.BucketingModule(sym_gen=sym_gen,default_bucket_key=data_train.default_bucket_key,context=mx.gpu())
model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.Perplexity(invalid_label),
kvstore = 'device',
optimizer = 'sgd',
optimizer_params = { 'learning_rate':0.01,
'momentum': 0.0,
'wd': 0.00001 },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch = 2,
batch_end_callback = mx.callback.Speedometer(batch_size, 50, auto_reset=False))
WARNING:root:Already bound, ignoring bind()
WARNING:root:optimizer already initialized, ignoring.
INFO:root:Epoch[0] Batch [50] Speed: 240.74 samples/sec perplexity=1230.415304
INFO:root:Epoch[0] Batch [100] Speed: 203.97 samples/sec perplexity=1176.951186
INFO:root:Epoch[0] Batch [150] Speed: 222.01 samples/sec perplexity=1161.217528
INFO:root:Epoch[0] Batch [200] Speed: 214.61 samples/sec perplexity=1130.756199
INFO:root:Epoch[0] Batch [250] Speed: 209.55 samples/sec perplexity=1109.315310
INFO:root:Epoch[0] Batch [300] Speed: 213.95 samples/sec perplexity=1093.083615
INFO:root:Epoch[0] Batch [350] Speed: 232.20 samples/sec perplexity=1084.233586
INFO:root:Epoch[0] Batch [400] Speed: 202.13 samples/sec perplexity=1069.696013
INFO:root:Epoch[0] Batch [450] Speed: 218.14 samples/sec perplexity=1057.711184
INFO:root:Epoch[0] Batch [500] Speed: 236.57 samples/sec perplexity=1048.120406
INFO:root:Epoch[0] Train-perplexity=1044.812667
INFO:root:Epoch[0] Time cost=118.042
INFO:root:Epoch[0] Validation-perplexity=853.844612
INFO:root:Epoch[1] Batch [50] Speed: 228.59 samples/sec perplexity=932.793729
INFO:root:Epoch[1] Batch [100] Speed: 210.51 samples/sec perplexity=933.630035
INFO:root:Epoch[1] Batch [150] Speed: 215.88 samples/sec perplexity=941.272076
INFO:root:Epoch[1] Batch [200] Speed: 226.13 samples/sec perplexity=937.232755
INFO:root:Epoch[1] Batch [250] Speed: 199.27 samples/sec perplexity=926.975004
INFO:root:Epoch[1] Batch [300] Speed: 196.35 samples/sec perplexity=913.408955
INFO:root:Epoch[1] Batch [350] Speed: 216.76 samples/sec perplexity=907.031329
INFO:root:Epoch[1] Batch [400] Speed: 198.65 samples/sec perplexity=899.224687
INFO:root:Epoch[1] Batch [450] Speed: 238.68 samples/sec perplexity=896.943083
INFO:root:Epoch[1] Batch [500] Speed: 205.63 samples/sec perplexity=892.764729
INFO:root:Epoch[1] Batch [550] Speed: 206.36 samples/sec perplexity=888.453916
INFO:root:Epoch[1] Batch [600] Speed: 218.98 samples/sec perplexity=885.808878
INFO:root:Epoch[1] Batch [650] Speed: 229.98 samples/sec perplexity=884.451112
INFO:root:Epoch[1] Batch [700] Speed: 226.57 samples/sec perplexity=882.243212
INFO:root:Epoch[1] Batch [750] Speed: 234.16 samples/sec perplexity=878.481937
INFO:root:Epoch[1] Batch [800] Speed: 218.44 samples/sec perplexity=874.363066
INFO:root:Epoch[1] Train-perplexity=869.764287
INFO:root:Epoch[1] Time cost=194.924
INFO:root:Epoch[1] Validation-perplexity=747.663144
參考文獻
[翻譯] WILDML RNN系列教程 第一部分 RNN簡介
[莫煩 PyTorch 系列教程] 4.3 - RNN 循環(huán)神經(jīng)網(wǎng)絡 (回歸 Regression)