2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or (at
5
* your option) any later version.
7
* This program is distributed in the hope that it will be useful, but
8
* WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
* General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */
17
* MixtureDistribution.java
18
* Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
22
package weka.classifiers.functions.pace;
24
import weka.core.TechnicalInformation;
25
import weka.core.TechnicalInformation.Type;
26
import weka.core.TechnicalInformation.Field;
27
import weka.core.TechnicalInformationHandler;
28
import weka.core.matrix.DoubleVector;
29
import weka.core.matrix.IntVector;
32
* Abtract class for manipulating mixture distributions. <p>
36
* Wang, Y. (2000). "A new approach to fitting linear models in high
37
* dimensional spaces." PhD Thesis. Department of Computer Science,
38
* University of Waikato, New Zealand. <p>
40
* Wang, Y. and Witten, I. H. (2002). "Modeling for optimal probability
41
* prediction." Proceedings of ICML'2002. Sydney. <p>
43
* @author Yong Wang (yongwang@cs.waikato.ac.nz)
44
* @version $Revision: 1.4 $ */
46
public abstract class MixtureDistribution
47
implements TechnicalInformationHandler {
49
protected DiscreteFunction mixingDistribution;
51
/** The nonnegative-measure-based method */
52
public static final int NNMMethod = 1;
54
/** The probability-measure-based method */
55
public static final int PMMethod = 2;
57
// The CDF-based method
58
// public static final int CDFMethod = 3;
60
// The method based on the Kolmogrov and von Mises measure
61
// public static final int ModifiedCDFMethod = 4;
64
* Returns an instance of a TechnicalInformation object, containing
65
* detailed information about the technical background of this class,
66
* e.g., paper reference or book this class is based on.
68
* @return the technical information about this class
70
public TechnicalInformation getTechnicalInformation() {
71
TechnicalInformation result;
72
TechnicalInformation additional;
74
result = new TechnicalInformation(Type.PHDTHESIS);
75
result.setValue(Field.AUTHOR, "Wang, Y");
76
result.setValue(Field.YEAR, "2000");
77
result.setValue(Field.TITLE, "A new approach to fitting linear models in high dimensional spaces");
78
result.setValue(Field.SCHOOL, "Department of Computer Science, University of Waikato");
79
result.setValue(Field.ADDRESS, "Hamilton, New Zealand");
81
additional = result.add(Type.INPROCEEDINGS);
82
additional.setValue(Field.AUTHOR, "Wang, Y. and Witten, I. H.");
83
additional.setValue(Field.YEAR, "2002");
84
additional.setValue(Field.TITLE, "Modeling for optimal probability prediction");
85
additional.setValue(Field.BOOKTITLE, "Proceedings of the Nineteenth International Conference in Machine Learning");
86
additional.setValue(Field.YEAR, "2002");
87
additional.setValue(Field.PAGES, "650-657");
88
additional.setValue(Field.ADDRESS, "Sydney, Australia");
94
* Gets the mixing distribution
96
* @return the mixing distribution
98
public DiscreteFunction getMixingDistribution() {
99
return mixingDistribution;
102
/** Sets the mixing distribution
103
* @param d the mixing distribution
105
public void setMixingDistribution( DiscreteFunction d ) {
106
mixingDistribution = d;
109
/** Fits the mixture (or mixing) distribution to the data. The default
110
* method is the nonnegative-measure-based method.
111
* @param data the data, supposedly generated from the mixture model */
112
public void fit( DoubleVector data ) {
113
fit( data, NNMMethod );
116
/** Fits the mixture (or mixing) distribution to the data.
117
* @param data the data supposedly generated from the mixture
118
* @param method the method to be used. Refer to the static final
119
* variables of this class. */
120
public void fit( DoubleVector data, int method ) {
121
DoubleVector data2 = (DoubleVector) data.clone();
122
if( data2.unsorted() ) data2.sort();
124
int n = data2.size();
127
DiscreteFunction d = new DiscreteFunction();
128
for( int i = 0; i < n-1; i++ ) {
129
if( separable( data2, start, i, data2.get(i+1) ) &&
130
separable( data2, i+1, n-1, data2.get(i) ) ) {
131
subset = (DoubleVector) data2.subvector( start, i );
132
d.plusEquals( fitForSingleCluster( subset, method ).
133
timesEquals(i - start + 1) );
137
subset = (DoubleVector) data2.subvector( start, n-1 );
138
d.plusEquals( fitForSingleCluster( subset, method ).
139
timesEquals(n - start) );
142
mixingDistribution = d;
146
* Fits the mixture (or mixing) distribution to the data. The data is
147
* not pre-clustered for computational efficiency.
149
* @param data the data supposedly generated from the mixture
150
* @param method the method to be used. Refer to the static final
151
* variables of this class.
152
* @return the generated distribution
154
public DiscreteFunction fitForSingleCluster( DoubleVector data,
157
if( data.size() < 2 ) return new DiscreteFunction( data );
158
DoubleVector sp = supportPoints( data, 0 );
159
PaceMatrix fi = fittingIntervals( data );
160
PaceMatrix pm = probabilityMatrix( sp, fi );
162
PaceMatrix( empiricalProbability( data, fi ).
163
timesEquals( 1. / data.size() ) );
165
IntVector pvt = (IntVector) IntVector.seq(0, sp.size()-1);
166
DoubleVector weights;
170
weights = pm.nnls( epm, pvt );
173
weights = pm.nnlse1( epm, pvt );
176
throw new IllegalArgumentException("unknown method");
179
DoubleVector sp2 = new DoubleVector( pvt.size() );
180
for( int i = 0; i < sp2.size(); i++ ){
181
sp2.set( i, sp.get(pvt.get(i)) );
184
DiscreteFunction d = new DiscreteFunction( sp2, weights );
191
* Return true if a value can be considered for mixture estimatino
192
* separately from the data indexed between i0 and i1
194
* @param data the data supposedly generated from the mixture
195
* @param i0 the index of the first element in the group
196
* @param i1 the index of the last element in the group
198
* @return true if a value can be considered
200
public abstract boolean separable( DoubleVector data,
201
int i0, int i1, double x );
204
* Contructs the set of support points for mixture estimation.
206
* @param data the data supposedly generated from the mixture
207
* @param ne the number of extra data that are suppposedly discarded
208
* earlier and not passed into here
209
* @return the set of support points
211
public abstract DoubleVector supportPoints( DoubleVector data, int ne );
214
* Contructs the set of fitting intervals for mixture estimation.
216
* @param data the data supposedly generated from the mixture
217
* @return the set of fitting intervals
219
public abstract PaceMatrix fittingIntervals( DoubleVector data );
222
* Contructs the probability matrix for mixture estimation, given a set
223
* of support points and a set of intervals.
225
* @param s the set of support points
226
* @param intervals the intervals
227
* @return the probability matrix
229
public abstract PaceMatrix probabilityMatrix( DoubleVector s,
230
PaceMatrix intervals );
233
* Computes the empirical probabilities of the data over a set of
236
* @param data the data
237
* @param intervals the intervals
238
* @return the empirical probabilities
240
public PaceMatrix empiricalProbability( DoubleVector data,
241
PaceMatrix intervals )
244
int k = intervals.getRowDimension();
245
PaceMatrix epm = new PaceMatrix( k, 1, 0 );
248
for( int j = 0; j < n; j ++ ) {
249
for(int i = 0; i < k; i++ ) {
251
if( intervals.get(i, 0) == data.get(j) ||
252
intervals.get(i, 1) == data.get(j) ) point = 0.5;
253
else if( intervals.get(i, 0) < data.get(j) &&
254
intervals.get(i, 1) > data.get(j) ) point = 1.0;
255
epm.setPlus( i, 0, point);
262
* Converts to a string
264
* @return a string representation
266
public String toString()
268
return "The mixing distribution:\n" + mixingDistribution.toString();