2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* PredictionAppender.java
19
* Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
23
package weka.gui.beans;
25
import weka.clusterers.DensityBasedClusterer;
26
import weka.core.Instance;
27
import weka.core.Instances;
29
import java.awt.BorderLayout;
30
import java.beans.EventSetDescriptor;
31
import java.io.Serializable;
32
import java.util.Enumeration;
33
import java.util.Vector;
35
import javax.swing.JPanel;
38
* Bean that can can accept batch or incremental classifier events
39
* and produce dataset or instance events which contain instances with
40
* predictions appended.
42
* @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
43
* @version $Revision: 1.15 $
45
public class PredictionAppender
47
implements DataSource, TrainingSetProducer, TestSetProducer, Visible, BeanCommon,
48
EventConstraints, BatchClassifierListener,
49
IncrementalClassifierListener, BatchClustererListener, Serializable {
51
/** for serialization */
52
private static final long serialVersionUID = -2987740065058976673L;
55
* Objects listenening for dataset events
57
protected Vector m_dataSourceListeners = new Vector();
60
* Objects listening for instances events
62
protected Vector m_instanceListeners = new Vector();
65
* Objects listening for training set events
67
protected Vector m_trainingSetListeners = new Vector();;
70
* Objects listening for test set events
72
protected Vector m_testSetListeners = new Vector();
75
* Non null if this object is a target for any events.
77
protected Object m_listenee = null;
80
* Format of instances to be produced.
82
protected Instances m_format;
84
protected BeanVisual m_visual =
85
new BeanVisual("PredictionAppender",
86
BeanVisual.ICON_PATH+"PredictionAppender.gif",
87
BeanVisual.ICON_PATH+"PredictionAppender_animated.gif");
90
* Append classifier's predicted probabilities (if the class is discrete
91
* and the classifier is a distribution classifier)
93
protected boolean m_appendProbabilities;
95
protected transient weka.gui.Logger m_logger;
98
* Global description of this bean
100
* @return a <code>String</code> value
102
public String globalInfo() {
103
return "Accepts batch or incremental classifier events and "
104
+"produces a new data set with classifier predictions appended.";
108
* Creates a new <code>PredictionAppender</code> instance.
110
public PredictionAppender() {
111
setLayout(new BorderLayout());
112
add(m_visual, BorderLayout.CENTER);
116
* Return a tip text suitable for displaying in a GUI
118
* @return a <code>String</code> value
120
public String appendPredictedProbabilitiesTipText() {
121
return "append probabilities rather than labels for discrete class "
126
* Return true if predicted probabilities are to be appended rather
129
* @return a <code>boolean</code> value
131
public boolean getAppendPredictedProbabilities() {
132
return m_appendProbabilities;
136
* Set whether to append predicted probabilities rather than
137
* class value (for discrete class data sets)
139
* @param ap a <code>boolean</code> value
141
public void setAppendPredictedProbabilities(boolean ap) {
142
m_appendProbabilities = ap;
146
* Add a training set listener
148
* @param tsl a <code>TrainingSetListener</code> value
150
public void addTrainingSetListener(TrainingSetListener tsl) {
151
// TODO Auto-generated method stub
152
m_trainingSetListeners.addElement(tsl);
153
// pass on any format that we might have determined so far
154
if (m_format != null) {
155
TrainingSetEvent e = new TrainingSetEvent(this, m_format);
156
tsl.acceptTrainingSet(e);
161
* Remove a training set listener
163
* @param tsl a <code>TrainingSetListener</code> value
165
public void removeTrainingSetListener(TrainingSetListener tsl) {
166
m_trainingSetListeners.removeElement(tsl);
170
* Add a test set listener
172
* @param tsl a <code>TestSetListener</code> value
174
public void addTestSetListener(TestSetListener tsl) {
175
m_testSetListeners.addElement(tsl);
176
// pass on any format that we might have determined so far
177
if (m_format != null) {
178
TestSetEvent e = new TestSetEvent(this, m_format);
179
tsl.acceptTestSet(e);
184
* Remove a test set listener
186
* @param tsl a <code>TestSetListener</code> value
188
public void removeTestSetListener(TestSetListener tsl) {
189
m_testSetListeners.removeElement(tsl);
193
* Add a datasource listener
195
* @param dsl a <code>DataSourceListener</code> value
197
public synchronized void addDataSourceListener(DataSourceListener dsl) {
198
m_dataSourceListeners.addElement(dsl);
199
// pass on any format that we might have determined so far
200
if (m_format != null) {
201
DataSetEvent e = new DataSetEvent(this, m_format);
202
dsl.acceptDataSet(e);
207
* Remove a datasource listener
209
* @param dsl a <code>DataSourceListener</code> value
211
public synchronized void removeDataSourceListener(DataSourceListener dsl) {
212
m_dataSourceListeners.remove(dsl);
216
* Add an instance listener
218
* @param dsl a <code>InstanceListener</code> value
220
public synchronized void addInstanceListener(InstanceListener dsl) {
221
m_instanceListeners.addElement(dsl);
222
// pass on any format that we might have determined so far
223
if (m_format != null) {
224
InstanceEvent e = new InstanceEvent(this, m_format);
225
dsl.acceptInstance(e);
230
* Remove an instance listener
232
* @param dsl a <code>InstanceListener</code> value
234
public synchronized void removeInstanceListener(InstanceListener dsl) {
235
m_instanceListeners.remove(dsl);
239
* Set the visual for this data source
241
* @param newVisual a <code>BeanVisual</code> value
243
public void setVisual(BeanVisual newVisual) {
244
m_visual = newVisual;
248
* Get the visual being used by this data source.
251
public BeanVisual getVisual() {
256
* Use the default images for a data source
259
public void useDefaultVisual() {
260
m_visual.loadIcons(BeanVisual.ICON_PATH+"PredictionAppender.gif",
261
BeanVisual.ICON_PATH+"PredictionAppender_animated.gif");
264
protected InstanceEvent m_instanceEvent;
265
protected double [] m_instanceVals;
269
* Accept and process an incremental classifier event
271
* @param e an <code>IncrementalClassifierEvent</code> value
273
public void acceptClassifier(IncrementalClassifierEvent e) {
274
weka.classifiers.Classifier classifier = e.getClassifier();
275
Instance currentI = e.getCurrentInstance();
276
int status = e.getStatus();
278
if (status == IncrementalClassifierEvent.NEW_BATCH) {
279
oldNumAtts = e.getStructure().numAttributes();
281
oldNumAtts = currentI.dataset().numAttributes();
283
if (status == IncrementalClassifierEvent.NEW_BATCH) {
284
m_instanceEvent = new InstanceEvent(this, null, 0);
285
// create new header structure
286
Instances oldStructure = new Instances(e.getStructure(), 0);
287
//String relationNameModifier = oldStructure.relationName()
288
//+"_with predictions";
289
String relationNameModifier = "_with predictions";
290
//+"_with predictions";
291
if (!m_appendProbabilities
292
|| oldStructure.classAttribute().isNumeric()) {
294
m_format = makeDataSetClass(oldStructure, classifier,
295
relationNameModifier);
296
m_instanceVals = new double [m_format.numAttributes()];
297
} catch (Exception ex) {
298
ex.printStackTrace();
301
} else if (m_appendProbabilities) {
304
makeDataSetProbabilities(oldStructure, classifier,
305
relationNameModifier);
306
m_instanceVals = new double [m_format.numAttributes()];
307
} catch (Exception ex) {
308
ex.printStackTrace();
312
// Pass on the structure
313
m_instanceEvent.setStructure(m_format);
314
notifyInstanceAvailable(m_instanceEvent);
320
// process the actual instance
321
for (int i = 0; i < oldNumAtts; i++) {
322
m_instanceVals[i] = currentI.value(i);
324
if (!m_appendProbabilities
325
|| currentI.dataset().classAttribute().isNumeric()) {
327
classifier.classifyInstance(currentI);
328
m_instanceVals[m_instanceVals.length - 1] = predClass;
329
} else if (m_appendProbabilities) {
330
double [] preds = classifier.distributionForInstance(currentI);
331
for (int i = oldNumAtts; i < m_instanceVals.length; i++) {
332
m_instanceVals[i] = preds[i-oldNumAtts];
335
} catch (Exception ex) {
336
ex.printStackTrace();
339
newInst = new Instance(currentI.weight(), m_instanceVals);
340
newInst.setDataset(m_format);
341
m_instanceEvent.setInstance(newInst);
342
m_instanceEvent.setStatus(status);
344
notifyInstanceAvailable(m_instanceEvent);
347
if (status == IncrementalClassifierEvent.BATCH_FINISHED) {
349
// m_incrementalStructure = null;
350
m_instanceVals = null;
351
m_instanceEvent = null;
356
* Accept and process a batch classifier event
358
* @param e a <code>BatchClassifierEvent</code> value
360
public void acceptClassifier(BatchClassifierEvent e) {
361
if (m_dataSourceListeners.size() > 0
362
|| m_trainingSetListeners.size() > 0
363
|| m_testSetListeners.size() > 0) {
364
Instances testSet = e.getTestSet().getDataSet();
365
Instances trainSet = e.getTrainSet().getDataSet();
366
int setNum = e.getSetNumber();
367
int maxNum = e.getMaxSetNumber();
369
weka.classifiers.Classifier classifier = e.getClassifier();
370
String relationNameModifier = "_set_"+e.getSetNumber()+"_of_"
371
+e.getMaxSetNumber();
372
if (!m_appendProbabilities || testSet.classAttribute().isNumeric()) {
374
Instances newTestSetInstances = makeDataSetClass(testSet, classifier,
375
relationNameModifier);
376
Instances newTrainingSetInstances = makeDataSetClass(trainSet, classifier,
377
relationNameModifier);
379
if (m_trainingSetListeners.size() > 0) {
380
TrainingSetEvent tse = new TrainingSetEvent(this,
381
new Instances(newTrainingSetInstances, 0));
382
tse.m_setNumber = setNum;
383
tse.m_maxSetNumber = maxNum;
384
notifyTrainingSetAvailable(tse);
385
// fill in predicted values
386
for (int i = 0; i < trainSet.numInstances(); i++) {
388
classifier.classifyInstance(trainSet.instance(i));
389
newTrainingSetInstances.instance(i).setValue(newTrainingSetInstances.numAttributes()-1,
392
tse = new TrainingSetEvent(this,
393
newTrainingSetInstances);
394
tse.m_setNumber = setNum;
395
tse.m_maxSetNumber = maxNum;
396
notifyTrainingSetAvailable(tse);
399
if (m_testSetListeners.size() > 0) {
400
TestSetEvent tse = new TestSetEvent(this,
401
new Instances(newTestSetInstances, 0));
402
tse.m_setNumber = setNum;
403
tse.m_maxSetNumber = maxNum;
404
notifyTestSetAvailable(tse);
406
if (m_dataSourceListeners.size() > 0) {
407
notifyDataSetAvailable(new DataSetEvent(this, new Instances(newTestSetInstances,0)));
409
if (e.getTestSet().isStructureOnly()) {
410
m_format = newTestSetInstances;
412
if (m_dataSourceListeners.size() > 0 || m_testSetListeners.size() > 0) {
413
// fill in predicted values
414
for (int i = 0; i < testSet.numInstances(); i++) {
416
classifier.classifyInstance(testSet.instance(i));
417
newTestSetInstances.instance(i).setValue(newTestSetInstances.numAttributes()-1,
422
if (m_testSetListeners.size() > 0) {
423
TestSetEvent tse = new TestSetEvent(this, newTestSetInstances);
424
tse.m_setNumber = setNum;
425
tse.m_maxSetNumber = maxNum;
426
notifyTestSetAvailable(tse);
428
if (m_dataSourceListeners.size() > 0) {
429
notifyDataSetAvailable(new DataSetEvent(this, newTestSetInstances));
432
} catch (Exception ex) {
433
ex.printStackTrace();
436
if (m_appendProbabilities) {
438
Instances newTestSetInstances =
439
makeDataSetProbabilities(testSet,
440
classifier,relationNameModifier);
441
Instances newTrainingSetInstances =
442
makeDataSetProbabilities(trainSet,
443
classifier,relationNameModifier);
444
if (m_trainingSetListeners.size() > 0) {
445
TrainingSetEvent tse = new TrainingSetEvent(this,
446
new Instances(newTrainingSetInstances, 0));
447
tse.m_setNumber = setNum;
448
tse.m_maxSetNumber = maxNum;
449
notifyTrainingSetAvailable(tse);
450
// fill in predicted probabilities
451
for (int i = 0; i < trainSet.numInstances(); i++) {
452
double [] preds = classifier.
453
distributionForInstance(trainSet.instance(i));
454
for (int j = 0; j < trainSet.classAttribute().numValues(); j++) {
455
newTrainingSetInstances.instance(i).setValue(trainSet.numAttributes()+j,
459
tse = new TrainingSetEvent(this,
460
newTrainingSetInstances);
461
tse.m_setNumber = setNum;
462
tse.m_maxSetNumber = maxNum;
463
notifyTrainingSetAvailable(tse);
465
if (m_testSetListeners.size() > 0) {
466
TestSetEvent tse = new TestSetEvent(this,
467
new Instances(newTestSetInstances, 0));
468
tse.m_setNumber = setNum;
469
tse.m_maxSetNumber = maxNum;
470
notifyTestSetAvailable(tse);
472
if (m_dataSourceListeners.size() > 0) {
473
notifyDataSetAvailable(new DataSetEvent(this, new Instances(newTestSetInstances,0)));
475
if (e.getTestSet().isStructureOnly()) {
476
m_format = newTestSetInstances;
478
if (m_dataSourceListeners.size() > 0 || m_testSetListeners.size() > 0) {
479
// fill in predicted probabilities
480
for (int i = 0; i < testSet.numInstances(); i++) {
481
double [] preds = classifier.
482
distributionForInstance(testSet.instance(i));
483
for (int j = 0; j < testSet.classAttribute().numValues(); j++) {
484
newTestSetInstances.instance(i).setValue(testSet.numAttributes()+j,
491
if (m_testSetListeners.size() > 0) {
492
TestSetEvent tse = new TestSetEvent(this, newTestSetInstances);
493
tse.m_setNumber = setNum;
494
tse.m_maxSetNumber = maxNum;
495
notifyTestSetAvailable(tse);
497
if (m_dataSourceListeners.size() > 0) {
498
notifyDataSetAvailable(new DataSetEvent(this, newTestSetInstances));
500
} catch (Exception ex) {
501
ex.printStackTrace();
509
* Accept and process a batch classifier event
511
* @param e a <code>BatchClassifierEvent</code> value
513
public void acceptClusterer(BatchClustererEvent e) {
514
if (m_dataSourceListeners.size() > 0) {
515
if(e.getTestSet().isStructureOnly())
517
Instances testSet = e.getTestSet().getDataSet();
518
weka.clusterers.Clusterer clusterer = e.getClusterer();
520
if(e.getTestOrTrain()==0)
524
String relationNameModifier = "_"+test+"_"+e.getSetNumber()+"_of_"
525
+e.getMaxSetNumber();
526
if (!m_appendProbabilities || !(clusterer instanceof DensityBasedClusterer)) {
527
if(m_appendProbabilities && !(clusterer instanceof DensityBasedClusterer)){
528
System.err.println("Only density based clusterers can append probabilities. Instead cluster will be assigned for each instance.");
529
if (m_logger != null) {
530
m_logger.logMessage("Only density based clusterers can append probabilities. Instead cluster will be assigned for each instance.");
534
Instances newInstances = makeClusterDataSetClass(testSet, clusterer,
535
relationNameModifier);
536
notifyDataSetAvailable(new DataSetEvent(this, new Instances(newInstances,0)));
538
// fill in predicted values
539
for (int i = 0; i < testSet.numInstances(); i++) {
541
clusterer.clusterInstance(testSet.instance(i));
542
newInstances.instance(i).setValue(newInstances.numAttributes()-1,
546
notifyDataSetAvailable(new DataSetEvent(this, newInstances));
548
} catch (Exception ex) {
549
ex.printStackTrace();
554
Instances newInstances =
555
makeClusterDataSetProbabilities(testSet,
556
clusterer,relationNameModifier);
557
notifyDataSetAvailable(new DataSetEvent(this, new Instances(newInstances,0)));
559
// fill in predicted probabilities
560
for (int i = 0; i < testSet.numInstances(); i++) {
561
double [] probs = clusterer.
562
distributionForInstance(testSet.instance(i));
563
for (int j = 0; j < clusterer.numberOfClusters(); j++) {
564
newInstances.instance(i).setValue(testSet.numAttributes()+j,
569
notifyDataSetAvailable(new DataSetEvent(this, newInstances));
570
} catch (Exception ex) {
571
ex.printStackTrace();
578
makeDataSetProbabilities(Instances format,
579
weka.classifiers.Classifier classifier,
580
String relationNameModifier)
582
String classifierName = classifier.getClass().getName();
583
classifierName = classifierName.
584
substring(classifierName.lastIndexOf('.')+1, classifierName.length());
585
int numOrigAtts = format.numAttributes();
586
Instances newInstances = new Instances(format);
587
for (int i = 0; i < format.classAttribute().numValues(); i++) {
588
weka.filters.unsupervised.attribute.Add addF = new
589
weka.filters.unsupervised.attribute.Add();
590
addF.setAttributeIndex("last");
591
addF.setAttributeName(classifierName+"_prob_"+format.classAttribute().value(i));
592
addF.setInputFormat(newInstances);
593
newInstances = weka.filters.Filter.useFilter(newInstances, addF);
595
newInstances.setRelationName(format.relationName()+relationNameModifier);
599
private Instances makeDataSetClass(Instances format,
600
weka.classifiers.Classifier classifier,
601
String relationNameModifier)
604
weka.filters.unsupervised.attribute.Add addF = new
605
weka.filters.unsupervised.attribute.Add();
606
addF.setAttributeIndex("last");
607
String classifierName = classifier.getClass().getName();
608
classifierName = classifierName.
609
substring(classifierName.lastIndexOf('.')+1, classifierName.length());
610
addF.setAttributeName("class_predicted_by: "+classifierName);
611
if (format.classAttribute().isNominal()) {
612
String classLabels = "";
613
Enumeration enu = format.classAttribute().enumerateValues();
614
classLabels += (String)enu.nextElement();
615
while (enu.hasMoreElements()) {
616
classLabels += ","+(String)enu.nextElement();
618
addF.setNominalLabels(classLabels);
620
addF.setInputFormat(format);
623
Instances newInstances =
624
weka.filters.Filter.useFilter(format, addF);
625
newInstances.setRelationName(format.relationName()+relationNameModifier);
630
makeClusterDataSetProbabilities(Instances format,
631
weka.clusterers.Clusterer clusterer,
632
String relationNameModifier)
634
int numOrigAtts = format.numAttributes();
635
Instances newInstances = new Instances(format);
636
for (int i = 0; i < clusterer.numberOfClusters(); i++) {
637
weka.filters.unsupervised.attribute.Add addF = new
638
weka.filters.unsupervised.attribute.Add();
639
addF.setAttributeIndex("last");
640
addF.setAttributeName("prob_cluster"+i);
641
addF.setInputFormat(newInstances);
642
newInstances = weka.filters.Filter.useFilter(newInstances, addF);
644
newInstances.setRelationName(format.relationName()+relationNameModifier);
648
private Instances makeClusterDataSetClass(Instances format,
649
weka.clusterers.Clusterer clusterer,
650
String relationNameModifier)
653
weka.filters.unsupervised.attribute.Add addF = new
654
weka.filters.unsupervised.attribute.Add();
655
addF.setAttributeIndex("last");
656
String clustererName = clusterer.getClass().getName();
657
clustererName = clustererName.
658
substring(clustererName.lastIndexOf('.')+1, clustererName.length());
659
addF.setAttributeName("assigned_cluster: "+clustererName);
660
//if (format.classAttribute().isNominal()) {
661
String clusterLabels = "0";
662
/*Enumeration enu = format.classAttribute().enumerateValues();
663
clusterLabels += (String)enu.nextElement();
664
while (enu.hasMoreElements()) {
665
clusterLabels += ","+(String)enu.nextElement();
667
for(int i = 1; i <= clusterer.numberOfClusters()-1; i++)
668
clusterLabels += ","+i;
669
addF.setNominalLabels(clusterLabels);
671
addF.setInputFormat(format);
674
Instances newInstances =
675
weka.filters.Filter.useFilter(format, addF);
676
newInstances.setRelationName(format.relationName()+relationNameModifier);
681
* Notify all instance listeners that an instance is available
683
* @param e an <code>InstanceEvent</code> value
685
protected void notifyInstanceAvailable(InstanceEvent e) {
687
synchronized (this) {
688
l = (Vector)m_instanceListeners.clone();
692
for(int i = 0; i < l.size(); i++) {
693
((InstanceListener)l.elementAt(i)).acceptInstance(e);
699
* Notify all Data source listeners that a data set is available
701
* @param e a <code>DataSetEvent</code> value
703
protected void notifyDataSetAvailable(DataSetEvent e) {
705
synchronized (this) {
706
l = (Vector)m_dataSourceListeners.clone();
710
for(int i = 0; i < l.size(); i++) {
711
((DataSourceListener)l.elementAt(i)).acceptDataSet(e);
717
* Notify all test set listeners that a test set is available
719
* @param e a <code>TestSetEvent</code> value
721
protected void notifyTestSetAvailable(TestSetEvent e) {
723
synchronized (this) {
724
l = (Vector)m_testSetListeners.clone();
728
for(int i = 0; i < l.size(); i++) {
729
((TestSetListener)l.elementAt(i)).acceptTestSet(e);
735
* Notify all test set listeners that a test set is available
737
* @param e a <code>TestSetEvent</code> value
739
protected void notifyTrainingSetAvailable(TrainingSetEvent e) {
741
synchronized (this) {
742
l = (Vector)m_trainingSetListeners.clone();
746
for(int i = 0; i < l.size(); i++) {
747
((TrainingSetListener)l.elementAt(i)).acceptTrainingSet(e);
755
* @param logger a <code>weka.gui.Logger</code> value
757
public void setLog(weka.gui.Logger logger) {
762
// cant really do anything meaningful here
766
* Returns true if, at this time,
767
* the object will accept a connection according to the supplied
770
* @param eventName the event
771
* @return true if the object will accept a connection
773
public boolean connectionAllowed(String eventName) {
774
return (m_listenee == null);
778
* Returns true if, at this time,
779
* the object will accept a connection according to the supplied
782
* @param esd the EventSetDescriptor
783
* @return true if the object will accept a connection
785
public boolean connectionAllowed(EventSetDescriptor esd) {
786
return connectionAllowed(esd.getName());
790
* Notify this object that it has been registered as a listener with
791
* a source with respect to the supplied event name
794
* @param source the source with which this object has been registered as
797
public synchronized void connectionNotification(String eventName,
799
if (connectionAllowed(eventName)) {
805
* Notify this object that it has been deregistered as a listener with
806
* a source with respect to the supplied event name
808
* @param eventName the event name
809
* @param source the source with which this object has been registered as
812
public synchronized void disconnectionNotification(String eventName,
814
if (m_listenee == source) {
816
m_format = null; // assume any calculated instance format if now invalid
821
* Returns true, if at the current time, the named event could
822
* be generated. Assumes that supplied event names are names of
823
* events that could be generated by this bean.
825
* @param eventName the name of the event in question
826
* @return true if the named event could be generated at this point in
829
public boolean eventGeneratable(String eventName) {
830
if (m_listenee == null) {
834
if (m_listenee instanceof EventConstraints) {
835
if (eventName.equals("instance")) {
836
if (!((EventConstraints)m_listenee).
837
eventGeneratable("incrementalClassifier")) {
841
if (eventName.equals("dataSet")
842
|| eventName.equals("trainingSet")
843
|| eventName.equals("testSet")) {
844
if (((EventConstraints)m_listenee).
845
eventGeneratable("batchClassifier")) {
848
if (((EventConstraints)m_listenee).eventGeneratable("batchClusterer")) {