diff options
Diffstat (limited to 'src/server')
| -rw-r--r-- | src/server/message_parser.zig | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig index f99dfcb..9fd490c 100644 --- a/src/server/message_parser.zig +++ b/src/server/message_parser.zig @@ -170,7 +170,7 @@ pub const Message = union(MessageType) { while (in.peekByte()) |byte| { if (std.ascii.isUpper(byte)) { try operation_string.appendBounded(byte); - try in.discardAll(1); + in.toss(1); } else break; } else |err| return err; @@ -190,7 +190,7 @@ pub const Message = union(MessageType) { // Should read the next JSON object to the fixed buffer writer. _ = try in.streamDelimiter(&connect_string_writer, '}'); try connect_string_writer.writeByte('}'); - std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' + try assertStreamBytes(in, "}\r\n"); // discard '}\r\n' // TODO: should be CONNECTION allocator const res = try std.json.parseFromSliceLeaky(Connect, connect_allocator, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); @@ -206,12 +206,12 @@ pub const Message = union(MessageType) { // Parse byte count const byte_count = blk: { var byte_count_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); - while (in.takeByte()) |byte| { + while (in.peekByte()) |byte| { if (std.ascii.isWhitespace(byte)) { - std.debug.assert(byte == '\r'); - std.debug.assert(try in.takeByte() == '\n'); + try assertStreamBytes(in, "\r\n"); break; } + defer in.toss(1); if (std.ascii.isDigit(byte)) { try byte_count_list.append(alloc, byte); @@ -226,7 +226,7 @@ pub const Message = union(MessageType) { const payload = blk: { const bytes = try alloc.alloc(u8, byte_count); try in.readSliceAll(bytes); - try assertStreamBytes(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); break :blk bytes; }; @@ -238,15 +238,18 @@ pub const Message = union(MessageType) { }; }, .ping => { - try assertStreamBytes(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .ping; }, .pong => { - try assertStreamBytes(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .pong; }, .sub => { - try assertStreamBytes(std.ascii.isWhitespace(try in.takeByte())); + if (!std.ascii.isWhitespace(try in.takeByte())) { + @branchHint(.unlikely); + return error.InvalidStream; + } const subject = try readSubject(alloc, in); const second = blk: { // Drop whitespace @@ -266,8 +269,7 @@ pub const Message = union(MessageType) { }; const queue_group = if ((try in.peekByte()) != '\r') second else null; const sid = if (queue_group) |_| try in.takeDelimiterExclusive('\r') else second; - std.debug.print("SID is '{s}'\n", .{sid}); - std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .{ .sub = .{ .subject = subject, @@ -277,7 +279,10 @@ pub const Message = union(MessageType) { }; }, .unsub => { - try assertStreamBytes(std.ascii.isWhitespace(try in.takeByte())); + if (!std.ascii.isWhitespace(try in.takeByte())) { + @branchHint(.unlikely); + return error.InvalidStream; + } // Parse byte count const sid = blk: { var acc: std.ArrayList(u8) = try .initCapacity(alloc, 8); @@ -290,7 +295,7 @@ pub const Message = union(MessageType) { }; if ((try in.peekByte()) == '\r') { - try assertStreamBytes(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .{ .unsub = .{ .sid = sid, @@ -300,10 +305,9 @@ pub const Message = union(MessageType) { in.toss(1); const max_msgs = blk: { var max_msgs_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); - while (in.takeByte()) |byte| { + while (in.peekByte()) |byte| { if (std.ascii.isWhitespace(byte)) { - std.debug.assert(byte == '\r'); - std.debug.assert(try in.takeByte() == '\n'); + try assertStreamBytes(in, "\r\n"); break; } @@ -337,9 +341,8 @@ fn readSubject(alloc: std.mem.Allocator, in: *std.Io.Reader) ![]const u8 { // Handle the first character { const byte = try in.takeByte(); - std.debug.assert(!std.ascii.isWhitespace(byte)); - if (byte == '.') - return error.InvalidSubject; + if (std.ascii.isWhitespace(byte) or byte == '.') + return error.InvalidStream; try subject_list.append(alloc, byte); } @@ -386,8 +389,8 @@ fn parsePub(in: *std.Io.Reader) !Message.Pub { }; } -inline fn assertStreamBytes(cond: bool) !void { - if (!cond) { +inline fn assertStreamBytes(reader: *std.Io.Reader, expected: []const u8) !void { + if (!std.mem.eql(u8, try reader.take(expected.len), expected)) { @branchHint(.unlikely); return error.InvalidStream; } |
