~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/CostMatrix.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    CostMatrix.java
 
19
 *    Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.classifiers;
 
24
 
 
25
import weka.core.AttributeExpression;
 
26
import weka.core.Instance;
 
27
import weka.core.Instances;
 
28
import weka.core.Matrix;
 
29
import weka.core.Utils;
 
30
 
 
31
import java.io.LineNumberReader;
 
32
import java.io.Reader;
 
33
import java.io.Serializable;
 
34
import java.io.StreamTokenizer;
 
35
import java.io.Writer;
 
36
import java.util.Random;
 
37
import java.util.StringTokenizer;
 
38
 
 
39
/**
 
40
 * Class for storing and manipulating a misclassification cost matrix.
 
41
 * The element at position i,j in the matrix is the penalty for classifying
 
42
 * an instance of class j as class i. Cost values can be fixed or
 
43
 * computed on a per-instance basis (cost sensitive evaluation only) 
 
44
 * from the value of an attribute or an expression involving 
 
45
 * attribute(s).
 
46
 *
 
47
 * @author Mark Hall
 
48
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 
49
 * @version $Revision: 1.15 $
 
50
 */
 
51
public class CostMatrix implements Serializable {
 
52
 
 
53
  /** for serialization */
 
54
  private static final long serialVersionUID = -1973792250544554965L;
 
55
  
 
56
  private int m_size;
 
57
 
 
58
  /** [rows][columns] */
 
59
  protected Object [][] m_matrix;
 
60
 
 
61
  /** The deafult file extension for cost matrix files */
 
62
  public static String FILE_EXTENSION = ".cost";
 
63
 
 
64
  /**
 
65
   * Creates a default cost matrix of a particular size. 
 
66
   * All diagonal values will be 0 and all non-diagonal values 1.
 
67
   *
 
68
   * @param numOfClasses the number of classes that the cost matrix holds.
 
69
   */  
 
70
  public CostMatrix(int numOfClasses) {
 
71
    m_size = numOfClasses;
 
72
    initialize();
 
73
  }
 
74
 
 
75
  /**
 
76
   * Creates a cost matrix that is a copy of another.
 
77
   *
 
78
   * @param toCopy the matrix to copy.
 
79
   */
 
80
  public CostMatrix(CostMatrix toCopy) {
 
81
    this(toCopy.size());
 
82
 
 
83
    for (int i = 0; i < m_size; i++) {
 
84
      for (int j = 0; j < m_size; j++) {
 
85
        setCell(i, j, toCopy.getCell(i, j));
 
86
      }
 
87
    }
 
88
  }
 
89
 
 
90
  /**
 
91
   * Initializes the matrix
 
92
   */
 
93
  public void initialize() {
 
94
    m_matrix = new Object[m_size][m_size];
 
95
    for (int i = 0; i < m_size; i++) {
 
96
      for (int j = 0; j < m_size; j++) {
 
97
        setCell(i, j, i == j ? new Double(0.0) : new Double(1.0));
 
98
      }
 
99
    }
 
100
  }
 
101
 
 
102
  /**
 
103
   * The number of rows (and columns)
 
104
   * @return the size of the matrix
 
105
   */
 
106
  public int size() {
 
107
    return m_size;
 
108
  }
 
109
 
 
110
  /**
 
111
   * Same as size
 
112
   * @return the number of columns
 
113
   */
 
114
  public int numColumns() {
 
115
    return size();
 
116
  }
 
117
 
 
118
  /**
 
119
   * Same as size
 
120
   * @return the number of rows
 
121
   */
 
122
  public int numRows() {
 
123
    return size();
 
124
  }
 
125
 
 
126
  private boolean replaceStrings() throws Exception {
 
127
    boolean nonDouble = false;
 
128
 
 
129
    for (int i = 0; i < m_size; i++) {
 
130
      for (int j = 0; j < m_size; j++) {
 
131
        if (getCell(i, j) instanceof String) {
 
132
          AttributeExpression temp = new AttributeExpression();
 
133
          temp.convertInfixToPostfix((String)getCell(i, j));
 
134
          setCell(i, j, temp);
 
135
          nonDouble = true;
 
136
        } else if (getCell(i, j) instanceof AttributeExpression) {
 
137
          nonDouble = true;
 
138
        }
 
139
      }
 
140
    }
 
141
    
 
142
    return nonDouble;
 
143
  }
 
144
 
 
145
  /**
 
146
   * Applies the cost matrix to a set of instances. If a random number generator is 
 
147
   * supplied the instances will be resampled, otherwise they will be rewighted. 
 
148
   * Adapted from code once sitting in Instances.java
 
149
   *
 
150
   * @param data the instances to reweight.
 
151
   * @param random a random number generator for resampling, if null then instances are
 
152
   * rewighted.
 
153
   * @return a new dataset reflecting the cost of misclassification.
 
154
   * @exception Exception if the data has no class or the matrix in inappropriate.
 
155
   */
 
156
  public Instances applyCostMatrix(Instances data, Random random)
 
157
    throws Exception {
 
158
    
 
159
    if (replaceStrings()) {
 
160
      // could reweight in the two class case
 
161
      throw new Exception("Can't resample/reweight instances using "
 
162
                          +"non-fixed cost values!");
 
163
    }
 
164
 
 
165
    double sumOfWeightFactors = 0, sumOfMissClassWeights,
 
166
      sumOfWeights;
 
167
    double [] weightOfInstancesInClass, weightFactor, weightOfInstances;
 
168
    Instances newData;
 
169
    
 
170
    if (data.classIndex() < 0) {
 
171
      throw new Exception("Class index is not set!");
 
172
    }
 
173
 
 
174
    if (size() != data.numClasses()) { 
 
175
      throw new Exception("Misclassification cost matrix has "+
 
176
                          "wrong format!");
 
177
    }
 
178
 
 
179
    weightFactor = new double[data.numClasses()];
 
180
    weightOfInstancesInClass = new double[data.numClasses()];
 
181
    for (int j = 0; j < data.numInstances(); j++) {
 
182
      weightOfInstancesInClass[(int)data.instance(j).classValue()] += 
 
183
        data.instance(j).weight();
 
184
    }
 
185
    sumOfWeights = Utils.sum(weightOfInstancesInClass);
 
186
 
 
187
    // normalize the matrix if not already
 
188
    for (int i=0; i< m_size; i++) {
 
189
      if (!Utils.eq(((Double)getCell(i, i)).doubleValue(), 0)) {
 
190
        CostMatrix normMatrix = new CostMatrix(this);
 
191
        normMatrix.normalize();
 
192
        return normMatrix.applyCostMatrix(data, random);
 
193
      }
 
194
    }
 
195
 
 
196
    for (int i = 0; i < data.numClasses(); i++) {      
 
197
      // Using Kai Ming Ting's formula for deriving weights for 
 
198
      // the classes and Breiman's heuristic for multiclass 
 
199
      // problems.
 
200
 
 
201
      sumOfMissClassWeights = 0;
 
202
      for (int j = 0; j < data.numClasses(); j++) {
 
203
        if (Utils.sm(((Double)getCell(i,j)).doubleValue(),0)) {
 
204
          throw new Exception("Neg. weights in misclassification "+
 
205
                              "cost matrix!"); 
 
206
        }
 
207
        sumOfMissClassWeights 
 
208
          += ((Double)getCell(i,j)).doubleValue();
 
209
      }
 
210
      weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
 
211
      sumOfWeightFactors += sumOfMissClassWeights * 
 
212
        weightOfInstancesInClass[i];
 
213
    }
 
214
    for (int i = 0; i < data.numClasses(); i++) {
 
215
      weightFactor[i] /= sumOfWeightFactors;
 
216
    }
 
217
 
 
218
    // Store new weights
 
219
    weightOfInstances = new double[data.numInstances()];
 
220
    for (int i = 0; i < data.numInstances(); i++) {
 
221
      weightOfInstances[i] = data.instance(i).weight()*
 
222
        weightFactor[(int)data.instance(i).classValue()];
 
223
    }
 
224
 
 
225
    // Change instances weight or do resampling
 
226
    if (random != null) {
 
227
      return data.resampleWithWeights(random, weightOfInstances);
 
228
    } else { 
 
229
      Instances instances = new Instances(data);
 
230
      for (int i = 0; i < data.numInstances(); i++) {
 
231
        instances.instance(i).setWeight(weightOfInstances[i]);
 
232
      }
 
233
      return instances;
 
234
    }
 
235
  }
 
236
 
 
237
  /**
 
238
   * Calculates the expected misclassification cost for each possible class value,
 
239
   * given class probability estimates. 
 
240
   *
 
241
   * @param classProbs the class probability estimates.
 
242
   * @return the expected costs.
 
243
   * @exception Exception if the wrong number of class probabilities is supplied.
 
244
   */
 
245
  public double[] expectedCosts(double[] classProbs) throws Exception {
 
246
 
 
247
    if (classProbs.length != m_size) { 
 
248
      throw new Exception("Length of probability estimates don't "
 
249
                          +"match cost matrix");
 
250
    }
 
251
 
 
252
    double[] costs = new double[m_size];
 
253
 
 
254
    for (int x = 0; x < m_size; x++) {
 
255
      for (int y = 0; y < m_size; y++) {
 
256
        Object element = getCell(y, x);
 
257
        if (!(element instanceof Double)) {
 
258
          throw new Exception("Can't use non-fixed costs in "
 
259
                              +"computing expected costs.");
 
260
        }
 
261
        costs[x] += classProbs[y] * ((Double)element).doubleValue();
 
262
      }
 
263
    }
 
264
 
 
265
    return costs;
 
266
  }
 
267
 
 
268
  /**
 
269
   * Calculates the expected misclassification cost for each possible class value,
 
270
   * given class probability estimates. 
 
271
   *
 
272
   * @param classProbs the class probability estimates.
 
273
   * @param inst the current instance for which the class probabilites
 
274
   * apply. Is used for computing any non-fixed cost values.
 
275
   * @return the expected costs.
 
276
   * @exception Exception if something goes wrong
 
277
   */
 
278
  public double[] expectedCosts(double [] classProbs, 
 
279
                                Instance inst) throws Exception {
 
280
 
 
281
    if (classProbs.length != m_size) { 
 
282
      throw new Exception("Length of probability estimates don't "
 
283
                          +"match cost matrix");
 
284
    }
 
285
 
 
286
    if (!replaceStrings()) {
 
287
      return expectedCosts(classProbs);
 
288
    }
 
289
    
 
290
    double[] costs = new double[m_size];
 
291
 
 
292
    for (int x = 0; x < m_size; x++) {
 
293
      for (int y = 0; y < m_size; y++) {
 
294
        Object element = getCell(y, x);
 
295
        double costVal;
 
296
        if (!(element instanceof Double)) {
 
297
          costVal = 
 
298
            ((AttributeExpression)element).evaluateExpression(inst);
 
299
        } else {
 
300
          costVal = ((Double)element).doubleValue();
 
301
        }
 
302
        costs[x] += classProbs[y] * costVal;
 
303
      }
 
304
    }
 
305
 
 
306
    return costs;
 
307
  }
 
308
 
 
309
  /**
 
310
   * Gets the maximum cost for a particular class value.
 
311
   *
 
312
   * @param classVal the class value.
 
313
   * @return the maximum cost.
 
314
   * @exception Exception if cost matrix contains non-fixed
 
315
   * costs
 
316
   */
 
317
  public double getMaxCost(int classVal) throws Exception {
 
318
 
 
319
    double maxCost = Double.NEGATIVE_INFINITY;
 
320
 
 
321
    for (int i = 0; i < m_size; i++) {
 
322
      Object element = getCell(classVal, i);
 
323
      if (!(element instanceof Double)) {
 
324
          throw new Exception("Can't use non-fixed costs when "
 
325
                              +"getting max cost.");
 
326
      }
 
327
      double cost = ((Double)element).doubleValue();
 
328
      if (cost > maxCost) maxCost = cost;
 
329
    }
 
330
 
 
331
    return maxCost;
 
332
  }
 
333
 
 
334
  /**
 
335
   * Gets the maximum cost for a particular class value.
 
336
   *
 
337
   * @param classVal the class value.
 
338
   * @return the maximum cost.
 
339
   * @exception Exception if cost matrix contains non-fixed
 
340
   * costs
 
341
   */
 
342
  public double getMaxCost(int classVal, Instance inst) 
 
343
    throws Exception {
 
344
 
 
345
    if (!replaceStrings()) {
 
346
      return getMaxCost(classVal);
 
347
    }
 
348
 
 
349
    double maxCost = Double.NEGATIVE_INFINITY;
 
350
    double cost;
 
351
    for (int i = 0; i < m_size; i++) {
 
352
      Object element = getCell(classVal, i);
 
353
      if (!(element instanceof Double)) {
 
354
        cost = 
 
355
          ((AttributeExpression)element).evaluateExpression(inst);
 
356
      } else {
 
357
        cost = ((Double)element).doubleValue();
 
358
      }
 
359
      if (cost > maxCost) maxCost = cost;
 
360
    }
 
361
 
 
362
    return maxCost;
 
363
  }
 
364
 
 
365
 
 
366
  /**
 
367
   * Normalizes the matrix so that the diagonal contains zeros.
 
368
   *
 
369
   */
 
370
  public void normalize() {
 
371
 
 
372
    for (int y=0; y<m_size; y++) {
 
373
      double diag = ((Double)getCell(y, y)).doubleValue();
 
374
      for (int x=0; x<m_size; x++) {
 
375
        setCell(x, y, new Double(((Double)getCell(x, y)).
 
376
                                    doubleValue() - diag));
 
377
      }
 
378
    }
 
379
  }
 
380
 
 
381
  /**
 
382
   * Loads a cost matrix in the old format from a reader. Adapted from code once sitting 
 
383
   * in Instances.java
 
384
   *
 
385
   * @param reader the reader to get the values from.
 
386
   * @exception Exception if the matrix cannot be read correctly.
 
387
   */  
 
388
  public void readOldFormat(Reader reader) throws Exception {
 
389
 
 
390
    StreamTokenizer tokenizer;
 
391
    int currentToken;
 
392
    double firstIndex, secondIndex, weight;
 
393
 
 
394
    tokenizer = new StreamTokenizer(reader);
 
395
 
 
396
    initialize();
 
397
 
 
398
    tokenizer.commentChar('%');
 
399
    tokenizer.eolIsSignificant(true);
 
400
    while (StreamTokenizer.TT_EOF != 
 
401
           (currentToken = tokenizer.nextToken())) {
 
402
 
 
403
      // Skip empty lines 
 
404
      if (currentToken == StreamTokenizer.TT_EOL) {
 
405
        continue;
 
406
      }
 
407
 
 
408
      // Get index of first class.
 
409
      if (currentToken != StreamTokenizer.TT_NUMBER) {
 
410
        throw new Exception("Only numbers and comments allowed "+
 
411
                            "in cost file!");
 
412
      }
 
413
      firstIndex = tokenizer.nval;
 
414
      if (!Utils.eq((double)(int)firstIndex,firstIndex)) {
 
415
        throw new Exception("First number in line has to be "+
 
416
                            "index of a class!");
 
417
      }
 
418
      if ((int)firstIndex >= size()) {
 
419
        throw new Exception("Class index out of range!");
 
420
      }
 
421
 
 
422
      // Get index of second class.
 
423
      if (StreamTokenizer.TT_EOF == 
 
424
          (currentToken = tokenizer.nextToken())) {
 
425
        throw new Exception("Premature end of file!");
 
426
      }
 
427
      if (currentToken == StreamTokenizer.TT_EOL) {
 
428
        throw new Exception("Premature end of line!");
 
429
      }
 
430
      if (currentToken != StreamTokenizer.TT_NUMBER) {
 
431
        throw new Exception("Only numbers and comments allowed "+
 
432
                            "in cost file!");
 
433
      }
 
434
      secondIndex = tokenizer.nval;
 
435
      if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
 
436
        throw new Exception("Second number in line has to be "+
 
437
                            "index of a class!");
 
438
      }
 
439
      if ((int)secondIndex >= size()) {
 
440
        throw new Exception("Class index out of range!");
 
441
      }
 
442
      if ((int)secondIndex == (int)firstIndex) {
 
443
        throw new Exception("Diagonal of cost matrix non-zero!");
 
444
      }
 
445
 
 
446
      // Get cost factor.
 
447
      if (StreamTokenizer.TT_EOF == 
 
448
          (currentToken = tokenizer.nextToken())) {
 
449
        throw new Exception("Premature end of file!");
 
450
      }
 
451
      if (currentToken == StreamTokenizer.TT_EOL) {
 
452
        throw new Exception("Premature end of line!");
 
453
      }
 
454
      if (currentToken != StreamTokenizer.TT_NUMBER) {
 
455
        throw new Exception("Only numbers and comments allowed "+
 
456
                            "in cost file!");
 
457
      }
 
458
      weight = tokenizer.nval;
 
459
      if (!Utils.gr(weight,0)) {
 
460
        throw new Exception("Only positive weights allowed!");
 
461
      }
 
462
      setCell((int)firstIndex, (int)secondIndex, 
 
463
                 new Double(weight));
 
464
    }
 
465
  }
 
466
 
 
467
  /**
 
468
   * Reads a matrix from a reader. The first line in the file should
 
469
   * contain the number of rows and columns. Subsequent lines
 
470
   * contain elements of the matrix. 
 
471
   * (FracPete: taken from old weka.core.Matrix class)
 
472
   *
 
473
   * @param     reader the reader containing the matrix
 
474
   * @throws    Exception if an error occurs
 
475
   * @see       #write(Writer)
 
476
   */
 
477
  public CostMatrix(Reader reader) throws Exception {
 
478
    LineNumberReader lnr = new LineNumberReader(reader);
 
479
    String line;
 
480
    int currentRow = -1;
 
481
 
 
482
    while ((line = lnr.readLine()) != null) {
 
483
 
 
484
      // Comments
 
485
      if (line.startsWith("%")) {  
 
486
        continue;
 
487
      }
 
488
      
 
489
      StringTokenizer st = new StringTokenizer(line);
 
490
      // Ignore blank lines
 
491
      if (!st.hasMoreTokens()) {
 
492
        continue;
 
493
      }
 
494
 
 
495
      if (currentRow < 0) {
 
496
        int rows = Integer.parseInt(st.nextToken());
 
497
        if (!st.hasMoreTokens()) {
 
498
          throw new Exception("Line " + lnr.getLineNumber() 
 
499
              + ": expected number of columns");
 
500
        }
 
501
 
 
502
        int cols = Integer.parseInt(st.nextToken());
 
503
        if (rows != cols) {
 
504
          throw new Exception("Trying to create a non-square cost "
 
505
                              +"matrix");
 
506
        }
 
507
        //        m_matrix = new Object[rows][cols];
 
508
        m_size = rows;
 
509
        initialize();
 
510
        currentRow++;
 
511
        continue;
 
512
 
 
513
      } else {
 
514
        if (currentRow == m_size) {
 
515
          throw new Exception("Line " + lnr.getLineNumber() 
 
516
              + ": too many rows provided");
 
517
        }
 
518
 
 
519
        for (int i = 0; i < m_size; i++) {
 
520
          if (!st.hasMoreTokens()) {
 
521
            throw new Exception("Line " + lnr.getLineNumber() 
 
522
                + ": too few matrix elements provided");
 
523
          }
 
524
 
 
525
          String nextTok = st.nextToken();
 
526
          // try to parse as a double first
 
527
          Double val = null;
 
528
          try {
 
529
            val = new Double(nextTok);
 
530
            double value = val.doubleValue();
 
531
          } catch (Exception ex) {
 
532
            val = null;
 
533
          }
 
534
          if (val == null) {
 
535
            setCell(currentRow, i, nextTok);
 
536
          } else {
 
537
            setCell(currentRow, i, val);
 
538
          }
 
539
        }
 
540
        currentRow++;
 
541
      }
 
542
    }
 
543
    
 
544
    if (currentRow == -1) {
 
545
      throw new Exception("Line " + lnr.getLineNumber() 
 
546
                          + ": expected number of rows");
 
547
    } else if (currentRow != m_size) {
 
548
      throw new Exception("Line " + lnr.getLineNumber() 
 
549
                          + ": too few rows provided");
 
550
    }
 
551
  }
 
552
 
 
553
  /**
 
554
   * Writes out a matrix. The format can be read via the 
 
555
   * CostMatrix(Reader) constructor.
 
556
   * (FracPete: taken from old weka.core.Matrix class)
 
557
   *
 
558
   * @param     w the output Writer
 
559
   * @throws    Exception if an error occurs
 
560
   */
 
561
  public void write(Writer w) throws Exception {
 
562
    w.write("% Rows\tColumns\n");
 
563
    w.write("" + m_size + "\t" + m_size + "\n");
 
564
    w.write("% Matrix elements\n");
 
565
    for(int i = 0; i < m_size; i++) {
 
566
      for(int j = 0; j < m_size; j++) {
 
567
        w.write("" + getCell(i, j) + "\t");
 
568
      }
 
569
      w.write("\n");
 
570
    }
 
571
    w.flush();
 
572
  }
 
573
 
 
574
  /**
 
575
   * converts the Matrix into a single line Matlab string: matrix is enclosed 
 
576
   * by parentheses, rows are separated by semicolon and single cells by
 
577
   * blanks, e.g., [1 2; 3 4].
 
578
   * @return      the matrix in Matlab single line format
 
579
   */
 
580
  public String toMatlab() {
 
581
    StringBuffer      result;
 
582
    int               i;
 
583
    int               n;
 
584
 
 
585
    result = new StringBuffer();
 
586
 
 
587
    result.append("[");
 
588
 
 
589
    for (i = 0; i < m_size; i++) {
 
590
      if (i > 0) {
 
591
        result.append("; ");
 
592
      }
 
593
      
 
594
      for (n = 0; n < m_size; n++) {
 
595
        if (n > 0) {
 
596
          result.append(" ");
 
597
        }
 
598
        result.append(getCell(i, n));
 
599
      }
 
600
    }
 
601
    
 
602
    result.append("]");
 
603
 
 
604
    return result.toString();
 
605
  }
 
606
 
 
607
  /**
 
608
   * Set the value of a particular cell in the matrix
 
609
   *
 
610
   * @param rowIndex the row
 
611
   * @param columnIndex the column
 
612
   * @param value the value to set
 
613
   */
 
614
  public final void setCell(int rowIndex, int columnIndex,
 
615
                               Object value) {
 
616
    m_matrix[rowIndex][columnIndex] = value;
 
617
  }
 
618
 
 
619
  /**
 
620
   * Return the contents of a particular cell. Note: this
 
621
   * method returns the Object stored at a particular cell.
 
622
   * 
 
623
   * @param rowIndex the row
 
624
   * @param columnIndex the column
 
625
   * @return the value at the cell
 
626
   */
 
627
  public final Object getCell(int rowIndex, int columnIndex) {
 
628
    return m_matrix[rowIndex][columnIndex];
 
629
  }
 
630
 
 
631
  /**
 
632
   * Return the value of a cell as a double (for legacy code)
 
633
   *
 
634
   * @param rowIndex the row
 
635
   * @param columnIndex the column
 
636
   * @return the value at a particular cell as a double
 
637
   * @exception Exception if the value is not a double
 
638
   */
 
639
  public final double getElement(int rowIndex, int columnIndex)
 
640
    throws Exception {
 
641
    if (!(m_matrix[rowIndex][columnIndex] instanceof Double)) {
 
642
      throw new Exception("Cost matrix contains non-fixed costs!");
 
643
    }
 
644
    return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue();
 
645
  }
 
646
 
 
647
  /**
 
648
   * Return the value of a cell as a double. Computes the
 
649
   * value for non-fixed costs using the supplied Instance
 
650
   *
 
651
   * @param rowIndex the row
 
652
   * @param columnIndex the column
 
653
   * @return the value from a particular cell
 
654
   * @exception Exception if something goes wrong
 
655
   */
 
656
  public final double getElement(int rowIndex, int columnIndex,
 
657
                                 Instance inst) throws Exception {
 
658
 
 
659
    if (m_matrix[rowIndex][columnIndex] instanceof Double) {
 
660
      return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue();
 
661
    } else if (m_matrix[rowIndex][columnIndex] instanceof String) {
 
662
      replaceStrings();
 
663
    }
 
664
 
 
665
    return ((AttributeExpression)m_matrix[rowIndex][columnIndex]).
 
666
      evaluateExpression(inst);
 
667
  }
 
668
 
 
669
  /**
 
670
   * Set the value of a cell as a double
 
671
   *
 
672
   * @param rowIndex the row
 
673
   * @param columnIndex the column
 
674
   * @param value the value (double) to set
 
675
   */
 
676
  public final void setElement(int rowIndex, int columnIndex,
 
677
                               double value) {
 
678
    m_matrix[rowIndex][columnIndex] = new Double(value);
 
679
  }
 
680
 
 
681
  /**
 
682
   * creates a matrix from the given Matlab string.
 
683
   * @param matlab  the matrix in matlab format
 
684
   * @return        the matrix represented by the given string
 
685
   */
 
686
  public static Matrix parseMatlab(String matlab) throws Exception {
 
687
    return Matrix.parseMatlab(matlab);
 
688
  }
 
689
 
 
690
  /** 
 
691
   * Converts a matrix to a string.
 
692
   * (FracPete: taken from old weka.core.Matrix class)
 
693
   *
 
694
   * @return    the converted string
 
695
   */
 
696
  public String toString() {
 
697
    // Determine the width required for the maximum element,
 
698
    // and check for fractional display requirement.
 
699
    double maxval = 0;
 
700
    boolean fractional = false;
 
701
    Object element = null;
 
702
    int widthNumber = 0;
 
703
    int widthExpression = 0;
 
704
    for (int i = 0; i < size(); i++) {
 
705
      for (int j = 0; j < size(); j++) {
 
706
        element = getCell(i, j);
 
707
        if (element instanceof Double) {
 
708
          double current = ((Double)element).doubleValue();
 
709
       
 
710
          if (current < 0)
 
711
            current *= -11;
 
712
          if (current > maxval)
 
713
            maxval = current;
 
714
          double fract = Math.abs(current - Math.rint(current));
 
715
          if (!fractional
 
716
              && ((Math.log(fract) / Math.log(10)) >= -2)) {
 
717
            fractional = true;
 
718
          }
 
719
        } else {
 
720
          if (element.toString().length() > widthExpression) {
 
721
            widthExpression = element.toString().length();
 
722
          }
 
723
        }
 
724
      }
 
725
    }
 
726
    if (maxval > 0) {
 
727
      widthNumber = (int)(Math.log(maxval) / Math.log(10) 
 
728
                          + (fractional ? 4 : 1));
 
729
    }
 
730
 
 
731
    int width = (widthNumber > widthExpression)
 
732
      ? widthNumber
 
733
      : widthExpression;
 
734
 
 
735
    StringBuffer text = new StringBuffer();   
 
736
    for (int i = 0; i < size(); i++) {
 
737
      for (int j = 0; j < size(); j++) {
 
738
        element = getCell(i, j);
 
739
        if (element instanceof Double) {
 
740
          text.append(" ").
 
741
            append(Utils.doubleToString(((Double)element).
 
742
                                        doubleValue(),
 
743
                                        width, (fractional ? 2 : 0)));
 
744
        } else {
 
745
          int diff = width - element.toString().length();
 
746
          if (diff > 0) {
 
747
            int left = diff % 2;
 
748
            left += diff / 2;
 
749
            String temp = Utils.padLeft(element.toString(),
 
750
                            element.toString().length()+left);
 
751
            temp = Utils.padRight(temp, width);
 
752
            text.append(" ").append(temp);
 
753
          } else {
 
754
            text.append(" ").
 
755
              append(element.toString());
 
756
          }
 
757
        }
 
758
      }
 
759
      text.append("\n");
 
760
    }
 
761
 
 
762
    return text.toString();
 
763
  } 
 
764
}