天天看點

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

前段時間搭建了一個分類網絡模型,然後用自己的資料進行了800epoch’的訓練,最後預設生成了三個ckpt檔案.由于要同時運作幾個網絡,是以打算将這個網絡模型進行固化成pb檔案,然後直接調用.

主要包括一下内容:

1.檢視ckpt模型的輸入輸出張量名稱.

2.将ckpt檔案生成pb檔案.

3.檢視生成的pb檔案的輸入輸出節點

4.運作pb檔案,進行網絡預測

1.檢視ckpt網絡的輸入輸出張量名稱

下面是我的網絡訓練後生成的三個ckpt檔案

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

運作以下代碼檢視自己模型的輸入輸出張量名稱(用于儲存pb檔案時保留這兩個節點)

注意第三行代碼換成自己的ckpt檔案位址,名稱是三個檔案共有的 model.ckpt

from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path=os.path.join('/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt')
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    a=reader.get_tensor(key)
    print( 'tensor_name: ',key)
    print("a.shape:%s"%[a.shape])
           

我的代碼運作後結果如下:

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

如果你的模型輸入輸出張量很容易找到,那這個方法對于你來說應該是可以的,但是我就是在這裡花了一天的時間才找到自己模型的輸入輸出張量,因為這個模型比較複雜,并且這個程式輸出的張量是無序的.我使用的模型是别人語義分割模型的改進,是以模型張量不是很好找.

仍然找不到輸入輸出張量怎麼辦?

我的解決辦法:我通過程式找到了模型的定義,然後在模型的最前端列印出輸入張量,在最後列印出輸出張量

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

上圖中的第二行代碼是輸出"inputs"張量,倒數第二行代碼輸出"net"張量,然後運作包含模型代碼的程式就可以看到列印出來的兩個張量了.下圖就是運作後的輸出結果,這樣就找到自己模型的輸入和輸出張量了.

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

2.将ckpt檔案生成pb檔案.

以下是将ckpt檔案轉化為pb檔案的代碼

1)更改node_names後面的值,改成自己想要保留的節點名稱,我保留了首尾兩個,就是上面得到的兩個.

2)input_checkpoint位址改成自己的ckpt檔案的位址.(注意寫到.ckpt)

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型儲存路徑
    :return:
    '''
    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    # 直接用最後輸出的節點,可以在tensorboard中查找到,tensorboard隻能在linux中使用
    node_names = "Placeholder,FC-DenseNet56/logits/BiasAdd"
    saver = tf.train.import_meta_graph(input_checkpoint+".meta" , clear_devices=True)
    graph = tf.get_default_graph() # 獲得預設的圖
    input_graph_def = graph.as_graph_def()  # 傳回一個序列化的圖代表目前的圖
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, input_checkpoint) #恢複圖并得到資料
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将變量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗号隔開
 
        with tf.gfile.GFile(output_graph, "wb") as f: #儲存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到目前圖有幾個操作節點
input_checkpoint="/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt"#輸入的ckpt檔案位置
output_graph="node.pb"#輸出節點的檔案名
freeze_graph(input_checkpoint,output_graph)
           

然後就可以得到一個node.pb檔案,名字可以自己更改

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

3.檢視生成的pb檔案的輸入輸出節點

檢視pb檔案的節點,隻是為了驗證一下,也可以不檢視,代碼如下:

隻需更改你的pb檔案的位址,運作後會得到一個txt檔案,打開可以檢視

import tensorflow as tf
import os
 
model_dir = './'
model_name = 'new_node.pb'
 
# 讀取并建立一個圖graph來存放Google訓練好的Inception_v3模型(函數)
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')
 
# 建立graph
create_graph()
 
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
result_file = os.path.join(model_dir, 'result.txt') 
with open(result_file, 'w+') as f:
    for tensor_name in tensor_name_list:
        f.write(tensor_name+'\n')
           

下面是我的txt檔案的内容,好像我的pb檔案生成了整個網絡的節點,并不隻是保留了輸入和輸出兩個,看一下輸入輸出節點和剛才檢視的是對應的.

将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.
将ckpt檔案生成pb檔案的詳細過程,并調用pb檔案進行模型預測.

4.運作pb檔案,進行網絡預測

以下是我用自己的pb檔案進行我自己圖檔的預測,代碼如下:

def get_RAC(image_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            
            input_x = sess.graph.get_tensor_by_name("Placeholder:0")
            final_result = sess.graph.get_tensor_by_name("FC-DenseNet56/logits/BiasAdd:0")
            output_image = sess.run(final_result, feed_dict={input_x: input_x })
            return output_image
           

運作上面代碼就可以得到網絡的輸出結果.