001package ca.uhn.hl7v2.hoh.sockets; 002 003import java.io.IOException; 004import java.io.InputStream; 005import java.net.InetSocketAddress; 006import java.net.ServerSocket; 007import java.net.Socket; 008import java.util.Arrays; 009import javax.net.ssl.SSLHandshakeException; 010import javax.net.ssl.SSLServerSocket; 011 012import ca.uhn.hl7v2.hoh.util.RandomServerPortProvider; 013import org.junit.Before; 014import org.junit.Test; 015import org.mortbay.jetty.Server; 016import org.mortbay.jetty.security.SslSelectChannelConnector; 017 018import static org.junit.Assert.assertEquals; 019import static org.junit.Assert.fail; 020 021public class CustomCertificateTlsSocketFactoryTest { 022 023 private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(CustomCertificateTlsSocketFactoryTest.class); 024 025 private int myPort; 026 027 @Before 028 public void before() { 029 myPort = RandomServerPortProvider.findFreePort(); 030 } 031 032 @Test 033 public void testConnectToNonTrustedSocket() throws IOException, InterruptedException { 034 035 CustomCertificateTlsSocketFactory badServer = createTrustedServerSocketFactory(); 036 Receiver receiver = new Receiver(badServer); 037 receiver.start(); 038 Thread.sleep(500); 039 040 try { 041 042 CustomCertificateTlsSocketFactory goodClient = createNonTrustedClientSocketFactory(); 043 Socket client = goodClient.createClientSocket(); 044 client.connect(new InetSocketAddress("localhost", myPort)); 045 046 client.getOutputStream().write("HELLO WORLD".getBytes()); 047 fail(); 048 049 } catch (SSLHandshakeException e) { 050 051 } 052 } 053 054 @Test 055 public void testConnectToTrustedSocket() throws IOException, InterruptedException { 056 057 CustomCertificateTlsSocketFactory goodServer = createTrustedServerSocketFactory(); 058 Receiver receiver = new Receiver(goodServer); 059 receiver.start(); 060 Thread.sleep(500); 061 062 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory(); 063 goodClient.setKeystoreFilename("src/test/resources/truststore.jks"); 064 // goodClient.setKeystorePassphrase("changeit"); 065 Socket client = goodClient.createClientSocket(); 066 client.connect(new InetSocketAddress("localhost", myPort)); 067 068 client.getOutputStream().write("HELLO WORLD".getBytes()); 069 client.close(); 070 071 Thread.sleep(500); 072 String expected = "HELLO WORLD"; 073 String actual = receiver.myString; 074 assertEquals(expected, actual); 075 076 } 077 078 public static CustomCertificateTlsSocketFactory createNonTrustedClientSocketFactory() { 079 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory(); 080 goodClient.setKeystoreFilename("src/test/resources/truststore2.jks"); 081 goodClient.setKeystorePassphrase("trustpassword"); 082 return goodClient; 083 } 084 085 public static StandardSocketFactory createNonSslServerSocketFactory() { 086 StandardSocketFactory goodClient = new StandardSocketFactory(); 087 return goodClient; 088 } 089 090 public static CustomCertificateTlsSocketFactory createTrustedClientSocketFactory() { 091 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory(); 092 goodClient.setKeystoreFilename("src/test/resources/truststore.jks"); 093// goodClient.setKeystorePassphrase("trustpassword"); 094 return goodClient; 095 } 096 097 public static CustomCertificateTlsSocketFactory createTrustedServerSocketFactory() { 098 CustomCertificateTlsSocketFactory goodServer = new CustomCertificateTlsSocketFactory(); 099 goodServer.setKeystoreFilename("src/test/resources/keystore.jks"); 100 goodServer.setKeystorePassphrase("changeit"); 101 return goodServer; 102 } 103 104 public static void main(String[] args) throws Exception { 105 106 Server s = new Server(); 107 108 SslSelectChannelConnector ssl = new SslSelectChannelConnector(); 109 ssl.setKeystore("src/test/resources/keystore.jks"); 110 ssl.setPassword("changeit"); 111 ssl.setKeyPassword("changeit"); 112 ssl.setPort(60647); 113 114 s.addConnector(ssl); 115 s.start(); 116 } 117 118 private class Receiver extends Thread { 119 120 private ISocketFactory myFactory; 121 private ServerSocket myServer; 122 private String myString; 123 124 public Receiver(ISocketFactory theFactory) { 125 myFactory = theFactory; 126 } 127 128 @Override 129 public void run() { 130 try { 131 132 ourLog.info("Listening on port {}", myPort); 133 134 myServer = myFactory.createServerSocket(); 135 myServer.bind(new InetSocketAddress(myPort)); 136 myServer.setSoTimeout(3000); 137 138 if (myServer instanceof SSLServerSocket) { 139 SSLServerSocket ss = (SSLServerSocket) myServer; 140 ourLog.info(Arrays.asList(ss.getEnabledCipherSuites()).toString()); 141 } 142 143 Socket socket = myServer.accept(); 144 socket.setSoTimeout(2000); 145 146 InputStream is = socket.getInputStream(); 147 StringBuilder b = new StringBuilder(); 148 for (;;) { 149 int next = is.read(); 150 if (next == -1) { 151 break; 152 } else { 153 b.append((char) next); 154 ourLog.info("Received: " + b); 155 } 156 } 157 158 myString = b.toString(); 159 } catch (Throwable e) { 160 ourLog.error("Failed", e); 161 fail(e.getMessage()); 162 } finally { 163 if (myServer != null) { 164 try { 165 myServer.close(); 166 } catch (Exception e) { 167 e.printStackTrace(); 168 } 169 } 170 } 171 } 172 173 } 174 175}