2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2007 Haijian Shi
23
package weka.classifiers.trees;
25
import weka.classifiers.Evaluation;
26
import weka.classifiers.RandomizableClassifier;
27
import weka.core.AdditionalMeasureProducer;
28
import weka.core.Attribute;
29
import weka.core.Capabilities;
30
import weka.core.Instance;
31
import weka.core.Instances;
32
import weka.core.Option;
33
import weka.core.TechnicalInformation;
34
import weka.core.TechnicalInformationHandler;
35
import weka.core.Utils;
36
import weka.core.Capabilities.Capability;
37
import weka.core.TechnicalInformation.Field;
38
import weka.core.TechnicalInformation.Type;
39
import weka.core.matrix.Matrix;
41
import java.util.Arrays;
42
import java.util.Enumeration;
43
import java.util.Random;
44
import java.util.Vector;
47
<!-- globalinfo-start -->
48
* Class implementing minimal cost-complexity pruning.<br/>
49
* Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/>
51
* For more information, see:<br/>
53
* Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
55
<!-- globalinfo-end -->
57
<!-- technical-bibtex-start -->
60
* @book{Breiman1984,
61
* address = {Belmont, California},
62
* author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
63
* publisher = {Wadsworth International Group},
64
* title = {Classification and Regression Trees},
69
<!-- technical-bibtex-end -->
71
<!-- options-start -->
72
* Valid options are: <p/>
74
* <pre> -S <num>
79
* If set, classifier is run in debug mode and
80
* may output additional info to the console</pre>
82
* <pre> -M <min no>
83
* The minimal number of instances at the terminal nodes.
86
* <pre> -N <num folds>
87
* The number of folds used in the minimal cost-complexity pruning.
91
* Don't use the minimal cost-complexity pruning.
92
* (default yes).</pre>
95
* Don't use the heuristic method for binary split.
96
* (default true).</pre>
99
* Use 1 SE rule to make pruning decision.
100
* (default no).</pre>
103
* Percentage of training data size (0-1].
108
* @author Haijian Shi (hs69@cs.waikato.ac.nz)
109
* @version $Revision: 1.3 $
111
public class SimpleCart
112
extends RandomizableClassifier
113
implements AdditionalMeasureProducer, TechnicalInformationHandler {
115
/** For serialization. */
116
private static final long serialVersionUID = 4154189200352566053L;
118
/** Training data. */
119
protected Instances m_train;
121
/** Successor nodes. */
122
protected SimpleCart[] m_Successors;
124
/** Attribute used to split data. */
125
protected Attribute m_Attribute;
127
/** Split point for a numeric attribute. */
128
protected double m_SplitValue;
130
/** Split subset used to split data for nominal attributes. */
131
protected String m_SplitString;
133
/** Class value if the node is leaf. */
134
protected double m_ClassValue;
136
/** Class attriubte of data. */
137
protected Attribute m_ClassAttribute;
139
/** Minimum number of instances in at the terminal nodes. */
140
protected double m_minNumObj = 2;
142
/** Number of folds for minimal cost-complexity pruning. */
143
protected int m_numFoldsPruning = 5;
145
/** Alpha-value (for pruning) at the node. */
146
protected double m_Alpha;
148
/** Number of training examples misclassified by the model (subtree rooted). */
149
protected double m_numIncorrectModel;
151
/** Number of training examples misclassified by the model (subtree not rooted). */
152
protected double m_numIncorrectTree;
154
/** Indicate if the node is a leaf node. */
155
protected boolean m_isLeaf;
157
/** If use minimal cost-compexity pruning. */
158
protected boolean m_Prune = true;
160
/** Total number of instances used to build the classifier. */
161
protected int m_totalTrainInstances;
163
/** Proportion for each branch. */
164
protected double[] m_Props;
166
/** Class probabilities. */
167
protected double[] m_ClassProbs = null;
169
/** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
170
protected double[] m_Distribution;
172
/** If use huristic search for nominal attributes in multi-class problems (default true). */
173
protected boolean m_Heuristic = true;
175
/** If use the 1SE rule to make final decision tree. */
176
protected boolean m_UseOneSE = false;
178
/** Training data size. */
179
protected double m_SizePer = 1;
182
* Return a description suitable for displaying in the explorer/experimenter.
184
* @return a description suitable for displaying in the
185
* explorer/experimenter
187
public String globalInfo() {
189
"Class implementing minimal cost-complexity pruning.\n"
190
+ "Note when dealing with missing values, use \"fractional "
191
+ "instances\" method instead of surrogate split method.\n\n"
192
+ "For more information, see:\n\n"
193
+ getTechnicalInformation().toString();
197
* Returns an instance of a TechnicalInformation object, containing
198
* detailed information about the technical background of this class,
199
* e.g., paper reference or book this class is based on.
201
* @return the technical information about this class
203
public TechnicalInformation getTechnicalInformation() {
204
TechnicalInformation result;
206
result = new TechnicalInformation(Type.BOOK);
207
result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
208
result.setValue(Field.YEAR, "1984");
209
result.setValue(Field.TITLE, "Classification and Regression Trees");
210
result.setValue(Field.PUBLISHER, "Wadsworth International Group");
211
result.setValue(Field.ADDRESS, "Belmont, California");
217
* Returns default capabilities of the classifier.
219
* @return the capabilities of this classifier
221
public Capabilities getCapabilities() {
222
Capabilities result = super.getCapabilities();
225
result.enable(Capability.NOMINAL_ATTRIBUTES);
226
result.enable(Capability.NUMERIC_ATTRIBUTES);
227
result.enable(Capability.MISSING_VALUES);
230
result.enable(Capability.NOMINAL_CLASS);
236
* Build the classifier.
238
* @param data the training instances
239
* @throws Exception if something goes wrong
241
public void buildClassifier(Instances data) throws Exception {
243
getCapabilities().testWithFail(data);
244
data = new Instances(data);
245
data.deleteWithMissingClass();
247
// unpruned CART decision tree
250
// calculate sorted indices and weights, and compute initial class counts.
251
int[][] sortedIndices = new int[data.numAttributes()][0];
252
double[][] weights = new double[data.numAttributes()][0];
253
double[] classProbs = new double[data.numClasses()];
254
double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
256
makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
257
totalWeight,m_minNumObj, m_Heuristic);
261
Random random = new Random(m_Seed);
262
Instances cvData = new Instances(data);
263
cvData.randomize(random);
264
cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
265
cvData.stratify(m_numFoldsPruning);
267
double[][] alphas = new double[m_numFoldsPruning][];
268
double[][] errors = new double[m_numFoldsPruning][];
270
// calculate errors and alphas for each fold
271
for (int i = 0; i < m_numFoldsPruning; i++) {
273
//for every fold, grow tree on training set and fix error on test set.
274
Instances train = cvData.trainCV(m_numFoldsPruning, i);
275
Instances test = cvData.testCV(m_numFoldsPruning, i);
277
// calculate sorted indices and weights, and compute initial class counts for each fold
278
int[][] sortedIndices = new int[train.numAttributes()][0];
279
double[][] weights = new double[train.numAttributes()][0];
280
double[] classProbs = new double[train.numClasses()];
281
double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);
283
makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
284
totalWeight,m_minNumObj, m_Heuristic);
286
int numNodes = numInnerNodes();
287
alphas[i] = new double[numNodes + 2];
288
errors[i] = new double[numNodes + 2];
290
// prune back and log alpha-values and errors on test set
291
prune(alphas[i], errors[i], test);
294
// calculate sorted indices and weights, and compute initial class counts on all training instances
295
int[][] sortedIndices = new int[data.numAttributes()][0];
296
double[][] weights = new double[data.numAttributes()][0];
297
double[] classProbs = new double[data.numClasses()];
298
double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
300
//build tree using all the data
301
makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
302
totalWeight,m_minNumObj, m_Heuristic);
304
int numNodes = numInnerNodes();
306
double[] treeAlphas = new double[numNodes + 2];
308
// prune back and log alpha-values
309
int iterations = prune(treeAlphas, null, null);
311
double[] treeErrors = new double[numNodes + 2];
313
// for each pruned subtree, find the cross-validated error
314
for (int i = 0; i <= iterations; i++){
315
//compute midpoint alphas
316
double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
318
for (int k = 0; k < m_numFoldsPruning; k++) {
320
while (alphas[k][l] <= alpha) l++;
321
error += errors[k][l - 1];
323
treeErrors[i] = error/m_numFoldsPruning;
328
double bestError = Double.MAX_VALUE;
329
for (int i = iterations; i >= 0; i--) {
330
if (treeErrors[i] < bestError) {
331
bestError = treeErrors[i];
336
// 1 SE rule to choose expansion
338
double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
339
for (int i = iterations; i >= 0; i--) {
340
if (treeErrors[i] <= bestError+oneSE) {
347
double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);
349
//"unprune" final tree (faster than regrowing it)
355
* Make binary decision tree recursively.
357
* @param data the training instances
358
* @param totalInstances total number of instances
359
* @param sortedIndices sorted indices of the instances
360
* @param weights weights of the instances
361
* @param classProbs class probabilities
362
* @param totalWeight total weight of instances
363
* @param minNumObj minimal number of instances at leaf nodes
364
* @param useHeuristic if use heuristic search for nominal attributes in multi-class problem
365
* @throws Exception if something goes wrong
367
protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
368
double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
369
boolean useHeuristic) throws Exception{
371
// if no instances have reached this node (normally won't happen)
372
if (totalWeight == 0){
374
m_ClassValue = Instance.missingValue();
375
m_Distribution = new double[data.numClasses()];
379
m_totalTrainInstances = totalInstances;
382
m_ClassProbs = new double[classProbs.length];
383
m_Distribution = new double[classProbs.length];
384
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
385
System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
386
if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);
388
// Compute class distributions and value of splitting
389
// criterion for each attribute
390
double[][][] dists = new double[data.numAttributes()][0][0];
391
double[][] props = new double[data.numAttributes()][0];
392
double[][] totalSubsetWeights = new double[data.numAttributes()][2];
393
double[] splits = new double[data.numAttributes()];
394
String[] splitString = new String[data.numAttributes()];
395
double[] giniGains = new double[data.numAttributes()];
397
// for each attribute find split information
398
for (int i = 0; i < data.numAttributes(); i++) {
399
Attribute att = data.attribute(i);
400
if (i==data.classIndex()) continue;
401
if (att.isNumeric()) {
403
splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
404
weights[i], totalSubsetWeights, giniGains, data);
407
splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
408
weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
412
// Find best attribute (split with maximum Gini gain)
413
int attIndex = Utils.maxIndex(giniGains);
414
m_Attribute = data.attribute(attIndex);
416
m_train = new Instances(data, sortedIndices[attIndex].length);
417
for (int i=0; i<sortedIndices[attIndex].length; i++) {
418
Instance inst = data.instance(sortedIndices[attIndex][i]);
419
Instance instCopy = (Instance)inst.copy();
420
instCopy.setWeight(weights[attIndex][i]);
421
m_train.add(instCopy);
424
// Check if node does not contain enough instances, or if it can not be split,
425
// or if it is pure. If does, make leaf.
426
if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||
427
props[attIndex][0]==0 || props[attIndex][1]==0) {
432
m_Props = props[attIndex];
433
int[][][] subsetIndices = new int[2][data.numAttributes()][0];
434
double[][][] subsetWeights = new double[2][data.numAttributes()][0];
437
if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];
440
else m_SplitString = splitString[attIndex];
442
splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
443
m_SplitString, sortedIndices, weights, data);
445
// If split of the node results in a node with less than minimal number of isntances,
446
// make the node leaf node.
447
if (subsetIndices[0][attIndex].length<minNumObj ||
448
subsetIndices[1][attIndex].length<minNumObj) {
453
// Otherwise, split the node.
455
m_Successors = new SimpleCart[2];
456
for (int i = 0; i < 2; i++) {
457
m_Successors[i] = new SimpleCart();
458
m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
459
subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],
460
minNumObj, useHeuristic);
466
* Prunes the original tree using the CART pruning scheme, given a
467
* cost-complexity parameter alpha.
469
* @param alpha the cost-complexity parameter
470
* @throws Exception if something goes wrong
472
public void prune(double alpha) throws Exception {
476
// determine training error of pruned subtrees (both with and without replacing a subtree),
477
// and calculate alpha-values from them
482
// get list of all inner nodes in the tree
483
nodeList = getInnerNodes();
485
boolean prune = (nodeList.size() > 0);
486
double preAlpha = Double.MAX_VALUE;
489
// select node with minimum alpha
490
SimpleCart nodeToPrune = nodeToPrune(nodeList);
492
// want to prune if its alpha is smaller than alpha
493
if (nodeToPrune.m_Alpha > alpha) {
497
nodeToPrune.makeLeaf(nodeToPrune.m_train);
499
// normally would not happen
500
if (nodeToPrune.m_Alpha==preAlpha) {
501
nodeToPrune.makeLeaf(nodeToPrune.m_train);
504
nodeList = getInnerNodes();
505
prune = (nodeList.size() > 0);
508
preAlpha = nodeToPrune.m_Alpha;
510
//update tree errors and alphas
514
nodeList = getInnerNodes();
515
prune = (nodeList.size() > 0);
520
* Method for performing one fold in the cross-validation of minimal
521
* cost-complexity pruning. Generates a sequence of alpha-values with error
522
* estimates for the corresponding (partially pruned) trees, given the test
525
* @param alphas array to hold the generated alpha-values
526
* @param errors array to hold the corresponding error estimates
527
* @param test test set of that fold (to obtain error estimates)
528
* @return the iteration of the pruning
529
* @throws Exception if something goes wrong
531
public int prune(double[] alphas, double[] errors, Instances test)
536
// determine training error of subtrees (both with and without replacing a subtree),
537
// and calculate alpha-values from them
542
// get list of all inner nodes in the tree
543
nodeList = getInnerNodes();
545
boolean prune = (nodeList.size() > 0);
547
//alpha_0 is always zero (unpruned tree)
552
// error of unpruned tree
553
if (errors != null) {
554
eval = new Evaluation(test);
555
eval.evaluateModel(this, test);
556
errors[0] = eval.errorRate();
560
double preAlpha = Double.MAX_VALUE;
565
// get node with minimum alpha
566
SimpleCart nodeToPrune = nodeToPrune(nodeList);
568
// do not set m_sons null, want to unprune
569
nodeToPrune.m_isLeaf = true;
571
// normally would not happen
572
if (nodeToPrune.m_Alpha==preAlpha) {
576
nodeList = getInnerNodes();
577
prune = (nodeList.size() > 0);
581
// get alpha-value of node
582
alphas[iteration] = nodeToPrune.m_Alpha;
585
if (errors != null) {
586
eval = new Evaluation(test);
587
eval.evaluateModel(this, test);
588
errors[iteration] = eval.errorRate();
590
preAlpha = nodeToPrune.m_Alpha;
592
//update errors/alphas
596
nodeList = getInnerNodes();
597
prune = (nodeList.size() > 0);
600
//set last alpha 1 to indicate end
601
alphas[iteration + 1] = 1.0;
606
* Method to "unprune" the CART tree. Sets all leaf-fields to false.
607
* Faster than re-growing the tree because CART do not have to be fit again.
609
protected void unprune() {
610
if (m_Successors != null) {
612
for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
617
* Compute distributions, proportions and total weights of two successor
618
* nodes for a given numeric attribute.
620
* @param props proportions of each two branches for each attribute
621
* @param dists class distributions of two branches for each attribute
622
* @param att numeric att split on
623
* @param sortedIndices sorted indices of instances for the attirubte
624
* @param weights weights of instances for the attirbute
625
* @param subsetWeights total weight of two branches split based on the attribute
626
* @param giniGains Gini gains for each attribute
627
* @param data training instances
628
* @return Gini gain the given numeric attribute
629
* @throws Exception if something goes wrong
631
protected double numericDistribution(double[][] props, double[][][] dists,
632
Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
633
double[] giniGains, Instances data)
636
double splitPoint = Double.NaN;
637
double[][] dist = null;
638
int numClasses = data.numClasses();
639
int i; // differ instances with or without missing values
641
double[][] currDist = new double[2][numClasses];
642
dist = new double[2][numClasses];
644
// Move all instances without missing values into second subset
645
double[] parentDist = new double[numClasses];
646
int missingStart = 0;
647
for (int j = 0; j < sortedIndices.length; j++) {
648
Instance inst = data.instance(sortedIndices[j]);
649
if (!inst.isMissing(att)) {
651
currDist[1][(int)inst.classValue()] += weights[j];
653
parentDist[(int)inst.classValue()] += weights[j];
655
System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
657
// Try all possible split points
658
double currSplit = data.instance(sortedIndices[0]).value(att);
660
double bestGiniGain = -Double.MAX_VALUE;
662
for (i = 0; i < sortedIndices.length; i++) {
663
Instance inst = data.instance(sortedIndices[i]);
664
if (inst.isMissing(att)) {
667
if (inst.value(att) > currSplit) {
669
double[][] tempDist = new double[2][numClasses];
670
for (int k=0; k<2; k++) {
671
//tempDist[k] = currDist[k];
672
System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
675
double[] tempProps = new double[2];
676
for (int k=0; k<2; k++) {
677
tempProps[k] = Utils.sum(tempDist[k]);
680
if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);
682
// split missing values
683
int index = missingStart;
684
while (index < sortedIndices.length) {
685
Instance insta = data.instance(sortedIndices[index]);
686
for (int j = 0; j < 2; j++) {
687
tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
692
currGiniGain = computeGiniGain(parentDist,tempDist);
694
if (currGiniGain > bestGiniGain) {
695
bestGiniGain = currGiniGain;
698
splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;
700
for (int j = 0; j < currDist.length; j++) {
701
System.arraycopy(tempDist[j], 0, dist[j], 0,
706
currSplit = inst.value(att);
707
currDist[0][(int)inst.classValue()] += weights[i];
708
currDist[1][(int)inst.classValue()] -= weights[i];
712
int attIndex = att.index();
713
props[attIndex] = new double[2];
714
for (int k = 0; k < 2; k++) {
715
props[attIndex][k] = Utils.sum(dist[k]);
717
if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);
719
// Compute subset weights
720
subsetWeights[attIndex] = new double[2];
721
for (int j = 0; j < 2; j++) {
722
subsetWeights[attIndex][j] += Utils.sum(dist[j]);
726
giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
727
dists[attIndex] = dist;
733
* Compute distributions, proportions and total weights of two successor
734
* nodes for a given nominal attribute.
736
* @param props proportions of each two branches for each attribute
737
* @param dists class distributions of two branches for each attribute
738
* @param att numeric att split on
739
* @param sortedIndices sorted indices of instances for the attirubte
740
* @param weights weights of instances for the attirbute
741
* @param subsetWeights total weight of two branches split based on the attribute
742
* @param giniGains Gini gains for each attribute
743
* @param data training instances
744
* @param useHeuristic if use heuristic search
745
* @return Gini gain for the given nominal attribute
746
* @throws Exception if something goes wrong
748
protected String nominalDistribution(double[][] props, double[][][] dists,
749
Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
750
double[] giniGains, Instances data, boolean useHeuristic)
753
String[] values = new String[att.numValues()];
754
int numCat = values.length; // number of values of the attribute
755
int numClasses = data.numClasses();
757
String bestSplitString = "";
758
double bestGiniGain = -Double.MAX_VALUE;
760
// class frequency for each value
761
int[] classFreq = new int[numCat];
762
for (int j=0; j<numCat; j++) classFreq[j] = 0;
764
double[] parentDist = new double[numClasses];
765
double[][] currDist = new double[2][numClasses];
766
double[][] dist = new double[2][numClasses];
767
int missingStart = 0;
769
for (int i = 0; i < sortedIndices.length; i++) {
770
Instance inst = data.instance(sortedIndices[i]);
771
if (!inst.isMissing(att)) {
773
classFreq[(int)inst.value(att)] ++;
775
parentDist[(int)inst.classValue()] += weights[i];
778
// count the number of values that class frequency is not 0
780
for (int j=0; j<numCat; j++) {
781
if (classFreq[j]!=0) nonEmpty ++;
784
// attribute values that class frequency is not 0
785
String[] nonEmptyValues = new String[nonEmpty];
786
int nonEmptyIndex = 0;
787
for (int j=0; j<numCat; j++) {
788
if (classFreq[j]!=0) {
789
nonEmptyValues[nonEmptyIndex] = att.value(j);
794
// attribute values that class frequency is 0
795
int empty = numCat - nonEmpty;
796
String[] emptyValues = new String[empty];
798
for (int j=0; j<numCat; j++) {
799
if (classFreq[j]==0) {
800
emptyValues[emptyIndex] = att.value(j);
806
giniGains[att.index()] = 0;
810
// for tow-class probloms
811
if (data.numClasses()==2) {
813
//// Firstly, for attribute values which class frequency is not zero
815
// probability of class 0 for each attribute value
816
double[] pClass0 = new double[nonEmpty];
817
// class distribution for each attribute value
818
double[][] valDist = new double[nonEmpty][2];
820
for (int j=0; j<nonEmpty; j++) {
821
for (int k=0; k<2; k++) {
826
for (int i = 0; i < sortedIndices.length; i++) {
827
Instance inst = data.instance(sortedIndices[i]);
828
if (inst.isMissing(att)) {
832
for (int j=0; j<nonEmpty; j++) {
833
if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
834
valDist[j][(int)inst.classValue()] += inst.weight();
840
for (int j=0; j<nonEmpty; j++) {
841
double distSum = Utils.sum(valDist[j]);
842
if (distSum==0) pClass0[j]=0;
843
else pClass0[j] = valDist[j][0]/distSum;
846
// sort category according to the probability of the first class
847
String[] sortedValues = new String[nonEmpty];
848
for (int j=0; j<nonEmpty; j++) {
849
sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
850
pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
853
// Find a subset of attribute values that maximize Gini decrease
855
// for the attribute values that class frequency is not 0
858
for (int j=0; j<nonEmpty-1; j++) {
859
currDist = new double[2][numClasses];
860
if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
861
else tempStr += "|"+ "(" + sortedValues[j] + ")";
862
for (int i=0; i<sortedIndices.length;i++) {
863
Instance inst = data.instance(sortedIndices[i]);
864
if (inst.isMissing(att)) {
869
("(" + att.value((int)inst.value(att)) + ")")!=-1) {
870
currDist[0][(int)inst.classValue()] += weights[i];
871
} else currDist[1][(int)inst.classValue()] += weights[i];
874
double[][] tempDist = new double[2][numClasses];
875
for (int kk=0; kk<2; kk++) {
876
tempDist[kk] = currDist[kk];
879
double[] tempProps = new double[2];
880
for (int kk=0; kk<2; kk++) {
881
tempProps[kk] = Utils.sum(tempDist[kk]);
884
if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
886
// split missing values
887
int mstart = missingStart;
888
while (mstart < sortedIndices.length) {
889
Instance insta = data.instance(sortedIndices[mstart]);
890
for (int jj = 0; jj < 2; jj++) {
891
tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
896
double currGiniGain = computeGiniGain(parentDist,tempDist);
898
if (currGiniGain>bestGiniGain) {
899
bestGiniGain = currGiniGain;
900
bestSplitString = tempStr;
901
for (int jj = 0; jj < 2; jj++) {
902
//dist[jj] = new double[currDist[jj].length];
903
System.arraycopy(tempDist[jj], 0, dist[jj], 0,
910
// multi-class problems - exhaustive search
911
else if (!useHeuristic || nonEmpty<=4) {
913
// Firstly, for attribute values which class frequency is not zero
914
for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
916
currDist = new double[2][numClasses];
919
for (int j=nonEmpty-1; j>=0; j--) {
920
mod = bit10%2; // convert from 10bit to 2bit
922
if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
923
else tempStr += "|" + "("+nonEmptyValues[j]+")";
927
for (int j=0; j<sortedIndices.length;j++) {
928
Instance inst = data.instance(sortedIndices[j]);
929
if (inst.isMissing(att)) {
933
if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
934
currDist[0][(int)inst.classValue()] += weights[j];
935
} else currDist[1][(int)inst.classValue()] += weights[j];
938
double[][] tempDist = new double[2][numClasses];
939
for (int k=0; k<2; k++) {
940
tempDist[k] = currDist[k];
943
double[] tempProps = new double[2];
944
for (int k=0; k<2; k++) {
945
tempProps[k] = Utils.sum(tempDist[k]);
948
if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
950
// split missing values
951
int index = missingStart;
952
while (index < sortedIndices.length) {
953
Instance insta = data.instance(sortedIndices[index]);
954
for (int j = 0; j < 2; j++) {
955
tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
960
double currGiniGain = computeGiniGain(parentDist,tempDist);
962
if (currGiniGain>bestGiniGain) {
963
bestGiniGain = currGiniGain;
964
bestSplitString = tempStr;
965
for (int j = 0; j < 2; j++) {
966
//dist[jj] = new double[currDist[jj].length];
967
System.arraycopy(tempDist[j], 0, dist[j], 0,
974
// huristic search to solve multi-classes problems
976
// Firstly, for attribute values which class frequency is not zero
978
int k = data.numClasses(); // number of classes of the data
979
double[][] P = new double[n][k]; // class probability matrix
980
int[] numInstancesValue = new int[n]; // number of instances for an attribute value
981
double[] meanClass = new double[k]; // vector of mean class probability
982
int numInstances = data.numInstances(); // total number of instances
984
// initialize the vector of mean class probability
985
for (int j=0; j<meanClass.length; j++) meanClass[j]=0;
987
for (int j=0; j<numInstances; j++) {
988
Instance inst = (Instance)data.instance(j);
989
int valueIndex = 0; // attribute value index in nonEmptyValues
990
for (int i=0; i<nonEmpty; i++) {
991
if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
996
P[valueIndex][(int)inst.classValue()]++;
997
numInstancesValue[valueIndex]++;
998
meanClass[(int)inst.classValue()]++;
1001
// calculate the class probability matrix
1002
for (int i=0; i<P.length; i++) {
1003
for (int j=0; j<P[0].length; j++) {
1004
if (numInstancesValue[i]==0) P[i][j]=0;
1005
else P[i][j]/=numInstancesValue[i];
1009
//calculate the vector of mean class probability
1010
for (int i=0; i<meanClass.length; i++) {
1011
meanClass[i]/=numInstances;
1014
// calculate the covariance matrix
1015
double[][] covariance = new double[k][k];
1016
for (int i1=0; i1<k; i1++) {
1017
for (int i2=0; i2<k; i2++) {
1019
for (int j=0; j<n; j++) {
1020
element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
1021
*numInstancesValue[j];
1023
covariance[i1][i2] = element;
1027
Matrix matrix = new Matrix(covariance);
1028
weka.core.matrix.EigenvalueDecomposition eigen =
1029
new weka.core.matrix.EigenvalueDecomposition(matrix);
1030
double[] eigenValues = eigen.getRealEigenvalues();
1032
// find index of the largest eigenvalue
1034
double largest = eigenValues[0];
1035
for (int i=1; i<eigenValues.length; i++) {
1036
if (eigenValues[i]>largest) {
1038
largest = eigenValues[i];
1042
// calculate the first principle component
1043
double[] FPC = new double[k];
1044
Matrix eigenVector = eigen.getV();
1045
double[][] vectorArray = eigenVector.getArray();
1046
for (int i=0; i<FPC.length; i++) {
1047
FPC[i] = vectorArray[i][index];
1050
// calculate the first principle component scores
1051
//System.out.println("the first principle component scores: ");
1052
double[] Sa = new double[n];
1053
for (int i=0; i<Sa.length; i++) {
1055
for (int j=0; j<k; j++) {
1056
Sa[i] += FPC[j]*P[i][j];
1060
// sort category according to Sa(s)
1061
double[] pCopy = new double[n];
1062
System.arraycopy(Sa,0,pCopy,0,n);
1063
String[] sortedValues = new String[n];
1066
for (int j=0; j<n; j++) {
1067
sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
1068
pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
1071
// for the attribute values that class frequency is not 0
1072
String tempStr = "";
1074
for (int j=0; j<nonEmpty-1; j++) {
1075
currDist = new double[2][numClasses];
1076
if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
1077
else tempStr += "|"+ "(" + sortedValues[j] + ")";
1078
for (int i=0; i<sortedIndices.length;i++) {
1079
Instance inst = data.instance(sortedIndices[i]);
1080
if (inst.isMissing(att)) {
1085
("(" + att.value((int)inst.value(att)) + ")")!=-1) {
1086
currDist[0][(int)inst.classValue()] += weights[i];
1087
} else currDist[1][(int)inst.classValue()] += weights[i];
1090
double[][] tempDist = new double[2][numClasses];
1091
for (int kk=0; kk<2; kk++) {
1092
tempDist[kk] = currDist[kk];
1095
double[] tempProps = new double[2];
1096
for (int kk=0; kk<2; kk++) {
1097
tempProps[kk] = Utils.sum(tempDist[kk]);
1100
if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
1102
// split missing values
1103
int mstart = missingStart;
1104
while (mstart < sortedIndices.length) {
1105
Instance insta = data.instance(sortedIndices[mstart]);
1106
for (int jj = 0; jj < 2; jj++) {
1107
tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
1112
double currGiniGain = computeGiniGain(parentDist,tempDist);
1114
if (currGiniGain>bestGiniGain) {
1115
bestGiniGain = currGiniGain;
1116
bestSplitString = tempStr;
1117
for (int jj = 0; jj < 2; jj++) {
1118
//dist[jj] = new double[currDist[jj].length];
1119
System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1127
int attIndex = att.index();
1128
props[attIndex] = new double[2];
1129
for (int k = 0; k < 2; k++) {
1130
props[attIndex][k] = Utils.sum(dist[k]);
1133
if (!(Utils.sum(props[attIndex]) > 0)) {
1134
for (int k = 0; k < props[attIndex].length; k++) {
1135
props[attIndex][k] = 1.0 / (double)props[attIndex].length;
1138
Utils.normalize(props[attIndex]);
1142
// Compute subset weights
1143
subsetWeights[attIndex] = new double[2];
1144
for (int j = 0; j < 2; j++) {
1145
subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1148
// Then, for the attribute values that class frequency is 0, split it into the
1149
// most frequent branch
1150
for (int j=0; j<empty; j++) {
1151
if (props[attIndex][0]>=props[attIndex][1]) {
1152
if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
1153
else bestSplitString += "|" + "(" + emptyValues[j] + ")";
1157
// clean Gini gain for the attribute
1158
giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
1160
dists[attIndex] = dist;
1161
return bestSplitString;
1166
* Split data into two subsets and store sorted indices and weights for two
1169
* @param subsetIndices sorted indecis of instances for each attribute
1170
* for two successor node
1171
* @param subsetWeights weights of instances for each attribute for
1172
* two successor node
1173
* @param att attribute the split based on
1174
* @param splitPoint split point the split based on if att is numeric
1175
* @param splitStr split subset the split based on if att is nominal
1176
* @param sortedIndices sorted indices of the instances to be split
1177
* @param weights weights of the instances to bes split
1178
* @param data training data
1179
* @throws Exception if something goes wrong
1181
protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
1182
Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
1183
double[][] weights, Instances data) throws Exception {
1186
// For each attribute
1187
for (int i = 0; i < data.numAttributes(); i++) {
1188
if (i==data.classIndex()) continue;
1189
int[] num = new int[2];
1190
for (int k = 0; k < 2; k++) {
1191
subsetIndices[k][i] = new int[sortedIndices[i].length];
1192
subsetWeights[k][i] = new double[weights[i].length];
1195
for (j = 0; j < sortedIndices[i].length; j++) {
1196
Instance inst = data.instance(sortedIndices[i][j]);
1197
if (inst.isMissing(att)) {
1198
// Split instance up
1199
for (int k = 0; k < 2; k++) {
1200
if (m_Props[k] > 0) {
1201
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
1202
subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
1208
if (att.isNumeric()) {
1209
subset = (inst.value(att) < splitPoint) ? 0 : 1;
1210
} else { // nominal attribute
1211
if (splitStr.indexOf
1212
("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
1216
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
1217
subsetWeights[subset][i][num[subset]] = weights[i][j];
1223
for (int k = 0; k < 2; k++) {
1224
int[] copy = new int[num[k]];
1225
System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
1226
subsetIndices[k][i] = copy;
1227
double[] copyWeights = new double[num[k]];
1228
System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
1229
subsetWeights[k][i] = copyWeights;
1235
* Updates the numIncorrectModel field for all nodes when subtree (to be
1236
* pruned) is rooted. This is needed for calculating the alpha-values.
1238
* @throws Exception if something goes wrong
1240
public void modelErrors() throws Exception{
1241
Evaluation eval = new Evaluation(m_train);
1244
m_isLeaf = true; //temporarily make leaf
1246
// calculate distribution for evaluation
1247
eval.evaluateModel(this, m_train);
1248
m_numIncorrectModel = eval.incorrect();
1252
for (int i = 0; i < m_Successors.length; i++)
1253
m_Successors[i].modelErrors();
1256
eval.evaluateModel(this, m_train);
1257
m_numIncorrectModel = eval.incorrect();
1262
* Updates the numIncorrectTree field for all nodes. This is needed for
1263
* calculating the alpha-values.
1265
* @throws Exception if something goes wrong
1267
public void treeErrors() throws Exception {
1269
m_numIncorrectTree = m_numIncorrectModel;
1271
m_numIncorrectTree = 0;
1272
for (int i = 0; i < m_Successors.length; i++) {
1273
m_Successors[i].treeErrors();
1274
m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
1280
* Updates the alpha field for all nodes.
1282
* @throws Exception if something goes wrong
1284
public void calculateAlphas() throws Exception {
1287
double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
1288
if (errorDiff <=0) {
1289
//split increases training error (should not normally happen).
1290
//prune it instantly.
1292
m_Alpha = Double.MAX_VALUE;
1295
errorDiff /= m_totalTrainInstances;
1296
m_Alpha = errorDiff / (double)(numLeaves() - 1);
1297
long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
1298
m_Alpha = (double)alphaLong/Math.pow(10,10);
1299
for (int i = 0; i < m_Successors.length; i++) {
1300
m_Successors[i].calculateAlphas();
1304
//alpha = infinite for leaves (do not want to prune)
1305
m_Alpha = Double.MAX_VALUE;
1310
* Find the node with minimal alpha value. If two nodes have the same alpha,
1311
* choose the one with more leave nodes.
1313
* @param nodeList list of inner nodes
1314
* @return the node to be pruned
1316
protected SimpleCart nodeToPrune(Vector nodeList) {
1317
if (nodeList.size()==0) return null;
1318
if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
1319
SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
1320
double baseAlpha = returnNode.m_Alpha;
1321
for (int i=1; i<nodeList.size(); i++) {
1322
SimpleCart node = (SimpleCart)nodeList.elementAt(i);
1323
if (node.m_Alpha < baseAlpha) {
1324
baseAlpha = node.m_Alpha;
1326
} else if (node.m_Alpha == baseAlpha) { // break tie
1327
if (node.numLeaves()>returnNode.numLeaves()) {
1336
* Compute sorted indices, weights and class probabilities for a given
1337
* dataset. Return total weights of the data at the node.
1339
* @param data training data
1340
* @param sortedIndices sorted indices of instances at the node
1341
* @param weights weights of instances at the node
1342
* @param classProbs class probabilities at the node
1343
* @return total weights of instances at the node
1344
* @throws Exception if something goes wrong
1346
protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
1347
double[] classProbs) throws Exception {
1349
// Create array of sorted indices and weights
1350
double[] vals = new double[data.numInstances()];
1351
for (int j = 0; j < data.numAttributes(); j++) {
1352
if (j==data.classIndex()) continue;
1353
weights[j] = new double[data.numInstances()];
1355
if (data.attribute(j).isNominal()) {
1357
// Handling nominal attributes. Putting indices of
1358
// instances with missing values at the end.
1359
sortedIndices[j] = new int[data.numInstances()];
1361
for (int i = 0; i < data.numInstances(); i++) {
1362
Instance inst = data.instance(i);
1363
if (!inst.isMissing(j)) {
1364
sortedIndices[j][count] = i;
1365
weights[j][count] = inst.weight();
1369
for (int i = 0; i < data.numInstances(); i++) {
1370
Instance inst = data.instance(i);
1371
if (inst.isMissing(j)) {
1372
sortedIndices[j][count] = i;
1373
weights[j][count] = inst.weight();
1379
// Sorted indices are computed for numeric attributes
1380
// missing values instances are put to end
1381
for (int i = 0; i < data.numInstances(); i++) {
1382
Instance inst = data.instance(i);
1383
vals[i] = inst.value(j);
1385
sortedIndices[j] = Utils.sort(vals);
1386
for (int i = 0; i < data.numInstances(); i++) {
1387
weights[j][i] = data.instance(sortedIndices[j][i]).weight();
1392
// Compute initial class counts
1393
double totalWeight = 0;
1394
for (int i = 0; i < data.numInstances(); i++) {
1395
Instance inst = data.instance(i);
1396
classProbs[(int)inst.classValue()] += inst.weight();
1397
totalWeight += inst.weight();
1404
* Compute and return gini gain for given distributions of a node and its
1407
* @param parentDist class distributions of parent node
1408
* @param childDist class distributions of successor nodes
1409
* @return Gini gain computed
1411
protected double computeGiniGain(double[] parentDist, double[][] childDist) {
1412
double totalWeight = Utils.sum(parentDist);
1413
if (totalWeight==0) return 0;
1415
double leftWeight = Utils.sum(childDist[0]);
1416
double rightWeight = Utils.sum(childDist[1]);
1418
double parentGini = computeGini(parentDist, totalWeight);
1419
double leftGini = computeGini(childDist[0],leftWeight);
1420
double rightGini = computeGini(childDist[1], rightWeight);
1422
return parentGini - leftWeight/totalWeight*leftGini -
1423
rightWeight/totalWeight*rightGini;
1427
* Compute and return gini index for a given distribution of a node.
1429
* @param dist class distributions
1430
* @param total class distributions
1431
* @return Gini index of the class distributions
1433
protected double computeGini(double[] dist, double total) {
1434
if (total==0) return 0;
1436
for (int i=0; i<dist.length; i++) {
1437
val += (dist[i]/total)*(dist[i]/total);
1443
* Computes class probabilities for instance using the decision tree.
1445
* @param instance the instance for which class probabilities is to be computed
1446
* @return the class probabilities for the given instance
1447
* @throws Exception if something goes wrong
1449
public double[] distributionForInstance(Instance instance)
1452
// value of split attribute is missing
1453
if (instance.isMissing(m_Attribute)) {
1454
double[] returnedDist = new double[m_ClassProbs.length];
1456
for (int i = 0; i < m_Successors.length; i++) {
1458
m_Successors[i].distributionForInstance(instance);
1460
for (int j = 0; j < help.length; j++) {
1461
returnedDist[j] += m_Props[i] * help[j];
1465
return returnedDist;
1468
// split attribute is nonimal
1469
else if (m_Attribute.isNominal()) {
1470
if (m_SplitString.indexOf("(" +
1471
m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
1472
return m_Successors[0].distributionForInstance(instance);
1473
else return m_Successors[1].distributionForInstance(instance);
1476
// split attribute is numeric
1478
if (instance.value(m_Attribute) < m_SplitValue)
1479
return m_Successors[0].distributionForInstance(instance);
1481
return m_Successors[1].distributionForInstance(instance);
1486
else return m_ClassProbs;
1490
* Make the node leaf node.
1492
* @param data trainging data
1494
protected void makeLeaf(Instances data) {
1497
m_ClassValue=Utils.maxIndex(m_ClassProbs);
1498
m_ClassAttribute = data.classAttribute();
1502
* Prints the decision tree using the protected toString method from below.
1504
* @return a textual description of the classifier
1506
public String toString() {
1507
if ((m_ClassProbs == null) && (m_Successors == null)) {
1508
return "CART Tree: No model built yet.";
1511
return "CART Decision Tree\n" + toString(0)+"\n\n"
1512
+"Number of Leaf Nodes: "+numLeaves()+"\n\n" +
1513
"Size of the Tree: "+numNodes();
1517
* Outputs a tree at a certain level.
1519
* @param level the level at which the tree is to be printed
1520
* @return a tree at a certain level
1522
protected String toString(int level) {
1524
StringBuffer text = new StringBuffer();
1526
if (m_Attribute == null) {
1527
if (Instance.isMissingValue(m_ClassValue)) {
1528
text.append(": null");
1530
double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
1532
double wrongNum = (int)((Utils.sum(m_Distribution) -
1533
m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
1534
String str = "(" + correctNum + "/" + wrongNum + ")";
1535
text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
1538
for (int j = 0; j < 2; j++) {
1540
for (int i = 0; i < level; i++) {
1544
if (m_Attribute.isNumeric())
1545
text.append(m_Attribute.name() + " < " + m_SplitValue);
1547
text.append(m_Attribute.name() + "=" + m_SplitString);
1549
if (m_Attribute.isNumeric())
1550
text.append(m_Attribute.name() + " >= " + m_SplitValue);
1552
text.append(m_Attribute.name() + "!=" + m_SplitString);
1554
text.append(m_Successors[j].toString(level + 1));
1557
return text.toString();
1561
* Compute size of the tree.
1563
* @return size of the tree
1565
public int numNodes() {
1570
for (int i=0;i<m_Successors.length;i++) {
1571
size+=m_Successors[i].numNodes();
1578
* Method to count the number of inner nodes in the tree.
1580
* @return the number of inner nodes
1582
public int numInnerNodes(){
1583
if (m_Attribute==null) return 0;
1585
for (int i = 0; i < m_Successors.length; i++)
1586
numNodes += m_Successors[i].numInnerNodes();
1591
* Return a list of all inner nodes in the tree.
1593
* @return the list of all inner nodes
1595
protected Vector getInnerNodes(){
1596
Vector nodeList = new Vector();
1597
fillInnerNodes(nodeList);
1602
* Fills a list with all inner nodes in the tree.
1604
* @param nodeList the list to be filled
1606
protected void fillInnerNodes(Vector nodeList) {
1609
for (int i = 0; i < m_Successors.length; i++)
1610
m_Successors[i].fillInnerNodes(nodeList);
1615
* Compute number of leaf nodes.
1617
* @return number of leaf nodes
1619
public int numLeaves() {
1620
if (m_isLeaf) return 1;
1623
for (int i=0;i<m_Successors.length;i++) {
1624
size+=m_Successors[i].numLeaves();
1631
* Returns an enumeration describing the available options.
1633
* @return an enumeration of all the available options.
1635
public Enumeration listOptions() {
1639
result = new Vector();
1641
en = super.listOptions();
1642
while (en.hasMoreElements())
1643
result.addElement(en.nextElement());
1645
result.addElement(new Option(
1646
"\tThe minimal number of instances at the terminal nodes.\n"
1648
"M", 1, "-M <min no>"));
1650
result.addElement(new Option(
1651
"\tThe number of folds used in the minimal cost-complexity pruning.\n"
1653
"N", 1, "-N <num folds>"));
1655
result.addElement(new Option(
1656
"\tDon't use the minimal cost-complexity pruning.\n"
1657
+ "\t(default yes).",
1660
result.addElement(new Option(
1661
"\tDon't use the heuristic method for binary split.\n"
1662
+ "\t(default true).",
1665
result.addElement(new Option(
1666
"\tUse 1 SE rule to make pruning decision.\n"
1667
+ "\t(default no).",
1670
result.addElement(new Option(
1671
"\tPercentage of training data size (0-1].\n"
1675
return result.elements();
1679
* Parses a given list of options. <p/>
1681
<!-- options-start -->
1682
* Valid options are: <p/>
1684
* <pre> -S <num>
1685
* Random number seed.
1689
* If set, classifier is run in debug mode and
1690
* may output additional info to the console</pre>
1692
* <pre> -M <min no>
1693
* The minimal number of instances at the terminal nodes.
1696
* <pre> -N <num folds>
1697
* The number of folds used in the minimal cost-complexity pruning.
1701
* Don't use the minimal cost-complexity pruning.
1702
* (default yes).</pre>
1705
* Don't use the heuristic method for binary split.
1706
* (default true).</pre>
1709
* Use 1 SE rule to make pruning decision.
1710
* (default no).</pre>
1713
* Percentage of training data size (0-1].
1714
* (default 1).</pre>
1716
<!-- options-end -->
1718
* @param options the list of options as an array of strings
1719
* @throws Exception if an options is not supported
1721
public void setOptions(String[] options) throws Exception {
1724
super.setOptions(options);
1726
tmpStr = Utils.getOption('M', options);
1727
if (tmpStr.length() != 0)
1728
setMinNumObj(Double.parseDouble(tmpStr));
1732
tmpStr = Utils.getOption('N', options);
1733
if (tmpStr.length()!=0)
1734
setNumFoldsPruning(Integer.parseInt(tmpStr));
1736
setNumFoldsPruning(5);
1738
setUsePrune(!Utils.getFlag('U',options));
1739
setHeuristic(!Utils.getFlag('H',options));
1740
setUseOneSE(Utils.getFlag('A',options));
1742
tmpStr = Utils.getOption('C', options);
1743
if (tmpStr.length()!=0)
1744
setSizePer(Double.parseDouble(tmpStr));
1748
Utils.checkForRemainingOptions(options);
1752
* Gets the current settings of the classifier.
1754
* @return the current setting of the classifier
1756
public String[] getOptions() {
1761
result = new Vector();
1763
options = super.getOptions();
1764
for (i = 0; i < options.length; i++)
1765
result.add(options[i]);
1768
result.add("" + getMinNumObj());
1771
result.add("" + getNumFoldsPruning());
1776
if (!getHeuristic())
1783
result.add("" + getSizePer());
1785
return (String[]) result.toArray(new String[result.size()]);
1789
* Return an enumeration of the measure names.
1791
* @return an enumeration of the measure names
1793
public Enumeration enumerateMeasures() {
1794
Vector result = new Vector();
1796
result.addElement("measureTreeSize");
1798
return result.elements();
1802
* Return number of tree size.
1804
* @return number of tree size
1806
public double measureTreeSize() {
1811
* Returns the value of the named measure.
1813
* @param additionalMeasureName the name of the measure to query for its value
1814
* @return the value of the named measure
1815
* @throws IllegalArgumentException if the named measure is not supported
1817
public double getMeasure(String additionalMeasureName) {
1818
if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
1819
return measureTreeSize();
1821
throw new IllegalArgumentException(additionalMeasureName
1822
+ " not supported (Cart pruning)");
1827
* Returns the tip text for this property
1829
* @return tip text for this property suitable for
1830
* displaying in the explorer/experimenter gui
1832
public String minNumObjTipText() {
1833
return "The minimal number of observations at the terminal nodes (default 2).";
1837
* Set minimal number of instances at the terminal nodes.
1839
* @param value minimal number of instances at the terminal nodes
1841
public void setMinNumObj(double value) {
1842
m_minNumObj = value;
1846
* Get minimal number of instances at the terminal nodes.
1848
* @return minimal number of instances at the terminal nodes
1850
public double getMinNumObj() {
1855
* Returns the tip text for this property
1857
* @return tip text for this property suitable for
1858
* displaying in the explorer/experimenter gui
1860
public String numFoldsPruningTipText() {
1861
return "The number of folds in the internal cross-validation (default 5).";
1865
* Set number of folds in internal cross-validation.
1867
* @param value number of folds in internal cross-validation.
1869
public void setNumFoldsPruning(int value) {
1870
m_numFoldsPruning = value;
1874
* Set number of folds in internal cross-validation.
1876
* @return number of folds in internal cross-validation.
1878
public int getNumFoldsPruning() {
1879
return m_numFoldsPruning;
1883
* Return the tip text for this property
1885
* @return tip text for this property suitable for displaying in
1886
* the explorer/experimenter gui.
1888
public String usePruneTipText() {
1889
return "Use minimal cost-complexity pruning (default yes).";
1893
* Set if use minimal cost-complexity pruning.
1895
* @param value if use minimal cost-complexity pruning
1897
public void setUsePrune(boolean value) {
1902
* Get if use minimal cost-complexity pruning.
1904
* @return if use minimal cost-complexity pruning
1906
public boolean getUsePrune() {
1911
* Returns the tip text for this property
1913
* @return tip text for this property suitable for
1914
* displaying in the explorer/experimenter gui.
1916
public String heuristicTipText() {
1918
"If heuristic search is used for binary split for nominal attributes "
1919
+ "in multi-class problems (default yes).";
1923
* Set if use heuristic search for nominal attributes in multi-class problems.
1925
* @param value if use heuristic search for nominal attributes in
1926
* multi-class problems
1928
public void setHeuristic(boolean value) {
1929
m_Heuristic = value;
1933
* Get if use heuristic search for nominal attributes in multi-class problems.
1935
* @return if use heuristic search for nominal attributes in
1936
* multi-class problems
1938
public boolean getHeuristic() {return m_Heuristic;}
1941
* Returns the tip text for this property
1943
* @return tip text for this property suitable for
1944
* displaying in the explorer/experimenter gui.
1946
public String useOneSETipText() {
1947
return "Use the 1SE rule to make pruning decisoin.";
1951
* Set if use the 1SE rule to choose final model.
1953
* @param value if use the 1SE rule to choose final model
1955
public void setUseOneSE(boolean value) {
1960
* Get if use the 1SE rule to choose final model.
1962
* @return if use the 1SE rule to choose final model
1964
public boolean getUseOneSE() {
1969
* Returns the tip text for this property
1971
* @return tip text for this property suitable for
1972
* displaying in the explorer/experimenter gui.
1974
public String sizePerTipText() {
1975
return "The percentage of the training set size (0-1, 0 not included).";
1979
* Set training set size.
1981
* @param value training set size
1983
public void setSizePer(double value) {
1984
if ((value <= 0) || (value > 1))
1986
"The percentage of the training set size must be in range 0 to 1 "
1987
+ "(0 not included) - ignored!");
1993
* Get training set size.
1995
* @return training set size
1997
public double getSizePer() {
2003
* @param args the options for the classifier
2005
public static void main(String[] args) {
2006
runClassifier(new SimpleCart(), args);