summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-09-05 10:17:09 +0200
committerDavid Lönnhager <david.l@mullvad.net>2024-04-16 14:43:15 +0200
commitaf96a710398870587df9e07ee6f5afd16b8d9888 (patch)
tree7a5dfc3ee3a388c74c5a144983b5dfba8a057902
parent99ae0b436f173b576343111cac38d6bec4ce2487 (diff)
downloadmullvadvpn-af96a710398870587df9e07ee6f5afd16b8d9888.tar.xz
mullvadvpn-af96a710398870587df9e07ee6f5afd16b8d9888.zip
Add DAITA Windows client and updated tuncfg
-rw-r--r--Cargo.lock135
-rw-r--r--Cargo.toml1
-rw-r--r--dist-assets/maybenot_machines4
-rw-r--r--docs/architecture.md2
-rw-r--r--gui/tasks/distribution.js1
-rw-r--r--mullvad-api/src/relay_list.rs3
-rw-r--r--mullvad-cli/src/cmds/relay.rs2
-rw-r--r--mullvad-cli/src/cmds/tunnel.rs29
-rw-r--r--mullvad-cli/src/format.rs13
-rw-r--r--mullvad-daemon/src/lib.rs55
-rw-r--r--mullvad-daemon/src/management_interface.rs19
-rw-r--r--mullvad-management-interface/proto/management_interface.proto10
-rw-r--r--mullvad-management-interface/src/client.rs12
-rw-r--r--mullvad-management-interface/src/types/conversions/custom_tunnel.rs2
-rw-r--r--mullvad-management-interface/src/types/conversions/net.rs6
-rw-r--r--mullvad-management-interface/src/types/conversions/relay_list.rs2
-rw-r--r--mullvad-management-interface/src/types/conversions/settings.rs11
-rw-r--r--mullvad-management-interface/src/types/conversions/wireguard.rs18
-rw-r--r--mullvad-relay-selector/src/relay_selector/detailer.rs9
-rw-r--r--mullvad-relay-selector/src/relay_selector/mod.rs16
-rw-r--r--mullvad-relay-selector/tests/relay_selector.rs6
-rw-r--r--mullvad-types/src/relay_list.rs4
-rw-r--r--mullvad-types/src/wireguard.rs13
-rw-r--r--talpid-core/src/tunnel/mod.rs1
-rw-r--r--talpid-tunnel-config-client/build.rs2
-rw-r--r--talpid-tunnel-config-client/examples/psk-exchange.rs6
-rw-r--r--talpid-tunnel-config-client/examples/tuncfg-server.rs105
-rw-r--r--talpid-tunnel-config-client/proto/ephemeralpeer.proto85
-rw-r--r--talpid-tunnel-config-client/proto/tunnel_config.proto83
-rw-r--r--talpid-tunnel-config-client/src/classic_mceliece.rs5
-rw-r--r--talpid-tunnel-config-client/src/kyber.rs5
-rw-r--r--talpid-tunnel-config-client/src/lib.rs156
-rw-r--r--talpid-types/src/net/mod.rs6
-rw-r--r--talpid-types/src/net/wireguard.rs7
-rw-r--r--talpid-wireguard/Cargo.toml2
-rw-r--r--talpid-wireguard/src/config.rs9
-rw-r--r--talpid-wireguard/src/connectivity_check.rs5
-rw-r--r--talpid-wireguard/src/lib.rs103
-rw-r--r--talpid-wireguard/src/wireguard_nt/daita.rs450
-rw-r--r--talpid-wireguard/src/wireguard_nt/mod.rs (renamed from talpid-wireguard/src/wireguard_nt.rs)172
40 files changed, 1331 insertions, 244 deletions
diff --git a/Cargo.lock b/Cargo.lock
index cc54dcb168..eedb5dd480 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -18,6 +18,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
+name = "adler32"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234"
+
+[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -53,6 +59,18 @@ dependencies = [
]
[[package]]
+name = "ahash"
+version = "0.8.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
+dependencies = [
+ "cfg-if",
+ "once_cell",
+ "version_check",
+ "zerocopy",
+]
+
+[[package]]
name = "aho-corasick"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -583,6 +601,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
+name = "core2"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
name = "cpufeatures"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -592,6 +619,15 @@ dependencies = [
]
[[package]]
+name = "crc32fast"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -715,6 +751,12 @@ dependencies = [
]
[[package]]
+name = "dary_heap"
+version = "0.3.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca"
+
+[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1219,6 +1261,15 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
+dependencies = [
+ "ahash",
+]
+
+[[package]]
+name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
@@ -1757,6 +1808,30 @@ dependencies = [
]
[[package]]
+name = "libflate"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf"
+dependencies = [
+ "adler32",
+ "core2",
+ "crc32fast",
+ "dary_heap",
+ "libflate_lz77",
+]
+
+[[package]]
+name = "libflate_lz77"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524"
+dependencies = [
+ "core2",
+ "hashbrown 0.13.2",
+ "rle-decode-fast",
+]
+
+[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1836,6 +1911,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef"
[[package]]
+name = "maybenot"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7cc2e64fe3f5fb1e247110a9a408449eff2259cc272cf57bad6f161e801ac962"
+dependencies = [
+ "byteorder",
+ "hex",
+ "libflate",
+ "rand 0.8.5",
+ "rand_distr",
+ "ring",
+ "serde",
+ "simple-error",
+]
+
+[[package]]
name = "md-5"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2974,6 +3065,16 @@ dependencies = [
]
[[package]]
+name = "rand_distr"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
+dependencies = [
+ "num-traits",
+ "rand 0.8.5",
+]
+
+[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3100,6 +3201,12 @@ dependencies = [
]
[[package]]
+name = "rle-decode-fast"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422"
+
+[[package]]
name = "rs-release"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3468,6 +3575,12 @@ dependencies = [
]
[[package]]
+name = "simple-error"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8542b68b8800c3cda649d2c72d688b6907b30f1580043135d61669d4aad1c175"
+
+[[package]]
name = "simple-signal"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3870,6 +3983,7 @@ dependencies = [
name = "talpid-wireguard"
version = "0.0.0"
dependencies = [
+ "base64 0.13.1",
"bitflags 1.3.2",
"byteorder",
"chrono",
@@ -3880,6 +3994,7 @@ dependencies = [
"ipnetwork",
"libc",
"log",
+ "maybenot",
"netlink-packet-core",
"netlink-packet-route",
"netlink-packet-utils",
@@ -4822,6 +4937,26 @@ dependencies = [
]
[[package]]
+name = "zerocopy"
+version = "0.7.32"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be"
+dependencies = [
+ "zerocopy-derive",
+]
+
+[[package]]
+name = "zerocopy-derive"
+version = "0.7.32"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.51",
+]
+
+[[package]]
name = "zeroize"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index b0c28f016a..65391bcf37 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,6 +12,7 @@ members = [
"ios/MullvadREST/Transport/Shadowsocks/shadowsocks-proxy",
"ios/TunnelObfuscation/tunnel-obfuscator-proxy",
"mullvad-api",
+ "mullvad-daemon",
"mullvad-cli",
"mullvad-daemon",
"mullvad-exclude",
diff --git a/dist-assets/maybenot_machines b/dist-assets/maybenot_machines
new file mode 100644
index 0000000000..e7bde4f764
--- /dev/null
+++ b/dist-assets/maybenot_machines
@@ -0,0 +1,4 @@
+789cd5cfbb0900200c04d08b833886adb889389f5bb9801be811acb58ae2837ce02010c158b070555c9538b6377a64dbb0ceff242c20b79038507dd169fbede9f629bf6f021efa1b66
+789ccd934b48024118c777357ae9250aa14b8542085d2a0f264233160975a80c440f5b04455deaa050143d0e81d8835ea708f6d2a9a2430511d2a136118bcaa854422a93c44d93a243121844b3204b8b87acadf0c7f0c137f3cdccf7e08f635cc200190eb81099af106235eda133a95a4f4df5dcbfad2eb640b4c921135330e42ba1585066951456b311b3f6d8ca4254ae415f212f356efba0928e44d9642df27769b79162fd9ff1ccf3fef7b9b19982cee963405699ee1c136e107bf0ae19d6fdff9ec76770ac8bc1a9a34847c986d9a8858391d28af2a20b76622824f55125186dedd813d9c2207bd22c382f76f1aed070ad7a896bbdc09317b48c1121deefa51f599855e23b7d1519a1eb044134414662573e1d546d1df6efd6ce401cdb3f5aaadb56e8a9b69d7ab140d60c0924a601b7086929019a54ef66e5b02cd7c376a86199d63cc57fbb63c9daf18f2b9805e6d5f49c9a0e249dff158df6919c0232003a87321e2f0ff86be903a6c37ad5
+789ced956d28435118c7cfdde63544f9424292489222493a77214962de92248992bc535e4221ad2549920f5e22e5ada525496b1faee52d9fc6e4255bb3469b96661f282379ee1d574bb7257c50fb7d38dd73efd3e9dcf33fe77708e4880943e3002180c6197c74b6d09ce6dd24a2a003a4530dc6e715ab4d44dea8222e9f8c6292878ee0a1afb598ca18ae876f0212ca1c2110b19f932c1156eee2e99d70596bed360e1294f8c5adeb71aa2e6f64b251f965721f840d99455ba15adc92e60d13d173d6b9f87bf8e87e690f7224a8aedba804cfc842366979b60f4f23731342d228a67c54f910ad6093a2739eb1e8f05c70bc36d2fd827dcf857bff44fa70bd19c356b21ee772ef8cef220d29ab4ed519d8f196e95f119eb37d179f78a0177af1ab4c78d602918b0ba8d7934e3a73d28717eb2119881542114be05a96be3bf0143f46d3d15f73ae682f4eaa90676bb0adaf4ed13eaec2200490889ab3dec54fe123469fbe45e4e2142dde3092c951524a32a734b103ec5de55f9ab261ca27db6ac7ee7ac4f9ef9e47088ced55531300679a2056e9cde07f8507a9c3f9cc4d3566464938f8f5e418af4835d87e9370dbde9339cc066cf79173a7fc77de00293aaf16
+789ce5914b4802511486ef75c8405a84b4aa8810029b8236b9a8b03b5af6801e24512d24225ab409a2550521152e24c455b831706a216ea4402c5a0d1404465022980d990d5941140c328bc81eb7910686bb89dc047e1c0edc737e2e87ff8740cd23c24d05a4702b1e0a685737bb36fc36a6ef76b92adc6be5a2b9f7d1abee6a0b5e1680009a75c99d79730cf5e065623185fcc7f5bb0bb329e2a81f663a6e863dbe341a9aae9c688b6408ddd6a51bd719a2e880860e0868dc29e13a2174a503050a9ed6304792d1e1ddb63378a8420bd6b9186bfbb033721834a9f84e0af75fe191d8bdced6b8e2787b594b3093bd57de7fe5b3e2d9600df3c8b9e47e491b85a2fffb7f68c1296b8be6aec798b987b790a86f22722807f990181f140618d7d3c8419dcbc4e1a14ca3c3ab5933d31600a19294087d0dfae63b549bbf88f0c124e1d854f6dcc0bf0a484e5db78f569069f2b03f41e84a8d2f5e5a7583
diff --git a/docs/architecture.md b/docs/architecture.md
index 1deba30706..78305a8c6a 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -200,7 +200,7 @@ WireGuard tunnel to the relay and deriving the PSK within the tunnel.
The PSK is stored in memory on the relay and the client, along with a new client generated ephemeral
WireGuard key. Subsequently, a new tunnel is created using the new WireGuard key and the PSK,
ensuring that the tunnel is quantum-resistant.
-See the [protocol definition file](../talpid-tunnel-config-client/proto/tunnel_config.proto) for
+See the [protocol definition file](../talpid-tunnel-config-client/proto/ephemeralpeer.proto) for
more details on the protocol.
#### Quantum-resistant tunnels & Multihop
diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js
index 99c8fb87ef..7087d71301 100644
--- a/gui/tasks/distribution.js
+++ b/gui/tasks/distribution.js
@@ -152,6 +152,7 @@ const config = {
from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard-nt/mullvad-wireguard.dll'),
to: '.',
},
+ { from: distAssets('maybenot_machines'), to: '.' },
],
},
diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs
index deaf29ef10..73b387a8d8 100644
--- a/mullvad-api/src/relay_list.rs
+++ b/mullvad-api/src/relay_list.rs
@@ -303,6 +303,8 @@ struct WireGuardRelay {
#[serde(flatten)]
relay: Relay,
public_key: wireguard::PublicKey,
+ #[serde(default)]
+ daita: bool,
}
impl WireGuardRelay {
@@ -312,6 +314,7 @@ impl WireGuardRelay {
location,
relay_list::RelayEndpointData::Wireguard(relay_list::WireguardRelayEndpointData {
public_key: self.public_key,
+ daita: self.daita,
}),
)
}
diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs
index 7ef60d758c..f022402a83 100644
--- a/mullvad-cli/src/cmds/relay.rs
+++ b/mullvad-cli/src/cmds/relay.rs
@@ -542,6 +542,8 @@ impl Relay {
allowed_ips: all_of_the_internet(),
endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port),
psk: None,
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
},
exit_peer: None,
ipv4_gateway,
diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs
index 19d5c1a3c9..77338ee336 100644
--- a/mullvad-cli/src/cmds/tunnel.rs
+++ b/mullvad-cli/src/cmds/tunnel.rs
@@ -1,6 +1,8 @@
use anyhow::Result;
use clap::Subcommand;
use mullvad_management_interface::MullvadProxyClient;
+#[cfg(target_os = "windows")]
+use mullvad_types::wireguard::DaitaSettings;
use mullvad_types::{
constraints::Constraint,
wireguard::{QuantumResistantState, RotationInterval, DEFAULT_ROTATION_INTERVAL},
@@ -38,6 +40,10 @@ pub enum TunnelOptions {
/// Configure quantum-resistant key exchange
#[arg(long)]
quantum_resistant: Option<QuantumResistantState>,
+ /// Configure whether to enable DAITA
+ #[cfg(target_os = "windows")]
+ #[arg(long)]
+ daita: Option<BooleanOption>,
/// The key rotation interval. Number of hours, or 'any'
#[arg(long)]
rotation_interval: Option<Constraint<RotationInterval>>,
@@ -95,6 +101,9 @@ impl Tunnel {
tunnel_options.wireguard.quantum_resistant,
);
+ #[cfg(target_os = "windows")]
+ print_option!("DAITA", tunnel_options.wireguard.daita.enabled);
+
let key = rpc.get_wireguard_key().await?;
print_option!("Public key", key.key,);
print_option!(format_args!(
@@ -129,10 +138,20 @@ impl Tunnel {
TunnelOptions::Wireguard {
mtu,
quantum_resistant,
+ #[cfg(target_os = "windows")]
+ daita,
rotation_interval,
rotate_key,
} => {
- Self::handle_wireguard(mtu, quantum_resistant, rotation_interval, rotate_key).await
+ Self::handle_wireguard(
+ mtu,
+ quantum_resistant,
+ #[cfg(target_os = "windows")]
+ daita,
+ rotation_interval,
+ rotate_key,
+ )
+ .await
}
TunnelOptions::Ipv6 { state } => Self::handle_ipv6(state).await,
}
@@ -159,6 +178,7 @@ impl Tunnel {
async fn handle_wireguard(
mtu: Option<Constraint<u16>>,
quantum_resistant: Option<QuantumResistantState>,
+ #[cfg(target_os = "windows")] daita: Option<BooleanOption>,
rotation_interval: Option<Constraint<RotationInterval>>,
rotate_key: Option<RotateKey>,
) -> Result<()> {
@@ -174,6 +194,13 @@ impl Tunnel {
println!("Quantum resistant setting has been updated");
}
+ #[cfg(target_os = "windows")]
+ if let Some(daita) = daita {
+ rpc.set_daita_settings(DaitaSettings { enabled: *daita })
+ .await?;
+ println!("DAITA setting has been updated");
+ }
+
if let Some(interval) = rotation_interval {
match interval {
Constraint::Only(interval) => {
diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs
index e605efbe3b..0d6ea0b0c0 100644
--- a/mullvad-cli/src/format.rs
+++ b/mullvad-cli/src/format.rs
@@ -174,6 +174,17 @@ fn format_relay_connection(
"\nQuantum resistant tunnel: no"
};
+ #[cfg(target_os = "windows")]
+ let daita = if !verbose {
+ ""
+ } else if endpoint.daita {
+ "\nDAITA: yes"
+ } else {
+ "\nDAITA: no"
+ };
+ #[cfg(not(target_os = "windows"))]
+ let daita = "";
+
let mut bridge_type = String::new();
let mut obfuscator_type = String::new();
if verbose {
@@ -186,7 +197,7 @@ fn format_relay_connection(
}
format!(
- "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{bridge_type}{obfuscator_type}",
+ "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{daita}{bridge_type}{obfuscator_type}",
first_hop = first_hop.unwrap_or_default(),
bridge = bridge.unwrap_or_default(),
obfuscator = obfuscator.unwrap_or_default(),
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 365e82efeb..2733482a4c 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -37,9 +37,11 @@ use futures::{
StreamExt,
};
use geoip::GeoIpHandler;
-use mullvad_relay_selector::{RelaySelector, SelectorConfig};
+use mullvad_relay_selector::{AdditionalRelayConstraints, AdditionalWireguardConstraints, RelaySelector, SelectorConfig};
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
+#[cfg(target_os = "windows")]
+use mullvad_types::wireguard::DaitaSettings;
use mullvad_types::{
access_method::{AccessMethod, AccessMethodSetting},
account::{AccountData, AccountToken, VoucherSubmission},
@@ -256,6 +258,9 @@ pub enum DaemonCommand {
SetEnableIpv6(ResponseTx<(), settings::Error>, bool),
/// Set whether to enable PQ PSK exchange in the tunnel
SetQuantumResistantTunnel(ResponseTx<(), settings::Error>, QuantumResistantState),
+ /// Set DAITA settings for the tunnel
+ #[cfg(target_os = "windows")]
+ SetDaitaSettings(ResponseTx<(), settings::Error>, DaitaSettings),
/// Set DNS options or servers to use
SetDnsOptions(ResponseTx<(), settings::Error>, DnsOptions),
/// Set override options to use for a given relay
@@ -1242,6 +1247,10 @@ where
self.on_set_quantum_resistant_tunnel(tx, quantum_resistant_state)
.await
}
+ #[cfg(target_os = "windows")]
+ SetDaitaSettings(tx, daita_settings) => {
+ self.on_set_daita_settings(tx, daita_settings).await
+ }
SetDnsOptions(tx, dns_servers) => self.on_set_dns_options(tx, dns_servers).await,
SetRelayOverride(tx, relay_override) => {
self.on_set_relay_override(tx, relay_override).await
@@ -2259,6 +2268,40 @@ where
}
}
+ #[cfg(target_os = "windows")]
+ async fn on_set_daita_settings(
+ &mut self,
+ tx: ResponseTx<(), settings::Error>,
+ daita_settings: DaitaSettings,
+ ) {
+ match self
+ .settings
+ .update(|settings| settings.tunnel_options.wireguard.daita = daita_settings)
+ .await
+ {
+ Ok(settings_changed) => {
+ Self::oneshot_send(tx, Ok(()), "set_daita_settings response");
+ if settings_changed {
+ self.parameters_generator
+ .set_tunnel_options(&self.settings.tunnel_options)
+ .await;
+ self.event_listener
+ .notify_settings(self.settings.to_settings());
+ self.relay_selector
+ .set_config(new_selector_config(&self.settings));
+ if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) {
+ log::info!("Reconnecting because DAITA settings changed");
+ self.reconnect_tunnel();
+ }
+ }
+ }
+ Err(e) => {
+ log::error!("{}", e.display_chain_with_msg("Unable to save settings"));
+ Self::oneshot_send(tx, Err(e), "set_daita_settings response");
+ }
+ }
+ }
+
async fn on_set_dns_options(
&mut self,
tx: ResponseTx<(), settings::Error>,
@@ -2780,8 +2823,18 @@ impl DaemonShutdownHandle {
}
fn new_selector_config(settings: &Settings) -> SelectorConfig {
+ let additional_constraints = AdditionalRelayConstraints {
+ wireguard: AdditionalWireguardConstraints {
+ #[cfg(target_os = "windows")]
+ daita: settings.tunnel_options.wireguard.daita.enabled,
+ #[cfg(not(target_os = "windows"))]
+ daita: false,
+ },
+ };
+
SelectorConfig {
relay_settings: settings.relay_settings.clone(),
+ additional_constraints,
bridge_state: settings.bridge_state,
bridge_settings: settings.bridge_settings.clone(),
obfuscation_settings: settings.obfuscation_settings.clone(),
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 55573f87a7..ce203e7ba8 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -325,6 +325,25 @@ impl ManagementService for ManagementServiceImpl {
Ok(Response::new(()))
}
+ #[cfg(target_os = "windows")]
+ async fn set_daita_settings(
+ &self,
+ request: Request<types::DaitaSettings>,
+ ) -> ServiceResult<()> {
+ let state = mullvad_types::wireguard::DaitaSettings::from(request.into_inner());
+
+ log::debug!("set_daita_settings({state:?})");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::SetDaitaSettings(tx, state))?;
+ self.wait_for_result(rx).await?.map(Response::new)?;
+ Ok(Response::new(()))
+ }
+
+ #[cfg(not(target_os = "windows"))]
+ async fn set_daita_settings(&self, _: Request<types::DaitaSettings>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
+
#[cfg(not(target_os = "android"))]
async fn set_dns_options(&self, request: Request<types::DnsOptions>) -> ServiceResult<()> {
let options = DnsOptions::try_from(request.into_inner()).map_err(map_protobuf_type_err)?;
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 8af34aa8af..31d39db306 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -43,6 +43,7 @@ service ManagementService {
rpc SetWireguardMtu(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {}
rpc SetEnableIpv6(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
rpc SetQuantumResistantTunnel(QuantumResistantState) returns (google.protobuf.Empty) {}
+ rpc SetDaitaSettings(DaitaSettings) returns (google.protobuf.Empty) {}
rpc SetDnsOptions(DnsOptions) returns (google.protobuf.Empty) {}
rpc SetRelayOverride(RelayOverride) returns (google.protobuf.Empty) {}
rpc ClearAllRelayOverrides(google.protobuf.Empty) returns (google.protobuf.Empty) {}
@@ -220,6 +221,7 @@ message TunnelEndpoint {
ObfuscationEndpoint obfuscation = 6;
Endpoint entry_endpoint = 7;
TunnelMetadata tunnel_metadata = 8;
+ bool daita = 9;
}
enum ObfuscationType {
@@ -494,12 +496,15 @@ message QuantumResistantState {
State state = 1;
}
+message DaitaSettings { bool enabled = 1; }
+
message TunnelOptions {
message OpenvpnOptions { optional uint32 mssfix = 1; }
message WireguardOptions {
optional uint32 mtu = 1;
google.protobuf.Duration rotation_interval = 2;
QuantumResistantState quantum_resistant = 4;
+ DaitaSettings daita = 5;
}
message GenericOptions { bool enable_ipv6 = 1; }
@@ -584,7 +589,10 @@ message Relay {
Location location = 11;
}
-message WireguardRelayEndpointData { bytes public_key = 1; }
+message WireguardRelayEndpointData {
+ bytes public_key = 1;
+ bool daita = 2;
+}
message Location {
string country = 1;
diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs
index 8cf3ac495d..04304a19ec 100644
--- a/mullvad-management-interface/src/client.rs
+++ b/mullvad-management-interface/src/client.rs
@@ -2,6 +2,8 @@
use crate::types;
use futures::{Stream, StreamExt};
+#[cfg(target_os = "windows")]
+use mullvad_types::wireguard::DaitaSettings;
use mullvad_types::{
access_method::{self, AccessMethod, AccessMethodSetting},
account::{AccountData, AccountToken, VoucherSubmission},
@@ -344,6 +346,16 @@ impl MullvadProxyClient {
Ok(())
}
+ #[cfg(target_os = "windows")]
+ pub async fn set_daita_settings(&mut self, settings: DaitaSettings) -> Result<()> {
+ let settings = types::DaitaSettings::from(settings);
+ self.0
+ .set_daita_settings(settings)
+ .await
+ .map_err(Error::Rpc)?;
+ Ok(())
+ }
+
pub async fn set_dns_options(&mut self, options: DnsOptions) -> Result<()> {
let options = types::DnsOptions::from(&options);
self.0.set_dns_options(options).await.map_err(Error::Rpc)?;
diff --git a/mullvad-management-interface/src/types/conversions/custom_tunnel.rs b/mullvad-management-interface/src/types/conversions/custom_tunnel.rs
index 8a4408ec83..2445ec3292 100644
--- a/mullvad-management-interface/src/types/conversions/custom_tunnel.rs
+++ b/mullvad-management-interface/src/types/conversions/custom_tunnel.rs
@@ -91,6 +91,8 @@ impl TryFrom<proto::ConnectionConfig> for mullvad_types::ConnectionConfig {
allowed_ips,
endpoint,
psk: None,
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
},
exit_peer: None,
ipv4_gateway,
diff --git a/mullvad-management-interface/src/types/conversions/net.rs b/mullvad-management-interface/src/types/conversions/net.rs
index 80648cbf8d..3557a6a636 100644
--- a/mullvad-management-interface/src/types/conversions/net.rs
+++ b/mullvad-management-interface/src/types/conversions/net.rs
@@ -40,6 +40,10 @@ impl From<talpid_types::net::TunnelEndpoint> for proto::TunnelEndpoint {
tunnel_metadata: endpoint
.tunnel_interface
.map(|tunnel_interface| proto::TunnelMetadata { tunnel_interface }),
+ #[cfg(target_os = "windows")]
+ daita: endpoint.daita,
+ #[cfg(not(target_os = "windows"))]
+ daita: false,
}
}
}
@@ -123,6 +127,8 @@ impl TryFrom<proto::TunnelEndpoint> for talpid_types::net::TunnelEndpoint {
tunnel_interface: endpoint
.tunnel_metadata
.map(|tunnel_metadata| tunnel_metadata.tunnel_interface),
+ #[cfg(target_os = "windows")]
+ daita: endpoint.daita,
})
}
}
diff --git a/mullvad-management-interface/src/types/conversions/relay_list.rs b/mullvad-management-interface/src/types/conversions/relay_list.rs
index 32aee834af..4e0a363702 100644
--- a/mullvad-management-interface/src/types/conversions/relay_list.rs
+++ b/mullvad-management-interface/src/types/conversions/relay_list.rs
@@ -122,6 +122,7 @@ impl From<mullvad_types::relay_list::Relay> for proto::Relay {
"mullvad_daemon.management_interface/WireguardRelayEndpointData",
proto::WireguardRelayEndpointData {
public_key: data.public_key.as_bytes().to_vec(),
+ daita: data.daita,
},
)),
_ => None,
@@ -236,6 +237,7 @@ impl TryFrom<proto::Relay> for mullvad_types::relay_list::Relay {
MullvadEndpointData::Wireguard(
mullvad_types::relay_list::WireguardRelayEndpointData {
public_key: bytes_to_pubkey(&data.public_key)?,
+ daita: data.daita,
},
)
}
diff --git a/mullvad-management-interface/src/types/conversions/settings.rs b/mullvad-management-interface/src/types/conversions/settings.rs
index a4d6313158..857f32d991 100644
--- a/mullvad-management-interface/src/types/conversions/settings.rs
+++ b/mullvad-management-interface/src/types/conversions/settings.rs
@@ -97,6 +97,10 @@ impl From<&mullvad_types::settings::TunnelOptions> for proto::TunnelOptions {
.expect("Failed to convert std::time::Duration to prost_types::Duration for tunnel_options.wireguard.rotation_interval")
}),
quantum_resistant: Some(proto::QuantumResistantState::from(options.wireguard.quantum_resistant)),
+ #[cfg(target_os = "windows")]
+ daita: Some(proto::DaitaSettings::from(options.wireguard.daita.clone())),
+ #[cfg(not(target_os = "windows"))]
+ daita: None,
}),
generic: Some(proto::tunnel_options::GenericOptions {
enable_ipv6: options.generic.enable_ipv6,
@@ -282,6 +286,13 @@ impl TryFrom<proto::TunnelOptions> for mullvad_types::settings::TunnelOptions {
.ok_or(FromProtobufTypeError::InvalidArgument(
"missing quantum resistant state",
))??,
+ #[cfg(target_os = "windows")]
+ daita: wireguard_options
+ .daita
+ .map(mullvad_types::wireguard::DaitaSettings::from)
+ .ok_or(FromProtobufTypeError::InvalidArgument(
+ "missing daita settings",
+ ))?,
},
generic: net::GenericTunnelOptions {
enable_ipv6: generic_options.enable_ipv6,
diff --git a/mullvad-management-interface/src/types/conversions/wireguard.rs b/mullvad-management-interface/src/types/conversions/wireguard.rs
index f35a8c7216..4a4341339c 100644
--- a/mullvad-management-interface/src/types/conversions/wireguard.rs
+++ b/mullvad-management-interface/src/types/conversions/wireguard.rs
@@ -71,3 +71,21 @@ impl TryFrom<proto::QuantumResistantState> for mullvad_types::wireguard::Quantum
}
}
}
+
+#[cfg(target_os = "windows")]
+impl From<mullvad_types::wireguard::DaitaSettings> for proto::DaitaSettings {
+ fn from(settings: mullvad_types::wireguard::DaitaSettings) -> Self {
+ proto::DaitaSettings {
+ enabled: settings.enabled,
+ }
+ }
+}
+
+#[cfg(target_os = "windows")]
+impl From<proto::DaitaSettings> for mullvad_types::wireguard::DaitaSettings {
+ fn from(settings: proto::DaitaSettings) -> Self {
+ mullvad_types::wireguard::DaitaSettings {
+ enabled: settings.enabled,
+ }
+ }
+}
diff --git a/mullvad-relay-selector/src/relay_selector/detailer.rs b/mullvad-relay-selector/src/relay_selector/detailer.rs
index 807903eb39..3dc1cc903e 100644
--- a/mullvad-relay-selector/src/relay_selector/detailer.rs
+++ b/mullvad-relay-selector/src/relay_selector/detailer.rs
@@ -84,6 +84,9 @@ fn wireguard_singlehop_endpoint(
allowed_ips: all_of_the_internet(),
// This will be filled in later, not the relay selector's problem
psk: None,
+ // This will be filled in later
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
};
Ok(MullvadWireguardEndpoint {
peer: peer_config,
@@ -122,6 +125,9 @@ fn wireguard_multihop_endpoint(
allowed_ips: all_of_the_internet(),
// This will be filled in later, not the relay selector's problem
psk: None,
+ // This will be filled in later
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
};
let entry_endpoint = {
@@ -137,6 +143,9 @@ fn wireguard_multihop_endpoint(
allowed_ips: vec![IpNetwork::from(exit.endpoint.ip())],
// This will be filled in later
psk: None,
+ // This will be filled in later
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
};
Ok(MullvadWireguardEndpoint {
diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs
index fbf3ac1c4b..7ec2fc02bb 100644
--- a/mullvad-relay-selector/src/relay_selector/mod.rs
+++ b/mullvad-relay-selector/src/relay_selector/mod.rs
@@ -99,6 +99,7 @@ pub struct RelaySelector {
pub struct SelectorConfig {
// Normal relay settings
pub relay_settings: RelaySettings,
+ pub additional_constraints: AdditionalRelayConstraints,
pub custom_lists: CustomListsSettings,
pub relay_overrides: Vec<RelayOverride>,
// Wireguard specific data
@@ -108,6 +109,20 @@ pub struct SelectorConfig {
pub bridge_settings: BridgeSettings,
}
+/// Extra relay constraints not specified in `relay_settings`.
+#[derive(Default, Debug, Clone, Eq, PartialEq)]
+pub struct AdditionalRelayConstraints {
+ pub wireguard: AdditionalWireguardConstraints,
+}
+
+/// Constraints to use when selecting WireGuard servers
+#[derive(Default, Debug, Clone, Eq, PartialEq)]
+pub struct AdditionalWireguardConstraints {
+ /// If true, select WireGuard relays that support DAITA. If false, select any
+ /// server.
+ pub daita: bool,
+}
+
/// Values which affect the choice of relay but are only known at runtime.
#[derive(Clone, Debug)]
pub struct RuntimeParameters {
@@ -273,6 +288,7 @@ impl Default for SelectorConfig {
let default_settings = Settings::default();
SelectorConfig {
relay_settings: default_settings.relay_settings,
+ additional_constraints: default_settings.additional_constraints,
bridge_settings: default_settings.bridge_settings,
obfuscation_settings: default_settings.obfuscation_settings,
bridge_state: default_settings.bridge_state,
diff --git a/mullvad-relay-selector/tests/relay_selector.rs b/mullvad-relay-selector/tests/relay_selector.rs
index ed3546c62e..40efae24e2 100644
--- a/mullvad-relay-selector/tests/relay_selector.rs
+++ b/mullvad-relay-selector/tests/relay_selector.rs
@@ -54,6 +54,7 @@ static RELAYS: Lazy<RelayList> = Lazy::new(|| RelayList {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
@@ -71,6 +72,7 @@ static RELAYS: Lazy<RelayList> = Lazy::new(|| RelayList {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
@@ -414,6 +416,7 @@ fn test_wireguard_entry() {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
@@ -431,6 +434,7 @@ fn test_wireguard_entry() {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
@@ -932,6 +936,7 @@ fn test_include_in_country() {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
@@ -949,6 +954,7 @@ fn test_include_in_country() {
"BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
)
.unwrap(),
+ daita: false,
}),
location: None,
},
diff --git a/mullvad-types/src/relay_list.rs b/mullvad-types/src/relay_list.rs
index 77cc621728..09ab312a5f 100644
--- a/mullvad-types/src/relay_list.rs
+++ b/mullvad-types/src/relay_list.rs
@@ -129,6 +129,7 @@ impl PartialEq for Relay {
/// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
/// # )
/// # .unwrap(),
+ /// # daita: false,
/// # }),
/// # location: None,
/// };
@@ -232,6 +233,9 @@ struct PortRange {
pub struct WireguardRelayEndpointData {
/// Public key used by the relay peer
pub public_key: wireguard::PublicKey,
+ /// Whether the server supports DAITA
+ #[serde(default)]
+ pub daita: bool,
}
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs
index c85860a776..d8e9bd403b 100644
--- a/mullvad-types/src/wireguard.rs
+++ b/mullvad-types/src/wireguard.rs
@@ -55,6 +55,12 @@ impl FromStr for QuantumResistantState {
#[error("Not a valid state")]
pub struct QuantumResistantStateParseError;
+#[cfg(target_os = "windows")]
+#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
+pub struct DaitaSettings {
+ pub enabled: bool,
+}
+
/// Contains account specific wireguard data
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct WireguardData {
@@ -201,6 +207,9 @@ pub struct TunnelOptions {
pub mtu: Option<u16>,
/// Obtain a PSK using the relay config client.
pub quantum_resistant: QuantumResistantState,
+ /// Configure DAITA
+ #[cfg(target_os = "windows")]
+ pub daita: DaitaSettings,
/// Interval used for automatic key rotation
#[cfg_attr(target_os = "android", jnix(skip))]
pub rotation_interval: Option<RotationInterval>,
@@ -212,6 +221,8 @@ impl Default for TunnelOptions {
TunnelOptions {
mtu: None,
quantum_resistant: QuantumResistantState::Auto,
+ #[cfg(target_os = "windows")]
+ daita: DaitaSettings::default(),
rotation_interval: None,
}
}
@@ -226,6 +237,8 @@ impl TunnelOptions {
QuantumResistantState::On => true,
QuantumResistantState::Off => false,
},
+ #[cfg(target_os = "windows")]
+ daita: self.daita.enabled,
}
}
}
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 0daa8b996c..ce860d0ed6 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -177,7 +177,6 @@ impl TunnelMonitor {
let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?;
let monitor = talpid_wireguard::WireguardMonitor::start(
config,
- params.options.quantum_resistant,
#[cfg(not(target_os = "android"))]
detect_mtu,
log.as_deref(),
diff --git a/talpid-tunnel-config-client/build.rs b/talpid-tunnel-config-client/build.rs
index 129cfd7154..aeb21fe009 100644
--- a/talpid-tunnel-config-client/build.rs
+++ b/talpid-tunnel-config-client/build.rs
@@ -1,3 +1,3 @@
fn main() {
- tonic_build::compile_protos("proto/tunnel_config.proto").unwrap();
+ tonic_build::compile_protos("proto/ephemeralpeer.proto").unwrap();
}
diff --git a/talpid-tunnel-config-client/examples/psk-exchange.rs b/talpid-tunnel-config-client/examples/psk-exchange.rs
index 87200d0336..e7b34ab851 100644
--- a/talpid-tunnel-config-client/examples/psk-exchange.rs
+++ b/talpid-tunnel-config-client/examples/psk-exchange.rs
@@ -18,14 +18,16 @@ async fn main() {
let pubkey = PublicKey::from_base64(pubkey_string.trim()).expect("Invalid public key");
let private_key = PrivateKey::new_from_random();
- let psk = talpid_tunnel_config_client::push_pq_key(
+ let ephemeral_peer = talpid_tunnel_config_client::request_ephemeral_peer(
tuncfg_server_ip,
pubkey,
private_key.public_key(),
+ true,
+ false,
)
.await
.unwrap();
println!("private key: {private_key:?}");
- println!("psk: {psk:?}");
+ println!("psk: {:?}", ephemeral_peer.psk);
}
diff --git a/talpid-tunnel-config-client/examples/tuncfg-server.rs b/talpid-tunnel-config-client/examples/tuncfg-server.rs
index 4d29d764a2..928587e968 100644
--- a/talpid-tunnel-config-client/examples/tuncfg-server.rs
+++ b/talpid-tunnel-config-client/examples/tuncfg-server.rs
@@ -1,80 +1,93 @@
-//! A server implementation of the tuncfg PskExchangeV1 RPC to test
-//! the client side implementation.
+//! A server implementation of the tuncfg RegisterPeerV1 RPC to test
+//! the client side implementation of PQ.
#[allow(clippy::derive_partial_eq_without_eq)]
mod proto {
- tonic::include_proto!("tunnel_config");
+ tonic::include_proto!("ephemeralpeer");
}
use classic_mceliece_rust::{PublicKey, CRYPTO_PUBLICKEYBYTES};
use proto::{
- post_quantum_secure_server::{PostQuantumSecure, PostQuantumSecureServer},
- PskRequestV1, PskResponseV1,
+ ephemeral_peer_server::{EphemeralPeer, EphemeralPeerServer},
+ EphemeralPeerRequestV1, EphemeralPeerResponseV1, PostQuantumResponseV1,
};
use talpid_types::net::wireguard::PresharedKey;
use tonic::{transport::Server, Request, Response, Status};
#[derive(Debug, Default)]
-pub struct PostQuantumSecureImpl {}
+pub struct EphemeralPeerImpl {}
#[tonic::async_trait]
-impl PostQuantumSecure for PostQuantumSecureImpl {
- async fn psk_exchange_v1(
+impl EphemeralPeer for EphemeralPeerImpl {
+ async fn register_peer_v1(
&self,
- request: Request<PskRequestV1>,
- ) -> Result<Response<PskResponseV1>, Status> {
+ request: Request<EphemeralPeerRequestV1>,
+ ) -> Result<Response<EphemeralPeerResponseV1>, Status> {
let mut rng = rand::thread_rng();
let request = request.into_inner();
- println!("wg_pubkey: {:?}", request.wg_pubkey);
- println!("wg_psk_pubkey: {:?}", request.wg_psk_pubkey);
+ println!("wg_parent_pubkey: {:?}", request.wg_parent_pubkey);
+ println!(
+ "wg_ephemeral_peer_pubkey: {:?}",
+ request.wg_ephemeral_peer_pubkey
+ );
+ println!("daita (no-op): {:?}", request.daita);
- // The ciphertexts that will be returned to the client
- let mut ciphertexts = Vec::new();
- // The final PSK that is computed by XORing together all the KEM outputs.
- let mut psk_data = Box::new([0u8; 32]);
+ let post_quantum = if let Some(post_quantum) = request.post_quantum {
+ // The ciphertexts that will be returned to the client
+ let mut ciphertexts = Vec::new();
- for kem_pubkey in request.kem_pubkeys {
- println!("\tKEM algorithm: {}", kem_pubkey.algorithm_name);
- let (ciphertext, shared_secret) = match kem_pubkey.algorithm_name.as_str() {
- "Classic-McEliece-460896f-round3" => {
- let key_data: [u8; CRYPTO_PUBLICKEYBYTES] =
- kem_pubkey.key_data.as_slice().try_into().unwrap();
- let public_key = PublicKey::from(&key_data);
- let (ciphertext, shared_secret) =
- classic_mceliece_rust::encapsulate_boxed(&public_key, &mut rng);
- (ciphertext.as_array().to_vec(), *shared_secret.as_array())
- }
- "Kyber1024" => {
- let public_key = kem_pubkey.key_data.as_slice();
- let (ciphertext, shared_secret) =
- pqc_kyber::encapsulate(public_key, &mut rng).unwrap();
- (ciphertext.to_vec(), shared_secret)
- }
- name => panic!("Unsupported KEM algorithm: {name}"),
- };
+ // The final PSK that is computed by XORing together all the KEM outputs.
+ let mut psk_data = Box::new([0u8; 32]);
+
+ for kem_pubkey in post_quantum.kem_pubkeys {
+ println!("\tKEM algorithm: {}", kem_pubkey.algorithm_name);
+ let (ciphertext, shared_secret) = match kem_pubkey.algorithm_name.as_str() {
+ "Classic-McEliece-460896f-round3" => {
+ let key_data: [u8; CRYPTO_PUBLICKEYBYTES] =
+ kem_pubkey.key_data.as_slice().try_into().unwrap();
+ let public_key = PublicKey::from(&key_data);
+ let (ciphertext, shared_secret) =
+ classic_mceliece_rust::encapsulate_boxed(&public_key, &mut rng);
+ (ciphertext.as_array().to_vec(), *shared_secret.as_array())
+ }
+ "Kyber1024" => {
+ let public_key = kem_pubkey.key_data.as_slice();
+ let (ciphertext, shared_secret) =
+ pqc_kyber::encapsulate(public_key, &mut rng).unwrap();
+ (ciphertext.to_vec(), shared_secret)
+ }
+ name => panic!("Unsupported KEM algorithm: {name}"),
+ };
- ciphertexts.push(ciphertext);
- println!("\tshared secret: {shared_secret:?}");
- for (psk_byte, shared_secret_byte) in psk_data.iter_mut().zip(shared_secret.iter()) {
- *psk_byte ^= shared_secret_byte;
+ ciphertexts.push(ciphertext);
+ println!("\tshared secret: {shared_secret:?}");
+ for (psk_byte, shared_secret_byte) in psk_data.iter_mut().zip(shared_secret.iter())
+ {
+ *psk_byte ^= shared_secret_byte;
+ }
}
- }
- let psk = PresharedKey::from(psk_data);
- println!("psk: {psk:?}");
- println!("==============================================");
- Ok(Response::new(PskResponseV1 { ciphertexts }))
+ let psk = PresharedKey::from(psk_data);
+ println!("psk: {psk:?}");
+ println!("==============================================");
+
+ Some(PostQuantumResponseV1 { ciphertexts })
+ } else {
+ None
+ };
+
+ Ok(Response::new(EphemeralPeerResponseV1 { post_quantum }))
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "127.0.0.1:1337".parse()?;
- let server = PostQuantumSecureImpl::default();
+ let server = EphemeralPeerImpl::default();
Server::builder()
- .add_service(PostQuantumSecureServer::new(server))
+ .add_service(EphemeralPeerServer::new(server))
.serve(addr)
.await?;
diff --git a/talpid-tunnel-config-client/proto/ephemeralpeer.proto b/talpid-tunnel-config-client/proto/ephemeralpeer.proto
new file mode 100644
index 0000000000..bb49eb5598
--- /dev/null
+++ b/talpid-tunnel-config-client/proto/ephemeralpeer.proto
@@ -0,0 +1,85 @@
+syntax = "proto3";
+
+option go_package = "github.com/mullvad/wg-manager/tuncfg/api/ephemeralpeer";
+
+package ephemeralpeer;
+
+service EphemeralPeer {
+ // Derive an ephemeral peer with one or several options enabled, such as PQ or DAITA.
+ //
+ // The VPN server associates the ephemeral peer with the peer who performed the exchange. Any
+ // already existing ephemeral peer for the normal peer is replaced. Each normal peer can have
+ // at most one ephemeral peer.
+ //
+ // The ephemeral peer is mutually exclusive to the normal peer. The server keeps both peers in
+ // memory, but only one of them is loaded into WireGuard at any point in time. A handshake from
+ // the normal peer unloads the corresponding ephemeral peer from WireGuard and vice versa.
+ //
+ // A new peer is negotiated to avoid a premature break of the tunnel used for negotiation.
+ // A tunnel would break prematurely if configuration such as preshared key were applied before the
+ // normal peer received the server's response. This cannot occur now because the client decides
+ // when to switch to the ephemeral tunnel. This design also allows the client to switch back to
+ // using a non-ephemeral tunnel at any point.
+ //
+ // The server gives no guarantees how long the ephemeral peer will be valid and working when it's
+ // no longer in use. The client should negotiate a new ephemeral peer every time it establishes a
+ // new tunnel to the server.
+ //
+ // The request from the VPN client should contain:
+ // * `wg_parent_pubkey` - The public key used by the current tunnel (that the request travels
+ // inside).
+ // * `wg_ephemeral_peer_pubkey` - A newly generated ephemeral WireGuard public key for the
+ // ephemeral peer. The server will associate the new configuration with this key.
+ // * One or more requests for different types of options. See the individual messages for more
+ // information. If a request is provided, a corresponding response may be returned in the
+ // server's response.
+ rpc RegisterPeerV1(EphemeralPeerRequestV1) returns (EphemeralPeerResponseV1) {}
+}
+
+message EphemeralPeerRequestV1 {
+ bytes wg_parent_pubkey = 1;
+ bytes wg_ephemeral_peer_pubkey = 2;
+ PostQuantumRequestV1 post_quantum = 3;
+ DaitaRequestV1 daita = 4;
+}
+
+// The v1 request supports exactly two algorithms.
+// The algorithms can appear soletary or in mixed order:
+// - "Classic-McEliece-460896f", but explicitly identified as "Classic-McEliece-460896f-round3"
+// - "Kyber1024"
+message PostQuantumRequestV1 { repeated KemPubkeyV1 kem_pubkeys = 1; }
+
+message KemPubkeyV1 {
+ string algorithm_name = 1;
+ bytes key_data = 2;
+}
+
+message DaitaRequestV1 { bool activate_daita = 1; }
+
+message EphemeralPeerResponseV1 {
+ // The response from the VPN server contains:
+ // * `ciphertexts` - A list of the ciphertexts (the encapsulated shared secrets) for all
+ // public keys in `kem_pubkeys` in the request, in the same order as in the request.
+ //
+ // # Deriving the WireGuard PSK
+ //
+ // The PSK to be used in WireGuard's preshared-key field is computed by XORing the resulting
+ // shared secrets of all the KEM algorithms. All currently supported and planned to be
+ // supported algorithms output 32 bytes, so this is trivial.
+ //
+ // Since the PSK provided to WireGuard is directly fed into a HKDF, it is not important that
+ // the entropy in the PSK is uniformly distributed. The actual keys used for encrypting the
+ // data channel will have uniformly distributed entropy anyway, thanks to the HKDF.
+ // But even if that was not true, since both CME and Kyber run SHAKE256 as the last step
+ // of their internal key derivation, the output they produce are uniformly distributed.
+ //
+ // If we later want to support another type of KEM that produce longer or shorter output,
+ // we can hash that secret into a 32 byte hash before proceeding to the XOR step.
+ //
+ // Mixing with XOR (A = B ^ C) is fine since nothing about A is revealed even if one of B or C
+ // is known. Both B *and* C must be known to compute any bit in A. This means all involved
+ // KEM algorithms must be broken before the PSK can be computed by an attacker.
+ PostQuantumResponseV1 post_quantum = 1;
+}
+
+message PostQuantumResponseV1 { repeated bytes ciphertexts = 1; }
diff --git a/talpid-tunnel-config-client/proto/tunnel_config.proto b/talpid-tunnel-config-client/proto/tunnel_config.proto
deleted file mode 100644
index e6e4b73d97..0000000000
--- a/talpid-tunnel-config-client/proto/tunnel_config.proto
+++ /dev/null
@@ -1,83 +0,0 @@
-syntax = "proto3";
-
-option go_package = "github.com/mullvad/wg-manager/server/tuncfg";
-
-package tunnel_config;
-
-service PostQuantumSecure {
- // Allows deriving a preshared key (PSK) using one or multiple PQ-secure key-encapsulation
- // mechanisms (KEM). The preshared key is added to WireGuard's preshared-key field in a new
- // ephemeral peer (PQ-peer). This makes the tunnel resistant towards attacks using
- // quantum computers.
- //
- // The VPN server associates the PQ-peer with the peer who performed the exchange. Any
- // already existing PQ-peer for the normal peer is replaced. Each normal peer can have
- // at most one PQ-peer.
- //
- // The PQ-peer is mutually exclusive to the normal peer. The server keeps both peers in memory,
- // but only one of them is loaded into WireGuard at any point in time. A handshake from the
- // normal peer unloads the corresponding PQ-peer from WireGuard and vice versa.
- //
- // A new peer is negotiated for PQ to avoid a premature break of the tunnel used for negotiation.
- // A tunnel would break prematurely if the preshared key is applied before the normal peer
- // received the server's contribution to the KEM exchange. This cannot occur now because
- // the client decides when to switch to the PQ-secure tunnel. This design also allows
- // the client to switch back to using a non-PQ-secure tunnel at any point.
- //
- // The negotiated PQ-peer is ephemeral. The server gives no guarantees how long it will be
- // valid and working. The client should negotiate a new PQ-peer every time it establishes a new
- // tunnel to the server.
- //
- // The full exchange requires just a single request-response round trip between the VPN client
- // and the VPN server.
- //
- // # Request-response format
- //
- // The request from the VPN client contains:
- // * `wg_pubkey` - The public key used by the current tunnel (that the request travels inside).
- // * `wg_psk_pubkey` - A newly generated ephemeral WireGuard public key for the PQ-peer.
- // The server will associate the derived PSK with this public key.
- // * `kem_pubkeys` - A list describing the KEM algorithms. Must have at least one entry.
- // The same KEM must not be listed more than once. Each list item contains:
- // * `algorithm_name` - The name of the KEM, including which variant. Should be the same
- // name/format that `liboqs` uses.
- // * `key_data` - The client's public key for this KEM. Will be used by the server to
- // encapsulate the shared secret for this KEM.
- //
- // The response from the VPN server contains:
- // * `ciphertexts` - A list of the ciphertexts (the encapsulated shared secrets) for all
- // public keys in `kem_pubkeys` in the request, in the same order as in the request.
- //
- // # Deriving the WireGuard PSK
- //
- // The PSK to be used in WireGuard's preshared-key field is computed by XORing the resulting
- // shared secrets of all the KEM algorithms. All currently supported and planned to be
- // supported algorithms output 32 bytes, so this is trivial.
- //
- // Since the PSK provided to WireGuard is directly fed into a HKDF, it is not important that
- // the entropy in the PSK is uniformly distributed. The actual keys used for encrypting the
- // data channel will have uniformly distributed entropy anyway, thanks to the HKDF.
- // But even if that was not true, since both CME and Kyber run SHAKE256 as the last step
- // of their internal key derivation, the output they produce are uniformly distributed.
- //
- // If we later want to support another type of KEM that produce longer or shorter output,
- // we can hash that secret into a 32 byte hash before proceeding to the XOR step.
- //
- // Mixing with XOR (A = B ^ C) is fine since nothing about A is revealed even if one of B or C
- // is known. Both B *and* C must be known to compute any bit in A. This means all involved
- // KEM algorithms must be broken before the PSK can be computed by an attacker.
- rpc PskExchangeV1(PskRequestV1) returns (PskResponseV1) {}
-}
-
-message PskRequestV1 {
- bytes wg_pubkey = 1;
- bytes wg_psk_pubkey = 2;
- repeated KemPubkeyV1 kem_pubkeys = 3;
-}
-
-message KemPubkeyV1 {
- string algorithm_name = 1;
- bytes key_data = 2;
-}
-
-message PskResponseV1 { repeated bytes ciphertexts = 1; }
diff --git a/talpid-tunnel-config-client/src/classic_mceliece.rs b/talpid-tunnel-config-client/src/classic_mceliece.rs
index a2d5dc0be2..2036bc3fc7 100644
--- a/talpid-tunnel-config-client/src/classic_mceliece.rs
+++ b/talpid-tunnel-config-client/src/classic_mceliece.rs
@@ -1,6 +1,5 @@
-use classic_mceliece_rust::{
- keypair_boxed, Ciphertext, PublicKey, SecretKey, SharedSecret, CRYPTO_CIPHERTEXTBYTES,
-};
+use classic_mceliece_rust::{keypair_boxed, Ciphertext, CRYPTO_CIPHERTEXTBYTES};
+pub use classic_mceliece_rust::{PublicKey, SecretKey, SharedSecret};
/// The `keypair_boxed` function needs just under 1 MiB of stack in debug
/// builds. Even though it probably works to run it directly on the main
diff --git a/talpid-tunnel-config-client/src/kyber.rs b/talpid-tunnel-config-client/src/kyber.rs
index 5654b2e10b..003c88dc48 100644
--- a/talpid-tunnel-config-client/src/kyber.rs
+++ b/talpid-tunnel-config-client/src/kyber.rs
@@ -1,6 +1,5 @@
-use pqc_kyber::{SecretKey, KYBER_CIPHERTEXTBYTES};
-
-pub use pqc_kyber::{keypair, KyberError};
+use pqc_kyber::KYBER_CIPHERTEXTBYTES;
+pub use pqc_kyber::{keypair, KyberError, SecretKey};
/// Use the strongest variant of Kyber. It is fast and the keys are small, so there is no practical
/// benefit of going with anything lower.
diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs
index 2c7b5e58f3..e89b1be42d 100644
--- a/talpid-tunnel-config-client/src/lib.rs
+++ b/talpid-tunnel-config-client/src/lib.rs
@@ -13,7 +13,7 @@ mod kyber;
#[allow(clippy::derive_partial_eq_without_eq)]
mod proto {
- tonic::include_proto!("tunnel_config");
+ tonic::include_proto!("ephemeralpeer");
}
use libc::setsockopt;
@@ -34,6 +34,7 @@ use sys::*;
pub enum Error {
GrpcConnectError(tonic::transport::Error),
GrpcError(tonic::Status),
+ MissingCiphertexts,
InvalidCiphertextLength {
algorithm: &'static str,
actual: usize,
@@ -51,6 +52,7 @@ impl std::fmt::Display for Error {
match self {
GrpcConnectError(_) => "Failed to connect to config service".fmt(f),
GrpcError(status) => write!(f, "RPC failed: {status}"),
+ MissingCiphertexts => write!(f, "Found no ciphertexts in response"),
InvalidCiphertextLength {
algorithm,
actual,
@@ -77,7 +79,7 @@ impl std::error::Error for Error {
}
}
-type RelayConfigService = proto::post_quantum_secure_client::PostQuantumSecureClient<Channel>;
+type RelayConfigService = proto::ephemeral_peer_client::EphemeralPeerClient<Channel>;
/// Port used by the tunnel config service.
pub const CONFIG_SERVICE_PORT: u16 = 1337;
@@ -93,72 +95,118 @@ pub const CONFIG_SERVICE_PORT: u16 = 1337;
/// handshake to work even if there is fragmentation.
const CONFIG_CLIENT_MTU: u16 = 576;
-/// Generates a new WireGuard key pair and negotiates a PSK with the relay in a PQ-safe
-/// manner. This creates a peer on the relay with the new WireGuard pubkey and PSK,
-/// which can then be used to establish a PQ-safe tunnel to the relay.
-// TODO: consider binding to the tunnel interface here, on non-windows platforms
-pub async fn push_pq_key(
+pub struct EphemeralPeer {
+ pub psk: Option<PresharedKey>,
+}
+
+/// Negotiate a short-lived peer with a PQ-safe PSK or with DAITA enabled.
+pub async fn request_ephemeral_peer(
+ service_address: IpAddr,
+ parent_pubkey: PublicKey,
+ ephemeral_pubkey: PublicKey,
+ enable_post_quantum: bool,
+ enable_daita: bool,
+) -> Result<EphemeralPeer, Error> {
+ request_ephemeral_peer_with_opts(
+ service_address,
+ parent_pubkey,
+ ephemeral_pubkey,
+ enable_post_quantum,
+ enable_daita,
+ )
+ .await
+}
+
+pub async fn request_ephemeral_peer_with_opts(
service_address: IpAddr,
- wg_pubkey: PublicKey,
- wg_psk_pubkey: PublicKey,
-) -> Result<PresharedKey, Error> {
- let (cme_kem_pubkey, cme_kem_secret) = classic_mceliece::generate_keys().await;
- let kyber_keypair = kyber::keypair(&mut rand::thread_rng());
+ parent_pubkey: PublicKey,
+ ephemeral_pubkey: PublicKey,
+ enable_post_quantum: bool,
+ enable_daita: bool,
+) -> Result<EphemeralPeer, Error> {
+ let (pq_request, kem_secrets) = if enable_post_quantum {
+ let (cme_kem_pubkey, cme_kem_secret) = classic_mceliece::generate_keys().await;
+ let kyber_keypair = kyber::keypair(&mut rand::thread_rng());
+
+ (
+ Some(proto::PostQuantumRequestV1 {
+ kem_pubkeys: vec![
+ proto::KemPubkeyV1 {
+ algorithm_name: classic_mceliece::ALGORITHM_NAME.to_owned(),
+ key_data: cme_kem_pubkey.as_array().to_vec(),
+ },
+ proto::KemPubkeyV1 {
+ algorithm_name: kyber::ALGORITHM_NAME.to_owned(),
+ key_data: kyber_keypair.public.to_vec(),
+ },
+ ],
+ }),
+ Some((cme_kem_secret, kyber_keypair.secret)),
+ )
+ } else {
+ (None, None)
+ };
+
+ let daita = Some(proto::DaitaRequestV1 {
+ activate_daita: enable_daita,
+ });
let mut client = new_client(service_address).await?;
let response = client
- .psk_exchange_v1(proto::PskRequestV1 {
- wg_pubkey: wg_pubkey.as_bytes().to_vec(),
- wg_psk_pubkey: wg_psk_pubkey.as_bytes().to_vec(),
- kem_pubkeys: vec![
- proto::KemPubkeyV1 {
- algorithm_name: classic_mceliece::ALGORITHM_NAME.to_owned(),
- key_data: cme_kem_pubkey.as_array().to_vec(),
- },
- proto::KemPubkeyV1 {
- algorithm_name: kyber::ALGORITHM_NAME.to_owned(),
- key_data: kyber_keypair.public.to_vec(),
- },
- ],
+ .register_peer_v1(proto::EphemeralPeerRequestV1 {
+ wg_parent_pubkey: parent_pubkey.as_bytes().to_vec(),
+ wg_ephemeral_peer_pubkey: ephemeral_pubkey.as_bytes().to_vec(),
+ post_quantum: pq_request,
+ daita,
})
.await
.map_err(Error::GrpcError)?;
- let ciphertexts = response.into_inner().ciphertexts;
+ let psk = if let Some((cme_kem_secret, kyber_secret)) = kem_secrets {
+ let ciphertexts = response
+ .into_inner()
+ .post_quantum
+ .ok_or(Error::MissingCiphertexts)?
+ .ciphertexts;
- // Unpack the ciphertexts into one per KEM without needing to access them by index.
- let [cme_ciphertext, kyber_ciphertext] = <&[Vec<u8>; 2]>::try_from(ciphertexts.as_slice())
- .map_err(|_| Error::InvalidCiphertextCount {
- actual: ciphertexts.len(),
- })?;
+ // Unpack the ciphertexts into one per KEM without needing to access them by index.
+ let [cme_ciphertext, kyber_ciphertext] = <&[Vec<u8>; 2]>::try_from(ciphertexts.as_slice())
+ .map_err(|_| Error::InvalidCiphertextCount {
+ actual: ciphertexts.len(),
+ })?;
- // Store the PSK data on the heap. So it can be passed around and then zeroized on drop without
- // being stored in a bunch of places on the stack.
- let mut psk_data = Box::new([0u8; 32]);
+ // Store the PSK data on the heap. So it can be passed around and then zeroized on drop without
+ // being stored in a bunch of places on the stack.
+ let mut psk_data = Box::new([0u8; 32]);
- // Decapsulate Classic McEliece and mix into PSK
- {
- let mut shared_secret = classic_mceliece::decapsulate(&cme_kem_secret, cme_ciphertext)?;
- xor_assign(&mut psk_data, shared_secret.as_array());
+ // Decapsulate Classic McEliece and mix into PSK
+ {
+ let mut shared_secret = classic_mceliece::decapsulate(&cme_kem_secret, cme_ciphertext)?;
+ xor_assign(&mut psk_data, shared_secret.as_array());
- // This should happen automatically due to `SharedSecret` implementing ZeroizeOnDrop. But
- // doing it explicitly provides a stronger guarantee that it's not accidentally
- // removed.
- shared_secret.zeroize();
- }
- // Decapsulate Kyber and mix into PSK
- {
- let mut shared_secret = kyber::decapsulate(kyber_keypair.secret, kyber_ciphertext)?;
- xor_assign(&mut psk_data, &shared_secret);
+ // This should happen automatically due to `SharedSecret` implementing ZeroizeOnDrop. But
+ // doing it explicitly provides a stronger guarantee that it's not accidentally
+ // removed.
+ shared_secret.zeroize();
+ }
+ // Decapsulate Kyber and mix into PSK
+ {
+ let mut shared_secret = kyber::decapsulate(kyber_secret, kyber_ciphertext)?;
+ xor_assign(&mut psk_data, &shared_secret);
- // The shared secret is sadly stored in an array on the stack. So we can't get any
- // guarantees that it's not copied around on the stack. The best we can do here
- // is to zero out the version we have and hope the compiler optimizes out copies.
- // https://github.com/Argyle-Software/kyber/issues/59
- shared_secret.zeroize();
- }
+ // The shared secret is sadly stored in an array on the stack. So we can't get any
+ // guarantees that it's not copied around on the stack. The best we can do here
+ // is to zero out the version we have and hope the compiler optimizes out copies.
+ // https://github.com/Argyle-Software/kyber/issues/59
+ shared_secret.zeroize();
+ }
+
+ Some(PresharedKey::from(psk_data))
+ } else {
+ None
+ };
- Ok(PresharedKey::from(psk_data))
+ Ok(EphemeralPeer { psk })
}
/// Performs `dst = dst ^ src`.
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 2b9f0a7366..a17d8ceb5f 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -40,6 +40,8 @@ impl TunnelParameters {
obfuscation: None,
entry_endpoint: None,
tunnel_interface: None,
+ #[cfg(target_os = "windows")]
+ daita: false,
},
TunnelParameters::Wireguard(params) => TunnelEndpoint {
tunnel_type: TunnelType::Wireguard,
@@ -55,6 +57,8 @@ impl TunnelParameters {
.get_exit_endpoint()
.map(|_| params.connection.get_endpoint()),
tunnel_interface: None,
+ #[cfg(target_os = "windows")]
+ daita: params.options.daita,
},
}
}
@@ -183,6 +187,8 @@ pub struct TunnelEndpoint {
pub entry_endpoint: Option<Endpoint>,
#[cfg_attr(target_os = "android", jnix(skip))]
pub tunnel_interface: Option<String>,
+ #[cfg(target_os = "windows")]
+ pub daita: bool,
}
impl fmt::Display for TunnelEndpoint {
diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs
index db7b2da3a9..f7212236e2 100644
--- a/talpid-types/src/net/wireguard.rs
+++ b/talpid-types/src/net/wireguard.rs
@@ -60,6 +60,10 @@ pub struct PeerConfig {
/// ephemeral and living in memory only.
#[serde(skip)]
pub psk: Option<PresharedKey>,
+ /// Enable constant packet sizes for `entry_peer``
+ #[cfg(target_os = "windows")]
+ #[serde(skip)]
+ pub constant_packet_size: bool,
}
#[derive(Clone, Eq, PartialEq, Deserialize, Serialize, Debug)]
@@ -76,6 +80,9 @@ pub struct TunnelOptions {
pub mtu: Option<u16>,
/// Perform PQ-safe PSK exchange when connecting
pub quantum_resistant: bool,
+ /// Enable DAITA during tunnel config
+ #[cfg(target_os = "windows")]
+ pub daita: bool,
}
/// Wireguard x25519 private key
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index 1fc6e13b3a..c2562d212c 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -12,6 +12,7 @@ workspace = true
[dependencies]
thiserror = { workspace = true }
+base64 = "0.13"
futures = "0.3.15"
hex = "0.4"
ipnetwork = "0.16"
@@ -54,6 +55,7 @@ talpid-dbus = { path = "../talpid-dbus" }
bitflags = "1.2"
talpid-windows = { path = "../talpid-windows" }
widestring = "1.0"
+maybenot = "1.0"
# TODO: Figure out which features are needed and which are not
[target.'cfg(windows)'.dependencies.windows-sys]
diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs
index 29328eb681..f10a0e4859 100644
--- a/talpid-wireguard/src/config.rs
+++ b/talpid-wireguard/src/config.rs
@@ -28,6 +28,10 @@ pub struct Config {
pub enable_ipv6: bool,
/// Obfuscator config to be used for reaching the relay.
pub obfuscator_config: Option<ObfuscatorConfig>,
+ /// Enable quantum-resistant PSK exchange
+ pub quantum_resistant: bool,
+ /// Enable DAITA
+ pub daita: bool,
}
/// Configuration errors
@@ -92,6 +96,11 @@ impl Config {
#[cfg(target_os = "linux")]
enable_ipv6: generic_options.enable_ipv6,
obfuscator_config: obfuscator_config.to_owned(),
+ quantum_resistant: wg_options.quantum_resistant,
+ #[cfg(target_os = "windows")]
+ daita: wg_options.daita,
+ #[cfg(not(target_os = "windows"))]
+ daita: false,
};
for peer in config.peers_mut() {
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs
index e820a3eb73..70f88e6872 100644
--- a/talpid-wireguard/src/connectivity_check.rs
+++ b/talpid-wireguard/src/connectivity_check.rs
@@ -612,6 +612,11 @@ mod test {
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
Box::pin(async { Ok(()) })
}
+
+ #[cfg(target_os = "windows")]
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
+ Ok(())
+ }
}
fn mock_monitor(
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 72d6a31566..bdef52343e 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -260,7 +260,6 @@ impl WireguardMonitor {
+ 'static,
>(
mut config: Config,
- psk_negotiation: bool,
#[cfg(not(target_os = "android"))] detect_mtu: bool,
log_path: Option<&Path>,
args: TunnelArgs<'_, F>,
@@ -283,12 +282,12 @@ impl WireguardMonitor {
log_path,
args.resource_dir,
args.tun_provider.clone(),
+ #[cfg(target_os = "android")]
+ config.quantum_resistant,
#[cfg(target_os = "windows")]
args.route_manager.clone(),
#[cfg(target_os = "windows")]
setup_done_tx,
- #[cfg(target_os = "android")]
- psk_negotiation,
)?;
let iface_name = tunnel.get_interface_name();
@@ -336,7 +335,7 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- let allowed_traffic = if psk_negotiation {
+ let allowed_traffic = if config.quantum_resistant || config.daita {
AllowedTunnelTraffic::One(Endpoint::new(
config.ipv4_gateway,
talpid_tunnel_config_client::CONFIG_SERVICE_PORT,
@@ -365,16 +364,16 @@ impl WireguardMonitor {
.map_err(Error::SetupRoutingError)
.map_err(CloseMsg::SetupError)?;
- let psk_obfs_sender = close_obfs_sender.clone();
- if psk_negotiation {
- Self::psk_negotiation(
+ let ephemeral_obfs_sender = close_obfs_sender.clone();
+ if config.quantum_resistant || config.daita {
+ Self::config_ephemeral_peers(
&tunnel,
&mut config,
args.retry_attempt,
args.on_event.clone(),
&iface_name,
obfuscator.clone(),
- psk_obfs_sender,
+ ephemeral_obfs_sender,
#[cfg(target_os = "android")]
args.tun_provider,
)
@@ -386,6 +385,14 @@ impl WireguardMonitor {
let config = config.clone();
let iface_name = iface_name.clone();
tokio::task::spawn(async move {
+ #[cfg(target_os = "windows")]
+ if config.daita {
+ // TODO: For now, we assume the MTU during the tunnel lifetime.
+ // We could instead poke maybenot whenever we detect changes to it.
+ log::warn!("MTU detection is not supported with DAITA. Skipping");
+ return;
+ }
+
if let Err(e) = mtu_detection::automatic_mtu_correction(
gateway,
iface_name,
@@ -465,7 +472,7 @@ impl WireguardMonitor {
}
#[allow(clippy::too_many_arguments)]
- async fn psk_negotiation<F>(
+ async fn config_ephemeral_peers<F>(
tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
config: &mut Config,
retry_attempt: u32,
@@ -482,7 +489,7 @@ impl WireguardMonitor {
+ Clone
+ 'static,
{
- let wg_psk_privkey = PrivateKey::new_from_random();
+ let ephemeral_private_key = PrivateKey::new_from_random();
let close_obfs_sender = close_obfs_sender.clone();
let allowed_traffic = Endpoint::new(
@@ -507,11 +514,17 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(iface_name, config);
(on_event)(TunnelEvent::InterfaceUp(metadata, allowed_traffic.clone())).await;
- let exit_psk =
- Self::perform_psk_negotiation(retry_attempt, config, wg_psk_privkey.public_key())
- .await?;
+ let exit_should_have_daita = config.daita && !config.is_multihop();
+ let exit_psk = Self::request_ephemeral_peer(
+ retry_attempt,
+ config,
+ ephemeral_private_key.public_key(),
+ config.quantum_resistant,
+ exit_should_have_daita,
+ )
+ .await?;
- log::debug!("Successfully exchanged PSK with exit peer");
+ log::debug!("Retrieved ephemeral peer");
if config.is_multihop() {
// Set up tunnel to lead to entry
@@ -531,22 +544,27 @@ impl WireguardMonitor {
&tun_provider,
)
.await?;
- let entry_psk = Some(
- Self::perform_psk_negotiation(
- retry_attempt,
- &entry_config,
- wg_psk_privkey.public_key(),
- )
- .await?,
- );
+ let entry_psk = Self::request_ephemeral_peer(
+ retry_attempt,
+ &entry_config,
+ ephemeral_private_key.public_key(),
+ config.quantum_resistant,
+ config.daita,
+ )
+ .await?;
log::debug!("Successfully exchanged PSK with entry peer");
config.entry_peer.psk = entry_psk;
}
- config.exit_peer_mut().psk = Some(exit_psk);
+ config.exit_peer_mut().psk = exit_psk;
+ #[cfg(target_os = "windows")]
+ if config.daita {
+ log::trace!("Enabling constant packet size for entry peer");
+ config.entry_peer.constant_packet_size = true;
+ }
- config.tunnel.private_key = wg_psk_privkey;
+ config.tunnel.private_key = ephemeral_private_key;
*config = Self::reconfigure_tunnel(
tunnel,
@@ -557,6 +575,19 @@ impl WireguardMonitor {
&tun_provider,
)
.await?;
+
+ #[cfg(target_os = "windows")]
+ if config.daita {
+ // Start local DAITA machines
+ let mut tunnel = tunnel.lock().unwrap();
+ if let Some(tunnel) = tunnel.as_mut() {
+ tunnel
+ .start_daita()
+ .map_err(Error::TunnelError)
+ .map_err(CloseMsg::SetupError)?;
+ }
+ }
+
let metadata = Self::tunnel_metadata(iface_name, config);
(on_event)(TunnelEvent::InterfaceUp(
metadata,
@@ -678,12 +709,14 @@ impl WireguardMonitor {
Ok(())
}
- async fn perform_psk_negotiation(
+ async fn request_ephemeral_peer(
retry_attempt: u32,
config: &Config,
wg_psk_pubkey: PublicKey,
- ) -> std::result::Result<PresharedKey, CloseMsg> {
- log::debug!("Performing PQ-safe PSK exchange");
+ enable_pq: bool,
+ enable_daita: bool,
+ ) -> std::result::Result<Option<PresharedKey>, CloseMsg> {
+ log::debug!("Requesting ephemeral peer");
let timeout = std::cmp::min(
MAX_PSK_EXCHANGE_TIMEOUT,
@@ -691,23 +724,25 @@ impl WireguardMonitor {
.saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)),
);
- let psk = tokio::time::timeout(
+ let ephemeral = tokio::time::timeout(
timeout,
- talpid_tunnel_config_client::push_pq_key(
+ talpid_tunnel_config_client::request_ephemeral_peer_with_opts(
IpAddr::from(config.ipv4_gateway),
config.tunnel.private_key.public_key(),
wg_psk_pubkey,
+ enable_pq,
+ enable_daita,
),
)
.await
.map_err(|_timeout_err| {
- log::warn!("Timeout while negotiating PSK");
+ log::warn!("Timeout while negotiating ephemeral peer");
CloseMsg::PskNegotiationTimeout
})?
.map_err(Error::PskNegotiationError)
.map_err(CloseMsg::SetupError)?;
- Ok(psk)
+ Ok(ephemeral.psk)
}
#[allow(unused_variables)]
@@ -717,7 +752,7 @@ impl WireguardMonitor {
log_path: Option<&Path>,
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
- #[cfg(target_os = "android")] psk_negotiation: bool,
+ #[cfg(target_os = "android")] gateway_only: bool,
#[cfg(windows)] route_manager: crate::routing::RouteManagerHandle,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Box<dyn Tunnel>> {
@@ -771,7 +806,7 @@ impl WireguardMonitor {
Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes);
#[cfg(target_os = "android")]
- let config = Self::patch_allowed_ips(config, psk_negotiation);
+ let config = Self::patch_allowed_ips(config, gateway_only);
#[cfg(target_os = "linux")]
log::debug!("Using userspace WireGuard implementation");
@@ -994,6 +1029,8 @@ pub(crate) trait Tunnel: Send {
&self,
_config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>;
+ #[cfg(target_os = "windows")]
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError>;
}
/// Errors to be returned from WireGuard implementations, namely implementers of the Tunnel trait
diff --git a/talpid-wireguard/src/wireguard_nt/daita.rs b/talpid-wireguard/src/wireguard_nt/daita.rs
new file mode 100644
index 0000000000..75d6ebaa4d
--- /dev/null
+++ b/talpid-wireguard/src/wireguard_nt/daita.rs
@@ -0,0 +1,450 @@
+use super::WIREGUARD_KEY_LENGTH;
+use maybenot::framework::MachineId;
+use once_cell::sync::OnceCell;
+use std::{collections::HashMap, fs, io, path::Path, time::Duration};
+use std::{os::windows::prelude::RawHandle, sync::Arc};
+use talpid_types::net::wireguard::PublicKey;
+use tokio::task::JoinHandle;
+use windows_sys::Win32::Foundation::BOOLEAN;
+use windows_sys::Win32::{
+ Foundation::ERROR_NO_MORE_ITEMS,
+ System::Threading::{WaitForMultipleObjects, WaitForSingleObject, INFINITE},
+};
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ /// Failed to find maybenot machines
+ #[error("Failed to enumerate maybenot machines")]
+ EnumerateMachines(#[source] io::Error),
+ /// Failed to parse maybenot machine
+ #[error("Failed to parse maybenot machine \"{0}\"")]
+ InvalidMachine(String),
+ /// Failed to initialize quit event
+ #[error("Failed to initialize quit event")]
+ InitializeQuitEvent(#[source] io::Error),
+ /// Failed to initialize machinist handle
+ #[error("Failed to initialize machinist handle")]
+ InitializeHandle(#[source] io::Error),
+ /// Failed to initialize maybenot framework
+ #[error("Failed to initialize maybenot framework: {0}")]
+ InitializeMaybenot(String),
+}
+
+// See DAITA_EVENT_TYPE:
+// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h
+#[repr(C)]
+#[derive(Debug)]
+#[allow(dead_code)]
+pub enum EventType {
+ NonpaddingSent,
+ NonpaddingReceived,
+ PaddingSent,
+ PaddingReceived,
+}
+
+// See DAITA_EVENT:
+// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h
+#[repr(C)]
+#[derive(Debug)]
+pub struct Event {
+ pub peer: [u8; WIREGUARD_KEY_LENGTH],
+ pub event_type: EventType,
+ pub xmit_bytes: u16,
+ pub user_context: usize,
+}
+
+// See DAITA_ACTION_TYPE:
+// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h
+#[repr(C)]
+pub enum ActionType {
+ InjectPadding,
+}
+
+// See DAITA_PADDING_ACTION:
+// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h
+#[repr(C)]
+#[derive(Debug, Clone, Copy)]
+pub struct PaddingAction {
+ pub byte_count: u16,
+ pub replace: BOOLEAN,
+}
+
+// See DAITA_ACTION:
+// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h
+#[repr(C)]
+pub struct Action {
+ pub peer: [u8; WIREGUARD_KEY_LENGTH],
+ pub action_type: ActionType,
+ pub payload: ActionPayload,
+ pub user_context: usize,
+}
+
+#[repr(C)]
+pub union ActionPayload {
+ pub padding: PaddingAction,
+}
+
+/// Maximum number of events that can be stored in the underlying buffer
+const EVENTS_CAPACITY: usize = 1000;
+/// Maximum number of actions that can be stored in the underlying buffer
+const ACTIONS_CAPACITY: usize = 1000;
+
+pub mod bindings {
+ use super::*;
+ use windows_sys::Win32::Foundation::BOOL;
+
+ pub type WireGuardDaitaActivateFn = unsafe extern "stdcall" fn(
+ adapter: RawHandle,
+ events_capacity: usize,
+ actions_capacity: usize,
+ ) -> BOOL;
+ pub type WireGuardDaitaEventDataAvailableEventFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle) -> RawHandle;
+ pub type WireGuardDaitaReceiveEventsFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, events: *mut Event) -> usize;
+ pub type WireGuardDaitaSendActionFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, action: *const Action) -> BOOL;
+}
+
+#[derive(Debug)]
+pub struct Session {
+ adapter: Arc<super::WgNtAdapter>,
+}
+
+impl Session {
+ /// Call `WireGuardDaitaActivate` for an existing WireGuard interface
+ pub(super) fn from_adapter(adapter: Arc<super::WgNtAdapter>) -> io::Result<Session> {
+ // SAFETY: `WgNtAdapter` has a valid adapter handle
+ unsafe {
+ adapter
+ .dll_handle
+ .daita_activate(adapter.handle, EVENTS_CAPACITY, ACTIONS_CAPACITY)
+ }?;
+ Ok(Self { adapter })
+ }
+
+ pub fn receive_events<'a>(
+ &self,
+ buffer: &'a mut [Event; EVENTS_CAPACITY],
+ ) -> io::Result<&'a [Event]> {
+ let num_events = unsafe {
+ // SAFETY: The adapter is valid, and the buffer is large enough to accommodate all
+ // events.
+ self.adapter
+ .dll_handle
+ .daita_receive_events(self.adapter.handle, buffer.as_mut_ptr())?
+ };
+ Ok(unsafe { std::slice::from_raw_parts(buffer.as_ptr(), num_events) })
+ }
+
+ pub fn send_action(&self, action: &Action) -> io::Result<()> {
+ // SAFETY: The adapter is valid
+ unsafe {
+ self.adapter
+ .dll_handle
+ .daita_send_action(self.adapter.handle, action)
+ }
+ }
+
+ pub fn event_data_available_event(&self) -> RawHandle {
+ // SAFETY: The adapter is valid
+ // This never fails when there's a DAITA session
+ unsafe {
+ self.adapter
+ .dll_handle
+ .daita_event_data_available_event(self.adapter.handle)
+ .unwrap()
+ }
+ }
+}
+
+fn maybenot_event_from_event(
+ event: &Event,
+ machine_ids: &MachineMap,
+ override_size: Option<u16>,
+) -> Option<maybenot::framework::TriggerEvent> {
+ let xmit_bytes = override_size.unwrap_or(event.xmit_bytes);
+ match event.event_type {
+ EventType::PaddingReceived => Some(maybenot::framework::TriggerEvent::PaddingRecv {
+ bytes_recv: xmit_bytes,
+ }),
+ EventType::NonpaddingSent => Some(maybenot::framework::TriggerEvent::NonPaddingSent {
+ bytes_sent: xmit_bytes,
+ }),
+ EventType::NonpaddingReceived => Some(maybenot::framework::TriggerEvent::NonPaddingRecv {
+ bytes_recv: xmit_bytes,
+ }),
+ EventType::PaddingSent => Some(maybenot::framework::TriggerEvent::PaddingSent {
+ bytes_sent: xmit_bytes,
+ machine: machine_ids.get_machine_id(event.user_context)?.to_owned(),
+ }),
+ }
+}
+
+/// Handle for a set of DAITA machines.
+/// Note: `close` is NOT called implicitly when this is dropped.
+pub struct MachinistHandle {
+ quit_event: talpid_windows::sync::Event,
+}
+
+impl MachinistHandle {
+ fn new(quit_event: &talpid_windows::sync::Event) -> io::Result<MachinistHandle> {
+ Ok(MachinistHandle {
+ quit_event: quit_event.duplicate()?,
+ })
+ }
+
+ /// Signal quit event
+ pub fn close(&self) -> io::Result<()> {
+ self.quit_event.set()
+ }
+}
+
+pub struct Machinist {
+ daita: Arc<Session>,
+ machine_ids: MachineMap,
+ machine_tasks: HashMap<usize, JoinHandle<()>>,
+ tokio_handle: tokio::runtime::Handle,
+ quit_event: talpid_windows::sync::Event,
+ peer: PublicKey,
+ override_size: Option<u16>,
+}
+
+// TODO: This is silly. Let me use the raw ID of MachineId, please.
+struct MachineMap {
+ id_to_num: HashMap<MachineId, usize>,
+ num_to_id: HashMap<usize, MachineId>,
+}
+
+impl MachineMap {
+ fn new() -> Self {
+ Self {
+ id_to_num: HashMap::new(),
+ num_to_id: HashMap::new(),
+ }
+ }
+
+ fn get_or_create_raw_id(&mut self, machine_id: MachineId) -> usize {
+ *self.id_to_num.entry(machine_id).or_insert_with(|| {
+ let raw_id = self.num_to_id.len();
+ self.num_to_id.insert(raw_id, machine_id);
+ raw_id
+ })
+ }
+
+ fn get_machine_id(&self, raw_id: usize) -> Option<&MachineId> {
+ self.num_to_id.get(&raw_id)
+ }
+}
+
+impl Machinist {
+ /// Spawn an actor that handles scheduling of Maybenot actions and forwards DAITA events to the framework.
+ pub fn spawn(
+ resource_dir: &Path,
+ daita: Session,
+ peer: PublicKey,
+ mtu: u16,
+ ) -> std::result::Result<MachinistHandle, Error> {
+ const MAX_PADDING_BYTES: f64 = 0.0;
+ const MAX_BLOCKING_BYTES: f64 = 0.0;
+
+ static MAYBENOT_MACHINES: OnceCell<Vec<maybenot::machine::Machine>> = OnceCell::new();
+
+ let machines = MAYBENOT_MACHINES.get_or_try_init(|| {
+ let path = resource_dir.join("maybenot_machines");
+ log::debug!("Reading maybenot machines from {}", path.display());
+
+ let mut machines = vec![];
+ let machines_str = fs::read_to_string(path).map_err(Error::EnumerateMachines)?;
+ for machine_str in machines_str.lines() {
+ let machine_str = machine_str.trim();
+ if matches!(machine_str.chars().next(), None | Some('#')) {
+ continue;
+ }
+ log::debug!("Adding maybenot machine: {machine_str}");
+ machines.push(
+ machine_str
+ .parse::<maybenot::machine::Machine>()
+ .map_err(|_error| Error::InvalidMachine(machine_str.to_owned()))?,
+ );
+ }
+ Ok(machines)
+ })?;
+
+ let quit_event =
+ talpid_windows::sync::Event::new(true, false).map_err(Error::InitializeQuitEvent)?;
+ let handle = MachinistHandle::new(&quit_event).map_err(Error::InitializeHandle)?;
+
+ let framework = maybenot::framework::Framework::new(
+ machines.clone(),
+ MAX_PADDING_BYTES,
+ MAX_BLOCKING_BYTES,
+ mtu,
+ std::time::Instant::now(),
+ )
+ .map_err(|error| Error::InitializeMaybenot(error.to_string()))?;
+
+ let daita = Arc::new(daita);
+ let tokio_handle = tokio::runtime::Handle::current();
+
+ std::thread::spawn(move || {
+ Self {
+ daita,
+ machine_ids: MachineMap::new(),
+ machine_tasks: HashMap::new(),
+ tokio_handle,
+ quit_event,
+ peer,
+ // TODO: We're assuming that constant packet size is always enabled here
+ override_size: Some(mtu),
+ }
+ .event_loop(framework);
+ });
+
+ Ok(handle)
+ }
+
+ fn event_loop(
+ mut self,
+ mut framework: maybenot::framework::Framework<Vec<maybenot::machine::Machine>>,
+ ) {
+ use windows_sys::Win32::Foundation::WAIT_OBJECT_0;
+
+ loop {
+ if unsafe { WaitForSingleObject(self.quit_event.as_raw(), 0) } == WAIT_OBJECT_0 {
+ break;
+ }
+
+ let events = match self.wait_for_events() {
+ Ok(events) => {
+ if events.is_empty() {
+ break;
+ }
+ events
+ }
+ Err(error) => {
+ log::error!("Error while waiting for DAITA events: {error}");
+ break;
+ }
+ };
+
+ for action in framework.trigger_events(&events, std::time::Instant::now()) {
+ self.handle_action(action);
+ }
+ }
+
+ log::debug!("Stopped DAITA event loop");
+ }
+
+ fn handle_action(&mut self, action: &maybenot::framework::Action) {
+ match *action {
+ maybenot::framework::Action::Cancel { machine } => {
+ let raw_id = self.machine_ids.get_or_create_raw_id(machine);
+
+ // Drop all scheduled actions for a given machine
+ if let Some(task) = self.machine_tasks.get_mut(&raw_id) {
+ task.abort();
+ }
+ }
+ maybenot::framework::Action::InjectPadding {
+ timeout,
+ size,
+ machine,
+ replace,
+ ..
+ } => {
+ let peer = self.peer.clone();
+
+ let raw_id = self.machine_ids.get_or_create_raw_id(machine);
+ self.machine_tasks.entry(raw_id).and_modify(|f| f.abort());
+
+ let action = Action {
+ peer: *peer.as_bytes(),
+ action_type: ActionType::InjectPadding,
+ user_context: raw_id,
+ payload: ActionPayload {
+ padding: PaddingAction {
+ byte_count: size,
+ replace: if replace { 1 } else { 0 },
+ },
+ },
+ };
+
+ if timeout == Duration::ZERO {
+ if let Err(error) = self.daita.send_action(&action) {
+ log::error!("Failed to send DAITA action: {error}");
+ }
+ } else {
+ // Schedule action on the tokio runtime
+ let daita = Arc::downgrade(&self.daita);
+ let task = self.tokio_handle.spawn(async move {
+ tokio::time::sleep(timeout).await;
+
+ let Some(daita) = daita.upgrade() else { return };
+
+ if let Err(error) = daita.send_action(&action) {
+ log::error!("Failed to send DAITA action: {error}");
+ }
+ });
+ self.machine_tasks.insert(raw_id, task);
+ }
+ }
+ maybenot::framework::Action::BlockOutgoing { .. } => {}
+ }
+ }
+
+ /// Take all events from the ring buffer while there are any left.
+ /// If there are no events available, wait for events to arrive.
+ /// Otherwise, break and return a non-zero number of events to be processed.
+ /// If the quit event was signaled, this returns an empty vector.
+ fn wait_for_events(&mut self) -> io::Result<Vec<maybenot::framework::TriggerEvent>> {
+ use windows_sys::Win32::Foundation::WAIT_OBJECT_0;
+
+ let wait_events = [
+ self.quit_event.as_raw(),
+ self.daita.event_data_available_event() as isize,
+ ];
+
+ let mut event_buffer: [Event; EVENTS_CAPACITY] = unsafe { std::mem::zeroed() };
+
+ loop {
+ match self.daita.receive_events(&mut event_buffer) {
+ Ok(events) => {
+ let converted_events: Vec<_> = events
+ .iter()
+ .filter(|event| &event.peer == self.peer.as_bytes())
+ .filter_map(|event| {
+ maybenot_event_from_event(event, &self.machine_ids, self.override_size)
+ })
+ .collect();
+ if !converted_events.is_empty() {
+ return Ok(converted_events);
+ }
+ // Try again if we only received events for irrelevant peers
+ }
+ Err(error) => {
+ if error.raw_os_error() == Some(ERROR_NO_MORE_ITEMS as i32) {
+ let wait_result = unsafe {
+ WaitForMultipleObjects(
+ u32::try_from(wait_events.len()).unwrap(),
+ wait_events.as_ptr(),
+ 0,
+ INFINITE,
+ )
+ };
+
+ if wait_result == WAIT_OBJECT_0 {
+ // Quit event signaled
+ break Ok(vec![]);
+ }
+ if wait_result == WAIT_OBJECT_0 + 1 {
+ // Event object signaled -- try to receive more events
+ continue;
+ }
+ }
+ break Err(std::io::Error::last_os_error());
+ }
+ }
+ }
+ }
+}
diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt/mod.rs
index 10cc45b384..6acd6dc710 100644
--- a/talpid-wireguard/src/wireguard_nt.rs
+++ b/talpid-wireguard/src/wireguard_nt/mod.rs
@@ -9,14 +9,14 @@ use futures::SinkExt;
use ipnetwork::IpNetwork;
use once_cell::sync::{Lazy, OnceCell};
use std::{
- ffi::CStr,
+ ffi::{c_uchar, CStr},
fmt,
future::Future,
- io, mem,
- mem::MaybeUninit,
+ io,
+ mem::{self, MaybeUninit},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
os::windows::io::RawHandle,
- path::Path,
+ path::{Path, PathBuf},
pin::Pin,
ptr,
sync::{Arc, Mutex},
@@ -38,6 +38,8 @@ use windows_sys::{
},
};
+mod daita;
+
static WG_NT_DLL: OnceCell<WgNtDll> = OnceCell::new();
static ADAPTER_TYPE: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap());
static ADAPTER_ALIAS: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap());
@@ -159,12 +161,23 @@ pub enum Error {
/// Failed to parse data returned by the driver
#[error("Failed to parse data returned by wireguard-nt")]
InvalidConfigData,
+
+ /// DAITA machinist failed
+ #[error("Failed to enable DAITA on tunnel device")]
+ EnableTunnelDaita(#[source] io::Error),
+
+ /// DAITA machinist failed
+ #[error("Failed to initialize DAITA machinist")]
+ InitializeMachinist(#[source] daita::Error),
}
pub struct WgNtTunnel {
+ resource_dir: PathBuf,
+ config: Arc<Mutex<Config>>,
device: Option<Arc<WgNtAdapter>>,
interface_name: String,
setup_handle: tokio::task::JoinHandle<()>,
+ daita_handle: Option<daita::MachinistHandle>,
_logger_handle: LoggerHandle,
}
@@ -305,6 +318,7 @@ bitflags! {
const REPLACE_ALLOWED_IPS = 0b00100000;
const REMOVE = 0b01000000;
const UPDATE = 0b10000000;
+ const HAS_CONSTANT_PACKET_SIZE = 0b100000000;
}
}
@@ -322,6 +336,7 @@ struct WgPeer {
rx_bytes: u64,
last_handshake: u64,
allowed_ips_count: u32,
+ constant_packet_size: c_uchar,
}
#[derive(Clone, Copy)]
@@ -446,7 +461,7 @@ impl WgNtTunnel {
let device = Some(device2.clone());
let setup_future = setup_ip_listener(
- device2,
+ device2.clone(),
u32::from(config.mtu),
config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()),
);
@@ -457,17 +472,50 @@ impl WgNtTunnel {
});
Ok(WgNtTunnel {
+ resource_dir: resource_dir.to_owned(),
+ config: Arc::new(Mutex::new(config.clone())),
device,
interface_name,
setup_handle,
+ daita_handle: None,
_logger_handle: logger_handle,
})
}
fn stop_tunnel(&mut self) {
self.setup_handle.abort();
+ if let Some(daita_handle) = self.daita_handle.take() {
+ let _ = daita_handle.close();
+ }
let _ = self.device.take();
}
+
+ fn spawn_machinist(&mut self) -> Result<()> {
+ if let Some(handle) = self.daita_handle.take() {
+ log::info!("Stopping previous DAITA machines");
+ let _ = handle.close();
+ }
+
+ let Some(device) = self.device.clone() else {
+ log::debug!("Tunnel is stopped; not starting machines");
+ return Ok(());
+ };
+
+ let config = self.config.lock().unwrap();
+
+ log::info!("Initializing DAITA for wireguard device");
+ let session = daita::Session::from_adapter(device).map_err(Error::EnableTunnelDaita)?;
+ self.daita_handle = Some(
+ daita::Machinist::spawn(
+ &self.resource_dir,
+ session,
+ config.entry_peer.public_key.clone(),
+ config.mtu,
+ )
+ .map_err(Error::InitializeMachinist)?,
+ );
+ Ok(())
+ }
}
async fn setup_ip_listener(device: Arc<WgNtAdapter>, mtu: u32, has_ipv6: bool) -> Result<()> {
@@ -622,6 +670,11 @@ struct WgNtDll {
func_set_adapter_state: WireGuardSetStateFn,
func_set_logger: WireGuardSetLoggerFn,
func_set_adapter_logging: WireGuardSetAdapterLoggingFn,
+
+ func_daita_activate: daita::bindings::WireGuardDaitaActivateFn,
+ func_daita_event_data_available_event: daita::bindings::WireGuardDaitaEventDataAvailableEventFn,
+ func_daita_receive_events: daita::bindings::WireGuardDaitaReceiveEventsFn,
+ func_daita_send_action: daita::bindings::WireGuardDaitaSendActionFn,
}
unsafe impl Send for WgNtDll {}
@@ -694,6 +747,30 @@ impl WgNtDll {
CStr::from_bytes_with_nul(b"WireGuardSetAdapterLogging\0").unwrap(),
)?) as *const _ as *const _)
},
+ func_daita_activate: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardDaitaActivate\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_daita_event_data_available_event: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardDaitaEventDataAvailableEvent\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_daita_receive_events: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardDaitaReceiveEvents\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_daita_send_action: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardDaitaSendAction\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
})
}
@@ -790,6 +867,52 @@ impl WgNtDll {
}
Ok(())
}
+
+ pub unsafe fn daita_activate(
+ &self,
+ adapter: RawHandle,
+ events_capacity: usize,
+ actions_capacity: usize,
+ ) -> io::Result<()> {
+ if (self.func_daita_activate)(adapter, events_capacity, actions_capacity) == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+
+ pub unsafe fn daita_event_data_available_event(
+ &self,
+ adapter: RawHandle,
+ ) -> io::Result<RawHandle> {
+ let ready_event = (self.func_daita_event_data_available_event)(adapter);
+ if ready_event.is_null() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(ready_event)
+ }
+
+ pub unsafe fn daita_receive_events(
+ &self,
+ adapter: RawHandle,
+ events: *mut daita::Event,
+ ) -> io::Result<usize> {
+ let num_events = (self.func_daita_receive_events)(adapter, events);
+ if num_events == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(num_events)
+ }
+
+ pub unsafe fn daita_send_action(
+ &self,
+ adapter: RawHandle,
+ action: *const daita::Action,
+ ) -> io::Result<()> {
+ if (self.func_daita_send_action)(adapter, action) == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
}
impl Drop for WgNtDll {
@@ -816,11 +939,13 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> {
buffer.extend(as_uninit_byte_slice(&header));
for peer in config.peers() {
- let flags = if peer.psk.is_some() {
- WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT
- } else {
- WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT
- };
+ let mut flags = WgPeerFlag::HAS_PUBLIC_KEY
+ | WgPeerFlag::HAS_ENDPOINT
+ | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE;
+ if peer.psk.is_some() {
+ flags |= WgPeerFlag::HAS_PRESHARED_KEY;
+ }
+ let constant_packet_size = if peer.constant_packet_size { 1 } else { 0 };
let wg_peer = WgPeer {
flags,
reserved: 0,
@@ -836,6 +961,7 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> {
rx_bytes: 0,
last_handshake: 0,
allowed_ips_count: u32::try_from(peer.allowed_ips.len()).unwrap(),
+ constant_packet_size,
};
buffer.extend(as_uninit_byte_slice(&wg_peer));
@@ -960,13 +1086,16 @@ impl Tunnel for WgNtTunnel {
config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
let device = self.device.clone();
+ let current_config = self.config.clone();
Box::pin(async move {
let Some(device) = device else {
log::error!("Failed to set config: No tunnel device");
return Err(super::TunnelError::SetConfigError);
};
- device.set_config(&config).map_err(|error| {
+ let mut current_config = current_config.lock().unwrap();
+ *current_config = config;
+ device.set_config(&current_config).map_err(|error| {
log::error!(
"{}",
error.display_chain_with_msg("Failed to set wg-nt tunnel config")
@@ -975,6 +1104,16 @@ impl Tunnel for WgNtTunnel {
})
})
}
+
+ fn start_daita(&mut self) -> std::result::Result<(), crate::TunnelError> {
+ self.spawn_machinist().map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to start DAITA for wg-nt tunnel")
+ );
+ super::TunnelError::SetConfigError
+ })
+ }
}
pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8>] {
@@ -984,7 +1123,6 @@ pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8
#[cfg(test)]
mod tests {
use super::*;
- use once_cell::sync::Lazy;
use talpid_types::net::wireguard;
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
@@ -1009,12 +1147,16 @@ mod tests {
allowed_ips: vec!["1.3.3.0/24".parse().unwrap()],
endpoint: "1.2.3.4:1234".parse().unwrap(),
psk: None,
+ #[cfg(target_os = "windows")]
+ constant_packet_size: false,
},
exit_peer: None,
ipv4_gateway: "0.0.0.0".parse().unwrap(),
ipv6_gateway: None,
mtu: 0,
obfuscator_config: None,
+ daita: false,
+ quantum_resistant: false,
});
static WG_STRUCT_CONFIG: Lazy<Interface> = Lazy::new(|| Interface {
@@ -1026,7 +1168,9 @@ mod tests {
peers_count: 1,
},
p0: WgPeer {
- flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT,
+ flags: WgPeerFlag::HAS_PUBLIC_KEY
+ | WgPeerFlag::HAS_ENDPOINT
+ | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE,
reserved: 0,
public_key: *WG_PUBLIC_KEY.as_bytes(),
preshared_key: [0; WIREGUARD_KEY_LENGTH],
@@ -1039,6 +1183,8 @@ mod tests {
rx_bytes: 0,
last_handshake: 0,
allowed_ips_count: 1,
+ #[cfg(target_os = "windows")]
+ constant_packet_size: 0,
},
p0_allowed_ip_0: WgAllowedIp {
address: WgIpAddr::from("1.3.3.0".parse::<Ipv4Addr>().unwrap()),