2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
23
package weka.classifiers;
25
import weka.core.AttributeExpression;
26
import weka.core.Instance;
27
import weka.core.Instances;
28
import weka.core.Matrix;
29
import weka.core.Utils;
31
import java.io.LineNumberReader;
32
import java.io.Reader;
33
import java.io.Serializable;
34
import java.io.StreamTokenizer;
35
import java.io.Writer;
36
import java.util.Random;
37
import java.util.StringTokenizer;
40
* Class for storing and manipulating a misclassification cost matrix.
41
* The element at position i,j in the matrix is the penalty for classifying
42
* an instance of class j as class i. Cost values can be fixed or
43
* computed on a per-instance basis (cost sensitive evaluation only)
44
* from the value of an attribute or an expression involving
48
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
49
* @version $Revision: 1.15 $
51
public class CostMatrix implements Serializable {
53
/** for serialization */
54
private static final long serialVersionUID = -1973792250544554965L;
58
/** [rows][columns] */
59
protected Object [][] m_matrix;
61
/** The deafult file extension for cost matrix files */
62
public static String FILE_EXTENSION = ".cost";
65
* Creates a default cost matrix of a particular size.
66
* All diagonal values will be 0 and all non-diagonal values 1.
68
* @param numOfClasses the number of classes that the cost matrix holds.
70
public CostMatrix(int numOfClasses) {
71
m_size = numOfClasses;
76
* Creates a cost matrix that is a copy of another.
78
* @param toCopy the matrix to copy.
80
public CostMatrix(CostMatrix toCopy) {
83
for (int i = 0; i < m_size; i++) {
84
for (int j = 0; j < m_size; j++) {
85
setCell(i, j, toCopy.getCell(i, j));
91
* Initializes the matrix
93
public void initialize() {
94
m_matrix = new Object[m_size][m_size];
95
for (int i = 0; i < m_size; i++) {
96
for (int j = 0; j < m_size; j++) {
97
setCell(i, j, i == j ? new Double(0.0) : new Double(1.0));
103
* The number of rows (and columns)
104
* @return the size of the matrix
112
* @return the number of columns
114
public int numColumns() {
120
* @return the number of rows
122
public int numRows() {
126
private boolean replaceStrings() throws Exception {
127
boolean nonDouble = false;
129
for (int i = 0; i < m_size; i++) {
130
for (int j = 0; j < m_size; j++) {
131
if (getCell(i, j) instanceof String) {
132
AttributeExpression temp = new AttributeExpression();
133
temp.convertInfixToPostfix((String)getCell(i, j));
136
} else if (getCell(i, j) instanceof AttributeExpression) {
146
* Applies the cost matrix to a set of instances. If a random number generator is
147
* supplied the instances will be resampled, otherwise they will be rewighted.
148
* Adapted from code once sitting in Instances.java
150
* @param data the instances to reweight.
151
* @param random a random number generator for resampling, if null then instances are
153
* @return a new dataset reflecting the cost of misclassification.
154
* @exception Exception if the data has no class or the matrix in inappropriate.
156
public Instances applyCostMatrix(Instances data, Random random)
159
if (replaceStrings()) {
160
// could reweight in the two class case
161
throw new Exception("Can't resample/reweight instances using "
162
+"non-fixed cost values!");
165
double sumOfWeightFactors = 0, sumOfMissClassWeights,
167
double [] weightOfInstancesInClass, weightFactor, weightOfInstances;
170
if (data.classIndex() < 0) {
171
throw new Exception("Class index is not set!");
174
if (size() != data.numClasses()) {
175
throw new Exception("Misclassification cost matrix has "+
179
weightFactor = new double[data.numClasses()];
180
weightOfInstancesInClass = new double[data.numClasses()];
181
for (int j = 0; j < data.numInstances(); j++) {
182
weightOfInstancesInClass[(int)data.instance(j).classValue()] +=
183
data.instance(j).weight();
185
sumOfWeights = Utils.sum(weightOfInstancesInClass);
187
// normalize the matrix if not already
188
for (int i=0; i< m_size; i++) {
189
if (!Utils.eq(((Double)getCell(i, i)).doubleValue(), 0)) {
190
CostMatrix normMatrix = new CostMatrix(this);
191
normMatrix.normalize();
192
return normMatrix.applyCostMatrix(data, random);
196
for (int i = 0; i < data.numClasses(); i++) {
197
// Using Kai Ming Ting's formula for deriving weights for
198
// the classes and Breiman's heuristic for multiclass
201
sumOfMissClassWeights = 0;
202
for (int j = 0; j < data.numClasses(); j++) {
203
if (Utils.sm(((Double)getCell(i,j)).doubleValue(),0)) {
204
throw new Exception("Neg. weights in misclassification "+
207
sumOfMissClassWeights
208
+= ((Double)getCell(i,j)).doubleValue();
210
weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
211
sumOfWeightFactors += sumOfMissClassWeights *
212
weightOfInstancesInClass[i];
214
for (int i = 0; i < data.numClasses(); i++) {
215
weightFactor[i] /= sumOfWeightFactors;
219
weightOfInstances = new double[data.numInstances()];
220
for (int i = 0; i < data.numInstances(); i++) {
221
weightOfInstances[i] = data.instance(i).weight()*
222
weightFactor[(int)data.instance(i).classValue()];
225
// Change instances weight or do resampling
226
if (random != null) {
227
return data.resampleWithWeights(random, weightOfInstances);
229
Instances instances = new Instances(data);
230
for (int i = 0; i < data.numInstances(); i++) {
231
instances.instance(i).setWeight(weightOfInstances[i]);
238
* Calculates the expected misclassification cost for each possible class value,
239
* given class probability estimates.
241
* @param classProbs the class probability estimates.
242
* @return the expected costs.
243
* @exception Exception if the wrong number of class probabilities is supplied.
245
public double[] expectedCosts(double[] classProbs) throws Exception {
247
if (classProbs.length != m_size) {
248
throw new Exception("Length of probability estimates don't "
249
+"match cost matrix");
252
double[] costs = new double[m_size];
254
for (int x = 0; x < m_size; x++) {
255
for (int y = 0; y < m_size; y++) {
256
Object element = getCell(y, x);
257
if (!(element instanceof Double)) {
258
throw new Exception("Can't use non-fixed costs in "
259
+"computing expected costs.");
261
costs[x] += classProbs[y] * ((Double)element).doubleValue();
269
* Calculates the expected misclassification cost for each possible class value,
270
* given class probability estimates.
272
* @param classProbs the class probability estimates.
273
* @param inst the current instance for which the class probabilites
274
* apply. Is used for computing any non-fixed cost values.
275
* @return the expected costs.
276
* @exception Exception if something goes wrong
278
public double[] expectedCosts(double [] classProbs,
279
Instance inst) throws Exception {
281
if (classProbs.length != m_size) {
282
throw new Exception("Length of probability estimates don't "
283
+"match cost matrix");
286
if (!replaceStrings()) {
287
return expectedCosts(classProbs);
290
double[] costs = new double[m_size];
292
for (int x = 0; x < m_size; x++) {
293
for (int y = 0; y < m_size; y++) {
294
Object element = getCell(y, x);
296
if (!(element instanceof Double)) {
298
((AttributeExpression)element).evaluateExpression(inst);
300
costVal = ((Double)element).doubleValue();
302
costs[x] += classProbs[y] * costVal;
310
* Gets the maximum cost for a particular class value.
312
* @param classVal the class value.
313
* @return the maximum cost.
314
* @exception Exception if cost matrix contains non-fixed
317
public double getMaxCost(int classVal) throws Exception {
319
double maxCost = Double.NEGATIVE_INFINITY;
321
for (int i = 0; i < m_size; i++) {
322
Object element = getCell(classVal, i);
323
if (!(element instanceof Double)) {
324
throw new Exception("Can't use non-fixed costs when "
325
+"getting max cost.");
327
double cost = ((Double)element).doubleValue();
328
if (cost > maxCost) maxCost = cost;
335
* Gets the maximum cost for a particular class value.
337
* @param classVal the class value.
338
* @return the maximum cost.
339
* @exception Exception if cost matrix contains non-fixed
342
public double getMaxCost(int classVal, Instance inst)
345
if (!replaceStrings()) {
346
return getMaxCost(classVal);
349
double maxCost = Double.NEGATIVE_INFINITY;
351
for (int i = 0; i < m_size; i++) {
352
Object element = getCell(classVal, i);
353
if (!(element instanceof Double)) {
355
((AttributeExpression)element).evaluateExpression(inst);
357
cost = ((Double)element).doubleValue();
359
if (cost > maxCost) maxCost = cost;
367
* Normalizes the matrix so that the diagonal contains zeros.
370
public void normalize() {
372
for (int y=0; y<m_size; y++) {
373
double diag = ((Double)getCell(y, y)).doubleValue();
374
for (int x=0; x<m_size; x++) {
375
setCell(x, y, new Double(((Double)getCell(x, y)).
376
doubleValue() - diag));
382
* Loads a cost matrix in the old format from a reader. Adapted from code once sitting
385
* @param reader the reader to get the values from.
386
* @exception Exception if the matrix cannot be read correctly.
388
public void readOldFormat(Reader reader) throws Exception {
390
StreamTokenizer tokenizer;
392
double firstIndex, secondIndex, weight;
394
tokenizer = new StreamTokenizer(reader);
398
tokenizer.commentChar('%');
399
tokenizer.eolIsSignificant(true);
400
while (StreamTokenizer.TT_EOF !=
401
(currentToken = tokenizer.nextToken())) {
404
if (currentToken == StreamTokenizer.TT_EOL) {
408
// Get index of first class.
409
if (currentToken != StreamTokenizer.TT_NUMBER) {
410
throw new Exception("Only numbers and comments allowed "+
413
firstIndex = tokenizer.nval;
414
if (!Utils.eq((double)(int)firstIndex,firstIndex)) {
415
throw new Exception("First number in line has to be "+
416
"index of a class!");
418
if ((int)firstIndex >= size()) {
419
throw new Exception("Class index out of range!");
422
// Get index of second class.
423
if (StreamTokenizer.TT_EOF ==
424
(currentToken = tokenizer.nextToken())) {
425
throw new Exception("Premature end of file!");
427
if (currentToken == StreamTokenizer.TT_EOL) {
428
throw new Exception("Premature end of line!");
430
if (currentToken != StreamTokenizer.TT_NUMBER) {
431
throw new Exception("Only numbers and comments allowed "+
434
secondIndex = tokenizer.nval;
435
if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
436
throw new Exception("Second number in line has to be "+
437
"index of a class!");
439
if ((int)secondIndex >= size()) {
440
throw new Exception("Class index out of range!");
442
if ((int)secondIndex == (int)firstIndex) {
443
throw new Exception("Diagonal of cost matrix non-zero!");
447
if (StreamTokenizer.TT_EOF ==
448
(currentToken = tokenizer.nextToken())) {
449
throw new Exception("Premature end of file!");
451
if (currentToken == StreamTokenizer.TT_EOL) {
452
throw new Exception("Premature end of line!");
454
if (currentToken != StreamTokenizer.TT_NUMBER) {
455
throw new Exception("Only numbers and comments allowed "+
458
weight = tokenizer.nval;
459
if (!Utils.gr(weight,0)) {
460
throw new Exception("Only positive weights allowed!");
462
setCell((int)firstIndex, (int)secondIndex,
468
* Reads a matrix from a reader. The first line in the file should
469
* contain the number of rows and columns. Subsequent lines
470
* contain elements of the matrix.
471
* (FracPete: taken from old weka.core.Matrix class)
473
* @param reader the reader containing the matrix
474
* @throws Exception if an error occurs
475
* @see #write(Writer)
477
public CostMatrix(Reader reader) throws Exception {
478
LineNumberReader lnr = new LineNumberReader(reader);
482
while ((line = lnr.readLine()) != null) {
485
if (line.startsWith("%")) {
489
StringTokenizer st = new StringTokenizer(line);
490
// Ignore blank lines
491
if (!st.hasMoreTokens()) {
495
if (currentRow < 0) {
496
int rows = Integer.parseInt(st.nextToken());
497
if (!st.hasMoreTokens()) {
498
throw new Exception("Line " + lnr.getLineNumber()
499
+ ": expected number of columns");
502
int cols = Integer.parseInt(st.nextToken());
504
throw new Exception("Trying to create a non-square cost "
507
// m_matrix = new Object[rows][cols];
514
if (currentRow == m_size) {
515
throw new Exception("Line " + lnr.getLineNumber()
516
+ ": too many rows provided");
519
for (int i = 0; i < m_size; i++) {
520
if (!st.hasMoreTokens()) {
521
throw new Exception("Line " + lnr.getLineNumber()
522
+ ": too few matrix elements provided");
525
String nextTok = st.nextToken();
526
// try to parse as a double first
529
val = new Double(nextTok);
530
double value = val.doubleValue();
531
} catch (Exception ex) {
535
setCell(currentRow, i, nextTok);
537
setCell(currentRow, i, val);
544
if (currentRow == -1) {
545
throw new Exception("Line " + lnr.getLineNumber()
546
+ ": expected number of rows");
547
} else if (currentRow != m_size) {
548
throw new Exception("Line " + lnr.getLineNumber()
549
+ ": too few rows provided");
554
* Writes out a matrix. The format can be read via the
555
* CostMatrix(Reader) constructor.
556
* (FracPete: taken from old weka.core.Matrix class)
558
* @param w the output Writer
559
* @throws Exception if an error occurs
561
public void write(Writer w) throws Exception {
562
w.write("% Rows\tColumns\n");
563
w.write("" + m_size + "\t" + m_size + "\n");
564
w.write("% Matrix elements\n");
565
for(int i = 0; i < m_size; i++) {
566
for(int j = 0; j < m_size; j++) {
567
w.write("" + getCell(i, j) + "\t");
575
* converts the Matrix into a single line Matlab string: matrix is enclosed
576
* by parentheses, rows are separated by semicolon and single cells by
577
* blanks, e.g., [1 2; 3 4].
578
* @return the matrix in Matlab single line format
580
public String toMatlab() {
585
result = new StringBuffer();
589
for (i = 0; i < m_size; i++) {
594
for (n = 0; n < m_size; n++) {
598
result.append(getCell(i, n));
604
return result.toString();
608
* Set the value of a particular cell in the matrix
610
* @param rowIndex the row
611
* @param columnIndex the column
612
* @param value the value to set
614
public final void setCell(int rowIndex, int columnIndex,
616
m_matrix[rowIndex][columnIndex] = value;
620
* Return the contents of a particular cell. Note: this
621
* method returns the Object stored at a particular cell.
623
* @param rowIndex the row
624
* @param columnIndex the column
625
* @return the value at the cell
627
public final Object getCell(int rowIndex, int columnIndex) {
628
return m_matrix[rowIndex][columnIndex];
632
* Return the value of a cell as a double (for legacy code)
634
* @param rowIndex the row
635
* @param columnIndex the column
636
* @return the value at a particular cell as a double
637
* @exception Exception if the value is not a double
639
public final double getElement(int rowIndex, int columnIndex)
641
if (!(m_matrix[rowIndex][columnIndex] instanceof Double)) {
642
throw new Exception("Cost matrix contains non-fixed costs!");
644
return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue();
648
* Return the value of a cell as a double. Computes the
649
* value for non-fixed costs using the supplied Instance
651
* @param rowIndex the row
652
* @param columnIndex the column
653
* @return the value from a particular cell
654
* @exception Exception if something goes wrong
656
public final double getElement(int rowIndex, int columnIndex,
657
Instance inst) throws Exception {
659
if (m_matrix[rowIndex][columnIndex] instanceof Double) {
660
return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue();
661
} else if (m_matrix[rowIndex][columnIndex] instanceof String) {
665
return ((AttributeExpression)m_matrix[rowIndex][columnIndex]).
666
evaluateExpression(inst);
670
* Set the value of a cell as a double
672
* @param rowIndex the row
673
* @param columnIndex the column
674
* @param value the value (double) to set
676
public final void setElement(int rowIndex, int columnIndex,
678
m_matrix[rowIndex][columnIndex] = new Double(value);
682
* creates a matrix from the given Matlab string.
683
* @param matlab the matrix in matlab format
684
* @return the matrix represented by the given string
686
public static Matrix parseMatlab(String matlab) throws Exception {
687
return Matrix.parseMatlab(matlab);
691
* Converts a matrix to a string.
692
* (FracPete: taken from old weka.core.Matrix class)
694
* @return the converted string
696
public String toString() {
697
// Determine the width required for the maximum element,
698
// and check for fractional display requirement.
700
boolean fractional = false;
701
Object element = null;
703
int widthExpression = 0;
704
for (int i = 0; i < size(); i++) {
705
for (int j = 0; j < size(); j++) {
706
element = getCell(i, j);
707
if (element instanceof Double) {
708
double current = ((Double)element).doubleValue();
712
if (current > maxval)
714
double fract = Math.abs(current - Math.rint(current));
716
&& ((Math.log(fract) / Math.log(10)) >= -2)) {
720
if (element.toString().length() > widthExpression) {
721
widthExpression = element.toString().length();
727
widthNumber = (int)(Math.log(maxval) / Math.log(10)
728
+ (fractional ? 4 : 1));
731
int width = (widthNumber > widthExpression)
735
StringBuffer text = new StringBuffer();
736
for (int i = 0; i < size(); i++) {
737
for (int j = 0; j < size(); j++) {
738
element = getCell(i, j);
739
if (element instanceof Double) {
741
append(Utils.doubleToString(((Double)element).
743
width, (fractional ? 2 : 0)));
745
int diff = width - element.toString().length();
749
String temp = Utils.padLeft(element.toString(),
750
element.toString().length()+left);
751
temp = Utils.padRight(temp, width);
752
text.append(" ").append(temp);
755
append(element.toString());
762
return text.toString();