1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.linear;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math.linear.InvalidMatrixException;
23  import org.apache.commons.math.linear.MatrixUtils;
24  import org.apache.commons.math.linear.RealMatrix;
25  import org.apache.commons.math.linear.TriDiagonalTransformer;
26  
27  import junit.framework.Test;
28  import junit.framework.TestCase;
29  import junit.framework.TestSuite;
30  
31  public class TriDiagonalTransformerTest extends TestCase {
32  
33      private double[][] testSquare5 = {
34              { 1, 2, 3, 1, 1 },
35              { 2, 1, 1, 3, 1 },
36              { 3, 1, 1, 1, 2 },
37              { 1, 3, 1, 2, 1 },
38              { 1, 1, 2, 1, 3 }
39      };
40  
41      private double[][] testSquare3 = {
42              { 1, 3, 4 },
43              { 3, 2, 2 },
44              { 4, 2, 0 }
45      };
46  
47      public TriDiagonalTransformerTest(String name) {
48          super(name);
49      }
50  
51      public void testNonSquare() {
52          try {
53              new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
54              fail("an exception should have been thrown");
55          } catch (InvalidMatrixException ime) {
56              // expected behavior
57          } catch (Exception e) {
58              fail("wrong exception caught");
59          }
60      }
61  
62      public void testAEqualQTQt() {
63          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
64          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
65      }
66  
67      private void checkAEqualQTQt(RealMatrix matrix) {
68          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
69          RealMatrix q  = transformer.getQ();
70          RealMatrix qT = transformer.getQT();
71          RealMatrix t  = transformer.getT();
72          double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm();
73          assertEquals(0, norm, 4.0e-15);
74      }
75  
76      public void testNoAccessBelowDiagonal() {
77          checkNoAccessBelowDiagonal(testSquare5);
78          checkNoAccessBelowDiagonal(testSquare3);
79      }
80  
81      private void checkNoAccessBelowDiagonal(double[][] data) {
82          double[][] modifiedData = new double[data.length][];
83          for (int i = 0; i < data.length; ++i) {
84              modifiedData[i] = data[i].clone();
85              Arrays.fill(modifiedData[i], 0, i, Double.NaN);
86          }
87          RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
88          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
89          RealMatrix q  = transformer.getQ();
90          RealMatrix qT = transformer.getQT();
91          RealMatrix t  = transformer.getT();
92          double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm();
93          assertEquals(0, norm, 4.0e-15);
94      }
95  
96      public void testQOrthogonal() {
97          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
98          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
99      }
100 
101     public void testQTOrthogonal() {
102         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
103         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
104     }
105 
106     private void checkOrthogonal(RealMatrix m) {
107         RealMatrix mTm = m.transpose().multiply(m);
108         RealMatrix id  = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
109         assertEquals(0, mTm.subtract(id).getNorm(), 1.0e-15);        
110     }
111 
112     public void testTTriDiagonal() {
113         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
114         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
115     }
116 
117     private void checkTriDiagonal(RealMatrix m) {
118         final int rows = m.getRowDimension();
119         final int cols = m.getColumnDimension();
120         for (int i = 0; i < rows; ++i) {
121             for (int j = 0; j < cols; ++j) {
122                 if ((i < j - 1) || (i > j + 1)) {
123                     assertEquals(0, m.getEntry(i, j), 1.0e-16);
124                 }                    
125             }
126         }
127     }
128 
129     public void testMatricesValues5() {
130         checkMatricesValues(testSquare5,
131                             new double[][] {
132                                 { 1.0,  0.0,                 0.0,                  0.0,                   0.0 },
133                                 { 0.0, -0.5163977794943222,  0.016748280772542083, 0.839800693771262,     0.16669620021405473 },
134                                 { 0.0, -0.7745966692414833, -0.4354553000860955,  -0.44989322880603355,  -0.08930153582895772 },
135                                 { 0.0, -0.2581988897471611,  0.6364346693566014,  -0.30263204032131164,   0.6608313651342882 },
136                                 { 0.0, -0.2581988897471611,  0.6364346693566009,  -0.027289660803112598, -0.7263191580755246 }
137                             },
138                             new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
139                             new double[] { -Math.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
140     }
141 
142     public void testMatricesValues3() {
143         checkMatricesValues(testSquare3,
144                             new double[][] {
145                                 {  1.0,  0.0,  0.0 },
146                                 {  0.0, -0.6,  0.8 },
147                                 {  0.0, -0.8, -0.6 },
148                             },
149                             new double[] { 1, 2.64, -0.64 },
150                             new double[] { -5, -1.52 });
151     }
152 
153     private void checkMatricesValues(double[][] matrix, double[][] qRef,
154                                      double[] mainDiagnonal,
155                                      double[] secondaryDiagonal) {
156         TriDiagonalTransformer transformer =
157             new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
158 
159         // check values against known references
160         RealMatrix q = transformer.getQ();
161         assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm(), 1.0e-14);
162 
163         RealMatrix t = transformer.getT();
164         double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
165         for (int i = 0; i < mainDiagnonal.length; ++i) {
166             tData[i][i] = mainDiagnonal[i];
167             if (i > 0) {
168                 tData[i][i - 1] = secondaryDiagonal[i - 1];
169             }
170             if (i < secondaryDiagonal.length) {
171                 tData[i][i + 1] = secondaryDiagonal[i];
172             }
173         }
174         assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm(), 1.0e-14);
175 
176         // check the same cached instance is returned the second time
177         assertTrue(q == transformer.getQ());
178         assertTrue(t == transformer.getT());
179         
180     }
181 
182     public static Test suite() {
183         return new TestSuite(TriDiagonalTransformerTest.class);
184     }
185 
186 }