diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-03-25 13:02:04 +0100 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-03-27 09:27:41 +0100 |
| commit | d8b29fdb81a6d1cd4b6668ae5fef3af581c13082 (patch) | |
| tree | ef5fe80468182cb436225e9bb914e577d93cbd2c | |
| parent | 9cceb079580928f5c7ccaf365988461d3fdebd62 (diff) | |
| download | mullvadvpn-d8b29fdb81a6d1cd4b6668ae5fef3af581c13082.tar.xz mullvadvpn-d8b29fdb81a6d1cd4b6668ae5fef3af581c13082.zip | |
Reset device check counter more seldom
Only reset the device check counter if the daemon successfully connects
to a Wireguard relay. Otherwise, the counter is either persisted through
multiple tunnel connections (OpenVPN) or incremented on successive
failures to connect to a Wireguard relay.
| -rw-r--r-- | mullvad-daemon/src/device/mod.rs | 288 |
1 files changed, 145 insertions, 143 deletions
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs index 0493976a41..16a5a37460 100644 --- a/mullvad-daemon/src/device/mod.rs +++ b/mullvad-daemon/src/device/mod.rs @@ -25,7 +25,11 @@ use std::{ time::{Duration, SystemTime}, }; use talpid_core::mpsc::Sender; -use talpid_types::{net::TunnelType, tunnel::TunnelStateTransition, ErrorExt}; +use talpid_types::{ + net::{TunnelEndpoint, TunnelType}, + tunnel::TunnelStateTransition, + ErrorExt, +}; use tokio::{ fs, io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, @@ -1239,7 +1243,7 @@ impl DeviceCacher { /// after multiple attempts. pub(crate) struct TunnelStateChangeHandler { manager: AccountManagerHandle, - no_more_retries: Arc<AtomicBool>, + can_retry: Arc<AtomicBool>, wg_retry_attempt: usize, } @@ -1247,7 +1251,7 @@ impl TunnelStateChangeHandler { pub fn new(manager: AccountManagerHandle) -> Self { Self { manager, - no_more_retries: Arc::new(AtomicBool::new(false)), + can_retry: Arc::new(AtomicBool::new(true)), wg_retry_attempt: 0, } } @@ -1255,120 +1259,146 @@ impl TunnelStateChangeHandler { /// Handle state transitions and optionally check the device/account validity. This should be /// called during every tunnel state transition. pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) { - let handle = self.manager.clone(); - - let wg_attempt = self.wg_retry_attempt; - self.wg_retry_attempt = Self::next_retry_attempt(new_state, self.wg_retry_attempt); - - if self.wg_retry_attempt > wg_attempt { - tokio::spawn(Self::maybe_check_validity( - wg_attempt, - self.no_more_retries.clone(), - move || Self::check_validity(handle), + self.wg_retry_attempt = Self::update_retry_counter(new_state, self.wg_retry_attempt); + Self::update_retry_bool(new_state, self.can_retry.clone()); + // Check if a device-check should be triggered + if Self::should_check_device_validity(self.wg_retry_attempt, self.can_retry.clone()) { + let handle = self.manager.clone(); + tokio::spawn(Self::check_device_validity( + self.can_retry.clone(), + move || Self::check_device_validity_inner(handle), )); } } + /// Run `validate` when connecting to a WireGuard server. + /// + /// # Note + /// `can_retry` is reset on network errors. Otherwise, it is set to `true` as to not + /// immediately trigger new device checks. + async fn check_device_validity<Validate, ValidateResult>( + can_retry: Arc<AtomicBool>, + validate: Validate, + ) where + Validate: FnOnce() -> ValidateResult + Send, + ValidateResult: Future<Output = Result<(), Error>> + Send, + { + // Log any error + let result = validate().await.inspect_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check device or account validity") + ) + }); + // Update `can_retry` based on the result of `validate` + match result { + // If the request failed due to a network error, we should continue + // retrying. + Err(ref error) if Self::should_continue_retries(error) => { + can_retry.store(true, Ordering::SeqCst); + } + // Otherwise we give up, because it means we have a known result or + // the API returned some error. + _ => (), + } + } + /// Return an incremented count for `retry_attempt` if this is another WireGuard connection - /// attempt, and zero the counter when leaving the connecting loop. - fn next_retry_attempt(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize { + /// attempt, otherwise `retry_attempt` is returned. + /// + /// Reset to the counter to `0` when we manage to successfully connect to a Wireguard relay. + fn update_retry_counter(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize { + let wireguard = + |endpoint: &TunnelEndpoint| matches!(endpoint.tunnel_type, TunnelType::Wireguard); + match new_state { - TunnelStateTransition::Connecting(endpoint) => { - if endpoint.tunnel_type == TunnelType::Wireguard { - retry_attempt.wrapping_add(1) - } else { - retry_attempt - } + // Increment the counter if this is another Wireguard attempt + TunnelStateTransition::Connecting(endpoint) if wireguard(endpoint) => { + retry_attempt.wrapping_add(1) } - TunnelStateTransition::Error(_) - | TunnelStateTransition::Connected(_) - | TunnelStateTransition::Disconnected { .. } => 0, + // Only reset the counter if we managed to connect to a Wireguard relay + TunnelStateTransition::Connected(endpoint) if wireguard(endpoint) => 0, + // Any other state transition doesn't affect the counter _ => retry_attempt, } } - /// Run `validate` when connecting to a WireGuard server, on certain retry attempts. - /// If `no_more_retries` is true, no further checks are made. `no_more_retries` is reset - /// on the first connection attempt. + /// Check if `new_state` breaks a connecting-loop. If so, the retry state `can_retry` is reset + /// (i.e. set to `true`). /// - /// This returns whether the device/account validity ran. - async fn maybe_check_validity<Validate, ValidateResult>( - wg_attempt: usize, - no_more_retries: Arc<AtomicBool>, - validate: Validate, - ) -> bool - where - Validate: FnOnce() -> ValidateResult + Send + 'static, - ValidateResult: Future<Output = Result<(), Error>> + Send + 'static, - { - if wg_attempt == 0 { - // Starting a new connecting loop, so reset the retry state - no_more_retries.store(false, Ordering::SeqCst); - } - - if !Self::should_check_validity_on_attempt(wg_attempt) { - return false; - } - if no_more_retries.swap(true, Ordering::SeqCst) { - // We've either already received the device state or we've given up - return false; - } - match validate().await { - Ok(()) => true, - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check device or account validity") - ); - if Self::should_continue_retries(error) { - // If the request failed due to a network error, we should continue - // retrying. We give up otherwise, because it means we have a known result or - // the API returned some error. - no_more_retries.store(false, Ordering::SeqCst); - } - true + /// # Note + /// The following state transition counts as breaking a connecting-loop: `Connected`, + /// `Disconnected` and `Error`. + fn update_retry_bool(new_state: &TunnelStateTransition, can_retry: Arc<AtomicBool>) { + match new_state { + TunnelStateTransition::Disconnected { .. } + | TunnelStateTransition::Connected(_) + | TunnelStateTransition::Error(_) => { + can_retry.store(true, Ordering::SeqCst); } - } + _ => {} + }; } - async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> { + async fn check_device_validity_inner(handle: AccountManagerHandle) -> Result<(), Error> { handle.validate_device().await?; handle.check_expiry().await.map(|_expiry| ()) } - fn should_check_validity_on_attempt(wg_attempt: usize) -> bool { - wg_attempt % WG_DEVICE_CHECK_THRESHOLD == WG_DEVICE_CHECK_THRESHOLD - 1 + /// Check if a device check is due + fn should_check_device_validity( + wireguard_retry_attempt: usize, + can_retry: Arc<AtomicBool>, + ) -> bool { + Self::should_check_device_validity_on_attempt(wireguard_retry_attempt) + && can_retry.swap(false, Ordering::SeqCst) + } + + /// Check if a device check should be triggered based on the current `wireguard_retry_attempt` + const fn should_check_device_validity_on_attempt(wireguard_retry_attempt: usize) -> bool { + // Incorporate a debounce effect where every `WG_DEVICE_CHECK_THRESHOLD` attempt should be + // able to trigger a device check. + wireguard_retry_attempt > 0 && (wireguard_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0) } - fn should_continue_retries(err: Error) -> bool { + fn should_continue_retries(err: &Error) -> bool { err.is_network_error() || err.is_aborted() } } #[cfg(test)] mod test { - use super::{Error, TunnelStateChangeHandler, WG_DEVICE_CHECK_THRESHOLD}; - use mullvad_relay_selector::RelaySelector; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; - use talpid_types::net::TunnelType; + use talpid_types::tunnel::TunnelStateTransition; + + use super::{Error, TunnelStateChangeHandler, WG_DEVICE_CHECK_THRESHOLD}; const TIMEOUT_ERROR: Error = Error::OtherRestError(mullvad_api::rest::Error::TimeoutError); - /// Starting a new connection loop should resume device validity checks - #[tokio::test] - async fn test_device_check_reset() { - let no_more_retries = Arc::new(AtomicBool::new(true)); + /// Verify that a device check is triggered 'when expected', i.e. when the current attempt + /// has reached the threshold as specified by [`WG_DEVICE_CHECK_THRESHOLD`] + #[test] + fn test_device_check_by_retry_attempt() { + assert!( + TunnelStateChangeHandler::should_check_device_validity_on_attempt( + WG_DEVICE_CHECK_THRESHOLD + ) + ); + } - TunnelStateChangeHandler::maybe_check_validity(0, no_more_retries.clone(), || async { - Ok(()) - }) - .await; + /// Starting a new connection loop should resume device validity checks + #[test] + fn test_device_check_reset() { + let can_retry = Arc::new(AtomicBool::new(false)); + // Transitioning to the 'Disconnected' state counts as breaking the 'connection loop' + let new_tunnel_state = TunnelStateTransition::Disconnected { locked_down: false }; + TunnelStateChangeHandler::update_retry_bool(&new_tunnel_state, can_retry.clone()); assert!( - !no_more_retries.load(Ordering::SeqCst), + can_retry.load(Ordering::SeqCst), "expected retry state to be reset on first connection attempt" ); } @@ -1376,31 +1406,24 @@ mod test { /// Retries should stop when a device check succeeds #[tokio::test] async fn test_device_check_on_success() { - const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1; - assert!(TunnelStateChangeHandler::should_check_validity_on_attempt( - ATTEMPT - )); - - let no_more_retries = Arc::new(AtomicBool::new(false)); + let can_retry = Arc::new(AtomicBool::new(true)); - let check_ran = TunnelStateChangeHandler::maybe_check_validity( - ATTEMPT, - no_more_retries.clone(), - || async { Ok(()) }, - ) - .await; - - assert!(check_ran, "expected device check to run"); - - let check_ran = TunnelStateChangeHandler::maybe_check_validity( - ATTEMPT, - no_more_retries.clone(), - || async { Ok(()) }, - ) - .await; + let did_run = TunnelStateChangeHandler::should_check_device_validity( + WG_DEVICE_CHECK_THRESHOLD, + can_retry.clone(), + ); + assert!(did_run, "expected device check to run"); + // Manually trigger the device check and verify that we still can try to perform a device + // check + TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async { Ok(()) }) + .await; + let did_run = TunnelStateChangeHandler::should_check_device_validity( + WG_DEVICE_CHECK_THRESHOLD, + can_retry.clone(), + ); assert!( - !check_ran, + !did_run, "expected device check to give up after successful check" ); } @@ -1408,53 +1431,32 @@ mod test { /// Retries should continue when a network error occurs #[tokio::test] async fn test_device_check_on_network_error() { - const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1; - assert!(TunnelStateChangeHandler::should_check_validity_on_attempt( - ATTEMPT - )); - - let no_more_retries = Arc::new(AtomicBool::new(false)); - - let check_ran = TunnelStateChangeHandler::maybe_check_validity( - ATTEMPT, - no_more_retries.clone(), - || async { Err(TIMEOUT_ERROR) }, - ) - .await; + let can_retry = Arc::new(AtomicBool::new(true)); - assert!(check_ran, "expected device check to occur"); - - let check_ran = TunnelStateChangeHandler::maybe_check_validity( - ATTEMPT, - no_more_retries.clone(), - || async { Err(TIMEOUT_ERROR) }, - ) + // Run the check with a (simulated) network error - verify that `can_retry` is still true + // afterwards, indicating that a device check may still be performed + TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async { + Err(TIMEOUT_ERROR) + }) .await; assert!( - check_ran, + can_retry.load(Ordering::SeqCst), "expected device check to continue after a network error" ); - } - /// Test whether the relay selector selects wireguard often enough, given no special - /// constraints, to verify that the device is valid - #[test] - fn test_validates_by_default() { - for attempt in 0.. { - let should_validate = - TunnelStateChangeHandler::should_check_validity_on_attempt(attempt); - let (_, _, tunnel_type) = - RelaySelector::preferred_tunnel_constraints(attempt.try_into().unwrap()); - assert_eq!( - tunnel_type, - TunnelType::Wireguard, - "failed on attempt {attempt}" - ); - if should_validate { - // Now that we've triggered a device check, we can give up - break; - } - } + // Re-run the check without a network error - verify that `can_retry` is no longer true + TunnelStateChangeHandler::should_check_device_validity( + WG_DEVICE_CHECK_THRESHOLD, + can_retry.clone(), + ); + + TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async { Ok(()) }) + .await; + + assert!( + !can_retry.load(Ordering::SeqCst), + "device check should no longer happen after successful check" + ); } } |
