~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/mi/MISMO.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * MISMO.java
 
19
 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers.mi;
 
24
 
 
25
import weka.classifiers.Classifier;
 
26
import weka.classifiers.functions.Logistic;
 
27
import weka.classifiers.functions.supportVector.Kernel;
 
28
import weka.classifiers.functions.supportVector.SMOset;
 
29
import weka.classifiers.mi.supportVector.MIPolyKernel;
 
30
import weka.core.Attribute;
 
31
import weka.core.Capabilities;
 
32
import weka.core.FastVector;
 
33
import weka.core.Instance;
 
34
import weka.core.Instances;
 
35
import weka.core.MultiInstanceCapabilitiesHandler;
 
36
import weka.core.Option;
 
37
import weka.core.OptionHandler;
 
38
import weka.core.SelectedTag;
 
39
import weka.core.SerializedObject;
 
40
import weka.core.Tag;
 
41
import weka.core.TechnicalInformation;
 
42
import weka.core.TechnicalInformationHandler;
 
43
import weka.core.Utils;
 
44
import weka.core.WeightedInstancesHandler;
 
45
import weka.core.Capabilities.Capability;
 
46
import weka.core.TechnicalInformation.Field;
 
47
import weka.core.TechnicalInformation.Type;
 
48
import weka.filters.Filter;
 
49
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
 
50
import weka.filters.unsupervised.attribute.NominalToBinary;
 
51
import weka.filters.unsupervised.attribute.Normalize;
 
52
import weka.filters.unsupervised.attribute.PropositionalToMultiInstance;
 
53
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
 
54
import weka.filters.unsupervised.attribute.Standardize;
 
55
 
 
56
import java.io.Serializable;
 
57
import java.util.Enumeration;
 
58
import java.util.Random;
 
59
import java.util.Vector;
 
60
 
 
61
/**
 
62
 <!-- globalinfo-start -->
 
63
 * Implements John Platt's sequential minimal optimization algorithm for training a support vector classifier.<br/>
 
64
 * <br/>
 
65
 * This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes by default. (In that case the coefficients in the output are based on the normalized data, not the original data --- this is important for interpreting the classifier.)<br/>
 
66
 * <br/>
 
67
 * Multi-class problems are solved using pairwise classification.<br/>
 
68
 * <br/>
 
69
 * To obtain proper probability estimates, use the option that fits logistic regression models to the outputs of the support vector machine. In the multi-class case the predicted probabilities are coupled using Hastie and Tibshirani's pairwise coupling method.<br/>
 
70
 * <br/>
 
71
 * Note: for improved speed normalization should be turned off when operating on SparseInstances.<br/>
 
72
 * <br/>
 
73
 * For more information on the SMO algorithm, see<br/>
 
74
 * <br/>
 
75
 * J. Platt: Machines using Sequential Minimal Optimization. In B. Schoelkopf and C. Burges and A. Smola, editors, Advances in Kernel Methods - Support Vector Learning, 1998.<br/>
 
76
 * <br/>
 
77
 * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy (2001). Improvements to Platt's SMO Algorithm for SVM Classifier Design. Neural Computation. 13(3):637-649.
 
78
 * <p/>
 
79
 <!-- globalinfo-end -->
 
80
 *
 
81
 <!-- technical-bibtex-start -->
 
82
 * BibTeX:
 
83
 * <pre>
 
84
 * &#64;incollection{Platt1998,
 
85
 *    author = {J. Platt},
 
86
 *    booktitle = {Advances in Kernel Methods - Support Vector Learning},
 
87
 *    editor = {B. Schoelkopf and C. Burges and A. Smola},
 
88
 *    publisher = {MIT Press},
 
89
 *    title = {Machines using Sequential Minimal Optimization},
 
90
 *    year = {1998}
 
91
 * }
 
92
 * 
 
93
 * &#64;article{Keerthi2001,
 
94
 *    author = {S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy},
 
95
 *    journal = {Neural Computation},
 
96
 *    number = {3},
 
97
 *    pages = {637-649},
 
98
 *    title = {Improvements to Platt's SMO Algorithm for SVM Classifier Design},
 
99
 *    volume = {13},
 
100
 *    year = {2001}
 
101
 * }
 
102
 * </pre>
 
103
 * <p/>
 
104
 <!-- technical-bibtex-end -->
 
105
 *
 
106
 <!-- options-start -->
 
107
 * Valid options are: <p/>
 
108
 * 
 
109
 * <pre> -D
 
110
 *  If set, classifier is run in debug mode and
 
111
 *  may output additional info to the console</pre>
 
112
 * 
 
113
 * <pre> -no-checks
 
114
 *  Turns off all checks - use with caution!
 
115
 *  Turning them off assumes that data is purely numeric, doesn't
 
116
 *  contain any missing values, and has a nominal class. Turning them
 
117
 *  off also means that no header information will be stored if the
 
118
 *  machine is linear. Finally, it also assumes that no instance has
 
119
 *  a weight equal to 0.
 
120
 *  (default: checks on)</pre>
 
121
 * 
 
122
 * <pre> -C &lt;double&gt;
 
123
 *  The complexity constant C. (default 1)</pre>
 
124
 * 
 
125
 * <pre> -N
 
126
 *  Whether to 0=normalize/1=standardize/2=neither.
 
127
 *  (default 0=normalize)</pre>
 
128
 * 
 
129
 * <pre> -I
 
130
 *  Use MIminimax feature space. </pre>
 
131
 * 
 
132
 * <pre> -L &lt;double&gt;
 
133
 *  The tolerance parameter. (default 1.0e-3)</pre>
 
134
 * 
 
135
 * <pre> -P &lt;double&gt;
 
136
 *  The epsilon for round-off error. (default 1.0e-12)</pre>
 
137
 * 
 
138
 * <pre> -M
 
139
 *  Fit logistic models to SVM outputs. </pre>
 
140
 * 
 
141
 * <pre> -V &lt;double&gt;
 
142
 *  The number of folds for the internal cross-validation. 
 
143
 *  (default -1, use training data)</pre>
 
144
 * 
 
145
 * <pre> -W &lt;double&gt;
 
146
 *  The random number seed. (default 1)</pre>
 
147
 * 
 
148
 * <pre> -K &lt;classname and parameters&gt;
 
149
 *  The Kernel to use.
 
150
 *  (default: weka.classifiers.functions.supportVector.PolyKernel)</pre>
 
151
 * 
 
152
 * <pre> 
 
153
 * Options specific to kernel weka.classifiers.mi.supportVector.MIPolyKernel:
 
154
 * </pre>
 
155
 * 
 
156
 * <pre> -D
 
157
 *  Enables debugging output (if available) to be printed.
 
158
 *  (default: off)</pre>
 
159
 * 
 
160
 * <pre> -no-checks
 
161
 *  Turns off all checks - use with caution!
 
162
 *  (default: checks on)</pre>
 
163
 * 
 
164
 * <pre> -C &lt;num&gt;
 
165
 *  The size of the cache (a prime number), 0 for full cache and 
 
166
 *  -1 to turn it off.
 
167
 *  (default: 250007)</pre>
 
168
 * 
 
169
 * <pre> -E &lt;num&gt;
 
170
 *  The Exponent to use.
 
171
 *  (default: 1.0)</pre>
 
172
 * 
 
173
 * <pre> -L
 
174
 *  Use lower-order terms.
 
175
 *  (default: no)</pre>
 
176
 * 
 
177
 <!-- options-end -->
 
178
 *
 
179
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 
180
 * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
 
181
 * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
 
182
 * @author Lin Dong (ld21@cs.waikato.ac.nz) (code for adapting to MI data)
 
183
 * @version $Revision: 1.5 $ 
 
184
 */
 
185
public class MISMO 
 
186
  extends Classifier 
 
187
  implements WeightedInstancesHandler, MultiInstanceCapabilitiesHandler,
 
188
             TechnicalInformationHandler {
 
189
 
 
190
  /** for serialization */
 
191
  static final long serialVersionUID = -5834036950143719712L;
 
192
  
 
193
  /**
 
194
   * Returns a string describing classifier
 
195
   * @return a description suitable for
 
196
   * displaying in the explorer/experimenter gui
 
197
   */
 
198
  public String globalInfo() {
 
199
 
 
200
    return  "Implements John Platt's sequential minimal optimization "
 
201
      + "algorithm for training a support vector classifier.\n\n"
 
202
      + "This implementation globally replaces all missing values and "
 
203
      + "transforms nominal attributes into binary ones. It also "
 
204
      + "normalizes all attributes by default. (In that case the coefficients "
 
205
      + "in the output are based on the normalized data, not the "
 
206
      + "original data --- this is important for interpreting the classifier.)\n\n"
 
207
      + "Multi-class problems are solved using pairwise classification.\n\n"
 
208
      + "To obtain proper probability estimates, use the option that fits "
 
209
      + "logistic regression models to the outputs of the support vector "
 
210
      + "machine. In the multi-class case the predicted probabilities "
 
211
      + "are coupled using Hastie and Tibshirani's pairwise coupling "
 
212
      + "method.\n\n"
 
213
      + "Note: for improved speed normalization should be turned off when "
 
214
      + "operating on SparseInstances.\n\n"
 
215
      + "For more information on the SMO algorithm, see\n\n"
 
216
      + getTechnicalInformation().toString();
 
217
  }
 
218
 
 
219
  /**
 
220
   * Returns an instance of a TechnicalInformation object, containing 
 
221
   * detailed information about the technical background of this class,
 
222
   * e.g., paper reference or book this class is based on.
 
223
   * 
 
224
   * @return the technical information about this class
 
225
   */
 
226
  public TechnicalInformation getTechnicalInformation() {
 
227
    TechnicalInformation        result;
 
228
    TechnicalInformation        additional;
 
229
    
 
230
    result = new TechnicalInformation(Type.INCOLLECTION);
 
231
    result.setValue(Field.AUTHOR, "J. Platt");
 
232
    result.setValue(Field.YEAR, "1998");
 
233
    result.setValue(Field.TITLE, "Machines using Sequential Minimal Optimization");
 
234
    result.setValue(Field.BOOKTITLE, "Advances in Kernel Methods - Support Vector Learning");
 
235
    result.setValue(Field.EDITOR, "B. Schoelkopf and C. Burges and A. Smola");
 
236
    result.setValue(Field.PUBLISHER, "MIT Press");
 
237
    
 
238
    additional = result.add(Type.ARTICLE);
 
239
    additional.setValue(Field.AUTHOR, "S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy");
 
240
    additional.setValue(Field.YEAR, "2001");
 
241
    additional.setValue(Field.TITLE, "Improvements to Platt's SMO Algorithm for SVM Classifier Design");
 
242
    additional.setValue(Field.JOURNAL, "Neural Computation");
 
243
    additional.setValue(Field.VOLUME, "13");
 
244
    additional.setValue(Field.NUMBER, "3");
 
245
    additional.setValue(Field.PAGES, "637-649");
 
246
    
 
247
    return result;
 
248
  }
 
249
 
 
250
  /**
 
251
   * Class for building a binary support vector machine.
 
252
   */
 
253
  protected class BinaryMISMO 
 
254
    implements Serializable {
 
255
 
 
256
    /** for serialization */
 
257
    static final long serialVersionUID = -7107082483475433531L;
 
258
    
 
259
    /** The Lagrange multipliers. */
 
260
    protected double[] m_alpha;
 
261
 
 
262
    /** The thresholds. */
 
263
    protected double m_b, m_bLow, m_bUp;
 
264
 
 
265
    /** The indices for m_bLow and m_bUp */
 
266
    protected int m_iLow, m_iUp;
 
267
 
 
268
    /** The training data. */
 
269
    protected Instances m_data;
 
270
 
 
271
    /** Weight vector for linear machine. */
 
272
    protected double[] m_weights;
 
273
 
 
274
    /** Variables to hold weight vector in sparse form.
 
275
      (To reduce storage requirements.) */
 
276
    protected double[] m_sparseWeights;
 
277
    protected int[] m_sparseIndices;
 
278
 
 
279
    /** Kernel to use **/
 
280
    protected Kernel m_kernel;
 
281
 
 
282
    /** The transformed class values. */
 
283
    protected double[] m_class;
 
284
 
 
285
    /** The current set of errors for all non-bound examples. */
 
286
    protected double[] m_errors;
 
287
 
 
288
    /* The five different sets used by the algorithm. */
 
289
    /** {i: 0 < m_alpha[i] < C} */
 
290
    protected SMOset m_I0;
 
291
    /** {i: m_class[i] = 1, m_alpha[i] = 0} */
 
292
    protected SMOset m_I1; 
 
293
    /** {i: m_class[i] = -1, m_alpha[i] = C} */
 
294
    protected SMOset m_I2; 
 
295
    /** {i: m_class[i] = 1, m_alpha[i] = C} */
 
296
    protected SMOset m_I3; 
 
297
    /** {i: m_class[i] = -1, m_alpha[i] = 0} */
 
298
    protected SMOset m_I4; 
 
299
 
 
300
    /** The set of support vectors {i: 0 < m_alpha[i]} */
 
301
    protected SMOset m_supportVectors;
 
302
 
 
303
    /** Stores logistic regression model for probability estimate */
 
304
    protected Logistic m_logistic = null;
 
305
 
 
306
    /** Stores the weight of the training instances */
 
307
    protected double m_sumOfWeights = 0;
 
308
 
 
309
    /**
 
310
     * Fits logistic regression model to SVM outputs analogue
 
311
     * to John Platt's method.  
 
312
     *
 
313
     * @param insts the set of training instances
 
314
     * @param cl1 the first class' index
 
315
     * @param cl2 the second class' index
 
316
     * @param numFolds the number of folds for cross-validation
 
317
     * @param random the random number generator for cross-validation
 
318
     * @throws Exception if the sigmoid can't be fit successfully
 
319
     */
 
320
    protected void fitLogistic(Instances insts, int cl1, int cl2,
 
321
        int numFolds, Random random) 
 
322
      throws Exception {
 
323
 
 
324
      // Create header of instances object
 
325
      FastVector atts = new FastVector(2);
 
326
      atts.addElement(new Attribute("pred"));
 
327
      FastVector attVals = new FastVector(2);
 
328
      attVals.addElement(insts.classAttribute().value(cl1));
 
329
      attVals.addElement(insts.classAttribute().value(cl2));
 
330
      atts.addElement(new Attribute("class", attVals));
 
331
      Instances data = new Instances("data", atts, insts.numInstances());
 
332
      data.setClassIndex(1);
 
333
 
 
334
      // Collect data for fitting the logistic model
 
335
      if (numFolds <= 0) {
 
336
 
 
337
        // Use training data
 
338
        for (int j = 0; j < insts.numInstances(); j++) {
 
339
          Instance inst = insts.instance(j);
 
340
          double[] vals = new double[2];
 
341
          vals[0] = SVMOutput(-1, inst);
 
342
          if (inst.classValue() == cl2) {
 
343
            vals[1] = 1;
 
344
          }
 
345
          data.add(new Instance(inst.weight(), vals));
 
346
        }
 
347
      } else {
 
348
 
 
349
        // Check whether number of folds too large
 
350
        if (numFolds > insts.numInstances()) {
 
351
          numFolds = insts.numInstances();
 
352
        }
 
353
 
 
354
        // Make copy of instances because we will shuffle them around
 
355
        insts = new Instances(insts);
 
356
 
 
357
        // Perform three-fold cross-validation to collect
 
358
        // unbiased predictions
 
359
        insts.randomize(random);
 
360
        insts.stratify(numFolds);
 
361
        for (int i = 0; i < numFolds; i++) {
 
362
          Instances train = insts.trainCV(numFolds, i, random);
 
363
          SerializedObject so = new SerializedObject(this);
 
364
          BinaryMISMO smo = (BinaryMISMO)so.getObject();
 
365
          smo.buildClassifier(train, cl1, cl2, false, -1, -1);
 
366
          Instances test = insts.testCV(numFolds, i);
 
367
          for (int j = 0; j < test.numInstances(); j++) {
 
368
            double[] vals = new double[2];
 
369
            vals[0] = smo.SVMOutput(-1, test.instance(j));
 
370
            if (test.instance(j).classValue() == cl2) {
 
371
              vals[1] = 1;
 
372
            }
 
373
            data.add(new Instance(test.instance(j).weight(), vals));
 
374
          }
 
375
        }
 
376
      }
 
377
 
 
378
      // Build logistic regression model
 
379
      m_logistic = new Logistic();
 
380
      m_logistic.buildClassifier(data);
 
381
    }
 
382
    
 
383
    /**
 
384
     * sets the kernel to use
 
385
     * 
 
386
     * @param value     the kernel to use
 
387
     */
 
388
    public void setKernel(Kernel value) {
 
389
      m_kernel = value;
 
390
    }
 
391
    
 
392
    /**
 
393
     * Returns the kernel to use
 
394
     * 
 
395
     * @return          the current kernel
 
396
     */
 
397
    public Kernel getKernel() {
 
398
      return m_kernel;
 
399
    }
 
400
 
 
401
    /**
 
402
     * Method for building the binary classifier.
 
403
     *
 
404
     * @param insts the set of training instances
 
405
     * @param cl1 the first class' index
 
406
     * @param cl2 the second class' index
 
407
     * @param fitLogistic true if logistic model is to be fit
 
408
     * @param numFolds number of folds for internal cross-validation
 
409
     * @param randomSeed seed value for random number generator for cross-validation
 
410
     * @throws Exception if the classifier can't be built successfully
 
411
     */
 
412
    protected void buildClassifier(Instances insts, int cl1, int cl2,
 
413
        boolean fitLogistic, int numFolds,
 
414
        int randomSeed) throws Exception {
 
415
 
 
416
      // Initialize some variables
 
417
      m_bUp = -1; m_bLow = 1; m_b = 0; 
 
418
      m_alpha = null; m_data = null; m_weights = null; m_errors = null;
 
419
      m_logistic = null; m_I0 = null; m_I1 = null; m_I2 = null;
 
420
      m_I3 = null; m_I4 = null; m_sparseWeights = null; m_sparseIndices = null;
 
421
 
 
422
      // Store the sum of weights
 
423
      m_sumOfWeights = insts.sumOfWeights();
 
424
 
 
425
      // Set class values
 
426
      m_class = new double[insts.numInstances()];
 
427
      m_iUp = -1; m_iLow = -1;
 
428
      for (int i = 0; i < m_class.length; i++) {
 
429
        if ((int) insts.instance(i).classValue() == cl1) {
 
430
          m_class[i] = -1; m_iLow = i;
 
431
        } else if ((int) insts.instance(i).classValue() == cl2) {
 
432
          m_class[i] = 1; m_iUp = i;
 
433
        } else {
 
434
          throw new Exception ("This should never happen!");
 
435
        }
 
436
      }
 
437
 
 
438
      // Check whether one or both classes are missing
 
439
      if ((m_iUp == -1) || (m_iLow == -1)) {
 
440
        if (m_iUp != -1) {
 
441
          m_b = -1;
 
442
        } else if (m_iLow != -1) {
 
443
          m_b = 1;
 
444
        } else {
 
445
          m_class = null;
 
446
          return;
 
447
        }
 
448
        m_supportVectors = new SMOset(0);
 
449
        m_alpha = new double[0];
 
450
        m_class = new double[0];
 
451
 
 
452
        // Fit sigmoid if requested
 
453
        if (fitLogistic) {
 
454
          fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
 
455
        }
 
456
        return;
 
457
      }
 
458
 
 
459
      // Set the reference to the data
 
460
      m_data = insts;
 
461
      m_weights = null;
 
462
 
 
463
      // Initialize alpha array to zero
 
464
      m_alpha = new double[m_data.numInstances()];
 
465
 
 
466
      // Initialize sets
 
467
      m_supportVectors = new SMOset(m_data.numInstances());
 
468
      m_I0 = new SMOset(m_data.numInstances());
 
469
      m_I1 = new SMOset(m_data.numInstances());
 
470
      m_I2 = new SMOset(m_data.numInstances());
 
471
      m_I3 = new SMOset(m_data.numInstances());
 
472
      m_I4 = new SMOset(m_data.numInstances());
 
473
 
 
474
      // Clean out some instance variables
 
475
      m_sparseWeights = null;
 
476
      m_sparseIndices = null;
 
477
 
 
478
      // Initialize error cache
 
479
      m_errors = new double[m_data.numInstances()];
 
480
      m_errors[m_iLow] = 1; m_errors[m_iUp] = -1;
 
481
 
 
482
      // Initialize kernel
 
483
      m_kernel.buildKernel(m_data);
 
484
 
 
485
      // Build up I1 and I4
 
486
      for (int i = 0; i < m_class.length; i++ ) {
 
487
        if (m_class[i] == 1) {
 
488
          m_I1.insert(i);
 
489
        } else {
 
490
          m_I4.insert(i);
 
491
        }
 
492
      }
 
493
 
 
494
      // Loop to find all the support vectors
 
495
      int numChanged = 0;
 
496
      boolean examineAll = true;
 
497
      while ((numChanged > 0) || examineAll) {
 
498
        numChanged = 0;
 
499
        if (examineAll) {
 
500
          for (int i = 0; i < m_alpha.length; i++) {
 
501
            if (examineExample(i)) {
 
502
              numChanged++;
 
503
            }
 
504
          }
 
505
        } else {
 
506
 
 
507
          // This code implements Modification 1 from Keerthi et al.'s paper
 
508
          for (int i = 0; i < m_alpha.length; i++) {
 
509
            if ((m_alpha[i] > 0) &&  
 
510
                (m_alpha[i] < m_C * m_data.instance(i).weight())) {
 
511
              if (examineExample(i)) {
 
512
                numChanged++;
 
513
              }
 
514
 
 
515
              // Is optimality on unbound vectors obtained?
 
516
              if (m_bUp > m_bLow - 2 * m_tol) {
 
517
                numChanged = 0;
 
518
                break;
 
519
              }
 
520
                }
 
521
          }
 
522
 
 
523
          //This is the code for Modification 2 from Keerthi et al.'s paper
 
524
          /*boolean innerLoopSuccess = true; 
 
525
            numChanged = 0;
 
526
            while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) {
 
527
            innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]);
 
528
            }*/
 
529
        }
 
530
 
 
531
        if (examineAll) {
 
532
          examineAll = false;
 
533
        } else if (numChanged == 0) {
 
534
          examineAll = true;
 
535
        }
 
536
      }
 
537
 
 
538
      // Set threshold
 
539
      m_b = (m_bLow + m_bUp) / 2.0;
 
540
 
 
541
      // Save memory
 
542
      m_kernel.clean(); 
 
543
 
 
544
      m_errors = null;
 
545
      m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null;
 
546
 
 
547
      // Fit sigmoid if requested
 
548
      if (fitLogistic) {
 
549
        fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
 
550
      }
 
551
 
 
552
    }
 
553
 
 
554
    /**
 
555
     * Computes SVM output for given instance.
 
556
     *
 
557
     * @param index the instance for which output is to be computed
 
558
     * @param inst the instance 
 
559
     * @return the output of the SVM for the given instance
 
560
     * @throws Exception if something goes wrong
 
561
     */
 
562
    protected double SVMOutput(int index, Instance inst) throws Exception {
 
563
 
 
564
      double result = 0;
 
565
 
 
566
      for (int i = m_supportVectors.getNext(-1); i != -1; 
 
567
          i = m_supportVectors.getNext(i)) {
 
568
        result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst);
 
569
      }
 
570
      result -= m_b;
 
571
 
 
572
      return result;
 
573
    }
 
574
 
 
575
    /**
 
576
     * Prints out the classifier.
 
577
     *
 
578
     * @return a description of the classifier as a string
 
579
     */
 
580
    public String toString() {
 
581
 
 
582
      StringBuffer text = new StringBuffer();
 
583
      int printed = 0;
 
584
 
 
585
      if ((m_alpha == null) && (m_sparseWeights == null)) {
 
586
        return "BinaryMISMO: No model built yet.\n";
 
587
      }
 
588
      try {
 
589
        text.append("BinaryMISMO\n\n");
 
590
 
 
591
        for (int i = 0; i < m_alpha.length; i++) {
 
592
          if (m_supportVectors.contains(i)) {
 
593
            double val = m_alpha[i];
 
594
            if (m_class[i] == 1) {
 
595
              if (printed > 0) {
 
596
                text.append(" + ");
 
597
              }
 
598
            } else {
 
599
              text.append(" - ");
 
600
            }
 
601
            text.append(Utils.doubleToString(val, 12, 4) 
 
602
                + " * <");
 
603
            for (int j = 0; j < m_data.numAttributes(); j++) {
 
604
              if (j != m_data.classIndex()) {
 
605
                text.append(m_data.instance(i).toString(j));
 
606
              }
 
607
              if (j != m_data.numAttributes() - 1) {
 
608
                text.append(" ");
 
609
              }
 
610
            }
 
611
            text.append("> * X]\n");
 
612
            printed++;
 
613
          }
 
614
        }
 
615
 
 
616
        if (m_b > 0) {
 
617
          text.append(" - " + Utils.doubleToString(m_b, 12, 4));
 
618
        } else {
 
619
          text.append(" + " + Utils.doubleToString(-m_b, 12, 4));
 
620
        }
 
621
 
 
622
        text.append("\n\nNumber of support vectors: " + 
 
623
            m_supportVectors.numElements());
 
624
        int numEval = 0;
 
625
        int numCacheHits = -1;
 
626
        if(m_kernel != null)
 
627
        {
 
628
          numEval = m_kernel.numEvals();
 
629
          numCacheHits = m_kernel.numCacheHits();
 
630
        }
 
631
        text.append("\n\nNumber of kernel evaluations: " + numEval);
 
632
        if (numCacheHits >= 0 && numEval > 0)
 
633
        {
 
634
          double hitRatio = 1 - numEval*1.0/(numCacheHits+numEval);
 
635
          text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3).trim() + "% cached)");
 
636
        }
 
637
 
 
638
      } catch (Exception e) {
 
639
        e.printStackTrace();
 
640
 
 
641
        return "Can't print BinaryMISMO classifier.";
 
642
      }
 
643
 
 
644
      return text.toString();
 
645
    }
 
646
 
 
647
    /**
 
648
     * Examines instance.
 
649
     *
 
650
     * @param i2 index of instance to examine
 
651
     * @return true if examination was successfull
 
652
     * @throws Exception if something goes wrong
 
653
     */
 
654
    protected boolean examineExample(int i2) throws Exception {
 
655
 
 
656
      double y2, F2;
 
657
      int i1 = -1;
 
658
 
 
659
      y2 = m_class[i2];
 
660
      if (m_I0.contains(i2)) {
 
661
        F2 = m_errors[i2];
 
662
      } else { 
 
663
        F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;
 
664
        m_errors[i2] = F2;
 
665
 
 
666
        // Update thresholds
 
667
        if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {
 
668
          m_bUp = F2; m_iUp = i2;
 
669
        } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {
 
670
          m_bLow = F2; m_iLow = i2;
 
671
        }
 
672
      }
 
673
 
 
674
      // Check optimality using current bLow and bUp and, if
 
675
      // violated, find an index i1 to do joint optimization
 
676
      // with i2...
 
677
      boolean optimal = true;
 
678
      if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {
 
679
        if (m_bLow - F2 > 2 * m_tol) {
 
680
          optimal = false; i1 = m_iLow;
 
681
        }
 
682
      }
 
683
      if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {
 
684
        if (F2 - m_bUp > 2 * m_tol) {
 
685
          optimal = false; i1 = m_iUp;
 
686
        }
 
687
      }
 
688
      if (optimal) {
 
689
        return false;
 
690
      }
 
691
 
 
692
      // For i2 unbound choose the better i1...
 
693
      if (m_I0.contains(i2)) {
 
694
        if (m_bLow - F2 > F2 - m_bUp) {
 
695
          i1 = m_iLow;
 
696
        } else {
 
697
          i1 = m_iUp;
 
698
        }
 
699
      }
 
700
      if (i1 == -1) {
 
701
        throw new Exception("This should never happen!");
 
702
      }
 
703
      return takeStep(i1, i2, F2);
 
704
    }
 
705
 
 
706
    /**
 
707
     * Method solving for the Lagrange multipliers for
 
708
     * two instances.
 
709
     *
 
710
     * @param i1 index of the first instance
 
711
     * @param i2 index of the second instance
 
712
     * @param F2
 
713
     * @return true if multipliers could be found
 
714
     * @throws Exception if something goes wrong
 
715
     */
 
716
    protected boolean takeStep(int i1, int i2, double F2) throws Exception {
 
717
 
 
718
      double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,
 
719
             a1, a2, f1, f2, v1, v2, Lobj, Hobj;
 
720
      double C1 = m_C * m_data.instance(i1).weight();
 
721
      double C2 = m_C * m_data.instance(i2).weight();
 
722
 
 
723
      // Don't do anything if the two instances are the same
 
724
      if (i1 == i2) {
 
725
        return false;
 
726
      }
 
727
 
 
728
      // Initialize variables
 
729
      alph1 = m_alpha[i1]; alph2 = m_alpha[i2];
 
730
      y1 = m_class[i1]; y2 = m_class[i2];
 
731
      F1 = m_errors[i1];
 
732
      s = y1 * y2;
 
733
 
 
734
      // Find the constraints on a2
 
735
      if (y1 != y2) {
 
736
        L = Math.max(0, alph2 - alph1); 
 
737
        H = Math.min(C2, C1 + alph2 - alph1);
 
738
      } else {
 
739
        L = Math.max(0, alph1 + alph2 - C1);
 
740
        H = Math.min(C2, alph1 + alph2);
 
741
      }
 
742
      if (L >= H) {
 
743
        return false;
 
744
      }
 
745
 
 
746
      // Compute second derivative of objective function
 
747
      k11 = m_kernel.eval(i1, i1, m_data.instance(i1));
 
748
      k12 = m_kernel.eval(i1, i2, m_data.instance(i1));
 
749
      k22 = m_kernel.eval(i2, i2, m_data.instance(i2));
 
750
      eta = 2 * k12 - k11 - k22;
 
751
 
 
752
      // Check if second derivative is negative
 
753
      if (eta < 0) {
 
754
 
 
755
        // Compute unconstrained maximum
 
756
        a2 = alph2 - y2 * (F1 - F2) / eta;
 
757
 
 
758
        // Compute constrained maximum
 
759
        if (a2 < L) {
 
760
          a2 = L;
 
761
        } else if (a2 > H) {
 
762
          a2 = H;
 
763
        }
 
764
      } else {
 
765
 
 
766
        // Look at endpoints of diagonal
 
767
        f1 = SVMOutput(i1, m_data.instance(i1));
 
768
        f2 = SVMOutput(i2, m_data.instance(i2));
 
769
        v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; 
 
770
        v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; 
 
771
        double gamma = alph1 + s * alph2;
 
772
        Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 
 
773
          0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - 
 
774
          y1 * (gamma - s * L) * v1 - y2 * L * v2;
 
775
        Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 
 
776
          0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - 
 
777
          y1 * (gamma - s * H) * v1 - y2 * H * v2;
 
778
        if (Lobj > Hobj + m_eps) {
 
779
          a2 = L;
 
780
        } else if (Lobj < Hobj - m_eps) {
 
781
          a2 = H;
 
782
        } else {
 
783
          a2 = alph2;
 
784
        }
 
785
      }
 
786
      if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {
 
787
        return false;
 
788
      }
 
789
 
 
790
      // To prevent precision problems
 
791
      if (a2 > C2 - m_Del * C2) {
 
792
        a2 = C2;
 
793
      } else if (a2 <= m_Del * C2) {
 
794
        a2 = 0;
 
795
      }
 
796
 
 
797
      // Recompute a1
 
798
      a1 = alph1 + s * (alph2 - a2);
 
799
 
 
800
      // To prevent precision problems
 
801
      if (a1 > C1 - m_Del * C1) {
 
802
        a1 = C1;
 
803
      } else if (a1 <= m_Del * C1) {
 
804
        a1 = 0;
 
805
      }
 
806
 
 
807
      // Update sets
 
808
      if (a1 > 0) {
 
809
        m_supportVectors.insert(i1);
 
810
      } else {
 
811
        m_supportVectors.delete(i1);
 
812
      }
 
813
      if ((a1 > 0) && (a1 < C1)) {
 
814
        m_I0.insert(i1);
 
815
      } else {
 
816
        m_I0.delete(i1);
 
817
      }
 
818
      if ((y1 == 1) && (a1 == 0)) {
 
819
        m_I1.insert(i1);
 
820
      } else {
 
821
        m_I1.delete(i1);
 
822
      }
 
823
      if ((y1 == -1) && (a1 == C1)) {
 
824
        m_I2.insert(i1);
 
825
      } else {
 
826
        m_I2.delete(i1);
 
827
      }
 
828
      if ((y1 == 1) && (a1 == C1)) {
 
829
        m_I3.insert(i1);
 
830
      } else {
 
831
        m_I3.delete(i1);
 
832
      }
 
833
      if ((y1 == -1) && (a1 == 0)) {
 
834
        m_I4.insert(i1);
 
835
      } else {
 
836
        m_I4.delete(i1);
 
837
      }
 
838
      if (a2 > 0) {
 
839
        m_supportVectors.insert(i2);
 
840
      } else {
 
841
        m_supportVectors.delete(i2);
 
842
      }
 
843
      if ((a2 > 0) && (a2 < C2)) {
 
844
        m_I0.insert(i2);
 
845
      } else {
 
846
        m_I0.delete(i2);
 
847
      }
 
848
      if ((y2 == 1) && (a2 == 0)) {
 
849
        m_I1.insert(i2);
 
850
      } else {
 
851
        m_I1.delete(i2);
 
852
      }
 
853
      if ((y2 == -1) && (a2 == C2)) {
 
854
        m_I2.insert(i2);
 
855
      } else {
 
856
        m_I2.delete(i2);
 
857
      }
 
858
      if ((y2 == 1) && (a2 == C2)) {
 
859
        m_I3.insert(i2);
 
860
      } else {
 
861
        m_I3.delete(i2);
 
862
      }
 
863
      if ((y2 == -1) && (a2 == 0)) {
 
864
        m_I4.insert(i2);
 
865
      } else {
 
866
        m_I4.delete(i2);
 
867
      }
 
868
 
 
869
      // Update error cache using new Lagrange multipliers
 
870
      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
 
871
        if ((j != i1) && (j != i2)) {
 
872
          m_errors[j] += 
 
873
            y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + 
 
874
            y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2));
 
875
        }
 
876
      }
 
877
 
 
878
      // Update error cache for i1 and i2
 
879
      m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
 
880
      m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
 
881
 
 
882
      // Update array with Lagrange multipliers
 
883
      m_alpha[i1] = a1;
 
884
      m_alpha[i2] = a2;
 
885
 
 
886
      // Update thresholds
 
887
      m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;
 
888
      m_iLow = -1; m_iUp = -1;
 
889
      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
 
890
        if (m_errors[j] < m_bUp) {
 
891
          m_bUp = m_errors[j]; m_iUp = j;
 
892
        }
 
893
        if (m_errors[j] > m_bLow) {
 
894
          m_bLow = m_errors[j]; m_iLow = j;
 
895
        }
 
896
      }
 
897
      if (!m_I0.contains(i1)) {
 
898
        if (m_I3.contains(i1) || m_I4.contains(i1)) {
 
899
          if (m_errors[i1] > m_bLow) {
 
900
            m_bLow = m_errors[i1]; m_iLow = i1;
 
901
          } 
 
902
        } else {
 
903
          if (m_errors[i1] < m_bUp) {
 
904
            m_bUp = m_errors[i1]; m_iUp = i1;
 
905
          }
 
906
        }
 
907
      }
 
908
      if (!m_I0.contains(i2)) {
 
909
        if (m_I3.contains(i2) || m_I4.contains(i2)) {
 
910
          if (m_errors[i2] > m_bLow) {
 
911
            m_bLow = m_errors[i2]; m_iLow = i2;
 
912
          }
 
913
        } else {
 
914
          if (m_errors[i2] < m_bUp) {
 
915
            m_bUp = m_errors[i2]; m_iUp = i2;
 
916
          }
 
917
        }
 
918
      }
 
919
      if ((m_iLow == -1) || (m_iUp == -1)) {
 
920
        throw new Exception("This should never happen!");
 
921
      }
 
922
 
 
923
      // Made some progress.
 
924
      return true;
 
925
    }
 
926
 
 
927
    /**
 
928
     * Quick and dirty check whether the quadratic programming problem is solved.
 
929
     * 
 
930
     * @throws Exception if something goes wrong
 
931
     */
 
932
    protected void checkClassifier() throws Exception {
 
933
 
 
934
      double sum = 0;
 
935
      for (int i = 0; i < m_alpha.length; i++) {
 
936
        if (m_alpha[i] > 0) {
 
937
          sum += m_class[i] * m_alpha[i];
 
938
        }
 
939
      }
 
940
      System.err.println("Sum of y(i) * alpha(i): " + sum);
 
941
 
 
942
      for (int i = 0; i < m_alpha.length; i++) {
 
943
        double output = SVMOutput(i, m_data.instance(i));
 
944
        if (Utils.eq(m_alpha[i], 0)) {
 
945
          if (Utils.sm(m_class[i] * output, 1)) {
 
946
            System.err.println("KKT condition 1 violated: " + m_class[i] * output);
 
947
          }
 
948
        } 
 
949
        if (Utils.gr(m_alpha[i], 0) && 
 
950
            Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) {
 
951
          if (!Utils.eq(m_class[i] * output, 1)) {
 
952
            System.err.println("KKT condition 2 violated: " + m_class[i] * output);
 
953
          }
 
954
            } 
 
955
        if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) {
 
956
          if (Utils.gr(m_class[i] * output, 1)) {
 
957
            System.err.println("KKT condition 3 violated: " + m_class[i] * output);
 
958
          }
 
959
        } 
 
960
      }
 
961
    }  
 
962
  }
 
963
 
 
964
  /** Normalize training data */
 
965
  public static final int FILTER_NORMALIZE = 0;
 
966
  /** Standardize training data */
 
967
  public static final int FILTER_STANDARDIZE = 1;
 
968
  /** No normalization/standardization */
 
969
  public static final int FILTER_NONE = 2;
 
970
  /** The filter to apply to the training data */
 
971
  public static final Tag [] TAGS_FILTER = {
 
972
    new Tag(FILTER_NORMALIZE, "Normalize training data"),
 
973
    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
 
974
    new Tag(FILTER_NONE, "No normalization/standardization"),
 
975
  };
 
976
 
 
977
  /** The binary classifier(s) */
 
978
  protected BinaryMISMO[][] m_classifiers = null;
 
979
 
 
980
  /** The complexity parameter. */
 
981
  protected double m_C = 1.0;
 
982
 
 
983
  /** Epsilon for rounding. */
 
984
  protected double m_eps = 1.0e-12;
 
985
 
 
986
  /** Tolerance for accuracy of result. */
 
987
  protected double m_tol = 1.0e-3;
 
988
 
 
989
  /** Whether to normalize/standardize/neither */
 
990
  protected int m_filterType = FILTER_NORMALIZE;
 
991
 
 
992
  /** Use MIMinimax feature space?  */
 
993
  protected boolean m_minimax = false;   
 
994
 
 
995
  /** The filter used to make attributes numeric. */
 
996
  protected NominalToBinary m_NominalToBinary;
 
997
 
 
998
  /** The filter used to standardize/normalize all values. */
 
999
  protected Filter m_Filter = null;
 
1000
 
 
1001
  /** The filter used to get rid of missing values. */
 
1002
  protected ReplaceMissingValues m_Missing;
 
1003
 
 
1004
  /** The class index from the training data */
 
1005
  protected int m_classIndex = -1;
 
1006
 
 
1007
  /** The class attribute */
 
1008
  protected Attribute m_classAttribute;
 
1009
  
 
1010
  /** Kernel to use **/
 
1011
  protected Kernel m_kernel = new MIPolyKernel();
 
1012
 
 
1013
  /** Turn off all checks and conversions? Turning them off assumes
 
1014
    that data is purely numeric, doesn't contain any missing values,
 
1015
    and has a nominal class. Turning them off also means that
 
1016
    no header information will be stored if the machine is linear. 
 
1017
    Finally, it also assumes that no instance has a weight equal to 0.*/
 
1018
  protected boolean m_checksTurnedOff;
 
1019
 
 
1020
  /** Precision constant for updating sets */
 
1021
  protected static double m_Del = 1000 * Double.MIN_VALUE;
 
1022
 
 
1023
  /** Whether logistic models are to be fit */
 
1024
  protected boolean m_fitLogisticModels = false;
 
1025
 
 
1026
  /** The number of folds for the internal cross-validation */
 
1027
  protected int m_numFolds = -1;
 
1028
 
 
1029
  /** The random number seed  */
 
1030
  protected int m_randomSeed = 1;
 
1031
 
 
1032
  /**
 
1033
   * Turns off checks for missing values, etc. Use with caution.
 
1034
   */
 
1035
  public void turnChecksOff() {
 
1036
 
 
1037
    m_checksTurnedOff = true;
 
1038
  }
 
1039
 
 
1040
  /**
 
1041
   * Turns on checks for missing values, etc.
 
1042
   */
 
1043
  public void turnChecksOn() {
 
1044
 
 
1045
    m_checksTurnedOff = false;
 
1046
  }
 
1047
 
 
1048
  /**
 
1049
   * Returns default capabilities of the classifier.
 
1050
   *
 
1051
   * @return      the capabilities of this classifier
 
1052
   */
 
1053
  public Capabilities getCapabilities() {
 
1054
    Capabilities result = getKernel().getCapabilities();
 
1055
    result.setOwner(this);
 
1056
 
 
1057
    // attributes
 
1058
    result.enable(Capability.NOMINAL_ATTRIBUTES);
 
1059
    result.enable(Capability.RELATIONAL_ATTRIBUTES);
 
1060
    result.enable(Capability.MISSING_VALUES);
 
1061
 
 
1062
    // class
 
1063
    result.disableAllClasses();
 
1064
    result.disableAllClassDependencies();
 
1065
    result.enable(Capability.NOMINAL_CLASS);
 
1066
    result.enable(Capability.MISSING_CLASS_VALUES);
 
1067
    
 
1068
    // other
 
1069
    result.enable(Capability.ONLY_MULTIINSTANCE);
 
1070
    
 
1071
    return result;
 
1072
  }
 
1073
 
 
1074
  /**
 
1075
   * Returns the capabilities of this multi-instance classifier for the
 
1076
   * relational data.
 
1077
   *
 
1078
   * @return            the capabilities of this object
 
1079
   * @see               Capabilities
 
1080
   */
 
1081
  public Capabilities getMultiInstanceCapabilities() {
 
1082
    Capabilities result = ((MultiInstanceCapabilitiesHandler) getKernel()).getMultiInstanceCapabilities();
 
1083
    result.setOwner(this);
 
1084
 
 
1085
    // attribute
 
1086
    result.enableAllAttributeDependencies();
 
1087
    // with NominalToBinary we can also handle nominal attributes, but only
 
1088
    // if the kernel can handle numeric attributes
 
1089
    if (result.handles(Capability.NUMERIC_ATTRIBUTES))
 
1090
      result.enable(Capability.NOMINAL_ATTRIBUTES);
 
1091
    result.enable(Capability.MISSING_VALUES);
 
1092
    
 
1093
    return result;
 
1094
  }
 
1095
 
 
1096
  /**
 
1097
   * Method for building the classifier. Implements a one-against-one
 
1098
   * wrapper for multi-class problems.
 
1099
   *
 
1100
   * @param insts the set of training instances
 
1101
   * @throws Exception if the classifier can't be built successfully
 
1102
   */
 
1103
  public void buildClassifier(Instances insts) throws Exception {
 
1104
    if (!m_checksTurnedOff) {
 
1105
      // can classifier handle the data?
 
1106
      getCapabilities().testWithFail(insts);
 
1107
 
 
1108
      // remove instances with missing class
 
1109
      insts = new Instances(insts);
 
1110
      insts.deleteWithMissingClass();
 
1111
 
 
1112
      /* Removes all the instances with weight equal to 0.
 
1113
         MUST be done since condition (8) of Keerthi's paper 
 
1114
         is made with the assertion Ci > 0 (See equation (3a). */
 
1115
      Instances data = new Instances(insts, insts.numInstances());
 
1116
      for(int i = 0; i < insts.numInstances(); i++){
 
1117
        if(insts.instance(i).weight() > 0)
 
1118
          data.add(insts.instance(i));
 
1119
      }
 
1120
      if (data.numInstances() == 0) {
 
1121
        throw new Exception("No training instances left after removing " + 
 
1122
            "instance with either a weight null or a missing class!");
 
1123
      }
 
1124
      insts = data;     
 
1125
    }
 
1126
 
 
1127
    // filter data
 
1128
    if (!m_checksTurnedOff) 
 
1129
      m_Missing = new ReplaceMissingValues();
 
1130
    else 
 
1131
      m_Missing = null;
 
1132
 
 
1133
    if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) {
 
1134
      boolean onlyNumeric = true;
 
1135
      if (!m_checksTurnedOff) {
 
1136
        for (int i = 0; i < insts.numAttributes(); i++) {
 
1137
          if (i != insts.classIndex()) {
 
1138
            if (!insts.attribute(i).isNumeric()) {
 
1139
              onlyNumeric = false;
 
1140
              break;
 
1141
            }
 
1142
          }
 
1143
        }
 
1144
      }
 
1145
      
 
1146
      if (!onlyNumeric) {
 
1147
        m_NominalToBinary = new NominalToBinary();
 
1148
        // exclude the bag attribute
 
1149
        m_NominalToBinary.setAttributeIndices("2-last");
 
1150
      }
 
1151
      else {
 
1152
        m_NominalToBinary = null;
 
1153
      }
 
1154
    }
 
1155
    else {
 
1156
      m_NominalToBinary = null;
 
1157
    }
 
1158
 
 
1159
    if (m_filterType == FILTER_STANDARDIZE) 
 
1160
      m_Filter = new Standardize();
 
1161
    else if (m_filterType == FILTER_NORMALIZE)
 
1162
      m_Filter = new Normalize();
 
1163
    else 
 
1164
      m_Filter = null;
 
1165
 
 
1166
 
 
1167
    Instances transformedInsts;
 
1168
    Filter convertToProp = new MultiInstanceToPropositional();
 
1169
    Filter convertToMI = new PropositionalToMultiInstance();
 
1170
 
 
1171
    //transform the data into single-instance format
 
1172
    if (m_minimax){ 
 
1173
      /* using SimpleMI class minimax transform method. 
 
1174
         this method transforms the multi-instance dataset into minmax feature space (single-instance) */
 
1175
      SimpleMI transMinimax = new SimpleMI();
 
1176
      transMinimax.setTransformMethod(
 
1177
          new SelectedTag(
 
1178
            SimpleMI.TRANSFORMMETHOD_MINIMAX, SimpleMI.TAGS_TRANSFORMMETHOD));
 
1179
      transformedInsts = transMinimax.transform(insts);
 
1180
    }
 
1181
    else { 
 
1182
      convertToProp.setInputFormat(insts);
 
1183
      transformedInsts=Filter.useFilter(insts, convertToProp);
 
1184
    }
 
1185
 
 
1186
    if (m_Missing != null) {
 
1187
      m_Missing.setInputFormat(transformedInsts);
 
1188
      transformedInsts = Filter.useFilter(transformedInsts, m_Missing); 
 
1189
    }
 
1190
 
 
1191
    if (m_NominalToBinary != null) { 
 
1192
      m_NominalToBinary.setInputFormat(transformedInsts);
 
1193
      transformedInsts = Filter.useFilter(transformedInsts, m_NominalToBinary); 
 
1194
    }
 
1195
 
 
1196
    if (m_Filter != null) {
 
1197
      m_Filter.setInputFormat(transformedInsts);
 
1198
      transformedInsts = Filter.useFilter(transformedInsts, m_Filter); 
 
1199
    }
 
1200
 
 
1201
    // convert the single-instance format to multi-instance format
 
1202
    convertToMI.setInputFormat(transformedInsts);
 
1203
    insts = Filter.useFilter( transformedInsts, convertToMI);
 
1204
 
 
1205
    m_classIndex = insts.classIndex();
 
1206
    m_classAttribute = insts.classAttribute();
 
1207
 
 
1208
    // Generate subsets representing each class
 
1209
    Instances[] subsets = new Instances[insts.numClasses()];
 
1210
    for (int i = 0; i < insts.numClasses(); i++) {
 
1211
      subsets[i] = new Instances(insts, insts.numInstances());
 
1212
    }
 
1213
    for (int j = 0; j < insts.numInstances(); j++) {
 
1214
      Instance inst = insts.instance(j);
 
1215
      subsets[(int)inst.classValue()].add(inst);
 
1216
    }
 
1217
    for (int i = 0; i < insts.numClasses(); i++) {
 
1218
      subsets[i].compactify();
 
1219
    }
 
1220
 
 
1221
    // Build the binary classifiers
 
1222
    Random rand = new Random(m_randomSeed);
 
1223
    m_classifiers = new BinaryMISMO[insts.numClasses()][insts.numClasses()];
 
1224
    for (int i = 0; i < insts.numClasses(); i++) {
 
1225
      for (int j = i + 1; j < insts.numClasses(); j++) {
 
1226
        m_classifiers[i][j] = new BinaryMISMO();  
 
1227
        m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel()));
 
1228
        Instances data = new Instances(insts, insts.numInstances());
 
1229
        for (int k = 0; k < subsets[i].numInstances(); k++) {
 
1230
          data.add(subsets[i].instance(k));
 
1231
        }
 
1232
        for (int k = 0; k < subsets[j].numInstances(); k++) {
 
1233
          data.add(subsets[j].instance(k));
 
1234
        }  
 
1235
        data.compactify(); 
 
1236
        data.randomize(rand);
 
1237
        m_classifiers[i][j].buildClassifier(data, i, j, 
 
1238
            m_fitLogisticModels,
 
1239
            m_numFolds, m_randomSeed);
 
1240
      }
 
1241
    } 
 
1242
 
 
1243
  }
 
1244
 
 
1245
  /**
 
1246
   * Estimates class probabilities for given instance.
 
1247
   * 
 
1248
   * @param inst the instance to compute the distribution for
 
1249
   * @return the class probabilities
 
1250
   * @throws Exception if computation fails
 
1251
   */
 
1252
  public double[] distributionForInstance(Instance inst) throws Exception { 
 
1253
 
 
1254
    //convert instance into instances
 
1255
    Instances insts = new Instances(inst.dataset(), 0);
 
1256
    insts.add(inst);
 
1257
 
 
1258
    //transform the data into single-instance format
 
1259
    Filter convertToProp = new MultiInstanceToPropositional();
 
1260
    Filter convertToMI = new PropositionalToMultiInstance();
 
1261
 
 
1262
    if (m_minimax){ // using minimax feature space
 
1263
      SimpleMI transMinimax = new SimpleMI();
 
1264
      transMinimax.setTransformMethod(
 
1265
          new SelectedTag(
 
1266
            SimpleMI.TRANSFORMMETHOD_MINIMAX, SimpleMI.TAGS_TRANSFORMMETHOD));
 
1267
      insts = transMinimax.transform (insts);
 
1268
    }
 
1269
    else{
 
1270
      convertToProp.setInputFormat(insts);
 
1271
      insts=Filter.useFilter( insts, convertToProp);
 
1272
    }
 
1273
 
 
1274
    // Filter instances 
 
1275
    if (m_Missing!=null) 
 
1276
      insts = Filter.useFilter(insts, m_Missing); 
 
1277
 
 
1278
    if (m_Filter!=null)
 
1279
      insts = Filter.useFilter(insts, m_Filter);     
 
1280
 
 
1281
    // convert the single-instance format to multi-instance format
 
1282
    convertToMI.setInputFormat(insts);
 
1283
    insts=Filter.useFilter( insts, convertToMI);
 
1284
 
 
1285
    inst = insts.instance(0);  
 
1286
 
 
1287
    if (!m_fitLogisticModels) {
 
1288
      double[] result = new double[inst.numClasses()];
 
1289
      for (int i = 0; i < inst.numClasses(); i++) {
 
1290
        for (int j = i + 1; j < inst.numClasses(); j++) {
 
1291
          if ((m_classifiers[i][j].m_alpha != null) || 
 
1292
              (m_classifiers[i][j].m_sparseWeights != null)) {
 
1293
            double output = m_classifiers[i][j].SVMOutput(-1, inst);
 
1294
            if (output > 0) {
 
1295
              result[j] += 1;
 
1296
            } else {
 
1297
              result[i] += 1;
 
1298
            }
 
1299
              }
 
1300
        } 
 
1301
      }
 
1302
      Utils.normalize(result);
 
1303
      return result;
 
1304
    } else {
 
1305
 
 
1306
      // We only need to do pairwise coupling if there are more
 
1307
      // then two classes.
 
1308
      if (inst.numClasses() == 2) {
 
1309
        double[] newInst = new double[2];
 
1310
        newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst);
 
1311
        newInst[1] = Instance.missingValue();
 
1312
        return m_classifiers[0][1].m_logistic.
 
1313
          distributionForInstance(new Instance(1, newInst));
 
1314
      }
 
1315
      double[][] r = new double[inst.numClasses()][inst.numClasses()];
 
1316
      double[][] n = new double[inst.numClasses()][inst.numClasses()];
 
1317
      for (int i = 0; i < inst.numClasses(); i++) {
 
1318
        for (int j = i + 1; j < inst.numClasses(); j++) {
 
1319
          if ((m_classifiers[i][j].m_alpha != null) || 
 
1320
              (m_classifiers[i][j].m_sparseWeights != null)) {
 
1321
            double[] newInst = new double[2];
 
1322
            newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst);
 
1323
            newInst[1] = Instance.missingValue();
 
1324
            r[i][j] = m_classifiers[i][j].m_logistic.
 
1325
              distributionForInstance(new Instance(1, newInst))[0];
 
1326
            n[i][j] = m_classifiers[i][j].m_sumOfWeights;
 
1327
              }
 
1328
        }
 
1329
      }
 
1330
      return pairwiseCoupling(n, r);
 
1331
    }
 
1332
  }
 
1333
 
 
1334
  /**
 
1335
   * Implements pairwise coupling.
 
1336
   *
 
1337
   * @param n the sum of weights used to train each model
 
1338
   * @param r the probability estimate from each model
 
1339
   * @return the coupled estimates
 
1340
   */
 
1341
  public double[] pairwiseCoupling(double[][] n, double[][] r) {
 
1342
 
 
1343
    // Initialize p and u array
 
1344
    double[] p = new double[r.length];
 
1345
    for (int i =0; i < p.length; i++) {
 
1346
      p[i] = 1.0 / (double)p.length;
 
1347
    }
 
1348
    double[][] u = new double[r.length][r.length];
 
1349
    for (int i = 0; i < r.length; i++) {
 
1350
      for (int j = i + 1; j < r.length; j++) {
 
1351
        u[i][j] = 0.5;
 
1352
      }
 
1353
    }
 
1354
 
 
1355
    // firstSum doesn't change
 
1356
    double[] firstSum = new double[p.length];
 
1357
    for (int i = 0; i < p.length; i++) {
 
1358
      for (int j = i + 1; j < p.length; j++) {
 
1359
        firstSum[i] += n[i][j] * r[i][j];
 
1360
        firstSum[j] += n[i][j] * (1 - r[i][j]);
 
1361
      }
 
1362
    }
 
1363
 
 
1364
    // Iterate until convergence
 
1365
    boolean changed;
 
1366
    do {
 
1367
      changed = false;
 
1368
      double[] secondSum = new double[p.length];
 
1369
      for (int i = 0; i < p.length; i++) {
 
1370
        for (int j = i + 1; j < p.length; j++) {
 
1371
          secondSum[i] += n[i][j] * u[i][j];
 
1372
          secondSum[j] += n[i][j] * (1 - u[i][j]);
 
1373
        }
 
1374
      }
 
1375
      for (int i = 0; i < p.length; i++) {
 
1376
        if ((firstSum[i] == 0) || (secondSum[i] == 0)) {
 
1377
          if (p[i] > 0) {
 
1378
            changed = true;
 
1379
          }
 
1380
          p[i] = 0;
 
1381
        } else {
 
1382
          double factor = firstSum[i] / secondSum[i];
 
1383
          double pOld = p[i];
 
1384
          p[i] *= factor;
 
1385
          if (Math.abs(pOld - p[i]) > 1.0e-3) {
 
1386
            changed = true;
 
1387
          }
 
1388
        }
 
1389
      }
 
1390
      Utils.normalize(p);
 
1391
      for (int i = 0; i < r.length; i++) {
 
1392
        for (int j = i + 1; j < r.length; j++) {
 
1393
          u[i][j] = p[i] / (p[i] + p[j]);
 
1394
        }
 
1395
      }
 
1396
    } while (changed);
 
1397
    return p;
 
1398
  }
 
1399
 
 
1400
  /**
 
1401
   * Returns the weights in sparse format.
 
1402
   * 
 
1403
   * @return the weights in sparse format
 
1404
   */
 
1405
  public double [][][] sparseWeights() {
 
1406
 
 
1407
    int numValues = m_classAttribute.numValues();
 
1408
    double [][][] sparseWeights = new double[numValues][numValues][];
 
1409
 
 
1410
    for (int i = 0; i < numValues; i++) {
 
1411
      for (int j = i + 1; j < numValues; j++) {
 
1412
        sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights;
 
1413
      }
 
1414
    }
 
1415
 
 
1416
    return sparseWeights;
 
1417
  }
 
1418
 
 
1419
  /**
 
1420
   * Returns the indices in sparse format.
 
1421
   * 
 
1422
   * @return the indices in sparse format
 
1423
   */
 
1424
  public int [][][] sparseIndices() {
 
1425
 
 
1426
    int numValues = m_classAttribute.numValues();
 
1427
    int [][][] sparseIndices = new int[numValues][numValues][];
 
1428
 
 
1429
    for (int i = 0; i < numValues; i++) {
 
1430
      for (int j = i + 1; j < numValues; j++) {
 
1431
        sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices;
 
1432
      }
 
1433
    }
 
1434
 
 
1435
    return sparseIndices;
 
1436
  }
 
1437
 
 
1438
  /**
 
1439
   * Returns the bias of each binary SMO.
 
1440
   * 
 
1441
   * @return the bias of each binary SMO
 
1442
   */
 
1443
  public double [][] bias() {
 
1444
 
 
1445
    int numValues = m_classAttribute.numValues();
 
1446
    double [][] bias = new double[numValues][numValues];
 
1447
 
 
1448
    for (int i = 0; i < numValues; i++) {
 
1449
      for (int j = i + 1; j < numValues; j++) {
 
1450
        bias[i][j] = m_classifiers[i][j].m_b;
 
1451
      }
 
1452
    }
 
1453
 
 
1454
    return bias;
 
1455
  }
 
1456
 
 
1457
  /**
 
1458
   * Returns the number of values of the class attribute.
 
1459
   * 
 
1460
   * @return the number values of the class attribute
 
1461
   */
 
1462
  public int numClassAttributeValues() {
 
1463
 
 
1464
    return m_classAttribute.numValues();
 
1465
  }
 
1466
 
 
1467
  /**
 
1468
   * Returns the names of the class attributes.
 
1469
   * 
 
1470
   * @return the names of the class attributes
 
1471
   */
 
1472
  public String[] classAttributeNames() {
 
1473
 
 
1474
    int numValues = m_classAttribute.numValues();
 
1475
 
 
1476
    String[] classAttributeNames = new String[numValues];
 
1477
 
 
1478
    for (int i = 0; i < numValues; i++) {
 
1479
      classAttributeNames[i] = m_classAttribute.value(i);
 
1480
    }
 
1481
 
 
1482
    return classAttributeNames;
 
1483
  }
 
1484
 
 
1485
  /**
 
1486
   * Returns the attribute names.
 
1487
   * 
 
1488
   * @return the attribute names
 
1489
   */
 
1490
  public String[][][] attributeNames() {
 
1491
 
 
1492
    int numValues = m_classAttribute.numValues();
 
1493
    String[][][] attributeNames = new String[numValues][numValues][];
 
1494
 
 
1495
    for (int i = 0; i < numValues; i++) {
 
1496
      for (int j = i + 1; j < numValues; j++) {
 
1497
        int numAttributes = m_classifiers[i][j].m_data.numAttributes();
 
1498
        String[] attrNames = new String[numAttributes];
 
1499
        for (int k = 0; k < numAttributes; k++) {
 
1500
          attrNames[k] = m_classifiers[i][j].m_data.attribute(k).name();
 
1501
        }
 
1502
        attributeNames[i][j] = attrNames;          
 
1503
      }
 
1504
    }
 
1505
    return attributeNames;
 
1506
  }
 
1507
 
 
1508
  /**
 
1509
   * Returns an enumeration describing the available options.
 
1510
   *
 
1511
   * @return an enumeration of all the available options.
 
1512
   */
 
1513
  public Enumeration listOptions() {
 
1514
 
 
1515
    Vector result = new Vector();
 
1516
 
 
1517
    Enumeration enm = super.listOptions();
 
1518
    while (enm.hasMoreElements())
 
1519
      result.addElement(enm.nextElement());
 
1520
 
 
1521
    result.addElement(new Option(
 
1522
        "\tTurns off all checks - use with caution!\n"
 
1523
        + "\tTurning them off assumes that data is purely numeric, doesn't\n"
 
1524
        + "\tcontain any missing values, and has a nominal class. Turning them\n"
 
1525
        + "\toff also means that no header information will be stored if the\n"
 
1526
        + "\tmachine is linear. Finally, it also assumes that no instance has\n"
 
1527
        + "\ta weight equal to 0.\n"
 
1528
        + "\t(default: checks on)",
 
1529
        "no-checks", 0, "-no-checks"));
 
1530
 
 
1531
    result.addElement(new Option(
 
1532
          "\tThe complexity constant C. (default 1)",
 
1533
          "C", 1, "-C <double>"));
 
1534
    
 
1535
    result.addElement(new Option(
 
1536
          "\tWhether to 0=normalize/1=standardize/2=neither.\n" 
 
1537
          + "\t(default 0=normalize)",
 
1538
          "N", 1, "-N"));
 
1539
    
 
1540
    result.addElement(new Option(
 
1541
          "\tUse MIminimax feature space. ",
 
1542
          "I", 0, "-I"));
 
1543
    
 
1544
    result.addElement(new Option(
 
1545
          "\tThe tolerance parameter. (default 1.0e-3)",
 
1546
          "L", 1, "-L <double>"));
 
1547
    
 
1548
    result.addElement(new Option(
 
1549
          "\tThe epsilon for round-off error. (default 1.0e-12)",
 
1550
          "P", 1, "-P <double>"));
 
1551
    
 
1552
    result.addElement(new Option(
 
1553
          "\tFit logistic models to SVM outputs. ",
 
1554
          "M", 0, "-M"));
 
1555
    
 
1556
    result.addElement(new Option(
 
1557
          "\tThe number of folds for the internal cross-validation. \n"
 
1558
          + "\t(default -1, use training data)",
 
1559
          "V", 1, "-V <double>"));
 
1560
    
 
1561
    result.addElement(new Option(
 
1562
          "\tThe random number seed. (default 1)",
 
1563
          "W", 1, "-W <double>"));
 
1564
    
 
1565
    result.addElement(new Option(
 
1566
        "\tThe Kernel to use.\n"
 
1567
        + "\t(default: weka.classifiers.functions.supportVector.PolyKernel)",
 
1568
        "K", 1, "-K <classname and parameters>"));
 
1569
 
 
1570
    result.addElement(new Option(
 
1571
        "",
 
1572
        "", 0, "\nOptions specific to kernel "
 
1573
        + getKernel().getClass().getName() + ":"));
 
1574
    
 
1575
    enm = ((OptionHandler) getKernel()).listOptions();
 
1576
    while (enm.hasMoreElements())
 
1577
      result.addElement(enm.nextElement());
 
1578
 
 
1579
    return result.elements();
 
1580
  }
 
1581
 
 
1582
  /**
 
1583
   * Parses a given list of options. <p/>
 
1584
   * 
 
1585
   <!-- options-start -->
 
1586
   * Valid options are: <p/>
 
1587
   * 
 
1588
   * <pre> -D
 
1589
   *  If set, classifier is run in debug mode and
 
1590
   *  may output additional info to the console</pre>
 
1591
   * 
 
1592
   * <pre> -no-checks
 
1593
   *  Turns off all checks - use with caution!
 
1594
   *  Turning them off assumes that data is purely numeric, doesn't
 
1595
   *  contain any missing values, and has a nominal class. Turning them
 
1596
   *  off also means that no header information will be stored if the
 
1597
   *  machine is linear. Finally, it also assumes that no instance has
 
1598
   *  a weight equal to 0.
 
1599
   *  (default: checks on)</pre>
 
1600
   * 
 
1601
   * <pre> -C &lt;double&gt;
 
1602
   *  The complexity constant C. (default 1)</pre>
 
1603
   * 
 
1604
   * <pre> -N
 
1605
   *  Whether to 0=normalize/1=standardize/2=neither.
 
1606
   *  (default 0=normalize)</pre>
 
1607
   * 
 
1608
   * <pre> -I
 
1609
   *  Use MIminimax feature space. </pre>
 
1610
   * 
 
1611
   * <pre> -L &lt;double&gt;
 
1612
   *  The tolerance parameter. (default 1.0e-3)</pre>
 
1613
   * 
 
1614
   * <pre> -P &lt;double&gt;
 
1615
   *  The epsilon for round-off error. (default 1.0e-12)</pre>
 
1616
   * 
 
1617
   * <pre> -M
 
1618
   *  Fit logistic models to SVM outputs. </pre>
 
1619
   * 
 
1620
   * <pre> -V &lt;double&gt;
 
1621
   *  The number of folds for the internal cross-validation. 
 
1622
   *  (default -1, use training data)</pre>
 
1623
   * 
 
1624
   * <pre> -W &lt;double&gt;
 
1625
   *  The random number seed. (default 1)</pre>
 
1626
   * 
 
1627
   * <pre> -K &lt;classname and parameters&gt;
 
1628
   *  The Kernel to use.
 
1629
   *  (default: weka.classifiers.functions.supportVector.PolyKernel)</pre>
 
1630
   * 
 
1631
   * <pre> 
 
1632
   * Options specific to kernel weka.classifiers.mi.supportVector.MIPolyKernel:
 
1633
   * </pre>
 
1634
   * 
 
1635
   * <pre> -D
 
1636
   *  Enables debugging output (if available) to be printed.
 
1637
   *  (default: off)</pre>
 
1638
   * 
 
1639
   * <pre> -no-checks
 
1640
   *  Turns off all checks - use with caution!
 
1641
   *  (default: checks on)</pre>
 
1642
   * 
 
1643
   * <pre> -C &lt;num&gt;
 
1644
   *  The size of the cache (a prime number), 0 for full cache and 
 
1645
   *  -1 to turn it off.
 
1646
   *  (default: 250007)</pre>
 
1647
   * 
 
1648
   * <pre> -E &lt;num&gt;
 
1649
   *  The Exponent to use.
 
1650
   *  (default: 1.0)</pre>
 
1651
   * 
 
1652
   * <pre> -L
 
1653
   *  Use lower-order terms.
 
1654
   *  (default: no)</pre>
 
1655
   * 
 
1656
   <!-- options-end -->
 
1657
   *
 
1658
   * @param options the list of options as an array of strings
 
1659
   * @throws Exception if an option is not supported 
 
1660
   */
 
1661
  public void setOptions(String[] options) throws Exception {
 
1662
    String      tmpStr;
 
1663
    String[]    tmpOptions;
 
1664
    
 
1665
    setChecksTurnedOff(Utils.getFlag("no-checks", options));
 
1666
 
 
1667
    tmpStr = Utils.getOption('C', options);
 
1668
    if (tmpStr.length() != 0)
 
1669
      setC(Double.parseDouble(tmpStr));
 
1670
    else
 
1671
      setC(1.0);
 
1672
 
 
1673
    tmpStr = Utils.getOption('L', options);
 
1674
    if (tmpStr.length() != 0)
 
1675
      setToleranceParameter(Double.parseDouble(tmpStr));
 
1676
    else
 
1677
      setToleranceParameter(1.0e-3);
 
1678
    
 
1679
    tmpStr = Utils.getOption('P', options);
 
1680
    if (tmpStr.length() != 0)
 
1681
      setEpsilon(new Double(tmpStr));
 
1682
    else
 
1683
      setEpsilon(1.0e-12);
 
1684
 
 
1685
    setMinimax(Utils.getFlag('I', options));
 
1686
 
 
1687
    tmpStr = Utils.getOption('N', options);
 
1688
    if (tmpStr.length() != 0)
 
1689
      setFilterType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_FILTER));
 
1690
    else
 
1691
      setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER));
 
1692
    
 
1693
    setBuildLogisticModels(Utils.getFlag('M', options));
 
1694
    
 
1695
    tmpStr = Utils.getOption('V', options);
 
1696
    if (tmpStr.length() != 0)
 
1697
      m_numFolds = Integer.parseInt(tmpStr);
 
1698
    else
 
1699
      m_numFolds = -1;
 
1700
 
 
1701
    tmpStr = Utils.getOption('W', options);
 
1702
    if (tmpStr.length() != 0)
 
1703
      setRandomSeed(Integer.parseInt(tmpStr));
 
1704
    else
 
1705
      setRandomSeed(1);
 
1706
 
 
1707
    tmpStr     = Utils.getOption('K', options);
 
1708
    tmpOptions = Utils.splitOptions(tmpStr);
 
1709
    if (tmpOptions.length != 0) {
 
1710
      tmpStr        = tmpOptions[0];
 
1711
      tmpOptions[0] = "";
 
1712
      setKernel(Kernel.forName(tmpStr, tmpOptions));
 
1713
    }
 
1714
    
 
1715
    super.setOptions(options);
 
1716
  }
 
1717
 
 
1718
  /**
 
1719
   * Gets the current settings of the classifier.
 
1720
   *
 
1721
   * @return an array of strings suitable for passing to setOptions
 
1722
   */
 
1723
  public String[] getOptions() {
 
1724
    int       i;
 
1725
    Vector    result;
 
1726
    String[]  options;
 
1727
 
 
1728
    result = new Vector();
 
1729
    options = super.getOptions();
 
1730
    for (i = 0; i < options.length; i++)
 
1731
      result.add(options[i]);
 
1732
 
 
1733
    if (getChecksTurnedOff())
 
1734
      result.add("-no-checks");
 
1735
 
 
1736
    result.add("-C"); 
 
1737
    result.add("" + getC());
 
1738
    
 
1739
    result.add("-L");
 
1740
    result.add("" + getToleranceParameter());
 
1741
    
 
1742
    result.add("-P");
 
1743
    result.add("" + getEpsilon());
 
1744
    
 
1745
    result.add("-N");
 
1746
    result.add("" + m_filterType);
 
1747
    
 
1748
    if (getMinimax())
 
1749
      result.add("-I");
 
1750
 
 
1751
    if (getBuildLogisticModels())
 
1752
      result.add("-M");
 
1753
    
 
1754
    result.add("-V");
 
1755
    result.add("" + getNumFolds());
 
1756
    
 
1757
    result.add("-W");
 
1758
    result.add("" + getRandomSeed());
 
1759
    
 
1760
    result.add("-K");
 
1761
    result.add("" + getKernel().getClass().getName() + " " + Utils.joinOptions(getKernel().getOptions()));
 
1762
    
 
1763
    return (String[]) result.toArray(new String[result.size()]);          
 
1764
  }
 
1765
 
 
1766
  /**
 
1767
   * Disables or enables the checks (which could be time-consuming). Use with
 
1768
   * caution!
 
1769
   * 
 
1770
   * @param value       if true turns off all checks
 
1771
   */
 
1772
  public void setChecksTurnedOff(boolean value) {
 
1773
    if (value)
 
1774
      turnChecksOff();
 
1775
    else
 
1776
      turnChecksOn();
 
1777
  }
 
1778
  
 
1779
  /**
 
1780
   * Returns whether the checks are turned off or not.
 
1781
   * 
 
1782
   * @return            true if the checks are turned off
 
1783
   */
 
1784
  public boolean getChecksTurnedOff() {
 
1785
    return m_checksTurnedOff;
 
1786
  }
 
1787
 
 
1788
  /**
 
1789
   * Returns the tip text for this property
 
1790
   * 
 
1791
   * @return            tip text for this property suitable for
 
1792
   *                    displaying in the explorer/experimenter gui
 
1793
   */
 
1794
  public String checksTurnedOffTipText() {
 
1795
    return "Turns time-consuming checks off - use with caution.";
 
1796
  }
 
1797
  
 
1798
  /**
 
1799
   * Returns the tip text for this property
 
1800
   * 
 
1801
   * @return            tip text for this property suitable for
 
1802
   *                    displaying in the explorer/experimenter gui
 
1803
   */
 
1804
  public String kernelTipText() {
 
1805
    return "The kernel to use.";
 
1806
  }
 
1807
 
 
1808
  /**
 
1809
   * Gets the kernel to use.
 
1810
   *
 
1811
   * @return            the kernel
 
1812
   */
 
1813
  public Kernel getKernel() {
 
1814
    return m_kernel;
 
1815
  }
 
1816
    
 
1817
  /**
 
1818
   * Sets the kernel to use.
 
1819
   *
 
1820
   * @param value       the kernel
 
1821
   */
 
1822
  public void setKernel(Kernel value) {
 
1823
    if (!(value instanceof MultiInstanceCapabilitiesHandler))
 
1824
      throw new IllegalArgumentException(
 
1825
          "Kernel must be able to handle multi-instance data!\n"
 
1826
          + "(This one does not implement " + MultiInstanceCapabilitiesHandler.class.getName() + ")");
 
1827
    
 
1828
    m_kernel = value;
 
1829
  }
 
1830
 
 
1831
  /**
 
1832
   * Returns the tip text for this property
 
1833
   * @return tip text for this property suitable for
 
1834
   * displaying in the explorer/experimenter gui
 
1835
   */
 
1836
  public String cTipText() {
 
1837
    return "The complexity parameter C.";
 
1838
  }
 
1839
 
 
1840
  /**
 
1841
   * Get the value of C.
 
1842
   *
 
1843
   * @return Value of C.
 
1844
   */
 
1845
  public double getC() {
 
1846
 
 
1847
    return m_C;
 
1848
  }
 
1849
 
 
1850
  /**
 
1851
   * Set the value of C.
 
1852
   *
 
1853
   * @param v  Value to assign to C.
 
1854
   */
 
1855
  public void setC(double v) {
 
1856
 
 
1857
    m_C = v;
 
1858
  }
 
1859
 
 
1860
  /**
 
1861
   * Returns the tip text for this property
 
1862
   * @return tip text for this property suitable for
 
1863
   * displaying in the explorer/experimenter gui
 
1864
   */
 
1865
  public String toleranceParameterTipText() {
 
1866
    return "The tolerance parameter (shouldn't be changed).";
 
1867
  }
 
1868
 
 
1869
  /**
 
1870
   * Get the value of tolerance parameter.
 
1871
   * @return Value of tolerance parameter.
 
1872
   */
 
1873
  public double getToleranceParameter() {
 
1874
 
 
1875
    return m_tol;
 
1876
  }
 
1877
 
 
1878
  /**
 
1879
   * Set the value of tolerance parameter.
 
1880
   * @param v  Value to assign to tolerance parameter.
 
1881
   */
 
1882
  public void setToleranceParameter(double v) {
 
1883
 
 
1884
    m_tol = v;
 
1885
  }
 
1886
 
 
1887
  /**
 
1888
   * Returns the tip text for this property
 
1889
   * @return tip text for this property suitable for
 
1890
   * displaying in the explorer/experimenter gui
 
1891
   */
 
1892
  public String epsilonTipText() {
 
1893
    return "The epsilon for round-off error (shouldn't be changed).";
 
1894
  }
 
1895
 
 
1896
  /**
 
1897
   * Get the value of epsilon.
 
1898
   * @return Value of epsilon.
 
1899
   */
 
1900
  public double getEpsilon() {
 
1901
 
 
1902
    return m_eps;
 
1903
  }
 
1904
 
 
1905
  /**
 
1906
   * Set the value of epsilon.
 
1907
   * @param v  Value to assign to epsilon.
 
1908
   */
 
1909
  public void setEpsilon(double v) {
 
1910
 
 
1911
    m_eps = v;
 
1912
  }
 
1913
 
 
1914
  /**
 
1915
   * Returns the tip text for this property
 
1916
   * @return tip text for this property suitable for
 
1917
   * displaying in the explorer/experimenter gui
 
1918
   */
 
1919
  public String filterTypeTipText() {
 
1920
    return "Determines how/if the data will be transformed.";
 
1921
  }
 
1922
 
 
1923
  /**
 
1924
   * Gets how the training data will be transformed. Will be one of
 
1925
   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
 
1926
   *
 
1927
   * @return the filtering mode
 
1928
   */
 
1929
  public SelectedTag getFilterType() {
 
1930
 
 
1931
    return new SelectedTag(m_filterType, TAGS_FILTER);
 
1932
  }
 
1933
 
 
1934
  /**
 
1935
   * Sets how the training data will be transformed. Should be one of
 
1936
   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
 
1937
   *
 
1938
   * @param newType the new filtering mode
 
1939
   */
 
1940
  public void setFilterType(SelectedTag newType) {
 
1941
 
 
1942
    if (newType.getTags() == TAGS_FILTER) {
 
1943
      m_filterType = newType.getSelectedTag().getID();
 
1944
    }
 
1945
  }
 
1946
 
 
1947
  /**
 
1948
   * Returns the tip text for this property
 
1949
   *
 
1950
   * @return tip text for this property suitable for
 
1951
   * displaying in the explorer/experimenter gui
 
1952
   */
 
1953
  public String minimaxTipText() {
 
1954
    return "Whether the MIMinimax feature space is to be used.";
 
1955
  }
 
1956
 
 
1957
  /**
 
1958
   * Check if the MIMinimax feature space is to be used.
 
1959
   * @return true if minimax
 
1960
   */
 
1961
  public boolean getMinimax() {
 
1962
 
 
1963
    return m_minimax;
 
1964
  }
 
1965
 
 
1966
  /**
 
1967
   * Set if the MIMinimax feature space is to be used.
 
1968
   * @param v  true if RBF
 
1969
   */
 
1970
  public void setMinimax(boolean v) {
 
1971
    m_minimax = v;
 
1972
  }
 
1973
 
 
1974
  /**
 
1975
   * Returns the tip text for this property
 
1976
   * @return tip text for this property suitable for
 
1977
   * displaying in the explorer/experimenter gui
 
1978
   */
 
1979
  public String buildLogisticModelsTipText() {
 
1980
    return "Whether to fit logistic models to the outputs (for proper "
 
1981
      + "probability estimates).";
 
1982
  }
 
1983
 
 
1984
  /**
 
1985
   * Get the value of buildLogisticModels.
 
1986
   *
 
1987
   * @return Value of buildLogisticModels.
 
1988
   */
 
1989
  public boolean getBuildLogisticModels() {
 
1990
 
 
1991
    return m_fitLogisticModels;
 
1992
  }
 
1993
 
 
1994
  /**
 
1995
   * Set the value of buildLogisticModels.
 
1996
   *
 
1997
   * @param newbuildLogisticModels Value to assign to buildLogisticModels.
 
1998
   */
 
1999
  public void setBuildLogisticModels(boolean newbuildLogisticModels) {
 
2000
 
 
2001
    m_fitLogisticModels = newbuildLogisticModels;
 
2002
  }
 
2003
 
 
2004
  /**
 
2005
   * Returns the tip text for this property
 
2006
   * @return tip text for this property suitable for
 
2007
   * displaying in the explorer/experimenter gui
 
2008
   */
 
2009
  public String numFoldsTipText() {
 
2010
    return "The number of folds for cross-validation used to generate "
 
2011
      + "training data for logistic models (-1 means use training data).";
 
2012
  }
 
2013
 
 
2014
  /**
 
2015
   * Get the value of numFolds.
 
2016
   *
 
2017
   * @return Value of numFolds.
 
2018
   */
 
2019
  public int getNumFolds() {
 
2020
 
 
2021
    return m_numFolds;
 
2022
  }
 
2023
 
 
2024
  /**
 
2025
   * Set the value of numFolds.
 
2026
   *
 
2027
   * @param newnumFolds Value to assign to numFolds.
 
2028
   */
 
2029
  public void setNumFolds(int newnumFolds) {
 
2030
 
 
2031
    m_numFolds = newnumFolds;
 
2032
  }
 
2033
 
 
2034
  /**
 
2035
   * Returns the tip text for this property
 
2036
   * @return tip text for this property suitable for
 
2037
   * displaying in the explorer/experimenter gui
 
2038
   */
 
2039
  public String randomSeedTipText() {
 
2040
    return "Random number seed for the cross-validation.";
 
2041
  }
 
2042
 
 
2043
  /**
 
2044
   * Get the value of randomSeed.
 
2045
   *
 
2046
   * @return Value of randomSeed.
 
2047
   */
 
2048
  public int getRandomSeed() {
 
2049
 
 
2050
    return m_randomSeed;
 
2051
  }
 
2052
 
 
2053
  /**
 
2054
   * Set the value of randomSeed.
 
2055
   *
 
2056
   * @param newrandomSeed Value to assign to randomSeed.
 
2057
   */
 
2058
  public void setRandomSeed(int newrandomSeed) {
 
2059
 
 
2060
    m_randomSeed = newrandomSeed;
 
2061
  }
 
2062
 
 
2063
  /**
 
2064
   * Prints out the classifier.
 
2065
   *
 
2066
   * @return a description of the classifier as a string
 
2067
   */
 
2068
  public String toString() {
 
2069
 
 
2070
    StringBuffer text = new StringBuffer();
 
2071
 
 
2072
    if ((m_classAttribute == null)) {
 
2073
      return "SMO: No model built yet.";
 
2074
    }
 
2075
    try {
 
2076
      text.append("SMO\n\n");
 
2077
      for (int i = 0; i < m_classAttribute.numValues(); i++) {
 
2078
        for (int j = i + 1; j < m_classAttribute.numValues(); j++) {
 
2079
          text.append("Classifier for classes: " + 
 
2080
              m_classAttribute.value(i) + ", " +
 
2081
              m_classAttribute.value(j) + "\n\n");
 
2082
          text.append(m_classifiers[i][j]);
 
2083
          if (m_fitLogisticModels) {
 
2084
            text.append("\n\n");
 
2085
            if ( m_classifiers[i][j].m_logistic == null) {
 
2086
              text.append("No logistic model has been fit.\n");
 
2087
            } else {
 
2088
              text.append(m_classifiers[i][j].m_logistic);
 
2089
            }
 
2090
          }
 
2091
          text.append("\n\n");
 
2092
        }
 
2093
      }
 
2094
    } catch (Exception e) {
 
2095
      return "Can't print SMO classifier.";
 
2096
    }
 
2097
 
 
2098
    return text.toString();
 
2099
  }
 
2100
 
 
2101
  /**
 
2102
   * Main method for testing this class.
 
2103
   * 
 
2104
   * @param argv the commandline parameters
 
2105
   */
 
2106
  public static void main(String[] argv) {
 
2107
    runClassifier(new MISMO(), argv);
 
2108
  }
 
2109
}