diff options
| author | David Göransson <david.goransson@mullvad.net> | 2025-02-20 11:09:42 +0100 |
|---|---|---|
| committer | David Göransson <david.goransson@mullvad.net> | 2025-02-25 12:02:00 +0100 |
| commit | b63bc866946795be36a617adf65c8c6db071b05d (patch) | |
| tree | c93db6a693a01510653723812d0985c94f6cd910 | |
| parent | a473a917e1bfad3c7d9baa1a948eacb5096455aa (diff) | |
| download | mullvadvpn-b63bc866946795be36a617adf65c8c6db071b05d.tar.xz mullvadvpn-b63bc866946795be36a617adf65c8c6db071b05d.zip | |
Reduce open_tun calls (Establish)
Each call to Establish opens a window for leaks on android. By only
invoking Establish if the VpnConfig if any of the input has changed and
reusing it otherwise we avoid many of these leaks. This commit also
waits for android to report back that the routes have been created to
ping and verify connectivity to avoid pings going outside the tunnel.
| -rw-r--r-- | android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt | 3 | ||||
| -rw-r--r-- | android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt | 17 | ||||
| -rw-r--r-- | android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt | 29 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 4 | ||||
| -rw-r--r-- | talpid-routing/src/unix/android.rs | 52 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 20 | ||||
| -rw-r--r-- | talpid-tunnel/src/tun_provider/android/mod.rs | 86 | ||||
| -rw-r--r-- | talpid-tunnel/src/tun_provider/mod.rs | 12 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/check.rs | 12 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 18 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 96 |
11 files changed, 246 insertions, 103 deletions
diff --git a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt index 27e7658a11..e3faaf3884 100644 --- a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt @@ -34,6 +34,9 @@ class TalpidVpnServiceFallbackDnsTest { every { talpidVpnService.prepareVpnSafe() } returns Prepared.right() builderMockk = mockk<VpnService.Builder>() + every { talpidVpnService getProperty "connectivityListener" } returns + mockk<ConnectivityListener>(relaxed = true) + mockkConstructor(VpnService.Builder::class) every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index 9c82d62251..b702a39a6e 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -10,6 +10,7 @@ import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.map @@ -18,7 +19,7 @@ import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.launch import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent import net.mullvad.talpid.util.RawNetworkState @@ -31,29 +32,30 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) val isConnected get() = _isConnected.value - private lateinit var _currentNetworkState: StateFlow<NetworkState?> + private val _mutableNetworkState = MutableStateFlow<NetworkState?>(null) private val resetNetworkState: Channel<Unit> = Channel() // Used by JNI val currentDefaultNetworkState: NetworkState? - get() = _currentNetworkState.value + get() = _mutableNetworkState.value // Used by JNI val currentDnsServers: ArrayList<InetAddress> - get() = _currentNetworkState.value?.dnsServers ?: ArrayList() + get() = _mutableNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { // Consider implementing retry logic for the flows below, because registering a listener on // the default network may fail if the network on Android 11 // https://issuetracker.google.com/issues/175055271?pli=1 - _currentNetworkState = + scope.launch { merge( connectivityManager.defaultRawNetworkStateFlow(), resetNetworkState.receiveAsFlow().map { null }, ) .map { it?.toNetworkState() } .onEach { notifyDefaultNetworkChange(it) } - .stateIn(scope, SharingStarted.Eagerly, null) + .collect(_mutableNetworkState) + } _isConnected = hasInternetCapability() @@ -70,8 +72,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) * know the last known values not to be correct anymore. */ fun invalidateNetworkStateCache() { - // TODO remove runBlocking - runBlocking { resetNetworkState.send(Unit) } + _mutableNetworkState.value = null } private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> = diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index a143df6132..1457ff35f4 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -57,34 +57,22 @@ open class TalpidVpnService : LifecycleVpnService() { // Used by JNI fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { - val tunStatus = activeTunStatus - - if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - tunStatus - } else { - openTunImpl(config) + createTun(config).merge().also { + currentTunConfig = config + activeTunStatus = it } } // Used by JNI - fun openTunForced(config: TunConfig): CreateTunResult = - synchronized(this) { openTunImpl(config) } - - // Used by JNI - fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + fun closeTun(): Unit = + synchronized(this) { + connectivityListener.invalidateNetworkStateCache() + activeTunStatus = null + } // Used by JNI fun bypass(socket: Int): Boolean = protect(socket) - private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config).merge() - - currentTunConfig = config - activeTunStatus = newTunStatus - - return newTunStatus - } - private fun createTun( config: TunConfig ): Either<CreateTunResult.Error, CreateTunResult.Success> = either { @@ -123,6 +111,7 @@ open class TalpidVpnService : LifecycleVpnService() { builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } + connectivityListener.invalidateNetworkStateCache() val vpnInterfaceFd = builder .establishSafe() diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 4faef9860f..cb06540b0f 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -114,6 +114,10 @@ impl ConnectingState { ErrorStateCause::SetFirewallPolicyError(error), ) } else { + // HACK: On Android, DNS is part of creating the VPN interface, this call + // ensures that the vpn_config is prepared with correct DNS servers in case they + // previously set to something else, e.g. in the case of blocking. This call + // should probably be part of start_tunnel call. #[cfg(target_os = "android")] { shared_values.prepare_tun_config(false); diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs index 137e69c1de..4907d34c97 100644 --- a/talpid-routing/src/unix/android.rs +++ b/talpid-routing/src/unix/android.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::ops::{ControlFlow, Not}; +use std::ops::ControlFlow; use std::sync::Mutex; use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; @@ -51,7 +51,7 @@ pub struct RouteManagerImpl { last_state: Option<NetworkState>, /// Clients waiting on response to [RouteManagerCommand::WaitForRoutes]. - waiting_for_routes: Vec<oneshot::Sender<()>>, + waiting_for_routes: Vec<(oneshot::Sender<()>, Vec<Route>)>, } impl RouteManagerImpl { @@ -64,7 +64,7 @@ impl RouteManagerImpl { // Try to poll for the current network state at startup. // This will most likely be null, but it covers the edge case where a NetworkState - // update has been emitted before we anyone starts to listen for route updates some + // update has been emitted before anyone starts to listen for route updates some // time in the future (when connecting). let last_state = match current_network_state(android_context) { Ok(initial_state) => initial_state, @@ -105,12 +105,19 @@ impl RouteManagerImpl { // update the last known NetworkState self.last_state = network_state; - if has_routes(self.last_state.as_ref()) { - // notify waiting clients that routes exist - for client in self.waiting_for_routes.drain(..) { - let _ = client.send(()); - } - } + // notify waiting clients that routes exist + self.waiting_for_routes = self + .waiting_for_routes + .into_iter() + .filter_map(|(client, expected_routes)| { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { + let _ = client.send(()); + None + } else { + Some((client, expected_routes)) + } + }) + .collect(); } } } @@ -126,31 +133,42 @@ impl RouteManagerImpl { let _ = tx.send(()); return ControlFlow::Break(()); } - RouteManagerCommand::WaitForRoutes(response_tx) => { + RouteManagerCommand::WaitForRoutes(response_tx, expected_routes) => { // check if routes have already been configured on the Android system. // otherwise, register a listener for network state changes. // routes may come in at any moment in the future. - if has_routes(self.last_state.as_ref()) { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { let _ = response_tx.send(()); } else { - self.waiting_for_routes.push(response_tx); + self.waiting_for_routes.push((response_tx, expected_routes)); } } + RouteManagerCommand::ClearRouteCache(tx) => { + self.clear_route_cache(); + let _ = tx.send(()); + } } ControlFlow::Continue(()) } + + fn clear_route_cache(&mut self) { + self.last_state = None; + } } -/// Check whether the [NetworkState] contains any routes. +/// Check whether the [NetworkState] contains expected routes. /// -/// Since we are the ones telling Android what routes to set, we make the assumption that: -/// If any routes exist whatsoever, they are the the routes we specified. -fn has_routes(state: Option<&NetworkState>) -> bool { +/// Matches the routes reported from Android and checks if all the routes we expect to be there is +/// present. +fn has_routes(state: Option<&NetworkState>, expected_routes: Vec<Route>) -> bool { let Some(network_state) = state else { return false; }; - configured_routes(network_state).is_empty().not() + + let routes = configured_routes(network_state); + + routes.is_superset(&HashSet::from_iter(expected_routes)) } fn configured_routes(state: &NetworkState) -> HashSet<Route> { diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 5aedc9626e..551f572b9f 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -37,6 +37,8 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(target_os = "android")] +use crate::Route; #[cfg(any(target_os = "macos", target_os = "linux"))] pub use imp::Error as PlatformError; @@ -103,7 +105,8 @@ pub(crate) enum RouteManagerCommand { #[cfg(target_os = "android")] #[derive(Debug)] pub(crate) enum RouteManagerCommand { - WaitForRoutes(oneshot::Sender<()>), + ClearRouteCache(oneshot::Sender<()>), + WaitForRoutes(oneshot::Sender<()>, Vec<Route>), Shutdown(oneshot::Sender<()>), } @@ -215,7 +218,7 @@ impl RouteManagerHandle { /// This function is guaranteed to *not* wait for longer than 2 seconds. /// Please, see the implementation of this function for further details. #[cfg(target_os = "android")] - pub async fn wait_for_routes(&self) -> Result<(), Error> { + pub async fn wait_for_routes(&self, expect_routes: Vec<Route>) -> Result<(), Error> { use std::time::Duration; use tokio::time::timeout; /// Maximum time to wait for routes to come up. The expected mean time is low (~200 ms), but @@ -224,7 +227,7 @@ impl RouteManagerHandle { let (result_tx, result_rx) = oneshot::channel(); self.tx - .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx)) + .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx, expect_routes)) .map_err(|_| Error::RouteManagerDown)?; timeout(WAIT_FOR_ROUTES_TIMEOUT, result_rx) @@ -247,6 +250,17 @@ impl RouteManagerHandle { Ok(()) } + /// (Android) Clear the cached routes + #[cfg(target_os = "android")] + pub async fn clear_route_cache(&self) -> Result<(), Error> { + let (result_tx, result_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::ClearRouteCache(result_tx)) + .map_err(|_| Error::RouteManagerDown)?; + let _ = result_rx.await; + Ok(()) + } + /// Listen for non-tunnel default route changes. #[cfg(target_os = "macos")] pub async fn default_route_listener( diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index f285b4a64c..d4adc6ba36 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -16,6 +16,7 @@ use std::{ os::unix::io::{AsRawFd, RawFd}, sync::Arc, }; +use talpid_routing::Route; use talpid_types::net::{ALLOWED_LAN_MULTICAST_NETS, ALLOWED_LAN_NETS}; use talpid_types::{android::AndroidContext, ErrorExt}; @@ -59,12 +60,15 @@ pub enum Error { OtherAlwaysOnApp { app_name: String }, } +type TunnelCache = Option<(VpnServiceConfig, RawFd)>; + /// Factory of tunnel devices on Android. pub struct AndroidTunProvider { jvm: Arc<JavaVM>, class: GlobalRef, object: GlobalRef, config: TunConfig, + current_tunnel: TunnelCache, } impl AndroidTunProvider { @@ -83,6 +87,7 @@ impl AndroidTunProvider { class: talpid_vpn_service_class, object: context.vpn_service, config, + current_tunnel: None, } } @@ -92,40 +97,59 @@ impl AndroidTunProvider { &mut self.config } - /// Open a tunnel with the current configuration. + /// Returns an open tunnel with the current configuration, if a tunnel already exists with the + /// corresponding VpnTunConfig it returns a cached copy. pub fn open_tun(&mut self) -> Result<VpnServiceTun, Error> { - self.open_tun_inner("openTun") - } + let config = VpnServiceConfig::new(self.config.clone()); - /// Open a tunnel with the current configuration. - /// Force recreation even if the tunnel config hasn't changed. - pub fn open_tun_forced(&mut self) -> Result<VpnServiceTun, Error> { - self.open_tun_inner("openTunForced") - } + let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } + .map_err(Error::CloneJavaVm)?; - /// Open a tunnel with the current configuration. - fn open_tun_inner(&mut self, get_tun_func_name: &'static str) -> Result<VpnServiceTun, Error> { - let tun_fd = self.open_tun_fd(get_tun_func_name)?; + // If we are recreating the same tunnel we return the same file descriptor to avoid calling + // open_tun in android since it may cause leaks. + if let Some((vpn_service_config, raw_fd)) = &self.current_tunnel { + if vpn_service_config == &config { + return Ok(VpnServiceTun { + tunnel: *raw_fd, + is_new_tunnel: false, + jvm, + class: self.class.clone(), + object: self.object.clone(), + }); + } + } + + self.open_tun_forced() + } + /// Returns an open tunnel with the current configuration + pub fn open_tun_forced(&mut self) -> Result<VpnServiceTun, Error> { + let config = VpnServiceConfig::new(self.config.clone()); let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } .map_err(Error::CloneJavaVm)?; + let raw_fd = self.open_tun_fd(config.clone())?; + + // Cache the current tunnel + self.current_tunnel = Some((config, raw_fd)); + Ok(VpnServiceTun { - tunnel: tun_fd, + tunnel: raw_fd, + is_new_tunnel: true, jvm, class: self.class.clone(), object: self.object.clone(), }) } - fn open_tun_fd(&self, get_tun_func_name: &'static str) -> Result<RawFd, Error> { - let config = VpnServiceConfig::new(self.config.clone()); + // Opens a tunnel in Android with the provided VpnServiceConfig. + fn open_tun_fd(&mut self, config: VpnServiceConfig) -> Result<RawFd, Error> { + let method_name = "openTun"; let env = self.env()?; let java_config = config.into_java(&env); - let result = self.call_method( - get_tun_func_name, + method_name, "(Lnet/mullvad/talpid/model/TunConfig;)Lnet/mullvad/talpid/model/CreateTunResult;", JavaType::Object("net/mullvad/talpid/model/CreateTunResult".to_owned()), &[JValue::Object(java_config.as_obj())], @@ -134,7 +158,7 @@ impl AndroidTunProvider { match result { JValue::Object(result) => CreateTunResult::from_java(&env, result).into(), value => Err(Error::InvalidMethodResult( - get_tun_func_name, + method_name, format!("{:?}", value), )), } @@ -153,11 +177,14 @@ impl AndroidTunProvider { Err(error) => Some(error), }; - if let Some(error) = error { - log::error!( + match error { + Some(error) => log::error!( "{}", error.display_chain_with_msg("Failed to close the tunnel") - ); + ), + + // Remove the cache of config + None => self.current_tunnel = None, } } @@ -188,6 +215,14 @@ impl AndroidTunProvider { } } + pub fn real_routes(&self) -> Vec<Route> { + self.config + .real_routes() + .into_iter() + .map(Route::new) + .collect() + } + fn call_method( &self, name: &'static str, @@ -221,7 +256,7 @@ impl AndroidTunProvider { /// Configuration to use for VpnService #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(class_name = "net.mullvad.talpid.model.TunConfig")] -struct VpnServiceConfig { +pub struct VpnServiceConfig { /// IP addresses for the tunnel interface. pub addresses: Vec<IpAddr>, @@ -318,7 +353,7 @@ impl VpnServiceConfig { #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(package = "net.mullvad.talpid.model")] -struct InetNetwork { +pub struct InetNetwork { address: IpAddr, prefix: i16, } @@ -332,9 +367,16 @@ impl From<IpNetwork> for InetNetwork { } } +impl From<&InetNetwork> for IpNetwork { + fn from(inet_network: &InetNetwork) -> Self { + IpNetwork::new(inet_network.address, inet_network.prefix as u8).unwrap() + } +} + /// Handle to a tunnel device on Android. pub struct VpnServiceTun { tunnel: RawFd, + pub is_new_tunnel: bool, jvm: JavaVM, class: GlobalRef, object: GlobalRef, diff --git a/talpid-tunnel/src/tun_provider/mod.rs b/talpid-tunnel/src/tun_provider/mod.rs index 1bf4e1abb4..d49b76147b 100644 --- a/talpid-tunnel/src/tun_provider/mod.rs +++ b/talpid-tunnel/src/tun_provider/mod.rs @@ -1,3 +1,5 @@ +#[cfg(target_os = "android")] +use crate::tun_provider::imp::VpnServiceConfig; use cfg_if::cfg_if; use ipnetwork::IpNetwork; use std::{ @@ -73,6 +75,16 @@ impl TunConfig { } servers } + + /// Routes to configure for the tunnel. + #[cfg(target_os = "android")] + pub fn real_routes(&self) -> Vec<IpNetwork> { + VpnServiceConfig::new(self.clone()) + .routes + .iter() + .map(IpNetwork::from) + .collect() + } } /// Return a tunnel configuration that routes all traffic inside the tunnel. diff --git a/talpid-wireguard/src/connectivity/check.rs b/talpid-wireguard/src/connectivity/check.rs index a5ac9cbeef..9c029948bf 100644 --- a/talpid-wireguard/src/connectivity/check.rs +++ b/talpid-wireguard/src/connectivity/check.rs @@ -184,7 +184,17 @@ impl Check { { return Ok(true); } - tokio::time::sleep(Duration::from_millis(20)).await; + // Calling get_stats has an unwanted effect of possibly causing segmentation fault, + // stacktrace hints towards Garbage Collector failing. The cause has yet not been + // determined, it could be because some dangling pointer, bug inside WG-go or + // something else. So for now we avoid spamming get_config too much since it lowers + // the risk of crash happening. + // + // The value was previously set to 20 ms, depending on when we called + // establish_connectivity, this caused the crash to reliably occur. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) + tokio::time::sleep(Duration::from_millis(100)).await; } }; diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index fe1a848e9a..96efed3431 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -28,6 +28,8 @@ use talpid_tunnel::{ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, }; +#[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; #[cfg(daita)] use talpid_tunnel_config_client::DaitaSettings; use talpid_types::{ @@ -434,6 +436,7 @@ impl WireguardMonitor { &config, log_path, args.tun_provider.clone(), + args.route_manager, // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. @@ -465,13 +468,6 @@ impl WireguardMonitor { .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; - // Wait for routes to come up - args.route_manager - .wait_for_routes() - .await - .map_err(Error::SetupRoutingError) - .map_err(CloseMsg::SetupError)?; - if should_negotiate_ephemeral_peer { let ephemeral_obfs_sender = close_obfs_sender.clone(); @@ -743,12 +739,13 @@ impl WireguardMonitor { config: &Config, log_path: Option<&Path>, #[cfg(unix)] tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(target_os = "android")] route_manager: RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(target_os = "android")] gateway_only: bool, #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, ) -> Result<WgGoTunnel> { - #[cfg(unix)] + #[cfg(all(unix, not(target_os = "android")))] let routes = config .get_tunnel_destinations() .flat_map(Self::replace_default_prefixes); @@ -780,7 +777,7 @@ impl WireguardMonitor { exit_peer, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -791,7 +788,7 @@ impl WireguardMonitor { &config, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -969,6 +966,7 @@ impl WireguardMonitor { } /// Replace default (0-prefix) routes with more specific routes. + #[cfg(not(target_os = "android"))] fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> { #[cfg(windows)] if network.prefix() == 0 { diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 813490899a..73b22a452f 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,5 +1,7 @@ #[cfg(target_os = "android")] use super::config; +#[cfg(target_os = "android")] +use super::Error; use super::{ stats::{Stats, StatsMap}, Config, Tunnel, TunnelError, @@ -9,7 +11,7 @@ use crate::config::MULLVAD_INTERFACE_NAME; #[cfg(target_os = "android")] use crate::connectivity; use crate::logging::{clean_up_logging, initialize_logging}; -#[cfg(unix)] +#[cfg(all(unix, not(target_os = "android")))] use ipnetwork::IpNetwork; #[cfg(daita)] use std::ffi::CString; @@ -23,6 +25,8 @@ use std::{ pin::Pin, }; #[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; +#[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; #[cfg(not(target_os = "windows"))] use talpid_tunnel::tun_provider::{Tun, TunProvider}; @@ -115,7 +119,7 @@ impl WgGoTunnel { let log_path = state._logging_context.path.clone(); let cancel_receiver = state.cancel_receiver.clone(); let tun_provider = Arc::clone(&state.tun_provider); - let routes = config.get_tunnel_destinations(); + let route_manager = state.route_manager.clone(); match self { WgGoTunnel::Multihop(state) if !config.is_multihop() => { @@ -124,7 +128,7 @@ impl WgGoTunnel { config, log_path.as_deref(), tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -136,22 +140,19 @@ impl WgGoTunnel { &config.exit_peer.clone().unwrap().clone(), log_path.as_deref(), tun_provider, - routes, + route_manager, cancel_receiver, ) .await } WgGoTunnel::Singlehop(mut state) => { state.set_config(config.clone())?; - // HACK: Check if the tunnel is working by sending a ping in the tunnel. let new_state = WgGoTunnel::Singlehop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } WgGoTunnel::Multihop(mut state) => { state.set_config(config.clone())?; let new_state = WgGoTunnel::Multihop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } } @@ -173,6 +174,8 @@ pub(crate) struct WgGoTunnelState { _logging_context: LoggingContext, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(target_os = "android")] + route_manager: RouteManagerHandle, #[cfg(daita)] config: Config, /// This is used to cancel the connectivity checks that occur when toggling multihop @@ -348,7 +351,7 @@ impl WgGoTunnel { fn get_tunnel( tun_provider: Arc<Mutex<TunProvider>>, config: &Config, - routes: impl Iterator<Item = IpNetwork>, + #[cfg(not(target_os = "android"))] routes: impl Iterator<Item = IpNetwork>, ) -> Result<(Tun, RawFd)> { let mut last_error = None; let mut tun_provider = tun_provider.lock().unwrap(); @@ -362,12 +365,17 @@ impl WgGoTunnel { tun_config.ipv4_gateway = config.ipv4_gateway; tun_config.ipv6_gateway = config.ipv6_gateway; tun_config.mtu = config.mtu; - tun_config.routes = if cfg!(target_os = "android") { - // Route everything into the tunnel and have wireguard-go act as a firewall. - vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] - } else { - routes.collect() - }; + + // Route everything into the tunnel and have wireguard-go act as a firewall. + #[cfg(not(target_os = "android"))] + { + tun_config.routes = routes.collect(); + } + + #[cfg(target_os = "android")] + { + tun_config.routes = vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]; + } for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { let tunnel_device = tun_provider @@ -395,11 +403,16 @@ impl WgGoTunnel { config: &Config, log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, - routes: impl Iterator<Item = IpNetwork>, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result<Self> { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + route_manager + .clear_route_cache() + .await + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config)?; + let is_new_tunnel = tunnel_device.is_new_tunnel; let interface_name: String = tunnel_device .interface_name() @@ -427,12 +440,21 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver, }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. For other platforms + // this is done in the tunnel_fut in WireguardMonitor.start, however that caused it to crash + // in GO on Android. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -443,11 +465,16 @@ impl WgGoTunnel { exit_peer: &PeerConfig, log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, - routes: impl Iterator<Item = IpNetwork>, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result<Self> { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + route_manager + .clear_route_cache() + .await + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config)?; + let is_new_tunnel = tunnel_device.is_new_tunnel; let interface_name: String = tunnel_device .interface_name() @@ -491,12 +518,21 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver: cancel_receiver.clone(), }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. For other platforms + // this is done in the tunnel_fut in WireguardMonitor.start, however that caused it to crash + // in GO on Android. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -517,6 +553,22 @@ impl WgGoTunnel { /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. + async fn wait_for_routes(&self) -> Result<()> { + let state = self.as_state(); + + let expected_routes = state.tun_provider.lock().unwrap().real_routes(); + + // Wait for routes to come up + state + .route_manager + .clone() + .wait_for_routes(expected_routes) + .await + .map_err(Error::SetupRoutingError) + .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; + + Ok(()) + } async fn ensure_tunnel_is_running(&self) -> Result<()> { let state = self.as_state(); let addr = state.config.ipv4_gateway; |
