最近想學習下GAN,于是先學習下VAE。
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiInVGcq5iM4MjNkZjZzEWZ3MWO2MzMmNWNzYWZhF2YlBzYjVzN18CX0JXZ252bj91Ztl2Lc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpeg)
代碼實作一部分出自這本書,因為這本書會給出pytorch的代碼實作,是以我覺得還不錯。但是缺點也很明顯:理論講解不夠,代碼還有錯誤或者不全:
代碼實作:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image # Save a given Tensor into an image file.
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
#建構VAE模型,主要由Encoder和Decoder組成
class VAE(nn.Module):
def __init__(self, image_size=784, h_dim=400, z_dim=20):
super().__init__()
self.fc1 = nn.Linear(image_size, h_dim)
self.fc2 = nn.Linear(h_dim, z_dim)
self.fc3 = nn.Linear(h_dim, z_dim)
self.fc4 = nn.Linear(z_dim, h_dim)
self.fc5 = nn.Linear(h_dim, image_size)
def encode(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h), self.fc3(h)
def reparameterize(self, mu, log_var):
std = torch.exp(log_var/2)
eps = torch.rand_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc4(z))
return F.sigmoid(self.fc5(h))
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_reconst = self.decode(z)
return x_reconst, mu, log_var
def loss_function(x_reconst, x, mu, log_var): #損失函數,抄的網上的
BCE_loss = nn.BCELoss(reduction='sum')
reconstruction_loss = BCE_loss(x_reconst, x)
KL_divergence = -0.5 * torch.sum(1 + log_var - torch.exp(log_var) - mu ** 2)
return reconstruction_loss + KL_divergence
if __name__ == '__main__':
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 30
batch_size = 128
learning_rate = 0.001
dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
examples = enumerate(data_loader) # 組合成一個索引序列
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure() # 顯示幾張data_loader裡的圖檔
for i in range(6) :
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title('Groud Truth: {}'.format(example_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
losses = []
for epoch in range(30) : # 開始訓練
train_loss = 0
train_acc = 0
model.train()
for imgs, labels in data_loader :
imgs = imgs.to(device)
labels = labels.to(device)
real_imgs = torch.flatten(imgs, start_dim=1)
# 前向傳播
gen_imgs, mu, log_var = model(real_imgs)
loss = loss_function(gen_imgs, real_imgs, mu, log_var)
# 反向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 記錄誤差
train_loss += loss.item()
print('epoch: {}, loss: {}'.format(epoch, train_loss / len(data_loader)))
losses.append(train_loss / len(data_loader))
fake_images = gen_imgs.view(-1, 1, 28, 28)
x_concat = torch.cat([imgs.view(-1, 1, 28, 28), fake_images], dim=3)
save_image(x_concat, 'MNIST_fake_pics/fake_images-{}.png'.format(epoch + 1)) #将原圖和生成的圖檔放一起對比
torch.save(model.state_dict(), './vae.pth')
plt.title('trainloss')
plt.plot(np.arange(len(losses)), losses, linewidth=1.5, linestyle='dashed', label='train_losses')
plt.xlabel('epoch')
plt.legend()
plt.show()
結果:
上圖中偶數列是MNIST原圖,奇數列是生成的圖。可以發現生成效果還不錯,雖然還是會淡一點點。