diff options
Diffstat (limited to 'src/main.zig')
| -rw-r--r-- | src/main.zig | 181 |
1 files changed, 154 insertions, 27 deletions
diff --git a/src/main.zig b/src/main.zig index 9682522..c6a8e76 100644 --- a/src/main.zig +++ b/src/main.zig @@ -124,8 +124,8 @@ pub fn main(init: std.process.Init) !void { return; } - var init_con_buf: [SaprusClient.max_payload_len]u8 = undefined; - var w: Writer = .fixed(&init_con_buf); + var con_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; + var w: Writer = .fixed(&con_buf); try w.print("{b64}", .{flags.connect.?}); if (flags.connect != null) { @@ -142,6 +142,8 @@ pub fn main(init: std.process.Init) !void { log.debug("Connection started", .{}); + var connection_writer: ConnectionWriter = .init(init.io, &connection, &con_buf); + next_message: while (true) { var res_buf: [2048]u8 = undefined; try client.socket.setTimeout(if (is_debug) 60 else 600, 0); @@ -160,34 +162,12 @@ pub fn main(init: std.process.Init) !void { var child = std.process.spawn(init.io, .{ .argv = &.{ "bash", "-c", connection_payload }, .stdout = .pipe, - .stderr = .pipe, }) catch continue; - var child_stdout: std.ArrayList(u8) = .empty; - defer child_stdout.deinit(init.gpa); - var child_stderr: std.ArrayList(u8) = .empty; - defer child_stderr.deinit(init.gpa); + var child_output_buf: [SaprusClient.max_payload_len]u8 = undefined; + var child_output_reader = child.stdout.?.reader(init.io, &child_output_buf); - child.collectOutput(init.gpa, &child_stdout, &child_stderr, std.math.maxInt(usize)) catch |err| { - log.debug("Failed to collect output: {t}", .{err}); - continue; - }; - _ = try child.wait(init.io); - - var cmd_output_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; - var cmd_output: Writer = .fixed(&cmd_output_buf); - - var cmd_output_window_iter = std.mem.window(u8, child_stdout.items, SaprusClient.max_payload_len, SaprusClient.max_payload_len); - while (cmd_output_window_iter.next()) |chunk| { - cmd_output.end = 0; - // Unreachable because the cmd_output_buf is twice the size of the chunk. - cmd_output.print("{b64}", .{chunk}) catch unreachable; - connection.send(init.io, cmd_output.buffered()) catch |err| { - log.debug("Failed to send connection chunk: {t}", .{err}); - continue :next_message; - }; - try init.io.sleep(.fromMilliseconds(40), .boot); - } + _ = child_output_reader.interface.stream(&connection_writer.interface, .limited(SaprusClient.max_payload_len * 10)) catch continue :next_message; } } } @@ -195,6 +175,153 @@ pub fn main(init: std.process.Init) !void { unreachable; } +const ConnectionWriter = struct { + connection: *zaprus.Connection, + io: std.Io, + interface: Writer, + err: ?anyerror, + + pub fn init(io: std.Io, connection: *zaprus.Connection, buf: []u8) ConnectionWriter { + return .{ + .connection = connection, + .io = io, + .interface = .{ + .vtable = &.{ + .drain = drain, + }, + .buffer = buf, + }, + .err = null, + }; + } + + pub fn drain(io_w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + _ = splat; + const self: *ConnectionWriter = @alignCast(@fieldParentPtr("interface", io_w)); + var res: usize = 0; + + // Get buffered data from the writer + const buffered = io_w.buffered(); + var buf_offset: usize = 0; + + // Process buffered data in chunks + while (buf_offset < buffered.len) { + const chunk_size = @min(SaprusClient.max_payload_len, buffered.len - buf_offset); + const chunk = buffered[buf_offset..][0..chunk_size]; + + // Base64 encode the chunk + var encoded_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; + const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); + const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); + + // Send encoded chunk + self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { + self.err = err; + return error.WriteFailed; + }; + self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); + + buf_offset += chunk_size; + res += chunk_size; + } + + // Process data slices + for (data) |slice| { + var slice_offset: usize = 0; + + while (slice_offset < slice.len) { + const chunk_size = @min(SaprusClient.max_payload_len, slice.len - slice_offset); + const chunk = slice[slice_offset..][0..chunk_size]; + + // Base64 encode the chunk + var encoded_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; + const encoded_len = std.base64.standard.Encoder.calcSize(chunk.len); + const encoded = std.base64.standard.Encoder.encode(&encoded_buf, chunk); + + // Send encoded chunk + self.connection.send(self.io, encoded[0..encoded_len]) catch |err| { + self.err = err; + return error.WriteFailed; + }; + self.io.sleep(.fromMilliseconds(40), .boot) catch @panic("honk shoo"); + + slice_offset += chunk_size; + res += chunk_size; + } + } + + return res; + } +}; + +// const ConnectionWriter = struct { +// connection: *zaprus.Connection, +// io: std.Io, +// interface: Writer, +// err: ?anyerror, + +// pub fn init(io: std.Io, connection: *zaprus.Connection) ConnectionWriter { +// return .{ +// .connection = connection, +// .io = io, +// .interface = .{}, +// }; +// } + +// pub fn drain(io_w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { +// var res: usize = 0; +// const w: *ConnectionWriter = @alignCast(@fieldParentPtr("interface", io_w)); +// var buffered_reader: std.Io.Reader = .fixed(io_w.buffered()); +// const io = w.io; + +// // Collect the output in chunks +// var output_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; +// var output_writer: Writer = .fixed(&output_buf); +// while (buffered_reader.end - buffered_reader.seek > SaprusClient.max_payload_len) { +// output_writer.end = 0; +// output_writer.print("{b64}", .{&buffered_reader.takeArray(SaprusClient.max_payload_len)}); +// self.connection.send(io, output_writer.buffered()) catch |err| { +// self.err = err; +// return error.WriteFailed; +// }; +// res += SaprusClient.max_payload_len; +// } +// // accumulate the remainder of buffered and the data slices before writing b64 to the output_writer +// var output_acc_buf: [SaprusClient.max_payload_len]u8 = undefined; +// var output_acc_w: Writer = .fixed(&output_acc_buf); + +// // We can write the rest of buffered_reader to the output_writer because we know after +// // the previous loop the maximum length of the remaining data is SaprusClient.max_payload_len. +// output_writer.end = 0; +// res += output_acc_w.write(buffered_reader.buffered()) catch unreachable; + +// for (data[0 .. data.len - 1]) |chunk| { +// if (chunk.len < SaprusClient.max_payload_len - output_acc_w.end) { +// res += output_acc_w.write(chunk) catch unreachable; +// continue; +// } +// var chunk_reader: std.Io.Reader = .fixed(chunk); +// while (chunk_reader.end - chunk_reader.seek > 0) { +// res += chunk_reader.stream( +// &output_acc_w, +// .limited(SaprusClient.max_payload_len - output_acc_w.end), +// ) catch unreachable; +// if (SaprusClient.max_payload_len - output_acc_w.end == 0) { +// output_writer.print("{b64}", .{output_acc_w.buffered()}); +// output_acc_w.end = 0; +// self.connection.send(io, output_writer.buffered()) catch |err| { +// self.err = err; +// return error.WriteFailed; +// }; +// output_writer.end = 0; +// } +// } +// } + +// return res; +// } +// }; + fn parseDest(in: ?[]const u8) [4]u8 { if (in) |dest| { if (dest.len <= 4) { |
