diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-09-24 10:04:59 +0200 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-09-25 13:08:28 +0200 |
| commit | aeab100ebeb46564696dd228acec4773aeaf6684 (patch) | |
| tree | 30c585053a2cb067aaaafcd756614bf224fbc4e4 /mullvad-api/src | |
| parent | bea8e150eea2a3f17511a2a094f5ed6e0844c001 (diff) | |
| download | mullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.tar.xz mullvadvpn-aeab100ebeb46564696dd228acec4773aeaf6684.zip | |
Consolidate two mutexes into one
Diffstat (limited to 'mullvad-api/src')
| -rw-r--r-- | mullvad-api/src/availability.rs | 274 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 8 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 16 |
3 files changed, 149 insertions, 149 deletions
diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs index ba33836b16..339aca8bca 100644 --- a/mullvad-api/src/availability.rs +++ b/mullvad-api/src/availability.rs @@ -1,12 +1,10 @@ use std::{ future::Future, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, time::Duration, }; use tokio::sync::broadcast; -const CHANNEL_CAPACITY: usize = 100; - /// Pause background requests if [ApiAvailabilityHandle::reset_inactivity_timer] hasn't been /// called for this long. const INACTIVITY_TIME: Duration = Duration::from_secs(3 * 24 * 60 * 60); @@ -26,182 +24,108 @@ pub struct State { inactive: bool, } +#[derive(Clone, Debug)] +pub struct ApiAvailability(Arc<Mutex<ApiAvailabilityState>>); + +#[derive(Debug)] +struct ApiAvailabilityState { + tx: broadcast::Sender<State>, + state: State, + inactivity_timer: Option<tokio::task::JoinHandle<()>>, +} + impl State { - pub fn is_suspended(&self) -> bool { + pub const fn is_suspended(&self) -> bool { self.suspended } - pub fn is_background_paused(&self) -> bool { + pub const fn is_background_paused(&self) -> bool { self.offline || self.pause_background || self.suspended || self.inactive } - pub fn is_offline(&self) -> bool { + pub const fn is_offline(&self) -> bool { self.offline } } -pub struct ApiAvailability { - state: Arc<Mutex<State>>, - tx: broadcast::Sender<State>, - - inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>, -} - impl ApiAvailability { - pub fn new(initial_state: State) -> Self { - let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY); - let state = Arc::new(Mutex::new(initial_state)); + const CHANNEL_CAPACITY: usize = 100; - let availability = ApiAvailability { - state, + pub fn new(initial_state: State) -> Self { + let (tx, _rx) = broadcast::channel(ApiAvailability::CHANNEL_CAPACITY); + let inner = ApiAvailabilityState { + state: initial_state, + inactivity_timer: None, tx, - inactivity_timer: Arc::new(Mutex::new(None)), }; - availability.handle().reset_inactivity_timer(); - availability + let handle = ApiAvailability(Arc::new(Mutex::new(inner))); + // Start an inactivity timer + handle.reset_inactivity_timer(); + handle } - pub fn get_state(&self) -> State { - *self.state.lock().unwrap() + fn acquire(&self) -> MutexGuard<'_, ApiAvailabilityState> { + self.0.lock().unwrap() } - pub fn handle(&self) -> ApiAvailabilityHandle { - ApiAvailabilityHandle { - state: self.state.clone(), - tx: self.tx.clone(), - inactivity_timer: self.inactivity_timer.clone(), - } - } -} - -impl Drop for ApiAvailability { - fn drop(&mut self) { - if let Some(timer) = self.inactivity_timer.lock().unwrap().take() { - timer.abort(); - } - } -} - -#[derive(Clone, Debug)] -pub struct ApiAvailabilityHandle { - state: Arc<Mutex<State>>, - tx: broadcast::Sender<State>, - inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>, -} - -impl ApiAvailabilityHandle { /// Reset task that automatically pauses API requests due inactivity, /// starting it if it's not currently running. pub fn reset_inactivity_timer(&self) { - log::trace!("Restarting API inactivity check"); - - let self_ = self.clone(); - - let mut inactivity_timer = self.inactivity_timer.lock().unwrap(); - if let Some(timer) = inactivity_timer.take() { - timer.abort(); - } - - self.set_active(); - - *inactivity_timer = Some(tokio::spawn(async move { + let mut inner = self.0.lock().unwrap(); + log::debug!("Restarting API inactivity check"); + inner.stop_inactivity_timer(); + let availability_handle = self.clone(); + inner.inactivity_timer = Some(tokio::spawn(async move { talpid_time::sleep(INACTIVITY_TIME).await; - self_.set_inactive(); + availability_handle.set_inactive(); })); + inner.set_active(); } /// Stops timer that pauses API requests due to inactivity. pub fn stop_inactivity_timer(&self) { - log::trace!("Stopping API inactivity check"); - - let mut inactivity_timer = self.inactivity_timer.lock().unwrap(); - if let Some(timer) = inactivity_timer.take() { - timer.abort(); - } - self.set_active(); - } - - fn inactivity_timer_running(&self) -> bool { - self.inactivity_timer.lock().unwrap().is_some() - } - - pub fn suspend(&self) { - let mut state = self.state.lock().unwrap(); - if !state.suspended { - log::debug!("Suspending API requests"); - - state.suspended = true; - let _ = self.tx.send(*state); - } - } - - pub fn unsuspend(&self) { - let mut state = self.state.lock().unwrap(); - if state.suspended { - log::debug!("Unsuspending API requests"); - - state.suspended = false; - let _ = self.tx.send(*state); - } + self.acquire().stop_inactivity_timer(); } pub fn pause_background(&self) { - let mut state = self.state.lock().unwrap(); - if !state.pause_background { - log::debug!("Pausing background API requests"); - - state.pause_background = true; - let _ = self.tx.send(*state); - } + self.acquire().pause_background(); } pub fn resume_background(&self) { - if self.inactivity_timer_running() { + let should_reset = { + let mut inner = self.acquire(); + inner.pause_background(); + inner.inactivity_timer_running() + }; + // Note: It is important that we do not hold on to the Mutex when calling `reset_inactivity_timer()`. + if should_reset { self.reset_inactivity_timer(); } - - let mut state = self.state.lock().unwrap(); - if state.pause_background { - log::debug!("Resuming background API requests"); - state.pause_background = false; - let _ = self.tx.send(*state); - } } - fn set_inactive(&self) { - let mut state = self.state.lock().unwrap(); - if !state.inactive { - log::debug!("Pausing background API requests due to inactivity"); - state.inactive = true; - let _ = self.tx.send(*state); - } + pub fn suspend(&self) { + self.acquire().suspend() } - fn set_active(&self) { - let mut state = self.state.lock().unwrap(); - if state.inactive { - log::debug!("Resuming background API requests due to activity"); - state.inactive = false; - let _ = self.tx.send(*state); - } + pub fn unsuspend(&self) { + self.acquire().unsuspend(); } pub fn set_offline(&self, offline: bool) { - let mut state = self.state.lock().unwrap(); - if state.offline != offline { - if offline { - log::debug!("Pausing API requests due to being offline"); - } else { - log::debug!("Resuming API requests due to coming online"); - } + self.acquire().set_offline(offline); + } - state.offline = offline; - let _ = self.tx.send(*state); - } + fn set_inactive(&self) { + self.acquire().set_inactive(); + } + + /// Check if the host is offline + pub fn is_offline(&self) -> bool { + self.get_state().is_offline() } - pub fn get_state(&self) -> State { - *self.state.lock().unwrap() + fn get_state(&self) -> State { + self.acquire().state } pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> { @@ -236,12 +160,12 @@ impl ApiAvailabilityHandle { &self, state_ready: impl Fn(State) -> bool, ) -> impl Future<Output = Result<(), Error>> { - let mut rx = self.tx.subscribe(); - let state = self.state.clone(); + let mut rx = { self.acquire().tx.subscribe() }; + let handle = self.clone(); async move { - let current_state = { *state.lock().unwrap() }; - if state_ready(current_state) { + let state = handle.get_state(); + if state_ready(state) { return Ok(()); } @@ -254,3 +178,79 @@ impl ApiAvailabilityHandle { } } } + +impl ApiAvailabilityState { + fn suspend(&mut self) { + if !self.state.suspended { + log::trace!("Suspending API requests"); + self.state.suspended = true; + let _ = self.tx.send(self.state); + } + } + + fn unsuspend(&mut self) { + if self.state.suspended { + log::trace!("Unsuspending API requests"); + self.state.suspended = false; + let _ = self.tx.send(self.state); + } + } + + fn set_inactive(&mut self) { + log::trace!("Settings state to inactive"); + if !self.state.inactive { + log::debug!("Pausing background API requests due to inactivity"); + self.state.inactive = true; + let _ = self.tx.send(self.state); + } + } + + fn set_active(&mut self) { + log::trace!("Settings state to active"); + if self.state.inactive { + log::debug!("Resuming background API requests due to activity"); + self.state.inactive = false; + let _ = self.tx.send(self.state).inspect_err(|send_err| { + log::debug!("All receivers of state updates have been dropped"); + log::debug!("{send_err}"); + }); + } + } + + fn set_offline(&mut self, offline: bool) { + if offline { + log::debug!("Pausing API requests due to being offline"); + } else { + log::debug!("Resuming API requests due to coming online"); + } + if self.state.offline != offline { + self.state.offline = offline; + let _ = self.tx.send(self.state); + } + } + + fn pause_background(&mut self) { + if !self.state.pause_background { + log::debug!("Pausing background API requests"); + self.state.pause_background = true; + let _ = self.tx.send(self.state); + } + } + + fn stop_inactivity_timer(&mut self) { + log::debug!("Stopping API inactivity check"); + if let Some(timer) = self.inactivity_timer.take() { + timer.abort(); + } + } + + const fn inactivity_timer_running(&self) -> bool { + self.inactivity_timer.is_some() + } +} + +impl Drop for ApiAvailabilityState { + fn drop(&mut self) { + self.stop_inactivity_timer(); + } +} diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 87b6e3d656..8add11d30a 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -21,7 +21,7 @@ use std::{ use talpid_types::ErrorExt; pub mod availability; -use availability::{ApiAvailability, ApiAvailabilityHandle}; +use availability::ApiAvailability; pub mod rest; mod abortable_stream; @@ -414,7 +414,7 @@ impl Runtime { ) -> rest::RequestServiceHandle { rest::RequestService::spawn( sni_hostname, - self.api_availability.handle(), + self.api_availability.clone(), self.address_cache.clone(), connection_mode_provider, #[cfg(target_os = "android")] @@ -467,8 +467,8 @@ impl Runtime { &mut self.handle } - pub fn availability_handle(&self) -> ApiAvailabilityHandle { - self.api_availability.handle() + pub fn availability_handle(&self) -> ApiAvailability { + self.api_availability.clone() } pub fn address_cache(&self) -> &AddressCache { diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 238d73206a..bbcef79903 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -3,7 +3,7 @@ pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ access::AccessTokenStore, address_cache::AddressCache, - availability::ApiAvailabilityHandle, + availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, }; @@ -122,14 +122,14 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> { client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, connection_mode_provider: T, connection_mode_generation: usize, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, } impl<T: ConnectionModeProvider + 'static> RequestService<T> { /// Constructs a new request service. pub fn spawn( sni_hostname: Option<String>, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, address_cache: AddressCache, connection_mode_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, @@ -218,7 +218,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { // Switch API endpoint if the request failed due to a network error if let Err(err) = &response { - if err.is_network_error() && !api_availability.get_state().is_offline() { + if err.is_network_error() && !api_availability.is_offline() { log::error!("{}", err.display_chain_with_msg("HTTP request failed")); if let Some(tx) = tx { let _ = tx.unbounded_send(RequestCommand::NextApiConfig( @@ -339,7 +339,7 @@ impl Request { async fn into_future<C: Connect + Clone + Send + Sync + 'static>( self, hyper_client: hyper::Client<C>, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, ) -> Result<Response> { let timeout = self.timeout; let inner_fut = self.into_future_without_timeout(hyper_client, api_availability); @@ -351,7 +351,7 @@ impl Request { async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>( mut self, hyper_client: hyper::Client<C>, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, ) -> Result<Response> { let _ = api_availability.wait_for_unsuspend().await; @@ -605,14 +605,14 @@ async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, - pub availability: ApiAvailabilityHandle, + pub availability: ApiAvailability, } impl MullvadRestHandle { pub(crate) fn new( service: RequestServiceHandle, factory: RequestFactory, - availability: ApiAvailabilityHandle, + availability: ApiAvailability, ) -> Self { Self { service, |
