diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-12-05 11:01:44 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-12-12 16:45:05 +0100 |
| commit | db9c774d36b2d8180acd01705f44c805f7dd2276 (patch) | |
| tree | 4773713f56b381bdbad9256397f48070fc9a97ac | |
| parent | 95a954b7016d45a415388acf6ed489d03686c700 (diff) | |
| download | mullvadvpn-db9c774d36b2d8180acd01705f44c805f7dd2276.tar.xz mullvadvpn-db9c774d36b2d8180acd01705f44c805f7dd2276.zip | |
Add unit tests for device check
| -rw-r--r-- | mullvad-daemon/src/device/mod.rs | 228 |
1 files changed, 196 insertions, 32 deletions
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs index a73d77cbf6..2668e995ee 100644 --- a/mullvad-daemon/src/device/mod.rs +++ b/mullvad-daemon/src/device/mod.rs @@ -44,8 +44,8 @@ const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10); /// How long to wait on logout (device removal) before letting it continue as a background task. const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2); -/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts -/// to set up a WireGuard tunnel. +/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` attempt to set up +/// a WireGuard tunnel. const WG_DEVICE_CHECK_THRESHOLD: usize = 2; #[derive(err_derive::Error, Debug, Clone)] @@ -1232,7 +1232,7 @@ impl DeviceCacher { /// after multiple attempts. pub(crate) struct TunnelStateChangeHandler { manager: AccountManagerHandle, - check_validity: Arc<AtomicBool>, + no_more_retries: Arc<AtomicBool>, wg_retry_attempt: usize, } @@ -1240,51 +1240,215 @@ impl TunnelStateChangeHandler { pub fn new(manager: AccountManagerHandle) -> Self { Self { manager, - check_validity: Arc::new(AtomicBool::new(true)), + no_more_retries: Arc::new(AtomicBool::new(false)), wg_retry_attempt: 0, } } + /// Handle state transitions and optionally check the device/account validity. This should be + /// called during every tunnel state transition. pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) { + let handle = self.manager.clone(); + + let wg_attempt = self.wg_retry_attempt; + self.wg_retry_attempt = Self::next_retry_attempt(new_state, self.wg_retry_attempt); + + if self.wg_retry_attempt > wg_attempt { + tokio::spawn(Self::maybe_check_validity( + wg_attempt, + self.no_more_retries.clone(), + move || Self::check_validity(handle), + )); + } + } + + /// Return an incremented count for `retry_attempt` if this is another WireGuard connection + /// attempt, and zero the counter when leaving the connecting loop. + fn next_retry_attempt(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize { match new_state { TunnelStateTransition::Connecting(endpoint) => { - if endpoint.tunnel_type != TunnelType::Wireguard { - return; - } - self.wg_retry_attempt = self.wg_retry_attempt.wrapping_add(1); - if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 { - let handle = self.manager.clone(); - let check_validity = self.check_validity.clone(); - tokio::spawn(async move { - if !check_validity.swap(false, Ordering::SeqCst) { - return; - } - if let Err(error) = Self::check_validity(handle).await { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to check device or account validity" - ) - ); - if error.is_network_error() || error.is_aborted() { - check_validity.store(true, Ordering::SeqCst); - } - } - }); + if endpoint.tunnel_type == TunnelType::Wireguard { + retry_attempt.wrapping_add(1) + } else { + retry_attempt } } TunnelStateTransition::Error(_) | TunnelStateTransition::Connected(_) - | TunnelStateTransition::Disconnected => { - self.check_validity.store(true, Ordering::SeqCst); - self.wg_retry_attempt = 0; + | TunnelStateTransition::Disconnected => 0, + _ => retry_attempt, + } + } + + /// Run `validate` when connecting to a WireGuard server, on certain retry attempts. + /// If `no_more_retries` is true, no further checks are made. `no_more_retries` is reset + /// on the first connection attempt. + /// + /// This returns whether the device/account validity ran. + async fn maybe_check_validity<Validate, ValidateResult>( + wg_attempt: usize, + no_more_retries: Arc<AtomicBool>, + validate: Validate, + ) -> bool + where + Validate: FnOnce() -> ValidateResult + Send + 'static, + ValidateResult: Future<Output = Result<(), Error>> + Send + 'static, + { + if wg_attempt == 0 { + // Starting a new connecting loop, so reset the retry state + no_more_retries.store(false, Ordering::SeqCst); + } + + if !Self::should_check_validity_on_attempt(wg_attempt) { + return false; + } + if no_more_retries.swap(true, Ordering::SeqCst) { + // We've either already received the device state or we've given up + return false; + } + match validate().await { + Ok(()) => true, + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check device or account validity") + ); + if Self::should_continue_retries(error) { + // If the request failed due to a network error, we should continue + // retrying. We give up otherwise, because it means we have a known result or + // the API returned some error. + no_more_retries.store(false, Ordering::SeqCst); + } + true } - _ => (), } } - pub async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> { + async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> { handle.validate_device().await?; handle.check_expiry().await.map(|_expiry| ()) } + + fn should_check_validity_on_attempt(wg_attempt: usize) -> bool { + wg_attempt % WG_DEVICE_CHECK_THRESHOLD == WG_DEVICE_CHECK_THRESHOLD - 1 + } + + fn should_continue_retries(err: Error) -> bool { + err.is_network_error() || err.is_aborted() + } +} + +#[cfg(test)] +mod test { + use super::TunnelStateChangeHandler; + use super::{Error, WG_DEVICE_CHECK_THRESHOLD}; + use mullvad_relay_selector::RelaySelector; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + use talpid_types::net::TunnelType; + + const TIMEOUT_ERROR: Error = Error::OtherRestError(mullvad_api::rest::Error::TimeoutError); + + /// Starting a new connection loop should resume device validity checks + #[tokio::test] + async fn test_device_check_reset() { + let no_more_retries = Arc::new(AtomicBool::new(true)); + + TunnelStateChangeHandler::maybe_check_validity(0, no_more_retries.clone(), || async { + Ok(()) + }) + .await; + + assert!( + !no_more_retries.load(Ordering::SeqCst), + "expected retry state to be reset on first connection attempt" + ); + } + + /// Retries should stop when a device check succeeds + #[tokio::test] + async fn test_device_check_on_success() { + const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1; + assert!(TunnelStateChangeHandler::should_check_validity_on_attempt( + ATTEMPT + )); + + let no_more_retries = Arc::new(AtomicBool::new(false)); + + let check_ran = TunnelStateChangeHandler::maybe_check_validity( + ATTEMPT, + no_more_retries.clone(), + || async { Ok(()) }, + ) + .await; + + assert!(check_ran, "expected device check to run"); + + let check_ran = TunnelStateChangeHandler::maybe_check_validity( + ATTEMPT, + no_more_retries.clone(), + || async { Ok(()) }, + ) + .await; + + assert!( + !check_ran, + "expected device check to give up after successful check" + ); + } + + /// Retries should continue when a network error occurs + #[tokio::test] + async fn test_device_check_on_network_error() { + const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1; + assert!(TunnelStateChangeHandler::should_check_validity_on_attempt( + ATTEMPT + )); + + let no_more_retries = Arc::new(AtomicBool::new(false)); + + let check_ran = TunnelStateChangeHandler::maybe_check_validity( + ATTEMPT, + no_more_retries.clone(), + || async { Err(TIMEOUT_ERROR) }, + ) + .await; + + assert!(check_ran, "expected device check to occur"); + + let check_ran = TunnelStateChangeHandler::maybe_check_validity( + ATTEMPT, + no_more_retries.clone(), + || async { Err(TIMEOUT_ERROR) }, + ) + .await; + + assert!( + check_ran, + "expected device check to continue after a network error" + ); + } + + /// Test whether the relay selector selects wireguard often enough, given no special + /// constraints, to verify that the device is valid + #[test] + fn test_validates_by_default() { + for attempt in 0.. { + let should_validate = + TunnelStateChangeHandler::should_check_validity_on_attempt(attempt); + let (_, _, tunnel_type) = + RelaySelector::preferred_tunnel_constraints(attempt.try_into().unwrap()); + assert_eq!( + tunnel_type, + TunnelType::Wireguard, + "failed on attempt {attempt}" + ); + if should_validate { + // Now that we've triggered a device check, we can give up + break; + } + } + } } |
