天天看點

變分自編碼器VAE實作MNIST資料集生成by Pytorch

最近想學習下GAN,于是先學習下VAE。

變分自編碼器VAE實作MNIST資料集生成by Pytorch
變分自編碼器VAE實作MNIST資料集生成by Pytorch
變分自編碼器VAE實作MNIST資料集生成by Pytorch
變分自編碼器VAE實作MNIST資料集生成by Pytorch

 代碼實作一部分出自這本書,因為這本書會給出pytorch的代碼實作,是以我覺得還不錯。但是缺點也很明顯:理論講解不夠,代碼還有錯誤或者不全:

變分自編碼器VAE實作MNIST資料集生成by Pytorch

代碼實作:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image  # Save a given Tensor into an image file.
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np


#建構VAE模型,主要由Encoder和Decoder組成
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.rand_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var



def loss_function(x_reconst, x, mu, log_var): #損失函數,抄的網上的
    BCE_loss = nn.BCELoss(reduction='sum')
    reconstruction_loss = BCE_loss(x_reconst, x)
    KL_divergence = -0.5 * torch.sum(1 + log_var - torch.exp(log_var) - mu ** 2)
    return reconstruction_loss + KL_divergence




if __name__ == '__main__':
    image_size = 784
    h_dim = 400
    z_dim = 20
    num_epochs = 30
    batch_size = 128
    learning_rate = 0.001

    dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
    data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

    examples = enumerate(data_loader)  # 組合成一個索引序列
    batch_idx, (example_data, example_targets) = next(examples)

    fig = plt.figure() # 顯示幾張data_loader裡的圖檔
    for i in range(6) :
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title('Groud Truth: {}'.format(example_targets[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    losses = []
    for epoch in range(30) : # 開始訓練
        train_loss = 0
        train_acc = 0
        model.train()
        for imgs, labels in data_loader :
            imgs = imgs.to(device)
            labels = labels.to(device)
            real_imgs = torch.flatten(imgs, start_dim=1)
            # 前向傳播
            gen_imgs, mu, log_var = model(real_imgs)
            loss = loss_function(gen_imgs, real_imgs, mu, log_var)

            # 反向傳播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 記錄誤差
            train_loss += loss.item()
        print('epoch: {}, loss: {}'.format(epoch, train_loss / len(data_loader)))
        losses.append(train_loss / len(data_loader))

        fake_images = gen_imgs.view(-1, 1, 28, 28)
        x_concat = torch.cat([imgs.view(-1, 1, 28, 28), fake_images], dim=3)
        save_image(x_concat, 'MNIST_fake_pics/fake_images-{}.png'.format(epoch + 1)) #将原圖和生成的圖檔放一起對比
    torch.save(model.state_dict(), './vae.pth')
    plt.title('trainloss')
    plt.plot(np.arange(len(losses)), losses, linewidth=1.5, linestyle='dashed', label='train_losses')
    plt.xlabel('epoch')
    plt.legend()
    plt.show()
           

 結果:

變分自編碼器VAE實作MNIST資料集生成by Pytorch
變分自編碼器VAE實作MNIST資料集生成by Pytorch
變分自編碼器VAE實作MNIST資料集生成by Pytorch

 上圖中偶數列是MNIST原圖,奇數列是生成的圖。可以發現生成效果還不錯,雖然還是會淡一點點。

繼續閱讀