diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-03-09 12:23:36 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-14 12:08:56 +0100 |
| commit | 3f83f94aa63c8980e49a2c485e977a14d4cf6066 (patch) | |
| tree | a462f2d85e9bc7b861cebfede4909b20886cc805 | |
| parent | 13e369418b930f4500c8a3c7f72fea1ff666ee82 (diff) | |
| download | mullvadvpn-3f83f94aa63c8980e49a2c485e977a14d4cf6066.tar.xz mullvadvpn-3f83f94aa63c8980e49a2c485e977a14d4cf6066.zip | |
Move device validity check to its own type
| -rw-r--r-- | mullvad-daemon/src/device.rs | 66 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 53 |
2 files changed, 70 insertions, 49 deletions
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs index 24848ff68f..6009455a63 100644 --- a/mullvad-daemon/src/device.rs +++ b/mullvad-daemon/src/device.rs @@ -17,13 +17,21 @@ use mullvad_types::{ use std::{ future::Future, path::Path, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::{Duration, SystemTime}, }; use talpid_core::{ future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered}, mpsc::Sender, }; -use talpid_types::{net::wireguard::PrivateKey, ErrorExt}; +use talpid_types::{ + net::{wireguard::PrivateKey, TunnelType}, + tunnel::TunnelStateTransition, + ErrorExt, +}; use tokio::{ fs, io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, @@ -49,6 +57,10 @@ const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10); /// How long to wait on logout (device removal) before letting it continue as a background task. const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2); +/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts +/// to set up a WireGuard tunnel. +const WG_DEVICE_CHECK_THRESHOLD: usize = 3; + #[derive(err_derive::Error, Debug)] pub enum Error { #[error(display = "The account already has a maximum number of devices")] @@ -1075,3 +1087,55 @@ fn retry_strategy() -> Jittered<ExponentialBackoff> { .max_delay(RETRY_BACKOFF_INTERVAL_MAX), ) } + +/// Checks if the current device is valid if a WireGuard tunnel cannot be set up +/// after multiple attempts. +pub(crate) struct TunnelStateChangeHandler { + manager: AccountManagerHandle, + check_validity: Arc<AtomicBool>, + wg_retry_attempt: usize, +} + +impl TunnelStateChangeHandler { + pub fn new(manager: AccountManagerHandle) -> Self { + Self { + manager, + check_validity: Arc::new(AtomicBool::new(true)), + wg_retry_attempt: 0, + } + } + + pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) { + match new_state { + TunnelStateTransition::Connecting(endpoint) => { + if endpoint.tunnel_type != TunnelType::Wireguard { + return; + } + self.wg_retry_attempt += 1; + if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 { + let handle = self.manager.clone(); + let check_validity = self.check_validity.clone(); + tokio::spawn(async move { + if !check_validity.swap(false, Ordering::SeqCst) { + return; + } + if let Err(error) = handle.validate_device().await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check device validity") + ); + if error.is_network_error() { + check_validity.store(true, Ordering::SeqCst); + } + } + }); + } + } + TunnelStateTransition::Connected(_) | TunnelStateTransition::Disconnected => { + self.check_validity.store(true, Ordering::SeqCst); + self.wg_retry_attempt = 0; + } + _ => (), + } + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index ddb84bf4a0..64c7107d7f 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -63,10 +63,7 @@ use std::{ net::{IpAddr, Ipv4Addr}, path::PathBuf, pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - mpsc as sync_mpsc, Arc, Weak, - }, + sync::{mpsc as sync_mpsc, Arc, Weak}, time::Duration, }; #[cfg(any(target_os = "linux", windows))] @@ -92,9 +89,6 @@ use tokio::io; /// Delay between generating a new WireGuard key and reconnecting const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60); -/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` attempts. -const WG_DEVICE_CHECK_THRESHOLD: usize = 3; - /// When we want to block certain contents with the help of DNS server side, /// we compute the resolver IP to use based on these constants. The last /// byte can be ORed together to combine multiple block lists. @@ -576,9 +570,8 @@ pub struct Daemon<L: EventListener> { event_listener: L, settings: SettingsPersister, account_history: account_history::AccountHistory, + device_checker: device::TunnelStateChangeHandler, account_manager: device::AccountManagerHandle, - wg_retry_attempt: usize, - wg_check_validity: Arc<AtomicBool>, rpc_runtime: mullvad_rpc::MullvadRpcRuntime, rpc_handle: mullvad_rpc::rest::MullvadRestHandle, version_updater_handle: version_check::VersionUpdaterHandle, @@ -782,9 +775,8 @@ where event_listener, settings, account_history, + device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()), account_manager, - wg_retry_attempt: 0, - wg_check_validity: Arc::new(AtomicBool::new(true)), rpc_runtime, rpc_handle, version_updater_handle, @@ -949,6 +941,8 @@ where ) { self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition) .await; + self.device_checker + .handle_state_transition(&tunnel_state_transition); let tunnel_state = match tunnel_state_transition { TunnelStateTransition::Disconnected => TunnelState::Disconnected, @@ -966,8 +960,6 @@ where TunnelStateTransition::Error(error_state) => TunnelState::Error(error_state), }; - self.maybe_validate_device(&tunnel_state); - if !tunnel_state.is_connected() { // Cancel reconnects except when entering the connected state. // Exempt the latter because a reconnect scheduled while connecting should not be @@ -1015,41 +1007,6 @@ where }; } - /// Check whether the device is valid after a number of failed connection attempts. - fn maybe_validate_device(&mut self, tunnel_state: &TunnelState) { - match tunnel_state { - TunnelState::Connecting { endpoint, .. } => { - if endpoint.tunnel_type != TunnelType::Wireguard { - return; - } - self.wg_retry_attempt += 1; - if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 { - let handle = self.account_manager.clone(); - let check_validity = self.wg_check_validity.clone(); - tokio::spawn(async move { - if !check_validity.swap(false, Ordering::SeqCst) { - return; - } - if let Err(error) = handle.validate_device().await { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check device validity") - ); - if error.is_network_error() { - check_validity.store(true, Ordering::SeqCst); - } - } - }); - } - } - TunnelState::Connected { .. } | TunnelState::Disconnected => { - self.wg_check_validity.store(true, Ordering::SeqCst); - self.wg_retry_attempt = 0; - } - _ => (), - } - } - async fn handle_generate_tunnel_parameters( &mut self, tunnel_parameters_tx: &sync_mpsc::Sender< |
