2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
23
package weka.classifiers;
25
import weka.classifiers.evaluation.NominalPrediction;
26
import weka.classifiers.evaluation.ThresholdCurve;
27
import weka.classifiers.xml.XMLClassifier;
28
import weka.core.Drawable;
29
import weka.core.FastVector;
30
import weka.core.Instance;
31
import weka.core.Instances;
32
import weka.core.Option;
33
import weka.core.OptionHandler;
34
import weka.core.Range;
35
import weka.core.Summarizable;
36
import weka.core.Utils;
37
import weka.core.Version;
38
import weka.core.converters.ConverterUtils.DataSink;
39
import weka.core.converters.ConverterUtils.DataSource;
40
import weka.core.xml.KOML;
41
import weka.core.xml.XMLOptions;
42
import weka.core.xml.XMLSerialization;
43
import weka.estimators.Estimator;
44
import weka.estimators.KernelEstimator;
46
import java.io.BufferedInputStream;
47
import java.io.BufferedOutputStream;
48
import java.io.BufferedReader;
49
import java.io.FileInputStream;
50
import java.io.FileOutputStream;
51
import java.io.FileReader;
52
import java.io.InputStream;
53
import java.io.ObjectInputStream;
54
import java.io.ObjectOutputStream;
55
import java.io.OutputStream;
56
import java.io.Reader;
57
import java.util.Date;
58
import java.util.Enumeration;
59
import java.util.Random;
60
import java.util.zip.GZIPInputStream;
61
import java.util.zip.GZIPOutputStream;
64
* Class for evaluating machine learning models. <p/>
66
* ------------------------------------------------------------------- <p/>
68
* General options when evaluating a learning scheme from the command-line: <p/>
71
* Name of the file with the training data. (required) <p/>
74
* Name of the file with the test data. If missing a cross-validation
78
* Index of the class attribute (1, 2, ...; default: last). <p/>
81
* The number of folds for the cross-validation (default: 10). <p/>
84
* No cross validation. If no test file is provided, no evaluation
87
* -split-percentage percentage <br/>
88
* Sets the percentage for the train/test set split, e.g., 66. <p/>
90
* -preserve-order <br/>
91
* Preserves the order in the percentage split instead of randomizing
92
* the data first with the seed value ('-s'). <p/>
95
* Random number seed for the cross-validation and percentage split
99
* The name of a file containing a cost matrix. <p/>
102
* Loads classifier from the given file. In case the filename ends with ".xml"
103
* the options are loaded from XML. <p/>
106
* Saves classifier built from the training data into the given file. In case
107
* the filename ends with ".xml" the options are saved XML, not the model. <p/>
110
* Outputs no statistics for the training data. <p/>
113
* Outputs statistics only, not the classifier. <p/>
116
* Outputs information-retrieval statistics per class. <p/>
119
* Outputs information-theoretic statistics. <p/>
122
* Outputs predictions for test instances (or the train instances if no test
123
* instances provided), along with the attributes in the specified range
124
* (and nothing else). Use '-p 0' if no attributes are desired. <p/>
126
* -distribution <br/>
127
* Outputs the distribution instead of only the prediction
128
* in conjunction with the '-p' option (only nominal classes). <p/>
131
* Outputs cumulative margin distribution (and nothing else). <p/>
134
* Only for classifiers that implement "Graphable." Outputs
135
* the graph representation of the classifier (and nothing
138
* -xml filename | xml-string <br/>
139
* Retrieves the options from the XML-data instead of the command line. <p/>
141
* -threshold-file file <br/>
142
* The file to save the threshold data to.
143
* The format is determined by the extensions, e.g., '.arff' for ARFF
144
* format or '.csv' for CSV. <p/>
146
* -threshold-label label <br/>
147
* The class label to determine the threshold data for
148
* (default is the first label) <p/>
150
* ------------------------------------------------------------------- <p/>
152
* Example usage as the main of a classifier (called FunkyClassifier):
154
* public static void main(String [] args) {
155
* runClassifier(new FunkyClassifier(), args);
160
* ------------------------------------------------------------------ <p/>
162
* Example usage from within an application:
164
* Instances trainInstances = ... instances got from somewhere
165
* Instances testInstances = ... instances got from somewhere
166
* Classifier scheme = ... scheme got from somewhere
168
* Evaluation evaluation = new Evaluation(trainInstances);
169
* evaluation.evaluateModel(scheme, testInstances);
170
* System.out.println(evaluation.toSummaryString());
174
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
175
* @author Len Trigg (trigg@cs.waikato.ac.nz)
176
* @version $Revision: 1.83 $
178
public class Evaluation
179
implements Summarizable {
181
/** The number of classes. */
182
protected int m_NumClasses;
184
/** The number of folds for a cross-validation. */
185
protected int m_NumFolds;
187
/** The weight of all incorrectly classified instances. */
188
protected double m_Incorrect;
190
/** The weight of all correctly classified instances. */
191
protected double m_Correct;
193
/** The weight of all unclassified instances. */
194
protected double m_Unclassified;
196
/*** The weight of all instances that had no class assigned to them. */
197
protected double m_MissingClass;
199
/** The weight of all instances that had a class assigned to them. */
200
protected double m_WithClass;
202
/** Array for storing the confusion matrix. */
203
protected double [][] m_ConfusionMatrix;
205
/** The names of the classes. */
206
protected String [] m_ClassNames;
208
/** Is the class nominal or numeric? */
209
protected boolean m_ClassIsNominal;
211
/** The prior probabilities of the classes */
212
protected double [] m_ClassPriors;
214
/** The sum of counts for priors */
215
protected double m_ClassPriorsSum;
217
/** The cost matrix (if given). */
218
protected CostMatrix m_CostMatrix;
220
/** The total cost of predictions (includes instance weights) */
221
protected double m_TotalCost;
223
/** Sum of errors. */
224
protected double m_SumErr;
226
/** Sum of absolute errors. */
227
protected double m_SumAbsErr;
229
/** Sum of squared errors. */
230
protected double m_SumSqrErr;
232
/** Sum of class values. */
233
protected double m_SumClass;
235
/** Sum of squared class values. */
236
protected double m_SumSqrClass;
238
/*** Sum of predicted values. */
239
protected double m_SumPredicted;
241
/** Sum of squared predicted values. */
242
protected double m_SumSqrPredicted;
244
/** Sum of predicted * class values. */
245
protected double m_SumClassPredicted;
247
/** Sum of absolute errors of the prior */
248
protected double m_SumPriorAbsErr;
250
/** Sum of absolute errors of the prior */
251
protected double m_SumPriorSqrErr;
253
/** Total Kononenko & Bratko Information */
254
protected double m_SumKBInfo;
256
/*** Resolution of the margin histogram */
257
protected static int k_MarginResolution = 500;
259
/** Cumulative margin distribution */
260
protected double m_MarginCounts [];
262
/** Number of non-missing class training instances seen */
263
protected int m_NumTrainClassVals;
265
/** Array containing all numeric training class values seen */
266
protected double [] m_TrainClassVals;
268
/** Array containing all numeric training class weights */
269
protected double [] m_TrainClassWeights;
271
/** Numeric class error estimator for prior */
272
protected Estimator m_PriorErrorEstimator;
274
/** Numeric class error estimator for scheme */
275
protected Estimator m_ErrorEstimator;
278
* The minimum probablility accepted from an estimator to avoid
279
* taking log(0) in Sf calculations.
281
protected static final double MIN_SF_PROB = Double.MIN_VALUE;
283
/** Total entropy of prior predictions */
284
protected double m_SumPriorEntropy;
286
/** Total entropy of scheme predictions */
287
protected double m_SumSchemeEntropy;
289
/** The list of predictions that have been generated (for computing AUC) */
290
private FastVector m_Predictions;
292
/** enables/disables the use of priors, e.g., if no training set is
293
* present in case of de-serialized schemes */
294
protected boolean m_NoPriors = false;
297
* Initializes all the counters for the evaluation.
298
* Use <code>useNoPriors()</code> if the dataset is the test set and you
299
* can't initialize with the priors from the training set via
300
* <code>setPriors(Instances)</code>.
302
* @param data set of training instances, to get some header
303
* information and prior class distribution information
304
* @throws Exception if the class is not defined
305
* @see #useNoPriors()
306
* @see #setPriors(Instances)
308
public Evaluation(Instances data) throws Exception {
314
* Initializes all the counters for the evaluation and also takes a
315
* cost matrix as parameter.
316
* Use <code>useNoPriors()</code> if the dataset is the test set and you
317
* can't initialize with the priors from the training set via
318
* <code>setPriors(Instances)</code>.
320
* @param data set of training instances, to get some header
321
* information and prior class distribution information
322
* @param costMatrix the cost matrix---if null, default costs will be used
323
* @throws Exception if cost matrix is not compatible with
324
* data, the class is not defined or the class is numeric
325
* @see #useNoPriors()
326
* @see #setPriors(Instances)
328
public Evaluation(Instances data, CostMatrix costMatrix)
331
m_NumClasses = data.numClasses();
333
m_ClassIsNominal = data.classAttribute().isNominal();
335
if (m_ClassIsNominal) {
336
m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
337
m_ClassNames = new String [m_NumClasses];
338
for(int i = 0; i < m_NumClasses; i++) {
339
m_ClassNames[i] = data.classAttribute().value(i);
342
m_CostMatrix = costMatrix;
343
if (m_CostMatrix != null) {
344
if (!m_ClassIsNominal) {
345
throw new Exception("Class has to be nominal if cost matrix " +
348
if (m_CostMatrix.size() != m_NumClasses) {
349
throw new Exception("Cost matrix not compatible with data!");
352
m_ClassPriors = new double [m_NumClasses];
354
m_MarginCounts = new double [k_MarginResolution + 1];
358
* Returns the area under ROC for those predictions that have been collected
359
* in the evaluateClassifier(Classifier, Instances) method. Returns
360
* Instance.missingValue() if the area is not available.
362
* @param classIndex the index of the class to consider as "positive"
363
* @return the area under the ROC curve or not a number
365
public double areaUnderROC(int classIndex) {
367
// Check if any predictions have been collected
368
if (m_Predictions == null) {
369
return Instance.missingValue();
371
ThresholdCurve tc = new ThresholdCurve();
372
Instances result = tc.getCurve(m_Predictions, classIndex);
373
return ThresholdCurve.getROCArea(result);
378
* Returns a copy of the confusion matrix.
380
* @return a copy of the confusion matrix as a two-dimensional array
382
public double[][] confusionMatrix() {
384
double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
386
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
387
newMatrix[i] = new double[m_ConfusionMatrix[i].length];
388
System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
389
m_ConfusionMatrix[i].length);
395
* Performs a (stratified if class is nominal) cross-validation
396
* for a classifier on a set of instances. Now performs
397
* a deep copy of the classifier before each call to
398
* buildClassifier() (just in case the classifier is not
399
* initialized properly).
401
* @param classifier the classifier with any options set.
402
* @param data the data on which the cross-validation is to be
404
* @param numFolds the number of folds for the cross-validation
405
* @param random random number generator for randomization
406
* @throws Exception if a classifier could not be generated
407
* successfully or the class is not defined
409
public void crossValidateModel(Classifier classifier,
410
Instances data, int numFolds, Random random)
413
// Make a copy of the data we can reorder
414
data = new Instances(data);
415
data.randomize(random);
416
if (data.classAttribute().isNominal()) {
417
data.stratify(numFolds);
420
for (int i = 0; i < numFolds; i++) {
421
Instances train = data.trainCV(numFolds, i, random);
423
Classifier copiedClassifier = Classifier.makeCopy(classifier);
424
copiedClassifier.buildClassifier(train);
425
Instances test = data.testCV(numFolds, i);
426
evaluateModel(copiedClassifier, test);
428
m_NumFolds = numFolds;
432
* Performs a (stratified if class is nominal) cross-validation
433
* for a classifier on a set of instances.
435
* @param classifierString a string naming the class of the classifier
436
* @param data the data on which the cross-validation is to be
438
* @param numFolds the number of folds for the cross-validation
439
* @param options the options to the classifier. Any options
440
* @param random the random number generator for randomizing the data
441
* accepted by the classifier will be removed from this array.
442
* @throws Exception if a classifier could not be generated
443
* successfully or the class is not defined
445
public void crossValidateModel(String classifierString,
446
Instances data, int numFolds,
447
String[] options, Random random)
450
crossValidateModel(Classifier.forName(classifierString, options),
451
data, numFolds, random);
455
* Evaluates a classifier with the options given in an array of
458
* Valid options are: <p/>
461
* Name of the file with the training data. (required) <p/>
464
* Name of the file with the test data. If missing a cross-validation
468
* Index of the class attribute (1, 2, ...; default: last). <p/>
471
* The number of folds for the cross-validation (default: 10). <p/>
474
* No cross validation. If no test file is provided, no evaluation
477
* -split-percentage percentage <br/>
478
* Sets the percentage for the train/test set split, e.g., 66. <p/>
480
* -preserve-order <br/>
481
* Preserves the order in the percentage split instead of randomizing
482
* the data first with the seed value ('-s'). <p/>
485
* Random number seed for the cross-validation and percentage split
489
* The name of a file containing a cost matrix. <p/>
492
* Loads classifier from the given file. In case the filename ends with
493
* ".xml" the options are loaded from XML. <p/>
496
* Saves classifier built from the training data into the given file. In case
497
* the filename ends with ".xml" the options are saved XML, not the model. <p/>
500
* Outputs no statistics for the training data. <p/>
503
* Outputs statistics only, not the classifier. <p/>
506
* Outputs detailed information-retrieval statistics per class. <p/>
509
* Outputs information-theoretic statistics. <p/>
512
* Outputs predictions for test instances (or the train instances if no test
513
* instances provided), along with the attributes in the specified range (and
514
* nothing else). Use '-p 0' if no attributes are desired. <p/>
516
* -distribution <br/>
517
* Outputs the distribution instead of only the prediction
518
* in conjunction with the '-p' option (only nominal classes). <p/>
521
* Outputs cumulative margin distribution (and nothing else). <p/>
524
* Only for classifiers that implement "Graphable." Outputs
525
* the graph representation of the classifier (and nothing
528
* -xml filename | xml-string <br/>
529
* Retrieves the options from the XML-data instead of the command line. <p/>
531
* -threshold-file file <br/>
532
* The file to save the threshold data to.
533
* The format is determined by the extensions, e.g., '.arff' for ARFF
534
* format or '.csv' for CSV. <p/>
536
* -threshold-label label <br/>
537
* The class label to determine the threshold data for
538
* (default is the first label) <p/>
540
* @param classifierString class of machine learning classifier as a string
541
* @param options the array of string containing the options
542
* @throws Exception if model could not be evaluated successfully
543
* @return a string describing the results
545
public static String evaluateModel(String classifierString,
546
String [] options) throws Exception {
548
Classifier classifier;
553
(Classifier)Class.forName(classifierString).newInstance();
554
} catch (Exception e) {
555
throw new Exception("Can't find class with name "
556
+ classifierString + '.');
558
return evaluateModel(classifier, options);
562
* A test method for this class. Just extracts the first command line
563
* argument as a classifier class name and calls evaluateModel.
564
* @param args an array of command line arguments, the first of which
565
* must be the class name of a classifier.
567
public static void main(String [] args) {
570
if (args.length == 0) {
571
throw new Exception("The first argument must be the class name"
572
+ " of a classifier");
574
String classifier = args[0];
576
System.out.println(evaluateModel(classifier, args));
577
} catch (Exception ex) {
578
ex.printStackTrace();
579
System.err.println(ex.getMessage());
584
* Evaluates a classifier with the options given in an array of
587
* Valid options are: <p/>
589
* -t name of training file <br/>
590
* Name of the file with the training data. (required) <p/>
592
* -T name of test file <br/>
593
* Name of the file with the test data. If missing a cross-validation
596
* -c class index <br/>
597
* Index of the class attribute (1, 2, ...; default: last). <p/>
599
* -x number of folds <br/>
600
* The number of folds for the cross-validation (default: 10). <p/>
603
* No cross validation. If no test file is provided, no evaluation
606
* -split-percentage percentage <br/>
607
* Sets the percentage for the train/test set split, e.g., 66. <p/>
609
* -preserve-order <br/>
610
* Preserves the order in the percentage split instead of randomizing
611
* the data first with the seed value ('-s'). <p/>
614
* Random number seed for the cross-validation and percentage split
617
* -m file with cost matrix <br/>
618
* The name of a file containing a cost matrix. <p/>
621
* Loads classifier from the given file. In case the filename ends with
622
* ".xml" the options are loaded from XML. <p/>
625
* Saves classifier built from the training data into the given file. In case
626
* the filename ends with ".xml" the options are saved XML, not the model. <p/>
629
* Outputs no statistics for the training data. <p/>
632
* Outputs statistics only, not the classifier. <p/>
635
* Outputs detailed information-retrieval statistics per class. <p/>
638
* Outputs information-theoretic statistics. <p/>
641
* Outputs predictions for test instances (or the train instances if no test
642
* instances provided), along with the attributes in the specified range
643
* (and nothing else). Use '-p 0' if no attributes are desired. <p/>
645
* -distribution <br/>
646
* Outputs the distribution instead of only the prediction
647
* in conjunction with the '-p' option (only nominal classes). <p/>
650
* Outputs cumulative margin distribution (and nothing else). <p/>
653
* Only for classifiers that implement "Graphable." Outputs
654
* the graph representation of the classifier (and nothing
657
* -xml filename | xml-string <br/>
658
* Retrieves the options from the XML-data instead of the command line. <p/>
660
* @param classifier machine learning classifier
661
* @param options the array of string containing the options
662
* @throws Exception if model could not be evaluated successfully
663
* @return a string describing the results
665
public static String evaluateModel(Classifier classifier,
666
String [] options) throws Exception {
668
Instances train = null, tempTrain, test = null, template = null;
669
int seed = 1, folds = 10, classIndex = -1;
670
boolean noCrossValidation = false;
671
String trainFileName, testFileName, sourceClass,
672
classIndexString, seedString, foldsString, objectInputFileName,
673
objectOutputFileName, attributeRangeString;
674
boolean noOutput = false,
675
printClassifications = false, trainStatistics = true,
676
printMargins = false, printComplexityStatistics = false,
677
printGraph = false, classStatistics = false, printSource = false;
678
StringBuffer text = new StringBuffer();
679
DataSource trainSource = null, testSource = null;
680
ObjectInputStream objectInputStream = null;
681
BufferedInputStream xmlInputStream = null;
682
CostMatrix costMatrix = null;
683
StringBuffer schemeOptionsText = null;
684
Range attributesToOutput = null;
685
long trainTimeStart = 0, trainTimeElapsed = 0,
686
testTimeStart = 0, testTimeElapsed = 0;
688
String[] optionsTmp = null;
689
Classifier classifierBackup;
690
Classifier classifierClassifications = null;
691
boolean printDistribution = false;
692
int actualClassIndex = -1; // 0-based class index
693
String splitPercentageString = "";
694
int splitPercentage = -1;
695
boolean preserveOrder = false;
696
boolean trainSetPresent = false;
697
boolean testSetPresent = false;
698
String thresholdFile;
699
String thresholdLabel;
702
if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
703
throw new Exception("\nHelp requested." + makeOptionString(classifier));
707
// do we get the input from XML instead of normal parameters?
708
xml = Utils.getOption("xml", options);
710
options = new XMLOptions(xml).toArray();
712
// is the input model only the XML-Options, i.e. w/o built model?
713
optionsTmp = new String[options.length];
714
for (int i = 0; i < options.length; i++)
715
optionsTmp[i] = options[i];
717
if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
718
// load options from serialized data ('-l' is automatically erased!)
719
XMLClassifier xmlserial = new XMLClassifier();
720
Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
722
optionsTmp = new String[options.length + cl.getOptions().length];
723
System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
724
System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
725
options = optionsTmp;
728
noCrossValidation = Utils.getFlag("no-cv", options);
729
// Get basic options (options the same for all schemes)
730
classIndexString = Utils.getOption('c', options);
731
if (classIndexString.length() != 0) {
732
if (classIndexString.equals("first"))
734
else if (classIndexString.equals("last"))
737
classIndex = Integer.parseInt(classIndexString);
739
trainFileName = Utils.getOption('t', options);
740
objectInputFileName = Utils.getOption('l', options);
741
objectOutputFileName = Utils.getOption('d', options);
742
testFileName = Utils.getOption('T', options);
743
foldsString = Utils.getOption('x', options);
744
if (foldsString.length() != 0) {
745
folds = Integer.parseInt(foldsString);
747
seedString = Utils.getOption('s', options);
748
if (seedString.length() != 0) {
749
seed = Integer.parseInt(seedString);
751
if (trainFileName.length() == 0) {
752
if (objectInputFileName.length() == 0) {
753
throw new Exception("No training file and no object "+
754
"input file given.");
756
if (testFileName.length() == 0) {
757
throw new Exception("No training file and no test "+
760
} else if ((objectInputFileName.length() != 0) &&
761
((!(classifier instanceof UpdateableClassifier)) ||
762
(testFileName.length() == 0))) {
763
throw new Exception("Classifier not incremental, or no " +
764
"test file provided: can't "+
765
"use both train and model file.");
768
if (trainFileName.length() != 0) {
769
trainSetPresent = true;
770
trainSource = new DataSource(trainFileName);
772
if (testFileName.length() != 0) {
773
testSetPresent = true;
774
testSource = new DataSource(testFileName);
776
if (objectInputFileName.length() != 0) {
777
InputStream is = new FileInputStream(objectInputFileName);
778
if (objectInputFileName.endsWith(".gz")) {
779
is = new GZIPInputStream(is);
782
if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
783
objectInputStream = new ObjectInputStream(is);
784
xmlInputStream = null;
787
objectInputStream = null;
788
xmlInputStream = new BufferedInputStream(is);
791
} catch (Exception e) {
792
throw new Exception("Can't open file " + e.getMessage() + '.');
794
if (testSetPresent) {
795
template = test = testSource.getStructure();
796
if (classIndex != -1) {
797
test.setClassIndex(classIndex - 1);
799
if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
800
test.setClassIndex(test.numAttributes() - 1);
802
actualClassIndex = test.classIndex();
806
splitPercentageString = Utils.getOption("split-percentage", options);
807
if (splitPercentageString.length() != 0) {
808
if (foldsString.length() != 0)
810
"Percentage split cannot be used in conjunction with "
811
+ "cross-validation ('-x').");
812
splitPercentage = Integer.parseInt(splitPercentageString);
813
if ((splitPercentage <= 0) || (splitPercentage >= 100))
814
throw new Exception("Percentage split value needs be >0 and <100.");
817
splitPercentage = -1;
819
preserveOrder = Utils.getFlag("preserve-order", options);
821
if (splitPercentage == -1)
822
throw new Exception("Percentage split ('-percentage-split') is missing.");
824
// create new train/test sources
825
if (splitPercentage > 0) {
826
testSetPresent = true;
827
Instances tmpInst = trainSource.getDataSet(actualClassIndex);
829
tmpInst.randomize(new Random(seed));
830
int trainSize = tmpInst.numInstances() * splitPercentage / 100;
831
int testSize = tmpInst.numInstances() - trainSize;
832
Instances trainInst = new Instances(tmpInst, 0, trainSize);
833
Instances testInst = new Instances(tmpInst, trainSize, testSize);
834
trainSource = new DataSource(trainInst);
835
testSource = new DataSource(testInst);
836
template = test = testSource.getStructure();
837
if (classIndex != -1) {
838
test.setClassIndex(classIndex - 1);
840
if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
841
test.setClassIndex(test.numAttributes() - 1);
843
actualClassIndex = test.classIndex();
846
if (trainSetPresent) {
847
template = train = trainSource.getStructure();
848
if (classIndex != -1) {
849
train.setClassIndex(classIndex - 1);
851
if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
852
train.setClassIndex(train.numAttributes() - 1);
854
actualClassIndex = train.classIndex();
855
if ((testSetPresent) && !test.equalHeaders(train)) {
856
throw new IllegalArgumentException("Train and test file not compatible!");
859
if (template == null) {
860
throw new Exception("No actual dataset provided to use as template");
862
costMatrix = handleCostOption(
863
Utils.getOption('m', options), template.numClasses());
865
classStatistics = Utils.getFlag('i', options);
866
noOutput = Utils.getFlag('o', options);
867
trainStatistics = !Utils.getFlag('v', options);
868
printComplexityStatistics = Utils.getFlag('k', options);
869
printMargins = Utils.getFlag('r', options);
870
printGraph = Utils.getFlag('g', options);
871
sourceClass = Utils.getOption('z', options);
872
printSource = (sourceClass.length() != 0);
873
printDistribution = Utils.getFlag("distribution", options);
874
thresholdFile = Utils.getOption("threshold-file", options);
875
thresholdLabel = Utils.getOption("threshold-label", options);
879
attributeRangeString = Utils.getOption('p', options);
881
catch (Exception e) {
882
throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
883
"It now expects a parameter specifying a range of attributes " +
884
"to list with the predictions. Use '-p 0' for none.");
886
if (attributeRangeString.length() != 0) {
887
printClassifications = true;
888
if (!attributeRangeString.equals("0"))
889
attributesToOutput = new Range(attributeRangeString);
892
if (!printClassifications && printDistribution)
893
throw new Exception("Cannot print distribution without '-p' option!");
895
// if no training file given, we don't have any priors
896
if ( (!trainSetPresent) && (printComplexityStatistics) )
897
throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
899
// If a model file is given, we can't process
900
// scheme-specific options
901
if (objectInputFileName.length() != 0) {
902
Utils.checkForRemainingOptions(options);
905
// Set options for classifier
906
if (classifier instanceof OptionHandler) {
907
for (int i = 0; i < options.length; i++) {
908
if (options[i].length() != 0) {
909
if (schemeOptionsText == null) {
910
schemeOptionsText = new StringBuffer();
912
if (options[i].indexOf(' ') != -1) {
913
schemeOptionsText.append('"' + options[i] + "\" ");
915
schemeOptionsText.append(options[i] + " ");
919
((OptionHandler)classifier).setOptions(options);
922
Utils.checkForRemainingOptions(options);
923
} catch (Exception e) {
924
throw new Exception("\nWeka exception: " + e.getMessage()
925
+ makeOptionString(classifier));
928
// Setup up evaluation objects
929
Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
930
Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
932
// disable use of priors if no training file given
933
if (!trainSetPresent)
934
testingEvaluation.useNoPriors();
936
if (objectInputFileName.length() != 0) {
937
// Load classifier from file
938
if (objectInputStream != null) {
939
classifier = (Classifier) objectInputStream.readObject();
940
// try and read a header (if present)
941
Instances savedStructure = null;
943
savedStructure = (Instances) objectInputStream.readObject();
944
} catch (Exception ex) {
947
if (savedStructure != null) {
948
// test for compatibility with template
949
if (!template.equalHeaders(savedStructure)) {
950
throw new Exception("training and test set are not compatible");
953
objectInputStream.close();
956
// whether KOML is available has already been checked (objectInputStream would null otherwise)!
957
classifier = (Classifier) KOML.read(xmlInputStream);
958
xmlInputStream.close();
962
// backup of fully setup classifier for cross-validation
963
classifierBackup = Classifier.makeCopy(classifier);
965
// Build the classifier if no object file provided
966
if ((classifier instanceof UpdateableClassifier) &&
968
(costMatrix == null) &&
971
// Build classifier incrementally
972
trainingEvaluation.setPriors(train);
973
testingEvaluation.setPriors(train);
974
trainTimeStart = System.currentTimeMillis();
975
if (objectInputFileName.length() == 0) {
976
classifier.buildClassifier(train);
979
while (trainSource.hasMoreElements(train)) {
980
trainInst = trainSource.nextElement(train);
981
trainingEvaluation.updatePriors(trainInst);
982
testingEvaluation.updatePriors(trainInst);
983
((UpdateableClassifier)classifier).updateClassifier(trainInst);
985
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
986
} else if (objectInputFileName.length() == 0) {
987
// Build classifier in one go
988
tempTrain = trainSource.getDataSet(actualClassIndex);
989
trainingEvaluation.setPriors(tempTrain);
990
testingEvaluation.setPriors(tempTrain);
991
trainTimeStart = System.currentTimeMillis();
992
classifier.buildClassifier(tempTrain);
993
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
996
// backup of fully trained classifier for printing the classifications
997
if (printClassifications)
998
classifierClassifications = Classifier.makeCopy(classifier);
1000
// Save the classifier if an object output file is provided
1001
if (objectOutputFileName.length() != 0) {
1002
OutputStream os = new FileOutputStream(objectOutputFileName);
1004
if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
1005
if (objectOutputFileName.endsWith(".gz")) {
1006
os = new GZIPOutputStream(os);
1008
ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
1009
objectOutputStream.writeObject(classifier);
1010
if (template != null) {
1011
objectOutputStream.writeObject(template);
1013
objectOutputStream.flush();
1014
objectOutputStream.close();
1018
BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
1019
if (objectOutputFileName.endsWith(".xml")) {
1020
XMLSerialization xmlSerial = new XMLClassifier();
1021
xmlSerial.write(xmlOutputStream, classifier);
1024
// whether KOML is present has already been checked
1025
// if not present -> ".koml" is interpreted as binary - see above
1026
if (objectOutputFileName.endsWith(".koml")) {
1027
KOML.write(xmlOutputStream, classifier);
1029
xmlOutputStream.close();
1033
// If classifier is drawable output string describing graph
1034
if ((classifier instanceof Drawable) && (printGraph)){
1035
return ((Drawable)classifier).graph();
1038
// Output the classifier as equivalent source
1039
if ((classifier instanceof Sourcable) && (printSource)){
1040
return wekaStaticWrapper((Sourcable) classifier, sourceClass);
1044
if (!(noOutput || printMargins)) {
1045
if (classifier instanceof OptionHandler) {
1046
if (schemeOptionsText != null) {
1047
text.append("\nOptions: "+schemeOptionsText);
1051
text.append("\n" + classifier.toString() + "\n");
1054
if (!printMargins && (costMatrix != null)) {
1055
text.append("\n=== Evaluation Cost Matrix ===\n\n");
1056
text.append(costMatrix.toString());
1059
// Output test instance predictions only
1060
if (printClassifications) {
1061
DataSource source = testSource;
1062
// no test set -> use train set
1064
source = trainSource;
1065
return printClassifications(classifierClassifications, new Instances(template, 0),
1066
source, actualClassIndex + 1, attributesToOutput,
1070
// Compute error estimate from training data
1071
if ((trainStatistics) && (trainSetPresent)) {
1073
if ((classifier instanceof UpdateableClassifier) &&
1075
(costMatrix == null)) {
1077
// Classifier was trained incrementally, so we have to
1078
// reset the source.
1079
trainSource.reset();
1081
// Incremental testing
1082
train = trainSource.getStructure(actualClassIndex);
1083
testTimeStart = System.currentTimeMillis();
1085
while (trainSource.hasMoreElements(train)) {
1086
trainInst = trainSource.nextElement(train);
1087
trainingEvaluation.evaluateModelOnce((Classifier)classifier, trainInst);
1089
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1091
testTimeStart = System.currentTimeMillis();
1092
trainingEvaluation.evaluateModel(
1093
classifier, trainSource.getDataSet(actualClassIndex));
1094
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1097
// Print the results of the training evaluation
1099
return trainingEvaluation.toCumulativeMarginDistributionString();
1101
text.append("\nTime taken to build model: "
1102
+ Utils.doubleToString(trainTimeElapsed / 1000.0,2)
1105
if (splitPercentage > 0)
1106
text.append("\nTime taken to test model on training split: ");
1108
text.append("\nTime taken to test model on training data: ");
1109
text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
1111
if (splitPercentage > 0)
1112
text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
1113
+ " split ===\n", printComplexityStatistics));
1115
text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
1116
+ " data ===\n", printComplexityStatistics));
1118
if (template.classAttribute().isNominal()) {
1119
if (classStatistics) {
1120
text.append("\n\n" + trainingEvaluation.toClassDetailsString());
1122
if (!noCrossValidation)
1123
text.append("\n\n" + trainingEvaluation.toMatrixString());
1129
// Compute proper error estimates
1130
if (testSource != null) {
1131
// Testing is on the supplied test data
1133
while (testSource.hasMoreElements(test)) {
1134
testInst = testSource.nextElement(test);
1135
testingEvaluation.evaluateModelOnceAndRecordPrediction(
1136
(Classifier)classifier, testInst);
1139
if (splitPercentage > 0)
1140
text.append("\n\n" + testingEvaluation.
1141
toSummaryString("=== Error on test split ===\n",
1142
printComplexityStatistics));
1144
text.append("\n\n" + testingEvaluation.
1145
toSummaryString("=== Error on test data ===\n",
1146
printComplexityStatistics));
1148
} else if (trainSource != null) {
1149
if (!noCrossValidation) {
1150
// Testing is via cross-validation on training data
1151
Random random = new Random(seed);
1152
// use untrained (!) classifier for cross-validation
1153
classifier = Classifier.makeCopy(classifierBackup);
1154
testingEvaluation.crossValidateModel(
1155
classifier, trainSource.getDataSet(actualClassIndex), folds, random);
1156
if (template.classAttribute().isNumeric()) {
1157
text.append("\n\n\n" + testingEvaluation.
1158
toSummaryString("=== Cross-validation ===\n",
1159
printComplexityStatistics));
1161
text.append("\n\n\n" + testingEvaluation.
1162
toSummaryString("=== Stratified " +
1163
"cross-validation ===\n",
1164
printComplexityStatistics));
1168
if (template.classAttribute().isNominal()) {
1169
if (classStatistics) {
1170
text.append("\n\n" + testingEvaluation.toClassDetailsString());
1172
if (!noCrossValidation)
1173
text.append("\n\n" + testingEvaluation.toMatrixString());
1176
if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
1178
if (thresholdLabel.length() != 0)
1179
labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
1180
if (labelIndex == -1)
1181
throw new IllegalArgumentException(
1182
"Class label '" + thresholdLabel + "' is unknown!");
1183
ThresholdCurve tc = new ThresholdCurve();
1184
Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
1185
DataSink.write(thresholdFile, result);
1188
return text.toString();
1192
* Attempts to load a cost matrix.
1194
* @param costFileName the filename of the cost matrix
1195
* @param numClasses the number of classes that should be in the cost matrix
1196
* (only used if the cost file is in old format).
1197
* @return a <code>CostMatrix</code> value, or null if costFileName is empty
1198
* @throws Exception if an error occurs.
1200
protected static CostMatrix handleCostOption(String costFileName,
1204
if ((costFileName != null) && (costFileName.length() != 0)) {
1206
"NOTE: The behaviour of the -m option has changed between WEKA 3.0"
1207
+" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
1208
+" only. For cost-sensitive *prediction*, use one of the"
1209
+" cost-sensitive metaschemes such as"
1210
+" weka.classifiers.meta.CostSensitiveClassifier or"
1211
+" weka.classifiers.meta.MetaCost");
1213
Reader costReader = null;
1215
costReader = new BufferedReader(new FileReader(costFileName));
1216
} catch (Exception e) {
1217
throw new Exception("Can't open file " + e.getMessage() + '.');
1220
// First try as a proper cost matrix format
1221
return new CostMatrix(costReader);
1222
} catch (Exception ex) {
1224
// Now try as the poxy old format :-)
1225
//System.err.println("Attempting to read old format cost file");
1227
costReader.close(); // Close the old one
1228
costReader = new BufferedReader(new FileReader(costFileName));
1229
} catch (Exception e) {
1230
throw new Exception("Can't open file " + e.getMessage() + '.');
1232
CostMatrix costMatrix = new CostMatrix(numClasses);
1233
//System.err.println("Created default cost matrix");
1234
costMatrix.readOldFormat(costReader);
1236
//System.err.println("Read old format");
1237
} catch (Exception e2) {
1238
// re-throw the original exception
1239
//System.err.println("Re-throwing original exception");
1249
* Evaluates the classifier on a given set of instances. Note that
1250
* the data must have exactly the same format (e.g. order of
1251
* attributes) as the data used to train the classifier! Otherwise
1252
* the results will generally be meaningless.
1254
* @param classifier machine learning classifier
1255
* @param data set of test instances for evaluation
1256
* @return the predictions
1257
* @throws Exception if model could not be evaluated
1260
public double[] evaluateModel(Classifier classifier,
1261
Instances data) throws Exception {
1263
double predictions[] = new double[data.numInstances()];
1265
// Need to be able to collect predictions if appropriate (for AUC)
1267
for (int i = 0; i < data.numInstances(); i++) {
1268
predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier,
1276
* Evaluates the classifier on a single instance and records the
1277
* prediction (if the class is nominal).
1279
* @param classifier machine learning classifier
1280
* @param instance the test instance to be classified
1281
* @return the prediction made by the clasifier
1282
* @throws Exception if model could not be evaluated
1283
* successfully or the data contains string attributes
1285
public double evaluateModelOnceAndRecordPrediction(Classifier classifier,
1286
Instance instance) throws Exception {
1288
Instance classMissing = (Instance)instance.copy();
1290
classMissing.setDataset(instance.dataset());
1291
classMissing.setClassMissing();
1292
if (m_ClassIsNominal) {
1293
if (m_Predictions == null) {
1294
m_Predictions = new FastVector();
1296
double [] dist = classifier.distributionForInstance(classMissing);
1297
pred = Utils.maxIndex(dist);
1298
if (dist[(int)pred] <= 0) {
1299
pred = Instance.missingValue();
1301
updateStatsForClassifier(dist, instance);
1302
m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
1303
instance.weight()));
1305
pred = classifier.classifyInstance(classMissing);
1306
updateStatsForPredictor(pred, instance);
1312
* Evaluates the classifier on a single instance.
1314
* @param classifier machine learning classifier
1315
* @param instance the test instance to be classified
1316
* @return the prediction made by the clasifier
1317
* @throws Exception if model could not be evaluated
1318
* successfully or the data contains string attributes
1320
public double evaluateModelOnce(Classifier classifier,
1321
Instance instance) throws Exception {
1323
Instance classMissing = (Instance)instance.copy();
1325
classMissing.setDataset(instance.dataset());
1326
classMissing.setClassMissing();
1327
if (m_ClassIsNominal) {
1328
double [] dist = classifier.distributionForInstance(classMissing);
1329
pred = Utils.maxIndex(dist);
1330
if (dist[(int)pred] <= 0) {
1331
pred = Instance.missingValue();
1333
updateStatsForClassifier(dist, instance);
1335
pred = classifier.classifyInstance(classMissing);
1336
updateStatsForPredictor(pred, instance);
1342
* Evaluates the supplied distribution on a single instance.
1344
* @param dist the supplied distribution
1345
* @param instance the test instance to be classified
1346
* @return the prediction
1347
* @throws Exception if model could not be evaluated
1350
public double evaluateModelOnce(double [] dist,
1351
Instance instance) throws Exception {
1353
if (m_ClassIsNominal) {
1354
pred = Utils.maxIndex(dist);
1355
if (dist[(int)pred] <= 0) {
1356
pred = Instance.missingValue();
1358
updateStatsForClassifier(dist, instance);
1361
updateStatsForPredictor(pred, instance);
1367
* Evaluates the supplied distribution on a single instance.
1369
* @param dist the supplied distribution
1370
* @param instance the test instance to be classified
1371
* @return the prediction
1372
* @throws Exception if model could not be evaluated
1375
public double evaluateModelOnceAndRecordPrediction(double [] dist,
1376
Instance instance) throws Exception {
1378
if (m_ClassIsNominal) {
1379
if (m_Predictions == null) {
1380
m_Predictions = new FastVector();
1382
pred = Utils.maxIndex(dist);
1383
if (dist[(int)pred] <= 0) {
1384
pred = Instance.missingValue();
1386
updateStatsForClassifier(dist, instance);
1387
m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
1388
instance.weight()));
1391
updateStatsForPredictor(pred, instance);
1397
* Evaluates the supplied prediction on a single instance.
1399
* @param prediction the supplied prediction
1400
* @param instance the test instance to be classified
1401
* @throws Exception if model could not be evaluated
1404
public void evaluateModelOnce(double prediction,
1405
Instance instance) throws Exception {
1407
if (m_ClassIsNominal) {
1408
updateStatsForClassifier(makeDistribution(prediction),
1411
updateStatsForPredictor(prediction, instance);
1416
* Returns the predictions that have been collected.
1418
* @return a reference to the FastVector containing the predictions
1419
* that have been collected. This should be null if no predictions
1420
* have been collected (e.g. if the class is numeric).
1422
public FastVector predictions() {
1424
return m_Predictions;
1428
* Wraps a static classifier in enough source to test using the weka
1431
* @param classifier a Sourcable Classifier
1432
* @param className the name to give to the source code class
1433
* @return the source for a static classifier that can be tested with
1435
* @throws Exception if code-generation fails
1437
public static String wekaStaticWrapper(Sourcable classifier, String className)
1440
StringBuffer result = new StringBuffer();
1441
String staticClassifier = classifier.toSource(className);
1443
result.append("// Generated with Weka " + Version.VERSION + "\n");
1444
result.append("//\n");
1445
result.append("// This code is public domain and comes with no warranty.\n");
1446
result.append("//\n");
1447
result.append("// Timestamp: " + new Date() + "\n");
1448
result.append("\n");
1449
result.append("package weka.classifiers;\n");
1450
result.append("\n");
1451
result.append("import weka.core.Attribute;\n");
1452
result.append("import weka.core.Capabilities;\n");
1453
result.append("import weka.core.Capabilities.Capability;\n");
1454
result.append("import weka.core.Instance;\n");
1455
result.append("import weka.core.Instances;\n");
1456
result.append("import weka.classifiers.Classifier;\n");
1457
result.append("\n");
1458
result.append("public class WekaWrapper\n");
1459
result.append(" extends Classifier {\n");
1462
result.append("\n");
1463
result.append(" /**\n");
1464
result.append(" * Returns only the toString() method.\n");
1465
result.append(" *\n");
1466
result.append(" * @return a string describing the classifier\n");
1467
result.append(" */\n");
1468
result.append(" public String globalInfo() {\n");
1469
result.append(" return toString();\n");
1470
result.append(" }\n");
1473
result.append("\n");
1474
result.append(" /**\n");
1475
result.append(" * Returns the capabilities of this classifier.\n");
1476
result.append(" *\n");
1477
result.append(" * @return the capabilities\n");
1478
result.append(" */\n");
1479
result.append(" public Capabilities getCapabilities() {\n");
1480
result.append(((Classifier) classifier).getCapabilities().toSource("result", 4));
1481
result.append(" return result;\n");
1482
result.append(" }\n");
1485
result.append("\n");
1486
result.append(" /**\n");
1487
result.append(" * only checks the data against its capabilities.\n");
1488
result.append(" *\n");
1489
result.append(" * @param i the training data\n");
1490
result.append(" */\n");
1491
result.append(" public void buildClassifier(Instances i) throws Exception {\n");
1492
result.append(" // can classifier handle the data?\n");
1493
result.append(" getCapabilities().testWithFail(i);\n");
1494
result.append(" }\n");
1497
result.append("\n");
1498
result.append(" /**\n");
1499
result.append(" * Classifies the given instance.\n");
1500
result.append(" *\n");
1501
result.append(" * @param i the instance to classify\n");
1502
result.append(" * @return the classification result\n");
1503
result.append(" */\n");
1504
result.append(" public double classifyInstance(Instance i) throws Exception {\n");
1505
result.append(" Object[] s = new Object[i.numAttributes()];\n");
1506
result.append(" \n");
1507
result.append(" for (int j = 0; j < s.length; j++) {\n");
1508
result.append(" if (!i.isMissing(j)) {\n");
1509
result.append(" if (i.attribute(j).isNominal())\n");
1510
result.append(" s[j] = new String(i.stringValue(j));\n");
1511
result.append(" else if (i.attribute(j).isNumeric())\n");
1512
result.append(" s[j] = new Double(i.value(j));\n");
1513
result.append(" }\n");
1514
result.append(" }\n");
1515
result.append(" \n");
1516
result.append(" // set class value to missing\n");
1517
result.append(" s[i.classIndex()] = null;\n");
1518
result.append(" \n");
1519
result.append(" return " + className + ".classify(s);\n");
1520
result.append(" }\n");
1523
result.append("\n");
1524
result.append(" /**\n");
1525
result.append(" * Returns only the classnames and what classifier it is based on.\n");
1526
result.append(" *\n");
1527
result.append(" * @return a short description\n");
1528
result.append(" */\n");
1529
result.append(" public String toString() {\n");
1530
result.append(" return \"Auto-generated classifier wrapper, based on "
1531
+ classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n"
1532
+ "\" + this.getClass().getName() + \"/" + className + "\";\n");
1533
result.append(" }\n");
1536
result.append("\n");
1537
result.append(" /**\n");
1538
result.append(" * Runs the classfier from commandline.\n");
1539
result.append(" *\n");
1540
result.append(" * @param args the commandline arguments\n");
1541
result.append(" */\n");
1542
result.append(" public static void main(String args[]) {\n");
1543
result.append(" runClassifier(new WekaWrapper(), args);\n");
1544
result.append(" }\n");
1545
result.append("}\n");
1547
// actual classifier code
1548
result.append("\n");
1549
result.append(staticClassifier);
1551
return result.toString();
1555
* Gets the number of test instances that had a known class value
1556
* (actually the sum of the weights of test instances with known
1559
* @return the number of test instances with known class
1561
public final double numInstances() {
1567
* Gets the number of instances incorrectly classified (that is, for
1568
* which an incorrect prediction was made). (Actually the sum of the weights
1569
* of these instances)
1571
* @return the number of incorrectly classified instances
1573
public final double incorrect() {
1579
* Gets the percentage of instances incorrectly classified (that is, for
1580
* which an incorrect prediction was made).
1582
* @return the percent of incorrectly classified instances
1583
* (between 0 and 100)
1585
public final double pctIncorrect() {
1587
return 100 * m_Incorrect / m_WithClass;
1591
* Gets the total cost, that is, the cost of each prediction times the
1592
* weight of the instance, summed over all instances.
1594
* @return the total cost
1596
public final double totalCost() {
1602
* Gets the average cost, that is, total cost of misclassifications
1603
* (incorrect plus unclassified) over the total number of instances.
1605
* @return the average cost.
1607
public final double avgCost() {
1609
return m_TotalCost / m_WithClass;
1613
* Gets the number of instances correctly classified (that is, for
1614
* which a correct prediction was made). (Actually the sum of the weights
1615
* of these instances)
1617
* @return the number of correctly classified instances
1619
public final double correct() {
1625
* Gets the percentage of instances correctly classified (that is, for
1626
* which a correct prediction was made).
1628
* @return the percent of correctly classified instances (between 0 and 100)
1630
public final double pctCorrect() {
1632
return 100 * m_Correct / m_WithClass;
1636
* Gets the number of instances not classified (that is, for
1637
* which no prediction was made by the classifier). (Actually the sum
1638
* of the weights of these instances)
1640
* @return the number of unclassified instances
1642
public final double unclassified() {
1644
return m_Unclassified;
1648
* Gets the percentage of instances not classified (that is, for
1649
* which no prediction was made by the classifier).
1651
* @return the percent of unclassified instances (between 0 and 100)
1653
public final double pctUnclassified() {
1655
return 100 * m_Unclassified / m_WithClass;
1659
* Returns the estimated error rate or the root mean squared error
1660
* (if the class is numeric). If a cost matrix was given this
1661
* error rate gives the average cost.
1663
* @return the estimated error rate (between 0 and 1, or between 0 and
1666
public final double errorRate() {
1668
if (!m_ClassIsNominal) {
1669
return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
1671
if (m_CostMatrix == null) {
1672
return m_Incorrect / m_WithClass;
1679
* Returns value of kappa statistic if class is nominal.
1681
* @return the value of the kappa statistic
1683
public final double kappa() {
1686
double[] sumRows = new double[m_ConfusionMatrix.length];
1687
double[] sumColumns = new double[m_ConfusionMatrix.length];
1688
double sumOfWeights = 0;
1689
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1690
for (int j = 0; j < m_ConfusionMatrix.length; j++) {
1691
sumRows[i] += m_ConfusionMatrix[i][j];
1692
sumColumns[j] += m_ConfusionMatrix[i][j];
1693
sumOfWeights += m_ConfusionMatrix[i][j];
1696
double correct = 0, chanceAgreement = 0;
1697
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1698
chanceAgreement += (sumRows[i] * sumColumns[i]);
1699
correct += m_ConfusionMatrix[i][i];
1701
chanceAgreement /= (sumOfWeights * sumOfWeights);
1702
correct /= sumOfWeights;
1704
if (chanceAgreement < 1) {
1705
return (correct - chanceAgreement) / (1 - chanceAgreement);
1712
* Returns the correlation coefficient if the class is numeric.
1714
* @return the correlation coefficient
1715
* @throws Exception if class is not numeric
1717
public final double correlationCoefficient() throws Exception {
1719
if (m_ClassIsNominal) {
1721
new Exception("Can't compute correlation coefficient: " +
1722
"class is nominal!");
1725
double correlation = 0;
1727
m_SumSqrClass - m_SumClass * m_SumClass /
1728
(m_WithClass - m_Unclassified);
1729
double varPredicted =
1730
m_SumSqrPredicted - m_SumPredicted * m_SumPredicted /
1731
(m_WithClass - m_Unclassified);
1733
m_SumClassPredicted - m_SumClass * m_SumPredicted /
1734
(m_WithClass - m_Unclassified);
1736
if (varActual * varPredicted <= 0) {
1739
correlation = varProd / Math.sqrt(varActual * varPredicted);
1746
* Returns the mean absolute error. Refers to the error of the
1747
* predicted values for numeric classes, and the error of the
1748
* predicted probability distribution for nominal classes.
1750
* @return the mean absolute error
1752
public final double meanAbsoluteError() {
1754
return m_SumAbsErr / (m_WithClass - m_Unclassified);
1758
* Returns the mean absolute error of the prior.
1760
* @return the mean absolute error
1762
public final double meanPriorAbsoluteError() {
1767
return m_SumPriorAbsErr / m_WithClass;
1771
* Returns the relative absolute error.
1773
* @return the relative absolute error
1774
* @throws Exception if it can't be computed
1776
public final double relativeAbsoluteError() throws Exception {
1781
return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
1785
* Returns the root mean squared error.
1787
* @return the root mean squared error
1789
public final double rootMeanSquaredError() {
1791
return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
1795
* Returns the root mean prior squared error.
1797
* @return the root mean prior squared error
1799
public final double rootMeanPriorSquaredError() {
1804
return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
1808
* Returns the root relative squared error if the class is numeric.
1810
* @return the root relative squared error
1812
public final double rootRelativeSquaredError() {
1817
return 100.0 * rootMeanSquaredError() /
1818
rootMeanPriorSquaredError();
1822
* Calculate the entropy of the prior distribution
1824
* @return the entropy of the prior distribution
1825
* @throws Exception if the class is not nominal
1827
public final double priorEntropy() throws Exception {
1829
if (!m_ClassIsNominal) {
1831
new Exception("Can't compute entropy of class prior: " +
1839
for(int i = 0; i < m_NumClasses; i++) {
1840
entropy -= m_ClassPriors[i] / m_ClassPriorsSum
1841
* Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
1847
* Return the total Kononenko & Bratko Information score in bits
1849
* @return the K&B information score
1850
* @throws Exception if the class is not nominal
1852
public final double KBInformation() throws Exception {
1854
if (!m_ClassIsNominal) {
1856
new Exception("Can't compute K&B Info score: " +
1867
* Return the Kononenko & Bratko Information score in bits per
1870
* @return the K&B information score
1871
* @throws Exception if the class is not nominal
1873
public final double KBMeanInformation() throws Exception {
1875
if (!m_ClassIsNominal) {
1877
new Exception("Can't compute K&B Info score: "
1878
+ "class numeric!");
1884
return m_SumKBInfo / (m_WithClass - m_Unclassified);
1888
* Return the Kononenko & Bratko Relative Information score
1890
* @return the K&B relative information score
1891
* @throws Exception if the class is not nominal
1893
public final double KBRelativeInformation() throws Exception {
1895
if (!m_ClassIsNominal) {
1897
new Exception("Can't compute K&B Info score: " +
1904
return 100.0 * KBInformation() / priorEntropy();
1908
* Returns the total entropy for the null model
1910
* @return the total null model entropy
1912
public final double SFPriorEntropy() {
1917
return m_SumPriorEntropy;
1921
* Returns the entropy per instance for the null model
1923
* @return the null model entropy per instance
1925
public final double SFMeanPriorEntropy() {
1930
return m_SumPriorEntropy / m_WithClass;
1934
* Returns the total entropy for the scheme
1936
* @return the total scheme entropy
1938
public final double SFSchemeEntropy() {
1943
return m_SumSchemeEntropy;
1947
* Returns the entropy per instance for the scheme
1949
* @return the scheme entropy per instance
1951
public final double SFMeanSchemeEntropy() {
1956
return m_SumSchemeEntropy / (m_WithClass - m_Unclassified);
1960
* Returns the total SF, which is the null model entropy minus
1961
* the scheme entropy.
1963
* @return the total SF
1965
public final double SFEntropyGain() {
1970
return m_SumPriorEntropy - m_SumSchemeEntropy;
1974
* Returns the SF per instance, which is the null model entropy
1975
* minus the scheme entropy, per instance.
1977
* @return the SF per instance
1979
public final double SFMeanEntropyGain() {
1984
return (m_SumPriorEntropy - m_SumSchemeEntropy) /
1985
(m_WithClass - m_Unclassified);
1989
* Output the cumulative margin distribution as a string suitable
1990
* for input for gnuplot or similar package.
1992
* @return the cumulative margin distribution
1993
* @throws Exception if the class attribute is nominal
1995
public String toCumulativeMarginDistributionString() throws Exception {
1997
if (!m_ClassIsNominal) {
1998
throw new Exception("Class must be nominal for margin distributions");
2001
double cumulativeCount = 0;
2003
for(int i = 0; i <= k_MarginResolution; i++) {
2004
if (m_MarginCounts[i] != 0) {
2005
cumulativeCount += m_MarginCounts[i];
2006
margin = (double)i * 2.0 / k_MarginResolution - 1.0;
2007
result = result + Utils.doubleToString(margin, 7, 3) + ' '
2008
+ Utils.doubleToString(cumulativeCount * 100
2009
/ m_WithClass, 7, 3) + '\n';
2010
} else if (i == 0) {
2011
result = Utils.doubleToString(-1.0, 7, 3) + ' '
2012
+ Utils.doubleToString(0, 7, 3) + '\n';
2020
* Calls toSummaryString() with no title and no complexity stats
2022
* @return a summary description of the classifier evaluation
2024
public String toSummaryString() {
2026
return toSummaryString("", false);
2030
* Calls toSummaryString() with a default title.
2032
* @param printComplexityStatistics if true, complexity statistics are
2034
* @return the summary string
2036
public String toSummaryString(boolean printComplexityStatistics) {
2038
return toSummaryString("=== Summary ===\n", printComplexityStatistics);
2042
* Outputs the performance statistics in summary form. Lists
2043
* number (and percentage) of instances classified correctly,
2044
* incorrectly and unclassified. Outputs the total number of
2045
* instances classified, and the number of instances (if any)
2046
* that had no class value provided.
2048
* @param title the title for the statistics
2049
* @param printComplexityStatistics if true, complexity statistics are
2051
* @return the summary as a String
2053
public String toSummaryString(String title,
2054
boolean printComplexityStatistics) {
2056
StringBuffer text = new StringBuffer();
2058
if (printComplexityStatistics && m_NoPriors) {
2059
printComplexityStatistics = false;
2060
System.err.println("Priors disabled, cannot print complexity statistics!");
2063
text.append(title + "\n");
2065
if (m_WithClass > 0) {
2066
if (m_ClassIsNominal) {
2068
text.append("Correctly Classified Instances ");
2069
text.append(Utils.doubleToString(correct(), 12, 4) + " " +
2070
Utils.doubleToString(pctCorrect(),
2072
text.append("Incorrectly Classified Instances ");
2073
text.append(Utils.doubleToString(incorrect(), 12, 4) + " " +
2074
Utils.doubleToString(pctIncorrect(),
2076
text.append("Kappa statistic ");
2077
text.append(Utils.doubleToString(kappa(), 12, 4) + "\n");
2079
if (m_CostMatrix != null) {
2080
text.append("Total Cost ");
2081
text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n");
2082
text.append("Average Cost ");
2083
text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n");
2085
if (printComplexityStatistics) {
2086
text.append("K&B Relative Info Score ");
2087
text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4)
2089
text.append("K&B Information Score ");
2090
text.append(Utils.doubleToString(KBInformation(), 12, 4)
2092
text.append(Utils.doubleToString(KBMeanInformation(), 12, 4)
2093
+ " bits/instance\n");
2096
text.append("Correlation coefficient ");
2097
text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
2100
if (printComplexityStatistics) {
2101
text.append("Class complexity | order 0 ");
2102
text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4)
2104
text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4)
2105
+ " bits/instance\n");
2106
text.append("Class complexity | scheme ");
2107
text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4)
2109
text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4)
2110
+ " bits/instance\n");
2111
text.append("Complexity improvement (Sf) ");
2112
text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
2113
text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4)
2114
+ " bits/instance\n");
2117
text.append("Mean absolute error ");
2118
text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4)
2120
text.append("Root mean squared error ");
2122
doubleToString(rootMeanSquaredError(), 12, 4)
2125
text.append("Relative absolute error ");
2126
text.append(Utils.doubleToString(relativeAbsoluteError(),
2128
text.append("Root relative squared error ");
2129
text.append(Utils.doubleToString(rootRelativeSquaredError(),
2133
if (Utils.gr(unclassified(), 0)) {
2134
text.append("UnClassified Instances ");
2135
text.append(Utils.doubleToString(unclassified(), 12,4) + " " +
2136
Utils.doubleToString(pctUnclassified(),
2139
text.append("Total Number of Instances ");
2140
text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n");
2141
if (m_MissingClass > 0) {
2142
text.append("Ignored Class Unknown Instances ");
2143
text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n");
2145
} catch (Exception ex) {
2146
// Should never occur since the class is known to be nominal
2148
System.err.println("Arggh - Must be a bug in Evaluation class");
2151
return text.toString();
2155
* Calls toMatrixString() with a default title.
2157
* @return the confusion matrix as a string
2158
* @throws Exception if the class is numeric
2160
public String toMatrixString() throws Exception {
2162
return toMatrixString("=== Confusion Matrix ===\n");
2166
* Outputs the performance statistics as a classification confusion
2167
* matrix. For each class value, shows the distribution of
2168
* predicted class values.
2170
* @param title the title for the confusion matrix
2171
* @return the confusion matrix as a String
2172
* @throws Exception if the class is numeric
2174
public String toMatrixString(String title) throws Exception {
2176
StringBuffer text = new StringBuffer();
2177
char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
2178
'k','l','m','n','o','p','q','r','s','t',
2179
'u','v','w','x','y','z'};
2181
boolean fractional = false;
2183
if (!m_ClassIsNominal) {
2184
throw new Exception("Evaluation: No confusion matrix possible!");
2187
// Find the maximum value in the matrix
2188
// and check for fractional display requirement
2190
for(int i = 0; i < m_NumClasses; i++) {
2191
for(int j = 0; j < m_NumClasses; j++) {
2192
double current = m_ConfusionMatrix[i][j];
2196
if (current > maxval) {
2199
double fract = current - Math.rint(current);
2201
&& ((Math.log(fract) / Math.log(10)) >= -2)) {
2207
IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)
2208
+ (fractional ? 3 : 0)),
2209
(int)(Math.log(m_NumClasses) /
2210
Math.log(IDChars.length)));
2211
text.append(title).append("\n");
2212
for(int i = 0; i < m_NumClasses; i++) {
2214
text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
2217
text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
2220
text.append(" <-- classified as\n");
2221
for(int i = 0; i< m_NumClasses; i++) {
2222
for(int j = 0; j < m_NumClasses; j++) {
2223
text.append(" ").append(
2224
Utils.doubleToString(m_ConfusionMatrix[i][j],
2226
(fractional ? 2 : 0)));
2228
text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
2229
.append(" = ").append(m_ClassNames[i]).append("\n");
2231
return text.toString();
2235
* Generates a breakdown of the accuracy for each class (with default title),
2236
* incorporating various information-retrieval statistics, such as
2237
* true/false positive rate, precision/recall/F-Measure. Should be
2238
* useful for ROC curves, recall/precision curves.
2240
* @return the statistics presented as a string
2241
* @throws Exception if class is not nominal
2243
public String toClassDetailsString() throws Exception {
2245
return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
2249
* Generates a breakdown of the accuracy for each class,
2250
* incorporating various information-retrieval statistics, such as
2251
* true/false positive rate, precision/recall/F-Measure. Should be
2252
* useful for ROC curves, recall/precision curves.
2254
* @param title the title to prepend the stats string with
2255
* @return the statistics presented as a string
2256
* @throws Exception if class is not nominal
2258
public String toClassDetailsString(String title) throws Exception {
2260
if (!m_ClassIsNominal) {
2261
throw new Exception("Evaluation: No confusion matrix possible!");
2263
StringBuffer text = new StringBuffer(title
2264
+ "\nTP Rate FP Rate"
2265
+ " Precision Recall"
2266
+ " F-Measure ROC Area Class\n");
2267
for(int i = 0; i < m_NumClasses; i++) {
2268
text.append(Utils.doubleToString(truePositiveRate(i), 7, 3))
2270
text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
2272
text.append(Utils.doubleToString(precision(i), 7, 3))
2274
text.append(Utils.doubleToString(recall(i), 7, 3))
2276
text.append(Utils.doubleToString(fMeasure(i), 7, 3))
2278
double rocVal = areaUnderROC(i);
2279
if (Instance.isMissingValue(rocVal)) {
2283
text.append(Utils.doubleToString(rocVal, 7, 3))
2286
text.append(m_ClassNames[i]).append('\n');
2288
return text.toString();
2292
* Calculate the number of true positives with respect to a particular class.
2293
* This is defined as<p/>
2295
* correctly classified positives
2298
* @param classIndex the index of the class to consider as "positive"
2299
* @return the true positive rate
2301
public double numTruePositives(int classIndex) {
2304
for (int j = 0; j < m_NumClasses; j++) {
2305
if (j == classIndex) {
2306
correct += m_ConfusionMatrix[classIndex][j];
2313
* Calculate the true positive rate with respect to a particular class.
2314
* This is defined as<p/>
2316
* correctly classified positives
2317
* ------------------------------
2321
* @param classIndex the index of the class to consider as "positive"
2322
* @return the true positive rate
2324
public double truePositiveRate(int classIndex) {
2326
double correct = 0, total = 0;
2327
for (int j = 0; j < m_NumClasses; j++) {
2328
if (j == classIndex) {
2329
correct += m_ConfusionMatrix[classIndex][j];
2331
total += m_ConfusionMatrix[classIndex][j];
2336
return correct / total;
2340
* Calculate the number of true negatives with respect to a particular class.
2341
* This is defined as<p/>
2343
* correctly classified negatives
2346
* @param classIndex the index of the class to consider as "positive"
2347
* @return the true positive rate
2349
public double numTrueNegatives(int classIndex) {
2352
for (int i = 0; i < m_NumClasses; i++) {
2353
if (i != classIndex) {
2354
for (int j = 0; j < m_NumClasses; j++) {
2355
if (j != classIndex) {
2356
correct += m_ConfusionMatrix[i][j];
2365
* Calculate the true negative rate with respect to a particular class.
2366
* This is defined as<p/>
2368
* correctly classified negatives
2369
* ------------------------------
2373
* @param classIndex the index of the class to consider as "positive"
2374
* @return the true positive rate
2376
public double trueNegativeRate(int classIndex) {
2378
double correct = 0, total = 0;
2379
for (int i = 0; i < m_NumClasses; i++) {
2380
if (i != classIndex) {
2381
for (int j = 0; j < m_NumClasses; j++) {
2382
if (j != classIndex) {
2383
correct += m_ConfusionMatrix[i][j];
2385
total += m_ConfusionMatrix[i][j];
2392
return correct / total;
2396
* Calculate number of false positives with respect to a particular class.
2397
* This is defined as<p/>
2399
* incorrectly classified negatives
2402
* @param classIndex the index of the class to consider as "positive"
2403
* @return the false positive rate
2405
public double numFalsePositives(int classIndex) {
2407
double incorrect = 0;
2408
for (int i = 0; i < m_NumClasses; i++) {
2409
if (i != classIndex) {
2410
for (int j = 0; j < m_NumClasses; j++) {
2411
if (j == classIndex) {
2412
incorrect += m_ConfusionMatrix[i][j];
2421
* Calculate the false positive rate with respect to a particular class.
2422
* This is defined as<p/>
2424
* incorrectly classified negatives
2425
* --------------------------------
2429
* @param classIndex the index of the class to consider as "positive"
2430
* @return the false positive rate
2432
public double falsePositiveRate(int classIndex) {
2434
double incorrect = 0, total = 0;
2435
for (int i = 0; i < m_NumClasses; i++) {
2436
if (i != classIndex) {
2437
for (int j = 0; j < m_NumClasses; j++) {
2438
if (j == classIndex) {
2439
incorrect += m_ConfusionMatrix[i][j];
2441
total += m_ConfusionMatrix[i][j];
2448
return incorrect / total;
2452
* Calculate number of false negatives with respect to a particular class.
2453
* This is defined as<p/>
2455
* incorrectly classified positives
2458
* @param classIndex the index of the class to consider as "positive"
2459
* @return the false positive rate
2461
public double numFalseNegatives(int classIndex) {
2463
double incorrect = 0;
2464
for (int i = 0; i < m_NumClasses; i++) {
2465
if (i == classIndex) {
2466
for (int j = 0; j < m_NumClasses; j++) {
2467
if (j != classIndex) {
2468
incorrect += m_ConfusionMatrix[i][j];
2477
* Calculate the false negative rate with respect to a particular class.
2478
* This is defined as<p/>
2480
* incorrectly classified positives
2481
* --------------------------------
2485
* @param classIndex the index of the class to consider as "positive"
2486
* @return the false positive rate
2488
public double falseNegativeRate(int classIndex) {
2490
double incorrect = 0, total = 0;
2491
for (int i = 0; i < m_NumClasses; i++) {
2492
if (i == classIndex) {
2493
for (int j = 0; j < m_NumClasses; j++) {
2494
if (j != classIndex) {
2495
incorrect += m_ConfusionMatrix[i][j];
2497
total += m_ConfusionMatrix[i][j];
2504
return incorrect / total;
2508
* Calculate the recall with respect to a particular class.
2509
* This is defined as<p/>
2511
* correctly classified positives
2512
* ------------------------------
2515
* (Which is also the same as the truePositiveRate.)
2517
* @param classIndex the index of the class to consider as "positive"
2518
* @return the recall
2520
public double recall(int classIndex) {
2522
return truePositiveRate(classIndex);
2526
* Calculate the precision with respect to a particular class.
2527
* This is defined as<p/>
2529
* correctly classified positives
2530
* ------------------------------
2531
* total predicted as positive
2534
* @param classIndex the index of the class to consider as "positive"
2535
* @return the precision
2537
public double precision(int classIndex) {
2539
double correct = 0, total = 0;
2540
for (int i = 0; i < m_NumClasses; i++) {
2541
if (i == classIndex) {
2542
correct += m_ConfusionMatrix[i][classIndex];
2544
total += m_ConfusionMatrix[i][classIndex];
2549
return correct / total;
2553
* Calculate the F-Measure with respect to a particular class.
2554
* This is defined as<p/>
2556
* 2 * recall * precision
2557
* ----------------------
2558
* recall + precision
2561
* @param classIndex the index of the class to consider as "positive"
2562
* @return the F-Measure
2564
public double fMeasure(int classIndex) {
2566
double precision = precision(classIndex);
2567
double recall = recall(classIndex);
2568
if ((precision + recall) == 0) {
2571
return 2 * precision * recall / (precision + recall);
2575
* Sets the class prior probabilities
2577
* @param train the training instances used to determine
2578
* the prior probabilities
2579
* @throws Exception if the class attribute of the instances is not
2582
public void setPriors(Instances train) throws Exception {
2585
if (!m_ClassIsNominal) {
2587
m_NumTrainClassVals = 0;
2588
m_TrainClassVals = null;
2589
m_TrainClassWeights = null;
2590
m_PriorErrorEstimator = null;
2591
m_ErrorEstimator = null;
2593
for (int i = 0; i < train.numInstances(); i++) {
2594
Instance currentInst = train.instance(i);
2595
if (!currentInst.classIsMissing()) {
2596
addNumericTrainClass(currentInst.classValue(),
2597
currentInst.weight());
2602
for (int i = 0; i < m_NumClasses; i++) {
2603
m_ClassPriors[i] = 1;
2605
m_ClassPriorsSum = m_NumClasses;
2606
for (int i = 0; i < train.numInstances(); i++) {
2607
if (!train.instance(i).classIsMissing()) {
2608
m_ClassPriors[(int)train.instance(i).classValue()] +=
2609
train.instance(i).weight();
2610
m_ClassPriorsSum += train.instance(i).weight();
2617
* Get the current weighted class counts
2619
* @return the weighted class counts
2621
public double [] getClassPriors() {
2622
return m_ClassPriors;
2626
* Updates the class prior probabilities (when incrementally
2629
* @param instance the new training instance seen
2630
* @throws Exception if the class of the instance is not
2633
public void updatePriors(Instance instance) throws Exception {
2634
if (!instance.classIsMissing()) {
2635
if (!m_ClassIsNominal) {
2636
if (!instance.classIsMissing()) {
2637
addNumericTrainClass(instance.classValue(),
2641
m_ClassPriors[(int)instance.classValue()] +=
2643
m_ClassPriorsSum += instance.weight();
2649
* disables the use of priors, e.g., in case of de-serialized schemes
2650
* that have no access to the original training set, but are evaluated
2653
public void useNoPriors() {
2658
* Tests whether the current evaluation object is equal to another
2661
* @param obj the object to compare against
2662
* @return true if the two objects are equal
2664
public boolean equals(Object obj) {
2666
if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
2669
Evaluation cmp = (Evaluation) obj;
2670
if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
2671
if (m_NumClasses != cmp.m_NumClasses) return false;
2673
if (m_Incorrect != cmp.m_Incorrect) return false;
2674
if (m_Correct != cmp.m_Correct) return false;
2675
if (m_Unclassified != cmp.m_Unclassified) return false;
2676
if (m_MissingClass != cmp.m_MissingClass) return false;
2677
if (m_WithClass != cmp.m_WithClass) return false;
2679
if (m_SumErr != cmp.m_SumErr) return false;
2680
if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
2681
if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
2682
if (m_SumClass != cmp.m_SumClass) return false;
2683
if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
2684
if (m_SumPredicted != cmp.m_SumPredicted) return false;
2685
if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
2686
if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
2688
if (m_ClassIsNominal) {
2689
for (int i = 0; i < m_NumClasses; i++) {
2690
for (int j = 0; j < m_NumClasses; j++) {
2691
if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
2702
* Prints the predictions for the given dataset into a String variable.
2704
* @param classifier the classifier to use
2705
* @param train the training data
2706
* @param testSource the test set
2707
* @param classIndex the class index (1-based), if -1 ot does not
2708
* override the class index is stored in the data
2709
* file (by using the last attribute)
2710
* @param attributesToOutput the indices of the attributes to output
2711
* @return the generated predictions for the attribute range
2712
* @throws Exception if test file cannot be opened
2714
protected static String printClassifications(Classifier classifier,
2716
DataSource testSource,
2718
Range attributesToOutput) throws Exception {
2720
return printClassifications(
2721
classifier, train, testSource, classIndex, attributesToOutput, false);
2725
* Prints the predictions for the given dataset into a String variable.
2727
* @param classifier the classifier to use
2728
* @param train the training data
2729
* @param testSource the test set
2730
* @param classIndex the class index (1-based), if -1 ot does not
2731
* override the class index is stored in the data
2732
* file (by using the last attribute)
2733
* @param attributesToOutput the indices of the attributes to output
2734
* @param printDistribution prints the complete distribution for nominal
2735
* classes, not just the predicted value
2736
* @return the generated predictions for the attribute range
2737
* @throws Exception if test file cannot be opened
2739
protected static String printClassifications(Classifier classifier,
2741
DataSource testSource,
2743
Range attributesToOutput,
2744
boolean printDistribution) throws Exception {
2746
StringBuffer text = new StringBuffer();
2747
if (testSource != null) {
2748
Instances test = testSource.getStructure();
2749
if (classIndex != -1) {
2750
test.setClassIndex(classIndex - 1);
2752
if (test.classIndex() == -1)
2753
test.setClassIndex(test.numAttributes() - 1);
2757
if (test.classAttribute().isNominal())
2758
if (printDistribution)
2759
text.append(" inst# actual predicted error distribution");
2761
text.append(" inst# actual predicted error prediction");
2763
text.append(" inst# actual predicted error");
2764
if (attributesToOutput != null) {
2765
attributesToOutput.setUpper(test.numAttributes() - 1);
2767
boolean first = true;
2768
for (int i = 0; i < test.numAttributes(); i++) {
2769
if (i == test.classIndex())
2772
if (attributesToOutput.isInRange(i)) {
2775
text.append(test.attribute(i).name());
2783
// print predictions
2786
test = testSource.getStructure(test.classIndex());
2787
while (testSource.hasMoreElements(test)) {
2788
Instance inst = testSource.nextElement(test);
2791
classifier, inst, i, attributesToOutput, printDistribution));
2795
return text.toString();
2799
* returns the prediction made by the classifier as a string
2801
* @param classifier the classifier to use
2802
* @param inst the instance to generate text from
2803
* @param instNum the index in the dataset
2804
* @param attributesToOutput the indices of the attributes to output
2805
* @param printDistribution prints the complete distribution for nominal
2806
* classes, not just the predicted value
2807
* @return the generated text
2808
* @throws Exception if something goes wrong
2809
* @see #printClassifications(Classifier, Instances, String, int, Range, boolean)
2811
protected static String predictionText(Classifier classifier,
2814
Range attributesToOutput,
2815
boolean printDistribution)
2818
StringBuffer result = new StringBuffer();
2822
Instance withMissing = (Instance)inst.copy();
2823
withMissing.setDataset(inst.dataset());
2824
double predValue = ((Classifier)classifier).classifyInstance(withMissing);
2827
result.append(Utils.padLeft("" + (instNum+1), 6));
2829
if (inst.dataset().classAttribute().isNumeric()) {
2831
if (inst.classIsMissing())
2832
result.append(" " + Utils.padLeft("?", width));
2834
result.append(" " + Utils.doubleToString(inst.classValue(), width, prec));
2836
if (Instance.isMissingValue(predValue))
2837
result.append(" " + Utils.padLeft("?", width));
2839
result.append(" " + Utils.doubleToString(predValue, width, prec));
2841
if (Instance.isMissingValue(predValue) || inst.classIsMissing())
2842
result.append(" " + Utils.padLeft("?", width));
2844
result.append(" " + Utils.doubleToString(predValue - inst.classValue(), width, prec));
2847
result.append(" " + Utils.padLeft(((int) inst.classValue()+1) + ":" + inst.toString(inst.classIndex()), width));
2849
if (Instance.isMissingValue(predValue))
2850
result.append(" " + Utils.padLeft("?", width));
2852
result.append(" " + Utils.padLeft(((int) predValue+1) + ":" + inst.dataset().classAttribute().value((int)predValue), width));
2854
if ((int) predValue+1 != (int) inst.classValue()+1)
2855
result.append(" " + " + ");
2857
result.append(" " + " ");
2858
// prediction/distribution
2859
if (printDistribution) {
2860
if (Instance.isMissingValue(predValue)) {
2861
result.append(" " + "?");
2865
double[] dist = classifier.distributionForInstance(withMissing);
2866
for (int n = 0; n < dist.length; n++) {
2869
if (n == (int) predValue)
2871
result.append(Utils.doubleToString(dist[n], prec));
2876
if (Instance.isMissingValue(predValue))
2877
result.append(" " + "?");
2879
result.append(" " + Utils.doubleToString(classifier.distributionForInstance(withMissing) [(int)predValue], prec));
2884
result.append(" " + attributeValuesString(withMissing, attributesToOutput) + "\n");
2886
return result.toString();
2890
* Builds a string listing the attribute values in a specified range of indices,
2891
* separated by commas and enclosed in brackets.
2893
* @param instance the instance to print the values from
2894
* @param attRange the range of the attributes to list
2895
* @return a string listing values of the attributes in the range
2897
protected static String attributeValuesString(Instance instance, Range attRange) {
2898
StringBuffer text = new StringBuffer();
2899
if (attRange != null) {
2900
boolean firstOutput = true;
2901
attRange.setUpper(instance.numAttributes() - 1);
2902
for (int i=0; i<instance.numAttributes(); i++)
2903
if (attRange.isInRange(i) && i != instance.classIndex()) {
2904
if (firstOutput) text.append("(");
2905
else text.append(",");
2906
text.append(instance.toString(i));
2907
firstOutput = false;
2909
if (!firstOutput) text.append(")");
2911
return text.toString();
2915
* Make up the help string giving all the command line options
2917
* @param classifier the classifier to include options for
2918
* @return a string detailing the valid command line options
2920
protected static String makeOptionString(Classifier classifier) {
2922
StringBuffer optionsText = new StringBuffer("");
2925
optionsText.append("\n\nGeneral options:\n\n");
2926
optionsText.append("-t <name of training file>\n");
2927
optionsText.append("\tSets training file.\n");
2928
optionsText.append("-T <name of test file>\n");
2929
optionsText.append("\tSets test file. If missing, a cross-validation will be performed\n");
2930
optionsText.append("\ton the training data.\n");
2931
optionsText.append("-c <class index>\n");
2932
optionsText.append("\tSets index of class attribute (default: last).\n");
2933
optionsText.append("-x <number of folds>\n");
2934
optionsText.append("\tSets number of folds for cross-validation (default: 10).\n");
2935
optionsText.append("-no-cv\n");
2936
optionsText.append("\tDo not perform any cross validation.\n");
2937
optionsText.append("-split-percentage <percentage>\n");
2938
optionsText.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
2939
optionsText.append("-preserve-order\n");
2940
optionsText.append("\tPreserves the order in the percentage split.\n");
2941
optionsText.append("-s <random number seed>\n");
2942
optionsText.append("\tSets random number seed for cross-validation or percentage split\n");
2943
optionsText.append("\t(default: 1).\n");
2944
optionsText.append("-m <name of file with cost matrix>\n");
2945
optionsText.append("\tSets file with cost matrix.\n");
2946
optionsText.append("-l <name of input file>\n");
2947
optionsText.append("\tSets model input file. In case the filename ends with '.xml',\n");
2948
optionsText.append("\tthe options are loaded from the XML file.\n");
2949
optionsText.append("-d <name of output file>\n");
2950
optionsText.append("\tSets model output file. In case the filename ends with '.xml',\n");
2951
optionsText.append("\tonly the options are saved to the XML file, not the model.\n");
2952
optionsText.append("-v\n");
2953
optionsText.append("\tOutputs no statistics for training data.\n");
2954
optionsText.append("-o\n");
2955
optionsText.append("\tOutputs statistics only, not the classifier.\n");
2956
optionsText.append("-i\n");
2957
optionsText.append("\tOutputs detailed information-retrieval");
2958
optionsText.append(" statistics for each class.\n");
2959
optionsText.append("-k\n");
2960
optionsText.append("\tOutputs information-theoretic statistics.\n");
2961
optionsText.append("-p <attribute range>\n");
2962
optionsText.append("\tOnly outputs predictions for test instances (or the train\n"
2963
+ "\tinstances if no test instances provided), along with attributes\n"
2964
+ "\t(0 for none).\n");
2965
optionsText.append("-distribution\n");
2966
optionsText.append("\tOutputs the distribution instead of only the prediction\n");
2967
optionsText.append("\tin conjunction with the '-p' option (only nominal classes).\n");
2968
optionsText.append("-r\n");
2969
optionsText.append("\tOnly outputs cumulative margin distribution.\n");
2970
if (classifier instanceof Sourcable) {
2971
optionsText.append("-z <class name>\n");
2972
optionsText.append("\tOnly outputs the source representation"
2973
+ " of the classifier,\n\tgiving it the supplied"
2976
if (classifier instanceof Drawable) {
2977
optionsText.append("-g\n");
2978
optionsText.append("\tOnly outputs the graph representation"
2979
+ " of the classifier.\n");
2981
optionsText.append("-xml filename | xml-string\n");
2982
optionsText.append("\tRetrieves the options from the XML-data instead of the "
2983
+ "command line.\n");
2984
optionsText.append("-threshold-file <file>\n");
2985
optionsText.append("\tThe file to save the threshold data to.\n"
2986
+ "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
2987
+ "\tformat or '.csv' for CSV.\n");
2988
optionsText.append("-threshold-label <label>\n");
2989
optionsText.append("\tThe class label to determine the threshold data for\n"
2990
+ "\t(default is the first label)\n");
2992
// Get scheme-specific options
2993
if (classifier instanceof OptionHandler) {
2994
optionsText.append("\nOptions specific to "
2995
+ classifier.getClass().getName()
2997
Enumeration enu = ((OptionHandler)classifier).listOptions();
2998
while (enu.hasMoreElements()) {
2999
Option option = (Option) enu.nextElement();
3000
optionsText.append(option.synopsis() + '\n');
3001
optionsText.append(option.description() + "\n");
3004
return optionsText.toString();
3008
* Method for generating indices for the confusion matrix.
3010
* @param num integer to format
3011
* @param IDChars the characters to use
3012
* @param IDWidth the width of the entry
3013
* @return the formatted integer as a string
3015
protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
3017
char ID [] = new char [IDWidth];
3020
for(i = IDWidth - 1; i >=0; i--) {
3021
ID[i] = IDChars[num % IDChars.length];
3022
num = num / IDChars.length - 1;
3027
for(i--; i >= 0; i--) {
3031
return new String(ID);
3035
* Convert a single prediction into a probability distribution
3036
* with all zero probabilities except the predicted value which
3037
* has probability 1.0;
3039
* @param predictedClass the index of the predicted class
3040
* @return the probability distribution
3042
protected double [] makeDistribution(double predictedClass) {
3044
double [] result = new double [m_NumClasses];
3045
if (Instance.isMissingValue(predictedClass)) {
3048
if (m_ClassIsNominal) {
3049
result[(int)predictedClass] = 1.0;
3051
result[0] = predictedClass;
3057
* Updates all the statistics about a classifiers performance for
3058
* the current test instance.
3060
* @param predictedDistribution the probabilities assigned to
3062
* @param instance the instance to be classified
3063
* @throws Exception if the class of the instance is not
3066
protected void updateStatsForClassifier(double [] predictedDistribution,
3070
int actualClass = (int)instance.classValue();
3072
if (!instance.classIsMissing()) {
3073
updateMargins(predictedDistribution, actualClass, instance.weight());
3075
// Determine the predicted class (doesn't detect multiple
3077
int predictedClass = -1;
3078
double bestProb = 0.0;
3079
for(int i = 0; i < m_NumClasses; i++) {
3080
if (predictedDistribution[i] > bestProb) {
3082
bestProb = predictedDistribution[i];
3086
m_WithClass += instance.weight();
3088
// Determine misclassification cost
3089
if (m_CostMatrix != null) {
3090
if (predictedClass < 0) {
3091
// For missing predictions, we assume the worst possible cost.
3092
// This is pretty harsh.
3093
// Perhaps we could take the negative of the cost of a correct
3094
// prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
3095
// although often this will be zero
3096
m_TotalCost += instance.weight()
3097
* m_CostMatrix.getMaxCost(actualClass, instance);
3099
m_TotalCost += instance.weight()
3100
* m_CostMatrix.getElement(actualClass, predictedClass,
3105
// Update counts when no class was predicted
3106
if (predictedClass < 0) {
3107
m_Unclassified += instance.weight();
3111
double predictedProb = Math.max(MIN_SF_PROB,
3112
predictedDistribution[actualClass]);
3113
double priorProb = Math.max(MIN_SF_PROB,
3114
m_ClassPriors[actualClass]
3115
/ m_ClassPriorsSum);
3116
if (predictedProb >= priorProb) {
3117
m_SumKBInfo += (Utils.log2(predictedProb) -
3118
Utils.log2(priorProb))
3119
* instance.weight();
3121
m_SumKBInfo -= (Utils.log2(1.0-predictedProb) -
3122
Utils.log2(1.0-priorProb))
3123
* instance.weight();
3126
m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
3127
m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
3129
updateNumericScores(predictedDistribution,
3130
makeDistribution(instance.classValue()),
3133
// Update other stats
3134
m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
3135
if (predictedClass != actualClass) {
3136
m_Incorrect += instance.weight();
3138
m_Correct += instance.weight();
3141
m_MissingClass += instance.weight();
3146
* Updates all the statistics about a predictors performance for
3147
* the current test instance.
3149
* @param predictedValue the numeric value the classifier predicts
3150
* @param instance the instance to be classified
3151
* @throws Exception if the class of the instance is not
3154
protected void updateStatsForPredictor(double predictedValue,
3158
if (!instance.classIsMissing()){
3161
m_WithClass += instance.weight();
3162
if (Instance.isMissingValue(predictedValue)) {
3163
m_Unclassified += instance.weight();
3166
m_SumClass += instance.weight() * instance.classValue();
3167
m_SumSqrClass += instance.weight() * instance.classValue()
3168
* instance.classValue();
3169
m_SumClassPredicted += instance.weight()
3170
* instance.classValue() * predictedValue;
3171
m_SumPredicted += instance.weight() * predictedValue;
3172
m_SumSqrPredicted += instance.weight() * predictedValue * predictedValue;
3174
if (m_ErrorEstimator == null) {
3175
setNumericPriorsFromBuffer();
3177
double predictedProb = Math.max(m_ErrorEstimator.getProbability(
3179
- instance.classValue()),
3181
double priorProb = Math.max(m_PriorErrorEstimator.getProbability(
3182
instance.classValue()),
3185
m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
3186
m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
3187
m_ErrorEstimator.addValue(predictedValue - instance.classValue(),
3190
updateNumericScores(makeDistribution(predictedValue),
3191
makeDistribution(instance.classValue()),
3195
m_MissingClass += instance.weight();
3199
* Update the cumulative record of classification margins
3201
* @param predictedDistribution the probability distribution predicted for
3202
* the current instance
3203
* @param actualClass the index of the actual instance class
3204
* @param weight the weight assigned to the instance
3206
protected void updateMargins(double [] predictedDistribution,
3207
int actualClass, double weight) {
3209
double probActual = predictedDistribution[actualClass];
3210
double probNext = 0;
3212
for(int i = 0; i < m_NumClasses; i++)
3213
if ((i != actualClass) &&
3214
(predictedDistribution[i] > probNext))
3215
probNext = predictedDistribution[i];
3217
double margin = probActual - probNext;
3218
int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
3219
m_MarginCounts[bin] += weight;
3223
* Update the numeric accuracy measures. For numeric classes, the
3224
* accuracy is between the actual and predicted class values. For
3225
* nominal classes, the accuracy is between the actual and
3226
* predicted class probabilities.
3228
* @param predicted the predicted values
3229
* @param actual the actual value
3230
* @param weight the weight associated with this prediction
3232
protected void updateNumericScores(double [] predicted,
3233
double [] actual, double weight) {
3236
double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
3237
double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
3238
for(int i = 0; i < m_NumClasses; i++) {
3239
diff = predicted[i] - actual[i];
3241
sumAbsErr += Math.abs(diff);
3242
sumSqrErr += diff * diff;
3243
diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
3244
sumPriorAbsErr += Math.abs(diff);
3245
sumPriorSqrErr += diff * diff;
3247
m_SumErr += weight * sumErr / m_NumClasses;
3248
m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
3249
m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
3250
m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
3251
m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
3255
* Adds a numeric (non-missing) training class value and weight to
3256
* the buffer of stored values.
3258
* @param classValue the class value
3259
* @param weight the instance weight
3261
protected void addNumericTrainClass(double classValue, double weight) {
3263
if (m_TrainClassVals == null) {
3264
m_TrainClassVals = new double [100];
3265
m_TrainClassWeights = new double [100];
3267
if (m_NumTrainClassVals == m_TrainClassVals.length) {
3268
double [] temp = new double [m_TrainClassVals.length * 2];
3269
System.arraycopy(m_TrainClassVals, 0,
3270
temp, 0, m_TrainClassVals.length);
3271
m_TrainClassVals = temp;
3273
temp = new double [m_TrainClassWeights.length * 2];
3274
System.arraycopy(m_TrainClassWeights, 0,
3275
temp, 0, m_TrainClassWeights.length);
3276
m_TrainClassWeights = temp;
3278
m_TrainClassVals[m_NumTrainClassVals] = classValue;
3279
m_TrainClassWeights[m_NumTrainClassVals] = weight;
3280
m_NumTrainClassVals++;
3284
* Sets up the priors for numeric class attributes from the
3285
* training class values that have been seen so far.
3287
protected void setNumericPriorsFromBuffer() {
3289
double numPrecision = 0.01; // Default value
3290
if (m_NumTrainClassVals > 1) {
3291
double [] temp = new double [m_NumTrainClassVals];
3292
System.arraycopy(m_TrainClassVals, 0, temp, 0, m_NumTrainClassVals);
3293
int [] index = Utils.sort(temp);
3294
double lastVal = temp[index[0]];
3295
double deltaSum = 0;
3297
for (int i = 1; i < temp.length; i++) {
3298
double current = temp[index[i]];
3299
if (current != lastVal) {
3300
deltaSum += current - lastVal;
3306
numPrecision = deltaSum / distinct;
3309
m_PriorErrorEstimator = new KernelEstimator(numPrecision);
3310
m_ErrorEstimator = new KernelEstimator(numPrecision);
3311
m_ClassPriors[0] = m_ClassPriorsSum = 0;
3312
for (int i = 0; i < m_NumTrainClassVals; i++) {
3313
m_ClassPriors[0] += m_TrainClassVals[i] * m_TrainClassWeights[i];
3314
m_ClassPriorsSum += m_TrainClassWeights[i];
3315
m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
3316
m_TrainClassWeights[i]);