summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-03-27 09:49:06 +0100
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-03-27 09:49:06 +0100
commit0317724fa86cd6b734c9688fa25640e37e0d429a (patch)
treeef5fe80468182cb436225e9bb914e577d93cbd2c
parent9cceb079580928f5c7ccaf365988461d3fdebd62 (diff)
parentd8b29fdb81a6d1cd4b6668ae5fef3af581c13082 (diff)
downloadmullvadvpn-0317724fa86cd6b734c9688fa25640e37e0d429a.tar.xz
mullvadvpn-0317724fa86cd6b734c9688fa25640e37e0d429a.zip
Merge branch 'decouple-device-check-from-retry-order-des-715'
-rw-r--r--mullvad-daemon/src/device/mod.rs288
1 files changed, 145 insertions, 143 deletions
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs
index 0493976a41..16a5a37460 100644
--- a/mullvad-daemon/src/device/mod.rs
+++ b/mullvad-daemon/src/device/mod.rs
@@ -25,7 +25,11 @@ use std::{
time::{Duration, SystemTime},
};
use talpid_core::mpsc::Sender;
-use talpid_types::{net::TunnelType, tunnel::TunnelStateTransition, ErrorExt};
+use talpid_types::{
+ net::{TunnelEndpoint, TunnelType},
+ tunnel::TunnelStateTransition,
+ ErrorExt,
+};
use tokio::{
fs,
io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
@@ -1239,7 +1243,7 @@ impl DeviceCacher {
/// after multiple attempts.
pub(crate) struct TunnelStateChangeHandler {
manager: AccountManagerHandle,
- no_more_retries: Arc<AtomicBool>,
+ can_retry: Arc<AtomicBool>,
wg_retry_attempt: usize,
}
@@ -1247,7 +1251,7 @@ impl TunnelStateChangeHandler {
pub fn new(manager: AccountManagerHandle) -> Self {
Self {
manager,
- no_more_retries: Arc::new(AtomicBool::new(false)),
+ can_retry: Arc::new(AtomicBool::new(true)),
wg_retry_attempt: 0,
}
}
@@ -1255,120 +1259,146 @@ impl TunnelStateChangeHandler {
/// Handle state transitions and optionally check the device/account validity. This should be
/// called during every tunnel state transition.
pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) {
- let handle = self.manager.clone();
-
- let wg_attempt = self.wg_retry_attempt;
- self.wg_retry_attempt = Self::next_retry_attempt(new_state, self.wg_retry_attempt);
-
- if self.wg_retry_attempt > wg_attempt {
- tokio::spawn(Self::maybe_check_validity(
- wg_attempt,
- self.no_more_retries.clone(),
- move || Self::check_validity(handle),
+ self.wg_retry_attempt = Self::update_retry_counter(new_state, self.wg_retry_attempt);
+ Self::update_retry_bool(new_state, self.can_retry.clone());
+ // Check if a device-check should be triggered
+ if Self::should_check_device_validity(self.wg_retry_attempt, self.can_retry.clone()) {
+ let handle = self.manager.clone();
+ tokio::spawn(Self::check_device_validity(
+ self.can_retry.clone(),
+ move || Self::check_device_validity_inner(handle),
));
}
}
+ /// Run `validate` when connecting to a WireGuard server.
+ ///
+ /// # Note
+ /// `can_retry` is reset on network errors. Otherwise, it is set to `true` as to not
+ /// immediately trigger new device checks.
+ async fn check_device_validity<Validate, ValidateResult>(
+ can_retry: Arc<AtomicBool>,
+ validate: Validate,
+ ) where
+ Validate: FnOnce() -> ValidateResult + Send,
+ ValidateResult: Future<Output = Result<(), Error>> + Send,
+ {
+ // Log any error
+ let result = validate().await.inspect_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check device or account validity")
+ )
+ });
+ // Update `can_retry` based on the result of `validate`
+ match result {
+ // If the request failed due to a network error, we should continue
+ // retrying.
+ Err(ref error) if Self::should_continue_retries(error) => {
+ can_retry.store(true, Ordering::SeqCst);
+ }
+ // Otherwise we give up, because it means we have a known result or
+ // the API returned some error.
+ _ => (),
+ }
+ }
+
/// Return an incremented count for `retry_attempt` if this is another WireGuard connection
- /// attempt, and zero the counter when leaving the connecting loop.
- fn next_retry_attempt(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize {
+ /// attempt, otherwise `retry_attempt` is returned.
+ ///
+ /// Reset to the counter to `0` when we manage to successfully connect to a Wireguard relay.
+ fn update_retry_counter(new_state: &TunnelStateTransition, retry_attempt: usize) -> usize {
+ let wireguard =
+ |endpoint: &TunnelEndpoint| matches!(endpoint.tunnel_type, TunnelType::Wireguard);
+
match new_state {
- TunnelStateTransition::Connecting(endpoint) => {
- if endpoint.tunnel_type == TunnelType::Wireguard {
- retry_attempt.wrapping_add(1)
- } else {
- retry_attempt
- }
+ // Increment the counter if this is another Wireguard attempt
+ TunnelStateTransition::Connecting(endpoint) if wireguard(endpoint) => {
+ retry_attempt.wrapping_add(1)
}
- TunnelStateTransition::Error(_)
- | TunnelStateTransition::Connected(_)
- | TunnelStateTransition::Disconnected { .. } => 0,
+ // Only reset the counter if we managed to connect to a Wireguard relay
+ TunnelStateTransition::Connected(endpoint) if wireguard(endpoint) => 0,
+ // Any other state transition doesn't affect the counter
_ => retry_attempt,
}
}
- /// Run `validate` when connecting to a WireGuard server, on certain retry attempts.
- /// If `no_more_retries` is true, no further checks are made. `no_more_retries` is reset
- /// on the first connection attempt.
+ /// Check if `new_state` breaks a connecting-loop. If so, the retry state `can_retry` is reset
+ /// (i.e. set to `true`).
///
- /// This returns whether the device/account validity ran.
- async fn maybe_check_validity<Validate, ValidateResult>(
- wg_attempt: usize,
- no_more_retries: Arc<AtomicBool>,
- validate: Validate,
- ) -> bool
- where
- Validate: FnOnce() -> ValidateResult + Send + 'static,
- ValidateResult: Future<Output = Result<(), Error>> + Send + 'static,
- {
- if wg_attempt == 0 {
- // Starting a new connecting loop, so reset the retry state
- no_more_retries.store(false, Ordering::SeqCst);
- }
-
- if !Self::should_check_validity_on_attempt(wg_attempt) {
- return false;
- }
- if no_more_retries.swap(true, Ordering::SeqCst) {
- // We've either already received the device state or we've given up
- return false;
- }
- match validate().await {
- Ok(()) => true,
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to check device or account validity")
- );
- if Self::should_continue_retries(error) {
- // If the request failed due to a network error, we should continue
- // retrying. We give up otherwise, because it means we have a known result or
- // the API returned some error.
- no_more_retries.store(false, Ordering::SeqCst);
- }
- true
+ /// # Note
+ /// The following state transition counts as breaking a connecting-loop: `Connected`,
+ /// `Disconnected` and `Error`.
+ fn update_retry_bool(new_state: &TunnelStateTransition, can_retry: Arc<AtomicBool>) {
+ match new_state {
+ TunnelStateTransition::Disconnected { .. }
+ | TunnelStateTransition::Connected(_)
+ | TunnelStateTransition::Error(_) => {
+ can_retry.store(true, Ordering::SeqCst);
}
- }
+ _ => {}
+ };
}
- async fn check_validity(handle: AccountManagerHandle) -> Result<(), Error> {
+ async fn check_device_validity_inner(handle: AccountManagerHandle) -> Result<(), Error> {
handle.validate_device().await?;
handle.check_expiry().await.map(|_expiry| ())
}
- fn should_check_validity_on_attempt(wg_attempt: usize) -> bool {
- wg_attempt % WG_DEVICE_CHECK_THRESHOLD == WG_DEVICE_CHECK_THRESHOLD - 1
+ /// Check if a device check is due
+ fn should_check_device_validity(
+ wireguard_retry_attempt: usize,
+ can_retry: Arc<AtomicBool>,
+ ) -> bool {
+ Self::should_check_device_validity_on_attempt(wireguard_retry_attempt)
+ && can_retry.swap(false, Ordering::SeqCst)
+ }
+
+ /// Check if a device check should be triggered based on the current `wireguard_retry_attempt`
+ const fn should_check_device_validity_on_attempt(wireguard_retry_attempt: usize) -> bool {
+ // Incorporate a debounce effect where every `WG_DEVICE_CHECK_THRESHOLD` attempt should be
+ // able to trigger a device check.
+ wireguard_retry_attempt > 0 && (wireguard_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0)
}
- fn should_continue_retries(err: Error) -> bool {
+ fn should_continue_retries(err: &Error) -> bool {
err.is_network_error() || err.is_aborted()
}
}
#[cfg(test)]
mod test {
- use super::{Error, TunnelStateChangeHandler, WG_DEVICE_CHECK_THRESHOLD};
- use mullvad_relay_selector::RelaySelector;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
- use talpid_types::net::TunnelType;
+ use talpid_types::tunnel::TunnelStateTransition;
+
+ use super::{Error, TunnelStateChangeHandler, WG_DEVICE_CHECK_THRESHOLD};
const TIMEOUT_ERROR: Error = Error::OtherRestError(mullvad_api::rest::Error::TimeoutError);
- /// Starting a new connection loop should resume device validity checks
- #[tokio::test]
- async fn test_device_check_reset() {
- let no_more_retries = Arc::new(AtomicBool::new(true));
+ /// Verify that a device check is triggered 'when expected', i.e. when the current attempt
+ /// has reached the threshold as specified by [`WG_DEVICE_CHECK_THRESHOLD`]
+ #[test]
+ fn test_device_check_by_retry_attempt() {
+ assert!(
+ TunnelStateChangeHandler::should_check_device_validity_on_attempt(
+ WG_DEVICE_CHECK_THRESHOLD
+ )
+ );
+ }
- TunnelStateChangeHandler::maybe_check_validity(0, no_more_retries.clone(), || async {
- Ok(())
- })
- .await;
+ /// Starting a new connection loop should resume device validity checks
+ #[test]
+ fn test_device_check_reset() {
+ let can_retry = Arc::new(AtomicBool::new(false));
+ // Transitioning to the 'Disconnected' state counts as breaking the 'connection loop'
+ let new_tunnel_state = TunnelStateTransition::Disconnected { locked_down: false };
+ TunnelStateChangeHandler::update_retry_bool(&new_tunnel_state, can_retry.clone());
assert!(
- !no_more_retries.load(Ordering::SeqCst),
+ can_retry.load(Ordering::SeqCst),
"expected retry state to be reset on first connection attempt"
);
}
@@ -1376,31 +1406,24 @@ mod test {
/// Retries should stop when a device check succeeds
#[tokio::test]
async fn test_device_check_on_success() {
- const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1;
- assert!(TunnelStateChangeHandler::should_check_validity_on_attempt(
- ATTEMPT
- ));
-
- let no_more_retries = Arc::new(AtomicBool::new(false));
+ let can_retry = Arc::new(AtomicBool::new(true));
- let check_ran = TunnelStateChangeHandler::maybe_check_validity(
- ATTEMPT,
- no_more_retries.clone(),
- || async { Ok(()) },
- )
- .await;
-
- assert!(check_ran, "expected device check to run");
-
- let check_ran = TunnelStateChangeHandler::maybe_check_validity(
- ATTEMPT,
- no_more_retries.clone(),
- || async { Ok(()) },
- )
- .await;
+ let did_run = TunnelStateChangeHandler::should_check_device_validity(
+ WG_DEVICE_CHECK_THRESHOLD,
+ can_retry.clone(),
+ );
+ assert!(did_run, "expected device check to run");
+ // Manually trigger the device check and verify that we still can try to perform a device
+ // check
+ TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async { Ok(()) })
+ .await;
+ let did_run = TunnelStateChangeHandler::should_check_device_validity(
+ WG_DEVICE_CHECK_THRESHOLD,
+ can_retry.clone(),
+ );
assert!(
- !check_ran,
+ !did_run,
"expected device check to give up after successful check"
);
}
@@ -1408,53 +1431,32 @@ mod test {
/// Retries should continue when a network error occurs
#[tokio::test]
async fn test_device_check_on_network_error() {
- const ATTEMPT: usize = WG_DEVICE_CHECK_THRESHOLD - 1;
- assert!(TunnelStateChangeHandler::should_check_validity_on_attempt(
- ATTEMPT
- ));
-
- let no_more_retries = Arc::new(AtomicBool::new(false));
-
- let check_ran = TunnelStateChangeHandler::maybe_check_validity(
- ATTEMPT,
- no_more_retries.clone(),
- || async { Err(TIMEOUT_ERROR) },
- )
- .await;
+ let can_retry = Arc::new(AtomicBool::new(true));
- assert!(check_ran, "expected device check to occur");
-
- let check_ran = TunnelStateChangeHandler::maybe_check_validity(
- ATTEMPT,
- no_more_retries.clone(),
- || async { Err(TIMEOUT_ERROR) },
- )
+ // Run the check with a (simulated) network error - verify that `can_retry` is still true
+ // afterwards, indicating that a device check may still be performed
+ TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async {
+ Err(TIMEOUT_ERROR)
+ })
.await;
assert!(
- check_ran,
+ can_retry.load(Ordering::SeqCst),
"expected device check to continue after a network error"
);
- }
- /// Test whether the relay selector selects wireguard often enough, given no special
- /// constraints, to verify that the device is valid
- #[test]
- fn test_validates_by_default() {
- for attempt in 0.. {
- let should_validate =
- TunnelStateChangeHandler::should_check_validity_on_attempt(attempt);
- let (_, _, tunnel_type) =
- RelaySelector::preferred_tunnel_constraints(attempt.try_into().unwrap());
- assert_eq!(
- tunnel_type,
- TunnelType::Wireguard,
- "failed on attempt {attempt}"
- );
- if should_validate {
- // Now that we've triggered a device check, we can give up
- break;
- }
- }
+ // Re-run the check without a network error - verify that `can_retry` is no longer true
+ TunnelStateChangeHandler::should_check_device_validity(
+ WG_DEVICE_CHECK_THRESHOLD,
+ can_retry.clone(),
+ );
+
+ TunnelStateChangeHandler::check_device_validity(can_retry.clone(), || async { Ok(()) })
+ .await;
+
+ assert!(
+ !can_retry.load(Ordering::SeqCst),
+ "device check should no longer happen after successful check"
+ );
}
}