package common; import java.io.*; import java.net.*; import java.rmi.server.*; import javax.net.ssl.*; import java.security.KeyStore; /** * Class RMISSLClientSocketFactory is similar in its purpose to the * default JDK javax.rmi.ssl.SslRMIClientSocketFactory but differs in * its ability to initilalize key sources in non-standard way. * * Mich like javax.rmi.ssl.SslRMIClientSocketFactory, and for the same * reason, RMISSLClientSocketFactory does not initilalize the * internals at constructor call time deferring that to createSocket * time: because the RMI client socket factory is created on the * server side, where that initialization is a priori meaningless, * unless both server and client run in the same JVM. We could * possibly override readObject() to force this initialization, but it * might not be a good idea to actually mix this with possible * deserialization problems. So contrarily to what we do for the * server side, the initialization of the SSLSocketFactory will be * delayed until the first time createSocket() is called - note that * the default SSLSocketFactory might already have been initialized * anyway if someone in the JVM already called * SSLSocketFactory.getDefault(). */ public class RMISSLClientSocketFactory implements RMIClientSocketFactory, Serializable { public static boolean TRACE = false; static void out (String msg) { System.out.println("RMISSL Client Socket Factory: " + msg); } // Note that after a client connects to a server, a factory matching // to what was obtained from the server will be initialized. The // value of isDefaultKM will match to what is on the server. For // instance, the client may initially use rmis_ssl_dflt but the // server may use the prog variant forcing the client to switch to // that, and if the credentials are on the path, this will continue // OK. boolean isDefaultKM; public RMISSLClientSocketFactory (boolean isDefaultKM) { this.isDefaultKM = isDefaultKM; if (TRACE) out("Creating RMISSLClientSocketFactory." + (isDefaultKM ? " with default KM" : "with custom KM")); } public RMISSLClientSocketFactory () { this(true); } public SSLSocketFactory initFactory() { if (TRACE) out("Enter initFactory. isDefaultKM = " + isDefaultKM); // if (TRACE) new Exception("initFactory's stack").printStackTrace(System.out); SSLSocketFactory factory = null; if (isDefaultKM) { // this method of obtaining factory and socket assumes the default // Key and Trust Store factory managers (see JSSE) specified as // -Djavax.net.ssl.keyStore=keystore // -Djavax.net.ssl.keyStorePassword=password // -Djavax.net.ssl.trustStore=truststore // -Djavax.net.ssl.trustStorePassword=trustword // when invoking this code if (TRACE) out("javax.net.ssl.keyStore: " + System.getProperty("javax.net.ssl.keyStore")); String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword"); if (TRACE) out("javax.net.ssl.keyStorePassword: " + keyStorePassword); if (TRACE) out("javax.net.ssl.trustStore: " + System.getProperty("javax.net.ssl.trustStore")); String trustStorePassword = System.getProperty("javax.net.ssl.trustStorePassword"); if (TRACE) out("javax.net.ssl.trustStorePassword: " + trustStorePassword); if (TRACE) out("createSocket: Get default socket factory..."); factory = (SSLSocketFactory)SSLSocketFactory.getDefault(); if (TRACE) out("Done get SSLSocketFactory " + factory + " hash: " + factory.hashCode()); } else { try { SSLContext ctx; KeyStore keystore, truststore; // explicitly set up key manager to do client authentication and // explicitly set up trust manager to do server authentication KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); InputStream pwis = this.getClass().getResourceAsStream("/tmp/client_input"); if (TRACE) out("-- pwis: " + pwis); BufferedReader reader = new BufferedReader(new InputStreamReader(pwis)); String passphrase_str = reader.readLine(); if (TRACE) out("passphrase_str: " + passphrase_str); char[] passphrase = passphrase_str.toCharArray(); if (TRACE) out("passphrase is array of len: " + passphrase.length); reader.close(); if (TRACE) out("Create keystore ..."); keystore = KeyStore.getInstance("JKS"); if (TRACE) out("Load keystore ..."); String ksname = KeytoolAttrs.clientKeyStore(); InputStream ksis = this.getClass().getResourceAsStream(ksname); if (TRACE) out("*store: " + ksname + " stream: " + ksis); keystore.load(ksis, passphrase); keystore.load(new FileInputStream(KeytoolAttrs.clientKeyStore()), passphrase); if (TRACE) out("Done load keystore " + keystore); if (TRACE) out("Create truststore ..."); truststore = KeyStore.getInstance("JKS"); if (TRACE) out("Load truststore ..."); ksname = KeytoolAttrs.clientTrustStore(); // ksis = new FileInputStream(ksname); ksis = this.getClass().getResourceAsStream(ksname); if (TRACE) out("*store: " + ksname + " stream: " + ksis); truststore.load(ksis, passphrase); if (TRACE) out("Done load truststore " + truststore); kmf.init(keystore, passphrase); tmf.init(truststore); if (TRACE) out("Create SSLContext ..."); ctx = SSLContext.getInstance("TLS"); ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); if (TRACE) out("Done create SSLContext " + ctx); if (TRACE) out("Get SSLSocketFactory ..."); factory = ctx.getSocketFactory(); if (TRACE) out("Done get SSLSocketFactory " + factory); if (TRACE) out("Compare that with the default socket factory..."); SSLSocketFactory factory_dflt = (SSLSocketFactory)SSLSocketFactory.getDefault(); if (TRACE) out("Done get SSLSocketFactory " + factory_dflt + " hash: " + factory_dflt.hashCode()); } catch (Exception e) { if (TRACE) { out("RMISSLClientSocketFactory: got " + e); e.printStackTrace(); } //throw new IOException("cause:"+e); } } return factory; } public Socket createSocket(String host, int port) throws IOException { // SSLSocketFactory factory = (SSLSocketFactory)SSLSocketFactory.getDefault(); SSLSocketFactory factory = initFactory(); if (TRACE) out("createSocket: Done get SSLSocketFactory " + factory); if (TRACE) out("Use " + factory + " to create SSLSocket on port " + port); SSLSocket s = (SSLSocket)factory.createSocket(host, port); if (TRACE) out("Got SSLSocket " + s + " class:" + s.getClass()); return s; } public int hashCode() { return getClass().hashCode(); } public boolean equals(Object obj) { if (TRACE) out("Enter equals"); if (obj == this) { return true; } else if (obj == null || getClass() != obj.getClass()) { return false; } return true; } }