天天看點

Deep Auto-encoder的代碼實作

李宏毅講Auto-encoder視訊:連結位址

看了上面的Auto-encoder視訊,想試着做一下裡面的Deep Auto-encoder,看看效果如何,神經網絡架構如下:

Deep Auto-encoder的代碼實作

按照上面的網絡架構,采用Keras實作Deep Auto-encoder,loss函數采用均方誤差函數。

疊代之後的loss下降圖:

Deep Auto-encoder的代碼實作

最終的效果:

上面是原圖,下面是由原圖經過整個網絡生成的圖檔:

Deep Auto-encoder的代碼實作

效果并沒有論文中顯示的那麼好,暫時還沒找到原因,明天看看論文,看看能不能解決。

更新:論文上面說如果要訓練得很好,那麼需要将參數初始化到最優附近,然後通過反向傳播算法進行fine-tune可以使得結果很好,初始化感覺比較複雜,難搞。

代碼如下:

#coding=utf-8

from keras.datasets import mnist
from keras.layers import Input,Dense,Reshape
from keras.models import Sequential, Model
from keras.optimizers import Adam,SGD

import matplotlib.pyplot as plt
import sys
import os
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

img_height = 28
img_width = 28
def build_model():

    encoder = Sequential()
    encoder.add(Dense(units=1000))
    encoder.add(Dense(units=500))
    encoder.add(Dense(units=250))
    encoder.add(Dense(units=30))

    decoder = Sequential()
    decoder.add(Dense(units=250))
    decoder.add(Dense(units=500))
    decoder.add(Dense(units=1000))
    decoder.add(Dense(units=784))

    img_input = Input(shape=[img_width*img_height])
    code = encoder(img_input)

    reconstruct_img = decoder(code)

    combined = Model(img_input,reconstruct_img)

    optimizer = Adam(0.001)
    combined.compile(loss='mse', optimizer=optimizer)
    return encoder,decoder,combined


epochs = 100000
batch_size = 64
mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)
def train():
    losses = []
    for i in range(epochs):
        imgs,labels = mnist.train.next_batch(batch_size)
        imgs = imgs/2.0 - 1 # -1.0 - 1.0
        loss = combined.train_on_batch(imgs,imgs)
        if i % 5 == 0:
            print("epoch:%d,loss:%f"%(i,loss))
            losses.append(loss)
    plt.plot(np.arange(0,epochs,5),losses)
    
def test():
    mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)
    imgs,labels = mnist.test.next_batch(3)
    imgs = imgs*2.0 - 1 # -1.0 - 1.0
    output_imgs = combined.predict(imgs)
    output_imgs = (output_imgs+1)/2.0 # -1.0 - 1.0
    for i in range(3):
        plt.figure(1)
        plt.subplot(2,3,i+1) #兩行一列的第一個子圖
        plt.imshow(imgs[i].reshape((28,28)), cmap='gray')
        plt.subplot(2,3,i+1+3) #兩行一列的第二個子圖
        plt.imshow(output_imgs[i].reshape((28,28)), cmap='gray')
        
if __name__ == '__main__':
    encoder, decoder, combined = build_model()
    train()
    test()
           

繼續閱讀