天天看点

(七)SN-GAN论文笔记与实战

(七)SN-GAN论文笔记与实战

        • 一、论文笔记
        • 二、完整代码
        • 三、遇到的问题及解决

一、论文笔记

在WGAN-GP中使用gradient penalty 的方法来限制判别器,但这种放法只能对生成数据分布与真实分布之间的分布空间的数据做梯度惩罚,无法对整个空间的数据做惩罚。这会导致随着训练的进行,生成数据分布与真实数据分布之间的空间会逐渐变化,从而导致gradient penalty 正则化方式不稳定。此外,WGAN-GP涉及比较多的运算,所以训练WGAN-GP的网络也比较耗时。

SN-GAN提出使用Spectral Normalization(谱归一化)的方法来让判别器D满足Lipschitz约束,简单而言,SN-GAN只需要改变判别器权值矩阵的最大奇异值,这种方法可以最大限度地保存判别器权值矩阵的信息,这个优势可以让SN-GAN使用类别较多的数据集作为训练数据,依旧可以获得比较好的生成效果。

从SN-GAN 论文中的实际效果来看,SN-GAN是目前仅有的可以使用单个生成器与判别器从ImageNet数据集(其中的图像有非常多的类别)生成高质量图像的GAN模型,WGAN、WGAN-GP等GAN模型在多类别图像中无法生成高质量的图像。其中一个可能的原因就是,在训练过程中,WGAN、WGAN-GP等GAN模型丧失了较多的原始信息。

简单而言,SN-GAN具有如下优势:

  1. 以Spectral Normalization 方法让判别器D满足Lipschitz约束,Lipschitz的常数K是唯一需要调整的超参数。
  2. 整体上SN-GAN只改变判别器权值矩阵的最大奇异值,从而可以最大限度地保留原始信息。
  3. 具体训练模型时,使用power iteration(迭代法),加快训练速度,可比WGAN-GP快许多。WGAN-GP慢的原因是使用gradient penalty后,模型在梯度下降的过程中相当于计算两次梯度,计算量更大,所以整体训练速度就变慢了。

判别器的目标就是对判别器的所有权重都做 W ‘ W^‘ W‘ = W W W / ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​ (谱归一化) 证明见原论文

但是由于直接计算谱范数 ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​是比较耗时的,所以为了让训练模型速度更快,就需要使用一个技巧。power iteration(幂迭代)方法通过迭代计算的思想可以比较快的计算出谱范数的近似值。

因为谱范数 ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​等于 W T W^T WT W W W的最大特征根,所以要求解谱范数,就可以转变为求 W T W^T WT W W W的最大特征根,使用power iteration的方法如下:

(七)SN-GAN论文笔记与实战

二、完整代码

代码跑不出来效果,没找到原因,不知道问题出在哪里,希望后面随着学习的深入再来看看。

import torch
import torchvision
import torch.nn as nn
import argparse
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.optim.lr_scheduler import LambdaLR
import random
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type = int, default = 200)
parser.add_argument('--batch_size', type = int, default = 64)
parser.add_argument('--lr', type = float, default = 0.0002)
parser.add_argument('--b1', type = float, default = 0.5)
parser.add_argument('--b2', type = float, default = 0.999)
parser.add_argument('--decay_epochs', type = int, default=100)
parser.add_argument('--z_dim', type = int, default=128, help = 'latent vector')

opt = parser.parse_args(args = [])
print(opt)
random.seed(22)
torch.manual_seed(22)
os.makedirs('Picture/SNGAN', exist_ok = True)
os.makedirs('runs/SNGAN', exist_ok = True)
os.makedirs('Model/SNGAN', exist_ok = True)
device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')

'''加载数据集'''
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
# 50000张图片用作训练集
train_set = torchvision.datasets.CIFAR10(root = '../dataset', train=True, transform=transform, download=False)
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=0)
print(train_set[0][0].shape)

'''自定义学习率类'''
class LambdaLR:
    def __init__(self, n_epochs, decay_epochs):
        self.n_epochs = n_epochs
        self.decay_epochs = decay_epochs
    def step(self, epoch):
        return 1.0 - max(0, (epoch - self.decay_epochs)/(self.n_epochs - self.decay_epochs))
    
'''Spectral Normalization -- 谱归一化类'''
class SpectralNorm(nn.Module):
    def __init__(self, layer, name = 'weight', power_iterations = 1):
        super(SpectralNorm, self).__init__()
        '''params:
        layer: 传入的需要使得参数谱归一化的网路层
        name : 谱归一化的参数
        power_iterations:幂迭代的次数,论文中提到,实际上迭代一次已足够
        '''
        self.layer = layer
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params(): # 如果迭代参数未初始化,则初始化
            self._make_params()
            
    def _update_u_v(self):
        u = getattr(self.layer, self.name+'_u')
        v = getattr(self.layer, self.name+'_v')
        w = getattr(self.layer, self.name+'_bar')
        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = self.l2Norm(torch.mv(torch.t(w.view(height, -1).data), u.data)) # 计算:v <- (W^t*u)/||W^t*u||   2范数
            u.data = self.l2Norm(torch.mv(w.view(height, -1).data, v.data)) # 计算:u <- (Wv)/||Wv||
        sigma = u.dot(w.view(height, -1).mv(v)) # 计算 W的谱范数 ≈ u^t * W * v
        setattr(self.layer, self.name, w/sigma.expand_as(w))
        
    def _made_params(self):
        # 存在这些参数则返回True, 否则返回False
        try:
            u = getattr(self.layer, self.name + '_u')
            v = getattr(self.layer, self.name + '_v')
            w = getattr(self.layer, self.name + '_bar')
            return True
        except AttributeError:
            return False
    def _make_params(self):
        w = getattr(self.layer, self.name)
        height = w.data.shape[0] # 输出的卷积核的数目
        width = w.view(height, -1).data.shape[1] # width为 in_feature*kernel*kernel 的值
        # .new()创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致
        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad = False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad = False)
        u.data = self.l2Norm(u.data)
        v.data = self.l2Norm(v.data)
        w_bar = nn.Parameter(w.data)
        del self.layer._parameters[self.name] # 删除以前的weight参数
        # 注册参数
        self.layer.register_parameter(self.name+'_u', u) # 传入的值u,v必须是Parameter类型
        self.layer.register_parameter(self.name+'_v', v)
        self.layer.register_parameter(self.name+'_bar', w_bar)
        
    def l2Norm(self, v, eps = 1e-12): # 用于计算例如:v/||v||
        return v/(v.norm() + eps) 
    
    def forward(self, *args):
        self._update_u_v()
        return self.layer.forward(*args)

'''网络模型'''
# DCGAN-like generator and discriminator
class Generator(nn.Module):
    def __init__(self, z_dim):
        self.z_dim = z_dim
        super(Generator,self).__init__()
        self.model = nn.Sequential( # 输入shape[b, z_dim, 1, 1]
            nn.ConvTranspose2d(z_dim, 512, 4, stride=1, bias=False), # --> [b, 512, 4, 4]
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False), # --> [b, 256, 8, 8]
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False), # --> [b, 128. 16, 16]
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False), # --> [b, 64, 32, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, 3, stride=1, padding=1, bias=False), # --> [b, 3, 32, 32]
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, 2, 1))
        self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, 2, 1))
        self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, 2, 1))
        self.conv7 = SpectralNorm(nn.Conv2d(256, 512, 3, 1, 1))
        self.fc = SpectralNorm(nn.Linear(4*4*512, 1))
        
    def forward(self, img):
        img = nn.LeakyReLU(0.1)(self.conv1(img))
        img = nn.LeakyReLU(0.1)(self.conv2(img))
        img = nn.LeakyReLU(0.1)(self.conv3(img))
        img = nn.LeakyReLU(0.1)(self.conv4(img))
        img = nn.LeakyReLU(0.1)(self.conv5(img))
        img = nn.LeakyReLU(0.1)(self.conv6(img))
        img = nn.LeakyReLU(0.1)(self.conv7(img))
        
        return self.fc(img.view(-1, 4*4*512))
    
generator = Generator(opt.z_dim).to(device)
discriminator = Discriminator().to(device)
print(generator)
print(discriminator)

test_noise = torch.randn(64, opt.z_dim, 1, 1, device = device)

# because the spectral normalization module creates parameters that don't require gradients (u and v), we don't want to 
# optimize these using sgd. We only let the optimizer operate on parameters that _do_ require gradients
# TODO: replace Parameters with buffers, which aren't returned from .parameters() method.
optim_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
optim_D = torch.optim.Adam(filter(lambda p : p.requires_grad, discriminator.parameters()), 
                           lr = opt.lr, betas=(opt.b1, opt.b2))
lr_schedual_G = torch.optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=LambdaLR(opt.n_epochs, opt.decay_epochs).step)
lr_schedual_D = torch.optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=LambdaLR(opt.n_epochs, opt.decay_epochs).step)

'''训练'''
writer = SummaryWriter('runs/SNGAN')
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(train_loader):
        ############################
        #     discriminator
        ###########################
        b_size = imgs.size(0)
        optim_D.zero_grad()
        z = torch.randn(b_size, opt.z_dim, 1, 1, device = device)
        real_imgs = imgs.to(device)
        fake_imgs = generator(z).detach()

        loss_D = torch.mean(discriminator(fake_imgs)) - torch.mean(discriminator(real_imgs))
        loss_D.backward()
        optim_D.step()
        
        ############################
        #      generator
        ###########################
        if i % 5 == 0:
            optim_G.zero_grad()
            fake_imgs = generator(z)
            loss_G = -torch.mean(discriminator(fake_imgs))
            loss_G.backward()
            optim_G.step()

            print('[Epoch {}/{}] [step {}/{}] [D_loss {}] [G_loss {}]'.format(epoch, opt.n_epochs, 
                            i, len(train_loader), loss_D, loss_G))
        writer.add_scalar('D_loss', loss_D, epoch)
        writer.add_scalar('G_loss', loss_G, epoch)
    
    lr_schedual_D.step()
    lr_schedual_G.step()
    
    with torch.no_grad():
        gen_imgs = generator(test_noise)
        torchvision.utils.save_image(gen_imgs.data, 'Picture/SNGAN/generator_{}.png'.format(epoch), nrow=8, normalize=True)

           

三、遇到的问题及解决

一、Python的hasattr() getattr() setattr() 函数使用方法详解

二、Pytorch中.new()的作用详解

三、查看模型的层和参数信息的几种方式

四、pytorch中tensor.expand()和tensor.expand_as()函数解读

五、pytorch中的register_parameter()和parameter()

六、为什么spectral norm对应的SNGAN未使用WGAN的loss?

七、[

继续阅读