前段時間搭建了一個分類網絡模型,然後用自己的資料進行了800epoch’的訓練,最後預設生成了三個ckpt檔案.由于要同時運作幾個網絡,是以打算将這個網絡模型進行固化成pb檔案,然後直接調用.
主要包括一下内容:
1.檢視ckpt模型的輸入輸出張量名稱.
2.将ckpt檔案生成pb檔案.
3.檢視生成的pb檔案的輸入輸出節點
4.運作pb檔案,進行網絡預測
1.檢視ckpt網絡的輸入輸出張量名稱
下面是我的網絡訓練後生成的三個ckpt檔案
運作以下代碼檢視自己模型的輸入輸出張量名稱(用于儲存pb檔案時保留這兩個節點)
注意第三行代碼換成自己的ckpt檔案位址,名稱是三個檔案共有的 model.ckpt
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path=os.path.join('/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt')
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for key in var_to_shape_map:
a=reader.get_tensor(key)
print( 'tensor_name: ',key)
print("a.shape:%s"%[a.shape])
我的代碼運作後結果如下:
如果你的模型輸入輸出張量很容易找到,那這個方法對于你來說應該是可以的,但是我就是在這裡花了一天的時間才找到自己模型的輸入輸出張量,因為這個模型比較複雜,并且這個程式輸出的張量是無序的.我使用的模型是别人語義分割模型的改進,是以模型張量不是很好找.
仍然找不到輸入輸出張量怎麼辦?
我的解決辦法:我通過程式找到了模型的定義,然後在模型的最前端列印出輸入張量,在最後列印出輸出張量
上圖中的第二行代碼是輸出"inputs"張量,倒數第二行代碼輸出"net"張量,然後運作包含模型代碼的程式就可以看到列印出來的兩個張量了.下圖就是運作後的輸出結果,這樣就找到自己模型的輸入和輸出張量了.
2.将ckpt檔案生成pb檔案.
以下是将ckpt檔案轉化為pb檔案的代碼
1)更改node_names後面的值,改成自己想要保留的節點名稱,我保留了首尾兩個,就是上面得到的兩個.
2)input_checkpoint位址改成自己的ckpt檔案的位址.(注意寫到.ckpt)
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型儲存路徑
:return:
'''
# 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
# 直接用最後輸出的節點,可以在tensorboard中查找到,tensorboard隻能在linux中使用
node_names = "Placeholder,FC-DenseNet56/logits/BiasAdd"
saver = tf.train.import_meta_graph(input_checkpoint+".meta" , clear_devices=True)
graph = tf.get_default_graph() # 獲得預設的圖
input_graph_def = graph.as_graph_def() # 傳回一個序列化的圖代表目前的圖
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, input_checkpoint) #恢複圖并得到資料
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将變量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗号隔開
with tf.gfile.GFile(output_graph, "wb") as f: #儲存模型
f.write(output_graph_def.SerializeToString()) #序列化輸出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到目前圖有幾個操作節點
input_checkpoint="/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt"#輸入的ckpt檔案位置
output_graph="node.pb"#輸出節點的檔案名
freeze_graph(input_checkpoint,output_graph)
然後就可以得到一個node.pb檔案,名字可以自己更改
3.檢視生成的pb檔案的輸入輸出節點
檢視pb檔案的節點,隻是為了驗證一下,也可以不檢視,代碼如下:
隻需更改你的pb檔案的位址,運作後會得到一個txt檔案,打開可以檢視
import tensorflow as tf
import os
model_dir = './'
model_name = 'new_node.pb'
# 讀取并建立一個圖graph來存放Google訓練好的Inception_v3模型(函數)
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
# 使用tf.GraphDef()定義一個空的Graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Imports the graph from graph_def into the current default Graph.
tf.import_graph_def(graph_def, name='')
# 建立graph
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
result_file = os.path.join(model_dir, 'result.txt')
with open(result_file, 'w+') as f:
for tensor_name in tensor_name_list:
f.write(tensor_name+'\n')
下面是我的txt檔案的内容,好像我的pb檔案生成了整個網絡的節點,并不隻是保留了輸入和輸出兩個,看一下輸入輸出節點和剛才檢視的是對應的.
4.運作pb檔案,進行網絡預測
以下是我用自己的pb檔案進行我自己圖檔的預測,代碼如下:
def get_RAC(image_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
input_x = sess.graph.get_tensor_by_name("Placeholder:0")
final_result = sess.graph.get_tensor_by_name("FC-DenseNet56/logits/BiasAdd:0")
output_image = sess.run(final_result, feed_dict={input_x: input_x })
return output_image
運作上面代碼就可以得到網絡的輸出結果.