~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/rules/part/ClassifierDecList.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    ClassifierDecList.java
 
19
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers.rules.part;
 
24
 
 
25
import weka.classifiers.trees.j48.ClassifierSplitModel;
 
26
import weka.classifiers.trees.j48.Distribution;
 
27
import weka.classifiers.trees.j48.EntropySplitCrit;
 
28
import weka.classifiers.trees.j48.ModelSelection;
 
29
import weka.classifiers.trees.j48.NoSplit;
 
30
import weka.core.Instance;
 
31
import weka.core.Instances;
 
32
import weka.core.Utils;
 
33
 
 
34
import java.io.Serializable;
 
35
 
 
36
/**
 
37
 * Class for handling a rule (partial tree) for a decision list.
 
38
 *
 
39
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 
40
 * @version $Revision: 1.12 $
 
41
 */
 
42
public class ClassifierDecList
 
43
  implements Serializable {
 
44
 
 
45
  /** for serialization */
 
46
  private static final long serialVersionUID = 7284358349711992497L;
 
47
 
 
48
  /** Minimum number of objects */
 
49
  protected int m_minNumObj;
 
50
 
 
51
  /** To compute the entropy. */
 
52
  protected static EntropySplitCrit m_splitCrit = new EntropySplitCrit();
 
53
 
 
54
  /** The model selection method. */
 
55
  protected ModelSelection m_toSelectModel;   
 
56
 
 
57
  /** Local model at node. */  
 
58
  protected ClassifierSplitModel m_localModel; 
 
59
 
 
60
  /** References to sons. */
 
61
  protected ClassifierDecList [] m_sons;       
 
62
  
 
63
  /** True if node is leaf. */
 
64
  protected boolean m_isLeaf;   
 
65
 
 
66
  /** True if node is empty. */
 
67
  protected boolean m_isEmpty;                 
 
68
 
 
69
  /** The training instances. */
 
70
  protected Instances m_train;                 
 
71
 
 
72
  /** The pruning instances. */ 
 
73
  protected Distribution m_test;               
 
74
 
 
75
  /** Which son to expand? */  
 
76
  protected int indeX;         
 
77
 
 
78
  /**
 
79
   * Constructor - just calls constructor of class DecList.
 
80
   */
 
81
  public ClassifierDecList(ModelSelection toSelectLocModel, int minNum){
 
82
 
 
83
    m_toSelectModel = toSelectLocModel;
 
84
    m_minNumObj = minNum;
 
85
   }
 
86
  
 
87
  /**
 
88
   * Method for building a pruned partial tree.
 
89
   *
 
90
   * @exception Exception if something goes wrong
 
91
   */
 
92
  public void buildRule(Instances data) throws Exception {
 
93
    
 
94
    buildDecList(data, false);
 
95
 
 
96
    cleanup(new Instances(data, 0));
 
97
  }
 
98
 
 
99
  /**
 
100
   * Builds the partial tree without hold out set.
 
101
   *
 
102
   * @exception Exception if something goes wrong
 
103
   */
 
104
  public void buildDecList(Instances data, boolean leaf) throws Exception {
 
105
    
 
106
    Instances [] localInstances,localPruneInstances;
 
107
    int index,ind;
 
108
    int i,j;
 
109
    double sumOfWeights;
 
110
    NoSplit noSplit;
 
111
    
 
112
    m_train = null;
 
113
    m_test = null;
 
114
    m_isLeaf = false;
 
115
    m_isEmpty = false;
 
116
    m_sons = null;
 
117
    indeX = 0;
 
118
    sumOfWeights = data.sumOfWeights();
 
119
    noSplit = new NoSplit (new Distribution((Instances)data));
 
120
    if (leaf)
 
121
      m_localModel = noSplit;
 
122
    else
 
123
      m_localModel = m_toSelectModel.selectModel(data);
 
124
    if (m_localModel.numSubsets() > 1) {
 
125
      localInstances = m_localModel.split(data);
 
126
      data = null;
 
127
      m_sons = new ClassifierDecList [m_localModel.numSubsets()];
 
128
      i = 0;
 
129
      do {
 
130
        i++;
 
131
        ind = chooseIndex();
 
132
        if (ind == -1) {
 
133
          for (j = 0; j < m_sons.length; j++) 
 
134
            if (m_sons[j] == null)
 
135
              m_sons[j] = getNewDecList(localInstances[j],true);
 
136
          if (i < 2) {
 
137
            m_localModel = noSplit;
 
138
            m_isLeaf = true;
 
139
            m_sons = null;
 
140
            if (Utils.eq(sumOfWeights,0))
 
141
              m_isEmpty = true;
 
142
            return;
 
143
          }
 
144
          ind = 0;
 
145
          break;
 
146
        } else 
 
147
          m_sons[ind] = getNewDecList(localInstances[ind],false);
 
148
      } while ((i < m_sons.length) && (m_sons[ind].m_isLeaf));
 
149
      
 
150
      // Choose rule
 
151
      indeX = chooseLastIndex();
 
152
    }else{
 
153
      m_isLeaf = true;
 
154
      if (Utils.eq(sumOfWeights, 0))
 
155
        m_isEmpty = true;
 
156
    }
 
157
  }
 
158
 
 
159
  /** 
 
160
   * Classifies an instance.
 
161
   *
 
162
   * @exception Exception if something goes wrong
 
163
   */
 
164
  public double classifyInstance(Instance instance)
 
165
       throws Exception {
 
166
 
 
167
    double maxProb = -1;
 
168
    double currentProb;
 
169
    int maxIndex = 0;
 
170
    int j;
 
171
 
 
172
    for (j = 0; j < instance.numClasses();
 
173
         j++){
 
174
      currentProb = getProbs(j,instance,1);
 
175
      if (Utils.gr(currentProb,maxProb)){
 
176
        maxIndex = j;
 
177
        maxProb = currentProb;
 
178
      }
 
179
    }
 
180
    if (Utils.eq(maxProb,0))
 
181
      return -1.0;
 
182
    else
 
183
      return (double)maxIndex;
 
184
  }
 
185
 
 
186
  /** 
 
187
   * Returns class probabilities for a weighted instance.
 
188
   *
 
189
   * @exception Exception if something goes wrong
 
190
   */
 
191
  public final double [] distributionForInstance(Instance instance) 
 
192
       throws Exception {
 
193
                
 
194
 
 
195
    double [] doubles =
 
196
      new double[instance.numClasses()];
 
197
 
 
198
    for (int i = 0; i < doubles.length; i++)
 
199
      doubles[i] = getProbs(i,instance,1);
 
200
    
 
201
    return doubles;
 
202
  }
 
203
  
 
204
  /**
 
205
   * Returns the weight a rule assigns to an instance.
 
206
   *
 
207
   * @exception Exception if something goes wrong
 
208
   */
 
209
  public double weight(Instance instance) throws Exception {
 
210
 
 
211
    int subset;
 
212
 
 
213
    if (m_isLeaf)
 
214
      return 1;
 
215
    subset = m_localModel.whichSubset(instance);
 
216
    if (subset == -1)
 
217
      return (m_localModel.weights(instance))[indeX]*
 
218
        m_sons[indeX].weight(instance);
 
219
    if (subset == indeX)
 
220
      return m_sons[indeX].weight(instance);
 
221
    return 0;
 
222
  }
 
223
 
 
224
  /**
 
225
   * Cleanup in order to save memory.
 
226
   */
 
227
  public final void cleanup(Instances justHeaderInfo) {
 
228
 
 
229
    m_train = justHeaderInfo;
 
230
    m_test = null;
 
231
    if (!m_isLeaf)
 
232
      for (int i = 0; i < m_sons.length; i++)
 
233
        if (m_sons[i] != null)
 
234
          m_sons[i].cleanup(justHeaderInfo);
 
235
  }
 
236
 
 
237
  /**
 
238
   * Prints rules.
 
239
   */
 
240
  public String toString(){
 
241
 
 
242
    try {
 
243
      StringBuffer text;
 
244
      
 
245
      text = new StringBuffer();
 
246
      if (m_isLeaf){
 
247
        text.append(": ");
 
248
        text.append(m_localModel.dumpLabel(0,m_train)+"\n");
 
249
      }else{
 
250
      dumpDecList(text);
 
251
      //dumpTree(0,text);
 
252
      }
 
253
      return text.toString();
 
254
    } catch (Exception e) {
 
255
      return "Can't print rule.";
 
256
    }
 
257
  }
 
258
 
 
259
  /**
 
260
   * Returns a newly created tree.
 
261
   *
 
262
   * @exception Exception if something goes wrong
 
263
   */
 
264
  protected ClassifierDecList getNewDecList(Instances train, boolean leaf) 
 
265
    throws Exception {
 
266
         
 
267
    ClassifierDecList newDecList = new ClassifierDecList(m_toSelectModel,
 
268
                                                         m_minNumObj);
 
269
    newDecList.buildDecList(train,leaf);
 
270
    
 
271
    return newDecList;
 
272
  }
 
273
 
 
274
  /**
 
275
   * Method for choosing a subset to expand.
 
276
   */
 
277
  public final int chooseIndex() {
 
278
    
 
279
    int minIndex = -1;
 
280
    double estimated, min = Double.MAX_VALUE;
 
281
    int i, j;
 
282
 
 
283
    for (i = 0; i < m_sons.length; i++)
 
284
      if (son(i) == null) {
 
285
        if (Utils.sm(localModel().distribution().perBag(i),
 
286
                     (double)m_minNumObj))
 
287
          estimated = Double.MAX_VALUE;
 
288
        else{
 
289
          estimated = 0;
 
290
          for (j = 0; j < localModel().distribution().numClasses(); j++) 
 
291
            estimated -= m_splitCrit.logFunc(localModel().distribution().
 
292
                                     perClassPerBag(i,j));
 
293
          estimated += m_splitCrit.logFunc(localModel().distribution().
 
294
                                   perBag(i));
 
295
          estimated /= localModel().distribution().perBag(i);
 
296
        }
 
297
        if (Utils.smOrEq(estimated,0))
 
298
          return i;
 
299
        if (Utils.sm(estimated,min)) {
 
300
          min = estimated;
 
301
          minIndex = i;
 
302
        }
 
303
      }
 
304
 
 
305
    return minIndex;
 
306
  }
 
307
  
 
308
  /**
 
309
   * Choose last index (ie. choose rule).
 
310
   */
 
311
  public final int chooseLastIndex() {
 
312
    
 
313
    int minIndex = 0;
 
314
    double estimated, min = Double.MAX_VALUE;
 
315
    
 
316
    if (!m_isLeaf) 
 
317
      for (int i = 0; i < m_sons.length; i++)
 
318
        if (son(i) != null) {
 
319
          if (Utils.grOrEq(localModel().distribution().perBag(i),
 
320
                           (double)m_minNumObj)) {
 
321
            estimated = son(i).getSizeOfBranch();
 
322
            if (Utils.sm(estimated,min)) {
 
323
              min = estimated;
 
324
              minIndex = i;
 
325
            }
 
326
          }
 
327
        }
 
328
 
 
329
    return minIndex;
 
330
  }
 
331
 
 
332
  /**
 
333
   * Returns the number of instances covered by a branch
 
334
   */
 
335
  protected double getSizeOfBranch() {
 
336
    
 
337
    if (m_isLeaf) {
 
338
      return -localModel().distribution().total();
 
339
    } else
 
340
      return son(indeX).getSizeOfBranch();
 
341
  }
 
342
 
 
343
  /**
 
344
   * Help method for printing tree structure.
 
345
   */
 
346
  private void dumpDecList(StringBuffer text) throws Exception {
 
347
    
 
348
    text.append(m_localModel.leftSide(m_train));
 
349
    text.append(m_localModel.rightSide(indeX, m_train));
 
350
    if (m_sons[indeX].m_isLeaf){
 
351
      text.append(": ");
 
352
      text.append(m_localModel.dumpLabel(indeX,m_train)+"\n");
 
353
    }else{
 
354
      text.append(" AND\n");
 
355
      m_sons[indeX].dumpDecList(text);
 
356
    }
 
357
  }
 
358
 
 
359
  /**
 
360
   * Dumps the partial tree (only used for debugging)
 
361
   *
 
362
   * @exception Exception Exception if something goes wrong
 
363
   */
 
364
  private void dumpTree(int depth,StringBuffer text)
 
365
       throws Exception {
 
366
    
 
367
    int i,j;
 
368
    
 
369
    for (i=0;i<m_sons.length;i++){
 
370
      text.append("\n");;
 
371
      for (j=0;j<depth;j++)
 
372
        text.append("|   ");
 
373
      text.append(m_localModel.leftSide(m_train));
 
374
      text.append(m_localModel.rightSide(i, m_train));
 
375
      if (m_sons[i] == null)
 
376
        text.append("null");
 
377
      else if (m_sons[i].m_isLeaf){
 
378
        text.append(": ");
 
379
        text.append(m_localModel.dumpLabel(i,m_train));
 
380
      }else
 
381
        m_sons[i].dumpTree(depth+1,text);
 
382
    }
 
383
  }
 
384
 
 
385
  /**
 
386
   * Help method for computing class probabilities of 
 
387
   * a given instance.
 
388
   *
 
389
   * @exception Exception Exception if something goes wrong
 
390
   */
 
391
  private double getProbs(int classIndex,Instance instance,
 
392
                          double weight) throws Exception {
 
393
    
 
394
    double [] weights;
 
395
    int treeIndex;
 
396
 
 
397
    if (m_isLeaf) {
 
398
      return weight * localModel().classProb(classIndex, instance, -1);
 
399
    } else {
 
400
      treeIndex = localModel().whichSubset(instance);
 
401
      if (treeIndex == -1) {
 
402
        weights = localModel().weights(instance);
 
403
        return son(indeX).getProbs(classIndex, instance, 
 
404
                                   weights[indeX] * weight);
 
405
      }else{
 
406
        if (treeIndex == indeX) {
 
407
          return son(indeX).getProbs(classIndex, instance, weight);
 
408
        } else {
 
409
          return 0;
 
410
        }
 
411
      }
 
412
    }
 
413
  }
 
414
 
 
415
  /**
 
416
   * Method just exists to make program easier to read.
 
417
   */
 
418
  protected ClassifierSplitModel localModel(){
 
419
    
 
420
    return (ClassifierSplitModel)m_localModel;
 
421
  }
 
422
 
 
423
  /**
 
424
   * Method just exists to make program easier to read.
 
425
   */
 
426
  protected ClassifierDecList son(int index){
 
427
    
 
428
    return m_sons[index];
 
429
  }
 
430
}
 
431
 
 
432
 
 
433
 
 
434
 
 
435