diff options
Diffstat (limited to 'src/message.zig')
-rw-r--r-- | src/message.zig | 157 |
1 files changed, 118 insertions, 39 deletions
diff --git a/src/message.zig b/src/message.zig index 3d69031..8e43c3a 100644 --- a/src/message.zig +++ b/src/message.zig @@ -26,6 +26,124 @@ pub const Error = error{ InvalidMessage, }; +pub fn MessageNew(comptime packet_type: PacketType) type { + comptime { + if (packet_type == .file_transfer) + @compileError("File transfer not implemented"); + if (packet_type != .relay and packet_type != .connection) + @compileError("Unkown message type"); + } + + return packed struct { + const Self = @This(); + const SelfBytes = []align(@alignOf(Self)) u8; + + const Relay = struct { + pub fn getPayload(self: *Self) []u8 { + return @as([*]align(@alignOf(Self)) u8, @ptrCast(&self.payload))[0 .. self.length - 4]; + } + }; + const Connection = packed struct { + pub fn getPayload(self: Self) []u8 { + return @as([*]u8, &self.payload)[0 .. self.length - 4]; + } + }; + + type: PacketType = packet_type, + length: u16, + + // Relay + dest: if (packet_type == .relay) @Vector(4, u8) else void, + + // Connection + src_port: if (packet_type == .connection) u16 else void, // random number > 1024 + dest_port: if (packet_type == .connection) u16 else void, // random number > 1024 + seq_num: if (packet_type == .connection) u32 else void, + msg_id: if (packet_type == .connection) u32 else void, + reserved: if (packet_type == .connection) u8 else void, + options: if (packet_type == .connection) ConnectionOptions else void = if (packet_type == .connection) .{} else {}, + + // Relay or Connection + payload: switch (packet_type) { + .relay, .connection => void, + else => noreturn, + }, + + pub usingnamespace switch (packet_type) { + .relay => Relay, + .connection => Connection, + .file_transfer => @compileError("File Transfer message type not implemented"), + else => @compileError("Unknown message type"), + }; + + pub fn init(allocator: Allocator, payload_len: u16) !*Self { + const size = payload_len + @sizeOf(Self); + const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size); + const res: *Self = @ptrCast(bytes.ptr); + res.type = packet_type; + res.length = payload_len; + return res; + } + + pub fn deinit(self: *Self, allocator: Allocator) void { + allocator.free(self.asBytes()); + } + + pub fn nativeFromNetworkEndian(self: *Self) void { + self.type = @enumFromInt(bigToNative( + @typeInfo(@TypeOf(self.type)).@"enum".tag_type, + @intFromEnum(self.type), + )); + self.length = bigToNative(@TypeOf(self.length), self.length); + + if (packet_type == .connection) { + self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port); + self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port); + self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num); + self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id); + } + } + + pub fn networkFromNativeEndian(self: *Self) void { + self.type = @enumFromInt(bigToNative( + @typeInfo(@TypeOf(self.type)).@"enum".tag_type, + @intFromEnum(self.type), + )); + self.length = bigToNative(@TypeOf(self.length), self.length); + + if (packet_type == .connection) { + self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port); + self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port); + self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num); + self.msg_id = nativeToBig(@TypeOf(self.msg_id), self.msg_id); + } + } + + pub fn asBytes(self: *Self) SelfBytes { + const size = @sizeOf(Self) + self.length; + return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size]; + } + }; +} + +test MessageNew { + comptime for (@typeInfo(MessageNew(.connection)).@"struct".decls) |field| { + @compileLog(field); + }; +} + +// pub fn bytesAsMessage(bytes: []const u8) !*Self { +// const res = std.mem.bytesAsValue(Self, bytes); +// return switch (res.type) { +// .relay, .connection => if (bytes.len == res.length + @sizeOf(Self)) +// res +// else +// Error.InvalidMessage, +// .file_transfer => Error.NotImplementedSaprusType, +// else => Error.UnknownSaprusType, +// }; +// } + // ZERO COPY STUFF // &payload could be a void value that is treated as a pointer to a [*]u8 /// All Saprus messages @@ -216,45 +334,6 @@ const asBytes = std.mem.asBytes; const nativeToBig = std.mem.nativeToBig; const bigToNative = std.mem.bigToNative; -test "Round trip Relay toBytes and fromBytes" { - const gpa = std.testing.allocator; - const msg = Message{ - .relay = .{ - .header = .{ .dest = .{ 255, 255, 255, 255 } }, - .payload = "Hello darkness my old friend", - }, - }; - - const to_bytes = try msg.toBytes(gpa); - defer gpa.free(to_bytes); - - const from_bytes = try Message.fromBytes(to_bytes, gpa); - defer from_bytes.deinit(gpa); - - try std.testing.expectEqualDeep(msg, from_bytes); -} - -test "Round trip Connection toBytes and fromBytes" { - const gpa = std.testing.allocator; - const msg = Message{ - .connection = .{ - .header = .{ - .src_port = 0, - .dest_port = 0, - }, - .payload = "Hello darkness my old friend", - }, - }; - - const to_bytes = try msg.toBytes(gpa); - defer gpa.free(to_bytes); - - const from_bytes = try Message.fromBytes(to_bytes, gpa); - defer from_bytes.deinit(gpa); - - try std.testing.expectEqualDeep(msg, from_bytes); -} - test { std.testing.refAllDeclsRecursive(@This()); } |