天天看點

圖像分類之花卉圖像分類(四)訓練模型

本來我想用tensorbaord來觀察LOSS曲線變化的,但是我代碼改得不對,如果有小夥伴改出來了,如果可以的話可以告訴我,我懶得改了。下面代碼也是注意改成自己的路徑

# 導入檔案
import os
import numpy as np
import tensorflow as tf
import input_data
import model
import os
import time
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings('ignore')

# 變量聲明
N_CLASSES = 5  # 五種花類型
IMG_W = 64  # resize圖像,太大的話訓練時間久
IMG_H = 64
BATCH_SIZE = 25
CAPACITY = 250
MAX_STEP =5000
learning_rate = 0.0005# 一般小于0.0001


train_dir = 'D:/flower_photos/input_data2/train'  # 訓練樣本的讀入路徑
val_dir = 'D:/flower_photos/input_data2/val'  # 驗證樣本的讀入路徑
logs_train_dir = 'D:/save2/train'  # logs存儲路徑
logs_val_dir = 'D:/save2/val'

train, train_label= input_data.get_files(train_dir)
val, val_label = input_data.get_files(val_dir)
# 訓練資料及标簽
train_batch, train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 測試資料及标簽
val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)

x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])

# 存放DropOut參數的容器,訓練時為0.45,測試時為0
dropout_placeholdr = tf.placeholder(tf.float32)
# 是否是訓練狀況
train = tf.placeholder(tf.float32)

logits = model.inference(x, BATCH_SIZE, N_CLASSES,dropout_placeholdr,train)
loss = model.losses(logits, y_)
acc = model.evaluation(logits, y_)
train_op = model.trainning(loss, learning_rate)

with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
    val_writer = tf.summary.FileWriter(logs_val_dir)
    # val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break
            tra_images, tra_labels = sess.run([train_batch, train_label_batch])
            _, tra_loss, tra_acc = sess.run([train_op, loss, acc],
                                            feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})
            if step % 100 == 0:
                print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
                summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})
                train_writer.add_summary(summary_str, step)

            if step % 100 == 0:
                val_images, val_labels = sess.run([val_batch, val_label_batch])
                val_loss, val_acc = sess.run([loss, acc],
                                             feed_dict={x: val_images, y_: val_labels,dropout_placeholdr:1.0,train:0})
                print('** val loss = %.2f, val accuracy = %.2f%%  **' % (val_loss, val_acc * 100.0))
                summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:1.0,train:0})
                val_writer.add_summary(summary_str, step)

            if step % 100 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
    coord.join(threads)

           

其中save檔案夾中存儲的就是訓練好的模型,這個在後面測試的時候會用到。

繼續閱讀