TensorFlow Android調(diào)用

前言

當(dāng)我們把使用Python訓(xùn)練的模型固化成PB文件之后,再進(jìn)行相應(yīng)的模型壓縮之后可以考慮往Mobile端移植了,本文主要講解TensorFlow Model移植到Android端。

TensorFlow1.0之后推出了Java版本,所以間接為Android開發(fā)TensorFlow程序帶來便利,以前我們需要用JNI去編寫,可是JNI難于調(diào)試,C++代碼對于普通Android開發(fā)者來講還是比Java繁瑣,所以本文以Java API講述開發(fā)過程。

正文

下面就正式開始一直TensorFlow model到Android中啦。

  • 引入依賴

在TensorFlow更新到1.2.0版本之后,TensorFlow為廣大開發(fā)者提供了gradle依賴,現(xiàn)在我們想要引入TensorFlow只需要在gradle中加入

compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'

即可引入TensorFlow的庫。

  • 復(fù)制PB文件

快速開發(fā)的話直接把PB文件放在assets文件夾里就行,如果正式上線的時候覺得PB文件一起打包較大的話可以放在服務(wù)器,打開APP的時候提示下載再復(fù)制進(jìn)去就好。

  • 創(chuàng)建TensorFlowInterface類

這個類指的是我們讀取、識別等一系列方法存放的類,名字隨你取。

  • 載入TensorFlow

在類的第一行加入這句話,會在加載類的時候首先加載TensorFlow

    {
        System.loadLibrary("tensorflow_inference");
    }
  • 定義常量

在這一步,我們先定義一些常量,比如輸入節(jié)點名、輸出節(jié)點名、輸出圖像的尺寸、通道、輸入節(jié)點數(shù)據(jù)類型、輸出節(jié)點數(shù)據(jù)類型。代碼如下

    private static final String input_layer = "inputs/X";
    private static final String output_layer = "output/predict";

    private Context context;
    private static final int HEIGHT = 64;
    private static final int WIDTH = 256;
    private static final int CHANNEL = 1;

    private float[] inputs = new float[HEIGHT*WIDTH*CHANNEL];
    private long[] outputs = new long[11];
  • 初始化模型

這一步TensorFlow的模型會載入到內(nèi)存中,傳入assets和PB文件名

TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(),"rounded_graph.pb"); 
  • 喂數(shù)據(jù)給輸入節(jié)點

這里的參數(shù)是輸入節(jié)點名,輸入數(shù)據(jù),輸入數(shù)據(jù)的shape

inferenceInterface.feed(input_layer,inputs,1,16384);
  • run session
inferenceInterface.run(new String[] { output_layer }, false);
  • 獲取輸出數(shù)據(jù)

根據(jù)你在Python定義的輸出格式,new一個接收輸出數(shù)據(jù)的變量,從輸出節(jié)點獲取數(shù)據(jù)

byte[] outPuts = new byte[88];
inferenceInterface.fetch(output_layer,outPuts);
  • 數(shù)據(jù)變換

從輸出節(jié)點獲取到數(shù)據(jù)之后就需要你對自己的輸出數(shù)據(jù)進(jìn)行操作,比如我在我們model里最終輸出的結(jié)果進(jìn)行了Argmax的操作,Argmax返回的值類型是Int64的,在Android里只有l(wèi)ong對應(yīng),但fetch方法的接受變量的參數(shù)類型只有double、float、int、byte,所以這里需要使用byte獲取,再進(jìn)行轉(zhuǎn)換。這里跟傳統(tǒng)的byte[8]轉(zhuǎn)long有些不同,具體處理方式要看你定義的數(shù)據(jù)格式,我這里的byte[8]用網(wǎng)上的方法轉(zhuǎn)long發(fā)現(xiàn)數(shù)值非常大,于是遍歷一遍byte[8],發(fā)現(xiàn)每個子元素都是相同的數(shù)值,所以這里只取第一個元素,組成一個新的數(shù)組,再對這個數(shù)組進(jìn)行解析。

long[] tOutputs=new long[11];
for (int i=0;i<11;i++)
{
    int k=i*8;
    tOutputs[i]=outPuts[k];
    Log.i("output",tOutputs[i]+"");
}
String outputStr="";
for(int i=0;i<11;i++){
    long char_idx=tOutputs[i];
    long char_code = 0;
    if (char_idx<10){
        char_code = char_idx + (int)('0');
    }
    else if (char_idx<36){
        char_code = char_idx-10 + (int)('A');
    }
    else if (char_idx<62){
        char_code = char_idx + (int)('a');
    }
    outputStr+= (char)char_code;
}

后記

有Java API確實相比C++來的更直觀方便,而且native debug也比JNI好操作,等TensorFlowLite出來的時候,Android TensorFlow應(yīng)用會更加廣泛吧。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • Android 自定義View的各種姿勢1 Activity的顯示之ViewRootImpl詳解 Activity...
    passiontim閱讀 179,048評論 25 709
  • 用合適的時間顆粒度,來保持注意力的專注度! 注意點: 1.防打斷 - 被動打斷(電話) - 誘惑性打斷(微信,微博...
    靖杰閱讀 329評論 0 0
  • 馬戲場里,當(dāng)一個節(jié)目表演完畢后,觀眾們大呼:“再來一個!”孩子不理解地問母親這是什么意思,母親解釋說:“這是表示歡...
    梓毓爸閱讀 355評論 0 3
  • 1. 上海今天有霾,我的心情和天氣一樣不怎么好,特別想念陽光燦爛。 下午洗完澡后,舍友們都在看論文,在一旁啃蘋果的...
    清音素閱讀 199評論 0 0
  • 1.感恩父母的生養(yǎng)之恩,感恩外公外婆,親人們在我成長道路上的付出; 2.感恩死黨,導(dǎo)師,老師,朋友們的支持,讓我生...
    心靈陪伴閱讀 144評論 0 1

友情鏈接更多精彩內(nèi)容