View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.linear;
19  
20  import java.lang.reflect.Array;
21  
22  import org.apache.commons.math.Field;
23  import org.apache.commons.math.FieldElement;
24  import org.apache.commons.math.MathRuntimeException;
25  
26  /**
27   * Calculates the LUP-decomposition of a square matrix.
28   * <p>The LUP-decomposition of a matrix A consists of three matrices
29   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
30   * upper triangular and P is a permutation matrix. All matrices are
31   * m&times;m.</p>
32   * <p>Since {@link FieldElement field elements} do not provide an ordering
33   * operator, the permutation matrix is computed here only in order to avoid
34   * a zero pivot element, no attempt is done to get the largest pivot element.</p>
35   *
36   * @param <T> the type of the field elements
37   * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $
38   * @since 2.0
39   */
40  public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
41  
42      /** Field to which the elements belong. */
43      private final Field<T> field;
44  
45      /** Entries of LU decomposition. */
46      private T lu[][];
47  
48      /** Pivot permutation associated with LU decomposition */
49      private int[] pivot;
50  
51      /** Parity of the permutation associated with the LU decomposition */
52      private boolean even;
53  
54      /** Singularity indicator. */
55      private boolean singular;
56  
57      /** Cached value of L. */
58      private FieldMatrix<T> cachedL;
59  
60      /** Cached value of U. */
61      private FieldMatrix<T> cachedU;
62  
63      /** Cached value of P. */
64      private FieldMatrix<T> cachedP;
65  
66      /**
67       * Calculates the LU-decomposition of the given matrix. 
68       * @param matrix The matrix to decompose.
69       * @exception NonSquareMatrixException if matrix is not square
70       */
71      public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
72          throws NonSquareMatrixException {
73  
74          if (!matrix.isSquare()) {
75              throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
76          }
77  
78          final int m = matrix.getColumnDimension();
79          field = matrix.getField();
80          lu = matrix.getData();
81          pivot = new int[m];
82          cachedL = null;
83          cachedU = null;
84          cachedP = null;
85  
86          // Initialize permutation array and parity
87          for (int row = 0; row < m; row++) {
88              pivot[row] = row;
89          }
90          even     = true;
91          singular = false;
92  
93          // Loop over columns
94          for (int col = 0; col < m; col++) {
95  
96              T sum = field.getZero();
97  
98              // upper
99              for (int row = 0; row < col; row++) {
100                 final T[] luRow = lu[row];
101                 sum = luRow[col];
102                 for (int i = 0; i < row; i++) {
103                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
104                 }
105                 luRow[col] = sum;
106             }
107 
108             // lower
109             int nonZero = col; // permutation row
110             for (int row = col; row < m; row++) {
111                 final T[] luRow = lu[row];
112                 sum = luRow[col];
113                 for (int i = 0; i < col; i++) {
114                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
115                 }
116                 luRow[col] = sum;
117 
118                 if (lu[nonZero][col].equals(field.getZero())) {
119                     // try to select a better permutation choice
120                     ++nonZero;
121                 }
122             }
123 
124             // Singularity check
125             if (nonZero >= m) {
126                 singular = true;
127                 return;
128             }
129 
130             // Pivot if necessary
131             if (nonZero != col) {
132                 T tmp = field.getZero();
133                 for (int i = 0; i < m; i++) {
134                     tmp = lu[nonZero][i];
135                     lu[nonZero][i] = lu[col][i];
136                     lu[col][i] = tmp;
137                 }
138                 int temp = pivot[nonZero];
139                 pivot[nonZero] = pivot[col];
140                 pivot[col] = temp;
141                 even = !even;
142             }
143 
144             // Divide the lower elements by the "winning" diagonal elt.
145             final T luDiag = lu[col][col];
146             for (int row = col + 1; row < m; row++) {
147                 final T[] luRow = lu[row];
148                 luRow[col] = luRow[col].divide(luDiag);
149             }
150         }
151 
152     }
153 
154     /** {@inheritDoc} */
155     public FieldMatrix<T> getL() {
156         if ((cachedL == null) && !singular) {
157             final int m = pivot.length;
158             cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
159             for (int i = 0; i < m; ++i) {
160                 final T[] luI = lu[i];
161                 for (int j = 0; j < i; ++j) {
162                     cachedL.setEntry(i, j, luI[j]);
163                 }
164                 cachedL.setEntry(i, i, field.getOne());
165             }
166         }
167         return cachedL;
168     }
169 
170     /** {@inheritDoc} */
171     public FieldMatrix<T> getU() {
172         if ((cachedU == null) && !singular) {
173             final int m = pivot.length;
174             cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
175             for (int i = 0; i < m; ++i) {
176                 final T[] luI = lu[i];
177                 for (int j = i; j < m; ++j) {
178                     cachedU.setEntry(i, j, luI[j]);
179                 }
180             }
181         }
182         return cachedU;
183     }
184 
185     /** {@inheritDoc} */
186     public FieldMatrix<T> getP() {
187         if ((cachedP == null) && !singular) {
188             final int m = pivot.length;
189             cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
190             for (int i = 0; i < m; ++i) {
191                 cachedP.setEntry(i, pivot[i], field.getOne());
192             }
193         }
194         return cachedP;
195     }
196 
197     /** {@inheritDoc} */
198     public int[] getPivot() {
199         return pivot.clone();
200     }
201 
202     /** {@inheritDoc} */
203     public T getDeterminant() {
204         if (singular) {
205             return field.getZero();
206         } else {
207             final int m = pivot.length;
208             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
209             for (int i = 0; i < m; i++) {
210                 determinant = determinant.multiply(lu[i][i]);
211             }
212             return determinant;
213         }
214     }
215 
216     /** {@inheritDoc} */
217     public FieldDecompositionSolver<T> getSolver() {
218         return new Solver<T>(field, lu, pivot, singular);
219     }
220 
221     /** Specialized solver. */
222     private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
223 
224         /** Serializable version identifier. */
225         private static final long serialVersionUID = -6353105415121373022L;
226 
227         /** Field to which the elements belong. */
228         private final Field<T> field;
229 
230         /** Entries of LU decomposition. */
231         private final T lu[][];
232 
233         /** Pivot permutation associated with LU decomposition. */
234         private final int[] pivot;
235 
236         /** Singularity indicator. */
237         private final boolean singular;
238 
239         /**
240          * Build a solver from decomposed matrix.
241          * @param field field to which the matrix elements belong
242          * @param lu entries of LU decomposition
243          * @param pivot pivot permutation associated with LU decomposition
244          * @param singular singularity indicator
245          */
246         private Solver(final Field<T> field, final T[][] lu,
247                        final int[] pivot, final boolean singular) {
248             this.field    = field;
249             this.lu       = lu;
250             this.pivot    = pivot;
251             this.singular = singular;
252         }
253 
254         /** {@inheritDoc} */
255         public boolean isNonSingular() {
256             return !singular;
257         }
258 
259         /** {@inheritDoc} */
260         @SuppressWarnings("unchecked")
261         public T[] solve(T[] b)
262             throws IllegalArgumentException, InvalidMatrixException {
263 
264             final int m = pivot.length;
265             if (b.length != m) {
266                 throw MathRuntimeException.createIllegalArgumentException(
267                         "vector length mismatch: got {0} but expected {1}",
268                         b.length, m);
269             }
270             if (singular) {
271                 throw new SingularMatrixException();
272             }
273 
274             final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
275 
276             // Apply permutations to b
277             for (int row = 0; row < m; row++) {
278                 bp[row] = b[pivot[row]];
279             }
280 
281             // Solve LY = b
282             for (int col = 0; col < m; col++) {
283                 final T bpCol = bp[col];
284                 for (int i = col + 1; i < m; i++) {
285                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
286                 }
287             }
288 
289             // Solve UX = Y
290             for (int col = m - 1; col >= 0; col--) {
291                 bp[col] = bp[col].divide(lu[col][col]);
292                 final T bpCol = bp[col];
293                 for (int i = 0; i < col; i++) {
294                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
295                 }
296             }
297 
298             return bp;
299 
300         }
301 
302         /** {@inheritDoc} */
303         @SuppressWarnings("unchecked")
304         public FieldVector<T> solve(FieldVector<T> b)
305             throws IllegalArgumentException, InvalidMatrixException {
306             try {
307                 return solve((ArrayFieldVector<T>) b);
308             } catch (ClassCastException cce) {
309 
310                 final int m = pivot.length;
311                 if (b.getDimension() != m) {
312                     throw MathRuntimeException.createIllegalArgumentException(
313                             "vector length mismatch: got {0} but expected {1}",
314                             b.getDimension(), m);
315                 }
316                 if (singular) {
317                     throw new SingularMatrixException();
318                 }
319 
320                 final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
321 
322                 // Apply permutations to b
323                 for (int row = 0; row < m; row++) {
324                     bp[row] = b.getEntry(pivot[row]);
325                 }
326 
327                 // Solve LY = b
328                 for (int col = 0; col < m; col++) {
329                     final T bpCol = bp[col];
330                     for (int i = col + 1; i < m; i++) {
331                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
332                     }
333                 }
334 
335                 // Solve UX = Y
336                 for (int col = m - 1; col >= 0; col--) {
337                     bp[col] = bp[col].divide(lu[col][col]);
338                     final T bpCol = bp[col];
339                     for (int i = 0; i < col; i++) {
340                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
341                     }
342                 }
343 
344                 return new ArrayFieldVector<T>(bp, false);
345 
346             }
347         }
348 
349         /** Solve the linear equation A &times; X = B.
350          * <p>The A matrix is implicit here. It is </p>
351          * @param b right-hand side of the equation A &times; X = B
352          * @return a vector X such that A &times; X = B
353          * @exception IllegalArgumentException if matrices dimensions don't match
354          * @exception InvalidMatrixException if decomposed matrix is singular
355          */
356         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
357             throws IllegalArgumentException, InvalidMatrixException {
358             return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
359         }
360 
361         /** {@inheritDoc} */
362         @SuppressWarnings("unchecked")
363         public FieldMatrix<T> solve(FieldMatrix<T> b)
364             throws IllegalArgumentException, InvalidMatrixException {
365 
366             final int m = pivot.length;
367             if (b.getRowDimension() != m) {
368                 throw MathRuntimeException.createIllegalArgumentException(
369                         "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
370                         b.getRowDimension(), b.getColumnDimension(), m, "n");
371             }
372             if (singular) {
373                 throw new SingularMatrixException();
374             }
375 
376             final int nColB = b.getColumnDimension();
377 
378             // Apply permutations to b
379             final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
380             for (int row = 0; row < m; row++) {
381                 final T[] bpRow = bp[row];
382                 final int pRow = pivot[row];
383                 for (int col = 0; col < nColB; col++) {
384                     bpRow[col] = b.getEntry(pRow, col);
385                 }
386             }
387 
388             // Solve LY = b
389             for (int col = 0; col < m; col++) {
390                 final T[] bpCol = bp[col];
391                 for (int i = col + 1; i < m; i++) {
392                     final T[] bpI = bp[i];
393                     final T luICol = lu[i][col];
394                     for (int j = 0; j < nColB; j++) {
395                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
396                     }
397                 }
398             }
399 
400             // Solve UX = Y
401             for (int col = m - 1; col >= 0; col--) {
402                 final T[] bpCol = bp[col];
403                 final T luDiag = lu[col][col];
404                 for (int j = 0; j < nColB; j++) {
405                     bpCol[j] = bpCol[j].divide(luDiag);
406                 }
407                 for (int i = 0; i < col; i++) {
408                     final T[] bpI = bp[i];
409                     final T luICol = lu[i][col];
410                     for (int j = 0; j < nColB; j++) {
411                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
412                     }
413                 }
414             }
415 
416             return new Array2DRowFieldMatrix<T>(bp, false);
417 
418         }
419 
420         /** {@inheritDoc} */
421         public FieldMatrix<T> getInverse() throws InvalidMatrixException {
422             final int m = pivot.length;
423             final T one = field.getOne();
424             FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
425             for (int i = 0; i < m; ++i) {
426                 identity.setEntry(i, i, one);
427             }
428             return solve(identity);
429         }
430 
431     }
432 
433 }