~ubuntu-branches/ubuntu/vivid/weka/vivid

« back to all changes in this revision

Viewing changes to weka/experiment/CostSensitiveClassifierSplitEvaluator.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    CostSensitiveClassifierSplitEvaluator.java
 
19
 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
 
 
24
package weka.experiment;
 
25
 
 
26
import weka.classifiers.Classifier;
 
27
import weka.classifiers.CostMatrix;
 
28
import weka.classifiers.Evaluation;
 
29
import weka.core.AdditionalMeasureProducer;
 
30
import weka.core.Attribute;
 
31
import weka.core.Instance;
 
32
import weka.core.Instances;
 
33
import weka.core.Option;
 
34
import weka.core.Summarizable;
 
35
import weka.core.Utils;
 
36
 
 
37
import java.io.BufferedReader;
 
38
import java.io.ByteArrayOutputStream;
 
39
import java.io.File;
 
40
import java.io.FileReader;
 
41
import java.io.ObjectOutputStream;
 
42
import java.lang.management.ManagementFactory;
 
43
import java.lang.management.ThreadMXBean;
 
44
import java.util.Enumeration;
 
45
import java.util.Vector;
 
46
 
 
47
/**
 
48
 <!-- globalinfo-start -->
 
49
 * SplitEvaluator that produces results for a classification scheme on a nominal class attribute, including weighted misclassification costs.
 
50
 * <p/>
 
51
 <!-- globalinfo-end -->
 
52
 *
 
53
 <!-- options-start -->
 
54
 * Valid options are: <p/>
 
55
 * 
 
56
 * <pre> -W &lt;class name&gt;
 
57
 *  The full class name of the classifier.
 
58
 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
 
59
 * 
 
60
 * <pre> -C &lt;index&gt;
 
61
 *  The index of the class for which IR statistics
 
62
 *  are to be output. (default 1)</pre>
 
63
 * 
 
64
 * <pre> -I &lt;index&gt;
 
65
 *  The index of an attribute to output in the
 
66
 *  results. This attribute should identify an
 
67
 *  instance in order to know which instances are
 
68
 *  in the test set of a cross validation. if 0
 
69
 *  no output (default 0).</pre>
 
70
 * 
 
71
 * <pre> -P
 
72
 *  Add target and prediction columns to the result
 
73
 *  for each fold.</pre>
 
74
 * 
 
75
 * <pre> 
 
76
 * Options specific to classifier weka.classifiers.rules.ZeroR:
 
77
 * </pre>
 
78
 * 
 
79
 * <pre> -D
 
80
 *  If set, classifier is run in debug mode and
 
81
 *  may output additional info to the console</pre>
 
82
 * 
 
83
 * <pre> -D &lt;directory&gt;
 
84
 *  Name of a directory to search for cost files when loading
 
85
 *  costs on demand (default current directory).</pre>
 
86
 * 
 
87
 <!-- options-end -->
 
88
 *
 
89
 * All options after -- will be passed to the classifier.
 
90
 *
 
91
 * @author Len Trigg (len@reeltwo.com)
 
92
 * @version $Revision: 1.17 $
 
93
 */
 
94
public class CostSensitiveClassifierSplitEvaluator 
 
95
  extends ClassifierSplitEvaluator { 
 
96
 
 
97
  /** for serialization */
 
98
  static final long serialVersionUID = -8069566663019501276L;
 
99
 
 
100
  /** 
 
101
   * The directory used when loading cost files on demand, null indicates
 
102
   * current directory 
 
103
   */
 
104
  protected File m_OnDemandDirectory = new File(System.getProperty("user.dir"));
 
105
 
 
106
  /** The length of a result */
 
107
  private static final int RESULT_SIZE = 31;
 
108
 
 
109
  /**
 
110
   * Returns a string describing this split evaluator
 
111
   * @return a description of the split evaluator suitable for
 
112
   * displaying in the explorer/experimenter gui
 
113
   */
 
114
  public String globalInfo() {
 
115
    return " SplitEvaluator that produces results for a classification scheme "
 
116
      +"on a nominal class attribute, including weighted misclassification "
 
117
      +"costs.";
 
118
  }
 
119
 
 
120
  /**
 
121
   * Returns an enumeration describing the available options..
 
122
   *
 
123
   * @return an enumeration of all the available options.
 
124
   */
 
125
  public Enumeration listOptions() {
 
126
 
 
127
    Vector newVector = new Vector(1);
 
128
    Enumeration enu = super.listOptions();
 
129
    while (enu.hasMoreElements()) {
 
130
      newVector.addElement(enu.nextElement());
 
131
    }
 
132
 
 
133
    newVector.addElement(new Option(
 
134
              "\tName of a directory to search for cost files when loading\n"
 
135
              +"\tcosts on demand (default current directory).",
 
136
              "D", 1, "-D <directory>"));
 
137
 
 
138
    return newVector.elements();
 
139
  }
 
140
 
 
141
  /**
 
142
   * Parses a given list of options. <p/>
 
143
   *
 
144
   <!-- options-start -->
 
145
   * Valid options are: <p/>
 
146
   * 
 
147
   * <pre> -W &lt;class name&gt;
 
148
   *  The full class name of the classifier.
 
149
   *  eg: weka.classifiers.bayes.NaiveBayes</pre>
 
150
   * 
 
151
   * <pre> -C &lt;index&gt;
 
152
   *  The index of the class for which IR statistics
 
153
   *  are to be output. (default 1)</pre>
 
154
   * 
 
155
   * <pre> -I &lt;index&gt;
 
156
   *  The index of an attribute to output in the
 
157
   *  results. This attribute should identify an
 
158
   *  instance in order to know which instances are
 
159
   *  in the test set of a cross validation. if 0
 
160
   *  no output (default 0).</pre>
 
161
   * 
 
162
   * <pre> -P
 
163
   *  Add target and prediction columns to the result
 
164
   *  for each fold.</pre>
 
165
   * 
 
166
   * <pre> 
 
167
   * Options specific to classifier weka.classifiers.rules.ZeroR:
 
168
   * </pre>
 
169
   * 
 
170
   * <pre> -D
 
171
   *  If set, classifier is run in debug mode and
 
172
   *  may output additional info to the console</pre>
 
173
   * 
 
174
   * <pre> -D &lt;directory&gt;
 
175
   *  Name of a directory to search for cost files when loading
 
176
   *  costs on demand (default current directory).</pre>
 
177
   * 
 
178
   <!-- options-end -->
 
179
   *
 
180
   * All options after -- will be passed to the classifier.
 
181
   *
 
182
   * @param options the list of options as an array of strings
 
183
   * @throws Exception if an option is not supported
 
184
   */
 
185
  public void setOptions(String[] options) throws Exception {
 
186
    
 
187
    String demandDir = Utils.getOption('D', options);
 
188
    if (demandDir.length() != 0) {
 
189
      setOnDemandDirectory(new File(demandDir));
 
190
    }
 
191
 
 
192
    super.setOptions(options);
 
193
  }
 
194
 
 
195
  /**
 
196
   * Gets the current settings of the Classifier.
 
197
   *
 
198
   * @return an array of strings suitable for passing to setOptions
 
199
   */
 
200
  public String [] getOptions() {
 
201
 
 
202
    String [] superOptions = super.getOptions();
 
203
    String [] options = new String [superOptions.length + 3];
 
204
    int current = 0;
 
205
 
 
206
    options[current++] = "-D";
 
207
    options[current++] = "" + getOnDemandDirectory();
 
208
 
 
209
    System.arraycopy(superOptions, 0, options, current, 
 
210
                     superOptions.length);
 
211
    current += superOptions.length;
 
212
    while (current < options.length) {
 
213
      options[current++] = "";
 
214
    }
 
215
    return options;
 
216
  }
 
217
 
 
218
  /**
 
219
   * Returns the tip text for this property
 
220
   * @return tip text for this property suitable for
 
221
   * displaying in the explorer/experimenter gui
 
222
   */
 
223
  public String onDemandDirectoryTipText() {
 
224
    return "The directory to look in for cost files. This directory will be "
 
225
      +"searched for cost files when loading on demand.";
 
226
  }
 
227
 
 
228
  /**
 
229
   * Returns the directory that will be searched for cost files when
 
230
   * loading on demand.
 
231
   *
 
232
   * @return The cost file search directory.
 
233
   */
 
234
  public File getOnDemandDirectory() {
 
235
 
 
236
    return m_OnDemandDirectory;
 
237
  }
 
238
 
 
239
  /**
 
240
   * Sets the directory that will be searched for cost files when
 
241
   * loading on demand.
 
242
   *
 
243
   * @param newDir The cost file search directory.
 
244
   */
 
245
  public void setOnDemandDirectory(File newDir) {
 
246
 
 
247
    if (newDir.isDirectory()) {
 
248
      m_OnDemandDirectory = newDir;
 
249
    } else {
 
250
      m_OnDemandDirectory = new File(newDir.getParent());
 
251
    }
 
252
  }
 
253
 
 
254
  /**
 
255
   * Gets the data types of each of the result columns produced for a 
 
256
   * single run. The number of result fields must be constant
 
257
   * for a given SplitEvaluator.
 
258
   *
 
259
   * @return an array containing objects of the type of each result column. 
 
260
   * The objects should be Strings, or Doubles.
 
261
   */
 
262
  public Object [] getResultTypes() {
 
263
    int addm = (m_AdditionalMeasures != null) 
 
264
      ? m_AdditionalMeasures.length 
 
265
      : 0;
 
266
    Object [] resultTypes = new Object[RESULT_SIZE+addm];
 
267
    Double doub = new Double(0);
 
268
    int current = 0;
 
269
    resultTypes[current++] = doub;
 
270
    resultTypes[current++] = doub;
 
271
 
 
272
    resultTypes[current++] = doub;
 
273
    resultTypes[current++] = doub;
 
274
    resultTypes[current++] = doub;
 
275
    resultTypes[current++] = doub;
 
276
    resultTypes[current++] = doub;
 
277
    resultTypes[current++] = doub;
 
278
    resultTypes[current++] = doub;
 
279
    resultTypes[current++] = doub;
 
280
 
 
281
    resultTypes[current++] = doub;
 
282
    resultTypes[current++] = doub;
 
283
    resultTypes[current++] = doub;
 
284
    resultTypes[current++] = doub;
 
285
 
 
286
    resultTypes[current++] = doub;
 
287
    resultTypes[current++] = doub;
 
288
    resultTypes[current++] = doub;
 
289
    resultTypes[current++] = doub;
 
290
    resultTypes[current++] = doub;
 
291
    resultTypes[current++] = doub;
 
292
 
 
293
    resultTypes[current++] = doub;
 
294
    resultTypes[current++] = doub;
 
295
    resultTypes[current++] = doub;
 
296
 
 
297
    // Timing stats
 
298
    resultTypes[current++] = doub;
 
299
    resultTypes[current++] = doub;
 
300
    resultTypes[current++] = doub;
 
301
    resultTypes[current++] = doub;
 
302
    
 
303
    // sizes
 
304
    resultTypes[current++] = doub;
 
305
    resultTypes[current++] = doub;
 
306
    resultTypes[current++] = doub;
 
307
    
 
308
    resultTypes[current++] = "";
 
309
 
 
310
    // add any additional measures
 
311
    for (int i=0;i<addm;i++) {
 
312
      resultTypes[current++] = doub;
 
313
    }
 
314
    if (current != RESULT_SIZE+addm) {
 
315
      throw new Error("ResultTypes didn't fit RESULT_SIZE");
 
316
    }
 
317
    return resultTypes;
 
318
  }
 
319
 
 
320
  /**
 
321
   * Gets the names of each of the result columns produced for a single run.
 
322
   * The number of result fields must be constant
 
323
   * for a given SplitEvaluator.
 
324
   *
 
325
   * @return an array containing the name of each result column
 
326
   */
 
327
  public String [] getResultNames() {
 
328
    int addm = (m_AdditionalMeasures != null) 
 
329
      ? m_AdditionalMeasures.length 
 
330
      : 0;
 
331
    String [] resultNames = new String[RESULT_SIZE+addm];
 
332
    int current = 0;
 
333
    resultNames[current++] = "Number_of_training_instances";
 
334
    resultNames[current++] = "Number_of_testing_instances";
 
335
 
 
336
    // Basic performance stats - right vs wrong
 
337
    resultNames[current++] = "Number_correct";
 
338
    resultNames[current++] = "Number_incorrect";
 
339
    resultNames[current++] = "Number_unclassified";
 
340
    resultNames[current++] = "Percent_correct";
 
341
    resultNames[current++] = "Percent_incorrect";
 
342
    resultNames[current++] = "Percent_unclassified";
 
343
    resultNames[current++] = "Total_cost";
 
344
    resultNames[current++] = "Average_cost";
 
345
 
 
346
    // Sensitive stats - certainty of predictions
 
347
    resultNames[current++] = "Mean_absolute_error";
 
348
    resultNames[current++] = "Root_mean_squared_error";
 
349
    resultNames[current++] = "Relative_absolute_error";
 
350
    resultNames[current++] = "Root_relative_squared_error";
 
351
 
 
352
    // SF stats
 
353
    resultNames[current++] = "SF_prior_entropy";
 
354
    resultNames[current++] = "SF_scheme_entropy";
 
355
    resultNames[current++] = "SF_entropy_gain";
 
356
    resultNames[current++] = "SF_mean_prior_entropy";
 
357
    resultNames[current++] = "SF_mean_scheme_entropy";
 
358
    resultNames[current++] = "SF_mean_entropy_gain";
 
359
 
 
360
    // K&B stats
 
361
    resultNames[current++] = "KB_information";
 
362
    resultNames[current++] = "KB_mean_information";
 
363
    resultNames[current++] = "KB_relative_information";
 
364
 
 
365
    // Timing stats
 
366
    resultNames[current++] = "Elapsed_Time_training";
 
367
    resultNames[current++] = "Elapsed_Time_testing";
 
368
    resultNames[current++] = "UserCPU_Time_training";
 
369
    resultNames[current++] = "UserCPU_Time_testing";
 
370
 
 
371
    // sizes
 
372
    resultNames[current++] = "Serialized_Model_Size";
 
373
    resultNames[current++] = "Serialized_Train_Set_Size";
 
374
    resultNames[current++] = "Serialized_Test_Set_Size";
 
375
 
 
376
    // Classifier defined extras
 
377
    resultNames[current++] = "Summary";
 
378
    // add any additional measures
 
379
    for (int i=0;i<addm;i++) {
 
380
      resultNames[current++] = m_AdditionalMeasures[i];
 
381
    }
 
382
    if (current != RESULT_SIZE+addm) {
 
383
      throw new Error("ResultNames didn't fit RESULT_SIZE");
 
384
    }
 
385
    return resultNames;
 
386
  }
 
387
 
 
388
  /**
 
389
   * Gets the results for the supplied train and test datasets. Now performs
 
390
   * a deep copy of the classifier before it is built and evaluated (just in case
 
391
   * the classifier is not initialized properly in buildClassifier()).
 
392
   *
 
393
   * @param train the training Instances.
 
394
   * @param test the testing Instances.
 
395
   * @return the results stored in an array. The objects stored in
 
396
   * the array may be Strings, Doubles, or null (for the missing value).
 
397
   * @throws Exception if a problem occurs while getting the results
 
398
   */
 
399
  public Object [] getResult(Instances train, Instances test)
 
400
  throws Exception {
 
401
    
 
402
    if (train.classAttribute().type() != Attribute.NOMINAL) {
 
403
      throw new Exception("Class attribute is not nominal!");
 
404
    }
 
405
    if (m_Template == null) {
 
406
      throw new Exception("No classifier has been specified");
 
407
    }
 
408
    ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean();
 
409
    boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported();
 
410
    if(!thMonitor.isThreadCpuTimeEnabled())
 
411
      thMonitor.setThreadCpuTimeEnabled(true);
 
412
    
 
413
    int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0;
 
414
    Object [] result = new Object[RESULT_SIZE+addm];
 
415
    long thID = Thread.currentThread().getId();
 
416
    long CPUStartTime=-1, trainCPUTimeElapsed=-1, testCPUTimeElapsed=-1,
 
417
         trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed;    
 
418
    
 
419
    String costName = train.relationName() + CostMatrix.FILE_EXTENSION;
 
420
    File costFile = new File(getOnDemandDirectory(), costName);
 
421
    if (!costFile.exists()) {
 
422
      throw new Exception("On-demand cost file doesn't exist: " + costFile);
 
423
    }
 
424
    CostMatrix costMatrix = new CostMatrix(new BufferedReader(
 
425
    new FileReader(costFile)));
 
426
    
 
427
    Evaluation eval = new Evaluation(train, costMatrix);    
 
428
    m_Classifier = Classifier.makeCopy(m_Template);
 
429
    
 
430
    trainTimeStart = System.currentTimeMillis();
 
431
    if(canMeasureCPUTime)
 
432
      CPUStartTime = thMonitor.getThreadUserTime(thID);
 
433
    m_Classifier.buildClassifier(train);
 
434
    if(canMeasureCPUTime)
 
435
      trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
 
436
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
 
437
    testTimeStart = System.currentTimeMillis();
 
438
    if(canMeasureCPUTime)
 
439
      CPUStartTime = thMonitor.getThreadUserTime(thID);
 
440
    eval.evaluateModel(m_Classifier, test);
 
441
    if(canMeasureCPUTime)
 
442
      testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
 
443
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
 
444
    thMonitor = null;
 
445
    
 
446
    m_result = eval.toSummaryString();
 
447
    // The results stored are all per instance -- can be multiplied by the
 
448
    // number of instances to get absolute numbers
 
449
    int current = 0;
 
450
    result[current++] = new Double(train.numInstances());
 
451
    result[current++] = new Double(eval.numInstances());
 
452
    
 
453
    result[current++] = new Double(eval.correct());
 
454
    result[current++] = new Double(eval.incorrect());
 
455
    result[current++] = new Double(eval.unclassified());
 
456
    result[current++] = new Double(eval.pctCorrect());
 
457
    result[current++] = new Double(eval.pctIncorrect());
 
458
    result[current++] = new Double(eval.pctUnclassified());
 
459
    result[current++] = new Double(eval.totalCost());
 
460
    result[current++] = new Double(eval.avgCost());
 
461
    
 
462
    result[current++] = new Double(eval.meanAbsoluteError());
 
463
    result[current++] = new Double(eval.rootMeanSquaredError());
 
464
    result[current++] = new Double(eval.relativeAbsoluteError());
 
465
    result[current++] = new Double(eval.rootRelativeSquaredError());
 
466
    
 
467
    result[current++] = new Double(eval.SFPriorEntropy());
 
468
    result[current++] = new Double(eval.SFSchemeEntropy());
 
469
    result[current++] = new Double(eval.SFEntropyGain());
 
470
    result[current++] = new Double(eval.SFMeanPriorEntropy());
 
471
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
 
472
    result[current++] = new Double(eval.SFMeanEntropyGain());
 
473
    
 
474
    // K&B stats
 
475
    result[current++] = new Double(eval.KBInformation());
 
476
    result[current++] = new Double(eval.KBMeanInformation());
 
477
    result[current++] = new Double(eval.KBRelativeInformation());
 
478
    
 
479
    // Timing stats
 
480
    result[current++] = new Double(trainTimeElapsed / 1000.0);
 
481
    result[current++] = new Double(testTimeElapsed / 1000.0);
 
482
    if(canMeasureCPUTime) {
 
483
      result[current++] = new Double((trainCPUTimeElapsed/1000000.0) / 1000.0);
 
484
      result[current++] = new Double((testCPUTimeElapsed /1000000.0) / 1000.0);
 
485
    }
 
486
    else {
 
487
      result[current++] = new Double(Instance.missingValue());
 
488
      result[current++] = new Double(Instance.missingValue());
 
489
    }
 
490
    
 
491
    // sizes
 
492
    ByteArrayOutputStream bastream = new ByteArrayOutputStream();
 
493
    ObjectOutputStream oostream = new ObjectOutputStream(bastream);
 
494
    oostream.writeObject(m_Classifier);
 
495
    result[current++] = new Double(bastream.size());
 
496
    bastream = new ByteArrayOutputStream();
 
497
    oostream = new ObjectOutputStream(bastream);
 
498
    oostream.writeObject(train);
 
499
    result[current++] = new Double(bastream.size());
 
500
    bastream = new ByteArrayOutputStream();
 
501
    oostream = new ObjectOutputStream(bastream);
 
502
    oostream.writeObject(test);
 
503
    result[current++] = new Double(bastream.size());
 
504
    
 
505
    if (m_Classifier instanceof Summarizable) {
 
506
      result[current++] = ((Summarizable)m_Classifier).toSummaryString();
 
507
    } else {
 
508
      result[current++] = null;
 
509
    }
 
510
    
 
511
    for (int i=0;i<addm;i++) {
 
512
      if (m_doesProduce[i]) {
 
513
        try {
 
514
          double dv = ((AdditionalMeasureProducer)m_Classifier).
 
515
          getMeasure(m_AdditionalMeasures[i]);
 
516
          if (!Instance.isMissingValue(dv)) {
 
517
            Double value = new Double(dv);
 
518
            result[current++] = value;
 
519
          } else {
 
520
            result[current++] = null;
 
521
          }
 
522
        } catch (Exception ex) {
 
523
          System.err.println(ex);
 
524
        }
 
525
      } else {
 
526
        result[current++] = null;
 
527
      }
 
528
    }
 
529
    
 
530
    if (current != RESULT_SIZE+addm) {
 
531
      throw new Error("Results didn't fit RESULT_SIZE");
 
532
    }
 
533
    return result;
 
534
  }
 
535
 
 
536
  /**
 
537
   * Returns a text description of the split evaluator.
 
538
   *
 
539
   * @return a text description of the split evaluator.
 
540
   */
 
541
  public String toString() {
 
542
 
 
543
    String result = "CostSensitiveClassifierSplitEvaluator: ";
 
544
    if (m_Template == null) {
 
545
      return result + "<null> classifier";
 
546
    }
 
547
    return result + m_Template.getClass().getName() + " " 
 
548
      + m_ClassifierOptions + "(version " + m_ClassifierVersion + ")";
 
549
  }
 
550
} // CostSensitiveClassifierSplitEvaluator