View Javadoc
1   /*
2    * Copyright 2002-2015 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.web.socket.messaging;
18  
19  import java.io.IOException;
20  import java.nio.ByteBuffer;
21  import java.security.Principal;
22  import java.util.Arrays;
23  import java.util.List;
24  import java.util.Map;
25  import java.util.Set;
26  import java.util.concurrent.ConcurrentHashMap;
27  import java.util.concurrent.atomic.AtomicInteger;
28  
29  import org.apache.commons.logging.Log;
30  import org.apache.commons.logging.LogFactory;
31  
32  import org.springframework.context.ApplicationEvent;
33  import org.springframework.context.ApplicationEventPublisher;
34  import org.springframework.context.ApplicationEventPublisherAware;
35  import org.springframework.messaging.Message;
36  import org.springframework.messaging.MessageChannel;
37  import org.springframework.messaging.simp.SimpAttributes;
38  import org.springframework.messaging.simp.SimpAttributesContextHolder;
39  import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
40  import org.springframework.messaging.simp.SimpMessageType;
41  import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
42  import org.springframework.messaging.simp.stomp.StompCommand;
43  import org.springframework.messaging.simp.stomp.StompDecoder;
44  import org.springframework.messaging.simp.stomp.StompEncoder;
45  import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
46  import org.springframework.messaging.simp.user.DestinationUserNameProvider;
47  import org.springframework.messaging.simp.user.UserSessionRegistry;
48  import org.springframework.messaging.support.AbstractMessageChannel;
49  import org.springframework.messaging.support.ChannelInterceptor;
50  import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
51  import org.springframework.messaging.support.MessageBuilder;
52  import org.springframework.messaging.support.MessageHeaderAccessor;
53  import org.springframework.messaging.support.MessageHeaderInitializer;
54  import org.springframework.util.Assert;
55  import org.springframework.util.MimeTypeUtils;
56  import org.springframework.web.socket.BinaryMessage;
57  import org.springframework.web.socket.CloseStatus;
58  import org.springframework.web.socket.TextMessage;
59  import org.springframework.web.socket.WebSocketMessage;
60  import org.springframework.web.socket.WebSocketSession;
61  import org.springframework.web.socket.handler.SessionLimitExceededException;
62  import org.springframework.web.socket.handler.WebSocketSessionDecorator;
63  import org.springframework.web.socket.sockjs.transport.SockJsSession;
64  
65  /**
66   * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2
67   * of the STOMP specification.
68   *
69   * @author Rossen Stoyanchev
70   * @author Andy Wilkinson
71   * @since 4.0
72   */
73  public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
74  
75  	/**
76  	 * This handler supports assembling large STOMP messages split into multiple
77  	 * WebSocket messages and STOMP clients (like stomp.js) indeed split large STOMP
78  	 * messages at 16K boundaries. Therefore the WebSocket server input message
79  	 * buffer size must allow 16K at least plus a little extra for SockJS framing.
80  	 */
81  	public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16 * 1024 + 256;
82  
83  	/**
84  	 * The name of the header set on the CONNECTED frame indicating the name
85  	 * of the user authenticated on the WebSocket session.
86  	 */
87  	public static final String CONNECTED_USER_HEADER = "user-name";
88  
89  	private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
90  
91  	private static final byte[] EMPTY_PAYLOAD = new byte[0];
92  
93  
94  	private int messageSizeLimit = 64 * 1024;
95  
96  	private UserSessionRegistry userSessionRegistry;
97  
98  	private final StompEncoder stompEncoder = new StompEncoder();
99  
100 	private final StompDecoder stompDecoder = new StompDecoder();
101 
102 	private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap<String, BufferingStompDecoder>();
103 
104 	private MessageHeaderInitializer headerInitializer;
105 
106 	private Boolean immutableMessageInterceptorPresent;
107 
108 	private ApplicationEventPublisher eventPublisher;
109 
110 	private final Stats stats = new Stats();
111 
112 
113 	/**
114 	 * Configure the maximum size allowed for an incoming STOMP message.
115 	 * Since a STOMP message can be received in multiple WebSocket messages,
116 	 * buffering may be required and therefore it is necessary to know the maximum
117 	 * allowed message size.
118 	 *
119 	 * <p>By default this property is set to 64K.
120 	 *
121 	 * @since 4.0.3
122 	 */
123 	public void setMessageSizeLimit(int messageSizeLimit) {
124 		this.messageSizeLimit = messageSizeLimit;
125 	}
126 
127 	/**
128 	 * Get the configured message buffer size limit in bytes.
129 	 *
130 	 * @since 4.0.3
131 	 */
132 	public int getMessageSizeLimit() {
133 		return this.messageSizeLimit;
134 	}
135 
136 	/**
137 	 * Provide a registry with which to register active user session ids.
138 	 * @see org.springframework.messaging.simp.user.UserDestinationMessageHandler
139 	 */
140 	public void setUserSessionRegistry(UserSessionRegistry registry) {
141 		this.userSessionRegistry = registry;
142 	}
143 
144 	/**
145 	 * @return the configured UserSessionRegistry.
146 	 */
147 	public UserSessionRegistry getUserSessionRegistry() {
148 		return this.userSessionRegistry;
149 	}
150 
151 	/**
152 	 * Configure a {@link MessageHeaderInitializer} to apply to the headers of all
153 	 * messages created from decoded STOMP frames and other messages sent to the
154 	 * client inbound channel.
155 	 *
156 	 * <p>By default this property is not set.
157 	 */
158 	public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
159 		this.headerInitializer = headerInitializer;
160 		this.stompDecoder.setHeaderInitializer(headerInitializer);
161 	}
162 
163 	/**
164 	 * @return the configured header initializer.
165 	 */
166 	public MessageHeaderInitializer getHeaderInitializer() {
167 		return this.headerInitializer;
168 	}
169 
170 	@Override
171 	public List<String> getSupportedProtocols() {
172 		return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
173 	}
174 
175 	@Override
176 	public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
177 		this.eventPublisher = applicationEventPublisher;
178 	}
179 
180 	/**
181 	 * Return a String describing internal state and counters.
182 	 */
183 	public String getStatsInfo() {
184 		return this.stats.toString();
185 	}
186 
187 
188 	/**
189 	 * Handle incoming WebSocket messages from clients.
190 	 */
191 	public void handleMessageFromClient(WebSocketSession session,
192 			WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
193 
194 		List<Message<byte[]>> messages;
195 		try {
196 			ByteBuffer byteBuffer;
197 			if (webSocketMessage instanceof TextMessage) {
198 				byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
199 			}
200 			else if (webSocketMessage instanceof BinaryMessage) {
201 				byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
202 			}
203 			else {
204 				return;
205 			}
206 
207 			BufferingStompDecoder decoder = this.decoders.get(session.getId());
208 			if (decoder == null) {
209 				throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
210 			}
211 
212 			messages = decoder.decode(byteBuffer);
213 			if (messages.isEmpty()) {
214 				if (logger.isTraceEnabled()) {
215 					logger.trace("Incomplete STOMP frame content received in session " +
216 							session + ", bufferSize=" + decoder.getBufferSize() +
217 							", bufferSizeLimit=" + decoder.getBufferSizeLimit() + ".");
218 				}
219 				return;
220 			}
221 		}
222 		catch (Throwable ex) {
223 			if (logger.isErrorEnabled()) {
224 				logger.error("Failed to parse " + webSocketMessage +
225 						" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
226 			}
227 			sendErrorMessage(session, ex);
228 			return;
229 		}
230 
231 		for (Message<byte[]> message : messages) {
232 			try {
233 				StompHeaderAccessor headerAccessor =
234 						MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
235 
236 				if (logger.isTraceEnabled()) {
237 					logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
238 				}
239 
240 				headerAccessor.setSessionId(session.getId());
241 				headerAccessor.setSessionAttributes(session.getAttributes());
242 				headerAccessor.setUser(session.getPrincipal());
243 				if (!detectImmutableMessageInterceptor(outputChannel)) {
244 					headerAccessor.setImmutable();
245 				}
246 
247 				if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
248 					this.stats.incrementConnectCount();
249 				}
250 				else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
251 					this.stats.incrementDisconnectCount();
252 				}
253 
254 				try {
255 					SimpAttributesContextHolder.setAttributesFromMessage(message);
256 					if (this.eventPublisher != null) {
257 						if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
258 							publishEvent(new SessionConnectEvent(this, message));
259 						}
260 						else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
261 							publishEvent(new SessionSubscribeEvent(this, message));
262 						}
263 						else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
264 							publishEvent(new SessionUnsubscribeEvent(this, message));
265 						}
266 					}
267 					outputChannel.send(message);
268 				}
269 				finally {
270 					SimpAttributesContextHolder.resetAttributes();
271 				}
272 			}
273 			catch (Throwable ex) {
274 				logger.error("Failed to send client message to application via MessageChannel" +
275 						" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
276 				sendErrorMessage(session, ex);
277 
278 			}
279 		}
280 	}
281 
282 	private boolean detectImmutableMessageInterceptor(MessageChannel channel) {
283 		if (this.immutableMessageInterceptorPresent != null) {
284 			return this.immutableMessageInterceptorPresent;
285 		}
286 		if (channel instanceof AbstractMessageChannel) {
287 			for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) {
288 				if (interceptor instanceof ImmutableMessageChannelInterceptor) {
289 					this.immutableMessageInterceptorPresent = true;
290 					return true;
291 				}
292 			}
293 		}
294 		this.immutableMessageInterceptorPresent = false;
295 		return false;
296 	}
297 
298 	private void publishEvent(ApplicationEvent event) {
299 		try {
300 			this.eventPublisher.publishEvent(event);
301 		}
302 		catch (Throwable ex) {
303 			logger.error("Error publishing " + event + ".", ex);
304 		}
305 	}
306 
307 	protected void sendErrorMessage(WebSocketSession session, Throwable error) {
308 		StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
309 		headerAccessor.setMessage(error.getMessage());
310 		byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
311 		try {
312 			session.sendMessage(new TextMessage(bytes));
313 		}
314 		catch (Throwable ex) {
315 			// Could be part of normal workflow (e.g. browser tab closed)
316 			logger.debug("Failed to send STOMP ERROR to client.", ex);
317 		}
318 	}
319 
320 	/**
321 	 * Handle STOMP messages going back out to WebSocket clients.
322 	 */
323 	@SuppressWarnings("unchecked")
324 	@Override
325 	public void handleMessageToClient(WebSocketSession session, Message<?> message) {
326 		if (!(message.getPayload() instanceof byte[])) {
327 			logger.error("Expected byte[] payload. Ignoring " + message + ".");
328 			return;
329 		}
330 		StompHeaderAccessor stompAccessor = getStompHeaderAccessor(message);
331 		StompCommand command = stompAccessor.getCommand();
332 		if (StompCommand.MESSAGE.equals(command)) {
333 			if (stompAccessor.getSubscriptionId() == null) {
334 				logger.warn("No STOMP \"subscription\" header in " + message);
335 			}
336 			String origDestination = stompAccessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
337 			if (origDestination != null) {
338 				stompAccessor = toMutableAccessor(stompAccessor, message);
339 				stompAccessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
340 				stompAccessor.setDestination(origDestination);
341 			}
342 		}
343 		else if (StompCommand.CONNECTED.equals(command)) {
344 			this.stats.incrementConnectedCount();
345 			stompAccessor = afterStompSessionConnected(message, stompAccessor, session);
346 			if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) {
347 				try {
348 					SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
349 					SimpAttributesContextHolder.setAttributes(simpAttributes);
350 					publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
351 				}
352 				finally {
353 					SimpAttributesContextHolder.resetAttributes();
354 				}
355 			}
356 		}
357 		try {
358 			byte[] payload = (byte[]) message.getPayload();
359 			byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), payload);
360 
361 			boolean useBinary = (payload.length > 0 && !(session instanceof SockJsSession) &&
362 					MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType()));
363 
364 			if (useBinary) {
365 				session.sendMessage(new BinaryMessage(bytes));
366 			}
367 			else {
368 				session.sendMessage(new TextMessage(bytes));
369 			}
370 		}
371 		catch (SessionLimitExceededException ex) {
372 			// Bad session, just get out
373 			throw ex;
374 		}
375 		catch (Throwable ex) {
376 			// Could be part of normal workflow (e.g. browser tab closed)
377 			logger.debug("Failed to send WebSocket message to client in session " + session.getId() + ".", ex);
378 			command = StompCommand.ERROR;
379 		}
380 		finally {
381 			if (StompCommand.ERROR.equals(command)) {
382 				try {
383 					session.close(CloseStatus.PROTOCOL_ERROR);
384 				}
385 				catch (IOException ex) {
386 					// Ignore
387 				}
388 			}
389 		}
390 	}
391 
392 	private  StompHeaderAccessor getStompHeaderAccessor(Message<?> message) {
393 		MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
394 		if (accessor == null) {
395 			// Shouldn't happen (only broker broadcasts directly to clients)
396 			throw new IllegalStateException("No header accessor in " + message + ".");
397 		}
398 		StompHeaderAccessor stompAccessor;
399 		if (accessor instanceof StompHeaderAccessor) {
400 			stompAccessor = (StompHeaderAccessor) accessor;
401 		}
402 		else if (accessor instanceof SimpMessageHeaderAccessor) {
403 			stompAccessor = StompHeaderAccessor.wrap(message);
404 			if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) {
405 				stompAccessor = convertConnectAcktoStompConnected(stompAccessor);
406 			}
407 			else if (SimpMessageType.DISCONNECT_ACK.equals(stompAccessor.getMessageType())) {
408 				stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
409 				stompAccessor.setMessage("Session closed.");
410 			}
411 			else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
412 				stompAccessor.updateStompCommandAsServerMessage();
413 			}
414 		}
415 		else {
416 			// Shouldn't happen (only broker broadcasts directly to clients)
417 			throw new IllegalStateException(
418 					"Unexpected header accessor type: " + accessor.getClass() + " in " + message + ".");
419 		}
420 		return stompAccessor;
421 	}
422 
423 	/**
424 	 * The simple broker produces {@code SimpMessageType.CONNECT_ACK} that's not STOMP
425 	 * specific and needs to be turned into a STOMP CONNECTED frame.
426 	 */
427 	private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor connectAckHeaders) {
428 		String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
429 		Message<?> message = (Message<?>) connectAckHeaders.getHeader(name);
430 		Assert.notNull(message, "Original STOMP CONNECT not found in " + connectAckHeaders);
431 		StompHeaderAccessor connectHeaders = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
432 		String version;
433 		Set<String> acceptVersions = connectHeaders.getAcceptVersion();
434 		if (acceptVersions.contains("1.2")) {
435 			version = "1.2";
436 		}
437 		else if (acceptVersions.contains("1.1")) {
438 			version = "1.1";
439 		}
440 		else if (acceptVersions.isEmpty()) {
441 			version = null;
442 		}
443 		else {
444 			throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'");
445 		}
446 		StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
447 		connectedHeaders.setVersion(version);
448 		connectedHeaders.setHeartbeat(0, 0); // not supported
449 		return connectedHeaders;
450 	}
451 
452 	protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
453 		return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
454 	}
455 
456 	private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor,
457 			WebSocketSession session) {
458 
459 		Principal principal = session.getPrincipal();
460 		if (principal != null) {
461 			accessor = toMutableAccessor(accessor, message);
462 			accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
463 			if (this.userSessionRegistry != null) {
464 				String userName = getSessionRegistryUserName(principal);
465 				this.userSessionRegistry.registerSessionId(userName, session.getId());
466 			}
467 		}
468 		long[] heartbeat = accessor.getHeartbeat();
469 		if (heartbeat[1] > 0) {
470 			session = WebSocketSessionDecorator.unwrap(session);
471 			if (session instanceof SockJsSession) {
472 				((SockJsSession) session).disableHeartbeat();
473 			}
474 		}
475 		return accessor;
476 	}
477 
478 	private String getSessionRegistryUserName(Principal principal) {
479 		String userName = principal.getName();
480 		if (principal instanceof DestinationUserNameProvider) {
481 			userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
482 		}
483 		return userName;
484 	}
485 
486 	@Override
487 	public String resolveSessionId(Message<?> message) {
488 		return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
489 	}
490 
491 	@Override
492 	public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
493 		if (session.getTextMessageSizeLimit() < MINIMUM_WEBSOCKET_MESSAGE_SIZE) {
494 			session.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE);
495 		}
496 		this.decoders.put(session.getId(), new BufferingStompDecoder(this.stompDecoder, getMessageSizeLimit()));
497 	}
498 
499 	@Override
500 	public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
501 		this.decoders.remove(session.getId());
502 		Principal principal = session.getPrincipal();
503 		if (principal != null && this.userSessionRegistry != null) {
504 			String userName = getSessionRegistryUserName(principal);
505 			this.userSessionRegistry.unregisterSessionId(userName, session.getId());
506 		}
507 		Message<byte[]> message = createDisconnectMessage(session);
508 		SimpAttributes simpAttributes = SimpAttributes.fromMessage(message);
509 		try {
510 			SimpAttributesContextHolder.setAttributes(simpAttributes);
511 			if (this.eventPublisher != null) {
512 				publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus));
513 			}
514 			outputChannel.send(message);
515 		}
516 		finally {
517 			SimpAttributesContextHolder.resetAttributes();
518 			simpAttributes.sessionCompleted();
519 		}
520 	}
521 
522 	private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
523 		StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
524 		if (getHeaderInitializer() != null) {
525 			getHeaderInitializer().initHeaders(headerAccessor);
526 		}
527 		headerAccessor.setSessionId(session.getId());
528 		headerAccessor.setSessionAttributes(session.getAttributes());
529 		headerAccessor.setUser(session.getPrincipal());
530 		return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
531 	}
532 
533 	@Override
534 	public String toString() {
535 		return "StompSubProtocolHandler" + getSupportedProtocols();
536 	}
537 
538 	private class Stats {
539 
540 		private final AtomicInteger connect = new AtomicInteger();
541 
542 		private final AtomicInteger connected = new AtomicInteger();
543 
544 		private final AtomicInteger disconnect = new AtomicInteger();
545 
546 
547 		public void incrementConnectCount() {
548 			this.connect.incrementAndGet();
549 		}
550 
551 		public void incrementConnectedCount() {
552 			this.connected.incrementAndGet();
553 		}
554 
555 		public void incrementDisconnectCount() {
556 			this.disconnect.incrementAndGet();
557 		}
558 
559 
560 		public String toString() {
561 			return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" +
562 					this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + ")";
563 		}
564 	}
565 
566 }