文章目錄
-
- General Framework of GAN
-
- f-divergence(通用的divergence模型)
- Fenchel Conjugate(凸共轭)
- Connection with GAN
- GAN訓練過程中可能産生的問題
-
- Mode Collapse
-
- 介紹
- 解決辦法
- Mode Dropping
-
- 介紹
- 問題分析
General Framework of GAN
之前在講GAN的時候,提到我們實際是在用Discriminator來衡量兩個資料的分布之間的JS divergence,那能不能是其他類型的divergence來衡量真實資料和生成資料之間的差距?又如何進行衡量?(雖然在實作上用不同divergence結果沒有很大差别)
李老師原話:在數學上感覺非常的屌
在開始講fGAN之前,需要先補充兩個基礎知識,f-divergence和Fenchel Conjugate。為什麼要講它們,後面會提到。
f-divergence(通用的divergence模型)
任意的divergence都可以用來衡量真實資料和生成資料之間的差距,用f-divergence進行衡量的算法就叫fGAN。先來看看f-divergence的概念:
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZuBnLzAjN0IDOwcTM0ETMwEjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
假設有兩個分布 P 和 Q,p(x) 和 q(x) 分别代表樣本 x 從這兩個分布中采樣出來的機率,則能定義 P 和 Q 之間的 f-divergence 為:
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x D_f(P||Q) = \int_xq(x)f(\frac{p(x)}{q(x)})dx Df(P∣∣Q)=∫xq(x)f(q(x)p(x))dx
其中對于函數 f(x) 有兩個要求:
- f 要求是凸函數
- 要求 f(1)=0
更換不同的函數 f(x),就能得到不同的 f-divergence。
為什麼這個式子可以用來衡量 P 和 Q 之間的差異呢?它有以下幾個特點:
- 如果對于所有的 x,有 p(x)=q(x),那麼 D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( 1 ) d x = 0 D_f(P||Q) = \int_xq(x)f(1)dx = 0 Df(P∣∣Q)=∫xq(x)f(1)dx=0 。原因很簡單,f(1)=0。
- 恒定的,有 D f ( P ∣ ∣ Q ) ≥ 0 D_f(P||Q)≥0 Df(P∣∣Q)≥0 。因為 f(x) 是一個凸函數,是以有
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x ≥ f ( ∫ x q ( x ) p ( x ) q ( x ) d x ) = f ( ∫ x p ( x ) d x ) = f ( 1 ) = 0 \begin{aligned} D_f(P||Q) = \int_xq(x)f(\frac{p(x)}{q(x)})dx ≥ & f(\int_xq(x)\frac{p(x)}{q(x)}dx) \\ = & f(\int_xp(x)dx) \\ = & f(1)\\ = & 0 \end{aligned} Df(P∣∣Q)=∫xq(x)f(q(x)p(x))dx≥===f(∫xq(x)q(x)p(x)dx)f(∫xp(x)dx)f(1)0
是以, D f ( P ∣ ∣ Q ) D_f(P||Q) Df(P∣∣Q) 可用來對 P 和 Q 之間差距作出衡量。
下面列舉幾個常見的 f(x) 取值及其得到的 f-divergence。
Fenchel Conjugate(凸共轭)
每一個凸函數都有一個共轭函數(conjugate function),記為 f ∗ f^* f∗ ,長這樣:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} f∗(t)=x∈dom(f)max{xt−f(x)}
在這個函數中,t 是自變量,就是說帶一個值 t 到 f ∗ f^* f∗ 裡面,窮舉所有的x,看看哪個x可以使得 f ∗ f^* f∗ 最大。
比較笨的窮舉法如下:
另外一種方法:函數 x t − f ( x ) xt-f(x) xt−f(x)是直線,我們帶不同的 x x x得到不同的直線,例如下面有三條直線:
然後找不同的t對應的最大值。(就是所有直線的upper bound)
上面的紅線無論你如何畫,最後都是convex的。
看個例子,假設: f ( x ) = x l o g x f(x) = xlogx f(x)=xlogx,把 x = 0.1 , x = 1 , x = 10 x=0.1,x=1,x=10 x=0.1,x=1,x=10帶入,結果如圖所示:
紅線最後接近:
f ∗ ( t ) = e t − 1 f^*(t) = e^{t-1} f∗(t)=et−1
下面是數學證明:
假設 f ( x ) = x l o g x f(x) = xlogx f(x)=xlogx,則:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } = max x ∈ d o m ( f ) { x t − x l o g x } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} = \max\limits_{x∈dom(f)}\{xt-xlogx\} f∗(t)=x∈dom(f)max{xt−f(x)}=x∈dom(f)max{xt−xlogx}
令上式中 x t − x l o g x = g ( x ) xt-xlogx = g(x) xt−xlogx=g(x),給一個t,求得最大的g(x)
如何求呢?求極值,就是用g(x)對x進行求導等于0:
g ′ ( x ) = t − l o g x − 1 = 0 = > x = e t − 1 g'(x) = t-logx - 1 = 0 \ \ \ => \ \ \ x=e^{t-1} g′(x)=t−logx−1=0 => x=et−1
把上面内容代入公式 f ∗ ( t ) f^*(t) f∗(t) :
f ∗ ( t ) = x t − x l o g x = e t − 1 × t − e t − 1 × ( t − 1 ) = e t − 1 f^*(t) = xt- xlogx = e^{t-1}×t-e^{t-1}×(t-1) = e^{t-1} f∗(t)=xt−xlogx=et−1×t−et−1×(t−1)=et−1
一般化後:
( f ∗ ) ∗ = f (f^*)^* = f (f∗)∗=f
對于共轭函數,還有一個性質,就是
共轭函數是互相的
,也就是對每一對共轭函數來說,有:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } ← → f ( x ) = max t ∈ d o m ( f ) { x t − f ∗ ( t ) } f^*(t) = \max\limits_{x∈dom(f)}\{xt-f(x)\} ←→f(x) = \max\limits_{t∈dom(f)}\{xt-f^*(t)\} f∗(t)=x∈dom(f)max{xt−f(x)}←→f(x)=t∈dom(f)max{xt−f∗(t)}
Connection with GAN
上面的内容到底和GAN有什麼關系呢?
我們假設有一個divergence:
D f ( P ∣ ∣ D ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x = ∫ x q ( x ) ( max t ∈ d o m ( f ∗ ) { p ( x ) q ( x ) t − f ∗ ( t ) } ) d x D_f(P||D) = \int_xq(x)f(\frac{p(x)}{q(x)})dx = \int_xq(x)\left(\max\limits_{t∈dom(f^*)}\{\frac{p(x)}{q(x)}t-f^*(t)\}\right)dx Df(P∣∣D)=∫xq(x)f(q(x)p(x))dx=∫xq(x)(t∈dom(f∗)max{q(x)p(x)t−f∗(t)})dx
接下來,我們學習一個function D,這個D的輸入是x, 輸出是t,我們将t用D(x)替代,同時我們去掉max的概念,改為 ≥ ≥ ≥ ,那麼上式可改寫為:
D f ( P ∣ ∣ D ) ≥ ∫ x q ( x ) ( p ( x ) q ( x ) D ( x ) − f ∗ ( D ( x ) ) ) d x = ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||D) ≥ \int_xq(x)\left(\frac{p(x)}{q(x)}D(x)-f^*\left(D(x)\right)\right)dx = \int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx Df(P∣∣D)≥∫xq(x)(q(x)p(x)D(x)−f∗(D(x)))dx=∫xp(x)D(x)dx−∫xq(x)f∗(D(x))dx
那麼其實相當于:
D f ( P ∣ ∣ D ) ≈ max D ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||D) ≈ \max\limits_D\int_xp(x)D(x)dx - \int_xq(x)f^*(D(x))dx Df(P∣∣D)≈Dmax∫xp(x)D(x)dx−∫xq(x)f∗(D(x))dx
上面這個公式,我們把它改寫一下:
D f ( P ∣ ∣ D ) = max D { E x ∼ P [ D ( x ) ] − E x ∼ Q [ f ∗ ( D ( x ) ) ] } D_f(P||D) = \max\limits_D\left\{ E_{x \sim P}[D(x)] - E_{x\sim Q}[f^*(D(x))] \right\} Df(P∣∣D)=Dmax{Ex∼P[D(x)]−Ex∼Q[f∗(D(x))]}
我們令 P = P d a t a , Q = P G P=P_{data},Q=P_G P=Pdata,Q=PG ,那麼 P d a t a , P G P_{data},P_G Pdata,PG 的f-divergence就可以寫成:
D f ( P ∣ ∣ D ) = max D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ G [ f ∗ ( D ( x ) ) ] } D_f(P||D) = \max\limits_D\left\{ E_{x \sim P_{data}}[D(x)] - E_{x\sim G}[f^*(D(x))] \right\} Df(P∣∣D)=Dmax{Ex∼Pdata[D(x)]−Ex∼G[f∗(D(x))]}
這個 f ∗ f^* f∗ 取決于f-divergence是什麼。
這個式子怎麼看起來好像GAN需要minimize的目标呢?在機器學習-35-Theory behind GAN(GAN背後的數學理論) 我們提到了GAN的訓練目标是 G ∗ = a r g min G D f ( P d a t a ∣ ∣ P G ) G^* = arg \min\limits_GD_f(P_{data}||P_G) G∗=argGminDf(Pdata∣∣PG),我們把GAN的訓練目标展開,就是:
G ∗ = a r g min G D f ( P d a t a ∣ ∣ P G ) = a r g min G max D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] } = a r g min G max D V ( G , D ) \begin{aligned} G^* = &arg \min\limits_GD_f(P_{data}||P_G)\\ =& arg \min\limits_G\max\limits_D\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\} \\ =& arg \min\limits_G \max\limits_D V(G,D) \end{aligned} G∗===argGminDf(Pdata∣∣PG)argGminDmax{Ex∼Pdata[D(x)]−Ex∼PG[f∗(D(x))]}argGminDmaxV(G,D)
是以你可以選用不同的f-divergence,優化不同的divergence,論文裡面給了個清單,你可以自己選:
那麼使用不同的divergence會有什麼用處嗎?它可能能夠用于解決GAN在訓練過程中會出現的一些問題(衆所周知GAN難以訓練)
GAN訓練過程中可能産生的問題
Mode Collapse
介紹
這個概念是GAN難以訓練的原因之一,它指的是GAN産生的樣本單一,認為滿足某一分布的結果為True,其餘為False。如下圖,原始資料分布的範圍要比GAN訓練的結果大得多。進而導緻generator訓練出來的結果可能都差不多,圖檔差異性不大。
當我們的GAN模型Training with too many iterations……
有些人臉就會比較像,除了一些顔色不太一樣
解決辦法
對于mode collapse到底有沒有一些更加通用的解決辦法呢?你可以用ensemble的方法,其實就是訓練多個Generator,然後在使用的時候随機挑一個generator來生成結果,當然是一個很流氓的招數,看你用到哪裡了…
Mode Dropping
介紹
這個問題從字面上也好了解,假設原始分布有兩個比較集中的波峰,而GAN有可能把分布集中在其中一個波峰,而抛棄掉了另一個,如下圖:
例如下面的人臉,一個循環隻有白種人,一個循環隻有黃種人,一個循環中隻有黑人。
問題分析
為什麼會有這樣的結果呢,一個猜測是divergence選得不好,選擇不同的divergence,最後generator得到的distribution會不一樣。如下圖,minimize KL divergence和reverse KL divergence的時候,最後得到的分布是不一樣的, 前者容易導緻模糊的問題,後者則可能導緻mode dropping。如果你覺得在訓練過程中出現的mode collapse或者mode dropping是由于divergence導緻的,你可以通過嘗試更換 f ∗ f^* f∗ 來實驗。當然不一定說就一定有效果,這裡隻是提供一種可能的猜測。