In data mining, k-means clustering is a method of cluster analysis that aims to partition n points into k clusters where each point belongs to the cluster with the nearest mean. This results in a partitioning of the data space into Voronoi cells.
The following codes are my implementation of the K-means algorithm for 3D points. The Java project can be downloaded from https://sites.google.com/site/moderntone/K-Means.zip
Cluster.java
package kmeans; import java.util.*; public class Cluster { private final List<point> points; private Point centroid; public Cluster(Point firstPoint) { points = new ArrayList<point>(); centroid = firstPoint; } public Point getCentroid(){ return centroid; } public void updateCentroid(){ double newx = 0d, newy = 0d, newz = 0d; for (Point point : points){ newx += point.x; newy += point.y; newz += point.z; } centroid = new Point(newx / points.size(), newy / points.size(), newz / points.size()); } public List<point> getPoints() { return points; } public String toString(){ StringBuilder builder = new StringBuilder("This cluster contains the following points:\n"); for (Point point : points) builder.append(point.toString() + ",\n"); return builder.deleteCharAt(builder.length() - 2).toString(); } }
Clusters.java
package kmeans; import java.util.*; public class Clusters extends ArrayList<cluster> { private static final long serialVersionUID = 1L; private final List<point> allPoints; private boolean isChanged; public Clusters(List<point> allPoints){ this.allPoints = allPoints; } /**@param point * @return the index of the Cluster nearest to the point */ public Integer getNearestCluster(Point point){ double minSquareOfDistance = Double.MAX_VALUE; int itsIndex = -1; for (int i = 0 ; i < size(); i++){ double squareOfDistance = point.getSquareOfDistance(get(i).getCentroid()); if (squareOfDistance < minSquareOfDistance){ minSquareOfDistance = squareOfDistance; itsIndex = i; } } return itsIndex; } public boolean updateClusters(){ for (Cluster cluster : this){ cluster.updateCentroid(); cluster.getPoints().clear(); } isChanged = false; assignPointsToClusters(); return isChanged; } public void assignPointsToClusters(){ for (Point point : allPoints){ int previousIndex = point.getIndex(); int newIndex = getNearestCluster(point); if (previousIndex != newIndex) isChanged = true; Cluster target = get(newIndex); point.setIndex(newIndex); target.getPoints().add(point); } } }
Point.java
package kmeans; public class Point { private int index = -1; //denotes which Cluster it belongs to public double x, y, z; public Point(double x, double y, double z) { this.x = x; this.y = y; this.z = z; } public Double getSquareOfDistance(Point anotherPoint){ return (x - anotherPoint.x) * (x - anotherPoint.x) + (y - anotherPoint.y) * (y - anotherPoint.y) + (z - anotherPoint.z) * (z - anotherPoint.z); } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public String toString(){ return "(" + x + "," + y + "," + z + ")"; } }
KMeans.java
package kmeans; import java.io.*; import java.util.*; public class KMeans { private static final Random random = new Random(); public final List<point> allPoints; public final int k; private Clusters pointClusters; //the k Clusters /**@param pointsFile : the csv file for input points * @param k : number of clusters */ public KMeans(String pointsFile, int k) { if (k < 2) new Exception("The value of k should be 2 or more.").printStackTrace(); this.k = k; List<point> points = new ArrayList<point>(); try { InputStreamReader read = new InputStreamReader( new FileInputStream(pointsFile), "UTF-8"); BufferedReader reader = new BufferedReader(read); String line; while ((line = reader.readLine()) != null) points.add(getPointByLine(line)); reader.close(); } catch (IOException e) { e.printStackTrace(); } this.allPoints = Collections.unmodifiableList(points); } private Point getPointByLine(String line) { String[] xyz = line.split(","); return new Point(Double.parseDouble(xyz[0]), Double.parseDouble(xyz[1]), Double.parseDouble(xyz[2])); } /**step 1: get random seeds as initial centroids of the k clusters */ private void getInitialKRandomSeeds(){ pointClusters = new Clusters(allPoints); List<point> kRandomPoints = getKRandomPoints(); for (int i = 0; i < k; i++){ kRandomPoints.get(i).setIndex(i); pointClusters.add(new Cluster(kRandomPoints.get(i))); } } private List<point> getKRandomPoints() { List<point> kRandomPoints = new ArrayList<point>(); boolean[] alreadyChosen = new boolean[allPoints.size()]; int size = allPoints.size(); for (int i = 0; i < k; i++) { int index = -1, r = random.nextInt(size--) + 1; for (int j = 0; j < r; j++) { index++; while (alreadyChosen[index]) index++; } kRandomPoints.add(allPoints.get(index)); alreadyChosen[index] = true; } return kRandomPoints; } /**step 2: assign points to initial Clusters */ private void getInitialClusters(){ pointClusters.assignPointsToClusters(); } /** step 3: update the k Clusters until no changes in their members occur */ private void updateClustersUntilNoChange(){ boolean isChanged = pointClusters.updateClusters(); while (isChanged) isChanged = pointClusters.updateClusters(); } /**do K-means clustering with this method */ public List<cluster> getPointsClusters() { if (pointClusters == null) { getInitialKRandomSeeds(); getInitialClusters(); updateClustersUntilNoChange(); } return pointClusters; } public static void main(String[] args) { String pointsFilePath = "files/randomPoints.csv"; KMeans kMeans = new KMeans(pointsFilePath, 6); List<cluster> pointsClusters = kMeans.getPointsClusters(); for (int i = 0 ; i < kMeans.k; i++) System.out.println("Cluster " + i + ": " + pointsClusters.get(i)); } }
thank you for your post of k means algo for 3d its very useful for my project....... :)
ReplyDeletehay can u send me the correct code of this
DeleteI have exceptions while compiling it could u please help me ? :(
ReplyDeletewhen i run this program ,compiler show belove error :
ReplyDeleteException in thread "main" java.lang.Error: Unresolved compilation problems:
cluster cannot be resolved to a type
The method getPointsClusters() from the type KMeans refers to the missing type cluster
at kmeans.KMeans.main(KMeans.java:98)
what is the meaning of anotherpoint.x same for y and z getsquare of method of point class and when it it called in clusters class the argumentt is get(I).Getcetroid() how it works? ans ,e as soon as possible i really need to solve within some days
ReplyDelete