~ubuntu-branches/ubuntu/maverick/commons-math/maverick

« back to all changes in this revision

Viewing changes to src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java

  • Committer: Bazaar Package Importer
  • Author(s): Damien Raude-Morvan
  • Date: 2009-08-22 01:13:25 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20090822011325-hi4peq1ua5weguwn
Tags: 2.0-1
* New upstream release.
* Set Maintainer field to Debian Java Team
* Add myself as Uploaders
* Switch to Quilt patch system:
  - Refresh all patchs
  - Remove B-D on dpatch, Add B-D on quilt
  - Include patchsys-quilt.mk in debian/rules
* Bump Standards-Version to 3.8.3:
  - Add a README.source to describe patch system
* Maven POMs:
  - Add a Build-Depends-Indep dependency on maven-repo-helper
  - Use mh_installpom and mh_installjar to install the POM and the jar to the
    Maven repository
* Use default-jdk/jre:
  - Depends on java5-runtime-headless
  - Build-Depends on default-jdk
  - Use /usr/lib/jvm/default-java as JAVA_HOME
* Move api documentation to /usr/share/doc/libcommons-math-java/api
* Build-Depends on junit4 instead of junit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * Licensed to the Apache Software Foundation (ASF) under one or more
 
3
 * contributor license agreements.  See the NOTICE file distributed with
 
4
 * this work for additional information regarding copyright ownership.
 
5
 * The ASF licenses this file to You under the Apache License, Version 2.0
 
6
 * (the "License"); you may not use this file except in compliance with
 
7
 * the License.  You may obtain a copy of the License at
 
8
 *
 
9
 *      http://www.apache.org/licenses/LICENSE-2.0
 
10
 *
 
11
 * Unless required by applicable law or agreed to in writing, software
 
12
 * distributed under the License is distributed on an "AS IS" BASIS,
 
13
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
14
 * See the License for the specific language governing permissions and
 
15
 * limitations under the License.
 
16
 */
 
17
 
 
18
package org.apache.commons.math.stat.clustering;
 
19
 
 
20
import java.util.ArrayList;
 
21
import java.util.Collection;
 
22
import java.util.List;
 
23
import java.util.Random;
 
24
 
 
25
/**
 
26
 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
 
27
 * @param <T> type of the points to cluster
 
28
 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
 
29
 * @version $Revision: 771076 $ $Date: 2009-05-03 12:28:48 -0400 (Sun, 03 May 2009) $
 
30
 * @since 2.0
 
31
 */
 
32
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
 
33
 
 
34
    /** Random generator for choosing initial centers. */
 
35
    private final Random random;
 
36
 
 
37
    /** Build a clusterer.
 
38
     * @param random random generator to use for choosing initial centers
 
39
     */
 
40
    public KMeansPlusPlusClusterer(final Random random) {
 
41
        this.random = random;
 
42
    }
 
43
 
 
44
    /**
 
45
     * Runs the K-means++ clustering algorithm.
 
46
     * 
 
47
     * @param points the points to cluster
 
48
     * @param k the number of clusters to split the data into
 
49
     * @param maxIterations the maximum number of iterations to run the algorithm
 
50
     *     for.  If negative, no maximum will be used
 
51
     * @return a list of clusters containing the points
 
52
     */
 
53
    public List<Cluster<T>> cluster(final Collection<T> points,
 
54
                                    final int k, final int maxIterations) {
 
55
        // create the initial clusters
 
56
        List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
 
57
        assignPointsToClusters(clusters, points);
 
58
 
 
59
        // iterate through updating the centers until we're done
 
60
        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 
 
61
        for (int count = 0; count < max; count++) {
 
62
            boolean clusteringChanged = false;
 
63
            List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
 
64
            for (final Cluster<T> cluster : clusters) {
 
65
                final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
 
66
                if (!newCenter.equals(cluster.getCenter())) {
 
67
                    clusteringChanged = true;
 
68
                }
 
69
                newClusters.add(new Cluster<T>(newCenter));
 
70
            }
 
71
            if (!clusteringChanged) {
 
72
                return clusters;
 
73
            }
 
74
            assignPointsToClusters(newClusters, points);
 
75
            clusters = newClusters;
 
76
        }
 
77
        return clusters;
 
78
    }
 
79
 
 
80
    /**
 
81
     * Adds the given points to the closest {@link Cluster}.
 
82
     * 
 
83
     * @param <T> type of the points to cluster
 
84
     * @param clusters the {@link Cluster}s to add the points to
 
85
     * @param points the points to add to the given {@link Cluster}s
 
86
     */
 
87
    private static <T extends Clusterable<T>> void
 
88
        assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
 
89
        for (final T p : points) {
 
90
            Cluster<T> cluster = getNearestCluster(clusters, p);
 
91
            cluster.addPoint(p);
 
92
        }
 
93
    }
 
94
 
 
95
    /**
 
96
     * Use K-means++ to choose the initial centers.
 
97
     * 
 
98
     * @param <T> type of the points to cluster
 
99
     * @param points the points to choose the initial centers from
 
100
     * @param k the number of centers to choose
 
101
     * @param random random generator to use
 
102
     * @return the initial centers
 
103
     */
 
104
    private static <T extends Clusterable<T>> List<Cluster<T>>
 
105
        chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
 
106
 
 
107
        final List<T> pointSet = new ArrayList<T>(points);
 
108
        final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
 
109
 
 
110
        // Choose one center uniformly at random from among the data points.
 
111
        final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
 
112
        resultSet.add(new Cluster<T>(firstPoint));
 
113
 
 
114
        final double[] dx2 = new double[pointSet.size()];
 
115
        while (resultSet.size() < k) {
 
116
            // For each data point x, compute D(x), the distance between x and 
 
117
            // the nearest center that has already been chosen.
 
118
            int sum = 0;
 
119
            for (int i = 0; i < pointSet.size(); i++) {
 
120
                final T p = pointSet.get(i);
 
121
                final Cluster<T> nearest = getNearestCluster(resultSet, p);
 
122
                final double d = p.distanceFrom(nearest.getCenter());
 
123
                sum += d * d;
 
124
                dx2[i] = sum;
 
125
            }
 
126
 
 
127
            // Add one new data point as a center. Each point x is chosen with
 
128
            // probability proportional to D(x)2
 
129
            final double r = random.nextDouble() * sum;
 
130
            for (int i = 0 ; i < dx2.length; i++) {
 
131
                if (dx2[i] >= r) {
 
132
                    final T p = pointSet.remove(i);
 
133
                    resultSet.add(new Cluster<T>(p));
 
134
                    break;
 
135
                }
 
136
            }
 
137
        }
 
138
 
 
139
        return resultSet;
 
140
 
 
141
    }
 
142
 
 
143
    /**
 
144
     * Returns the nearest {@link Cluster} to the given point
 
145
     * 
 
146
     * @param <T> type of the points to cluster
 
147
     * @param clusters the {@link Cluster}s to search
 
148
     * @param point the point to find the nearest {@link Cluster} for
 
149
     * @return the nearest {@link Cluster} to the given point
 
150
     */
 
151
    private static <T extends Clusterable<T>> Cluster<T>
 
152
        getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
 
153
        double minDistance = Double.MAX_VALUE;
 
154
        Cluster<T> minCluster = null;
 
155
        for (final Cluster<T> c : clusters) {
 
156
            final double distance = point.distanceFrom(c.getCenter());
 
157
            if (distance < minDistance) {
 
158
                minDistance = distance;
 
159
                minCluster = c;
 
160
            }
 
161
        }
 
162
        return minCluster;
 
163
    }
 
164
 
 
165
}