summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-01-08 09:56:05 +0100
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-01-08 09:56:05 +0100
commit9d98bd30203a11728f9d90f444136af94bf32ea9 (patch)
treee23e3d00918d1931e9e43aee70f3143193fda4c2
parent1bcf0ef542af225747fcaa3a6fcebd0819d82f0c (diff)
parent32ba86f5359f0ea0f6cc930ad4db6a1c1f6f693a (diff)
downloadmullvadvpn-9d98bd30203a11728f9d90f444136af94bf32ea9.tar.xz
mullvadvpn-9d98bd30203a11728f9d90f444136af94bf32ea9.zip
Merge branch 'broadcast-api-access-method-changes'
-rw-r--r--mullvad-api/src/bin/relay_list.rs2
-rw-r--r--mullvad-api/src/lib.rs20
-rw-r--r--mullvad-api/src/proxy.rs38
-rw-r--r--mullvad-api/src/rest.rs49
-rw-r--r--mullvad-cli/src/cmds/status.rs5
-rw-r--r--mullvad-daemon/src/access_method.rs175
-rw-r--r--mullvad-daemon/src/api.rs424
-rw-r--r--mullvad-daemon/src/lib.rs167
-rw-r--r--mullvad-daemon/src/management_interface.rs12
-rw-r--r--mullvad-jni/src/jni_event_listener.rs6
-rw-r--r--mullvad-management-interface/proto/management_interface.proto1
-rw-r--r--mullvad-management-interface/src/client.rs6
-rw-r--r--mullvad-management-interface/src/types/conversions/access_method.rs10
-rw-r--r--mullvad-problem-report/src/lib.rs1
-rw-r--r--mullvad-setup/src/main.rs1
-rw-r--r--test/test-manager/src/tests/account.rs5
16 files changed, 485 insertions, 437 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index ffb65c28b2..c016b4c8a1 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -13,7 +13,7 @@ async fn main() {
let relay_list_request = RelayListProxy::new(
runtime
- .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat(), |_| async { true })
+ .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat())
.await,
)
.relay_list(None)
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index ae7929deec..237ed100d4 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -18,7 +18,7 @@ use std::{
path::Path,
sync::OnceLock,
};
-use talpid_types::{net::AllowedEndpoint, ErrorExt};
+use talpid_types::ErrorExt;
pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
@@ -216,19 +216,6 @@ pub enum Error {
ApiCheckError(#[error(source)] availability::Error),
}
-/// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`.
-/// It should return a future that determines whether to reject the new endpoint or not.
-pub trait ApiEndpointUpdateCallback: Fn(AllowedEndpoint) -> Self::AcceptedNewEndpoint {
- type AcceptedNewEndpoint: Future<Output = bool> + Send;
-}
-
-impl<U, T: Future<Output = bool> + Send> ApiEndpointUpdateCallback for U
-where
- U: Fn(AllowedEndpoint) -> T,
-{
- type AcceptedNewEndpoint = T;
-}
-
impl Runtime {
/// Create a new `Runtime`.
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
@@ -305,7 +292,6 @@ impl Runtime {
&self,
sni_hostname: Option<String>,
proxy_provider: T,
- new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
@@ -313,7 +299,6 @@ impl Runtime {
self.api_availability.handle(),
self.address_cache.clone(),
proxy_provider,
- new_address_callback,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
@@ -326,13 +311,11 @@ impl Runtime {
>(
&self,
proxy_provider: T,
- new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
) -> rest::MullvadRestHandle {
let service = self
.new_request_service(
Some(API.host.clone()),
proxy_provider,
- new_address_callback,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
)
@@ -353,7 +336,6 @@ impl Runtime {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_repeat(),
- |_| async { true },
#[cfg(target_os = "android")]
None,
)
diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs
index 3c7d071d92..2b4821ba64 100644
--- a/mullvad-api/src/proxy.rs
+++ b/mullvad-api/src/proxy.rs
@@ -8,7 +8,7 @@ use std::{
task::{self, Poll},
};
use talpid_types::{
- net::{proxy, AllowedClients, Endpoint, TransportProtocol},
+ net::{proxy, Endpoint, TransportProtocol},
ErrorExt,
};
use tokio::{
@@ -70,6 +70,16 @@ impl fmt::Display for ProxyConfig {
}
}
+impl From<proxy::CustomProxy> for ProxyConfig {
+ fn from(value: proxy::CustomProxy) -> Self {
+ match value {
+ proxy::CustomProxy::Shadowsocks(shadowsocks) => ProxyConfig::Shadowsocks(shadowsocks),
+ proxy::CustomProxy::Socks5Local(socks) => ProxyConfig::Socks5Local(socks),
+ proxy::CustomProxy::Socks5Remote(socks) => ProxyConfig::Socks5Remote(socks),
+ }
+ }
+}
+
impl ApiConnectionMode {
/// Reads the proxy config from `CURRENT_CONFIG_FILENAME`.
/// This returns `ApiConnectionMode::Direct` if reading from disk fails for any reason.
@@ -139,32 +149,6 @@ impl ApiConnectionMode {
}
}
- #[cfg(unix)]
- pub fn allowed_clients(&self) -> AllowedClients {
- match self {
- ApiConnectionMode::Proxied(ProxyConfig::Socks5Local(_)) => AllowedClients::All,
- ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => AllowedClients::Root,
- }
- }
-
- #[cfg(windows)]
- pub fn allowed_clients(&self) -> AllowedClients {
- match self {
- ApiConnectionMode::Proxied(ProxyConfig::Socks5Local(_)) => AllowedClients::all(),
- ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => {
- let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
- vec![
- daemon_exe
- .parent()
- .expect("missing executable parent directory")
- .join("mullvad-problem-report.exe"),
- daemon_exe,
- ]
- .into()
- }
- }
- }
-
pub fn is_proxy(&self) -> bool {
*self != ApiConnectionMode::Direct
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 6332c1266e..9f1e88a751 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -24,10 +24,7 @@ use std::{
sync::{Arc, Weak},
time::Duration,
};
-use talpid_types::{
- net::{AllowedEndpoint, Endpoint, TransportProtocol},
- ErrorExt,
-};
+use talpid_types::ErrorExt;
#[cfg(feature = "api-override")]
use crate::API;
@@ -123,36 +120,24 @@ impl Error {
}
}
-use super::ApiEndpointUpdateCallback;
-
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
/// requests
-pub(crate) struct RequestService<
- T: Stream<Item = ApiConnectionMode>,
- F: ApiEndpointUpdateCallback + Send,
-> {
+pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
proxy_config_provider: T,
- new_address_callback: F,
- address_cache: AddressCache,
api_availability: ApiAvailabilityHandle,
}
-impl<
- T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
- F: ApiEndpointUpdateCallback + Send + Sync + 'static,
- > RequestService<T, F>
-{
+impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
/// Constructs a new request service.
pub async fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
mut proxy_config_provider: T,
- new_address_callback: F,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
@@ -184,8 +169,6 @@ impl<
connector_handle,
client,
proxy_config_provider,
- new_address_callback,
- address_cache,
api_availability,
};
let handle = RequestServiceHandle { tx: command_tx };
@@ -203,26 +186,14 @@ impl<
}
RequestCommand::NextApiConfig(completion_tx) => {
#[cfg(feature = "api-override")]
- if API.force_direct_connection {
- log::debug!("Ignoring API connection mode");
- let _ = completion_tx.send(Ok(()));
- return;
- }
+ let force_direct_connection = API.force_direct_connection;
+ #[cfg(not(feature = "api-override"))]
+ let force_direct_connection = false;
- if let Some(new_config) = self.proxy_config_provider.next().await {
- let endpoint = match new_config.get_endpoint() {
- Some(endpoint) => endpoint,
- None => Endpoint::from_socket_address(
- self.address_cache.get_address().await,
- TransportProtocol::Tcp,
- ),
- };
- let clients = new_config.allowed_clients();
- let allowed_endpoint = AllowedEndpoint { endpoint, clients };
- // Switch to new connection mode unless rejected by address change callback
- if (self.new_address_callback)(allowed_endpoint).await {
- self.connector_handle.set_connection_mode(new_config);
- }
+ if force_direct_connection {
+ log::debug!("Ignoring API connection mode");
+ } else if let Some(connection_mode) = self.proxy_config_provider.next().await {
+ self.connector_handle.set_connection_mode(connection_mode);
}
let _ = completion_tx.send(Ok(()));
diff --git a/mullvad-cli/src/cmds/status.rs b/mullvad-cli/src/cmds/status.rs
index 5bed82b4c0..15b0d10dfe 100644
--- a/mullvad-cli/src/cmds/status.rs
+++ b/mullvad-cli/src/cmds/status.rs
@@ -75,6 +75,11 @@ impl Status {
println!("Remove device event: {device:#?}");
}
}
+ DaemonEvent::NewAccessMethod(access_method) => {
+ if args.debug {
+ println!("New access method: {access_method:#?}");
+ }
+ }
}
}
Ok(())
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index 3bea2c4c7d..4bdbd17f15 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -1,9 +1,9 @@
use crate::{
- api::{self, AccessModeSelectorHandle},
+ api,
settings::{self, MadeChanges},
Daemon, EventListener,
};
-use mullvad_api::rest::{self, MullvadRestHandle};
+use mullvad_api::rest;
use mullvad_types::{
access_method::{self, AccessMethod, AccessMethodSetting},
settings::Settings,
@@ -34,18 +34,6 @@ pub enum Error {
Settings(#[error(source)] settings::Error),
}
-/// A tiny datastructure used for signaling whether the daemon should force a
-/// rotation of the currently used [`AccessMethodSetting`] or not, and if so:
-/// how it should do it.
-pub enum Command {
- /// There is no need to force a rotation of [`AccessMethodSetting`]
- Nothing,
- /// Select the next available [`AccessMethodSetting`], whichever that is
- Rotate,
- /// Select the [`AccessMethodSetting`] with a certain [`access_method::Id`]
- Set(access_method::Id),
-}
-
impl<L> Daemon<L>
where
L: EventListener + Clone + Send + 'static,
@@ -79,30 +67,29 @@ where
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
- // Make sure that we are not trying to remove a built-in API access
- // method
- let command = match self.settings.api_access_methods.find_by_id(&access_method) {
- Some(api_access_method) => {
- if api_access_method.is_builtin() {
- Err(Error::RemoveBuiltIn)
- } else if api_access_method.get_id()
- == self.get_current_access_method().await?.get_id()
- {
- Ok(Command::Rotate)
- } else {
- Ok(Command::Nothing)
- }
+ match self.settings.api_access_methods.find_by_id(&access_method) {
+ // Make sure that we are not trying to remove a built-in API access
+ // method
+ Some(api_access_method) if api_access_method.is_builtin() => {
+ return Err(Error::RemoveBuiltIn)
+ }
+ // If the currently active access method is removed, a new access
+ // method should trigger
+ Some(api_access_method)
+ if api_access_method.get_id()
+ == self.get_current_access_method().await?.get_id() =>
+ {
+ self.force_api_endpoint_rotation().await?;
}
- None => Ok(Command::Nothing),
- }?;
+ _ => (),
+ }
self.settings
.update(|settings| settings.api_access_methods.remove(&access_method))
.await
.map(|did_change| self.notify_on_change(did_change))
- .map_err(Error::Settings)?
- .process_command(command)
- .await
+ .map(|_| ())
+ .map_err(Error::Settings)
}
/// Set a [`AccessMethodSetting`] as the current API access method.
@@ -119,9 +106,6 @@ where
.set_access_method(access_method)
.await?;
// Force a rotation of Access Methods.
- //
- // This is not a call to `process_command` due to the restrictions on
- // recursively calling async functions.
self.force_api_endpoint_rotation().await
}
@@ -150,7 +134,8 @@ where
// If the currently active access method is updated, we need to re-set
// it after updating the settings.
let current = self.get_current_access_method().await?;
- let mut command = Command::Nothing;
+ // If the currently active access method is updated, we need to re-set it.
+ let mut refresh = None;
let settings_update = |settings: &mut Settings| {
let access_methods = &mut settings.api_access_methods;
if let Some(access_method) =
@@ -158,7 +143,7 @@ where
{
*access_method = access_method_update;
if access_method.get_id() == current.get_id() {
- command = Command::Set(access_method.get_id())
+ refresh = Some(access_method.get_id())
}
// We have to be a bit careful. If we are about to disable the last
// remaining enabled access method, we would cause an inconsistent state
@@ -185,19 +170,25 @@ where
.update(settings_update)
.await
.map(|did_change| self.notify_on_change(did_change))
- .map_err(Error::Settings)?
- .process_command(command)
- .await
+ .map_err(Error::Settings)?;
+ if let Some(id) = refresh {
+ self.set_api_access_method(id).await?;
+ }
+ Ok(())
}
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
- Ok(self.connection_modes_handler.get_access_method().await?)
+ self.connection_modes_handler
+ .get_current()
+ .await
+ .map(|current| current.setting)
+ .map_err(Error::ConnectionMode)
}
- /// Change which [`AccessMethodSetting`] which will be used to figure out
- /// the Mullvad API endpoint.
+ /// Change which [`AccessMethodSetting`] which will be used as the Mullvad
+ /// API endpoint.
async fn force_api_endpoint_rotation(&self) -> Result<(), Error> {
self.api_handle
.service()
@@ -233,102 +224,4 @@ where
};
self
}
-
- /// The semantics of the [`Command`] datastructure.
- async fn process_command(&mut self, command: Command) -> Result<(), Error> {
- match command {
- Command::Nothing => Ok(()),
- Command::Rotate => self.force_api_endpoint_rotation().await,
- Command::Set(id) => self.set_api_access_method(id).await,
- }
- }
-}
-
-/// Try to reach the Mullvad API using a specific access method, returning
-/// an [`Error`] in the case where the test fails to reach the API.
-///
-/// Ephemerally sets a new access method (associated with `access_method`)
-/// to be used for subsequent API calls, before performing an API call and
-/// switching back to the previously active access method. The previous
-/// access method is *always* reset.
-pub async fn test_access_method(
- new_access_method: AccessMethodSetting,
- access_mode_selector: AccessModeSelectorHandle,
- rest_handle: MullvadRestHandle,
-) -> Result<bool, Error> {
- // Setup test
- let previous_access_method = access_mode_selector
- .get_access_method()
- .await
- .map_err(Error::ConnectionMode)?;
-
- let method_under_test = new_access_method.clone();
- access_mode_selector
- .set_access_method(new_access_method)
- .await
- .map_err(Error::ConnectionMode)?;
-
- // We need to perform a rotation of API endpoint after a set action
- let rotation_handle = rest_handle.clone();
- rotation_handle
- .service()
- .next_api_endpoint()
- .await
- .map_err(|err| {
- log::error!("Failed to rotate API endpoint: {err}");
- Error::Rest(err)
- })?;
-
- // Set up the reset
- //
- // In case the API call fails, the next API endpoint will
- // automatically be selected, which means that we need to set up
- // with the previous API endpoint beforehand.
- access_mode_selector
- .set_access_method(previous_access_method)
- .await
- .map_err(|err| {
- log::error!(
- "Could not reset to previous access
- method after API reachability test was carried out. This should only
- happen if the previous access method was removed in the meantime."
- );
- Error::ConnectionMode(err)
- })?;
-
- // Perform test
- //
- // Send a HEAD request to some Mullvad API endpoint. We issue a HEAD
- // request because we are *only* concerned with if we get a reply from
- // the API, and not with the actual data that the endpoint returns.
- let result = mullvad_api::ApiProxy::new(rest_handle)
- .api_addrs_available()
- .await
- .map_err(Error::Rest)?;
-
- // We need to perform a rotation of API endpoint after a set action
- // Note that this will be done automatically if the API call fails,
- // so it only has to be done if the call succeeded ..
- if result {
- rotation_handle
- .service()
- .next_api_endpoint()
- .await
- .map_err(|err| {
- log::error!("Failed to rotate API endpoint: {err}");
- Error::Rest(err)
- })?;
- }
-
- log::info!(
- "The result of testing {method:?} is {result}",
- method = method_under_test.access_method,
- result = if result {
- "success".to_string()
- } else {
- "failed".to_string()
- }
- );
-
- Ok(result)
}
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index efb8f3088d..3f6f1747a6 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -4,33 +4,120 @@
//! [`ApiConnectionMode`], which in turn is used by `mullvad-api` for
//! establishing connections when performing API requests.
#[cfg(target_os = "android")]
-use crate::{DaemonCommand, DaemonEventSender};
+use crate::DaemonCommand;
+use crate::DaemonEventSender;
use futures::{
- channel::{mpsc, oneshot},
- stream::unfold,
+ channel::{
+ mpsc,
+ oneshot::{self, Canceled},
+ },
Stream, StreamExt,
};
use mullvad_api::{
availability::ApiAvailabilityHandle,
proxy::{ApiConnectionMode, ProxyConfig},
- ApiEndpointUpdateCallback,
+ AddressCache,
};
use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
-use std::{
- path::PathBuf,
- sync::{Arc, Mutex, Weak},
-};
-#[cfg(target_os = "android")]
+use std::{net::SocketAddr, path::PathBuf};
use talpid_core::mpsc::Sender;
-use talpid_core::tunnel_state_machine::TunnelCommand;
-use talpid_types::net::{AllowedEndpoint, Endpoint};
+use talpid_types::net::{AllowedClients, AllowedEndpoint, Endpoint, TransportProtocol};
pub enum Message {
- Get(ResponseTx<AccessMethodSetting>),
+ Get(ResponseTx<ResolvedConnectionMode>),
Set(ResponseTx<()>, AccessMethodSetting),
Next(ResponseTx<ApiConnectionMode>),
Update(ResponseTx<()>, Vec<AccessMethodSetting>),
+ Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting),
+}
+
+/// A [`NewAccessMethodEvent`] is emitted when the active access method changes,
+/// which happens in any of the following two scenarios:
+///
+/// * When a [`mullvad_api::rest::RequestService`] requests a new
+/// [`ApiConnectionMode`] from the running [`AccessModeSelector`]. This will
+/// lead to a [`crate::InternalDaemonEvent::AccessMethodEvent`] being sent to
+/// the daemon, which in turn will notify all clients about the new access
+/// method.
+///
+/// * When testing if some [`AccessMethodSetting`] can be used to reach the
+/// Mullvad API. In this scenario, the currently active access method will
+/// temporarily change (approximately for the duration of 1 API call). Since
+/// this is just an internal test which should be opaque to any client, it
+/// should not produce any unwanted noise and as such it is *not* broadcasted
+/// after the daemon is done processing this [`NewAccessMethodEvent`].
+pub struct NewAccessMethodEvent {
+ /// The new active [`AccessMethodSetting`].
+ pub setting: AccessMethodSetting,
+ /// The endpoint which represents how to connect to the Mullvad API and
+ /// which clients are allowed to initiate such a connection.
+ pub endpoint: AllowedEndpoint,
+ /// If the daemon should notify clients about the new access method.
+ ///
+ /// Defaults to `true`.
+ pub announce: bool,
+}
+
+impl NewAccessMethodEvent {
+ /// Create a new [`NewAccessMethodEvent`] for the daemon to process. A
+ /// [`oneshot::Receiver`] can be used to await the daemon while it finishes
+ /// handling the new event.
+ pub fn new(setting: AccessMethodSetting, endpoint: AllowedEndpoint) -> NewAccessMethodEvent {
+ NewAccessMethodEvent {
+ setting,
+ endpoint,
+ announce: true,
+ }
+ }
+
+ /// Whether the daemon should notify clients about the new access method or
+ /// not.
+ ///
+ /// * If `announce` is set to `true` the daemon will broadcast this event to
+ /// clients.
+ /// * If `announce` is set to `false` the daemon will *not* broadcast this
+ /// event.
+ pub fn announce(mut self, announce: bool) -> Self {
+ self.announce = announce;
+ self
+ }
+
+ /// Send an internal daemon event which will punch a hole in the firewall
+ /// for the connection mode we are testing.
+ ///
+ /// Returns the channel on which the daemon will send a message over when it
+ /// is done applying the firewall changes.
+ pub(crate) async fn send(
+ self,
+ daemon_event_sender: DaemonEventSender<(NewAccessMethodEvent, oneshot::Sender<()>)>,
+ ) -> std::result::Result<(), Canceled> {
+ // It is up to the daemon to actually allow traffic to/from `api_endpoint`
+ // by updating the firewall. This [`oneshot::Sender`] allows the daemon to
+ // communicate when that action is done.
+ let (update_finished_tx, update_finished_rx) = oneshot::channel();
+ let _ = daemon_event_sender.send((self, update_finished_tx));
+ // Wait for the daemon to finish processing `event`.
+ update_finished_rx.await
+ }
+}
+
+/// This struct represent a concrete API endpoint (in the form of an
+/// [`ApiConnectionMode`] and [`AllowedEndpoint`]) which has been derived from
+/// some [`AccessMethodSetting`] (most likely the currently active access
+/// method). These logically related values are sometimes useful to group
+/// together into one value, which is encoded by [`ResolvedConnectionMode`].
+#[derive(Clone)]
+pub struct ResolvedConnectionMode {
+ /// The connection strategy to be used by the `mullvad-api` crate when
+ /// initializing API requests.
+ pub connection_mode: ApiConnectionMode,
+ /// The actual endpoint of the Mullvad API and which clients should be
+ /// allowed to initialize a connection to this endpoint.
+ pub endpoint: AllowedEndpoint,
+ /// This is the [`AccessMethodSetting`] which resolved into
+ /// `connection_mode` and `endpoint`.
+ pub setting: AccessMethodSetting,
}
#[derive(err_derive::Error, Debug)]
@@ -52,6 +139,7 @@ impl std::fmt::Display for Message {
Message::Set(..) => f.write_str("Set"),
Message::Next(_) => f.write_str("Next"),
Message::Update(..) => f.write_str("Update"),
+ Message::Resolve(..) => f.write_str("Resolve"),
}
}
}
@@ -83,13 +171,11 @@ pub struct AccessModeSelectorHandle {
impl AccessModeSelectorHandle {
async fn send_command<T>(&self, make_cmd: impl FnOnce(ResponseTx<T>) -> Message) -> Result<T> {
let (tx, rx) = oneshot::channel();
- self.cmd_tx
- .unbounded_send(make_cmd(tx))
- .map_err(Error::SendFailed)?;
+ self.cmd_tx.unbounded_send(make_cmd(tx))?;
rx.await.map_err(Error::NotRunning)?
}
- pub async fn get_access_method(&self) -> Result<AccessMethodSetting> {
+ pub async fn get_current(&self) -> Result<ResolvedConnectionMode> {
self.send_command(Message::Get).await.map_err(|err| {
log::debug!("Failed to get current access method!");
err
@@ -114,6 +200,15 @@ impl AccessModeSelectorHandle {
})
}
+ pub async fn resolve(&self, setting: AccessMethodSetting) -> Result<ResolvedConnectionMode> {
+ self.send_command(|tx| Message::Resolve(tx, setting))
+ .await
+ .map_err(|err| {
+ log::error!("Failed to update new access methods!");
+ err
+ })
+ }
+
pub async fn next(&self) -> Result<ApiConnectionMode> {
self.send_command(Message::Next).await.map_err(|err| {
log::debug!("Failed while getting the next access method");
@@ -124,10 +219,10 @@ impl AccessModeSelectorHandle {
/// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the
/// associated [`AccessModeSelector`].
///
- /// Practically converts the handle to a listener for when the
- /// currently valid connection modes changes.
+ /// Calling `next` on this stream will poll for the next access method,
+ /// which will be lazily produced (on-demand rather than speculatively).
pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> {
- unfold(self, |handle| async move {
+ futures::stream::unfold(self, |handle| async move {
match handle.next().await {
Ok(connection_mode) => Some((connection_mode, handle)),
// End this stream in case of failure in `next`. `next` should
@@ -155,35 +250,49 @@ pub struct AccessModeSelector {
/// Used for selecting a Bridge when the `Mullvad Bridges` access method is used.
relay_selector: RelaySelector,
connection_modes: ConnectionModesIterator,
+ address_cache: AddressCache,
+ access_method_event_sender: DaemonEventSender<(NewAccessMethodEvent, oneshot::Sender<()>)>,
+ current: ResolvedConnectionMode,
}
impl AccessModeSelector {
- pub fn spawn(
+ pub(crate) async fn spawn(
cache_dir: PathBuf,
relay_selector: RelaySelector,
connection_modes: Vec<AccessMethodSetting>,
- ) -> AccessModeSelectorHandle {
+ access_method_event_sender: DaemonEventSender<(NewAccessMethodEvent, oneshot::Sender<()>)>,
+ address_cache: AddressCache,
+ ) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();
- let connection_modes = match ConnectionModesIterator::new(connection_modes) {
- Ok(provider) => provider,
- Err(Error::NoAccessMethods) | Err(_) => {
+ let mut connection_modes =
+ ConnectionModesIterator::new(connection_modes).unwrap_or_else(|_| {
// No settings seem to have been found. Default to using the the
// direct access method.
let default = mullvad_types::access_method::Settings::direct();
ConnectionModesIterator::new(vec![default]).expect(
"Failed to create the data structure responsible for managing access methods",
)
- }
+ });
+
+ let initial_connection_mode = {
+ let next = connection_modes.next().ok_or(Error::NoAccessMethods)?;
+ Self::resolve_inner(next, &relay_selector, &address_cache).await
};
+
let selector = AccessModeSelector {
cmd_rx,
cache_dir,
relay_selector,
connection_modes,
+ address_cache,
+ access_method_event_sender,
+ current: initial_connection_mode,
};
+
tokio::spawn(selector.into_future());
- AccessModeSelectorHandle { cmd_tx }
+
+ Ok(AccessModeSelectorHandle { cmd_tx })
}
async fn into_future(mut self) {
@@ -192,8 +301,9 @@ impl AccessModeSelector {
let execution = match cmd {
Message::Get(tx) => self.on_get_access_method(tx),
Message::Set(tx, value) => self.on_set_access_method(tx, value),
- Message::Next(tx) => self.on_next_connection_mode(tx),
+ Message::Next(tx) => self.on_next_connection_mode(tx).await,
Message::Update(tx, values) => self.on_update_access_methods(tx, values),
+ Message::Resolve(tx, setting) => self.on_resolve_access_method(tx, setting).await,
};
match execution {
Ok(_) => (),
@@ -213,13 +323,8 @@ impl AccessModeSelector {
Ok(())
}
- fn on_get_access_method(&mut self, tx: ResponseTx<AccessMethodSetting>) -> Result<()> {
- let value = self.get_access_method();
- self.reply(tx, value)
- }
-
- fn get_access_method(&mut self) -> AccessMethodSetting {
- self.connection_modes.peek()
+ fn on_get_access_method(&mut self, tx: ResponseTx<ResolvedConnectionMode>) -> Result<()> {
+ self.reply(tx, self.current.clone())
}
fn on_set_access_method(
@@ -231,40 +336,55 @@ impl AccessModeSelector {
self.reply(tx, ())
}
+ /// Set the next access method to be returned by the [`Stream`] produced by
+ /// calling `into_stream`.
fn set_access_method(&mut self, value: AccessMethodSetting) {
self.connection_modes.set_access_method(value);
}
- fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
- let next = self.next_connection_mode();
- // Save the new connection mode to cache!
- {
- let cache_dir = self.cache_dir.clone();
- let next = next.clone();
- tokio::spawn(async move {
- if next.save(&cache_dir).await.is_err() {
- log::warn!(
- "Failed to save {connection_mode} to cache",
- connection_mode = next
- )
- }
- });
- }
+ async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
+ let next = self.next_connection_mode().await?;
self.reply(tx, next)
}
- fn next_connection_mode(&mut self) -> ApiConnectionMode {
- let access_method = self
- .connection_modes
- .next()
- .map(|access_method_setting| access_method_setting.access_method)
- .unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct));
+ async fn next_connection_mode(&mut self) -> Result<ApiConnectionMode> {
+ let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?;
+ log::info!(
+ "A new API access method has been selected: {name}",
+ name = access_method.name
+ );
+ let resolved = self.resolve(access_method).await;
+ // Note: If the daemon is busy waiting for a call to this function
+ // to complete while we wait for the daemon to fully handle this
+ // `NewAccessMethodEvent`, then we find ourselves in a deadlock.
+ // This can happen during daemon startup when spawning a new
+ // `MullvadRestHandle`, which will call and await `next` on a Stream
+ // created from this `AccessModeSelector` instance. As such, the
+ // completion channel is discarded in this instance.
+ let setting = resolved.setting.clone();
+ let endpoint = resolved.endpoint.clone();
+ let daemon_sender = self.access_method_event_sender.clone();
+ tokio::spawn(async move {
+ let _ = NewAccessMethodEvent::new(setting, endpoint)
+ .send(daemon_sender)
+ .await;
+ });
- let connection_mode = self.from(access_method);
- log::info!("New API connection mode selected: {connection_mode}");
- connection_mode
- }
+ // Save the new connection mode to cache!
+ let cache_dir = self.cache_dir.clone();
+ let new_connection_mode = resolved.connection_mode.clone();
+ tokio::spawn(async move {
+ if new_connection_mode.save(&cache_dir).await.is_err() {
+ log::warn!(
+ "Failed to save {connection_mode} to cache",
+ connection_mode = new_connection_mode
+ )
+ }
+ });
+ self.current = resolved;
+ Ok(self.current.connection_mode.clone())
+ }
fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
@@ -278,51 +398,60 @@ impl AccessModeSelector {
self.connection_modes.update_access_methods(values)
}
- /// Ad-hoc version of [`std::convert::From::from`], but since some
- /// [`ApiConnectionMode`]s require extra logic/data from
- /// [`ApiConnectionModeProvider`] the standard [`std::convert::From`] trait
- /// can not be implemented.
- fn from(&mut self, access_method: AccessMethod) -> ApiConnectionMode {
- use talpid_types::net::proxy;
- match access_method {
- AccessMethod::BuiltIn(access_method) => match access_method {
- BuiltInAccessMethod::Direct => ApiConnectionMode::Direct,
- BuiltInAccessMethod::Bridge => self
- .relay_selector
- .get_bridge_forced()
- .and_then(|settings| match settings {
- proxy::CustomProxy::Shadowsocks(ss_settings) => {
- let ss_settings: proxy::Shadowsocks = proxy::Shadowsocks::new(
- ss_settings.endpoint,
- ss_settings.cipher,
- ss_settings.password,
- );
- Some(ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(
- ss_settings,
- )))
- }
- _ => {
- log::error!("Received unexpected proxy settings type");
- None
- }
- })
- .unwrap_or(ApiConnectionMode::Direct),
- },
- AccessMethod::Custom(access_method) => match access_method {
- proxy::CustomProxy::Shadowsocks(shadowsocks_config) => {
- ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(shadowsocks_config))
- }
- proxy::CustomProxy::Socks5Local(socks_config) => {
- ApiConnectionMode::Proxied(ProxyConfig::Socks5Local(socks_config))
- }
- proxy::CustomProxy::Socks5Remote(socks_config) => {
- ApiConnectionMode::Proxied(ProxyConfig::Socks5Remote(socks_config))
- }
- },
+ pub async fn on_resolve_access_method(
+ &mut self,
+ tx: ResponseTx<ResolvedConnectionMode>,
+ setting: AccessMethodSetting,
+ ) -> Result<()> {
+ let reply = self.resolve(setting).await;
+ self.reply(tx, reply)
+ }
+
+ async fn resolve(&mut self, access_method: AccessMethodSetting) -> ResolvedConnectionMode {
+ Self::resolve_inner(access_method, &self.relay_selector, &self.address_cache).await
+ }
+
+ async fn resolve_inner(
+ access_method: AccessMethodSetting,
+ relay_selector: &RelaySelector,
+ address_cache: &AddressCache,
+ ) -> ResolvedConnectionMode {
+ let connection_mode =
+ resolve_connection_mode(access_method.access_method.clone(), relay_selector);
+ let endpoint =
+ resolve_allowed_endpoint(&connection_mode, address_cache.get_address().await);
+ ResolvedConnectionMode {
+ connection_mode,
+ endpoint,
+ setting: access_method,
}
}
}
+/// Ad-hoc version of [`std::convert::From::from`], but since some
+/// [`ApiConnectionMode`]s require extra logic/data from [`RelaySelector`] to be
+/// instantiated the standard [`std::convert::From`] trait can not be
+/// implemented.
+fn resolve_connection_mode(
+ access_method: AccessMethod,
+ relay_selector: &RelaySelector,
+) -> ApiConnectionMode {
+ match access_method {
+ AccessMethod::BuiltIn(BuiltInAccessMethod::Direct) => ApiConnectionMode::Direct,
+ AccessMethod::BuiltIn(BuiltInAccessMethod::Bridge) => relay_selector
+ .get_bridge_forced()
+ .map(ProxyConfig::from)
+ .map(ApiConnectionMode::Proxied)
+ .unwrap_or_else(|| {
+ log::error!(
+ "Received unexpected proxy settings type. Defaulting to direct API connection"
+ );
+ ApiConnectionMode::Direct
+ }),
+ AccessMethod::Custom(config) => ApiConnectionMode::Proxied(ProxyConfig::from(config)),
+ }
+}
+
/// An iterator which will always produce an [`AccessMethod`].
///
/// Safety: It is always safe to [`unwrap`] after calling [`next`] on a
@@ -353,6 +482,7 @@ impl ConnectionModesIterator {
pub fn set_access_method(&mut self, next: AccessMethodSetting) {
self.next = Some(next);
}
+
/// Update the collection of [`AccessMethod`] which this iterator will
/// return.
pub fn update_access_methods(
@@ -376,11 +506,6 @@ impl ConnectionModesIterator {
Ok(Box::new(access_methods.into_iter().cycle()))
}
}
-
- /// Look at the currently active [`AccessMethod`]
- pub fn peek(&self) -> AccessMethodSetting {
- self.current.clone()
- }
}
impl Iterator for ConnectionModesIterator {
@@ -397,73 +522,44 @@ impl Iterator for ConnectionModesIterator {
}
}
-/// Notifies the tunnel state machine that the API (real or proxied) endpoint has
-/// changed. [ApiEndpointUpdaterHandle::callback()] creates a callback that may
-/// be passed to the `mullvad-api` runtime.
-pub(super) struct ApiEndpointUpdaterHandle {
- tunnel_cmd_tx: Arc<Mutex<Option<Weak<mpsc::UnboundedSender<TunnelCommand>>>>>,
+pub fn resolve_allowed_endpoint(
+ connection_mode: &ApiConnectionMode,
+ fallback: SocketAddr,
+) -> AllowedEndpoint {
+ let endpoint = match connection_mode.get_endpoint() {
+ Some(endpoint) => endpoint,
+ None => Endpoint::from_socket_address(fallback, TransportProtocol::Tcp),
+ };
+ let clients = allowed_clients(connection_mode);
+ AllowedEndpoint { endpoint, clients }
}
-impl ApiEndpointUpdaterHandle {
- pub fn new() -> Self {
- Self {
- tunnel_cmd_tx: Arc::new(Mutex::new(None)),
- }
- }
-
- pub fn set_tunnel_command_tx(&self, tunnel_cmd_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>) {
- *self.tunnel_cmd_tx.lock().unwrap() = Some(tunnel_cmd_tx);
+#[cfg(unix)]
+pub fn allowed_clients(connection_mode: &ApiConnectionMode) -> AllowedClients {
+ match connection_mode {
+ ApiConnectionMode::Proxied(ProxyConfig::Socks5Local(_)) => AllowedClients::All,
+ ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => AllowedClients::Root,
}
+}
- pub fn callback(&self) -> impl ApiEndpointUpdateCallback {
- let tunnel_tx = self.tunnel_cmd_tx.clone();
- move |allowed_endpoint: AllowedEndpoint| {
- let inner_tx = tunnel_tx.clone();
- async move {
- let tunnel_tx = if let Some(tunnel_tx) = { inner_tx.lock().unwrap().as_ref() }
- .and_then(|tx: &Weak<mpsc::UnboundedSender<TunnelCommand>>| tx.upgrade())
- {
- tunnel_tx
- } else {
- log::error!("Rejecting allowed endpoint: Tunnel state machine is not running");
- return false;
- };
- let (result_tx, result_rx) = oneshot::channel();
- let _ = tunnel_tx.unbounded_send(TunnelCommand::AllowEndpoint(
- allowed_endpoint.clone(),
- result_tx,
- ));
- // Wait for the firewall policy to be updated.
- let _ = result_rx.await;
- log::debug!(
- "API endpoint: {endpoint}",
- endpoint = allowed_endpoint.endpoint
- );
- true
- }
+#[cfg(windows)]
+pub fn allowed_clients(connection_mode: &ApiConnectionMode) -> AllowedClients {
+ match connection_mode {
+ ApiConnectionMode::Proxied(ProxyConfig::Socks5Local(_)) => AllowedClients::all(),
+ ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => {
+ let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
+ vec![
+ daemon_exe
+ .parent()
+ .expect("missing executable parent directory")
+ .join("mullvad-problem-report.exe"),
+ daemon_exe,
+ ]
+ .into()
}
}
}
-pub(super) fn get_allowed_endpoint(endpoint: Endpoint) -> AllowedEndpoint {
- #[cfg(unix)]
- let clients = talpid_types::net::AllowedClients::Root;
- #[cfg(windows)]
- let clients = {
- let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
- vec![
- daemon_exe
- .parent()
- .expect("missing executable parent directory")
- .join("mullvad-problem-report.exe"),
- daemon_exe,
- ]
- .into()
- };
-
- AllowedEndpoint { endpoint, clients }
-}
-
pub(crate) fn forward_offline_state(
api_availability: ApiAvailabilityHandle,
mut offline_state_rx: mpsc::UnboundedReceiver<bool>,
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index f365496baf..0d02a068c4 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -27,6 +27,7 @@ pub mod version;
mod version_check;
use crate::target_state::PersistentTargetState;
+use api::NewAccessMethodEvent;
use device::{AccountEvent, PrivateAccountAndDevice, PrivateDeviceEvent};
use futures::{
channel::{mpsc, oneshot},
@@ -369,6 +370,11 @@ pub(crate) enum InternalDaemonEvent {
NewAppVersionInfo(AppVersionInfo),
/// Sent when a device is updated in any way (key rotation, login, logout, etc.).
DeviceEvent(AccountEvent),
+ /// Sent when access methods are changed in any way (new active access method).
+ AccessMethodEvent {
+ event: NewAccessMethodEvent,
+ endpoint_active_tx: oneshot::Sender<()>,
+ },
/// Handles updates from versions without devices.
DeviceMigrationEvent(Result<PrivateAccountAndDevice, device::Error>),
/// A geographical location has has been received from am.i.mullvad.net
@@ -408,6 +414,15 @@ impl From<AccountEvent> for InternalDaemonEvent {
}
}
+impl From<(NewAccessMethodEvent, oneshot::Sender<()>)> for InternalDaemonEvent {
+ fn from(event: (NewAccessMethodEvent, oneshot::Sender<()>)) -> Self {
+ InternalDaemonEvent::AccessMethodEvent {
+ event: event.0,
+ endpoint_active_tx: event.1,
+ }
+ }
+}
+
#[derive(Clone, Debug, Eq, PartialEq)]
enum DaemonExecutionState {
Running,
@@ -590,6 +605,9 @@ pub trait EventListener {
/// Notify that a device was revoked using `RemoveDevice`.
fn notify_remove_device_event(&self, event: RemoveDeviceEvent);
+
+ /// Notify that the api access method changed.
+ fn notify_new_access_method_event(&self, new_access_method: AccessMethodSetting);
}
pub struct Daemon<L: EventListener> {
@@ -654,8 +672,6 @@ where
let api_availability = api_runtime.availability_handle();
api_availability.suspend();
- let endpoint_updater = api::ApiEndpointUpdaterHandle::new();
-
let migration_data = migrations::migrate_all(&cache_dir, &settings_dir)
.await
.unwrap_or_else(|error| {
@@ -687,18 +703,20 @@ where
});
let connection_modes = settings.api_access_methods.collect_enabled();
+ let connection_modes_address_cache = api_runtime.address_cache.clone();
let connection_modes_handler = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
connection_modes,
- );
+ internal_event_tx.to_specialized_sender(),
+ connection_modes_address_cache.clone(),
+ )
+ .await
+ .map_err(Error::ApiConnectionModeError)?;
let api_handle = api_runtime
- .mullvad_rest_handle(
- Box::pin(connection_modes_handler.clone().into_stream()),
- endpoint_updater.callback(),
- )
+ .mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream()))
.await;
let migration_complete = if let Some(migration_data) = migration_data {
@@ -750,11 +768,6 @@ where
vec![]
};
- let initial_api_endpoint =
- api::get_allowed_endpoint(talpid_types::net::Endpoint::from_socket_address(
- api_runtime.address_cache.get_address().await,
- talpid_types::net::TransportProtocol::Tcp,
- ));
let parameters_generator = tunnel::ParametersGenerator::new(
account_manager.clone(),
relay_selector.clone(),
@@ -772,6 +785,11 @@ where
let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned());
});
+ let initial_api_endpoint = connection_modes_handler
+ .get_current()
+ .await
+ .map_err(Error::ApiConnectionModeError)?
+ .endpoint;
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
@@ -803,9 +821,6 @@ where
.await
.map_err(Error::TunnelError)?;
- endpoint_updater
- .set_tunnel_command_tx(Arc::downgrade(tunnel_state_machine_handle.command_tx()));
-
api::forward_offline_state(api_availability.clone(), offline_state_rx);
let relay_list_listener = event_listener.clone();
@@ -962,6 +977,10 @@ where
self.handle_new_app_version_info(app_version_info);
}
DeviceEvent(event) => self.handle_device_event(event).await,
+ AccessMethodEvent {
+ event,
+ endpoint_active_tx,
+ } => self.handle_access_method_event(event, endpoint_active_tx),
DeviceMigrationEvent(event) => self.handle_device_migration_event(event),
LocationEvent(location_data) => self.handle_location_event(location_data),
#[cfg(windows)]
@@ -1218,7 +1237,7 @@ where
UpdateApiAccessMethod(tx, method) => self.on_update_api_access_method(tx, method).await,
GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx),
SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await,
- TestApiAccessMethod(tx, method) => self.on_test_api_access_method(tx, method),
+ TestApiAccessMethod(tx, method) => self.on_test_api_access_method(tx, method).await,
IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx),
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
@@ -1317,6 +1336,32 @@ where
}
}
+ fn handle_access_method_event(
+ &mut self,
+ event: NewAccessMethodEvent,
+ endpoint_active_tx: oneshot::Sender<()>,
+ ) {
+ // Update the firewall to exempt a new API endpoint.
+ let (completion_tx, completion_rx) = oneshot::channel();
+ self.send_tunnel_command(TunnelCommand::AllowEndpoint(event.endpoint, completion_tx));
+ // If the `NewAccessMethodEvent` should be announced to any client
+ // listening for updates of the currently active access method, we need
+ // to clone the handle to the broadcaster of such events. The
+ // announcement should be made after the firewall policy has been
+ // updated, since the new access method will be useless before then.
+ let event_listener = self.event_listener.clone();
+ tokio::spawn(async move {
+ // Wait for the firewall policy to be updated.
+ let _ = completion_rx.await;
+ // Let the emitter of this event know that the firewall has been updated.
+ let _ = endpoint_active_tx.send(());
+ // Notify clients about the change if necessary.
+ if event.announce {
+ event_listener.notify_new_access_method_event(event.setting);
+ }
+ });
+ }
+
fn handle_device_migration_event(
&mut self,
result: Result<PrivateAccountAndDevice, device::Error>,
@@ -2405,42 +2450,88 @@ where
let handle = self.connection_modes_handler.clone();
tokio::spawn(async move {
let result = handle
- .get_access_method()
+ .get_current()
.await
+ .map(|current| current.setting)
.map_err(Error::ApiConnectionModeError);
Self::oneshot_send(tx, result, "get_current_api_access_method response");
});
}
- fn on_test_api_access_method(
+ async fn on_test_api_access_method(
&mut self,
tx: ResponseTx<bool, Error>,
access_method: mullvad_types::access_method::Id,
) {
- // NOTE: Preferably we would block all new API calls until the test is
- // done and the previous access method is reset. Otherwise we run the
- // risk of errounously triggering a rotation of the currently in-use
- // access method.
- let api_handle = self.api_handle.clone();
- let handle = self.connection_modes_handler.clone();
- let access_method_lookup = self
- .get_api_access_method(access_method)
- .map_err(Error::AccessMethodError);
+ let reply =
+ |response| Self::oneshot_send(tx, response, "on_test_api_access_method response");
- match access_method_lookup {
- Ok(access_method) => {
- tokio::spawn(async move {
- let result =
- access_method::test_access_method(access_method, handle, api_handle)
- .await
- .map_err(Error::AccessMethodError);
- Self::oneshot_send(tx, result, "on_test_api_access_method response");
- });
+ let access_method = match self.get_api_access_method(access_method) {
+ Ok(x) => x,
+ Err(err) => {
+ reply(Err(Error::AccessMethodError(err)));
+ return;
}
+ };
+
+ let test_subject = match self.connection_modes_handler.resolve(access_method).await {
+ Ok(test_subject) => test_subject,
Err(err) => {
- Self::oneshot_send(tx, Err(err), "on_test_api_access_method response");
+ reply(Err(Error::ApiConnectionModeError(err)));
+ return;
}
- }
+ };
+
+ let test_subject_name = test_subject.setting.name.clone();
+ let proxy_provider = test_subject.connection_mode.clone().into_repeat();
+ let rest_handle = self.api_runtime.mullvad_rest_handle(proxy_provider).await;
+ let api_proxy = mullvad_api::ApiProxy::new(rest_handle);
+ let daemon_event_sender = self.tx.to_specialized_sender();
+ let access_method_selector = self.connection_modes_handler.clone();
+
+ tokio::spawn(async move {
+ let result = async move {
+ // Send an internal daemon event which will punch a hole in the firewall
+ // for the connection mode we are testing.
+ let _ = api::NewAccessMethodEvent::new(test_subject.setting, test_subject.endpoint)
+ .announce(false)
+ .send(daemon_event_sender.clone())
+ .await;
+
+ // Send a HEAD request to some Mullvad API endpoint. We issue a HEAD
+ // request because we are *only* concerned with if we get a reply from
+ // the API, and not with the actual data that the endpoint returns.
+ let result = api_proxy
+ .api_addrs_available()
+ .await
+ .map_err(Error::RestError);
+
+ // Tell the daemon to reset the hole we just punched to whatever was in
+ // place before.
+ let active = access_method_selector
+ .get_current()
+ .await
+ .map_err(Error::ApiConnectionModeError)?;
+ let _ = api::NewAccessMethodEvent::new(active.setting, active.endpoint)
+ .announce(false)
+ .send(daemon_event_sender.clone())
+ .await;
+
+ result
+ }
+ .await;
+
+ log::debug!(
+ "API access method {method} {verdict}",
+ method = test_subject_name,
+ verdict = match result {
+ Ok(true) => "could successfully connect to the Mullvad API",
+ _ => "could not connect to the Mullvad API",
+ }
+ );
+
+ reply(result);
+ });
}
fn on_get_settings(&self, tx: oneshot::Sender<Settings>) {
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 6985af685c..8b8d840cfb 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -1009,6 +1009,18 @@ impl EventListener for ManagementInterfaceEventBroadcaster {
)),
})
}
+
+ fn notify_new_access_method_event(
+ &self,
+ new_access_method: mullvad_types::access_method::AccessMethodSetting,
+ ) {
+ log::debug!("Broadcasting access method event");
+ self.notify(types::DaemonEvent {
+ event: Some(daemon_event::Event::NewAccessMethod(
+ types::AccessMethodSetting::from(new_access_method),
+ )),
+ })
+ }
}
impl ManagementInterfaceEventBroadcaster {
diff --git a/mullvad-jni/src/jni_event_listener.rs b/mullvad-jni/src/jni_event_listener.rs
index 2aeb8320e3..4567622975 100644
--- a/mullvad-jni/src/jni_event_listener.rs
+++ b/mullvad-jni/src/jni_event_listener.rs
@@ -7,6 +7,7 @@ use jnix::{
};
use mullvad_daemon::EventListener;
use mullvad_types::{
+ access_method::AccessMethodSetting,
device::{DeviceEvent, RemoveDeviceEvent},
relay_list::RelayList,
settings::Settings,
@@ -71,6 +72,11 @@ impl EventListener for JniEventListener {
fn notify_remove_device_event(&self, event: RemoveDeviceEvent) {
let _ = self.0.send(Event::RemoveDevice(event));
}
+
+ // TODO: Implement this function when API access methods is implemented in
+ // the Android app.
+ #[allow(dead_code, unused_variables)]
+ fn notify_new_access_method_event(&self, access_method: AccessMethodSetting) {}
}
struct JniEventHandler<'env> {
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 268e70327d..6978e5a666 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -594,6 +594,7 @@ message DaemonEvent {
AppVersionInfo version_info = 4;
DeviceEvent device = 5;
RemoveDeviceEvent remove_device = 6;
+ AccessMethodSetting new_access_method = 7;
}
}
diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs
index 85048e29e7..2c0ca2fecf 100644
--- a/mullvad-management-interface/src/client.rs
+++ b/mullvad-management-interface/src/client.rs
@@ -37,6 +37,7 @@ pub enum DaemonEvent {
AppVersionInfo(AppVersionInfo),
Device(DeviceEvent),
RemoveDevice(RemoveDeviceEvent),
+ NewAccessMethod(AccessMethodSetting),
}
impl TryFrom<types::daemon_event::Event> for DaemonEvent {
@@ -62,6 +63,11 @@ impl TryFrom<types::daemon_event::Event> for DaemonEvent {
types::daemon_event::Event::RemoveDevice(event) => RemoveDeviceEvent::try_from(event)
.map(DaemonEvent::RemoveDevice)
.map_err(Error::InvalidResponse),
+ types::daemon_event::Event::NewAccessMethod(event) => {
+ AccessMethodSetting::try_from(event)
+ .map(DaemonEvent::NewAccessMethod)
+ .map_err(Error::InvalidResponse)
+ }
}
}
}
diff --git a/mullvad-management-interface/src/types/conversions/access_method.rs b/mullvad-management-interface/src/types/conversions/access_method.rs
index 368528f5bc..a5aa84d374 100644
--- a/mullvad-management-interface/src/types/conversions/access_method.rs
+++ b/mullvad-management-interface/src/types/conversions/access_method.rs
@@ -252,14 +252,20 @@ mod data {
}
}
- impl From<Id> for proto::Uuid {
- fn from(value: Id) -> Self {
+ impl From<&Id> for proto::Uuid {
+ fn from(value: &Id) -> Self {
proto::Uuid {
value: value.to_string(),
}
}
}
+ impl From<Id> for proto::Uuid {
+ fn from(value: Id) -> Self {
+ proto::Uuid::from(&value)
+ }
+ }
+
impl TryFrom<proto::Uuid> for Id {
type Error = FromProtobufTypeError;
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 707b1d4e8b..1f687b4570 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -305,7 +305,6 @@ async fn send_problem_report_inner(
ApiConnectionMode::try_from_cache(cache_dir)
.await
.into_repeat(),
- |_| async { true },
)
.await,
);
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index bcae459442..f89baeb049 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -165,7 +165,6 @@ async fn remove_device() -> Result<(), Error> {
ApiConnectionMode::try_from_cache(&cache_path)
.await
.into_repeat(),
- |_| async { true },
)
.await,
);
diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs
index 78eb42adc4..8b30f85768 100644
--- a/test/test-manager/src/tests/account.rs
+++ b/test/test-manager/src/tests/account.rs
@@ -237,10 +237,7 @@ pub async fn new_device_client() -> DevicesProxy {
let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("failed to create api runtime");
let rest_handle = api
- .mullvad_rest_handle(
- mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat(),
- |_| async { true },
- )
+ .mullvad_rest_handle(mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat())
.await;
DevicesProxy::new(rest_handle)
}