天天看點

CIFAR-10 圖像識别

零、學習目标

  1. tensorflow 資料讀取原理
  2. 深度學習資料增強原理

一、CIFAR-10資料集簡介

是用于普通物體識别的小型資料集,一共包含 10個類别 的 RGB彩×××片(包含:(飛機、汽車、鳥類、貓、鹿、狗、蛙、馬、船、卡車)。圖檔大小均為 3232像素**,資料集中一共有 50000 張訓練圖檔和 1000*** 張測試圖檔。部分代碼來自于tensorflow官方,以下表格列出了所需的官方代碼。

檔案 用途
cifar10.py 建立CIFAR-1O預測模型
cifar10_input.py 在tensorflow中讀入CIFAR-10訓練圖檔
cifar10_input_test.py cifar10_input 的測試用例檔案
cifar10_train.py 使用單個GPU或CPU訓練模型
cifar10_train_multi_gpu.py 使用多個gpu訓練模型
cifar10_eval.py 在測試集上測試模型的性能

二、下載下傳CIFAR-10資料

在工程根目錄建立 cifar10_download.py ,輸入如下代碼建立下載下傳資料的程式:

# 引入目前目錄中已經編寫好的cifar10子產品
import cifar10
# 引入tensorflow
import tensorflow as tf

# 定義全局變量存儲器,可用于指令行參數的處理
# tf.app.flags.FLAGS 是tensorflow 内部的一個全局變量存儲器
FLAGS = tf.app.flags.FLAGS
# 在cifar10 子產品中預先定義了cifar-10的資料存儲路徑,修改資料存儲路徑
FLAGS.data_dir = 'cifar10_data/'
# 如果資料不存在,則下載下傳
cifar10.maybe_download_and_extract()           

執行完這段代碼後,CIFAR-10資料集會下載下傳到目錄 cifar10_data 目錄下。預設的存儲路徑書 tmp/cifar10_data,定義在代碼檔案cifar10.py中,位置大約在53行附近。

修改完資料存儲路徑後,通過

cifar10.maybe_download_and_extract()

來下載下傳資料,下載下傳期間如果資料存在于資料檔案夾中則跳過下載下傳資料,反之下載下傳資料。下載下傳成功後會提示 Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.

下載下傳完成後,cifar10_data/cifar-10-batches-bin 中将出現8個檔案,名稱和用途如下表:

檔案名
batches.meta.txt 存儲每個類别的英文名
data_batch_1.bin、......、data_batch_5.bin CIFAR-10的五個訓練集,每個訓練集用二進制格式存儲了10000張32*32的彩×××像和圖相對應的标簽,沒個樣本由3073個位元組組成,第一個位元組未标簽,剩下的位元組未圖像資料
test_batch.bin 存儲1000張用于測試的圖像和對應的标簽
readme.html 資料集介紹檔案

三、TensorFlow 讀取資料的機制

  1. 普通方式

    将硬碟上的資料讀入記憶體中,然後提供給CPU或者GPU處理

  2. 記憶體隊列方式

    普通方式讀取資料會出現GPU或CPU在一段時間記憶體在空閑,導緻運算效率降低。利用記憶體隊列,将資料讀取和計算放在兩個線程中,讀取線程隻需向記憶體隊列中讀入檔案,而計算線程隻用從記憶體隊列中讀取計算需要的資料,這樣就解決了GPU或者CPU的空閑問題。

  3. 檔案名隊列+記憶體隊列

    TensorFlow采用 檔案名隊列+記憶體隊列,這種方式可以很好的管理epoch(注1)和避免計算單元的空閑問題。舉個例子,假設有三個資料檔案要執行一次epoch,那麼就在檔案名隊列中放入這三個資料檔案各一次,并且在最後放入的資料檔案後面标注隊列結束。記憶體隊列依次從檔案名隊列的頂部讀取資料檔案,讀到結束标記後就會自動抛出異常,捕獲這個異常後程式就可以結束。如果是執行N次epoch,那麼就把每個資料檔案放入檔案名隊列N次。

    注1:

    對于資料集來說,運作一次epoch就是将資料集裡的所有資料完整的計算一遍,以此類推運作N次epoch就是将資料集裡的所有資料完整的計算N遍

四、建立檔案名隊列和記憶體隊列

  1. 建立檔案名隊列

    利用tensorflow的

    tf.train.string_input_producer()

    (注2) 函數。給函數傳入一個檔案名清單,系統将會轉換未檔案名隊列。tf.train.string_input_producer() 函數有兩個重要的參數,分别是 num_epochs 和 shuffle ,num_epochs表示epochs數,shuffle表示是否打亂檔案名隊列内檔案的順序,如果是True表示不按照檔案名清單添加的順序進入檔案名隊列,如果是Flase表示按照檔案名清單添加的順序進入檔案名隊列。
  2. 建立記憶體隊列

    在tensorflow中不手動建立記憶體隊列,隻需使用

    reader

    對象從檔案名隊列中讀取資料就可以了。

    注2:

    使用tf.train.string_input_producer() 建立完檔案名隊列後,檔案名并沒有被加入到隊列中,如果此時開始計算,會導緻整個系統處于阻塞狀态。

    在建立完檔案名隊列後,應調用

    tf.train.start_queue_runners

    方法才會啟動檔案名隊列的填充,整個程式才能正常運作起來。
  3. 代碼
import tensorflow as tf

# 建立session
with tf.Session() as sess:
    # 要讀取的三張圖檔
    filename = ['img/1.jpg', 'img/2.jpg', 'img/3.jpg']
    # 建立檔案名隊列
    filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    # 初始化變量(epoch)
    tf.local_variables_initializer().run()
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 擷取圖檔儲存資料
        image_data = sess.run(value)
        with open('read/test_%d.jpg' % i, 'wb') as f:
            f.write(image_data)           

五、資料增強

對于圖像資料來說,資料增強方法就是利用平移、縮放、顔色等變換增大訓練集樣本個數,進而達到更好的效果(注3),使用資料增強可以大大提高模型的泛化能力,并且能夠預防過拟合。

常用的圖像資料增強方法如下表

方法 說明
平移 将圖像在一定尺度範圍内平移
旋轉 将圖像在一定角度範圍内旋轉
翻轉 水準翻轉或者上下翻轉圖檔
裁剪 在原圖上裁剪出一塊
縮放 将圖像在一定尺度内放大或縮小
顔色變換 對圖像的RGB顔色空間進行一些變換
噪聲擾動 給圖像加入一些人工生成的噪聲

注3:

使用資料增強的方法前提是,這些資料增強方法不會改變圖像的原有标簽。比如數字6的圖檔,經過上下翻轉之後就變成了數字9的圖檔。

六、CIFAR-10識别模型

建立模型的代碼在cifar10.py檔案額inference函數中,代碼在這裡不進行詳解,讀者可以去閱讀代碼中的注釋。

這裡我們通過以下指令訓練模型:

python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/           

這段指令中 --data_dir cifar10_data/ 表示資料儲存的位置, --train_dir cifar10_train/ 表示儲存模型參數和訓練時日志資訊的位置

七、檢視訓練進度

在訓練的時候我們往往需要知道損失的變化和每層的訓練情況,這個時候我們就會用到tensorflow提供的 TensorBoard。打開一個新的指令行,輸入如下指令:

tensorboard --logdir cifar10_train/           

其中 --logdir cifar10_train/ 表示模型訓練日志儲存的位置,運作該指令後将會在指令行看到類似如下的内容

CIFAR-10 圖像識别

在浏覽器上輸入顯示的位址,即可通路TensorBoard。簡單解釋一下常用的幾個标簽:

标簽
total_loss_1 loss 的變化曲線,變化曲線會根據時間實時變化
learning_rate 學習率變化曲線
global_step 美妙訓練步數的情況,如果訓練速度變化較大,或者越來越慢,就說明程式有可能存在錯誤

八、檢測模型的準确性

在指令行視窗輸入如下指令:

python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/           
tensorboard --logdir cifar10_eval/ --port 6007           

九、代碼下載下傳

繼續閱讀