View Javadoc
1   /*
2    * Copyright 2002-2014 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.messaging.simp.stomp;
18  
19  import java.io.ByteArrayOutputStream;
20  import java.nio.ByteBuffer;
21  import java.nio.charset.Charset;
22  import java.util.ArrayList;
23  import java.util.List;
24  
25  import org.apache.commons.logging.Log;
26  import org.apache.commons.logging.LogFactory;
27  
28  import org.springframework.messaging.Message;
29  import org.springframework.messaging.support.MessageBuilder;
30  import org.springframework.messaging.support.MessageHeaderInitializer;
31  import org.springframework.messaging.support.NativeMessageHeaderAccessor;
32  import org.springframework.util.InvalidMimeTypeException;
33  import org.springframework.util.MultiValueMap;
34  
35  /**
36   * Decodes one or more STOMP frames contained in a {@link ByteBuffer}.
37   *
38   * <p>An attempt is made to read all complete STOMP frames from the buffer, which
39   * could be zero, one, or more. If there is any left-over content, i.e. an incomplete
40   * STOMP frame, at the end the buffer is reset to point to the beginning of the
41   * partial content. The caller is then responsible for dealing with that
42   * incomplete content by buffering until there is more input available.
43   *
44   * @author Andy Wilkinson
45   * @author Rossen Stoyanchev
46   * @since 4.0
47   */
48  public class StompDecoder {
49  
50  	static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
51  
52  	static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
53  
54  	private static final Log logger = LogFactory.getLog(StompDecoder.class);
55  
56  
57  	private MessageHeaderInitializer headerInitializer;
58  
59  
60  	/**
61  	 * Configure a
62  	 * {@link org.springframework.messaging.support.MessageHeaderInitializer MessageHeaderInitializer}
63  	 * to apply to the headers of {@link Message}s from decoded STOMP frames.
64  	 */
65  	public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
66  		this.headerInitializer = headerInitializer;
67  	}
68  
69  	/**
70  	 * Return the configured {@code MessageHeaderInitializer}, if any.
71  	 */
72  	public MessageHeaderInitializer getHeaderInitializer() {
73  		return this.headerInitializer;
74  	}
75  
76  
77  	/**
78  	 * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
79  	 * list of {@link Message}s. If the input buffer contains partial STOMP frame
80  	 * content, or additional content with a partial STOMP frame, the buffer is
81  	 * reset and {@code null} is returned.
82  	 * @param buffer the buffer to decode the STOMP frame from
83  	 * @return the decoded messages, or an empty list if none
84  	 * @throws StompConversionException raised in case of decoding issues
85  	 */
86  	public List<Message<byte[]>> decode(ByteBuffer buffer) {
87  		return decode(buffer, null);
88  	}
89  
90  	/**
91  	 * Decodes one or more STOMP frames from the given {@code buffer} and returns
92  	 * a list of {@link Message}s.
93  	 * <p>If the given ByteBuffer contains only partial STOMP frame content and no
94  	 * complete STOMP frames, an empty list is returned, and the buffer is reset to
95  	 * to where it was.
96  	 * <p>If the buffer contains one ore more STOMP frames, those are returned and
97  	 * the buffer reset to point to the beginning of the unused partial content.
98  	 * <p>The output partialMessageHeaders map is used to store successfully parsed
99  	 * headers in case of partial content. The caller can then check if a
100 	 * "content-length" header was read, which helps to determine how much more
101 	 * content is needed before the next attempt to decode.
102 	 * @param buffer the buffer to decode the STOMP frame from
103 	 * @param partialMessageHeaders an empty output map that will store the last
104 	 * successfully parsed partialMessageHeaders in case of partial message content
105 	 * in cases where the partial buffer ended with a partial STOMP frame
106 	 * @return the decoded messages, or an empty list if none
107 	 * @throws StompConversionException raised in case of decoding issues
108 	 */
109 	public List<Message<byte[]>> decode(ByteBuffer buffer, MultiValueMap<String, String> partialMessageHeaders) {
110 		List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
111 		while (buffer.hasRemaining()) {
112 			Message<byte[]> message = decodeMessage(buffer, partialMessageHeaders);
113 			if (message != null) {
114 				messages.add(message);
115 			}
116 			else {
117 				break;
118 			}
119 		}
120 		return messages;
121 	}
122 
123 	/**
124 	 * Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
125 	 */
126 	private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {
127 		Message<byte[]> decodedMessage = null;
128 		skipLeadingEol(buffer);
129 		buffer.mark();
130 
131 		String command = readCommand(buffer);
132 		if (command.length() > 0) {
133 			StompHeaderAccessor headerAccessor = null;
134 			byte[] payload = null;
135 			if (buffer.remaining() > 0) {
136 				StompCommand stompCommand = StompCommand.valueOf(command);
137 				headerAccessor = StompHeaderAccessor.create(stompCommand);
138 				initHeaders(headerAccessor);
139 				readHeaders(buffer, headerAccessor);
140 				payload = readPayload(buffer, headerAccessor);
141 			}
142 			if (payload != null) {
143 				if (payload.length > 0 && !headerAccessor.getCommand().isBodyAllowed()) {
144 					throw new StompConversionException(headerAccessor.getCommand() +
145 							" shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
146 				}
147 				headerAccessor.updateSimpMessageHeadersFromStompHeaders();
148 				headerAccessor.setLeaveMutable(true);
149 				decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
150 				if (logger.isTraceEnabled()) {
151 					logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
152 				}
153 			}
154 			else {
155 				if (logger.isTraceEnabled()) {
156 					logger.trace("Incomplete frame, resetting input buffer...");
157 				}
158 				if (headers != null && headerAccessor != null) {
159 					String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
160 					@SuppressWarnings("unchecked")
161 					MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name);
162 					if (map != null) {
163 						headers.putAll(map);
164 					}
165 				}
166 				buffer.reset();
167 			}
168 		}
169 		else {
170 			StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
171 			initHeaders(headerAccessor);
172 			headerAccessor.setLeaveMutable(true);
173 			decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
174 			if (logger.isTraceEnabled()) {
175 				logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
176 			}
177 		}
178 
179 		return decodedMessage;
180 	}
181 
182 	private void initHeaders(StompHeaderAccessor headerAccessor) {
183 		MessageHeaderInitializer initializer = getHeaderInitializer();
184 		if (initializer != null) {
185 			initializer.initHeaders(headerAccessor);
186 		}
187 	}
188 
189 	/**
190 	 * Skip one ore more EOL characters at the start of the given ByteBuffer.
191 	 * Those are STOMP heartbeat frames.
192 	 */
193 	protected void skipLeadingEol(ByteBuffer buffer) {
194 		while (true) {
195 			if (!tryConsumeEndOfLine(buffer)) {
196 				break;
197 			}
198 		}
199 	}
200 
201 	private String readCommand(ByteBuffer buffer) {
202 		ByteArrayOutputStream command = new ByteArrayOutputStream(256);
203 		while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
204 			command.write(buffer.get());
205 		}
206 		return new String(command.toByteArray(), UTF8_CHARSET);
207 	}
208 
209 	private void readHeaders(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
210 		while (true) {
211 			ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
212 			while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
213 				headerStream.write(buffer.get());
214 			}
215 			if (headerStream.size() > 0) {
216 				String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
217 				int colonIndex = header.indexOf(':');
218 				if (colonIndex <= 0 || colonIndex == header.length() - 1) {
219 					if (buffer.remaining() > 0) {
220 						throw new StompConversionException("Illegal header: '" + header +
221 								"'. A header must be of the form <name>:<value>.");
222 					}
223 				}
224 				else {
225 					String headerName = unescape(header.substring(0, colonIndex));
226 					String headerValue = unescape(header.substring(colonIndex + 1));
227 					try {
228 						headerAccessor.addNativeHeader(headerName, headerValue);
229 					}
230 					catch (InvalidMimeTypeException ex) {
231 						if (buffer.remaining() > 0) {
232 							throw ex;
233 						}
234 					}
235 				}
236 			}
237 			else {
238 				break;
239 			}
240 		}
241 	}
242 
243 	/**
244 	 * See STOMP Spec 1.2:
245 	 * <a href="http://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
246 	 */
247 	private String unescape(String inString) {
248 		StringBuilder sb = new StringBuilder(inString.length());
249 		int pos = 0;  // position in the old string
250 		int index = inString.indexOf("\\");
251 
252 		while (index >= 0) {
253 			sb.append(inString.substring(pos, index));
254 			if (index + 1 >= inString.length()) {
255 				throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
256 			}
257 			Character c = inString.charAt(index + 1);
258 			if (c == 'r') {
259 				sb.append('\r');
260 			}
261 			else if (c == 'n') {
262 				sb.append('\n');
263 			}
264 			else if (c == 'c') {
265 				sb.append(':');
266 			}
267 			else if (c == '\\') {
268 				sb.append('\\');
269 			}
270 			else {
271 				// should never happen
272 				throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
273 			}
274 			pos = index + 2;
275 			index = inString.indexOf("\\", pos);
276 		}
277 
278 		sb.append(inString.substring(pos));
279 		return sb.toString();
280 	}
281 
282 	private byte[] readPayload(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
283 		Integer contentLength;
284 		try {
285 			contentLength = headerAccessor.getContentLength();
286 		}
287 		catch (NumberFormatException ex) {
288 			logger.warn("Ignoring invalid content-length: '" + headerAccessor);
289 			contentLength = null;
290 		}
291 
292 		if (contentLength != null && contentLength >= 0) {
293 			if (buffer.remaining() > contentLength) {
294 				byte[] payload = new byte[contentLength];
295 				buffer.get(payload);
296 				if (buffer.get() != 0) {
297 					throw new StompConversionException("Frame must be terminated with a null octet");
298 				}
299 				return payload;
300 			}
301 			else {
302 				return null;
303 			}
304 		}
305 		else {
306 			ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
307 			while (buffer.remaining() > 0) {
308 				byte b = buffer.get();
309 				if (b == 0) {
310 					return payload.toByteArray();
311 				}
312 				else {
313 					payload.write(b);
314 				}
315 			}
316 		}
317 		return null;
318 	}
319 
320 	/**
321 	 * Try to read an EOL incrementing the buffer position if successful.
322 	 * @return whether an EOL was consumed
323 	 */
324 	private boolean tryConsumeEndOfLine(ByteBuffer buffer) {
325 		if (buffer.remaining() > 0) {
326 			byte b = buffer.get();
327 			if (b == '\n') {
328 				return true;
329 			}
330 			else if (b == '\r') {
331 				if (buffer.remaining() > 0 && buffer.get() == '\n') {
332 					return true;
333 				}
334 				else {
335 					throw new StompConversionException("'\\r' must be followed by '\\n'");
336 				}
337 			}
338 			buffer.position(buffer.position() - 1);
339 		}
340 		return false;
341 	}
342 
343 }