天天看點

幹貨 | 這可能全網最好的BatchNorm詳解

文章來自:公衆号【機器學習煉丹術】。求關注~

其實關于BN層,我在之前的文章“梯度爆炸”那一篇中已經涉及到了,但是鑒于面試經曆中多次問道這個,這裡再做一個更加全面的講解。

Internal Covariate Shift(ICS)

Batch Normalization的原論文作者給了Internal Covariate Shift一個較規範的定義:在深層網絡訓練的過程中,由于網絡中參數變化而引起内部結點資料分布發生變化的這一過程被稱作Internal Covariate Shift。

這裡做一個簡單的數學定義,對于全連結網絡而言,第i層的數學表達可以展現為:

\(Z^i=W^i\times input^i+b^i\)

\(input^{i+1}=g^i(Z^i)\)

  • 第一個公式就是一個簡單的線性變換;
  • 第二個公式是表示一個激活函數的過程。

【怎麼了解ICS問題】

我們知道,随着梯度下降的進行,每一層的參數\(W^i,b^i\)都會不斷地更新,這意味着\(Z^i\)的分布也不斷地改變,進而\(input^{i+1}\)的分布發生了改變。這意味着,除了第一層的輸入資料不改變,之後所有層的輸入資料的分布都會随着模型參數的更新發生改變,而每一層就要不停的去适應這種資料分布的變化,這個過程就是Internal Covariate Shift。

BN解決的問題

【ICS帶來的收斂速度慢】

因為每一層的參數不斷發生變化,進而每一層的計算結果的分布發生變化,後層網絡不斷地适應這種分布變化,這個時候會讓整個網絡的學習速度過慢。

【梯度飽和問題】

因為神經網絡中經常會采用sigmoid,tanh這樣的飽和激活函數(saturated actication function),是以模型訓練有陷入梯度飽和區的風險。解決這樣的梯度飽和問題有兩個思路:第一種就是更為非飽和性激活函數,例如線性整流函數ReLU可以在一定程度上解決訓練進入梯度飽和區的問題。另一種思路是,我們可以讓激活函數的輸入分布保持在一個穩定狀态來盡可能避免它們陷入梯度飽和區,這也就是Normalization的思路。

Batch Normalization

batchNormalization就像是名字一樣,對一個batch的資料進行normalization。

現在假設一個batch有3個資料,每個資料有兩個特征:(1,2),(2,3),(0,1)

如果做一個簡單的normalization,那麼就是計算均值和方差,把資料減去均值除以标準差,變成0均值1方差的标準形式。

對于第一個特征來說:

\(\mu=\frac{1}{3}(1+2+0)=1\)

\(\sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67\)

【通用公式】

\(\mu=\frac{1}{m}\sum_{i=1}^m{Z}\)

\(\sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu)\)

\(\hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}\)

  • 其中m表示一個batch的數量。
  • \(\epsilon\)是一個極小數,防止分母為0。

目前為止,我們做到了讓每個特征的分布均值為0,方差為1。這樣分布都一樣,一定不會有ICS問題

如同上面提到的,Normalization操作我們雖然緩解了ICS問題,讓每一層網絡的輸入資料分布都變得穩定,但卻導緻了資料表達能力的缺失。每一層的分布都相同,所有任務的資料分布都相同,模型學啥呢

【0均值1方差資料的弊端】

  1. 資料表達能力的缺失;
  2. 通過讓每一層的輸入分布均值為0,方差為1,會使得輸入在經過sigmoid或tanh激活函數時,容易陷入非線性激活函數的線性區域。(線性區域和飽和區域都不理想,最好是非線性區域)

為了解決這個問題,BN層引入了兩個可學習的參數\(\gamma\)和\(\beta\),這樣,經過BN層normalization的資料其實是服從\(\beta\)均值,\(\gamma^2\)方差的資料。

是以對于某一層的網絡來說,我們現在變成這樣的流程:

  1. \(Z=W\times input^i+b\)
  2. \(\hat{Z}=\gamma \times \frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta\)
  3. \(input^{i+1}=g(\hat{Z})\)

(上面公式中,省略了\(i\),總的來說是表示第i層的網絡層産生第i+1層輸入資料的過程)

測試階段的BN

我們知道BN在每一層計算的\(\mu\)與\(\sigma^2\) 都是基于目前batch中的訓練資料,但是這就帶來了一個問題:我們在預測階段,有可能隻需要預測一個樣本或很少的樣本,沒有像訓練樣本中那麼多的資料,這樣的\(\sigma^2\)和\(\mu\)要怎麼計算呢?

利用訓練集訓練好模型之後,其實每一層的BN層都保留下了每一個batch算出來的\(\mu\)和\(\sigma^2\).然後呢利用整體的訓練集來估計測試集的\(\mu_{test}\)和\(\sigma_{test}^2\)

\(\mu_{test}=E(\mu_{train})\)

\(\sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2)\)

然後再對測試機進行BN層:

當然,計算訓練集的\(\mu\)和\(\simga\)的方法除了上面的求均值之外。吳恩達老師在其課程中也提出了,可以使用指數權重平均的方法。不過都是同樣的道理,根據整個訓練集來估計測試機的均值方差。

BN層的好處有哪些

  1. BN使得網絡中每層輸入資料的分布相對穩定,加速模型學習速度。

    BN通過規範化與線性變換使得每一層網絡的輸入資料的均值與方差都在一定範圍内,使得後一層網絡不必不斷去适應底層網絡中輸入的變化,進而實作了網絡中層與層之間的解耦,允許每一層進行獨立學習,有利于提高整個神經網絡的學習速度。

  2. BN允許網絡使用飽和性激活函數(例如sigmoid,tanh等),緩解梯度消失問題

    通過normalize操作可以讓激活函數的輸入資料落在梯度非飽和區,緩解梯度消失的問題;另外通過自适應學習\(\gamma\)與 \(\beta\) 又讓資料保留更多的原始資訊。

  3. BN具有一定的正則化效果

    在Batch Normalization中,由于我們使用mini-batch的均值與方差作為對整體訓練樣本均值與方差的估計,盡管每一個batch中的資料都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會有所不同,這就為網絡的學習過程中增加了随機噪音

BN與其他normalizaiton的比較

【weight normalization】

Weight Normalization是對網絡權值進行normalization,也就是L2 norm。

相對于BN有下面的優勢:

  1. WN通過重寫神經網絡的權重的方式來加速網絡參數的收斂,不依賴于mini-batch。BN因為以來minibatch是以BN不能用于RNN網路,而WN可以。而且BN要儲存每一個batch的均值方差,是以WN節省記憶體;
  2. BN的優點中有正則化效果,但是添加噪音不适合對噪聲敏感的強化學習、GAN等網絡。WN可以引入更小的噪音。

但是WN要特别注意參數初始化的選擇。

【Layer normalization】

更常見的比較是BN與LN的比較。

BN層有兩個缺點:

  1. 無法進行線上學習,因為線上學習的mini-batch為1;LN可以
  2. 之前提到的BN不能用在RNN中;LN可以
  3. 消耗一定的記憶體來記錄均值和方差;LN不用

但是,在CNN中LN并沒有取得比BN更好的效果。

參考連結:

  1. https://zhuanlan.zhihu.com/p/34879333
  2. https://www.zhihu.com/question/59728870
  3. https://zhuanlan.zhihu.com/p/113233908
  4. https://www.zhihu.com/question/55890057/answer/267872896