~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/trees/j48/NBTreeModelSelection.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    NBTreeModelSelection.java
 
19
 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers.trees.j48;
 
24
 
 
25
import weka.core.Attribute;
 
26
import weka.core.Instances;
 
27
import weka.core.Utils;
 
28
 
 
29
import java.util.Enumeration;
 
30
 
 
31
/**
 
32
 * Class for selecting a NB tree split.
 
33
 *
 
34
 * @author Mark Hall (mhall@cs.waikato.ac.nz)
 
35
 * @version $Revision: 1.4 $
 
36
 */
 
37
public class NBTreeModelSelection
 
38
  extends ModelSelection {
 
39
 
 
40
  /** for serialization */
 
41
  private static final long serialVersionUID = 990097748931976704L;
 
42
 
 
43
  /** Minimum number of objects in interval. */
 
44
  private int m_minNoObj;               
 
45
 
 
46
  /** All the training data */
 
47
  private Instances m_allData; // 
 
48
 
 
49
  /**
 
50
   * Initializes the split selection method with the given parameters.
 
51
   *
 
52
   * @param minNoObj minimum number of instances that have to occur in at least two
 
53
   * subsets induced by split
 
54
   * @param allData FULL training dataset (necessary for
 
55
   * selection of split points).
 
56
   */
 
57
  public NBTreeModelSelection(int minNoObj, Instances allData) {
 
58
    m_minNoObj = minNoObj;
 
59
    m_allData = allData;
 
60
  }
 
61
 
 
62
  /**
 
63
   * Sets reference to training data to null.
 
64
   */
 
65
  public void cleanup() {
 
66
 
 
67
    m_allData = null;
 
68
  }
 
69
 
 
70
  /**
 
71
   * Selects NBTree-type split for the given dataset.
 
72
   */
 
73
  public final ClassifierSplitModel selectModel(Instances data){
 
74
 
 
75
    double globalErrors = 0;
 
76
 
 
77
    double minResult;
 
78
    double currentResult;
 
79
    NBTreeSplit [] currentModel;
 
80
    NBTreeSplit bestModel = null;
 
81
    NBTreeNoSplit noSplitModel = null;
 
82
    int validModels = 0;
 
83
    boolean multiVal = true;
 
84
    Distribution checkDistribution;
 
85
    Attribute attribute;
 
86
    double sumOfWeights;
 
87
    int i;
 
88
    
 
89
    try{
 
90
      // build the global model at this node
 
91
      noSplitModel = new NBTreeNoSplit();
 
92
      noSplitModel.buildClassifier(data);
 
93
      if (data.numInstances() < 5) {
 
94
        return noSplitModel;
 
95
      }
 
96
 
 
97
      // evaluate it
 
98
      globalErrors = noSplitModel.getErrors();
 
99
      if (globalErrors == 0) {
 
100
        return noSplitModel;
 
101
      }
 
102
 
 
103
      // Check if all Instances belong to one class or if not
 
104
      // enough Instances to split.
 
105
      checkDistribution = new Distribution(data);
 
106
      if (Utils.sm(checkDistribution.total(), m_minNoObj) ||
 
107
          Utils.eq(checkDistribution.total(),
 
108
                   checkDistribution.perClass(checkDistribution.maxClass()))) {
 
109
        return noSplitModel;
 
110
      }
 
111
 
 
112
      // Check if all attributes are nominal and have a 
 
113
      // lot of values.
 
114
      if (m_allData != null) {
 
115
        Enumeration enu = data.enumerateAttributes();
 
116
        while (enu.hasMoreElements()) {
 
117
          attribute = (Attribute) enu.nextElement();
 
118
          if ((attribute.isNumeric()) ||
 
119
              (Utils.sm((double)attribute.numValues(),
 
120
                        (0.3*(double)m_allData.numInstances())))){
 
121
            multiVal = false;
 
122
            break;
 
123
          }
 
124
        }
 
125
      }
 
126
 
 
127
      currentModel = new NBTreeSplit[data.numAttributes()];
 
128
      sumOfWeights = data.sumOfWeights();
 
129
 
 
130
      // For each attribute.
 
131
      for (i = 0; i < data.numAttributes(); i++){
 
132
        
 
133
        // Apart from class attribute.
 
134
        if (i != (data).classIndex()){
 
135
          
 
136
          // Get models for current attribute.
 
137
          currentModel[i] = new NBTreeSplit(i,m_minNoObj,sumOfWeights);
 
138
          currentModel[i].setGlobalModel(noSplitModel);
 
139
          currentModel[i].buildClassifier(data);
 
140
          
 
141
          // Check if useful split for current attribute
 
142
          // exists and check for enumerated attributes with 
 
143
          // a lot of values.
 
144
          if (currentModel[i].checkModel()){
 
145
            validModels++;
 
146
          }
 
147
        } else {
 
148
          currentModel[i] = null;
 
149
        }
 
150
      }
 
151
      
 
152
      // Check if any useful split was found.
 
153
      if (validModels == 0) {
 
154
        return noSplitModel;
 
155
      }
 
156
      
 
157
     // Find "best" attribute to split on.
 
158
      minResult = globalErrors;
 
159
      for (i=0;i<data.numAttributes();i++){
 
160
        if ((i != (data).classIndex()) &&
 
161
            (currentModel[i].checkModel())) {
 
162
          /*  System.err.println("Errors for "+data.attribute(i).name()+" "+
 
163
              currentModel[i].getErrors()); */
 
164
          if (currentModel[i].getErrors() < minResult) {
 
165
            bestModel = currentModel[i];
 
166
            minResult = currentModel[i].getErrors();
 
167
          }
 
168
        }
 
169
      }
 
170
      //      System.exit(1);
 
171
      // Check if useful split was found.
 
172
      
 
173
 
 
174
      if (((globalErrors - minResult) / globalErrors) < 0.05) {
 
175
        return noSplitModel;
 
176
      }
 
177
      
 
178
      /*      if (bestModel == null) {
 
179
        System.err.println("This shouldn't happen! glob : "+globalErrors+
 
180
                           " minRes : "+minResult);
 
181
        System.exit(1);
 
182
        } */
 
183
      // Set the global model for the best split
 
184
      //      bestModel.setGlobalModel(noSplitModel);
 
185
 
 
186
      return bestModel;
 
187
    }catch(Exception e){
 
188
      e.printStackTrace();
 
189
    }
 
190
    return null;
 
191
  }
 
192
 
 
193
  /**
 
194
   * Selects NBTree-type split for the given dataset.
 
195
   */
 
196
  public final ClassifierSplitModel selectModel(Instances train, Instances test) {
 
197
 
 
198
    return selectModel(train);
 
199
  }
 
200
}