1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.client;
18
19 import java.net.URI;
20 import java.util.Collections;
21 import java.util.HashSet;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Set;
25
26 import org.apache.commons.logging.Log;
27 import org.apache.commons.logging.LogFactory;
28
29 import org.springframework.http.HttpHeaders;
30 import org.springframework.util.Assert;
31 import org.springframework.util.concurrent.ListenableFuture;
32 import org.springframework.web.socket.WebSocketExtension;
33 import org.springframework.web.socket.WebSocketHandler;
34 import org.springframework.web.socket.WebSocketHttpHeaders;
35 import org.springframework.web.socket.WebSocketSession;
36 import org.springframework.web.util.UriComponentsBuilder;
37
38
39
40
41
42
43
44 public abstract class AbstractWebSocketClient implements WebSocketClient {
45
46 protected final Log logger = LogFactory.getLog(getClass());
47
48 private static final Set<String> specialHeaders = new HashSet<String>();
49
50 static {
51 specialHeaders.add("cache-control");
52 specialHeaders.add("connection");
53 specialHeaders.add("host");
54 specialHeaders.add("sec-websocket-extensions");
55 specialHeaders.add("sec-websocket-key");
56 specialHeaders.add("sec-websocket-protocol");
57 specialHeaders.add("sec-websocket-version");
58 specialHeaders.add("pragma");
59 specialHeaders.add("upgrade");
60 }
61
62
63 @Override
64 public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
65 String uriTemplate, Object... uriVars) {
66
67 Assert.notNull(uriTemplate, "uriTemplate must not be null");
68 URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
69 return doHandshake(webSocketHandler, null, uri);
70 }
71
72 @Override
73 public final ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
74 WebSocketHttpHeaders headers, URI uri) {
75
76 Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
77 assertUri(uri);
78
79 if (logger.isDebugEnabled()) {
80 logger.debug("Connecting to " + uri);
81 }
82
83 HttpHeaders headersToUse = new HttpHeaders();
84 if (headers != null) {
85 for (String header : headers.keySet()) {
86 if (!specialHeaders.contains(header.toLowerCase())) {
87 headersToUse.put(header, headers.get(header));
88 }
89 }
90 }
91
92 List<String> subProtocols = ((headers != null) && (headers.getSecWebSocketProtocol() != null)) ?
93 headers.getSecWebSocketProtocol() : Collections.<String>emptyList();
94
95 List<WebSocketExtension> extensions = ((headers != null) && (headers.getSecWebSocketExtensions() != null)) ?
96 headers.getSecWebSocketExtensions() : Collections.<WebSocketExtension>emptyList();
97
98 return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, extensions,
99 Collections.<String, Object>emptyMap());
100 }
101
102 protected void assertUri(URI uri) {
103 Assert.notNull(uri, "uri must not be null");
104 String scheme = uri.getScheme();
105 Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme);
106 }
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122 protected abstract ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
123 HttpHeaders headers, URI uri, List<String> subProtocols, List<WebSocketExtension> extensions,
124 Map<String, Object> attributes);
125
126 }