diff options
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 29 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 26 |
2 files changed, 31 insertions, 24 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index e5eabbbd29..a62979b63f 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -83,8 +83,6 @@ use tokio::io; #[path = "wireguard.rs"] mod wireguard; -const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); - /// Timeout for first WireGuard key pushing const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5); @@ -547,8 +545,7 @@ pub struct Daemon<L: EventListener> { last_generated_entry_relay: Option<Relay>, app_version_info: Option<AppVersionInfo>, shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>, - /// oneshot channel that completes once the tunnel state machine has been shut down - tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>, + tunnel_state_machine_handle: tunnel_state_machine::JoinHandle, #[cfg(target_os = "windows")] volume_update_tx: mpsc::UnboundedSender<()>, } @@ -572,8 +569,6 @@ where exclusion_gid::set_exclusion_gid().map_err(Error::GroupIdError)? }; - let (tunnel_state_machine_shutdown_tx, tunnel_state_machine_shutdown_signal) = - oneshot::channel(); let runtime = tokio::runtime::Handle::current(); let (internal_event_tx, internal_event_rx) = command_channel.destructure(); @@ -630,7 +625,7 @@ where let (offline_state_tx, offline_state_rx) = mpsc::unbounded(); #[cfg(target_os = "windows")] let (volume_update_tx, volume_update_rx) = mpsc::unbounded(); - let tunnel_command_tx = tunnel_state_machine::spawn( + let (tunnel_command_tx, tunnel_state_machine_handle) = tunnel_state_machine::spawn( tunnel_state_machine::InitialTunnelState { allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, @@ -645,7 +640,6 @@ where resource_dir.clone(), internal_event_tx.to_specialized_sender(), offline_state_tx, - tunnel_state_machine_shutdown_tx, #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "macos")] @@ -746,7 +740,7 @@ where last_generated_entry_relay: None, app_version_info, shutdown_tasks: vec![], - tunnel_state_machine_shutdown_signal, + tunnel_state_machine_handle, #[cfg(target_os = "windows")] volume_update_tx, }; @@ -842,20 +836,13 @@ where } async fn finalize(self) { - let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal) = + let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_handle) = self.shutdown(); for future in shutdown_tasks { future.await; } - let shutdown_signal = tokio::time::timeout( - TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, - tunnel_state_machine_shutdown_signal, - ); - match shutdown_signal.await { - Ok(_) => log::info!("Tunnel state machine shut down"), - Err(_) => log::error!("Tunnel state machine did not shut down gracefully"), - } + tunnel_state_machine_handle.try_join().await; mem::drop(event_listener); mem::drop(rpc_runtime); @@ -876,13 +863,13 @@ where L, Vec<Pin<Box<dyn Future<Output = ()>>>>, mullvad_rpc::MullvadRpcRuntime, - oneshot::Receiver<()>, + tunnel_state_machine::JoinHandle, ) { let Daemon { event_listener, mut shutdown_tasks, rpc_runtime, - tunnel_state_machine_shutdown_signal, + tunnel_state_machine_handle, target_state, .. } = self; @@ -893,7 +880,7 @@ where event_listener, shutdown_tasks, rpc_runtime, - tunnel_state_machine_shutdown_signal, + tunnel_state_machine_handle, ) } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index cd8fba4091..7296195672 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -36,6 +36,7 @@ use std::{ net::IpAddr, path::PathBuf, sync::{Arc, Mutex}, + time::Duration, }; #[cfg(target_os = "android")] use talpid_types::{android::AndroidContext, ErrorExt}; @@ -44,6 +45,8 @@ use talpid_types::{ tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, }; +const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + /// Errors that can happen when setting up or using the state machine. #[derive(err_derive::Error, Debug)] pub enum Error { @@ -108,11 +111,10 @@ pub async fn spawn( resource_dir: PathBuf, state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static, offline_state_listener: mpsc::UnboundedSender<bool>, - shutdown_tx: oneshot::Sender<()>, #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "macos")] exclusion_gid: u32, #[cfg(target_os = "android")] android_context: AndroidContext, -) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { +) -> Result<(Arc<mpsc::UnboundedSender<TunnelCommand>>, JoinHandle), Error> { let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); @@ -125,6 +127,8 @@ pub async fn spawn( initial_settings.dns_servers.clone(), ); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let weak_command_tx = Arc::downgrade(&command_tx); let state_machine = TunnelStateMachine::new( initial_settings, @@ -151,7 +155,7 @@ pub async fn spawn( } }); - Ok(command_tx) + Ok((command_tx, JoinHandle { shutdown_rx })) } /// Representation of external commands for the tunnel state machine. @@ -580,3 +584,19 @@ state_wrapper! { Error(ErrorState), } } + +/// Handle used to wait for the tunnel state machine to shut down. +pub struct JoinHandle { + shutdown_rx: oneshot::Receiver<()>, +} + +impl JoinHandle { + /// Waits for the tunnel state machine to shut down. + /// This may fail after a timeout of `TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT`. + pub async fn try_join(self) { + match tokio::time::timeout(TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, self.shutdown_rx).await { + Ok(_) => log::info!("Tunnel state machine shut down"), + Err(_) => log::error!("Tunnel state machine did not shut down gracefully"), + } + } +} |
