1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.config.annotation;
18
19 import java.util.Arrays;
20 import java.util.List;
21 import java.util.Map;
22
23 import org.junit.Before;
24 import org.junit.Test;
25
26 import org.springframework.messaging.MessageChannel;
27 import org.springframework.messaging.SubscribableChannel;
28 import org.springframework.scheduling.TaskScheduler;
29 import org.springframework.util.MultiValueMap;
30 import org.springframework.web.HttpRequestHandler;
31 import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
32 import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
33 import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
34 import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
35 import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
36 import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
37 import org.springframework.web.socket.sockjs.transport.TransportHandler;
38 import org.springframework.web.socket.sockjs.transport.TransportType;
39 import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
40 import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
41
42 import static org.junit.Assert.*;
43 import static org.mockito.Mockito.*;
44
45
46
47
48
49
50
51 public class WebMvcStompWebSocketEndpointRegistrationTests {
52
53 private SubProtocolWebSocketHandler handler;
54
55 private TaskScheduler scheduler;
56
57
58 @Before
59 public void setup() {
60 this.handler = new SubProtocolWebSocketHandler(mock(MessageChannel.class), mock(SubscribableChannel.class));
61 this.scheduler = mock(TaskScheduler.class);
62 }
63
64 @Test
65 public void minimalRegistration() {
66 WebMvcStompWebSocketEndpointRegistration registration =
67 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
68
69 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
70 assertEquals(1, mappings.size());
71
72 Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
73 assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
74 assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size());
75 assertEquals(Arrays.asList("/foo"), entry.getValue());
76 }
77
78 @Test
79 public void allowedOrigins() {
80 WebMvcStompWebSocketEndpointRegistration registration =
81 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
82
83 registration.setAllowedOrigins();
84
85 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
86 assertEquals(1, mappings.size());
87 WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
88 assertNotNull(requestHandler.getWebSocketHandler());
89 assertEquals(1, requestHandler.getHandshakeInterceptors().size());
90 assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
91 }
92
93 @Test
94 public void sameOrigin() {
95 WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
96
97 registration.setAllowedOrigins();
98
99 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
100 assertEquals(1, mappings.size());
101 WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
102 assertNotNull(requestHandler.getWebSocketHandler());
103 assertEquals(1, requestHandler.getHandshakeInterceptors().size());
104 assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
105 }
106
107 @Test
108 public void allowedOriginsWithSockJsService() {
109 WebMvcStompWebSocketEndpointRegistration registration =
110 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
111
112 String origin = "http://mydomain.com";
113 registration.setAllowedOrigins(origin).withSockJS();
114
115 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
116 assertEquals(1, mappings.size());
117 SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
118 assertNotNull(requestHandler.getSockJsService());
119 DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
120 assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
121 assertFalse(sockJsService.shouldSuppressCors());
122
123 registration =
124 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
125 registration.withSockJS().setAllowedOrigins(origin);
126 mappings = registration.getMappings();
127 assertEquals(1, mappings.size());
128 requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
129 assertNotNull(requestHandler.getSockJsService());
130 sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
131 assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
132 assertFalse(sockJsService.shouldSuppressCors());
133 }
134
135 @Test
136 public void disableCorsWithSockJsService() {
137 WebMvcStompWebSocketEndpointRegistration registration =
138 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
139
140 registration.withSockJS().setSupressCors(true);
141
142 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
143 assertEquals(1, mappings.size());
144 SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
145 assertNotNull(requestHandler.getSockJsService());
146 DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
147 assertTrue(sockJsService.shouldSuppressCors());
148 }
149
150 @Test
151 public void handshakeHandlerAndInterceptor() {
152 WebMvcStompWebSocketEndpointRegistration registration =
153 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
154
155 DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
156 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
157
158 registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor);
159
160 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
161 assertEquals(1, mappings.size());
162
163 Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
164 assertEquals(Arrays.asList("/foo"), entry.getValue());
165
166 WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
167 assertNotNull(requestHandler.getWebSocketHandler());
168 assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
169 assertEquals(2, requestHandler.getHandshakeInterceptors().size());
170 assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
171 assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
172 }
173
174 @Test
175 public void handshakeHandlerAndInterceptorWithAllowedOrigins() {
176 WebMvcStompWebSocketEndpointRegistration registration =
177 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
178
179 DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
180 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
181 String origin = "http://mydomain.com";
182 registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin);
183
184 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
185 assertEquals(1, mappings.size());
186
187 Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
188 assertEquals(Arrays.asList("/foo"), entry.getValue());
189
190 WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
191 assertNotNull(requestHandler.getWebSocketHandler());
192 assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
193 assertEquals(2, requestHandler.getHandshakeInterceptors().size());
194 assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
195 assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
196 }
197
198 @Test
199 public void handshakeHandlerInterceptorWithSockJsService() {
200 WebMvcStompWebSocketEndpointRegistration registration =
201 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
202
203 DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
204 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
205
206 registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS();
207
208 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
209 assertEquals(1, mappings.size());
210
211 Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
212 assertEquals(Arrays.asList("/foo/**"), entry.getValue());
213
214 SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
215 assertNotNull(requestHandler.getWebSocketHandler());
216
217 DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
218 assertNotNull(sockJsService);
219
220 Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
221 WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
222 assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
223 assertEquals(2, sockJsService.getHandshakeInterceptors().size());
224 assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
225 assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass());
226 }
227
228 @Test
229 public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() {
230 WebMvcStompWebSocketEndpointRegistration registration =
231 new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
232
233 DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
234 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
235 String origin = "http://mydomain.com";
236
237 registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS();
238
239 MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
240 assertEquals(1, mappings.size());
241
242 Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
243 assertEquals(Arrays.asList("/foo/**"), entry.getValue());
244
245 SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
246 assertNotNull(requestHandler.getWebSocketHandler());
247
248 DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
249 assertNotNull(sockJsService);
250
251 Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
252 WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
253 assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
254 assertEquals(2, sockJsService.getHandshakeInterceptors().size());
255 assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
256 assertEquals(OriginHandshakeInterceptor.class,
257 sockJsService.getHandshakeInterceptors().get(1).getClass());
258 assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
259 }
260
261 }