summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2020-09-02 12:23:28 +0100
committerEmīls <emils@mullvad.net>2020-09-02 12:23:28 +0100
commit58ae8def1aba534b36dbb1e053debcb258c8d3e1 (patch)
treed259b09a2c777ddd9e24c7db0722bd7979417d5d
parent21a39a53d108c6e90fcb73b225ebeda8123f5b9b (diff)
parent0dc0a6634adb50fd95ac06aaaa280a47c89754f4 (diff)
downloadmullvadvpn-58ae8def1aba534b36dbb1e053debcb258c8d3e1.tar.xz
mullvadvpn-58ae8def1aba534b36dbb1e053debcb258c8d3e1.zip
Merge branch 'linux-use-wg-kernel-module'
-rw-r--r--CHANGELOG.md4
-rw-r--r--Cargo.lock95
-rw-r--r--docs/security.md7
-rw-r--r--talpid-core/Cargo.toml14
-rw-r--r--talpid-core/src/dns/linux/network_manager.rs14
-rw-r--r--talpid-core/src/firewall/linux.rs19
-rw-r--r--talpid-core/src/firewall/mod.rs8
-rw-r--r--talpid-core/src/linux.rs3
-rw-r--r--talpid-core/src/process/openvpn.rs7
-rw-r--r--talpid-core/src/routing/linux.rs25
-rw-r--r--talpid-core/src/routing/unix.rs6
-rw-r--r--talpid-core/src/tunnel/wireguard/config.rs8
-rw-r--r--talpid-core/src/tunnel/wireguard/connectivity_check.rs4
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs41
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs4
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs479
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs135
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs99
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs899
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs35
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs21
-rw-r--r--talpid-types/src/net/mod.rs19
22 files changed, 1835 insertions, 111 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index db17cb937d..6cd0464823 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -24,6 +24,10 @@ Line wrap the file at 100 chars. Th
## [Unreleased]
+### Added
+#### Linux
+- Add support for WireGuard's kernel module if it's loaded.
+
### Fixed
#### Windows
diff --git a/Cargo.lock b/Cargo.lock
index d530e53bd9..e32b8c32cb 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -182,6 +182,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "bytes"
+version = "0.4.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
+ "iovec 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
+name = "bytes"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1331,56 +1340,56 @@ dependencies = [
[[package]]
name = "netlink-packet-core"
-version = "0.1.0"
+version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
+ "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
- "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "netlink-packet-route"
-version = "0.2.1"
+version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
+ "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)",
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
- "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "netlink-packet-utils"
-version = "0.1.1"
+version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
+ "anyhow 1.0.32 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
- "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
+ "thiserror 1.0.20 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "netlink-proto"
-version = "0.2.1"
+version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"bytes 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)",
- "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio-util 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "netlink-sys"
-version = "0.2.0"
+version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -1436,6 +1445,17 @@ dependencies = [
]
[[package]]
+name = "nix"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "cc 1.0.59 (registry+https://github.com/rust-lang/crates.io-index)",
+ "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
+ "libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "notify"
version = "4.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2024,14 +2044,14 @@ dependencies = [
[[package]]
name = "rtnetlink"
-version = "0.2.2"
+version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "thiserror 1.0.20 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -2398,6 +2418,7 @@ version = "0.1.0"
dependencies = [
"async-stream 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
+ "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono 0.4.15 (registry+https://github.com/rust-lang/crates.io-index)",
"dbus 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -2413,11 +2434,13 @@ dependencies = [
"libc 0.2.76 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)",
"mnl 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"nftnl 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "nix 0.17.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "nix 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)",
"notify 4.0.15 (registry+https://github.com/rust-lang/crates.io-index)",
"openvpn-plugin 0.3.0 (git+https://github.com/mullvad/openvpn-plugin-rs?branch=auth-failed-event)",
"os_pipe 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -2431,13 +2454,14 @@ dependencies = [
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)",
"resolv-conf 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)",
- "rtnetlink 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rtnetlink 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"shell-escape 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"socket2 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)",
"system-configuration 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"talpid-types 0.1.0",
"tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)",
+ "tokio-io 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
"tonic 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"tonic-build 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"triggered 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -2579,6 +2603,16 @@ dependencies = [
]
[[package]]
+name = "tokio-io"
+version = "0.1.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)",
+ "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)",
+ "log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "tokio-macros"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3219,6 +3253,7 @@ dependencies = [
"checksum blake2b_simd 0.5.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d8fb2d74254a3a0b5cac33ac9f8ed0e44aa50378d9dbb2e5d83bd21ed1dc2c8a"
"checksum bumpalo 3.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2e8c087f005730276d1096a652e92a8bacee2e2472bcc9715a74d2bec38b5820"
"checksum byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de"
+"checksum bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c"
"checksum bytes 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38"
"checksum cc 1.0.59 (registry+https://github.com/rust-lang/crates.io-index)" = "66120af515773fb005778dc07c261bd201ec8ce50bd6e7144c927753fe013381"
"checksum cesu8 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
@@ -3327,15 +3362,16 @@ dependencies = [
"checksum multimap 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d8883adfde9756c1d30b0f519c9b8c502a94b41ac62f696453c37c7fc0a958ce"
"checksum natord 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "308d96db8debc727c3fd9744aac51751243420e46edf401010908da7f8d5e57c"
"checksum net2 0.2.34 (registry+https://github.com/rust-lang/crates.io-index)" = "2ba7c918ac76704fb42afcbbb43891e72731f3dcca3bef2a19786297baf14af7"
-"checksum netlink-packet-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9cdae99aa0db00bffb58886de3cdba39d07164cec467867f162827872e3ed957"
-"checksum netlink-packet-route 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "472467595d208fb94b0e8f70b1155c14b51eeb784713b4845344b501a8297a11"
-"checksum netlink-packet-utils 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "6785792d3020aad7c392caea9a9d687ac84625d262d11f8e3c1ee959272150dc"
-"checksum netlink-proto 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "844a78a78bee85b99686973856e57ce339ef2490660305d26e35bb74a672ad15"
-"checksum netlink-sys 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "aee128bb9bcc04f426d9b5e0bf3077726776b5b41770a3b2e4db5f52295625bf"
+"checksum netlink-packet-core 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5fa0ae27e4832438fa054230e7075f69d0fa464dd335c2be3343cb481c0e8113"
+"checksum netlink-packet-route 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c46727055595605d4f625e633e5e9bd7296e7c79ea701aaf2dd53b5062cd5aa3"
+"checksum netlink-packet-utils 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2ce628faa6689198d3db4f68e6165a5ba02a8e0a5fe741cca9c1b7856bab6a66"
+"checksum netlink-proto 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fba96b2619706b19de0d6620c4479c0ba2c8d293164e55494f02a3bfdaffd36a"
+"checksum netlink-sys 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d322b2ea918a35471492c831381c18bf8afaefeb2ecddc04ef2de6d166ee0fb8"
"checksum nftnl 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b7528eff501558f9f892c5001e945b0d7e980cb464a7969101c94e18481c4563"
"checksum nftnl-sys 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fe241d8ce673ef755c8d2b8717cd74990d4e0a61d437792054750ce9a35743d0"
"checksum nix 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3b2e0b4f3320ed72aaedb9a5ac838690a8047c7b275da22711fddff4f8a14229"
"checksum nix 0.17.0 (registry+https://github.com/rust-lang/crates.io-index)" = "50e4785f2c3b7589a0d0c1dd60285e1188adac4006e8abd6dd578e1567027363"
+"checksum nix 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83450fe6a6142ddd95fb064b746083fc4ef1705fe81f64a64e1d4b39f54a1055"
"checksum notify 4.0.15 (registry+https://github.com/rust-lang/crates.io-index)" = "80ae4a7688d1fab81c5bf19c64fc8db920be8d519ce6336ed4e7efe024724dbd"
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
@@ -3402,7 +3438,7 @@ dependencies = [
"checksum resolv-conf 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)" = "11834e137f3b14e309437a8276714eed3a80d1ef894869e510f2c0c0b98b9f4a"
"checksum ring 0.16.15 (registry+https://github.com/rust-lang/crates.io-index)" = "952cd6b98c85bbc30efa1ba5783b8abf12fec8b3287ffa52605b9432313e34e4"
"checksum rs-release 0.1.7 (git+https://github.com/mullvad/rs-release?branch=snailquote-unescape)" = "<none>"
-"checksum rtnetlink 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f802e53265ca90edd3cfc59a3ceb2c30d655da6a1a6b954684d51977633ef5fd"
+"checksum rtnetlink 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d3a3e80f0a3ac6877e56c430f41a581a1c8e95c2bd979a84f90270df9808feed"
"checksum rust-argon2 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2bc8af4bda8e1ff4932523b94d3dd20ee30a87232323eda55903ffd71d2fb017"
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
"checksum rustc-serialize 0.3.24 (registry+https://github.com/rust-lang/crates.io-index)" = "dcf128d1287d2ea9d80910b5f1120d0b8eede3fbf1abe91c40d39ea7d51e6fda"
@@ -3454,6 +3490,7 @@ dependencies = [
"checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14"
"checksum time 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
"checksum tokio 0.2.22 (registry+https://github.com/rust-lang/crates.io-index)" = "5d34ca54d84bf2b5b4d7d31e901a8464f7b60ac145a284fba25ceb801f2ddccd"
+"checksum tokio-io 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674"
"checksum tokio-macros 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c3acc6aa564495a0f2e1d59fab677cd7f81a19994cfc7f3ad0e64301560389"
"checksum tokio-rustls 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "15cb62a0d2770787abc96e99c1cd98fcf17f94959f3af63ca85bdfb203f051b4"
"checksum tokio-util 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "571da51182ec208780505a32528fc5512a8fe1443ab960b3f2f3ef093cd16930"
diff --git a/docs/security.md b/docs/security.md
index 7fc5369bd3..786cf99d8b 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -124,8 +124,11 @@ VPN tunnel is allowed on all interfaces, together with responses to this outgoin
First hop means the bridge server if one is used, otherwise the VPN server directly.
This IP+port+protocol combination should only be allowed for the process establishing the
VPN tunnel, or only administrator level processes, depending on what the platform firewall
-allows restricting. On Windows the rule only allows processes from binaries in certain paths.
-On Linux and macOS the rule only allows packets from processes running as `root`.
+allows restricting. On Windows the rule only allows processes from binaries in certain paths. macOS
+the rule only allows packets from processes running as `root`. On Linux, the rule only allows
+packets that have the mark `0x6d6f6c65` set: setting a firewall mark on traffic requires elevated
+privileges when using tunnels that support marking traffic, otherwise the rule is the same as on
+macOS: the packet needs to originate from a process running as `root`.
This process/user check is important to not allow unprivileged programs
to leak packets to this IP outside the tunnel, as those packets can be fingerprinted.
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 5222c5d26d..9828bbedf6 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -39,7 +39,8 @@ tonic = "0.3.1"
prost = "0.6"
[target.'cfg(unix)'.dependencies]
-nix = "0.17"
+nix = "0.18"
+tokio-io = "0.1"
[target.'cfg(target_os = "android")'.dependencies]
@@ -52,10 +53,13 @@ failure = "0.1"
notify = "4.0"
resolv-conf = "0.6.1"
async-stream = "0.2"
-rtnetlink = "0.2"
-netlink-packet-route = "0.2"
-netlink-proto = "0.2"
-netlink-sys = "0.2"
+rtnetlink = "0.3"
+netlink-packet-core = "0.2"
+netlink-packet-utils = "0.2"
+netlink-packet-route = "0.3"
+netlink-proto = "0.4"
+netlink-sys = "0.3"
+byteorder = "1"
futures = { package = "futures", version = "0.3" }
nftnl = { version = "0.5", features = ["nftnl-1-1-0"] }
mnl = { version = "0.2.0", features = ["mnl-1-0-4"] }
diff --git a/talpid-core/src/dns/linux/network_manager.rs b/talpid-core/src/dns/linux/network_manager.rs
index 0e517e07f1..a8c4efccd3 100644
--- a/talpid-core/src/dns/linux/network_manager.rs
+++ b/talpid-core/src/dns/linux/network_manager.rs
@@ -194,6 +194,12 @@ impl NetworkManager {
.get(NM_DEVICE, "Ip6Config")
.map_err(Error::Dbus)?;
+ let device_addresses6: Vec<(Vec<u8>, u32, Vec<u8>)> = self
+ .dbus_connection
+ .with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS)
+ .get(NM_IP6_CONFIG, "Addresses")
+ .map_err(Error::Dbus)?;
+
let device_routes6: Vec<(Vec<u8>, u32, Vec<u8>, u32)> = self
.dbus_connection
.with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS)
@@ -209,6 +215,7 @@ impl NetworkManager {
ipv6_settings.insert("route-metric", Variant(Box::new(0u32)));
ipv6_settings.insert("routes", Variant(Box::new(device_routes6)));
ipv6_settings.insert("route-data", Variant(Box::new(device_route6_data)));
+ ipv6_settings.insert("addresses", Variant(Box::new(device_addresses6)));
}
let mut settings_backup =
@@ -248,6 +255,13 @@ impl NetworkManager {
Self::update_dns_config(&mut settings, "ipv6", v6_dns);
}
+ if let Some(wg_config) = settings.get_mut("wireguard") {
+ wg_config.insert(
+ "fwmark",
+ Variant(Box::new(crate::linux::TUNNEL_FW_MARK) as Box<dyn RefArg>),
+ );
+ }
+
self.reapply_settings(&device, settings, version_id)?;
self.device = Some(device);
diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs
index 3a68d3734d..f4a3e07378 100644
--- a/talpid-core/src/firewall/linux.rs
+++ b/talpid-core/src/firewall/linux.rs
@@ -445,9 +445,11 @@ impl<'a> PolicyBatch<'a> {
peer_endpoint,
pingable_hosts,
allow_lan,
+ use_fwmark,
} => {
self.add_allow_icmp_pingable_hosts(&pingable_hosts);
- self.add_allow_endpoint_rules(peer_endpoint);
+ self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark);
+
// Important to block DNS after allow relay rule (so the relay can operate
// over port 53) but before allow LAN (so DNS does not leak to the LAN)
self.add_drop_dns_rule();
@@ -457,8 +459,9 @@ impl<'a> PolicyBatch<'a> {
peer_endpoint,
tunnel,
allow_lan,
+ use_fwmark,
} => {
- self.add_allow_endpoint_rules(peer_endpoint);
+ self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark);
self.add_allow_dns_rules(tunnel, TransportProtocol::Udp)?;
self.add_allow_dns_rules(tunnel, TransportProtocol::Tcp)?;
// Important to block DNS *before* we allow the tunnel and allow LAN. So DNS
@@ -492,7 +495,7 @@ impl<'a> PolicyBatch<'a> {
Ok(())
}
- fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) {
+ fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) {
let mut in_rule = Rule::new(&self.in_chain);
check_endpoint(&mut in_rule, End::Src, endpoint);
@@ -504,11 +507,15 @@ impl<'a> PolicyBatch<'a> {
self.batch.add(&in_rule, nftnl::MsgType::Add);
-
let mut out_rule = Rule::new(&self.out_chain);
check_endpoint(&mut out_rule, End::Dst, endpoint);
- out_rule.add_expr(&nft_expr!(meta skuid));
- out_rule.add_expr(&nft_expr!(cmp == 0u32));
+ if use_fwmark {
+ out_rule.add_expr(&nft_expr!(meta mark));
+ out_rule.add_expr(&nft_expr!(cmp == crate::linux::TUNNEL_FW_MARK));
+ } else {
+ out_rule.add_expr(&nft_expr!(meta skuid));
+ out_rule.add_expr(&nft_expr!(cmp == 0u32));
+ }
add_verdict(&mut out_rule, &Verdict::Accept);
self.batch.add(&out_rule, nftnl::MsgType::Add);
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index 6b65dd5b70..4d8d3f0459 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -97,6 +97,10 @@ pub enum FirewallPolicy {
/// A process that is allowed to send packets to the relay.
#[cfg(windows)]
relay_client: PathBuf,
+ /// Whether rule for allowing traffic to endpoint should match a firewall mark or match on
+ /// root UID.
+ #[cfg(target_os = "linux")]
+ use_fwmark: bool,
},
/// Allow traffic only to server and over tunnel interface
@@ -110,6 +114,10 @@ pub enum FirewallPolicy {
/// A process that is allowed to send packets to the relay.
#[cfg(windows)]
relay_client: PathBuf,
+ /// Whether rule for allowing traffic to endpoint should match a firewall mark or match on
+ /// root UID.
+ #[cfg(target_os = "linux")]
+ use_fwmark: bool,
},
/// Block all network traffic in and out from the computer.
diff --git a/talpid-core/src/linux.rs b/talpid-core/src/linux.rs
index 05655bf4be..47d0714813 100644
--- a/talpid-core/src/linux.rs
+++ b/talpid-core/src/linux.rs
@@ -25,3 +25,6 @@ pub enum IfaceIndexLookupError {
#[error(display = "Failed to get index for interface {}", _0)]
InterfaceLookupError(String, #[error(source)] io::Error),
}
+
+// b"mole" is [ 0x6d, 0x6f 0x6c, 0x65 ]
+pub const TUNNEL_FW_MARK: u32 = 0x6d6f6c65;
diff --git a/talpid-core/src/process/openvpn.rs b/talpid-core/src/process/openvpn.rs
index e07be4fa16..7922ba952c 100644
--- a/talpid-core/src/process/openvpn.rs
+++ b/talpid-core/src/process/openvpn.rs
@@ -249,6 +249,13 @@ impl OpenVpnCommand {
args.extend(Self::tls_cipher_arguments().iter().map(OsString::from));
args.extend(self.proxy_arguments().iter().map(OsString::from));
+ #[cfg(target_os = "linux")]
+ args.extend(
+ ["--mark", &crate::linux::TUNNEL_FW_MARK.to_string()]
+ .iter()
+ .map(OsString::from),
+ );
+
args
}
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index 566e1c027b..c7be9c6bff 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -52,7 +52,7 @@ pub enum Error {
BindError(#[error(source)] io::Error),
#[error(display = "Netlink error")]
- NetlinkError(#[error(source)] failure::Compat<rtnetlink::Error>),
+ NetlinkError(#[error(source)] rtnetlink::Error),
#[error(display = "Route without a valid node")]
InvalidRoute,
@@ -379,7 +379,6 @@ impl RouteManagerImpl {
while let Some(route) = route_request
.try_next()
.await
- .map_err(failure::Fail::compat)
.map_err(Error::NetlinkError)?
{
if route.header.destination_prefix_length == 0 {
@@ -394,12 +393,7 @@ impl RouteManagerImpl {
async fn initialize_link_map(handle: &rtnetlink::Handle) -> Result<BTreeMap<u32, String>> {
let mut link_map = BTreeMap::new();
let mut link_request = handle.link().get().execute();
- while let Some(link) = link_request
- .try_next()
- .await
- .map_err(failure::Fail::compat)
- .map_err(Error::NetlinkError)?
- {
+ while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? {
if let Some((idx, link_name)) = Self::map_iface_name_to_idx(link) {
link_map.insert(idx, link_name);
}
@@ -536,7 +530,7 @@ impl RouteManagerImpl {
Route::new(best_node, required_route.destination).table(required_route.table_id);
if let Err(e) = self.delete_route(&route).await {
if let Error::NetlinkError(err) = &e {
- if let rtnetlink::ErrorKind::NetlinkError(msg) = err.get_ref().kind() {
+ if let rtnetlink::Error::NetlinkError(msg) = err {
// -3 means that the route doesn't exist anymore anyway
if msg.code == -3 {
continue;
@@ -551,7 +545,7 @@ impl RouteManagerImpl {
for route in self.added_routes.drain().collect::<Vec<_>>().iter() {
if let Err(e) = self.delete_route(&route).await {
if let Error::NetlinkError(err) = &e {
- if let rtnetlink::ErrorKind::NetlinkError(msg) = err.get_ref().kind() {
+ if let rtnetlink::Error::NetlinkError(msg) = err {
// -3 means that the route doesn't exist anymore anyway
if msg.code == -3 {
continue;
@@ -785,7 +779,6 @@ impl RouteManagerImpl {
.del(route_message)
.execute()
.await
- .map_err(failure::Fail::compat)
.map_err(Error::NetlinkError)
}
@@ -849,17 +842,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
- let mut response = self
- .handle
- .request(req)
- .map_err(failure::Fail::compat)
- .map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(err) = message.payload {
- let compat_err =
- failure::Fail::compat(rtnetlink::ErrorKind::NetlinkError(err).into());
- return Err(Error::NetlinkError(compat_err));
+ return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err)));
}
}
self.added_routes.insert(route.clone());
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index a59ee6ed27..503e69bc35 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -194,6 +194,12 @@ impl RouteManager {
}
}
+ /// Exposes runtime handle
+ #[cfg(target_os = "linux")]
+ pub fn runtime_handle(&self) -> tokio::runtime::Handle {
+ self.runtime.handle().clone()
+ }
+
/// Route DNS requests through the tunnel interface.
#[cfg(target_os = "linux")]
pub fn route_exclusions_dns(
diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs
index 4f0e851b1e..87c0d2c0f0 100644
--- a/talpid-core/src/tunnel/wireguard/config.rs
+++ b/talpid-core/src/tunnel/wireguard/config.rs
@@ -17,6 +17,9 @@ pub struct Config {
pub ipv6_gateway: Option<Ipv6Addr>,
/// Maximum transmission unit for the tunnel
pub mtu: u16,
+ /// Firewall mark
+ #[cfg(target_os = "linux")]
+ pub fwmark: u32,
}
const DEFAULT_MTU: u16 = 1380;
@@ -96,6 +99,8 @@ impl Config {
ipv4_gateway: connection_config.ipv4_gateway,
ipv6_gateway,
mtu,
+ #[cfg(target_os = "linux")]
+ fwmark: crate::linux::TUNNEL_FW_MARK,
})
}
@@ -108,6 +113,9 @@ impl Config {
.add("private_key", self.tunnel.private_key.to_bytes().as_ref())
.add("listen_port", "0");
+ #[cfg(target_os = "linux")]
+ wg_conf.add("fwmark", self.fwmark.to_string().as_str());
+
wg_conf.add("replace_peers", "true");
for peer in &self.peers {
diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
index 111edd772a..8a604fcf6f 100644
--- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs
+++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
@@ -514,8 +514,8 @@ mod test {
}
impl Tunnel for MockTunnel {
- fn get_interface_name(&self) -> &str {
- "mock-tunnel"
+ fn get_interface_name(&self) -> String {
+ "mock-tunnel".to_string()
}
fn stop(self: Box<Self>) -> Result<(), TunnelError> {
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 6324d13e80..987274acb5 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -16,6 +16,8 @@ mod connectivity_check;
mod logging;
mod stats;
mod wireguard_go;
+#[cfg(target_os = "linux")]
+mod wireguard_kernel;
use self::wireguard_go::WgGoTunnel;
@@ -62,12 +64,7 @@ impl WireguardMonitor {
tun_provider: &mut TunProvider,
route_manager: &mut routing::RouteManager,
) -> Result<WireguardMonitor> {
- let tunnel = Box::new(WgGoTunnel::start_tunnel(
- &config,
- log_path,
- tun_provider,
- Self::get_tunnel_routes(config),
- )?);
+ let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?;
let iface_name = tunnel.get_interface_name().to_string();
route_manager
.add_routes(Self::get_routes(&iface_name, &config))
@@ -93,7 +90,7 @@ impl WireguardMonitor {
let close_sender = monitor.close_msg_sender.clone();
let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
gateway,
- iface_name,
+ iface_name.to_string(),
Arc::downgrade(&monitor.tunnel),
pinger_rx,
)?;
@@ -125,6 +122,34 @@ impl WireguardMonitor {
Ok(monitor)
}
+ #[cfg_attr(not(target_os = "linux"), allow(unused_variables))]
+ fn open_tunnel(
+ config: &Config,
+ log_path: Option<&Path>,
+ tun_provider: &mut TunProvider,
+ route_manager: &mut routing::RouteManager,
+ ) -> Result<Box<dyn Tunnel>> {
+ #[cfg(target_os = "linux")]
+ match wireguard_kernel::KernelTunnel::new(route_manager.runtime_handle(), config) {
+ Ok(tunnel) => {
+ return Ok(Box::new(tunnel));
+ }
+ Err(err) => {
+ log::error!(
+ "Failed to setup kernel WireGuard device, falling back to userspace: {}",
+ err
+ );
+ }
+ };
+
+ Ok(Box::new(WgGoTunnel::start_tunnel(
+ &config,
+ log_path,
+ tun_provider,
+ Self::get_tunnel_routes(config),
+ )?))
+ }
+
/// Returns a close handle for the tunnel
pub fn close_handle(&self) -> CloseHandle {
CloseHandle {
@@ -228,7 +253,7 @@ impl CloseHandle {
}
pub(crate) trait Tunnel: Send {
- fn get_interface_name(&self) -> &str;
+ fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
fn get_tunnel_stats(&self) -> std::result::Result<stats::Stats, TunnelError>;
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 761276dcdf..b4d437f046 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -283,8 +283,8 @@ impl Drop for WgGoTunnel {
}
impl Tunnel for WgGoTunnel {
- fn get_interface_name(&self) -> &str {
- &self.interface_name
+ fn get_interface_name(&self) -> String {
+ self.interface_name.clone()
}
fn get_tunnel_stats(&self) -> Result<Stats> {
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
new file mode 100644
index 0000000000..b1fa3a0f65
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
@@ -0,0 +1,479 @@
+use super::{stats::Stats, Config, Tunnel, TunnelError};
+use futures::future::{abortable, AbortHandle};
+use netlink_packet_core::{constants::*, NetlinkDeserializable};
+use netlink_packet_route::{
+ rtnl::{
+ address::nlas::Nla as AddressNla,
+ link::nlas::{Info, InfoKind, Nla as LinkNla},
+ AddressMessage, LinkMessage, RtnlMessage, RT_SCOPE_UNIVERSE,
+ },
+ NetlinkMessage, NetlinkPayload,
+};
+use netlink_packet_utils::DecodeError;
+use netlink_proto::{
+ sys::{Protocol, SocketAddr},
+ ConnectionHandle, Error as NetlinkError,
+};
+use std::{ffi::CString, net::IpAddr};
+use tokio::stream::StreamExt;
+
+mod parsers;
+
+mod wg_message;
+use wg_message::{DeviceMessage, DeviceNla, PeerNla};
+mod nl_message;
+use nl_message::{ControlNla, NetlinkControlMessage};
+
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ #[error(display = "Failed to decode netlink message")]
+ DecodeError(#[error(source)] DecodeError),
+
+ #[error(display = "Failed to execute netlink control request")]
+ NetlinkControlMessageError(#[error(source)] nl_message::Error),
+
+ #[error(display = "Failed to open netlink socket")]
+ NetlinkSocketError(#[error(source)] std::io::Error),
+
+ #[error(display = "Failed to send netlink control request")]
+ NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
+
+ #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")]
+ WireguardNetlinkInterfaceUnavailable,
+
+ #[error(display = "Unknown WireGuard command _0")]
+ UnnkownWireguardCommmand(u8),
+
+ #[error(display = "Received no response")]
+ NoResponse,
+
+ #[error(display = "Received truncated message")]
+ Truncated,
+
+ #[error(display = "WireGuard device does not exist")]
+ NoDevice,
+
+ #[error(display = "Failed to get config: _0")]
+ WgGetConfError(netlink_packet_core::error::ErrorMessage),
+
+ #[error(display = "Failed to apply config: _0")]
+ WgSetConfError(netlink_packet_core::error::ErrorMessage),
+
+ #[error(display = "Interface name too long")]
+ InterfaceNameError,
+
+ #[error(display = "Send request error")]
+ SendRequestError(#[error(source)] NetlinkError<DeviceMessage>),
+
+ #[error(display = "Create device error")]
+ NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error),
+
+ #[error(display = "Add IP to device error")]
+ NetlinkSetIpError(rtnetlink::Error),
+
+ #[error(display = "Failed to delete device")]
+ DeleteDeviceError(#[error(source)] rtnetlink::Error),
+}
+
+pub struct KernelTunnel {
+ interface_index: u32,
+ netlink_connections: Handle,
+ tokio_handle: tokio::runtime::Handle,
+}
+
+const MULLVAD_INTERFACE_NAME: &str = "wg-mullvad";
+
+impl KernelTunnel {
+ pub fn new(tokio_handle: tokio::runtime::Handle, config: &Config) -> Result<Self, Error> {
+ tokio_handle.clone().block_on(async {
+ let mut netlink_connections = Handle::connect().await?;
+ let interface_index = netlink_connections
+ .create_device(MULLVAD_INTERFACE_NAME.to_string(), config.mtu as u32)
+ .await?;
+
+ let mut tunnel = Self {
+ interface_index,
+ netlink_connections,
+ tokio_handle,
+ };
+
+ if let Err(err) = tunnel.setup(config).await {
+ if let Err(teardown_err) = tunnel
+ .netlink_connections
+ .delete_device(interface_index)
+ .await
+ {
+ log::error!(
+ "Failed to tear down WireGuard interface after failing to apply config: {}",
+ teardown_err
+ );
+ }
+ return Err(err);
+ }
+
+
+ Ok(tunnel)
+ })
+ }
+
+ async fn setup(&mut self, config: &Config) -> Result<(), Error> {
+ self.netlink_connections
+ .wg_handle
+ .set_config(self.interface_index, config)
+ .await?;
+
+ for tunnel_ip in config.tunnel.addresses.iter() {
+ self.netlink_connections
+ .set_ip_address(self.interface_index, *tunnel_ip)
+ .await?;
+ }
+
+ Ok(())
+ }
+}
+
+impl Tunnel for KernelTunnel {
+ fn get_interface_name(&self) -> String {
+ let mut wg = self.netlink_connections.wg_handle.clone();
+ let result = self.tokio_handle.block_on(async move {
+ let device = wg.get_by_index(self.interface_index).await?;
+ for nla in device.nlas {
+ if let DeviceNla::IfName(name) = nla {
+ return Ok(name);
+ }
+ }
+ return Err(Error::Truncated);
+ });
+
+ match result {
+ Ok(name) => name.to_string_lossy().to_string(),
+ Err(err) => {
+ log::error!("Failed to deduce interface name at runtime, will attempt to use the default name. {}", err);
+ MULLVAD_INTERFACE_NAME.to_string()
+ }
+ }
+ }
+
+ fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError> {
+ let Self {
+ mut netlink_connections,
+ interface_index,
+ tokio_handle,
+ } = *self;
+ tokio_handle.block_on(async move {
+ if let Err(err) = netlink_connections.delete_device(interface_index).await {
+ log::error!("Failed to remove WireGuard device - {}", err);
+ Err(TunnelError::FatalStartWireguardError)
+ } else {
+ Ok(())
+ }
+ })
+ }
+
+ fn get_tunnel_stats(&self) -> std::result::Result<Stats, TunnelError> {
+ let mut wg = self.netlink_connections.wg_handle.clone();
+ let interface_index = self.interface_index;
+ let result = self.tokio_handle.block_on(async move {
+ let device = wg.get_by_index(interface_index).await.map_err(|err| {
+ log::error!("Failed to fetch WireGuard device config: {}", err);
+ TunnelError::GetConfigError
+ })?;
+
+ // iterate over device attributes
+ let mut tx_bytes = 0;
+ let mut rx_bytes = 0;
+ for nla in device.nlas {
+ if let DeviceNla::Peers(peers) = nla {
+ // iterate over all peer attributes
+ let peer_iter = peers.iter().map(|peer| peer.0.as_slice()).flatten();
+
+ for peer_nla in peer_iter {
+ match peer_nla {
+ PeerNla::TxBytes(bytes) => tx_bytes += *bytes,
+ PeerNla::RxBytes(bytes) => rx_bytes += *bytes,
+ _ => continue,
+ };
+ }
+ }
+ }
+
+ Ok(Stats { tx_bytes, rx_bytes })
+ });
+
+ result
+ }
+}
+
+
+#[derive(Debug)]
+pub struct Handle {
+ wg_handle: WireguardConnection,
+ route_handle: rtnetlink::Handle,
+ wg_abort_handle: AbortHandle,
+ route_abort_handle: AbortHandle,
+ message_type: u16,
+}
+
+
+impl Handle {
+ pub async fn connect() -> Result<Self, Error> {
+ let message_type = Self::get_wireguard_message_type().await?;
+ let (conn, wireguard_connection, _messages) =
+ netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?;
+ let wg_handle = WireguardConnection {
+ message_type,
+ connection: wireguard_connection,
+ };
+ let (abortable_connection, wg_abort_handle) = abortable(conn);
+ tokio::spawn(abortable_connection);
+ let (conn, route_handle, _messages) =
+ rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?;
+ let (abortable_connection, route_abort_handle) = abortable(conn);
+ tokio::spawn(abortable_connection);
+
+
+ Ok(Self {
+ wg_handle,
+ route_handle,
+ message_type,
+ wg_abort_handle,
+ route_abort_handle,
+ })
+ }
+
+ async fn get_wireguard_message_type() -> Result<u16, Error> {
+ let (conn, mut handle, _messages) =
+ netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?;
+ let (conn, abort_handle) = abortable(conn);
+ tokio::spawn(conn);
+
+ let result = async move {
+ let mut message: NetlinkMessage<NetlinkControlMessage> =
+ NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap())
+ .map_err(Error::NetlinkControlMessageError)?
+ .into();
+
+ message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
+
+ let mut req = handle
+ .request(message, SocketAddr::new(0, 0))
+ .map_err(Error::NetlinkRequestError)?;
+ let response = req.next().await;
+ if let Some(response) = response {
+ if let NetlinkPayload::InnerMessage(msg) = response.payload {
+ for nla in msg.nlas.into_iter() {
+ if let ControlNla::FamilyId(id) = nla {
+ return Ok(id);
+ }
+ }
+ }
+ }
+ Err(Error::WireguardNetlinkInterfaceUnavailable)
+ }
+ .await;
+
+ abort_handle.abort();
+ result
+ }
+
+ // create a wireguard device with the given name.
+ pub async fn create_device(&mut self, name: String, mtu: u32) -> Result<u32, Error> {
+ let mut message = LinkMessage::default();
+
+ // set link to be up
+ message.header.flags = netlink_packet_route::IFF_UP;
+ // message.header.change_mask = netlink_packet_route::IFF_UP;
+ // set link name
+ message.nlas.push(LinkNla::IfName(name.clone()));
+ // set link MTU
+ message.nlas.push(LinkNla::Mtu(mtu));
+ // set link type
+ message
+ .nlas
+ .push(LinkNla::Info(vec![Info::Kind(InfoKind::Other(
+ "wireguard".to_string(),
+ ))]));
+
+ let mut add_request = NetlinkMessage::from(RtnlMessage::NewLink(message));
+ add_request.header.flags =
+ NLM_F_REQUEST | NLM_F_ACK | NLM_F_REPLACE | NLM_F_CREATE | NLM_F_MATCH;
+ let mut response = self
+ .route_handle
+ .request(add_request)
+ .map_err(Error::NetlinkCreateDeviceError)?;
+ while let Some(response_message) = response.next().await {
+ if let NetlinkPayload::Error(err) = response_message.payload {
+ // if the device exists, verify that it's a wireguard device
+ if -err.code != libc::EEXIST {
+ return Err(Error::NetlinkCreateDeviceError(
+ rtnetlink::Error::NetlinkError(err),
+ ));
+ }
+ }
+ }
+
+ // fetch interface index of new device
+ let new_device = self.wg_handle.get_by_name(name).await?;
+ for nla in new_device.nlas {
+ if let DeviceNla::IfIndex(index) = nla {
+ return Ok(index);
+ }
+ }
+
+
+ Err(Error::NoDevice)
+ }
+
+ pub async fn set_ip_address(&mut self, index: u32, addr: IpAddr) -> Result<(), Error> {
+ let address_message = add_ip_addr_message(index, addr);
+ let mut request = NetlinkMessage::from(RtnlMessage::NewAddress(address_message));
+ request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
+
+
+ let mut response = self
+ .route_handle
+ .request(request)
+ .map_err(Error::NetlinkSetIpError)?;
+ while let Some(response_message) = response.next().await {
+ consume_netlink_error(response_message, Error::NetlinkSetIpError)?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn delete_device(&mut self, index: u32) -> Result<(), Error> {
+ let mut link_message = LinkMessage::default();
+ link_message.header.index = index;
+
+ let mut request = NetlinkMessage::from(RtnlMessage::DelLink(link_message));
+ request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+
+ let mut response = self
+ .route_handle
+ .request(request)
+ .map_err(Error::DeleteDeviceError)?;
+ while let Some(message) = response.next().await {
+ consume_netlink_error(message, Error::DeleteDeviceError)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl Drop for Handle {
+ fn drop(&mut self) {
+ self.wg_abort_handle.abort();
+ self.route_abort_handle.abort();
+ }
+}
+
+#[derive(Debug, Clone)]
+struct WireguardConnection {
+ connection: ConnectionHandle<DeviceMessage>,
+ message_type: u16,
+}
+
+impl WireguardConnection {
+ pub async fn get_by_name(&mut self, name: String) -> Result<DeviceMessage, Error> {
+ self.fetch_device(DeviceMessage::get_by_name(self.message_type, name)?)
+ .await
+ }
+
+ pub async fn get_by_index(&mut self, index: u32) -> Result<DeviceMessage, Error> {
+ self.fetch_device(DeviceMessage::get_by_index(self.message_type, index))
+ .await
+ }
+
+ pub async fn fetch_device(
+ &mut self,
+ device_message: DeviceMessage,
+ ) -> Result<DeviceMessage, Error> {
+ let mut netlink_message = NetlinkMessage::from(device_message);
+ netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
+
+ let mut response = self
+ .connection
+ .request(netlink_message, SocketAddr::new(0, 0))
+ .map_err(Error::SendRequestError)?;
+ match response.next().await {
+ Some(received_message) => match received_message.payload {
+ NetlinkPayload::InnerMessage(inner) => Ok(inner),
+ NetlinkPayload::Error(err) => {
+ if err.code == -libc::ENODEV {
+ Err(Error::NoDevice)
+ } else {
+ Err(Error::WgGetConfError(err))
+ }
+ }
+ anything_else => {
+ log::error!("Received unexpected response - {:?}", anything_else);
+ Err(Error::NoResponse)
+ }
+ },
+ None => Err(Error::NoResponse),
+ }
+ }
+
+ pub async fn set_config(&mut self, interface_index: u32, config: &Config) -> Result<(), Error> {
+ let message = DeviceMessage::reset_config(self.message_type, interface_index, config);
+ let mut netlink_message = NetlinkMessage::from(message);
+ netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
+
+ let mut request = self
+ .connection
+ .request(netlink_message, SocketAddr::new(0, 0))
+ .map_err(Error::SendRequestError)?;
+
+ while let Some(response) = request.next().await {
+ if let NetlinkPayload::Error(err) = response.payload {
+ return Err(Error::WgSetConfError(err));
+ }
+ }
+ Ok(())
+ }
+}
+
+
+fn consume_netlink_error<
+ T,
+ I: NetlinkDeserializable<T> + Clone + Eq + std::fmt::Debug,
+ F: Fn(rtnetlink::Error) -> Error,
+>(
+ message: NetlinkMessage<I>,
+ err_constructor: F,
+) -> Result<(), Error> {
+ if let NetlinkPayload::Error(err) = message.payload {
+ return Err(err_constructor(rtnetlink::Error::NetlinkError(err)));
+ }
+ Ok(())
+}
+
+// the built-in support for adding addresses is too helpful, so a simple AddressMessage with a
+// single Address nla is created
+fn add_ip_addr_message(if_index: u32, addr: IpAddr) -> AddressMessage {
+ let prefix_len = if addr.is_ipv4() { 32 } else { 128 };
+ let mut message = AddressMessage::default();
+ message.header.prefix_len = prefix_len;
+ message.header.index = if_index;
+ message.header.scope = RT_SCOPE_UNIVERSE;
+
+ match addr {
+ IpAddr::V4(ipv4) => {
+ message.header.family = libc::AF_INET as u8;
+ let ip_bytes = ipv4.octets().to_vec();
+
+ message.nlas.push(AddressNla::Address(ip_bytes.clone()));
+ message.nlas.push(AddressNla::Local(ip_bytes));
+ }
+ IpAddr::V6(ipv6) => {
+ message.header.family = libc::AF_INET6 as u8;
+ message
+ .nlas
+ .push(AddressNla::Address(ipv6.octets().to_vec()));
+ }
+ };
+
+ message
+}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs
new file mode 100644
index 0000000000..7fc8be8304
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs
@@ -0,0 +1,135 @@
+use super::parsers;
+use byteorder::{ByteOrder, NativeEndian};
+use netlink_packet_core::{
+ NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable,
+};
+use netlink_packet_utils::{
+ nla::{Nla, NlaBuffer, NlasIterator},
+ traits::{Emitable, Parseable},
+ DecodeError,
+};
+use std::{ffi::CString, io::Write, mem};
+
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "Family name too long")]
+ FamilyNameTooLong,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct NetlinkControlMessage {
+ cmd: u8,
+ version: u8,
+ pub nlas: Vec<ControlNla>,
+}
+
+impl NetlinkControlMessage {
+ pub fn get_netlink_family_id(name: CString) -> Result<Self, Error> {
+ if name.as_bytes_with_nul().len() > (libc::GENL_NAMSIZ as usize) {
+ return Err(Error::FamilyNameTooLong);
+ }
+ Ok(Self {
+ nlas: vec![ControlNla::FamilyName(name)],
+ cmd: libc::CTRL_CMD_GETFAMILY as u8,
+ version: 1,
+ })
+ }
+}
+
+
+impl NetlinkSerializable<NetlinkControlMessage> for NetlinkControlMessage {
+ fn message_type(&self) -> u16 {
+ libc::GENL_ID_CTRL as u16
+ }
+
+ fn buffer_len(&self) -> usize {
+ mem::size_of::<libc::genlmsghdr>() + self.nlas.as_slice().buffer_len()
+ }
+
+ fn serialize(&self, mut buffer: &mut [u8]) {
+ let _ = buffer.write(&[self.cmd, self.version, 0u8, 0u8]).unwrap();
+ self.nlas.as_slice().emit(&mut buffer);
+ }
+}
+
+impl Into<NetlinkPayload<NetlinkControlMessage>> for NetlinkControlMessage {
+ fn into(self) -> NetlinkPayload<NetlinkControlMessage> {
+ NetlinkPayload::InnerMessage(self)
+ }
+}
+
+impl NetlinkDeserializable<NetlinkControlMessage> for NetlinkControlMessage {
+ type Error = DecodeError;
+ fn deserialize(
+ _header: &NetlinkHeader,
+ payload: &[u8],
+ ) -> Result<NetlinkControlMessage, Self::Error> {
+ // skip the genlmsghdr
+ let (cmd, version) = parsers::parse_genlmsghdr(payload)?;
+ let nla_buffer = &payload[mem::size_of::<libc::genlmsghdr>()..];
+ let nlas = NlasIterator::new(nla_buffer)
+ .map(|buffer| ControlNla::parse(&buffer?))
+ .collect::<Result<Vec<_>, DecodeError>>()?;
+
+ Ok(NetlinkControlMessage { nlas, cmd, version })
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum ControlNla {
+ FamilyName(CString),
+ FamilyId(u16),
+ Unknown(u16, Vec<u8>),
+}
+
+impl Nla for ControlNla {
+ fn value_len(&self) -> usize {
+ use ControlNla::*;
+ match self {
+ FamilyName(name) => name.as_bytes_with_nul().len(),
+ FamilyId(_id) => 2,
+ Unknown(_, buffer) => buffer.len(),
+ }
+ }
+
+ fn kind(&self) -> u16 {
+ use ControlNla::*;
+ match self {
+ FamilyName(_) => libc::CTRL_ATTR_FAMILY_NAME as u16,
+ FamilyId(_) => libc::CTRL_ATTR_FAMILY_ID as u16,
+ Unknown(kind, _) => *kind,
+ }
+ }
+
+ fn emit_value(&self, mut buffer: &mut [u8]) {
+ use ControlNla::*;
+ match self {
+ FamilyName(name) => {
+ let _ = buffer.write(name.as_bytes()).unwrap();
+ }
+ FamilyId(id) => {
+ NativeEndian::write_u16(buffer, *id);
+ }
+
+ Unknown(_, value) => {
+ let _ = buffer.write(value).unwrap();
+ }
+ }
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized + std::fmt::Debug> Parseable<NlaBuffer<&'a T>>
+ for ControlNla
+{
+ fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ let nla = match buf.kind() as i32 {
+ libc::CTRL_ATTR_FAMILY_NAME => {
+ ControlNla::FamilyName(parsers::parse_cstring(buf.value())?)
+ }
+ libc::CTRL_ATTR_FAMILY_ID => ControlNla::FamilyId(parsers::parse_u16(buf.value())?),
+ _unknown_kind => ControlNla::Unknown(buf.kind(), buf.value().to_vec()),
+ };
+ Ok(nla)
+ }
+}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs
new file mode 100644
index 0000000000..b34c82d342
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs
@@ -0,0 +1,99 @@
+use byteorder::{ByteOrder, NativeEndian};
+use nix::sys::{socket::InetAddr, time::TimeSpec};
+use std::{
+ ffi::{CStr, CString},
+ mem,
+ net::IpAddr,
+};
+
+pub use netlink_packet_utils::parsers::*;
+use netlink_packet_utils::DecodeError;
+
+pub fn parse_ip_addr(bytes: &[u8]) -> Result<IpAddr, DecodeError> {
+ if bytes.len() == 4 {
+ let mut ipv4_bytes = [0u8; 4];
+ ipv4_bytes.copy_from_slice(bytes);
+ Ok(IpAddr::from(ipv4_bytes))
+ } else if bytes.len() == 16 {
+ let mut ipv6_bytes = [0u8; 16];
+ ipv6_bytes.copy_from_slice(bytes);
+ Ok(IpAddr::from(ipv6_bytes))
+ } else {
+ log::error!("Expected either 4 or 16 bytes, got {} bytes", bytes.len());
+ Err(format!("Invalid bytes for IP address: {:?}", bytes).into())
+ }
+}
+
+pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> {
+ match buffer.len() {
+ 32 => {
+ let mut key = [0u8; 32];
+ key.clone_from_slice(buffer);
+ Ok(key)
+ }
+ anything_else => Err(format!("Unexpected length of key: {}", anything_else).into()),
+ }
+}
+
+pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> {
+ if buffer.len() != mem::size_of::<libc::sockaddr_in6>()
+ && buffer.len() != mem::size_of::<libc::sockaddr_in>()
+ {
+ return Err(format!(
+ "Unexpected length for sockaddr_in: {}, expected {} or {}",
+ buffer.len(),
+ mem::size_of::<libc::sockaddr_in6>(),
+ mem::size_of::<libc::sockaddr_in>()
+ )
+ .into());
+ }
+ let ptr = buffer.as_ptr();
+ const AF_INET: u16 = libc::AF_INET as u16;
+ const AF_INET6: u16 = libc::AF_INET6 as u16;
+
+ match NativeEndian::read_u16(buffer) {
+ AF_INET => unsafe {
+ let sockaddr: *const libc::sockaddr_in = ptr as *const _;
+ Ok(InetAddr::V4(*sockaddr).into())
+ },
+ AF_INET6 => unsafe {
+ let sockaddr: *const libc::sockaddr_in6 = ptr as *const _;
+ Ok(InetAddr::V6(*sockaddr))
+ },
+ unexpected_addr_family => {
+ Err(format!("Unexpected address family: {}", unexpected_addr_family).into())
+ }
+ }
+}
+
+pub fn parse_timespec(buffer: &[u8]) -> Result<TimeSpec, DecodeError> {
+ if buffer.len() != mem::size_of::<libc::timespec>() {
+ return Err(format!("Unexpected size for timespec: {}", buffer.len()).into());
+ }
+
+ Ok(TimeSpec::from(libc::timespec {
+ tv_sec: NativeEndian::read_i64(buffer),
+ // TODO: become compatible with 32-bit systems maybe?
+ tv_nsec: NativeEndian::read_i64(buffer),
+ }))
+}
+
+pub fn parse_cstring(buffer: &[u8]) -> Result<CString, DecodeError> {
+ Ok(CStr::from_bytes_with_nul(buffer)
+ .map_err(|err| format!("{}", err))?
+ .into())
+}
+
+pub fn parse_genlmsghdr(buffer: &[u8]) -> Result<(u8, u8), DecodeError> {
+ const GENLMSGHDR_SIZE: usize = mem::size_of::<libc::genlmsghdr>();
+ if buffer.len() < GENLMSGHDR_SIZE {
+ return Err(format!(
+ "Expected at least {}, got {}",
+ GENLMSGHDR_SIZE,
+ buffer.len()
+ )
+ .into());
+ }
+
+ Ok((buffer[0], buffer[1]))
+}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
new file mode 100644
index 0000000000..1e587712cc
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
@@ -0,0 +1,899 @@
+use super::{super::config::Config, parsers, Error};
+use byteorder::{ByteOrder, NativeEndian};
+use ipnetwork::IpNetwork;
+use netlink_packet_core::{
+ NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable,
+};
+use netlink_packet_utils::{
+ nla::{Nla, NlaBuffer, NlasIterator, NLA_F_NESTED},
+ traits::{Emitable, Parseable},
+ DecodeError,
+};
+use nix::sys::{socket::InetAddr, time::TimeSpec};
+use std::{ffi::CString, io::Write, mem, net::IpAddr};
+
+/// WireGuard netlink constants
+mod constants {
+ #![allow(dead_code)]
+ pub const WG_GENL_VERSION: u8 = 1;
+
+ /// Command constants
+ pub const WG_CMD_GET_DEVICE: u8 = 0;
+ pub const WG_CMD_SET_DEVICE: u8 = 1;
+
+ // wgdevice_flag
+ pub const WGDEVICE_F_REPLACE_PEERS: u32 = 1 << 0;
+
+ // wgdevice_attribute
+ pub const WGDEVICE_A_UNSPEC: u16 = 0;
+ pub const WGDEVICE_A_IFINDEX: u16 = 1;
+ pub const WGDEVICE_A_IFNAME: u16 = 2;
+ pub const WGDEVICE_A_PRIVATE_KEY: u16 = 3;
+ pub const WGDEVICE_A_PUBLIC_KEY: u16 = 4;
+ pub const WGDEVICE_A_FLAGS: u16 = 5;
+ pub const WGDEVICE_A_LISTEN_PORT: u16 = 6;
+ pub const WGDEVICE_A_FWMARK: u16 = 7;
+ pub const WGDEVICE_A_PEERS: u16 = 8;
+
+ // wgpeer_flag
+ pub const WGPEER_F_REMOVE_ME: u32 = 1 << 0;
+ pub const WGPEER_F_REPLACE_ALLOWEDIPS: u32 = 1 << 1;
+ pub const WGPEER_F_UPDATE_ONLY: u32 = 1 << 2;
+
+ // wgpeer_attribute
+ pub const WGPEER_A_UNSPEC: u16 = 0;
+ pub const WGPEER_A_PUBLIC_KEY: u16 = 1;
+ pub const WGPEER_A_PRESHARED_KEY: u16 = 2;
+ pub const WGPEER_A_FLAGS: u16 = 3;
+ pub const WGPEER_A_ENDPOINT: u16 = 4;
+ pub const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5;
+ pub const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6;
+ pub const WGPEER_A_RX_BYTES: u16 = 7;
+ pub const WGPEER_A_TX_BYTES: u16 = 8;
+ pub const WGPEER_A_ALLOWEDIPS: u16 = 9;
+ pub const WGPEER_A_PROTOCOL_VERSION: u16 = 10;
+
+ // wgallowedip_attribute
+ pub const WGALLOWEDIP_A_UNSPEC: u16 = 0;
+ pub const WGALLOWEDIP_A_FAMILY: u16 = 1;
+ pub const WGALLOWEDIP_A_IPADDR: u16 = 2;
+ pub const WGALLOWEDIP_A_CIDR_MASK: u16 = 3;
+}
+
+use constants::*;
+pub use constants::{WG_CMD_GET_DEVICE, WG_CMD_SET_DEVICE};
+
+type PrivateKey = [u8; 32];
+type PublicKey = [u8; 32];
+type PresharedKey = [u8; 32];
+
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct DeviceMessage {
+ pub nlas: Vec<DeviceNla>,
+ pub message_type: u16,
+ pub command: u8,
+}
+
+impl DeviceMessage {
+ pub fn reset_config(message_type: u16, interface_index: u32, config: &Config) -> DeviceMessage {
+ let mut peers = vec![];
+
+ for peer in config.peers.iter() {
+ let peer_endpoint = InetAddr::from_std(&peer.endpoint);
+ let allowed_ips = peer.allowed_ips.iter().map(From::from).collect();
+ peers.push(PeerMessage(vec![
+ PeerNla::PublicKey(*peer.public_key.as_bytes()),
+ PeerNla::Endpoint(peer_endpoint),
+ PeerNla::AllowedIps(allowed_ips),
+ PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
+ ]));
+ }
+
+ let nlas = vec![
+ DeviceNla::IfIndex(interface_index),
+ DeviceNla::ListenPort(0),
+ DeviceNla::Fwmark(crate::linux::TUNNEL_FW_MARK),
+ DeviceNla::PrivateKey(config.tunnel.private_key.to_bytes()),
+ DeviceNla::Flags(WGDEVICE_F_REPLACE_PEERS),
+ DeviceNla::Peers(peers),
+ ];
+
+
+ Self {
+ nlas,
+ message_type,
+ command: WG_CMD_SET_DEVICE,
+ }
+ }
+
+ pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> {
+ let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?;
+ if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ {
+ return Err(Error::InterfaceNameError);
+ }
+
+ Ok(Self {
+ message_type,
+ nlas: vec![DeviceNla::IfName(c_name)],
+ command: WG_CMD_GET_DEVICE,
+ })
+ }
+
+ pub fn get_by_index(message_type: u16, index: u32) -> Self {
+ Self {
+ message_type,
+ nlas: vec![DeviceNla::IfIndex(index)],
+ command: WG_CMD_GET_DEVICE,
+ }
+ }
+
+ // All WireGuard netlink messages should start with a libc::genlmsghdr, for which the first
+ // byte contains the command.
+ fn read_genlmsghdr(buff: &[u8]) -> Result<u8, Error> {
+ if buff.len() < mem::size_of::<libc::genlmsghdr>() {
+ return Err(Error::Truncated);
+ }
+
+ let cmd = buff[0];
+ if cmd == WG_CMD_GET_DEVICE || cmd == WG_CMD_SET_DEVICE {
+ Ok(cmd)
+ } else {
+ Err(Error::UnnkownWireguardCommmand(cmd))
+ }
+ }
+}
+
+impl NetlinkSerializable<DeviceMessage> for DeviceMessage {
+ fn message_type(&self) -> u16 {
+ self.message_type
+ }
+
+ fn buffer_len(&self) -> usize {
+ // add the genlmsghdr
+ mem::size_of::<libc::genlmsghdr>() +
+ // size of all of the NLAs
+ self.nlas.as_slice().buffer_len()
+ }
+
+ fn serialize(&self, mut buffer: &mut [u8]) {
+ let command_buf = [self.command, WG_GENL_VERSION, 0u8, 0u8];
+ let _ = buffer.write(&command_buf).unwrap();
+ self.nlas.as_slice().emit(&mut buffer)
+ }
+}
+impl Into<NetlinkPayload<DeviceMessage>> for DeviceMessage {
+ fn into(self) -> NetlinkPayload<DeviceMessage> {
+ NetlinkPayload::InnerMessage(self)
+ }
+}
+
+impl NetlinkDeserializable<DeviceMessage> for DeviceMessage {
+ type Error = Error;
+ fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result<DeviceMessage, Self::Error> {
+ let command = Self::read_genlmsghdr(payload)?;
+ let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..];
+ let mut nlas = vec![];
+ for buf in NlasIterator::new(new_payload) {
+ nlas.push(
+ DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?,
+ );
+ }
+
+ Ok(DeviceMessage {
+ nlas,
+ command,
+ message_type: header.message_type,
+ })
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum DeviceNla {
+ IfIndex(u32),
+ IfName(CString),
+ Flags(u32),
+ PrivateKey(PrivateKey),
+ PublicKey(PublicKey),
+ ListenPort(u16),
+ Fwmark(u32),
+ Peers(Vec<PeerMessage>),
+ Unspec(Vec<u8>),
+}
+
+impl Nla for DeviceNla {
+ fn value_len(&self) -> usize {
+ use DeviceNla::*;
+ match self {
+ IfIndex(_) | Fwmark(_) | Flags(_) => 4,
+ IfName(name) => name.as_bytes_with_nul().len(),
+ PrivateKey(key) | PublicKey(key) => key.len(),
+ ListenPort(_) => 2,
+ Peers(peers) => peers.as_slice().buffer_len(),
+ Unspec(payload) => payload.len(),
+ }
+ }
+
+ fn kind(&self) -> u16 {
+ use DeviceNla::*;
+ match self {
+ IfIndex(_) => WGDEVICE_A_IFINDEX,
+ IfName(_) => WGDEVICE_A_IFNAME,
+ PrivateKey(_) => WGDEVICE_A_PRIVATE_KEY,
+ PublicKey(_) => WGDEVICE_A_PUBLIC_KEY,
+ Flags(_) => WGDEVICE_A_FLAGS,
+ ListenPort(_) => WGDEVICE_A_LISTEN_PORT,
+ Fwmark(_) => WGDEVICE_A_FWMARK,
+ Peers(_) => WGDEVICE_A_PEERS | NLA_F_NESTED,
+ Unspec(_) => WGDEVICE_A_UNSPEC,
+ }
+ }
+
+ fn emit_value(&self, mut buffer: &mut [u8]) {
+ use DeviceNla::*;
+ match self {
+ IfIndex(value) | Fwmark(value) | Flags(value) => {
+ NativeEndian::write_u32(buffer, *value)
+ }
+ IfName(interface_name) => {
+ let _ = buffer
+ .write(interface_name.as_bytes_with_nul())
+ .expect("Failed to write interface name");
+ }
+ PrivateKey(key) | PublicKey(key) => {
+ let _ = buffer.write(key).expect("Failed to write key");
+ }
+ ListenPort(port) => NativeEndian::write_u16(buffer, *port),
+ Peers(peers) => {
+ peers.as_slice().emit(buffer);
+ }
+ Unspec(payload) => {
+ let _ = buffer.write(&payload).expect("Failed to write ");
+ }
+ }
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized + core::fmt::Debug> Parseable<NlaBuffer<&'a T>>
+ for DeviceNla
+{
+ fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ use DeviceNla::*;
+ let value = buf.value();
+ let kind = buf.kind();
+ let nla = match kind {
+ WGDEVICE_A_IFINDEX => IfIndex(parsers::parse_u32(value)?),
+ WGDEVICE_A_IFNAME => IfName(parsers::parse_cstring(value)?),
+ WGDEVICE_A_PRIVATE_KEY => PrivateKey(parsers::parse_wg_key(value)?.into()),
+ WGDEVICE_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()),
+ WGDEVICE_A_FLAGS => Flags(parsers::parse_u32(value)?),
+ WGDEVICE_A_LISTEN_PORT => ListenPort(parsers::parse_u16(value)?),
+ WGDEVICE_A_FWMARK => Fwmark(parsers::parse_u32(value)?),
+ WGDEVICE_A_PEERS => {
+ let peers = NlasIterator::new(value)
+ .map(|nla_bytes| {
+ let buf = nla_bytes?;
+ let val = buf.value();
+ PeerMessage::parse(&val)
+ })
+ .collect::<Result<Vec<PeerMessage>, DecodeError>>()?;
+ Peers(peers)
+ }
+ WGDEVICE_A_UNSPEC => Unspec(value.to_vec()),
+ _ => {
+ return Err(format!("Unexpected device attribute kind: {}", buf.kind()).into());
+ }
+ };
+ Ok(nla)
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct PeerMessage(pub Vec<PeerNla>);
+
+impl PeerMessage {
+ fn parse(payload: &[u8]) -> Result<Self, DecodeError> {
+ let mut nlas = vec![];
+
+ let nla_iter = NlasIterator::new(&payload);
+ for buffer in nla_iter {
+ nlas.push(PeerNla::parse(&buffer?)?)
+ }
+ Ok(Self(nlas))
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerMessage {
+ fn parse(payload: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ Ok(Self(
+ NlasIterator::new(&payload.into_inner())
+ .map(|buffer| PeerNla::parse(&buffer?))
+ .collect::<Result<Vec<PeerNla>, DecodeError>>()?,
+ ))
+ }
+}
+
+impl Nla for PeerMessage {
+ fn value_len(&self) -> usize {
+ self.0.as_slice().buffer_len()
+ }
+
+ fn kind(&self) -> u16 {
+ NLA_F_NESTED
+ }
+
+ fn emit_value(&self, buffer: &mut [u8]) {
+ self.0.as_slice().emit(buffer);
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum PeerNla {
+ Unspec(Vec<u8>),
+ PublicKey(PublicKey),
+ PresharedKey(PresharedKey),
+ Flags(u32),
+ Endpoint(InetAddr),
+ PersistentKeepaliveInterval(u16),
+ LastHandshakeTime(TimeSpec),
+ RxBytes(u64),
+ TxBytes(u64),
+ AllowedIps(Vec<AllowedIpMessage>),
+ ProtocolVersion(u32),
+}
+
+impl Nla for PeerNla {
+ fn value_len(&self) -> usize {
+ use PeerNla::*;
+ match self {
+ PublicKey(key) | PresharedKey(key) => key.len(),
+ Endpoint(endpoint) => match &endpoint {
+ InetAddr::V4(_) => mem::size_of::<libc::sockaddr_in>(),
+ InetAddr::V6(_) => mem::size_of::<libc::sockaddr_in6>(),
+ },
+ PersistentKeepaliveInterval(_) => 2,
+ LastHandshakeTime(_) => mem::size_of::<libc::timespec>(),
+ RxBytes(_) | TxBytes(_) => 8,
+ AllowedIps(ips) => ips.as_slice().buffer_len(),
+ Flags(_) | ProtocolVersion(_) => 4,
+ Unspec(payload) => payload.len(),
+ }
+ }
+
+ fn kind(&self) -> u16 {
+ use PeerNla::*;
+ match self {
+ PublicKey(_) => WGPEER_A_PUBLIC_KEY,
+ PresharedKey(_) => WGPEER_A_PRESHARED_KEY,
+ Flags(_) => WGPEER_A_FLAGS,
+ Endpoint(_) => WGPEER_A_ENDPOINT,
+ PersistentKeepaliveInterval(_) => WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
+ LastHandshakeTime(_) => WGPEER_A_LAST_HANDSHAKE_TIME,
+ RxBytes(_) => WGPEER_A_RX_BYTES,
+ TxBytes(_) => WGPEER_A_TX_BYTES,
+ AllowedIps(_) => WGPEER_A_ALLOWEDIPS | NLA_F_NESTED,
+ ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION,
+ Unspec(_) => WGPEER_A_UNSPEC,
+ }
+ }
+
+ fn emit_value(&self, mut buffer: &mut [u8]) {
+ use PeerNla::*;
+ match self {
+ PublicKey(key) | PresharedKey(key) => {
+ let _ = buffer.write(key).expect("Buffer too small for a key");
+ }
+ Flags(value) | ProtocolVersion(value) => NativeEndian::write_u32(buffer, *value),
+ Endpoint(endpoint) => match &endpoint {
+ InetAddr::V4(sockaddr_in) => {
+ let slice = unsafe { struct_as_slice(sockaddr_in) };
+ buffer
+ .write(slice)
+ .expect("Buffer too small for sockaddr_in");
+ }
+ InetAddr::V6(sockaddr_in6) => {
+ buffer
+ .write(unsafe { struct_as_slice(sockaddr_in6) })
+ .expect("Buffer too small for sockaddr_in6");
+ }
+ },
+ PersistentKeepaliveInterval(interval) => {
+ NativeEndian::write_u16(buffer, *interval);
+ }
+ LastHandshakeTime(last_handshake) => {
+ let timespec: &libc::timespec = last_handshake.as_ref();
+ buffer
+ .write(unsafe { struct_as_slice(timespec) })
+ .expect("Buffer too small for timespec");
+ }
+ RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes),
+ AllowedIps(ips) => ips.as_slice().emit(buffer),
+ Unspec(payload) => {
+ let _ = buffer
+ .write(&payload)
+ .expect("Buffer too small for unspecified payload");
+ }
+ }
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerNla {
+ fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ use PeerNla::*;
+ let value = buf.value();
+ let nla = match buf.kind() {
+ WGPEER_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()),
+ WGPEER_A_PRESHARED_KEY => PresharedKey(parsers::parse_wg_key(value)?.into()),
+ WGPEER_A_FLAGS => Flags(parsers::parse_u32(value)?),
+ WGPEER_A_ENDPOINT => Endpoint(parsers::parse_inet_sockaddr(value)?),
+ WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => {
+ PersistentKeepaliveInterval(parsers::parse_u16(value)?)
+ }
+
+ WGPEER_A_LAST_HANDSHAKE_TIME => LastHandshakeTime(parsers::parse_timespec(value)?),
+ WGPEER_A_RX_BYTES => RxBytes(parsers::parse_u64(value)?),
+ WGPEER_A_TX_BYTES => TxBytes(parsers::parse_u64(value)?),
+ WGPEER_A_ALLOWEDIPS => {
+ let nlas = NlasIterator::new(value)
+ .map(|nla_buffer| AllowedIpMessage::parse(&nla_buffer?))
+ .collect::<Result<Vec<_>, DecodeError>>()?;
+
+ AllowedIps(nlas)
+ }
+ WGPEER_A_PROTOCOL_VERSION => ProtocolVersion(parsers::parse_u32(value)?),
+ WGPEER_A_UNSPEC => Unspec(value.to_vec()),
+ _ => {
+ return Err(format!("Unexpected peer attribute kind: {}", buf.kind()).into());
+ }
+ };
+ Ok(nla)
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct AllowedIpMessage(Vec<AllowedIpNla>);
+
+impl From<&IpNetwork> for AllowedIpMessage {
+ fn from(ip: &IpNetwork) -> Self {
+ use AllowedIpNla::*;
+ let address_family = if ip.is_ipv4() {
+ libc::AF_INET
+ } else {
+ libc::AF_INET6
+ };
+
+ AllowedIpMessage(vec![
+ AddressFamily(address_family as u16),
+ CidrMask(ip.prefix()),
+ IpAddr(ip.ip().into()),
+ ])
+ }
+}
+
+impl Nla for AllowedIpMessage {
+ fn value_len(&self) -> usize {
+ self.0.as_slice().buffer_len()
+ }
+
+ fn kind(&self) -> u16 {
+ NLA_F_NESTED
+ }
+
+ fn emit_value(&self, buffer: &mut [u8]) {
+ self.0.as_slice().emit(buffer);
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpMessage {
+ fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ let nlas = NlasIterator::new(buf.value())
+ .map(|buffer| AllowedIpNla::parse(&buffer?))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(AllowedIpMessage(nlas))
+ }
+}
+
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum AllowedIpNla {
+ AddressFamily(u16),
+ IpAddr(IpAddr),
+ CidrMask(u8),
+ Unspec(Vec<u8>),
+}
+
+impl Nla for AllowedIpNla {
+ fn value_len(&self) -> usize {
+ use AllowedIpNla::*;
+ match &self {
+ AddressFamily(_) => 2,
+ IpAddr(addr) => ip_addr_to_bytes(addr).len(),
+ CidrMask(_) => 1,
+ Unspec(payload) => payload.len(),
+ }
+ }
+
+ fn kind(&self) -> u16 {
+ use AllowedIpNla::*;
+ match &self {
+ AddressFamily(_) => WGALLOWEDIP_A_FAMILY,
+ IpAddr(_) => WGALLOWEDIP_A_IPADDR,
+ CidrMask(_) => WGALLOWEDIP_A_CIDR_MASK,
+ Unspec(_) => WGALLOWEDIP_A_UNSPEC,
+ }
+ }
+
+ fn emit_value(&self, mut buffer: &mut [u8]) {
+ use AllowedIpNla::*;
+ match self {
+ AddressFamily(af) => {
+ NativeEndian::write_u16(buffer, *af);
+ }
+ IpAddr(ip_addr) => {
+ buffer
+ .write(&ip_addr_to_bytes(ip_addr))
+ .expect("Buffer too small for AllowedIpNla::IpAddr");
+ }
+ CidrMask(cidr_mask) => buffer[0] = *cidr_mask,
+ Unspec(payload) => {
+ let _ = buffer
+ .write(&payload)
+ .expect("Buffer too small for unspec payload");
+ }
+ }
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpNla {
+ fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
+ use AllowedIpNla::*;
+ let value = buf.value();
+ let nla = match buf.kind() {
+ WGALLOWEDIP_A_FAMILY => AddressFamily(parsers::parse_u16(value)?),
+ WGALLOWEDIP_A_IPADDR => IpAddr(parsers::parse_ip_addr(value)?),
+ WGALLOWEDIP_A_CIDR_MASK => CidrMask(parsers::parse_u8(value)?),
+ WGALLOWEDIP_A_UNSPEC => Unspec(value.to_vec()),
+ _ => Err(format!(
+ "Unexpected allowed IP attribute kind: {}",
+ buf.kind()
+ ))?,
+ };
+ Ok(nla)
+ }
+}
+
+unsafe fn struct_as_slice<T: Sized>(t: &T) -> &[u8] {
+ let s = mem::size_of::<T>();
+ let ptr = t as *const T as *const u8;
+ std::slice::from_raw_parts(ptr, s)
+}
+
+fn ip_addr_to_bytes(addr: &IpAddr) -> Vec<u8> {
+ match addr {
+ IpAddr::V4(addr) => addr.octets().to_vec(),
+ IpAddr::V6(addr) => addr.octets().to_vec(),
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use nix::sys::time::TimeValLike;
+ use std::net::Ipv4Addr;
+
+
+ #[test]
+ fn deserialize_netlink_message() {
+ #[rustfmt::skip]
+ let payload = vec![
+ 0x00, 0x01, 0x00, 0x00,
+ // 6 bytes of WGDEVICE_A_LISTEN_PORT 51820 + 2 bytes of padding
+ 0x06, 0x00, 0x06, 0x00, 0x6c, 0xca, 0x00, 0x00,
+ // 8 bytes of WGDEVICE_A_FWMARK 0
+ 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 8 bytes of WGDEVIEC_A_IFINDEX 320
+ 0x08, 0x00, 0x01, 0x00, 0x40, 0x01, 0x00, 0x00,
+ // 12 bytes of WGDEVICE_A_IFNAME "wg-test\0"
+ 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x00,
+ // 36 bytes of WGDEVICE_A_PRIVATE_KEY OEf0rWXfVRarrw8nNbTBxkk3NTu8GjRKrbMW1aFH/H0=
+ 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16, 0xab, 0xaf,
+ 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a, 0x34, 0x4a,
+ 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d,
+ // 36 bytes of WGDEVICE_A_PUBLIC_KEY Ztqy3r8VO1N8tHwpWwqGx1S6G9o12BRdy1JESr2OYzs=
+ 0x24, 0x00, 0x04, 0x00, 0x66, 0xda, 0xb2, 0xde, 0xbf, 0x15, 0x3b, 0x53, 0x7c, 0xb4,
+ 0x7c, 0x29, 0x5b, 0x0a, 0x86, 0xc7, 0x54, 0xba, 0x1b, 0xda, 0x35, 0xd8, 0x14, 0x5d,
+ 0xcb, 0x52, 0x44, 0x4a, 0xbd, 0x8e, 0x63, 0x3b,
+ // 380 bytes of WGDEVICE_A_PEERS
+ 0x7c, 0x01, 0x08, 0x80,
+ // 188 bytes of WGPEER attributes
+ 0xbc, 0x00, 0x00, 0x80,
+ // 36 bytes of WGPEER_A_PUBLIC_KEY IOBEBReIZ+XOOyLn14vW7FBRuweaxfskq5wwSZEvhjY=
+ 0x24, 0x00, 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5,
+ 0xce, 0x3b, 0x22, 0xe7, 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07,
+ 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c, 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36,
+ // 36 bytes of WGPEER_A_PRESHARED_KEY (all zeroes)
+ 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME 0
+ 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL 0
+ 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 12 bytes of WGPEER_A_TX_BYTES 0
+ 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 12 bytes of WGPEER_A_RX_BYTES 0
+ 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 8 bytes of WGPEER_A_PROTOCOL_VERSION 1
+ 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00,
+ // 20 bytes of WGPEER_A_ENDPOINT 192.168.39.2:9797
+ 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x01,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 32 bytes of WGPEER_A_ALLOWEDIPS
+ 0x20, 0x00, 0x09, 0x80,
+ // 28 bytes of WGALLOWDIP_A_*
+ 0x1c, 0x00,0x00, 0x80,
+ // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32
+ 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00,
+ // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4)
+ 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00,
+ // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.1
+ 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x01,
+ // 188 bvytes of WGPEER attributes
+ 0xbc, 0x00, 0x00, 0x80,
+ // 36 bytes of WGPEER_A_PUBLIC_KEY
+ 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7,
+ 0xc2, 0x9d, 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e,
+ 0xd9, 0x38, 0x0c, 0x92, 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c,
+ // 36 bytes of WGPEER_A_PRESHARED_KEY
+ 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME
+ 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL + 2 bytes of padding
+ 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 12 bytes of WGPEER_A_TX_BYTES
+ 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 12 bytes of WGPEER_A_RX_BYTES
+ 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 8 bytes of WGPEER_A_PROTOCOL_VERSION
+ 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00,
+ // 20 bytes of WGPEER_A_ENDPOINT
+ 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ // 32 bytes of WGPEER_A_ALLOWEDIPS
+ 0x20, 0x00, 0x09, 0x80,
+ // 28 bytes of WGALLOWDIP_A_*
+ 0x1c, 0x00, 0x00, 0x80,
+ // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32
+ 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00,
+ // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4)
+ 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00,
+ // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.2
+ 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x02,
+ ];
+ let header = NetlinkHeader {
+ length: payload.len() as u32,
+ message_type: 0,
+ flags: 0,
+ sequence_number: 0,
+ port_number: 0,
+ };
+ let message = DeviceMessage::deserialize(&header, &payload).unwrap();
+
+ let mut serialized_message = vec![0u8; payload.len()];
+
+ message.serialize(&mut serialized_message);
+
+ assert_eq!(message, sample_get_message());
+ assert_eq!(&payload, &serialized_message)
+ }
+
+ fn sample_get_message() -> DeviceMessage {
+ use AllowedIpNla::*;
+ use DeviceNla::*;
+ use PeerNla::*;
+
+ let if_name = CString::new(b"wg-test".to_vec()).unwrap();
+
+ let peer_1 = PeerMessage(
+ [
+ PeerNla::PublicKey([
+ 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80,
+ 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54,
+ ]),
+ PeerNla::PresharedKey([
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0,
+ ]),
+ LastHandshakeTime(TimeSpec::seconds(0)),
+ PersistentKeepaliveInterval(0),
+ TxBytes(0),
+ RxBytes(0),
+ ProtocolVersion(1),
+ Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())),
+ AllowedIps(
+ [AllowedIpMessage(
+ [
+ CidrMask(32),
+ AddressFamily(2),
+ IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()),
+ ]
+ .to_vec(),
+ )]
+ .to_vec()
+ .to_vec(),
+ ),
+ ]
+ .to_vec(),
+ );
+
+ let peer_2 = PeerMessage(
+ [
+ PeerNla::PublicKey([
+ 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24,
+ 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44,
+ ]),
+ PresharedKey([
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0,
+ ]),
+ LastHandshakeTime(TimeSpec::seconds(0)),
+ PersistentKeepaliveInterval(0),
+ TxBytes(0),
+ RxBytes(0),
+ ProtocolVersion(1),
+ Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())),
+ AllowedIps(
+ [AllowedIpMessage(
+ vec![
+ CidrMask(32),
+ AddressFamily(2),
+ IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()),
+ ]
+ .to_vec(),
+ )]
+ .to_vec(),
+ ),
+ ]
+ .to_vec(),
+ );
+
+ DeviceMessage {
+ command: WG_CMD_GET_DEVICE,
+ message_type: 0,
+ nlas: [
+ ListenPort(51820),
+ Fwmark(0),
+ IfIndex(320),
+ IfName(if_name),
+ PrivateKey([
+ 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73,
+ 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125,
+ ]),
+ DeviceNla::PublicKey([
+ 102, 218, 178, 222, 191, 21, 59, 83, 124, 180, 124, 41, 91, 10, 134, 199, 84,
+ 186, 27, 218, 53, 216, 20, 93, 203, 82, 68, 74, 189, 142, 99, 59,
+ ]),
+ Peers([peer_1, peer_2].to_vec()),
+ ]
+ .to_vec(),
+ }
+ }
+
+ pub fn sample_set_message() -> DeviceMessage {
+ use AllowedIpNla::*;
+ use DeviceNla::*;
+ use PeerNla::*;
+
+ let if_name = CString::new("wg-test".to_string()).unwrap();
+
+ let peer_1 = PeerMessage(
+ [
+ PeerNla::PublicKey([
+ 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80,
+ 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54,
+ ]),
+ Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())),
+ PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
+ AllowedIps(
+ [AllowedIpMessage(
+ [
+ AddressFamily(2),
+ IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()),
+ CidrMask(32),
+ ]
+ .to_vec(),
+ )]
+ .to_vec()
+ .to_vec(),
+ ),
+ ]
+ .to_vec(),
+ );
+
+ let peer_2 = PeerMessage(
+ [
+ PeerNla::PublicKey([
+ 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24,
+ 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44,
+ ]),
+ Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())),
+ PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
+ AllowedIps(
+ [AllowedIpMessage(
+ vec![
+ AddressFamily(2),
+ IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()),
+ CidrMask(32),
+ ]
+ .to_vec(),
+ )]
+ .to_vec(),
+ ),
+ ]
+ .to_vec(),
+ );
+
+ DeviceMessage {
+ command: WG_CMD_SET_DEVICE,
+ message_type: 0,
+ nlas: [
+ IfName(if_name),
+ PrivateKey([
+ 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73,
+ 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125,
+ ]),
+ ListenPort(51820),
+ Peers([peer_1, peer_2].to_vec()),
+ ]
+ .to_vec(),
+ }
+ }
+
+
+ #[test]
+ fn serialize_netlink_message() {
+ let expected_payload: &[u8] = &[
+ 0x01, 0x01, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73,
+ 0x74, 0x00, 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16,
+ 0xab, 0xaf, 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a,
+ 0x34, 0x4a, 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d, 0x06, 0x00, 0x06, 0x00,
+ 0x6c, 0xca, 0x00, 0x00, 0xcc, 0x00, 0x08, 0x80, 0x64, 0x00, 0x00, 0x80, 0x24, 0x00,
+ 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5, 0xce, 0x3b, 0x22, 0xe7,
+ 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07, 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c,
+ 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45,
+ 0xc0, 0xa8, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
+ 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80,
+ 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8,
+ 0x27, 0x01, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x80,
+ 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7, 0xc2, 0x9d,
+ 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e, 0xd9, 0x38, 0x0c, 0x92,
+ 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00,
+ 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00,
+ 0x00, 0x80, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00,
+ 0xc0, 0xa8, 0x27, 0x02, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00,
+ ];
+
+ let mut message = sample_set_message();
+ message.command = WG_CMD_SET_DEVICE;
+
+
+ let mut payload_buffer = vec![0u8; message.buffer_len()];
+ message.serialize(&mut payload_buffer);
+ let header = NetlinkHeader {
+ length: payload_buffer.len() as u32,
+ message_type: 0,
+ flags: 0,
+ sequence_number: 0,
+ port_number: 0,
+ };
+ let deserialized_device = DeviceMessage::deserialize(&header, &payload_buffer).unwrap();
+
+ assert_eq!(message, deserialized_device);
+ assert_eq!(payload_buffer, expected_payload);
+ }
+}
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 43595a4e79..dd259a87d9 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -11,7 +11,7 @@ use futures01::{
Async, Future, Stream,
};
use talpid_types::{
- net::{Endpoint, TunnelParameters},
+ net::TunnelParameters,
tunnel::{ErrorStateCause, FirewallPolicyError},
BoxedError, ErrorExt,
};
@@ -52,19 +52,7 @@ impl ConnectedState {
&self,
shared_values: &mut SharedTunnelStateValues,
) -> Result<(), FirewallPolicyError> {
- // If a proxy is specified we need to pass it on as the peer endpoint.
- let peer_endpoint = self.get_endpoint_from_params();
-
- let policy = FirewallPolicy::Connected {
- peer_endpoint,
- tunnel: self.metadata.clone(),
- allow_lan: shared_values.allow_lan,
- #[cfg(windows)]
- relay_client: TunnelMonitor::get_relay_client(
- &shared_values.resource_dir,
- &self.tunnel_parameters,
- ),
- };
+ let policy = self.get_firewall_policy(shared_values);
shared_values
.firewall
.apply_policy(policy)
@@ -85,13 +73,18 @@ impl ConnectedState {
})
}
- fn get_endpoint_from_params(&self) -> Endpoint {
- match self.tunnel_parameters {
- TunnelParameters::OpenVpn(ref params) => match params.proxy {
- Some(ref proxy_settings) => proxy_settings.get_endpoint().endpoint,
- None => params.config.endpoint,
- },
- TunnelParameters::Wireguard(ref params) => params.connection.get_endpoint(),
+ fn get_firewall_policy(&self, shared_values: &SharedTunnelStateValues) -> FirewallPolicy {
+ FirewallPolicy::Connected {
+ peer_endpoint: self.tunnel_parameters.get_next_hop_endpoint(),
+ tunnel: self.metadata.clone(),
+ allow_lan: shared_values.allow_lan,
+ #[cfg(windows)]
+ relay_client: TunnelMonitor::get_relay_client(
+ &shared_values.resource_dir,
+ &self.tunnel_parameters,
+ ),
+ #[cfg(target_os = "linux")]
+ use_fwmark: self.tunnel_parameters.get_proxy_endpoint().is_none(),
}
}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index d206b34a23..859ef52eb0 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -22,7 +22,7 @@ use std::{
time::{Duration, Instant},
};
use talpid_types::{
- net::{openvpn, TunnelParameters},
+ net::TunnelParameters,
tunnel::{ErrorStateCause, FirewallPolicyError},
ErrorExt,
};
@@ -48,13 +48,7 @@ impl ConnectingState {
shared_values: &mut SharedTunnelStateValues,
params: &TunnelParameters,
) -> Result<(), FirewallPolicyError> {
- let proxy = &get_openvpn_proxy_settings(&params);
- let endpoint = params.get_tunnel_endpoint().endpoint;
-
- let peer_endpoint = match proxy {
- Some(proxy_settings) => proxy_settings.get_endpoint().endpoint,
- None => endpoint,
- };
+ let peer_endpoint = params.get_next_hop_endpoint();
let policy = FirewallPolicy::Connecting {
peer_endpoint,
@@ -62,6 +56,8 @@ impl ConnectingState {
allow_lan: shared_values.allow_lan,
#[cfg(windows)]
relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, &params),
+ #[cfg(target_os = "linux")]
+ use_fwmark: params.get_proxy_endpoint().is_none(),
};
shared_values
.firewall
@@ -312,15 +308,6 @@ impl ConnectingState {
}
}
-fn get_openvpn_proxy_settings(
- tunnel_parameters: &TunnelParameters,
-) -> &Option<openvpn::ProxySettings> {
- match tunnel_parameters {
- TunnelParameters::OpenVpn(ref config) => &config.proxy,
- _ => &None,
- }
-}
-
fn should_retry(error: &tunnel::Error) -> bool {
#[cfg(not(windows))]
use tunnel::wireguard::{Error, TunnelError};
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 97292fd381..15bf33a5bb 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -37,6 +37,25 @@ impl TunnelParameters {
}
}
+ // Returns the endpoint that will be connected to
+ pub fn get_next_hop_endpoint(&self) -> Endpoint {
+ match self {
+ TunnelParameters::OpenVpn(params) => params
+ .proxy
+ .as_ref()
+ .map(|proxy| proxy.get_endpoint().endpoint)
+ .unwrap_or(params.config.endpoint),
+ TunnelParameters::Wireguard(params) => params.connection.get_endpoint(),
+ }
+ }
+
+ pub fn get_proxy_endpoint(&self) -> Option<openvpn::ProxySettings> {
+ match self {
+ TunnelParameters::OpenVpn(params) => params.proxy.clone(),
+ _ => None,
+ }
+ }
+
pub fn get_generic_options(&self) -> &GenericTunnelOptions {
match &self {
TunnelParameters::OpenVpn(params) => &params.generic_options,