summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-24 10:04:59 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-25 13:08:28 +0200
commitaeab100ebeb46564696dd228acec4773aeaf6684 (patch)
tree30c585053a2cb067aaaafcd756614bf224fbc4e4
parentbea8e150eea2a3f17511a2a094f5ed6e0844c001 (diff)
downloadmullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.tar.xz
mullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.zip
Consolidate two mutexes into one
-rw-r--r--mullvad-api/src/availability.rs274
-rw-r--r--mullvad-api/src/lib.rs8
-rw-r--r--mullvad-api/src/rest.rs16
-rw-r--r--mullvad-daemon/src/api.rs6
-rw-r--r--mullvad-daemon/src/device/service.rs16
-rw-r--r--mullvad-daemon/src/relay_list/mod.rs6
-rw-r--r--mullvad-daemon/src/version_check.rs8
7 files changed, 167 insertions, 167 deletions
diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs
index ba33836b16..339aca8bca 100644
--- a/mullvad-api/src/availability.rs
+++ b/mullvad-api/src/availability.rs
@@ -1,12 +1,10 @@
use std::{
future::Future,
- sync::{Arc, Mutex},
+ sync::{Arc, Mutex, MutexGuard},
time::Duration,
};
use tokio::sync::broadcast;
-const CHANNEL_CAPACITY: usize = 100;
-
/// Pause background requests if [ApiAvailabilityHandle::reset_inactivity_timer] hasn't been
/// called for this long.
const INACTIVITY_TIME: Duration = Duration::from_secs(3 * 24 * 60 * 60);
@@ -26,182 +24,108 @@ pub struct State {
inactive: bool,
}
+#[derive(Clone, Debug)]
+pub struct ApiAvailability(Arc<Mutex<ApiAvailabilityState>>);
+
+#[derive(Debug)]
+struct ApiAvailabilityState {
+ tx: broadcast::Sender<State>,
+ state: State,
+ inactivity_timer: Option<tokio::task::JoinHandle<()>>,
+}
+
impl State {
- pub fn is_suspended(&self) -> bool {
+ pub const fn is_suspended(&self) -> bool {
self.suspended
}
- pub fn is_background_paused(&self) -> bool {
+ pub const fn is_background_paused(&self) -> bool {
self.offline || self.pause_background || self.suspended || self.inactive
}
- pub fn is_offline(&self) -> bool {
+ pub const fn is_offline(&self) -> bool {
self.offline
}
}
-pub struct ApiAvailability {
- state: Arc<Mutex<State>>,
- tx: broadcast::Sender<State>,
-
- inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
-}
-
impl ApiAvailability {
- pub fn new(initial_state: State) -> Self {
- let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY);
- let state = Arc::new(Mutex::new(initial_state));
+ const CHANNEL_CAPACITY: usize = 100;
- let availability = ApiAvailability {
- state,
+ pub fn new(initial_state: State) -> Self {
+ let (tx, _rx) = broadcast::channel(ApiAvailability::CHANNEL_CAPACITY);
+ let inner = ApiAvailabilityState {
+ state: initial_state,
+ inactivity_timer: None,
tx,
- inactivity_timer: Arc::new(Mutex::new(None)),
};
- availability.handle().reset_inactivity_timer();
- availability
+ let handle = ApiAvailability(Arc::new(Mutex::new(inner)));
+ // Start an inactivity timer
+ handle.reset_inactivity_timer();
+ handle
}
- pub fn get_state(&self) -> State {
- *self.state.lock().unwrap()
+ fn acquire(&self) -> MutexGuard<'_, ApiAvailabilityState> {
+ self.0.lock().unwrap()
}
- pub fn handle(&self) -> ApiAvailabilityHandle {
- ApiAvailabilityHandle {
- state: self.state.clone(),
- tx: self.tx.clone(),
- inactivity_timer: self.inactivity_timer.clone(),
- }
- }
-}
-
-impl Drop for ApiAvailability {
- fn drop(&mut self) {
- if let Some(timer) = self.inactivity_timer.lock().unwrap().take() {
- timer.abort();
- }
- }
-}
-
-#[derive(Clone, Debug)]
-pub struct ApiAvailabilityHandle {
- state: Arc<Mutex<State>>,
- tx: broadcast::Sender<State>,
- inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
-}
-
-impl ApiAvailabilityHandle {
/// Reset task that automatically pauses API requests due inactivity,
/// starting it if it's not currently running.
pub fn reset_inactivity_timer(&self) {
- log::trace!("Restarting API inactivity check");
-
- let self_ = self.clone();
-
- let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
- if let Some(timer) = inactivity_timer.take() {
- timer.abort();
- }
-
- self.set_active();
-
- *inactivity_timer = Some(tokio::spawn(async move {
+ let mut inner = self.0.lock().unwrap();
+ log::debug!("Restarting API inactivity check");
+ inner.stop_inactivity_timer();
+ let availability_handle = self.clone();
+ inner.inactivity_timer = Some(tokio::spawn(async move {
talpid_time::sleep(INACTIVITY_TIME).await;
- self_.set_inactive();
+ availability_handle.set_inactive();
}));
+ inner.set_active();
}
/// Stops timer that pauses API requests due to inactivity.
pub fn stop_inactivity_timer(&self) {
- log::trace!("Stopping API inactivity check");
-
- let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
- if let Some(timer) = inactivity_timer.take() {
- timer.abort();
- }
- self.set_active();
- }
-
- fn inactivity_timer_running(&self) -> bool {
- self.inactivity_timer.lock().unwrap().is_some()
- }
-
- pub fn suspend(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.suspended {
- log::debug!("Suspending API requests");
-
- state.suspended = true;
- let _ = self.tx.send(*state);
- }
- }
-
- pub fn unsuspend(&self) {
- let mut state = self.state.lock().unwrap();
- if state.suspended {
- log::debug!("Unsuspending API requests");
-
- state.suspended = false;
- let _ = self.tx.send(*state);
- }
+ self.acquire().stop_inactivity_timer();
}
pub fn pause_background(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.pause_background {
- log::debug!("Pausing background API requests");
-
- state.pause_background = true;
- let _ = self.tx.send(*state);
- }
+ self.acquire().pause_background();
}
pub fn resume_background(&self) {
- if self.inactivity_timer_running() {
+ let should_reset = {
+ let mut inner = self.acquire();
+ inner.pause_background();
+ inner.inactivity_timer_running()
+ };
+ // Note: It is important that we do not hold on to the Mutex when calling `reset_inactivity_timer()`.
+ if should_reset {
self.reset_inactivity_timer();
}
-
- let mut state = self.state.lock().unwrap();
- if state.pause_background {
- log::debug!("Resuming background API requests");
- state.pause_background = false;
- let _ = self.tx.send(*state);
- }
}
- fn set_inactive(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.inactive {
- log::debug!("Pausing background API requests due to inactivity");
- state.inactive = true;
- let _ = self.tx.send(*state);
- }
+ pub fn suspend(&self) {
+ self.acquire().suspend()
}
- fn set_active(&self) {
- let mut state = self.state.lock().unwrap();
- if state.inactive {
- log::debug!("Resuming background API requests due to activity");
- state.inactive = false;
- let _ = self.tx.send(*state);
- }
+ pub fn unsuspend(&self) {
+ self.acquire().unsuspend();
}
pub fn set_offline(&self, offline: bool) {
- let mut state = self.state.lock().unwrap();
- if state.offline != offline {
- if offline {
- log::debug!("Pausing API requests due to being offline");
- } else {
- log::debug!("Resuming API requests due to coming online");
- }
+ self.acquire().set_offline(offline);
+ }
- state.offline = offline;
- let _ = self.tx.send(*state);
- }
+ fn set_inactive(&self) {
+ self.acquire().set_inactive();
+ }
+
+ /// Check if the host is offline
+ pub fn is_offline(&self) -> bool {
+ self.get_state().is_offline()
}
- pub fn get_state(&self) -> State {
- *self.state.lock().unwrap()
+ fn get_state(&self) -> State {
+ self.acquire().state
}
pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> {
@@ -236,12 +160,12 @@ impl ApiAvailabilityHandle {
&self,
state_ready: impl Fn(State) -> bool,
) -> impl Future<Output = Result<(), Error>> {
- let mut rx = self.tx.subscribe();
- let state = self.state.clone();
+ let mut rx = { self.acquire().tx.subscribe() };
+ let handle = self.clone();
async move {
- let current_state = { *state.lock().unwrap() };
- if state_ready(current_state) {
+ let state = handle.get_state();
+ if state_ready(state) {
return Ok(());
}
@@ -254,3 +178,79 @@ impl ApiAvailabilityHandle {
}
}
}
+
+impl ApiAvailabilityState {
+ fn suspend(&mut self) {
+ if !self.state.suspended {
+ log::trace!("Suspending API requests");
+ self.state.suspended = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn unsuspend(&mut self) {
+ if self.state.suspended {
+ log::trace!("Unsuspending API requests");
+ self.state.suspended = false;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn set_inactive(&mut self) {
+ log::trace!("Settings state to inactive");
+ if !self.state.inactive {
+ log::debug!("Pausing background API requests due to inactivity");
+ self.state.inactive = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn set_active(&mut self) {
+ log::trace!("Settings state to active");
+ if self.state.inactive {
+ log::debug!("Resuming background API requests due to activity");
+ self.state.inactive = false;
+ let _ = self.tx.send(self.state).inspect_err(|send_err| {
+ log::debug!("All receivers of state updates have been dropped");
+ log::debug!("{send_err}");
+ });
+ }
+ }
+
+ fn set_offline(&mut self, offline: bool) {
+ if offline {
+ log::debug!("Pausing API requests due to being offline");
+ } else {
+ log::debug!("Resuming API requests due to coming online");
+ }
+ if self.state.offline != offline {
+ self.state.offline = offline;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn pause_background(&mut self) {
+ if !self.state.pause_background {
+ log::debug!("Pausing background API requests");
+ self.state.pause_background = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn stop_inactivity_timer(&mut self) {
+ log::debug!("Stopping API inactivity check");
+ if let Some(timer) = self.inactivity_timer.take() {
+ timer.abort();
+ }
+ }
+
+ const fn inactivity_timer_running(&self) -> bool {
+ self.inactivity_timer.is_some()
+ }
+}
+
+impl Drop for ApiAvailabilityState {
+ fn drop(&mut self) {
+ self.stop_inactivity_timer();
+ }
+}
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 87b6e3d656..8add11d30a 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -21,7 +21,7 @@ use std::{
use talpid_types::ErrorExt;
pub mod availability;
-use availability::{ApiAvailability, ApiAvailabilityHandle};
+use availability::ApiAvailability;
pub mod rest;
mod abortable_stream;
@@ -414,7 +414,7 @@ impl Runtime {
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
- self.api_availability.handle(),
+ self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
#[cfg(target_os = "android")]
@@ -467,8 +467,8 @@ impl Runtime {
&mut self.handle
}
- pub fn availability_handle(&self) -> ApiAvailabilityHandle {
- self.api_availability.handle()
+ pub fn availability_handle(&self) -> ApiAvailability {
+ self.api_availability.clone()
}
pub fn address_cache(&self) -> &AddressCache {
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 238d73206a..bbcef79903 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -3,7 +3,7 @@ pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
access::AccessTokenStore,
address_cache::AddressCache,
- availability::ApiAvailabilityHandle,
+ availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
};
@@ -122,14 +122,14 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> {
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
connection_mode_provider: T,
connection_mode_generation: usize,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
}
impl<T: ConnectionModeProvider + 'static> RequestService<T> {
/// Constructs a new request service.
pub fn spawn(
sni_hostname: Option<String>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
@@ -218,7 +218,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
// Switch API endpoint if the request failed due to a network error
if let Err(err) = &response {
- if err.is_network_error() && !api_availability.get_state().is_offline() {
+ if err.is_network_error() && !api_availability.is_offline() {
log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
if let Some(tx) = tx {
let _ = tx.unbounded_send(RequestCommand::NextApiConfig(
@@ -339,7 +339,7 @@ impl Request {
async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
self,
hyper_client: hyper::Client<C>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
) -> Result<Response> {
let timeout = self.timeout;
let inner_fut = self.into_future_without_timeout(hyper_client, api_availability);
@@ -351,7 +351,7 @@ impl Request {
async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>(
mut self,
hyper_client: hyper::Client<C>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
) -> Result<Response> {
let _ = api_availability.wait_for_unsuspend().await;
@@ -605,14 +605,14 @@ async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
- pub availability: ApiAvailabilityHandle,
+ pub availability: ApiAvailability,
}
impl MullvadRestHandle {
pub(crate) fn new(
service: RequestServiceHandle,
factory: RequestFactory,
- availability: ApiAvailabilityHandle,
+ availability: ApiAvailability,
) -> Self {
Self {
service,
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index acfdbb7664..ac54382a57 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -11,7 +11,7 @@ use futures::{
StreamExt,
};
use mullvad_api::{
- availability::ApiAvailabilityHandle,
+ availability::ApiAvailability,
proxy::{ApiConnectionMode, ConnectionModeProvider, ProxyConfig},
AddressCache,
};
@@ -578,9 +578,9 @@ pub fn allowed_clients(connection_mode: &ApiConnectionMode) -> AllowedClients {
}
}
-/// Forwards the received values from `offline_state_rx` to the [`ApiAvailabilityHandle`].
+/// Forwards the received values from `offline_state_rx` to the [`ApiAvailability`].
pub(crate) fn forward_offline_state(
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
mut offline_state_rx: mpsc::UnboundedReceiver<Connectivity>,
) {
tokio::spawn(async move {
diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs
index c4c949ceba..093dd14f67 100644
--- a/mullvad-daemon/src/device/service.rs
+++ b/mullvad-daemon/src/device/service.rs
@@ -13,7 +13,7 @@ use talpid_types::net::wireguard::PrivateKey;
use super::{Error, PrivateAccountAndDevice, PrivateDevice};
use mullvad_api::{
- availability::ApiAvailabilityHandle,
+ availability::ApiAvailability,
rest::{self, MullvadRestHandle},
AccountsProxy, DevicesProxy,
};
@@ -28,12 +28,12 @@ const RETRY_BACKOFF_STRATEGY: Jittered<ExponentialBackoff> = Jittered::jitter(
#[derive(Clone)]
pub struct DeviceService {
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
proxy: DevicesProxy,
}
impl DeviceService {
- pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self {
+ pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailability) -> Self {
Self {
proxy: DevicesProxy::new(handle),
api_availability,
@@ -255,7 +255,7 @@ impl DeviceService {
#[derive(Clone)]
pub struct AccountService {
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
initial_check_abort_handle: AbortHandle,
proxy: AccountsProxy,
}
@@ -368,7 +368,7 @@ impl AccountService {
pub fn spawn_account_service(
api_handle: MullvadRestHandle,
token: Option<String>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
) -> AccountService {
let accounts_proxy = AccountsProxy::new(api_handle);
api_availability.pause_background();
@@ -403,7 +403,7 @@ pub fn spawn_account_service(
fn handle_account_data_result(
result: &Result<AccountData, rest::Error>,
- api_availability: &ApiAvailabilityHandle,
+ api_availability: &ApiAvailability,
) -> bool {
match result {
Ok(_data) if _data.expiry >= chrono::Utc::now() => {
@@ -425,9 +425,9 @@ fn handle_account_data_result(
}
}
-fn should_retry<T>(result: &Result<T, rest::Error>, api_handle: &ApiAvailabilityHandle) -> bool {
+fn should_retry<T>(result: &Result<T, rest::Error>, api_handle: &ApiAvailability) -> bool {
match result {
- Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
+ Err(error) if error.is_network_error() => !api_handle.is_offline(),
_ => false,
}
}
diff --git a/mullvad-daemon/src/relay_list/mod.rs b/mullvad-daemon/src/relay_list/mod.rs
index 2b4be3db54..99fa60df57 100644
--- a/mullvad-daemon/src/relay_list/mod.rs
+++ b/mullvad-daemon/src/relay_list/mod.rs
@@ -11,7 +11,7 @@ use std::{
};
use tokio::fs::File;
-use mullvad_api::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, RelayListProxy};
+use mullvad_api::{availability::ApiAvailability, rest::MullvadRestHandle, RelayListProxy};
use mullvad_relay_selector::RelaySelector;
use mullvad_types::relay_list::RelayList;
use talpid_future::retry::{retry_future, ExponentialBackoff, Jittered};
@@ -68,7 +68,7 @@ pub struct RelayListUpdater {
relay_selector: RelaySelector,
on_update: Box<dyn Fn(&RelayList) + Send + 'static>,
last_check: SystemTime,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
}
impl RelayListUpdater {
@@ -163,7 +163,7 @@ impl RelayListUpdater {
}
fn download_relay_list(
- api_handle: ApiAvailabilityHandle,
+ api_handle: ApiAvailability,
proxy: RelayListProxy,
tag: Option<String>,
) -> impl Future<Output = Result<Option<RelayList>, mullvad_api::Error>> + 'static {
diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs
index 4aae30f574..ed50eb6ff5 100644
--- a/mullvad-daemon/src/version_check.rs
+++ b/mullvad-daemon/src/version_check.rs
@@ -4,7 +4,7 @@ use futures::{
future::{BoxFuture, FusedFuture},
FutureExt, SinkExt, StreamExt, TryFutureExt,
};
-use mullvad_api::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, AppVersionProxy};
+use mullvad_api::{availability::ApiAvailability, rest::MullvadRestHandle, AppVersionProxy};
use mullvad_types::version::{AppVersionInfo, ParsedAppVersion};
use serde::{Deserialize, Serialize};
use std::{
@@ -149,7 +149,7 @@ impl VersionUpdaterHandle {
impl VersionUpdater {
pub async fn spawn(
mut api_handle: MullvadRestHandle,
- availability_handle: ApiAvailabilityHandle,
+ availability_handle: ApiAvailability,
cache_dir: PathBuf,
update_sender: DaemonEventSender<AppVersionInfo>,
show_beta_releases: bool,
@@ -413,7 +413,7 @@ impl UpdateContext {
#[derive(Clone)]
struct ApiContext {
- api_handle: ApiAvailabilityHandle,
+ api_handle: ApiAvailability,
version_proxy: AppVersionProxy,
platform_version: String,
}
@@ -435,7 +435,7 @@ fn do_version_check(
// retry immediately on network errors (unless we're offline)
let should_retry_immediate = move |result: &Result<_, Error>| {
if let Err(Error::Download(error)) = result {
- error.is_network_error() && !api.api_handle.get_state().is_offline()
+ error.is_network_error() && !api.api_handle.is_offline()
} else {
false
}