天天看點

TensorFlow-模型的儲存和調用(freeze方式)

TensorFlow-模型的儲存和調用(freeze方式)

硬體:NVIDIA-GTX1080

軟體:Windows7、python3.6.5、tensorflow-gpu-1.4.0

一、基礎知識

freeze:将ckpt的三個檔案融合為一個檔案,将variables轉換為constant,檔案更小,更易于移植

二、代碼展示

1、儲存模型

import tensorflow as tf
from tensorflow.python.framework import graph_util

# input
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')

# value b to be saved, like weight or bias
b = tf.Variable(1, name='b')

# output
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # name = 'output' must be added

inti = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(inti)

    # define graph to write
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
 
    # write pb file
    with tf.gfile.FastGFile('model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # test
    print(sess.run(output, feed_dict = {x: 10, y: 3}))
           

2、調用模型

import tensorflow as tf
from tensorflow.python.platform import gfile

sess = tf.Session()

# import graph, restore
with gfile.FastGFile('model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

# need initial  
sess.run(tf.global_variables_initializer())
 
# input
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

# check variable b, like weight or bias
print(sess.run('b:0'))

# output
output = sess.graph.get_tensor_by_name('output:0')

# test
print(sess.run(output, feed_dict={input_x: 5, input_y: 5}))
           

任何問題請加唯一QQ2258205918(名稱samylee)!

或唯一VX:samylee_csdn

繼續閱讀