View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.client;
18  
19  import java.net.URI;
20  import java.util.Collections;
21  import java.util.HashSet;
22  import java.util.List;
23  import java.util.Map;
24  import java.util.Set;
25  
26  import org.apache.commons.logging.Log;
27  import org.apache.commons.logging.LogFactory;
28  
29  import org.springframework.http.HttpHeaders;
30  import org.springframework.util.Assert;
31  import org.springframework.util.concurrent.ListenableFuture;
32  import org.springframework.web.socket.WebSocketExtension;
33  import org.springframework.web.socket.WebSocketHandler;
34  import org.springframework.web.socket.WebSocketHttpHeaders;
35  import org.springframework.web.socket.WebSocketSession;
36  import org.springframework.web.util.UriComponentsBuilder;
37  
38  /**
39   * Abstract base class for {@link WebSocketClient} implementations.
40   *
41   * @author Rossen Stoyanchev
42   * @since 4.0
43   */
44  public abstract class AbstractWebSocketClient implements WebSocketClient {
45  
46  	protected final Log logger = LogFactory.getLog(getClass());
47  
48  	private static final Set<String> specialHeaders = new HashSet<String>();
49  
50  	static {
51  		specialHeaders.add("cache-control");
52  		specialHeaders.add("connection");
53  		specialHeaders.add("host");
54  		specialHeaders.add("sec-websocket-extensions");
55  		specialHeaders.add("sec-websocket-key");
56  		specialHeaders.add("sec-websocket-protocol");
57  		specialHeaders.add("sec-websocket-version");
58  		specialHeaders.add("pragma");
59  		specialHeaders.add("upgrade");
60  	}
61  
62  
63  	@Override
64  	public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
65  			String uriTemplate, Object... uriVars) {
66  
67  		Assert.notNull(uriTemplate, "uriTemplate must not be null");
68  		URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
69  		return doHandshake(webSocketHandler, null, uri);
70  	}
71  
72  	@Override
73  	public final ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
74  			WebSocketHttpHeaders headers, URI uri) {
75  
76  		Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
77  		assertUri(uri);
78  
79  		if (logger.isDebugEnabled()) {
80  			logger.debug("Connecting to " + uri);
81  		}
82  
83  		HttpHeaders headersToUse = new HttpHeaders();
84  		if (headers != null) {
85  			for (String header : headers.keySet()) {
86  				if (!specialHeaders.contains(header.toLowerCase())) {
87  					headersToUse.put(header, headers.get(header));
88  				}
89  			}
90  		}
91  
92  		List<String> subProtocols = ((headers != null) && (headers.getSecWebSocketProtocol() != null)) ?
93  				headers.getSecWebSocketProtocol() : Collections.<String>emptyList();
94  
95  		List<WebSocketExtension> extensions = ((headers != null) && (headers.getSecWebSocketExtensions() != null)) ?
96  				headers.getSecWebSocketExtensions() : Collections.<WebSocketExtension>emptyList();
97  
98  		return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, extensions,
99  				Collections.<String, Object>emptyMap());
100 	}
101 
102 	protected void assertUri(URI uri) {
103 		Assert.notNull(uri, "uri must not be null");
104 		String scheme = uri.getScheme();
105 		Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme);
106 	}
107 
108 	/**
109 	 * Perform the actual handshake to establish a connection to the server.
110 	 *
111 	 * @param webSocketHandler the client-side handler for WebSocket messages
112 	 * @param headers HTTP headers to use for the handshake, with unwanted (forbidden)
113 	 * headers filtered out, never {@code null}
114 	 * @param uri the target URI for the handshake, never {@code null}
115 	 * @param subProtocols requested sub-protocols, or an empty list
116 	 * @param extensions requested WebSocket extensions, or an empty list
117 	 * @param attributes attributes to associate with the WebSocketSession, i.e. via
118 	 * {@link WebSocketSession#getAttributes()}; currently always an empty map.
119 	 *
120 	 * @return the established WebSocket session wrapped in a ListenableFuture.
121 	 */
122 	protected abstract ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
123 			HttpHeaders headers, URI uri, List<String> subProtocols, List<WebSocketExtension> extensions,
124 			Map<String, Object> attributes);
125 
126 }