本節書摘來自異步社群出版社《寫給程式員的資料挖掘實踐指南》一書中的第5章,第5.3節,作者:【美】ron zacharski(紮哈爾斯基),更多章節内容可以通路雲栖社群“異步社群”公衆号檢視。
到目前為止,通過計算下列精确率百分比,我們對分類器進行評估:
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZuBnL3ImM1UTOmZmZ2gTOzIDMkJWOlZjMlNGM2QGM5YmYzEDZwcjNiRmNz8CXt92Yu4GZjlGbh5SZslmZxl3Lc9CX6MHc0RHaiojIsJye.png)
有時,我們可能希望得到分類器算法的更詳細的性能。能夠詳細揭示性能的一種可視化方法是引入一個稱為混淆矩陣(confusion matrix)的表格。混淆矩陣的行代表測試樣本的真實類别,而列代表分類器所預測出的類别。
它之是以名為混淆矩陣,是因為很容易通過這個矩陣看清楚算法産生混淆的地方。下面以女運動員分類為例來展示這個矩陣。假設我們有一個由100名女子體操運動員、100名wnba籃球運動員及100名女子馬拉松運動員的屬性構成的資料集。我們利用10折交叉驗證法對分類器進行評估。在10折交叉測試中,每個執行個體正好隻被測試過一次。上述測試的結果可能如下面的混淆矩陣所示:
同前面一樣,每一行代表執行個體實際屬于的類别,每一列代表的是分類器預測的類别。是以,上述表格表明,有83個體操運動員被正确分類,但是卻有17個被錯分為馬拉松運動員。92個籃球運動員被正确分類,但是卻有8個被錯分為馬拉松運動員。85名馬拉松運動員被正确分類,但是卻有8個人被錯分為體操運動員,還有16個人被錯分為籃球運動員。
混淆矩陣的對角線給出了正确分類的執行個體數目。
上述表格中,算法的精确率為:
通過觀察上述矩陣很容易了解分類器的錯誤類型。在本例當中,分類器在區分體操運動員和籃球運動員上表現得相當不錯,而有時體操運動員和籃球運動員卻會被誤判為馬拉松運動員,馬拉松運動員有時被誤判為體操運動員或籃球運動員。
一個程式設計的例子
回到上一章當中提到的來自卡内基梅隆大學的汽車mpg資料集,該資料集的格式如下:
下面試圖基于氣缸的數目、排水量(立方英寸)、功率、重量和加速時間預測汽車的mpg。我将所有392個執行個體放到mpgdata.txt檔案中,然後編寫了如下的短python程式,該程式利用分層采樣方法将資料分到10個桶中(資料集及python代碼都可以從網站guidetodatamining.com下載下傳)。
執行上述代碼會産生10個分别為mpgdata01、mpgdata02… mpgdata10的檔案。
能否修改上一章中近鄰算法的代碼,以使test函數能夠在剛剛建構的10個檔案上進行10折交叉驗證(該資料集可以從網站guidetodatamining.com下載下傳)?
你的程式應該輸出類似如下矩陣的混淆矩陣:
.
該解答隻涉及如下方面:
修改initializer方法以便從9個桶中讀取資料;
加入一個新的方法對一個桶中的資料進行測試;
加入一個新的過程來執行10折交叉驗證過程。
下面依次來考察上述修改。
initializer方法的簽名看起來如下:
每個桶的檔案名類似于mpgdata-01、mpgdata-02,等等。這種情況下,bucketprefix将是“mpgdata”,而testbucketnumber是包含測試資料的桶。如果testbucketnumber為3,則分類器将會在桶1、2、4、5、6、7、8、9、10上進行訓練。dataformat是一個如何解釋資料中每列的字元串,比如:
它表示第一列代表執行個體的類别,下面5列代表執行個體的數值型屬性,最後一列會被看成注釋。
新的初始化方法的完整代碼如下:
下面編寫一個新的方法來測試一個桶中的資料。
它以bucketprefix和bucketnumber為輸入,如果前者為“mpgdata”、後者為3的話,測試資料将會從檔案mpgdata-03中讀取,而testbucket将會傳回如下格式的字典:
字典的鍵代表的是執行個體的真實類别。例如,上面第一行表示真實類别為35mpg的執行個體的結果。每個鍵的值是另一部字典,該字典代表分類器對執行個體進行分類的結果。例如行
'15':
<code>`</code>javascript
{'20': 3, '15': 4, '10': 1},
def tenfold(bucketprefix, dataformat):
results = {}
for i in range(1, 11):
c = classifier(bucketprefix, i, dataformat)
t = c.testbucket(bucketprefix, i)
for (key, value) in t.items():
results.setdefault(key, {})
for (ckey, cvalue) in value.items():
results[key].setdefault(ckey, 0)
resultskey += cvalue
# now print results
categories = list(results.keys())
categories.sort()
print( "n classified as: ")
header = " "
subheader = " +"
for category in categories:
header += category + " "
subheader += "----+"
print (header)
print (subheader)
total = 0.0
correct = 0.0
row = category + " |"
for c2 in categories:
if c2 in results[category]:
count = resultscategory
else:
count = 0
row += " %2i |" % count
total += count
if c2 == category:
correct += count
print(row)
print(subheader)
print("n%5.3f percent correct" %((correct * 100) / total))
print("total of %i instances" % total)
tenfold("mpgdata", "class num num num num num comment")