注意:本部分的ppt來源于中國大學mooc網站:https://www.icourse163.org/learn/ZUCC-1206146808?tid=1206445215&from=study#/learn/content?type=detail&id=1211168244&cid=1213754001
#MNIST手寫數字識别資料集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#了解MNIST手寫數字識别資料集
print("訓練集train數量:",mnist.train.num_examples,
",驗證集 validation數量:",mnist.validation.num_examples,
",測試集 test 數量:",mnist.test.num_examples)
print("train image shape:",mnist.train.images.shape,
"labels shape:",mnist.train.labels.shape)
全部源碼:
#MNIST手寫數字識别資料集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
import os
#讀取相關的資料
mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#定義待輸入資料的占位符
#mnist中每張圖檔共有28*28=784個像素點
x=tf.placeholder(tf.float32,[None,784],name="X")
#0-9一共10個數字====》10個類别
y=tf.placeholder(tf.float32,[None,10],name="y")
#定義模型變量
'''
在本案例中,以正态分布的随機數初始化權重W,以常數0初始化偏置b
'''
#定義變量
w=tf.Variable(tf.random_normal([784,10]),name="w")
b=tf.Variable(tf.zeros([10]),name="b")
#用單個神經元建構神經網絡
forward=tf.matmul(x,w)+b#前向計算
pred=tf.nn.softmax(forward)#softmax分類
#設定訓練參數
train_epochs=100#訓練輪數
batch_size=100#單次訓練樣本數(批次大小)
total_batch=int(mnist.train.num_examples/batch_size)#一輪訓練有多少批次
display_step=1#顯示粒度
learning_rate=0.01#學習率
#定義損失函數(定義交叉商的損失函數)
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
#梯度下降優化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
#檢查預測類别tf.argmax(ored,1)與實際類别tf.argmax(y,1)的比對情況
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#準确率,将布爾值轉化為浮點數,并計算平均值
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
sess=tf.Session()#聲明會話
init=tf.global_variables_initializer()#變量初始化
sess.run(init)
#訓練模型的儲存
#儲存模型的粒子
save_step=5
#建立儲存模型檔案的目錄
ckpt_dir="./ckpt_dir/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
#聲明完所有 變量之後,使用tf.train.Saver()
saver=tf.train.Saver()
#模型訓練
#開始訓練
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys=mnist.train.next_batch(batch_size)#讀取批次資料
sess.run(optimizer,feed_dict={x:xs,y:ys})#執行批次訓練
#total_batch個批次訓練完成後,使用驗證資料計算誤差與準确率:驗證沒有分批
loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
#列印訓練過程中的詳細資訊
if(epoch+1)%display_step==0:
print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9}".format(loss),"Accuracy=","{:.4f}".format(acc))
if(epoch+1)%save_step==0:
saver.save(sess,os.path.join(ckpt_dir,'mnist_h256_model_{:06d}.ckpt'.format(epoch+1)))
print('mnist_h256_model_{:06d}.ckpt'.format(epoch+1))
#對訓練的模型進行儲存
saver.save(sess,os.path.join(ckpt_dir,'mnist_h256_model_ckpt'))
print("Train Finished")
#評估模型
#完成訓練之後,在測試集上評估模型的準确率
def accu_test():
accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
def acc_validation():
#完成訓練之後在驗證集上評估模型的準确率
acc_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Validation Accuracy:",acc_validation)
def acc_train():
#完成訓練之後,在訓練集上評估模型的準确率
acc_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Train Accuracy:",acc_train)
#定義資料可視化
def plot_image_labels_prediction(images,labels,prediction,index,num=10):
'''
image:圖像清單
labels:标簽清單
prediction:預測值清單
index:從第index個開始顯示
num:一次顯示多少副圖檔,預設的話一次顯示10個
'''
fig=plt.gcf()#擷取目前圖表,Get Current Figure
fig.set_size_inches(10,12)#1英寸等于1.54cm
if num>25:
num=25#設定最多顯示25個子圖
for i in range(0,num):
ax=plt.subplot(5,5,i+1)#擷取目前要處理的子圖
ax.imshow(np.reshape(images[index],(28,28)),cmap="binary")
title="label="+str(np.argmax(labels[index]))#建構該圖上要顯示的title資訊
if len(prediction)>0:
title+=",predict="+str(prediction[index])
ax.set_title(title,fontsize=10)#顯示圖上的title資訊
ax.set_xticks([])#不顯示坐标軸
ax.set_yticks([])
index+=1
plt.show()
#MNIST手寫數字識别資料集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#獨熱編碼如何取值
print(mnist.train.labels[1])
#argmax()取出獨熱編碼中最大值的下标
print(np.argmax(mnist.train.labels[1]))
一紙高中萬裡風,寒窗讀破華堂空。
莫道長安花看盡,由來枝葉幾相同?