天天看點

android 手寫數字,基于TensorFlow的MNIST手寫數字識别與Android移植

本文基于TensorFlow實作了MNIST手寫數字識别,并将訓練好的模型移植到了Android上。

環境配置

TensorFlow 1.0.1

Python2.7

Android Studio 2.2

主要步驟

生成pb檔案:使用TensorFlow Python API 建構并訓練網絡,最後将訓練後的網絡的拓撲結構和參數儲存為pb檔案。

将pb檔案、jar包以及so庫引入Android工程中,并基于TensorFlowInferenceInterface類完成識别。

移植過程

一、生成pb檔案

pb檔案中儲存了網絡的拓撲結構和參數。為了得到pb檔案需要先基于TensorFlow Python API 建構并訓練網絡。

1. 給網絡拓撲中的關鍵節點指定名稱

網絡的輸入節點和輸出節點在使用tf.placeholder定義的時候必須要通過name形參指定名稱,便于在将模型移植到Android後可以通過名稱來擷取指定節點的值,或者給指定節點指派。

x = tf.placeholder(tf.float32, [None, height, width], name='input') #輸入節點的名字這裡取名為'input'

sofmax_out = tf.nn.softmax(logits,name="out_softmax") #輸出節點

# keep_prob_placeholder這個節點也命名了,便于後期用于區分訓練和測試。

keep_prob_placeholder = tf.placeholder(tf.float32, name='keep_prob_placeholder')

2. 将訓練好後的網絡模型儲存為pb檔案

這是通過convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None)函數實作的,該函數的定義見這。

convert_variables_to_constants完成如下兩件事情:(@mirosval的回答)

convert_variables_to_constants() does two things:

It freezes the weights by replacing variables with constants

It removes nodes which are not related to feedforward prediction

from tensorflow.python.framework import graph_util

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out_softmax"])

with tf.gfile.FastGFile(pb_file_path,mode='wb') as f:

f.write(constant_graph.SerializeToString())

二、建構jar包和so庫

這裡簡要地總結一下主要步驟。

1. 安裝 Bazel,Android NDK,Android SDK

2. 下載下傳TensorFlow源碼,修改項目根目錄下的WORKSPACE檔案

修改WORKSPACE檔案中的Android SDK和Android NDK的配置資訊,其中的路徑等資訊根據之前的安裝情況進行修改。

本文将WORKSPACE檔案的配置修改如下:

# Uncomment and update the paths in these entries to build the Android demo.

android_sdk_repository(

name = "androidsdk",

api_level = 25,

build_tools_version = "25.0.2",

# Replace with path to Android SDK on your system

path = "/home/tsiangleo/android_dev/tool/android-sdk-linux",

)

android_ndk_repository(

name="androidndk",

path="/home/tsiangleo/android_dev/tool/android-ndk-r13b",

api_level=21)

3. 建構so庫

在TensorFlow源碼的根目錄下執行如下指令,建構so庫。

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \

--crosstool_top=//external:android/crosstool \

[email protected]_tools//tools/cpp:toolchain \

--cpu=armeabi-v7a

建構成功後,可在如下目錄找到so庫。

bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

4. 建構jar包

在TensorFlow源碼的根目錄下執行如下指令,建構jar包。

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java

建構成功後,可在如下目錄找到jar包。

bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

三、整合到Android Studio工程

以下操作針對Android Studio。

1.将pb檔案放入Android項目中

打開 Project view ,app/src/main/assets。

若不存在assets目錄,右鍵main->new->Directory,輸入assets。

2.将jar包引入Android項目中

打開Project view,将jar包拷貝到app->libs下

選中jar檔案,右鍵 add as library

3.将so庫引入Android項目中

打開 Project view,将libtensorflow_inference.so檔案拷貝到 app/src/main/jniLibs/armeabi-v7a下(若jniLibs/armeabi-v7a目錄不存在,則先建立,方法同1。)。

4.基于TensorFlowInferenceInterface類,編寫代碼進行識别。

下面以識别MNIST手寫數字為例來介紹,具體代碼見github。

(1)定義一些關鍵的常量

public static final String MODEL_FILE = "file:///android_asset/mnist-tf1.0.1.pb"; //asserts目錄下的pb檔案名字

public static final String INPUT_NODE = "input"; //輸入節點的名稱

public static final String OUTPUT_NODE = "out_softmax"; //輸出節點的名稱

public static final String KEEP_PROB_NODE = "keep_prob_placeholder"; // keep_prob節點的名稱

public static final int NUM_CLASSES = 10; //輸出節點的個數,即總的類别數。

public static final int HEIGHT = 28; //輸入圖檔的像素高

public static final int WIDTH = 28; //輸入圖檔的像素寬

(2)建立TensorFlowInferenceInterface對象并初始化

//初始化TensorFlowInferenceInterface對象。

TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface();

//根據指定的MODEL_FILE建立一個本地的TensorFlow session

inferenceInterface.initializeTensorFlow(context.getAssets(), MODEL_FILE);

(3)輸入圖檔的像素點,得到分類結果

// 輸入資料pixelArray,pixelArray的資料類型是float[],存放了一張圖檔的像素點。

inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH}, pixelArray);

inferenceInterface.fillNodeFloat(KEEP_PROB_NODE,new int[]{1},new float[]{1.0f});

//進行模型的推理

inferenceInterface.runInference(new String[]{OUTPUT_NODE});

//擷取圖檔屬于各個分類的機率,存放在outputs數組中。

float[] outputs = new float[NUM_CLASSES]; //用于存儲模型的輸出資料

inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //擷取輸出資料

本文源碼

參考文檔