diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-11-05 16:20:19 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-12-12 09:54:11 +0100 |
| commit | f1ea8f1b344a0fcb33a7c13b66f2bd9b952252d4 (patch) | |
| tree | b6fba254dac58e09b7c09e54d9dc4631b4c3917d | |
| parent | 8580f9acb75679c7995996c6049ad6f0541d2215 (diff) | |
| download | mullvadvpn-f1ea8f1b344a0fcb33a7c13b66f2bd9b952252d4.tar.xz mullvadvpn-f1ea8f1b344a0fcb33a7c13b66f2bd9b952252d4.zip | |
Use dynamic DAITA machines in wireguard-go
Co-authored-by: Markus Pettersson <markus.pettersson@mullvad.net>
| -rw-r--r-- | talpid-tunnel-config-client/src/lib.rs | 72 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/mock.rs | 5 | ||||
| -rw-r--r-- | talpid-wireguard/src/ephemeral.rs | 36 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 21 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 36 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs | 3 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs | 3 | ||||
| -rw-r--r-- | wireguard-go-rs/libwg/libwg_daita.go | 7 | ||||
| -rw-r--r-- | wireguard-go-rs/src/lib.rs | 6 |
9 files changed, 103 insertions, 86 deletions
diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs index fecd9de32b..f7d559f641 100644 --- a/talpid-tunnel-config-client/src/lib.rs +++ b/talpid-tunnel-config-client/src/lib.rs @@ -22,6 +22,7 @@ mod proto { tonic::include_proto!("ephemeralpeer"); } +#[cfg(all(unix, not(target_os = "ios")))] const DAITA_VERSION: u32 = 2; #[derive(Debug)] @@ -87,6 +88,7 @@ pub const CONFIG_SERVICE_PORT: u16 = 1337; pub struct EphemeralPeer { pub psk: Option<PresharedKey>, + #[cfg(all(unix, not(target_os = "ios")))] pub daita: Option<DaitaSettings>, } @@ -136,9 +138,16 @@ pub async fn request_ephemeral_peer_with( wg_parent_pubkey: parent_pubkey.as_bytes().to_vec(), wg_ephemeral_peer_pubkey: ephemeral_pubkey.as_bytes().to_vec(), post_quantum: pq_request, + #[cfg(any(windows, target_os = "ios"))] + daita: Some(proto::DaitaRequestV1 { + activate_daita: enable_daita, + }), + #[cfg(any(windows, target_os = "ios"))] + daita_v2: None, + #[cfg(all(unix, not(target_os = "ios")))] daita: None, + #[cfg(all(unix, not(target_os = "ios")))] daita_v2: enable_daita.then(|| proto::DaitaRequestV2 { - // TODO level: i32::from(proto::DaitaLevel::LevelDefault), platform: i32::from(get_platform()), version: DAITA_VERSION, @@ -192,43 +201,42 @@ pub async fn request_ephemeral_peer_with( None }; - let daita = response.daita.map(|daita| DaitaSettings { - client_machines: daita.client_machines, - max_padding_frac: daita.max_padding_frac, - max_blocking_frac: daita.max_blocking_frac, - }); - if daita.is_none() && enable_daita { - return Err(Error::MissingDaitaResponse); - } - - Ok(EphemeralPeer { psk, daita }) -} - -fn get_platform() -> proto::DaitaPlatform { - #[cfg(windows)] - { - proto::DaitaPlatform::WindowsNative - } - - #[cfg(target_os = "linux")] - { - proto::DaitaPlatform::LinuxWgGo - } - - #[cfg(target_os = "macos")] + #[cfg(all(unix, not(target_os = "ios")))] { - proto::DaitaPlatform::MacosWgGo + let daita = response.daita.map(|daita| DaitaSettings { + client_machines: daita.client_machines, + max_padding_frac: daita.max_padding_frac, + max_blocking_frac: daita.max_blocking_frac, + }); + if daita.is_none() && enable_daita { + return Err(Error::MissingDaitaResponse); + } + Ok(EphemeralPeer { psk, daita }) } - #[cfg(target_os = "android")] + #[cfg(any(windows, target_os = "ios"))] { - proto::DaitaPlatform::AndroidWgGo + Ok(EphemeralPeer { psk }) } +} - #[cfg(target_os = "ios")] - { - proto::DaitaPlatform::IosWgGo - } +#[cfg(all(unix, not(target_os = "ios")))] +const fn get_platform() -> proto::DaitaPlatform { + use proto::DaitaPlatform; + const PLATFORM: DaitaPlatform = if cfg!(target_os = "windows") { + DaitaPlatform::WindowsNative + } else if cfg!(target_os = "linux") { + DaitaPlatform::LinuxWgGo + } else if cfg!(target_os = "macos") { + DaitaPlatform::MacosWgGo + } else if cfg!(target_os = "android") { + DaitaPlatform::AndroidWgGo + } else if cfg!(target_os = "ios") { + DaitaPlatform::IosWgGo + } else { + panic!("This platform does not support DAITA V2") + }; + PLATFORM } async fn post_quantum_secrets() -> ( diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs index 892f3966ea..eea3004bfc 100644 --- a/talpid-wireguard/src/connectivity/mock.rs +++ b/talpid-wireguard/src/connectivity/mock.rs @@ -118,7 +118,10 @@ impl Tunnel for MockTunnel { } #[cfg(daita)] - fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + fn start_daita( + &mut self, + #[cfg(not(target_os = "windows"))] _: talpid_tunnel_config_client::DaitaSettings, + ) -> std::result::Result<(), TunnelError> { Ok(()) } } diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index a9283fcb2e..31f3957253 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -16,7 +16,8 @@ use std::{ use talpid_tunnel::tun_provider::TunProvider; use ipnetwork::IpNetwork; -use talpid_types::net::wireguard::{PresharedKey, PrivateKey, PublicKey}; +use talpid_tunnel_config_client::EphemeralPeer; +use talpid_types::net::wireguard::{PrivateKey, PublicKey}; use tokio::sync::Mutex as AsyncMutex; const INITIAL_PSK_EXCHANGE_TIMEOUT: Duration = Duration::from_secs(8); @@ -100,7 +101,7 @@ async fn config_ephemeral_peers_inner( let close_obfs_sender = close_obfs_sender.clone(); let exit_should_have_daita = config.daita && !config.is_multihop(); - let exit_psk = request_ephemeral_peer( + let exit_ephemeral_peer = request_ephemeral_peer( retry_attempt, config, ephemeral_private_key.public_key(), @@ -109,6 +110,9 @@ async fn config_ephemeral_peers_inner( ) .await?; + #[cfg(not(target_os = "windows"))] + let mut daita = exit_ephemeral_peer.daita; + log::debug!("Retrieved ephemeral peer"); if config.is_multihop() { @@ -130,8 +134,7 @@ async fn config_ephemeral_peers_inner( &tun_provider, ) .await?; - - let entry_psk = request_ephemeral_peer( + let entry_ephemeral_peer = request_ephemeral_peer( retry_attempt, &entry_config, ephemeral_private_key.public_key(), @@ -141,10 +144,14 @@ async fn config_ephemeral_peers_inner( .await?; log::debug!("Successfully exchanged PSK with entry peer"); - config.entry_peer.psk = entry_psk; + config.entry_peer.psk = entry_ephemeral_peer.psk; + #[cfg(not(target_os = "windows"))] + { + daita = entry_ephemeral_peer.daita; + } } - config.exit_peer_mut().psk = exit_psk; + config.exit_peer_mut().psk = exit_ephemeral_peer.psk; #[cfg(daita)] if config.daita { log::trace!("Enabling constant packet size for entry peer"); @@ -165,9 +172,22 @@ async fn config_ephemeral_peers_inner( #[cfg(daita)] if config.daita { + #[cfg(not(target_os = "windows"))] + let Some(daita) = daita + else { + unreachable!("missing DAITA settings"); + }; + // Start local DAITA machines let mut tunnel = tunnel.lock().await; if let Some(tunnel) = tunnel.as_mut() { + #[cfg(not(target_os = "windows"))] + tunnel + .start_daita(daita) + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + + #[cfg(target_os = "windows")] tunnel .start_daita() .map_err(Error::TunnelError) @@ -254,7 +274,7 @@ async fn request_ephemeral_peer( wg_psk_pubkey: PublicKey, enable_pq: bool, enable_daita: bool, -) -> std::result::Result<Option<PresharedKey>, CloseMsg> { +) -> std::result::Result<EphemeralPeer, CloseMsg> { log::debug!("Requesting ephemeral peer"); let timeout = std::cmp::min( @@ -281,5 +301,5 @@ async fn request_ephemeral_peer( .map_err(Error::EphemeralPeerNegotiationError) .map_err(CloseMsg::SetupError)?; - Ok(ephemeral.psk) + Ok(ephemeral) } diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index c3cf9a554f..2d282c6315 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -28,6 +28,8 @@ use talpid_tunnel::{ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, }; +#[cfg(not(target_os = "windows"))] +use talpid_tunnel_config_client::DaitaSettings; use talpid_types::{ net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol}, BoxedError, ErrorExt, @@ -196,6 +198,7 @@ impl WireguardMonitor { args.runtime.clone(), &config, log_path, + #[cfg(target_os = "windows")] args.resource_dir, args.tun_provider.clone(), #[cfg(target_os = "windows")] @@ -426,7 +429,6 @@ impl WireguardMonitor { let tunnel = Self::open_wireguard_go_tunnel( &config, log_path, - args.resource_dir, args.tun_provider.clone(), // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs // that we only allows traffic to/from the gateway. This is only needed on Android @@ -634,7 +636,7 @@ impl WireguardMonitor { runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, - resource_dir: &Path, + #[cfg(windows)] resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, @@ -646,8 +648,7 @@ impl WireguardMonitor { // If DAITA is enabled, wireguard-go has to be used. if config.daita { let tunnel = - Self::open_wireguard_go_tunnel(config, log_path, resource_dir, tun_provider) - .map(Box::new)?; + Self::open_wireguard_go_tunnel(config, log_path, tun_provider).map(Box::new)?; return Ok(tunnel); } @@ -699,8 +700,6 @@ impl WireguardMonitor { let tunnel = Self::open_wireguard_go_tunnel( config, log_path, - #[cfg(daita)] - resource_dir, tun_provider, #[cfg(target_os = "android")] gateway_only, @@ -715,7 +714,6 @@ impl WireguardMonitor { fn open_wireguard_go_tunnel( config: &Config, log_path: Option<&Path>, - #[cfg(daita)] resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(target_os = "android")] gateway_only: bool, #[cfg(target_os = "android")] connectivity_check: connectivity::Check< @@ -733,8 +731,6 @@ impl WireguardMonitor { log_path, tun_provider, routes, - #[cfg(daita)] - resource_dir, ) .map_err(Error::TunnelError)?; @@ -757,8 +753,6 @@ impl WireguardMonitor { log_path, tun_provider, routes, - #[cfg(daita)] - resource_dir, connectivity_check, ) .map_err(Error::TunnelError)? @@ -769,8 +763,6 @@ impl WireguardMonitor { log_path, tun_provider, routes, - #[cfg(daita)] - resource_dir, connectivity_check, ) .map_err(Error::TunnelError)? @@ -994,6 +986,9 @@ pub(crate) trait Tunnel: Send { ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'a>>; #[cfg(daita)] /// A [`Tunnel`] capable of using DAITA. + #[cfg(not(target_os = "windows"))] + fn start_daita(&mut self, settings: DaitaSettings) -> std::result::Result<(), TunnelError>; + #[cfg(target_os = "windows")] fn start_daita(&mut self) -> std::result::Result<(), TunnelError>; } diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index e2072c291e..c0775ac339 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -13,19 +13,18 @@ use crate::connectivity; use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; #[cfg(daita)] -use once_cell::sync::OnceCell; -#[cfg(daita)] -use std::{ffi::CString, fs, path::PathBuf}; +use std::ffi::CString; use std::{ future::Future, os::unix::io::{AsRawFd, RawFd}, - path::Path, + path::{Path, PathBuf}, pin::Pin, sync::{Arc, Mutex}, }; #[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; use talpid_tunnel::tun_provider::{Tun, TunProvider}; +use talpid_tunnel_config_client::DaitaSettings; #[cfg(target_os = "android")] use talpid_types::net::wireguard::PeerConfig; use talpid_types::BoxedError; @@ -115,8 +114,6 @@ impl WgGoTunnel { let log_path = state._logging_context.path.clone(); let tun_provider = Arc::clone(&state.tun_provider); let routes = config.get_tunnel_destinations(); - #[cfg(daita)] - let resource_dir = state.resource_dir.clone(); match self { WgGoTunnel::Multihop(state) if !config.is_multihop() => { @@ -126,7 +123,6 @@ impl WgGoTunnel { log_path.as_deref(), tun_provider, routes, - &resource_dir, connectivity_checker, ) } @@ -138,7 +134,6 @@ impl WgGoTunnel { log_path.as_deref(), tun_provider, routes, - &resource_dir, connectivity_checker, ) } @@ -169,8 +164,6 @@ pub(crate) struct WgGoTunnelState { #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, #[cfg(daita)] - resource_dir: PathBuf, - #[cfg(daita)] config: Config, // HACK: Check is not Clone, so we have to pass this around .. // This is conceptually the connection between this Tunnel and the currently running @@ -224,7 +217,6 @@ impl WgGoTunnel { log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, - #[cfg(daita)] resource_dir: &Path, ) -> Result<Self> { let (tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?; @@ -251,8 +243,6 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, #[cfg(daita)] - resource_dir: resource_dir.to_owned(), - #[cfg(daita)] config: config.clone(), })) } @@ -303,7 +293,6 @@ impl WgGoTunnel { log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, - #[cfg(daita)] resource_dir: &Path, mut connectivity_check: connectivity::Check<connectivity::Cancellable>, ) -> Result<Self> { let (mut tunnel_device, tunnel_fd) = @@ -334,8 +323,6 @@ impl WgGoTunnel { _logging_context: logging_context, tun_provider, #[cfg(daita)] - resource_dir: resource_dir.to_owned(), - #[cfg(daita)] config: config.clone(), connectivity_checker: None, }); @@ -353,7 +340,6 @@ impl WgGoTunnel { log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, - #[cfg(daita)] resource_dir: &Path, mut connectivity_check: connectivity::Check<connectivity::Cancellable>, ) -> Result<Self> { let (mut tunnel_device, tunnel_fd) = @@ -400,8 +386,6 @@ impl WgGoTunnel { _logging_context: logging_context, tun_provider, #[cfg(daita)] - resource_dir: resource_dir.to_owned(), - #[cfg(daita)] config: config.clone(), connectivity_checker: None, }); @@ -477,20 +461,22 @@ impl Tunnel for WgGoTunnel { } #[cfg(daita)] - fn start_daita(&mut self) -> Result<()> { - static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new(); - let machines = MAYBENOT_MACHINES - .get_or_try_init(|| load_maybenot_machines(&self.as_state().resource_dir))?; - + fn start_daita(&mut self, settings: DaitaSettings) -> Result<()> { log::info!("Initializing DAITA for wireguard device"); let config = &self.as_state().config; let peer_public_key = &config.entry_peer.public_key; + let machines = settings.client_machines.join("\n"); + let machines = + CString::new(machines).map_err(|err| TunnelError::StartDaita(Box::new(err)))?; + self.as_state() .tunnel_handle .activate_daita( peer_public_key.as_bytes(), - machines, + &machines, + settings.max_padding_frac, + settings.max_blocking_frac, DAITA_EVENTS_CAPACITY, DAITA_ACTIONS_CAPACITY, ) diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs index 52d8616c02..8b84b3769d 100644 --- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs @@ -1,6 +1,7 @@ use std::pin::Pin; use futures::Future; +use talpid_tunnel_config_client::DaitaSettings; use crate::config::MULLVAD_INTERFACE_NAME; @@ -131,7 +132,7 @@ impl Tunnel for NetlinkTunnel { } /// Outright fail to start - this tunnel type does not support DAITA. - fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + fn start_daita(&mut self, _: DaitaSettings) -> std::result::Result<(), TunnelError> { Err(TunnelError::DaitaNotSupported) } } diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs index 76cfbcaddd..070e3d1ee9 100644 --- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs @@ -13,6 +13,7 @@ use talpid_dbus::{ WireguardTunnel, }, }; +use talpid_tunnel_config_client::DaitaSettings; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -114,7 +115,7 @@ impl Tunnel for NetworkManagerTunnel { } /// Outright fail to start - this tunnel type does not support DAITA. - fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + fn start_daita(&mut self, _: DaitaSettings) -> std::result::Result<(), TunnelError> { Err(TunnelError::DaitaNotSupported) } } diff --git a/wireguard-go-rs/libwg/libwg_daita.go b/wireguard-go-rs/libwg/libwg_daita.go index fbfceec8f0..b73be376a3 100644 --- a/wireguard-go-rs/libwg/libwg_daita.go +++ b/wireguard-go-rs/libwg/libwg_daita.go @@ -19,11 +19,8 @@ import ( "golang.zx2c4.com/wireguard/device" ) -const maxPaddingBytes = 0.0 -const maxBlockingBytes = 0.0 - //export wgActivateDaita -func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C.char, eventsCapacity C.uint32_t, actionsCapacity C.uint32_t) C.int32_t { +func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C.char, maxPaddingFrac C.double, maxBlockingFrac C.double, eventsCapacity C.uint32_t, actionsCapacity C.uint32_t) C.int32_t { tunnel, err := tunnels.Get(int32(tunnelHandle)) if err != nil { @@ -46,7 +43,7 @@ func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C. return ERROR_UNKNOWN_PEER } - if !peer.EnableDaita(goStringFixed((*C.char)(machines)), uint(eventsCapacity), uint(actionsCapacity), maxPaddingBytes, maxBlockingBytes) { + if !peer.EnableDaita(goStringFixed((*C.char)(machines)), uint(eventsCapacity), uint(actionsCapacity), float64(maxPaddingFrac), float64(maxBlockingFrac)) { return ERROR_ENABLE_DAITA } diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs index 851fd47b9f..3e75506f61 100644 --- a/wireguard-go-rs/src/lib.rs +++ b/wireguard-go-rs/src/lib.rs @@ -191,6 +191,8 @@ impl Tunnel { &self, peer_public_key: &[u8; 32], machines: &CStr, + max_padding_frac: f64, + max_blocking_frac: f64, events_capacity: u32, actions_capacity: u32, ) -> Result<(), Error> { @@ -200,6 +202,8 @@ impl Tunnel { self.handle, peer_public_key.as_ptr(), machines.as_ptr(), + max_padding_frac, + max_blocking_frac, events_capacity, actions_capacity, ) @@ -342,6 +346,8 @@ mod ffi { tunnel_handle: i32, peer_public_key: *const u8, machines: *const c_char, + max_padding_frac: f64, + max_blocking_frac: f64, events_capacity: u32, actions_capacity: u32, ) -> i32; |
