View Javadoc
1   /*
2    * Copyright 2002-2015 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket;
18  
19  
20  import static org.junit.Assert.*;
21  
22  import java.net.URI;
23  import java.util.ArrayList;
24  import java.util.Arrays;
25  import java.util.List;
26  import java.util.concurrent.CountDownLatch;
27  import java.util.concurrent.TimeUnit;
28  
29  import org.junit.Test;
30  import org.junit.runner.RunWith;
31  import org.junit.runners.Parameterized;
32  
33  import org.springframework.beans.factory.annotation.Autowired;
34  import org.springframework.context.annotation.Bean;
35  import org.springframework.context.annotation.Configuration;
36  import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
37  import org.springframework.web.socket.client.standard.StandardWebSocketClient;
38  import org.springframework.web.socket.config.annotation.EnableWebSocket;
39  import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
40  import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
41  import org.springframework.web.socket.handler.AbstractWebSocketHandler;
42  import org.springframework.web.socket.handler.TextWebSocketHandler;
43  import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
44  
45  
46  /**
47   * Client and server-side WebSocket integration tests.
48   *
49   * @author Rossen Stoyanchev
50   */
51  @RunWith(Parameterized.class)
52  public class WebSocketIntegrationTests extends  AbstractWebSocketIntegrationTests {
53  
54  	@Parameterized.Parameters
55  	public static Iterable<Object[]> arguments() {
56  		return Arrays.asList(new Object[][] {
57  				{new JettyWebSocketTestServer(), new JettyWebSocketClient()},
58  				{new TomcatWebSocketTestServer(), new StandardWebSocketClient()},
59  				{new UndertowTestServer(), new JettyWebSocketClient()}
60  		});
61  	}
62  
63  
64  	@Override
65  	protected Class<?>[] getAnnotatedConfigClasses() {
66  		return new Class<?>[] { TestConfig.class };
67  	}
68  
69  	@Test
70  	public void subProtocolNegotiation() throws Exception {
71  		WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
72  		headers.setSecWebSocketProtocol("foo");
73  		URI url = new URI(getWsBaseUrl() + "/ws");
74  		WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get();
75  		assertEquals("foo", session.getAcceptedProtocol());
76  		session.close();
77  	}
78  
79  	// SPR-12727
80  
81  	@Test
82  	public void unsolicitedPongWithEmptyPayload() throws Exception {
83  		TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
84  		serverHandler.setWaitMessageCount(1);
85  
86  		String url = getWsBaseUrl() + "/ws";
87  		WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get();
88  		session.sendMessage(new PongMessage());
89  
90  		serverHandler.await();
91  		assertNull(serverHandler.getTransportError());
92  		assertEquals(1, serverHandler.getReceivedMessages().size());
93  		assertEquals(PongMessage.class, serverHandler.getReceivedMessages().get(0).getClass());
94  	}
95  
96  
97  	@Configuration
98  	@EnableWebSocket
99  	static class TestConfig implements WebSocketConfigurer {
100 
101 		@Autowired
102 		private DefaultHandshakeHandler handshakeHandler;  // can't rely on classpath for server detection
103 
104 		@Override
105 		public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
106 			this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz");
107 			registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler);
108 		}
109 
110 		@Bean
111 		public TestWebSocketHandler handler() {
112 			return new TestWebSocketHandler();
113 		}
114 
115 	}
116 
117 	private static class TestWebSocketHandler extends AbstractWebSocketHandler {
118 
119 		private List<WebSocketMessage> receivedMessages = new ArrayList<>();
120 
121 		private int waitMessageCount;
122 
123 		private final CountDownLatch latch = new CountDownLatch(1);
124 
125 		private Throwable transportError;
126 
127 
128 		public void setWaitMessageCount(int waitMessageCount) {
129 			this.waitMessageCount = waitMessageCount;
130 		}
131 
132 		public List<WebSocketMessage> getReceivedMessages() {
133 			return this.receivedMessages;
134 		}
135 
136 		public Throwable getTransportError() {
137 			return this.transportError;
138 		}
139 
140 		@Override
141 		public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
142 			this.receivedMessages.add(message);
143 			if (this.receivedMessages.size() >= this.waitMessageCount) {
144 				this.latch.countDown();
145 			}
146 		}
147 
148 		@Override
149 		public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
150 			this.transportError = exception;
151 			this.latch.countDown();
152 		}
153 
154 		public void await() throws InterruptedException {
155 			this.latch.await(5, TimeUnit.SECONDS);
156 		}
157 	}
158 
159 }