~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/meta/ensembleSelection/EnsembleSelectionLibraryModel.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    EnsembleSelection.java
 
19
 *    Copyright (C) 2006 David Michael
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers.meta.ensembleSelection;
 
24
 
 
25
import weka.classifiers.Classifier;
 
26
import weka.classifiers.EnsembleLibraryModel;
 
27
import weka.core.Instance;
 
28
import weka.core.Instances;
 
29
import weka.core.Utils;
 
30
 
 
31
import java.io.File;
 
32
import java.io.FileInputStream;
 
33
import java.io.FileOutputStream;
 
34
import java.io.IOException;
 
35
import java.io.ObjectInputStream;
 
36
import java.io.ObjectOutput;
 
37
import java.io.ObjectOutputStream;
 
38
import java.io.Serializable;
 
39
import java.io.UnsupportedEncodingException;
 
40
import java.util.Date;
 
41
import java.util.zip.Adler32;
 
42
 
 
43
/**
 
44
 * This class represents a library model that is used for EnsembleSelection. At
 
45
 * this level the concept of cross validation is abstracted away. This class
 
46
 * keeps track of the performance statistics and bookkeeping information for its
 
47
 * "model type" accross all the CV folds. By "model type", I mean the
 
48
 * combination of both the Classifier type (e.g. J48), and its set of parameters
 
49
 * (e.g. -C 0.5 -X 1 -Y 5). So for example, if you are using 5 fold cross
 
50
 * validaiton, this model will keep an array of classifiers[] of length 5 and
 
51
 * will keep track of their performances accordingly. This class also has
 
52
 * methods to deal with serializing all of this information into the .elm file
 
53
 * that will represent this model.
 
54
 * <p/>
 
55
 * Also it is worth mentioning that another important function of this class is
 
56
 * to track all of the dataset information that was used to create this model.
 
57
 * This is because we want to protect users from doing foreseeably bad things.
 
58
 * e.g., trying to build an ensemble for a dataset with models that were trained
 
59
 * on the wrong partitioning of the dataset. This could lead to artificially high
 
60
 * performance due to the fact that instances used for the test set to gauge
 
61
 * performance could have accidentally been used to train the base classifiers.
 
62
 * So in a nutshell, we are preventing people from unintentionally "cheating" by
 
63
 * enforcing that the seed, #folds, validation ration, and the checksum of the 
 
64
 * Instances.toString() method ALL match exactly.  Otherwise we throw an 
 
65
 * exception.
 
66
 * 
 
67
 * @author  Robert Jung (mrbobjung@gmail.com)
 
68
 * @version $Revision: 1.1 $ 
 
69
 */
 
70
public class EnsembleSelectionLibraryModel
 
71
  extends EnsembleLibraryModel
 
72
  implements Serializable {
 
73
  
 
74
  /**
 
75
   * This is the serialVersionUID that SHOULD stay the same so that future
 
76
   * modified versions of this class will be backwards compatible with older
 
77
   * model versions.
 
78
   */
 
79
  private static final long serialVersionUID = -6426075459862947640L;
 
80
  
 
81
  /** The default file extension for ensemble library models */
 
82
  public static final String FILE_EXTENSION = ".elm";
 
83
  
 
84
  /** the models */
 
85
  private Classifier[] m_models = null;
 
86
  
 
87
  /** The seed that was used to create this model */
 
88
  private int m_seed;
 
89
  
 
90
  /**
 
91
   * The checksum of the instances.arff object that was used to create this
 
92
   * model
 
93
   */
 
94
  private String m_checksum;
 
95
  
 
96
  /** The validation ratio that was used to create this model */
 
97
  private double m_validationRatio;
 
98
  
 
99
  /**
 
100
   * The number of folds, or number of CV models that was used to create this
 
101
   * "model"
 
102
   */
 
103
  private int m_folds;
 
104
  
 
105
  /**
 
106
   * The .elm file name that this model should be saved/loaded to/from
 
107
   */
 
108
  private String m_fileName;
 
109
  
 
110
  /**
 
111
   * The debug flag as propagated from the main EnsembleSelection class.
 
112
   */
 
113
  public transient boolean m_Debug = true;
 
114
  
 
115
  /**
 
116
   * the validation predictions of this model. First index for the instance.
 
117
   * third is for the class (we use distributionForInstance).
 
118
   */
 
119
  private double[][] m_validationPredictions = null; // = new double[0][0];
 
120
  
 
121
  /**
 
122
   * Default Constructor
 
123
   */
 
124
  public EnsembleSelectionLibraryModel() {
 
125
  }
 
126
  
 
127
  /**
 
128
   * Constructor for LibaryModel
 
129
   * 
 
130
   * @param classifier          the classifier to use
 
131
   * @param seed                the random seed value
 
132
   * @param checksum            the checksum
 
133
   * @param validationRatio     the ration to use
 
134
   * @param folds               the number of folds to use
 
135
   */
 
136
  public EnsembleSelectionLibraryModel(Classifier classifier, int seed,
 
137
      String checksum, double validationRatio, int folds) {
 
138
    
 
139
    super(classifier);
 
140
    
 
141
    m_seed = seed;
 
142
    m_checksum = checksum;
 
143
    m_validationRatio = validationRatio;
 
144
    m_models = null;
 
145
    m_folds = folds;
 
146
  }
 
147
  
 
148
  /**
 
149
   * This is used to propagate the m_Debug flag of the EnsembleSelection
 
150
   * classifier to this class. There are things we would want to print out
 
151
   * here also.
 
152
   * 
 
153
   * @param debug       if true additional information is output
 
154
   */
 
155
  public void setDebug(boolean debug) {
 
156
    m_Debug = debug;
 
157
  }
 
158
  
 
159
  /**
 
160
   * Returns the average of the prediction of the models across all folds.
 
161
   * 
 
162
   * @param instance    the instance to get predictions for
 
163
   * @return            the average prediction
 
164
   * @throws Exception  if something goes wrong
 
165
   */
 
166
  public double[] getAveragePrediction(Instance instance) throws Exception {
 
167
    
 
168
    // Return the average prediction from all classifiers that make up
 
169
    // this model.
 
170
    double average[] = new double[instance.numClasses()];
 
171
    for (int i = 0; i < m_folds; ++i) {
 
172
      // Some models alter the instance (MultiLayerPerceptron), so we need
 
173
      // to copy it.
 
174
      Instance temp_instance = (Instance) instance.copy();
 
175
      double[] pred = getFoldPrediction(temp_instance, i);
 
176
      if (pred == null) {
 
177
        // Some models have bugs whereby they can return a null
 
178
        // prediction
 
179
        // array (again, MultiLayerPerceptron). We return null, and this
 
180
        // should be handled above in EnsembleSelection.
 
181
        System.err.println("Null validation predictions given: "
 
182
            + getStringRepresentation());
 
183
        return null;
 
184
      }
 
185
      if (i == 0) {
 
186
        // The first time through the loop, just use the first returned
 
187
        // prediction array. Just a simple optimization.
 
188
        average = pred;
 
189
      } else {
 
190
        // For the rest, add the prediction to the average array.
 
191
        for (int j = 0; j < pred.length; ++j) {
 
192
          average[j] += pred[j];
 
193
        }
 
194
      }
 
195
    }
 
196
    if (instance.classAttribute().isNominal()) {
 
197
      // Normalize predictions for classes to add up to 1.
 
198
      Utils.normalize(average);
 
199
    } else {
 
200
      average[0] /= m_folds;
 
201
    }
 
202
    return average;
 
203
  }
 
204
  
 
205
  /**
 
206
   * Basic Constructor
 
207
   * 
 
208
   * @param classifier  the classifier to use
 
209
   */
 
210
  public EnsembleSelectionLibraryModel(Classifier classifier) {
 
211
    super(classifier);
 
212
  }
 
213
  
 
214
  /**
 
215
   * Returns prediction of the classifier for the specified fold.
 
216
   * 
 
217
   * @param instance
 
218
   *            instance for which to make a prediction.
 
219
   * @param fold
 
220
   *            fold number of the classifier to use.
 
221
   * @return the prediction for the classes
 
222
   * @throws Exception if prediction fails
 
223
   */
 
224
  public double[] getFoldPrediction(Instance instance, int fold)
 
225
    throws Exception {
 
226
    
 
227
    return m_models[fold].distributionForInstance(instance);
 
228
  }
 
229
  
 
230
  /**
 
231
   * Creates the model. If there are n folds, it constructs n classifiers
 
232
   * using the current Classifier class and options. If the model has already
 
233
   * been created or loaded, starts fresh.
 
234
   * 
 
235
   * @param data                the data to work with
 
236
   * @param hillclimbData       the data for hillclimbing
 
237
   * @param dataDirectoryName   the directory to use
 
238
   * @param algorithm           the type of algorithm
 
239
   * @throws Exception          if something goeds wrong
 
240
   */
 
241
  public void createModel(Instances[] data, Instances[] hillclimbData,
 
242
      String dataDirectoryName, int algorithm) throws Exception {
 
243
    
 
244
    String modelFileName = getFileName(getStringRepresentation());
 
245
    
 
246
    File modelFile = new File(dataDirectoryName, modelFileName);
 
247
    
 
248
    String relativePath = (new File(dataDirectoryName)).getName()
 
249
    + File.separatorChar + modelFileName;
 
250
    // if (m_Debug) System.out.println("setting relative path to:
 
251
    // "+relativePath);
 
252
    setFileName(relativePath);
 
253
    
 
254
    if (!modelFile.exists()) {
 
255
      
 
256
      Date startTime = new Date();
 
257
      
 
258
      String lockFileName = EnsembleSelectionLibraryModel
 
259
      .getFileName(getStringRepresentation());
 
260
      lockFileName = lockFileName.substring(0, lockFileName.length() - 3)
 
261
      + "LCK";
 
262
      File lockFile = new File(dataDirectoryName, lockFileName);
 
263
      
 
264
      if (lockFile.exists()) {
 
265
        if (m_Debug)
 
266
          System.out.println("Detected lock file.  Skipping: "
 
267
              + lockFileName);
 
268
        throw new Exception("Lock File Detected: " + lockFile.getName());
 
269
        
 
270
      } else { // if (algorithm ==
 
271
        // EnsembleSelection.ALGORITHM_BUILD_LIBRARY) {
 
272
        // This lock file lets other computers that might be sharing the
 
273
        // same file
 
274
        // system that this model is already being trained so they know
 
275
        // to move ahead
 
276
        // and train other models.
 
277
        
 
278
        if (lockFile.createNewFile()) {
 
279
          
 
280
          if (m_Debug)
 
281
            System.out
 
282
            .println("lock file created: " + lockFileName);
 
283
          
 
284
          if (m_Debug)
 
285
            System.out.println("Creating model in locked mode: "
 
286
                + modelFile.getPath());
 
287
          
 
288
          m_models = new Classifier[m_folds];
 
289
          for (int i = 0; i < m_folds; ++i) {
 
290
            
 
291
            try {
 
292
              m_models[i] = Classifier.forName(getModelClass()
 
293
                  .getName(), null);
 
294
              m_models[i].setOptions(getOptions());
 
295
            } catch (Exception e) {
 
296
              throw new Exception("Invalid Options: "
 
297
                  + e.getMessage());
 
298
            }
 
299
          }
 
300
          
 
301
          try {
 
302
            for (int i = 0; i < m_folds; ++i) {
 
303
              train(data[i], i);
 
304
            }
 
305
          } catch (Exception e) {
 
306
            throw new Exception("Could not Train: "
 
307
                + e.getMessage());
 
308
          }
 
309
          
 
310
          Date endTime = new Date();
 
311
          int diff = (int) (endTime.getTime() - startTime.getTime());
 
312
          
 
313
          // We don't need the actual model for hillclimbing. To save
 
314
          // memory, release
 
315
          // it.
 
316
          
 
317
          // if (!invalidModels.contains(model)) {
 
318
          // EnsembleLibraryModel.saveModel(dataDirectory.getPath(),
 
319
          // model);
 
320
          // model.releaseModel();
 
321
          // }
 
322
          if (m_Debug)
 
323
            System.out.println("Train time for " + modelFileName
 
324
                + " was: " + diff);
 
325
          
 
326
          if (m_Debug)
 
327
            System.out
 
328
            .println("Generating validation set predictions");
 
329
          
 
330
          startTime = new Date();
 
331
          
 
332
          int total = 0;
 
333
          for (int i = 0; i < m_folds; ++i) {
 
334
            total += hillclimbData[i].numInstances();
 
335
          }
 
336
          
 
337
          m_validationPredictions = new double[total][];
 
338
          
 
339
          int preds_index = 0;
 
340
          for (int i = 0; i < m_folds; ++i) {
 
341
            for (int j = 0; j < hillclimbData[i].numInstances(); ++j) {
 
342
              Instance temp = (Instance) hillclimbData[i]
 
343
                                                       .instance(j).copy();// new
 
344
              // Instance(m_hillclimbData[i].instance(j));
 
345
              // must copy the instance because SOME classifiers
 
346
              // (I'm not pointing fingers...
 
347
              // MULTILAYERPERCEPTRON)
 
348
              // change the instance!
 
349
              
 
350
              m_validationPredictions[preds_index] = getFoldPrediction(
 
351
                  temp, i);
 
352
              
 
353
              if (m_validationPredictions[preds_index] == null) {
 
354
                throw new Exception(
 
355
                    "Null validation predictions given: "
 
356
                    + getStringRepresentation());
 
357
              }
 
358
              
 
359
              ++preds_index;
 
360
            }
 
361
          }
 
362
          
 
363
          endTime = new Date();
 
364
          diff = (int) (endTime.getTime() - startTime.getTime());
 
365
          
 
366
          // if (m_Debug) System.out.println("Generated a validation
 
367
          // set array of size: "+m_validationPredictions.length);
 
368
          if (m_Debug)
 
369
            System.out
 
370
            .println("Time to create validation predictions was: "
 
371
                + diff);
 
372
          
 
373
          EnsembleSelectionLibraryModel.saveModel(dataDirectoryName,
 
374
              this);
 
375
          
 
376
          if (m_Debug)
 
377
            System.out.println("deleting lock file: "
 
378
                + lockFileName);
 
379
          lockFile.delete();
 
380
          
 
381
        } else {
 
382
          
 
383
          if (m_Debug)
 
384
            System.out
 
385
            .println("Could not create lock file.  Skipping: "
 
386
                + lockFileName);
 
387
          throw new Exception(
 
388
              "Could not create lock file.  Skipping: "
 
389
              + lockFile.getName());
 
390
          
 
391
        }
 
392
        
 
393
      }
 
394
      
 
395
    } else {
 
396
      // This branch is responsible for loading a model from a .elm file
 
397
      
 
398
      if (m_Debug)
 
399
        System.out.println("Loading model: " + modelFile.getPath());
 
400
      // now we need to check to see if the model is valid, if so then
 
401
      // load it
 
402
      Date startTime = new Date();
 
403
      
 
404
      EnsembleSelectionLibraryModel newModel = loadModel(modelFile
 
405
          .getPath());
 
406
      
 
407
      if (!newModel.getStringRepresentation().equals(
 
408
          getStringRepresentation()))
 
409
        throw new EnsembleModelMismatchException(
 
410
            "String representations "
 
411
            + newModel.getStringRepresentation() + " and "
 
412
            + getStringRepresentation() + " not equal");
 
413
      
 
414
      if (!newModel.getChecksum().equals(getChecksum()))
 
415
        throw new EnsembleModelMismatchException("Checksums "
 
416
            + newModel.getChecksum() + " and " + getChecksum()
 
417
            + " not equal");
 
418
      
 
419
      if (newModel.getSeed() != getSeed())
 
420
        throw new EnsembleModelMismatchException("Seeds "
 
421
            + newModel.getSeed() + " and " + getSeed()
 
422
            + " not equal");
 
423
      
 
424
      if (newModel.getFolds() != getFolds())
 
425
        throw new EnsembleModelMismatchException("Folds "
 
426
            + newModel.getFolds() + " and " + getFolds()
 
427
            + " not equal");
 
428
      
 
429
      if (newModel.getValidationRatio() != getValidationRatio())
 
430
        throw new EnsembleModelMismatchException("Validation Ratios "
 
431
            + newModel.getValidationRatio() + " and "
 
432
            + getValidationRatio() + " not equal");
 
433
      
 
434
      // setFileName(modelFileName);
 
435
      
 
436
      m_models = newModel.getModels();
 
437
      m_validationPredictions = newModel.getValidationPredictions();
 
438
      
 
439
      Date endTime = new Date();
 
440
      int diff = (int) (endTime.getTime() - startTime.getTime());
 
441
      if (m_Debug)
 
442
        System.out.println("Time to load " + modelFileName + " was: "
 
443
            + diff);
 
444
    }
 
445
  }
 
446
  
 
447
  /**
 
448
   * The purpose of this method is to "rehydrate" the classifier object fot
 
449
   * this library model from the filesystem.
 
450
   * 
 
451
   * @param workingDirectory    the working directory to use
 
452
   */
 
453
  public void rehydrateModel(String workingDirectory) {
 
454
    
 
455
    if (m_models == null) {
 
456
      
 
457
      File file = new File(workingDirectory, m_fileName);
 
458
      
 
459
      if (m_Debug)
 
460
        System.out.println("Rehydrating Model: " + file.getPath());
 
461
      EnsembleSelectionLibraryModel model = EnsembleSelectionLibraryModel
 
462
      .loadModel(file.getPath());
 
463
      
 
464
      m_models = model.getModels();
 
465
      
 
466
    }
 
467
  }
 
468
  
 
469
  /**
 
470
   * Releases the model from memory. TODO - need to be saving these so we can
 
471
   * retrieve them later!!
 
472
   */
 
473
  public void releaseModel() {
 
474
    /*
 
475
     * if (m_unsaved) { saveModel(); }
 
476
     */
 
477
    m_models = null;
 
478
  }
 
479
  
 
480
  /** 
 
481
   * Train the classifier for the specified fold on the given data
 
482
   * 
 
483
   * @param trainData   the data to train with
 
484
   * @param fold        the fold number
 
485
   * @throws Exception  if something goes wrong, e.g., out of memory
 
486
   */
 
487
  public void train(Instances trainData, int fold) throws Exception {
 
488
    if (m_models != null) {
 
489
      
 
490
      try {
 
491
        // OK, this is it... this is the point where our code surrenders
 
492
        // to the weka classifiers.
 
493
        m_models[fold].buildClassifier(trainData);
 
494
      } catch (Throwable t) {
 
495
        m_models[fold] = null;
 
496
        throw new Exception(
 
497
            "Exception caught while training: (null could mean out of memory)"
 
498
            + t.getMessage());
 
499
      }
 
500
      
 
501
    } else {
 
502
      throw new Exception("Cannot train: model was null");
 
503
      // TODO: throw Exception?
 
504
    }
 
505
  }
 
506
  
 
507
  /**
 
508
   * Set the seed
 
509
   * 
 
510
   * @param seed        the seed value
 
511
   */
 
512
  public void setSeed(int seed) {
 
513
    m_seed = seed;
 
514
  }
 
515
  
 
516
  /**
 
517
   * Get the seed
 
518
   * 
 
519
   * @return the seed value
 
520
   */
 
521
  public int getSeed() {
 
522
    return m_seed;
 
523
  }
 
524
  
 
525
  /**
 
526
   * Sets the validation set ratio (only meaningful if folds == 1)
 
527
   * 
 
528
   * @param validationRatio     the new ration
 
529
   */
 
530
  public void setValidationRatio(double validationRatio) {
 
531
    m_validationRatio = validationRatio;
 
532
  }
 
533
  
 
534
  /**
 
535
   * get validationRatio
 
536
   * 
 
537
   * @return            the current ratio
 
538
   */
 
539
  public double getValidationRatio() {
 
540
    return m_validationRatio;
 
541
  }
 
542
  
 
543
  /**
 
544
   * Set the number of folds for cross validation. The number of folds also
 
545
   * indicates how many classifiers will be built to represent this model.
 
546
   * 
 
547
   * @param folds       the number of folds to use
 
548
   */
 
549
  public void setFolds(int folds) {
 
550
    m_folds = folds;
 
551
  }
 
552
  
 
553
  /**
 
554
   * get the number of folds
 
555
   * 
 
556
   * @return            the current number of folds
 
557
   */
 
558
  public int getFolds() {
 
559
    return m_folds;
 
560
  }
 
561
  
 
562
  /**
 
563
   * set the checksum
 
564
   * 
 
565
   * @param instancesChecksum   the new checksum
 
566
   */
 
567
  public void setChecksum(String instancesChecksum) {
 
568
    m_checksum = instancesChecksum;
 
569
  }
 
570
  
 
571
  /**
 
572
   * get the checksum
 
573
   * 
 
574
   * @return            the current checksum
 
575
   */
 
576
  public String getChecksum() {
 
577
    return m_checksum;
 
578
  }
 
579
  
 
580
  /**
 
581
   * Returs the array of classifiers
 
582
   * 
 
583
   * @return            the current models
 
584
   */
 
585
  public Classifier[] getModels() {
 
586
    return m_models;
 
587
  }
 
588
  
 
589
  /**
 
590
   * Sets the .elm file name for this library model
 
591
   * 
 
592
   * @param fileName    the new filename
 
593
   */
 
594
  public void setFileName(String fileName) {
 
595
    m_fileName = fileName;
 
596
  }
 
597
  
 
598
  /**
 
599
   * Gets a checksum for the string defining this classifier. This is used to
 
600
   * preserve uniqueness in the classifier names.
 
601
   * 
 
602
   * @param string      the classifier definition
 
603
   * @return            the checksum string
 
604
   */
 
605
  public static String getStringChecksum(String string) {
 
606
    
 
607
    String checksumString = null;
 
608
    
 
609
    try {
 
610
      
 
611
      Adler32 checkSummer = new Adler32();
 
612
      
 
613
      byte[] utf8 = string.toString().getBytes("UTF8");
 
614
      ;
 
615
      
 
616
      checkSummer.update(utf8);
 
617
      checksumString = Long.toHexString(checkSummer.getValue());
 
618
      
 
619
    } catch (UnsupportedEncodingException e) {
 
620
      // TODO Auto-generated catch block
 
621
      e.printStackTrace();
 
622
    }
 
623
    
 
624
    return checksumString;
 
625
  }
 
626
  
 
627
  /**
 
628
   * The purpose of this method is to get an appropriate file name for a model
 
629
   * based on its string representation of a model. All generated filenames
 
630
   * are limited to less than 128 characters and all of them will end with a
 
631
   * 64 bit checksum value of their string representation to try to maintain
 
632
   * some uniqueness of file names.
 
633
   * 
 
634
   * @param stringRepresentation        string representation of model
 
635
   * @return                            unique filename
 
636
   */
 
637
  public static String getFileName(String stringRepresentation) {
 
638
    
 
639
    // Get rid of space and quote marks(windows doesn't lke them)
 
640
    String fileName = stringRepresentation.trim().replace(' ', '_')
 
641
    .replace('"', '_');
 
642
    
 
643
    if (fileName.length() > 115) {
 
644
      
 
645
      fileName = fileName.substring(0, 115);
 
646
      
 
647
    }
 
648
    
 
649
    fileName += getStringChecksum(stringRepresentation)
 
650
    + EnsembleSelectionLibraryModel.FILE_EXTENSION;
 
651
    
 
652
    return fileName;
 
653
  }
 
654
  
 
655
  /**
 
656
   * Saves the given model to the specified file.
 
657
   * 
 
658
   * @param directory   the directory to save the model to
 
659
   * @param model       the model to save
 
660
   */
 
661
  public static void saveModel(String directory,
 
662
      EnsembleSelectionLibraryModel model) {
 
663
    
 
664
    try {
 
665
      String fileName = getFileName(model.getStringRepresentation());
 
666
      
 
667
      File file = new File(directory, fileName);
 
668
      
 
669
      // System.out.println("Saving model: "+file.getPath());
 
670
      
 
671
      // model.setFileName(new String(file.getPath()));
 
672
      
 
673
      // Serialize to a file
 
674
      ObjectOutput out = new ObjectOutputStream(
 
675
          new FileOutputStream(file));
 
676
      out.writeObject(model);
 
677
      
 
678
      out.close();
 
679
      
 
680
    } catch (IOException e) {
 
681
      
 
682
      e.printStackTrace();
 
683
    }
 
684
  }
 
685
  
 
686
  /**
 
687
   * loads the specified model
 
688
   * 
 
689
   * @param modelFilePath       the path of the model
 
690
   * @return                    the model
 
691
   */
 
692
  public static EnsembleSelectionLibraryModel loadModel(String modelFilePath) {
 
693
    
 
694
    EnsembleSelectionLibraryModel model = null;
 
695
    
 
696
    try {
 
697
      
 
698
      File file = new File(modelFilePath);
 
699
      
 
700
      ObjectInputStream in = new ObjectInputStream(new FileInputStream(
 
701
          file));
 
702
      
 
703
      model = (EnsembleSelectionLibraryModel) in.readObject();
 
704
      
 
705
      in.close();
 
706
      
 
707
    } catch (ClassNotFoundException e) {
 
708
      
 
709
      e.printStackTrace();
 
710
      
 
711
    } catch (IOException e) {
 
712
      
 
713
      e.printStackTrace();
 
714
      
 
715
    }
 
716
    
 
717
    return model;
 
718
  }
 
719
  
 
720
  /*
 
721
   * Problems persist in this code so we left it commented out. The intent was
 
722
   * to create the methods necessary for custom serialization to allow for
 
723
   * forwards/backwards compatability of .elm files accross multiple versions
 
724
   * of this classifier. The main problem however is that these methods do not
 
725
   * appear to be called. I'm not sure what the problem is, but this would be
 
726
   * a great feature. If anyone is a seasoned veteran of this serialization
 
727
   * stuff, please help!
 
728
   * 
 
729
   * private void writeObject(ObjectOutputStream stream) throws IOException {
 
730
   * //stream.defaultWriteObject(); //stream.writeObject(b);
 
731
   * 
 
732
   * //first serialize the LibraryModel fields
 
733
   * 
 
734
   * //super.writeObject(stream);
 
735
   * 
 
736
   * //now serialize the LibraryModel fields
 
737
   * 
 
738
   * stream.writeObject(m_Classifier);
 
739
   * 
 
740
   * stream.writeObject(m_DescriptionText);
 
741
   * 
 
742
   * stream.writeObject(m_ErrorText);
 
743
   * 
 
744
   * stream.writeObject(new Boolean(m_OptionsWereValid));
 
745
   * 
 
746
   * stream.writeObject(m_StringRepresentation);
 
747
   * 
 
748
   * stream.writeObject(m_models);
 
749
   * 
 
750
   * 
 
751
   * //now serialize the EnsembleLibraryModel fields //stream.writeObject(new
 
752
   * String("blah"));
 
753
   * 
 
754
   * stream.writeObject(new Integer(m_seed));
 
755
   * 
 
756
   * stream.writeObject(m_checksum);
 
757
   * 
 
758
   * stream.writeObject(new Double(m_validationRatio));
 
759
   * 
 
760
   * stream.writeObject(new Integer(m_folds));
 
761
   * 
 
762
   * stream.writeObject(m_fileName);
 
763
   * 
 
764
   * stream.writeObject(new Boolean(m_isTrained));
 
765
   * 
 
766
   * 
 
767
   * if (m_validationPredictions == null) {
 
768
   *  }
 
769
   * 
 
770
   * if (m_Debug) System.out.println("Saving
 
771
   * "+m_validationPredictions.length+" indexed array");
 
772
   * stream.writeObject(m_validationPredictions);
 
773
   *  }
 
774
   * 
 
775
   * private void readObject(ObjectInputStream stream) throws IOException,
 
776
   * ClassNotFoundException { //stream.defaultReadObject(); //b = (String)
 
777
   * stream.readObject();
 
778
   * 
 
779
   * //super.readObject(stream);
 
780
   * 
 
781
   * //deserialize the LibraryModel fields m_Classifier =
 
782
   * (Classifier)stream.readObject();
 
783
   * 
 
784
   * m_DescriptionText = (String)stream.readObject();
 
785
   * 
 
786
   * m_ErrorText = (String)stream.readObject();
 
787
   * 
 
788
   * m_OptionsWereValid = ((Boolean)stream.readObject()).booleanValue();
 
789
   * 
 
790
   * m_StringRepresentation = (String)stream.readObject();
 
791
   * 
 
792
   * 
 
793
   * 
 
794
   * //now deserialize the EnsembleLibraryModel fields m_models =
 
795
   * (Classifier[])stream.readObject();
 
796
   * 
 
797
   * m_seed = ((Integer)stream.readObject()).intValue();
 
798
   * 
 
799
   * m_checksum = (String)stream.readObject();
 
800
   * 
 
801
   * m_validationRatio = ((Double)stream.readObject()).doubleValue();
 
802
   * 
 
803
   * m_folds = ((Integer)stream.readObject()).intValue();
 
804
   * 
 
805
   * m_fileName = (String)stream.readObject();
 
806
   * 
 
807
   * m_isTrained = ((Boolean)stream.readObject()).booleanValue();
 
808
   * 
 
809
   * m_validationPredictions = (double[][])stream.readObject();
 
810
   * 
 
811
   * if (m_Debug) System.out.println("Loaded
 
812
   * "+m_validationPredictions.length+" indexed array"); }
 
813
   * 
 
814
   */
 
815
  
 
816
  /**
 
817
   * getter for validation predictions
 
818
   * 
 
819
   * @return            the current validation predictions
 
820
   */
 
821
  public double[][] getValidationPredictions() {
 
822
    return m_validationPredictions;
 
823
  }
 
824
  
 
825
  /**
 
826
   * setter for validation predictions
 
827
   * 
 
828
   * @param predictions the new validation predictions
 
829
   */
 
830
  public void setValidationPredictions(double[][] predictions) {
 
831
    if (m_Debug)
 
832
      System.out.println("Saving validation array of size "
 
833
          + predictions.length);
 
834
    m_validationPredictions = new double[predictions.length][];
 
835
    System.arraycopy(predictions, 0, m_validationPredictions, 0,
 
836
        predictions.length);
 
837
  }
 
838
}
 
 
b'\\ No newline at end of file'