aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Client.zig17
-rw-r--r--src/Connection.zig49
-rw-r--r--src/RawSocket.zig25
-rw-r--r--src/c_api.zig2
-rw-r--r--src/main.zig5
-rw-r--r--src/message.zig32
-rw-r--r--src/root.zig1
7 files changed, 102 insertions, 29 deletions
diff --git a/src/Client.zig b/src/Client.zig
index a8170a5..1709cab 100644
--- a/src/Client.zig
+++ b/src/Client.zig
@@ -100,7 +100,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection {
var connection: SaprusMessage = .{
.connection = .{
.src = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
- .dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
+ .dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)), // Ignored, but good noise
.seq = undefined,
.id = undefined,
.payload = payload,
@@ -108,7 +108,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection {
};
log.debug("Setting bpf filter to port {}", .{connection.connection.src});
- self.socket.attachSaprusPortFilter(connection.connection.src) catch |err| {
+ self.socket.attachSaprusPortFilter(null, connection.connection.src) catch |err| {
log.err("Failed to set port filter: {t}", .{err});
return err;
};
@@ -131,7 +131,17 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection {
log.debug("Awaiting handshake response", .{});
// Ignore response from sentinel, just accept that we got one.
- _ = try self.socket.receive(&res_buf);
+ const full_handshake_res = try self.socket.receive(&res_buf);
+ const handshake_res = saprusParse(full_handshake_res[42..]) catch |err| {
+ log.err("Parse error: {t}", .{err});
+ return err;
+ };
+ self.socket.attachSaprusPortFilter(handshake_res.connection.src, handshake_res.connection.dest) catch |err| {
+ log.err("Failed to set port filter: {t}", .{err});
+ return err;
+ };
+ connection.connection.dest = handshake_res.connection.src;
+ connection_bytes = connection.toBytes(&connection_buf);
headers.udp.dst_port = udp_dest_port;
headers.ip.id = rand.int(u16);
@@ -153,6 +163,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection {
const RawSocket = @import("./RawSocket.zig");
const SaprusMessage = @import("message.zig").Message;
+const saprusParse = @import("message.zig").parse;
const SaprusConnection = @import("Connection.zig");
const EthIpUdp = @import("./EthIpUdp.zig").EthIpUdp;
diff --git a/src/Connection.zig b/src/Connection.zig
index 90109af..bb81c38 100644
--- a/src/Connection.zig
+++ b/src/Connection.zig
@@ -28,25 +28,50 @@ pub fn init(socket: RawSocket, headers: EthIpUdp, connection: SaprusMessage) Con
};
}
-pub fn next(self: Connection, io: Io, buf: []u8) ![]const u8 {
- _ = io;
- log.debug("Awaiting connection message", .{});
- const res = try self.socket.receive(buf);
- log.debug("Received {} byte connection message", .{res.len});
- const msg: SaprusMessage = try .parse(res[42..]);
- const connection_res = msg.connection;
-
- log.debug("Payload was {s}", .{connection_res.payload});
-
- return connection_res.payload;
+// 'p' as base64
+const pong = "cA==";
+
+pub fn next(self: *Connection, io: Io, buf: []u8) ![]const u8 {
+ while (true) {
+ log.debug("Awaiting connection message", .{});
+ const res = try self.socket.receive(buf);
+ log.debug("Received {} byte connection message", .{res.len});
+ const msg = SaprusMessage.parse(res[42..]) catch |err| {
+ log.err("Failed to parse next message: {t}\n{x}\n{x}", .{ err, res[0..], res[42..] });
+ return err;
+ };
+
+ switch (msg) {
+ .connection => |con_res| {
+ if (try con_res.management()) |mgt| {
+ log.debug("Received management message {t}", .{mgt});
+ switch (mgt) {
+ .ping => {
+ log.debug("Sending pong", .{});
+ try self.send(io, .{ .management = true }, pong);
+ log.debug("Sent pong message", .{});
+ },
+ else => |m| log.debug("Received management message that I don't know how to handle: {t}", .{m}),
+ }
+ } else {
+ log.debug("Payload was {s}", .{con_res.payload});
+ return con_res.payload;
+ }
+ },
+ else => |m| {
+ std.debug.panic("Expected connection message, instead got {x}. This means there is an error with the BPF.", .{@intFromEnum(m)});
+ },
+ }
+ }
}
-pub fn send(self: *Connection, io: Io, buf: []const u8) !void {
+pub fn send(self: *Connection, io: Io, options: SaprusMessage.Connection.Options, buf: []const u8) !void {
const io_source: std.Random.IoSource = .{ .io = io };
const rand = io_source.interface();
log.debug("Sending connection message", .{});
+ self.connection.connection.options = options;
self.connection.connection.payload = buf;
var connection_bytes_buf: [2048]u8 = undefined;
const connection_bytes = self.connection.toBytes(&connection_bytes_buf);
diff --git a/src/RawSocket.zig b/src/RawSocket.zig
index 5732ce9..e43a8e4 100644
--- a/src/RawSocket.zig
+++ b/src/RawSocket.zig
@@ -133,7 +133,7 @@ pub fn receive(self: RawSocket, buf: []u8) ![]u8 {
return buf[0..len];
}
-pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void {
+pub fn attachSaprusPortFilter(self: RawSocket, incoming_src_port: ?u16, incoming_dest_port: u16) !void {
const BPF = std.os.linux.BPF;
// BPF instruction structure for classic BPF
const SockFilter = extern struct {
@@ -149,11 +149,26 @@ pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void {
};
// Build the filter program
- const filter = [_]SockFilter{
+ const filter = if (incoming_src_port) |inc_src| &[_]SockFilter{
// Load 2 bytes at offset 46 (absolute)
.{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 46 },
+ // Jump if equal to port (skip 1 if true, skip 0 if false)
+ .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 1, .jf = 0, .k = @as(u32, inc_src) },
+ // Return 0x0 (fail)
+ .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0x0 },
+ // Load 2 bytes at offset 48 (absolute)
+ .{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 48 },
+ // Jump if equal to port (skip 0 if true, skip 1 if false)
+ .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, incoming_dest_port) },
+ // Return 0xffff (pass)
+ .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0xffff },
+ // Return 0x0 (fail)
+ .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0x0 },
+ } else &[_]SockFilter{
+ // Load 2 bytes at offset 48 (absolute)
+ .{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 48 },
// Jump if equal to port (skip 0 if true, skip 1 if false)
- .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, port) },
+ .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, incoming_dest_port) },
// Return 0xffff (pass)
.{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0xffff },
// Return 0x0 (fail)
@@ -161,8 +176,8 @@ pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void {
};
const fprog = SockFprog{
- .len = filter.len,
- .filter = &filter,
+ .len = @intCast(filter.len),
+ .filter = filter.ptr,
};
// Attach filter to socket using setsockopt
diff --git a/src/c_api.zig b/src/c_api.zig
index 7f10c45..c2f3190 100644
--- a/src/c_api.zig
+++ b/src/c_api.zig
@@ -99,6 +99,6 @@ export fn zaprus_connection_send(
const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection));
const zc = c orelse return 1;
- zc.send(io, payload[0..payload_len]) catch return 1;
+ zc.send(io, .{}, payload[0..payload_len]) catch return 1;
return 0;
}
diff --git a/src/main.zig b/src/main.zig
index 734357b..10dca33 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -191,6 +191,7 @@ pub fn main(init: std.process.Init) !void {
error.SymLinkLoop,
error.SystemResources,
=> blk: {
+ log.debug("Trying to execute command directly: {s}", .{connection_payload});
var argv_buf: [128][]const u8 = undefined;
var argv: ArrayList([]const u8) = .initBuffer(&argv_buf);
var payload_iter = std.mem.splitAny(u8, connection_payload, " \t\n");
@@ -229,7 +230,7 @@ pub fn main(init: std.process.Init) !void {
error.EndOfStream => {
cmd_output.print("{b64}", .{child_output_reader.interface.buffered()}) catch unreachable;
if (cmd_output.end > 0) {
- connection.send(init.io, cmd_output.buffered()) catch |e| {
+ connection.send(init.io, .{}, cmd_output.buffered()) catch |e| {
log.debug("Failed to send connection chunk: {t}", .{e});
continue :next_message;
};
@@ -238,7 +239,7 @@ pub fn main(init: std.process.Init) !void {
},
};
cmd_output.print("{b64}", .{try child_output_reader.interface.takeArray(child_output_buf.len)}) catch unreachable;
- connection.send(init.io, cmd_output.buffered()) catch |err| {
+ connection.send(init.io, .{}, cmd_output.buffered()) catch |err| {
log.debug("Failed to send connection chunk: {t}", .{err});
continue :next_message;
};
diff --git a/src/message.zig b/src/message.zig
index e8ef268..0c1410d 100644
--- a/src/message.zig
+++ b/src/message.zig
@@ -169,11 +169,11 @@ const Connection = struct {
seq: u32,
id: u32,
reserved: u8 = undefined,
- options: Options = undefined,
+ options: Options = .{},
payload: []const u8,
- /// Reserved option values.
- /// Currently unused.
+ /// Option values.
+ /// Currently used!
pub const Options = packed struct(u8) {
opt1: bool = false,
opt2: bool = false,
@@ -182,7 +182,7 @@ const Connection = struct {
opt5: bool = false,
opt6: bool = false,
opt7: bool = false,
- opt8: bool = false,
+ management: bool = false,
};
/// Asserts that buf is large enough to fit the connection message.
@@ -199,6 +199,28 @@ const Connection = struct {
out.writeAll(self.payload) catch unreachable;
return out.buffered();
}
+
+ /// If the current message is a management message, return what kind.
+ /// Else return null.
+ pub fn management(self: Connection) MessageParseError!?Management {
+ const b64_dec = std.base64.standard.Decoder;
+ if (self.options.management) {
+ var buf: [1]u8 = undefined;
+ _ = b64_dec.decode(&buf, self.payload) catch return error.InvalidMessage;
+
+ return switch (buf[0]) {
+ 'P' => .ping,
+ 'p' => .pong,
+ else => error.UnknownSaprusType,
+ };
+ }
+ return null;
+ }
+
+ pub const Management = enum {
+ ping,
+ pong,
+ };
};
test "Round trip" {
@@ -223,5 +245,5 @@ const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
test {
- std.testing.refAllDeclsRecursive(@This());
+ std.testing.refAllDecls(@This());
}
diff --git a/src/root.zig b/src/root.zig
index c469021..aa78565 100644
--- a/src/root.zig
+++ b/src/root.zig
@@ -19,7 +19,6 @@ pub const Connection = @import("Connection.zig");
const msg = @import("message.zig");
-pub const PacketType = msg.PacketType;
pub const MessageTypeError = msg.MessageTypeError;
pub const MessageParseError = msg.MessageParseError;
pub const Message = msg.Message;