天天看點

卷積神經網絡處理圖像識别(三)

本篇接着上一篇來介紹卷積神經網絡的訓練(即反向傳播)和應用。

卷積神經網絡處理圖像識别(三)

訓練神經網絡和儲存訓練結果的代碼如下:

import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference

MODEL_SAVE_PATH ="E:/Python36/my tensorflow/CNN/model_path/"
MODEL_NAME = "MNIST_CNNmodel.ckpt"
print(os.path.join(MODEL_SAVE_PATH, MODEL_NAME))
BATCH_SIZE  =100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
MOVING_AVERAGE_DECAY = 0.99
TRAINING_STEPS = 20000

def train(mnist):
    '''training'''
    x = tf.placeholder(tf.float32,
                       [None,
                        CNN_MNIST_inference.IMAGE_HEIGHT,
                        CNN_MNIST_inference.IMAGE_WIDTH,
                        CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
    #I2 正則
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = CNN_MNIST_inference.inference(x, True, regularizer, None, reuse = False)
    global_step = tf.Variable(0, trainable = False)
    #平均移動
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables()) # moving average applied
    average_y = CNN_MNIST_inference.inference(x, True, regularizer,variable_averages, reuse = True)
    
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    tf.add_to_collection('losses', cross_entropy_mean)
    loss = tf.add_n(tf.get_collection('losses'))
    #loss = cross_entropy_mean
    
    #learning rate with decay
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)
    #learning_rate = 0.01
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
    train_op = tf.group(train_step, variables_averages_op)
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    saver = tf.train.Saver() #初始化持久類
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run() # 真正變量初始化
        
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #驗證集

        test_set  = np.reshape(mnist.test.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        test_feed        = {x: test_set, y_ : mnist.test.labels} #測試集(訓練集)
        
        steps = [] # only for plot
        accs = [] # only for plot
        losses = [] # only for plot
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            xs = np.reshape(xs,
                            [BATCH_SIZE,
                             CNN_MNIST_inference.IMAGE_HEIGHT,
                             CNN_MNIST_inference.IMAGE_WIDTH,
                             CNN_MNIST_inference.NUM_CHANNELS])
                                
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x : xs, y_: ys})

            #print(i,loss_value)
            if i % 25  == 0:
                validate_acc = sess.run(accuracy, feed_dict = validate_feed) #驗證集 準确度
                steps.append(step); accs.append(validate_acc*100); losses.append(loss_value) # only for plot
                print("After %d training steps, validation dataset accuracy after this batch is %g%%, test dataset loss on this batch is %g"%(step, validate_acc*100,loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)
                
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)    
        test_acc = sess.run(accuracy, feed_dict = test_feed)
        print("After %d training steps, test accuracy using average model is %g%%"%
              (TRAINING_STEPS, test_acc*100))
        writer = tf.summary.FileWriter("E://TensorBoard//test",sess.graph)
        
        saver.save(sess, r"E:\Python36\my tensorflow\ckpt files\mode_mnist.ckpt")
    #only for plot
    from matplotlib import pyplot as plt
    import matplotlib.ticker as mtick
    plt.subplot(211)
    plt.plot(steps, losses,color="red")
    plt.scatter(steps, losses,s=20,color="red")
    plt.xlabel("訓練的步數(Batch數)"); plt.ylabel("訓練batch上的Loss(含L2正則Loss)")
    plt.subplot(212)
    plt.plot(steps, accs,color="green")
    plt.scatter(steps, accs,s=20,color="green")
    yticks = mtick.FormatStrFormatter("%.3f%%")
    plt.gca().yaxis.set_major_formatter(yticks)
    plt.xlabel("step"); plt.ylabel("驗證集上的預測準确率")
    plt.show()
 
def main(argv = None):
    mnist = input_data.read_data_sets(r"E:\Python36\my tensorflow\MNIST_data",one_hot =True)
    train(mnist)

if __name__ == "__main__":
    tf.app.run() #調用main()           

複制

下面是測試Batch的總Loss和驗證集上的準确率的收斂趨勢圖。由于我的電腦性能不好,是以我大幅度削減了待訓練參數個數。盡管如此,2000輪訓練之後,在驗證集上5000個圖檔的預測正确率已達98.3%。如若不削減參數,準确率可達99.4%。

卷積神經網絡處理圖像識别(三)

下面的代碼是利用訓練好的卷積神經網絡模型來評估它在驗證集上的準确率(可以在正式訓練時不評估進而節省訓練時間),以及用它用來識别單張圖檔。

import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference
import CNN_MNIST_train
import matplotlib.pyplot as plt

def evaluate(mnist):   #評估驗證集的預測準确度
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32,
                           [None,
                            CNN_MNIST_inference.IMAGE_HEIGHT,
                            CNN_MNIST_inference.IMAGE_WIDTH,
                            CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #驗證集
        
        y = CNN_MNIST_inference.inference(x, False, None, None, reuse = False)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            #平均移動
        variable_averages = tf.train.ExponentialMovingAverage(CNN_MNIST_train.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            #print(CNN_MNIST_train.MODEL_SAVE_PATH)
            #找到目錄中最新的模型檔案
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            #print(ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                #加載模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                #模型的疊代輪數
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split("-")[-1]
                accuary_score = sess.run(accuracy, feed_dict =validate_feed)
                print("After %s training steps, validation accuary = %g" %(global_step, accuary_score)) #global_step是str
            else:
                print('No checkpoint file found')
                return
                
 #把所有輸入資料input_data、聲明的常量放進with tf.Graph().as_default(): 裡面就行了,就可以統一到同一個graph了,
#不然input_data是放到系統預設建立的Graph,跟你又重新with tf.Graph().as_default():不是同一個Graph()就會報錯           
def recognize(input_x):
    g = tf.get_default_graph() # 因為 input_x 預設的圖中,是以可把下面的計算也預設的圖中
    with g.as_default():
        y = CNN_MNIST_inference.inference(input_x, False, None, None, reuse = False)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #找到目錄中最新的模型檔案
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                #加載模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                predicted_label = tf.argmax(y, 1)
                print("predicted_label: ", sess.run(predicted_label)[0])
            else:
                print('No checkpoint file found')
                return
                
def plotImage(path):#僅用于繪制待識别的圖檔
    image_rawdata = tf.gfile.FastGFile(path,"rb").read()
    img_data = tf.image.decode_jpeg(image_rawdata)
    if img_data.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data, dtype = tf.float32)
    with tf.Session() as sess:
        image_data = img_data.eval() # return a numpy array#需要運作在會話中
    image_data_shaped1 = image_data.reshape(image_data.shape[0],image_data.shape[1])#numpy array
    #print(image_data_shaped1)
    plt.imshow(image_data_shaped1,cmap='gray')
    plt.show()
    
def main(argv=None): 
    mnist = input_data.read_data_sets(r"E:\Python36\my tensorflow\MNIST_data",one_hot =True)
    evaluate(mnist) #評估在驗證集上的預測準确度
    #輸入
    image_path = r"E:\Python36\MNIST picture\test\50.jpg"
    image_rawdata = tf.gfile.FastGFile(image_path,"rb").read()
    img_data0 = tf.image.decode_jpeg(image_rawdata)
    if img_data0.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data0, dtype = tf.float32)
    
    #根據神經網絡的要求轉換圖檔資料的shape!    
    input_x =  tf.reshape(img_data, [1,
                                    CNN_MNIST_inference.IMAGE_HEIGHT,
                                    CNN_MNIST_inference.IMAGE_WIDTH,
                                    CNN_MNIST_inference.NUM_CHANNELS])
    plotImage(image_path)
    recognize(input_x)

if __name__ =="__main__":
    #tf.app.run() #調用main()
    main()#           

複制