TensorFlow-模型的儲存和調用(freeze方式)
硬體:NVIDIA-GTX1080
軟體:Windows7、python3.6.5、tensorflow-gpu-1.4.0
一、基礎知識
freeze:将ckpt的三個檔案融合為一個檔案,将variables轉換為constant,檔案更小,更易于移植
二、代碼展示
1、儲存模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
# input
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
# value b to be saved, like weight or bias
b = tf.Variable(1, name='b')
# output
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # name = 'output' must be added
inti = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(inti)
# define graph to write
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
# write pb file
with tf.gfile.FastGFile('model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# test
print(sess.run(output, feed_dict = {x: 10, y: 3}))
2、調用模型
import tensorflow as tf
from tensorflow.python.platform import gfile
sess = tf.Session()
# import graph, restore
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# need initial
sess.run(tf.global_variables_initializer())
# input
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
# check variable b, like weight or bias
print(sess.run('b:0'))
# output
output = sess.graph.get_tensor_by_name('output:0')
# test
print(sess.run(output, feed_dict={input_x: 5, input_y: 5}))
任何問題請加唯一QQ2258205918(名稱samylee)!
或唯一VX:samylee_csdn