2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* EnsembleSelection.java
19
* Copyright (C) 2006 David Michael
23
package weka.classifiers.meta;
25
import weka.classifiers.Evaluation;
26
import weka.classifiers.RandomizableClassifier;
27
import weka.classifiers.meta.ensembleSelection.EnsembleMetricHelper;
28
import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibrary;
29
import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibraryModel;
30
import weka.classifiers.meta.ensembleSelection.ModelBag;
31
import weka.classifiers.trees.REPTree;
32
import weka.classifiers.xml.XMLClassifier;
33
import weka.core.Capabilities;
34
import weka.core.Instance;
35
import weka.core.Instances;
36
import weka.core.Option;
37
import weka.core.SelectedTag;
39
import weka.core.TechnicalInformation;
40
import weka.core.TechnicalInformationHandler;
41
import weka.core.Utils;
42
import weka.core.Capabilities.Capability;
43
import weka.core.TechnicalInformation.Field;
44
import weka.core.TechnicalInformation.Type;
45
import weka.core.xml.KOML;
46
import weka.core.xml.XMLOptions;
47
import weka.core.xml.XMLSerialization;
49
import java.io.BufferedInputStream;
50
import java.io.BufferedOutputStream;
51
import java.io.BufferedReader;
53
import java.io.FileInputStream;
54
import java.io.FileOutputStream;
55
import java.io.FileReader;
56
import java.io.InputStream;
57
import java.io.ObjectInputStream;
58
import java.io.ObjectOutputStream;
59
import java.io.OutputStream;
60
import java.util.Date;
61
import java.util.Enumeration;
62
import java.util.HashMap;
63
import java.util.Iterator;
65
import java.util.Random;
67
import java.util.Vector;
68
import java.util.zip.GZIPInputStream;
69
import java.util.zip.GZIPOutputStream;
72
<!-- globalinfo-start -->
73
* Combines several classifiers using the ensemble selection method. For more information, see: Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, Ensemble Selection from Libraries of Models, The International Conference on Machine Learning (ICML'04), 2004. Implemented in Weka by Bob Jung and David Michael.
75
<!-- globalinfo-end -->
77
<!-- technical-bibtex-start -->
80
* @inproceedings{RichCaruana2004,
81
* author = {Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes},
82
* booktitle = {21st International Conference on Machine Learning},
83
* title = {Ensemble Selection from Libraries of Models},
88
<!-- technical-bibtex-end -->
90
* Our implementation of ensemble selection is a bit different from the other
91
* classifiers because we assume that the list of models to be trained is too
92
* large to fit in memory and that our base classifiers will need to be
93
* serialized to the file system (in the directory listed in the "workingDirectory
94
* option). We have adopted the term "model library" for this large set of
95
* classifiers keeping in line with the original paper.
98
* If you are planning to use this classifier, we highly recommend you take a
99
* quick look at our FAQ/tutorial on the WIKI. There are a few things that
100
* are unique to this classifier that could trip you up. Otherwise, this
101
* method is a great way to get really great classifier performance without
102
* having to do too much parameter tuning. What is nice is that in the worst
103
* case you get a nice summary of how s large number of diverse models
104
* performed on your data set.
107
* This class relies on the package weka.classifiers.meta.ensembleSelection.
110
* When run from the Explorer or another GUI, the classifier depends on the
111
* package weka.gui.libraryEditor.
114
<!-- options-start -->
115
* Valid options are: <p/>
117
* <pre> -L </path/to/modelLibrary>
118
* Specifies the Model Library File, continuing the list of all models.</pre>
120
* <pre> -W </path/to/working/directory>
121
* Specifies the Working Directory, where all models will be stored.</pre>
123
* <pre> -B <numModelBags>
124
* Set the number of bags, i.e., number of iterations to run
125
* the ensemble selection algorithm.</pre>
127
* <pre> -E <modelRatio>
128
* Set the ratio of library models that will be randomly chosen
129
* to populate each bag of models.</pre>
131
* <pre> -V <validationRatio>
132
* Set the ratio of the training data set that will be reserved
133
* for validation.</pre>
135
* <pre> -H <hillClimbIterations>
136
* Set the number of hillclimbing iterations to be performed
137
* on each model bag.</pre>
139
* <pre> -I <sortInitialization>
140
* Set the the ratio of the ensemble library that the sort
141
* initialization algorithm will be able to choose from while
142
* initializing the ensemble for each model bag</pre>
144
* <pre> -X <numFolds>
145
* Sets the number of cross-validation folds.</pre>
147
* <pre> -P <hillclimbMettric>
148
* Specify the metric that will be used for model selection
149
* during the hillclimbing algorithm.
151
* accuracy, rmse, roc, precision, recall, fscore, all</pre>
153
* <pre> -A <algorithm>
154
* Specifies the algorithm to be used for ensemble selection.
155
* Valid algorithms are:
156
* "forward" (default) for forward selection.
157
* "backward" for backward elimination.
158
* "both" for both forward and backward elimination.
159
* "best" to simply print out top performer from the
161
* "library" to only train the models in the ensemble
165
* Flag whether or not models can be selected more than once
166
* for an ensemble.</pre>
169
* Whether sort initialization greedily stops adding models
170
* when performance degrades.</pre>
173
* Flag for verbose output. Prints out performance of all
174
* selected models.</pre>
176
* <pre> -S <num>
177
* Random number seed.
181
* If set, classifier is run in debug mode and
182
* may output additional info to the console</pre>
186
* @author Robert Jung
187
* @author David Michael
188
* @version $Revision: 1.3 $
190
public class EnsembleSelection
191
extends RandomizableClassifier
192
implements TechnicalInformationHandler {
194
/** for serialization */
195
private static final long serialVersionUID = -1744155148765058511L;
198
* The Library of models, from which we can select our ensemble. Usually
199
* loaded from a model list file (.mlf or .model.xml) using the -L
200
* command-line option.
202
protected EnsembleSelectionLibrary m_library = new EnsembleSelectionLibrary();
205
* List of models chosen by EnsembleSelection. Populated by buildClassifier.
207
protected EnsembleSelectionLibraryModel[] m_chosen_models = null;
210
* An array of weights for the chosen models. Elements are parallel to those
211
* in m_chosen_models. That is, m_chosen_model_weights[i] is the weight
212
* associated with the model at m_chosen_models[i].
214
protected int[] m_chosen_model_weights = null;
216
/** Total weight of all chosen models. */
217
protected int m_total_weight = 0;
220
* ratio of library models that will be randomly chosen to be used for each
223
protected double m_modelRatio = 0.5;
226
* Indicates the fraction of the given training set that should be used for
227
* hillclimbing/validation. This fraction is set aside and not used for
228
* training. It is assumed that any loaded models were also not trained on
229
* set-aside data. (If the same percentage and random seed were used
230
* previously to train the models in the library, this will work as expected -
231
* i.e., those models will be valid)
233
protected double m_validationRatio = 0.25;
235
/** defines metrics that can be chosen for hillclimbing */
236
public static final Tag[] TAGS_METRIC = {
237
new Tag(EnsembleMetricHelper.METRIC_ACCURACY, "Optimize with Accuracy"),
238
new Tag(EnsembleMetricHelper.METRIC_RMSE, "Optimize with RMSE"),
239
new Tag(EnsembleMetricHelper.METRIC_ROC, "Optimize with ROC"),
240
new Tag(EnsembleMetricHelper.METRIC_PRECISION, "Optimize with precision"),
241
new Tag(EnsembleMetricHelper.METRIC_RECALL, "Optimize with recall"),
242
new Tag(EnsembleMetricHelper.METRIC_FSCORE, "Optimize with fscore"),
243
new Tag(EnsembleMetricHelper.METRIC_ALL, "Optimize with all metrics"), };
246
* The "enumeration" of the algorithms we can use. Forward - forward
247
* selection. For hillclimb iterations,
249
public static final int ALGORITHM_FORWARD = 0;
251
public static final int ALGORITHM_BACKWARD = 1;
253
public static final int ALGORITHM_FORWARD_BACKWARD = 2;
255
public static final int ALGORITHM_BEST = 3;
257
public static final int ALGORITHM_BUILD_LIBRARY = 4;
259
/** defines metrics that can be chosen for hillclimbing */
260
public static final Tag[] TAGS_ALGORITHM = {
261
new Tag(ALGORITHM_FORWARD, "Forward selection"),
262
new Tag(ALGORITHM_BACKWARD, "Backward elimation"),
263
new Tag(ALGORITHM_FORWARD_BACKWARD, "Forward Selection + Backward Elimination"),
264
new Tag(ALGORITHM_BEST, "Best model"),
265
new Tag(ALGORITHM_BUILD_LIBRARY, "Build Library Only") };
268
* this specifies the number of "Ensembl-X" directories that are allowed to
269
* be created in the users home directory where X is the number of the
272
private static final int MAX_DEFAULT_DIRECTORIES = 1000;
275
* The name of the Model Library File (if one is specified) which lists
276
* models from which ensemble selection will choose. This is only used when
277
* run from the command-line, as otherwise m_library is responsible for
280
protected String m_modelLibraryFileName = null;
283
* The number of "model bags". Using 1 is equivalent to no bagging at all.
285
protected int m_numModelBags = 10;
287
/** The metric for which the ensemble will be optimized. */
288
protected int m_hillclimbMetric = EnsembleMetricHelper.METRIC_RMSE;
290
/** The algorithm used for ensemble selection. */
291
protected int m_algorithm = ALGORITHM_FORWARD;
294
* number of hillclimbing iterations for the ensemble selection algorithm
296
protected int m_hillclimbIterations = 100;
298
/** ratio of library models to be used for sort initialization */
299
protected double m_sortInitializationRatio = 1.0;
302
* specifies whether or not the ensemble algorithm is allowed to include a
303
* specific model in the library more than once in each ensemble
305
protected boolean m_replacement = true;
308
* specifies whether we use "greedy" sort initialization. If false, we
309
* simply add the best m_sortInitializationRatio models of the bag blindly.
310
* If true, we add the best models in order up to m_sortInitializationRatio
311
* until adding the next model would not help performance.
313
protected boolean m_greedySortInitialization = true;
316
* Specifies whether or not we will output metrics for all models
318
protected boolean m_verboseOutput = false;
321
* Hash map of cached predictions. The key is a stringified Instance. Each
322
* entry is a 2d array, first indexed by classifier index (i.e., the one
323
* used in m_chosen_model). The second index is the usual "distribution"
324
* index across classes.
326
protected Map m_cachedPredictions = null;
329
* This string will store the working directory where all models , temporary
330
* prediction values, and modellist logs are to be built and stored.
332
protected File m_workingDirectory = new File(getDefaultWorkingDirectory());
335
* Indicates the number of folds for cross-validation. A value of 1
336
* indicates there is no cross-validation. Cross validation is done in the
337
* "embedded" fashion described by Caruana, Niculescu, and Munson
338
* (unpublished work - tech report forthcoming)
340
protected int m_NumFolds = 1;
343
* Returns a string describing classifier
345
* @return a description suitable for displaying in the
346
* explorer/experimenter gui
348
public String globalInfo() {
350
return "Combines several classifiers using the ensemble "
351
+ "selection method. For more information, see: "
352
+ "Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, "
353
+ "Ensemble Selection from Libraries of Models, "
354
+ "The International Conference on Machine Learning (ICML'04), 2004. "
355
+ "Implemented in Weka by Bob Jung and David Michael.";
359
* Returns an enumeration describing the available options.
361
* @return an enumeration of all the available options.
363
public Enumeration listOptions() {
364
Vector result = new Vector();
366
result.addElement(new Option(
367
"\tSpecifies the Model Library File, continuing the list of all models.",
368
"L", 1, "-L </path/to/modelLibrary>"));
370
result.addElement(new Option(
371
"\tSpecifies the Working Directory, where all models will be stored.",
372
"W", 1, "-W </path/to/working/directory>"));
374
result.addElement(new Option(
375
"\tSet the number of bags, i.e., number of iterations to run \n"
376
+ "\tthe ensemble selection algorithm.",
377
"B", 1, "-B <numModelBags>"));
379
result.addElement(new Option(
380
"\tSet the ratio of library models that will be randomly chosen \n"
381
+ "\tto populate each bag of models.",
382
"E", 1, "-E <modelRatio>"));
384
result.addElement(new Option(
385
"\tSet the ratio of the training data set that will be reserved \n"
386
+ "\tfor validation.",
387
"V", 1, "-V <validationRatio>"));
389
result.addElement(new Option(
390
"\tSet the number of hillclimbing iterations to be performed \n"
391
+ "\ton each model bag.",
392
"H", 1, "-H <hillClimbIterations>"));
394
result.addElement(new Option(
395
"\tSet the the ratio of the ensemble library that the sort \n"
396
+ "\tinitialization algorithm will be able to choose from while \n"
397
+ "\tinitializing the ensemble for each model bag",
398
"I", 1, "-I <sortInitialization>"));
400
result.addElement(new Option(
401
"\tSets the number of cross-validation folds.",
402
"X", 1, "-X <numFolds>"));
404
result.addElement(new Option(
405
"\tSpecify the metric that will be used for model selection \n"
406
+ "\tduring the hillclimbing algorithm.\n"
407
+ "\tValid metrics are: \n"
408
+ "\t\taccuracy, rmse, roc, precision, recall, fscore, all",
409
"P", 1, "-P <hillclimbMettric>"));
411
result.addElement(new Option(
412
"\tSpecifies the algorithm to be used for ensemble selection. \n"
413
+ "\tValid algorithms are:\n"
414
+ "\t\t\"forward\" (default) for forward selection.\n"
415
+ "\t\t\"backward\" for backward elimination.\n"
416
+ "\t\t\"both\" for both forward and backward elimination.\n"
417
+ "\t\t\"best\" to simply print out top performer from the \n"
418
+ "\t\t ensemble library\n"
419
+ "\t\t\"library\" to only train the models in the ensemble \n"
421
"A", 1, "-A <algorithm>"));
423
result.addElement(new Option(
424
"\tFlag whether or not models can be selected more than once \n"
425
+ "\tfor an ensemble.",
428
result.addElement(new Option(
429
"\tWhether sort initialization greedily stops adding models \n"
430
+ "\twhen performance degrades.",
433
result.addElement(new Option(
434
"\tFlag for verbose output. Prints out performance of all \n"
435
+ "\tselected models.",
438
// TODO - Add more options here
439
Enumeration enu = super.listOptions();
440
while (enu.hasMoreElements()) {
441
result.addElement(enu.nextElement());
444
return result.elements();
448
* We return true for basically everything except for Missing class values,
449
* because we can't really answer for all the models in our library. If any of
450
* them don't work with the supplied data then we just trap the exception.
452
* @return the capabilities of this classifier
454
public Capabilities getCapabilities() {
455
Capabilities result = super.getCapabilities(); // returns the object
457
// weka.classifiers.Classifier
460
result.enable(Capability.NOMINAL_ATTRIBUTES);
461
result.enable(Capability.NUMERIC_ATTRIBUTES);
462
result.enable(Capability.DATE_ATTRIBUTES);
463
result.enable(Capability.MISSING_VALUES);
464
result.enable(Capability.BINARY_ATTRIBUTES);
467
result.enable(Capability.NOMINAL_CLASS);
468
result.enable(Capability.NUMERIC_CLASS);
469
result.enable(Capability.BINARY_CLASS);
475
<!-- options-start -->
476
* Valid options are: <p/>
478
* <pre> -L </path/to/modelLibrary>
479
* Specifies the Model Library File, continuing the list of all models.</pre>
481
* <pre> -W </path/to/working/directory>
482
* Specifies the Working Directory, where all models will be stored.</pre>
484
* <pre> -B <numModelBags>
485
* Set the number of bags, i.e., number of iterations to run
486
* the ensemble selection algorithm.</pre>
488
* <pre> -E <modelRatio>
489
* Set the ratio of library models that will be randomly chosen
490
* to populate each bag of models.</pre>
492
* <pre> -V <validationRatio>
493
* Set the ratio of the training data set that will be reserved
494
* for validation.</pre>
496
* <pre> -H <hillClimbIterations>
497
* Set the number of hillclimbing iterations to be performed
498
* on each model bag.</pre>
500
* <pre> -I <sortInitialization>
501
* Set the the ratio of the ensemble library that the sort
502
* initialization algorithm will be able to choose from while
503
* initializing the ensemble for each model bag</pre>
505
* <pre> -X <numFolds>
506
* Sets the number of cross-validation folds.</pre>
508
* <pre> -P <hillclimbMettric>
509
* Specify the metric that will be used for model selection
510
* during the hillclimbing algorithm.
512
* accuracy, rmse, roc, precision, recall, fscore, all</pre>
514
* <pre> -A <algorithm>
515
* Specifies the algorithm to be used for ensemble selection.
516
* Valid algorithms are:
517
* "forward" (default) for forward selection.
518
* "backward" for backward elimination.
519
* "both" for both forward and backward elimination.
520
* "best" to simply print out top performer from the
522
* "library" to only train the models in the ensemble
526
* Flag whether or not models can be selected more than once
527
* for an ensemble.</pre>
530
* Whether sort initialization greedily stops adding models
531
* when performance degrades.</pre>
534
* Flag for verbose output. Prints out performance of all
535
* selected models.</pre>
537
* <pre> -S <num>
538
* Random number seed.
542
* If set, classifier is run in debug mode and
543
* may output additional info to the console</pre>
548
* the list of options as an array of strings
550
* if an option is not supported
552
public void setOptions(String[] options) throws Exception {
555
tmpStr = Utils.getOption('L', options);
556
if (tmpStr.length() != 0) {
557
m_modelLibraryFileName = tmpStr;
558
m_library = new EnsembleSelectionLibrary(m_modelLibraryFileName);
560
setLibrary(new EnsembleSelectionLibrary());
561
// setLibrary(new Library(super.m_Classifiers));
564
tmpStr = Utils.getOption('W', options);
565
if (tmpStr.length() != 0 && validWorkingDirectory(tmpStr)) {
566
m_workingDirectory = new File(tmpStr);
568
m_workingDirectory = new File(getDefaultWorkingDirectory());
570
m_library.setWorkingDirectory(m_workingDirectory);
572
tmpStr = Utils.getOption('E', options);
573
if (tmpStr.length() != 0) {
574
setModelRatio(Double.parseDouble(tmpStr));
579
tmpStr = Utils.getOption('V', options);
580
if (tmpStr.length() != 0) {
581
setValidationRatio(Double.parseDouble(tmpStr));
583
setValidationRatio(0.25);
586
tmpStr = Utils.getOption('B', options);
587
if (tmpStr.length() != 0) {
588
setNumModelBags(Integer.parseInt(tmpStr));
593
tmpStr = Utils.getOption('H', options);
594
if (tmpStr.length() != 0) {
595
setHillclimbIterations(Integer.parseInt(tmpStr));
597
setHillclimbIterations(100);
600
tmpStr = Utils.getOption('I', options);
601
if (tmpStr.length() != 0) {
602
setSortInitializationRatio(Double.parseDouble(tmpStr));
604
setSortInitializationRatio(1.0);
607
tmpStr = Utils.getOption('X', options);
608
if (tmpStr.length() != 0) {
609
setNumFolds(Integer.parseInt(tmpStr));
614
setReplacement(Utils.getFlag('R', options));
616
setGreedySortInitialization(Utils.getFlag('G', options));
618
setVerboseOutput(Utils.getFlag('O', options));
620
tmpStr = Utils.getOption('P', options);
621
// if (hillclimbMetricString.length() != 0) {
623
if (tmpStr.toLowerCase().equals("accuracy")) {
624
setHillclimbMetric(new SelectedTag(
625
EnsembleMetricHelper.METRIC_ACCURACY, TAGS_METRIC));
626
} else if (tmpStr.toLowerCase().equals("rmse")) {
627
setHillclimbMetric(new SelectedTag(
628
EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
629
} else if (tmpStr.toLowerCase().equals("roc")) {
630
setHillclimbMetric(new SelectedTag(
631
EnsembleMetricHelper.METRIC_ROC, TAGS_METRIC));
632
} else if (tmpStr.toLowerCase().equals("precision")) {
633
setHillclimbMetric(new SelectedTag(
634
EnsembleMetricHelper.METRIC_PRECISION, TAGS_METRIC));
635
} else if (tmpStr.toLowerCase().equals("recall")) {
636
setHillclimbMetric(new SelectedTag(
637
EnsembleMetricHelper.METRIC_RECALL, TAGS_METRIC));
638
} else if (tmpStr.toLowerCase().equals("fscore")) {
639
setHillclimbMetric(new SelectedTag(
640
EnsembleMetricHelper.METRIC_FSCORE, TAGS_METRIC));
641
} else if (tmpStr.toLowerCase().equals("all")) {
642
setHillclimbMetric(new SelectedTag(
643
EnsembleMetricHelper.METRIC_ALL, TAGS_METRIC));
645
setHillclimbMetric(new SelectedTag(
646
EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
649
tmpStr = Utils.getOption('A', options);
650
if (tmpStr.toLowerCase().equals("forward")) {
651
setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
652
} else if (tmpStr.toLowerCase().equals("backward")) {
653
setAlgorithm(new SelectedTag(ALGORITHM_BACKWARD, TAGS_ALGORITHM));
654
} else if (tmpStr.toLowerCase().equals("both")) {
655
setAlgorithm(new SelectedTag(ALGORITHM_FORWARD_BACKWARD, TAGS_ALGORITHM));
656
} else if (tmpStr.toLowerCase().equals("forward")) {
657
setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
658
} else if (tmpStr.toLowerCase().equals("best")) {
659
setAlgorithm(new SelectedTag(ALGORITHM_BEST, TAGS_ALGORITHM));
660
} else if (tmpStr.toLowerCase().equals("library")) {
661
setAlgorithm(new SelectedTag(ALGORITHM_BUILD_LIBRARY, TAGS_ALGORITHM));
663
setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
666
super.setOptions(options);
668
m_library.setDebug(m_Debug);
673
* Gets the current settings of the Classifier.
675
* @return an array of strings suitable for passing to setOptions
677
public String[] getOptions() {
682
result = new Vector();
684
if (m_library.getModelListFile() != null) {
686
result.add("" + m_library.getModelListFile());
689
if (!m_workingDirectory.equals("")) {
691
result.add("" + getWorkingDirectory());
695
switch (getHillclimbMetric().getSelectedTag().getID()) {
696
case (EnsembleMetricHelper.METRIC_ACCURACY):
697
result.add("accuracy");
699
case (EnsembleMetricHelper.METRIC_RMSE):
702
case (EnsembleMetricHelper.METRIC_ROC):
705
case (EnsembleMetricHelper.METRIC_PRECISION):
706
result.add("precision");
708
case (EnsembleMetricHelper.METRIC_RECALL):
709
result.add("recall");
711
case (EnsembleMetricHelper.METRIC_FSCORE):
712
result.add("fscore");
714
case (EnsembleMetricHelper.METRIC_ALL):
720
switch (getAlgorithm().getSelectedTag().getID()) {
721
case (ALGORITHM_FORWARD):
722
result.add("forward");
724
case (ALGORITHM_BACKWARD):
725
result.add("backward");
727
case (ALGORITHM_FORWARD_BACKWARD):
730
case (ALGORITHM_BEST):
733
case (ALGORITHM_BUILD_LIBRARY):
734
result.add("library");
739
result.add("" + getNumModelBags());
741
result.add("" + getValidationRatio());
743
result.add("" + getModelRatio());
745
result.add("" + getHillclimbIterations());
747
result.add("" + getSortInitializationRatio());
749
result.add("" + getNumFolds());
753
if (m_greedySortInitialization)
758
options = super.getOptions();
759
for (i = 0; i < options.length; i++)
760
result.add(options[i]);
762
return (String[]) result.toArray(new String[result.size()]);
766
* Returns the tip text for this property
768
* @return tip text for this property suitable for displaying in the
769
* explorer/experimenter gui
771
public String numFoldsTipText() {
772
return "The number of folds used for cross-validation.";
776
* Gets the number of folds for the cross-validation.
778
* @return the number of folds for the cross-validation
780
public int getNumFolds() {
785
* Sets the number of folds for the cross-validation.
788
* the number of folds for the cross-validation
790
* if parameter illegal
792
public void setNumFolds(int numFolds) throws Exception {
794
throw new IllegalArgumentException(
795
"EnsembleSelection: Number of cross-validation "
796
+ "folds must be positive.");
798
m_NumFolds = numFolds;
802
* Returns the tip text for this property
804
* @return tip text for this property suitable for displaying in the
805
* explorer/experimenter gui
807
public String libraryTipText() {
808
return "An ensemble library.";
812
* Gets the ensemble library.
814
* @return the ensemble library
816
public EnsembleSelectionLibrary getLibrary() {
821
* Sets the ensemble library.
824
* the ensemble library
826
public void setLibrary(EnsembleSelectionLibrary newLibrary) {
827
m_library = newLibrary;
828
m_library.setDebug(m_Debug);
832
* Returns the tip text for this property
834
* @return tip text for this property suitable for displaying in the
835
* explorer/experimenter gui
837
public String modelRatioTipText() {
838
return "The ratio of library models that will be randomly chosen to be used for each iteration.";
842
* Get the value of modelRatio.
844
* @return Value of modelRatio.
846
public double getModelRatio() {
851
* Set the value of modelRatio.
854
* Value to assign to modelRatio.
856
public void setModelRatio(double v) {
861
* Returns the tip text for this property
863
* @return tip text for this property suitable for displaying in the
864
* explorer/experimenter gui
866
public String validationRatioTipText() {
867
return "The ratio of the training data set that will be reserved for validation.";
871
* Get the value of validationRatio.
873
* @return Value of validationRatio.
875
public double getValidationRatio() {
876
return m_validationRatio;
880
* Set the value of validationRatio.
883
* Value to assign to validationRatio.
885
public void setValidationRatio(double v) {
886
m_validationRatio = v;
890
* Returns the tip text for this property
892
* @return tip text for this property suitable for displaying in the
893
* explorer/experimenter gui
895
public String hillclimbMetricTipText() {
896
return "the metric that will be used to optimizer the chosen ensemble..";
900
* Gets the hill climbing metric. Will be one of METRIC_ACCURACY,
901
* METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
904
* @return the hillclimbMetric
906
public SelectedTag getHillclimbMetric() {
907
return new SelectedTag(m_hillclimbMetric, TAGS_METRIC);
911
* Sets the hill climbing metric. Will be one of METRIC_ACCURACY,
912
* METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
916
* the new hillclimbMetric
918
public void setHillclimbMetric(SelectedTag newType) {
919
if (newType.getTags() == TAGS_METRIC) {
920
m_hillclimbMetric = newType.getSelectedTag().getID();
925
* Returns the tip text for this property
927
* @return tip text for this property suitable for displaying in the
928
* explorer/experimenter gui
930
public String algorithmTipText() {
931
return "the algorithm used to optimizer the ensemble";
937
* @return the algorithm
939
public SelectedTag getAlgorithm() {
940
return new SelectedTag(m_algorithm, TAGS_ALGORITHM);
944
* Sets the Algorithm to use
949
public void setAlgorithm(SelectedTag newType) {
950
if (newType.getTags() == TAGS_ALGORITHM) {
951
m_algorithm = newType.getSelectedTag().getID();
956
* Returns the tip text for this property
958
* @return tip text for this property suitable for displaying in the
959
* explorer/experimenter gui
961
public String hillclimbIterationsTipText() {
962
return "The number of hillclimbing iterations for the ensemble selection algorithm.";
966
* Gets the number of hillclimbIterations.
968
* @return the number of hillclimbIterations
970
public int getHillclimbIterations() {
971
return m_hillclimbIterations;
975
* Sets the number of hillclimbIterations.
978
* the number of hillclimbIterations
980
* if parameter illegal
982
public void setHillclimbIterations(int n) throws Exception {
984
throw new IllegalArgumentException(
985
"EnsembleSelection: Number of hillclimb iterations "
986
+ "must be positive.");
988
m_hillclimbIterations = n;
992
* Returns the tip text for this property
994
* @return tip text for this property suitable for displaying in the
995
* explorer/experimenter gui
997
public String numModelBagsTipText() {
998
return "The number of \"model bags\" used in the ensemble selection algorithm.";
1002
* Gets numModelBags.
1004
* @return numModelBags
1006
public int getNumModelBags() {
1007
return m_numModelBags;
1011
* Sets numModelBags.
1014
* the new value for numModelBags
1016
* if parameter illegal
1018
public void setNumModelBags(int n) throws Exception {
1020
throw new IllegalArgumentException(
1021
"EnsembleSelection: Number of model bags "
1022
+ "must be positive.");
1028
* Returns the tip text for this property
1030
* @return tip text for this property suitable for displaying in the
1031
* explorer/experimenter gui
1033
public String sortInitializationRatioTipText() {
1034
return "The ratio of library models to be used for sort initialization.";
1038
* Get the value of sortInitializationRatio.
1040
* @return Value of sortInitializationRatio.
1042
public double getSortInitializationRatio() {
1043
return m_sortInitializationRatio;
1047
* Set the value of sortInitializationRatio.
1050
* Value to assign to sortInitializationRatio.
1052
public void setSortInitializationRatio(double v) {
1053
m_sortInitializationRatio = v;
1057
* Returns the tip text for this property
1059
* @return tip text for this property suitable for displaying in the
1060
* explorer/experimenter gui
1062
public String replacementTipText() {
1063
return "Whether models in the library can be included more than once in an ensemble.";
1067
* Get the value of replacement.
1069
* @return Value of replacement.
1071
public boolean getReplacement() {
1072
return m_replacement;
1076
* Set the value of replacement.
1078
* @param newReplacement
1079
* Value to assign to replacement.
1081
public void setReplacement(boolean newReplacement) {
1082
m_replacement = newReplacement;
1086
* Returns the tip text for this property
1088
* @return tip text for this property suitable for displaying in the
1089
* explorer/experimenter gui
1091
public String greedySortInitializationTipText() {
1092
return "Whether sort initialization greedily stops adding models when performance degrades.";
1096
* Get the value of greedySortInitialization.
1098
* @return Value of replacement.
1100
public boolean getGreedySortInitialization() {
1101
return m_greedySortInitialization;
1105
* Set the value of greedySortInitialization.
1107
* @param newGreedySortInitialization
1108
* Value to assign to replacement.
1110
public void setGreedySortInitialization(boolean newGreedySortInitialization) {
1111
m_greedySortInitialization = newGreedySortInitialization;
1115
* Returns the tip text for this property
1117
* @return tip text for this property suitable for displaying in the
1118
* explorer/experimenter gui
1120
public String verboseOutputTipText() {
1121
return "Whether metrics are printed for each model.";
1125
* Get the value of verboseOutput.
1127
* @return Value of verboseOutput.
1129
public boolean getVerboseOutput() {
1130
return m_verboseOutput;
1134
* Set the value of verboseOutput.
1136
* @param newVerboseOutput
1137
* Value to assign to verboseOutput.
1139
public void setVerboseOutput(boolean newVerboseOutput) {
1140
m_verboseOutput = newVerboseOutput;
1144
* Returns the tip text for this property
1146
* @return tip text for this property suitable for displaying in the
1147
* explorer/experimenter gui
1149
public String workingDirectoryTipText() {
1150
return "The working directory of the ensemble - where trained models will be stored.";
1154
* Get the value of working directory.
1156
* @return Value of working directory.
1158
public File getWorkingDirectory() {
1159
return m_workingDirectory;
1163
* Set the value of working directory.
1165
* @param newWorkingDirectory directory Value.
1167
public void setWorkingDirectory(File newWorkingDirectory) {
1169
System.out.println("working directory changed to: "
1170
+ newWorkingDirectory);
1172
m_library.setWorkingDirectory(newWorkingDirectory);
1174
m_workingDirectory = newWorkingDirectory;
1178
* Buildclassifier selects a classifier from the set of classifiers by
1179
* minimising error on the training data.
1181
* @param trainData the training data to be used for generating the boosted
1183
* @throws Exception if the classifier could not be built successfully
1185
public void buildClassifier(Instances trainData) throws Exception {
1187
getCapabilities().testWithFail(trainData);
1189
// First we need to make sure that some library models
1190
// were specified. If not, then use the default list
1191
if (m_library.m_Models.size() == 0) {
1194
.println("WARNING: No library file specified. Using some default models.");
1196
.println("You should specify a model list with -L <file> from the command line.");
1198
.println("Or edit the list directly with the LibraryEditor from the GUI");
1200
for (int i = 0; i < 10; i++) {
1202
REPTree tree = new REPTree();
1204
m_library.addModel(new EnsembleSelectionLibraryModel(tree));
1210
if (m_library == null) {
1211
m_library = new EnsembleSelectionLibrary();
1212
m_library.setDebug(m_Debug);
1215
m_library.setNumFolds(getNumFolds());
1216
m_library.setValidationRatio(getValidationRatio());
1217
// train all untrained models, and set "data" to the hillclimbing set.
1218
Instances data = m_library.trainAll(trainData, m_workingDirectory.getAbsolutePath(),
1220
// We cache the hillclimb predictions from all of the models in
1221
// the library so that we can evaluate their performances when we
1223
// in various ways (without needing to keep the classifiers in memory).
1224
double predictions[][][] = m_library.getHillclimbPredictions();
1225
int numModels = predictions.length;
1226
int modelWeights[] = new int[numModels];
1228
Random rand = new Random(m_Seed);
1230
if (m_algorithm == ALGORITHM_BUILD_LIBRARY) {
1233
} else if (m_algorithm == ALGORITHM_BEST) {
1234
// If we want to choose the best model, just make a model bag that
1235
// includes all the models, then sort initialize to find the 1 that
1237
ModelBag model_bag = new ModelBag(predictions, 1.0, m_Debug);
1238
int[] modelPicked = model_bag.sortInitialize(1, false, data,
1240
// Then give it a weight of 1, while all others remain 0.
1241
modelWeights[modelPicked[0]] = 1;
1245
System.out.println("Starting hillclimbing algorithm: "
1248
for (int i = 0; i < getNumModelBags(); ++i) {
1249
// For the number of bags,
1251
System.out.println("Starting on ensemble bag: " + i);
1252
// Create a new bag of the appropriate size
1253
ModelBag modelBag = new ModelBag(predictions, getModelRatio(),
1256
modelBag.shuffle(rand);
1257
if (getSortInitializationRatio() > 0.0) {
1258
// Sort initialize, if the ratio greater than 0.
1259
modelBag.sortInitialize((int) (getSortInitializationRatio()
1260
* getModelRatio() * numModels),
1261
getGreedySortInitialization(), data,
1265
if (m_algorithm == ALGORITHM_BACKWARD) {
1266
// If we're doing backwards elimination, we just give all
1268
// a weight of 1 initially. If the # of hillclimb iterations
1269
// is too high, we'll end up with just one model in the end
1270
// (we never delete all models from a bag). TODO - it might
1272
// smarter to base this weight off of how many models we
1274
modelBag.weightAll(1); // for now at least, I'm just
1277
// Now the bag is initialized, and we're ready to hillclimb.
1278
for (int j = 0; j < getHillclimbIterations(); ++j) {
1279
if (m_algorithm == ALGORITHM_FORWARD) {
1280
modelBag.forwardSelect(getReplacement(), data,
1282
} else if (m_algorithm == ALGORITHM_BACKWARD) {
1283
modelBag.backwardEliminate(data, m_hillclimbMetric);
1284
} else if (m_algorithm == ALGORITHM_FORWARD_BACKWARD) {
1285
modelBag.forwardSelectOrBackwardEliminate(
1286
getReplacement(), data, m_hillclimbMetric);
1289
// Now that we've done all the hillclimbing steps, we can just
1291
// the model weights that the bag determined, and add them to
1294
int[] bagWeights = modelBag.getModelWeights();
1295
for (int j = 0; j < bagWeights.length; ++j) {
1296
modelWeights[j] += bagWeights[j];
1300
// Now we've done the hard work of actually learning the ensemble. Now
1301
// we set up the appropriate data structures so that Ensemble Selection
1303
// make predictions for future test examples.
1304
Set modelNames = m_library.getModelNames();
1305
String[] modelNamesArray = new String[m_library.size()];
1306
Iterator iter = modelNames.iterator();
1307
// libraryIndex indexes over all the models in the library (not just
1309
// which we chose for the ensemble).
1310
int libraryIndex = 0;
1311
// chosenModels will count the total number of models which were
1313
// by EnsembleSelection (those that have non-zero weight).
1314
int chosenModels = 0;
1315
while (iter.hasNext()) {
1316
// Note that we have to be careful of order. Our model_weights array
1317
// is in the same order as our list of models in m_library.
1319
// Get the name of the model,
1320
modelNamesArray[libraryIndex] = (String) iter.next();
1322
int weightOfModel = modelWeights[libraryIndex++];
1323
m_total_weight += weightOfModel;
1324
if (weightOfModel > 0) {
1325
// If the model was chosen at least once, increment the
1326
// number of chosen models.
1330
if (m_verboseOutput) {
1331
// Output every model and its performance with respect to the
1334
ModelBag bag = new ModelBag(predictions, 1.0, m_Debug);
1335
int modelIndexes[] = bag.sortInitialize(modelNamesArray.length,
1336
false, data, m_hillclimbMetric);
1337
double modelPerformance[] = bag.getIndividualPerformance(data,
1339
for (int i = 0; i < modelIndexes.length; ++i) {
1340
// TODO - Could do this in a more readable way.
1341
System.out.println("" + modelPerformance[i] + " "
1342
+ modelNamesArray[modelIndexes[i]]);
1345
// We're now ready to build our array of the models which were chosen
1346
// and there associated weights.
1347
m_chosen_models = new EnsembleSelectionLibraryModel[chosenModels];
1348
m_chosen_model_weights = new int[chosenModels];
1351
// chosenIndex indexes over the models which were chosen by
1352
// EnsembleSelection
1353
// (those which have non-zero weight).
1354
int chosenIndex = 0;
1355
iter = m_library.getModels().iterator();
1356
while (iter.hasNext()) {
1357
int weightOfModel = modelWeights[libraryIndex++];
1359
EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel) iter
1362
if (weightOfModel > 0) {
1363
// If the model was chosen at least once, add it to our array
1364
// of chosen models and weights.
1365
m_chosen_models[chosenIndex] = model;
1366
m_chosen_model_weights[chosenIndex] = weightOfModel;
1367
// Note that the EnsembleSelectionLibraryModel may not be
1369
// that is, its classifier(s) may be null pointers. That's okay
1371
// we'll "rehydrate" them later, if and when we need to.
1378
* Calculates the class membership probabilities for the given test instance.
1380
* @param instance the instance to be classified
1381
* @return predicted class probability distribution
1382
* @throws Exception if instance could not be classified
1385
public double[] distributionForInstance(Instance instance) throws Exception {
1386
String stringInstance = instance.toString();
1387
double cachedPreds[][] = null;
1389
if (m_cachedPredictions != null) {
1390
// If we have any cached predictions (i.e., if cachePredictions was
1391
// called), look for a cached set of predictions for this instance.
1392
if (m_cachedPredictions.containsKey(stringInstance)) {
1393
cachedPreds = (double[][]) m_cachedPredictions.get(stringInstance);
1396
double[] prediction = new double[instance.numClasses()];
1397
for (int i = 0; i < prediction.length; ++i) {
1398
prediction[i] = 0.0;
1401
// Now do a weighted average of the predictions of each of our models.
1402
for (int i = 0; i < m_chosen_models.length; ++i) {
1403
double[] predictionForThisModel = null;
1404
if (cachedPreds == null) {
1405
// If there are no predictions cached, we'll load the model's
1406
// classifier(s) in to memory and get the predictions.
1407
m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());
1408
predictionForThisModel = m_chosen_models[i].getAveragePrediction(instance);
1409
// We could release the model here to save memory, but we assume
1410
// that there is enough available since we're not using the
1411
// prediction caching functionality. If we load and release a
1413
// every time we need to get a prediction for an instance, it
1415
// prohibitively slow.
1417
// If it's cached, just get it from the array of cached preds
1418
// for this instance.
1419
predictionForThisModel = cachedPreds[i];
1421
// We have encountered a bug where MultilayerPerceptron returns a
1423
// prediction array. If that happens, we just don't count that model
1425
// our ensemble prediction.
1426
if (predictionForThisModel != null) {
1427
// Okay, the model returned a valid prediction array, so we'll
1428
// add the appropriate fraction of this model's prediction.
1429
for (int j = 0; j < prediction.length; ++j) {
1430
prediction[j] += m_chosen_model_weights[i] * predictionForThisModel[j] / m_total_weight;
1434
// normalize to add up to 1.
1435
if (instance.classAttribute().isNominal()) {
1436
if (Utils.sum(prediction) > 0)
1437
Utils.normalize(prediction);
1443
* This function tests whether or not a given path is appropriate for being
1444
* the working directory. Specifically, we care that we can write to the
1445
* path and that it doesn't point to a "non-directory" file handle.
1447
* @param dir the directory to test
1448
* @return true if the directory is valid
1450
private boolean validWorkingDirectory(String dir) {
1452
boolean valid = false;
1454
File f = new File((dir));
1457
if (f.isDirectory() && f.canWrite())
1469
* This method tries to find a reasonable path name for the ensemble working
1470
* directory where models and files will be stored.
1473
* @return true if m_workingDirectory now has a valid file name
1475
public static String getDefaultWorkingDirectory() {
1477
String defaultDirectory = new String("");
1479
boolean success = false;
1483
while (i < MAX_DEFAULT_DIRECTORIES && !success) {
1485
File f = new File(System.getProperty("user.home"), "Ensemble-" + i);
1487
if (!f.exists() && f.getParentFile().canWrite()) {
1488
defaultDirectory = f.getPath();
1496
defaultDirectory = new String("");
1497
// should we print an error or something?
1500
return defaultDirectory;
1504
* Output a representation of this classifier
1506
* @return a string representation of the classifier
1508
public String toString() {
1509
// We just print out the models which were selected, and the number
1510
// of times each was selected.
1511
String result = new String();
1512
if (m_chosen_models != null) {
1513
for (int i = 0; i < m_chosen_models.length; ++i) {
1514
result += m_chosen_model_weights[i];
1515
result += " " + m_chosen_models[i].getStringRepresentation()
1519
result = "No models selected.";
1525
* Cache predictions for the individual base classifiers in the ensemble
1526
* with respect to the given dataset. This is used so that when testing a
1527
* large ensemble on a test set, we don't have to keep the models in memory.
1529
* @param test The instances for which to cache predictions.
1530
* @throws Exception if somethng goes wrong
1532
private void cachePredictions(Instances test) throws Exception {
1533
m_cachedPredictions = new HashMap();
1534
Evaluation evalModel = null;
1535
Instances originalInstances = null;
1536
// If the verbose flag is set, we'll also print out the performances of
1537
// all the individual models w.r.t. this test set while we're at it.
1538
boolean printModelPerformances = getVerboseOutput();
1539
if (printModelPerformances) {
1540
// To get performances, we need to keep the class attribute.
1541
originalInstances = new Instances(test);
1544
// For each model, we'll go through the dataset and get predictions.
1545
// The idea is we want to only have one model in memory at a time, so
1547
// load one model in to memory, get all its predictions, and add them to
1549
// hash map. Then we can release it from memory and move on to the next.
1550
for (int i = 0; i < m_chosen_models.length; ++i) {
1551
if (printModelPerformances) {
1552
// If we're going to print predictions, we need to make a new
1553
// Evaluation object.
1554
evalModel = new Evaluation(originalInstances);
1557
Date startTime = new Date();
1559
// Load the model in to memory.
1560
m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());
1561
// Now loop through all the instances and get the model's
1563
for (int j = 0; j < test.numInstances(); ++j) {
1564
Instance currentInstance = test.instance(j);
1565
// When we're looking for a cached prediction later, we'll only
1566
// have the non-class attributes, so we set the class missing
1568
// in order to make the string match up properly.
1569
currentInstance.setClassMissing();
1570
String stringInstance = currentInstance.toString();
1572
// When we come in here with the first model, the instance will
1574
// yet be part of the map.
1575
if (!m_cachedPredictions.containsKey(stringInstance)) {
1576
// The instance isn't in the map yet, so add it.
1577
// For each instance, we store a two-dimensional array - the
1579
// index is over all the models in the ensemble, and the
1581
// index is over the (i.e., typical prediction array).
1582
int predSize = test.classAttribute().isNumeric() ? 1 : test
1583
.classAttribute().numValues();
1584
double predictionArray[][] = new double[m_chosen_models.length][predSize];
1585
m_cachedPredictions.put(stringInstance, predictionArray);
1587
// Get the array from the map which is associated with this
1589
double predictions[][] = (double[][]) m_cachedPredictions
1590
.get(stringInstance);
1591
// And add our model's prediction for it.
1592
predictions[i] = m_chosen_models[i].getAveragePrediction(test
1595
if (printModelPerformances) {
1596
evalModel.evaluateModelOnceAndRecordPrediction(
1597
predictions[i], originalInstances.instance(j));
1600
// Now we're done with model #i, so we can release it.
1601
m_chosen_models[i].releaseModel();
1603
Date endTime = new Date();
1604
long diff = endTime.getTime() - startTime.getTime();
1607
System.out.println("Test time for "
1608
+ m_chosen_models[i].getStringRepresentation()
1611
if (printModelPerformances) {
1612
String output = new String(m_chosen_models[i]
1613
.getStringRepresentation()
1615
output += "\tRMSE:" + evalModel.rootMeanSquaredError();
1616
output += "\tACC:" + evalModel.pctCorrect();
1617
if (test.numClasses() == 2) {
1618
// For multiclass problems, we could print these too, but
1620
// not clear which class we should use in that case... so
1622
// we only print these metrics for binary classification
1624
output += "\tROC:" + evalModel.areaUnderROC(1);
1625
output += "\tPREC:" + evalModel.precision(1);
1626
output += "\tFSCR:" + evalModel.fMeasure(1);
1628
System.out.println(output);
1634
* Return the technical information. There is actually another
1635
* paper that describes our current method of CV for this classifier
1636
* TODO: Cite Technical report when published
1638
* @return the technical information about this class
1640
public TechnicalInformation getTechnicalInformation() {
1642
TechnicalInformation result;
1644
result = new TechnicalInformation(Type.INPROCEEDINGS);
1645
result.setValue(Field.AUTHOR, "Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes");
1646
result.setValue(Field.TITLE, "Ensemble Selection from Libraries of Models");
1647
result.setValue(Field.BOOKTITLE, "21st International Conference on Machine Learning");
1648
result.setValue(Field.YEAR, "2004");
1654
* Executes the classifier from commandline.
1657
* should contain the following arguments: -t training file [-T
1658
* test file] [-c class index]
1660
public static void main(String[] argv) {
1664
String options[] = (String[]) argv.clone();
1666
// do we get the input from XML instead of normal parameters?
1667
String xml = Utils.getOption("xml", options);
1668
if (!xml.equals(""))
1669
options = new XMLOptions(xml).toArray();
1671
String trainFileName = Utils.getOption('t', options);
1672
String objectInputFileName = Utils.getOption('l', options);
1673
String testFileName = Utils.getOption('T', options);
1675
if (testFileName.length() != 0 && objectInputFileName.length() != 0
1676
&& trainFileName.length() == 0) {
1678
System.out.println("Caching predictions");
1680
EnsembleSelection classifier = null;
1682
BufferedReader testReader = new BufferedReader(new FileReader(
1685
// Set up the Instances Object
1687
int classIndex = -1;
1688
String classIndexString = Utils.getOption('c', options);
1689
if (classIndexString.length() != 0) {
1690
classIndex = Integer.parseInt(classIndexString);
1693
test = new Instances(testReader, 1);
1694
if (classIndex != -1) {
1695
test.setClassIndex(classIndex - 1);
1697
test.setClassIndex(test.numAttributes() - 1);
1699
if (classIndex > test.numAttributes()) {
1700
throw new Exception("Index of class attribute too large.");
1703
while (test.readInstance(testReader)) {
1708
// Now yoink the EnsembleSelection Object from the fileSystem
1710
InputStream is = new FileInputStream(objectInputFileName);
1711
if (objectInputFileName.endsWith(".gz")) {
1712
is = new GZIPInputStream(is);
1716
if (!(objectInputFileName.endsWith("UpdateableClassifier.koml") && KOML
1718
ObjectInputStream objectInputStream = new ObjectInputStream(
1720
classifier = (EnsembleSelection) objectInputStream
1722
objectInputStream.close();
1724
BufferedInputStream xmlInputStream = new BufferedInputStream(
1726
classifier = (EnsembleSelection) KOML.read(xmlInputStream);
1727
xmlInputStream.close();
1730
String workingDir = Utils.getOption('W', argv);
1731
if (!workingDir.equals("")) {
1732
classifier.setWorkingDirectory(new File(workingDir));
1735
classifier.setDebug(Utils.getFlag('D', argv));
1736
classifier.setVerboseOutput(Utils.getFlag('O', argv));
1738
classifier.cachePredictions(test);
1740
// Now we write the model back out to the file system.
1741
String objectOutputFileName = objectInputFileName;
1742
OutputStream os = new FileOutputStream(objectOutputFileName);
1744
if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName
1745
.endsWith(".koml") && KOML.isPresent()))) {
1746
if (objectOutputFileName.endsWith(".gz")) {
1747
os = new GZIPOutputStream(os);
1749
ObjectOutputStream objectOutputStream = new ObjectOutputStream(
1751
objectOutputStream.writeObject(classifier);
1752
objectOutputStream.flush();
1753
objectOutputStream.close();
1757
BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
1759
if (objectOutputFileName.endsWith(".xml")) {
1760
XMLSerialization xmlSerial = new XMLClassifier();
1761
xmlSerial.write(xmlOutputStream, classifier);
1763
// whether KOML is present has already been checked
1764
// if not present -> ".koml" is interpreted as binary - see
1766
if (objectOutputFileName.endsWith(".koml")) {
1767
KOML.write(xmlOutputStream, classifier);
1769
xmlOutputStream.close();
1774
System.out.println(Evaluation.evaluateModel(
1775
new EnsembleSelection(), argv));
1777
} catch (Exception e) {
1778
if ( (e.getMessage() != null)
1779
&& (e.getMessage().indexOf("General options") == -1) )
1780
e.printStackTrace();
1782
System.err.println(e.getMessage());