diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/server/client.zig | 24 | ||||
| -rw-r--r-- | src/server/main.zig | 228 | ||||
| -rw-r--r-- | src/server/message_parser.zig | 123 |
3 files changed, 178 insertions, 197 deletions
diff --git a/src/server/client.zig b/src/server/client.zig index 458bbbb..ed1d33e 100644 --- a/src/server/client.zig +++ b/src/server/client.zig @@ -42,7 +42,6 @@ pub const ClientState = struct { ) void { while (true) { const message = self.recv_queue.getOne(io) catch break; - std.debug.print("got message in write loop to send to client: {any}\n", .{message}); switch (message) { .@"+ok" => { writeOk(self.to_client) catch break; @@ -54,28 +53,25 @@ pub const ClientState = struct { writeInfo(self.to_client, info) catch break; }, .msg => |m| { - if (writeMsg(self.to_client, m)) |_| { - @branchHint(.likely); - } else |_| { - @branchHint(.unlikely); - break; - } + writeMsg(self.to_client, m) catch break; }, else => { std.debug.panic("unimplemented write", .{}); }, } } + self.task.cancel(io); } pub fn deinit(self: *ClientState, io: std.Io, allocator: std.mem.Allocator) void { self.task.cancel(io); self.connect.deinit(); - allocator.destroy(self.recv_queue); + _ = allocator; + // allocator.destroy(self.recv_queue); } /// Return true if the value was put in the clients buffer to process, else false. - pub fn send(self: *ClientState, io: std.Io, msg: Message) std.Io.Cancelable!bool { + pub fn send(self: *ClientState, io: std.Io, msg: Message) (std.Io.Cancelable || std.Io.QueueClosedError)!bool { try self.recv_queue.putOne(io, msg); return true; } @@ -95,14 +91,11 @@ fn writeOk(out: *std.Io.Writer) !void { } fn writePong(out: *std.Io.Writer) !void { - std.debug.print("out pointer: {*}\n", .{out}); - std.debug.print("writing pong\n", .{}); _ = try out.write("PONG\r\n"); try out.flush(); } pub fn writeInfo(out: *std.Io.Writer, info: Message.ServerInfo) !void { - std.debug.print("writing info: {any}\n", .{info}); _ = try out.write("INFO "); try std.json.Stringify.value(info, .{}, out); _ = try out.write("\r\n"); @@ -110,7 +103,6 @@ pub fn writeInfo(out: *std.Io.Writer, info: Message.ServerInfo) !void { } fn writeMsg(out: *std.Io.Writer, msg: Message.Msg) !void { - std.debug.print("PRINTING MESSAGE\n\n\n\n", .{}); try out.print( "MSG {s} {s} {s} {d}\r\n{s}\r\n", .{ @@ -136,6 +128,7 @@ test { var from_client_queue: std.Io.Queue(Message) = .init(&from_client_buf); { + // Simulate stream while (Message.next(gpa, &from_client)) |msg| { switch (msg) { .eos => { @@ -146,7 +139,10 @@ test { try from_client_queue.putOne(io, msg); }, } - } else |_| {} + } else |err| switch (err) { + error.EndOfStream => try from_client_queue.close(io), + else => return err, + } while (from_client_queue.getOne(io)) |msg| { switch (msg) { diff --git a/src/server/main.zig b/src/server/main.zig index 2ebf96d..bb09179 100644 --- a/src/server/main.zig +++ b/src/server/main.zig @@ -10,7 +10,24 @@ clients: std.AutoHashMapUnmanaged(usize, *ClientState) = .empty, /// Map of subjects to a map of (client ID => SID) subscriptions: std.StringHashMapUnmanaged(std.AutoHashMapUnmanaged(usize, []const u8)) = .empty, +var keep_running = std.atomic.Value(bool).init(true); + +fn handleSigInt(sig: std.os.linux.SIG) callconv(.c) void { + _ = sig; + keep_running.store(false, .monotonic); +} + pub fn main(gpa: std.mem.Allocator, server_config: ServerInfo) !void { + // Configure the signal action + // const act = std.posix.Sigaction{ + // .handler = .{ .handler = handleSigInt }, + // .mask = std.posix.sigemptyset(), + // .flags = 0, + // }; + + // // Register the handler for SIGINT (Ctrl+C) + // std.posix.sigaction(std.posix.SIG.INT, &act, null); + var server: Server = .{ .info = server_config, }; @@ -26,7 +43,8 @@ pub fn main(gpa: std.mem.Allocator, server_config: ServerInfo) !void { defer tcp_server.deinit(io); var id: usize = 0; - while (true) : (id +%= 1) { + // Run until SIGINT is handled, then exit gracefully + while (keep_running.load(.monotonic)) : (id +%= 1) { std.debug.print("in server loop\n", .{}); if (server.clients.contains(id)) continue; const stream = try tcp_server.accept(io); @@ -36,6 +54,8 @@ pub fn main(gpa: std.mem.Allocator, server_config: ServerInfo) !void { stream.close(io); }; } + + std.debug.print("Exiting gracefully\n", .{}); } fn addClient(server: *Server, allocator: std.mem.Allocator, id: usize, client: *ClientState) !void { @@ -53,13 +73,20 @@ fn removeClient(server: *Server, allocator: std.mem.Allocator, id: usize) void { fn handleConnection( server: *Server, - allocator: std.mem.Allocator, + server_allocator: std.mem.Allocator, io: std.Io, id: usize, stream: std.Io.net.Stream, ) !void { + var client_allocator: std.heap.DebugAllocator(.{}) = .init; + client_allocator.backing_allocator = server_allocator; + defer { + std.debug.print("deinitializing debug allocator\n", .{}); + _ = client_allocator.deinit(); + } + const allocator = client_allocator.allocator(); defer stream.close(io); - var w_buffer: [1024]u8 = undefined; + var w_buffer: [4096]u8 = undefined; var writer = stream.writer(io, &w_buffer); const out = &writer.interface; @@ -79,155 +106,38 @@ fn handleConnection( try server.addClient(allocator, id, &client_state); defer server.removeClient(allocator, id); - // defer { - // server.clients.lockPointers(); - // server.clients.remove(allocator, id); - // server.clients.unlockPointers(); - // server.subscriptions.lockPointers(); - // var sub_iter = server.subscriptions.iterator(); - // var to_free: std.ArrayList(usize) = .empty; - // defer to_free.deinit(allocator); - // while (sub_iter.next()) |sub| { - // while (std.simd.firstIndexOfValue(sub.value_ptr.*, id)) |i| { - // sub.value_ptr.*.orderedRemove(i); - // } - // if (sub.value_ptr.items.len == 0) { - // to_free.append(allocator, sub.index); - // } - // } - // server.subscriptions.orderedRemoveAtMany(allocator, to_free.items); - // server.subscriptions.unlockPointers(); - // } - - { - defer std.debug.print("done processing client??\n", .{}); - std.debug.print("processing client: {d}\n", .{client_state.id}); - - std.debug.print("awaiting next message from client\n", .{}); - while (client_state.next(allocator)) |msg| { - std.debug.print("message from client!: {any}\n", .{msg}); - switch (msg) { - .ping => { - std.debug.print("got a ping! sending a pong.\n", .{}); - - std.debug.print("recv queue in server loop: {*}\n", .{&client_state.recv_queue}); - // @import("./client.zig").writePong(out) catch return; - for (0..5) |_| { - if (try client_state.send(io, .pong)) { - std.debug.print("sent pong\n", .{}); - break; - } - std.debug.print("trying to send a pong again.\n", .{}); - } else { - std.debug.print("could not pong to client {d}\n", .{client_state.id}); - } - }, - .@"pub" => |@"pub"| { - std.debug.print("pub: {any}\n", .{@"pub"}); - try server.publishMessage(io, @"pub"); - if (client_state.connect.connect.verbose) { - std.debug.print("server IS sending +ok\n", .{}); - _ = try client_state.send(io, .@"+ok"); - } else { - std.debug.print("server NOT sending +ok\n", .{}); + while (client_state.next(allocator)) |msg| { + switch (msg) { + .ping => { + // Respond to ping with pong. + for (0..5) |_| { + if (try client_state.send(io, .pong)) { + break; } - }, - .sub => |sub| { - try server.subscribe(allocator, client_state.id, sub); - }, - .eos => { - break; - }, - else => |e| { - std.debug.panic("Unimplemented message: {any}\n", .{e}); - }, - } - - std.debug.print("processed message from client\n", .{}); - std.debug.print("awaiting next message from client\n", .{}); - } else |_| {} - - // while (!io.cancelRequested()) { - // if (client_state.send_queue.getOne(io)) |msg| { - // switch (msg) { - // // Respond to ping with pong. - // .ping => { - // try client_state.recv_queue.putOne(io, .pong); - // }, - // .@"pub" => |p| { - // std.debug.print("subs (in pub): {any}\n", .{server.subscriptions}); - // std.debug.print("subs size: {d}\n", .{server.subscriptions.size}); - // std.debug.print("subs subjects:\n", .{}); - // var key_iter = server.subscriptions.keyIterator(); - // while (key_iter.next()) |k| { - // std.debug.print("- {s}\n", .{k.*}); - // } else std.debug.print("<none>\n", .{}); - // std.debug.print("pub subject: '{s}'\n", .{p.subject}); - // std.debug.print("pub: {any}\n", .{p}); - // errdefer _ = client_state.recv_queue.put(io, &.{.@"-err"}, 1) catch {}; - // // Just publishing to exact matches right now. - // // TODO: Publish to pattern matching subjects. - // if (server.subscriptions.get(p.subject)) |subs| { - // var subs_iter = subs.iterator(); - // while (subs_iter.next()) |sub| { - // var client = server.clients.get(sub.key_ptr.*) orelse std.debug.panic("Trying to pub to a client that doesn't exist!\n", .{}); - // std.debug.print("{d} is pubbing to {d}\n", .{ client_state.id, client.id }); - // if ((try client.recv_queue.put( - // io, - // &.{ - // .{ - // .msg = .{ - // .subject = p.subject, - // .sid = sub.value_ptr.*, - // .reply_to = p.reply_to, - // .payload = p.payload, - // }, - // }, - // }, - // 0, - // )) > 0) { - // std.debug.print("published message!\n", .{}); - // } else { - // std.debug.print("skipped publishing for some reason\n", .{}); - // } - // } - // try client_state.recv_queue.putOne(io, .@"+ok"); - // } else { - // std.debug.print("no subs with subject\n", .{}); - // } - // }, - // .sub => |s| { - // var subscribers = try server.subscriptions.getOrPut(gpa, s.subject); - // if (!subscribers.found_existing) { - // subscribers.value_ptr.* = .empty; - // } - // try subscribers.value_ptr.*.put(gpa, client_state.id, s.sid); - - // std.debug.print("subs: {any}\n", .{server.subscriptions}); - // }, - // .info => |info| { - // std.debug.panic("got an info message? : {any}\n", .{info}); - // }, - // else => |m| { - // std.debug.panic("Unimplemented: {any}\n", .{m}); - // }, - // } - // } else |err| return err; - // } - - // while (true) { - // switch (next_message) { - // .connect => |connect| { - // std.debug.panic("Connection message after already connected: {any}\n", .{connect}); - // }, - // .ping => try writePong(out), - // .@"pub" => try writeOk(out), - // else => |msg| std.debug.panic("Message type not implemented: {any}\n", .{msg}), - // } - // } + } else {} + }, + .@"pub" => |@"pub"| { + try server.publishMessage(io, @"pub"); + if (client_state.connect.connect.verbose) { + _ = try client_state.send(io, .@"+ok"); + } + }, + .sub => |sub| { + try server.subscribe(allocator, client_state.id, sub); + }, + .unsub => |unsub| { + try server.unsubscribe(client_state.id, unsub); + }, + else => |e| { + std.debug.panic("Unimplemented message: {any}\n", .{e}); + }, + } + } else |err| { + // This is probably going to be normal on disconnect + std.debug.print("Ran into error in client process loop: {}\n", .{err}); } - client_state.task.await(io); + // client_state.task.await(io); } // // Result is owned by the caller @@ -262,11 +172,31 @@ fn publishMessage(server: *Server, io: std.Io, msg: Message.Pub) !void { } fn subscribe(server: *Server, gpa: std.mem.Allocator, id: usize, msg: Message.Sub) !void { + std.debug.print("Recieved SUBSCRIBE message: {any}\n\n", .{msg}); var subs_for_subject: std.AutoHashMapUnmanaged(usize, []const u8) = if (server.subscriptions.fetchRemove(msg.subject)) |s| s.value else .empty; try subs_for_subject.put(gpa, id, msg.sid); try server.subscriptions.put(gpa, msg.subject, subs_for_subject); } +fn unsubscribe(server: *Server, id: usize, msg: Message.Unsub) !void { + // Get the subscription in subscriptions by looping over all the subjects, + // and getting the SID for that subject for the current client ID. + // If the SID matches, remove the kv for the client ID from subscriptions for that subject. + // If the value for that subject is empty, remove the subject. + var subscriptions_iter = server.subscriptions.iterator(); + while (subscriptions_iter.next()) |*subs_for_sub| { + if (subs_for_sub.value_ptr.get(id)) |client_sub| { + if (std.mem.eql(u8, client_sub, msg.sid)) { + _ = subs_for_sub.value_ptr.*.remove(id); + if (subs_for_sub.value_ptr.count() == 0) { + _ = server.subscriptions.remove(subs_for_sub.key_ptr.*); + } + break; + } + } + } +} + pub fn createId() []const u8 { return "SERVERID"; } diff --git a/src/server/message_parser.zig b/src/server/message_parser.zig index 5669d2a..9fd490c 100644 --- a/src/server/message_parser.zig +++ b/src/server/message_parser.zig @@ -13,7 +13,6 @@ pub const MessageType = enum { pong, @"+ok", @"-err", - eos, fn parseMemEql(input: []const u8) ?MessageType { // if (std.mem.eql(u8, "INFO", input)) return .info; @@ -38,16 +37,13 @@ pub const Message = union(MessageType) { @"pub": Pub, hpub: void, sub: Sub, - unsub: void, + unsub: Unsub, msg: Msg, hmsg: void, ping, pong, @"+ok": void, @"-err": void, - // Not an actual NATS message, but used to signal end of stream was reached in the input, - // and we should close the reader. - eos: void, pub const ServerInfo = struct { /// The unique identifier of the NATS server. server_id: []const u8, @@ -117,6 +113,12 @@ pub const Message = union(MessageType) { /// A unique alphanumeric subscription ID, generated by the client. sid: []const u8, }; + pub const Unsub = struct { + /// The unique alphanumeric subscription ID of the subject to unsubscribe from. + sid: []const u8, + /// A number of messages to wait for before automatically unsubscribing. + max_msgs: ?usize = null, + }; pub const Msg = struct { subject: []const u8, sid: []const u8, @@ -168,16 +170,9 @@ pub const Message = union(MessageType) { while (in.peekByte()) |byte| { if (std.ascii.isUpper(byte)) { try operation_string.appendBounded(byte); - try in.discardAll(1); + in.toss(1); } else break; - } else |err| switch (err) { - error.EndOfStream => { - return .{ .eos = {} }; - }, - else => { - return err; - }, - } + } else |err| return err; const operation = parse(operation_string.items) orelse { return error.InvalidOperation; @@ -195,7 +190,7 @@ pub const Message = union(MessageType) { // Should read the next JSON object to the fixed buffer writer. _ = try in.streamDelimiter(&connect_string_writer, '}'); try connect_string_writer.writeByte('}'); - std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' + try assertStreamBytes(in, "}\r\n"); // discard '}\r\n' // TODO: should be CONNECTION allocator const res = try std.json.parseFromSliceLeaky(Connect, connect_allocator, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); @@ -211,27 +206,27 @@ pub const Message = union(MessageType) { // Parse byte count const byte_count = blk: { var byte_count_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); - while (in.takeByte() catch null) |byte| { + while (in.peekByte()) |byte| { if (std.ascii.isWhitespace(byte)) { - std.debug.assert(byte == '\r'); - std.debug.assert(try in.takeByte() == '\n'); + try assertStreamBytes(in, "\r\n"); break; } + defer in.toss(1); if (std.ascii.isDigit(byte)) { try byte_count_list.append(alloc, byte); } else { return error.InvalidStream; } - } else return error.InvalidStream; + } else |err| return err; - break :blk try std.fmt.parseUnsigned(u64, byte_count_list.items, 10); + break :blk try std.fmt.parseUnsigned(usize, byte_count_list.items, 10); }; const payload = blk: { const bytes = try alloc.alloc(u8, byte_count); try in.readSliceAll(bytes); - std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); break :blk bytes; }; @@ -243,35 +238,38 @@ pub const Message = union(MessageType) { }; }, .ping => { - std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .ping; }, .pong => { - std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .pong; }, .sub => { - std.debug.assert(std.ascii.isWhitespace(try in.takeByte())); + if (!std.ascii.isWhitespace(try in.takeByte())) { + @branchHint(.unlikely); + return error.InvalidStream; + } const subject = try readSubject(alloc, in); const second = blk: { // Drop whitespace - while (in.peekByte() catch null) |byte| { + while (in.peekByte()) |byte| { if (std.ascii.isWhitespace(byte)) { in.toss(1); } else break; - } else return error.InvalidStream; + } else |err| return err; var acc: std.ArrayList(u8) = try .initCapacity(alloc, 32); - while (in.takeByte() catch null) |byte| { + while (in.takeByte()) |byte| { if (std.ascii.isWhitespace(byte)) break; try acc.append(alloc, byte); - } else return error.InvalidStream; + } else |err| return err; break :blk try acc.toOwnedSlice(alloc); }; const queue_group = if ((try in.peekByte()) != '\r') second else null; const sid = if (queue_group) |_| try in.takeDelimiterExclusive('\r') else second; - std.debug.assert(std.mem.eql(u8, try in.take(2), "\r\n")); + try assertStreamBytes(in, "\r\n"); return .{ .sub = .{ .subject = subject, @@ -280,6 +278,57 @@ pub const Message = union(MessageType) { }, }; }, + .unsub => { + if (!std.ascii.isWhitespace(try in.takeByte())) { + @branchHint(.unlikely); + return error.InvalidStream; + } + // Parse byte count + const sid = blk: { + var acc: std.ArrayList(u8) = try .initCapacity(alloc, 8); + while (in.peekByte()) |byte| { + if (std.ascii.isWhitespace(byte)) break; + try acc.append(alloc, byte); + in.toss(1); + } else |err| return err; + break :blk try acc.toOwnedSlice(alloc); + }; + + if ((try in.peekByte()) == '\r') { + try assertStreamBytes(in, "\r\n"); + return .{ + .unsub = .{ + .sid = sid, + }, + }; + } else if (std.ascii.isWhitespace(try in.peekByte())) { + in.toss(1); + const max_msgs = blk: { + var max_msgs_list: std.ArrayList(u8) = try .initCapacity(alloc, 64); + while (in.peekByte()) |byte| { + if (std.ascii.isWhitespace(byte)) { + try assertStreamBytes(in, "\r\n"); + break; + } + + if (std.ascii.isDigit(byte)) { + try max_msgs_list.append(alloc, byte); + } else { + return error.InvalidStream; + } + } else |err| return err; + + break :blk try std.fmt.parseUnsigned(usize, max_msgs_list.items, 10); + }; + + return .{ + .unsub = .{ + .sid = sid, + .max_msgs = max_msgs, + }, + }; + } else return error.InvalidStream; + }, else => |msg| std.debug.panic("Not implemented: {}\n", .{msg}), } } @@ -292,14 +341,13 @@ fn readSubject(alloc: std.mem.Allocator, in: *std.Io.Reader) ![]const u8 { // Handle the first character { const byte = try in.takeByte(); - std.debug.assert(!std.ascii.isWhitespace(byte)); - if (byte == '.') - return error.InvalidSubject; + if (std.ascii.isWhitespace(byte) or byte == '.') + return error.InvalidStream; try subject_list.append(alloc, byte); } - while (in.takeByte() catch null) |byte| { + while (in.takeByte()) |byte| { if (std.ascii.isWhitespace(byte)) break; if (std.ascii.isAscii(byte)) { if (byte == '.') { @@ -309,7 +357,7 @@ fn readSubject(alloc: std.mem.Allocator, in: *std.Io.Reader) ![]const u8 { } try subject_list.append(alloc, byte); } - } else return error.InvalidStream; + } else |err| return err; return subject_list.toOwnedSlice(alloc); } @@ -341,6 +389,13 @@ fn parsePub(in: *std.Io.Reader) !Message.Pub { }; } +inline fn assertStreamBytes(reader: *std.Io.Reader, expected: []const u8) !void { + if (!std.mem.eql(u8, try reader.take(expected.len), expected)) { + @branchHint(.unlikely); + return error.InvalidStream; + } +} + // try returning error in debug mode, only null in release? // pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { // const message_type: MessageType = blk: { |
