前言
當(dāng)我們把使用Python訓(xùn)練的模型固化成PB文件之后,再進(jìn)行相應(yīng)的模型壓縮之后可以考慮往Mobile端移植了,本文主要講解TensorFlow Model移植到Android端。
TensorFlow1.0之后推出了Java版本,所以間接為Android開發(fā)TensorFlow程序帶來便利,以前我們需要用JNI去編寫,可是JNI難于調(diào)試,C++代碼對于普通Android開發(fā)者來講還是比Java繁瑣,所以本文以Java API講述開發(fā)過程。
正文
下面就正式開始一直TensorFlow model到Android中啦。
- 引入依賴
在TensorFlow更新到1.2.0版本之后,TensorFlow為廣大開發(fā)者提供了gradle依賴,現(xiàn)在我們想要引入TensorFlow只需要在gradle中加入
compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'
即可引入TensorFlow的庫。
- 復(fù)制PB文件
快速開發(fā)的話直接把PB文件放在assets文件夾里就行,如果正式上線的時候覺得PB文件一起打包較大的話可以放在服務(wù)器,打開APP的時候提示下載再復(fù)制進(jìn)去就好。
- 創(chuàng)建TensorFlowInterface類
這個類指的是我們讀取、識別等一系列方法存放的類,名字隨你取。
- 載入TensorFlow
在類的第一行加入這句話,會在加載類的時候首先加載TensorFlow
{
System.loadLibrary("tensorflow_inference");
}
- 定義常量
在這一步,我們先定義一些常量,比如輸入節(jié)點名、輸出節(jié)點名、輸出圖像的尺寸、通道、輸入節(jié)點數(shù)據(jù)類型、輸出節(jié)點數(shù)據(jù)類型。代碼如下
private static final String input_layer = "inputs/X";
private static final String output_layer = "output/predict";
private Context context;
private static final int HEIGHT = 64;
private static final int WIDTH = 256;
private static final int CHANNEL = 1;
private float[] inputs = new float[HEIGHT*WIDTH*CHANNEL];
private long[] outputs = new long[11];
- 初始化模型
這一步TensorFlow的模型會載入到內(nèi)存中,傳入assets和PB文件名
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(),"rounded_graph.pb");
- 喂數(shù)據(jù)給輸入節(jié)點
這里的參數(shù)是輸入節(jié)點名,輸入數(shù)據(jù),輸入數(shù)據(jù)的shape
inferenceInterface.feed(input_layer,inputs,1,16384);
- run session
inferenceInterface.run(new String[] { output_layer }, false);
- 獲取輸出數(shù)據(jù)
根據(jù)你在Python定義的輸出格式,new一個接收輸出數(shù)據(jù)的變量,從輸出節(jié)點獲取數(shù)據(jù)
byte[] outPuts = new byte[88];
inferenceInterface.fetch(output_layer,outPuts);
- 數(shù)據(jù)變換
從輸出節(jié)點獲取到數(shù)據(jù)之后就需要你對自己的輸出數(shù)據(jù)進(jìn)行操作,比如我在我們model里最終輸出的結(jié)果進(jìn)行了Argmax的操作,Argmax返回的值類型是Int64的,在Android里只有l(wèi)ong對應(yīng),但fetch方法的接受變量的參數(shù)類型只有double、float、int、byte,所以這里需要使用byte獲取,再進(jìn)行轉(zhuǎn)換。這里跟傳統(tǒng)的byte[8]轉(zhuǎn)long有些不同,具體處理方式要看你定義的數(shù)據(jù)格式,我這里的byte[8]用網(wǎng)上的方法轉(zhuǎn)long發(fā)現(xiàn)數(shù)值非常大,于是遍歷一遍byte[8],發(fā)現(xiàn)每個子元素都是相同的數(shù)值,所以這里只取第一個元素,組成一個新的數(shù)組,再對這個數(shù)組進(jìn)行解析。
long[] tOutputs=new long[11];
for (int i=0;i<11;i++)
{
int k=i*8;
tOutputs[i]=outPuts[k];
Log.i("output",tOutputs[i]+"");
}
String outputStr="";
for(int i=0;i<11;i++){
long char_idx=tOutputs[i];
long char_code = 0;
if (char_idx<10){
char_code = char_idx + (int)('0');
}
else if (char_idx<36){
char_code = char_idx-10 + (int)('A');
}
else if (char_idx<62){
char_code = char_idx + (int)('a');
}
outputStr+= (char)char_code;
}
后記
有Java API確實相比C++來的更直觀方便,而且native debug也比JNI好操作,等TensorFlowLite出來的時候,Android TensorFlow應(yīng)用會更加廣泛吧。