天天看點

tensorflow架構.ckpt .pb模型節點tensor_name列印及ckpt模型轉.pb模型

轉換模型首先要知道的是從哪個節點輸出,如果沒有源代碼是很難清楚節點資訊。

擷取ckpt模型的節點名稱

import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    # print(reader.get_tensor(key)) #相應的值
           

擷取pb模型的節點名稱

import tensorflow as tf
import os

model_dir = './'
model_name = 'model.pb'

def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name,'\n')

           

ckpt轉換為pb模型

from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf

saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)

#【敲黑闆!】這裡就是填寫輸出節點名稱惹
output_nodes = ["xxx"] 

with tf.Session(graph=tf.get_default_graph()) as sess:
    input_graph_def = sess.graph.as_graph_def()
    saver.restore(sess, "./ade20k/model.ckpt-27150")
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                    input_graph_def,
                                                                    output_nodes)
    with open("frozen_model.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())
           

[new version] ckpt to pb

import tensorflow as tf
from tensorflow.python.framework import graph_util

checkpoint = "model.ckpt-xxx"
graph_file = "xxx.pb"


def return_ops(candidate):
    ops = []
    if isinstance(candidate, (list, tuple)):
        for x in candidate:
            ops += return_ops(x)
    else:
        ops.append(candidate.op)

    return ops


def dump_graph():
    with tf.Graph().as_default():
        inputs = setup_input(dtype=tf.float32,
                             shape=[None, 224, 224, 3],
                             name='graph_input')

        outputs = model_inference(inputs, 1000)

        model_info = gen_info(inputs, outputs)
        print(model_info)

        saver = tf.train.Saver(tf.global_variables())
        dest_node = return_ops(outputs)

        with tf.Session() as sess:
            saver.restore(sess, checkpoint)
            cur_graphdef = sess.graph.as_graph_def()
            output_graphdef = graph_util.convert_variables_to_constants(
                sess, cur_graphdef, [n.name for n in dest_node])

            with tf.gfile.GFile(graph_file, 'wb') as gf:
                gf.write(output_graphdef.SerializeToString())

            with open(graph_file + '.info', 'w') as info_f:
                info_f.write(model_info)


def setup_input(dtype, shape, name=None):
    p_node = tf.Placeholder(dtype=dtype, shape=shape, name=name)
    return p_node


def gen_info(inp, o):
    info_text = '[input tensor]: {0}\n[input shape]: {1}\n'.format(
        inp.name, inp.get_shape())
    print("outp", o)
    info_text += '[output tensor]: {0}\n[output shape]: {1}\n'.format(
        o.name, o.get_shape())

    return info_text


def model_inference(images, num_classes):
    with tf.variable_scope('xxx'):
        logits = tf.xxx
    return logits


if __name__ == "__main__":
    dump_graph()
    print('dump finish!')

           

繼續閱讀