天天看点

Python学习笔记——GanGan

Gan

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: gan.py
@time: 2023/01/08
@desc:
"""
import torch
from torch import nn, optim, autograd
import numpy as np
import visdom
import random
from matplotlib import pyplot as plt

h_dim=400
batchsz=512
viz=visdom.Visdom()

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.net=nn.Sequential(
            #z:[b,2]=>[b,2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2),
        )

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.Sigmoid()

        )
        def forward(self, x):
            output=self.net(x)
            return output.view(-1)
def data_generator():
    '''
    8-gaussian mixture models
    :return:
    '''
    scale=2.
    centers=[
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1./np.sqrt(2),1./np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2)),


    ]
    centers=[(scale * x,scale * y) for x, y in centers]
    while True:
        dataset=[]
        for i in range(batchsz):
            point= np.random.randn(2)*0.02
            center = random.choice(centers)
            #N(0,1)+center_x1/x2
            point[0]+=center[0]
            point[1] += center[0]
            dataset.append(point)

        dataset=np.array(dataset).astype(np.float32)
        dataset /=1.414
        yield dataset
        ##实现无限数据循环生成器


def main():

    torch.manual_seed(23)
    np.random.seed(23)
    data_iter=data_generator()
    x=next(data_iter)
    #[b,2]
    # print(x.shape)
    G=Generator().cuda()
    D=Discriminator().cuda()
    #网络结构
    #print(G)
    #print(D)
    optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
    optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
    viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
    ##Gan核心部分
    for epoch in range(50000):

        #1. train discrimimator firstly
        for _in range(5):
             #1. train on real data
            xr=next(data_iter)
            xr = torch.from_numpy(x).cuda()
        #【b,2】=>[b,1]
            predict D(xr)
        #max predr,
            loss= -predr.mean()
            #1.2 train on fake data
            #[b,]
            z= torch. randn(batchsz,2).cuda()
            xf=G(z).datach()  #类似 tf.stop_gradient()
            predf=D(xf)
            lossf=predf.mean()

            ##aggregate all
            loss D= lossr+ lossf
            #optimize
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        #2. train generator
        z=torch.randn(batchsz,2).cuda()
        xf=G(z)
        predf = D(xf)
        # max predf.mean()
        loss_G=-predf.mean()
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100==0:
            viz.lines()
            print(loss_D.item,loss_G.item())
            generate_image(D,G,xr,epoch)









if __name__=='__main__':
    main()



           

WGAN可以改善GAN的训练问题

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: wgan.py
@time: 2023/01/09
@desc:
"""
#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: gan.py
@time: 2023/01/08
@desc:
"""
import torch
from torch import nn, optim, autograd
import numpy as np
import visdom
import random
from matplotlib import pyplot as plt

##WGAN解决GAN的训练不稳定问题
h_dim=400
batchsz=512
viz=visdom.Visdom()

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.net=nn.Sequential(
            #z:[b,2]=>[b,2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2),
        )

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.Sigmoid()

        )
        def forward(self, x):
            output=self.net(x)
            return output.view(-1)
def data_generator():
    '''
    8-gaussian mixture models
    :return:
    '''
    scale=2.
    centers=[
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1./np.sqrt(2),1./np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2)),


    ]
    centers=[(scale * x,scale * y) for x, y in centers]
    while True:
        dataset=[]
        for i in range(batchsz):
            point= np.random.randn(2)*0.02
            center = random.choice(centers)
            #N(0,1)+center_x1/x2
            point[0]+=center[0]
            point[1] += center[0]
            dataset.append(point)

        dataset=np.array(dataset).astype(np.float32)
        dataset /=1.414
        yield dataset
        ##实现无限数据循环生成器
gradient_penalty(D,xr,xf):
    """
    :param D:
    :param xr[b,2]:
    :param xf[b,2]:
    :return:   
    """
    #[b,1]
    t=torch.rand(batchsz,1).cuda()
    [b,1]=>[b,2]
    t=t.expand_as(xr)
    #interpolation
    mid=t * xr +[1-t] * xf
    #set it requires gradient
    mid.requires_grad_()

    pred=D(mid)
    grads=autograd.grad(outputs=pred, inputs=mid,
                        grad_output=torch.ones_like(mid),
                        create_graph=True, retain_graph=True, only_iputs=True)[0]
    gp = torch.pow(grds.norm(2,dim=1)-1,2).mean()
    return gp




def main():

    torch.manual_seed(23)
    np.random.seed(23)
    data_iter=data_generator()
    x=next(data_iter)
    #[b,2]
    # print(x.shape)
    G=Generator().cuda()
    D=Discriminator().cuda()
    #网络结构
    #print(G)
    #print(D)
    optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
    optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
    viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
    ##Gan核心部分
    for epoch in range(50000):

        #1. train discrimimator firstly
        for _in range(5):
             #1. train on real data
            xr=next(data_iter)
            xr = torch.from_numpy(x).cuda()
        #【b,2】=>[b,1]
            predict D(xr)
        #max predr,
            loss= -predr.mean()
            #1.2 train on fake data
            #[b,]
            z= torch. randn(batchsz,2).cuda()
            xf=G(z).datach()  #类似 tf.stop_gradient()
            predf=D(xf)
            lossf=predf.mean()

            #1.3 gradient penalty
            gp = gradient_penalty(D,xr,xf.detach())


            ##aggregate all
            loss D= lossr+ lossf
            #optimize
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        #2. train generator
        z=torch.randn(batchsz,2).cuda()
        xf=G(z)
        predf = D(xf)
        # max predf.mean()
        loss_G=-predf.mean()
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100==0:
            viz.lines()
            print(loss_D.item,loss_G.item())
            generate_image(D,G,xr,epoch)









if __name__=='__main__':
    main()