summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-01-31 11:26:22 +0100
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-01-31 12:46:16 +0100
commit1f1928d6542177bb57433f22161b8929de09decf (patch)
treeae7f45918c9b77a221d305b5ee3dcdc20460c6c9
parent62728e3b7faf156b7e1527faef65ee3de105dfaf (diff)
downloadmullvadvpn-1f1928d6542177bb57433f22161b8929de09decf.tar.xz
mullvadvpn-1f1928d6542177bb57433f22161b8929de09decf.zip
If the current access method is disabled or removed, select the next available
If the current access method is disabled, select the next available access method from the daemon settings.
-rw-r--r--mullvad-daemon/src/access_method.rs76
-rw-r--r--mullvad-daemon/src/api.rs166
-rw-r--r--mullvad-daemon/src/lib.rs7
3 files changed, 98 insertions, 151 deletions
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs
index e28ba21793..51bf6c1ea5 100644
--- a/mullvad-daemon/src/access_method.rs
+++ b/mullvad-daemon/src/access_method.rs
@@ -114,7 +114,8 @@ where
// Toggle the enabled status if needed
if !access_method.enabled() {
access_method.enable();
- self.update_access_method_inner(&access_method).await?
+ self.update_access_method_inner(access_method.clone())
+ .await?
}
// Set `access_method` as the next access method to use
self.connection_modes_handler
@@ -130,7 +131,8 @@ where
) -> Result<AccessMethodSetting, Error> {
self.settings
.api_access_methods
- .find_by_id(&access_method)
+ .iter()
+ .find(|setting| setting.get_id() == access_method)
.ok_or(Error::NoSuchMethod(access_method))
.cloned()
}
@@ -146,14 +148,20 @@ where
&mut self,
access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
- self.update_access_method_inner(&access_method_update)
+ self.update_access_method_inner(access_method_update.clone())
.await?;
- // If the currently active access method is updated, we need to re-set
- // it after updating the settings.
- if access_method_update.get_id() == self.get_current_access_method().await?.get_id() {
- self.use_api_access_method(access_method_update.get_id())
- .await?;
+ if self.is_in_use(access_method_update.get_id()).await? {
+ if access_method_update.disabled() {
+ // If the currently active access method is updated & disabled
+ // we should select the next access method
+ self.force_api_endpoint_rotation().await?;
+ } else {
+ // If the currently active access method is just updated, we
+ // need to re-set it after updating the settings
+ self.use_api_access_method(access_method_update.get_id())
+ .await?;
+ }
}
Ok(())
@@ -167,33 +175,14 @@ where
/// existing, in-use setting needs to be re-set.
async fn update_access_method_inner(
&mut self,
- access_method_update: &AccessMethodSetting,
+ access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
- let access_method_update_moved = access_method_update.clone();
let settings_update = |settings: &mut Settings| {
- if let Some(access_method) = settings
- .api_access_methods
- .find_by_id_mut(&access_method_update_moved.get_id())
- {
- *access_method = access_method_update_moved;
- // We have to be a bit careful. If the update is about to
- // disable the last remaining enabled access method, we would
- // cause an inconsistent state in the daemon's settings.
- // Therefore, we have to explicitly safeguard against this by.
- // In that case, we should re-enable the `Direct` access method.
- if settings.api_access_methods.collect_enabled().is_empty() {
- if let Some(direct) = settings.api_access_methods.get_direct() {
- direct.enabled = true;
- } else {
- // If the `Direct` access method does not exist within the
- // settings for some reason, the settings are in an
- // inconsistent state. We don't have much choice but to
- // reset these settings to their default value.
- log::warn!("The built-in access methods can not be found. This might be due to a corrupt settings file");
- settings.api_access_methods = access_method::Settings::default();
- }
- }
- }
+ let target = access_method_update.get_id();
+ settings.api_access_methods.update(
+ |access_method| access_method.get_id() == target,
+ |_| access_method_update,
+ );
};
self.settings
@@ -205,6 +194,13 @@ where
Ok(())
}
+ /// Check if some access method is the same as the currently active one.
+ ///
+ /// This can be useful for invalidating stale states.
+ async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> {
+ Ok(access_method == self.get_current_access_method().await?.get_id())
+ }
+
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
@@ -291,19 +287,9 @@ where
.notify_settings(self.settings.to_settings());
let handle = self.connection_modes_handler.clone();
- let new_access_methods = self.settings.api_access_methods.collect_enabled();
+ let new_access_methods = self.settings.api_access_methods.clone();
tokio::spawn(async move {
- match handle.update_access_methods(new_access_methods).await {
- Ok(_) => (),
- Err(api::Error::NoAccessMethods) | Err(_) => {
- // `access_methods` was empty! This implies that the user
- // disabled all access methods. If we ever get into this
- // state, we should default to using the direct access
- // method.
- let default = access_method::Settings::direct();
- handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods");
- }
- }
+ let _ = handle.update_access_methods(new_access_methods).await;
});
};
self
diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs
index 9a4614dca5..588068810e 100644
--- a/mullvad-daemon/src/api.rs
+++ b/mullvad-daemon/src/api.rs
@@ -16,7 +16,9 @@ use mullvad_api::{
AddressCache,
};
use mullvad_relay_selector::RelaySelector;
-use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
+use mullvad_types::access_method::{
+ AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings,
+};
use std::{net::SocketAddr, path::PathBuf};
use talpid_core::mpsc::Sender;
use talpid_types::net::{AllowedClients, AllowedEndpoint, Endpoint, TransportProtocol};
@@ -25,7 +27,7 @@ pub enum Message {
Get(ResponseTx<ResolvedConnectionMode>),
Set(ResponseTx<()>, AccessMethodSetting),
Next(ResponseTx<ApiConnectionMode>),
- Update(ResponseTx<()>, Vec<AccessMethodSetting>),
+ Update(ResponseTx<()>, Settings),
Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting),
}
@@ -164,8 +166,8 @@ impl AccessModeSelectorHandle {
})
}
- pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> {
- self.send_command(|tx| Message::Update(tx, values))
+ pub async fn update_access_methods(&self, access_methods: Settings) -> Result<()> {
+ self.send_command(|tx| Message::Update(tx, access_methods))
.await
.map_err(|err| {
log::debug!("Failed to switch to a new set of access methods");
@@ -213,54 +215,46 @@ impl AccessModeSelectorHandle {
/// connection mode. The API can be connected to either directly (i.e.,
/// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`])
/// or via any supported custom proxy protocol
-/// ([`api_access_methods::ObfuscationProtocol`]).
-///
-/// The strategy for determining the next [`ApiConnectionMode`] is handled by
-/// [`ConnectionModesIterator`].
+/// ([`talpid_types::net::proxy::CustomProxy`]).
pub struct AccessModeSelector {
cmd_rx: mpsc::UnboundedReceiver<Message>,
cache_dir: PathBuf,
/// Used for selecting a Bridge when the `Mullvad Bridges` access method is used.
relay_selector: RelaySelector,
- connection_modes: ConnectionModesIterator,
+ access_method_settings: Settings,
address_cache: AddressCache,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
current: ResolvedConnectionMode,
+ /// `index` is used to keep track of the [`AccessMethodSetting`] to use.
+ index: usize,
+ /// `set` is used to set the next [`AccessMethodSetting`] to use.
+ set: Option<AccessMethodSetting>,
}
impl AccessModeSelector {
pub(crate) async fn spawn(
cache_dir: PathBuf,
relay_selector: RelaySelector,
- connection_modes: Vec<AccessMethodSetting>,
+ access_method_settings: Settings,
access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>,
address_cache: AddressCache,
) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();
- let mut connection_modes =
- ConnectionModesIterator::new(connection_modes).unwrap_or_else(|_| {
- // No settings seem to have been found. Default to using the the
- // direct access method.
- let default = mullvad_types::access_method::Settings::direct();
- ConnectionModesIterator::new(vec![default]).expect(
- "Failed to create the data structure responsible for managing access methods",
- )
- });
-
- let initial_connection_mode = {
- let next = connection_modes.next().ok_or(Error::NoAccessMethods)?;
- Self::resolve_inner(next, &relay_selector, &address_cache).await
- };
+ let (index, next) = Self::get_next_inner(0, &access_method_settings);
+ let initial_connection_mode =
+ Self::resolve_inner(next, &relay_selector, &address_cache).await;
let selector = AccessModeSelector {
cmd_rx,
cache_dir,
relay_selector,
- connection_modes,
+ access_method_settings,
address_cache,
access_method_event_sender,
current: initial_connection_mode,
+ index,
+ set: None,
};
tokio::spawn(selector.into_future());
@@ -312,7 +306,14 @@ impl AccessModeSelector {
/// Set the next access method to be returned by the [`Stream`] produced by
/// calling `into_stream`.
fn set_access_method(&mut self, value: AccessMethodSetting) {
- self.connection_modes.set_access_method(value);
+ if let Some(index) = self
+ .access_method_settings
+ .iter()
+ .position(|access_method| access_method.get_id() == value.get_id())
+ {
+ self.index = index;
+ self.set = Some(value);
+ }
}
async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
@@ -351,7 +352,7 @@ impl AccessModeSelector {
);
}
- let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?;
+ let access_method = self.get_next();
log::info!(
"A new API access method has been selected: {name}",
name = access_method.name
@@ -388,17 +389,50 @@ impl AccessModeSelector {
self.current = resolved;
Ok(self.current.connection_mode.clone())
}
+
+ fn get_next(&mut self) -> AccessMethodSetting {
+ if let Some(access_method) = self.set.take() {
+ access_method
+ } else {
+ let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings);
+ self.index = index;
+ next
+ }
+ }
+
+ fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
+ let xs: Vec<_> = access_methods.iter().collect();
+ for offset in 1..=access_methods.cardinality() {
+ let index = (start + offset) % access_methods.cardinality();
+ if let Some(&candidate) = xs.get(index) {
+ if candidate.enabled {
+ return (index, candidate.clone());
+ }
+ }
+ }
+ (0, access_methods.direct().clone())
+ }
+
fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
- values: Vec<AccessMethodSetting>,
+ access_methods: Settings,
) -> Result<()> {
- self.update_access_methods(values)?;
+ self.update_access_methods(access_methods);
self.reply(tx, ())
}
- fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> {
- self.connection_modes.update_access_methods(values)
+ fn update_access_methods(&mut self, access_methods: Settings) {
+ let removed_active = !access_methods
+ .iter()
+ .any(|access_method| access_method.get_id() == self.current.setting.get_id());
+ if removed_active {
+ // A new access mehtod will suddenly have the same index as the one
+ // we are removing, but we want it to still be a candidate. A minor
+ // hack to achieve this is to simply decrement the current index.
+ self.index = self.index.saturating_sub(1);
+ }
+ self.access_method_settings = access_methods;
}
pub async fn on_resolve_access_method(
@@ -455,76 +489,6 @@ fn resolve_connection_mode(
}
}
-/// An iterator which will always produce an [`AccessMethod`].
-///
-/// Safety: It is always safe to [`unwrap`] after calling [`next`] on a
-/// [`std::iter::Cycle`], so thereby it is safe to always call [`unwrap`] on a
-/// [`ConnectionModesIterator`].
-///
-/// [`unwrap`]: Option::unwrap
-/// [`next`]: std::iter::Iterator::next
-pub struct ConnectionModesIterator {
- available_modes: Box<dyn Iterator<Item = AccessMethodSetting> + Send>,
- next: Option<AccessMethodSetting>,
- current: AccessMethodSetting,
-}
-
-impl ConnectionModesIterator {
- pub fn new(
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<ConnectionModesIterator, Error> {
- let mut iterator = Self::new_iterator(access_methods)?;
- Ok(Self {
- next: None,
- current: iterator.next().ok_or(Error::NoAccessMethods)?,
- available_modes: iterator,
- })
- }
-
- /// Set the next [`AccessMethod`] to be returned from this iterator.
- pub fn set_access_method(&mut self, next: AccessMethodSetting) {
- self.next = Some(next);
- }
-
- /// Update the collection of [`AccessMethod`] which this iterator will
- /// return.
- pub fn update_access_methods(
- &mut self,
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<(), Error> {
- self.available_modes = Self::new_iterator(access_methods)?;
- Ok(())
- }
-
- /// Create a cyclic iterator of [`AccessMethodSetting`]s.
- ///
- /// If the `access_methods` argument is an empty vector, an [`Error`] is
- /// returned.
- fn new_iterator(
- access_methods: Vec<AccessMethodSetting>,
- ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> {
- if access_methods.is_empty() {
- Err(Error::NoAccessMethods)
- } else {
- Ok(Box::new(access_methods.into_iter().cycle()))
- }
- }
-}
-
-impl Iterator for ConnectionModesIterator {
- type Item = AccessMethodSetting;
-
- fn next(&mut self) -> Option<Self::Item> {
- let next = self
- .next
- .take()
- .or_else(|| self.available_modes.next())
- .unwrap();
- self.current = next.clone();
- Some(next)
- }
-}
-
pub fn resolve_allowed_endpoint(
connection_mode: &ApiConnectionMode,
fallback: SocketAddr,
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index d54598b284..3c852a1e8d 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -709,15 +709,12 @@ where
.set_config(new_selector_config(settings));
});
- let connection_modes = settings.api_access_methods.collect_enabled();
- let connection_modes_address_cache = api_runtime.address_cache.clone();
-
let connection_modes_handler = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
- connection_modes,
+ settings.api_access_methods.clone(),
internal_event_tx.to_specialized_sender(),
- connection_modes_address_cache.clone(),
+ api_runtime.address_cache.clone().clone(),
)
.await
.map_err(Error::ApiConnectionModeError)?;