View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.server.standard;
18  
19  import java.lang.reflect.Constructor;
20  import java.util.Arrays;
21  import java.util.Collections;
22  import java.util.List;
23  import java.util.Set;
24  import java.util.concurrent.ConcurrentHashMap;
25  import javax.servlet.http.HttpServletRequest;
26  import javax.servlet.http.HttpServletResponse;
27  import javax.websocket.Decoder;
28  import javax.websocket.Encoder;
29  import javax.websocket.Endpoint;
30  import javax.websocket.Extension;
31  import javax.websocket.server.ServerEndpointConfig;
32  
33  import io.undertow.server.HttpServerExchange;
34  import io.undertow.server.HttpUpgradeListener;
35  import io.undertow.servlet.api.InstanceFactory;
36  import io.undertow.servlet.api.InstanceHandle;
37  import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
38  import io.undertow.util.PathTemplate;
39  import io.undertow.websockets.core.WebSocketChannel;
40  import io.undertow.websockets.core.WebSocketVersion;
41  import io.undertow.websockets.core.protocol.Handshake;
42  import io.undertow.websockets.jsr.ConfiguredServerEndpoint;
43  import io.undertow.websockets.jsr.EncodingFactory;
44  import io.undertow.websockets.jsr.EndpointSessionHandler;
45  import io.undertow.websockets.jsr.ServerWebSocketContainer;
46  import io.undertow.websockets.jsr.annotated.AnnotatedEndpointFactory;
47  import io.undertow.websockets.jsr.handshake.HandshakeUtil;
48  import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake;
49  import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake;
50  import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake;
51  import org.xnio.StreamConnection;
52  
53  import org.springframework.http.server.ServerHttpRequest;
54  import org.springframework.http.server.ServerHttpResponse;
55  import org.springframework.util.ClassUtils;
56  import org.springframework.web.socket.server.HandshakeFailureException;
57  
58  
59  /**
60   * A {@link org.springframework.web.socket.server.RequestUpgradeStrategy} for use
61   * with WildFly and its underlying Undertow web server.
62   *
63   * @author Rossen Stoyanchev
64   * @since 4.0.1
65   */
66  public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
67  
68  	private static final Constructor<ServletWebSocketHttpExchange> exchangeConstructor;
69  
70  	private static final Constructor<ConfiguredServerEndpoint> endpointConstructor;
71  
72  	private static final boolean undertow10Present;
73  
74  	private static final boolean undertow11Present;
75  
76  	static {
77  		Class<ServletWebSocketHttpExchange> exchangeType = ServletWebSocketHttpExchange.class;
78  		Class<?>[] exchangeParamTypes = new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class, Set.class};
79  		if (ClassUtils.hasConstructor(exchangeType, exchangeParamTypes)) {
80  			exchangeConstructor = ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes);
81  			undertow10Present = false;
82  		}
83  		else {
84  			exchangeParamTypes = new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class};
85  			exchangeConstructor = ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes);
86  			undertow10Present = true;
87  		}
88  
89  		Class<ConfiguredServerEndpoint> endpointType = ConfiguredServerEndpoint.class;
90  		Class<?>[] endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class,
91  				PathTemplate.class, EncodingFactory.class, AnnotatedEndpointFactory.class};
92  		if (ClassUtils.hasConstructor(endpointType, endpointParamTypes)) {
93  			endpointConstructor = ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes);
94  			undertow11Present = true;
95  		}
96  		else {
97  			endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class,
98  					PathTemplate.class, EncodingFactory.class};
99  			endpointConstructor = ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes);
100 			undertow11Present = false;
101 		}
102 	}
103 
104 	private static final String[] supportedVersions = new String[] {
105 			WebSocketVersion.V13.toHttpHeaderValue(),
106 			WebSocketVersion.V08.toHttpHeaderValue(),
107 			WebSocketVersion.V07.toHttpHeaderValue()
108 	};
109 
110 
111 	private final Set<WebSocketChannel> peerConnections;
112 
113 
114 	public UndertowRequestUpgradeStrategy() {
115 		if (undertow10Present) {
116 			this.peerConnections = null;
117 		}
118 		else {
119 			this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>());
120 		}
121 	}
122 
123 
124 	@Override
125 	public String[] getSupportedVersions() {
126 		return supportedVersions;
127 	}
128 
129 	@Override
130 	protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
131 			String selectedProtocol, List<Extension> selectedExtensions, final Endpoint endpoint)
132 			throws HandshakeFailureException {
133 
134 		HttpServletRequest servletRequest = getHttpServletRequest(request);
135 		HttpServletResponse servletResponse = getHttpServletResponse(response);
136 
137 		final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse);
138 		exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap());
139 
140 		ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest);
141 		final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer);
142 
143 		final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint(
144 				selectedProtocol, selectedExtensions, endpoint, servletRequest);
145 
146 		final Handshake handshake = getHandshakeToUse(exchange, configuredServerEndpoint);
147 
148 		exchange.upgradeChannel(new HttpUpgradeListener() {
149 			@Override
150 			public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) {
151 				WebSocketChannel channel = handshake.createChannel(exchange, connection, exchange.getBufferPool());
152 				if (peerConnections != null) {
153 					peerConnections.add(channel);
154 				}
155 				endpointSessionHandler.onConnect(exchange, channel);
156 			}
157 		});
158 
159 		handshake.handshake(exchange);
160 	}
161 
162 	private ServletWebSocketHttpExchange createHttpExchange(HttpServletRequest request, HttpServletResponse response) {
163 		try {
164 			return (this.peerConnections != null ?
165 					exchangeConstructor.newInstance(request, response, this.peerConnections) :
166 					exchangeConstructor.newInstance(request, response));
167 		}
168 		catch (Exception ex) {
169 			throw new HandshakeFailureException("Failed to instantiate ServletWebSocketHttpExchange", ex);
170 		}
171 	}
172 
173 	private Handshake getHandshakeToUse(ServletWebSocketHttpExchange exchange, ConfiguredServerEndpoint endpoint) {
174 		Handshake handshake = new JsrHybi13Handshake(endpoint);
175 		if (handshake.matches(exchange)) {
176 			return handshake;
177 		}
178 		handshake = new JsrHybi08Handshake(endpoint);
179 		if (handshake.matches(exchange)) {
180 			return handshake;
181 		}
182 		handshake = new JsrHybi07Handshake(endpoint);
183 		if (handshake.matches(exchange)) {
184 			return handshake;
185 		}
186 		// Should never occur
187 		throw new HandshakeFailureException("No matching Undertow Handshake found: " + exchange.getRequestHeaders());
188 	}
189 
190 	private ConfiguredServerEndpoint createConfiguredServerEndpoint(String selectedProtocol,
191 			List<Extension> selectedExtensions, Endpoint endpoint, HttpServletRequest servletRequest) {
192 
193 		String path = servletRequest.getRequestURI();  // shouldn't matter
194 		ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration(path, endpoint);
195 		endpointRegistration.setSubprotocols(Arrays.asList(selectedProtocol));
196 		endpointRegistration.setExtensions(selectedExtensions);
197 
198 		EncodingFactory encodingFactory = new EncodingFactory(
199 				Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
200 				Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap(),
201 				Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
202 				Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap());
203 		try {
204 			return undertow11Present ?
205 					endpointConstructor.newInstance(endpointRegistration,
206 						new EndpointInstanceFactory(endpoint), null, encodingFactory, null) :
207 					endpointConstructor.newInstance(endpointRegistration,
208 						new EndpointInstanceFactory(endpoint), null, encodingFactory);
209 		}
210 		catch (Exception ex) {
211 			throw new HandshakeFailureException("Failed to instantiate ConfiguredServerEndpoint", ex);
212 		}
213 	}
214 
215 
216 	private static class EndpointInstanceFactory implements InstanceFactory<Endpoint> {
217 
218 		private final Endpoint endpoint;
219 
220 		public EndpointInstanceFactory(Endpoint endpoint) {
221 			this.endpoint = endpoint;
222 		}
223 
224 		@Override
225 		public InstanceHandle<Endpoint> createInstance() throws InstantiationException {
226 			return new InstanceHandle<Endpoint>() {
227 				@Override
228 				public Endpoint getInstance() {
229 					return endpoint;
230 				}
231 				@Override
232 				public void release() {
233 				}
234 			};
235 		}
236 	}
237 
238 }