diff options
| author | Robby Zambito <contact@robbyzambito.me> | 2025-05-10 11:44:57 -0400 | 
|---|---|---|
| committer | Robby Zambito <contact@robbyzambito.me> | 2025-05-10 21:46:53 -0400 | 
| commit | 245dab49098247289e09b03baf279e48c9340f48 (patch) | |
| tree | 599fb499aabd960217586409f2f51b8666e2b58f | |
| parent | cde5c3626cbcb5baa7b2ac9c815face628438dfc (diff) | |
Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
| -rw-r--r-- | src/Client.zig | 32 | ||||
| -rw-r--r-- | src/message.zig | 58 | 
2 files changed, 47 insertions, 43 deletions
| diff --git a/src/Client.zig b/src/Client.zig index 0a1cd99..2a7eaaf 100644 --- a/src/Client.zig +++ b/src/Client.zig @@ -48,12 +48,17 @@ fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void {  }  pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void { -    const msg: *SaprusMessage = try .init( -        allocator, -        .relay, -        @intCast(base64Enc.calcSize(payload.len)), +    const msg_bytes = try allocator.alignedAlloc( +        u8, +        @alignOf(SaprusMessage), +        try SaprusMessage.lengthForPayloadLength( +            .relay, +            base64Enc.calcSize(payload.len), +        ),      ); -    defer msg.deinit(allocator); +    defer allocator.free(msg_bytes); +    const msg: *SaprusMessage = .init(.relay, msg_bytes); +      const relay = (try msg.getSaprusTypePayload()).relay;      relay.dest = dest;      _ = base64Enc.encode(relay.getPayload(), payload); @@ -72,8 +77,14 @@ fn randomPort() u16 {  pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage {      const dest_port = randomPort(); -    const msg: *SaprusMessage = try .init(allocator, .connection, @intCast(payload.len)); -    defer msg.deinit(allocator); +    const msg_bytes = try allocator.alignedAlloc( +        u8, +        @alignOf(SaprusMessage), +        try SaprusMessage.lengthForPayloadLength(.connection, payload.len), +    ); +    defer allocator.free(msg_bytes); +    const msg: *SaprusMessage = .init(.connection, msg_bytes); +      const connection = (try msg.getSaprusTypePayload()).connection;      connection.src_port = initial_port;      connection.dest_port = dest_port; @@ -90,8 +101,7 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {          initial_port = r.intRangeAtMost(u16, 1024, 65000);      } else unreachable; -    var initial_conn_res: ?SaprusMessage = null; -    errdefer if (initial_conn_res) |*c| c.deinit(allocator); +    var initial_conn_res: ?*SaprusMessage = null;      var sock = try network.Socket.create(.ipv4, .udp);      defer sock.close(); @@ -112,8 +122,8 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {      _ = try sock.receive(&response_buf); // Ignore message that I sent.      const len = try sock.receive(&response_buf); -    std.debug.print("response bytes: {x}\n", .{response_buf}); -    initial_conn_res = (try SaprusMessage.bytesAsValue(response_buf[0..len])).*; +    std.debug.print("response bytes: {x}\n", .{response_buf[0..len]}); +    initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]);      // Complete handshake after awaiting response      try broadcastSaprusMessage(msg, randomPort()); diff --git a/src/message.zig b/src/message.zig index 3d69031..622a2d5 100644 --- a/src/message.zig +++ b/src/message.zig @@ -20,9 +20,11 @@ pub const ConnectionOptions = packed struct(u8) {      opt8: bool = false,  }; -pub const Error = error{ +pub const MessageTypeError = error{      NotImplementedSaprusType,      UnknownSaprusType, +}; +pub const MessageParseError = MessageTypeError || error{      InvalidMessage,  }; @@ -75,23 +77,25 @@ pub const Message = packed struct {      length: u16,      bytes: void = {}, -    pub fn init(allocator: Allocator, comptime @"type": PacketType, payload_len: u16) !*Self { -        const header_size = @sizeOf(switch (@"type") { -            .relay => Relay, -            .connection => Connection, -            .file_transfer => return Error.NotImplementedSaprusType, -            else => return Error.UnknownSaprusType, -        }); -        const size = payload_len + @sizeOf(Self) + header_size; -        const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size); +    /// Takes a byte slice, and returns a Message struct backed by the slice. +    /// This properly initializes the top level headers within the slice. +    pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self { +        std.debug.assert(bytes.len >= @sizeOf(Self));          const res: *Self = @ptrCast(bytes.ptr);          res.type = @"type"; -        res.length = payload_len + header_size; +        res.length = @intCast(bytes.len - @sizeOf(Self));          return res;      } -    pub fn deinit(self: *Self, allocator: Allocator) void { -        allocator.free(self.asBytes()); +    pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 { +        std.debug.assert(payload_len < std.math.maxInt(u16)); +        const header_size = @sizeOf(switch (@"type") { +            .relay => Relay, +            .connection => Connection, +            .file_transfer => return MessageTypeError.NotImplementedSaprusType, +            else => return MessageTypeError.UnknownSaprusType, +        }); +        return @intCast(payload_len + @sizeOf(Self) + header_size);      }      fn getRelay(self: *Self) *align(1) Relay { @@ -101,7 +105,7 @@ pub const Message = packed struct {          return std.mem.bytesAsValue(Connection, &self.bytes);      } -    pub fn getSaprusTypePayload(self: *Self) Error!(union(PacketType) { +    pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {          relay: *align(1) Relay,          file_transfer: void,          connection: *align(1) Connection, @@ -109,12 +113,12 @@ pub const Message = packed struct {          return switch (self.type) {              .relay => .{ .relay = self.getRelay() },              .connection => .{ .connection = self.getConnection() }, -            .file_transfer => Error.NotImplementedSaprusType, -            else => Error.UnknownSaprusType, +            .file_transfer => MessageTypeError.NotImplementedSaprusType, +            else => MessageTypeError.UnknownSaprusType,          };      } -    pub fn nativeFromNetworkEndian(self: *Self) Error!void { +    pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {          self.type = @enumFromInt(bigToNative(              @typeInfo(@TypeOf(self.type)).@"enum".tag_type,              @intFromEnum(self.type), @@ -137,12 +141,12 @@ pub const Message = packed struct {          }      } -    pub fn networkFromNativeEndian(self: *Self) Error!void { +    pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {          try switch (try self.getSaprusTypePayload()) {              .relay => {},              .connection => |*con| con.*.networkFromNativeEndian(), -            .file_transfer => Error.NotImplementedSaprusType, -            else => Error.UnknownSaprusType, +            .file_transfer => MessageTypeError.NotImplementedSaprusType, +            else => MessageTypeError.UnknownSaprusType,          };          self.type = @enumFromInt(nativeToBig(              @typeInfo(@TypeOf(self.type)).@"enum".tag_type, @@ -151,18 +155,8 @@ pub const Message = packed struct {          self.length = nativeToBig(@TypeOf(self.length), self.length);      } -    pub fn bytesAsValue(bytes: SelfBytes) !*Self { -        const res = std.mem.bytesAsValue(Self, bytes); -        return switch (res.type) { -            .relay, .connection => if (bytes.len == res.length + @sizeOf(Self)) -                res -            else -                Error.InvalidMessage, -            .file_transfer => Error.NotImplementedSaprusType, -            else => Error.UnknownSaprusType, -        }; -    } - +    /// Deprecated. +    /// If I need the bytes, I should just pass around the slice that is backing this to begin with.      pub fn asBytes(self: *Self) SelfBytes {          const size = @sizeOf(Self) + self.length;          return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size]; | 
