1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.web.socket.messaging;
18
19 import java.io.IOException;
20 import java.nio.ByteBuffer;
21 import java.security.Principal;
22 import java.util.Arrays;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.Set;
26 import java.util.concurrent.ConcurrentHashMap;
27 import java.util.concurrent.atomic.AtomicInteger;
28
29 import org.apache.commons.logging.Log;
30 import org.apache.commons.logging.LogFactory;
31
32 import org.springframework.context.ApplicationEvent;
33 import org.springframework.context.ApplicationEventPublisher;
34 import org.springframework.context.ApplicationEventPublisherAware;
35 import org.springframework.messaging.Message;
36 import org.springframework.messaging.MessageChannel;
37 import org.springframework.messaging.simp.SimpAttributes;
38 import org.springframework.messaging.simp.SimpAttributesContextHolder;
39 import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
40 import org.springframework.messaging.simp.SimpMessageType;
41 import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
42 import org.springframework.messaging.simp.stomp.StompCommand;
43 import org.springframework.messaging.simp.stomp.StompDecoder;
44 import org.springframework.messaging.simp.stomp.StompEncoder;
45 import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
46 import org.springframework.messaging.simp.user.DestinationUserNameProvider;
47 import org.springframework.messaging.simp.user.UserSessionRegistry;
48 import org.springframework.messaging.support.AbstractMessageChannel;
49 import org.springframework.messaging.support.ChannelInterceptor;
50 import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
51 import org.springframework.messaging.support.MessageBuilder;
52 import org.springframework.messaging.support.MessageHeaderAccessor;
53 import org.springframework.messaging.support.MessageHeaderInitializer;
54 import org.springframework.util.Assert;
55 import org.springframework.util.MimeTypeUtils;
56 import org.springframework.web.socket.BinaryMessage;
57 import org.springframework.web.socket.CloseStatus;
58 import org.springframework.web.socket.TextMessage;
59 import org.springframework.web.socket.WebSocketMessage;
60 import org.springframework.web.socket.WebSocketSession;
61 import org.springframework.web.socket.handler.SessionLimitExceededException;
62 import org.springframework.web.socket.handler.WebSocketSessionDecorator;
63 import org.springframework.web.socket.sockjs.transport.SockJsSession;
64
65
66
67
68
69
70
71
72
73 public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
74
75
76
77
78
79
80
81 public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16 * 1024 + 256;
82
83
84
85
86
87 public static final String CONNECTED_USER_HEADER = "user-name";
88
89 private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
90
91 private static final byte[] EMPTY_PAYLOAD = new byte[0];
92
93
94 private int messageSizeLimit = 64 * 1024;
95
96 private UserSessionRegistry userSessionRegistry;
97
98 private final StompEncoder stompEncoder = new StompEncoder();
99
100 private final StompDecoder stompDecoder = new StompDecoder();
101
102 private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap<String, BufferingStompDecoder>();
103
104 private MessageHeaderInitializer headerInitializer;
105
106 private Boolean immutableMessageInterceptorPresent;
107
108 private ApplicationEventPublisher eventPublisher;
109
110 private final Stats stats = new Stats();
111
112
113
114
115
116
117
118
119
120
121
122
123 public void setMessageSizeLimit(int messageSizeLimit) {
124 this.messageSizeLimit = messageSizeLimit;
125 }
126
127
128
129
130
131
132 public int getMessageSizeLimit() {
133 return this.messageSizeLimit;
134 }
135
136
137
138
139
140 public void setUserSessionRegistry(UserSessionRegistry registry) {
141 this.userSessionRegistry = registry;
142 }
143
144
145
146
147 public UserSessionRegistry getUserSessionRegistry() {
148 return this.userSessionRegistry;
149 }
150
151
152
153
154
155
156
157
158 public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
159 this.headerInitializer = headerInitializer;
160 this.stompDecoder.setHeaderInitializer(headerInitializer);
161 }
162
163
164
165
166 public MessageHeaderInitializer getHeaderInitializer() {
167 return this.headerInitializer;
168 }
169
170 @Override
171 public List<String> getSupportedProtocols() {
172 return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
173 }
174
175 @Override
176 public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
177 this.eventPublisher = applicationEventPublisher;
178 }
179
180
181
182
183 public String getStatsInfo() {
184 return this.stats.toString();
185 }
186
187
188
189
190
191 public void handleMessageFromClient(WebSocketSession session,
192 WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
193
194 List<Message<byte[]>> messages;
195 try {
196 ByteBuffer byteBuffer;
197 if (webSocketMessage instanceof TextMessage) {
198 byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
199 }
200 else if (webSocketMessage instanceof BinaryMessage) {
201 byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
202 }
203 else {
204 return;
205 }
206
207 BufferingStompDecoder decoder = this.decoders.get(session.getId());
208 if (decoder == null) {
209 throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
210 }
211
212 messages = decoder.decode(byteBuffer);
213 if (messages.isEmpty()) {
214 if (logger.isTraceEnabled()) {
215 logger.trace("Incomplete STOMP frame content received in session " +
216 session + ", bufferSize=" + decoder.getBufferSize() +
217 ", bufferSizeLimit=" + decoder.getBufferSizeLimit() + ".");
218 }
219 return;
220 }
221 }
222 catch (Throwable ex) {
223 if (logger.isErrorEnabled()) {
224 logger.error("Failed to parse " + webSocketMessage +
225 " in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
226 }
227 sendErrorMessage(session, ex);
228 return;
229 }
230
231 for (Message<byte[]> message : messages) {
232 try {
233 StompHeaderAccessor headerAccessor =
234 MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
235
236 if (logger.isTraceEnabled()) {
237 logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
238 }
239
240 headerAccessor.setSessionId(session.getId());
241 headerAccessor.setSessionAttributes(session.getAttributes());
242 headerAccessor.setUser(session.getPrincipal());
243 if (!detectImmutableMessageInterceptor(outputChannel)) {
244 headerAccessor.setImmutable();
245 }
246
247 if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
248 this.stats.incrementConnectCount();
249 }
250 else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
251 this.stats.incrementDisconnectCount();
252 }
253
254 try {
255 SimpAttributesContextHolder.setAttributesFromMessage(message);
256 if (this.eventPublisher != null) {
257 if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
258 publishEvent(new SessionConnectEvent(this, message));
259 }
260 else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
261 publishEvent(new SessionSubscribeEvent(this, message));
262 }
263 else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
264 publishEvent(new SessionUnsubscribeEvent(this, message));
265 }
266 }
267 outputChannel.send(message);
268 }
269 finally {
270 SimpAttributesContextHolder.resetAttributes();
271 }
272 }
273 catch (Throwable ex) {
274 logger.error("Failed to send client message to application via MessageChannel" +
275 " in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
276 sendErrorMessage(session, ex);
277
278 }
279 }
280 }
281
282 private boolean detectImmutableMessageInterceptor(MessageChannel channel) {
283 if (this.immutableMessageInterceptorPresent != null) {
284 return this.immutableMessageInterceptorPresent;
285 }
286 if (channel instanceof AbstractMessageChannel) {
287 for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) {
288 if (interceptor instanceof ImmutableMessageChannelInterceptor) {
289 this.immutableMessageInterceptorPresent = true;
290 return true;
291 }
292 }
293 }
294 this.immutableMessageInterceptorPresent = false;
295 return false;
296 }
297
298 private void publishEvent(ApplicationEvent event) {
299 try {
300 this.eventPublisher.publishEvent(event);
301 }
302 catch (Throwable ex) {
303 logger.error("Error publishing " + event + ".", ex);
304 }
305 }
306
307 protected void sendErrorMessage(WebSocketSession session, Throwable error) {
308 StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
309 headerAccessor.setMessage(error.getMessage());
310 byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
311 try {
312 session.sendMessage(new TextMessage(bytes));
313 }
314 catch (Throwable ex) {
315
316 logger.debug("Failed to send STOMP ERROR to client.", ex);
317 }
318 }
319
320
321
322
323 @SuppressWarnings("unchecked")
324 @Override
325 public void handleMessageToClient(WebSocketSession session, Message<?> message) {
326 if (!(message.getPayload() instanceof byte[])) {
327 logger.error("Expected byte[] payload. Ignoring " + message + ".");
328 return;
329 }
330 StompHeaderAccessor stompAccessor = getStompHeaderAccessor(message);
331 StompCommand command = stompAccessor.getCommand();
332 if (StompCommand.MESSAGE.equals(command)) {
333 if (stompAccessor.getSubscriptionId() == null) {
334 logger.warn("No STOMP \"subscription\" header in " + message);
335 }
336 String origDestination = stompAccessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
337 if (origDestination != null) {
338 stompAccessor = toMutableAccessor(stompAccessor, message);
339 stompAccessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
340 stompAccessor.setDestination(origDestination);
341 }
342 }
343 else if (StompCommand.CONNECTED.equals(command)) {
344 this.stats.incrementConnectedCount();
345 stompAccessor = afterStompSessionConnected(message, stompAccessor, session);
346 if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) {
347 try {
348 SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
349 SimpAttributesContextHolder.setAttributes(simpAttributes);
350 publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
351 }
352 finally {
353 SimpAttributesContextHolder.resetAttributes();
354 }
355 }
356 }
357 try {
358 byte[] payload = (byte[]) message.getPayload();
359 byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), payload);
360
361 boolean useBinary = (payload.length > 0 && !(session instanceof SockJsSession) &&
362 MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType()));
363
364 if (useBinary) {
365 session.sendMessage(new BinaryMessage(bytes));
366 }
367 else {
368 session.sendMessage(new TextMessage(bytes));
369 }
370 }
371 catch (SessionLimitExceededException ex) {
372
373 throw ex;
374 }
375 catch (Throwable ex) {
376
377 logger.debug("Failed to send WebSocket message to client in session " + session.getId() + ".", ex);
378 command = StompCommand.ERROR;
379 }
380 finally {
381 if (StompCommand.ERROR.equals(command)) {
382 try {
383 session.close(CloseStatus.PROTOCOL_ERROR);
384 }
385 catch (IOException ex) {
386
387 }
388 }
389 }
390 }
391
392 private StompHeaderAccessor getStompHeaderAccessor(Message<?> message) {
393 MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
394 if (accessor == null) {
395
396 throw new IllegalStateException("No header accessor in " + message + ".");
397 }
398 StompHeaderAccessor stompAccessor;
399 if (accessor instanceof StompHeaderAccessor) {
400 stompAccessor = (StompHeaderAccessor) accessor;
401 }
402 else if (accessor instanceof SimpMessageHeaderAccessor) {
403 stompAccessor = StompHeaderAccessor.wrap(message);
404 if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) {
405 stompAccessor = convertConnectAcktoStompConnected(stompAccessor);
406 }
407 else if (SimpMessageType.DISCONNECT_ACK.equals(stompAccessor.getMessageType())) {
408 stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
409 stompAccessor.setMessage("Session closed.");
410 }
411 else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
412 stompAccessor.updateStompCommandAsServerMessage();
413 }
414 }
415 else {
416
417 throw new IllegalStateException(
418 "Unexpected header accessor type: " + accessor.getClass() + " in " + message + ".");
419 }
420 return stompAccessor;
421 }
422
423
424
425
426
427 private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor connectAckHeaders) {
428 String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
429 Message<?> message = (Message<?>) connectAckHeaders.getHeader(name);
430 Assert.notNull(message, "Original STOMP CONNECT not found in " + connectAckHeaders);
431 StompHeaderAccessor connectHeaders = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
432 String version;
433 Set<String> acceptVersions = connectHeaders.getAcceptVersion();
434 if (acceptVersions.contains("1.2")) {
435 version = "1.2";
436 }
437 else if (acceptVersions.contains("1.1")) {
438 version = "1.1";
439 }
440 else if (acceptVersions.isEmpty()) {
441 version = null;
442 }
443 else {
444 throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'");
445 }
446 StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
447 connectedHeaders.setVersion(version);
448 connectedHeaders.setHeartbeat(0, 0);
449 return connectedHeaders;
450 }
451
452 protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
453 return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
454 }
455
456 private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor,
457 WebSocketSession session) {
458
459 Principal principal = session.getPrincipal();
460 if (principal != null) {
461 accessor = toMutableAccessor(accessor, message);
462 accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
463 if (this.userSessionRegistry != null) {
464 String userName = getSessionRegistryUserName(principal);
465 this.userSessionRegistry.registerSessionId(userName, session.getId());
466 }
467 }
468 long[] heartbeat = accessor.getHeartbeat();
469 if (heartbeat[1] > 0) {
470 session = WebSocketSessionDecorator.unwrap(session);
471 if (session instanceof SockJsSession) {
472 ((SockJsSession) session).disableHeartbeat();
473 }
474 }
475 return accessor;
476 }
477
478 private String getSessionRegistryUserName(Principal principal) {
479 String userName = principal.getName();
480 if (principal instanceof DestinationUserNameProvider) {
481 userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
482 }
483 return userName;
484 }
485
486 @Override
487 public String resolveSessionId(Message<?> message) {
488 return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
489 }
490
491 @Override
492 public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
493 if (session.getTextMessageSizeLimit() < MINIMUM_WEBSOCKET_MESSAGE_SIZE) {
494 session.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE);
495 }
496 this.decoders.put(session.getId(), new BufferingStompDecoder(this.stompDecoder, getMessageSizeLimit()));
497 }
498
499 @Override
500 public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
501 this.decoders.remove(session.getId());
502 Principal principal = session.getPrincipal();
503 if (principal != null && this.userSessionRegistry != null) {
504 String userName = getSessionRegistryUserName(principal);
505 this.userSessionRegistry.unregisterSessionId(userName, session.getId());
506 }
507 Message<byte[]> message = createDisconnectMessage(session);
508 SimpAttributes simpAttributes = SimpAttributes.fromMessage(message);
509 try {
510 SimpAttributesContextHolder.setAttributes(simpAttributes);
511 if (this.eventPublisher != null) {
512 publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus));
513 }
514 outputChannel.send(message);
515 }
516 finally {
517 SimpAttributesContextHolder.resetAttributes();
518 simpAttributes.sessionCompleted();
519 }
520 }
521
522 private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
523 StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
524 if (getHeaderInitializer() != null) {
525 getHeaderInitializer().initHeaders(headerAccessor);
526 }
527 headerAccessor.setSessionId(session.getId());
528 headerAccessor.setSessionAttributes(session.getAttributes());
529 headerAccessor.setUser(session.getPrincipal());
530 return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
531 }
532
533 @Override
534 public String toString() {
535 return "StompSubProtocolHandler" + getSupportedProtocols();
536 }
537
538 private class Stats {
539
540 private final AtomicInteger connect = new AtomicInteger();
541
542 private final AtomicInteger connected = new AtomicInteger();
543
544 private final AtomicInteger disconnect = new AtomicInteger();
545
546
547 public void incrementConnectCount() {
548 this.connect.incrementAndGet();
549 }
550
551 public void incrementConnectedCount() {
552 this.connected.incrementAndGet();
553 }
554
555 public void incrementDisconnectCount() {
556 this.disconnect.incrementAndGet();
557 }
558
559
560 public String toString() {
561 return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" +
562 this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + ")";
563 }
564 }
565
566 }