1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.server.standard;
18
19 import java.io.IOException;
20 import java.lang.reflect.Constructor;
21 import java.lang.reflect.Method;
22 import java.net.URI;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.Random;
27 import javax.servlet.ServletException;
28 import javax.servlet.http.HttpServletRequest;
29 import javax.servlet.http.HttpServletResponse;
30 import javax.websocket.DeploymentException;
31 import javax.websocket.Endpoint;
32 import javax.websocket.EndpointConfig;
33 import javax.websocket.Extension;
34 import javax.websocket.WebSocketContainer;
35
36 import org.glassfish.tyrus.core.ComponentProviderService;
37 import org.glassfish.tyrus.core.RequestContext;
38 import org.glassfish.tyrus.core.TyrusEndpoint;
39 import org.glassfish.tyrus.core.TyrusEndpointWrapper;
40 import org.glassfish.tyrus.core.TyrusUpgradeResponse;
41 import org.glassfish.tyrus.core.TyrusWebSocketEngine;
42 import org.glassfish.tyrus.core.Version;
43 import org.glassfish.tyrus.core.WebSocketApplication;
44 import org.glassfish.tyrus.server.TyrusServerContainer;
45 import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;
46
47 import org.springframework.beans.DirectFieldAccessor;
48 import org.springframework.http.HttpHeaders;
49 import org.springframework.http.server.ServerHttpRequest;
50 import org.springframework.http.server.ServerHttpResponse;
51 import org.springframework.util.ReflectionUtils;
52 import org.springframework.util.StringUtils;
53 import org.springframework.web.socket.WebSocketExtension;
54 import org.springframework.web.socket.server.HandshakeFailureException;
55
56
57
58
59
60
61
62
63
64
65 public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
66
67 private static final Random random = new Random();
68
69 private final ComponentProviderService componentProvider = ComponentProviderService.create();
70
71
72 @Override
73 public String[] getSupportedVersions() {
74 return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions());
75 }
76
77 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
78 try {
79 return super.getInstalledExtensions(container);
80 }
81 catch (UnsupportedOperationException ex) {
82 return new ArrayList<WebSocketExtension>(0);
83 }
84 }
85
86 @Override
87 public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
88 String selectedProtocol, List<Extension> extensions, Endpoint endpoint)
89 throws HandshakeFailureException {
90
91 HttpServletRequest servletRequest = getHttpServletRequest(request);
92 HttpServletResponse servletResponse = getHttpServletResponse(response);
93
94 TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest);
95 TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine();
96 Object tyrusEndpoint = null;
97
98 try {
99
100 String path = "/" + random.nextLong();
101 tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine);
102 getEndpointHelper().register(engine, tyrusEndpoint);
103
104 HttpHeaders headers = request.getHeaders();
105 RequestContext requestContext = createRequestContext(servletRequest, path, headers);
106 TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
107 UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
108
109 switch (upgradeInfo.getStatus()) {
110 case SUCCESS:
111 if (logger.isTraceEnabled()) {
112 logger.trace("Successful upgrade: " + upgradeResponse.getHeaders());
113 }
114 handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse);
115 break;
116 case HANDSHAKE_FAILED:
117
118 throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI());
119 case NOT_APPLICABLE:
120
121 throw new HandshakeFailureException("Unexpected handshake mapping failure: " + request.getURI());
122 }
123 }
124 catch (Exception ex) {
125 throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex);
126 }
127 finally {
128 if (tyrusEndpoint != null) {
129 getEndpointHelper().unregister(engine, tyrusEndpoint);
130 }
131 }
132 }
133
134 private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, String protocol,
135 List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine)
136 throws DeploymentException {
137
138 ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint);
139 endpointConfig.setSubprotocols(Arrays.asList(protocol));
140 endpointConfig.setExtensions(extensions);
141 return getEndpointHelper().createdEndpoint(endpointConfig, this.componentProvider, container, engine);
142 }
143
144 private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) {
145 RequestContext context =
146 RequestContext.Builder.create()
147 .requestURI(URI.create(endpointPath))
148 .userPrincipal(request.getUserPrincipal())
149 .secure(request.isSecure())
150
151 .build();
152 for (String header : headers.keySet()) {
153 context.getHeaders().put(header, headers.get(header));
154 }
155 return context;
156 }
157
158
159 protected abstract TyrusEndpointHelper getEndpointHelper();
160
161 protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response,
162 UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException;
163
164
165
166
167
168 protected interface TyrusEndpointHelper {
169
170 Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
171 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException;
172
173 void register(TyrusWebSocketEngine engine, Object endpoint);
174
175 void unregister(TyrusWebSocketEngine engine, Object endpoint);
176 }
177
178
179 protected static class Tyrus17EndpointHelper implements TyrusEndpointHelper {
180
181 private static final Constructor<?> constructor;
182
183 private static final Method registerMethod;
184
185 private static final Method unRegisterMethod;
186
187 static {
188 try {
189 constructor = getEndpointConstructor();
190 registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class);
191 unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class);
192 ReflectionUtils.makeAccessible(registerMethod);
193 }
194 catch (Exception ex) {
195 throw new IllegalStateException("No compatible Tyrus version found", ex);
196 }
197 }
198
199 private static Constructor<?> getEndpointConstructor() {
200 for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) {
201 Class<?>[] types = current.getParameterTypes();
202 if (types[0].equals(Endpoint.class) && types[1].equals(EndpointConfig.class)) {
203 return current;
204 }
205 }
206 throw new IllegalStateException("No compatible Tyrus version found");
207 }
208
209
210 @Override
211 public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
212 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {
213
214 DirectFieldAccessor accessor = new DirectFieldAccessor(engine);
215 Object sessionListener = accessor.getPropertyValue("sessionListener");
216 Object clusterContext = accessor.getPropertyValue("clusterContext");
217 try {
218 return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
219 "/", registration.getConfigurator(), sessionListener, clusterContext, null);
220 }
221 catch (Exception ex) {
222 throw new HandshakeFailureException("Failed to register " + registration, ex);
223 }
224 }
225
226 @Override
227 public void register(TyrusWebSocketEngine engine, Object endpoint) {
228 try {
229 registerMethod.invoke(engine, endpoint);
230 }
231 catch (Exception ex) {
232 throw new HandshakeFailureException("Failed to register " + endpoint, ex);
233 }
234 }
235
236 @Override
237 public void unregister(TyrusWebSocketEngine engine, Object endpoint) {
238 try {
239 unRegisterMethod.invoke(engine, endpoint);
240 }
241 catch (Exception ex) {
242 throw new HandshakeFailureException("Failed to unregister " + endpoint, ex);
243 }
244 }
245 }
246
247
248 protected static class Tyrus135EndpointHelper implements TyrusEndpointHelper {
249
250 private static final Method registerMethod;
251
252 static {
253 try {
254 registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", WebSocketApplication.class);
255 ReflectionUtils.makeAccessible(registerMethod);
256 }
257 catch (Exception ex) {
258 throw new IllegalStateException("No compatible Tyrus version found", ex);
259 }
260 }
261
262 @Override
263 public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
264 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {
265
266 TyrusEndpointWrapper endpointWrapper = new TyrusEndpointWrapper(registration.getEndpoint(),
267 registration, provider, container, "/", registration.getConfigurator());
268
269 return new TyrusEndpoint(endpointWrapper);
270 }
271
272 @Override
273 public void register(TyrusWebSocketEngine engine, Object endpoint) {
274 try {
275 registerMethod.invoke(engine, endpoint);
276 }
277 catch (Exception ex) {
278 throw new HandshakeFailureException("Failed to register " + endpoint, ex);
279 }
280 }
281
282 @Override
283 public void unregister(TyrusWebSocketEngine engine, Object endpoint) {
284 engine.unregister((TyrusEndpoint) endpoint);
285 }
286 }
287
288 }