2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* BVDecomposeSegCVSub.java
19
* Copyright (C) 2003 Paul Conilione
21
* Based on the class: BVDecompose.java by Len Trigg (1999)
28
* Paul Conilione would like to express his deep gratitude and appreciation
29
* to his Chinese Buddhist Taoist Master Sifu Chow Yuk Nen for the abilities
30
* and insight that he has been taught, which have allowed him to program in
31
* a clear and efficient manner.
33
* Master Sifu Chow Yuk Nen's Teachings are unique and precious. They are
34
* applicable to any field of human endeavour. Through his unique and powerful
35
* ability to skilfully apply Chinese Buddhist Teachings, people have achieved
36
* success in; Computing, chemical engineering, business, accounting, philosophy
41
package weka.classifiers;
43
import weka.core.Attribute;
44
import weka.core.Instance;
45
import weka.core.Instances;
46
import weka.core.Option;
47
import weka.core.OptionHandler;
48
import weka.core.TechnicalInformation;
49
import weka.core.TechnicalInformation.Type;
50
import weka.core.TechnicalInformation.Field;
51
import weka.core.TechnicalInformationHandler;
52
import weka.core.Utils;
54
import java.io.BufferedReader;
55
import java.io.FileReader;
56
import java.io.Reader;
57
import java.util.Enumeration;
58
import java.util.Random;
59
import java.util.Vector;
62
<!-- globalinfo-start -->
63
* This class performs Bias-Variance decomposion on any classifier using the sub-sampled cross-validation procedure as specified in (1).<br/>
64
* The Kohavi and Wolpert definition of bias and variance is specified in (2).<br/>
65
* The Webb definition of bias and variance is specified in (3).<br/>
67
* Geoffrey I. Webb, Paul Conilione (2002). Estimating bias and variance from data. School of Computer Science and Software Engineering, Victoria, Australia.<br/>
69
* Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.<br/>
71
* Geoffrey I. Webb (2000). MultiBoosting: A Technique for Combining Boosting and Wagging. Machine Learning. 40(2):159-196.
73
<!-- globalinfo-end -->
75
<!-- technical-bibtex-start -->
79
* address = {School of Computer Science and Software Engineering, Victoria, Australia},
80
* author = {Geoffrey I. Webb and Paul Conilione},
81
* institution = {Monash University},
82
* title = {Estimating bias and variance from data},
84
* PDF = {http://www.csse.monash.edu.au/\~webb/Files/WebbConilione04.pdf}
87
* @inproceedings{Kohavi1996,
88
* author = {Ron Kohavi and David H. Wolpert},
89
* booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
90
* editor = {Lorenza Saitta},
92
* publisher = {Morgan Kaufmann},
93
* title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
95
* PS = {http://robotics.stanford.edu/\~ronnyk/biasVar.ps}
98
* @article{Webb2000,
99
* author = {Geoffrey I. Webb},
100
* journal = {Machine Learning},
103
* title = {MultiBoosting: A Technique for Combining Boosting and Wagging},
109
<!-- technical-bibtex-end -->
111
<!-- options-start -->
112
* Valid options are: <p/>
114
* <pre> -c <class index>
115
* The index of the class attribute.
116
* (default last)</pre>
119
* Turn on debugging output.</pre>
121
* <pre> -l <num>
122
* The number of times each instance is classified.
125
* <pre> -p <proportion of objects in common>
126
* The average proportion of instances common between any two training sets</pre>
128
* <pre> -s <seed>
129
* The random number seed used.</pre>
131
* <pre> -t <name of arff file>
132
* The name of the arff file used for the decomposition.</pre>
134
* <pre> -T <number of instances in training set>
135
* The number of instances in the training set.</pre>
137
* <pre> -W <classifier class name>
138
* Full class name of the learner used in the decomposition.
139
* eg: weka.classifiers.bayes.NaiveBayes</pre>
142
* Options specific to learner weka.classifiers.rules.ZeroR:
146
* If set, classifier is run in debug mode and
147
* may output additional info to the console</pre>
151
* Options after -- are passed to the designated sub-learner. <p>
153
* @author Paul Conilione (paulc4321@yahoo.com.au)
154
* @version $Revision: 1.6 $
156
public class BVDecomposeSegCVSub
157
implements OptionHandler, TechnicalInformationHandler {
159
/** Debugging mode, gives extra output if true. */
160
protected boolean m_Debug;
162
/** An instantiated base classifier used for getting and testing options. */
163
protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
165
/** The options to be passed to the base classifier. */
166
protected String [] m_ClassifierOptions;
168
/** The number of times an instance is classified*/
169
protected int m_ClassifyIterations;
171
/** The name of the data file used for the decomposition */
172
protected String m_DataFileName;
174
/** The index of the class attribute */
175
protected int m_ClassIndex = -1;
177
/** The random number seed */
178
protected int m_Seed = 1;
180
/** The calculated Kohavi & Wolpert bias (squared) */
181
protected double m_KWBias;
183
/** The calculated Kohavi & Wolpert variance */
184
protected double m_KWVariance;
186
/** The calculated Kohavi & Wolpert sigma */
187
protected double m_KWSigma;
189
/** The calculated Webb bias */
190
protected double m_WBias;
192
/** The calculated Webb variance */
193
protected double m_WVariance;
195
/** The error rate */
196
protected double m_Error;
198
/** The training set size */
199
protected int m_TrainSize;
201
/** Proportion of instances common between any two training sets. */
202
protected double m_P;
205
* Returns a string describing this object
206
* @return a description of the classifier suitable for
207
* displaying in the explorer/experimenter gui
209
public String globalInfo() {
211
"This class performs Bias-Variance decomposion on any classifier using the "
212
+ "sub-sampled cross-validation procedure as specified in (1).\n"
213
+ "The Kohavi and Wolpert definition of bias and variance is specified in (2).\n"
214
+ "The Webb definition of bias and variance is specified in (3).\n\n"
215
+ getTechnicalInformation().toString();
219
* Returns an instance of a TechnicalInformation object, containing
220
* detailed information about the technical background of this class,
221
* e.g., paper reference or book this class is based on.
223
* @return the technical information about this class
225
public TechnicalInformation getTechnicalInformation() {
226
TechnicalInformation result;
227
TechnicalInformation additional;
229
result = new TechnicalInformation(Type.MISC);
230
result.setValue(Field.AUTHOR, "Geoffrey I. Webb and Paul Conilione");
231
result.setValue(Field.YEAR, "2002");
232
result.setValue(Field.TITLE, "Estimating bias and variance from data");
233
result.setValue(Field.INSTITUTION, "Monash University");
234
result.setValue(Field.ADDRESS, "School of Computer Science and Software Engineering, Victoria, Australia");
235
result.setValue(Field.PDF, "http://www.csse.monash.edu.au/~webb/Files/WebbConilione04.pdf");
237
additional = result.add(Type.INPROCEEDINGS);
238
additional.setValue(Field.AUTHOR, "Ron Kohavi and David H. Wolpert");
239
additional.setValue(Field.YEAR, "1996");
240
additional.setValue(Field.TITLE, "Bias Plus Variance Decomposition for Zero-One Loss Functions");
241
additional.setValue(Field.BOOKTITLE, "Machine Learning: Proceedings of the Thirteenth International Conference");
242
additional.setValue(Field.PUBLISHER, "Morgan Kaufmann");
243
additional.setValue(Field.EDITOR, "Lorenza Saitta");
244
additional.setValue(Field.PAGES, "275-283");
245
additional.setValue(Field.PS, "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
247
additional = result.add(Type.ARTICLE);
248
additional.setValue(Field.AUTHOR, "Geoffrey I. Webb");
249
additional.setValue(Field.YEAR, "2000");
250
additional.setValue(Field.TITLE, "MultiBoosting: A Technique for Combining Boosting and Wagging");
251
additional.setValue(Field.JOURNAL, "Machine Learning");
252
additional.setValue(Field.VOLUME, "40");
253
additional.setValue(Field.NUMBER, "2");
254
additional.setValue(Field.PAGES, "159-196");
260
* Returns an enumeration describing the available options.
262
* @return an enumeration of all the available options.
264
public Enumeration listOptions() {
266
Vector newVector = new Vector(8);
268
newVector.addElement(new Option(
269
"\tThe index of the class attribute.\n"+
271
"c", 1, "-c <class index>"));
272
newVector.addElement(new Option(
273
"\tTurn on debugging output.",
275
newVector.addElement(new Option(
276
"\tThe number of times each instance is classified.\n"
278
"l", 1, "-l <num>"));
279
newVector.addElement(new Option(
280
"\tThe average proportion of instances common between any two training sets",
281
"p", 1, "-p <proportion of objects in common>"));
282
newVector.addElement(new Option(
283
"\tThe random number seed used.",
284
"s", 1, "-s <seed>"));
285
newVector.addElement(new Option(
286
"\tThe name of the arff file used for the decomposition.",
287
"t", 1, "-t <name of arff file>"));
288
newVector.addElement(new Option(
289
"\tThe number of instances in the training set.",
290
"T", 1, "-T <number of instances in training set>"));
291
newVector.addElement(new Option(
292
"\tFull class name of the learner used in the decomposition.\n"
293
+"\teg: weka.classifiers.bayes.NaiveBayes",
294
"W", 1, "-W <classifier class name>"));
296
if ((m_Classifier != null) &&
297
(m_Classifier instanceof OptionHandler)) {
298
newVector.addElement(new Option(
300
"", 0, "\nOptions specific to learner "
301
+ m_Classifier.getClass().getName()
303
Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
304
while (enu.hasMoreElements()) {
305
newVector.addElement(enu.nextElement());
308
return newVector.elements();
313
* Sets the OptionHandler's options using the given list. All options
314
* will be set (or reset) during this call (i.e. incremental setting
315
* of options is not possible). <p/>
317
<!-- options-start -->
318
* Valid options are: <p/>
320
* <pre> -c <class index>
321
* The index of the class attribute.
322
* (default last)</pre>
325
* Turn on debugging output.</pre>
327
* <pre> -l <num>
328
* The number of times each instance is classified.
331
* <pre> -p <proportion of objects in common>
332
* The average proportion of instances common between any two training sets</pre>
334
* <pre> -s <seed>
335
* The random number seed used.</pre>
337
* <pre> -t <name of arff file>
338
* The name of the arff file used for the decomposition.</pre>
340
* <pre> -T <number of instances in training set>
341
* The number of instances in the training set.</pre>
343
* <pre> -W <classifier class name>
344
* Full class name of the learner used in the decomposition.
345
* eg: weka.classifiers.bayes.NaiveBayes</pre>
348
* Options specific to learner weka.classifiers.rules.ZeroR:
352
* If set, classifier is run in debug mode and
353
* may output additional info to the console</pre>
357
* @param options the list of options as an array of strings
358
* @throws Exception if an option is not supported
360
public void setOptions(String[] options) throws Exception {
361
setDebug(Utils.getFlag('D', options));
363
String classIndex = Utils.getOption('c', options);
364
if (classIndex.length() != 0) {
365
if (classIndex.toLowerCase().equals("last")) {
367
} else if (classIndex.toLowerCase().equals("first")) {
370
setClassIndex(Integer.parseInt(classIndex));
376
String classifyIterations = Utils.getOption('l', options);
377
if (classifyIterations.length() != 0) {
378
setClassifyIterations(Integer.parseInt(classifyIterations));
380
setClassifyIterations(10);
383
String prob = Utils.getOption('p', options);
384
if (prob.length() != 0) {
385
setP( Double.parseDouble(prob));
389
//throw new Exception("A proportion must be specified" + " with a -p option.");
391
String seedString = Utils.getOption('s', options);
392
if (seedString.length() != 0) {
393
setSeed(Integer.parseInt(seedString));
398
String dataFile = Utils.getOption('t', options);
399
if (dataFile.length() != 0) {
400
setDataFileName(dataFile);
402
throw new Exception("An arff file must be specified"
403
+ " with the -t option.");
406
String trainSize = Utils.getOption('T', options);
407
if (trainSize.length() != 0) {
408
setTrainSize(Integer.parseInt(trainSize));
412
//throw new Exception("A training set size must be specified" + " with a -T option.");
414
String classifierName = Utils.getOption('W', options);
415
if (classifierName.length() != 0) {
416
setClassifier(Classifier.forName(classifierName, Utils.partitionOptions(options)));
418
throw new Exception("A learner must be specified with the -W option.");
423
* Gets the current settings of the CheckClassifier.
425
* @return an array of strings suitable for passing to setOptions
427
public String [] getOptions() {
429
String [] classifierOptions = new String [0];
430
if ((m_Classifier != null) &&
431
(m_Classifier instanceof OptionHandler)) {
432
classifierOptions = ((OptionHandler)m_Classifier).getOptions();
434
String [] options = new String [classifierOptions.length + 14];
437
options[current++] = "-D";
439
options[current++] = "-c"; options[current++] = "" + getClassIndex();
440
options[current++] = "-l"; options[current++] = "" + getClassifyIterations();
441
options[current++] = "-p"; options[current++] = "" + getP();
442
options[current++] = "-s"; options[current++] = "" + getSeed();
443
if (getDataFileName() != null) {
444
options[current++] = "-t"; options[current++] = "" + getDataFileName();
446
options[current++] = "-T"; options[current++] = "" + getTrainSize();
447
if (getClassifier() != null) {
448
options[current++] = "-W";
449
options[current++] = getClassifier().getClass().getName();
452
options[current++] = "--";
453
System.arraycopy(classifierOptions, 0, options, current,
454
classifierOptions.length);
455
current += classifierOptions.length;
456
while (current < options.length) {
457
options[current++] = "";
463
* Set the classifiers being analysed
465
* @param newClassifier the Classifier to use.
467
public void setClassifier(Classifier newClassifier) {
469
m_Classifier = newClassifier;
473
* Gets the name of the classifier being analysed
475
* @return the classifier being analysed.
477
public Classifier getClassifier() {
483
* Sets debugging mode
485
* @param debug true if debug output should be printed
487
public void setDebug(boolean debug) {
493
* Gets whether debugging is turned on
495
* @return true if debugging output is on
497
public boolean getDebug() {
504
* Sets the random number seed
506
* @param seed the random number seed
508
public void setSeed(int seed) {
514
* Gets the random number seed
516
* @return the random number seed
518
public int getSeed() {
524
* Sets the number of times an instance is classified
526
* @param classifyIterations number of times an instance is classified
528
public void setClassifyIterations(int classifyIterations) {
530
m_ClassifyIterations = classifyIterations;
534
* Gets the number of times an instance is classified
536
* @return the maximum number of times an instance is classified
538
public int getClassifyIterations() {
540
return m_ClassifyIterations;
544
* Sets the name of the dataset file.
546
* @param dataFileName name of dataset file.
548
public void setDataFileName(String dataFileName) {
550
m_DataFileName = dataFileName;
554
* Get the name of the data file used for the decomposition
556
* @return the name of the data file
558
public String getDataFileName() {
560
return m_DataFileName;
564
* Get the index (starting from 1) of the attribute used as the class.
566
* @return the index of the class attribute
568
public int getClassIndex() {
570
return m_ClassIndex + 1;
574
* Sets index of attribute to discretize on
576
* @param classIndex the index (starting from 1) of the class attribute
578
public void setClassIndex(int classIndex) {
580
m_ClassIndex = classIndex - 1;
584
* Get the calculated bias squared according to the Kohavi and Wolpert definition
586
* @return the bias squared
588
public double getKWBias() {
594
* Get the calculated bias according to the Webb definition
599
public double getWBias() {
606
* Get the calculated variance according to the Kohavi and Wolpert definition
608
* @return the variance
610
public double getKWVariance() {
616
* Get the calculated variance according to the Webb definition
618
* @return the variance according to Webb
621
public double getWVariance() {
627
* Get the calculated sigma according to the Kohavi and Wolpert definition
632
public double getKWSigma() {
638
* Set the training size.
640
* @param size the size of the training set
643
public void setTrainSize(int size) {
649
* Get the training size
651
* @return the size of the training set
654
public int getTrainSize() {
660
* Set the proportion of instances that are common between two training sets
661
* used to train a classifier.
663
* @param proportion the proportion of instances that are common between training
667
public void setP(double proportion) {
673
* Get the proportion of instances that are common between two training sets.
675
* @return the proportion
678
public double getP() {
684
* Get the calculated error rate
686
* @return the error rate
688
public double getError() {
694
* Carry out the bias-variance decomposition using the sub-sampled cross-validation method.
696
* @throws Exception if the decomposition couldn't be carried out
698
public void decompose() throws Exception {
703
int tps; // training pool size, size of segment E.
704
int k; // number of folds in segment E.
705
int q; // number of segments of size tps.
707
dataReader = new BufferedReader(new FileReader(m_DataFileName)); //open file
708
data = new Instances(dataReader); // encapsulate in wrapper class called weka.Instances()
710
if (m_ClassIndex < 0) {
711
data.setClassIndex(data.numAttributes() - 1);
713
data.setClassIndex(m_ClassIndex);
716
if (data.classAttribute().type() != Attribute.NOMINAL) {
717
throw new Exception("Class attribute must be nominal");
719
int numClasses = data.numClasses();
721
data.deleteWithMissingClass();
722
if ( data.checkForStringAttributes() ) {
723
throw new Exception("Can't handle string attributes!");
726
// Dataset size must be greater than 2
727
if ( data.numInstances() <= 2 ){
728
throw new Exception("Dataset size must be greater than 2.");
731
if ( m_TrainSize == -1 ){ // default value
732
m_TrainSize = (int) Math.floor( (double) data.numInstances() / 2.0 );
733
}else if ( m_TrainSize < 0 || m_TrainSize >= data.numInstances() - 1 ) { // Check if 0 < training Size < D - 1
734
throw new Exception("Training set size of "+m_TrainSize+" is invalid.");
737
if ( m_P == -1 ){ // default value
738
m_P = (double) m_TrainSize / ( (double)data.numInstances() - 1 );
739
}else if ( m_P < ( m_TrainSize / ( (double)data.numInstances() - 1 ) ) || m_P >= 1.0 ) { //Check if p is in range: m/(|D|-1) <= p < 1.0
740
throw new Exception("Proportion is not in range: "+ (m_TrainSize / ((double) data.numInstances() - 1 )) +" <= p < 1.0 ");
743
//roundup tps from double to integer
744
tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 );
745
k = (int) Math.ceil( tps / (tps - (double) m_TrainSize));
747
// number of folds cannot be more than the number of instances in the training pool
749
throw new Exception("The required number of folds is too many."
750
+ "Change p or the size of the training set.");
753
// calculate the number of segments, round down.
754
q = (int) Math.floor( (double) data.numInstances() / (double)tps );
756
//create confusion matrix, columns = number of instances in data set, as all will be used, by rows = number of classes.
757
double [][] instanceProbs = new double [data.numInstances()][numClasses];
758
int [][] foldIndex = new int [ k ][ 2 ];
759
Vector segmentList = new Vector(q + 1);
762
Random random = new Random(m_Seed);
764
data.randomize(random);
766
//create index arrays for different segments
768
int currentDataIndex = 0;
770
for( int count = 1; count <= (q + 1); count++ ){
772
int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ];
773
for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
775
segmentIndex[index] = currentDataIndex;
777
segmentList.add(segmentIndex);
779
int [] segmentIndex = new int [ tps ];
781
for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
782
segmentIndex[index] = currentDataIndex;
784
segmentList.add(segmentIndex);
788
int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
790
//foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3, 3.3->4)
791
int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer
795
for( int count = 0; count < k; count ++){
796
if( remainder != 0 && count == remainder ){
799
foldIndex[count][0] = index;
800
foldIndex[count][1] = foldSize;
804
for( int l = 0; l < m_ClassifyIterations; l++) {
806
for(int i = 1; i <= q; i++){
808
int [] currentSegment = (int[]) segmentList.get(i - 1);
810
randomize(currentSegment, random);
812
//CROSS FOLD VALIDATION for current Segment
813
for( int j = 1; j <= k; j++){
816
for(int foldNum = 1; foldNum <= k; foldNum++){
819
int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index
820
foldSize = foldIndex[ foldNum - 1 ][ 1 ];
821
int endFoldIndex = startFoldIndex + foldSize - 1;
823
for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){
826
TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);
828
TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );
834
TP.randomize(random);
836
if( getTrainSize() > TP.numInstances() ){
837
throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "
838
+ TP.numInstances() );
841
Instances train = new Instances(TP, 0, m_TrainSize);
843
Classifier current = Classifier.makeCopy(m_Classifier);
844
current.buildClassifier(train); // create a clssifier using the instances in train.
846
int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index
847
int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size
848
int endTestIndex = currentTestIndex + testFoldSize - 1;
850
while( currentTestIndex <= endTestIndex ){
852
Instance testInst = data.instance( currentSegment[currentTestIndex] );
853
int pred = (int)current.classifyInstance( testInst );
856
if(pred != testInst.classValue()) {
857
m_Error++; // add 1 to mis-classifications.
859
instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;
863
if( i == 1 && j == 1){
864
int[] segmentElast = (int[])segmentList.lastElement();
865
for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){
866
Instance testInst = data.instance( segmentElast[currentIndex] );
867
int pred = (int)current.classifyInstance( testInst );
868
if(pred != testInst.classValue()) {
869
m_Error++; // add 1 to mis-classifications.
872
instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;
879
m_Error /= (double)( m_ClassifyIterations * data.numInstances() );
888
for (int i = 0; i < data.numInstances(); i++) {
890
Instance current = data.instance( i );
892
double [] predProbs = instanceProbs[ i ];
893
double pActual, pPred;
894
double bsum = 0, vsum = 0, ssum = 0;
895
double wBSum = 0, wVSum = 0;
897
Vector centralTendencies = findCentralTendencies( predProbs );
899
if( centralTendencies == null ){
900
throw new Exception("Central tendency was null.");
903
for (int j = 0; j < numClasses; j++) {
904
pActual = (current.classValue() == j) ? 1 : 0;
905
pPred = predProbs[j] / m_ClassifyIterations;
906
bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);
907
vsum += pPred * pPred;
908
ssum += pActual * pActual;
912
m_KWVariance += (1 - vsum);
913
m_KWSigma += (1 - ssum);
915
for( int count = 0; count < centralTendencies.size(); count++ ) {
918
int centralTendency = ((Integer)centralTendencies.get(count)).intValue();
920
// For a single instance xi, find the bias and variance.
921
for (int j = 0; j < numClasses; j++) {
924
if( j != (int)current.classValue() && j == centralTendency ) {
927
if( j != (int)current.classValue() && j != centralTendency ) {
932
wBSum += (double) wB;
933
wVSum += (double) wV;
936
// calculate bais by dividing bSum by the number of central tendencies and
937
// total number of instances. (effectively finding the average and dividing
938
// by the number of instances to get the nominalised probability).
940
m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
941
// calculate variance by dividing vSum by the total number of interations
942
m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
946
m_KWBias /= (2.0 * (double) data.numInstances());
947
m_KWVariance /= (2.0 * (double) data.numInstances());
948
m_KWSigma /= (2.0 * (double) data.numInstances());
950
// bias = bias / number of data instances
951
m_WBias /= (double) data.numInstances();
952
// variance = variance / number of data instances.
953
m_WVariance /= (double) data.numInstances();
956
System.err.println("Decomposition finished");
961
/** Finds the central tendency, given the classifications for an instance.
963
* Where the central tendency is defined as the class that was most commonly
964
* selected for a given instance.<p>
966
* For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
967
* so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
968
* and '3' = 3 times. Then the central tendency is '2'. <p>
970
* However, it is important to note that this method returns a list of all classes
971
* that have the highest number of classifications.
973
* In cases where there are several classes with the largest number of classifications, then
974
* all of these classes are returned. For example if 'x' is classified '1' = 4 times,
975
* '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
977
* @param predProbs the array of classifications for a single instance.
979
* @return a Vector containing Integer objects which store the class(s) which
980
* are the central tendency.
982
public Vector findCentralTendencies(double[] predProbs) {
984
int centralTValue = 0;
985
int currentValue = 0;
986
//array to store the list of classes the have the greatest number of classifictions.
987
Vector centralTClasses;
989
centralTClasses = new Vector(); //create an array with size of the number of classes.
991
// Go through array, finding the central tendency.
992
for( int i = 0; i < predProbs.length; i++) {
993
currentValue = (int) predProbs[i];
994
// if current value is greater than the central tendency value then
995
// clear vector and add new class to vector array.
996
if( currentValue > centralTValue) {
997
centralTClasses.clear();
998
centralTClasses.addElement( new Integer(i) );
999
centralTValue = currentValue;
1000
} else if( currentValue != 0 && currentValue == centralTValue) {
1001
centralTClasses.addElement( new Integer(i) );
1004
//return all classes that have the greatest number of classifications.
1005
if( centralTValue != 0){
1006
return centralTClasses;
1014
* Returns description of the bias-variance decomposition results.
1016
* @return the bias-variance decomposition results as a string
1018
public String toString() {
1020
String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +
1021
"with subsampling.\n";
1023
if (getClassifier() == null) {
1024
return "Invalid setup";
1027
result += "\nClassifier : " + getClassifier().getClass().getName();
1028
if (getClassifier() instanceof OptionHandler) {
1029
result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
1031
result += "\nData File : " + getDataFileName();
1032
result += "\nClass Index : ";
1033
if (getClassIndex() == 0) {
1036
result += getClassIndex();
1038
result += "\nIterations : " + getClassifyIterations();
1039
result += "\np : " + getP();
1040
result += "\nTraining Size : " + getTrainSize();
1041
result += "\nSeed : " + getSeed();
1043
result += "\n\nDefinition : " +"Kohavi and Wolpert";
1044
result += "\nError :" + Utils.doubleToString(getError(), 4);
1045
result += "\nBias^2 :" + Utils.doubleToString(getKWBias(), 4);
1046
result += "\nVariance :" + Utils.doubleToString(getKWVariance(), 4);
1047
result += "\nSigma^2 :" + Utils.doubleToString(getKWSigma(), 4);
1049
result += "\n\nDefinition : " +"Webb";
1050
result += "\nError :" + Utils.doubleToString(getError(), 4);
1051
result += "\nBias :" + Utils.doubleToString(getWBias(), 4);
1052
result += "\nVariance :" + Utils.doubleToString(getWVariance(), 4);
1060
* Test method for this class
1062
* @param args the command line arguments
1064
public static void main(String [] args) {
1067
BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
1070
bvd.setOptions(args);
1071
Utils.checkForRemainingOptions(args);
1072
} catch (Exception ex) {
1073
String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
1074
Enumeration enu = bvd.listOptions();
1075
while (enu.hasMoreElements()) {
1076
Option option = (Option) enu.nextElement();
1077
result += option.synopsis() + "\n" + option.description() + "\n";
1079
throw new Exception(result);
1084
System.out.println(bvd.toString());
1086
} catch (Exception ex) {
1087
System.err.println(ex.getMessage());
1093
* Accepts an array of ints and randomises the values in the array, using the
1096
*@param index is the array of integers
1097
*@param random is the Random seed.
1099
public final void randomize(int[] index, Random random) {
1100
for( int j = index.length - 1; j > 0; j-- ){
1101
int k = random.nextInt( j + 1 );
1102
int temp = index[j];
1103
index[j] = index[k];