天天看點

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

🍊作者簡介:秃頭小蘇,緻力于用最通俗的語言描述問題

🍊專欄推薦:深度學習網絡原理與實戰

🍊近期目标:寫好專欄的每一篇文章

🍊支援小蘇:點贊👍🏼、收藏⭐、留言📩

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

寫在前面

Hello,大家好,我是小蘇🧒🏽🧒🏽🧒🏽

在前面的文章中,我已經介紹過挺多種GAN網絡了,感興趣的可以關注一下我的專欄:深度學習網絡原理與實戰 。目前專欄主要更新了GAN系列文章、Transformer系列和語義分割系列文章,都有理論詳解和代碼實戰,文中的講解都比較通俗易懂,如果你希望豐富這方面的知識,建議你閱讀試試,相信你會有蠻不錯的收獲。🍸🍸🍸

在閱讀本篇教程之前,你非常有必要閱讀下面兩篇文章:

  • [1]對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例
  • [2]對抗生成網絡GAN系列——WGAN原理及實戰演練

其實啊,我相信大家來看這篇文章的時候,一定是對上文提到的文章有所了解了,是以大家要是覺得自己對GAN和WGAN了解的已經足夠透徹了,那麼完全沒有必要再浪費時間閱讀了。如果你還對它們有一些疑惑或者過了很久已經忘了希望回顧一下的話,那麼文章[1]和文章[2]擷取對你有所幫助。

大家準備好了嘛,我們這就開始準備學習Spectral Normalization啦!🚖🚖🚖

Spectral Normalization原理詳解

​  首先,讓我們簡單的回顧一下WGAN。🌞🌞🌞由于原始GAN網絡存在訓練不穩定的現象,究其本質,是因為它的損失函數實際上是JS散度,而JS散度不會随着兩個分布的距離改變而改變(這句不嚴謹,細節參考WGAN中的描述),這就會導緻生成器的梯度會一直不變,進而導緻模型訓練效果很差。WGAN為了解決原始GAN網絡訓練不穩定的現象,引入了EM distance代替原有的JS散度,這樣的改變會使生成器梯度一直變化,進而使模型得到充分訓練。但是WGAN的提出伴随着一個難點,即如何讓判别器的參數矩陣滿足Lipschitz連續條件。

​  如何解決上述所說的難點呢?在WGAN中,我們采用了一種簡單粗暴的方式來滿足這一條件,即直接對判别器的權重參數進行剪裁,強制将權重限制在[-c,c]範圍内。大家可以動動我們的小腦瓜想想這種權重剪裁的方式有什麼樣的問題——(滴,揭曉答案🍍🍍🍍)如果權重剪裁的參數c很大,那麼任何權重可能都需要很長時間才能達到極限,進而使訓練判别器達到最優變得更加困難;如果權重剪裁的參數c很小,這又容易導緻梯度消失。是以,如何确定權重剪裁參數c是重要的,同時這也是困難的。WGAN提出之後,又提出了WGAN-GP來實作Lipschitz 連續條件,其主要通過添加一個懲罰項來實作。【關于WGAN-GP我沒有做相關教程,如果不明白的可以評論區留言】那麼本文提出了一種歸一化的手段Spectral Normalization來實作Lipschitz連續條件,這種歸一化具體是怎麼實作的呢,下面聽我慢慢道來。🍻🍻🍻

我們還是來先回顧一下Lipschitz連續條件,如下:

​             ∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ K ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)| \le K|x_1-x_2| ∣f(x1​)−f(x2​)∣≤K∣x1​−x2​∣

這個式子限制了函數 f ( ⋅ ) {\rm{f}}( \cdot ) f(⋅)的導數,即其導數的絕對值小于K, ∣ f ( x 1 ) − f ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ K \frac{|f(x_1)-f(x_2)|}{|x_1-x_2|} \le K ∣x1​−x2​∣∣f(x1​)−f(x2​)∣​≤K。 🍋🍋🍋

本文介紹的Spectral Normalization的K=1,讓我們一起來看看怎麼實作的吧!!!

  上文提到,WGAN的難點是如何讓判别器的參數矩陣滿足Lipschitz連續條件。那麼我們就從判别器入手和大家唠一唠。實際上,判别器也是由多層卷積神經網絡構成的,我們用下式表示第n層網絡輸出和第n-1層輸入的關系:

​             X n = a n ( W n ⋅ X n − 1 + b n ) X_n=a_n(W_n \cdot X_{n-1}+b_n) Xn​=an​(Wn​⋅Xn−1​+bn​)

  其中 a n ( ⋅ ) a_n(\cdot) an​(⋅)表示激活函數, W n W_n Wn​表示權重參數矩陣。為了友善起見,我們不設定偏置項 b n b_n bn​,即 b n = 0 b_n=0 bn​=0。那麼上式變為:

​             X n = a n ( W n ⋅ X n − 1 ) X_n=a_n(W_n \cdot X_{n-1}) Xn​=an​(Wn​⋅Xn−1​)

  再為了友善起見🤸🏽‍♂️🤸🏽‍♂️🤸🏽‍♂️,我們設 a n ( ⋅ ) a_n(\cdot) an​(⋅),即激活函數為Relu。Relu函數在大于0時為y=x,小于0時為y=0,函數圖像如下圖所示:

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

​  這樣的話式 X n = a n ( W n ⋅ X n − 1 ) X_n=a_n(W_n \cdot X_{n-1}) Xn​=an​(Wn​⋅Xn−1​)可以寫成 X n = D n ⋅ W n ⋅ X n − 1 X_n=D_n \cdot W_n \cdot X_{n-1} Xn​=Dn​⋅Wn​⋅Xn−1​,其中 D n D_n Dn​為對角矩陣。【大家這裡能否了解呢?如果我們的輸入為正數時,通過Relu函數值是不變的,那麼此時 D n D_n Dn​對應的對角元素應該為1;如果我們的輸入為負數時,通過Relu函數值将變成0,那麼此時 D n D_n Dn​對應的對角元素應該為0。也就是說我們将 X n X_n Xn​改寫成 D n ⋅ W n ⋅ X n − 1 D_n \cdot W_n \cdot X_{n-1} Dn​⋅Wn​⋅Xn−1​形式是可行的。】

​  接着我們做一些簡單的推理,得到判别器第n層輸出和原始輸入的關系,如下圖所示:

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

  最後一層的輸出 X n X_n Xn​即為判别器的輸出,接下來我們用 f ( x ) f(x) f(x)表示;原始輸入資料 x 0 x_0 x0​我們接下來用 x x x表示。則判别器最終輸入輸出的關系式如下:

​    f ( x ) = D n ⋅ W n ⋅ D n − 1 ⋅ W n − 1 ⋯ D 3 ⋅ W 3 ⋅ D 2 ⋅ W 2 ⋅ D 1 ⋅ W 1 ⋅ x f(x) = {D_n} \cdot {W_n} \cdot {D_{n - 1}} \cdot {W_{n - 1}} \cdots {D_3} \cdot {W_3} \cdot {D_2} \cdot {W_2} \cdot {D_1} \cdot {W_1} \cdot x f(x)=Dn​⋅Wn​⋅Dn−1​⋅Wn−1​⋯D3​⋅W3​⋅D2​⋅W2​⋅D1​⋅W1​⋅x

  上文說到Lipschitz連續條件本質上就是限制函數 f ( ⋅ ) {\rm{f}}( \cdot ) f(⋅)的導數變化範圍,其實就是對 f ( x ) f(x) f(x)梯度提出限制,如下:

∣ ∣ ∇ x f ( x ) ∣ ∣ 2 = ∣ ∣ D n ⋅ W n ⋅ D n − 1 ⋅ W n − 1 ⋯ D 3 ⋅ W 3 ⋅ D 2 ⋅ W 2 ⋅ D 1 ⋅ W 1 ∣ ∣ 2 ≤ ∣ ∣ D n ∣ ∣ 2 ⋅ ∣ ∣ W n ∣ ∣ 2 ⋅ ∣ ∣ D n − 1 ∣ ∣ 2 ⋅ ∣ ∣ W n − 1 ∣ ∣ 2 ⋯ ∣ ∣ D 1 ∣ ∣ 2 ⋅ ∣ ∣ W 1 ∣ ∣ 2 ||{\nabla _x}f(x)|{|_2} = ||{D_n} \cdot {W_n} \cdot {D_{n - 1}} \cdot {W_{n - 1}} \cdots {D_3} \cdot {W_3} \cdot {D_2} \cdot {W_2} \cdot {D_1} \cdot {W_1}|{|_2} \le ||{D_n}|{|_2} \cdot ||{W_n}|{|_2} \cdot ||{D_{n - 1}}|{|_2} \cdot ||{W_{n - 1}}|{|_2} \cdots ||{D_1}|{|_2} \cdot ||{W_1}|{|_2} ∣∣∇x​f(x)∣∣2​=∣∣Dn​⋅Wn​⋅Dn−1​⋅Wn−1​⋯D3​⋅W3​⋅D2​⋅W2​⋅D1​⋅W1​∣∣2​≤∣∣Dn​∣∣2​⋅∣∣Wn​∣∣2​⋅∣∣Dn−1​∣∣2​⋅∣∣Wn−1​∣∣2​⋯∣∣D1​∣∣2​⋅∣∣W1​∣∣2​

  其中 ∣ ∣ A ∣ ∣ 2 ||A||_2 ∣∣A∣∣2​表示矩陣A的2範數,也叫譜範數,它的值為 λ 1 \sqrt {{\lambda _1}} λ1​

​, λ 1 {\lambda _1} λ1​為 A H A {{\rm{A}}^H}{\rm{A}} AHA的最大特征值。 λ 1 \sqrt {{\lambda _1}} λ1​

​又稱作矩陣A的奇異值【注:奇異值是 A H A {{\rm{A}}^H}{\rm{A}} AHA的特征值的開根号,也就是說 λ 1 \sqrt {{\lambda _1}} λ1​

​為A的其中一個奇異值或譜範數是最大的奇異值,這裡我們将譜範數,即最大的奇異值記作 σ ( A ) = λ 1 \sigma {(A)} = \sqrt {{\lambda _1}} σ(A)=λ1​

​。由于D是對角矩陣且由0、1構成,其奇異值總是小于等于1,故有下式:

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

  即 ∇ x f ( x ) ∣ ∣ 2 = ∣ ∣ D n ∣ ∣ 2 ⋅ ∣ ∣ W n ∣ ∣ 2 ⋯ ∣ ∣ D 1 ∣ ∣ 2 ⋅ ∣ ∣ W 1 ∣ ∣ 2 ≤ Π 1 n σ ( W i ) {\nabla _x}f(x)|{|_2}= ||{D_n}|{|_2}\cdot ||{W_n}|{|_2} \cdots ||{D_1}|{|_2} \cdot ||{W_1}|{|_2} \le \mathop \Pi \limits_1^{\rm{n}} \sigma ({W_i}) ∇x​f(x)∣∣2​=∣∣Dn​∣∣2​⋅∣∣Wn​∣∣2​⋯∣∣D1​∣∣2​⋅∣∣W1​∣∣2​≤1Πn​σ(Wi​)。為滿足Lipschitz連續條件,我們應該讓 ∣ ∣ ∇ x f ( x ) ∣ ∣ 2 ≤ K ||{\nabla _x}f(x)|{|_2} \le K ∣∣∇x​f(x)∣∣2​≤K ,這裡的K設定為1。那具體要怎麼做呢,其實就是對上式做一個歸一化處理,讓每一層參數矩陣除以該層參數矩陣的譜範數,如下:

​   ∣ ∣ ∇ x f ( x ) ∣ ∣ 2 = ∣ D n ∣ ∣ 2 ⋅ ∣ ∣ W n ∣ ∣ 2 σ ( W n ) ⋯ ∣ ∣ D 1 ∣ ∣ 2 ⋅ ∣ ∣ W 1 ∣ ∣ 2 σ ( W 1 ) ≤ Π 1 n σ ( W i ) σ ( W i ) = 1 ||{\nabla _x}f(x)|{|_2} = |{D_n}|{|_2} \cdot \frac{{||{W_n}|{|_2}}}{{\sigma ({W_n})}} \cdots ||{D_1}|{|_2} \cdot \frac{{||{W_1}|{|_2}}}{{\sigma ({W_1})}} \le \mathop \Pi \limits_1^{\rm{n}} \frac{{\sigma ({W_i})}}{{\sigma ({W_i})}} = 1 ∣∣∇x​f(x)∣∣2​=∣Dn​∣∣2​⋅σ(Wn​)∣∣Wn​∣∣2​​⋯∣∣D1​∣∣2​⋅σ(W1​)∣∣W1​∣∣2​​≤1Πn​σ(Wi​)σ(Wi​)​=1

  這樣,其實我們的

Spectral Normalization

原理就講的差不多了,最後我們要做的就是求得每層參數矩陣的譜範數,然後再進行歸一化操作。要想求矩陣的譜範數,首先得求矩陣的奇異值,具體求法我放在附錄部分。

  但是按照正常求奇異值的方法會消耗大量的計算資源,是以論文中使用了一種近似求解譜範數的方法,僞代碼如下圖所示:

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

  在代碼的實戰中我們就是按照上圖的僞代碼求解譜範數的,屆時我們會為大家介紹。🍄🍄🍄

注:大家閱讀這部分有沒有什麼難度呢,我覺得可能還是挺難的,你需要一些矩陣分析的知識,我已經盡可能把這個問題描述的簡單了,有的文章寫的很好,公式推導的也很詳盡,我會在參考連結中給出。但是會涉及到最優化的一些理論,估計這就讓大家更頭疼了,是以大家慢慢消化吧!!!🍚🍚🍚在最後的附錄中,我會給出本節内容相關的矩陣分析知識,是我上課時的一些筆記,筆記包含本節的知識點,但針對性可能不是很強,也就是說可能包含一些其它内容,大家可以選擇忽略,當然了,你也可以細細的研究研究每個知識點,說不定後面就用到了呢!!!🥝🥝🥝

Spectral Normalization源碼解析

源碼下載下傳位址:Spectral Normalization📥📥📥

  這個代碼使用的是CIFAR10資料集,實作的是一般生成對抗網絡的圖像生成任務。我不打算再對每一句代碼進行詳細的解釋,有不明白的可以先去看看我專欄中的其它GAN網絡的文章,都有源碼解析,弄明白後再看這篇你會發現非常簡單。那麼這篇文章我主要來介紹一下

Spectral Normalization

部分的内容,其相關内容在

spectral_normalization.py

檔案中,我們理論部分提到

Spectral Normalization

關鍵的一步是求解每個參數矩陣的譜範數,相關代碼如下:

def _update_u_v(self):
    u = getattr(self.module, self.name + "_u")
    v = getattr(self.module, self.name + "_v")
    w = getattr(self.module, self.name + "_bar")
    height = w.data.shape[0]
    for _ in range(self.power_iterations):
        u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))  
        v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))

    sigma = u.dot(w.view(height, -1).mv(v))
    setattr(self.module, self.name, w / sigma.expand_as(w))
    
    
    
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)
           

  對上述代碼做一定的解釋,6,7,8,9,10行做的就是理論部分僞代碼的工作,最後會得到譜範數sigma。11行為使用參數矩陣除以譜範數sigma,以此實作歸一化的作用。【torch.mv實作的是矩陣乘法的操作,裡面可能還有些函數你沒見過,大家百度一下用法就知道了,非常簡單】

其實關鍵的代碼就這些,是不是發現特别簡單呢🍸🍸🍸每次介紹代碼時我都會強調自己動手調試的重要性,很多時候寫文章介紹源碼都覺得有些力不從心,一些想表達的點總是很難表述,總之,大家要是有什麼不明白的就盡情調試叭,或者評論區留言,我天天線上摸魚滴喔。⭐⭐⭐後期我也打算出一些視訊教學了,這樣的話就可以帶着大家一起調試,我想這樣介紹源碼彼此都會輕松很多。🛩🛩🛩

小結

  

Spectral Normalization

确實是有一定難度的,我也有許多地方了解的也不是很清楚,對于這種難啃的問題我是這樣認為的。我們可以先對其有一個大緻的了解,知道整個過程,知道代碼怎麼實作,能使用代碼跑通一些模型,然後考慮能否将其用在自己可能需要使用的地方,如果加入的效果不好,我們就沒必要深究原理了,如果發現效果好,這時候我們再回來慢慢細嚼原理也不遲。最後,希望各位都能擷取新知識,能夠學有所成叭!!!🌹🌹🌹

參考連結

GAN — Spectral Normalization 🍁🍁🍁

Spectral Normalization for GAN🍁🍁🍁

詳解GAN的譜歸一化(Spectral Normalization)🍁🍁🍁

譜歸一化(Spectral Normalization)的了解🍁🍁🍁

附錄

  這部分是我學習矩陣分析這門課程時的筆記,截取一些包含此部分的内容,有需求的感興趣的可以看一看。🌱🌱🌱

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析
對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析
對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析
對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

如若文章對你有所幫助,那就🛴🛴🛴

        

對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析對抗生成網絡GAN系列——Spectral Normalization原理詳解及源碼解析

繼續閱讀