1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.messaging.simp.stomp;
18
19 import java.io.ByteArrayOutputStream;
20 import java.nio.ByteBuffer;
21 import java.nio.charset.Charset;
22 import java.util.ArrayList;
23 import java.util.List;
24
25 import org.apache.commons.logging.Log;
26 import org.apache.commons.logging.LogFactory;
27
28 import org.springframework.messaging.Message;
29 import org.springframework.messaging.support.MessageBuilder;
30 import org.springframework.messaging.support.MessageHeaderInitializer;
31 import org.springframework.messaging.support.NativeMessageHeaderAccessor;
32 import org.springframework.util.InvalidMimeTypeException;
33 import org.springframework.util.MultiValueMap;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 public class StompDecoder {
49
50 static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
51
52 static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
53
54 private static final Log logger = LogFactory.getLog(StompDecoder.class);
55
56
57 private MessageHeaderInitializer headerInitializer;
58
59
60
61
62
63
64
65 public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
66 this.headerInitializer = headerInitializer;
67 }
68
69
70
71
72 public MessageHeaderInitializer getHeaderInitializer() {
73 return this.headerInitializer;
74 }
75
76
77
78
79
80
81
82
83
84
85
86 public List<Message<byte[]>> decode(ByteBuffer buffer) {
87 return decode(buffer, null);
88 }
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109 public List<Message<byte[]>> decode(ByteBuffer buffer, MultiValueMap<String, String> partialMessageHeaders) {
110 List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
111 while (buffer.hasRemaining()) {
112 Message<byte[]> message = decodeMessage(buffer, partialMessageHeaders);
113 if (message != null) {
114 messages.add(message);
115 }
116 else {
117 break;
118 }
119 }
120 return messages;
121 }
122
123
124
125
126 private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {
127 Message<byte[]> decodedMessage = null;
128 skipLeadingEol(buffer);
129 buffer.mark();
130
131 String command = readCommand(buffer);
132 if (command.length() > 0) {
133 StompHeaderAccessor headerAccessor = null;
134 byte[] payload = null;
135 if (buffer.remaining() > 0) {
136 StompCommand stompCommand = StompCommand.valueOf(command);
137 headerAccessor = StompHeaderAccessor.create(stompCommand);
138 initHeaders(headerAccessor);
139 readHeaders(buffer, headerAccessor);
140 payload = readPayload(buffer, headerAccessor);
141 }
142 if (payload != null) {
143 if (payload.length > 0 && !headerAccessor.getCommand().isBodyAllowed()) {
144 throw new StompConversionException(headerAccessor.getCommand() +
145 " shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
146 }
147 headerAccessor.updateSimpMessageHeadersFromStompHeaders();
148 headerAccessor.setLeaveMutable(true);
149 decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
150 if (logger.isTraceEnabled()) {
151 logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
152 }
153 }
154 else {
155 if (logger.isTraceEnabled()) {
156 logger.trace("Incomplete frame, resetting input buffer...");
157 }
158 if (headers != null && headerAccessor != null) {
159 String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
160 @SuppressWarnings("unchecked")
161 MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name);
162 if (map != null) {
163 headers.putAll(map);
164 }
165 }
166 buffer.reset();
167 }
168 }
169 else {
170 StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
171 initHeaders(headerAccessor);
172 headerAccessor.setLeaveMutable(true);
173 decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
174 if (logger.isTraceEnabled()) {
175 logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
176 }
177 }
178
179 return decodedMessage;
180 }
181
182 private void initHeaders(StompHeaderAccessor headerAccessor) {
183 MessageHeaderInitializer initializer = getHeaderInitializer();
184 if (initializer != null) {
185 initializer.initHeaders(headerAccessor);
186 }
187 }
188
189
190
191
192
193 protected void skipLeadingEol(ByteBuffer buffer) {
194 while (true) {
195 if (!tryConsumeEndOfLine(buffer)) {
196 break;
197 }
198 }
199 }
200
201 private String readCommand(ByteBuffer buffer) {
202 ByteArrayOutputStream command = new ByteArrayOutputStream(256);
203 while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
204 command.write(buffer.get());
205 }
206 return new String(command.toByteArray(), UTF8_CHARSET);
207 }
208
209 private void readHeaders(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
210 while (true) {
211 ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
212 while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
213 headerStream.write(buffer.get());
214 }
215 if (headerStream.size() > 0) {
216 String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
217 int colonIndex = header.indexOf(':');
218 if (colonIndex <= 0 || colonIndex == header.length() - 1) {
219 if (buffer.remaining() > 0) {
220 throw new StompConversionException("Illegal header: '" + header +
221 "'. A header must be of the form <name>:<value>.");
222 }
223 }
224 else {
225 String headerName = unescape(header.substring(0, colonIndex));
226 String headerValue = unescape(header.substring(colonIndex + 1));
227 try {
228 headerAccessor.addNativeHeader(headerName, headerValue);
229 }
230 catch (InvalidMimeTypeException ex) {
231 if (buffer.remaining() > 0) {
232 throw ex;
233 }
234 }
235 }
236 }
237 else {
238 break;
239 }
240 }
241 }
242
243
244
245
246
247 private String unescape(String inString) {
248 StringBuilder sb = new StringBuilder(inString.length());
249 int pos = 0;
250 int index = inString.indexOf("\\");
251
252 while (index >= 0) {
253 sb.append(inString.substring(pos, index));
254 if (index + 1 >= inString.length()) {
255 throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
256 }
257 Character c = inString.charAt(index + 1);
258 if (c == 'r') {
259 sb.append('\r');
260 }
261 else if (c == 'n') {
262 sb.append('\n');
263 }
264 else if (c == 'c') {
265 sb.append(':');
266 }
267 else if (c == '\\') {
268 sb.append('\\');
269 }
270 else {
271
272 throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
273 }
274 pos = index + 2;
275 index = inString.indexOf("\\", pos);
276 }
277
278 sb.append(inString.substring(pos));
279 return sb.toString();
280 }
281
282 private byte[] readPayload(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
283 Integer contentLength;
284 try {
285 contentLength = headerAccessor.getContentLength();
286 }
287 catch (NumberFormatException ex) {
288 logger.warn("Ignoring invalid content-length: '" + headerAccessor);
289 contentLength = null;
290 }
291
292 if (contentLength != null && contentLength >= 0) {
293 if (buffer.remaining() > contentLength) {
294 byte[] payload = new byte[contentLength];
295 buffer.get(payload);
296 if (buffer.get() != 0) {
297 throw new StompConversionException("Frame must be terminated with a null octet");
298 }
299 return payload;
300 }
301 else {
302 return null;
303 }
304 }
305 else {
306 ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
307 while (buffer.remaining() > 0) {
308 byte b = buffer.get();
309 if (b == 0) {
310 return payload.toByteArray();
311 }
312 else {
313 payload.write(b);
314 }
315 }
316 }
317 return null;
318 }
319
320
321
322
323
324 private boolean tryConsumeEndOfLine(ByteBuffer buffer) {
325 if (buffer.remaining() > 0) {
326 byte b = buffer.get();
327 if (b == '\n') {
328 return true;
329 }
330 else if (b == '\r') {
331 if (buffer.remaining() > 0 && buffer.get() == '\n') {
332 return true;
333 }
334 else {
335 throw new StompConversionException("'\\r' must be followed by '\\n'");
336 }
337 }
338 buffer.position(buffer.position() - 1);
339 }
340 return false;
341 }
342
343 }