- model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
- ckpt.data : 保存模型中每个权重的数值
1、获取需要持久化模型的输出节点名称,通常可以在正常的ckpt模型推断代码中找到:
如图1,假设需要持久化yolo_model,yolo_model推断的输出为pred_feature_maps;如图2所示pred_feature_maps中包含tensor的名称分别是:
yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3;图1 模型推断
图2 节点名称
2、获取节点名称后通过ckpt文件持久化模型,生成pb文件
from tensorflow.python.framework import graph_util
import tensorflow as tf
input_path = "./weights/yolov3.ckpt"
output_path = "./yolov3.pb"
## ckpt文件持久化模型
def freeze_graph(input_path, output_path):
#节点名称
output_node_names = "yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3"
saver = tf.train.import_meta_graph(input_path+".meta", clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_path)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def, # = sess.graph_def,
output_node_names=output_node_names.split(","))
with tf.gfile.GFile(output_path, 'wb') as fgraph:
fgraph.write(output_graph_def.SerializeToString())
3、读取生成的pb文件,并打印节点名称:
graph_path = "./yolov3.pb"
## 读取pb文件
def read_pb(graph_path):
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_path, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
for node in graph_def.node:
print(node.name)
图3 pb模型中节点名称
4、利用生成的pb文件实现推断:
graph_path = "./yolov3.pb"
## pb图模型推断
def pbInference(graph_path):
## 导入图模型
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_path, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
### PB图模型中节点名称
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
isess = tf.InteractiveSession()
## 默认是Placeholder可以在pb文件中查看节点名称,也可以在tensor_name_list这个变量中查看
images_placeholder = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
feat1 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_1:0")
feat2 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_2:0")
feat3 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_3:0")
image = cv2.imread("./images/000001.jpg")
image, resize_ratio, dw, dh = letterbox_resize(image, 416, 416)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = np.asarray(image, np.float32)
image = image[np.newaxis, :] / 255.
feat1_, feat2_, feat3_ = isess.run([feat1, feat2, feat3], feed_dict={images_placeholder: image})
print(feat1_.shape,feat2_.shape,feat3_.shape)
最的推断结果的维度与图2中的一致:
图4 推断结果的特征维度
代码的网盘链接:(提取码:i23a)
https://pan.baidu.com/s/1EB9IOf_azDc2QxnSo6NK3Apan.baidu.com