diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-09-15 13:39:29 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-09-15 13:39:29 +0200 |
| commit | 301503dcd59185a5f8bb7d9741dcef21319023b6 (patch) | |
| tree | 51cbf0ab59253931e01c5daadb8465654dc7c264 | |
| parent | e5b9e8ddea9ce8f70f412d5a7b1f70c675161f30 (diff) | |
| parent | 6856f00b982726051099706a0159e84aac93e607 (diff) | |
| download | mullvadvpn-301503dcd59185a5f8bb7d9741dcef21319023b6.tar.xz mullvadvpn-301503dcd59185a5f8bb7d9741dcef21319023b6.zip | |
Merge branch 'pause-automatic-requests'
| -rw-r--r-- | mullvad-daemon/src/account.rs | 120 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 93 | ||||
| -rw-r--r-- | mullvad-daemon/src/relays.rs | 51 | ||||
| -rw-r--r-- | mullvad-daemon/src/version_check.rs | 17 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 56 | ||||
| -rw-r--r-- | mullvad-rpc/src/availability.rs | 127 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 24 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 59 | ||||
| -rw-r--r-- | mullvad-types/src/account.rs | 7 | ||||
| -rw-r--r-- | talpid-core/src/offline/android.rs | 26 | ||||
| -rw-r--r-- | talpid-core/src/offline/linux.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/offline/macos.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/offline/mod.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/offline/windows.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 19 |
15 files changed, 549 insertions, 110 deletions
diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs new file mode 100644 index 0000000000..88b996743a --- /dev/null +++ b/mullvad-daemon/src/account.rs @@ -0,0 +1,120 @@ +use chrono::{DateTime, Utc}; +use futures::future::{abortable, AbortHandle}; +use mullvad_rpc::{ + availability::ApiAvailabilityHandle, + rest::{self, MullvadRestHandle}, + AccountsProxy, +}; +use mullvad_types::account::{AccountToken, VoucherSubmission}; +use std::time::Duration; +use talpid_core::future_retry::{retry_future_with_backoff, ExponentialBackoff, Jittered}; + +const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4); +const RETRY_INTERVAL_FACTOR: u32 = 5; +const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); + + +pub struct Account(()); + +#[derive(Clone)] +pub struct AccountHandle { + api_availability: ApiAvailabilityHandle, + initial_check_abort_handle: AbortHandle, + pub proxy: AccountsProxy, +} + +impl AccountHandle { + pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { + let result = self.proxy.get_expiry(token).await; + if handle_expiry_result_inner(&result, &self.api_availability) { + self.initial_check_abort_handle.abort(); + } + result + } + + pub async fn submit_voucher( + &mut self, + account_token: AccountToken, + voucher: String, + ) -> Result<VoucherSubmission, rest::Error> { + let result = self.proxy.submit_voucher(account_token, voucher).await; + if result.is_ok() { + self.initial_check_abort_handle.abort(); + self.api_availability.resume(); + } + result + } +} + +impl Account { + pub fn new( + runtime: tokio::runtime::Handle, + rpc_handle: MullvadRestHandle, + token: Option<String>, + api_availability: ApiAvailabilityHandle, + ) -> AccountHandle { + let accounts_proxy = AccountsProxy::new(rpc_handle); + api_availability.pause(); + + let api_availability_copy = api_availability.clone(); + let accounts_proxy_copy = accounts_proxy.clone(); + + let (future, initial_check_abort_handle) = abortable(async move { + let token = if let Some(token) = token { + token + } else { + api_availability.pause(); + return; + }; + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) + .max_delay(RETRY_INTERVAL_MAX), + ); + let future_generator = move || { + let wait_online = api_availability.wait_online(); + let expiry_fut = accounts_proxy.get_expiry(token.clone()); + let api_availability_copy = api_availability.clone(); + async move { + let _ = wait_online.await; + handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) + } + }; + let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; + let retry_future = + retry_future_with_backoff(future_generator, should_retry, retry_strategy); + retry_future.await; + }); + runtime.spawn(future); + + AccountHandle { + api_availability: api_availability_copy, + initial_check_abort_handle, + proxy: accounts_proxy_copy, + } + } +} + +fn handle_expiry_result_inner( + result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>, + api_availability: &ApiAvailabilityHandle, +) -> bool { + match result { + Ok(_expiry) if *_expiry >= chrono::Utc::now() => { + api_availability.resume(); + true + } + Ok(_expiry) => { + api_availability.pause(); + true + } + Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => { + if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH { + api_availability.pause(); + return true; + } + false + } + Err(_) => false, + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 7394604f87..818c75b908 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -5,6 +5,7 @@ extern crate serde; +mod account; pub mod account_history; pub mod exception_logging; mod geoip; @@ -25,7 +26,7 @@ use futures::{ SinkExt, StreamExt, }; use log::{debug, error, info, warn}; -use mullvad_rpc::AccountsProxy; +use mullvad_rpc::availability::ApiAvailabilityHandle; use mullvad_types::{ account::{AccountData, AccountToken, VoucherSubmission}, endpoint::MullvadEndpoint, @@ -105,6 +106,9 @@ pub enum Error { #[error(display = "REST request failed")] RestError(#[error(source)] mullvad_rpc::rest::Error), + #[error(display = "API availability check failed")] + ApiCheckError(#[error(source)] mullvad_rpc::availability::Error), + #[error(display = "Unable to load account history")] LoadAccountHistory(#[error(source)] account_history::Error), @@ -509,7 +513,7 @@ pub struct Daemon<L: EventListener> { event_listener: L, settings: SettingsPersister, account_history: account_history::AccountHistory, - accounts_proxy: AccountsProxy, + account: account::AccountHandle, rpc_runtime: mullvad_rpc::MullvadRpcRuntime, rpc_handle: mullvad_rpc::rest::MullvadRestHandle, wireguard_key_manager: wireguard::KeyManager, @@ -539,7 +543,6 @@ where ) -> Result<Self, Error> { let (tunnel_state_machine_shutdown_tx, tunnel_state_machine_shutdown_signal) = oneshot::channel(); - let runtime = tokio::runtime::Handle::current(); let (internal_event_tx, internal_event_rx) = command_channel.destructure(); @@ -571,17 +574,19 @@ where .await .map_err(Error::InitRpcFactory)?; let rpc_handle = rpc_runtime.mullvad_rest_handle(); + let api_availability = rpc_runtime.availability_handle(); let relay_list_listener = event_listener.clone(); let on_relay_list_update = move |relay_list: &RelayList| { relay_list_listener.notify_relay_list(relay_list.clone()); }; - let mut relay_selector = relays::RelaySelector::new( + let relay_selector = relays::RelaySelector::new( rpc_handle.clone(), on_relay_list_update, &resource_dir, &cache_dir, + api_availability.clone(), ); @@ -594,6 +599,7 @@ where let app_version_info = version_check::load_cache(&cache_dir).await; let (version_updater, version_updater_handle) = version_check::VersionUpdater::new( rpc_handle.clone(), + api_availability.clone(), cache_dir.clone(), internal_event_tx.to_specialized_sender(), app_version_info.clone(), @@ -667,8 +673,10 @@ where vec![] }; + let (offline_state_tx, offline_state_rx) = mpsc::unbounded(); + let tunnel_command_tx = tunnel_state_machine::spawn( - runtime, + runtime.clone(), tunnel_state_machine::InitialTunnelState { allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, @@ -683,6 +691,7 @@ where resource_dir, cache_dir.clone(), internal_event_tx.to_specialized_sender(), + offline_state_tx, tunnel_state_machine_shutdown_tx, #[cfg(target_os = "android")] android_context, @@ -690,6 +699,8 @@ where .await .map_err(Error::TunnelError)?; + Self::forward_offline_state(&runtime, api_availability.clone(), offline_state_rx).await; + let tsm_api_address_change_tx = Arc::downgrade(&tunnel_command_tx); tokio::spawn(async move { while let Some(address_change) = address_change_rx.next().await { @@ -701,11 +712,25 @@ where } }); - let wireguard_key_manager = - wireguard::KeyManager::new(internal_event_tx.clone(), rpc_handle.clone()); + let wireguard_key_manager = wireguard::KeyManager::new( + internal_event_tx.clone(), + api_availability.clone(), + rpc_handle.clone(), + ); + + let account = account::Account::new( + runtime, + rpc_handle.clone(), + settings.get_account_token(), + api_availability.clone(), + ); // Attempt to download a fresh relay list - relay_selector.update().await; + let mut relay_handle = relay_selector.updater_handle(); + relay_handle + .update_relay_list_deferred() + .await + .expect("Relay list updated thread has stopped unexpectedly"); let mut daemon = Daemon { tunnel_command_tx, @@ -721,8 +746,8 @@ where event_listener, settings, account_history, + account, rpc_runtime, - accounts_proxy: AccountsProxy::new(rpc_handle.clone()), rpc_handle, wireguard_key_manager, version_updater_handle, @@ -1418,7 +1443,7 @@ where async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) { let daemon_tx = self.tx.clone(); - let future = self.accounts_proxy.create_account(); + let future = self.account.proxy.create_account(); tokio::spawn(async move { match future.await { @@ -1437,17 +1462,20 @@ where tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>, account_token: AccountToken, ) { - let expiry_fut = self.accounts_proxy.get_expiry(account_token); - let rpc_call = async { - let result = expiry_fut.await.map(|expiry| AccountData { expiry }); - Self::oneshot_send(tx, result, "account data"); - }; - tokio::spawn(rpc_call); + let account = self.account.clone(); + tokio::spawn(async move { + let result = account.check_expiry(account_token).await; + Self::oneshot_send( + tx, + result.map(|expiry| AccountData { expiry }), + "account data", + ); + }); } async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) { if let Some(account_token) = self.settings.get_account_token() { - let future = self.accounts_proxy.get_www_auth_token(account_token); + let future = self.account.proxy.get_www_auth_token(account_token); let rpc_call = async { Self::oneshot_send( tx, @@ -1471,15 +1499,17 @@ where voucher: String, ) { if let Some(account_token) = self.settings.get_account_token() { - let future = self.accounts_proxy.submit_voucher(account_token, voucher); - let rpc_call = async { + let mut account = self.account.clone(); + tokio::spawn(async move { Self::oneshot_send( tx, - future.await.map_err(Error::RestError), + account + .submit_voucher(account_token, voucher) + .await + .map_err(Error::RestError), "submit_voucher response", ); - }; - tokio::spawn(rpc_call); + }); } else { Self::oneshot_send(tx, Err(Error::NoAccountToken), "submit_voucher response"); } @@ -2252,6 +2282,7 @@ where } Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys), Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), + Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), } } @@ -2291,6 +2322,7 @@ where let result = match verification_rpc.await { Ok(is_valid) => Ok(is_valid), Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), + Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), Err(wireguard::Error::TooManyKeys) => return, }; Self::oneshot_send(tx, result, "verify_wireguard_key response"); @@ -2351,6 +2383,23 @@ where Some(bypass_tx) } + async fn forward_offline_state( + runtime: &tokio::runtime::Handle, + api_availability: ApiAvailabilityHandle, + mut offline_state_rx: mpsc::UnboundedReceiver<bool>, + ) { + let initial_state = offline_state_rx + .next() + .await + .expect("missing initial offline state"); + api_availability.set_offline(initial_state); + runtime.spawn(async move { + while let Some(is_offline) = offline_state_rx.next().await { + api_availability.set_offline(is_offline); + } + }); + } + /// Set the target state of the client. If it changed trigger the operations needed to /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. diff --git a/mullvad-daemon/src/relays.rs b/mullvad-daemon/src/relays.rs index ad1a038817..607724a534 100644 --- a/mullvad-daemon/src/relays.rs +++ b/mullvad-daemon/src/relays.rs @@ -9,7 +9,7 @@ use futures::{ }; use ipnetwork::IpNetwork; use log::{debug, error, info, warn}; -use mullvad_rpc::{rest::MullvadRestHandle, RelayListProxy}; +use mullvad_rpc::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, RelayListProxy}; use mullvad_types::{ endpoint::MullvadEndpoint, location::Location, @@ -187,6 +187,7 @@ impl RelaySelector { on_update: impl Fn(&RelayList) + Send + 'static, resource_dir: &Path, cache_dir: &Path, + api_availability: ApiAvailabilityHandle, ) -> Self { let cache_path = cache_dir.join(RELAYS_FILENAME); let resource_path = resource_dir.join(RELAYS_FILENAME); @@ -211,6 +212,7 @@ impl RelaySelector { cache_path, parsed_relays.clone(), Box::new(on_update), + api_availability, ); @@ -232,6 +234,10 @@ impl RelaySelector { } } + pub fn updater_handle(&self) -> RelayListUpdaterHandle { + self.updater.as_ref().unwrap().clone() + } + /// Returns all countries and cities. The cities in the object returned does not have any /// relays in them. pub fn get_locations(&mut self) -> RelayList { @@ -986,13 +992,20 @@ impl RelaySelector { #[derive(Clone)] pub struct RelayListUpdaterHandle { - tx: mpsc::Sender<()>, + tx: mpsc::Sender<bool>, } impl RelayListUpdaterHandle { - async fn update_relay_list(&mut self) -> Result<(), Error> { + pub async fn update_relay_list(&mut self) -> Result<(), Error> { self.tx - .send(()) + .send(false) + .await + .map_err(|_| Error::DownloaderShutDown) + } + + pub async fn update_relay_list_deferred(&mut self) -> Result<(), Error> { + self.tx + .send(true) .await .map_err(|_| Error::DownloaderShutDown) } @@ -1004,6 +1017,7 @@ struct RelayListUpdater { parsed_relays: Arc<Mutex<ParsedRelays>>, on_update: Box<dyn Fn(&RelayList) + Send + 'static>, earliest_next_try: Instant, + api_availability: ApiAvailabilityHandle, } impl RelayListUpdater { @@ -1012,6 +1026,7 @@ impl RelayListUpdater { cache_path: PathBuf, parsed_relays: Arc<Mutex<ParsedRelays>>, on_update: Box<dyn Fn(&RelayList) + Send + 'static>, + api_availability: ApiAvailabilityHandle, ) -> RelayListUpdaterHandle { let (tx, cmd_rx) = mpsc::channel(1); let service = rpc_handle.service(); @@ -1022,6 +1037,7 @@ impl RelayListUpdater { parsed_relays, on_update, earliest_next_try: Instant::now() + UPDATE_INTERVAL, + api_availability, }; service.spawn(updater.run(cmd_rx)); @@ -1029,7 +1045,7 @@ impl RelayListUpdater { RelayListUpdaterHandle { tx } } - async fn run(mut self, mut cmd_rx: mpsc::Receiver<()>) { + async fn run(mut self, mut cmd_rx: mpsc::Receiver<bool>) { let mut check_interval = tokio_stream::wrappers::IntervalStream::new( tokio::time::interval(UPDATE_CHECK_INTERVAL), ) @@ -1040,7 +1056,7 @@ impl RelayListUpdater { _check_update = check_interval.next() => { if download_future.is_terminated() && self.should_update() { let tag = self.parsed_relays.lock().tag().map(|tag| tag.to_string()); - download_future = Box::pin(Self::download_relay_list(self.rpc_client.clone(), tag).fuse()); + download_future = Box::pin(Self::download_relay_list(self.api_availability.clone(), self.rpc_client.clone(), tag).fuse()); self.earliest_next_try = Instant::now() + UPDATE_INTERVAL; } }, @@ -1052,9 +1068,14 @@ impl RelayListUpdater { cmd = cmd_rx.next() => { match cmd { - Some(_) => { + Some(defer) => { let tag = self.parsed_relays.lock().tag().map(|tag| tag.to_string()); - self.consume_new_relay_list(self.rpc_client.relay_list(tag).await).await; + if defer { + let download_future = Self::download_relay_list(self.api_availability.clone(), self.rpc_client.clone(), tag); + self.consume_new_relay_list(download_future.await).await; + } else { + self.consume_new_relay_list(self.rpc_client.relay_list(tag).await.map_err(mullvad_rpc::Error::from)).await; + } }, None => { log::error!("Relay list updater shutting down"); @@ -1069,7 +1090,7 @@ impl RelayListUpdater { async fn consume_new_relay_list( &mut self, - result: Result<Option<RelayList>, mullvad_rpc::rest::Error>, + result: Result<Option<RelayList>, mullvad_rpc::Error>, ) { match result { Ok(Some(relay_list)) => { @@ -1103,10 +1124,18 @@ impl RelayListUpdater { } fn download_relay_list( + api_handle: ApiAvailabilityHandle, rpc_handle: RelayListProxy, tag: Option<String>, - ) -> impl Future<Output = Result<Option<RelayList>, mullvad_rpc::rest::Error>> + 'static { - let download_futures = move || rpc_handle.relay_list(tag.clone()); + ) -> impl Future<Output = Result<Option<RelayList>, mullvad_rpc::Error>> + 'static { + let download_futures = move || { + let available = api_handle.wait_available(); + let req = rpc_handle.relay_list(tag.clone()); + async move { + available.await?; + req.await.map_err(mullvad_rpc::Error::from) + } + }; let exponential_backoff = ExponentialBackoff::new(EXPONENTIAL_BACKOFF_INITIAL, EXPONENTIAL_BACKOFF_FACTOR) diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index d0d5c13d8c..b06466bb1e 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -3,7 +3,7 @@ use crate::{ DaemonEventSender, }; use futures::{channel::mpsc, stream::FusedStream, FutureExt, SinkExt, StreamExt, TryFutureExt}; -use mullvad_rpc::{rest::MullvadRestHandle, AppVersionProxy}; +use mullvad_rpc::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, AppVersionProxy}; use mullvad_types::version::{AppVersionInfo, ParsedAppVersion}; use serde::{Deserialize, Serialize}; use std::{ @@ -78,6 +78,9 @@ pub enum Error { #[error(display = "Failed to check the latest app version")] Download(#[error(source)] mullvad_rpc::rest::Error), + #[error(display = "API availability check failed")] + ApiCheck(#[error(source)] mullvad_rpc::availability::Error), + #[error(display = "Clearing version check cache due to a version mismatch")] CacheVersionMismatch, } @@ -92,6 +95,7 @@ pub(crate) struct VersionUpdater { next_update_time: Instant, show_beta_releases: bool, rx: Option<mpsc::Receiver<VersionUpdaterCommand>>, + availability_handle: ApiAvailabilityHandle, } #[derive(Clone)] @@ -133,6 +137,7 @@ impl VersionUpdaterHandle { impl VersionUpdater { pub fn new( mut rpc_handle: MullvadRestHandle, + availability_handle: ApiAvailabilityHandle, cache_dir: PathBuf, update_sender: DaemonEventSender<AppVersionInfo>, last_app_version_info: Option<AppVersionInfo>, @@ -154,6 +159,7 @@ impl VersionUpdater { next_update_time: Instant::now(), show_beta_releases, rx: Some(rx), + availability_handle, }, VersionUpdaterHandle { tx }, ) @@ -162,15 +168,20 @@ impl VersionUpdater { fn create_update_future( &self, ) -> impl Future<Output = Result<mullvad_rpc::AppVersionResponse, Error>> + Send + 'static { + let api_handle = self.availability_handle.clone(); let version_proxy = self.version_proxy.clone(); let platform_version = self.platform_version.clone(); let download_future_factory = move || { - let response = version_proxy.version_check( + let when_available = api_handle.wait_available(); + let request = version_proxy.version_check( PRODUCT_VERSION.to_owned(), PLATFORM, platform_version.clone(), ); - response.map_err(Error::Download) + async move { + when_available.await.map_err(Error::ApiCheck)?; + request.await.map_err(Error::Download) + } }; let should_retry = |result: &Result<_, _>| -> bool { result.is_err() }; diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs index 1c548d3e8d..c0c260de29 100644 --- a/mullvad-daemon/src/wireguard.rs +++ b/mullvad-daemon/src/wireguard.rs @@ -1,6 +1,9 @@ use crate::{DaemonEventSender, InternalDaemonEvent}; use chrono::offset::Utc; -use mullvad_rpc::rest::{Error as RestError, MullvadRestHandle}; +use mullvad_rpc::{ + availability::ApiAvailabilityHandle, + rest::{Error as RestError, MullvadRestHandle}, +}; use mullvad_types::account::AccountToken; pub use mullvad_types::wireguard::*; use std::{future::Future, pin::Pin, time::Duration}; @@ -31,6 +34,8 @@ const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); pub enum Error { #[error(display = "Unexpected HTTP request error")] RestError(#[error(source)] mullvad_rpc::rest::Error), + #[error(display = "API availability check was interrupted")] + ApiCheckError(#[error(source)] mullvad_rpc::availability::Error), #[error(display = "Account already has maximum number of keys")] TooManyKeys, } @@ -39,6 +44,7 @@ pub type Result<T> = std::result::Result<T, Error>; pub struct KeyManager { daemon_tx: DaemonEventSender, + availability_handle: ApiAvailabilityHandle, http_handle: MullvadRestHandle, current_job: Option<AbortHandle>, @@ -47,9 +53,14 @@ pub struct KeyManager { } impl KeyManager { - pub(crate) fn new(daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle) -> Self { + pub(crate) fn new( + daemon_tx: DaemonEventSender, + availability_handle: ApiAvailabilityHandle, + http_handle: MullvadRestHandle, + ) -> Self { Self { daemon_tx, + availability_handle, http_handle, current_job: None, abort_scheduler_tx: None, @@ -164,11 +175,22 @@ impl KeyManager { let mut inner_future_generator = self.push_future_generator(account.clone(), private_key, timeout); + let availability_handle = self.availability_handle.clone(); + let future_generator = move || { + let wait_available = availability_handle.wait_available(); let fut = inner_future_generator(); let error_tx = error_tx.clone(); let error_account = error_account.clone(); async move { + let error_account_copy = error_account.clone(); + wait_available.await.map_err(|error| { + let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( + error_account_copy, + Err(Error::ApiCheckError(error)), + ))); + false + })?; let response = fut.await; match response { Ok(addresses) => Ok(addresses), @@ -299,6 +321,7 @@ impl KeyManager { async fn create_automatic_rotation( daemon_tx: DaemonEventSender, + availability_handle: ApiAvailabilityHandle, http_handle: MullvadRestHandle, mut public_key: PublicKey, rotation_interval_secs: u64, @@ -306,14 +329,20 @@ impl KeyManager { ) { tokio::time::sleep(ROTATION_START_DELAY).await; - let rotate_key_for_account = move |old_key: &PublicKey| { - Self::rotate_key( - daemon_tx.clone(), - http_handle.clone(), - account_token.clone(), - old_key.clone(), - ) - }; + let rotate_key_for_account = + move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> { + let wait_available = availability_handle.wait_available(); + let rotate = Self::rotate_key( + daemon_tx.clone(), + http_handle.clone(), + account_token.clone(), + old_key.clone(), + ); + Box::pin(async move { + wait_available.await?; + rotate.await + }) + }; loop { Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await; @@ -341,12 +370,12 @@ impl KeyManager { http_handle: MullvadRestHandle, account_token: AccountToken, old_key: PublicKey, - ) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> { + ) -> impl Future<Output = Result<PublicKey>> { let new_key = PrivateKey::new_from_random(); let rpc_result = Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key); - Box::pin(async move { + async move { match rpc_result.await { Ok(data) => { // Update account data @@ -365,7 +394,7 @@ impl KeyManager { } Err(unknown) => Err(unknown), } - }) + } } async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey> @@ -403,6 +432,7 @@ impl KeyManager { // Schedule cancellable series of repeating rotation tasks let fut = Self::create_automatic_rotation( self.daemon_tx.clone(), + self.availability_handle.clone(), self.http_handle.clone(), public_key, self.auto_rotation_interval.as_duration().as_secs(), diff --git a/mullvad-rpc/src/availability.rs b/mullvad-rpc/src/availability.rs new file mode 100644 index 0000000000..227bc0cd35 --- /dev/null +++ b/mullvad-rpc/src/availability.rs @@ -0,0 +1,127 @@ +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; +use tokio::sync::broadcast; + + +const CHANNEL_CAPACITY: usize = 100; + + +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The [`ApiAvailability`] instance was dropped, or the receiver lagged behind. + #[error(display = "API availability instance was dropped")] + Interrupted(#[error(source)] broadcast::error::RecvError), +} + + +#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)] +pub struct State { + pause_automatic: bool, + offline: bool, +} + +impl State { + pub fn is_paused(&self) -> bool { + self.pause_automatic + } + + pub fn is_offline(&self) -> bool { + self.offline + } + + pub fn is_available(&self) -> bool { + !self.is_paused() && !self.is_offline() + } +} + +pub struct ApiAvailability { + state: Arc<Mutex<State>>, + tx: broadcast::Sender<State>, +} + +impl ApiAvailability { + pub fn new(initial_state: State) -> Self { + let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY); + let state = Arc::new(Mutex::new(initial_state)); + ApiAvailability { state, tx } + } + + pub fn get_state(&self) -> State { + *self.state.lock().unwrap() + } + + pub fn handle(&self) -> ApiAvailabilityHandle { + ApiAvailabilityHandle { + state: self.state.clone(), + tx: self.tx.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ApiAvailabilityHandle { + state: Arc<Mutex<State>>, + tx: broadcast::Sender<State>, +} + +impl ApiAvailabilityHandle { + pub fn pause(&self) { + let mut state = self.state.lock().unwrap(); + if !state.pause_automatic { + state.pause_automatic = true; + let _ = self.tx.send(*state); + } + } + + pub fn resume(&self) { + let mut state = self.state.lock().unwrap(); + if state.pause_automatic { + state.pause_automatic = false; + let _ = self.tx.send(*state); + } + } + + pub fn set_offline(&self, offline: bool) { + let mut state = self.state.lock().unwrap(); + if state.offline != offline { + state.offline = offline; + let _ = self.tx.send(*state); + } + } + + pub fn get_state(&self) -> State { + *self.state.lock().unwrap() + } + + pub fn wait_available(&self) -> impl Future<Output = Result<(), Error>> { + self.wait_for_state(|state| state.is_available()) + } + + pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> { + self.wait_for_state(|state| !state.is_offline()) + } + + fn wait_for_state( + &self, + state_ready: impl Fn(State) -> bool, + ) -> impl Future<Output = Result<(), Error>> { + let mut rx = self.tx.subscribe(); + let state = self.state.clone(); + + async move { + let current_state = { *state.lock().unwrap() }; + if state_ready(current_state) { + return Ok(()); + } + + loop { + let new_state = rx.recv().await?; + if state_ready(new_state) { + return Ok(()); + } + } + } + } +} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 96260775f8..098ed2f0b4 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -18,6 +18,8 @@ use std::{ use talpid_types::{net::wireguard, ErrorExt}; +pub mod availability; +use availability::{ApiAvailability, ApiAvailabilityHandle}; pub mod rest; mod https_client_with_sni; @@ -41,6 +43,9 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER"; /// Error code returned by the Mullvad API if the account token is invalid. pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT"; +/// Error code returned by the Mullvad API if the account token is missing or invalid. +pub const INVALID_AUTH: &str = "INVALID_AUTH"; + const API_HOST: &str = "api.mullvad.net"; pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt"; const API_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78)); @@ -51,6 +56,7 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443); pub struct MullvadRpcRuntime { handle: tokio::runtime::Handle, pub address_cache: AddressCache, + api_availability: availability::ApiAvailability, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } @@ -62,6 +68,9 @@ pub enum Error { #[error(display = "Failed to load address cache")] AddressCacheError(#[error(source)] address_cache::Error), + + #[error(display = "API availability check failed")] + ApiCheckError(#[error(source)] availability::Error), } impl MullvadRpcRuntime { @@ -74,6 +83,7 @@ impl MullvadRpcRuntime { None, Arc::new(Box::new(|_| Ok(()))), )?, + api_availability: ApiAvailability::new(availability::State::default()), #[cfg(target_os = "android")] socket_bypass_tx: None, }) @@ -139,6 +149,7 @@ impl MullvadRpcRuntime { Ok(MullvadRpcRuntime { handle, address_cache, + api_availability: ApiAvailability::new(availability::State::default()), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -156,6 +167,7 @@ impl MullvadRpcRuntime { let service = rest::RequestService::new( https_connector, self.handle.clone(), + self.api_availability.handle(), self.address_cache.clone(), ); let handle = service.handle(); @@ -172,7 +184,12 @@ impl MullvadRpcRuntime { Some("app".to_owned()), ); - rest::MullvadRestHandle::new(service, factory, self.address_cache.clone()) + rest::MullvadRestHandle::new( + service, + factory, + self.address_cache.clone(), + self.availability_handle(), + ) } /// Returns a new request service handle @@ -183,8 +200,13 @@ impl MullvadRpcRuntime { pub fn handle(&mut self) -> &mut tokio::runtime::Handle { &mut self.handle } + + pub fn availability_handle(&self) -> ApiAvailabilityHandle { + self.api_availability.handle() + } } +#[derive(Clone)] pub struct AccountsProxy { handle: rest::MullvadRestHandle, } diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index be4c1dc990..77bf06fd55 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,6 +1,6 @@ use crate::{ - address_cache::AddressCache, https_client_with_sni::HttpsConnectorWithSni, - tcp_stream::TcpStreamHandle, + address_cache::AddressCache, availability::ApiAvailabilityHandle, + https_client_with_sni::HttpsConnectorWithSni, tcp_stream::TcpStreamHandle, }; use futures::{ channel::{mpsc, oneshot}, @@ -22,6 +22,7 @@ use std::{ str::FromStr, time::{Duration, Instant}, }; +use talpid_types::ErrorExt; use tokio::runtime::Handle; pub use hyper::StatusCode; @@ -84,6 +85,7 @@ pub(crate) struct RequestService { handle: Handle, next_id: u64, in_flight_requests: BTreeMap<u64, AbortHandle>, + api_availability: ApiAvailabilityHandle, address_cache: AddressCache, } @@ -92,6 +94,7 @@ impl RequestService { pub fn new( mut connector: HttpsConnectorWithSni, handle: Handle, + api_availability: ApiAvailabilityHandle, address_cache: AddressCache, ) -> RequestService { let (command_tx, command_rx) = mpsc::channel(1); @@ -99,7 +102,6 @@ impl RequestService { connector.set_service_tx(command_tx.clone()); let client = Client::builder().build(connector); - Self { command_tx, command_rx, @@ -108,6 +110,7 @@ impl RequestService { in_flight_requests: BTreeMap::new(), next_id: 0, handle, + api_availability, address_cache, } } @@ -134,6 +137,7 @@ impl RequestService { abortable(self.client.request(hyper_request).map_err(Error::from)); let address_cache = self.address_cache.clone(); let handle = self.handle.clone(); + let api_availability = self.api_availability.clone(); let future = async move { let response = @@ -146,20 +150,25 @@ impl RequestService { if let Err(err) = &response { match err { Error::HyperError(_) | Error::TimeoutError(_) => { - log::error!("HTTP request failed: {}", err); - let current_address = address_cache.peek_address(); - if current_address == host_addr - && address_cache.has_tried_current_address() - { - handle.spawn(async move { - address_cache.select_new_address().await; - let new_address = address_cache.peek_address(); - log::error!( - "Request failed using address {}. Trying next API address: {}", - current_address, - new_address, - ); - }); + log::error!( + "{}", + err.display_chain_with_msg("HTTP request failed") + ); + if !api_availability.get_state().is_offline() { + let current_address = address_cache.peek_address(); + if current_address == host_addr + && address_cache.has_tried_current_address() + { + handle.spawn(async move { + address_cache.select_new_address().await; + let new_address = address_cache.peek_address(); + log::error!( + "Request failed using address {}. Trying next API address: {}", + current_address, + new_address, + ); + }); + } } } _ => (), @@ -594,6 +603,7 @@ pub async fn handle_error_response<T>(response: Response) -> Result<T> { pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, + availability: ApiAvailabilityHandle, } impl MullvadRestHandle { @@ -601,8 +611,13 @@ impl MullvadRestHandle { service: RequestServiceHandle, factory: RequestFactory, address_cache: AddressCache, + availability: ApiAvailabilityHandle, ) -> Self { - let handle = Self { service, factory }; + let handle = Self { + service, + factory, + availability, + }; handle.spawn_api_address_fetcher(address_cache); handle @@ -610,10 +625,11 @@ impl MullvadRestHandle { fn spawn_api_address_fetcher(&self, address_cache: AddressCache) { let handle = self.clone(); + let availability = self.availability.clone(); self.service.spawn(async move { // always start the fetch after 15 minutes - let api_proxy = crate::ApiProxy { handle }; + let api_proxy = crate::ApiProxy::new(handle); let mut next_check = Instant::now() + API_IP_CHECK_DELAY; let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL; @@ -624,6 +640,11 @@ impl MullvadRestHandle { loop { interval.tick().await; if next_check < Instant::now() { + if let Err(error) = availability.wait_available().await { + log::error!("Failed while waiting for API: {}", error); + next_check = next_error_check(); + continue; + } match api_proxy.clone().get_api_addrs().await { Ok(new_addrs) => { log::debug!("Fetched new API addresses {:?}, will fetch again in {} hours", new_addrs, API_IP_CHECK_INTERVAL.as_secs() / ( 60 * 60 )); diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs index 0acb0941aa..b5479640e6 100644 --- a/mullvad-types/src/account.rs +++ b/mullvad-types/src/account.rs @@ -15,6 +15,13 @@ pub struct AccountData { pub expiry: DateTime<Utc>, } +impl AccountData { + /// Return true if the account has no time left. + pub fn is_expired(&self) -> bool { + self.expiry >= Utc::now() + } +} + /// Data structure that's returned from successful invocation of the mullvad API's /// `/v1/submit-voucher` RPC. #[derive(Deserialize, Serialize, Debug)] diff --git a/talpid-core/src/offline/android.rs b/talpid-core/src/offline/android.rs index fefe2556cf..65f0e7cf58 100644 --- a/talpid-core/src/offline/android.rs +++ b/talpid-core/src/offline/android.rs @@ -1,4 +1,3 @@ -use crate::tunnel_state_machine::TunnelCommand; use futures::channel::mpsc::UnboundedSender; use jnix::{ jni::{ @@ -44,10 +43,14 @@ pub struct MonitorHandle { jvm: Arc<JavaVM>, class: GlobalRef, object: GlobalRef, + _sender: Arc<UnboundedSender<bool>>, } impl MonitorHandle { - pub fn new(android_context: AndroidContext) -> Result<Self, Error> { + pub fn new( + android_context: AndroidContext, + sender: Arc<UnboundedSender<bool>>, + ) -> Result<Self, Error> { let env = JnixEnv::from( android_context .jvm @@ -93,6 +96,7 @@ impl MonitorHandle { jvm: android_context.jvm, class, object, + _sender: sender, }) } @@ -128,7 +132,7 @@ impl MonitorHandle { } } - fn set_sender(&self, sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<(), Error> { + fn set_sender(&self, sender: Weak<UnboundedSender<bool>>) -> Result<(), Error> { let sender_ptr = Box::new(sender); let sender_address = Box::into_raw(sender_ptr) as jlong; @@ -181,10 +185,10 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnec sender_address: jlong, ) { let sender_ref = Box::leak(unsafe { get_sender_from_address(sender_address) }); - let tunnel_command = TunnelCommand::IsOffline(is_connected == JNI_FALSE); + let is_offline = is_connected == JNI_FALSE; if let Some(sender) = sender_ref.upgrade() { - if sender.unbounded_send(tunnel_command).is_err() { + if sender.unbounded_send(is_offline).is_err() { log::warn!("Failed to send offline change event"); } } @@ -201,17 +205,19 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySende let _ = unsafe { get_sender_from_address(sender_address) }; } -unsafe fn get_sender_from_address(address: jlong) -> Box<Weak<UnboundedSender<TunnelCommand>>> { - Box::from_raw(address as *mut Weak<UnboundedSender<TunnelCommand>>) +unsafe fn get_sender_from_address(address: jlong) -> Box<Weak<UnboundedSender<bool>>> { + Box::from_raw(address as *mut Weak<UnboundedSender<bool>>) } pub async fn spawn_monitor( - sender: Weak<UnboundedSender<TunnelCommand>>, + sender: UnboundedSender<bool>, android_context: AndroidContext, ) -> Result<MonitorHandle, Error> { - let monitor_handle = MonitorHandle::new(android_context)?; + let sender = Arc::new(sender); + let weak_sender = Arc::downgrade(&sender); + let monitor_handle = MonitorHandle::new(android_context, sender)?; - monitor_handle.set_sender(sender)?; + monitor_handle.set_sender(weak_sender)?; Ok(monitor_handle) } diff --git a/talpid-core/src/offline/linux.rs b/talpid-core/src/offline/linux.rs index ceaa864cc7..f9e137853b 100644 --- a/talpid-core/src/offline/linux.rs +++ b/talpid-core/src/offline/linux.rs @@ -1,11 +1,8 @@ -use crate::{ - routing::{self, RouteManagerHandle}, - tunnel_state_machine::TunnelCommand, -}; +use crate::routing::{self, RouteManagerHandle}; use futures::{channel::mpsc::UnboundedSender, StreamExt}; use std::{ net::{IpAddr, Ipv4Addr}, - sync::Weak, + sync::Arc, }; use talpid_types::ErrorExt; @@ -20,6 +17,7 @@ pub enum Error { pub struct MonitorHandle { route_manager: RouteManagerHandle, + _notify_tx: Arc<UnboundedSender<bool>>, } // Mullvad API's public IP address, correct at the time of writing, but any public IP address will @@ -42,7 +40,7 @@ impl MonitorHandle { } pub async fn spawn_monitor( - sender: Weak<UnboundedSender<TunnelCommand>>, + notify_tx: UnboundedSender<bool>, route_manager: RouteManagerHandle, ) -> Result<MonitorHandle> { let mut is_offline = public_ip_unreachable(&route_manager).await?; @@ -52,8 +50,11 @@ pub async fn spawn_monitor( .await .map_err(Error::RouteManagerError)?; + let notify_tx = Arc::new(notify_tx); + let sender = Arc::downgrade(¬ify_tx); let monitor_handle = MonitorHandle { route_manager: route_manager.clone(), + _notify_tx: notify_tx, }; tokio::spawn(async move { @@ -71,7 +72,7 @@ pub async fn spawn_monitor( }); if new_offline_state != is_offline { is_offline = new_offline_state; - let _ = sender.unbounded_send(TunnelCommand::IsOffline(is_offline)); + let _ = sender.unbounded_send(is_offline); } } None => return, diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs index 2569fa06c6..3e374cf29c 100644 --- a/talpid-core/src/offline/macos.rs +++ b/talpid-core/src/offline/macos.rs @@ -1,4 +1,3 @@ -use crate::tunnel_state_machine::TunnelCommand; use futures::channel::mpsc::UnboundedSender; use std::{ net::{Ipv4Addr, SocketAddr}, @@ -39,7 +38,9 @@ pub enum Error { InitializationError, } -pub struct MonitorHandle; +pub struct MonitorHandle { + _notify_tx: Arc<UnboundedSender<bool>>, +} impl MonitorHandle { /// Host is considered to be offline if the IPv4 internet is considered to be unreachable by the @@ -54,10 +55,10 @@ impl MonitorHandle { } } -pub async fn spawn_monitor( - sender: Weak<UnboundedSender<TunnelCommand>>, -) -> Result<MonitorHandle, Error> { +pub async fn spawn_monitor(notify_tx: UnboundedSender<bool>) -> Result<MonitorHandle, Error> { let (result_tx, result_rx) = mpsc::channel(); + let notify_tx = Arc::new(notify_tx); + let sender = Arc::downgrade(¬ify_tx); thread::spawn(move || { let mut reachability_ref = SCNetworkReachability::from(ipv4_internet()); let store = SCDynamicStoreBuilder::new("talpid-offline-watcher").build(); @@ -108,7 +109,9 @@ pub async fn spawn_monitor( }); let _ = result_rx.recv().map_err(|_| Error::InitializationError)??; - Ok(MonitorHandle {}) + Ok(MonitorHandle { + _notify_tx: notify_tx, + }) } fn ipv4_internet() -> SocketAddr { @@ -170,7 +173,7 @@ fn iface_is_physical(iface: &SCNetworkInterface) -> bool { #[derive(Clone)] struct OfflineStateContext { - sender: Weak<UnboundedSender<TunnelCommand>>, + sender: Weak<UnboundedSender<bool>>, is_offline: Arc<AtomicBool>, } @@ -182,8 +185,7 @@ impl OfflineStateContext { fn new_state(&self, is_offline: bool) { if self.is_offline.swap(is_offline, Ordering::SeqCst) != is_offline { if let Some(sender) = self.sender.upgrade() { - let cmd = TunnelCommand::IsOffline(is_offline); - let _ = sender.unbounded_send(cmd); + let _ = sender.unbounded_send(is_offline); } } } diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs index ac8b10e222..b713e0ba69 100644 --- a/talpid-core/src/offline/mod.rs +++ b/talpid-core/src/offline/mod.rs @@ -1,8 +1,6 @@ #[cfg(target_os = "linux")] use crate::routing::RouteManagerHandle; -use crate::tunnel_state_machine::TunnelCommand; use futures::channel::mpsc::UnboundedSender; -use std::sync::Weak; #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; @@ -43,7 +41,7 @@ impl MonitorHandle { } pub async fn spawn_monitor( - sender: Weak<UnboundedSender<TunnelCommand>>, + sender: UnboundedSender<bool>, #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<MonitorHandle, Error> { diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs index d9e5c7782d..84d83fd7cd 100644 --- a/talpid-core/src/offline/windows.rs +++ b/talpid-core/src/offline/windows.rs @@ -1,4 +1,4 @@ -use crate::{logging::windows::log_sink, tunnel_state_machine::TunnelCommand, winnet}; +use crate::{logging::windows::log_sink, winnet}; use futures::channel::mpsc::UnboundedSender; use parking_lot::Mutex; use std::{ @@ -49,16 +49,18 @@ pub struct BroadcastListener { thread_handle: RawHandle, thread_id: DWORD, _system_state: Arc<Mutex<SystemState>>, + _notify_tx: Arc<UnboundedSender<bool>>, } unsafe impl Send for BroadcastListener {} impl BroadcastListener { - pub fn start(sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<Self, Error> { + pub fn start(notify_tx: UnboundedSender<bool>) -> Result<Self, Error> { + let notify_tx = Arc::new(notify_tx); let mut system_state = Arc::new(Mutex::new(SystemState { network_connectivity: None, suspended: false, - daemon_channel: sender, + notify_tx: Arc::downgrade(¬ify_tx), })); let power_broadcast_state_ref = system_state.clone(); @@ -95,6 +97,7 @@ impl BroadcastListener { thread_handle: real_handle, thread_id: unsafe { GetThreadId(real_handle) }, _system_state: system_state, + _notify_tx: notify_tx, }) } @@ -229,7 +232,7 @@ enum StateChange { struct SystemState { network_connectivity: Option<bool>, suspended: bool, - daemon_channel: Weak<UnboundedSender<TunnelCommand>>, + notify_tx: Weak<UnboundedSender<bool>>, } impl SystemState { @@ -247,10 +250,8 @@ impl SystemState { let new_state = self.is_offline_currently(); if old_state != new_state { - if let Some(daemon_channel) = self.daemon_channel.upgrade() { - if let Err(e) = daemon_channel - .unbounded_send(TunnelCommand::IsOffline(new_state.unwrap_or(false))) - { + if let Some(notify_tx) = self.notify_tx.upgrade() { + if let Err(e) = notify_tx.unbounded_send(new_state.unwrap_or(false)) { log::error!("Failed to send new offline state to daemon: {}", e); } } @@ -264,9 +265,7 @@ impl SystemState { pub type MonitorHandle = BroadcastListener; -pub async fn spawn_monitor( - sender: Weak<UnboundedSender<TunnelCommand>>, -) -> Result<MonitorHandle, Error> { +pub async fn spawn_monitor(sender: UnboundedSender<bool>) -> Result<MonitorHandle, Error> { BroadcastListener::start(sender) } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 0b745ac6d2..c4e5e5ef0c 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -104,6 +104,7 @@ pub async fn spawn( resource_dir: PathBuf, cache_dir: impl AsRef<Path> + Send + 'static, state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static, + offline_state_listener: mpsc::UnboundedSender<bool>, shutdown_tx: oneshot::Sender<()>, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { @@ -128,6 +129,7 @@ pub async fn spawn( runtime.clone(), initial_settings, weak_command_tx, + offline_state_listener, tunnel_parameters_generator, tun_provider, log_dir, @@ -216,6 +218,7 @@ impl TunnelStateMachine { runtime: tokio::runtime::Handle, settings: InitialTunnelState, command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, + offline_state_tx: mpsc::UnboundedSender<bool>, tunnel_parameters_generator: impl TunnelParametersGenerator, tun_provider: TunProvider, log_dir: Option<PathBuf>, @@ -247,8 +250,21 @@ impl TunnelStateMachine { .map_err(Error::InitRouteManagerError)?, ) .map_err(Error::InitDnsMonitorError)?; + + let (offline_tx, mut offline_rx) = mpsc::unbounded(); + let initial_offline_state_tx = offline_state_tx.clone(); + tokio::spawn(async move { + while let Some(offline) = offline_rx.next().await { + if let Some(tx) = command_tx.upgrade() { + let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline)); + } else { + break; + } + let _ = offline_state_tx.unbounded_send(offline); + } + }); let mut offline_monitor = offline::spawn_monitor( - command_tx, + offline_tx, #[cfg(target_os = "linux")] route_manager .handle() @@ -259,6 +275,7 @@ impl TunnelStateMachine { .await .map_err(Error::OfflineMonitorError)?; let is_offline = offline_monitor.is_offline().await; + let _ = initial_offline_state_tx.unbounded_send(is_offline); #[cfg(windows)] split_tunnel |
