2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20
* Algorithm developed by: Fei ZHENG and Geoff Webb
21
* Code written by: Fei ZHENG and Janice Boughton
24
package weka.classifiers.bayes;
26
import weka.classifiers.Classifier;
27
import weka.classifiers.UpdateableClassifier;
28
import weka.core.Capabilities;
29
import weka.core.Instance;
30
import weka.core.Instances;
31
import weka.core.Option;
32
import weka.core.OptionHandler;
33
import weka.core.TechnicalInformation;
34
import weka.core.TechnicalInformationHandler;
35
import weka.core.Utils;
36
import weka.core.WeightedInstancesHandler;
37
import weka.core.Capabilities.Capability;
38
import weka.core.TechnicalInformation.Field;
39
import weka.core.TechnicalInformation.Type;
41
import java.util.Enumeration;
42
import java.util.Vector;
46
<!-- globalinfo-start -->
47
* AODEsr augments AODE with Subsumption Resolution.
48
* AODEsr detects specializations between two attribute values at
49
* classification time and deletes the generalization attribute value.
51
* For more information, see<br/>
53
* Zheng, F., Webb, G.I. (2006): Efficient lazy elimination for
54
* averaged-one dependence
55
* estimators. In: Proc. 23th Int. Conf. Machine Learning (ICML 2006),
58
* Note: the subsumption resolution technique is called lazy elimination
61
<!-- globalinfo-end -->
63
<!-- technical-bibtex-start -->
66
* @INPROCEEDINGS{ZhengWebbICML2006,
67
* AUTHOR = {Fei Zheng and Geoffrey I. Webb},
68
* TITLE = {Efficient Lazy Elimination for Averaged-One Dependence
70
* BOOKTITLE = {Proceedings of the Twenty-third International
71
* Conference on Machine Learning (ICML 2006)},
72
* ISBN = {1-59593-383-2},
73
* PAGES = {1113--1120},
74
* PUBLISHER = {ACM Press},
79
<!-- technical-bibtex-end -->
81
<!-- options-start -->
82
* Valid options are:<p/>
85
* Output debugging information
88
* <pre> -F <int>
89
* Impose a frequency limit for superParents
90
* (default is 1)</pre>
93
* Use Laplace estimation
94
* (default is m-estimation)</pre>
96
* <pre> -M <double>
97
* Specify the m value of m-estimation
98
* (default is 1)</pre>
100
* <pre>-C <int>
101
* Specify critical value for specialization-generalization.
103
* Larger values than the default of 50 substantially reduce
104
* the risk of incorrectly inferring that one value subsumes
105
* another, but also reduces the number of true subsumptions
106
* that are detected.</pre>
111
* @author Janice Boughton
112
* @version $Revision: 1.2 $
114
public class AODEsr extends Classifier
115
implements OptionHandler, WeightedInstancesHandler, UpdateableClassifier,
116
TechnicalInformationHandler {
118
/** for serialization */
119
static final long serialVersionUID = 5602143019183068848L;
122
* 3D array (m_NumClasses * m_TotalAttValues * m_TotalAttValues)
123
* of attribute counts, i.e. the number of times an attribute value occurs
124
* in conjunction with another attribute value and a class value.
126
private double [][][] m_CondiCounts;
129
* 2D array (m_TotalAttValues * m_TotalAttValues) of attributes counts.
130
* similar to m_CondiCounts, but ignoring class value.
132
private double [][] m_CondiCountsNoClass;
134
/** The number of times each class value occurs in the dataset */
135
private double [] m_ClassCounts;
137
/** The sums of attribute-class counts
138
* -- if there are no missing values for att, then
139
* m_SumForCounts[classVal][att] will be the same as
140
* m_ClassCounts[classVal]
142
private double [][] m_SumForCounts;
144
/** The number of classes */
145
private int m_NumClasses;
147
/** The number of attributes in dataset, including class */
148
private int m_NumAttributes;
150
/** The number of instances in the dataset */
151
private int m_NumInstances;
153
/** The index of the class attribute */
154
private int m_ClassIndex;
157
private Instances m_Instances;
160
* The total number of values (including an extra for each attribute's
161
* missing value, which are included in m_CondiCounts) for all attributes
162
* (not including class). Eg. for three atts each with two possible values,
163
* m_TotalAttValues would be 9 (6 values + 3 missing).
164
* This variable is used when allocating space for m_CondiCounts matrix.
166
private int m_TotalAttValues;
168
/** The starting index (in the m_CondiCounts matrix) of the values for each attribute */
169
private int [] m_StartAttIndex;
171
/** The number of values for each attribute */
172
private int [] m_NumAttValues;
174
/** The frequency of each attribute value for the dataset */
175
private double [] m_Frequencies;
177
/** The number of valid class values observed in dataset
178
* -- with no missing classes, this number is the same as m_NumInstances.
180
private double m_SumInstances;
182
/** An att's frequency must be this value or more to be a superParent */
183
private int m_Limit = 1;
185
/** If true, outputs debugging info */
186
private boolean m_Debug = false;
188
/** m value for m-estimation */
189
protected double m_MWeight = 1.0;
191
/** Using LapLace estimation or not*/
192
private boolean m_Laplace = false;
194
/** the critical value for the specialization-generalization */
195
private int m_Critical = 50;
199
* Returns a string describing this classifier
200
* @return a description of the classifier suitable for
201
* displaying in the explorer/experimenter gui
203
public String globalInfo() {
205
return "AODEsr augments AODE with Subsumption Resolution."
206
+"AODEsr detects specializations between two attribute "
207
+"values at classification time and deletes the generalization "
208
+"attribute value.\n"
209
+"For more information, see:\n"
210
+"Zheng, F., Webb, G.I. (2006): Efficient lazy elimination for "
211
+"averaged-one dependence "
212
+"estimators. In: Proc. 23th Int. Conf. Machine Learning (ICML 2006), "
217
* Returns an instance of a TechnicalInformation object, containing
218
* detailed information about the technical background of this class,
219
* e.g., paper reference or book this class is based on.
221
* @return the technical information about this class
223
public TechnicalInformation getTechnicalInformation() {
224
TechnicalInformation result;
226
result = new TechnicalInformation(Type.INPROCEEDINGS);
227
result.setValue(Field.AUTHOR, "Fei Zheng and Geoffrey I. Webb");
228
result.setValue(Field.YEAR, "2006");
229
result.setValue(Field.TITLE, "Efficient Lazy Elimination for Averaged-One Dependence Estimators");
230
result.setValue(Field.PAGES, "1113-1120");
231
result.setValue(Field.BOOKTITLE, "Proceedings of the Twenty-third International Conference on Machine Learning (ICML 2006)");
232
result.setValue(Field.PUBLISHER, "ACM Press");
233
result.setValue(Field.ISBN, "1-59593-383-2");
239
* Returns default capabilities of the classifier.
241
* @return the capabilities of this classifier
243
public Capabilities getCapabilities() {
244
Capabilities result = super.getCapabilities();
247
result.enable(Capability.NOMINAL_ATTRIBUTES);
248
result.enable(Capability.MISSING_VALUES);
251
result.enable(Capability.NOMINAL_CLASS);
252
result.enable(Capability.MISSING_CLASS_VALUES);
255
result.setMinimumNumberInstances(0);
261
* Generates the classifier.
263
* @param instances set of instances serving as training data
264
* @throws Exception if the classifier has not been generated
267
public void buildClassifier(Instances instances) throws Exception {
269
// can classifier handle the data?
270
getCapabilities().testWithFail(instances);
272
// remove instances with missing class
273
m_Instances = new Instances(instances);
274
m_Instances.deleteWithMissingClass();
276
// reset variable for this fold
278
m_ClassIndex = instances.classIndex();
279
m_NumInstances = m_Instances.numInstances();
280
m_NumAttributes = instances.numAttributes();
281
m_NumClasses = instances.numClasses();
283
// allocate space for attribute reference arrays
284
m_StartAttIndex = new int[m_NumAttributes];
285
m_NumAttValues = new int[m_NumAttributes];
287
m_TotalAttValues = 0;
288
for(int i = 0; i < m_NumAttributes; i++) {
289
if(i != m_ClassIndex) {
290
m_StartAttIndex[i] = m_TotalAttValues;
291
m_NumAttValues[i] = m_Instances.attribute(i).numValues();
292
m_TotalAttValues += m_NumAttValues[i] + 1;
293
// + 1 so room for missing value count
295
// m_StartAttIndex[i] = -1; // class isn't included
296
m_NumAttValues[i] = m_NumClasses;
300
// allocate space for counts and frequencies
301
m_CondiCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];
302
m_ClassCounts = new double[m_NumClasses];
303
m_SumForCounts = new double[m_NumClasses][m_NumAttributes];
304
m_Frequencies = new double[m_TotalAttValues];
305
m_CondiCountsNoClass = new double[m_TotalAttValues][m_TotalAttValues];
307
// calculate the counts
308
for(int k = 0; k < m_NumInstances; k++) {
309
addToCounts((Instance)m_Instances.instance(k));
312
// free up some space
313
m_Instances = new Instances(m_Instances, 0);
318
* Updates the classifier with the given instance.
320
* @param instance the new training instance to include in the model
321
* @throws Exception if the instance could not be incorporated in
324
public void updateClassifier(Instance instance) {
325
this.addToCounts(instance);
329
* Puts an instance's values into m_CondiCounts, m_ClassCounts and
332
* @param instance the instance whose values are to be put into the
335
private void addToCounts(Instance instance) {
337
double [] countsPointer;
338
double [] countsNoClassPointer;
340
if(instance.classIsMissing())
341
return; // ignore instances with missing class
343
int classVal = (int)instance.classValue();
344
double weight = instance.weight();
346
m_ClassCounts[classVal] += weight;
347
m_SumInstances += weight;
349
// store instance's att val indexes in an array, b/c accessing it
350
// in loop(s) is more efficient
351
int [] attIndex = new int[m_NumAttributes];
352
for(int i = 0; i < m_NumAttributes; i++) {
353
if(i == m_ClassIndex)
354
attIndex[i] = -1; // we don't use the class attribute in counts
356
if(instance.isMissing(i))
357
attIndex[i] = m_StartAttIndex[i] + m_NumAttValues[i];
359
attIndex[i] = m_StartAttIndex[i] + (int)instance.value(i);
363
for(int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
364
if(attIndex[Att1] == -1)
365
continue; // avoid pointless looping as Att1 is currently the class attribute
367
m_Frequencies[attIndex[Att1]] += weight;
369
// if this is a missing value, we don't want to increase sumforcounts
370
if(!instance.isMissing(Att1))
371
m_SumForCounts[classVal][Att1] += weight;
373
// save time by referencing this now, rather than repeatedly in the loop
374
countsPointer = m_CondiCounts[classVal][attIndex[Att1]];
375
countsNoClassPointer = m_CondiCountsNoClass[attIndex[Att1]];
377
for(int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
378
if(attIndex[Att2] != -1) {
379
countsPointer[attIndex[Att2]] += weight;
380
countsNoClassPointer[attIndex[Att2]] += weight;
388
* Calculates the class membership probabilities for the given test
391
* @param instance the instance to be classified
392
* @return predicted class probability distribution
393
* @throws Exception if there is a problem generating the prediction
395
public double [] distributionForInstance(Instance instance) throws Exception {
397
// accumulates posterior probabilities for each class
398
double [] probs = new double[m_NumClasses];
400
// index for parent attribute value, and a count of parents used
401
int pIndex, parentCount;
403
int [] SpecialGeneralArray = new int[m_NumAttributes];
405
// pointers for efficiency
406
double [][] countsForClass;
407
double [] countsForClassParent;
408
double [] countsForAtti;
409
double [] countsForAttj;
411
// store instance's att values in an int array, so accessing them
412
// is more efficient in loop(s).
413
int [] attIndex = new int[m_NumAttributes];
414
for(int att = 0; att < m_NumAttributes; att++) {
415
if(instance.isMissing(att) || att == m_ClassIndex)
416
attIndex[att] = -1; // can't use class & missing vals in calculations
418
attIndex[att] = m_StartAttIndex[att] + (int)instance.value(att);
420
// -1 indicates attribute is not a generalization of any other attributes
421
for(int i = 0; i < m_NumAttributes; i++) {
422
SpecialGeneralArray[i] = -1;
425
// calculate the specialization-generalization array
426
for(int i = 0; i < m_NumAttributes; i++){
427
// skip i if it's the class or is missing
428
if(attIndex[i] == -1) continue;
429
countsForAtti = m_CondiCountsNoClass[attIndex[i]];
431
for(int j = 0; j < m_NumAttributes; j++) {
432
// skip j if it's the class, missing, is i or a generalization of i
433
if((attIndex[j] == -1) || (i == j) || (SpecialGeneralArray[j] == i))
436
countsForAttj = m_CondiCountsNoClass[attIndex[j]];
438
// check j's frequency is above critical value
439
if(countsForAttj[attIndex[j]] > m_Critical) {
441
// skip j if the frequency of i and j together is not equivalent
442
// to the frequency of j alone
443
if(countsForAttj[attIndex[j]] == countsForAtti[attIndex[j]]) {
445
// if attributes i and j are both a specialization of each other
446
// avoid deleting both by skipping j
447
if((countsForAttj[attIndex[j]] == countsForAtti[attIndex[i]])
451
// set the specialization relationship
452
SpecialGeneralArray[i] = j;
453
break; // break out of j loop because a specialization has been found
460
// calculate probabilities for each possible class value
461
for(int classVal = 0; classVal < m_NumClasses; classVal++) {
467
countsForClass = m_CondiCounts[classVal];
469
// each attribute has a turn of being the parent
470
for(int parent = 0; parent < m_NumAttributes; parent++) {
471
if(attIndex[parent] == -1)
472
continue; // skip class attribute or missing value
474
// determine correct index for the parent in m_CondiCounts matrix
475
pIndex = attIndex[parent];
477
// check that the att value has a frequency of m_Limit or greater
478
if(m_Frequencies[pIndex] < m_Limit)
481
// delete the generalization attributes.
482
if(SpecialGeneralArray[parent] != -1)
485
countsForClassParent = countsForClass[pIndex];
487
// block the parent from being its own child
488
attIndex[parent] = -1;
492
double classparentfreq = countsForClassParent[pIndex];
494
// find the number of missing values for parent's attribute
495
double missing4ParentAtt =
496
m_Frequencies[m_StartAttIndex[parent] + m_NumAttValues[parent]];
498
// calculate the prior probability -- P(parent & classVal)
500
x = LaplaceEstimate(classparentfreq, m_SumInstances - missing4ParentAtt,
501
m_NumClasses * m_NumAttValues[parent]);
504
x = MEstimate(classparentfreq, m_SumInstances - missing4ParentAtt,
505
m_NumClasses * m_NumAttValues[parent]);
510
// take into account the value of each attribute
511
for(int att = 0; att < m_NumAttributes; att++) {
512
if(attIndex[att] == -1) // skip class attribute or missing value
514
// delete the generalization attributes.
515
if(SpecialGeneralArray[att] != -1)
519
double missingForParentandChildAtt =
520
countsForClassParent[m_StartAttIndex[att] + m_NumAttValues[att]];
523
x *= LaplaceEstimate(countsForClassParent[attIndex[att]],
524
classparentfreq - missingForParentandChildAtt, m_NumAttValues[att]);
526
x *= MEstimate(countsForClassParent[attIndex[att]],
527
classparentfreq - missingForParentandChildAtt, m_NumAttValues[att]);
531
// add this probability to the overall probability
532
probs[classVal] += x;
534
// unblock the parent
535
attIndex[parent] = pIndex;
538
// check that at least one att was a parent
539
if(parentCount < 1) {
541
// do plain naive bayes conditional prob
542
probs[classVal] = NBconditionalProb(instance, classVal);
543
//probs[classVal] = Double.NaN;
547
// divide by number of parent atts to get the mean
548
probs[classVal] /= (double)(parentCount);
551
Utils.normalize(probs);
557
* Calculates the probability of the specified class for the given test
558
* instance, using naive Bayes.
560
* @param instance the instance to be classified
561
* @param classVal the class for which to calculate the probability
562
* @return predicted class probability
563
* @throws Exception if there is a problem generating the prediction
565
public double NBconditionalProb(Instance instance, int classVal)
571
// calculate the prior probability
573
prob = LaplaceEstimate(m_ClassCounts[classVal],m_SumInstances,m_NumClasses);
575
prob = MEstimate(m_ClassCounts[classVal], m_SumInstances, m_NumClasses);
577
pointer = m_CondiCounts[classVal];
579
// consider effect of each att value
580
for(int att = 0; att < m_NumAttributes; att++) {
581
if(att == m_ClassIndex || instance.isMissing(att))
584
// determine correct index for att in m_CondiCounts
585
attIndex = m_StartAttIndex[att] + (int)instance.value(att);
587
prob *= LaplaceEstimate((double)pointer[attIndex][attIndex],
588
(double)m_SumForCounts[classVal][att], m_NumAttValues[att]);
590
prob *= MEstimate((double)pointer[attIndex][attIndex],
591
(double)m_SumForCounts[classVal][att], m_NumAttValues[att]);
599
* Returns the probability estimate, using m-estimate
601
* @param frequency frequency of value of interest
602
* @param total count of all values
603
* @param numValues number of different values
604
* @return the probability estimate
606
public double MEstimate(double frequency, double total,
609
return (frequency + m_MWeight / numValues) / (total + m_MWeight);
613
* Returns the probability estimate, using laplace correction
615
* @param frequency frequency of value of interest
616
* @param total count of all values
617
* @param numValues number of different values
618
* @return the probability estimate
620
public double LaplaceEstimate(double frequency, double total,
623
return (frequency + 1.0) / (total + numValues);
628
* Returns an enumeration describing the available options
630
* @return an enumeration of all the available options
632
public Enumeration listOptions() {
634
Vector newVector = new Vector(5);
636
newVector.addElement(
637
new Option("\tOutput debugging information\n",
639
newVector.addElement(
640
new Option("\tImpose a critcal value for specialization-generalization relationship\n"
641
+ "\t(default is 50)", "C", 1,"-C"));
642
newVector.addElement(
643
new Option("\tImpose a frequency limit for superParents\n"
644
+ "\t(default is 1)", "F", 2,"-F"));
645
newVector.addElement(
646
new Option("\tUsing Laplace estimation\n"
647
+ "\t(default is m-esimation (m=1))",
649
newVector.addElement(
650
new Option("\tWeight value for m-estimation\n"
651
+ "\t(default is 1.0)", "M", 4,"-M"));
653
return newVector.elements();
658
* Parses a given list of options. <p/>
660
<!-- options-start -->
661
* Valid options are:<p/>
664
* Output debugging information
667
* <pre> -F <int>
668
* Impose a frequency limit for superParents
669
* (default is 1)</pre>
672
* Use Laplace estimation
673
* (default is m-estimation)</pre>
675
* <pre> -M <double>
676
* Specify the m value of m-estimation
677
* (default is 1)</pre>
679
* <pre>-C <int>
680
* Specify critical value for specialization-generalization.
682
* Larger values than the default of 50 substantially reduce
683
* the risk of incorrectly inferring that one value subsumes
684
* another, but also reduces the number of true subsumptions
685
* that are detected.</pre>
689
* @param options the list of options as an array of strings
690
* @throws Exception if an option is not supported
692
public void setOptions(String[] options) throws Exception {
694
m_Debug = Utils.getFlag('D', options);
696
String Critical = Utils.getOption('C', options);
697
if(Critical.length() != 0)
698
m_Critical = Integer.parseInt(Critical);
702
String Freq = Utils.getOption('F', options);
703
if(Freq.length() != 0)
704
m_Limit = Integer.parseInt(Freq);
708
m_Laplace = Utils.getFlag('L', options);
709
String MWeight = Utils.getOption('M', options);
710
if(MWeight.length() != 0) {
712
throw new Exception("weight for m-estimate is pointless if using laplace estimation!");
713
m_MWeight = Double.parseDouble(MWeight);
717
Utils.checkForRemainingOptions(options);
721
* Gets the current settings of the classifier.
723
* @return an array of strings suitable for passing to setOptions
725
public String [] getOptions() {
727
Vector result = new Vector();
733
result.add("" + m_Limit);
739
result.add("" + m_MWeight);
743
result.add("" + m_Critical);
745
return (String[]) result.toArray(new String[result.size()]);
749
* Returns the tip text for this property
750
* @return tip text for this property suitable for
751
* displaying in the explorer/experimenter gui
753
public String mestWeightTipText() {
754
return "Set the weight for m-estimate.";
758
* Sets the weight for m-estimate
760
* @param w the weight
762
public void setMestWeight(double w) {
763
if (getUseLaplace()) {
765
"Weight is only used in conjunction with m-estimate - ignored!");
770
System.out.println("M-Estimate Weight must be greater than 0!");
775
* Gets the weight used in m-estimate
777
* @return the weight for m-estimation
779
public double getMestWeight() {
784
* Returns the tip text for this property
785
* @return tip text for this property suitable for
786
* displaying in the explorer/experimenter gui
788
public String useLaplaceTipText() {
789
return "Use Laplace correction instead of m-estimation.";
793
* Gets if laplace correction is being used.
795
* @return Value of m_Laplace.
797
public boolean getUseLaplace() {
802
* Sets if laplace correction is to be used.
804
* @param value Value to assign to m_Laplace.
806
public void setUseLaplace(boolean value) {
811
* Returns the tip text for this property
812
* @return tip text for this property suitable for
813
* displaying in the explorer/experimenter gui
815
public String frequencyLimitTipText() {
816
return "Attributes with a frequency in the train set below "
817
+ "this value aren't used as parents.";
821
* Sets the frequency limit
823
* @param f the frequency limit
825
public void setFrequencyLimit(int f) {
830
* Gets the frequency limit.
832
* @return the frequency limit
834
public int getFrequencyLimit() {
839
* Returns the tip text for this property
840
* @return tip text for this property suitable for
841
* displaying in the explorer/experimenter gui
843
public String criticalValueTipText() {
844
return "Specify critical value for specialization-generalization "
845
+ "relationship (default 50).";
849
* Sets the critical value
851
* @param c the critical value
853
public void setCriticalValue(int c) {
858
* Gets the critical value.
860
* @return the critical value
862
public int getCriticalValue() {
867
* Returns a description of the classifier.
869
* @return a description of the classifier as a string.
871
public String toString() {
873
StringBuffer text = new StringBuffer();
875
text.append("The AODEsr Classifier");
876
if (m_Instances == null) {
877
text.append(": No model built yet.");
880
for (int i = 0; i < m_NumClasses; i++) {
881
// print to string, the prior probabilities of class values
882
text.append("\nClass " + m_Instances.classAttribute().value(i) +
883
": Prior probability = " + Utils.
884
doubleToString(((m_ClassCounts[i] + 1)
885
/(m_SumInstances + m_NumClasses)), 4, 2)+"\n\n");
888
text.append("Dataset: " + m_Instances.relationName() + "\n"
889
+ "Instances: " + m_NumInstances + "\n"
890
+ "Attributes: " + m_NumAttributes + "\n"
891
+ "Frequency limit for superParents: " + m_Limit + "\n"
892
+ "Critical value for the specializtion-generalization "
893
+ "relationship: " + m_Critical + "\n");
895
text.append("Using LapLace estimation.");
897
text.append("Using m-estimation, m = " + m_MWeight);
899
} catch (Exception ex) {
900
text.append(ex.getMessage());
903
return text.toString();
908
* Main method for testing this class.
910
* @param argv the options
912
public static void main(String [] argv) {
913
runClassifier(new AODEsr(), argv);