diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Client.zig | 32 | ||||
-rw-r--r-- | src/message.zig | 58 |
2 files changed, 47 insertions, 43 deletions
diff --git a/src/Client.zig b/src/Client.zig index 0a1cd99..2a7eaaf 100644 --- a/src/Client.zig +++ b/src/Client.zig @@ -48,12 +48,17 @@ fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void { } pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void { - const msg: *SaprusMessage = try .init( - allocator, - .relay, - @intCast(base64Enc.calcSize(payload.len)), + const msg_bytes = try allocator.alignedAlloc( + u8, + @alignOf(SaprusMessage), + try SaprusMessage.lengthForPayloadLength( + .relay, + base64Enc.calcSize(payload.len), + ), ); - defer msg.deinit(allocator); + defer allocator.free(msg_bytes); + const msg: *SaprusMessage = .init(.relay, msg_bytes); + const relay = (try msg.getSaprusTypePayload()).relay; relay.dest = dest; _ = base64Enc.encode(relay.getPayload(), payload); @@ -72,8 +77,14 @@ fn randomPort() u16 { pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage { const dest_port = randomPort(); - const msg: *SaprusMessage = try .init(allocator, .connection, @intCast(payload.len)); - defer msg.deinit(allocator); + const msg_bytes = try allocator.alignedAlloc( + u8, + @alignOf(SaprusMessage), + try SaprusMessage.lengthForPayloadLength(.connection, payload.len), + ); + defer allocator.free(msg_bytes); + const msg: *SaprusMessage = .init(.connection, msg_bytes); + const connection = (try msg.getSaprusTypePayload()).connection; connection.src_port = initial_port; connection.dest_port = dest_port; @@ -90,8 +101,7 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection { initial_port = r.intRangeAtMost(u16, 1024, 65000); } else unreachable; - var initial_conn_res: ?SaprusMessage = null; - errdefer if (initial_conn_res) |*c| c.deinit(allocator); + var initial_conn_res: ?*SaprusMessage = null; var sock = try network.Socket.create(.ipv4, .udp); defer sock.close(); @@ -112,8 +122,8 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection { _ = try sock.receive(&response_buf); // Ignore message that I sent. const len = try sock.receive(&response_buf); - std.debug.print("response bytes: {x}\n", .{response_buf}); - initial_conn_res = (try SaprusMessage.bytesAsValue(response_buf[0..len])).*; + std.debug.print("response bytes: {x}\n", .{response_buf[0..len]}); + initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]); // Complete handshake after awaiting response try broadcastSaprusMessage(msg, randomPort()); diff --git a/src/message.zig b/src/message.zig index 3d69031..622a2d5 100644 --- a/src/message.zig +++ b/src/message.zig @@ -20,9 +20,11 @@ pub const ConnectionOptions = packed struct(u8) { opt8: bool = false, }; -pub const Error = error{ +pub const MessageTypeError = error{ NotImplementedSaprusType, UnknownSaprusType, +}; +pub const MessageParseError = MessageTypeError || error{ InvalidMessage, }; @@ -75,23 +77,25 @@ pub const Message = packed struct { length: u16, bytes: void = {}, - pub fn init(allocator: Allocator, comptime @"type": PacketType, payload_len: u16) !*Self { - const header_size = @sizeOf(switch (@"type") { - .relay => Relay, - .connection => Connection, - .file_transfer => return Error.NotImplementedSaprusType, - else => return Error.UnknownSaprusType, - }); - const size = payload_len + @sizeOf(Self) + header_size; - const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size); + /// Takes a byte slice, and returns a Message struct backed by the slice. + /// This properly initializes the top level headers within the slice. + pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self { + std.debug.assert(bytes.len >= @sizeOf(Self)); const res: *Self = @ptrCast(bytes.ptr); res.type = @"type"; - res.length = payload_len + header_size; + res.length = @intCast(bytes.len - @sizeOf(Self)); return res; } - pub fn deinit(self: *Self, allocator: Allocator) void { - allocator.free(self.asBytes()); + pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 { + std.debug.assert(payload_len < std.math.maxInt(u16)); + const header_size = @sizeOf(switch (@"type") { + .relay => Relay, + .connection => Connection, + .file_transfer => return MessageTypeError.NotImplementedSaprusType, + else => return MessageTypeError.UnknownSaprusType, + }); + return @intCast(payload_len + @sizeOf(Self) + header_size); } fn getRelay(self: *Self) *align(1) Relay { @@ -101,7 +105,7 @@ pub const Message = packed struct { return std.mem.bytesAsValue(Connection, &self.bytes); } - pub fn getSaprusTypePayload(self: *Self) Error!(union(PacketType) { + pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) { relay: *align(1) Relay, file_transfer: void, connection: *align(1) Connection, @@ -109,12 +113,12 @@ pub const Message = packed struct { return switch (self.type) { .relay => .{ .relay = self.getRelay() }, .connection => .{ .connection = self.getConnection() }, - .file_transfer => Error.NotImplementedSaprusType, - else => Error.UnknownSaprusType, + .file_transfer => MessageTypeError.NotImplementedSaprusType, + else => MessageTypeError.UnknownSaprusType, }; } - pub fn nativeFromNetworkEndian(self: *Self) Error!void { + pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void { self.type = @enumFromInt(bigToNative( @typeInfo(@TypeOf(self.type)).@"enum".tag_type, @intFromEnum(self.type), @@ -137,12 +141,12 @@ pub const Message = packed struct { } } - pub fn networkFromNativeEndian(self: *Self) Error!void { + pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void { try switch (try self.getSaprusTypePayload()) { .relay => {}, .connection => |*con| con.*.networkFromNativeEndian(), - .file_transfer => Error.NotImplementedSaprusType, - else => Error.UnknownSaprusType, + .file_transfer => MessageTypeError.NotImplementedSaprusType, + else => MessageTypeError.UnknownSaprusType, }; self.type = @enumFromInt(nativeToBig( @typeInfo(@TypeOf(self.type)).@"enum".tag_type, @@ -151,18 +155,8 @@ pub const Message = packed struct { self.length = nativeToBig(@TypeOf(self.length), self.length); } - pub fn bytesAsValue(bytes: SelfBytes) !*Self { - const res = std.mem.bytesAsValue(Self, bytes); - return switch (res.type) { - .relay, .connection => if (bytes.len == res.length + @sizeOf(Self)) - res - else - Error.InvalidMessage, - .file_transfer => Error.NotImplementedSaprusType, - else => Error.UnknownSaprusType, - }; - } - + /// Deprecated. + /// If I need the bytes, I should just pass around the slice that is backing this to begin with. pub fn asBytes(self: *Self) SelfBytes { const size = @sizeOf(Self) + self.length; return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size]; |