const std = @import("std"); const Allocator = std.mem.Allocator; const ArrayList = std.ArrayList; const AutoHashMapUnmanaged = std.AutoHashMapUnmanaged; const Io = std.Io; const Dir = Io.Dir; const Group = Io.Group; const IpAddress = std.Io.net.IpAddress; const Mutex = Io.Mutex; const Queue = Io.Queue; const Stream = std.Io.net.Stream; pub const Client = @import("./Server/Client.zig"); pub const message = @import("./Server/message.zig"); const parse = message.parse; const MessageType = message.Control; const Message = message.Message; const ServerInfo = Message.ServerInfo; const Msgs = Client.Msgs; const Server = @This(); const builtin = @import("builtin"); const Subscription = struct { subject: []const u8, client_id: u128, sid: []const u8, queue_group: ?[]const u8, queue_lock: *Mutex, queue: *Queue(u8), fn deinit(self: Subscription, alloc: Allocator) void { alloc.free(self.subject); alloc.free(self.sid); if (self.queue_group) |g| alloc.free(g); } fn send(self: *Subscription, io: Io, buf: []u8, bytes: []const []const u8) !void { var w: std.Io.Writer = .fixed(buf); for (bytes) |chunk| { w.writeAll(chunk) catch unreachable; } try self.queue_lock.lock(io); defer self.queue_lock.unlock(io); try self.queue.putAll(io, w.buffered()); } }; const eql = std.mem.eql; const log = std.log.scoped(.zits); const panic = std.debug.panic; info: ServerInfo, subs_lock: Mutex = .init, subscriptions: ArrayList(Subscription) = .empty, pub fn deinit(server: *Server, io: Io, alloc: Allocator) void { server.subs_lock.lockUncancelable(io); for (server.subscriptions.items) |sub| { sub.deinit(alloc); } // TODO drain subscription queues server.subscriptions.deinit(alloc); server.subs_lock.unlock(io); } pub fn start(server: *Server, io: Io, gpa: Allocator) !void { var tcp_server = try IpAddress.listen(try IpAddress.parse( server.info.host, server.info.port, ), io, .{}); defer tcp_server.deinit(io); log.debug("Server headers: {s}", .{if (server.info.headers) "true" else "false"}); log.debug("Server max payload: {d}", .{server.info.max_payload}); log.info("Server ID: {s}", .{server.info.server_id}); log.info("Server name: {s}", .{server.info.server_name}); log.info("Server listening on {s}:{d}", .{ server.info.host, server.info.port }); var client_group: Group = .init; defer client_group.cancel(io); const read_buffer_size, const write_buffer_size = getBufferSizes(io); log.debug("read buf: {d} write buf: {d}", .{ read_buffer_size, write_buffer_size }); const rand_io: std.Random.IoSource = .{ .io = io }; const rand: std.Random = rand_io.interface(); var id = rand.int(u128); while (true) : (id = rand.int(u128)) { log.debug("Accepting next client", .{}); const stream = try tcp_server.accept(io); log.debug("Accepted connection {s}", .{idToStr(id)}); _ = client_group.concurrent(io, handleConnectionInfallible, .{ server, gpa, io, rand, id, stream, read_buffer_size, write_buffer_size, }) catch { log.err("Could not start concurrent handler for {s}", .{idToStr(id)}); stream.close(io); }; } } fn removeClient(server: *Server, io: Io, allocator: Allocator, id: u128) void { server.subs_lock.lockUncancelable(io); defer server.subs_lock.unlock(io); const len = server.subscriptions.items.len; for (0..len) |from_end| { const i = len - from_end - 1; const sub = server.subscriptions.items[i]; if (sub.client_id == id) { sub.deinit(allocator); _ = server.subscriptions.swapRemove(i); } } } fn handleConnectionInfallible( server: *Server, server_allocator: Allocator, io: Io, rand: std.Random, id: u128, stream: Stream, r_buf_size: usize, w_buf_size: usize, ) !void { handleConnection(server, server_allocator, io, rand, id, stream, r_buf_size, w_buf_size) catch |err| switch (err) { error.Canceled => return error.Canceled, error.ClientDisconnected => log.debug("Client {s} disconnected", .{idToStr(id)}), else => log.err("Failed processing client {s}: {t}", .{ idToStr(id), err }), }; } fn handleConnection( server: *Server, server_allocator: Allocator, io: Io, rand: std.Random, id: u128, stream: Stream, r_buf_size: usize, w_buf_size: usize, ) !void { defer stream.close(io); var dba: std.heap.DebugAllocator(.{}) = .init; dba.backing_allocator = server_allocator; defer _ = dba.deinit(); const alloc = if (builtin.mode == .Debug or builtin.mode == .ReleaseSafe) dba.allocator() else server_allocator; // Set up client writer const w_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), w_buf_size); defer alloc.free(w_buffer); var writer = stream.writer(io, w_buffer); const out = &writer.interface; // Set up client reader const r_buffer: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), r_buf_size); defer alloc.free(r_buffer); var reader = stream.reader(io, r_buffer); const in = &reader.interface; // Set up buffer queue const qbuf: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 256 * 1024 * 1024); defer alloc.free(qbuf); var recv_queue: Queue(u8) = .init(qbuf); defer recv_queue.close(io); const msg_write_buf: []u8 = try alloc.alignedAlloc(u8, .fromByteUnits(std.atomic.cache_line), 1 * 1024 * 1024); defer alloc.free(msg_write_buf); // Create client var client: Client = .init(null, &recv_queue, in, out); defer client.deinit(server_allocator); defer server.removeClient(io, server_allocator, id); // Do initial handshake with client _ = try out.write("INFO "); try std.json.Stringify.value(server.info, .{}, out); _ = try out.write("\r\n"); try out.flush(); var client_task = try io.concurrent(Client.start, .{ &client, io }); defer client_task.cancel(io) catch {}; while (client.next()) |ctrl| { switch (ctrl) { .PING => { // Respond to ping with pong. try client.recv_queue_write_lock.lock(io); defer client.recv_queue_write_lock.unlock(io); _ = try client.from_client.take(2); // throw out \r\n try client.recv_queue.putAll(io, "PONG\r\n"); }, .PUB => { @branchHint(.likely); // log.debug("received a pub msg", .{}); server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .@"pub") catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, }; }, .HPUB => { @branchHint(.likely); server.publishMessage(io, rand, server_allocator, msg_write_buf, &client, .hpub) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, }; }, .SUB => { server.subscribe(io, server_allocator, &client, id) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, }; }, .UNSUB => { server.unsubscribe(io, server_allocator, client.from_client, id) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, }; }, .CONNECT => { if (client.connect) |*current| { current.deinit(server_allocator); } client.connect = parse.connect(server_allocator, client.from_client) catch |err| switch (err) { error.ReadFailed => return reader.err.?, error.EndOfStream => return error.ClientDisconnected, else => |e| return e, }; }, else => |e| { panic("Unimplemented message: {any}\n", .{e}); }, } } else |err| switch (err) { error.EndOfStream => return error.ClientDisconnected, error.ReadFailed => switch (reader.err.?) { error.ConnectionResetByPeer => return error.ClientDisconnected, else => |e| return e, }, else => |e| return e, } } fn subjectMatches(sub_subject: []const u8, pub_subject: []const u8) bool { // TODO: assert that sub_subject and pub_subject are valid. var sub_iter = std.mem.splitScalar(u8, sub_subject, '.'); var pub_iter = std.mem.splitScalar(u8, pub_subject, '.'); while (sub_iter.next()) |st| { const pt = pub_iter.next() orelse return false; if (eql(u8, st, ">")) return true; if (!eql(u8, st, "*") and !eql(u8, st, pt)) { return false; } } return pub_iter.next() == null; } test subjectMatches { const expect = std.testing.expect; try expect(subjectMatches("foo", "foo")); try expect(!subjectMatches("foo", "bar")); try expect(subjectMatches("foo.*", "foo.bar")); try expect(!subjectMatches("foo.*", "foo")); try expect(!subjectMatches("foo.>", "foo")); // the wildcard subscriptions foo.*.quux and foo.> both match foo.bar.quux, but only the latter matches foo.bar.baz. try expect(subjectMatches("foo.*.quux", "foo.bar.quux")); try expect(subjectMatches("foo.>", "foo.bar.quux")); try expect(!subjectMatches("foo.*.quux", "foo.bar.baz")); try expect(subjectMatches("foo.>", "foo.bar.baz")); } fn publishMessage( server: *Server, io: Io, rand: std.Random, alloc: Allocator, msg_write_buf: []u8, source_client: *Client, comptime pub_or_hpub: enum { @"pub", hpub }, ) !void { defer if (source_client.connect) |c| { if (c.verbose) { if (source_client.recv_queue_write_lock.lock(io)) |_| { defer source_client.recv_queue_write_lock.unlock(io); source_client.recv_queue.putAll(io, "+OK\r\n") catch {}; } else |_| {} } }; const hpubmsg = switch (pub_or_hpub) { .@"pub" => {}, .hpub => try parse.hpub(source_client.from_client), }; const msg: Message.Pub = switch (pub_or_hpub) { .@"pub" => try parse.@"pub"(source_client.from_client), .hpub => hpubmsg.@"pub", }; try server.subs_lock.lock(io); defer server.subs_lock.unlock(io); var published_queue_groups: ArrayList([]const u8) = .empty; defer published_queue_groups.deinit(alloc); subs: for (0..server.subscriptions.items.len) |i| { var subscription = server.subscriptions.items[i]; if (subjectMatches(subscription.subject, msg.subject)) { if (subscription.queue_group) |sg| { for (published_queue_groups.items) |g| { if (eql(u8, g, sg)) { continue :subs; } } // Don't republish to the same queue try published_queue_groups.append(alloc, sg); } // The rest of this loop is setting up a slice of byte slices to simultaneously // send to the underlying queue. // Each "chunk" is a section of the message to be sent. // The chunk_count starts off at the minimum number of chunks per message, and // then increases as branches add additional chunks. // The msg_chunks_buf.len is the maximum number of chunks in a message. // We can appendAssumeCapacity because it is a programmer error to append // more than max_msg_chunks. // If we need to append more chunks, this value should be increased. // The reason for this strategy is to avoid any intermediary allocations between // the publishers read buffer, and the subscribers write buffer. const min_msg_chunks, const max_msg_chunks = .{ 7, 10 }; var chunk_count: usize = min_msg_chunks; var msg_chunks_buf: [max_msg_chunks][]const u8 = undefined; var msg_chunks: ArrayList([]const u8) = .initBuffer(&msg_chunks_buf); switch (pub_or_hpub) { .@"pub" => _ = msg_chunks.appendAssumeCapacity("MSG "), .hpub => _ = msg_chunks.appendAssumeCapacity("HMSG "), } msg_chunks.appendAssumeCapacity(msg.subject); msg_chunks.appendAssumeCapacity(" "); msg_chunks.appendAssumeCapacity(subscription.sid); msg_chunks.appendAssumeCapacity(" "); if (msg.reply_to) |reply_to| { chunk_count += 2; msg_chunks.appendAssumeCapacity(reply_to); msg_chunks.appendAssumeCapacity(" "); } switch (pub_or_hpub) { .hpub => { chunk_count += 1; var hlen_buf: [std.fmt.count("{d} ", .{std.math.maxInt(usize)})]u8 = undefined; msg_chunks.appendAssumeCapacity( std.fmt.bufPrint(&hlen_buf, "{d} ", .{hpubmsg.header_bytes}) catch unreachable, ); }, else => {}, } var len_buf: [std.fmt.count("{d}\r\n", .{std.math.maxInt(usize)})]u8 = undefined; msg_chunks.appendAssumeCapacity( std.fmt.bufPrint(&len_buf, "{d}\r\n", .{msg.payload.len - 2}) catch unreachable, ); msg_chunks.appendAssumeCapacity(msg.payload); subscription.send(io, msg_write_buf, msg_chunks.items[0..chunk_count]) catch |err| switch (err) { error.Closed => {}, error.Canceled => |e| return e, }; } } rand.shuffle(Subscription, server.subscriptions.items); } fn subscribe( server: *Server, io: Io, gpa: Allocator, client: *Client, id: u128, ) !void { const msg = try parse.sub(client.from_client); try server.subs_lock.lock(io); defer server.subs_lock.unlock(io); const subject = try gpa.dupe(u8, msg.subject); errdefer gpa.free(subject); const sid = try gpa.dupe(u8, msg.sid); errdefer gpa.free(sid); const queue_group = if (msg.queue_group) |q| try gpa.dupe(u8, q) else null; errdefer if (queue_group) |q| gpa.free(q); try server.subscriptions.append(gpa, .{ .subject = subject, .client_id = id, .sid = sid, .queue_group = queue_group, .queue_lock = &client.recv_queue_write_lock, .queue = client.recv_queue, }); log.debug("Client {s} subscribed to {s}", .{ idToStr(id), msg.subject }); } fn unsubscribe( server: *Server, io: Io, gpa: Allocator, senders_reader: *std.Io.Reader, id: u128, ) !void { const msg = try parse.unsub(senders_reader); try server.subs_lock.lock(io); defer server.subs_lock.unlock(io); const len = server.subscriptions.items.len; for (0..len) |from_end| { const i = len - from_end - 1; const sub = server.subscriptions.items[i]; if (sub.client_id == id and eql(u8, sub.sid, msg.sid)) { log.debug("Client {s} unsubscribed from {s}", .{ idToStr(id), server.subscriptions.items[i].subject }); sub.deinit(gpa); _ = server.subscriptions.swapRemove(i); break; } } } // TODO: The probed system value is too low. // Setting the value higher leads to higher throughput. // Find a more appropriate value. // It should be the probed value at a minimum. /// Probes the system for an appropriate buffer size. /// Try to match the kernel socket buffers to maximize /// the amount of data we push through each syscall. fn getBufferSizes(io: Io) @Tuple(&.{ usize, usize }) { const default_size = 4 * 1024; const default = .{ default_size, default_size }; // if (true) return default; const dir = Dir.openDirAbsolute(io, "/proc/sys/net/core", .{}) catch { log.warn("couldn't open /proc/sys/net/core", .{}); return default; }; var buf: [64]u8 = undefined; const rmem_max = readBufferSize(io, dir, "rmem_max", &buf, default_size) * 2; const wmem_max = readBufferSize(io, dir, "wmem_max", &buf, default_size) * 2; return .{ rmem_max, wmem_max }; } fn readBufferSize(io: Io, dir: anytype, filename: []const u8, buf: []u8, default: usize) usize { const bytes = dir.readFile(io, filename, buf) catch |err| { log.err("couldn't open {s}: {any}", .{ filename, err }); return default; }; return std.fmt.parseUnsigned(usize, bytes[0 .. bytes.len - 1], 10) catch |err| { log.err("couldn't parse {s}: {any}", .{ bytes[0 .. bytes.len - 1], err }); return default; }; } const uuid_len = 36; fn idToStr(in: u128) [uuid_len]u8 { // Extract segments using bit shifting and casting const part1: u32 = @intCast(in >> 96); const part2: u16 = @intCast((in >> 80) & 0xFFFF); const part3: u16 = @intCast((in >> 64) & 0xFFFF); const part4: u16 = @intCast((in >> 48) & 0xFFFF); const part5: u64 = @intCast(in & 0xFFFFFFFFFFFF); var res: [uuid_len]u8 = undefined; // bufPrint returns a slice of the buffer; we ignore it as we return the whole array _ = std.fmt.bufPrint(&res, "{x:0>8}-{x:0>4}-{x:0>4}-{x:0>4}-{x:0>12}", .{ part1, part2, part3, part4, part5, }) catch unreachable; // unreachable because the buffer size is guaranteed return res; } pub const default_id = "server-id-123"; pub const default_name = "Zits Server";