001    /*
002     * CDDL HEADER START
003     *
004     * The contents of this file are subject to the terms of the
005     * Common Development and Distribution License, Version 1.0 only
006     * (the "License").  You may not use this file except in compliance
007     * with the License.
008     *
009     * You can obtain a copy of the license at
010     * trunk/opends/resource/legal-notices/OpenDS.LICENSE
011     * or https://OpenDS.dev.java.net/OpenDS.LICENSE.
012     * See the License for the specific language governing permissions
013     * and limitations under the License.
014     *
015     * When distributing Covered Code, include this CDDL HEADER in each
016     * file and include the License file at
017     * trunk/opends/resource/legal-notices/OpenDS.LICENSE.  If applicable,
018     * add the following below this CDDL HEADER, with the fields enclosed
019     * by brackets "[]" replaced with your own identifying information:
020     *      Portions Copyright [yyyy] [name of copyright owner]
021     *
022     * CDDL HEADER END
023     *
024     *
025     *      Copyright 2008 Sun Microsystems, Inc.
026     */
027    
028    package org.opends.admin.ads.util;
029    
030    import java.io.IOException;
031    import java.net.Socket;
032    import java.net.InetAddress;
033    import java.util.Map;
034    import java.util.HashMap;
035    
036    import java.security.GeneralSecurityException;
037    
038    import javax.net.SocketFactory;
039    import javax.net.ssl.KeyManager;
040    import javax.net.ssl.SSLContext;
041    import javax.net.ssl.SSLSocketFactory;
042    import javax.net.ssl.SSLKeyException;
043    import javax.net.ssl.TrustManager;
044    
045    /**
046     * An implementation of SSLSocketFactory.
047     */
048    public class TrustedSocketFactory extends SSLSocketFactory
049    {
050      private static Map<Thread, TrustManager> hmTrustManager =
051        new HashMap<Thread, TrustManager>();
052      private static Map<Thread, KeyManager> hmKeyManager =
053        new HashMap<Thread, KeyManager>();
054    
055      private static Map<TrustManager, SocketFactory> hmDefaultFactoryTm =
056        new HashMap<TrustManager, SocketFactory>();
057      private static Map<KeyManager, SocketFactory> hmDefaultFactoryKm =
058        new HashMap<KeyManager, SocketFactory>();
059    
060      private SSLSocketFactory innerFactory;
061      private TrustManager trustManager;
062      private KeyManager   keyManager;
063    
064      /**
065       * Constructor of the TrustedSocketFactory.
066       * @param trustManager the trust manager to use.
067       * @param keyManager   the key manager to use.
068       */
069      public TrustedSocketFactory(TrustManager trustManager, KeyManager keyManager)
070      {
071        this.trustManager = trustManager;
072        this.keyManager   = keyManager;
073      }
074    
075      /**
076       * Sets the provided trust and key manager for the operations in the
077       * current thread.
078       *
079       * @param trustManager
080       *          the trust manager to use.
081       * @param keyManager
082       *          the key manager to use.
083       */
084      public static synchronized void setCurrentThreadTrustManager(
085          TrustManager trustManager, KeyManager keyManager)
086      {
087        setThreadTrustManager(trustManager, Thread.currentThread());
088        setThreadKeyManager  (keyManager, Thread.currentThread());
089      }
090    
091      /**
092       * Sets the provided trust manager for the operations in the provided thread.
093       * @param trustManager the trust manager to use.
094       * @param thread the thread where we want to use the provided trust manager.
095       */
096      public static synchronized void setThreadTrustManager(
097          TrustManager trustManager, Thread thread)
098      {
099        TrustManager currentTrustManager = hmTrustManager.get(thread);
100        if (currentTrustManager != null) {
101          hmDefaultFactoryTm.remove(currentTrustManager);
102          hmTrustManager.remove(thread);
103        }
104        if (trustManager != null) {
105          hmTrustManager.put(thread, trustManager);
106        }
107      }
108    
109      /**
110       * Sets the provided key manager for the operations in the provided thread.
111       * @param keyManager the key manager to use.
112       * @param thread the thread where we want to use the provided key manager.
113       */
114      public static synchronized void setThreadKeyManager(
115          KeyManager keyManager, Thread thread)
116      {
117        KeyManager currentKeyManager = hmKeyManager.get(thread);
118        if (currentKeyManager != null) {
119          hmDefaultFactoryKm.remove(currentKeyManager);
120          hmKeyManager.remove(thread);
121        }
122        if (keyManager != null) {
123          hmKeyManager.put(thread, keyManager);
124        }
125      }
126    
127      //
128      // SocketFactory implementation
129      //
130      /**
131       * Returns the default SSL socket factory. The default
132       * implementation can be changed by setting the value of the
133       * "ssl.SocketFactory.provider" security property (in the Java
134       * security properties file) to the desired class. If SSL has not
135       * been configured properly for this virtual machine, the factory
136       * will be inoperative (reporting instantiation exceptions).
137       *
138       * @return the default SocketFactory
139       */
140      public static synchronized SocketFactory getDefault()
141      {
142        Thread currentThread = Thread.currentThread();
143        TrustManager trustManager = hmTrustManager.get(currentThread);
144        KeyManager   keyManager   = hmKeyManager.get(currentThread);
145        SocketFactory result;
146    
147        if (trustManager == null)
148        {
149          if (keyManager == null)
150          {
151            result = new TrustedSocketFactory(null,null);
152          }
153          else
154          {
155            result = hmDefaultFactoryKm.get(keyManager);
156            if (result == null)
157            {
158              result = new TrustedSocketFactory(null,keyManager);
159              hmDefaultFactoryKm.put(keyManager, result);
160            }
161          }
162        }
163        else
164        {
165          if (keyManager == null)
166          {
167            result = hmDefaultFactoryTm.get(trustManager);
168            if (result == null)
169            {
170              result = new TrustedSocketFactory(trustManager, null);
171              hmDefaultFactoryTm.put(trustManager, result);
172            }
173          }
174          else
175          {
176            SocketFactory tmsf = hmDefaultFactoryTm.get(trustManager);
177            SocketFactory kmsf = hmDefaultFactoryKm.get(keyManager);
178            if ( tmsf == null || kmsf == null)
179            {
180              result = new TrustedSocketFactory(trustManager, keyManager);
181              hmDefaultFactoryTm.put(trustManager, result);
182              hmDefaultFactoryKm.put(keyManager, result);
183            }
184            else
185            if ( !tmsf.equals(kmsf) )
186            {
187              result = new TrustedSocketFactory(trustManager, keyManager);
188              hmDefaultFactoryTm.put(trustManager, result);
189              hmDefaultFactoryKm.put(keyManager, result);
190            }
191            else
192            {
193              result = tmsf ;
194            }
195          }
196        }
197    
198        return result;
199      }
200    
201      /**
202       * {@inheritDoc}
203       */
204      public Socket createSocket(InetAddress address, int port) throws IOException {
205        return getInnerFactory().createSocket(address, port);
206      }
207    
208      /**
209       * {@inheritDoc}
210       */
211      public Socket createSocket(InetAddress address, int port,
212          InetAddress clientAddress, int clientPort) throws IOException
213      {
214        return getInnerFactory().createSocket(address, port, clientAddress,
215            clientPort);
216      }
217    
218      /**
219       * {@inheritDoc}
220       */
221      public Socket createSocket(String host, int port) throws IOException
222      {
223        return getInnerFactory().createSocket(host, port);
224      }
225    
226      /**
227       * {@inheritDoc}
228       */
229      public Socket createSocket(String host, int port, InetAddress clientHost,
230          int clientPort) throws IOException
231      {
232        return getInnerFactory().createSocket(host, port, clientHost, clientPort);
233      }
234    
235      /**
236       * {@inheritDoc}
237       */
238      public Socket createSocket(Socket s, String host, int port, boolean autoClose)
239      throws IOException
240      {
241        return getInnerFactory().createSocket(s, host, port, autoClose);
242      }
243    
244      /**
245       * {@inheritDoc}
246       */
247      public String[] getDefaultCipherSuites()
248      {
249        try
250        {
251          return getInnerFactory().getDefaultCipherSuites();
252        }
253        catch(IOException x)
254        {
255          return new String[0];
256        }
257      }
258    
259      /**
260       * {@inheritDoc}
261       */
262      public String[] getSupportedCipherSuites()
263      {
264        try
265        {
266          return getInnerFactory().getSupportedCipherSuites();
267        }
268        catch(IOException x)
269        {
270          return new String[0];
271        }
272      }
273    
274    
275      //
276      // Private
277      //
278    
279      private SSLSocketFactory getInnerFactory() throws IOException {
280        if (innerFactory == null)
281        {
282          String algorithm = "TLSv1";
283          SSLKeyException xx;
284          KeyManager[] km = null;
285          TrustManager[] tm = null;
286    
287          try {
288            SSLContext sslCtx = SSLContext.getInstance(algorithm);
289            if (trustManager != null)
290            {
291              tm = new TrustManager[] { trustManager };
292            }
293            if (keyManager != null)
294            {
295              km = new KeyManager[] { keyManager };
296            }
297            sslCtx.init(km, tm, new java.security.SecureRandom() );
298            innerFactory = sslCtx.getSocketFactory();
299          }
300          catch(GeneralSecurityException x) {
301            xx = new SSLKeyException("Failed to create SSLContext for " +
302                algorithm);
303            xx.initCause(x);
304            throw xx;
305          }
306        }
307        return innerFactory;
308      }
309    }
310