天天看点

谈谈Tensorflow的Batch Normalization的使用

tensorflow 在实现Batch Normalization (各个网络层输出的结果归一化,以防止过拟合)时,主要用到一下两个API。分别是

1)tf.nn.moments(x, axes, name=None, keep_dims=False) ⇒ mean, variance: 

其中计算的得到的为统计矩,mean 是一阶矩,variance 是二阶中心矩 各参数的另一为

  • x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]
  • axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
  • name 就是个名字,
  • keep_dims 是否保持维度

下面为一个例子:

img = tf.Variable(tf.random_normal([2, 3]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
           

输出的结果分别为:

img = [[ 0.69495416  2.08983064 -1.08764684]
         [ 0.31431156 -0.98923939 -0.34656194]]
mean =  [ 0.50463283  0.55029559 -0.71710438]
variance =  [ 0.0362222   2.37016821  0.13730171]
           

这个例子挺容易理解的,该函数就是在[0] 维度上求了一个均值和方差。

2) tf.nn. batch_normalization (x, mean, variance, offset, scale, variance_epsilon, name=None) 

tf.nn. batch_norm_with_global_normalization (t, m, v, beta, gamma, variance_epsilon, scale_after_normalization, name=None)

由函数接口可知,tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参数进一步调用;

在这一堆参数里面,其中x,mean和variance这三个,已经知道了,就是通过moments计算得到的,另外菱格参数,offset和scale一般需要训练,其中offset一般初始化为0,scale初始化为1,另外这两个参数的offset,scale的维度和mean相同。

def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
    """ Assume 2d [batch, values] tensor"""
    with tf.variable_scope(name_scope):
        size = x.get_shape().as_list()[1]
        scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
        pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, [0])
        train_mean_op = tf.assign(pop_mean, pop_mean*decay+batch_mean*(1-decay))
        train_var_op = tf.assign(pop_var, pop_var*decay + batch_var*(1-decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

        return tf.cond(training, batch_statistics, population_statistics)
           

参考文章: [1] https://www.jianshu.com/p/0312e04e4e83

[2] http://blog.csdn.net/lanchunhui/article/details/70792458

欢迎关注: 自然语言处理技术

谈谈Tensorflow的Batch Normalization的使用