天天看点

tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)

tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)
tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)
  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
  • ckpt.data : 保存模型中每个权重的数值

1、获取需要持久化模型的输出节点名称,通常可以在正常的ckpt模型推断代码中找到:

如图1,假设需要持久化yolo_model,yolo_model推断的输出为pred_feature_maps;如图2所示pred_feature_maps中包含tensor的名称分别是:

yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3;
tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)

图1 模型推断

tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)

图2 节点名称

2、获取节点名称后通过ckpt文件持久化模型,生成pb文件

from tensorflow.python.framework import graph_util
import tensorflow as tf

input_path = "./weights/yolov3.ckpt"
output_path = "./yolov3.pb"
## ckpt文件持久化模型
def freeze_graph(input_path, output_path):
    #节点名称
    output_node_names = "yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3" 
    saver = tf.train.import_meta_graph(input_path+".meta", clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, input_path)
        output_graph_def = graph_util.convert_variables_to_constants(
                           sess=sess,
                           input_graph_def=input_graph_def,   # = sess.graph_def,
                           output_node_names=output_node_names.split(","))

        with tf.gfile.GFile(output_path, 'wb') as fgraph:
            fgraph.write(output_graph_def.SerializeToString())
           

3、读取生成的pb文件,并打印节点名称:

graph_path = "./yolov3.pb"
## 读取pb文件
def read_pb(graph_path):
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(graph_path, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    for node in graph_def.node:
        print(node.name)
           
tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)

图3 pb模型中节点名称

4、利用生成的pb文件实现推断:

graph_path = "./yolov3.pb"
## pb图模型推断
def pbInference(graph_path):
    ## 导入图模型
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(graph_path, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    ### PB图模型中节点名称
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

    isess = tf.InteractiveSession()
    ## 默认是Placeholder可以在pb文件中查看节点名称,也可以在tensor_name_list这个变量中查看
    images_placeholder = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
    feat1 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_1:0")
    feat2 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_2:0")
    feat3 = tf.get_default_graph().get_tensor_by_name("yolov3/yolov3_head/feature_map_3:0")

    image = cv2.imread("./images/000001.jpg")
    image, resize_ratio, dw, dh = letterbox_resize(image, 416, 416)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = np.asarray(image, np.float32)
    image = image[np.newaxis, :] / 255.
    feat1_, feat2_, feat3_ = isess.run([feat1, feat2, feat3], feed_dict={images_placeholder: image})
    print(feat1_.shape,feat2_.shape,feat3_.shape)
           

最的推断结果的维度与图2中的一致:

tensorboard ckpt pb 模型的输出节点_tensorflow中ckpt转pb文件(模型持久化)

图4 推断结果的特征维度

代码的网盘链接:(提取码:i23a)

https://pan.baidu.com/s/1EB9IOf_azDc2QxnSo6NK3A​pan.baidu.com

继续阅读