diff options
| -rw-r--r-- | mullvad-daemon/src/access_method.rs | 76 | ||||
| -rw-r--r-- | mullvad-daemon/src/api.rs | 166 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 7 |
3 files changed, 98 insertions, 151 deletions
diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index e28ba21793..51bf6c1ea5 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -114,7 +114,8 @@ where // Toggle the enabled status if needed if !access_method.enabled() { access_method.enable(); - self.update_access_method_inner(&access_method).await? + self.update_access_method_inner(access_method.clone()) + .await? } // Set `access_method` as the next access method to use self.connection_modes_handler @@ -130,7 +131,8 @@ where ) -> Result<AccessMethodSetting, Error> { self.settings .api_access_methods - .find_by_id(&access_method) + .iter() + .find(|setting| setting.get_id() == access_method) .ok_or(Error::NoSuchMethod(access_method)) .cloned() } @@ -146,14 +148,20 @@ where &mut self, access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - self.update_access_method_inner(&access_method_update) + self.update_access_method_inner(access_method_update.clone()) .await?; - // If the currently active access method is updated, we need to re-set - // it after updating the settings. - if access_method_update.get_id() == self.get_current_access_method().await?.get_id() { - self.use_api_access_method(access_method_update.get_id()) - .await?; + if self.is_in_use(access_method_update.get_id()).await? { + if access_method_update.disabled() { + // If the currently active access method is updated & disabled + // we should select the next access method + self.force_api_endpoint_rotation().await?; + } else { + // If the currently active access method is just updated, we + // need to re-set it after updating the settings + self.use_api_access_method(access_method_update.get_id()) + .await?; + } } Ok(()) @@ -167,33 +175,14 @@ where /// existing, in-use setting needs to be re-set. async fn update_access_method_inner( &mut self, - access_method_update: &AccessMethodSetting, + access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - let access_method_update_moved = access_method_update.clone(); let settings_update = |settings: &mut Settings| { - if let Some(access_method) = settings - .api_access_methods - .find_by_id_mut(&access_method_update_moved.get_id()) - { - *access_method = access_method_update_moved; - // We have to be a bit careful. If the update is about to - // disable the last remaining enabled access method, we would - // cause an inconsistent state in the daemon's settings. - // Therefore, we have to explicitly safeguard against this by. - // In that case, we should re-enable the `Direct` access method. - if settings.api_access_methods.collect_enabled().is_empty() { - if let Some(direct) = settings.api_access_methods.get_direct() { - direct.enabled = true; - } else { - // If the `Direct` access method does not exist within the - // settings for some reason, the settings are in an - // inconsistent state. We don't have much choice but to - // reset these settings to their default value. - log::warn!("The built-in access methods can not be found. This might be due to a corrupt settings file"); - settings.api_access_methods = access_method::Settings::default(); - } - } - } + let target = access_method_update.get_id(); + settings.api_access_methods.update( + |access_method| access_method.get_id() == target, + |_| access_method_update, + ); }; self.settings @@ -205,6 +194,13 @@ where Ok(()) } + /// Check if some access method is the same as the currently active one. + /// + /// This can be useful for invalidating stale states. + async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> { + Ok(access_method == self.get_current_access_method().await?.get_id()) + } + /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> { @@ -291,19 +287,9 @@ where .notify_settings(self.settings.to_settings()); let handle = self.connection_modes_handler.clone(); - let new_access_methods = self.settings.api_access_methods.collect_enabled(); + let new_access_methods = self.settings.api_access_methods.clone(); tokio::spawn(async move { - match handle.update_access_methods(new_access_methods).await { - Ok(_) => (), - Err(api::Error::NoAccessMethods) | Err(_) => { - // `access_methods` was empty! This implies that the user - // disabled all access methods. If we ever get into this - // state, we should default to using the direct access - // method. - let default = access_method::Settings::direct(); - handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods"); - } - } + let _ = handle.update_access_methods(new_access_methods).await; }); }; self diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index 9a4614dca5..588068810e 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -16,7 +16,9 @@ use mullvad_api::{ AddressCache, }; use mullvad_relay_selector::RelaySelector; -use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod}; +use mullvad_types::access_method::{ + AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings, +}; use std::{net::SocketAddr, path::PathBuf}; use talpid_core::mpsc::Sender; use talpid_types::net::{AllowedClients, AllowedEndpoint, Endpoint, TransportProtocol}; @@ -25,7 +27,7 @@ pub enum Message { Get(ResponseTx<ResolvedConnectionMode>), Set(ResponseTx<()>, AccessMethodSetting), Next(ResponseTx<ApiConnectionMode>), - Update(ResponseTx<()>, Vec<AccessMethodSetting>), + Update(ResponseTx<()>, Settings), Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting), } @@ -164,8 +166,8 @@ impl AccessModeSelectorHandle { }) } - pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> { - self.send_command(|tx| Message::Update(tx, values)) + pub async fn update_access_methods(&self, access_methods: Settings) -> Result<()> { + self.send_command(|tx| Message::Update(tx, access_methods)) .await .map_err(|err| { log::debug!("Failed to switch to a new set of access methods"); @@ -213,54 +215,46 @@ impl AccessModeSelectorHandle { /// connection mode. The API can be connected to either directly (i.e., /// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`]) /// or via any supported custom proxy protocol -/// ([`api_access_methods::ObfuscationProtocol`]). -/// -/// The strategy for determining the next [`ApiConnectionMode`] is handled by -/// [`ConnectionModesIterator`]. +/// ([`talpid_types::net::proxy::CustomProxy`]). pub struct AccessModeSelector { cmd_rx: mpsc::UnboundedReceiver<Message>, cache_dir: PathBuf, /// Used for selecting a Bridge when the `Mullvad Bridges` access method is used. relay_selector: RelaySelector, - connection_modes: ConnectionModesIterator, + access_method_settings: Settings, address_cache: AddressCache, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, current: ResolvedConnectionMode, + /// `index` is used to keep track of the [`AccessMethodSetting`] to use. + index: usize, + /// `set` is used to set the next [`AccessMethodSetting`] to use. + set: Option<AccessMethodSetting>, } impl AccessModeSelector { pub(crate) async fn spawn( cache_dir: PathBuf, relay_selector: RelaySelector, - connection_modes: Vec<AccessMethodSetting>, + access_method_settings: Settings, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, address_cache: AddressCache, ) -> Result<AccessModeSelectorHandle> { let (cmd_tx, cmd_rx) = mpsc::unbounded(); - let mut connection_modes = - ConnectionModesIterator::new(connection_modes).unwrap_or_else(|_| { - // No settings seem to have been found. Default to using the the - // direct access method. - let default = mullvad_types::access_method::Settings::direct(); - ConnectionModesIterator::new(vec![default]).expect( - "Failed to create the data structure responsible for managing access methods", - ) - }); - - let initial_connection_mode = { - let next = connection_modes.next().ok_or(Error::NoAccessMethods)?; - Self::resolve_inner(next, &relay_selector, &address_cache).await - }; + let (index, next) = Self::get_next_inner(0, &access_method_settings); + let initial_connection_mode = + Self::resolve_inner(next, &relay_selector, &address_cache).await; let selector = AccessModeSelector { cmd_rx, cache_dir, relay_selector, - connection_modes, + access_method_settings, address_cache, access_method_event_sender, current: initial_connection_mode, + index, + set: None, }; tokio::spawn(selector.into_future()); @@ -312,7 +306,14 @@ impl AccessModeSelector { /// Set the next access method to be returned by the [`Stream`] produced by /// calling `into_stream`. fn set_access_method(&mut self, value: AccessMethodSetting) { - self.connection_modes.set_access_method(value); + if let Some(index) = self + .access_method_settings + .iter() + .position(|access_method| access_method.get_id() == value.get_id()) + { + self.index = index; + self.set = Some(value); + } } async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> { @@ -351,7 +352,7 @@ impl AccessModeSelector { ); } - let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?; + let access_method = self.get_next(); log::info!( "A new API access method has been selected: {name}", name = access_method.name @@ -388,17 +389,50 @@ impl AccessModeSelector { self.current = resolved; Ok(self.current.connection_mode.clone()) } + + fn get_next(&mut self) -> AccessMethodSetting { + if let Some(access_method) = self.set.take() { + access_method + } else { + let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings); + self.index = index; + next + } + } + + fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { + let xs: Vec<_> = access_methods.iter().collect(); + for offset in 1..=access_methods.cardinality() { + let index = (start + offset) % access_methods.cardinality(); + if let Some(&candidate) = xs.get(index) { + if candidate.enabled { + return (index, candidate.clone()); + } + } + } + (0, access_methods.direct().clone()) + } + fn on_update_access_methods( &mut self, tx: ResponseTx<()>, - values: Vec<AccessMethodSetting>, + access_methods: Settings, ) -> Result<()> { - self.update_access_methods(values)?; + self.update_access_methods(access_methods); self.reply(tx, ()) } - fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> { - self.connection_modes.update_access_methods(values) + fn update_access_methods(&mut self, access_methods: Settings) { + let removed_active = !access_methods + .iter() + .any(|access_method| access_method.get_id() == self.current.setting.get_id()); + if removed_active { + // A new access mehtod will suddenly have the same index as the one + // we are removing, but we want it to still be a candidate. A minor + // hack to achieve this is to simply decrement the current index. + self.index = self.index.saturating_sub(1); + } + self.access_method_settings = access_methods; } pub async fn on_resolve_access_method( @@ -455,76 +489,6 @@ fn resolve_connection_mode( } } -/// An iterator which will always produce an [`AccessMethod`]. -/// -/// Safety: It is always safe to [`unwrap`] after calling [`next`] on a -/// [`std::iter::Cycle`], so thereby it is safe to always call [`unwrap`] on a -/// [`ConnectionModesIterator`]. -/// -/// [`unwrap`]: Option::unwrap -/// [`next`]: std::iter::Iterator::next -pub struct ConnectionModesIterator { - available_modes: Box<dyn Iterator<Item = AccessMethodSetting> + Send>, - next: Option<AccessMethodSetting>, - current: AccessMethodSetting, -} - -impl ConnectionModesIterator { - pub fn new( - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<ConnectionModesIterator, Error> { - let mut iterator = Self::new_iterator(access_methods)?; - Ok(Self { - next: None, - current: iterator.next().ok_or(Error::NoAccessMethods)?, - available_modes: iterator, - }) - } - - /// Set the next [`AccessMethod`] to be returned from this iterator. - pub fn set_access_method(&mut self, next: AccessMethodSetting) { - self.next = Some(next); - } - - /// Update the collection of [`AccessMethod`] which this iterator will - /// return. - pub fn update_access_methods( - &mut self, - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<(), Error> { - self.available_modes = Self::new_iterator(access_methods)?; - Ok(()) - } - - /// Create a cyclic iterator of [`AccessMethodSetting`]s. - /// - /// If the `access_methods` argument is an empty vector, an [`Error`] is - /// returned. - fn new_iterator( - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> { - if access_methods.is_empty() { - Err(Error::NoAccessMethods) - } else { - Ok(Box::new(access_methods.into_iter().cycle())) - } - } -} - -impl Iterator for ConnectionModesIterator { - type Item = AccessMethodSetting; - - fn next(&mut self) -> Option<Self::Item> { - let next = self - .next - .take() - .or_else(|| self.available_modes.next()) - .unwrap(); - self.current = next.clone(); - Some(next) - } -} - pub fn resolve_allowed_endpoint( connection_mode: &ApiConnectionMode, fallback: SocketAddr, diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d54598b284..3c852a1e8d 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -709,15 +709,12 @@ where .set_config(new_selector_config(settings)); }); - let connection_modes = settings.api_access_methods.collect_enabled(); - let connection_modes_address_cache = api_runtime.address_cache.clone(); - let connection_modes_handler = api::AccessModeSelector::spawn( cache_dir.clone(), relay_selector.clone(), - connection_modes, + settings.api_access_methods.clone(), internal_event_tx.to_specialized_sender(), - connection_modes_address_cache.clone(), + api_runtime.address_cache.clone().clone(), ) .await .map_err(Error::ApiConnectionModeError)?; |
