View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.config.annotation;
18  
19  import java.util.Arrays;
20  import java.util.List;
21  import java.util.Map;
22  
23  import org.junit.Before;
24  import org.junit.Test;
25  
26  import org.springframework.messaging.MessageChannel;
27  import org.springframework.messaging.SubscribableChannel;
28  import org.springframework.scheduling.TaskScheduler;
29  import org.springframework.util.MultiValueMap;
30  import org.springframework.web.HttpRequestHandler;
31  import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
32  import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
33  import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
34  import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
35  import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
36  import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
37  import org.springframework.web.socket.sockjs.transport.TransportHandler;
38  import org.springframework.web.socket.sockjs.transport.TransportType;
39  import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
40  import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
41  
42  import static org.junit.Assert.*;
43  import static org.mockito.Mockito.*;
44  
45  /**
46   * Test fixture for
47   * {@link org.springframework.web.socket.config.annotation.WebMvcStompWebSocketEndpointRegistration}.
48   *
49   * @author Rossen Stoyanchev
50   */
51  public class WebMvcStompWebSocketEndpointRegistrationTests {
52  
53  	private SubProtocolWebSocketHandler handler;
54  
55  	private TaskScheduler scheduler;
56  
57  
58  	@Before
59  	public void setup() {
60  		this.handler = new SubProtocolWebSocketHandler(mock(MessageChannel.class), mock(SubscribableChannel.class));
61  		this.scheduler = mock(TaskScheduler.class);
62  	}
63  
64  	@Test
65  	public void minimalRegistration() {
66  		WebMvcStompWebSocketEndpointRegistration registration =
67  				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
68  
69  		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
70  		assertEquals(1, mappings.size());
71  
72  		Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
73  		assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
74  		assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size());
75  		assertEquals(Arrays.asList("/foo"), entry.getValue());
76  	}
77  
78  	@Test
79  	public void allowedOrigins() {
80  		WebMvcStompWebSocketEndpointRegistration registration =
81  				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
82  
83  		registration.setAllowedOrigins();
84  
85  		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
86  		assertEquals(1, mappings.size());
87  		WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
88  		assertNotNull(requestHandler.getWebSocketHandler());
89  		assertEquals(1, requestHandler.getHandshakeInterceptors().size());
90  		assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
91  	}
92  
93  	@Test
94  	public void sameOrigin() {
95  		WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
96  
97  		registration.setAllowedOrigins();
98  
99  		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
100 		assertEquals(1, mappings.size());
101 		WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
102 		assertNotNull(requestHandler.getWebSocketHandler());
103 		assertEquals(1, requestHandler.getHandshakeInterceptors().size());
104 		assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
105 	}
106 
107 	@Test
108 	public void allowedOriginsWithSockJsService() {
109 		WebMvcStompWebSocketEndpointRegistration registration =
110 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
111 
112 		String origin = "http://mydomain.com";
113 		registration.setAllowedOrigins(origin).withSockJS();
114 
115 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
116 		assertEquals(1, mappings.size());
117 		SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
118 		assertNotNull(requestHandler.getSockJsService());
119 		DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
120 		assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
121 		assertFalse(sockJsService.shouldSuppressCors());
122 
123 		registration =
124 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
125 		registration.withSockJS().setAllowedOrigins(origin);
126 		mappings = registration.getMappings();
127 		assertEquals(1, mappings.size());
128 		requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
129 		assertNotNull(requestHandler.getSockJsService());
130 		sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
131 		assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
132 		assertFalse(sockJsService.shouldSuppressCors());
133 	}
134 
135 	@Test  // SPR-12283
136 	public void disableCorsWithSockJsService() {
137 		WebMvcStompWebSocketEndpointRegistration registration =
138 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
139 
140 		registration.withSockJS().setSupressCors(true);
141 
142 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
143 		assertEquals(1, mappings.size());
144 		SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
145 		assertNotNull(requestHandler.getSockJsService());
146 		DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
147 		assertTrue(sockJsService.shouldSuppressCors());
148 	}
149 
150 	@Test
151 	public void handshakeHandlerAndInterceptor() {
152 		WebMvcStompWebSocketEndpointRegistration registration =
153 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
154 
155 		DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
156 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
157 
158 		registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor);
159 
160 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
161 		assertEquals(1, mappings.size());
162 
163 		Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
164 		assertEquals(Arrays.asList("/foo"), entry.getValue());
165 
166 		WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
167 		assertNotNull(requestHandler.getWebSocketHandler());
168 		assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
169 		assertEquals(2, requestHandler.getHandshakeInterceptors().size());
170 		assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
171 		assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
172 	}
173 
174 	@Test
175 	public void handshakeHandlerAndInterceptorWithAllowedOrigins() {
176 		WebMvcStompWebSocketEndpointRegistration registration =
177 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
178 
179 		DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
180 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
181 		String origin = "http://mydomain.com";
182 		registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin);
183 
184 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
185 		assertEquals(1, mappings.size());
186 
187 		Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
188 		assertEquals(Arrays.asList("/foo"), entry.getValue());
189 
190 		WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
191 		assertNotNull(requestHandler.getWebSocketHandler());
192 		assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
193 		assertEquals(2, requestHandler.getHandshakeInterceptors().size());
194 		assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
195 		assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
196 	}
197 
198 	@Test
199 	public void handshakeHandlerInterceptorWithSockJsService() {
200 		WebMvcStompWebSocketEndpointRegistration registration =
201 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
202 
203 		DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
204 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
205 
206 		registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS();
207 
208 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
209 		assertEquals(1, mappings.size());
210 
211 		Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
212 		assertEquals(Arrays.asList("/foo/**"), entry.getValue());
213 
214 		SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
215 		assertNotNull(requestHandler.getWebSocketHandler());
216 
217 		DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
218 		assertNotNull(sockJsService);
219 
220 		Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
221 		WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
222 		assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
223 		assertEquals(2, sockJsService.getHandshakeInterceptors().size());
224 		assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
225 		assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass());
226 	}
227 
228 	@Test
229 	public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() {
230 		WebMvcStompWebSocketEndpointRegistration registration =
231 				new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
232 
233 		DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
234 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
235 		String origin = "http://mydomain.com";
236 
237 		registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS();
238 
239 		MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
240 		assertEquals(1, mappings.size());
241 
242 		Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
243 		assertEquals(Arrays.asList("/foo/**"), entry.getValue());
244 
245 		SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
246 		assertNotNull(requestHandler.getWebSocketHandler());
247 
248 		DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
249 		assertNotNull(sockJsService);
250 
251 		Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
252 		WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
253 		assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
254 		assertEquals(2, sockJsService.getHandshakeInterceptors().size());
255 		assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
256 		assertEquals(OriginHandshakeInterceptor.class,
257 				sockJsService.getHandshakeInterceptors().get(1).getClass());
258 		assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
259 	}
260 
261 }