2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* KernelEstimator.java
19
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
23
package weka.estimators;
25
import weka.core.Capabilities.Capability;
26
import weka.core.Capabilities;
27
import weka.core.Utils;
28
import weka.core.Statistics;
31
* Simple kernel density estimator. Uses one gaussian kernel per observed
34
* @author Len Trigg (trigg@cs.waikato.ac.nz)
35
* @version $Revision: 1.7 $
37
public class KernelEstimator extends Estimator implements IncrementalEstimator {
39
/** for serialization */
40
private static final long serialVersionUID = 3646923563367683925L;
42
/** Vector containing all of the values seen */
43
private double [] m_Values;
45
/** Vector containing the associated weights */
46
private double [] m_Weights;
48
/** Number of values stored in m_Weights and m_Values so far */
49
private int m_NumValues;
51
/** The sum of the weights so far */
52
private double m_SumOfWeights;
54
/** The standard deviation */
55
private double m_StandardDev;
57
/** The precision of data values */
58
private double m_Precision;
60
/** Whether we can optimise the kernel summation */
61
private boolean m_AllWeightsOne;
63
/** Maximum percentage error permitted in probability calculations */
64
private static double MAX_ERROR = 0.01;
68
* Execute a binary search to locate the nearest data value
70
* @param the data value to locate
71
* @return the index of the nearest data value
73
private int findNearestValue(double key) {
76
int high = m_NumValues;
79
middle = (low + high) / 2;
80
double current = m_Values[middle];
86
} else if (current < key) {
94
* Round a data value using the defined precision for this estimator
96
* @param data the value to round
97
* @return the rounded data value
99
private double round(double data) {
101
return Math.rint(data / m_Precision) * m_Precision;
109
* Constructor that takes a precision argument.
111
* @param precision the precision to which numeric values are given. For
112
* example, if the precision is stated to be 0.1, the values in the
113
* interval (0.25,0.35] are all treated as 0.3.
115
public KernelEstimator(double precision) {
117
m_Values = new double [50];
118
m_Weights = new double [50];
121
m_AllWeightsOne = true;
122
m_Precision = precision;
123
// precision cannot be zero
124
if (m_Precision < Utils.SMALL) m_Precision = Utils.SMALL;
125
// m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
126
m_StandardDev = m_Precision / (2 * 3);
130
* Add a new data value to the current estimator.
132
* @param data the new data value
133
* @param weight the weight assigned to the data value
135
public void addValue(double data, double weight) {
141
int insertIndex = findNearestValue(data);
142
if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
143
if (m_NumValues < m_Values.length) {
144
int left = m_NumValues - insertIndex;
145
System.arraycopy(m_Values, insertIndex,
146
m_Values, insertIndex + 1, left);
147
System.arraycopy(m_Weights, insertIndex,
148
m_Weights, insertIndex + 1, left);
150
m_Values[insertIndex] = data;
151
m_Weights[insertIndex] = weight;
154
double [] newValues = new double [m_Values.length * 2];
155
double [] newWeights = new double [m_Values.length * 2];
156
int left = m_NumValues - insertIndex;
157
System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
158
System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
159
newValues[insertIndex] = data;
160
newWeights[insertIndex] = weight;
161
System.arraycopy(m_Values, insertIndex,
162
newValues, insertIndex + 1, left);
163
System.arraycopy(m_Weights, insertIndex,
164
newWeights, insertIndex + 1, left);
166
m_Values = newValues;
167
m_Weights = newWeights;
170
m_AllWeightsOne = false;
173
m_Weights[insertIndex] += weight;
174
m_AllWeightsOne = false;
176
m_SumOfWeights += weight;
177
double range = m_Values[m_NumValues - 1] - m_Values[0];
179
m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
180
// allow at most 3 sds within one interval
181
m_Precision / (2 * 3));
186
* Get a probability estimate for a value.
188
* @param data the value to estimate the probability of
189
* @return the estimated probability of the supplied value
191
public double getProbability(double data) {
193
double delta = 0, sum = 0, currentProb = 0;
194
double zLower = 0, zUpper = 0;
195
if (m_NumValues == 0) {
196
zLower = (data - (m_Precision / 2)) / m_StandardDev;
197
zUpper = (data + (m_Precision / 2)) / m_StandardDev;
198
return (Statistics.normalProbability(zUpper)
199
- Statistics.normalProbability(zLower));
201
double weightSum = 0;
202
int start = findNearestValue(data);
203
for (int i = start; i < m_NumValues; i++) {
204
delta = m_Values[i] - data;
205
zLower = (delta - (m_Precision / 2)) / m_StandardDev;
206
zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
207
currentProb = (Statistics.normalProbability(zUpper)
208
- Statistics.normalProbability(zLower));
209
sum += currentProb * m_Weights[i];
211
System.out.print("zL" + (i + 1) + ": " + zLower + " ");
212
System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
213
System.out.print("P" + (i + 1) + ": " + currentProb + " ");
214
System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
216
weightSum += m_Weights[i];
217
if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
221
for (int i = start - 1; i >= 0; i--) {
222
delta = m_Values[i] - data;
223
zLower = (delta - (m_Precision / 2)) / m_StandardDev;
224
zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
225
currentProb = (Statistics.normalProbability(zUpper)
226
- Statistics.normalProbability(zLower));
227
sum += currentProb * m_Weights[i];
228
weightSum += m_Weights[i];
229
if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
233
return sum / m_SumOfWeights;
236
/** Display a representation of this estimator */
237
public String toString() {
239
String result = m_NumValues + " Normal Kernels. \nStandardDev = "
240
+ Utils.doubleToString(m_StandardDev,6,4)
241
+ " Precision = " + m_Precision;
242
if (m_NumValues == 0) {
243
result += " \nMean = 0";
245
result += " \nMeans =";
246
for (int i = 0; i < m_NumValues; i++) {
247
result += " " + m_Values[i];
249
if (!m_AllWeightsOne) {
250
result += "\nWeights = ";
251
for (int i = 0; i < m_NumValues; i++) {
252
result += " " + m_Weights[i];
256
return result + "\n";
260
* Returns default capabilities of the classifier.
262
* @return the capabilities of this classifier
264
public Capabilities getCapabilities() {
265
Capabilities result = super.getCapabilities();
268
result.enable(Capability.NUMERIC_ATTRIBUTES);
273
* Main method for testing this class.
275
* @param argv should contain a sequence of numeric values
277
public static void main(String [] argv) {
280
if (argv.length < 2) {
281
System.out.println("Please specify a set of instances.");
284
KernelEstimator newEst = new KernelEstimator(0.01);
285
for (int i = 0; i < argv.length - 3; i += 2) {
286
newEst.addValue(Double.valueOf(argv[i]).doubleValue(),
287
Double.valueOf(argv[i + 1]).doubleValue());
289
System.out.println(newEst);
291
double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
292
double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
293
for (double current = start; current < finish;
294
current += (finish - start) / 50) {
295
System.out.println("Data: " + current + " "
296
+ newEst.getProbability(current));
298
} catch (Exception e) {
299
System.out.println(e.getMessage());