001package ca.uhn.hl7v2.hoh.sockets;
002
003import java.io.IOException;
004import java.io.InputStream;
005import java.net.InetSocketAddress;
006import java.net.ServerSocket;
007import java.net.Socket;
008import java.util.Arrays;
009import javax.net.ssl.SSLHandshakeException;
010import javax.net.ssl.SSLServerSocket;
011
012import ca.uhn.hl7v2.hoh.util.RandomServerPortProvider;
013import org.junit.Before;
014import org.junit.Test;
015import org.mortbay.jetty.Server;
016import org.mortbay.jetty.security.SslSelectChannelConnector;
017
018import static org.junit.Assert.assertEquals;
019import static org.junit.Assert.fail;
020
021public class CustomCertificateTlsSocketFactoryTest {
022
023        private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(CustomCertificateTlsSocketFactoryTest.class);
024
025        private int myPort;
026
027        @Before
028        public void before() {
029                myPort = RandomServerPortProvider.findFreePort();
030        }
031
032        @Test
033        public void testConnectToNonTrustedSocket() throws IOException, InterruptedException {
034
035                CustomCertificateTlsSocketFactory badServer = createTrustedServerSocketFactory();
036                Receiver receiver = new Receiver(badServer);
037                receiver.start();
038                Thread.sleep(500);
039
040                try {
041
042                        CustomCertificateTlsSocketFactory goodClient = createNonTrustedClientSocketFactory();
043                        Socket client = goodClient.createClientSocket();
044                        client.connect(new InetSocketAddress("localhost", myPort));
045
046                        client.getOutputStream().write("HELLO WORLD".getBytes());
047                        fail();
048
049                } catch (SSLHandshakeException e) {
050
051                }
052        }
053
054        @Test
055        public void testConnectToTrustedSocket() throws IOException, InterruptedException {
056
057                CustomCertificateTlsSocketFactory goodServer = createTrustedServerSocketFactory();
058                Receiver receiver = new Receiver(goodServer);
059                receiver.start();
060                Thread.sleep(500);
061
062                CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
063                goodClient.setKeystoreFilename("src/test/resources/truststore.jks");
064                // goodClient.setKeystorePassphrase("changeit");
065                Socket client = goodClient.createClientSocket();
066                client.connect(new InetSocketAddress("localhost", myPort));
067
068                client.getOutputStream().write("HELLO WORLD".getBytes());
069                client.close();
070
071                Thread.sleep(500);
072                String expected = "HELLO WORLD";
073                String actual = receiver.myString;
074                assertEquals(expected, actual);
075
076        }
077
078        public static CustomCertificateTlsSocketFactory createNonTrustedClientSocketFactory() {
079                CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
080                goodClient.setKeystoreFilename("src/test/resources/truststore2.jks");
081                goodClient.setKeystorePassphrase("trustpassword");
082                return goodClient;
083        }
084
085        public static StandardSocketFactory createNonSslServerSocketFactory() {
086                StandardSocketFactory goodClient = new StandardSocketFactory();
087                return goodClient;
088        }
089        
090        public static CustomCertificateTlsSocketFactory createTrustedClientSocketFactory() {
091                CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
092                goodClient.setKeystoreFilename("src/test/resources/truststore.jks");
093//              goodClient.setKeystorePassphrase("trustpassword");
094                return goodClient;
095        }
096
097        public static CustomCertificateTlsSocketFactory createTrustedServerSocketFactory() {
098                CustomCertificateTlsSocketFactory goodServer = new CustomCertificateTlsSocketFactory();
099                goodServer.setKeystoreFilename("src/test/resources/keystore.jks");
100                goodServer.setKeystorePassphrase("changeit");
101                return goodServer;
102        }
103
104        public static void main(String[] args) throws Exception {
105
106                Server s = new Server();
107
108                SslSelectChannelConnector ssl = new SslSelectChannelConnector();
109                ssl.setKeystore("src/test/resources/keystore.jks");
110                ssl.setPassword("changeit");
111                ssl.setKeyPassword("changeit");
112                ssl.setPort(60647);
113
114                s.addConnector(ssl);
115                s.start();
116        }
117
118        private class Receiver extends Thread {
119
120                private ISocketFactory myFactory;
121                private ServerSocket myServer;
122                private String myString;
123
124                public Receiver(ISocketFactory theFactory) {
125                        myFactory = theFactory;
126                }
127
128                @Override
129                public void run() {
130                        try {
131
132                                ourLog.info("Listening on port {}", myPort);
133
134                                myServer = myFactory.createServerSocket();
135                                myServer.bind(new InetSocketAddress(myPort));
136                                myServer.setSoTimeout(3000);
137
138                                if (myServer instanceof SSLServerSocket) {
139                                        SSLServerSocket ss = (SSLServerSocket) myServer;
140                                        ourLog.info(Arrays.asList(ss.getEnabledCipherSuites()).toString());
141                                }
142                                
143                                Socket socket = myServer.accept();
144                                socket.setSoTimeout(2000);
145
146                                InputStream is = socket.getInputStream();
147                                StringBuilder b = new StringBuilder();
148                                for (;;) {
149                                        int next = is.read();
150                                        if (next == -1) {
151                                                break;
152                                        } else {
153                                                b.append((char) next);
154                                                ourLog.info("Received: " + b);
155                                        }
156                                }
157
158                                myString = b.toString();
159                        } catch (Throwable e) {
160                                ourLog.error("Failed", e);
161                                fail(e.getMessage());
162                        } finally {
163                                if (myServer != null) {
164                                        try {
165                                                myServer.close();
166                                        } catch (Exception e) {
167                                                e.printStackTrace();
168                                        }
169                                }
170                        }
171                }
172
173        }
174
175}