In data mining, k-means clustering is a method of cluster analysis that aims to partition n points into k clusters where each point belongs to the cluster with the nearest mean. This results in a partitioning of the data space into Voronoi cells.
The following codes are my implementation of the K-means algorithm for 3D points. The Java project can be downloaded from https://sites.google.com/site/moderntone/K-Means.zip
Cluster.java
package kmeans;
import java.util.*;
public class Cluster {
private final List<point> points;
private Point centroid;
public Cluster(Point firstPoint) {
points = new ArrayList<point>();
centroid = firstPoint;
}
public Point getCentroid(){
return centroid;
}
public void updateCentroid(){
double newx = 0d, newy = 0d, newz = 0d;
for (Point point : points){
newx += point.x; newy += point.y; newz += point.z;
}
centroid = new Point(newx / points.size(), newy / points.size(), newz / points.size());
}
public List<point> getPoints() {
return points;
}
public String toString(){
StringBuilder builder = new StringBuilder("This cluster contains the following points:\n");
for (Point point : points)
builder.append(point.toString() + ",\n");
return builder.deleteCharAt(builder.length() - 2).toString();
}
}
Clusters.java
package kmeans;
import java.util.*;
public class Clusters extends ArrayList<cluster> {
private static final long serialVersionUID = 1L;
private final List<point> allPoints;
private boolean isChanged;
public Clusters(List<point> allPoints){
this.allPoints = allPoints;
}
/**@param point
* @return the index of the Cluster nearest to the point
*/
public Integer getNearestCluster(Point point){
double minSquareOfDistance = Double.MAX_VALUE;
int itsIndex = -1;
for (int i = 0 ; i < size(); i++){
double squareOfDistance = point.getSquareOfDistance(get(i).getCentroid());
if (squareOfDistance < minSquareOfDistance){
minSquareOfDistance = squareOfDistance;
itsIndex = i;
}
}
return itsIndex;
}
public boolean updateClusters(){
for (Cluster cluster : this){
cluster.updateCentroid();
cluster.getPoints().clear();
}
isChanged = false;
assignPointsToClusters();
return isChanged;
}
public void assignPointsToClusters(){
for (Point point : allPoints){
int previousIndex = point.getIndex();
int newIndex = getNearestCluster(point);
if (previousIndex != newIndex)
isChanged = true;
Cluster target = get(newIndex);
point.setIndex(newIndex);
target.getPoints().add(point);
}
}
}
Point.java
package kmeans;
public class Point {
private int index = -1; //denotes which Cluster it belongs to
public double x, y, z;
public Point(double x, double y, double z) {
this.x = x;
this.y = y;
this.z = z;
}
public Double getSquareOfDistance(Point anotherPoint){
return (x - anotherPoint.x) * (x - anotherPoint.x)
+ (y - anotherPoint.y) * (y - anotherPoint.y)
+ (z - anotherPoint.z) * (z - anotherPoint.z);
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public String toString(){
return "(" + x + "," + y + "," + z + ")";
}
}
KMeans.java
package kmeans;
import java.io.*;
import java.util.*;
public class KMeans {
private static final Random random = new Random();
public final List<point> allPoints;
public final int k;
private Clusters pointClusters; //the k Clusters
/**@param pointsFile : the csv file for input points
* @param k : number of clusters
*/
public KMeans(String pointsFile, int k) {
if (k < 2)
new Exception("The value of k should be 2 or more.").printStackTrace();
this.k = k;
List<point> points = new ArrayList<point>();
try {
InputStreamReader read = new InputStreamReader(
new FileInputStream(pointsFile), "UTF-8");
BufferedReader reader = new BufferedReader(read);
String line;
while ((line = reader.readLine()) != null)
points.add(getPointByLine(line));
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
this.allPoints = Collections.unmodifiableList(points);
}
private Point getPointByLine(String line) {
String[] xyz = line.split(",");
return new Point(Double.parseDouble(xyz[0]),
Double.parseDouble(xyz[1]), Double.parseDouble(xyz[2]));
}
/**step 1: get random seeds as initial centroids of the k clusters
*/
private void getInitialKRandomSeeds(){
pointClusters = new Clusters(allPoints);
List<point> kRandomPoints = getKRandomPoints();
for (int i = 0; i < k; i++){
kRandomPoints.get(i).setIndex(i);
pointClusters.add(new Cluster(kRandomPoints.get(i)));
}
}
private List<point> getKRandomPoints() {
List<point> kRandomPoints = new ArrayList<point>();
boolean[] alreadyChosen = new boolean[allPoints.size()];
int size = allPoints.size();
for (int i = 0; i < k; i++) {
int index = -1, r = random.nextInt(size--) + 1;
for (int j = 0; j < r; j++) {
index++;
while (alreadyChosen[index])
index++;
}
kRandomPoints.add(allPoints.get(index));
alreadyChosen[index] = true;
}
return kRandomPoints;
}
/**step 2: assign points to initial Clusters
*/
private void getInitialClusters(){
pointClusters.assignPointsToClusters();
}
/** step 3: update the k Clusters until no changes in their members occur
*/
private void updateClustersUntilNoChange(){
boolean isChanged = pointClusters.updateClusters();
while (isChanged)
isChanged = pointClusters.updateClusters();
}
/**do K-means clustering with this method
*/
public List<cluster> getPointsClusters() {
if (pointClusters == null) {
getInitialKRandomSeeds();
getInitialClusters();
updateClustersUntilNoChange();
}
return pointClusters;
}
public static void main(String[] args) {
String pointsFilePath = "files/randomPoints.csv";
KMeans kMeans = new KMeans(pointsFilePath, 6);
List<cluster> pointsClusters = kMeans.getPointsClusters();
for (int i = 0 ; i < kMeans.k; i++)
System.out.println("Cluster " + i + ": " + pointsClusters.get(i));
}
}
thank you for your post of k means algo for 3d its very useful for my project....... :)
ReplyDeletehay can u send me the correct code of this
DeleteI have exceptions while compiling it could u please help me ? :(
ReplyDeletewhen i run this program ,compiler show belove error :
ReplyDeleteException in thread "main" java.lang.Error: Unresolved compilation problems:
cluster cannot be resolved to a type
The method getPointsClusters() from the type KMeans refers to the missing type cluster
at kmeans.KMeans.main(KMeans.java:98)
what is the meaning of anotherpoint.x same for y and z getsquare of method of point class and when it it called in clusters class the argumentt is get(I).Getcetroid() how it works? ans ,e as soon as possible i really need to solve within some days
ReplyDelete