diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-02-15 19:47:07 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-02-16 16:37:37 +0100 |
| commit | e471d0739446279b01022090ac4457fe337ca598 (patch) | |
| tree | a67ac87a161cae956cde7bc4023cfd12beba5a6a | |
| parent | c8a3a3be92098cf64bc9269b3c4791e41c3b500d (diff) | |
| download | mullvadvpn-e471d0739446279b01022090ac4457fe337ca598.tar.xz mullvadvpn-e471d0739446279b01022090ac4457fe337ca598.zip | |
Refactor API access methods
| -rw-r--r-- | mullvad-api/src/bin/relay_list.rs | 10 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 25 | ||||
| -rw-r--r-- | mullvad-api/src/proxy.rs | 42 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 65 | ||||
| -rw-r--r-- | mullvad-daemon/src/access_method.rs | 164 | ||||
| -rw-r--r-- | mullvad-daemon/src/api.rs | 236 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 46 | ||||
| -rw-r--r-- | mullvad-problem-report/src/lib.rs | 2 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 2 | ||||
| -rw-r--r-- | mullvad-types/src/access_method.rs | 11 | ||||
| -rw-r--r-- | test/test-manager/src/tests/account.rs | 5 |
11 files changed, 295 insertions, 313 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index e395d8ae5f..8cb615d77f 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -11,12 +11,10 @@ async fn main() { let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("Failed to load runtime"); - let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle( - ApiConnectionMode::Direct, - ApiConnectionMode::Direct.into_repeat(), - )) - .relay_list(None) - .await; + let relay_list_request = + RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider())) + .relay_list(None) + .await; let relay_list = match relay_list_request { Ok(relay_list) => relay_list, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index dad6cdf706..6114bec90a 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -1,6 +1,5 @@ #[cfg(target_os = "android")] use futures::channel::mpsc; -use futures::Stream; use hyper::Method; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; @@ -8,7 +7,7 @@ use mullvad_types::{ account::{AccountData, AccountToken, VoucherSubmission}, version::AppVersion, }; -use proxy::ApiConnectionMode; +use proxy::{ApiConnectionMode, ConnectionModeProvider}; use std::{ cell::Cell, collections::BTreeMap, @@ -408,34 +407,30 @@ impl Runtime { } /// Creates a new request service and returns a handle to it. - fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>( + fn new_request_service<T: ConnectionModeProvider + 'static>( &self, sni_hostname: Option<String>, - initial_connection_mode: ApiConnectionMode, - proxy_provider: T, + connection_mode_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> rest::RequestServiceHandle { rest::RequestService::spawn( sni_hostname, self.api_availability.handle(), self.address_cache.clone(), - initial_connection_mode, - proxy_provider, + connection_mode_provider, #[cfg(target_os = "android")] socket_bypass_tx, ) } /// Returns a request factory initialized to create requests for the master API - pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>( + pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>( &self, - initial_connection_mode: ApiConnectionMode, - proxy_provider: T, + connection_mode_provider: T, ) -> rest::MullvadRestHandle { let service = self.new_request_service( Some(API.host().to_string()), - initial_connection_mode, - proxy_provider, + connection_mode_provider, #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -454,8 +449,7 @@ impl Runtime { pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { let service = self.new_request_service( Some(hostname.clone()), - ApiConnectionMode::Direct, - futures::stream::repeat(ApiConnectionMode::Direct), + ApiConnectionMode::Direct.into_provider(), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -474,8 +468,7 @@ impl Runtime { pub fn rest_handle(&self) -> rest::RequestServiceHandle { self.new_request_service( None, - ApiConnectionMode::Direct, - ApiConnectionMode::Direct.into_repeat(), + ApiConnectionMode::Direct.into_provider(), #[cfg(target_os = "android")] None, ) diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs index 2b4821ba64..0915d1d23c 100644 --- a/mullvad-api/src/proxy.rs +++ b/mullvad-api/src/proxy.rs @@ -1,4 +1,3 @@ -use futures::Stream; use hyper::client::connect::Connected; use serde::{Deserialize, Serialize}; use std::{ @@ -18,6 +17,41 @@ use tokio::{ const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json"; +pub trait ConnectionModeProvider: Send { + /// Initial connection mode + fn initial(&self) -> ApiConnectionMode; + + /// Request a new connection mode from the provider + fn rotate(&self) -> impl std::future::Future<Output = ()> + Send; + + /// Receive changes to the connection mode, announced by the provider + fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send; +} + +pub struct StaticConnectionModeProvider { + mode: ApiConnectionMode, +} + +impl StaticConnectionModeProvider { + pub fn new(mode: ApiConnectionMode) -> Self { + Self { mode } + } +} + +impl ConnectionModeProvider for StaticConnectionModeProvider { + fn initial(&self) -> ApiConnectionMode { + self.mode.clone() + } + + fn rotate(&self) -> impl std::future::Future<Output = ()> + Send { + futures::future::ready(()) + } + + fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send { + futures::future::pending() + } +} + #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] pub enum ApiConnectionMode { /// Connect directly to the target. @@ -153,10 +187,8 @@ impl ApiConnectionMode { *self != ApiConnectionMode::Direct } - /// Convenience function that returns a stream that repeats - /// this config forever. - pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> { - futures::stream::repeat(self) + pub fn into_provider(self) -> StaticConnectionModeProvider { + StaticConnectionModeProvider::new(self) } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index ca63f16c1f..158d84f01b 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -5,12 +5,11 @@ use crate::{ address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, - proxy::ApiConnectionMode, + proxy::ConnectionModeProvider, }; use futures::{ channel::{mpsc, oneshot}, stream::StreamExt, - Stream, }; use hyper::{ client::{connect::Connect, Client}, @@ -120,23 +119,22 @@ impl Error { /// A service that executes HTTP requests, allowing for on-demand termination of all in-flight /// requests -pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> { +pub(crate) struct RequestService<T: ConnectionModeProvider> { command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>, command_rx: mpsc::UnboundedReceiver<RequestCommand>, connector_handle: HttpsConnectorWithSniHandle, client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, - proxy_config_provider: T, + connection_mode_provider: T, api_availability: ApiAvailabilityHandle, } -impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> { +impl<T: ConnectionModeProvider + 'static> RequestService<T> { /// Constructs a new request service. pub fn spawn( sni_hostname: Option<String>, api_availability: ApiAvailabilityHandle, address_cache: AddressCache, - initial_connection_mode: ApiConnectionMode, - proxy_config_provider: T, + connection_mode_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( @@ -146,7 +144,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic socket_bypass_tx.clone(), ); - connector_handle.set_connection_mode(initial_connection_mode); + connector_handle.set_connection_mode(connection_mode_provider.initial()); let (command_tx, command_rx) = mpsc::unbounded(); let client = Client::builder().build(connector); @@ -158,7 +156,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic command_rx, connector_handle, client, - proxy_config_provider, + connection_mode_provider, api_availability, }; let handle = RequestServiceHandle { tx: command_tx }; @@ -166,6 +164,27 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic handle } + async fn into_future(mut self) { + loop { + tokio::select! { + new_mode = self.connection_mode_provider.receive() => { + let Some(new_mode) = new_mode else { + break; + }; + self.connector_handle.set_connection_mode(new_mode); + } + command = self.command_rx.next() => { + let Some(command) = command else { + break; + }; + + self.process_command(command).await; + } + } + } + self.connector_handle.reset(); + } + async fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { @@ -174,11 +193,8 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic RequestCommand::Reset => { self.connector_handle.reset(); } - RequestCommand::NextApiConfig(completion_tx) => { - if let Some(connection_mode) = self.proxy_config_provider.next().await { - self.connector_handle.set_connection_mode(connection_mode); - } - let _ = completion_tx.send(Ok(())); + RequestCommand::NextApiConfig => { + self.connection_mode_provider.rotate().await; } } } @@ -201,8 +217,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic if err.is_network_error() && !api_availability.get_state().is_offline() { log::error!("{}", err.display_chain_with_msg("HTTP request failed")); if let Some(tx) = tx { - let (completion_tx, _completion_rx) = oneshot::channel(); - let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx)); + let _ = tx.unbounded_send(RequestCommand::NextApiConfig); } } } @@ -210,13 +225,6 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic let _ = completion_tx.send(response); }); } - - async fn into_future(mut self) { - while let Some(command) = self.command_rx.next().await { - self.process_command(command).await; - } - self.connector_handle.reset(); - } } #[derive(Clone)] @@ -239,15 +247,6 @@ impl RequestServiceHandle { .map_err(|_| Error::RestServiceDown)?; completion_rx.await.map_err(|_| Error::RestServiceDown)? } - - /// Forcibly update the connection mode. - pub async fn next_api_endpoint(&self) -> Result<()> { - let (completion_tx, completion_rx) = oneshot::channel(); - self.tx - .unbounded_send(RequestCommand::NextApiConfig(completion_tx)) - .map_err(|_| Error::RestServiceDown)?; - completion_rx.await.map_err(|_| Error::RestServiceDown)? - } } #[derive(Debug)] @@ -257,7 +256,7 @@ pub(crate) enum RequestCommand { oneshot::Sender<std::result::Result<Response, Error>>, ), Reset, - NextApiConfig(oneshot::Sender<std::result::Result<(), Error>>), + NextApiConfig, } /// A REST request that is sent to the RequestService to be executed. diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 664fce6bfe..793d82bb5c 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -1,8 +1,4 @@ -use crate::{ - api, - settings::{self, MadeChanges}, - Daemon, EventListener, -}; +use crate::{api, settings, Daemon, EventListener}; use mullvad_api::{proxy::ApiConnectionMode, rest, ApiProxy}; use mullvad_types::{ access_method::{self, AccessMethod, AccessMethodSetting}, @@ -17,9 +13,6 @@ pub enum Error { /// Can not find access method #[error(display = "Cannot find custom access method {}", _0)] NoSuchMethod(access_method::Id), - /// Access method could not be rotate - #[error(display = "Access method could not be rotated")] - RotationFailed, /// Some error occured in the daemon's state of handling /// [`AccessMethodSetting`]s & [`ApiConnectionMode`]s #[error(display = "Error occured when handling connection settings & details")] @@ -54,42 +47,21 @@ where let id = access_method_setting.get_id(); self.settings .update(|settings| settings.api_access_methods.append(access_method_setting)) - .await - .map(|did_change| self.notify_on_change(did_change)) - .map(|_| id) - .map_err(Error::Settings) + .await?; + Ok(id) } /// Remove a [`AccessMethodSetting`] from the daemon's saved settings. - /// - /// If the [`AccessMethodSetting`] which is currently in use happens to be - /// removed, the daemon should force a rotation of the active API endpoint. pub async fn remove_access_method( &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - let did_change = self - .settings + self.settings .try_update(|settings| -> Result<(), Error> { settings.api_access_methods.remove(&access_method)?; Ok(()) }) - .await - .map_err(Error::Settings)?; - - self.notify_on_change(did_change); - // If the currently active access method is removed, a new access - // method should be selected. - // - // Notice the ordering here: It is important that the current method is - // removed before we pick a new access method. The `remove` function - // will ensure that atleast one access method is enabled after the - // removal. If the currently active access method is removed, some other - // method is enabled before we pick the next access method to use. - if self.is_in_use(access_method.clone()).await? { - self.force_api_endpoint_rotation().await?; - } - + .await?; Ok(()) } @@ -110,19 +82,18 @@ where &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - let mut access_method = self.get_api_access_method(access_method)?; - // Toggle the enabled status if needed - if !access_method.enabled() { - access_method.enable(); - self.update_access_method_inner(access_method.clone()) - .await? - } - // Set `access_method` as the next access method to use - self.connection_modes_handler - .set_access_method(access_method) + self.settings + .update(|settings| { + settings.api_access_methods.update( + |setting| setting.get_id() == access_method, + |setting| setting.enable(), + ); + }) .await?; - // Force a rotation of Access Methods - self.force_api_endpoint_rotation().await + self.access_mode_handler + .use_access_method(access_method) + .await?; + Ok(()) } pub fn get_api_access_method( @@ -140,88 +111,28 @@ where /// Updates a [`AccessMethodSetting`] by replacing the existing entry with /// the argument `access_method_update`. if an entry with a matching /// [`access_method::Id`] is found. - /// - /// If the currently active [`AccessMethodSetting`] is updated, the daemon - /// will automatically use this updated [`AccessMethodSetting`] when - /// performing subsequent API calls. pub async fn update_access_method( &mut self, access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - self.update_access_method_inner(access_method_update.clone()) - .await?; - - if self.is_in_use(access_method_update.get_id()).await? { - if access_method_update.disabled() { - // If the currently active access method is updated & disabled - // we should select the next access method - self.force_api_endpoint_rotation().await?; - } else { - // If the currently active access method is just updated, we - // need to re-set it after updating the settings - self.use_api_access_method(access_method_update.get_id()) - .await?; - } - } - - Ok(()) - } - - /// Updates a [`AccessMethodSetting`] by replacing the existing entry with - /// the argument `access_method_update`. if an entry with a matching - /// [`access_method::Id`] is found. - /// - /// This inner function does not perform any kind of check to see if the - /// existing, in-use setting needs to be re-set. - async fn update_access_method_inner( - &mut self, - access_method_update: AccessMethodSetting, - ) -> Result<(), Error> { - let settings_update = |settings: &mut Settings| { - let target = access_method_update.get_id(); - settings.api_access_methods.update( - |access_method| access_method.get_id() == target, - |_| access_method_update, - ); - }; - self.settings - .update(settings_update) - .await - .map(|did_change| self.notify_on_change(did_change)) - .map_err(Error::Settings)?; + .update(|settings: &mut Settings| { + let target = access_method_update.get_id(); + settings.api_access_methods.update( + |access_method| access_method.get_id() == target, + |method| *method = access_method_update, + ); + }) + .await?; Ok(()) } - /// Check if some access method is the same as the currently active one. - /// - /// This can be useful for invalidating stale states. - async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> { - Ok(access_method == self.get_current_access_method().await?.get_id()) - } - /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> { - self.connection_modes_handler - .get_current() - .await - .map(|current| current.setting) - .map_err(Error::ApiService) - } - - /// Change which [`AccessMethodSetting`] which will be used as the Mullvad - /// API endpoint. - async fn force_api_endpoint_rotation(&self) -> Result<(), Error> { - self.api_handle - .service() - .next_api_endpoint() - .await - .map_err(|error| { - log::error!("Failed to rotate API endpoint: {}", error); - Error::RotationFailed - }) + let current = self.access_mode_handler.get_current().await?; + Ok(current.setting) } /// Test if the API is reachable via `proxy`. @@ -259,11 +170,11 @@ where } /// Create an [`ApiProxy`] which will perform all REST requests against one - /// specific endpoint `proxy_provider`. - pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy { + /// specific endpoint `connection_mode`. + pub fn create_limited_api_proxy(&mut self, connection_mode: ApiConnectionMode) -> ApiProxy { let rest_handle = self .api_runtime - .mullvad_rest_handle(proxy_provider, futures::stream::empty()); + .mullvad_rest_handle(connection_mode.into_provider()); ApiProxy::new(rest_handle) } @@ -273,21 +184,6 @@ where /// * Returns `Ok(false)` if the API returned an unexpected result /// * Returns `Err(..)` if the API could not be reached async fn perform_api_request(api_proxy: ApiProxy) -> Result<bool, Error> { - api_proxy.api_addrs_available().await.map_err(Error::Rest) - } - - /// If settings were changed due to an update, notify all listeners. - fn notify_on_change(&mut self, settings_changed: MadeChanges) -> &mut Self { - if settings_changed { - self.event_listener - .notify_settings(self.settings.to_settings()); - - let handle = self.connection_modes_handler.clone(); - let new_access_methods = self.settings.api_access_methods.clone(); - tokio::spawn(async move { - let _ = handle.update_access_methods(new_access_methods).await; - }); - }; - self + Ok(api_proxy.api_addrs_available().await?) } } diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index 1e6ba296a3..ccdf9e7bf3 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -8,16 +8,16 @@ use crate::DaemonCommand; use crate::DaemonEventSender; use futures::{ channel::{mpsc, oneshot}, - Stream, StreamExt, + StreamExt, }; use mullvad_api::{ availability::ApiAvailabilityHandle, - proxy::{ApiConnectionMode, ProxyConfig}, + proxy::{ApiConnectionMode, ConnectionModeProvider, ProxyConfig}, AddressCache, }; use mullvad_relay_selector::RelaySelector; use mullvad_types::access_method::{ - AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings, + AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Id, Settings, }; use std::{net::SocketAddr, path::PathBuf}; use talpid_core::mpsc::Sender; @@ -27,8 +27,8 @@ use talpid_types::net::{ pub enum Message { Get(ResponseTx<ResolvedConnectionMode>), - Set(ResponseTx<()>, AccessMethodSetting), - Next(ResponseTx<ApiConnectionMode>), + Use(ResponseTx<()>, Id), + Rotate(ResponseTx<ApiConnectionMode>), Update(ResponseTx<()>, Settings), Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting), } @@ -113,8 +113,8 @@ impl std::fmt::Display for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Message::Get(_) => f.write_str("Get"), - Message::Set(..) => f.write_str("Set"), - Message::Next(_) => f.write_str("Next"), + Message::Use(..) => f.write_str("Set"), + Message::Rotate(_) => f.write_str("Rotate"), Message::Update(..) => f.write_str("Update"), Message::Resolve(..) => f.write_str("Resolve"), } @@ -159,8 +159,8 @@ impl AccessModeSelectorHandle { }) } - pub async fn set_access_method(&self, value: AccessMethodSetting) -> Result<()> { - self.send_command(|tx| Message::Set(tx, value)) + pub async fn use_access_method(&self, value: Id) -> Result<()> { + self.send_command(|tx| Message::Use(tx, value)) .await .map_err(|err| { log::debug!("Failed to set new access method!"); @@ -186,30 +186,51 @@ impl AccessModeSelectorHandle { }) } - pub async fn next(&self) -> Result<ApiConnectionMode> { - self.send_command(Message::Next).await.map_err(|err| { + pub async fn rotate(&self) -> Result<ApiConnectionMode> { + self.send_command(Message::Rotate).await.map_err(|err| { log::debug!("Failed while getting the next access method"); err }) } +} - /// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the - /// associated [`AccessModeSelector`]. - /// - /// Calling `next` on this stream will poll for the next access method, - /// which will be lazily produced (on-demand rather than speculatively). - pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> { - futures::stream::unfold(self, |handle| async move { - match handle.next().await { - Ok(connection_mode) => Some((connection_mode, handle)), - // End this stream in case of failure in `next`. `next` should - // not fail if the actor is in a good state. - Err(_) => None, - } +pub struct AccessModeConnectionModeProvider { + initial: ApiConnectionMode, + handle: AccessModeSelectorHandle, + change_rx: mpsc::UnboundedReceiver<ApiConnectionMode>, +} + +impl AccessModeConnectionModeProvider { + fn new( + handle: AccessModeSelectorHandle, + initial_connection_mode: ApiConnectionMode, + change_rx: mpsc::UnboundedReceiver<ApiConnectionMode>, + ) -> Result<Self> { + Ok(Self { + initial: initial_connection_mode, + handle, + change_rx, }) } } +impl ConnectionModeProvider for AccessModeConnectionModeProvider { + fn initial(&self) -> ApiConnectionMode { + self.initial.clone() + } + + fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send { + self.change_rx.next() + } + + fn rotate(&self) -> impl std::future::Future<Output = ()> + Send { + let handle = self.handle.clone(); + async move { + handle.rotate().await.ok(); + } + } +} + /// A small actor which takes care of handling the logic around rotating /// connection modes to be used for Mullvad API request. /// @@ -226,28 +247,40 @@ pub struct AccessModeSelector { access_method_settings: Settings, address_cache: AddressCache, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, + connection_mode_provider_sender: mpsc::UnboundedSender<ApiConnectionMode>, current: ResolvedConnectionMode, /// `index` is used to keep track of the [`AccessMethodSetting`] to use. index: usize, - /// `set` is used to set the next [`AccessMethodSetting`] to use. - set: Option<AccessMethodSetting>, } impl AccessModeSelector { pub(crate) async fn spawn( cache_dir: PathBuf, relay_selector: RelaySelector, - access_method_settings: Settings, + #[cfg_attr(not(feature = "api-override"), allow(unused_mut))] + mut access_method_settings: Settings, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, address_cache: AddressCache, - ) -> Result<AccessModeSelectorHandle> { + ) -> Result<(AccessModeSelectorHandle, AccessModeConnectionModeProvider)> { let (cmd_tx, cmd_rx) = mpsc::unbounded(); + #[cfg(feature = "api-override")] + { + if mullvad_api::API.force_direct { + access_method_settings + .update(|setting| setting.is_direct(), |setting| setting.enable()); + } + } + // Always start looking from the position of `Direct`. - let (index, next) = Self::select_next_active(0, &access_method_settings); + let (index, next) = Self::find_next_active(0, &access_method_settings); let initial_connection_mode = Self::resolve_inner(next, &relay_selector, &address_cache).await; + let (change_tx, change_rx) = mpsc::unbounded(); + + let api_connection_mode = initial_connection_mode.connection_mode.clone(); + let selector = AccessModeSelector { cmd_rx, cache_dir, @@ -255,14 +288,19 @@ impl AccessModeSelector { access_method_settings, address_cache, access_method_event_sender, + connection_mode_provider_sender: change_tx, current: initial_connection_mode, index, - set: None, }; tokio::spawn(selector.into_future()); - Ok(AccessModeSelectorHandle { cmd_tx }) + let handle = AccessModeSelectorHandle { cmd_tx }; + + let connection_mode_provider = + AccessModeConnectionModeProvider::new(handle.clone(), api_connection_mode, change_rx)?; + + Ok((handle, connection_mode_provider)) } async fn into_future(mut self) { @@ -270,9 +308,9 @@ impl AccessModeSelector { log::trace!("Processing {cmd} command"); let execution = match cmd { Message::Get(tx) => self.on_get_access_method(tx), - Message::Set(tx, value) => self.on_set_access_method(tx, value), - Message::Next(tx) => self.on_next_connection_mode(tx).await, - Message::Update(tx, values) => self.on_update_access_methods(tx, values), + Message::Use(tx, id) => self.on_use_access_method(tx, id).await, + Message::Rotate(tx) => self.on_next_connection_mode(tx).await, + Message::Update(tx, values) => self.on_update_access_methods(tx, values).await, Message::Resolve(tx, setting) => self.on_resolve_access_method(tx, setting).await, }; match execution { @@ -297,26 +335,32 @@ impl AccessModeSelector { self.reply(tx, self.current.clone()) } - fn on_set_access_method( - &mut self, - tx: ResponseTx<()>, - value: AccessMethodSetting, - ) -> Result<()> { - self.set_access_method(value); + async fn on_use_access_method(&mut self, tx: ResponseTx<()>, id: Id) -> Result<()> { + self.use_access_method(id).await; self.reply(tx, ()) } - /// Set the next access method to be returned by the [`Stream`] produced by - /// calling `into_stream`. - fn set_access_method(&mut self, value: AccessMethodSetting) { - if let Some(index) = self - .access_method_settings - .iter() - .position(|access_method| access_method.get_id() == value.get_id()) + /// Set and announce the specified access method as the current one. + async fn use_access_method(&mut self, id: Id) { + #[cfg(feature = "api-override")] { - self.index = index; - self.set = Some(value); + if mullvad_api::API.force_direct { + log::debug!("API proxies are disabled"); + return; + } } + + let Some((index, method)) = self + .access_method_settings + .iter() + .enumerate() + .find(|(_, access_method)| access_method.get_id() == id) + else { + return; + }; + + self.index = index; + self.set_current(method.to_owned()).await; } async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> { @@ -327,22 +371,8 @@ impl AccessModeSelector { async fn next_connection_mode(&mut self) -> Result<ApiConnectionMode> { #[cfg(feature = "api-override")] { - use mullvad_api::API; - if API.force_direct { + if mullvad_api::API.force_direct { log::debug!("API proxies are disabled"); - let endpoint = resolve_allowed_endpoint( - &ApiConnectionMode::Direct, - // Note that the address cache *should* be initialized with - // the overridden API endpoint, so we can simply fetch the - // endpoint address from it. - self.address_cache.get_address().await, - ); - let daemon_sender = self.access_method_event_sender.clone(); - tokio::spawn(async move { - let _ = AccessMethodEvent::Allow { endpoint } - .send(daemon_sender) - .await; - }); return Ok(ApiConnectionMode::Direct); } @@ -352,12 +382,16 @@ impl AccessModeSelector { ); } - let access_method = self.get_next(); - log::info!( - "A new API access method has been selected: {name}", - name = access_method.name - ); + let (next_index, next) = + Self::find_next_active(self.index + 1, &self.access_method_settings); + self.index = next_index; + self.set_current(next).await; + Ok(self.current.connection_mode.clone()) + } + + async fn set_current(&mut self, access_method: AccessMethodSetting) { let resolved = self.resolve(access_method).await; + // Note: If the daemon is busy waiting for a call to this function // to complete while we wait for the daemon to fully handle this // `NewAccessMethodEvent`, then we find ourselves in a deadlock. @@ -386,26 +420,24 @@ impl AccessModeSelector { } }); + // Notify REST client + let _ = self + .connection_mode_provider_sender + .unbounded_send(resolved.connection_mode.clone()); + self.current = resolved; - Ok(self.current.connection_mode.clone()) - } - fn get_next(&mut self) -> AccessMethodSetting { - if let Some(access_method) = self.set.take() { - access_method - } else { - let (next_index, next) = - Self::select_next_active(self.index + 1, &self.access_method_settings); - self.index = next_index; - next - } + log::info!( + "A new API access method has been selected: {name}", + name = self.current.setting.name + ); } /// Find the next access method to use. /// /// * `start`: From which point in `access_methods` to start the search. /// * `access_methods`: The search space. - fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { + fn find_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { access_methods .iter() .cloned() @@ -416,26 +448,46 @@ impl AccessModeSelector { .find(|(_index, access_method)| access_method.enabled()) .unwrap_or_else(|| (0, access_methods.direct().clone())) } - fn on_update_access_methods( + + async fn on_update_access_methods( &mut self, tx: ResponseTx<()>, access_methods: Settings, ) -> Result<()> { - self.update_access_methods(access_methods); + self.update_access_methods(access_methods).await?; self.reply(tx, ()) } - fn update_access_methods(&mut self, access_methods: Settings) { - let removed_active = !access_methods + async fn update_access_methods(&mut self, access_methods: Settings) -> Result<()> { + self.access_method_settings = access_methods; + + let new_current = self + .access_method_settings .iter() - .any(|access_method| access_method.get_id() == self.current.setting.get_id()); - if removed_active { - // A new access mehtod will suddenly have the same index as the one - // we are removing, but we want it to still be a candidate. A minor - // hack to achieve this is to simply decrement the current index. - self.index = self.index.saturating_sub(1); + .enumerate() + .find(|(_, access_method)| access_method.get_id() == self.current.setting.get_id()); + + match new_current { + Some((index, new_current)) => { + // If the current method was modified, announce changes + self.index = index; + if self.current.setting != *new_current { + if new_current.enabled() { + self.set_current(new_current.to_owned()).await; + } else { + self.next_connection_mode().await?; + } + } + } + None => { + // Current method was removed: A new access method will suddenly have the same index as the one + // we are removing, but we want it to still be a candidate. A minor + // hack to achieve this is to simply decrement the current index. + self.index = self.index.saturating_sub(1); + self.next_connection_mode().await?; + } } - self.access_method_settings = access_methods; + Ok(()) } pub async fn on_resolve_access_method( diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index ef4bcb86e0..e446735df0 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -630,7 +630,7 @@ pub struct Daemon<L: EventListener> { account_history: account_history::AccountHistory, device_checker: device::TunnelStateChangeHandler, account_manager: device::AccountManagerHandle, - connection_modes_handler: api::AccessModeSelectorHandle, + access_mode_handler: api::AccessModeSelectorHandle, api_runtime: mullvad_api::Runtime, api_handle: mullvad_api::rest::MullvadRestHandle, version_updater_handle: version_check::VersionUpdaterHandle, @@ -707,7 +707,7 @@ where .set_config(new_selector_config(settings)); }); - let connection_modes_handler = api::AccessModeSelector::spawn( + let (access_mode_handler, access_mode_provider) = api::AccessModeSelector::spawn( cache_dir.clone(), relay_selector.clone(), settings.api_access_methods.clone(), @@ -717,15 +717,16 @@ where .await .map_err(Error::ApiConnectionModeError)?; - let initial_connection_mode = connection_modes_handler - .get_current() - .await - .map_err(Error::ApiConnectionModeError)?; + let api_handle = api_runtime.mullvad_rest_handle(access_mode_provider); - let api_handle = api_runtime.mullvad_rest_handle( - initial_connection_mode.connection_mode, - Box::pin(connection_modes_handler.clone().into_stream()), - ); + let access_method_handle = access_mode_handler.clone(); + settings.register_change_listener(move |settings| { + let handle = access_method_handle.clone(); + let new_access_methods = settings.api_access_methods.clone(); + tokio::spawn(async move { + let _ = handle.update_access_methods(new_access_methods).await; + }); + }); let migration_complete = if let Some(migration_data) = migration_data { migrations::migrate_device( @@ -801,7 +802,11 @@ where allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options), - allowed_endpoint: initial_connection_mode.endpoint, + allowed_endpoint: access_mode_handler + .get_current() + .await + .map_err(Error::ApiConnectionModeError)? + .endpoint, reset_firewall: *target_state != TargetState::Secured, #[cfg(windows)] exclude_paths, @@ -874,7 +879,7 @@ where account_history, device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()), account_manager, - connection_modes_handler, + access_mode_handler, api_runtime, api_handle, version_updater_handle, @@ -2117,9 +2122,12 @@ where { Ok(settings_changes) => { if settings_changes { - if let Err(error) = self.api_handle.service().next_api_endpoint().await { - log::error!("Failed to rotate API endpoint: {}", error); - } + let access_mode_handler = self.access_mode_handler.clone(); + tokio::spawn(async move { + if let Err(error) = access_mode_handler.rotate().await { + log::error!("Failed to rotate API endpoint: {error}"); + } + }); self.reconnect_tunnel(); }; Self::oneshot_send(tx, Ok(()), "set_bridge_settings"); @@ -2466,7 +2474,7 @@ where } fn on_get_current_api_access_method(&mut self, tx: ResponseTx<AccessMethodSetting, Error>) { - let handle = self.connection_modes_handler.clone(); + let handle = self.access_mode_handler.clone(); tokio::spawn(async move { let result = handle .get_current() @@ -2493,7 +2501,7 @@ where }; let daemon_event_sender = self.tx.to_specialized_sender(); - let access_method_selector = self.connection_modes_handler.clone(); + let access_method_selector = self.access_mode_handler.clone(); tokio::spawn(async move { let result = Self::test_access_method( proxy_endpoint, @@ -2524,7 +2532,7 @@ where } }; - let test_subject = match self.connection_modes_handler.resolve(access_method).await { + let test_subject = match self.access_mode_handler.resolve(access_method).await { Ok(test_subject) => test_subject, Err(err) => { reply(Err(Error::ApiConnectionModeError(err))); @@ -2534,7 +2542,7 @@ where let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode); let daemon_event_sender = self.tx.to_specialized_sender(); - let access_method_selector = self.connection_modes_handler.clone(); + let access_method_selector = self.access_mode_handler.clone(); tokio::spawn(async move { let result = Self::test_access_method( diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index bcd820bef5..05502bcc46 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -301,7 +301,7 @@ async fn send_problem_report_inner( let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await; let api_client = mullvad_api::ProblemReportProxy::new( - api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()), + api_runtime.mullvad_rest_handle(connection_mode.into_provider()), ); for _attempt in 0..MAX_SEND_ATTEMPTS { diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 4b14319414..e361d41a2b 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -160,7 +160,7 @@ async fn remove_device() -> Result<(), Error> { let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await; let proxy = mullvad_api::DevicesProxy::new( - api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()), + api_runtime.mullvad_rest_handle(connection_mode.into_provider()), ); let device_removal = retry_future( diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs index e8f6bd4a5d..fb365101be 100644 --- a/mullvad-types/src/access_method.rs +++ b/mullvad-types/src/access_method.rs @@ -62,11 +62,11 @@ impl Settings { pub fn update( &mut self, predicate: impl Fn(&AccessMethodSetting) -> bool, - f: impl FnOnce(&AccessMethodSetting) -> AccessMethodSetting, + f: impl FnOnce(&mut AccessMethodSetting), ) -> bool { let mut updated = false; if let Some(access_method) = self.iter_mut().find(|setting| predicate(setting)) { - *access_method = f(access_method); + f(access_method); updated = true; } self.ensure_consistent_state(); @@ -241,6 +241,13 @@ impl AccessMethodSetting { self.as_custom().is_none() } + pub fn is_direct(&self) -> bool { + matches!( + self.access_method, + AccessMethod::BuiltIn(BuiltInAccessMethod::Direct) + ) + } + /// Set an API access method to be enabled. pub fn enable(&mut self) { self.enabled = true; diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 1eeeb8c170..56bcadafbd 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -237,10 +237,7 @@ pub fn new_device_client() -> DevicesProxy { let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("failed to create api runtime"); - let rest_handle = api.mullvad_rest_handle( - ApiConnectionMode::Direct, - ApiConnectionMode::Direct.into_repeat(), - ); + let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()); DevicesProxy::new(rest_handle) } |
