記錄如何在IOS上使用TensorflowLite部署自己的深度學(xué)習(xí)模型,后面考慮加入Android,參考TensorflowLite官網(wǎng)的實(shí)例。
環(huán)境配置
在自己的python 環(huán)境中使用pip 按照好 tensorflow:
pip3 install tensorflow
從github 下載工程文件:
git clone https://github.com/googlecodelabs/tensorflow-for-poets-2
下載數(shù)據(jù)集:
wget http://download.tensorflow.org/example_images/flower_photos.tgz
該數(shù)據(jù)集包含5種不同的花類型,我們用來訓(xùn)練模型判斷花的種類
下載后解壓到tensorflow-for-poets-2/tf_files/ 路徑下:

模型訓(xùn)練
在scripts路徑下包含了幾個(gè)腳本文件,其中retrain.py文件用于使用tensorflow 在 imagenet 數(shù)據(jù)集上訓(xùn)練好的 Inception和 mobilenet模型(運(yùn)行的時(shí)候會(huì)自動(dòng)下載)重新訓(xùn)練用于我們的花類型分類任務(wù), 里面也包含了大量的可以設(shè)置的參數(shù):
--architecture ARCHITECTURE
Which model architecture to use. 'inception_v3' is the
most accurate, but also the slowest. For faster or
smaller models, chose a MobileNet with the form
'mobilenet_<parameter size>_<input_size>[_quantized]'.
For example, 'mobilenet_1.0_224' will pick a model
that is 17 MB in size and takes 224 pixel input
images, while 'mobilenet_0.25_128_quantized' will
choose a much less accurate, but smaller and faster
network that's 920 KB on disk and takes 128x128
images. See
https://research.googleblog.com/2017/06/mobilenets-
open-source-models-for.html for more information on
Mobilenet.
訓(xùn)練腳本:
python scripts/retrain.py \
--output_graph=tf_files/retrained_graph.pb \
--output_labels=tf_files/retrained_labels.txt \
--image_dir=tf_files/flower_photos \
--architecture=mobilenet_1.0_224 \
--summaries_dir tf_files/training_summaries/mobilenet_1.0_244

打開tensorboard可以查看finetune過程中的loss/accuracy的變化曲線:
tensorboard --logdir=tf_files/training_summaries/mobilenet_1.0_244

模型轉(zhuǎn)換
將訓(xùn)練好的靜態(tài)圖文件轉(zhuǎn)換為tflite模型的時(shí)候我們使用google官方提供的轉(zhuǎn)換工具toco, 關(guān)于toco的介紹可以查看我的另一篇文章Tensorflow移動(dòng)端模型轉(zhuǎn)換
IMAGE_SIZE=224
toco \
--graph_def_file=tf_files/retrained_graph.pb \
--output_file=tf_files/optimized_graph.lite \
--output_format=TFLITE \
--input_shape=1,${IMAGE_SIZE},${IMAGE_SIZE},3 \
--input_array=input \
--output_array=final_result \
--inference_type=FLOAT \
--inference_input_type=FLOAT
衡量tflite模型的準(zhǔn)確度
實(shí)際上在轉(zhuǎn)換模型的過程中我們的模型的精度會(huì)有一定損失,獲得轉(zhuǎn)換好的tflite模型之后,我們還是希望能夠能夠先衡量下轉(zhuǎn)換好的模型精度,這需要直接在python腳本中調(diào)用tflite模型解釋器,然后在測(cè)試數(shù)據(jù)集上計(jì)算tflite模型的精度:
下面給出一個(gè)調(diào)用的腳本(tensorflow接口變換很快,不保證可用):
import numpy as np
import tensorflow as tf
from skimage.transform import resize
import cv2
import os
def predict(interpreter, input_shape, input_data):
"""generate softmax predictions for input_data
interpreter: the enviroment to run model
input_shape: config information for resize input_data
input_data: user data
"""
input_data = resize(img, input_shape[1:])
input_data = input_data.reshape(input_shape)
input_data = input_data.astype("float32")
# input_data = (input_data - 127.5) / 127.5
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
index = np.argmax(output_data)
return index
if __name__ == "__main__":
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="tf_files/optimized_graph.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on customer data
input_shape = input_details[0]['shape']
# load sub classes
data_path = "/Users/yuhua.cheng/Opt/temp/tensorflow-for-poets-2/tf_files/flower_photos"
sub_classes = [f for f in sorted(os.listdir(data_path))if os.path.isdir(os.path.join(data_path, f))]
print(sub_classes)
count = 0
total = 0
for label, sub_class in enumerate(sub_classes):
print("processing: ", sub_class)
sub_path = os.path.join(data_path, sub_class)
img_files = [f for f in os.listdir(sub_path) if not f.startswith('.')]
for img_file in img_files:
img = cv2.imread(os.path.join(sub_path, img_file), -1)
pred = predict(interpreter, input_shape, img)
if pred == label:
count += 1
total += 1
print('accuracy:', count / total)
在IOS工程調(diào)用tflite模型
先安裝必要的相關(guān)文件:
xcode-select --install
sudo gem install cocoapods
pod install --project-directory=ios/tflite/
打開IOS工程:
open ios/tflite/tflite_camera_example.xcworkspace
將模型文件和label文件復(fù)制到工程對(duì)應(yīng)路徑:
cp tf_files/optimized_graph.lite ios/tflite/data/graph.lite
cp tf_files/retrained_labels.txt ios/tflite/data/labels.txt
連接手機(jī)直接運(yùn)行:
在手機(jī)上復(fù)現(xiàn)的結(jié)果:
---------后面會(huì)加入在官方教程的基礎(chǔ)上轉(zhuǎn)換以及調(diào)用自己訓(xùn)練好的模型結(jié)果-------
問題記錄
- toco 將原有的simplenet.pb轉(zhuǎn)換為tflite的時(shí)候報(bào)錯(cuò):
原始模型結(jié)構(gòu):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 227, 227, 3) 0
_________________________________________________________________
block1_0_conv (Conv2D) (None, 76, 76, 64) 9408
_________________________________________________________________
block1_0_bn (BatchNormalizat (None, 76, 76, 64) 192
_________________________________________________________________
block1_0_relu (Activation) (None, 76, 76, 64) 0
_________________________________________________________________
block1_0_dropout (Dropout) (None, 76, 76, 64) 0
_________________________________________________________________
block1_1_conv (Conv2D) (None, 76, 76, 32) 18432
_________________________________________________________________
block1_1_bn (BatchNormalizat (None, 76, 76, 32) 96
_________________________________________________________________
block1_1_relu (Activation) (None, 76, 76, 32) 0
_________________________________________________________________
block1_1_dropout (Dropout) (None, 76, 76, 32) 0
_________________________________________________________________
block2_0_conv (Conv2D) (None, 76, 76, 32) 9216
_________________________________________________________________
block2_0_bn (BatchNormalizat (None, 76, 76, 32) 96
_________________________________________________________________
block2_0_relu (Activation) (None, 76, 76, 32) 0
_________________________________________________________________
block2_0_dropout (Dropout) (None, 76, 76, 32) 0
_________________________________________________________________
block2_1_conv (Conv2D) (None, 76, 76, 32) 9216
_________________________________________________________________
block2_1_bn (BatchNormalizat (None, 76, 76, 32) 96
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 38, 38, 32) 0
_________________________________________________________________
block2_1_relu (Activation) (None, 38, 38, 32) 0
_________________________________________________________________
block2_1_dropout (Dropout) (None, 38, 38, 32) 0
_________________________________________________________________
block2_2_conv (Conv2D) (None, 38, 38, 32) 9216
_________________________________________________________________
block2_2_bn (BatchNormalizat (None, 38, 38, 32) 96
_________________________________________________________________
block2_2_relu (Activation) (None, 38, 38, 32) 0
_________________________________________________________________
block2_2_dropout (Dropout) (None, 38, 38, 32) 0
_________________________________________________________________
block3_0_conv (Conv2D) (None, 38, 38, 32) 9216
_________________________________________________________________
block3_0_bn (BatchNormalizat (None, 38, 38, 32) 96
_________________________________________________________________
block3_0_relu (Activation) (None, 38, 38, 32) 0
_________________________________________________________________
block3_0_dropout (Dropout) (None, 38, 38, 32) 0
_________________________________________________________________
block4_0_conv (Conv2D) (None, 38, 38, 64) 18432
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 19, 19, 64) 0
_________________________________________________________________
block4_0_bn (BatchNormalizat (None, 19, 19, 64) 192
_________________________________________________________________
block4_0_relu (Activation) (None, 19, 19, 64) 0
_________________________________________________________________
block4_0_dropout (Dropout) (None, 19, 19, 64) 0
_________________________________________________________________
block4_1_conv (Conv2D) (None, 19, 19, 64) 36864
_________________________________________________________________
block4_1_bn (BatchNormalizat (None, 19, 19, 64) 192
_________________________________________________________________
block4_1_relu (Activation) (None, 19, 19, 64) 0
_________________________________________________________________
block4_1_dropout (Dropout) (None, 19, 19, 64) 0
_________________________________________________________________
block4_2_conv (Conv2D) (None, 19, 19, 64) 36864
_________________________________________________________________
block4_2_bn (BatchNormalizat (None, 19, 19, 64) 192
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 9, 9, 64) 0
_________________________________________________________________
block4_2_relu (Activation) (None, 9, 9, 64) 0
_________________________________________________________________
block4_2_dropout (Dropout) (None, 9, 9, 64) 0
_________________________________________________________________
cccp4 (Conv2D) (None, 9, 9, 256) 16640
_________________________________________________________________
cccp5 (Conv2D) (None, 9, 9, 64) 16448
_________________________________________________________________
poolcp5 (MaxPooling2D) (None, 4, 4, 64) 0
_________________________________________________________________
cccp6 (Conv2D) (None, 4, 4, 64) 36928
_________________________________________________________________
poolcp6 (GlobalMaxPooling2D) (None, 64) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 228,778
Trainable params: 227,946
Non-trainable params: 832
_________________________________________________________________
轉(zhuǎn)換問題:
Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.contrib.lite.toco_convert(). Here is a list of operators for which you will need custom implementations: Max.\n'
問題原因: keras里面一些層使用Tensorflow封裝,在Tensorflow 轉(zhuǎn)換為tflite的時(shí)候不完全支持: https://github.com/tensorflow/tensorflow/issues/20042
擬解決的方案: 在tensorflow中,使用tensorflow自己的實(shí)現(xiàn)重新實(shí)現(xiàn)一遍。
更新tensorflow 版本從1.10到1.12問題解決, 成功轉(zhuǎn)換
pip install --upgrade tensorflow
- xcode 調(diào)用tflite報(bào)錯(cuò):
Op builtin_code out or range: 82. Are you using old TFLite binary with newer model?
Registration failed.
打斷點(diǎn)發(fā)現(xiàn)問題出在:
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
最后發(fā)現(xiàn)將第一個(gè)卷積層stride 3 改為stride 2便可,可能是TFLite中沒有相應(yīng)的stride 3 實(shí)現(xiàn)。
Reference
- 如何在IOS上部署自己的深度學(xué)習(xí)模型(Tensorflow官方例子):
https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/#0 - https://v-play.net/cross-platform-development/machine-learning-add-image-classification-for-ios-and-android-with-qt-and-tensorflow
- https://heartbeat.fritz.ai/neural-networks-on-mobile-devices-with-tensorflow-lite-a-tutorial-85b41f53230c
- 如何進(jìn)行模型量化: https://www.tensorflow.org/lite/performance/post_training_quantization
- tensorflow 模型和 tflite模型 準(zhǔn)確度不一致: https://github.com/tensorflow/tensorflow/issues/21921