summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-12-05 11:01:44 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-12-12 16:45:05 +0100
commitdb9c774d36b2d8180acd01705f44c805f7dd2276 (patch)
tree4773713f56b381bdbad9256397f48070fc9a97ac
parent95a954b7016d45a415388acf6ed489d03686c700 (diff)
downloadmullvadvpn-db9c774d36b2d8180acd01705f44c805f7dd2276.tar.xz
mullvadvpn-db9c774d36b2d8180acd01705f44c805f7dd2276.zip
Add unit tests for device check
-rw-r--r--mullvad-daemon/src/device/mod.rs228
1 files changed, 196 insertions, 32 deletions
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs
index a73d77cbf6..2668e995ee 100644
--- a/mullvad-daemon/src/device/mod.rs
+++ b/mullvad-daemon/src/device/mod.rs
@@ -44,8 +44,8 @@ const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10);
/// How long to wait on logout (device removal) before letting it continue as a background task.
const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2);
-/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts
-/// to set up a WireGuard tunnel.
+/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` attempt to set up
+/// a WireGuard tunnel.
const WG_DEVICE_CHECK_THRESHOLD: usize = 2;
#[derive(err_derive::Error, Debug, Clone)]
@@ -1232,7 +1232,7 @@ impl DeviceCacher {
/// after multiple attempts.
pub(crate) struct TunnelStateChangeHandler {
manager: AccountManagerHandle,
- check_validity: Arc<AtomicBool>,
+ no_more_retries: Arc<AtomicBool>,
wg_retry_attempt: usize,
}
@@ -1240,51 +1240,215 @@ impl TunnelStateChangeHandler {
pub fn new(manager: AccountManagerHandle) -> Self {
Self {
manager,
- check_validity: Arc::new(AtomicBool::new(true)),
+ no_more_retries: Arc::new(AtomicBool::new(false)),
wg_retry_attempt: 0,
}
}
+ /// Handle state transitions and optionally check the device/account validity. This should be
+ /// called during every tunnel state transition.
pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) {
+ let handle = self.manager.clone();
+
+ let wg_attempt = self.wg_retry_attempt;
+ self.wg_retry_attempt = Self::next_retry_attempt(new_state, self.wg_retry_attempt);
+
+ if self.wg_retry_attempt > wg_attempt {
+ tokio::spawn(Self::maybe_check_validity(
+ wg_attempt,
+ self.no_more_retries.clone(),
+ move || Self::check_validity(handle),
+ ));
+ }
+ }
+
+ /// Return an incremented count for `retry_attempt` if this is another WireGuard connection
+ /// attempt, and zero the counter when leaving the connecting loop.
+ fn next_retry_attempt(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize {
match new_state {
TunnelStateTransition::Connecting(endpoint) => {
- if endpoint.tunnel_type != TunnelType::Wireguard {
- return;
- }
- self.wg_retry_attempt = self.wg_retry_attempt.wrapping_add(1);
- if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 {
- let handle = self.manager.clone();
- let check_validity = self.check_validity.clone();
- tokio::spawn(async move {
- if !check_validity.swap(false, Ordering::SeqCst) {
- return;
- }
- if let Err(error) = Self::check_validity(handle).await {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to check device or account validity"
- )
- );
- if error.is_network_error() || error.is_aborted() {
- check_validity.store(true, Ordering::SeqCst);
- }
- }
- });
+ if endpoint.tunnel_type == TunnelType::Wireguard {
+ retry_attempt.wrapping_add(1)
+ } else {
+ retry_attempt
}
}
TunnelStateTransition::Error(_)
| TunnelStateTransition::Connected(_)
- | TunnelStateTransition::Disconnected => {
- self.check_validity.store(true, Ordering::SeqCst);
- self.wg_retry_attempt = 0;
+ | TunnelStateTransition::Disconnected => 0,
+ _ => retry_attempt,
+ }
+ }
+
+ /// Run `validate` when connecting to a WireGuard server, on certain retry attempts.
+ /// If `no_more_retries` is true, no further checks are made. `no_more_retries` is reset
+ /// on the first connection attempt.
+ ///
+ /// This returns whether the device/account validity ran.
+ async fn maybe_check_validity<Validate, ValidateResult>(
+ wg_attempt: usize,
+ no_more_retries: Arc<AtomicBool>,
+ validate: Validate,
+ ) -> bool
+ where
+ Validate: FnOnce() -> ValidateResult + Send + 'static,
+ ValidateResult: Future<Output = Result<(), Error>> + Send + 'static,
+ {
+ if wg_attempt == 0 {
+ // Starting a new connecting loop, so reset the retry state
+ no_more_retries.store(false, Ordering::SeqCst);
+ }
+
+ if !Self::should_check_validity_on_attempt(wg_attempt) {
+ return false;
+ }
+ if no_more_retries.swap(true, Ordering::SeqCst) {
+ // We've either already received the device state or we've given up
+ return false;
+ }
+ match validate().await {
+ Ok(()) => true,
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check device or account validity")
+ );
+ if Self::should_continue_retries(error) {
+ // If the request failed due to a network error, we should continue
+ // retrying. We give up otherwise, because it means we have a known result or
+ // the API returned some error.
+ no_more_retries.store(false, Ordering::SeqCst);
+ }
+ true
}
- _ => (),
}
}
- pub async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> {
+ async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> {
handle.validate_device().await?;
handle.check_expiry().await.map(|_expiry| ())
}
+
+ fn should_check_validity_on_attempt(wg_attempt: usize) -> bool {
+ wg_attempt % WG_DEVICE_CHECK_THRESHOLD == WG_DEVICE_CHECK_THRESHOLD - 1
+ }
+
+ fn should_continue_retries(err: Error) -> bool {
+ err.is_network_error() || err.is_aborted()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::TunnelStateChangeHandler;
+ use super::{Error, WG_DEVICE_CHECK_THRESHOLD};
+ use mullvad_relay_selector::RelaySelector;
+ use std::sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+ };
+ use talpid_types::net::TunnelType;
+
+ const TIMEOUT_ERROR: Error = Error::OtherRestError(mullvad_api::rest::Error::TimeoutError);
+
+ /// Starting a new connection loop should resume device validity checks
+ #[tokio::test]
+ async fn test_device_check_reset() {
+ let no_more_retries = Arc::new(AtomicBool::new(true));
+
+ TunnelStateChangeHandler::maybe_check_validity(0, no_more_retries.clone(), || async {
+ Ok(())
+ })
+ .await;
+
+ assert!(
+ !no_more_retries.load(Ordering::SeqCst),
+ "expected retry state to be reset on first connection attempt"
+ );
+ }
+
+ /// Retries should stop when a device check succeeds
+ #[tokio::test]
+ async fn test_device_check_on_success() {
+ const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1;
+ assert!(TunnelStateChangeHandler::should_check_validity_on_attempt(
+ ATTEMPT
+ ));
+
+ let no_more_retries = Arc::new(AtomicBool::new(false));
+
+ let check_ran = TunnelStateChangeHandler::maybe_check_validity(
+ ATTEMPT,
+ no_more_retries.clone(),
+ || async { Ok(()) },
+ )
+ .await;
+
+ assert!(check_ran, "expected device check to run");
+
+ let check_ran = TunnelStateChangeHandler::maybe_check_validity(
+ ATTEMPT,
+ no_more_retries.clone(),
+ || async { Ok(()) },
+ )
+ .await;
+
+ assert!(
+ !check_ran,
+ "expected device check to give up after successful check"
+ );
+ }
+
+ /// Retries should continue when a network error occurs
+ #[tokio::test]
+ async fn test_device_check_on_network_error() {
+ const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1;
+ assert!(TunnelStateChangeHandler::should_check_validity_on_attempt(
+ ATTEMPT
+ ));
+
+ let no_more_retries = Arc::new(AtomicBool::new(false));
+
+ let check_ran = TunnelStateChangeHandler::maybe_check_validity(
+ ATTEMPT,
+ no_more_retries.clone(),
+ || async { Err(TIMEOUT_ERROR) },
+ )
+ .await;
+
+ assert!(check_ran, "expected device check to occur");
+
+ let check_ran = TunnelStateChangeHandler::maybe_check_validity(
+ ATTEMPT,
+ no_more_retries.clone(),
+ || async { Err(TIMEOUT_ERROR) },
+ )
+ .await;
+
+ assert!(
+ check_ran,
+ "expected device check to continue after a network error"
+ );
+ }
+
+ /// Test whether the relay selector selects wireguard often enough, given no special
+ /// constraints, to verify that the device is valid
+ #[test]
+ fn test_validates_by_default() {
+ for attempt in 0.. {
+ let should_validate =
+ TunnelStateChangeHandler::should_check_validity_on_attempt(attempt);
+ let (_, _, tunnel_type) =
+ RelaySelector::preferred_tunnel_constraints(attempt.try_into().unwrap());
+ assert_eq!(
+ tunnel_type,
+ TunnelType::Wireguard,
+ "failed on attempt {attempt}"
+ );
+ if should_validate {
+ // Now that we've triggered a device check, we can give up
+ break;
+ }
+ }
+ }
}