diff options
| author | Emīls <emils@mullvad.net> | 2020-09-02 12:23:28 +0100 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2020-09-02 12:23:28 +0100 |
| commit | 58ae8def1aba534b36dbb1e053debcb258c8d3e1 (patch) | |
| tree | d259b09a2c777ddd9e24c7db0722bd7979417d5d | |
| parent | 21a39a53d108c6e90fcb73b225ebeda8123f5b9b (diff) | |
| parent | 0dc0a6634adb50fd95ac06aaaa280a47c89754f4 (diff) | |
| download | mullvadvpn-58ae8def1aba534b36dbb1e053debcb258c8d3e1.tar.xz mullvadvpn-58ae8def1aba534b36dbb1e053debcb258c8d3e1.zip | |
Merge branch 'linux-use-wg-kernel-module'
22 files changed, 1835 insertions, 111 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index db17cb937d..6cd0464823 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ Line wrap the file at 100 chars. Th ## [Unreleased] +### Added +#### Linux +- Add support for WireGuard's kernel module if it's loaded. + ### Fixed #### Windows diff --git a/Cargo.lock b/Cargo.lock index d530e53bd9..e32b8c32cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,6 +182,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "bytes" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "iovec 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "bytes" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1331,56 +1340,56 @@ dependencies = [ [[package]] name = "netlink-packet-core" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ + "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", - "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "netlink-packet-route" -version = "0.2.1" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ + "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)", "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", - "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "netlink-packet-utils" -version = "0.1.1" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ + "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", - "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)", + "thiserror 1.0.20 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "netlink-proto" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "bytes 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)", - "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)", "tokio-util 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "netlink-sys" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1436,6 +1445,17 @@ dependencies = [ ] [[package]] +name = "nix" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "cc 1.0.59 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "notify" version = "4.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2024,14 +2044,14 @@ dependencies = [ [[package]] name = "rtnetlink" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "thiserror 1.0.20 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -2398,6 +2418,7 @@ version = "0.1.0" dependencies = [ "async-stream 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", + "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", "chrono 0.4.15 (registry+https://github.com/rust-lang/crates.io-index)", "dbus 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2413,11 +2434,13 @@ dependencies = [ "libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)", "mnl 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "nftnl 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "nix 0.17.0 (registry+https://github.com/rust-lang/crates.io-index)", + "nix 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)", "notify 4.0.15 (registry+https://github.com/rust-lang/crates.io-index)", "openvpn-plugin 0.3.0 (git+https://github.com/mullvad/openvpn-plugin-rs?branch=auth-failed-event)", "os_pipe 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2431,13 +2454,14 @@ dependencies = [ "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", "resolv-conf 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)", - "rtnetlink 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rtnetlink 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "shell-escape 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", "socket2 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", "system-configuration 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "talpid-types 0.1.0", "tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)", + "tokio-io 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", "tonic 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "tonic-build 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "triggered 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2579,6 +2603,16 @@ dependencies = [ ] [[package]] +name = "tokio-io" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)", + "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "tokio-macros" version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3219,6 +3253,7 @@ dependencies = [ "checksum blake2b_simd 0.5.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d8fb2d74254a3a0b5cac33ac9f8ed0e44aa50378d9dbb2e5d83bd21ed1dc2c8a" "checksum bumpalo 3.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2e8c087f005730276d1096a652e92a8bacee2e2472bcc9715a74d2bec38b5820" "checksum byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de" +"checksum bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" "checksum bytes 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" "checksum cc 1.0.59 (registry+https://github.com/rust-lang/crates.io-index)" = "66120af515773fb005778dc07c261bd201ec8ce50bd6e7144c927753fe013381" "checksum cesu8 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" @@ -3327,15 +3362,16 @@ dependencies = [ "checksum multimap 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d8883adfde9756c1d30b0f519c9b8c502a94b41ac62f696453c37c7fc0a958ce" "checksum natord 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "308d96db8debc727c3fd9744aac51751243420e46edf401010908da7f8d5e57c" "checksum net2 0.2.34 (registry+https://github.com/rust-lang/crates.io-index)" = "2ba7c918ac76704fb42afcbbb43891e72731f3dcca3bef2a19786297baf14af7" -"checksum netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9cdae99aa0db00bffb58886de3cdba39d07164cec467867f162827872e3ed957" -"checksum netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "472467595d208fb94b0e8f70b1155c14b51eeb784713b4845344b501a8297a11" -"checksum netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "6785792d3020aad7c392caea9a9d687ac84625d262d11f8e3c1ee959272150dc" -"checksum netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "844a78a78bee85b99686973856e57ce339ef2490660305d26e35bb74a672ad15" -"checksum netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "aee128bb9bcc04f426d9b5e0bf3077726776b5b41770a3b2e4db5f52295625bf" +"checksum netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5fa0ae27e4832438fa054230e7075f69d0fa464dd335c2be3343cb481c0e8113" +"checksum netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c46727055595605d4f625e633e5e9bd7296e7c79ea701aaf2dd53b5062cd5aa3" +"checksum netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2ce628faa6689198d3db4f68e6165a5ba02a8e0a5fe741cca9c1b7856bab6a66" +"checksum netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fba96b2619706b19de0d6620c4479c0ba2c8d293164e55494f02a3bfdaffd36a" +"checksum netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d322b2ea918a35471492c831381c18bf8afaefeb2ecddc04ef2de6d166ee0fb8" "checksum nftnl 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b7528eff501558f9f892c5001e945b0d7e980cb464a7969101c94e18481c4563" "checksum nftnl-sys 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fe241d8ce673ef755c8d2b8717cd74990d4e0a61d437792054750ce9a35743d0" "checksum nix 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3b2e0b4f3320ed72aaedb9a5ac838690a8047c7b275da22711fddff4f8a14229" "checksum nix 0.17.0 (registry+https://github.com/rust-lang/crates.io-index)" = "50e4785f2c3b7589a0d0c1dd60285e1188adac4006e8abd6dd578e1567027363" +"checksum nix 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83450fe6a6142ddd95fb064b746083fc4ef1705fe81f64a64e1d4b39f54a1055" "checksum notify 4.0.15 (registry+https://github.com/rust-lang/crates.io-index)" = "80ae4a7688d1fab81c5bf19c64fc8db920be8d519ce6336ed4e7efe024724dbd" "checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b" "checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" @@ -3402,7 +3438,7 @@ dependencies = [ "checksum resolv-conf 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)" = "11834e137f3b14e309437a8276714eed3a80d1ef894869e510f2c0c0b98b9f4a" "checksum ring 0.16.15 (registry+https://github.com/rust-lang/crates.io-index)" = "952cd6b98c85bbc30efa1ba5783b8abf12fec8b3287ffa52605b9432313e34e4" "checksum rs-release 0.1.7 (git+https://github.com/mullvad/rs-release?branch=snailquote-unescape)" = "<none>" -"checksum rtnetlink 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f802e53265ca90edd3cfc59a3ceb2c30d655da6a1a6b954684d51977633ef5fd" +"checksum rtnetlink 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d3a3e80f0a3ac6877e56c430f41a581a1c8e95c2bd979a84f90270df9808feed" "checksum rust-argon2 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2bc8af4bda8e1ff4932523b94d3dd20ee30a87232323eda55903ffd71d2fb017" "checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" "checksum rustc-serialize 0.3.24 (registry+https://github.com/rust-lang/crates.io-index)" = "dcf128d1287d2ea9d80910b5f1120d0b8eede3fbf1abe91c40d39ea7d51e6fda" @@ -3454,6 +3490,7 @@ dependencies = [ "checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" "checksum time 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" "checksum tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)" = "5d34ca54d84bf2b5b4d7d31e901a8464f7b60ac145a284fba25ceb801f2ddccd" +"checksum tokio-io 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674" "checksum tokio-macros 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c3acc6aa564495a0f2e1d59fab677cd7f81a19994cfc7f3ad0e64301560389" "checksum tokio-rustls 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "15cb62a0d2770787abc96e99c1cd98fcf17f94959f3af63ca85bdfb203f051b4" "checksum tokio-util 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "571da51182ec208780505a32528fc5512a8fe1443ab960b3f2f3ef093cd16930" diff --git a/docs/security.md b/docs/security.md index 7fc5369bd3..786cf99d8b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -124,8 +124,11 @@ VPN tunnel is allowed on all interfaces, together with responses to this outgoin First hop means the bridge server if one is used, otherwise the VPN server directly. This IP+port+protocol combination should only be allowed for the process establishing the VPN tunnel, or only administrator level processes, depending on what the platform firewall -allows restricting. On Windows the rule only allows processes from binaries in certain paths. -On Linux and macOS the rule only allows packets from processes running as `root`. +allows restricting. On Windows the rule only allows processes from binaries in certain paths. macOS +the rule only allows packets from processes running as `root`. On Linux, the rule only allows +packets that have the mark `0x6d6f6c65` set: setting a firewall mark on traffic requires elevated +privileges when using tunnels that support marking traffic, otherwise the rule is the same as on +macOS: the packet needs to originate from a process running as `root`. This process/user check is important to not allow unprivileged programs to leak packets to this IP outside the tunnel, as those packets can be fingerprinted. diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 5222c5d26d..9828bbedf6 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -39,7 +39,8 @@ tonic = "0.3.1" prost = "0.6" [target.'cfg(unix)'.dependencies] -nix = "0.17" +nix = "0.18" +tokio-io = "0.1" [target.'cfg(target_os = "android")'.dependencies] @@ -52,10 +53,13 @@ failure = "0.1" notify = "4.0" resolv-conf = "0.6.1" async-stream = "0.2" -rtnetlink = "0.2" -netlink-packet-route = "0.2" -netlink-proto = "0.2" -netlink-sys = "0.2" +rtnetlink = "0.3" +netlink-packet-core = "0.2" +netlink-packet-utils = "0.2" +netlink-packet-route = "0.3" +netlink-proto = "0.4" +netlink-sys = "0.3" +byteorder = "1" futures = { package = "futures", version = "0.3" } nftnl = { version = "0.5", features = ["nftnl-1-1-0"] } mnl = { version = "0.2.0", features = ["mnl-1-0-4"] } diff --git a/talpid-core/src/dns/linux/network_manager.rs b/talpid-core/src/dns/linux/network_manager.rs index 0e517e07f1..a8c4efccd3 100644 --- a/talpid-core/src/dns/linux/network_manager.rs +++ b/talpid-core/src/dns/linux/network_manager.rs @@ -194,6 +194,12 @@ impl NetworkManager { .get(NM_DEVICE, "Ip6Config") .map_err(Error::Dbus)?; + let device_addresses6: Vec<(Vec<u8>, u32, Vec<u8>)> = self + .dbus_connection + .with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS) + .get(NM_IP6_CONFIG, "Addresses") + .map_err(Error::Dbus)?; + let device_routes6: Vec<(Vec<u8>, u32, Vec<u8>, u32)> = self .dbus_connection .with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS) @@ -209,6 +215,7 @@ impl NetworkManager { ipv6_settings.insert("route-metric", Variant(Box::new(0u32))); ipv6_settings.insert("routes", Variant(Box::new(device_routes6))); ipv6_settings.insert("route-data", Variant(Box::new(device_route6_data))); + ipv6_settings.insert("addresses", Variant(Box::new(device_addresses6))); } let mut settings_backup = @@ -248,6 +255,13 @@ impl NetworkManager { Self::update_dns_config(&mut settings, "ipv6", v6_dns); } + if let Some(wg_config) = settings.get_mut("wireguard") { + wg_config.insert( + "fwmark", + Variant(Box::new(crate::linux::TUNNEL_FW_MARK) as Box<dyn RefArg>), + ); + } + self.reapply_settings(&device, settings, version_id)?; self.device = Some(device); diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 3a68d3734d..f4a3e07378 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -445,9 +445,11 @@ impl<'a> PolicyBatch<'a> { peer_endpoint, pingable_hosts, allow_lan, + use_fwmark, } => { self.add_allow_icmp_pingable_hosts(&pingable_hosts); - self.add_allow_endpoint_rules(peer_endpoint); + self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); + // Important to block DNS after allow relay rule (so the relay can operate // over port 53) but before allow LAN (so DNS does not leak to the LAN) self.add_drop_dns_rule(); @@ -457,8 +459,9 @@ impl<'a> PolicyBatch<'a> { peer_endpoint, tunnel, allow_lan, + use_fwmark, } => { - self.add_allow_endpoint_rules(peer_endpoint); + self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); self.add_allow_dns_rules(tunnel, TransportProtocol::Udp)?; self.add_allow_dns_rules(tunnel, TransportProtocol::Tcp)?; // Important to block DNS *before* we allow the tunnel and allow LAN. So DNS @@ -492,7 +495,7 @@ impl<'a> PolicyBatch<'a> { Ok(()) } - fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) { + fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) { let mut in_rule = Rule::new(&self.in_chain); check_endpoint(&mut in_rule, End::Src, endpoint); @@ -504,11 +507,15 @@ impl<'a> PolicyBatch<'a> { self.batch.add(&in_rule, nftnl::MsgType::Add); - let mut out_rule = Rule::new(&self.out_chain); check_endpoint(&mut out_rule, End::Dst, endpoint); - out_rule.add_expr(&nft_expr!(meta skuid)); - out_rule.add_expr(&nft_expr!(cmp == 0u32)); + if use_fwmark { + out_rule.add_expr(&nft_expr!(meta mark)); + out_rule.add_expr(&nft_expr!(cmp == crate::linux::TUNNEL_FW_MARK)); + } else { + out_rule.add_expr(&nft_expr!(meta skuid)); + out_rule.add_expr(&nft_expr!(cmp == 0u32)); + } add_verdict(&mut out_rule, &Verdict::Accept); self.batch.add(&out_rule, nftnl::MsgType::Add); diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index 6b65dd5b70..4d8d3f0459 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -97,6 +97,10 @@ pub enum FirewallPolicy { /// A process that is allowed to send packets to the relay. #[cfg(windows)] relay_client: PathBuf, + /// Whether rule for allowing traffic to endpoint should match a firewall mark or match on + /// root UID. + #[cfg(target_os = "linux")] + use_fwmark: bool, }, /// Allow traffic only to server and over tunnel interface @@ -110,6 +114,10 @@ pub enum FirewallPolicy { /// A process that is allowed to send packets to the relay. #[cfg(windows)] relay_client: PathBuf, + /// Whether rule for allowing traffic to endpoint should match a firewall mark or match on + /// root UID. + #[cfg(target_os = "linux")] + use_fwmark: bool, }, /// Block all network traffic in and out from the computer. diff --git a/talpid-core/src/linux.rs b/talpid-core/src/linux.rs index 05655bf4be..47d0714813 100644 --- a/talpid-core/src/linux.rs +++ b/talpid-core/src/linux.rs @@ -25,3 +25,6 @@ pub enum IfaceIndexLookupError { #[error(display = "Failed to get index for interface {}", _0)] InterfaceLookupError(String, #[error(source)] io::Error), } + +// b"mole" is [ 0x6d, 0x6f 0x6c, 0x65 ] +pub const TUNNEL_FW_MARK: u32 = 0x6d6f6c65; diff --git a/talpid-core/src/process/openvpn.rs b/talpid-core/src/process/openvpn.rs index e07be4fa16..7922ba952c 100644 --- a/talpid-core/src/process/openvpn.rs +++ b/talpid-core/src/process/openvpn.rs @@ -249,6 +249,13 @@ impl OpenVpnCommand { args.extend(Self::tls_cipher_arguments().iter().map(OsString::from)); args.extend(self.proxy_arguments().iter().map(OsString::from)); + #[cfg(target_os = "linux")] + args.extend( + ["--mark", &crate::linux::TUNNEL_FW_MARK.to_string()] + .iter() + .map(OsString::from), + ); + args } diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 566e1c027b..c7be9c6bff 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -52,7 +52,7 @@ pub enum Error { BindError(#[error(source)] io::Error), #[error(display = "Netlink error")] - NetlinkError(#[error(source)] failure::Compat<rtnetlink::Error>), + NetlinkError(#[error(source)] rtnetlink::Error), #[error(display = "Route without a valid node")] InvalidRoute, @@ -379,7 +379,6 @@ impl RouteManagerImpl { while let Some(route) = route_request .try_next() .await - .map_err(failure::Fail::compat) .map_err(Error::NetlinkError)? { if route.header.destination_prefix_length == 0 { @@ -394,12 +393,7 @@ impl RouteManagerImpl { async fn initialize_link_map(handle: &rtnetlink::Handle) -> Result<BTreeMap<u32, String>> { let mut link_map = BTreeMap::new(); let mut link_request = handle.link().get().execute(); - while let Some(link) = link_request - .try_next() - .await - .map_err(failure::Fail::compat) - .map_err(Error::NetlinkError)? - { + while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? { if let Some((idx, link_name)) = Self::map_iface_name_to_idx(link) { link_map.insert(idx, link_name); } @@ -536,7 +530,7 @@ impl RouteManagerImpl { Route::new(best_node, required_route.destination).table(required_route.table_id); if let Err(e) = self.delete_route(&route).await { if let Error::NetlinkError(err) = &e { - if let rtnetlink::ErrorKind::NetlinkError(msg) = err.get_ref().kind() { + if let rtnetlink::Error::NetlinkError(msg) = err { // -3 means that the route doesn't exist anymore anyway if msg.code == -3 { continue; @@ -551,7 +545,7 @@ impl RouteManagerImpl { for route in self.added_routes.drain().collect::<Vec<_>>().iter() { if let Err(e) = self.delete_route(&route).await { if let Error::NetlinkError(err) = &e { - if let rtnetlink::ErrorKind::NetlinkError(msg) = err.get_ref().kind() { + if let rtnetlink::Error::NetlinkError(msg) = err { // -3 means that the route doesn't exist anymore anyway if msg.code == -3 { continue; @@ -785,7 +779,6 @@ impl RouteManagerImpl { .del(route_message) .execute() .await - .map_err(failure::Fail::compat) .map_err(Error::NetlinkError) } @@ -849,17 +842,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self - .handle - .request(req) - .map_err(failure::Fail::compat) - .map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(err) = message.payload { - let compat_err = - failure::Fail::compat(rtnetlink::ErrorKind::NetlinkError(err).into()); - return Err(Error::NetlinkError(compat_err)); + return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err))); } } self.added_routes.insert(route.clone()); diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index a59ee6ed27..503e69bc35 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -194,6 +194,12 @@ impl RouteManager { } } + /// Exposes runtime handle + #[cfg(target_os = "linux")] + pub fn runtime_handle(&self) -> tokio::runtime::Handle { + self.runtime.handle().clone() + } + /// Route DNS requests through the tunnel interface. #[cfg(target_os = "linux")] pub fn route_exclusions_dns( diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs index 4f0e851b1e..87c0d2c0f0 100644 --- a/talpid-core/src/tunnel/wireguard/config.rs +++ b/talpid-core/src/tunnel/wireguard/config.rs @@ -17,6 +17,9 @@ pub struct Config { pub ipv6_gateway: Option<Ipv6Addr>, /// Maximum transmission unit for the tunnel pub mtu: u16, + /// Firewall mark + #[cfg(target_os = "linux")] + pub fwmark: u32, } const DEFAULT_MTU: u16 = 1380; @@ -96,6 +99,8 @@ impl Config { ipv4_gateway: connection_config.ipv4_gateway, ipv6_gateway, mtu, + #[cfg(target_os = "linux")] + fwmark: crate::linux::TUNNEL_FW_MARK, }) } @@ -108,6 +113,9 @@ impl Config { .add("private_key", self.tunnel.private_key.to_bytes().as_ref()) .add("listen_port", "0"); + #[cfg(target_os = "linux")] + wg_conf.add("fwmark", self.fwmark.to_string().as_str()); + wg_conf.add("replace_peers", "true"); for peer in &self.peers { diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index 111edd772a..8a604fcf6f 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -514,8 +514,8 @@ mod test { } impl Tunnel for MockTunnel { - fn get_interface_name(&self) -> &str { - "mock-tunnel" + fn get_interface_name(&self) -> String { + "mock-tunnel".to_string() } fn stop(self: Box<Self>) -> Result<(), TunnelError> { diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 6324d13e80..987274acb5 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -16,6 +16,8 @@ mod connectivity_check; mod logging; mod stats; mod wireguard_go; +#[cfg(target_os = "linux")] +mod wireguard_kernel; use self::wireguard_go::WgGoTunnel; @@ -62,12 +64,7 @@ impl WireguardMonitor { tun_provider: &mut TunProvider, route_manager: &mut routing::RouteManager, ) -> Result<WireguardMonitor> { - let tunnel = Box::new(WgGoTunnel::start_tunnel( - &config, - log_path, - tun_provider, - Self::get_tunnel_routes(config), - )?); + let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?; let iface_name = tunnel.get_interface_name().to_string(); route_manager .add_routes(Self::get_routes(&iface_name, &config)) @@ -93,7 +90,7 @@ impl WireguardMonitor { let close_sender = monitor.close_msg_sender.clone(); let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( gateway, - iface_name, + iface_name.to_string(), Arc::downgrade(&monitor.tunnel), pinger_rx, )?; @@ -125,6 +122,34 @@ impl WireguardMonitor { Ok(monitor) } + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] + fn open_tunnel( + config: &Config, + log_path: Option<&Path>, + tun_provider: &mut TunProvider, + route_manager: &mut routing::RouteManager, + ) -> Result<Box<dyn Tunnel>> { + #[cfg(target_os = "linux")] + match wireguard_kernel::KernelTunnel::new(route_manager.runtime_handle(), config) { + Ok(tunnel) => { + return Ok(Box::new(tunnel)); + } + Err(err) => { + log::error!( + "Failed to setup kernel WireGuard device, falling back to userspace: {}", + err + ); + } + }; + + Ok(Box::new(WgGoTunnel::start_tunnel( + &config, + log_path, + tun_provider, + Self::get_tunnel_routes(config), + )?)) + } + /// Returns a close handle for the tunnel pub fn close_handle(&self) -> CloseHandle { CloseHandle { @@ -228,7 +253,7 @@ impl CloseHandle { } pub(crate) trait Tunnel: Send { - fn get_interface_name(&self) -> &str; + fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::Stats, TunnelError>; } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 761276dcdf..b4d437f046 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -283,8 +283,8 @@ impl Drop for WgGoTunnel { } impl Tunnel for WgGoTunnel { - fn get_interface_name(&self) -> &str { - &self.interface_name + fn get_interface_name(&self) -> String { + self.interface_name.clone() } fn get_tunnel_stats(&self) -> Result<Stats> { diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs new file mode 100644 index 0000000000..b1fa3a0f65 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs @@ -0,0 +1,479 @@ +use super::{stats::Stats, Config, Tunnel, TunnelError}; +use futures::future::{abortable, AbortHandle}; +use netlink_packet_core::{constants::*, NetlinkDeserializable}; +use netlink_packet_route::{ + rtnl::{ + address::nlas::Nla as AddressNla, + link::nlas::{Info, InfoKind, Nla as LinkNla}, + AddressMessage, LinkMessage, RtnlMessage, RT_SCOPE_UNIVERSE, + }, + NetlinkMessage, NetlinkPayload, +}; +use netlink_packet_utils::DecodeError; +use netlink_proto::{ + sys::{Protocol, SocketAddr}, + ConnectionHandle, Error as NetlinkError, +}; +use std::{ffi::CString, net::IpAddr}; +use tokio::stream::StreamExt; + +mod parsers; + +mod wg_message; +use wg_message::{DeviceMessage, DeviceNla, PeerNla}; +mod nl_message; +use nl_message::{ControlNla, NetlinkControlMessage}; + + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + #[error(display = "Failed to decode netlink message")] + DecodeError(#[error(source)] DecodeError), + + #[error(display = "Failed to execute netlink control request")] + NetlinkControlMessageError(#[error(source)] nl_message::Error), + + #[error(display = "Failed to open netlink socket")] + NetlinkSocketError(#[error(source)] std::io::Error), + + #[error(display = "Failed to send netlink control request")] + NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), + + #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")] + WireguardNetlinkInterfaceUnavailable, + + #[error(display = "Unknown WireGuard command _0")] + UnnkownWireguardCommmand(u8), + + #[error(display = "Received no response")] + NoResponse, + + #[error(display = "Received truncated message")] + Truncated, + + #[error(display = "WireGuard device does not exist")] + NoDevice, + + #[error(display = "Failed to get config: _0")] + WgGetConfError(netlink_packet_core::error::ErrorMessage), + + #[error(display = "Failed to apply config: _0")] + WgSetConfError(netlink_packet_core::error::ErrorMessage), + + #[error(display = "Interface name too long")] + InterfaceNameError, + + #[error(display = "Send request error")] + SendRequestError(#[error(source)] NetlinkError<DeviceMessage>), + + #[error(display = "Create device error")] + NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error), + + #[error(display = "Add IP to device error")] + NetlinkSetIpError(rtnetlink::Error), + + #[error(display = "Failed to delete device")] + DeleteDeviceError(#[error(source)] rtnetlink::Error), +} + +pub struct KernelTunnel { + interface_index: u32, + netlink_connections: Handle, + tokio_handle: tokio::runtime::Handle, +} + +const MULLVAD_INTERFACE_NAME: &str = "wg-mullvad"; + +impl KernelTunnel { + pub fn new(tokio_handle: tokio::runtime::Handle, config: &Config) -> Result<Self, Error> { + tokio_handle.clone().block_on(async { + let mut netlink_connections = Handle::connect().await?; + let interface_index = netlink_connections + .create_device(MULLVAD_INTERFACE_NAME.to_string(), config.mtu as u32) + .await?; + + let mut tunnel = Self { + interface_index, + netlink_connections, + tokio_handle, + }; + + if let Err(err) = tunnel.setup(config).await { + if let Err(teardown_err) = tunnel + .netlink_connections + .delete_device(interface_index) + .await + { + log::error!( + "Failed to tear down WireGuard interface after failing to apply config: {}", + teardown_err + ); + } + return Err(err); + } + + + Ok(tunnel) + }) + } + + async fn setup(&mut self, config: &Config) -> Result<(), Error> { + self.netlink_connections + .wg_handle + .set_config(self.interface_index, config) + .await?; + + for tunnel_ip in config.tunnel.addresses.iter() { + self.netlink_connections + .set_ip_address(self.interface_index, *tunnel_ip) + .await?; + } + + Ok(()) + } +} + +impl Tunnel for KernelTunnel { + fn get_interface_name(&self) -> String { + let mut wg = self.netlink_connections.wg_handle.clone(); + let result = self.tokio_handle.block_on(async move { + let device = wg.get_by_index(self.interface_index).await?; + for nla in device.nlas { + if let DeviceNla::IfName(name) = nla { + return Ok(name); + } + } + return Err(Error::Truncated); + }); + + match result { + Ok(name) => name.to_string_lossy().to_string(), + Err(err) => { + log::error!("Failed to deduce interface name at runtime, will attempt to use the default name. {}", err); + MULLVAD_INTERFACE_NAME.to_string() + } + } + } + + fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError> { + let Self { + mut netlink_connections, + interface_index, + tokio_handle, + } = *self; + tokio_handle.block_on(async move { + if let Err(err) = netlink_connections.delete_device(interface_index).await { + log::error!("Failed to remove WireGuard device - {}", err); + Err(TunnelError::FatalStartWireguardError) + } else { + Ok(()) + } + }) + } + + fn get_tunnel_stats(&self) -> std::result::Result<Stats, TunnelError> { + let mut wg = self.netlink_connections.wg_handle.clone(); + let interface_index = self.interface_index; + let result = self.tokio_handle.block_on(async move { + let device = wg.get_by_index(interface_index).await.map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::GetConfigError + })?; + + // iterate over device attributes + let mut tx_bytes = 0; + let mut rx_bytes = 0; + for nla in device.nlas { + if let DeviceNla::Peers(peers) = nla { + // iterate over all peer attributes + let peer_iter = peers.iter().map(|peer| peer.0.as_slice()).flatten(); + + for peer_nla in peer_iter { + match peer_nla { + PeerNla::TxBytes(bytes) => tx_bytes += *bytes, + PeerNla::RxBytes(bytes) => rx_bytes += *bytes, + _ => continue, + }; + } + } + } + + Ok(Stats { tx_bytes, rx_bytes }) + }); + + result + } +} + + +#[derive(Debug)] +pub struct Handle { + wg_handle: WireguardConnection, + route_handle: rtnetlink::Handle, + wg_abort_handle: AbortHandle, + route_abort_handle: AbortHandle, + message_type: u16, +} + + +impl Handle { + pub async fn connect() -> Result<Self, Error> { + let message_type = Self::get_wireguard_message_type().await?; + let (conn, wireguard_connection, _messages) = + netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?; + let wg_handle = WireguardConnection { + message_type, + connection: wireguard_connection, + }; + let (abortable_connection, wg_abort_handle) = abortable(conn); + tokio::spawn(abortable_connection); + let (conn, route_handle, _messages) = + rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?; + let (abortable_connection, route_abort_handle) = abortable(conn); + tokio::spawn(abortable_connection); + + + Ok(Self { + wg_handle, + route_handle, + message_type, + wg_abort_handle, + route_abort_handle, + }) + } + + async fn get_wireguard_message_type() -> Result<u16, Error> { + let (conn, mut handle, _messages) = + netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?; + let (conn, abort_handle) = abortable(conn); + tokio::spawn(conn); + + let result = async move { + let mut message: NetlinkMessage<NetlinkControlMessage> = + NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap()) + .map_err(Error::NetlinkControlMessageError)? + .into(); + + message.header.flags = NLM_F_REQUEST | NLM_F_ACK; + + let mut req = handle + .request(message, SocketAddr::new(0, 0)) + .map_err(Error::NetlinkRequestError)?; + let response = req.next().await; + if let Some(response) = response { + if let NetlinkPayload::InnerMessage(msg) = response.payload { + for nla in msg.nlas.into_iter() { + if let ControlNla::FamilyId(id) = nla { + return Ok(id); + } + } + } + } + Err(Error::WireguardNetlinkInterfaceUnavailable) + } + .await; + + abort_handle.abort(); + result + } + + // create a wireguard device with the given name. + pub async fn create_device(&mut self, name: String, mtu: u32) -> Result<u32, Error> { + let mut message = LinkMessage::default(); + + // set link to be up + message.header.flags = netlink_packet_route::IFF_UP; + // message.header.change_mask = netlink_packet_route::IFF_UP; + // set link name + message.nlas.push(LinkNla::IfName(name.clone())); + // set link MTU + message.nlas.push(LinkNla::Mtu(mtu)); + // set link type + message + .nlas + .push(LinkNla::Info(vec![Info::Kind(InfoKind::Other( + "wireguard".to_string(), + ))])); + + let mut add_request = NetlinkMessage::from(RtnlMessage::NewLink(message)); + add_request.header.flags = + NLM_F_REQUEST | NLM_F_ACK | NLM_F_REPLACE | NLM_F_CREATE | NLM_F_MATCH; + let mut response = self + .route_handle + .request(add_request) + .map_err(Error::NetlinkCreateDeviceError)?; + while let Some(response_message) = response.next().await { + if let NetlinkPayload::Error(err) = response_message.payload { + // if the device exists, verify that it's a wireguard device + if -err.code != libc::EEXIST { + return Err(Error::NetlinkCreateDeviceError( + rtnetlink::Error::NetlinkError(err), + )); + } + } + } + + // fetch interface index of new device + let new_device = self.wg_handle.get_by_name(name).await?; + for nla in new_device.nlas { + if let DeviceNla::IfIndex(index) = nla { + return Ok(index); + } + } + + + Err(Error::NoDevice) + } + + pub async fn set_ip_address(&mut self, index: u32, addr: IpAddr) -> Result<(), Error> { + let address_message = add_ip_addr_message(index, addr); + let mut request = NetlinkMessage::from(RtnlMessage::NewAddress(address_message)); + request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; + + + let mut response = self + .route_handle + .request(request) + .map_err(Error::NetlinkSetIpError)?; + while let Some(response_message) = response.next().await { + consume_netlink_error(response_message, Error::NetlinkSetIpError)?; + } + + Ok(()) + } + + pub async fn delete_device(&mut self, index: u32) -> Result<(), Error> { + let mut link_message = LinkMessage::default(); + link_message.header.index = index; + + let mut request = NetlinkMessage::from(RtnlMessage::DelLink(link_message)); + request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + + let mut response = self + .route_handle + .request(request) + .map_err(Error::DeleteDeviceError)?; + while let Some(message) = response.next().await { + consume_netlink_error(message, Error::DeleteDeviceError)?; + } + + Ok(()) + } +} + +impl Drop for Handle { + fn drop(&mut self) { + self.wg_abort_handle.abort(); + self.route_abort_handle.abort(); + } +} + +#[derive(Debug, Clone)] +struct WireguardConnection { + connection: ConnectionHandle<DeviceMessage>, + message_type: u16, +} + +impl WireguardConnection { + pub async fn get_by_name(&mut self, name: String) -> Result<DeviceMessage, Error> { + self.fetch_device(DeviceMessage::get_by_name(self.message_type, name)?) + .await + } + + pub async fn get_by_index(&mut self, index: u32) -> Result<DeviceMessage, Error> { + self.fetch_device(DeviceMessage::get_by_index(self.message_type, index)) + .await + } + + pub async fn fetch_device( + &mut self, + device_message: DeviceMessage, + ) -> Result<DeviceMessage, Error> { + let mut netlink_message = NetlinkMessage::from(device_message); + netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; + + let mut response = self + .connection + .request(netlink_message, SocketAddr::new(0, 0)) + .map_err(Error::SendRequestError)?; + match response.next().await { + Some(received_message) => match received_message.payload { + NetlinkPayload::InnerMessage(inner) => Ok(inner), + NetlinkPayload::Error(err) => { + if err.code == -libc::ENODEV { + Err(Error::NoDevice) + } else { + Err(Error::WgGetConfError(err)) + } + } + anything_else => { + log::error!("Received unexpected response - {:?}", anything_else); + Err(Error::NoResponse) + } + }, + None => Err(Error::NoResponse), + } + } + + pub async fn set_config(&mut self, interface_index: u32, config: &Config) -> Result<(), Error> { + let message = DeviceMessage::reset_config(self.message_type, interface_index, config); + let mut netlink_message = NetlinkMessage::from(message); + netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK; + + let mut request = self + .connection + .request(netlink_message, SocketAddr::new(0, 0)) + .map_err(Error::SendRequestError)?; + + while let Some(response) = request.next().await { + if let NetlinkPayload::Error(err) = response.payload { + return Err(Error::WgSetConfError(err)); + } + } + Ok(()) + } +} + + +fn consume_netlink_error< + T, + I: NetlinkDeserializable<T> + Clone + Eq + std::fmt::Debug, + F: Fn(rtnetlink::Error) -> Error, +>( + message: NetlinkMessage<I>, + err_constructor: F, +) -> Result<(), Error> { + if let NetlinkPayload::Error(err) = message.payload { + return Err(err_constructor(rtnetlink::Error::NetlinkError(err))); + } + Ok(()) +} + +// the built-in support for adding addresses is too helpful, so a simple AddressMessage with a +// single Address nla is created +fn add_ip_addr_message(if_index: u32, addr: IpAddr) -> AddressMessage { + let prefix_len = if addr.is_ipv4() { 32 } else { 128 }; + let mut message = AddressMessage::default(); + message.header.prefix_len = prefix_len; + message.header.index = if_index; + message.header.scope = RT_SCOPE_UNIVERSE; + + match addr { + IpAddr::V4(ipv4) => { + message.header.family = libc::AF_INET as u8; + let ip_bytes = ipv4.octets().to_vec(); + + message.nlas.push(AddressNla::Address(ip_bytes.clone())); + message.nlas.push(AddressNla::Local(ip_bytes)); + } + IpAddr::V6(ipv6) => { + message.header.family = libc::AF_INET6 as u8; + message + .nlas + .push(AddressNla::Address(ipv6.octets().to_vec())); + } + }; + + message +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs new file mode 100644 index 0000000000..7fc8be8304 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs @@ -0,0 +1,135 @@ +use super::parsers; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_core::{ + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, +}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator}, + traits::{Emitable, Parseable}, + DecodeError, +}; +use std::{ffi::CString, io::Write, mem}; + + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Family name too long")] + FamilyNameTooLong, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct NetlinkControlMessage { + cmd: u8, + version: u8, + pub nlas: Vec<ControlNla>, +} + +impl NetlinkControlMessage { + pub fn get_netlink_family_id(name: CString) -> Result<Self, Error> { + if name.as_bytes_with_nul().len() > (libc::GENL_NAMSIZ as usize) { + return Err(Error::FamilyNameTooLong); + } + Ok(Self { + nlas: vec![ControlNla::FamilyName(name)], + cmd: libc::CTRL_CMD_GETFAMILY as u8, + version: 1, + }) + } +} + + +impl NetlinkSerializable<NetlinkControlMessage> for NetlinkControlMessage { + fn message_type(&self) -> u16 { + libc::GENL_ID_CTRL as u16 + } + + fn buffer_len(&self) -> usize { + mem::size_of::<libc::genlmsghdr>() + self.nlas.as_slice().buffer_len() + } + + fn serialize(&self, mut buffer: &mut [u8]) { + let _ = buffer.write(&[self.cmd, self.version, 0u8, 0u8]).unwrap(); + self.nlas.as_slice().emit(&mut buffer); + } +} + +impl Into<NetlinkPayload<NetlinkControlMessage>> for NetlinkControlMessage { + fn into(self) -> NetlinkPayload<NetlinkControlMessage> { + NetlinkPayload::InnerMessage(self) + } +} + +impl NetlinkDeserializable<NetlinkControlMessage> for NetlinkControlMessage { + type Error = DecodeError; + fn deserialize( + _header: &NetlinkHeader, + payload: &[u8], + ) -> Result<NetlinkControlMessage, Self::Error> { + // skip the genlmsghdr + let (cmd, version) = parsers::parse_genlmsghdr(payload)?; + let nla_buffer = &payload[mem::size_of::<libc::genlmsghdr>()..]; + let nlas = NlasIterator::new(nla_buffer) + .map(|buffer| ControlNla::parse(&buffer?)) + .collect::<Result<Vec<_>, DecodeError>>()?; + + Ok(NetlinkControlMessage { nlas, cmd, version }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ControlNla { + FamilyName(CString), + FamilyId(u16), + Unknown(u16, Vec<u8>), +} + +impl Nla for ControlNla { + fn value_len(&self) -> usize { + use ControlNla::*; + match self { + FamilyName(name) => name.as_bytes_with_nul().len(), + FamilyId(_id) => 2, + Unknown(_, buffer) => buffer.len(), + } + } + + fn kind(&self) -> u16 { + use ControlNla::*; + match self { + FamilyName(_) => libc::CTRL_ATTR_FAMILY_NAME as u16, + FamilyId(_) => libc::CTRL_ATTR_FAMILY_ID as u16, + Unknown(kind, _) => *kind, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use ControlNla::*; + match self { + FamilyName(name) => { + let _ = buffer.write(name.as_bytes()).unwrap(); + } + FamilyId(id) => { + NativeEndian::write_u16(buffer, *id); + } + + Unknown(_, value) => { + let _ = buffer.write(value).unwrap(); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized + std::fmt::Debug> Parseable<NlaBuffer<&'a T>> + for ControlNla +{ + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + let nla = match buf.kind() as i32 { + libc::CTRL_ATTR_FAMILY_NAME => { + ControlNla::FamilyName(parsers::parse_cstring(buf.value())?) + } + libc::CTRL_ATTR_FAMILY_ID => ControlNla::FamilyId(parsers::parse_u16(buf.value())?), + _unknown_kind => ControlNla::Unknown(buf.kind(), buf.value().to_vec()), + }; + Ok(nla) + } +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs new file mode 100644 index 0000000000..b34c82d342 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs @@ -0,0 +1,99 @@ +use byteorder::{ByteOrder, NativeEndian}; +use nix::sys::{socket::InetAddr, time::TimeSpec}; +use std::{ + ffi::{CStr, CString}, + mem, + net::IpAddr, +}; + +pub use netlink_packet_utils::parsers::*; +use netlink_packet_utils::DecodeError; + +pub fn parse_ip_addr(bytes: &[u8]) -> Result<IpAddr, DecodeError> { + if bytes.len() == 4 { + let mut ipv4_bytes = [0u8; 4]; + ipv4_bytes.copy_from_slice(bytes); + Ok(IpAddr::from(ipv4_bytes)) + } else if bytes.len() == 16 { + let mut ipv6_bytes = [0u8; 16]; + ipv6_bytes.copy_from_slice(bytes); + Ok(IpAddr::from(ipv6_bytes)) + } else { + log::error!("Expected either 4 or 16 bytes, got {} bytes", bytes.len()); + Err(format!("Invalid bytes for IP address: {:?}", bytes).into()) + } +} + +pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> { + match buffer.len() { + 32 => { + let mut key = [0u8; 32]; + key.clone_from_slice(buffer); + Ok(key) + } + anything_else => Err(format!("Unexpected length of key: {}", anything_else).into()), + } +} + +pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> { + if buffer.len() != mem::size_of::<libc::sockaddr_in6>() + && buffer.len() != mem::size_of::<libc::sockaddr_in>() + { + return Err(format!( + "Unexpected length for sockaddr_in: {}, expected {} or {}", + buffer.len(), + mem::size_of::<libc::sockaddr_in6>(), + mem::size_of::<libc::sockaddr_in>() + ) + .into()); + } + let ptr = buffer.as_ptr(); + const AF_INET: u16 = libc::AF_INET as u16; + const AF_INET6: u16 = libc::AF_INET6 as u16; + + match NativeEndian::read_u16(buffer) { + AF_INET => unsafe { + let sockaddr: *const libc::sockaddr_in = ptr as *const _; + Ok(InetAddr::V4(*sockaddr).into()) + }, + AF_INET6 => unsafe { + let sockaddr: *const libc::sockaddr_in6 = ptr as *const _; + Ok(InetAddr::V6(*sockaddr)) + }, + unexpected_addr_family => { + Err(format!("Unexpected address family: {}", unexpected_addr_family).into()) + } + } +} + +pub fn parse_timespec(buffer: &[u8]) -> Result<TimeSpec, DecodeError> { + if buffer.len() != mem::size_of::<libc::timespec>() { + return Err(format!("Unexpected size for timespec: {}", buffer.len()).into()); + } + + Ok(TimeSpec::from(libc::timespec { + tv_sec: NativeEndian::read_i64(buffer), + // TODO: become compatible with 32-bit systems maybe? + tv_nsec: NativeEndian::read_i64(buffer), + })) +} + +pub fn parse_cstring(buffer: &[u8]) -> Result<CString, DecodeError> { + Ok(CStr::from_bytes_with_nul(buffer) + .map_err(|err| format!("{}", err))? + .into()) +} + +pub fn parse_genlmsghdr(buffer: &[u8]) -> Result<(u8, u8), DecodeError> { + const GENLMSGHDR_SIZE: usize = mem::size_of::<libc::genlmsghdr>(); + if buffer.len() < GENLMSGHDR_SIZE { + return Err(format!( + "Expected at least {}, got {}", + GENLMSGHDR_SIZE, + buffer.len() + ) + .into()); + } + + Ok((buffer[0], buffer[1])) +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs new file mode 100644 index 0000000000..1e587712cc --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -0,0 +1,899 @@ +use super::{super::config::Config, parsers, Error}; +use byteorder::{ByteOrder, NativeEndian}; +use ipnetwork::IpNetwork; +use netlink_packet_core::{ + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, +}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator, NLA_F_NESTED}, + traits::{Emitable, Parseable}, + DecodeError, +}; +use nix::sys::{socket::InetAddr, time::TimeSpec}; +use std::{ffi::CString, io::Write, mem, net::IpAddr}; + +/// WireGuard netlink constants +mod constants { + #![allow(dead_code)] + pub const WG_GENL_VERSION: u8 = 1; + + /// Command constants + pub const WG_CMD_GET_DEVICE: u8 = 0; + pub const WG_CMD_SET_DEVICE: u8 = 1; + + // wgdevice_flag + pub const WGDEVICE_F_REPLACE_PEERS: u32 = 1 << 0; + + // wgdevice_attribute + pub const WGDEVICE_A_UNSPEC: u16 = 0; + pub const WGDEVICE_A_IFINDEX: u16 = 1; + pub const WGDEVICE_A_IFNAME: u16 = 2; + pub const WGDEVICE_A_PRIVATE_KEY: u16 = 3; + pub const WGDEVICE_A_PUBLIC_KEY: u16 = 4; + pub const WGDEVICE_A_FLAGS: u16 = 5; + pub const WGDEVICE_A_LISTEN_PORT: u16 = 6; + pub const WGDEVICE_A_FWMARK: u16 = 7; + pub const WGDEVICE_A_PEERS: u16 = 8; + + // wgpeer_flag + pub const WGPEER_F_REMOVE_ME: u32 = 1 << 0; + pub const WGPEER_F_REPLACE_ALLOWEDIPS: u32 = 1 << 1; + pub const WGPEER_F_UPDATE_ONLY: u32 = 1 << 2; + + // wgpeer_attribute + pub const WGPEER_A_UNSPEC: u16 = 0; + pub const WGPEER_A_PUBLIC_KEY: u16 = 1; + pub const WGPEER_A_PRESHARED_KEY: u16 = 2; + pub const WGPEER_A_FLAGS: u16 = 3; + pub const WGPEER_A_ENDPOINT: u16 = 4; + pub const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5; + pub const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6; + pub const WGPEER_A_RX_BYTES: u16 = 7; + pub const WGPEER_A_TX_BYTES: u16 = 8; + pub const WGPEER_A_ALLOWEDIPS: u16 = 9; + pub const WGPEER_A_PROTOCOL_VERSION: u16 = 10; + + // wgallowedip_attribute + pub const WGALLOWEDIP_A_UNSPEC: u16 = 0; + pub const WGALLOWEDIP_A_FAMILY: u16 = 1; + pub const WGALLOWEDIP_A_IPADDR: u16 = 2; + pub const WGALLOWEDIP_A_CIDR_MASK: u16 = 3; +} + +use constants::*; +pub use constants::{WG_CMD_GET_DEVICE, WG_CMD_SET_DEVICE}; + +type PrivateKey = [u8; 32]; +type PublicKey = [u8; 32]; +type PresharedKey = [u8; 32]; + + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DeviceMessage { + pub nlas: Vec<DeviceNla>, + pub message_type: u16, + pub command: u8, +} + +impl DeviceMessage { + pub fn reset_config(message_type: u16, interface_index: u32, config: &Config) -> DeviceMessage { + let mut peers = vec![]; + + for peer in config.peers.iter() { + let peer_endpoint = InetAddr::from_std(&peer.endpoint); + let allowed_ips = peer.allowed_ips.iter().map(From::from).collect(); + peers.push(PeerMessage(vec![ + PeerNla::PublicKey(*peer.public_key.as_bytes()), + PeerNla::Endpoint(peer_endpoint), + PeerNla::AllowedIps(allowed_ips), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + ])); + } + + let nlas = vec![ + DeviceNla::IfIndex(interface_index), + DeviceNla::ListenPort(0), + DeviceNla::Fwmark(crate::linux::TUNNEL_FW_MARK), + DeviceNla::PrivateKey(config.tunnel.private_key.to_bytes()), + DeviceNla::Flags(WGDEVICE_F_REPLACE_PEERS), + DeviceNla::Peers(peers), + ]; + + + Self { + nlas, + message_type, + command: WG_CMD_SET_DEVICE, + } + } + + pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> { + let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?; + if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ { + return Err(Error::InterfaceNameError); + } + + Ok(Self { + message_type, + nlas: vec![DeviceNla::IfName(c_name)], + command: WG_CMD_GET_DEVICE, + }) + } + + pub fn get_by_index(message_type: u16, index: u32) -> Self { + Self { + message_type, + nlas: vec![DeviceNla::IfIndex(index)], + command: WG_CMD_GET_DEVICE, + } + } + + // All WireGuard netlink messages should start with a libc::genlmsghdr, for which the first + // byte contains the command. + fn read_genlmsghdr(buff: &[u8]) -> Result<u8, Error> { + if buff.len() < mem::size_of::<libc::genlmsghdr>() { + return Err(Error::Truncated); + } + + let cmd = buff[0]; + if cmd == WG_CMD_GET_DEVICE || cmd == WG_CMD_SET_DEVICE { + Ok(cmd) + } else { + Err(Error::UnnkownWireguardCommmand(cmd)) + } + } +} + +impl NetlinkSerializable<DeviceMessage> for DeviceMessage { + fn message_type(&self) -> u16 { + self.message_type + } + + fn buffer_len(&self) -> usize { + // add the genlmsghdr + mem::size_of::<libc::genlmsghdr>() + + // size of all of the NLAs + self.nlas.as_slice().buffer_len() + } + + fn serialize(&self, mut buffer: &mut [u8]) { + let command_buf = [self.command, WG_GENL_VERSION, 0u8, 0u8]; + let _ = buffer.write(&command_buf).unwrap(); + self.nlas.as_slice().emit(&mut buffer) + } +} +impl Into<NetlinkPayload<DeviceMessage>> for DeviceMessage { + fn into(self) -> NetlinkPayload<DeviceMessage> { + NetlinkPayload::InnerMessage(self) + } +} + +impl NetlinkDeserializable<DeviceMessage> for DeviceMessage { + type Error = Error; + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result<DeviceMessage, Self::Error> { + let command = Self::read_genlmsghdr(payload)?; + let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..]; + let mut nlas = vec![]; + for buf in NlasIterator::new(new_payload) { + nlas.push( + DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?, + ); + } + + Ok(DeviceMessage { + nlas, + command, + message_type: header.message_type, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum DeviceNla { + IfIndex(u32), + IfName(CString), + Flags(u32), + PrivateKey(PrivateKey), + PublicKey(PublicKey), + ListenPort(u16), + Fwmark(u32), + Peers(Vec<PeerMessage>), + Unspec(Vec<u8>), +} + +impl Nla for DeviceNla { + fn value_len(&self) -> usize { + use DeviceNla::*; + match self { + IfIndex(_) | Fwmark(_) | Flags(_) => 4, + IfName(name) => name.as_bytes_with_nul().len(), + PrivateKey(key) | PublicKey(key) => key.len(), + ListenPort(_) => 2, + Peers(peers) => peers.as_slice().buffer_len(), + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use DeviceNla::*; + match self { + IfIndex(_) => WGDEVICE_A_IFINDEX, + IfName(_) => WGDEVICE_A_IFNAME, + PrivateKey(_) => WGDEVICE_A_PRIVATE_KEY, + PublicKey(_) => WGDEVICE_A_PUBLIC_KEY, + Flags(_) => WGDEVICE_A_FLAGS, + ListenPort(_) => WGDEVICE_A_LISTEN_PORT, + Fwmark(_) => WGDEVICE_A_FWMARK, + Peers(_) => WGDEVICE_A_PEERS | NLA_F_NESTED, + Unspec(_) => WGDEVICE_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use DeviceNla::*; + match self { + IfIndex(value) | Fwmark(value) | Flags(value) => { + NativeEndian::write_u32(buffer, *value) + } + IfName(interface_name) => { + let _ = buffer + .write(interface_name.as_bytes_with_nul()) + .expect("Failed to write interface name"); + } + PrivateKey(key) | PublicKey(key) => { + let _ = buffer.write(key).expect("Failed to write key"); + } + ListenPort(port) => NativeEndian::write_u16(buffer, *port), + Peers(peers) => { + peers.as_slice().emit(buffer); + } + Unspec(payload) => { + let _ = buffer.write(&payload).expect("Failed to write "); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized + core::fmt::Debug> Parseable<NlaBuffer<&'a T>> + for DeviceNla +{ + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use DeviceNla::*; + let value = buf.value(); + let kind = buf.kind(); + let nla = match kind { + WGDEVICE_A_IFINDEX => IfIndex(parsers::parse_u32(value)?), + WGDEVICE_A_IFNAME => IfName(parsers::parse_cstring(value)?), + WGDEVICE_A_PRIVATE_KEY => PrivateKey(parsers::parse_wg_key(value)?.into()), + WGDEVICE_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()), + WGDEVICE_A_FLAGS => Flags(parsers::parse_u32(value)?), + WGDEVICE_A_LISTEN_PORT => ListenPort(parsers::parse_u16(value)?), + WGDEVICE_A_FWMARK => Fwmark(parsers::parse_u32(value)?), + WGDEVICE_A_PEERS => { + let peers = NlasIterator::new(value) + .map(|nla_bytes| { + let buf = nla_bytes?; + let val = buf.value(); + PeerMessage::parse(&val) + }) + .collect::<Result<Vec<PeerMessage>, DecodeError>>()?; + Peers(peers) + } + WGDEVICE_A_UNSPEC => Unspec(value.to_vec()), + _ => { + return Err(format!("Unexpected device attribute kind: {}", buf.kind()).into()); + } + }; + Ok(nla) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct PeerMessage(pub Vec<PeerNla>); + +impl PeerMessage { + fn parse(payload: &[u8]) -> Result<Self, DecodeError> { + let mut nlas = vec![]; + + let nla_iter = NlasIterator::new(&payload); + for buffer in nla_iter { + nlas.push(PeerNla::parse(&buffer?)?) + } + Ok(Self(nlas)) + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerMessage { + fn parse(payload: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + Ok(Self( + NlasIterator::new(&payload.into_inner()) + .map(|buffer| PeerNla::parse(&buffer?)) + .collect::<Result<Vec<PeerNla>, DecodeError>>()?, + )) + } +} + +impl Nla for PeerMessage { + fn value_len(&self) -> usize { + self.0.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.0.as_slice().emit(buffer); + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum PeerNla { + Unspec(Vec<u8>), + PublicKey(PublicKey), + PresharedKey(PresharedKey), + Flags(u32), + Endpoint(InetAddr), + PersistentKeepaliveInterval(u16), + LastHandshakeTime(TimeSpec), + RxBytes(u64), + TxBytes(u64), + AllowedIps(Vec<AllowedIpMessage>), + ProtocolVersion(u32), +} + +impl Nla for PeerNla { + fn value_len(&self) -> usize { + use PeerNla::*; + match self { + PublicKey(key) | PresharedKey(key) => key.len(), + Endpoint(endpoint) => match &endpoint { + InetAddr::V4(_) => mem::size_of::<libc::sockaddr_in>(), + InetAddr::V6(_) => mem::size_of::<libc::sockaddr_in6>(), + }, + PersistentKeepaliveInterval(_) => 2, + LastHandshakeTime(_) => mem::size_of::<libc::timespec>(), + RxBytes(_) | TxBytes(_) => 8, + AllowedIps(ips) => ips.as_slice().buffer_len(), + Flags(_) | ProtocolVersion(_) => 4, + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use PeerNla::*; + match self { + PublicKey(_) => WGPEER_A_PUBLIC_KEY, + PresharedKey(_) => WGPEER_A_PRESHARED_KEY, + Flags(_) => WGPEER_A_FLAGS, + Endpoint(_) => WGPEER_A_ENDPOINT, + PersistentKeepaliveInterval(_) => WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, + LastHandshakeTime(_) => WGPEER_A_LAST_HANDSHAKE_TIME, + RxBytes(_) => WGPEER_A_RX_BYTES, + TxBytes(_) => WGPEER_A_TX_BYTES, + AllowedIps(_) => WGPEER_A_ALLOWEDIPS | NLA_F_NESTED, + ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION, + Unspec(_) => WGPEER_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use PeerNla::*; + match self { + PublicKey(key) | PresharedKey(key) => { + let _ = buffer.write(key).expect("Buffer too small for a key"); + } + Flags(value) | ProtocolVersion(value) => NativeEndian::write_u32(buffer, *value), + Endpoint(endpoint) => match &endpoint { + InetAddr::V4(sockaddr_in) => { + let slice = unsafe { struct_as_slice(sockaddr_in) }; + buffer + .write(slice) + .expect("Buffer too small for sockaddr_in"); + } + InetAddr::V6(sockaddr_in6) => { + buffer + .write(unsafe { struct_as_slice(sockaddr_in6) }) + .expect("Buffer too small for sockaddr_in6"); + } + }, + PersistentKeepaliveInterval(interval) => { + NativeEndian::write_u16(buffer, *interval); + } + LastHandshakeTime(last_handshake) => { + let timespec: &libc::timespec = last_handshake.as_ref(); + buffer + .write(unsafe { struct_as_slice(timespec) }) + .expect("Buffer too small for timespec"); + } + RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes), + AllowedIps(ips) => ips.as_slice().emit(buffer), + Unspec(payload) => { + let _ = buffer + .write(&payload) + .expect("Buffer too small for unspecified payload"); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerNla { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use PeerNla::*; + let value = buf.value(); + let nla = match buf.kind() { + WGPEER_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()), + WGPEER_A_PRESHARED_KEY => PresharedKey(parsers::parse_wg_key(value)?.into()), + WGPEER_A_FLAGS => Flags(parsers::parse_u32(value)?), + WGPEER_A_ENDPOINT => Endpoint(parsers::parse_inet_sockaddr(value)?), + WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => { + PersistentKeepaliveInterval(parsers::parse_u16(value)?) + } + + WGPEER_A_LAST_HANDSHAKE_TIME => LastHandshakeTime(parsers::parse_timespec(value)?), + WGPEER_A_RX_BYTES => RxBytes(parsers::parse_u64(value)?), + WGPEER_A_TX_BYTES => TxBytes(parsers::parse_u64(value)?), + WGPEER_A_ALLOWEDIPS => { + let nlas = NlasIterator::new(value) + .map(|nla_buffer| AllowedIpMessage::parse(&nla_buffer?)) + .collect::<Result<Vec<_>, DecodeError>>()?; + + AllowedIps(nlas) + } + WGPEER_A_PROTOCOL_VERSION => ProtocolVersion(parsers::parse_u32(value)?), + WGPEER_A_UNSPEC => Unspec(value.to_vec()), + _ => { + return Err(format!("Unexpected peer attribute kind: {}", buf.kind()).into()); + } + }; + Ok(nla) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct AllowedIpMessage(Vec<AllowedIpNla>); + +impl From<&IpNetwork> for AllowedIpMessage { + fn from(ip: &IpNetwork) -> Self { + use AllowedIpNla::*; + let address_family = if ip.is_ipv4() { + libc::AF_INET + } else { + libc::AF_INET6 + }; + + AllowedIpMessage(vec![ + AddressFamily(address_family as u16), + CidrMask(ip.prefix()), + IpAddr(ip.ip().into()), + ]) + } +} + +impl Nla for AllowedIpMessage { + fn value_len(&self) -> usize { + self.0.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.0.as_slice().emit(buffer); + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpMessage { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + let nlas = NlasIterator::new(buf.value()) + .map(|buffer| AllowedIpNla::parse(&buffer?)) + .collect::<Result<Vec<_>, _>>()?; + Ok(AllowedIpMessage(nlas)) + } +} + + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum AllowedIpNla { + AddressFamily(u16), + IpAddr(IpAddr), + CidrMask(u8), + Unspec(Vec<u8>), +} + +impl Nla for AllowedIpNla { + fn value_len(&self) -> usize { + use AllowedIpNla::*; + match &self { + AddressFamily(_) => 2, + IpAddr(addr) => ip_addr_to_bytes(addr).len(), + CidrMask(_) => 1, + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use AllowedIpNla::*; + match &self { + AddressFamily(_) => WGALLOWEDIP_A_FAMILY, + IpAddr(_) => WGALLOWEDIP_A_IPADDR, + CidrMask(_) => WGALLOWEDIP_A_CIDR_MASK, + Unspec(_) => WGALLOWEDIP_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use AllowedIpNla::*; + match self { + AddressFamily(af) => { + NativeEndian::write_u16(buffer, *af); + } + IpAddr(ip_addr) => { + buffer + .write(&ip_addr_to_bytes(ip_addr)) + .expect("Buffer too small for AllowedIpNla::IpAddr"); + } + CidrMask(cidr_mask) => buffer[0] = *cidr_mask, + Unspec(payload) => { + let _ = buffer + .write(&payload) + .expect("Buffer too small for unspec payload"); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpNla { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use AllowedIpNla::*; + let value = buf.value(); + let nla = match buf.kind() { + WGALLOWEDIP_A_FAMILY => AddressFamily(parsers::parse_u16(value)?), + WGALLOWEDIP_A_IPADDR => IpAddr(parsers::parse_ip_addr(value)?), + WGALLOWEDIP_A_CIDR_MASK => CidrMask(parsers::parse_u8(value)?), + WGALLOWEDIP_A_UNSPEC => Unspec(value.to_vec()), + _ => Err(format!( + "Unexpected allowed IP attribute kind: {}", + buf.kind() + ))?, + }; + Ok(nla) + } +} + +unsafe fn struct_as_slice<T: Sized>(t: &T) -> &[u8] { + let s = mem::size_of::<T>(); + let ptr = t as *const T as *const u8; + std::slice::from_raw_parts(ptr, s) +} + +fn ip_addr_to_bytes(addr: &IpAddr) -> Vec<u8> { + match addr { + IpAddr::V4(addr) => addr.octets().to_vec(), + IpAddr::V6(addr) => addr.octets().to_vec(), + } +} + +#[cfg(test)] +mod test { + use super::*; + use nix::sys::time::TimeValLike; + use std::net::Ipv4Addr; + + + #[test] + fn deserialize_netlink_message() { + #[rustfmt::skip] + let payload = vec![ + 0x00, 0x01, 0x00, 0x00, + // 6 bytes of WGDEVICE_A_LISTEN_PORT 51820 + 2 bytes of padding + 0x06, 0x00, 0x06, 0x00, 0x6c, 0xca, 0x00, 0x00, + // 8 bytes of WGDEVICE_A_FWMARK 0 + 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGDEVIEC_A_IFINDEX 320 + 0x08, 0x00, 0x01, 0x00, 0x40, 0x01, 0x00, 0x00, + // 12 bytes of WGDEVICE_A_IFNAME "wg-test\0" + 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x00, + // 36 bytes of WGDEVICE_A_PRIVATE_KEY OEf0rWXfVRarrw8nNbTBxkk3NTu8GjRKrbMW1aFH/H0= + 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16, 0xab, 0xaf, + 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a, 0x34, 0x4a, + 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d, + // 36 bytes of WGDEVICE_A_PUBLIC_KEY Ztqy3r8VO1N8tHwpWwqGx1S6G9o12BRdy1JESr2OYzs= + 0x24, 0x00, 0x04, 0x00, 0x66, 0xda, 0xb2, 0xde, 0xbf, 0x15, 0x3b, 0x53, 0x7c, 0xb4, + 0x7c, 0x29, 0x5b, 0x0a, 0x86, 0xc7, 0x54, 0xba, 0x1b, 0xda, 0x35, 0xd8, 0x14, 0x5d, + 0xcb, 0x52, 0x44, 0x4a, 0xbd, 0x8e, 0x63, 0x3b, + // 380 bytes of WGDEVICE_A_PEERS + 0x7c, 0x01, 0x08, 0x80, + // 188 bytes of WGPEER attributes + 0xbc, 0x00, 0x00, 0x80, + // 36 bytes of WGPEER_A_PUBLIC_KEY IOBEBReIZ+XOOyLn14vW7FBRuweaxfskq5wwSZEvhjY= + 0x24, 0x00, 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5, + 0xce, 0x3b, 0x22, 0xe7, 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07, + 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c, 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36, + // 36 bytes of WGPEER_A_PRESHARED_KEY (all zeroes) + 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME 0 + 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL 0 + 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_TX_BYTES 0 + 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_RX_BYTES 0 + 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGPEER_A_PROTOCOL_VERSION 1 + 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_ENDPOINT 192.168.39.2:9797 + 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 32 bytes of WGPEER_A_ALLOWEDIPS + 0x20, 0x00, 0x09, 0x80, + // 28 bytes of WGALLOWDIP_A_* + 0x1c, 0x00,0x00, 0x80, + // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32 + 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4) + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, + // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.1 + 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x01, + // 188 bvytes of WGPEER attributes + 0xbc, 0x00, 0x00, 0x80, + // 36 bytes of WGPEER_A_PUBLIC_KEY + 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7, + 0xc2, 0x9d, 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e, + 0xd9, 0x38, 0x0c, 0x92, 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c, + // 36 bytes of WGPEER_A_PRESHARED_KEY + 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME + 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL + 2 bytes of padding + 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_TX_BYTES + 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_RX_BYTES + 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGPEER_A_PROTOCOL_VERSION + 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_ENDPOINT + 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 32 bytes of WGPEER_A_ALLOWEDIPS + 0x20, 0x00, 0x09, 0x80, + // 28 bytes of WGALLOWDIP_A_* + 0x1c, 0x00, 0x00, 0x80, + // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32 + 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4) + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, + // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.2 + 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x02, + ]; + let header = NetlinkHeader { + length: payload.len() as u32, + message_type: 0, + flags: 0, + sequence_number: 0, + port_number: 0, + }; + let message = DeviceMessage::deserialize(&header, &payload).unwrap(); + + let mut serialized_message = vec![0u8; payload.len()]; + + message.serialize(&mut serialized_message); + + assert_eq!(message, sample_get_message()); + assert_eq!(&payload, &serialized_message) + } + + fn sample_get_message() -> DeviceMessage { + use AllowedIpNla::*; + use DeviceNla::*; + use PeerNla::*; + + let if_name = CString::new(b"wg-test".to_vec()).unwrap(); + + let peer_1 = PeerMessage( + [ + PeerNla::PublicKey([ + 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80, + 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54, + ]), + PeerNla::PresharedKey([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + LastHandshakeTime(TimeSpec::seconds(0)), + PersistentKeepaliveInterval(0), + TxBytes(0), + RxBytes(0), + ProtocolVersion(1), + Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())), + AllowedIps( + [AllowedIpMessage( + [ + CidrMask(32), + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()), + ] + .to_vec(), + )] + .to_vec() + .to_vec(), + ), + ] + .to_vec(), + ); + + let peer_2 = PeerMessage( + [ + PeerNla::PublicKey([ + 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24, + 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44, + ]), + PresharedKey([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + LastHandshakeTime(TimeSpec::seconds(0)), + PersistentKeepaliveInterval(0), + TxBytes(0), + RxBytes(0), + ProtocolVersion(1), + Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())), + AllowedIps( + [AllowedIpMessage( + vec![ + CidrMask(32), + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()), + ] + .to_vec(), + )] + .to_vec(), + ), + ] + .to_vec(), + ); + + DeviceMessage { + command: WG_CMD_GET_DEVICE, + message_type: 0, + nlas: [ + ListenPort(51820), + Fwmark(0), + IfIndex(320), + IfName(if_name), + PrivateKey([ + 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73, + 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125, + ]), + DeviceNla::PublicKey([ + 102, 218, 178, 222, 191, 21, 59, 83, 124, 180, 124, 41, 91, 10, 134, 199, 84, + 186, 27, 218, 53, 216, 20, 93, 203, 82, 68, 74, 189, 142, 99, 59, + ]), + Peers([peer_1, peer_2].to_vec()), + ] + .to_vec(), + } + } + + pub fn sample_set_message() -> DeviceMessage { + use AllowedIpNla::*; + use DeviceNla::*; + use PeerNla::*; + + let if_name = CString::new("wg-test".to_string()).unwrap(); + + let peer_1 = PeerMessage( + [ + PeerNla::PublicKey([ + 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80, + 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54, + ]), + Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + AllowedIps( + [AllowedIpMessage( + [ + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()), + CidrMask(32), + ] + .to_vec(), + )] + .to_vec() + .to_vec(), + ), + ] + .to_vec(), + ); + + let peer_2 = PeerMessage( + [ + PeerNla::PublicKey([ + 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24, + 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44, + ]), + Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + AllowedIps( + [AllowedIpMessage( + vec![ + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()), + CidrMask(32), + ] + .to_vec(), + )] + .to_vec(), + ), + ] + .to_vec(), + ); + + DeviceMessage { + command: WG_CMD_SET_DEVICE, + message_type: 0, + nlas: [ + IfName(if_name), + PrivateKey([ + 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73, + 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125, + ]), + ListenPort(51820), + Peers([peer_1, peer_2].to_vec()), + ] + .to_vec(), + } + } + + + #[test] + fn serialize_netlink_message() { + let expected_payload: &[u8] = &[ + 0x01, 0x01, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73, + 0x74, 0x00, 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16, + 0xab, 0xaf, 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a, + 0x34, 0x4a, 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d, 0x06, 0x00, 0x06, 0x00, + 0x6c, 0xca, 0x00, 0x00, 0xcc, 0x00, 0x08, 0x80, 0x64, 0x00, 0x00, 0x80, 0x24, 0x00, + 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5, 0xce, 0x3b, 0x22, 0xe7, + 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07, 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c, + 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, + 0xc0, 0xa8, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, + 0x27, 0x01, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x80, + 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7, 0xc2, 0x9d, + 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e, 0xd9, 0x38, 0x0c, 0x92, + 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, + 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, + 0x00, 0x80, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, + 0xc0, 0xa8, 0x27, 0x02, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + ]; + + let mut message = sample_set_message(); + message.command = WG_CMD_SET_DEVICE; + + + let mut payload_buffer = vec![0u8; message.buffer_len()]; + message.serialize(&mut payload_buffer); + let header = NetlinkHeader { + length: payload_buffer.len() as u32, + message_type: 0, + flags: 0, + sequence_number: 0, + port_number: 0, + }; + let deserialized_device = DeviceMessage::deserialize(&header, &payload_buffer).unwrap(); + + assert_eq!(message, deserialized_device); + assert_eq!(payload_buffer, expected_payload); + } +} diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 43595a4e79..dd259a87d9 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -11,7 +11,7 @@ use futures01::{ Async, Future, Stream, }; use talpid_types::{ - net::{Endpoint, TunnelParameters}, + net::TunnelParameters, tunnel::{ErrorStateCause, FirewallPolicyError}, BoxedError, ErrorExt, }; @@ -52,19 +52,7 @@ impl ConnectedState { &self, shared_values: &mut SharedTunnelStateValues, ) -> Result<(), FirewallPolicyError> { - // If a proxy is specified we need to pass it on as the peer endpoint. - let peer_endpoint = self.get_endpoint_from_params(); - - let policy = FirewallPolicy::Connected { - peer_endpoint, - tunnel: self.metadata.clone(), - allow_lan: shared_values.allow_lan, - #[cfg(windows)] - relay_client: TunnelMonitor::get_relay_client( - &shared_values.resource_dir, - &self.tunnel_parameters, - ), - }; + let policy = self.get_firewall_policy(shared_values); shared_values .firewall .apply_policy(policy) @@ -85,13 +73,18 @@ impl ConnectedState { }) } - fn get_endpoint_from_params(&self) -> Endpoint { - match self.tunnel_parameters { - TunnelParameters::OpenVpn(ref params) => match params.proxy { - Some(ref proxy_settings) => proxy_settings.get_endpoint().endpoint, - None => params.config.endpoint, - }, - TunnelParameters::Wireguard(ref params) => params.connection.get_endpoint(), + fn get_firewall_policy(&self, shared_values: &SharedTunnelStateValues) -> FirewallPolicy { + FirewallPolicy::Connected { + peer_endpoint: self.tunnel_parameters.get_next_hop_endpoint(), + tunnel: self.metadata.clone(), + allow_lan: shared_values.allow_lan, + #[cfg(windows)] + relay_client: TunnelMonitor::get_relay_client( + &shared_values.resource_dir, + &self.tunnel_parameters, + ), + #[cfg(target_os = "linux")] + use_fwmark: self.tunnel_parameters.get_proxy_endpoint().is_none(), } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index d206b34a23..859ef52eb0 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -22,7 +22,7 @@ use std::{ time::{Duration, Instant}, }; use talpid_types::{ - net::{openvpn, TunnelParameters}, + net::TunnelParameters, tunnel::{ErrorStateCause, FirewallPolicyError}, ErrorExt, }; @@ -48,13 +48,7 @@ impl ConnectingState { shared_values: &mut SharedTunnelStateValues, params: &TunnelParameters, ) -> Result<(), FirewallPolicyError> { - let proxy = &get_openvpn_proxy_settings(¶ms); - let endpoint = params.get_tunnel_endpoint().endpoint; - - let peer_endpoint = match proxy { - Some(proxy_settings) => proxy_settings.get_endpoint().endpoint, - None => endpoint, - }; + let peer_endpoint = params.get_next_hop_endpoint(); let policy = FirewallPolicy::Connecting { peer_endpoint, @@ -62,6 +56,8 @@ impl ConnectingState { allow_lan: shared_values.allow_lan, #[cfg(windows)] relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, ¶ms), + #[cfg(target_os = "linux")] + use_fwmark: params.get_proxy_endpoint().is_none(), }; shared_values .firewall @@ -312,15 +308,6 @@ impl ConnectingState { } } -fn get_openvpn_proxy_settings( - tunnel_parameters: &TunnelParameters, -) -> &Option<openvpn::ProxySettings> { - match tunnel_parameters { - TunnelParameters::OpenVpn(ref config) => &config.proxy, - _ => &None, - } -} - fn should_retry(error: &tunnel::Error) -> bool { #[cfg(not(windows))] use tunnel::wireguard::{Error, TunnelError}; diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 97292fd381..15bf33a5bb 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -37,6 +37,25 @@ impl TunnelParameters { } } + // Returns the endpoint that will be connected to + pub fn get_next_hop_endpoint(&self) -> Endpoint { + match self { + TunnelParameters::OpenVpn(params) => params + .proxy + .as_ref() + .map(|proxy| proxy.get_endpoint().endpoint) + .unwrap_or(params.config.endpoint), + TunnelParameters::Wireguard(params) => params.connection.get_endpoint(), + } + } + + pub fn get_proxy_endpoint(&self) -> Option<openvpn::ProxySettings> { + match self { + TunnelParameters::OpenVpn(params) => params.proxy.clone(), + _ => None, + } + } + pub fn get_generic_options(&self) -> &GenericTunnelOptions { match &self { TunnelParameters::OpenVpn(params) => ¶ms.generic_options, |
