RNN與LSTM
RNN網(wǎng)絡(luò)是在傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的基礎(chǔ)上加入了記憶的成分。對(duì)于RNN模型來(lái)說(shuō),序列被看做一系列隨著時(shí)間步長(zhǎng)遞進(jìn)的事件序列。這里的時(shí)間步長(zhǎng)并不是真實(shí)世界中所指的時(shí)間,而是指序列中的位置。RNN模型的特殊結(jié)構(gòu)可以讓他處理相互依賴的時(shí)間序列及變長(zhǎng)數(shù)據(jù)。長(zhǎng)期依賴對(duì)于文本理解是不可回避的問(wèn)題,但普通RNN結(jié)構(gòu)并不能很好的處理這個(gè)問(wèn)題,由于RNN的參數(shù)共享,在狀態(tài)傳遞的過(guò)程中會(huì)發(fā)生梯度消失或爆炸的問(wèn)題。LSTM就是為了解決長(zhǎng)期依賴問(wèn)題而產(chǎn)生的。與普通RNN相比,最主要的改進(jìn)就是多出了三個(gè)門控制器:輸入門、輸出門、遺忘門。
有關(guān)RNN與LSTM的具體數(shù)學(xué)推導(dǎo)可見(jiàn)相關(guān)技術(shù)博客,這里不做詳細(xì)闡述。
分類
這里用mnist數(shù)據(jù)集來(lái)做RNN的分類。RNN通常的輸入是三維張量[batch_size, step_time, cell.input_size],這里把28*28的圖片的每一行作為輸入,28行作為step_time,用128張圖片作為一個(gè)batch。
首先導(dǎo)入mnist數(shù)據(jù)集,設(shè)置超參數(shù)和占位符。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
lr = 0.001
n_input = 28
n_step = 28
n_digit = 10
n_cell = 128
n_batch = 128
train_times = 100000
x = tf.placeholder(tf.float32, [None, n_step, n_input])
y = tf.placeholder(tf.float32, [None, n_digit])
這里設(shè)計(jì)LSTM網(wǎng)絡(luò)來(lái)對(duì)圖片分類,LSTM的核心是一個(gè)隱藏的神經(jīng)層cell,包括各種門的參數(shù)和激活函數(shù),在cell的前后,還各需要輸入和輸出的網(wǎng)絡(luò)層。由于輸入是三維的張量,在進(jìn)行輸入時(shí),需要將其reshape成二維張量[batch_size*step_time, cell.input_size],再進(jìn)行權(quán)重計(jì)算。
x_new = tf.reshape(x, [-1, n_input])
cell_in = tf.layers.dense(x_new, n_cell)
cell_in = tf.reshape(cell_in, [-1, n_step, n_cell])
現(xiàn)在TensorFlow1.2將隱層網(wǎng)絡(luò)的設(shè)計(jì)進(jìn)行了封裝,可以直接調(diào)用tf.layers.dense.如下:
dense(
inputs,
units,
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
reuse=None
)
inputs是輸入的張量,units是神經(jīng)元的個(gè)數(shù),activation是激活函數(shù),默認(rèn)沒(méi)有激活函數(shù)。這里只需要將輸入進(jìn)行線性組合,所以不需要激活函數(shù)。
然后設(shè)計(jì)LSTM cell。
cell = tf.contrib.rnn.BasicLSTMCell(n_cell)
# Args:
# num_units: int, The number of units in the LSTM cell.
# forget_bias: float, The bias added to forget gates (see above).
# state_is_tuple: If True, accepted and returned states are 2-tuples of the c_state and m_state.
# If False, they are concatenated along the column axis. The latter behavior will soon be deprecated.
# activation: Activation function of the inner states. Default: tanh.
# reuse: (optional) Python boolean describing whether to reuse variables
# in an existing scope. If not True, and the existing scope already has
# the given variables, an error is raised.
init_state = cell.zero_state(n_batch, dtype=tf.float32)
output, state = tf.nn.dynamic_rnn(
cell, cell_in, initial_state=init_state, time_major=False)
調(diào)用tf.contrib.rnn.BasicLSTMCell設(shè)計(jì)cell。num_units是cell中神經(jīng)元的個(gè)數(shù),forget_bias默認(rèn)為1,表示遺忘門的初始值為1,表示遺忘之前的輸入聯(lián)系。state_is_tuple默認(rèn)為true,表示輸出的state是一個(gè)tuple,包含兩個(gè)list,state[0]是cell中的狀態(tài),cell[1]是輸出的狀態(tài)。開(kāi)始時(shí)需要初始化狀態(tài),如第二行代碼,然后調(diào)用tf.nn.dynamic_rnn得到RNN層的結(jié)果,output的shape是[batch_size, step_time, cell],state的shape是[batch_size, cell],cell表示RNN中的神經(jīng)元個(gè)數(shù)。
cell后再設(shè)計(jì)一個(gè)輸出層,與上類似
cell_out = tf.layers.dense(state[1], n_digit)
這里,使用的是state[1],也可以使用output,但是output是三維張量,對(duì)應(yīng)了每個(gè)step的輸出,所以需要shape之后再用output[-1].
網(wǎng)絡(luò)層設(shè)計(jì)好之后,就是模型的代價(jià)函數(shù)設(shè)計(jì),訓(xùn)練和評(píng)估了。
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=cell_out))
train = tf.train.AdamOptimizer(lr).minimize(loss)
acc = tf.reduce_mean(
tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(cell_out, 1)), tf.float32))
with tf.Session() as sess:
tf.global_variables_initializer().run()
# print(sess.run(output, {x: mnist.train.images[
# :128].reshape([128, n_step, n_input])}).shape)
# output.shape:n_batch,n_step,n_cell >>= time_major=False
# print(sess.run(state[1], {x: mnist.train.images[
# :128].reshape([128, n_step, n_input])}).shape)
# state[1].shape:n_batch,n_cell
for i in range(train_times):
xs, ys = mnist.train.next_batch(n_batch)
xs = xs.reshape([n_batch, n_step, n_input])
sess.run(train, {x: xs, y: ys})
if i % 200 == 0:
print(sess.run(acc, {x: xs, y: ys}))
note
- 在圖片分類時(shí),由于每一張圖片前后沒(méi)有聯(lián)系,所以在初始化狀態(tài),使初始狀態(tài)為0后,并不需要改變,但在一些時(shí)間序列問(wèn)題上,每一個(gè)batch相互聯(lián)系,所以在第一次batch初始化狀態(tài)后,使后一次的狀態(tài)是前一次的輸出狀態(tài)state。
- RNN的輸入是[batch_size, step_time, cell.input_size],在每一次batch后得到一個(gè)輸出結(jié)果[batch_size,output],若選擇dynamic_rnn return的output進(jìn)行下一步計(jì)算,則輸出結(jié)果為[batch_size, step_time, output],選擇output[-1]時(shí),結(jié)果與用state[1]一致。