前言
一般機(jī)器學(xué)習(xí)框架都使用MNIST作為入門。就像"Hello World"對于任何一門編程語言一樣,要想入門機(jī)器學(xué)習(xí),就先要掌握MNIST。
筆者在學(xué)習(xí)的時(shí)候Tensorflow已經(jīng)成為十分流行的機(jī)器學(xué)習(xí)框架,網(wǎng)上有大量的“資源”,但是大多都限于皮毛。
很多教程就是給你一段代碼然后隨便講兩句,這樣對新手并不友好。
因此我萌生了寫一個(gè)詳解的想法。
筆者是一名網(wǎng)絡(luò)工程在讀大學(xué)生,知識水平有限,未必能做到面面俱到且處處正確,如有錯(cuò)誤請指出。
源代碼
- 訓(xùn)練集
請點(diǎn)擊此處下載。
提取碼:xgpy - 源代碼
在源代碼同一目錄下新建文件夾“訓(xùn)練集”,把百度云連接里面的.gz文件放入該文件夾。
# -*- coding: utf-8 -*-
import tensorflow as tf
import input_data
mnist = input_data.read_data_sets('./訓(xùn)練集', one_hot=True)
'''
#構(gòu)建運(yùn)算圖
'''
# X Y 都是占位符 占位而已 不表示具體的數(shù)據(jù)
x = tf.placeholder("float",[None,784]) # 圖像的大小為784;None表示第一個(gè)維度可以是任意長度
# 一個(gè)Variable代表一個(gè)可修改的張量,它們可以用于計(jì)算輸入值,也可以在計(jì)算中被修改
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
# 計(jì)算交叉熵
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 梯度下降算法(gradient descent algorithm)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 在運(yùn)行計(jì)算之前,我們需要添加一個(gè)操作來初始化我們創(chuàng)建的變量:
init = tf.global_variables_initializer()
# 在一個(gè)Session里面啟動我們的模型,并且初始化變量:
sess = tf.Session()
sess.run(init)
# 訓(xùn)練模型1000次
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#print('-**-',accuracy,type(accuracy))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
詳解
這一大段代碼實(shí)現(xiàn)的功能是:
建立 y = w*x+b 的模型,其中x是輸入的
可以直觀的看到,以上代碼分為三部分:構(gòu)建圖、定義會話、啟動圖。
構(gòu)建圖
構(gòu)建圖也分為定義變量、定義交叉熵、定義優(yōu)化方法。
- 定義變量
由定義方法分類,本實(shí)例中主要有兩種變量。
第一類是由tf.Variable()定義的w、b
第二類是由tf.placeholder()定義的y_、x
順帶提一句y = tf.nn.softmax(tf.matmul(x,W) + b)是這兩者結(jié)合起來的。
那么這兩類有什么區(qū)別呢?
一般而言,Varibale主要用來保存tensorflow圖中的一些結(jié)構(gòu)中的參數(shù),如本例中的w權(quán)重,b偏置。需要初始化。
plceholder主要用來把要訓(xùn)練/測試的數(shù)據(jù)輸入模型,每次訓(xùn)練plceholder都有不一樣的值。在Session.run(feed_dict={})中的參數(shù)確定實(shí)際的值。
可視化網(wǎng)頁
https://www.cs.ryerson.ca/~aharley/vis/fc/
https://www.cs.ryerson.ca/~aharley/vis/conv/flat.html
https://poloclub.github.io/cnn-explainer/