天天看點

TensorFlow圖像分類教程

深度學習算法與計算機硬體性能的發展,使研究人員和企業在圖像識别、語音識别、推薦引擎和機器翻譯等領域取得了巨大的進步。六年前,視覺模式識别領域取得了第一個超凡的成果。兩年前,Google大腦團隊開發了TensorFlow,并将深度學習巧妙的應用于各個領域。現在,TensorFlow則超越了很多用于深度學習的複雜工具。

利用TensorFlow,你可以獲得具有強大能力的複雜功能,其強大的基石來自于TensorFlow的易用性。

在這個由兩部分組成的系列中,我将講述如何快速的建立一個應用于圖像識别的卷積神經網絡。TensorFlow計算步驟是并行的,可對其配置進行逐幀視訊分析,也可對其擴充進行時間感覺視訊分析。

本系列文章直接切入關鍵的部分,隻需要對指令行和Python有最基本的了解,就可以在家快速地建立一些令你激動不已的項目。本文不會深入探讨TensorFlow的工作原理,如果你想了解更多,我會提供大量額外的參考資料。本系列所有的庫和工具都是免費開源的軟體。

<b>工作原理</b><b></b>

本教程旨在把一個事先被放到訓練過的類别裡的圖檔,通過運作一個指令以識别該圖像具體屬于哪個類别。步驟如下圖所示:

TensorFlow圖像分類教程

• 标注:管理訓練資料。例如花卉,将雛菊的圖像放到“雛菊”目錄下,将玫瑰放到“玫瑰”目錄下等等,将盡可能多的不同種類的花朵按照類别不同放在不同的目錄下。如果我們不标注“蕨類植物”,那麼分類器永遠也不會傳回“蕨類植物”。這需要每個類型的很多樣本,是以這一步很重要,并且很耗時。(本文使用預先标記好的資料以提高效率)

• 訓練:将标記好的資料(圖像)提供給模型。有一個工具将随機抓取一批圖像,使用模型猜測每種花的類型,測試猜測的準确性,重複執行,直到使用了大部分訓練資料為止。最後一批未被使用的圖像用于計算該訓練模型的準确性。

• 分類:在新的圖像上使用模型。例如,輸入:IMG207.JPG,輸出:雛菊。這個步驟快速簡單,且衡量的代價小。

<b>訓練和分類</b><b></b>

本教程将訓練一個用于識别不同類型花朵的圖像分類器。深度學習需要大量的訓練資料,是以,我們需要大量已分類的花朵圖像。值得慶幸的是,另外一個模型在圖像收集和分類這方面做得非常出色,是以我們使用這個帶有腳本的已分類資料集,它有現成且完全訓練過的圖像分類模型,重新訓練模型的最後幾層以達到我們想要的結果,這種技術稱為遷移學習。

直到我們做了這個約20分鐘的訓練,Inception才知道如何識别雛菊和郁金香,這就是深度學習中的“學習”部分。

<b>安裝</b><b></b>

在很多TensorFlow教程中最先且唯一依賴的就是Docker(應該表明這是個合理的開始)。我也更喜歡這種安裝TensorFlow的方法,因為不需要安裝一系列的依賴項,這可以保持主機(筆記本或桌面)比較幹淨。

<b>Bootstrap TensorFlow</b><b></b>

安裝Docker後,我們準備啟動一個訓練和分類的TensorFlow容器。在硬碟上建立一個2GB空閑空間的工作目錄,建立一個名為local的子目錄,并記錄完整路徑。

<b>docker run </b><b>-v</b><b> </b><b>/</b><b>path</b><b>/</b><b>to</b><b>/</b><b>local:</b><b>/</b><b>notebooks</b><b>/</b><b>local</b><b> </b><b>--rm</b><b> </b><b>-it</b><b> </b><b>--name</b><b> tensorflow </b><b></b>

<b>tensorflow</b><b>/</b><b>tensorflow:nightly </b><b>/</b><b>bin</b><b>/</b><b>bash</b><b></b>

下面是指令解析:

<b>-v /path/to/local:/notebooks/local</b>将剛建立的local目錄挂載到容器中适當的位置。如果使用RHEL、Fedora或其他支援SELinux的系統,添加:Z允許容器通路目錄。

<b>--rm </b>退出時令docker删除容器

<b>-it </b>連接配接輸入輸出,實作互動。

<b>--name tensorflow</b>将容器命名為tensorflow,而不是sneaky_chowderhead或任何Docker定義的随機名字。

<b>tensorflow/tensorflow:nightly</b>從Docker Hub (公共圖像存儲庫)運作tensorflow/tensorflow的<b>nightly</b> 圖像,而不是最新的圖像(預設為最近建立/可用圖像)。使用<b>nightly</b>圖像而不是latest圖像,是因為(在寫入時)latest包含的一個bug會破壞TensorBoard,這是我們稍後需要的一個資料可視化工具。

<b>/bin/bash</b>指定運作Bash shell,而不運作系統預設指令。

<b>訓練模型</b><b></b>

在容器中運作下述指令,對訓練資料進行下載下傳和完整性檢查。

<b>curl </b><b>-O</b><b> http:</b><b>//</b><b>download.tensorflow.org</b><b>/</b><b>example_images</b><b>/</b><b>flower_photos.tgz</b><b></b>

<b>echo</b><b> </b><b>'db6b71d5d3afff90302ee17fd1fefc11d57f243f  flower_photos.tgz'</b><b> </b><b>|</b><b> sha1sum </b><b>-c</b><b></b>

如果沒有看到“flower_photos.tgz”資訊:說明檔案不正确。如果上訴curl 或sha1sum步驟失敗,請手動下載下傳訓練資料包并解壓(SHA-1 校驗碼:db6b71d5d3afff90302ee17fd1fefc11d57f243f)到本地主機的local目錄下。

現在把訓練資料放好,然後對再訓練腳本進行下載下傳和完整性檢查。

<b>mv</b><b> flower_photos.tgz local</b><b>/</b><b></b>

<b>cd</b><b> </b><b>local</b><b></b>

<b>curl </b><b>-O</b><b> https:</b><b>//</b><b>raw.githubusercontent.com</b><b>/</b><b>tensorflow</b><b>/</b><b>tensorflow</b><b>/</b><b>10cf65b48e1b2f16eaa82</b><b></b>

<b>6d2793cb67207a085d0</b><b>/</b><b>tensorflow</b><b>/</b><b>examples</b><b>/</b><b>image_retraining</b><b>/</b><b>retrain.py</b><b></b>

<b>echo</b><b> </b><b>'a74361beb4f763dc2d0101cfe87b672ceae6e2f5  retrain.py'</b><b> </b><b>|</b><b> sha1sum </b><b>-c</b><b></b>

确認retrain.py有正确的内容,你應該看到retrain.py: OK.。

最後,開始學習!運作再訓練腳本。

<b>python retrain.py --image_dir flower_photos --output_graph output_graph.pb </b><b></b>

<b>--output_labels output_labels.txt</b><b></b>

如果遇到如下錯誤,忽略它:

TypeError: not all arguments converted during string formatting Logged from file

tf_logging.py, line 82.

随着retrain.py 的運作,訓練圖像會自動的分批次訓練、測試和驗證資料集。

請注意控制台輸出的最後一行:

<b>INFO:tensorflow:Final</b><b> </b><b>test</b><b> </b><b>accuracy =</b><b> </b><b>89.1%</b><b> </b><b>(</b><b>N</b><b>=</b><b>340</b><b>)</b><b></b>

這說明我們已經得到了一個模型:給定一張圖像,10次中有9次可正确猜出是五種花朵類型中的哪一種。由于提供給訓練過程的随機數不同,分類的精确度也會有所不同。

<b>分類</b><b></b>

再添加一個小腳本,就可以将新的花朵圖像添加到模型中,并輸出測試結果。這就是圖像分類。

将下述腳本命名為 classify.py儲存在本地local目錄:

<b>import</b><b> tensorflow </b><b>as</b><b> tf</b><b>,</b><b> </b><b>sys</b><b></b>

<b> </b><b></b>

<b>image_path </b><b>=</b><b> </b><b>sys</b><b>.</b><b>argv[</b><b>1</b><b>]</b><b></b>

<b>graph_path </b><b>=</b><b> </b><b>'output_graph.pb'</b><b></b>

<b>labels_path </b><b>=</b><b> </b><b>'output_labels.txt'</b><b></b>

<b># Read in the image_data</b><b></b>

<b>image_data </b><b>=</b><b> tf.</b><b>gfile</b><b>.</b><b>FastGFile(</b><b>image_path</b><b>,</b><b> </b><b>'rb'</b><b>)</b><b>.</b><b>read()</b><b></b>

<b># Loads label file, strips off carriage return</b><b></b>

<b>label_lines </b><b>=</b><b> </b><b>[</b><b>line.</b><b>rstrip()</b><b> </b><b>for</b><b> line</b><b></b>

<b>    </b><b>in</b><b> tf.</b><b>gfile</b><b>.</b><b>GFile(</b><b>labels_path</b><b>)]</b><b></b>

<b># Unpersists graph from file</b><b></b>

<b>with</b><b> tf.</b><b>gfile</b><b>.</b><b>FastGFile(</b><b>graph_path</b><b>,</b><b> </b><b>'rb'</b><b>)</b><b> </b><b>as</b><b> f:</b><b></b>

<b>    graph_def </b><b>=</b><b> tf.</b><b>GraphDef()</b><b></b>

<b>    graph_def.</b><b>ParseFromString(</b><b>f.</b><b>read())</b><b></b>

<b>    _ </b><b>=</b><b> tf.</b><b>import_graph_def(</b><b>graph_def</b><b>,</b><b> name</b><b>=</b><b>''</b><b>)</b><b></b>

<b># Feed the image_data as input to the graph and get first prediction</b><b></b>

<b>with</b><b> tf.</b><b>Session()</b><b> </b><b>as</b><b> sess:</b><b></b>

<b>    softmax_tensor </b><b>=</b><b> sess.</b><b>graph</b><b>.</b><b>get_tensor_by_name(</b><b>'final_result:0'</b><b>)</b><b></b>

<b>    predictions </b><b>=</b><b> sess.</b><b>run(</b><b>softmax_tensor</b><b>,</b><b> </b><b></b>

<b>    </b><b>{</b><b>'DecodeJpeg/contents:0'</b><b>: image_data</b><b>})</b><b></b>

<b>    </b><b># Sort to show labels of first prediction in order of confidence</b><b></b>

<b>    top_k </b><b>=</b><b> predictions</b><b>[</b><b>0</b><b>]</b><b>.</b><b>argsort()[</b><b>-</b><b>len</b><b>(</b><b>predictions</b><b>[</b><b>0</b><b>])</b><b>:</b><b>][</b><b>::-</b><b>1</b><b>]</b><b></b>

<b>    </b><b>for</b><b> node_id </b><b>in</b><b> top_k:</b><b></b>

<b>         human_string </b><b>=</b><b> label_lines</b><b>[</b><b>node_id</b><b>]</b><b></b>

<b>         score </b><b>=</b><b> predictions</b><b>[</b><b>0</b><b>][</b><b>node_id</b><b>]</b><b></b>

<b>         </b><b>print</b><b>(</b><b>'%s (score = %.5f)'</b><b> % </b><b>(</b><b>human_string</b><b>,</b><b> score</b><b>))</b>

為了測試你自己的圖像,儲存在local目錄下并命名為test.jpg,運作(在容器内) <b>python classify.py test.jpg</b>。輸出結果如下:

<b>sunflowers </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.78311</b><b>)</b><b></b>

<b>daisy </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.20722</b><b>)</b><b></b>

<b>dandelion </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00605</b><b>)</b><b></b>

<b>tulips </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00289</b><b>)</b><b></b>

<b>roses </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00073</b><b>)</b>

資料說明了一切!模型确定圖像中的花朵是向日葵的準确度為78.311%。數值越高表明比對度越高。請注意,隻能有一個比對類型。多标簽分類則需要另外一個不同的方法。

分類腳本中的圖表加載代碼已經被破壞,在這裡,我用graph_def = tf.GraphDef()等作為圖表加載代碼。

利用零基礎知識和一些代碼,我們建了一個相當好的花卉圖像分類器,在現有的筆記本電腦上每秒大約可以處理5張圖像。

    希望你能夠繼續關注本部落格的系列博文。

以上為譯文。

<b>文章原标題《</b>Learn how to classify images with TensorFlow<b>》</b><b>,譯者:</b><b>Mags,審校:袁虎。</b>

繼續閱讀