001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.activemq.transport.stomp;
018    
019    import java.io.DataInput;
020    import java.io.DataInputStream;
021    import java.io.DataOutput;
022    import java.io.DataOutputStream;
023    import java.io.IOException;
024    import java.util.HashMap;
025    import java.util.Iterator;
026    import java.util.Map;
027    
028    import org.apache.activemq.util.ByteArrayInputStream;
029    import org.apache.activemq.util.ByteArrayOutputStream;
030    import org.apache.activemq.util.ByteSequence;
031    import org.apache.activemq.wireformat.WireFormat;
032    
033    /**
034     * Implements marshalling and unmarsalling the <a
035     * href="http://stomp.codehaus.org/">Stomp</a> protocol.
036     */
037    public class StompWireFormat implements WireFormat {
038    
039        private static final byte[] NO_DATA = new byte[] {};
040        private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
041    
042        private static final int MAX_COMMAND_LENGTH = 1024;
043        private static final int MAX_HEADER_LENGTH = 1024 * 10;
044        private static final int MAX_HEADERS = 1000;
045        private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
046    
047        private int version = 1;
048    
049        public ByteSequence marshal(Object command) throws IOException {
050            ByteArrayOutputStream baos = new ByteArrayOutputStream();
051            DataOutputStream dos = new DataOutputStream(baos);
052            marshal(command, dos);
053            dos.close();
054            return baos.toByteSequence();
055        }
056    
057        public Object unmarshal(ByteSequence packet) throws IOException {
058            ByteArrayInputStream stream = new ByteArrayInputStream(packet);
059            DataInputStream dis = new DataInputStream(stream);
060            return unmarshal(dis);
061        }
062    
063        public void marshal(Object command, DataOutput os) throws IOException {
064            StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
065    
066            StringBuffer buffer = new StringBuffer();
067            buffer.append(stomp.getAction());
068            buffer.append(Stomp.NEWLINE);
069    
070            // Output the headers.
071            for (Iterator iter = stomp.getHeaders().entrySet().iterator(); iter.hasNext();) {
072                Map.Entry entry = (Map.Entry)iter.next();
073                buffer.append(entry.getKey());
074                buffer.append(Stomp.Headers.SEPERATOR);
075                buffer.append(entry.getValue());
076                buffer.append(Stomp.NEWLINE);
077            }
078    
079            // Add a newline to seperate the headers from the content.
080            buffer.append(Stomp.NEWLINE);
081    
082            os.write(buffer.toString().getBytes("UTF-8"));
083            os.write(stomp.getContent());
084            os.write(END_OF_FRAME);
085        }
086    
087        public Object unmarshal(DataInput in) throws IOException {
088    
089            try {
090                String action = null;
091    
092                // skip white space to next real action line
093                while (true) {
094                    action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
095                    if (action == null) {
096                        throw new IOException("connection was closed");
097                    } else {
098                        action = action.trim();
099                        if (action.length() > 0) {
100                            break;
101                        }
102                    }
103                }
104    
105                // Parse the headers
106                HashMap<String, String> headers = new HashMap<String, String>(25);
107                while (true) {
108                    String line = readLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
109                    if (line != null && line.trim().length() > 0) {
110    
111                        if (headers.size() > MAX_HEADERS) {
112                            throw new ProtocolException("The maximum number of headers was exceeded", true);
113                        }
114    
115                        try {
116                            int seperatorIndex = line.indexOf(Stomp.Headers.SEPERATOR);
117                            String name = line.substring(0, seperatorIndex).trim();
118                            String value = line.substring(seperatorIndex + 1, line.length()).trim();
119                            headers.put(name, value);
120                        } catch (Exception e) {
121                            throw new ProtocolException("Unable to parser header line [" + line + "]", true);
122                        }
123                    } else {
124                        break;
125                    }
126                }
127    
128                // Read in the data part.
129                byte[] data = NO_DATA;
130                String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
131                if (contentLength != null) {
132    
133                    // Bless the client, he's telling us how much data to read in.
134                    int length;
135                    try {
136                        length = Integer.parseInt(contentLength.trim());
137                    } catch (NumberFormatException e) {
138                        throw new ProtocolException("Specified content-length is not a valid integer", true);
139                    }
140    
141                    if (length > MAX_DATA_LENGTH) {
142                        throw new ProtocolException("The maximum data length was exceeded", true);
143                    }
144    
145                    data = new byte[length];
146                    in.readFully(data);
147    
148                    if (in.readByte() != 0) {
149                        throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
150                    }
151    
152                } else {
153    
154                    // We don't know how much to read.. data ends when we hit a 0
155                    byte b;
156                    ByteArrayOutputStream baos = null;
157                    while ((b = in.readByte()) != 0) {
158    
159                        if (baos == null) {
160                            baos = new ByteArrayOutputStream();
161                        } else if (baos.size() > MAX_DATA_LENGTH) {
162                            throw new ProtocolException("The maximum data length was exceeded", true);
163                        }
164    
165                        baos.write(b);
166                    }
167    
168                    if (baos != null) {
169                        baos.close();
170                        data = baos.toByteArray();
171                    }
172    
173                }
174    
175                return new StompFrame(action, headers, data);
176    
177            } catch (ProtocolException e) {
178                return new StompFrameError(e);
179            }
180    
181        }
182    
183        private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
184            byte b;
185            ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
186            while ((b = in.readByte()) != '\n') {
187                if (baos.size() > maxLength) {
188                    throw new ProtocolException(errorMessage, true);
189                }
190                baos.write(b);
191            }
192            baos.close();
193            ByteSequence sequence = baos.toByteSequence();
194            return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8");
195        }
196    
197        public int getVersion() {
198            return version;
199        }
200    
201        public void setVersion(int version) {
202            this.version = version;
203        }
204    
205    }