1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.config.annotation;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22
23 import org.junit.Before;
24 import org.junit.Test;
25 import org.mockito.Mockito;
26
27 import org.springframework.scheduling.TaskScheduler;
28 import org.springframework.web.socket.WebSocketHandler;
29 import org.springframework.web.socket.handler.TextWebSocketHandler;
30 import org.springframework.web.socket.server.HandshakeHandler;
31 import org.springframework.web.socket.server.HandshakeInterceptor;
32 import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
33 import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
34 import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
35 import org.springframework.web.socket.sockjs.SockJsService;
36 import org.springframework.web.socket.sockjs.transport.TransportType;
37 import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
38 import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
39
40 import static org.junit.Assert.*;
41
42
43
44
45
46
47
48 public class WebSocketHandlerRegistrationTests {
49
50 private TestWebSocketHandlerRegistration registration;
51
52 private TaskScheduler taskScheduler;
53
54
55 @Before
56 public void setup() {
57 this.taskScheduler = Mockito.mock(TaskScheduler.class);
58 this.registration = new TestWebSocketHandlerRegistration(taskScheduler);
59 }
60
61 @Test
62 public void minimal() {
63 WebSocketHandler handler = new TextWebSocketHandler();
64 this.registration.addHandler(handler, "/foo", "/bar");
65
66 List<Mapping> mappings = this.registration.getMappings();
67 assertEquals(2, mappings.size());
68
69 Mapping m1 = mappings.get(0);
70 assertEquals(handler, m1.webSocketHandler);
71 assertEquals("/foo", m1.path);
72 assertEquals(1, m1.interceptors.length);
73 assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
74
75 Mapping m2 = mappings.get(1);
76 assertEquals(handler, m2.webSocketHandler);
77 assertEquals("/bar", m2.path);
78 assertEquals(1, m2.interceptors.length);
79 assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
80 }
81
82 @Test
83 public void interceptors() {
84 WebSocketHandler handler = new TextWebSocketHandler();
85 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
86
87 this.registration.addHandler(handler, "/foo").addInterceptors(interceptor);
88
89 List<Mapping> mappings = this.registration.getMappings();
90 assertEquals(1, mappings.size());
91
92 Mapping mapping = mappings.get(0);
93 assertEquals(handler, mapping.webSocketHandler);
94 assertEquals("/foo", mapping.path);
95 assertEquals(2, mapping.interceptors.length);
96 assertEquals(interceptor, mapping.interceptors[0]);
97 assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
98 }
99
100 @Test
101 public void emptyAllowedOrigin() {
102 WebSocketHandler handler = new TextWebSocketHandler();
103 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
104
105 this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins();
106
107 List<Mapping> mappings = this.registration.getMappings();
108 assertEquals(1, mappings.size());
109
110 Mapping mapping = mappings.get(0);
111 assertEquals(handler, mapping.webSocketHandler);
112 assertEquals("/foo", mapping.path);
113 assertEquals(2, mapping.interceptors.length);
114 assertEquals(interceptor, mapping.interceptors[0]);
115 assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
116 }
117
118 @Test
119 public void interceptorsWithAllowedOrigins() {
120 WebSocketHandler handler = new TextWebSocketHandler();
121 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
122
123 this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("http://mydomain1.com");
124
125 List<Mapping> mappings = this.registration.getMappings();
126 assertEquals(1, mappings.size());
127
128 Mapping mapping = mappings.get(0);
129 assertEquals(handler, mapping.webSocketHandler);
130 assertEquals("/foo", mapping.path);
131 assertEquals(2, mapping.interceptors.length);
132 assertEquals(interceptor, mapping.interceptors[0]);
133 assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
134 }
135
136 @Test
137 public void interceptorsPassedToSockJsRegistration() {
138 WebSocketHandler handler = new TextWebSocketHandler();
139 HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
140
141 this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
142 .setAllowedOrigins("http://mydomain1.com").withSockJS();
143
144 List<Mapping> mappings = this.registration.getMappings();
145 assertEquals(1, mappings.size());
146
147 Mapping mapping = mappings.get(0);
148 assertEquals(handler, mapping.webSocketHandler);
149 assertEquals("/foo/**", mapping.path);
150 assertNotNull(mapping.sockJsService);
151 assertEquals(Arrays.asList("http://mydomain1.com"),
152 mapping.sockJsService.getAllowedOrigins());
153 List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
154 assertEquals(interceptor, interceptors.get(0));
155 assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
156 }
157
158 @Test
159 public void handshakeHandler() {
160 WebSocketHandler handler = new TextWebSocketHandler();
161 HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
162
163 this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler);
164
165 List<Mapping> mappings = this.registration.getMappings();
166 assertEquals(1, mappings.size());
167
168 Mapping mapping = mappings.get(0);
169 assertEquals(handler, mapping.webSocketHandler);
170 assertEquals("/foo", mapping.path);
171 assertSame(handshakeHandler, mapping.handshakeHandler);
172 }
173
174 @Test
175 public void handshakeHandlerPassedToSockJsRegistration() {
176 WebSocketHandler handler = new TextWebSocketHandler();
177 HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
178
179 this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS();
180
181 List<Mapping> mappings = this.registration.getMappings();
182 assertEquals(1, mappings.size());
183
184 Mapping mapping = mappings.get(0);
185 assertEquals(handler, mapping.webSocketHandler);
186 assertEquals("/foo/**", mapping.path);
187 assertNotNull(mapping.sockJsService);
188
189 WebSocketTransportHandler transportHandler =
190 (WebSocketTransportHandler) mapping.sockJsService.getTransportHandlers().get(TransportType.WEBSOCKET);
191 assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
192 }
193
194
195 private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
196
197 public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
198 super(sockJsTaskScheduler);
199 }
200
201 @Override
202 protected List<Mapping> createMappings() {
203 return new ArrayList<>();
204 }
205
206 @Override
207 protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
208 WebSocketHandler wsHandler, String pathPattern) {
209
210 mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
211 }
212
213 @Override
214 protected void addWebSocketHandlerMapping(List<Mapping> mappings, WebSocketHandler handler,
215 HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path) {
216
217 mappings.add(new Mapping(handler, path, handshakeHandler, interceptors));
218 }
219 }
220
221 private static class Mapping {
222
223 private final WebSocketHandler webSocketHandler;
224
225 private final String path;
226
227 private final HandshakeHandler handshakeHandler;
228
229 private final HandshakeInterceptor[] interceptors;
230
231 private final DefaultSockJsService sockJsService;
232
233
234 public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
235 this.webSocketHandler = handler;
236 this.path = path;
237 this.handshakeHandler = null;
238 this.interceptors = null;
239 this.sockJsService = (DefaultSockJsService) sockJsService;
240 }
241
242 public Mapping(WebSocketHandler h, String path, HandshakeHandler hh, HandshakeInterceptor[] interceptors) {
243 this.webSocketHandler = h;
244 this.path = path;
245 this.handshakeHandler = hh;
246 this.interceptors = interceptors;
247 this.sockJsService = null;
248 }
249 }
250
251 }