天天看點

Matlab圖像識别/檢索系列(6)-10行代碼完成深度學習網絡之基于CNN的圖像分類

源站:http://blog.51cto.com/8764888/2053960?utm_source=oschina-app

在Matlab2017中,完成一個使用CNN網絡進行分類的示例非常簡單。為了便于建立圖像集,Matlab2015引入了 ImageDatastore對象,實作函數為imageDatastore,該函數可以輕易的完成周遊一個檔案夾中的圖像建立圖像及的功能,不管該檔案夾是否含有子檔案夾。這也是它差別于imageSet的地方之一。代碼如下。

%exam1.m
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');
%建立圖像集,參數設定為包含子檔案夾、子檔案夾名作為類标簽
digitData = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,'LabelSource','foldernames');
figure;
%取20個置亂數字
perm = randperm(10000,20);
%顯示20幅圖像
for i = 1:20
    subplot(4,5,i);
    imshow(digitData.Files{perm(i)});
end
trainingNumFiles = 750;
%若報錯,可改為rng('default')
rng(1) 
%在圖象集每一類中随機取trainingNumFiles個圖像作為訓練圖像,其餘作為測試圖像
[trainDigitData,testDigitData] = splitEachLabel(digitData,...
                trainingNumFiles,'randomize');
%建立簡單CNN網絡
layers = [imageInputLayer([28 28 1]);
          convolution2dLayer(5,20);
          reluLayer();
          maxPooling2dLayer(2,'Stride',2);
          fullyConnectedLayer(10);
          softmaxLayer();
          classificationLayer()];
%設定訓練參數
options = trainingOptions('sgdm','MaxEpochs',20,...
    'InitialLearnRate',0.0001);
%訓練CNN網絡
convnet = trainNetwork(trainDigitData,layers,options);
%對測試圖像進行分類
YTest = classify(convnet,testDigitData);
%顯示測試圖像标簽
TTest = testDigitData.Labels;