001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math.stat.regression; 018 019 import org.apache.commons.math.MathRuntimeException; 020 import org.apache.commons.math.linear.LUDecompositionImpl; 021 import org.apache.commons.math.linear.QRDecomposition; 022 import org.apache.commons.math.linear.QRDecompositionImpl; 023 import org.apache.commons.math.linear.RealMatrix; 024 import org.apache.commons.math.linear.Array2DRowRealMatrix; 025 import org.apache.commons.math.linear.RealVector; 026 import org.apache.commons.math.linear.ArrayRealVector; 027 028 /** 029 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 030 * multiple linear regression model.</p> 031 * 032 * <p>OLS assumes the covariance matrix of the error to be diagonal and with 033 * equal variance.</p> 034 * <p> 035 * u ~ N(0, σ<sup>2</sup>I) 036 * </p> 037 * 038 * <p>The regression coefficients, b, satisfy the normal equations: 039 * <p> 040 * X<sup>T</sup> X b = X<sup>T</sup> y 041 * </p> 042 * 043 * <p>To solve the normal equations, this implementation uses QR decomposition 044 * of the X matrix. (See {@link QRDecompositionImpl} for details on the 045 * decomposition algorithm.) 046 * </p> 047 * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/> 048 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/> 049 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/> 050 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/> 051 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/> 052 * R b = Q<sup>T</sup> y 053 * </p> 054 * Given Q and R, the last equation is solved by back-subsitution.</p> 055 * 056 * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $ 057 * @since 2.0 058 */ 059 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { 060 061 /** Cached QR decomposition of X matrix */ 062 private QRDecomposition qr = null; 063 064 /** 065 * Loads model x and y sample data, overriding any previous sample. 066 * 067 * Computes and caches QR decomposition of the X matrix. 068 * @param y the [n,1] array representing the y sample 069 * @param x the [n,k] array representing the x sample 070 * @throws IllegalArgumentException if the x and y array data are not 071 * compatible for the regression 072 */ 073 public void newSampleData(double[] y, double[][] x) { 074 validateSampleData(x, y); 075 newYSampleData(y); 076 newXSampleData(x); 077 } 078 079 /** 080 * {@inheritDoc} 081 * 082 * Computes and caches QR decomposition of the X matrix 083 */ 084 @Override 085 public void newSampleData(double[] data, int nobs, int nvars) { 086 super.newSampleData(data, nobs, nvars); 087 qr = new QRDecompositionImpl(X); 088 } 089 090 /** 091 * <p>Compute the "hat" matrix. 092 * </p> 093 * <p>The hat matrix is defined in terms of the design matrix X 094 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> 095 * </p> 096 * <p>The implementation here uses the QR decomposition to compute the 097 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the 098 * p-dimensional identity matrix augmented by 0's. This computational 099 * formula is from "The Hat Matrix in Regression and ANOVA", 100 * David C. Hoaglin and Roy E. Welsch, 101 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. 102 * 103 * @return the hat matrix 104 */ 105 public RealMatrix calculateHat() { 106 // Create augmented identity matrix 107 RealMatrix Q = qr.getQ(); 108 final int p = qr.getR().getColumnDimension(); 109 final int n = Q.getColumnDimension(); 110 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); 111 double[][] augIData = augI.getDataRef(); 112 for (int i = 0; i < n; i++) { 113 for (int j =0; j < n; j++) { 114 if (i == j && i < p) { 115 augIData[i][j] = 1d; 116 } else { 117 augIData[i][j] = 0d; 118 } 119 } 120 } 121 122 // Compute and return Hat matrix 123 return Q.multiply(augI).multiply(Q.transpose()); 124 } 125 126 /** 127 * Loads new x sample data, overriding any previous sample 128 * 129 * @param x the [n,k] array representing the x sample 130 */ 131 @Override 132 protected void newXSampleData(double[][] x) { 133 this.X = new Array2DRowRealMatrix(x); 134 qr = new QRDecompositionImpl(X); 135 } 136 137 /** 138 * Calculates regression coefficients using OLS. 139 * 140 * @return beta 141 */ 142 @Override 143 protected RealVector calculateBeta() { 144 return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y)); 145 } 146 147 /** 148 * <p>Calculates the variance on the beta by OLS. 149 * </p> 150 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> 151 * </p> 152 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> 153 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of 154 * R included, where p = the length of the beta vector.</p> 155 * 156 * @return The beta variance 157 */ 158 @Override 159 protected RealMatrix calculateBetaVariance() { 160 int p = X.getColumnDimension(); 161 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); 162 RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse(); 163 return Rinv.multiply(Rinv.transpose()); 164 } 165 166 167 /** 168 * <p>Calculates the variance on the Y by OLS. 169 * </p> 170 * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k) 171 * </p> 172 * @return The Y variance 173 */ 174 @Override 175 protected double calculateYVariance() { 176 RealVector residuals = calculateResiduals(); 177 return residuals.dotProduct(residuals) / 178 (X.getRowDimension() - X.getColumnDimension()); 179 } 180 181 /** TODO: Find a home for the following methods in the linear package */ 182 183 /** 184 * <p>Uses back substitution to solve the system</p> 185 * 186 * <p>coefficients X = constants</p> 187 * 188 * <p>coefficients must upper-triangular and constants must be a column 189 * matrix. The solution is returned as a column matrix.</p> 190 * 191 * <p>The number of columns in coefficients determines the length 192 * of the returned solution vector (column matrix). If constants 193 * has more rows than coefficients has columns, excess rows are ignored. 194 * Similarly, extra (zero) rows in coefficients are ignored</p> 195 * 196 * @param coefficients upper-triangular coefficients matrix 197 * @param constants column RHS constants vector 198 * @return solution matrix as a column vector 199 * 200 */ 201 private static RealVector solveUpperTriangular(RealMatrix coefficients, 202 RealVector constants) { 203 checkUpperTriangular(coefficients, 1E-12); 204 int length = coefficients.getColumnDimension(); 205 double x[] = new double[length]; 206 for (int i = 0; i < length; i++) { 207 int index = length - 1 - i; 208 double sum = 0; 209 for (int j = index + 1; j < length; j++) { 210 sum += coefficients.getEntry(index, j) * x[j]; 211 } 212 x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index); 213 } 214 return new ArrayRealVector(x); 215 } 216 217 /** 218 * <p>Check if a matrix is upper-triangular.</p> 219 * 220 * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p> 221 * 222 * @param m matrix to check 223 * @param epsilon maximum allowable absolute value for elements below 224 * the main diagonal 225 * 226 * @throws IllegalArgumentException if m is not upper-triangular 227 */ 228 private static void checkUpperTriangular(RealMatrix m, double epsilon) { 229 int nCols = m.getColumnDimension(); 230 int nRows = m.getRowDimension(); 231 for (int r = 0; r < nRows; r++) { 232 int bound = Math.min(r, nCols); 233 for (int c = 0; c < bound; c++) { 234 if (Math.abs(m.getEntry(r, c)) > epsilon) { 235 throw MathRuntimeException.createIllegalArgumentException( 236 "matrix is not upper-triangular, entry ({0}, {1}) = {2} is too large", 237 r, c, m.getEntry(r, c)); 238 } 239 } 240 } 241 } 242 }