Gan
#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: gan.py
@time: 2023/01/08
@desc:
"""
import torch
from torch import nn, optim, autograd
import numpy as np
import visdom
import random
from matplotlib import pyplot as plt
h_dim=400
batchsz=512
viz=visdom.Visdom()
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.net=nn.Sequential(
#z:[b,2]=>[b,2]
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,2),
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.net=nn.Sequential(
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.Sigmoid()
)
def forward(self, x):
output=self.net(x)
return output.view(-1)
def data_generator():
'''
8-gaussian mixture models
:return:
'''
scale=2.
centers=[
(1,0),
(-1,0),
(0,1),
(0,-1),
(1./np.sqrt(2),1./np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2)),
]
centers=[(scale * x,scale * y) for x, y in centers]
while True:
dataset=[]
for i in range(batchsz):
point= np.random.randn(2)*0.02
center = random.choice(centers)
#N(0,1)+center_x1/x2
point[0]+=center[0]
point[1] += center[0]
dataset.append(point)
dataset=np.array(dataset).astype(np.float32)
dataset /=1.414
yield dataset
##实现无限数据循环生成器
def main():
torch.manual_seed(23)
np.random.seed(23)
data_iter=data_generator()
x=next(data_iter)
#[b,2]
# print(x.shape)
G=Generator().cuda()
D=Discriminator().cuda()
#网络结构
#print(G)
#print(D)
optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
##Gan核心部分
for epoch in range(50000):
#1. train discrimimator firstly
for _in range(5):
#1. train on real data
xr=next(data_iter)
xr = torch.from_numpy(x).cuda()
#【b,2】=>[b,1]
predict D(xr)
#max predr,
loss= -predr.mean()
#1.2 train on fake data
#[b,]
z= torch. randn(batchsz,2).cuda()
xf=G(z).datach() #类似 tf.stop_gradient()
predf=D(xf)
lossf=predf.mean()
##aggregate all
loss D= lossr+ lossf
#optimize
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
#2. train generator
z=torch.randn(batchsz,2).cuda()
xf=G(z)
predf = D(xf)
# max predf.mean()
loss_G=-predf.mean()
#optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 100==0:
viz.lines()
print(loss_D.item,loss_G.item())
generate_image(D,G,xr,epoch)
if __name__=='__main__':
main()
WGAN可以改善GAN的训练问题
#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: wgan.py
@time: 2023/01/09
@desc:
"""
#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: JMS
@file: gan.py
@time: 2023/01/08
@desc:
"""
import torch
from torch import nn, optim, autograd
import numpy as np
import visdom
import random
from matplotlib import pyplot as plt
##WGAN解决GAN的训练不稳定问题
h_dim=400
batchsz=512
viz=visdom.Visdom()
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.net=nn.Sequential(
#z:[b,2]=>[b,2]
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,2),
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.net=nn.Sequential(
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.Sigmoid()
)
def forward(self, x):
output=self.net(x)
return output.view(-1)
def data_generator():
'''
8-gaussian mixture models
:return:
'''
scale=2.
centers=[
(1,0),
(-1,0),
(0,1),
(0,-1),
(1./np.sqrt(2),1./np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2)),
]
centers=[(scale * x,scale * y) for x, y in centers]
while True:
dataset=[]
for i in range(batchsz):
point= np.random.randn(2)*0.02
center = random.choice(centers)
#N(0,1)+center_x1/x2
point[0]+=center[0]
point[1] += center[0]
dataset.append(point)
dataset=np.array(dataset).astype(np.float32)
dataset /=1.414
yield dataset
##实现无限数据循环生成器
gradient_penalty(D,xr,xf):
"""
:param D:
:param xr[b,2]:
:param xf[b,2]:
:return:
"""
#[b,1]
t=torch.rand(batchsz,1).cuda()
[b,1]=>[b,2]
t=t.expand_as(xr)
#interpolation
mid=t * xr +[1-t] * xf
#set it requires gradient
mid.requires_grad_()
pred=D(mid)
grads=autograd.grad(outputs=pred, inputs=mid,
grad_output=torch.ones_like(mid),
create_graph=True, retain_graph=True, only_iputs=True)[0]
gp = torch.pow(grds.norm(2,dim=1)-1,2).mean()
return gp
def main():
torch.manual_seed(23)
np.random.seed(23)
data_iter=data_generator()
x=next(data_iter)
#[b,2]
# print(x.shape)
G=Generator().cuda()
D=Discriminator().cuda()
#网络结构
#print(G)
#print(D)
optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
##Gan核心部分
for epoch in range(50000):
#1. train discrimimator firstly
for _in range(5):
#1. train on real data
xr=next(data_iter)
xr = torch.from_numpy(x).cuda()
#【b,2】=>[b,1]
predict D(xr)
#max predr,
loss= -predr.mean()
#1.2 train on fake data
#[b,]
z= torch. randn(batchsz,2).cuda()
xf=G(z).datach() #类似 tf.stop_gradient()
predf=D(xf)
lossf=predf.mean()
#1.3 gradient penalty
gp = gradient_penalty(D,xr,xf.detach())
##aggregate all
loss D= lossr+ lossf
#optimize
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
#2. train generator
z=torch.randn(batchsz,2).cuda()
xf=G(z)
predf = D(xf)
# max predf.mean()
loss_G=-predf.mean()
#optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 100==0:
viz.lines()
print(loss_D.item,loss_G.item())
generate_image(D,G,xr,epoch)
if __name__=='__main__':
main()