本文基于TensorFlow實作了MNIST手寫數字識别,并将訓練好的模型移植到了Android上。
環境配置
TensorFlow 1.0.1
Python2.7
Android Studio 2.2
主要步驟
生成pb檔案:使用TensorFlow Python API 建構并訓練網絡,最後将訓練後的網絡的拓撲結構和參數儲存為pb檔案。
将pb檔案、jar包以及so庫引入Android工程中,并基于TensorFlowInferenceInterface類完成識别。
移植過程
一、生成pb檔案
pb檔案中儲存了網絡的拓撲結構和參數。為了得到pb檔案需要先基于TensorFlow Python API 建構并訓練網絡。
1. 給網絡拓撲中的關鍵節點指定名稱
網絡的輸入節點和輸出節點在使用tf.placeholder定義的時候必須要通過name形參指定名稱,便于在将模型移植到Android後可以通過名稱來擷取指定節點的值,或者給指定節點指派。
x = tf.placeholder(tf.float32, [None, height, width], name='input') #輸入節點的名字這裡取名為'input'
sofmax_out = tf.nn.softmax(logits,name="out_softmax") #輸出節點
# keep_prob_placeholder這個節點也命名了,便于後期用于區分訓練和測試。
keep_prob_placeholder = tf.placeholder(tf.float32, name='keep_prob_placeholder')
2. 将訓練好後的網絡模型儲存為pb檔案
這是通過convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None)函數實作的,該函數的定義見這。
convert_variables_to_constants完成如下兩件事情:(@mirosval的回答)
convert_variables_to_constants() does two things:
It freezes the weights by replacing variables with constants
It removes nodes which are not related to feedforward prediction
from tensorflow.python.framework import graph_util
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out_softmax"])
with tf.gfile.FastGFile(pb_file_path,mode='wb') as f:
f.write(constant_graph.SerializeToString())
二、建構jar包和so庫
這裡簡要地總結一下主要步驟。
1. 安裝 Bazel,Android NDK,Android SDK
2. 下載下傳TensorFlow源碼,修改項目根目錄下的WORKSPACE檔案
修改WORKSPACE檔案中的Android SDK和Android NDK的配置資訊,其中的路徑等資訊根據之前的安裝情況進行修改。
本文将WORKSPACE檔案的配置修改如下:
# Uncomment and update the paths in these entries to build the Android demo.
android_sdk_repository(
name = "androidsdk",
api_level = 25,
build_tools_version = "25.0.2",
# Replace with path to Android SDK on your system
path = "/home/tsiangleo/android_dev/tool/android-sdk-linux",
)
android_ndk_repository(
name="androidndk",
path="/home/tsiangleo/android_dev/tool/android-ndk-r13b",
api_level=21)
3. 建構so庫
在TensorFlow源碼的根目錄下執行如下指令,建構so庫。
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool \
[email protected]_tools//tools/cpp:toolchain \
--cpu=armeabi-v7a
建構成功後,可在如下目錄找到so庫。
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
4. 建構jar包
在TensorFlow源碼的根目錄下執行如下指令,建構jar包。
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
建構成功後,可在如下目錄找到jar包。
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
三、整合到Android Studio工程
以下操作針對Android Studio。
1.将pb檔案放入Android項目中
打開 Project view ,app/src/main/assets。
若不存在assets目錄,右鍵main->new->Directory,輸入assets。
2.将jar包引入Android項目中
打開Project view,将jar包拷貝到app->libs下
選中jar檔案,右鍵 add as library
3.将so庫引入Android項目中
打開 Project view,将libtensorflow_inference.so檔案拷貝到 app/src/main/jniLibs/armeabi-v7a下(若jniLibs/armeabi-v7a目錄不存在,則先建立,方法同1。)。
4.基于TensorFlowInferenceInterface類,編寫代碼進行識别。
下面以識别MNIST手寫數字為例來介紹,具體代碼見github。
(1)定義一些關鍵的常量
public static final String MODEL_FILE = "file:///android_asset/mnist-tf1.0.1.pb"; //asserts目錄下的pb檔案名字
public static final String INPUT_NODE = "input"; //輸入節點的名稱
public static final String OUTPUT_NODE = "out_softmax"; //輸出節點的名稱
public static final String KEEP_PROB_NODE = "keep_prob_placeholder"; // keep_prob節點的名稱
public static final int NUM_CLASSES = 10; //輸出節點的個數,即總的類别數。
public static final int HEIGHT = 28; //輸入圖檔的像素高
public static final int WIDTH = 28; //輸入圖檔的像素寬
(2)建立TensorFlowInferenceInterface對象并初始化
//初始化TensorFlowInferenceInterface對象。
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface();
//根據指定的MODEL_FILE建立一個本地的TensorFlow session
inferenceInterface.initializeTensorFlow(context.getAssets(), MODEL_FILE);
(3)輸入圖檔的像素點,得到分類結果
// 輸入資料pixelArray,pixelArray的資料類型是float[],存放了一張圖檔的像素點。
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH}, pixelArray);
inferenceInterface.fillNodeFloat(KEEP_PROB_NODE,new int[]{1},new float[]{1.0f});
//進行模型的推理
inferenceInterface.runInference(new String[]{OUTPUT_NODE});
//擷取圖檔屬于各個分類的機率,存放在outputs數組中。
float[] outputs = new float[NUM_CLASSES]; //用于存儲模型的輸出資料
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //擷取輸出資料
本文源碼
參考文檔