summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-09-15 13:39:29 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-09-15 13:39:29 +0200
commit301503dcd59185a5f8bb7d9741dcef21319023b6 (patch)
tree51cbf0ab59253931e01c5daadb8465654dc7c264
parente5b9e8ddea9ce8f70f412d5a7b1f70c675161f30 (diff)
parent6856f00b982726051099706a0159e84aac93e607 (diff)
downloadmullvadvpn-301503dcd59185a5f8bb7d9741dcef21319023b6.tar.xz
mullvadvpn-301503dcd59185a5f8bb7d9741dcef21319023b6.zip
Merge branch 'pause-automatic-requests'
-rw-r--r--mullvad-daemon/src/account.rs120
-rw-r--r--mullvad-daemon/src/lib.rs93
-rw-r--r--mullvad-daemon/src/relays.rs51
-rw-r--r--mullvad-daemon/src/version_check.rs17
-rw-r--r--mullvad-daemon/src/wireguard.rs56
-rw-r--r--mullvad-rpc/src/availability.rs127
-rw-r--r--mullvad-rpc/src/lib.rs24
-rw-r--r--mullvad-rpc/src/rest.rs59
-rw-r--r--mullvad-types/src/account.rs7
-rw-r--r--talpid-core/src/offline/android.rs26
-rw-r--r--talpid-core/src/offline/linux.rs15
-rw-r--r--talpid-core/src/offline/macos.rs20
-rw-r--r--talpid-core/src/offline/mod.rs4
-rw-r--r--talpid-core/src/offline/windows.rs21
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs19
15 files changed, 549 insertions, 110 deletions
diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs
new file mode 100644
index 0000000000..88b996743a
--- /dev/null
+++ b/mullvad-daemon/src/account.rs
@@ -0,0 +1,120 @@
+use chrono::{DateTime, Utc};
+use futures::future::{abortable, AbortHandle};
+use mullvad_rpc::{
+ availability::ApiAvailabilityHandle,
+ rest::{self, MullvadRestHandle},
+ AccountsProxy,
+};
+use mullvad_types::account::{AccountToken, VoucherSubmission};
+use std::time::Duration;
+use talpid_core::future_retry::{retry_future_with_backoff, ExponentialBackoff, Jittered};
+
+const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
+const RETRY_INTERVAL_FACTOR: u32 = 5;
+const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
+
+
+pub struct Account(());
+
+#[derive(Clone)]
+pub struct AccountHandle {
+ api_availability: ApiAvailabilityHandle,
+ initial_check_abort_handle: AbortHandle,
+ pub proxy: AccountsProxy,
+}
+
+impl AccountHandle {
+ pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
+ let result = self.proxy.get_expiry(token).await;
+ if handle_expiry_result_inner(&result, &self.api_availability) {
+ self.initial_check_abort_handle.abort();
+ }
+ result
+ }
+
+ pub async fn submit_voucher(
+ &mut self,
+ account_token: AccountToken,
+ voucher: String,
+ ) -> Result<VoucherSubmission, rest::Error> {
+ let result = self.proxy.submit_voucher(account_token, voucher).await;
+ if result.is_ok() {
+ self.initial_check_abort_handle.abort();
+ self.api_availability.resume();
+ }
+ result
+ }
+}
+
+impl Account {
+ pub fn new(
+ runtime: tokio::runtime::Handle,
+ rpc_handle: MullvadRestHandle,
+ token: Option<String>,
+ api_availability: ApiAvailabilityHandle,
+ ) -> AccountHandle {
+ let accounts_proxy = AccountsProxy::new(rpc_handle);
+ api_availability.pause();
+
+ let api_availability_copy = api_availability.clone();
+ let accounts_proxy_copy = accounts_proxy.clone();
+
+ let (future, initial_check_abort_handle) = abortable(async move {
+ let token = if let Some(token) = token {
+ token
+ } else {
+ api_availability.pause();
+ return;
+ };
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
+ .max_delay(RETRY_INTERVAL_MAX),
+ );
+ let future_generator = move || {
+ let wait_online = api_availability.wait_online();
+ let expiry_fut = accounts_proxy.get_expiry(token.clone());
+ let api_availability_copy = api_availability.clone();
+ async move {
+ let _ = wait_online.await;
+ handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy)
+ }
+ };
+ let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
+ let retry_future =
+ retry_future_with_backoff(future_generator, should_retry, retry_strategy);
+ retry_future.await;
+ });
+ runtime.spawn(future);
+
+ AccountHandle {
+ api_availability: api_availability_copy,
+ initial_check_abort_handle,
+ proxy: accounts_proxy_copy,
+ }
+ }
+}
+
+fn handle_expiry_result_inner(
+ result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>,
+ api_availability: &ApiAvailabilityHandle,
+) -> bool {
+ match result {
+ Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
+ api_availability.resume();
+ true
+ }
+ Ok(_expiry) => {
+ api_availability.pause();
+ true
+ }
+ Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => {
+ if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH {
+ api_availability.pause();
+ return true;
+ }
+ false
+ }
+ Err(_) => false,
+ }
+}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 7394604f87..818c75b908 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -5,6 +5,7 @@
extern crate serde;
+mod account;
pub mod account_history;
pub mod exception_logging;
mod geoip;
@@ -25,7 +26,7 @@ use futures::{
SinkExt, StreamExt,
};
use log::{debug, error, info, warn};
-use mullvad_rpc::AccountsProxy;
+use mullvad_rpc::availability::ApiAvailabilityHandle;
use mullvad_types::{
account::{AccountData, AccountToken, VoucherSubmission},
endpoint::MullvadEndpoint,
@@ -105,6 +106,9 @@ pub enum Error {
#[error(display = "REST request failed")]
RestError(#[error(source)] mullvad_rpc::rest::Error),
+ #[error(display = "API availability check failed")]
+ ApiCheckError(#[error(source)] mullvad_rpc::availability::Error),
+
#[error(display = "Unable to load account history")]
LoadAccountHistory(#[error(source)] account_history::Error),
@@ -509,7 +513,7 @@ pub struct Daemon<L: EventListener> {
event_listener: L,
settings: SettingsPersister,
account_history: account_history::AccountHistory,
- accounts_proxy: AccountsProxy,
+ account: account::AccountHandle,
rpc_runtime: mullvad_rpc::MullvadRpcRuntime,
rpc_handle: mullvad_rpc::rest::MullvadRestHandle,
wireguard_key_manager: wireguard::KeyManager,
@@ -539,7 +543,6 @@ where
) -> Result<Self, Error> {
let (tunnel_state_machine_shutdown_tx, tunnel_state_machine_shutdown_signal) =
oneshot::channel();
-
let runtime = tokio::runtime::Handle::current();
let (internal_event_tx, internal_event_rx) = command_channel.destructure();
@@ -571,17 +574,19 @@ where
.await
.map_err(Error::InitRpcFactory)?;
let rpc_handle = rpc_runtime.mullvad_rest_handle();
+ let api_availability = rpc_runtime.availability_handle();
let relay_list_listener = event_listener.clone();
let on_relay_list_update = move |relay_list: &RelayList| {
relay_list_listener.notify_relay_list(relay_list.clone());
};
- let mut relay_selector = relays::RelaySelector::new(
+ let relay_selector = relays::RelaySelector::new(
rpc_handle.clone(),
on_relay_list_update,
&resource_dir,
&cache_dir,
+ api_availability.clone(),
);
@@ -594,6 +599,7 @@ where
let app_version_info = version_check::load_cache(&cache_dir).await;
let (version_updater, version_updater_handle) = version_check::VersionUpdater::new(
rpc_handle.clone(),
+ api_availability.clone(),
cache_dir.clone(),
internal_event_tx.to_specialized_sender(),
app_version_info.clone(),
@@ -667,8 +673,10 @@ where
vec![]
};
+ let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
+
let tunnel_command_tx = tunnel_state_machine::spawn(
- runtime,
+ runtime.clone(),
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
@@ -683,6 +691,7 @@ where
resource_dir,
cache_dir.clone(),
internal_event_tx.to_specialized_sender(),
+ offline_state_tx,
tunnel_state_machine_shutdown_tx,
#[cfg(target_os = "android")]
android_context,
@@ -690,6 +699,8 @@ where
.await
.map_err(Error::TunnelError)?;
+ Self::forward_offline_state(&runtime, api_availability.clone(), offline_state_rx).await;
+
let tsm_api_address_change_tx = Arc::downgrade(&tunnel_command_tx);
tokio::spawn(async move {
while let Some(address_change) = address_change_rx.next().await {
@@ -701,11 +712,25 @@ where
}
});
- let wireguard_key_manager =
- wireguard::KeyManager::new(internal_event_tx.clone(), rpc_handle.clone());
+ let wireguard_key_manager = wireguard::KeyManager::new(
+ internal_event_tx.clone(),
+ api_availability.clone(),
+ rpc_handle.clone(),
+ );
+
+ let account = account::Account::new(
+ runtime,
+ rpc_handle.clone(),
+ settings.get_account_token(),
+ api_availability.clone(),
+ );
// Attempt to download a fresh relay list
- relay_selector.update().await;
+ let mut relay_handle = relay_selector.updater_handle();
+ relay_handle
+ .update_relay_list_deferred()
+ .await
+ .expect("Relay list updated thread has stopped unexpectedly");
let mut daemon = Daemon {
tunnel_command_tx,
@@ -721,8 +746,8 @@ where
event_listener,
settings,
account_history,
+ account,
rpc_runtime,
- accounts_proxy: AccountsProxy::new(rpc_handle.clone()),
rpc_handle,
wireguard_key_manager,
version_updater_handle,
@@ -1418,7 +1443,7 @@ where
async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) {
let daemon_tx = self.tx.clone();
- let future = self.accounts_proxy.create_account();
+ let future = self.account.proxy.create_account();
tokio::spawn(async move {
match future.await {
@@ -1437,17 +1462,20 @@ where
tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>,
account_token: AccountToken,
) {
- let expiry_fut = self.accounts_proxy.get_expiry(account_token);
- let rpc_call = async {
- let result = expiry_fut.await.map(|expiry| AccountData { expiry });
- Self::oneshot_send(tx, result, "account data");
- };
- tokio::spawn(rpc_call);
+ let account = self.account.clone();
+ tokio::spawn(async move {
+ let result = account.check_expiry(account_token).await;
+ Self::oneshot_send(
+ tx,
+ result.map(|expiry| AccountData { expiry }),
+ "account data",
+ );
+ });
}
async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) {
if let Some(account_token) = self.settings.get_account_token() {
- let future = self.accounts_proxy.get_www_auth_token(account_token);
+ let future = self.account.proxy.get_www_auth_token(account_token);
let rpc_call = async {
Self::oneshot_send(
tx,
@@ -1471,15 +1499,17 @@ where
voucher: String,
) {
if let Some(account_token) = self.settings.get_account_token() {
- let future = self.accounts_proxy.submit_voucher(account_token, voucher);
- let rpc_call = async {
+ let mut account = self.account.clone();
+ tokio::spawn(async move {
Self::oneshot_send(
tx,
- future.await.map_err(Error::RestError),
+ account
+ .submit_voucher(account_token, voucher)
+ .await
+ .map_err(Error::RestError),
"submit_voucher response",
);
- };
- tokio::spawn(rpc_call);
+ });
} else {
Self::oneshot_send(tx, Err(Error::NoAccountToken), "submit_voucher response");
}
@@ -2252,6 +2282,7 @@ where
}
Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys),
Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
+ Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
}
}
@@ -2291,6 +2322,7 @@ where
let result = match verification_rpc.await {
Ok(is_valid) => Ok(is_valid),
Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
+ Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
Err(wireguard::Error::TooManyKeys) => return,
};
Self::oneshot_send(tx, result, "verify_wireguard_key response");
@@ -2351,6 +2383,23 @@ where
Some(bypass_tx)
}
+ async fn forward_offline_state(
+ runtime: &tokio::runtime::Handle,
+ api_availability: ApiAvailabilityHandle,
+ mut offline_state_rx: mpsc::UnboundedReceiver<bool>,
+ ) {
+ let initial_state = offline_state_rx
+ .next()
+ .await
+ .expect("missing initial offline state");
+ api_availability.set_offline(initial_state);
+ runtime.spawn(async move {
+ while let Some(is_offline) = offline_state_rx.next().await {
+ api_availability.set_offline(is_offline);
+ }
+ });
+ }
+
/// Set the target state of the client. If it changed trigger the operations needed to
/// progress towards that state.
/// Returns a bool representing whether or not a state change was initiated.
diff --git a/mullvad-daemon/src/relays.rs b/mullvad-daemon/src/relays.rs
index ad1a038817..607724a534 100644
--- a/mullvad-daemon/src/relays.rs
+++ b/mullvad-daemon/src/relays.rs
@@ -9,7 +9,7 @@ use futures::{
};
use ipnetwork::IpNetwork;
use log::{debug, error, info, warn};
-use mullvad_rpc::{rest::MullvadRestHandle, RelayListProxy};
+use mullvad_rpc::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, RelayListProxy};
use mullvad_types::{
endpoint::MullvadEndpoint,
location::Location,
@@ -187,6 +187,7 @@ impl RelaySelector {
on_update: impl Fn(&RelayList) + Send + 'static,
resource_dir: &Path,
cache_dir: &Path,
+ api_availability: ApiAvailabilityHandle,
) -> Self {
let cache_path = cache_dir.join(RELAYS_FILENAME);
let resource_path = resource_dir.join(RELAYS_FILENAME);
@@ -211,6 +212,7 @@ impl RelaySelector {
cache_path,
parsed_relays.clone(),
Box::new(on_update),
+ api_availability,
);
@@ -232,6 +234,10 @@ impl RelaySelector {
}
}
+ pub fn updater_handle(&self) -> RelayListUpdaterHandle {
+ self.updater.as_ref().unwrap().clone()
+ }
+
/// Returns all countries and cities. The cities in the object returned does not have any
/// relays in them.
pub fn get_locations(&mut self) -> RelayList {
@@ -986,13 +992,20 @@ impl RelaySelector {
#[derive(Clone)]
pub struct RelayListUpdaterHandle {
- tx: mpsc::Sender<()>,
+ tx: mpsc::Sender<bool>,
}
impl RelayListUpdaterHandle {
- async fn update_relay_list(&mut self) -> Result<(), Error> {
+ pub async fn update_relay_list(&mut self) -> Result<(), Error> {
self.tx
- .send(())
+ .send(false)
+ .await
+ .map_err(|_| Error::DownloaderShutDown)
+ }
+
+ pub async fn update_relay_list_deferred(&mut self) -> Result<(), Error> {
+ self.tx
+ .send(true)
.await
.map_err(|_| Error::DownloaderShutDown)
}
@@ -1004,6 +1017,7 @@ struct RelayListUpdater {
parsed_relays: Arc<Mutex<ParsedRelays>>,
on_update: Box<dyn Fn(&RelayList) + Send + 'static>,
earliest_next_try: Instant,
+ api_availability: ApiAvailabilityHandle,
}
impl RelayListUpdater {
@@ -1012,6 +1026,7 @@ impl RelayListUpdater {
cache_path: PathBuf,
parsed_relays: Arc<Mutex<ParsedRelays>>,
on_update: Box<dyn Fn(&RelayList) + Send + 'static>,
+ api_availability: ApiAvailabilityHandle,
) -> RelayListUpdaterHandle {
let (tx, cmd_rx) = mpsc::channel(1);
let service = rpc_handle.service();
@@ -1022,6 +1037,7 @@ impl RelayListUpdater {
parsed_relays,
on_update,
earliest_next_try: Instant::now() + UPDATE_INTERVAL,
+ api_availability,
};
service.spawn(updater.run(cmd_rx));
@@ -1029,7 +1045,7 @@ impl RelayListUpdater {
RelayListUpdaterHandle { tx }
}
- async fn run(mut self, mut cmd_rx: mpsc::Receiver<()>) {
+ async fn run(mut self, mut cmd_rx: mpsc::Receiver<bool>) {
let mut check_interval = tokio_stream::wrappers::IntervalStream::new(
tokio::time::interval(UPDATE_CHECK_INTERVAL),
)
@@ -1040,7 +1056,7 @@ impl RelayListUpdater {
_check_update = check_interval.next() => {
if download_future.is_terminated() && self.should_update() {
let tag = self.parsed_relays.lock().tag().map(|tag| tag.to_string());
- download_future = Box::pin(Self::download_relay_list(self.rpc_client.clone(), tag).fuse());
+ download_future = Box::pin(Self::download_relay_list(self.api_availability.clone(), self.rpc_client.clone(), tag).fuse());
self.earliest_next_try = Instant::now() + UPDATE_INTERVAL;
}
},
@@ -1052,9 +1068,14 @@ impl RelayListUpdater {
cmd = cmd_rx.next() => {
match cmd {
- Some(_) => {
+ Some(defer) => {
let tag = self.parsed_relays.lock().tag().map(|tag| tag.to_string());
- self.consume_new_relay_list(self.rpc_client.relay_list(tag).await).await;
+ if defer {
+ let download_future = Self::download_relay_list(self.api_availability.clone(), self.rpc_client.clone(), tag);
+ self.consume_new_relay_list(download_future.await).await;
+ } else {
+ self.consume_new_relay_list(self.rpc_client.relay_list(tag).await.map_err(mullvad_rpc::Error::from)).await;
+ }
},
None => {
log::error!("Relay list updater shutting down");
@@ -1069,7 +1090,7 @@ impl RelayListUpdater {
async fn consume_new_relay_list(
&mut self,
- result: Result<Option<RelayList>, mullvad_rpc::rest::Error>,
+ result: Result<Option<RelayList>, mullvad_rpc::Error>,
) {
match result {
Ok(Some(relay_list)) => {
@@ -1103,10 +1124,18 @@ impl RelayListUpdater {
}
fn download_relay_list(
+ api_handle: ApiAvailabilityHandle,
rpc_handle: RelayListProxy,
tag: Option<String>,
- ) -> impl Future<Output = Result<Option<RelayList>, mullvad_rpc::rest::Error>> + 'static {
- let download_futures = move || rpc_handle.relay_list(tag.clone());
+ ) -> impl Future<Output = Result<Option<RelayList>, mullvad_rpc::Error>> + 'static {
+ let download_futures = move || {
+ let available = api_handle.wait_available();
+ let req = rpc_handle.relay_list(tag.clone());
+ async move {
+ available.await?;
+ req.await.map_err(mullvad_rpc::Error::from)
+ }
+ };
let exponential_backoff =
ExponentialBackoff::new(EXPONENTIAL_BACKOFF_INITIAL, EXPONENTIAL_BACKOFF_FACTOR)
diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs
index d0d5c13d8c..b06466bb1e 100644
--- a/mullvad-daemon/src/version_check.rs
+++ b/mullvad-daemon/src/version_check.rs
@@ -3,7 +3,7 @@ use crate::{
DaemonEventSender,
};
use futures::{channel::mpsc, stream::FusedStream, FutureExt, SinkExt, StreamExt, TryFutureExt};
-use mullvad_rpc::{rest::MullvadRestHandle, AppVersionProxy};
+use mullvad_rpc::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, AppVersionProxy};
use mullvad_types::version::{AppVersionInfo, ParsedAppVersion};
use serde::{Deserialize, Serialize};
use std::{
@@ -78,6 +78,9 @@ pub enum Error {
#[error(display = "Failed to check the latest app version")]
Download(#[error(source)] mullvad_rpc::rest::Error),
+ #[error(display = "API availability check failed")]
+ ApiCheck(#[error(source)] mullvad_rpc::availability::Error),
+
#[error(display = "Clearing version check cache due to a version mismatch")]
CacheVersionMismatch,
}
@@ -92,6 +95,7 @@ pub(crate) struct VersionUpdater {
next_update_time: Instant,
show_beta_releases: bool,
rx: Option<mpsc::Receiver<VersionUpdaterCommand>>,
+ availability_handle: ApiAvailabilityHandle,
}
#[derive(Clone)]
@@ -133,6 +137,7 @@ impl VersionUpdaterHandle {
impl VersionUpdater {
pub fn new(
mut rpc_handle: MullvadRestHandle,
+ availability_handle: ApiAvailabilityHandle,
cache_dir: PathBuf,
update_sender: DaemonEventSender<AppVersionInfo>,
last_app_version_info: Option<AppVersionInfo>,
@@ -154,6 +159,7 @@ impl VersionUpdater {
next_update_time: Instant::now(),
show_beta_releases,
rx: Some(rx),
+ availability_handle,
},
VersionUpdaterHandle { tx },
)
@@ -162,15 +168,20 @@ impl VersionUpdater {
fn create_update_future(
&self,
) -> impl Future<Output = Result<mullvad_rpc::AppVersionResponse, Error>> + Send + 'static {
+ let api_handle = self.availability_handle.clone();
let version_proxy = self.version_proxy.clone();
let platform_version = self.platform_version.clone();
let download_future_factory = move || {
- let response = version_proxy.version_check(
+ let when_available = api_handle.wait_available();
+ let request = version_proxy.version_check(
PRODUCT_VERSION.to_owned(),
PLATFORM,
platform_version.clone(),
);
- response.map_err(Error::Download)
+ async move {
+ when_available.await.map_err(Error::ApiCheck)?;
+ request.await.map_err(Error::Download)
+ }
};
let should_retry = |result: &Result<_, _>| -> bool { result.is_err() };
diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs
index 1c548d3e8d..c0c260de29 100644
--- a/mullvad-daemon/src/wireguard.rs
+++ b/mullvad-daemon/src/wireguard.rs
@@ -1,6 +1,9 @@
use crate::{DaemonEventSender, InternalDaemonEvent};
use chrono::offset::Utc;
-use mullvad_rpc::rest::{Error as RestError, MullvadRestHandle};
+use mullvad_rpc::{
+ availability::ApiAvailabilityHandle,
+ rest::{Error as RestError, MullvadRestHandle},
+};
use mullvad_types::account::AccountToken;
pub use mullvad_types::wireguard::*;
use std::{future::Future, pin::Pin, time::Duration};
@@ -31,6 +34,8 @@ const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
pub enum Error {
#[error(display = "Unexpected HTTP request error")]
RestError(#[error(source)] mullvad_rpc::rest::Error),
+ #[error(display = "API availability check was interrupted")]
+ ApiCheckError(#[error(source)] mullvad_rpc::availability::Error),
#[error(display = "Account already has maximum number of keys")]
TooManyKeys,
}
@@ -39,6 +44,7 @@ pub type Result<T> = std::result::Result<T, Error>;
pub struct KeyManager {
daemon_tx: DaemonEventSender,
+ availability_handle: ApiAvailabilityHandle,
http_handle: MullvadRestHandle,
current_job: Option<AbortHandle>,
@@ -47,9 +53,14 @@ pub struct KeyManager {
}
impl KeyManager {
- pub(crate) fn new(daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle) -> Self {
+ pub(crate) fn new(
+ daemon_tx: DaemonEventSender,
+ availability_handle: ApiAvailabilityHandle,
+ http_handle: MullvadRestHandle,
+ ) -> Self {
Self {
daemon_tx,
+ availability_handle,
http_handle,
current_job: None,
abort_scheduler_tx: None,
@@ -164,11 +175,22 @@ impl KeyManager {
let mut inner_future_generator =
self.push_future_generator(account.clone(), private_key, timeout);
+ let availability_handle = self.availability_handle.clone();
+
let future_generator = move || {
+ let wait_available = availability_handle.wait_available();
let fut = inner_future_generator();
let error_tx = error_tx.clone();
let error_account = error_account.clone();
async move {
+ let error_account_copy = error_account.clone();
+ wait_available.await.map_err(|error| {
+ let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent((
+ error_account_copy,
+ Err(Error::ApiCheckError(error)),
+ )));
+ false
+ })?;
let response = fut.await;
match response {
Ok(addresses) => Ok(addresses),
@@ -299,6 +321,7 @@ impl KeyManager {
async fn create_automatic_rotation(
daemon_tx: DaemonEventSender,
+ availability_handle: ApiAvailabilityHandle,
http_handle: MullvadRestHandle,
mut public_key: PublicKey,
rotation_interval_secs: u64,
@@ -306,14 +329,20 @@ impl KeyManager {
) {
tokio::time::sleep(ROTATION_START_DELAY).await;
- let rotate_key_for_account = move |old_key: &PublicKey| {
- Self::rotate_key(
- daemon_tx.clone(),
- http_handle.clone(),
- account_token.clone(),
- old_key.clone(),
- )
- };
+ let rotate_key_for_account =
+ move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> {
+ let wait_available = availability_handle.wait_available();
+ let rotate = Self::rotate_key(
+ daemon_tx.clone(),
+ http_handle.clone(),
+ account_token.clone(),
+ old_key.clone(),
+ );
+ Box::pin(async move {
+ wait_available.await?;
+ rotate.await
+ })
+ };
loop {
Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await;
@@ -341,12 +370,12 @@ impl KeyManager {
http_handle: MullvadRestHandle,
account_token: AccountToken,
old_key: PublicKey,
- ) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> {
+ ) -> impl Future<Output = Result<PublicKey>> {
let new_key = PrivateKey::new_from_random();
let rpc_result =
Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key);
- Box::pin(async move {
+ async move {
match rpc_result.await {
Ok(data) => {
// Update account data
@@ -365,7 +394,7 @@ impl KeyManager {
}
Err(unknown) => Err(unknown),
}
- })
+ }
}
async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey>
@@ -403,6 +432,7 @@ impl KeyManager {
// Schedule cancellable series of repeating rotation tasks
let fut = Self::create_automatic_rotation(
self.daemon_tx.clone(),
+ self.availability_handle.clone(),
self.http_handle.clone(),
public_key,
self.auto_rotation_interval.as_duration().as_secs(),
diff --git a/mullvad-rpc/src/availability.rs b/mullvad-rpc/src/availability.rs
new file mode 100644
index 0000000000..227bc0cd35
--- /dev/null
+++ b/mullvad-rpc/src/availability.rs
@@ -0,0 +1,127 @@
+use std::{
+ future::Future,
+ sync::{Arc, Mutex},
+};
+use tokio::sync::broadcast;
+
+
+const CHANNEL_CAPACITY: usize = 100;
+
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// The [`ApiAvailability`] instance was dropped, or the receiver lagged behind.
+ #[error(display = "API availability instance was dropped")]
+ Interrupted(#[error(source)] broadcast::error::RecvError),
+}
+
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)]
+pub struct State {
+ pause_automatic: bool,
+ offline: bool,
+}
+
+impl State {
+ pub fn is_paused(&self) -> bool {
+ self.pause_automatic
+ }
+
+ pub fn is_offline(&self) -> bool {
+ self.offline
+ }
+
+ pub fn is_available(&self) -> bool {
+ !self.is_paused() && !self.is_offline()
+ }
+}
+
+pub struct ApiAvailability {
+ state: Arc<Mutex<State>>,
+ tx: broadcast::Sender<State>,
+}
+
+impl ApiAvailability {
+ pub fn new(initial_state: State) -> Self {
+ let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY);
+ let state = Arc::new(Mutex::new(initial_state));
+ ApiAvailability { state, tx }
+ }
+
+ pub fn get_state(&self) -> State {
+ *self.state.lock().unwrap()
+ }
+
+ pub fn handle(&self) -> ApiAvailabilityHandle {
+ ApiAvailabilityHandle {
+ state: self.state.clone(),
+ tx: self.tx.clone(),
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ApiAvailabilityHandle {
+ state: Arc<Mutex<State>>,
+ tx: broadcast::Sender<State>,
+}
+
+impl ApiAvailabilityHandle {
+ pub fn pause(&self) {
+ let mut state = self.state.lock().unwrap();
+ if !state.pause_automatic {
+ state.pause_automatic = true;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn resume(&self) {
+ let mut state = self.state.lock().unwrap();
+ if state.pause_automatic {
+ state.pause_automatic = false;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn set_offline(&self, offline: bool) {
+ let mut state = self.state.lock().unwrap();
+ if state.offline != offline {
+ state.offline = offline;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn get_state(&self) -> State {
+ *self.state.lock().unwrap()
+ }
+
+ pub fn wait_available(&self) -> impl Future<Output = Result<(), Error>> {
+ self.wait_for_state(|state| state.is_available())
+ }
+
+ pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> {
+ self.wait_for_state(|state| !state.is_offline())
+ }
+
+ fn wait_for_state(
+ &self,
+ state_ready: impl Fn(State) -> bool,
+ ) -> impl Future<Output = Result<(), Error>> {
+ let mut rx = self.tx.subscribe();
+ let state = self.state.clone();
+
+ async move {
+ let current_state = { *state.lock().unwrap() };
+ if state_ready(current_state) {
+ return Ok(());
+ }
+
+ loop {
+ let new_state = rx.recv().await?;
+ if state_ready(new_state) {
+ return Ok(());
+ }
+ }
+ }
+ }
+}
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 96260775f8..098ed2f0b4 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -18,6 +18,8 @@ use std::{
use talpid_types::{net::wireguard, ErrorExt};
+pub mod availability;
+use availability::{ApiAvailability, ApiAvailabilityHandle};
pub mod rest;
mod https_client_with_sni;
@@ -41,6 +43,9 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER";
/// Error code returned by the Mullvad API if the account token is invalid.
pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT";
+/// Error code returned by the Mullvad API if the account token is missing or invalid.
+pub const INVALID_AUTH: &str = "INVALID_AUTH";
+
const API_HOST: &str = "api.mullvad.net";
pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt";
const API_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78));
@@ -51,6 +56,7 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443);
pub struct MullvadRpcRuntime {
handle: tokio::runtime::Handle,
pub address_cache: AddressCache,
+ api_availability: availability::ApiAvailability,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -62,6 +68,9 @@ pub enum Error {
#[error(display = "Failed to load address cache")]
AddressCacheError(#[error(source)] address_cache::Error),
+
+ #[error(display = "API availability check failed")]
+ ApiCheckError(#[error(source)] availability::Error),
}
impl MullvadRpcRuntime {
@@ -74,6 +83,7 @@ impl MullvadRpcRuntime {
None,
Arc::new(Box::new(|_| Ok(()))),
)?,
+ api_availability: ApiAvailability::new(availability::State::default()),
#[cfg(target_os = "android")]
socket_bypass_tx: None,
})
@@ -139,6 +149,7 @@ impl MullvadRpcRuntime {
Ok(MullvadRpcRuntime {
handle,
address_cache,
+ api_availability: ApiAvailability::new(availability::State::default()),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
@@ -156,6 +167,7 @@ impl MullvadRpcRuntime {
let service = rest::RequestService::new(
https_connector,
self.handle.clone(),
+ self.api_availability.handle(),
self.address_cache.clone(),
);
let handle = service.handle();
@@ -172,7 +184,12 @@ impl MullvadRpcRuntime {
Some("app".to_owned()),
);
- rest::MullvadRestHandle::new(service, factory, self.address_cache.clone())
+ rest::MullvadRestHandle::new(
+ service,
+ factory,
+ self.address_cache.clone(),
+ self.availability_handle(),
+ )
}
/// Returns a new request service handle
@@ -183,8 +200,13 @@ impl MullvadRpcRuntime {
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
+
+ pub fn availability_handle(&self) -> ApiAvailabilityHandle {
+ self.api_availability.handle()
+ }
}
+#[derive(Clone)]
pub struct AccountsProxy {
handle: rest::MullvadRestHandle,
}
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index be4c1dc990..77bf06fd55 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,6 +1,6 @@
use crate::{
- address_cache::AddressCache, https_client_with_sni::HttpsConnectorWithSni,
- tcp_stream::TcpStreamHandle,
+ address_cache::AddressCache, availability::ApiAvailabilityHandle,
+ https_client_with_sni::HttpsConnectorWithSni, tcp_stream::TcpStreamHandle,
};
use futures::{
channel::{mpsc, oneshot},
@@ -22,6 +22,7 @@ use std::{
str::FromStr,
time::{Duration, Instant},
};
+use talpid_types::ErrorExt;
use tokio::runtime::Handle;
pub use hyper::StatusCode;
@@ -84,6 +85,7 @@ pub(crate) struct RequestService {
handle: Handle,
next_id: u64,
in_flight_requests: BTreeMap<u64, AbortHandle>,
+ api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
}
@@ -92,6 +94,7 @@ impl RequestService {
pub fn new(
mut connector: HttpsConnectorWithSni,
handle: Handle,
+ api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
) -> RequestService {
let (command_tx, command_rx) = mpsc::channel(1);
@@ -99,7 +102,6 @@ impl RequestService {
connector.set_service_tx(command_tx.clone());
let client = Client::builder().build(connector);
-
Self {
command_tx,
command_rx,
@@ -108,6 +110,7 @@ impl RequestService {
in_flight_requests: BTreeMap::new(),
next_id: 0,
handle,
+ api_availability,
address_cache,
}
}
@@ -134,6 +137,7 @@ impl RequestService {
abortable(self.client.request(hyper_request).map_err(Error::from));
let address_cache = self.address_cache.clone();
let handle = self.handle.clone();
+ let api_availability = self.api_availability.clone();
let future = async move {
let response =
@@ -146,20 +150,25 @@ impl RequestService {
if let Err(err) = &response {
match err {
Error::HyperError(_) | Error::TimeoutError(_) => {
- log::error!("HTTP request failed: {}", err);
- let current_address = address_cache.peek_address();
- if current_address == host_addr
- && address_cache.has_tried_current_address()
- {
- handle.spawn(async move {
- address_cache.select_new_address().await;
- let new_address = address_cache.peek_address();
- log::error!(
- "Request failed using address {}. Trying next API address: {}",
- current_address,
- new_address,
- );
- });
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("HTTP request failed")
+ );
+ if !api_availability.get_state().is_offline() {
+ let current_address = address_cache.peek_address();
+ if current_address == host_addr
+ && address_cache.has_tried_current_address()
+ {
+ handle.spawn(async move {
+ address_cache.select_new_address().await;
+ let new_address = address_cache.peek_address();
+ log::error!(
+ "Request failed using address {}. Trying next API address: {}",
+ current_address,
+ new_address,
+ );
+ });
+ }
}
}
_ => (),
@@ -594,6 +603,7 @@ pub async fn handle_error_response<T>(response: Response) -> Result<T> {
pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
+ availability: ApiAvailabilityHandle,
}
impl MullvadRestHandle {
@@ -601,8 +611,13 @@ impl MullvadRestHandle {
service: RequestServiceHandle,
factory: RequestFactory,
address_cache: AddressCache,
+ availability: ApiAvailabilityHandle,
) -> Self {
- let handle = Self { service, factory };
+ let handle = Self {
+ service,
+ factory,
+ availability,
+ };
handle.spawn_api_address_fetcher(address_cache);
handle
@@ -610,10 +625,11 @@ impl MullvadRestHandle {
fn spawn_api_address_fetcher(&self, address_cache: AddressCache) {
let handle = self.clone();
+ let availability = self.availability.clone();
self.service.spawn(async move {
// always start the fetch after 15 minutes
- let api_proxy = crate::ApiProxy { handle };
+ let api_proxy = crate::ApiProxy::new(handle);
let mut next_check = Instant::now() + API_IP_CHECK_DELAY;
let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL;
@@ -624,6 +640,11 @@ impl MullvadRestHandle {
loop {
interval.tick().await;
if next_check < Instant::now() {
+ if let Err(error) = availability.wait_available().await {
+ log::error!("Failed while waiting for API: {}", error);
+ next_check = next_error_check();
+ continue;
+ }
match api_proxy.clone().get_api_addrs().await {
Ok(new_addrs) => {
log::debug!("Fetched new API addresses {:?}, will fetch again in {} hours", new_addrs, API_IP_CHECK_INTERVAL.as_secs() / ( 60 * 60 ));
diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs
index 0acb0941aa..b5479640e6 100644
--- a/mullvad-types/src/account.rs
+++ b/mullvad-types/src/account.rs
@@ -15,6 +15,13 @@ pub struct AccountData {
pub expiry: DateTime<Utc>,
}
+impl AccountData {
+ /// Return true if the account has no time left.
+ pub fn is_expired(&self) -> bool {
+ self.expiry >= Utc::now()
+ }
+}
+
/// Data structure that's returned from successful invocation of the mullvad API's
/// `/v1/submit-voucher` RPC.
#[derive(Deserialize, Serialize, Debug)]
diff --git a/talpid-core/src/offline/android.rs b/talpid-core/src/offline/android.rs
index fefe2556cf..65f0e7cf58 100644
--- a/talpid-core/src/offline/android.rs
+++ b/talpid-core/src/offline/android.rs
@@ -1,4 +1,3 @@
-use crate::tunnel_state_machine::TunnelCommand;
use futures::channel::mpsc::UnboundedSender;
use jnix::{
jni::{
@@ -44,10 +43,14 @@ pub struct MonitorHandle {
jvm: Arc<JavaVM>,
class: GlobalRef,
object: GlobalRef,
+ _sender: Arc<UnboundedSender<bool>>,
}
impl MonitorHandle {
- pub fn new(android_context: AndroidContext) -> Result<Self, Error> {
+ pub fn new(
+ android_context: AndroidContext,
+ sender: Arc<UnboundedSender<bool>>,
+ ) -> Result<Self, Error> {
let env = JnixEnv::from(
android_context
.jvm
@@ -93,6 +96,7 @@ impl MonitorHandle {
jvm: android_context.jvm,
class,
object,
+ _sender: sender,
})
}
@@ -128,7 +132,7 @@ impl MonitorHandle {
}
}
- fn set_sender(&self, sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<(), Error> {
+ fn set_sender(&self, sender: Weak<UnboundedSender<bool>>) -> Result<(), Error> {
let sender_ptr = Box::new(sender);
let sender_address = Box::into_raw(sender_ptr) as jlong;
@@ -181,10 +185,10 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnec
sender_address: jlong,
) {
let sender_ref = Box::leak(unsafe { get_sender_from_address(sender_address) });
- let tunnel_command = TunnelCommand::IsOffline(is_connected == JNI_FALSE);
+ let is_offline = is_connected == JNI_FALSE;
if let Some(sender) = sender_ref.upgrade() {
- if sender.unbounded_send(tunnel_command).is_err() {
+ if sender.unbounded_send(is_offline).is_err() {
log::warn!("Failed to send offline change event");
}
}
@@ -201,17 +205,19 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySende
let _ = unsafe { get_sender_from_address(sender_address) };
}
-unsafe fn get_sender_from_address(address: jlong) -> Box<Weak<UnboundedSender<TunnelCommand>>> {
- Box::from_raw(address as *mut Weak<UnboundedSender<TunnelCommand>>)
+unsafe fn get_sender_from_address(address: jlong) -> Box<Weak<UnboundedSender<bool>>> {
+ Box::from_raw(address as *mut Weak<UnboundedSender<bool>>)
}
pub async fn spawn_monitor(
- sender: Weak<UnboundedSender<TunnelCommand>>,
+ sender: UnboundedSender<bool>,
android_context: AndroidContext,
) -> Result<MonitorHandle, Error> {
- let monitor_handle = MonitorHandle::new(android_context)?;
+ let sender = Arc::new(sender);
+ let weak_sender = Arc::downgrade(&sender);
+ let monitor_handle = MonitorHandle::new(android_context, sender)?;
- monitor_handle.set_sender(sender)?;
+ monitor_handle.set_sender(weak_sender)?;
Ok(monitor_handle)
}
diff --git a/talpid-core/src/offline/linux.rs b/talpid-core/src/offline/linux.rs
index ceaa864cc7..f9e137853b 100644
--- a/talpid-core/src/offline/linux.rs
+++ b/talpid-core/src/offline/linux.rs
@@ -1,11 +1,8 @@
-use crate::{
- routing::{self, RouteManagerHandle},
- tunnel_state_machine::TunnelCommand,
-};
+use crate::routing::{self, RouteManagerHandle};
use futures::{channel::mpsc::UnboundedSender, StreamExt};
use std::{
net::{IpAddr, Ipv4Addr},
- sync::Weak,
+ sync::Arc,
};
use talpid_types::ErrorExt;
@@ -20,6 +17,7 @@ pub enum Error {
pub struct MonitorHandle {
route_manager: RouteManagerHandle,
+ _notify_tx: Arc<UnboundedSender<bool>>,
}
// Mullvad API's public IP address, correct at the time of writing, but any public IP address will
@@ -42,7 +40,7 @@ impl MonitorHandle {
}
pub async fn spawn_monitor(
- sender: Weak<UnboundedSender<TunnelCommand>>,
+ notify_tx: UnboundedSender<bool>,
route_manager: RouteManagerHandle,
) -> Result<MonitorHandle> {
let mut is_offline = public_ip_unreachable(&route_manager).await?;
@@ -52,8 +50,11 @@ pub async fn spawn_monitor(
.await
.map_err(Error::RouteManagerError)?;
+ let notify_tx = Arc::new(notify_tx);
+ let sender = Arc::downgrade(&notify_tx);
let monitor_handle = MonitorHandle {
route_manager: route_manager.clone(),
+ _notify_tx: notify_tx,
};
tokio::spawn(async move {
@@ -71,7 +72,7 @@ pub async fn spawn_monitor(
});
if new_offline_state != is_offline {
is_offline = new_offline_state;
- let _ = sender.unbounded_send(TunnelCommand::IsOffline(is_offline));
+ let _ = sender.unbounded_send(is_offline);
}
}
None => return,
diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs
index 2569fa06c6..3e374cf29c 100644
--- a/talpid-core/src/offline/macos.rs
+++ b/talpid-core/src/offline/macos.rs
@@ -1,4 +1,3 @@
-use crate::tunnel_state_machine::TunnelCommand;
use futures::channel::mpsc::UnboundedSender;
use std::{
net::{Ipv4Addr, SocketAddr},
@@ -39,7 +38,9 @@ pub enum Error {
InitializationError,
}
-pub struct MonitorHandle;
+pub struct MonitorHandle {
+ _notify_tx: Arc<UnboundedSender<bool>>,
+}
impl MonitorHandle {
/// Host is considered to be offline if the IPv4 internet is considered to be unreachable by the
@@ -54,10 +55,10 @@ impl MonitorHandle {
}
}
-pub async fn spawn_monitor(
- sender: Weak<UnboundedSender<TunnelCommand>>,
-) -> Result<MonitorHandle, Error> {
+pub async fn spawn_monitor(notify_tx: UnboundedSender<bool>) -> Result<MonitorHandle, Error> {
let (result_tx, result_rx) = mpsc::channel();
+ let notify_tx = Arc::new(notify_tx);
+ let sender = Arc::downgrade(&notify_tx);
thread::spawn(move || {
let mut reachability_ref = SCNetworkReachability::from(ipv4_internet());
let store = SCDynamicStoreBuilder::new("talpid-offline-watcher").build();
@@ -108,7 +109,9 @@ pub async fn spawn_monitor(
});
let _ = result_rx.recv().map_err(|_| Error::InitializationError)??;
- Ok(MonitorHandle {})
+ Ok(MonitorHandle {
+ _notify_tx: notify_tx,
+ })
}
fn ipv4_internet() -> SocketAddr {
@@ -170,7 +173,7 @@ fn iface_is_physical(iface: &SCNetworkInterface) -> bool {
#[derive(Clone)]
struct OfflineStateContext {
- sender: Weak<UnboundedSender<TunnelCommand>>,
+ sender: Weak<UnboundedSender<bool>>,
is_offline: Arc<AtomicBool>,
}
@@ -182,8 +185,7 @@ impl OfflineStateContext {
fn new_state(&self, is_offline: bool) {
if self.is_offline.swap(is_offline, Ordering::SeqCst) != is_offline {
if let Some(sender) = self.sender.upgrade() {
- let cmd = TunnelCommand::IsOffline(is_offline);
- let _ = sender.unbounded_send(cmd);
+ let _ = sender.unbounded_send(is_offline);
}
}
}
diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs
index ac8b10e222..b713e0ba69 100644
--- a/talpid-core/src/offline/mod.rs
+++ b/talpid-core/src/offline/mod.rs
@@ -1,8 +1,6 @@
#[cfg(target_os = "linux")]
use crate::routing::RouteManagerHandle;
-use crate::tunnel_state_machine::TunnelCommand;
use futures::channel::mpsc::UnboundedSender;
-use std::sync::Weak;
#[cfg(target_os = "android")]
use talpid_types::android::AndroidContext;
@@ -43,7 +41,7 @@ impl MonitorHandle {
}
pub async fn spawn_monitor(
- sender: Weak<UnboundedSender<TunnelCommand>>,
+ sender: UnboundedSender<bool>,
#[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<MonitorHandle, Error> {
diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs
index d9e5c7782d..84d83fd7cd 100644
--- a/talpid-core/src/offline/windows.rs
+++ b/talpid-core/src/offline/windows.rs
@@ -1,4 +1,4 @@
-use crate::{logging::windows::log_sink, tunnel_state_machine::TunnelCommand, winnet};
+use crate::{logging::windows::log_sink, winnet};
use futures::channel::mpsc::UnboundedSender;
use parking_lot::Mutex;
use std::{
@@ -49,16 +49,18 @@ pub struct BroadcastListener {
thread_handle: RawHandle,
thread_id: DWORD,
_system_state: Arc<Mutex<SystemState>>,
+ _notify_tx: Arc<UnboundedSender<bool>>,
}
unsafe impl Send for BroadcastListener {}
impl BroadcastListener {
- pub fn start(sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<Self, Error> {
+ pub fn start(notify_tx: UnboundedSender<bool>) -> Result<Self, Error> {
+ let notify_tx = Arc::new(notify_tx);
let mut system_state = Arc::new(Mutex::new(SystemState {
network_connectivity: None,
suspended: false,
- daemon_channel: sender,
+ notify_tx: Arc::downgrade(&notify_tx),
}));
let power_broadcast_state_ref = system_state.clone();
@@ -95,6 +97,7 @@ impl BroadcastListener {
thread_handle: real_handle,
thread_id: unsafe { GetThreadId(real_handle) },
_system_state: system_state,
+ _notify_tx: notify_tx,
})
}
@@ -229,7 +232,7 @@ enum StateChange {
struct SystemState {
network_connectivity: Option<bool>,
suspended: bool,
- daemon_channel: Weak<UnboundedSender<TunnelCommand>>,
+ notify_tx: Weak<UnboundedSender<bool>>,
}
impl SystemState {
@@ -247,10 +250,8 @@ impl SystemState {
let new_state = self.is_offline_currently();
if old_state != new_state {
- if let Some(daemon_channel) = self.daemon_channel.upgrade() {
- if let Err(e) = daemon_channel
- .unbounded_send(TunnelCommand::IsOffline(new_state.unwrap_or(false)))
- {
+ if let Some(notify_tx) = self.notify_tx.upgrade() {
+ if let Err(e) = notify_tx.unbounded_send(new_state.unwrap_or(false)) {
log::error!("Failed to send new offline state to daemon: {}", e);
}
}
@@ -264,9 +265,7 @@ impl SystemState {
pub type MonitorHandle = BroadcastListener;
-pub async fn spawn_monitor(
- sender: Weak<UnboundedSender<TunnelCommand>>,
-) -> Result<MonitorHandle, Error> {
+pub async fn spawn_monitor(sender: UnboundedSender<bool>) -> Result<MonitorHandle, Error> {
BroadcastListener::start(sender)
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 0b745ac6d2..c4e5e5ef0c 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -104,6 +104,7 @@ pub async fn spawn(
resource_dir: PathBuf,
cache_dir: impl AsRef<Path> + Send + 'static,
state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static,
+ offline_state_listener: mpsc::UnboundedSender<bool>,
shutdown_tx: oneshot::Sender<()>,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
@@ -128,6 +129,7 @@ pub async fn spawn(
runtime.clone(),
initial_settings,
weak_command_tx,
+ offline_state_listener,
tunnel_parameters_generator,
tun_provider,
log_dir,
@@ -216,6 +218,7 @@ impl TunnelStateMachine {
runtime: tokio::runtime::Handle,
settings: InitialTunnelState,
command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ offline_state_tx: mpsc::UnboundedSender<bool>,
tunnel_parameters_generator: impl TunnelParametersGenerator,
tun_provider: TunProvider,
log_dir: Option<PathBuf>,
@@ -247,8 +250,21 @@ impl TunnelStateMachine {
.map_err(Error::InitRouteManagerError)?,
)
.map_err(Error::InitDnsMonitorError)?;
+
+ let (offline_tx, mut offline_rx) = mpsc::unbounded();
+ let initial_offline_state_tx = offline_state_tx.clone();
+ tokio::spawn(async move {
+ while let Some(offline) = offline_rx.next().await {
+ if let Some(tx) = command_tx.upgrade() {
+ let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline));
+ } else {
+ break;
+ }
+ let _ = offline_state_tx.unbounded_send(offline);
+ }
+ });
let mut offline_monitor = offline::spawn_monitor(
- command_tx,
+ offline_tx,
#[cfg(target_os = "linux")]
route_manager
.handle()
@@ -259,6 +275,7 @@ impl TunnelStateMachine {
.await
.map_err(Error::OfflineMonitorError)?;
let is_offline = offline_monitor.is_offline().await;
+ let _ = initial_offline_state_tx.unbounded_send(is_offline);
#[cfg(windows)]
split_tunnel