tensorflow 在实现Batch Normalization (各个网络层输出的结果归一化,以防止过拟合)时,主要用到一下两个API。分别是
1)tf.nn.moments(x, axes, name=None, keep_dims=False) ⇒ mean, variance:
其中计算的得到的为统计矩,mean 是一阶矩,variance 是二阶中心矩 各参数的另一为
- x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]
- axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
- name 就是个名字,
- keep_dims 是否保持维度
下面为一个例子:
img = tf.Variable(tf.random_normal([2, 3]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
输出的结果分别为:
img = [[ 0.69495416 2.08983064 -1.08764684]
[ 0.31431156 -0.98923939 -0.34656194]]
mean = [ 0.50463283 0.55029559 -0.71710438]
variance = [ 0.0362222 2.37016821 0.13730171]
这个例子挺容易理解的,该函数就是在[0] 维度上求了一个均值和方差。
2) tf.nn. batch_normalization (x, mean, variance, offset, scale, variance_epsilon, name=None)
tf.nn. batch_norm_with_global_normalization (t, m, v, beta, gamma, variance_epsilon, scale_after_normalization, name=None)
由函数接口可知,tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参数进一步调用;
在这一堆参数里面,其中x,mean和variance这三个,已经知道了,就是通过moments计算得到的,另外菱格参数,offset和scale一般需要训练,其中offset一般初始化为0,scale初始化为1,另外这两个参数的offset,scale的维度和mean相同。
def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
""" Assume 2d [batch, values] tensor"""
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
offset = tf.get_variable('offset', [size])
pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
batch_mean, batch_var = tf.nn.moments(x, [0])
train_mean_op = tf.assign(pop_mean, pop_mean*decay+batch_mean*(1-decay))
train_var_op = tf.assign(pop_var, pop_var*decay + batch_var*(1-decay))
def batch_statistics():
with tf.control_dependencies([train_mean_op, train_var_op]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(training, batch_statistics, population_statistics)
参考文章: [1] https://www.jianshu.com/p/0312e04e4e83
[2] http://blog.csdn.net/lanchunhui/article/details/70792458
欢迎关注: 自然语言处理技术