1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.sockjs.client;
18
19 import java.io.IOException;
20 import java.net.URI;
21 import java.security.Principal;
22 import java.util.Map;
23 import java.util.concurrent.ConcurrentHashMap;
24
25 import org.apache.commons.logging.Log;
26 import org.apache.commons.logging.LogFactory;
27
28 import org.springframework.http.HttpHeaders;
29 import org.springframework.util.Assert;
30 import org.springframework.util.concurrent.SettableListenableFuture;
31 import org.springframework.web.socket.CloseStatus;
32 import org.springframework.web.socket.TextMessage;
33 import org.springframework.web.socket.WebSocketHandler;
34 import org.springframework.web.socket.WebSocketMessage;
35 import org.springframework.web.socket.WebSocketSession;
36 import org.springframework.web.socket.sockjs.frame.SockJsFrame;
37 import org.springframework.web.socket.sockjs.frame.SockJsFrameType;
38 import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
39
40
41
42
43
44
45
46
47
48
49 public abstract class AbstractClientSockJsSession implements WebSocketSession {
50
51 protected final Log logger = LogFactory.getLog(getClass());
52
53
54 private final TransportRequest request;
55
56 private final WebSocketHandler webSocketHandler;
57
58 private final SettableListenableFuture<WebSocketSession> connectFuture;
59
60
61 private final Map<String, Object> attributes = new ConcurrentHashMap<String, Object>();
62
63 private volatile State state = State.NEW;
64
65 private volatile CloseStatus closeStatus;
66
67
68 protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler,
69 SettableListenableFuture<WebSocketSession> connectFuture) {
70
71 Assert.notNull(request, "'request' is required");
72 Assert.notNull(handler, "'handler' is required");
73 Assert.notNull(connectFuture, "'connectFuture' is required");
74 this.request = request;
75 this.webSocketHandler = handler;
76 this.connectFuture = connectFuture;
77 }
78
79
80 @Override
81 public String getId() {
82 return this.request.getSockJsUrlInfo().getSessionId();
83 }
84
85 @Override
86 public URI getUri() {
87 return this.request.getSockJsUrlInfo().getSockJsUrl();
88 }
89
90 @Override
91 public HttpHeaders getHandshakeHeaders() {
92 return this.request.getHandshakeHeaders();
93 }
94
95 @Override
96 public Map<String, Object> getAttributes() {
97 return this.attributes;
98 }
99
100 @Override
101 public Principal getPrincipal() {
102 return this.request.getUser();
103 }
104
105 public SockJsMessageCodec getMessageCodec() {
106 return this.request.getMessageCodec();
107 }
108
109 public WebSocketHandler getWebSocketHandler() {
110 return this.webSocketHandler;
111 }
112
113
114
115
116
117
118
119 Runnable getTimeoutTask() {
120 return new Runnable() {
121 @Override
122 public void run() {
123 closeInternal(new CloseStatus(2007, "Transport timed out"));
124 }
125 };
126 }
127
128 @Override
129 public boolean isOpen() {
130 return State.OPEN.equals(this.state);
131 }
132
133 public boolean isDisconnected() {
134 return (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state));
135 }
136
137 @Override
138 public final void sendMessage(WebSocketMessage<?> message) throws IOException {
139 Assert.state(State.OPEN.equals(this.state), this + " is not open, current state=" + this.state);
140 Assert.isInstanceOf(TextMessage.class, message, this + " supports text messages only.");
141 String payload = ((TextMessage) message).getPayload();
142 payload = getMessageCodec().encode(new String[] { payload });
143 payload = payload.substring(1);
144 message = new TextMessage(payload);
145 if (logger.isTraceEnabled()) {
146 logger.trace("Sending message " + message + " in " + this);
147 }
148 sendInternal((TextMessage) message);
149 }
150
151 protected abstract void sendInternal(TextMessage textMessage) throws IOException;
152
153 @Override
154 public final void close() throws IOException {
155 close(CloseStatus.NORMAL);
156 }
157
158 @Override
159 public final void close(CloseStatus status) {
160 Assert.isTrue(status != null && isUserSetStatus(status), "Invalid close status: " + status);
161 if (logger.isDebugEnabled()) {
162 logger.debug("Closing session with " + status + " in " + this);
163 }
164 closeInternal(status);
165 }
166
167 private boolean isUserSetStatus(CloseStatus status) {
168 return (status.getCode() == 1000 || (status.getCode() >= 3000 && status.getCode() <= 4999));
169 }
170
171 protected void closeInternal(CloseStatus status) {
172 if (this.state == null) {
173 logger.warn("Ignoring close since connect() was never invoked");
174 return;
175 }
176 if (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)) {
177 logger.debug("Ignoring close (already closing or closed), current state=" + this.state);
178 return;
179 }
180 this.state = State.CLOSING;
181 this.closeStatus = status;
182 try {
183 disconnect(status);
184 }
185 catch (Throwable ex) {
186 if (logger.isErrorEnabled()) {
187 logger.error("Failed to close " + this, ex);
188 }
189 }
190 }
191
192 protected abstract void disconnect(CloseStatus status) throws IOException;
193
194 public void handleFrame(String payload) {
195 SockJsFrame frame = new SockJsFrame(payload);
196 if (SockJsFrameType.OPEN.equals(frame.getType())) {
197 handleOpenFrame();
198 }
199 else if (SockJsFrameType.MESSAGE.equals(frame.getType())) {
200 handleMessageFrame(frame);
201 }
202 else if (SockJsFrameType.CLOSE.equals(frame.getType())) {
203 handleCloseFrame(frame);
204 }
205 else if (SockJsFrameType.HEARTBEAT.equals(frame.getType())) {
206 if (logger.isTraceEnabled()) {
207 logger.trace("Received heartbeat in " + this);
208 }
209 }
210 else {
211
212 throw new IllegalStateException("Unknown SockJS frame type " + frame + " in " + this);
213 }
214 }
215
216 private void handleOpenFrame() {
217 if (logger.isDebugEnabled()) {
218 logger.debug("Processing SockJS open frame in " + this);
219 }
220 if (State.NEW.equals(state)) {
221 this.state = State.OPEN;
222 try {
223 this.webSocketHandler.afterConnectionEstablished(this);
224 this.connectFuture.set(this);
225 }
226 catch (Throwable ex) {
227 if (logger.isErrorEnabled()) {
228 Class<?> type = this.webSocketHandler.getClass();
229 logger.error(type + ".afterConnectionEstablished threw exception in " + this, ex);
230 }
231 }
232 }
233 else {
234 if (logger.isDebugEnabled()) {
235 logger.debug("Open frame received in " + getId() + " but we're not" +
236 "connecting (current state=" + this.state + "). The server might " +
237 "have been restarted and lost track of the session.");
238 }
239 closeInternal(new CloseStatus(1006, "Server lost session"));
240 }
241 }
242
243 private void handleMessageFrame(SockJsFrame frame) {
244 if (!isOpen()) {
245 if (logger.isErrorEnabled()) {
246 logger.error("Ignoring received message due to state=" + this.state + " in " + this);
247 }
248 return;
249 }
250 String[] messages;
251 try {
252 messages = getMessageCodec().decode(frame.getFrameData());
253 }
254 catch (IOException ex) {
255 if (logger.isErrorEnabled()) {
256 logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex);
257 }
258 closeInternal(CloseStatus.BAD_DATA);
259 return;
260 }
261 if (logger.isTraceEnabled()) {
262 logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this);
263 }
264 for (String message : messages) {
265 try {
266 if (isOpen()) {
267 this.webSocketHandler.handleMessage(this, new TextMessage(message));
268 }
269 }
270 catch (Throwable ex) {
271 Class<?> type = this.webSocketHandler.getClass();
272 logger.error(type + ".handleMessage threw an exception on " + frame + " in " + this, ex);
273 }
274 }
275 }
276
277 private void handleCloseFrame(SockJsFrame frame) {
278 CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE;
279 try {
280 String[] data = getMessageCodec().decode(frame.getFrameData());
281 if (data.length == 2) {
282 closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]);
283 }
284 if (logger.isDebugEnabled()) {
285 logger.debug("Processing SockJS close frame with " + closeStatus + " in " + this);
286 }
287 }
288 catch (IOException ex) {
289 if (logger.isErrorEnabled()) {
290 logger.error("Failed to decode data for " + frame + " in " + this, ex);
291 }
292 }
293 closeInternal(closeStatus);
294 }
295
296 public void handleTransportError(Throwable error) {
297 try {
298 if (logger.isErrorEnabled()) {
299 logger.error("Transport error in " + this, error);
300 }
301 this.webSocketHandler.handleTransportError(this, error);
302 }
303 catch (Exception ex) {
304 Class<?> type = this.webSocketHandler.getClass();
305 if (logger.isErrorEnabled()) {
306 logger.error(type + ".handleTransportError threw an exception", ex);
307 }
308 }
309 }
310
311 public void afterTransportClosed(CloseStatus closeStatus) {
312 this.closeStatus = (this.closeStatus != null ? this.closeStatus : closeStatus);
313 Assert.state(this.closeStatus != null, "CloseStatus not available");
314
315 if (logger.isDebugEnabled()) {
316 logger.debug("Transport closed with " + this.closeStatus + " in " + this);
317 }
318
319 this.state = State.CLOSED;
320 try {
321 this.webSocketHandler.afterConnectionClosed(this, this.closeStatus);
322 }
323 catch (Exception ex) {
324 if (logger.isErrorEnabled()) {
325 Class<?> type = this.webSocketHandler.getClass();
326 logger.error(type + ".afterConnectionClosed threw an exception", ex);
327 }
328 }
329 }
330
331 @Override
332 public String toString() {
333 return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]";
334 }
335
336
337 private enum State { NEW, OPEN, CLOSING, CLOSED }
338
339 }