天天看點

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

文章目錄

    • General Framework of GAN
      • f-divergence(通用的divergence模型)
      • Fenchel Conjugate(凸共轭)
      • Connection with GAN
      • GAN訓練過程中可能産生的問題
        • Mode Collapse
          • 介紹
          • 解決辦法
        • Mode Dropping
          • 介紹
          • 問題分析

General Framework of GAN

之前在講GAN的時候,提到我們實際是在用Discriminator來衡量兩個資料的分布之間的JS divergence,那能不能是其他類型的divergence來衡量真實資料和生成資料之間的差距?又如何進行衡量?(雖然在實作上用不同divergence結果沒有很大差别)

李老師原話:在數學上感覺非常的屌

在開始講fGAN之前,需要先補充兩個基礎知識,f-divergence和Fenchel Conjugate。為什麼要講它們,後面會提到。

f-divergence(通用的divergence模型)

任意的divergence都可以用來衡量真實資料和生成資料之間的差距,用f-divergence進行衡量的算法就叫fGAN。先來看看f-divergence的概念:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

假設有兩個分布 P 和 Q,p(x) 和 q(x) 分别代表樣本 x 從這兩個分布中采樣出來的機率,則能定義 P 和 Q 之間的 f-divergence 為:

D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x D_f(P||Q) = \int_xq(x)f(\frac{p(x)}{q(x)})dx Df​(P∣∣Q)=∫x​q(x)f(q(x)p(x)​)dx

其中對于函數 f(x) 有兩個要求:

  1. f 要求是凸函數
  2. 要求 f(1)=0

更換不同的函數 f(x),就能得到不同的 f-divergence。

為什麼這個式子可以用來衡量 P 和 Q 之間的差異呢?它有以下幾個特點:

  • 如果對于所有的 x,有 p(x)=q(x),那麼 D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( 1 ) d x = 0 D_f(P||Q) = \int_xq(x)f(1)dx = 0 Df​(P∣∣Q)=∫x​q(x)f(1)dx=0 。原因很簡單,f(1)=0。
  • 恒定的,有 D f ( P ∣ ∣ Q ) ≥ 0 D_f(P||Q)≥0 Df​(P∣∣Q)≥0 。因為 f(x) 是一個凸函數,是以有

D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x ≥ f ( ∫ x q ( x ) p ( x ) q ( x ) d x ) = f ( ∫ x p ( x ) d x ) = f ( 1 ) = 0 \begin{aligned} D_f(P||Q) = \int_xq(x)f(\frac{p(x)}{q(x)})dx ≥ & f(\int_xq(x)\frac{p(x)}{q(x)}dx) \\ = & f(\int_xp(x)dx) \\ = & f(1)\\ = & 0 \end{aligned} Df​(P∣∣Q)=∫x​q(x)f(q(x)p(x)​)dx≥===​f(∫x​q(x)q(x)p(x)​dx)f(∫x​p(x)dx)f(1)0​

是以, D f ( P ∣ ∣ Q ) D_f(P||Q) Df​(P∣∣Q) 可用來對 P 和 Q 之間差距作出衡量。

下面列舉幾個常見的 f(x) 取值及其得到的 f-divergence。

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

Fenchel Conjugate(凸共轭)

每一個凸函數都有一個共轭函數(conjugate function),記為 f ∗ f^* f∗ ,長這樣:

f ∗ ( t ) = max ⁡ x ∈ d o m ( f ) { x t − f ( x ) } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} f∗(t)=x∈dom(f)max​{xt−f(x)}

在這個函數中,t 是自變量,就是說帶一個值 t 到 f ∗ f^* f∗ 裡面,窮舉所有的x,看看哪個x可以使得 f ∗ f^* f∗ 最大。

比較笨的窮舉法如下:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

另外一種方法:函數 x t − f ( x ) xt-f(x) xt−f(x)是直線,我們帶不同的 x x x得到不同的直線,例如下面有三條直線:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

然後找不同的t對應的最大值。(就是所有直線的upper bound)

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

上面的紅線無論你如何畫,最後都是convex的。

看個例子,假設: f ( x ) = x l o g x f(x) = xlogx f(x)=xlogx,把 x = 0.1 , x = 1 , x = 10 x=0.1,x=1,x=10 x=0.1,x=1,x=10帶入,結果如圖所示:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

紅線最後接近:

f ∗ ( t ) = e t − 1 f^*(t) = e^{t-1} f∗(t)=et−1

下面是數學證明:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

假設 f ( x ) = x l o g x f(x) = xlogx f(x)=xlogx,則:

f ∗ ( t ) = max ⁡ x ∈ d o m ( f ) { x t − f ( x ) } = max ⁡ x ∈ d o m ( f ) { x t − x l o g x } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} = \max\limits_{x∈dom(f)}\{xt-xlogx\} f∗(t)=x∈dom(f)max​{xt−f(x)}=x∈dom(f)max​{xt−xlogx}

令上式中 x t − x l o g x = g ( x ) xt-xlogx = g(x) xt−xlogx=g(x),給一個t,求得最大的g(x)

如何求呢?求極值,就是用g(x)對x進行求導等于0:

g ′ ( x ) = t − l o g x − 1 = 0     = >     x = e t − 1 g'(x) = t-logx - 1 = 0 \ \ \ => \ \ \ x=e^{t-1} g′(x)=t−logx−1=0   =>   x=et−1

把上面内容代入公式 f ∗ ( t ) f^*(t) f∗(t) :

f ∗ ( t ) = x t − x l o g x = e t − 1 × t − e t − 1 × ( t − 1 ) = e t − 1 f^*(t) = xt- xlogx = e^{t-1}×t-e^{t-1}×(t-1) = e^{t-1} f∗(t)=xt−xlogx=et−1×t−et−1×(t−1)=et−1

一般化後:

( f ∗ ) ∗ = f (f^*)^* = f (f∗)∗=f

對于共轭函數,還有一個性質,就是

共轭函數是互相的

,也就是對每一對共轭函數來說,有:

f ∗ ( t ) = max ⁡ x ∈ d o m ( f ) { x t − f ( x ) } ← → f ( x ) = max ⁡ t ∈ d o m ( f ) { x t − f ∗ ( t ) } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} ←→f(x) = \max\limits_{t∈dom(f)}\{xt-f^*(t)\} f∗(t)=x∈dom(f)max​{xt−f(x)}←→f(x)=t∈dom(f)max​{xt−f∗(t)}

Connection with GAN

上面的内容到底和GAN有什麼關系呢?

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

我們假設有一個divergence:

D f ( P ∣ ∣ D ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x = ∫ x q ( x ) ( max ⁡ t ∈ d o m ( f ∗ ) { p ( x ) q ( x ) t − f ∗ ( t ) } ) d x D_f(P||D) = \int_xq(x)f(\frac{p(x)}{q(x)})dx = \int_xq(x)\left(\max\limits_{t∈dom(f^*)}\{\frac{p(x)}{q(x)}t-f^*(t)\}\right)dx Df​(P∣∣D)=∫x​q(x)f(q(x)p(x)​)dx=∫x​q(x)(t∈dom(f∗)max​{q(x)p(x)​t−f∗(t)})dx

接下來,我們學習一個function D,這個D的輸入是x, 輸出是t,我們将t用D(x)替代,同時我們去掉max的概念,改為 ≥ ≥ ≥ ,那麼上式可改寫為:

D f ( P ∣ ∣ D ) ≥ ∫ x q ( x ) ( p ( x ) q ( x ) D ( x ) − f ∗ ( D ( x ) ) ) d x = ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||D) ≥ \int_xq(x)\left(\frac{p(x)}{q(x)}D(x)-f^*\left(D(x)\right)\right)dx = \int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx Df​(P∣∣D)≥∫x​q(x)(q(x)p(x)​D(x)−f∗(D(x)))dx=∫x​p(x)D(x)dx−∫x​q(x)f∗(D(x))dx

那麼其實相當于:

D f ( P ∣ ∣ D ) ≈ max ⁡ D ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||D) ≈ \max\limits_D\int_xp(x)D(x)dx - \int_xq(x)f^*(D(x))dx Df​(P∣∣D)≈Dmax​∫x​p(x)D(x)dx−∫x​q(x)f∗(D(x))dx

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

上面這個公式,我們把它改寫一下:

D f ( P ∣ ∣ D ) = max ⁡ D { E x ∼ P [ D ( x ) ] − E x ∼ Q [ f ∗ ( D ( x ) ) ] } D_f(P||D) = \max\limits_D\left\{ E_{x \sim P}[D(x)] - E_{x\sim Q}[f^*(D(x))] \right\} Df​(P∣∣D)=Dmax​{Ex∼P​[D(x)]−Ex∼Q​[f∗(D(x))]}

我們令 P = P d a t a , Q = P G P=P_{data},Q=P_G P=Pdata​,Q=PG​ ,那麼 P d a t a , P G P_{data},P_G Pdata​,PG​ 的f-divergence就可以寫成:

D f ( P ∣ ∣ D ) = max ⁡ D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ G [ f ∗ ( D ( x ) ) ] } D_f(P||D) = \max\limits_D\left\{ E_{x \sim P_{data}}[D(x)] - E_{x\sim G}[f^*(D(x))] \right\} Df​(P∣∣D)=Dmax​{Ex∼Pdata​​[D(x)]−Ex∼G​[f∗(D(x))]}

這個 f ∗ f^* f∗ 取決于f-divergence是什麼。

這個式子怎麼看起來好像GAN需要minimize的目标呢?在機器學習-35-Theory behind GAN(GAN背後的數學理論) 我們提到了GAN的訓練目标是 G ∗ = a r g min ⁡ G D f ( P d a t a ∣ ∣ P G ) G^* = arg \min\limits_GD_f(P_{data}||P_G) G∗=argGmin​Df​(Pdata​∣∣PG​),我們把GAN的訓練目标展開,就是:

G ∗ = a r g min ⁡ G D f ( P d a t a ∣ ∣ P G ) = a r g min ⁡ G max ⁡ D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] } = a r g min ⁡ G max ⁡ D V ( G , D ) \begin{aligned} G^* = &arg \min\limits_GD_f(P_{data}||P_G)\\ =& arg \min\limits_G\max\limits_D\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\} \\ =& arg \min\limits_G \max\limits_D V(G,D) \end{aligned} G∗===​argGmin​Df​(Pdata​∣∣PG​)argGmin​Dmax​{Ex∼Pdata​​[D(x)]−Ex∼PG​​[f∗(D(x))]}argGmin​Dmax​V(G,D)​

是以你可以選用不同的f-divergence,優化不同的divergence,論文裡面給了個清單,你可以自己選:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

那麼使用不同的divergence會有什麼用處嗎?它可能能夠用于解決GAN在訓練過程中會出現的一些問題(衆所周知GAN難以訓練)

GAN訓練過程中可能産生的問題

Mode Collapse

介紹

這個概念是GAN難以訓練的原因之一,它指的是GAN産生的樣本單一,認為滿足某一分布的結果為True,其餘為False。如下圖,原始資料分布的範圍要比GAN訓練的結果大得多。進而導緻generator訓練出來的結果可能都差不多,圖檔差異性不大。

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

當我們的GAN模型Training with too many iterations……

有些人臉就會比較像,除了一些顔色不太一樣

解決辦法

對于mode collapse到底有沒有一些更加通用的解決辦法呢?你可以用ensemble的方法,其實就是訓練多個Generator,然後在使用的時候随機挑一個generator來生成結果,當然是一個很流氓的招數,看你用到哪裡了…

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

Mode Dropping

介紹

這個問題從字面上也好了解,假設原始分布有兩個比較集中的波峰,而GAN有可能把分布集中在其中一個波峰,而抛棄掉了另一個,如下圖:

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

例如下面的人臉,一個循環隻有白種人,一個循環隻有黃種人,一個循環中隻有黑人。

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)
問題分析

為什麼會有這樣的結果呢,一個猜測是divergence選得不好,選擇不同的divergence,最後generator得到的distribution會不一樣。如下圖,minimize KL divergence和reverse KL divergence的時候,最後得到的分布是不一樣的, 前者容易導緻模糊的問題,後者則可能導緻mode dropping。如果你覺得在訓練過程中出現的mode collapse或者mode dropping是由于divergence導緻的,你可以通過嘗試更換 f ∗ f^* f∗ 來實驗。當然不一定說就一定有效果,這裡隻是提供一種可能的猜測。

機器學習-38-GAN-05-General Framework of GAN(fGAN,GAN的一般架構)

繼續閱讀