天天看点

生成对抗网络GAN(Generative Adversarial Network)

生成对抗网络GAN(Generative Adversarial Network)

生成对抗网络GAN是一种深度学习模型,它源于2014年发表的论文:《Generative Adversarial Nets》,论文地址:​​https://arxiv.org/pdf/1406.2661.pdf​​。

GAN的用途非常广泛,比如:有大量的卡通头像,想通过学习自动生成卡通图片,此问题只提供正例,可视为无监督学习问题。不可能通过人工判断大量数据。如何生成图片?如何评价生成的图片好坏?GAN为此类问题提供了解决方法。

GAN同时训练两个模型:生成模型G(Generative Model)和判别模型D(Discriminative Model),生成模型G的目标是学习数据的分布,判别模型D的目标是区别真实数据和模型G生成的数据。以生成卡通图片为例,生成网络G的目标是生成尽量真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。G和D构成了一个动态的“博弈过程”,通过迭代双方能力都不断提高。

用途

  • 生成数据 GAN常用于实现复杂分布上的无监督学习和半监督学习,学习数据的分布,模拟现有数据生成同类型的图片、文本、旋律等等。
  • 数据增强 GAN也用于扩展现有的数据集,即数据增强。使用它训练好的生成网络,可以在数据不足时用于补充数据。
  • 生成特定数据 GAN掌握了数据生成能力后,可通过加入限制,使模型生成特定类型的数据。比如改变图片风格,隐去敏感信息,实现诸如数据加密的功能。
  • 使用判断模型 训练好的判别模型可以用于判断数据是否属于该类别,判断数据的真实性,以及判断异常数据。
生成对抗网络GAN(Generative Adversarial Network)

我们现在拥有大量的手写数字的数据集,我们希望通过GAN生成一些能够以假乱真的手写字图片。主要由如下两个部分组成:

  1. 定义一个模型来作为生成器(图中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像。
  2. 定义一个分类器来作为判别器(图中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集Training set中的还是生成器中生成的Fake image),输入为手写图片,输出为判别图片的标签。

如何训练

接下来说明如何进行训练。

基本流程如下:

  • 初始化判别器D的参数 和生成器G的参数 。
  • 从真实样本中采样 个样本 { } ,从先验分布噪声中采样 个噪声样本 { } 并通过生成器获取 个生成样本 { } 。固定生成器G,训练判别器D尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本。
  • 循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数,训练生成器使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器判别错误。
  • 多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出。亦即最终样本判别概率均为0.5。
Tips: 之所以要训练k次判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新。更直观的理解可以参考下图:
生成对抗网络GAN(Generative Adversarial Network)

图四 生成器判别器与样本示意图

注:图中的黑色虚线表示真实的样本的分布情况,蓝色虚线表示判别器判别概率的分布情况,绿色实线表示生成样本的分布。 表示噪声, 到 表示通过生成器之后的分布的映射情况。

我们的目标是使用生成样本分布(绿色实线)去拟合真实的样本分布(黑色虚线),来达到生成以假乱真样本的目的。

可以看到在(a)状态处于最初始的状态的时候,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不是很稳定,因此会先训练判别器来更好地分辨样本。通过多次训练判别器来达到(b)样本状态,此时判别样本区分得非常显著和良好。然后再对生成器进行训练。训练生成器之后达到(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。经过多次反复训练迭代之后,最终希望能够达到(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的(判别概率均为0.5)。也就是说我们这个时候就可以生成出非常真实的样本啦,目的达到。

GAN是一种全新的非监督式的架构(如下图所示)。GAN包括了两套独立的网络,两者之间作为互相对抗的目标。第一套网络是我们需要训练的分类器(下图中的D),用来分辨是否是真实数据还是虚假数据;第二套网络是生成器(下图中的G),生成类似于真实样本的随机样本,并将其作为假样本。

生成对抗网络GAN(Generative Adversarial Network)

详细说明:

D作为一个图片分类器,对于一系列图片区分不同的动物。生成器G的目标是绘制出非常接近的伪造图片来欺骗D,做法是选取训练数据潜在空间中的元素进行组合,并加入随机噪音,例如在这里可以选取一个猫的图片,然后给猫加上第三只眼睛,以此作为假数据。

在训练过程中,D会接收真数据和G产生的假数据,它的任务是判断图片是属于真数据的还是假数据的。对于最后输出的结果,可以同时对两方的参数进行调优。如果D判断正确,那就需要调整G的参数从而使得生成的假数据更为逼真;如果D判断错误,则需调节D的参数,避免下次类似判断出错。训练会一直持续到两者进入到一个均衡和谐的状态。

训练后的产物是一个质量较高的自动生成器和一个判断能力较强强的分类器。前者可以用于机器创作(自动画出“猫”“狗”),而后者则可以用来机器分类(自动判断“猫”“狗”)。

生成对抗网络GAN(Generative Adversarial Network)

GAN模型的目标函数如下:

在这里,训练网络D使得最大概率地分对训练样本的标签(最大化log D(x)和log(1—D(G(z)))),训练网络G最小化log(1-D(G(z))),即最大化D的损失。而训练过程中固定一方,更新另一个网络的参数,交替迭代,使得对方的错误最大化,最终,G 能估测出样本数据的分布,也就是生成的样本更加的真实。

    或者我们可以直接理解G网络的loss是log(1—D(G(z))),

    D的loss是 —(log D(x))+log(1—D(G(z)))

然后从式子中解释对抗,我们知道G网络的训练是希望D(G(z))趋近于1,也就是正类,这样G的loss就会最小。而D网络的训练就是一个2分类,目标是分清楚真实数据和生成数据,也就是希望真实数据的D输出趋近于1,而生成数据的输出即D(G(z))趋近于0,或是负类。这里就是体现了对抗的思想。