天天看點

tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究

筆者近來在tensorflow中使用batch_norm時,由于事先不熟悉其内部的原理,是以将其錯誤使用,進而出現了結果與預想不一緻的結果。事後對其進行了一定的調查與研究,在此進行一些總結。

一、錯誤使用及結果

筆者最先使用時隻是了解到了在tensorflow中tf.layers.batch_normalization這個函數,就在函數中直接将其使用,該函數中有一個參數為training,在訓練階段指派True,在測試階段指派False。但是在訓練完成後,出現了奇怪的現象時,在training指派為True時,測試的正确率正常,但是training指派為False時,測試正确率就很低。上述錯誤使用過程可以精簡為下列代碼段

is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)
loss = ...
train_op = optimizer.minimize(loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(train_op)      

二、batch_normalization

下面首先粗略的介紹一下batch_normalization,這種歸一化方法的示意圖和算法如下圖,

tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究
tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究

總的來說就是對于同一batch的input,假設輸入大小為[batch_num, height, width, channel],逐channel地計算同一batch中所有資料的mean和variance,再對input使用mean和variance進行歸一化,最後的輸出再進行線性平移,得到batch_norm的最終結果。僞代碼如下:

for i in range(channel):
    x = input[:,:,:,i]
    mean = mean(x)
    variance = variance(x)
    x = (x - mean) / sqrt(variance)
    x = scale * x + offset
    input[:,:,:,i] = x      

在實作的時候,會在訓練階段記錄下訓練資料中平均mean和variance,記為moving_mean和moving_variance,并在測試階段使用訓練時的moving_mean和moving_variance進行計算,這也就是參數training的作用。另外,在實作時一般使用一個decay系數來逐漸更新moving_mean和moving_variance,moving_mean = moving_mean * decay + new_batch_mean * (1 - decay)

三、tensorflow中的三種實作

tensorflow中關于batch_norm現在有三種實作方式。

1、tf.nn.batch_normalization(最底層的實作)

tf.nn.batch_normalization(
    x,
    mean,
    variance,
    offset,
    scale,
    variance_epsilon,
    name=None
)      

該函數是一種最底層的實作方法,在使用時mean、variance、scale、offset等參數需要自己傳遞并更新,是以實際使用時還需自己對該函數進行封裝,一般不建議使用,但是對了解batch_norm的原理很有幫助。

封裝使用的執行個體如下:

import tensorflow as tf

def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
    """ Assume nd [batch, N1, N2, ..., Nm, Channel] tensor"""
    with tf.variable_scope(name_scope):
        size = x.get_shape().as_list()[-1]
        scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
        pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, list(range(len(x.get_shape())-1)))
        train_mean_op = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
        train_var_op = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

        return tf.cond(training, batch_statistics, population_statistics)

is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = batch_norm(input, name_scope='batch_norm_nn', training=is_traing)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "batch_norm_nn/Model")      

在batch_norm中,首先先計算了x的逐通道的mean和var,然後将pop_mean和pop_var進行更新,并根據是在訓練階段還是測試階段選擇将當批次計算的mean和var或者訓練階段儲存的mean和var與新定義的變量scale和offset一起傳遞給tf.nn.batch_normalization

2、tf.layers.batch_normalization

tf.layers.batch_normalization(
    inputs,
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=True,
    scale=True,
    beta_initializer=tf.zeros_initializer(),
    gamma_initializer=tf.ones_initializer(),
    moving_mean_initializer=tf.zeros_initializer(),
    moving_variance_initializer=tf.ones_initializer(),
    beta_regularizer=None,
    gamma_regularizer=None,
    beta_constraint=None,
    gamma_constraint=None,
    training=False,
    trainable=True,
    name=None,
    reuse=None,
    renorm=False,
    renorm_clipping=None,
    renorm_momentum=0.99,
    fused=None,
    virtual_batch_size=None,
    adjustment=None
)      

 該函數也就是筆者之前使用的函數,在官方文檔中寫道

Note: when training, the moving_mean and moving_variance need to be updated. 
By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. 

Also, be sure to add any batch_normalization ops before getting the update_ops collection. 
Otherwise, update_ops will be empty, and training/inference will not work properly. For example:      
x_norm = tf.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)      

可以看到,與筆者之前的錯誤實作方法的差異主要在

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):      

這兩句話,同時可以看到在第一種方法tf.nn.batch_normalization的封裝過程中也用到了類似的處理方法,具體會在下一段進行說明。

3、tf.contrib.layers.batch_norm     (slim)

tf.contrib.layers.batch_norm(
    inputs,
    decay=0.999,
    center=True,
    scale=False,
    epsilon=0.001,
    activation_fn=None,
    param_initializers=None,
    param_regularizers=None,
    updates_collections=tf.GraphKeys.UPDATE_OPS,
    is_training=True,
    reuse=None,
    variables_collections=None,
    outputs_collections=None,
    trainable=True,
    batch_weights=None,
    fused=None,
    data_format=DATA_FORMAT_NHWC,
    zero_debias_moving_mean=False,
    scope=None,
    renorm=False,
    renorm_clipping=None,
    renorm_decay=0.99,
    adjustment=None
)      

這種方法與tf.layers.batch_normalization的使用方法差不多,兩者最主要的差别在參數scale和centre的預設值上,這兩個參數即是我們之前介紹原理時所說明的對input進行mean和variance的歸一化之後采用的線性平移中的scale和offset,可以看到offset的預設值兩者都是True,但是scale的預設值前者為True後者為False,也就是說明在tf.contrib.layers.batch_norm中,預設不對處理後的input進行線性縮放,隻是加一個偏移。

四、關于tf.GraphKeys.UPDATA_OPS

介紹到這裡,還有兩個概念沒有介紹,一個是​

​tf.GraphKeys.UPDATE_OPS​

​​,另一個是​

​tf.control_dependencies​

​。

1、tf.control_dependencies

首先我們先介紹​

​tf.control_dependencies​

​,該函數保證其轄域中的操作必須要在該函數所傳遞的參數中的操作完成後再進行。請看下面一個例子。

import tensorflow as tf
a_1 = tf.Variable(1)
b_1 = tf.Variable(2)
update_op = tf.assign(a_1, 10)
add = tf.add(a_1, b_1)

a_2 = tf.Variable(1)
b_2 = tf.Variable(2)
update_op = tf.assign(a_2, 10)
with tf.control_dependencies([update_op]):
    add_with_dependencies = tf.add(a_2, b_2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ans_1, ans_2 = sess.run([add, add_with_dependencies])
    print("Add: ", ans_1)
    print("Add_with_dependency: ", ans_2)



輸出:
Add:  3
Add_with_dependency:  12      

可以看到兩組加法進行的對比,正常的計算圖在計算add時是不會經過update_op操作的,是以在加法時a的值為1,但是采用tf.control_dependencies函數,可以控制在進行add前先完成update_op的操作,是以在加法時a的值為10,是以最後兩種加法的結果不同。

2、tf.GraphKeys.UPDATE_OPS

關于tf.GraphKeys.UPDATE_OPS,這是一個tensorflow的計算圖中内置的一個集合,其中會儲存一些需要在訓練操作之前完成的操作,并配合tf.control_dependencies函數使用。

關于在batch_norm中,即為更新mean和variance的操作。通過下面一個例子可以看到tf.layers.batch_normalization中是如何實作的。

import tensorflow as tf

is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
# with tf.control_dependencies(update_ops):
    # train_op = optimizer.minimize(loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "batch_norm_layer/Model")
    
 


輸出:
 [<tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(3,) dtype=float32_ref>, 
  <tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(3,) dtype=float32_ref>]      

可以看到輸出的即為兩個batch_normalization中更新mean和variance的操作,需要保證它們在  train_op 前完成。

這兩個操作是在tensorflow的内部實作中自動被加入tf.GraphKeys.UPDATE_OPS這個集合的,在tf.contrib.layers.batch_norm的參數中可以看到有一項updates_collections的預設值即為  tf.GraphKeys.UPDATE_OPS,

而在tf.layers.batch_normalization中則是直接将兩個更新操作放入了上述集合。

五、關于最初的錯誤使用的思考

最後我對于一開始的使用方法為什麼會導緻錯誤進行了思考,tensorflow  中具體實作batch_normalization的代碼在

​tensorflow\python\layers\normalization.py​

​中,下面展示一些關鍵代碼。

( tf.layers.batch_normalization   :)

if self.scale:
    self.gamma = self.add_variable(
          name='gamma',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True)
else:
    self.gamma = None
      
if self.center:
    self.beta = self.add_variable(
          name='beta',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True)
else:
    self.beta = None
    
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

self.moving_mean = self._add_tower_local_variable(
          name='moving_mean',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_mean_initializer,
          trainable=False)

self.moving_variance = self._add_tower_local_variable(
          name='moving_variance',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_variance_initializer,
          trainable=False)

def _assign_moving_average(self, variable, value, momentum):
    with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope:
        decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
        if decay.dtype != variable.dtype.base_dtype:
            decay = math_ops.cast(decay, variable.dtype.base_dtype)
        update_delta = (variable - value) * decay
        return state_ops.assign_sub(variable, update_delta, name=scope)
 
def _do_update(var, value):
    return self._assign_moving_average(var, value, self.momentum)



# Determine a boolean value for `training`: could be True, False, or None.
training_value = utils.constant_value(training)
if training_value is not False:
    mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
    moving_mean = self.moving_mean
    moving_variance = self.moving_variance
    mean = utils.smart_cond(training,
                              lambda: mean,
                              lambda: moving_mean)
    variance = utils.smart_cond(training,
                                  lambda: variance,
                                  lambda: moving_variance)
else:
    new_mean, new_variance = mean, variance
    
mean_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_mean, new_mean),
          lambda: self.moving_mean)
variance_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_variance, new_variance),
          lambda: self.moving_variance)
if not context.executing_eagerly():
    self.add_update(mean_update, inputs=inputs)
    self.add_update(variance_update, inputs=inputs)
outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)      

可以看到其内部邏輯和我在介紹tf.nn.batch_normalization一節中展示的封裝時所使用的方法類似。

如果不在使用時添加tf.control_dependencies函數,即在訓練時(training=True)每批次時隻會計算當批次的mean和var,并傳遞給tf.nn.batch_normalization進行歸一化,由于mean_update和variance_update在計算圖中并不在上述操作的依賴路徑上,因為并不會主動完成,也就是說,在訓練時mean_update和variance_update并不會被使用到,其值一直是初始值。

繼續閱讀