diff options
| author | Kalle Lindström <karl.lindstrom@mullvad.net> | 2024-10-29 11:33:01 +0100 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-11-22 17:42:38 +0100 |
| commit | 40a9dc2269d408e5ae819fad44a19eec7aa7965a (patch) | |
| tree | 43a5cf66f7d397fe4a7500e5cd51d828c77e1fad | |
| parent | 9de3f60296a8b19aba1f69ff86f2e876c1f7f9ee (diff) | |
| download | mullvadvpn-40a9dc2269d408e5ae819fad44a19eec7aa7965a.tar.xz mullvadvpn-40a9dc2269d408e5ae819fad44a19eec7aa7965a.zip | |
Add multihop negotiation with ephemeral peers
Use `WgGoTunnel` directly on Android because a specialized implemenation
of `set_config` has to be used.
| -rw-r--r-- | talpid-wireguard/src/config.rs | 7 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity_check.rs | 6 | ||||
| -rw-r--r-- | talpid-wireguard/src/ephemeral.rs | 69 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 74 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 270 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs | 1 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs | 1 |
7 files changed, 298 insertions, 130 deletions
diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index 5326427d13..ae5194e200 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -185,6 +185,13 @@ impl Config { .into_iter() .chain(std::iter::once(&mut self.entry_peer)) } + + /// Return routes for all allowed IPs. + pub fn get_tunnel_destinations(&self) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { + self.peers() + .flat_map(|peer| peer.allowed_ips.iter()) + .cloned() + } } enum ConfValue<'a> { diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index 608002d1a6..35096a0ce2 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -1,6 +1,7 @@ use crate::{ ping_monitor::{new_pinger, Pinger}, stats::StatsMap, + TunnelType, }; use std::{ cmp, @@ -71,7 +72,8 @@ pub enum Error { /// Once a connection established, a connection is only considered broken once the connectivity /// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. pub struct ConnectivityMonitor { - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, + /// Tunnel implementation + tunnel_handle: Weak<Mutex<Option<TunnelType>>>, conn_state: ConnState, initial_ping_timestamp: Option<Instant>, num_pings_sent: u32, @@ -83,7 +85,7 @@ impl ConnectivityMonitor { pub(super) fn new( addr: Ipv4Addr, #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, + tunnel_handle: Weak<Mutex<Option<TunnelType>>>, close_receiver: mpsc::Receiver<()>, ) -> Result<Self, Error> { let pinger = new_pinger( diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index 5440a142f6..07f611c842 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -1,7 +1,7 @@ //! This module takes care of obtaining ephemeral peers, updating the WireGuard configuration and //! restarting obfuscation and WG tunnels when necessary. -use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, Tunnel}; +use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, Tunnel, TunnelType}; #[cfg(target_os = "android")] use std::sync::Mutex; use std::{ @@ -22,7 +22,7 @@ const PSK_EXCHANGE_TIMEOUT_MULTIPLIER: u32 = 2; #[cfg(windows)] pub async fn config_ephemeral_peers( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, @@ -64,15 +64,14 @@ fn try_set_ipv4_mtu(alias: &str, mtu: u16) { } } -#[cfg(not(windows))] pub async fn config_ephemeral_peers( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, -) -> std::result::Result<(), CloseMsg> { +) -> Result<(), CloseMsg> { config_ephemeral_peers_inner( tunnel, config, @@ -86,16 +85,17 @@ pub async fn config_ephemeral_peers( } async fn config_ephemeral_peers_inner( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, -) -> std::result::Result<(), CloseMsg> { +) -> Result<(), CloseMsg> { let ephemeral_private_key = PrivateKey::new_from_random(); let close_obfs_sender = close_obfs_sender.clone(); + // NOTE: this might be the entry? let exit_should_have_daita = config.daita && !config.is_multihop(); let exit_psk = request_ephemeral_peer( retry_attempt, @@ -111,6 +111,7 @@ async fn config_ephemeral_peers_inner( if config.is_multihop() { // Set up tunnel to lead to entry let mut entry_tun_config = config.clone(); + entry_tun_config.exit_peer = None; entry_tun_config .entry_peer .allowed_ips @@ -158,6 +159,8 @@ async fn config_ephemeral_peers_inner( ) .await?; + log::info!("Config: {config:#?}"); + #[cfg(daita)] if config.daita { // Start local DAITA machines @@ -173,15 +176,16 @@ async fn config_ephemeral_peers_inner( Ok(()) } +#[cfg(target_os = "android")] /// Reconfigures the tunnel to use the provided config while potentially modifying the config /// and restarting the obfuscation provider. Returns the new config used by the new tunnel. async fn reconfigure_tunnel( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, mut config: Config, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, - #[cfg(target_os = "android")] tun_provider: &Arc<Mutex<TunProvider>>, -) -> std::result::Result<Config, CloseMsg> { + tun_provider: &Arc<Mutex<TunProvider>>, +) -> Result<Config, CloseMsg> { let mut obfs_guard = obfuscator.lock().await; if let Some(obfuscator_handle) = obfs_guard.take() { obfuscator_handle.abort(); @@ -195,16 +199,45 @@ async fn reconfigure_tunnel( .map_err(CloseMsg::ObfuscatorFailed)?; } - let mut tunnel = tunnel.lock().await; + let mut lock = tunnel.lock().await; - let set_config_future = tunnel - .as_mut() - .map(|tunnel| tunnel.set_config(config.clone())); + let tunnel = lock.take().expect("tunnel was None"); - if let Some(f) = set_config_future { - f.await - .map_err(Error::TunnelError) - .map_err(CloseMsg::SetupError)?; + let new_tunnel = tunnel.better_set_config(&config).unwrap(); + + *lock = Some(new_tunnel); + Ok(config) +} + +#[cfg(not(target_os = "android"))] +/// Reconfigures the tunnel to use the provided config while potentially modifying the config +/// and restarting the obfuscation provider. Returns the new config used by the new tunnel. +async fn reconfigure_tunnel( + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, + mut config: Config, + obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, + close_obfs_sender: sync_mpsc::Sender<CloseMsg>, +) -> Result<Config, CloseMsg> { + let mut obfs_guard = obfuscator.lock().await; + if let Some(obfuscator_handle) = obfs_guard.take() { + obfuscator_handle.abort(); + *obfs_guard = super::obfuscation::apply_obfuscation_config(&mut config, close_obfs_sender) + .await + .map_err(CloseMsg::ObfuscatorFailed)?; + } + + { + let mut tunnel = tunnel.lock().await; + + let set_config_future = tunnel + .as_mut() + .map(|tunnel| tunnel.set_config(config.clone())); + + if let Some(f) = set_config_future { + f.await + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } } Ok(config) diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 27509e00b6..948cbc0129 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -54,6 +54,12 @@ mod mtu_detection; #[cfg(wireguard_go)] use self::wireguard_go::WgGoTunnel; +// TODO: Document why we have a type alias ! +#[cfg(not(target_os = "android"))] +type TunnelType = Box<dyn Tunnel>; +#[cfg(target_os = "android")] +type TunnelType = WgGoTunnel; + type Result<T> = std::result::Result<T, Error>; type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; @@ -134,7 +140,7 @@ impl Error { pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation - tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: Arc<AsyncMutex<Option<TunnelType>>>, /// Callback to signal tunnel events event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, @@ -396,8 +402,8 @@ impl WireguardMonitor { args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { let desired_mtu = get_desired_mtu(params); - let mut config = crate::config::Config::from_parameters(params, desired_mtu) - .map_err(Error::WireguardConfigError)?; + let mut config = + Config::from_parameters(params, desired_mtu).map_err(Error::WireguardConfigError)?; let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel(); // Start obfuscation server and patch the WireGuard config to point the endpoint to it. @@ -417,8 +423,7 @@ impl WireguardMonitor { } let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; - let tunnel = Self::open_tunnel( - args.runtime.clone(), + let tunnel = Self::open_wireguard_go_tunnel( &config, log_path, args.resource_dir, @@ -428,13 +433,13 @@ impl WireguardMonitor { // since we lack a firewall there. should_negotiate_ephemeral_peer, )?; - let iface_name = tunnel.get_interface_name(); + let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), - tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), + tunnel: Arc::clone(&tunnel), event_callback: Box::new(args.on_event.clone()), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, @@ -444,23 +449,21 @@ impl WireguardMonitor { let gateway = config.ipv4_gateway; let connectivity_monitor = connectivity_check::ConnectivityMonitor::new( gateway, - Arc::downgrade(&monitor.tunnel), + Arc::downgrade(&tunnel), pinger_rx, ) .map_err(Error::ConnectivityMonitorError)?; - let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); let tunnel_fut = async move { - let tunnel = moved_tunnel; let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender; let obfuscator = moved_obfuscator; let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor)); let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - (args.on_event.clone())(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; let handle_ping = |ping_result: std::result::Result< @@ -498,6 +501,7 @@ impl WireguardMonitor { // Ping before negotiating the ephemeral peer to make sure that the tunnel works. tokio::task::spawn_blocking(ping()).await.unwrap()?; let ephemeral_obfs_sender = close_obfs_sender.clone(); + ephemeral::config_ephemeral_peers( &tunnel, &mut config, @@ -585,6 +589,8 @@ impl WireguardMonitor { /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. /// Used to block traffic to other destinations while connecting on Android. + /// + /// TODO: This might need some patchin' now when multihop is a thing. #[cfg(target_os = "android")] fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> { if gateway_only { @@ -654,16 +660,16 @@ impl WireguardMonitor { } #[allow(unused_variables)] + #[cfg(not(target_os = "android"))] fn open_tunnel( runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, - #[cfg(target_os = "android")] gateway_only: bool, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, - ) -> Result<Box<dyn Tunnel>> { + ) -> Result<TunnelType> { log::debug!("Tunnel MTU: {}", config.mtu); #[cfg(target_os = "linux")] @@ -751,9 +757,33 @@ impl WireguardMonitor { let exit_config = wireguard_go::exit_config(&config); + let should_negotiate_with_ephemeral_peer = gateway_only; #[cfg(target_os = "android")] - let tunnel = if exit_config.is_some() { - WgGoTunnel::start_multihop_tunnel( + let tunnel = match exit_config { + // Android uses multihop implemented in Mullvad's wireguard-go fork. When negotiating + // with an ephemeral peer, this multihop strategy require us to restart the tunnel + // every time we want to reconfigure it. As such, we will actually start a multihop + // tunnel at a later stage, after we have negotiated with the first ephemeral peer. + // At this point, when the tunnel *is first started*, we establish a regular, singlehop + // tunnel to where the ephemeral peer resides. + // + // TODO: Refer to `docs/architecture.md` for details on how to use multihop + PQ. + Some(_exit_config) if should_negotiate_with_ephemeral_peer => { + WgGoTunnel::start_multihop_tunnel( + #[allow(clippy::needless_borrow)] + // TODO: Check if `entry` should be used instead ?? + &config, + log_path, + tun_provider, + routes, + #[cfg(daita)] + resource_dir, + ) + .map_err(Error::TunnelError)? + } + // If we don't need to negotiate with an ephemeral peer, we may simply start a multihop + // tunnel from the get-go. + Some(_exit_config) => WgGoTunnel::start_multihop_tunnel( #[allow(clippy::needless_borrow)] &config, log_path, @@ -762,9 +792,8 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir, ) - .map_err(Error::TunnelError)? - } else { - WgGoTunnel::start_tunnel( + .map_err(Error::TunnelError)?, + None => WgGoTunnel::start_tunnel( #[allow(clippy::needless_borrow)] &config, log_path, @@ -773,7 +802,7 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir, ) - .map_err(Error::TunnelError)? + .map_err(Error::TunnelError)?, }; #[cfg(not(target_os = "android"))] @@ -956,12 +985,10 @@ impl WireguardMonitor { } } + // TODO: Remove? /// Return routes for all allowed IPs. fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { - config - .peers() - .flat_map(|peer| peer.allowed_ips.iter()) - .cloned() + config.get_tunnel_destinations() } /// Replace default (0-prefix) routes with more specific routes. @@ -1001,6 +1028,7 @@ enum CloseMsg { ObfuscatorFailed(Error), } +#[allow(unused)] pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 74cee0c768..c2543ed917 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -44,12 +44,88 @@ impl Drop for LoggingContext { } } -pub struct WgGoTunnel { +pub enum WgGoTunnel { + Multihop(WgGoTunnelState), + Singlehop(WgGoTunnelState), +} + +impl WgGoTunnel { + fn into_state(self) -> WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + fn as_state(&self) -> &WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + fn to_state_mut(&mut self) -> &mut WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + pub fn better_set_config(self, config: &Config) -> Result<Self> { + let state = self.as_state(); + let log_path = state._log_path.clone(); + let tun_provider = Arc::clone(&state.tun_provider); + let routes = config.get_tunnel_destinations(); + #[cfg(daita)] + let resource_dir = state.resource_dir.clone(); + + match self { + WgGoTunnel::Multihop(state) if !config.is_multihop() => { + // Important! + state.stop().unwrap(); + Self::start_tunnel( + config, + log_path.as_deref(), + tun_provider, + routes, + &resource_dir, + ) + } + WgGoTunnel::Singlehop(state) if config.is_multihop() => { + state.stop().unwrap(); + Self::start_multihop_tunnel( + config, + log_path.as_deref(), + tun_provider, + routes, + &resource_dir, + ) + } + WgGoTunnel::Singlehop(mut state) => { + state.set_config(config.clone())?; + Ok(WgGoTunnel::Singlehop(state)) + } + WgGoTunnel::Multihop(mut state) => { + state.set_config(config.clone())?; + Ok(WgGoTunnel::Multihop(state)) + } + } + } + + pub fn stop(self) -> Result<()> { + self.into_state().stop() + } +} + +// TODO: Does this need to be pub? +pub struct WgGoTunnelState { interface_name: String, tunnel_handle: wireguard_go_rs::Tunnel, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped _tunnel_device: Tun, + // HACK: Don't use this. Only sometimes. ;-) + _log_path: Option<PathBuf>, // context that maps to fs::File instance, used with logging callback _logging_context: LoggingContext, #[cfg(target_os = "android")] @@ -60,6 +136,42 @@ pub struct WgGoTunnel { config: Config, } +impl WgGoTunnelState { + fn stop(self) -> Result<()> { + self.tunnel_handle + .turn_off() + .map_err(|e| TunnelError::StopWireguardError(Box::new(e))) + } + + fn set_config(&mut self, config: Config) -> Result<()> { + let wg_config_str = config.to_userspace_format(); + + self.tunnel_handle + .set_config(&wg_config_str) + .map_err(|_| TunnelError::SetConfigError)?; + + #[cfg(target_os = "android")] + let tun_provider = self.tun_provider.clone(); + + // When reapplying the config, the endpoint socket may be discarded + // and needs to be excluded again + #[cfg(target_os = "android")] + { + let socket_v4 = self.tunnel_handle.get_socket_v4(); + let socket_v6 = self.tunnel_handle.get_socket_v6(); + let mut provider = tun_provider.lock().unwrap(); + provider + .bypass(socket_v4) + .map_err(super::TunnelError::BypassError)?; + provider + .bypass(socket_v6) + .map_err(super::TunnelError::BypassError)?; + } + + Ok(()) + } +} + // TODO: move into impl of Config pub(crate) fn exit_config(multihop_config: &Config) -> Option<Config> { let mut exit_config = multihop_config.clone(); @@ -68,7 +180,7 @@ pub(crate) fn exit_config(multihop_config: &Config) -> Option<Config> { } // TODO: move into impl of Config -fn entry_config(multihop_config: &Config) -> Config { +pub(crate) fn entry_config(multihop_config: &Config) -> Config { let mut entry_config = multihop_config.clone(); entry_config.exit_peer = None; entry_config @@ -116,7 +228,7 @@ impl WgGoTunnel { ) .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; - Ok(WgGoTunnel { + Ok(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -128,7 +240,47 @@ impl WgGoTunnel { }) } - #[cfg(target_os = "android")] + fn get_tunnel( + tun_provider: Arc<Mutex<TunProvider>>, + config: &Config, + routes: impl Iterator<Item = IpNetwork>, + ) -> Result<(Tun, RawFd)> { + let mut last_error = None; + let mut tun_provider = tun_provider.lock().unwrap(); + + let tun_config = tun_provider.config_mut(); + #[cfg(target_os = "linux")] + { + tun_config.name = Some(MULLVAD_INTERFACE_NAME.to_string()); + } + tun_config.addresses = config.tunnel.addresses.clone(); + tun_config.ipv4_gateway = config.ipv4_gateway; + tun_config.ipv6_gateway = config.ipv6_gateway; + tun_config.routes = routes.collect(); + tun_config.mtu = config.mtu; + + for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { + let tunnel_device = tun_provider + .open_tun() + .map_err(TunnelError::SetupTunnelDevice)?; + + match nix::unistd::dup(tunnel_device.as_raw_fd()) { + Ok(fd) => return Ok((tunnel_device, fd)), + #[cfg(not(target_os = "macos"))] + Err(error @ nix::errno::Errno::EBADFD) => last_error = Some(error), + Err(error @ nix::errno::Errno::EBADF) => last_error = Some(error), + Err(error) => return Err(TunnelError::FdDuplicationError(error)), + } + } + + Err(TunnelError::FdDuplicationError( + last_error.expect("Should be collected in loop"), + )) + } +} + +#[cfg(target_os = "android")] +impl WgGoTunnel { pub fn start_tunnel( config: &Config, log_path: Option<&Path>, @@ -142,6 +294,7 @@ impl WgGoTunnel { let interface_name: String = tunnel_device.interface_name().to_string(); let wg_config_str = config.to_userspace_format(); + let _log_path = log_path; let logging_context = initialize_logging(log_path) .map(LoggingContext) .map_err(TunnelError::LoggingError)?; @@ -157,20 +310,20 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - Ok(WgGoTunnel { + Ok(WgGoTunnel::Singlehop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, _logging_context: logging_context, + _log_path: _log_path.map(|log_path| log_path.to_owned()), tun_provider: tun_provider_clone, #[cfg(daita)] resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), - }) + })) } - #[cfg(target_os = "android")] pub fn start_multihop_tunnel( config: &Config, log_path: Option<&Path>, @@ -183,6 +336,7 @@ impl WgGoTunnel { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?; let interface_name: String = tunnel_device.interface_name().to_string(); + let _log_path = log_path; let logging_context = initialize_logging(log_path) .map(LoggingContext) .map_err(TunnelError::LoggingError)?; @@ -209,20 +363,20 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - Ok(WgGoTunnel { + Ok(WgGoTunnel::Multihop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, _logging_context: logging_context, + _log_path: _log_path.map(|log_path| log_path.to_owned()), tun_provider: tun_provider_clone, #[cfg(daita)] resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), - }) + })) } - #[cfg(target_os = "android")] fn bypass_tunnel_sockets( handle: &wireguard_go_rs::Tunnel, tunnel_device: &mut Tun, @@ -235,53 +389,20 @@ impl WgGoTunnel { Ok(()) } - - fn get_tunnel( - tun_provider: Arc<Mutex<TunProvider>>, - config: &Config, - routes: impl Iterator<Item = IpNetwork>, - ) -> Result<(Tun, RawFd)> { - let mut last_error = None; - let mut tun_provider = tun_provider.lock().unwrap(); - - let tun_config = tun_provider.config_mut(); - #[cfg(target_os = "linux")] - { - tun_config.name = Some(MULLVAD_INTERFACE_NAME.to_string()); - } - tun_config.addresses = config.tunnel.addresses.clone(); - tun_config.ipv4_gateway = config.ipv4_gateway; - tun_config.ipv6_gateway = config.ipv6_gateway; - tun_config.routes = routes.collect(); - tun_config.mtu = config.mtu; - - for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { - let tunnel_device = tun_provider - .open_tun() - .map_err(TunnelError::SetupTunnelDevice)?; - - match nix::unistd::dup(tunnel_device.as_raw_fd()) { - Ok(fd) => return Ok((tunnel_device, fd)), - #[cfg(not(target_os = "macos"))] - Err(error @ nix::errno::Errno::EBADFD) => last_error = Some(error), - Err(error @ nix::errno::Errno::EBADF) => last_error = Some(error), - Err(error) => return Err(TunnelError::FdDuplicationError(error)), - } - } - - Err(TunnelError::FdDuplicationError( - last_error.expect("Should be collected in loop"), - )) - } } impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> String { - self.interface_name.clone() + self.as_state().interface_name.clone() + } + + fn stop(self: Box<Self>) -> Result<()> { + self.into_state().stop() } fn get_tunnel_stats(&self) -> Result<StatsMap> { - self.tunnel_handle + self.as_state() + .tunnel_handle .get_config(|cstr| { Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) }) @@ -289,54 +410,29 @@ impl Tunnel for WgGoTunnel { .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) } - fn stop(self: Box<Self>) -> Result<()> { - self.tunnel_handle - .turn_off() - .map_err(|e| TunnelError::StopWireguardError(Box::new(e))) - } - fn set_config( &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> { - Box::pin(async move { - let wg_config_str = config.to_userspace_format(); - - self.tunnel_handle - .set_config(&wg_config_str) - .map_err(|_| TunnelError::SetConfigError)?; - - #[cfg(target_os = "android")] - let tun_provider = self.tun_provider.clone(); - - // When reapplying the config, the endpoint socket may be discarded - // and needs to be excluded again - #[cfg(target_os = "android")] - { - let socket_v4 = self.tunnel_handle.get_socket_v4(); - let socket_v6 = self.tunnel_handle.get_socket_v6(); - let mut provider = tun_provider.lock().unwrap(); - provider - .bypass(socket_v4) - .map_err(super::TunnelError::BypassError)?; - provider - .bypass(socket_v6) - .map_err(super::TunnelError::BypassError)?; - } - - Ok(()) - }) + Box::pin(async move { self.to_state_mut().set_config(config) }) } #[cfg(daita)] fn start_daita(&mut self) -> Result<()> { static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new(); - let machines = - MAYBENOT_MACHINES.get_or_try_init(|| load_maybenot_machines(&self.resource_dir))?; + let machines = MAYBENOT_MACHINES + .get_or_try_init(|| load_maybenot_machines(&self.as_state().resource_dir))?; log::info!("Initializing DAITA for wireguard device"); - let peer_public_key = &self.config.entry_peer.public_key; - self.tunnel_handle + let config = &self.as_state().config; + + let peer_public_key = match config.exit_peer.as_ref() { + Some(exit) => &exit.public_key, + None => &config.entry_peer.public_key, + }; + + self.as_state() + .tunnel_handle .activate_daita( peer_public_key.as_bytes(), machines, diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs index 52d8616c02..69b7687579 100644 --- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs @@ -1,3 +1,4 @@ +use std::any::Any; use std::pin::Pin; use futures::Future; diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs index 76cfbcaddd..84230c22a1 100644 --- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs @@ -5,6 +5,7 @@ use super::{ Config, Error as WgKernelError, Handle, Tunnel, TunnelError, }; use futures::Future; +use std::any::Any; use std::{collections::HashMap, pin::Pin}; use talpid_dbus::{ dbus, |
