From c6dfcc541d14a934ab739cb56f4e11882f46e9ea Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Tue, 25 Nov 2025 13:25:12 -0500 Subject: Can handle PUB --- src/main.zig | 116 ++++++++------------ src/server/message_parser.zig | 239 +++++++++++++++++++++++++++--------------- 2 files changed, 197 insertions(+), 158 deletions(-) (limited to 'src') diff --git a/src/main.zig b/src/main.zig index fb6be1c..e24bbcb 100644 --- a/src/main.zig +++ b/src/main.zig @@ -2,8 +2,7 @@ const std = @import("std"); const zits = @import("zits"); const clap = @import("clap"); -const MessageType = zits.MessageParser.MessageType; -const parseNextMessage = zits.MessageParser.parseNextMessage; +const Message = zits.MessageParser.Message; const SubCommands = enum { help, @@ -27,7 +26,7 @@ const main_params = clap.parseParamsComptime( const MainArgs = clap.ResultEx(clap.Help, &main_params, main_parsers); pub fn main() !void { - var dba = std.heap.DebugAllocator(.{}){}; + var dba: std.heap.DebugAllocator(.{}) = .init; defer _ = dba.deinit(); const gpa = dba.allocator(); @@ -109,30 +108,27 @@ fn serverMain(gpa: std.mem.Allocator, iter: *std.process.ArgIterator, main_args: .max_payload = 1048576, }; - // const info: ServerInfo = .{ - // .server_id = "foo", - // .server_name = "bar", - // .version = "6.9.0", - // .max_payload = 6969, - // }; - - var server = try std.Io.net.IpAddress.listen(.{ - .ip4 = .{ - .bytes = .{ 0, 0, 0, 0 }, - .port = info.port, + var server = try std.Io.net.IpAddress.listen( + .{ + .ip4 = .{ + .bytes = .{ 0, 0, 0, 0 }, + .port = info.port, + }, }, - }, io, .{}); + io, + .{}, + ); defer server.deinit(io); var group: std.Io.Group = .init; defer group.wait(io); for (0..5) |_| { const stream = try server.accept(io); - group.async(io, handleConnection, .{ io, stream, info }); + group.async(io, handleConnection, .{ gpa, io, stream, info }); } } -fn handleConnection(io: std.Io, stream: std.Io.net.Stream, info: ServerInfo) void { +fn handleConnection(allocator: std.mem.Allocator, io: std.Io, stream: std.Io.net.Stream, info: ServerInfo) void { defer stream.close(io); var w_buffer: [1024]u8 = undefined; var writer = stream.writer(io, &w_buffer); @@ -142,76 +138,52 @@ fn handleConnection(io: std.Io, stream: std.Io.net.Stream, info: ServerInfo) voi var reader = stream.reader(io, &r_buffer); const in = &reader.interface; - processClient(in, out, info) catch |err| { + processClient(allocator, in, out, info) catch |err| { std.debug.panic("Error processing client: {}\n", .{err}); }; - - // var stdout_buffer: [1024]u8 = undefined; - // const stdout_file = std.fs.File.stdout(); - // var stdout_file_writer = stdout_file.writer(&stdout_buffer); - // const stdout_writer = &stdout_file_writer.interface; - - // var timeout = io.async(std.Io.sleep, .{ io, .fromSeconds(10), .real }); - // defer timeout.cancel(io) catch {}; - - // var user_res = io.async(std.Io.Reader.streamRemaining, .{ in, stdout_writer }); - // defer _ = user_res.cancel(io) catch {}; - - // switch (io.select(.{ - // .timeout = &timeout, - // .data = &user_res, - // }) catch unreachable) { - // .timeout => std.debug.print("timeout\n", .{}), - // .data => |_| { - // stdout_writer.flush() catch |err| { - // std.debug.print("Could not flush stdout: {}\n", .{err}); - // }; - // // std.debug.print("received data {any}\n", .{d}); - // }, - // } } -fn processClient(in: *std.Io.Reader, out: *std.Io.Writer, info: ServerInfo) !void { +fn processClient(gpa: std.mem.Allocator, in: *std.Io.Reader, out: *std.Io.Writer, info: ServerInfo) !void { try writeInfo(out, info); - // move this inside client_state declaration - var json_parse_buf: [4096]u8 = undefined; - var json_parse_alloc_fb: std.heap.FixedBufferAllocator = std.heap.FixedBufferAllocator.init(&json_parse_buf); - var json_parse_alloc = json_parse_alloc_fb.allocator(); - var json_reader: std.json.Reader = .init(json_parse_alloc, in); - - // var client_state = try std.json.parseFromSliceLeaky(ClientState, json_parse_alloc, in.buffered(), .{}); - // in.toss(in.buffered().len); - - // var client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); - - const client_state = 0; - std.debug.print("client_state: {}\n", .{client_state}); + var client_state_arena: std.heap.ArenaAllocator = .init(gpa); + defer client_state_arena.deinit(); + const client_state = (try Message.next(client_state_arena.allocator(), in)).connect; + _ = client_state; + var message_parsing_arena: std.heap.ArenaAllocator = .init(gpa); + defer message_parsing_arena.deinit(); + const message_parsing_allocator = message_parsing_arena.allocator(); while (true) { - const next_message_type = parseNextMessage(json_parse_alloc, in) orelse return; - - switch (next_message_type) { + defer _ = message_parsing_arena.reset(.retain_capacity); + const next_message = Message.next(message_parsing_allocator, in) catch |err| { + switch (err) { + error.EndOfStream => { + break; + }, + else => { + return err; + }, + } + }; + switch (next_message) { .connect => |connect| { - std.debug.print("connect: {s}\n", .{connect.name orelse "\"\""}); - json_parse_alloc_fb = .init(&json_parse_buf); - json_parse_alloc = json_parse_alloc_fb.allocator(); - json_reader = .init(json_parse_alloc, in); - // client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); - std.debug.print("client_state: {any}\n", .{client_state}); + std.debug.panic("Connection message after already connected: {any}\n", .{connect}); }, - .ping => writePong(out) catch |err| { - std.debug.panic("failed to pong: {any}\n", .{err}); - }, - else => |msg| std.debug.print("received {}\n", .{msg}), + .ping => try writePong(out), + .@"pub" => try writeOk(out), + else => |msg| std.debug.panic("Message type not implemented: {any}\n", .{msg}), } } } +fn writeOk(out: *std.Io.Writer) !void { + _ = try out.write("+OK\r\n"); + try out.flush(); +} + fn writePong(out: *std.Io.Writer) !void { - std.debug.print("in writePong\n", .{}); - _ = try out.write("PONG"); - _ = try out.write("\r\n"); + _ = try out.write("PONG\r\n"); try out.flush(); } diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig index 2bc5286..75e13d2 100644 --- a/src/server/message_parser.zig +++ b/src/server/message_parser.zig @@ -29,31 +29,9 @@ pub const MessageType = enum { // if (std.mem.eql(u8, "@"-ERR"", input)) return .@"-err"; return error.InvalidMessageType; } - - const client_types = std.StaticStringMap(MessageType).initComptime( - .{ - // {"INFO", .info}, - .{ "CONNECT", .connect }, - .{ "PUB", .@"pub" }, - .{ "HPUB", .hpub }, - .{ "SUB", .sub }, - .{ "UNSUB", .unsub }, - // {"MSG", .msg}, - // {"HMSG", .hmsg}, - .{ "PING", .ping }, - .{ "PONG", .pong }, - // {"+OK", .@"+ok"}, - // {"-ERR", .@"-err"}, - }, - ); - fn parseStaticStringMap(input: []const u8) ?MessageType { - return client_types.get(input); - } - - pub const parse = parseStaticStringMap; }; -const Message = union(MessageType) { +pub const Message = union(enum) { info: void, connect: Connect, @@ -88,9 +66,144 @@ const Message = union(MessageType) { const Pub = struct { subject: []const u8, reply_to: ?[]const u8 = null, - bytes: usize, payload: []const u8, }; + + const client_types = std.StaticStringMap(MessageType).initComptime( + .{ + // {"INFO", .info}, + .{ "CONNECT", .connect }, + .{ "PUB", .@"pub" }, + .{ "HPUB", .hpub }, + .{ "SUB", .sub }, + .{ "UNSUB", .unsub }, + // {"MSG", .msg}, + // {"HMSG", .hmsg}, + .{ "PING", .ping }, + .{ "PONG", .pong }, + // {"+OK", .@"+ok"}, + // {"-ERR", .@"-err"}, + }, + ); + fn parseStaticStringMap(input: []const u8) ?MessageType { + return client_types.get(input); + } + + pub const parse = parseStaticStringMap; + + /// An error should be handled by cleaning up this connection. + pub fn next(alloc: std.mem.Allocator, in: *std.Io.Reader) !Message { + var operation_string: std.ArrayList(u8) = blk: { + var buf: ["CONTINUE".len]u8 = undefined; + break :blk .initBuffer(&buf); + }; + + while (in.peekByte()) |byte| { + if (std.ascii.isUpper(byte)) { + try operation_string.appendBounded(byte); + try in.discardAll(1); + } else break; + } else |err| return err; + + const operation = parse(operation_string.items) orelse { + std.debug.print("operation: '{s}'\n", .{operation_string.items}); + std.debug.print("buffered: '{s}'", .{in.buffered()}); + return error.InvalidOperation; + }; + + switch (operation) { + .connect => { + // TODO: should be ARENA allocator + var connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(alloc, 1024); + var connect_string_writer = connect_string_writer_allocating.writer; + try in.discardAll(1); // throw away space + + // Should read the next JSON object to the fixed buffer writer. + _ = try in.streamDelimiter(&connect_string_writer, '}'); + try connect_string_writer.writeByte('}'); + std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' + + // TODO: should be CONNECTION allocator + const res = try std.json.parseFromSliceLeaky(Connect, alloc, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); + + return .{ .connect = res }; + }, + .@"pub" => { + try in.discardAll(1); // throw away space + + // Parse subject + const subject: []const u8 = blk: { + // TODO: should be ARENA allocator + var subject_list: std.ArrayList(u8) = try .initCapacity(alloc, 1024); + + // Handle the first character + { + const byte = try in.takeByte(); + if (byte == '.' or std.ascii.isWhitespace(byte)) + return error.InvalidSubject; + + try subject_list.append(alloc, byte); + } + + while (in.takeByte() catch null) |byte| { + if (std.ascii.isWhitespace(byte)) break; + if (std.ascii.isAscii(byte)) { + if (byte == '.') { + const next_byte = try in.peekByte(); + if (next_byte == '.' or std.ascii.isWhitespace(next_byte)) + return error.InvalidSubject; + } + try subject_list.append(alloc, byte); + } + } else return error.InvalidStream; + break :blk subject_list.items; + }; + + // Parse byte count + const byte_count = blk: { + var byte_count_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); + while (in.takeByte() catch null) |byte| { + if (std.ascii.isWhitespace(byte)) { + std.debug.assert(byte == '\r'); + std.debug.assert(try in.takeByte() == '\n'); + break; + } + + if (std.ascii.isDigit(byte)) { + try byte_count_list.append(alloc, byte); + } else { + return error.InvalidStream; + } + } else return error.InvalidStream; + + break :blk try std.fmt.parseUnsigned(u64, byte_count_list.items, 10); + }; + + const payload = blk: { + const bytes = try alloc.alloc(u8, byte_count); + try in.readSliceAll(bytes); + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + break :blk bytes; + }; + + std.debug.print("buffer: '{s}'\n", .{in.buffered()}); + // return std.debug.panic("not implemented", .{}); + return .{ .@"pub" = .{ + .subject = subject, + .payload = payload, + } }; + }, + .ping => { + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + return .ping; + }, + .pong => { + std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + return .pong; + }, + else => |msg| std.debug.panic("Not implemented: {}\n", .{msg}), + } + } }; fn parseJsonMessage(T: type, alloc: std.mem.Allocator, in: *std.Io.Reader) !T { @@ -104,7 +217,6 @@ fn parseJsonMessage(T: type, alloc: std.mem.Allocator, in: *std.Io.Reader) !T { fn parsePub(in: *std.Io.Reader) !Message.Pub { const subject = (try in.takeDelimiter(' ')) orelse return error.EndOfStream; const next = (try in.takeDelimiter('\r')) orelse return error.EndOfStream; - std.debug.print("next: '{s}'\n", .{next}); var reply_to: ?[]const u8 = null; const bytes = std.fmt.parseUnsigned(usize, next, 10) catch blk: { reply_to = next; @@ -134,16 +246,15 @@ pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { len = i + 1; } else break; } - std.debug.print("word: '{s}'\n", .{word[0..len]}); - break :blk MessageType.parse(word[0..len]) orelse return null; + + break :blk Message.parse(word[0..len]) orelse return null; }; - std.debug.print("buffered: '{s}'\n", .{in.buffered()}); + // defer in.toss(2); // CRLF return switch (message_type) { .connect => blk: { const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }; - std.debug.print("value: {s}\n", .{value.?.connect.name.?}); - std.debug.print("buffered: '{d}'\n", .{in.buffered().len}); + break :blk value; }, .@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) }, @@ -152,47 +263,6 @@ pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { }; } -// test parseNextMessage { -// { -// const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\naftertheendoftheinput"; -// var reader: std.Io.Reader = .fixed(input); -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); -// const gpa = arena.allocator(); -// const msg: ?Message = parseNextMessage(gpa, &reader); -// const expected: ?Message = .{ .connect = .{ -// .verbose = false, -// .pedantic = false, -// .tls_required = false, -// .name = "NATS CLI Version v0.2.4", -// .lang = "go", -// .version = "1.43.0", -// .protocol = 1, -// .echo = true, -// .headers = true, -// .no_responders = true, -// } }; -// try std.testing.expect(msg != null); -// try std.testing.expectEqualDeep(msg, expected); -// } -// { -// const input = "PUB hi 3\r\nfoo\r\n"; -// var reader: std.Io.Reader = .fixed(input); -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); -// const gpa = arena.allocator(); -// const msg: ?Message = parseNextMessage(gpa, &reader); -// std.debug.print("msg: {any}\n", .{msg}); -// const expected: ?Message = .{ .@"pub" = .{ -// .subject = "hi", -// .bytes = 3, -// .payload = "foo", -// } }; -// try std.testing.expect(msg != null); -// try std.testing.expectEqualDeep(msg, expected); -// } -// } - test parseNextMessage { const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n"; var reader: std.Io.Reader = .fixed(input); @@ -201,8 +271,8 @@ test parseNextMessage { const gpa = arena.allocator(); { - const msg: ?Message = parseNextMessage(gpa, &reader); - const expected: ?Message = .{ .connect = .{ + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ .connect = .{ .verbose = false, .pedantic = false, .tls_required = false, @@ -214,24 +284,21 @@ test parseNextMessage { .headers = true, .no_responders = true, } }; - try std.testing.expect(msg != null); - try std.testing.expectEqualDeep(msg, expected); + + try std.testing.expectEqualDeep(expected, msg); } { - const msg: ?Message = parseNextMessage(gpa, &reader); - std.debug.print("msg: {any}\n", .{msg}); - const expected: ?Message = .{ .@"pub" = .{ + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ .@"pub" = .{ .subject = "hi", - .bytes = 3, .payload = "foo", } }; - try std.testing.expect(msg != null); - try std.testing.expectEqualDeep(msg, expected); + try std.testing.expectEqualDeep(expected, msg); } } -test "MessageType.parse performance" { - // Measure perf for parseMemEql - // Measure perf for parseStaticStringMap - // assert parse = fastest perf -} +// test "MessageType.parse performance" { +// // Measure perf for parseMemEql +// // Measure perf for parseStaticStringMap +// // assert parse = fastest perf +// } -- cgit