~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/trees/j48/GraftSplit.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *  GraftSplit.java
 
19
 *  Copyright (C) 2007 Geoff Webb & Janice Boughton
 
20
 *  a split object for nodes added to a tree during grafting.
 
21
 *  (used in classifier J48g).
 
22
 */
 
23
 
 
24
package weka.classifiers.trees.j48;
 
25
 
 
26
import weka.core.*;
 
27
 
 
28
/**
 
29
 * Class implementing a split for nodes added to a tree during grafting.
 
30
 *
 
31
 * @author Janice Boughton (jrbought@infotech.monash.edu.au)
 
32
 * @version $Revision 1.0 $
 
33
 */
 
34
public class GraftSplit extends ClassifierSplitModel implements Comparable {
 
35
 
 
36
  /** the distribution for graft values, from cases in atbop */
 
37
  private Distribution m_graftdistro;
 
38
        
 
39
  /** the attribute we are splitting on */
 
40
  private int m_attIndex;
 
41
 
 
42
  /** value of split point (if numeric attribute) */
 
43
  private double m_splitPoint;
 
44
 
 
45
  /** dominant class of the subset specified by m_testType */
 
46
  private int m_maxClass;
 
47
 
 
48
  /** dominant class of the subset not specified by m_testType */
 
49
  private int m_otherLeafMaxClass;
 
50
 
 
51
  /** laplace value of the subset specified by m_testType for m_maxClass */
 
52
  private double m_laplace;
 
53
 
 
54
  /** leaf for the subset specified by m_testType */
 
55
  private Distribution m_leafdistro;
 
56
 
 
57
  /** 
 
58
   * type of test:
 
59
   * 0: <= test
 
60
   * 1: > test
 
61
   * 2: = test
 
62
   * 3: != test
 
63
   */
 
64
  private int m_testType;
 
65
 
 
66
 
 
67
  /**
 
68
   * constructor
 
69
   *
 
70
   * @param a the attribute to split on
 
71
   * @param v the value of a where split occurs
 
72
   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !)
 
73
   * @param c the class to label the leaf node pointed to by test as.
 
74
   * @param l the laplace value (needed when sorting GraftSplits)
 
75
   */
 
76
  public GraftSplit(int a, double v, int t, double c, double l) {
 
77
 
 
78
    m_attIndex = a;
 
79
    m_splitPoint = v;
 
80
    m_testType = t;
 
81
    m_maxClass = (int)c;
 
82
    m_laplace = l;
 
83
  }
 
84
 
 
85
 
 
86
  /**
 
87
   * constructor
 
88
   *
 
89
   * @param a the attribute to split on
 
90
   * @param v the value of a where split occurs
 
91
   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !=)
 
92
   * @param oC the class to label the leaf node not pointed to by test as.
 
93
   * @param counts the distribution for this split
 
94
   */
 
95
  public GraftSplit(int a, double v, int t, double oC, double [][] counts)
 
96
                                                           throws Exception {
 
97
    m_attIndex = a;
 
98
    m_splitPoint = v;
 
99
    m_testType = t;
 
100
    m_otherLeafMaxClass = (int)oC;
 
101
 
 
102
    // only deal with binary cuts (<= and >; = and !=)
 
103
    m_numSubsets = 2;
 
104
 
 
105
    // which subset are we looking at for the graft?
 
106
    int subset = subsetOfInterest();  // this is the subset for m_leaf
 
107
 
 
108
    // create graft distribution, based on counts
 
109
    m_distribution = new Distribution(counts);
 
110
 
 
111
    // create a distribution object for m_leaf
 
112
    double [][] lcounts = new double[1][m_distribution.numClasses()];
 
113
    for(int c = 0; c < lcounts[0].length; c++) {
 
114
       lcounts[0][c] = counts[subset][c];
 
115
    }
 
116
    m_leafdistro = new Distribution(lcounts);
 
117
 
 
118
    // set the max class
 
119
    m_maxClass = m_distribution.maxClass(subset);
 
120
 
 
121
    // set the laplace value (assumes binary class) for subset of interest
 
122
    m_laplace = (m_distribution.perClassPerBag(subset, m_maxClass) + 1.0) 
 
123
               / (m_distribution.perBag(subset) + 2.0);
 
124
  }
 
125
 
 
126
 
 
127
  /**
 
128
   * deletes the cases in data that belong to leaf pointed to by
 
129
   * the test (i.e. the subset of interest).  this is useful so
 
130
   * the instances belonging to that leaf aren't passed down the
 
131
   * other branch.
 
132
   *
 
133
   * @param data the instances to delete from
 
134
   */
 
135
  public void deleteGraftedCases(Instances data) {
 
136
 
 
137
    int subOfInterest = subsetOfInterest();
 
138
    for(int x = 0; x < data.numInstances(); x++) {
 
139
       if(whichSubset(data.instance(x)) == subOfInterest) {
 
140
          data.delete(x--);
 
141
       }
 
142
    }
 
143
  }
 
144
 
 
145
 
 
146
  /**
 
147
   * builds m_graftdistro using the passed data
 
148
   *
 
149
   * @param data the instances to use when creating the distribution
 
150
   */
 
151
  public void buildClassifier(Instances data) throws Exception {
 
152
 
 
153
    // distribution for the graft, not counting cases in atbop, only orig leaf
 
154
    m_graftdistro = new Distribution(2, data.numClasses());
 
155
 
 
156
    // which subset are we looking at for the graft?
 
157
    int subset = subsetOfInterest();  // this is the subset for m_leaf
 
158
 
 
159
    double thisNodeCount = 0;
 
160
    double knownCases = 0;
 
161
    boolean allKnown = true;
 
162
    // populate distribution
 
163
    for(int x = 0; x < data.numInstances(); x++) {
 
164
       Instance instance = data.instance(x);
 
165
       if(instance.isMissing(m_attIndex)) {
 
166
          allKnown = false;
 
167
          continue;
 
168
       }
 
169
       knownCases += instance.weight();
 
170
       int subst = whichSubset(instance);
 
171
       if(subst == -1)
 
172
          continue;
 
173
       m_graftdistro.add(subst, instance);
 
174
       if(subst == subset) {  // instance belongs at m_leaf
 
175
          thisNodeCount += instance.weight();
 
176
       }
 
177
    }
 
178
    double factor = (knownCases == 0) ? (1.0 / (double)2.0)
 
179
                                      : (thisNodeCount / knownCases);
 
180
    if(!allKnown) {
 
181
       for(int x = 0; x < data.numInstances(); x++) {
 
182
          if(data.instance(x).isMissing(m_attIndex)) {
 
183
             Instance instance = data.instance(x);
 
184
             int subst = whichSubset(instance);
 
185
             if(subst == -1)
 
186
                continue;
 
187
             instance.setWeight(instance.weight() * factor);
 
188
             m_graftdistro.add(subst, instance);
 
189
          }
 
190
       }
 
191
    }
 
192
 
 
193
    // if there are no cases at the leaf, make sure the desired
 
194
    // class is chosen, by setting counts to 0.01
 
195
    if(m_graftdistro.perBag(subset) == 0) {
 
196
       double [] counts = new double[data.numClasses()];
 
197
       counts[m_maxClass] = 0.01;
 
198
       m_graftdistro.add(subset, counts);
 
199
    }
 
200
    if(m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {
 
201
       double [] counts = new double[data.numClasses()];
 
202
       counts[(int)m_otherLeafMaxClass] = 0.01;
 
203
       m_graftdistro.add((subset == 0) ? 1 : 0, counts);
 
204
    }
 
205
  }
 
206
 
 
207
 
 
208
  /**
 
209
   * @return the NoSplit object for the leaf pointed to by m_testType branch
 
210
   */
 
211
  public NoSplit getLeaf() {
 
212
    return new NoSplit(m_leafdistro);
 
213
  }
 
214
 
 
215
 
 
216
  /**
 
217
   * @return the NoSplit object for the leaf not pointed to by m_testType branch
 
218
   */
 
219
  public NoSplit getOtherLeaf() {
 
220
 
 
221
    // the bag (subset) that isn't pointed to by m_testType branch
 
222
    int bag = (subsetOfInterest() == 0) ? 1 : 0;
 
223
 
 
224
    double [][] counts = new double[1][m_graftdistro.numClasses()];
 
225
    double totals = 0;
 
226
    for(int c = 0; c < counts[0].length; c++) {
 
227
       counts[0][c] = m_graftdistro.perClassPerBag(bag, c);
 
228
       totals += counts[0][c];
 
229
    }
 
230
    // if empty, make sure proper class gets chosen
 
231
    if(totals == 0) {
 
232
       counts[0][m_otherLeafMaxClass] += 0.01;
 
233
    }
 
234
    return new NoSplit(new Distribution(counts));
 
235
  }
 
236
 
 
237
 
 
238
  /**
 
239
   * Prints label for subset index of instances (eg class).
 
240
   *
 
241
   * @param index the bag to dump label for
 
242
   * @param data to get attribute names and such
 
243
   * @return the label as a string
 
244
   * @exception Exception if something goes wrong
 
245
   */
 
246
  public final String dumpLabelG(int index, Instances data) throws Exception {
 
247
 
 
248
    StringBuffer text;
 
249
 
 
250
    text = new StringBuffer();
 
251
    text.append(((Instances)data).classAttribute().
 
252
       value((index==subsetOfInterest()) ? m_maxClass : m_otherLeafMaxClass));
 
253
    text.append(" ("+Utils.roundDouble(m_graftdistro.perBag(index),1));
 
254
    if(Utils.gr(m_graftdistro.numIncorrect(index),0))
 
255
       text.append("/"
 
256
        +Utils.roundDouble(m_graftdistro.numIncorrect(index),2));
 
257
 
 
258
    // show the graft values, only if this is subsetOfInterest()
 
259
    if(index == subsetOfInterest()) {
 
260
       text.append("|"+Utils.roundDouble(m_distribution.perBag(index),2));
 
261
       if(Utils.gr(m_distribution.numIncorrect(index),0))
 
262
          text.append("/"
 
263
             +Utils.roundDouble(m_distribution.numIncorrect(index),2));
 
264
    }
 
265
    text.append(")");
 
266
    return text.toString();
 
267
  }
 
268
 
 
269
 
 
270
  /**
 
271
   * @return the subset that is specified by the test type
 
272
   */
 
273
  public int subsetOfInterest() {
 
274
    if(m_testType == 2)
 
275
       return 0;
 
276
    if(m_testType == 3)
 
277
       return 1;
 
278
    return m_testType;
 
279
  }
 
280
 
 
281
 
 
282
  /**
 
283
   * @return the number of positive cases in the subset of interest
 
284
   */
 
285
  public double positivesForSubsetOfInterest() {
 
286
    return (m_distribution.perClassPerBag(subsetOfInterest(), m_maxClass));
 
287
  }
 
288
 
 
289
 
 
290
  /**
 
291
   * @param subset the subset to get the positives for
 
292
   * @return the number of positive cases in the specified subset
 
293
   */
 
294
  public double positives(int subset) {
 
295
    return (m_distribution.perClassPerBag(subset, 
 
296
                                    m_distribution.maxClass(subset)));
 
297
  }
 
298
 
 
299
 
 
300
  /**
 
301
   * @return the number of instances in the subset of interest
 
302
   */
 
303
  public double totalForSubsetOfInterest() {
 
304
    return (m_distribution.perBag(subsetOfInterest()));
 
305
  }
 
306
 
 
307
  
 
308
  /**
 
309
   * @param subset the index of the bag to get the total for
 
310
   * @return the number of instances in the subset
 
311
   */
 
312
  public double totalForSubset(int subset) {
 
313
    return (m_distribution.perBag(subset));
 
314
  }
 
315
 
 
316
 
 
317
  /**
 
318
   * Prints left side of condition satisfied by instances.
 
319
   *
 
320
   * @param data the data.
 
321
   */
 
322
  public String leftSide(Instances data) {
 
323
    return data.attribute(m_attIndex).name();
 
324
  }
 
325
 
 
326
 
 
327
  /**
 
328
   * @return the index of the attribute to split on
 
329
   */ 
 
330
  public int attribute() {
 
331
    return m_attIndex;
 
332
  }
 
333
 
 
334
 
 
335
  /**
 
336
   * Prints condition satisfied by instances in subset index.
 
337
   */
 
338
  public final String rightSide(int index, Instances data) {
 
339
 
 
340
    StringBuffer text;
 
341
 
 
342
    text = new StringBuffer();
 
343
    if(data.attribute(m_attIndex).isNominal())
 
344
       if(index == 0)
 
345
          text.append(" = "+
 
346
                      data.attribute(m_attIndex).value((int)m_splitPoint));
 
347
       else
 
348
          text.append(" != "+
 
349
                      data.attribute(m_attIndex).value((int)m_splitPoint));
 
350
    else
 
351
       if(index == 0)
 
352
          text.append(" <= "+
 
353
                      Utils.doubleToString(m_splitPoint,6));
 
354
       else
 
355
          text.append(" > "+
 
356
                      Utils.doubleToString(m_splitPoint,6));
 
357
    return text.toString();
 
358
  }
 
359
 
 
360
 
 
361
  /**
 
362
   * Returns a string containing java source code equivalent to the test
 
363
   * made at this node. The instance being tested is called "i".
 
364
   *
 
365
   * @param index index of the nominal value tested
 
366
   * @param data the data containing instance structure info
 
367
   * @return a value of type 'String'
 
368
   */
 
369
  public final String sourceExpression(int index, Instances data) {
 
370
 
 
371
    StringBuffer expr = null;
 
372
    if(index < 0) {
 
373
       return "i[" + m_attIndex + "] == null";
 
374
    }
 
375
    if(data.attribute(m_attIndex).isNominal()) {
 
376
       if(index == 0)
 
377
          expr = new StringBuffer("i[");
 
378
       else
 
379
          expr = new StringBuffer("!i[");
 
380
       expr.append(m_attIndex).append("]");
 
381
       expr.append(".equals(\"").append(data.attribute(m_attIndex)
 
382
                                      .value((int)m_splitPoint)).append("\")");
 
383
    } else {
 
384
       expr = new StringBuffer("((Double) i[");
 
385
       expr.append(m_attIndex).append("])");
 
386
       if(index == 0) {
 
387
          expr.append(".doubleValue() <= ").append(m_splitPoint);
 
388
       } else {
 
389
          expr.append(".doubleValue() > ").append(m_splitPoint);
 
390
       }
 
391
    }
 
392
    return expr.toString();
 
393
  }
 
394
 
 
395
 
 
396
  /**
 
397
   * @param instance the instance to produce the weights for
 
398
   * @return a double array of weights, null if only belongs to one subset
 
399
   */
 
400
  public double [] weights(Instance instance) {
 
401
 
 
402
    double [] weights;
 
403
    int i;
 
404
 
 
405
    if(instance.isMissing(m_attIndex)) {
 
406
       weights = new double [m_numSubsets];
 
407
       for(i=0;i<m_numSubsets;i++) {
 
408
          weights [i] = m_graftdistro.perBag(i)/m_graftdistro.total();
 
409
       }
 
410
       return weights;
 
411
    } else {
 
412
       return null;
 
413
    }
 
414
  }
 
415
 
 
416
 
 
417
  /**
 
418
   * @param instance the instance for which to determine the subset
 
419
   * @return an int indicating the subset this instance belongs to
 
420
   */
 
421
  public int whichSubset(Instance instance) {
 
422
 
 
423
    if(instance.isMissing(m_attIndex))
 
424
       return -1;
 
425
 
 
426
    if(instance.attribute(m_attIndex).isNominal()) {
 
427
       // in the case of nominal, m_splitPoint is the = value, all else is !=
 
428
       if(instance.value(m_attIndex) == m_splitPoint)
 
429
          return 0;
 
430
       else
 
431
          return 1;
 
432
    } else {
 
433
       if(Utils.smOrEq(instance.value(m_attIndex), m_splitPoint))
 
434
          return 0;
 
435
       else
 
436
          return 1;
 
437
    }
 
438
  }
 
439
 
 
440
 
 
441
  /**
 
442
   * @return the value of the split point
 
443
   */
 
444
  public double splitPoint() {
 
445
    return m_splitPoint;
 
446
  }
 
447
 
 
448
  /**
 
449
   * @return the dominate class for the subset of interest
 
450
   */
 
451
  public int maxClassForSubsetOfInterest() {
 
452
    return m_maxClass;
 
453
  }
 
454
 
 
455
  /**
 
456
   * @return the laplace value for maxClass of subset of interest
 
457
   */
 
458
  public double laplaceForSubsetOfInterest() {
 
459
    return m_laplace;
 
460
  }
 
461
 
 
462
  /**
 
463
   * returns the test type
 
464
   * @return value of testtype
 
465
   */
 
466
  public int testType() {
 
467
    return m_testType;
 
468
  }
 
469
 
 
470
  /**
 
471
   * method needed for sorting a collection of GraftSplits by laplace value
 
472
   * @param g the graft split to compare to this one
 
473
   * @return -1, 0, or 1 if this GraftSplit laplace is <, = or > than that of g
 
474
   */
 
475
  public int compareTo(Object g) {
 
476
 
 
477
    if(m_laplace > ((GraftSplit)g).laplaceForSubsetOfInterest())
 
478
       return 1;
 
479
    if(m_laplace < ((GraftSplit)g).laplaceForSubsetOfInterest())
 
480
       return -1;
 
481
    return 0;
 
482
  }
 
483
 
 
484
  /**
 
485
   * returns the probability for instance for the specified class
 
486
   * @param classIndex the index of the class
 
487
   * @param instance the instance to get the probability for
 
488
   * @param theSubset the subset
 
489
   */
 
490
  public final double classProb(int classIndex, Instance instance, 
 
491
                            int theSubset) throws Exception {
 
492
 
 
493
    if (theSubset <= -1) {
 
494
       double [] weights = weights(instance);
 
495
       if (weights == null) {
 
496
          return m_distribution.prob(classIndex);
 
497
       } else {
 
498
          double prob = 0;
 
499
          for (int i = 0; i < weights.length; i++) {
 
500
             prob += weights[i] * m_distribution.prob(classIndex, i);
 
501
          }
 
502
          return prob;
 
503
       }
 
504
    } else {
 
505
       if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
 
506
          return m_distribution.prob(classIndex, theSubset);
 
507
       } else {
 
508
          return m_distribution.prob(classIndex);
 
509
       }
 
510
    }
 
511
  }
 
512
 
 
513
 
 
514
  /**
 
515
   * method for returning information about this GraftSplit
 
516
   * @param data instances for determining names of attributes and values
 
517
   * @return a string showing this GraftSplit's information
 
518
   */
 
519
  public String toString(Instances data) {
 
520
 
 
521
    String theTest;
 
522
    if(m_testType == 0)
 
523
       theTest = " <= ";
 
524
    else if(m_testType == 1)
 
525
       theTest = " > ";
 
526
    else if(m_testType == 2)
 
527
       theTest = " = ";
 
528
    else
 
529
       theTest = " != ";
 
530
 
 
531
    if(data.attribute(m_attIndex).isNominal())
 
532
       theTest += data.attribute(m_attIndex).value((int)m_splitPoint);
 
533
    else
 
534
       theTest += Double.toString(m_splitPoint);
 
535
 
 
536
    return data.attribute(m_attIndex).name() + theTest
 
537
           + " (" + Double.toString(m_laplace) + ") --> " 
 
538
           + data.attribute(data.classIndex()).value(m_maxClass);
 
539
  }
 
540
}