From fe166d21060ee541d1d053da3a85144c7b269120 Mon Sep 17 00:00:00 2001 From: Robby Zambito Date: Sun, 12 Oct 2025 17:12:58 -0400 Subject: Start breaking out net logic to NetWriter --- src/Client.zig | 202 ++++++++++++----------------------------------------- src/NetWriter.zig | 205 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 19 +++-- src/root.zig | 1 + 4 files changed, 264 insertions(+), 163 deletions(-) create mode 100644 src/NetWriter.zig (limited to 'src') diff --git a/src/Client.zig b/src/Client.zig index 24dbcc0..7de50f1 100644 --- a/src/Client.zig +++ b/src/Client.zig @@ -1,7 +1,6 @@ const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '='); const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '='); -rand: Random, writer: *std.Io.Writer, const Self = @This(); @@ -9,15 +8,7 @@ const Self = @This(); const max_message_size = 2048; pub fn init(writer: *std.Io.Writer) !Self { - var prng = Random.DefaultPrng.init(blk: { - var seed: u64 = undefined; - try posix.getrandom(mem.asBytes(&seed)); - break :blk seed; - }); - const rand = prng.random(); - return .{ - .rand = rand, .writer = writer, }; } @@ -31,40 +22,8 @@ pub fn deinit(self: *Self) void { fn broadcastInitialInterestMessage(self: *Self, msg_bytes: []u8) !void { const writer = self.writer; - const total_len = @sizeOf(EthernetHeaders) + @sizeOf(IpHeaders) + @sizeOf(UdpHeaders) + msg_bytes.len; - // Ensure the writer is in a valid state - std.debug.assert(writer.buffer.len >= total_len); - _ = writer.consumeAll(); - - const ether_headers: EthernetHeaders = .{ - .dest_mac = .{ 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff }, - .src_mac = blk: { - var output_bytes: [6]u8 = undefined; - output_bytes[0] = 0xee; - self.rand.bytes(output_bytes[1..]); - break :blk output_bytes; - }, - .ether_type = 0x0800, - }; - - const ip_headers: IpHeaders = .{ - .total_length = @intCast(total_len - 92), - .ttl = 0x64, - .protocol = 0x11, - .src_ip = .{ 0xff, 0x02, 0x03, 0x04 }, - .dest_ip = .{ 0xff, 0xff, 0xff, 0xff }, - }; - - const udp_headers: UdpHeaders = .{ - .src_port = 0xbbbb, - .dest_port = 8888, - .length = @intCast(msg_bytes.len), - }; - - try ether_headers.write(writer); - try ip_headers.write(writer); - try udp_headers.write(writer); + std.debug.assert(writer.buffer.len - writer.end >= msg_bytes.len); // Saprus const msg_target_bytes = try writer.writableSlice(msg_bytes.len); @@ -118,137 +77,68 @@ pub fn sendRelay(self: *Self, payload: []const u8, dest: [4]u8) !void { try self.broadcastInitialInterestMessage(msg_bytes); } -fn randomPort(self: Self) u16 { - return self.rand.intRangeAtMost(u16, 1024, 65000); -} +// pub fn sendInitialConnection( +// self: Self, +// payload: []const u8, +// output_bytes: []u8, +// initial_port: u16, +// ) !*align(1) SaprusMessage { +// const dest_port = self.randomPort(); +// const msg_bytes = output_bytes[0..try SaprusMessage.calcSize( +// .connection, +// base64Enc.calcSize(payload.len), +// )]; +// const msg: *align(1) SaprusMessage = .init(.connection, msg_bytes); -pub fn sendInitialConnection( - self: Self, - payload: []const u8, - output_bytes: []u8, - initial_port: u16, -) !*align(1) SaprusMessage { - const dest_port = self.randomPort(); - const msg_bytes = output_bytes[0..try SaprusMessage.calcSize( - .connection, - base64Enc.calcSize(payload.len), - )]; - const msg: *align(1) SaprusMessage = .init(.connection, msg_bytes); +// const connection = (try msg.getSaprusTypePayload()).connection; +// connection.src_port = initial_port; +// connection.dest_port = dest_port; +// _ = base64Enc.encode(connection.getPayload(), payload); - const connection = (try msg.getSaprusTypePayload()).connection; - connection.src_port = initial_port; - connection.dest_port = dest_port; - _ = base64Enc.encode(connection.getPayload(), payload); +// try broadcastSaprusMessage(msg_bytes, 8888); - try broadcastSaprusMessage(msg_bytes, 8888); +// return msg; +// } - return msg; -} - -pub fn connect(self: Self, payload: []const u8) !?SaprusConnection { - const initial_port = self.randomPort(); - - var initial_conn_res: ?*align(1) SaprusMessage = null; +// pub fn connect(self: Self, payload: []const u8) !?SaprusConnection { +// const initial_port = self.randomPort(); - var sock = try network.Socket.create(.ipv4, .udp); - defer sock.close(); +// var initial_conn_res: ?*align(1) SaprusMessage = null; - // Bind to 255.255.255.255:8888 - const bind_addr = network.EndPoint{ - .address = network.Address{ .ipv4 = network.Address.IPv4.broadcast }, - .port = 8888, - }; +// var sock = try network.Socket.create(.ipv4, .udp); +// defer sock.close(); - // timeout 1s - try sock.setReadTimeout(1 * std.time.us_per_s); - try sock.bind(bind_addr); +// // Bind to 255.255.255.255:8888 +// const bind_addr = network.EndPoint{ +// .address = network.Address{ .ipv4 = network.Address.IPv4.broadcast }, +// .port = 8888, +// }; - var sent_msg_bytes: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined; - const msg = try self.sendInitialConnection(payload, &sent_msg_bytes, initial_port); +// // timeout 1s +// try sock.setReadTimeout(1 * std.time.us_per_s); +// try sock.bind(bind_addr); - var response_buf: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined; - _ = try sock.receive(&response_buf); // Ignore message that I sent. - const len = try sock.receive(&response_buf); +// var sent_msg_bytes: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined; +// const msg = try self.sendInitialConnection(payload, &sent_msg_bytes, initial_port); - initial_conn_res = try .networkBytesAsValue(response_buf[0..len]); +// var response_buf: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined; +// _ = try sock.receive(&response_buf); // Ignore message that I sent. +// const len = try sock.receive(&response_buf); - // Complete handshake after awaiting response - try broadcastSaprusMessage(msg.asBytes(), self.randomPort()); +// initial_conn_res = try .networkBytesAsValue(response_buf[0..len]); - if (false) { - return initial_conn_res.?; - } - return null; -} +// // Complete handshake after awaiting response +// try broadcastSaprusMessage(msg.asBytes(), self.randomPort()); -const EthernetHeaders = struct { - dest_mac: @Vector(6, u8), - - src_mac: @Vector(6, u8), - - ether_type: u16, - - fn write(hdr: @This(), writer: *std.Io.Writer) !void { - try writer.writeInt(u48, @bitCast(hdr.dest_mac), .big); - try writer.writeInt(u48, @bitCast(hdr.src_mac), .big); - try writer.writeInt(u16, hdr.ether_type, .big); - } -}; - -const IpHeaders = struct { - _: u8 = 0x45, - // ip_version: u4, - // header_length: u4 = 0, - type_of_service: u8 = 0, - total_length: u16 = 0x04, - - identification: u16 = 0, - __: u16 = 0x0, - // ethernet_flags: u3 = 0, - // fragment_offset: u13 = 0, - ttl: u8 = 0, - protocol: u8 = 0, - - header_checksum: @Vector(2, u8) = .{ 0, 0 }, - src_ip: @Vector(4, u8), - - dest_ip: @Vector(4, u8), - - fn write(hdr: @This(), writer: *std.Io.Writer) !void { - try writer.writeInt(u8, 0x45, .big); // ip version and header length - try writer.writeByte(hdr.type_of_service); - try writer.writeInt(u16, hdr.total_length, .big); - try writer.writeInt(u16, hdr.identification, .big); - try writer.writeInt(u16, 0x00, .big); // ethernet flags and fragment offset - try writer.writeByte(hdr.ttl); - try writer.writeByte(hdr.protocol); - try writer.writeInt(u16, @bitCast(hdr.header_checksum), .big); - try writer.writeInt(u32, @bitCast(hdr.src_ip), .big); - try writer.writeInt(u32, @bitCast(hdr.dest_ip), .big); - } -}; - -const UdpHeaders = packed struct { - src_port: u16, - - dest_port: u16, - length: u16, - checksum: @Vector(2, u8) = .{ 0, 0 }, - - fn write(hdr: @This(), writer: *std.Io.Writer) !void { - try writer.writeInt(u16, hdr.src_port, .big); - try writer.writeInt(u16, hdr.dest_port, .big); - try writer.writeInt(u16, hdr.length, .big); - try writer.writeInt(u16, @bitCast(hdr.checksum), .big); - } -}; +// if (false) { +// return initial_conn_res.?; +// } +// return null; +// } const SaprusMessage = @import("message.zig").Message; const SaprusConnection = @import("Connection.zig"); const std = @import("std"); -const Random = std.Random; -const posix = std.posix; -const mem = std.mem; const network = @import("network"); diff --git a/src/NetWriter.zig b/src/NetWriter.zig new file mode 100644 index 0000000..b08ccfe --- /dev/null +++ b/src/NetWriter.zig @@ -0,0 +1,205 @@ +//! Wraps a writer with UDP headers. +//! This is useful for wrapping RawSocket Writer with appropriate headers. + +rand: Random, +wrapped: *Writer, +interface: Writer, + +pub fn init(w: *Writer) !NetWriter { + std.debug.assert(w.buffer.len > @sizeOf(EthernetHeaders) + @sizeOf(IpHeaders) + @sizeOf(UdpHeaders)); + + var prng = Random.DefaultPrng.init(blk: { + var seed: u64 = undefined; + try posix.getrandom(mem.asBytes(&seed)); + break :blk seed; + }); + + return .{ + .rand = prng.random(), + .wrapped = w, + .interface = .{ + .vtable = &.{ + .drain = drain, + .flush = flush, + }, + .buffer = &.{}, + }, + }; +} + +fn drain(io_w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + const w: *NetWriter = @alignCast(@fieldParentPtr("interface", io_w)); + + var res: usize = 0; + + if (io_w.end == 0) { + const ether_headers: EthernetHeaders = .{ + .dest_mac = .{ 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff }, + .src_mac = blk: { + var output_bytes: [6]u8 = undefined; + output_bytes[0] = 0xee; + w.rand.bytes(output_bytes[1..]); + break :blk output_bytes; + }, + .ether_type = 0x0800, + }; + + const ip_headers: IpHeaders = .{ + .total_length = @intCast(res - 92), + .ttl = 0x64, + .protocol = 0x11, + .src_ip = .{ 0xff, 0x02, 0x03, 0x04 }, + .dest_ip = .{ 0xff, 0xff, 0xff, 0xff }, + }; + + const udp_headers: UdpHeaders = .{ + .src_port = 0xbbbb, + .dest_port = 8888, + .length = @intCast(res), + }; + + res += try ether_headers.write(w.wrapped); + res += try ip_headers.write(w.wrapped); + res += try udp_headers.write(w.wrapped); + } + + res += try w.wrapped.writeSplat(data, splat); + return res; +} + +fn flush(io_w: *Writer) Writer.Error!void { + const w: *NetWriter = @alignCast(@fieldParentPtr("interface", io_w)); + try w.wrapped.flush(); +} + +const EthernetHeaders = struct { + dest_mac: @Vector(6, u8), + + src_mac: @Vector(6, u8), + + ether_type: u16, + + fn write(hdr: EthernetHeaders, writer: *std.Io.Writer) Writer.Error!usize { + comptime var res: usize = 0; + + res += @sizeOf(u48); + try writer.writeInt(u48, @bitCast(hdr.dest_mac), .big); + + res += @sizeOf(u48); + try writer.writeInt(u48, @bitCast(hdr.src_mac), .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.ether_type, .big); + + return res; + } + + const byte_len = @bitSizeOf(EthernetHeaders) / 8; + + fn bytes(hdr: EthernetHeaders) [byte_len]u8 { + var res: [byte_len]u8 = undefined; + hdr.write(Writer.fixed(&res)) catch unreachable; + } +}; + +const IpHeaders = struct { + // ip_version: u4, + // header_length: u4 = 0, + type_of_service: u8 = 0, + total_length: u16 = 0x04, + + identification: u16 = 0, + // ethernet_flags: u3 = 0, + // fragment_offset: u13 = 0, + ttl: u8 = 0, + protocol: u8 = 0, + + header_checksum: @Vector(2, u8) = .{ 0, 0 }, + src_ip: @Vector(4, u8), + + dest_ip: @Vector(4, u8), + + fn write(hdr: @This(), writer: *std.Io.Writer) Writer.Error!usize { + comptime var res: usize = 0; + + res += @sizeOf(u8); + try writer.writeInt(u8, 0x45, .big); // ip version and header length + + res += @sizeOf(u8); + try writer.writeByte(hdr.type_of_service); + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.total_length, .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.identification, .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, 0x00, .big); // ethernet flags and fragment offset + + res += @sizeOf(u8); + try writer.writeByte(hdr.ttl); + + res += @sizeOf(u8); + try writer.writeByte(hdr.protocol); + + res += @sizeOf(u16); + try writer.writeInt(u16, @bitCast(hdr.header_checksum), .big); + + res += @sizeOf(u32); + try writer.writeInt(u32, @bitCast(hdr.src_ip), .big); + + res += @sizeOf(u32); + try writer.writeInt(u32, @bitCast(hdr.dest_ip), .big); + + return res; + } + + const byte_len = @bitSizeOf(IpHeaders) / 8; + + fn bytes(hdr: IpHeaders) [byte_len]u8 { + var res: [byte_len]u8 = undefined; + hdr.write(Writer.fixed(&res)) catch unreachable; + } +}; + +const UdpHeaders = packed struct { + src_port: u16, + + dest_port: u16, + length: u16, + checksum: @Vector(2, u8) = .{ 0, 0 }, + + fn write(hdr: @This(), writer: *std.Io.Writer) Writer.Error!usize { + comptime var res: usize = 0; + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.src_port, .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.dest_port, .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, hdr.length, .big); + + res += @sizeOf(u16); + try writer.writeInt(u16, @bitCast(hdr.checksum), .big); + + return res; + } + + const byte_len = @bitSizeOf(UdpHeaders) / 8; + + fn bytes(hdr: UdpHeaders) [byte_len]u8 { + var res: [byte_len]u8 = undefined; + hdr.write(Writer.fixed(&res)) catch unreachable; + } +}; + +const std = @import("std"); +const Random = std.Random; +const posix = std.posix; +const Writer = std.Io.Writer; +const mem = std.mem; + +const NetWriter = @This(); diff --git a/src/main.zig b/src/main.zig index 07a7a10..1266675 100644 --- a/src/main.zig +++ b/src/main.zig @@ -47,8 +47,9 @@ pub fn main() !void { } var sock_buffer: [2048]u8 = undefined; - var rawSocketWriter: RawSocketWriter = try .init("enp7s0", &sock_buffer); // /proc/net/dev - var client = try SaprusClient.init(&rawSocketWriter.interface); + var raw_socket_writer: RawSocketWriter = try .init("enp7s0", &sock_buffer); // /proc/net/dev + var net_writer: NetWriter = try .init(&raw_socket_writer.interface); + var client = try SaprusClient.init(&net_writer.interface); defer client.deinit(); if (res.args.relay) |r| { @@ -60,11 +61,14 @@ pub fn main() !void { // std.debug.print("Sent: {s}\n", .{r}); return; } else if (res.args.connect) |c| { - _ = client.connect(if (c.len > 0) c else "Hello darkness my old friend") catch |err| switch (err) { - error.WouldBlock => null, - else => return err, - }; - return; + if (false) { + _ = client.connect(if (c.len > 0) c else "Hello darkness my old friend") catch |err| switch (err) { + error.WouldBlock => null, + else => return err, + }; + return; + } + @panic("Not implemented"); } return clap.helpToFile(.stderr(), clap.Help, ¶ms, .{}); @@ -93,5 +97,6 @@ const zaprus = @import("zaprus"); const SaprusClient = zaprus.Client; const SaprusMessage = zaprus.Message; const RawSocketWriter = zaprus.RawSocketWriter; +const NetWriter = zaprus.NetWriter; const clap = @import("clap"); diff --git a/src/root.zig b/src/root.zig index b7c2795..bcf9415 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,6 +1,7 @@ pub const Client = @import("Client.zig"); pub const Connection = @import("Connection.zig"); pub const RawSocketWriter = @import("RawSocketWriter.zig"); +pub const NetWriter = @import("NetWriter.zig"); const msg = @import("message.zig"); -- cgit