天天看点

《TensorFlow实战》学习1——softmax regression

作为书中第一个实战例子,“Tensorflow实现Softmax Regression识别手写数字”中使用的网络很简单,因为没有隐含层,都算不上神经网络。我也简单的记录一下这个实例中比较有价值的点吧。

一. 数据集

本书中很多实例都是跑mnist数据集,此数据集很小只有55000张训练图片,10000张测试图片,5000张验证图片,图片的内容是0-9手写数字。图片是28*28的灰度图,空白像素点的值为0,有颜色的像素值全为1。数据集有两种输入格式:1*784或28*28.

二.训练网络

《TensorFlow实战》学习1——softmax regression

训练有3个过程:

1.      提取特征

《TensorFlow实战》学习1——softmax regression

2.      Softmax

《TensorFlow实战》学习1——softmax regression

3.      交叉熵和梯度下降优化

《TensorFlow实战》学习1——softmax regression

   梯度下降调用tf内部的操作即可。

三.测试网络

《TensorFlow实战》学习1——softmax regression

测试网络通过10000张验证图片求accuracy。

四.程序分析

# coding:utf-8

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf

#创建新的会话,图的执行和保存一般是在会话中完成,但这里sess貌似没有用到,这点不太明白。
sess = tf.InteractiveSession()
#placeholder:占位符,提供输入数据的地方,不是tensor.一般只有x,y输入才需要用到。
#[None, 784]:占位符的shape,也是输入tensor的shape,取决于输入数据的shape。784表示列数,None几乎等于输入行数任意。
x = tf.placeholder(tf.float32, [None, 784])

#Variable:持久化保存tensor。tensor本身一用完就会消失,像w,b这种一直迭代的参数需要持久化存在。
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

#tensorflow的nn中有大量神经网络组件,要用某种操作,首先到nn、train这样的模组种找。
#matmul矩阵乘法中的一个ops,最常用的ops有array、矩阵等文件下的ops,可在../tenserflow/python/ops下xx_ops.py中找。
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32,[None, 10])

#计算交叉熵,也就是求loss
#reduce_mean:对每个batch数据结果求均值
#reduce_sum:求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

#GradientDescentOptimizer:随机梯度下降(SGD),0.5为其学习率,minimize的参数为最小化的目标tensor.
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

#初始化图中变量。
tf.global_variables_initializer().run()

#开始按Batch训练。
for i in range(1000):
    #读入100张图和对应的100个标签。
    batch_xs, batch_ys = mnist.train.next_batch(100)
    #执行优化
    #{x:batch_xs, y_:batch_ys}将数据输入到对应的placeholder中。
    train_step.run({x:batch_xs, y_:batch_ys})

#equal:判断两个参数是否相等,预测是否准确。
#argmax:求各预测数字中概率最大的一个。
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

#计算搜索测试的平均准确率。
#cast:转换类型,将第一个参数的类型转化为第二参数指定的类型。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print accuracy.eval({x:mnist.test.images, y_:mnist.test.labels})
           

继续阅读