001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math.stat.regression; 018 019 import static org.junit.Assert.assertEquals; 020 021 import org.apache.commons.math.TestUtils; 022 import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor; 023 import org.apache.commons.math.linear.MatrixUtils; 024 import org.apache.commons.math.linear.MatrixVisitorException; 025 import org.apache.commons.math.linear.RealMatrix; 026 import org.apache.commons.math.linear.Array2DRowRealMatrix; 027 import org.junit.Before; 028 import org.junit.Test; 029 030 public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest { 031 032 private double[] y; 033 private double[][] x; 034 035 @Before 036 @Override 037 public void setUp(){ 038 y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0}; 039 x = new double[6][]; 040 x[0] = new double[]{1.0, 0, 0, 0, 0, 0}; 041 x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0}; 042 x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0}; 043 x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0}; 044 x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0}; 045 x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0}; 046 super.setUp(); 047 } 048 049 @Override 050 protected OLSMultipleLinearRegression createRegression() { 051 OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 052 regression.newSampleData(y, x); 053 return regression; 054 } 055 056 @Override 057 protected int getNumberOfRegressors() { 058 return x[0].length; 059 } 060 061 @Override 062 protected int getSampleSize() { 063 return y.length; 064 } 065 066 @Test(expected=IllegalArgumentException.class) 067 public void cannotAddXSampleData() { 068 createRegression().newSampleData(new double[]{}, null); 069 } 070 071 @Test(expected=IllegalArgumentException.class) 072 public void cannotAddNullYSampleData() { 073 createRegression().newSampleData(null, new double[][]{}); 074 } 075 076 @Test(expected=IllegalArgumentException.class) 077 public void cannotAddSampleDataWithSizeMismatch() { 078 double[] y = new double[]{1.0, 2.0}; 079 double[][] x = new double[1][]; 080 x[0] = new double[]{1.0, 0}; 081 createRegression().newSampleData(y, x); 082 } 083 084 @Test 085 public void testPerfectFit() { 086 double[] betaHat = regression.estimateRegressionParameters(); 087 TestUtils.assertEquals(betaHat, 088 new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 }, 089 1e-14); 090 double[] residuals = regression.estimateResiduals(); 091 TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d}, 092 1e-14); 093 RealMatrix errors = 094 new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false); 095 final double[] s = { 1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0 }; 096 RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length); 097 referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { 098 @Override 099 public double visit(int row, int column, double value) 100 throws MatrixVisitorException { 101 if (row == 0) { 102 return s[column]; 103 } 104 double x = s[row] * s[column]; 105 return (row == column) ? 2 * x : x; 106 } 107 }); 108 assertEquals(0.0, 109 errors.subtract(referenceVariance).getNorm(), 110 5.0e-16 * referenceVariance.getNorm()); 111 } 112 113 114 /** 115 * Test Longley dataset against certified values provided by NIST. 116 * Data Source: J. Longley (1967) "An Appraisal of Least Squares 117 * Programs for the Electronic Computer from the Point of View of the User" 118 * Journal of the American Statistical Association, vol. 62. September, 119 * pp. 819-841. 120 * 121 * Certified values (and data) are from NIST: 122 * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat 123 */ 124 @Test 125 public void testLongly() { 126 // Y values are first, then independent vars 127 // Each row is one observation 128 double[] design = new double[] { 129 60323,83.0,234289,2356,1590,107608,1947, 130 61122,88.5,259426,2325,1456,108632,1948, 131 60171,88.2,258054,3682,1616,109773,1949, 132 61187,89.5,284599,3351,1650,110929,1950, 133 63221,96.2,328975,2099,3099,112075,1951, 134 63639,98.1,346999,1932,3594,113270,1952, 135 64989,99.0,365385,1870,3547,115094,1953, 136 63761,100.0,363112,3578,3350,116219,1954, 137 66019,101.2,397469,2904,3048,117388,1955, 138 67857,104.6,419180,2822,2857,118734,1956, 139 68169,108.4,442769,2936,2798,120445,1957, 140 66513,110.8,444546,4681,2637,121950,1958, 141 68655,112.6,482704,3813,2552,123366,1959, 142 69564,114.2,502601,3931,2514,125368,1960, 143 69331,115.7,518173,4806,2572,127852,1961, 144 70551,116.9,554894,4007,2827,130081,1962 145 }; 146 147 // Transform to Y and X required by interface 148 int nobs = 16; 149 int nvars = 6; 150 151 // Estimate the model 152 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); 153 model.newSampleData(design, nobs, nvars); 154 155 // Check expected beta values from NIST 156 double[] betaHat = model.estimateRegressionParameters(); 157 TestUtils.assertEquals(betaHat, 158 new double[]{-3482258.63459582, 15.0618722713733, 159 -0.358191792925910E-01,-2.02022980381683, 160 -1.03322686717359,-0.511041056535807E-01, 161 1829.15146461355}, 2E-8); // 162 163 // Check expected residuals from R 164 double[] residuals = model.estimateResiduals(); 165 TestUtils.assertEquals(residuals, new double[]{ 166 267.340029759711,-94.0139423988359,46.28716775752924, 167 -410.114621930906,309.7145907602313,-249.3112153297231, 168 -164.0489563956039,-13.18035686637081,14.30477260005235, 169 455.394094551857,-17.26892711483297,-39.0550425226967, 170 -155.5499735953195,-85.6713080421283,341.9315139607727, 171 -206.7578251937366}, 172 1E-8); 173 174 // Check standard errors from NIST 175 double[] errors = model.estimateRegressionParametersStandardErrors(); 176 TestUtils.assertEquals(new double[] {890420.383607373, 177 84.9149257747669, 178 0.334910077722432E-01, 179 0.488399681651699, 180 0.214274163161675, 181 0.226073200069370, 182 455.478499142212}, errors, 1E-6); 183 } 184 185 /** 186 * Test R Swiss fertility dataset against R. 187 * Data Source: R datasets package 188 */ 189 @Test 190 public void testSwissFertility() { 191 double[] design = new double[] { 192 80.2,17.0,15,12,9.96, 193 83.1,45.1,6,9,84.84, 194 92.5,39.7,5,5,93.40, 195 85.8,36.5,12,7,33.77, 196 76.9,43.5,17,15,5.16, 197 76.1,35.3,9,7,90.57, 198 83.8,70.2,16,7,92.85, 199 92.4,67.8,14,8,97.16, 200 82.4,53.3,12,7,97.67, 201 82.9,45.2,16,13,91.38, 202 87.1,64.5,14,6,98.61, 203 64.1,62.0,21,12,8.52, 204 66.9,67.5,14,7,2.27, 205 68.9,60.7,19,12,4.43, 206 61.7,69.3,22,5,2.82, 207 68.3,72.6,18,2,24.20, 208 71.7,34.0,17,8,3.30, 209 55.7,19.4,26,28,12.11, 210 54.3,15.2,31,20,2.15, 211 65.1,73.0,19,9,2.84, 212 65.5,59.8,22,10,5.23, 213 65.0,55.1,14,3,4.52, 214 56.6,50.9,22,12,15.14, 215 57.4,54.1,20,6,4.20, 216 72.5,71.2,12,1,2.40, 217 74.2,58.1,14,8,5.23, 218 72.0,63.5,6,3,2.56, 219 60.5,60.8,16,10,7.72, 220 58.3,26.8,25,19,18.46, 221 65.4,49.5,15,8,6.10, 222 75.5,85.9,3,2,99.71, 223 69.3,84.9,7,6,99.68, 224 77.3,89.7,5,2,100.00, 225 70.5,78.2,12,6,98.96, 226 79.4,64.9,7,3,98.22, 227 65.0,75.9,9,9,99.06, 228 92.2,84.6,3,3,99.46, 229 79.3,63.1,13,13,96.83, 230 70.4,38.4,26,12,5.62, 231 65.7,7.7,29,11,13.79, 232 72.7,16.7,22,13,11.22, 233 64.4,17.6,35,32,16.92, 234 77.6,37.6,15,7,4.97, 235 67.6,18.7,25,7,8.65, 236 35.0,1.2,37,53,42.34, 237 44.7,46.6,16,29,50.43, 238 42.8,27.7,22,29,58.33 239 }; 240 241 // Transform to Y and X required by interface 242 int nobs = 47; 243 int nvars = 4; 244 245 // Estimate the model 246 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); 247 model.newSampleData(design, nobs, nvars); 248 249 // Check expected beta values from R 250 double[] betaHat = model.estimateRegressionParameters(); 251 TestUtils.assertEquals(betaHat, 252 new double[]{91.05542390271397, 253 -0.22064551045715, 254 -0.26058239824328, 255 -0.96161238456030, 256 0.12441843147162}, 1E-12); 257 258 // Check expected residuals from R 259 double[] residuals = model.estimateResiduals(); 260 TestUtils.assertEquals(residuals, new double[]{ 261 7.1044267859730512,1.6580347433531366, 262 4.6944952770029644,8.4548022690166160,13.6547432343186212, 263 -9.3586864458500774,7.5822446330520386,15.5568995563859289, 264 0.8113090736598980,7.1186762732484308,7.4251378771228724, 265 2.6761316873234109,0.8351584810309354,7.1769991119615177, 266 -3.8746753206299553,-3.1337779476387251,-0.1412575244091504, 267 1.1186809170469780,-6.3588097346816594,3.4039270429434074, 268 2.3374058329820175,-7.9272368576900503,-7.8361010968497959, 269 -11.2597369269357070,0.9445333697827101,6.6544245101380328, 270 -0.9146136301118665,-4.3152449403848570,-4.3536932047009183, 271 -3.8907885169304661,-6.3027643926302188,-7.8308982189289091, 272 -3.1792280015332750,-6.7167298771158226,-4.8469946718041754, 273 -10.6335664353633685,11.1031134362036958,6.0084032641811733, 274 5.4326230830188482,-7.2375578629692230,2.1671550814448222, 275 15.0147574652763112,4.8625103516321015,-7.1597256413907706, 276 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063}, 277 1E-12); 278 279 // Check standard errors from R 280 double[] errors = model.estimateRegressionParametersStandardErrors(); 281 TestUtils.assertEquals(new double[] {6.94881329475087, 282 0.07360008972340, 283 0.27410957467466, 284 0.19454551679325, 285 0.03726654773803}, errors, 1E-10); 286 } 287 288 /** 289 * Test hat matrix computation 290 * 291 * @throws Exception 292 */ 293 @Test 294 public void testHat() throws Exception { 295 296 /* 297 * This example is from "The Hat Matrix in Regression and ANOVA", 298 * David C. Hoaglin and Roy E. Welsch, 299 * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. 300 * 301 */ 302 double[] design = new double[] { 303 11.14, .499, 11.1, 304 12.74, .558, 8.9, 305 13.13, .604, 8.8, 306 11.51, .441, 8.9, 307 12.38, .550, 8.8, 308 12.60, .528, 9.9, 309 11.13, .418, 10.7, 310 11.7, .480, 10.5, 311 11.02, .406, 10.5, 312 11.41, .467, 10.7 313 }; 314 315 int nobs = 10; 316 int nvars = 2; 317 318 // Estimate the model 319 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); 320 model.newSampleData(design, nobs, nvars); 321 322 RealMatrix hat = model.calculateHat(); 323 324 // Reference data is upper half of symmetric hat matrix 325 double[] referenceData = new double[] { 326 .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242, 327 .242, .292, .136, .243, .128, -.041, .033, -.035, .004, 328 .417, -.019, .273, .187, -.126, .044, -.153, .004, 329 .604, .197, -.038, .168, -.022, .275, -.028, 330 .252, .111, -.030, .019, -.010, -.010, 331 .148, .042, .117, .012, .111, 332 .262, .145, .277, .174, 333 .154, .120, .168, 334 .315, .148, 335 .187 336 }; 337 338 // Check against reference data and verify symmetry 339 int k = 0; 340 for (int i = 0; i < 10; i++) { 341 for (int j = i; j < 10; j++) { 342 assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3); 343 assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12); 344 k++; 345 } 346 } 347 348 /* 349 * Verify that residuals computed using the hat matrix are close to 350 * what we get from direct computation, i.e. r = (I - H) y 351 */ 352 double[] residuals = model.estimateResiduals(); 353 RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); 354 double[] hatResiduals = I.subtract(hat).operate(model.Y).getData(); 355 TestUtils.assertEquals(residuals, hatResiduals, 10e-12); 356 } 357 }