From 018ea4761cd8a61ef3b45d80b68808b06f4bf6f2 Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Sun, 25 Jan 2026 13:22:08 -0500 Subject: Move connection writer into connection Make chunk size a part of the connection type --- src/Client.zig | 2 +- src/Connection.zig | 182 +++++++++++++++++++++++++++++++++++++++-------------- src/c_api.zig | 8 +-- src/main.zig | 86 ++----------------------- 4 files changed, 144 insertions(+), 134 deletions(-) diff --git a/src/Client.zig b/src/Client.zig index ae9ca66..8390ea3 100644 --- a/src/Client.zig +++ b/src/Client.zig @@ -137,7 +137,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection { const RawSocket = @import("./RawSocket.zig"); const SaprusMessage = @import("message.zig").Message; -const SaprusConnection = @import("Connection.zig"); +const SaprusConnection = @import("Connection.zig").Default; const EthIpUdp = @import("./EthIpUdp.zig").EthIpUdp; const std = @import("std"); diff --git a/src/Connection.zig b/src/Connection.zig index 95805de..fd201e9 100644 --- a/src/Connection.zig +++ b/src/Connection.zig @@ -1,57 +1,143 @@ -socket: RawSocket, -headers: EthIpUdp, -connection: SaprusMessage, - -const Connection = @This(); - -pub fn init(socket: RawSocket, headers: EthIpUdp, connection: SaprusMessage) Connection { - return .{ - .socket = socket, - .headers = headers, - .connection = connection, +pub fn Chunked(comptime cs: usize) type { + return struct { + socket: RawSocket, + headers: EthIpUdp, + connection: SaprusMessage, + + const Self = @This(); + + pub const chunk_size = cs; + + pub fn init(socket: RawSocket, headers: EthIpUdp, connection: SaprusMessage) Self { + return .{ + .socket = socket, + .headers = headers, + .connection = connection, + }; + } + + pub fn next(self: Self, io: Io, buf: []u8) ![]const u8 { + _ = io; + log.debug("Awaiting connection message", .{}); + const res = try self.socket.receive(buf); + log.debug("Received {} byte connection message", .{res.len}); + const msg: SaprusMessage = try .parse(res[42..]); + const connection_res = msg.connection; + + log.debug("Payload was {s}", .{connection_res.payload}); + + return connection_res.payload; + } + + pub fn send(self: *Self, io: Io, buf: []const u8) !void { + const io_source: std.Random.IoSource = .{ .io = io }; + const rand = io_source.interface(); + + log.debug("Sending connection message", .{}); + + self.connection.connection.payload = buf; + var connection_bytes_buf: [2048]u8 = undefined; + const connection_bytes = self.connection.toBytes(&connection_bytes_buf); + + self.headers.ip.id = rand.int(u16); + self.headers.setPayloadLen(connection_bytes.len); + + var msg_buf: [2048]u8 = undefined; + var msg_w: Io.Writer = .fixed(&msg_buf); + try msg_w.writeAll(&self.headers.toBytes()); + try msg_w.writeAll(connection_bytes); + const full_msg = msg_w.buffered(); + + try self.socket.send(full_msg); + + log.debug("Sent {} byte connection message", .{full_msg.len}); + } + + pub const Writer = struct { + connection: *Self, + io: Io, + interface: Io.Writer, + err: ?anyerror, + + pub fn init(io: Io, connection: *Self, buf: []u8) Writer { + return .{ + .connection = connection, + .io = io, + .interface = .{ + .vtable = &.{ + .drain = drain, + }, + .buffer = buf, + }, + .err = null, + }; + } + + pub fn drain(io_w: *Io.Writer, data: []const []const u8, splat: usize) Io.Writer.Error!usize { + _ = splat; + const self: *Writer = @alignCast(@fieldParentPtr("interface", io_w)); + var res: usize = 0; + + // Get buffered data from the writer + const buffered = io_w.buffered(); + var buf_offset: usize = 0; + + // Process buffered data in chunks + while (buf_offset < buffered.len) { + const current_chunk_size = @min(chunk_size, buffered.len - buf_offset); + const chunk = buffered[buf_offset..][0..current_chunk_size]; + + // Base64 encode the chunk + var encoded_buf: [chunk_size * 2]u8 = undefined; + const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); + const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); + + // Send encoded chunk + self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { + self.err = err; + return error.WriteFailed; + }; + self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); + + buf_offset += current_chunk_size; + res += current_chunk_size; + } + + // Process data slices + for (data) |slice| { + var slice_offset: usize = 0; + + while (slice_offset < slice.len) { + const current_chunk_size = @min(chunk_size, slice.len - slice_offset); + const chunk = slice[slice_offset..][0..current_chunk_size]; + + // Base64 encode the chunk + var encoded_buf: [chunk_size * 2]u8 = undefined; + const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); + const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); + + // Send encoded chunk + self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { + self.err = err; + return error.WriteFailed; + }; + self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); + + slice_offset += current_chunk_size; + res += current_chunk_size; + } + } + + return res; + } + }; }; } -pub fn next(self: Connection, io: Io, buf: []u8) ![]const u8 { - _ = io; - log.debug("Awaiting connection message", .{}); - const res = try self.socket.receive(buf); - log.debug("Received {} byte connection message", .{res.len}); - const msg: SaprusMessage = try .parse(res[42..]); - const connection_res = msg.connection; - - log.debug("Payload was {s}", .{connection_res.payload}); - - return connection_res.payload; -} - -pub fn send(self: *Connection, io: Io, buf: []const u8) !void { - const io_source: std.Random.IoSource = .{ .io = io }; - const rand = io_source.interface(); - - log.debug("Sending connection message", .{}); - - self.connection.connection.payload = buf; - var connection_bytes_buf: [2048]u8 = undefined; - const connection_bytes = self.connection.toBytes(&connection_bytes_buf); - - self.headers.ip.id = rand.int(u16); - self.headers.setPayloadLen(connection_bytes.len); - - var msg_buf: [2048]u8 = undefined; - var msg_w: Writer = .fixed(&msg_buf); - try msg_w.writeAll(&self.headers.toBytes()); - try msg_w.writeAll(connection_bytes); - const full_msg = msg_w.buffered(); - - try self.socket.send(full_msg); - - log.debug("Sent {} byte connection message", .{full_msg.len}); -} +pub const Default = Chunked(RawSocket.max_payload_len); const std = @import("std"); const Io = std.Io; -const Writer = std.Io.Writer; const log = std.log; diff --git a/src/c_api.zig b/src/c_api.zig index 964f399..830eb85 100644 --- a/src/c_api.zig +++ b/src/c_api.zig @@ -46,7 +46,7 @@ export fn zaprus_connect( const c: ?*zaprus.Client = @ptrCast(@alignCast(client)); const zc = c orelse return null; - const connection = alloc.create(zaprus.Connection) catch return null; + const connection = alloc.create(zaprus.Connection.Default) catch return null; connection.* = zc.connect(io, payload[0..payload_len]) catch { alloc.destroy(connection); return null; @@ -55,7 +55,7 @@ export fn zaprus_connect( } export fn zaprus_deinit_connection(connection: ?*ZaprusConnection) void { - const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection)); + const c: ?*zaprus.Connection.Default = @ptrCast(@alignCast(connection)); if (c) |zc| { alloc.destroy(zc); } @@ -67,7 +67,7 @@ export fn zaprus_connection_next( capacity: usize, out_len: *usize, ) c_int { - const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection)); + const c: ?*zaprus.Connection.Default = @ptrCast(@alignCast(connection)); const zc = c orelse return 1; const result = zc.next(io, out[0..capacity]) catch return 1; @@ -80,7 +80,7 @@ export fn zaprus_connection_send( payload: [*c]const u8, payload_len: usize, ) c_int { - const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection)); + const c: ?*zaprus.Connection.Default = @ptrCast(@alignCast(connection)); const zc = c orelse return 1; zc.send(io, payload[0..payload_len]) catch return 1; diff --git a/src/main.zig b/src/main.zig index c6a8e76..b4d1977 100644 --- a/src/main.zig +++ b/src/main.zig @@ -142,7 +142,7 @@ pub fn main(init: std.process.Init) !void { log.debug("Connection started", .{}); - var connection_writer: ConnectionWriter = .init(init.io, &connection, &con_buf); + var connection_writer: zaprus.Connection.Default.Writer = .init(init.io, &connection, &con_buf); next_message: while (true) { var res_buf: [2048]u8 = undefined; @@ -167,7 +167,10 @@ pub fn main(init: std.process.Init) !void { var child_output_buf: [SaprusClient.max_payload_len]u8 = undefined; var child_output_reader = child.stdout.?.reader(init.io, &child_output_buf); - _ = child_output_reader.interface.stream(&connection_writer.interface, .limited(SaprusClient.max_payload_len * 10)) catch continue :next_message; + _ = child_output_reader.interface.stream( + &connection_writer.interface, + .limited(@TypeOf(connection_writer.connection.*).chunk_size * 10), + ) catch continue :next_message; } } } @@ -175,85 +178,6 @@ pub fn main(init: std.process.Init) !void { unreachable; } -const ConnectionWriter = struct { - connection: *zaprus.Connection, - io: std.Io, - interface: Writer, - err: ?anyerror, - - pub fn init(io: std.Io, connection: *zaprus.Connection, buf: []u8) ConnectionWriter { - return .{ - .connection = connection, - .io = io, - .interface = .{ - .vtable = &.{ - .drain = drain, - }, - .buffer = buf, - }, - .err = null, - }; - } - - pub fn drain(io_w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { - _ = splat; - const self: *ConnectionWriter = @alignCast(@fieldParentPtr("interface", io_w)); - var res: usize = 0; - - // Get buffered data from the writer - const buffered = io_w.buffered(); - var buf_offset: usize = 0; - - // Process buffered data in chunks - while (buf_offset < buffered.len) { - const chunk_size = @min(SaprusClient.max_payload_len, buffered.len - buf_offset); - const chunk = buffered[buf_offset..][0..chunk_size]; - - // Base64 encode the chunk - var encoded_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; - const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); - const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); - - // Send encoded chunk - self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { - self.err = err; - return error.WriteFailed; - }; - self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); - - buf_offset += chunk_size; - res += chunk_size; - } - - // Process data slices - for (data) |slice| { - var slice_offset: usize = 0; - - while (slice_offset < slice.len) { - const chunk_size = @min(SaprusClient.max_payload_len, slice.len - slice_offset); - const chunk = slice[slice_offset..][0..chunk_size]; - - // Base64 encode the chunk - var encoded_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; - const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); - const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); - - // Send encoded chunk - self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { - self.err = err; - return error.WriteFailed; - }; - self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); - - slice_offset += chunk_size; - res += chunk_size; - } - } - - return res; - } -}; - // const ConnectionWriter = struct { // connection: *zaprus.Connection, // io: std.Io, -- cgit