1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket;
18
19
20 import static org.junit.Assert.*;
21
22 import java.net.URI;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.concurrent.CountDownLatch;
27 import java.util.concurrent.TimeUnit;
28
29 import org.junit.Test;
30 import org.junit.runner.RunWith;
31 import org.junit.runners.Parameterized;
32
33 import org.springframework.beans.factory.annotation.Autowired;
34 import org.springframework.context.annotation.Bean;
35 import org.springframework.context.annotation.Configuration;
36 import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
37 import org.springframework.web.socket.client.standard.StandardWebSocketClient;
38 import org.springframework.web.socket.config.annotation.EnableWebSocket;
39 import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
40 import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
41 import org.springframework.web.socket.handler.AbstractWebSocketHandler;
42 import org.springframework.web.socket.handler.TextWebSocketHandler;
43 import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
44
45
46
47
48
49
50
51 @RunWith(Parameterized.class)
52 public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
53
54 @Parameterized.Parameters
55 public static Iterable<Object[]> arguments() {
56 return Arrays.asList(new Object[][] {
57 {new JettyWebSocketTestServer(), new JettyWebSocketClient()},
58 {new TomcatWebSocketTestServer(), new StandardWebSocketClient()},
59 {new UndertowTestServer(), new JettyWebSocketClient()}
60 });
61 }
62
63
64 @Override
65 protected Class<?>[] getAnnotatedConfigClasses() {
66 return new Class<?>[] { TestConfig.class };
67 }
68
69 @Test
70 public void subProtocolNegotiation() throws Exception {
71 WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
72 headers.setSecWebSocketProtocol("foo");
73 URI url = new URI(getWsBaseUrl() + "/ws");
74 WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get();
75 assertEquals("foo", session.getAcceptedProtocol());
76 session.close();
77 }
78
79
80
81 @Test
82 public void unsolicitedPongWithEmptyPayload() throws Exception {
83 TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
84 serverHandler.setWaitMessageCount(1);
85
86 String url = getWsBaseUrl() + "/ws";
87 WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get();
88 session.sendMessage(new PongMessage());
89
90 serverHandler.await();
91 assertNull(serverHandler.getTransportError());
92 assertEquals(1, serverHandler.getReceivedMessages().size());
93 assertEquals(PongMessage.class, serverHandler.getReceivedMessages().get(0).getClass());
94 }
95
96
97 @Configuration
98 @EnableWebSocket
99 static class TestConfig implements WebSocketConfigurer {
100
101 @Autowired
102 private DefaultHandshakeHandler handshakeHandler;
103
104 @Override
105 public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
106 this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz");
107 registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler);
108 }
109
110 @Bean
111 public TestWebSocketHandler handler() {
112 return new TestWebSocketHandler();
113 }
114
115 }
116
117 private static class TestWebSocketHandler extends AbstractWebSocketHandler {
118
119 private List<WebSocketMessage> receivedMessages = new ArrayList<>();
120
121 private int waitMessageCount;
122
123 private final CountDownLatch latch = new CountDownLatch(1);
124
125 private Throwable transportError;
126
127
128 public void setWaitMessageCount(int waitMessageCount) {
129 this.waitMessageCount = waitMessageCount;
130 }
131
132 public List<WebSocketMessage> getReceivedMessages() {
133 return this.receivedMessages;
134 }
135
136 public Throwable getTransportError() {
137 return this.transportError;
138 }
139
140 @Override
141 public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
142 this.receivedMessages.add(message);
143 if (this.receivedMessages.size() >= this.waitMessageCount) {
144 this.latch.countDown();
145 }
146 }
147
148 @Override
149 public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
150 this.transportError = exception;
151 this.latch.countDown();
152 }
153
154 public void await() throws InterruptedException {
155 this.latch.await(5, TimeUnit.SECONDS);
156 }
157 }
158
159 }