~ubuntu-branches/ubuntu/maverick/commons-math/maverick

« back to all changes in this revision

Viewing changes to src/test/java/org/apache/commons/math/optimization/direct/NelderMeadTest.java

  • Committer: Bazaar Package Importer
  • Author(s): Damien Raude-Morvan
  • Date: 2009-08-22 01:13:25 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20090822011325-hi4peq1ua5weguwn
Tags: 2.0-1
* New upstream release.
* Set Maintainer field to Debian Java Team
* Add myself as Uploaders
* Switch to Quilt patch system:
  - Refresh all patchs
  - Remove B-D on dpatch, Add B-D on quilt
  - Include patchsys-quilt.mk in debian/rules
* Bump Standards-Version to 3.8.3:
  - Add a README.source to describe patch system
* Maven POMs:
  - Add a Build-Depends-Indep dependency on maven-repo-helper
  - Use mh_installpom and mh_installjar to install the POM and the jar to the
    Maven repository
* Use default-jdk/jre:
  - Depends on java5-runtime-headless
  - Build-Depends on default-jdk
  - Use /usr/lib/jvm/default-java as JAVA_HOME
* Move api documentation to /usr/share/doc/libcommons-math-java/api
* Build-Depends on junit4 instead of junit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * Licensed to the Apache Software Foundation (ASF) under one or more
 
3
 * contributor license agreements.  See the NOTICE file distributed with
 
4
 * this work for additional information regarding copyright ownership.
 
5
 * The ASF licenses this file to You under the Apache License, Version 2.0
 
6
 * (the "License"); you may not use this file except in compliance with
 
7
 * the License.  You may obtain a copy of the License at
 
8
 *
 
9
 *      http://www.apache.org/licenses/LICENSE-2.0
 
10
 *
 
11
 * Unless required by applicable law or agreed to in writing, software
 
12
 * distributed under the License is distributed on an "AS IS" BASIS,
 
13
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
14
 * See the License for the specific language governing permissions and
 
15
 * limitations under the License.
 
16
 */
 
17
 
 
18
package org.apache.commons.math.optimization.direct;
 
19
 
 
20
import static org.junit.Assert.assertEquals;
 
21
import static org.junit.Assert.assertNotNull;
 
22
import static org.junit.Assert.assertNull;
 
23
import static org.junit.Assert.assertTrue;
 
24
import static org.junit.Assert.fail;
 
25
 
 
26
import org.apache.commons.math.ConvergenceException;
 
27
import org.apache.commons.math.FunctionEvaluationException;
 
28
import org.apache.commons.math.MathException;
 
29
import org.apache.commons.math.MaxEvaluationsExceededException;
 
30
import org.apache.commons.math.MaxIterationsExceededException;
 
31
import org.apache.commons.math.analysis.MultivariateRealFunction;
 
32
import org.apache.commons.math.analysis.MultivariateVectorialFunction;
 
33
import org.apache.commons.math.linear.Array2DRowRealMatrix;
 
34
import org.apache.commons.math.linear.RealMatrix;
 
35
import org.apache.commons.math.optimization.GoalType;
 
36
import org.apache.commons.math.optimization.LeastSquaresConverter;
 
37
import org.apache.commons.math.optimization.OptimizationException;
 
38
import org.apache.commons.math.optimization.RealPointValuePair;
 
39
import org.apache.commons.math.optimization.SimpleRealPointChecker;
 
40
import org.apache.commons.math.optimization.SimpleScalarValueChecker;
 
41
import org.junit.Test;
 
42
 
 
43
public class NelderMeadTest {
 
44
 
 
45
  @Test
 
46
  public void testFunctionEvaluationExceptions() {
 
47
      MultivariateRealFunction wrong =
 
48
          new MultivariateRealFunction() {
 
49
            private static final long serialVersionUID = 4751314470965489371L;
 
50
            public double value(double[] x) throws FunctionEvaluationException {
 
51
                if (x[0] < 0) {
 
52
                    throw new FunctionEvaluationException(x, "{0}", "oops");
 
53
                } else if (x[0] > 1) {
 
54
                    throw new FunctionEvaluationException(new RuntimeException("oops"), x);
 
55
                } else {
 
56
                    return x[0] * (1 - x[0]);
 
57
                }
 
58
            }
 
59
      };
 
60
      try {
 
61
          NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
 
62
          optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
 
63
          fail("an exception should have been thrown");
 
64
      } catch (FunctionEvaluationException ce) {
 
65
          // expected behavior
 
66
          assertNull(ce.getCause());
 
67
      } catch (Exception e) {
 
68
          fail("wrong exception caught: " + e.getMessage());
 
69
      } 
 
70
      try {
 
71
          NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
 
72
          optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
 
73
          fail("an exception should have been thrown");
 
74
      } catch (FunctionEvaluationException ce) {
 
75
          // expected behavior
 
76
          assertNotNull(ce.getCause());
 
77
      } catch (Exception e) {
 
78
          fail("wrong exception caught: " + e.getMessage());
 
79
      } 
 
80
  }
 
81
 
 
82
  @Test
 
83
  public void testMinimizeMaximize()
 
84
      throws FunctionEvaluationException, ConvergenceException {
 
85
 
 
86
      // the following function has 4 local extrema:
 
87
      final double xM        = -3.841947088256863675365;
 
88
      final double yM        = -1.391745200270734924416;
 
89
      final double xP        =  0.2286682237349059125691;
 
90
      final double yP        = -yM;
 
91
      final double valueXmYm =  0.2373295333134216789769; // local  maximum
 
92
      final double valueXmYp = -valueXmYm;                // local  minimum
 
93
      final double valueXpYm = -0.7290400707055187115322; // global minimum
 
94
      final double valueXpYp = -valueXpYm;                // global maximum
 
95
      MultivariateRealFunction fourExtrema = new MultivariateRealFunction() {
 
96
          private static final long serialVersionUID = -7039124064449091152L;
 
97
          public double value(double[] variables) throws FunctionEvaluationException {
 
98
              final double x = variables[0];
 
99
              final double y = variables[1];
 
100
              return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y));
 
101
          }
 
102
      };
 
103
 
 
104
      NelderMead optimizer = new NelderMead();
 
105
      optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30));
 
106
      optimizer.setMaxIterations(100);
 
107
      optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
 
108
      RealPointValuePair optimum;
 
109
 
 
110
      // minimization
 
111
      optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
 
112
      assertEquals(xM,        optimum.getPoint()[0], 2.0e-7);
 
113
      assertEquals(yP,        optimum.getPoint()[1], 2.0e-5);
 
114
      assertEquals(valueXmYp, optimum.getValue(),    6.0e-12);
 
115
      assertTrue(optimizer.getEvaluations() > 60);
 
116
      assertTrue(optimizer.getEvaluations() < 90);
 
117
 
 
118
      optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
 
119
      assertEquals(xP,        optimum.getPoint()[0], 5.0e-6);
 
120
      assertEquals(yM,        optimum.getPoint()[1], 6.0e-6);
 
121
      assertEquals(valueXpYm, optimum.getValue(),    1.0e-11);              
 
122
      assertTrue(optimizer.getEvaluations() > 60);
 
123
      assertTrue(optimizer.getEvaluations() < 90);
 
124
 
 
125
      // maximization
 
126
      optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
 
127
      assertEquals(xM,        optimum.getPoint()[0], 1.0e-5);
 
128
      assertEquals(yM,        optimum.getPoint()[1], 3.0e-6);
 
129
      assertEquals(valueXmYm, optimum.getValue(),    3.0e-12);
 
130
      assertTrue(optimizer.getEvaluations() > 60);
 
131
      assertTrue(optimizer.getEvaluations() < 90);
 
132
 
 
133
      optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
 
134
      assertEquals(xP,        optimum.getPoint()[0], 4.0e-6);
 
135
      assertEquals(yP,        optimum.getPoint()[1], 5.0e-6);
 
136
      assertEquals(valueXpYp, optimum.getValue(),    7.0e-12);
 
137
      assertTrue(optimizer.getEvaluations() > 60);
 
138
      assertTrue(optimizer.getEvaluations() < 90);
 
139
 
 
140
  }
 
141
 
 
142
  @Test
 
143
  public void testRosenbrock()
 
144
    throws FunctionEvaluationException, ConvergenceException {
 
145
 
 
146
    Rosenbrock rosenbrock = new Rosenbrock();
 
147
    NelderMead optimizer = new NelderMead();
 
148
    optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1, 1.0e-3));
 
149
    optimizer.setMaxIterations(100);
 
150
    optimizer.setStartConfiguration(new double[][] {
 
151
            { -1.2,  1.0 }, { 0.9, 1.2 } , {  3.5, -2.3 }
 
152
    });
 
153
    RealPointValuePair optimum =
 
154
        optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
 
155
 
 
156
    assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
 
157
    assertTrue(optimizer.getEvaluations() > 40);
 
158
    assertTrue(optimizer.getEvaluations() < 50);
 
159
    assertTrue(optimum.getValue() < 8.0e-4);
 
160
 
 
161
  }
 
162
 
 
163
  @Test
 
164
  public void testPowell()
 
165
    throws FunctionEvaluationException, ConvergenceException {
 
166
 
 
167
    Powell powell = new Powell();
 
168
    NelderMead optimizer = new NelderMead();
 
169
    optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
 
170
    optimizer.setMaxIterations(200);
 
171
    RealPointValuePair optimum =
 
172
      optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
 
173
    assertEquals(powell.getCount(), optimizer.getEvaluations());
 
174
    assertTrue(optimizer.getEvaluations() > 110);
 
175
    assertTrue(optimizer.getEvaluations() < 130);
 
176
    assertTrue(optimum.getValue() < 2.0e-3);
 
177
 
 
178
  }
 
179
 
 
180
  @Test
 
181
  public void testLeastSquares1()
 
182
  throws FunctionEvaluationException, ConvergenceException {
 
183
 
 
184
      final RealMatrix factors =
 
185
          new Array2DRowRealMatrix(new double[][] {
 
186
              { 1.0, 0.0 },
 
187
              { 0.0, 1.0 }
 
188
          }, false);
 
189
      LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
 
190
          public double[] value(double[] variables) {
 
191
              return factors.operate(variables);
 
192
          }
 
193
      }, new double[] { 2.0, -3.0 });
 
194
      NelderMead optimizer = new NelderMead();
 
195
      optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
 
196
      optimizer.setMaxIterations(200);
 
197
      RealPointValuePair optimum =
 
198
          optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
 
199
      assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5);
 
200
      assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4);
 
201
      assertTrue(optimizer.getEvaluations() > 60);
 
202
      assertTrue(optimizer.getEvaluations() < 80);
 
203
      assertTrue(optimum.getValue() < 1.0e-6);
 
204
  }
 
205
 
 
206
  @Test
 
207
  public void testLeastSquares2()
 
208
  throws FunctionEvaluationException, ConvergenceException {
 
209
 
 
210
      final RealMatrix factors =
 
211
          new Array2DRowRealMatrix(new double[][] {
 
212
              { 1.0, 0.0 },
 
213
              { 0.0, 1.0 }
 
214
          }, false);
 
215
      LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
 
216
          public double[] value(double[] variables) {
 
217
              return factors.operate(variables);
 
218
          }
 
219
      }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 });
 
220
      NelderMead optimizer = new NelderMead();
 
221
      optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
 
222
      optimizer.setMaxIterations(200);
 
223
      RealPointValuePair optimum =
 
224
          optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
 
225
      assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5);
 
226
      assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
 
227
      assertTrue(optimizer.getEvaluations() > 60);
 
228
      assertTrue(optimizer.getEvaluations() < 80);
 
229
      assertTrue(optimum.getValue() < 1.0e-6);
 
230
  }
 
231
 
 
232
  @Test
 
233
  public void testLeastSquares3()
 
234
  throws FunctionEvaluationException, ConvergenceException {
 
235
 
 
236
      final RealMatrix factors =
 
237
          new Array2DRowRealMatrix(new double[][] {
 
238
              { 1.0, 0.0 },
 
239
              { 0.0, 1.0 }
 
240
          }, false);
 
241
      LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
 
242
          public double[] value(double[] variables) {
 
243
              return factors.operate(variables);
 
244
          }
 
245
      }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] {
 
246
          { 1.0, 1.2 }, { 1.2, 2.0 }
 
247
      }));
 
248
      NelderMead optimizer = new NelderMead();
 
249
      optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
 
250
      optimizer.setMaxIterations(200);
 
251
      RealPointValuePair optimum =
 
252
          optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
 
253
      assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3);
 
254
      assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
 
255
      assertTrue(optimizer.getEvaluations() > 60);
 
256
      assertTrue(optimizer.getEvaluations() < 80);
 
257
      assertTrue(optimum.getValue() < 1.0e-6);
 
258
  }
 
259
 
 
260
  @Test(expected = MaxIterationsExceededException.class)
 
261
  public void testMaxIterations() throws MathException {
 
262
      try {
 
263
          Powell powell = new Powell();
 
264
          NelderMead optimizer = new NelderMead();
 
265
          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
 
266
          optimizer.setMaxIterations(20);
 
267
          optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
 
268
      } catch (OptimizationException oe) {
 
269
          if (oe.getCause() instanceof ConvergenceException) {
 
270
              throw (ConvergenceException) oe.getCause();
 
271
          }
 
272
          throw oe;
 
273
      }
 
274
  }
 
275
 
 
276
  @Test(expected = MaxEvaluationsExceededException.class)
 
277
  public void testMaxEvaluations() throws MathException {
 
278
      try {
 
279
          Powell powell = new Powell();
 
280
          NelderMead optimizer = new NelderMead();
 
281
          optimizer.setConvergenceChecker(new SimpleRealPointChecker(-1.0, 1.0e-3));
 
282
          optimizer.setMaxEvaluations(20);
 
283
          optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
 
284
      } catch (FunctionEvaluationException fee) {
 
285
          if (fee.getCause() instanceof ConvergenceException) {
 
286
              throw (ConvergenceException) fee.getCause();
 
287
          }
 
288
          throw fee;
 
289
      }
 
290
  }
 
291
 
 
292
  private static class Rosenbrock implements MultivariateRealFunction {
 
293
 
 
294
      private int count;
 
295
 
 
296
      public Rosenbrock() {
 
297
          count = 0;
 
298
      }
 
299
 
 
300
      public double value(double[] x) throws FunctionEvaluationException {
 
301
          ++count;
 
302
          double a = x[1] - x[0] * x[0];
 
303
          double b = 1.0 - x[0];
 
304
          return 100 * a * a + b * b;
 
305
      }
 
306
 
 
307
      public int getCount() {
 
308
          return count;
 
309
      }
 
310
 
 
311
  }
 
312
 
 
313
  private static class Powell implements MultivariateRealFunction {
 
314
 
 
315
      private int count;
 
316
 
 
317
      public Powell() {
 
318
          count = 0;
 
319
      }
 
320
 
 
321
      public double value(double[] x) throws FunctionEvaluationException {
 
322
          ++count;
 
323
          double a = x[0] + 10 * x[1];
 
324
          double b = x[2] - x[3];
 
325
          double c = x[1] - 2 * x[2];
 
326
          double d = x[0] - x[3];
 
327
          return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
 
328
      }
 
329
 
 
330
      public int getCount() {
 
331
          return count;
 
332
      }
 
333
 
 
334
  }
 
335
 
 
336
}