~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/trees/SimpleCart.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * SimpleCart.java
 
19
 * Copyright (C) 2007 Haijian Shi
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers.trees;
 
24
 
 
25
import weka.classifiers.Evaluation;
 
26
import weka.classifiers.RandomizableClassifier;
 
27
import weka.core.AdditionalMeasureProducer;
 
28
import weka.core.Attribute;
 
29
import weka.core.Capabilities;
 
30
import weka.core.Instance;
 
31
import weka.core.Instances;
 
32
import weka.core.Option;
 
33
import weka.core.TechnicalInformation;
 
34
import weka.core.TechnicalInformationHandler;
 
35
import weka.core.Utils;
 
36
import weka.core.Capabilities.Capability;
 
37
import weka.core.TechnicalInformation.Field;
 
38
import weka.core.TechnicalInformation.Type;
 
39
import weka.core.matrix.Matrix;
 
40
 
 
41
import java.util.Arrays;
 
42
import java.util.Enumeration;
 
43
import java.util.Random;
 
44
import java.util.Vector;
 
45
 
 
46
/**
 
47
 <!-- globalinfo-start -->
 
48
 * Class implementing minimal cost-complexity pruning.<br/>
 
49
 * Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/>
 
50
 * <br/>
 
51
 * For more information, see:<br/>
 
52
 * <br/>
 
53
 * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
 
54
 * <p/>
 
55
 <!-- globalinfo-end -->        
 
56
 *
 
57
 <!-- technical-bibtex-start -->
 
58
 * BibTeX:
 
59
 * <pre>
 
60
 * &#64;book{Breiman1984,
 
61
 *    address = {Belmont, California},
 
62
 *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
 
63
 *    publisher = {Wadsworth International Group},
 
64
 *    title = {Classification and Regression Trees},
 
65
 *    year = {1984}
 
66
 * }
 
67
 * </pre>
 
68
 * <p/>
 
69
 <!-- technical-bibtex-end -->
 
70
 *
 
71
 <!-- options-start -->
 
72
 * Valid options are: <p/>
 
73
 * 
 
74
 * <pre> -S &lt;num&gt;
 
75
 *  Random number seed.
 
76
 *  (default 1)</pre>
 
77
 * 
 
78
 * <pre> -D
 
79
 *  If set, classifier is run in debug mode and
 
80
 *  may output additional info to the console</pre>
 
81
 * 
 
82
 * <pre> -M &lt;min no&gt;
 
83
 *  The minimal number of instances at the terminal nodes.
 
84
 *  (default 2)</pre>
 
85
 * 
 
86
 * <pre> -N &lt;num folds&gt;
 
87
 *  The number of folds used in the minimal cost-complexity pruning.
 
88
 *  (default 5)</pre>
 
89
 * 
 
90
 * <pre> -U
 
91
 *  Don't use the minimal cost-complexity pruning.
 
92
 *  (default yes).</pre>
 
93
 * 
 
94
 * <pre> -H
 
95
 *  Don't use the heuristic method for binary split.
 
96
 *  (default true).</pre>
 
97
 * 
 
98
 * <pre> -A
 
99
 *  Use 1 SE rule to make pruning decision.
 
100
 *  (default no).</pre>
 
101
 * 
 
102
 * <pre> -C
 
103
 *  Percentage of training data size (0-1].
 
104
 *  (default 1).</pre>
 
105
 * 
 
106
 <!-- options-end -->
 
107
 *
 
108
 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
 
109
 * @version $Revision: 1.3 $
 
110
 */
 
111
public class SimpleCart
 
112
  extends RandomizableClassifier
 
113
  implements AdditionalMeasureProducer, TechnicalInformationHandler {
 
114
 
 
115
  /** For serialization.         */
 
116
  private static final long serialVersionUID = 4154189200352566053L;
 
117
 
 
118
  /** Training data.  */
 
119
  protected Instances m_train;
 
120
 
 
121
  /** Successor nodes. */
 
122
  protected SimpleCart[] m_Successors;
 
123
 
 
124
  /** Attribute used to split data. */
 
125
  protected Attribute m_Attribute;
 
126
 
 
127
  /** Split point for a numeric attribute. */
 
128
  protected double m_SplitValue;
 
129
 
 
130
  /** Split subset used to split data for nominal attributes. */
 
131
  protected String m_SplitString;
 
132
 
 
133
  /** Class value if the node is leaf. */
 
134
  protected double m_ClassValue;
 
135
 
 
136
  /** Class attriubte of data. */
 
137
  protected Attribute m_ClassAttribute;
 
138
 
 
139
  /** Minimum number of instances in at the terminal nodes. */
 
140
  protected double m_minNumObj = 2;
 
141
 
 
142
  /** Number of folds for minimal cost-complexity pruning. */
 
143
  protected int m_numFoldsPruning = 5;
 
144
 
 
145
  /** Alpha-value (for pruning) at the node. */
 
146
  protected double m_Alpha;
 
147
 
 
148
  /** Number of training examples misclassified by the model (subtree rooted). */
 
149
  protected double m_numIncorrectModel;
 
150
 
 
151
  /** Number of training examples misclassified by the model (subtree not rooted). */
 
152
  protected double m_numIncorrectTree;
 
153
 
 
154
  /** Indicate if the node is a leaf node. */
 
155
  protected boolean m_isLeaf;
 
156
 
 
157
  /** If use minimal cost-compexity pruning. */
 
158
  protected boolean m_Prune = true;
 
159
 
 
160
  /** Total number of instances used to build the classifier. */
 
161
  protected int m_totalTrainInstances;
 
162
 
 
163
  /** Proportion for each branch. */
 
164
  protected double[] m_Props;
 
165
 
 
166
  /** Class probabilities. */
 
167
  protected double[] m_ClassProbs = null;
 
168
 
 
169
  /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
 
170
  protected double[] m_Distribution;
 
171
 
 
172
  /** If use huristic search for nominal attributes in multi-class problems (default true). */
 
173
  protected boolean m_Heuristic = true;
 
174
 
 
175
  /** If use the 1SE rule to make final decision tree. */
 
176
  protected boolean m_UseOneSE = false;
 
177
 
 
178
  /** Training data size. */
 
179
  protected double m_SizePer = 1;
 
180
 
 
181
  /**
 
182
   * Return a description suitable for displaying in the explorer/experimenter.
 
183
   * 
 
184
   * @return            a description suitable for displaying in the 
 
185
   *                    explorer/experimenter
 
186
   */
 
187
  public String globalInfo() {
 
188
    return  
 
189
        "Class implementing minimal cost-complexity pruning.\n"
 
190
      + "Note when dealing with missing values, use \"fractional "
 
191
      + "instances\" method instead of surrogate split method.\n\n"
 
192
      + "For more information, see:\n\n"
 
193
      + getTechnicalInformation().toString();
 
194
  }
 
195
 
 
196
  /**
 
197
   * Returns an instance of a TechnicalInformation object, containing 
 
198
   * detailed information about the technical background of this class,
 
199
   * e.g., paper reference or book this class is based on.
 
200
   * 
 
201
   * @return            the technical information about this class
 
202
   */
 
203
  public TechnicalInformation getTechnicalInformation() {
 
204
    TechnicalInformation        result;
 
205
    
 
206
    result = new TechnicalInformation(Type.BOOK);
 
207
    result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
 
208
    result.setValue(Field.YEAR, "1984");
 
209
    result.setValue(Field.TITLE, "Classification and Regression Trees");
 
210
    result.setValue(Field.PUBLISHER, "Wadsworth International Group");
 
211
    result.setValue(Field.ADDRESS, "Belmont, California");
 
212
    
 
213
    return result;
 
214
  }
 
215
 
 
216
  /**
 
217
   * Returns default capabilities of the classifier.
 
218
   * 
 
219
   * @return            the capabilities of this classifier
 
220
   */
 
221
  public Capabilities getCapabilities() {
 
222
    Capabilities result = super.getCapabilities();
 
223
 
 
224
    // attributes
 
225
    result.enable(Capability.NOMINAL_ATTRIBUTES);
 
226
    result.enable(Capability.NUMERIC_ATTRIBUTES);
 
227
    result.enable(Capability.MISSING_VALUES);
 
228
 
 
229
    // class
 
230
    result.enable(Capability.NOMINAL_CLASS);
 
231
 
 
232
    return result;
 
233
  }
 
234
 
 
235
  /**
 
236
   * Build the classifier.
 
237
   * 
 
238
   * @param data        the training instances
 
239
   * @throws Exception  if something goes wrong
 
240
   */
 
241
  public void buildClassifier(Instances data) throws Exception {
 
242
 
 
243
    getCapabilities().testWithFail(data);
 
244
    data = new Instances(data);        
 
245
    data.deleteWithMissingClass();
 
246
 
 
247
    // unpruned CART decision tree
 
248
    if (!m_Prune) {
 
249
 
 
250
      // calculate sorted indices and weights, and compute initial class counts.
 
251
      int[][] sortedIndices = new int[data.numAttributes()][0];
 
252
      double[][] weights = new double[data.numAttributes()][0];
 
253
      double[] classProbs = new double[data.numClasses()];
 
254
      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
 
255
 
 
256
      makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
 
257
          totalWeight,m_minNumObj, m_Heuristic);
 
258
      return;
 
259
    }
 
260
 
 
261
    Random random = new Random(m_Seed);
 
262
    Instances cvData = new Instances(data);
 
263
    cvData.randomize(random);
 
264
    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
 
265
    cvData.stratify(m_numFoldsPruning);
 
266
 
 
267
    double[][] alphas = new double[m_numFoldsPruning][];
 
268
    double[][] errors = new double[m_numFoldsPruning][];
 
269
 
 
270
    // calculate errors and alphas for each fold
 
271
    for (int i = 0; i < m_numFoldsPruning; i++) {
 
272
 
 
273
      //for every fold, grow tree on training set and fix error on test set.
 
274
      Instances train = cvData.trainCV(m_numFoldsPruning, i);
 
275
      Instances test = cvData.testCV(m_numFoldsPruning, i);
 
276
 
 
277
      // calculate sorted indices and weights, and compute initial class counts for each fold
 
278
      int[][] sortedIndices = new int[train.numAttributes()][0];
 
279
      double[][] weights = new double[train.numAttributes()][0];
 
280
      double[] classProbs = new double[train.numClasses()];
 
281
      double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);
 
282
 
 
283
      makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
 
284
          totalWeight,m_minNumObj, m_Heuristic);
 
285
 
 
286
      int numNodes = numInnerNodes();
 
287
      alphas[i] = new double[numNodes + 2];
 
288
      errors[i] = new double[numNodes + 2];
 
289
 
 
290
      // prune back and log alpha-values and errors on test set
 
291
      prune(alphas[i], errors[i], test);
 
292
    }
 
293
 
 
294
    // calculate sorted indices and weights, and compute initial class counts on all training instances
 
295
    int[][] sortedIndices = new int[data.numAttributes()][0];
 
296
    double[][] weights = new double[data.numAttributes()][0];
 
297
    double[] classProbs = new double[data.numClasses()];
 
298
    double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
 
299
 
 
300
    //build tree using all the data
 
301
    makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
 
302
        totalWeight,m_minNumObj, m_Heuristic);
 
303
 
 
304
    int numNodes = numInnerNodes();
 
305
 
 
306
    double[] treeAlphas = new double[numNodes + 2];
 
307
 
 
308
    // prune back and log alpha-values
 
309
    int iterations = prune(treeAlphas, null, null);
 
310
 
 
311
    double[] treeErrors = new double[numNodes + 2];
 
312
 
 
313
    // for each pruned subtree, find the cross-validated error
 
314
    for (int i = 0; i <= iterations; i++){
 
315
      //compute midpoint alphas
 
316
      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
 
317
      double error = 0;
 
318
      for (int k = 0; k < m_numFoldsPruning; k++) {
 
319
        int l = 0;
 
320
        while (alphas[k][l] <= alpha) l++;
 
321
        error += errors[k][l - 1];
 
322
      }
 
323
      treeErrors[i] = error/m_numFoldsPruning;
 
324
    }
 
325
 
 
326
    // find best alpha
 
327
    int best = -1;
 
328
    double bestError = Double.MAX_VALUE;
 
329
    for (int i = iterations; i >= 0; i--) {
 
330
      if (treeErrors[i] < bestError) {
 
331
        bestError = treeErrors[i];
 
332
        best = i;
 
333
      }
 
334
    }
 
335
 
 
336
    // 1 SE rule to choose expansion
 
337
    if (m_UseOneSE) {
 
338
      double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
 
339
      for (int i = iterations; i >= 0; i--) {
 
340
        if (treeErrors[i] <= bestError+oneSE) {
 
341
          best = i;
 
342
          break;
 
343
        }
 
344
      }
 
345
    }
 
346
 
 
347
    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);
 
348
 
 
349
    //"unprune" final tree (faster than regrowing it)
 
350
    unprune();
 
351
    prune(bestAlpha);        
 
352
  }
 
353
 
 
354
  /**
 
355
   * Make binary decision tree recursively.
 
356
   * 
 
357
   * @param data                the training instances
 
358
   * @param totalInstances      total number of instances
 
359
   * @param sortedIndices       sorted indices of the instances
 
360
   * @param weights             weights of the instances
 
361
   * @param classProbs          class probabilities
 
362
   * @param totalWeight         total weight of instances
 
363
   * @param minNumObj           minimal number of instances at leaf nodes
 
364
   * @param useHeuristic        if use heuristic search for nominal attributes in multi-class problem
 
365
   * @throws Exception          if something goes wrong
 
366
   */
 
367
  protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
 
368
      double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
 
369
      boolean useHeuristic) throws Exception{
 
370
 
 
371
    // if no instances have reached this node (normally won't happen)
 
372
    if (totalWeight == 0){
 
373
      m_Attribute = null;
 
374
      m_ClassValue = Instance.missingValue();
 
375
      m_Distribution = new double[data.numClasses()];
 
376
      return;
 
377
    }
 
378
 
 
379
    m_totalTrainInstances = totalInstances;
 
380
    m_isLeaf = true;
 
381
 
 
382
    m_ClassProbs = new double[classProbs.length];
 
383
    m_Distribution = new double[classProbs.length];
 
384
    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
 
385
    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
 
386
    if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);
 
387
 
 
388
    // Compute class distributions and value of splitting
 
389
    // criterion for each attribute
 
390
    double[][][] dists = new double[data.numAttributes()][0][0];
 
391
    double[][] props = new double[data.numAttributes()][0];
 
392
    double[][] totalSubsetWeights = new double[data.numAttributes()][2];
 
393
    double[] splits = new double[data.numAttributes()];
 
394
    String[] splitString = new String[data.numAttributes()];
 
395
    double[] giniGains = new double[data.numAttributes()];
 
396
 
 
397
    // for each attribute find split information
 
398
    for (int i = 0; i < data.numAttributes(); i++) {
 
399
      Attribute att = data.attribute(i);
 
400
      if (i==data.classIndex()) continue;
 
401
      if (att.isNumeric()) {
 
402
        // numeric attribute
 
403
        splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
 
404
            weights[i], totalSubsetWeights, giniGains, data);
 
405
      } else {
 
406
        // nominal attribute
 
407
        splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
 
408
            weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
 
409
      }
 
410
    }
 
411
 
 
412
    // Find best attribute (split with maximum Gini gain)
 
413
    int attIndex = Utils.maxIndex(giniGains);
 
414
    m_Attribute = data.attribute(attIndex);
 
415
 
 
416
    m_train = new Instances(data, sortedIndices[attIndex].length);
 
417
    for (int i=0; i<sortedIndices[attIndex].length; i++) {
 
418
      Instance inst = data.instance(sortedIndices[attIndex][i]);
 
419
      Instance instCopy = (Instance)inst.copy();
 
420
      instCopy.setWeight(weights[attIndex][i]);
 
421
      m_train.add(instCopy);
 
422
    }
 
423
 
 
424
    // Check if node does not contain enough instances, or if it can not be split,
 
425
    // or if it is pure. If does, make leaf.
 
426
    if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||
 
427
        props[attIndex][0]==0 || props[attIndex][1]==0) {
 
428
      makeLeaf(data);
 
429
    }
 
430
 
 
431
    else {            
 
432
      m_Props = props[attIndex];
 
433
      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
 
434
      double[][][] subsetWeights = new double[2][data.numAttributes()][0];
 
435
 
 
436
      // numeric split
 
437
      if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];
 
438
 
 
439
      // nominal split
 
440
      else m_SplitString = splitString[attIndex];
 
441
 
 
442
      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
 
443
          m_SplitString, sortedIndices, weights, data);
 
444
 
 
445
      // If split of the node results in a node with less than minimal number of isntances, 
 
446
      // make the node leaf node.
 
447
      if (subsetIndices[0][attIndex].length<minNumObj ||
 
448
          subsetIndices[1][attIndex].length<minNumObj) {
 
449
        makeLeaf(data);
 
450
        return;
 
451
      }
 
452
 
 
453
      // Otherwise, split the node.
 
454
      m_isLeaf = false;
 
455
      m_Successors = new SimpleCart[2];
 
456
      for (int i = 0; i < 2; i++) {
 
457
        m_Successors[i] = new SimpleCart();
 
458
        m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
 
459
            subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],
 
460
            minNumObj, useHeuristic);
 
461
      }
 
462
    }
 
463
  }
 
464
 
 
465
  /**
 
466
   * Prunes the original tree using the CART pruning scheme, given a 
 
467
   * cost-complexity parameter alpha.
 
468
   * 
 
469
   * @param alpha       the cost-complexity parameter
 
470
   * @throws Exception  if something goes wrong
 
471
   */
 
472
  public void prune(double alpha) throws Exception {
 
473
 
 
474
    Vector nodeList;
 
475
 
 
476
    // determine training error of pruned subtrees (both with and without replacing a subtree),
 
477
    // and calculate alpha-values from them
 
478
    modelErrors();
 
479
    treeErrors();
 
480
    calculateAlphas();
 
481
 
 
482
    // get list of all inner nodes in the tree
 
483
    nodeList = getInnerNodes();
 
484
 
 
485
    boolean prune = (nodeList.size() > 0);
 
486
    double preAlpha = Double.MAX_VALUE;
 
487
    while (prune) {
 
488
 
 
489
      // select node with minimum alpha
 
490
      SimpleCart nodeToPrune = nodeToPrune(nodeList);
 
491
 
 
492
      // want to prune if its alpha is smaller than alpha
 
493
      if (nodeToPrune.m_Alpha > alpha) {
 
494
        break;
 
495
      }
 
496
 
 
497
      nodeToPrune.makeLeaf(nodeToPrune.m_train);
 
498
 
 
499
      // normally would not happen
 
500
      if (nodeToPrune.m_Alpha==preAlpha) {
 
501
        nodeToPrune.makeLeaf(nodeToPrune.m_train);
 
502
        treeErrors();
 
503
        calculateAlphas();
 
504
        nodeList = getInnerNodes();
 
505
        prune = (nodeList.size() > 0);
 
506
        continue;
 
507
      }
 
508
      preAlpha = nodeToPrune.m_Alpha;
 
509
 
 
510
      //update tree errors and alphas
 
511
      treeErrors();
 
512
      calculateAlphas();
 
513
 
 
514
      nodeList = getInnerNodes();
 
515
      prune = (nodeList.size() > 0);
 
516
    }
 
517
  }
 
518
 
 
519
  /**
 
520
   * Method for performing one fold in the cross-validation of minimal 
 
521
   * cost-complexity pruning. Generates a sequence of alpha-values with error 
 
522
   * estimates for the corresponding (partially pruned) trees, given the test 
 
523
   * set of that fold.
 
524
   *
 
525
   * @param alphas      array to hold the generated alpha-values
 
526
   * @param errors      array to hold the corresponding error estimates
 
527
   * @param test        test set of that fold (to obtain error estimates)
 
528
   * @return            the iteration of the pruning
 
529
   * @throws Exception  if something goes wrong
 
530
   */
 
531
  public int prune(double[] alphas, double[] errors, Instances test) 
 
532
    throws Exception {
 
533
 
 
534
    Vector nodeList;
 
535
 
 
536
    // determine training error of subtrees (both with and without replacing a subtree), 
 
537
    // and calculate alpha-values from them
 
538
    modelErrors();
 
539
    treeErrors();
 
540
    calculateAlphas();
 
541
 
 
542
    // get list of all inner nodes in the tree
 
543
    nodeList = getInnerNodes();
 
544
 
 
545
    boolean prune = (nodeList.size() > 0);
 
546
 
 
547
    //alpha_0 is always zero (unpruned tree)
 
548
    alphas[0] = 0;
 
549
 
 
550
    Evaluation eval;
 
551
 
 
552
    // error of unpruned tree
 
553
    if (errors != null) {
 
554
      eval = new Evaluation(test);
 
555
      eval.evaluateModel(this, test);
 
556
      errors[0] = eval.errorRate();
 
557
    }
 
558
 
 
559
    int iteration = 0;
 
560
    double preAlpha = Double.MAX_VALUE;
 
561
    while (prune) {
 
562
 
 
563
      iteration++;
 
564
 
 
565
      // get node with minimum alpha
 
566
      SimpleCart nodeToPrune = nodeToPrune(nodeList);
 
567
 
 
568
      // do not set m_sons null, want to unprune
 
569
      nodeToPrune.m_isLeaf = true;
 
570
 
 
571
      // normally would not happen
 
572
      if (nodeToPrune.m_Alpha==preAlpha) {
 
573
        iteration--;
 
574
        treeErrors();
 
575
        calculateAlphas();
 
576
        nodeList = getInnerNodes();
 
577
        prune = (nodeList.size() > 0);
 
578
        continue;
 
579
      }
 
580
 
 
581
      // get alpha-value of node
 
582
      alphas[iteration] = nodeToPrune.m_Alpha;
 
583
 
 
584
      // log error
 
585
      if (errors != null) {
 
586
        eval = new Evaluation(test);
 
587
        eval.evaluateModel(this, test);
 
588
        errors[iteration] = eval.errorRate();
 
589
      }
 
590
      preAlpha = nodeToPrune.m_Alpha;
 
591
 
 
592
      //update errors/alphas
 
593
      treeErrors();
 
594
      calculateAlphas();
 
595
 
 
596
      nodeList = getInnerNodes();
 
597
      prune = (nodeList.size() > 0);
 
598
    }
 
599
 
 
600
    //set last alpha 1 to indicate end
 
601
    alphas[iteration + 1] = 1.0;
 
602
    return iteration;
 
603
  }
 
604
 
 
605
  /**
 
606
   * Method to "unprune" the CART tree. Sets all leaf-fields to false.
 
607
   * Faster than re-growing the tree because CART do not have to be fit again.
 
608
   */
 
609
  protected void unprune() {
 
610
    if (m_Successors != null) {
 
611
      m_isLeaf = false;
 
612
      for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
 
613
    }
 
614
  }
 
615
 
 
616
  /**
 
617
   * Compute distributions, proportions and total weights of two successor 
 
618
   * nodes for a given numeric attribute.
 
619
   * 
 
620
   * @param props               proportions of each two branches for each attribute
 
621
   * @param dists               class distributions of two branches for each attribute
 
622
   * @param att                 numeric att split on
 
623
   * @param sortedIndices       sorted indices of instances for the attirubte
 
624
   * @param weights             weights of instances for the attirbute
 
625
   * @param subsetWeights       total weight of two branches split based on the attribute
 
626
   * @param giniGains           Gini gains for each attribute 
 
627
   * @param data                training instances
 
628
   * @return                    Gini gain the given numeric attribute
 
629
   * @throws Exception          if something goes wrong
 
630
   */
 
631
  protected double numericDistribution(double[][] props, double[][][] dists,
 
632
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
 
633
      double[] giniGains, Instances data)
 
634
    throws Exception {
 
635
 
 
636
    double splitPoint = Double.NaN;
 
637
    double[][] dist = null;
 
638
    int numClasses = data.numClasses();
 
639
    int i; // differ instances with or without missing values
 
640
 
 
641
    double[][] currDist = new double[2][numClasses];
 
642
    dist = new double[2][numClasses];
 
643
 
 
644
    // Move all instances without missing values into second subset
 
645
    double[] parentDist = new double[numClasses];
 
646
    int missingStart = 0;
 
647
    for (int j = 0; j < sortedIndices.length; j++) {
 
648
      Instance inst = data.instance(sortedIndices[j]);
 
649
      if (!inst.isMissing(att)) {
 
650
        missingStart ++;
 
651
        currDist[1][(int)inst.classValue()] += weights[j];
 
652
      }
 
653
      parentDist[(int)inst.classValue()] += weights[j];
 
654
    }
 
655
    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
 
656
 
 
657
    // Try all possible split points
 
658
    double currSplit = data.instance(sortedIndices[0]).value(att);
 
659
    double currGiniGain;
 
660
    double bestGiniGain = -Double.MAX_VALUE;
 
661
 
 
662
    for (i = 0; i < sortedIndices.length; i++) {
 
663
      Instance inst = data.instance(sortedIndices[i]);
 
664
      if (inst.isMissing(att)) {
 
665
        break;
 
666
      }
 
667
      if (inst.value(att) > currSplit) {
 
668
 
 
669
        double[][] tempDist = new double[2][numClasses];
 
670
        for (int k=0; k<2; k++) {
 
671
          //tempDist[k] = currDist[k];
 
672
          System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
 
673
        }
 
674
 
 
675
        double[] tempProps = new double[2];
 
676
        for (int k=0; k<2; k++) {
 
677
          tempProps[k] = Utils.sum(tempDist[k]);
 
678
        }
 
679
 
 
680
        if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);
 
681
 
 
682
        // split missing values
 
683
        int index = missingStart;
 
684
        while (index < sortedIndices.length) {
 
685
          Instance insta = data.instance(sortedIndices[index]);
 
686
          for (int j = 0; j < 2; j++) {
 
687
            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
 
688
          }
 
689
          index++;
 
690
        }
 
691
 
 
692
        currGiniGain = computeGiniGain(parentDist,tempDist);
 
693
 
 
694
        if (currGiniGain > bestGiniGain) {
 
695
          bestGiniGain = currGiniGain;
 
696
 
 
697
          // clean split point
 
698
          splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; 
 
699
 
 
700
          for (int j = 0; j < currDist.length; j++) {
 
701
            System.arraycopy(tempDist[j], 0, dist[j], 0,
 
702
                dist[j].length);
 
703
          }
 
704
        }
 
705
      }
 
706
      currSplit = inst.value(att);
 
707
      currDist[0][(int)inst.classValue()] += weights[i];
 
708
      currDist[1][(int)inst.classValue()] -= weights[i];
 
709
    }
 
710
 
 
711
    // Compute weights
 
712
    int attIndex = att.index();
 
713
    props[attIndex] = new double[2];
 
714
    for (int k = 0; k < 2; k++) {
 
715
      props[attIndex][k] = Utils.sum(dist[k]);
 
716
    }
 
717
    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);
 
718
 
 
719
    // Compute subset weights
 
720
    subsetWeights[attIndex] = new double[2];
 
721
    for (int j = 0; j < 2; j++) {
 
722
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
 
723
    }
 
724
 
 
725
    // clean Gini gain
 
726
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
 
727
    dists[attIndex] = dist;
 
728
 
 
729
    return splitPoint;
 
730
  }
 
731
 
 
732
  /**
 
733
   * Compute distributions, proportions and total weights of two successor 
 
734
   * nodes for a given nominal attribute.
 
735
   * 
 
736
   * @param props               proportions of each two branches for each attribute
 
737
   * @param dists               class distributions of two branches for each attribute
 
738
   * @param att                 numeric att split on
 
739
   * @param sortedIndices       sorted indices of instances for the attirubte
 
740
   * @param weights             weights of instances for the attirbute
 
741
   * @param subsetWeights       total weight of two branches split based on the attribute
 
742
   * @param giniGains           Gini gains for each attribute 
 
743
   * @param data                training instances
 
744
   * @param useHeuristic        if use heuristic search
 
745
   * @return                    Gini gain for the given nominal attribute
 
746
   * @throws Exception          if something goes wrong
 
747
   */
 
748
  protected String nominalDistribution(double[][] props, double[][][] dists,
 
749
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
 
750
      double[] giniGains, Instances data, boolean useHeuristic)
 
751
    throws Exception {
 
752
 
 
753
    String[] values = new String[att.numValues()];
 
754
    int numCat = values.length; // number of values of the attribute
 
755
    int numClasses = data.numClasses();
 
756
 
 
757
    String bestSplitString = "";
 
758
    double bestGiniGain = -Double.MAX_VALUE;
 
759
 
 
760
    // class frequency for each value
 
761
    int[] classFreq = new int[numCat];
 
762
    for (int j=0; j<numCat; j++) classFreq[j] = 0;
 
763
 
 
764
    double[] parentDist = new double[numClasses];
 
765
    double[][] currDist = new double[2][numClasses];
 
766
    double[][] dist = new double[2][numClasses];
 
767
    int missingStart = 0;
 
768
 
 
769
    for (int i = 0; i < sortedIndices.length; i++) {
 
770
      Instance inst = data.instance(sortedIndices[i]);
 
771
      if (!inst.isMissing(att)) {
 
772
        missingStart++;
 
773
        classFreq[(int)inst.value(att)] ++;
 
774
      }
 
775
      parentDist[(int)inst.classValue()] += weights[i];
 
776
    }
 
777
 
 
778
    // count the number of values that class frequency is not 0
 
779
    int nonEmpty = 0;
 
780
    for (int j=0; j<numCat; j++) {
 
781
      if (classFreq[j]!=0) nonEmpty ++;
 
782
    }
 
783
 
 
784
    // attribute values that class frequency is not 0
 
785
    String[] nonEmptyValues = new String[nonEmpty];
 
786
    int nonEmptyIndex = 0;
 
787
    for (int j=0; j<numCat; j++) {
 
788
      if (classFreq[j]!=0) {
 
789
        nonEmptyValues[nonEmptyIndex] = att.value(j);
 
790
        nonEmptyIndex ++;
 
791
      }
 
792
    }
 
793
 
 
794
    // attribute values that class frequency is 0
 
795
    int empty = numCat - nonEmpty;
 
796
    String[] emptyValues = new String[empty];
 
797
    int emptyIndex = 0;
 
798
    for (int j=0; j<numCat; j++) {
 
799
      if (classFreq[j]==0) {
 
800
        emptyValues[emptyIndex] = att.value(j);
 
801
        emptyIndex ++;
 
802
      }
 
803
    }
 
804
 
 
805
    if (nonEmpty<=1) {
 
806
      giniGains[att.index()] = 0;
 
807
      return "";
 
808
    }
 
809
 
 
810
    // for tow-class probloms
 
811
    if (data.numClasses()==2) {
 
812
 
 
813
      //// Firstly, for attribute values which class frequency is not zero
 
814
 
 
815
      // probability of class 0 for each attribute value
 
816
      double[] pClass0 = new double[nonEmpty];
 
817
      // class distribution for each attribute value
 
818
      double[][] valDist = new double[nonEmpty][2];
 
819
 
 
820
      for (int j=0; j<nonEmpty; j++) {
 
821
        for (int k=0; k<2; k++) {
 
822
          valDist[j][k] = 0;
 
823
        }
 
824
      }
 
825
 
 
826
      for (int i = 0; i < sortedIndices.length; i++) {
 
827
        Instance inst = data.instance(sortedIndices[i]);
 
828
        if (inst.isMissing(att)) {
 
829
          break;
 
830
        }
 
831
 
 
832
        for (int j=0; j<nonEmpty; j++) {
 
833
          if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
 
834
            valDist[j][(int)inst.classValue()] += inst.weight();
 
835
            break;
 
836
          }
 
837
        }
 
838
      }
 
839
 
 
840
      for (int j=0; j<nonEmpty; j++) {
 
841
        double distSum = Utils.sum(valDist[j]);
 
842
        if (distSum==0) pClass0[j]=0;
 
843
        else pClass0[j] = valDist[j][0]/distSum;
 
844
      }
 
845
 
 
846
      // sort category according to the probability of the first class
 
847
      String[] sortedValues = new String[nonEmpty];
 
848
      for (int j=0; j<nonEmpty; j++) {
 
849
        sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
 
850
        pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
 
851
      }
 
852
 
 
853
      // Find a subset of attribute values that maximize Gini decrease
 
854
 
 
855
      // for the attribute values that class frequency is not 0
 
856
      String tempStr = "";
 
857
 
 
858
      for (int j=0; j<nonEmpty-1; j++) {
 
859
        currDist = new double[2][numClasses];
 
860
        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
 
861
        else tempStr += "|"+ "(" + sortedValues[j] + ")";
 
862
        for (int i=0; i<sortedIndices.length;i++) {
 
863
          Instance inst = data.instance(sortedIndices[i]);
 
864
          if (inst.isMissing(att)) {
 
865
            break;
 
866
          }
 
867
 
 
868
          if (tempStr.indexOf
 
869
              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
 
870
            currDist[0][(int)inst.classValue()] += weights[i];
 
871
          } else currDist[1][(int)inst.classValue()] += weights[i];
 
872
        }
 
873
 
 
874
        double[][] tempDist = new double[2][numClasses];
 
875
        for (int kk=0; kk<2; kk++) {
 
876
          tempDist[kk] = currDist[kk];
 
877
        }
 
878
 
 
879
        double[] tempProps = new double[2];
 
880
        for (int kk=0; kk<2; kk++) {
 
881
          tempProps[kk] = Utils.sum(tempDist[kk]);
 
882
        }
 
883
 
 
884
        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
 
885
 
 
886
        // split missing values
 
887
        int mstart = missingStart;
 
888
        while (mstart < sortedIndices.length) {
 
889
          Instance insta = data.instance(sortedIndices[mstart]);
 
890
          for (int jj = 0; jj < 2; jj++) {
 
891
            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
 
892
          }
 
893
          mstart++;
 
894
        }
 
895
 
 
896
        double currGiniGain = computeGiniGain(parentDist,tempDist);
 
897
 
 
898
        if (currGiniGain>bestGiniGain) {
 
899
          bestGiniGain = currGiniGain;
 
900
          bestSplitString = tempStr;
 
901
          for (int jj = 0; jj < 2; jj++) {
 
902
            //dist[jj] = new double[currDist[jj].length];
 
903
            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
 
904
                dist[jj].length);
 
905
          }
 
906
        }
 
907
      }
 
908
    }
 
909
 
 
910
    // multi-class problems - exhaustive search
 
911
    else if (!useHeuristic || nonEmpty<=4) {
 
912
 
 
913
      // Firstly, for attribute values which class frequency is not zero
 
914
      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
 
915
        String tempStr="";
 
916
        currDist = new double[2][numClasses];
 
917
        int mod;
 
918
        int bit10 = i;
 
919
        for (int j=nonEmpty-1; j>=0; j--) {
 
920
          mod = bit10%2; // convert from 10bit to 2bit
 
921
          if (mod==1) {
 
922
            if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
 
923
            else tempStr += "|" + "("+nonEmptyValues[j]+")";
 
924
          }
 
925
          bit10 = bit10/2;
 
926
        }
 
927
        for (int j=0; j<sortedIndices.length;j++) {
 
928
          Instance inst = data.instance(sortedIndices[j]);
 
929
          if (inst.isMissing(att)) {
 
930
            break;
 
931
          }
 
932
 
 
933
          if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
 
934
            currDist[0][(int)inst.classValue()] += weights[j];
 
935
          } else currDist[1][(int)inst.classValue()] += weights[j];
 
936
        }
 
937
 
 
938
        double[][] tempDist = new double[2][numClasses];
 
939
        for (int k=0; k<2; k++) {
 
940
          tempDist[k] = currDist[k];
 
941
        }
 
942
 
 
943
        double[] tempProps = new double[2];
 
944
        for (int k=0; k<2; k++) {
 
945
          tempProps[k] = Utils.sum(tempDist[k]);
 
946
        }
 
947
 
 
948
        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
 
949
 
 
950
        // split missing values
 
951
        int index = missingStart;
 
952
        while (index < sortedIndices.length) {
 
953
          Instance insta = data.instance(sortedIndices[index]);
 
954
          for (int j = 0; j < 2; j++) {
 
955
            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
 
956
          }
 
957
          index++;
 
958
        }
 
959
 
 
960
        double currGiniGain = computeGiniGain(parentDist,tempDist);
 
961
 
 
962
        if (currGiniGain>bestGiniGain) {
 
963
          bestGiniGain = currGiniGain;
 
964
          bestSplitString = tempStr;
 
965
          for (int j = 0; j < 2; j++) {
 
966
            //dist[jj] = new double[currDist[jj].length];
 
967
            System.arraycopy(tempDist[j], 0, dist[j], 0,
 
968
                dist[j].length);
 
969
          }
 
970
        }
 
971
      }
 
972
    }
 
973
 
 
974
    // huristic search to solve multi-classes problems
 
975
    else {
 
976
      // Firstly, for attribute values which class frequency is not zero
 
977
      int n = nonEmpty;
 
978
      int k = data.numClasses();  // number of classes of the data
 
979
      double[][] P = new double[n][k];      // class probability matrix
 
980
      int[] numInstancesValue = new int[n]; // number of instances for an attribute value
 
981
      double[] meanClass = new double[k];   // vector of mean class probability
 
982
      int numInstances = data.numInstances(); // total number of instances
 
983
 
 
984
      // initialize the vector of mean class probability
 
985
      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;
 
986
 
 
987
      for (int j=0; j<numInstances; j++) {
 
988
        Instance inst = (Instance)data.instance(j);
 
989
        int valueIndex = 0; // attribute value index in nonEmptyValues
 
990
        for (int i=0; i<nonEmpty; i++) {
 
991
          if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
 
992
            valueIndex = i;
 
993
            break;
 
994
          }
 
995
        }
 
996
        P[valueIndex][(int)inst.classValue()]++;
 
997
        numInstancesValue[valueIndex]++;
 
998
        meanClass[(int)inst.classValue()]++;
 
999
      }
 
1000
 
 
1001
      // calculate the class probability matrix
 
1002
      for (int i=0; i<P.length; i++) {
 
1003
        for (int j=0; j<P[0].length; j++) {
 
1004
          if (numInstancesValue[i]==0) P[i][j]=0;
 
1005
          else P[i][j]/=numInstancesValue[i];
 
1006
        }
 
1007
      }
 
1008
 
 
1009
      //calculate the vector of mean class probability
 
1010
      for (int i=0; i<meanClass.length; i++) {
 
1011
        meanClass[i]/=numInstances;
 
1012
      }
 
1013
 
 
1014
      // calculate the covariance matrix
 
1015
      double[][] covariance = new double[k][k];
 
1016
      for (int i1=0; i1<k; i1++) {
 
1017
        for (int i2=0; i2<k; i2++) {
 
1018
          double element = 0;
 
1019
          for (int j=0; j<n; j++) {
 
1020
            element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
 
1021
            *numInstancesValue[j];
 
1022
          }
 
1023
          covariance[i1][i2] = element;
 
1024
        }
 
1025
      }
 
1026
 
 
1027
      Matrix matrix = new Matrix(covariance);
 
1028
      weka.core.matrix.EigenvalueDecomposition eigen =
 
1029
        new weka.core.matrix.EigenvalueDecomposition(matrix);
 
1030
      double[] eigenValues = eigen.getRealEigenvalues();
 
1031
 
 
1032
      // find index of the largest eigenvalue
 
1033
      int index=0;
 
1034
      double largest = eigenValues[0];
 
1035
      for (int i=1; i<eigenValues.length; i++) {
 
1036
        if (eigenValues[i]>largest) {
 
1037
          index=i;
 
1038
          largest = eigenValues[i];
 
1039
        }
 
1040
      }
 
1041
 
 
1042
      // calculate the first principle component
 
1043
      double[] FPC = new double[k];
 
1044
      Matrix eigenVector = eigen.getV();
 
1045
      double[][] vectorArray = eigenVector.getArray();
 
1046
      for (int i=0; i<FPC.length; i++) {
 
1047
        FPC[i] = vectorArray[i][index];
 
1048
      }
 
1049
 
 
1050
      // calculate the first principle component scores
 
1051
      //System.out.println("the first principle component scores: ");
 
1052
      double[] Sa = new double[n];
 
1053
      for (int i=0; i<Sa.length; i++) {
 
1054
        Sa[i]=0;
 
1055
        for (int j=0; j<k; j++) {
 
1056
          Sa[i] += FPC[j]*P[i][j];
 
1057
        }
 
1058
      }
 
1059
 
 
1060
      // sort category according to Sa(s)
 
1061
      double[] pCopy = new double[n];
 
1062
      System.arraycopy(Sa,0,pCopy,0,n);
 
1063
      String[] sortedValues = new String[n];
 
1064
      Arrays.sort(Sa);
 
1065
 
 
1066
      for (int j=0; j<n; j++) {
 
1067
        sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
 
1068
        pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
 
1069
      }
 
1070
 
 
1071
      // for the attribute values that class frequency is not 0
 
1072
      String tempStr = "";
 
1073
 
 
1074
      for (int j=0; j<nonEmpty-1; j++) {
 
1075
        currDist = new double[2][numClasses];
 
1076
        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
 
1077
        else tempStr += "|"+ "(" + sortedValues[j] + ")";
 
1078
        for (int i=0; i<sortedIndices.length;i++) {
 
1079
          Instance inst = data.instance(sortedIndices[i]);
 
1080
          if (inst.isMissing(att)) {
 
1081
            break;
 
1082
          }
 
1083
 
 
1084
          if (tempStr.indexOf
 
1085
              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
 
1086
            currDist[0][(int)inst.classValue()] += weights[i];
 
1087
          } else currDist[1][(int)inst.classValue()] += weights[i];
 
1088
        }
 
1089
 
 
1090
        double[][] tempDist = new double[2][numClasses];
 
1091
        for (int kk=0; kk<2; kk++) {
 
1092
          tempDist[kk] = currDist[kk];
 
1093
        }
 
1094
 
 
1095
        double[] tempProps = new double[2];
 
1096
        for (int kk=0; kk<2; kk++) {
 
1097
          tempProps[kk] = Utils.sum(tempDist[kk]);
 
1098
        }
 
1099
 
 
1100
        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
 
1101
 
 
1102
        // split missing values
 
1103
        int mstart = missingStart;
 
1104
        while (mstart < sortedIndices.length) {
 
1105
          Instance insta = data.instance(sortedIndices[mstart]);
 
1106
          for (int jj = 0; jj < 2; jj++) {
 
1107
            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
 
1108
          }
 
1109
          mstart++;
 
1110
        }
 
1111
 
 
1112
        double currGiniGain = computeGiniGain(parentDist,tempDist);
 
1113
 
 
1114
        if (currGiniGain>bestGiniGain) {
 
1115
          bestGiniGain = currGiniGain;
 
1116
          bestSplitString = tempStr;
 
1117
          for (int jj = 0; jj < 2; jj++) {
 
1118
            //dist[jj] = new double[currDist[jj].length];
 
1119
            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
 
1120
                dist[jj].length);
 
1121
          }
 
1122
        }
 
1123
      }
 
1124
    }
 
1125
 
 
1126
    // Compute weights
 
1127
    int attIndex = att.index();        
 
1128
    props[attIndex] = new double[2];
 
1129
    for (int k = 0; k < 2; k++) {
 
1130
      props[attIndex][k] = Utils.sum(dist[k]);
 
1131
    }
 
1132
 
 
1133
    if (!(Utils.sum(props[attIndex]) > 0)) {
 
1134
      for (int k = 0; k < props[attIndex].length; k++) {
 
1135
        props[attIndex][k] = 1.0 / (double)props[attIndex].length;
 
1136
      }
 
1137
    } else {
 
1138
      Utils.normalize(props[attIndex]);
 
1139
    }
 
1140
 
 
1141
 
 
1142
    // Compute subset weights
 
1143
    subsetWeights[attIndex] = new double[2];
 
1144
    for (int j = 0; j < 2; j++) {
 
1145
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
 
1146
    }
 
1147
 
 
1148
    // Then, for the attribute values that class frequency is 0, split it into the
 
1149
    // most frequent branch
 
1150
    for (int j=0; j<empty; j++) {
 
1151
      if (props[attIndex][0]>=props[attIndex][1]) {
 
1152
        if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
 
1153
        else bestSplitString += "|" + "(" + emptyValues[j] + ")";
 
1154
      }
 
1155
    }
 
1156
 
 
1157
    // clean Gini gain for the attribute
 
1158
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
 
1159
 
 
1160
    dists[attIndex] = dist;
 
1161
    return bestSplitString;
 
1162
  }
 
1163
 
 
1164
 
 
1165
  /**
 
1166
   * Split data into two subsets and store sorted indices and weights for two
 
1167
   * successor nodes.
 
1168
   * 
 
1169
   * @param subsetIndices       sorted indecis of instances for each attribute 
 
1170
   *                            for two successor node
 
1171
   * @param subsetWeights       weights of instances for each attribute for 
 
1172
   *                            two successor node
 
1173
   * @param att                 attribute the split based on
 
1174
   * @param splitPoint          split point the split based on if att is numeric
 
1175
   * @param splitStr            split subset the split based on if att is nominal
 
1176
   * @param sortedIndices       sorted indices of the instances to be split
 
1177
   * @param weights             weights of the instances to bes split
 
1178
   * @param data                training data
 
1179
   * @throws Exception          if something goes wrong  
 
1180
   */
 
1181
  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
 
1182
      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
 
1183
      double[][] weights, Instances data) throws Exception {
 
1184
 
 
1185
    int j;
 
1186
    // For each attribute
 
1187
    for (int i = 0; i < data.numAttributes(); i++) {
 
1188
      if (i==data.classIndex()) continue;
 
1189
      int[] num = new int[2];
 
1190
      for (int k = 0; k < 2; k++) {
 
1191
        subsetIndices[k][i] = new int[sortedIndices[i].length];
 
1192
        subsetWeights[k][i] = new double[weights[i].length];
 
1193
      }
 
1194
 
 
1195
      for (j = 0; j < sortedIndices[i].length; j++) {
 
1196
        Instance inst = data.instance(sortedIndices[i][j]);
 
1197
        if (inst.isMissing(att)) {
 
1198
          // Split instance up
 
1199
          for (int k = 0; k < 2; k++) {
 
1200
            if (m_Props[k] > 0) {
 
1201
              subsetIndices[k][i][num[k]] = sortedIndices[i][j];
 
1202
              subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
 
1203
              num[k]++;
 
1204
            }
 
1205
          }
 
1206
        } else {
 
1207
          int subset;
 
1208
          if (att.isNumeric())  {
 
1209
            subset = (inst.value(att) < splitPoint) ? 0 : 1;
 
1210
          } else { // nominal attribute
 
1211
            if (splitStr.indexOf
 
1212
                ("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
 
1213
              subset = 0;
 
1214
            } else subset = 1;
 
1215
          }
 
1216
          subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
 
1217
          subsetWeights[subset][i][num[subset]] = weights[i][j];
 
1218
          num[subset]++;
 
1219
        }
 
1220
      }
 
1221
 
 
1222
      // Trim arrays
 
1223
      for (int k = 0; k < 2; k++) {
 
1224
        int[] copy = new int[num[k]];
 
1225
        System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
 
1226
        subsetIndices[k][i] = copy;
 
1227
        double[] copyWeights = new double[num[k]];
 
1228
        System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
 
1229
        subsetWeights[k][i] = copyWeights;
 
1230
      }
 
1231
    }
 
1232
  }
 
1233
 
 
1234
  /**
 
1235
   * Updates the numIncorrectModel field for all nodes when subtree (to be 
 
1236
   * pruned) is rooted. This is needed for calculating the alpha-values.
 
1237
   * 
 
1238
   * @throws Exception  if something goes wrong
 
1239
   */
 
1240
  public void modelErrors() throws Exception{
 
1241
    Evaluation eval = new Evaluation(m_train);
 
1242
 
 
1243
    if (!m_isLeaf) {
 
1244
      m_isLeaf = true; //temporarily make leaf
 
1245
 
 
1246
      // calculate distribution for evaluation
 
1247
      eval.evaluateModel(this, m_train);
 
1248
      m_numIncorrectModel = eval.incorrect();
 
1249
 
 
1250
      m_isLeaf = false;
 
1251
 
 
1252
      for (int i = 0; i < m_Successors.length; i++)
 
1253
        m_Successors[i].modelErrors();
 
1254
 
 
1255
    } else {
 
1256
      eval.evaluateModel(this, m_train);
 
1257
      m_numIncorrectModel = eval.incorrect();
 
1258
    }       
 
1259
  }
 
1260
 
 
1261
  /**
 
1262
   * Updates the numIncorrectTree field for all nodes. This is needed for
 
1263
   * calculating the alpha-values.
 
1264
   * 
 
1265
   * @throws Exception  if something goes wrong
 
1266
   */
 
1267
  public void treeErrors() throws Exception {
 
1268
    if (m_isLeaf) {
 
1269
      m_numIncorrectTree = m_numIncorrectModel;
 
1270
    } else {
 
1271
      m_numIncorrectTree = 0;
 
1272
      for (int i = 0; i < m_Successors.length; i++) {
 
1273
        m_Successors[i].treeErrors();
 
1274
        m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
 
1275
      }
 
1276
    }
 
1277
  }
 
1278
 
 
1279
  /**
 
1280
   * Updates the alpha field for all nodes.
 
1281
   * 
 
1282
   * @throws Exception  if something goes wrong
 
1283
   */
 
1284
  public void calculateAlphas() throws Exception {
 
1285
 
 
1286
    if (!m_isLeaf) {
 
1287
      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
 
1288
      if (errorDiff <=0) {
 
1289
        //split increases training error (should not normally happen).
 
1290
        //prune it instantly.
 
1291
        makeLeaf(m_train);
 
1292
        m_Alpha = Double.MAX_VALUE;
 
1293
      } else {
 
1294
        //compute alpha
 
1295
        errorDiff /= m_totalTrainInstances;
 
1296
        m_Alpha = errorDiff / (double)(numLeaves() - 1);
 
1297
        long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
 
1298
        m_Alpha = (double)alphaLong/Math.pow(10,10);
 
1299
        for (int i = 0; i < m_Successors.length; i++) {
 
1300
          m_Successors[i].calculateAlphas();
 
1301
        }
 
1302
      }
 
1303
    } else {
 
1304
      //alpha = infinite for leaves (do not want to prune)
 
1305
      m_Alpha = Double.MAX_VALUE;
 
1306
    }
 
1307
  }
 
1308
 
 
1309
  /**
 
1310
   * Find the node with minimal alpha value. If two nodes have the same alpha, 
 
1311
   * choose the one with more leave nodes.
 
1312
   * 
 
1313
   * @param nodeList    list of inner nodes
 
1314
   * @return            the node to be pruned
 
1315
   */
 
1316
  protected SimpleCart nodeToPrune(Vector nodeList) {
 
1317
    if (nodeList.size()==0) return null;
 
1318
    if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
 
1319
    SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
 
1320
    double baseAlpha = returnNode.m_Alpha;
 
1321
    for (int i=1; i<nodeList.size(); i++) {
 
1322
      SimpleCart node = (SimpleCart)nodeList.elementAt(i);
 
1323
      if (node.m_Alpha < baseAlpha) {
 
1324
        baseAlpha = node.m_Alpha;
 
1325
        returnNode = node;
 
1326
      } else if (node.m_Alpha == baseAlpha) { // break tie
 
1327
        if (node.numLeaves()>returnNode.numLeaves()) {
 
1328
          returnNode = node;
 
1329
        }
 
1330
      }
 
1331
    }
 
1332
    return returnNode;
 
1333
  }
 
1334
 
 
1335
  /**
 
1336
   * Compute sorted indices, weights and class probabilities for a given 
 
1337
   * dataset. Return total weights of the data at the node.
 
1338
   * 
 
1339
   * @param data                training data
 
1340
   * @param sortedIndices       sorted indices of instances at the node
 
1341
   * @param weights             weights of instances at the node
 
1342
   * @param classProbs          class probabilities at the node
 
1343
   * @return total              weights of instances at the node
 
1344
   * @throws Exception          if something goes wrong
 
1345
   */
 
1346
  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
 
1347
      double[] classProbs) throws Exception {
 
1348
 
 
1349
    // Create array of sorted indices and weights
 
1350
    double[] vals = new double[data.numInstances()];
 
1351
    for (int j = 0; j < data.numAttributes(); j++) {
 
1352
      if (j==data.classIndex()) continue;
 
1353
      weights[j] = new double[data.numInstances()];
 
1354
 
 
1355
      if (data.attribute(j).isNominal()) {
 
1356
 
 
1357
        // Handling nominal attributes. Putting indices of
 
1358
        // instances with missing values at the end.
 
1359
        sortedIndices[j] = new int[data.numInstances()];
 
1360
        int count = 0;
 
1361
        for (int i = 0; i < data.numInstances(); i++) {
 
1362
          Instance inst = data.instance(i);
 
1363
          if (!inst.isMissing(j)) {
 
1364
            sortedIndices[j][count] = i;
 
1365
            weights[j][count] = inst.weight();
 
1366
            count++;
 
1367
          }
 
1368
        }
 
1369
        for (int i = 0; i < data.numInstances(); i++) {
 
1370
          Instance inst = data.instance(i);
 
1371
          if (inst.isMissing(j)) {
 
1372
            sortedIndices[j][count] = i;
 
1373
            weights[j][count] = inst.weight();
 
1374
            count++;
 
1375
          }
 
1376
        }
 
1377
      } else {
 
1378
 
 
1379
        // Sorted indices are computed for numeric attributes
 
1380
        // missing values instances are put to end 
 
1381
        for (int i = 0; i < data.numInstances(); i++) {
 
1382
          Instance inst = data.instance(i);
 
1383
          vals[i] = inst.value(j);
 
1384
        }
 
1385
        sortedIndices[j] = Utils.sort(vals);
 
1386
        for (int i = 0; i < data.numInstances(); i++) {
 
1387
          weights[j][i] = data.instance(sortedIndices[j][i]).weight();
 
1388
        }
 
1389
      }
 
1390
    }
 
1391
 
 
1392
    // Compute initial class counts
 
1393
    double totalWeight = 0;
 
1394
    for (int i = 0; i < data.numInstances(); i++) {
 
1395
      Instance inst = data.instance(i);
 
1396
      classProbs[(int)inst.classValue()] += inst.weight();
 
1397
      totalWeight += inst.weight();
 
1398
    }
 
1399
 
 
1400
    return totalWeight;
 
1401
  }
 
1402
 
 
1403
  /**
 
1404
   * Compute and return gini gain for given distributions of a node and its 
 
1405
   * successor nodes.
 
1406
   * 
 
1407
   * @param parentDist  class distributions of parent node
 
1408
   * @param childDist   class distributions of successor nodes
 
1409
   * @return            Gini gain computed
 
1410
   */
 
1411
  protected double computeGiniGain(double[] parentDist, double[][] childDist) {
 
1412
    double totalWeight = Utils.sum(parentDist);
 
1413
    if (totalWeight==0) return 0;
 
1414
 
 
1415
    double leftWeight = Utils.sum(childDist[0]);
 
1416
    double rightWeight = Utils.sum(childDist[1]);
 
1417
 
 
1418
    double parentGini = computeGini(parentDist, totalWeight);
 
1419
    double leftGini = computeGini(childDist[0],leftWeight);
 
1420
    double rightGini = computeGini(childDist[1], rightWeight);
 
1421
 
 
1422
    return parentGini - leftWeight/totalWeight*leftGini -
 
1423
    rightWeight/totalWeight*rightGini;
 
1424
  }
 
1425
 
 
1426
  /**
 
1427
   * Compute and return gini index for a given distribution of a node.
 
1428
   * 
 
1429
   * @param dist        class distributions
 
1430
   * @param total       class distributions
 
1431
   * @return            Gini index of the class distributions
 
1432
   */
 
1433
  protected double computeGini(double[] dist, double total) {
 
1434
    if (total==0) return 0;
 
1435
    double val = 0;
 
1436
    for (int i=0; i<dist.length; i++) {
 
1437
      val += (dist[i]/total)*(dist[i]/total);
 
1438
    }
 
1439
    return 1- val;
 
1440
  }
 
1441
 
 
1442
  /**
 
1443
   * Computes class probabilities for instance using the decision tree.
 
1444
   * 
 
1445
   * @param instance    the instance for which class probabilities is to be computed
 
1446
   * @return            the class probabilities for the given instance
 
1447
   * @throws Exception  if something goes wrong
 
1448
   */
 
1449
  public double[] distributionForInstance(Instance instance)
 
1450
  throws Exception {
 
1451
    if (!m_isLeaf) {
 
1452
      // value of split attribute is missing
 
1453
      if (instance.isMissing(m_Attribute)) {
 
1454
        double[] returnedDist = new double[m_ClassProbs.length];
 
1455
 
 
1456
        for (int i = 0; i < m_Successors.length; i++) {
 
1457
          double[] help =
 
1458
            m_Successors[i].distributionForInstance(instance);
 
1459
          if (help != null) {
 
1460
            for (int j = 0; j < help.length; j++) {
 
1461
              returnedDist[j] += m_Props[i] * help[j];
 
1462
            }
 
1463
          }
 
1464
        }
 
1465
        return returnedDist;
 
1466
      }
 
1467
 
 
1468
      // split attribute is nonimal
 
1469
      else if (m_Attribute.isNominal()) {
 
1470
        if (m_SplitString.indexOf("(" +
 
1471
            m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
 
1472
          return  m_Successors[0].distributionForInstance(instance);
 
1473
        else return  m_Successors[1].distributionForInstance(instance);
 
1474
      }
 
1475
 
 
1476
      // split attribute is numeric
 
1477
      else {
 
1478
        if (instance.value(m_Attribute) < m_SplitValue)
 
1479
          return m_Successors[0].distributionForInstance(instance);
 
1480
        else
 
1481
          return m_Successors[1].distributionForInstance(instance);
 
1482
      }
 
1483
    }
 
1484
 
 
1485
    // leaf node
 
1486
    else return m_ClassProbs;
 
1487
  }
 
1488
 
 
1489
  /**
 
1490
   * Make the node leaf node.
 
1491
   * 
 
1492
   * @param data        trainging data
 
1493
   */
 
1494
  protected void makeLeaf(Instances data) {
 
1495
    m_Attribute = null;
 
1496
    m_isLeaf = true;
 
1497
    m_ClassValue=Utils.maxIndex(m_ClassProbs);
 
1498
    m_ClassAttribute = data.classAttribute();
 
1499
  }
 
1500
 
 
1501
  /**
 
1502
   * Prints the decision tree using the protected toString method from below.
 
1503
   * 
 
1504
   * @return            a textual description of the classifier
 
1505
   */
 
1506
  public String toString() {
 
1507
    if ((m_ClassProbs == null) && (m_Successors == null)) {
 
1508
      return "CART Tree: No model built yet.";
 
1509
    }
 
1510
 
 
1511
    return "CART Decision Tree\n" + toString(0)+"\n\n"
 
1512
    +"Number of Leaf Nodes: "+numLeaves()+"\n\n" +
 
1513
    "Size of the Tree: "+numNodes();
 
1514
  }
 
1515
 
 
1516
  /**
 
1517
   * Outputs a tree at a certain level.
 
1518
   * 
 
1519
   * @param level       the level at which the tree is to be printed
 
1520
   * @return            a tree at a certain level
 
1521
   */
 
1522
  protected String toString(int level) {
 
1523
 
 
1524
    StringBuffer text = new StringBuffer();
 
1525
    // if leaf nodes
 
1526
    if (m_Attribute == null) {
 
1527
      if (Instance.isMissingValue(m_ClassValue)) {
 
1528
        text.append(": null");
 
1529
      } else {
 
1530
        double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
 
1531
        100.0;
 
1532
        double wrongNum = (int)((Utils.sum(m_Distribution) -
 
1533
            m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
 
1534
        String str = "("  + correctNum + "/" + wrongNum + ")";
 
1535
        text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
 
1536
      }
 
1537
    } else {
 
1538
      for (int j = 0; j < 2; j++) {
 
1539
        text.append("\n");
 
1540
        for (int i = 0; i < level; i++) {
 
1541
          text.append("|  ");
 
1542
        }
 
1543
        if (j==0) {
 
1544
          if (m_Attribute.isNumeric())
 
1545
            text.append(m_Attribute.name() + " < " + m_SplitValue);
 
1546
          else
 
1547
            text.append(m_Attribute.name() + "=" + m_SplitString);
 
1548
        } else {
 
1549
          if (m_Attribute.isNumeric())
 
1550
            text.append(m_Attribute.name() + " >= " + m_SplitValue);
 
1551
          else
 
1552
            text.append(m_Attribute.name() + "!=" + m_SplitString);
 
1553
        }
 
1554
        text.append(m_Successors[j].toString(level + 1));
 
1555
      }
 
1556
    }
 
1557
    return text.toString();
 
1558
  }
 
1559
 
 
1560
  /**
 
1561
   * Compute size of the tree.
 
1562
   * 
 
1563
   * @return            size of the tree
 
1564
   */
 
1565
  public int numNodes() {
 
1566
    if (m_isLeaf) {
 
1567
      return 1;
 
1568
    } else {
 
1569
      int size =1;
 
1570
      for (int i=0;i<m_Successors.length;i++) {
 
1571
        size+=m_Successors[i].numNodes();
 
1572
      }
 
1573
      return size;
 
1574
    }
 
1575
  }
 
1576
 
 
1577
  /**
 
1578
   * Method to count the number of inner nodes in the tree.
 
1579
   * 
 
1580
   * @return            the number of inner nodes
 
1581
   */
 
1582
  public int numInnerNodes(){
 
1583
    if (m_Attribute==null) return 0;
 
1584
    int numNodes = 1;
 
1585
    for (int i = 0; i < m_Successors.length; i++)
 
1586
      numNodes += m_Successors[i].numInnerNodes();
 
1587
    return numNodes;
 
1588
  }
 
1589
 
 
1590
  /**
 
1591
   * Return a list of all inner nodes in the tree.
 
1592
   * 
 
1593
   * @return            the list of all inner nodes
 
1594
   */
 
1595
  protected Vector getInnerNodes(){
 
1596
    Vector nodeList = new Vector();
 
1597
    fillInnerNodes(nodeList);
 
1598
    return nodeList;
 
1599
  }
 
1600
 
 
1601
  /**
 
1602
   * Fills a list with all inner nodes in the tree.
 
1603
   * 
 
1604
   * @param nodeList    the list to be filled
 
1605
   */
 
1606
  protected void fillInnerNodes(Vector nodeList) {
 
1607
    if (!m_isLeaf) {
 
1608
      nodeList.add(this);
 
1609
      for (int i = 0; i < m_Successors.length; i++)
 
1610
        m_Successors[i].fillInnerNodes(nodeList);
 
1611
    }
 
1612
  }
 
1613
 
 
1614
  /**
 
1615
   * Compute number of leaf nodes.
 
1616
   * 
 
1617
   * @return            number of leaf nodes
 
1618
   */
 
1619
  public int numLeaves() {
 
1620
    if (m_isLeaf) return 1;
 
1621
    else {
 
1622
      int size=0;
 
1623
      for (int i=0;i<m_Successors.length;i++) {
 
1624
        size+=m_Successors[i].numLeaves();
 
1625
      }
 
1626
      return size;
 
1627
    }
 
1628
  }
 
1629
 
 
1630
  /**
 
1631
   * Returns an enumeration describing the available options.
 
1632
   *
 
1633
   * @return            an enumeration of all the available options.
 
1634
   */
 
1635
  public Enumeration listOptions() {
 
1636
    Vector      result;
 
1637
    Enumeration en;
 
1638
    
 
1639
    result = new Vector();
 
1640
    
 
1641
    en = super.listOptions();
 
1642
    while (en.hasMoreElements())
 
1643
      result.addElement(en.nextElement());
 
1644
 
 
1645
    result.addElement(new Option(
 
1646
        "\tThe minimal number of instances at the terminal nodes.\n" 
 
1647
        + "\t(default 2)",
 
1648
        "M", 1, "-M <min no>"));
 
1649
    
 
1650
    result.addElement(new Option(
 
1651
        "\tThe number of folds used in the minimal cost-complexity pruning.\n"
 
1652
        + "\t(default 5)",
 
1653
        "N", 1, "-N <num folds>"));
 
1654
    
 
1655
    result.addElement(new Option(
 
1656
        "\tDon't use the minimal cost-complexity pruning.\n"
 
1657
        + "\t(default yes).",
 
1658
        "U", 0, "-U"));
 
1659
    
 
1660
    result.addElement(new Option(
 
1661
        "\tDon't use the heuristic method for binary split.\n"
 
1662
        + "\t(default true).",
 
1663
        "H", 0, "-H"));
 
1664
    
 
1665
    result.addElement(new Option(
 
1666
        "\tUse 1 SE rule to make pruning decision.\n"
 
1667
        + "\t(default no).",
 
1668
        "A", 0, "-A"));
 
1669
    
 
1670
    result.addElement(new Option(
 
1671
        "\tPercentage of training data size (0-1].\n" 
 
1672
        + "\t(default 1).",
 
1673
        "C", 1, "-C"));
 
1674
 
 
1675
    return result.elements();
 
1676
  }
 
1677
 
 
1678
  /**
 
1679
   * Parses a given list of options. <p/>
 
1680
   * 
 
1681
   <!-- options-start -->
 
1682
   * Valid options are: <p/>
 
1683
   * 
 
1684
   * <pre> -S &lt;num&gt;
 
1685
   *  Random number seed.
 
1686
   *  (default 1)</pre>
 
1687
   * 
 
1688
   * <pre> -D
 
1689
   *  If set, classifier is run in debug mode and
 
1690
   *  may output additional info to the console</pre>
 
1691
   * 
 
1692
   * <pre> -M &lt;min no&gt;
 
1693
   *  The minimal number of instances at the terminal nodes.
 
1694
   *  (default 2)</pre>
 
1695
   * 
 
1696
   * <pre> -N &lt;num folds&gt;
 
1697
   *  The number of folds used in the minimal cost-complexity pruning.
 
1698
   *  (default 5)</pre>
 
1699
   * 
 
1700
   * <pre> -U
 
1701
   *  Don't use the minimal cost-complexity pruning.
 
1702
   *  (default yes).</pre>
 
1703
   * 
 
1704
   * <pre> -H
 
1705
   *  Don't use the heuristic method for binary split.
 
1706
   *  (default true).</pre>
 
1707
   * 
 
1708
   * <pre> -A
 
1709
   *  Use 1 SE rule to make pruning decision.
 
1710
   *  (default no).</pre>
 
1711
   * 
 
1712
   * <pre> -C
 
1713
   *  Percentage of training data size (0-1].
 
1714
   *  (default 1).</pre>
 
1715
   * 
 
1716
   <!-- options-end -->
 
1717
   * 
 
1718
   * @param options the list of options as an array of strings
 
1719
   * @throws Exception if an options is not supported
 
1720
   */
 
1721
  public void setOptions(String[] options) throws Exception {
 
1722
    String      tmpStr;
 
1723
    
 
1724
    super.setOptions(options);
 
1725
    
 
1726
    tmpStr = Utils.getOption('M', options);
 
1727
    if (tmpStr.length() != 0)
 
1728
      setMinNumObj(Double.parseDouble(tmpStr));
 
1729
    else
 
1730
      setMinNumObj(2);
 
1731
 
 
1732
    tmpStr = Utils.getOption('N', options);
 
1733
    if (tmpStr.length()!=0)
 
1734
      setNumFoldsPruning(Integer.parseInt(tmpStr));
 
1735
    else
 
1736
      setNumFoldsPruning(5);
 
1737
 
 
1738
    setUsePrune(!Utils.getFlag('U',options));
 
1739
    setHeuristic(!Utils.getFlag('H',options));
 
1740
    setUseOneSE(Utils.getFlag('A',options));
 
1741
 
 
1742
    tmpStr = Utils.getOption('C', options);
 
1743
    if (tmpStr.length()!=0)
 
1744
      setSizePer(Double.parseDouble(tmpStr));
 
1745
    else
 
1746
      setSizePer(1);
 
1747
 
 
1748
    Utils.checkForRemainingOptions(options);
 
1749
  }
 
1750
 
 
1751
  /**
 
1752
   * Gets the current settings of the classifier.
 
1753
   * 
 
1754
   * @return            the current setting of the classifier
 
1755
   */
 
1756
  public String[] getOptions() {
 
1757
    int         i;
 
1758
    Vector      result;
 
1759
    String[]    options;
 
1760
 
 
1761
    result = new Vector();
 
1762
 
 
1763
    options = super.getOptions();
 
1764
    for (i = 0; i < options.length; i++)
 
1765
      result.add(options[i]);
 
1766
 
 
1767
    result.add("-M");
 
1768
    result.add("" + getMinNumObj());
 
1769
    
 
1770
    result.add("-N");
 
1771
    result.add("" + getNumFoldsPruning());
 
1772
    
 
1773
    if (!getUsePrune())
 
1774
      result.add("-U");
 
1775
    
 
1776
    if (!getHeuristic())
 
1777
      result.add("-H");
 
1778
    
 
1779
    if (getUseOneSE())
 
1780
      result.add("-A");
 
1781
    
 
1782
    result.add("-C");
 
1783
    result.add("" + getSizePer());
 
1784
 
 
1785
    return (String[]) result.toArray(new String[result.size()]);          
 
1786
  }
 
1787
 
 
1788
  /**
 
1789
   * Return an enumeration of the measure names.
 
1790
   * 
 
1791
   * @return            an enumeration of the measure names
 
1792
   */
 
1793
  public Enumeration enumerateMeasures() {
 
1794
    Vector result = new Vector();
 
1795
    
 
1796
    result.addElement("measureTreeSize");
 
1797
    
 
1798
    return result.elements();
 
1799
  }
 
1800
 
 
1801
  /**
 
1802
   * Return number of tree size.
 
1803
   * 
 
1804
   * @return            number of tree size
 
1805
   */
 
1806
  public double measureTreeSize() {
 
1807
    return numNodes();
 
1808
  }
 
1809
 
 
1810
  /**
 
1811
   * Returns the value of the named measure.
 
1812
   * 
 
1813
   * @param additionalMeasureName       the name of the measure to query for its value
 
1814
   * @return                            the value of the named measure
 
1815
   * @throws IllegalArgumentException   if the named measure is not supported
 
1816
   */
 
1817
  public double getMeasure(String additionalMeasureName) {
 
1818
    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
 
1819
      return measureTreeSize();
 
1820
    } else {
 
1821
      throw new IllegalArgumentException(additionalMeasureName
 
1822
          + " not supported (Cart pruning)");
 
1823
    }
 
1824
  }
 
1825
 
 
1826
  /**
 
1827
   * Returns the tip text for this property
 
1828
   * 
 
1829
   * @return            tip text for this property suitable for
 
1830
   *                    displaying in the explorer/experimenter gui
 
1831
   */
 
1832
  public String minNumObjTipText() {
 
1833
    return "The minimal number of observations at the terminal nodes (default 2).";
 
1834
  }
 
1835
 
 
1836
  /**
 
1837
   * Set minimal number of instances at the terminal nodes.
 
1838
   * 
 
1839
   * @param value       minimal number of instances at the terminal nodes
 
1840
   */
 
1841
  public void setMinNumObj(double value) {
 
1842
    m_minNumObj = value;
 
1843
  }
 
1844
 
 
1845
  /**
 
1846
   * Get minimal number of instances at the terminal nodes.
 
1847
   * 
 
1848
   * @return            minimal number of instances at the terminal nodes
 
1849
   */
 
1850
  public double getMinNumObj() {
 
1851
    return m_minNumObj;
 
1852
  }
 
1853
 
 
1854
  /**
 
1855
   * Returns the tip text for this property
 
1856
   * 
 
1857
   * @return            tip text for this property suitable for
 
1858
   *                    displaying in the explorer/experimenter gui
 
1859
   */
 
1860
  public String numFoldsPruningTipText() {
 
1861
    return "The number of folds in the internal cross-validation (default 5).";
 
1862
  }
 
1863
 
 
1864
  /** 
 
1865
   * Set number of folds in internal cross-validation.
 
1866
   * 
 
1867
   * @param value       number of folds in internal cross-validation.
 
1868
   */
 
1869
  public void setNumFoldsPruning(int value) {
 
1870
    m_numFoldsPruning = value;
 
1871
  }
 
1872
 
 
1873
  /**
 
1874
   * Set number of folds in internal cross-validation.
 
1875
   * 
 
1876
   * @return            number of folds in internal cross-validation.
 
1877
   */
 
1878
  public int getNumFoldsPruning() {
 
1879
    return m_numFoldsPruning;
 
1880
  }
 
1881
 
 
1882
  /**
 
1883
   * Return the tip text for this property
 
1884
   * 
 
1885
   * @return            tip text for this property suitable for displaying in 
 
1886
   *                    the explorer/experimenter gui.
 
1887
   */
 
1888
  public String usePruneTipText() {
 
1889
    return "Use minimal cost-complexity pruning (default yes).";
 
1890
  }
 
1891
 
 
1892
  /** 
 
1893
   * Set if use minimal cost-complexity pruning.
 
1894
   * 
 
1895
   * @param value       if use minimal cost-complexity pruning
 
1896
   */
 
1897
  public void setUsePrune(boolean value) {
 
1898
    m_Prune = value;
 
1899
  }
 
1900
 
 
1901
  /** 
 
1902
   * Get if use minimal cost-complexity pruning.
 
1903
   * 
 
1904
   * @return            if use minimal cost-complexity pruning
 
1905
   */
 
1906
  public boolean getUsePrune() {
 
1907
    return m_Prune;
 
1908
  }
 
1909
 
 
1910
  /**
 
1911
   * Returns the tip text for this property
 
1912
   * 
 
1913
   * @return            tip text for this property suitable for
 
1914
   *                    displaying in the explorer/experimenter gui.
 
1915
   */
 
1916
  public String heuristicTipText() {
 
1917
    return 
 
1918
        "If heuristic search is used for binary split for nominal attributes "
 
1919
      + "in multi-class problems (default yes).";
 
1920
  }
 
1921
 
 
1922
  /**
 
1923
   * Set if use heuristic search for nominal attributes in multi-class problems.
 
1924
   * 
 
1925
   * @param value       if use heuristic search for nominal attributes in 
 
1926
   *                    multi-class problems
 
1927
   */
 
1928
  public void setHeuristic(boolean value) {
 
1929
    m_Heuristic = value;
 
1930
  }
 
1931
 
 
1932
  /** 
 
1933
   * Get if use heuristic search for nominal attributes in multi-class problems.
 
1934
   * 
 
1935
   * @return            if use heuristic search for nominal attributes in 
 
1936
   *                    multi-class problems
 
1937
   */
 
1938
  public boolean getHeuristic() {return m_Heuristic;}
 
1939
 
 
1940
  /**
 
1941
   * Returns the tip text for this property
 
1942
   * 
 
1943
   * @return            tip text for this property suitable for
 
1944
   *                    displaying in the explorer/experimenter gui.
 
1945
   */
 
1946
  public String useOneSETipText() {
 
1947
    return "Use the 1SE rule to make pruning decisoin.";
 
1948
  }
 
1949
 
 
1950
  /** 
 
1951
   * Set if use the 1SE rule to choose final model.
 
1952
   * 
 
1953
   * @param value       if use the 1SE rule to choose final model
 
1954
   */
 
1955
  public void setUseOneSE(boolean value) {
 
1956
    m_UseOneSE = value;
 
1957
  }
 
1958
 
 
1959
  /**
 
1960
   * Get if use the 1SE rule to choose final model.
 
1961
   * 
 
1962
   * @return            if use the 1SE rule to choose final model
 
1963
   */
 
1964
  public boolean getUseOneSE() {
 
1965
    return m_UseOneSE;
 
1966
  }
 
1967
 
 
1968
  /**
 
1969
   * Returns the tip text for this property
 
1970
   * 
 
1971
   * @return            tip text for this property suitable for
 
1972
   *                    displaying in the explorer/experimenter gui.
 
1973
   */
 
1974
  public String sizePerTipText() {
 
1975
    return "The percentage of the training set size (0-1, 0 not included).";
 
1976
  }
 
1977
 
 
1978
  /** 
 
1979
   * Set training set size.
 
1980
   * 
 
1981
   * @param value       training set size
 
1982
   */  
 
1983
  public void setSizePer(double value) {
 
1984
    if ((value <= 0) || (value > 1))
 
1985
      System.err.println(
 
1986
          "The percentage of the training set size must be in range 0 to 1 "
 
1987
          + "(0 not included) - ignored!");
 
1988
    else
 
1989
      m_SizePer = value;
 
1990
  }
 
1991
 
 
1992
  /**
 
1993
   * Get training set size.
 
1994
   * 
 
1995
   * @return            training set size
 
1996
   */
 
1997
  public double getSizePer() {
 
1998
    return m_SizePer;
 
1999
  }
 
2000
 
 
2001
  /**
 
2002
   * Main method.
 
2003
   * @param args the options for the classifier
 
2004
   */
 
2005
  public static void main(String[] args) {
 
2006
    runClassifier(new SimpleCart(), args);
 
2007
  }
 
2008
}