本节书摘来自异步社区出版社《写给程序员的数据挖掘实践指南》一书中的第5章,第5.3节,作者:【美】ron zacharski(扎哈尔斯基),更多章节内容可以访问云栖社区“异步社区”公众号查看。
到目前为止,通过计算下列精确率百分比,我们对分类器进行评估:
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZuBnL3ImM1UTOmZmZ2gTOzIDMkJWOlZjMlNGM2QGM5YmYzEDZwcjNiRmNz8CXt92Yu4GZjlGbh5SZslmZxl3Lc9CX6MHc0RHaiojIsJye.png)
有时,我们可能希望得到分类器算法的更详细的性能。能够详细揭示性能的一种可视化方法是引入一个称为混淆矩阵(confusion matrix)的表格。混淆矩阵的行代表测试样本的真实类别,而列代表分类器所预测出的类别。
它之所以名为混淆矩阵,是因为很容易通过这个矩阵看清楚算法产生混淆的地方。下面以女运动员分类为例来展示这个矩阵。假设我们有一个由100名女子体操运动员、100名wnba篮球运动员及100名女子马拉松运动员的属性构成的数据集。我们利用10折交叉验证法对分类器进行评估。在10折交叉测试中,每个实例正好只被测试过一次。上述测试的结果可能如下面的混淆矩阵所示:
同前面一样,每一行代表实例实际属于的类别,每一列代表的是分类器预测的类别。因此,上述表格表明,有83个体操运动员被正确分类,但是却有17个被错分为马拉松运动员。92个篮球运动员被正确分类,但是却有8个被错分为马拉松运动员。85名马拉松运动员被正确分类,但是却有8个人被错分为体操运动员,还有16个人被错分为篮球运动员。
混淆矩阵的对角线给出了正确分类的实例数目。
上述表格中,算法的精确率为:
通过观察上述矩阵很容易了解分类器的错误类型。在本例当中,分类器在区分体操运动员和篮球运动员上表现得相当不错,而有时体操运动员和篮球运动员却会被误判为马拉松运动员,马拉松运动员有时被误判为体操运动员或篮球运动员。
一个编程的例子
回到上一章当中提到的来自卡内基梅隆大学的汽车mpg数据集,该数据集的格式如下:
下面试图基于气缸的数目、排水量(立方英寸)、功率、重量和加速时间预测汽车的mpg。我将所有392个实例放到mpgdata.txt文件中,然后编写了如下的短python程序,该程序利用分层采样方法将数据分到10个桶中(数据集及python代码都可以从网站guidetodatamining.com下载)。
执行上述代码会产生10个分别为mpgdata01、mpgdata02… mpgdata10的文件。
能否修改上一章中近邻算法的代码,以使test函数能够在刚刚构建的10个文件上进行10折交叉验证(该数据集可以从网站guidetodatamining.com下载)?
你的程序应该输出类似如下矩阵的混淆矩阵:
.
该解答只涉及如下方面:
修改initializer方法以便从9个桶中读取数据;
加入一个新的方法对一个桶中的数据进行测试;
加入一个新的过程来执行10折交叉验证过程。
下面依次来考察上述修改。
initializer方法的签名看起来如下:
每个桶的文件名类似于mpgdata-01、mpgdata-02,等等。这种情况下,bucketprefix将是“mpgdata”,而testbucketnumber是包含测试数据的桶。如果testbucketnumber为3,则分类器将会在桶1、2、4、5、6、7、8、9、10上进行训练。dataformat是一个如何解释数据中每列的字符串,比如:
它表示第一列代表实例的类别,下面5列代表实例的数值型属性,最后一列会被看成注释。
新的初始化方法的完整代码如下:
下面编写一个新的方法来测试一个桶中的数据。
它以bucketprefix和bucketnumber为输入,如果前者为“mpgdata”、后者为3的话,测试数据将会从文件mpgdata-03中读取,而testbucket将会返回如下格式的字典:
字典的键代表的是实例的真实类别。例如,上面第一行表示真实类别为35mpg的实例的结果。每个键的值是另一部字典,该字典代表分类器对实例进行分类的结果。例如行
'15':
<code>`</code>javascript
{'20': 3, '15': 4, '10': 1},
def tenfold(bucketprefix, dataformat):
results = {}
for i in range(1, 11):
c = classifier(bucketprefix, i, dataformat)
t = c.testbucket(bucketprefix, i)
for (key, value) in t.items():
results.setdefault(key, {})
for (ckey, cvalue) in value.items():
results[key].setdefault(ckey, 0)
resultskey += cvalue
# now print results
categories = list(results.keys())
categories.sort()
print( "n classified as: ")
header = " "
subheader = " +"
for category in categories:
header += category + " "
subheader += "----+"
print (header)
print (subheader)
total = 0.0
correct = 0.0
row = category + " |"
for c2 in categories:
if c2 in results[category]:
count = resultscategory
else:
count = 0
row += " %2i |" % count
total += count
if c2 == category:
correct += count
print(row)
print(subheader)
print("n%5.3f percent correct" %((correct * 100) / total))
print("total of %i instances" % total)
tenfold("mpgdata", "class num num num num num comment")