View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.server.standard;
18  
19  import java.net.InetSocketAddress;
20  import java.security.Principal;
21  import java.util.ArrayList;
22  import java.util.List;
23  import java.util.Map;
24  import javax.servlet.ServletContext;
25  import javax.servlet.http.HttpServletRequest;
26  import javax.servlet.http.HttpServletResponse;
27  import javax.websocket.Endpoint;
28  import javax.websocket.Extension;
29  import javax.websocket.WebSocketContainer;
30  import javax.websocket.server.ServerContainer;
31  
32  import org.apache.commons.logging.Log;
33  import org.apache.commons.logging.LogFactory;
34  
35  import org.springframework.http.HttpHeaders;
36  import org.springframework.http.server.ServerHttpRequest;
37  import org.springframework.http.server.ServerHttpResponse;
38  import org.springframework.http.server.ServletServerHttpRequest;
39  import org.springframework.http.server.ServletServerHttpResponse;
40  import org.springframework.util.Assert;
41  import org.springframework.web.socket.WebSocketExtension;
42  import org.springframework.web.socket.WebSocketHandler;
43  import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter;
44  import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
45  import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
46  import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
47  import org.springframework.web.socket.server.HandshakeFailureException;
48  import org.springframework.web.socket.server.RequestUpgradeStrategy;
49  
50  /**
51   * A base class for {@link RequestUpgradeStrategy} implementations that build
52   * on the standard WebSocket API for Java (JSR-356).
53   *
54   * @author Rossen Stoyanchev
55   * @since 4.0
56   */
57  public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy {
58  
59  	protected final Log logger = LogFactory.getLog(getClass());
60  
61  	private volatile List<WebSocketExtension> extensions;
62  
63  
64  	protected ServerContainer getContainer(HttpServletRequest request) {
65  		ServletContext servletContext = request.getServletContext();
66  		String attrName = "javax.websocket.server.ServerContainer";
67  		ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName);
68  		Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
69  				"Are you running in a Servlet container that supports JSR-356?");
70  		return container;
71  	}
72  
73  	protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
74  		Assert.isTrue(request instanceof ServletServerHttpRequest);
75  		return ((ServletServerHttpRequest) request).getServletRequest();
76  	}
77  
78  	protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
79  		Assert.isTrue(response instanceof ServletServerHttpResponse);
80  		return ((ServletServerHttpResponse) response).getServletResponse();
81  	}
82  
83  
84  	@Override
85  	public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
86  		if (this.extensions == null) {
87  			HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
88  			this.extensions = getInstalledExtensions(getContainer(servletRequest));
89  		}
90  		return this.extensions;
91  	}
92  
93  	protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
94  		List<WebSocketExtension> result = new ArrayList<WebSocketExtension>();
95  		for (Extension ext : container.getInstalledExtensions()) {
96  			result.add(new StandardToWebSocketExtensionAdapter(ext));
97  		}
98  		return result;
99  	}
100 
101 
102 	@Override
103 	public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
104 			String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
105 			WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException {
106 
107 		HttpHeaders headers = request.getHeaders();
108 		InetSocketAddress localAddr = request.getLocalAddress();
109 		InetSocketAddress remoteAddr = request.getRemoteAddress();
110 
111 		StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user);
112 		StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session);
113 
114 		List<Extension> extensions = new ArrayList<Extension>();
115 		for (WebSocketExtension extension : selectedExtensions) {
116 			extensions.add(new WebSocketToStandardExtensionAdapter(extension));
117 		}
118 
119 		upgradeInternal(request, response, selectedProtocol, extensions, endpoint);
120 	}
121 
122 	protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
123 			String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint)
124 			throws HandshakeFailureException;
125 
126 }