summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api/src
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-24 10:04:59 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-25 13:08:28 +0200
commitaeab100ebeb46564696dd228acec4773aeaf6684 (patch)
tree30c585053a2cb067aaaafcd756614bf224fbc4e4 /mullvad-api/src
parentbea8e150eea2a3f17511a2a094f5ed6e0844c001 (diff)
downloadmullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.tar.xz
mullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.zip
Consolidate two mutexes into one
Diffstat (limited to 'mullvad-api/src')
-rw-r--r--mullvad-api/src/availability.rs274
-rw-r--r--mullvad-api/src/lib.rs8
-rw-r--r--mullvad-api/src/rest.rs16
3 files changed, 149 insertions, 149 deletions
diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs
index ba33836b16..339aca8bca 100644
--- a/mullvad-api/src/availability.rs
+++ b/mullvad-api/src/availability.rs
@@ -1,12 +1,10 @@
use std::{
future::Future,
- sync::{Arc, Mutex},
+ sync::{Arc, Mutex, MutexGuard},
time::Duration,
};
use tokio::sync::broadcast;
-const CHANNEL_CAPACITY: usize = 100;
-
/// Pause background requests if [ApiAvailabilityHandle::reset_inactivity_timer] hasn't been
/// called for this long.
const INACTIVITY_TIME: Duration = Duration::from_secs(3 * 24 * 60 * 60);
@@ -26,182 +24,108 @@ pub struct State {
inactive: bool,
}
+#[derive(Clone, Debug)]
+pub struct ApiAvailability(Arc<Mutex<ApiAvailabilityState>>);
+
+#[derive(Debug)]
+struct ApiAvailabilityState {
+ tx: broadcast::Sender<State>,
+ state: State,
+ inactivity_timer: Option<tokio::task::JoinHandle<()>>,
+}
+
impl State {
- pub fn is_suspended(&self) -> bool {
+ pub const fn is_suspended(&self) -> bool {
self.suspended
}
- pub fn is_background_paused(&self) -> bool {
+ pub const fn is_background_paused(&self) -> bool {
self.offline || self.pause_background || self.suspended || self.inactive
}
- pub fn is_offline(&self) -> bool {
+ pub const fn is_offline(&self) -> bool {
self.offline
}
}
-pub struct ApiAvailability {
- state: Arc<Mutex<State>>,
- tx: broadcast::Sender<State>,
-
- inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
-}
-
impl ApiAvailability {
- pub fn new(initial_state: State) -> Self {
- let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY);
- let state = Arc::new(Mutex::new(initial_state));
+ const CHANNEL_CAPACITY: usize = 100;
- let availability = ApiAvailability {
- state,
+ pub fn new(initial_state: State) -> Self {
+ let (tx, _rx) = broadcast::channel(ApiAvailability::CHANNEL_CAPACITY);
+ let inner = ApiAvailabilityState {
+ state: initial_state,
+ inactivity_timer: None,
tx,
- inactivity_timer: Arc::new(Mutex::new(None)),
};
- availability.handle().reset_inactivity_timer();
- availability
+ let handle = ApiAvailability(Arc::new(Mutex::new(inner)));
+ // Start an inactivity timer
+ handle.reset_inactivity_timer();
+ handle
}
- pub fn get_state(&self) -> State {
- *self.state.lock().unwrap()
+ fn acquire(&self) -> MutexGuard<'_, ApiAvailabilityState> {
+ self.0.lock().unwrap()
}
- pub fn handle(&self) -> ApiAvailabilityHandle {
- ApiAvailabilityHandle {
- state: self.state.clone(),
- tx: self.tx.clone(),
- inactivity_timer: self.inactivity_timer.clone(),
- }
- }
-}
-
-impl Drop for ApiAvailability {
- fn drop(&mut self) {
- if let Some(timer) = self.inactivity_timer.lock().unwrap().take() {
- timer.abort();
- }
- }
-}
-
-#[derive(Clone, Debug)]
-pub struct ApiAvailabilityHandle {
- state: Arc<Mutex<State>>,
- tx: broadcast::Sender<State>,
- inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
-}
-
-impl ApiAvailabilityHandle {
/// Reset task that automatically pauses API requests due inactivity,
/// starting it if it's not currently running.
pub fn reset_inactivity_timer(&self) {
- log::trace!("Restarting API inactivity check");
-
- let self_ = self.clone();
-
- let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
- if let Some(timer) = inactivity_timer.take() {
- timer.abort();
- }
-
- self.set_active();
-
- *inactivity_timer = Some(tokio::spawn(async move {
+ let mut inner = self.0.lock().unwrap();
+ log::debug!("Restarting API inactivity check");
+ inner.stop_inactivity_timer();
+ let availability_handle = self.clone();
+ inner.inactivity_timer = Some(tokio::spawn(async move {
talpid_time::sleep(INACTIVITY_TIME).await;
- self_.set_inactive();
+ availability_handle.set_inactive();
}));
+ inner.set_active();
}
/// Stops timer that pauses API requests due to inactivity.
pub fn stop_inactivity_timer(&self) {
- log::trace!("Stopping API inactivity check");
-
- let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
- if let Some(timer) = inactivity_timer.take() {
- timer.abort();
- }
- self.set_active();
- }
-
- fn inactivity_timer_running(&self) -> bool {
- self.inactivity_timer.lock().unwrap().is_some()
- }
-
- pub fn suspend(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.suspended {
- log::debug!("Suspending API requests");
-
- state.suspended = true;
- let _ = self.tx.send(*state);
- }
- }
-
- pub fn unsuspend(&self) {
- let mut state = self.state.lock().unwrap();
- if state.suspended {
- log::debug!("Unsuspending API requests");
-
- state.suspended = false;
- let _ = self.tx.send(*state);
- }
+ self.acquire().stop_inactivity_timer();
}
pub fn pause_background(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.pause_background {
- log::debug!("Pausing background API requests");
-
- state.pause_background = true;
- let _ = self.tx.send(*state);
- }
+ self.acquire().pause_background();
}
pub fn resume_background(&self) {
- if self.inactivity_timer_running() {
+ let should_reset = {
+ let mut inner = self.acquire();
+ inner.pause_background();
+ inner.inactivity_timer_running()
+ };
+ // Note: It is important that we do not hold on to the Mutex when calling `reset_inactivity_timer()`.
+ if should_reset {
self.reset_inactivity_timer();
}
-
- let mut state = self.state.lock().unwrap();
- if state.pause_background {
- log::debug!("Resuming background API requests");
- state.pause_background = false;
- let _ = self.tx.send(*state);
- }
}
- fn set_inactive(&self) {
- let mut state = self.state.lock().unwrap();
- if !state.inactive {
- log::debug!("Pausing background API requests due to inactivity");
- state.inactive = true;
- let _ = self.tx.send(*state);
- }
+ pub fn suspend(&self) {
+ self.acquire().suspend()
}
- fn set_active(&self) {
- let mut state = self.state.lock().unwrap();
- if state.inactive {
- log::debug!("Resuming background API requests due to activity");
- state.inactive = false;
- let _ = self.tx.send(*state);
- }
+ pub fn unsuspend(&self) {
+ self.acquire().unsuspend();
}
pub fn set_offline(&self, offline: bool) {
- let mut state = self.state.lock().unwrap();
- if state.offline != offline {
- if offline {
- log::debug!("Pausing API requests due to being offline");
- } else {
- log::debug!("Resuming API requests due to coming online");
- }
+ self.acquire().set_offline(offline);
+ }
- state.offline = offline;
- let _ = self.tx.send(*state);
- }
+ fn set_inactive(&self) {
+ self.acquire().set_inactive();
+ }
+
+ /// Check if the host is offline
+ pub fn is_offline(&self) -> bool {
+ self.get_state().is_offline()
}
- pub fn get_state(&self) -> State {
- *self.state.lock().unwrap()
+ fn get_state(&self) -> State {
+ self.acquire().state
}
pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> {
@@ -236,12 +160,12 @@ impl ApiAvailabilityHandle {
&self,
state_ready: impl Fn(State) -> bool,
) -> impl Future<Output = Result<(), Error>> {
- let mut rx = self.tx.subscribe();
- let state = self.state.clone();
+ let mut rx = { self.acquire().tx.subscribe() };
+ let handle = self.clone();
async move {
- let current_state = { *state.lock().unwrap() };
- if state_ready(current_state) {
+ let state = handle.get_state();
+ if state_ready(state) {
return Ok(());
}
@@ -254,3 +178,79 @@ impl ApiAvailabilityHandle {
}
}
}
+
+impl ApiAvailabilityState {
+ fn suspend(&mut self) {
+ if !self.state.suspended {
+ log::trace!("Suspending API requests");
+ self.state.suspended = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn unsuspend(&mut self) {
+ if self.state.suspended {
+ log::trace!("Unsuspending API requests");
+ self.state.suspended = false;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn set_inactive(&mut self) {
+ log::trace!("Settings state to inactive");
+ if !self.state.inactive {
+ log::debug!("Pausing background API requests due to inactivity");
+ self.state.inactive = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn set_active(&mut self) {
+ log::trace!("Settings state to active");
+ if self.state.inactive {
+ log::debug!("Resuming background API requests due to activity");
+ self.state.inactive = false;
+ let _ = self.tx.send(self.state).inspect_err(|send_err| {
+ log::debug!("All receivers of state updates have been dropped");
+ log::debug!("{send_err}");
+ });
+ }
+ }
+
+ fn set_offline(&mut self, offline: bool) {
+ if offline {
+ log::debug!("Pausing API requests due to being offline");
+ } else {
+ log::debug!("Resuming API requests due to coming online");
+ }
+ if self.state.offline != offline {
+ self.state.offline = offline;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn pause_background(&mut self) {
+ if !self.state.pause_background {
+ log::debug!("Pausing background API requests");
+ self.state.pause_background = true;
+ let _ = self.tx.send(self.state);
+ }
+ }
+
+ fn stop_inactivity_timer(&mut self) {
+ log::debug!("Stopping API inactivity check");
+ if let Some(timer) = self.inactivity_timer.take() {
+ timer.abort();
+ }
+ }
+
+ const fn inactivity_timer_running(&self) -> bool {
+ self.inactivity_timer.is_some()
+ }
+}
+
+impl Drop for ApiAvailabilityState {
+ fn drop(&mut self) {
+ self.stop_inactivity_timer();
+ }
+}
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 87b6e3d656..8add11d30a 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -21,7 +21,7 @@ use std::{
use talpid_types::ErrorExt;
pub mod availability;
-use availability::{ApiAvailability, ApiAvailabilityHandle};
+use availability::ApiAvailability;
pub mod rest;
mod abortable_stream;
@@ -414,7 +414,7 @@ impl Runtime {
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
- self.api_availability.handle(),
+ self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
#[cfg(target_os = "android")]
@@ -467,8 +467,8 @@ impl Runtime {
&mut self.handle
}
- pub fn availability_handle(&self) -> ApiAvailabilityHandle {
- self.api_availability.handle()
+ pub fn availability_handle(&self) -> ApiAvailability {
+ self.api_availability.clone()
}
pub fn address_cache(&self) -> &AddressCache {
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 238d73206a..bbcef79903 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -3,7 +3,7 @@ pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
access::AccessTokenStore,
address_cache::AddressCache,
- availability::ApiAvailabilityHandle,
+ availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
};
@@ -122,14 +122,14 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> {
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
connection_mode_provider: T,
connection_mode_generation: usize,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
}
impl<T: ConnectionModeProvider + 'static> RequestService<T> {
/// Constructs a new request service.
pub fn spawn(
sni_hostname: Option<String>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
@@ -218,7 +218,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
// Switch API endpoint if the request failed due to a network error
if let Err(err) = &response {
- if err.is_network_error() && !api_availability.get_state().is_offline() {
+ if err.is_network_error() && !api_availability.is_offline() {
log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
if let Some(tx) = tx {
let _ = tx.unbounded_send(RequestCommand::NextApiConfig(
@@ -339,7 +339,7 @@ impl Request {
async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
self,
hyper_client: hyper::Client<C>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
) -> Result<Response> {
let timeout = self.timeout;
let inner_fut = self.into_future_without_timeout(hyper_client, api_availability);
@@ -351,7 +351,7 @@ impl Request {
async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>(
mut self,
hyper_client: hyper::Client<C>,
- api_availability: ApiAvailabilityHandle,
+ api_availability: ApiAvailability,
) -> Result<Response> {
let _ = api_availability.wait_for_unsuspend().await;
@@ -605,14 +605,14 @@ async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
- pub availability: ApiAvailabilityHandle,
+ pub availability: ApiAvailability,
}
impl MullvadRestHandle {
pub(crate) fn new(
service: RequestServiceHandle,
factory: RequestFactory,
- availability: ApiAvailabilityHandle,
+ availability: ApiAvailability,
) -> Self {
Self {
service,