~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/Evaluation.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    Evaluation.java
 
19
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers;
 
24
 
 
25
import weka.classifiers.evaluation.NominalPrediction;
 
26
import weka.classifiers.evaluation.ThresholdCurve;
 
27
import weka.classifiers.xml.XMLClassifier;
 
28
import weka.core.Drawable;
 
29
import weka.core.FastVector;
 
30
import weka.core.Instance;
 
31
import weka.core.Instances;
 
32
import weka.core.Option;
 
33
import weka.core.OptionHandler;
 
34
import weka.core.Range;
 
35
import weka.core.Summarizable;
 
36
import weka.core.Utils;
 
37
import weka.core.Version;
 
38
import weka.core.converters.ConverterUtils.DataSink;
 
39
import weka.core.converters.ConverterUtils.DataSource;
 
40
import weka.core.xml.KOML;
 
41
import weka.core.xml.XMLOptions;
 
42
import weka.core.xml.XMLSerialization;
 
43
import weka.estimators.Estimator;
 
44
import weka.estimators.KernelEstimator;
 
45
 
 
46
import java.io.BufferedInputStream;
 
47
import java.io.BufferedOutputStream;
 
48
import java.io.BufferedReader;
 
49
import java.io.FileInputStream;
 
50
import java.io.FileOutputStream;
 
51
import java.io.FileReader;
 
52
import java.io.InputStream;
 
53
import java.io.ObjectInputStream;
 
54
import java.io.ObjectOutputStream;
 
55
import java.io.OutputStream;
 
56
import java.io.Reader;
 
57
import java.util.Date;
 
58
import java.util.Enumeration;
 
59
import java.util.Random;
 
60
import java.util.zip.GZIPInputStream;
 
61
import java.util.zip.GZIPOutputStream;
 
62
 
 
63
/**
 
64
 * Class for evaluating machine learning models. <p/>
 
65
 *
 
66
 * ------------------------------------------------------------------- <p/>
 
67
 *
 
68
 * General options when evaluating a learning scheme from the command-line: <p/>
 
69
 *
 
70
 * -t filename <br/>
 
71
 * Name of the file with the training data. (required) <p/>
 
72
 *
 
73
 * -T filename <br/>
 
74
 * Name of the file with the test data. If missing a cross-validation 
 
75
 * is performed. <p/>
 
76
 *
 
77
 * -c index <br/>
 
78
 * Index of the class attribute (1, 2, ...; default: last). <p/>
 
79
 *
 
80
 * -x number <br/>
 
81
 * The number of folds for the cross-validation (default: 10). <p/>
 
82
 *
 
83
 * -no-cv <br/>
 
84
 * No cross validation.  If no test file is provided, no evaluation
 
85
 * is done. <p/>
 
86
 * 
 
87
 * -split-percentage percentage <br/>
 
88
 * Sets the percentage for the train/test set split, e.g., 66. <p/>
 
89
 * 
 
90
 * -preserve-order <br/>
 
91
 * Preserves the order in the percentage split instead of randomizing
 
92
 * the data first with the seed value ('-s'). <p/>
 
93
 *
 
94
 * -s seed <br/>
 
95
 * Random number seed for the cross-validation and percentage split
 
96
 * (default: 1). <p/>
 
97
 *
 
98
 * -m filename <br/>
 
99
 * The name of a file containing a cost matrix. <p/>
 
100
 *
 
101
 * -l filename <br/>
 
102
 * Loads classifier from the given file. In case the filename ends with ".xml" 
 
103
 * the options are loaded from XML. <p/>
 
104
 *
 
105
 * -d filename <br/>
 
106
 * Saves classifier built from the training data into the given file. In case 
 
107
 * the filename ends with ".xml" the options are saved XML, not the model. <p/>
 
108
 *
 
109
 * -v <br/>
 
110
 * Outputs no statistics for the training data. <p/>
 
111
 *
 
112
 * -o <br/>
 
113
 * Outputs statistics only, not the classifier. <p/>
 
114
 * 
 
115
 * -i <br/>
 
116
 * Outputs information-retrieval statistics per class. <p/>
 
117
 *
 
118
 * -k <br/>
 
119
 * Outputs information-theoretic statistics. <p/>
 
120
 *
 
121
 * -p range <br/>
 
122
 * Outputs predictions for test instances (or the train instances if no test
 
123
 * instances provided), along with the attributes in the specified range 
 
124
 * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
 
125
 * 
 
126
 * -distribution <br/>
 
127
 * Outputs the distribution instead of only the prediction
 
128
 * in conjunction with the '-p' option (only nominal classes). <p/>
 
129
 *
 
130
 * -r <br/>
 
131
 * Outputs cumulative margin distribution (and nothing else). <p/>
 
132
 *
 
133
 * -g <br/> 
 
134
 * Only for classifiers that implement "Graphable." Outputs
 
135
 * the graph representation of the classifier (and nothing
 
136
 * else). <p/>
 
137
 * 
 
138
 * -xml filename | xml-string <br/>
 
139
 * Retrieves the options from the XML-data instead of the command line. <p/>
 
140
 * 
 
141
 * -threshold-file file <br/>
 
142
 * The file to save the threshold data to.
 
143
 * The format is determined by the extensions, e.g., '.arff' for ARFF
 
144
 * format or '.csv' for CSV. <p/>
 
145
 *         
 
146
 * -threshold-label label <br/>
 
147
 * The class label to determine the threshold data for
 
148
 * (default is the first label) <p/>
 
149
 *         
 
150
 * ------------------------------------------------------------------- <p/>
 
151
 *
 
152
 * Example usage as the main of a classifier (called FunkyClassifier):
 
153
 * <code> <pre>
 
154
 * public static void main(String [] args) {
 
155
 *   runClassifier(new FunkyClassifier(), args);
 
156
 * }
 
157
 * </pre> </code> 
 
158
 * <p/>
 
159
 *
 
160
 * ------------------------------------------------------------------ <p/>
 
161
 *
 
162
 * Example usage from within an application:
 
163
 * <code> <pre>
 
164
 * Instances trainInstances = ... instances got from somewhere
 
165
 * Instances testInstances = ... instances got from somewhere
 
166
 * Classifier scheme = ... scheme got from somewhere
 
167
 *
 
168
 * Evaluation evaluation = new Evaluation(trainInstances);
 
169
 * evaluation.evaluateModel(scheme, testInstances);
 
170
 * System.out.println(evaluation.toSummaryString());
 
171
 * </pre> </code> 
 
172
 *
 
173
 *
 
174
 * @author   Eibe Frank (eibe@cs.waikato.ac.nz)
 
175
 * @author   Len Trigg (trigg@cs.waikato.ac.nz)
 
176
 * @version  $Revision: 1.83 $
 
177
 */
 
178
public class Evaluation
 
179
implements Summarizable {
 
180
 
 
181
  /** The number of classes. */
 
182
  protected int m_NumClasses;
 
183
 
 
184
  /** The number of folds for a cross-validation. */
 
185
  protected int m_NumFolds;
 
186
 
 
187
  /** The weight of all incorrectly classified instances. */
 
188
  protected double m_Incorrect;
 
189
 
 
190
  /** The weight of all correctly classified instances. */
 
191
  protected double m_Correct;
 
192
 
 
193
  /** The weight of all unclassified instances. */
 
194
  protected double m_Unclassified;
 
195
 
 
196
  /*** The weight of all instances that had no class assigned to them. */
 
197
  protected double m_MissingClass;
 
198
 
 
199
  /** The weight of all instances that had a class assigned to them. */
 
200
  protected double m_WithClass;
 
201
 
 
202
  /** Array for storing the confusion matrix. */
 
203
  protected double [][] m_ConfusionMatrix;
 
204
 
 
205
  /** The names of the classes. */
 
206
  protected String [] m_ClassNames;
 
207
 
 
208
  /** Is the class nominal or numeric? */
 
209
  protected boolean m_ClassIsNominal;
 
210
 
 
211
  /** The prior probabilities of the classes */
 
212
  protected double [] m_ClassPriors;
 
213
 
 
214
  /** The sum of counts for priors */
 
215
  protected double m_ClassPriorsSum;
 
216
 
 
217
  /** The cost matrix (if given). */
 
218
  protected CostMatrix m_CostMatrix;
 
219
 
 
220
  /** The total cost of predictions (includes instance weights) */
 
221
  protected double m_TotalCost;
 
222
 
 
223
  /** Sum of errors. */
 
224
  protected double m_SumErr;
 
225
 
 
226
  /** Sum of absolute errors. */
 
227
  protected double m_SumAbsErr;
 
228
 
 
229
  /** Sum of squared errors. */
 
230
  protected double m_SumSqrErr;
 
231
 
 
232
  /** Sum of class values. */
 
233
  protected double m_SumClass;
 
234
 
 
235
  /** Sum of squared class values. */
 
236
  protected double m_SumSqrClass;
 
237
 
 
238
  /*** Sum of predicted values. */
 
239
  protected double m_SumPredicted;
 
240
 
 
241
  /** Sum of squared predicted values. */
 
242
  protected double m_SumSqrPredicted;
 
243
 
 
244
  /** Sum of predicted * class values. */
 
245
  protected double m_SumClassPredicted;
 
246
 
 
247
  /** Sum of absolute errors of the prior */
 
248
  protected double m_SumPriorAbsErr;
 
249
 
 
250
  /** Sum of absolute errors of the prior */
 
251
  protected double m_SumPriorSqrErr;
 
252
 
 
253
  /** Total Kononenko & Bratko Information */
 
254
  protected double m_SumKBInfo;
 
255
 
 
256
  /*** Resolution of the margin histogram */
 
257
  protected static int k_MarginResolution = 500;
 
258
 
 
259
  /** Cumulative margin distribution */
 
260
  protected double m_MarginCounts [];
 
261
 
 
262
  /** Number of non-missing class training instances seen */
 
263
  protected int m_NumTrainClassVals;
 
264
 
 
265
  /** Array containing all numeric training class values seen */
 
266
  protected double [] m_TrainClassVals;
 
267
 
 
268
  /** Array containing all numeric training class weights */
 
269
  protected double [] m_TrainClassWeights;
 
270
 
 
271
  /** Numeric class error estimator for prior */
 
272
  protected Estimator m_PriorErrorEstimator;
 
273
 
 
274
  /** Numeric class error estimator for scheme */
 
275
  protected Estimator m_ErrorEstimator;
 
276
 
 
277
  /**
 
278
   * The minimum probablility accepted from an estimator to avoid
 
279
   * taking log(0) in Sf calculations.
 
280
   */
 
281
  protected static final double MIN_SF_PROB = Double.MIN_VALUE;
 
282
 
 
283
  /** Total entropy of prior predictions */
 
284
  protected double m_SumPriorEntropy;
 
285
 
 
286
  /** Total entropy of scheme predictions */
 
287
  protected double m_SumSchemeEntropy;
 
288
 
 
289
  /** The list of predictions that have been generated (for computing AUC) */
 
290
  private FastVector m_Predictions;
 
291
 
 
292
  /** enables/disables the use of priors, e.g., if no training set is
 
293
   * present in case of de-serialized schemes */
 
294
  protected boolean m_NoPriors = false;
 
295
 
 
296
  /**
 
297
   * Initializes all the counters for the evaluation. 
 
298
   * Use <code>useNoPriors()</code> if the dataset is the test set and you
 
299
   * can't initialize with the priors from the training set via 
 
300
   * <code>setPriors(Instances)</code>.
 
301
   *
 
302
   * @param data        set of training instances, to get some header 
 
303
   *                    information and prior class distribution information
 
304
   * @throws Exception  if the class is not defined
 
305
   * @see               #useNoPriors()
 
306
   * @see               #setPriors(Instances)
 
307
   */
 
308
  public Evaluation(Instances data) throws Exception {
 
309
 
 
310
    this(data, null);
 
311
  }
 
312
 
 
313
  /**
 
314
   * Initializes all the counters for the evaluation and also takes a
 
315
   * cost matrix as parameter.
 
316
   * Use <code>useNoPriors()</code> if the dataset is the test set and you
 
317
   * can't initialize with the priors from the training set via 
 
318
   * <code>setPriors(Instances)</code>.
 
319
   *
 
320
   * @param data        set of training instances, to get some header 
 
321
   *                    information and prior class distribution information
 
322
   * @param costMatrix  the cost matrix---if null, default costs will be used
 
323
   * @throws Exception  if cost matrix is not compatible with 
 
324
   *                    data, the class is not defined or the class is numeric
 
325
   * @see               #useNoPriors()
 
326
   * @see               #setPriors(Instances)
 
327
   */
 
328
  public Evaluation(Instances data, CostMatrix costMatrix) 
 
329
  throws Exception {
 
330
 
 
331
    m_NumClasses = data.numClasses();
 
332
    m_NumFolds = 1;
 
333
    m_ClassIsNominal = data.classAttribute().isNominal();
 
334
 
 
335
    if (m_ClassIsNominal) {
 
336
      m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
 
337
      m_ClassNames = new String [m_NumClasses];
 
338
      for(int i = 0; i < m_NumClasses; i++) {
 
339
        m_ClassNames[i] = data.classAttribute().value(i);
 
340
      }
 
341
    }
 
342
    m_CostMatrix = costMatrix;
 
343
    if (m_CostMatrix != null) {
 
344
      if (!m_ClassIsNominal) {
 
345
        throw new Exception("Class has to be nominal if cost matrix " + 
 
346
        "given!");
 
347
      }
 
348
      if (m_CostMatrix.size() != m_NumClasses) {
 
349
        throw new Exception("Cost matrix not compatible with data!");
 
350
      }
 
351
    }
 
352
    m_ClassPriors = new double [m_NumClasses];
 
353
    setPriors(data);
 
354
    m_MarginCounts = new double [k_MarginResolution + 1];
 
355
  }
 
356
 
 
357
  /**
 
358
   * Returns the area under ROC for those predictions that have been collected
 
359
   * in the evaluateClassifier(Classifier, Instances) method. Returns 
 
360
   * Instance.missingValue() if the area is not available.
 
361
   *
 
362
   * @param classIndex the index of the class to consider as "positive"
 
363
   * @return the area under the ROC curve or not a number
 
364
   */
 
365
  public double areaUnderROC(int classIndex) {
 
366
 
 
367
    // Check if any predictions have been collected
 
368
    if (m_Predictions == null) {
 
369
      return Instance.missingValue();
 
370
    } else {
 
371
      ThresholdCurve tc = new ThresholdCurve();
 
372
      Instances result = tc.getCurve(m_Predictions, classIndex);
 
373
      return ThresholdCurve.getROCArea(result);
 
374
    }
 
375
  }
 
376
 
 
377
  /**
 
378
   * Returns a copy of the confusion matrix.
 
379
   *
 
380
   * @return a copy of the confusion matrix as a two-dimensional array
 
381
   */
 
382
  public double[][] confusionMatrix() {
 
383
 
 
384
    double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
 
385
 
 
386
    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
 
387
      newMatrix[i] = new double[m_ConfusionMatrix[i].length];
 
388
      System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
 
389
          m_ConfusionMatrix[i].length);
 
390
    }
 
391
    return newMatrix;
 
392
  }
 
393
 
 
394
  /**
 
395
   * Performs a (stratified if class is nominal) cross-validation 
 
396
   * for a classifier on a set of instances. Now performs
 
397
   * a deep copy of the classifier before each call to 
 
398
   * buildClassifier() (just in case the classifier is not
 
399
   * initialized properly).
 
400
   *
 
401
   * @param classifier the classifier with any options set.
 
402
   * @param data the data on which the cross-validation is to be 
 
403
   * performed 
 
404
   * @param numFolds the number of folds for the cross-validation
 
405
   * @param random random number generator for randomization 
 
406
   * @throws Exception if a classifier could not be generated 
 
407
   * successfully or the class is not defined
 
408
   */
 
409
  public void crossValidateModel(Classifier classifier,
 
410
      Instances data, int numFolds, Random random) 
 
411
  throws Exception {
 
412
 
 
413
    // Make a copy of the data we can reorder
 
414
    data = new Instances(data);
 
415
    data.randomize(random);
 
416
    if (data.classAttribute().isNominal()) {
 
417
      data.stratify(numFolds);
 
418
    }
 
419
    // Do the folds
 
420
    for (int i = 0; i < numFolds; i++) {
 
421
      Instances train = data.trainCV(numFolds, i, random);
 
422
      setPriors(train);
 
423
      Classifier copiedClassifier = Classifier.makeCopy(classifier);
 
424
      copiedClassifier.buildClassifier(train);
 
425
      Instances test = data.testCV(numFolds, i);
 
426
      evaluateModel(copiedClassifier, test);
 
427
    }
 
428
    m_NumFolds = numFolds;
 
429
  }
 
430
 
 
431
  /**
 
432
   * Performs a (stratified if class is nominal) cross-validation 
 
433
   * for a classifier on a set of instances.
 
434
   *
 
435
   * @param classifierString a string naming the class of the classifier
 
436
   * @param data the data on which the cross-validation is to be 
 
437
   * performed 
 
438
   * @param numFolds the number of folds for the cross-validation
 
439
   * @param options the options to the classifier. Any options
 
440
   * @param random the random number generator for randomizing the data
 
441
   * accepted by the classifier will be removed from this array.
 
442
   * @throws Exception if a classifier could not be generated 
 
443
   * successfully or the class is not defined
 
444
   */
 
445
  public void crossValidateModel(String classifierString,
 
446
      Instances data, int numFolds,
 
447
      String[] options, Random random) 
 
448
  throws Exception {
 
449
 
 
450
    crossValidateModel(Classifier.forName(classifierString, options),
 
451
        data, numFolds, random);
 
452
  }
 
453
 
 
454
  /**
 
455
   * Evaluates a classifier with the options given in an array of
 
456
   * strings. <p/>
 
457
   *
 
458
   * Valid options are: <p/>
 
459
   *
 
460
   * -t filename <br/>
 
461
   * Name of the file with the training data. (required) <p/>
 
462
   *
 
463
   * -T filename <br/>
 
464
   * Name of the file with the test data. If missing a cross-validation 
 
465
   * is performed. <p/>
 
466
   *
 
467
   * -c index <br/>
 
468
   * Index of the class attribute (1, 2, ...; default: last). <p/>
 
469
   *
 
470
   * -x number <br/>
 
471
   * The number of folds for the cross-validation (default: 10). <p/>
 
472
   *
 
473
   * -no-cv <br/>
 
474
   * No cross validation.  If no test file is provided, no evaluation
 
475
   * is done. <p/>
 
476
   * 
 
477
   * -split-percentage percentage <br/>
 
478
   * Sets the percentage for the train/test set split, e.g., 66. <p/>
 
479
   * 
 
480
   * -preserve-order <br/>
 
481
   * Preserves the order in the percentage split instead of randomizing
 
482
   * the data first with the seed value ('-s'). <p/>
 
483
   *
 
484
   * -s seed <br/>
 
485
   * Random number seed for the cross-validation and percentage split
 
486
   * (default: 1). <p/>
 
487
   *
 
488
   * -m filename <br/>
 
489
   * The name of a file containing a cost matrix. <p/>
 
490
   *
 
491
   * -l filename <br/>
 
492
   * Loads classifier from the given file. In case the filename ends with
 
493
   * ".xml" the options are loaded from XML. <p/>
 
494
   *
 
495
   * -d filename <br/>
 
496
   * Saves classifier built from the training data into the given file. In case 
 
497
   * the filename ends with ".xml" the options are saved XML, not the model. <p/>
 
498
   *
 
499
   * -v <br/>
 
500
   * Outputs no statistics for the training data. <p/>
 
501
   *
 
502
   * -o <br/>
 
503
   * Outputs statistics only, not the classifier. <p/>
 
504
   * 
 
505
   * -i <br/>
 
506
   * Outputs detailed information-retrieval statistics per class. <p/>
 
507
   *
 
508
   * -k <br/>
 
509
   * Outputs information-theoretic statistics. <p/>
 
510
   *
 
511
   * -p range <br/>
 
512
   * Outputs predictions for test instances (or the train instances if no test
 
513
   * instances provided), along with the attributes in the specified range (and 
 
514
   *  nothing else). Use '-p 0' if no attributes are desired. <p/>
 
515
   *
 
516
   * -distribution <br/>
 
517
   * Outputs the distribution instead of only the prediction
 
518
   * in conjunction with the '-p' option (only nominal classes). <p/>
 
519
   *
 
520
   * -r <br/>
 
521
   * Outputs cumulative margin distribution (and nothing else). <p/>
 
522
   *
 
523
   * -g <br/> 
 
524
   * Only for classifiers that implement "Graphable." Outputs
 
525
   * the graph representation of the classifier (and nothing
 
526
   * else). <p/>
 
527
   *
 
528
   * -xml filename | xml-string <br/>
 
529
   * Retrieves the options from the XML-data instead of the command line. <p/>
 
530
   * 
 
531
   * -threshold-file file <br/>
 
532
   * The file to save the threshold data to.
 
533
   * The format is determined by the extensions, e.g., '.arff' for ARFF
 
534
   * format or '.csv' for CSV. <p/>
 
535
   *         
 
536
   * -threshold-label label <br/>
 
537
   * The class label to determine the threshold data for
 
538
   * (default is the first label) <p/>
 
539
   *
 
540
   * @param classifierString class of machine learning classifier as a string
 
541
   * @param options the array of string containing the options
 
542
   * @throws Exception if model could not be evaluated successfully
 
543
   * @return a string describing the results 
 
544
   */
 
545
  public static String evaluateModel(String classifierString, 
 
546
      String [] options) throws Exception {
 
547
 
 
548
    Classifier classifier;       
 
549
 
 
550
    // Create classifier
 
551
    try {
 
552
      classifier = 
 
553
        (Classifier)Class.forName(classifierString).newInstance();
 
554
    } catch (Exception e) {
 
555
      throw new Exception("Can't find class with name " 
 
556
          + classifierString + '.');
 
557
    }
 
558
    return evaluateModel(classifier, options);
 
559
  }
 
560
 
 
561
  /**
 
562
   * A test method for this class. Just extracts the first command line
 
563
   * argument as a classifier class name and calls evaluateModel.
 
564
   * @param args an array of command line arguments, the first of which
 
565
   * must be the class name of a classifier.
 
566
   */
 
567
  public static void main(String [] args) {
 
568
 
 
569
    try {
 
570
      if (args.length == 0) {
 
571
        throw new Exception("The first argument must be the class name"
 
572
            + " of a classifier");
 
573
      }
 
574
      String classifier = args[0];
 
575
      args[0] = "";
 
576
      System.out.println(evaluateModel(classifier, args));
 
577
    } catch (Exception ex) {
 
578
      ex.printStackTrace();
 
579
      System.err.println(ex.getMessage());
 
580
    }
 
581
  }
 
582
 
 
583
  /**
 
584
   * Evaluates a classifier with the options given in an array of
 
585
   * strings. <p/>
 
586
   *
 
587
   * Valid options are: <p/>
 
588
   *
 
589
   * -t name of training file <br/>
 
590
   * Name of the file with the training data. (required) <p/>
 
591
   *
 
592
   * -T name of test file <br/>
 
593
   * Name of the file with the test data. If missing a cross-validation 
 
594
   * is performed. <p/>
 
595
   *
 
596
   * -c class index <br/>
 
597
   * Index of the class attribute (1, 2, ...; default: last). <p/>
 
598
   *
 
599
   * -x number of folds <br/>
 
600
   * The number of folds for the cross-validation (default: 10). <p/>
 
601
   *
 
602
   * -no-cv <br/>
 
603
   * No cross validation.  If no test file is provided, no evaluation
 
604
   * is done. <p/>
 
605
   * 
 
606
   * -split-percentage percentage <br/>
 
607
   * Sets the percentage for the train/test set split, e.g., 66. <p/>
 
608
   * 
 
609
   * -preserve-order <br/>
 
610
   * Preserves the order in the percentage split instead of randomizing
 
611
   * the data first with the seed value ('-s'). <p/>
 
612
   *
 
613
   * -s seed <br/>
 
614
   * Random number seed for the cross-validation and percentage split
 
615
   * (default: 1). <p/>
 
616
   *
 
617
   * -m file with cost matrix <br/>
 
618
   * The name of a file containing a cost matrix. <p/>
 
619
   *
 
620
   * -l filename <br/>
 
621
   * Loads classifier from the given file. In case the filename ends with
 
622
   * ".xml" the options are loaded from XML. <p/>
 
623
   *
 
624
   * -d filename <br/>
 
625
   * Saves classifier built from the training data into the given file. In case 
 
626
   * the filename ends with ".xml" the options are saved XML, not the model. <p/>
 
627
   *
 
628
   * -v <br/>
 
629
   * Outputs no statistics for the training data. <p/>
 
630
   *
 
631
   * -o <br/>
 
632
   * Outputs statistics only, not the classifier. <p/>
 
633
   * 
 
634
   * -i <br/>
 
635
   * Outputs detailed information-retrieval statistics per class. <p/>
 
636
   *
 
637
   * -k <br/>
 
638
   * Outputs information-theoretic statistics. <p/>
 
639
   *
 
640
   * -p range <br/>
 
641
   * Outputs predictions for test instances (or the train instances if no test
 
642
   * instances provided), along with the attributes in the specified range 
 
643
   * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
 
644
   *
 
645
   * -distribution <br/>
 
646
   * Outputs the distribution instead of only the prediction
 
647
   * in conjunction with the '-p' option (only nominal classes). <p/>
 
648
   *
 
649
   * -r <br/>
 
650
   * Outputs cumulative margin distribution (and nothing else). <p/>
 
651
   *
 
652
   * -g <br/> 
 
653
   * Only for classifiers that implement "Graphable." Outputs
 
654
   * the graph representation of the classifier (and nothing
 
655
   * else). <p/>
 
656
   *
 
657
   * -xml filename | xml-string <br/>
 
658
   * Retrieves the options from the XML-data instead of the command line. <p/>
 
659
   *
 
660
   * @param classifier machine learning classifier
 
661
   * @param options the array of string containing the options
 
662
   * @throws Exception if model could not be evaluated successfully
 
663
   * @return a string describing the results 
 
664
   */
 
665
  public static String evaluateModel(Classifier classifier,
 
666
      String [] options) throws Exception {
 
667
 
 
668
    Instances train = null, tempTrain, test = null, template = null;
 
669
    int seed = 1, folds = 10, classIndex = -1;
 
670
    boolean noCrossValidation = false;
 
671
    String trainFileName, testFileName, sourceClass, 
 
672
    classIndexString, seedString, foldsString, objectInputFileName, 
 
673
    objectOutputFileName, attributeRangeString;
 
674
    boolean noOutput = false,
 
675
    printClassifications = false, trainStatistics = true,
 
676
    printMargins = false, printComplexityStatistics = false,
 
677
    printGraph = false, classStatistics = false, printSource = false;
 
678
    StringBuffer text = new StringBuffer();
 
679
    DataSource trainSource = null, testSource = null;
 
680
    ObjectInputStream objectInputStream = null;
 
681
    BufferedInputStream xmlInputStream = null;
 
682
    CostMatrix costMatrix = null;
 
683
    StringBuffer schemeOptionsText = null;
 
684
    Range attributesToOutput = null;
 
685
    long trainTimeStart = 0, trainTimeElapsed = 0,
 
686
    testTimeStart = 0, testTimeElapsed = 0;
 
687
    String xml = "";
 
688
    String[] optionsTmp = null;
 
689
    Classifier classifierBackup;
 
690
    Classifier classifierClassifications = null;
 
691
    boolean printDistribution = false;
 
692
    int actualClassIndex = -1;  // 0-based class index
 
693
    String splitPercentageString = "";
 
694
    int splitPercentage = -1;
 
695
    boolean preserveOrder = false;
 
696
    boolean trainSetPresent = false;
 
697
    boolean testSetPresent = false;
 
698
    String thresholdFile;
 
699
    String thresholdLabel;
 
700
 
 
701
    // help requested?
 
702
    if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
 
703
      throw new Exception("\nHelp requested." + makeOptionString(classifier));
 
704
    }
 
705
    
 
706
    try {
 
707
      // do we get the input from XML instead of normal parameters?
 
708
      xml = Utils.getOption("xml", options);
 
709
      if (!xml.equals(""))
 
710
        options = new XMLOptions(xml).toArray();
 
711
 
 
712
      // is the input model only the XML-Options, i.e. w/o built model?
 
713
      optionsTmp = new String[options.length];
 
714
      for (int i = 0; i < options.length; i++)
 
715
        optionsTmp[i] = options[i];
 
716
 
 
717
      if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
 
718
        // load options from serialized data ('-l' is automatically erased!)
 
719
        XMLClassifier xmlserial = new XMLClassifier();
 
720
        Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
 
721
        // merge options
 
722
        optionsTmp = new String[options.length + cl.getOptions().length];
 
723
        System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
 
724
        System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
 
725
        options = optionsTmp;
 
726
      }
 
727
 
 
728
      noCrossValidation = Utils.getFlag("no-cv", options);
 
729
      // Get basic options (options the same for all schemes)
 
730
      classIndexString = Utils.getOption('c', options);
 
731
      if (classIndexString.length() != 0) {
 
732
        if (classIndexString.equals("first"))
 
733
          classIndex = 1;
 
734
        else if (classIndexString.equals("last"))
 
735
          classIndex = -1;
 
736
        else
 
737
          classIndex = Integer.parseInt(classIndexString);
 
738
      }
 
739
      trainFileName = Utils.getOption('t', options); 
 
740
      objectInputFileName = Utils.getOption('l', options);
 
741
      objectOutputFileName = Utils.getOption('d', options);
 
742
      testFileName = Utils.getOption('T', options);
 
743
      foldsString = Utils.getOption('x', options);
 
744
      if (foldsString.length() != 0) {
 
745
        folds = Integer.parseInt(foldsString);
 
746
      }
 
747
      seedString = Utils.getOption('s', options);
 
748
      if (seedString.length() != 0) {
 
749
        seed = Integer.parseInt(seedString);
 
750
      }
 
751
      if (trainFileName.length() == 0) {
 
752
        if (objectInputFileName.length() == 0) {
 
753
          throw new Exception("No training file and no object "+
 
754
          "input file given.");
 
755
        } 
 
756
        if (testFileName.length() == 0) {
 
757
          throw new Exception("No training file and no test "+
 
758
          "file given.");
 
759
        }
 
760
      } else if ((objectInputFileName.length() != 0) &&
 
761
          ((!(classifier instanceof UpdateableClassifier)) ||
 
762
              (testFileName.length() == 0))) {
 
763
        throw new Exception("Classifier not incremental, or no " +
 
764
            "test file provided: can't "+
 
765
        "use both train and model file.");
 
766
      }
 
767
      try {
 
768
        if (trainFileName.length() != 0) {
 
769
          trainSetPresent = true;
 
770
          trainSource = new DataSource(trainFileName);
 
771
        }
 
772
        if (testFileName.length() != 0) {
 
773
          testSetPresent = true;
 
774
          testSource = new DataSource(testFileName);
 
775
        }
 
776
        if (objectInputFileName.length() != 0) {
 
777
          InputStream is = new FileInputStream(objectInputFileName);
 
778
          if (objectInputFileName.endsWith(".gz")) {
 
779
            is = new GZIPInputStream(is);
 
780
          }
 
781
          // load from KOML?
 
782
          if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
 
783
            objectInputStream = new ObjectInputStream(is);
 
784
            xmlInputStream    = null;
 
785
          }
 
786
          else {
 
787
            objectInputStream = null;
 
788
            xmlInputStream    = new BufferedInputStream(is);
 
789
          }
 
790
        }
 
791
      } catch (Exception e) {
 
792
        throw new Exception("Can't open file " + e.getMessage() + '.');
 
793
      }
 
794
      if (testSetPresent) {
 
795
        template = test = testSource.getStructure();
 
796
        if (classIndex != -1) {
 
797
          test.setClassIndex(classIndex - 1);
 
798
        } else {
 
799
          if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
 
800
            test.setClassIndex(test.numAttributes() - 1);
 
801
        }
 
802
        actualClassIndex = test.classIndex();
 
803
      }
 
804
      else {
 
805
        // percentage split
 
806
        splitPercentageString = Utils.getOption("split-percentage", options);
 
807
        if (splitPercentageString.length() != 0) {
 
808
          if (foldsString.length() != 0)
 
809
            throw new Exception(
 
810
                "Percentage split cannot be used in conjunction with "
 
811
                + "cross-validation ('-x').");
 
812
          splitPercentage = Integer.parseInt(splitPercentageString);
 
813
          if ((splitPercentage <= 0) || (splitPercentage >= 100))
 
814
            throw new Exception("Percentage split value needs be >0 and <100.");
 
815
        }
 
816
        else {
 
817
          splitPercentage = -1;
 
818
        }
 
819
        preserveOrder = Utils.getFlag("preserve-order", options);
 
820
        if (preserveOrder) {
 
821
          if (splitPercentage == -1)
 
822
            throw new Exception("Percentage split ('-percentage-split') is missing.");
 
823
        }
 
824
        // create new train/test sources
 
825
        if (splitPercentage > 0) {
 
826
          testSetPresent = true;
 
827
          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
 
828
          if (!preserveOrder)
 
829
            tmpInst.randomize(new Random(seed));
 
830
          int trainSize = tmpInst.numInstances() * splitPercentage / 100;
 
831
          int testSize  = tmpInst.numInstances() - trainSize;
 
832
          Instances trainInst = new Instances(tmpInst, 0, trainSize);
 
833
          Instances testInst  = new Instances(tmpInst, trainSize, testSize);
 
834
          trainSource = new DataSource(trainInst);
 
835
          testSource  = new DataSource(testInst);
 
836
          template = test = testSource.getStructure();
 
837
          if (classIndex != -1) {
 
838
            test.setClassIndex(classIndex - 1);
 
839
          } else {
 
840
            if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
 
841
              test.setClassIndex(test.numAttributes() - 1);
 
842
          }
 
843
          actualClassIndex = test.classIndex();
 
844
        }
 
845
      }
 
846
      if (trainSetPresent) {
 
847
        template = train = trainSource.getStructure();
 
848
        if (classIndex != -1) {
 
849
          train.setClassIndex(classIndex - 1);
 
850
        } else {
 
851
          if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
 
852
            train.setClassIndex(train.numAttributes() - 1);
 
853
        }
 
854
        actualClassIndex = train.classIndex();
 
855
        if ((testSetPresent) && !test.equalHeaders(train)) {
 
856
          throw new IllegalArgumentException("Train and test file not compatible!");
 
857
        }
 
858
      }
 
859
      if (template == null) {
 
860
        throw new Exception("No actual dataset provided to use as template");
 
861
      }
 
862
      costMatrix = handleCostOption(
 
863
          Utils.getOption('m', options), template.numClasses());
 
864
 
 
865
      classStatistics = Utils.getFlag('i', options);
 
866
      noOutput = Utils.getFlag('o', options);
 
867
      trainStatistics = !Utils.getFlag('v', options);
 
868
      printComplexityStatistics = Utils.getFlag('k', options);
 
869
      printMargins = Utils.getFlag('r', options);
 
870
      printGraph = Utils.getFlag('g', options);
 
871
      sourceClass = Utils.getOption('z', options);
 
872
      printSource = (sourceClass.length() != 0);
 
873
      printDistribution = Utils.getFlag("distribution", options);
 
874
      thresholdFile = Utils.getOption("threshold-file", options);
 
875
      thresholdLabel = Utils.getOption("threshold-label", options);
 
876
 
 
877
      // Check -p option
 
878
      try {
 
879
        attributeRangeString = Utils.getOption('p', options);
 
880
      }
 
881
      catch (Exception e) {
 
882
        throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
 
883
            "It now expects a parameter specifying a range of attributes " +
 
884
        "to list with the predictions. Use '-p 0' for none.");
 
885
      }
 
886
      if (attributeRangeString.length() != 0) {
 
887
        printClassifications = true;
 
888
        if (!attributeRangeString.equals("0")) 
 
889
          attributesToOutput = new Range(attributeRangeString);
 
890
      }
 
891
 
 
892
      if (!printClassifications && printDistribution)
 
893
        throw new Exception("Cannot print distribution without '-p' option!");
 
894
 
 
895
      // if no training file given, we don't have any priors
 
896
      if ( (!trainSetPresent) && (printComplexityStatistics) )
 
897
        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
 
898
 
 
899
      // If a model file is given, we can't process 
 
900
      // scheme-specific options
 
901
      if (objectInputFileName.length() != 0) {
 
902
        Utils.checkForRemainingOptions(options);
 
903
      } else {
 
904
 
 
905
        // Set options for classifier
 
906
        if (classifier instanceof OptionHandler) {
 
907
          for (int i = 0; i < options.length; i++) {
 
908
            if (options[i].length() != 0) {
 
909
              if (schemeOptionsText == null) {
 
910
                schemeOptionsText = new StringBuffer();
 
911
              }
 
912
              if (options[i].indexOf(' ') != -1) {
 
913
                schemeOptionsText.append('"' + options[i] + "\" ");
 
914
              } else {
 
915
                schemeOptionsText.append(options[i] + " ");
 
916
              }
 
917
            }
 
918
          }
 
919
          ((OptionHandler)classifier).setOptions(options);
 
920
        }
 
921
      }
 
922
      Utils.checkForRemainingOptions(options);
 
923
    } catch (Exception e) {
 
924
      throw new Exception("\nWeka exception: " + e.getMessage()
 
925
          + makeOptionString(classifier));
 
926
    }
 
927
 
 
928
    // Setup up evaluation objects
 
929
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
 
930
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
 
931
 
 
932
    // disable use of priors if no training file given
 
933
    if (!trainSetPresent)
 
934
      testingEvaluation.useNoPriors();
 
935
 
 
936
    if (objectInputFileName.length() != 0) {
 
937
      // Load classifier from file
 
938
      if (objectInputStream != null) {
 
939
        classifier = (Classifier) objectInputStream.readObject();
 
940
        // try and read a header (if present)
 
941
        Instances savedStructure = null;
 
942
        try {
 
943
          savedStructure = (Instances) objectInputStream.readObject();
 
944
        } catch (Exception ex) {
 
945
          // don't make a fuss
 
946
        }
 
947
        if (savedStructure != null) {
 
948
          // test for compatibility with template
 
949
          if (!template.equalHeaders(savedStructure)) {
 
950
            throw new Exception("training and test set are not compatible");
 
951
          }
 
952
        }
 
953
        objectInputStream.close();
 
954
      }
 
955
      else {
 
956
        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
 
957
        classifier = (Classifier) KOML.read(xmlInputStream);
 
958
        xmlInputStream.close();
 
959
      }
 
960
    }
 
961
 
 
962
    // backup of fully setup classifier for cross-validation
 
963
    classifierBackup = Classifier.makeCopy(classifier);
 
964
 
 
965
    // Build the classifier if no object file provided
 
966
    if ((classifier instanceof UpdateableClassifier) &&
 
967
        (testSetPresent) &&
 
968
        (costMatrix == null) &&
 
969
        (trainSetPresent)) {
 
970
 
 
971
      // Build classifier incrementally
 
972
      trainingEvaluation.setPriors(train);
 
973
      testingEvaluation.setPriors(train);
 
974
      trainTimeStart = System.currentTimeMillis();
 
975
      if (objectInputFileName.length() == 0) {
 
976
        classifier.buildClassifier(train);
 
977
      }
 
978
      Instance trainInst;
 
979
      while (trainSource.hasMoreElements(train)) {
 
980
        trainInst = trainSource.nextElement(train);
 
981
        trainingEvaluation.updatePriors(trainInst);
 
982
        testingEvaluation.updatePriors(trainInst);
 
983
        ((UpdateableClassifier)classifier).updateClassifier(trainInst);
 
984
      }
 
985
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
 
986
    } else if (objectInputFileName.length() == 0) {
 
987
      // Build classifier in one go
 
988
      tempTrain = trainSource.getDataSet(actualClassIndex);
 
989
      trainingEvaluation.setPriors(tempTrain);
 
990
      testingEvaluation.setPriors(tempTrain);
 
991
      trainTimeStart = System.currentTimeMillis();
 
992
      classifier.buildClassifier(tempTrain);
 
993
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
 
994
    } 
 
995
 
 
996
    // backup of fully trained classifier for printing the classifications
 
997
    if (printClassifications)
 
998
      classifierClassifications = Classifier.makeCopy(classifier);
 
999
 
 
1000
    // Save the classifier if an object output file is provided
 
1001
    if (objectOutputFileName.length() != 0) {
 
1002
      OutputStream os = new FileOutputStream(objectOutputFileName);
 
1003
      // binary
 
1004
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
 
1005
        if (objectOutputFileName.endsWith(".gz")) {
 
1006
          os = new GZIPOutputStream(os);
 
1007
        }
 
1008
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
 
1009
        objectOutputStream.writeObject(classifier);
 
1010
        if (template != null) {
 
1011
          objectOutputStream.writeObject(template);
 
1012
        }
 
1013
        objectOutputStream.flush();
 
1014
        objectOutputStream.close();
 
1015
      }
 
1016
      // KOML/XML
 
1017
      else {
 
1018
        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
 
1019
        if (objectOutputFileName.endsWith(".xml")) {
 
1020
          XMLSerialization xmlSerial = new XMLClassifier();
 
1021
          xmlSerial.write(xmlOutputStream, classifier);
 
1022
        }
 
1023
        else
 
1024
          // whether KOML is present has already been checked
 
1025
          // if not present -> ".koml" is interpreted as binary - see above
 
1026
          if (objectOutputFileName.endsWith(".koml")) {
 
1027
            KOML.write(xmlOutputStream, classifier);
 
1028
          }
 
1029
        xmlOutputStream.close();
 
1030
      }
 
1031
    }
 
1032
 
 
1033
    // If classifier is drawable output string describing graph
 
1034
    if ((classifier instanceof Drawable) && (printGraph)){
 
1035
      return ((Drawable)classifier).graph();
 
1036
    }
 
1037
 
 
1038
    // Output the classifier as equivalent source
 
1039
    if ((classifier instanceof Sourcable) && (printSource)){
 
1040
      return wekaStaticWrapper((Sourcable) classifier, sourceClass);
 
1041
    }
 
1042
 
 
1043
    // Output model
 
1044
    if (!(noOutput || printMargins)) {
 
1045
      if (classifier instanceof OptionHandler) {
 
1046
        if (schemeOptionsText != null) {
 
1047
          text.append("\nOptions: "+schemeOptionsText);
 
1048
          text.append("\n");
 
1049
        }
 
1050
      }
 
1051
      text.append("\n" + classifier.toString() + "\n");
 
1052
    }
 
1053
 
 
1054
    if (!printMargins && (costMatrix != null)) {
 
1055
      text.append("\n=== Evaluation Cost Matrix ===\n\n");
 
1056
      text.append(costMatrix.toString());
 
1057
    }
 
1058
 
 
1059
    // Output test instance predictions only
 
1060
    if (printClassifications) {
 
1061
      DataSource source = testSource;
 
1062
      // no test set -> use train set
 
1063
      if (source == null)
 
1064
        source = trainSource;
 
1065
      return printClassifications(classifierClassifications, new Instances(template, 0),
 
1066
          source, actualClassIndex + 1, attributesToOutput,
 
1067
          printDistribution);
 
1068
    }
 
1069
 
 
1070
    // Compute error estimate from training data
 
1071
    if ((trainStatistics) && (trainSetPresent)) {
 
1072
 
 
1073
      if ((classifier instanceof UpdateableClassifier) &&
 
1074
          (testSetPresent) &&
 
1075
          (costMatrix == null)) {
 
1076
 
 
1077
        // Classifier was trained incrementally, so we have to 
 
1078
        // reset the source.
 
1079
        trainSource.reset();
 
1080
 
 
1081
        // Incremental testing
 
1082
        train = trainSource.getStructure(actualClassIndex);
 
1083
        testTimeStart = System.currentTimeMillis();
 
1084
        Instance trainInst;
 
1085
        while (trainSource.hasMoreElements(train)) {
 
1086
          trainInst = trainSource.nextElement(train);
 
1087
          trainingEvaluation.evaluateModelOnce((Classifier)classifier, trainInst);
 
1088
        }
 
1089
        testTimeElapsed = System.currentTimeMillis() - testTimeStart;
 
1090
      } else {
 
1091
        testTimeStart = System.currentTimeMillis();
 
1092
        trainingEvaluation.evaluateModel(
 
1093
            classifier, trainSource.getDataSet(actualClassIndex));
 
1094
        testTimeElapsed = System.currentTimeMillis() - testTimeStart;
 
1095
      }
 
1096
 
 
1097
      // Print the results of the training evaluation
 
1098
      if (printMargins) {
 
1099
        return trainingEvaluation.toCumulativeMarginDistributionString();
 
1100
      } else {
 
1101
        text.append("\nTime taken to build model: "
 
1102
            + Utils.doubleToString(trainTimeElapsed / 1000.0,2)
 
1103
            + " seconds");
 
1104
        
 
1105
        if (splitPercentage > 0)
 
1106
          text.append("\nTime taken to test model on training split: ");
 
1107
        else
 
1108
          text.append("\nTime taken to test model on training data: ");
 
1109
        text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
 
1110
 
 
1111
        if (splitPercentage > 0)
 
1112
          text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
 
1113
              + " split ===\n", printComplexityStatistics));
 
1114
        else
 
1115
          text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
 
1116
              + " data ===\n", printComplexityStatistics));
 
1117
        
 
1118
        if (template.classAttribute().isNominal()) {
 
1119
          if (classStatistics) {
 
1120
            text.append("\n\n" + trainingEvaluation.toClassDetailsString());
 
1121
          }
 
1122
          if (!noCrossValidation)
 
1123
            text.append("\n\n" + trainingEvaluation.toMatrixString());
 
1124
        }
 
1125
 
 
1126
      }
 
1127
    }
 
1128
 
 
1129
    // Compute proper error estimates
 
1130
    if (testSource != null) {
 
1131
      // Testing is on the supplied test data
 
1132
      Instance testInst;
 
1133
      while (testSource.hasMoreElements(test)) {
 
1134
        testInst = testSource.nextElement(test);
 
1135
        testingEvaluation.evaluateModelOnceAndRecordPrediction(
 
1136
            (Classifier)classifier, testInst);
 
1137
      }
 
1138
 
 
1139
      if (splitPercentage > 0)
 
1140
        text.append("\n\n" + testingEvaluation.
 
1141
            toSummaryString("=== Error on test split ===\n",
 
1142
                printComplexityStatistics));
 
1143
      else
 
1144
        text.append("\n\n" + testingEvaluation.
 
1145
            toSummaryString("=== Error on test data ===\n",
 
1146
                printComplexityStatistics));
 
1147
 
 
1148
    } else if (trainSource != null) {
 
1149
      if (!noCrossValidation) {
 
1150
        // Testing is via cross-validation on training data
 
1151
        Random random = new Random(seed);
 
1152
        // use untrained (!) classifier for cross-validation
 
1153
        classifier = Classifier.makeCopy(classifierBackup);
 
1154
        testingEvaluation.crossValidateModel(
 
1155
            classifier, trainSource.getDataSet(actualClassIndex), folds, random);
 
1156
        if (template.classAttribute().isNumeric()) {
 
1157
          text.append("\n\n\n" + testingEvaluation.
 
1158
              toSummaryString("=== Cross-validation ===\n",
 
1159
                  printComplexityStatistics));
 
1160
        } else {
 
1161
          text.append("\n\n\n" + testingEvaluation.
 
1162
              toSummaryString("=== Stratified " + 
 
1163
                  "cross-validation ===\n",
 
1164
                  printComplexityStatistics));
 
1165
        }
 
1166
      }
 
1167
    }
 
1168
    if (template.classAttribute().isNominal()) {
 
1169
      if (classStatistics) {
 
1170
        text.append("\n\n" + testingEvaluation.toClassDetailsString());
 
1171
      }
 
1172
      if (!noCrossValidation)
 
1173
        text.append("\n\n" + testingEvaluation.toMatrixString());
 
1174
    }
 
1175
 
 
1176
    if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
 
1177
      int labelIndex = 0;
 
1178
      if (thresholdLabel.length() != 0)
 
1179
        labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
 
1180
      if (labelIndex == -1)
 
1181
        throw new IllegalArgumentException(
 
1182
            "Class label '" + thresholdLabel + "' is unknown!");
 
1183
      ThresholdCurve tc = new ThresholdCurve();
 
1184
      Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
 
1185
      DataSink.write(thresholdFile, result);
 
1186
    }
 
1187
    
 
1188
    return text.toString();
 
1189
  }
 
1190
 
 
1191
  /**
 
1192
   * Attempts to load a cost matrix.
 
1193
   *
 
1194
   * @param costFileName the filename of the cost matrix
 
1195
   * @param numClasses the number of classes that should be in the cost matrix
 
1196
   * (only used if the cost file is in old format).
 
1197
   * @return a <code>CostMatrix</code> value, or null if costFileName is empty
 
1198
   * @throws Exception if an error occurs.
 
1199
   */
 
1200
  protected static CostMatrix handleCostOption(String costFileName, 
 
1201
      int numClasses) 
 
1202
  throws Exception {
 
1203
 
 
1204
    if ((costFileName != null) && (costFileName.length() != 0)) {
 
1205
      System.out.println(
 
1206
          "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
 
1207
          +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
 
1208
          +" only. For cost-sensitive *prediction*, use one of the"
 
1209
          +" cost-sensitive metaschemes such as"
 
1210
          +" weka.classifiers.meta.CostSensitiveClassifier or"
 
1211
          +" weka.classifiers.meta.MetaCost");
 
1212
 
 
1213
      Reader costReader = null;
 
1214
      try {
 
1215
        costReader = new BufferedReader(new FileReader(costFileName));
 
1216
      } catch (Exception e) {
 
1217
        throw new Exception("Can't open file " + e.getMessage() + '.');
 
1218
      }
 
1219
      try {
 
1220
        // First try as a proper cost matrix format
 
1221
        return new CostMatrix(costReader);
 
1222
      } catch (Exception ex) {
 
1223
        try {
 
1224
          // Now try as the poxy old format :-)
 
1225
          //System.err.println("Attempting to read old format cost file");
 
1226
          try {
 
1227
            costReader.close(); // Close the old one
 
1228
            costReader = new BufferedReader(new FileReader(costFileName));
 
1229
          } catch (Exception e) {
 
1230
            throw new Exception("Can't open file " + e.getMessage() + '.');
 
1231
          }
 
1232
          CostMatrix costMatrix = new CostMatrix(numClasses);
 
1233
          //System.err.println("Created default cost matrix");
 
1234
          costMatrix.readOldFormat(costReader);
 
1235
          return costMatrix;
 
1236
          //System.err.println("Read old format");
 
1237
        } catch (Exception e2) {
 
1238
          // re-throw the original exception
 
1239
          //System.err.println("Re-throwing original exception");
 
1240
          throw ex;
 
1241
        }
 
1242
      }
 
1243
    } else {
 
1244
      return null;
 
1245
    }
 
1246
  }
 
1247
 
 
1248
  /**
 
1249
   * Evaluates the classifier on a given set of instances. Note that
 
1250
   * the data must have exactly the same format (e.g. order of
 
1251
   * attributes) as the data used to train the classifier! Otherwise
 
1252
   * the results will generally be meaningless.
 
1253
   *
 
1254
   * @param classifier machine learning classifier
 
1255
   * @param data set of test instances for evaluation
 
1256
   * @return the predictions
 
1257
   * @throws Exception if model could not be evaluated 
 
1258
   * successfully 
 
1259
   */
 
1260
  public double[] evaluateModel(Classifier classifier,
 
1261
      Instances data) throws Exception {
 
1262
 
 
1263
    double predictions[] = new double[data.numInstances()];
 
1264
 
 
1265
    // Need to be able to collect predictions if appropriate (for AUC)
 
1266
 
 
1267
    for (int i = 0; i < data.numInstances(); i++) {
 
1268
      predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier, 
 
1269
          data.instance(i));
 
1270
    }
 
1271
 
 
1272
    return predictions;
 
1273
  }
 
1274
 
 
1275
  /**
 
1276
   * Evaluates the classifier on a single instance and records the
 
1277
   * prediction (if the class is nominal).
 
1278
   *
 
1279
   * @param classifier machine learning classifier
 
1280
   * @param instance the test instance to be classified
 
1281
   * @return the prediction made by the clasifier
 
1282
   * @throws Exception if model could not be evaluated 
 
1283
   * successfully or the data contains string attributes
 
1284
   */
 
1285
  public double evaluateModelOnceAndRecordPrediction(Classifier classifier,
 
1286
      Instance instance) throws Exception {
 
1287
 
 
1288
    Instance classMissing = (Instance)instance.copy();
 
1289
    double pred = 0;
 
1290
    classMissing.setDataset(instance.dataset());
 
1291
    classMissing.setClassMissing();
 
1292
    if (m_ClassIsNominal) {
 
1293
      if (m_Predictions == null) {
 
1294
        m_Predictions = new FastVector();
 
1295
      }
 
1296
      double [] dist = classifier.distributionForInstance(classMissing);
 
1297
      pred = Utils.maxIndex(dist);
 
1298
      if (dist[(int)pred] <= 0) {
 
1299
        pred = Instance.missingValue();
 
1300
      }
 
1301
      updateStatsForClassifier(dist, instance);
 
1302
      m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist, 
 
1303
          instance.weight()));
 
1304
    } else {
 
1305
      pred = classifier.classifyInstance(classMissing);
 
1306
      updateStatsForPredictor(pred, instance);
 
1307
    }
 
1308
    return pred;
 
1309
  }
 
1310
 
 
1311
  /**
 
1312
   * Evaluates the classifier on a single instance.
 
1313
   *
 
1314
   * @param classifier machine learning classifier
 
1315
   * @param instance the test instance to be classified
 
1316
   * @return the prediction made by the clasifier
 
1317
   * @throws Exception if model could not be evaluated 
 
1318
   * successfully or the data contains string attributes
 
1319
   */
 
1320
  public double evaluateModelOnce(Classifier classifier,
 
1321
      Instance instance) throws Exception {
 
1322
 
 
1323
    Instance classMissing = (Instance)instance.copy();
 
1324
    double pred = 0;
 
1325
    classMissing.setDataset(instance.dataset());
 
1326
    classMissing.setClassMissing();
 
1327
    if (m_ClassIsNominal) {
 
1328
      double [] dist = classifier.distributionForInstance(classMissing);
 
1329
      pred = Utils.maxIndex(dist);
 
1330
      if (dist[(int)pred] <= 0) {
 
1331
        pred = Instance.missingValue();
 
1332
      }
 
1333
      updateStatsForClassifier(dist, instance);
 
1334
    } else {
 
1335
      pred = classifier.classifyInstance(classMissing);
 
1336
      updateStatsForPredictor(pred, instance);
 
1337
    }
 
1338
    return pred;
 
1339
  }
 
1340
 
 
1341
  /**
 
1342
   * Evaluates the supplied distribution on a single instance.
 
1343
   *
 
1344
   * @param dist the supplied distribution
 
1345
   * @param instance the test instance to be classified
 
1346
   * @return the prediction
 
1347
   * @throws Exception if model could not be evaluated 
 
1348
   * successfully
 
1349
   */
 
1350
  public double evaluateModelOnce(double [] dist, 
 
1351
      Instance instance) throws Exception {
 
1352
    double pred;
 
1353
    if (m_ClassIsNominal) {
 
1354
      pred = Utils.maxIndex(dist);
 
1355
      if (dist[(int)pred] <= 0) {
 
1356
        pred = Instance.missingValue();
 
1357
      }
 
1358
      updateStatsForClassifier(dist, instance);
 
1359
    } else {
 
1360
      pred = dist[0];
 
1361
      updateStatsForPredictor(pred, instance);
 
1362
    }
 
1363
    return pred;
 
1364
  }
 
1365
 
 
1366
  /**
 
1367
   * Evaluates the supplied distribution on a single instance.
 
1368
   *
 
1369
   * @param dist the supplied distribution
 
1370
   * @param instance the test instance to be classified
 
1371
   * @return the prediction
 
1372
   * @throws Exception if model could not be evaluated 
 
1373
   * successfully
 
1374
   */
 
1375
  public double evaluateModelOnceAndRecordPrediction(double [] dist, 
 
1376
      Instance instance) throws Exception {
 
1377
    double pred;
 
1378
    if (m_ClassIsNominal) {
 
1379
      if (m_Predictions == null) {
 
1380
        m_Predictions = new FastVector();
 
1381
      }
 
1382
      pred = Utils.maxIndex(dist);
 
1383
      if (dist[(int)pred] <= 0) {
 
1384
        pred = Instance.missingValue();
 
1385
      }
 
1386
      updateStatsForClassifier(dist, instance);
 
1387
      m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist, 
 
1388
          instance.weight()));
 
1389
    } else {
 
1390
      pred = dist[0];
 
1391
      updateStatsForPredictor(pred, instance);
 
1392
    }
 
1393
    return pred;
 
1394
  }
 
1395
 
 
1396
  /**
 
1397
   * Evaluates the supplied prediction on a single instance.
 
1398
   *
 
1399
   * @param prediction the supplied prediction
 
1400
   * @param instance the test instance to be classified
 
1401
   * @throws Exception if model could not be evaluated 
 
1402
   * successfully
 
1403
   */
 
1404
  public void evaluateModelOnce(double prediction,
 
1405
      Instance instance) throws Exception {
 
1406
 
 
1407
    if (m_ClassIsNominal) {
 
1408
      updateStatsForClassifier(makeDistribution(prediction), 
 
1409
          instance);
 
1410
    } else {
 
1411
      updateStatsForPredictor(prediction, instance);
 
1412
    }
 
1413
  }
 
1414
 
 
1415
  /**
 
1416
   * Returns the predictions that have been collected.
 
1417
   *
 
1418
   * @return a reference to the FastVector containing the predictions
 
1419
   * that have been collected. This should be null if no predictions
 
1420
   * have been collected (e.g. if the class is numeric).
 
1421
   */
 
1422
  public FastVector predictions() {
 
1423
 
 
1424
    return m_Predictions;
 
1425
  }
 
1426
 
 
1427
  /**
 
1428
   * Wraps a static classifier in enough source to test using the weka
 
1429
   * class libraries.
 
1430
   *
 
1431
   * @param classifier a Sourcable Classifier
 
1432
   * @param className the name to give to the source code class
 
1433
   * @return the source for a static classifier that can be tested with
 
1434
   * weka libraries.
 
1435
   * @throws Exception if code-generation fails
 
1436
   */
 
1437
  public static String wekaStaticWrapper(Sourcable classifier, String className)     
 
1438
    throws Exception {
 
1439
 
 
1440
    StringBuffer result = new StringBuffer();
 
1441
    String staticClassifier = classifier.toSource(className);
 
1442
    
 
1443
    result.append("// Generated with Weka " + Version.VERSION + "\n");
 
1444
    result.append("//\n");
 
1445
    result.append("// This code is public domain and comes with no warranty.\n");
 
1446
    result.append("//\n");
 
1447
    result.append("// Timestamp: " + new Date() + "\n");
 
1448
    result.append("\n");
 
1449
    result.append("package weka.classifiers;\n");
 
1450
    result.append("\n");
 
1451
    result.append("import weka.core.Attribute;\n");
 
1452
    result.append("import weka.core.Capabilities;\n");
 
1453
    result.append("import weka.core.Capabilities.Capability;\n");
 
1454
    result.append("import weka.core.Instance;\n");
 
1455
    result.append("import weka.core.Instances;\n");
 
1456
    result.append("import weka.classifiers.Classifier;\n");
 
1457
    result.append("\n");
 
1458
    result.append("public class WekaWrapper\n");
 
1459
    result.append("  extends Classifier {\n");
 
1460
    
 
1461
    // globalInfo
 
1462
    result.append("\n");
 
1463
    result.append("  /**\n");
 
1464
    result.append("   * Returns only the toString() method.\n");
 
1465
    result.append("   *\n");
 
1466
    result.append("   * @return a string describing the classifier\n");
 
1467
    result.append("   */\n");
 
1468
    result.append("  public String globalInfo() {\n");
 
1469
    result.append("    return toString();\n");
 
1470
    result.append("  }\n");
 
1471
    
 
1472
    // getCapabilities
 
1473
    result.append("\n");
 
1474
    result.append("  /**\n");
 
1475
    result.append("   * Returns the capabilities of this classifier.\n");
 
1476
    result.append("   *\n");
 
1477
    result.append("   * @return the capabilities\n");
 
1478
    result.append("   */\n");
 
1479
    result.append("  public Capabilities getCapabilities() {\n");
 
1480
    result.append(((Classifier) classifier).getCapabilities().toSource("result", 4));
 
1481
    result.append("    return result;\n");
 
1482
    result.append("  }\n");
 
1483
    
 
1484
    // buildClassifier
 
1485
    result.append("\n");
 
1486
    result.append("  /**\n");
 
1487
    result.append("   * only checks the data against its capabilities.\n");
 
1488
    result.append("   *\n");
 
1489
    result.append("   * @param i the training data\n");
 
1490
    result.append("   */\n");
 
1491
    result.append("  public void buildClassifier(Instances i) throws Exception {\n");
 
1492
    result.append("    // can classifier handle the data?\n");
 
1493
    result.append("    getCapabilities().testWithFail(i);\n");
 
1494
    result.append("  }\n");
 
1495
    
 
1496
    // classifyInstance
 
1497
    result.append("\n");
 
1498
    result.append("  /**\n");
 
1499
    result.append("   * Classifies the given instance.\n");
 
1500
    result.append("   *\n");
 
1501
    result.append("   * @param i the instance to classify\n");
 
1502
    result.append("   * @return the classification result\n");
 
1503
    result.append("   */\n");
 
1504
    result.append("  public double classifyInstance(Instance i) throws Exception {\n");
 
1505
    result.append("    Object[] s = new Object[i.numAttributes()];\n");
 
1506
    result.append("    \n");
 
1507
    result.append("    for (int j = 0; j < s.length; j++) {\n");
 
1508
    result.append("      if (!i.isMissing(j)) {\n");
 
1509
    result.append("        if (i.attribute(j).isNominal())\n");
 
1510
    result.append("          s[j] = new String(i.stringValue(j));\n");
 
1511
    result.append("        else if (i.attribute(j).isNumeric())\n");
 
1512
    result.append("          s[j] = new Double(i.value(j));\n");
 
1513
    result.append("      }\n");
 
1514
    result.append("    }\n");
 
1515
    result.append("    \n");
 
1516
    result.append("    // set class value to missing\n");
 
1517
    result.append("    s[i.classIndex()] = null;\n");
 
1518
    result.append("    \n");
 
1519
    result.append("    return " + className + ".classify(s);\n");
 
1520
    result.append("  }\n");
 
1521
    
 
1522
    // toString
 
1523
    result.append("\n");
 
1524
    result.append("  /**\n");
 
1525
    result.append("   * Returns only the classnames and what classifier it is based on.\n");
 
1526
    result.append("   *\n");
 
1527
    result.append("   * @return a short description\n");
 
1528
    result.append("   */\n");
 
1529
    result.append("  public String toString() {\n");
 
1530
    result.append("    return \"Auto-generated classifier wrapper, based on " 
 
1531
        + classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n" 
 
1532
        + "\" + this.getClass().getName() + \"/" + className + "\";\n");
 
1533
    result.append("  }\n");
 
1534
    
 
1535
    // main
 
1536
    result.append("\n");
 
1537
    result.append("  /**\n");
 
1538
    result.append("   * Runs the classfier from commandline.\n");
 
1539
    result.append("   *\n");
 
1540
    result.append("   * @param args the commandline arguments\n");
 
1541
    result.append("   */\n");
 
1542
    result.append("  public static void main(String args[]) {\n");
 
1543
    result.append("    runClassifier(new WekaWrapper(), args);\n");
 
1544
    result.append("  }\n");
 
1545
    result.append("}\n");
 
1546
    
 
1547
    // actual classifier code
 
1548
    result.append("\n");
 
1549
    result.append(staticClassifier);
 
1550
    
 
1551
    return result.toString();
 
1552
  }
 
1553
 
 
1554
  /**
 
1555
   * Gets the number of test instances that had a known class value
 
1556
   * (actually the sum of the weights of test instances with known 
 
1557
   * class value).
 
1558
   *
 
1559
   * @return the number of test instances with known class
 
1560
   */
 
1561
  public final double numInstances() {
 
1562
 
 
1563
    return m_WithClass;
 
1564
  }
 
1565
 
 
1566
  /**
 
1567
   * Gets the number of instances incorrectly classified (that is, for
 
1568
   * which an incorrect prediction was made). (Actually the sum of the weights
 
1569
   * of these instances)
 
1570
   *
 
1571
   * @return the number of incorrectly classified instances 
 
1572
   */
 
1573
  public final double incorrect() {
 
1574
 
 
1575
    return m_Incorrect;
 
1576
  }
 
1577
 
 
1578
  /**
 
1579
   * Gets the percentage of instances incorrectly classified (that is, for
 
1580
   * which an incorrect prediction was made).
 
1581
   *
 
1582
   * @return the percent of incorrectly classified instances 
 
1583
   * (between 0 and 100)
 
1584
   */
 
1585
  public final double pctIncorrect() {
 
1586
 
 
1587
    return 100 * m_Incorrect / m_WithClass;
 
1588
  }
 
1589
 
 
1590
  /**
 
1591
   * Gets the total cost, that is, the cost of each prediction times the
 
1592
   * weight of the instance, summed over all instances.
 
1593
   *
 
1594
   * @return the total cost
 
1595
   */
 
1596
  public final double totalCost() {
 
1597
 
 
1598
    return m_TotalCost;
 
1599
  }
 
1600
 
 
1601
  /**
 
1602
   * Gets the average cost, that is, total cost of misclassifications
 
1603
   * (incorrect plus unclassified) over the total number of instances.
 
1604
   *
 
1605
   * @return the average cost.  
 
1606
   */
 
1607
  public final double avgCost() {
 
1608
 
 
1609
    return m_TotalCost / m_WithClass;
 
1610
  }
 
1611
 
 
1612
  /**
 
1613
   * Gets the number of instances correctly classified (that is, for
 
1614
   * which a correct prediction was made). (Actually the sum of the weights
 
1615
   * of these instances)
 
1616
   *
 
1617
   * @return the number of correctly classified instances
 
1618
   */
 
1619
  public final double correct() {
 
1620
 
 
1621
    return m_Correct;
 
1622
  }
 
1623
 
 
1624
  /**
 
1625
   * Gets the percentage of instances correctly classified (that is, for
 
1626
   * which a correct prediction was made).
 
1627
   *
 
1628
   * @return the percent of correctly classified instances (between 0 and 100)
 
1629
   */
 
1630
  public final double pctCorrect() {
 
1631
 
 
1632
    return 100 * m_Correct / m_WithClass;
 
1633
  }
 
1634
 
 
1635
  /**
 
1636
   * Gets the number of instances not classified (that is, for
 
1637
   * which no prediction was made by the classifier). (Actually the sum
 
1638
   * of the weights of these instances)
 
1639
   *
 
1640
   * @return the number of unclassified instances
 
1641
   */
 
1642
  public final double unclassified() {
 
1643
 
 
1644
    return m_Unclassified;
 
1645
  }
 
1646
 
 
1647
  /**
 
1648
   * Gets the percentage of instances not classified (that is, for
 
1649
   * which no prediction was made by the classifier).
 
1650
   *
 
1651
   * @return the percent of unclassified instances (between 0 and 100)
 
1652
   */
 
1653
  public final double pctUnclassified() {
 
1654
 
 
1655
    return 100 * m_Unclassified / m_WithClass;
 
1656
  }
 
1657
 
 
1658
  /**
 
1659
   * Returns the estimated error rate or the root mean squared error
 
1660
   * (if the class is numeric). If a cost matrix was given this
 
1661
   * error rate gives the average cost.
 
1662
   *
 
1663
   * @return the estimated error rate (between 0 and 1, or between 0 and 
 
1664
   * maximum cost)
 
1665
   */
 
1666
  public final double errorRate() {
 
1667
 
 
1668
    if (!m_ClassIsNominal) {
 
1669
      return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
 
1670
    }
 
1671
    if (m_CostMatrix == null) {
 
1672
      return m_Incorrect / m_WithClass;
 
1673
    } else {
 
1674
      return avgCost();
 
1675
    }
 
1676
  }
 
1677
 
 
1678
  /**
 
1679
   * Returns value of kappa statistic if class is nominal.
 
1680
   *
 
1681
   * @return the value of the kappa statistic
 
1682
   */
 
1683
  public final double kappa() {
 
1684
 
 
1685
 
 
1686
    double[] sumRows = new double[m_ConfusionMatrix.length];
 
1687
    double[] sumColumns = new double[m_ConfusionMatrix.length];
 
1688
    double sumOfWeights = 0;
 
1689
    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
 
1690
      for (int j = 0; j < m_ConfusionMatrix.length; j++) {
 
1691
        sumRows[i] += m_ConfusionMatrix[i][j];
 
1692
        sumColumns[j] += m_ConfusionMatrix[i][j];
 
1693
        sumOfWeights += m_ConfusionMatrix[i][j];
 
1694
      }
 
1695
    }
 
1696
    double correct = 0, chanceAgreement = 0;
 
1697
    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
 
1698
      chanceAgreement += (sumRows[i] * sumColumns[i]);
 
1699
      correct += m_ConfusionMatrix[i][i];
 
1700
    }
 
1701
    chanceAgreement /= (sumOfWeights * sumOfWeights);
 
1702
    correct /= sumOfWeights;
 
1703
 
 
1704
    if (chanceAgreement < 1) {
 
1705
      return (correct - chanceAgreement) / (1 - chanceAgreement);
 
1706
    } else {
 
1707
      return 1;
 
1708
    }
 
1709
  }
 
1710
 
 
1711
  /**
 
1712
   * Returns the correlation coefficient if the class is numeric.
 
1713
   *
 
1714
   * @return the correlation coefficient
 
1715
   * @throws Exception if class is not numeric
 
1716
   */
 
1717
  public final double correlationCoefficient() throws Exception {
 
1718
 
 
1719
    if (m_ClassIsNominal) {
 
1720
      throw
 
1721
      new Exception("Can't compute correlation coefficient: " + 
 
1722
      "class is nominal!");
 
1723
    }
 
1724
 
 
1725
    double correlation = 0;
 
1726
    double varActual = 
 
1727
      m_SumSqrClass - m_SumClass * m_SumClass / 
 
1728
      (m_WithClass - m_Unclassified);
 
1729
    double varPredicted = 
 
1730
      m_SumSqrPredicted - m_SumPredicted * m_SumPredicted / 
 
1731
      (m_WithClass - m_Unclassified);
 
1732
    double varProd = 
 
1733
      m_SumClassPredicted - m_SumClass * m_SumPredicted / 
 
1734
      (m_WithClass - m_Unclassified);
 
1735
 
 
1736
    if (varActual * varPredicted <= 0) {
 
1737
      correlation = 0.0;
 
1738
    } else {
 
1739
      correlation = varProd / Math.sqrt(varActual * varPredicted);
 
1740
    }
 
1741
 
 
1742
    return correlation;
 
1743
  }
 
1744
 
 
1745
  /**
 
1746
   * Returns the mean absolute error. Refers to the error of the
 
1747
   * predicted values for numeric classes, and the error of the 
 
1748
   * predicted probability distribution for nominal classes.
 
1749
   *
 
1750
   * @return the mean absolute error 
 
1751
   */
 
1752
  public final double meanAbsoluteError() {
 
1753
 
 
1754
    return m_SumAbsErr / (m_WithClass - m_Unclassified);
 
1755
  }
 
1756
 
 
1757
  /**
 
1758
   * Returns the mean absolute error of the prior.
 
1759
   *
 
1760
   * @return the mean absolute error 
 
1761
   */
 
1762
  public final double meanPriorAbsoluteError() {
 
1763
 
 
1764
    if (m_NoPriors)
 
1765
      return Double.NaN;
 
1766
 
 
1767
    return m_SumPriorAbsErr / m_WithClass;
 
1768
  }
 
1769
 
 
1770
  /**
 
1771
   * Returns the relative absolute error.
 
1772
   *
 
1773
   * @return the relative absolute error 
 
1774
   * @throws Exception if it can't be computed
 
1775
   */
 
1776
  public final double relativeAbsoluteError() throws Exception {
 
1777
 
 
1778
    if (m_NoPriors)
 
1779
      return Double.NaN;
 
1780
 
 
1781
    return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
 
1782
  }
 
1783
 
 
1784
  /**
 
1785
   * Returns the root mean squared error.
 
1786
   *
 
1787
   * @return the root mean squared error 
 
1788
   */
 
1789
  public final double rootMeanSquaredError() {
 
1790
 
 
1791
    return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
 
1792
  }
 
1793
 
 
1794
  /**
 
1795
   * Returns the root mean prior squared error.
 
1796
   *
 
1797
   * @return the root mean prior squared error 
 
1798
   */
 
1799
  public final double rootMeanPriorSquaredError() {
 
1800
 
 
1801
    if (m_NoPriors)
 
1802
      return Double.NaN;
 
1803
 
 
1804
    return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
 
1805
  }
 
1806
 
 
1807
  /**
 
1808
   * Returns the root relative squared error if the class is numeric.
 
1809
   *
 
1810
   * @return the root relative squared error 
 
1811
   */
 
1812
  public final double rootRelativeSquaredError() {
 
1813
 
 
1814
    if (m_NoPriors)
 
1815
      return Double.NaN;
 
1816
 
 
1817
    return 100.0 * rootMeanSquaredError() / 
 
1818
    rootMeanPriorSquaredError();
 
1819
  }
 
1820
 
 
1821
  /**
 
1822
   * Calculate the entropy of the prior distribution
 
1823
   *
 
1824
   * @return the entropy of the prior distribution
 
1825
   * @throws Exception if the class is not nominal
 
1826
   */
 
1827
  public final double priorEntropy() throws Exception {
 
1828
 
 
1829
    if (!m_ClassIsNominal) {
 
1830
      throw
 
1831
      new Exception("Can't compute entropy of class prior: " + 
 
1832
      "class numeric!");
 
1833
    }
 
1834
 
 
1835
    if (m_NoPriors)
 
1836
      return Double.NaN;
 
1837
 
 
1838
    double entropy = 0;
 
1839
    for(int i = 0; i < m_NumClasses; i++) {
 
1840
      entropy -= m_ClassPriors[i] / m_ClassPriorsSum 
 
1841
      * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
 
1842
    }
 
1843
    return entropy;
 
1844
  }
 
1845
 
 
1846
  /**
 
1847
   * Return the total Kononenko & Bratko Information score in bits
 
1848
   *
 
1849
   * @return the K&B information score
 
1850
   * @throws Exception if the class is not nominal
 
1851
   */
 
1852
  public final double KBInformation() throws Exception {
 
1853
 
 
1854
    if (!m_ClassIsNominal) {
 
1855
      throw
 
1856
      new Exception("Can't compute K&B Info score: " + 
 
1857
      "class numeric!");
 
1858
    }
 
1859
 
 
1860
    if (m_NoPriors)
 
1861
      return Double.NaN;
 
1862
 
 
1863
    return m_SumKBInfo;
 
1864
  }
 
1865
 
 
1866
  /**
 
1867
   * Return the Kononenko & Bratko Information score in bits per 
 
1868
   * instance.
 
1869
   *
 
1870
   * @return the K&B information score
 
1871
   * @throws Exception if the class is not nominal
 
1872
   */
 
1873
  public final double KBMeanInformation() throws Exception {
 
1874
 
 
1875
    if (!m_ClassIsNominal) {
 
1876
      throw
 
1877
      new Exception("Can't compute K&B Info score: "
 
1878
          + "class numeric!");
 
1879
    }
 
1880
 
 
1881
    if (m_NoPriors)
 
1882
      return Double.NaN;
 
1883
 
 
1884
    return m_SumKBInfo / (m_WithClass - m_Unclassified);
 
1885
  }
 
1886
 
 
1887
  /**
 
1888
   * Return the Kononenko & Bratko Relative Information score
 
1889
   *
 
1890
   * @return the K&B relative information score
 
1891
   * @throws Exception if the class is not nominal
 
1892
   */
 
1893
  public final double KBRelativeInformation() throws Exception {
 
1894
 
 
1895
    if (!m_ClassIsNominal) {
 
1896
      throw
 
1897
      new Exception("Can't compute K&B Info score: " + 
 
1898
      "class numeric!");
 
1899
    }
 
1900
 
 
1901
    if (m_NoPriors)
 
1902
      return Double.NaN;
 
1903
 
 
1904
    return 100.0 * KBInformation() / priorEntropy();
 
1905
  }
 
1906
 
 
1907
  /**
 
1908
   * Returns the total entropy for the null model
 
1909
   * 
 
1910
   * @return the total null model entropy
 
1911
   */
 
1912
  public final double SFPriorEntropy() {
 
1913
 
 
1914
    if (m_NoPriors)
 
1915
      return Double.NaN;
 
1916
 
 
1917
    return m_SumPriorEntropy;
 
1918
  }
 
1919
 
 
1920
  /**
 
1921
   * Returns the entropy per instance for the null model
 
1922
   * 
 
1923
   * @return the null model entropy per instance
 
1924
   */
 
1925
  public final double SFMeanPriorEntropy() {
 
1926
 
 
1927
    if (m_NoPriors)
 
1928
      return Double.NaN;
 
1929
 
 
1930
    return m_SumPriorEntropy / m_WithClass;
 
1931
  }
 
1932
 
 
1933
  /**
 
1934
   * Returns the total entropy for the scheme
 
1935
   * 
 
1936
   * @return the total scheme entropy
 
1937
   */
 
1938
  public final double SFSchemeEntropy() {
 
1939
 
 
1940
    if (m_NoPriors)
 
1941
      return Double.NaN;
 
1942
 
 
1943
    return m_SumSchemeEntropy;
 
1944
  }
 
1945
 
 
1946
  /**
 
1947
   * Returns the entropy per instance for the scheme
 
1948
   * 
 
1949
   * @return the scheme entropy per instance
 
1950
   */
 
1951
  public final double SFMeanSchemeEntropy() {
 
1952
 
 
1953
    if (m_NoPriors)
 
1954
      return Double.NaN;
 
1955
 
 
1956
    return m_SumSchemeEntropy / (m_WithClass - m_Unclassified);
 
1957
  }
 
1958
 
 
1959
  /**
 
1960
   * Returns the total SF, which is the null model entropy minus
 
1961
   * the scheme entropy.
 
1962
   * 
 
1963
   * @return the total SF
 
1964
   */
 
1965
  public final double SFEntropyGain() {
 
1966
 
 
1967
    if (m_NoPriors)
 
1968
      return Double.NaN;
 
1969
 
 
1970
    return m_SumPriorEntropy - m_SumSchemeEntropy;
 
1971
  }
 
1972
 
 
1973
  /**
 
1974
   * Returns the SF per instance, which is the null model entropy
 
1975
   * minus the scheme entropy, per instance.
 
1976
   * 
 
1977
   * @return the SF per instance
 
1978
   */
 
1979
  public final double SFMeanEntropyGain() {
 
1980
 
 
1981
    if (m_NoPriors)
 
1982
      return Double.NaN;
 
1983
 
 
1984
    return (m_SumPriorEntropy - m_SumSchemeEntropy) / 
 
1985
      (m_WithClass - m_Unclassified);
 
1986
  }
 
1987
 
 
1988
  /**
 
1989
   * Output the cumulative margin distribution as a string suitable
 
1990
   * for input for gnuplot or similar package.
 
1991
   *
 
1992
   * @return the cumulative margin distribution
 
1993
   * @throws Exception if the class attribute is nominal
 
1994
   */
 
1995
  public String toCumulativeMarginDistributionString() throws Exception {
 
1996
 
 
1997
    if (!m_ClassIsNominal) {
 
1998
      throw new Exception("Class must be nominal for margin distributions");
 
1999
    }
 
2000
    String result = "";
 
2001
    double cumulativeCount = 0;
 
2002
    double margin;
 
2003
    for(int i = 0; i <= k_MarginResolution; i++) {
 
2004
      if (m_MarginCounts[i] != 0) {
 
2005
        cumulativeCount += m_MarginCounts[i];
 
2006
        margin = (double)i * 2.0 / k_MarginResolution - 1.0;
 
2007
        result = result + Utils.doubleToString(margin, 7, 3) + ' ' 
 
2008
        + Utils.doubleToString(cumulativeCount * 100 
 
2009
            / m_WithClass, 7, 3) + '\n';
 
2010
      } else if (i == 0) {
 
2011
        result = Utils.doubleToString(-1.0, 7, 3) + ' ' 
 
2012
        + Utils.doubleToString(0, 7, 3) + '\n';
 
2013
      }
 
2014
    }
 
2015
    return result;
 
2016
  }
 
2017
 
 
2018
 
 
2019
  /**
 
2020
   * Calls toSummaryString() with no title and no complexity stats
 
2021
   *
 
2022
   * @return a summary description of the classifier evaluation
 
2023
   */
 
2024
  public String toSummaryString() {
 
2025
 
 
2026
    return toSummaryString("", false);
 
2027
  }
 
2028
 
 
2029
  /**
 
2030
   * Calls toSummaryString() with a default title.
 
2031
   *
 
2032
   * @param printComplexityStatistics if true, complexity statistics are
 
2033
   * returned as well
 
2034
   * @return the summary string
 
2035
   */
 
2036
  public String toSummaryString(boolean printComplexityStatistics) {
 
2037
 
 
2038
    return toSummaryString("=== Summary ===\n", printComplexityStatistics);
 
2039
  }
 
2040
 
 
2041
  /**
 
2042
   * Outputs the performance statistics in summary form. Lists 
 
2043
   * number (and percentage) of instances classified correctly, 
 
2044
   * incorrectly and unclassified. Outputs the total number of 
 
2045
   * instances classified, and the number of instances (if any) 
 
2046
   * that had no class value provided. 
 
2047
   *
 
2048
   * @param title the title for the statistics
 
2049
   * @param printComplexityStatistics if true, complexity statistics are
 
2050
   * returned as well
 
2051
   * @return the summary as a String
 
2052
   */
 
2053
  public String toSummaryString(String title, 
 
2054
      boolean printComplexityStatistics) { 
 
2055
 
 
2056
    StringBuffer text = new StringBuffer();
 
2057
 
 
2058
    if (printComplexityStatistics && m_NoPriors) {
 
2059
      printComplexityStatistics = false;
 
2060
      System.err.println("Priors disabled, cannot print complexity statistics!");
 
2061
    }
 
2062
 
 
2063
    text.append(title + "\n");
 
2064
    try {
 
2065
      if (m_WithClass > 0) {
 
2066
        if (m_ClassIsNominal) {
 
2067
 
 
2068
          text.append("Correctly Classified Instances     ");
 
2069
          text.append(Utils.doubleToString(correct(), 12, 4) + "     " +
 
2070
              Utils.doubleToString(pctCorrect(),
 
2071
                  12, 4) + " %\n");
 
2072
          text.append("Incorrectly Classified Instances   ");
 
2073
          text.append(Utils.doubleToString(incorrect(), 12, 4) + "     " +
 
2074
              Utils.doubleToString(pctIncorrect(),
 
2075
                  12, 4) + " %\n");
 
2076
          text.append("Kappa statistic                    ");
 
2077
          text.append(Utils.doubleToString(kappa(), 12, 4) + "\n");
 
2078
 
 
2079
          if (m_CostMatrix != null) {
 
2080
            text.append("Total Cost                         ");
 
2081
            text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n");
 
2082
            text.append("Average Cost                       ");
 
2083
            text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n");
 
2084
          }
 
2085
          if (printComplexityStatistics) {
 
2086
            text.append("K&B Relative Info Score            ");
 
2087
            text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4) 
 
2088
                + " %\n");
 
2089
            text.append("K&B Information Score              ");
 
2090
            text.append(Utils.doubleToString(KBInformation(), 12, 4) 
 
2091
                + " bits");
 
2092
            text.append(Utils.doubleToString(KBMeanInformation(), 12, 4) 
 
2093
                + " bits/instance\n");
 
2094
          }
 
2095
        } else {        
 
2096
          text.append("Correlation coefficient            ");
 
2097
          text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
 
2098
          "\n");
 
2099
        }
 
2100
        if (printComplexityStatistics) {
 
2101
          text.append("Class complexity | order 0         ");
 
2102
          text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4) 
 
2103
              + " bits");
 
2104
          text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4) 
 
2105
              + " bits/instance\n");
 
2106
          text.append("Class complexity | scheme          ");
 
2107
          text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4) 
 
2108
              + " bits");
 
2109
          text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4) 
 
2110
              + " bits/instance\n");
 
2111
          text.append("Complexity improvement     (Sf)    ");
 
2112
          text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
 
2113
          text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4) 
 
2114
              + " bits/instance\n");
 
2115
        }
 
2116
 
 
2117
        text.append("Mean absolute error                ");
 
2118
        text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4) 
 
2119
            + "\n");
 
2120
        text.append("Root mean squared error            ");
 
2121
        text.append(Utils.
 
2122
            doubleToString(rootMeanSquaredError(), 12, 4) 
 
2123
            + "\n");
 
2124
        if (!m_NoPriors) {
 
2125
          text.append("Relative absolute error            ");
 
2126
          text.append(Utils.doubleToString(relativeAbsoluteError(), 
 
2127
              12, 4) + " %\n");
 
2128
          text.append("Root relative squared error        ");
 
2129
          text.append(Utils.doubleToString(rootRelativeSquaredError(), 
 
2130
              12, 4) + " %\n");
 
2131
        }
 
2132
      }
 
2133
      if (Utils.gr(unclassified(), 0)) {
 
2134
        text.append("UnClassified Instances             ");
 
2135
        text.append(Utils.doubleToString(unclassified(), 12,4) +  "     " +
 
2136
            Utils.doubleToString(pctUnclassified(),
 
2137
                12, 4) + " %\n");
 
2138
      }
 
2139
      text.append("Total Number of Instances          ");
 
2140
      text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n");
 
2141
      if (m_MissingClass > 0) {
 
2142
        text.append("Ignored Class Unknown Instances            ");
 
2143
        text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n");
 
2144
      }
 
2145
    } catch (Exception ex) {
 
2146
      // Should never occur since the class is known to be nominal 
 
2147
      // here
 
2148
      System.err.println("Arggh - Must be a bug in Evaluation class");
 
2149
    }
 
2150
 
 
2151
    return text.toString(); 
 
2152
  }
 
2153
 
 
2154
  /**
 
2155
   * Calls toMatrixString() with a default title.
 
2156
   *
 
2157
   * @return the confusion matrix as a string
 
2158
   * @throws Exception if the class is numeric
 
2159
   */
 
2160
  public String toMatrixString() throws Exception {
 
2161
 
 
2162
    return toMatrixString("=== Confusion Matrix ===\n");
 
2163
  }
 
2164
 
 
2165
  /**
 
2166
   * Outputs the performance statistics as a classification confusion
 
2167
   * matrix. For each class value, shows the distribution of 
 
2168
   * predicted class values.
 
2169
   *
 
2170
   * @param title the title for the confusion matrix
 
2171
   * @return the confusion matrix as a String
 
2172
   * @throws Exception if the class is numeric
 
2173
   */
 
2174
  public String toMatrixString(String title) throws Exception {
 
2175
 
 
2176
    StringBuffer text = new StringBuffer();
 
2177
    char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
 
2178
        'k','l','m','n','o','p','q','r','s','t',
 
2179
        'u','v','w','x','y','z'};
 
2180
    int IDWidth;
 
2181
    boolean fractional = false;
 
2182
 
 
2183
    if (!m_ClassIsNominal) {
 
2184
      throw new Exception("Evaluation: No confusion matrix possible!");
 
2185
    }
 
2186
 
 
2187
    // Find the maximum value in the matrix
 
2188
    // and check for fractional display requirement 
 
2189
    double maxval = 0;
 
2190
    for(int i = 0; i < m_NumClasses; i++) {
 
2191
      for(int j = 0; j < m_NumClasses; j++) {
 
2192
        double current = m_ConfusionMatrix[i][j];
 
2193
        if (current < 0) {
 
2194
          current *= -10;
 
2195
        }
 
2196
        if (current > maxval) {
 
2197
          maxval = current;
 
2198
        }
 
2199
        double fract = current - Math.rint(current);
 
2200
        if (!fractional
 
2201
            && ((Math.log(fract) / Math.log(10)) >= -2)) {
 
2202
          fractional = true;
 
2203
        }
 
2204
      }
 
2205
    }
 
2206
 
 
2207
    IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10) 
 
2208
        + (fractional ? 3 : 0)),
 
2209
        (int)(Math.log(m_NumClasses) / 
 
2210
            Math.log(IDChars.length)));
 
2211
    text.append(title).append("\n");
 
2212
    for(int i = 0; i < m_NumClasses; i++) {
 
2213
      if (fractional) {
 
2214
        text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
 
2215
        .append("   ");
 
2216
      } else {
 
2217
        text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
 
2218
      }
 
2219
    }
 
2220
    text.append("   <-- classified as\n");
 
2221
    for(int i = 0; i< m_NumClasses; i++) { 
 
2222
      for(int j = 0; j < m_NumClasses; j++) {
 
2223
        text.append(" ").append(
 
2224
            Utils.doubleToString(m_ConfusionMatrix[i][j],
 
2225
                IDWidth,
 
2226
                (fractional ? 2 : 0)));
 
2227
      }
 
2228
      text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
 
2229
      .append(" = ").append(m_ClassNames[i]).append("\n");
 
2230
    }
 
2231
    return text.toString();
 
2232
  }
 
2233
 
 
2234
  /**
 
2235
   * Generates a breakdown of the accuracy for each class (with default title),
 
2236
   * incorporating various information-retrieval statistics, such as
 
2237
   * true/false positive rate, precision/recall/F-Measure.  Should be
 
2238
   * useful for ROC curves, recall/precision curves.  
 
2239
   * 
 
2240
   * @return the statistics presented as a string
 
2241
   * @throws Exception if class is not nominal
 
2242
   */
 
2243
  public String toClassDetailsString() throws Exception {
 
2244
 
 
2245
    return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
 
2246
  }
 
2247
 
 
2248
  /**
 
2249
   * Generates a breakdown of the accuracy for each class,
 
2250
   * incorporating various information-retrieval statistics, such as
 
2251
   * true/false positive rate, precision/recall/F-Measure.  Should be
 
2252
   * useful for ROC curves, recall/precision curves.  
 
2253
   * 
 
2254
   * @param title the title to prepend the stats string with 
 
2255
   * @return the statistics presented as a string
 
2256
   * @throws Exception if class is not nominal
 
2257
   */
 
2258
  public String toClassDetailsString(String title) throws Exception {
 
2259
 
 
2260
    if (!m_ClassIsNominal) {
 
2261
      throw new Exception("Evaluation: No confusion matrix possible!");
 
2262
    }
 
2263
    StringBuffer text = new StringBuffer(title 
 
2264
        + "\nTP Rate   FP Rate"
 
2265
        + "   Precision   Recall"
 
2266
        + "  F-Measure   ROC Area  Class\n");
 
2267
    for(int i = 0; i < m_NumClasses; i++) {
 
2268
      text.append(Utils.doubleToString(truePositiveRate(i), 7, 3))
 
2269
      .append("   ");
 
2270
      text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
 
2271
      .append("    ");
 
2272
      text.append(Utils.doubleToString(precision(i), 7, 3))
 
2273
      .append("   ");
 
2274
      text.append(Utils.doubleToString(recall(i), 7, 3))
 
2275
      .append("   ");
 
2276
      text.append(Utils.doubleToString(fMeasure(i), 7, 3))
 
2277
      .append("    ");
 
2278
      double rocVal = areaUnderROC(i);
 
2279
      if (Instance.isMissingValue(rocVal)) {
 
2280
        text.append("  ?    ")
 
2281
        .append("    ");
 
2282
      } else {
 
2283
        text.append(Utils.doubleToString(rocVal, 7, 3))
 
2284
        .append("    ");
 
2285
      }
 
2286
      text.append(m_ClassNames[i]).append('\n');
 
2287
    }
 
2288
    return text.toString();
 
2289
  }
 
2290
 
 
2291
  /**
 
2292
   * Calculate the number of true positives with respect to a particular class. 
 
2293
   * This is defined as<p/>
 
2294
   * <pre>
 
2295
   * correctly classified positives
 
2296
   * </pre>
 
2297
   *
 
2298
   * @param classIndex the index of the class to consider as "positive"
 
2299
   * @return the true positive rate
 
2300
   */
 
2301
  public double numTruePositives(int classIndex) {
 
2302
 
 
2303
    double correct = 0;
 
2304
    for (int j = 0; j < m_NumClasses; j++) {
 
2305
      if (j == classIndex) {
 
2306
        correct += m_ConfusionMatrix[classIndex][j];
 
2307
      }
 
2308
    }
 
2309
    return correct;
 
2310
  }
 
2311
 
 
2312
  /**
 
2313
   * Calculate the true positive rate with respect to a particular class. 
 
2314
   * This is defined as<p/>
 
2315
   * <pre>
 
2316
   * correctly classified positives
 
2317
   * ------------------------------
 
2318
   *       total positives
 
2319
   * </pre>
 
2320
   *
 
2321
   * @param classIndex the index of the class to consider as "positive"
 
2322
   * @return the true positive rate
 
2323
   */
 
2324
  public double truePositiveRate(int classIndex) {
 
2325
 
 
2326
    double correct = 0, total = 0;
 
2327
    for (int j = 0; j < m_NumClasses; j++) {
 
2328
      if (j == classIndex) {
 
2329
        correct += m_ConfusionMatrix[classIndex][j];
 
2330
      }
 
2331
      total += m_ConfusionMatrix[classIndex][j];
 
2332
    }
 
2333
    if (total == 0) {
 
2334
      return 0;
 
2335
    }
 
2336
    return correct / total;
 
2337
  }
 
2338
 
 
2339
  /**
 
2340
   * Calculate the number of true negatives with respect to a particular class. 
 
2341
   * This is defined as<p/>
 
2342
   * <pre>
 
2343
   * correctly classified negatives
 
2344
   * </pre>
 
2345
   *
 
2346
   * @param classIndex the index of the class to consider as "positive"
 
2347
   * @return the true positive rate
 
2348
   */
 
2349
  public double numTrueNegatives(int classIndex) {
 
2350
 
 
2351
    double correct = 0;
 
2352
    for (int i = 0; i < m_NumClasses; i++) {
 
2353
      if (i != classIndex) {
 
2354
        for (int j = 0; j < m_NumClasses; j++) {
 
2355
          if (j != classIndex) {
 
2356
            correct += m_ConfusionMatrix[i][j];
 
2357
          }
 
2358
        }
 
2359
      }
 
2360
    }
 
2361
    return correct;
 
2362
  }
 
2363
 
 
2364
  /**
 
2365
   * Calculate the true negative rate with respect to a particular class. 
 
2366
   * This is defined as<p/>
 
2367
   * <pre>
 
2368
   * correctly classified negatives
 
2369
   * ------------------------------
 
2370
   *       total negatives
 
2371
   * </pre>
 
2372
   *
 
2373
   * @param classIndex the index of the class to consider as "positive"
 
2374
   * @return the true positive rate
 
2375
   */
 
2376
  public double trueNegativeRate(int classIndex) {
 
2377
 
 
2378
    double correct = 0, total = 0;
 
2379
    for (int i = 0; i < m_NumClasses; i++) {
 
2380
      if (i != classIndex) {
 
2381
        for (int j = 0; j < m_NumClasses; j++) {
 
2382
          if (j != classIndex) {
 
2383
            correct += m_ConfusionMatrix[i][j];
 
2384
          }
 
2385
          total += m_ConfusionMatrix[i][j];
 
2386
        }
 
2387
      }
 
2388
    }
 
2389
    if (total == 0) {
 
2390
      return 0;
 
2391
    }
 
2392
    return correct / total;
 
2393
  }
 
2394
 
 
2395
  /**
 
2396
   * Calculate number of false positives with respect to a particular class. 
 
2397
   * This is defined as<p/>
 
2398
   * <pre>
 
2399
   * incorrectly classified negatives
 
2400
   * </pre>
 
2401
   *
 
2402
   * @param classIndex the index of the class to consider as "positive"
 
2403
   * @return the false positive rate
 
2404
   */
 
2405
  public double numFalsePositives(int classIndex) {
 
2406
 
 
2407
    double incorrect = 0;
 
2408
    for (int i = 0; i < m_NumClasses; i++) {
 
2409
      if (i != classIndex) {
 
2410
        for (int j = 0; j < m_NumClasses; j++) {
 
2411
          if (j == classIndex) {
 
2412
            incorrect += m_ConfusionMatrix[i][j];
 
2413
          }
 
2414
        }
 
2415
      }
 
2416
    }
 
2417
    return incorrect;
 
2418
  }
 
2419
 
 
2420
  /**
 
2421
   * Calculate the false positive rate with respect to a particular class. 
 
2422
   * This is defined as<p/>
 
2423
   * <pre>
 
2424
   * incorrectly classified negatives
 
2425
   * --------------------------------
 
2426
   *        total negatives
 
2427
   * </pre>
 
2428
   *
 
2429
   * @param classIndex the index of the class to consider as "positive"
 
2430
   * @return the false positive rate
 
2431
   */
 
2432
  public double falsePositiveRate(int classIndex) {
 
2433
 
 
2434
    double incorrect = 0, total = 0;
 
2435
    for (int i = 0; i < m_NumClasses; i++) {
 
2436
      if (i != classIndex) {
 
2437
        for (int j = 0; j < m_NumClasses; j++) {
 
2438
          if (j == classIndex) {
 
2439
            incorrect += m_ConfusionMatrix[i][j];
 
2440
          }
 
2441
          total += m_ConfusionMatrix[i][j];
 
2442
        }
 
2443
      }
 
2444
    }
 
2445
    if (total == 0) {
 
2446
      return 0;
 
2447
    }
 
2448
    return incorrect / total;
 
2449
  }
 
2450
 
 
2451
  /**
 
2452
   * Calculate number of false negatives with respect to a particular class. 
 
2453
   * This is defined as<p/>
 
2454
   * <pre>
 
2455
   * incorrectly classified positives
 
2456
   * </pre>
 
2457
   *
 
2458
   * @param classIndex the index of the class to consider as "positive"
 
2459
   * @return the false positive rate
 
2460
   */
 
2461
  public double numFalseNegatives(int classIndex) {
 
2462
 
 
2463
    double incorrect = 0;
 
2464
    for (int i = 0; i < m_NumClasses; i++) {
 
2465
      if (i == classIndex) {
 
2466
        for (int j = 0; j < m_NumClasses; j++) {
 
2467
          if (j != classIndex) {
 
2468
            incorrect += m_ConfusionMatrix[i][j];
 
2469
          }
 
2470
        }
 
2471
      }
 
2472
    }
 
2473
    return incorrect;
 
2474
  }
 
2475
 
 
2476
  /**
 
2477
   * Calculate the false negative rate with respect to a particular class. 
 
2478
   * This is defined as<p/>
 
2479
   * <pre>
 
2480
   * incorrectly classified positives
 
2481
   * --------------------------------
 
2482
   *        total positives
 
2483
   * </pre>
 
2484
   *
 
2485
   * @param classIndex the index of the class to consider as "positive"
 
2486
   * @return the false positive rate
 
2487
   */
 
2488
  public double falseNegativeRate(int classIndex) {
 
2489
 
 
2490
    double incorrect = 0, total = 0;
 
2491
    for (int i = 0; i < m_NumClasses; i++) {
 
2492
      if (i == classIndex) {
 
2493
        for (int j = 0; j < m_NumClasses; j++) {
 
2494
          if (j != classIndex) {
 
2495
            incorrect += m_ConfusionMatrix[i][j];
 
2496
          }
 
2497
          total += m_ConfusionMatrix[i][j];
 
2498
        }
 
2499
      }
 
2500
    }
 
2501
    if (total == 0) {
 
2502
      return 0;
 
2503
    }
 
2504
    return incorrect / total;
 
2505
  }
 
2506
 
 
2507
  /**
 
2508
   * Calculate the recall with respect to a particular class. 
 
2509
   * This is defined as<p/>
 
2510
   * <pre>
 
2511
   * correctly classified positives
 
2512
   * ------------------------------
 
2513
   *       total positives
 
2514
   * </pre><p/>
 
2515
   * (Which is also the same as the truePositiveRate.)
 
2516
   *
 
2517
   * @param classIndex the index of the class to consider as "positive"
 
2518
   * @return the recall
 
2519
   */
 
2520
  public double recall(int classIndex) {
 
2521
 
 
2522
    return truePositiveRate(classIndex);
 
2523
  }
 
2524
 
 
2525
  /**
 
2526
   * Calculate the precision with respect to a particular class. 
 
2527
   * This is defined as<p/>
 
2528
   * <pre>
 
2529
   * correctly classified positives
 
2530
   * ------------------------------
 
2531
   *  total predicted as positive
 
2532
   * </pre>
 
2533
   *
 
2534
   * @param classIndex the index of the class to consider as "positive"
 
2535
   * @return the precision
 
2536
   */
 
2537
  public double precision(int classIndex) {
 
2538
 
 
2539
    double correct = 0, total = 0;
 
2540
    for (int i = 0; i < m_NumClasses; i++) {
 
2541
      if (i == classIndex) {
 
2542
        correct += m_ConfusionMatrix[i][classIndex];
 
2543
      }
 
2544
      total += m_ConfusionMatrix[i][classIndex];
 
2545
    }
 
2546
    if (total == 0) {
 
2547
      return 0;
 
2548
    }
 
2549
    return correct / total;
 
2550
  }
 
2551
 
 
2552
  /**
 
2553
   * Calculate the F-Measure with respect to a particular class. 
 
2554
   * This is defined as<p/>
 
2555
   * <pre>
 
2556
   * 2 * recall * precision
 
2557
   * ----------------------
 
2558
   *   recall + precision
 
2559
   * </pre>
 
2560
   *
 
2561
   * @param classIndex the index of the class to consider as "positive"
 
2562
   * @return the F-Measure
 
2563
   */
 
2564
  public double fMeasure(int classIndex) {
 
2565
 
 
2566
    double precision = precision(classIndex);
 
2567
    double recall = recall(classIndex);
 
2568
    if ((precision + recall) == 0) {
 
2569
      return 0;
 
2570
    }
 
2571
    return 2 * precision * recall / (precision + recall);
 
2572
  }
 
2573
 
 
2574
  /**
 
2575
   * Sets the class prior probabilities
 
2576
   *
 
2577
   * @param train the training instances used to determine
 
2578
   * the prior probabilities
 
2579
   * @throws Exception if the class attribute of the instances is not
 
2580
   * set
 
2581
   */
 
2582
  public void setPriors(Instances train) throws Exception {
 
2583
    m_NoPriors = false;
 
2584
 
 
2585
    if (!m_ClassIsNominal) {
 
2586
 
 
2587
      m_NumTrainClassVals = 0;
 
2588
      m_TrainClassVals = null;
 
2589
      m_TrainClassWeights = null;
 
2590
      m_PriorErrorEstimator = null;
 
2591
      m_ErrorEstimator = null;
 
2592
 
 
2593
      for (int i = 0; i < train.numInstances(); i++) {
 
2594
        Instance currentInst = train.instance(i);
 
2595
        if (!currentInst.classIsMissing()) {
 
2596
          addNumericTrainClass(currentInst.classValue(), 
 
2597
              currentInst.weight());
 
2598
        }
 
2599
      }
 
2600
 
 
2601
    } else {
 
2602
      for (int i = 0; i < m_NumClasses; i++) {
 
2603
        m_ClassPriors[i] = 1;
 
2604
      }
 
2605
      m_ClassPriorsSum = m_NumClasses;
 
2606
      for (int i = 0; i < train.numInstances(); i++) {
 
2607
        if (!train.instance(i).classIsMissing()) {
 
2608
          m_ClassPriors[(int)train.instance(i).classValue()] += 
 
2609
            train.instance(i).weight();
 
2610
          m_ClassPriorsSum += train.instance(i).weight();
 
2611
        }
 
2612
      }
 
2613
    }
 
2614
  }
 
2615
 
 
2616
  /**
 
2617
   * Get the current weighted class counts
 
2618
   * 
 
2619
   * @return the weighted class counts
 
2620
   */
 
2621
  public double [] getClassPriors() {
 
2622
    return m_ClassPriors;
 
2623
  }
 
2624
 
 
2625
  /**
 
2626
   * Updates the class prior probabilities (when incrementally 
 
2627
   * training)
 
2628
   *
 
2629
   * @param instance the new training instance seen
 
2630
   * @throws Exception if the class of the instance is not
 
2631
   * set
 
2632
   */
 
2633
  public void updatePriors(Instance instance) throws Exception {
 
2634
    if (!instance.classIsMissing()) {
 
2635
      if (!m_ClassIsNominal) {
 
2636
        if (!instance.classIsMissing()) {
 
2637
          addNumericTrainClass(instance.classValue(), 
 
2638
              instance.weight());
 
2639
        }
 
2640
      } else {
 
2641
        m_ClassPriors[(int)instance.classValue()] += 
 
2642
          instance.weight();
 
2643
        m_ClassPriorsSum += instance.weight();
 
2644
      }
 
2645
    }    
 
2646
  }
 
2647
 
 
2648
  /**
 
2649
   * disables the use of priors, e.g., in case of de-serialized schemes
 
2650
   * that have no access to the original training set, but are evaluated
 
2651
   * on a set set.
 
2652
   */
 
2653
  public void useNoPriors() {
 
2654
    m_NoPriors = true;
 
2655
  }
 
2656
 
 
2657
  /**
 
2658
   * Tests whether the current evaluation object is equal to another
 
2659
   * evaluation object
 
2660
   *
 
2661
   * @param obj the object to compare against
 
2662
   * @return true if the two objects are equal
 
2663
   */
 
2664
  public boolean equals(Object obj) {
 
2665
 
 
2666
    if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
 
2667
      return false;
 
2668
    }
 
2669
    Evaluation cmp = (Evaluation) obj;
 
2670
    if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
 
2671
    if (m_NumClasses != cmp.m_NumClasses) return false;
 
2672
 
 
2673
    if (m_Incorrect != cmp.m_Incorrect) return false;
 
2674
    if (m_Correct != cmp.m_Correct) return false;
 
2675
    if (m_Unclassified != cmp.m_Unclassified) return false;
 
2676
    if (m_MissingClass != cmp.m_MissingClass) return false;
 
2677
    if (m_WithClass != cmp.m_WithClass) return false;
 
2678
 
 
2679
    if (m_SumErr != cmp.m_SumErr) return false;
 
2680
    if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
 
2681
    if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
 
2682
    if (m_SumClass != cmp.m_SumClass) return false;
 
2683
    if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
 
2684
    if (m_SumPredicted != cmp.m_SumPredicted) return false;
 
2685
    if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
 
2686
    if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
 
2687
 
 
2688
    if (m_ClassIsNominal) {
 
2689
      for (int i = 0; i < m_NumClasses; i++) {
 
2690
        for (int j = 0; j < m_NumClasses; j++) {
 
2691
          if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
 
2692
            return false;
 
2693
          }
 
2694
        }
 
2695
      }
 
2696
    }
 
2697
 
 
2698
    return true;
 
2699
  }
 
2700
 
 
2701
  /**
 
2702
   * Prints the predictions for the given dataset into a String variable.
 
2703
   * 
 
2704
   * @param classifier          the classifier to use
 
2705
   * @param train               the training data
 
2706
   * @param testSource          the test set
 
2707
   * @param classIndex          the class index (1-based), if -1 ot does not 
 
2708
   *                            override the class index is stored in the data 
 
2709
   *                            file (by using the last attribute)
 
2710
   * @param attributesToOutput  the indices of the attributes to output
 
2711
   * @return                    the generated predictions for the attribute range
 
2712
   * @throws Exception          if test file cannot be opened
 
2713
   */
 
2714
  protected static String printClassifications(Classifier classifier, 
 
2715
      Instances train,
 
2716
      DataSource testSource,
 
2717
      int classIndex,
 
2718
      Range attributesToOutput) throws Exception {
 
2719
    
 
2720
    return printClassifications(
 
2721
        classifier, train, testSource, classIndex, attributesToOutput, false);
 
2722
  }
 
2723
 
 
2724
  /**
 
2725
   * Prints the predictions for the given dataset into a String variable.
 
2726
   * 
 
2727
   * @param classifier          the classifier to use
 
2728
   * @param train               the training data
 
2729
   * @param testSource          the test set
 
2730
   * @param classIndex          the class index (1-based), if -1 ot does not 
 
2731
   *                            override the class index is stored in the data 
 
2732
   *                            file (by using the last attribute)
 
2733
   * @param attributesToOutput  the indices of the attributes to output
 
2734
   * @param printDistribution   prints the complete distribution for nominal 
 
2735
   *                            classes, not just the predicted value
 
2736
   * @return                    the generated predictions for the attribute range
 
2737
   * @throws Exception          if test file cannot be opened
 
2738
   */
 
2739
  protected static String printClassifications(Classifier classifier, 
 
2740
      Instances train,
 
2741
      DataSource testSource,
 
2742
      int classIndex,
 
2743
      Range attributesToOutput,
 
2744
      boolean printDistribution) throws Exception {
 
2745
 
 
2746
    StringBuffer text = new StringBuffer();
 
2747
    if (testSource != null) {
 
2748
      Instances test = testSource.getStructure();
 
2749
      if (classIndex != -1) {
 
2750
        test.setClassIndex(classIndex - 1);
 
2751
      } else {
 
2752
        if (test.classIndex() == -1)
 
2753
          test.setClassIndex(test.numAttributes() - 1);
 
2754
      }
 
2755
 
 
2756
      // print header
 
2757
      if (test.classAttribute().isNominal())
 
2758
        if (printDistribution)
 
2759
          text.append(" inst#     actual  predicted error distribution");
 
2760
        else
 
2761
          text.append(" inst#     actual  predicted error prediction");
 
2762
      else
 
2763
        text.append(" inst#     actual  predicted      error");
 
2764
      if (attributesToOutput != null) {
 
2765
        attributesToOutput.setUpper(test.numAttributes() - 1);
 
2766
        text.append(" (");
 
2767
        boolean first = true;
 
2768
        for (int i = 0; i < test.numAttributes(); i++) {
 
2769
          if (i == test.classIndex())
 
2770
            continue;
 
2771
 
 
2772
          if (attributesToOutput.isInRange(i)) {
 
2773
            if (!first)
 
2774
              text.append(",");
 
2775
            text.append(test.attribute(i).name());
 
2776
            first = false;
 
2777
          }
 
2778
        }
 
2779
        text.append(")");
 
2780
      }
 
2781
      text.append("\n");
 
2782
 
 
2783
      // print predictions
 
2784
      int i = 0;
 
2785
      testSource.reset();
 
2786
      test = testSource.getStructure(test.classIndex());
 
2787
      while (testSource.hasMoreElements(test)) {
 
2788
        Instance inst = testSource.nextElement(test);
 
2789
        text.append(
 
2790
            predictionText(
 
2791
                classifier, inst, i, attributesToOutput, printDistribution));
 
2792
        i++;
 
2793
      }
 
2794
    }
 
2795
    return text.toString();
 
2796
  }
 
2797
 
 
2798
  /**
 
2799
   * returns the prediction made by the classifier as a string
 
2800
   * 
 
2801
   * @param classifier          the classifier to use
 
2802
   * @param inst                the instance to generate text from
 
2803
   * @param instNum             the index in the dataset
 
2804
   * @param attributesToOutput  the indices of the attributes to output
 
2805
   * @param printDistribution   prints the complete distribution for nominal 
 
2806
   *                            classes, not just the predicted value
 
2807
   * @return                    the generated text
 
2808
   * @throws Exception          if something goes wrong
 
2809
   * @see                       #printClassifications(Classifier, Instances, String, int, Range, boolean)
 
2810
   */
 
2811
  protected static String predictionText(Classifier classifier, 
 
2812
      Instance inst, 
 
2813
      int instNum,
 
2814
      Range attributesToOutput,
 
2815
      boolean printDistribution) 
 
2816
  throws Exception {
 
2817
 
 
2818
    StringBuffer result = new StringBuffer();
 
2819
    int width = 10;
 
2820
    int prec = 3;
 
2821
 
 
2822
    Instance withMissing = (Instance)inst.copy();
 
2823
    withMissing.setDataset(inst.dataset());
 
2824
    double predValue = ((Classifier)classifier).classifyInstance(withMissing);
 
2825
 
 
2826
    // index
 
2827
    result.append(Utils.padLeft("" + (instNum+1), 6));
 
2828
 
 
2829
    if (inst.dataset().classAttribute().isNumeric()) {
 
2830
      // actual
 
2831
      if (inst.classIsMissing())
 
2832
        result.append(" " + Utils.padLeft("?", width));
 
2833
      else
 
2834
        result.append(" " + Utils.doubleToString(inst.classValue(), width, prec));
 
2835
      // predicted
 
2836
      if (Instance.isMissingValue(predValue))
 
2837
        result.append(" " + Utils.padLeft("?", width));
 
2838
      else
 
2839
        result.append(" " + Utils.doubleToString(predValue, width, prec));
 
2840
      // error
 
2841
      if (Instance.isMissingValue(predValue) || inst.classIsMissing())
 
2842
        result.append(" " + Utils.padLeft("?", width));
 
2843
      else
 
2844
        result.append(" " + Utils.doubleToString(predValue - inst.classValue(), width, prec));
 
2845
    } else {
 
2846
      // actual
 
2847
      result.append(" " + Utils.padLeft(((int) inst.classValue()+1) + ":" + inst.toString(inst.classIndex()), width));
 
2848
      // predicted
 
2849
      if (Instance.isMissingValue(predValue))
 
2850
        result.append(" " + Utils.padLeft("?", width));
 
2851
      else
 
2852
        result.append(" " + Utils.padLeft(((int) predValue+1) + ":" + inst.dataset().classAttribute().value((int)predValue), width));
 
2853
      // error?
 
2854
      if ((int) predValue+1 != (int) inst.classValue()+1)
 
2855
        result.append(" " + "  +  ");
 
2856
      else
 
2857
        result.append(" " + "     ");
 
2858
      // prediction/distribution
 
2859
      if (printDistribution) {
 
2860
        if (Instance.isMissingValue(predValue)) {
 
2861
          result.append(" " + "?");
 
2862
        }
 
2863
        else {
 
2864
          result.append(" ");
 
2865
          double[] dist = classifier.distributionForInstance(withMissing);
 
2866
          for (int n = 0; n < dist.length; n++) {
 
2867
            if (n > 0)
 
2868
              result.append(",");
 
2869
            if (n == (int) predValue)
 
2870
              result.append("*");
 
2871
            result.append(Utils.doubleToString(dist[n], prec));
 
2872
          }
 
2873
        }
 
2874
      }
 
2875
      else {
 
2876
        if (Instance.isMissingValue(predValue))
 
2877
          result.append(" " + "?");
 
2878
        else
 
2879
          result.append(" " + Utils.doubleToString(classifier.distributionForInstance(withMissing) [(int)predValue], prec));
 
2880
      }
 
2881
    }
 
2882
 
 
2883
    // attributes
 
2884
    result.append(" " + attributeValuesString(withMissing, attributesToOutput) + "\n");
 
2885
 
 
2886
    return result.toString();
 
2887
  }
 
2888
 
 
2889
  /**
 
2890
   * Builds a string listing the attribute values in a specified range of indices,
 
2891
   * separated by commas and enclosed in brackets.
 
2892
   *
 
2893
   * @param instance the instance to print the values from
 
2894
   * @param attRange the range of the attributes to list
 
2895
   * @return a string listing values of the attributes in the range
 
2896
   */
 
2897
  protected static String attributeValuesString(Instance instance, Range attRange) {
 
2898
    StringBuffer text = new StringBuffer();
 
2899
    if (attRange != null) {
 
2900
      boolean firstOutput = true;
 
2901
      attRange.setUpper(instance.numAttributes() - 1);
 
2902
      for (int i=0; i<instance.numAttributes(); i++)
 
2903
        if (attRange.isInRange(i) && i != instance.classIndex()) {
 
2904
          if (firstOutput) text.append("(");
 
2905
          else text.append(",");
 
2906
          text.append(instance.toString(i));
 
2907
          firstOutput = false;
 
2908
        }
 
2909
      if (!firstOutput) text.append(")");
 
2910
    }
 
2911
    return text.toString();
 
2912
  }
 
2913
 
 
2914
  /**
 
2915
   * Make up the help string giving all the command line options
 
2916
   *
 
2917
   * @param classifier the classifier to include options for
 
2918
   * @return a string detailing the valid command line options
 
2919
   */
 
2920
  protected static String makeOptionString(Classifier classifier) {
 
2921
 
 
2922
    StringBuffer optionsText = new StringBuffer("");
 
2923
 
 
2924
    // General options
 
2925
    optionsText.append("\n\nGeneral options:\n\n");
 
2926
    optionsText.append("-t <name of training file>\n");
 
2927
    optionsText.append("\tSets training file.\n");
 
2928
    optionsText.append("-T <name of test file>\n");
 
2929
    optionsText.append("\tSets test file. If missing, a cross-validation will be performed\n");
 
2930
    optionsText.append("\ton the training data.\n");
 
2931
    optionsText.append("-c <class index>\n");
 
2932
    optionsText.append("\tSets index of class attribute (default: last).\n");
 
2933
    optionsText.append("-x <number of folds>\n");
 
2934
    optionsText.append("\tSets number of folds for cross-validation (default: 10).\n");
 
2935
    optionsText.append("-no-cv\n");
 
2936
    optionsText.append("\tDo not perform any cross validation.\n");
 
2937
    optionsText.append("-split-percentage <percentage>\n");
 
2938
    optionsText.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
 
2939
    optionsText.append("-preserve-order\n");
 
2940
    optionsText.append("\tPreserves the order in the percentage split.\n");
 
2941
    optionsText.append("-s <random number seed>\n");
 
2942
    optionsText.append("\tSets random number seed for cross-validation or percentage split\n");
 
2943
    optionsText.append("\t(default: 1).\n");
 
2944
    optionsText.append("-m <name of file with cost matrix>\n");
 
2945
    optionsText.append("\tSets file with cost matrix.\n");
 
2946
    optionsText.append("-l <name of input file>\n");
 
2947
    optionsText.append("\tSets model input file. In case the filename ends with '.xml',\n");
 
2948
    optionsText.append("\tthe options are loaded from the XML file.\n");
 
2949
    optionsText.append("-d <name of output file>\n");
 
2950
    optionsText.append("\tSets model output file. In case the filename ends with '.xml',\n");
 
2951
    optionsText.append("\tonly the options are saved to the XML file, not the model.\n");
 
2952
    optionsText.append("-v\n");
 
2953
    optionsText.append("\tOutputs no statistics for training data.\n");
 
2954
    optionsText.append("-o\n");
 
2955
    optionsText.append("\tOutputs statistics only, not the classifier.\n");
 
2956
    optionsText.append("-i\n");
 
2957
    optionsText.append("\tOutputs detailed information-retrieval");
 
2958
    optionsText.append(" statistics for each class.\n");
 
2959
    optionsText.append("-k\n");
 
2960
    optionsText.append("\tOutputs information-theoretic statistics.\n");
 
2961
    optionsText.append("-p <attribute range>\n");
 
2962
    optionsText.append("\tOnly outputs predictions for test instances (or the train\n"
 
2963
        + "\tinstances if no test instances provided), along with attributes\n"
 
2964
        + "\t(0 for none).\n");
 
2965
    optionsText.append("-distribution\n");
 
2966
    optionsText.append("\tOutputs the distribution instead of only the prediction\n");
 
2967
    optionsText.append("\tin conjunction with the '-p' option (only nominal classes).\n");
 
2968
    optionsText.append("-r\n");
 
2969
    optionsText.append("\tOnly outputs cumulative margin distribution.\n");
 
2970
    if (classifier instanceof Sourcable) {
 
2971
      optionsText.append("-z <class name>\n");
 
2972
      optionsText.append("\tOnly outputs the source representation"
 
2973
          + " of the classifier,\n\tgiving it the supplied"
 
2974
          + " name.\n");
 
2975
    }
 
2976
    if (classifier instanceof Drawable) {
 
2977
      optionsText.append("-g\n");
 
2978
      optionsText.append("\tOnly outputs the graph representation"
 
2979
          + " of the classifier.\n");
 
2980
    }
 
2981
    optionsText.append("-xml filename | xml-string\n");
 
2982
    optionsText.append("\tRetrieves the options from the XML-data instead of the " 
 
2983
        + "command line.\n");
 
2984
    optionsText.append("-threshold-file <file>\n");
 
2985
    optionsText.append("\tThe file to save the threshold data to.\n"
 
2986
        + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
 
2987
        + "\tformat or '.csv' for CSV.\n");
 
2988
    optionsText.append("-threshold-label <label>\n");
 
2989
    optionsText.append("\tThe class label to determine the threshold data for\n"
 
2990
        + "\t(default is the first label)\n");
 
2991
 
 
2992
    // Get scheme-specific options
 
2993
    if (classifier instanceof OptionHandler) {
 
2994
      optionsText.append("\nOptions specific to "
 
2995
          + classifier.getClass().getName()
 
2996
          + ":\n\n");
 
2997
      Enumeration enu = ((OptionHandler)classifier).listOptions();
 
2998
      while (enu.hasMoreElements()) {
 
2999
        Option option = (Option) enu.nextElement();
 
3000
        optionsText.append(option.synopsis() + '\n');
 
3001
        optionsText.append(option.description() + "\n");
 
3002
      }
 
3003
    }
 
3004
    return optionsText.toString();
 
3005
  }
 
3006
 
 
3007
  /**
 
3008
   * Method for generating indices for the confusion matrix.
 
3009
   *
 
3010
   * @param num         integer to format
 
3011
   * @param IDChars     the characters to use
 
3012
   * @param IDWidth     the width of the entry
 
3013
   * @return            the formatted integer as a string
 
3014
   */
 
3015
  protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
 
3016
 
 
3017
    char ID [] = new char [IDWidth];
 
3018
    int i;
 
3019
 
 
3020
    for(i = IDWidth - 1; i >=0; i--) {
 
3021
      ID[i] = IDChars[num % IDChars.length];
 
3022
      num = num / IDChars.length - 1;
 
3023
      if (num < 0) {
 
3024
        break;
 
3025
      }
 
3026
    }
 
3027
    for(i--; i >= 0; i--) {
 
3028
      ID[i] = ' ';
 
3029
    }
 
3030
 
 
3031
    return new String(ID);
 
3032
  }
 
3033
 
 
3034
  /**
 
3035
   * Convert a single prediction into a probability distribution
 
3036
   * with all zero probabilities except the predicted value which
 
3037
   * has probability 1.0;
 
3038
   *
 
3039
   * @param predictedClass the index of the predicted class
 
3040
   * @return the probability distribution
 
3041
   */
 
3042
  protected double [] makeDistribution(double predictedClass) {
 
3043
 
 
3044
    double [] result = new double [m_NumClasses];
 
3045
    if (Instance.isMissingValue(predictedClass)) {
 
3046
      return result;
 
3047
    }
 
3048
    if (m_ClassIsNominal) {
 
3049
      result[(int)predictedClass] = 1.0;
 
3050
    } else {
 
3051
      result[0] = predictedClass;
 
3052
    }
 
3053
    return result;
 
3054
  } 
 
3055
 
 
3056
  /**
 
3057
   * Updates all the statistics about a classifiers performance for 
 
3058
   * the current test instance.
 
3059
   *
 
3060
   * @param predictedDistribution the probabilities assigned to 
 
3061
   * each class
 
3062
   * @param instance the instance to be classified
 
3063
   * @throws Exception if the class of the instance is not
 
3064
   * set
 
3065
   */
 
3066
  protected void updateStatsForClassifier(double [] predictedDistribution,
 
3067
      Instance instance)
 
3068
  throws Exception {
 
3069
 
 
3070
    int actualClass = (int)instance.classValue();
 
3071
 
 
3072
    if (!instance.classIsMissing()) {
 
3073
      updateMargins(predictedDistribution, actualClass, instance.weight());
 
3074
 
 
3075
      // Determine the predicted class (doesn't detect multiple 
 
3076
      // classifications)
 
3077
      int predictedClass = -1;
 
3078
      double bestProb = 0.0;
 
3079
      for(int i = 0; i < m_NumClasses; i++) {
 
3080
        if (predictedDistribution[i] > bestProb) {
 
3081
          predictedClass = i;
 
3082
          bestProb = predictedDistribution[i];
 
3083
        }
 
3084
      }
 
3085
 
 
3086
      m_WithClass += instance.weight();
 
3087
 
 
3088
      // Determine misclassification cost
 
3089
      if (m_CostMatrix != null) {
 
3090
        if (predictedClass < 0) {
 
3091
          // For missing predictions, we assume the worst possible cost.
 
3092
          // This is pretty harsh.
 
3093
          // Perhaps we could take the negative of the cost of a correct
 
3094
          // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
 
3095
          // although often this will be zero
 
3096
          m_TotalCost += instance.weight()
 
3097
          * m_CostMatrix.getMaxCost(actualClass, instance);
 
3098
        } else {
 
3099
          m_TotalCost += instance.weight() 
 
3100
          * m_CostMatrix.getElement(actualClass, predictedClass,
 
3101
              instance);
 
3102
        }
 
3103
      }
 
3104
 
 
3105
      // Update counts when no class was predicted
 
3106
      if (predictedClass < 0) {
 
3107
        m_Unclassified += instance.weight();
 
3108
        return;
 
3109
      }
 
3110
 
 
3111
      double predictedProb = Math.max(MIN_SF_PROB,
 
3112
          predictedDistribution[actualClass]);
 
3113
      double priorProb = Math.max(MIN_SF_PROB,
 
3114
          m_ClassPriors[actualClass]
 
3115
                        / m_ClassPriorsSum);
 
3116
      if (predictedProb >= priorProb) {
 
3117
        m_SumKBInfo += (Utils.log2(predictedProb) - 
 
3118
            Utils.log2(priorProb))
 
3119
            * instance.weight();
 
3120
      } else {
 
3121
        m_SumKBInfo -= (Utils.log2(1.0-predictedProb) - 
 
3122
            Utils.log2(1.0-priorProb))
 
3123
            * instance.weight();
 
3124
      }
 
3125
 
 
3126
      m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
 
3127
      m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
 
3128
 
 
3129
      updateNumericScores(predictedDistribution, 
 
3130
          makeDistribution(instance.classValue()), 
 
3131
          instance.weight());
 
3132
 
 
3133
      // Update other stats
 
3134
      m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
 
3135
      if (predictedClass != actualClass) {
 
3136
        m_Incorrect += instance.weight();
 
3137
      } else {
 
3138
        m_Correct += instance.weight();
 
3139
      }
 
3140
    } else {
 
3141
      m_MissingClass += instance.weight();
 
3142
    }
 
3143
  }
 
3144
 
 
3145
  /**
 
3146
   * Updates all the statistics about a predictors performance for 
 
3147
   * the current test instance.
 
3148
   *
 
3149
   * @param predictedValue the numeric value the classifier predicts
 
3150
   * @param instance the instance to be classified
 
3151
   * @throws Exception if the class of the instance is not
 
3152
   * set
 
3153
   */
 
3154
  protected void updateStatsForPredictor(double predictedValue,
 
3155
      Instance instance) 
 
3156
  throws Exception {
 
3157
 
 
3158
    if (!instance.classIsMissing()){
 
3159
 
 
3160
      // Update stats
 
3161
      m_WithClass += instance.weight();
 
3162
      if (Instance.isMissingValue(predictedValue)) {
 
3163
        m_Unclassified += instance.weight();
 
3164
        return;
 
3165
      }
 
3166
      m_SumClass += instance.weight() * instance.classValue();
 
3167
      m_SumSqrClass += instance.weight() * instance.classValue()
 
3168
      * instance.classValue();
 
3169
      m_SumClassPredicted += instance.weight() 
 
3170
      * instance.classValue() * predictedValue;
 
3171
      m_SumPredicted += instance.weight() * predictedValue;
 
3172
      m_SumSqrPredicted += instance.weight() * predictedValue * predictedValue;
 
3173
 
 
3174
      if (m_ErrorEstimator == null) {
 
3175
        setNumericPriorsFromBuffer();
 
3176
      }
 
3177
      double predictedProb = Math.max(m_ErrorEstimator.getProbability(
 
3178
          predictedValue 
 
3179
          - instance.classValue()),
 
3180
          MIN_SF_PROB);
 
3181
      double priorProb = Math.max(m_PriorErrorEstimator.getProbability(
 
3182
          instance.classValue()),
 
3183
          MIN_SF_PROB);
 
3184
 
 
3185
      m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
 
3186
      m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
 
3187
      m_ErrorEstimator.addValue(predictedValue - instance.classValue(), 
 
3188
          instance.weight());
 
3189
 
 
3190
      updateNumericScores(makeDistribution(predictedValue),
 
3191
          makeDistribution(instance.classValue()),
 
3192
          instance.weight());
 
3193
 
 
3194
    } else
 
3195
      m_MissingClass += instance.weight();
 
3196
  }
 
3197
 
 
3198
  /**
 
3199
   * Update the cumulative record of classification margins
 
3200
   *
 
3201
   * @param predictedDistribution the probability distribution predicted for
 
3202
   * the current instance
 
3203
   * @param actualClass the index of the actual instance class
 
3204
   * @param weight the weight assigned to the instance
 
3205
   */
 
3206
  protected void updateMargins(double [] predictedDistribution, 
 
3207
      int actualClass, double weight) {
 
3208
 
 
3209
    double probActual = predictedDistribution[actualClass];
 
3210
    double probNext = 0;
 
3211
 
 
3212
    for(int i = 0; i < m_NumClasses; i++)
 
3213
      if ((i != actualClass) &&
 
3214
          (predictedDistribution[i] > probNext))
 
3215
        probNext = predictedDistribution[i];
 
3216
 
 
3217
    double margin = probActual - probNext;
 
3218
    int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
 
3219
    m_MarginCounts[bin] += weight;
 
3220
  }
 
3221
 
 
3222
  /**
 
3223
   * Update the numeric accuracy measures. For numeric classes, the
 
3224
   * accuracy is between the actual and predicted class values. For 
 
3225
   * nominal classes, the accuracy is between the actual and 
 
3226
   * predicted class probabilities.
 
3227
   *
 
3228
   * @param predicted the predicted values
 
3229
   * @param actual the actual value
 
3230
   * @param weight the weight associated with this prediction
 
3231
   */
 
3232
  protected void updateNumericScores(double [] predicted, 
 
3233
      double [] actual, double weight) {
 
3234
 
 
3235
    double diff;
 
3236
    double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
 
3237
    double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
 
3238
    for(int i = 0; i < m_NumClasses; i++) {
 
3239
      diff = predicted[i] - actual[i];
 
3240
      sumErr += diff;
 
3241
      sumAbsErr += Math.abs(diff);
 
3242
      sumSqrErr += diff * diff;
 
3243
      diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
 
3244
      sumPriorAbsErr += Math.abs(diff);
 
3245
      sumPriorSqrErr += diff * diff;
 
3246
    }
 
3247
    m_SumErr += weight * sumErr / m_NumClasses;
 
3248
    m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
 
3249
    m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
 
3250
    m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
 
3251
    m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
 
3252
  }
 
3253
 
 
3254
  /**
 
3255
   * Adds a numeric (non-missing) training class value and weight to 
 
3256
   * the buffer of stored values.
 
3257
   *
 
3258
   * @param classValue the class value
 
3259
   * @param weight the instance weight
 
3260
   */
 
3261
  protected void addNumericTrainClass(double classValue, double weight) {
 
3262
 
 
3263
    if (m_TrainClassVals == null) {
 
3264
      m_TrainClassVals = new double [100];
 
3265
      m_TrainClassWeights = new double [100];
 
3266
    }
 
3267
    if (m_NumTrainClassVals == m_TrainClassVals.length) {
 
3268
      double [] temp = new double [m_TrainClassVals.length * 2];
 
3269
      System.arraycopy(m_TrainClassVals, 0, 
 
3270
          temp, 0, m_TrainClassVals.length);
 
3271
      m_TrainClassVals = temp;
 
3272
 
 
3273
      temp = new double [m_TrainClassWeights.length * 2];
 
3274
      System.arraycopy(m_TrainClassWeights, 0, 
 
3275
          temp, 0, m_TrainClassWeights.length);
 
3276
      m_TrainClassWeights = temp;
 
3277
    }
 
3278
    m_TrainClassVals[m_NumTrainClassVals] = classValue;
 
3279
    m_TrainClassWeights[m_NumTrainClassVals] = weight;
 
3280
    m_NumTrainClassVals++;
 
3281
  }
 
3282
 
 
3283
  /**
 
3284
   * Sets up the priors for numeric class attributes from the 
 
3285
   * training class values that have been seen so far.
 
3286
   */
 
3287
  protected void setNumericPriorsFromBuffer() {
 
3288
 
 
3289
    double numPrecision = 0.01; // Default value
 
3290
    if (m_NumTrainClassVals > 1) {
 
3291
      double [] temp = new double [m_NumTrainClassVals];
 
3292
      System.arraycopy(m_TrainClassVals, 0, temp, 0, m_NumTrainClassVals);
 
3293
      int [] index = Utils.sort(temp);
 
3294
      double lastVal = temp[index[0]];
 
3295
      double deltaSum = 0;
 
3296
      int distinct = 0;
 
3297
      for (int i = 1; i < temp.length; i++) {
 
3298
        double current = temp[index[i]];
 
3299
        if (current != lastVal) {
 
3300
          deltaSum += current - lastVal;
 
3301
          lastVal = current;
 
3302
          distinct++;
 
3303
        }
 
3304
      }
 
3305
      if (distinct > 0) {
 
3306
        numPrecision = deltaSum / distinct;
 
3307
      }
 
3308
    }
 
3309
    m_PriorErrorEstimator = new KernelEstimator(numPrecision);
 
3310
    m_ErrorEstimator = new KernelEstimator(numPrecision);
 
3311
    m_ClassPriors[0] = m_ClassPriorsSum = 0;
 
3312
    for (int i = 0; i < m_NumTrainClassVals; i++) {
 
3313
      m_ClassPriors[0] += m_TrainClassVals[i] * m_TrainClassWeights[i];
 
3314
      m_ClassPriorsSum += m_TrainClassWeights[i];
 
3315
      m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
 
3316
          m_TrainClassWeights[i]);
 
3317
    }
 
3318
  }
 
3319
}