View Javadoc
1   /*
2    * Copyright 2002-2015 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.config.annotation;
18  
19  import java.util.ArrayList;
20  import java.util.Arrays;
21  import java.util.List;
22  
23  import org.junit.Before;
24  import org.junit.Test;
25  import org.mockito.Mockito;
26  
27  import org.springframework.scheduling.TaskScheduler;
28  import org.springframework.web.socket.WebSocketHandler;
29  import org.springframework.web.socket.handler.TextWebSocketHandler;
30  import org.springframework.web.socket.server.HandshakeHandler;
31  import org.springframework.web.socket.server.HandshakeInterceptor;
32  import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
33  import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
34  import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
35  import org.springframework.web.socket.sockjs.SockJsService;
36  import org.springframework.web.socket.sockjs.transport.TransportType;
37  import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
38  import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
39  
40  import static org.junit.Assert.*;
41  
42  /**
43   * Test fixture for
44   * {@link org.springframework.web.socket.config.annotation.AbstractWebSocketHandlerRegistration}.
45   *
46   * @author Rossen Stoyanchev
47   */
48  public class WebSocketHandlerRegistrationTests {
49  
50  	private TestWebSocketHandlerRegistration registration;
51  
52  	private TaskScheduler taskScheduler;
53  
54  
55  	@Before
56  	public void setup() {
57  		this.taskScheduler = Mockito.mock(TaskScheduler.class);
58  		this.registration = new TestWebSocketHandlerRegistration(taskScheduler);
59  	}
60  
61  	@Test
62  	public void minimal() {
63  		WebSocketHandler handler = new TextWebSocketHandler();
64  		this.registration.addHandler(handler, "/foo", "/bar");
65  
66  		List<Mapping> mappings = this.registration.getMappings();
67  		assertEquals(2, mappings.size());
68  
69  		Mapping m1 = mappings.get(0);
70  		assertEquals(handler, m1.webSocketHandler);
71  		assertEquals("/foo", m1.path);
72  		assertEquals(1, m1.interceptors.length);
73  		assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
74  
75  		Mapping m2 = mappings.get(1);
76  		assertEquals(handler, m2.webSocketHandler);
77  		assertEquals("/bar", m2.path);
78  		assertEquals(1, m2.interceptors.length);
79  		assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
80  	}
81  
82  	@Test
83  	public void interceptors() {
84  		WebSocketHandler handler = new TextWebSocketHandler();
85  		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
86  
87  		this.registration.addHandler(handler, "/foo").addInterceptors(interceptor);
88  
89  		List<Mapping> mappings = this.registration.getMappings();
90  		assertEquals(1, mappings.size());
91  
92  		Mapping mapping = mappings.get(0);
93  		assertEquals(handler, mapping.webSocketHandler);
94  		assertEquals("/foo", mapping.path);
95  		assertEquals(2, mapping.interceptors.length);
96  		assertEquals(interceptor, mapping.interceptors[0]);
97  		assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
98  	}
99  
100 	@Test
101 	public void emptyAllowedOrigin() {
102 		WebSocketHandler handler = new TextWebSocketHandler();
103 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
104 
105 		this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins();
106 
107 		List<Mapping> mappings = this.registration.getMappings();
108 		assertEquals(1, mappings.size());
109 
110 		Mapping mapping = mappings.get(0);
111 		assertEquals(handler, mapping.webSocketHandler);
112 		assertEquals("/foo", mapping.path);
113 		assertEquals(2, mapping.interceptors.length);
114 		assertEquals(interceptor, mapping.interceptors[0]);
115 		assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
116 	}
117 
118 	@Test
119 	public void interceptorsWithAllowedOrigins() {
120 		WebSocketHandler handler = new TextWebSocketHandler();
121 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
122 
123 		this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("http://mydomain1.com");
124 
125 		List<Mapping> mappings = this.registration.getMappings();
126 		assertEquals(1, mappings.size());
127 
128 		Mapping mapping = mappings.get(0);
129 		assertEquals(handler, mapping.webSocketHandler);
130 		assertEquals("/foo", mapping.path);
131 		assertEquals(2, mapping.interceptors.length);
132 		assertEquals(interceptor, mapping.interceptors[0]);
133 		assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
134 	}
135 
136 	@Test
137 	public void interceptorsPassedToSockJsRegistration() {
138 		WebSocketHandler handler = new TextWebSocketHandler();
139 		HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
140 
141 		this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
142 				.setAllowedOrigins("http://mydomain1.com").withSockJS();
143 
144 		List<Mapping> mappings = this.registration.getMappings();
145 		assertEquals(1, mappings.size());
146 
147 		Mapping mapping = mappings.get(0);
148 		assertEquals(handler, mapping.webSocketHandler);
149 		assertEquals("/foo/**", mapping.path);
150 		assertNotNull(mapping.sockJsService);
151 		assertEquals(Arrays.asList("http://mydomain1.com"),
152 				mapping.sockJsService.getAllowedOrigins());
153 		List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
154 		assertEquals(interceptor, interceptors.get(0));
155 		assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
156 	}
157 
158 	@Test
159 	public void handshakeHandler() {
160 		WebSocketHandler handler = new TextWebSocketHandler();
161 		HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
162 
163 		this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler);
164 
165 		List<Mapping> mappings = this.registration.getMappings();
166 		assertEquals(1, mappings.size());
167 
168 		Mapping mapping = mappings.get(0);
169 		assertEquals(handler, mapping.webSocketHandler);
170 		assertEquals("/foo", mapping.path);
171 		assertSame(handshakeHandler, mapping.handshakeHandler);
172 	}
173 
174 	@Test
175 	public void handshakeHandlerPassedToSockJsRegistration() {
176 		WebSocketHandler handler = new TextWebSocketHandler();
177 		HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
178 
179 		this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS();
180 
181 		List<Mapping> mappings = this.registration.getMappings();
182 		assertEquals(1, mappings.size());
183 
184 		Mapping mapping = mappings.get(0);
185 		assertEquals(handler, mapping.webSocketHandler);
186 		assertEquals("/foo/**", mapping.path);
187 		assertNotNull(mapping.sockJsService);
188 
189 		WebSocketTransportHandler transportHandler =
190 				(WebSocketTransportHandler) mapping.sockJsService.getTransportHandlers().get(TransportType.WEBSOCKET);
191 		assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
192 	}
193 
194 
195 	private static class TestWebSocketHandlerRegistration  extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
196 
197 		public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
198 			super(sockJsTaskScheduler);
199 		}
200 
201 		@Override
202 		protected List<Mapping> createMappings() {
203 			return new ArrayList<>();
204 		}
205 
206 		@Override
207 		protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
208 				WebSocketHandler wsHandler, String pathPattern) {
209 
210 			mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
211 		}
212 
213 		@Override
214 		protected void addWebSocketHandlerMapping(List<Mapping> mappings, WebSocketHandler handler,
215 				HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path) {
216 
217 			mappings.add(new Mapping(handler, path, handshakeHandler, interceptors));
218 		}
219 	}
220 
221 	private static class Mapping {
222 
223 		private final WebSocketHandler webSocketHandler;
224 
225 		private final String path;
226 
227 		private final HandshakeHandler handshakeHandler;
228 
229 		private final HandshakeInterceptor[] interceptors;
230 
231 		private final DefaultSockJsService sockJsService;
232 
233 
234 		public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
235 			this.webSocketHandler = handler;
236 			this.path = path;
237 			this.handshakeHandler = null;
238 			this.interceptors = null;
239 			this.sockJsService = (DefaultSockJsService) sockJsService;
240 		}
241 
242 		public Mapping(WebSocketHandler h, String path, HandshakeHandler hh, HandshakeInterceptor[] interceptors) {
243 			this.webSocketHandler = h;
244 			this.path = path;
245 			this.handshakeHandler = hh;
246 			this.interceptors = interceptors;
247 			this.sockJsService = null;
248 		}
249 	}
250 
251 }