利用 python 實作 K-Means聚類
一.k-means聚類算法簡介
(一)k-means聚類算法的概念
k-means算法是很典型的基于距離的聚類算法,采用距離作為相似性的評價名額,即認為兩個對象的距離越近,其相似度就越大。該算法認為簇是由距離靠近的對象組成的,是以把得到緊湊且獨立的簇作為最終目标。
k個初始類聚類中心點的選取對聚類結果具有較大的影響,因為在該算法第一步中是随機的選取任意k個對象作為初始聚類的中心,初始地代表一個簇。該算法在每次疊代中對資料集中剩餘的每個對象,根據其與各個簇中心的距離将每個對象重新賦給最近的簇。當考察完所有資料對象後,一次疊代運算完成,新的聚類中心被計算出來。如果在一次疊代前後,J的值沒有發生變化,說明算法已經收斂。
(二)對k-means算法的認識
1.優點
(1)算法快速、簡單。
(2)對大資料集有較高的效率并且是可伸縮性的。
(3)時間複雜度近于線性,而且适合挖掘大規模資料集。K-Means聚類算法的時間複雜度是O(nkt) ,其中n代表資料集中對象的數量,t代表着算法疊代的次數,k代表着簇的數目。
2.缺點
(1)聚類是一種無監督的學習方法,在 K-means 算法中 K 是事先給定的,K均值算法需要使用者指定建立的簇數k,但這個 K 值的標明是非常難以估計的。
(2)在 K-means 算法中,首先需要根據初始聚類中心來确定一個初始劃分,然後對初始劃分進行優化。這個初始聚類中心的選擇對聚類結果有較大的影響,一旦初始值選擇的不好,可能無法得到有效的聚類結果,這也成為 K-means算法的一個主要問題。
(3)從 K-means 算法架構可以看出,該算法需要不斷地進行樣本分類調整,不斷地計算調整後的新的聚類中心,是以當資料量非常大時,算法的時間開銷是非常大的。是以需要對算法的時間複雜度進行分析、改進,提高算法應用範圍,而這導緻K均值算法在大資料集上收斂較慢。
二.算法過程
1.選取k個初始聚類中心點
2. 計算每個點與各簇中心點的歐氏距離
3. 根據每個點與各個簇中心的距離将每個對象重新賦給最近的簇
3. 更新簇的中心點
4. 疊代,直到收斂(通常停止疊代的條件: 簇的中心點不再發生明顯的變化,即收斂)
三.設計思想
1.利用pandas和numpy将CSV中的資料導入到python中,并轉化成多元向量。
2.先通過觀察所有點的分布來确定k的取值,将導入到python中的所有坐标值通過彩虹色來實作對k值的初步判斷。
3. 選取k=3,随機選擇三個中心點,計算所有值到簇中心點的距離并根據每個點與各個簇中心的距離将每個對象重新賦給距離最近的簇。
4.對新生成的簇的所有值求平均值來獲得新的簇中心點。
5.通過遞歸重複3、4兩步,直至簇的中心點不再發生明顯的變化。
6.生成疊代停止後的散點圖,以及各簇中所包含的點。
四.實作代碼
&emsp這一部分的代碼是将資料集中的所有點按照彩虹色畫出來;先通過觀察所有點的分布來确定k的取值。
#導入子產品
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
#設定工作路徑
os.chdir('d:/workpath')
#利用pandas和numpy将CSV中的資料導入到python中
df = pd.read_csv('d:\iris.csv',header=None)
content=pd.DataFrame(df,columns=[0,1,2,3])
a=np.array(content)
#建立繪制三維圖的環境
fig = plt.figure()
ax = Axes3D(fig)
#繪制圖像
ax.scatter(content[0], content[2], content[3],c=content[3],cmap='rainbow')
#調整觀察角度和方位角。這裡将俯仰角設為30度,把方位角調整為-45度
ax.view_init(30, -45)
plt.show()
根據圖1顯示所有的值的分布可以将資料大緻可以分為三類,是以我們取k=3來進行聚類分析。
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsICM38FdsYkRGZkRG9lcvx2bjxiNx8VZ6l2cs0TPB9UNnpWT6FEROBDOsJGcohVYsR2MMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2X0hXZ0xCMx81dvRWYoNHLrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdssmch1mclRXY39CXldWYtlWPzNXZj9mcw1ycz9WL49zZuBnL4cDNyEDMwUTM5IDMxkTMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
圖1 彩虹色散點圖
利用kmean方法進行聚類,将聚類結果輸出到out_file.txt中,并将不同的類型的點按不同顔色畫在坐标系中,
#生成空矩陣用于存儲各簇的資料
b=np.zeros(0)
c=np.zeros(0)
d=np.zeros(0)
#随機選擇三個中心點,計算所有值到中心點的距離
for i in range(0,len(a)):
dis1=np.linalg.norm(a[i:i+1]-a[7:8])
dis2=np.linalg.norm(a[i:i+1]-a[99:100])
dis3=np.linalg.norm(a[i:i+1]-a[112:113])
# 根據每個點與各個簇中心的距離将每個對象重新賦給最近的簇
if min(dis1,dis2,dis3)==dis1:
b=np.append(b,a[i:i+1])
x=int(len(b)/4)
b=np.reshape(b,newshape=(x,4))
if min(dis1,dis2,dis3)==dis2:
c=np.append(c,a[i:i+1])
x=int(len(c)/4)
c=np.reshape(c,newshape=(x,4))
if min(dis1,dis2,dis3)==dis3:
d=np.append(d,a[i:i+1])
x=int(len(d)/4)
d=np.reshape(d,newshape=(x,4))
#對新生成的簇的所有值求平均值來獲得新的簇中心點
b=np.mean(b,axis=0)
c=np.mean(c,axis=0)
d=np.mean(d,axis=0)
#建立k_mean函數
def k_mean(b,c,d):
#儲存新的簇中的所有點
list1 = []
list2 = []
list3 = []
#儲存新的簇中心點
gap1 = b
gap2 = c
gap3 = d
#生成空矩陣用于存儲各簇的資料
b = np.zeros(0)
c = np.zeros(0)
d = np.zeros(0)
#計算所有值到新的簇中心點的距離
for i in range(0,len(a)):
dis1=np.linalg.norm(a[i:i+1]-gap1)
dis2=np.linalg.norm(a[i:i+1]-gap2)
dis3=np.linalg.norm(a[i:i+1]-gap3)
# 根據每個點與各個簇中心的距離将每個對象重新賦給最近的簇
if min(dis1, dis2, dis3) == dis1:
list1.append(i)
b = np.append(b, a[i:i + 1])
x = int(len(b) / 4)
b = np.reshape(b, newshape=(x, 4))
if min(dis1, dis2, dis3) == dis2:
list2.append(i)
c = np.append(c, a[i:i + 1])
x = int(len(c) / 4)
c = np.reshape(c, newshape=(x, 4))
if min(dis1, dis2, dis3) == dis3:
list3.append(i)
d = np.append(d, a[i:i + 1])
x = int(len(d) / 4)
d = np.reshape(d, newshape=(x, 4))
#将新生成的簇的所有值轉化為多元向量
line1= pd.DataFrame(b)
line2=pd.DataFrame(c)
line3=pd.DataFrame(d)
#對新生成的簇的所有值求平均值來獲得新的簇中心點
b=np.mean(b,axis=0)
c=np.mean(c,axis=0)
d=np.mean(d,axis=0)
#當簇的中心點不再發生明顯的變化時停止遞歸
if abs(sum(b - gap1))+abs(sum(c - gap2))+abs(sum(d - gap3))<10**(-64):
# 建立out_file用于存儲輸出結果
out_file = open('out_file.txt', 'w')
out_file.write('Setosa\n')
out_file.writelines(str(list1))
out_file.write('\nVersicolor\n')
out_file.writelines(str(list2))
out_file.write('\nVirginica\n')
out_file.writelines(str(list3))
# 三維散點的資料
x1 = line1[0]
y1 = line1[2]
z1 = line1[3]
x2 = line2[0]
y2 = line2[2]
z2 = line2[3]
x3 = line3[0]
y3 = line3[2]
z3 = line3[3]
#建立繪制三維圖的環境
fig = plt.figure()
ax = Axes3D(fig)
# 繪制散點圖
ax.scatter(x1, y1, z1,cmap='Blues',label='Setosa')
ax.scatter(x2, y2, z2,c='g',label='Versicolor',marker='D')
ax.scatter(x3, y3, z3,c='r',label='Virginica')
ax.legend(loc='best')
# 調整觀察角度和方位角。這裡将俯仰角設為30度,把方位角調整為-45度
ax.view_init(30, -45)
plt.show()
#當簇的中心點發生明顯的變化時繼續遞歸k_mean函數
else:
gap1=b
gap2=c
gap3=d
return k_mean(b,c,d)
#運作k_mean函數
k_mean(b,c,d)
五.分析結果
随機選擇三個中心點,計算所有值到簇中心點的距離并根據每個點與各個簇中心的距離将每個對象重新賦給距離最近的簇。并對新生成的簇的所有值求平均值來獲得新的簇中心點。通過多次疊代獲得最終的聚類結果如圖2所示。
圖2 聚類結果展示
out_file.txt檔案中的輸出結果:
Setosa:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49]
Versicolor:
[51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
73, 74, 75, 76, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99, 101, 106, 113, 114, 119, 121, 123, 126, 127, 133, 138, 142,
146, 149]
Virginica:
[50, 52, 77, 100, 102, 103, 104, 105, 107, 108, 109, 110, 111, 112, 115, 116,
117, 118, 120, 122, 124, 125, 128, 129, 130, 131, 132, 134, 135, 136, 137,
139, 140, 141, 143, 144, 145, 147, 148]
通過輸出的簇的各點可以看出聚類得到的結果和真實值存在一定的偏差,展現了 K-means 算法中的初始聚類中心的選擇對聚類結果的影響。輸出結果中有17個資料和原始資料的分類有差別,我們得到的結果和原始資料的相似率為89%。