本文介紹CGAN(Conditional Generative Adversarial Nets)- 條件生成對抗網絡
相關論文 https://arxiv.org/pdf/1411.1784.pdf
1、概述
CGAN( Conditional Generative Adversarial Nets),條件生成對抗網絡。條件生成對抗網絡指的是在生成對抗網絡中
加入條件(condition),條件的作用是監督生成對抗網絡。最基本的對抗網具有以下優點:永遠不需要馬爾可夫鍊,僅使用
反向傳播來獲得梯度,在學習期間不需要推理,并且可以容易地将各種因素和互相作用結合到模型中。
在無條件的生成模型中,無法控制正在生成的資料的模式。但是,通過在附加資訊上調整模型,可以指導資料生成過程。
這種調節可以基于類别标簽,在某些部分資料上進行修複,甚至是來自不同模态的資料。
2、生成對抗網絡
GAN(Generative Adversarial Nets)由兩個“對抗”模型組成:一個捕獲資料分布的生成模型G和一個判别模型D,
它估計樣本來自訓練資料的機率而不是生成樣本的機率. G和D都可以是非線性的映射函數,例如多層感覺器。
為了在資料資料x上學習生成器分布p_z(z) ,生成器建立從先前噪聲分布p_z(z)到資料空間的映射函數,如G(z;θg)。
鑒别器D(x;θd)輸入是真實圖像或者生成圖像,輸出單個标量,該标量表示x來自訓練資料而不是p_g的機率。
G和D都同時訓練:固定判别模型 D,調整 G 的參數使得 log(1−D(G(z))的期望最小化;固定生成模型 G,調整 D 的參數
使得 logD(X)+log(1−D(G(z)))log 的期望最大化,這個優化過程歸結為二進制極小極大博弈(minimax two-player game)”問題:
3、條件生成對抗網絡
條件生成式對抗網絡(CGAN)是對原始GAN的一個擴充,生成器和判别器都增加額外資訊 y為條件, y可以使任意資訊,
例如類别資訊,或者其他模态的資料。
如下圖所示,通過将額外資訊 y 輸送給判别模型和生成模型,作為輸入層的一部分,進而實作條件GAN。在生成模型中,
先驗輸入噪聲 p(z) 和條件資訊 y 聯合組成了聯合隐層表征。對抗訓練架構在隐層表征的組成方式方面相當地靈活。類似地,
條件 GAN 的目标函數是帶有條件機率的二人極小極大值博弈(two-player minimax game ):
4、TensorFlow實作通過mnist資料集生成手寫數字
完整代碼 https://github.com/clark82/deeplearning
4.1 代碼解讀
輸入資料,其中 real_img_digit為真實資料的标簽資料,10維的向量,即條件資訊。此資訊可以引導生成哪個數字。
def inputs(real_size, noise_size):
"""
真實圖像tensor與噪聲圖像tensor
"""
real_img_digit = tf.placeholder(tf.float32, [None, k], name='real_img_digit')
real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')
return real_img, noise_img, real_img_digit
生成器,和基礎的GAN基本一樣,先條件資訊和noise資料拼接,之後操作和GAN完全一樣
def generator(digit, noise_img, n_units, out_dim, reuse=False, alpha=0.01):
"""
digit:輸入的條件資訊
noise_img: 生成器的輸入
n_units: 隐層單元個數
out_dim: 生成器輸出tensor的size,這裡應該為32*32=784
alpha: leaky ReLU系數
"""
with tf.variable_scope("generator", reuse=reuse):
concatenated_img_digit = tf.concat([digit, noise_img], 1)
# hidden layer
hidden1 = tf.layers.dense(concatenated_img_digit, n_units)
# leaky ReLU ,和ReLU差別:ReLU是将所有的負值都設為零,相反,Leaky ReLU是給所有負值賦予一個非零斜率。
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# dropout
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
# logits & outputs
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return outputs
判别器,和基礎的GAN基本一樣,先條件資訊和真實資料拼接,之後操作和GAN完全一樣
def discriminator(digit, img, n_units, reuse=False, alpha=0.01):
"""
digit:輸入的條件資訊
n_units: 隐層結點數量
alpha: Leaky ReLU系數
"""
with tf.variable_scope("discriminator", reuse=reuse):
concatenated_img_digit = tf.concat([digit, img], 1)
# hidden layer
hidden1 = tf.layers.dense(concatenated_img_digit, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# logits & outputs
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
損失函數和GAN完全一樣,訓練過程增加标簽資訊
for batch_i in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
# 這裡讀取标簽資訊,作為real_img_digit: digits,batch資料
digits = batch[1]
batch_images = batch[0].reshape((batch_size, 784))
# 對圖像像素進行scale,這是因為tanh輸出的結果介于(-1,1),real和fake圖檔共享discriminator的參數
# 把圖檔灰階0~1變成 -1 到1的值, 以适應generator輸出的結果(-1,1)
batch_images = batch_images*2 - 1
# generator的輸入噪聲
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
# Run optimizers
_ = sess.run(d_train_opt, feed_dict={real_img_digit: digits, real_img: batch_images, noise_img: batch_noise})
_ = sess.run(g_train_opt, feed_dict={real_img_digit: digits, noise_img: batch_noise})
4.2 調試
完整的代碼https://github.com/clark82/deeplearning
1、下mnist資料,拷貝到MNIST_data目錄下
2、訓練模型
python train.py -f 0
輸出如下過程,訓練正常運作
Epoch 1/300… Discriminator Loss: 0.2163(Real: 0.0101 + Fake: 0.2062)… Generator Loss: 1.9151
Epoch 2/300… Discriminator Loss: 0.2752(Real: 0.0348 + Fake: 0.2404)… Generator Loss: 6.2620
Epoch 3/300… Discriminator Loss: 0.6922(Real: 0.3858 + Fake: 0.3064)… Generator Loss: 2.3120
Epoch 4/300… Discriminator Loss: 2.1965(Real: 0.8177 + Fake: 1.3788)… Generator Loss: 1.0322
3、檢視訓練過程中生成狀态
python train.py -f 2
4、驗證模型,生成一批資料
python train.py -f 1
這裡可以把生成條件資訊(标簽)列印出來,觀察其生成資料的關系
# 生成标簽使用者生成圖檔
digits = np.zeros((25, k))
for i in range(0, 25):
j = np.random.randint(0, 9, 1)
digits[i][j] = 1
print (digits)
gen_samples = sess.run(generator(real_img_digit, noise_img, g_units, img_size, reuse=True),
feed_dict={real_img_digit: digits, noise_img: sample_noise})
輸出結果,可以看到輸入标簽和生成的數字一一對應
[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] – 5
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] – 1
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] – 6
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]