diff options
| -rw-r--r-- | mullvad-api/src/lib.rs | 12 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 4 | ||||
| -rw-r--r-- | mullvad-cli/src/cmds/api_access.rs | 43 | ||||
| -rw-r--r-- | mullvad-daemon/src/access_method.rs | 172 | ||||
| -rw-r--r-- | mullvad-daemon/src/api.rs | 292 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 96 | ||||
| -rw-r--r-- | mullvad-daemon/src/management_interface.rs | 8 | ||||
| -rw-r--r-- | mullvad-management-interface/proto/management_interface.proto | 4 | ||||
| -rw-r--r-- | mullvad-management-interface/src/client.rs | 14 | ||||
| -rw-r--r-- | mullvad-management-interface/src/types/conversions/net.rs | 21 | ||||
| -rw-r--r-- | mullvad-types/src/access_method.rs | 8 |
11 files changed, 447 insertions, 227 deletions
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index c0024b22ee..c8765ec2b2 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -620,4 +620,16 @@ impl ApiProxy { let response = self.handle.service.request(request).await?; response.deserialize().await } + + /// Check the availablility of `{APP_URL_PREFIX}/api-addrs`. + pub async fn api_addrs_available(&self) -> Result<bool, rest::Error> { + let request = self + .handle + .factory + .head(&format!("{APP_URL_PREFIX}/api-addrs"))? + .expected_status(&[StatusCode::OK]); + + let response = self.handle.service.request(request).await?; + Ok(response.status().is_success()) + } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 2484bec64b..559ddd4b4e 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -524,6 +524,10 @@ impl RequestFactory { self.request(path, Method::DELETE) } + pub fn head(&self, path: &str) -> Result<Request> { + self.request(path, Method::HEAD) + } + pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> { self.json_request(Method::POST, path, body) } diff --git a/mullvad-cli/src/cmds/api_access.rs b/mullvad-cli/src/cmds/api_access.rs index 9ad481c4ef..c6e01c52d6 100644 --- a/mullvad-cli/src/cmds/api_access.rs +++ b/mullvad-cli/src/cmds/api_access.rs @@ -182,25 +182,16 @@ impl ApiAccess { /// Test an access method to see if it successfully reaches the Mullvad API. async fn test(item: SelectItem) -> Result<()> { let mut rpc = MullvadProxyClient::new().await?; - // Retrieve the currently used access method. We will reset to this - // after we are done testing. - let previous_access_method = rpc.get_current_api_access_method().await?; let access_method = Self::get_access_method(&mut rpc, &item).await?; println!("Testing access method \"{}\"", access_method.name); - rpc.set_access_method(access_method.get_id()).await?; - // Make the daemon perform an network request which involves talking to the Mullvad API. - let result = match rpc.get_api_addresses().await { - Ok(_) => { + match rpc.test_api_access_method(access_method.get_id()).await { + Ok(true) => { println!("Success!"); Ok(()) } - Err(_) => Err(anyhow!("Could not reach the Mullvad API")), - }; - // In any case, switch back to the previous access method. - rpc.set_access_method(previous_access_method.get_id()) - .await?; - result + Ok(false) | Err(_) => Err(anyhow!("Could not reach the Mullvad API.")), + } } /// Try to use of a specific [`AccessMethodSetting`] for subsequent calls to @@ -217,30 +208,24 @@ impl ApiAccess { /// configured ones. async fn set(item: SelectItem) -> Result<()> { let mut rpc = MullvadProxyClient::new().await?; - let previous_access_method = rpc.get_current_api_access_method().await?; let mut new_access_method = Self::get_access_method(&mut rpc, &item).await?; + let current_access_method = rpc.get_current_api_access_method().await?; // Try to reach the API with the newly selected access method. + rpc.test_api_access_method(new_access_method.get_id()) + .await + .map_err(|_| { + anyhow!("Could not reach the Mullvad API using access method \"{}\". Rolling back to \"{}\"", new_access_method.get_name(), current_access_method.get_name()) + })? + + ; + // If the test succeeded, the new access method should be used from now on. rpc.set_access_method(new_access_method.get_id()).await?; - match rpc.get_api_addresses().await { - Ok(_) => (), - Err(_) => { - // Roll-back to the previous access method - rpc.set_access_method(previous_access_method.get_id()) - .await?; - return Err(anyhow!( - "Could not reach the Mullvad API using access method \"{}\"", - new_access_method.get_name(), - )); - } - }; - // It worked! Let the daemon keep using this access method. - let display_name = new_access_method.get_name(); + println!("Using access method \"{}\"", new_access_method.get_name()); // Toggle the enabled status if needed if !new_access_method.enabled() { new_access_method.enable(); rpc.update_access_method(new_access_method).await?; } - println!("Using access method \"{}\"", display_name); Ok(()) } diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 4584aa374a..7d9d3dba95 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -1,7 +1,9 @@ use crate::{ + api::{self, AccessModeSelectorHandle}, settings::{self, MadeChanges}, Daemon, EventListener, }; +use mullvad_api::rest::{self, MullvadRestHandle}; use mullvad_types::{ access_method::{self, AccessMethod, AccessMethodSetting}, settings::Settings, @@ -18,13 +20,15 @@ pub enum Error { /// Can not find access method #[error(display = "Cannot find custom access method {}", _0)] NoSuchMethod(access_method::Id), - /// Can not find *any* access method. This should never happen. If it does, - /// the user should do a factory reset. - #[error(display = "No access methods are configured")] - NoMethodsExist, /// Access method could not be rotate #[error(display = "Access method could not be rotated")] RotationError, + /// Some error occured in the daemon's state of handling + /// [`AccessMethodSetting`]s & [`ApiConnectionMode`]s. + #[error(display = "Error occured when handling connection settings & details")] + ConnectionMode(#[error(source)] api::Error), + #[error(display = "API endpoint rotation failed")] + RestError(#[error(source)] rest::Error), /// Access methods settings error #[error(display = "Settings error")] Settings(#[error(source)] settings::Error), @@ -81,7 +85,9 @@ where Some(api_access_method) => { if api_access_method.is_builtin() { Err(Error::RemoveBuiltIn) - } else if api_access_method.get_id() == self.get_current_access_method()?.get_id() { + } else if api_access_method.get_id() + == self.get_current_access_method().await?.get_id() + { Ok(Command::Rotate) } else { Ok(Command::Nothing) @@ -108,15 +114,10 @@ where &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - let access_method = self - .settings - .api_access_methods - .find(&access_method) - .ok_or(Error::NoSuchMethod(access_method))?; - { - let mut connection_modes = self.connection_modes.lock().unwrap(); - connection_modes.set_access_method(access_method.clone()); - } + let access_method = self.get_api_access_method(access_method)?; + self.connection_modes_handler + .set_access_method(access_method) + .await?; // Force a rotation of Access Methods. // // This is not a call to `process_command` due to the restrictions on @@ -124,6 +125,17 @@ where self.force_api_endpoint_rotation().await } + pub fn get_api_access_method( + &mut self, + access_method: access_method::Id, + ) -> Result<AccessMethodSetting, Error> { + self.settings + .api_access_methods + .find(&access_method) + .ok_or(Error::NoSuchMethod(access_method)) + .cloned() + } + /// "Updates" an [`AccessMethodSetting`] by replacing the existing entry /// with the argument `access_method_update` if an existing entry with /// matching [`access_method::Id`] is found. @@ -140,7 +152,7 @@ where // in the daemon's settings. Therefore, we have to safeguard against // this by explicitly checking for & disallow any update which would // cause the last enabled access method to become disabled. - let current = self.get_current_access_method()?; + let current = self.get_current_access_method().await?; let mut command = Command::Nothing; let settings_update = |settings: &mut Settings| { if let Some(access_method) = settings @@ -165,9 +177,8 @@ where /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. - pub fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> { - let connections_modes = self.connection_modes.lock().unwrap(); - Ok(connections_modes.peek()) + pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> { + Ok(self.connection_modes_handler.get_access_method().await?) } /// Change which [`AccessMethodSetting`] which will be used to figure out @@ -189,29 +200,21 @@ where self.event_listener .notify_settings(self.settings.to_settings()); - let access_methods: Vec<_> = self - .settings - .api_access_methods - .access_method_settings - .iter() - .filter(|api_access_method| api_access_method.enabled()) - .cloned() - .collect(); - - let mut connection_modes = self.connection_modes.lock().unwrap(); - match connection_modes.update_access_methods(access_methods) { - Ok(_) => (), - Err(crate::api::Error::NoAccessMethods) => { - // `access_methods` was empty! This implies that the user - // disabled all access methods. If we ever get into this - // state, we should default to using the direct access - // method. - let default = access_method::Settings::direct(); - connection_modes - .update_access_methods(vec![default]) - .expect("Failed to create the data structure responsible for managing access methods"); + let handle = self.connection_modes_handler.clone(); + let new_access_methods = self.settings.api_access_methods.collect_enabled(); + tokio::spawn(async move { + match handle.update_access_methods(new_access_methods).await { + Ok(_) => (), + Err(api::Error::NoAccessMethods) | Err(_) => { + // `access_methods` was empty! This implies that the user + // disabled all access methods. If we ever get into this + // state, we should default to using the direct access + // method. + let default = access_method::Settings::direct(); + handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods"); + } } - } + }); }; self } @@ -225,3 +228,92 @@ where } } } + +/// Try to reach the Mullvad API using a specific access method, returning +/// an [`Error`] in the case where the test fails to reach the API. +/// +/// Ephemerally sets a new access method (associated with `access_method`) +/// to be used for subsequent API calls, before performing an API call and +/// switching back to the previously active access method. The previous +/// access method is *always* reset. +pub async fn test_access_method( + new_access_method: AccessMethodSetting, + access_mode_selector: AccessModeSelectorHandle, + rest_handle: MullvadRestHandle, +) -> Result<bool, Error> { + // Setup test + let previous_access_method = access_mode_selector + .get_access_method() + .await + .map_err(Error::ConnectionMode)?; + + let method_under_test = new_access_method.clone(); + access_mode_selector + .set_access_method(new_access_method) + .await + .map_err(Error::ConnectionMode)?; + + // We need to perform a rotation of API endpoint after a set action + let rotation_handle = rest_handle.clone(); + rotation_handle + .service() + .next_api_endpoint() + .await + .map_err(|err| { + log::error!("Failed to rotate API endpoint: {err}"); + Error::RestError(err) + })?; + + // Set up the reset + // + // In case the API call fails, the next API endpoint will + // automatically be selected, which means that we need to set up + // with the previous API endpoint beforehand. + access_mode_selector + .set_access_method(previous_access_method) + .await + .map_err(|err| { + log::error!( + "Could not reset to previous access + method after API reachability test was carried out. This should only + happen if the previous access method was removed in the meantime." + ); + Error::ConnectionMode(err) + })?; + + // Perform test + // + // Send a HEAD request to some Mullvad API endpoint. We issue a HEAD + // request because we are *only* concerned with if we get a reply from + // the API, and not with the actual data that the endpoint returns. + let result = mullvad_api::ApiProxy::new(rest_handle) + .api_addrs_available() + .await + .map_err(Error::RestError)?; + + // We need to perform a rotation of API endpoint after a set action + // Note that this will be done automatically if the API call fails, + // so it only has to be done if the call succeeded .. + if result { + rotation_handle + .service() + .next_api_endpoint() + .await + .map_err(|err| { + log::error!("Failed to rotate API endpoint: {err}"); + Error::RestError(err) + })?; + } + + log::info!( + "The result of testing {method:?} is {result}", + method = method_under_test.access_method, + result = if result { + "success".to_string() + } else { + "failed".to_string() + } + ); + + Ok(result) +} diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index d5099ae74a..2da307ff5f 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -1,8 +1,14 @@ +//! This module is responsible for enabling custom [`AccessMethodSetting`]s to +//! be used when connecting to the Mullvad API. In practice this means +//! converting [`AccessMethodSetting`]s to connection details as encoded by +//! [`ApiConnectionMode`], which in turn is used by `mullvad-api` for +//! establishing connections when performing API requests. #[cfg(target_os = "android")] use crate::{DaemonCommand, DaemonEventSender}; use futures::{ channel::{mpsc, oneshot}, - Future, Stream, StreamExt, + stream::unfold, + Stream, StreamExt, }; use mullvad_api::{ availability::ApiAvailabilityHandle, @@ -13,109 +19,239 @@ use mullvad_relay_selector::RelaySelector; use mullvad_types::access_method::{self, AccessMethod, AccessMethodSetting, BuiltInAccessMethod}; use std::{ path::PathBuf, - pin::Pin, sync::{Arc, Mutex, Weak}, - task::Poll, }; #[cfg(target_os = "android")] use talpid_core::mpsc::Sender; use talpid_core::tunnel_state_machine::TunnelCommand; -use talpid_types::{ - net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint}, - ErrorExt, -}; +use talpid_types::net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint}; + +pub enum Message { + Get(ResponseTx<AccessMethodSetting>), + Set(ResponseTx<()>, AccessMethodSetting), + Next(ResponseTx<ApiConnectionMode>), + Update(ResponseTx<()>, Vec<AccessMethodSetting>), +} + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "No access methods were provided.")] + NoAccessMethods, + #[error(display = "AccessModeSelector is not receiving any messages.")] + SendFailed(#[error(source)] mpsc::TrySendError<Message>), + #[error(display = "AccessModeSelector is not receiving any messages.")] + OneshotSendFailed, + #[error(display = "AccessModeSelector is not responding.")] + NotRunning(#[error(source)] oneshot::Canceled), +} + +type ResponseTx<T> = oneshot::Sender<Result<T>>; +type Result<T> = std::result::Result<T, Error>; + +/// A channel for sending [`Message`] commands to a running +/// [`AccessModeSelector`]. +#[derive(Clone)] +pub struct AccessModeSelectorHandle { + cmd_tx: mpsc::UnboundedSender<Message>, +} + +impl AccessModeSelectorHandle { + async fn send_command<T>(&self, make_cmd: impl FnOnce(ResponseTx<T>) -> Message) -> Result<T> { + let (tx, rx) = oneshot::channel(); + self.cmd_tx + .unbounded_send(make_cmd(tx)) + .map_err(Error::SendFailed)?; + rx.await.map_err(Error::NotRunning)? + } + + pub async fn get_access_method(&self) -> Result<AccessMethodSetting> { + self.send_command(Message::Get).await.map_err(|err| { + log::error!("Failed to get current access method!"); + err + }) + } + + pub async fn set_access_method(&self, value: AccessMethodSetting) -> Result<()> { + self.send_command(|tx| Message::Set(tx, value)) + .await + .map_err(|err| { + log::error!("Failed to set new access method!"); + err + }) + } -/// A stream that returns the next API connection mode to use for reaching the API. + pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> { + self.send_command(|tx| Message::Update(tx, values)) + .await + .map_err(|err| { + log::error!("Failed to update new access methods!"); + err + }) + } + + pub async fn next(&self) -> Result<ApiConnectionMode> { + self.send_command(Message::Next).await.map_err(|err| { + log::error!("Failed to update new access methods!"); + err + }) + } + + /// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the + /// associated [`AccessModeSelector`]. + /// + /// Practically converts the handle to a listener for when the + /// currently valid connection modes changes. + pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> { + unfold(self, |handle| async move { + match handle.next().await { + Ok(connection_mode) => Some((connection_mode, handle)), + // End this stream in case of failure in `next`. `next` should + // not fail if the actor is in a good state. + Err(_) => None, + } + }) + } +} + +/// A small actor which takes care of handling the logic around rotating +/// connection modes to be used for Mullvad API request. /// -/// When `mullvad-api` fails to contact the API, it requests a new connection -/// mode. The API can be connected to either directly (i.e., +/// When `mullvad-api` fails to contact the API, it will request a new +/// connection mode. The API can be connected to either directly (i.e., /// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`]) -/// or via any supported custom proxy protocol ([`api_access_methods::ObfuscationProtocol`]). +/// or via any supported custom proxy protocol +/// ([`api_access_methods::ObfuscationProtocol`]). /// /// The strategy for determining the next [`ApiConnectionMode`] is handled by /// [`ConnectionModesIterator`]. -pub struct ApiConnectionModeProvider { +pub struct AccessModeSelector { + cmd_rx: mpsc::UnboundedReceiver<Message>, cache_dir: PathBuf, /// Used for selecting a Bridge when the `Mullvad Bridges` access method is used. relay_selector: RelaySelector, - current_task: Option<Pin<Box<dyn Future<Output = ApiConnectionMode> + Send>>>, - connection_modes: Arc<Mutex<ConnectionModesIterator>>, + connection_modes: ConnectionModesIterator, } -impl Stream for ApiConnectionModeProvider { - type Item = ApiConnectionMode; +impl AccessModeSelector { + pub fn spawn( + cache_dir: PathBuf, + relay_selector: RelaySelector, + connection_modes: Vec<AccessMethodSetting>, + ) -> AccessModeSelectorHandle { + let (cmd_tx, cmd_rx) = mpsc::unbounded(); - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<Option<Self::Item>> { - // Poll the current task - if let Some(task) = self.current_task.as_mut() { - return match task.as_mut().poll(cx) { - Poll::Ready(mode) => { - self.current_task = None; - Poll::Ready(Some(mode)) - } - Poll::Pending => Poll::Pending, + let connection_modes = match ConnectionModesIterator::new(connection_modes) { + Ok(provider) => provider, + Err(Error::NoAccessMethods) | Err(_) => { + // No settings seem to have been found. Default to using the the + // direct access method. + let default = mullvad_types::access_method::Settings::direct(); + ConnectionModesIterator::new(vec![default]).expect( + "Failed to create the data structure responsible for managing access methods", + ) + } + }; + let selector = AccessModeSelector { + cmd_rx, + cache_dir, + relay_selector, + connection_modes, + }; + tokio::spawn(selector.into_future()); + AccessModeSelectorHandle { cmd_tx } + } + + async fn into_future(mut self) { + while let Some(cmd) = self.cmd_rx.next().await { + let execution = match cmd { + Message::Get(tx) => self.on_get_access_method(tx), + Message::Set(tx, value) => self.on_set_access_method(tx, value), + Message::Next(tx) => self.on_next_connection_mode(tx), + Message::Update(tx, values) => self.on_update_access_methods(tx, values), }; + match execution { + Ok(_) => (), + Err(err) => { + log::trace!( + "AccessModeSelector is going down due to {error}", + error = err + ); + break; + } + } } + } - let connection_mode = self.new_connection_mode(); + fn reply<T>(&self, tx: ResponseTx<T>, value: T) -> Result<()> { + tx.send(Ok(value)).map_err(|_| Error::OneshotSendFailed)?; + Ok(()) + } - let cache_dir = self.cache_dir.clone(); - self.current_task = Some(Box::pin(async move { - if let Err(error) = connection_mode.save(&cache_dir).await { - log::debug!( - "{}", - error.display_chain_with_msg("Failed to save API endpoint") - ); - } - connection_mode - })); + fn on_get_access_method(&mut self, tx: ResponseTx<AccessMethodSetting>) -> Result<()> { + let value = self.get_access_method(); + self.reply(tx, value) + } - self.poll_next(cx) + fn get_access_method(&mut self) -> AccessMethodSetting { + self.connection_modes.peek() } -} -impl ApiConnectionModeProvider { - pub(crate) fn new( - cache_dir: PathBuf, - relay_selector: RelaySelector, - connection_modes: Vec<AccessMethodSetting>, - ) -> Result<Self, Error> { - let connection_modes_iterator = ConnectionModesIterator::new(connection_modes)?; - Ok(Self { - cache_dir, - relay_selector, - current_task: None, - connection_modes: Arc::new(Mutex::new(connection_modes_iterator)), - }) + fn on_set_access_method( + &mut self, + tx: ResponseTx<()>, + value: AccessMethodSetting, + ) -> Result<()> { + self.set_access_method(value); + self.reply(tx, ()) } - /// Return a pointer to the underlying iterator over [`AccessMethod`]. - /// Having access to this iterator allow you to influence , e.g. by calling - /// [`ConnectionModesIterator::set_access_method()`] or - /// [`ConnectionModesIterator::update_access_methods()`]. - pub(crate) fn handle(&self) -> Arc<Mutex<ConnectionModesIterator>> { - self.connection_modes.clone() + fn set_access_method(&mut self, value: AccessMethodSetting) { + self.connection_modes.set_access_method(value); } - /// Return a new connection mode to be used for the API connection. - fn new_connection_mode(&mut self) -> ApiConnectionMode { - log::debug!("Rotating Access mode!"); - let access_method = { - let mut access_methods_picker = self.connection_modes.lock().unwrap(); - access_methods_picker - .next() - .map(|access_method_setting| access_method_setting.access_method) - .unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct)) - }; + fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> { + let next = self.next_connection_mode(); + // Save the new connection mode to cache! + { + let cache_dir = self.cache_dir.clone(); + let next = next.clone(); + tokio::spawn(async move { + if next.save(&cache_dir).await.is_err() { + log::warn!( + "Failed to save {connection_mode} to cache", + connection_mode = next + ) + } + }); + } + self.reply(tx, next) + } + + fn next_connection_mode(&mut self) -> ApiConnectionMode { + let access_method = self + .connection_modes + .next() + .map(|access_method_setting| access_method_setting.access_method) + .unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct)); let connection_mode = self.from(access_method); - log::info!("New API connection mode selected: {}", connection_mode); + log::info!("New API connection mode selected: {connection_mode}"); connection_mode } + fn on_update_access_methods( + &mut self, + tx: ResponseTx<()>, + values: Vec<AccessMethodSetting>, + ) -> Result<()> { + self.update_access_methods(values)?; + self.reply(tx, ()) + } + + fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> { + self.connection_modes.update_access_methods(values) + } + /// Ad-hoc version of [`std::convert::From::from`], but since some /// [`ApiConnectionMode`]s require extra logic/data from /// [`ApiConnectionModeProvider`] the standard [`std::convert::From`] trait @@ -172,14 +308,10 @@ pub struct ConnectionModesIterator { current: AccessMethodSetting, } -#[derive(err_derive::Error, Debug)] -pub enum Error { - #[error(display = "No access methods were provided.")] - NoAccessMethods, -} - impl ConnectionModesIterator { - pub fn new(access_methods: Vec<AccessMethodSetting>) -> Result<ConnectionModesIterator, Error> { + pub fn new( + access_methods: Vec<AccessMethodSetting>, + ) -> std::result::Result<ConnectionModesIterator, Error> { let mut iterator = Self::new_iterator(access_methods)?; Ok(Self { next: None, @@ -197,7 +329,7 @@ impl ConnectionModesIterator { pub fn update_access_methods( &mut self, access_methods: Vec<AccessMethodSetting>, - ) -> Result<(), Error> { + ) -> std::result::Result<(), Error> { self.available_modes = Self::new_iterator(access_methods)?; Ok(()) } @@ -208,7 +340,7 @@ impl ConnectionModesIterator { /// returned. fn new_iterator( access_methods: Vec<AccessMethodSetting>, - ) -> Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> { + ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> { if access_methods.is_empty() { Err(Error::NoAccessMethods) } else { diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 65f2c56654..0288f9d8c4 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -66,7 +66,7 @@ use std::{ mem, path::PathBuf, pin::Pin, - sync::{Arc, Mutex, Weak}, + sync::{Arc, Weak}, time::Duration, }; #[cfg(any(target_os = "linux", windows))] @@ -179,6 +179,9 @@ pub enum Error { #[error(display = "Access method error")] AccessMethodError(#[error(source)] access_method::Error), + #[error(display = "API connection mode error")] + ApiConnectionModeError(#[error(source)] api::Error), + #[cfg(target_os = "macos")] #[error(display = "Failed to set exclusion group")] GroupIdError(#[error(source)] io::Error), @@ -293,8 +296,8 @@ pub enum DaemonCommand { UpdateApiAccessMethod(ResponseTx<(), Error>, AccessMethodSetting), /// Get the currently used API access method GetCurrentAccessMethod(ResponseTx<AccessMethodSetting, Error>), - /// Get the addresses of all known API endpoints - GetApiAddresses(ResponseTx<Vec<std::net::SocketAddr>, Error>), + /// Test an API access method + TestApiAccessMethod(ResponseTx<bool, Error>, mullvad_types::access_method::Id), /// Get information about the currently running and latest app versions GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>), /// Return whether the daemon is performing post-upgrade tasks @@ -602,7 +605,7 @@ pub struct Daemon<L: EventListener> { account_history: account_history::AccountHistory, device_checker: device::TunnelStateChangeHandler, account_manager: device::AccountManagerHandle, - connection_modes: Arc<Mutex<api::ConnectionModesIterator>>, + connection_modes_handler: api::AccessModeSelectorHandle, api_runtime: mullvad_api::Runtime, api_handle: mullvad_api::rest::MullvadRestHandle, version_updater_handle: version_check::VersionUpdaterHandle, @@ -680,38 +683,19 @@ where .set_config(new_selector_config(settings)); }); - let proxy_provider = match api::ApiConnectionModeProvider::new( + let connection_modes = settings.api_access_methods.collect_enabled(); + + let connection_modes_handler = api::AccessModeSelector::spawn( cache_dir.clone(), relay_selector.clone(), - settings - .api_access_methods - .access_method_settings - .iter() - // We only care about the access methods which are set to 'enabled' by the user. - .filter(|api_access_method| api_access_method.enabled()) - .cloned() - .collect(), - ) { - Ok(provider) => provider, - Err(api::Error::NoAccessMethods) => { - // No settings seem to have been found. Default to using the the - // direct access method. - let default = mullvad_types::access_method::Settings::direct(); - api::ApiConnectionModeProvider::new( - cache_dir.clone(), - relay_selector.clone(), - vec![default], - ) - .expect( - "Failed to create the data structure responsible for managing access methods", - ) - } - }; - - let connection_modes = proxy_provider.handle(); + connection_modes, + ); let api_handle = api_runtime - .mullvad_rest_handle(proxy_provider, endpoint_updater.callback()) + .mullvad_rest_handle( + Box::pin(connection_modes_handler.clone().into_stream()), + endpoint_updater.callback(), + ) .await; let migration_complete = if let Some(migration_data) = migration_data { @@ -861,7 +845,7 @@ where account_history, device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()), account_manager, - connection_modes, + connection_modes_handler, api_runtime, api_handle, version_updater_handle, @@ -1151,7 +1135,7 @@ where UpdateApiAccessMethod(tx, method) => self.on_update_api_access_method(tx, method).await, GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx), SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await, - GetApiAddresses(tx) => self.on_get_api_addresses(tx).await, + TestApiAccessMethod(tx, method) => self.on_test_api_access_method(tx, method), IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx), GetCurrentVersion(tx) => self.on_get_current_version(tx), #[cfg(not(target_os = "android"))] @@ -2375,17 +2359,45 @@ where } fn on_get_current_api_access_method(&mut self, tx: ResponseTx<AccessMethodSetting, Error>) { - let result = self - .get_current_access_method() - .map_err(Error::AccessMethodError); - Self::oneshot_send(tx, result, "get_current_api_access_method response"); + let handle = self.connection_modes_handler.clone(); + tokio::spawn(async move { + let result = handle + .get_access_method() + .await + .map_err(Error::ApiConnectionModeError); + Self::oneshot_send(tx, result, "get_current_api_access_method response"); + }); } - async fn on_get_api_addresses(&mut self, tx: ResponseTx<Vec<std::net::SocketAddr>, Error>) { - let api_proxy = mullvad_api::ApiProxy::new(self.api_handle.clone()); - let result = api_proxy.get_api_addrs().await.map_err(Error::RestError); + fn on_test_api_access_method( + &mut self, + tx: ResponseTx<bool, Error>, + access_method: mullvad_types::access_method::Id, + ) { + // NOTE: Preferably we would block all new API calls until the test is + // done and the previous access method is reset. Otherwise we run the + // risk of errounously triggering a rotation of the currently in-use + // access method. + let api_handle = self.api_handle.clone(); + let handle = self.connection_modes_handler.clone(); + let access_method_lookup = self + .get_api_access_method(access_method) + .map_err(Error::AccessMethodError); - Self::oneshot_send(tx, result, "on_get_api_adressess response"); + match access_method_lookup { + Ok(access_method) => { + tokio::spawn(async move { + let result = + access_method::test_access_method(access_method, handle, api_handle) + .await + .map_err(Error::AccessMethodError); + Self::oneshot_send(tx, result, "on_test_api_access_method response"); + }); + } + Err(err) => { + Self::oneshot_send(tx, Err(err), "on_test_api_access_method response"); + } + } } fn on_get_settings(&self, tx: oneshot::Sender<Settings>) { diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index f042a923e5..c194825a34 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -693,13 +693,13 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_daemon_error) } - async fn get_api_addresses(&self, _: Request<()>) -> ServiceResult<types::ApiAddresses> { - log::debug!("get_api_addresses"); + async fn test_api_access_method(&self, request: Request<types::Uuid>) -> ServiceResult<bool> { + log::debug!("test_api_access_method"); let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::GetApiAddresses(tx))?; + let api_access_method = mullvad_types::access_method::Id::try_from(request.into_inner())?; + self.send_command_to_daemon(DaemonCommand::TestApiAccessMethod(tx, api_access_method))?; self.wait_for_result(rx) .await? - .map(types::ApiAddresses::from) .map(Response::new) .map_err(map_daemon_error) } diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index a27698f317..d66707b79f 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -22,7 +22,6 @@ service ManagementService { rpc GetCurrentVersion(google.protobuf.Empty) returns (google.protobuf.StringValue) {} rpc GetVersionInfo(google.protobuf.Empty) returns (AppVersionInfo) {} - rpc GetApiAddresses(google.protobuf.Empty) returns (ApiAddresses) {} rpc IsPerformingPostUpgrade(google.protobuf.Empty) returns (google.protobuf.BoolValue) {} @@ -82,6 +81,7 @@ service ManagementService { rpc SetApiAccessMethod(UUID) returns (google.protobuf.Empty) {} rpc UpdateApiAccessMethod(AccessMethodSetting) returns (google.protobuf.Empty) {} rpc GetCurrentApiAccessMethod(google.protobuf.Empty) returns (AccessMethodSetting) {} + rpc TestApiAccessMethod(UUID) returns (google.protobuf.BoolValue) {} // Split tunneling (Linux) rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {} @@ -110,8 +110,6 @@ message AccountData { google.protobuf.Timestamp expiry = 1; } message AccountHistory { google.protobuf.StringValue token = 1; } -message ApiAddresses { repeated google.protobuf.StringValue api_addresses = 1; } - message VoucherSubmission { uint64 seconds_added = 1; google.protobuf.Timestamp new_expiry = 2; diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 64ee088d18..1c9d80b2e8 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -208,15 +208,13 @@ impl MullvadProxyClient { }) } - pub async fn get_api_addresses(&mut self) -> Result<Vec<std::net::SocketAddr>> { - self.0 - .get_api_addresses(()) + pub async fn test_api_access_method(&mut self, id: access_method::Id) -> Result<bool> { + let result = self + .0 + .test_api_access_method(types::Uuid::from(id)) .await - .map_err(Error::Rpc) - .map(tonic::Response::into_inner) - .and_then(|api_addresses| { - Vec::<std::net::SocketAddr>::try_from(api_addresses).map_err(Error::InvalidResponse) - }) + .map_err(Error::Rpc)?; + Ok(result.into_inner()) } pub async fn update_relay_locations(&mut self) -> Result<()> { diff --git a/mullvad-management-interface/src/types/conversions/net.rs b/mullvad-management-interface/src/types/conversions/net.rs index 7b24a8f2b4..e2df5553a0 100644 --- a/mullvad-management-interface/src/types/conversions/net.rs +++ b/mullvad-management-interface/src/types/conversions/net.rs @@ -163,27 +163,6 @@ impl From<proto::IpVersion> for talpid_types::net::IpVersion { } } -impl TryFrom<proto::ApiAddresses> for Vec<SocketAddr> { - type Error = FromProtobufTypeError; - - fn try_from(value: proto::ApiAddresses) -> Result<Self, Self::Error> { - value - .api_addresses - .iter() - .map(|api_address| api_address.parse::<SocketAddr>()) - .collect::<Result<_, _>>() - .map_err(|_| FromProtobufTypeError::InvalidArgument("Invalid socket address")) - } -} - -impl From<Vec<SocketAddr>> for proto::ApiAddresses { - fn from(value: Vec<SocketAddr>) -> Self { - Self { - api_addresses: value.iter().map(SocketAddr::to_string).collect(), - } - } -} - pub fn try_tunnel_type_from_i32( tunnel_type: i32, ) -> Result<talpid_types::net::TunnelType, FromProtobufTypeError> { diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs index fe4b2507ed..7afaf94dfc 100644 --- a/mullvad-types/src/access_method.rs +++ b/mullvad-types/src/access_method.rs @@ -61,6 +61,14 @@ impl Settings { let method = BuiltInAccessMethod::Bridge; AccessMethodSetting::new(method.canonical_name(), true, AccessMethod::from(method)) } + + /// Retrieve all [`AccessMethodSetting`]s which are enabled. + pub fn collect_enabled(&self) -> Vec<AccessMethodSetting> { + self.cloned() + .into_iter() + .filter(|access_method| access_method.enabled) + .collect() + } } impl Default for Settings { |
