Sunday, April 28, 2013

A Java Implementation of the K-means Clustering Algorithm for 3D points


In data mining, k-means clustering is a method of cluster analysis that aims to partition n points into k clusters where each point belongs to the cluster with the nearest mean. This results in a partitioning of the data space into Voronoi cells.

The following codes are my implementation of the K-means algorithm for 3D points. The Java project can be downloaded from https://sites.google.com/site/moderntone/K-Means.zip




Cluster.java

package kmeans;
import java.util.*;

public class Cluster {

 private final List<point> points;
 private Point centroid;
 
 public Cluster(Point firstPoint) {
  points = new ArrayList<point>();
  centroid = firstPoint;
 }
 
 public Point getCentroid(){
  return centroid;
 }
 
 public void updateCentroid(){
  double newx = 0d, newy = 0d, newz = 0d;
  for (Point point : points){
   newx += point.x; newy += point.y; newz += point.z;
  }
  centroid = new Point(newx / points.size(), newy / points.size(), newz / points.size());
 }
 
 public List<point> getPoints() {
  return points;
 }
 
 public String toString(){
  StringBuilder builder = new StringBuilder("This cluster contains the following points:\n");
  for (Point point : points)
   builder.append(point.toString() + ",\n");
  return builder.deleteCharAt(builder.length() - 2).toString(); 
 }
}



Clusters.java

package kmeans;

import java.util.*;

public class Clusters extends ArrayList<cluster> {

 private static final long serialVersionUID = 1L;
 private final List<point> allPoints;
 private boolean isChanged;
 
 public Clusters(List<point> allPoints){
  this.allPoints = allPoints;
 }
 
 /**@param point
  * @return the index of the Cluster nearest to the point
  */
 public Integer getNearestCluster(Point point){
  double minSquareOfDistance = Double.MAX_VALUE;
  int itsIndex = -1;
  for (int i = 0 ; i < size(); i++){
   double squareOfDistance = point.getSquareOfDistance(get(i).getCentroid());
   if (squareOfDistance < minSquareOfDistance){
    minSquareOfDistance = squareOfDistance;
    itsIndex = i;
   }
  }
  return itsIndex;
 }

 public boolean updateClusters(){
  for (Cluster cluster : this){
   cluster.updateCentroid();
   cluster.getPoints().clear();
  }
  isChanged = false;
  assignPointsToClusters();
  return isChanged;
 }
 
 public void assignPointsToClusters(){
  for (Point point : allPoints){
   int previousIndex = point.getIndex();
   int newIndex = getNearestCluster(point);
   if (previousIndex != newIndex)
    isChanged = true;
   Cluster target = get(newIndex);
   point.setIndex(newIndex);
   target.getPoints().add(point);
  }
 }
}

Point.java

package kmeans;

public class Point {
 
 private int index = -1; //denotes which Cluster it belongs to
 public double x, y, z;
 
 public Point(double x, double y, double z) {
  this.x = x;
  this.y = y;
  this.z = z;
 }
 
 public Double getSquareOfDistance(Point anotherPoint){
  return  (x - anotherPoint.x) * (x - anotherPoint.x)
    + (y - anotherPoint.y) *  (y - anotherPoint.y) 
    + (z - anotherPoint.z) *  (z - anotherPoint.z);
 }

 public int getIndex() {
  return index;
 }

 public void setIndex(int index) {
  this.index = index;
 }
 
 public String toString(){
  return "(" + x + "," + y + "," + z + ")";
 } 
}

KMeans.java

package kmeans;

import java.io.*;
import java.util.*;

public class KMeans {

 private static final Random random = new Random();
 public final List<point> allPoints;
 public final int k;
 private Clusters pointClusters; //the k Clusters

 /**@param pointsFile : the csv file for input points
  * @param k : number of clusters
  */
 public KMeans(String pointsFile, int k) {
  if (k < 2)
   new Exception("The value of k should be 2 or more.").printStackTrace();
  this.k = k;
  List<point> points = new ArrayList<point>();
  try {
   InputStreamReader read = new InputStreamReader(
     new FileInputStream(pointsFile), "UTF-8");
   BufferedReader reader = new BufferedReader(read);
   String line;
   while ((line = reader.readLine()) != null) 
    points.add(getPointByLine(line));
   reader.close();
   
  } catch (IOException e) {
   e.printStackTrace();
  }
  this.allPoints = Collections.unmodifiableList(points);
 }

 private Point getPointByLine(String line) {
  String[] xyz = line.split(",");
  return new Point(Double.parseDouble(xyz[0]),
    Double.parseDouble(xyz[1]), Double.parseDouble(xyz[2]));
 }

 /**step 1: get random seeds as initial centroids of the k clusters
  */
 private void getInitialKRandomSeeds(){
  pointClusters = new Clusters(allPoints);
  List<point> kRandomPoints = getKRandomPoints();
  for (int i = 0; i < k; i++){
   kRandomPoints.get(i).setIndex(i);
   pointClusters.add(new Cluster(kRandomPoints.get(i)));
  } 
 }
 
 private List<point> getKRandomPoints() {
  List<point> kRandomPoints = new ArrayList<point>();
  boolean[] alreadyChosen = new boolean[allPoints.size()];
  int size = allPoints.size();
  for (int i = 0; i < k; i++) {
   int index = -1, r = random.nextInt(size--) + 1;
   for (int j = 0; j < r; j++) {
    index++;
    while (alreadyChosen[index])
     index++;
   }
   kRandomPoints.add(allPoints.get(index));
   alreadyChosen[index] = true;
  }
  return kRandomPoints;
 }
 
 /**step 2: assign points to initial Clusters
  */
 private void getInitialClusters(){
  pointClusters.assignPointsToClusters();
 }
 
 /** step 3: update the k Clusters until no changes in their members occur
  */
 private void updateClustersUntilNoChange(){
  boolean isChanged = pointClusters.updateClusters();
  while (isChanged)
   isChanged = pointClusters.updateClusters();
 }
 
 /**do K-means clustering with this method
  */
 public List<cluster> getPointsClusters() {
  if (pointClusters == null) {
   getInitialKRandomSeeds();
   getInitialClusters();
   updateClustersUntilNoChange();
  }
  return pointClusters;
 }
 
 public static void main(String[] args) {
  String pointsFilePath = "files/randomPoints.csv";
  KMeans kMeans = new KMeans(pointsFilePath, 6);
  List<cluster> pointsClusters = kMeans.getPointsClusters();
  for (int i = 0 ; i < kMeans.k; i++)
   System.out.println("Cluster " + i + ": " + pointsClusters.get(i));
 }
}

3 comments:

  1. thank you for your post of k means algo for 3d its very useful for my project....... :)

    ReplyDelete
  2. I have exceptions while compiling it could u please help me ? :(

    ReplyDelete
  3. when i run this program ,compiler show belove error :
    Exception in thread "main" java.lang.Error: Unresolved compilation problems:
    cluster cannot be resolved to a type
    The method getPointsClusters() from the type KMeans refers to the missing type cluster

    at kmeans.KMeans.main(KMeans.java:98)

    ReplyDelete