diff options
| -rw-r--r-- | mullvad-daemon/src/device/service.rs | 55 | ||||
| -rw-r--r-- | mullvad-daemon/src/version_check.rs | 12 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 10 | ||||
| -rw-r--r-- | talpid-core/src/future_retry.rs | 56 |
4 files changed, 64 insertions, 69 deletions
diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index 74dc15b94f..9c967e4413 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -15,12 +15,10 @@ use mullvad_api::{ rest::{self, Error as RestError, MullvadRestHandle}, AccountsProxy, DevicesProxy, }; -use talpid_core::future_retry::{ - constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered, -}; -const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; -const RETRY_ACTION_MAX_RETRIES: usize = 2; - +use talpid_core::future_retry::{retry_future, ConstantInterval, ExponentialBackoff, Jittered}; +/// Retry strategy used for user-initiated actions that require immediate feedback +const RETRY_ACTION_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(3)); +/// Retry strategy used for background tasks const RETRY_BACKOFF_STRATEGY: Jittered<ExponentialBackoff> = Jittered::jitter( ExponentialBackoff::new(Duration::from_secs(4), 5) .max_delay(Some(Duration::from_secs(24 * 60 * 60))), @@ -52,11 +50,10 @@ impl DeviceService { let api_handle = self.api_availability.clone(); let token_copy = account_token.clone(); async move { - let (device, addresses) = retry_future_n( + let (device, addresses) = retry_future( move || proxy.create(token_copy.clone(), pubkey.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -123,11 +120,10 @@ impl DeviceService { ) -> Result<(), Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.remove(token.clone(), device.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -165,11 +161,10 @@ impl DeviceService { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); let pubkey = private_key.public_key(); - let addresses = retry_future_n( + let addresses = retry_future( move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -218,11 +213,10 @@ impl DeviceService { pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.list(token.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error) @@ -247,11 +241,10 @@ impl DeviceService { pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result<Device, Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.get(token.clone(), device.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error) @@ -269,11 +262,10 @@ impl AccountService { pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> { let mut proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.create_account(), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) } @@ -283,22 +275,20 @@ impl AccountService { ) -> impl Future<Output = Result<String, rest::Error>> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.get_www_auth_token(account.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) } pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - let result = retry_future_n( + let result = retry_future( move || proxy.get_expiry(token.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await; if handle_expiry_result_inner(&result, &self.api_availability) { @@ -318,11 +308,10 @@ impl AccountService { ) -> Result<VoucherSubmission, Error> { let mut proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - let result = retry_future_n( + let result = retry_future( move || proxy.submit_voucher(account_token.clone(), voucher.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await; if result.is_ok() { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index b488b0b324..dfe0a26b5f 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -15,7 +15,7 @@ use std::{ str::FromStr, time::Duration, }; -use talpid_core::mpsc::Sender; +use talpid_core::{future_retry::ConstantInterval, mpsc::Sender}; use talpid_types::ErrorExt; use tokio::fs::{self, File}; @@ -31,9 +31,8 @@ const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(15); const UPDATE_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24); /// Wait this long until next try if an update failed const UPDATE_INTERVAL_ERROR: Duration = Duration::from_secs(60 * 60 * 6); -/// Retry interval for `RunVersionCheck`. -const IMMEDIATE_UPDATE_INTERVAL_ERROR: Duration = Duration::ZERO; -const IMMEDIATE_UPDATE_MAX_RETRIES: usize = 2; +/// Retry strategy for `RunVersionCheck`. +const IMMEDIATE_RETRY_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(3)); #[cfg(target_os = "linux")] const PLATFORM: &str = "linux"; @@ -194,11 +193,10 @@ impl VersionUpdater { .map_err(Error::Download) }; - Box::pin(talpid_core::future_retry::retry_future_n( + Box::pin(talpid_core::future_retry::retry_future( download_future_factory, move |result| Self::should_retry_immediate(result, &api_handle), - std::iter::repeat(IMMEDIATE_UPDATE_INTERVAL_ERROR), - IMMEDIATE_UPDATE_MAX_RETRIES, + IMMEDIATE_RETRY_STRATEGY, )) } diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index a349554ca6..bcae459442 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -6,15 +6,14 @@ use once_cell::sync::Lazy; use std::{path::PathBuf, process, str::FromStr, time::Duration}; use talpid_core::{ firewall::{self, Firewall}, - future_retry::{constant_interval, retry_future_n}, + future_retry::{retry_future, ConstantInterval}, }; use talpid_types::ErrorExt; static APP_VERSION: Lazy<ParsedAppVersion> = Lazy::new(|| ParsedAppVersion::from_str(mullvad_version::VERSION).unwrap()); -const KEY_RETRY_INTERVAL: Duration = Duration::ZERO; -const KEY_RETRY_MAX_RETRIES: usize = 4; +const DEVICE_REMOVAL_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(5)); #[repr(i32)] enum ExitStatus { @@ -171,14 +170,13 @@ async fn remove_device() -> Result<(), Error> { .await, ); - let device_removal = retry_future_n( + let device_removal = retry_future( move || proxy.remove(device.account_token.clone(), device.device.id.clone()), move |result| match result { Err(error) => error.is_network_error(), _ => false, }, - constant_interval(KEY_RETRY_INTERVAL), - KEY_RETRY_MAX_RETRIES, + DEVICE_REMOVAL_STRATEGY, ) .await; diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs index f5dc7a8f72..f7b68a3f2d 100644 --- a/talpid-core/src/future_retry.rs +++ b/talpid-core/src/future_retry.rs @@ -2,23 +2,6 @@ use rand::{distributions::OpenClosed01, Rng}; use std::{future::Future, ops::Deref, time::Duration}; use talpid_time::sleep; -/// Convenience function that works like [`retry_future`] but limits the number -/// of retries to `max_retries`. -pub async fn retry_future_n< - F: FnMut() -> O + 'static, - R: FnMut(&T) -> bool + 'static, - D: Iterator<Item = Duration> + 'static, - O: Future<Output = T>, - T, ->( - factory: F, - should_retry: R, - delays: D, - max_retries: usize, -) -> T { - retry_future(factory, should_retry, delays.take(max_retries)).await -} - /// Retries a future until it should stop as determined by the retry function, or when /// the iterator returns `None`. pub async fn retry_future< @@ -44,9 +27,36 @@ pub async fn retry_future< } } -/// Returns an iterator that repeats the same interval. -pub fn constant_interval(interval: Duration) -> impl Iterator<Item = Duration> { - std::iter::repeat(interval) +/// Iterator that repeats the same interval, with an optional maximum no. of attempts. +pub struct ConstantInterval { + interval: Duration, + attempt: usize, + max_attempts: Option<usize>, +} + +impl ConstantInterval { + /// Creates a `ConstantInterval` that repeats `interval`, at most `max_attempts` times. + pub const fn new(interval: Duration, max_attempts: Option<usize>) -> ConstantInterval { + ConstantInterval { + interval, + attempt: 0, + max_attempts, + } + } +} + +impl Iterator for ConstantInterval { + type Item = Duration; + + fn next(&mut self) -> Option<Duration> { + if let Some(max_attempts) = self.max_attempts { + if self.attempt >= max_attempts { + return None; + } + } + self.attempt = self.attempt.saturating_add(1); + Some(self.interval) + } } /// Provides an exponential back-off timer to delay the next retry of a failed operation. @@ -212,12 +222,12 @@ mod test { let retry_interval_max = Duration::from_secs(24 * 60 * 60); tokio::time::pause(); - let _ = retry_future_n( + let _ = retry_future( || async { 0 }, |_| true, ExponentialBackoff::new(retry_interval_initial, retry_interval_factor) - .max_delay(Some(retry_interval_max)), - 5, + .max_delay(Some(retry_interval_max)) + .take(5), ) .await; } |
