1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.client.standard;
18
19 import java.net.InetAddress;
20 import java.net.InetSocketAddress;
21 import java.net.URI;
22 import java.net.UnknownHostException;
23 import java.util.ArrayList;
24 import java.util.List;
25 import java.util.Locale;
26 import java.util.Map;
27 import java.util.concurrent.Callable;
28 import javax.websocket.ClientEndpointConfig;
29 import javax.websocket.ClientEndpointConfig.Configurator;
30 import javax.websocket.ContainerProvider;
31 import javax.websocket.Endpoint;
32 import javax.websocket.Extension;
33 import javax.websocket.HandshakeResponse;
34 import javax.websocket.WebSocketContainer;
35
36 import org.springframework.core.task.AsyncListenableTaskExecutor;
37 import org.springframework.core.task.SimpleAsyncTaskExecutor;
38 import org.springframework.core.task.TaskExecutor;
39 import org.springframework.http.HttpHeaders;
40 import org.springframework.lang.UsesJava7;
41 import org.springframework.util.Assert;
42 import org.springframework.util.concurrent.ListenableFuture;
43 import org.springframework.util.concurrent.ListenableFutureTask;
44 import org.springframework.web.socket.WebSocketExtension;
45 import org.springframework.web.socket.WebSocketHandler;
46 import org.springframework.web.socket.WebSocketSession;
47 import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
48 import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
49 import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
50 import org.springframework.web.socket.client.AbstractWebSocketClient;
51
52
53
54
55
56
57
58
59 public class StandardWebSocketClient extends AbstractWebSocketClient {
60
61 private final WebSocketContainer webSocketContainer;
62
63 private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
64
65
66
67
68
69
70
71 public StandardWebSocketClient() {
72 this.webSocketContainer = ContainerProvider.getWebSocketContainer();
73 }
74
75
76
77
78
79
80
81 public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
82 Assert.notNull(webSocketContainer, "WebSocketContainer must not be null");
83 this.webSocketContainer = webSocketContainer;
84 }
85
86
87
88
89
90
91
92
93 public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
94 this.taskExecutor = taskExecutor;
95 }
96
97
98
99
100 public AsyncListenableTaskExecutor getTaskExecutor() {
101 return this.taskExecutor;
102 }
103
104
105 @Override
106 protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
107 HttpHeaders headers, final URI uri, List<String> protocols,
108 List<WebSocketExtension> extensions, Map<String, Object> attributes) {
109
110 int port = getPort(uri);
111 InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
112 InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
113
114 final StandardWebSocketSession session = new StandardWebSocketSession(headers,
115 attributes, localAddress, remoteAddress);
116
117 final ClientEndpointConfig.Builder configBuilder = ClientEndpointConfig.Builder.create();
118 configBuilder.configurator(new StandardWebSocketClientConfigurator(headers));
119 configBuilder.preferredSubprotocols(protocols);
120 configBuilder.extensions(adaptExtensions(extensions));
121 final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
122
123 Callable<WebSocketSession> connectTask = new Callable<WebSocketSession>() {
124 @Override
125 public WebSocketSession call() throws Exception {
126 webSocketContainer.connectToServer(endpoint, configBuilder.build(), uri);
127 return session;
128 }
129 };
130
131 if (this.taskExecutor != null) {
132 return this.taskExecutor.submitListenable(connectTask);
133 }
134 else {
135 ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<WebSocketSession>(connectTask);
136 task.run();
137 return task;
138 }
139 }
140
141 private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) {
142 List<Extension> result = new ArrayList<Extension>();
143 for (WebSocketExtension extension : extensions) {
144 result.add(new WebSocketToStandardExtensionAdapter(extension));
145 }
146 return result;
147 }
148
149 @UsesJava7
150 private InetAddress getLocalHost() {
151 try {
152 return InetAddress.getLocalHost();
153 }
154 catch (UnknownHostException ex) {
155 return InetAddress.getLoopbackAddress();
156 }
157 }
158
159 private int getPort(URI uri) {
160 if (uri.getPort() == -1) {
161 String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
162 return ("wss".equals(scheme) ? 443 : 80);
163 }
164 return uri.getPort();
165 }
166
167
168 private class StandardWebSocketClientConfigurator extends Configurator {
169
170 private final HttpHeaders headers;
171
172 public StandardWebSocketClientConfigurator(HttpHeaders headers) {
173 this.headers = headers;
174 }
175
176 @Override
177 public void beforeRequest(Map<String, List<String>> requestHeaders) {
178 requestHeaders.putAll(this.headers);
179 if (logger.isTraceEnabled()) {
180 logger.trace("Handshake request headers: " + requestHeaders);
181 }
182 }
183 @Override
184 public void afterResponse(HandshakeResponse response) {
185 if (logger.isTraceEnabled()) {
186 logger.trace("Handshake response headers: " + response.getHeaders());
187 }
188 }
189 }
190
191 }