2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
23
package weka.classifiers.trees.j48;
25
import weka.core.Capabilities;
26
import weka.core.CapabilitiesHandler;
27
import weka.core.Drawable;
28
import weka.core.Instance;
29
import weka.core.Instances;
30
import weka.core.Utils;
32
import java.io.Serializable;
35
* Class for handling a tree structure used for
38
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
39
* @version $Revision: 1.21 $
41
public class ClassifierTree
42
implements Drawable, Serializable, CapabilitiesHandler {
44
/** for serialization */
45
static final long serialVersionUID = -8722249377542734193L;
47
/** The model selection method. */
48
protected ModelSelection m_toSelectModel;
50
/** Local model at node. */
51
protected ClassifierSplitModel m_localModel;
53
/** References to sons. */
54
protected ClassifierTree [] m_sons;
56
/** True if node is leaf. */
57
protected boolean m_isLeaf;
59
/** True if node is empty. */
60
protected boolean m_isEmpty;
62
/** The training instances. */
63
protected Instances m_train;
65
/** The pruning instances. */
66
protected Distribution m_test;
68
/** The id for the node. */
72
* For getting a unique ID when outputting the tree (hashcode isn't
75
private static long PRINTED_NODES = 0;
78
* Gets the next unique node ID.
80
* @return the next unique node ID.
82
protected static long nextID() {
84
return PRINTED_NODES ++;
88
* Resets the unique node ID counter (e.g.
89
* between repeated separate print types)
91
protected static void resetID() {
99
public ClassifierTree(ModelSelection toSelectLocModel) {
101
m_toSelectModel = toSelectLocModel;
105
* Returns default capabilities of the classifier tree.
107
* @return the capabilities of this classifier tree
109
public Capabilities getCapabilities() {
110
return new Capabilities(this);
114
* Method for building a classifier tree.
116
* @param data the data to build the tree from
117
* @throws Exception if something goes wrong
119
public void buildClassifier(Instances data) throws Exception {
121
// can classifier tree handle the data?
122
getCapabilities().testWithFail(data);
124
// remove instances with missing class
125
data = new Instances(data);
126
data.deleteWithMissingClass();
128
buildTree(data, false);
132
* Builds the tree structure.
134
* @param data the data for which the tree structure is to be
136
* @param keepData is training data to be kept?
137
* @throws Exception if something goes wrong
139
public void buildTree(Instances data, boolean keepData) throws Exception {
141
Instances [] localInstances;
150
m_localModel = m_toSelectModel.selectModel(data);
151
if (m_localModel.numSubsets() > 1) {
152
localInstances = m_localModel.split(data);
154
m_sons = new ClassifierTree [m_localModel.numSubsets()];
155
for (int i = 0; i < m_sons.length; i++) {
156
m_sons[i] = getNewTree(localInstances[i]);
157
localInstances[i] = null;
161
if (Utils.eq(data.sumOfWeights(), 0))
168
* Builds the tree structure with hold out set
170
* @param train the data for which the tree structure is to be
172
* @param test the test data for potential pruning
173
* @param keepData is training Data to be kept?
174
* @throws Exception if something goes wrong
176
public void buildTree(Instances train, Instances test, boolean keepData)
179
Instances [] localTrain, localTest;
188
m_localModel = m_toSelectModel.selectModel(train, test);
189
m_test = new Distribution(test, m_localModel);
190
if (m_localModel.numSubsets() > 1) {
191
localTrain = m_localModel.split(train);
192
localTest = m_localModel.split(test);
194
m_sons = new ClassifierTree [m_localModel.numSubsets()];
195
for (i=0;i<m_sons.length;i++) {
196
m_sons[i] = getNewTree(localTrain[i], localTest[i]);
197
localTrain[i] = null;
202
if (Utils.eq(train.sumOfWeights(), 0))
209
* Classifies an instance.
211
* @param instance the instance to classify
212
* @return the classification
213
* @throws Exception if something goes wrong
215
public double classifyInstance(Instance instance)
223
for (j = 0; j < instance.numClasses(); j++) {
224
currentProb = getProbs(j, instance, 1);
225
if (Utils.gr(currentProb,maxProb)) {
227
maxProb = currentProb;
231
return (double)maxIndex;
235
* Cleanup in order to save memory.
237
* @param justHeaderInfo
239
public final void cleanup(Instances justHeaderInfo) {
241
m_train = justHeaderInfo;
244
for (int i = 0; i < m_sons.length; i++)
245
m_sons[i].cleanup(justHeaderInfo);
249
* Returns class probabilities for a weighted instance.
251
* @param instance the instance to get the distribution for
252
* @param useLaplace whether to use laplace or not
253
* @return the distribution
254
* @throws Exception if something goes wrong
256
public final double [] distributionForInstance(Instance instance,
260
double [] doubles = new double[instance.numClasses()];
262
for (int i = 0; i < doubles.length; i++) {
264
doubles[i] = getProbs(i, instance, 1);
266
doubles[i] = getProbsLaplace(i, instance, 1);
274
* Assigns a uniqe id to every node in the tree.
276
* @param lastID the last ID that was assign
277
* @return the new current ID
279
public int assignIDs(int lastID) {
281
int currLastID = lastID + 1;
284
if (m_sons != null) {
285
for (int i = 0; i < m_sons.length; i++) {
286
currLastID = m_sons[i].assignIDs(currLastID);
293
* Returns the type of graph this classifier
295
* @return Drawable.TREE
297
public int graphType() {
298
return Drawable.TREE;
302
* Returns graph describing the tree.
304
* @throws Exception if something goes wrong
305
* @return the tree as graph
307
public String graph() throws Exception {
309
StringBuffer text = new StringBuffer();
312
text.append("digraph J48Tree {\n");
314
text.append("N" + m_id
316
m_localModel.dumpLabel(0,m_train) + "\" " +
317
"shape=box style=filled ");
318
if (m_train != null && m_train.numInstances() > 0) {
319
text.append("data =\n" + m_train + "\n");
325
text.append("N" + m_id
327
m_localModel.leftSide(m_train) + "\" ");
328
if (m_train != null && m_train.numInstances() > 0) {
329
text.append("data =\n" + m_train + "\n");
336
return text.toString() +"}\n";
340
* Returns tree in prefix order.
342
* @throws Exception if something goes wrong
343
* @return the prefix order
345
public String prefix() throws Exception {
349
text = new StringBuffer();
351
text.append("["+m_localModel.dumpLabel(0,m_train)+"]");
356
return text.toString();
360
* Returns source code for the tree as an if-then statement. The
361
* class is assigned to variable "p", and assumes the tested
362
* instance is named "i". The results are returned as two stringbuffers:
363
* a section of code for assignment of the class, and a section of
364
* code containing support code (eg: other support methods).
366
* @param className the classname that this static classifier has
367
* @return an array containing two stringbuffers, the first string containing
368
* assignment code, and the second containing source for support code.
369
* @throws Exception if something goes wrong
371
public StringBuffer [] toSource(String className) throws Exception {
373
StringBuffer [] result = new StringBuffer [2];
375
result[0] = new StringBuffer(" p = "
376
+ m_localModel.distribution().maxClass(0) + ";\n");
377
result[1] = new StringBuffer("");
379
StringBuffer text = new StringBuffer();
380
StringBuffer atEnd = new StringBuffer();
382
long printID = ClassifierTree.nextID();
384
text.append(" static double N")
385
.append(Integer.toHexString(m_localModel.hashCode()) + printID)
386
.append("(Object []i) {\n")
387
.append(" double p = Double.NaN;\n");
390
.append(m_localModel.sourceExpression(-1, m_train))
393
.append(m_localModel.distribution().maxClass(0))
396
for (int i = 0; i < m_sons.length; i++) {
397
text.append("else if (" + m_localModel.sourceExpression(i, m_train)
399
if (m_sons[i].m_isLeaf) {
401
+ m_localModel.distribution().maxClass(i) + ";\n");
403
StringBuffer [] sub = m_sons[i].toSource(className);
405
atEnd.append(sub[1]);
408
if (i == m_sons.length - 1) {
413
text.append(" return p;\n }\n");
415
result[0] = new StringBuffer(" p = " + className + ".N");
416
result[0].append(Integer.toHexString(m_localModel.hashCode()) + printID)
418
result[1] = text.append(atEnd);
424
* Returns number of leaves in tree structure.
426
* @return the number of leaves
428
public int numLeaves() {
436
for (i=0;i<m_sons.length;i++)
437
num = num+m_sons[i].numLeaves();
443
* Returns number of nodes in tree structure.
445
* @return the number of nodes
447
public int numNodes() {
453
for (i=0;i<m_sons.length;i++)
454
no = no+m_sons[i].numNodes();
460
* Prints tree structure.
462
* @return the tree structure
464
public String toString() {
467
StringBuffer text = new StringBuffer();
471
text.append(m_localModel.dumpLabel(0,m_train));
474
text.append("\n\nNumber of Leaves : \t"+numLeaves()+"\n");
475
text.append("\nSize of the tree : \t"+numNodes()+"\n");
477
return text.toString();
478
} catch (Exception e) {
479
return "Can't print classification tree.";
484
* Returns a newly created tree.
486
* @param data the training data
487
* @return the generated tree
488
* @throws Exception if something goes wrong
490
protected ClassifierTree getNewTree(Instances data) throws Exception {
492
ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
493
newTree.buildTree(data, false);
499
* Returns a newly created tree.
501
* @param train the training data
502
* @param test the pruning data.
503
* @return the generated tree
504
* @throws Exception if something goes wrong
506
protected ClassifierTree getNewTree(Instances train, Instances test)
509
ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
510
newTree.buildTree(train, test, false);
516
* Help method for printing tree structure.
518
* @param depth the current depth
519
* @param text for outputting the structure
520
* @throws Exception if something goes wrong
522
private void dumpTree(int depth, StringBuffer text)
527
for (i=0;i<m_sons.length;i++) {
529
for (j=0;j<depth;j++)
531
text.append(m_localModel.leftSide(m_train));
532
text.append(m_localModel.rightSide(i, m_train));
533
if (m_sons[i].m_isLeaf) {
535
text.append(m_localModel.dumpLabel(i,m_train));
537
m_sons[i].dumpTree(depth+1,text);
542
* Help method for printing tree structure as a graph.
544
* @param text for outputting the tree
545
* @throws Exception if something goes wrong
547
private void graphTree(StringBuffer text) throws Exception {
549
for (int i = 0; i < m_sons.length; i++) {
550
text.append("N" + m_id
552
"N" + m_sons[i].m_id +
553
" [label=\"" + m_localModel.rightSide(i,m_train).trim() +
555
if (m_sons[i].m_isLeaf) {
556
text.append("N" + m_sons[i].m_id +
557
" [label=\""+m_localModel.dumpLabel(i,m_train)+"\" "+
558
"shape=box style=filled ");
559
if (m_train != null && m_train.numInstances() > 0) {
560
text.append("data =\n" + m_sons[i].m_train + "\n");
565
text.append("N" + m_sons[i].m_id +
566
" [label=\""+m_sons[i].m_localModel.leftSide(m_train) +
568
if (m_train != null && m_train.numInstances() > 0) {
569
text.append("data =\n" + m_sons[i].m_train + "\n");
573
m_sons[i].graphTree(text);
579
* Prints the tree in prefix form
581
* @param text the buffer to output the prefix form to
582
* @throws Exception if something goes wrong
584
private void prefixTree(StringBuffer text) throws Exception {
587
text.append(m_localModel.leftSide(m_train)+":");
588
for (int i = 0; i < m_sons.length; i++) {
592
text.append(m_localModel.rightSide(i, m_train));
594
for (int i = 0; i < m_sons.length; i++) {
595
if (m_sons[i].m_isLeaf) {
597
text.append(m_localModel.dumpLabel(i,m_train));
600
m_sons[i].prefixTree(text);
607
* Help method for computing class probabilities of
610
* @param classIndex the class index
611
* @param instance the instance to compute the probabilities for
612
* @param weight the weight to use
613
* @return the laplace probs
614
* @throws Exception if something goes wrong
616
private double getProbsLaplace(int classIndex, Instance instance, double weight)
622
return weight * localModel().classProbLaplace(classIndex, instance, -1);
624
int treeIndex = localModel().whichSubset(instance);
625
if (treeIndex == -1) {
626
double[] weights = localModel().weights(instance);
627
for (int i = 0; i < m_sons.length; i++) {
628
if (!son(i).m_isEmpty) {
629
prob += son(i).getProbsLaplace(classIndex, instance,
630
weights[i] * weight);
635
if (son(treeIndex).m_isEmpty) {
636
return weight * localModel().classProbLaplace(classIndex, instance,
639
return son(treeIndex).getProbsLaplace(classIndex, instance, weight);
646
* Help method for computing class probabilities of
649
* @param classIndex the class index
650
* @param instance the instance to compute the probabilities for
651
* @param weight the weight to use
653
* @throws Exception if something goes wrong
655
private double getProbs(int classIndex, Instance instance, double weight)
661
return weight * localModel().classProb(classIndex, instance, -1);
663
int treeIndex = localModel().whichSubset(instance);
664
if (treeIndex == -1) {
665
double[] weights = localModel().weights(instance);
666
for (int i = 0; i < m_sons.length; i++) {
667
if (!son(i).m_isEmpty) {
668
prob += son(i).getProbs(classIndex, instance,
669
weights[i] * weight);
674
if (son(treeIndex).m_isEmpty) {
675
return weight * localModel().classProb(classIndex, instance,
678
return son(treeIndex).getProbs(classIndex, instance, weight);
685
* Method just exists to make program easier to read.
687
private ClassifierSplitModel localModel() {
689
return (ClassifierSplitModel)m_localModel;
693
* Method just exists to make program easier to read.
695
private ClassifierTree son(int index) {
697
return (ClassifierTree)m_sons[index];