2
* This program is free software; you can redistribsute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* DataNearBalancedND.java
19
* Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
22
package weka.classifiers.meta.nestedDichotomies;
24
import weka.classifiers.Classifier;
25
import weka.classifiers.RandomizableSingleClassifierEnhancer;
26
import weka.classifiers.meta.FilteredClassifier;
27
import weka.core.Capabilities;
28
import weka.core.Instance;
29
import weka.core.Instances;
30
import weka.core.Range;
31
import weka.core.TechnicalInformation;
32
import weka.core.TechnicalInformationHandler;
33
import weka.core.Utils;
34
import weka.core.Capabilities.Capability;
35
import weka.core.TechnicalInformation.Field;
36
import weka.core.TechnicalInformation.Type;
37
import weka.filters.Filter;
38
import weka.filters.unsupervised.attribute.MakeIndicator;
39
import weka.filters.unsupervised.instance.RemoveWithValues;
41
import java.util.Hashtable;
42
import java.util.Random;
46
<!-- globalinfo-start -->
47
* A meta classifier for handling multi-class datasets with 2-class classifiers by building a random data-balanced tree structure.<br/>
49
* For more info, check<br/>
51
* Lin Dong, Eibe Frank, Stefan Kramer: Ensembles of Balanced Nested Dichotomies for Multi-class Problems. In: PKDD, 84-95, 2005.<br/>
53
* Eibe Frank, Stefan Kramer: Ensembles of nested dichotomies for multi-class problems. In: Twenty-first International Conference on Machine Learning, 2004.
55
<!-- globalinfo-end -->
57
<!-- technical-bibtex-start -->
60
* @inproceedings{Dong2005,
61
* author = {Lin Dong and Eibe Frank and Stefan Kramer},
64
* publisher = {Springer},
65
* title = {Ensembles of Balanced Nested Dichotomies for Multi-class Problems},
69
* @inproceedings{Frank2004,
70
* author = {Eibe Frank and Stefan Kramer},
71
* booktitle = {Twenty-first International Conference on Machine Learning},
73
* title = {Ensembles of nested dichotomies for multi-class problems},
78
<!-- technical-bibtex-end -->
80
<!-- options-start -->
81
* Valid options are: <p/>
83
* <pre> -S <num>
88
* If set, classifier is run in debug mode and
89
* may output additional info to the console</pre>
92
* Full name of base classifier.
93
* (default: weka.classifiers.trees.J48)</pre>
96
* Options specific to classifier weka.classifiers.trees.J48:
100
* Use unpruned tree.</pre>
102
* <pre> -C <pruning confidence>
103
* Set confidence threshold for pruning.
104
* (default 0.25)</pre>
106
* <pre> -M <minimum number of instances>
107
* Set minimum number of instances per leaf.
111
* Use reduced error pruning.</pre>
113
* <pre> -N <number of folds>
114
* Set number of folds for reduced error
115
* pruning. One fold is used as pruning set.
119
* Use binary splits only.</pre>
122
* Don't perform subtree raising.</pre>
125
* Do not clean up after the tree has been built.</pre>
128
* Laplace smoothing for predicted probabilities.</pre>
130
* <pre> -Q <seed>
131
* Seed for random data shuffling (default 1).</pre>
138
public class DataNearBalancedND
139
extends RandomizableSingleClassifierEnhancer
140
implements TechnicalInformationHandler {
142
/** for serialization */
143
static final long serialVersionUID = 5117477294209496368L;
145
/** The filtered classifier in which the base classifier is wrapped. */
146
protected FilteredClassifier m_FilteredClassifier;
148
/** The hashtable for this node. */
149
protected Hashtable m_classifiers=new Hashtable();
151
/** The first successor */
152
protected DataNearBalancedND m_FirstSuccessor = null;
154
/** The second successor */
155
protected DataNearBalancedND m_SecondSuccessor = null;
157
/** The classes that are grouped together at the current node */
158
protected Range m_Range = null;
160
/** Is Hashtable given from END? */
161
protected boolean m_hashtablegiven = false;
166
public DataNearBalancedND() {
168
m_Classifier = new weka.classifiers.trees.J48();
172
* String describing default classifier.
174
* @return the default classifier classname
176
protected String defaultClassifierString() {
178
return "weka.classifiers.trees.J48";
182
* Returns an instance of a TechnicalInformation object, containing
183
* detailed information about the technical background of this class,
184
* e.g., paper reference or book this class is based on.
186
* @return the technical information about this class
188
public TechnicalInformation getTechnicalInformation() {
189
TechnicalInformation result;
190
TechnicalInformation additional;
192
result = new TechnicalInformation(Type.INPROCEEDINGS);
193
result.setValue(Field.AUTHOR, "Lin Dong and Eibe Frank and Stefan Kramer");
194
result.setValue(Field.TITLE, "Ensembles of Balanced Nested Dichotomies for Multi-class Problems");
195
result.setValue(Field.BOOKTITLE, "PKDD");
196
result.setValue(Field.YEAR, "2005");
197
result.setValue(Field.PAGES, "84-95");
198
result.setValue(Field.PUBLISHER, "Springer");
200
additional = result.add(Type.INPROCEEDINGS);
201
additional.setValue(Field.AUTHOR, "Eibe Frank and Stefan Kramer");
202
additional.setValue(Field.TITLE, "Ensembles of nested dichotomies for multi-class problems");
203
additional.setValue(Field.BOOKTITLE, "Twenty-first International Conference on Machine Learning");
204
additional.setValue(Field.YEAR, "2004");
205
additional.setValue(Field.PUBLISHER, "ACM");
211
* Set hashtable from END.
213
* @param table the hashtable to use
215
public void setHashtable(Hashtable table) {
217
m_hashtablegiven = true;
218
m_classifiers = table;
222
* Generates a classifier for the current node and proceeds recursively.
224
* @param data contains the (multi-class) instances
225
* @param classes contains the indices of the classes that are present
226
* @param rand the random number generator to use
227
* @param classifier the classifier to use
228
* @param table the Hashtable to use
229
* @param instsNumAllClasses
230
* @throws Exception if anything goes worng
232
private void generateClassifierForNode(Instances data, Range classes,
233
Random rand, Classifier classifier, Hashtable table,
234
double[] instsNumAllClasses)
238
int[] indices = classes.getSelection();
240
// Randomize the order of the indices
241
for (int j = indices.length - 1; j > 0; j--) {
242
int randPos = rand.nextInt(j + 1);
243
int temp = indices[randPos];
244
indices[randPos] = indices[j];
248
// Pick the classes for the current split
250
for (int j = 0; j < indices.length; j++) {
251
total += instsNumAllClasses[indices[j]];
253
double halfOfTotal = total / 2;
255
// Go through the list of classes until the either the left or
256
// right subset exceeds half the total weight
257
double sumLeft = 0, sumRight = 0;
258
int i = 0, j = indices.length - 1;
261
if (rand.nextBoolean()) {
262
sumLeft += instsNumAllClasses[indices[i++]];
264
sumRight += instsNumAllClasses[indices[j--]];
267
sumLeft += instsNumAllClasses[indices[i++]];
268
sumRight += instsNumAllClasses[indices[j--]];
270
} while (Utils.sm(sumLeft, halfOfTotal) && Utils.sm(sumRight, halfOfTotal));
272
int first = 0, second = 0;
273
if (!Utils.sm(sumLeft, halfOfTotal)) {
278
second = indices.length - first;
280
int[] firstInds = new int[first];
281
int[] secondInds = new int[second];
282
System.arraycopy(indices, 0, firstInds, 0, first);
283
System.arraycopy(indices, first, secondInds, 0, second);
285
// Sort the indices (important for hash key)!
286
int[] sortedFirst = Utils.sort(firstInds);
287
int[] sortedSecond = Utils.sort(secondInds);
288
int[] firstCopy = new int[first];
289
int[] secondCopy = new int[second];
290
for (int k = 0; k < sortedFirst.length; k++) {
291
firstCopy[k] = firstInds[sortedFirst[k]];
293
firstInds = firstCopy;
294
for (int k = 0; k < sortedSecond.length; k++) {
295
secondCopy[k] = secondInds[sortedSecond[k]];
297
secondInds = secondCopy;
299
// Unify indices to improve hashing
300
if (firstInds[0] > secondInds[0]) {
301
int[] help = secondInds;
302
secondInds = firstInds;
309
m_Range = new Range(Range.indicesToRangeList(firstInds));
310
m_Range.setUpper(data.numClasses() - 1);
312
Range secondRange = new Range(Range.indicesToRangeList(secondInds));
313
secondRange.setUpper(data.numClasses() - 1);
315
// Change the class labels and build the classifier
316
MakeIndicator filter = new MakeIndicator();
317
filter.setAttributeIndex("" + (data.classIndex() + 1));
318
filter.setValueIndices(m_Range.getRanges());
319
filter.setNumeric(false);
320
filter.setInputFormat(data);
321
m_FilteredClassifier = new FilteredClassifier();
322
if (data.numInstances() > 0) {
323
m_FilteredClassifier.setClassifier(Classifier.makeCopies(classifier, 1)[0]);
325
m_FilteredClassifier.setClassifier(new weka.classifiers.rules.ZeroR());
327
m_FilteredClassifier.setFilter(filter);
329
// Save reference to hash table at current node
332
if (!m_classifiers.containsKey( getString(firstInds) + "|" + getString(secondInds))) {
333
m_FilteredClassifier.buildClassifier(data);
334
m_classifiers.put(getString(firstInds) + "|" + getString(secondInds), m_FilteredClassifier);
336
m_FilteredClassifier=(FilteredClassifier)m_classifiers.get(getString(firstInds) + "|" +
337
getString(secondInds));
340
// Create two successors if necessary
341
m_FirstSuccessor = new DataNearBalancedND();
343
m_FirstSuccessor.m_Range = m_Range;
345
RemoveWithValues rwv = new RemoveWithValues();
346
rwv.setInvertSelection(true);
347
rwv.setNominalIndices(m_Range.getRanges());
348
rwv.setAttributeIndex("" + (data.classIndex() + 1));
349
rwv.setInputFormat(data);
350
Instances firstSubset = Filter.useFilter(data, rwv);
351
m_FirstSuccessor.generateClassifierForNode(firstSubset, m_Range,
352
rand, classifier, m_classifiers,
355
m_SecondSuccessor = new DataNearBalancedND();
357
m_SecondSuccessor.m_Range = secondRange;
359
RemoveWithValues rwv = new RemoveWithValues();
360
rwv.setInvertSelection(true);
361
rwv.setNominalIndices(secondRange.getRanges());
362
rwv.setAttributeIndex("" + (data.classIndex() + 1));
363
rwv.setInputFormat(data);
364
Instances secondSubset = Filter.useFilter(data, rwv);
365
m_SecondSuccessor = new DataNearBalancedND();
367
m_SecondSuccessor.generateClassifierForNode(secondSubset, secondRange,
368
rand, classifier, m_classifiers,
374
* Returns default capabilities of the classifier.
376
* @return the capabilities of this classifier
378
public Capabilities getCapabilities() {
379
Capabilities result = super.getCapabilities();
382
result.disableAllClasses();
383
result.enable(Capability.NOMINAL_CLASS);
384
result.enable(Capability.MISSING_CLASS_VALUES);
387
result.setMinimumNumberInstances(1);
393
* Builds tree recursively.
395
* @param data contains the (multi-class) instances
396
* @throws Exception if the building fails
398
public void buildClassifier(Instances data) throws Exception {
400
// can classifier handle the data?
401
getCapabilities().testWithFail(data);
403
// remove instances with missing class
404
data = new Instances(data);
405
data.deleteWithMissingClass();
407
Random random = data.getRandomNumberGenerator(m_Seed);
409
if (!m_hashtablegiven) {
410
m_classifiers = new Hashtable();
413
// Check which classes are present in the
414
// data and construct initial list of classes
415
boolean[] present = new boolean[data.numClasses()];
416
for (int i = 0; i < data.numInstances(); i++) {
417
present[(int)data.instance(i).classValue()] = true;
419
StringBuffer list = new StringBuffer();
420
for (int i = 0; i < present.length; i++) {
422
if (list.length() > 0) {
429
// Determine the number of instances in each class
430
double[] instsNum = new double[data.numClasses()];
431
for (int i = 0; i < data.numInstances(); i++) {
432
instsNum[(int)data.instance(i).classValue()] += data.instance(i).weight();
435
Range newRange = new Range(list.toString());
436
newRange.setUpper(data.numClasses() - 1);
438
generateClassifierForNode(data, newRange, random, m_Classifier, m_classifiers, instsNum);
442
* Predicts the class distribution for a given instance
444
* @param inst the (multi-class) instance to be classified
445
* @return the class distribution
446
* @throws Exception if computing fails
448
public double[] distributionForInstance(Instance inst) throws Exception {
450
double[] newDist = new double[inst.numClasses()];
451
if (m_FirstSuccessor == null) {
452
for (int i = 0; i < inst.numClasses(); i++) {
453
if (m_Range.isInRange(i)) {
459
double[] firstDist = m_FirstSuccessor.distributionForInstance(inst);
460
double[] secondDist = m_SecondSuccessor.distributionForInstance(inst);
461
double[] dist = m_FilteredClassifier.distributionForInstance(inst);
462
for (int i = 0; i < inst.numClasses(); i++) {
463
if ((firstDist[i] > 0) && (secondDist[i] > 0)) {
464
System.err.println("Panik!!");
466
if (m_Range.isInRange(i)) {
467
newDist[i] = dist[1] * firstDist[i];
469
newDist[i] = dist[0] * secondDist[i];
472
if (!Utils.eq(Utils.sum(newDist), 1)) {
473
System.err.println(Utils.sum(newDist));
474
for (int j = 0; j < dist.length; j++) {
475
System.err.print(dist[j] + " ");
477
System.err.println();
478
for (int j = 0; j < newDist.length; j++) {
479
System.err.print(newDist[j] + " ");
481
System.err.println();
482
System.err.println(inst);
483
System.err.println(m_FilteredClassifier);
484
//System.err.println(m_Data);
485
System.err.println("bad");
492
* Returns the list of indices as a string.
494
* @param indices the indices to return as string
495
* @return the indices as string
497
public String getString(int [] indices) {
499
StringBuffer string = new StringBuffer();
500
for (int i = 0; i < indices.length; i++) {
504
string.append(indices[i]);
506
return string.toString();
510
* @return a description of the classifier suitable for
511
* displaying in the explorer/experimenter gui
513
public String globalInfo() {
516
"A meta classifier for handling multi-class datasets with 2-class "
517
+ "classifiers by building a random data-balanced tree structure.\n\n"
518
+ "For more info, check\n\n"
519
+ getTechnicalInformation().toString();
523
* Outputs the classifier as a string.
525
* @return a string representation of the classifier
527
public String toString() {
529
if (m_classifiers == null) {
530
return "DataNearBalancedND: No model built yet.";
532
StringBuffer text = new StringBuffer();
533
text.append("DataNearBalancedND");
534
treeToString(text, 0);
536
return text.toString();
540
* Returns string description of the tree.
542
* @param text the buffer to add the node to
543
* @param nn the node number
544
* @return the next node number
546
private int treeToString(StringBuffer text, int nn) {
549
text.append("\n\nNode number: " + nn + "\n\n");
550
if (m_FilteredClassifier != null) {
551
text.append(m_FilteredClassifier);
555
if (m_FirstSuccessor != null) {
556
nn = m_FirstSuccessor.treeToString(text, nn);
557
nn = m_SecondSuccessor.treeToString(text, nn);
563
* Main method for testing this class.
565
* @param argv the options
567
public static void main(String [] argv) {
568
runClassifier(new DataNearBalancedND(), argv);