2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* EnsembleSelection.java
19
* Copyright (C) 2006 David Michael
23
package weka.classifiers.meta.ensembleSelection;
25
import weka.classifiers.Classifier;
26
import weka.classifiers.EnsembleLibraryModel;
27
import weka.core.Instance;
28
import weka.core.Instances;
29
import weka.core.Utils;
32
import java.io.FileInputStream;
33
import java.io.FileOutputStream;
34
import java.io.IOException;
35
import java.io.ObjectInputStream;
36
import java.io.ObjectOutput;
37
import java.io.ObjectOutputStream;
38
import java.io.Serializable;
39
import java.io.UnsupportedEncodingException;
40
import java.util.Date;
41
import java.util.zip.Adler32;
44
* This class represents a library model that is used for EnsembleSelection. At
45
* this level the concept of cross validation is abstracted away. This class
46
* keeps track of the performance statistics and bookkeeping information for its
47
* "model type" accross all the CV folds. By "model type", I mean the
48
* combination of both the Classifier type (e.g. J48), and its set of parameters
49
* (e.g. -C 0.5 -X 1 -Y 5). So for example, if you are using 5 fold cross
50
* validaiton, this model will keep an array of classifiers[] of length 5 and
51
* will keep track of their performances accordingly. This class also has
52
* methods to deal with serializing all of this information into the .elm file
53
* that will represent this model.
55
* Also it is worth mentioning that another important function of this class is
56
* to track all of the dataset information that was used to create this model.
57
* This is because we want to protect users from doing foreseeably bad things.
58
* e.g., trying to build an ensemble for a dataset with models that were trained
59
* on the wrong partitioning of the dataset. This could lead to artificially high
60
* performance due to the fact that instances used for the test set to gauge
61
* performance could have accidentally been used to train the base classifiers.
62
* So in a nutshell, we are preventing people from unintentionally "cheating" by
63
* enforcing that the seed, #folds, validation ration, and the checksum of the
64
* Instances.toString() method ALL match exactly. Otherwise we throw an
67
* @author Robert Jung (mrbobjung@gmail.com)
68
* @version $Revision: 1.1 $
70
public class EnsembleSelectionLibraryModel
71
extends EnsembleLibraryModel
72
implements Serializable {
75
* This is the serialVersionUID that SHOULD stay the same so that future
76
* modified versions of this class will be backwards compatible with older
79
private static final long serialVersionUID = -6426075459862947640L;
81
/** The default file extension for ensemble library models */
82
public static final String FILE_EXTENSION = ".elm";
85
private Classifier[] m_models = null;
87
/** The seed that was used to create this model */
91
* The checksum of the instances.arff object that was used to create this
94
private String m_checksum;
96
/** The validation ratio that was used to create this model */
97
private double m_validationRatio;
100
* The number of folds, or number of CV models that was used to create this
106
* The .elm file name that this model should be saved/loaded to/from
108
private String m_fileName;
111
* The debug flag as propagated from the main EnsembleSelection class.
113
public transient boolean m_Debug = true;
116
* the validation predictions of this model. First index for the instance.
117
* third is for the class (we use distributionForInstance).
119
private double[][] m_validationPredictions = null; // = new double[0][0];
122
* Default Constructor
124
public EnsembleSelectionLibraryModel() {
128
* Constructor for LibaryModel
130
* @param classifier the classifier to use
131
* @param seed the random seed value
132
* @param checksum the checksum
133
* @param validationRatio the ration to use
134
* @param folds the number of folds to use
136
public EnsembleSelectionLibraryModel(Classifier classifier, int seed,
137
String checksum, double validationRatio, int folds) {
142
m_checksum = checksum;
143
m_validationRatio = validationRatio;
149
* This is used to propagate the m_Debug flag of the EnsembleSelection
150
* classifier to this class. There are things we would want to print out
153
* @param debug if true additional information is output
155
public void setDebug(boolean debug) {
160
* Returns the average of the prediction of the models across all folds.
162
* @param instance the instance to get predictions for
163
* @return the average prediction
164
* @throws Exception if something goes wrong
166
public double[] getAveragePrediction(Instance instance) throws Exception {
168
// Return the average prediction from all classifiers that make up
170
double average[] = new double[instance.numClasses()];
171
for (int i = 0; i < m_folds; ++i) {
172
// Some models alter the instance (MultiLayerPerceptron), so we need
174
Instance temp_instance = (Instance) instance.copy();
175
double[] pred = getFoldPrediction(temp_instance, i);
177
// Some models have bugs whereby they can return a null
179
// array (again, MultiLayerPerceptron). We return null, and this
180
// should be handled above in EnsembleSelection.
181
System.err.println("Null validation predictions given: "
182
+ getStringRepresentation());
186
// The first time through the loop, just use the first returned
187
// prediction array. Just a simple optimization.
190
// For the rest, add the prediction to the average array.
191
for (int j = 0; j < pred.length; ++j) {
192
average[j] += pred[j];
196
if (instance.classAttribute().isNominal()) {
197
// Normalize predictions for classes to add up to 1.
198
Utils.normalize(average);
200
average[0] /= m_folds;
208
* @param classifier the classifier to use
210
public EnsembleSelectionLibraryModel(Classifier classifier) {
215
* Returns prediction of the classifier for the specified fold.
218
* instance for which to make a prediction.
220
* fold number of the classifier to use.
221
* @return the prediction for the classes
222
* @throws Exception if prediction fails
224
public double[] getFoldPrediction(Instance instance, int fold)
227
return m_models[fold].distributionForInstance(instance);
231
* Creates the model. If there are n folds, it constructs n classifiers
232
* using the current Classifier class and options. If the model has already
233
* been created or loaded, starts fresh.
235
* @param data the data to work with
236
* @param hillclimbData the data for hillclimbing
237
* @param dataDirectoryName the directory to use
238
* @param algorithm the type of algorithm
239
* @throws Exception if something goeds wrong
241
public void createModel(Instances[] data, Instances[] hillclimbData,
242
String dataDirectoryName, int algorithm) throws Exception {
244
String modelFileName = getFileName(getStringRepresentation());
246
File modelFile = new File(dataDirectoryName, modelFileName);
248
String relativePath = (new File(dataDirectoryName)).getName()
249
+ File.separatorChar + modelFileName;
250
// if (m_Debug) System.out.println("setting relative path to:
252
setFileName(relativePath);
254
if (!modelFile.exists()) {
256
Date startTime = new Date();
258
String lockFileName = EnsembleSelectionLibraryModel
259
.getFileName(getStringRepresentation());
260
lockFileName = lockFileName.substring(0, lockFileName.length() - 3)
262
File lockFile = new File(dataDirectoryName, lockFileName);
264
if (lockFile.exists()) {
266
System.out.println("Detected lock file. Skipping: "
268
throw new Exception("Lock File Detected: " + lockFile.getName());
270
} else { // if (algorithm ==
271
// EnsembleSelection.ALGORITHM_BUILD_LIBRARY) {
272
// This lock file lets other computers that might be sharing the
274
// system that this model is already being trained so they know
276
// and train other models.
278
if (lockFile.createNewFile()) {
282
.println("lock file created: " + lockFileName);
285
System.out.println("Creating model in locked mode: "
286
+ modelFile.getPath());
288
m_models = new Classifier[m_folds];
289
for (int i = 0; i < m_folds; ++i) {
292
m_models[i] = Classifier.forName(getModelClass()
294
m_models[i].setOptions(getOptions());
295
} catch (Exception e) {
296
throw new Exception("Invalid Options: "
302
for (int i = 0; i < m_folds; ++i) {
305
} catch (Exception e) {
306
throw new Exception("Could not Train: "
310
Date endTime = new Date();
311
int diff = (int) (endTime.getTime() - startTime.getTime());
313
// We don't need the actual model for hillclimbing. To save
317
// if (!invalidModels.contains(model)) {
318
// EnsembleLibraryModel.saveModel(dataDirectory.getPath(),
320
// model.releaseModel();
323
System.out.println("Train time for " + modelFileName
328
.println("Generating validation set predictions");
330
startTime = new Date();
333
for (int i = 0; i < m_folds; ++i) {
334
total += hillclimbData[i].numInstances();
337
m_validationPredictions = new double[total][];
340
for (int i = 0; i < m_folds; ++i) {
341
for (int j = 0; j < hillclimbData[i].numInstances(); ++j) {
342
Instance temp = (Instance) hillclimbData[i]
343
.instance(j).copy();// new
344
// Instance(m_hillclimbData[i].instance(j));
345
// must copy the instance because SOME classifiers
346
// (I'm not pointing fingers...
347
// MULTILAYERPERCEPTRON)
348
// change the instance!
350
m_validationPredictions[preds_index] = getFoldPrediction(
353
if (m_validationPredictions[preds_index] == null) {
355
"Null validation predictions given: "
356
+ getStringRepresentation());
363
endTime = new Date();
364
diff = (int) (endTime.getTime() - startTime.getTime());
366
// if (m_Debug) System.out.println("Generated a validation
367
// set array of size: "+m_validationPredictions.length);
370
.println("Time to create validation predictions was: "
373
EnsembleSelectionLibraryModel.saveModel(dataDirectoryName,
377
System.out.println("deleting lock file: "
385
.println("Could not create lock file. Skipping: "
388
"Could not create lock file. Skipping: "
389
+ lockFile.getName());
396
// This branch is responsible for loading a model from a .elm file
399
System.out.println("Loading model: " + modelFile.getPath());
400
// now we need to check to see if the model is valid, if so then
402
Date startTime = new Date();
404
EnsembleSelectionLibraryModel newModel = loadModel(modelFile
407
if (!newModel.getStringRepresentation().equals(
408
getStringRepresentation()))
409
throw new EnsembleModelMismatchException(
410
"String representations "
411
+ newModel.getStringRepresentation() + " and "
412
+ getStringRepresentation() + " not equal");
414
if (!newModel.getChecksum().equals(getChecksum()))
415
throw new EnsembleModelMismatchException("Checksums "
416
+ newModel.getChecksum() + " and " + getChecksum()
419
if (newModel.getSeed() != getSeed())
420
throw new EnsembleModelMismatchException("Seeds "
421
+ newModel.getSeed() + " and " + getSeed()
424
if (newModel.getFolds() != getFolds())
425
throw new EnsembleModelMismatchException("Folds "
426
+ newModel.getFolds() + " and " + getFolds()
429
if (newModel.getValidationRatio() != getValidationRatio())
430
throw new EnsembleModelMismatchException("Validation Ratios "
431
+ newModel.getValidationRatio() + " and "
432
+ getValidationRatio() + " not equal");
434
// setFileName(modelFileName);
436
m_models = newModel.getModels();
437
m_validationPredictions = newModel.getValidationPredictions();
439
Date endTime = new Date();
440
int diff = (int) (endTime.getTime() - startTime.getTime());
442
System.out.println("Time to load " + modelFileName + " was: "
448
* The purpose of this method is to "rehydrate" the classifier object fot
449
* this library model from the filesystem.
451
* @param workingDirectory the working directory to use
453
public void rehydrateModel(String workingDirectory) {
455
if (m_models == null) {
457
File file = new File(workingDirectory, m_fileName);
460
System.out.println("Rehydrating Model: " + file.getPath());
461
EnsembleSelectionLibraryModel model = EnsembleSelectionLibraryModel
462
.loadModel(file.getPath());
464
m_models = model.getModels();
470
* Releases the model from memory. TODO - need to be saving these so we can
471
* retrieve them later!!
473
public void releaseModel() {
475
* if (m_unsaved) { saveModel(); }
481
* Train the classifier for the specified fold on the given data
483
* @param trainData the data to train with
484
* @param fold the fold number
485
* @throws Exception if something goes wrong, e.g., out of memory
487
public void train(Instances trainData, int fold) throws Exception {
488
if (m_models != null) {
491
// OK, this is it... this is the point where our code surrenders
492
// to the weka classifiers.
493
m_models[fold].buildClassifier(trainData);
494
} catch (Throwable t) {
495
m_models[fold] = null;
497
"Exception caught while training: (null could mean out of memory)"
502
throw new Exception("Cannot train: model was null");
503
// TODO: throw Exception?
510
* @param seed the seed value
512
public void setSeed(int seed) {
519
* @return the seed value
521
public int getSeed() {
526
* Sets the validation set ratio (only meaningful if folds == 1)
528
* @param validationRatio the new ration
530
public void setValidationRatio(double validationRatio) {
531
m_validationRatio = validationRatio;
535
* get validationRatio
537
* @return the current ratio
539
public double getValidationRatio() {
540
return m_validationRatio;
544
* Set the number of folds for cross validation. The number of folds also
545
* indicates how many classifiers will be built to represent this model.
547
* @param folds the number of folds to use
549
public void setFolds(int folds) {
554
* get the number of folds
556
* @return the current number of folds
558
public int getFolds() {
565
* @param instancesChecksum the new checksum
567
public void setChecksum(String instancesChecksum) {
568
m_checksum = instancesChecksum;
574
* @return the current checksum
576
public String getChecksum() {
581
* Returs the array of classifiers
583
* @return the current models
585
public Classifier[] getModels() {
590
* Sets the .elm file name for this library model
592
* @param fileName the new filename
594
public void setFileName(String fileName) {
595
m_fileName = fileName;
599
* Gets a checksum for the string defining this classifier. This is used to
600
* preserve uniqueness in the classifier names.
602
* @param string the classifier definition
603
* @return the checksum string
605
public static String getStringChecksum(String string) {
607
String checksumString = null;
611
Adler32 checkSummer = new Adler32();
613
byte[] utf8 = string.toString().getBytes("UTF8");
616
checkSummer.update(utf8);
617
checksumString = Long.toHexString(checkSummer.getValue());
619
} catch (UnsupportedEncodingException e) {
620
// TODO Auto-generated catch block
624
return checksumString;
628
* The purpose of this method is to get an appropriate file name for a model
629
* based on its string representation of a model. All generated filenames
630
* are limited to less than 128 characters and all of them will end with a
631
* 64 bit checksum value of their string representation to try to maintain
632
* some uniqueness of file names.
634
* @param stringRepresentation string representation of model
635
* @return unique filename
637
public static String getFileName(String stringRepresentation) {
639
// Get rid of space and quote marks(windows doesn't lke them)
640
String fileName = stringRepresentation.trim().replace(' ', '_')
643
if (fileName.length() > 115) {
645
fileName = fileName.substring(0, 115);
649
fileName += getStringChecksum(stringRepresentation)
650
+ EnsembleSelectionLibraryModel.FILE_EXTENSION;
656
* Saves the given model to the specified file.
658
* @param directory the directory to save the model to
659
* @param model the model to save
661
public static void saveModel(String directory,
662
EnsembleSelectionLibraryModel model) {
665
String fileName = getFileName(model.getStringRepresentation());
667
File file = new File(directory, fileName);
669
// System.out.println("Saving model: "+file.getPath());
671
// model.setFileName(new String(file.getPath()));
673
// Serialize to a file
674
ObjectOutput out = new ObjectOutputStream(
675
new FileOutputStream(file));
676
out.writeObject(model);
680
} catch (IOException e) {
687
* loads the specified model
689
* @param modelFilePath the path of the model
692
public static EnsembleSelectionLibraryModel loadModel(String modelFilePath) {
694
EnsembleSelectionLibraryModel model = null;
698
File file = new File(modelFilePath);
700
ObjectInputStream in = new ObjectInputStream(new FileInputStream(
703
model = (EnsembleSelectionLibraryModel) in.readObject();
707
} catch (ClassNotFoundException e) {
711
} catch (IOException e) {
721
* Problems persist in this code so we left it commented out. The intent was
722
* to create the methods necessary for custom serialization to allow for
723
* forwards/backwards compatability of .elm files accross multiple versions
724
* of this classifier. The main problem however is that these methods do not
725
* appear to be called. I'm not sure what the problem is, but this would be
726
* a great feature. If anyone is a seasoned veteran of this serialization
727
* stuff, please help!
729
* private void writeObject(ObjectOutputStream stream) throws IOException {
730
* //stream.defaultWriteObject(); //stream.writeObject(b);
732
* //first serialize the LibraryModel fields
734
* //super.writeObject(stream);
736
* //now serialize the LibraryModel fields
738
* stream.writeObject(m_Classifier);
740
* stream.writeObject(m_DescriptionText);
742
* stream.writeObject(m_ErrorText);
744
* stream.writeObject(new Boolean(m_OptionsWereValid));
746
* stream.writeObject(m_StringRepresentation);
748
* stream.writeObject(m_models);
751
* //now serialize the EnsembleLibraryModel fields //stream.writeObject(new
754
* stream.writeObject(new Integer(m_seed));
756
* stream.writeObject(m_checksum);
758
* stream.writeObject(new Double(m_validationRatio));
760
* stream.writeObject(new Integer(m_folds));
762
* stream.writeObject(m_fileName);
764
* stream.writeObject(new Boolean(m_isTrained));
767
* if (m_validationPredictions == null) {
770
* if (m_Debug) System.out.println("Saving
771
* "+m_validationPredictions.length+" indexed array");
772
* stream.writeObject(m_validationPredictions);
775
* private void readObject(ObjectInputStream stream) throws IOException,
776
* ClassNotFoundException { //stream.defaultReadObject(); //b = (String)
777
* stream.readObject();
779
* //super.readObject(stream);
781
* //deserialize the LibraryModel fields m_Classifier =
782
* (Classifier)stream.readObject();
784
* m_DescriptionText = (String)stream.readObject();
786
* m_ErrorText = (String)stream.readObject();
788
* m_OptionsWereValid = ((Boolean)stream.readObject()).booleanValue();
790
* m_StringRepresentation = (String)stream.readObject();
794
* //now deserialize the EnsembleLibraryModel fields m_models =
795
* (Classifier[])stream.readObject();
797
* m_seed = ((Integer)stream.readObject()).intValue();
799
* m_checksum = (String)stream.readObject();
801
* m_validationRatio = ((Double)stream.readObject()).doubleValue();
803
* m_folds = ((Integer)stream.readObject()).intValue();
805
* m_fileName = (String)stream.readObject();
807
* m_isTrained = ((Boolean)stream.readObject()).booleanValue();
809
* m_validationPredictions = (double[][])stream.readObject();
811
* if (m_Debug) System.out.println("Loaded
812
* "+m_validationPredictions.length+" indexed array"); }
817
* getter for validation predictions
819
* @return the current validation predictions
821
public double[][] getValidationPredictions() {
822
return m_validationPredictions;
826
* setter for validation predictions
828
* @param predictions the new validation predictions
830
public void setValidationPredictions(double[][] predictions) {
832
System.out.println("Saving validation array of size "
833
+ predictions.length);
834
m_validationPredictions = new double[predictions.length][];
835
System.arraycopy(predictions, 0, m_validationPredictions, 0,
b'\\ No newline at end of file'