天天看点

机器学习(十一)- Multiclass Classification

Multiclass Classification

其实多分类问题,之前就要讲的,但是正好programming exercies 3有关于用神经网络进行多分类的作业,于是就在这里一起讲了,正好比较一下逻辑回归和神经网络的多分类的区别。

之前一直讲的都是二分类问题,对于现实生活,这显然是不够的,更多的用到的是多分类。

逻辑回归

对于二分类,y只有0,1,那么对于多分类的话,y自然就不止两个了。下面我们以三个分类为例,来展示多分类的分类思想

机器学习(十一)- Multiclass Classification

当你看完上图,你一定有一种醍醐灌顶的感觉,至少我第一次看到的时候有这种感觉,不去纠结怎么去构造复杂的函数一次性分出3类,而是直接构造3个二分类器,毕竟对一某一类来说,另外两类都可以叫做“不是这一类”。最后只要选择一个概率最高的作为最后的结论即可。下图是一般式。

机器学习(十一)- Multiclass Classification
% 写在前面,下面的只是核心函数和简单步骤,函数要使用需要单独贴到.m文件里
% 计算梯度和代价函数
function [J, grad] = lrCostFunction(theta, X, y, lambda)
m = length(y); % number of training examples
J = -1/m*(y'*log(sigmoid(X*theta))+ ...
    (1-y')*log(1-sigmoid(X*theta)))+ ...
    lambda/(2*m)*sum(theta(2:end).^2);
grad(1) = 1/m*(X(:,1)'*(sigmoid(X*theta)-y));
grad(2:end) = 1/m*(X(:,2:end)'*(sigmoid(X*theta)-y))+lambda/m*theta(2:end);
grad = grad(:);
end

% 训练模型
function [all_theta] = oneVsAll(X, y, num_labels, lambda)
% Some useful variables
m = size(X, 1);
n = size(X, 2);
% Add ones to the X data matrix
X = [ones(m, 1) X]; 
for c = 1:num_labels
    initial_theta = zeros(n + 1, 1);
    options = optimset('GradObj', 'on', 'MaxIter', 50); 
    [all_theta(c,:)] = ...
        fmincg (@(t)(lrCostFunction(t, X, (y == c), lambda)), ...
        initial_theta, options);
end
end

% 预测函数
function p = predictOneVsAll(all_theta, X)
m = size(X, 1);
num_labels = size(all_theta, 1);
% Add ones to the X data matrix
X = [ones(m, 1) X];
[~ , p] = max(sigmoid(X*all_theta'),[],2);
end

lambda = 0.1;
[all_theta] = oneVsAll(X, y, num_labels, lambda);

pred = predictOneVsAll(all_theta, X);
           

神经网络

对于神经网络,多分类任务,我们的输出层不只只有一个输出元,而是有几个分类,就有几个元,这样神经网络最终的输出就是一个向量,哪个index下的概率最高,该输出就判定为这个index对应的分类,比如输出向量的第一个代表行人,如果输入一张图片进入网络,最后输出的 h θ ( x ) h_\theta(x) hθ​(x)如下图左下角,那么这个图片就是行人。

机器学习(十一)- Multiclass Classification
% 写在前面,下面的只是核心函数和简单步骤,函数要使用需要单独贴到.m文件里
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
num_labels = size(Theta2, 1);
X = [ones(m, 1) X];

[~,p] = max(sigmoid([ones(m,1) sigmoid(X*Theta1')]*Theta2'),[],2);
end

pred = predict(Theta1, Theta2, X);
           

由于本次作业,神经网络的参数都是给定的,没有进行训练,所以无法比较两者的运算速度。

继续阅读