在我上一篇 構(gòu)建一個(gè)基于 TensorFlow 的 Android 應(yīng)用 的文章最后提到:我們可以通過(guò)對(duì)現(xiàn)有模型進(jìn)行遷移訓(xùn)練(retrain)來(lái)定制我們自己的模型。
下面就通過(guò)對(duì)現(xiàn)有的 Google Inception-V3 模型進(jìn)行 retrain ,對(duì) 5 種花朵樣本數(shù)據(jù)的進(jìn)行訓(xùn)練,來(lái)完成一個(gè)可以識(shí)別五種花朵的模型,并將新訓(xùn)練的模型遷移到 Android 端平臺(tái)。
相關(guān)代碼可查看:GitHub 項(xiàng)目地址
安裝 TensorFlow (Mac 為例)
其他平臺(tái)可以直接參考官網(wǎng)說(shuō)明:Installing TensorFlow
首先檢查系統(tǒng)是否安裝了 Python
要安裝 TensorFlow ,你的系統(tǒng)必須依據(jù)安裝了以下任一 Python 版本:
- Python 2.7
- Python 3.3+
查看 Python 版本的命令:
# Python 2
$ python --version
# Python 3
$ python3 --version
如果你的系統(tǒng)還沒(méi)有安裝符合以上版本的 Python,現(xiàn)在安裝。
檢查 Pip 是否安裝
Pip 是 Python 的安裝和包管理工具,要使用本地 pip 安裝 TensorFlow,系統(tǒng)上必須安裝下面的任一版本的 pip :
- pip for Python 2.7
- pip3 for Python 3.n.
pip 或者 pip3 可能在你安裝 Python 的時(shí)候已經(jīng)安裝了,執(zhí)行以下任一命令確認(rèn)系統(tǒng)上是否安裝了 pip 或 pip3
$ pip -V # for Python 2.7
$ pip3 -V # for Python 3.n
建議使用 pip 或者 pip3 為 8.1 或者更新的版本安裝 TensorFlow,如果沒(méi)有安裝,執(zhí)行以下任一命令安裝或更新:
$ sudo easy_install --upgrade pip
$ sudo easy_install --upgrade six
通過(guò) pip 安裝 TensorFlow
# Python 2
$ pip install tensorflow
# Python 3
$ pip3 install tensorflow
通過(guò)官方樣例測(cè)試 TensorFlow 是否正常安裝
進(jìn)入 Python 環(huán)境后輸入以下代碼,當(dāng)出現(xiàn) “Hello, TensorFlow!” 時(shí)表明已經(jīng)安裝成功,可正常使用 TensorFlow 了。
$ python
...
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> print(sess.run(hello))
Hello, TensorFlow!

準(zhǔn)備訓(xùn)練樣本
前面說(shuō)到我們要訓(xùn)練花朵的識(shí)別,這里我們直接找 Google 提供的一個(gè)訓(xùn)練樣本。我們?yōu)闉槟P瓦w移訓(xùn)練專(zhuān)門(mén)新建一個(gè)文件夾用于存放。
下載并解壓得到訓(xùn)練樣本
$ cd TensorFlowRetrainInceptionV3
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
$ tar xzf flower_photos.tgz
打開(kāi)訓(xùn)練樣本文件夾 flower_photos ,里面有 5 種類(lèi)別的花:daisy(雛菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香),每個(gè)類(lèi)別的大概有 600-700 張訓(xùn)練樣本圖片。
可以根據(jù)自身情況,減少訓(xùn)練樣本數(shù)量,減少訓(xùn)練時(shí)間。
開(kāi)始訓(xùn)練
下載 retrain 腳本
該腳本會(huì)自動(dòng)下載 google Inception v3 模型相關(guān)文件,retrain.py 是 Google 提供的遷移訓(xùn)練腳本。
$ cd TensorFlowRetrainInceptionV3
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py
啟動(dòng) TensorBoard
TensorBoard 是為 TensorFlow 訓(xùn)練效果提供可視化的工具,具體效果如下圖所示:

$ cd TensorFlowRetrainInceptionV3
$ tensorboard --logdir training_summaries &
啟動(dòng) TensorBoard 會(huì)占用系統(tǒng) 6006 端口 ,再啟動(dòng)一個(gè)新的 TensorBoard 之前,必須要 kill 已在運(yùn)行的 TensorBoard 任務(wù)。
$ pkill -f "tensorboard
啟動(dòng)訓(xùn)練腳本
在運(yùn)行 retrain.py 腳本時(shí),需要配置一些運(yùn)行命令參數(shù),比如指定模型輸入輸出相關(guān)名稱(chēng)和其他訓(xùn)練要求的配置。
$ cd TensorFlowRetrainInceptionV3
$ python3 retrain.py \
--bottleneck_dir=bottlenecks \
--how_many_training_steps=500 \
--model_dir=inception \
--summaries_dir=training_summaries/basic \
--output_graph=retrained_graph.pb \
--output_labels=retrained_labels.txt \
--image_dir=flower_photos
如果不添加
--how_many_training_steps=500配置,默認(rèn)值為4000,會(huì)相當(dāng)耗時(shí),建議測(cè)試階段可以減少這個(gè)值。
啟動(dòng)瀏覽器查看 TensorBoard
等待當(dāng)前目錄下的 bottlenecks 文件夾中的文件生成結(jié)束后,可以啟動(dòng)瀏覽器,在地址欄中輸入 localhost:6006 來(lái)查看訓(xùn)練進(jìn)度。
等到訓(xùn)練完成后,我們將得到新生成的 retrained_labels.txt 和 retrained_graph.pb 這兩個(gè)模型相關(guān)文件。
測(cè)試重新訓(xùn)練后的模型
同樣的,我們先下載測(cè)試模型的腳本 label_image.py,測(cè)試重新訓(xùn)練后的模型的識(shí)別準(zhǔn)確率。
$ cd TensorFlowRetrainInceptionV3
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python3 label_image.py flower_photos/daisy/488202750_c420cbce61.jpg

經(jīng)過(guò)簡(jiǎn)單的實(shí)際測(cè)試,對(duì)已有樣本數(shù)據(jù)的識(shí)別準(zhǔn)確率基本在 90% 以上,可以知道重新訓(xùn)練后模型滿(mǎn)足使用要求,下面就按照前面的 Android 應(yīng)用集成 TensorFlow 教程,將新的模型導(dǎo)入到上面的項(xiàng)目中。
將新訓(xùn)練的 TensorFlow 模型移植到 Android 中
下圖是完成遷移訓(xùn)練后的新模型文件,新打包出來(lái)的 GraphDef 文件(PB文件)達(dá)到了87.5 MB 。考慮到我們要將這個(gè)模型移植到 Android 端去加載,這不僅會(huì)對(duì)應(yīng)用的運(yùn)行內(nèi)存造成巨大壓力,而且會(huì)導(dǎo)致安裝包增大很多,對(duì)于一個(gè)簡(jiǎn)單的花朵識(shí)別應(yīng)用來(lái)說(shuō),現(xiàn)在模型文件有些大了。因此,我們要考慮對(duì)模型文件進(jìn)行優(yōu)化,壓縮它的體積。

優(yōu)化模型文件
如前面所說(shuō),重新訓(xùn)練后的模型移植到 Android 平臺(tái)前需要對(duì)模型文件進(jìn)行優(yōu)化才行,下面我們就來(lái)看看官方推薦的幾種方法。
Optimize for inference
通過(guò)調(diào)用 optimize_for_inference 腳本,會(huì)自動(dòng)刪除模型中輸入層和輸出層之間所有不需要的節(jié)點(diǎn)。
同時(shí)該腳本還做了一些其他優(yōu)化以提高運(yùn)行速度。例如它把顯式批處理標(biāo)準(zhǔn)化運(yùn)算跟卷積權(quán)重進(jìn)行了合并,從而降低了計(jì)算量。
1. 用 bazel 工具構(gòu)建 optimize_for_inference 腳本文件
# 在 tensorflow 項(xiàng)目的根目錄(WORKSPACE 文件所在)執(zhí)行下面的 build 命令
bazel build tensorflow/python/tools:optimize_for_inference
build 完成后腳本文件路徑:tensorflow/python/tools/optimize_for_inference.py
如果還沒(méi)安裝 bazel ,建議先看看前一篇文章
2. 調(diào)用 optimize_for_inference.py 腳本進(jìn)行優(yōu)化
調(diào)用腳本時(shí),我們需要提供幾個(gè)命令參數(shù),比如輸入的 PB 文件路徑,輸出的 PB 文件路徑,輸入節(jié)點(diǎn)名以及輸出節(jié)點(diǎn)名等。
python3 -m tensorflow.python.tools.optimize_for_inference \
--input=retrained_graph.pb \
--output=optimized_graph.pb \
--input_names="Cast" \
--output_names="final_result"
查看腳本執(zhí)行完成后輸出的 optimized_graph.pb 文件

可以看到,經(jīng)過(guò) optimize_for_inference 優(yōu)化過(guò)后的模型依然是非常的大的。經(jīng)過(guò)這一次的優(yōu)化,文件只是變小了一些,但還不足以我們放到手機(jī)端去運(yùn)行,所以我們要進(jìn)一步的壓縮模型,同時(shí)還要保證準(zhǔn)確率。
Quantize the network weights
Android 項(xiàng)目中,通常我們把模型 PB 文件放在 assets 文件夾中加載,其實(shí)不管是直接打包進(jìn) APP 還是進(jìn)入 APP 后再進(jìn)行下載,模型文件占用太大的問(wèn)題還是沒(méi)得到解決。我們知道 Android 的 APK 文件在構(gòu)建過(guò)程中會(huì)進(jìn)行 zip 壓縮。那有沒(méi)有一種行之有效的方法在不過(guò)多的降低精確度的情況下壓縮更大的空間呢?
Google 就提供了這么一個(gè)腳本,經(jīng)過(guò)這個(gè)腳本優(yōu)化后的模型 PB 文件大小不會(huì)改變,但會(huì)有更多的可利用的重復(fù)性,所以在打包構(gòu)建APK 包時(shí)對(duì) PB 文件進(jìn)行 zip 壓縮后,最終按照中的 PB 文件會(huì)縮小大約 3~4 倍的大小。
1. 用 bazel 工具構(gòu)建 quantize_graph 腳本
# 在 tensorflow 項(xiàng)目的根目錄(WORKSPACE 文件所在)執(zhí)行下面的 build 命令
$ bazel build tensorflow/tools/quantization:quantize_graph.py
build 完成后腳本文件路徑:tensorflow/tools/quantization/quantize_graph.py
2. 調(diào)用 quantize_graph 腳本進(jìn)行優(yōu)化
將生成的 quantize_graph.py 文件拷貝到 retrain 文件夾下,在目錄下執(zhí)行腳本。
輸入的參數(shù)依然是:輸入的 PB 文件路徑,輸出的 PB 文件路徑,輸出節(jié)點(diǎn)名,這里還有個(gè)特別的參數(shù) mode ,這個(gè)參數(shù)是告訴腳本我們選擇哪種壓縮方式,這里我們選擇了對(duì)權(quán)重進(jìn)行四舍五入。
python3 -m quantize_graph \
--input=optimized_graph.pb \
--output=rounded_graph.pb \
--output_node_names=final_result \
--mode=weights_rounded
可以看到最終的輸出文件 rounded_graph.pb 大小并沒(méi)有改變,下面我們就將優(yōu)化后的遷移訓(xùn)練模型文件重新導(dǎo)入到我們?cè)瓉?lái)的 Android 項(xiàng)目中。

把新訓(xùn)練的模型導(dǎo)入到 Android 中
同樣的,我們把新訓(xùn)練的模型 pb 文件和 labels 文件復(fù)制到 assets 文件夾下。

因?yàn)樾掠?xùn)練的模型,輸入和輸出層名稱(chēng)也發(fā)生的改變,這里要修改之前 TensorFlowImageClassifier.create 方法傳入的參數(shù)。
/**
* retrained inception-v3 model, flower classifier
*/
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 1f;
private static final String INPUT_NAME = "Mul";
private static final String OUTPUT_NAME = "final_result";
private static final String MODEL_FILE = "file:///android_asset/model/rounded_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/model/retrained_labels.txt";
最終打包出來(lái)的 APK 文件,可以看到壓縮后的 pb 文件只有 22 MB

參考
教程:在 Mac OS X 上安裝 TensorFlow
當(dāng)Android開(kāi)發(fā)者遇見(jiàn)TensorFlow
Retrain a tensorflow model based on Inception v3
TensorFlow Mobile模型壓縮