零、學習目标
- tensorflow 資料讀取原理
- 深度學習資料增強原理
一、CIFAR-10資料集簡介
是用于普通物體識别的小型資料集,一共包含 10個類别 的 RGB彩×××片(包含:(飛機、汽車、鳥類、貓、鹿、狗、蛙、馬、船、卡車)。圖檔大小均為 3232像素**,資料集中一共有 50000 張訓練圖檔和 1000*** 張測試圖檔。部分代碼來自于tensorflow官方,以下表格列出了所需的官方代碼。
檔案 | 用途 |
---|---|
cifar10.py | 建立CIFAR-1O預測模型 |
cifar10_input.py | 在tensorflow中讀入CIFAR-10訓練圖檔 |
cifar10_input_test.py | cifar10_input 的測試用例檔案 |
cifar10_train.py | 使用單個GPU或CPU訓練模型 |
cifar10_train_multi_gpu.py | 使用多個gpu訓練模型 |
cifar10_eval.py | 在測試集上測試模型的性能 |
二、下載下傳CIFAR-10資料
在工程根目錄建立 cifar10_download.py ,輸入如下代碼建立下載下傳資料的程式:
# 引入目前目錄中已經編寫好的cifar10子產品
import cifar10
# 引入tensorflow
import tensorflow as tf
# 定義全局變量存儲器,可用于指令行參數的處理
# tf.app.flags.FLAGS 是tensorflow 内部的一個全局變量存儲器
FLAGS = tf.app.flags.FLAGS
# 在cifar10 子產品中預先定義了cifar-10的資料存儲路徑,修改資料存儲路徑
FLAGS.data_dir = 'cifar10_data/'
# 如果資料不存在,則下載下傳
cifar10.maybe_download_and_extract()
執行完這段代碼後,CIFAR-10資料集會下載下傳到目錄 cifar10_data 目錄下。預設的存儲路徑書 tmp/cifar10_data,定義在代碼檔案cifar10.py中,位置大約在53行附近。
修改完資料存儲路徑後,通過
cifar10.maybe_download_and_extract()
來下載下傳資料,下載下傳期間如果資料存在于資料檔案夾中則跳過下載下傳資料,反之下載下傳資料。下載下傳成功後會提示 Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.
下載下傳完成後,cifar10_data/cifar-10-batches-bin 中将出現8個檔案,名稱和用途如下表:
檔案名 | |
---|---|
batches.meta.txt | 存儲每個類别的英文名 |
data_batch_1.bin、......、data_batch_5.bin | CIFAR-10的五個訓練集,每個訓練集用二進制格式存儲了10000張32*32的彩×××像和圖相對應的标簽,沒個樣本由3073個位元組組成,第一個位元組未标簽,剩下的位元組未圖像資料 |
test_batch.bin | 存儲1000張用于測試的圖像和對應的标簽 |
readme.html | 資料集介紹檔案 |
三、TensorFlow 讀取資料的機制
-
普通方式
将硬碟上的資料讀入記憶體中,然後提供給CPU或者GPU處理
-
記憶體隊列方式
普通方式讀取資料會出現GPU或CPU在一段時間記憶體在空閑,導緻運算效率降低。利用記憶體隊列,将資料讀取和計算放在兩個線程中,讀取線程隻需向記憶體隊列中讀入檔案,而計算線程隻用從記憶體隊列中讀取計算需要的資料,這樣就解決了GPU或者CPU的空閑問題。
-
檔案名隊列+記憶體隊列
TensorFlow采用 檔案名隊列+記憶體隊列,這種方式可以很好的管理epoch(注1)和避免計算單元的空閑問題。舉個例子,假設有三個資料檔案要執行一次epoch,那麼就在檔案名隊列中放入這三個資料檔案各一次,并且在最後放入的資料檔案後面标注隊列結束。記憶體隊列依次從檔案名隊列的頂部讀取資料檔案,讀到結束标記後就會自動抛出異常,捕獲這個異常後程式就可以結束。如果是執行N次epoch,那麼就把每個資料檔案放入檔案名隊列N次。
注1:
對于資料集來說,運作一次epoch就是将資料集裡的所有資料完整的計算一遍,以此類推運作N次epoch就是将資料集裡的所有資料完整的計算N遍
四、建立檔案名隊列和記憶體隊列
-
建立檔案名隊列
利用tensorflow的
(注2) 函數。給函數傳入一個檔案名清單,系統将會轉換未檔案名隊列。tf.train.string_input_producer() 函數有兩個重要的參數,分别是 num_epochs 和 shuffle ,num_epochs表示epochs數,shuffle表示是否打亂檔案名隊列内檔案的順序,如果是True表示不按照檔案名清單添加的順序進入檔案名隊列,如果是Flase表示按照檔案名清單添加的順序進入檔案名隊列。tf.train.string_input_producer()
-
建立記憶體隊列
在tensorflow中不手動建立記憶體隊列,隻需使用
對象從檔案名隊列中讀取資料就可以了。reader
注2:
使用tf.train.string_input_producer() 建立完檔案名隊列後,檔案名并沒有被加入到隊列中,如果此時開始計算,會導緻整個系統處于阻塞狀态。
在建立完檔案名隊列後,應調用
方法才會啟動檔案名隊列的填充,整個程式才能正常運作起來。tf.train.start_queue_runners
- 代碼
import tensorflow as tf
# 建立session
with tf.Session() as sess:
# 要讀取的三張圖檔
filename = ['img/1.jpg', 'img/2.jpg', 'img/3.jpg']
# 建立檔案名隊列
filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False)
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
# 初始化變量(epoch)
tf.local_variables_initializer().run()
threads = tf.train.start_queue_runners(sess=sess)
i = 0
while True:
i += 1
# 擷取圖檔儲存資料
image_data = sess.run(value)
with open('read/test_%d.jpg' % i, 'wb') as f:
f.write(image_data)
五、資料增強
對于圖像資料來說,資料增強方法就是利用平移、縮放、顔色等變換增大訓練集樣本個數,進而達到更好的效果(注3),使用資料增強可以大大提高模型的泛化能力,并且能夠預防過拟合。
常用的圖像資料增強方法如下表
方法 | 說明 |
---|---|
平移 | 将圖像在一定尺度範圍内平移 |
旋轉 | 将圖像在一定角度範圍内旋轉 |
翻轉 | 水準翻轉或者上下翻轉圖檔 |
裁剪 | 在原圖上裁剪出一塊 |
縮放 | 将圖像在一定尺度内放大或縮小 |
顔色變換 | 對圖像的RGB顔色空間進行一些變換 |
噪聲擾動 | 給圖像加入一些人工生成的噪聲 |
注3:
使用資料增強的方法前提是,這些資料增強方法不會改變圖像的原有标簽。比如數字6的圖檔,經過上下翻轉之後就變成了數字9的圖檔。
六、CIFAR-10識别模型
建立模型的代碼在cifar10.py檔案額inference函數中,代碼在這裡不進行詳解,讀者可以去閱讀代碼中的注釋。
這裡我們通過以下指令訓練模型:
python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/
這段指令中 --data_dir cifar10_data/ 表示資料儲存的位置, --train_dir cifar10_train/ 表示儲存模型參數和訓練時日志資訊的位置
七、檢視訓練進度
在訓練的時候我們往往需要知道損失的變化和每層的訓練情況,這個時候我們就會用到tensorflow提供的 TensorBoard。打開一個新的指令行,輸入如下指令:
tensorboard --logdir cifar10_train/
其中 --logdir cifar10_train/ 表示模型訓練日志儲存的位置,運作該指令後将會在指令行看到類似如下的内容
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLicmbw5SMTlHdKl2LchDMvwFMx8CX4EDMy8CXt92YugXM4FmLxM3Lc9CX6MHc0RHaiojIsJye.png)
在浏覽器上輸入顯示的位址,即可通路TensorBoard。簡單解釋一下常用的幾個标簽:
标簽 | |
---|---|
total_loss_1 | loss 的變化曲線,變化曲線會根據時間實時變化 |
learning_rate | 學習率變化曲線 |
global_step | 美妙訓練步數的情況,如果訓練速度變化較大,或者越來越慢,就說明程式有可能存在錯誤 |
八、檢測模型的準确性
在指令行視窗輸入如下指令:
python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/
tensorboard --logdir cifar10_eval/ --port 6007