~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/estimators/NormalEstimator.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    NormalEstimator.java
 
19
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.estimators;
 
24
 
 
25
import weka.core.Capabilities.Capability;
 
26
import weka.core.Capabilities;
 
27
import weka.core.Statistics;
 
28
import weka.core.Utils;
 
29
 
 
30
/** 
 
31
 * Simple probability estimator that places a single normal distribution
 
32
 * over the observed values.
 
33
 *
 
34
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 
35
 * @version $Revision: 1.8 $
 
36
 */
 
37
public class NormalEstimator
 
38
  extends Estimator
 
39
  implements IncrementalEstimator {
 
40
 
 
41
  /** for serialization */
 
42
  private static final long serialVersionUID = 93584379632315841L;
 
43
 
 
44
  /** The sum of the weights */
 
45
  private double m_SumOfWeights;
 
46
 
 
47
  /** The sum of the values seen */
 
48
  private double m_SumOfValues;
 
49
 
 
50
  /** The sum of the values squared */
 
51
  private double m_SumOfValuesSq;
 
52
 
 
53
  /** The current mean */
 
54
  private double m_Mean;
 
55
 
 
56
  /** The current standard deviation */
 
57
  private double m_StandardDev;
 
58
 
 
59
  /** The precision of numeric values ( = minimum std dev permitted) */
 
60
  private double m_Precision;
 
61
 
 
62
  /**
 
63
   * Round a data value using the defined precision for this estimator
 
64
   *
 
65
   * @param data the value to round
 
66
   * @return the rounded data value
 
67
   */
 
68
  private double round(double data) {
 
69
 
 
70
    return Math.rint(data / m_Precision) * m_Precision;
 
71
  }
 
72
  
 
73
  // ===============
 
74
  // Public methods.
 
75
  // ===============
 
76
  
 
77
  /**
 
78
   * Constructor that takes a precision argument.
 
79
   *
 
80
   * @param precision the precision to which numeric values are given. For
 
81
   * example, if the precision is stated to be 0.1, the values in the
 
82
   * interval (0.25,0.35] are all treated as 0.3. 
 
83
   */
 
84
  public NormalEstimator(double precision) {
 
85
 
 
86
    m_Precision = precision;
 
87
 
 
88
    // Allow at most 3 sd's within one interval
 
89
    m_StandardDev = m_Precision / (2 * 3);
 
90
  }
 
91
 
 
92
  /**
 
93
   * Add a new data value to the current estimator.
 
94
   *
 
95
   * @param data the new data value 
 
96
   * @param weight the weight assigned to the data value 
 
97
   */
 
98
  public void addValue(double data, double weight) {
 
99
 
 
100
    if (weight == 0) {
 
101
      return;
 
102
    }
 
103
    data = round(data);
 
104
    m_SumOfWeights += weight;
 
105
    m_SumOfValues += data * weight;
 
106
    m_SumOfValuesSq += data * data * weight;
 
107
 
 
108
    if (m_SumOfWeights > 0) {
 
109
      m_Mean = m_SumOfValues / m_SumOfWeights;
 
110
      double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq 
 
111
                                          - m_Mean * m_SumOfValues) 
 
112
                                         / m_SumOfWeights);
 
113
      // If the stdDev ~= 0, we really have no idea of scale yet, 
 
114
      // so stick with the default. Otherwise...
 
115
      if (stdDev > 1e-10) {
 
116
        m_StandardDev = Math.max(m_Precision / (2 * 3), 
 
117
                                 // allow at most 3sd's within one interval 
 
118
                                 stdDev);
 
119
      }
 
120
    }
 
121
  }
 
122
 
 
123
  /**
 
124
   * Get a probability estimate for a value
 
125
   *
 
126
   * @param data the value to estimate the probability of
 
127
   * @return the estimated probability of the supplied value
 
128
   */
 
129
  public double getProbability(double data) {
 
130
 
 
131
    data = round(data);
 
132
    double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;
 
133
    double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;
 
134
    
 
135
    double pLower = Statistics.normalProbability(zLower);
 
136
    double pUpper = Statistics.normalProbability(zUpper);
 
137
    return pUpper - pLower;
 
138
  }
 
139
 
 
140
  /**
 
141
   * Display a representation of this estimator
 
142
   */
 
143
  public String toString() {
 
144
 
 
145
    return "Normal Distribution. Mean = " + Utils.doubleToString(m_Mean, 4)
 
146
      + " StandardDev = " + Utils.doubleToString(m_StandardDev, 4)
 
147
      + " WeightSum = " + Utils.doubleToString(m_SumOfWeights, 4)
 
148
      + " Precision = " + m_Precision + "\n";
 
149
  }
 
150
 
 
151
  /**
 
152
   * Returns default capabilities of the classifier.
 
153
   *
 
154
   * @return      the capabilities of this classifier
 
155
   */
 
156
  public Capabilities getCapabilities() {
 
157
    Capabilities result = super.getCapabilities();
 
158
    
 
159
    // attributes
 
160
    result.enable(Capability.NUMERIC_ATTRIBUTES);
 
161
    return result;
 
162
  }
 
163
 
 
164
  /**
 
165
   * Main method for testing this class.
 
166
   *
 
167
   * @param argv should contain a sequence of numeric values
 
168
   */
 
169
  public static void main(String [] argv) {
 
170
 
 
171
    try {
 
172
 
 
173
      if (argv.length == 0) {
 
174
        System.out.println("Please specify a set of instances.");
 
175
        return;
 
176
      }
 
177
      NormalEstimator newEst = new NormalEstimator(0.01);
 
178
      for(int i = 0; i < argv.length; i++) {
 
179
        double current = Double.valueOf(argv[i]).doubleValue();
 
180
        System.out.println(newEst);
 
181
        System.out.println("Prediction for " + current 
 
182
                           + " = " + newEst.getProbability(current));
 
183
        newEst.addValue(current, 1);
 
184
      }
 
185
    } catch (Exception e) {
 
186
      System.out.println(e.getMessage());
 
187
    }
 
188
  }
 
189
}
 
190
 
 
191
 
 
192
 
 
193
 
 
194
 
 
195
 
 
196
 
 
197