儲存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
...
saver.save(sess, DIR + 'projector/projector/a_model.ckpt')
使用模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
# 列印準确率(不準确),因為此時權重矩陣為0,偏置值為0
print("未載入模型時準确率:" + str(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})))
# 載入已經訓練好的模型參數(權值矩陣、偏置值)
saver.restore(sess, DIR + 'projector/projector/a_model.ckpt')
# 列印準确率,此時權值矩陣、偏置值均不為0
print("載入模型後準确率:" + str(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})))