diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-01-31 13:11:59 +0100 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-01-31 13:11:59 +0100 |
| commit | 0d9abfd26b0d0e2151eaebaae1bd01f2f439466b (patch) | |
| tree | 6f7c6c6fd29bcf69cd5906a56a14c3cdeeab3d45 | |
| parent | 9b9da3ebbd90d52cec84c17edb4d99a472fd9a61 (diff) | |
| parent | 87cb3b634dd7b9a5e173637dff956f9544aaccd3 (diff) | |
| download | mullvadvpn-0d9abfd26b0d0e2151eaebaae1bd01f2f439466b.tar.xz mullvadvpn-0d9abfd26b0d0e2151eaebaae1bd01f2f439466b.zip | |
Merge branch 'correctly-handle-removal-disabling-of-access-method-des-561'
20 files changed, 755 insertions, 337 deletions
diff --git a/gui/.eslintrc.js b/gui/.eslintrc.js index 05bdf4cd45..4cf1921910 100644 --- a/gui/.eslintrc.js +++ b/gui/.eslintrc.js @@ -28,6 +28,10 @@ const namingConvention = [ format: ['camelCase'], }, { + selector: 'typeProperty', + format: ['camelCase'], + }, + { selector: 'typeLike', format: ['PascalCase'], }, @@ -103,6 +107,6 @@ module.exports = { '@typescript-eslint/no-use-before-define': 'off', '@typescript-eslint/explicit-module-boundary-types': 'off', '@typescript-eslint/no-non-null-assertion': 'off', - 'react/prop-types': 'off' + 'react/prop-types': 'off', }, }; diff --git a/gui/src/main/daemon-rpc.ts b/gui/src/main/daemon-rpc.ts index 681eb0a454..6290aa7a92 100644 --- a/gui/src/main/daemon-rpc.ts +++ b/gui/src/main/daemon-rpc.ts @@ -1163,7 +1163,7 @@ function convertFromSettings(settings: grpcTypes.Settings): ISettings | undefine const splitTunnel = settingsObject.splitTunnel ?? { enableExclusions: false, appsList: [] }; const obfuscationSettings = convertFromObfuscationSettings(settingsObject.obfuscationSettings); const customLists = convertFromCustomListSettings(settings.getCustomLists()); - const apiAccessMethods = convertFromApiAccessMethodSettings(settings.getApiAccessMethods()); + const apiAccessMethods = convertFromApiAccessMethodSettings(settings.getApiAccessMethods()!); return { ...settings.toObject(), bridgeState, @@ -1893,14 +1893,25 @@ function convertToSocksAuth(authentication: SocksAuth): grpcTypes.SocksAuth { } function convertFromApiAccessMethodSettings( - accessMethods?: grpcTypes.ApiAccessMethodSettings, + accessMethods: grpcTypes.ApiAccessMethodSettings, ): ApiAccessMethodSettings { - return ( + const direct = convertFromApiAccessMethodSetting( + ensureExists(accessMethods.getDirect(), "no 'Direct' access method was found"), + ); + const bridges = convertFromApiAccessMethodSetting( + ensureExists(accessMethods.getMullvadBridges(), "no 'Mullvad Bridges' access method was found"), + ); + const custom = accessMethods - ?.getAccessMethodSettingsList() + .getCustomList() .filter((setting) => setting.hasId() && setting.hasAccessMethod()) - .map(convertFromApiAccessMethodSetting) ?? [] - ); + .map(convertFromApiAccessMethodSetting) ?? []; + + return { + direct, + mullvadBridges: bridges, + custom, + }; } function convertFromApiAccessMethodSetting( diff --git a/gui/src/main/default-settings.ts b/gui/src/main/default-settings.ts index 4510a71896..7f7656a56a 100644 --- a/gui/src/main/default-settings.ts +++ b/gui/src/main/default-settings.ts @@ -1,4 +1,9 @@ -import { ISettings, ObfuscationType, Ownership } from '../shared/daemon-rpc-types'; +import { + ApiAccessMethodSettings, + ISettings, + ObfuscationType, + Ownership, +} from '../shared/daemon-rpc-types'; export function getDefaultSettings(): ISettings { return { @@ -71,6 +76,26 @@ export function getDefaultSettings(): ISettings { }, }, customLists: [], - apiAccessMethods: [], + apiAccessMethods: getDefaultApiAccessMethods(), + }; +} + +export function getDefaultApiAccessMethods(): ApiAccessMethodSettings { + // 'id's are UUIDs generated by the daemon when an access method is created, + // and as such we can't provide a good default value for them. + return { + direct: { + id: '', + name: 'Direct', + enabled: true, + type: 'direct', + }, + mullvadBridges: { + id: '', + name: 'Mullvad Bridges', + enabled: false, + type: 'bridges', + }, + custom: [], }; } diff --git a/gui/src/renderer/components/ApiAccessMethods.tsx b/gui/src/renderer/components/ApiAccessMethods.tsx index c8a6f98b19..0774fcc320 100644 --- a/gui/src/renderer/components/ApiAccessMethods.tsx +++ b/gui/src/renderer/components/ApiAccessMethods.tsx @@ -125,11 +125,20 @@ export default function ApiAccessMethods() { <StyledSettingsContent> <Cell.Group> - {methods.map((method) => ( + <ApiAccessMethod + method={methods.direct} + inUse={methods.direct.id === currentMethod?.id} + /> + <ApiAccessMethod + method={methods.mullvadBridges} + inUse={methods.mullvadBridges.id === currentMethod?.id} + /> + {methods.custom.map((method) => ( <ApiAccessMethod key={method.id} method={method} inUse={method.id === currentMethod?.id} + custom /> ))} </Cell.Group> @@ -150,6 +159,7 @@ export default function ApiAccessMethods() { interface ApiAccessMethodProps { method: AccessMethodSetting; inUse: boolean; + custom?: boolean; } function ApiAccessMethod(props: ApiAccessMethodProps) { @@ -186,8 +196,8 @@ function ApiAccessMethod(props: ApiAccessMethodProps) { } }, [testApiAccessMethod, props.method.id]); - const menuItems = useMemo<Array<ContextMenuItem>>( - () => [ + const menuItems = useMemo<Array<ContextMenuItem>>(() => { + const items: Array<ContextMenuItem> = [ { type: 'item' as const, label: 'Use', @@ -195,28 +205,30 @@ function ApiAccessMethod(props: ApiAccessMethodProps) { onClick: setApiAccessMethod, }, { type: 'item' as const, label: 'Test', onClick: () => testApiAccessMethod(props.method.id) }, - // Edit and Delete shouldn't be available for direct and bridges. - ...(props.method.type === 'direct' || props.method.type === 'bridges' - ? [] - : [ - { type: 'separator' as const }, - { - type: 'item' as const, - label: 'Edit', - onClick: () => - history.push( - generateRoutePath(RoutePath.editApiAccessMethods, { id: props.method.id }), - ), - }, - { - type: 'item' as const, - label: 'Delete', - onClick: showRemoveConfirmation, - }, - ]), - ], - [props.method.id, props.inUse, setApiAccessMethod, testApiAccessMethod, history.push], - ); + ]; + + // Edit and Delete shouldn't be available for direct and bridges. + if (props.custom) { + items.push( + { type: 'separator' as const }, + { + type: 'item' as const, + label: 'Edit', + onClick: () => + history.push( + generateRoutePath(RoutePath.editApiAccessMethods, { id: props.method.id }), + ), + }, + { + type: 'item' as const, + label: 'Delete', + onClick: showRemoveConfirmation, + }, + ); + } + + return items; + }, [props.method.id, props.inUse, setApiAccessMethod, testApiAccessMethod, history.push]); return ( <Cell.Row> diff --git a/gui/src/renderer/components/EditApiAccessMethod.tsx b/gui/src/renderer/components/EditApiAccessMethod.tsx index c56329214b..ecb633ae98 100644 --- a/gui/src/renderer/components/EditApiAccessMethod.tsx +++ b/gui/src/renderer/components/EditApiAccessMethod.tsx @@ -55,7 +55,10 @@ function AccessMethodForm() { // Use id in url to figure out which method is to be edited. undefined means this is a new method. const { id } = useParams<{ id: string | undefined }>(); - const method = methods.find((method) => method.id === id); + // Ugly way of iterating over all access methods, but it works. + const method = [methods.direct, methods.mullvadBridges, ...methods.custom].find( + (method) => method.id === id, + ); const updatedMethod = useRef<NewAccessMethodSetting | undefined>(method); const updateMethod = useCallback( diff --git a/gui/src/renderer/components/Switch.tsx b/gui/src/renderer/components/Switch.tsx index 595bc422a2..f87bf2a52b 100644 --- a/gui/src/renderer/components/Switch.tsx +++ b/gui/src/renderer/components/Switch.tsx @@ -5,7 +5,9 @@ import { colors } from '../../config.json'; interface IProps { id?: string; + // eslint-disable-next-line @typescript-eslint/naming-convention 'aria-labelledby'?: string; + // eslint-disable-next-line @typescript-eslint/naming-convention 'aria-describedby'?: string; isOn: boolean; onChange?: (isOn: boolean) => void; diff --git a/gui/src/renderer/components/TransitionContainer.tsx b/gui/src/renderer/components/TransitionContainer.tsx index 3e7007e3c8..9773c64512 100644 --- a/gui/src/renderer/components/TransitionContainer.tsx +++ b/gui/src/renderer/components/TransitionContainer.tsx @@ -48,6 +48,7 @@ interface StyledTransitionContentProps { export const StyledTransitionContent = styled.div.attrs< StyledTransitionContentProps, + // eslint-disable-next-line @typescript-eslint/naming-convention { 'data-testid': string } >({ 'data-testid': 'transition-content', diff --git a/gui/src/renderer/redux/settings/reducers.ts b/gui/src/renderer/redux/settings/reducers.ts index cdff32e22c..bb971f896a 100644 --- a/gui/src/renderer/redux/settings/reducers.ts +++ b/gui/src/renderer/redux/settings/reducers.ts @@ -1,3 +1,4 @@ +import { getDefaultApiAccessMethods } from '../../../main/default-settings'; import { IWindowsApplication } from '../../../shared/application-types'; import { AccessMethodSetting, @@ -178,7 +179,7 @@ const initialState: ISettingsReduxState = { }, }, customLists: [], - apiAccessMethods: [], + apiAccessMethods: getDefaultApiAccessMethods(), currentApiAccessMethod: undefined, }; diff --git a/gui/src/shared/daemon-rpc-types.ts b/gui/src/shared/daemon-rpc-types.ts index 71f6936804..f048549b7a 100644 --- a/gui/src/shared/daemon-rpc-types.ts +++ b/gui/src/shared/daemon-rpc-types.ts @@ -533,7 +533,11 @@ export type AccessMethodSetting = NewAccessMethodSetting & { id: string; }; -export type ApiAccessMethodSettings = Array<AccessMethodSetting>; +export type ApiAccessMethodSettings = { + direct: AccessMethodSetting; + mullvadBridges: AccessMethodSetting; + custom: Array<AccessMethodSetting>; +}; export function parseSocketAddress(socketAddrStr: string): ISocketAddress { const re = new RegExp(/(.+):(\d+)$/); diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 835061c77b..51bf6c1ea5 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -14,9 +14,6 @@ pub enum Error { /// Can not add access method #[error(display = "Cannot add custom access method")] Add, - /// Can not remove built-in access method - #[error(display = "Cannot remove built-in access method")] - RemoveBuiltIn, /// Can not find access method #[error(display = "Cannot find custom access method {}", _0)] NoSuchMethod(access_method::Id), @@ -30,6 +27,9 @@ pub enum Error { /// A REST request failed #[error(display = "Reset request failed")] Rest(#[error(source)] rest::Error), + /// Something went wrong in the [`access_method`](mod@access_method) module. + #[error(display = "Access method error")] + AccessMethod(#[error(source)] access_method::Error), /// Access methods settings error #[error(display = "Settings error")] Settings(#[error(source)] settings::Error), @@ -68,29 +68,29 @@ where &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - match self.settings.api_access_methods.find_by_id(&access_method) { - // Make sure that we are not trying to remove a built-in API access - // method - Some(api_access_method) if api_access_method.is_builtin() => { - return Err(Error::RemoveBuiltIn) - } - // If the currently active access method is removed, a new access - // method should trigger - Some(api_access_method) - if api_access_method.get_id() - == self.get_current_access_method().await?.get_id() => - { - self.force_api_endpoint_rotation().await?; - } - _ => (), + let did_change = self + .settings + .try_update(|settings| -> Result<(), Error> { + settings.api_access_methods.remove(&access_method)?; + Ok(()) + }) + .await + .map_err(Error::Settings)?; + + self.notify_on_change(did_change); + // If the currently active access method is removed, a new access + // method should be selected. + // + // Notice the ordering here: It is important that the current method is + // removed before we pick a new access method. The `remove` function + // will ensure that atleast one access method is enabled after the + // removal. If the currently active access method is removed, some other + // method is enabled before we pick the next access method to use. + if self.is_in_use(access_method.clone()).await? { + self.force_api_endpoint_rotation().await?; } - self.settings - .update(|settings| settings.api_access_methods.remove(&access_method)) - .await - .map(|did_change| self.notify_on_change(did_change)) - .map(|_| ()) - .map_err(Error::Settings) + Ok(()) } /// Select an [`AccessMethodSetting`] as the current API access method. @@ -114,7 +114,8 @@ where // Toggle the enabled status if needed if !access_method.enabled() { access_method.enable(); - self.update_access_method_inner(&access_method).await? + self.update_access_method_inner(access_method.clone()) + .await? } // Set `access_method` as the next access method to use self.connection_modes_handler @@ -130,7 +131,8 @@ where ) -> Result<AccessMethodSetting, Error> { self.settings .api_access_methods - .find_by_id(&access_method) + .iter() + .find(|setting| setting.get_id() == access_method) .ok_or(Error::NoSuchMethod(access_method)) .cloned() } @@ -146,14 +148,20 @@ where &mut self, access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - self.update_access_method_inner(&access_method_update) + self.update_access_method_inner(access_method_update.clone()) .await?; - // If the currently active access method is updated, we need to re-set - // it after updating the settings. - if access_method_update.get_id() == self.get_current_access_method().await?.get_id() { - self.use_api_access_method(access_method_update.get_id()) - .await?; + if self.is_in_use(access_method_update.get_id()).await? { + if access_method_update.disabled() { + // If the currently active access method is updated & disabled + // we should select the next access method + self.force_api_endpoint_rotation().await?; + } else { + // If the currently active access method is just updated, we + // need to re-set it after updating the settings + self.use_api_access_method(access_method_update.get_id()) + .await?; + } } Ok(()) @@ -167,33 +175,14 @@ where /// existing, in-use setting needs to be re-set. async fn update_access_method_inner( &mut self, - access_method_update: &AccessMethodSetting, + access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - let access_method_update_moved = access_method_update.clone(); let settings_update = |settings: &mut Settings| { - if let Some(access_method) = settings - .api_access_methods - .find_by_id_mut(&access_method_update_moved.get_id()) - { - *access_method = access_method_update_moved; - // We have to be a bit careful. If the update is about to - // disable the last remaining enabled access method, we would - // cause an inconsistent state in the daemon's settings. - // Therefore, we have to explicitly safeguard against this by. - // In that case, we should re-enable the `Direct` access method. - if settings.api_access_methods.collect_enabled().is_empty() { - if let Some(direct) = settings.api_access_methods.get_direct() { - direct.enabled = true; - } else { - // If the `Direct` access method does not exist within the - // settings for some reason, the settings are in an - // inconsistent state. We don't have much choice but to - // reset these settings to their default value. - log::warn!("The built-in access methods can not be found. This might be due to a corrupt settings file"); - settings.api_access_methods = access_method::Settings::default(); - } - } - } + let target = access_method_update.get_id(); + settings.api_access_methods.update( + |access_method| access_method.get_id() == target, + |_| access_method_update, + ); }; self.settings @@ -205,6 +194,13 @@ where Ok(()) } + /// Check if some access method is the same as the currently active one. + /// + /// This can be useful for invalidating stale states. + async fn is_in_use(&self, access_method: access_method::Id) -> Result<bool, Error> { + Ok(access_method == self.get_current_access_method().await?.get_id()) + } + /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> { @@ -291,19 +287,9 @@ where .notify_settings(self.settings.to_settings()); let handle = self.connection_modes_handler.clone(); - let new_access_methods = self.settings.api_access_methods.collect_enabled(); + let new_access_methods = self.settings.api_access_methods.clone(); tokio::spawn(async move { - match handle.update_access_methods(new_access_methods).await { - Ok(_) => (), - Err(api::Error::NoAccessMethods) | Err(_) => { - // `access_methods` was empty! This implies that the user - // disabled all access methods. If we ever get into this - // state, we should default to using the direct access - // method. - let default = access_method::Settings::direct(); - handle.update_access_methods(vec![default]).await.expect("Failed to create the data structure responsible for managing access methods"); - } - } + let _ = handle.update_access_methods(new_access_methods).await; }); }; self diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index 9a4614dca5..588068810e 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -16,7 +16,9 @@ use mullvad_api::{ AddressCache, }; use mullvad_relay_selector::RelaySelector; -use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod}; +use mullvad_types::access_method::{ + AccessMethod, AccessMethodSetting, BuiltInAccessMethod, Settings, +}; use std::{net::SocketAddr, path::PathBuf}; use talpid_core::mpsc::Sender; use talpid_types::net::{AllowedClients, AllowedEndpoint, Endpoint, TransportProtocol}; @@ -25,7 +27,7 @@ pub enum Message { Get(ResponseTx<ResolvedConnectionMode>), Set(ResponseTx<()>, AccessMethodSetting), Next(ResponseTx<ApiConnectionMode>), - Update(ResponseTx<()>, Vec<AccessMethodSetting>), + Update(ResponseTx<()>, Settings), Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting), } @@ -164,8 +166,8 @@ impl AccessModeSelectorHandle { }) } - pub async fn update_access_methods(&self, values: Vec<AccessMethodSetting>) -> Result<()> { - self.send_command(|tx| Message::Update(tx, values)) + pub async fn update_access_methods(&self, access_methods: Settings) -> Result<()> { + self.send_command(|tx| Message::Update(tx, access_methods)) .await .map_err(|err| { log::debug!("Failed to switch to a new set of access methods"); @@ -213,54 +215,46 @@ impl AccessModeSelectorHandle { /// connection mode. The API can be connected to either directly (i.e., /// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`]) /// or via any supported custom proxy protocol -/// ([`api_access_methods::ObfuscationProtocol`]). -/// -/// The strategy for determining the next [`ApiConnectionMode`] is handled by -/// [`ConnectionModesIterator`]. +/// ([`talpid_types::net::proxy::CustomProxy`]). pub struct AccessModeSelector { cmd_rx: mpsc::UnboundedReceiver<Message>, cache_dir: PathBuf, /// Used for selecting a Bridge when the `Mullvad Bridges` access method is used. relay_selector: RelaySelector, - connection_modes: ConnectionModesIterator, + access_method_settings: Settings, address_cache: AddressCache, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, current: ResolvedConnectionMode, + /// `index` is used to keep track of the [`AccessMethodSetting`] to use. + index: usize, + /// `set` is used to set the next [`AccessMethodSetting`] to use. + set: Option<AccessMethodSetting>, } impl AccessModeSelector { pub(crate) async fn spawn( cache_dir: PathBuf, relay_selector: RelaySelector, - connection_modes: Vec<AccessMethodSetting>, + access_method_settings: Settings, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, address_cache: AddressCache, ) -> Result<AccessModeSelectorHandle> { let (cmd_tx, cmd_rx) = mpsc::unbounded(); - let mut connection_modes = - ConnectionModesIterator::new(connection_modes).unwrap_or_else(|_| { - // No settings seem to have been found. Default to using the the - // direct access method. - let default = mullvad_types::access_method::Settings::direct(); - ConnectionModesIterator::new(vec![default]).expect( - "Failed to create the data structure responsible for managing access methods", - ) - }); - - let initial_connection_mode = { - let next = connection_modes.next().ok_or(Error::NoAccessMethods)?; - Self::resolve_inner(next, &relay_selector, &address_cache).await - }; + let (index, next) = Self::get_next_inner(0, &access_method_settings); + let initial_connection_mode = + Self::resolve_inner(next, &relay_selector, &address_cache).await; let selector = AccessModeSelector { cmd_rx, cache_dir, relay_selector, - connection_modes, + access_method_settings, address_cache, access_method_event_sender, current: initial_connection_mode, + index, + set: None, }; tokio::spawn(selector.into_future()); @@ -312,7 +306,14 @@ impl AccessModeSelector { /// Set the next access method to be returned by the [`Stream`] produced by /// calling `into_stream`. fn set_access_method(&mut self, value: AccessMethodSetting) { - self.connection_modes.set_access_method(value); + if let Some(index) = self + .access_method_settings + .iter() + .position(|access_method| access_method.get_id() == value.get_id()) + { + self.index = index; + self.set = Some(value); + } } async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> { @@ -351,7 +352,7 @@ impl AccessModeSelector { ); } - let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?; + let access_method = self.get_next(); log::info!( "A new API access method has been selected: {name}", name = access_method.name @@ -388,17 +389,50 @@ impl AccessModeSelector { self.current = resolved; Ok(self.current.connection_mode.clone()) } + + fn get_next(&mut self) -> AccessMethodSetting { + if let Some(access_method) = self.set.take() { + access_method + } else { + let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings); + self.index = index; + next + } + } + + fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { + let xs: Vec<_> = access_methods.iter().collect(); + for offset in 1..=access_methods.cardinality() { + let index = (start + offset) % access_methods.cardinality(); + if let Some(&candidate) = xs.get(index) { + if candidate.enabled { + return (index, candidate.clone()); + } + } + } + (0, access_methods.direct().clone()) + } + fn on_update_access_methods( &mut self, tx: ResponseTx<()>, - values: Vec<AccessMethodSetting>, + access_methods: Settings, ) -> Result<()> { - self.update_access_methods(values)?; + self.update_access_methods(access_methods); self.reply(tx, ()) } - fn update_access_methods(&mut self, values: Vec<AccessMethodSetting>) -> Result<()> { - self.connection_modes.update_access_methods(values) + fn update_access_methods(&mut self, access_methods: Settings) { + let removed_active = !access_methods + .iter() + .any(|access_method| access_method.get_id() == self.current.setting.get_id()); + if removed_active { + // A new access mehtod will suddenly have the same index as the one + // we are removing, but we want it to still be a candidate. A minor + // hack to achieve this is to simply decrement the current index. + self.index = self.index.saturating_sub(1); + } + self.access_method_settings = access_methods; } pub async fn on_resolve_access_method( @@ -455,76 +489,6 @@ fn resolve_connection_mode( } } -/// An iterator which will always produce an [`AccessMethod`]. -/// -/// Safety: It is always safe to [`unwrap`] after calling [`next`] on a -/// [`std::iter::Cycle`], so thereby it is safe to always call [`unwrap`] on a -/// [`ConnectionModesIterator`]. -/// -/// [`unwrap`]: Option::unwrap -/// [`next`]: std::iter::Iterator::next -pub struct ConnectionModesIterator { - available_modes: Box<dyn Iterator<Item = AccessMethodSetting> + Send>, - next: Option<AccessMethodSetting>, - current: AccessMethodSetting, -} - -impl ConnectionModesIterator { - pub fn new( - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<ConnectionModesIterator, Error> { - let mut iterator = Self::new_iterator(access_methods)?; - Ok(Self { - next: None, - current: iterator.next().ok_or(Error::NoAccessMethods)?, - available_modes: iterator, - }) - } - - /// Set the next [`AccessMethod`] to be returned from this iterator. - pub fn set_access_method(&mut self, next: AccessMethodSetting) { - self.next = Some(next); - } - - /// Update the collection of [`AccessMethod`] which this iterator will - /// return. - pub fn update_access_methods( - &mut self, - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<(), Error> { - self.available_modes = Self::new_iterator(access_methods)?; - Ok(()) - } - - /// Create a cyclic iterator of [`AccessMethodSetting`]s. - /// - /// If the `access_methods` argument is an empty vector, an [`Error`] is - /// returned. - fn new_iterator( - access_methods: Vec<AccessMethodSetting>, - ) -> std::result::Result<Box<dyn Iterator<Item = AccessMethodSetting> + Send>, Error> { - if access_methods.is_empty() { - Err(Error::NoAccessMethods) - } else { - Ok(Box::new(access_methods.into_iter().cycle())) - } - } -} - -impl Iterator for ConnectionModesIterator { - type Item = AccessMethodSetting; - - fn next(&mut self) -> Option<Self::Item> { - let next = self - .next - .take() - .or_else(|| self.available_modes.next()) - .unwrap(); - self.current = next.clone(); - Some(next) - } -} - pub fn resolve_allowed_endpoint( connection_mode: &ApiConnectionMode, fallback: SocketAddr, diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d54598b284..00f4613ffb 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -280,8 +280,6 @@ pub enum DaemonCommand { DeleteCustomList(ResponseTx<(), Error>, mullvad_types::custom_list::Id), /// Update a custom list with a given id UpdateCustomList(ResponseTx<(), Error>, CustomList), - /// Get API access methods - GetApiAccessMethods(ResponseTx<Vec<AccessMethodSetting>, Error>), /// Add API access methods AddApiAccessMethod( ResponseTx<mullvad_types::access_method::Id, Error>, @@ -709,15 +707,12 @@ where .set_config(new_selector_config(settings)); }); - let connection_modes = settings.api_access_methods.collect_enabled(); - let connection_modes_address_cache = api_runtime.address_cache.clone(); - let connection_modes_handler = api::AccessModeSelector::spawn( cache_dir.clone(), relay_selector.clone(), - connection_modes, + settings.api_access_methods.clone(), internal_event_tx.to_specialized_sender(), - connection_modes_address_cache.clone(), + api_runtime.address_cache.clone().clone(), ) .await .map_err(Error::ApiConnectionModeError)?; @@ -1244,7 +1239,6 @@ where DeleteCustomList(tx, id) => self.on_delete_custom_list(tx, id).await, UpdateCustomList(tx, update) => self.on_update_custom_list(tx, update).await, GetVersionInfo(tx) => self.on_get_version_info(tx), - GetApiAccessMethods(tx) => self.on_get_api_access_methods(tx), AddApiAccessMethod(tx, name, enabled, access_method) => { self.on_add_access_method(tx, name, enabled, access_method) .await @@ -2422,11 +2416,6 @@ where Self::oneshot_send(tx, result, "update_custom_list response"); } - fn on_get_api_access_methods(&mut self, tx: ResponseTx<Vec<AccessMethodSetting>, Error>) { - let result = Ok(self.settings.api_access_methods.cloned()); - Self::oneshot_send(tx, result, "get_api_access_methods response"); - } - async fn on_add_access_method( &mut self, tx: ResponseTx<mullvad_types::access_method::Id, Error>, diff --git a/mullvad-daemon/src/migrations/v7.rs b/mullvad-daemon/src/migrations/v7.rs index efd3193624..4614cedce9 100644 --- a/mullvad-daemon/src/migrations/v7.rs +++ b/mullvad-daemon/src/migrations/v7.rs @@ -75,6 +75,10 @@ pub struct ShadowsocksProxySettings { /// that instead of having a Socks5 and Shadowsocks variant instead has a Socks5Local, Socks5Remote /// and Shadowsocks variant. /// +/// The predefined access methods "Direct" and "Mullvad Bridges" are now stored as distinct keys in +/// the api_access_methods settings, separating them from user-defined access methods in the settings +/// datastructure. +/// /// We also take the oppertunity to rename a couple of fields that relate to proxy types. /// We rename /// - shadowsocks.peer to shadowsocks.endpoint @@ -141,6 +145,62 @@ fn migrate_api_access_settings(settings: &mut serde_json::Value) -> Result<()> { } } + // Step 1. Rename { "api_access_methods": { "access_method_settings": .. } } to { "api_access_methods": { "custom": .. } }. + // Step 2. Collect all of the built-in methods from { "api_access_methods": { "custom": [ .. ] } }. + // Step 3. Remove all of the built-in methods from { "api_access_methods": { "custom": [ .. ] } }. + // Step 4. Add the collected built-in methods from step 2 to { "api_access_methods": { .. } } under some appropriate key. + if let Some(access_method_settings) = settings + .get_mut("api_access_methods") + .and_then(serde_json::value::Value::as_object_mut) + { + // Step 1. + rename_map_field(access_method_settings, "access_method_settings", "custom")?; + + if let Some(access_method_settings_list) = access_method_settings + .get_mut("custom") + .and_then(serde_json::value::Value::as_array_mut) + { + // Step 2. + let built_ins: Vec<_> = access_method_settings_list + .iter() + .filter(|value| { + value + .get("access_method") + .and_then(|value| value.get("built_in")) + .is_some() + }) + .cloned() + .collect(); + + // Step 3. + for built_in in built_ins.iter() { + access_method_settings_list + .retain(|access_method| access_method.get("id") != built_in.get("id")); + } + + // Step 4. + // Note that the only supported built-in access methods at this time + // are "Direct" and "Mullvad Bridges", so we may discard anything + // else. + let built_ins: Vec<_> = built_ins + .into_iter() + .filter_map(|built_in| { + match built_in + .get("access_method") + .and_then(|value| value.get("built_in")) + .and_then(|value| value.as_str()) + { + Some("direct") => Some(("direct".to_string(), built_in)), + Some("bridge") => Some(("mullvad_bridges".to_string(), built_in)), + Some(_) | None => None, + } + }) + .collect(); + + access_method_settings.extend(built_ins); + } + } + Ok(()) } @@ -228,15 +288,24 @@ fn extract_str(opt: Option<&serde_json::Value>) -> Result<&str> { .ok_or(Error::InvalidSettingsContent) } -fn rename_field(object: &mut serde_json::Value, old_name: &str, new_name: &str) -> Result<()> { - object[new_name] = object +fn rename_field(value: &mut serde_json::Value, old_name: &str, new_name: &str) -> Result<()> { + value + .as_object_mut() + .ok_or(Error::InvalidSettingsContent) + .and_then(|object| rename_map_field(object, old_name, new_name)) +} + +fn rename_map_field( + object: &mut serde_json::Map<String, serde_json::Value>, + old_name: &str, + new_name: &str, +) -> Result<()> { + let old_value = object .get(old_name) .ok_or(Error::InvalidSettingsContent)? .clone(); - object - .as_object_mut() - .ok_or(Error::InvalidSettingsContent)? - .remove(old_name); + let _ = object.insert(new_name.to_string(), old_value); + object.remove(old_name); Ok(()) } @@ -475,23 +544,23 @@ mod test { } }, "api_access_methods": { - "access_method_settings": [ - { - "id": "8cbdcfc8-fa7b-41de-8d12-26fa37439f89", - "name": "Direct", - "enabled": true, - "access_method": { - "built_in": "direct" - } - }, - { - "id": "1d0d8891-dbb3-4439-a8f7-0e7d742ddbe4", - "name": "Mullvad Bridges", - "enabled": true, - "access_method": { - "built_in": "bridge" - } - }, + "direct": { + "id": "8cbdcfc8-fa7b-41de-8d12-26fa37439f89", + "name": "Direct", + "enabled": true, + "access_method": { + "built_in": "direct" + } + }, + "mullvad_bridges": { + "id": "1d0d8891-dbb3-4439-a8f7-0e7d742ddbe4", + "name": "Mullvad Bridges", + "enabled": true, + "access_method": { + "built_in": "bridge" + } + }, + "custom": [ { "id": "1aaff7ab-e09f-4c03-af02-765e41943a7b", "name": "localsox", @@ -851,7 +920,7 @@ mod test { r#" { "api_access_methods": { - "access_method_settings": [ + "custom": [ { "id": "5eb9b2ee-f764-47c8-8111-ee95910d0099", "name": "mysocks", @@ -880,7 +949,6 @@ mod test { #[test] fn test_api_access_methods_custom_socks5_remote() { - println!("wew"); let mut pre: serde_json::Value = serde_json::from_str( r#" { @@ -910,7 +978,7 @@ mod test { r#" { "api_access_methods": { - "access_method_settings": [ + "custom": [ { "id": "8e377232-8a53-4414-8b8f-f487227aaedb", "name": "remotesox", @@ -965,7 +1033,7 @@ mod test { r#" { "api_access_methods": { - "access_method_settings": [ + "custom": [ { "id": "74e5c659-acdd-4cad-a632-a25bf63c20e2", "name": "remotess", @@ -989,4 +1057,195 @@ mod test { migrate_api_access_settings(&mut pre).unwrap(); assert_eq!(pre, post); } + + #[test] + fn test_api_access_methods_extract_direct() { + let mut pre: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "access_method_settings": [ + { + "id": "8cbdcfc8-fa7b-41de-8d12-26fa37439f89", + "name": "Direct", + "enabled": true, + "access_method": { + "built_in": "direct" + } + } + ] + } +} +"#, + ) + .unwrap(); + + let post: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "direct": { + "id": "8cbdcfc8-fa7b-41de-8d12-26fa37439f89", + "name": "Direct", + "enabled": true, + "access_method": { + "built_in": "direct" + } + }, + "custom": [] + } +} +"#, + ) + .unwrap(); + + migrate_api_access_settings(&mut pre).unwrap(); + assert_eq!(pre, post); + } + + #[test] + fn test_api_access_methods_extract_mullvad_bridges() { + let mut pre: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "access_method_settings": [ + { + "id": "1d0d8891-dbb3-4439-a8f7-0e7d742ddbe4", + "name": "Mullvad Bridges", + "enabled": true, + "access_method": { + "built_in": "bridge" + } + } + ] + } +} +"#, + ) + .unwrap(); + + let post: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "mullvad_bridges": { + "id": "1d0d8891-dbb3-4439-a8f7-0e7d742ddbe4", + "name": "Mullvad Bridges", + "enabled": true, + "access_method": { + "built_in": "bridge" + } + }, + "custom": [] + } +} +"#, + ) + .unwrap(); + + migrate_api_access_settings(&mut pre).unwrap(); + assert_eq!(pre, post); + } + + #[test] + fn test_api_access_methods_do_not_extract_custom_methods() { + let mut pre: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "access_method_settings": [ + { + "id": "1aaff7ab-e09f-4c03-af02-765e41943a7b", + "name": "localsox", + "enabled": false, + "access_method": { + "custom": { + "socks5": { + "local": { + "remote_endpoint": { + "address": "1.3.3.7:1080", + "protocol": "tcp" + }, + "local_port": 1079 + } + } + } + } + } + ] + } +} +"#, + ) + .unwrap(); + + let post: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "custom": [ + { + "id": "1aaff7ab-e09f-4c03-af02-765e41943a7b", + "name": "localsox", + "enabled": false, + "access_method": { + "custom": { + "socks5_local": { + "remote_endpoint": { + "address": "1.3.3.7:1080", + "protocol": "tcp" + }, + "local_port": 1079 + } + } + } + } + ] + } +} +"#, + ) + .unwrap(); + + migrate_api_access_settings(&mut pre).unwrap(); + assert_eq!(pre, post); + } + + #[test] + fn test_api_access_methods_extract_corrupt_built_in() { + let mut pre: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "access_method_settings": [ + { + "id": "1d0d8891-dbb3-4439-a8f7-0e7d742ddbe4", + "name": "Mullvad Bridges", + "enabled": true, + "access_method": { + "built_in": "some_other_alternative" + } + } + ] + } +} +"#, + ) + .unwrap(); + + let post: serde_json::Value = serde_json::from_str( + r#" +{ + "api_access_methods": { + "custom": [] + } +} +"#, + ) + .unwrap(); + + migrate_api_access_settings(&mut pre).unwrap(); + assert_eq!(pre, post); + } } diff --git a/mullvad-daemon/src/settings/mod.rs b/mullvad-daemon/src/settings/mod.rs index 47a2edc989..d411bce2b4 100644 --- a/mullvad-daemon/src/settings/mod.rs +++ b/mullvad-daemon/src/settings/mod.rs @@ -38,6 +38,9 @@ pub enum Error { #[error(display = "Unable to write settings to {}", _0)] WriteError(String, #[error(source)] io::Error), + + #[error(display = "Failed to apply settings update")] + UpdateFailed(Box<dyn std::error::Error + Send + Sync>), } /// Converts an [Error] to a management interface status @@ -50,7 +53,7 @@ impl From<Error> for mullvad_management_interface::Status { Error::DeleteError(..) | Error::WriteError(..) | Error::ReadError(..) => { Status::new(Code::FailedPrecondition, error.to_string()) } - Error::SerializeError(..) | Error::ParseError(..) => { + Error::SerializeError(..) | Error::ParseError(..) | Error::UpdateFailed(..) => { Status::new(Code::Internal, error.to_string()) } } @@ -224,18 +227,82 @@ impl SettingsPersister { settings } - /// Edit the settings in a closure, and write the changes, if any, to disk. + /// Edit the settings in a closure and write the changes to disk. + /// + /// # On success + /// + /// Returns a boolean indicating whether any settings were changed. + /// + /// # On failure + /// + /// If the settings could not be written to disk, all changes are rolled + /// back, and an error is returned. + /// + /// # Note /// - /// On success, the function returns a boolean indicating whether any settings were changed. - /// If the settings could not be written to disk, all changes are rolled back, and an error is - /// returned. + /// If no settings were changed, no I/O will be performed. pub async fn update( &mut self, update_fn: impl FnOnce(&mut Settings), ) -> Result<MadeChanges, Error> { + self.try_update(|settings| -> Result<(), Error> { + update_fn(settings); + Ok(()) + }) + .await + } + + /// Edit the settings in a closure, and write the changes to disk. + /// + /// # On success + /// + /// Returns a boolean indicating whether any settings were changed. + /// + /// # On failure + /// + /// `try_update` may fail in two scenarios + /// + /// ## The settings could not be written to disk + /// + /// In this case, all changes are rolled back and an error is returned. + /// + /// ## `update_fn` failed + /// + /// If `update_fn` were to fail the error will be propagated through the + /// [`Error::UpdateFailed`] error variant. Since the error will be boxed, it + /// has to be downcasted at runtime using [`Box::downcast`] in case you want + /// to inspect the error closer. + /// + /// ```ignore + /// #[derive(Debug, err_derive::Error)] + /// pub enum MyError { + /// #[error(display = "Failed for this reason: {:?}", _0)] + /// Failed(String), + /// } + /// + /// let settings = Settings::default_settings(); + /// let err = settings.try_update(|settings| { + /// // Perform some update on the settings + /// settings.allow_lan = !settings.allow_lan; + /// // Fail the update procedure due to some error + /// Err(MyError::Failed("No particular reason".to_string())) + /// }); + /// + /// matches!(err, Error::UpdateFailed(_)) ; + /// assert_eq!(settings, Settings::default_settings()) + /// ``` + pub async fn try_update<E>( + &mut self, + update_fn: impl FnOnce(&mut Settings) -> Result<(), E>, + ) -> Result<MadeChanges, Error> + where + E: std::error::Error + Send + Sync + 'static, + { let mut new_settings = self.settings.clone(); - update_fn(&mut new_settings); + update_fn(&mut new_settings) + .map_err(Box::from) + .map_err(Error::UpdateFailed)?; if self.settings == new_settings { return Ok(false); diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 21de18df8e..7fbdb6eba6 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -380,7 +380,11 @@ message NewAccessMethodSetting { AccessMethod access_method = 3; } -message ApiAccessMethodSettings { repeated AccessMethodSetting access_method_settings = 1; } +message ApiAccessMethodSettings { + AccessMethodSetting direct = 1; + AccessMethodSetting mullvad_bridges = 2; + repeated AccessMethodSetting custom = 3; +} message Settings { RelaySettings relay_settings = 1; diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 52dbd72b47..f30b613171 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -176,19 +176,20 @@ impl MullvadProxyClient { } pub async fn get_api_access_methods(&mut self) -> Result<Vec<AccessMethodSetting>> { - self.0 + let access_method_settings = self + .0 .get_settings(()) .await .map_err(Error::Rpc)? .into_inner() .api_access_methods - .ok_or(Error::ApiAccessMethodSettingsNotFound)? - .access_method_settings - .into_iter() - .map(|api_access_method| { - AccessMethodSetting::try_from(api_access_method).map_err(Error::InvalidResponse) - }) - .collect() + .ok_or(Error::ApiAccessMethodSettingsNotFound) + .and_then(|access_method_settings| { + access_method::Settings::try_from(access_method_settings) + .map_err(Error::InvalidResponse) + })?; + + Ok(access_method_settings.iter().cloned().collect()) } pub async fn get_api_access_method( diff --git a/mullvad-management-interface/src/types/conversions/access_method.rs b/mullvad-management-interface/src/types/conversions/access_method.rs index ec5681d74d..d9758a571c 100644 --- a/mullvad-management-interface/src/types/conversions/access_method.rs +++ b/mullvad-management-interface/src/types/conversions/access_method.rs @@ -5,35 +5,49 @@ mod settings { use crate::types::{proto, FromProtobufTypeError}; use mullvad_types::access_method; - impl From<&access_method::Settings> for proto::ApiAccessMethodSettings { - fn from(settings: &access_method::Settings) -> Self { + impl From<access_method::Settings> for proto::ApiAccessMethodSettings { + fn from(settings: access_method::Settings) -> Self { Self { - access_method_settings: settings - .access_method_settings - .iter() - .map(|method| method.clone().into()) + direct: Some(settings.direct().clone().into()), + mullvad_bridges: Some(settings.mullvad_bridges().clone().into()), + custom: settings + .iter_custom() + .cloned() + .map(|method| method.into()) .collect(), } } } - impl From<access_method::Settings> for proto::ApiAccessMethodSettings { - fn from(settings: access_method::Settings) -> Self { - proto::ApiAccessMethodSettings::from(&settings) - } - } - impl TryFrom<proto::ApiAccessMethodSettings> for access_method::Settings { type Error = FromProtobufTypeError; fn try_from(settings: proto::ApiAccessMethodSettings) -> Result<Self, Self::Error> { - Ok(Self { - access_method_settings: settings - .access_method_settings - .iter() - .map(access_method::AccessMethodSetting::try_from) - .collect::<Result<Vec<access_method::AccessMethodSetting>, _>>()?, - }) + let direct = settings + .direct + .ok_or(FromProtobufTypeError::InvalidArgument( + "Could not deserialize Direct Access Method from protobuf", + )) + .and_then(access_method::AccessMethodSetting::try_from)?; + + let mullvad_bridges = settings + .mullvad_bridges + .ok_or(FromProtobufTypeError::InvalidArgument( + "Could not deserialize Mullvad Bridges Access Method from protobuf", + )) + .and_then(access_method::AccessMethodSetting::try_from)?; + + let custom = settings + .custom + .iter() + .map(access_method::AccessMethodSetting::try_from) + .collect::<Result<Vec<_>, _>>()?; + + Ok(access_method::Settings::new( + direct, + mullvad_bridges, + custom, + )) } } } diff --git a/mullvad-management-interface/src/types/conversions/settings.rs b/mullvad-management-interface/src/types/conversions/settings.rs index a7c4bcd78c..a4d6313158 100644 --- a/mullvad-management-interface/src/types/conversions/settings.rs +++ b/mullvad-management-interface/src/types/conversions/settings.rs @@ -43,7 +43,7 @@ impl From<&mullvad_types::settings::Settings> for proto::Settings { settings.custom_lists.clone(), )), api_access_methods: Some(proto::ApiAccessMethodSettings::from( - &settings.api_access_methods, + settings.api_access_methods.clone(), )), relay_overrides: settings .relay_overrides diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs index 73ab671c9c..80ba7ab69c 100644 --- a/mullvad-types/src/access_method.rs +++ b/mullvad-types/src/access_method.rs @@ -1,102 +1,162 @@ -use std::str::FromStr; - use serde::{Deserialize, Serialize}; use talpid_types::net::proxy::{CustomProxy, Shadowsocks, Socks5Local, Socks5Remote}; -/// Dttings for API access methods. +/// Settings for API access methods. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Settings { - pub access_method_settings: Vec<AccessMethodSetting>, + direct: AccessMethodSetting, + mullvad_bridges: AccessMethodSetting, + /// Custom API access methods. + custom: Vec<AccessMethodSetting>, } impl Settings { + pub fn new( + direct: AccessMethodSetting, + mullvad_bridges: AccessMethodSetting, + custom: Vec<AccessMethodSetting>, + ) -> Settings { + Settings { + direct, + mullvad_bridges, + custom, + } + } + /// Append an [`AccessMethod`] to the end of `api_access_methods`. pub fn append(&mut self, api_access_method: AccessMethodSetting) { - self.access_method_settings.push(api_access_method) + self.custom.push(api_access_method) } /// Remove an [`ApiAccessMethod`] from `api_access_methods`. - pub fn remove(&mut self, api_access_method: &Id) { - self.retain(|method| method.get_id() != *api_access_method) + /// + /// This function will return an error if a built-in API access is about to + /// be removed. + pub fn remove(&mut self, api_access_method: &Id) -> Result<(), Error> { + let maybe_setting = self + .custom + .iter() + .find(|setting| setting.get_id() == *api_access_method); + + match maybe_setting { + Some(x) => match x.access_method { + AccessMethod::BuiltIn(ref built_in) => Err(Error::RemoveBuiltin { + attempted: built_in.clone(), + }), + AccessMethod::Custom(_) => { + self.custom + .retain(|method| method.get_id() != *api_access_method); + self.ensure_consistent_state(); + Ok(()) + } + }, + None => Ok(()), + } } - /// Search for any [`AccessMethod`] in `api_access_methods` which matches `predicate`. - pub fn find<P>(&self, predicate: P) -> Option<&AccessMethodSetting> - where - P: Fn(&AccessMethodSetting) -> bool, - { - self.access_method_settings - .iter() - .find(|api_access_method| predicate(api_access_method)) + /// Update an existing [`AccessMethodSetting`] chosen by `predicate`, in a + /// closure `f`, saving the result to `self`. + /// + /// Returns a bool to indicate whether some [`AccessMethodSetting`] was + /// updated. + pub fn update( + &mut self, + predicate: impl Fn(&AccessMethodSetting) -> bool, + f: impl FnOnce(&AccessMethodSetting) -> AccessMethodSetting, + ) -> bool { + let mut updated = false; + if let Some(access_method) = self.iter_mut().find(|setting| predicate(setting)) { + *access_method = f(access_method); + updated = true; + } + self.ensure_consistent_state(); + + updated } - /// Search for any [`AccessMethod`] in `api_access_methods`. - pub fn find_mut<P>(&mut self, predicate: P) -> Option<&mut AccessMethodSetting> - where - P: Fn(&AccessMethodSetting) -> bool, - { - self.access_method_settings - .iter_mut() - .find(|api_access_method| predicate(api_access_method)) + /// Check that `self` contains atleast one enabled access methods. If not, + /// the `Direct` access method is re-enabled. + fn ensure_consistent_state(&mut self) { + if self.collect_enabled().is_empty() { + self.direct.enable(); + } } - /// Search for a particular [`AccessMethod`] in `api_access_methods`. - pub fn find_by_id(&self, element: &Id) -> Option<&AccessMethodSetting> { - self.find(|api_access_method| *element == api_access_method.get_id()) + // TODO(markus): This can surely be removed. + /// Retrieve all [`AccessMethodSetting`]s which are enabled. + pub fn collect_enabled(&self) -> Vec<AccessMethodSetting> { + self.iter() + .filter(|access_method| access_method.enabled) + .cloned() + .collect() } - /// Search for a particular [`AccessMethod`] in `api_access_methods`. - pub fn find_by_id_mut(&mut self, element: &Id) -> Option<&mut AccessMethodSetting> { - self.find_mut(|api_access_method| *element == api_access_method.get_id()) + /// Iterate over references of built-in & custom access methods. + pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> { + use std::iter::once; + once(&self.direct) + .chain(once(&self.mullvad_bridges)) + .chain(&self.custom) } - /// Equivalent to [`Vec::retain`]. - pub fn retain<F>(&mut self, f: F) - where - F: FnMut(&AccessMethodSetting) -> bool, - { - self.access_method_settings.retain(f) + /// Iterate over mutable references of built-in & custom access methods. + fn iter_mut(&mut self) -> impl Iterator<Item = &mut AccessMethodSetting> { + use std::iter::once; + once(&mut self.direct) + .chain(once(&mut self.mullvad_bridges)) + .chain(&mut self.custom) } - /// Clone the content of `api_access_methods`. - pub fn cloned(&self) -> Vec<AccessMethodSetting> { - self.access_method_settings.clone() + /// Iterate over references of custom access methods. + pub fn iter_custom(&self) -> impl Iterator<Item = &AccessMethodSetting> { + self.custom.iter() } - /// Get a reference to the `Direct` access method instance of this [`Settings`]. - pub fn get_direct(&mut self) -> Option<&mut AccessMethodSetting> { - self.find_mut(|access_method| { - access_method.access_method == BuiltInAccessMethod::Direct.into() - }) + /// Return the total number of access methods. + /// This counts both enabled and disabled [`AccessMethodSetting`]s. + pub fn cardinality(&self) -> usize { + 1 + // 'Direct' + 1 + // 'Mullvad bridges' + self.custom.len() } - pub fn direct() -> AccessMethodSetting { + pub fn direct(&self) -> &AccessMethodSetting { + &self.direct + } + + pub fn mullvad_bridges(&self) -> &AccessMethodSetting { + &self.mullvad_bridges + } + + // TODO(markus): This can probably be made private + pub fn create_direct() -> AccessMethodSetting { let method = BuiltInAccessMethod::Direct; AccessMethodSetting::new(method.canonical_name(), true, AccessMethod::from(method)) } - pub fn mullvad_bridges() -> AccessMethodSetting { + fn create_mullvad_bridges() -> AccessMethodSetting { let method = BuiltInAccessMethod::Bridge; AccessMethodSetting::new(method.canonical_name(), true, AccessMethod::from(method)) } - - /// Retrieve all [`AccessMethodSetting`]s which are enabled. - pub fn collect_enabled(&self) -> Vec<AccessMethodSetting> { - self.cloned() - .into_iter() - .filter(|access_method| access_method.enabled) - .collect() - } } impl Default for Settings { fn default() -> Self { Self { - access_method_settings: vec![Settings::direct(), Settings::mullvad_bridges()], + direct: Settings::create_direct(), + mullvad_bridges: Settings::create_mullvad_bridges(), + custom: vec![], } } } +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// Built-in access methods can not be removed + #[error(display = "Cannot remove built-in access method {}", attempted)] + RemoveBuiltin { attempted: BuiltInAccessMethod }, +} + /// API Access Method datastructure /// /// Mirrors the protobuf definition @@ -120,6 +180,7 @@ impl Id { /// Tries to parse a UUID from a raw String. If it is successful, an /// [`Id`] is instantiated. pub fn from_string(id: String) -> Option<Self> { + use std::str::FromStr; uuid::Uuid::from_str(&id).ok().map(Self) } } @@ -178,6 +239,10 @@ impl AccessMethodSetting { self.enabled } + pub fn disabled(&self) -> bool { + !self.enabled + } + pub fn as_custom(&self) -> Option<&CustomProxy> { self.access_method.as_custom() } @@ -223,6 +288,12 @@ impl BuiltInAccessMethod { } } +impl std::fmt::Display for BuiltInAccessMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.canonical_name()) + } +} + impl From<BuiltInAccessMethod> for AccessMethod { fn from(value: BuiltInAccessMethod) -> Self { AccessMethod::BuiltIn(value) diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs index e6886c9e3c..af8c11e8d7 100644 --- a/mullvad-types/src/settings/mod.rs +++ b/mullvad-types/src/settings/mod.rs @@ -79,7 +79,7 @@ pub struct Settings { /// All of the custom relay lists #[cfg_attr(target_os = "android", jnix(skip))] pub custom_lists: CustomListsSettings, - /// API access methods. + /// API access methods #[cfg_attr(target_os = "android", jnix(skip))] pub api_access_methods: access_method::Settings, /// If the daemon should allow communication with private (LAN) networks. |
