summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-02-21 14:48:41 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-02-22 10:00:33 +0100
commitd1dc73964dfad4b42b059bd0d89b7a1c5308a161 (patch)
tree3c54a037002a11e39d219feac56a049156f5b164
parentc1198ab23f359b2457b7946f865667e88775ca91 (diff)
downloadmullvadvpn-d1dc73964dfad4b42b059bd0d89b7a1c5308a161.tar.xz
mullvadvpn-d1dc73964dfad4b42b059bd0d89b7a1c5308a161.zip
Use tunnel state machine handle instead of plain oneshot
-rw-r--r--mullvad-daemon/src/lib.rs29
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs26
2 files changed, 31 insertions, 24 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index e5eabbbd29..a62979b63f 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -83,8 +83,6 @@ use tokio::io;
#[path = "wireguard.rs"]
mod wireguard;
-const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
-
/// Timeout for first WireGuard key pushing
const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5);
@@ -547,8 +545,7 @@ pub struct Daemon<L: EventListener> {
last_generated_entry_relay: Option<Relay>,
app_version_info: Option<AppVersionInfo>,
shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
- /// oneshot channel that completes once the tunnel state machine has been shut down
- tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>,
+ tunnel_state_machine_handle: tunnel_state_machine::JoinHandle,
#[cfg(target_os = "windows")]
volume_update_tx: mpsc::UnboundedSender<()>,
}
@@ -572,8 +569,6 @@ where
exclusion_gid::set_exclusion_gid().map_err(Error::GroupIdError)?
};
- let (tunnel_state_machine_shutdown_tx, tunnel_state_machine_shutdown_signal) =
- oneshot::channel();
let runtime = tokio::runtime::Handle::current();
let (internal_event_tx, internal_event_rx) = command_channel.destructure();
@@ -630,7 +625,7 @@ where
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
- let tunnel_command_tx = tunnel_state_machine::spawn(
+ let (tunnel_command_tx, tunnel_state_machine_handle) = tunnel_state_machine::spawn(
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
@@ -645,7 +640,6 @@ where
resource_dir.clone(),
internal_event_tx.to_specialized_sender(),
offline_state_tx,
- tunnel_state_machine_shutdown_tx,
#[cfg(target_os = "windows")]
volume_update_rx,
#[cfg(target_os = "macos")]
@@ -746,7 +740,7 @@ where
last_generated_entry_relay: None,
app_version_info,
shutdown_tasks: vec![],
- tunnel_state_machine_shutdown_signal,
+ tunnel_state_machine_handle,
#[cfg(target_os = "windows")]
volume_update_tx,
};
@@ -842,20 +836,13 @@ where
}
async fn finalize(self) {
- let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal) =
+ let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_handle) =
self.shutdown();
for future in shutdown_tasks {
future.await;
}
- let shutdown_signal = tokio::time::timeout(
- TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT,
- tunnel_state_machine_shutdown_signal,
- );
- match shutdown_signal.await {
- Ok(_) => log::info!("Tunnel state machine shut down"),
- Err(_) => log::error!("Tunnel state machine did not shut down gracefully"),
- }
+ tunnel_state_machine_handle.try_join().await;
mem::drop(event_listener);
mem::drop(rpc_runtime);
@@ -876,13 +863,13 @@ where
L,
Vec<Pin<Box<dyn Future<Output = ()>>>>,
mullvad_rpc::MullvadRpcRuntime,
- oneshot::Receiver<()>,
+ tunnel_state_machine::JoinHandle,
) {
let Daemon {
event_listener,
mut shutdown_tasks,
rpc_runtime,
- tunnel_state_machine_shutdown_signal,
+ tunnel_state_machine_handle,
target_state,
..
} = self;
@@ -893,7 +880,7 @@ where
event_listener,
shutdown_tasks,
rpc_runtime,
- tunnel_state_machine_shutdown_signal,
+ tunnel_state_machine_handle,
)
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index cd8fba4091..7296195672 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -36,6 +36,7 @@ use std::{
net::IpAddr,
path::PathBuf,
sync::{Arc, Mutex},
+ time::Duration,
};
#[cfg(target_os = "android")]
use talpid_types::{android::AndroidContext, ErrorExt};
@@ -44,6 +45,8 @@ use talpid_types::{
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
};
+const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
+
/// Errors that can happen when setting up or using the state machine.
#[derive(err_derive::Error, Debug)]
pub enum Error {
@@ -108,11 +111,10 @@ pub async fn spawn(
resource_dir: PathBuf,
state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static,
offline_state_listener: mpsc::UnboundedSender<bool>,
- shutdown_tx: oneshot::Sender<()>,
#[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
-) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
+) -> Result<(Arc<mpsc::UnboundedSender<TunnelCommand>>, JoinHandle), Error> {
let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
@@ -125,6 +127,8 @@ pub async fn spawn(
initial_settings.dns_servers.clone(),
);
+ let (shutdown_tx, shutdown_rx) = oneshot::channel();
+
let weak_command_tx = Arc::downgrade(&command_tx);
let state_machine = TunnelStateMachine::new(
initial_settings,
@@ -151,7 +155,7 @@ pub async fn spawn(
}
});
- Ok(command_tx)
+ Ok((command_tx, JoinHandle { shutdown_rx }))
}
/// Representation of external commands for the tunnel state machine.
@@ -580,3 +584,19 @@ state_wrapper! {
Error(ErrorState),
}
}
+
+/// Handle used to wait for the tunnel state machine to shut down.
+pub struct JoinHandle {
+ shutdown_rx: oneshot::Receiver<()>,
+}
+
+impl JoinHandle {
+ /// Waits for the tunnel state machine to shut down.
+ /// This may fail after a timeout of `TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT`.
+ pub async fn try_join(self) {
+ match tokio::time::timeout(TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, self.shutdown_rx).await {
+ Ok(_) => log::info!("Tunnel state machine shut down"),
+ Err(_) => log::error!("Tunnel state machine did not shut down gracefully"),
+ }
+ }
+}