summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobby Zambito <contact@robbyzambito.me>2026-01-12 19:18:45 -0500
committerRobby Zambito <contact@robbyzambito.me>2026-01-12 22:21:48 -0500
commit1ee0f9263cd8e2984237ff34ae4625e8aaf2680c (patch)
treee91d4fa1f6c890f02cf9fbaebe3912e3774777d2
parentc5ad98adc6ebea6627fc08c2c16324610a8a97e0 (diff)
Properly handle large messagesHEADdev
add max bytes setting
-rw-r--r--src/Server.zig54
-rw-r--r--src/Server/parse.zig109
-rw-r--r--src/subcommand/serve.zig23
3 files changed, 148 insertions, 38 deletions
diff --git a/src/Server.zig b/src/Server.zig
index e0aca28..0259d8f 100644
--- a/src/Server.zig
+++ b/src/Server.zig
@@ -39,8 +39,17 @@ const Subscription = struct {
if (self.queue_group) |g| alloc.free(g);
}
- fn send(self: *Subscription, io: Io, buf: []u8, bytes: []const []const u8) !void {
- var w: std.Io.Writer = .fixed(buf);
+ fn send(self: *Subscription, io: Io, hot_buf: *align(std.atomic.cache_line) [512]u8, buf: []u8, bytes: []const []const u8) !void {
+ const total_len = blk: {
+ var total_len: usize = 0;
+ for (bytes) |chunk| {
+ total_len += chunk.len;
+ }
+ break :blk total_len;
+ };
+ log.debug("Payload len: {d}", .{bytes[bytes.len - 1].len});
+ var w: std.Io.Writer = .fixed(if (total_len <= hot_buf.len) hot_buf else buf);
+ log.debug("Using buffer size: {d}", .{w.buffer.len});
for (bytes) |chunk| {
w.writeAll(chunk) catch unreachable;
}
@@ -171,8 +180,7 @@ fn handleConnection(
const out = &writer.interface;
// Set up client reader
- _ = r_buf_size;
- const r_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 64 * 1024 * 1024);
+ const r_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), r_buf_size);
defer alloc.free(r_buffer);
var reader = stream.reader(io, r_buffer);
const in = &reader.interface;
@@ -183,7 +191,8 @@ fn handleConnection(
var recv_queue: Queue(u8) = .init(qbuf);
defer recv_queue.close(io);
- const msg_write_buf: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 1 * 1024 * 1024);
+ var hot_msg_buf: [std.atomic.cache_line * 4]u8 align(std.atomic.cache_line) = undefined;
+ const msg_write_buf: []u8 = try alloc.alloc(u8, server.info.max_payload);
defer alloc.free(msg_write_buf);
// Create client
@@ -212,7 +221,15 @@ fn handleConnection(
.PUB => {
@branchHint(.likely);
const before = try clock.now(io);
- server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .@"pub") catch |err| switch (err) {
+ server.publishMessage(
+ io,
+ rand,
+ server_allocator,
+ &hot_msg_buf,
+ msg_write_buf,
+ &client,
+ .@"pub",
+ ) catch |err| switch (err) {
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
@@ -225,7 +242,15 @@ fn handleConnection(
.HPUB => {
@branchHint(.likely);
const before = try clock.now(io);
- server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .hpub) catch |err| switch (err) {
+ server.publishMessage(
+ io,
+ rand,
+ server_allocator,
+ &hot_msg_buf,
+ msg_write_buf,
+ &client,
+ .hpub,
+ ) catch |err| switch (err) {
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
@@ -312,6 +337,7 @@ fn publishMessage(
io: Io,
rand: std.Random,
alloc: Allocator,
+ hot_write_buf: *align(std.atomic.cache_line) [512]u8,
msg_write_buf: []u8,
source_client: *Client,
comptime pub_or_hpub: enum { @"pub", hpub },
@@ -325,13 +351,16 @@ fn publishMessage(
}
};
+ var big_msg_arena_allocator: std.heap.ArenaAllocator = .init(alloc);
+ defer big_msg_arena_allocator.deinit();
+
const hpubmsg = switch (pub_or_hpub) {
.@"pub" => {},
- .hpub => try parse.hpub(source_client.from_client),
+ .hpub => try parse.hpub(source_client.from_client, &big_msg_arena_allocator),
};
const msg: Message.Pub = switch (pub_or_hpub) {
- .@"pub" => try parse.@"pub"(source_client.from_client),
+ .@"pub" => try parse.@"pub"(source_client.from_client, &big_msg_arena_allocator),
.hpub => hpubmsg.@"pub",
};
@@ -399,7 +428,12 @@ fn publishMessage(
);
msg_chunks.appendAssumeCapacity(msg.payload);
- subscription.send(io, msg_write_buf, msg_chunks.items[0..chunk_count]) catch |err| switch (err) {
+ subscription.send(
+ io,
+ hot_write_buf,
+ msg_write_buf,
+ msg_chunks.items[0..chunk_count],
+ ) catch |err| switch (err) {
error.Closed => {},
error.Canceled => |e| return e,
};
diff --git a/src/Server/parse.zig b/src/Server/parse.zig
index 6f8281b..f47b671 100644
--- a/src/Server/parse.zig
+++ b/src/Server/parse.zig
@@ -51,6 +51,12 @@ pub fn control(in: *Reader) Error!message.Control {
break :blk min_len;
};
std.debug.assert(in.buffer.len >= longest_ctrl);
+ if (in.seek == in.end) {
+ // If there is nothing in the read buffer, reset it to start from the beginning.
+ // This will minimize rebases.
+ in.seek = 0;
+ in.end = 0;
+ }
// Wait until at least the enough text to parse the shortest control value is available
try in.fill(3);
while (true) {
@@ -60,10 +66,10 @@ pub fn control(in: *Reader) Error!message.Control {
in.toss(str.len);
return ctrl;
} else if (str.len >= longest_ctrl) {
+ log.debug("ctrl too long: '{s}'\tbytes: {d}", .{ str, str.len });
return error.InvalidStream;
}
}
- log.debug("filling more in control.", .{});
try in.fillMore();
}
}
@@ -121,7 +127,9 @@ test control {
/// The return value is owned by the reader passed to this function.
/// Operations that modify the readers buffer invalidates this value.
-pub fn @"pub"(in: *Reader) Error!Message.Pub {
+/// The arena_allocator is used to store the payload if it can't fit
+/// in the readers buffer.
+pub fn @"pub"(in: *Reader, arena_allocator: *std.heap.ArenaAllocator) (error{OutOfMemory} || Error)!Message.Pub {
// TODO: Add pedantic option.
// See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1
@@ -140,29 +148,53 @@ pub fn @"pub"(in: *Reader) Error!Message.Pub {
continue;
}
if (in.buffered()[iter.index] == '\r') {
- const bytes = parseUnsigned(usize, second, 10) catch return error.InvalidStream;
- log.debug("received len: {d}", .{in.buffered().len});
- log.debug("headers len: {d}\tbytes: {d}", .{ iter.index, bytes });
- log.debug("buffer len: {d}", .{in.buffer.len});
- if (in.buffered().len < iter.index + bytes + "\r\n".len + "\r\n".len) {
- try in.fill(iter.index + bytes + "\r\n".len + "\r\n".len);
- continue;
- }
- in.toss(iter.index + "\r\n".len);
- return .{
- .subject = subject,
- .reply_to = null,
- .payload = in.take(bytes + 2) catch unreachable,
+ const bytes = parseUnsigned(usize, second, 10) catch {
+ log.debug("pub can't parse bytes: '{s}'", .{second});
+ return error.InvalidStream;
};
+ // if we can fit the payload and the headers in our read buffer
+ // reference the read buffer.
+ if (in.buffer.len > iter.index + bytes + "\r\n".len + "\r\n".len) {
+ // TODO: Can we use >=?
+ if (in.buffered().len < iter.index + bytes + "\r\n".len + "\r\n".len) {
+ try in.fill(iter.index + bytes + "\r\n".len + "\r\n".len);
+ continue;
+ }
+ in.toss(iter.index + "\r\n".len);
+ return .{
+ .subject = subject,
+ .reply_to = null,
+ .payload = in.take(bytes + 2) catch unreachable,
+ };
+ } else {
+ // else alloc the payload
+ const alloc = arena_allocator.allocator();
+ // We have to dupe the subject because we will not retain it in the read buffer
+ // as we accumulate the payload.
+ const subject_alloc = try alloc.dupe(u8, subject);
+ in.toss(iter.index + "\r\n".len);
+ const payload = try in.readAlloc(alloc, bytes + 2);
+ return .{
+ .subject = subject_alloc,
+ .reply_to = null,
+ .payload = payload,
+ };
+ }
+
+ if (in.end - in.seek - "\r\n".len > bytes) {} else {}
}
switch (in.buffered()[iter.index]) {
'\t', ' ' => {
const reply_to = second;
- const bytes = parseUnsigned(usize, iter.next() orelse {
+ const third = iter.next() orelse {
try in.fillMore();
continue;
- }, 10) catch return error.InvalidStream;
+ };
+ const bytes = parseUnsigned(usize, third, 10) catch {
+ log.debug("pub can't parse bytes (with reply): '{s}'", .{third});
+ return error.InvalidStream;
+ };
if (in.buffered().len == iter.index or in.buffered()[iter.index] != '\r') {
try in.fillMore();
@@ -417,7 +449,10 @@ pub fn unsub(in: *Reader) Error!Message.Unsub {
}
if (iter.next()) |max_msgs_str| {
if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
- const max_msgs = parseUnsigned(usize, max_msgs_str, 10) catch return error.InvalidStream;
+ const max_msgs = parseUnsigned(usize, max_msgs_str, 10) catch {
+ log.debug("unsub can't parse max_msgs: '{s}'", .{max_msgs_str});
+ return error.InvalidStream;
+ };
if (in.buffered().len < iter.index + 2) {
try in.fill(iter.index + 2);
@@ -443,7 +478,10 @@ pub fn unsub(in: *Reader) Error!Message.Unsub {
const sid = iter.next() orelse return error.EndOfStream;
const max_msgs = if (iter.next()) |max_msgs_str| blk: {
log.debug("max_msgs: {any}", .{max_msgs_str});
- break :blk parseUnsigned(usize, max_msgs_str, 10) catch return error.InvalidStream;
+ break :blk parseUnsigned(usize, max_msgs_str, 10) catch {
+ log.debug("unsub can't parse bytes (eos): '{s}'", .{max_msgs_str});
+ return error.InvalidStream;
+ };
} else null;
return .{
.sid = sid,
@@ -529,7 +567,10 @@ test unsub {
/// The return value is owned by the reader passed to this function.
/// Operations that modify the readers buffer invalidates this value.
-pub fn hpub(in: *Reader) Error!Message.HPub {
+/// The arena_allocator is used to store the payload if it can't fit
+/// in the readers buffer.
+pub fn hpub(in: *Reader, arena_allocator: *std.heap.ArenaAllocator) (error{OutOfMemory} || Error)!Message.HPub {
+ _ = arena_allocator;
// TODO: Add pedantic option.
// See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1
while (true) {
@@ -543,8 +584,14 @@ pub fn hpub(in: *Reader) Error!Message.HPub {
const header_bytes_str = second;
const total_bytes_str = third;
- const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch return error.InvalidStream;
- const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch return error.InvalidStream;
+ const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch {
+ log.debug("hpub can't parse header bytes: '{s}'", .{header_bytes_str});
+ return error.InvalidStream;
+ };
+ const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch {
+ log.debug("hpub can't parse total bytes: '{s}'", .{header_bytes_str});
+ return error.InvalidStream;
+ };
if (in.buffered().len < iter.index + total_bytes + 4) {
try in.fill(iter.index + total_bytes + 4);
@@ -568,8 +615,14 @@ pub fn hpub(in: *Reader) Error!Message.HPub {
const header_bytes_str = third;
if (iter.next()) |total_bytes_str| {
if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
- const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch return error.InvalidStream;
- const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch return error.InvalidStream;
+ const header_bytes = parseUnsigned(usize, header_bytes_str, 10) catch {
+ log.debug("hpub can't parse header bytes (with reply): '{s}'", .{header_bytes_str});
+ return error.InvalidStream;
+ };
+ const total_bytes = parseUnsigned(usize, total_bytes_str, 10) catch {
+ log.debug("hpub can't parse total bytes (with reply): '{s}'", .{header_bytes_str});
+ return error.InvalidStream;
+ };
if (in.buffered().len < iter.index + total_bytes + 4) {
try in.fill(iter.index + total_bytes + 4);
@@ -664,14 +717,18 @@ pub fn connect(alloc: Allocator, in: *Reader) (error{OutOfMemory} || Error)!Mess
connect_allocator,
connect_str,
.{ .allocate = .alloc_always },
- ) catch return error.InvalidStream;
+ ) catch {
+ log.debug("connect can't parse json body: '{s}'", .{connect_str});
+ return error.InvalidStream;
+ };
return res.dupe(alloc);
}
inline fn expectStreamBytes(reader: *Reader, expected: []const u8) !void {
if (!std.mem.eql(u8, try reader.take(expected.len), expected)) {
- @branchHint(.unlikely);
+ @branchHint(.cold);
+ log.debug("expectStreamBytes wrong bytes", .{});
return error.InvalidStream;
}
}
diff --git a/src/subcommand/serve.zig b/src/subcommand/serve.zig
index e6bb648..9cb35bb 100644
--- a/src/subcommand/serve.zig
+++ b/src/subcommand/serve.zig
@@ -30,7 +30,7 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void {
.server_id = Server.default_id,
.server_name = Server.default_name,
.version = "zits-master",
- .max_payload = 1048576,
+ .max_payload = 1 * 1024 * 1024,
.headers = true,
};
@@ -54,6 +54,18 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void {
return;
}
},
+ .max_bytes => {
+ i += 1;
+ if (args.len > i) {
+ server_config.max_payload = std.fmt.parseUnsigned(usize, args[i], 10) catch {
+ std.log.err("Could not parse max bytes: {s}", .{args[i]});
+ return;
+ };
+ } else {
+ std.log.err("Must specify max bytes with {s}", .{args[i - 1]});
+ return;
+ }
+ },
}
}
}
@@ -90,16 +102,23 @@ pub fn main(alloc: Allocator, outer_io: Io, args: []const [:0]const u8) !void {
std.log.info("Goodbye", .{});
}
-const help = "serve help\n";
+const help =
+ \\--port/-p <n> Specify port
+ \\--help/-h Get help.
+ \\--max-bytes <n> Specify max message byte size
+ \\
+;
const to_flag: std.StaticStringMap(Flag) = .initComptime(.{
.{ "-p", .port },
.{ "--port", .port },
.{ "-h", .help },
.{ "--help", .help },
+ .{ "--max-bytes", .max_bytes },
});
const Flag = enum {
port,
help,
+ max_bytes,
};