summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-daemon/src/access_method.rs76
-rw-r--r--mullvad-daemon/src/api.rs166
-rw-r--r--mullvad-daemon/src/lib.rs7
3 files changed, 98 insertions, 151 deletions
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index e28ba21793..51bf6c1ea5 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -114,7 +114,8 @@ where
// Toggle the enabled status if needed
if !access_method.enabled() {
access_method.enable();
- self.update_access_method_inner(&access_method).await?
+ self.update_access_method_inner(access_method.clone())
+ .await?
}
// Set `access_method` as the next access method to use
self.connection_modes_handler
@@ -130,7 +131,8 @@ where
) -> Result<AccessMethodSetting, Error> {
self.settings
.api_access_methods
- .find_by_id(&access_method)
+ .iter()
+ .find(|setting| setting.get_id() == access_method)
.ok_or(Error::NoSuchMethod(access_method))
.cloned()
}
@@ -146,14 +148,20 @@ where
&mut self,
access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
- self.update_access_method_inner(&access_method_update)
+ self.update_access_method_inner(access_method_update.clone())
.await?;
- // If the currently active access method is updated, we need to re-set
- // it after updating the settings.
- if access_method_update.get_id() == self.get_current_access_method().await?.get_id() {
- self.use_api_access_method(access_method_update.get_id())
- .await?;
+ if self.is_in_use(access_method_update.get_id()).await? {
+ if access_method_update.disabled() {
+ // If the currently active access method is updated & disabled
+ // we should select the next access method
+ self.force_api_endpoint_rotation().await?;
+ } else {
+ // If the currently active access method is just updated, we
+ // need to re-set it after updating the settings
+ self.use_api_access_method(access_method_update.get_id())
+ .await?;
+ }
}
Ok(())
@@ -167,33 +175,14 @@ where
/// existing, in-use setting needs to be re-set.
async fn update_access_method_inner(
&mut self,
- access_method_update: &AccessMethodSetting,
+ access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
- let access_method_update_moved = access_method_update.clone();
let settings_update = |settings: &mut Settings| {
- if let Some(access_method) = settings
- .api_access_methods
- .find_by_id_mut(&access_method_update_moved.get_id())
- {
- *access_method = access_method_update_moved;
- // We have to be a bit careful. If the update is about to
- // disable the last remaining enabled access method, we would
- // cause an inconsistent state in the daemon's settings.
- // Therefore, we have to explicitly safeguard against this by.
- // In that case, we should re-enable the `Direct` access method.
- if settings.api_access_methods.collect_enabled().is_empty() {
- if let Some(direct) = settings.api_access_methods.get_direct() {
- direct.enabled = true;
- } else {
- // If the `Direct` access method does not exist within the
- // settings for some reason, the settings are in an
- // inconsistent state. We don't have much choice but to
- // reset these settings to their default value.
- log::warn!("The built-in access methods can not be found. This might be due to a corrupt settings file");
- settings.api_access_methods = access_method::Settings::default();
- }
- }
- }
+ let target = access_method_update.get_id();
+ settings.api_access_methods.update(
+ |access_method| access_method.get_id() == target,
+ |_| access_method_update,
+ );
};
self.settings
@@ -205,6 +194,13 @@ where
Ok(())
}
+ /// Check if some access method is the same as the currently active one.
+ ///
+ /// This can be useful for invalidating stale states.
+ async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> {
+ Ok(access_method == self.get_current_access_method().await?.get_id())
+ }
+
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
@@ -291,19 +287,9 @@ where
.notify_settings(self.settings.to_settings());
let handle = self.connection_modes_handler.clone();
- let new_access_methods = self.settings.api_access_methods.collect_enabled();
+ let new_access_methods = self.settings.api_access_methods.clone();
tokio::spawn(async move {
- match handle.update_access_methods(new_access_methods).await {
- Ok(_) => (),
- Err(api::Error::NoAccessMethods) | Err(_) => {
- // `access_methods` was empty! This implies that the user
- // disabled all access methods. If we ever get into this
- // state, we should default to using the direct access
- // method.
- let default = access_method::Settings::direct();
- handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods");
- }
- }
+ let _ = handle.update_access_methods(new_access_methods).await;
});
};
self
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index 9a4614dca5..588068810e 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -16,7 +16,9 @@ use mullvad_api::{
AddressCache,
};
use mullvad_relay_selector::RelaySelector;
-use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
+use mullvad_types::access_method::{
+ AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings,
+};
use std::{net::SocketAddr, path::PathBuf};
use talpid_core::mpsc::Sender;
use talpid_types::net::{AllowedClients, AllowedEndpoint, Endpoint, TransportProtocol};
@@ -25,7 +27,7 @@ pub enum Message {
Get(ResponseTx<ResolvedConnectionMode>),
Set(ResponseTx<()>, AccessMethodSetting),
Next(ResponseTx<ApiConnectionMode>),
- Update(ResponseTx<()>, Vec<AccessMethodSetting>),
+ Update(ResponseTx<()>, Settings),
Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting),
}
@@ -164,8 +166,8 @@ impl AccessModeSelectorHandle {
})
}
- pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> {
- self.send_command(|tx| Message::Update(tx, values))
+ pub async fn update_access_methods(&self, access_methods: Settings) -> Result<()> {
+ self.send_command(|tx| Message::Update(tx, access_methods))
.await
.map_err(|err| {
log::debug!("Failed to switch to a new set of access methods");
@@ -213,54 +215,46 @@ impl AccessModeSelectorHandle {
/// connection mode. The API can be connected to either directly (i.e.,
/// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`])
/// or via any supported custom proxy protocol
-/// ([`api_access_methods::ObfuscationProtocol`]).
-///
-/// The strategy for determining the next [`ApiConnectionMode`] is handled by
-/// [`ConnectionModesIterator`].
+/// ([`talpid_types::net::proxy::CustomProxy`]).
pub struct AccessModeSelector {
cmd_rx: mpsc::UnboundedReceiver<Message>,
cache_dir: PathBuf,
/// Used for selecting a Bridge when the `Mullvad Bridges` access method is used.
relay_selector: RelaySelector,
- connection_modes: ConnectionModesIterator,
+ access_method_settings: Settings,
address_cache: AddressCache,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
current: ResolvedConnectionMode,
+ /// `index` is used to keep track of the [`AccessMethodSetting`] to use.
+ index: usize,
+ /// `set` is used to set the next [`AccessMethodSetting`] to use.
+ set: Option<AccessMethodSetting>,
}
impl AccessModeSelector {
pub(crate) async fn spawn(
cache_dir: PathBuf,
relay_selector: RelaySelector,
- connection_modes: Vec<AccessMethodSetting>,
+ access_method_settings: Settings,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
address_cache: AddressCache,
) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();
- let mut connection_modes =
- ConnectionModesIterator::new(connection_modes).unwrap_or_else(|_| {
- // No settings seem to have been found. Default to using the the
- // direct access method.
- let default = mullvad_types::access_method::Settings::direct();
- ConnectionModesIterator::new(vec![default]).expect(
- "Failed to create the data structure responsible for managing access methods",
- )
- });
-
- let initial_connection_mode = {
- let next = connection_modes.next().ok_or(Error::NoAccessMethods)?;
- Self::resolve_inner(next, &relay_selector, &address_cache).await
- };
+ let (index, next) = Self::get_next_inner(0, &access_method_settings);
+ let initial_connection_mode =
+ Self::resolve_inner(next, &relay_selector, &address_cache).await;
let selector = AccessModeSelector {
cmd_rx,
cache_dir,
relay_selector,
- connection_modes,
+ access_method_settings,
address_cache,
access_method_event_sender,
current: initial_connection_mode,
+ index,
+ set: None,
};
tokio::spawn(selector.into_future());
@@ -312,7 +306,14 @@ impl AccessModeSelector {
/// Set the next access method to be returned by the [`Stream`] produced by
/// calling `into_stream`.
fn set_access_method(&mut self, value: AccessMethodSetting) {
- self.connection_modes.set_access_method(value);
+ if let Some(index) = self
+ .access_method_settings
+ .iter()
+ .position(|access_method| access_method.get_id() == value.get_id())
+ {
+ self.index = index;
+ self.set = Some(value);
+ }
}
async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
@@ -351,7 +352,7 @@ impl AccessModeSelector {
);
}
- let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?;
+ let access_method = self.get_next();
log::info!(
"A new API access method has been selected: {name}",
name = access_method.name
@@ -388,17 +389,50 @@ impl AccessModeSelector {
self.current = resolved;
Ok(self.current.connection_mode.clone())
}
+
+ fn get_next(&mut self) -> AccessMethodSetting {
+ if let Some(access_method) = self.set.take() {
+ access_method
+ } else {
+ let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings);
+ self.index = index;
+ next
+ }
+ }
+
+ fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
+ let xs: Vec<_> = access_methods.iter().collect();
+ for offset in 1..=access_methods.cardinality() {
+ let index = (start + offset) % access_methods.cardinality();
+ if let Some(&candidate) = xs.get(index) {
+ if candidate.enabled {
+ return (index, candidate.clone());
+ }
+ }
+ }
+ (0, access_methods.direct().clone())
+ }
+
fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
- values: Vec<AccessMethodSetting>,
+ access_methods: Settings,
) -> Result<()> {
- self.update_access_methods(values)?;
+ self.update_access_methods(access_methods);
self.reply(tx, ())
}
- fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> {
- self.connection_modes.update_access_methods(values)
+ fn update_access_methods(&mut self, access_methods: Settings) {
+ let removed_active = !access_methods
+ .iter()
+ .any(|access_method| access_method.get_id() == self.current.setting.get_id());
+ if removed_active {
+ // A new access mehtod will suddenly have the same index as the one
+ // we are removing, but we want it to still be a candidate. A minor
+ // hack to achieve this is to simply decrement the current index.
+ self.index = self.index.saturating_sub(1);
+ }
+ self.access_method_settings = access_methods;
}
pub async fn on_resolve_access_method(
@@ -455,76 +489,6 @@ fn resolve_connection_mode(
}
}
-/// An iterator which will always produce an [`AccessMethod`].
-///
-/// Safety: It is always safe to [`unwrap`] after calling [`next`] on a
-/// [`std::iter::Cycle`], so thereby it is safe to always call [`unwrap`] on a
-/// [`ConnectionModesIterator`].
-///
-/// [`unwrap`]: Option::unwrap
-/// [`next`]: std::iter::Iterator::next
-pub struct ConnectionModesIterator {
- available_modes: Box<dyn Iterator<Item = AccessMethodSetting> + Send>,
- next: Option<AccessMethodSetting>,
- current: AccessMethodSetting,
-}
-
-impl ConnectionModesIterator {
- pub fn new(
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<ConnectionModesIterator, Error> {
- let mut iterator = Self::new_iterator(access_methods)?;
- Ok(Self {
- next: None,
- current: iterator.next().ok_or(Error::NoAccessMethods)?,
- available_modes: iterator,
- })
- }
-
- /// Set the next [`AccessMethod`] to be returned from this iterator.
- pub fn set_access_method(&mut self, next: AccessMethodSetting) {
- self.next = Some(next);
- }
-
- /// Update the collection of [`AccessMethod`] which this iterator will
- /// return.
- pub fn update_access_methods(
- &mut self,
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<(), Error> {
- self.available_modes = Self::new_iterator(access_methods)?;
- Ok(())
- }
-
- /// Create a cyclic iterator of [`AccessMethodSetting`]s.
- ///
- /// If the `access_methods` argument is an empty vector, an [`Error`] is
- /// returned.
- fn new_iterator(
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> {
- if access_methods.is_empty() {
- Err(Error::NoAccessMethods)
- } else {
- Ok(Box::new(access_methods.into_iter().cycle()))
- }
- }
-}
-
-impl Iterator for ConnectionModesIterator {
- type Item = AccessMethodSetting;
-
- fn next(&mut self) -> Option<Self::Item> {
- let next = self
- .next
- .take()
- .or_else(|| self.available_modes.next())
- .unwrap();
- self.current = next.clone();
- Some(next)
- }
-}
-
pub fn resolve_allowed_endpoint(
connection_mode: &ApiConnectionMode,
fallback: SocketAddr,
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index d54598b284..3c852a1e8d 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -709,15 +709,12 @@ where
.set_config(new_selector_config(settings));
});
- let connection_modes = settings.api_access_methods.collect_enabled();
- let connection_modes_address_cache = api_runtime.address_cache.clone();
-
let connection_modes_handler = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
- connection_modes,
+ settings.api_access_methods.clone(),
internal_event_tx.to_specialized_sender(),
- connection_modes_address_cache.clone(),
+ api_runtime.address_cache.clone().clone(),
)
.await
.map_err(Error::ApiConnectionModeError)?;