summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-02-15 19:47:07 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-02-16 16:37:37 +0100
commite471d0739446279b01022090ac4457fe337ca598 (patch)
treea67ac87a161cae956cde7bc4023cfd12beba5a6a
parentc8a3a3be92098cf64bc9269b3c4791e41c3b500d (diff)
downloadmullvadvpn-e471d0739446279b01022090ac4457fe337ca598.tar.xz
mullvadvpn-e471d0739446279b01022090ac4457fe337ca598.zip
Refactor API access methods
-rw-r--r--mullvad-api/src/bin/relay_list.rs10
-rw-r--r--mullvad-api/src/lib.rs25
-rw-r--r--mullvad-api/src/proxy.rs42
-rw-r--r--mullvad-api/src/rest.rs65
-rw-r--r--mullvad-daemon/src/access_method.rs164
-rw-r--r--mullvad-daemon/src/api.rs236
-rw-r--r--mullvad-daemon/src/lib.rs46
-rw-r--r--mullvad-problem-report/src/lib.rs2
-rw-r--r--mullvad-setup/src/main.rs2
-rw-r--r--mullvad-types/src/access_method.rs11
-rw-r--r--test/test-manager/src/tests/account.rs5
11 files changed, 295 insertions, 313 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index e395d8ae5f..8cb615d77f 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -11,12 +11,10 @@ async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
- let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(
- ApiConnectionMode::Direct,
- ApiConnectionMode::Direct.into_repeat(),
- ))
- .relay_list(None)
- .await;
+ let relay_list_request =
+ RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()))
+ .relay_list(None)
+ .await;
let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index dad6cdf706..6114bec90a 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -1,6 +1,5 @@
#[cfg(target_os = "android")]
use futures::channel::mpsc;
-use futures::Stream;
use hyper::Method;
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
@@ -8,7 +7,7 @@ use mullvad_types::{
account::{AccountData, AccountToken, VoucherSubmission},
version::AppVersion,
};
-use proxy::ApiConnectionMode;
+use proxy::{ApiConnectionMode, ConnectionModeProvider};
use std::{
cell::Cell,
collections::BTreeMap,
@@ -408,34 +407,30 @@ impl Runtime {
}
/// Creates a new request service and returns a handle to it.
- fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
+ fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
- initial_connection_mode: ApiConnectionMode,
- proxy_provider: T,
+ connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.handle(),
self.address_cache.clone(),
- initial_connection_mode,
- proxy_provider,
+ connection_mode_provider,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}
/// Returns a request factory initialized to create requests for the master API
- pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
+ pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
- initial_connection_mode: ApiConnectionMode,
- proxy_provider: T,
+ connection_mode_provider: T,
) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(API.host().to_string()),
- initial_connection_mode,
- proxy_provider,
+ connection_mode_provider,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -454,8 +449,7 @@ impl Runtime {
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(hostname.clone()),
- ApiConnectionMode::Direct,
- futures::stream::repeat(ApiConnectionMode::Direct),
+ ApiConnectionMode::Direct.into_provider(),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -474,8 +468,7 @@ impl Runtime {
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
self.new_request_service(
None,
- ApiConnectionMode::Direct,
- ApiConnectionMode::Direct.into_repeat(),
+ ApiConnectionMode::Direct.into_provider(),
#[cfg(target_os = "android")]
None,
)
diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs
index 2b4821ba64..0915d1d23c 100644
--- a/mullvad-api/src/proxy.rs
+++ b/mullvad-api/src/proxy.rs
@@ -1,4 +1,3 @@
-use futures::Stream;
use hyper::client::connect::Connected;
use serde::{Deserialize, Serialize};
use std::{
@@ -18,6 +17,41 @@ use tokio::{
const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json";
+pub trait ConnectionModeProvider: Send {
+ /// Initial connection mode
+ fn initial(&self) -> ApiConnectionMode;
+
+ /// Request a new connection mode from the provider
+ fn rotate(&self) -> impl std::future::Future<Output = ()> + Send;
+
+ /// Receive changes to the connection mode, announced by the provider
+ fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send;
+}
+
+pub struct StaticConnectionModeProvider {
+ mode: ApiConnectionMode,
+}
+
+impl StaticConnectionModeProvider {
+ pub fn new(mode: ApiConnectionMode) -> Self {
+ Self { mode }
+ }
+}
+
+impl ConnectionModeProvider for StaticConnectionModeProvider {
+ fn initial(&self) -> ApiConnectionMode {
+ self.mode.clone()
+ }
+
+ fn rotate(&self) -> impl std::future::Future<Output = ()> + Send {
+ futures::future::ready(())
+ }
+
+ fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send {
+ futures::future::pending()
+ }
+}
+
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub enum ApiConnectionMode {
/// Connect directly to the target.
@@ -153,10 +187,8 @@ impl ApiConnectionMode {
*self != ApiConnectionMode::Direct
}
- /// Convenience function that returns a stream that repeats
- /// this config forever.
- pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> {
- futures::stream::repeat(self)
+ pub fn into_provider(self) -> StaticConnectionModeProvider {
+ StaticConnectionModeProvider::new(self)
}
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index ca63f16c1f..158d84f01b 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -5,12 +5,11 @@ use crate::{
address_cache::AddressCache,
availability::ApiAvailabilityHandle,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
- proxy::ApiConnectionMode,
+ proxy::ConnectionModeProvider,
};
use futures::{
channel::{mpsc, oneshot},
stream::StreamExt,
- Stream,
};
use hyper::{
client::{connect::Connect, Client},
@@ -120,23 +119,22 @@ impl Error {
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
/// requests
-pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {
+pub(crate) struct RequestService<T: ConnectionModeProvider> {
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
- proxy_config_provider: T,
+ connection_mode_provider: T,
api_availability: ApiAvailabilityHandle,
}
-impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
+impl<T: ConnectionModeProvider + 'static> RequestService<T> {
/// Constructs a new request service.
pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
- initial_connection_mode: ApiConnectionMode,
- proxy_config_provider: T,
+ connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
@@ -146,7 +144,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
socket_bypass_tx.clone(),
);
- connector_handle.set_connection_mode(initial_connection_mode);
+ connector_handle.set_connection_mode(connection_mode_provider.initial());
let (command_tx, command_rx) = mpsc::unbounded();
let client = Client::builder().build(connector);
@@ -158,7 +156,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
command_rx,
connector_handle,
client,
- proxy_config_provider,
+ connection_mode_provider,
api_availability,
};
let handle = RequestServiceHandle { tx: command_tx };
@@ -166,6 +164,27 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
handle
}
+ async fn into_future(mut self) {
+ loop {
+ tokio::select! {
+ new_mode = self.connection_mode_provider.receive() => {
+ let Some(new_mode) = new_mode else {
+ break;
+ };
+ self.connector_handle.set_connection_mode(new_mode);
+ }
+ command = self.command_rx.next() => {
+ let Some(command) = command else {
+ break;
+ };
+
+ self.process_command(command).await;
+ }
+ }
+ }
+ self.connector_handle.reset();
+ }
+
async fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
@@ -174,11 +193,8 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
RequestCommand::Reset => {
self.connector_handle.reset();
}
- RequestCommand::NextApiConfig(completion_tx) => {
- if let Some(connection_mode) = self.proxy_config_provider.next().await {
- self.connector_handle.set_connection_mode(connection_mode);
- }
- let _ = completion_tx.send(Ok(()));
+ RequestCommand::NextApiConfig => {
+ self.connection_mode_provider.rotate().await;
}
}
}
@@ -201,8 +217,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
if err.is_network_error() && !api_availability.get_state().is_offline() {
log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
if let Some(tx) = tx {
- let (completion_tx, _completion_rx) = oneshot::channel();
- let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
+ let _ = tx.unbounded_send(RequestCommand::NextApiConfig);
}
}
}
@@ -210,13 +225,6 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
let _ = completion_tx.send(response);
});
}
-
- async fn into_future(mut self) {
- while let Some(command) = self.command_rx.next().await {
- self.process_command(command).await;
- }
- self.connector_handle.reset();
- }
}
#[derive(Clone)]
@@ -239,15 +247,6 @@ impl RequestServiceHandle {
.map_err(|_| Error::RestServiceDown)?;
completion_rx.await.map_err(|_| Error::RestServiceDown)?
}
-
- /// Forcibly update the connection mode.
- pub async fn next_api_endpoint(&self) -> Result<()> {
- let (completion_tx, completion_rx) = oneshot::channel();
- self.tx
- .unbounded_send(RequestCommand::NextApiConfig(completion_tx))
- .map_err(|_| Error::RestServiceDown)?;
- completion_rx.await.map_err(|_| Error::RestServiceDown)?
- }
}
#[derive(Debug)]
@@ -257,7 +256,7 @@ pub(crate) enum RequestCommand {
oneshot::Sender<std::result::Result<Response, Error>>,
),
Reset,
- NextApiConfig(oneshot::Sender<std::result::Result<(), Error>>),
+ NextApiConfig,
}
/// A REST request that is sent to the RequestService to be executed.
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index 664fce6bfe..793d82bb5c 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -1,8 +1,4 @@
-use crate::{
- api,
- settings::{self, MadeChanges},
- Daemon, EventListener,
-};
+use crate::{api, settings, Daemon, EventListener};
use mullvad_api::{proxy::ApiConnectionMode, rest, ApiProxy};
use mullvad_types::{
access_method::{self, AccessMethod, AccessMethodSetting},
@@ -17,9 +13,6 @@ pub enum Error {
/// Can not find access method
#[error(display = "Cannot find custom access method {}", _0)]
NoSuchMethod(access_method::Id),
- /// Access method could not be rotate
- #[error(display = "Access method could not be rotated")]
- RotationFailed,
/// Some error occured in the daemon's state of handling
/// [`AccessMethodSetting`]s & [`ApiConnectionMode`]s
#[error(display = "Error occured when handling connection settings & details")]
@@ -54,42 +47,21 @@ where
let id = access_method_setting.get_id();
self.settings
.update(|settings| settings.api_access_methods.append(access_method_setting))
- .await
- .map(|did_change| self.notify_on_change(did_change))
- .map(|_| id)
- .map_err(Error::Settings)
+ .await?;
+ Ok(id)
}
/// Remove a [`AccessMethodSetting`] from the daemon's saved settings.
- ///
- /// If the [`AccessMethodSetting`] which is currently in use happens to be
- /// removed, the daemon should force a rotation of the active API endpoint.
pub async fn remove_access_method(
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
- let did_change = self
- .settings
+ self.settings
.try_update(|settings| -> Result<(), Error> {
settings.api_access_methods.remove(&access_method)?;
Ok(())
})
- .await
- .map_err(Error::Settings)?;
-
- self.notify_on_change(did_change);
- // If the currently active access method is removed, a new access
- // method should be selected.
- //
- // Notice the ordering here: It is important that the current method is
- // removed before we pick a new access method. The `remove` function
- // will ensure that atleast one access method is enabled after the
- // removal. If the currently active access method is removed, some other
- // method is enabled before we pick the next access method to use.
- if self.is_in_use(access_method.clone()).await? {
- self.force_api_endpoint_rotation().await?;
- }
-
+ .await?;
Ok(())
}
@@ -110,19 +82,18 @@ where
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
- let mut access_method = self.get_api_access_method(access_method)?;
- // Toggle the enabled status if needed
- if !access_method.enabled() {
- access_method.enable();
- self.update_access_method_inner(access_method.clone())
- .await?
- }
- // Set `access_method` as the next access method to use
- self.connection_modes_handler
- .set_access_method(access_method)
+ self.settings
+ .update(|settings| {
+ settings.api_access_methods.update(
+ |setting| setting.get_id() == access_method,
+ |setting| setting.enable(),
+ );
+ })
.await?;
- // Force a rotation of Access Methods
- self.force_api_endpoint_rotation().await
+ self.access_mode_handler
+ .use_access_method(access_method)
+ .await?;
+ Ok(())
}
pub fn get_api_access_method(
@@ -140,88 +111,28 @@ where
/// Updates a [`AccessMethodSetting`] by replacing the existing entry with
/// the argument `access_method_update`. if an entry with a matching
/// [`access_method::Id`] is found.
- ///
- /// If the currently active [`AccessMethodSetting`] is updated, the daemon
- /// will automatically use this updated [`AccessMethodSetting`] when
- /// performing subsequent API calls.
pub async fn update_access_method(
&mut self,
access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
- self.update_access_method_inner(access_method_update.clone())
- .await?;
-
- if self.is_in_use(access_method_update.get_id()).await? {
- if access_method_update.disabled() {
- // If the currently active access method is updated & disabled
- // we should select the next access method
- self.force_api_endpoint_rotation().await?;
- } else {
- // If the currently active access method is just updated, we
- // need to re-set it after updating the settings
- self.use_api_access_method(access_method_update.get_id())
- .await?;
- }
- }
-
- Ok(())
- }
-
- /// Updates a [`AccessMethodSetting`] by replacing the existing entry with
- /// the argument `access_method_update`. if an entry with a matching
- /// [`access_method::Id`] is found.
- ///
- /// This inner function does not perform any kind of check to see if the
- /// existing, in-use setting needs to be re-set.
- async fn update_access_method_inner(
- &mut self,
- access_method_update: AccessMethodSetting,
- ) -> Result<(), Error> {
- let settings_update = |settings: &mut Settings| {
- let target = access_method_update.get_id();
- settings.api_access_methods.update(
- |access_method| access_method.get_id() == target,
- |_| access_method_update,
- );
- };
-
self.settings
- .update(settings_update)
- .await
- .map(|did_change| self.notify_on_change(did_change))
- .map_err(Error::Settings)?;
+ .update(|settings: &mut Settings| {
+ let target = access_method_update.get_id();
+ settings.api_access_methods.update(
+ |access_method| access_method.get_id() == target,
+ |method| *method = access_method_update,
+ );
+ })
+ .await?;
Ok(())
}
- /// Check if some access method is the same as the currently active one.
- ///
- /// This can be useful for invalidating stale states.
- async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> {
- Ok(access_method == self.get_current_access_method().await?.get_id())
- }
-
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
- self.connection_modes_handler
- .get_current()
- .await
- .map(|current| current.setting)
- .map_err(Error::ApiService)
- }
-
- /// Change which [`AccessMethodSetting`] which will be used as the Mullvad
- /// API endpoint.
- async fn force_api_endpoint_rotation(&self) -> Result<(), Error> {
- self.api_handle
- .service()
- .next_api_endpoint()
- .await
- .map_err(|error| {
- log::error!("Failed to rotate API endpoint: {}", error);
- Error::RotationFailed
- })
+ let current = self.access_mode_handler.get_current().await?;
+ Ok(current.setting)
}
/// Test if the API is reachable via `proxy`.
@@ -259,11 +170,11 @@ where
}
/// Create an [`ApiProxy`] which will perform all REST requests against one
- /// specific endpoint `proxy_provider`.
- pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy {
+ /// specific endpoint `connection_mode`.
+ pub fn create_limited_api_proxy(&mut self, connection_mode: ApiConnectionMode) -> ApiProxy {
let rest_handle = self
.api_runtime
- .mullvad_rest_handle(proxy_provider, futures::stream::empty());
+ .mullvad_rest_handle(connection_mode.into_provider());
ApiProxy::new(rest_handle)
}
@@ -273,21 +184,6 @@ where
/// * Returns `Ok(false)` if the API returned an unexpected result
/// * Returns `Err(..)` if the API could not be reached
async fn perform_api_request(api_proxy: ApiProxy) -> Result<bool, Error> {
- api_proxy.api_addrs_available().await.map_err(Error::Rest)
- }
-
- /// If settings were changed due to an update, notify all listeners.
- fn notify_on_change(&mut self, settings_changed: MadeChanges) -> &mut Self {
- if settings_changed {
- self.event_listener
- .notify_settings(self.settings.to_settings());
-
- let handle = self.connection_modes_handler.clone();
- let new_access_methods = self.settings.api_access_methods.clone();
- tokio::spawn(async move {
- let _ = handle.update_access_methods(new_access_methods).await;
- });
- };
- self
+ Ok(api_proxy.api_addrs_available().await?)
}
}
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index 1e6ba296a3..ccdf9e7bf3 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -8,16 +8,16 @@ use crate::DaemonCommand;
use crate::DaemonEventSender;
use futures::{
channel::{mpsc, oneshot},
- Stream, StreamExt,
+ StreamExt,
};
use mullvad_api::{
availability::ApiAvailabilityHandle,
- proxy::{ApiConnectionMode, ProxyConfig},
+ proxy::{ApiConnectionMode, ConnectionModeProvider, ProxyConfig},
AddressCache,
};
use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{
- AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings,
+ AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Id, Settings,
};
use std::{net::SocketAddr, path::PathBuf};
use talpid_core::mpsc::Sender;
@@ -27,8 +27,8 @@ use talpid_types::net::{
pub enum Message {
Get(ResponseTx<ResolvedConnectionMode>),
- Set(ResponseTx<()>, AccessMethodSetting),
- Next(ResponseTx<ApiConnectionMode>),
+ Use(ResponseTx<()>, Id),
+ Rotate(ResponseTx<ApiConnectionMode>),
Update(ResponseTx<()>, Settings),
Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting),
}
@@ -113,8 +113,8 @@ impl std::fmt::Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Message::Get(_) => f.write_str("Get"),
- Message::Set(..) => f.write_str("Set"),
- Message::Next(_) => f.write_str("Next"),
+ Message::Use(..) => f.write_str("Set"),
+ Message::Rotate(_) => f.write_str("Rotate"),
Message::Update(..) => f.write_str("Update"),
Message::Resolve(..) => f.write_str("Resolve"),
}
@@ -159,8 +159,8 @@ impl AccessModeSelectorHandle {
})
}
- pub async fn set_access_method(&self, value: AccessMethodSetting) -> Result<()> {
- self.send_command(|tx| Message::Set(tx, value))
+ pub async fn use_access_method(&self, value: Id) -> Result<()> {
+ self.send_command(|tx| Message::Use(tx, value))
.await
.map_err(|err| {
log::debug!("Failed to set new access method!");
@@ -186,30 +186,51 @@ impl AccessModeSelectorHandle {
})
}
- pub async fn next(&self) -> Result<ApiConnectionMode> {
- self.send_command(Message::Next).await.map_err(|err| {
+ pub async fn rotate(&self) -> Result<ApiConnectionMode> {
+ self.send_command(Message::Rotate).await.map_err(|err| {
log::debug!("Failed while getting the next access method");
err
})
}
+}
- /// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the
- /// associated [`AccessModeSelector`].
- ///
- /// Calling `next` on this stream will poll for the next access method,
- /// which will be lazily produced (on-demand rather than speculatively).
- pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> {
- futures::stream::unfold(self, |handle| async move {
- match handle.next().await {
- Ok(connection_mode) => Some((connection_mode, handle)),
- // End this stream in case of failure in `next`. `next` should
- // not fail if the actor is in a good state.
- Err(_) => None,
- }
+pub struct AccessModeConnectionModeProvider {
+ initial: ApiConnectionMode,
+ handle: AccessModeSelectorHandle,
+ change_rx: mpsc::UnboundedReceiver<ApiConnectionMode>,
+}
+
+impl AccessModeConnectionModeProvider {
+ fn new(
+ handle: AccessModeSelectorHandle,
+ initial_connection_mode: ApiConnectionMode,
+ change_rx: mpsc::UnboundedReceiver<ApiConnectionMode>,
+ ) -> Result<Self> {
+ Ok(Self {
+ initial: initial_connection_mode,
+ handle,
+ change_rx,
})
}
}
+impl ConnectionModeProvider for AccessModeConnectionModeProvider {
+ fn initial(&self) -> ApiConnectionMode {
+ self.initial.clone()
+ }
+
+ fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send {
+ self.change_rx.next()
+ }
+
+ fn rotate(&self) -> impl std::future::Future<Output = ()> + Send {
+ let handle = self.handle.clone();
+ async move {
+ handle.rotate().await.ok();
+ }
+ }
+}
+
/// A small actor which takes care of handling the logic around rotating
/// connection modes to be used for Mullvad API request.
///
@@ -226,28 +247,40 @@ pub struct AccessModeSelector {
access_method_settings: Settings,
address_cache: AddressCache,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
+ connection_mode_provider_sender: mpsc::UnboundedSender<ApiConnectionMode>,
current: ResolvedConnectionMode,
/// `index` is used to keep track of the [`AccessMethodSetting`] to use.
index: usize,
- /// `set` is used to set the next [`AccessMethodSetting`] to use.
- set: Option<AccessMethodSetting>,
}
impl AccessModeSelector {
pub(crate) async fn spawn(
cache_dir: PathBuf,
relay_selector: RelaySelector,
- access_method_settings: Settings,
+ #[cfg_attr(not(feature = "api-override"), allow(unused_mut))]
+ mut access_method_settings: Settings,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
address_cache: AddressCache,
- ) -> Result<AccessModeSelectorHandle> {
+ ) -> Result<(AccessModeSelectorHandle, AccessModeConnectionModeProvider)> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();
+ #[cfg(feature = "api-override")]
+ {
+ if mullvad_api::API.force_direct {
+ access_method_settings
+ .update(|setting| setting.is_direct(), |setting| setting.enable());
+ }
+ }
+
// Always start looking from the position of `Direct`.
- let (index, next) = Self::select_next_active(0, &access_method_settings);
+ let (index, next) = Self::find_next_active(0, &access_method_settings);
let initial_connection_mode =
Self::resolve_inner(next, &relay_selector, &address_cache).await;
+ let (change_tx, change_rx) = mpsc::unbounded();
+
+ let api_connection_mode = initial_connection_mode.connection_mode.clone();
+
let selector = AccessModeSelector {
cmd_rx,
cache_dir,
@@ -255,14 +288,19 @@ impl AccessModeSelector {
access_method_settings,
address_cache,
access_method_event_sender,
+ connection_mode_provider_sender: change_tx,
current: initial_connection_mode,
index,
- set: None,
};
tokio::spawn(selector.into_future());
- Ok(AccessModeSelectorHandle { cmd_tx })
+ let handle = AccessModeSelectorHandle { cmd_tx };
+
+ let connection_mode_provider =
+ AccessModeConnectionModeProvider::new(handle.clone(), api_connection_mode, change_rx)?;
+
+ Ok((handle, connection_mode_provider))
}
async fn into_future(mut self) {
@@ -270,9 +308,9 @@ impl AccessModeSelector {
log::trace!("Processing {cmd} command");
let execution = match cmd {
Message::Get(tx) => self.on_get_access_method(tx),
- Message::Set(tx, value) => self.on_set_access_method(tx, value),
- Message::Next(tx) => self.on_next_connection_mode(tx).await,
- Message::Update(tx, values) => self.on_update_access_methods(tx, values),
+ Message::Use(tx, id) => self.on_use_access_method(tx, id).await,
+ Message::Rotate(tx) => self.on_next_connection_mode(tx).await,
+ Message::Update(tx, values) => self.on_update_access_methods(tx, values).await,
Message::Resolve(tx, setting) => self.on_resolve_access_method(tx, setting).await,
};
match execution {
@@ -297,26 +335,32 @@ impl AccessModeSelector {
self.reply(tx, self.current.clone())
}
- fn on_set_access_method(
- &mut self,
- tx: ResponseTx<()>,
- value: AccessMethodSetting,
- ) -> Result<()> {
- self.set_access_method(value);
+ async fn on_use_access_method(&mut self, tx: ResponseTx<()>, id: Id) -> Result<()> {
+ self.use_access_method(id).await;
self.reply(tx, ())
}
- /// Set the next access method to be returned by the [`Stream`] produced by
- /// calling `into_stream`.
- fn set_access_method(&mut self, value: AccessMethodSetting) {
- if let Some(index) = self
- .access_method_settings
- .iter()
- .position(|access_method| access_method.get_id() == value.get_id())
+ /// Set and announce the specified access method as the current one.
+ async fn use_access_method(&mut self, id: Id) {
+ #[cfg(feature = "api-override")]
{
- self.index = index;
- self.set = Some(value);
+ if mullvad_api::API.force_direct {
+ log::debug!("API proxies are disabled");
+ return;
+ }
}
+
+ let Some((index, method)) = self
+ .access_method_settings
+ .iter()
+ .enumerate()
+ .find(|(_, access_method)| access_method.get_id() == id)
+ else {
+ return;
+ };
+
+ self.index = index;
+ self.set_current(method.to_owned()).await;
}
async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
@@ -327,22 +371,8 @@ impl AccessModeSelector {
async fn next_connection_mode(&mut self) -> Result<ApiConnectionMode> {
#[cfg(feature = "api-override")]
{
- use mullvad_api::API;
- if API.force_direct {
+ if mullvad_api::API.force_direct {
log::debug!("API proxies are disabled");
- let endpoint = resolve_allowed_endpoint(
- &ApiConnectionMode::Direct,
- // Note that the address cache *should* be initialized with
- // the overridden API endpoint, so we can simply fetch the
- // endpoint address from it.
- self.address_cache.get_address().await,
- );
- let daemon_sender = self.access_method_event_sender.clone();
- tokio::spawn(async move {
- let _ = AccessMethodEvent::Allow { endpoint }
- .send(daemon_sender)
- .await;
- });
return Ok(ApiConnectionMode::Direct);
}
@@ -352,12 +382,16 @@ impl AccessModeSelector {
);
}
- let access_method = self.get_next();
- log::info!(
- "A new API access method has been selected: {name}",
- name = access_method.name
- );
+ let (next_index, next) =
+ Self::find_next_active(self.index + 1, &self.access_method_settings);
+ self.index = next_index;
+ self.set_current(next).await;
+ Ok(self.current.connection_mode.clone())
+ }
+
+ async fn set_current(&mut self, access_method: AccessMethodSetting) {
let resolved = self.resolve(access_method).await;
+
// Note: If the daemon is busy waiting for a call to this function
// to complete while we wait for the daemon to fully handle this
// `NewAccessMethodEvent`, then we find ourselves in a deadlock.
@@ -386,26 +420,24 @@ impl AccessModeSelector {
}
});
+ // Notify REST client
+ let _ = self
+ .connection_mode_provider_sender
+ .unbounded_send(resolved.connection_mode.clone());
+
self.current = resolved;
- Ok(self.current.connection_mode.clone())
- }
- fn get_next(&mut self) -> AccessMethodSetting {
- if let Some(access_method) = self.set.take() {
- access_method
- } else {
- let (next_index, next) =
- Self::select_next_active(self.index + 1, &self.access_method_settings);
- self.index = next_index;
- next
- }
+ log::info!(
+ "A new API access method has been selected: {name}",
+ name = self.current.setting.name
+ );
}
/// Find the next access method to use.
///
/// * `start`: From which point in `access_methods` to start the search.
/// * `access_methods`: The search space.
- fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
+ fn find_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
access_methods
.iter()
.cloned()
@@ -416,26 +448,46 @@ impl AccessModeSelector {
.find(|(_index, access_method)| access_method.enabled())
.unwrap_or_else(|| (0, access_methods.direct().clone()))
}
- fn on_update_access_methods(
+
+ async fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
access_methods: Settings,
) -> Result<()> {
- self.update_access_methods(access_methods);
+ self.update_access_methods(access_methods).await?;
self.reply(tx, ())
}
- fn update_access_methods(&mut self, access_methods: Settings) {
- let removed_active = !access_methods
+ async fn update_access_methods(&mut self, access_methods: Settings) -> Result<()> {
+ self.access_method_settings = access_methods;
+
+ let new_current = self
+ .access_method_settings
.iter()
- .any(|access_method| access_method.get_id() == self.current.setting.get_id());
- if removed_active {
- // A new access mehtod will suddenly have the same index as the one
- // we are removing, but we want it to still be a candidate. A minor
- // hack to achieve this is to simply decrement the current index.
- self.index = self.index.saturating_sub(1);
+ .enumerate()
+ .find(|(_, access_method)| access_method.get_id() == self.current.setting.get_id());
+
+ match new_current {
+ Some((index, new_current)) => {
+ // If the current method was modified, announce changes
+ self.index = index;
+ if self.current.setting != *new_current {
+ if new_current.enabled() {
+ self.set_current(new_current.to_owned()).await;
+ } else {
+ self.next_connection_mode().await?;
+ }
+ }
+ }
+ None => {
+ // Current method was removed: A new access method will suddenly have the same index as the one
+ // we are removing, but we want it to still be a candidate. A minor
+ // hack to achieve this is to simply decrement the current index.
+ self.index = self.index.saturating_sub(1);
+ self.next_connection_mode().await?;
+ }
}
- self.access_method_settings = access_methods;
+ Ok(())
}
pub async fn on_resolve_access_method(
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index ef4bcb86e0..e446735df0 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -630,7 +630,7 @@ pub struct Daemon<L: EventListener> {
account_history: account_history::AccountHistory,
device_checker: device::TunnelStateChangeHandler,
account_manager: device::AccountManagerHandle,
- connection_modes_handler: api::AccessModeSelectorHandle,
+ access_mode_handler: api::AccessModeSelectorHandle,
api_runtime: mullvad_api::Runtime,
api_handle: mullvad_api::rest::MullvadRestHandle,
version_updater_handle: version_check::VersionUpdaterHandle,
@@ -707,7 +707,7 @@ where
.set_config(new_selector_config(settings));
});
- let connection_modes_handler = api::AccessModeSelector::spawn(
+ let (access_mode_handler, access_mode_provider) = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
settings.api_access_methods.clone(),
@@ -717,15 +717,16 @@ where
.await
.map_err(Error::ApiConnectionModeError)?;
- let initial_connection_mode = connection_modes_handler
- .get_current()
- .await
- .map_err(Error::ApiConnectionModeError)?;
+ let api_handle = api_runtime.mullvad_rest_handle(access_mode_provider);
- let api_handle = api_runtime.mullvad_rest_handle(
- initial_connection_mode.connection_mode,
- Box::pin(connection_modes_handler.clone().into_stream()),
- );
+ let access_method_handle = access_mode_handler.clone();
+ settings.register_change_listener(move |settings| {
+ let handle = access_method_handle.clone();
+ let new_access_methods = settings.api_access_methods.clone();
+ tokio::spawn(async move {
+ let _ = handle.update_access_methods(new_access_methods).await;
+ });
+ });
let migration_complete = if let Some(migration_data) = migration_data {
migrations::migrate_device(
@@ -801,7 +802,11 @@ where
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options),
- allowed_endpoint: initial_connection_mode.endpoint,
+ allowed_endpoint: access_mode_handler
+ .get_current()
+ .await
+ .map_err(Error::ApiConnectionModeError)?
+ .endpoint,
reset_firewall: *target_state != TargetState::Secured,
#[cfg(windows)]
exclude_paths,
@@ -874,7 +879,7 @@ where
account_history,
device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()),
account_manager,
- connection_modes_handler,
+ access_mode_handler,
api_runtime,
api_handle,
version_updater_handle,
@@ -2117,9 +2122,12 @@ where
{
Ok(settings_changes) => {
if settings_changes {
- if let Err(error) = self.api_handle.service().next_api_endpoint().await {
- log::error!("Failed to rotate API endpoint: {}", error);
- }
+ let access_mode_handler = self.access_mode_handler.clone();
+ tokio::spawn(async move {
+ if let Err(error) = access_mode_handler.rotate().await {
+ log::error!("Failed to rotate API endpoint: {error}");
+ }
+ });
self.reconnect_tunnel();
};
Self::oneshot_send(tx, Ok(()), "set_bridge_settings");
@@ -2466,7 +2474,7 @@ where
}
fn on_get_current_api_access_method(&mut self, tx: ResponseTx<AccessMethodSetting, Error>) {
- let handle = self.connection_modes_handler.clone();
+ let handle = self.access_mode_handler.clone();
tokio::spawn(async move {
let result = handle
.get_current()
@@ -2493,7 +2501,7 @@ where
};
let daemon_event_sender = self.tx.to_specialized_sender();
- let access_method_selector = self.connection_modes_handler.clone();
+ let access_method_selector = self.access_mode_handler.clone();
tokio::spawn(async move {
let result = Self::test_access_method(
proxy_endpoint,
@@ -2524,7 +2532,7 @@ where
}
};
- let test_subject = match self.connection_modes_handler.resolve(access_method).await {
+ let test_subject = match self.access_mode_handler.resolve(access_method).await {
Ok(test_subject) => test_subject,
Err(err) => {
reply(Err(Error::ApiConnectionModeError(err)));
@@ -2534,7 +2542,7 @@ where
let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode);
let daemon_event_sender = self.tx.to_specialized_sender();
- let access_method_selector = self.connection_modes_handler.clone();
+ let access_method_selector = self.access_mode_handler.clone();
tokio::spawn(async move {
let result = Self::test_access_method(
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index bcd820bef5..05502bcc46 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -301,7 +301,7 @@ async fn send_problem_report_inner(
let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await;
let api_client = mullvad_api::ProblemReportProxy::new(
- api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
+ api_runtime.mullvad_rest_handle(connection_mode.into_provider()),
);
for _attempt in 0..MAX_SEND_ATTEMPTS {
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index 4b14319414..e361d41a2b 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -160,7 +160,7 @@ async fn remove_device() -> Result<(), Error> {
let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await;
let proxy = mullvad_api::DevicesProxy::new(
- api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
+ api_runtime.mullvad_rest_handle(connection_mode.into_provider()),
);
let device_removal = retry_future(
diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs
index e8f6bd4a5d..fb365101be 100644
--- a/mullvad-types/src/access_method.rs
+++ b/mullvad-types/src/access_method.rs
@@ -62,11 +62,11 @@ impl Settings {
pub fn update(
&mut self,
predicate: impl Fn(&AccessMethodSetting) -> bool,
- f: impl FnOnce(&AccessMethodSetting) -> AccessMethodSetting,
+ f: impl FnOnce(&mut AccessMethodSetting),
) -> bool {
let mut updated = false;
if let Some(access_method) = self.iter_mut().find(|setting| predicate(setting)) {
- *access_method = f(access_method);
+ f(access_method);
updated = true;
}
self.ensure_consistent_state();
@@ -241,6 +241,13 @@ impl AccessMethodSetting {
self.as_custom().is_none()
}
+ pub fn is_direct(&self) -> bool {
+ matches!(
+ self.access_method,
+ AccessMethod::BuiltIn(BuiltInAccessMethod::Direct)
+ )
+ }
+
/// Set an API access method to be enabled.
pub fn enable(&mut self) {
self.enabled = true;
diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs
index 1eeeb8c170..56bcadafbd 100644
--- a/test/test-manager/src/tests/account.rs
+++ b/test/test-manager/src/tests/account.rs
@@ -237,10 +237,7 @@ pub fn new_device_client() -> DevicesProxy {
let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("failed to create api runtime");
- let rest_handle = api.mullvad_rest_handle(
- ApiConnectionMode::Direct,
- ApiConnectionMode::Direct.into_repeat(),
- );
+ let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider());
DevicesProxy::new(rest_handle)
}