【轉(zhuǎn)載】用tensorflow搭建RNN(LSTM)進行MNIST 手寫數(shù)字辨識

原文鏈接:http://www.cnblogs.com/sandy-t/p/6930608.html

循環(huán)神經(jīng)網(wǎng)絡(luò)RNN相比傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)在處理序列化數(shù)據(jù)時更有優(yōu)勢,因為RNN能夠?qū)⒓尤肷希ㄏ拢┪男畔⑦M行考慮。一個簡單的RNN如下圖所示:

將這個循環(huán)展開得到下圖:

上一時刻的狀態(tài)會傳遞到下一時刻。這種鏈式特性決定了RNN能夠很好的處理序列化的數(shù)據(jù),RNN 在語音識別,語言建模,翻譯,圖片描述等問題上已經(jīng)取得了很到的結(jié)果。

根據(jù)輸入、輸出的不同和是否有延遲等一些情況,RNN在應(yīng)用中有如下一些形態(tài):

RNN存在的問題

RNN能夠把狀態(tài)傳遞到下一時刻,好像對一部分信息有記憶能力一樣,如下圖:

h3h3的值可能會由x1x1,x2x2的值來決定。

但是,對于一些復(fù)雜場景

由于距離太遠,中間間隔了太多狀態(tài),x1x1,x2x2對ht+1ht+1的值幾乎起不到任何作用。(梯度消失和梯度爆炸)

LSTM(Long Short Term Memory)

由于RNN不能很好地處理這種問題,于是出現(xiàn)了LSTM(Long Short Term Memory)一種加強版的RNN(LSTM可以改善梯度消失問題)。簡單來說就是原始RNN沒有長期的記憶能力,于是就給RNN加上了一些記憶控制器,實現(xiàn)對某些信息能夠較長期的記憶,而對某些信息只有短期記憶能力。

如上圖所示,LSTM中存在Forget Gate,Input Gate,Output Gate來控制信息的流動程度。

RNN:

LSTN:

加號圓圈表示線性相加,乘號圓圈表示用gate來過濾信息。

Understanding LSTM中對LSTM有非常詳細的介紹。(對應(yīng)的中文翻譯

LSTM MNIST手寫數(shù)字辨識

實際上,圖片文字識別這類任務(wù)用CNN來做效果更好,但是這里想要強行用LSTM來做一波。

MNIST_data中每一個image的大小是28*28,以行順序作為序列輸入,即第一行的28個像素作為$x_{0}

,第二行為,第二行為x_1,...,第28行的28個像素作為,...,第28行的28個像素作為x_28$輸入,一個網(wǎng)絡(luò)結(jié)構(gòu)總共的輸入是28個維度為28的向量,輸出值是10維的向量,表示的是0-9個數(shù)字的概率值。這是一個many to one的RNN結(jié)構(gòu)。

下面直接上代碼:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 參數(shù)設(shè)置

BATCH_SIZE = 100? ? ? ? # BATCH的大小,相當于一次處理50個image

TIME_STEP = 28? ? ? ? ? # 一個LSTM中,輸入序列的長度,image有28行

INPUT_SIZE = 28? ? ? ? # x_i 的向量長度,image有28列

LR = 0.01? ? ? ? ? ? ? # 學(xué)習(xí)率

NUM_UNITS = 100? ? ? ? # 多少個LTSM單元

ITERATIONS=8000? ? ? ? # 迭代次數(shù)

N_CLASSES=10? ? ? ? ? ? # 輸出大小,0-9十個數(shù)字的概率

# 定義 placeholders 以便接收x,y

train_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])? ? ? # 維度是[BATCH_SIZE,TIME_STEP * INPUT_SIZE]

image = tf.reshape(train_x, [-1, TIME_STEP, INPUT_SIZE])? ? ? ? ? ? ? ? ? # 輸入的是二維數(shù)據(jù),將其還原為三維,維度是[BATCH_SIZE, TIME_STEP, INPUT_SIZE]

train_y = tf.placeholder(tf.int32, [None, N_CLASSES])? ? ? ? ? ? ? ? ? ?

# 定義RNN(LSTM)結(jié)構(gòu)

rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=NUM_UNITS)

outputs,final_state = tf.nn.dynamic_rnn(

? ? cell=rnn_cell,? ? ? ? ? ? ? # 選擇傳入的cell

? ? inputs=image,? ? ? ? ? ? ? # 傳入的數(shù)據(jù)

? ? initial_state=None,? ? ? ? # 初始狀態(tài)

? ? dtype=tf.float32,? ? ? ? ? # 數(shù)據(jù)類型

? ? time_major=False,? ? ? ? ? # False: (batch, time step, input); True: (time step, batch, input),這里根據(jù)image結(jié)構(gòu)選擇False

)

output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES)? ? ?

這里outputs,final_state = tf.nn.dynamic_rnn(...).

final_state包含兩個量,第一個為c保存了每個LSTM任務(wù)最后一個cell中每個神經(jīng)元的狀態(tài)值,第二個量h保存了每個LSTM任務(wù)最后一個cell中每個神經(jīng)元的輸出值,所以c和h的維度都是[BATCH_SIZE,NUM_UNITS]。

outputs的維度是[BATCH_SIZE,TIME_STEP,NUM_UNITS],保存了每個step中cell的輸出值h。

由于這里是一個many to one的任務(wù),只需要最后一個step的輸出outputs[:, -1, :],output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES) 通過一個全連接層將輸出限制為N_CLASSES。

loss = tf.losses.softmax_cross_entropy(onehot_labels=train_y, logits=output) # 計算loss

train_op = tf.train.AdamOptimizer(LR).minimize(loss)? ? ? #選擇優(yōu)化方法

correct_prediction = tf.equal(tf.argmax(train_y, axis=1),tf.argmax(output, axis=1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))? #計算正確率

sess = tf.Session()

sess.run(tf.global_variables_initializer())? ? # 初始化計算圖中的變量

for step in range(ITERATIONS):? ? # 開始訓(xùn)練

? ? x, y = mnist.train.next_batch(BATCH_SIZE)?

? ? test_x, test_y = mnist.test.next_batch(5000)

? ? _, loss_ = sess.run([train_op, loss], {train_x: x, train_y: y})

? ? if step % 500 == 0:? ? ? # test(validation)

? ? ? ? accuracy_ = sess.run(accuracy, {train_x: test_x, train_y: test_y})

? ? ? ? print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

訓(xùn)練過程輸出:

train loss: 2.2990 | test accuracy: 0.13

train loss: 0.1347 | test accuracy: 0.96

train loss: 0.0620 | test accuracy: 0.97

train loss: 0.0788 | test accuracy: 0.98

train loss: 0.0160 | test accuracy: 0.98

train loss: 0.0084 | test accuracy: 0.99

train loss: 0.0436 | test accuracy: 0.99

train loss: 0.0104 | test accuracy: 0.98

train loss: 0.0736 | test accuracy: 0.99

train loss: 0.0154 | test accuracy: 0.98

train loss: 0.0407 | test accuracy: 0.98

train loss: 0.0109 | test accuracy: 0.98

train loss: 0.0722 | test accuracy: 0.98

train loss: 0.1133 | test accuracy: 0.98

train loss: 0.0072 | test accuracy: 0.99

train loss: 0.0352 | test accuracy: 0.98

可以看到,雖然RNN是擅長處理序列類的任務(wù),在MNIST手寫數(shù)字圖片辨識這個任務(wù)上,RNN同樣可以取得很高的正確率。

參考:

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

https://yjango.gitbooks.io/superorganism/content/lstmgru.html

參考代碼

https://yjango.gitbooks.io/superorganism/content/lstmgru.html

參考代碼

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容