From 51008cd7e17d7e30b43107140781a72f10a58830 Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Tue, 18 Nov 2025 13:49:39 -0500 Subject: --- src/main.zig | 92 ++++------------------- src/root.zig | 24 +----- src/server/client.zig | 18 +++++ src/server/message_parser.zig | 168 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 101 deletions(-) create mode 100644 src/server/client.zig create mode 100644 src/server/message_parser.zig diff --git a/src/main.zig b/src/main.zig index 6ebd11e..bcb2d98 100644 --- a/src/main.zig +++ b/src/main.zig @@ -2,6 +2,8 @@ const std = @import("std"); const zits = @import("zits"); const clap = @import("clap"); +const MessageType = zits.MessageParser.MessageType; + const SubCommands = enum { help, serve, @@ -135,12 +137,12 @@ fn handleConnection(io: std.Io, stream: std.Io.net.Stream, info: ServerInfo) voi var writer = stream.writer(io, &w_buffer); const out = &writer.interface; - var r_buffer: [1024]u8 = undefined; + var r_buffer: [8192]u8 = undefined; var reader = stream.reader(io, &r_buffer); const in = &reader.interface; processClient(in, out, info) catch |err| { - std.debug.print("Error processing client: {}\n", .{err}); + std.debug.panic("Error processing client: {}\n", .{err}); }; // var stdout_buffer: [1024]u8 = undefined; @@ -171,102 +173,36 @@ fn handleConnection(io: std.Io, stream: std.Io.net.Stream, info: ServerInfo) voi fn processClient(in: *std.Io.Reader, out: *std.Io.Writer, info: ServerInfo) !void { try writeInfo(out, info); - const ClientState = struct { - verbose: bool = false, - pedantic: bool = false, - tls_required: bool = false, - auth_token: ?[]const u8 = null, - user: ?[]const u8 = null, - pass: ?[]const u8 = null, - name: ?[]const u8 = null, - lang: []const u8, - version: []const u8, - protocol: u32, - echo: ?bool = null, - sig: ?[]const u8 = null, - jwt: ?[]const u8 = null, - no_responders: ?bool = null, - headers: ?bool = null, - nkey: ?[]const u8 = null, - }; - - const MessageType = enum { - info, - connect, - @"pub", - hpub, - sub, - unsub, - msg, - hmsg, - ping, - pong, - @"+ok", - @"-err", - - // fn parse(input: []u8) !MessageType { - // // if (std.mem.eql(u8, "INFO", input)) return .info; - // if (std.mem.eql(u8, "CONNECT", input)) return .connect; - // if (std.mem.eql(u8, "PUB", input)) return .@"pub"; - // if (std.mem.eql(u8, "HPUB", input)) return .hpub; - // if (std.mem.eql(u8, "SUB", input)) return .sub; - // if (std.mem.eql(u8, "UNSUB", input)) return .unsub; - // // if (std.mem.eql(u8, "MSG", input)) return .msg; - // // if (std.mem.eql(u8, "HMSG", input)) return .hmsg; - // if (std.mem.eql(u8, "PING", input)) return .ping; - // if (std.mem.eql(u8, "PONG", input)) return .pong; - // // if (std.mem.eql(u8, "@"+OK"", input)) return .@"+ok"; - // // if (std.mem.eql(u8, "@"-ERR"", input)) return .@"-err"; - // return error.InvalidMessageType; - // } - - const client_types = std.StaticStringMap(@This()).initComptime( - .{ - // {"INFO", .info}, - .{ "CONNECT", .connect }, - .{ "PUB", .@"pub" }, - .{ "HPUB", .hpub }, - .{ "SUB", .sub }, - .{ "UNSUB", .unsub }, - // {"MSG", .msg}, - // {"HMSG", .hmsg}, - .{ "PING", .ping }, - .{ "PONG", .pong }, - // {"+OK", .@"+ok"}, - // {"-ERR", .@"-err"}, - }, - ); - fn parse(input: []u8) !@This() { - return client_types.get(input) orelse return error.InvalidMessageType; - } - }; - - const initial_message_type = try MessageType.parse((in.takeDelimiter(' ') catch return error.InvalidMessageType) orelse return error.InvalidMessageType); + const initial_message_type = MessageType.parse((in.takeDelimiter(' ') catch return error.InvalidMessageType) orelse "") orelse return error.InvalidMessageType; if (initial_message_type != .connect) return error.InvalidMessageType; // move this inside client_state declaration - var json_parse_buf: [1024]u8 = undefined; + var json_parse_buf: [4096]u8 = undefined; var json_parse_alloc_fb: std.heap.FixedBufferAllocator = std.heap.FixedBufferAllocator.init(&json_parse_buf); var json_parse_alloc = json_parse_alloc_fb.allocator(); var json_reader: std.json.Reader = .init(json_parse_alloc, in); std.debug.print("buffered:{s}\n", .{in.buffered()}); - var client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); + // var client_state = try std.json.parseFromSliceLeaky(ClientState, json_parse_alloc, in.buffered(), .{}); + // in.toss(in.buffered().len); + + // var client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); - std.debug.print("client_state: {any}\n", .{client_state}); + const client_state = 0; + std.debug.print("client_state: {}\n", .{client_state}); while (true) { // Rebase the next message to the start of the buffer // in.rebase(in.buffer.len); - const next_message_type = try MessageType.parse((in.takeDelimiter(' ') catch return error.InvalidMessageType) orelse return error.InvalidMessageType); + const next_message_type = MessageType.parse((in.takeDelimiter(' ') catch return error.InvalidMessageType) orelse "") orelse return error.InvalidMessageType; switch (next_message_type) { .connect => { json_parse_alloc_fb = std.heap.FixedBufferAllocator.init(&json_parse_buf); json_parse_alloc = json_parse_alloc_fb.allocator(); json_reader = .init(json_parse_alloc, in); - client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); + // client_state = try std.json.parseFromTokenSourceLeaky(ClientState, json_parse_alloc, &json_reader, .{}); std.debug.print("client_state: {any}\n", .{client_state}); }, else => |msg| std.debug.print("received {}\n", .{msg}), diff --git a/src/root.zig b/src/root.zig index 94c7cd0..12ebb2b 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,23 +1 @@ -//! By convention, root.zig is the root source file when making a library. -const std = @import("std"); - -pub fn bufferedPrint() !void { - // Stdout is for the actual output of your application, for example if you - // are implementing gzip, then only the compressed bytes should be sent to - // stdout, not any debugging messages. - var stdout_buffer: [1024]u8 = undefined; - var stdout_writer = std.fs.File.stdout().writer(&stdout_buffer); - const stdout = &stdout_writer.interface; - - try stdout.print("Run `zig build test` to run the tests.\n", .{}); - - try stdout.flush(); // Don't forget to flush! -} - -pub fn add(a: i32, b: i32) i32 { - return a + b; -} - -test "basic add functionality" { - try std.testing.expect(add(3, 7) == 10); -} +pub const MessageParser = @import("server/message_parser.zig"); diff --git a/src/server/client.zig b/src/server/client.zig new file mode 100644 index 0000000..8b49b89 --- /dev/null +++ b/src/server/client.zig @@ -0,0 +1,18 @@ +const ClientState = struct { + verbose: bool = false, + pedantic: bool = false, + tls_required: bool = false, + auth_token: ?[]const u8 = null, + user: ?[]const u8 = null, + pass: ?[]const u8 = null, + name: ?[]const u8 = null, + lang: []const u8, + version: []const u8, + protocol: u32, + echo: ?bool = null, + sig: ?[]const u8 = null, + jwt: ?[]const u8 = null, + no_responders: ?bool = null, + headers: ?bool = null, + nkey: ?[]const u8 = null, +}; diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig new file mode 100644 index 0000000..c8359d5 --- /dev/null +++ b/src/server/message_parser.zig @@ -0,0 +1,168 @@ +const std = @import("std"); + +pub const MessageType = enum { + info, + connect, + @"pub", + hpub, + sub, + unsub, + msg, + hmsg, + ping, + pong, + @"+ok", + @"-err", + + fn parseMemEql(input: []const u8) ?MessageType { + // if (std.mem.eql(u8, "INFO", input)) return .info; + if (std.mem.eql(u8, "CONNECT", input)) return .connect; + if (std.mem.eql(u8, "PUB", input)) return .@"pub"; + if (std.mem.eql(u8, "HPUB", input)) return .hpub; + if (std.mem.eql(u8, "SUB", input)) return .sub; + if (std.mem.eql(u8, "UNSUB", input)) return .unsub; + // if (std.mem.eql(u8, "MSG", input)) return .msg; + // if (std.mem.eql(u8, "HMSG", input)) return .hmsg; + if (std.mem.eql(u8, "PING", input)) return .ping; + if (std.mem.eql(u8, "PONG", input)) return .pong; + // if (std.mem.eql(u8, "@"+OK"", input)) return .@"+ok"; + // if (std.mem.eql(u8, "@"-ERR"", input)) return .@"-err"; + return error.InvalidMessageType; + } + + const client_types = std.StaticStringMap(MessageType).initComptime( + .{ + // {"INFO", .info}, + .{ "CONNECT", .connect }, + .{ "PUB", .@"pub" }, + .{ "HPUB", .hpub }, + .{ "SUB", .sub }, + .{ "UNSUB", .unsub }, + // {"MSG", .msg}, + // {"HMSG", .hmsg}, + .{ "PING", .ping }, + .{ "PONG", .pong }, + // {"+OK", .@"+ok"}, + // {"-ERR", .@"-err"}, + }, + ); + fn parseStaticStringMap(input: []const u8) ?MessageType { + std.debug.print("input: '{s}'\n", .{input}); + return client_types.get(input); + } + + pub const parse = parseStaticStringMap; +}; + +const Message = union(MessageType) { + info: void, + + connect: Connect, + @"pub": Pub, + hpub: void, + sub: void, + unsub: void, + msg: void, + hmsg: void, + ping, + pong, + @"+ok": void, + @"-err": void, + const Connect = struct { + verbose: bool = false, + pedantic: bool = false, + tls_required: bool = false, + auth_token: ?[]const u8 = null, + user: ?[]const u8 = null, + pass: ?[]const u8 = null, + name: ?[]const u8 = null, + lang: []const u8, + version: []const u8, + protocol: u32, + echo: ?bool = null, + sig: ?[]const u8 = null, + jwt: ?[]const u8 = null, + no_responders: ?bool = null, + headers: ?bool = null, + nkey: ?[]const u8 = null, + }; + const Pub = struct { + subject: []const u8, + reply_to: ?[]const u8, + bytes: usize, + payload: []const u8, + }; +}; + +fn parseJsonMessage(T: type, alloc: std.mem.Allocator, in: *std.Io.Reader) !T { + var json_reader: std.json.Reader = .init(alloc, in); + defer json_reader.deinit(); + + return std.json.parseFromTokenSourceLeaky(T, alloc, &json_reader, .{}); +} + +fn parsePub(in: *std.Io.Reader) !Message.Pub { + const subject = (try in.takeDelimiter(' ')) orelse return error.EndOfStream; + const next = (try in.takeDelimiter(' ')) orelse return error.EndOfStream; + var reply_to: ?[]const u8 = null; + const bytes = std.fmt.parseUnsigned(usize, next, 10) catch blk: { + reply_to = next; + break :blk try std.fmt.parseUnsigned(usize, (try in.takeDelimiter(' ')) orelse return error.EndOfStream, 10); + }; + // in.toss(2); // CRLF + const payload = try in.take(bytes); + + return .{ + .subject = subject, + .reply_to = reply_to, + .bytes = bytes, + .payload = payload, + }; +} + +// try returning error in debug mode, only null in release? +pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { + const message_type: MessageType = blk: { + const word: []const u8 = (in.takeDelimiter(' ') catch return null) orelse return null; + std.debug.print("word: {s}\n", .{word}); + break :blk MessageType.parse(word) orelse return null; + }; + // defer in.toss(2); // CRLF + return switch (message_type) { + .connect => .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }, + .@"pub" => .{ .@"pub" = parsePub(in) catch return null }, + .ping => .{ .ping = {} }, + else => null, + }; +} + +test parseNextMessage { + const input = + \\CONNECT {"verbose":false,"pedantic":false,"tls_required":false,"name":"NATS CLI Version v0.2.4","lang":"go","version":"1.43.0","protocol":1,"echo":true,"headers":true,"no_responders":true} + ; + var reader: std.Io.Reader = .fixed(input); + var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); + defer arena.deinit(); + const gpa = arena.allocator(); + const msg: ?Message = parseNextMessage(gpa, &reader); + const expected: ?Message = .{ .connect = .{ + .verbose = false, + .pedantic = false, + .tls_required = false, + .name = "NATS CLI Version v0.2.4", + .lang = "go", + .version = "1.43.0", + .protocol = 1, + .echo = true, + .headers = true, + .no_responders = true, + } }; + try std.testing.expect(msg != null); + try std.testing.expectEqualDeep(msg, expected); +} + +test "MessageType.parse performance" { + // Measure perf for parseMemEql + // Measure perf for parseStaticStringMap + // assert parse = fastest perf +} -- cgit