diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-09-18 14:15:09 +0200 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-09-18 14:15:09 +0200 |
| commit | 63c1abe65257b3357bbca696ad832968e6ecab44 (patch) | |
| tree | c1663932f2fea79df9a4d4783edc3460abd53333 | |
| parent | a8c1ca3474cb7741debf5ade3fdbd21de8dcc660 (diff) | |
| parent | 0264abaf2801709bb9e78e533b8873a0ee3ae6dd (diff) | |
| download | mullvadvpn-63c1abe65257b3357bbca696ad832968e6ecab44.tar.xz mullvadvpn-63c1abe65257b3357bbca696ad832968e6ecab44.zip | |
Merge branch 'timeout-negotiating-ephemeral-peer-des-1238'
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 214 |
1 files changed, 190 insertions, 24 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index a477bea80b..0c918e1fc7 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -20,7 +20,7 @@ use std::{ }; #[cfg(target_os = "linux")] use std::{env, sync::LazyLock}; -use talpid_routing as routing; +#[cfg(not(target_os = "android"))] use talpid_routing::{self, RequiredRoute}; #[cfg(not(windows))] use talpid_tunnel::tun_provider; @@ -264,6 +264,7 @@ async fn maybe_create_obfuscator( impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config + #[cfg(not(target_os = "android"))] pub fn start< F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + Send @@ -272,7 +273,7 @@ impl WireguardMonitor { + 'static, >( mut config: Config, - #[cfg(not(target_os = "android"))] detect_mtu: bool, + detect_mtu: bool, log_path: Option<&Path>, args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { @@ -294,8 +295,6 @@ impl WireguardMonitor { log_path, args.resource_dir, args.tun_provider.clone(), - #[cfg(target_os = "android")] - config.quantum_resistant, #[cfg(target_os = "windows")] args.route_manager.clone(), #[cfg(target_os = "windows")] @@ -303,15 +302,6 @@ impl WireguardMonitor { )?; let iface_name = tunnel.get_interface_name(); - #[cfg(target_os = "android")] - if let Some(remote_socket_fd) = obfuscator.as_ref().map(|obfs| obfs.remote_socket_fd()) { - // Exclude remote obfuscation socket or bridge - log::debug!("Excluding remote socket fd from the tunnel"); - if let Err(error) = args.tun_provider.lock().unwrap().bypass(remote_socket_fd) { - log::error!("Failed to exclude remote socket fd: {error}"); - } - } - let obfuscator = Arc::new(AsyncMutex::new(obfuscator)); let event_callback = Box::new(on_event.clone()); @@ -376,8 +366,6 @@ impl WireguardMonitor { args.retry_attempt, obfuscator.clone(), ephemeral_obfs_sender, - #[cfg(target_os = "android")] - args.tun_provider, ) .await?; @@ -389,7 +377,6 @@ impl WireguardMonitor { .await; } - #[cfg(not(target_os = "android"))] if detect_mtu { let config = config.clone(); let iface_name = iface_name.clone(); @@ -420,6 +407,7 @@ impl WireguardMonitor { }; }); } + let mut connectivity_monitor = tokio::task::spawn_blocking(move || { match connectivity_monitor.establish_connectivity(args.retry_attempt) { Ok(true) => Ok(connectivity_monitor), @@ -480,6 +468,177 @@ impl WireguardMonitor { Ok(monitor) } + /// Starts a WireGuard tunnel with the given config + /// + /// This differs from [`start`] on other platforms in multiple ways. Here is a list of some + /// notable differences: + /// - A ping is sent between the Wireguard-GO tunnel is started and an ephemeral peer is + /// negotiated. There seems to be a race condition between starting the tunnel and the tunnel + /// being ready to serve traffic. + /// - No routes are configured on android. + #[cfg(target_os = "android")] + pub fn start< + F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + + Send + + Sync + + Clone + + 'static, + >( + mut config: Config, + log_path: Option<&Path>, + args: TunnelArgs<'_, F>, + ) -> Result<WireguardMonitor> { + let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; + let tunnel = Self::open_tunnel( + args.runtime.clone(), + &config, + log_path, + args.resource_dir, + args.tun_provider.clone(), + // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs + // that we only allows traffic to/from the gateway. This is only needed on Android + // since we lack a firewall there. + should_negotiate_ephemeral_peer, + )?; + + let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel(); + let obfuscator = args.runtime.block_on(maybe_create_obfuscator( + &mut config, + close_obfs_sender.clone(), + ))?; + + if let Some(remote_socket_fd) = obfuscator.as_ref().map(|obfs| obfs.remote_socket_fd()) { + // Exclude remote obfuscation socket or bridge + log::debug!("Excluding remote socket fd from the tunnel"); + if let Err(error) = args.tun_provider.lock().unwrap().bypass(remote_socket_fd) { + log::error!("Failed to exclude remote socket fd: {error}"); + } + } + + let iface_name = tunnel.get_interface_name(); + + let (pinger_tx, pinger_rx) = sync_mpsc::channel(); + let monitor = WireguardMonitor { + runtime: args.runtime.clone(), + tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), + event_callback: Box::new(args.on_event.clone()), + close_msg_receiver: close_obfs_listener, + pinger_stop_sender: pinger_tx, + obfuscator: Arc::new(AsyncMutex::new(obfuscator)), + }; + + let gateway = config.ipv4_gateway; + let connectivity_monitor = connectivity_check::ConnectivityMonitor::new( + gateway, + Arc::downgrade(&monitor.tunnel), + pinger_rx, + ) + .map_err(Error::ConnectivityMonitorError)?; + + let moved_tunnel = monitor.tunnel.clone(); + let moved_close_obfs_sender = close_obfs_sender.clone(); + let moved_obfuscator = monitor.obfuscator.clone(); + let tunnel_fut = async move { + let tunnel = moved_tunnel; + let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender; + let obfuscator = moved_obfuscator; + let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor)); + + let metadata = Self::tunnel_metadata(&iface_name, &config); + let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); + (args.on_event.clone())(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + .await; + + let handle_ping = |ping_result: std::result::Result< + bool, + connectivity_check::Error, + >| match ping_result { + Ok(true) => Ok(()), + Ok(false) => { + log::warn!("Timeout while checking tunnel connection"); + Err(CloseMsg::PingErr) + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check tunnel connection") + ); + Err(CloseMsg::PingErr) + } + }; + + // Prepare a closure which pings inside the tunnel when executed. + let ping = || { + let connectivity_monitor_arc = connectivity_monitor.clone(); + let retry_attempt = args.retry_attempt; + move || { + let ping_result = connectivity_monitor_arc + .lock() + .unwrap() + .establish_connectivity(retry_attempt); + handle_ping(ping_result) + } + }; + + if should_negotiate_ephemeral_peer { + // Ping before negotiating the ephemeral peer to make sure that the tunnel works. + tokio::task::spawn_blocking(ping()).await.unwrap()?; + let ephemeral_obfs_sender = close_obfs_sender.clone(); + Self::config_ephemeral_peers( + &tunnel, + &mut config, + args.retry_attempt, + obfuscator.clone(), + ephemeral_obfs_sender, + args.tun_provider, + ) + .await?; + + let metadata = Self::tunnel_metadata(&iface_name, &config); + (args.on_event.clone())(TunnelEvent::InterfaceUp( + metadata, + Self::allowed_traffic_after_tunnel_config(), + )) + .await; + } + + // Make sure the tunnel works (after potentially having negotiated an ephemeral peer). + tokio::task::spawn_blocking(ping()).await.unwrap()?; + + let metadata = Self::tunnel_metadata(&iface_name, &config); + (args.on_event.clone())(TunnelEvent::Up(metadata)).await; + + tokio::task::spawn_blocking(move || { + if let Err(error) = connectivity_monitor.lock().unwrap().run() { + log::error!( + "{}", + error.display_chain_with_msg("Connectivity monitor failed") + ); + } + }) + .await + .unwrap(); + + Err::<Infallible, CloseMsg>(CloseMsg::PingErr) + }; + + let close_sender = close_obfs_sender.clone(); + let monitor_handle = tokio::spawn(async move { + // This is safe to unwrap because the future resolves to `Result<Infallible, E>`. + let close_msg = tunnel_fut.await.unwrap_err(); + let _ = close_sender.send(close_msg); + }); + + tokio::spawn(async move { + if args.tunnel_close_rx.await.is_ok() { + monitor_handle.abort(); + let _ = close_obfs_sender.send(CloseMsg::Stop); + } + }); + + Ok(monitor) + } + fn allowed_traffic_during_tunnel_config(config: &Config) -> AllowedTunnelTraffic { // During ephemeral peer negotiation, only allow traffic to the config service. if config.quantum_resistant || config.daita { @@ -754,7 +913,7 @@ impl WireguardMonitor { resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(target_os = "android")] gateway_only: bool, - #[cfg(windows)] route_manager: crate::routing::RouteManagerHandle, + #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { log::debug!("Tunnel MTU: {}", config.mtu); @@ -894,6 +1053,7 @@ impl WireguardMonitor { /// Returns routes to the peer endpoints (through the physical interface). #[cfg_attr(target_os = "linux", allow(unused_variables))] + #[cfg(not(target_os = "android"))] fn get_endpoint_routes(endpoints: &[IpAddr]) -> impl Iterator<Item = RequiredRoute> + '_ { #[cfg(target_os = "linux")] { @@ -904,37 +1064,42 @@ impl WireguardMonitor { endpoints.iter().map(|ip| { RequiredRoute::new( ipnetwork::IpNetwork::from(*ip), - routing::NetNode::DefaultNode, + talpid_routing::NetNode::DefaultNode, ) }) } #[cfg_attr(not(target_os = "windows"), allow(unused_variables))] - fn get_tunnel_nodes(iface_name: &str, config: &Config) -> (routing::Node, routing::Node) { + #[cfg(not(target_os = "android"))] + fn get_tunnel_nodes( + iface_name: &str, + config: &Config, + ) -> (talpid_routing::Node, talpid_routing::Node) { #[cfg(windows)] { - let v4 = routing::Node::new(config.ipv4_gateway.into(), iface_name.to_string()); + let v4 = talpid_routing::Node::new(config.ipv4_gateway.into(), iface_name.to_string()); let v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() { - routing::Node::new((*ipv6_gateway).into(), iface_name.to_string()) + talpid_routing::Node::new((*ipv6_gateway).into(), iface_name.to_string()) } else { - routing::Node::device(iface_name.to_string()) + talpid_routing::Node::device(iface_name.to_string()) }; (v4, v6) } #[cfg(not(windows))] { - let node = routing::Node::device(iface_name.to_string()); + let node = talpid_routing::Node::device(iface_name.to_string()); (node.clone(), node) } } /// Return routes for all allowed IPs, as well as the gateway, except 0.0.0.0/0. + #[cfg(not(target_os = "android"))] fn get_pre_tunnel_routes<'a>( iface_name: &str, config: &'a Config, ) -> impl Iterator<Item = RequiredRoute> + 'a { - let gateway_node = routing::Node::device(iface_name.to_string()); + let gateway_node = talpid_routing::Node::device(iface_name.to_string()); let gateway_routes = std::iter::once(RequiredRoute::new( ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), gateway_node.clone(), @@ -965,6 +1130,7 @@ impl WireguardMonitor { } /// Return any 0.0.0.0/0 routes specified by the allowed IPs. + #[cfg(not(target_os = "android"))] fn get_post_tunnel_routes<'a>( iface_name: &str, config: &'a Config, |
