天天看点

2.2 KNN算法实现

1 源码下载

下载代码

2 代码截图

2.2 KNN算法实现

3 KNN代码实现

import random
import math
from operator import itemgetter

"""
 定义加载数据函数
 fileName: 文件名称
 split:分割点,将原始数据分为训练集和测试集
 trainingData:训练集
 testData:测试集
"""
def loadData(fileName,split,trainingData,testData):
    file = open(fileName,'r')
    lines = file.read()
    rows = lines.strip('\n').split('\n')
    for x in range( len(rows) ):
       row = rows[x].strip(',').split(',')
       if random.random() < split:
           trainingData.append(row)
       else:
           testData.append(row)

"""
得到最近的点
trainingDataList:训练集
testData:测试集
"""
def getNeighbors(trainingDataList,testData,K):
    dimension = len(testData) - 1
    distances = []
    for trainingData in trainingDataList :
        distance = euclideanDistance(trainingData,testData,dimension)
        distances.append( ( trainingData,distance) )
    distances = sorted(distances,key=itemgetter(1) )
    neighbors = []
    for x in range(K):
        neighbors.append(distances[x][0])
    return neighbors

"""
进行投票,小数服从多数
例如 A:3票, B:2票,C:5票
那么结果为 C
"""
def vote(neighbors):
    dict = {}
    for neighbor in neighbors:
        type = neighbor[ len(neighbor) - 1 ]
        if type in dict:
            dict[type] += 1
        else:
            dict[type] = 1
    items = sorted(dict.items(),key=itemgetter(1),reverse=True)
    return items[0][0]

"""
估算两点之间的欧几里得距离
point1 : 第一个点
point2 : 第二个点
dimension : 维度,比如 x1(7.5) 维度为1, x2(1,5)维度为2, x3(1,6,32)维度为3
"""
def euclideanDistance( point1,point2, dimension):
    distance = 0
    for x in range(dimension):
        distance += pow( float( point1[x] ) - float( point2[x] ) ,2 )
    return math.sqrt(distance)

"""
得到预测的精确度
"""
def getAccuracy(forecastList):
    errorCount = 0
    for forecast in forecastList:
        if forecast[1] != forecast[0][-1]:
            errorCount += 1
    accuracy = 1 - float( errorCount / len(forecastList) )
    return  accuracy


"""
main函数
"""
def main():
    trainingDataList = []
    testDataList = []
    split = 0.9
    K = 5
    fileName = "D:/workspace/MachineLearning/07stage/1-fundamental/01/KNN/irisdata.txt"
    loadData(fileName,split,trainingDataList,testDataList)
    forecastList = []
    for testData in testDataList:
        neighbors = getNeighbors(trainingDataList,testData,K)
        voteResult = vote(neighbors)
        forecastList.append( ( testData,voteResult ) )
    accuracy = getAccuracy(forecastList)
    print("准确度:{}% ,错误率 : {} %".format( accuracy * 100 , ( 1 - accuracy) * 100 ),)
for x in range(100):
    main()

           

继续阅读