天天看點

tensorflow中batch normalization的用法

網上找了下tensorflow中使用batch normalization的部落格,發現寫的都不是很好,在此總結下:

1.原理

公式如下:

y=γ(x-μ)/σ+β

其中x是輸入,y是輸出,μ是均值,σ是方差,γ和β是縮放(scale)、偏移(offset)系數。

一般來講,這些參數都是基于channel來做的,比如輸入x是一個16*32*32*128(NWHC格式)的feature map,那麼上述參數都是128維的向量。其中γ和β是可有可無的,有的話,就是一個可以學習的參數(參與前向後向),沒有的話,就簡化成y=(x-μ)/σ。而μ和σ,在訓練的時候,使用的是batch内的統計值,測試/預測的時候,采用的是訓練時計算出的滑動平均值。

2.tensorflow中使用

tensorflow中batch normalization的實作主要有下面三個:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封裝程度逐個遞進,建議使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因為在tensorflow官網的解釋比較詳細。我平時多使用tf.layers.batch_normalization,是以下面的步驟都是基于這個。

3.訓練

訓練的時候需要注意兩點,(1)輸入參數

training=True,(2)計算loss時,要添加以下代碼(即添加

update_ops到最後的train_op中)。這樣才能計算μ和σ的滑動平均(測試時會用到)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)      

4.測試

測試時需要注意一點,輸入參數

training=False,其他就沒了

5.預測

預測時比較特别,因為這一步一般都是從checkpoint檔案中讀取模型參數,然後做預測。一般來說,儲存checkpoint的時候,不會把所有模型參數都儲存下來,因為一些無關資料會增大模型的尺寸,常見的方法是隻儲存那些訓練時更新的參數(可訓練參數),如下:

var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)      

但使用了batch_normalization,γ和β是可訓練參數沒錯,μ和σ不是,它們僅僅是通過滑動平均計算出的,如果按照上面的方法儲存模型,在讀取模型預測時,會報錯找不到μ和σ。更詭異的是,利用

tf.moving_average_variables()也沒法擷取bn層中的μ和σ(也可能是我用法不對),不過好在所有的參數都在

tf.global_variables()中,是以可以這麼寫:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)      

按照上述寫法,即可把μ和σ儲存下來,讀取模型預測時也不會報錯,當然輸入參數training=False還是要的。

注意上面有個不嚴謹的地方,因為我的網絡結構中隻有bn層包含moving_mean和moving_variance,是以隻根據這兩個字元串做了過濾,如果你的網絡結構中其他層也有這兩個參數,但你不需要儲存,建議使用諸如bn/moving_mean的字元串進行過濾。

2018.4.22更新

提供一個基于mnist的示例,供大家參考。包含兩個檔案,分别用于train/test。注意bn_train.py檔案的51-61行,僅儲存了網絡中的可訓練變量和bn層利用統計得到的mean和var。注意示例中需要下載下傳mnist資料集,要保持電腦可以聯網。

繼續閱讀