View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.linear;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math.MathRuntimeException;
23  
24  
25  /**
26   * Calculates the QR-decomposition of a matrix.
27   * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
28   * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
29   * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
30   * <p>This class compute the decomposition using Householder reflectors.</p>
31   * <p>For efficiency purposes, the decomposition in packed form is transposed.
32   * This allows inner loop to iterate inside rows, which is much more cache-efficient
33   * in Java.</p>
34   *
35   * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
36   * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
37   *
38   * @version $Revision: 799857 $ $Date: 2009-08-01 09:07:12 -0400 (Sat, 01 Aug 2009) $
39   * @since 1.2
40   */
41  public class QRDecompositionImpl implements QRDecomposition {
42  
43      /**
44       * A packed TRANSPOSED representation of the QR decomposition.
45       * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
46       * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
47       * from which an explicit form of Q can be recomputed if desired.</p>
48       */
49      private double[][] qrt;
50  
51      /** The diagonal elements of R. */
52      private double[] rDiag;
53  
54      /** Cached value of Q. */
55      private RealMatrix cachedQ;
56  
57      /** Cached value of QT. */
58      private RealMatrix cachedQT;
59  
60      /** Cached value of R. */
61      private RealMatrix cachedR;
62  
63      /** Cached value of H. */
64      private RealMatrix cachedH;
65  
66      /**
67       * Calculates the QR-decomposition of the given matrix. 
68       * @param matrix The matrix to decompose.
69       */
70      public QRDecompositionImpl(RealMatrix matrix) {
71  
72          final int m = matrix.getRowDimension();
73          final int n = matrix.getColumnDimension();
74          qrt = matrix.transpose().getData();
75          rDiag = new double[Math.min(m, n)];
76          cachedQ  = null;
77          cachedQT = null;
78          cachedR  = null;
79          cachedH  = null;
80  
81          /*
82           * The QR decomposition of a matrix A is calculated using Householder
83           * reflectors by repeating the following operations to each minor
84           * A(minor,minor) of A:
85           */
86          for (int minor = 0; minor < Math.min(m, n); minor++) {
87  
88              final double[] qrtMinor = qrt[minor];
89  
90              /*
91               * Let x be the first column of the minor, and a^2 = |x|^2.
92               * x will be in the positions qr[minor][minor] through qr[m][minor].
93               * The first column of the transformed minor will be (a,0,0,..)'
94               * The sign of a is chosen to be opposite to the sign of the first
95               * component of x. Let's find a:
96               */
97              double xNormSqr = 0;
98              for (int row = minor; row < m; row++) {
99                  final double c = qrtMinor[row];
100                 xNormSqr += c * c;
101             }
102             final double a = (qrtMinor[minor] > 0) ? -Math.sqrt(xNormSqr) : Math.sqrt(xNormSqr);
103             rDiag[minor] = a;
104 
105             if (a != 0.0) {
106 
107                 /*
108                  * Calculate the normalized reflection vector v and transform
109                  * the first column. We know the norm of v beforehand: v = x-ae
110                  * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
111                  * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
112                  * Here <x, e> is now qr[minor][minor].
113                  * v = x-ae is stored in the column at qr:
114                  */
115                 qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
116 
117                 /*
118                  * Transform the rest of the columns of the minor:
119                  * They will be transformed by the matrix H = I-2vv'/|v|^2.
120                  * If x is a column vector of the minor, then
121                  * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
122                  * Therefore the transformation is easily calculated by
123                  * subtracting the column vector (2<x,v>/|v|^2)v from x.
124                  *
125                  * Let 2<x,v>/|v|^2 = alpha. From above we have
126                  * |v|^2 = -2a*(qr[minor][minor]), so
127                  * alpha = -<x,v>/(a*qr[minor][minor])
128                  */
129                 for (int col = minor+1; col < n; col++) {
130                     final double[] qrtCol = qrt[col];
131                     double alpha = 0;
132                     for (int row = minor; row < m; row++) {
133                         alpha -= qrtCol[row] * qrtMinor[row];
134                     }
135                     alpha /= a * qrtMinor[minor];
136 
137                     // Subtract the column vector alpha*v from x.
138                     for (int row = minor; row < m; row++) {
139                         qrtCol[row] -= alpha * qrtMinor[row];
140                     }
141                 }
142             }
143         }
144     }
145 
146     /** {@inheritDoc} */
147     public RealMatrix getR() {
148 
149         if (cachedR == null) {
150 
151             // R is supposed to be m x n
152             final int n = qrt.length;
153             final int m = qrt[0].length;
154             cachedR = MatrixUtils.createRealMatrix(m, n);
155 
156             // copy the diagonal from rDiag and the upper triangle of qr
157             for (int row = Math.min(m, n) - 1; row >= 0; row--) {
158                 cachedR.setEntry(row, row, rDiag[row]);
159                 for (int col = row + 1; col < n; col++) {
160                     cachedR.setEntry(row, col, qrt[col][row]);
161                 }
162             }
163 
164         }
165 
166         // return the cached matrix
167         return cachedR;
168 
169     }
170 
171     /** {@inheritDoc} */
172     public RealMatrix getQ() {
173         if (cachedQ == null) {
174             cachedQ = getQT().transpose();
175         }
176         return cachedQ;
177     }
178 
179     /** {@inheritDoc} */
180     public RealMatrix getQT() {
181 
182         if (cachedQT == null) {
183 
184             // QT is supposed to be m x m
185             final int n = qrt.length;
186             final int m = qrt[0].length;
187             cachedQT = MatrixUtils.createRealMatrix(m, m);
188 
189             /* 
190              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then 
191              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in 
192              * succession to the result 
193              */ 
194             for (int minor = m - 1; minor >= Math.min(m, n); minor--) {
195                 cachedQT.setEntry(minor, minor, 1.0);
196             }
197 
198             for (int minor = Math.min(m, n)-1; minor >= 0; minor--){
199                 final double[] qrtMinor = qrt[minor];
200                 cachedQT.setEntry(minor, minor, 1.0);
201                 if (qrtMinor[minor] != 0.0) {
202                     for (int col = minor; col < m; col++) {
203                         double alpha = 0;
204                         for (int row = minor; row < m; row++) {
205                             alpha -= cachedQT.getEntry(col, row) * qrtMinor[row];
206                         }
207                         alpha /= rDiag[minor] * qrtMinor[minor];
208 
209                         for (int row = minor; row < m; row++) {
210                             cachedQT.addToEntry(col, row, -alpha * qrtMinor[row]);
211                         }
212                     }
213                 }
214             }
215 
216         }
217 
218         // return the cached matrix
219         return cachedQT;
220 
221     }
222 
223     /** {@inheritDoc} */
224     public RealMatrix getH() {
225 
226         if (cachedH == null) {
227 
228             final int n = qrt.length;
229             final int m = qrt[0].length;
230             cachedH = MatrixUtils.createRealMatrix(m, n);
231             for (int i = 0; i < m; ++i) {
232                 for (int j = 0; j < Math.min(i + 1, n); ++j) {
233                     cachedH.setEntry(i, j, qrt[j][i] / -rDiag[j]);
234                 }
235             }
236 
237         }
238 
239         // return the cached matrix
240         return cachedH;
241 
242     }
243 
244     /** {@inheritDoc} */
245     public DecompositionSolver getSolver() {
246         return new Solver(qrt, rDiag);
247     }
248 
249     /** Specialized solver. */
250     private static class Solver implements DecompositionSolver {
251     
252         /**
253          * A packed TRANSPOSED representation of the QR decomposition.
254          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
255          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
256          * from which an explicit form of Q can be recomputed if desired.</p>
257          */
258         private final double[][] qrt;
259 
260         /** The diagonal elements of R. */
261         private final double[] rDiag;
262 
263         /**
264          * Build a solver from decomposed matrix.
265          * @param qrt packed TRANSPOSED representation of the QR decomposition
266          * @param rDiag diagonal elements of R
267          */
268         private Solver(final double[][] qrt, final double[] rDiag) {
269             this.qrt   = qrt;
270             this.rDiag = rDiag;
271         }
272 
273         /** {@inheritDoc} */
274         public boolean isNonSingular() {
275 
276             for (double diag : rDiag) {
277                 if (diag == 0) {
278                     return false;
279                 }
280             }
281             return true;
282 
283         }
284 
285         /** {@inheritDoc} */
286         public double[] solve(double[] b)
287         throws IllegalArgumentException, InvalidMatrixException {
288 
289             final int n = qrt.length;
290             final int m = qrt[0].length;
291             if (b.length != m) {
292                 throw MathRuntimeException.createIllegalArgumentException(
293                         "vector length mismatch: got {0} but expected {1}",
294                         b.length, m);
295             }
296             if (!isNonSingular()) {
297                 throw new SingularMatrixException();
298             }
299 
300             final double[] x = new double[n];
301             final double[] y = b.clone();
302 
303             // apply Householder transforms to solve Q.y = b
304             for (int minor = 0; minor < Math.min(m, n); minor++) {
305 
306                 final double[] qrtMinor = qrt[minor];
307                 double dotProduct = 0;
308                 for (int row = minor; row < m; row++) {
309                     dotProduct += y[row] * qrtMinor[row];
310                 }
311                 dotProduct /= rDiag[minor] * qrtMinor[minor];
312 
313                 for (int row = minor; row < m; row++) {
314                     y[row] += dotProduct * qrtMinor[row];
315                 }
316 
317             }
318 
319             // solve triangular system R.x = y
320             for (int row = rDiag.length - 1; row >= 0; --row) {
321                 y[row] /= rDiag[row];
322                 final double yRow   = y[row];
323                 final double[] qrtRow = qrt[row];
324                 x[row] = yRow;
325                 for (int i = 0; i < row; i++) {
326                     y[i] -= yRow * qrtRow[i];
327                 }
328             }
329 
330             return x;
331 
332         }
333 
334         /** {@inheritDoc} */
335         public RealVector solve(RealVector b)
336         throws IllegalArgumentException, InvalidMatrixException {
337             try {
338                 return solve((ArrayRealVector) b);
339             } catch (ClassCastException cce) {
340                 return new ArrayRealVector(solve(b.getData()), false);
341             }
342         }
343 
344         /** Solve the linear equation A &times; X = B.
345          * <p>The A matrix is implicit here. It is </p>
346          * @param b right-hand side of the equation A &times; X = B
347          * @return a vector X that minimizes the two norm of A &times; X - B
348          * @throws IllegalArgumentException if matrices dimensions don't match
349          * @throws InvalidMatrixException if decomposed matrix is singular
350          */
351         public ArrayRealVector solve(ArrayRealVector b)
352         throws IllegalArgumentException, InvalidMatrixException {
353             return new ArrayRealVector(solve(b.getDataRef()), false);
354         }
355 
356         /** {@inheritDoc} */
357         public RealMatrix solve(RealMatrix b)
358         throws IllegalArgumentException, InvalidMatrixException {
359 
360             final int n = qrt.length;
361             final int m = qrt[0].length;
362             if (b.getRowDimension() != m) {
363                 throw MathRuntimeException.createIllegalArgumentException(
364                         "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
365                         b.getRowDimension(), b.getColumnDimension(), m, "n");
366             }
367             if (!isNonSingular()) {
368                 throw new SingularMatrixException();
369             }
370 
371             final int columns        = b.getColumnDimension();
372             final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
373             final int cBlocks        = (columns + blockSize - 1) / blockSize;
374             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
375             final double[][] y       = new double[b.getRowDimension()][blockSize];
376             final double[]   alpha   = new double[blockSize];
377 
378             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
379                 final int kStart = kBlock * blockSize;
380                 final int kEnd   = Math.min(kStart + blockSize, columns);
381                 final int kWidth = kEnd - kStart;
382 
383                 // get the right hand side vector
384                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
385 
386                 // apply Householder transforms to solve Q.y = b
387                 for (int minor = 0; minor < Math.min(m, n); minor++) {
388                     final double[] qrtMinor = qrt[minor];
389                     final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]); 
390 
391                     Arrays.fill(alpha, 0, kWidth, 0.0);
392                     for (int row = minor; row < m; ++row) {
393                         final double   d    = qrtMinor[row];
394                         final double[] yRow = y[row];
395                         for (int k = 0; k < kWidth; ++k) {
396                             alpha[k] += d * yRow[k];
397                         }
398                     }
399                     for (int k = 0; k < kWidth; ++k) {
400                         alpha[k] *= factor;
401                     }
402 
403                     for (int row = minor; row < m; ++row) {
404                         final double   d    = qrtMinor[row];
405                         final double[] yRow = y[row];
406                         for (int k = 0; k < kWidth; ++k) {
407                             yRow[k] += alpha[k] * d;
408                         }
409                     }
410 
411                 }
412 
413                 // solve triangular system R.x = y
414                 for (int j = rDiag.length - 1; j >= 0; --j) {
415                     final int      jBlock = j / blockSize;
416                     final int      jStart = jBlock * blockSize;
417                     final double   factor = 1.0 / rDiag[j];
418                     final double[] yJ     = y[j];
419                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
420                     for (int k = 0, index = (j - jStart) * kWidth; k < kWidth; ++k, ++index) {
421                         yJ[k]        *= factor;
422                         xBlock[index] = yJ[k];
423                     }
424 
425                     final double[] qrtJ = qrt[j];
426                     for (int i = 0; i < j; ++i) {
427                         final double rIJ  = qrtJ[i];
428                         final double[] yI = y[i];
429                         for (int k = 0; k < kWidth; ++k) {
430                             yI[k] -= yJ[k] * rIJ;
431                         }
432                     }
433 
434                 }
435 
436             }
437 
438             return new BlockRealMatrix(n, columns, xBlocks, false);
439 
440         }
441 
442         /** {@inheritDoc} */
443         public RealMatrix getInverse()
444         throws InvalidMatrixException {
445             return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
446         }
447 
448     }
449 
450 }