1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math.stat.regression;
18
19 import org.apache.commons.math.MathRuntimeException;
20 import org.apache.commons.math.linear.LUDecompositionImpl;
21 import org.apache.commons.math.linear.QRDecomposition;
22 import org.apache.commons.math.linear.QRDecompositionImpl;
23 import org.apache.commons.math.linear.RealMatrix;
24 import org.apache.commons.math.linear.Array2DRowRealMatrix;
25 import org.apache.commons.math.linear.RealVector;
26 import org.apache.commons.math.linear.ArrayRealVector;
27
28 /**
29 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
30 * multiple linear regression model.</p>
31 *
32 * <p>OLS assumes the covariance matrix of the error to be diagonal and with
33 * equal variance.</p>
34 * <p>
35 * u ~ N(0, σ<sup>2</sup>I)
36 * </p>
37 *
38 * <p>The regression coefficients, b, satisfy the normal equations:
39 * <p>
40 * X<sup>T</sup> X b = X<sup>T</sup> y
41 * </p>
42 *
43 * <p>To solve the normal equations, this implementation uses QR decomposition
44 * of the X matrix. (See {@link QRDecompositionImpl} for details on the
45 * decomposition algorithm.)
46 * </p>
47 * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
48 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
49 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
50 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
51 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
52 * R b = Q<sup>T</sup> y
53 * </p>
54 * Given Q and R, the last equation is solved by back-subsitution.</p>
55 *
56 * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $
57 * @since 2.0
58 */
59 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
60
61 /** Cached QR decomposition of X matrix */
62 private QRDecomposition qr = null;
63
64 /**
65 * Loads model x and y sample data, overriding any previous sample.
66 *
67 * Computes and caches QR decomposition of the X matrix.
68 * @param y the [n,1] array representing the y sample
69 * @param x the [n,k] array representing the x sample
70 * @throws IllegalArgumentException if the x and y array data are not
71 * compatible for the regression
72 */
73 public void newSampleData(double[] y, double[][] x) {
74 validateSampleData(x, y);
75 newYSampleData(y);
76 newXSampleData(x);
77 }
78
79 /**
80 * {@inheritDoc}
81 *
82 * Computes and caches QR decomposition of the X matrix
83 */
84 @Override
85 public void newSampleData(double[] data, int nobs, int nvars) {
86 super.newSampleData(data, nobs, nvars);
87 qr = new QRDecompositionImpl(X);
88 }
89
90 /**
91 * <p>Compute the "hat" matrix.
92 * </p>
93 * <p>The hat matrix is defined in terms of the design matrix X
94 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
95 * </p>
96 * <p>The implementation here uses the QR decomposition to compute the
97 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
98 * p-dimensional identity matrix augmented by 0's. This computational
99 * formula is from "The Hat Matrix in Regression and ANOVA",
100 * David C. Hoaglin and Roy E. Welsch,
101 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
102 *
103 * @return the hat matrix
104 */
105 public RealMatrix calculateHat() {
106 // Create augmented identity matrix
107 RealMatrix Q = qr.getQ();
108 final int p = qr.getR().getColumnDimension();
109 final int n = Q.getColumnDimension();
110 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
111 double[][] augIData = augI.getDataRef();
112 for (int i = 0; i < n; i++) {
113 for (int j =0; j < n; j++) {
114 if (i == j && i < p) {
115 augIData[i][j] = 1d;
116 } else {
117 augIData[i][j] = 0d;
118 }
119 }
120 }
121
122 // Compute and return Hat matrix
123 return Q.multiply(augI).multiply(Q.transpose());
124 }
125
126 /**
127 * Loads new x sample data, overriding any previous sample
128 *
129 * @param x the [n,k] array representing the x sample
130 */
131 @Override
132 protected void newXSampleData(double[][] x) {
133 this.X = new Array2DRowRealMatrix(x);
134 qr = new QRDecompositionImpl(X);
135 }
136
137 /**
138 * Calculates regression coefficients using OLS.
139 *
140 * @return beta
141 */
142 @Override
143 protected RealVector calculateBeta() {
144 return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
145 }
146
147 /**
148 * <p>Calculates the variance on the beta by OLS.
149 * </p>
150 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
151 * </p>
152 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
153 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
154 * R included, where p = the length of the beta vector.</p>
155 *
156 * @return The beta variance
157 */
158 @Override
159 protected RealMatrix calculateBetaVariance() {
160 int p = X.getColumnDimension();
161 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
162 RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
163 return Rinv.multiply(Rinv.transpose());
164 }
165
166
167 /**
168 * <p>Calculates the variance on the Y by OLS.
169 * </p>
170 * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
171 * </p>
172 * @return The Y variance
173 */
174 @Override
175 protected double calculateYVariance() {
176 RealVector residuals = calculateResiduals();
177 return residuals.dotProduct(residuals) /
178 (X.getRowDimension() - X.getColumnDimension());
179 }
180
181 /** TODO: Find a home for the following methods in the linear package */
182
183 /**
184 * <p>Uses back substitution to solve the system</p>
185 *
186 * <p>coefficients X = constants</p>
187 *
188 * <p>coefficients must upper-triangular and constants must be a column
189 * matrix. The solution is returned as a column matrix.</p>
190 *
191 * <p>The number of columns in coefficients determines the length
192 * of the returned solution vector (column matrix). If constants
193 * has more rows than coefficients has columns, excess rows are ignored.
194 * Similarly, extra (zero) rows in coefficients are ignored</p>
195 *
196 * @param coefficients upper-triangular coefficients matrix
197 * @param constants column RHS constants vector
198 * @return solution matrix as a column vector
199 *
200 */
201 private static RealVector solveUpperTriangular(RealMatrix coefficients,
202 RealVector constants) {
203 checkUpperTriangular(coefficients, 1E-12);
204 int length = coefficients.getColumnDimension();
205 double x[] = new double[length];
206 for (int i = 0; i < length; i++) {
207 int index = length - 1 - i;
208 double sum = 0;
209 for (int j = index + 1; j < length; j++) {
210 sum += coefficients.getEntry(index, j) * x[j];
211 }
212 x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
213 }
214 return new ArrayRealVector(x);
215 }
216
217 /**
218 * <p>Check if a matrix is upper-triangular.</p>
219 *
220 * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
221 *
222 * @param m matrix to check
223 * @param epsilon maximum allowable absolute value for elements below
224 * the main diagonal
225 *
226 * @throws IllegalArgumentException if m is not upper-triangular
227 */
228 private static void checkUpperTriangular(RealMatrix m, double epsilon) {
229 int nCols = m.getColumnDimension();
230 int nRows = m.getRowDimension();
231 for (int r = 0; r < nRows; r++) {
232 int bound = Math.min(r, nCols);
233 for (int c = 0; c < bound; c++) {
234 if (Math.abs(m.getEntry(r, c)) > epsilon) {
235 throw MathRuntimeException.createIllegalArgumentException(
236 "matrix is not upper-triangular, entry ({0}, {1}) = {2} is too large",
237 r, c, m.getEntry(r, c));
238 }
239 }
240 }
241 }
242 }