001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization.direct;
019    
020    import static org.junit.Assert.assertEquals;
021    import static org.junit.Assert.assertNotNull;
022    import static org.junit.Assert.assertNull;
023    import static org.junit.Assert.assertTrue;
024    import static org.junit.Assert.fail;
025    
026    import org.apache.commons.math.ConvergenceException;
027    import org.apache.commons.math.FunctionEvaluationException;
028    import org.apache.commons.math.MathException;
029    import org.apache.commons.math.MaxEvaluationsExceededException;
030    import org.apache.commons.math.MaxIterationsExceededException;
031    import org.apache.commons.math.analysis.MultivariateRealFunction;
032    import org.apache.commons.math.analysis.MultivariateVectorialFunction;
033    import org.apache.commons.math.linear.Array2DRowRealMatrix;
034    import org.apache.commons.math.linear.RealMatrix;
035    import org.apache.commons.math.optimization.GoalType;
036    import org.apache.commons.math.optimization.LeastSquaresConverter;
037    import org.apache.commons.math.optimization.OptimizationException;
038    import org.apache.commons.math.optimization.RealPointValuePair;
039    import org.apache.commons.math.optimization.SimpleRealPointChecker;
040    import org.apache.commons.math.optimization.SimpleScalarValueChecker;
041    import org.junit.Test;
042    
043    public class NelderMeadTest {
044    
045      @Test
046      public void testFunctionEvaluationExceptions() {
047          MultivariateRealFunction wrong =
048              new MultivariateRealFunction() {
049                private static final long serialVersionUID = 4751314470965489371L;
050                public double value(double[] x) throws FunctionEvaluationException {
051                    if (x[0] < 0) {
052                        throw new FunctionEvaluationException(x, "{0}", "oops");
053                    } else if (x[0] > 1) {
054                        throw new FunctionEvaluationException(new RuntimeException("oops"), x);
055                    } else {
056                        return x[0] * (1 - x[0]);
057                    }
058                }
059          };
060          try {
061              NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
062              optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
063              fail("an exception should have been thrown");
064          } catch (FunctionEvaluationException ce) {
065              // expected behavior
066              assertNull(ce.getCause());
067          } catch (Exception e) {
068              fail("wrong exception caught: " + e.getMessage());
069          } 
070          try {
071              NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
072              optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
073              fail("an exception should have been thrown");
074          } catch (FunctionEvaluationException ce) {
075              // expected behavior
076              assertNotNull(ce.getCause());
077          } catch (Exception e) {
078              fail("wrong exception caught: " + e.getMessage());
079          } 
080      }
081    
082      @Test
083      public void testMinimizeMaximize()
084          throws FunctionEvaluationException, ConvergenceException {
085    
086          // the following function has 4 local extrema:
087          final double xM        = -3.841947088256863675365;
088          final double yM        = -1.391745200270734924416;
089          final double xP        =  0.2286682237349059125691;
090          final double yP        = -yM;
091          final double valueXmYm =  0.2373295333134216789769; // local  maximum
092          final double valueXmYp = -valueXmYm;                // local  minimum
093          final double valueXpYm = -0.7290400707055187115322; // global minimum
094          final double valueXpYp = -valueXpYm;                // global maximum
095          MultivariateRealFunction fourExtrema = new MultivariateRealFunction() {
096              private static final long serialVersionUID = -7039124064449091152L;
097              public double value(double[] variables) throws FunctionEvaluationException {
098                  final double x = variables[0];
099                  final double y = variables[1];
100                  return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y));
101              }
102          };
103    
104          NelderMead optimizer = new NelderMead();
105          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30));
106          optimizer.setMaxIterations(100);
107          optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
108          RealPointValuePair optimum;
109    
110          // minimization
111          optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
112          assertEquals(xM,        optimum.getPoint()[0], 2.0e-7);
113          assertEquals(yP,        optimum.getPoint()[1], 2.0e-5);
114          assertEquals(valueXmYp, optimum.getValue(),    6.0e-12);
115          assertTrue(optimizer.getEvaluations() > 60);
116          assertTrue(optimizer.getEvaluations() < 90);
117    
118          optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
119          assertEquals(xP,        optimum.getPoint()[0], 5.0e-6);
120          assertEquals(yM,        optimum.getPoint()[1], 6.0e-6);
121          assertEquals(valueXpYm, optimum.getValue(),    1.0e-11);              
122          assertTrue(optimizer.getEvaluations() > 60);
123          assertTrue(optimizer.getEvaluations() < 90);
124    
125          // maximization
126          optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
127          assertEquals(xM,        optimum.getPoint()[0], 1.0e-5);
128          assertEquals(yM,        optimum.getPoint()[1], 3.0e-6);
129          assertEquals(valueXmYm, optimum.getValue(),    3.0e-12);
130          assertTrue(optimizer.getEvaluations() > 60);
131          assertTrue(optimizer.getEvaluations() < 90);
132    
133          optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
134          assertEquals(xP,        optimum.getPoint()[0], 4.0e-6);
135          assertEquals(yP,        optimum.getPoint()[1], 5.0e-6);
136          assertEquals(valueXpYp, optimum.getValue(),    7.0e-12);
137          assertTrue(optimizer.getEvaluations() > 60);
138          assertTrue(optimizer.getEvaluations() < 90);
139    
140      }
141    
142      @Test
143      public void testRosenbrock()
144        throws FunctionEvaluationException, ConvergenceException {
145    
146        Rosenbrock rosenbrock = new Rosenbrock();
147        NelderMead optimizer = new NelderMead();
148        optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1, 1.0e-3));
149        optimizer.setMaxIterations(100);
150        optimizer.setStartConfiguration(new double[][] {
151                { -1.2,  1.0 }, { 0.9, 1.2 } , {  3.5, -2.3 }
152        });
153        RealPointValuePair optimum =
154            optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
155    
156        assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
157        assertTrue(optimizer.getEvaluations() > 40);
158        assertTrue(optimizer.getEvaluations() < 50);
159        assertTrue(optimum.getValue() < 8.0e-4);
160    
161      }
162    
163      @Test
164      public void testPowell()
165        throws FunctionEvaluationException, ConvergenceException {
166    
167        Powell powell = new Powell();
168        NelderMead optimizer = new NelderMead();
169        optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
170        optimizer.setMaxIterations(200);
171        RealPointValuePair optimum =
172          optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
173        assertEquals(powell.getCount(), optimizer.getEvaluations());
174        assertTrue(optimizer.getEvaluations() > 110);
175        assertTrue(optimizer.getEvaluations() < 130);
176        assertTrue(optimum.getValue() < 2.0e-3);
177    
178      }
179    
180      @Test
181      public void testLeastSquares1()
182      throws FunctionEvaluationException, ConvergenceException {
183    
184          final RealMatrix factors =
185              new Array2DRowRealMatrix(new double[][] {
186                  { 1.0, 0.0 },
187                  { 0.0, 1.0 }
188              }, false);
189          LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
190              public double[] value(double[] variables) {
191                  return factors.operate(variables);
192              }
193          }, new double[] { 2.0, -3.0 });
194          NelderMead optimizer = new NelderMead();
195          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
196          optimizer.setMaxIterations(200);
197          RealPointValuePair optimum =
198              optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
199          assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5);
200          assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4);
201          assertTrue(optimizer.getEvaluations() > 60);
202          assertTrue(optimizer.getEvaluations() < 80);
203          assertTrue(optimum.getValue() < 1.0e-6);
204      }
205    
206      @Test
207      public void testLeastSquares2()
208      throws FunctionEvaluationException, ConvergenceException {
209    
210          final RealMatrix factors =
211              new Array2DRowRealMatrix(new double[][] {
212                  { 1.0, 0.0 },
213                  { 0.0, 1.0 }
214              }, false);
215          LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
216              public double[] value(double[] variables) {
217                  return factors.operate(variables);
218              }
219          }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 });
220          NelderMead optimizer = new NelderMead();
221          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
222          optimizer.setMaxIterations(200);
223          RealPointValuePair optimum =
224              optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
225          assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5);
226          assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
227          assertTrue(optimizer.getEvaluations() > 60);
228          assertTrue(optimizer.getEvaluations() < 80);
229          assertTrue(optimum.getValue() < 1.0e-6);
230      }
231    
232      @Test
233      public void testLeastSquares3()
234      throws FunctionEvaluationException, ConvergenceException {
235    
236          final RealMatrix factors =
237              new Array2DRowRealMatrix(new double[][] {
238                  { 1.0, 0.0 },
239                  { 0.0, 1.0 }
240              }, false);
241          LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
242              public double[] value(double[] variables) {
243                  return factors.operate(variables);
244              }
245          }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] {
246              { 1.0, 1.2 }, { 1.2, 2.0 }
247          }));
248          NelderMead optimizer = new NelderMead();
249          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
250          optimizer.setMaxIterations(200);
251          RealPointValuePair optimum =
252              optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
253          assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3);
254          assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
255          assertTrue(optimizer.getEvaluations() > 60);
256          assertTrue(optimizer.getEvaluations() < 80);
257          assertTrue(optimum.getValue() < 1.0e-6);
258      }
259    
260      @Test(expected = MaxIterationsExceededException.class)
261      public void testMaxIterations() throws MathException {
262          try {
263              Powell powell = new Powell();
264              NelderMead optimizer = new NelderMead();
265              optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
266              optimizer.setMaxIterations(20);
267              optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
268          } catch (OptimizationException oe) {
269              if (oe.getCause() instanceof ConvergenceException) {
270                  throw (ConvergenceException) oe.getCause();
271              }
272              throw oe;
273          }
274      }
275    
276      @Test(expected = MaxEvaluationsExceededException.class)
277      public void testMaxEvaluations() throws MathException {
278          try {
279              Powell powell = new Powell();
280              NelderMead optimizer = new NelderMead();
281              optimizer.setConvergenceChecker(new SimpleRealPointChecker(-1.0, 1.0e-3));
282              optimizer.setMaxEvaluations(20);
283              optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
284          } catch (FunctionEvaluationException fee) {
285              if (fee.getCause() instanceof ConvergenceException) {
286                  throw (ConvergenceException) fee.getCause();
287              }
288              throw fee;
289          }
290      }
291    
292      private static class Rosenbrock implements MultivariateRealFunction {
293    
294          private int count;
295    
296          public Rosenbrock() {
297              count = 0;
298          }
299    
300          public double value(double[] x) throws FunctionEvaluationException {
301              ++count;
302              double a = x[1] - x[0] * x[0];
303              double b = 1.0 - x[0];
304              return 100 * a * a + b * b;
305          }
306    
307          public int getCount() {
308              return count;
309          }
310    
311      }
312    
313      private static class Powell implements MultivariateRealFunction {
314    
315          private int count;
316    
317          public Powell() {
318              count = 0;
319          }
320    
321          public double value(double[] x) throws FunctionEvaluationException {
322              ++count;
323              double a = x[0] + 10 * x[1];
324              double b = x[2] - x[3];
325              double c = x[1] - 2 * x[2];
326              double d = x[0] - x[3];
327              return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
328          }
329    
330          public int getCount() {
331              return count;
332          }
333    
334      }
335    
336    }