diff options
| -rw-r--r-- | android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt | 3 | ||||
| -rw-r--r-- | android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt | 17 | ||||
| -rw-r--r-- | android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt | 29 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 4 | ||||
| -rw-r--r-- | talpid-routing/src/unix/android.rs | 52 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 20 | ||||
| -rw-r--r-- | talpid-tunnel/src/tun_provider/android/mod.rs | 86 | ||||
| -rw-r--r-- | talpid-tunnel/src/tun_provider/mod.rs | 12 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/check.rs | 12 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 18 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 96 |
11 files changed, 246 insertions, 103 deletions
diff --git a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt index 27e7658a11..e3faaf3884 100644 --- a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt @@ -34,6 +34,9 @@ class TalpidVpnServiceFallbackDnsTest { every { talpidVpnService.prepareVpnSafe() } returns Prepared.right() builderMockk = mockk<VpnService.Builder>() + every { talpidVpnService getProperty "connectivityListener" } returns + mockk<ConnectivityListener>(relaxed = true) + mockkConstructor(VpnService.Builder::class) every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index 9c82d62251..b702a39a6e 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -10,6 +10,7 @@ import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.map @@ -18,7 +19,7 @@ import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.launch import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent import net.mullvad.talpid.util.RawNetworkState @@ -31,29 +32,30 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) val isConnected get() = _isConnected.value - private lateinit var _currentNetworkState: StateFlow<NetworkState?> + private val _mutableNetworkState = MutableStateFlow<NetworkState?>(null) private val resetNetworkState: Channel<Unit> = Channel() // Used by JNI val currentDefaultNetworkState: NetworkState? - get() = _currentNetworkState.value + get() = _mutableNetworkState.value // Used by JNI val currentDnsServers: ArrayList<InetAddress> - get() = _currentNetworkState.value?.dnsServers ?: ArrayList() + get() = _mutableNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { // Consider implementing retry logic for the flows below, because registering a listener on // the default network may fail if the network on Android 11 // https://issuetracker.google.com/issues/175055271?pli=1 - _currentNetworkState = + scope.launch { merge( connectivityManager.defaultRawNetworkStateFlow(), resetNetworkState.receiveAsFlow().map { null }, ) .map { it?.toNetworkState() } .onEach { notifyDefaultNetworkChange(it) } - .stateIn(scope, SharingStarted.Eagerly, null) + .collect(_mutableNetworkState) + } _isConnected = hasInternetCapability() @@ -70,8 +72,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) * know the last known values not to be correct anymore. */ fun invalidateNetworkStateCache() { - // TODO remove runBlocking - runBlocking { resetNetworkState.send(Unit) } + _mutableNetworkState.value = null } private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> = diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index a143df6132..1457ff35f4 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -57,34 +57,22 @@ open class TalpidVpnService : LifecycleVpnService() { // Used by JNI fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { - val tunStatus = activeTunStatus - - if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - tunStatus - } else { - openTunImpl(config) + createTun(config).merge().also { + currentTunConfig = config + activeTunStatus = it } } // Used by JNI - fun openTunForced(config: TunConfig): CreateTunResult = - synchronized(this) { openTunImpl(config) } - - // Used by JNI - fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + fun closeTun(): Unit = + synchronized(this) { + connectivityListener.invalidateNetworkStateCache() + activeTunStatus = null + } // Used by JNI fun bypass(socket: Int): Boolean = protect(socket) - private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config).merge() - - currentTunConfig = config - activeTunStatus = newTunStatus - - return newTunStatus - } - private fun createTun( config: TunConfig ): Either<CreateTunResult.Error, CreateTunResult.Success> = either { @@ -123,6 +111,7 @@ open class TalpidVpnService : LifecycleVpnService() { builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } + connectivityListener.invalidateNetworkStateCache() val vpnInterfaceFd = builder .establishSafe() diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 4faef9860f..cb06540b0f 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -114,6 +114,10 @@ impl ConnectingState { ErrorStateCause::SetFirewallPolicyError(error), ) } else { + // HACK: On Android, DNS is part of creating the VPN interface, this call + // ensures that the vpn_config is prepared with correct DNS servers in case they + // previously set to something else, e.g. in the case of blocking. This call + // should probably be part of start_tunnel call. #[cfg(target_os = "android")] { shared_values.prepare_tun_config(false); diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs index 137e69c1de..4907d34c97 100644 --- a/talpid-routing/src/unix/android.rs +++ b/talpid-routing/src/unix/android.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::ops::{ControlFlow, Not}; +use std::ops::ControlFlow; use std::sync::Mutex; use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; @@ -51,7 +51,7 @@ pub struct RouteManagerImpl { last_state: Option<NetworkState>, /// Clients waiting on response to [RouteManagerCommand::WaitForRoutes]. - waiting_for_routes: Vec<oneshot::Sender<()>>, + waiting_for_routes: Vec<(oneshot::Sender<()>, Vec<Route>)>, } impl RouteManagerImpl { @@ -64,7 +64,7 @@ impl RouteManagerImpl { // Try to poll for the current network state at startup. // This will most likely be null, but it covers the edge case where a NetworkState - // update has been emitted before we anyone starts to listen for route updates some + // update has been emitted before anyone starts to listen for route updates some // time in the future (when connecting). let last_state = match current_network_state(android_context) { Ok(initial_state) => initial_state, @@ -105,12 +105,19 @@ impl RouteManagerImpl { // update the last known NetworkState self.last_state = network_state; - if has_routes(self.last_state.as_ref()) { - // notify waiting clients that routes exist - for client in self.waiting_for_routes.drain(..) { - let _ = client.send(()); - } - } + // notify waiting clients that routes exist + self.waiting_for_routes = self + .waiting_for_routes + .into_iter() + .filter_map(|(client, expected_routes)| { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { + let _ = client.send(()); + None + } else { + Some((client, expected_routes)) + } + }) + .collect(); } } } @@ -126,31 +133,42 @@ impl RouteManagerImpl { let _ = tx.send(()); return ControlFlow::Break(()); } - RouteManagerCommand::WaitForRoutes(response_tx) => { + RouteManagerCommand::WaitForRoutes(response_tx, expected_routes) => { // check if routes have already been configured on the Android system. // otherwise, register a listener for network state changes. // routes may come in at any moment in the future. - if has_routes(self.last_state.as_ref()) { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { let _ = response_tx.send(()); } else { - self.waiting_for_routes.push(response_tx); + self.waiting_for_routes.push((response_tx, expected_routes)); } } + RouteManagerCommand::ClearRouteCache(tx) => { + self.clear_route_cache(); + let _ = tx.send(()); + } } ControlFlow::Continue(()) } + + fn clear_route_cache(&mut self) { + self.last_state = None; + } } -/// Check whether the [NetworkState] contains any routes. +/// Check whether the [NetworkState] contains expected routes. /// -/// Since we are the ones telling Android what routes to set, we make the assumption that: -/// If any routes exist whatsoever, they are the the routes we specified. -fn has_routes(state: Option<&NetworkState>) -> bool { +/// Matches the routes reported from Android and checks if all the routes we expect to be there is +/// present. +fn has_routes(state: Option<&NetworkState>, expected_routes: Vec<Route>) -> bool { let Some(network_state) = state else { return false; }; - configured_routes(network_state).is_empty().not() + + let routes = configured_routes(network_state); + + routes.is_superset(&HashSet::from_iter(expected_routes)) } fn configured_routes(state: &NetworkState) -> HashSet<Route> { diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 5aedc9626e..551f572b9f 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -37,6 +37,8 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(target_os = "android")] +use crate::Route; #[cfg(any(target_os = "macos", target_os = "linux"))] pub use imp::Error as PlatformError; @@ -103,7 +105,8 @@ pub(crate) enum RouteManagerCommand { #[cfg(target_os = "android")] #[derive(Debug)] pub(crate) enum RouteManagerCommand { - WaitForRoutes(oneshot::Sender<()>), + ClearRouteCache(oneshot::Sender<()>), + WaitForRoutes(oneshot::Sender<()>, Vec<Route>), Shutdown(oneshot::Sender<()>), } @@ -215,7 +218,7 @@ impl RouteManagerHandle { /// This function is guaranteed to *not* wait for longer than 2 seconds. /// Please, see the implementation of this function for further details. #[cfg(target_os = "android")] - pub async fn wait_for_routes(&self) -> Result<(), Error> { + pub async fn wait_for_routes(&self, expect_routes: Vec<Route>) -> Result<(), Error> { use std::time::Duration; use tokio::time::timeout; /// Maximum time to wait for routes to come up. The expected mean time is low (~200 ms), but @@ -224,7 +227,7 @@ impl RouteManagerHandle { let (result_tx, result_rx) = oneshot::channel(); self.tx - .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx)) + .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx, expect_routes)) .map_err(|_| Error::RouteManagerDown)?; timeout(WAIT_FOR_ROUTES_TIMEOUT, result_rx) @@ -247,6 +250,17 @@ impl RouteManagerHandle { Ok(()) } + /// (Android) Clear the cached routes + #[cfg(target_os = "android")] + pub async fn clear_route_cache(&self) -> Result<(), Error> { + let (result_tx, result_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::ClearRouteCache(result_tx)) + .map_err(|_| Error::RouteManagerDown)?; + let _ = result_rx.await; + Ok(()) + } + /// Listen for non-tunnel default route changes. #[cfg(target_os = "macos")] pub async fn default_route_listener( diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index f285b4a64c..d4adc6ba36 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -16,6 +16,7 @@ use std::{ os::unix::io::{AsRawFd, RawFd}, sync::Arc, }; +use talpid_routing::Route; use talpid_types::net::{ALLOWED_LAN_MULTICAST_NETS, ALLOWED_LAN_NETS}; use talpid_types::{android::AndroidContext, ErrorExt}; @@ -59,12 +60,15 @@ pub enum Error { OtherAlwaysOnApp { app_name: String }, } +type TunnelCache = Option<(VpnServiceConfig, RawFd)>; + /// Factory of tunnel devices on Android. pub struct AndroidTunProvider { jvm: Arc<JavaVM>, class: GlobalRef, object: GlobalRef, config: TunConfig, + current_tunnel: TunnelCache, } impl AndroidTunProvider { @@ -83,6 +87,7 @@ impl AndroidTunProvider { class: talpid_vpn_service_class, object: context.vpn_service, config, + current_tunnel: None, } } @@ -92,40 +97,59 @@ impl AndroidTunProvider { &mut self.config } - /// Open a tunnel with the current configuration. + /// Returns an open tunnel with the current configuration, if a tunnel already exists with the + /// corresponding VpnTunConfig it returns a cached copy. pub fn open_tun(&mut self) -> Result<VpnServiceTun, Error> { - self.open_tun_inner("openTun") - } + let config = VpnServiceConfig::new(self.config.clone()); - /// Open a tunnel with the current configuration. - /// Force recreation even if the tunnel config hasn't changed. - pub fn open_tun_forced(&mut self) -> Result<VpnServiceTun, Error> { - self.open_tun_inner("openTunForced") - } + let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } + .map_err(Error::CloneJavaVm)?; - /// Open a tunnel with the current configuration. - fn open_tun_inner(&mut self, get_tun_func_name: &'static str) -> Result<VpnServiceTun, Error> { - let tun_fd = self.open_tun_fd(get_tun_func_name)?; + // If we are recreating the same tunnel we return the same file descriptor to avoid calling + // open_tun in android since it may cause leaks. + if let Some((vpn_service_config, raw_fd)) = &self.current_tunnel { + if vpn_service_config == &config { + return Ok(VpnServiceTun { + tunnel: *raw_fd, + is_new_tunnel: false, + jvm, + class: self.class.clone(), + object: self.object.clone(), + }); + } + } + + self.open_tun_forced() + } + /// Returns an open tunnel with the current configuration + pub fn open_tun_forced(&mut self) -> Result<VpnServiceTun, Error> { + let config = VpnServiceConfig::new(self.config.clone()); let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } .map_err(Error::CloneJavaVm)?; + let raw_fd = self.open_tun_fd(config.clone())?; + + // Cache the current tunnel + self.current_tunnel = Some((config, raw_fd)); + Ok(VpnServiceTun { - tunnel: tun_fd, + tunnel: raw_fd, + is_new_tunnel: true, jvm, class: self.class.clone(), object: self.object.clone(), }) } - fn open_tun_fd(&self, get_tun_func_name: &'static str) -> Result<RawFd, Error> { - let config = VpnServiceConfig::new(self.config.clone()); + // Opens a tunnel in Android with the provided VpnServiceConfig. + fn open_tun_fd(&mut self, config: VpnServiceConfig) -> Result<RawFd, Error> { + let method_name = "openTun"; let env = self.env()?; let java_config = config.into_java(&env); - let result = self.call_method( - get_tun_func_name, + method_name, "(Lnet/mullvad/talpid/model/TunConfig;)Lnet/mullvad/talpid/model/CreateTunResult;", JavaType::Object("net/mullvad/talpid/model/CreateTunResult".to_owned()), &[JValue::Object(java_config.as_obj())], @@ -134,7 +158,7 @@ impl AndroidTunProvider { match result { JValue::Object(result) => CreateTunResult::from_java(&env, result).into(), value => Err(Error::InvalidMethodResult( - get_tun_func_name, + method_name, format!("{:?}", value), )), } @@ -153,11 +177,14 @@ impl AndroidTunProvider { Err(error) => Some(error), }; - if let Some(error) = error { - log::error!( + match error { + Some(error) => log::error!( "{}", error.display_chain_with_msg("Failed to close the tunnel") - ); + ), + + // Remove the cache of config + None => self.current_tunnel = None, } } @@ -188,6 +215,14 @@ impl AndroidTunProvider { } } + pub fn real_routes(&self) -> Vec<Route> { + self.config + .real_routes() + .into_iter() + .map(Route::new) + .collect() + } + fn call_method( &self, name: &'static str, @@ -221,7 +256,7 @@ impl AndroidTunProvider { /// Configuration to use for VpnService #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(class_name = "net.mullvad.talpid.model.TunConfig")] -struct VpnServiceConfig { +pub struct VpnServiceConfig { /// IP addresses for the tunnel interface. pub addresses: Vec<IpAddr>, @@ -318,7 +353,7 @@ impl VpnServiceConfig { #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(package = "net.mullvad.talpid.model")] -struct InetNetwork { +pub struct InetNetwork { address: IpAddr, prefix: i16, } @@ -332,9 +367,16 @@ impl From<IpNetwork> for InetNetwork { } } +impl From<&InetNetwork> for IpNetwork { + fn from(inet_network: &InetNetwork) -> Self { + IpNetwork::new(inet_network.address, inet_network.prefix as u8).unwrap() + } +} + /// Handle to a tunnel device on Android. pub struct VpnServiceTun { tunnel: RawFd, + pub is_new_tunnel: bool, jvm: JavaVM, class: GlobalRef, object: GlobalRef, diff --git a/talpid-tunnel/src/tun_provider/mod.rs b/talpid-tunnel/src/tun_provider/mod.rs index 1bf4e1abb4..d49b76147b 100644 --- a/talpid-tunnel/src/tun_provider/mod.rs +++ b/talpid-tunnel/src/tun_provider/mod.rs @@ -1,3 +1,5 @@ +#[cfg(target_os = "android")] +use crate::tun_provider::imp::VpnServiceConfig; use cfg_if::cfg_if; use ipnetwork::IpNetwork; use std::{ @@ -73,6 +75,16 @@ impl TunConfig { } servers } + + /// Routes to configure for the tunnel. + #[cfg(target_os = "android")] + pub fn real_routes(&self) -> Vec<IpNetwork> { + VpnServiceConfig::new(self.clone()) + .routes + .iter() + .map(IpNetwork::from) + .collect() + } } /// Return a tunnel configuration that routes all traffic inside the tunnel. diff --git a/talpid-wireguard/src/connectivity/check.rs b/talpid-wireguard/src/connectivity/check.rs index a5ac9cbeef..9c029948bf 100644 --- a/talpid-wireguard/src/connectivity/check.rs +++ b/talpid-wireguard/src/connectivity/check.rs @@ -184,7 +184,17 @@ impl Check { { return Ok(true); } - tokio::time::sleep(Duration::from_millis(20)).await; + // Calling get_stats has an unwanted effect of possibly causing segmentation fault, + // stacktrace hints towards Garbage Collector failing. The cause has yet not been + // determined, it could be because some dangling pointer, bug inside WG-go or + // something else. So for now we avoid spamming get_config too much since it lowers + // the risk of crash happening. + // + // The value was previously set to 20 ms, depending on when we called + // establish_connectivity, this caused the crash to reliably occur. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) + tokio::time::sleep(Duration::from_millis(100)).await; } }; diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index fe1a848e9a..96efed3431 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -28,6 +28,8 @@ use talpid_tunnel::{ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, }; +#[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; #[cfg(daita)] use talpid_tunnel_config_client::DaitaSettings; use talpid_types::{ @@ -434,6 +436,7 @@ impl WireguardMonitor { &config, log_path, args.tun_provider.clone(), + args.route_manager, // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. @@ -465,13 +468,6 @@ impl WireguardMonitor { .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; - // Wait for routes to come up - args.route_manager - .wait_for_routes() - .await - .map_err(Error::SetupRoutingError) - .map_err(CloseMsg::SetupError)?; - if should_negotiate_ephemeral_peer { let ephemeral_obfs_sender = close_obfs_sender.clone(); @@ -743,12 +739,13 @@ impl WireguardMonitor { config: &Config, log_path: Option<&Path>, #[cfg(unix)] tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(target_os = "android")] route_manager: RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(target_os = "android")] gateway_only: bool, #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, ) -> Result<WgGoTunnel> { - #[cfg(unix)] + #[cfg(all(unix, not(target_os = "android")))] let routes = config .get_tunnel_destinations() .flat_map(Self::replace_default_prefixes); @@ -780,7 +777,7 @@ impl WireguardMonitor { exit_peer, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -791,7 +788,7 @@ impl WireguardMonitor { &config, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -969,6 +966,7 @@ impl WireguardMonitor { } /// Replace default (0-prefix) routes with more specific routes. + #[cfg(not(target_os = "android"))] fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> { #[cfg(windows)] if network.prefix() == 0 { diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 813490899a..73b22a452f 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,5 +1,7 @@ #[cfg(target_os = "android")] use super::config; +#[cfg(target_os = "android")] +use super::Error; use super::{ stats::{Stats, StatsMap}, Config, Tunnel, TunnelError, @@ -9,7 +11,7 @@ use crate::config::MULLVAD_INTERFACE_NAME; #[cfg(target_os = "android")] use crate::connectivity; use crate::logging::{clean_up_logging, initialize_logging}; -#[cfg(unix)] +#[cfg(all(unix, not(target_os = "android")))] use ipnetwork::IpNetwork; #[cfg(daita)] use std::ffi::CString; @@ -23,6 +25,8 @@ use std::{ pin::Pin, }; #[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; +#[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; #[cfg(not(target_os = "windows"))] use talpid_tunnel::tun_provider::{Tun, TunProvider}; @@ -115,7 +119,7 @@ impl WgGoTunnel { let log_path = state._logging_context.path.clone(); let cancel_receiver = state.cancel_receiver.clone(); let tun_provider = Arc::clone(&state.tun_provider); - let routes = config.get_tunnel_destinations(); + let route_manager = state.route_manager.clone(); match self { WgGoTunnel::Multihop(state) if !config.is_multihop() => { @@ -124,7 +128,7 @@ impl WgGoTunnel { config, log_path.as_deref(), tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -136,22 +140,19 @@ impl WgGoTunnel { &config.exit_peer.clone().unwrap().clone(), log_path.as_deref(), tun_provider, - routes, + route_manager, cancel_receiver, ) .await } WgGoTunnel::Singlehop(mut state) => { state.set_config(config.clone())?; - // HACK: Check if the tunnel is working by sending a ping in the tunnel. let new_state = WgGoTunnel::Singlehop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } WgGoTunnel::Multihop(mut state) => { state.set_config(config.clone())?; let new_state = WgGoTunnel::Multihop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } } @@ -173,6 +174,8 @@ pub(crate) struct WgGoTunnelState { _logging_context: LoggingContext, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(target_os = "android")] + route_manager: RouteManagerHandle, #[cfg(daita)] config: Config, /// This is used to cancel the connectivity checks that occur when toggling multihop @@ -348,7 +351,7 @@ impl WgGoTunnel { fn get_tunnel( tun_provider: Arc<Mutex<TunProvider>>, config: &Config, - routes: impl Iterator<Item = IpNetwork>, + #[cfg(not(target_os = "android"))] routes: impl Iterator<Item = IpNetwork>, ) -> Result<(Tun, RawFd)> { let mut last_error = None; let mut tun_provider = tun_provider.lock().unwrap(); @@ -362,12 +365,17 @@ impl WgGoTunnel { tun_config.ipv4_gateway = config.ipv4_gateway; tun_config.ipv6_gateway = config.ipv6_gateway; tun_config.mtu = config.mtu; - tun_config.routes = if cfg!(target_os = "android") { - // Route everything into the tunnel and have wireguard-go act as a firewall. - vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] - } else { - routes.collect() - }; + + // Route everything into the tunnel and have wireguard-go act as a firewall. + #[cfg(not(target_os = "android"))] + { + tun_config.routes = routes.collect(); + } + + #[cfg(target_os = "android")] + { + tun_config.routes = vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]; + } for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { let tunnel_device = tun_provider @@ -395,11 +403,16 @@ impl WgGoTunnel { config: &Config, log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, - routes: impl Iterator<Item = IpNetwork>, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result<Self> { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + route_manager + .clear_route_cache() + .await + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config)?; + let is_new_tunnel = tunnel_device.is_new_tunnel; let interface_name: String = tunnel_device .interface_name() @@ -427,12 +440,21 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver, }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. For other platforms + // this is done in the tunnel_fut in WireguardMonitor.start, however that caused it to crash + // in GO on Android. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -443,11 +465,16 @@ impl WgGoTunnel { exit_peer: &PeerConfig, log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, - routes: impl Iterator<Item = IpNetwork>, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result<Self> { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + route_manager + .clear_route_cache() + .await + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config)?; + let is_new_tunnel = tunnel_device.is_new_tunnel; let interface_name: String = tunnel_device .interface_name() @@ -491,12 +518,21 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver: cancel_receiver.clone(), }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. For other platforms + // this is done in the tunnel_fut in WireguardMonitor.start, however that caused it to crash + // in GO on Android. + // + // Tracked by DROID-1825 (Investigate GO crash issue with runtime.GC()) tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -517,6 +553,22 @@ impl WgGoTunnel { /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. + async fn wait_for_routes(&self) -> Result<()> { + let state = self.as_state(); + + let expected_routes = state.tun_provider.lock().unwrap().real_routes(); + + // Wait for routes to come up + state + .route_manager + .clone() + .wait_for_routes(expected_routes) + .await + .map_err(Error::SetupRoutingError) + .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; + + Ok(()) + } async fn ensure_tunnel_is_running(&self) -> Result<()> { let state = self.as_state(); let addr = state.config.ipv4_gateway; |
