View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.client.standard;
18  
19  import java.net.InetAddress;
20  import java.net.InetSocketAddress;
21  import java.net.URI;
22  import java.net.UnknownHostException;
23  import java.util.ArrayList;
24  import java.util.List;
25  import java.util.Locale;
26  import java.util.Map;
27  import java.util.concurrent.Callable;
28  import javax.websocket.ClientEndpointConfig;
29  import javax.websocket.ClientEndpointConfig.Configurator;
30  import javax.websocket.ContainerProvider;
31  import javax.websocket.Endpoint;
32  import javax.websocket.Extension;
33  import javax.websocket.HandshakeResponse;
34  import javax.websocket.WebSocketContainer;
35  
36  import org.springframework.core.task.AsyncListenableTaskExecutor;
37  import org.springframework.core.task.SimpleAsyncTaskExecutor;
38  import org.springframework.core.task.TaskExecutor;
39  import org.springframework.http.HttpHeaders;
40  import org.springframework.lang.UsesJava7;
41  import org.springframework.util.Assert;
42  import org.springframework.util.concurrent.ListenableFuture;
43  import org.springframework.util.concurrent.ListenableFutureTask;
44  import org.springframework.web.socket.WebSocketExtension;
45  import org.springframework.web.socket.WebSocketHandler;
46  import org.springframework.web.socket.WebSocketSession;
47  import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
48  import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
49  import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
50  import org.springframework.web.socket.client.AbstractWebSocketClient;
51  
52  /**
53   * Initiates WebSocket requests to a WebSocket server programmatically
54   * through the standard Java WebSocket API.
55   *
56   * @author Rossen Stoyanchev
57   * @since 4.0
58   */
59  public class StandardWebSocketClient extends AbstractWebSocketClient {
60  
61  	private final WebSocketContainer webSocketContainer;
62  
63  	private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
64  
65  
66  	/**
67  	 * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()}
68  	 * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor
69  	 * accepting existing {@code WebSocketContainer} instance.
70  	 */
71  	public StandardWebSocketClient() {
72  		this.webSocketContainer = ContainerProvider.getWebSocketContainer();
73  	}
74  
75  	/**
76  	 * Constructor accepting an existing {@link WebSocketContainer} instance.
77  	 * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java
78  	 * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain
79  	 * the {@code WebSocketContainer} instance.
80  	 */
81  	public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
82  		Assert.notNull(webSocketContainer, "WebSocketContainer must not be null");
83  		this.webSocketContainer = webSocketContainer;
84  	}
85  
86  
87  	/**
88  	 * Set an {@link AsyncListenableTaskExecutor} to use when opening connections.
89  	 * If this property is set to {@code null}, calls to  any of the
90  	 * {@code doHandshake} methods will block until the connection is established.
91  	 * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used.
92  	 */
93  	public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
94  		this.taskExecutor = taskExecutor;
95  	}
96  
97  	/**
98  	 * Return the configured {@link TaskExecutor}.
99  	 */
100 	public AsyncListenableTaskExecutor getTaskExecutor() {
101 		return this.taskExecutor;
102 	}
103 
104 
105 	@Override
106 	protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
107 			HttpHeaders headers, final URI uri, List<String> protocols,
108 			List<WebSocketExtension> extensions, Map<String, Object> attributes) {
109 
110 		int port = getPort(uri);
111 		InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
112 		InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
113 
114 		final StandardWebSocketSession session = new StandardWebSocketSession(headers,
115 				attributes, localAddress, remoteAddress);
116 
117 		final ClientEndpointConfig.Builder configBuilder = ClientEndpointConfig.Builder.create();
118 		configBuilder.configurator(new StandardWebSocketClientConfigurator(headers));
119 		configBuilder.preferredSubprotocols(protocols);
120 		configBuilder.extensions(adaptExtensions(extensions));
121 		final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
122 
123 		Callable<WebSocketSession> connectTask = new Callable<WebSocketSession>() {
124 			@Override
125 			public WebSocketSession call() throws Exception {
126 				webSocketContainer.connectToServer(endpoint, configBuilder.build(), uri);
127 				return session;
128 			}
129 		};
130 
131 		if (this.taskExecutor != null) {
132 			return this.taskExecutor.submitListenable(connectTask);
133 		}
134 		else {
135 			ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<WebSocketSession>(connectTask);
136 			task.run();
137 			return task;
138 		}
139 	}
140 
141 	private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) {
142 		List<Extension> result = new ArrayList<Extension>();
143 		for (WebSocketExtension extension : extensions) {
144 			result.add(new WebSocketToStandardExtensionAdapter(extension));
145 		}
146 		return result;
147 	}
148 
149 	@UsesJava7  // fallback to InetAddress.getLoopbackAddress()
150 	private InetAddress getLocalHost() {
151 		try {
152 			return InetAddress.getLocalHost();
153 		}
154 		catch (UnknownHostException ex) {
155 			return InetAddress.getLoopbackAddress();
156 		}
157 	}
158 
159 	private int getPort(URI uri) {
160 		if (uri.getPort() == -1) {
161 	        String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
162 	        return ("wss".equals(scheme) ? 443 : 80);
163 		}
164 		return uri.getPort();
165 	}
166 
167 
168 	private class StandardWebSocketClientConfigurator extends Configurator {
169 
170 		private final HttpHeaders headers;
171 
172 		public StandardWebSocketClientConfigurator(HttpHeaders headers) {
173 			this.headers = headers;
174 		}
175 
176 		@Override
177 		public void beforeRequest(Map<String, List<String>> requestHeaders) {
178 			requestHeaders.putAll(this.headers);
179 			if (logger.isTraceEnabled()) {
180 				logger.trace("Handshake request headers: " + requestHeaders);
181 			}
182 		}
183 		@Override
184 		public void afterResponse(HandshakeResponse response) {
185 			if (logger.isTraceEnabled()) {
186 				logger.trace("Handshake response headers: " + response.getHeaders());
187 			}
188 		}
189 	}
190 
191 }