~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/estimators/KernelEstimator.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    KernelEstimator.java
 
19
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.estimators;
 
24
 
 
25
import weka.core.Capabilities.Capability;
 
26
import weka.core.Capabilities;
 
27
import weka.core.Utils;
 
28
import weka.core.Statistics;
 
29
 
 
30
/** 
 
31
 * Simple kernel density estimator. Uses one gaussian kernel per observed
 
32
 * data value.
 
33
 *
 
34
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 
35
 * @version $Revision: 1.7 $
 
36
 */
 
37
public class KernelEstimator extends Estimator implements IncrementalEstimator {
 
38
 
 
39
  /** for serialization */
 
40
  private static final long serialVersionUID = 3646923563367683925L;
 
41
 
 
42
  /** Vector containing all of the values seen */
 
43
  private double [] m_Values;
 
44
 
 
45
  /** Vector containing the associated weights */
 
46
  private double [] m_Weights;
 
47
 
 
48
  /** Number of values stored in m_Weights and m_Values so far */
 
49
  private int m_NumValues;
 
50
 
 
51
  /** The sum of the weights so far */
 
52
  private double m_SumOfWeights;
 
53
 
 
54
  /** The standard deviation */
 
55
  private double m_StandardDev;
 
56
 
 
57
  /** The precision of data values */
 
58
  private double m_Precision;
 
59
 
 
60
  /** Whether we can optimise the kernel summation */
 
61
  private boolean m_AllWeightsOne;
 
62
 
 
63
  /** Maximum percentage error permitted in probability calculations */
 
64
  private static double MAX_ERROR = 0.01;
 
65
 
 
66
 
 
67
  /**
 
68
   * Execute a binary search to locate the nearest data value
 
69
   *
 
70
   * @param the data value to locate
 
71
   * @return the index of the nearest data value
 
72
   */
 
73
  private int findNearestValue(double key) {
 
74
 
 
75
    int low = 0; 
 
76
    int high = m_NumValues;
 
77
    int middle = 0;
 
78
    while (low < high) {
 
79
      middle = (low + high) / 2;
 
80
      double current = m_Values[middle];
 
81
      if (current == key) {
 
82
        return middle;
 
83
      }
 
84
      if (current > key) {
 
85
        high = middle;
 
86
      } else if (current < key) {
 
87
        low = middle + 1;
 
88
      }
 
89
    }
 
90
    return low;
 
91
  }
 
92
 
 
93
  /**
 
94
   * Round a data value using the defined precision for this estimator
 
95
   *
 
96
   * @param data the value to round
 
97
   * @return the rounded data value
 
98
   */
 
99
  private double round(double data) {
 
100
 
 
101
    return Math.rint(data / m_Precision) * m_Precision;
 
102
  }
 
103
  
 
104
  // ===============
 
105
  // Public methods.
 
106
  // ===============
 
107
  
 
108
  /**
 
109
   * Constructor that takes a precision argument.
 
110
   *
 
111
   * @param precision the  precision to which numeric values are given. For
 
112
   * example, if the precision is stated to be 0.1, the values in the
 
113
   * interval (0.25,0.35] are all treated as 0.3. 
 
114
   */
 
115
  public KernelEstimator(double precision) {
 
116
 
 
117
    m_Values = new double [50];
 
118
    m_Weights = new double [50];
 
119
    m_NumValues = 0;
 
120
    m_SumOfWeights = 0;
 
121
    m_AllWeightsOne = true;
 
122
    m_Precision = precision;
 
123
    // precision cannot be zero
 
124
    if (m_Precision < Utils.SMALL) m_Precision = Utils.SMALL;
 
125
    //    m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
 
126
    m_StandardDev = m_Precision / (2 * 3);
 
127
  }
 
128
 
 
129
  /**
 
130
   * Add a new data value to the current estimator.
 
131
   *
 
132
   * @param data the new data value 
 
133
   * @param weight the weight assigned to the data value 
 
134
   */
 
135
  public void addValue(double data, double weight) {
 
136
    
 
137
    if (weight == 0) {
 
138
      return;
 
139
    }
 
140
    data = round(data);
 
141
    int insertIndex = findNearestValue(data);
 
142
    if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
 
143
      if (m_NumValues < m_Values.length) {
 
144
        int left = m_NumValues - insertIndex; 
 
145
        System.arraycopy(m_Values, insertIndex, 
 
146
            m_Values, insertIndex + 1, left);
 
147
        System.arraycopy(m_Weights, insertIndex, 
 
148
            m_Weights, insertIndex + 1, left);
 
149
        
 
150
        m_Values[insertIndex] = data;
 
151
        m_Weights[insertIndex] = weight;
 
152
        m_NumValues++;
 
153
      } else {
 
154
        double [] newValues = new double [m_Values.length * 2];
 
155
        double [] newWeights = new double [m_Values.length * 2];
 
156
        int left = m_NumValues - insertIndex; 
 
157
        System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
 
158
        System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
 
159
        newValues[insertIndex] = data;
 
160
        newWeights[insertIndex] = weight;
 
161
        System.arraycopy(m_Values, insertIndex, 
 
162
            newValues, insertIndex + 1, left);
 
163
        System.arraycopy(m_Weights, insertIndex, 
 
164
            newWeights, insertIndex + 1, left);
 
165
        m_NumValues++;
 
166
        m_Values = newValues;
 
167
        m_Weights = newWeights;
 
168
      }
 
169
      if (weight != 1) {
 
170
        m_AllWeightsOne = false;
 
171
      }
 
172
    } else {
 
173
      m_Weights[insertIndex] += weight;
 
174
      m_AllWeightsOne = false;      
 
175
    }
 
176
    m_SumOfWeights += weight;
 
177
    double range = m_Values[m_NumValues - 1] - m_Values[0];
 
178
    if (range > 0) {
 
179
      m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights), 
 
180
          // allow at most 3 sds within one interval
 
181
          m_Precision / (2 * 3));
 
182
    }
 
183
  }
 
184
  
 
185
  /**
 
186
   * Get a probability estimate for a value.
 
187
   *
 
188
   * @param data the value to estimate the probability of
 
189
   * @return the estimated probability of the supplied value
 
190
   */
 
191
  public double getProbability(double data) {
 
192
 
 
193
    double delta = 0, sum = 0, currentProb = 0;
 
194
    double zLower = 0, zUpper = 0;
 
195
    if (m_NumValues == 0) {
 
196
      zLower = (data - (m_Precision / 2)) / m_StandardDev;
 
197
      zUpper = (data + (m_Precision / 2)) / m_StandardDev;
 
198
      return (Statistics.normalProbability(zUpper)
 
199
              - Statistics.normalProbability(zLower));
 
200
    }
 
201
    double weightSum = 0;
 
202
    int start = findNearestValue(data);
 
203
    for (int i = start; i < m_NumValues; i++) {
 
204
      delta = m_Values[i] - data;
 
205
      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
 
206
      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
 
207
      currentProb = (Statistics.normalProbability(zUpper)
 
208
                     - Statistics.normalProbability(zLower));
 
209
      sum += currentProb * m_Weights[i];
 
210
      /*
 
211
      System.out.print("zL" + (i + 1) + ": " + zLower + " ");
 
212
      System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
 
213
      System.out.print("P" + (i + 1) + ": " + currentProb + " ");
 
214
      System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
 
215
      */
 
216
      weightSum += m_Weights[i];
 
217
      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
 
218
        break;
 
219
      }
 
220
    }
 
221
    for (int i = start - 1; i >= 0; i--) {
 
222
      delta = m_Values[i] - data;
 
223
      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
 
224
      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
 
225
      currentProb = (Statistics.normalProbability(zUpper)
 
226
                     - Statistics.normalProbability(zLower));
 
227
      sum += currentProb * m_Weights[i];
 
228
      weightSum += m_Weights[i];
 
229
      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
 
230
        break;
 
231
      }
 
232
    }
 
233
    return sum / m_SumOfWeights;
 
234
  }
 
235
 
 
236
  /** Display a representation of this estimator */
 
237
  public String toString() {
 
238
 
 
239
    String result = m_NumValues + " Normal Kernels. \nStandardDev = " 
 
240
      + Utils.doubleToString(m_StandardDev,6,4)
 
241
      + " Precision = " + m_Precision;
 
242
    if (m_NumValues == 0) {
 
243
      result += "  \nMean = 0";
 
244
    } else {
 
245
      result += "  \nMeans =";
 
246
      for (int i = 0; i < m_NumValues; i++) {
 
247
        result += " " + m_Values[i];
 
248
      }
 
249
      if (!m_AllWeightsOne) {
 
250
        result += "\nWeights = ";
 
251
        for (int i = 0; i < m_NumValues; i++) {
 
252
          result += " " + m_Weights[i];
 
253
        }
 
254
      }
 
255
    }
 
256
    return result + "\n";
 
257
  }
 
258
 
 
259
  /**
 
260
   * Returns default capabilities of the classifier.
 
261
   *
 
262
   * @return      the capabilities of this classifier
 
263
   */
 
264
  public Capabilities getCapabilities() {
 
265
    Capabilities result = super.getCapabilities();
 
266
    
 
267
    // attributes
 
268
    result.enable(Capability.NUMERIC_ATTRIBUTES);
 
269
    return result;
 
270
  }
 
271
 
 
272
  /**
 
273
   * Main method for testing this class.
 
274
   *
 
275
   * @param argv should contain a sequence of numeric values
 
276
   */
 
277
  public static void main(String [] argv) {
 
278
 
 
279
    try {
 
280
      if (argv.length < 2) {
 
281
        System.out.println("Please specify a set of instances.");
 
282
        return;
 
283
      }
 
284
      KernelEstimator newEst = new KernelEstimator(0.01);
 
285
      for (int i = 0; i < argv.length - 3; i += 2) {
 
286
        newEst.addValue(Double.valueOf(argv[i]).doubleValue(), 
 
287
                        Double.valueOf(argv[i + 1]).doubleValue());
 
288
      }
 
289
      System.out.println(newEst);
 
290
 
 
291
      double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
 
292
      double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
 
293
      for (double current = start; current < finish; 
 
294
          current += (finish - start) / 50) {
 
295
        System.out.println("Data: " + current + " " 
 
296
                           + newEst.getProbability(current));
 
297
      }
 
298
    } catch (Exception e) {
 
299
      System.out.println(e.getMessage());
 
300
    }
 
301
  }
 
302
}
 
303
 
 
304
 
 
305
 
 
306
 
 
307
 
 
308
 
 
309
 
 
310