天天看点

TensorFlow教程之完整教程 2.5 TensorFlow运作方式入门

本文档为tensorflow参考文档,本转载已得到tensorflow中文社区授权。

本篇教程的目的,是向大家展示如何利用tensorflow使用(经典)mnist数据集训练并评估一个用于识别手写数字的简易前馈神经网络(feed-forward neural network)。我们的目标读者,是有兴趣使用tensorflow的资深机器学习人士。

因此,撰写该系列教程并不是为了教大家机器学习领域的基础知识。

本教程引用如下文件:

文件

目的

<a href="https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/mnist.py" target="_blank"><code>mnist.py</code></a>

构建一个完全连接(fully connected)的minst模型所需的代码。

<a href="https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py" target="_blank"><code>fully_connected_feed.py</code></a>

利用下载的数据集训练构建好的mnist模型的主要代码,以数据反馈字典(feed dictionary)的形式作为输入模型。

只需要直接运行<code>fully_connected_feed.py</code>文件,就可以开始训练:

<code>python fully_connected_feed.py</code>

mnist是机器学习领域的一个经典问题,指的是让机器查看一系列大小为28x28像素的手写数字灰度图像,并判断这些图像代表0-9中的哪一个数字。

TensorFlow教程之完整教程 2.5 TensorFlow运作方式入门

在<code>run_training()</code>方法的一开始,<code>input_data.read_data_sets()</code>函数会确保你的本地训练文件夹中,已经下载了正确的数据,然后将这些数据解压并返回一个含有<code>dataset</code>实例的字典。

注意:<code>fake_data</code>标记是用于单元测试的,读者可以不必理会。

数据集

<code>data_sets.train</code>

55000个图像和标签(labels),作为主要训练集。

<code>data_sets.validation</code>

5000个图像和标签,用于迭代验证训练准确度。

<code>data_sets.test</code>

10000个图像和标签,用于最终测试训练准确度(trained accuracy)。

在训练循环(training loop)的后续步骤中,传入的整个图像和标签数据集会被切片,以符合每一个操作所设置的<code>batch_size</code>值,占位符操作将会填补以符合这个<code>batch_size</code>值。然后使用<code>feed_dict</code>参数,将数据传入<code>sess.run()</code>函数。

在为数据创建占位符之后,就可以运行<code>mnist.py</code>文件,经过三阶段的模式函数操作:<code>inference()</code>, <code>loss()</code>,和<code>training()</code>。图表就构建完成了。

1.<code>inference()</code> —— 尽可能地构建好图表,满足促使神经网络向前反馈并做出预测的要求。

2.<code>loss()</code> —— 往inference图表中添加生成损失(loss)所需要的操作(ops)。

3.<code>training()</code> —— 往损失图表中添加计算并应用梯度(gradients)所需的操作。

TensorFlow教程之完整教程 2.5 TensorFlow运作方式入门

<code>inference()</code>函数会尽可能地构建图表,做到返回包含了预测结果(output prediction)的tensor。

它接受图像占位符为输入,在此基础上借助relu(rectified linear units)激活函数,构建一对完全连接层(layers),以及一个有着十个节点(node)、指明了输出logtis模型的线性层。

例如,当这些层是在<code>hidden1</code>作用域下生成时,赋予权重变量的独特名称将会是"<code>hidden1/weights</code>"。

每个变量在构建时,都会获得初始化操作(initializer ops)。

最后,程序会返回包含了输出结果的<code>logits</code>tensor。

<code>loss()</code>函数通过添加所需的损失操作,进一步构建图表。

首先,<code>labels_placeholer</code>中的值,将被编码为一个含有1-hot values的tensor。例如,如果类标识符为“3”,那么该值就会被转换为: 

<code>[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</code>

最后,程序会返回包含了损失值的tensor。

注意:交叉熵是信息理论中的概念,可以让我们描述如果基于已有事实,相信神经网络所做的推测最坏会导致什么结果。更多详情,请查阅博文《可视化信息理论

<code>training()</code>函数添加了通过梯度下降(gradient descent)将损失最小化所需的操作。

最后,程序返回包含了训练操作(training op)输出结果的tensor。

一旦图表构建完毕,就通过<code>fully_connected_feed.py</code>文件中的用户代码进行循环地迭代式训练和评估。

<code>tf.graph</code>实例是一系列可以作为整体执行的操作。tensorflow的大部分场景只需要依赖默认图表一个实例即可。

利用多个图表的更加复杂的使用场景也是可能的,但是超出了本教程的范围。

另外,也可以利用<code>with</code>代码块生成<code>session</code>,限制作用域:

<code>session</code>函数中没有传入参数,表明该代码将会依附于(如果还没有创建会话,则会创建新的会话)默认的本地会话。

完成会话中变量的初始化之后,就可以开始训练了。

训练的每一步都是通过用户代码控制,而能实现有效训练的最简单循环就是:

但是,本教程中的例子要更为复杂一点,原因是我们必须把输入的数据根据每一步的情况进行切分,以匹配之前生成的占位符。

执行每一步时,我们的代码会生成一个反馈字典(feed dictionary),其中包含对应步骤中训练所要使用的例子,这些例子的哈希键就是其所代表的占位符操作。

<code>fill_feed_dict</code>函数会查询给定的<code>dataset</code>,索要下一批次<code>batch_size</code>的图像和标签,与占位符相匹配的tensor则会包含下一批次的图像和标签。

然后,以占位符为哈希键,创建一个python字典对象,键值则是其代表的反馈tensor。

这个字典随后作为<code>feed_dict</code>参数,传入<code>sess.run()</code>函数中,为这一步的训练提供输入样例。

在运行<code>sess.run</code>函数时,要在代码中明确其需要获取的两个值:<code>[train_op, loss]</code>。

因为要获取这两个值,<code>sess.run()</code>会返回一个有两个元素的元组。其中每一个<code>tensor</code>对象,对应了返回的元组中的numpy数组,而这些数组中包含了当前这步训练中对应tensor的值。由于<code>train_op</code>并不会产生输出,其在返回的元祖中的对应元素就是<code>none</code>,所以会被抛弃。但是,如果模型在训练中出现偏差,<code>loss</code>tensor的值可能会变成nan,所以我们要获取它的值,并记录下来。

假设训练一切正常,没有出现nan,训练循环会每隔100个训练步骤,就打印一行简单的状态文本,告知用户当前的训练状态。

最后,每次运行<code>summary_op</code>时,都会往事件文件中写入最新的即时数据,函数的输出会传入事件文件读写器(writer)的<code>add_summary()</code>函数。。

事件文件写入完毕之后,可以就训练文件夹打开一个tensorboard,查看即时数据的情况。

TensorFlow教程之完整教程 2.5 TensorFlow运作方式入门

每隔一千个训练步骤,我们的代码会尝试使用训练数据集与测试数据集,对模型进行评估。<code>do_eval</code>函数会被调用三次,分别使用训练数据集、验证数据集合测试数据集。

注意,更复杂的使用场景通常是,先隔绝<code>data_sets.test</code>测试数据集,只有在大量的超参数优化调整(hyperparameter tuning)之后才进行检查。但是,由于mnist问题比较简单,我们在这里一次性评估所有的数据。

在打开默认图表(graph)之前,我们应该先调用<code>get_data(train=false)</code>函数,抓取测试数据集。

在进入训练循环之前,我们应该先调用<code>mnist.py</code>文件中的<code>evaluation</code>函数,传入的logits和标签参数要与<code>loss</code>函数的一致。这样做事为了先构建eval操作。

之后,我们可以创建一个循环,往其中添加<code>feed_dict</code>,并在调用<code>sess.run()</code>函数时传入<code>eval_correct</code>操作,目的就是用给定的数据集评估模型。

<code>true_count</code>变量会累加所有<code>in_top_k</code>操作判定为正确的预测之和。接下来,只需要将正确测试的总数,除以例子总数,就可以得出准确率了。