天天看點

生成對抗網絡(GAN)

  GAN的全稱是 Generative Adversarial Networks,中文名稱是生成對抗網絡。原始的GAN是一種無監督學習方法,巧妙的利用“博弈”的思想來學習生成式模型。

1 GAN的原理

  GAN的基本原理很簡單,其由兩個網絡組成,一個是生成網絡G(Generator) ,另外一個是判别網絡D(Discriminator)。它們的功能分别是:

  生成網絡G:負責生成圖檔,它接收一個随機的噪聲 $z$,通過該噪聲生成圖檔,将生成的圖檔記為 $G(z)$。

  判别網絡D:負責判别一張圖檔是真實的圖檔還是由G生成的假的圖檔。其輸入是一張圖檔 $x$ ,輸出是0, 1值,0代表圖檔是由G生成的,1代表是真實圖檔。

  在訓練過程中,生成網路G的目标是盡量生成真實的圖檔去欺騙判别網絡D。而判别網絡D的目标就是盡量把G生成的圖檔和真實的圖檔區分開來。這樣G和D就構成了一個動态的博弈過程。這是GAN的基本思想。

  在最理想的狀态下,G可以生成足以“以假亂真”的圖檔 $G(z)$。對于D來說,它難以判斷G生成的圖檔究竟是不是真實的,是以 $D(G(z)) = 0.5$ (在這裡我們輸入的真實圖檔和生成的圖檔是各一半的)。此時得到的生成網絡G就可以用來生成圖檔。

2 GAN損失函數

  從數學的角度上來看GAN,假設用于訓練的真實圖檔資料是 $x$,圖檔資料的分布為 $p_{data}(x)$,生成網絡G需要去學習到真實資料分布 $p_{data}(x)$。噪聲 $z$ 的分布假設為$p_z(z)$,在這裡 $p_z(z)$是已知的,而 $p_{data}(x)$ 是未知的。在理想的狀态下$G(z)$ 的分布應該是盡可能接近$p_{data}(x)$,G将已知分布的$z$ 變量映射到位置分布 $x$ 變量上。

  根據交叉熵損失,可以構造下面的損失函數:

  $ V(D,G) = E_{x~p_{data}(x)} [ln D(x)] + E_{z~p_z(z)} [ln(1-D(G(z)))] $

  其實從損失函數中可以看出和邏輯回歸的損失函數基本一樣,唯一不一樣的是負例的機率值為 $ 1-D(G(z))$。

  損失函數中加号的前一半是訓練資料中的真實樣本,後一半是從已知的噪聲分布中取的樣本。下面對這個損失函數較長的描述:

  1)整個式子有兩項構成。 $x$表示真實圖檔,$z$表示輸入G網絡的噪聲,而$G(z)$ 表示G網絡生成的圖檔。

  2)$D(x)$ 表示D網絡判斷真實圖檔是否真實的機率 ,即 $P(y=1 | x)$。而$D(G(z))$ 是D網絡判斷$G$生成的圖檔是否真實的機率。

  3)G的目的:G應該希望自己生成的圖檔越真實越好。也就是說G希望 $D(G(z))$ 盡可能大,即$P(G(z) = 1 | x)$,這時 $V(D, G)$ 盡可能小。

  4)D的目的:D的能力越強,$D(x)$ 就應該越大,$D(G(x))$應該越小(即假的圖檔都被識别為0)。是以D的目的和G的目的不同,D希望 $V(D, G)$ 越大越好。

3 GAN模組化流程

  在實際訓練中,使用梯度下降法,對D和G交替做優化,具體步驟如下:

  1)從已知的噪聲分布 $p_z(z)$中選取一些樣本

    ${z_1, z_2, ......, z_m}$

  2)從訓練資料中選出同樣個數的真實圖檔

    ${x_1, x_2, ......, x_m}$

  3)設判别器D的參數為 $\theta_d$,其損失函數的梯度為

    $ \nabla \frac{1}{m} \sum_{i=1}^m [lnD(x_i) + ln(1-D(G(Z_I)))] $

  4)設生成器G的參數為 $\theta_g$,其損失函數的梯度為

    $ \nabla \frac{1}{m} \sum_{i=1}^m [ln(1-D(G(Z_I)))] $

  在上面的步驟中,每更新一次D的參數,緊接着就更新一次G的參數,有時也可以在更新 $k$ 次D的參數,再更新一次G的參數。