summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-api/src/lib.rs12
-rw-r--r--mullvad-api/src/rest.rs4
-rw-r--r--mullvad-cli/src/cmds/api_access.rs43
-rw-r--r--mullvad-daemon/src/access_method.rs172
-rw-r--r--mullvad-daemon/src/api.rs292
-rw-r--r--mullvad-daemon/src/lib.rs96
-rw-r--r--mullvad-daemon/src/management_interface.rs8
-rw-r--r--mullvad-management-interface/proto/management_interface.proto4
-rw-r--r--mullvad-management-interface/src/client.rs14
-rw-r--r--mullvad-management-interface/src/types/conversions/net.rs21
-rw-r--r--mullvad-types/src/access_method.rs8
11 files changed, 447 insertions, 227 deletions
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index c0024b22ee..c8765ec2b2 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -620,4 +620,16 @@ impl ApiProxy {
let response = self.handle.service.request(request).await?;
response.deserialize().await
}
+
+ /// Check the availablility of `{APP_URL_PREFIX}/api-addrs`.
+ pub async fn api_addrs_available(&self) -> Result<bool, rest::Error> {
+ let request = self
+ .handle
+ .factory
+ .head(&format!("{APP_URL_PREFIX}/api-addrs"))?
+ .expected_status(&[StatusCode::OK]);
+
+ let response = self.handle.service.request(request).await?;
+ Ok(response.status().is_success())
+ }
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 2484bec64b..559ddd4b4e 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -524,6 +524,10 @@ impl RequestFactory {
self.request(path, Method::DELETE)
}
+ pub fn head(&self, path: &str) -> Result<Request> {
+ self.request(path, Method::HEAD)
+ }
+
pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> {
self.json_request(Method::POST, path, body)
}
diff --git a/mullvad-cli/src/cmds/api_access.rs b/mullvad-cli/src/cmds/api_access.rs
index 9ad481c4ef..c6e01c52d6 100644
--- a/mullvad-cli/src/cmds/api_access.rs
+++ b/mullvad-cli/src/cmds/api_access.rs
@@ -182,25 +182,16 @@ impl ApiAccess {
/// Test an access method to see if it successfully reaches the Mullvad API.
async fn test(item: SelectItem) -> Result<()> {
let mut rpc = MullvadProxyClient::new().await?;
- // Retrieve the currently used access method. We will reset to this
- // after we are done testing.
- let previous_access_method = rpc.get_current_api_access_method().await?;
let access_method = Self::get_access_method(&mut rpc, &item).await?;
println!("Testing access method \"{}\"", access_method.name);
- rpc.set_access_method(access_method.get_id()).await?;
- // Make the daemon perform an network request which involves talking to the Mullvad API.
- let result = match rpc.get_api_addresses().await {
- Ok(_) => {
+ match rpc.test_api_access_method(access_method.get_id()).await {
+ Ok(true) => {
println!("Success!");
Ok(())
}
- Err(_) => Err(anyhow!("Could not reach the Mullvad API")),
- };
- // In any case, switch back to the previous access method.
- rpc.set_access_method(previous_access_method.get_id())
- .await?;
- result
+ Ok(false) | Err(_) => Err(anyhow!("Could not reach the Mullvad API.")),
+ }
}
/// Try to use of a specific [`AccessMethodSetting`] for subsequent calls to
@@ -217,30 +208,24 @@ impl ApiAccess {
/// configured ones.
async fn set(item: SelectItem) -> Result<()> {
let mut rpc = MullvadProxyClient::new().await?;
- let previous_access_method = rpc.get_current_api_access_method().await?;
let mut new_access_method = Self::get_access_method(&mut rpc, &item).await?;
+ let current_access_method = rpc.get_current_api_access_method().await?;
// Try to reach the API with the newly selected access method.
+ rpc.test_api_access_method(new_access_method.get_id())
+ .await
+ .map_err(|_| {
+ anyhow!("Could not reach the Mullvad API using access method \"{}\". Rolling back to \"{}\"", new_access_method.get_name(), current_access_method.get_name())
+ })?
+
+ ;
+ // If the test succeeded, the new access method should be used from now on.
rpc.set_access_method(new_access_method.get_id()).await?;
- match rpc.get_api_addresses().await {
- Ok(_) => (),
- Err(_) => {
- // Roll-back to the previous access method
- rpc.set_access_method(previous_access_method.get_id())
- .await?;
- return Err(anyhow!(
- "Could not reach the Mullvad API using access method \"{}\"",
- new_access_method.get_name(),
- ));
- }
- };
- // It worked! Let the daemon keep using this access method.
- let display_name = new_access_method.get_name();
+ println!("Using access method \"{}\"", new_access_method.get_name());
// Toggle the enabled status if needed
if !new_access_method.enabled() {
new_access_method.enable();
rpc.update_access_method(new_access_method).await?;
}
- println!("Using access method \"{}\"", display_name);
Ok(())
}
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index 4584aa374a..7d9d3dba95 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -1,7 +1,9 @@
use crate::{
+ api::{self, AccessModeSelectorHandle},
settings::{self, MadeChanges},
Daemon, EventListener,
};
+use mullvad_api::rest::{self, MullvadRestHandle};
use mullvad_types::{
access_method::{self, AccessMethod, AccessMethodSetting},
settings::Settings,
@@ -18,13 +20,15 @@ pub enum Error {
/// Can not find access method
#[error(display = "Cannot find custom access method {}", _0)]
NoSuchMethod(access_method::Id),
- /// Can not find *any* access method. This should never happen. If it does,
- /// the user should do a factory reset.
- #[error(display = "No access methods are configured")]
- NoMethodsExist,
/// Access method could not be rotate
#[error(display = "Access method could not be rotated")]
RotationError,
+ /// Some error occured in the daemon's state of handling
+ /// [`AccessMethodSetting`]s & [`ApiConnectionMode`]s.
+ #[error(display = "Error occured when handling connection settings & details")]
+ ConnectionMode(#[error(source)] api::Error),
+ #[error(display = "API endpoint rotation failed")]
+ RestError(#[error(source)] rest::Error),
/// Access methods settings error
#[error(display = "Settings error")]
Settings(#[error(source)] settings::Error),
@@ -81,7 +85,9 @@ where
Some(api_access_method) => {
if api_access_method.is_builtin() {
Err(Error::RemoveBuiltIn)
- } else if api_access_method.get_id() == self.get_current_access_method()?.get_id() {
+ } else if api_access_method.get_id()
+ == self.get_current_access_method().await?.get_id()
+ {
Ok(Command::Rotate)
} else {
Ok(Command::Nothing)
@@ -108,15 +114,10 @@ where
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
- let access_method = self
- .settings
- .api_access_methods
- .find(&access_method)
- .ok_or(Error::NoSuchMethod(access_method))?;
- {
- let mut connection_modes = self.connection_modes.lock().unwrap();
- connection_modes.set_access_method(access_method.clone());
- }
+ let access_method = self.get_api_access_method(access_method)?;
+ self.connection_modes_handler
+ .set_access_method(access_method)
+ .await?;
// Force a rotation of Access Methods.
//
// This is not a call to `process_command` due to the restrictions on
@@ -124,6 +125,17 @@ where
self.force_api_endpoint_rotation().await
}
+ pub fn get_api_access_method(
+ &mut self,
+ access_method: access_method::Id,
+ ) -> Result<AccessMethodSetting, Error> {
+ self.settings
+ .api_access_methods
+ .find(&access_method)
+ .ok_or(Error::NoSuchMethod(access_method))
+ .cloned()
+ }
+
/// "Updates" an [`AccessMethodSetting`] by replacing the existing entry
/// with the argument `access_method_update` if an existing entry with
/// matching [`access_method::Id`] is found.
@@ -140,7 +152,7 @@ where
// in the daemon's settings. Therefore, we have to safeguard against
// this by explicitly checking for & disallow any update which would
// cause the last enabled access method to become disabled.
- let current = self.get_current_access_method()?;
+ let current = self.get_current_access_method().await?;
let mut command = Command::Nothing;
let settings_update = |settings: &mut Settings| {
if let Some(access_method) = settings
@@ -165,9 +177,8 @@ where
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
- pub fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
- let connections_modes = self.connection_modes.lock().unwrap();
- Ok(connections_modes.peek())
+ pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
+ Ok(self.connection_modes_handler.get_access_method().await?)
}
/// Change which [`AccessMethodSetting`] which will be used to figure out
@@ -189,29 +200,21 @@ where
self.event_listener
.notify_settings(self.settings.to_settings());
- let access_methods: Vec<_> = self
- .settings
- .api_access_methods
- .access_method_settings
- .iter()
- .filter(|api_access_method| api_access_method.enabled())
- .cloned()
- .collect();
-
- let mut connection_modes = self.connection_modes.lock().unwrap();
- match connection_modes.update_access_methods(access_methods) {
- Ok(_) => (),
- Err(crate::api::Error::NoAccessMethods) => {
- // `access_methods` was empty! This implies that the user
- // disabled all access methods. If we ever get into this
- // state, we should default to using the direct access
- // method.
- let default = access_method::Settings::direct();
- connection_modes
- .update_access_methods(vec![default])
- .expect("Failed to create the data structure responsible for managing access methods");
+ let handle = self.connection_modes_handler.clone();
+ let new_access_methods = self.settings.api_access_methods.collect_enabled();
+ tokio::spawn(async move {
+ match handle.update_access_methods(new_access_methods).await {
+ Ok(_) => (),
+ Err(api::Error::NoAccessMethods) | Err(_) => {
+ // `access_methods` was empty! This implies that the user
+ // disabled all access methods. If we ever get into this
+ // state, we should default to using the direct access
+ // method.
+ let default = access_method::Settings::direct();
+ handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods");
+ }
}
- }
+ });
};
self
}
@@ -225,3 +228,92 @@ where
}
}
}
+
+/// Try to reach the Mullvad API using a specific access method, returning
+/// an [`Error`] in the case where the test fails to reach the API.
+///
+/// Ephemerally sets a new access method (associated with `access_method`)
+/// to be used for subsequent API calls, before performing an API call and
+/// switching back to the previously active access method. The previous
+/// access method is *always* reset.
+pub async fn test_access_method(
+ new_access_method: AccessMethodSetting,
+ access_mode_selector: AccessModeSelectorHandle,
+ rest_handle: MullvadRestHandle,
+) -> Result<bool, Error> {
+ // Setup test
+ let previous_access_method = access_mode_selector
+ .get_access_method()
+ .await
+ .map_err(Error::ConnectionMode)?;
+
+ let method_under_test = new_access_method.clone();
+ access_mode_selector
+ .set_access_method(new_access_method)
+ .await
+ .map_err(Error::ConnectionMode)?;
+
+ // We need to perform a rotation of API endpoint after a set action
+ let rotation_handle = rest_handle.clone();
+ rotation_handle
+ .service()
+ .next_api_endpoint()
+ .await
+ .map_err(|err| {
+ log::error!("Failed to rotate API endpoint: {err}");
+ Error::RestError(err)
+ })?;
+
+ // Set up the reset
+ //
+ // In case the API call fails, the next API endpoint will
+ // automatically be selected, which means that we need to set up
+ // with the previous API endpoint beforehand.
+ access_mode_selector
+ .set_access_method(previous_access_method)
+ .await
+ .map_err(|err| {
+ log::error!(
+ "Could not reset to previous access
+ method after API reachability test was carried out. This should only
+ happen if the previous access method was removed in the meantime."
+ );
+ Error::ConnectionMode(err)
+ })?;
+
+ // Perform test
+ //
+ // Send a HEAD request to some Mullvad API endpoint. We issue a HEAD
+ // request because we are *only* concerned with if we get a reply from
+ // the API, and not with the actual data that the endpoint returns.
+ let result = mullvad_api::ApiProxy::new(rest_handle)
+ .api_addrs_available()
+ .await
+ .map_err(Error::RestError)?;
+
+ // We need to perform a rotation of API endpoint after a set action
+ // Note that this will be done automatically if the API call fails,
+ // so it only has to be done if the call succeeded ..
+ if result {
+ rotation_handle
+ .service()
+ .next_api_endpoint()
+ .await
+ .map_err(|err| {
+ log::error!("Failed to rotate API endpoint: {err}");
+ Error::RestError(err)
+ })?;
+ }
+
+ log::info!(
+ "The result of testing {method:?} is {result}",
+ method = method_under_test.access_method,
+ result = if result {
+ "success".to_string()
+ } else {
+ "failed".to_string()
+ }
+ );
+
+ Ok(result)
+}
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index d5099ae74a..2da307ff5f 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -1,8 +1,14 @@
+//! This module is responsible for enabling custom [`AccessMethodSetting`]s to
+//! be used when connecting to the Mullvad API. In practice this means
+//! converting [`AccessMethodSetting`]s to connection details as encoded by
+//! [`ApiConnectionMode`], which in turn is used by `mullvad-api` for
+//! establishing connections when performing API requests.
#[cfg(target_os = "android")]
use crate::{DaemonCommand, DaemonEventSender};
use futures::{
channel::{mpsc, oneshot},
- Future, Stream, StreamExt,
+ stream::unfold,
+ Stream, StreamExt,
};
use mullvad_api::{
availability::ApiAvailabilityHandle,
@@ -13,109 +19,239 @@ use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{self, AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
use std::{
path::PathBuf,
- pin::Pin,
sync::{Arc, Mutex, Weak},
- task::Poll,
};
#[cfg(target_os = "android")]
use talpid_core::mpsc::Sender;
use talpid_core::tunnel_state_machine::TunnelCommand;
-use talpid_types::{
- net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint},
- ErrorExt,
-};
+use talpid_types::net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint};
+
+pub enum Message {
+ Get(ResponseTx<AccessMethodSetting>),
+ Set(ResponseTx<()>, AccessMethodSetting),
+ Next(ResponseTx<ApiConnectionMode>),
+ Update(ResponseTx<()>, Vec<AccessMethodSetting>),
+}
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "No access methods were provided.")]
+ NoAccessMethods,
+ #[error(display = "AccessModeSelector is not receiving any messages.")]
+ SendFailed(#[error(source)] mpsc::TrySendError<Message>),
+ #[error(display = "AccessModeSelector is not receiving any messages.")]
+ OneshotSendFailed,
+ #[error(display = "AccessModeSelector is not responding.")]
+ NotRunning(#[error(source)] oneshot::Canceled),
+}
+
+type ResponseTx<T> = oneshot::Sender<Result<T>>;
+type Result<T> = std::result::Result<T, Error>;
+
+/// A channel for sending [`Message`] commands to a running
+/// [`AccessModeSelector`].
+#[derive(Clone)]
+pub struct AccessModeSelectorHandle {
+ cmd_tx: mpsc::UnboundedSender<Message>,
+}
+
+impl AccessModeSelectorHandle {
+ async fn send_command<T>(&self, make_cmd: impl FnOnce(ResponseTx<T>) -> Message) -> Result<T> {
+ let (tx, rx) = oneshot::channel();
+ self.cmd_tx
+ .unbounded_send(make_cmd(tx))
+ .map_err(Error::SendFailed)?;
+ rx.await.map_err(Error::NotRunning)?
+ }
+
+ pub async fn get_access_method(&self) -> Result<AccessMethodSetting> {
+ self.send_command(Message::Get).await.map_err(|err| {
+ log::error!("Failed to get current access method!");
+ err
+ })
+ }
+
+ pub async fn set_access_method(&self, value: AccessMethodSetting) -> Result<()> {
+ self.send_command(|tx| Message::Set(tx, value))
+ .await
+ .map_err(|err| {
+ log::error!("Failed to set new access method!");
+ err
+ })
+ }
-/// A stream that returns the next API connection mode to use for reaching the API.
+ pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> {
+ self.send_command(|tx| Message::Update(tx, values))
+ .await
+ .map_err(|err| {
+ log::error!("Failed to update new access methods!");
+ err
+ })
+ }
+
+ pub async fn next(&self) -> Result<ApiConnectionMode> {
+ self.send_command(Message::Next).await.map_err(|err| {
+ log::error!("Failed to update new access methods!");
+ err
+ })
+ }
+
+ /// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the
+ /// associated [`AccessModeSelector`].
+ ///
+ /// Practically converts the handle to a listener for when the
+ /// currently valid connection modes changes.
+ pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> {
+ unfold(self, |handle| async move {
+ match handle.next().await {
+ Ok(connection_mode) => Some((connection_mode, handle)),
+ // End this stream in case of failure in `next`. `next` should
+ // not fail if the actor is in a good state.
+ Err(_) => None,
+ }
+ })
+ }
+}
+
+/// A small actor which takes care of handling the logic around rotating
+/// connection modes to be used for Mullvad API request.
///
-/// When `mullvad-api` fails to contact the API, it requests a new connection
-/// mode. The API can be connected to either directly (i.e.,
+/// When `mullvad-api` fails to contact the API, it will request a new
+/// connection mode. The API can be connected to either directly (i.e.,
/// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`])
-/// or via any supported custom proxy protocol ([`api_access_methods::ObfuscationProtocol`]).
+/// or via any supported custom proxy protocol
+/// ([`api_access_methods::ObfuscationProtocol`]).
///
/// The strategy for determining the next [`ApiConnectionMode`] is handled by
/// [`ConnectionModesIterator`].
-pub struct ApiConnectionModeProvider {
+pub struct AccessModeSelector {
+ cmd_rx: mpsc::UnboundedReceiver<Message>,
cache_dir: PathBuf,
/// Used for selecting a Bridge when the `Mullvad Bridges` access method is used.
relay_selector: RelaySelector,
- current_task: Option<Pin<Box<dyn Future<Output = ApiConnectionMode> + Send>>>,
- connection_modes: Arc<Mutex<ConnectionModesIterator>>,
+ connection_modes: ConnectionModesIterator,
}
-impl Stream for ApiConnectionModeProvider {
- type Item = ApiConnectionMode;
+impl AccessModeSelector {
+ pub fn spawn(
+ cache_dir: PathBuf,
+ relay_selector: RelaySelector,
+ connection_modes: Vec<AccessMethodSetting>,
+ ) -> AccessModeSelectorHandle {
+ let (cmd_tx, cmd_rx) = mpsc::unbounded();
- fn poll_next(
- mut self: Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- // Poll the current task
- if let Some(task) = self.current_task.as_mut() {
- return match task.as_mut().poll(cx) {
- Poll::Ready(mode) => {
- self.current_task = None;
- Poll::Ready(Some(mode))
- }
- Poll::Pending => Poll::Pending,
+ let connection_modes = match ConnectionModesIterator::new(connection_modes) {
+ Ok(provider) => provider,
+ Err(Error::NoAccessMethods) | Err(_) => {
+ // No settings seem to have been found. Default to using the the
+ // direct access method.
+ let default = mullvad_types::access_method::Settings::direct();
+ ConnectionModesIterator::new(vec![default]).expect(
+ "Failed to create the data structure responsible for managing access methods",
+ )
+ }
+ };
+ let selector = AccessModeSelector {
+ cmd_rx,
+ cache_dir,
+ relay_selector,
+ connection_modes,
+ };
+ tokio::spawn(selector.into_future());
+ AccessModeSelectorHandle { cmd_tx }
+ }
+
+ async fn into_future(mut self) {
+ while let Some(cmd) = self.cmd_rx.next().await {
+ let execution = match cmd {
+ Message::Get(tx) => self.on_get_access_method(tx),
+ Message::Set(tx, value) => self.on_set_access_method(tx, value),
+ Message::Next(tx) => self.on_next_connection_mode(tx),
+ Message::Update(tx, values) => self.on_update_access_methods(tx, values),
};
+ match execution {
+ Ok(_) => (),
+ Err(err) => {
+ log::trace!(
+ "AccessModeSelector is going down due to {error}",
+ error = err
+ );
+ break;
+ }
+ }
}
+ }
- let connection_mode = self.new_connection_mode();
+ fn reply<T>(&self, tx: ResponseTx<T>, value: T) -> Result<()> {
+ tx.send(Ok(value)).map_err(|_| Error::OneshotSendFailed)?;
+ Ok(())
+ }
- let cache_dir = self.cache_dir.clone();
- self.current_task = Some(Box::pin(async move {
- if let Err(error) = connection_mode.save(&cache_dir).await {
- log::debug!(
- "{}",
- error.display_chain_with_msg("Failed to save API endpoint")
- );
- }
- connection_mode
- }));
+ fn on_get_access_method(&mut self, tx: ResponseTx<AccessMethodSetting>) -> Result<()> {
+ let value = self.get_access_method();
+ self.reply(tx, value)
+ }
- self.poll_next(cx)
+ fn get_access_method(&mut self) -> AccessMethodSetting {
+ self.connection_modes.peek()
}
-}
-impl ApiConnectionModeProvider {
- pub(crate) fn new(
- cache_dir: PathBuf,
- relay_selector: RelaySelector,
- connection_modes: Vec<AccessMethodSetting>,
- ) -> Result<Self, Error> {
- let connection_modes_iterator = ConnectionModesIterator::new(connection_modes)?;
- Ok(Self {
- cache_dir,
- relay_selector,
- current_task: None,
- connection_modes: Arc::new(Mutex::new(connection_modes_iterator)),
- })
+ fn on_set_access_method(
+ &mut self,
+ tx: ResponseTx<()>,
+ value: AccessMethodSetting,
+ ) -> Result<()> {
+ self.set_access_method(value);
+ self.reply(tx, ())
}
- /// Return a pointer to the underlying iterator over [`AccessMethod`].
- /// Having access to this iterator allow you to influence , e.g. by calling
- /// [`ConnectionModesIterator::set_access_method()`] or
- /// [`ConnectionModesIterator::update_access_methods()`].
- pub(crate) fn handle(&self) -> Arc<Mutex<ConnectionModesIterator>> {
- self.connection_modes.clone()
+ fn set_access_method(&mut self, value: AccessMethodSetting) {
+ self.connection_modes.set_access_method(value);
}
- /// Return a new connection mode to be used for the API connection.
- fn new_connection_mode(&mut self) -> ApiConnectionMode {
- log::debug!("Rotating Access mode!");
- let access_method = {
- let mut access_methods_picker = self.connection_modes.lock().unwrap();
- access_methods_picker
- .next()
- .map(|access_method_setting| access_method_setting.access_method)
- .unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct))
- };
+ fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
+ let next = self.next_connection_mode();
+ // Save the new connection mode to cache!
+ {
+ let cache_dir = self.cache_dir.clone();
+ let next = next.clone();
+ tokio::spawn(async move {
+ if next.save(&cache_dir).await.is_err() {
+ log::warn!(
+ "Failed to save {connection_mode} to cache",
+ connection_mode = next
+ )
+ }
+ });
+ }
+ self.reply(tx, next)
+ }
+
+ fn next_connection_mode(&mut self) -> ApiConnectionMode {
+ let access_method = self
+ .connection_modes
+ .next()
+ .map(|access_method_setting| access_method_setting.access_method)
+ .unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct));
let connection_mode = self.from(access_method);
- log::info!("New API connection mode selected: {}", connection_mode);
+ log::info!("New API connection mode selected: {connection_mode}");
connection_mode
}
+ fn on_update_access_methods(
+ &mut self,
+ tx: ResponseTx<()>,
+ values: Vec<AccessMethodSetting>,
+ ) -> Result<()> {
+ self.update_access_methods(values)?;
+ self.reply(tx, ())
+ }
+
+ fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> {
+ self.connection_modes.update_access_methods(values)
+ }
+
/// Ad-hoc version of [`std::convert::From::from`], but since some
/// [`ApiConnectionMode`]s require extra logic/data from
/// [`ApiConnectionModeProvider`] the standard [`std::convert::From`] trait
@@ -172,14 +308,10 @@ pub struct ConnectionModesIterator {
current: AccessMethodSetting,
}
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- #[error(display = "No access methods were provided.")]
- NoAccessMethods,
-}
-
impl ConnectionModesIterator {
- pub fn new(access_methods: Vec<AccessMethodSetting>) -> Result<ConnectionModesIterator, Error> {
+ pub fn new(
+ access_methods: Vec<AccessMethodSetting>,
+ ) -> std::result::Result<ConnectionModesIterator, Error> {
let mut iterator = Self::new_iterator(access_methods)?;
Ok(Self {
next: None,
@@ -197,7 +329,7 @@ impl ConnectionModesIterator {
pub fn update_access_methods(
&mut self,
access_methods: Vec<AccessMethodSetting>,
- ) -> Result<(), Error> {
+ ) -> std::result::Result<(), Error> {
self.available_modes = Self::new_iterator(access_methods)?;
Ok(())
}
@@ -208,7 +340,7 @@ impl ConnectionModesIterator {
/// returned.
fn new_iterator(
access_methods: Vec<AccessMethodSetting>,
- ) -> Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> {
+ ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> {
if access_methods.is_empty() {
Err(Error::NoAccessMethods)
} else {
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 65f2c56654..0288f9d8c4 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -66,7 +66,7 @@ use std::{
mem,
path::PathBuf,
pin::Pin,
- sync::{Arc, Mutex, Weak},
+ sync::{Arc, Weak},
time::Duration,
};
#[cfg(any(target_os = "linux", windows))]
@@ -179,6 +179,9 @@ pub enum Error {
#[error(display = "Access method error")]
AccessMethodError(#[error(source)] access_method::Error),
+ #[error(display = "API connection mode error")]
+ ApiConnectionModeError(#[error(source)] api::Error),
+
#[cfg(target_os = "macos")]
#[error(display = "Failed to set exclusion group")]
GroupIdError(#[error(source)] io::Error),
@@ -293,8 +296,8 @@ pub enum DaemonCommand {
UpdateApiAccessMethod(ResponseTx<(), Error>, AccessMethodSetting),
/// Get the currently used API access method
GetCurrentAccessMethod(ResponseTx<AccessMethodSetting, Error>),
- /// Get the addresses of all known API endpoints
- GetApiAddresses(ResponseTx<Vec<std::net::SocketAddr>, Error>),
+ /// Test an API access method
+ TestApiAccessMethod(ResponseTx<bool, Error>, mullvad_types::access_method::Id),
/// Get information about the currently running and latest app versions
GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>),
/// Return whether the daemon is performing post-upgrade tasks
@@ -602,7 +605,7 @@ pub struct Daemon<L: EventListener> {
account_history: account_history::AccountHistory,
device_checker: device::TunnelStateChangeHandler,
account_manager: device::AccountManagerHandle,
- connection_modes: Arc<Mutex<api::ConnectionModesIterator>>,
+ connection_modes_handler: api::AccessModeSelectorHandle,
api_runtime: mullvad_api::Runtime,
api_handle: mullvad_api::rest::MullvadRestHandle,
version_updater_handle: version_check::VersionUpdaterHandle,
@@ -680,38 +683,19 @@ where
.set_config(new_selector_config(settings));
});
- let proxy_provider = match api::ApiConnectionModeProvider::new(
+ let connection_modes = settings.api_access_methods.collect_enabled();
+
+ let connection_modes_handler = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
- settings
- .api_access_methods
- .access_method_settings
- .iter()
- // We only care about the access methods which are set to 'enabled' by the user.
- .filter(|api_access_method| api_access_method.enabled())
- .cloned()
- .collect(),
- ) {
- Ok(provider) => provider,
- Err(api::Error::NoAccessMethods) => {
- // No settings seem to have been found. Default to using the the
- // direct access method.
- let default = mullvad_types::access_method::Settings::direct();
- api::ApiConnectionModeProvider::new(
- cache_dir.clone(),
- relay_selector.clone(),
- vec![default],
- )
- .expect(
- "Failed to create the data structure responsible for managing access methods",
- )
- }
- };
-
- let connection_modes = proxy_provider.handle();
+ connection_modes,
+ );
let api_handle = api_runtime
- .mullvad_rest_handle(proxy_provider, endpoint_updater.callback())
+ .mullvad_rest_handle(
+ Box::pin(connection_modes_handler.clone().into_stream()),
+ endpoint_updater.callback(),
+ )
.await;
let migration_complete = if let Some(migration_data) = migration_data {
@@ -861,7 +845,7 @@ where
account_history,
device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()),
account_manager,
- connection_modes,
+ connection_modes_handler,
api_runtime,
api_handle,
version_updater_handle,
@@ -1151,7 +1135,7 @@ where
UpdateApiAccessMethod(tx, method) => self.on_update_api_access_method(tx, method).await,
GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx),
SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await,
- GetApiAddresses(tx) => self.on_get_api_addresses(tx).await,
+ TestApiAccessMethod(tx, method) => self.on_test_api_access_method(tx, method),
IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx),
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
@@ -2375,17 +2359,45 @@ where
}
fn on_get_current_api_access_method(&mut self, tx: ResponseTx<AccessMethodSetting, Error>) {
- let result = self
- .get_current_access_method()
- .map_err(Error::AccessMethodError);
- Self::oneshot_send(tx, result, "get_current_api_access_method response");
+ let handle = self.connection_modes_handler.clone();
+ tokio::spawn(async move {
+ let result = handle
+ .get_access_method()
+ .await
+ .map_err(Error::ApiConnectionModeError);
+ Self::oneshot_send(tx, result, "get_current_api_access_method response");
+ });
}
- async fn on_get_api_addresses(&mut self, tx: ResponseTx<Vec<std::net::SocketAddr>, Error>) {
- let api_proxy = mullvad_api::ApiProxy::new(self.api_handle.clone());
- let result = api_proxy.get_api_addrs().await.map_err(Error::RestError);
+ fn on_test_api_access_method(
+ &mut self,
+ tx: ResponseTx<bool, Error>,
+ access_method: mullvad_types::access_method::Id,
+ ) {
+ // NOTE: Preferably we would block all new API calls until the test is
+ // done and the previous access method is reset. Otherwise we run the
+ // risk of errounously triggering a rotation of the currently in-use
+ // access method.
+ let api_handle = self.api_handle.clone();
+ let handle = self.connection_modes_handler.clone();
+ let access_method_lookup = self
+ .get_api_access_method(access_method)
+ .map_err(Error::AccessMethodError);
- Self::oneshot_send(tx, result, "on_get_api_adressess response");
+ match access_method_lookup {
+ Ok(access_method) => {
+ tokio::spawn(async move {
+ let result =
+ access_method::test_access_method(access_method, handle, api_handle)
+ .await
+ .map_err(Error::AccessMethodError);
+ Self::oneshot_send(tx, result, "on_test_api_access_method response");
+ });
+ }
+ Err(err) => {
+ Self::oneshot_send(tx, Err(err), "on_test_api_access_method response");
+ }
+ }
}
fn on_get_settings(&self, tx: oneshot::Sender<Settings>) {
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index f042a923e5..c194825a34 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -693,13 +693,13 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_daemon_error)
}
- async fn get_api_addresses(&self, _: Request<()>) -> ServiceResult<types::ApiAddresses> {
- log::debug!("get_api_addresses");
+ async fn test_api_access_method(&self, request: Request<types::Uuid>) -> ServiceResult<bool> {
+ log::debug!("test_api_access_method");
let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::GetApiAddresses(tx))?;
+ let api_access_method = mullvad_types::access_method::Id::try_from(request.into_inner())?;
+ self.send_command_to_daemon(DaemonCommand::TestApiAccessMethod(tx, api_access_method))?;
self.wait_for_result(rx)
.await?
- .map(types::ApiAddresses::from)
.map(Response::new)
.map_err(map_daemon_error)
}
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index a27698f317..d66707b79f 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -22,7 +22,6 @@ service ManagementService {
rpc GetCurrentVersion(google.protobuf.Empty) returns (google.protobuf.StringValue) {}
rpc GetVersionInfo(google.protobuf.Empty) returns (AppVersionInfo) {}
- rpc GetApiAddresses(google.protobuf.Empty) returns (ApiAddresses) {}
rpc IsPerformingPostUpgrade(google.protobuf.Empty) returns (google.protobuf.BoolValue) {}
@@ -82,6 +81,7 @@ service ManagementService {
rpc SetApiAccessMethod(UUID) returns (google.protobuf.Empty) {}
rpc UpdateApiAccessMethod(AccessMethodSetting) returns (google.protobuf.Empty) {}
rpc GetCurrentApiAccessMethod(google.protobuf.Empty) returns (AccessMethodSetting) {}
+ rpc TestApiAccessMethod(UUID) returns (google.protobuf.BoolValue) {}
// Split tunneling (Linux)
rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {}
@@ -110,8 +110,6 @@ message AccountData { google.protobuf.Timestamp expiry = 1; }
message AccountHistory { google.protobuf.StringValue token = 1; }
-message ApiAddresses { repeated google.protobuf.StringValue api_addresses = 1; }
-
message VoucherSubmission {
uint64 seconds_added = 1;
google.protobuf.Timestamp new_expiry = 2;
diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs
index 64ee088d18..1c9d80b2e8 100644
--- a/mullvad-management-interface/src/client.rs
+++ b/mullvad-management-interface/src/client.rs
@@ -208,15 +208,13 @@ impl MullvadProxyClient {
})
}
- pub async fn get_api_addresses(&mut self) -> Result<Vec<std::net::SocketAddr>> {
- self.0
- .get_api_addresses(())
+ pub async fn test_api_access_method(&mut self, id: access_method::Id) -> Result<bool> {
+ let result = self
+ .0
+ .test_api_access_method(types::Uuid::from(id))
.await
- .map_err(Error::Rpc)
- .map(tonic::Response::into_inner)
- .and_then(|api_addresses| {
- Vec::<std::net::SocketAddr>::try_from(api_addresses).map_err(Error::InvalidResponse)
- })
+ .map_err(Error::Rpc)?;
+ Ok(result.into_inner())
}
pub async fn update_relay_locations(&mut self) -> Result<()> {
diff --git a/mullvad-management-interface/src/types/conversions/net.rs b/mullvad-management-interface/src/types/conversions/net.rs
index 7b24a8f2b4..e2df5553a0 100644
--- a/mullvad-management-interface/src/types/conversions/net.rs
+++ b/mullvad-management-interface/src/types/conversions/net.rs
@@ -163,27 +163,6 @@ impl From<proto::IpVersion> for talpid_types::net::IpVersion {
}
}
-impl TryFrom<proto::ApiAddresses> for Vec<SocketAddr> {
- type Error = FromProtobufTypeError;
-
- fn try_from(value: proto::ApiAddresses) -> Result<Self, Self::Error> {
- value
- .api_addresses
- .iter()
- .map(|api_address| api_address.parse::<SocketAddr>())
- .collect::<Result<_, _>>()
- .map_err(|_| FromProtobufTypeError::InvalidArgument("Invalid socket address"))
- }
-}
-
-impl From<Vec<SocketAddr>> for proto::ApiAddresses {
- fn from(value: Vec<SocketAddr>) -> Self {
- Self {
- api_addresses: value.iter().map(SocketAddr::to_string).collect(),
- }
- }
-}
-
pub fn try_tunnel_type_from_i32(
tunnel_type: i32,
) -> Result<talpid_types::net::TunnelType, FromProtobufTypeError> {
diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs
index fe4b2507ed..7afaf94dfc 100644
--- a/mullvad-types/src/access_method.rs
+++ b/mullvad-types/src/access_method.rs
@@ -61,6 +61,14 @@ impl Settings {
let method = BuiltInAccessMethod::Bridge;
AccessMethodSetting::new(method.canonical_name(), true, AccessMethod::from(method))
}
+
+ /// Retrieve all [`AccessMethodSetting`]s which are enabled.
+ pub fn collect_enabled(&self) -> Vec<AccessMethodSetting> {
+ self.cloned()
+ .into_iter()
+ .filter(|access_method| access_method.enabled)
+ .collect()
+ }
}
impl Default for Settings {