2
* Licensed to the Apache Software Foundation (ASF) under one or more
3
* contributor license agreements. See the NOTICE file distributed with
4
* this work for additional information regarding copyright ownership.
5
* The ASF licenses this file to You under the Apache License, Version 2.0
6
* (the "License"); you may not use this file except in compliance with
7
* the License. You may obtain a copy of the License at
9
* http://www.apache.org/licenses/LICENSE-2.0
11
* Unless required by applicable law or agreed to in writing, software
12
* distributed under the License is distributed on an "AS IS" BASIS,
13
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
* See the License for the specific language governing permissions and
15
* limitations under the License.
18
package org.apache.commons.math.stat.clustering;
20
import java.util.ArrayList;
21
import java.util.Collection;
22
import java.util.List;
23
import java.util.Random;
26
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
27
* @param <T> type of the points to cluster
28
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
29
* @version $Revision: 771076 $ $Date: 2009-05-03 12:28:48 -0400 (Sun, 03 May 2009) $
32
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
34
/** Random generator for choosing initial centers. */
35
private final Random random;
37
/** Build a clusterer.
38
* @param random random generator to use for choosing initial centers
40
public KMeansPlusPlusClusterer(final Random random) {
45
* Runs the K-means++ clustering algorithm.
47
* @param points the points to cluster
48
* @param k the number of clusters to split the data into
49
* @param maxIterations the maximum number of iterations to run the algorithm
50
* for. If negative, no maximum will be used
51
* @return a list of clusters containing the points
53
public List<Cluster<T>> cluster(final Collection<T> points,
54
final int k, final int maxIterations) {
55
// create the initial clusters
56
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
57
assignPointsToClusters(clusters, points);
59
// iterate through updating the centers until we're done
60
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
61
for (int count = 0; count < max; count++) {
62
boolean clusteringChanged = false;
63
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
64
for (final Cluster<T> cluster : clusters) {
65
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
66
if (!newCenter.equals(cluster.getCenter())) {
67
clusteringChanged = true;
69
newClusters.add(new Cluster<T>(newCenter));
71
if (!clusteringChanged) {
74
assignPointsToClusters(newClusters, points);
75
clusters = newClusters;
81
* Adds the given points to the closest {@link Cluster}.
83
* @param <T> type of the points to cluster
84
* @param clusters the {@link Cluster}s to add the points to
85
* @param points the points to add to the given {@link Cluster}s
87
private static <T extends Clusterable<T>> void
88
assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
89
for (final T p : points) {
90
Cluster<T> cluster = getNearestCluster(clusters, p);
96
* Use K-means++ to choose the initial centers.
98
* @param <T> type of the points to cluster
99
* @param points the points to choose the initial centers from
100
* @param k the number of centers to choose
101
* @param random random generator to use
102
* @return the initial centers
104
private static <T extends Clusterable<T>> List<Cluster<T>>
105
chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
107
final List<T> pointSet = new ArrayList<T>(points);
108
final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
110
// Choose one center uniformly at random from among the data points.
111
final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
112
resultSet.add(new Cluster<T>(firstPoint));
114
final double[] dx2 = new double[pointSet.size()];
115
while (resultSet.size() < k) {
116
// For each data point x, compute D(x), the distance between x and
117
// the nearest center that has already been chosen.
119
for (int i = 0; i < pointSet.size(); i++) {
120
final T p = pointSet.get(i);
121
final Cluster<T> nearest = getNearestCluster(resultSet, p);
122
final double d = p.distanceFrom(nearest.getCenter());
127
// Add one new data point as a center. Each point x is chosen with
128
// probability proportional to D(x)2
129
final double r = random.nextDouble() * sum;
130
for (int i = 0 ; i < dx2.length; i++) {
132
final T p = pointSet.remove(i);
133
resultSet.add(new Cluster<T>(p));
144
* Returns the nearest {@link Cluster} to the given point
146
* @param <T> type of the points to cluster
147
* @param clusters the {@link Cluster}s to search
148
* @param point the point to find the nearest {@link Cluster} for
149
* @return the nearest {@link Cluster} to the given point
151
private static <T extends Clusterable<T>> Cluster<T>
152
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
153
double minDistance = Double.MAX_VALUE;
154
Cluster<T> minCluster = null;
155
for (final Cluster<T> c : clusters) {
156
final double distance = point.distanceFrom(c.getCenter());
157
if (distance < minDistance) {
158
minDistance = distance;