diff options
Diffstat (limited to 'src/server/message_parser.zig')
| -rw-r--r-- | src/server/message_parser.zig | 239 |
1 files changed, 153 insertions, 86 deletions
diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig index 2bc5286..75e13d2 100644 --- a/src/server/message_parser.zig +++ b/src/server/message_parser.zig @@ -29,31 +29,9 @@ pub const MessageType = enum { // if (std.mem.eql(u8, "@"-ERR"", input)) return .@"-err"; return error.InvalidMessageType; } - - const client_types = std.StaticStringMap(MessageType).initComptime( - .{ - // {"INFO", .info}, - .{ "CONNECT", .connect }, - .{ "PUB", .@"pub" }, - .{ "HPUB", .hpub }, - .{ "SUB", .sub }, - .{ "UNSUB", .unsub }, - // {"MSG", .msg}, - // {"HMSG", .hmsg}, - .{ "PING", .ping }, - .{ "PONG", .pong }, - // {"+OK", .@"+ok"}, - // {"-ERR", .@"-err"}, - }, - ); - fn parseStaticStringMap(input: []const u8) ?MessageType { - return client_types.get(input); - } - - pub const parse = parseStaticStringMap; }; -const Message = union(MessageType) { +pub const Message = union(enum) { info: void, connect: Connect, @@ -88,9 +66,144 @@ const Message = union(MessageType) { const Pub = struct { subject: []const u8, reply_to: ?[]const u8 = null, - bytes: usize, payload: []const u8, }; + + const client_types = std.StaticStringMap(MessageType).initComptime( + .{ + // {"INFO", .info}, + .{ "CONNECT", .connect }, + .{ "PUB", .@"pub" }, + .{ "HPUB", .hpub }, + .{ "SUB", .sub }, + .{ "UNSUB", .unsub }, + // {"MSG", .msg}, + // {"HMSG", .hmsg}, + .{ "PING", .ping }, + .{ "PONG", .pong }, + // {"+OK", .@"+ok"}, + // {"-ERR", .@"-err"}, + }, + ); + fn parseStaticStringMap(input: []const u8) ?MessageType { + return client_types.get(input); + } + + pub const parse = parseStaticStringMap; + + /// An error should be handled by cleaning up this connection. + pub fn next(alloc: std.mem.Allocator, in: *std.Io.Reader) !Message { + var operation_string: std.ArrayList(u8) = blk: { + var buf: ["CONTINUE".len]u8 = undefined; + break :blk .initBuffer(&buf); + }; + + while (in.peekByte()) |byte| { + if (std.ascii.isUpper(byte)) { + try operation_string.appendBounded(byte); + try in.discardAll(1); + } else break; + } else |err| return err; + + const operation = parse(operation_string.items) orelse { + std.debug.print("operation: '{s}'\n", .{operation_string.items}); + std.debug.print("buffered: '{s}'", .{in.buffered()}); + return error.InvalidOperation; + }; + + switch (operation) { + .connect => { + // TODO: should be ARENA allocator + var connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(alloc, 1024); + var connect_string_writer = connect_string_writer_allocating.writer; + try in.discardAll(1); // throw away space + + // Should read the next JSON object to the fixed buffer writer. + _ = try in.streamDelimiter(&connect_string_writer, '}'); + try connect_string_writer.writeByte('}'); + std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' + + // TODO: should be CONNECTION allocator + const res = try std.json.parseFromSliceLeaky(Connect, alloc, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); + + return .{ .connect = res }; + }, + .@"pub" => { + try in.discardAll(1); // throw away space + + // Parse subject + const subject: []const u8 = blk: { + // TODO: should be ARENA allocator + var subject_list: std.ArrayList(u8) = try .initCapacity(alloc, 1024); + + // Handle the first character + { + const byte = try in.takeByte(); + if (byte == '.' or std.ascii.isWhitespace(byte)) + return error.InvalidSubject; + + try subject_list.append(alloc, byte); + } + + while (in.takeByte() catch null) |byte| { + if (std.ascii.isWhitespace(byte)) break; + if (std.ascii.isAscii(byte)) { + if (byte == '.') { + const next_byte = try in.peekByte(); + if (next_byte == '.' or std.ascii.isWhitespace(next_byte)) + return error.InvalidSubject; + } + try subject_list.append(alloc, byte); + } + } else return error.InvalidStream; + break :blk subject_list.items; + }; + + // Parse byte count + const byte_count = blk: { + var byte_count_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); + while (in.takeByte() catch null) |byte| { + if (std.ascii.isWhitespace(byte)) { + std.debug.assert(byte == '\r'); + std.debug.assert(try in.takeByte() == '\n'); + break; + } + + if (std.ascii.isDigit(byte)) { + try byte_count_list.append(alloc, byte); + } else { + return error.InvalidStream; + } + } else return error.InvalidStream; + + break :blk try std.fmt.parseUnsigned(u64, byte_count_list.items, 10); + }; + + const payload = blk: { + const bytes = try alloc.alloc(u8, byte_count); + try in.readSliceAll(bytes); + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + break :blk bytes; + }; + + std.debug.print("buffer: '{s}'\n", .{in.buffered()}); + // return std.debug.panic("not implemented", .{}); + return .{ .@"pub" = .{ + .subject = subject, + .payload = payload, + } }; + }, + .ping => { + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + return .ping; + }, + .pong => { + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + return .pong; + }, + else => |msg| std.debug.panic("Not implemented: {}\n", .{msg}), + } + } }; fn parseJsonMessage(T: type, alloc: std.mem.Allocator, in: *std.Io.Reader) !T { @@ -104,7 +217,6 @@ fn parseJsonMessage(T: type, alloc: std.mem.Allocator, in: *std.Io.Reader) !T { fn parsePub(in: *std.Io.Reader) !Message.Pub { const subject = (try in.takeDelimiter(' ')) orelse return error.EndOfStream; const next = (try in.takeDelimiter('\r')) orelse return error.EndOfStream; - std.debug.print("next: '{s}'\n", .{next}); var reply_to: ?[]const u8 = null; const bytes = std.fmt.parseUnsigned(usize, next, 10) catch blk: { reply_to = next; @@ -134,16 +246,15 @@ pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { len = i + 1; } else break; } - std.debug.print("word: '{s}'\n", .{word[0..len]}); - break :blk MessageType.parse(word[0..len]) orelse return null; + + break :blk Message.parse(word[0..len]) orelse return null; }; - std.debug.print("buffered: '{s}'\n", .{in.buffered()}); + // defer in.toss(2); // CRLF return switch (message_type) { .connect => blk: { const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }; - std.debug.print("value: {s}\n", .{value.?.connect.name.?}); - std.debug.print("buffered: '{d}'\n", .{in.buffered().len}); + break :blk value; }, .@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) }, @@ -152,47 +263,6 @@ pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { }; } -// test parseNextMessage { -// { -// const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\naftertheendoftheinput"; -// var reader: std.Io.Reader = .fixed(input); -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); -// const gpa = arena.allocator(); -// const msg: ?Message = parseNextMessage(gpa, &reader); -// const expected: ?Message = .{ .connect = .{ -// .verbose = false, -// .pedantic = false, -// .tls_required = false, -// .name = "NATS CLI Version v0.2.4", -// .lang = "go", -// .version = "1.43.0", -// .protocol = 1, -// .echo = true, -// .headers = true, -// .no_responders = true, -// } }; -// try std.testing.expect(msg != null); -// try std.testing.expectEqualDeep(msg, expected); -// } -// { -// const input = "PUB hi 3\r\nfoo\r\n"; -// var reader: std.Io.Reader = .fixed(input); -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); -// const gpa = arena.allocator(); -// const msg: ?Message = parseNextMessage(gpa, &reader); -// std.debug.print("msg: {any}\n", .{msg}); -// const expected: ?Message = .{ .@"pub" = .{ -// .subject = "hi", -// .bytes = 3, -// .payload = "foo", -// } }; -// try std.testing.expect(msg != null); -// try std.testing.expectEqualDeep(msg, expected); -// } -// } - test parseNextMessage { const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n"; var reader: std.Io.Reader = .fixed(input); @@ -201,8 +271,8 @@ test parseNextMessage { const gpa = arena.allocator(); { - const msg: ?Message = parseNextMessage(gpa, &reader); - const expected: ?Message = .{ .connect = .{ + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ .connect = .{ .verbose = false, .pedantic = false, .tls_required = false, @@ -214,24 +284,21 @@ test parseNextMessage { .headers = true, .no_responders = true, } }; - try std.testing.expect(msg != null); - try std.testing.expectEqualDeep(msg, expected); + + try std.testing.expectEqualDeep(expected, msg); } { - const msg: ?Message = parseNextMessage(gpa, &reader); - std.debug.print("msg: {any}\n", .{msg}); - const expected: ?Message = .{ .@"pub" = .{ + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ .@"pub" = .{ .subject = "hi", - .bytes = 3, .payload = "foo", } }; - try std.testing.expect(msg != null); - try std.testing.expectEqualDeep(msg, expected); + try std.testing.expectEqualDeep(expected, msg); } } -test "MessageType.parse performance" { - // Measure perf for parseMemEql - // Measure perf for parseStaticStringMap - // assert parse = fastest perf -} +// test "MessageType.parse performance" { +// // Measure perf for parseMemEql +// // Measure perf for parseStaticStringMap +// // assert parse = fastest perf +// } |
