天天看点

TensorFlow图像分类教程

深度学习算法与计算机硬件性能的发展,使研究人员和企业在图像识别、语音识别、推荐引擎和机器翻译等领域取得了巨大的进步。六年前,视觉模式识别领域取得了第一个超凡的成果。两年前,Google大脑团队开发了TensorFlow,并将深度学习巧妙的应用于各个领域。现在,TensorFlow则超越了很多用于深度学习的复杂工具。

利用TensorFlow,你可以获得具有强大能力的复杂功能,其强大的基石来自于TensorFlow的易用性。

在这个由两部分组成的系列中,我将讲述如何快速的创建一个应用于图像识别的卷积神经网络。TensorFlow计算步骤是并行的,可对其配置进行逐帧视频分析,也可对其扩展进行时间感知视频分析。

本系列文章直接切入关键的部分,只需要对命令行和Python有最基本的了解,就可以在家快速地创建一些令你激动不已的项目。本文不会深入探讨TensorFlow的工作原理,如果你想了解更多,我会提供大量额外的参考资料。本系列所有的库和工具都是免费开源的软件。

<b>工作原理</b><b></b>

本教程旨在把一个事先被放到训练过的类别里的图片,通过运行一个命令以识别该图像具体属于哪个类别。步骤如下图所示:

TensorFlow图像分类教程

• 标注:管理训练数据。例如花卉,将雏菊的图像放到“雏菊”目录下,将玫瑰放到“玫瑰”目录下等等,将尽可能多的不同种类的花朵按照类别不同放在不同的目录下。如果我们不标注“蕨类植物”,那么分类器永远也不会返回“蕨类植物”。这需要每个类型的很多样本,因此这一步很重要,并且很耗时。(本文使用预先标记好的数据以提高效率)

• 训练:将标记好的数据(图像)提供给模型。有一个工具将随机抓取一批图像,使用模型猜测每种花的类型,测试猜测的准确性,重复执行,直到使用了大部分训练数据为止。最后一批未被使用的图像用于计算该训练模型的准确性。

• 分类:在新的图像上使用模型。例如,输入:IMG207.JPG,输出:雏菊。这个步骤快速简单,且衡量的代价小。

<b>训练和分类</b><b></b>

本教程将训练一个用于识别不同类型花朵的图像分类器。深度学习需要大量的训练数据,因此,我们需要大量已分类的花朵图像。值得庆幸的是,另外一个模型在图像收集和分类这方面做得非常出色,所以我们使用这个带有脚本的已分类数据集,它有现成且完全训练过的图像分类模型,重新训练模型的最后几层以达到我们想要的结果,这种技术称为迁移学习。

直到我们做了这个约20分钟的训练,Inception才知道如何识别雏菊和郁金香,这就是深度学习中的“学习”部分。

<b>安装</b><b></b>

在很多TensorFlow教程中最先且唯一依赖的就是Docker(应该表明这是个合理的开始)。我也更喜欢这种安装TensorFlow的方法,因为不需要安装一系列的依赖项,这可以保持主机(笔记本或桌面)比较干净。

<b>Bootstrap TensorFlow</b><b></b>

安装Docker后,我们准备启动一个训练和分类的TensorFlow容器。在硬盘上创建一个2GB空闲空间的工作目录,创建一个名为local的子目录,并记录完整路径。

<b>docker run </b><b>-v</b><b> </b><b>/</b><b>path</b><b>/</b><b>to</b><b>/</b><b>local:</b><b>/</b><b>notebooks</b><b>/</b><b>local</b><b> </b><b>--rm</b><b> </b><b>-it</b><b> </b><b>--name</b><b> tensorflow </b><b></b>

<b>tensorflow</b><b>/</b><b>tensorflow:nightly </b><b>/</b><b>bin</b><b>/</b><b>bash</b><b></b>

下面是命令解析:

<b>-v /path/to/local:/notebooks/local</b>将刚创建的local目录挂载到容器中适当的位置。如果使用RHEL、Fedora或其他支持SELinux的系统,添加:Z允许容器访问目录。

<b>--rm </b>退出时令docker删除容器

<b>-it </b>连接输入输出,实现交互。

<b>--name tensorflow</b>将容器命名为tensorflow,而不是sneaky_chowderhead或任何Docker定义的随机名字。

<b>tensorflow/tensorflow:nightly</b>从Docker Hub (公共图像存储库)运行tensorflow/tensorflow的<b>nightly</b> 图像,而不是最新的图像(默认为最近建立/可用图像)。使用<b>nightly</b>图像而不是latest图像,是因为(在写入时)latest包含的一个bug会破坏TensorBoard,这是我们稍后需要的一个数据可视化工具。

<b>/bin/bash</b>指定运行Bash shell,而不运行系统默认命令。

<b>训练模型</b><b></b>

在容器中运行下述命令,对训练数据进行下载和完整性检查。

<b>curl </b><b>-O</b><b> http:</b><b>//</b><b>download.tensorflow.org</b><b>/</b><b>example_images</b><b>/</b><b>flower_photos.tgz</b><b></b>

<b>echo</b><b> </b><b>'db6b71d5d3afff90302ee17fd1fefc11d57f243f  flower_photos.tgz'</b><b> </b><b>|</b><b> sha1sum </b><b>-c</b><b></b>

如果没有看到“flower_photos.tgz”信息:说明文件不正确。如果上诉curl 或sha1sum步骤失败,请手动下载训练数据包并解压(SHA-1 校验码:db6b71d5d3afff90302ee17fd1fefc11d57f243f)到本地主机的local目录下。

现在把训练数据放好,然后对再训练脚本进行下载和完整性检查。

<b>mv</b><b> flower_photos.tgz local</b><b>/</b><b></b>

<b>cd</b><b> </b><b>local</b><b></b>

<b>curl </b><b>-O</b><b> https:</b><b>//</b><b>raw.githubusercontent.com</b><b>/</b><b>tensorflow</b><b>/</b><b>tensorflow</b><b>/</b><b>10cf65b48e1b2f16eaa82</b><b></b>

<b>6d2793cb67207a085d0</b><b>/</b><b>tensorflow</b><b>/</b><b>examples</b><b>/</b><b>image_retraining</b><b>/</b><b>retrain.py</b><b></b>

<b>echo</b><b> </b><b>'a74361beb4f763dc2d0101cfe87b672ceae6e2f5  retrain.py'</b><b> </b><b>|</b><b> sha1sum </b><b>-c</b><b></b>

确认retrain.py有正确的内容,你应该看到retrain.py: OK.。

最后,开始学习!运行再训练脚本。

<b>python retrain.py --image_dir flower_photos --output_graph output_graph.pb </b><b></b>

<b>--output_labels output_labels.txt</b><b></b>

如果遇到如下错误,忽略它:

TypeError: not all arguments converted during string formatting Logged from file

tf_logging.py, line 82.

随着retrain.py 的运行,训练图像会自动的分批次训练、测试和验证数据集。

请注意控制台输出的最后一行:

<b>INFO:tensorflow:Final</b><b> </b><b>test</b><b> </b><b>accuracy =</b><b> </b><b>89.1%</b><b> </b><b>(</b><b>N</b><b>=</b><b>340</b><b>)</b><b></b>

这说明我们已经得到了一个模型:给定一张图像,10次中有9次可正确猜出是五种花朵类型中的哪一种。由于提供给训练过程的随机数不同,分类的精确度也会有所不同。

<b>分类</b><b></b>

再添加一个小脚本,就可以将新的花朵图像添加到模型中,并输出测试结果。这就是图像分类。

将下述脚本命名为 classify.py保存在本地local目录:

<b>import</b><b> tensorflow </b><b>as</b><b> tf</b><b>,</b><b> </b><b>sys</b><b></b>

<b> </b><b></b>

<b>image_path </b><b>=</b><b> </b><b>sys</b><b>.</b><b>argv[</b><b>1</b><b>]</b><b></b>

<b>graph_path </b><b>=</b><b> </b><b>'output_graph.pb'</b><b></b>

<b>labels_path </b><b>=</b><b> </b><b>'output_labels.txt'</b><b></b>

<b># Read in the image_data</b><b></b>

<b>image_data </b><b>=</b><b> tf.</b><b>gfile</b><b>.</b><b>FastGFile(</b><b>image_path</b><b>,</b><b> </b><b>'rb'</b><b>)</b><b>.</b><b>read()</b><b></b>

<b># Loads label file, strips off carriage return</b><b></b>

<b>label_lines </b><b>=</b><b> </b><b>[</b><b>line.</b><b>rstrip()</b><b> </b><b>for</b><b> line</b><b></b>

<b>    </b><b>in</b><b> tf.</b><b>gfile</b><b>.</b><b>GFile(</b><b>labels_path</b><b>)]</b><b></b>

<b># Unpersists graph from file</b><b></b>

<b>with</b><b> tf.</b><b>gfile</b><b>.</b><b>FastGFile(</b><b>graph_path</b><b>,</b><b> </b><b>'rb'</b><b>)</b><b> </b><b>as</b><b> f:</b><b></b>

<b>    graph_def </b><b>=</b><b> tf.</b><b>GraphDef()</b><b></b>

<b>    graph_def.</b><b>ParseFromString(</b><b>f.</b><b>read())</b><b></b>

<b>    _ </b><b>=</b><b> tf.</b><b>import_graph_def(</b><b>graph_def</b><b>,</b><b> name</b><b>=</b><b>''</b><b>)</b><b></b>

<b># Feed the image_data as input to the graph and get first prediction</b><b></b>

<b>with</b><b> tf.</b><b>Session()</b><b> </b><b>as</b><b> sess:</b><b></b>

<b>    softmax_tensor </b><b>=</b><b> sess.</b><b>graph</b><b>.</b><b>get_tensor_by_name(</b><b>'final_result:0'</b><b>)</b><b></b>

<b>    predictions </b><b>=</b><b> sess.</b><b>run(</b><b>softmax_tensor</b><b>,</b><b> </b><b></b>

<b>    </b><b>{</b><b>'DecodeJpeg/contents:0'</b><b>: image_data</b><b>})</b><b></b>

<b>    </b><b># Sort to show labels of first prediction in order of confidence</b><b></b>

<b>    top_k </b><b>=</b><b> predictions</b><b>[</b><b>0</b><b>]</b><b>.</b><b>argsort()[</b><b>-</b><b>len</b><b>(</b><b>predictions</b><b>[</b><b>0</b><b>])</b><b>:</b><b>][</b><b>::-</b><b>1</b><b>]</b><b></b>

<b>    </b><b>for</b><b> node_id </b><b>in</b><b> top_k:</b><b></b>

<b>         human_string </b><b>=</b><b> label_lines</b><b>[</b><b>node_id</b><b>]</b><b></b>

<b>         score </b><b>=</b><b> predictions</b><b>[</b><b>0</b><b>][</b><b>node_id</b><b>]</b><b></b>

<b>         </b><b>print</b><b>(</b><b>'%s (score = %.5f)'</b><b> % </b><b>(</b><b>human_string</b><b>,</b><b> score</b><b>))</b>

为了测试你自己的图像,保存在local目录下并命名为test.jpg,运行(在容器内) <b>python classify.py test.jpg</b>。输出结果如下:

<b>sunflowers </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.78311</b><b>)</b><b></b>

<b>daisy </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.20722</b><b>)</b><b></b>

<b>dandelion </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00605</b><b>)</b><b></b>

<b>tulips </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00289</b><b>)</b><b></b>

<b>roses </b><b>(</b><b>score </b><b>=</b><b> </b><b>0.00073</b><b>)</b>

数据说明了一切!模型确定图像中的花朵是向日葵的准确度为78.311%。数值越高表明匹配度越高。请注意,只能有一个匹配类型。多标签分类则需要另外一个不同的方法。

分类脚本中的图表加载代码已经被破坏,在这里,我用graph_def = tf.GraphDef()等作为图表加载代码。

利用零基础知识和一些代码,我们建了一个相当好的花卉图像分类器,在现有的笔记本电脑上每秒大约可以处理5张图像。

    希望你能够继续关注本博客的系列博文。

以上为译文。

<b>文章原标题《</b>Learn how to classify images with TensorFlow<b>》</b><b>,译者:</b><b>Mags,审校:袁虎。</b>

继续阅读