001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math.stat.descriptive; 018 019 020 import java.util.Locale; 021 022 import junit.framework.Test; 023 import junit.framework.TestCase; 024 import junit.framework.TestSuite; 025 026 import org.apache.commons.math.DimensionMismatchException; 027 import org.apache.commons.math.TestUtils; 028 import org.apache.commons.math.stat.descriptive.moment.Mean; 029 030 /** 031 * Test cases for the {@link MultivariateSummaryStatistics} class. 032 * 033 * @version $Revision: 797744 $ $Date: 2009-07-25 07:09:14 -0400 (Sat, 25 Jul 2009) $ 034 */ 035 036 public class MultivariateSummaryStatisticsTest extends TestCase { 037 038 public MultivariateSummaryStatisticsTest(String name) { 039 super(name); 040 } 041 042 public static Test suite() { 043 TestSuite suite = new TestSuite(MultivariateSummaryStatisticsTest.class); 044 suite.setName("MultivariateSummaryStatistics tests"); 045 return suite; 046 } 047 048 protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) { 049 return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected); 050 } 051 052 public void testSetterInjection() throws Exception { 053 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); 054 u.setMeanImpl(new StorelessUnivariateStatistic[] { 055 new sumMean(), new sumMean() 056 }); 057 u.addValue(new double[] { 1, 2 }); 058 u.addValue(new double[] { 3, 4 }); 059 assertEquals(4, u.getMean()[0], 1E-14); 060 assertEquals(6, u.getMean()[1], 1E-14); 061 u.clear(); 062 u.addValue(new double[] { 1, 2 }); 063 u.addValue(new double[] { 3, 4 }); 064 assertEquals(4, u.getMean()[0], 1E-14); 065 assertEquals(6, u.getMean()[1], 1E-14); 066 u.clear(); 067 u.setMeanImpl(new StorelessUnivariateStatistic[] { 068 new Mean(), new Mean() 069 }); // OK after clear 070 u.addValue(new double[] { 1, 2 }); 071 u.addValue(new double[] { 3, 4 }); 072 assertEquals(2, u.getMean()[0], 1E-14); 073 assertEquals(3, u.getMean()[1], 1E-14); 074 assertEquals(2, u.getDimension()); 075 } 076 077 public void testSetterIllegalState() throws Exception { 078 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); 079 u.addValue(new double[] { 1, 2 }); 080 u.addValue(new double[] { 3, 4 }); 081 try { 082 u.setMeanImpl(new StorelessUnivariateStatistic[] { 083 new sumMean(), new sumMean() 084 }); 085 fail("Expecting IllegalStateException"); 086 } catch (IllegalStateException ex) { 087 // expected 088 } 089 } 090 091 public void testToString() throws DimensionMismatchException { 092 MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true); 093 stats.addValue(new double[] {1, 3}); 094 stats.addValue(new double[] {2, 2}); 095 stats.addValue(new double[] {3, 1}); 096 Locale d = Locale.getDefault(); 097 Locale.setDefault(Locale.US); 098 assertEquals("MultivariateSummaryStatistics:\n" + 099 "n: 3\n" + 100 "min: 1.0, 1.0\n" + 101 "max: 3.0, 3.0\n" + 102 "mean: 2.0, 2.0\n" + 103 "geometric mean: 1.817..., 1.817...\n" + 104 "sum of squares: 14.0, 14.0\n" + 105 "sum of logarithms: 1.791..., 1.791...\n" + 106 "standard deviation: 1.0, 1.0\n" + 107 "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}\n", 108 stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1...")); 109 Locale.setDefault(d); 110 } 111 112 public void testShuffledStatistics() throws DimensionMismatchException { 113 // the purpose of this test is only to check the get/set methods 114 // we are aware shuffling statistics like this is really not 115 // something sensible to do in production ... 116 MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true); 117 MultivariateSummaryStatistics shuffled = createMultivariateSummaryStatistics(2, true); 118 119 StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl(); 120 shuffled.setGeoMeanImpl(shuffled.getMeanImpl()); 121 shuffled.setMeanImpl(shuffled.getMaxImpl()); 122 shuffled.setMaxImpl(shuffled.getMinImpl()); 123 shuffled.setMinImpl(shuffled.getSumImpl()); 124 shuffled.setSumImpl(shuffled.getSumsqImpl()); 125 shuffled.setSumsqImpl(shuffled.getSumLogImpl()); 126 shuffled.setSumLogImpl(tmp); 127 128 for (int i = 100; i > 0; --i) { 129 reference.addValue(new double[] {i, i}); 130 shuffled.addValue(new double[] {i, i}); 131 } 132 133 TestUtils.assertEquals(reference.getMean(), shuffled.getGeometricMean(), 1.0e-10); 134 TestUtils.assertEquals(reference.getMax(), shuffled.getMean(), 1.0e-10); 135 TestUtils.assertEquals(reference.getMin(), shuffled.getMax(), 1.0e-10); 136 TestUtils.assertEquals(reference.getSum(), shuffled.getMin(), 1.0e-10); 137 TestUtils.assertEquals(reference.getSumSq(), shuffled.getSum(), 1.0e-10); 138 TestUtils.assertEquals(reference.getSumLog(), shuffled.getSumSq(), 1.0e-10); 139 TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(), 1.0e-10); 140 141 } 142 143 /** 144 * Bogus mean implementation to test setter injection. 145 * Returns the sum instead of the mean. 146 */ 147 static class sumMean implements StorelessUnivariateStatistic { 148 private double sum = 0; 149 private long n = 0; 150 public double evaluate(double[] values, int begin, int length) { 151 return 0; 152 } 153 public double evaluate(double[] values) { 154 return 0; 155 } 156 public void clear() { 157 sum = 0; 158 n = 0; 159 } 160 public long getN() { 161 return n; 162 } 163 public double getResult() { 164 return sum; 165 } 166 public void increment(double d) { 167 sum += d; 168 n++; 169 } 170 public void incrementAll(double[] values, int start, int length) { 171 } 172 public void incrementAll(double[] values) { 173 } 174 public StorelessUnivariateStatistic copy() { 175 return new sumMean(); 176 } 177 } 178 179 public void testDimension() { 180 try { 181 createMultivariateSummaryStatistics(2, true).addValue(new double[3]); 182 } catch (DimensionMismatchException dme) { 183 // expected behavior 184 } catch (Exception e) { 185 fail("wrong exception caught"); 186 } 187 } 188 189 /** test stats */ 190 public void testStats() throws DimensionMismatchException { 191 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); 192 assertEquals(0, u.getN()); 193 u.addValue(new double[] { 1, 2 }); 194 u.addValue(new double[] { 2, 3 }); 195 u.addValue(new double[] { 2, 3 }); 196 u.addValue(new double[] { 3, 4 }); 197 assertEquals( 4, u.getN()); 198 assertEquals( 8, u.getSum()[0], 1.0e-10); 199 assertEquals(12, u.getSum()[1], 1.0e-10); 200 assertEquals(18, u.getSumSq()[0], 1.0e-10); 201 assertEquals(38, u.getSumSq()[1], 1.0e-10); 202 assertEquals( 1, u.getMin()[0], 1.0e-10); 203 assertEquals( 2, u.getMin()[1], 1.0e-10); 204 assertEquals( 3, u.getMax()[0], 1.0e-10); 205 assertEquals( 4, u.getMax()[1], 1.0e-10); 206 assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10); 207 assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10); 208 assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10); 209 assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10); 210 assertEquals( 2, u.getMean()[0], 1.0e-10); 211 assertEquals( 3, u.getMean()[1], 1.0e-10); 212 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10); 213 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10); 214 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10); 215 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10); 216 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10); 217 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10); 218 u.clear(); 219 assertEquals(0, u.getN()); 220 } 221 222 public void testN0andN1Conditions() throws Exception { 223 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true); 224 assertTrue(Double.isNaN(u.getMean()[0])); 225 assertTrue(Double.isNaN(u.getStandardDeviation()[0])); 226 227 /* n=1 */ 228 u.addValue(new double[] { 1 }); 229 assertEquals(1.0, u.getMean()[0], 1.0e-10); 230 assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10); 231 assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10); 232 233 /* n=2 */ 234 u.addValue(new double[] { 2 }); 235 assertTrue(u.getStandardDeviation()[0] > 0); 236 237 } 238 239 public void testNaNContracts() throws DimensionMismatchException { 240 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true); 241 assertTrue(Double.isNaN(u.getMean()[0])); 242 assertTrue(Double.isNaN(u.getMin()[0])); 243 assertTrue(Double.isNaN(u.getStandardDeviation()[0])); 244 assertTrue(Double.isNaN(u.getGeometricMean()[0])); 245 246 u.addValue(new double[] { 1.0 }); 247 assertFalse(Double.isNaN(u.getMean()[0])); 248 assertFalse(Double.isNaN(u.getMin()[0])); 249 assertFalse(Double.isNaN(u.getStandardDeviation()[0])); 250 assertFalse(Double.isNaN(u.getGeometricMean()[0])); 251 252 } 253 254 public void testSerialization() throws DimensionMismatchException { 255 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); 256 // Empty test 257 TestUtils.checkSerializedEquality(u); 258 MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u); 259 assertEquals(u, s); 260 261 // Add some data 262 u.addValue(new double[] { 2d, 1d }); 263 u.addValue(new double[] { 1d, 1d }); 264 u.addValue(new double[] { 3d, 1d }); 265 u.addValue(new double[] { 4d, 1d }); 266 u.addValue(new double[] { 5d, 1d }); 267 268 // Test again 269 TestUtils.checkSerializedEquality(u); 270 s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u); 271 assertEquals(u, s); 272 273 } 274 275 public void testEqualsAndHashCode() throws DimensionMismatchException { 276 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true); 277 MultivariateSummaryStatistics t = null; 278 int emptyHash = u.hashCode(); 279 assertTrue(u.equals(u)); 280 assertFalse(u.equals(t)); 281 assertFalse(u.equals(Double.valueOf(0))); 282 t = createMultivariateSummaryStatistics(2, true); 283 assertTrue(t.equals(u)); 284 assertTrue(u.equals(t)); 285 assertEquals(emptyHash, t.hashCode()); 286 287 // Add some data to u 288 u.addValue(new double[] { 2d, 1d }); 289 u.addValue(new double[] { 1d, 1d }); 290 u.addValue(new double[] { 3d, 1d }); 291 u.addValue(new double[] { 4d, 1d }); 292 u.addValue(new double[] { 5d, 1d }); 293 assertFalse(t.equals(u)); 294 assertFalse(u.equals(t)); 295 assertTrue(u.hashCode() != t.hashCode()); 296 297 //Add data in same order to t 298 t.addValue(new double[] { 2d, 1d }); 299 t.addValue(new double[] { 1d, 1d }); 300 t.addValue(new double[] { 3d, 1d }); 301 t.addValue(new double[] { 4d, 1d }); 302 t.addValue(new double[] { 5d, 1d }); 303 assertTrue(t.equals(u)); 304 assertTrue(u.equals(t)); 305 assertEquals(u.hashCode(), t.hashCode()); 306 307 // Clear and make sure summaries are indistinguishable from empty summary 308 u.clear(); 309 t.clear(); 310 assertTrue(t.equals(u)); 311 assertTrue(u.equals(t)); 312 assertEquals(emptyHash, t.hashCode()); 313 assertEquals(emptyHash, u.hashCode()); 314 } 315 316 }