天天看點

深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字

 本文介紹CGAN(Conditional Generative Adversarial Nets)- 條件生成對抗網絡

 相關論文 https://arxiv.org/pdf/1411.1784.pdf

1、概述

 CGAN( Conditional Generative Adversarial Nets),條件生成對抗網絡。條件生成對抗網絡指的是在生成對抗網絡中

加入條件(condition),條件的作用是監督生成對抗網絡。最基本的對抗網具有以下優點:永遠不需要馬爾可夫鍊,僅使用

反向傳播來獲得梯度,在學習期間不需要推理,并且可以容易地将各種因素和互相作用結合到模型中。

 在無條件的生成模型中,無法控制正在生成的資料的模式。但是,通過在附加資訊上調整模型,可以指導資料生成過程。

這種調節可以基于類别标簽,在某些部分資料上進行修複,甚至是來自不同模态的資料。

2、生成對抗網絡

 GAN(Generative Adversarial Nets)由兩個“對抗”模型組成:一個捕獲資料分布的生成模型G和一個判别模型D,

它估計樣本來自訓練資料的機率而不是生成樣本的機率. G和D都可以是非線性的映射函數,例如多層感覺器。

為了在資料資料x上學習生成器分布p_z(z) ,生成器建立從先前噪聲分布p_z(z)到資料空間的映射函數,如G(z;θg)。

鑒别器D(x;θd)輸入是真實圖像或者生成圖像,輸出單個标量,該标量表示x來自訓練資料而不是p_g的機率。

G和D都同時訓練:固定判别模型 D,調整 G 的參數使得 log(1−D(G(z))的期望最小化;固定生成模型 G,調整 D 的參數

使得 logD(X)+log(1−D(G(z)))log 的期望最大化,這個優化過程歸結為二進制極小極大博弈(minimax two-player game)”問題:

深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字

3、條件生成對抗網絡

 條件生成式對抗網絡(CGAN)是對原始GAN的一個擴充,生成器和判别器都增加額外資訊 y為條件, y可以使任意資訊,

例如類别資訊,或者其他模态的資料。

 如下圖所示,通過将額外資訊 y 輸送給判别模型和生成模型,作為輸入層的一部分,進而實作條件GAN。在生成模型中,

先驗輸入噪聲 p(z) 和條件資訊 y 聯合組成了聯合隐層表征。對抗訓練架構在隐層表征的組成方式方面相當地靈活。類似地,

條件 GAN 的目标函數是帶有條件機率的二人極小極大值博弈(two-player minimax game ):

深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字
深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字

4、TensorFlow實作通過mnist資料集生成手寫數字

 完整代碼 https://github.com/clark82/deeplearning

4.1 代碼解讀

 輸入資料,其中 real_img_digit為真實資料的标簽資料,10維的向量,即條件資訊。此資訊可以引導生成哪個數字。

def inputs(real_size, noise_size):
    """
    真實圖像tensor與噪聲圖像tensor
    """

    real_img_digit = tf.placeholder(tf.float32, [None, k], name='real_img_digit')
    real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')
    
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')
      
    return real_img, noise_img, real_img_digit
           

 生成器,和基礎的GAN基本一樣,先條件資訊和noise資料拼接,之後操作和GAN完全一樣

def generator(digit, noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    """
    digit:輸入的條件資訊
    noise_img: 生成器的輸入
    n_units: 隐層單元個數
    out_dim: 生成器輸出tensor的size,這裡應該為32*32=784
    alpha: leaky ReLU系數
    """
    with tf.variable_scope("generator", reuse=reuse):

    concatenated_img_digit = tf.concat([digit, noise_img], 1)

    # hidden layer
    hidden1 = tf.layers.dense(concatenated_img_digit, n_units)
    # leaky ReLU ,和ReLU差別:ReLU是将所有的負值都設為零,相反,Leaky ReLU是給所有負值賦予一個非零斜率。
    hidden1 = tf.maximum(alpha * hidden1, hidden1)
    # dropout
    hidden1 = tf.layers.dropout(hidden1, rate=0.2)

    # logits & outputs
    logits = tf.layers.dense(hidden1, out_dim)
    outputs = tf.tanh(logits)
    
    return outputs
           

 判别器,和基礎的GAN基本一樣,先條件資訊和真實資料拼接,之後操作和GAN完全一樣

def discriminator(digit, img, n_units, reuse=False, alpha=0.01):
    """
    digit:輸入的條件資訊
    n_units: 隐層結點數量
    alpha: Leaky ReLU系數
    """
    
    with tf.variable_scope("discriminator", reuse=reuse):

        concatenated_img_digit = tf.concat([digit, img], 1)

        # hidden layer
        hidden1 = tf.layers.dense(concatenated_img_digit, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        
        # logits & outputs
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs
           

 損失函數和GAN完全一樣,訓練過程增加标簽資訊

for batch_i in range(mnist.train.num_examples//batch_size):

                batch = mnist.train.next_batch(batch_size)
                # 這裡讀取标簽資訊,作為real_img_digit: digits,batch資料
                digits = batch[1]
               
                batch_images = batch[0].reshape((batch_size, 784))
                # 對圖像像素進行scale,這是因為tanh輸出的結果介于(-1,1),real和fake圖檔共享discriminator的參數
                # 把圖檔灰階0~1變成 -1 到1的值, 以适應generator輸出的結果(-1,1)
                batch_images = batch_images*2 - 1
               
                # generator的輸入噪聲
                batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
                
                # Run optimizers
                _ = sess.run(d_train_opt, feed_dict={real_img_digit: digits, real_img: batch_images, noise_img: batch_noise})
                _ = sess.run(g_train_opt, feed_dict={real_img_digit: digits, noise_img: batch_noise})
           

4.2 調試

 完整的代碼https://github.com/clark82/deeplearning

 1、下mnist資料,拷貝到MNIST_data目錄下

 2、訓練模型

 python train.py -f 0

 輸出如下過程,訓練正常運作

 Epoch 1/300… Discriminator Loss: 0.2163(Real: 0.0101 + Fake: 0.2062)… Generator Loss: 1.9151

 Epoch 2/300… Discriminator Loss: 0.2752(Real: 0.0348 + Fake: 0.2404)… Generator Loss: 6.2620

 Epoch 3/300… Discriminator Loss: 0.6922(Real: 0.3858 + Fake: 0.3064)… Generator Loss: 2.3120

 Epoch 4/300… Discriminator Loss: 2.1965(Real: 0.8177 + Fake: 1.3788)… Generator Loss: 1.0322

 3、檢視訓練過程中生成狀态

 python train.py -f 2

深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字

 4、驗證模型,生成一批資料

 python train.py -f 1

 這裡可以把生成條件資訊(标簽)列印出來,觀察其生成資料的關系

# 生成标簽使用者生成圖檔
digits = np.zeros((25, k))
for i in range(0, 25):
    j = np.random.randint(0, 9, 1)
    digits[i][j] = 1

print (digits)
gen_samples = sess.run(generator(real_img_digit, noise_img, g_units, img_size, reuse=True),
                    feed_dict={real_img_digit: digits, noise_img: sample_noise})
           

 輸出結果,可以看到輸入标簽和生成的數字一一對應

[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] – 5

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] – 1

[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] – 6

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]

[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]

[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]

深度學習之 CGAN及TensorFlow 實作1、概述2、生成對抗網絡3、條件生成對抗網絡4、TensorFlow實作通過mnist資料集生成手寫數字