天天看点

用java实现K-均值聚类(k-means)

首先大家了解一下什么是K-均值聚类,如下:

K均值聚类算法是先随机选取K个对象作为初始的聚类中心。然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。一旦全部对象都被分配了,每个聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

我们查阅资料了解到K-均值聚类的python代码如下:

def distEclud(vecA, vecB):
    return sqrt(sum(power(vecA - vecB, 2))) #la.norm(vecA-vecB)

def randCent(dataSet, k):
    n = shape(dataSet)[1]
    centroids = mat(zeros((k,n)))#create centroid mat
    for j in range(n):#create random cluster centers, within bounds of each dimension
        minJ = min(dataSet[:,j]) 
        rangeJ = float(max(dataSet[:,j]) - minJ)
        centroids[:,j] = mat(minJ + rangeJ * random.rand(k,1))
    return centroids
    
def kMeans(dataSet, k, distMeas=distEclud, createCent=randCent):
    m = shape(dataSet)[0]
    clusterAssment = mat(zeros((m,2)))#create mat to assign data points 
                                      #to a centroid, also holds SE of each point
    centroids = createCent(dataSet, k)
    clusterChanged = True
    while clusterChanged:
        clusterChanged = False
        for i in range(m):#for each data point assign it to the closest centroid
            minDist = inf; minIndex = -1
            for j in range(k):
                distJI = distMeas(centroids[j,:],dataSet[i,:])
                if distJI < minDist:
                    minDist = distJI; minIndex = j
            if clusterAssment[i,0] != minIndex: clusterChanged = True
            clusterAssment[i,:] = minIndex,minDist**2
        print centroids
        for cent in range(k):#recalculate centroids
            ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]#get all the point in this cluster
            centroids[cent,:] = mean(ptsInClust, axis=0) #assign centroid to mean 
    return centroids, clusterAssment
           

我们用java代码开始实现,首先是欧式几何距离计算

private static double distEclud(DenseMatrix64F vecA,DenseMatrix64F vecB,int vecA_row,int vecB_row) {
		
		double rs=0;
		
		for(int i=0;i<vecA.numCols;i++) {
			rs+=Math.pow((vecA.get(vecA_row,i)-vecB.get(vecB_row,i)),2);
		}
		
		return Math.sqrt(rs);
		
	}
           

然后是簇的初始化

private static DenseMatrix64F randCent(DenseMatrix64F dataSet,int k) {
		DenseMatrix64F centroids = new DenseMatrix64F(k,dataSet.numCols);
		centroids.zero();
		
		for(int j=0;j<dataSet.numCols;j++) {
			double minJ = Double.MAX_VALUE;
			double maxJ = Double.MIN_VALUE;
			for(int i=0;i<dataSet.numRows;i++) {
				if(minJ > dataSet.get(i, j))
					minJ = dataSet.get(i, j);
				if(maxJ < dataSet.get(i, j))
					maxJ = dataSet.get(i, j);
			}
			double rangeJ = maxJ - minJ;
			
			for(int i=0;i<k;i++) {
				centroids.set(i, j, minJ + rangeJ * Math.random());
			}
		}
		
		return centroids;
	}
           

然后便是k-means的关键替代簇函数

public static DenseMatrix64F[] kMeans(DenseMatrix64F dataSet,int k) {
	    DenseMatrix64F clusterAssment = new DenseMatrix64F(dataSet.numRows, 2);
	    clusterAssment.zero();
	    DenseMatrix64F centroids = randCent(dataSet,k);
	    boolean clusterChanged = true;
	    while(clusterChanged) {	    	
	        clusterChanged = false;
	        
	        int changed = 0;
	        
	        for(int i=0;i<dataSet.numRows;i++) {
	            double minDist = Double.MAX_VALUE;
	            double minIndex = -1;
	            for(int j=0;j<k;j++) {
	                double distJI = distEclud(centroids,dataSet,j,i);
	                if(distJI < minDist) {
	                    minDist = distJI;
	                    minIndex = j;
	                }

	            }

	            if(clusterAssment.get(i, 0) != minIndex) {
	            	clusterChanged = true;
	            	changed++;
	            }
	            	
	            clusterAssment.set(i, 0, minIndex);
	            clusterAssment.set(i,1, minDist*minDist);
	        }
	        
	        System.out.println("变动点数:"+changed);
	        System.out.println(centroids);
	        
	        for(int cent=0;cent<k;cent++) {
	        	
	        	DenseMatrix64F tmp = new DenseMatrix64F(0,dataSet.numCols);
	        	
	        	for(int i=0;i<dataSet.numRows;i++) {
	        		if(clusterAssment.get(i, 0) == cent) {
	        			tmp.reshape(tmp.numRows+1, dataSet.numCols, true);
	        			for(int j=0;j<dataSet.numCols;j++) {
	        				tmp.set(tmp.numRows-1, j, dataSet.get(i, j));
	        			}
	        		}
	        	}
	        	
	        	
	        	
	        	if(tmp.numRows > 0) {
	        		
	        		for(int i=0;i<tmp.numCols;i++) {
	        			double tmpSum=0;
	        			for(int j=0;j<tmp.numRows;j++) {
	        				tmpSum+=dataSet.get(j, i);
	        			}
	        			centroids.set(cent,i,tmpSum/(tmp.numRows));
		        	}
		        	
	        	}


            }
	        
        	
        //对簇点进行排序
        	for(int i=0;i<centroids.numRows-1;i++) {
        		for(int j=i+1;j<centroids.numRows;j++) {
            		if(centroids.get(i, 0) > centroids.get(j, 0)) {
            			for(int n=0;n<centroids.numCols;n++) {
            				double tmp = centroids.get(j,n);
            				centroids.set(j, n,centroids.get(i, n));
            				centroids.set(i, n,tmp);
            			}
            			
            		}
            	}
        	}
	            
	    }

	    return new DenseMatrix64F[] {centroids, clusterAssment};
	}
           

这里面有一个细节比较重要就是更新簇之后,簇得排序,否则做聚类运算的时候簇点会循环换位导致跳不出循环。

开始测试

List<String> list = new ArrayList<String>();
        try{
            BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch10\\testSet2.txt"));
            String s = null;
            while((s = br.readLine())!=null){
            	list.add(s);
            }
            br.close();    
        }catch(Exception e){
            e.printStackTrace();
        }
        
        DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),2);

        for(int i=0;i<list.size();i++) {
        	
        	String[] items = list.get(i).split("	");
        	dataMatIn.set(i, 0, Double.parseDouble(items[0]));
        	dataMatIn.set(i,1, Double.parseDouble(items[1]));
        }
       
        DenseMatrix64F[] test = kMeans(dataMatIn,4);
        
        System.out.println(test[0]);
        System.out.println(test[1]);
           

在多次循环后收敛

用java实现K-均值聚类(k-means)

ok搞定