~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/BVDecomposeSegCVSub.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    BVDecomposeSegCVSub.java
 
19
 *    Copyright (C) 2003 Paul Conilione
 
20
 *
 
21
 *    Based on the class: BVDecompose.java by Len Trigg (1999)
 
22
 */
 
23
 
 
24
 
 
25
/*
 
26
 *    DEDICATION
 
27
 *
 
28
 *    Paul Conilione would like to express his deep gratitude and appreciation
 
29
 *    to his Chinese Buddhist Taoist Master Sifu Chow Yuk Nen for the abilities
 
30
 *    and insight that he has been taught, which have allowed him to program in 
 
31
 *    a clear and efficient manner.
 
32
 *
 
33
 *    Master Sifu Chow Yuk Nen's Teachings are unique and precious. They are
 
34
 *    applicable to any field of human endeavour. Through his unique and powerful
 
35
 *    ability to skilfully apply Chinese Buddhist Teachings, people have achieved
 
36
 *    success in; Computing, chemical engineering, business, accounting, philosophy
 
37
 *    and more.
 
38
 *
 
39
 */
 
40
 
 
41
package weka.classifiers;
 
42
 
 
43
import weka.core.Attribute;
 
44
import weka.core.Instance;
 
45
import weka.core.Instances;
 
46
import weka.core.Option;
 
47
import weka.core.OptionHandler;
 
48
import weka.core.TechnicalInformation;
 
49
import weka.core.TechnicalInformation.Type;
 
50
import weka.core.TechnicalInformation.Field;
 
51
import weka.core.TechnicalInformationHandler;
 
52
import weka.core.Utils;
 
53
 
 
54
import java.io.BufferedReader;
 
55
import java.io.FileReader;
 
56
import java.io.Reader;
 
57
import java.util.Enumeration;
 
58
import java.util.Random;
 
59
import java.util.Vector;
 
60
 
 
61
/**
 
62
 <!-- globalinfo-start -->
 
63
 * This class performs Bias-Variance decomposion on any classifier using the sub-sampled cross-validation procedure as specified in (1).<br/>
 
64
 * The Kohavi and Wolpert definition of bias and variance is specified in (2).<br/>
 
65
 * The Webb definition of bias and variance is specified in (3).<br/>
 
66
 * <br/>
 
67
 * Geoffrey I. Webb, Paul Conilione (2002). Estimating bias and variance from data. School of Computer Science and Software Engineering, Victoria, Australia.<br/>
 
68
 * <br/>
 
69
 * Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.<br/>
 
70
 * <br/>
 
71
 * Geoffrey I. Webb (2000). MultiBoosting: A Technique for Combining Boosting and Wagging. Machine Learning. 40(2):159-196.
 
72
 * <p/>
 
73
 <!-- globalinfo-end -->
 
74
 * 
 
75
 <!-- technical-bibtex-start -->
 
76
 * BibTeX:
 
77
 * <pre>
 
78
 * &#64;misc{Webb2002,
 
79
 *    address = {School of Computer Science and Software Engineering, Victoria, Australia},
 
80
 *    author = {Geoffrey I. Webb and Paul Conilione},
 
81
 *    institution = {Monash University},
 
82
 *    title = {Estimating bias and variance from data},
 
83
 *    year = {2002},
 
84
 *    PDF = {http://www.csse.monash.edu.au/\~webb/Files/WebbConilione04.pdf}
 
85
 * }
 
86
 * 
 
87
 * &#64;inproceedings{Kohavi1996,
 
88
 *    author = {Ron Kohavi and David H. Wolpert},
 
89
 *    booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
 
90
 *    editor = {Lorenza Saitta},
 
91
 *    pages = {275-283},
 
92
 *    publisher = {Morgan Kaufmann},
 
93
 *    title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
 
94
 *    year = {1996},
 
95
 *    PS = {http://robotics.stanford.edu/\~ronnyk/biasVar.ps}
 
96
 * }
 
97
 * 
 
98
 * &#64;article{Webb2000,
 
99
 *    author = {Geoffrey I. Webb},
 
100
 *    journal = {Machine Learning},
 
101
 *    number = {2},
 
102
 *    pages = {159-196},
 
103
 *    title = {MultiBoosting: A Technique for Combining Boosting and Wagging},
 
104
 *    volume = {40},
 
105
 *    year = {2000}
 
106
 * }
 
107
 * </pre>
 
108
 * <p/>
 
109
 <!-- technical-bibtex-end -->
 
110
 *
 
111
 <!-- options-start -->
 
112
 * Valid options are: <p/>
 
113
 * 
 
114
 * <pre> -c &lt;class index&gt;
 
115
 *  The index of the class attribute.
 
116
 *  (default last)</pre>
 
117
 * 
 
118
 * <pre> -D
 
119
 *  Turn on debugging output.</pre>
 
120
 * 
 
121
 * <pre> -l &lt;num&gt;
 
122
 *  The number of times each instance is classified.
 
123
 *  (default 10)</pre>
 
124
 * 
 
125
 * <pre> -p &lt;proportion of objects in common&gt;
 
126
 *  The average proportion of instances common between any two training sets</pre>
 
127
 * 
 
128
 * <pre> -s &lt;seed&gt;
 
129
 *  The random number seed used.</pre>
 
130
 * 
 
131
 * <pre> -t &lt;name of arff file&gt;
 
132
 *  The name of the arff file used for the decomposition.</pre>
 
133
 * 
 
134
 * <pre> -T &lt;number of instances in training set&gt;
 
135
 *  The number of instances in the training set.</pre>
 
136
 * 
 
137
 * <pre> -W &lt;classifier class name&gt;
 
138
 *  Full class name of the learner used in the decomposition.
 
139
 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
 
140
 * 
 
141
 * <pre> 
 
142
 * Options specific to learner weka.classifiers.rules.ZeroR:
 
143
 * </pre>
 
144
 * 
 
145
 * <pre> -D
 
146
 *  If set, classifier is run in debug mode and
 
147
 *  may output additional info to the console</pre>
 
148
 * 
 
149
 <!-- options-end -->
 
150
 *
 
151
 * Options after -- are passed to the designated sub-learner. <p>
 
152
 *
 
153
 * @author Paul Conilione (paulc4321@yahoo.com.au)
 
154
 * @version $Revision: 1.6 $
 
155
 */
 
156
public class BVDecomposeSegCVSub
 
157
    implements OptionHandler, TechnicalInformationHandler {
 
158
    
 
159
    /** Debugging mode, gives extra output if true. */
 
160
    protected boolean m_Debug;
 
161
    
 
162
    /** An instantiated base classifier used for getting and testing options. */
 
163
    protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
 
164
    
 
165
    /** The options to be passed to the base classifier. */
 
166
    protected String [] m_ClassifierOptions;
 
167
    
 
168
    /** The number of times an instance is classified*/
 
169
    protected int m_ClassifyIterations;
 
170
    
 
171
    /** The name of the data file used for the decomposition */
 
172
    protected String m_DataFileName;
 
173
    
 
174
    /** The index of the class attribute */
 
175
    protected int m_ClassIndex = -1;
 
176
    
 
177
    /** The random number seed */
 
178
    protected int m_Seed = 1;
 
179
    
 
180
    /** The calculated Kohavi & Wolpert bias (squared) */
 
181
    protected double m_KWBias;
 
182
    
 
183
    /** The calculated Kohavi & Wolpert variance */
 
184
    protected double m_KWVariance;
 
185
    
 
186
    /** The calculated Kohavi & Wolpert sigma */
 
187
    protected double m_KWSigma;
 
188
    
 
189
    /** The calculated Webb bias */
 
190
    protected double m_WBias;
 
191
    
 
192
    /** The calculated Webb variance */
 
193
    protected double m_WVariance;
 
194
    
 
195
    /** The error rate */
 
196
    protected double m_Error;
 
197
    
 
198
    /** The training set size */
 
199
    protected int m_TrainSize;
 
200
    
 
201
    /** Proportion of instances common between any two training sets. */
 
202
    protected double m_P;
 
203
    
 
204
    /**
 
205
     * Returns a string describing this object
 
206
     * @return a description of the classifier suitable for
 
207
     * displaying in the explorer/experimenter gui
 
208
     */
 
209
    public String globalInfo() {
 
210
      return 
 
211
          "This class performs Bias-Variance decomposion on any classifier using the "
 
212
        + "sub-sampled cross-validation procedure as specified in (1).\n"
 
213
        + "The Kohavi and Wolpert definition of bias and variance is specified in (2).\n"
 
214
        + "The Webb definition of bias and variance is specified in (3).\n\n"
 
215
        + getTechnicalInformation().toString();
 
216
    }
 
217
 
 
218
    /**
 
219
     * Returns an instance of a TechnicalInformation object, containing 
 
220
     * detailed information about the technical background of this class,
 
221
     * e.g., paper reference or book this class is based on.
 
222
     * 
 
223
     * @return the technical information about this class
 
224
     */
 
225
    public TechnicalInformation getTechnicalInformation() {
 
226
      TechnicalInformation      result;
 
227
      TechnicalInformation      additional;
 
228
      
 
229
      result = new TechnicalInformation(Type.MISC);
 
230
      result.setValue(Field.AUTHOR, "Geoffrey I. Webb and Paul Conilione");
 
231
      result.setValue(Field.YEAR, "2002");
 
232
      result.setValue(Field.TITLE, "Estimating bias and variance from data");
 
233
      result.setValue(Field.INSTITUTION, "Monash University");
 
234
      result.setValue(Field.ADDRESS, "School of Computer Science and Software Engineering, Victoria, Australia");
 
235
      result.setValue(Field.PDF, "http://www.csse.monash.edu.au/~webb/Files/WebbConilione04.pdf");
 
236
 
 
237
      additional = result.add(Type.INPROCEEDINGS);
 
238
      additional.setValue(Field.AUTHOR, "Ron Kohavi and David H. Wolpert");
 
239
      additional.setValue(Field.YEAR, "1996");
 
240
      additional.setValue(Field.TITLE, "Bias Plus Variance Decomposition for Zero-One Loss Functions");
 
241
      additional.setValue(Field.BOOKTITLE, "Machine Learning: Proceedings of the Thirteenth International Conference");
 
242
      additional.setValue(Field.PUBLISHER, "Morgan Kaufmann");
 
243
      additional.setValue(Field.EDITOR, "Lorenza Saitta");
 
244
      additional.setValue(Field.PAGES, "275-283");
 
245
      additional.setValue(Field.PS, "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
 
246
 
 
247
      additional = result.add(Type.ARTICLE);
 
248
      additional.setValue(Field.AUTHOR, "Geoffrey I. Webb");
 
249
      additional.setValue(Field.YEAR, "2000");
 
250
      additional.setValue(Field.TITLE, "MultiBoosting: A Technique for Combining Boosting and Wagging");
 
251
      additional.setValue(Field.JOURNAL, "Machine Learning");
 
252
      additional.setValue(Field.VOLUME, "40");
 
253
      additional.setValue(Field.NUMBER, "2");
 
254
      additional.setValue(Field.PAGES, "159-196");
 
255
 
 
256
      return result;
 
257
    }
 
258
    
 
259
    /**
 
260
     * Returns an enumeration describing the available options.
 
261
     *
 
262
     * @return an enumeration of all the available options.
 
263
     */
 
264
    public Enumeration listOptions() {
 
265
        
 
266
        Vector newVector = new Vector(8);
 
267
        
 
268
        newVector.addElement(new Option(
 
269
        "\tThe index of the class attribute.\n"+
 
270
        "\t(default last)",
 
271
        "c", 1, "-c <class index>"));
 
272
        newVector.addElement(new Option(
 
273
        "\tTurn on debugging output.",
 
274
        "D", 0, "-D"));
 
275
        newVector.addElement(new Option(
 
276
        "\tThe number of times each instance is classified.\n"
 
277
        +"\t(default 10)",
 
278
        "l", 1, "-l <num>"));
 
279
        newVector.addElement(new Option(
 
280
        "\tThe average proportion of instances common between any two training sets",
 
281
        "p", 1, "-p <proportion of objects in common>"));
 
282
        newVector.addElement(new Option(
 
283
        "\tThe random number seed used.",
 
284
        "s", 1, "-s <seed>"));
 
285
        newVector.addElement(new Option(
 
286
        "\tThe name of the arff file used for the decomposition.",
 
287
        "t", 1, "-t <name of arff file>"));
 
288
        newVector.addElement(new Option(
 
289
        "\tThe number of instances in the training set.",
 
290
        "T", 1, "-T <number of instances in training set>"));
 
291
        newVector.addElement(new Option(
 
292
        "\tFull class name of the learner used in the decomposition.\n"
 
293
        +"\teg: weka.classifiers.bayes.NaiveBayes",
 
294
        "W", 1, "-W <classifier class name>"));
 
295
        
 
296
        if ((m_Classifier != null) &&
 
297
        (m_Classifier instanceof OptionHandler)) {
 
298
            newVector.addElement(new Option(
 
299
            "",
 
300
            "", 0, "\nOptions specific to learner "
 
301
            + m_Classifier.getClass().getName()
 
302
            + ":"));
 
303
            Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
 
304
            while (enu.hasMoreElements()) {
 
305
                newVector.addElement(enu.nextElement());
 
306
            }
 
307
        }
 
308
        return newVector.elements();
 
309
    }
 
310
    
 
311
    
 
312
    /** 
 
313
     * Sets the OptionHandler's options using the given list. All options
 
314
     * will be set (or reset) during this call (i.e. incremental setting
 
315
     * of options is not possible). <p/>
 
316
     *
 
317
     <!-- options-start -->
 
318
     * Valid options are: <p/>
 
319
     * 
 
320
     * <pre> -c &lt;class index&gt;
 
321
     *  The index of the class attribute.
 
322
     *  (default last)</pre>
 
323
     * 
 
324
     * <pre> -D
 
325
     *  Turn on debugging output.</pre>
 
326
     * 
 
327
     * <pre> -l &lt;num&gt;
 
328
     *  The number of times each instance is classified.
 
329
     *  (default 10)</pre>
 
330
     * 
 
331
     * <pre> -p &lt;proportion of objects in common&gt;
 
332
     *  The average proportion of instances common between any two training sets</pre>
 
333
     * 
 
334
     * <pre> -s &lt;seed&gt;
 
335
     *  The random number seed used.</pre>
 
336
     * 
 
337
     * <pre> -t &lt;name of arff file&gt;
 
338
     *  The name of the arff file used for the decomposition.</pre>
 
339
     * 
 
340
     * <pre> -T &lt;number of instances in training set&gt;
 
341
     *  The number of instances in the training set.</pre>
 
342
     * 
 
343
     * <pre> -W &lt;classifier class name&gt;
 
344
     *  Full class name of the learner used in the decomposition.
 
345
     *  eg: weka.classifiers.bayes.NaiveBayes</pre>
 
346
     * 
 
347
     * <pre> 
 
348
     * Options specific to learner weka.classifiers.rules.ZeroR:
 
349
     * </pre>
 
350
     * 
 
351
     * <pre> -D
 
352
     *  If set, classifier is run in debug mode and
 
353
     *  may output additional info to the console</pre>
 
354
     * 
 
355
     <!-- options-end -->
 
356
     *
 
357
     * @param options the list of options as an array of strings
 
358
     * @throws Exception if an option is not supported
 
359
     */
 
360
    public void setOptions(String[] options) throws Exception {
 
361
        setDebug(Utils.getFlag('D', options));
 
362
        
 
363
        String classIndex = Utils.getOption('c', options);
 
364
        if (classIndex.length() != 0) {
 
365
            if (classIndex.toLowerCase().equals("last")) {
 
366
                setClassIndex(0);
 
367
            } else if (classIndex.toLowerCase().equals("first")) {
 
368
                setClassIndex(1);
 
369
            } else {
 
370
                setClassIndex(Integer.parseInt(classIndex));
 
371
            }
 
372
        } else {
 
373
            setClassIndex(0);
 
374
        }
 
375
        
 
376
        String classifyIterations = Utils.getOption('l', options);
 
377
        if (classifyIterations.length() != 0) {
 
378
            setClassifyIterations(Integer.parseInt(classifyIterations));
 
379
        } else {
 
380
            setClassifyIterations(10);
 
381
        }
 
382
        
 
383
        String prob = Utils.getOption('p', options);
 
384
        if (prob.length() != 0) {
 
385
            setP( Double.parseDouble(prob));
 
386
        } else {
 
387
            setP(-1);
 
388
        }
 
389
        //throw new Exception("A proportion must be specified" + " with a -p option.");
 
390
        
 
391
        String seedString = Utils.getOption('s', options);
 
392
        if (seedString.length() != 0) {
 
393
            setSeed(Integer.parseInt(seedString));
 
394
        } else {
 
395
            setSeed(1);
 
396
        }
 
397
        
 
398
        String dataFile = Utils.getOption('t', options);
 
399
        if (dataFile.length() != 0) {
 
400
            setDataFileName(dataFile);
 
401
        } else {
 
402
            throw new Exception("An arff file must be specified"
 
403
            + " with the -t option.");
 
404
        }
 
405
        
 
406
        String trainSize = Utils.getOption('T', options);
 
407
        if (trainSize.length() != 0) {
 
408
            setTrainSize(Integer.parseInt(trainSize));
 
409
        } else {
 
410
            setTrainSize(-1);
 
411
        }
 
412
        //throw new Exception("A training set size must be specified" + " with a -T option.");
 
413
        
 
414
        String classifierName = Utils.getOption('W', options);
 
415
        if (classifierName.length() != 0) {
 
416
            setClassifier(Classifier.forName(classifierName, Utils.partitionOptions(options)));
 
417
        } else {
 
418
            throw new Exception("A learner must be specified with the -W option.");
 
419
        }
 
420
    }
 
421
    
 
422
    /**
 
423
     * Gets the current settings of the CheckClassifier.
 
424
     *
 
425
     * @return an array of strings suitable for passing to setOptions
 
426
     */
 
427
    public String [] getOptions() {
 
428
        
 
429
        String [] classifierOptions = new String [0];
 
430
        if ((m_Classifier != null) &&
 
431
        (m_Classifier instanceof OptionHandler)) {
 
432
            classifierOptions = ((OptionHandler)m_Classifier).getOptions();
 
433
        }
 
434
        String [] options = new String [classifierOptions.length + 14];
 
435
        int current = 0;
 
436
        if (getDebug()) {
 
437
            options[current++] = "-D";
 
438
        }
 
439
        options[current++] = "-c"; options[current++] = "" + getClassIndex();
 
440
        options[current++] = "-l"; options[current++] = "" + getClassifyIterations();
 
441
        options[current++] = "-p"; options[current++] = "" + getP();
 
442
        options[current++] = "-s"; options[current++] = "" + getSeed();
 
443
        if (getDataFileName() != null) {
 
444
            options[current++] = "-t"; options[current++] = "" + getDataFileName();
 
445
        }
 
446
        options[current++] = "-T"; options[current++] = "" + getTrainSize();
 
447
        if (getClassifier() != null) {
 
448
            options[current++] = "-W";
 
449
            options[current++] = getClassifier().getClass().getName();
 
450
        }
 
451
        
 
452
        options[current++] = "--";
 
453
        System.arraycopy(classifierOptions, 0, options, current,
 
454
        classifierOptions.length);
 
455
        current += classifierOptions.length;
 
456
        while (current < options.length) {
 
457
            options[current++] = "";
 
458
        }
 
459
        return options;
 
460
    }
 
461
    
 
462
    /**
 
463
     * Set the classifiers being analysed
 
464
     *
 
465
     * @param newClassifier the Classifier to use.
 
466
     */
 
467
    public void setClassifier(Classifier newClassifier) {
 
468
        
 
469
        m_Classifier = newClassifier;
 
470
    }
 
471
    
 
472
    /**
 
473
     * Gets the name of the classifier being analysed
 
474
     *
 
475
     * @return the classifier being analysed.
 
476
     */
 
477
    public Classifier getClassifier() {
 
478
        
 
479
        return m_Classifier;
 
480
    }
 
481
    
 
482
    /**
 
483
     * Sets debugging mode
 
484
     *
 
485
     * @param debug true if debug output should be printed
 
486
     */
 
487
    public void setDebug(boolean debug) {
 
488
        
 
489
        m_Debug = debug;
 
490
    }
 
491
    
 
492
    /**
 
493
     * Gets whether debugging is turned on
 
494
     *
 
495
     * @return true if debugging output is on
 
496
     */
 
497
    public boolean getDebug() {
 
498
        
 
499
        return m_Debug;
 
500
    }
 
501
    
 
502
    
 
503
    /**
 
504
     * Sets the random number seed
 
505
     * 
 
506
     * @param seed the random number seed
 
507
     */
 
508
    public void setSeed(int seed) {
 
509
        
 
510
        m_Seed = seed;
 
511
    }
 
512
    
 
513
    /**
 
514
     * Gets the random number seed
 
515
     *
 
516
     * @return the random number seed
 
517
     */
 
518
    public int getSeed() {
 
519
        
 
520
        return m_Seed;
 
521
    }
 
522
    
 
523
    /**
 
524
     * Sets the number of times an instance is classified
 
525
     *
 
526
     * @param classifyIterations number of times an instance is classified
 
527
     */
 
528
    public void setClassifyIterations(int classifyIterations) {
 
529
        
 
530
        m_ClassifyIterations = classifyIterations;
 
531
    }
 
532
    
 
533
    /**
 
534
     * Gets the number of times an instance is classified
 
535
     *
 
536
     * @return the maximum number of times an instance is classified
 
537
     */
 
538
    public int getClassifyIterations() {
 
539
        
 
540
        return m_ClassifyIterations;
 
541
    }
 
542
    
 
543
    /**
 
544
     * Sets the name of the dataset file.
 
545
     *
 
546
     * @param dataFileName name of dataset file.
 
547
     */
 
548
    public void setDataFileName(String dataFileName) {
 
549
        
 
550
        m_DataFileName = dataFileName;
 
551
    }
 
552
    
 
553
    /**
 
554
     * Get the name of the data file used for the decomposition
 
555
     *
 
556
     * @return the name of the data file
 
557
     */
 
558
    public String getDataFileName() {
 
559
        
 
560
        return m_DataFileName;
 
561
    }
 
562
    
 
563
    /**
 
564
     * Get the index (starting from 1) of the attribute used as the class.
 
565
     *
 
566
     * @return the index of the class attribute
 
567
     */
 
568
    public int getClassIndex() {
 
569
        
 
570
        return m_ClassIndex + 1;
 
571
    }
 
572
    
 
573
    /**
 
574
     * Sets index of attribute to discretize on
 
575
     *
 
576
     * @param classIndex the index (starting from 1) of the class attribute
 
577
     */
 
578
    public void setClassIndex(int classIndex) {
 
579
        
 
580
        m_ClassIndex = classIndex - 1;
 
581
    }
 
582
    
 
583
    /**
 
584
     * Get the calculated bias squared according to the Kohavi and Wolpert definition
 
585
     *
 
586
     * @return the bias squared
 
587
     */
 
588
    public double getKWBias() {
 
589
        
 
590
        return m_KWBias;
 
591
    }
 
592
    
 
593
    /**
 
594
     * Get the calculated bias according to the Webb definition
 
595
     *
 
596
     * @return the bias
 
597
     *
 
598
     */
 
599
    public double getWBias() {
 
600
        
 
601
        return m_WBias;
 
602
    }
 
603
    
 
604
    
 
605
    /**
 
606
     * Get the calculated variance according to the Kohavi and Wolpert definition
 
607
     *
 
608
     * @return the variance
 
609
     */
 
610
    public double getKWVariance() {
 
611
        
 
612
        return m_KWVariance;
 
613
    }
 
614
    
 
615
    /**
 
616
     * Get the calculated variance according to the Webb definition
 
617
     *
 
618
     * @return the variance according to Webb
 
619
     *
 
620
     */
 
621
    public double getWVariance() {
 
622
        
 
623
        return m_WVariance;
 
624
    }
 
625
    
 
626
    /**
 
627
     * Get the calculated sigma according to the Kohavi and Wolpert definition
 
628
     *
 
629
     * @return the sigma
 
630
     *
 
631
     */
 
632
    public double getKWSigma() {
 
633
        
 
634
        return m_KWSigma;
 
635
    }
 
636
    
 
637
    /**
 
638
     * Set the training size.
 
639
     *
 
640
     * @param size the size of the training set
 
641
     *
 
642
     */
 
643
    public void setTrainSize(int size) {
 
644
        
 
645
        m_TrainSize = size;
 
646
    }
 
647
    
 
648
    /**
 
649
     * Get the training size
 
650
     *
 
651
     * @return the size of the training set
 
652
     *
 
653
     */
 
654
    public int getTrainSize() {
 
655
        
 
656
        return m_TrainSize;
 
657
    }
 
658
    
 
659
    /**
 
660
     * Set the proportion of instances that are common between two training sets
 
661
     * used to train a classifier.
 
662
     *
 
663
     * @param proportion the proportion of instances that are common between training
 
664
     * sets.
 
665
     *
 
666
     */
 
667
    public void setP(double proportion) {
 
668
        
 
669
        m_P = proportion;
 
670
    }
 
671
    
 
672
    /**
 
673
     * Get the proportion of instances that are common between two training sets.
 
674
     *
 
675
     * @return the proportion
 
676
     *
 
677
     */
 
678
    public double getP() {
 
679
        
 
680
        return m_P;
 
681
    }
 
682
    
 
683
    /**
 
684
     * Get the calculated error rate
 
685
     *
 
686
     * @return the error rate
 
687
     */
 
688
    public double getError() {
 
689
        
 
690
        return m_Error;
 
691
    }
 
692
    
 
693
    /**
 
694
     * Carry out the bias-variance decomposition using the sub-sampled cross-validation method.
 
695
     *
 
696
     * @throws Exception if the decomposition couldn't be carried out
 
697
     */
 
698
    public void decompose() throws Exception {
 
699
        
 
700
        Reader dataReader;
 
701
        Instances data;
 
702
        
 
703
        int tps; // training pool size, size of segment E.
 
704
        int k; // number of folds in segment E.
 
705
        int q; // number of segments of size tps.
 
706
        
 
707
        dataReader = new BufferedReader(new FileReader(m_DataFileName)); //open file
 
708
        data = new Instances(dataReader); // encapsulate in wrapper class called weka.Instances()
 
709
        
 
710
        if (m_ClassIndex < 0) {
 
711
            data.setClassIndex(data.numAttributes() - 1);
 
712
        } else {
 
713
            data.setClassIndex(m_ClassIndex);
 
714
        }
 
715
        
 
716
        if (data.classAttribute().type() != Attribute.NOMINAL) {
 
717
            throw new Exception("Class attribute must be nominal");
 
718
        }
 
719
        int numClasses = data.numClasses();
 
720
        
 
721
        data.deleteWithMissingClass();
 
722
        if ( data.checkForStringAttributes() ) {
 
723
            throw new Exception("Can't handle string attributes!");
 
724
        }
 
725
        
 
726
        // Dataset size must be greater than 2
 
727
        if ( data.numInstances() <= 2 ){
 
728
            throw new Exception("Dataset size must be greater than 2.");
 
729
        }
 
730
        
 
731
        if ( m_TrainSize == -1 ){ // default value
 
732
            m_TrainSize = (int) Math.floor( (double) data.numInstances() / 2.0 );
 
733
        }else  if ( m_TrainSize < 0 || m_TrainSize >= data.numInstances() - 1 ) {  // Check if 0 < training Size < D - 1
 
734
            throw new Exception("Training set size of "+m_TrainSize+" is invalid.");
 
735
        }
 
736
        
 
737
        if ( m_P == -1 ){ // default value
 
738
            m_P = (double) m_TrainSize / ( (double)data.numInstances() - 1 );
 
739
        }else if (  m_P < ( m_TrainSize / ( (double)data.numInstances() - 1 ) ) || m_P >= 1.0  ) { //Check if p is in range: m/(|D|-1) <= p < 1.0
 
740
            throw new Exception("Proportion is not in range: "+ (m_TrainSize / ((double) data.numInstances() - 1 )) +" <= p < 1.0 ");
 
741
        }
 
742
        
 
743
        //roundup tps from double to integer
 
744
        tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 );
 
745
        k = (int) Math.ceil( tps / (tps - (double) m_TrainSize));
 
746
        
 
747
        // number of folds cannot be more than the number of instances in the training pool
 
748
        if ( k > tps ) {
 
749
            throw new Exception("The required number of folds is too many."
 
750
            + "Change p or the size of the training set.");
 
751
        }
 
752
        
 
753
        // calculate the number of segments, round down.
 
754
        q = (int) Math.floor( (double) data.numInstances() / (double)tps );
 
755
        
 
756
        //create confusion matrix, columns = number of instances in data set, as all will be used,  by rows = number of classes.
 
757
        double [][] instanceProbs = new double [data.numInstances()][numClasses];
 
758
        int [][] foldIndex = new int [ k ][ 2 ];
 
759
        Vector segmentList = new Vector(q + 1);
 
760
        
 
761
        //Set random seed
 
762
        Random random = new Random(m_Seed);
 
763
        
 
764
        data.randomize(random);
 
765
        
 
766
        //create index arrays for different segments
 
767
        
 
768
        int currentDataIndex = 0;
 
769
 
 
770
        for( int count = 1; count <= (q + 1); count++ ){
 
771
            if( count > q){
 
772
                int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ];
 
773
                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
 
774
                    
 
775
                    segmentIndex[index] = currentDataIndex;
 
776
                }
 
777
                segmentList.add(segmentIndex);
 
778
            } else {
 
779
                int [] segmentIndex = new int [ tps ];
 
780
                
 
781
                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
 
782
                    segmentIndex[index] = currentDataIndex;
 
783
                }
 
784
                segmentList.add(segmentIndex);
 
785
            }
 
786
        }
 
787
        
 
788
        int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
 
789
        
 
790
        //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3,  3.3->4)
 
791
        int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer
 
792
        int index = 0;
 
793
        int currentIndex;
 
794
        
 
795
        for( int count = 0; count < k; count ++){
 
796
            if( remainder != 0 && count == remainder ){
 
797
                foldSize -= 1;
 
798
            }
 
799
            foldIndex[count][0] = index;
 
800
            foldIndex[count][1] = foldSize;
 
801
            index += foldSize;
 
802
        }
 
803
        
 
804
        for( int l = 0; l < m_ClassifyIterations; l++) {
 
805
            
 
806
            for(int i = 1; i <= q; i++){
 
807
                
 
808
                int [] currentSegment = (int[]) segmentList.get(i - 1);
 
809
                
 
810
                randomize(currentSegment, random);
 
811
                
 
812
                //CROSS FOLD VALIDATION for current Segment
 
813
                for( int j = 1; j <= k; j++){
 
814
                    
 
815
                    Instances TP = null;
 
816
                    for(int foldNum = 1; foldNum <= k; foldNum++){
 
817
                        if( foldNum != j){
 
818
                            
 
819
                            int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index
 
820
                            foldSize = foldIndex[ foldNum - 1 ][ 1 ];
 
821
                            int endFoldIndex = startFoldIndex + foldSize - 1;
 
822
                            
 
823
                            for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){
 
824
                                
 
825
                                if( TP == null ){
 
826
                                    TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);
 
827
                                }else{
 
828
                                    TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );
 
829
                                }
 
830
                            }
 
831
                        }
 
832
                    }
 
833
                    
 
834
                    TP.randomize(random);
 
835
                    
 
836
                    if( getTrainSize() > TP.numInstances() ){
 
837
                        throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "
 
838
                        + TP.numInstances() );
 
839
                    }
 
840
                    
 
841
                    Instances train = new Instances(TP, 0, m_TrainSize);
 
842
                    
 
843
                    Classifier current = Classifier.makeCopy(m_Classifier);
 
844
                    current.buildClassifier(train); // create a clssifier using the instances in train.
 
845
                    
 
846
                    int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index
 
847
                    int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size
 
848
                    int endTestIndex = currentTestIndex + testFoldSize - 1;
 
849
                    
 
850
                    while( currentTestIndex <= endTestIndex ){
 
851
                        
 
852
                        Instance testInst = data.instance( currentSegment[currentTestIndex] );
 
853
                        int pred = (int)current.classifyInstance( testInst );
 
854
                        
 
855
                        
 
856
                        if(pred != testInst.classValue()) {
 
857
                            m_Error++; // add 1 to mis-classifications.
 
858
                        }
 
859
                        instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;
 
860
                        currentTestIndex++;
 
861
                    }
 
862
                    
 
863
                    if( i == 1 && j == 1){
 
864
                        int[] segmentElast = (int[])segmentList.lastElement();
 
865
                        for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){
 
866
                            Instance testInst = data.instance( segmentElast[currentIndex] );
 
867
                            int pred = (int)current.classifyInstance( testInst );
 
868
                            if(pred != testInst.classValue()) {
 
869
                                m_Error++; // add 1 to mis-classifications.
 
870
                            }
 
871
                            
 
872
                            instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;
 
873
                        }
 
874
                    }
 
875
                }
 
876
            }
 
877
        }
 
878
        
 
879
        m_Error /= (double)( m_ClassifyIterations * data.numInstances() );
 
880
        
 
881
        m_KWBias = 0.0;
 
882
        m_KWVariance = 0.0;
 
883
        m_KWSigma = 0.0;
 
884
        
 
885
        m_WBias = 0.0;
 
886
        m_WVariance = 0.0;
 
887
        
 
888
        for (int i = 0; i < data.numInstances(); i++) {
 
889
            
 
890
            Instance current = data.instance( i );
 
891
            
 
892
            double [] predProbs = instanceProbs[ i ];
 
893
            double pActual, pPred;
 
894
            double bsum = 0, vsum = 0, ssum = 0;
 
895
            double wBSum = 0, wVSum = 0;
 
896
            
 
897
            Vector centralTendencies = findCentralTendencies( predProbs );
 
898
            
 
899
            if( centralTendencies == null ){
 
900
                throw new Exception("Central tendency was null.");
 
901
            }
 
902
            
 
903
            for (int j = 0; j < numClasses; j++) {
 
904
                pActual = (current.classValue() == j) ? 1 : 0;
 
905
                pPred = predProbs[j] / m_ClassifyIterations;
 
906
                bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);
 
907
                vsum += pPred * pPred;
 
908
                ssum += pActual * pActual;
 
909
            }
 
910
            
 
911
            m_KWBias += bsum;
 
912
            m_KWVariance += (1 - vsum);
 
913
            m_KWSigma += (1 - ssum);
 
914
            
 
915
            for( int count = 0; count < centralTendencies.size(); count++ ) {
 
916
                
 
917
                int wB = 0, wV = 0;
 
918
                int centralTendency = ((Integer)centralTendencies.get(count)).intValue();
 
919
                
 
920
                // For a single instance xi, find the bias and variance.
 
921
                for (int j = 0; j < numClasses; j++) {
 
922
                    
 
923
                    //Webb definition
 
924
                    if( j != (int)current.classValue() && j == centralTendency ) {
 
925
                        wB += predProbs[j];
 
926
                    }
 
927
                    if( j != (int)current.classValue() && j != centralTendency ) {
 
928
                        wV += predProbs[j];
 
929
                    }
 
930
                    
 
931
                }
 
932
                wBSum += (double) wB;
 
933
                wVSum += (double) wV;
 
934
            }
 
935
            
 
936
            // calculate bais by dividing bSum by the number of central tendencies and
 
937
            // total number of instances. (effectively finding the average and dividing
 
938
            // by the number of instances to get the nominalised probability).
 
939
            
 
940
            m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
 
941
            // calculate variance by dividing vSum by the total number of interations
 
942
            m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
 
943
            
 
944
        }
 
945
        
 
946
        m_KWBias /= (2.0 * (double) data.numInstances());
 
947
        m_KWVariance /= (2.0 * (double) data.numInstances());
 
948
        m_KWSigma /= (2.0 * (double) data.numInstances());
 
949
        
 
950
        // bias = bias / number of data instances
 
951
        m_WBias /= (double) data.numInstances();
 
952
        // variance = variance / number of data instances.
 
953
        m_WVariance /= (double) data.numInstances();
 
954
        
 
955
        if (m_Debug) {
 
956
            System.err.println("Decomposition finished");
 
957
        }
 
958
        
 
959
    }
 
960
    
 
961
    /** Finds the central tendency, given the classifications for an instance.
 
962
     *
 
963
     * Where the central tendency is defined as the class that was most commonly
 
964
     * selected for a given instance.<p>
 
965
     *
 
966
     * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
 
967
     * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
 
968
     * and '3' = 3 times. Then the central tendency is '2'. <p>
 
969
     *
 
970
     * However, it is important to note that this method returns a list of all classes
 
971
     * that have the highest number of classifications.
 
972
     *
 
973
     * In cases where there are several classes with the largest number of classifications, then
 
974
     * all of these classes are returned. For example if 'x' is classified '1' = 4 times,
 
975
     * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
 
976
     *
 
977
     * @param predProbs the array of classifications for a single instance.
 
978
     *
 
979
     * @return a Vector containing Integer objects which store the class(s) which
 
980
     * are the central tendency.
 
981
     */
 
982
    public Vector findCentralTendencies(double[] predProbs) {
 
983
        
 
984
        int centralTValue = 0;
 
985
        int currentValue = 0;
 
986
        //array to store the list of classes the have the greatest number of classifictions.
 
987
        Vector centralTClasses;
 
988
        
 
989
        centralTClasses = new Vector(); //create an array with size of the number of classes.
 
990
        
 
991
        // Go through array, finding the central tendency.
 
992
        for( int i = 0; i < predProbs.length; i++) {
 
993
            currentValue = (int) predProbs[i];
 
994
            // if current value is greater than the central tendency value then
 
995
            // clear vector and add new class to vector array.
 
996
            if( currentValue > centralTValue) {
 
997
                centralTClasses.clear();
 
998
                centralTClasses.addElement( new Integer(i) );
 
999
                centralTValue = currentValue;
 
1000
            } else if( currentValue != 0 && currentValue == centralTValue) {
 
1001
                centralTClasses.addElement( new Integer(i) );
 
1002
            }
 
1003
        }
 
1004
        //return all classes that have the greatest number of classifications.
 
1005
        if( centralTValue != 0){
 
1006
            return centralTClasses;
 
1007
        } else {
 
1008
            return null;
 
1009
        }
 
1010
        
 
1011
    }
 
1012
    
 
1013
    /**
 
1014
     * Returns description of the bias-variance decomposition results.
 
1015
     *
 
1016
     * @return the bias-variance decomposition results as a string
 
1017
     */
 
1018
    public String toString() {
 
1019
        
 
1020
        String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +
 
1021
        "with subsampling.\n";
 
1022
        
 
1023
        if (getClassifier() == null) {
 
1024
            return "Invalid setup";
 
1025
        }
 
1026
        
 
1027
        result += "\nClassifier    : " + getClassifier().getClass().getName();
 
1028
        if (getClassifier() instanceof OptionHandler) {
 
1029
            result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
 
1030
        }
 
1031
        result += "\nData File     : " + getDataFileName();
 
1032
        result += "\nClass Index   : ";
 
1033
        if (getClassIndex() == 0) {
 
1034
            result += "last";
 
1035
        } else {
 
1036
            result += getClassIndex();
 
1037
        }
 
1038
        result += "\nIterations    : " + getClassifyIterations();
 
1039
        result += "\np             : " + getP();
 
1040
        result += "\nTraining Size : " + getTrainSize();
 
1041
        result += "\nSeed          : " + getSeed();
 
1042
        
 
1043
        result += "\n\nDefinition   : " +"Kohavi and Wolpert";
 
1044
        result += "\nError         :" + Utils.doubleToString(getError(), 4);
 
1045
        result += "\nBias^2        :" + Utils.doubleToString(getKWBias(), 4);
 
1046
        result += "\nVariance      :" + Utils.doubleToString(getKWVariance(), 4);
 
1047
        result += "\nSigma^2       :" + Utils.doubleToString(getKWSigma(), 4);
 
1048
        
 
1049
        result += "\n\nDefinition   : " +"Webb";
 
1050
        result += "\nError         :" + Utils.doubleToString(getError(), 4);
 
1051
        result += "\nBias          :" + Utils.doubleToString(getWBias(), 4);
 
1052
        result += "\nVariance      :" + Utils.doubleToString(getWVariance(), 4);
 
1053
        
 
1054
        return result;
 
1055
    }
 
1056
    
 
1057
    
 
1058
    
 
1059
    /**
 
1060
     * Test method for this class
 
1061
     *
 
1062
     * @param args the command line arguments
 
1063
     */
 
1064
    public static void main(String [] args) {
 
1065
        
 
1066
        try {
 
1067
            BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
 
1068
            
 
1069
            try {
 
1070
                bvd.setOptions(args);
 
1071
                Utils.checkForRemainingOptions(args);
 
1072
            } catch (Exception ex) {
 
1073
                String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
 
1074
                Enumeration enu = bvd.listOptions();
 
1075
                while (enu.hasMoreElements()) {
 
1076
                    Option option = (Option) enu.nextElement();
 
1077
                    result += option.synopsis() + "\n" + option.description() + "\n";
 
1078
                }
 
1079
                throw new Exception(result);
 
1080
            }
 
1081
            
 
1082
            bvd.decompose();
 
1083
            
 
1084
            System.out.println(bvd.toString());
 
1085
            
 
1086
        } catch (Exception ex) {
 
1087
            System.err.println(ex.getMessage());
 
1088
        }
 
1089
        
 
1090
    }
 
1091
    
 
1092
    /**
 
1093
     * Accepts an array of ints and randomises the values in the array, using the
 
1094
     * random seed.
 
1095
     *
 
1096
     *@param index is the array of integers
 
1097
     *@param random is the Random seed.
 
1098
     */
 
1099
    public final void randomize(int[] index, Random random) {
 
1100
        for( int j = index.length - 1; j > 0; j-- ){
 
1101
            int k = random.nextInt( j + 1 );
 
1102
            int temp = index[j];
 
1103
            index[j] = index[k];
 
1104
            index[k] = temp;
 
1105
        }
 
1106
    }
 
1107
}