summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-02-15 10:21:59 +0100
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-02-15 10:21:59 +0100
commit02240d69574e0d373c71d109e7ce4f006b916d2e (patch)
tree02c86a5d58153997fd3a1773abf6555f2083d274
parent05a78732fc38c241a7465adca3f906d5f20c1ad6 (diff)
parentb226b8e634f4b9d7dff2f67c2752aa02bdd25a13 (diff)
downloadmullvadvpn-02240d69574e0d373c71d109e7ce4f006b916d2e.tar.xz
mullvadvpn-02240d69574e0d373c71d109e7ce4f006b916d2e.zip
Merge branch 'fix/consider-direct-access-method-first'
-rw-r--r--mullvad-api/src/bin/relay_list.rs9
-rw-r--r--mullvad-api/src/lib.rs46
-rw-r--r--mullvad-api/src/rest.rs9
-rw-r--r--mullvad-daemon/src/access_method.rs8
-rw-r--r--mullvad-daemon/src/api.rs34
-rw-r--r--mullvad-daemon/src/lib.rs33
-rw-r--r--mullvad-problem-report/src/lib.rs9
-rw-r--r--mullvad-setup/src/main.rs9
-rw-r--r--mullvad-types/src/access_method.rs2
-rw-r--r--test/test-manager/src/tests/account.rs20
-rw-r--r--test/test-manager/src/tests/install.rs9
11 files changed, 86 insertions, 102 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index c016b4c8a1..e395d8ae5f 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -11,11 +11,10 @@ async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
- let relay_list_request = RelayListProxy::new(
- runtime
- .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat())
- .await,
- )
+ let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(
+ ApiConnectionMode::Direct,
+ ApiConnectionMode::Direct.into_repeat(),
+ ))
.relay_list(None)
.await;
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 26d6b9758c..17e80b66c0 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -382,9 +382,10 @@ impl Runtime {
}
/// Creates a new request service and returns a handle to it.
- async fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
+ fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
&self,
sni_hostname: Option<String>,
+ initial_connection_mode: ApiConnectionMode,
proxy_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
@@ -392,28 +393,26 @@ impl Runtime {
sni_hostname,
self.api_availability.handle(),
self.address_cache.clone(),
+ initial_connection_mode,
proxy_provider,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
- .await
}
/// Returns a request factory initialized to create requests for the master API
- pub async fn mullvad_rest_handle<
- T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
- >(
+ pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
&self,
+ initial_connection_mode: ApiConnectionMode,
proxy_provider: T,
) -> rest::MullvadRestHandle {
- let service = self
- .new_request_service(
- Some(API.host().to_string()),
- proxy_provider,
- #[cfg(target_os = "android")]
- self.socket_bypass_tx.clone(),
- )
- .await;
+ let service = self.new_request_service(
+ Some(API.host().to_string()),
+ initial_connection_mode,
+ proxy_provider,
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
+ );
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(API.host(), Some(token_store));
@@ -426,15 +425,14 @@ impl Runtime {
}
/// This is only to be used in test code
- pub async fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
- let service = self
- .new_request_service(
- Some(hostname.clone()),
- futures::stream::repeat(ApiConnectionMode::Direct),
- #[cfg(target_os = "android")]
- self.socket_bypass_tx.clone(),
- )
- .await;
+ pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
+ let service = self.new_request_service(
+ Some(hostname.clone()),
+ ApiConnectionMode::Direct,
+ futures::stream::repeat(ApiConnectionMode::Direct),
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
+ );
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(hostname, Some(token_store));
@@ -447,14 +445,14 @@ impl Runtime {
}
/// Returns a new request service handle
- pub async fn rest_handle(&self) -> rest::RequestServiceHandle {
+ pub fn rest_handle(&self) -> rest::RequestServiceHandle {
self.new_request_service(
None,
+ ApiConnectionMode::Direct,
ApiConnectionMode::Direct.into_repeat(),
#[cfg(target_os = "android")]
None,
)
- .await
}
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 0560642bb0..ca63f16c1f 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -131,11 +131,12 @@ pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {
impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
/// Constructs a new request service.
- pub async fn spawn(
+ pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
- mut proxy_config_provider: T,
+ initial_connection_mode: ApiConnectionMode,
+ proxy_config_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
@@ -145,9 +146,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
socket_bypass_tx.clone(),
);
- if let Some(config) = proxy_config_provider.next().await {
- connector_handle.set_connection_mode(config);
- }
+ connector_handle.set_connection_mode(initial_connection_mode);
let (command_tx, command_rx) = mpsc::unbounded();
let client = Client::builder().build(connector);
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index 51bf6c1ea5..664fce6bfe 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -260,14 +260,10 @@ where
/// Create an [`ApiProxy`] which will perform all REST requests against one
/// specific endpoint `proxy_provider`.
- pub async fn create_limited_api_proxy(
- &mut self,
- proxy_provider: ApiConnectionMode,
- ) -> ApiProxy {
+ pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy {
let rest_handle = self
.api_runtime
- .mullvad_rest_handle(proxy_provider.into_repeat())
- .await;
+ .mullvad_rest_handle(proxy_provider, futures::stream::empty());
ApiProxy::new(rest_handle)
}
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index a493b532d1..5db03c2008 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -243,7 +243,8 @@ impl AccessModeSelector {
) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();
- let (index, next) = Self::get_next_inner(0, &access_method_settings);
+ // Always start looking from the position of `Direct`.
+ let (index, next) = Self::select_next_active(0, &access_method_settings);
let initial_connection_mode =
Self::resolve_inner(next, &relay_selector, &address_cache).await;
@@ -396,25 +397,28 @@ impl AccessModeSelector {
if let Some(access_method) = self.set.take() {
access_method
} else {
- let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings);
- self.index = index;
+ let (next_index, next) =
+ Self::select_next_active(self.index + 1, &self.access_method_settings);
+ self.index = next_index;
next
}
}
- fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
- let xs: Vec<_> = access_methods.iter().collect();
- for offset in 1..=access_methods.cardinality() {
- let index = (start + offset) % access_methods.cardinality();
- if let Some(&candidate) = xs.get(index) {
- if candidate.enabled {
- return (index, candidate.clone());
- }
- }
- }
- (0, access_methods.direct().clone())
+ /// Find the next access method to use.
+ ///
+ /// * `start`: From which point in `access_methods` to start the search.
+ /// * `access_methods`: The search space.
+ fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
+ access_methods
+ .iter()
+ .cloned()
+ .enumerate()
+ .cycle()
+ .skip(start)
+ .take(access_methods.cardinality())
+ .find(|(_index, access_method)| access_method.enabled())
+ .unwrap_or_else(|| (0, access_methods.direct().clone()))
}
-
fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 00f4613ffb..ef4bcb86e0 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -717,9 +717,15 @@ where
.await
.map_err(Error::ApiConnectionModeError)?;
- let api_handle = api_runtime
- .mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream()))
- .await;
+ let initial_connection_mode = connection_modes_handler
+ .get_current()
+ .await
+ .map_err(Error::ApiConnectionModeError)?;
+
+ let api_handle = api_runtime.mullvad_rest_handle(
+ initial_connection_mode.connection_mode,
+ Box::pin(connection_modes_handler.clone().into_stream()),
+ );
let migration_complete = if let Some(migration_data) = migration_data {
migrations::migrate_device(
@@ -787,11 +793,6 @@ where
let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned());
});
- let initial_api_endpoint = connection_modes_handler
- .get_current()
- .await
- .map_err(Error::ApiConnectionModeError)?
- .endpoint;
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
@@ -800,7 +801,7 @@ where
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options),
- allowed_endpoint: initial_api_endpoint,
+ allowed_endpoint: initial_connection_mode.endpoint,
reset_firewall: *target_state != TargetState::Secured,
#[cfg(windows)]
exclude_paths,
@@ -851,7 +852,7 @@ where
relay_list_updater.update().await;
let location_handler = GeoIpHandler::new(
- api_runtime.rest_handle().await,
+ api_runtime.rest_handle(),
internal_event_tx.clone().to_specialized_sender(),
);
@@ -1248,9 +1249,7 @@ where
GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx),
SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await,
TestApiAccessMethodById(tx, method) => self.on_test_api_access_method(tx, method).await,
- TestCustomApiAccessMethod(tx, proxy) => {
- self.on_test_proxy_as_access_method(tx, proxy).await
- }
+ TestCustomApiAccessMethod(tx, proxy) => self.on_test_proxy_as_access_method(tx, proxy),
IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx),
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
@@ -2478,7 +2477,7 @@ where
});
}
- async fn on_test_proxy_as_access_method(
+ fn on_test_proxy_as_access_method(
&mut self,
tx: ResponseTx<bool, Error>,
proxy: talpid_types::net::proxy::CustomProxy,
@@ -2487,7 +2486,7 @@ where
use talpid_types::net::AllowedEndpoint;
let connection_mode = ApiConnectionMode::Proxied(ProxyConfig::from(proxy.clone()));
- let api_proxy = self.create_limited_api_proxy(connection_mode.clone()).await;
+ let api_proxy = self.create_limited_api_proxy(connection_mode.clone());
let proxy_endpoint = AllowedEndpoint {
endpoint: proxy.get_remote_endpoint().endpoint,
clients: api::allowed_clients(&connection_mode),
@@ -2533,9 +2532,7 @@ where
}
};
- let api_proxy = self
- .create_limited_api_proxy(test_subject.connection_mode)
- .await;
+ let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode);
let daemon_event_sender = self.tx.to_specialized_sender();
let access_method_selector = self.connection_modes_handler.clone();
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 1f687b4570..bcd820bef5 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -299,14 +299,9 @@ async fn send_problem_report_inner(
.await
.map_err(Error::CreateRpcClientError)?;
+ let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await;
let api_client = mullvad_api::ProblemReportProxy::new(
- api_runtime
- .mullvad_rest_handle(
- ApiConnectionMode::try_from_cache(cache_dir)
- .await
- .into_repeat(),
- )
- .await,
+ api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
);
for _attempt in 0..MAX_SEND_ATTEMPTS {
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index f89baeb049..cf93b2d039 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -159,14 +159,9 @@ async fn remove_device() -> Result<(), Error> {
.await
.map_err(Error::RpcInitializationError)?;
+ let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await;
let proxy = mullvad_api::DevicesProxy::new(
- api_runtime
- .mullvad_rest_handle(
- ApiConnectionMode::try_from_cache(&cache_path)
- .await
- .into_repeat(),
- )
- .await,
+ api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
);
let device_removal = retry_future(
diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs
index c0f648f25f..e8f6bd4a5d 100644
--- a/mullvad-types/src/access_method.rs
+++ b/mullvad-types/src/access_method.rs
@@ -83,7 +83,7 @@ impl Settings {
}
/// Iterate over references of built-in & custom access methods.
- pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> {
+ pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> + Clone {
use std::iter::once;
once(&self.direct)
.chain(once(&self.mullvad_bridges))
diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs
index 95b54b6432..1eeeb8c170 100644
--- a/test/test-manager/src/tests/account.rs
+++ b/test/test-manager/src/tests/account.rs
@@ -23,7 +23,7 @@ pub async fn test_login(
// Instruct daemon to log in
//
- clear_devices(&new_device_client().await)
+ clear_devices(&new_device_client())
.await
.expect("failed to clear devices");
@@ -65,7 +65,7 @@ pub async fn test_too_many_devices(
) -> Result<(), Error> {
log::info!("Using up all devices");
- let device_client = new_device_client().await;
+ let device_client = new_device_client();
const MAX_ATTEMPTS: usize = 15;
@@ -151,7 +151,7 @@ pub async fn test_revoked_device(
log::debug!("Removing current device");
- let device_client = new_device_client().await;
+ let device_client = new_device_client();
retry_if_throttled(|| {
device_client.remove(TEST_CONFIG.account_number.clone(), device_id.clone())
})
@@ -217,9 +217,10 @@ pub async fn clear_devices(device_client: &DevicesProxy) -> Result<(), mullvad_a
Ok(())
}
-pub async fn new_device_client() -> DevicesProxy {
- let api_endpoint = mullvad_api::ApiEndpoint::from_env_vars();
+pub fn new_device_client() -> DevicesProxy {
+ use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, API};
+ let api_endpoint = ApiEndpoint::from_env_vars();
let api_host = format!("api.{}", TEST_CONFIG.mullvad_host);
let api_address = format!("{api_host}:443")
.to_socket_addrs()
@@ -228,7 +229,7 @@ pub async fn new_device_client() -> DevicesProxy {
.unwrap();
// Override the API endpoint to use the one specified in the test config
- let _ = mullvad_api::API.override_init(mullvad_api::ApiEndpoint {
+ let _ = API.override_init(ApiEndpoint {
host: Some(api_host),
address: Some(api_address),
..api_endpoint
@@ -236,9 +237,10 @@ pub async fn new_device_client() -> DevicesProxy {
let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("failed to create api runtime");
- let rest_handle = api
- .mullvad_rest_handle(mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat())
- .await;
+ let rest_handle = api.mullvad_rest_handle(
+ ApiConnectionMode::Direct,
+ ApiConnectionMode::Direct.into_repeat(),
+ );
DevicesProxy::new(rest_handle)
}
diff --git a/test/test-manager/src/tests/install.rs b/test/test-manager/src/tests/install.rs
index 614b9d9bb9..8f2f2cdff2 100644
--- a/test/test-manager/src/tests/install.rs
+++ b/test/test-manager/src/tests/install.rs
@@ -49,7 +49,7 @@ pub async fn test_upgrade_app(ctx: TestContext, rpc: ServiceClient) -> Result<()
return Err(Error::DaemonNotRunning);
}
- super::account::clear_devices(&super::account::new_device_client().await)
+ super::account::clear_devices(&super::account::new_device_client())
.await
.expect("failed to clear devices");
@@ -227,10 +227,9 @@ pub async fn test_uninstall_app(
}
// verify that device was removed
- let devices =
- super::account::list_devices_with_retries(&super::account::new_device_client().await)
- .await
- .expect("failed to list devices");
+ let devices = super::account::list_devices_with_retries(&super::account::new_device_client())
+ .await
+ .expect("failed to list devices");
assert!(
!devices.iter().any(|device| device.id == uninstalled_device),