View Javadoc

1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *  
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *  
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License. 
18   *  
19   */
20  package org.apache.directory.server.kerberos.shared.crypto.encryption;
21  
22  
23  /**
24   * An implementation of the n-fold algorithm, as required by RFC 3961,
25   * "Encryption and Checksum Specifications for Kerberos 5."
26   * 
27   * "To n-fold a number X, replicate the input value to a length that
28   * is the least common multiple of n and the length of X.  Before
29   * each repetition, the input is rotated to the right by 13 bit
30   * positions.  The successive n-bit chunks are added together using
31   * 1's-complement addition (that is, with end-around carry) to yield
32   * a n-bit result."
33   * 
34   * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a>
35   * @version $Rev$, $Date$
36   */
37  public class NFold
38  {
39      /**
40       * N-fold the data n times.
41       * 
42       * @param n The number of times to n-fold the data.
43       * @param data The data to n-fold.
44       * @return The n-folded data.
45       */
46      public static byte[] nFold( int n, byte[] data )
47      {
48          int k = data.length * 8;
49          int lcm = getLcm( n, k );
50          int replicate = lcm / k;
51          byte[] sumBytes = new byte[lcm / 8];
52  
53          for ( int i = 0; i < replicate; i++ )
54          {
55              int rotation = 13 * i;
56  
57              byte[] temp = rotateRight( data, data.length * 8, rotation );
58  
59              for ( int j = 0; j < temp.length; j++ )
60              {
61                  sumBytes[j + i * temp.length] = temp[j];
62              }
63          }
64  
65          byte[] sum = new byte[n / 8];
66          byte[] nfold = new byte[n / 8];
67  
68          for ( int m = 0; m < lcm / n; m++ )
69          {
70              for ( int o = 0; o < n / 8; o++ )
71              {
72                  sum[o] = sumBytes[o + ( m * n / 8 )];
73              }
74  
75              nfold = sum( nfold, sum, nfold.length * 8 );
76  
77          }
78  
79          return nfold;
80      }
81  
82  
83      /**
84       * For 2 numbers, return the least-common multiple.
85       *
86       * @param n1 The first number.
87       * @param n2 The second number.
88       * @return The least-common multiple.
89       */
90      protected static int getLcm( int n1, int n2 )
91      {
92          int temp;
93          int product;
94  
95          product = n1 * n2;
96  
97          do
98          {
99              if ( n1 < n2 )
100             {
101                 temp = n1;
102                 n1 = n2;
103                 n2 = temp;
104             }
105             n1 = n1 % n2;
106         }
107         while ( n1 != 0 );
108 
109         return product / n2;
110     }
111 
112 
113     /**
114      * Right-rotate the given byte array.
115      *
116      * @param in The byte array to right-rotate.
117      * @param len The length of the byte array to rotate.
118      * @param step The number of positions to rotate the byte array.
119      * @return The right-rotated byte array.
120      */
121     private static byte[] rotateRight( byte[] in, int len, int step )
122     {
123         int numOfBytes = ( len - 1 ) / 8 + 1;
124         byte[] out = new byte[numOfBytes];
125 
126         for ( int i = 0; i < len; i++ )
127         {
128             int val = getBit( in, i );
129             setBit( out, ( i + step ) % len, val );
130         }
131         return out;
132     }
133 
134 
135     /**
136      * Perform one's complement addition (addition with end-around carry).  Note
137      * that for purposes of n-folding, we do not actually complement the
138      * result of the addition.
139      * 
140      * @param n1 The first number.
141      * @param n2 The second number.
142      * @param len The length of the byte arrays to sum.
143      * @return The sum with end-around carry.
144      */
145     protected static byte[] sum( byte[] n1, byte[] n2, int len )
146     {
147         int numOfBytes = ( len - 1 ) / 8 + 1;
148         byte[] out = new byte[numOfBytes];
149         int carry = 0;
150 
151         for ( int i = len - 1; i > -1; i-- )
152         {
153             int n1b = getBit( n1, i );
154             int n2b = getBit( n2, i );
155 
156             int sum = n1b + n2b + carry;
157 
158             if ( sum == 0 || sum == 1 )
159             {
160                 setBit( out, i, sum );
161                 carry = 0;
162             }
163             else if ( sum == 2 )
164             {
165                 carry = 1;
166             }
167             else if ( sum == 3 )
168             {
169                 setBit( out, i, 1 );
170                 carry = 1;
171             }
172         }
173 
174         if ( carry == 1 )
175         {
176             byte[] carryArray = new byte[n1.length];
177             carryArray[carryArray.length - 1] = 1;
178             out = sum( out, carryArray, n1.length * 8 );
179         }
180 
181         return out;
182     }
183 
184 
185     /**
186      * Get a bit from a byte array at a given position.
187      *
188      * @param data The data to get the bit from.
189      * @param pos The position to get the bit at.
190      * @return The value of the bit.
191      */
192     private static int getBit( byte[] data, int pos )
193     {
194         int posByte = pos / 8;
195         int posBit = pos % 8;
196 
197         byte valByte = data[posByte];
198         int valInt = valByte >> ( 8 - ( posBit + 1 ) ) & 0x0001;
199         return valInt;
200     }
201 
202 
203     /**
204      * Set a bit in a byte array at a given position.
205      *
206      * @param data The data to set the bit in.
207      * @param pos The position of the bit to set.
208      * @param The value to set the bit to.
209      */
210     private static void setBit( byte[] data, int pos, int val )
211     {
212         int posByte = pos / 8;
213         int posBit = pos % 8;
214         byte oldByte = data[posByte];
215         oldByte = ( byte ) ( ( ( 0xFF7F >> posBit ) & oldByte ) & 0x00FF );
216         byte newByte = ( byte ) ( ( val << ( 8 - ( posBit + 1 ) ) ) | oldByte );
217         data[posByte] = newByte;
218     }
219 }