2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* NormalEstimator.java
19
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
23
package weka.estimators;
25
import weka.core.Capabilities.Capability;
26
import weka.core.Capabilities;
27
import weka.core.Statistics;
28
import weka.core.Utils;
31
* Simple probability estimator that places a single normal distribution
32
* over the observed values.
34
* @author Len Trigg (trigg@cs.waikato.ac.nz)
35
* @version $Revision: 1.8 $
37
public class NormalEstimator
39
implements IncrementalEstimator {
41
/** for serialization */
42
private static final long serialVersionUID = 93584379632315841L;
44
/** The sum of the weights */
45
private double m_SumOfWeights;
47
/** The sum of the values seen */
48
private double m_SumOfValues;
50
/** The sum of the values squared */
51
private double m_SumOfValuesSq;
53
/** The current mean */
54
private double m_Mean;
56
/** The current standard deviation */
57
private double m_StandardDev;
59
/** The precision of numeric values ( = minimum std dev permitted) */
60
private double m_Precision;
63
* Round a data value using the defined precision for this estimator
65
* @param data the value to round
66
* @return the rounded data value
68
private double round(double data) {
70
return Math.rint(data / m_Precision) * m_Precision;
78
* Constructor that takes a precision argument.
80
* @param precision the precision to which numeric values are given. For
81
* example, if the precision is stated to be 0.1, the values in the
82
* interval (0.25,0.35] are all treated as 0.3.
84
public NormalEstimator(double precision) {
86
m_Precision = precision;
88
// Allow at most 3 sd's within one interval
89
m_StandardDev = m_Precision / (2 * 3);
93
* Add a new data value to the current estimator.
95
* @param data the new data value
96
* @param weight the weight assigned to the data value
98
public void addValue(double data, double weight) {
104
m_SumOfWeights += weight;
105
m_SumOfValues += data * weight;
106
m_SumOfValuesSq += data * data * weight;
108
if (m_SumOfWeights > 0) {
109
m_Mean = m_SumOfValues / m_SumOfWeights;
110
double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq
111
- m_Mean * m_SumOfValues)
113
// If the stdDev ~= 0, we really have no idea of scale yet,
114
// so stick with the default. Otherwise...
115
if (stdDev > 1e-10) {
116
m_StandardDev = Math.max(m_Precision / (2 * 3),
117
// allow at most 3sd's within one interval
124
* Get a probability estimate for a value
126
* @param data the value to estimate the probability of
127
* @return the estimated probability of the supplied value
129
public double getProbability(double data) {
132
double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;
133
double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;
135
double pLower = Statistics.normalProbability(zLower);
136
double pUpper = Statistics.normalProbability(zUpper);
137
return pUpper - pLower;
141
* Display a representation of this estimator
143
public String toString() {
145
return "Normal Distribution. Mean = " + Utils.doubleToString(m_Mean, 4)
146
+ " StandardDev = " + Utils.doubleToString(m_StandardDev, 4)
147
+ " WeightSum = " + Utils.doubleToString(m_SumOfWeights, 4)
148
+ " Precision = " + m_Precision + "\n";
152
* Returns default capabilities of the classifier.
154
* @return the capabilities of this classifier
156
public Capabilities getCapabilities() {
157
Capabilities result = super.getCapabilities();
160
result.enable(Capability.NUMERIC_ATTRIBUTES);
165
* Main method for testing this class.
167
* @param argv should contain a sequence of numeric values
169
public static void main(String [] argv) {
173
if (argv.length == 0) {
174
System.out.println("Please specify a set of instances.");
177
NormalEstimator newEst = new NormalEstimator(0.01);
178
for(int i = 0; i < argv.length; i++) {
179
double current = Double.valueOf(argv[i]).doubleValue();
180
System.out.println(newEst);
181
System.out.println("Prediction for " + current
182
+ " = " + newEst.getProbability(current));
183
newEst.addValue(current, 1);
185
} catch (Exception e) {
186
System.out.println(e.getMessage());