~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/filters/unsupervised/attribute/Standardize.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    Standardize.java
 
19
 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.filters.unsupervised.attribute;
 
24
 
 
25
import weka.core.Capabilities;
 
26
import weka.core.Instance;
 
27
import weka.core.Instances;
 
28
import weka.core.SparseInstance;
 
29
import weka.core.Utils;
 
30
import weka.core.Capabilities.Capability;
 
31
import weka.filters.Sourcable;
 
32
import weka.filters.UnsupervisedFilter;
 
33
 
 
34
/** 
 
35
 <!-- globalinfo-start -->
 
36
 * Standardizes all numeric attributes in the given dataset to have zero mean and unit variance (apart from the class attribute, if set).
 
37
 * <p/>
 
38
 <!-- globalinfo-end -->
 
39
 *
 
40
 <!-- options-start -->
 
41
 * Valid options are: <p/>
 
42
 * 
 
43
 * <pre> -unset-class-temporarily
 
44
 *  Unsets the class index temporarily before the filter is
 
45
 *  applied to the data.
 
46
 *  (default: no)</pre>
 
47
 * 
 
48
 <!-- options-end -->
 
49
 * 
 
50
 * @author Eibe Frank (eibe@cs.waikato.ac.nz) 
 
51
 * @version $Revision: 1.11 $
 
52
 */
 
53
public class Standardize 
 
54
  extends PotentialClassIgnorer 
 
55
  implements UnsupervisedFilter, Sourcable {
 
56
  
 
57
  /** for serialization */
 
58
  static final long serialVersionUID = -6830769026855053281L;
 
59
 
 
60
  /** The means */
 
61
  private double [] m_Means;
 
62
  
 
63
  /** The variances */
 
64
  private double [] m_StdDevs;
 
65
 
 
66
  /**
 
67
   * Returns a string describing this filter
 
68
   *
 
69
   * @return a description of the filter suitable for
 
70
   * displaying in the explorer/experimenter gui
 
71
   */
 
72
  public String globalInfo() {
 
73
 
 
74
    return "Standardizes all numeric attributes in the given dataset "
 
75
      + "to have zero mean and unit variance (apart from the class attribute, if set).";
 
76
  }
 
77
 
 
78
  /** 
 
79
   * Returns the Capabilities of this filter.
 
80
   *
 
81
   * @return            the capabilities of this object
 
82
   * @see               Capabilities
 
83
   */
 
84
  public Capabilities getCapabilities() {
 
85
    Capabilities result = super.getCapabilities();
 
86
 
 
87
    // attributes
 
88
    result.enableAllAttributes();
 
89
    result.enable(Capability.MISSING_VALUES);
 
90
    
 
91
    // class
 
92
    result.enableAllClasses();
 
93
    result.enable(Capability.MISSING_CLASS_VALUES);
 
94
    result.enable(Capability.NO_CLASS);
 
95
    
 
96
    return result;
 
97
  }
 
98
 
 
99
  /**
 
100
   * Sets the format of the input instances.
 
101
   *
 
102
   * @param instanceInfo an Instances object containing the input 
 
103
   * instance structure (any instances contained in the object are 
 
104
   * ignored - only the structure is required).
 
105
   * @return true if the outputFormat may be collected immediately
 
106
   * @throws Exception if the input format can't be set 
 
107
   * successfully
 
108
   */
 
109
  public boolean setInputFormat(Instances instanceInfo) 
 
110
       throws Exception {
 
111
 
 
112
    super.setInputFormat(instanceInfo);
 
113
    setOutputFormat(instanceInfo);
 
114
    m_Means = m_StdDevs = null;
 
115
    return true;
 
116
  }
 
117
 
 
118
  /**
 
119
   * Input an instance for filtering. Filter requires all
 
120
   * training instances be read before producing output.
 
121
   *
 
122
   * @param instance the input instance
 
123
   * @return true if the filtered instance may now be
 
124
   * collected with output().
 
125
   * @throws IllegalStateException if no input format has been set.
 
126
   */
 
127
  public boolean input(Instance instance) throws Exception {
 
128
 
 
129
    if (getInputFormat() == null) {
 
130
      throw new IllegalStateException("No input instance format defined");
 
131
    }
 
132
    if (m_NewBatch) {
 
133
      resetQueue();
 
134
      m_NewBatch = false;
 
135
    }
 
136
    if (m_Means == null) {
 
137
      bufferInput(instance);
 
138
      return false;
 
139
    } else {
 
140
      convertInstance(instance);
 
141
      return true;
 
142
    }
 
143
  }
 
144
 
 
145
  /**
 
146
   * Signify that this batch of input to the filter is finished. 
 
147
   * If the filter requires all instances prior to filtering,
 
148
   * output() may now be called to retrieve the filtered instances.
 
149
   *
 
150
   * @return true if there are instances pending output
 
151
   * @exception Exception if an error occurs
 
152
   * @exception IllegalStateException if no input structure has been defined
 
153
   */
 
154
  public boolean batchFinished() throws Exception {
 
155
 
 
156
    if (getInputFormat() == null) {
 
157
      throw new IllegalStateException("No input instance format defined");
 
158
    }
 
159
    if (m_Means == null) {
 
160
      Instances input = getInputFormat();
 
161
      m_Means = new double[input.numAttributes()];
 
162
      m_StdDevs = new double[input.numAttributes()];
 
163
      for (int i = 0; i < input.numAttributes(); i++) {
 
164
        if (input.attribute(i).isNumeric() &&
 
165
            (input.classIndex() != i)) {
 
166
          m_Means[i] = input.meanOrMode(i);
 
167
          m_StdDevs[i] = Math.sqrt(input.variance(i));
 
168
        }
 
169
      }
 
170
 
 
171
      // Convert pending input instances
 
172
      for(int i = 0; i < input.numInstances(); i++) {
 
173
        convertInstance(input.instance(i));
 
174
      }
 
175
    } 
 
176
    // Free memory
 
177
    flushInput();
 
178
 
 
179
    m_NewBatch = true;
 
180
    return (numPendingOutput() != 0);
 
181
  }
 
182
 
 
183
  /**
 
184
   * Convert a single instance over. The converted instance is 
 
185
   * added to the end of the output queue.
 
186
   *
 
187
   * @param instance the instance to convert
 
188
   * @exception Exception if an error occurs
 
189
   */
 
190
  private void convertInstance(Instance instance) throws Exception {
 
191
  
 
192
    Instance inst = null;
 
193
    if (instance instanceof SparseInstance) {
 
194
      double[] newVals = new double[instance.numAttributes()];
 
195
      int[] newIndices = new int[instance.numAttributes()];
 
196
      double[] vals = instance.toDoubleArray();
 
197
      int ind = 0;
 
198
      for (int j = 0; j < instance.numAttributes(); j++) {
 
199
        double value;
 
200
        if (instance.attribute(j).isNumeric() &&
 
201
            (!Instance.isMissingValue(vals[j])) &&
 
202
            (getInputFormat().classIndex() != j)) {
 
203
          
 
204
          // Just subtract the mean if the standard deviation is zero
 
205
          if (m_StdDevs[j] > 0) { 
 
206
            value = (vals[j] - m_Means[j]) / m_StdDevs[j];
 
207
          } else {
 
208
            value = vals[j] - m_Means[j];
 
209
          }
 
210
          if (Double.isNaN(value)) {
 
211
            throw new Exception("A NaN value was generated "
 
212
                                + "while standardizing attribute " 
 
213
                                + instance.attribute(j).name());
 
214
          }
 
215
          if (value != 0.0) {
 
216
            newVals[ind] = value;
 
217
            newIndices[ind] = j;
 
218
            ind++;
 
219
          }
 
220
        } else {
 
221
          value = vals[j];
 
222
          if (value != 0.0) {
 
223
            newVals[ind] = value;
 
224
            newIndices[ind] = j;
 
225
            ind++;
 
226
          }
 
227
        }
 
228
      } 
 
229
      double[] tempVals = new double[ind];
 
230
      int[] tempInd = new int[ind];
 
231
      System.arraycopy(newVals, 0, tempVals, 0, ind);
 
232
      System.arraycopy(newIndices, 0, tempInd, 0, ind);
 
233
      inst = new SparseInstance(instance.weight(), tempVals, tempInd,
 
234
                                instance.numAttributes());
 
235
    } else {
 
236
      double[] vals = instance.toDoubleArray();
 
237
      for (int j = 0; j < getInputFormat().numAttributes(); j++) {
 
238
        if (instance.attribute(j).isNumeric() &&
 
239
            (!Instance.isMissingValue(vals[j])) &&
 
240
            (getInputFormat().classIndex() != j)) {
 
241
          
 
242
          // Just subtract the mean if the standard deviation is zero
 
243
          if (m_StdDevs[j] > 0) { 
 
244
            vals[j] = (vals[j] - m_Means[j]) / m_StdDevs[j];
 
245
          } else {
 
246
            vals[j] = (vals[j] - m_Means[j]);
 
247
          }
 
248
          if (Double.isNaN(vals[j])) {
 
249
            throw new Exception("A NaN value was generated "
 
250
                                + "while standardizing attribute " 
 
251
                                + instance.attribute(j).name());
 
252
          }
 
253
        }
 
254
      } 
 
255
      inst = new Instance(instance.weight(), vals);
 
256
    }
 
257
    inst.setDataset(instance.dataset());
 
258
    push(inst);
 
259
  }
 
260
  
 
261
  /**
 
262
   * Returns a string that describes the filter as source. The
 
263
   * filter will be contained in a class with the given name (there may
 
264
   * be auxiliary classes),
 
265
   * and will contain two methods with these signatures:
 
266
   * <pre><code>
 
267
   * // converts one row
 
268
   * public static Object[] filter(Object[] i);
 
269
   * // converts a full dataset (first dimension is row index)
 
270
   * public static Object[][] filter(Object[][] i);
 
271
   * </code></pre>
 
272
   * where the array <code>i</code> contains elements that are either
 
273
   * Double, String, with missing values represented as null. The generated
 
274
   * code is public domain and comes with no warranty.
 
275
   *
 
276
   * @param className   the name that should be given to the source class.
 
277
   * @param data        the dataset used for initializing the filter
 
278
   * @return            the object source described by a string
 
279
   * @throws Exception  if the source can't be computed
 
280
   */
 
281
  public String toSource(String className, Instances data) throws Exception {
 
282
    StringBuffer        result;
 
283
    boolean[]           process;
 
284
    int                 i;
 
285
    
 
286
    result = new StringBuffer();
 
287
    
 
288
    // determine what attributes were processed
 
289
    process = new boolean[data.numAttributes()];
 
290
    for (i = 0; i < data.numAttributes(); i++) {
 
291
      process[i] = (data.attribute(i).isNumeric() && (i != data.classIndex()));
 
292
    }
 
293
    
 
294
    result.append("class " + className + " {\n");
 
295
    result.append("\n");
 
296
    result.append("  /** lists which attributes will be processed */\n");
 
297
    result.append("  protected final static boolean[] PROCESS = new boolean[]{" + Utils.arrayToString(process) + "};\n");
 
298
    result.append("\n");
 
299
    result.append("  /** the computed means */\n");
 
300
    result.append("  protected final static double[] MEANS = new double[]{" + Utils.arrayToString(m_Means) + "};\n");
 
301
    result.append("\n");
 
302
    result.append("  /** the computed standard deviations */\n");
 
303
    result.append("  protected final static double[] STDEVS = new double[]{" + Utils.arrayToString(m_StdDevs) + "};\n");
 
304
    result.append("\n");
 
305
    result.append("  /**\n");
 
306
    result.append("   * filters a single row\n");
 
307
    result.append("   * \n");
 
308
    result.append("   * @param i the row to process\n");
 
309
    result.append("   * @return the processed row\n");
 
310
    result.append("   */\n");
 
311
    result.append("  public static Object[] filter(Object[] i) {\n");
 
312
    result.append("    Object[] result;\n");
 
313
    result.append("\n");
 
314
    result.append("    result = new Object[i.length];\n");
 
315
    result.append("    for (int n = 0; n < i.length; n++) {\n");
 
316
    result.append("      if (PROCESS[n] && (i[n] != null)) {\n");
 
317
    result.append("        if (STDEVS[n] > 0)\n");
 
318
    result.append("          result[n] = (((Double) i[n]) - MEANS[n]) / STDEVS[n];\n");
 
319
    result.append("        else\n");
 
320
    result.append("          result[n] = ((Double) i[n]) - MEANS[n];\n");
 
321
    result.append("      }\n");
 
322
    result.append("      else {\n");
 
323
    result.append("        result[n] = i[n];\n");
 
324
    result.append("      }\n");
 
325
    result.append("    }\n");
 
326
    result.append("\n");
 
327
    result.append("    return result;\n");
 
328
    result.append("  }\n");
 
329
    result.append("\n");
 
330
    result.append("  /**\n");
 
331
    result.append("   * filters multiple rows\n");
 
332
    result.append("   * \n");
 
333
    result.append("   * @param i the rows to process\n");
 
334
    result.append("   * @return the processed rows\n");
 
335
    result.append("   */\n");
 
336
    result.append("  public static Object[][] filter(Object[][] i) {\n");
 
337
    result.append("    Object[][] result;\n");
 
338
    result.append("\n");
 
339
    result.append("    result = new Object[i.length][];\n");
 
340
    result.append("    for (int n = 0; n < i.length; n++) {\n");
 
341
    result.append("      result[n] = filter(i[n]);\n");
 
342
    result.append("    }\n");
 
343
    result.append("\n");
 
344
    result.append("    return result;\n");
 
345
    result.append("  }\n");
 
346
    result.append("}\n");
 
347
    
 
348
    return result.toString();
 
349
  }
 
350
 
 
351
  /**
 
352
   * Main method for testing this class.
 
353
   *
 
354
   * @param argv should contain arguments to the filter: 
 
355
   * use -h for help
 
356
   */
 
357
  public static void main(String [] argv) {
 
358
    runFilter(new Standardize(), argv);
 
359
  }
 
360
}