1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.optimization.direct;
19
20 import java.util.Comparator;
21
22 import org.apache.commons.math.FunctionEvaluationException;
23 import org.apache.commons.math.optimization.OptimizationException;
24 import org.apache.commons.math.optimization.RealPointValuePair;
25
26
27
28
29
30
31
32
33 public class NelderMead extends DirectSearchOptimizer {
34
35
36 private final double rho;
37
38
39 private final double khi;
40
41
42 private final double gamma;
43
44
45 private final double sigma;
46
47
48
49
50
51 public NelderMead() {
52 this.rho = 1.0;
53 this.khi = 2.0;
54 this.gamma = 0.5;
55 this.sigma = 0.5;
56 }
57
58
59
60
61
62
63
64 public NelderMead(final double rho, final double khi,
65 final double gamma, final double sigma) {
66 this.rho = rho;
67 this.khi = khi;
68 this.gamma = gamma;
69 this.sigma = sigma;
70 }
71
72
73 @Override
74 protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
75 throws FunctionEvaluationException, OptimizationException {
76
77 incrementIterationsCounter();
78
79
80 final int n = simplex.length - 1;
81
82
83 final RealPointValuePair best = simplex[0];
84 final RealPointValuePair secondBest = simplex[n-1];
85 final RealPointValuePair worst = simplex[n];
86 final double[] xWorst = worst.getPointRef();
87
88
89
90 final double[] centroid = new double[n];
91 for (int i = 0; i < n; ++i) {
92 final double[] x = simplex[i].getPointRef();
93 for (int j = 0; j < n; ++j) {
94 centroid[j] += x[j];
95 }
96 }
97 final double scaling = 1.0 / n;
98 for (int j = 0; j < n; ++j) {
99 centroid[j] *= scaling;
100 }
101
102
103 final double[] xR = new double[n];
104 for (int j = 0; j < n; ++j) {
105 xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
106 }
107 final RealPointValuePair reflected = new RealPointValuePair(xR, evaluate(xR), false);
108
109 if ((comparator.compare(best, reflected) <= 0) &&
110 (comparator.compare(reflected, secondBest) < 0)) {
111
112
113 replaceWorstPoint(reflected, comparator);
114
115 } else if (comparator.compare(reflected, best) < 0) {
116
117
118 final double[] xE = new double[n];
119 for (int j = 0; j < n; ++j) {
120 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
121 }
122 final RealPointValuePair expanded = new RealPointValuePair(xE, evaluate(xE), false);
123
124 if (comparator.compare(expanded, reflected) < 0) {
125
126 replaceWorstPoint(expanded, comparator);
127 } else {
128
129 replaceWorstPoint(reflected, comparator);
130 }
131
132 } else {
133
134 if (comparator.compare(reflected, worst) < 0) {
135
136
137 final double[] xC = new double[n];
138 for (int j = 0; j < n; ++j) {
139 xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
140 }
141 final RealPointValuePair outContracted = new RealPointValuePair(xC, evaluate(xC), false);
142
143 if (comparator.compare(outContracted, reflected) <= 0) {
144
145 replaceWorstPoint(outContracted, comparator);
146 return;
147 }
148
149 } else {
150
151
152 final double[] xC = new double[n];
153 for (int j = 0; j < n; ++j) {
154 xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
155 }
156 final RealPointValuePair inContracted = new RealPointValuePair(xC, evaluate(xC), false);
157
158 if (comparator.compare(inContracted, worst) < 0) {
159
160 replaceWorstPoint(inContracted, comparator);
161 return;
162 }
163
164 }
165
166
167 final double[] xSmallest = simplex[0].getPointRef();
168 for (int i = 1; i < simplex.length; ++i) {
169 final double[] x = simplex[i].getPoint();
170 for (int j = 0; j < n; ++j) {
171 x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
172 }
173 simplex[i] = new RealPointValuePair(x, Double.NaN, false);
174 }
175 evaluateSimplex(comparator);
176
177 }
178
179 }
180
181 }