在 Flutter 中使用 TensorFlow Lite 插件實現(xiàn)文字分類

如果您希望能有一種簡單、高效且靈活的方式把 TensorFlow 模型集成到 Flutter 應(yīng)用里,那請您一定不要錯過我們今天介紹的這個全新插件 tflite_flutter。這個插件的開發(fā)者是 Google Summer of Code(GSoC) 的一名實習(xí)生 Amish Garg,本文來自他在 Medium 上的一篇文章《在 Flutter 中使用 TensorFlow Lite 插件實現(xiàn)文字分類》。

tflite_flutter 插件的核心特性:

  • 它提供了與 TFLite Java 和 Swift API 相似的 Dart API,所以其靈活性和在這些平臺上的效果是完全一樣的
  • 通過 dart:ffi 直接與 TensorFlow Lite C API 相綁定,所以它比其它平臺集成方式更加高效。
  • 無需編寫特定平臺的代碼。
  • 通過 NNAPI 提供加速支持,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。

本文中,我們將使用 tflite_flutter 構(gòu)建一個 文字分類 Flutter 應(yīng)用 帶您體驗 tflite_flutter 插件,首先從新建一個 Flutter 項目 text_classification_app 開始。

初始化配置

Linux 和 Mac 用戶

install.sh 拷貝到您應(yīng)用的根目錄,然后在根目錄執(zhí)行 sh install.sh,本例中就是目錄 text_classification_app/

Windows 用戶

install.bat 文件拷貝到應(yīng)用根目錄,并在根目錄運行批處理文件 install.bat,本例中就是目錄 text_classification_app/。

它會自動從 release assets 下載最新的二進制資源,然后把它放到指定的目錄下。

請點擊到 README 文件里查看更多 關(guān)于初始配置的信息。

獲取插件

pubspec.yaml 添加 tflite_flutter: ^<latest_version>詳情)。

下載模型

要在移動端上運行 TensorFlow 訓(xùn)練模型,我們需要使用 .tflite 格式。如果需要了解如何將 TensorFlow 訓(xùn)練的模型轉(zhuǎn)換為 .tflite 格式,請參閱官方指南

這里我們準備使用 TensorFlow 官方站點上預(yù)訓(xùn)練的文字分類模型,可從這里下載

該預(yù)訓(xùn)練的模型可以預(yù)測當前段落的情感是積極還是消極。它是基于來自 Mass 等人的 Large Movie Review Dataset v1.0 數(shù)據(jù)集進行訓(xùn)練的。數(shù)據(jù)集由基于 IMDB 電影評論所標記的積極或消極標簽組成,點擊查看更多信息。

text_classification.tflitetext_classification_vocab.txt 文件拷貝到 text_classification_app/assets/ 目錄下。

pubspec.yaml 文件中添加 assets/。

assets:    
  - assets/

現(xiàn)在萬事俱備,我們可以開始寫代碼了。 ??

實現(xiàn)分類器

預(yù)處理

正如 文字分類模型頁面 里所提到的??梢园凑障旅娴牟襟E使用模型對段落進行分類:

  1. 對段落文本進行分詞,然后使用預(yù)定義的詞匯集將它轉(zhuǎn)換為一組詞匯 ID;
  2. 將生成的這組詞匯 ID 輸入 TensorFlow Lite 模型里;
  3. 從模型的輸出里獲取當前段落是積極或者是消極的概率值。

我們首先寫一個方法對原始字符串進行分詞,其中使用 text_classification_vocab.txt 作為詞匯集。

lib/ 文件夾下創(chuàng)建一個新文件 classifier.dart。

這里先寫代碼加載 text_classification_vocab.txt 到字典里。

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';
  
  Map<String, int> _dict;

  Classifier() {
    _loadDictionary();
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }
  
}

加載字典

現(xiàn)在我們來編寫一個函數(shù)對原始字符串進行分詞。

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  // 單句的最大長度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;
  
  List<List<double>> tokenizeInputText(String text) {
    
    // 使用空格進行分詞
    final toks = text.split(' ');
    
    // 創(chuàng)建一個列表,它的長度等于 _sentenceLen,并且使用 <pad> 的對應(yīng)的字典值來填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 對于句子里的每個單詞在 dict 里找到相應(yīng)的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我們的解釋器輸入 tensor 所需的形狀 [1,256] 返回 List<List<double>>
    return [vec];
  }
}


使用 tflite_flutter 進行分析

這是本文的主體部分,這里我們會討論 tflite_flutter 插件的用途。

這里的分析是指基于輸入數(shù)據(jù)在設(shè)備上使用 TensorFlow Lite 模型的處理過程。要使用 TensorFlow Lite 模型進行分析,需要通過 解釋器 來運行它,了解更多

創(chuàng)建解釋器,加載模型

tflite_flutter 提供了一個方法直接通過資源創(chuàng)建解釋器。

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})

由于我們的模型在 assets/ 文件夾下,需要使用上面的方法來創(chuàng)建解析器。對于 InterpreterOptions 的相關(guān)說明,請 參考這里。

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名稱
  final _modelFile = 'text_classification.tflite';

  // TensorFlow Lite 解釋器對象
  Interpreter _interpreter;

  Classifier() {
    // 當分類器初始化以后加載模型
    _loadModel();
  }

  void _loadModel() async {
    
    // 使用 Interpreter.fromAsset 創(chuàng)建解釋器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

}

創(chuàng)建解釋器的代碼

如果您不希望將模型放在 assets/ 目錄下,tflite_flutter 還提供了工廠構(gòu)造函數(shù)創(chuàng)建解釋器,更多信息。

我們開始進行分析!

現(xiàn)在用下面方法啟動分析:

void run(Object input, Object output);

注意這里的方法和 Java API 中的是一樣的。

Object inputObject output 必須是和 Input Tensor 與 Output Tensor 維度相同的列表。

要查看 input tensors 和 output tensors 的維度,可以使用如下代碼:

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());

在本例中 text_classification 模型的輸出如下:

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data:  1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data:  8]

現(xiàn)在,我們實現(xiàn)分類方法,該方法返回值為 1 表示積極,返回值為 0 表示消極。

int classify(String rawText) {
    
    //  tokenizeInputText 返回形狀為 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);
   
    // [1,2] 形狀的輸出
    var output = List<double>(2).reshape([1, 2]);
    
    // run 方法會運行分析并且存儲輸出的值
    _interpreter.run(input, output);

    var result = 0;
    // 如果輸出中第一個元素的值比第二個大,那么句子就是消極的
    
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

用于分析的代碼

在 tflite_flutter 的 extension ListShape on List 下面定義了一些使用的擴展:

// 將提供的列表進行矩陣變形,輸入?yún)?shù)為元素總數(shù) // 保持相等 
// 用法:List(400).reshape([2,10,20]) 
// 返回  List<dynamic>

List reshape(List<int> shape)
// 返回列表的形狀
List<int> get shape
// 返回列表任意形狀的元素數(shù)量
int get computeNumElements

最終的 classifier.dart 應(yīng)該是這樣的:

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名稱
  final _modelFile = 'text_classification.tflite';
  final _vocabFile = 'text_classification_vocab.txt';

  // 語句的最大長度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  // TensorFlow Lite 解釋器對象
  Interpreter _interpreter;

  Classifier() {
    // 當分類器初始化的時候加載模型
    _loadModel();
    _loadDictionary();
  }

  void _loadModel() async {
    // 使用 Intepreter.fromAsset 創(chuàng)建解析器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

  int classify(String rawText) {
    // tokenizeInputText  返回形狀為 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    //輸出形狀為 [1, 2] 的矩陣
    var output = List<double>(2).reshape([1, 2]);

    // run 方法會運行分析并且將結(jié)果存儲在 output 中。
    _interpreter.run(input, output);

    var result = 0;
    // 如果第一個元素的輸出比第二個大,那么當前語句是消極的

    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

  List<List<double>> tokenizeInputText(String text) {
    // 用空格分詞
    final toks = text.split(' ');

    // 創(chuàng)建一個列表,它的長度等于 _sentenceLen,并且使用 <pad> 對應(yīng)的字典值來填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 對于句子中的每個單詞,在 dict 中找到相應(yīng)的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我們的解釋器輸入 tensor 所需的形狀 [1,256] 返回 List<List<double>>
    return [vec];
  }
}

現(xiàn)在,可以根據(jù)您的喜好實現(xiàn) UI 的代碼,分類器的用法比較簡單。

// 創(chuàng)建 Classifier 對象
Classifer _classifier = Classifier();
// 將目標語句作為參數(shù),調(diào)用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (積極的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消極的)

請在這里查閱完整代碼:Text Classification Example app with UI。

Text Classification Example App

文字分類示例應(yīng)用

了解更多關(guān)于 tflite_flutter 插件的信息,請訪問 GitHub repo: am15h/tflite_flutter_plugin。

答疑

問:tflite_fluttertflite v1.0.5 有哪些區(qū)別?

tflite v1.0.5 側(cè)重于為特定用途的應(yīng)用場景提供高級特性,比如圖片分類、物體檢測等等。而新的 tflite_flutter 則提供了與 Java API 相同的特性和靈活性,而且可以用于任何 tflite 模型中,它還支持 delegate。

由于使用 dart:ffi (dart ?? (ffi) ?? C),tflite_flutter 非常快 (擁有低延時)。而 tflite 使用平臺集成 (dart ?? platform-channel ?? (Java/Swift) ?? JNI ?? C)。

問:如何使用 tflite_flutter 創(chuàng)建圖片分類應(yīng)用?有沒有類似 TensorFlow Lite Android Support Library 的依賴包?

更新(07/01/2020): TFLite Flutter Helper 開發(fā)庫已發(fā)布。

TensorFlow Lite Flutter Helper Library 為處理和控制輸入及輸出的 TFLite 模型提供了易用的架構(gòu)。它的 API 設(shè)計和文檔與 TensorFlow Lite Android Support Library 是一樣的。更多信息請 參考這里

以上是本文的全部內(nèi)容,歡迎大家對 tflite_flutter 插件進行反饋,請在這里 上報 bug 或提出功能需求。

謝謝關(guān)注。

感謝 Michael Thomsen。

致謝

  • 譯者:Yuan,谷創(chuàng)字幕組
  • 審校:Xinlei、Lynn Wang、Alex,CFUG 社區(qū)。

本文聯(lián)合發(fā)布在 TensorFlow 線上討論區(qū)、101.devFlutter 中文文檔,以及 Flutter 社區(qū)線上渠道。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容