From 1ee0f9263cd8e2984237ff34ae4625e8aaf2680c Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Mon, 12 Jan 2026 19:18:45 -0500 Subject: Properly handle large messages add max bytes setting --- src/Server.zig | 54 ++++++++++++++++++----- src/Server/parse.zig | 109 ++++++++++++++++++++++++++++++++++++----------- src/subcommand/serve.zig | 23 +++++++++- 3 files changed, 148 insertions(+), 38 deletions(-) (limited to 'src') diff --git a/src/Server.zig b/src/Server.zig index e0aca28..0259d8f 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -39,8 +39,17 @@ const Subscription = struct { if (self.queue_group) |g| alloc.free(g); } - fn send(self: *Subscription, io: Io, buf: []u8, bytes: []const []const u8) !void { - var w: std.Io.Writer = .fixed(buf); + fn send(self: *Subscription, io: Io, hot_buf: *align(std.atomic.cache_line) [512]u8, buf: []u8, bytes: []const []const u8) !void { + const total_len = blk: { + var total_len: usize = 0; + for (bytes) |chunk| { + total_len += chunk.len; + } + break :blk total_len; + }; + log.debug("Payload len: {d}", .{bytes[bytes.len - 1].len}); + var w: std.Io.Writer = .fixed(if (total_len <= hot_buf.len) hot_buf else buf); + log.debug("Using buffer size: {d}", .{w.buffer.len}); for (bytes) |chunk| { w.writeAll(chunk) catch unreachable; } @@ -171,8 +180,7 @@ fn handleConnection( const out = &writer.interface; // Set up client reader - _ = r_buf_size; - const r_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 64 * 1024 * 1024); + const r_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), r_buf_size); defer alloc.free(r_buffer); var reader = stream.reader(io, r_buffer); const in = &reader.interface; @@ -183,7 +191,8 @@ fn handleConnection( var recv_queue: Queue(u8) = .init(qbuf); defer recv_queue.close(io); - const msg_write_buf: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 1 * 1024 * 1024); + var hot_msg_buf: [std.atomic.cache_line * 4]u8 align(std.atomic.cache_line) = undefined; + const msg_write_buf: []u8 = try alloc.alloc(u8, server.info.max_payload); defer alloc.free(msg_write_buf); // Create client @@ -212,7 +221,15 @@ fn handleConnection( .PUB => { @branchHint(.likely); const before = try clock.now(io); - server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .@"pub") catch |err| switch (err) { + server.publishMessage( + io, + rand, + server_allocator, + &hot_msg_buf, + msg_write_buf, + &client, + .@"pub", + ) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, @@ -225,7 +242,15 @@ fn handleConnection( .HPUB => { @branchHint(.likely); const before = try clock.now(io); - server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .hpub) catch |err| switch (err) { + server.publishMessage( + io, + rand, + server_allocator, + &hot_msg_buf, + msg_write_buf, + &client, + .hpub, + ) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, @@ -312,6 +337,7 @@ fn publishMessage( io: Io, rand: std.Random, alloc: Allocator, + hot_write_buf: *align(std.atomic.cache_line) [512]u8, msg_write_buf: []u8, source_client: *Client, comptime pub_or_hpub: enum { @"pub", hpub }, @@ -325,13 +351,16 @@ fn publishMessage( } }; + var big_msg_arena_allocator: std.heap.ArenaAllocator = .init(alloc); + defer big_msg_arena_allocator.deinit(); + const hpubmsg = switch (pub_or_hpub) { .@"pub" => {}, - .hpub => try parse.hpub(source_client.from_client), + .hpub => try parse.hpub(source_client.from_client, &big_msg_arena_allocator), }; const msg: Message.Pub = switch (pub_or_hpub) { - .@"pub" => try parse.@"pub"(source_client.from_client), + .@"pub" => try parse.@"pub"(source_client.from_client, &big_msg_arena_allocator), .hpub => hpubmsg.@"pub", }; @@ -399,7 +428,12 @@ fn publishMessage( ); msg_chunks.appendAssumeCapacity(msg.payload); - subscription.send(io, msg_write_buf, msg_chunks.items[0..chunk_count]) catch |err| switch (err) { + subscription.send( + io, + hot_write_buf, + msg_write_buf, + msg_chunks.items[0..chunk_count], + ) catch |err| switch (err) { error.Closed => {}, error.Canceled => |e| return e, }; diff --git a/src/Server/parse.zig b/src/Server/parse.zig index 6f8281b..f47b671 100644 --- a/src/Server/parse.zig +++ b/src/Server/parse.zig @@ -51,6 +51,12 @@ pub fn control(in: *Reader) Error!message.Control { break :blk min_len; }; std.debug.assert(in.buffer.len >= longest_ctrl); + if (in.seek == in.end) { + // If there is nothing in the read buffer, reset it to start from the beginning. + // This will minimize rebases. + in.seek = 0; + in.end = 0; + } // Wait until at least the enough text to parse the shortest control value is available try in.fill(3); while (true) { @@ -60,10 +66,10 @@ pub fn control(in: *Reader) Error!message.Control { in.toss(str.len); return ctrl; } else if (str.len >= longest_ctrl) { + log.debug("ctrl too long: '{s}'\tbytes: {d}", .{ str, str.len }); return error.InvalidStream; } } - log.debug("filling more in control.", .{}); try in.fillMore(); } } @@ -121,7 +127,9 @@ test control { /// The return value is owned by the reader passed to this function. /// Operations that modify the readers buffer invalidates this value. -pub fn @"pub"(in: *Reader) Error!Message.Pub { +/// The arena_allocator is used to store the payload if it can't fit +/// in the readers buffer. +pub fn @"pub"(in: *Reader, arena_allocator: *std.heap.ArenaAllocator) (error{OutOfMemory} || Error)!Message.Pub { // TODO: Add pedantic option. // See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1 @@ -140,29 +148,53 @@ pub fn @"pub"(in: *Reader) Error!Message.Pub { continue; } if (in.buffered()[iter.index] == '\r') { - const bytes = parseUnsigned(usize, second, 10) catch return error.InvalidStream; - log.debug("received len: {d}", .{in.buffered().len}); - log.debug("headers len: {d}\tbytes: {d}", .{ iter.index, bytes }); - log.debug("buffer len: {d}", .{in.buffer.len}); - if (in.buffered().len < iter.index + bytes + "\r\n".len + "\r\n".len) { - try in.fill(iter.index + bytes + "\r\n".len + "\r\n".len); - continue; - } - in.toss(iter.index + "\r\n".len); - return .{ - .subject = subject, - .reply_to = null, - .payload = in.take(bytes + 2) catch unreachable, + const bytes = parseUnsigned(usize, second, 10) catch { + log.debug("pub can't parse bytes: '{s}'", .{second}); + return error.InvalidStream; }; + // if we can fit the payload and the headers in our read buffer + // reference the read buffer. + if (in.buffer.len > iter.index + bytes + "\r\n".len + "\r\n".len) { + // TODO: Can we use >=? + if (in.buffered().len < iter.index + bytes + "\r\n".len + "\r\n".len) { + try in.fill(iter.index + bytes + "\r\n".len + "\r\n".len); + continue; + } + in.toss(iter.index + "\r\n".len); + return .{ + .subject = subject, + .reply_to = null, + .payload = in.take(bytes + 2) catch unreachable, + }; + } else { + // else alloc the payload + const alloc = arena_allocator.allocator(); + // We have to dupe the subject because we will not retain it in the read buffer + // as we accumulate the payload. + const subject_alloc = try alloc.dupe(u8, subject); + in.toss(iter.index + "\r\n".len); + const payload = try in.readAlloc(alloc, bytes + 2); + return .{ + .subject = subject_alloc, + .reply_to = null, + .payload = payload, + }; + } + + if (in.end - in.seek - "\r\n".len > bytes) {} else {} } switch (in.buffered()[iter.index]) { '\t', ' ' => { const reply_to = second; - const bytes = parseUnsigned(usize, iter.next() orelse { + const third = iter.next() orelse { try in.fillMore(); continue; - }, 10) catch return error.InvalidStream; + }; + const bytes = parseUnsigned(usize, third, 10) catch { + log.debug("pub can't parse bytes (with reply): '{s}'", .{third}); + return error.InvalidStream; + }; if (in.buffered().len == iter.index or in.buffered()[iter.index] != '\r') { try in.fillMore(); @@ -417,7 +449,10 @@ pub fn unsub(in: *Reader) Error!Message.Unsub { } if (iter.next()) |max_msgs_str| { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') { - const max_msgs = parseUnsigned(usize, max_msgs_str, 10) catch return error.InvalidStream; + const max_msgs = parseUnsigned(usize, max_msgs_str, 10) catch { + log.debug("unsub can't parse max_msgs: '{s}'", .{max_msgs_str}); + return error.InvalidStream; + }; if (in.buffered().len < iter.index + 2) { try in.fill(iter.index + 2); @@ -443,7 +478,10 @@ pub fn unsub(in: *Reader) Error!Message.Unsub { const sid = iter.next() orelse return error.EndOfStream; const max_msgs = if (iter.next()) |max_msgs_str| blk: { log.debug("max_msgs: {any}", .{max_msgs_str}); - break :blk parseUnsigned(usize, max_msgs_str, 10) catch return error.InvalidStream; + break :blk parseUnsigned(usize, max_msgs_str, 10) catch { + log.debug("unsub can't parse bytes (eos): '{s}'", .{max_msgs_str}); + return error.InvalidStream; + }; } else null; return .{ .sid = sid, @@ -529,7 +567,10 @@ test unsub { /// The return value is owned by the reader passed to this function. /// Operations that modify the readers buffer invalidates this value. -pub fn hpub(in: *Reader) Error!Message.HPub { +/// The arena_allocator is used to store the payload if it can't fit +/// in the readers buffer. +pub fn hpub(in: *Reader, arena_allocator: *std.heap.ArenaAllocator) (error{OutOfMemory} || Error)!Message.HPub { + _ = arena_allocator; // TODO: Add pedantic option. // See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1 while (true) { @@ -543,8 +584,14 @@ pub fn hpub(in: *Reader) Error!Message.HPub { const header_bytes_str = second; const total_bytes_str = third; - const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch return error.InvalidStream; - const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch return error.InvalidStream; + const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch { + log.debug("hpub can't parse header bytes: '{s}'", .{header_bytes_str}); + return error.InvalidStream; + }; + const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch { + log.debug("hpub can't parse total bytes: '{s}'", .{header_bytes_str}); + return error.InvalidStream; + }; if (in.buffered().len < iter.index + total_bytes + 4) { try in.fill(iter.index + total_bytes + 4); @@ -568,8 +615,14 @@ pub fn hpub(in: *Reader) Error!Message.HPub { const header_bytes_str = third; if (iter.next()) |total_bytes_str| { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') { - const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch return error.InvalidStream; - const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch return error.InvalidStream; + const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch { + log.debug("hpub can't parse header bytes (with reply): '{s}'", .{header_bytes_str}); + return error.InvalidStream; + }; + const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch { + log.debug("hpub can't parse total bytes (with reply): '{s}'", .{header_bytes_str}); + return error.InvalidStream; + }; if (in.buffered().len < iter.index + total_bytes + 4) { try in.fill(iter.index + total_bytes + 4); @@ -664,14 +717,18 @@ pub fn connect(alloc: Allocator, in: *Reader) (error{OutOfMemory} || Error)!Mess connect_allocator, connect_str, .{ .allocate = .alloc_always }, - ) catch return error.InvalidStream; + ) catch { + log.debug("connect can't parse json body: '{s}'", .{connect_str}); + return error.InvalidStream; + }; return res.dupe(alloc); } inline fn expectStreamBytes(reader: *Reader, expected: []const u8) !void { if (!std.mem.eql(u8, try reader.take(expected.len), expected)) { - @branchHint(.unlikely); + @branchHint(.cold); + log.debug("expectStreamBytes wrong bytes", .{}); return error.InvalidStream; } } diff --git a/src/subcommand/serve.zig b/src/subcommand/serve.zig index e6bb648..9cb35bb 100644 --- a/src/subcommand/serve.zig +++ b/src/subcommand/serve.zig @@ -30,7 +30,7 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void { .server_id = Server.default_id, .server_name = Server.default_name, .version = "zits-master", - .max_payload = 1048576, + .max_payload = 1 * 1024 * 1024, .headers = true, }; @@ -54,6 +54,18 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void { return; } }, + .max_bytes => { + i += 1; + if (args.len > i) { + server_config.max_payload = std.fmt.parseUnsigned(usize, args[i], 10) catch { + std.log.err("Could not parse max bytes: {s}", .{args[i]}); + return; + }; + } else { + std.log.err("Must specify max bytes with {s}", .{args[i - 1]}); + return; + } + }, } } } @@ -90,16 +102,23 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void { std.log.info("Goodbye", .{}); } -const help = "serve help\n"; +const help = + \\--port/-p Specify port + \\--help/-h Get help. + \\--max-bytes Specify max message byte size + \\ +; const to_flag: std.StaticStringMap(Flag) = .initComptime(.{ .{ "-p", .port }, .{ "--port", .port }, .{ "-h", .help }, .{ "--help", .help }, + .{ "--max-bytes", .max_bytes }, }); const Flag = enum { port, help, + max_bytes, }; -- cgit