~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/gui/visualize/ThresholdVisualizePanel.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    ThresholdVisualizePanel.java
 
19
 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.gui.visualize;
 
24
 
 
25
import weka.classifiers.Classifier;
 
26
import weka.classifiers.evaluation.EvaluationUtils;
 
27
import weka.classifiers.evaluation.ThresholdCurve;
 
28
import weka.core.FastVector;
 
29
import weka.core.Instances;
 
30
import weka.core.SingleIndex;
 
31
import weka.core.Utils;
 
32
 
 
33
import java.awt.BorderLayout;
 
34
import java.awt.event.ActionEvent;
 
35
import java.awt.event.ActionListener;
 
36
import java.awt.event.WindowAdapter;
 
37
import java.awt.event.WindowEvent;
 
38
import java.io.BufferedReader;
 
39
import java.io.FileReader;
 
40
 
 
41
import javax.swing.BorderFactory;
 
42
import javax.swing.JFrame;
 
43
import javax.swing.border.TitledBorder;
 
44
 
 
45
/** 
 
46
 * This panel is a VisualizePanel, with the added ablility to display the
 
47
 * area under the ROC curve if an ROC curve is chosen.
 
48
 *
 
49
 * @author Dale Fletcher (dale@cs.waikato.ac.nz)
 
50
 * @author FracPete (fracpete at waikato dot ac dot nz)
 
51
 * @version $Revision: 1.5 $
 
52
 */
 
53
public class ThresholdVisualizePanel 
 
54
  extends VisualizePanel {
 
55
 
 
56
  /** for serialization */
 
57
  private static final long serialVersionUID = 3070002211779443890L;
 
58
 
 
59
  /** The string to add to the Plot Border. */
 
60
  private String m_ROCString="";
 
61
 
 
62
  /** Original border text */
 
63
  private String m_savePanelBorderText;
 
64
 
 
65
  /**
 
66
   * default constructor
 
67
   */
 
68
  public ThresholdVisualizePanel() {
 
69
    super();
 
70
 
 
71
    // Save the current border text
 
72
    TitledBorder tb=(TitledBorder) m_plotSurround.getBorder();
 
73
    m_savePanelBorderText = tb.getTitle();
 
74
  }
 
75
  
 
76
  /**
 
77
   * Set the string with ROC area
 
78
   * @param str ROC area string to add to border
 
79
   */  
 
80
  public void setROCString(String str) {
 
81
    m_ROCString=str;
 
82
  }
 
83
 
 
84
  /**
 
85
   * This extracts the ROC area string 
 
86
   * @return ROC area string 
 
87
   */
 
88
  public String getROCString() {
 
89
    return m_ROCString;
 
90
  }
 
91
 
 
92
  /**
 
93
   * This overloads VisualizePanel's setUpComboBoxes to add 
 
94
   * ActionListeners to watch for when the X/Y Axis comboboxes
 
95
   * are changed. 
 
96
   * @param inst a set of instances with data for plotting
 
97
   */
 
98
  public void setUpComboBoxes(Instances inst) {
 
99
    super.setUpComboBoxes(inst);
 
100
 
 
101
    m_XCombo.addActionListener(new ActionListener() {
 
102
        public void actionPerformed(ActionEvent e) {
 
103
          setBorderText();
 
104
        }
 
105
    });
 
106
    m_YCombo.addActionListener(new ActionListener() {
 
107
        public void actionPerformed(ActionEvent e) {
 
108
          setBorderText();
 
109
        }
 
110
    });
 
111
 
 
112
    // Just in case the default is ROC
 
113
    setBorderText();
 
114
  }
 
115
 
 
116
  /**
 
117
   * This checks the current selected X/Y Axis comboBoxes to see if 
 
118
   * an ROC graph is selected. If so, add the ROC area string to the
 
119
   * plot border, otherwise display the original border text.
 
120
   */
 
121
  private void setBorderText() {
 
122
 
 
123
    String xs = m_XCombo.getSelectedItem().toString();
 
124
    String ys = m_YCombo.getSelectedItem().toString();
 
125
 
 
126
    if (xs.equals("X: False Positive Rate (Num)") && ys.equals("Y: True Positive Rate (Num)"))   {
 
127
        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText+" "+m_ROCString)));
 
128
    } else
 
129
        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText))); 
 
130
  }
 
131
 
 
132
  /**
 
133
   * displays the previously saved instances
 
134
   * 
 
135
   * @param insts       the instances to display
 
136
   * @throws Exception  if display is not possible
 
137
   */
 
138
  protected void openVisibleInstances(Instances insts) throws Exception {
 
139
    super.openVisibleInstances(insts);
 
140
 
 
141
    setROCString(
 
142
        "(Area under ROC = " 
 
143
        + Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");
 
144
    
 
145
    setBorderText();
 
146
  }
 
147
  
 
148
  /**
 
149
   * Starts the ThresholdVisualizationPanel with parameters from the command line. <p/>
 
150
   * 
 
151
   * Valid options are: <p/>
 
152
   *  -h <br/>
 
153
   *  lists all the commandline parameters <p/>
 
154
   *  
 
155
   *  -t file <br/>
 
156
   *  Dataset to process with given classifier. <p/>
 
157
   *  
 
158
   *  -W classname <br/>
 
159
   *  Full classname of classifier to run.<br/>
 
160
   *  Options after '--' are passed to the classifier. <br/>
 
161
   *  (default weka.classifiers.functions.Logistic) <p/>
 
162
   *  
 
163
   *  -r number <br/>
 
164
   *  The number of runs to perform (default 2). <p/>
 
165
   *  
 
166
   *  -x number <br/>
 
167
   *  The number of Cross-validation folds (default 10). <p/>
 
168
   *  
 
169
   *  -l file <br/>
 
170
   *  Previously saved threshold curve ARFF file. <p/>
 
171
   *
 
172
   * @param args optional commandline parameters
 
173
   */
 
174
  public static void main(String [] args) {
 
175
    Instances           inst;
 
176
    Classifier          classifier;
 
177
    int                 runs;
 
178
    int                 folds;
 
179
    String              tmpStr;
 
180
    boolean             compute;
 
181
    Instances           result;
 
182
    String[]            options;
 
183
    SingleIndex         classIndex;
 
184
    SingleIndex         valueIndex;
 
185
    int                 seed;
 
186
    
 
187
    inst       = null;
 
188
    classifier = null;
 
189
    runs       = 2;
 
190
    folds      = 10;
 
191
    compute    = true;
 
192
    result     = null;
 
193
    classIndex = null;
 
194
    valueIndex = null;
 
195
    seed       = 1;
 
196
    
 
197
    try {
 
198
      // help?
 
199
      if (Utils.getFlag('h', args)) {
 
200
        System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n");
 
201
        System.out.println("-h\n\tThis help.");
 
202
        System.out.println("-t <file>\n\tDataset to process with given classifier.");
 
203
        System.out.println("-c <num>\n\tThe class index. first and last are valid, too (default: last).");
 
204
        System.out.println("-C <num>\n\tThe index of the class value to get the the curve for (default: first).");
 
205
        System.out.println("-W <classname>\n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)");
 
206
        System.out.println("-r <number>\n\tThe number of runs to perform (default: 1).");
 
207
        System.out.println("-x <number>\n\tThe number of Cross-validation folds (default: 10).");
 
208
        System.out.println("-S <number>\n\tThe seed value for randomizing the data (default: 1).");
 
209
        System.out.println("-l <file>\n\tPreviously saved threshold curve ARFF file.");
 
210
        return;
 
211
      }
 
212
      
 
213
      // regular options
 
214
      tmpStr = Utils.getOption('l', args);
 
215
      if (tmpStr.length() != 0) {
 
216
        result = new Instances(new BufferedReader(new FileReader(tmpStr)));
 
217
        compute = false;
 
218
      }
 
219
      
 
220
      if (compute) {
 
221
        tmpStr = Utils.getOption('r', args);
 
222
        if (tmpStr.length() != 0)
 
223
          runs = Integer.parseInt(tmpStr);
 
224
        else
 
225
          runs = 1;
 
226
        
 
227
        tmpStr = Utils.getOption('x', args);
 
228
        if (tmpStr.length() != 0)
 
229
          folds = Integer.parseInt(tmpStr);
 
230
        else
 
231
          folds = 10;
 
232
        
 
233
        tmpStr = Utils.getOption('S', args);
 
234
        if (tmpStr.length() != 0)
 
235
          seed = Integer.parseInt(tmpStr);
 
236
        else
 
237
          seed = 1;
 
238
        
 
239
        tmpStr = Utils.getOption('t', args);
 
240
        if (tmpStr.length() != 0) {
 
241
          inst = new Instances(new BufferedReader(new FileReader(tmpStr)));
 
242
          inst.setClassIndex(inst.numAttributes() - 1);
 
243
        }
 
244
        
 
245
        tmpStr = Utils.getOption('W', args);
 
246
        if (tmpStr.length() != 0) {
 
247
          options = Utils.partitionOptions(args);
 
248
        }
 
249
        else {
 
250
          tmpStr = weka.classifiers.functions.Logistic.class.getName();
 
251
          options = new String[0];
 
252
        }
 
253
        classifier = Classifier.forName(tmpStr, options);
 
254
        
 
255
        tmpStr = Utils.getOption('c', args);
 
256
        if (tmpStr.length() != 0)
 
257
          classIndex = new SingleIndex(tmpStr);
 
258
        else
 
259
          classIndex = new SingleIndex("last");
 
260
        
 
261
        tmpStr = Utils.getOption('C', args);
 
262
        if (tmpStr.length() != 0)
 
263
          valueIndex = new SingleIndex(tmpStr);
 
264
        else
 
265
          valueIndex = new SingleIndex("first");
 
266
      }
 
267
      
 
268
      // compute if necessary
 
269
      if (compute) {
 
270
        if (classIndex != null) {
 
271
          classIndex.setUpper(inst.numAttributes() - 1);
 
272
          inst.setClassIndex(classIndex.getIndex());
 
273
        }
 
274
        else {
 
275
          inst.setClassIndex(inst.numAttributes() - 1);
 
276
        }
 
277
        
 
278
        if (valueIndex != null) {
 
279
          valueIndex.setUpper(inst.classAttribute().numValues() - 1);
 
280
        }
 
281
        
 
282
        ThresholdCurve tc = new ThresholdCurve();
 
283
        EvaluationUtils eu = new EvaluationUtils();
 
284
        FastVector predictions = new FastVector();
 
285
        for (int i = 0; i < runs; i++) {
 
286
          eu.setSeed(seed + i);
 
287
          predictions.appendElements(eu.getCVPredictions(classifier, inst, folds));
 
288
        }
 
289
        
 
290
        if (valueIndex != null)
 
291
          result = tc.getCurve(predictions, valueIndex.getIndex());
 
292
        else
 
293
          result = tc.getCurve(predictions);
 
294
      }
 
295
      
 
296
      // setup GUI
 
297
      ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
 
298
      vmc.setROCString("(Area under ROC = " + 
 
299
          Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
 
300
      if (compute)     
 
301
        vmc.setName(
 
302
            result.relationName() 
 
303
            + ". (Class value " + inst.classAttribute().value(valueIndex.getIndex()) + ")");
 
304
      else
 
305
        vmc.setName(
 
306
            result.relationName()
 
307
            + " (display only)");
 
308
      PlotData2D tempd = new PlotData2D(result);
 
309
      tempd.setPlotName(result.relationName());
 
310
      tempd.addInstanceNumberAttribute();
 
311
      vmc.addPlot(tempd);
 
312
      
 
313
      String plotName = vmc.getName(); 
 
314
      final JFrame jf = new JFrame("Weka Classifier Visualize: "+plotName);
 
315
      jf.setSize(500,400);
 
316
      jf.getContentPane().setLayout(new BorderLayout());
 
317
      
 
318
      jf.getContentPane().add(vmc, BorderLayout.CENTER);
 
319
      jf.addWindowListener(new WindowAdapter() {
 
320
        public void windowClosing(WindowEvent e) {
 
321
          jf.dispose();
 
322
        }
 
323
      });
 
324
      
 
325
      jf.setVisible(true);
 
326
    }
 
327
    catch (Exception e) {
 
328
      e.printStackTrace();
 
329
    }
 
330
  }
 
331
}
 
332
 
 
333
 
 
334
 
 
335
 
 
336
 
 
337
 
 
338
 
 
339