diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-09-05 10:17:09 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-04-16 14:43:15 +0200 |
| commit | af96a710398870587df9e07ee6f5afd16b8d9888 (patch) | |
| tree | 7a5dfc3ee3a388c74c5a144983b5dfba8a057902 | |
| parent | 99ae0b436f173b576343111cac38d6bec4ce2487 (diff) | |
| download | mullvadvpn-af96a710398870587df9e07ee6f5afd16b8d9888.tar.xz mullvadvpn-af96a710398870587df9e07ee6f5afd16b8d9888.zip | |
Add DAITA Windows client and updated tuncfg
40 files changed, 1331 insertions, 244 deletions
diff --git a/Cargo.lock b/Cargo.lock index cc54dcb168..eedb5dd480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,6 +18,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + +[[package]] name = "aead" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -53,6 +59,18 @@ dependencies = [ ] [[package]] +name = "ahash" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] name = "aho-corasick" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -583,6 +601,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] + +[[package]] name = "cpufeatures" version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -592,6 +619,15 @@ dependencies = [ ] [[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + +[[package]] name = "crossbeam-channel" version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -715,6 +751,12 @@ dependencies = [ ] [[package]] +name = "dary_heap" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" + +[[package]] name = "dashmap" version = "5.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1219,6 +1261,15 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" @@ -1757,6 +1808,30 @@ dependencies = [ ] [[package]] +name = "libflate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +dependencies = [ + "core2", + "hashbrown 0.13.2", + "rle-decode-fast", +] + +[[package]] name = "libm" version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1836,6 +1911,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef" [[package]] +name = "maybenot" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cc2e64fe3f5fb1e247110a9a408449eff2259cc272cf57bad6f161e801ac962" +dependencies = [ + "byteorder", + "hex", + "libflate", + "rand 0.8.5", + "rand_distr", + "ring", + "serde", + "simple-error", +] + +[[package]] name = "md-5" version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2974,6 +3065,16 @@ dependencies = [ ] [[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] name = "rand_hc" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3100,6 +3201,12 @@ dependencies = [ ] [[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + +[[package]] name = "rs-release" version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3468,6 +3575,12 @@ dependencies = [ ] [[package]] +name = "simple-error" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8542b68b8800c3cda649d2c72d688b6907b30f1580043135d61669d4aad1c175" + +[[package]] name = "simple-signal" version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3870,6 +3983,7 @@ dependencies = [ name = "talpid-wireguard" version = "0.0.0" dependencies = [ + "base64 0.13.1", "bitflags 1.3.2", "byteorder", "chrono", @@ -3880,6 +3994,7 @@ dependencies = [ "ipnetwork", "libc", "log", + "maybenot", "netlink-packet-core", "netlink-packet-route", "netlink-packet-utils", @@ -4822,6 +4937,26 @@ dependencies = [ ] [[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + +[[package]] name = "zeroize" version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index b0c28f016a..65391bcf37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "ios/MullvadREST/Transport/Shadowsocks/shadowsocks-proxy", "ios/TunnelObfuscation/tunnel-obfuscator-proxy", "mullvad-api", + "mullvad-daemon", "mullvad-cli", "mullvad-daemon", "mullvad-exclude", diff --git a/dist-assets/maybenot_machines b/dist-assets/maybenot_machines new file mode 100644 index 0000000000..e7bde4f764 --- /dev/null +++ b/dist-assets/maybenot_machines @@ -0,0 +1,4 @@ +789cd5cfbb0900200c04d08b833886adb889389f5bb9801be811acb58ae2837ce02010c158b070555c9538b6377a64dbb0ceff242c20b79038507dd169fbede9f629bf6f021efa1b66 +789ccd934b48024118c777357ae9250aa14b8542085d2a0f264233160975a80c440f5b04455deaa050143d0e81d8835ea708f6d2a9a2430511d2a136118bcaa854422a93c44d93a243121844b3204b8b87acadf0c7f0c137f3cdccf7e08f635cc200190eb81099af106235eda133a95a4f4df5dcbfad2eb640b4c921135330e42ba1585066951456b311b3f6d8ca4254ae415f212f356efba0928e44d9642df27769b79162fd9ff1ccf3fef7b9b19982cee963405699ee1c136e107bf0ae19d6fdff9ec76770ac8bc1a9a34847c986d9a8858391d28af2a20b76622824f55125186dedd813d9c2207bd22c382f76f1aed070ad7a896bbdc09317b48c1121deefa51f599855e23b7d1519a1eb044134414662573e1d546d1df6efd6ce401cdb3f5aaadb56e8a9b69d7ab140d60c0924a601b7086929019a54ef66e5b02cd7c376a86199d63cc57fbb63c9daf18f2b9805e6d5f49c9a0e249dff158df6919c0232003a87321e2f0ff86be903a6c37ad5 +789ced956d28435118c7cfdde63544f9424292489222493a77214962de92248992bc535e4221ad2549920f5e22e5ada525496b1faee52d9fc6e4255bb3469b96661f282379ee1d574bb7257c50fb7d38dd73efd3e9dcf33fe77708e4880943e3002180c6197c74b6d09ce6dd24a2a003a4530dc6e715ab4d44dea8222e9f8c6292878ee0a1afb598ca18ae876f0212ca1c2110b19f932c1156eee2e99d70596bed360e1294f8c5adeb71aa2e6f64b251f965721f840d99455ba15adc92e60d13d173d6b9f87bf8e87e690f7224a8aedba804cfc842366979b60f4f23731342d228a67c54f910ad6093a2739eb1e8f05c70bc36d2fd827dcf857bff44fa70bd19c356b21ee772ef8cef220d29ab4ed519d8f196e95f119eb37d179f78a0177af1ab4c78d602918b0ba8d7934e3a73d28717eb2119881542114be05a96be3bf0143f46d3d15f73ae682f4eaa90676bb0adaf4ed13eaec2200490889ab3dec54fe123469fbe45e4e2142dde3092c951524a32a734b103ec5de55f9ab261ca27db6ac7ee7ac4f9ef9e47088ced55531300679a2056e9cde07f8507a9c3f9cc4d3566464938f8f5e418af4835d87e9370dbde9339cc066cf79173a7fc77de00293aaf16 +789ce5914b4802511486ef75c8405a84b4aa8810029b8236b9a8b03b5af6801e24512d24225ab409a2550521152e24c455b831706a216ea4402c5a0d1404465022980d990d5941140c328bc81eb7910686bb89dc047e1c0edc737e2e87ff8740cd23c24d05a4702b1e0a685737bb36fc36a6ef76b92adc6be5a2b9f7d1abee6a0b5e1680009a75c99d79730cf5e065623185fcc7f5bb0bb329e2a81f663a6e863dbe341a9aae9c688b6408ddd6a51bd719a2e880860e0868dc29e13a2174a503050a9ed6304792d1e1ddb63378a8420bd6b9186bfbb033721834a9f84e0af75fe191d8bdced6b8e2787b594b3093bd57de7fe5b3e2d9600df3c8b9e47e491b85a2fffb7f68c1296b8be6aec798b987b790a86f22722807f990181f140618d7d3c8419dcbc4e1a14ca3c3ab5933d31600a19294087d0dfae63b549bbf88f0c124e1d854f6dcc0bf0a484e5db78f569069f2b03f41e84a8d2f5e5a7583 diff --git a/docs/architecture.md b/docs/architecture.md index 1deba30706..78305a8c6a 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -200,7 +200,7 @@ WireGuard tunnel to the relay and deriving the PSK within the tunnel. The PSK is stored in memory on the relay and the client, along with a new client generated ephemeral WireGuard key. Subsequently, a new tunnel is created using the new WireGuard key and the PSK, ensuring that the tunnel is quantum-resistant. -See the [protocol definition file](../talpid-tunnel-config-client/proto/tunnel_config.proto) for +See the [protocol definition file](../talpid-tunnel-config-client/proto/ephemeralpeer.proto) for more details on the protocol. #### Quantum-resistant tunnels & Multihop diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js index 99c8fb87ef..7087d71301 100644 --- a/gui/tasks/distribution.js +++ b/gui/tasks/distribution.js @@ -152,6 +152,7 @@ const config = { from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard-nt/mullvad-wireguard.dll'), to: '.', }, + { from: distAssets('maybenot_machines'), to: '.' }, ], }, diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index deaf29ef10..73b387a8d8 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -303,6 +303,8 @@ struct WireGuardRelay { #[serde(flatten)] relay: Relay, public_key: wireguard::PublicKey, + #[serde(default)] + daita: bool, } impl WireGuardRelay { @@ -312,6 +314,7 @@ impl WireGuardRelay { location, relay_list::RelayEndpointData::Wireguard(relay_list::WireguardRelayEndpointData { public_key: self.public_key, + daita: self.daita, }), ) } diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs index 7ef60d758c..f022402a83 100644 --- a/mullvad-cli/src/cmds/relay.rs +++ b/mullvad-cli/src/cmds/relay.rs @@ -542,6 +542,8 @@ impl Relay { allowed_ips: all_of_the_internet(), endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port), psk: None, + #[cfg(target_os = "windows")] + constant_packet_size: false, }, exit_peer: None, ipv4_gateway, diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs index 19d5c1a3c9..77338ee336 100644 --- a/mullvad-cli/src/cmds/tunnel.rs +++ b/mullvad-cli/src/cmds/tunnel.rs @@ -1,6 +1,8 @@ use anyhow::Result; use clap::Subcommand; use mullvad_management_interface::MullvadProxyClient; +#[cfg(target_os = "windows")] +use mullvad_types::wireguard::DaitaSettings; use mullvad_types::{ constraints::Constraint, wireguard::{QuantumResistantState, RotationInterval, DEFAULT_ROTATION_INTERVAL}, @@ -38,6 +40,10 @@ pub enum TunnelOptions { /// Configure quantum-resistant key exchange #[arg(long)] quantum_resistant: Option<QuantumResistantState>, + /// Configure whether to enable DAITA + #[cfg(target_os = "windows")] + #[arg(long)] + daita: Option<BooleanOption>, /// The key rotation interval. Number of hours, or 'any' #[arg(long)] rotation_interval: Option<Constraint<RotationInterval>>, @@ -95,6 +101,9 @@ impl Tunnel { tunnel_options.wireguard.quantum_resistant, ); + #[cfg(target_os = "windows")] + print_option!("DAITA", tunnel_options.wireguard.daita.enabled); + let key = rpc.get_wireguard_key().await?; print_option!("Public key", key.key,); print_option!(format_args!( @@ -129,10 +138,20 @@ impl Tunnel { TunnelOptions::Wireguard { mtu, quantum_resistant, + #[cfg(target_os = "windows")] + daita, rotation_interval, rotate_key, } => { - Self::handle_wireguard(mtu, quantum_resistant, rotation_interval, rotate_key).await + Self::handle_wireguard( + mtu, + quantum_resistant, + #[cfg(target_os = "windows")] + daita, + rotation_interval, + rotate_key, + ) + .await } TunnelOptions::Ipv6 { state } => Self::handle_ipv6(state).await, } @@ -159,6 +178,7 @@ impl Tunnel { async fn handle_wireguard( mtu: Option<Constraint<u16>>, quantum_resistant: Option<QuantumResistantState>, + #[cfg(target_os = "windows")] daita: Option<BooleanOption>, rotation_interval: Option<Constraint<RotationInterval>>, rotate_key: Option<RotateKey>, ) -> Result<()> { @@ -174,6 +194,13 @@ impl Tunnel { println!("Quantum resistant setting has been updated"); } + #[cfg(target_os = "windows")] + if let Some(daita) = daita { + rpc.set_daita_settings(DaitaSettings { enabled: *daita }) + .await?; + println!("DAITA setting has been updated"); + } + if let Some(interval) = rotation_interval { match interval { Constraint::Only(interval) => { diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs index e605efbe3b..0d6ea0b0c0 100644 --- a/mullvad-cli/src/format.rs +++ b/mullvad-cli/src/format.rs @@ -174,6 +174,17 @@ fn format_relay_connection( "\nQuantum resistant tunnel: no" }; + #[cfg(target_os = "windows")] + let daita = if !verbose { + "" + } else if endpoint.daita { + "\nDAITA: yes" + } else { + "\nDAITA: no" + }; + #[cfg(not(target_os = "windows"))] + let daita = ""; + let mut bridge_type = String::new(); let mut obfuscator_type = String::new(); if verbose { @@ -186,7 +197,7 @@ fn format_relay_connection( } format!( - "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{bridge_type}{obfuscator_type}", + "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{daita}{bridge_type}{obfuscator_type}", first_hop = first_hop.unwrap_or_default(), bridge = bridge.unwrap_or_default(), obfuscator = obfuscator.unwrap_or_default(), diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 365e82efeb..2733482a4c 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -37,9 +37,11 @@ use futures::{ StreamExt, }; use geoip::GeoIpHandler; -use mullvad_relay_selector::{RelaySelector, SelectorConfig}; +use mullvad_relay_selector::{AdditionalRelayConstraints, AdditionalWireguardConstraints, RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; +#[cfg(target_os = "windows")] +use mullvad_types::wireguard::DaitaSettings; use mullvad_types::{ access_method::{AccessMethod, AccessMethodSetting}, account::{AccountData, AccountToken, VoucherSubmission}, @@ -256,6 +258,9 @@ pub enum DaemonCommand { SetEnableIpv6(ResponseTx<(), settings::Error>, bool), /// Set whether to enable PQ PSK exchange in the tunnel SetQuantumResistantTunnel(ResponseTx<(), settings::Error>, QuantumResistantState), + /// Set DAITA settings for the tunnel + #[cfg(target_os = "windows")] + SetDaitaSettings(ResponseTx<(), settings::Error>, DaitaSettings), /// Set DNS options or servers to use SetDnsOptions(ResponseTx<(), settings::Error>, DnsOptions), /// Set override options to use for a given relay @@ -1242,6 +1247,10 @@ where self.on_set_quantum_resistant_tunnel(tx, quantum_resistant_state) .await } + #[cfg(target_os = "windows")] + SetDaitaSettings(tx, daita_settings) => { + self.on_set_daita_settings(tx, daita_settings).await + } SetDnsOptions(tx, dns_servers) => self.on_set_dns_options(tx, dns_servers).await, SetRelayOverride(tx, relay_override) => { self.on_set_relay_override(tx, relay_override).await @@ -2259,6 +2268,40 @@ where } } + #[cfg(target_os = "windows")] + async fn on_set_daita_settings( + &mut self, + tx: ResponseTx<(), settings::Error>, + daita_settings: DaitaSettings, + ) { + match self + .settings + .update(|settings| settings.tunnel_options.wireguard.daita = daita_settings) + .await + { + Ok(settings_changed) => { + Self::oneshot_send(tx, Ok(()), "set_daita_settings response"); + if settings_changed { + self.parameters_generator + .set_tunnel_options(&self.settings.tunnel_options) + .await; + self.event_listener + .notify_settings(self.settings.to_settings()); + self.relay_selector + .set_config(new_selector_config(&self.settings)); + if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) { + log::info!("Reconnecting because DAITA settings changed"); + self.reconnect_tunnel(); + } + } + } + Err(e) => { + log::error!("{}", e.display_chain_with_msg("Unable to save settings")); + Self::oneshot_send(tx, Err(e), "set_daita_settings response"); + } + } + } + async fn on_set_dns_options( &mut self, tx: ResponseTx<(), settings::Error>, @@ -2780,8 +2823,18 @@ impl DaemonShutdownHandle { } fn new_selector_config(settings: &Settings) -> SelectorConfig { + let additional_constraints = AdditionalRelayConstraints { + wireguard: AdditionalWireguardConstraints { + #[cfg(target_os = "windows")] + daita: settings.tunnel_options.wireguard.daita.enabled, + #[cfg(not(target_os = "windows"))] + daita: false, + }, + }; + SelectorConfig { relay_settings: settings.relay_settings.clone(), + additional_constraints, bridge_state: settings.bridge_state, bridge_settings: settings.bridge_settings.clone(), obfuscation_settings: settings.obfuscation_settings.clone(), diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 55573f87a7..ce203e7ba8 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -325,6 +325,25 @@ impl ManagementService for ManagementServiceImpl { Ok(Response::new(())) } + #[cfg(target_os = "windows")] + async fn set_daita_settings( + &self, + request: Request<types::DaitaSettings>, + ) -> ServiceResult<()> { + let state = mullvad_types::wireguard::DaitaSettings::from(request.into_inner()); + + log::debug!("set_daita_settings({state:?})"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::SetDaitaSettings(tx, state))?; + self.wait_for_result(rx).await?.map(Response::new)?; + Ok(Response::new(())) + } + + #[cfg(not(target_os = "windows"))] + async fn set_daita_settings(&self, _: Request<types::DaitaSettings>) -> ServiceResult<()> { + Ok(Response::new(())) + } + #[cfg(not(target_os = "android"))] async fn set_dns_options(&self, request: Request<types::DnsOptions>) -> ServiceResult<()> { let options = DnsOptions::try_from(request.into_inner()).map_err(map_protobuf_type_err)?; diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 8af34aa8af..31d39db306 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -43,6 +43,7 @@ service ManagementService { rpc SetWireguardMtu(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {} rpc SetEnableIpv6(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} rpc SetQuantumResistantTunnel(QuantumResistantState) returns (google.protobuf.Empty) {} + rpc SetDaitaSettings(DaitaSettings) returns (google.protobuf.Empty) {} rpc SetDnsOptions(DnsOptions) returns (google.protobuf.Empty) {} rpc SetRelayOverride(RelayOverride) returns (google.protobuf.Empty) {} rpc ClearAllRelayOverrides(google.protobuf.Empty) returns (google.protobuf.Empty) {} @@ -220,6 +221,7 @@ message TunnelEndpoint { ObfuscationEndpoint obfuscation = 6; Endpoint entry_endpoint = 7; TunnelMetadata tunnel_metadata = 8; + bool daita = 9; } enum ObfuscationType { @@ -494,12 +496,15 @@ message QuantumResistantState { State state = 1; } +message DaitaSettings { bool enabled = 1; } + message TunnelOptions { message OpenvpnOptions { optional uint32 mssfix = 1; } message WireguardOptions { optional uint32 mtu = 1; google.protobuf.Duration rotation_interval = 2; QuantumResistantState quantum_resistant = 4; + DaitaSettings daita = 5; } message GenericOptions { bool enable_ipv6 = 1; } @@ -584,7 +589,10 @@ message Relay { Location location = 11; } -message WireguardRelayEndpointData { bytes public_key = 1; } +message WireguardRelayEndpointData { + bytes public_key = 1; + bool daita = 2; +} message Location { string country = 1; diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 8cf3ac495d..04304a19ec 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -2,6 +2,8 @@ use crate::types; use futures::{Stream, StreamExt}; +#[cfg(target_os = "windows")] +use mullvad_types::wireguard::DaitaSettings; use mullvad_types::{ access_method::{self, AccessMethod, AccessMethodSetting}, account::{AccountData, AccountToken, VoucherSubmission}, @@ -344,6 +346,16 @@ impl MullvadProxyClient { Ok(()) } + #[cfg(target_os = "windows")] + pub async fn set_daita_settings(&mut self, settings: DaitaSettings) -> Result<()> { + let settings = types::DaitaSettings::from(settings); + self.0 + .set_daita_settings(settings) + .await + .map_err(Error::Rpc)?; + Ok(()) + } + pub async fn set_dns_options(&mut self, options: DnsOptions) -> Result<()> { let options = types::DnsOptions::from(&options); self.0.set_dns_options(options).await.map_err(Error::Rpc)?; diff --git a/mullvad-management-interface/src/types/conversions/custom_tunnel.rs b/mullvad-management-interface/src/types/conversions/custom_tunnel.rs index 8a4408ec83..2445ec3292 100644 --- a/mullvad-management-interface/src/types/conversions/custom_tunnel.rs +++ b/mullvad-management-interface/src/types/conversions/custom_tunnel.rs @@ -91,6 +91,8 @@ impl TryFrom<proto::ConnectionConfig> for mullvad_types::ConnectionConfig { allowed_ips, endpoint, psk: None, + #[cfg(target_os = "windows")] + constant_packet_size: false, }, exit_peer: None, ipv4_gateway, diff --git a/mullvad-management-interface/src/types/conversions/net.rs b/mullvad-management-interface/src/types/conversions/net.rs index 80648cbf8d..3557a6a636 100644 --- a/mullvad-management-interface/src/types/conversions/net.rs +++ b/mullvad-management-interface/src/types/conversions/net.rs @@ -40,6 +40,10 @@ impl From<talpid_types::net::TunnelEndpoint> for proto::TunnelEndpoint { tunnel_metadata: endpoint .tunnel_interface .map(|tunnel_interface| proto::TunnelMetadata { tunnel_interface }), + #[cfg(target_os = "windows")] + daita: endpoint.daita, + #[cfg(not(target_os = "windows"))] + daita: false, } } } @@ -123,6 +127,8 @@ impl TryFrom<proto::TunnelEndpoint> for talpid_types::net::TunnelEndpoint { tunnel_interface: endpoint .tunnel_metadata .map(|tunnel_metadata| tunnel_metadata.tunnel_interface), + #[cfg(target_os = "windows")] + daita: endpoint.daita, }) } } diff --git a/mullvad-management-interface/src/types/conversions/relay_list.rs b/mullvad-management-interface/src/types/conversions/relay_list.rs index 32aee834af..4e0a363702 100644 --- a/mullvad-management-interface/src/types/conversions/relay_list.rs +++ b/mullvad-management-interface/src/types/conversions/relay_list.rs @@ -122,6 +122,7 @@ impl From<mullvad_types::relay_list::Relay> for proto::Relay { "mullvad_daemon.management_interface/WireguardRelayEndpointData", proto::WireguardRelayEndpointData { public_key: data.public_key.as_bytes().to_vec(), + daita: data.daita, }, )), _ => None, @@ -236,6 +237,7 @@ impl TryFrom<proto::Relay> for mullvad_types::relay_list::Relay { MullvadEndpointData::Wireguard( mullvad_types::relay_list::WireguardRelayEndpointData { public_key: bytes_to_pubkey(&data.public_key)?, + daita: data.daita, }, ) } diff --git a/mullvad-management-interface/src/types/conversions/settings.rs b/mullvad-management-interface/src/types/conversions/settings.rs index a4d6313158..857f32d991 100644 --- a/mullvad-management-interface/src/types/conversions/settings.rs +++ b/mullvad-management-interface/src/types/conversions/settings.rs @@ -97,6 +97,10 @@ impl From<&mullvad_types::settings::TunnelOptions> for proto::TunnelOptions { .expect("Failed to convert std::time::Duration to prost_types::Duration for tunnel_options.wireguard.rotation_interval") }), quantum_resistant: Some(proto::QuantumResistantState::from(options.wireguard.quantum_resistant)), + #[cfg(target_os = "windows")] + daita: Some(proto::DaitaSettings::from(options.wireguard.daita.clone())), + #[cfg(not(target_os = "windows"))] + daita: None, }), generic: Some(proto::tunnel_options::GenericOptions { enable_ipv6: options.generic.enable_ipv6, @@ -282,6 +286,13 @@ impl TryFrom<proto::TunnelOptions> for mullvad_types::settings::TunnelOptions { .ok_or(FromProtobufTypeError::InvalidArgument( "missing quantum resistant state", ))??, + #[cfg(target_os = "windows")] + daita: wireguard_options + .daita + .map(mullvad_types::wireguard::DaitaSettings::from) + .ok_or(FromProtobufTypeError::InvalidArgument( + "missing daita settings", + ))?, }, generic: net::GenericTunnelOptions { enable_ipv6: generic_options.enable_ipv6, diff --git a/mullvad-management-interface/src/types/conversions/wireguard.rs b/mullvad-management-interface/src/types/conversions/wireguard.rs index f35a8c7216..4a4341339c 100644 --- a/mullvad-management-interface/src/types/conversions/wireguard.rs +++ b/mullvad-management-interface/src/types/conversions/wireguard.rs @@ -71,3 +71,21 @@ impl TryFrom<proto::QuantumResistantState> for mullvad_types::wireguard::Quantum } } } + +#[cfg(target_os = "windows")] +impl From<mullvad_types::wireguard::DaitaSettings> for proto::DaitaSettings { + fn from(settings: mullvad_types::wireguard::DaitaSettings) -> Self { + proto::DaitaSettings { + enabled: settings.enabled, + } + } +} + +#[cfg(target_os = "windows")] +impl From<proto::DaitaSettings> for mullvad_types::wireguard::DaitaSettings { + fn from(settings: proto::DaitaSettings) -> Self { + mullvad_types::wireguard::DaitaSettings { + enabled: settings.enabled, + } + } +} diff --git a/mullvad-relay-selector/src/relay_selector/detailer.rs b/mullvad-relay-selector/src/relay_selector/detailer.rs index 807903eb39..3dc1cc903e 100644 --- a/mullvad-relay-selector/src/relay_selector/detailer.rs +++ b/mullvad-relay-selector/src/relay_selector/detailer.rs @@ -84,6 +84,9 @@ fn wireguard_singlehop_endpoint( allowed_ips: all_of_the_internet(), // This will be filled in later, not the relay selector's problem psk: None, + // This will be filled in later + #[cfg(target_os = "windows")] + constant_packet_size: false, }; Ok(MullvadWireguardEndpoint { peer: peer_config, @@ -122,6 +125,9 @@ fn wireguard_multihop_endpoint( allowed_ips: all_of_the_internet(), // This will be filled in later, not the relay selector's problem psk: None, + // This will be filled in later + #[cfg(target_os = "windows")] + constant_packet_size: false, }; let entry_endpoint = { @@ -137,6 +143,9 @@ fn wireguard_multihop_endpoint( allowed_ips: vec![IpNetwork::from(exit.endpoint.ip())], // This will be filled in later psk: None, + // This will be filled in later + #[cfg(target_os = "windows")] + constant_packet_size: false, }; Ok(MullvadWireguardEndpoint { diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs index fbf3ac1c4b..7ec2fc02bb 100644 --- a/mullvad-relay-selector/src/relay_selector/mod.rs +++ b/mullvad-relay-selector/src/relay_selector/mod.rs @@ -99,6 +99,7 @@ pub struct RelaySelector { pub struct SelectorConfig { // Normal relay settings pub relay_settings: RelaySettings, + pub additional_constraints: AdditionalRelayConstraints, pub custom_lists: CustomListsSettings, pub relay_overrides: Vec<RelayOverride>, // Wireguard specific data @@ -108,6 +109,20 @@ pub struct SelectorConfig { pub bridge_settings: BridgeSettings, } +/// Extra relay constraints not specified in `relay_settings`. +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct AdditionalRelayConstraints { + pub wireguard: AdditionalWireguardConstraints, +} + +/// Constraints to use when selecting WireGuard servers +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct AdditionalWireguardConstraints { + /// If true, select WireGuard relays that support DAITA. If false, select any + /// server. + pub daita: bool, +} + /// Values which affect the choice of relay but are only known at runtime. #[derive(Clone, Debug)] pub struct RuntimeParameters { @@ -273,6 +288,7 @@ impl Default for SelectorConfig { let default_settings = Settings::default(); SelectorConfig { relay_settings: default_settings.relay_settings, + additional_constraints: default_settings.additional_constraints, bridge_settings: default_settings.bridge_settings, obfuscation_settings: default_settings.obfuscation_settings, bridge_state: default_settings.bridge_state, diff --git a/mullvad-relay-selector/tests/relay_selector.rs b/mullvad-relay-selector/tests/relay_selector.rs index ed3546c62e..40efae24e2 100644 --- a/mullvad-relay-selector/tests/relay_selector.rs +++ b/mullvad-relay-selector/tests/relay_selector.rs @@ -54,6 +54,7 @@ static RELAYS: Lazy<RelayList> = Lazy::new(|| RelayList { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, @@ -71,6 +72,7 @@ static RELAYS: Lazy<RelayList> = Lazy::new(|| RelayList { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, @@ -414,6 +416,7 @@ fn test_wireguard_entry() { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, @@ -431,6 +434,7 @@ fn test_wireguard_entry() { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, @@ -932,6 +936,7 @@ fn test_include_in_country() { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, @@ -949,6 +954,7 @@ fn test_include_in_country() { "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", ) .unwrap(), + daita: false, }), location: None, }, diff --git a/mullvad-types/src/relay_list.rs b/mullvad-types/src/relay_list.rs index 77cc621728..09ab312a5f 100644 --- a/mullvad-types/src/relay_list.rs +++ b/mullvad-types/src/relay_list.rs @@ -129,6 +129,7 @@ impl PartialEq for Relay { /// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", /// # ) /// # .unwrap(), + /// # daita: false, /// # }), /// # location: None, /// }; @@ -232,6 +233,9 @@ struct PortRange { pub struct WireguardRelayEndpointData { /// Public key used by the relay peer pub public_key: wireguard::PublicKey, + /// Whether the server supports DAITA + #[serde(default)] + pub daita: bool, } #[derive(Debug, Default, Clone, Deserialize, Serialize)] diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs index c85860a776..d8e9bd403b 100644 --- a/mullvad-types/src/wireguard.rs +++ b/mullvad-types/src/wireguard.rs @@ -55,6 +55,12 @@ impl FromStr for QuantumResistantState { #[error("Not a valid state")] pub struct QuantumResistantStateParseError; +#[cfg(target_os = "windows")] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] +pub struct DaitaSettings { + pub enabled: bool, +} + /// Contains account specific wireguard data #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct WireguardData { @@ -201,6 +207,9 @@ pub struct TunnelOptions { pub mtu: Option<u16>, /// Obtain a PSK using the relay config client. pub quantum_resistant: QuantumResistantState, + /// Configure DAITA + #[cfg(target_os = "windows")] + pub daita: DaitaSettings, /// Interval used for automatic key rotation #[cfg_attr(target_os = "android", jnix(skip))] pub rotation_interval: Option<RotationInterval>, @@ -212,6 +221,8 @@ impl Default for TunnelOptions { TunnelOptions { mtu: None, quantum_resistant: QuantumResistantState::Auto, + #[cfg(target_os = "windows")] + daita: DaitaSettings::default(), rotation_interval: None, } } @@ -226,6 +237,8 @@ impl TunnelOptions { QuantumResistantState::On => true, QuantumResistantState::Off => false, }, + #[cfg(target_os = "windows")] + daita: self.daita.enabled, } } } diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 0daa8b996c..ce860d0ed6 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -177,7 +177,6 @@ impl TunnelMonitor { let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?; let monitor = talpid_wireguard::WireguardMonitor::start( config, - params.options.quantum_resistant, #[cfg(not(target_os = "android"))] detect_mtu, log.as_deref(), diff --git a/talpid-tunnel-config-client/build.rs b/talpid-tunnel-config-client/build.rs index 129cfd7154..aeb21fe009 100644 --- a/talpid-tunnel-config-client/build.rs +++ b/talpid-tunnel-config-client/build.rs @@ -1,3 +1,3 @@ fn main() { - tonic_build::compile_protos("proto/tunnel_config.proto").unwrap(); + tonic_build::compile_protos("proto/ephemeralpeer.proto").unwrap(); } diff --git a/talpid-tunnel-config-client/examples/psk-exchange.rs b/talpid-tunnel-config-client/examples/psk-exchange.rs index 87200d0336..e7b34ab851 100644 --- a/talpid-tunnel-config-client/examples/psk-exchange.rs +++ b/talpid-tunnel-config-client/examples/psk-exchange.rs @@ -18,14 +18,16 @@ async fn main() { let pubkey = PublicKey::from_base64(pubkey_string.trim()).expect("Invalid public key"); let private_key = PrivateKey::new_from_random(); - let psk = talpid_tunnel_config_client::push_pq_key( + let ephemeral_peer = talpid_tunnel_config_client::request_ephemeral_peer( tuncfg_server_ip, pubkey, private_key.public_key(), + true, + false, ) .await .unwrap(); println!("private key: {private_key:?}"); - println!("psk: {psk:?}"); + println!("psk: {:?}", ephemeral_peer.psk); } diff --git a/talpid-tunnel-config-client/examples/tuncfg-server.rs b/talpid-tunnel-config-client/examples/tuncfg-server.rs index 4d29d764a2..928587e968 100644 --- a/talpid-tunnel-config-client/examples/tuncfg-server.rs +++ b/talpid-tunnel-config-client/examples/tuncfg-server.rs @@ -1,80 +1,93 @@ -//! A server implementation of the tuncfg PskExchangeV1 RPC to test -//! the client side implementation. +//! A server implementation of the tuncfg RegisterPeerV1 RPC to test +//! the client side implementation of PQ. #[allow(clippy::derive_partial_eq_without_eq)] mod proto { - tonic::include_proto!("tunnel_config"); + tonic::include_proto!("ephemeralpeer"); } use classic_mceliece_rust::{PublicKey, CRYPTO_PUBLICKEYBYTES}; use proto::{ - post_quantum_secure_server::{PostQuantumSecure, PostQuantumSecureServer}, - PskRequestV1, PskResponseV1, + ephemeral_peer_server::{EphemeralPeer, EphemeralPeerServer}, + EphemeralPeerRequestV1, EphemeralPeerResponseV1, PostQuantumResponseV1, }; use talpid_types::net::wireguard::PresharedKey; use tonic::{transport::Server, Request, Response, Status}; #[derive(Debug, Default)] -pub struct PostQuantumSecureImpl {} +pub struct EphemeralPeerImpl {} #[tonic::async_trait] -impl PostQuantumSecure for PostQuantumSecureImpl { - async fn psk_exchange_v1( +impl EphemeralPeer for EphemeralPeerImpl { + async fn register_peer_v1( &self, - request: Request<PskRequestV1>, - ) -> Result<Response<PskResponseV1>, Status> { + request: Request<EphemeralPeerRequestV1>, + ) -> Result<Response<EphemeralPeerResponseV1>, Status> { let mut rng = rand::thread_rng(); let request = request.into_inner(); - println!("wg_pubkey: {:?}", request.wg_pubkey); - println!("wg_psk_pubkey: {:?}", request.wg_psk_pubkey); + println!("wg_parent_pubkey: {:?}", request.wg_parent_pubkey); + println!( + "wg_ephemeral_peer_pubkey: {:?}", + request.wg_ephemeral_peer_pubkey + ); + println!("daita (no-op): {:?}", request.daita); - // The ciphertexts that will be returned to the client - let mut ciphertexts = Vec::new(); - // The final PSK that is computed by XORing together all the KEM outputs. - let mut psk_data = Box::new([0u8; 32]); + let post_quantum = if let Some(post_quantum) = request.post_quantum { + // The ciphertexts that will be returned to the client + let mut ciphertexts = Vec::new(); - for kem_pubkey in request.kem_pubkeys { - println!("\tKEM algorithm: {}", kem_pubkey.algorithm_name); - let (ciphertext, shared_secret) = match kem_pubkey.algorithm_name.as_str() { - "Classic-McEliece-460896f-round3" => { - let key_data: [u8; CRYPTO_PUBLICKEYBYTES] = - kem_pubkey.key_data.as_slice().try_into().unwrap(); - let public_key = PublicKey::from(&key_data); - let (ciphertext, shared_secret) = - classic_mceliece_rust::encapsulate_boxed(&public_key, &mut rng); - (ciphertext.as_array().to_vec(), *shared_secret.as_array()) - } - "Kyber1024" => { - let public_key = kem_pubkey.key_data.as_slice(); - let (ciphertext, shared_secret) = - pqc_kyber::encapsulate(public_key, &mut rng).unwrap(); - (ciphertext.to_vec(), shared_secret) - } - name => panic!("Unsupported KEM algorithm: {name}"), - }; + // The final PSK that is computed by XORing together all the KEM outputs. + let mut psk_data = Box::new([0u8; 32]); + + for kem_pubkey in post_quantum.kem_pubkeys { + println!("\tKEM algorithm: {}", kem_pubkey.algorithm_name); + let (ciphertext, shared_secret) = match kem_pubkey.algorithm_name.as_str() { + "Classic-McEliece-460896f-round3" => { + let key_data: [u8; CRYPTO_PUBLICKEYBYTES] = + kem_pubkey.key_data.as_slice().try_into().unwrap(); + let public_key = PublicKey::from(&key_data); + let (ciphertext, shared_secret) = + classic_mceliece_rust::encapsulate_boxed(&public_key, &mut rng); + (ciphertext.as_array().to_vec(), *shared_secret.as_array()) + } + "Kyber1024" => { + let public_key = kem_pubkey.key_data.as_slice(); + let (ciphertext, shared_secret) = + pqc_kyber::encapsulate(public_key, &mut rng).unwrap(); + (ciphertext.to_vec(), shared_secret) + } + name => panic!("Unsupported KEM algorithm: {name}"), + }; - ciphertexts.push(ciphertext); - println!("\tshared secret: {shared_secret:?}"); - for (psk_byte, shared_secret_byte) in psk_data.iter_mut().zip(shared_secret.iter()) { - *psk_byte ^= shared_secret_byte; + ciphertexts.push(ciphertext); + println!("\tshared secret: {shared_secret:?}"); + for (psk_byte, shared_secret_byte) in psk_data.iter_mut().zip(shared_secret.iter()) + { + *psk_byte ^= shared_secret_byte; + } } - } - let psk = PresharedKey::from(psk_data); - println!("psk: {psk:?}"); - println!("=============================================="); - Ok(Response::new(PskResponseV1 { ciphertexts })) + let psk = PresharedKey::from(psk_data); + println!("psk: {psk:?}"); + println!("=============================================="); + + Some(PostQuantumResponseV1 { ciphertexts }) + } else { + None + }; + + Ok(Response::new(EphemeralPeerResponseV1 { post_quantum })) } } #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { let addr = "127.0.0.1:1337".parse()?; - let server = PostQuantumSecureImpl::default(); + let server = EphemeralPeerImpl::default(); Server::builder() - .add_service(PostQuantumSecureServer::new(server)) + .add_service(EphemeralPeerServer::new(server)) .serve(addr) .await?; diff --git a/talpid-tunnel-config-client/proto/ephemeralpeer.proto b/talpid-tunnel-config-client/proto/ephemeralpeer.proto new file mode 100644 index 0000000000..bb49eb5598 --- /dev/null +++ b/talpid-tunnel-config-client/proto/ephemeralpeer.proto @@ -0,0 +1,85 @@ +syntax = "proto3"; + +option go_package = "github.com/mullvad/wg-manager/tuncfg/api/ephemeralpeer"; + +package ephemeralpeer; + +service EphemeralPeer { + // Derive an ephemeral peer with one or several options enabled, such as PQ or DAITA. + // + // The VPN server associates the ephemeral peer with the peer who performed the exchange. Any + // already existing ephemeral peer for the normal peer is replaced. Each normal peer can have + // at most one ephemeral peer. + // + // The ephemeral peer is mutually exclusive to the normal peer. The server keeps both peers in + // memory, but only one of them is loaded into WireGuard at any point in time. A handshake from + // the normal peer unloads the corresponding ephemeral peer from WireGuard and vice versa. + // + // A new peer is negotiated to avoid a premature break of the tunnel used for negotiation. + // A tunnel would break prematurely if configuration such as preshared key were applied before the + // normal peer received the server's response. This cannot occur now because the client decides + // when to switch to the ephemeral tunnel. This design also allows the client to switch back to + // using a non-ephemeral tunnel at any point. + // + // The server gives no guarantees how long the ephemeral peer will be valid and working when it's + // no longer in use. The client should negotiate a new ephemeral peer every time it establishes a + // new tunnel to the server. + // + // The request from the VPN client should contain: + // * `wg_parent_pubkey` - The public key used by the current tunnel (that the request travels + // inside). + // * `wg_ephemeral_peer_pubkey` - A newly generated ephemeral WireGuard public key for the + // ephemeral peer. The server will associate the new configuration with this key. + // * One or more requests for different types of options. See the individual messages for more + // information. If a request is provided, a corresponding response may be returned in the + // server's response. + rpc RegisterPeerV1(EphemeralPeerRequestV1) returns (EphemeralPeerResponseV1) {} +} + +message EphemeralPeerRequestV1 { + bytes wg_parent_pubkey = 1; + bytes wg_ephemeral_peer_pubkey = 2; + PostQuantumRequestV1 post_quantum = 3; + DaitaRequestV1 daita = 4; +} + +// The v1 request supports exactly two algorithms. +// The algorithms can appear soletary or in mixed order: +// - "Classic-McEliece-460896f", but explicitly identified as "Classic-McEliece-460896f-round3" +// - "Kyber1024" +message PostQuantumRequestV1 { repeated KemPubkeyV1 kem_pubkeys = 1; } + +message KemPubkeyV1 { + string algorithm_name = 1; + bytes key_data = 2; +} + +message DaitaRequestV1 { bool activate_daita = 1; } + +message EphemeralPeerResponseV1 { + // The response from the VPN server contains: + // * `ciphertexts` - A list of the ciphertexts (the encapsulated shared secrets) for all + // public keys in `kem_pubkeys` in the request, in the same order as in the request. + // + // # Deriving the WireGuard PSK + // + // The PSK to be used in WireGuard's preshared-key field is computed by XORing the resulting + // shared secrets of all the KEM algorithms. All currently supported and planned to be + // supported algorithms output 32 bytes, so this is trivial. + // + // Since the PSK provided to WireGuard is directly fed into a HKDF, it is not important that + // the entropy in the PSK is uniformly distributed. The actual keys used for encrypting the + // data channel will have uniformly distributed entropy anyway, thanks to the HKDF. + // But even if that was not true, since both CME and Kyber run SHAKE256 as the last step + // of their internal key derivation, the output they produce are uniformly distributed. + // + // If we later want to support another type of KEM that produce longer or shorter output, + // we can hash that secret into a 32 byte hash before proceeding to the XOR step. + // + // Mixing with XOR (A = B ^ C) is fine since nothing about A is revealed even if one of B or C + // is known. Both B *and* C must be known to compute any bit in A. This means all involved + // KEM algorithms must be broken before the PSK can be computed by an attacker. + PostQuantumResponseV1 post_quantum = 1; +} + +message PostQuantumResponseV1 { repeated bytes ciphertexts = 1; } diff --git a/talpid-tunnel-config-client/proto/tunnel_config.proto b/talpid-tunnel-config-client/proto/tunnel_config.proto deleted file mode 100644 index e6e4b73d97..0000000000 --- a/talpid-tunnel-config-client/proto/tunnel_config.proto +++ /dev/null @@ -1,83 +0,0 @@ -syntax = "proto3"; - -option go_package = "github.com/mullvad/wg-manager/server/tuncfg"; - -package tunnel_config; - -service PostQuantumSecure { - // Allows deriving a preshared key (PSK) using one or multiple PQ-secure key-encapsulation - // mechanisms (KEM). The preshared key is added to WireGuard's preshared-key field in a new - // ephemeral peer (PQ-peer). This makes the tunnel resistant towards attacks using - // quantum computers. - // - // The VPN server associates the PQ-peer with the peer who performed the exchange. Any - // already existing PQ-peer for the normal peer is replaced. Each normal peer can have - // at most one PQ-peer. - // - // The PQ-peer is mutually exclusive to the normal peer. The server keeps both peers in memory, - // but only one of them is loaded into WireGuard at any point in time. A handshake from the - // normal peer unloads the corresponding PQ-peer from WireGuard and vice versa. - // - // A new peer is negotiated for PQ to avoid a premature break of the tunnel used for negotiation. - // A tunnel would break prematurely if the preshared key is applied before the normal peer - // received the server's contribution to the KEM exchange. This cannot occur now because - // the client decides when to switch to the PQ-secure tunnel. This design also allows - // the client to switch back to using a non-PQ-secure tunnel at any point. - // - // The negotiated PQ-peer is ephemeral. The server gives no guarantees how long it will be - // valid and working. The client should negotiate a new PQ-peer every time it establishes a new - // tunnel to the server. - // - // The full exchange requires just a single request-response round trip between the VPN client - // and the VPN server. - // - // # Request-response format - // - // The request from the VPN client contains: - // * `wg_pubkey` - The public key used by the current tunnel (that the request travels inside). - // * `wg_psk_pubkey` - A newly generated ephemeral WireGuard public key for the PQ-peer. - // The server will associate the derived PSK with this public key. - // * `kem_pubkeys` - A list describing the KEM algorithms. Must have at least one entry. - // The same KEM must not be listed more than once. Each list item contains: - // * `algorithm_name` - The name of the KEM, including which variant. Should be the same - // name/format that `liboqs` uses. - // * `key_data` - The client's public key for this KEM. Will be used by the server to - // encapsulate the shared secret for this KEM. - // - // The response from the VPN server contains: - // * `ciphertexts` - A list of the ciphertexts (the encapsulated shared secrets) for all - // public keys in `kem_pubkeys` in the request, in the same order as in the request. - // - // # Deriving the WireGuard PSK - // - // The PSK to be used in WireGuard's preshared-key field is computed by XORing the resulting - // shared secrets of all the KEM algorithms. All currently supported and planned to be - // supported algorithms output 32 bytes, so this is trivial. - // - // Since the PSK provided to WireGuard is directly fed into a HKDF, it is not important that - // the entropy in the PSK is uniformly distributed. The actual keys used for encrypting the - // data channel will have uniformly distributed entropy anyway, thanks to the HKDF. - // But even if that was not true, since both CME and Kyber run SHAKE256 as the last step - // of their internal key derivation, the output they produce are uniformly distributed. - // - // If we later want to support another type of KEM that produce longer or shorter output, - // we can hash that secret into a 32 byte hash before proceeding to the XOR step. - // - // Mixing with XOR (A = B ^ C) is fine since nothing about A is revealed even if one of B or C - // is known. Both B *and* C must be known to compute any bit in A. This means all involved - // KEM algorithms must be broken before the PSK can be computed by an attacker. - rpc PskExchangeV1(PskRequestV1) returns (PskResponseV1) {} -} - -message PskRequestV1 { - bytes wg_pubkey = 1; - bytes wg_psk_pubkey = 2; - repeated KemPubkeyV1 kem_pubkeys = 3; -} - -message KemPubkeyV1 { - string algorithm_name = 1; - bytes key_data = 2; -} - -message PskResponseV1 { repeated bytes ciphertexts = 1; } diff --git a/talpid-tunnel-config-client/src/classic_mceliece.rs b/talpid-tunnel-config-client/src/classic_mceliece.rs index a2d5dc0be2..2036bc3fc7 100644 --- a/talpid-tunnel-config-client/src/classic_mceliece.rs +++ b/talpid-tunnel-config-client/src/classic_mceliece.rs @@ -1,6 +1,5 @@ -use classic_mceliece_rust::{ - keypair_boxed, Ciphertext, PublicKey, SecretKey, SharedSecret, CRYPTO_CIPHERTEXTBYTES, -}; +use classic_mceliece_rust::{keypair_boxed, Ciphertext, CRYPTO_CIPHERTEXTBYTES}; +pub use classic_mceliece_rust::{PublicKey, SecretKey, SharedSecret}; /// The `keypair_boxed` function needs just under 1 MiB of stack in debug /// builds. Even though it probably works to run it directly on the main diff --git a/talpid-tunnel-config-client/src/kyber.rs b/talpid-tunnel-config-client/src/kyber.rs index 5654b2e10b..003c88dc48 100644 --- a/talpid-tunnel-config-client/src/kyber.rs +++ b/talpid-tunnel-config-client/src/kyber.rs @@ -1,6 +1,5 @@ -use pqc_kyber::{SecretKey, KYBER_CIPHERTEXTBYTES}; - -pub use pqc_kyber::{keypair, KyberError}; +use pqc_kyber::KYBER_CIPHERTEXTBYTES; +pub use pqc_kyber::{keypair, KyberError, SecretKey}; /// Use the strongest variant of Kyber. It is fast and the keys are small, so there is no practical /// benefit of going with anything lower. diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs index 2c7b5e58f3..e89b1be42d 100644 --- a/talpid-tunnel-config-client/src/lib.rs +++ b/talpid-tunnel-config-client/src/lib.rs @@ -13,7 +13,7 @@ mod kyber; #[allow(clippy::derive_partial_eq_without_eq)] mod proto { - tonic::include_proto!("tunnel_config"); + tonic::include_proto!("ephemeralpeer"); } use libc::setsockopt; @@ -34,6 +34,7 @@ use sys::*; pub enum Error { GrpcConnectError(tonic::transport::Error), GrpcError(tonic::Status), + MissingCiphertexts, InvalidCiphertextLength { algorithm: &'static str, actual: usize, @@ -51,6 +52,7 @@ impl std::fmt::Display for Error { match self { GrpcConnectError(_) => "Failed to connect to config service".fmt(f), GrpcError(status) => write!(f, "RPC failed: {status}"), + MissingCiphertexts => write!(f, "Found no ciphertexts in response"), InvalidCiphertextLength { algorithm, actual, @@ -77,7 +79,7 @@ impl std::error::Error for Error { } } -type RelayConfigService = proto::post_quantum_secure_client::PostQuantumSecureClient<Channel>; +type RelayConfigService = proto::ephemeral_peer_client::EphemeralPeerClient<Channel>; /// Port used by the tunnel config service. pub const CONFIG_SERVICE_PORT: u16 = 1337; @@ -93,72 +95,118 @@ pub const CONFIG_SERVICE_PORT: u16 = 1337; /// handshake to work even if there is fragmentation. const CONFIG_CLIENT_MTU: u16 = 576; -/// Generates a new WireGuard key pair and negotiates a PSK with the relay in a PQ-safe -/// manner. This creates a peer on the relay with the new WireGuard pubkey and PSK, -/// which can then be used to establish a PQ-safe tunnel to the relay. -// TODO: consider binding to the tunnel interface here, on non-windows platforms -pub async fn push_pq_key( +pub struct EphemeralPeer { + pub psk: Option<PresharedKey>, +} + +/// Negotiate a short-lived peer with a PQ-safe PSK or with DAITA enabled. +pub async fn request_ephemeral_peer( + service_address: IpAddr, + parent_pubkey: PublicKey, + ephemeral_pubkey: PublicKey, + enable_post_quantum: bool, + enable_daita: bool, +) -> Result<EphemeralPeer, Error> { + request_ephemeral_peer_with_opts( + service_address, + parent_pubkey, + ephemeral_pubkey, + enable_post_quantum, + enable_daita, + ) + .await +} + +pub async fn request_ephemeral_peer_with_opts( service_address: IpAddr, - wg_pubkey: PublicKey, - wg_psk_pubkey: PublicKey, -) -> Result<PresharedKey, Error> { - let (cme_kem_pubkey, cme_kem_secret) = classic_mceliece::generate_keys().await; - let kyber_keypair = kyber::keypair(&mut rand::thread_rng()); + parent_pubkey: PublicKey, + ephemeral_pubkey: PublicKey, + enable_post_quantum: bool, + enable_daita: bool, +) -> Result<EphemeralPeer, Error> { + let (pq_request, kem_secrets) = if enable_post_quantum { + let (cme_kem_pubkey, cme_kem_secret) = classic_mceliece::generate_keys().await; + let kyber_keypair = kyber::keypair(&mut rand::thread_rng()); + + ( + Some(proto::PostQuantumRequestV1 { + kem_pubkeys: vec![ + proto::KemPubkeyV1 { + algorithm_name: classic_mceliece::ALGORITHM_NAME.to_owned(), + key_data: cme_kem_pubkey.as_array().to_vec(), + }, + proto::KemPubkeyV1 { + algorithm_name: kyber::ALGORITHM_NAME.to_owned(), + key_data: kyber_keypair.public.to_vec(), + }, + ], + }), + Some((cme_kem_secret, kyber_keypair.secret)), + ) + } else { + (None, None) + }; + + let daita = Some(proto::DaitaRequestV1 { + activate_daita: enable_daita, + }); let mut client = new_client(service_address).await?; let response = client - .psk_exchange_v1(proto::PskRequestV1 { - wg_pubkey: wg_pubkey.as_bytes().to_vec(), - wg_psk_pubkey: wg_psk_pubkey.as_bytes().to_vec(), - kem_pubkeys: vec![ - proto::KemPubkeyV1 { - algorithm_name: classic_mceliece::ALGORITHM_NAME.to_owned(), - key_data: cme_kem_pubkey.as_array().to_vec(), - }, - proto::KemPubkeyV1 { - algorithm_name: kyber::ALGORITHM_NAME.to_owned(), - key_data: kyber_keypair.public.to_vec(), - }, - ], + .register_peer_v1(proto::EphemeralPeerRequestV1 { + wg_parent_pubkey: parent_pubkey.as_bytes().to_vec(), + wg_ephemeral_peer_pubkey: ephemeral_pubkey.as_bytes().to_vec(), + post_quantum: pq_request, + daita, }) .await .map_err(Error::GrpcError)?; - let ciphertexts = response.into_inner().ciphertexts; + let psk = if let Some((cme_kem_secret, kyber_secret)) = kem_secrets { + let ciphertexts = response + .into_inner() + .post_quantum + .ok_or(Error::MissingCiphertexts)? + .ciphertexts; - // Unpack the ciphertexts into one per KEM without needing to access them by index. - let [cme_ciphertext, kyber_ciphertext] = <&[Vec<u8>; 2]>::try_from(ciphertexts.as_slice()) - .map_err(|_| Error::InvalidCiphertextCount { - actual: ciphertexts.len(), - })?; + // Unpack the ciphertexts into one per KEM without needing to access them by index. + let [cme_ciphertext, kyber_ciphertext] = <&[Vec<u8>; 2]>::try_from(ciphertexts.as_slice()) + .map_err(|_| Error::InvalidCiphertextCount { + actual: ciphertexts.len(), + })?; - // Store the PSK data on the heap. So it can be passed around and then zeroized on drop without - // being stored in a bunch of places on the stack. - let mut psk_data = Box::new([0u8; 32]); + // Store the PSK data on the heap. So it can be passed around and then zeroized on drop without + // being stored in a bunch of places on the stack. + let mut psk_data = Box::new([0u8; 32]); - // Decapsulate Classic McEliece and mix into PSK - { - let mut shared_secret = classic_mceliece::decapsulate(&cme_kem_secret, cme_ciphertext)?; - xor_assign(&mut psk_data, shared_secret.as_array()); + // Decapsulate Classic McEliece and mix into PSK + { + let mut shared_secret = classic_mceliece::decapsulate(&cme_kem_secret, cme_ciphertext)?; + xor_assign(&mut psk_data, shared_secret.as_array()); - // This should happen automatically due to `SharedSecret` implementing ZeroizeOnDrop. But - // doing it explicitly provides a stronger guarantee that it's not accidentally - // removed. - shared_secret.zeroize(); - } - // Decapsulate Kyber and mix into PSK - { - let mut shared_secret = kyber::decapsulate(kyber_keypair.secret, kyber_ciphertext)?; - xor_assign(&mut psk_data, &shared_secret); + // This should happen automatically due to `SharedSecret` implementing ZeroizeOnDrop. But + // doing it explicitly provides a stronger guarantee that it's not accidentally + // removed. + shared_secret.zeroize(); + } + // Decapsulate Kyber and mix into PSK + { + let mut shared_secret = kyber::decapsulate(kyber_secret, kyber_ciphertext)?; + xor_assign(&mut psk_data, &shared_secret); - // The shared secret is sadly stored in an array on the stack. So we can't get any - // guarantees that it's not copied around on the stack. The best we can do here - // is to zero out the version we have and hope the compiler optimizes out copies. - // https://github.com/Argyle-Software/kyber/issues/59 - shared_secret.zeroize(); - } + // The shared secret is sadly stored in an array on the stack. So we can't get any + // guarantees that it's not copied around on the stack. The best we can do here + // is to zero out the version we have and hope the compiler optimizes out copies. + // https://github.com/Argyle-Software/kyber/issues/59 + shared_secret.zeroize(); + } + + Some(PresharedKey::from(psk_data)) + } else { + None + }; - Ok(PresharedKey::from(psk_data)) + Ok(EphemeralPeer { psk }) } /// Performs `dst = dst ^ src`. diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 2b9f0a7366..a17d8ceb5f 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -40,6 +40,8 @@ impl TunnelParameters { obfuscation: None, entry_endpoint: None, tunnel_interface: None, + #[cfg(target_os = "windows")] + daita: false, }, TunnelParameters::Wireguard(params) => TunnelEndpoint { tunnel_type: TunnelType::Wireguard, @@ -55,6 +57,8 @@ impl TunnelParameters { .get_exit_endpoint() .map(|_| params.connection.get_endpoint()), tunnel_interface: None, + #[cfg(target_os = "windows")] + daita: params.options.daita, }, } } @@ -183,6 +187,8 @@ pub struct TunnelEndpoint { pub entry_endpoint: Option<Endpoint>, #[cfg_attr(target_os = "android", jnix(skip))] pub tunnel_interface: Option<String>, + #[cfg(target_os = "windows")] + pub daita: bool, } impl fmt::Display for TunnelEndpoint { diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index db7b2da3a9..f7212236e2 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -60,6 +60,10 @@ pub struct PeerConfig { /// ephemeral and living in memory only. #[serde(skip)] pub psk: Option<PresharedKey>, + /// Enable constant packet sizes for `entry_peer`` + #[cfg(target_os = "windows")] + #[serde(skip)] + pub constant_packet_size: bool, } #[derive(Clone, Eq, PartialEq, Deserialize, Serialize, Debug)] @@ -76,6 +80,9 @@ pub struct TunnelOptions { pub mtu: Option<u16>, /// Perform PQ-safe PSK exchange when connecting pub quantum_resistant: bool, + /// Enable DAITA during tunnel config + #[cfg(target_os = "windows")] + pub daita: bool, } /// Wireguard x25519 private key diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index 1fc6e13b3a..c2562d212c 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] thiserror = { workspace = true } +base64 = "0.13" futures = "0.3.15" hex = "0.4" ipnetwork = "0.16" @@ -54,6 +55,7 @@ talpid-dbus = { path = "../talpid-dbus" } bitflags = "1.2" talpid-windows = { path = "../talpid-windows" } widestring = "1.0" +maybenot = "1.0" # TODO: Figure out which features are needed and which are not [target.'cfg(windows)'.dependencies.windows-sys] diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index 29328eb681..f10a0e4859 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -28,6 +28,10 @@ pub struct Config { pub enable_ipv6: bool, /// Obfuscator config to be used for reaching the relay. pub obfuscator_config: Option<ObfuscatorConfig>, + /// Enable quantum-resistant PSK exchange + pub quantum_resistant: bool, + /// Enable DAITA + pub daita: bool, } /// Configuration errors @@ -92,6 +96,11 @@ impl Config { #[cfg(target_os = "linux")] enable_ipv6: generic_options.enable_ipv6, obfuscator_config: obfuscator_config.to_owned(), + quantum_resistant: wg_options.quantum_resistant, + #[cfg(target_os = "windows")] + daita: wg_options.daita, + #[cfg(not(target_os = "windows"))] + daita: false, }; for peer in config.peers_mut() { diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index e820a3eb73..70f88e6872 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -612,6 +612,11 @@ mod test { ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { Box::pin(async { Ok(()) }) } + + #[cfg(target_os = "windows")] + fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + Ok(()) + } } fn mock_monitor( diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 72d6a31566..bdef52343e 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -260,7 +260,6 @@ impl WireguardMonitor { + 'static, >( mut config: Config, - psk_negotiation: bool, #[cfg(not(target_os = "android"))] detect_mtu: bool, log_path: Option<&Path>, args: TunnelArgs<'_, F>, @@ -283,12 +282,12 @@ impl WireguardMonitor { log_path, args.resource_dir, args.tun_provider.clone(), + #[cfg(target_os = "android")] + config.quantum_resistant, #[cfg(target_os = "windows")] args.route_manager.clone(), #[cfg(target_os = "windows")] setup_done_tx, - #[cfg(target_os = "android")] - psk_negotiation, )?; let iface_name = tunnel.get_interface_name(); @@ -336,7 +335,7 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - let allowed_traffic = if psk_negotiation { + let allowed_traffic = if config.quantum_resistant || config.daita { AllowedTunnelTraffic::One(Endpoint::new( config.ipv4_gateway, talpid_tunnel_config_client::CONFIG_SERVICE_PORT, @@ -365,16 +364,16 @@ impl WireguardMonitor { .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; - let psk_obfs_sender = close_obfs_sender.clone(); - if psk_negotiation { - Self::psk_negotiation( + let ephemeral_obfs_sender = close_obfs_sender.clone(); + if config.quantum_resistant || config.daita { + Self::config_ephemeral_peers( &tunnel, &mut config, args.retry_attempt, args.on_event.clone(), &iface_name, obfuscator.clone(), - psk_obfs_sender, + ephemeral_obfs_sender, #[cfg(target_os = "android")] args.tun_provider, ) @@ -386,6 +385,14 @@ impl WireguardMonitor { let config = config.clone(); let iface_name = iface_name.clone(); tokio::task::spawn(async move { + #[cfg(target_os = "windows")] + if config.daita { + // TODO: For now, we assume the MTU during the tunnel lifetime. + // We could instead poke maybenot whenever we detect changes to it. + log::warn!("MTU detection is not supported with DAITA. Skipping"); + return; + } + if let Err(e) = mtu_detection::automatic_mtu_correction( gateway, iface_name, @@ -465,7 +472,7 @@ impl WireguardMonitor { } #[allow(clippy::too_many_arguments)] - async fn psk_negotiation<F>( + async fn config_ephemeral_peers<F>( tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>, config: &mut Config, retry_attempt: u32, @@ -482,7 +489,7 @@ impl WireguardMonitor { + Clone + 'static, { - let wg_psk_privkey = PrivateKey::new_from_random(); + let ephemeral_private_key = PrivateKey::new_from_random(); let close_obfs_sender = close_obfs_sender.clone(); let allowed_traffic = Endpoint::new( @@ -507,11 +514,17 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(iface_name, config); (on_event)(TunnelEvent::InterfaceUp(metadata, allowed_traffic.clone())).await; - let exit_psk = - Self::perform_psk_negotiation(retry_attempt, config, wg_psk_privkey.public_key()) - .await?; + let exit_should_have_daita = config.daita && !config.is_multihop(); + let exit_psk = Self::request_ephemeral_peer( + retry_attempt, + config, + ephemeral_private_key.public_key(), + config.quantum_resistant, + exit_should_have_daita, + ) + .await?; - log::debug!("Successfully exchanged PSK with exit peer"); + log::debug!("Retrieved ephemeral peer"); if config.is_multihop() { // Set up tunnel to lead to entry @@ -531,22 +544,27 @@ impl WireguardMonitor { &tun_provider, ) .await?; - let entry_psk = Some( - Self::perform_psk_negotiation( - retry_attempt, - &entry_config, - wg_psk_privkey.public_key(), - ) - .await?, - ); + let entry_psk = Self::request_ephemeral_peer( + retry_attempt, + &entry_config, + ephemeral_private_key.public_key(), + config.quantum_resistant, + config.daita, + ) + .await?; log::debug!("Successfully exchanged PSK with entry peer"); config.entry_peer.psk = entry_psk; } - config.exit_peer_mut().psk = Some(exit_psk); + config.exit_peer_mut().psk = exit_psk; + #[cfg(target_os = "windows")] + if config.daita { + log::trace!("Enabling constant packet size for entry peer"); + config.entry_peer.constant_packet_size = true; + } - config.tunnel.private_key = wg_psk_privkey; + config.tunnel.private_key = ephemeral_private_key; *config = Self::reconfigure_tunnel( tunnel, @@ -557,6 +575,19 @@ impl WireguardMonitor { &tun_provider, ) .await?; + + #[cfg(target_os = "windows")] + if config.daita { + // Start local DAITA machines + let mut tunnel = tunnel.lock().unwrap(); + if let Some(tunnel) = tunnel.as_mut() { + tunnel + .start_daita() + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } + } + let metadata = Self::tunnel_metadata(iface_name, config); (on_event)(TunnelEvent::InterfaceUp( metadata, @@ -678,12 +709,14 @@ impl WireguardMonitor { Ok(()) } - async fn perform_psk_negotiation( + async fn request_ephemeral_peer( retry_attempt: u32, config: &Config, wg_psk_pubkey: PublicKey, - ) -> std::result::Result<PresharedKey, CloseMsg> { - log::debug!("Performing PQ-safe PSK exchange"); + enable_pq: bool, + enable_daita: bool, + ) -> std::result::Result<Option<PresharedKey>, CloseMsg> { + log::debug!("Requesting ephemeral peer"); let timeout = std::cmp::min( MAX_PSK_EXCHANGE_TIMEOUT, @@ -691,23 +724,25 @@ impl WireguardMonitor { .saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)), ); - let psk = tokio::time::timeout( + let ephemeral = tokio::time::timeout( timeout, - talpid_tunnel_config_client::push_pq_key( + talpid_tunnel_config_client::request_ephemeral_peer_with_opts( IpAddr::from(config.ipv4_gateway), config.tunnel.private_key.public_key(), wg_psk_pubkey, + enable_pq, + enable_daita, ), ) .await .map_err(|_timeout_err| { - log::warn!("Timeout while negotiating PSK"); + log::warn!("Timeout while negotiating ephemeral peer"); CloseMsg::PskNegotiationTimeout })? .map_err(Error::PskNegotiationError) .map_err(CloseMsg::SetupError)?; - Ok(psk) + Ok(ephemeral.psk) } #[allow(unused_variables)] @@ -717,7 +752,7 @@ impl WireguardMonitor { log_path: Option<&Path>, resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, - #[cfg(target_os = "android")] psk_negotiation: bool, + #[cfg(target_os = "android")] gateway_only: bool, #[cfg(windows)] route_manager: crate::routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { @@ -771,7 +806,7 @@ impl WireguardMonitor { Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes); #[cfg(target_os = "android")] - let config = Self::patch_allowed_ips(config, psk_negotiation); + let config = Self::patch_allowed_ips(config, gateway_only); #[cfg(target_os = "linux")] log::debug!("Using userspace WireGuard implementation"); @@ -994,6 +1029,8 @@ pub(crate) trait Tunnel: Send { &self, _config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>; + #[cfg(target_os = "windows")] + fn start_daita(&mut self) -> std::result::Result<(), TunnelError>; } /// Errors to be returned from WireGuard implementations, namely implementers of the Tunnel trait diff --git a/talpid-wireguard/src/wireguard_nt/daita.rs b/talpid-wireguard/src/wireguard_nt/daita.rs new file mode 100644 index 0000000000..75d6ebaa4d --- /dev/null +++ b/talpid-wireguard/src/wireguard_nt/daita.rs @@ -0,0 +1,450 @@ +use super::WIREGUARD_KEY_LENGTH; +use maybenot::framework::MachineId; +use once_cell::sync::OnceCell; +use std::{collections::HashMap, fs, io, path::Path, time::Duration}; +use std::{os::windows::prelude::RawHandle, sync::Arc}; +use talpid_types::net::wireguard::PublicKey; +use tokio::task::JoinHandle; +use windows_sys::Win32::Foundation::BOOLEAN; +use windows_sys::Win32::{ + Foundation::ERROR_NO_MORE_ITEMS, + System::Threading::{WaitForMultipleObjects, WaitForSingleObject, INFINITE}, +}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Failed to find maybenot machines + #[error("Failed to enumerate maybenot machines")] + EnumerateMachines(#[source] io::Error), + /// Failed to parse maybenot machine + #[error("Failed to parse maybenot machine \"{0}\"")] + InvalidMachine(String), + /// Failed to initialize quit event + #[error("Failed to initialize quit event")] + InitializeQuitEvent(#[source] io::Error), + /// Failed to initialize machinist handle + #[error("Failed to initialize machinist handle")] + InitializeHandle(#[source] io::Error), + /// Failed to initialize maybenot framework + #[error("Failed to initialize maybenot framework: {0}")] + InitializeMaybenot(String), +} + +// See DAITA_EVENT_TYPE: +// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h +#[repr(C)] +#[derive(Debug)] +#[allow(dead_code)] +pub enum EventType { + NonpaddingSent, + NonpaddingReceived, + PaddingSent, + PaddingReceived, +} + +// See DAITA_EVENT: +// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h +#[repr(C)] +#[derive(Debug)] +pub struct Event { + pub peer: [u8; WIREGUARD_KEY_LENGTH], + pub event_type: EventType, + pub xmit_bytes: u16, + pub user_context: usize, +} + +// See DAITA_ACTION_TYPE: +// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h +#[repr(C)] +pub enum ActionType { + InjectPadding, +} + +// See DAITA_PADDING_ACTION: +// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct PaddingAction { + pub byte_count: u16, + pub replace: BOOLEAN, +} + +// See DAITA_ACTION: +// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h +#[repr(C)] +pub struct Action { + pub peer: [u8; WIREGUARD_KEY_LENGTH], + pub action_type: ActionType, + pub payload: ActionPayload, + pub user_context: usize, +} + +#[repr(C)] +pub union ActionPayload { + pub padding: PaddingAction, +} + +/// Maximum number of events that can be stored in the underlying buffer +const EVENTS_CAPACITY: usize = 1000; +/// Maximum number of actions that can be stored in the underlying buffer +const ACTIONS_CAPACITY: usize = 1000; + +pub mod bindings { + use super::*; + use windows_sys::Win32::Foundation::BOOL; + + pub type WireGuardDaitaActivateFn = unsafe extern "stdcall" fn( + adapter: RawHandle, + events_capacity: usize, + actions_capacity: usize, + ) -> BOOL; + pub type WireGuardDaitaEventDataAvailableEventFn = + unsafe extern "stdcall" fn(adapter: RawHandle) -> RawHandle; + pub type WireGuardDaitaReceiveEventsFn = + unsafe extern "stdcall" fn(adapter: RawHandle, events: *mut Event) -> usize; + pub type WireGuardDaitaSendActionFn = + unsafe extern "stdcall" fn(adapter: RawHandle, action: *const Action) -> BOOL; +} + +#[derive(Debug)] +pub struct Session { + adapter: Arc<super::WgNtAdapter>, +} + +impl Session { + /// Call `WireGuardDaitaActivate` for an existing WireGuard interface + pub(super) fn from_adapter(adapter: Arc<super::WgNtAdapter>) -> io::Result<Session> { + // SAFETY: `WgNtAdapter` has a valid adapter handle + unsafe { + adapter + .dll_handle + .daita_activate(adapter.handle, EVENTS_CAPACITY, ACTIONS_CAPACITY) + }?; + Ok(Self { adapter }) + } + + pub fn receive_events<'a>( + &self, + buffer: &'a mut [Event; EVENTS_CAPACITY], + ) -> io::Result<&'a [Event]> { + let num_events = unsafe { + // SAFETY: The adapter is valid, and the buffer is large enough to accommodate all + // events. + self.adapter + .dll_handle + .daita_receive_events(self.adapter.handle, buffer.as_mut_ptr())? + }; + Ok(unsafe { std::slice::from_raw_parts(buffer.as_ptr(), num_events) }) + } + + pub fn send_action(&self, action: &Action) -> io::Result<()> { + // SAFETY: The adapter is valid + unsafe { + self.adapter + .dll_handle + .daita_send_action(self.adapter.handle, action) + } + } + + pub fn event_data_available_event(&self) -> RawHandle { + // SAFETY: The adapter is valid + // This never fails when there's a DAITA session + unsafe { + self.adapter + .dll_handle + .daita_event_data_available_event(self.adapter.handle) + .unwrap() + } + } +} + +fn maybenot_event_from_event( + event: &Event, + machine_ids: &MachineMap, + override_size: Option<u16>, +) -> Option<maybenot::framework::TriggerEvent> { + let xmit_bytes = override_size.unwrap_or(event.xmit_bytes); + match event.event_type { + EventType::PaddingReceived => Some(maybenot::framework::TriggerEvent::PaddingRecv { + bytes_recv: xmit_bytes, + }), + EventType::NonpaddingSent => Some(maybenot::framework::TriggerEvent::NonPaddingSent { + bytes_sent: xmit_bytes, + }), + EventType::NonpaddingReceived => Some(maybenot::framework::TriggerEvent::NonPaddingRecv { + bytes_recv: xmit_bytes, + }), + EventType::PaddingSent => Some(maybenot::framework::TriggerEvent::PaddingSent { + bytes_sent: xmit_bytes, + machine: machine_ids.get_machine_id(event.user_context)?.to_owned(), + }), + } +} + +/// Handle for a set of DAITA machines. +/// Note: `close` is NOT called implicitly when this is dropped. +pub struct MachinistHandle { + quit_event: talpid_windows::sync::Event, +} + +impl MachinistHandle { + fn new(quit_event: &talpid_windows::sync::Event) -> io::Result<MachinistHandle> { + Ok(MachinistHandle { + quit_event: quit_event.duplicate()?, + }) + } + + /// Signal quit event + pub fn close(&self) -> io::Result<()> { + self.quit_event.set() + } +} + +pub struct Machinist { + daita: Arc<Session>, + machine_ids: MachineMap, + machine_tasks: HashMap<usize, JoinHandle<()>>, + tokio_handle: tokio::runtime::Handle, + quit_event: talpid_windows::sync::Event, + peer: PublicKey, + override_size: Option<u16>, +} + +// TODO: This is silly. Let me use the raw ID of MachineId, please. +struct MachineMap { + id_to_num: HashMap<MachineId, usize>, + num_to_id: HashMap<usize, MachineId>, +} + +impl MachineMap { + fn new() -> Self { + Self { + id_to_num: HashMap::new(), + num_to_id: HashMap::new(), + } + } + + fn get_or_create_raw_id(&mut self, machine_id: MachineId) -> usize { + *self.id_to_num.entry(machine_id).or_insert_with(|| { + let raw_id = self.num_to_id.len(); + self.num_to_id.insert(raw_id, machine_id); + raw_id + }) + } + + fn get_machine_id(&self, raw_id: usize) -> Option<&MachineId> { + self.num_to_id.get(&raw_id) + } +} + +impl Machinist { + /// Spawn an actor that handles scheduling of Maybenot actions and forwards DAITA events to the framework. + pub fn spawn( + resource_dir: &Path, + daita: Session, + peer: PublicKey, + mtu: u16, + ) -> std::result::Result<MachinistHandle, Error> { + const MAX_PADDING_BYTES: f64 = 0.0; + const MAX_BLOCKING_BYTES: f64 = 0.0; + + static MAYBENOT_MACHINES: OnceCell<Vec<maybenot::machine::Machine>> = OnceCell::new(); + + let machines = MAYBENOT_MACHINES.get_or_try_init(|| { + let path = resource_dir.join("maybenot_machines"); + log::debug!("Reading maybenot machines from {}", path.display()); + + let mut machines = vec![]; + let machines_str = fs::read_to_string(path).map_err(Error::EnumerateMachines)?; + for machine_str in machines_str.lines() { + let machine_str = machine_str.trim(); + if matches!(machine_str.chars().next(), None | Some('#')) { + continue; + } + log::debug!("Adding maybenot machine: {machine_str}"); + machines.push( + machine_str + .parse::<maybenot::machine::Machine>() + .map_err(|_error| Error::InvalidMachine(machine_str.to_owned()))?, + ); + } + Ok(machines) + })?; + + let quit_event = + talpid_windows::sync::Event::new(true, false).map_err(Error::InitializeQuitEvent)?; + let handle = MachinistHandle::new(&quit_event).map_err(Error::InitializeHandle)?; + + let framework = maybenot::framework::Framework::new( + machines.clone(), + MAX_PADDING_BYTES, + MAX_BLOCKING_BYTES, + mtu, + std::time::Instant::now(), + ) + .map_err(|error| Error::InitializeMaybenot(error.to_string()))?; + + let daita = Arc::new(daita); + let tokio_handle = tokio::runtime::Handle::current(); + + std::thread::spawn(move || { + Self { + daita, + machine_ids: MachineMap::new(), + machine_tasks: HashMap::new(), + tokio_handle, + quit_event, + peer, + // TODO: We're assuming that constant packet size is always enabled here + override_size: Some(mtu), + } + .event_loop(framework); + }); + + Ok(handle) + } + + fn event_loop( + mut self, + mut framework: maybenot::framework::Framework<Vec<maybenot::machine::Machine>>, + ) { + use windows_sys::Win32::Foundation::WAIT_OBJECT_0; + + loop { + if unsafe { WaitForSingleObject(self.quit_event.as_raw(), 0) } == WAIT_OBJECT_0 { + break; + } + + let events = match self.wait_for_events() { + Ok(events) => { + if events.is_empty() { + break; + } + events + } + Err(error) => { + log::error!("Error while waiting for DAITA events: {error}"); + break; + } + }; + + for action in framework.trigger_events(&events, std::time::Instant::now()) { + self.handle_action(action); + } + } + + log::debug!("Stopped DAITA event loop"); + } + + fn handle_action(&mut self, action: &maybenot::framework::Action) { + match *action { + maybenot::framework::Action::Cancel { machine } => { + let raw_id = self.machine_ids.get_or_create_raw_id(machine); + + // Drop all scheduled actions for a given machine + if let Some(task) = self.machine_tasks.get_mut(&raw_id) { + task.abort(); + } + } + maybenot::framework::Action::InjectPadding { + timeout, + size, + machine, + replace, + .. + } => { + let peer = self.peer.clone(); + + let raw_id = self.machine_ids.get_or_create_raw_id(machine); + self.machine_tasks.entry(raw_id).and_modify(|f| f.abort()); + + let action = Action { + peer: *peer.as_bytes(), + action_type: ActionType::InjectPadding, + user_context: raw_id, + payload: ActionPayload { + padding: PaddingAction { + byte_count: size, + replace: if replace { 1 } else { 0 }, + }, + }, + }; + + if timeout == Duration::ZERO { + if let Err(error) = self.daita.send_action(&action) { + log::error!("Failed to send DAITA action: {error}"); + } + } else { + // Schedule action on the tokio runtime + let daita = Arc::downgrade(&self.daita); + let task = self.tokio_handle.spawn(async move { + tokio::time::sleep(timeout).await; + + let Some(daita) = daita.upgrade() else { return }; + + if let Err(error) = daita.send_action(&action) { + log::error!("Failed to send DAITA action: {error}"); + } + }); + self.machine_tasks.insert(raw_id, task); + } + } + maybenot::framework::Action::BlockOutgoing { .. } => {} + } + } + + /// Take all events from the ring buffer while there are any left. + /// If there are no events available, wait for events to arrive. + /// Otherwise, break and return a non-zero number of events to be processed. + /// If the quit event was signaled, this returns an empty vector. + fn wait_for_events(&mut self) -> io::Result<Vec<maybenot::framework::TriggerEvent>> { + use windows_sys::Win32::Foundation::WAIT_OBJECT_0; + + let wait_events = [ + self.quit_event.as_raw(), + self.daita.event_data_available_event() as isize, + ]; + + let mut event_buffer: [Event; EVENTS_CAPACITY] = unsafe { std::mem::zeroed() }; + + loop { + match self.daita.receive_events(&mut event_buffer) { + Ok(events) => { + let converted_events: Vec<_> = events + .iter() + .filter(|event| &event.peer == self.peer.as_bytes()) + .filter_map(|event| { + maybenot_event_from_event(event, &self.machine_ids, self.override_size) + }) + .collect(); + if !converted_events.is_empty() { + return Ok(converted_events); + } + // Try again if we only received events for irrelevant peers + } + Err(error) => { + if error.raw_os_error() == Some(ERROR_NO_MORE_ITEMS as i32) { + let wait_result = unsafe { + WaitForMultipleObjects( + u32::try_from(wait_events.len()).unwrap(), + wait_events.as_ptr(), + 0, + INFINITE, + ) + }; + + if wait_result == WAIT_OBJECT_0 { + // Quit event signaled + break Ok(vec![]); + } + if wait_result == WAIT_OBJECT_0 + 1 { + // Event object signaled -- try to receive more events + continue; + } + } + break Err(std::io::Error::last_os_error()); + } + } + } + } +} diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt/mod.rs index 10cc45b384..6acd6dc710 100644 --- a/talpid-wireguard/src/wireguard_nt.rs +++ b/talpid-wireguard/src/wireguard_nt/mod.rs @@ -9,14 +9,14 @@ use futures::SinkExt; use ipnetwork::IpNetwork; use once_cell::sync::{Lazy, OnceCell}; use std::{ - ffi::CStr, + ffi::{c_uchar, CStr}, fmt, future::Future, - io, mem, - mem::MaybeUninit, + io, + mem::{self, MaybeUninit}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, os::windows::io::RawHandle, - path::Path, + path::{Path, PathBuf}, pin::Pin, ptr, sync::{Arc, Mutex}, @@ -38,6 +38,8 @@ use windows_sys::{ }, }; +mod daita; + static WG_NT_DLL: OnceCell<WgNtDll> = OnceCell::new(); static ADAPTER_TYPE: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); static ADAPTER_ALIAS: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); @@ -159,12 +161,23 @@ pub enum Error { /// Failed to parse data returned by the driver #[error("Failed to parse data returned by wireguard-nt")] InvalidConfigData, + + /// DAITA machinist failed + #[error("Failed to enable DAITA on tunnel device")] + EnableTunnelDaita(#[source] io::Error), + + /// DAITA machinist failed + #[error("Failed to initialize DAITA machinist")] + InitializeMachinist(#[source] daita::Error), } pub struct WgNtTunnel { + resource_dir: PathBuf, + config: Arc<Mutex<Config>>, device: Option<Arc<WgNtAdapter>>, interface_name: String, setup_handle: tokio::task::JoinHandle<()>, + daita_handle: Option<daita::MachinistHandle>, _logger_handle: LoggerHandle, } @@ -305,6 +318,7 @@ bitflags! { const REPLACE_ALLOWED_IPS = 0b00100000; const REMOVE = 0b01000000; const UPDATE = 0b10000000; + const HAS_CONSTANT_PACKET_SIZE = 0b100000000; } } @@ -322,6 +336,7 @@ struct WgPeer { rx_bytes: u64, last_handshake: u64, allowed_ips_count: u32, + constant_packet_size: c_uchar, } #[derive(Clone, Copy)] @@ -446,7 +461,7 @@ impl WgNtTunnel { let device = Some(device2.clone()); let setup_future = setup_ip_listener( - device2, + device2.clone(), u32::from(config.mtu), config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()), ); @@ -457,17 +472,50 @@ impl WgNtTunnel { }); Ok(WgNtTunnel { + resource_dir: resource_dir.to_owned(), + config: Arc::new(Mutex::new(config.clone())), device, interface_name, setup_handle, + daita_handle: None, _logger_handle: logger_handle, }) } fn stop_tunnel(&mut self) { self.setup_handle.abort(); + if let Some(daita_handle) = self.daita_handle.take() { + let _ = daita_handle.close(); + } let _ = self.device.take(); } + + fn spawn_machinist(&mut self) -> Result<()> { + if let Some(handle) = self.daita_handle.take() { + log::info!("Stopping previous DAITA machines"); + let _ = handle.close(); + } + + let Some(device) = self.device.clone() else { + log::debug!("Tunnel is stopped; not starting machines"); + return Ok(()); + }; + + let config = self.config.lock().unwrap(); + + log::info!("Initializing DAITA for wireguard device"); + let session = daita::Session::from_adapter(device).map_err(Error::EnableTunnelDaita)?; + self.daita_handle = Some( + daita::Machinist::spawn( + &self.resource_dir, + session, + config.entry_peer.public_key.clone(), + config.mtu, + ) + .map_err(Error::InitializeMachinist)?, + ); + Ok(()) + } } async fn setup_ip_listener(device: Arc<WgNtAdapter>, mtu: u32, has_ipv6: bool) -> Result<()> { @@ -622,6 +670,11 @@ struct WgNtDll { func_set_adapter_state: WireGuardSetStateFn, func_set_logger: WireGuardSetLoggerFn, func_set_adapter_logging: WireGuardSetAdapterLoggingFn, + + func_daita_activate: daita::bindings::WireGuardDaitaActivateFn, + func_daita_event_data_available_event: daita::bindings::WireGuardDaitaEventDataAvailableEventFn, + func_daita_receive_events: daita::bindings::WireGuardDaitaReceiveEventsFn, + func_daita_send_action: daita::bindings::WireGuardDaitaSendActionFn, } unsafe impl Send for WgNtDll {} @@ -694,6 +747,30 @@ impl WgNtDll { CStr::from_bytes_with_nul(b"WireGuardSetAdapterLogging\0").unwrap(), )?) as *const _ as *const _) }, + func_daita_activate: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardDaitaActivate\0").unwrap(), + )?) as *const _ as *const _) + }, + func_daita_event_data_available_event: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardDaitaEventDataAvailableEvent\0").unwrap(), + )?) as *const _ as *const _) + }, + func_daita_receive_events: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardDaitaReceiveEvents\0").unwrap(), + )?) as *const _ as *const _) + }, + func_daita_send_action: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardDaitaSendAction\0").unwrap(), + )?) as *const _ as *const _) + }, }) } @@ -790,6 +867,52 @@ impl WgNtDll { } Ok(()) } + + pub unsafe fn daita_activate( + &self, + adapter: RawHandle, + events_capacity: usize, + actions_capacity: usize, + ) -> io::Result<()> { + if (self.func_daita_activate)(adapter, events_capacity, actions_capacity) == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } + + pub unsafe fn daita_event_data_available_event( + &self, + adapter: RawHandle, + ) -> io::Result<RawHandle> { + let ready_event = (self.func_daita_event_data_available_event)(adapter); + if ready_event.is_null() { + return Err(io::Error::last_os_error()); + } + Ok(ready_event) + } + + pub unsafe fn daita_receive_events( + &self, + adapter: RawHandle, + events: *mut daita::Event, + ) -> io::Result<usize> { + let num_events = (self.func_daita_receive_events)(adapter, events); + if num_events == 0 { + return Err(io::Error::last_os_error()); + } + Ok(num_events) + } + + pub unsafe fn daita_send_action( + &self, + adapter: RawHandle, + action: *const daita::Action, + ) -> io::Result<()> { + if (self.func_daita_send_action)(adapter, action) == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } } impl Drop for WgNtDll { @@ -816,11 +939,13 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { buffer.extend(as_uninit_byte_slice(&header)); for peer in config.peers() { - let flags = if peer.psk.is_some() { - WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT - } else { - WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT - }; + let mut flags = WgPeerFlag::HAS_PUBLIC_KEY + | WgPeerFlag::HAS_ENDPOINT + | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE; + if peer.psk.is_some() { + flags |= WgPeerFlag::HAS_PRESHARED_KEY; + } + let constant_packet_size = if peer.constant_packet_size { 1 } else { 0 }; let wg_peer = WgPeer { flags, reserved: 0, @@ -836,6 +961,7 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { rx_bytes: 0, last_handshake: 0, allowed_ips_count: u32::try_from(peer.allowed_ips.len()).unwrap(), + constant_packet_size, }; buffer.extend(as_uninit_byte_slice(&wg_peer)); @@ -960,13 +1086,16 @@ impl Tunnel for WgNtTunnel { config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { let device = self.device.clone(); + let current_config = self.config.clone(); Box::pin(async move { let Some(device) = device else { log::error!("Failed to set config: No tunnel device"); return Err(super::TunnelError::SetConfigError); }; - device.set_config(&config).map_err(|error| { + let mut current_config = current_config.lock().unwrap(); + *current_config = config; + device.set_config(¤t_config).map_err(|error| { log::error!( "{}", error.display_chain_with_msg("Failed to set wg-nt tunnel config") @@ -975,6 +1104,16 @@ impl Tunnel for WgNtTunnel { }) }) } + + fn start_daita(&mut self) -> std::result::Result<(), crate::TunnelError> { + self.spawn_machinist().map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to start DAITA for wg-nt tunnel") + ); + super::TunnelError::SetConfigError + }) + } } pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8>] { @@ -984,7 +1123,6 @@ pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8 #[cfg(test)] mod tests { use super::*; - use once_cell::sync::Lazy; use talpid_types::net::wireguard; #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -1009,12 +1147,16 @@ mod tests { allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], endpoint: "1.2.3.4:1234".parse().unwrap(), psk: None, + #[cfg(target_os = "windows")] + constant_packet_size: false, }, exit_peer: None, ipv4_gateway: "0.0.0.0".parse().unwrap(), ipv6_gateway: None, mtu: 0, obfuscator_config: None, + daita: false, + quantum_resistant: false, }); static WG_STRUCT_CONFIG: Lazy<Interface> = Lazy::new(|| Interface { @@ -1026,7 +1168,9 @@ mod tests { peers_count: 1, }, p0: WgPeer { - flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, + flags: WgPeerFlag::HAS_PUBLIC_KEY + | WgPeerFlag::HAS_ENDPOINT + | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE, reserved: 0, public_key: *WG_PUBLIC_KEY.as_bytes(), preshared_key: [0; WIREGUARD_KEY_LENGTH], @@ -1039,6 +1183,8 @@ mod tests { rx_bytes: 0, last_handshake: 0, allowed_ips_count: 1, + #[cfg(target_os = "windows")] + constant_packet_size: 0, }, p0_allowed_ip_0: WgAllowedIp { address: WgIpAddr::from("1.3.3.0".parse::<Ipv4Addr>().unwrap()), |
