2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* ThresholdVisualizePanel.java
19
* Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
23
package weka.gui.visualize;
25
import weka.classifiers.Classifier;
26
import weka.classifiers.evaluation.EvaluationUtils;
27
import weka.classifiers.evaluation.ThresholdCurve;
28
import weka.core.FastVector;
29
import weka.core.Instances;
30
import weka.core.SingleIndex;
31
import weka.core.Utils;
33
import java.awt.BorderLayout;
34
import java.awt.event.ActionEvent;
35
import java.awt.event.ActionListener;
36
import java.awt.event.WindowAdapter;
37
import java.awt.event.WindowEvent;
38
import java.io.BufferedReader;
39
import java.io.FileReader;
41
import javax.swing.BorderFactory;
42
import javax.swing.JFrame;
43
import javax.swing.border.TitledBorder;
46
* This panel is a VisualizePanel, with the added ablility to display the
47
* area under the ROC curve if an ROC curve is chosen.
49
* @author Dale Fletcher (dale@cs.waikato.ac.nz)
50
* @author FracPete (fracpete at waikato dot ac dot nz)
51
* @version $Revision: 1.5 $
53
public class ThresholdVisualizePanel
54
extends VisualizePanel {
56
/** for serialization */
57
private static final long serialVersionUID = 3070002211779443890L;
59
/** The string to add to the Plot Border. */
60
private String m_ROCString="";
62
/** Original border text */
63
private String m_savePanelBorderText;
68
public ThresholdVisualizePanel() {
71
// Save the current border text
72
TitledBorder tb=(TitledBorder) m_plotSurround.getBorder();
73
m_savePanelBorderText = tb.getTitle();
77
* Set the string with ROC area
78
* @param str ROC area string to add to border
80
public void setROCString(String str) {
85
* This extracts the ROC area string
86
* @return ROC area string
88
public String getROCString() {
93
* This overloads VisualizePanel's setUpComboBoxes to add
94
* ActionListeners to watch for when the X/Y Axis comboboxes
96
* @param inst a set of instances with data for plotting
98
public void setUpComboBoxes(Instances inst) {
99
super.setUpComboBoxes(inst);
101
m_XCombo.addActionListener(new ActionListener() {
102
public void actionPerformed(ActionEvent e) {
106
m_YCombo.addActionListener(new ActionListener() {
107
public void actionPerformed(ActionEvent e) {
112
// Just in case the default is ROC
117
* This checks the current selected X/Y Axis comboBoxes to see if
118
* an ROC graph is selected. If so, add the ROC area string to the
119
* plot border, otherwise display the original border text.
121
private void setBorderText() {
123
String xs = m_XCombo.getSelectedItem().toString();
124
String ys = m_YCombo.getSelectedItem().toString();
126
if (xs.equals("X: False Positive Rate (Num)") && ys.equals("Y: True Positive Rate (Num)")) {
127
m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText+" "+m_ROCString)));
129
m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText)));
133
* displays the previously saved instances
135
* @param insts the instances to display
136
* @throws Exception if display is not possible
138
protected void openVisibleInstances(Instances insts) throws Exception {
139
super.openVisibleInstances(insts);
143
+ Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");
149
* Starts the ThresholdVisualizationPanel with parameters from the command line. <p/>
151
* Valid options are: <p/>
153
* lists all the commandline parameters <p/>
156
* Dataset to process with given classifier. <p/>
159
* Full classname of classifier to run.<br/>
160
* Options after '--' are passed to the classifier. <br/>
161
* (default weka.classifiers.functions.Logistic) <p/>
164
* The number of runs to perform (default 2). <p/>
167
* The number of Cross-validation folds (default 10). <p/>
170
* Previously saved threshold curve ARFF file. <p/>
172
* @param args optional commandline parameters
174
public static void main(String [] args) {
176
Classifier classifier;
183
SingleIndex classIndex;
184
SingleIndex valueIndex;
199
if (Utils.getFlag('h', args)) {
200
System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n");
201
System.out.println("-h\n\tThis help.");
202
System.out.println("-t <file>\n\tDataset to process with given classifier.");
203
System.out.println("-c <num>\n\tThe class index. first and last are valid, too (default: last).");
204
System.out.println("-C <num>\n\tThe index of the class value to get the the curve for (default: first).");
205
System.out.println("-W <classname>\n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)");
206
System.out.println("-r <number>\n\tThe number of runs to perform (default: 1).");
207
System.out.println("-x <number>\n\tThe number of Cross-validation folds (default: 10).");
208
System.out.println("-S <number>\n\tThe seed value for randomizing the data (default: 1).");
209
System.out.println("-l <file>\n\tPreviously saved threshold curve ARFF file.");
214
tmpStr = Utils.getOption('l', args);
215
if (tmpStr.length() != 0) {
216
result = new Instances(new BufferedReader(new FileReader(tmpStr)));
221
tmpStr = Utils.getOption('r', args);
222
if (tmpStr.length() != 0)
223
runs = Integer.parseInt(tmpStr);
227
tmpStr = Utils.getOption('x', args);
228
if (tmpStr.length() != 0)
229
folds = Integer.parseInt(tmpStr);
233
tmpStr = Utils.getOption('S', args);
234
if (tmpStr.length() != 0)
235
seed = Integer.parseInt(tmpStr);
239
tmpStr = Utils.getOption('t', args);
240
if (tmpStr.length() != 0) {
241
inst = new Instances(new BufferedReader(new FileReader(tmpStr)));
242
inst.setClassIndex(inst.numAttributes() - 1);
245
tmpStr = Utils.getOption('W', args);
246
if (tmpStr.length() != 0) {
247
options = Utils.partitionOptions(args);
250
tmpStr = weka.classifiers.functions.Logistic.class.getName();
251
options = new String[0];
253
classifier = Classifier.forName(tmpStr, options);
255
tmpStr = Utils.getOption('c', args);
256
if (tmpStr.length() != 0)
257
classIndex = new SingleIndex(tmpStr);
259
classIndex = new SingleIndex("last");
261
tmpStr = Utils.getOption('C', args);
262
if (tmpStr.length() != 0)
263
valueIndex = new SingleIndex(tmpStr);
265
valueIndex = new SingleIndex("first");
268
// compute if necessary
270
if (classIndex != null) {
271
classIndex.setUpper(inst.numAttributes() - 1);
272
inst.setClassIndex(classIndex.getIndex());
275
inst.setClassIndex(inst.numAttributes() - 1);
278
if (valueIndex != null) {
279
valueIndex.setUpper(inst.classAttribute().numValues() - 1);
282
ThresholdCurve tc = new ThresholdCurve();
283
EvaluationUtils eu = new EvaluationUtils();
284
FastVector predictions = new FastVector();
285
for (int i = 0; i < runs; i++) {
286
eu.setSeed(seed + i);
287
predictions.appendElements(eu.getCVPredictions(classifier, inst, folds));
290
if (valueIndex != null)
291
result = tc.getCurve(predictions, valueIndex.getIndex());
293
result = tc.getCurve(predictions);
297
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
298
vmc.setROCString("(Area under ROC = " +
299
Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
302
result.relationName()
303
+ ". (Class value " + inst.classAttribute().value(valueIndex.getIndex()) + ")");
306
result.relationName()
307
+ " (display only)");
308
PlotData2D tempd = new PlotData2D(result);
309
tempd.setPlotName(result.relationName());
310
tempd.addInstanceNumberAttribute();
313
String plotName = vmc.getName();
314
final JFrame jf = new JFrame("Weka Classifier Visualize: "+plotName);
316
jf.getContentPane().setLayout(new BorderLayout());
318
jf.getContentPane().add(vmc, BorderLayout.CENTER);
319
jf.addWindowListener(new WindowAdapter() {
320
public void windowClosing(WindowEvent e) {
327
catch (Exception e) {