天天看點

基于PSO訓練正常自動編碼器(Matlab代碼實作)

 💥💥💞💞歡迎來到本部落格❤️❤️💥💥

🏆部落客優勢:🌞🌞🌞部落格内容盡量做到思維缜密,邏輯清晰,為了友善讀者。

⛳️座右銘:行百裡者,半于九十。

目錄

​​💥1 概述​​

​​📚2 運作結果​​

​​🎉3 參考文獻​​

​​🌈4 Matlab代碼實作​​

💥1 概述

本文基于PSO訓練正常自動編碼器,粒子群優化是優化中最著名的基于随機搜尋算法之一。這裡不過多介紹。

📚2 運作結果

基于PSO訓練正常自動編碼器(Matlab代碼實作)

部分代碼:

clear all;

clc;

addpath('NEW_PSO','AE');

%% data preparation

original=imread('tu_pian.png');

original=imresize(original,[150,90]);

x=rgb2gray(original);

Inputs=double(x);

%% network initialization

number_neurons=89;% number of neurons

LB=-10;           % lower bands of weights

UB=10;            % upperbands of weights

n=10;             % number of population

%% training process

[net]=PSO_AE(Inputs,number_neurons,LB,UB,n);

%% Illustration

regenerated=net.code*pinv(net.B');

subplot(121)

imagesc(regenerated);

colormap(gray);

Tc=num2str(net.prefomance);

Tc= ['RMSE = ' Tc];

xlabel('regenerated image')

title(Tc)

subplot(122)

plot(smooth(net.errors,52),'LineWidth',2);

xlabel('iterations')

ylabel('RMSE')

title('loss function behavior')

axis([0 length(net.errors) min(net.errors) max(net.errors)])

grid

function[net]=PSO_AE(Inputs,number_neurons,LB,UB,n)

% PSO_AE:this function  trains  an auto-encoder based random search tool (PSO). 

% number_neurons:number of neurons in hidden layer.

% Inputs: the training set.

% LB: Lower band constraints for the weights

% LB: Lower band constraints for the weights

% n : number of population (in PSO).

% net: this variable contains the important Characteristics  of training

Inputs = scaledata(Inputs,0,1);% data Normalization

alpha=size(Inputs);

% Initialize the PSO parameters

m=number_neurons*alpha(2);

LB=ones(1,m)*(LB);

UB=ones(1,m)*(UB);

% Solving Optimization problem based on random search

[Fvalue,B,nb_iterations,fit_behavior]=PSO(m,n,LB,UB,Inputs,number_neurons);%

% prepare the problem solution 

B=reshape(B,number_neurons,alpha(2));

% calculate the Inputs_hat : Unlike other networks the AEs uses the same weight

% beta as an input weigth for coding and output weights for decoding

% we will no longer use the old input weights:input_weights. 

Hnew=Inputs*B';          % the hidden layer

Inputs_hat=Hnew*pinv(B');% the estimated Input

% store the network Characteristics 

net.errors=fit_behavior;% the training loss function behavior

net.prefomance=sqrt(mse(Inputs-Inputs_hat));% the training preformance

net.B=B;% the reconstructio weights 

net.code=Hnew;% the training hidden layer

end 

🎉3 參考文獻

​​🌈​​4 Matlab代碼實作

繼續閱讀