From b87412ee66197d4c89f1fbf93b32fe63ed1c63ab Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Tue, 6 Jan 2026 18:45:17 -0500 Subject: Restructuring Add a bunch of tests for the client --- src/Server.zig | 422 +++++++++++++++ src/Server/Client.zig | 240 +++++++++ src/Server/message_parser.zig | 1141 +++++++++++++++++++++++++++++++++++++++++ src/main.zig | 26 +- src/root.zig | 4 +- src/server/Client.zig | 253 --------- src/server/Server.zig | 417 --------------- src/server/main.zig | 66 --- src/server/message_parser.zig | 1141 ----------------------------------------- src/subcommand/server.zig | 66 +++ 10 files changed, 1895 insertions(+), 1881 deletions(-) create mode 100644 src/Server.zig create mode 100644 src/Server/Client.zig create mode 100644 src/Server/message_parser.zig delete mode 100644 src/server/Client.zig delete mode 100644 src/server/Server.zig delete mode 100644 src/server/main.zig delete mode 100644 src/server/message_parser.zig create mode 100644 src/subcommand/server.zig diff --git a/src/Server.zig b/src/Server.zig new file mode 100644 index 0000000..e7d00b1 --- /dev/null +++ b/src/Server.zig @@ -0,0 +1,422 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ArrayList = std.ArrayList; +const AutoHashMapUnmanaged = std.AutoHashMapUnmanaged; + +const Io = std.Io; +const Dir = Io.Dir; +const Group = Io.Group; +const IpAddress = std.Io.net.IpAddress; +const Mutex = Io.Mutex; +const Queue = Io.Queue; +const Stream = std.Io.net.Stream; + +pub const Client = @import("./Server/Client.zig"); + +const message_parser = @import("./Server/message_parser.zig"); + +pub const MessageType = message_parser.MessageType; +pub const Message = message_parser.Message; +const ServerInfo = Message.ServerInfo; + +const Msgs = Client.Msgs; +const Server = @This(); + +const builtin = @import("builtin"); + +pub const Subscription = struct { + subject: []const u8, + client_id: usize, + sid: []const u8, + queue_group: ?[]const u8, + queue: *Queue(Msgs), + // used to alloc messages in the queue + alloc: Allocator, + + fn deinit(self: Subscription, alloc: Allocator) void { + alloc.free(self.subject); + alloc.free(self.sid); + if (self.queue_group) |g| alloc.free(g); + } +}; + +const eql = std.mem.eql; +const log = std.log.scoped(.zits); +const panic = std.debug.panic; + +info: ServerInfo, +clients: AutoHashMapUnmanaged(usize, *Client) = .empty, + +subs_lock: Mutex = .init, +subscriptions: ArrayList(Subscription) = .empty, + +pub fn deinit(server: *Server, io: Io, alloc: Allocator) void { + server.subs_lock.lockUncancelable(io); + defer server.subs_lock.unlock(io); + for (server.subscriptions.items) |sub| { + sub.deinit(alloc); + } + // TODO drain subscription queues + server.subscriptions.deinit(alloc); + server.clients.deinit(alloc); +} + +pub fn start(server: *Server, io: Io, gpa: Allocator) !void { + var tcp_server = try IpAddress.listen(try IpAddress.parse( + server.info.host, + server.info.port, + ), io, .{}); + defer tcp_server.deinit(io); + log.debug("Server headers: {s}", .{if (server.info.headers) "true" else "false"}); + log.debug("Server max payload: {d}", .{server.info.max_payload}); + log.info("Server ID: {s}", .{server.info.server_id}); + log.info("Server name: {s}", .{server.info.server_name}); + log.info("Server listening on {s}:{d}", .{ server.info.host, server.info.port }); + + var client_group: Group = .init; + defer client_group.cancel(io); + + const read_buffer_size, const write_buffer_size = getBufferSizes(io); + log.debug("read buf: {d} write buf: {d}", .{ read_buffer_size, write_buffer_size }); + + var id: usize = 0; + while (true) : (id +%= 1) { + if (server.clients.contains(id)) continue; + log.debug("Accepting next client", .{}); + const stream = try tcp_server.accept(io); + log.debug("Accepted connection {d}", .{id}); + _ = client_group.concurrent(io, handleConnectionInfallible, .{ + server, + gpa, + io, + id, + stream, + read_buffer_size, + write_buffer_size, + }) catch { + log.err("Could not start concurrent handler for {d}", .{id}); + stream.close(io); + }; + } +} + +fn addClient(server: *Server, allocator: Allocator, id: usize, client: *Client) !void { + try server.clients.put(allocator, id, client); +} + +fn removeClient(server: *Server, io: Io, allocator: Allocator, id: usize) void { + server.subs_lock.lockUncancelable(io); + defer server.subs_lock.unlock(io); + if (server.clients.remove(id)) { + const len = server.subscriptions.items.len; + for (0..len) |from_end| { + const i = len - from_end - 1; + const sub = server.subscriptions.items[i]; + if (sub.client_id == id) { + sub.deinit(allocator); + _ = server.subscriptions.swapRemove(i); + } + } + } +} + +fn handleConnectionInfallible( + server: *Server, + server_allocator: Allocator, + io: Io, + id: usize, + stream: Stream, + r_buf_size: usize, + w_buf_size: usize, +) !void { + handleConnection(server, server_allocator, io, id, stream, r_buf_size, w_buf_size) catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => log.err("Failed processing client {d}: {any}", .{ id, err }), + }; +} + +fn handleConnection( + server: *Server, + server_allocator: Allocator, + io: Io, + id: usize, + stream: Stream, + r_buf_size: usize, + w_buf_size: usize, +) !void { + defer stream.close(io); + + var dba: std.heap.DebugAllocator(.{}) = .init; + dba.backing_allocator = server_allocator; + defer _ = dba.deinit(); + const alloc = if (builtin.mode == .Debug or builtin.mode == .ReleaseSafe) + dba.allocator() + else + server_allocator; + + // Set up client writer + const w_buffer: []u8 = try alloc.alloc(u8, w_buf_size); + defer alloc.free(w_buffer); + var writer = stream.writer(io, w_buffer); + const out = &writer.interface; + + // Set up client reader + const r_buffer: []u8 = try alloc.alloc(u8, r_buf_size); + defer alloc.free(r_buffer); + var reader = stream.reader(io, r_buffer); + const in = &reader.interface; + + // Set up buffer queue + const qbuf: []Message = try alloc.alloc(Message, 16); + defer alloc.free(qbuf); + var recv_queue: Queue(Message) = .init(qbuf); + defer recv_queue.close(io); + + const mbuf: []Msgs = try alloc.alloc(Msgs, w_buf_size / @sizeOf(Msgs)); + defer alloc.free(mbuf); + var msgs_queue: Queue(Msgs) = .init(mbuf); + defer { + msgs_queue.close(io); + while (msgs_queue.getOne(io)) |msg| { + switch (msg) { + .MSG => |m| m.deinit(alloc), + .HMSG => |h| h.deinit(alloc), + } + } else |_| {} + } + + // Create client + var client: Client = .init(null, alloc, &recv_queue, &msgs_queue, in, out); + defer client.deinit(server_allocator); + + try server.addClient(server_allocator, id, &client); + defer server.removeClient(io, server_allocator, id); + + // Do initial handshake with client + // try recv_queue.putOne(io, .PONG); + try recv_queue.putOne(io, .{ .INFO = server.info }); + + var client_task = try io.concurrent(Client.start, .{ &client, io }); + defer client_task.cancel(io) catch {}; + + // Messages are owned by the server after they are received from the client + while (client.next(server_allocator)) |msg| { + switch (msg) { + .PING => { + // Respond to ping with pong. + try client.send(io, .PONG); + }, + .PUB => |pb| { + @branchHint(.likely); + defer pb.deinit(server_allocator); + try server.publishMessage(io, server_allocator, &client, msg); + }, + .HPUB => |hp| { + @branchHint(.likely); + defer hp.deinit(server_allocator); + try server.publishMessage(io, server_allocator, &client, msg); + }, + .SUB => |sub| { + defer sub.deinit(server_allocator); + try server.subscribe(io, server_allocator, client, id, sub); + }, + .UNSUB => |unsub| { + defer unsub.deinit(server_allocator); + try server.unsubscribe(io, server_allocator, id, unsub); + }, + .CONNECT => |connect| { + if (client.connect) |*current| { + current.deinit(server_allocator); + } + client.connect = connect; + }, + else => |e| { + panic("Unimplemented message: {any}\n", .{e}); + }, + } + } else |err| switch (err) { + error.EndOfStream, error.ReadFailed => { + log.debug("Client {d} disconnected", .{id}); + return error.Canceled; + }, + else => { + return err; + }, + } +} + +fn subjectMatches(sub_subject: []const u8, pub_subject: []const u8) bool { + // TODO: assert that sub_subject and pub_subject are valid. + var sub_iter = std.mem.splitScalar(u8, sub_subject, '.'); + var pub_iter = std.mem.splitScalar(u8, pub_subject, '.'); + + while (sub_iter.next()) |st| { + const pt = pub_iter.next() orelse return false; + + if (eql(u8, st, ">")) return true; + + if (!eql(u8, st, "*") and !eql(u8, st, pt)) { + return false; + } + } + + return pub_iter.next() == null; +} + +test subjectMatches { + const expect = std.testing.expect; + try expect(subjectMatches("foo", "foo")); + try expect(!subjectMatches("foo", "bar")); + + try expect(subjectMatches("foo.*", "foo.bar")); + try expect(!subjectMatches("foo.*", "foo")); + try expect(!subjectMatches("foo.>", "foo")); + + // the wildcard subscriptions foo.*.quux and foo.> both match foo.bar.quux, but only the latter matches foo.bar.baz. + try expect(subjectMatches("foo.*.quux", "foo.bar.quux")); + try expect(subjectMatches("foo.>", "foo.bar.quux")); + try expect(!subjectMatches("foo.*.quux", "foo.bar.baz")); + try expect(subjectMatches("foo.>", "foo.bar.baz")); +} + +fn publishMessage( + server: *Server, + io: Io, + alloc: Allocator, + source_client: *Client, + msg: Message, +) !void { + defer if (source_client.connect) |c| { + if (c.verbose) { + source_client.send(io, .@"+OK") catch {}; + } + }; + + const subject = switch (msg) { + .PUB => |pb| pb.subject, + .HPUB => |hp| hp.@"pub".subject, + else => unreachable, + }; + try server.subs_lock.lock(io); + defer server.subs_lock.unlock(io); + var published_queue_groups: ArrayList([]const u8) = .empty; + defer published_queue_groups.deinit(alloc); + var published_queue_sub_idxs: ArrayList(usize) = .empty; + defer published_queue_sub_idxs.deinit(alloc); + + subs: for (0..server.subscriptions.items.len) |i| { + const subscription = server.subscriptions.items[i]; + if (subjectMatches(subscription.subject, subject)) { + if (subscription.queue_group) |sg| { + for (published_queue_groups.items) |g| { + if (eql(u8, g, sg)) { + continue :subs; + } + } + // Don't republish to the same queue + try published_queue_groups.append(alloc, sg); + // Move this index to the end of the subscription list, + // to prioritize other subscriptions in the queue next time. + try published_queue_sub_idxs.append(alloc, i); + } + switch (msg) { + .PUB => |pb| { + try subscription.queue.putOne(io, .{ + .MSG = try pb.toMsg(subscription.alloc, subscription.sid), + }); + }, + .HPUB => |hp| { + try subscription.queue.putOne(io, .{ + .HMSG = try hp.toHMsg(subscription.alloc, subscription.sid), + }); + }, + else => unreachable, + } + } + } + + for (0..published_queue_sub_idxs.items.len) |from_end| { + const i = published_queue_sub_idxs.items.len - from_end - 1; + server.subscriptions.appendAssumeCapacity(server.subscriptions.orderedRemove(i)); + } +} + +fn subscribe( + server: *Server, + io: Io, + gpa: Allocator, + client: Client, + id: usize, + msg: Message.Sub, +) !void { + try server.subs_lock.lock(io); + defer server.subs_lock.unlock(io); + const subject = try gpa.dupe(u8, msg.subject); + errdefer gpa.free(subject); + const sid = try gpa.dupe(u8, msg.sid); + errdefer gpa.free(sid); + const queue_group = if (msg.queue_group) |q| try gpa.dupe(u8, q) else null; + errdefer if (queue_group) |q| gpa.free(q); + try server.subscriptions.append(gpa, .{ + .subject = subject, + .client_id = id, + .sid = sid, + .queue_group = queue_group, + .queue = client.msg_queue, + .alloc = client.alloc, + }); +} + +fn unsubscribe( + server: *Server, + io: Io, + gpa: Allocator, + id: usize, + msg: Message.Unsub, +) !void { + try server.subs_lock.lock(io); + defer server.subs_lock.unlock(io); + const len = server.subscriptions.items.len; + for (0..len) |from_end| { + const i = len - from_end - 1; + const sub = server.subscriptions.items[i]; + if (sub.client_id == id and eql(u8, sub.sid, msg.sid)) { + sub.deinit(gpa); + _ = server.subscriptions.swapRemove(i); + } + } +} + +const parseUnsigned = std.fmt.parseUnsigned; + +fn getBufferSizes(io: Io) struct { usize, usize } { + const default_size = 4 * 1024; + const default = .{ default_size, default_size }; + + const dir = Dir.openDirAbsolute(io, "/proc/sys/net/core", .{}) catch { + log.warn("couldn't open /proc/sys/net/core", .{}); + return default; + }; + + var buf: [64]u8 = undefined; + + const rmem_max = readBufferSize(io, dir, "rmem_max", &buf, default_size); + const wmem_max = readBufferSize(io, dir, "wmem_max", &buf, default_size); + + return .{ rmem_max, wmem_max }; +} + +fn readBufferSize(io: Io, dir: anytype, filename: []const u8, buf: []u8, default: usize) usize { + const bytes = dir.readFile(io, filename, buf) catch |err| { + log.err("couldn't open {s}: {any}", .{ filename, err }); + return default; + }; + + return parseUnsigned(usize, bytes[0 .. bytes.len - 1], 10) catch |err| { + log.err("couldn't parse {s}: {any}", .{ bytes[0 .. bytes.len - 1], err }); + return default; + }; +} + +pub const default_id = "server-id-123"; +pub const default_name = "Zits Server"; diff --git a/src/Server/Client.zig b/src/Server/Client.zig new file mode 100644 index 0000000..dff3534 --- /dev/null +++ b/src/Server/Client.zig @@ -0,0 +1,240 @@ +const Message = @import("message_parser.zig").Message; +const std = @import("std"); +const Queue = std.Io.Queue; + +const Client = @This(); + +pub const Msgs = union(enum) { + MSG: Message.Msg, + HMSG: Message.HMsg, +}; + +connect: ?Message.Connect, +// Used to own messages that we receive in our queues. +alloc: std.mem.Allocator, + +// Messages for this client to receive. +recv_queue: *Queue(Message), +msg_queue: *Queue(Msgs), + +from_client: *std.Io.Reader, +to_client: *std.Io.Writer, + +pub fn init( + connect: ?Message.Connect, + alloc: std.mem.Allocator, + recv_queue: *Queue(Message), + msg_queue: *Queue(Msgs), + in: *std.Io.Reader, + out: *std.Io.Writer, +) Client { + return .{ + .connect = connect, + .alloc = alloc, + .recv_queue = recv_queue, + .msg_queue = msg_queue, + .from_client = in, + .to_client = out, + }; +} + +pub fn deinit(self: *Client, alloc: std.mem.Allocator) void { + if (self.connect) |c| { + c.deinit(alloc); + } + self.* = undefined; +} + +pub fn start(self: *Client, io: std.Io) !void { + var msgs_buf: [1024]Msgs = undefined; + + var recv_msgs_task = io.concurrent(Queue(Msgs).get, .{ self.msg_queue, io, &msgs_buf, 1 }) catch @panic("Concurrency unavailable"); + errdefer _ = recv_msgs_task.cancel(io) catch {}; + + var recv_proto_task = io.concurrent(Queue(Message).getOne, .{ self.recv_queue, io }) catch unreachable; + errdefer _ = recv_proto_task.cancel(io) catch {}; + + while (true) { + switch (try io.select(.{ .msgs = &recv_msgs_task, .proto = &recv_proto_task })) { + .msgs => |len_err| { + @branchHint(.likely); + const msgs = msgs_buf[0..try len_err]; + for (0..msgs.len) |i| { + const msg = msgs[i]; + defer switch (msg) { + .MSG => |m| m.deinit(self.alloc), + .HMSG => |h| h.deinit(self.alloc), + }; + errdefer for (msgs[i + 1 ..]) |mg| switch (mg) { + .MSG => |m| { + m.deinit(self.alloc); + }, + .HMSG => |h| { + h.deinit(self.alloc); + }, + }; + switch (msg) { + .MSG => |m| { + try self.to_client.print( + "MSG {s} {s} {s} {d}\r\n", + .{ + m.subject, + m.sid, + m.reply_to orelse "", + m.payload.len, + }, + ); + try m.payload.write(self.to_client); + try self.to_client.print("\r\n", .{}); + }, + .HMSG => |hmsg| { + try self.to_client.print("HMSG {s} {s} {s} {d} {d}\r\n", .{ + hmsg.msg.subject, + hmsg.msg.sid, + hmsg.msg.reply_to orelse "", + hmsg.header_bytes, + hmsg.msg.payload.len, + }); + try hmsg.msg.payload.write(self.to_client); + try self.to_client.print("\r\n", .{}); + }, + } + } + recv_msgs_task = io.concurrent(Queue(Msgs).get, .{ self.msg_queue, io, &msgs_buf, 1 }) catch unreachable; + }, + .proto => |msg_err| { + @branchHint(.unlikely); + const msg = try msg_err; + switch (msg) { + .@"+OK" => { + _ = try self.to_client.write("+OK\r\n"); + }, + .PONG => { + _ = try self.to_client.write("PONG\r\n"); + }, + .INFO => |info| { + _ = try self.to_client.write("INFO "); + try std.json.Stringify.value(info, .{}, self.to_client); + _ = try self.to_client.write("\r\n"); + }, + .@"-ERR" => |s| { + _ = try self.to_client.print("-ERR '{s}'\r\n", .{s}); + }, + else => |m| { + std.debug.panic("unimplemented write: {any}\n", .{m}); + }, + } + recv_proto_task = io.concurrent(Queue(Message).getOne, .{ self.recv_queue, io }) catch unreachable; + }, + } + try self.to_client.flush(); + } +} + +pub fn send(self: *Client, io: std.Io, msg: Message) !void { + switch (msg) { + .MSG => |m| try self.msg_queue.putOne(io, .{ .MSG = m }), + .HMSG => |m| try self.msg_queue.putOne(io, .{ .HMSG = m }), + else => try self.recv_queue.putOne(io, msg), + } +} + +test send { + const io = std.testing.io; + const gpa = std.testing.allocator; + + var to_client: std.Io.Writer = .fixed(blk: { + var buf: [1024]u8 = undefined; + break :blk &buf; + }); + var recv_queue: Queue(Message) = .init(&.{}); + var msgs_queue: Queue(Msgs) = .init(blk: { + var buf: [1]Msgs = undefined; + break :blk &buf; + }); + var client: Client = .init(null, gpa, &recv_queue, &msgs_queue, undefined, &to_client); + defer client.deinit(gpa); + + var c_task = try io.concurrent(Client.start, .{ &client, io }); + defer c_task.cancel(io) catch {}; + + { + try client.send(io, .PONG); + // Wait for the concurrent client task to write to the writer + try io.sleep(.fromMilliseconds(1), .awake); + try std.testing.expectEqualSlices(u8, "PONG\r\n", to_client.buffered()); + } + + to_client.end = 0; + + { + const payload = "payload"; + const msg: Message.Msg = .{ + .sid = "1", + .subject = "subject", + .reply_to = "reply", + .payload = .{ + .len = payload.len, + .short = blk: { + var buf: [128]u8 = undefined; + @memcpy(buf[0..payload.len], payload); + break :blk buf; + }, + .long = null, + }, + }; + try client.send(io, .{ + // msg must be owned by the allocator the client uses + .MSG = try msg.dupe(gpa), + }); + try io.sleep(.fromMilliseconds(1), .awake); + try std.testing.expectEqualSlices(u8, "MSG subject 1 reply 7\r\npayload\r\n", to_client.buffered()); + } +} + +pub fn next(self: *Client, allocator: std.mem.Allocator) !Message { + return Message.next(allocator, self.from_client); +} + +test next { + const gpa = std.testing.allocator; + + var from_client: std.Io.Reader = .fixed( + "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_r" ++ + "equired\":false,\"name\":\"NATS CLI Version v0.2." ++ + "4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"prot" ++ + "ocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\n" ++ + "PING\r\n", + ); + + var client: Client = .init(null, undefined, undefined, undefined, &from_client, undefined); + + { + // Simulate stream + + { + const msg = try client.next(gpa); + try std.testing.expectEqual(.CONNECT, std.meta.activeTag(msg)); + defer msg.CONNECT.deinit(gpa); + try std.testing.expectEqualDeep(Message{ + .CONNECT = .{ + .verbose = false, + .pedantic = false, + .tls_required = false, + .name = "NATS CLI Version v0.2.4", + .lang = "go", + .version = "1.43.0", + .protocol = 1, + .echo = true, + .headers = true, + .no_responders = true, + }, + }, msg); + } + + { + const msg = try client.next(gpa); + try std.testing.expectEqual(.PING, std.meta.activeTag(msg)); + } + } +} diff --git a/src/Server/message_parser.zig b/src/Server/message_parser.zig new file mode 100644 index 0000000..fd1b5b1 --- /dev/null +++ b/src/Server/message_parser.zig @@ -0,0 +1,1141 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; +const ArrayList = std.ArrayList; +const StaticStringMap = std.StaticStringMap; + +const Io = std.Io; +const Writer = Io.Writer; +const AllocatingWriter = Writer.Allocating; +const Reader = Io.Reader; + +const ascii = std.ascii; +const isDigit = std.ascii.isDigit; +const isUpper = std.ascii.isUpper; +const isWhitespace = std.ascii.isWhitespace; + +const parseUnsigned = std.fmt.parseUnsigned; + +const log = std.log; + +pub const Payload = struct { + len: u32, + short: [128]u8, + long: ?[]u8, + + pub fn read(alloc: Allocator, in: *Reader, bytes: usize) !Payload { + var res: Payload = .{ + .len = @intCast(bytes), + .short = undefined, + .long = null, + }; + + try in.readSliceAll(res.short[0..@min(bytes, res.short.len)]); + if (bytes > res.short.len) { + const long = try alloc.alloc(u8, bytes - res.short.len); + errdefer alloc.free(long); + try in.readSliceAll(long); + res.long = long; + } + return res; + } + + pub fn write(self: Payload, out: *Writer) !void { + std.debug.assert(out.buffer.len >= self.short.len); + std.debug.assert(self.len <= self.short.len or self.long != null); + try out.writeAll(self.short[0..@min(self.len, self.short.len)]); + if (self.long) |l| { + try out.writeAll(l); + } + } + + pub fn deinit(self: Payload, alloc: Allocator) void { + if (self.long) |l| { + alloc.free(l); + } + } + + pub fn dupe(self: Payload, alloc: Allocator) !Payload { + var res = self; + if (self.long) |l| { + res.long = try alloc.dupe(u8, l); + } + errdefer if (res.long) |l| alloc.free(l); + return res; + } +}; + +pub const MessageType = @typeInfo(Message).@"union".tag_type.?; + +pub const Message = union(enum) { + INFO: ServerInfo, + CONNECT: Connect, + PUB: Pub, + HPUB: HPub, + SUB: Sub, + UNSUB: Unsub, + MSG: Msg, + HMSG: HMsg, + PING, + PONG, + @"+OK": void, + @"-ERR": []const u8, + pub const ServerInfo = struct { + /// The unique identifier of the NATS server. + server_id: []const u8, + /// The name of the NATS server. + server_name: []const u8, + /// The version of NATS. + version: []const u8, + /// The version of golang the NATS server was built with. + go: []const u8 = "0.0.0", + /// The IP address used to start the NATS server, + /// by default this will be 0.0.0.0 and can be + /// configured with -client_advertise host:port. + host: []const u8 = "0.0.0.0", + /// The port number the NATS server is configured + /// to listen on. + port: u16 = 4222, + /// Whether the server supports headers. + headers: bool = false, + /// Maximum payload size, in bytes, that the server + /// will accept from the client. + max_payload: u64, + /// An integer indicating the protocol version of + /// the server. The server version 1.2.0 sets this + /// to 1 to indicate that it supports the "Echo" + /// feature. + proto: u32 = 1, + }; + pub const Connect = struct { + verbose: bool = false, + pedantic: bool = false, + tls_required: bool = false, + auth_token: ?[]const u8 = null, + user: ?[]const u8 = null, + pass: ?[]const u8 = null, + name: ?[]const u8 = null, + lang: []const u8, + version: []const u8, + protocol: u32, + echo: ?bool = null, + sig: ?[]const u8 = null, + jwt: ?[]const u8 = null, + no_responders: ?bool = null, + headers: ?bool = null, + nkey: ?[]const u8 = null, + + pub fn deinit(self: Connect, alloc: Allocator) void { + if (self.auth_token) |a| alloc.free(a); + if (self.user) |u| alloc.free(u); + if (self.pass) |p| alloc.free(p); + if (self.name) |n| alloc.free(n); + alloc.free(self.lang); + alloc.free(self.version); + if (self.sig) |s| alloc.free(s); + if (self.jwt) |j| alloc.free(j); + if (self.nkey) |n| alloc.free(n); + } + + pub fn dupe(self: Connect, alloc: Allocator) !Connect { + var res = self; + res.auth_token = if (self.auth_token) |a| try alloc.dupe(u8, a) else null; + errdefer if (res.auth_token) |a| alloc.free(a); + res.user = if (self.user) |u| try alloc.dupe(u8, u) else null; + errdefer if (res.user) |u| alloc.free(u); + res.pass = if (self.pass) |p| try alloc.dupe(u8, p) else null; + errdefer if (res.pass) |p| alloc.free(p); + res.name = if (self.name) |n| try alloc.dupe(u8, n) else null; + errdefer if (res.name) |n| alloc.free(n); + res.lang = try alloc.dupe(u8, self.lang); + errdefer alloc.free(res.lang); + res.version = try alloc.dupe(u8, self.version); + errdefer alloc.free(res.version); + res.sig = if (self.sig) |s| try alloc.dupe(u8, s) else null; + errdefer if (res.sig) |s| alloc.free(s); + res.jwt = if (self.jwt) |j| try alloc.dupe(u8, j) else null; + errdefer if (res.jwt) |j| alloc.free(j); + res.nkey = if (self.nkey) |n| try alloc.dupe(u8, n) else null; + errdefer if (res.nkey) |n| alloc.free(n); + return res; + } + }; + pub const Pub = struct { + /// The destination subject to publish to. + subject: []const u8, + /// The reply subject that subscribers can use to send a response back to the publisher/requestor. + reply_to: ?[]const u8 = null, + /// The message payload data. + payload: Payload, + + pub fn deinit(self: Pub, alloc: Allocator) void { + alloc.free(self.subject); + self.payload.deinit(alloc); + if (self.reply_to) |r| alloc.free(r); + } + + pub fn toMsg(self: Pub, alloc: Allocator, sid: []const u8) !Msg { + const res: Msg = .{ + .subject = self.subject, + .sid = sid, + .reply_to = self.reply_to, + .payload = self.payload, + }; + return res.dupe(alloc); + } + }; + pub const HPub = struct { + header_bytes: usize, + @"pub": Pub, + + pub fn deinit(self: HPub, alloc: Allocator) void { + self.@"pub".deinit(alloc); + } + + pub fn toHMsg(self: HPub, alloc: Allocator, sid: []const u8) !HMsg { + return .{ + .header_bytes = self.header_bytes, + .msg = try self.@"pub".toMsg(alloc, sid), + }; + } + }; + + pub const HMsg = struct { + header_bytes: usize, + msg: Msg, + + pub fn deinit(self: HMsg, alloc: Allocator) void { + self.msg.deinit(alloc); + } + + pub fn dupe(self: HMsg, alloc: Allocator) !HMsg { + var res = self; + res.msg = try self.msg.dupe(alloc); + errdefer alloc.free(res.msg); + return res; + } + }; + pub const Sub = struct { + /// The subject name to subscribe to. + subject: []const u8, + /// If specified, the subscriber will join this queue group. + queue_group: ?[]const u8, + /// A unique alphanumeric subscription ID, generated by the client. + sid: []const u8, + + pub fn deinit(self: Sub, alloc: Allocator) void { + alloc.free(self.subject); + alloc.free(self.sid); + if (self.queue_group) |q| alloc.free(q); + } + }; + pub const Unsub = struct { + /// The unique alphanumeric subscription ID of the subject to unsubscribe from. + sid: []const u8, + /// A number of messages to wait for before automatically unsubscribing. + max_msgs: ?usize = null, + + pub fn deinit(self: Unsub, alloc: Allocator) void { + alloc.free(self.sid); + } + }; + pub const Msg = struct { + subject: []const u8, + sid: []const u8, + reply_to: ?[]const u8, + payload: Payload, + + pub fn deinit(self: Msg, alloc: Allocator) void { + alloc.free(self.subject); + alloc.free(self.sid); + if (self.reply_to) |r| alloc.free(r); + self.payload.deinit(alloc); + } + + pub fn dupe(self: Msg, alloc: Allocator) !Msg { + var res: Msg = undefined; + res.subject = try alloc.dupe(u8, self.subject); + errdefer alloc.free(res.subject); + res.sid = try alloc.dupe(u8, self.sid); + errdefer alloc.free(res.sid); + res.reply_to = if (self.reply_to) |r| try alloc.dupe(u8, r) else null; + errdefer if (res.reply_to) |r| alloc.free(r); + res.payload = try self.payload.dupe(alloc); + errdefer alloc.free(res.payload); + return res; + } + }; + + const client_types = StaticStringMap(MessageType).initComptime( + .{ + // {"INFO", .info}, + .{ @tagName(.CONNECT), .CONNECT }, + .{ @tagName(.PUB), .PUB }, + .{ @tagName(.HPUB), .HPUB }, + .{ @tagName(.SUB), .SUB }, + .{ @tagName(.UNSUB), .UNSUB }, + // {"MSG", .msg}, + // {"HMSG", .hmsg}, + .{ @tagName(.PING), .PING }, + .{ @tagName(.PONG), .PONG }, + // {"+OK", .@"+ok"}, + // {"-ERR", .@"-err"}, + }, + ); + fn parseStaticStringMap(input: []const u8) ?MessageType { + return client_types.get(input); + } + + pub const parse = parseStaticStringMap; + + /// An error should be handled by cleaning up this connection. + pub fn next(alloc: Allocator, in: *Reader) !Message { + var operation_string: ArrayList(u8) = blk: { + comptime var buf_len = 0; + comptime { + for (client_types.keys()) |key| { + buf_len = @max(buf_len, key.len); + } + } + var buf: [buf_len]u8 = undefined; + break :blk .initBuffer(&buf); + }; + + while (in.peekByte()) |byte| { + if (isUpper(byte)) { + try operation_string.appendBounded(byte); + in.toss(1); + } else break; + } else |err| return err; + + const operation = parse(operation_string.items) orelse { + log.err("Invalid operation: '{s}'", .{operation_string.items}); + return error.InvalidOperation; + }; + + errdefer log.err("Failed to parse {s}", .{operation_string.items}); + + switch (operation) { + .CONNECT => { + return parseConnect(alloc, in); + }, + .PUB => { + @branchHint(.likely); + return parsePub(alloc, in); + }, + .HPUB => { + @branchHint(.likely); + return parseHPub(alloc, in); + }, + .PING => { + try expectStreamBytes(in, "\r\n"); + return .PING; + }, + .PONG => { + try expectStreamBytes(in, "\r\n"); + return .PONG; + }, + .SUB => { + return parseSub(alloc, in); + }, + .UNSUB => { + return parseUnsub(alloc, in); + }, + else => |msg| std.debug.panic("Not implemented: {}\n", .{msg}), + } + } +}; + +fn parseConnect(alloc: Allocator, in: *Reader) !Message { + // for storing the json string + var connect_string_writer_allocating: AllocatingWriter = .init(alloc); + defer connect_string_writer_allocating.deinit(); + var connect_string_writer = &connect_string_writer_allocating.writer; + + // for parsing the json string + var connect_arena_allocator: ArenaAllocator = .init(alloc); + defer connect_arena_allocator.deinit(); + const connect_allocator = connect_arena_allocator.allocator(); + + try in.discardAll(1); // throw away space + + // Should read the next JSON object to the fixed buffer writer. + _ = try in.streamDelimiter(connect_string_writer, '}'); + try connect_string_writer.writeByte('}'); + try expectStreamBytes(in, "}\r\n"); // discard '}\r\n' + + const connect_str = try connect_string_writer_allocating.toOwnedSlice(); + defer alloc.free(connect_str); + // TODO: should be CONNECTION allocator + const res = try std.json.parseFromSliceLeaky( + Message.Connect, + connect_allocator, + connect_str, + .{ .allocate = .alloc_always }, + ); + + return .{ .CONNECT = try res.dupe(alloc) }; +} + +fn parseSub(alloc: Allocator, in: *Reader) !Message { + try in.discardAll(1); // throw away space + const subject = try readSubject(alloc, in, .sub); + + const States = enum { + before_second, + in_second, + after_second, + in_third, + in_end, + }; + + var second: ArrayList(u8) = .empty; + errdefer second.deinit(alloc); + var third: ?ArrayList(u8) = null; + errdefer if (third) |*t| t.deinit(alloc); + + sw: switch (@as(States, .before_second)) { + .before_second => { + const byte = try in.peekByte(); + if (isWhitespace(byte)) { + in.toss(1); + continue :sw .before_second; + } + continue :sw .in_second; + }, + .in_second => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .after_second; + }, + .after_second => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } else if (isWhitespace(byte)) { + in.toss(1); + continue :sw .after_second; + } + third = .empty; + continue :sw .in_third; + }, + .in_third => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try third.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .in_end; + }, + .in_end => { + try expectStreamBytes(in, "\r\n"); + }, + } + + return .{ + .SUB = .{ + .subject = subject, + .queue_group = if (third) |_| try second.toOwnedSlice(alloc) else null, + .sid = if (third) |*t| try t.toOwnedSlice(alloc) else try second.toOwnedSlice(alloc), + }, + }; +} + +test parseSub { + const alloc = std.testing.allocator; + const expectEqualDeep = std.testing.expectEqualDeep; + { + var in: Reader = .fixed(" foo 1\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "foo", + .queue_group = null, + .sid = "1", + }, + }, + res, + ); + } + { + var in: Reader = .fixed(" foo 1\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "foo", + .queue_group = null, + .sid = "1", + }, + }, + res, + ); + } + { + var in: Reader = .fixed(" foo q 1\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "foo", + .queue_group = "q", + .sid = "1", + }, + }, + res, + ); + } + { + var in: Reader = .fixed(" 1 q 1\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "1", + .queue_group = "q", + .sid = "1", + }, + }, + res, + ); + } + { + var in: Reader = .fixed(" $SRV.PING 4\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "$SRV.PING", + .queue_group = null, + .sid = "4", + }, + }, + res, + ); + } + { + var in: Reader = .fixed(" foo.echo q 10\r\n"); + var res = try parseSub(alloc, &in); + defer res.SUB.deinit(alloc); + try expectEqualDeep( + Message{ + .SUB = .{ + .subject = "foo.echo", + .queue_group = "q", + .sid = "10", + }, + }, + res, + ); + } +} + +fn parseUnsub(alloc: Allocator, in: *Reader) !Message { + const States = enum { + before_first, + in_first, + after_first, + in_second, + in_end, + }; + + var first: ArrayList(u8) = .empty; + errdefer first.deinit(alloc); + var second: ?ArrayList(u8) = null; + defer if (second) |*s| s.deinit(alloc); + + sw: switch (@as(States, .before_first)) { + .before_first => { + const byte = try in.peekByte(); + if (isWhitespace(byte)) { + in.toss(1); + continue :sw .before_first; + } + continue :sw .in_first; + }, + .in_first => { + const byte = try in.peekByte(); + if (!isWhitespace(byte)) { + try first.append(alloc, byte); + in.toss(1); + continue :sw .in_first; + } + continue :sw .after_first; + }, + .after_first => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } else if (isWhitespace(byte)) { + in.toss(1); + continue :sw .after_first; + } + second = .empty; + continue :sw .in_second; + }, + .in_second => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } + try second.?.append(alloc, byte); + in.toss(1); + continue :sw .in_second; + }, + .in_end => { + try expectStreamBytes(in, "\r\n"); + }, + } + + return .{ + .UNSUB = .{ + .sid = try first.toOwnedSlice(alloc), + .max_msgs = if (second) |s| try parseUnsigned(usize, s.items, 10) else null, + }, + }; +} + +test parseUnsub { + const alloc = std.testing.allocator; + const expectEqualDeep = std.testing.expectEqualDeep; + const expectEqual = std.testing.expectEqual; + { + var in: Reader = .fixed(" 1\r\n"); + var res = try parseUnsub(alloc, &in); + defer res.UNSUB.deinit(alloc); + try expectEqualDeep( + Message{ + .UNSUB = .{ + .sid = "1", + .max_msgs = null, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } + + { + var in: Reader = .fixed(" 1 1\r\n"); + var res = try parseUnsub(alloc, &in); + defer res.UNSUB.deinit(alloc); + try expectEqualDeep( + Message{ + .UNSUB = .{ + .sid = "1", + .max_msgs = 1, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } +} + +fn parsePub(alloc: Allocator, in: *Reader) !Message { + try in.discardAll(1); // throw away space + + // Parse subject + const subject: []const u8 = try readSubject(alloc, in, .@"pub"); + errdefer alloc.free(subject); + + const States = enum { + before_second, + in_second, + after_second, + in_third, + in_end, + }; + + var second: ArrayList(u8) = .empty; + defer second.deinit(alloc); + var third: ?ArrayList(u8) = null; + defer if (third) |*t| t.deinit(alloc); + + sw: switch (@as(States, .before_second)) { + .before_second => { + // Drop whitespace + const byte = try in.peekByte(); + if (isWhitespace(byte)) { + in.toss(1); + continue :sw .before_second; + } + continue :sw .in_second; + }, + .in_second => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .after_second; + }, + .after_second => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } else if (isWhitespace(byte)) { + in.toss(1); + continue :sw .after_second; + } + third = .empty; + continue :sw .in_third; + }, + .in_third => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try third.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .in_end; + }, + .in_end => { + try expectStreamBytes(in, "\r\n"); + }, + } + + const reply_to: ?[]const u8, const bytes: usize = + if (third) |t| .{ + try alloc.dupe(u8, second.items), + try parseUnsigned(usize, t.items, 10), + } else .{ + null, + try parseUnsigned(usize, second.items, 10), + }; + + const payload: Payload = try .read(alloc, in, bytes); + errdefer payload.deinit(alloc); + try expectStreamBytes(in, "\r\n"); + + return .{ + .PUB = .{ + .subject = subject, + .payload = payload, + .reply_to = reply_to, + }, + }; +} + +test parsePub { + const alloc = std.testing.allocator; + const expectEqualDeep = std.testing.expectEqualDeep; + const expectEqual = std.testing.expectEqual; + { + var in: Reader = .fixed(" foo 3\r\nbar\r\n"); + var res = try parsePub(alloc, &in); + defer res.PUB.deinit(alloc); + try expectEqualDeep( + Message{ + .PUB = .{ + .subject = "foo", + .reply_to = null, + .payload = .{ + .len = 3, + .short = blk: { + var s: [128]u8 = undefined; + @memcpy(s[0..3], "bar"); + break :blk s; + }, + .long = null, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } + + { + var in: Reader = .fixed(" foo reply.to 3\r\nbar\r\n"); + var res = try parsePub(alloc, &in); + defer res.PUB.deinit(alloc); + try expectEqualDeep( + Message{ + .PUB = .{ + .subject = "foo", + .reply_to = "reply.to", + .payload = .{ + .len = 3, + .short = blk: { + var s: [128]u8 = undefined; + @memcpy(s[0..3], "bar"); + break :blk s; + }, + .long = null, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } + + // numeric reply subject + { + var in: Reader = .fixed(" foo 5 3\r\nbar\r\n"); + var res = try parsePub(alloc, &in); + defer res.PUB.deinit(alloc); + try expectEqualDeep( + Message{ + .PUB = .{ + .subject = "foo", + .reply_to = "5", + .payload = .{ + .len = 3, + .short = blk: { + var s: [128]u8 = undefined; + @memcpy(s[0..3], "bar"); + break :blk s; + }, + .long = null, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } +} + +fn parseHPub(alloc: Allocator, in: *Reader) !Message { + try in.discardAll(1); // throw away space + + // Parse subject + const subject: []const u8 = try readSubject(alloc, in, .@"pub"); + errdefer alloc.free(subject); + + const States = enum { + before_second, + in_second, + after_second, + in_third, + after_third, + in_fourth, + in_end, + }; + + var second: ArrayList(u8) = .empty; + defer second.deinit(alloc); + var third: ArrayList(u8) = .empty; + defer third.deinit(alloc); + var fourth: ?ArrayList(u8) = null; + defer if (fourth) |*f| f.deinit(alloc); + + sw: switch (@as(States, .before_second)) { + .before_second => { + // Drop whitespace + const byte = try in.peekByte(); + if (isWhitespace(byte)) { + in.toss(1); + continue :sw .before_second; + } + continue :sw .in_second; + }, + .in_second => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .after_second; + }, + .after_second => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } else if (isWhitespace(byte)) { + in.toss(1); + continue :sw .after_second; + } + third = .empty; + continue :sw .in_third; + }, + .in_third => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try third.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .after_third; + }, + .after_third => { + const byte = try in.peekByte(); + if (byte == '\r') { + continue :sw .in_end; + } else if (isWhitespace(byte)) { + in.toss(1); + continue :sw .after_third; + } + fourth = .empty; + continue :sw .in_fourth; + }, + .in_fourth => { + for (1..in.buffer.len) |i| { + try in.fill(i + 1); + if (isWhitespace(in.buffered()[i])) { + @memcpy(try fourth.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); + in.toss(i); + break; + } + } else return error.EndOfStream; + continue :sw .in_end; + }, + .in_end => { + try expectStreamBytes(in, "\r\n"); + }, + } + + const reply_to: ?[]const u8, const header_bytes: usize, const total_bytes: usize = + if (fourth) |f| .{ + try alloc.dupe(u8, second.items), + try parseUnsigned(usize, third.items, 10), + try parseUnsigned(usize, f.items, 10), + } else .{ + null, + try parseUnsigned(usize, second.items, 10), + try parseUnsigned(usize, third.items, 10), + }; + + const payload: Payload = try .read(alloc, in, total_bytes); + errdefer payload.deinit(alloc); + try expectStreamBytes(in, "\r\n"); + + return .{ + .HPUB = .{ + .header_bytes = header_bytes, + .@"pub" = .{ + .subject = subject, + .payload = payload, + .reply_to = reply_to, + }, + }, + }; +} + +test parseHPub { + const alloc = std.testing.allocator; + const expectEqualDeep = std.testing.expectEqualDeep; + const expectEqual = std.testing.expectEqual; + { + var in: Reader = .fixed(" foo 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); + var res = try parseHPub(alloc, &in); + defer res.HPUB.deinit(alloc); + try expectEqualDeep( + Message{ + .HPUB = .{ + .header_bytes = 22, + .@"pub" = .{ + .subject = "foo", + .reply_to = null, + .payload = .{ + .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, + .short = blk: { + var s: [128]u8 = undefined; + const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; + @memcpy(s[0..str.len], str); + break :blk s; + }, + .long = null, + }, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } + + { + var in: Reader = .fixed(" foo reply.to 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); + var res = try parseHPub(alloc, &in); + defer res.HPUB.deinit(alloc); + try expectEqualDeep( + Message{ + .HPUB = .{ + .header_bytes = 22, + .@"pub" = .{ + .subject = "foo", + .reply_to = "reply.to", + .payload = .{ + .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, + .short = blk: { + var s: [128]u8 = undefined; + const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; + @memcpy(s[0..str.len], str); + break :blk s; + }, + .long = null, + }, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } + + { + var in: Reader = .fixed(" foo 6 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); + var res = try parseHPub(alloc, &in); + defer res.HPUB.deinit(alloc); + try expectEqualDeep( + Message{ + .HPUB = .{ + .header_bytes = 22, + .@"pub" = .{ + .subject = "foo", + .reply_to = "6", + .payload = .{ + .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, + .short = blk: { + var s: [128]u8 = undefined; + const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; + @memcpy(s[0..str.len], str); + break :blk s; + }, + .long = null, + }, + }, + }, + }, + res, + ); + try expectEqual(0, in.buffered().len); + } +} + +fn readSubject(alloc: Allocator, in: *Reader, comptime pub_or_sub: enum { @"pub", sub }) ![]const u8 { + var subject_list: ArrayList(u8) = .empty; + errdefer subject_list.deinit(alloc); + + // Handle the first character + { + const byte = try in.takeByte(); + if (isWhitespace(byte) or byte == '.' or (pub_or_sub == .@"pub" and (byte == '*' or byte == '>'))) + return error.InvalidStream; + + try subject_list.append(alloc, byte); + } + + switch (pub_or_sub) { + .sub => { + while (in.takeByte()) |byte| { + if (isWhitespace(byte)) break; + if (byte == '.') { + const next_byte = try in.peekByte(); + if (next_byte == '.' or isWhitespace(next_byte)) + return error.InvalidStream; + } else if (byte == '>') { + const next_byte = try in.takeByte(); + if (!isWhitespace(next_byte)) + return error.InvalidStream; + } else if (byte == '*') { + const next_byte = try in.peekByte(); + if (next_byte != '.' and !isWhitespace(next_byte)) + return error.InvalidStream; + } + try subject_list.append(alloc, byte); + } else |err| return err; + }, + .@"pub" => { + while (in.takeByte()) |byte| { + if (isWhitespace(byte)) break; + if (byte == '*' or byte == '>') return error.InvalidStream; + if (byte == '.') { + const next_byte = try in.peekByte(); + if (next_byte == '.' or isWhitespace(next_byte)) + return error.InvalidStream; + } + try subject_list.append(alloc, byte); + } else |err| return err; + }, + } + + return subject_list.toOwnedSlice(alloc); +} + +inline fn expectStreamBytes(reader: *Reader, expected: []const u8) !void { + if (!std.mem.eql(u8, try reader.take(expected.len), expected)) { + @branchHint(.unlikely); + return error.InvalidStream; + } +} + +test "parsing a stream" { + const alloc = std.testing.allocator; + const expectEqualDeep = std.testing.expectEqualDeep; + const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":fa" ++ + "lse,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43" ++ + ".0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r" ++ + "\nPUB hi 3\r\nfoo\r\n"; + var reader: Reader = .fixed(input); + var arena: ArenaAllocator = .init(alloc); + defer arena.deinit(); + const gpa = arena.allocator(); + + { + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ + .CONNECT = .{ + .verbose = false, + .pedantic = false, + .tls_required = false, + .name = "NATS CLI Version v0.2.4", + .lang = "go", + .version = "1.43.0", + .protocol = 1, + .echo = true, + .headers = true, + .no_responders = true, + }, + }; + + try expectEqualDeep(expected, msg); + } + { + const msg: Message = try Message.next(gpa, &reader); + const expected: Message = .{ + .PUB = .{ + .subject = "hi", + .payload = .{ + .len = 3, + .short = blk: { + var s: [128]u8 = undefined; + const str = "foo"; + @memcpy(s[0..str.len], str); + break :blk s; + }, + .long = null, + }, + }, + }; + try expectEqualDeep(expected, msg); + } +} diff --git a/src/main.zig b/src/main.zig index 47992af..a413fba 100644 --- a/src/main.zig +++ b/src/main.zig @@ -6,6 +6,8 @@ const yazap = @import("yazap"); const Message = zits.MessageParser.Message; const Server = zits.Server; +const serverSubcommand = @import("./subcommand/server.zig").main; + pub fn main() !void { var dba: std.heap.DebugAllocator(.{}) = .init; defer _ = dba.deinit(); @@ -67,7 +69,7 @@ pub fn main() !void { info.server_name = name; } - try @import("./server/main.zig").main(gpa, info); + try serverSubcommand(gpa, info); return; } else if (matches.subcommandMatches("pub")) |_| { std.debug.print("Unimplemented\n", .{}); @@ -76,3 +78,25 @@ pub fn main() !void { try app.displayHelp(io); } + +pub const std_options: std.Options = .{ + // By default, in safe build modes, the standard library will attach a segfault handler to the program to + // print a helpful stack trace if a segmentation fault occurs. Here, we can disable this, or even enable + // it in unsafe build modes. + .enable_segfault_handler = true, + // This is the logging function used by `std.log`. + .logFn = myLogFn, +}; + +fn myLogFn( + comptime level: std.log.Level, + comptime scope: @EnumLiteral(), + comptime format: []const u8, + args: anytype, +) void { + if (scope == .zits) { + std.log.defaultLog(level, std.log.default_log_scope, format, args); + } else { + std.log.defaultLog(level, scope, format, args); + } +} diff --git a/src/root.zig b/src/root.zig index 49631cb..d4c7cd8 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,3 +1 @@ -const MessageParser = @import("server/message_parser.zig"); - -pub const Server = @import("server/Server.zig"); +pub const Server = @import("Server.zig"); diff --git a/src/server/Client.zig b/src/server/Client.zig deleted file mode 100644 index 690cabf..0000000 --- a/src/server/Client.zig +++ /dev/null @@ -1,253 +0,0 @@ -const Message = @import("message_parser.zig").Message; -const std = @import("std"); -const Queue = std.Io.Queue; - -const Client = @This(); - -pub const Msgs = union(enum) { - MSG: Message.Msg, - HMSG: Message.HMsg, -}; - -connect: ?Message.Connect, -// Used to own messages that we receive in our queues. -alloc: std.mem.Allocator, - -// Messages for this client to receive. -recv_queue: *Queue(Message), -msg_queue: *Queue(Msgs), - -from_client: *std.Io.Reader, -to_client: *std.Io.Writer, - -pub fn init( - connect: ?Message.Connect, - alloc: std.mem.Allocator, - recv_queue: *Queue(Message), - msg_queue: *Queue(Msgs), - in: *std.Io.Reader, - out: *std.Io.Writer, -) Client { - return .{ - .connect = connect, - .alloc = alloc, - .recv_queue = recv_queue, - .msg_queue = msg_queue, - .from_client = in, - .to_client = out, - }; -} - -pub fn deinit(self: *Client, alloc: std.mem.Allocator) void { - if (self.connect) |c| { - c.deinit(alloc); - } - self.* = undefined; -} - -pub fn start(self: *Client, io: std.Io) !void { - var msgs_buf: [1024]Msgs = undefined; - - var recv_msgs_task = io.concurrent(Queue(Msgs).get, .{ self.msg_queue, io, &msgs_buf, 1 }) catch @panic("Concurrency unavailable"); - errdefer _ = recv_msgs_task.cancel(io) catch {}; - - var recv_proto_task = io.concurrent(Queue(Message).getOne, .{ self.recv_queue, io }) catch unreachable; - errdefer _ = recv_proto_task.cancel(io) catch {}; - - while (true) { - switch (try io.select(.{ .msgs = &recv_msgs_task, .proto = &recv_proto_task })) { - .msgs => |len_err| { - @branchHint(.likely); - const msgs = msgs_buf[0..try len_err]; - for (0..msgs.len) |i| { - const msg = msgs[i]; - defer switch (msg) { - .MSG => |m| m.deinit(self.alloc), - .HMSG => |h| h.deinit(self.alloc), - }; - errdefer for (msgs[i + 1 ..]) |mg| switch (mg) { - .MSG => |m| { - m.deinit(self.alloc); - }, - .HMSG => |h| { - h.deinit(self.alloc); - }, - }; - switch (msg) { - .MSG => |m| { - try self.to_client.print( - "MSG {s} {s} {s} {d}\r\n", - .{ - m.subject, - m.sid, - m.reply_to orelse "", - m.payload.len, - }, - ); - try m.payload.write(self.to_client); - try self.to_client.print("\r\n", .{}); - }, - .HMSG => |hmsg| { - try self.to_client.print("HMSG {s} {s} {s} {d} {d}\r\n", .{ - hmsg.msg.subject, - hmsg.msg.sid, - hmsg.msg.reply_to orelse "", - hmsg.header_bytes, - hmsg.msg.payload.len, - }); - try hmsg.msg.payload.write(self.to_client); - try self.to_client.print("\r\n", .{}); - }, - } - } - recv_msgs_task = io.concurrent(Queue(Msgs).get, .{ self.msg_queue, io, &msgs_buf, 1 }) catch unreachable; - }, - .proto => |msg_err| { - @branchHint(.unlikely); - const msg = try msg_err; - switch (msg) { - .@"+OK" => { - _ = try self.to_client.write("+OK\r\n"); - }, - .PONG => { - _ = try self.to_client.write("PONG\r\n"); - }, - .INFO => |info| { - _ = try self.to_client.write("INFO "); - try std.json.Stringify.value(info, .{}, self.to_client); - _ = try self.to_client.write("\r\n"); - }, - .@"-ERR" => |s| { - _ = try self.to_client.print("-ERR '{s}'\r\n", .{s}); - }, - else => |m| { - std.debug.panic("unimplemented write: {any}\n", .{m}); - }, - } - recv_proto_task = io.concurrent(Queue(Message).getOne, .{ self.recv_queue, io }) catch unreachable; - }, - } - try self.to_client.flush(); - } -} - -pub fn send(self: *Client, io: std.Io, msg: Message) !void { - try self.recv_queue.putOne(io, msg); -} - -pub fn next(self: *Client, allocator: std.mem.Allocator) !Message { - return Message.next(allocator, self.from_client); -} - -test { - const io = std.testing.io; - const gpa = std.testing.allocator; - - var from_client: std.Io.Reader = .fixed( - "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\n" ++ - "PING\r\n", - ); - var from_client_buf: [1024]Message = undefined; - var from_client_queue: std.Io.Queue(Message) = .init(&from_client_buf); - - { - // Simulate stream - while (Message.next(gpa, &from_client)) |msg| { - try from_client_queue.putOne(io, msg); - } else |err| switch (err) { - error.EndOfStream => from_client_queue.close(io), - else => return err, - } - - while (from_client_queue.getOne(io)) |msg| { - switch (msg) { - .connect => |*c| { - std.debug.print("Message: {any}\n", .{msg}); - c.deinit(gpa); - }, - else => { - std.debug.print("Message: {any}\n", .{msg}); - }, - } - } else |_| {} - } - - from_client_queue = .init(&from_client_buf); - // Reset the reader to process it again. - from_client.seek = 0; - - // { - // const SemiClient = struct { - // q: std.Io.Queue(Message), - - // fn parseClientInput(self: *@This(), ioh: std.Io, in: *std.Io.Reader) void { - // defer std.debug.print("done parse\n", .{}); - // while (Message.next(gpa, in)) |msg| { - // self.q.putOne(ioh, msg) catch return; - // } else |_| {} - // } - - // fn next(self: *@This(), ioh: std.Io) !Message { - // return self.q.getOne(ioh); - // } - - // fn printAll(self: *@This(), ioh: std.Io) void { - // defer std.debug.print("done print\n", .{}); - // while (self.next(ioh)) |*msg| { - // std.debug.print("Client msg: {any}\n", .{msg}); - // switch (msg.*) { - // .connect => |c| { - // c.deinit(gpa); - // }, - // else => {}, - // } - // } else |_| {} - // } - // }; - - // var c: SemiClient = .{ .q = from_client_queue }; - // var group: std.Io.Group = .init; - // defer group.wait(io); - - // group.concurrent(io, SemiClient.printAll, .{ &c, io }) catch { - // @panic("could not start printAll\n"); - // }; - - // group.concurrent(io, SemiClient.parseClientInput, .{ &c, io, &from_client }) catch { - // @panic("could not start printAll\n"); - // }; - // } - - //////// - - // const connect = (Message.next(gpa, &from_client) catch unreachable).connect; - - // var to_client_alloc: std.Io.Writer.Allocating = .init(gpa); - // defer to_client_alloc.deinit(); - // var to_client = to_client_alloc.writer; - - // var client: ClientState = try .init(io, gpa, 0, connect, &from_client, &to_client); - // defer client.deinit(gpa); - - // { - // var get_next = io.concurrent(ClientState.next, .{ &client, io }) catch unreachable; - // defer if (get_next.cancel(io)) |_| {} else |_| @panic("fail"); - - // var timeout = io.concurrent(std.Io.sleep, .{ io, .fromMilliseconds(1000), .awake }) catch unreachable; - // defer timeout.cancel(io) catch {}; - - // switch (try io.select(.{ - // .get_next = &get_next, - // .timeout = &timeout, - // })) { - // .get_next => |next| { - // std.debug.print("next is {any}\n", .{next}); - // try std.testing.expect((next catch |err| return err) == .ping); - // }, - // .timeout => { - // std.debug.print("reached timeout\n", .{}); - // return error.TestUnexpectedResult; - // }, - // } - // } -} diff --git a/src/server/Server.zig b/src/server/Server.zig deleted file mode 100644 index 18214ae..0000000 --- a/src/server/Server.zig +++ /dev/null @@ -1,417 +0,0 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; -const ArrayList = std.ArrayList; -const AutoHashMapUnmanaged = std.AutoHashMapUnmanaged; - -const Io = std.Io; -const Dir = Io.Dir; -const Group = Io.Group; -const IpAddress = std.Io.net.IpAddress; -const Mutex = Io.Mutex; -const Queue = Io.Queue; -const Stream = std.Io.net.Stream; - -const message_parser = @import("./message_parser.zig"); -pub const MessageType = message_parser.MessageType; -pub const Message = message_parser.Message; -const ServerInfo = Message.ServerInfo; -pub const Client = @import("./Client.zig"); -const Msgs = Client.Msgs; -const Server = @This(); - -const builtin = @import("builtin"); -const safe_build = builtin.mode == .Debug or builtin.mode == .ReleaseSafe; - -pub const Subscription = struct { - subject: []const u8, - client_id: usize, - sid: []const u8, - queue_group: ?[]const u8, - queue: *Queue(Msgs), - // used to alloc messages in the queue - alloc: Allocator, - - fn deinit(self: Subscription, alloc: Allocator) void { - alloc.free(self.subject); - alloc.free(self.sid); - if (self.queue_group) |g| alloc.free(g); - } -}; - -const eql = std.mem.eql; -const log = std.log; -const panic = std.debug.panic; - -info: ServerInfo, -clients: AutoHashMapUnmanaged(usize, *Client) = .empty, - -subs_lock: Mutex = .init, -subscriptions: ArrayList(Subscription) = .empty, - -pub fn deinit(server: *Server, io: Io, alloc: Allocator) void { - server.subs_lock.lockUncancelable(io); - defer server.subs_lock.unlock(io); - for (server.subscriptions.items) |sub| { - sub.deinit(alloc); - } - // TODO drain subscription queues - server.subscriptions.deinit(alloc); - server.clients.deinit(alloc); -} - -pub fn start(server: *Server, io: Io, gpa: Allocator) !void { - var tcp_server = try IpAddress.listen(try IpAddress.parse( - server.info.host, - server.info.port, - ), io, .{}); - defer tcp_server.deinit(io); - log.debug("Server headers: {s}", .{if (server.info.headers) "true" else "false"}); - log.debug("Server max payload: {d}", .{server.info.max_payload}); - log.info("Server ID: {s}", .{server.info.server_id}); - log.info("Server name: {s}", .{server.info.server_name}); - log.info("Server listening on {s}:{d}", .{ server.info.host, server.info.port }); - - var client_group: Group = .init; - defer client_group.cancel(io); - - const read_buffer_size, const write_buffer_size = getBufferSizes(io); - log.debug("read buf: {d} write buf: {d}", .{ read_buffer_size, write_buffer_size }); - - var id: usize = 0; - while (true) : (id +%= 1) { - if (server.clients.contains(id)) continue; - log.debug("Accepting next client", .{}); - const stream = try tcp_server.accept(io); - log.debug("Accepted connection {d}", .{id}); - _ = client_group.concurrent(io, handleConnectionInfallible, .{ - server, - gpa, - io, - id, - stream, - read_buffer_size, - write_buffer_size, - }) catch { - log.err("Could not start concurrent handler for {d}", .{id}); - stream.close(io); - }; - } -} - -fn addClient(server: *Server, allocator: Allocator, id: usize, client: *Client) !void { - try server.clients.put(allocator, id, client); -} - -fn removeClient(server: *Server, io: Io, allocator: Allocator, id: usize) void { - server.subs_lock.lockUncancelable(io); - defer server.subs_lock.unlock(io); - if (server.clients.remove(id)) { - const len = server.subscriptions.items.len; - for (0..len) |from_end| { - const i = len - from_end - 1; - const sub = server.subscriptions.items[i]; - if (sub.client_id == id) { - sub.deinit(allocator); - _ = server.subscriptions.swapRemove(i); - } - } - } -} - -fn handleConnectionInfallible( - server: *Server, - server_allocator: Allocator, - io: Io, - id: usize, - stream: Stream, - r_buf_size: usize, - w_buf_size: usize, -) !void { - handleConnection(server, server_allocator, io, id, stream, r_buf_size, w_buf_size) catch |err| switch (err) { - error.Canceled => return error.Canceled, - else => log.err("Failed processing client {d}: {any}", .{ id, err }), - }; -} - -fn handleConnection( - server: *Server, - server_allocator: Allocator, - io: Io, - id: usize, - stream: Stream, - r_buf_size: usize, - w_buf_size: usize, -) !void { - defer stream.close(io); - - var dba: std.heap.DebugAllocator(.{}) = .init; - dba.backing_allocator = server_allocator; - defer _ = dba.deinit(); - const alloc = if (safe_build) dba.allocator() else server_allocator; - - // Set up client writer - const w_buffer: []u8 = try alloc.alloc(u8, w_buf_size); - defer alloc.free(w_buffer); - var writer = stream.writer(io, w_buffer); - const out = &writer.interface; - - // Set up client reader - const r_buffer: []u8 = try alloc.alloc(u8, r_buf_size); - defer alloc.free(r_buffer); - var reader = stream.reader(io, r_buffer); - const in = &reader.interface; - - // Set up buffer queue - const qbuf: []Message = try alloc.alloc(Message, 16); - defer alloc.free(qbuf); - var recv_queue: Queue(Message) = .init(qbuf); - defer recv_queue.close(io); - - const mbuf: []Msgs = try alloc.alloc(Msgs, w_buf_size / @sizeOf(Msgs)); - defer alloc.free(mbuf); - var msgs_queue: Queue(Msgs) = .init(mbuf); - defer { - msgs_queue.close(io); - while (msgs_queue.getOne(io)) |msg| { - switch (msg) { - .MSG => |m| m.deinit(alloc), - .HMSG => |h| h.deinit(alloc), - } - } else |_| {} - } - - // Create client - var client: Client = .init(null, alloc, &recv_queue, &msgs_queue, in, out); - defer client.deinit(server_allocator); - - try server.addClient(server_allocator, id, &client); - defer server.removeClient(io, server_allocator, id); - - // Do initial handshake with client - // try recv_queue.putOne(io, .PONG); - try recv_queue.putOne(io, .{ .INFO = server.info }); - - var client_task = try io.concurrent(Client.start, .{ &client, io }); - defer client_task.cancel(io) catch {}; - - // Messages are owned by the server after they are received from the client - while (client.next(server_allocator)) |msg| { - switch (msg) { - .PING => { - // Respond to ping with pong. - try client.send(io, .PONG); - }, - .PUB => |pb| { - @branchHint(.likely); - defer pb.deinit(server_allocator); - try server.publishMessage(io, server_allocator, &client, msg); - }, - .HPUB => |hp| { - @branchHint(.likely); - defer hp.deinit(server_allocator); - try server.publishMessage(io, server_allocator, &client, msg); - }, - .SUB => |sub| { - defer sub.deinit(server_allocator); - try server.subscribe(io, server_allocator, client, id, sub); - }, - .UNSUB => |unsub| { - defer unsub.deinit(server_allocator); - try server.unsubscribe(io, server_allocator, id, unsub); - }, - .CONNECT => |connect| { - if (client.connect) |*current| { - current.deinit(server_allocator); - } - client.connect = connect; - }, - else => |e| { - panic("Unimplemented message: {any}\n", .{e}); - }, - } - } else |err| switch (err) { - error.EndOfStream, error.ReadFailed => { - log.debug("Client {d} disconnected", .{id}); - return error.Canceled; - }, - else => { - return err; - }, - } -} - -fn subjectMatches(sub_subject: []const u8, pub_subject: []const u8) bool { - // TODO: assert that sub_subject and pub_subject are valid. - var sub_iter = std.mem.splitScalar(u8, sub_subject, '.'); - var pub_iter = std.mem.splitScalar(u8, pub_subject, '.'); - - while (sub_iter.next()) |st| { - const pt = pub_iter.next() orelse return false; - - if (eql(u8, st, ">")) return true; - - if (!eql(u8, st, "*") and !eql(u8, st, pt)) { - return false; - } - } - - return pub_iter.next() == null; -} - -test subjectMatches { - const expect = std.testing.expect; - try expect(subjectMatches("foo", "foo")); - try expect(!subjectMatches("foo", "bar")); - - try expect(subjectMatches("foo.*", "foo.bar")); - try expect(!subjectMatches("foo.*", "foo")); - try expect(!subjectMatches("foo.>", "foo")); - - // the wildcard subscriptions foo.*.quux and foo.> both match foo.bar.quux, but only the latter matches foo.bar.baz. - try expect(subjectMatches("foo.*.quux", "foo.bar.quux")); - try expect(subjectMatches("foo.>", "foo.bar.quux")); - try expect(!subjectMatches("foo.*.quux", "foo.bar.baz")); - try expect(subjectMatches("foo.>", "foo.bar.baz")); -} - -fn publishMessage( - server: *Server, - io: Io, - alloc: Allocator, - source_client: *Client, - msg: Message, -) !void { - defer if (source_client.connect) |c| { - if (c.verbose) { - source_client.send(io, .@"+OK") catch {}; - } - }; - - const subject = switch (msg) { - .PUB => |pb| pb.subject, - .HPUB => |hp| hp.@"pub".subject, - else => unreachable, - }; - try server.subs_lock.lock(io); - defer server.subs_lock.unlock(io); - var published_queue_groups: ArrayList([]const u8) = .empty; - defer published_queue_groups.deinit(alloc); - var published_queue_sub_idxs: ArrayList(usize) = .empty; - defer published_queue_sub_idxs.deinit(alloc); - - subs: for (0..server.subscriptions.items.len) |i| { - const subscription = server.subscriptions.items[i]; - if (subjectMatches(subscription.subject, subject)) { - if (subscription.queue_group) |sg| { - for (published_queue_groups.items) |g| { - if (eql(u8, g, sg)) { - continue :subs; - } - } - // Don't republish to the same queue - try published_queue_groups.append(alloc, sg); - // Move this index to the end of the subscription list, - // to prioritize other subscriptions in the queue next time. - try published_queue_sub_idxs.append(alloc, i); - } - switch (msg) { - .PUB => |pb| { - try subscription.queue.putOne(io, .{ - .MSG = try pb.toMsg(subscription.alloc, subscription.sid), - }); - }, - .HPUB => |hp| { - try subscription.queue.putOne(io, .{ - .HMSG = try hp.toHMsg(subscription.alloc, subscription.sid), - }); - }, - else => unreachable, - } - } - } - - for (0..published_queue_sub_idxs.items.len) |from_end| { - const i = published_queue_sub_idxs.items.len - from_end - 1; - server.subscriptions.appendAssumeCapacity(server.subscriptions.orderedRemove(i)); - } -} - -fn subscribe( - server: *Server, - io: Io, - gpa: Allocator, - client: Client, - id: usize, - msg: Message.Sub, -) !void { - try server.subs_lock.lock(io); - defer server.subs_lock.unlock(io); - const subject = try gpa.dupe(u8, msg.subject); - errdefer gpa.free(subject); - const sid = try gpa.dupe(u8, msg.sid); - errdefer gpa.free(sid); - const queue_group = if (msg.queue_group) |q| try gpa.dupe(u8, q) else null; - errdefer if (queue_group) |q| gpa.free(q); - try server.subscriptions.append(gpa, .{ - .subject = subject, - .client_id = id, - .sid = sid, - .queue_group = queue_group, - .queue = client.msg_queue, - .alloc = client.alloc, - }); -} - -fn unsubscribe( - server: *Server, - io: Io, - gpa: Allocator, - id: usize, - msg: Message.Unsub, -) !void { - try server.subs_lock.lock(io); - defer server.subs_lock.unlock(io); - const len = server.subscriptions.items.len; - for (0..len) |from_end| { - const i = len - from_end - 1; - const sub = server.subscriptions.items[i]; - if (sub.client_id == id and eql(u8, sub.sid, msg.sid)) { - sub.deinit(gpa); - _ = server.subscriptions.swapRemove(i); - } - } -} - -const parseUnsigned = std.fmt.parseUnsigned; - -fn getBufferSizes(io: Io) struct { usize, usize } { - const default_size = 4 * 1024; - const default = .{ default_size, default_size }; - - const dir = Dir.openDirAbsolute(io, "/proc/sys/net/core", .{}) catch { - log.warn("couldn't open /proc/sys/net/core", .{}); - return default; - }; - - var buf: [64]u8 = undefined; - - const rmem_max = readBufferSize(io, dir, "rmem_max", &buf, default_size); - const wmem_max = readBufferSize(io, dir, "wmem_max", &buf, default_size); - - return .{ rmem_max, wmem_max }; -} - -fn readBufferSize(io: Io, dir: anytype, filename: []const u8, buf: []u8, default: usize) usize { - const bytes = dir.readFile(io, filename, buf) catch |err| { - log.err("couldn't open {s}: {any}", .{ filename, err }); - return default; - }; - - return parseUnsigned(usize, bytes[0 .. bytes.len - 1], 10) catch |err| { - log.err("couldn't parse {s}: {any}", .{ bytes[0 .. bytes.len - 1], err }); - return default; - }; -} - -pub const default_id = "server-id-123"; -pub const default_name = "Zits Server"; diff --git a/src/server/main.zig b/src/server/main.zig deleted file mode 100644 index 1aaf572..0000000 --- a/src/server/main.zig +++ /dev/null @@ -1,66 +0,0 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; -const AtomicValue = std.atomic.Value; -const DebugAllocator = std.heap.DebugAllocator; -const Sigaction = std.posix.Sigaction; - -const Io = std.Io; -const Threaded = Io.Threaded; - -const builtin = @import("builtin"); - -const zits = @import("zits"); -const Message = zits.Server.Message; -const ServerInfo = Message.ServerInfo; - -const Server = zits.Server; - -const safe_build = builtin.mode == .Debug or builtin.mode == .ReleaseSafe; - -var keep_running = AtomicValue(bool).init(true); - -fn handleSigInt(sig: std.os.linux.SIG) callconv(.c) void { - _ = sig; - keep_running.store(false, .monotonic); -} - -pub fn main(outer_alloc: Allocator, server_config: ServerInfo) !void { - // Configure the signal action - const act = Sigaction{ - .handler = .{ .handler = handleSigInt }, - .mask = std.posix.sigemptyset(), - .flags = 0, - }; - - // Register the handler for SIGINT (Ctrl+C) - std.posix.sigaction(std.posix.SIG.INT, &act, null); - - { - var dba: DebugAllocator(.{}) = .init; - dba.backing_allocator = outer_alloc; - defer _ = dba.deinit(); - const alloc = if (safe_build) dba.allocator() else outer_alloc; - - var threaded: Threaded = .init(alloc, .{}); - defer threaded.deinit(); - const io = threaded.io(); - - var server: Server = .{ - .info = server_config, - }; - defer server.deinit(io, alloc); - - var server_task = try io.concurrent(Server.start, .{ &server, io, alloc }); - defer server_task.cancel(io) catch {}; - - // Block until Ctrl+C - while (keep_running.load(.monotonic)) { - try io.sleep(.fromMilliseconds(1), .awake); - } - - std.debug.print("\n", .{}); - std.log.info("Shutting down...", .{}); - server_task.cancel(io) catch {}; - } - std.log.info("Goodbye", .{}); -} diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig deleted file mode 100644 index fd1b5b1..0000000 --- a/src/server/message_parser.zig +++ /dev/null @@ -1,1141 +0,0 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; -const ArenaAllocator = std.heap.ArenaAllocator; -const ArrayList = std.ArrayList; -const StaticStringMap = std.StaticStringMap; - -const Io = std.Io; -const Writer = Io.Writer; -const AllocatingWriter = Writer.Allocating; -const Reader = Io.Reader; - -const ascii = std.ascii; -const isDigit = std.ascii.isDigit; -const isUpper = std.ascii.isUpper; -const isWhitespace = std.ascii.isWhitespace; - -const parseUnsigned = std.fmt.parseUnsigned; - -const log = std.log; - -pub const Payload = struct { - len: u32, - short: [128]u8, - long: ?[]u8, - - pub fn read(alloc: Allocator, in: *Reader, bytes: usize) !Payload { - var res: Payload = .{ - .len = @intCast(bytes), - .short = undefined, - .long = null, - }; - - try in.readSliceAll(res.short[0..@min(bytes, res.short.len)]); - if (bytes > res.short.len) { - const long = try alloc.alloc(u8, bytes - res.short.len); - errdefer alloc.free(long); - try in.readSliceAll(long); - res.long = long; - } - return res; - } - - pub fn write(self: Payload, out: *Writer) !void { - std.debug.assert(out.buffer.len >= self.short.len); - std.debug.assert(self.len <= self.short.len or self.long != null); - try out.writeAll(self.short[0..@min(self.len, self.short.len)]); - if (self.long) |l| { - try out.writeAll(l); - } - } - - pub fn deinit(self: Payload, alloc: Allocator) void { - if (self.long) |l| { - alloc.free(l); - } - } - - pub fn dupe(self: Payload, alloc: Allocator) !Payload { - var res = self; - if (self.long) |l| { - res.long = try alloc.dupe(u8, l); - } - errdefer if (res.long) |l| alloc.free(l); - return res; - } -}; - -pub const MessageType = @typeInfo(Message).@"union".tag_type.?; - -pub const Message = union(enum) { - INFO: ServerInfo, - CONNECT: Connect, - PUB: Pub, - HPUB: HPub, - SUB: Sub, - UNSUB: Unsub, - MSG: Msg, - HMSG: HMsg, - PING, - PONG, - @"+OK": void, - @"-ERR": []const u8, - pub const ServerInfo = struct { - /// The unique identifier of the NATS server. - server_id: []const u8, - /// The name of the NATS server. - server_name: []const u8, - /// The version of NATS. - version: []const u8, - /// The version of golang the NATS server was built with. - go: []const u8 = "0.0.0", - /// The IP address used to start the NATS server, - /// by default this will be 0.0.0.0 and can be - /// configured with -client_advertise host:port. - host: []const u8 = "0.0.0.0", - /// The port number the NATS server is configured - /// to listen on. - port: u16 = 4222, - /// Whether the server supports headers. - headers: bool = false, - /// Maximum payload size, in bytes, that the server - /// will accept from the client. - max_payload: u64, - /// An integer indicating the protocol version of - /// the server. The server version 1.2.0 sets this - /// to 1 to indicate that it supports the "Echo" - /// feature. - proto: u32 = 1, - }; - pub const Connect = struct { - verbose: bool = false, - pedantic: bool = false, - tls_required: bool = false, - auth_token: ?[]const u8 = null, - user: ?[]const u8 = null, - pass: ?[]const u8 = null, - name: ?[]const u8 = null, - lang: []const u8, - version: []const u8, - protocol: u32, - echo: ?bool = null, - sig: ?[]const u8 = null, - jwt: ?[]const u8 = null, - no_responders: ?bool = null, - headers: ?bool = null, - nkey: ?[]const u8 = null, - - pub fn deinit(self: Connect, alloc: Allocator) void { - if (self.auth_token) |a| alloc.free(a); - if (self.user) |u| alloc.free(u); - if (self.pass) |p| alloc.free(p); - if (self.name) |n| alloc.free(n); - alloc.free(self.lang); - alloc.free(self.version); - if (self.sig) |s| alloc.free(s); - if (self.jwt) |j| alloc.free(j); - if (self.nkey) |n| alloc.free(n); - } - - pub fn dupe(self: Connect, alloc: Allocator) !Connect { - var res = self; - res.auth_token = if (self.auth_token) |a| try alloc.dupe(u8, a) else null; - errdefer if (res.auth_token) |a| alloc.free(a); - res.user = if (self.user) |u| try alloc.dupe(u8, u) else null; - errdefer if (res.user) |u| alloc.free(u); - res.pass = if (self.pass) |p| try alloc.dupe(u8, p) else null; - errdefer if (res.pass) |p| alloc.free(p); - res.name = if (self.name) |n| try alloc.dupe(u8, n) else null; - errdefer if (res.name) |n| alloc.free(n); - res.lang = try alloc.dupe(u8, self.lang); - errdefer alloc.free(res.lang); - res.version = try alloc.dupe(u8, self.version); - errdefer alloc.free(res.version); - res.sig = if (self.sig) |s| try alloc.dupe(u8, s) else null; - errdefer if (res.sig) |s| alloc.free(s); - res.jwt = if (self.jwt) |j| try alloc.dupe(u8, j) else null; - errdefer if (res.jwt) |j| alloc.free(j); - res.nkey = if (self.nkey) |n| try alloc.dupe(u8, n) else null; - errdefer if (res.nkey) |n| alloc.free(n); - return res; - } - }; - pub const Pub = struct { - /// The destination subject to publish to. - subject: []const u8, - /// The reply subject that subscribers can use to send a response back to the publisher/requestor. - reply_to: ?[]const u8 = null, - /// The message payload data. - payload: Payload, - - pub fn deinit(self: Pub, alloc: Allocator) void { - alloc.free(self.subject); - self.payload.deinit(alloc); - if (self.reply_to) |r| alloc.free(r); - } - - pub fn toMsg(self: Pub, alloc: Allocator, sid: []const u8) !Msg { - const res: Msg = .{ - .subject = self.subject, - .sid = sid, - .reply_to = self.reply_to, - .payload = self.payload, - }; - return res.dupe(alloc); - } - }; - pub const HPub = struct { - header_bytes: usize, - @"pub": Pub, - - pub fn deinit(self: HPub, alloc: Allocator) void { - self.@"pub".deinit(alloc); - } - - pub fn toHMsg(self: HPub, alloc: Allocator, sid: []const u8) !HMsg { - return .{ - .header_bytes = self.header_bytes, - .msg = try self.@"pub".toMsg(alloc, sid), - }; - } - }; - - pub const HMsg = struct { - header_bytes: usize, - msg: Msg, - - pub fn deinit(self: HMsg, alloc: Allocator) void { - self.msg.deinit(alloc); - } - - pub fn dupe(self: HMsg, alloc: Allocator) !HMsg { - var res = self; - res.msg = try self.msg.dupe(alloc); - errdefer alloc.free(res.msg); - return res; - } - }; - pub const Sub = struct { - /// The subject name to subscribe to. - subject: []const u8, - /// If specified, the subscriber will join this queue group. - queue_group: ?[]const u8, - /// A unique alphanumeric subscription ID, generated by the client. - sid: []const u8, - - pub fn deinit(self: Sub, alloc: Allocator) void { - alloc.free(self.subject); - alloc.free(self.sid); - if (self.queue_group) |q| alloc.free(q); - } - }; - pub const Unsub = struct { - /// The unique alphanumeric subscription ID of the subject to unsubscribe from. - sid: []const u8, - /// A number of messages to wait for before automatically unsubscribing. - max_msgs: ?usize = null, - - pub fn deinit(self: Unsub, alloc: Allocator) void { - alloc.free(self.sid); - } - }; - pub const Msg = struct { - subject: []const u8, - sid: []const u8, - reply_to: ?[]const u8, - payload: Payload, - - pub fn deinit(self: Msg, alloc: Allocator) void { - alloc.free(self.subject); - alloc.free(self.sid); - if (self.reply_to) |r| alloc.free(r); - self.payload.deinit(alloc); - } - - pub fn dupe(self: Msg, alloc: Allocator) !Msg { - var res: Msg = undefined; - res.subject = try alloc.dupe(u8, self.subject); - errdefer alloc.free(res.subject); - res.sid = try alloc.dupe(u8, self.sid); - errdefer alloc.free(res.sid); - res.reply_to = if (self.reply_to) |r| try alloc.dupe(u8, r) else null; - errdefer if (res.reply_to) |r| alloc.free(r); - res.payload = try self.payload.dupe(alloc); - errdefer alloc.free(res.payload); - return res; - } - }; - - const client_types = StaticStringMap(MessageType).initComptime( - .{ - // {"INFO", .info}, - .{ @tagName(.CONNECT), .CONNECT }, - .{ @tagName(.PUB), .PUB }, - .{ @tagName(.HPUB), .HPUB }, - .{ @tagName(.SUB), .SUB }, - .{ @tagName(.UNSUB), .UNSUB }, - // {"MSG", .msg}, - // {"HMSG", .hmsg}, - .{ @tagName(.PING), .PING }, - .{ @tagName(.PONG), .PONG }, - // {"+OK", .@"+ok"}, - // {"-ERR", .@"-err"}, - }, - ); - fn parseStaticStringMap(input: []const u8) ?MessageType { - return client_types.get(input); - } - - pub const parse = parseStaticStringMap; - - /// An error should be handled by cleaning up this connection. - pub fn next(alloc: Allocator, in: *Reader) !Message { - var operation_string: ArrayList(u8) = blk: { - comptime var buf_len = 0; - comptime { - for (client_types.keys()) |key| { - buf_len = @max(buf_len, key.len); - } - } - var buf: [buf_len]u8 = undefined; - break :blk .initBuffer(&buf); - }; - - while (in.peekByte()) |byte| { - if (isUpper(byte)) { - try operation_string.appendBounded(byte); - in.toss(1); - } else break; - } else |err| return err; - - const operation = parse(operation_string.items) orelse { - log.err("Invalid operation: '{s}'", .{operation_string.items}); - return error.InvalidOperation; - }; - - errdefer log.err("Failed to parse {s}", .{operation_string.items}); - - switch (operation) { - .CONNECT => { - return parseConnect(alloc, in); - }, - .PUB => { - @branchHint(.likely); - return parsePub(alloc, in); - }, - .HPUB => { - @branchHint(.likely); - return parseHPub(alloc, in); - }, - .PING => { - try expectStreamBytes(in, "\r\n"); - return .PING; - }, - .PONG => { - try expectStreamBytes(in, "\r\n"); - return .PONG; - }, - .SUB => { - return parseSub(alloc, in); - }, - .UNSUB => { - return parseUnsub(alloc, in); - }, - else => |msg| std.debug.panic("Not implemented: {}\n", .{msg}), - } - } -}; - -fn parseConnect(alloc: Allocator, in: *Reader) !Message { - // for storing the json string - var connect_string_writer_allocating: AllocatingWriter = .init(alloc); - defer connect_string_writer_allocating.deinit(); - var connect_string_writer = &connect_string_writer_allocating.writer; - - // for parsing the json string - var connect_arena_allocator: ArenaAllocator = .init(alloc); - defer connect_arena_allocator.deinit(); - const connect_allocator = connect_arena_allocator.allocator(); - - try in.discardAll(1); // throw away space - - // Should read the next JSON object to the fixed buffer writer. - _ = try in.streamDelimiter(connect_string_writer, '}'); - try connect_string_writer.writeByte('}'); - try expectStreamBytes(in, "}\r\n"); // discard '}\r\n' - - const connect_str = try connect_string_writer_allocating.toOwnedSlice(); - defer alloc.free(connect_str); - // TODO: should be CONNECTION allocator - const res = try std.json.parseFromSliceLeaky( - Message.Connect, - connect_allocator, - connect_str, - .{ .allocate = .alloc_always }, - ); - - return .{ .CONNECT = try res.dupe(alloc) }; -} - -fn parseSub(alloc: Allocator, in: *Reader) !Message { - try in.discardAll(1); // throw away space - const subject = try readSubject(alloc, in, .sub); - - const States = enum { - before_second, - in_second, - after_second, - in_third, - in_end, - }; - - var second: ArrayList(u8) = .empty; - errdefer second.deinit(alloc); - var third: ?ArrayList(u8) = null; - errdefer if (third) |*t| t.deinit(alloc); - - sw: switch (@as(States, .before_second)) { - .before_second => { - const byte = try in.peekByte(); - if (isWhitespace(byte)) { - in.toss(1); - continue :sw .before_second; - } - continue :sw .in_second; - }, - .in_second => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .after_second; - }, - .after_second => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } else if (isWhitespace(byte)) { - in.toss(1); - continue :sw .after_second; - } - third = .empty; - continue :sw .in_third; - }, - .in_third => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try third.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .in_end; - }, - .in_end => { - try expectStreamBytes(in, "\r\n"); - }, - } - - return .{ - .SUB = .{ - .subject = subject, - .queue_group = if (third) |_| try second.toOwnedSlice(alloc) else null, - .sid = if (third) |*t| try t.toOwnedSlice(alloc) else try second.toOwnedSlice(alloc), - }, - }; -} - -test parseSub { - const alloc = std.testing.allocator; - const expectEqualDeep = std.testing.expectEqualDeep; - { - var in: Reader = .fixed(" foo 1\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "foo", - .queue_group = null, - .sid = "1", - }, - }, - res, - ); - } - { - var in: Reader = .fixed(" foo 1\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "foo", - .queue_group = null, - .sid = "1", - }, - }, - res, - ); - } - { - var in: Reader = .fixed(" foo q 1\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "foo", - .queue_group = "q", - .sid = "1", - }, - }, - res, - ); - } - { - var in: Reader = .fixed(" 1 q 1\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "1", - .queue_group = "q", - .sid = "1", - }, - }, - res, - ); - } - { - var in: Reader = .fixed(" $SRV.PING 4\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "$SRV.PING", - .queue_group = null, - .sid = "4", - }, - }, - res, - ); - } - { - var in: Reader = .fixed(" foo.echo q 10\r\n"); - var res = try parseSub(alloc, &in); - defer res.SUB.deinit(alloc); - try expectEqualDeep( - Message{ - .SUB = .{ - .subject = "foo.echo", - .queue_group = "q", - .sid = "10", - }, - }, - res, - ); - } -} - -fn parseUnsub(alloc: Allocator, in: *Reader) !Message { - const States = enum { - before_first, - in_first, - after_first, - in_second, - in_end, - }; - - var first: ArrayList(u8) = .empty; - errdefer first.deinit(alloc); - var second: ?ArrayList(u8) = null; - defer if (second) |*s| s.deinit(alloc); - - sw: switch (@as(States, .before_first)) { - .before_first => { - const byte = try in.peekByte(); - if (isWhitespace(byte)) { - in.toss(1); - continue :sw .before_first; - } - continue :sw .in_first; - }, - .in_first => { - const byte = try in.peekByte(); - if (!isWhitespace(byte)) { - try first.append(alloc, byte); - in.toss(1); - continue :sw .in_first; - } - continue :sw .after_first; - }, - .after_first => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } else if (isWhitespace(byte)) { - in.toss(1); - continue :sw .after_first; - } - second = .empty; - continue :sw .in_second; - }, - .in_second => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } - try second.?.append(alloc, byte); - in.toss(1); - continue :sw .in_second; - }, - .in_end => { - try expectStreamBytes(in, "\r\n"); - }, - } - - return .{ - .UNSUB = .{ - .sid = try first.toOwnedSlice(alloc), - .max_msgs = if (second) |s| try parseUnsigned(usize, s.items, 10) else null, - }, - }; -} - -test parseUnsub { - const alloc = std.testing.allocator; - const expectEqualDeep = std.testing.expectEqualDeep; - const expectEqual = std.testing.expectEqual; - { - var in: Reader = .fixed(" 1\r\n"); - var res = try parseUnsub(alloc, &in); - defer res.UNSUB.deinit(alloc); - try expectEqualDeep( - Message{ - .UNSUB = .{ - .sid = "1", - .max_msgs = null, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } - - { - var in: Reader = .fixed(" 1 1\r\n"); - var res = try parseUnsub(alloc, &in); - defer res.UNSUB.deinit(alloc); - try expectEqualDeep( - Message{ - .UNSUB = .{ - .sid = "1", - .max_msgs = 1, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } -} - -fn parsePub(alloc: Allocator, in: *Reader) !Message { - try in.discardAll(1); // throw away space - - // Parse subject - const subject: []const u8 = try readSubject(alloc, in, .@"pub"); - errdefer alloc.free(subject); - - const States = enum { - before_second, - in_second, - after_second, - in_third, - in_end, - }; - - var second: ArrayList(u8) = .empty; - defer second.deinit(alloc); - var third: ?ArrayList(u8) = null; - defer if (third) |*t| t.deinit(alloc); - - sw: switch (@as(States, .before_second)) { - .before_second => { - // Drop whitespace - const byte = try in.peekByte(); - if (isWhitespace(byte)) { - in.toss(1); - continue :sw .before_second; - } - continue :sw .in_second; - }, - .in_second => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .after_second; - }, - .after_second => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } else if (isWhitespace(byte)) { - in.toss(1); - continue :sw .after_second; - } - third = .empty; - continue :sw .in_third; - }, - .in_third => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try third.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .in_end; - }, - .in_end => { - try expectStreamBytes(in, "\r\n"); - }, - } - - const reply_to: ?[]const u8, const bytes: usize = - if (third) |t| .{ - try alloc.dupe(u8, second.items), - try parseUnsigned(usize, t.items, 10), - } else .{ - null, - try parseUnsigned(usize, second.items, 10), - }; - - const payload: Payload = try .read(alloc, in, bytes); - errdefer payload.deinit(alloc); - try expectStreamBytes(in, "\r\n"); - - return .{ - .PUB = .{ - .subject = subject, - .payload = payload, - .reply_to = reply_to, - }, - }; -} - -test parsePub { - const alloc = std.testing.allocator; - const expectEqualDeep = std.testing.expectEqualDeep; - const expectEqual = std.testing.expectEqual; - { - var in: Reader = .fixed(" foo 3\r\nbar\r\n"); - var res = try parsePub(alloc, &in); - defer res.PUB.deinit(alloc); - try expectEqualDeep( - Message{ - .PUB = .{ - .subject = "foo", - .reply_to = null, - .payload = .{ - .len = 3, - .short = blk: { - var s: [128]u8 = undefined; - @memcpy(s[0..3], "bar"); - break :blk s; - }, - .long = null, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } - - { - var in: Reader = .fixed(" foo reply.to 3\r\nbar\r\n"); - var res = try parsePub(alloc, &in); - defer res.PUB.deinit(alloc); - try expectEqualDeep( - Message{ - .PUB = .{ - .subject = "foo", - .reply_to = "reply.to", - .payload = .{ - .len = 3, - .short = blk: { - var s: [128]u8 = undefined; - @memcpy(s[0..3], "bar"); - break :blk s; - }, - .long = null, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } - - // numeric reply subject - { - var in: Reader = .fixed(" foo 5 3\r\nbar\r\n"); - var res = try parsePub(alloc, &in); - defer res.PUB.deinit(alloc); - try expectEqualDeep( - Message{ - .PUB = .{ - .subject = "foo", - .reply_to = "5", - .payload = .{ - .len = 3, - .short = blk: { - var s: [128]u8 = undefined; - @memcpy(s[0..3], "bar"); - break :blk s; - }, - .long = null, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } -} - -fn parseHPub(alloc: Allocator, in: *Reader) !Message { - try in.discardAll(1); // throw away space - - // Parse subject - const subject: []const u8 = try readSubject(alloc, in, .@"pub"); - errdefer alloc.free(subject); - - const States = enum { - before_second, - in_second, - after_second, - in_third, - after_third, - in_fourth, - in_end, - }; - - var second: ArrayList(u8) = .empty; - defer second.deinit(alloc); - var third: ArrayList(u8) = .empty; - defer third.deinit(alloc); - var fourth: ?ArrayList(u8) = null; - defer if (fourth) |*f| f.deinit(alloc); - - sw: switch (@as(States, .before_second)) { - .before_second => { - // Drop whitespace - const byte = try in.peekByte(); - if (isWhitespace(byte)) { - in.toss(1); - continue :sw .before_second; - } - continue :sw .in_second; - }, - .in_second => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try second.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .after_second; - }, - .after_second => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } else if (isWhitespace(byte)) { - in.toss(1); - continue :sw .after_second; - } - third = .empty; - continue :sw .in_third; - }, - .in_third => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try third.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .after_third; - }, - .after_third => { - const byte = try in.peekByte(); - if (byte == '\r') { - continue :sw .in_end; - } else if (isWhitespace(byte)) { - in.toss(1); - continue :sw .after_third; - } - fourth = .empty; - continue :sw .in_fourth; - }, - .in_fourth => { - for (1..in.buffer.len) |i| { - try in.fill(i + 1); - if (isWhitespace(in.buffered()[i])) { - @memcpy(try fourth.?.addManyAsSlice(alloc, i), in.buffered()[0..i]); - in.toss(i); - break; - } - } else return error.EndOfStream; - continue :sw .in_end; - }, - .in_end => { - try expectStreamBytes(in, "\r\n"); - }, - } - - const reply_to: ?[]const u8, const header_bytes: usize, const total_bytes: usize = - if (fourth) |f| .{ - try alloc.dupe(u8, second.items), - try parseUnsigned(usize, third.items, 10), - try parseUnsigned(usize, f.items, 10), - } else .{ - null, - try parseUnsigned(usize, second.items, 10), - try parseUnsigned(usize, third.items, 10), - }; - - const payload: Payload = try .read(alloc, in, total_bytes); - errdefer payload.deinit(alloc); - try expectStreamBytes(in, "\r\n"); - - return .{ - .HPUB = .{ - .header_bytes = header_bytes, - .@"pub" = .{ - .subject = subject, - .payload = payload, - .reply_to = reply_to, - }, - }, - }; -} - -test parseHPub { - const alloc = std.testing.allocator; - const expectEqualDeep = std.testing.expectEqualDeep; - const expectEqual = std.testing.expectEqual; - { - var in: Reader = .fixed(" foo 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); - var res = try parseHPub(alloc, &in); - defer res.HPUB.deinit(alloc); - try expectEqualDeep( - Message{ - .HPUB = .{ - .header_bytes = 22, - .@"pub" = .{ - .subject = "foo", - .reply_to = null, - .payload = .{ - .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, - .short = blk: { - var s: [128]u8 = undefined; - const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; - @memcpy(s[0..str.len], str); - break :blk s; - }, - .long = null, - }, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } - - { - var in: Reader = .fixed(" foo reply.to 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); - var res = try parseHPub(alloc, &in); - defer res.HPUB.deinit(alloc); - try expectEqualDeep( - Message{ - .HPUB = .{ - .header_bytes = 22, - .@"pub" = .{ - .subject = "foo", - .reply_to = "reply.to", - .payload = .{ - .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, - .short = blk: { - var s: [128]u8 = undefined; - const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; - @memcpy(s[0..str.len], str); - break :blk s; - }, - .long = null, - }, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } - - { - var in: Reader = .fixed(" foo 6 22 33\r\nNATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!\r\n"); - var res = try parseHPub(alloc, &in); - defer res.HPUB.deinit(alloc); - try expectEqualDeep( - Message{ - .HPUB = .{ - .header_bytes = 22, - .@"pub" = .{ - .subject = "foo", - .reply_to = "6", - .payload = .{ - .len = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!".len, - .short = blk: { - var s: [128]u8 = undefined; - const str = "NATS/1.0\r\nBar: Baz\r\n\r\nHello NATS!"; - @memcpy(s[0..str.len], str); - break :blk s; - }, - .long = null, - }, - }, - }, - }, - res, - ); - try expectEqual(0, in.buffered().len); - } -} - -fn readSubject(alloc: Allocator, in: *Reader, comptime pub_or_sub: enum { @"pub", sub }) ![]const u8 { - var subject_list: ArrayList(u8) = .empty; - errdefer subject_list.deinit(alloc); - - // Handle the first character - { - const byte = try in.takeByte(); - if (isWhitespace(byte) or byte == '.' or (pub_or_sub == .@"pub" and (byte == '*' or byte == '>'))) - return error.InvalidStream; - - try subject_list.append(alloc, byte); - } - - switch (pub_or_sub) { - .sub => { - while (in.takeByte()) |byte| { - if (isWhitespace(byte)) break; - if (byte == '.') { - const next_byte = try in.peekByte(); - if (next_byte == '.' or isWhitespace(next_byte)) - return error.InvalidStream; - } else if (byte == '>') { - const next_byte = try in.takeByte(); - if (!isWhitespace(next_byte)) - return error.InvalidStream; - } else if (byte == '*') { - const next_byte = try in.peekByte(); - if (next_byte != '.' and !isWhitespace(next_byte)) - return error.InvalidStream; - } - try subject_list.append(alloc, byte); - } else |err| return err; - }, - .@"pub" => { - while (in.takeByte()) |byte| { - if (isWhitespace(byte)) break; - if (byte == '*' or byte == '>') return error.InvalidStream; - if (byte == '.') { - const next_byte = try in.peekByte(); - if (next_byte == '.' or isWhitespace(next_byte)) - return error.InvalidStream; - } - try subject_list.append(alloc, byte); - } else |err| return err; - }, - } - - return subject_list.toOwnedSlice(alloc); -} - -inline fn expectStreamBytes(reader: *Reader, expected: []const u8) !void { - if (!std.mem.eql(u8, try reader.take(expected.len), expected)) { - @branchHint(.unlikely); - return error.InvalidStream; - } -} - -test "parsing a stream" { - const alloc = std.testing.allocator; - const expectEqualDeep = std.testing.expectEqualDeep; - const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":fa" ++ - "lse,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43" ++ - ".0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r" ++ - "\nPUB hi 3\r\nfoo\r\n"; - var reader: Reader = .fixed(input); - var arena: ArenaAllocator = .init(alloc); - defer arena.deinit(); - const gpa = arena.allocator(); - - { - const msg: Message = try Message.next(gpa, &reader); - const expected: Message = .{ - .CONNECT = .{ - .verbose = false, - .pedantic = false, - .tls_required = false, - .name = "NATS CLI Version v0.2.4", - .lang = "go", - .version = "1.43.0", - .protocol = 1, - .echo = true, - .headers = true, - .no_responders = true, - }, - }; - - try expectEqualDeep(expected, msg); - } - { - const msg: Message = try Message.next(gpa, &reader); - const expected: Message = .{ - .PUB = .{ - .subject = "hi", - .payload = .{ - .len = 3, - .short = blk: { - var s: [128]u8 = undefined; - const str = "foo"; - @memcpy(s[0..str.len], str); - break :blk s; - }, - .long = null, - }, - }, - }; - try expectEqualDeep(expected, msg); - } -} diff --git a/src/subcommand/server.zig b/src/subcommand/server.zig new file mode 100644 index 0000000..1aaf572 --- /dev/null +++ b/src/subcommand/server.zig @@ -0,0 +1,66 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const AtomicValue = std.atomic.Value; +const DebugAllocator = std.heap.DebugAllocator; +const Sigaction = std.posix.Sigaction; + +const Io = std.Io; +const Threaded = Io.Threaded; + +const builtin = @import("builtin"); + +const zits = @import("zits"); +const Message = zits.Server.Message; +const ServerInfo = Message.ServerInfo; + +const Server = zits.Server; + +const safe_build = builtin.mode == .Debug or builtin.mode == .ReleaseSafe; + +var keep_running = AtomicValue(bool).init(true); + +fn handleSigInt(sig: std.os.linux.SIG) callconv(.c) void { + _ = sig; + keep_running.store(false, .monotonic); +} + +pub fn main(outer_alloc: Allocator, server_config: ServerInfo) !void { + // Configure the signal action + const act = Sigaction{ + .handler = .{ .handler = handleSigInt }, + .mask = std.posix.sigemptyset(), + .flags = 0, + }; + + // Register the handler for SIGINT (Ctrl+C) + std.posix.sigaction(std.posix.SIG.INT, &act, null); + + { + var dba: DebugAllocator(.{}) = .init; + dba.backing_allocator = outer_alloc; + defer _ = dba.deinit(); + const alloc = if (safe_build) dba.allocator() else outer_alloc; + + var threaded: Threaded = .init(alloc, .{}); + defer threaded.deinit(); + const io = threaded.io(); + + var server: Server = .{ + .info = server_config, + }; + defer server.deinit(io, alloc); + + var server_task = try io.concurrent(Server.start, .{ &server, io, alloc }); + defer server_task.cancel(io) catch {}; + + // Block until Ctrl+C + while (keep_running.load(.monotonic)) { + try io.sleep(.fromMilliseconds(1), .awake); + } + + std.debug.print("\n", .{}); + std.log.info("Shutting down...", .{}); + server_task.cancel(io) catch {}; + } + std.log.info("Goodbye", .{}); +} -- cgit