001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math.stat.clustering; 019 020 import java.util.ArrayList; 021 import java.util.Collection; 022 import java.util.List; 023 import java.util.Random; 024 025 /** 026 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 027 * @param <T> type of the points to cluster 028 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 029 * @version $Revision: 771076 $ $Date: 2009-05-03 12:28:48 -0400 (Sun, 03 May 2009) $ 030 * @since 2.0 031 */ 032 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { 033 034 /** Random generator for choosing initial centers. */ 035 private final Random random; 036 037 /** Build a clusterer. 038 * @param random random generator to use for choosing initial centers 039 */ 040 public KMeansPlusPlusClusterer(final Random random) { 041 this.random = random; 042 } 043 044 /** 045 * Runs the K-means++ clustering algorithm. 046 * 047 * @param points the points to cluster 048 * @param k the number of clusters to split the data into 049 * @param maxIterations the maximum number of iterations to run the algorithm 050 * for. If negative, no maximum will be used 051 * @return a list of clusters containing the points 052 */ 053 public List<Cluster<T>> cluster(final Collection<T> points, 054 final int k, final int maxIterations) { 055 // create the initial clusters 056 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); 057 assignPointsToClusters(clusters, points); 058 059 // iterate through updating the centers until we're done 060 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 061 for (int count = 0; count < max; count++) { 062 boolean clusteringChanged = false; 063 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); 064 for (final Cluster<T> cluster : clusters) { 065 final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); 066 if (!newCenter.equals(cluster.getCenter())) { 067 clusteringChanged = true; 068 } 069 newClusters.add(new Cluster<T>(newCenter)); 070 } 071 if (!clusteringChanged) { 072 return clusters; 073 } 074 assignPointsToClusters(newClusters, points); 075 clusters = newClusters; 076 } 077 return clusters; 078 } 079 080 /** 081 * Adds the given points to the closest {@link Cluster}. 082 * 083 * @param <T> type of the points to cluster 084 * @param clusters the {@link Cluster}s to add the points to 085 * @param points the points to add to the given {@link Cluster}s 086 */ 087 private static <T extends Clusterable<T>> void 088 assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) { 089 for (final T p : points) { 090 Cluster<T> cluster = getNearestCluster(clusters, p); 091 cluster.addPoint(p); 092 } 093 } 094 095 /** 096 * Use K-means++ to choose the initial centers. 097 * 098 * @param <T> type of the points to cluster 099 * @param points the points to choose the initial centers from 100 * @param k the number of centers to choose 101 * @param random random generator to use 102 * @return the initial centers 103 */ 104 private static <T extends Clusterable<T>> List<Cluster<T>> 105 chooseInitialCenters(final Collection<T> points, final int k, final Random random) { 106 107 final List<T> pointSet = new ArrayList<T>(points); 108 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); 109 110 // Choose one center uniformly at random from among the data points. 111 final T firstPoint = pointSet.remove(random.nextInt(pointSet.size())); 112 resultSet.add(new Cluster<T>(firstPoint)); 113 114 final double[] dx2 = new double[pointSet.size()]; 115 while (resultSet.size() < k) { 116 // For each data point x, compute D(x), the distance between x and 117 // the nearest center that has already been chosen. 118 int sum = 0; 119 for (int i = 0; i < pointSet.size(); i++) { 120 final T p = pointSet.get(i); 121 final Cluster<T> nearest = getNearestCluster(resultSet, p); 122 final double d = p.distanceFrom(nearest.getCenter()); 123 sum += d * d; 124 dx2[i] = sum; 125 } 126 127 // Add one new data point as a center. Each point x is chosen with 128 // probability proportional to D(x)2 129 final double r = random.nextDouble() * sum; 130 for (int i = 0 ; i < dx2.length; i++) { 131 if (dx2[i] >= r) { 132 final T p = pointSet.remove(i); 133 resultSet.add(new Cluster<T>(p)); 134 break; 135 } 136 } 137 } 138 139 return resultSet; 140 141 } 142 143 /** 144 * Returns the nearest {@link Cluster} to the given point 145 * 146 * @param <T> type of the points to cluster 147 * @param clusters the {@link Cluster}s to search 148 * @param point the point to find the nearest {@link Cluster} for 149 * @return the nearest {@link Cluster} to the given point 150 */ 151 private static <T extends Clusterable<T>> Cluster<T> 152 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { 153 double minDistance = Double.MAX_VALUE; 154 Cluster<T> minCluster = null; 155 for (final Cluster<T> c : clusters) { 156 final double distance = point.distanceFrom(c.getCenter()); 157 if (distance < minDistance) { 158 minDistance = distance; 159 minCluster = c; 160 } 161 } 162 return minCluster; 163 } 164 165 }