1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.server.standard;
18
19 import java.net.InetSocketAddress;
20 import java.security.Principal;
21 import java.util.ArrayList;
22 import java.util.List;
23 import java.util.Map;
24 import javax.servlet.ServletContext;
25 import javax.servlet.http.HttpServletRequest;
26 import javax.servlet.http.HttpServletResponse;
27 import javax.websocket.Endpoint;
28 import javax.websocket.Extension;
29 import javax.websocket.WebSocketContainer;
30 import javax.websocket.server.ServerContainer;
31
32 import org.apache.commons.logging.Log;
33 import org.apache.commons.logging.LogFactory;
34
35 import org.springframework.http.HttpHeaders;
36 import org.springframework.http.server.ServerHttpRequest;
37 import org.springframework.http.server.ServerHttpResponse;
38 import org.springframework.http.server.ServletServerHttpRequest;
39 import org.springframework.http.server.ServletServerHttpResponse;
40 import org.springframework.util.Assert;
41 import org.springframework.web.socket.WebSocketExtension;
42 import org.springframework.web.socket.WebSocketHandler;
43 import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter;
44 import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
45 import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
46 import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
47 import org.springframework.web.socket.server.HandshakeFailureException;
48 import org.springframework.web.socket.server.RequestUpgradeStrategy;
49
50
51
52
53
54
55
56
57 public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy {
58
59 protected final Log logger = LogFactory.getLog(getClass());
60
61 private volatile List<WebSocketExtension> extensions;
62
63
64 protected ServerContainer getContainer(HttpServletRequest request) {
65 ServletContext servletContext = request.getServletContext();
66 String attrName = "javax.websocket.server.ServerContainer";
67 ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName);
68 Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
69 "Are you running in a Servlet container that supports JSR-356?");
70 return container;
71 }
72
73 protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
74 Assert.isTrue(request instanceof ServletServerHttpRequest);
75 return ((ServletServerHttpRequest) request).getServletRequest();
76 }
77
78 protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
79 Assert.isTrue(response instanceof ServletServerHttpResponse);
80 return ((ServletServerHttpResponse) response).getServletResponse();
81 }
82
83
84 @Override
85 public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
86 if (this.extensions == null) {
87 HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
88 this.extensions = getInstalledExtensions(getContainer(servletRequest));
89 }
90 return this.extensions;
91 }
92
93 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
94 List<WebSocketExtension> result = new ArrayList<WebSocketExtension>();
95 for (Extension ext : container.getInstalledExtensions()) {
96 result.add(new StandardToWebSocketExtensionAdapter(ext));
97 }
98 return result;
99 }
100
101
102 @Override
103 public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
104 String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
105 WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException {
106
107 HttpHeaders headers = request.getHeaders();
108 InetSocketAddress localAddr = request.getLocalAddress();
109 InetSocketAddress remoteAddr = request.getRemoteAddress();
110
111 StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user);
112 StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session);
113
114 List<Extension> extensions = new ArrayList<Extension>();
115 for (WebSocketExtension extension : selectedExtensions) {
116 extensions.add(new WebSocketToStandardExtensionAdapter(extension));
117 }
118
119 upgradeInternal(request, response, selectedProtocol, extensions, endpoint);
120 }
121
122 protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
123 String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint)
124 throws HandshakeFailureException;
125
126 }