diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-02-08 10:39:11 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-02-14 11:00:00 +0100 |
| commit | 6453491bb01bf4ad31a5377a9f16358bf185bc3b (patch) | |
| tree | 8b24b86182165ae48e366f464b89b38c44262fe9 | |
| parent | 4100a6c46d51e7735a028a0dd4b6dff7c4201638 (diff) | |
| download | mullvadvpn-6453491bb01bf4ad31a5377a9f16358bf185bc3b.tar.xz mullvadvpn-6453491bb01bf4ad31a5377a9f16358bf185bc3b.zip | |
Set up tunnel monitor in separate thread
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 81 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 57 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 30 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 18 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 206 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnecting_state.rs | 25 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/error_state.rs | 7 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 25 |
10 files changed, 248 insertions, 223 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 0ed39d7bee..ca6665114f 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,9 +1,10 @@ use self::tun_provider::TunProvider; -use crate::{logging, routing::RouteManager}; +use crate::{logging, routing::RouteManagerHandle}; +use futures::channel::oneshot; use std::{ - io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, + sync::{Arc, Mutex}, }; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; @@ -104,9 +105,10 @@ impl TunnelMonitor { log_dir: &Option<PathBuf>, resource_dir: &Path, on_event: L, - tun_provider: &mut TunProvider, - route_manager: &mut RouteManager, + tun_provider: Arc<Mutex<TunProvider>>, + route_manager: RouteManagerHandle, retry_attempt: u32, + tunnel_close_rx: oneshot::Receiver<()>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -120,9 +122,15 @@ impl TunnelMonitor { match tunnel_parameters { #[cfg(not(target_os = "android"))] - TunnelParameters::OpenVpn(config) => { - Self::start_openvpn_tunnel(&config, log_file, resource_dir, on_event, route_manager) - } + TunnelParameters::OpenVpn(config) => Self::start_openvpn_tunnel( + &config, + log_file, + resource_dir, + on_event, + tunnel_close_rx, + #[cfg(target_os = "linux")] + route_manager, + ), #[cfg(target_os = "android")] TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform), @@ -135,6 +143,7 @@ impl TunnelMonitor { tun_provider, route_manager, retry_attempt, + tunnel_close_rx, ), } } @@ -165,9 +174,10 @@ impl TunnelMonitor { log: Option<PathBuf>, resource_dir: &Path, on_event: L, - tun_provider: &mut TunProvider, - route_manager: &mut RouteManager, + tun_provider: Arc<Mutex<TunProvider>>, + route_manager: RouteManagerHandle, retry_attempt: u32, + tunnel_close_rx: oneshot::Receiver<()>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -186,6 +196,7 @@ impl TunnelMonitor { tun_provider, route_manager, retry_attempt, + tunnel_close_rx, )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), @@ -198,7 +209,8 @@ impl TunnelMonitor { log: Option<PathBuf>, resource_dir: &Path, on_event: L, - route_manager: &mut RouteManager, + tunnel_close_rx: oneshot::Receiver<()>, + #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -206,8 +218,15 @@ impl TunnelMonitor { + Sync + 'static, { - let monitor = - openvpn::OpenVpnMonitor::start(on_event, config, log, resource_dir, route_manager)?; + let monitor = openvpn::OpenVpnMonitor::start( + on_event, + config, + log, + resource_dir, + tunnel_close_rx, + #[cfg(target_os = "linux")] + route_manager, + )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::OpenVpn(monitor), }) @@ -263,42 +282,12 @@ impl TunnelMonitor { } } - /// Creates a handle to this monitor, allowing the tunnel to be closed while some other - /// thread - /// is blocked in `wait`. - pub fn close_handle(&self) -> CloseHandle { - self.monitor.close_handle() - } - /// Consumes the monitor and blocks until the tunnel exits or there is an error. pub fn wait(self) -> Result<()> { self.monitor.wait().map_err(Error::from) } } -/// A handle to a `TunnelMonitor` -pub enum CloseHandle { - #[cfg(not(target_os = "android"))] - /// OpenVpn close handle - OpenVpn(openvpn::OpenVpnCloseHandle), - /// Wireguard close handle - Wireguard(wireguard::CloseHandle), -} - -impl CloseHandle { - /// Closes the underlying tunnel, making the `TunnelMonitor::wait` method return. - pub fn close(self) -> io::Result<()> { - match self { - #[cfg(not(target_os = "android"))] - CloseHandle::OpenVpn(handle) => handle.close(), - CloseHandle::Wireguard(mut handle) => { - handle.close(); - Ok(()) - } - } - } -} - enum InternalTunnelMonitor { #[cfg(not(target_os = "android"))] OpenVpn(openvpn::OpenVpnMonitor), @@ -306,14 +295,6 @@ enum InternalTunnelMonitor { } impl InternalTunnelMonitor { - fn close_handle(&self) -> CloseHandle { - match self { - #[cfg(not(target_os = "android"))] - InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()), - InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()), - } - } - fn wait(self) -> Result<()> { match self { #[cfg(not(target_os = "android"))] diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index ffca275f68..446ce13d79 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -1,6 +1,6 @@ use super::TunnelEvent; #[cfg(target_os = "linux")] -use crate::routing::RequiredRoute; +use crate::routing::{self, RequiredRoute}; use crate::{ mktemp, process::{ @@ -8,8 +8,8 @@ use crate::{ stoppable_process::StoppableProcess, }, proxy::{self, ProxyMonitor, ProxyResourceData}, - routing, }; +use futures::channel::oneshot; #[cfg(windows)] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -65,11 +65,6 @@ pub enum Error { #[error(display = "Failed to initialize the tokio runtime")] RuntimeError(#[error(source)] io::Error), - /// Failed to set up routing. - #[cfg(target_os = "linux")] - #[error(display = "Failed to setup routing")] - SetupRoutingError(#[error(source)] routing::Error), - /// Unable to start, wait for or kill the OpenVPN process. #[error(display = "Error in OpenVPN process management: {}", _0)] ChildProcessError(&'static str, #[error(source)] io::Error), @@ -254,8 +249,8 @@ impl OpenVpnMonitor<OpenVpnCommand> { params: &openvpn::TunnelParameters, log_path: Option<PathBuf>, resource_dir: &Path, - #[cfg(target_os = "linux")] route_manager: &mut routing::RouteManager, - #[cfg(not(target_os = "linux"))] _route_manager: &mut routing::RouteManager, + tunnel_close_rx: oneshot::Receiver<()>, + #[cfg(target_os = "linux")] route_manager: routing::RouteManagerHandle, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -323,8 +318,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { #[cfg(target_os = "linux")] let ipv6_enabled = params.generic_options.enable_ipv6; - #[cfg(target_os = "linux")] - let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); @@ -338,7 +331,7 @@ impl OpenVpnMonitor<OpenVpnCommand> { proxy_auth_file_path: proxy_auth_file_path.clone(), abort_server_tx: event_server_abort_tx, #[cfg(target_os = "linux")] - route_manager_handle, + route_manager_handle: route_manager, #[cfg(target_os = "linux")] ipv6_enabled, }, @@ -347,6 +340,7 @@ impl OpenVpnMonitor<OpenVpnCommand> { user_pass_file, proxy_auth_file, proxy_monitor, + tunnel_close_rx, #[cfg(windows)] Box::new(WintunContextImpl { adapter: wintun_adapter, @@ -379,6 +373,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { user_pass_file: mktemp::TempFile, proxy_auth_file: Option<mktemp::TempFile>, proxy_monitor: Option<Box<dyn ProxyMonitor>>, + tunnel_close_rx: oneshot::Receiver<()>, #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where @@ -424,7 +419,9 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { )); let spawn_task = runtime.spawn(spawn_task); - Ok(OpenVpnMonitor { + let handle = runtime.handle().clone(); + + let monitor = OpenVpnMonitor { spawn_task: Some(spawn_task), abort_spawn, child: Arc::new(Mutex::new(None)), @@ -439,7 +436,25 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { #[cfg(windows)] _wintun: wintun, - }) + }; + + let close_handle = monitor.close_handle(); + handle.spawn(async move { + if tunnel_close_rx.await.is_ok() { + tokio::task::spawn_blocking(move || { + if let Err(error) = close_handle.close() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to close the tunnel") + ); + } + }) + .await + .expect("close handle panic"); + } + }); + + Ok(monitor) } async fn prepare_process( @@ -457,7 +472,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { /// Creates a handle to this monitor, allowing the tunnel to be closed while some other /// thread is blocked in `wait`. - pub fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> { + fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> { OpenVpnCloseHandle { child: self.child.clone(), abort_spawn: self.abort_spawn.clone(), @@ -1212,6 +1227,7 @@ mod tests { fn sets_plugin() { let builder = TestOpenVpnBuilder::default(); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let _ = OpenVpnMonitor::new_internal( builder.clone(), event_server_abort_tx, @@ -1222,6 +1238,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ); @@ -1235,6 +1252,7 @@ mod tests { fn sets_log() { let builder = TestOpenVpnBuilder::default(); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let _ = OpenVpnMonitor::new_internal( builder.clone(), event_server_abort_tx, @@ -1245,6 +1263,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ); @@ -1259,6 +1278,7 @@ mod tests { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(0)); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let testee = OpenVpnMonitor::new_internal( builder, event_server_abort_tx, @@ -1269,6 +1289,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ) @@ -1281,6 +1302,7 @@ mod tests { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let testee = OpenVpnMonitor::new_internal( builder, event_server_abort_tx, @@ -1291,6 +1313,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ) @@ -1303,6 +1326,7 @@ mod tests { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let testee = OpenVpnMonitor::new_internal( builder, event_server_abort_tx, @@ -1313,6 +1337,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ) @@ -1325,6 +1350,7 @@ mod tests { fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let (_close_tx, close_rx) = oneshot::channel(); let result = OpenVpnMonitor::new_internal( builder, event_server_abort_tx, @@ -1335,6 +1361,7 @@ mod tests { TempFile::new(), None, None, + close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), ) diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 404b2dec68..8b5e0cfa4a 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -2,10 +2,10 @@ use self::config::Config; #[cfg(not(windows))] use super::tun_provider; use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; -use crate::routing::{self, RequiredRoute}; -use futures::future::abortable; +use crate::routing::{self, RequiredRoute, RouteManagerHandle}; #[cfg(windows)] use futures::{channel::mpsc, StreamExt}; +use futures::{channel::oneshot, future::abortable}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -168,9 +168,10 @@ impl WireguardMonitor { log_path: Option<&Path>, resource_dir: &Path, on_event: F, - tun_provider: &mut TunProvider, - route_manager: &mut routing::RouteManager, + tun_provider: Arc<Mutex<TunProvider>>, + route_manager: RouteManagerHandle, retry_attempt: u32, + tunnel_close_rx: oneshot::Receiver<()>, ) -> Result<WireguardMonitor> { let mut tcp_proxies = vec![]; let mut endpoint_addrs = vec![]; @@ -194,7 +195,6 @@ impl WireguardMonitor { log_path, resource_dir, tun_provider, - route_manager, #[cfg(target_os = "windows")] setup_done_tx, )?; @@ -224,8 +224,6 @@ impl WireguardMonitor { ) .map_err(Error::ConnectivityMonitorError)?; - let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; - let metadata = Self::tunnel_metadata(&iface_name, &config); tokio::spawn(async move { @@ -258,7 +256,7 @@ impl WireguardMonitor { } #[cfg(target_os = "linux")] - route_handle + route_manager .create_routing_rules(config.enable_ipv6) .await .map_err(Error::SetupRoutingError)?; @@ -266,7 +264,7 @@ impl WireguardMonitor { let routes = Self::get_in_tunnel_routes(&iface_name, &config) .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs)); - route_handle + route_manager .add_routes(routes.collect()) .await .map_err(Error::SetupRoutingError) @@ -304,6 +302,13 @@ impl WireguardMonitor { let _ = close_sender.send(CloseMsg::PingErr); }); + let mut close_handle = monitor.close_handle(); + tokio::spawn(async move { + if tunnel_close_rx.await.is_ok() { + close_handle.close(); + } + }); + Ok(monitor) } @@ -313,8 +318,7 @@ impl WireguardMonitor { config: &Config, log_path: Option<&Path>, resource_dir: &Path, - tun_provider: &mut TunProvider, - route_manager: &mut routing::RouteManager, + tun_provider: Arc<Mutex<TunProvider>>, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { #[cfg(target_os = "linux")] @@ -384,8 +388,6 @@ impl WireguardMonitor { #[cfg(not(windows))] Self::get_tunnel_destinations(config), #[cfg(windows)] - route_manager, - #[cfg(windows)] setup_done_tx, ) .map_err(Error::TunnelError)?, @@ -393,7 +395,7 @@ impl WireguardMonitor { } /// Returns a close handle for the tunnel - pub fn close_handle(&self) -> CloseHandle { + fn close_handle(&self) -> CloseHandle { CloseHandle { chan: self.close_msg_sender.clone(), } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index e53b0cfbb3..28666b6506 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -2,8 +2,6 @@ use super::{ stats::{Stats, StatsMap}, Config, Tunnel, TunnelError, }; -#[cfg(windows)] -use crate::routing; #[cfg(not(windows))] use crate::tunnel::tun_provider::TunProvider; use crate::tunnel::wireguard::logging::{ @@ -43,6 +41,9 @@ type Result<T> = std::result::Result<T, TunnelError>; use crate::winnet; #[cfg(not(target_os = "windows"))] +use std::sync::{Arc, Mutex}; + +#[cfg(not(target_os = "windows"))] const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; struct LoggingContext(u32); @@ -73,7 +74,7 @@ impl WgGoTunnel { pub fn start_tunnel( config: &Config, log_path: Option<&Path>, - tun_provider: &mut TunProvider, + tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, ) -> Result<Self> { #[cfg_attr(not(target_os = "android"), allow(unused_mut))] @@ -114,12 +115,13 @@ impl WgGoTunnel { pub fn start_tunnel( config: &Config, log_path: Option<&Path>, - route_manager: &mut routing::RouteManager, mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Self> { - let route_callback_handle = route_manager - .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()) - .ok(); + let route_callback_handle = winnet::add_default_route_change_callback( + Some(WgGoTunnel::default_route_changed_callback), + (), + ) + .ok(); if route_callback_handle.is_none() { log::warn!("Failed to register default route callback"); } @@ -275,13 +277,15 @@ impl WgGoTunnel { #[cfg(not(target_os = "windows"))] fn get_tunnel( - tun_provider: &mut TunProvider, + tun_provider: Arc<Mutex<TunProvider>>, config: &Config, routes: impl Iterator<Item = IpNetwork>, ) -> Result<(Tun, RawFd)> { let mut last_error = None; let tunnel_config = Self::create_tunnel_config(config, routes); + let mut tun_provider = tun_provider.lock().unwrap(); + for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { let tunnel_device = tun_provider .get_tun(tunnel_config.clone()) diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 546f9e92ab..52c410fdf6 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -5,7 +5,7 @@ use super::{ }; use crate::{ firewall::FirewallPolicy, - tunnel::{CloseHandle, TunnelEvent, TunnelMetadata}, + tunnel::{TunnelEvent, TunnelMetadata}, }; use cfg_if::cfg_if; use futures::{ @@ -33,7 +33,7 @@ pub struct ConnectedStateBootstrap { pub tunnel_events: TunnelEventsReceiver, pub tunnel_parameters: TunnelParameters, pub tunnel_close_event: TunnelCloseEvent, - pub close_handle: Option<CloseHandle>, + pub tunnel_close_tx: oneshot::Sender<()>, } /// The tunnel is up and working. @@ -42,7 +42,7 @@ pub struct ConnectedState { tunnel_events: TunnelEventsReceiver, tunnel_parameters: TunnelParameters, tunnel_close_event: TunnelCloseEvent, - close_handle: Option<CloseHandle>, + tunnel_close_tx: oneshot::Sender<()>, } impl ConnectedState { @@ -52,7 +52,7 @@ impl ConnectedState { tunnel_events: bootstrap.tunnel_events, tunnel_parameters: bootstrap.tunnel_parameters, tunnel_close_event: bootstrap.tunnel_close_event, - close_handle: bootstrap.close_handle, + tunnel_close_tx: bootstrap.tunnel_close_tx, } } @@ -173,7 +173,11 @@ impl ConnectedState { EventConsequence::NewState(DisconnectingState::enter( shared_values, - (self.close_handle, self.tunnel_close_event, after_disconnect), + ( + self.tunnel_close_tx, + self.tunnel_close_event, + after_disconnect, + ), )) } @@ -328,7 +332,7 @@ impl TunnelState for ConnectedState { DisconnectingState::enter( shared_values, ( - connected_state.close_handle, + connected_state.tunnel_close_tx, connected_state.tunnel_close_event, AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), ), @@ -338,7 +342,7 @@ impl TunnelState for ConnectedState { DisconnectingState::enter( shared_values, ( - connected_state.close_handle, + connected_state.tunnel_close_tx, connected_state.tunnel_close_event, AfterDisconnect::Block(ErrorStateCause::SetDnsError), ), diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 2ae1924988..3c3be4c7f1 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -6,9 +6,7 @@ use super::{ use crate::{ firewall::FirewallPolicy, routing::RouteManager, - tunnel::{ - self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor, - }, + tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor}, }; use cfg_if::cfg_if; use futures::{ @@ -18,6 +16,7 @@ use futures::{ }; use std::{ path::{Path, PathBuf}, + sync::{Arc, Mutex}, thread, time::{Duration, Instant}, }; @@ -49,7 +48,7 @@ pub struct ConnectingState { tunnel_parameters: TunnelParameters, tunnel_metadata: Option<TunnelMetadata>, tunnel_close_event: TunnelCloseEvent, - close_handle: Option<CloseHandle>, + tunnel_close_tx: oneshot::Sender<()>, retry_attempt: u32, } @@ -95,10 +94,10 @@ impl ConnectingState { parameters: TunnelParameters, log_dir: &Option<PathBuf>, resource_dir: &Path, - tun_provider: &mut TunProvider, + tun_provider: Arc<Mutex<TunProvider>>, route_manager: &mut RouteManager, retry_attempt: u32, - ) -> crate::tunnel::Result<Self> { + ) -> Self { let (event_tx, event_rx) = mpsc::unbounded(); let on_tunnel_event = move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> { @@ -109,45 +108,86 @@ impl ConnectingState { }) }; - let monitor = TunnelMonitor::start( - runtime, - ¶meters, - log_dir, - resource_dir, - on_tunnel_event, - tun_provider, - route_manager, - retry_attempt, - )?; - let close_handle = Some(monitor.close_handle()); - let tunnel_close_event = - Self::spawn_tunnel_monitor_wait_thread(Some(monitor), retry_attempt); - - Ok(ConnectingState { - tunnel_events: event_rx.fuse(), - tunnel_parameters: parameters, - tunnel_metadata: None, - tunnel_close_event, - close_handle, - retry_attempt, - }) - } + let route_manager_handle = route_manager.handle(); + let log_dir = log_dir.clone(); + let resource_dir = resource_dir.to_path_buf(); - fn spawn_tunnel_monitor_wait_thread( - tunnel_monitor: Option<TunnelMonitor>, - retry_attempt: u32, - ) -> TunnelCloseEvent { + let (tunnel_close_tx, tunnel_close_rx) = oneshot::channel(); let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel(); - thread::spawn(move || { + let tunnel_parameters = parameters.clone(); + + tokio::task::spawn_blocking(move || { let start = Instant::now(); - let block_reason = if let Some(monitor) = tunnel_monitor { - let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt); - log::debug!("Tunnel monitor exited with block reason: {:?}", reason); - reason - } else { - None + let route_manager_handle = match route_manager_handle { + Ok(handle) => handle, + Err(error) => { + if tunnel_close_event_tx + .send(Some(ErrorStateCause::StartTunnelError)) + .is_err() + { + log::warn!( + "Tunnel state machine stopped before receiving tunnel closed event" + ); + } + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain route monitor handle") + ); + return; + } + }; + + let block_reason = match TunnelMonitor::start( + runtime, + &tunnel_parameters, + &log_dir, + &resource_dir, + on_tunnel_event, + tun_provider, + route_manager_handle, + retry_attempt, + tunnel_close_rx, + ) { + Ok(monitor) => { + let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt); + log::debug!("Tunnel monitor exited with block reason: {:?}", reason); + reason + } + Err(error) if should_retry(&error, retry_attempt) => { + log::warn!( + "{}", + error.display_chain_with_msg( + "Retrying to connect after failing to start tunnel" + ) + ); + None + } + Err(error) => { + log::error!("{}", error.display_chain_with_msg("Failed to start tunnel")); + let block_reason = match error { + tunnel::Error::EnableIpv6Error => ErrorStateCause::Ipv6Unavailable, + #[cfg(target_os = "android")] + tunnel::Error::WireguardTunnelMonitoringError( + tunnel::wireguard::Error::TunnelError( + tunnel::wireguard::TunnelError::SetupTunnelDeviceError( + tun_provider::Error::PermissionDenied, + ), + ), + ) => ErrorStateCause::VpnPermissionDenied, + #[cfg(target_os = "android")] + tunnel::Error::WireguardTunnelMonitoringError( + tunnel::wireguard::Error::TunnelError( + tunnel::wireguard::TunnelError::SetupTunnelDeviceError( + tun_provider::Error::InvalidDnsServers(addresses), + ), + ), + ) => ErrorStateCause::InvalidDnsServers(addresses), + _ => ErrorStateCause::StartTunnelError, + }; + Some(block_reason) + } }; if block_reason.is_none() { @@ -163,7 +203,14 @@ impl ConnectingState { log::trace!("Tunnel monitor thread exit"); }); - tunnel_close_event_rx.fuse() + ConnectingState { + tunnel_events: event_rx.fuse(), + tunnel_parameters: parameters, + tunnel_metadata: None, + tunnel_close_event: tunnel_close_event_rx.fuse(), + tunnel_close_tx, + retry_attempt, + } } fn wait_for_tunnel_monitor( @@ -205,7 +252,7 @@ impl ConnectingState { tunnel_events: self.tunnel_events, tunnel_parameters: self.tunnel_parameters, tunnel_close_event: self.tunnel_close_event, - close_handle: self.close_handle, + tunnel_close_tx: self.tunnel_close_tx, } } @@ -234,7 +281,11 @@ impl ConnectingState { EventConsequence::NewState(DisconnectingState::enter( shared_values, - (self.close_handle, self.tunnel_close_event, after_disconnect), + ( + self.tunnel_close_tx, + self.tunnel_close_event, + after_disconnect, + ), )) } @@ -512,7 +563,9 @@ impl TunnelState for ConnectingState { #[cfg(target_os = "android")] { if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 { - if let Err(error) = shared_values.tun_provider.create_tun() { + if let Err(error) = + { shared_values.tun_provider.lock().unwrap().create_tun() } + { log::error!( "{}", error.display_chain_with_msg("Failed to recreate tun device") @@ -521,69 +574,20 @@ impl TunnelState for ConnectingState { } } - match Self::start_tunnel( + let connecting_state = Self::start_tunnel( shared_values.runtime.clone(), tunnel_parameters, &shared_values.log_dir, &shared_values.resource_dir, - &mut shared_values.tun_provider, + shared_values.tun_provider.clone(), &mut shared_values.route_manager, retry_attempt, - ) { - Ok(connecting_state) => { - let params = connecting_state.tunnel_parameters.clone(); - ( - TunnelStateWrapper::from(connecting_state), - TunnelStateTransition::Connecting(params.get_tunnel_endpoint()), - ) - } - Err(error) => { - if should_retry(&error, retry_attempt) { - log::warn!( - "{}", - error.display_chain_with_msg( - "Retrying to connect after failing to start tunnel" - ) - ); - DisconnectingState::enter( - shared_values, - ( - None, - Self::spawn_tunnel_monitor_wait_thread(None, retry_attempt), - AfterDisconnect::Reconnect(retry_attempt + 1), - ), - ) - } else { - log::error!( - "{}", - error.display_chain_with_msg("Failed to start tunnel") - ); - let block_reason = match error { - tunnel::Error::EnableIpv6Error => { - ErrorStateCause::Ipv6Unavailable - } - #[cfg(target_os = "android")] - tunnel::Error::WireguardTunnelMonitoringError( - tunnel::wireguard::Error::TunnelError( - tunnel::wireguard::TunnelError::SetupTunnelDeviceError( - tun_provider::Error::PermissionDenied, - ), - ), - ) => ErrorStateCause::VpnPermissionDenied, - #[cfg(target_os = "android")] - tunnel::Error::WireguardTunnelMonitoringError( - tunnel::wireguard::Error::TunnelError( - tunnel::wireguard::TunnelError::SetupTunnelDeviceError( - tun_provider::Error::InvalidDnsServers(addresses), - ), - ), - ) => ErrorStateCause::InvalidDnsServers(addresses), - _ => ErrorStateCause::StartTunnelError, - }; - ErrorState::enter(shared_values, block_reason) - } - } - } + ); + let params = connecting_state.tunnel_parameters.clone(); + ( + TunnelStateWrapper::from(connecting_state), + TunnelStateTransition::Connecting(params.get_tunnel_endpoint()), + ) } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 3682accc0b..6d0af09aee 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -120,7 +120,7 @@ impl TunnelState for DisconnectedState { #[cfg(target_os = "linux")] shared_values.reset_connectivity_check(); #[cfg(target_os = "android")] - shared_values.tun_provider.close_tun(); + shared_values.tun_provider.lock().unwrap().close_tun(); ( TunnelStateWrapper::from(DisconnectedState), diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 8f6f6ae68b..2d3444f44a 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -3,13 +3,8 @@ use super::{ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; -use crate::tunnel::CloseHandle; -use futures::{future::FusedFuture, StreamExt}; -use std::thread; -use talpid_types::{ - tunnel::{ActionAfterDisconnect, ErrorStateCause}, - ErrorExt, -}; +use futures::{channel::oneshot, future::FusedFuture, StreamExt}; +use talpid_types::tunnel::{ActionAfterDisconnect, ErrorStateCause}; /// This state is active from when we manually trigger a tunnel kill until the tunnel wait /// operation (TunnelExit) returned. @@ -175,23 +170,13 @@ impl DisconnectingState { } impl TunnelState for DisconnectingState { - type Bootstrap = (Option<CloseHandle>, TunnelCloseEvent, AfterDisconnect); + type Bootstrap = (oneshot::Sender<()>, TunnelCloseEvent, AfterDisconnect); fn enter( _: &mut SharedTunnelStateValues, - (close_handle, tunnel_close_event, after_disconnect): Self::Bootstrap, + (tunnel_close_tx, tunnel_close_event, after_disconnect): Self::Bootstrap, ) -> (TunnelStateWrapper, TunnelStateTransition) { - if let Some(close_handle) = close_handle { - thread::spawn(move || { - if let Err(error) = close_handle.close() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to close the tunnel") - ); - } - }); - } - + let _ = tunnel_close_tx.send(()); let action_after_disconnect = after_disconnect.action(); ( diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index a501b21f92..5464acac12 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -51,7 +51,12 @@ impl ErrorState { /// Returns true if a new tunnel device was successfully created. #[cfg(target_os = "android")] fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool { - match shared_values.tun_provider.create_blocking_tun() { + match shared_values + .tun_provider + .lock() + .unwrap() + .create_blocking_tun() + { Ok(()) => true, Err(error) => { log::error!( diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index e22975a827..8c4446f3f3 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -30,7 +30,13 @@ use futures::{ }; #[cfg(target_os = "android")] use std::os::unix::io::RawFd; -use std::{collections::HashSet, io, net::IpAddr, path::PathBuf, sync::Arc}; +use std::{ + collections::HashSet, + io, + net::IpAddr, + path::PathBuf, + sync::{Arc, Mutex}, +}; #[cfg(target_os = "android")] use talpid_types::{android::AndroidContext, ErrorExt}; use talpid_types::{ @@ -294,7 +300,7 @@ impl TunnelStateMachine { dns_servers: settings.dns_servers, allowed_endpoint: settings.allowed_endpoint, tunnel_parameters_generator: Box::new(tunnel_parameters_generator), - tun_provider, + tun_provider: Arc::new(Mutex::new(tun_provider)), log_dir, resource_dir, #[cfg(target_os = "linux")] @@ -383,7 +389,7 @@ struct SharedTunnelStateValues { /// The generator of new `TunnelParameter`s tunnel_parameters_generator: Box<dyn TunnelParametersGenerator>, /// The provider of tunnel devices. - tun_provider: TunProvider, + tun_provider: Arc<Mutex<TunProvider>>, /// Directory to store tunnel log file. log_dir: Option<PathBuf>, /// Resource directory path. @@ -405,7 +411,7 @@ impl SharedTunnelStateValues { #[cfg(target_os = "android")] { - if let Err(error) = self.tun_provider.set_allow_lan(allow_lan) { + if let Err(error) = self.tun_provider.lock().unwrap().set_allow_lan(allow_lan) { log::error!( "{}", error.display_chain_with_msg(&format!( @@ -425,6 +431,8 @@ impl SharedTunnelStateValues { if self.allowed_endpoint != endpoint { #[cfg(target_os = "android")] self.tun_provider + .lock() + .unwrap() .set_allowed_endpoint(endpoint.endpoint.address.ip()); self.allowed_endpoint = endpoint; @@ -444,7 +452,12 @@ impl SharedTunnelStateValues { #[cfg(target_os = "android")] { - if let Err(error) = self.tun_provider.set_dns_servers(dns_servers) { + if let Err(error) = self + .tun_provider + .lock() + .unwrap() + .set_dns_servers(dns_servers) + { log::error!( "{}", error.display_chain_with_msg( @@ -489,7 +502,7 @@ impl SharedTunnelStateValues { #[cfg(target_os = "android")] pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) { - if let Err(err) = self.tun_provider.bypass(fd) { + if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) { log::error!("Failed to bypass socket {}", err); } let _ = tx.send(()); |
