天天看點

RVM算法的matlab實作

相比于SVM,MLaPP對RVM評價頗高:We attribute the enormous popularity of SVMs not to their superiority, but to ignorance of the alternatives, and also to the lack of high quality software implementing the alternatives.

這裡用Matlab實作了簡單的RVM算法,具體算法詳見PRML 7.2節。

下圖為針對訓練集預測結果,可以看到相關向量(RV:Relevance Vectors)确實很少。

RVM算法的matlab實作

代碼如下:

close all;
clear all;
clc;

%% parameters
N=200;     % 訓練集樣本數
Nts=1000;    % 測試集/預測集樣本數

%% data generation and display
[x,t]=datagen(N);

figure(1);
hold on;
plot(x,t,'k.');

%% RVM
K=RBFkernal(x,x); %生成K矩陣:N*(N+1)
% 随機初始化 alpha(系數w的先驗方差倒數) 和 
% beta(樣本點誤差的方差倒數,p(t|w,x,1/beta)=N(t|y(x),beta)=N(t|w*K,1/beta)
m=size(K,2);
alp=rand(1,m);
beta=rand();

for ii=1:1000,
    % 計算原理詳見PRML 7.2節
    sig=pinv(diag(alp)+beta*(K'*K));    % 系數w的後驗方差矩陣Sigma
    mu=sig*K'*t*beta;                   % 系數w的後驗均值mu/u
    
    gamma=1-alp.*diag(sig)';
    
    alp_old=alp;
    beta_old=beta;
    idx=abs(alp)<1e3; % 部分alp會趨向于無窮大,對應的mu會趨向于會向于0,對于這部分alp不再更新
    alp(idx)=gamma(idx)./(mu(idx)'.^2);
    beta=(N-sum(gamma))/((t-K*mu)'*(t-K*mu));
    
    % 判斷收斂則退出循環
    tmp_err=max(abs(alp(idx)-alp_old(idx))./abs(alp(idx)))+abs(beta-beta_old)/abs(beta);
    if tmp_err<0.1,
        break;
    end;
    
end;

% 計算并呈現訓練集的預測結果,注意,這裡在計算預測結果時僅使用了有效的mu值,即不為零的mu值
tpred=K(:,idx)*mu(idx);
figure(1);
plot(x,tpred,'b*');

% 呈現相關向量(Relevance Vectors)
figure(1);
plot(x(idx(2:end)),t(idx(2:end)),'ro');
title('Training Data');
xlabel('x');
ylabel('y');
legend('RAW Data','Predicted','Relevance Vector');

% 測試集/驗證集資料生成、預測及呈現
[xts,tts]=datagen(Nts);
figure; hold;
plot(xts,tts,'.');

xtsK=RBFkernal(xts,x);
ttspred=xtsK(:,idx)*mu(idx);
plot(xts,ttspred,'o');
title('Testing Data');
xlabel('x');
ylabel('y');
legend('RAW Data','Predicted');

%%%%%%%%%%%%%%%%%%%%%

function [ x, y ] = datagen( N )
%DATAGEN Summary of this function goes here
%   Detailed explanation goes here

x=(rand(N,1)-0.5)*4;
y=cos(x*5)-x+randn(N,1)/4;  % 注意:按照RVM算法,各資料點的誤差方差應該一樣

end

%%%%%%%%%%%%%%%%%%

function phi = RBFkernal( x, xb )
%RBFKERNAL Summary of this function goes here
%   Detailed explanation goes here
%   
%   x和xb之間的K矩陣:k(i,j)=exp(-(x(i)-xb(j))^2)

len=size(x,1);
lenb=size(xb,1);
phi=ones(len,lenb+1);

for ii=1:lenb,
    phi(:,ii+1)=exp(-(x-xb(ii)).^2);
end;

end

           

繼續閱讀