~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/meta/nestedDichotomies/DataNearBalancedND.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribsute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    DataNearBalancedND.java
 
19
 *    Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
package weka.classifiers.meta.nestedDichotomies;
 
23
 
 
24
import weka.classifiers.Classifier;
 
25
import weka.classifiers.RandomizableSingleClassifierEnhancer;
 
26
import weka.classifiers.meta.FilteredClassifier;
 
27
import weka.core.Capabilities;
 
28
import weka.core.Instance;
 
29
import weka.core.Instances;
 
30
import weka.core.Range;
 
31
import weka.core.TechnicalInformation;
 
32
import weka.core.TechnicalInformationHandler;
 
33
import weka.core.Utils;
 
34
import weka.core.Capabilities.Capability;
 
35
import weka.core.TechnicalInformation.Field;
 
36
import weka.core.TechnicalInformation.Type;
 
37
import weka.filters.Filter;
 
38
import weka.filters.unsupervised.attribute.MakeIndicator;
 
39
import weka.filters.unsupervised.instance.RemoveWithValues;
 
40
 
 
41
import java.util.Hashtable;
 
42
import java.util.Random;
 
43
 
 
44
 
 
45
/**
 
46
 <!-- globalinfo-start -->
 
47
 * A meta classifier for handling multi-class datasets with 2-class classifiers by building a random data-balanced tree structure.<br/>
 
48
 * <br/>
 
49
 * For more info, check<br/>
 
50
 * <br/>
 
51
 * Lin Dong, Eibe Frank, Stefan Kramer: Ensembles of Balanced Nested Dichotomies for Multi-class Problems. In: PKDD, 84-95, 2005.<br/>
 
52
 * <br/>
 
53
 * Eibe Frank, Stefan Kramer: Ensembles of nested dichotomies for multi-class problems. In: Twenty-first International Conference on Machine Learning, 2004.
 
54
 * <p/>
 
55
 <!-- globalinfo-end -->
 
56
 *
 
57
 <!-- technical-bibtex-start -->
 
58
 * BibTeX:
 
59
 * <pre>
 
60
 * &#64;inproceedings{Dong2005,
 
61
 *    author = {Lin Dong and Eibe Frank and Stefan Kramer},
 
62
 *    booktitle = {PKDD},
 
63
 *    pages = {84-95},
 
64
 *    publisher = {Springer},
 
65
 *    title = {Ensembles of Balanced Nested Dichotomies for Multi-class Problems},
 
66
 *    year = {2005}
 
67
 * }
 
68
 * 
 
69
 * &#64;inproceedings{Frank2004,
 
70
 *    author = {Eibe Frank and Stefan Kramer},
 
71
 *    booktitle = {Twenty-first International Conference on Machine Learning},
 
72
 *    publisher = {ACM},
 
73
 *    title = {Ensembles of nested dichotomies for multi-class problems},
 
74
 *    year = {2004}
 
75
 * }
 
76
 * </pre>
 
77
 * <p/>
 
78
 <!-- technical-bibtex-end -->
 
79
 *
 
80
 <!-- options-start -->
 
81
 * Valid options are: <p/>
 
82
 * 
 
83
 * <pre> -S &lt;num&gt;
 
84
 *  Random number seed.
 
85
 *  (default 1)</pre>
 
86
 * 
 
87
 * <pre> -D
 
88
 *  If set, classifier is run in debug mode and
 
89
 *  may output additional info to the console</pre>
 
90
 * 
 
91
 * <pre> -W
 
92
 *  Full name of base classifier.
 
93
 *  (default: weka.classifiers.trees.J48)</pre>
 
94
 * 
 
95
 * <pre> 
 
96
 * Options specific to classifier weka.classifiers.trees.J48:
 
97
 * </pre>
 
98
 * 
 
99
 * <pre> -U
 
100
 *  Use unpruned tree.</pre>
 
101
 * 
 
102
 * <pre> -C &lt;pruning confidence&gt;
 
103
 *  Set confidence threshold for pruning.
 
104
 *  (default 0.25)</pre>
 
105
 * 
 
106
 * <pre> -M &lt;minimum number of instances&gt;
 
107
 *  Set minimum number of instances per leaf.
 
108
 *  (default 2)</pre>
 
109
 * 
 
110
 * <pre> -R
 
111
 *  Use reduced error pruning.</pre>
 
112
 * 
 
113
 * <pre> -N &lt;number of folds&gt;
 
114
 *  Set number of folds for reduced error
 
115
 *  pruning. One fold is used as pruning set.
 
116
 *  (default 3)</pre>
 
117
 * 
 
118
 * <pre> -B
 
119
 *  Use binary splits only.</pre>
 
120
 * 
 
121
 * <pre> -S
 
122
 *  Don't perform subtree raising.</pre>
 
123
 * 
 
124
 * <pre> -L
 
125
 *  Do not clean up after the tree has been built.</pre>
 
126
 * 
 
127
 * <pre> -A
 
128
 *  Laplace smoothing for predicted probabilities.</pre>
 
129
 * 
 
130
 * <pre> -Q &lt;seed&gt;
 
131
 *  Seed for random data shuffling (default 1).</pre>
 
132
 * 
 
133
 <!-- options-end -->
 
134
 *
 
135
 * @author Lin Dong
 
136
 * @author Eibe Frank
 
137
 */
 
138
public class DataNearBalancedND 
 
139
  extends RandomizableSingleClassifierEnhancer
 
140
  implements TechnicalInformationHandler {
 
141
 
 
142
  /** for serialization */
 
143
  static final long serialVersionUID = 5117477294209496368L;
 
144
  
 
145
  /** The filtered classifier in which the base classifier is wrapped. */
 
146
  protected FilteredClassifier m_FilteredClassifier;
 
147
    
 
148
  /** The hashtable for this node. */
 
149
  protected Hashtable m_classifiers=new Hashtable();
 
150
 
 
151
  /** The first successor */
 
152
  protected DataNearBalancedND m_FirstSuccessor = null;
 
153
 
 
154
  /** The second successor */
 
155
  protected DataNearBalancedND m_SecondSuccessor = null;
 
156
  
 
157
  /** The classes that are grouped together at the current node */
 
158
  protected Range m_Range = null;
 
159
    
 
160
  /** Is Hashtable given from END? */
 
161
  protected boolean m_hashtablegiven = false;
 
162
    
 
163
  /**
 
164
   * Constructor.
 
165
   */
 
166
  public DataNearBalancedND() {
 
167
    
 
168
    m_Classifier = new weka.classifiers.trees.J48();
 
169
  }
 
170
  
 
171
  /**
 
172
   * String describing default classifier.
 
173
   * 
 
174
   * @return the default classifier classname
 
175
   */
 
176
  protected String defaultClassifierString() {
 
177
    
 
178
    return "weka.classifiers.trees.J48";
 
179
  }
 
180
 
 
181
  /**
 
182
   * Returns an instance of a TechnicalInformation object, containing 
 
183
   * detailed information about the technical background of this class,
 
184
   * e.g., paper reference or book this class is based on.
 
185
   * 
 
186
   * @return the technical information about this class
 
187
   */
 
188
  public TechnicalInformation getTechnicalInformation() {
 
189
    TechnicalInformation        result;
 
190
    TechnicalInformation        additional;
 
191
    
 
192
    result = new TechnicalInformation(Type.INPROCEEDINGS);
 
193
    result.setValue(Field.AUTHOR, "Lin Dong and Eibe Frank and Stefan Kramer");
 
194
    result.setValue(Field.TITLE, "Ensembles of Balanced Nested Dichotomies for Multi-class Problems");
 
195
    result.setValue(Field.BOOKTITLE, "PKDD");
 
196
    result.setValue(Field.YEAR, "2005");
 
197
    result.setValue(Field.PAGES, "84-95");
 
198
    result.setValue(Field.PUBLISHER, "Springer");
 
199
 
 
200
    additional = result.add(Type.INPROCEEDINGS);
 
201
    additional.setValue(Field.AUTHOR, "Eibe Frank and Stefan Kramer");
 
202
    additional.setValue(Field.TITLE, "Ensembles of nested dichotomies for multi-class problems");
 
203
    additional.setValue(Field.BOOKTITLE, "Twenty-first International Conference on Machine Learning");
 
204
    additional.setValue(Field.YEAR, "2004");
 
205
    additional.setValue(Field.PUBLISHER, "ACM");
 
206
    
 
207
    return result;
 
208
  }
 
209
 
 
210
  /**
 
211
   * Set hashtable from END.
 
212
   * 
 
213
   * @param table the hashtable to use
 
214
   */
 
215
  public void setHashtable(Hashtable table) {
 
216
 
 
217
    m_hashtablegiven = true;
 
218
    m_classifiers = table;
 
219
  }
 
220
    
 
221
  /**
 
222
   * Generates a classifier for the current node and proceeds recursively.
 
223
   *
 
224
   * @param data contains the (multi-class) instances
 
225
   * @param classes contains the indices of the classes that are present
 
226
   * @param rand the random number generator to use
 
227
   * @param classifier the classifier to use
 
228
   * @param table the Hashtable to use
 
229
   * @param instsNumAllClasses
 
230
   * @throws Exception if anything goes worng
 
231
   */
 
232
  private void generateClassifierForNode(Instances data, Range classes,
 
233
                                         Random rand, Classifier classifier, Hashtable table,
 
234
                                         double[] instsNumAllClasses) 
 
235
    throws Exception {
 
236
        
 
237
    // Get the indices
 
238
    int[] indices = classes.getSelection();
 
239
 
 
240
    // Randomize the order of the indices
 
241
    for (int j = indices.length - 1; j > 0; j--) {
 
242
      int randPos = rand.nextInt(j + 1);
 
243
      int temp = indices[randPos];
 
244
      indices[randPos] = indices[j];
 
245
      indices[j] = temp;
 
246
    }
 
247
 
 
248
    // Pick the classes for the current split
 
249
    double total = 0;
 
250
    for (int j = 0; j < indices.length; j++) {
 
251
      total += instsNumAllClasses[indices[j]];
 
252
    }
 
253
    double halfOfTotal = total / 2;
 
254
        
 
255
    // Go through the list of classes until the either the left or
 
256
    // right subset exceeds half the total weight
 
257
    double sumLeft = 0, sumRight = 0;
 
258
    int i = 0, j = indices.length - 1;
 
259
    do {
 
260
      if (i == j) {
 
261
        if (rand.nextBoolean()) {
 
262
          sumLeft += instsNumAllClasses[indices[i++]];
 
263
        } else {
 
264
          sumRight += instsNumAllClasses[indices[j--]];
 
265
        }
 
266
      } else {
 
267
        sumLeft += instsNumAllClasses[indices[i++]];
 
268
        sumRight += instsNumAllClasses[indices[j--]];
 
269
      }
 
270
    } while (Utils.sm(sumLeft, halfOfTotal) && Utils.sm(sumRight, halfOfTotal));
 
271
 
 
272
    int first = 0, second = 0;
 
273
    if (!Utils.sm(sumLeft, halfOfTotal)) {
 
274
      first = i;
 
275
    } else {
 
276
      first = j + 1;
 
277
    }
 
278
    second = indices.length - first;
 
279
 
 
280
    int[] firstInds = new int[first];
 
281
    int[] secondInds = new int[second];
 
282
    System.arraycopy(indices, 0, firstInds, 0, first);
 
283
    System.arraycopy(indices, first, secondInds, 0, second);
 
284
        
 
285
    // Sort the indices (important for hash key)!
 
286
    int[] sortedFirst = Utils.sort(firstInds);
 
287
    int[] sortedSecond = Utils.sort(secondInds);
 
288
    int[] firstCopy = new int[first];
 
289
    int[] secondCopy = new int[second];
 
290
       for (int k = 0; k < sortedFirst.length; k++) {
 
291
      firstCopy[k] = firstInds[sortedFirst[k]];
 
292
    }
 
293
    firstInds = firstCopy;
 
294
    for (int k = 0; k < sortedSecond.length; k++) {
 
295
      secondCopy[k] = secondInds[sortedSecond[k]];
 
296
    }
 
297
    secondInds = secondCopy;
 
298
                
 
299
    // Unify indices to improve hashing
 
300
    if (firstInds[0] > secondInds[0]) {
 
301
      int[] help = secondInds;
 
302
      secondInds = firstInds;
 
303
      firstInds = help;
 
304
      int help2 = second;
 
305
      second = first;
 
306
      first = help2;
 
307
    }
 
308
 
 
309
    m_Range = new Range(Range.indicesToRangeList(firstInds));
 
310
    m_Range.setUpper(data.numClasses() - 1);
 
311
 
 
312
    Range secondRange = new Range(Range.indicesToRangeList(secondInds));
 
313
    secondRange.setUpper(data.numClasses() - 1);
 
314
       
 
315
    // Change the class labels and build the classifier
 
316
    MakeIndicator filter = new MakeIndicator();
 
317
    filter.setAttributeIndex("" + (data.classIndex() + 1));
 
318
    filter.setValueIndices(m_Range.getRanges());
 
319
    filter.setNumeric(false);
 
320
    filter.setInputFormat(data);
 
321
    m_FilteredClassifier = new FilteredClassifier();
 
322
    if (data.numInstances() > 0) {
 
323
      m_FilteredClassifier.setClassifier(Classifier.makeCopies(classifier, 1)[0]);
 
324
    } else {
 
325
      m_FilteredClassifier.setClassifier(new weka.classifiers.rules.ZeroR());
 
326
    }
 
327
    m_FilteredClassifier.setFilter(filter);
 
328
 
 
329
    // Save reference to hash table at current node
 
330
    m_classifiers=table;
 
331
        
 
332
    if (!m_classifiers.containsKey( getString(firstInds) + "|" + getString(secondInds))) {
 
333
      m_FilteredClassifier.buildClassifier(data);
 
334
      m_classifiers.put(getString(firstInds) + "|" + getString(secondInds), m_FilteredClassifier);
 
335
    } else {
 
336
      m_FilteredClassifier=(FilteredClassifier)m_classifiers.get(getString(firstInds) + "|" + 
 
337
                                                                 getString(secondInds));        
 
338
    }
 
339
                                
 
340
    // Create two successors if necessary
 
341
    m_FirstSuccessor = new DataNearBalancedND();
 
342
    if (first == 1) {
 
343
      m_FirstSuccessor.m_Range = m_Range;
 
344
    } else {
 
345
      RemoveWithValues rwv = new RemoveWithValues();
 
346
      rwv.setInvertSelection(true);
 
347
      rwv.setNominalIndices(m_Range.getRanges());
 
348
      rwv.setAttributeIndex("" + (data.classIndex() + 1));
 
349
      rwv.setInputFormat(data);
 
350
      Instances firstSubset = Filter.useFilter(data, rwv);
 
351
      m_FirstSuccessor.generateClassifierForNode(firstSubset, m_Range, 
 
352
                                                 rand, classifier, m_classifiers,
 
353
                                                 instsNumAllClasses);
 
354
    }
 
355
    m_SecondSuccessor = new DataNearBalancedND();
 
356
    if (second == 1) {
 
357
      m_SecondSuccessor.m_Range = secondRange;
 
358
    } else {
 
359
      RemoveWithValues rwv = new RemoveWithValues();
 
360
      rwv.setInvertSelection(true);
 
361
      rwv.setNominalIndices(secondRange.getRanges());
 
362
      rwv.setAttributeIndex("" + (data.classIndex() + 1));
 
363
      rwv.setInputFormat(data);
 
364
      Instances secondSubset = Filter.useFilter(data, rwv);
 
365
      m_SecondSuccessor = new DataNearBalancedND();
 
366
      
 
367
      m_SecondSuccessor.generateClassifierForNode(secondSubset, secondRange, 
 
368
                                                  rand, classifier, m_classifiers,
 
369
                                                  instsNumAllClasses);
 
370
    }
 
371
  }
 
372
 
 
373
  /**
 
374
   * Returns default capabilities of the classifier.
 
375
   *
 
376
   * @return      the capabilities of this classifier
 
377
   */
 
378
  public Capabilities getCapabilities() {
 
379
    Capabilities result = super.getCapabilities();
 
380
 
 
381
    // class
 
382
    result.disableAllClasses();
 
383
    result.enable(Capability.NOMINAL_CLASS);
 
384
    result.enable(Capability.MISSING_CLASS_VALUES);
 
385
 
 
386
    // instances
 
387
    result.setMinimumNumberInstances(1);
 
388
    
 
389
    return result;
 
390
  }
 
391
    
 
392
  /**
 
393
   * Builds tree recursively.
 
394
   *
 
395
   * @param data contains the (multi-class) instances
 
396
   * @throws Exception if the building fails
 
397
   */
 
398
  public void buildClassifier(Instances data) throws Exception {
 
399
 
 
400
    // can classifier handle the data?
 
401
    getCapabilities().testWithFail(data);
 
402
 
 
403
    // remove instances with missing class
 
404
    data = new Instances(data);
 
405
    data.deleteWithMissingClass();
 
406
    
 
407
    Random random = data.getRandomNumberGenerator(m_Seed);
 
408
        
 
409
    if (!m_hashtablegiven) {
 
410
      m_classifiers = new Hashtable();
 
411
    }
 
412
        
 
413
    // Check which classes are present in the
 
414
    // data and construct initial list of classes
 
415
    boolean[] present = new boolean[data.numClasses()];
 
416
    for (int i = 0; i < data.numInstances(); i++) {
 
417
      present[(int)data.instance(i).classValue()] = true;
 
418
    }
 
419
    StringBuffer list = new StringBuffer();
 
420
    for (int i = 0; i < present.length; i++) {
 
421
      if (present[i]) {
 
422
        if (list.length() > 0) {
 
423
          list.append(",");
 
424
        }
 
425
        list.append(i + 1);
 
426
      }
 
427
    }
 
428
 
 
429
    // Determine the number of instances in each class
 
430
    double[] instsNum = new double[data.numClasses()];
 
431
    for (int i = 0; i < data.numInstances(); i++) {
 
432
      instsNum[(int)data.instance(i).classValue()] += data.instance(i).weight();
 
433
    }
 
434
      
 
435
    Range newRange = new Range(list.toString());
 
436
    newRange.setUpper(data.numClasses() - 1);
 
437
        
 
438
    generateClassifierForNode(data, newRange, random, m_Classifier, m_classifiers, instsNum);
 
439
  }
 
440
    
 
441
  /**
 
442
   * Predicts the class distribution for a given instance
 
443
   *
 
444
   * @param inst the (multi-class) instance to be classified
 
445
   * @return the class distribution
 
446
   * @throws Exception if computing fails
 
447
   */
 
448
  public double[] distributionForInstance(Instance inst) throws Exception {
 
449
        
 
450
    double[] newDist = new double[inst.numClasses()];
 
451
    if (m_FirstSuccessor == null) {
 
452
      for (int i = 0; i < inst.numClasses(); i++) {
 
453
        if (m_Range.isInRange(i)) {
 
454
          newDist[i] = 1;
 
455
        }
 
456
      }
 
457
      return newDist;
 
458
    } else {
 
459
      double[] firstDist = m_FirstSuccessor.distributionForInstance(inst);
 
460
      double[] secondDist = m_SecondSuccessor.distributionForInstance(inst);
 
461
      double[] dist = m_FilteredClassifier.distributionForInstance(inst);
 
462
      for (int i = 0; i < inst.numClasses(); i++) {
 
463
        if ((firstDist[i] > 0) && (secondDist[i] > 0)) {
 
464
          System.err.println("Panik!!");
 
465
        }
 
466
        if (m_Range.isInRange(i)) {
 
467
          newDist[i] = dist[1] * firstDist[i];
 
468
        } else {
 
469
          newDist[i] = dist[0] * secondDist[i];
 
470
        }
 
471
      }
 
472
      if  (!Utils.eq(Utils.sum(newDist), 1)) {
 
473
        System.err.println(Utils.sum(newDist));
 
474
        for (int j = 0; j < dist.length; j++) {
 
475
          System.err.print(dist[j] + " ");
 
476
        }
 
477
        System.err.println();
 
478
        for (int j = 0; j < newDist.length; j++) {
 
479
          System.err.print(newDist[j] + " ");
 
480
        }
 
481
        System.err.println();
 
482
        System.err.println(inst);
 
483
        System.err.println(m_FilteredClassifier);
 
484
        //System.err.println(m_Data);
 
485
        System.err.println("bad");
 
486
      }
 
487
      return newDist;
 
488
    }
 
489
  }
 
490
    
 
491
  /**
 
492
   * Returns the list of indices as a string.
 
493
   * 
 
494
   * @param indices the indices to return as string
 
495
   * @return the indices as string
 
496
   */
 
497
  public String getString(int [] indices) {
 
498
 
 
499
    StringBuffer string = new StringBuffer();
 
500
    for (int i = 0; i < indices.length; i++) {
 
501
      if (i > 0) {
 
502
        string.append(',');
 
503
      }
 
504
      string.append(indices[i]);
 
505
    }
 
506
    return string.toString();
 
507
  }
 
508
        
 
509
  /**
 
510
   * @return a description of the classifier suitable for
 
511
   * displaying in the explorer/experimenter gui
 
512
   */
 
513
  public String globalInfo() {
 
514
            
 
515
    return 
 
516
        "A meta classifier for handling multi-class datasets with 2-class "
 
517
      + "classifiers by building a random data-balanced tree structure.\n\n"
 
518
      + "For more info, check\n\n"
 
519
      + getTechnicalInformation().toString();
 
520
  }
 
521
        
 
522
  /**
 
523
   * Outputs the classifier as a string.
 
524
   * 
 
525
   * @return a string representation of the classifier
 
526
   */
 
527
  public String toString() {
 
528
            
 
529
    if (m_classifiers == null) {
 
530
      return "DataNearBalancedND: No model built yet.";
 
531
    }
 
532
    StringBuffer text = new StringBuffer();
 
533
    text.append("DataNearBalancedND");
 
534
    treeToString(text, 0);
 
535
            
 
536
    return text.toString();
 
537
  }
 
538
        
 
539
  /**
 
540
   * Returns string description of the tree.
 
541
   * 
 
542
   * @param text the buffer to add the node to
 
543
   * @param nn the node number
 
544
   * @return the next node number
 
545
   */
 
546
  private int treeToString(StringBuffer text, int nn) {
 
547
            
 
548
    nn++;
 
549
    text.append("\n\nNode number: " + nn + "\n\n");
 
550
    if (m_FilteredClassifier != null) {
 
551
      text.append(m_FilteredClassifier);
 
552
    } else {
 
553
      text.append("null");
 
554
    }
 
555
    if (m_FirstSuccessor != null) {
 
556
      nn = m_FirstSuccessor.treeToString(text, nn);
 
557
      nn = m_SecondSuccessor.treeToString(text, nn);
 
558
    }
 
559
    return nn;
 
560
  }
 
561
        
 
562
  /**
 
563
   * Main method for testing this class.
 
564
   *
 
565
   * @param argv the options
 
566
   */
 
567
  public static void main(String [] argv) {
 
568
    runClassifier(new DataNearBalancedND(), argv);
 
569
  }
 
570
}
 
571