2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
23
package weka.classifiers;
25
import weka.core.Attribute;
26
import weka.core.Instance;
27
import weka.core.Instances;
28
import weka.core.Option;
29
import weka.core.OptionHandler;
30
import weka.core.TechnicalInformation;
31
import weka.core.TechnicalInformation.Type;
32
import weka.core.TechnicalInformation.Field;
33
import weka.core.TechnicalInformationHandler;
34
import weka.core.Utils;
36
import java.io.BufferedReader;
37
import java.io.FileReader;
38
import java.io.Reader;
39
import java.util.Enumeration;
40
import java.util.Random;
41
import java.util.Vector;
44
<!-- globalinfo-start -->
45
* Class for performing a Bias-Variance decomposition on any classifier using the method specified in:<br/>
47
* Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.
49
<!-- globalinfo-end -->
51
<!-- technical-bibtex-start -->
54
* @inproceedings{Kohavi1996,
55
* author = {Ron Kohavi and David H. Wolpert},
56
* booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
57
* editor = {Lorenza Saitta},
59
* publisher = {Morgan Kaufmann},
60
* title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
62
* PS = {http://robotics.stanford.edu/\~ronnyk/biasVar.ps}
66
<!-- technical-bibtex-end -->
68
<!-- options-start -->
69
* Valid options are: <p/>
71
* <pre> -c <class index>
72
* The index of the class attribute.
73
* (default last)</pre>
75
* <pre> -t <name of arff file>
76
* The name of the arff file used for the decomposition.</pre>
78
* <pre> -T <training pool size>
79
* The number of instances placed in the training pool.
80
* The remainder will be used for testing. (default 100)</pre>
82
* <pre> -s <seed>
83
* The random number seed used.</pre>
85
* <pre> -x <num>
86
* The number of training repetitions used.
90
* Turn on debugging output.</pre>
92
* <pre> -W <classifier class name>
93
* Full class name of the learner used in the decomposition.
94
* eg: weka.classifiers.bayes.NaiveBayes</pre>
97
* Options specific to learner weka.classifiers.rules.ZeroR:
101
* If set, classifier is run in debug mode and
102
* may output additional info to the console</pre>
106
* Options after -- are passed to the designated sub-learner. <p>
108
* @author Len Trigg (trigg@cs.waikato.ac.nz)
109
* @version $Revision: 1.14 $
111
public class BVDecompose
112
implements OptionHandler, TechnicalInformationHandler {
114
/** Debugging mode, gives extra output if true */
115
protected boolean m_Debug;
117
/** An instantiated base classifier used for getting and testing options. */
118
protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
120
/** The options to be passed to the base classifier. */
121
protected String [] m_ClassifierOptions;
123
/** The number of train iterations */
124
protected int m_TrainIterations = 50;
126
/** The name of the data file used for the decomposition */
127
protected String m_DataFileName;
129
/** The index of the class attribute */
130
protected int m_ClassIndex = -1;
132
/** The random number seed */
133
protected int m_Seed = 1;
135
/** The calculated bias (squared) */
136
protected double m_Bias;
138
/** The calculated variance */
139
protected double m_Variance;
141
/** The calculated sigma (squared) */
142
protected double m_Sigma;
144
/** The error rate */
145
protected double m_Error;
147
/** The number of instances used in the training pool */
148
protected int m_TrainPoolSize = 100;
151
* Returns a string describing this object
152
* @return a description of the classifier suitable for
153
* displaying in the explorer/experimenter gui
155
public String globalInfo() {
158
"Class for performing a Bias-Variance decomposition on any classifier "
159
+ "using the method specified in:\n\n"
160
+ getTechnicalInformation().toString();
164
* Returns an instance of a TechnicalInformation object, containing
165
* detailed information about the technical background of this class,
166
* e.g., paper reference or book this class is based on.
168
* @return the technical information about this class
170
public TechnicalInformation getTechnicalInformation() {
171
TechnicalInformation result;
173
result = new TechnicalInformation(Type.INPROCEEDINGS);
174
result.setValue(Field.AUTHOR, "Ron Kohavi and David H. Wolpert");
175
result.setValue(Field.YEAR, "1996");
176
result.setValue(Field.TITLE, "Bias Plus Variance Decomposition for Zero-One Loss Functions");
177
result.setValue(Field.BOOKTITLE, "Machine Learning: Proceedings of the Thirteenth International Conference");
178
result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
179
result.setValue(Field.EDITOR, "Lorenza Saitta");
180
result.setValue(Field.PAGES, "275-283");
181
result.setValue(Field.PS, "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
187
* Returns an enumeration describing the available options.
189
* @return an enumeration of all the available options.
191
public Enumeration listOptions() {
193
Vector newVector = new Vector(7);
195
newVector.addElement(new Option(
196
"\tThe index of the class attribute.\n"+
198
"c", 1, "-c <class index>"));
199
newVector.addElement(new Option(
200
"\tThe name of the arff file used for the decomposition.",
201
"t", 1, "-t <name of arff file>"));
202
newVector.addElement(new Option(
203
"\tThe number of instances placed in the training pool.\n"
204
+ "\tThe remainder will be used for testing. (default 100)",
205
"T", 1, "-T <training pool size>"));
206
newVector.addElement(new Option(
207
"\tThe random number seed used.",
208
"s", 1, "-s <seed>"));
209
newVector.addElement(new Option(
210
"\tThe number of training repetitions used.\n"
212
"x", 1, "-x <num>"));
213
newVector.addElement(new Option(
214
"\tTurn on debugging output.",
216
newVector.addElement(new Option(
217
"\tFull class name of the learner used in the decomposition.\n"
218
+"\teg: weka.classifiers.bayes.NaiveBayes",
219
"W", 1, "-W <classifier class name>"));
221
if ((m_Classifier != null) &&
222
(m_Classifier instanceof OptionHandler)) {
223
newVector.addElement(new Option(
225
"", 0, "\nOptions specific to learner "
226
+ m_Classifier.getClass().getName()
228
Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
229
while (enu.hasMoreElements()) {
230
newVector.addElement(enu.nextElement());
233
return newVector.elements();
237
* Parses a given list of options. <p/>
239
<!-- options-start -->
240
* Valid options are: <p/>
242
* <pre> -c <class index>
243
* The index of the class attribute.
244
* (default last)</pre>
246
* <pre> -t <name of arff file>
247
* The name of the arff file used for the decomposition.</pre>
249
* <pre> -T <training pool size>
250
* The number of instances placed in the training pool.
251
* The remainder will be used for testing. (default 100)</pre>
253
* <pre> -s <seed>
254
* The random number seed used.</pre>
256
* <pre> -x <num>
257
* The number of training repetitions used.
261
* Turn on debugging output.</pre>
263
* <pre> -W <classifier class name>
264
* Full class name of the learner used in the decomposition.
265
* eg: weka.classifiers.bayes.NaiveBayes</pre>
268
* Options specific to learner weka.classifiers.rules.ZeroR:
272
* If set, classifier is run in debug mode and
273
* may output additional info to the console</pre>
277
* Options after -- are passed to the designated sub-learner. <p>
279
* @param options the list of options as an array of strings
280
* @throws Exception if an option is not supported
282
public void setOptions(String[] options) throws Exception {
284
setDebug(Utils.getFlag('D', options));
286
String classIndex = Utils.getOption('c', options);
287
if (classIndex.length() != 0) {
288
if (classIndex.toLowerCase().equals("last")) {
290
} else if (classIndex.toLowerCase().equals("first")) {
293
setClassIndex(Integer.parseInt(classIndex));
299
String trainIterations = Utils.getOption('x', options);
300
if (trainIterations.length() != 0) {
301
setTrainIterations(Integer.parseInt(trainIterations));
303
setTrainIterations(50);
306
String trainPoolSize = Utils.getOption('T', options);
307
if (trainPoolSize.length() != 0) {
308
setTrainPoolSize(Integer.parseInt(trainPoolSize));
310
setTrainPoolSize(100);
313
String seedString = Utils.getOption('s', options);
314
if (seedString.length() != 0) {
315
setSeed(Integer.parseInt(seedString));
320
String dataFile = Utils.getOption('t', options);
321
if (dataFile.length() == 0) {
322
throw new Exception("An arff file must be specified"
323
+ " with the -t option.");
325
setDataFileName(dataFile);
327
String classifierName = Utils.getOption('W', options);
328
if (classifierName.length() == 0) {
329
throw new Exception("A learner must be specified with the -W option.");
331
setClassifier(Classifier.forName(classifierName,
332
Utils.partitionOptions(options)));
336
* Gets the current settings of the CheckClassifier.
338
* @return an array of strings suitable for passing to setOptions
340
public String [] getOptions() {
342
String [] classifierOptions = new String [0];
343
if ((m_Classifier != null) &&
344
(m_Classifier instanceof OptionHandler)) {
345
classifierOptions = ((OptionHandler)m_Classifier).getOptions();
347
String [] options = new String [classifierOptions.length + 14];
350
options[current++] = "-D";
352
options[current++] = "-c"; options[current++] = "" + getClassIndex();
353
options[current++] = "-x"; options[current++] = "" + getTrainIterations();
354
options[current++] = "-T"; options[current++] = "" + getTrainPoolSize();
355
options[current++] = "-s"; options[current++] = "" + getSeed();
356
if (getDataFileName() != null) {
357
options[current++] = "-t"; options[current++] = "" + getDataFileName();
359
if (getClassifier() != null) {
360
options[current++] = "-W";
361
options[current++] = getClassifier().getClass().getName();
363
options[current++] = "--";
364
System.arraycopy(classifierOptions, 0, options, current,
365
classifierOptions.length);
366
current += classifierOptions.length;
367
while (current < options.length) {
368
options[current++] = "";
374
* Get the number of instances in the training pool.
376
* @return number of instances in the training pool.
378
public int getTrainPoolSize() {
380
return m_TrainPoolSize;
384
* Set the number of instances in the training pool.
386
* @param numTrain number of instances in the training pool.
388
public void setTrainPoolSize(int numTrain) {
390
m_TrainPoolSize = numTrain;
394
* Set the classifiers being analysed
396
* @param newClassifier the Classifier to use.
398
public void setClassifier(Classifier newClassifier) {
400
m_Classifier = newClassifier;
404
* Gets the name of the classifier being analysed
406
* @return the classifier being analysed.
408
public Classifier getClassifier() {
414
* Sets debugging mode
416
* @param debug true if debug output should be printed
418
public void setDebug(boolean debug) {
424
* Gets whether debugging is turned on
426
* @return true if debugging output is on
428
public boolean getDebug() {
434
* Sets the random number seed
436
* @param seed the random number seed
438
public void setSeed(int seed) {
444
* Gets the random number seed
446
* @return the random number seed
448
public int getSeed() {
454
* Sets the maximum number of boost iterations
456
* @param trainIterations the number of boost iterations
458
public void setTrainIterations(int trainIterations) {
460
m_TrainIterations = trainIterations;
464
* Gets the maximum number of boost iterations
466
* @return the maximum number of boost iterations
468
public int getTrainIterations() {
470
return m_TrainIterations;
474
* Sets the name of the data file used for the decomposition
476
* @param dataFileName the data file to use
478
public void setDataFileName(String dataFileName) {
480
m_DataFileName = dataFileName;
484
* Get the name of the data file used for the decomposition
486
* @return the name of the data file
488
public String getDataFileName() {
490
return m_DataFileName;
494
* Get the index (starting from 1) of the attribute used as the class.
496
* @return the index of the class attribute
498
public int getClassIndex() {
500
return m_ClassIndex + 1;
504
* Sets index of attribute to discretize on
506
* @param classIndex the index (starting from 1) of the class attribute
508
public void setClassIndex(int classIndex) {
510
m_ClassIndex = classIndex - 1;
514
* Get the calculated bias squared
516
* @return the bias squared
518
public double getBias() {
524
* Get the calculated variance
526
* @return the variance
528
public double getVariance() {
534
* Get the calculated sigma squared
536
* @return the sigma squared
538
public double getSigma() {
544
* Get the calculated error rate
546
* @return the error rate
548
public double getError() {
554
* Carry out the bias-variance decomposition
556
* @throws Exception if the decomposition couldn't be carried out
558
public void decompose() throws Exception {
560
Reader dataReader = new BufferedReader(new FileReader(m_DataFileName));
561
Instances data = new Instances(dataReader);
563
if (m_ClassIndex < 0) {
564
data.setClassIndex(data.numAttributes() - 1);
566
data.setClassIndex(m_ClassIndex);
568
if (data.classAttribute().type() != Attribute.NOMINAL) {
569
throw new Exception("Class attribute must be nominal");
571
int numClasses = data.numClasses();
573
data.deleteWithMissingClass();
574
if (data.checkForStringAttributes()) {
575
throw new Exception("Can't handle string attributes!");
578
if (data.numInstances() < 2 * m_TrainPoolSize) {
579
throw new Exception("The dataset must contain at least "
580
+ (2 * m_TrainPoolSize) + " instances");
582
Random random = new Random(m_Seed);
583
data.randomize(random);
584
Instances trainPool = new Instances(data, 0, m_TrainPoolSize);
585
Instances test = new Instances(data, m_TrainPoolSize,
586
data.numInstances() - m_TrainPoolSize);
587
int numTest = test.numInstances();
588
double [][] instanceProbs = new double [numTest][numClasses];
591
for (int i = 0; i < m_TrainIterations; i++) {
593
System.err.println("Iteration " + (i + 1));
595
trainPool.randomize(random);
596
Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2);
598
Classifier current = Classifier.makeCopy(m_Classifier);
599
current.buildClassifier(train);
601
//// Evaluate the classifier on test, updating BVD stats
602
for (int j = 0; j < numTest; j++) {
603
int pred = (int)current.classifyInstance(test.instance(j));
604
if (pred != test.instance(j).classValue()) {
607
instanceProbs[j][pred]++;
610
m_Error /= (m_TrainIterations * numTest);
612
// Average the BV over each instance in test.
616
for (int i = 0; i < numTest; i++) {
617
Instance current = test.instance(i);
618
double [] predProbs = instanceProbs[i];
619
double pActual, pPred;
620
double bsum = 0, vsum = 0, ssum = 0;
621
for (int j = 0; j < numClasses; j++) {
622
pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data?
623
pPred = predProbs[j] / m_TrainIterations;
624
bsum += (pActual - pPred) * (pActual - pPred)
625
- pPred * (1 - pPred) / (m_TrainIterations - 1);
626
vsum += pPred * pPred;
627
ssum += pActual * pActual;
630
m_Variance += (1 - vsum);
631
m_Sigma += (1 - ssum);
633
m_Bias /= (2 * numTest);
634
m_Variance /= (2 * numTest);
635
m_Sigma /= (2 * numTest);
638
System.err.println("Decomposition finished");
644
* Returns description of the bias-variance decomposition results.
646
* @return the bias-variance decomposition results as a string
648
public String toString() {
650
String result = "\nBias-Variance Decomposition\n";
652
if (getClassifier() == null) {
653
return "Invalid setup";
656
result += "\nClassifier : " + getClassifier().getClass().getName();
657
if (getClassifier() instanceof OptionHandler) {
658
result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
660
result += "\nData File : " + getDataFileName();
661
result += "\nClass Index : ";
662
if (getClassIndex() == 0) {
665
result += getClassIndex();
667
result += "\nTraining Pool: " + getTrainPoolSize();
668
result += "\nIterations : " + getTrainIterations();
669
result += "\nSeed : " + getSeed();
670
result += "\nError : " + Utils.doubleToString(getError(), 6, 4);
671
result += "\nSigma^2 : " + Utils.doubleToString(getSigma(), 6, 4);
672
result += "\nBias^2 : " + Utils.doubleToString(getBias(), 6, 4);
673
result += "\nVariance : " + Utils.doubleToString(getVariance(), 6, 4);
675
return result + "\n";
680
* Test method for this class
682
* @param args the command line arguments
684
public static void main(String [] args) {
687
BVDecompose bvd = new BVDecompose();
690
bvd.setOptions(args);
691
Utils.checkForRemainingOptions(args);
692
} catch (Exception ex) {
693
String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
694
Enumeration enu = bvd.listOptions();
695
while (enu.hasMoreElements()) {
696
Option option = (Option) enu.nextElement();
697
result += option.synopsis() + "\n" + option.description() + "\n";
699
throw new Exception(result);
703
System.out.println(bvd.toString());
704
} catch (Exception ex) {
705
System.err.println(ex.getMessage());