天天看點

9. 載入訓練好的模型參數并用于測試集

儲存模型

saver = tf.train.Saver()

with tf.Session() as sess:
 	sess.run(init)
 	for epoch in range(21):
 		...
 	saver.save(sess, DIR + 'projector/projector/a_model.ckpt')

           

使用模型

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    # 列印準确率(不準确),因為此時權重矩陣為0,偏置值為0
    print("未載入模型時準确率:" + str(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})))
    # 載入已經訓練好的模型參數(權值矩陣、偏置值)
    saver.restore(sess, DIR + 'projector/projector/a_model.ckpt')
    # 列印準确率,此時權值矩陣、偏置值均不為0
    print("載入模型後準确率:" + str(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})))