From 826da348a51c0650394e564850e9a0c65c1cfeea Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Tue, 2 Dec 2025 22:37:50 -0500 Subject: --- src/server/client.zig | 115 +++++++++++++++++++++++++----- src/server/main.zig | 64 +++++++++++++++-- src/server/message_parser.zig | 159 ++++++++++++++++++++++++------------------ src/server/test.zig | 3 + 4 files changed, 249 insertions(+), 92 deletions(-) create mode 100644 src/server/test.zig (limited to 'src/server') diff --git a/src/server/client.zig b/src/server/client.zig index c8a9239..b4bc55b 100644 --- a/src/server/client.zig +++ b/src/server/client.zig @@ -24,7 +24,7 @@ pub const ClientState = struct { connect: Message.Connect, in: *std.Io.Reader, out: *std.Io.Writer, - ) ClientState { + ) !ClientState { var res: ClientState = .{ .id = id, .connect = connect, @@ -37,12 +37,10 @@ pub const ClientState = struct { }; res.send_queue = .init(&res.send_queue_buffer); res.recv_queue = .init(&res.recv_queue_buffer); - const write_task = io.async(processWrite, .{ &res, io, out }); - // @compileLog(@TypeOf(write_task)); - const read_task = io.async(processRead, .{ &res, io, allocator, in }); - // @compileLog(@TypeOf(read_task)); - res.write_task = write_task; - res.read_task = read_task; + // res.send_queue = .init(&.{}); + // res.recv_queue = .init(&.{}); + res.write_task = try io.concurrent(processWrite, .{ &res, io, out }); + res.read_task = try io.concurrent(processRead, .{ &res, io, allocator, in }); return res; } @@ -53,13 +51,23 @@ pub const ClientState = struct { out: *std.Io.Writer, ) void { while (true) { - const message = self.recv_queue.getOne(io) catch break; + const message = self.recv_queue.getOne(io) catch continue; switch (message) { - .@"+ok" => writeOk(out) catch break, - .pong => writePong(out) catch break, - .info => |info| writeInfo(out, info) catch break, - .msg => |m| writeMsg(out, m) catch break, - else => std.debug.panic("unimplemented write", .{}), + .@"+ok" => { + writeOk(out) catch break; + }, + .pong => { + writePong(out) catch break; + }, + .info => |info| { + writeInfo(out, info) catch break; + }, + .msg => |m| { + writeMsg(out, m) catch break; + }, + else => { + std.debug.panic("unimplemented write", .{}); + }, } } } @@ -70,19 +78,27 @@ pub const ClientState = struct { allocator: std.mem.Allocator, in: *std.Io.Reader, ) void { + io.sleep(.fromMilliseconds(100), .real) catch @panic("couldn't sleep"); while (true) { + std.debug.print("waiting for message\n", .{}); const next_message = Message.next(allocator, in) catch |err| switch (err) { - error.EndOfStream => { - break; - }, + error.EndOfStream => break, else => { std.debug.panic("guh: {any}\n", .{err}); break; // return err; }, }; - self.send_queue.putOne(io, next_message) catch break; + std.debug.print("got message {}\n", .{next_message}); + // std.debug.print("queue: {any}\n", .{self.send_queue}); + self.send_queue.putOneUncancelable(io, next_message); //catch { + // std.debug.print("in catch\n\n\n", .{}); + + // std.debug.print("queue: {any}\n", .{self.send_queue}); + // }; } + + std.debug.print("no more messages\n", .{}); } pub fn deinit(self: *ClientState, alloc: std.mem.Allocator) void { @@ -102,7 +118,10 @@ pub const ClientState = struct { return (try self.recv_queue.put(io, &.{msg}, 0)) > 0; } - pub fn next(self: *ClientState, io: std.Io) std.Io.Cancelable!Message { + pub fn next(self: *ClientState, io: std.Io) !Message { + std.debug.print("in client awaiting next message\n", .{}); + errdefer std.debug.print("actually it was canceled\n", .{}); + defer std.debug.print("client returning next message!\n", .{}); return self.send_queue.getOne(io); } }; @@ -113,11 +132,13 @@ fn writeOk(out: *std.Io.Writer) !void { } fn writePong(out: *std.Io.Writer) !void { + std.debug.print("writing pong\n", .{}); _ = try out.write("PONG\r\n"); try out.flush(); } pub fn writeInfo(out: *std.Io.Writer, info: Message.ServerInfo) !void { + std.debug.print("writing info: {any}\n", .{info}); _ = try out.write("INFO "); try std.json.Stringify.value(info, .{}, out); _ = try out.write("\r\n"); @@ -138,3 +159,61 @@ fn writeMsg(out: *std.Io.Writer, msg: Message.Msg) !void { ); try out.flush(); } + +test { + const io = std.testing.io; + const gpa = std.testing.allocator; + + var from_client: std.Io.Reader = .fixed( + "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\n" ++ + "PING\r\n", + ); + var from_client_buf: [1024]Message = undefined; + var from_client_queue: std.Io.Queue(Message) = .init(&from_client_buf); + + while (Message.next(gpa, &from_client)) |msg| { + try from_client_queue.putOne(io, msg); + } else |_| {} + + for (0..2) |_| { + var msg = try from_client_queue.getOne(io); + std.debug.print("Message: {any}\n", .{msg}); + switch (msg) { + .connect => |*c| { + c.deinit(); + }, + else => {}, + } + } + + // const connect = (Message.next(gpa, &from_client) catch unreachable).connect; + + // var to_client_alloc: std.Io.Writer.Allocating = .init(gpa); + // defer to_client_alloc.deinit(); + // var to_client = to_client_alloc.writer; + + // var client: ClientState = try .init(io, gpa, 0, connect, &from_client, &to_client); + // defer client.deinit(gpa); + + // { + // var get_next = io.concurrent(ClientState.next, .{ &client, io }) catch unreachable; + // defer if (get_next.cancel(io)) |_| {} else |_| @panic("fail"); + + // var timeout = io.concurrent(std.Io.sleep, .{ io, .fromMilliseconds(1000), .awake }) catch unreachable; + // defer timeout.cancel(io) catch {}; + + // switch (try io.select(.{ + // .get_next = &get_next, + // .timeout = &timeout, + // })) { + // .get_next => |next| { + // std.debug.print("next is {any}\n", .{next}); + // try std.testing.expect((next catch |err| return err) == .ping); + // }, + // .timeout => { + // std.debug.print("reached timeout\n", .{}); + // return error.TestUnexpectedResult; + // }, + // } + // } +} diff --git a/src/server/main.zig b/src/server/main.zig index d90bea8..7f9b9a3 100644 --- a/src/server/main.zig +++ b/src/server/main.zig @@ -29,7 +29,10 @@ pub fn main(gpa: std.mem.Allocator, server_config: ServerInfo) !void { while (true) : (id +%= 1) { if (server.clients.contains(id)) continue; const stream = try tcp_server.accept(io); - _ = io.async(handleConnection, .{ &server, gpa, io, id, stream }); + _ = io.concurrent(handleConnection, .{ &server, gpa, io, id, stream }) catch { + std.debug.print("could not start concurrent handler for {d}\n", .{id}); + stream.close(io); + }; } } @@ -66,7 +69,7 @@ fn handleConnection( var connect_arena: std.heap.ArenaAllocator = .init(allocator); defer connect_arena.deinit(); const connect = (Message.next(connect_arena.allocator(), in) catch return).connect; - var client_state: ClientState = .init(io, allocator, id, connect, in, out); + var client_state: ClientState = try .init(io, allocator, id, connect, in, out); try server.addClient(allocator, id, client_state); defer server.removeClient(allocator, id); @@ -133,16 +136,24 @@ fn subscribe(server: *Server, gpa: std.mem.Allocator, id: usize, msg: Message.Su try server.subscriptions.put(gpa, msg.subject, subs_for_subject); } -fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_state: *ClientState) !void { +pub fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_state: *ClientState) !void { + defer std.debug.print("done processing client??\n", .{}); defer client_state.deinit(gpa); + std.debug.print("processing client: {d}\n", .{client_state.id}); while (true) { - switch (try client_state.next(io)) { + std.debug.print("awaiting next message from client\n", .{}); + switch (client_state.next(io)) { .ping => { + std.debug.print("got a ping! sending a pong.\n", .{}); for (0..5) |_| { - if (try client_state.send(io, .pong)) break; + if (try client_state.send(io, .pong)) { + std.debug.print("sent pong\n", .{}); + break; + } + std.debug.print("trying to send a pong again.\n", .{}); } else { - std.debug.print("could not pong to client {}\n", .{client_state.id}); + std.debug.print("could not pong to client {d}\n", .{client_state.id}); } }, .@"pub" => |msg| { @@ -158,6 +169,8 @@ fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_sta std.debug.panic("Unimplemented message: {any}\n", .{msg}); }, } + + std.debug.print("processed message from client\n", .{}); } // while (!io.cancelRequested()) { @@ -247,3 +260,42 @@ pub fn createId() []const u8 { pub fn createName() []const u8 { return "SERVERNAME"; } + +// TESTING + +// fn initTestServer() Server { +// return .{ +// .info = .{ +// .server_id = "ABCD", +// .server_name = "test server", +// .version = "0.1.2", +// .max_payload = 1234, +// }, +// }; +// } + +// fn initTestClient( +// io: std.Io, +// allocator: std.mem.Allocator, +// id: usize, +// data_from: []const u8, +// ) !struct { +// Client, +// *std.Io.Reader, +// *std.Io.Writer, +// } { +// return .init(io, allocator, id, .{}, in, out); +// } + +// test { +// const gpa = std.testing.allocator; +// const io = std.testing.io; + +// const server = initTestServer(); +// const client: Client = .init( +// io, +// gpa, +// 1, +// .{}, +// ); +// } diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig index bee11bc..e791d14 100644 --- a/src/server/message_parser.zig +++ b/src/server/message_parser.zig @@ -31,9 +31,9 @@ pub const MessageType = enum { } }; -pub const Message = union(enum) { +pub const Message = union(MessageType) { info: ServerInfo, - connect: Connect, + connect: AllocatedConnect, @"pub": Pub, hpub: void, sub: Sub, @@ -71,6 +71,14 @@ pub const Message = union(enum) { /// feature. proto: u32 = 1, }; + pub const AllocatedConnect = struct { + allocator: std.heap.ArenaAllocator, + connect: Connect, + + pub fn deinit(self: *AllocatedConnect) void { + self.allocator.deinit(); + } + }; pub const Connect = struct { verbose: bool = false, pedantic: bool = false, @@ -136,8 +144,20 @@ pub const Message = union(enum) { /// An error should be handled by cleaning up this connection. pub fn next(alloc: std.mem.Allocator, in: *std.Io.Reader) !Message { + // errdefer |err| { + // std.debug.print("Error occurred: {}\n", .{err}); + + // // Get the error return trace + // if (@errorReturnTrace()) |trace| { + // std.debug.print("Error return trace:\n", .{}); + // std.debug.dumpStackTrace(trace); + // } else { + // std.debug.print("No error return trace available\n", .{}); + // } + // } + var operation_string: std.ArrayList(u8) = blk: { - var buf: ["CONTINUE".len]u8 = undefined; + var buf: ["CONTINUE".len + 1]u8 = undefined; break :blk .initBuffer(&buf); }; @@ -149,15 +169,15 @@ pub const Message = union(enum) { } else |err| return err; const operation = parse(operation_string.items) orelse { - std.debug.print("operation: '{s}'\n", .{operation_string.items}); - std.debug.print("buffered: '{s}'", .{in.buffered()}); return error.InvalidOperation; }; switch (operation) { .connect => { // TODO: should be ARENA allocator - var connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(alloc, 1024); + var connect_arena_allocator: std.heap.ArenaAllocator = .init(alloc); + const connect_allocator = connect_arena_allocator.allocator(); + const connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(connect_allocator, 1024); var connect_string_writer = connect_string_writer_allocating.writer; try in.discardAll(1); // throw away space @@ -167,9 +187,9 @@ pub const Message = union(enum) { std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' // TODO: should be CONNECTION allocator - const res = try std.json.parseFromSliceLeaky(Connect, alloc, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); + const res = try std.json.parseFromSliceLeaky(Connect, connect_allocator, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); - return .{ .connect = res }; + return .{ .connect = .{ .allocator = connect_arena_allocator, .connect = res } }; }, .@"pub" => { try in.discardAll(1); // throw away space @@ -311,67 +331,70 @@ fn parsePub(in: *std.Io.Reader) !Message.Pub { } // try returning error in debug mode, only null in release? -pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { - const message_type: MessageType = blk: { - var word: ["CONNECT".len]u8 = undefined; - var len: usize = 0; - for (&word, 0..) |*b, i| { - const byte = in.takeByte() catch return null; - if (std.ascii.isUpper(byte)) { - b.* = byte; - len = i + 1; - } else break; - } - - break :blk Message.parse(word[0..len]) orelse return null; - }; - - // defer in.toss(2); // CRLF - return switch (message_type) { - .connect => blk: { - const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }; - - break :blk value; - }, - .@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) }, - .ping => .ping, - else => null, - }; -} - -test parseNextMessage { - const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n"; - var reader: std.Io.Reader = .fixed(input); - var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); - defer arena.deinit(); - const gpa = arena.allocator(); +// pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { +// const message_type: MessageType = blk: { +// var word: ["CONNECT".len]u8 = undefined; +// var len: usize = 0; +// for (&word, 0..) |*b, i| { +// const byte = in.takeByte() catch return null; +// if (std.ascii.isUpper(byte)) { +// b.* = byte; +// len = i + 1; +// } else break; +// } + +// break :blk Message.parse(word[0..len]) orelse return null; +// }; + +// // defer in.toss(2); // CRLF +// return switch (message_type) { +// .connect => blk: { +// const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }; + +// break :blk value; +// }, +// .@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) }, +// .ping => .ping, +// else => null, +// }; +// } - { - const msg: Message = try Message.next(gpa, &reader); - const expected: Message = .{ .connect = .{ - .verbose = false, - .pedantic = false, - .tls_required = false, - .name = "NATS CLI Version v0.2.4", - .lang = "go", - .version = "1.43.0", - .protocol = 1, - .echo = true, - .headers = true, - .no_responders = true, - } }; - - try std.testing.expectEqualDeep(expected, msg); - } - { - const msg: Message = try Message.next(gpa, &reader); - const expected: Message = .{ .@"pub" = .{ - .subject = "hi", - .payload = "foo", - } }; - try std.testing.expectEqualDeep(expected, msg); - } -} +// test parseNextMessage { +// const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n"; +// var reader: std.Io.Reader = .fixed(input); +// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); +// defer arena.deinit(); +// const gpa = arena.allocator(); + +// { +// const msg: Message = try Message.next(gpa, &reader); +// const expected: Message = .{ .connect = .{ +// .connect = .{ +// .verbose = false, +// .pedantic = false, +// .tls_required = false, +// .name = try gpa.dupe(u8, "NATS CLI Version v0.2.4"), +// .lang = try gpa.dupe(u8, "go"), +// .version = try gpa.dupe(u8, "1.43.0"), +// .protocol = 1, +// .echo = true, +// .headers = true, +// .no_responders = true, +// }, +// .allocator = arena, +// } }; + +// try std.testing.expectEqualDeep(expected, msg); +// } +// { +// const msg: Message = try Message.next(gpa, &reader); +// const expected: Message = .{ .@"pub" = .{ +// .subject = "hi", +// .payload = "foo", +// } }; +// try std.testing.expectEqualDeep(expected, msg); +// } +// } // test "MessageType.parse performance" { // // Measure perf for parseMemEql diff --git a/src/server/test.zig b/src/server/test.zig new file mode 100644 index 0000000..e89e354 --- /dev/null +++ b/src/server/test.zig @@ -0,0 +1,3 @@ +const std = @import("std"); +const Server = @import("./main.zig"); +const Client = @import("./client.zig").ClientState; -- cgit