001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math.optimization.direct; 019 020 import static org.junit.Assert.assertEquals; 021 import static org.junit.Assert.assertNotNull; 022 import static org.junit.Assert.assertNull; 023 import static org.junit.Assert.assertTrue; 024 import static org.junit.Assert.fail; 025 026 import org.apache.commons.math.ConvergenceException; 027 import org.apache.commons.math.FunctionEvaluationException; 028 import org.apache.commons.math.MathException; 029 import org.apache.commons.math.MaxEvaluationsExceededException; 030 import org.apache.commons.math.MaxIterationsExceededException; 031 import org.apache.commons.math.analysis.MultivariateRealFunction; 032 import org.apache.commons.math.analysis.MultivariateVectorialFunction; 033 import org.apache.commons.math.linear.Array2DRowRealMatrix; 034 import org.apache.commons.math.linear.RealMatrix; 035 import org.apache.commons.math.optimization.GoalType; 036 import org.apache.commons.math.optimization.LeastSquaresConverter; 037 import org.apache.commons.math.optimization.OptimizationException; 038 import org.apache.commons.math.optimization.RealPointValuePair; 039 import org.apache.commons.math.optimization.SimpleRealPointChecker; 040 import org.apache.commons.math.optimization.SimpleScalarValueChecker; 041 import org.junit.Test; 042 043 public class NelderMeadTest { 044 045 @Test 046 public void testFunctionEvaluationExceptions() { 047 MultivariateRealFunction wrong = 048 new MultivariateRealFunction() { 049 private static final long serialVersionUID = 4751314470965489371L; 050 public double value(double[] x) throws FunctionEvaluationException { 051 if (x[0] < 0) { 052 throw new FunctionEvaluationException(x, "{0}", "oops"); 053 } else if (x[0] > 1) { 054 throw new FunctionEvaluationException(new RuntimeException("oops"), x); 055 } else { 056 return x[0] * (1 - x[0]); 057 } 058 } 059 }; 060 try { 061 NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6); 062 optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 }); 063 fail("an exception should have been thrown"); 064 } catch (FunctionEvaluationException ce) { 065 // expected behavior 066 assertNull(ce.getCause()); 067 } catch (Exception e) { 068 fail("wrong exception caught: " + e.getMessage()); 069 } 070 try { 071 NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6); 072 optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 }); 073 fail("an exception should have been thrown"); 074 } catch (FunctionEvaluationException ce) { 075 // expected behavior 076 assertNotNull(ce.getCause()); 077 } catch (Exception e) { 078 fail("wrong exception caught: " + e.getMessage()); 079 } 080 } 081 082 @Test 083 public void testMinimizeMaximize() 084 throws FunctionEvaluationException, ConvergenceException { 085 086 // the following function has 4 local extrema: 087 final double xM = -3.841947088256863675365; 088 final double yM = -1.391745200270734924416; 089 final double xP = 0.2286682237349059125691; 090 final double yP = -yM; 091 final double valueXmYm = 0.2373295333134216789769; // local maximum 092 final double valueXmYp = -valueXmYm; // local minimum 093 final double valueXpYm = -0.7290400707055187115322; // global minimum 094 final double valueXpYp = -valueXpYm; // global maximum 095 MultivariateRealFunction fourExtrema = new MultivariateRealFunction() { 096 private static final long serialVersionUID = -7039124064449091152L; 097 public double value(double[] variables) throws FunctionEvaluationException { 098 final double x = variables[0]; 099 final double y = variables[1]; 100 return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y)); 101 } 102 }; 103 104 NelderMead optimizer = new NelderMead(); 105 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30)); 106 optimizer.setMaxIterations(100); 107 optimizer.setStartConfiguration(new double[] { 0.2, 0.2 }); 108 RealPointValuePair optimum; 109 110 // minimization 111 optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 }); 112 assertEquals(xM, optimum.getPoint()[0], 2.0e-7); 113 assertEquals(yP, optimum.getPoint()[1], 2.0e-5); 114 assertEquals(valueXmYp, optimum.getValue(), 6.0e-12); 115 assertTrue(optimizer.getEvaluations() > 60); 116 assertTrue(optimizer.getEvaluations() < 90); 117 118 optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 }); 119 assertEquals(xP, optimum.getPoint()[0], 5.0e-6); 120 assertEquals(yM, optimum.getPoint()[1], 6.0e-6); 121 assertEquals(valueXpYm, optimum.getValue(), 1.0e-11); 122 assertTrue(optimizer.getEvaluations() > 60); 123 assertTrue(optimizer.getEvaluations() < 90); 124 125 // maximization 126 optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 }); 127 assertEquals(xM, optimum.getPoint()[0], 1.0e-5); 128 assertEquals(yM, optimum.getPoint()[1], 3.0e-6); 129 assertEquals(valueXmYm, optimum.getValue(), 3.0e-12); 130 assertTrue(optimizer.getEvaluations() > 60); 131 assertTrue(optimizer.getEvaluations() < 90); 132 133 optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 }); 134 assertEquals(xP, optimum.getPoint()[0], 4.0e-6); 135 assertEquals(yP, optimum.getPoint()[1], 5.0e-6); 136 assertEquals(valueXpYp, optimum.getValue(), 7.0e-12); 137 assertTrue(optimizer.getEvaluations() > 60); 138 assertTrue(optimizer.getEvaluations() < 90); 139 140 } 141 142 @Test 143 public void testRosenbrock() 144 throws FunctionEvaluationException, ConvergenceException { 145 146 Rosenbrock rosenbrock = new Rosenbrock(); 147 NelderMead optimizer = new NelderMead(); 148 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1, 1.0e-3)); 149 optimizer.setMaxIterations(100); 150 optimizer.setStartConfiguration(new double[][] { 151 { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 } 152 }); 153 RealPointValuePair optimum = 154 optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 }); 155 156 assertEquals(rosenbrock.getCount(), optimizer.getEvaluations()); 157 assertTrue(optimizer.getEvaluations() > 40); 158 assertTrue(optimizer.getEvaluations() < 50); 159 assertTrue(optimum.getValue() < 8.0e-4); 160 161 } 162 163 @Test 164 public void testPowell() 165 throws FunctionEvaluationException, ConvergenceException { 166 167 Powell powell = new Powell(); 168 NelderMead optimizer = new NelderMead(); 169 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3)); 170 optimizer.setMaxIterations(200); 171 RealPointValuePair optimum = 172 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); 173 assertEquals(powell.getCount(), optimizer.getEvaluations()); 174 assertTrue(optimizer.getEvaluations() > 110); 175 assertTrue(optimizer.getEvaluations() < 130); 176 assertTrue(optimum.getValue() < 2.0e-3); 177 178 } 179 180 @Test 181 public void testLeastSquares1() 182 throws FunctionEvaluationException, ConvergenceException { 183 184 final RealMatrix factors = 185 new Array2DRowRealMatrix(new double[][] { 186 { 1.0, 0.0 }, 187 { 0.0, 1.0 } 188 }, false); 189 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { 190 public double[] value(double[] variables) { 191 return factors.operate(variables); 192 } 193 }, new double[] { 2.0, -3.0 }); 194 NelderMead optimizer = new NelderMead(); 195 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); 196 optimizer.setMaxIterations(200); 197 RealPointValuePair optimum = 198 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); 199 assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5); 200 assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4); 201 assertTrue(optimizer.getEvaluations() > 60); 202 assertTrue(optimizer.getEvaluations() < 80); 203 assertTrue(optimum.getValue() < 1.0e-6); 204 } 205 206 @Test 207 public void testLeastSquares2() 208 throws FunctionEvaluationException, ConvergenceException { 209 210 final RealMatrix factors = 211 new Array2DRowRealMatrix(new double[][] { 212 { 1.0, 0.0 }, 213 { 0.0, 1.0 } 214 }, false); 215 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { 216 public double[] value(double[] variables) { 217 return factors.operate(variables); 218 } 219 }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 }); 220 NelderMead optimizer = new NelderMead(); 221 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); 222 optimizer.setMaxIterations(200); 223 RealPointValuePair optimum = 224 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); 225 assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5); 226 assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4); 227 assertTrue(optimizer.getEvaluations() > 60); 228 assertTrue(optimizer.getEvaluations() < 80); 229 assertTrue(optimum.getValue() < 1.0e-6); 230 } 231 232 @Test 233 public void testLeastSquares3() 234 throws FunctionEvaluationException, ConvergenceException { 235 236 final RealMatrix factors = 237 new Array2DRowRealMatrix(new double[][] { 238 { 1.0, 0.0 }, 239 { 0.0, 1.0 } 240 }, false); 241 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() { 242 public double[] value(double[] variables) { 243 return factors.operate(variables); 244 } 245 }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] { 246 { 1.0, 1.2 }, { 1.2, 2.0 } 247 })); 248 NelderMead optimizer = new NelderMead(); 249 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6)); 250 optimizer.setMaxIterations(200); 251 RealPointValuePair optimum = 252 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 }); 253 assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3); 254 assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4); 255 assertTrue(optimizer.getEvaluations() > 60); 256 assertTrue(optimizer.getEvaluations() < 80); 257 assertTrue(optimum.getValue() < 1.0e-6); 258 } 259 260 @Test(expected = MaxIterationsExceededException.class) 261 public void testMaxIterations() throws MathException { 262 try { 263 Powell powell = new Powell(); 264 NelderMead optimizer = new NelderMead(); 265 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3)); 266 optimizer.setMaxIterations(20); 267 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); 268 } catch (OptimizationException oe) { 269 if (oe.getCause() instanceof ConvergenceException) { 270 throw (ConvergenceException) oe.getCause(); 271 } 272 throw oe; 273 } 274 } 275 276 @Test(expected = MaxEvaluationsExceededException.class) 277 public void testMaxEvaluations() throws MathException { 278 try { 279 Powell powell = new Powell(); 280 NelderMead optimizer = new NelderMead(); 281 optimizer.setConvergenceChecker(new SimpleRealPointChecker(-1.0, 1.0e-3)); 282 optimizer.setMaxEvaluations(20); 283 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); 284 } catch (FunctionEvaluationException fee) { 285 if (fee.getCause() instanceof ConvergenceException) { 286 throw (ConvergenceException) fee.getCause(); 287 } 288 throw fee; 289 } 290 } 291 292 private static class Rosenbrock implements MultivariateRealFunction { 293 294 private int count; 295 296 public Rosenbrock() { 297 count = 0; 298 } 299 300 public double value(double[] x) throws FunctionEvaluationException { 301 ++count; 302 double a = x[1] - x[0] * x[0]; 303 double b = 1.0 - x[0]; 304 return 100 * a * a + b * b; 305 } 306 307 public int getCount() { 308 return count; 309 } 310 311 } 312 313 private static class Powell implements MultivariateRealFunction { 314 315 private int count; 316 317 public Powell() { 318 count = 0; 319 } 320 321 public double value(double[] x) throws FunctionEvaluationException { 322 ++count; 323 double a = x[0] + 10 * x[1]; 324 double b = x[2] - x[3]; 325 double c = x[1] - 2 * x[2]; 326 double d = x[0] - x[3]; 327 return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d; 328 } 329 330 public int getCount() { 331 return count; 332 } 333 334 } 335 336 }