本文改編自TensorFLow官方教程中文版,力求更加簡潔、清晰。
一、介紹
TensorFlow是當(dāng)前最流行的機器學(xué)習(xí)框架,有了它,開發(fā)人工智能程序就像Java編程一樣簡單。今天,就讓我們從手寫體數(shù)字識別入手,看看如何用機器學(xué)習(xí)的方法解決這個問題。
二、編程環(huán)境
Python2.7+TensorFlow0.5.0下測試通過,Python3.5下未測試。請參考《TensorFLow下載與安裝》配置環(huán)境。
三、思路
沒有接觸過圖像處理的人可能會很納悶,從一張圖片識別出里面的內(nèi)容似乎是件相當(dāng)神奇的事情。其實,當(dāng)你把圖片當(dāng)成一枚枚像素來看的話,就沒那么神秘了。下圖為手寫體數(shù)字1的圖片,它在計算機中的存儲其實是一個二維矩陣,每個元素都是0~1之間的數(shù)字,0代表白色,1代表黑色,小數(shù)代表某種程度的灰色。

現(xiàn)在,對于MNIST數(shù)據(jù)集中的圖片來說,我們只要把它當(dāng)成長度為784的向量就可以了(忽略它的二維結(jié)構(gòu),28×28=784)。我們的任務(wù)就是讓這個向量經(jīng)過一個函數(shù)后輸出一個類別,吶,就是下邊這個函數(shù),稱為Softmax分類器。

這個式子里的圖片向量的長度只有3,用x表示。乘上一個系數(shù)矩陣W,再加上一個列向量b,然后輸入softmax函數(shù),輸出就是分類結(jié)果y。W是一個權(quán)重矩陣,W的每一行與整個圖片像素相乘的結(jié)果是一個分?jǐn)?shù)score,分?jǐn)?shù)越高表示圖片越接近該行代表的類別。因此,W x + b 的結(jié)果其實是一個列向量,每一行代表圖片屬于該類的評分。熟悉圖像分類的同學(xué)應(yīng)該了解,通常分類的結(jié)果并非評分,而是概率,表示有多大的概率屬于此類別。因此,Softmax函數(shù)的作用就是把評分轉(zhuǎn)換成概率,并使總的概率為1。
有了這個模型,如何訓(xùn)練它呢?
對于機器學(xué)習(xí)算法來說,訓(xùn)練就是不斷調(diào)整模型參數(shù)使誤差達到最小的過程。這里的模型參數(shù)就是W和b。接下來我們需要定義誤差。誤差當(dāng)然是把預(yù)測的結(jié)果y和正確結(jié)果相比較得到的,但是由于正確結(jié)果是one_hot向量(即只有一個元素是1,其它元素都是0),而預(yù)測結(jié)果是個概率向量,用什么方法比較其實是個需要深入考慮的事情。事實上,我們使用的是交叉熵?fù)p失(cross-entropy loss),為什么用這個,其實我現(xiàn)在也不太清楚,所以姑且先用著吧,以后見得多了自然就明白了。
好了,到這里思路大體上就講完了,還有不清楚的地方讓我們看看代碼就能理解了。
四、TensorFlow實現(xiàn)
說實話,這個代碼比想象中還要簡練,只有33行,所以我把它直接貼出來。
# coding=utf-8
import tensorflow as tf
import input_data
# 下載MNIST數(shù)據(jù)集到'MNIST_data'文件夾并解壓
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 設(shè)置權(quán)重weights和偏置biases作為優(yōu)化變量,初始值設(shè)為0
weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([10]))
# 構(gòu)建模型
x = tf.placeholder("float", [None, 784])
y = tf.nn.softmax(tf.matmul(x, weights) + biases) # 模型的預(yù)測值
y_real = tf.placeholder("float", [None, 10]) # 真實值
cross_entropy = -tf.reduce_sum(y_real * tf.log(y)) # 預(yù)測值與真實值的交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 使用梯度下降優(yōu)化器最小化交叉熵
# 開始訓(xùn)練
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100) # 每次隨機選取100個數(shù)據(jù)進行訓(xùn)練,即所謂的“隨機梯度下降(Stochastic Gradient Descent,SGD)”
sess.run(train_step, feed_dict={x: batch_xs, y_real:batch_ys}) # 正式執(zhí)行train_step,用feed_dict的數(shù)據(jù)取代placeholder
if i % 100 == 0:
# 每訓(xùn)練100次后評估模型
correct_prediction = tf.equal(tf.argmax(y, 1), tf.arg_max(y_real, 1)) # 比較預(yù)測值和真實值是否一致
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # 統(tǒng)計預(yù)測正確的個數(shù),取均值得到準(zhǔn)確率
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_real: mnist.test.labels})
這里用到了官方給的一個代碼文件input_data,我已經(jīng)放到工程里了。導(dǎo)入input_data,就可以用它來讀取MNIST數(shù)據(jù)集,非常方便。
整體來說,使用TensorFLow編程主要分為兩個階段,第一個階段是構(gòu)建模型,把網(wǎng)絡(luò)模型用代碼搭建起來。TensorFlow的本質(zhì)是數(shù)據(jù)流圖,因此這一階段其實是在規(guī)定數(shù)據(jù)的流動方向。第二個階段是開始訓(xùn)練,把數(shù)據(jù)輸入到模型中,并通過梯度下降等方法優(yōu)化變量的值。
首先,我們需要把權(quán)重weights和偏置biases設(shè)置成優(yōu)化變量,只有優(yōu)化變量才可以在后面被Optimizer優(yōu)化。并且需要為它們賦初值,這里將weights設(shè)為784×10的zero矩陣,把biases設(shè)為1×10的zero矩陣。
然后構(gòu)建模型。模型的輸入一般設(shè)置為placeholder,譯為占位符。在訓(xùn)練的過程中只有placeholder可以允許數(shù)據(jù)輸入。第一維的長度為None表示允許輸入任意長度,也就是說輸入可以是任意張圖像。
使用tf.log計算y中每個元素的對數(shù),并逐個與y_real相乘,再求和并取反,就得到了交叉熵。使用梯度下降優(yōu)化器最小化交叉熵作為訓(xùn)練步驟train_step。
接下來開始訓(xùn)練。首先要調(diào)用tf.initialize_all_variables()方法初始化所有變量。再創(chuàng)建一個tf.Session對象來控制整個訓(xùn)練流程。循環(huán)訓(xùn)練1000次,每次從訓(xùn)練集中隨機取100個數(shù)據(jù)進行訓(xùn)練。
在訓(xùn)練的過程中,每隔100次對模型進行一次評估。評估使用測試集數(shù)據(jù),統(tǒng)計正確預(yù)測的個數(shù)的百分比并輸出。結(jié)果如下:
$ /usr/bin/python2.7 /home/wjg/projects/MNISTRecognition/main.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.894
0.8989
0.9012
0.904
0.9105
0.9086
0.9137
0.9105
0.9174
Process finished with exit code 0
可見預(yù)測準(zhǔn)確率逐漸上升,最后達到91%。
五、總結(jié)
這是我第一次使用TensorFlow,它給我的感覺是非常方便,很貼合程序員的開發(fā)習(xí)慣。相比之下,之前用Caffe的時候就總是摸不著頭腦。當(dāng)然也可能是因為TensorFlow的官方文檔更友好的緣故。
本文在很多地方都語焉不詳,因為作者水平有限,有關(guān)深奧的數(shù)學(xué)原理都一帶而過。所以如果想要深入了解,還是推薦大家看官方教程。文末的參考資料一欄列出了一些有幫助的文章和視頻。
最后,可以從我的GitHub上下載完整代碼:https://github.com/jingedawang/MNISTRecognition
另外,熟悉多維矩陣操作(NumPy中的切片和廣播)可以更好的地理解代碼,建議閱讀參考資料最后一條:P
六、參考資料
MNIST機器學(xué)習(xí)入門 TensorFlow中文社區(qū)
莫煩 Tensorflow 16 Classification 分類學(xué)習(xí) 莫煩
Classification 分類學(xué)習(xí) 莫煩
Softmax 函數(shù)的特點和作用是什么? 知乎
CS231n課程筆記翻譯:線性分類筆記(下) 杜客譯
CS231n課程筆記翻譯:Python Numpy教程 杜客譯