Sunday, April 28, 2013

A Java Implementation of the K-means Clustering Algorithm for 3D points


In data mining, k-means clustering is a method of cluster analysis that aims to partition n points into k clusters where each point belongs to the cluster with the nearest mean. This results in a partitioning of the data space into Voronoi cells.

The following codes are my implementation of the K-means algorithm for 3D points. The Java project can be downloaded from https://sites.google.com/site/moderntone/K-Means.zip




Cluster.java

package kmeans;
import java.util.*;

public class Cluster {

 private final List<point> points;
 private Point centroid;
 
 public Cluster(Point firstPoint) {
  points = new ArrayList<point>();
  centroid = firstPoint;
 }
 
 public Point getCentroid(){
  return centroid;
 }
 
 public void updateCentroid(){
  double newx = 0d, newy = 0d, newz = 0d;
  for (Point point : points){
   newx += point.x; newy += point.y; newz += point.z;
  }
  centroid = new Point(newx / points.size(), newy / points.size(), newz / points.size());
 }
 
 public List<point> getPoints() {
  return points;
 }
 
 public String toString(){
  StringBuilder builder = new StringBuilder("This cluster contains the following points:\n");
  for (Point point : points)
   builder.append(point.toString() + ",\n");
  return builder.deleteCharAt(builder.length() - 2).toString(); 
 }
}



Clusters.java

package kmeans;

import java.util.*;

public class Clusters extends ArrayList<cluster> {

 private static final long serialVersionUID = 1L;
 private final List<point> allPoints;
 private boolean isChanged;
 
 public Clusters(List<point> allPoints){
  this.allPoints = allPoints;
 }
 
 /**@param point
  * @return the index of the Cluster nearest to the point
  */
 public Integer getNearestCluster(Point point){
  double minSquareOfDistance = Double.MAX_VALUE;
  int itsIndex = -1;
  for (int i = 0 ; i < size(); i++){
   double squareOfDistance = point.getSquareOfDistance(get(i).getCentroid());
   if (squareOfDistance < minSquareOfDistance){
    minSquareOfDistance = squareOfDistance;
    itsIndex = i;
   }
  }
  return itsIndex;
 }

 public boolean updateClusters(){
  for (Cluster cluster : this){
   cluster.updateCentroid();
   cluster.getPoints().clear();
  }
  isChanged = false;
  assignPointsToClusters();
  return isChanged;
 }
 
 public void assignPointsToClusters(){
  for (Point point : allPoints){
   int previousIndex = point.getIndex();
   int newIndex = getNearestCluster(point);
   if (previousIndex != newIndex)
    isChanged = true;
   Cluster target = get(newIndex);
   point.setIndex(newIndex);
   target.getPoints().add(point);
  }
 }
}

Point.java

package kmeans;

public class Point {
 
 private int index = -1; //denotes which Cluster it belongs to
 public double x, y, z;
 
 public Point(double x, double y, double z) {
  this.x = x;
  this.y = y;
  this.z = z;
 }
 
 public Double getSquareOfDistance(Point anotherPoint){
  return  (x - anotherPoint.x) * (x - anotherPoint.x)
    + (y - anotherPoint.y) *  (y - anotherPoint.y) 
    + (z - anotherPoint.z) *  (z - anotherPoint.z);
 }

 public int getIndex() {
  return index;
 }

 public void setIndex(int index) {
  this.index = index;
 }
 
 public String toString(){
  return "(" + x + "," + y + "," + z + ")";
 } 
}

KMeans.java

package kmeans;

import java.io.*;
import java.util.*;

public class KMeans {

 private static final Random random = new Random();
 public final List<point> allPoints;
 public final int k;
 private Clusters pointClusters; //the k Clusters

 /**@param pointsFile : the csv file for input points
  * @param k : number of clusters
  */
 public KMeans(String pointsFile, int k) {
  if (k < 2)
   new Exception("The value of k should be 2 or more.").printStackTrace();
  this.k = k;
  List<point> points = new ArrayList<point>();
  try {
   InputStreamReader read = new InputStreamReader(
     new FileInputStream(pointsFile), "UTF-8");
   BufferedReader reader = new BufferedReader(read);
   String line;
   while ((line = reader.readLine()) != null) 
    points.add(getPointByLine(line));
   reader.close();
   
  } catch (IOException e) {
   e.printStackTrace();
  }
  this.allPoints = Collections.unmodifiableList(points);
 }

 private Point getPointByLine(String line) {
  String[] xyz = line.split(",");
  return new Point(Double.parseDouble(xyz[0]),
    Double.parseDouble(xyz[1]), Double.parseDouble(xyz[2]));
 }

 /**step 1: get random seeds as initial centroids of the k clusters
  */
 private void getInitialKRandomSeeds(){
  pointClusters = new Clusters(allPoints);
  List<point> kRandomPoints = getKRandomPoints();
  for (int i = 0; i < k; i++){
   kRandomPoints.get(i).setIndex(i);
   pointClusters.add(new Cluster(kRandomPoints.get(i)));
  } 
 }
 
 private List<point> getKRandomPoints() {
  List<point> kRandomPoints = new ArrayList<point>();
  boolean[] alreadyChosen = new boolean[allPoints.size()];
  int size = allPoints.size();
  for (int i = 0; i < k; i++) {
   int index = -1, r = random.nextInt(size--) + 1;
   for (int j = 0; j < r; j++) {
    index++;
    while (alreadyChosen[index])
     index++;
   }
   kRandomPoints.add(allPoints.get(index));
   alreadyChosen[index] = true;
  }
  return kRandomPoints;
 }
 
 /**step 2: assign points to initial Clusters
  */
 private void getInitialClusters(){
  pointClusters.assignPointsToClusters();
 }
 
 /** step 3: update the k Clusters until no changes in their members occur
  */
 private void updateClustersUntilNoChange(){
  boolean isChanged = pointClusters.updateClusters();
  while (isChanged)
   isChanged = pointClusters.updateClusters();
 }
 
 /**do K-means clustering with this method
  */
 public List<cluster> getPointsClusters() {
  if (pointClusters == null) {
   getInitialKRandomSeeds();
   getInitialClusters();
   updateClustersUntilNoChange();
  }
  return pointClusters;
 }
 
 public static void main(String[] args) {
  String pointsFilePath = "files/randomPoints.csv";
  KMeans kMeans = new KMeans(pointsFilePath, 6);
  List<cluster> pointsClusters = kMeans.getPointsClusters();
  for (int i = 0 ; i < kMeans.k; i++)
   System.out.println("Cluster " + i + ": " + pointsClusters.get(i));
 }
}

Sunday, April 7, 2013

An implementation of the R-Tree algorithm in Java


The following classes are my implementation of R-Tree, which can be used to construct an R-Tree for a list of points in a plane. The package along with a csv file storing points to be inserted can be downloaded from https://sites.google.com/site/moderntone/RTree.zip. The method to split an overflowing node is the Quadratic method by Antonin Guttman.

I have done a little testing but am still not very sure that the source codes below are entirely free of bugs. Any reader who finds bugs is welcomed to report them in comments below. And if I find some bugs afterwards, I will also modify this post.




MBR.java

package RTree;

import java.util.ArrayList;
import java.util.Collections;

public class MBR {
 
 protected double left, right, top, bottom;
 private Double area;
 
 private ArrayList children;  //an array of children MBRs
 
 //an array of leaf entries; only leaf MBRs have entries with nonzero size
 private ArrayList entries;
 
 private MBR parent;  //All leaf entries and node entries except the root have a parent MBR
 
 
 
 private static int idTrace = 0;
 private Integer id;
 
 private static final QuadraticComparator qc = new QuadraticComparator();
 
 private static Integer m, M;
 public static void initialize(int m, int M){
  if (M < 2 || m > M / 2)
   new Exception("Improper m and M values").printStackTrace();
  MBR.m = m; MBR.M = M;
 }
 
 public MBR(double left, double right, double top, double bottom) {
  this.left = left; this.right = right;
  this.top = top;  this.bottom = bottom;
  if (left > right || bottom > top)
   new Exception("Left shouldn't be larger than right, " +
     "and bottom shouldn't be larger than top").printStackTrace();
  children = new ArrayList(); entries = new ArrayList();
  setId();
 }
 
 
 /**
  * search the leaf MBR the leafEntry inserts to;
  * @param leafEntry
  */
 public MBRPair search(MBR leafEntry) {
  MBRPair targetMbrPair = null;
  if (children.size() == 0) //has no children but may have leaf entries
   targetMbrPair = new MBRPair(this, leafEntry);
  else {
   ArrayList mbrPairs = new ArrayList();
   for (MBR child : children)
    mbrPairs.add(new MBRPair(child, leafEntry));
   targetMbrPair = Collections.min(mbrPairs).getTarget().search(leafEntry);
  }
  return targetMbrPair;
 }
 
 /**
  * If the leaf node is not full, an entry is inserted. Else
 –Split the leaf node
 –Update the directory rectangles of the ancestor nodes if necessary
 * return null if no split occurs, or root MBR if it does
  */
 public MBR splitWhenFull(){
  
  MBR parentMbr = null;
  
  if (getLeafEntries().size() == M + 1){
   parentMbr = split_quardratic_forLeafEntries();
//   System.out.println("MBR 72");
//   parentMbr.printDetails();
  }
  else 
   return null;
  while (parentMbr.getChildren().size() == M + 1) {
   parentMbr = parentMbr.split_quardratic_forNoneLeafEntries();
  }
  
  return parentMbr.getRoot();
  
 }
 
 
 private void updateNodes(MBR newChild){
  MBRPair temp = new MBRPair(this, newChild);
  if (temp.getEnlargement() == 0) return;
  adjustRegion(temp.getMergedMBR());
 }
 
 /**
  * save calculation time a bit
  * @param newChild
  * @param pair
  */
 private void updateNodes(MBR newChild, MBRPair pair){
  MBRPair temp;
  if (pair.getEnlargement() == 0) return;
  adjustRegion(pair.getMergedMBR());
  
  MBR ancestor = parent;
  while (ancestor != null){
   temp = new MBRPair(ancestor, newChild);
   if (temp.getEnlargement() == 0) return;
   ancestor.adjustRegion(temp.getMergedMBR());
   ancestor = ancestor.getParent();
  }
 }
 
 
 
 
 public void addNonLeafChild(MBR nonLeafChild){
  nonLeafChild.setParent(this);
  children.add(nonLeafChild);
  updateNodes(nonLeafChild);
 }
 
 /**
  * reduce calculation time a bit compared to addNonLeafChild(MBR nonLeafChild)
  */
 public void addNonLeafChild(MBR nonLeafChild, MBRPair pair){
  nonLeafChild.setParent(this);
  children.add(nonLeafChild);
  updateNodes(nonLeafChild, pair);
 }
 
 public void addLeafChild(MBR leafChild){
  leafChild.setParent(this);
  entries.add(leafChild);
  updateNodes(leafChild);
 }
 
 /**
  * reduce calculation time a bit compared to addLeafChild(MBR leafChild)
  */
 public void addLeafChild(MBR leafChild, MBRPair pair){
  leafChild.setParent(this);
  entries.add(leafChild);
  updateNodes(leafChild, pair);
 }
 
 
 public MBR split_quardratic_forLeafEntries(){
  MBR group1 = this;
  ArrayList group1sLeafEntries = group1.getLeafEntries();
  ArrayList allPairs = new ArrayList();
  for (int j = 1; j < group1sLeafEntries.size(); j ++){
   for (int i = 0 ; i < j; i++)
    allPairs.add(new MBRPair(group1sLeafEntries.get(i), group1sLeafEntries.get(j)));
  }
  
  MBRPair theBestPair = Collections.max(allPairs, qc);
  MBR group1_firstLeafEntry = theBestPair.getTarget();
  MBR group2_firstLeafEntry = theBestPair.getToBeInserted();
  
  ArrayList leafEntries_bak = new ArrayList();
  leafEntries_bak.addAll(group1sLeafEntries);
  leafEntries_bak.remove(group1_firstLeafEntry);
  leafEntries_bak.remove(group2_firstLeafEntry);


  group1sLeafEntries.clear();
  group1.adjustRegion(group1_firstLeafEntry);
  group1.addLeafChild(group1_firstLeafEntry);
  
  
  if (parent == null){ //happens when splitting the root; the parent becomes the new root
   parent = new MBR(left, right, top, bottom);
   parent.addNonLeafChild(group1);
   
  }
  
  MBR group2 = new MBR(group2_firstLeafEntry.left, group2_firstLeafEntry.right, group2_firstLeafEntry.top, group2_firstLeafEntry.bottom);
  parent.addNonLeafChild(group2);
  group2.addLeafChild(group2_firstLeafEntry);
  
  for (MBR child : leafEntries_bak){
   MBRPair pair1 = new MBRPair(group1, child);
   MBRPair pair2 = new MBRPair(group2, child);
   
   if (group1.getLeafEntries().size() == M - m + 1){
    group2.addLeafChild(child, pair2);
    continue;
   }else if (group2.getLeafEntries().size() == M - m + 1){
    group1.addLeafChild(child, pair1);
    continue;
   }
   
   if (pair1.getEnlargement() < pair2.getEnlargement()){
    group1.addLeafChild(child, pair1);
   }else if (pair2.getEnlargement() < pair1.getEnlargement() ){
    group2.addLeafChild(child, pair2);
   }else {
    if (group1.getArea() < group2.getArea()){
     group1.addLeafChild(child, pair1);
    }else if (group2.getArea() < group1.getArea()){
     group2.addLeafChild(child, pair2);
    }
    else {
     if (group1.getChildren().size() <= group2.getChildren().size())
      group1.addLeafChild(child, pair1);
     else
      group2.addLeafChild(child, pair2);
    }
   }
   
  }
  
//  System.out.println("MBR 216 : two groups " + group1.getLeafEntries().size() + ", " + group2.getLeafEntries().size());
//  System.out.println("MBR 217 " + group1.left + ", " + group1.right + ", " + group1.top + ", " + group1.bottom);
//  System.out.println("MBR 218 " + group2.left + ", " + group2.right + ", " + group2.top + ", " + group2.bottom);
  return parent;
  
 }
 
 public MBR split_quardratic_forNoneLeafEntries(){
  //this: group1; this.getChildren(): group1sChildren
  MBR group1 = this;
  ArrayList group1sChildren = group1.getChildren();
  ArrayList allPairs = new ArrayList();
  for (int j = 1; j < group1sChildren.size(); j ++){
   for (int i = 0 ; i < j; i++)
    allPairs.add(new MBRPair(group1sChildren.get(i), group1sChildren.get(j)));
  }
  
  MBRPair theBestPair = Collections.max(allPairs, qc);
  MBR group1_firstMBR = theBestPair.getTarget();
  MBR group2_firstMBR = theBestPair.getToBeInserted();
  
  ArrayList children_bak = new ArrayList();
  children_bak.addAll(group1sChildren);
  children_bak.remove(group1_firstMBR); 
  children_bak.remove(group2_firstMBR);
  group1sChildren.clear();
  group1.adjustRegion(group1_firstMBR);
  group1.addNonLeafChild(group1_firstMBR);
  
  if (parent == null){ //parent becomes new root
   parent = new MBR(left, right, top, bottom);
   parent.addNonLeafChild(group1);
  }
  
  MBR group2 = new MBR(group2_firstMBR.left, group2_firstMBR.right, group2_firstMBR.top, group2_firstMBR.bottom);
  parent.addNonLeafChild(group2);
  group2.addNonLeafChild(group2_firstMBR);
  
  for (MBR child : children_bak){
   MBRPair pair1 = new MBRPair(group1, child);
   MBRPair pair2 = new MBRPair(group2, child);
   
   if (group1.getChildren().size() == M - m + 1){
    group2.addNonLeafChild(child, pair2);
    continue;
   }else if (group2.getChildren().size() == M - m + 1){
    group1.addNonLeafChild(child, pair1);
    continue;
   }
   
   if (pair1.getEnlargement() < pair2.getEnlargement()){
    group1.addNonLeafChild(child, pair1);
   }else if (pair2.getEnlargement() < pair1.getEnlargement() ){
    group2.addNonLeafChild(child, pair2);
   }else {
    if (group1.getArea() < group2.getArea()){
     group1.addNonLeafChild(child, pair1);
    }else if (group2.getArea() < group1.getArea()){
     group2.addNonLeafChild(child, pair2);
    }
    else {
     if (group1.getChildren().size() <= group2.getChildren().size())
      group1.addNonLeafChild(child, pair1);
     else
      group2.addNonLeafChild(child, pair2);
    }
   }
   
  }
  
  return parent;
 }
 
 
 

 /**
  * adjust regon and leave other info unchanged
  * @param newRegionMBR
  */
 private void adjustRegion(MBR newRegionMBR){
  this.left = newRegionMBR.left; this.right = newRegionMBR.right;
  this.top = newRegionMBR.top; this.bottom = newRegionMBR.bottom;
 }
 
 public Double getArea(){
  if (area == null)
   area = (right - left) * (top - bottom);
  return area;
 }
 

 public MBR getParent() {
  return parent;
 }

 public void setParent(MBR parent) {
  this.parent = parent;
 }

 public ArrayList getChildren() {
  return children;
 }
 


 public Integer getId() {
  return id;
 }
 public void setId() {
  if (id == null)
   id = ++idTrace; 
 }

 public ArrayList getLeafEntries() {
  return entries;
 }
 
 public MBR getRoot(){
  if (parent == null)
   return this;
  MBR root = parent;
  while (root.getParent() != null)
   root = root.getParent();
  return root;
 }

 public void printDetails(){
  System.out.println("MBR.printDetails(): for MBR with id = " + id);
  System.out.println("left = " + left + ", right = " + right 
    + ", top = " + top + ", bottom = " + bottom);
  System.out.println("children.size() = " + children.size() + ", leafEntries.size() = " + entries.size());
  for (int i = 0; i < children.size(); i++){
   System.out.println("child " + i + "left = " + children.get(i).left + ", right = " + children.get(i).right 
     + ", top = " + children.get(i).top + ", bottom = " + children.get(i).bottom);
  }
  if (parent != null){
   System.out.println("parent.left = " + parent.left + ", parent.right = " + parent.right 
     + ", parent.top = " + parent.top + ", parent.bottom = " + parent.bottom);
  }
 } 
}

MBRPair.java
package RTree;

public class MBRPair implements Comparable{

 /**
  * An MBRPair represent a pair of MBR;
  * Used to facilitate the determination of the most proper leaf MBR an entry should be inserted to
  * or the most proper non-leaf MBR an MBR should be inserted to 
  */
 private Double enlargement;
 private final MBR target, toBeInserted;
 private MBR mergedMBR; //merge target and toBeInserted into one by adjusting left, right, top, bottom
 
 public MBRPair(MBR target, MBR toBeInserted){
  this.target = target; 
  this.toBeInserted = toBeInserted;
 }
 
 /**
  * �VIf there is a node whose directory rectangle contains the mbbto be inserted, then search the subtree
�VElse choose a node such that the enlargement of its directory rectangle is minimal, then search the subtree
�VIf more than one node satisfy this, choose the one with smallest area
  */
 @Override
 public int compareTo(MBRPair anotherPair) {
  int firstComparison = getEnlargement().compareTo(anotherPair.getEnlargement());
  if (firstComparison != 0)
   return firstComparison;
  return target.getArea().compareTo(anotherPair.getTarget().getArea());
 }
 
 public Double getEnlargement() {
  if (enlargement == null){
   double leftMost, rightMost, topMost, bottomMost;
   leftMost = min(target.left, toBeInserted.left);
   rightMost = max(target.right, toBeInserted.right);
   topMost = max(target.top, toBeInserted.top);
   bottomMost = min(target.bottom, toBeInserted.bottom);
   mergedMBR = new MBR(leftMost, rightMost, topMost, bottomMost);
   enlargement = mergedMBR.getArea() - target.getArea();
  }
  return enlargement;
 }
 
 
 private double max(double a, double b){
  return a > b ? a : b;
 }
 private double min(double a, double b){
  return a < b ? a : b;
 }
 
 public MBR getTarget() {
  return target;
 }
 
 public MBR getToBeInserted() {
  return toBeInserted;
 }

 public MBR getMergedMBR() {
  if (mergedMBR == null)
   getEnlargement();
  return mergedMBR;
 }
}
QuadraticComparator.java
package RTree;

import java.util.Comparator;

public class QuadraticComparator implements Comparator{
 
 @Override
 public int compare(MBRPair pair1, MBRPair pair2) {
  Double additionalArea1 = computeAdditionalArea(pair1);
  Double additionalArea2 = computeAdditionalArea(pair2);
  return additionalArea1.compareTo(additionalArea2);
 }
 
 private double computeAdditionalArea(MBRPair pair){
  return pair.getMergedMBR().getArea() - pair.getTarget().getArea() - pair.getToBeInserted().getArea();
 }
}
RTree.java
package RTree;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;


public class RTree {

 public RTree() {
  
 }
 MBR root = null;
 
 public static void main(String[] args) {
  MBR.initialize(3, 7);
  RTree rTree = new RTree();
  String pointsFilePath = "files/randomPoints.csv";
  MBR root = rTree.constructRTree(pointsFilePath);
  root.printDetails();
 }
 
 
 public MBR constructRTree(String pointsFilePath){
  
  ArrayList leafMBRsToBeInserted = readLeafMBRs(pointsFilePath);
  if (leafMBRsToBeInserted.size() == 0)
   return root = null;
  else{
   MBR firstLeafMBR = leafMBRsToBeInserted.get(0);
   root = new MBR(firstLeafMBR.left, firstLeafMBR.right, firstLeafMBR.top, firstLeafMBR.bottom);
   root.addLeafChild(firstLeafMBR);
  }
  for (int i = 1 ; i < leafMBRsToBeInserted.size(); i++){
   MBRPair pair = root.search(leafMBRsToBeInserted.get(i));
   MBR targetMbr = pair.getTarget();
   targetMbr.addLeafChild(leafMBRsToBeInserted.get(i), pair);
   
   MBR newRoot = targetMbr.splitWhenFull();
   if (newRoot != null)
    root = newRoot;
   
  }
  
  return root;
 }
 
 
 
 /**
  * Read leaf MBRs to be inserted to the RTree from file.
  * The leaf MBrs of the RTree are zero-area points. 
  */
 private ArrayList readLeafMBRs(String pointsFilePath){
  ArrayList points = new ArrayList();
  try {
   InputStreamReader read = new InputStreamReader(new FileInputStream(pointsFilePath), "utf-8");
   BufferedReader reader = new BufferedReader(read);
   String line;
   while ((line = reader.readLine()) != null) {
    
    int comma = line.indexOf(",");
    double x = Double.parseDouble(line.substring(0, comma));
    double y = Double.parseDouble(line.substring(comma + 1));
    points.add(new MBR(x, x, y, y));
   } reader.close();
  } catch (Exception e) {
   e.printStackTrace();
  }
  return points;
 }

}