通過(guò)遷移訓(xùn)練來(lái)定制 TensorFlow 模型

在我上一篇 構(gòu)建一個(gè)基于 TensorFlow 的 Android 應(yīng)用 的文章最后提到:我們可以通過(guò)對(duì)現(xiàn)有模型進(jìn)行遷移訓(xùn)練(retrain)來(lái)定制我們自己的模型。

下面就通過(guò)對(duì)現(xiàn)有的 Google Inception-V3 模型進(jìn)行 retrain ,對(duì) 5 種花朵樣本數(shù)據(jù)的進(jìn)行訓(xùn)練,來(lái)完成一個(gè)可以識(shí)別五種花朵的模型,并將新訓(xùn)練的模型遷移到 Android 端平臺(tái)。

相關(guān)代碼可查看:GitHub 項(xiàng)目地址

安裝 TensorFlow (Mac 為例)

其他平臺(tái)可以直接參考官網(wǎng)說(shuō)明:Installing TensorFlow

首先檢查系統(tǒng)是否安裝了 Python

要安裝 TensorFlow ,你的系統(tǒng)必須依據(jù)安裝了以下任一 Python 版本:

  • Python 2.7
  • Python 3.3+

查看 Python 版本的命令:

# Python 2
$ python --version
# Python 3
$ python3 --version

如果你的系統(tǒng)還沒(méi)有安裝符合以上版本的 Python,現(xiàn)在安裝。

檢查 Pip 是否安裝

Pip 是 Python 的安裝和包管理工具,要使用本地 pip 安裝 TensorFlow,系統(tǒng)上必須安裝下面的任一版本的 pip :

  • pip for Python 2.7
  • pip3 for Python 3.n.

pip 或者 pip3 可能在你安裝 Python 的時(shí)候已經(jīng)安裝了,執(zhí)行以下任一命令確認(rèn)系統(tǒng)上是否安裝了 pip 或 pip3

$ pip -V  # for Python 2.7
$ pip3 -V # for Python 3.n 

建議使用 pip 或者 pip3 為 8.1 或者更新的版本安裝 TensorFlow,如果沒(méi)有安裝,執(zhí)行以下任一命令安裝或更新:

$ sudo easy_install --upgrade pip
$ sudo easy_install --upgrade six

通過(guò) pip 安裝 TensorFlow

# Python 2
$ pip install tensorflow
# Python 3
$ pip3 install tensorflow 

通過(guò)官方樣例測(cè)試 TensorFlow 是否正常安裝

進(jìn)入 Python 環(huán)境后輸入以下代碼,當(dāng)出現(xiàn) “Hello, TensorFlow!” 時(shí)表明已經(jīng)安裝成功,可正常使用 TensorFlow 了。

$ python
...
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> print(sess.run(hello))
Hello, TensorFlow!
image.png

準(zhǔn)備訓(xùn)練樣本

前面說(shuō)到我們要訓(xùn)練花朵的識(shí)別,這里我們直接找 Google 提供的一個(gè)訓(xùn)練樣本。我們?yōu)闉槟P瓦w移訓(xùn)練專(zhuān)門(mén)新建一個(gè)文件夾用于存放。

下載并解壓得到訓(xùn)練樣本

$ cd TensorFlowRetrainInceptionV3
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
$ tar xzf flower_photos.tgz

打開(kāi)訓(xùn)練樣本文件夾 flower_photos ,里面有 5 種類(lèi)別的花:daisy(雛菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香),每個(gè)類(lèi)別的大概有 600-700 張訓(xùn)練樣本圖片。

可以根據(jù)自身情況,減少訓(xùn)練樣本數(shù)量,減少訓(xùn)練時(shí)間。

開(kāi)始訓(xùn)練

下載 retrain 腳本
該腳本會(huì)自動(dòng)下載 google Inception v3 模型相關(guān)文件,retrain.py 是 Google 提供的遷移訓(xùn)練腳本。

$ cd TensorFlowRetrainInceptionV3
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

啟動(dòng) TensorBoard
TensorBoard 是為 TensorFlow 訓(xùn)練效果提供可視化的工具,具體效果如下圖所示:

$ cd TensorFlowRetrainInceptionV3
$ tensorboard --logdir training_summaries &

啟動(dòng) TensorBoard 會(huì)占用系統(tǒng) 6006 端口 ,再啟動(dòng)一個(gè)新的 TensorBoard 之前,必須要 kill 已在運(yùn)行的 TensorBoard 任務(wù)。

 $ pkill -f "tensorboard

啟動(dòng)訓(xùn)練腳本

在運(yùn)行 retrain.py 腳本時(shí),需要配置一些運(yùn)行命令參數(shù),比如指定模型輸入輸出相關(guān)名稱(chēng)和其他訓(xùn)練要求的配置。

$ cd TensorFlowRetrainInceptionV3
$ python3 retrain.py \
  --bottleneck_dir=bottlenecks \
  --how_many_training_steps=500 \
  --model_dir=inception \
  --summaries_dir=training_summaries/basic \
  --output_graph=retrained_graph.pb \
  --output_labels=retrained_labels.txt \
  --image_dir=flower_photos

如果不添加--how_many_training_steps=500配置,默認(rèn)值為4000,會(huì)相當(dāng)耗時(shí),建議測(cè)試階段可以減少這個(gè)值。

啟動(dòng)瀏覽器查看 TensorBoard

等待當(dāng)前目錄下的 bottlenecks 文件夾中的文件生成結(jié)束后,可以啟動(dòng)瀏覽器,在地址欄中輸入 localhost:6006 來(lái)查看訓(xùn)練進(jìn)度。

等到訓(xùn)練完成后,我們將得到新生成的 retrained_labels.txtretrained_graph.pb 這兩個(gè)模型相關(guān)文件。

測(cè)試重新訓(xùn)練后的模型

同樣的,我們先下載測(cè)試模型的腳本 label_image.py,測(cè)試重新訓(xùn)練后的模型的識(shí)別準(zhǔn)確率。

$ cd TensorFlowRetrainInceptionV3
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python3 label_image.py flower_photos/daisy/488202750_c420cbce61.jpg

經(jīng)過(guò)簡(jiǎn)單的實(shí)際測(cè)試,對(duì)已有樣本數(shù)據(jù)的識(shí)別準(zhǔn)確率基本在 90% 以上,可以知道重新訓(xùn)練后模型滿(mǎn)足使用要求,下面就按照前面的 Android 應(yīng)用集成 TensorFlow 教程,將新的模型導(dǎo)入到上面的項(xiàng)目中。

將新訓(xùn)練的 TensorFlow 模型移植到 Android 中

下圖是完成遷移訓(xùn)練后的新模型文件,新打包出來(lái)的 GraphDef 文件(PB文件)達(dá)到了87.5 MB 。考慮到我們要將這個(gè)模型移植到 Android 端去加載,這不僅會(huì)對(duì)應(yīng)用的運(yùn)行內(nèi)存造成巨大壓力,而且會(huì)導(dǎo)致安裝包增大很多,對(duì)于一個(gè)簡(jiǎn)單的花朵識(shí)別應(yīng)用來(lái)說(shuō),現(xiàn)在模型文件有些大了。因此,我們要考慮對(duì)模型文件進(jìn)行優(yōu)化,壓縮它的體積。

優(yōu)化模型文件

如前面所說(shuō),重新訓(xùn)練后的模型移植到 Android 平臺(tái)前需要對(duì)模型文件進(jìn)行優(yōu)化才行,下面我們就來(lái)看看官方推薦的幾種方法。

Optimize for inference

通過(guò)調(diào)用 optimize_for_inference 腳本,會(huì)自動(dòng)刪除模型中輸入層和輸出層之間所有不需要的節(jié)點(diǎn)。
同時(shí)該腳本還做了一些其他優(yōu)化以提高運(yùn)行速度。例如它把顯式批處理標(biāo)準(zhǔn)化運(yùn)算跟卷積權(quán)重進(jìn)行了合并,從而降低了計(jì)算量。

1. 用 bazel 工具構(gòu)建 optimize_for_inference 腳本文件

# 在 tensorflow 項(xiàng)目的根目錄(WORKSPACE 文件所在)執(zhí)行下面的 build 命令
bazel build tensorflow/python/tools:optimize_for_inference

build 完成后腳本文件路徑:tensorflow/python/tools/optimize_for_inference.py

如果還沒(méi)安裝 bazel ,建議先看看前一篇文章

2. 調(diào)用 optimize_for_inference.py 腳本進(jìn)行優(yōu)化

調(diào)用腳本時(shí),我們需要提供幾個(gè)命令參數(shù),比如輸入的 PB 文件路徑,輸出的 PB 文件路徑,輸入節(jié)點(diǎn)名以及輸出節(jié)點(diǎn)名等。

python3 -m tensorflow.python.tools.optimize_for_inference \
  --input=retrained_graph.pb \
  --output=optimized_graph.pb \
  --input_names="Cast" \
  --output_names="final_result"

查看腳本執(zhí)行完成后輸出的 optimized_graph.pb 文件

可以看到,經(jīng)過(guò) optimize_for_inference 優(yōu)化過(guò)后的模型依然是非常的大的。經(jīng)過(guò)這一次的優(yōu)化,文件只是變小了一些,但還不足以我們放到手機(jī)端去運(yùn)行,所以我們要進(jìn)一步的壓縮模型,同時(shí)還要保證準(zhǔn)確率。

Quantize the network weights

Android 項(xiàng)目中,通常我們把模型 PB 文件放在 assets 文件夾中加載,其實(shí)不管是直接打包進(jìn) APP 還是進(jìn)入 APP 后再進(jìn)行下載,模型文件占用太大的問(wèn)題還是沒(méi)得到解決。我們知道 Android 的 APK 文件在構(gòu)建過(guò)程中會(huì)進(jìn)行 zip 壓縮。那有沒(méi)有一種行之有效的方法在不過(guò)多的降低精確度的情況下壓縮更大的空間呢?

Google 就提供了這么一個(gè)腳本,經(jīng)過(guò)這個(gè)腳本優(yōu)化后的模型 PB 文件大小不會(huì)改變,但會(huì)有更多的可利用的重復(fù)性,所以在打包構(gòu)建APK 包時(shí)對(duì) PB 文件進(jìn)行 zip 壓縮后,最終按照中的 PB 文件會(huì)縮小大約 3~4 倍的大小。

1. 用 bazel 工具構(gòu)建 quantize_graph 腳本

# 在 tensorflow 項(xiàng)目的根目錄(WORKSPACE 文件所在)執(zhí)行下面的 build 命令
$ bazel build tensorflow/tools/quantization:quantize_graph.py

build 完成后腳本文件路徑:tensorflow/tools/quantization/quantize_graph.py

2. 調(diào)用 quantize_graph 腳本進(jìn)行優(yōu)化
將生成的 quantize_graph.py 文件拷貝到 retrain 文件夾下,在目錄下執(zhí)行腳本。

輸入的參數(shù)依然是:輸入的 PB 文件路徑,輸出的 PB 文件路徑,輸出節(jié)點(diǎn)名,這里還有個(gè)特別的參數(shù) mode ,這個(gè)參數(shù)是告訴腳本我們選擇哪種壓縮方式,這里我們選擇了對(duì)權(quán)重進(jìn)行四舍五入。

python3 -m quantize_graph \
  --input=optimized_graph.pb \
  --output=rounded_graph.pb \
  --output_node_names=final_result \
  --mode=weights_rounded

可以看到最終的輸出文件 rounded_graph.pb 大小并沒(méi)有改變,下面我們就將優(yōu)化后的遷移訓(xùn)練模型文件重新導(dǎo)入到我們?cè)瓉?lái)的 Android 項(xiàng)目中。

把新訓(xùn)練的模型導(dǎo)入到 Android 中

同樣的,我們把新訓(xùn)練的模型 pb 文件和 labels 文件復(fù)制到 assets 文件夾下。

因?yàn)樾掠?xùn)練的模型,輸入和輸出層名稱(chēng)也發(fā)生的改變,這里要修改之前 TensorFlowImageClassifier.create 方法傳入的參數(shù)。

   /**
     * retrained inception-v3 model, flower classifier
     */
    private static final int INPUT_SIZE = 299;
    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 1f;
    private static final String INPUT_NAME = "Mul";
    private static final String OUTPUT_NAME = "final_result";
    private static final String MODEL_FILE = "file:///android_asset/model/rounded_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/model/retrained_labels.txt";

最終打包出來(lái)的 APK 文件,可以看到壓縮后的 pb 文件只有 22 MB

參考

教程:在 Mac OS X 上安裝 TensorFlow
當(dāng)Android開(kāi)發(fā)者遇見(jiàn)TensorFlow
Retrain a tensorflow model based on Inception v3
TensorFlow Mobile模型壓縮

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 1. 介紹 首先讓我們來(lái)看看TensorFlow! 但是在我們開(kāi)始之前,我們先來(lái)看看Python API中的Ten...
    JasonJe閱讀 11,978評(píng)論 1 32
  • Android 自定義View的各種姿勢(shì)1 Activity的顯示之ViewRootImpl詳解 Activity...
    passiontim閱讀 179,012評(píng)論 25 709
  • 2016.12.12開(kāi)始嘗試古箏,長(zhǎng)這么大第一次系統(tǒng)地學(xué)習(xí)一門(mén)樂(lè)器。2017年,你好!我會(huì)好好成長(zhǎng)的!
    四只閱讀 232評(píng)論 1 0
  • 第一次在那家小店就餐后,我們打了85分。 那天的服務(wù)生是一個(gè)靦腆的男孩?!澳c(diǎn)兩菜一湯就可以了,再點(diǎn)吃不完,浪費(fèi)了...
    老丁子閱讀 317評(píng)論 2 1
  • 離北京最近的沙漠是庫(kù)布齊,也是中國(guó)第七大沙漠,“庫(kù)布齊”是蒙古語(yǔ),意思是弓上的弦,這個(gè)弓是黃河,庫(kù)布齊像弦一樣就掛...
    蘭苑1972閱讀 471評(píng)論 1 1

友情鏈接更多精彩內(nèi)容