首先大家了解一下什么是K-均值聚类,如下:
K均值聚类算法是先随机选取K个对象作为初始的聚类中心。然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。一旦全部对象都被分配了,每个聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。
我们查阅资料了解到K-均值聚类的python代码如下:
def distEclud(vecA, vecB):
return sqrt(sum(power(vecA - vecB, 2))) #la.norm(vecA-vecB)
def randCent(dataSet, k):
n = shape(dataSet)[1]
centroids = mat(zeros((k,n)))#create centroid mat
for j in range(n):#create random cluster centers, within bounds of each dimension
minJ = min(dataSet[:,j])
rangeJ = float(max(dataSet[:,j]) - minJ)
centroids[:,j] = mat(minJ + rangeJ * random.rand(k,1))
return centroids
def kMeans(dataSet, k, distMeas=distEclud, createCent=randCent):
m = shape(dataSet)[0]
clusterAssment = mat(zeros((m,2)))#create mat to assign data points
#to a centroid, also holds SE of each point
centroids = createCent(dataSet, k)
clusterChanged = True
while clusterChanged:
clusterChanged = False
for i in range(m):#for each data point assign it to the closest centroid
minDist = inf; minIndex = -1
for j in range(k):
distJI = distMeas(centroids[j,:],dataSet[i,:])
if distJI < minDist:
minDist = distJI; minIndex = j
if clusterAssment[i,0] != minIndex: clusterChanged = True
clusterAssment[i,:] = minIndex,minDist**2
print centroids
for cent in range(k):#recalculate centroids
ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]#get all the point in this cluster
centroids[cent,:] = mean(ptsInClust, axis=0) #assign centroid to mean
return centroids, clusterAssment
我们用java代码开始实现,首先是欧式几何距离计算
private static double distEclud(DenseMatrix64F vecA,DenseMatrix64F vecB,int vecA_row,int vecB_row) {
double rs=0;
for(int i=0;i<vecA.numCols;i++) {
rs+=Math.pow((vecA.get(vecA_row,i)-vecB.get(vecB_row,i)),2);
}
return Math.sqrt(rs);
}
然后是簇的初始化
private static DenseMatrix64F randCent(DenseMatrix64F dataSet,int k) {
DenseMatrix64F centroids = new DenseMatrix64F(k,dataSet.numCols);
centroids.zero();
for(int j=0;j<dataSet.numCols;j++) {
double minJ = Double.MAX_VALUE;
double maxJ = Double.MIN_VALUE;
for(int i=0;i<dataSet.numRows;i++) {
if(minJ > dataSet.get(i, j))
minJ = dataSet.get(i, j);
if(maxJ < dataSet.get(i, j))
maxJ = dataSet.get(i, j);
}
double rangeJ = maxJ - minJ;
for(int i=0;i<k;i++) {
centroids.set(i, j, minJ + rangeJ * Math.random());
}
}
return centroids;
}
然后便是k-means的关键替代簇函数
public static DenseMatrix64F[] kMeans(DenseMatrix64F dataSet,int k) {
DenseMatrix64F clusterAssment = new DenseMatrix64F(dataSet.numRows, 2);
clusterAssment.zero();
DenseMatrix64F centroids = randCent(dataSet,k);
boolean clusterChanged = true;
while(clusterChanged) {
clusterChanged = false;
int changed = 0;
for(int i=0;i<dataSet.numRows;i++) {
double minDist = Double.MAX_VALUE;
double minIndex = -1;
for(int j=0;j<k;j++) {
double distJI = distEclud(centroids,dataSet,j,i);
if(distJI < minDist) {
minDist = distJI;
minIndex = j;
}
}
if(clusterAssment.get(i, 0) != minIndex) {
clusterChanged = true;
changed++;
}
clusterAssment.set(i, 0, minIndex);
clusterAssment.set(i,1, minDist*minDist);
}
System.out.println("变动点数:"+changed);
System.out.println(centroids);
for(int cent=0;cent<k;cent++) {
DenseMatrix64F tmp = new DenseMatrix64F(0,dataSet.numCols);
for(int i=0;i<dataSet.numRows;i++) {
if(clusterAssment.get(i, 0) == cent) {
tmp.reshape(tmp.numRows+1, dataSet.numCols, true);
for(int j=0;j<dataSet.numCols;j++) {
tmp.set(tmp.numRows-1, j, dataSet.get(i, j));
}
}
}
if(tmp.numRows > 0) {
for(int i=0;i<tmp.numCols;i++) {
double tmpSum=0;
for(int j=0;j<tmp.numRows;j++) {
tmpSum+=dataSet.get(j, i);
}
centroids.set(cent,i,tmpSum/(tmp.numRows));
}
}
}
//对簇点进行排序
for(int i=0;i<centroids.numRows-1;i++) {
for(int j=i+1;j<centroids.numRows;j++) {
if(centroids.get(i, 0) > centroids.get(j, 0)) {
for(int n=0;n<centroids.numCols;n++) {
double tmp = centroids.get(j,n);
centroids.set(j, n,centroids.get(i, n));
centroids.set(i, n,tmp);
}
}
}
}
}
return new DenseMatrix64F[] {centroids, clusterAssment};
}
这里面有一个细节比较重要就是更新簇之后,簇得排序,否则做聚类运算的时候簇点会循环换位导致跳不出循环。
开始测试
List<String> list = new ArrayList<String>();
try{
BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch10\\testSet2.txt"));
String s = null;
while((s = br.readLine())!=null){
list.add(s);
}
br.close();
}catch(Exception e){
e.printStackTrace();
}
DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),2);
for(int i=0;i<list.size();i++) {
String[] items = list.get(i).split(" ");
dataMatIn.set(i, 0, Double.parseDouble(items[0]));
dataMatIn.set(i,1, Double.parseDouble(items[1]));
}
DenseMatrix64F[] test = kMeans(dataMatIn,4);
System.out.println(test[0]);
System.out.println(test[1]);
在多次循环后收敛
ok搞定