diff options
16 files changed, 194 insertions, 59 deletions
diff --git a/ios/MullvadMockData/MullvadREST/AccessMethodRepository+Stub.swift b/ios/MullvadMockData/MullvadREST/AccessMethodRepository+Stub.swift index 6e7dbb93ef..148927711f 100644 --- a/ios/MullvadMockData/MullvadREST/AccessMethodRepository+Stub.swift +++ b/ios/MullvadMockData/MullvadREST/AccessMethodRepository+Stub.swift @@ -28,7 +28,7 @@ public struct AccessMethodRepositoryStub: AccessMethodRepositoryDataSource, @unc passthroughSubject.value } - public func saveLastReachable(_ method: PersistentAccessMethod) {} + public func requestAccessMethod(_ method: PersistentAccessMethod) {} public func fetchLastReachable() -> PersistentAccessMethod { directAccess diff --git a/ios/MullvadREST/Transport/AccessMethodIterator.swift b/ios/MullvadREST/Transport/AccessMethodIterator.swift index 91de54bbd7..d1672ad8fb 100644 --- a/ios/MullvadREST/Transport/AccessMethodIterator.swift +++ b/ios/MullvadREST/Transport/AccessMethodIterator.swift @@ -8,10 +8,13 @@ import Combine import Foundation +import MullvadLogging import MullvadSettings import MullvadTypes final class AccessMethodIterator: @unchecked Sendable, SwiftConnectionModeProviding { + private let logger = Logger(label: "AccessMethodIterator") + private let dataSource: AccessMethodRepositoryDataSource private var index = 0 @@ -47,13 +50,14 @@ final class AccessMethodIterator: @unchecked Sendable, SwiftConnectionModeProvid index = firstIndex } - dataSource.saveLastReachable(pick()) + let newAccessMethod = pick() + dataSource.requestAccessMethod(newAccessMethod) } func rotate() { let (partial, isOverflow) = index.addingReportingOverflow(1) index = isOverflow ? 0 : partial - dataSource.saveLastReachable(pick()) + dataSource.requestAccessMethod(pick()) } func pick() -> PersistentAccessMethod { diff --git a/ios/MullvadRustRuntime/MullvadAccessMethodReceiver.swift b/ios/MullvadRustRuntime/MullvadAccessMethodReceiver.swift index c683a10af6..bb94759717 100644 --- a/ios/MullvadRustRuntime/MullvadAccessMethodReceiver.swift +++ b/ios/MullvadRustRuntime/MullvadAccessMethodReceiver.swift @@ -17,12 +17,12 @@ public class MullvadAccessMethodReceiver { public init( apiContext: MullvadApiContext, accessMethodsDataSource: AnyPublisher<[PersistentAccessMethod], Never>, - lastReachableDataSource: AnyPublisher<PersistentAccessMethod, Never> + requestDataSource: AnyPublisher<PersistentAccessMethod, Never> ) { self.apiContext = apiContext - lastReachableDataSource.sink { [weak self] in - self?.saveLastReachable($0) + requestDataSource.sink { [weak self] latestReachable in + self?.saveLastReachable(latestReachable) } .store(in: &cancellables) diff --git a/ios/MullvadRustRuntime/MullvadApiContext.swift b/ios/MullvadRustRuntime/MullvadApiContext.swift index 7600b79e70..d2d289b951 100644 --- a/ios/MullvadRustRuntime/MullvadApiContext.swift +++ b/ios/MullvadRustRuntime/MullvadApiContext.swift @@ -6,18 +6,28 @@ // Copyright © 2025 Mullvad VPN AB. All rights reserved. // +import MullvadSettings import MullvadTypes -public struct MullvadApiContext: @unchecked Sendable { +func onAccessChangeCallback(selfPtr: UnsafeRawPointer?, bytes: UnsafePointer<UInt8>?) { + guard let selfPtr, let bytes else { return } + let context = Unmanaged<MullvadApiContext>.fromOpaque(selfPtr).takeUnretainedValue() + + let uuid = NSUUID(uuidBytes: bytes) as UUID + context.accessMethodChangeListener?.accessMethodChangedTo(uuid) +} + +public class MullvadApiContext: @unchecked Sendable { enum Error: Swift.Error { case failedToConstructApiClient } - public let context: SwiftApiContext + public private(set) var context: SwiftApiContext! private let shadowsocksBridgeProvider: SwiftShadowsocksBridgeProviding! private let shadowsocksBridgeProviderWrapper: SwiftShadowsocksLoaderWrapper! private let addressCacheWrapper: SwiftAddressCacheWrapper! private let addressCacheProvider: AddressCacheProviding! + public var accessMethodChangeListener: MullvadAccessMethodChangeListening? public init( host: String, @@ -36,6 +46,7 @@ public struct MullvadApiContext: @unchecked Sendable { self.addressCacheProvider = defaultAddressCache self.addressCacheWrapper = iniSwiftAddressCacheWrapper(provider: defaultAddressCache) + let selfPtr = Unmanaged.passUnretained(self).toOpaque() context = switch disableTls { case true: mullvad_api_init_new_tls_disabled( @@ -44,7 +55,9 @@ public struct MullvadApiContext: @unchecked Sendable { domain, shadowsocksBridgeProviderWrapper, accessMethodWrapper, - addressCacheWrapper + addressCacheWrapper, + onAccessChangeCallback, + selfPtr ) case false: mullvad_api_init_new( @@ -53,7 +66,9 @@ public struct MullvadApiContext: @unchecked Sendable { domain, shadowsocksBridgeProviderWrapper, accessMethodWrapper, - addressCacheWrapper + addressCacheWrapper, + onAccessChangeCallback, + selfPtr ) } diff --git a/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h b/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h index 751e58aae0..c8443c1fcf 100644 --- a/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h +++ b/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h @@ -169,7 +169,10 @@ struct SwiftApiContext mullvad_api_init_new_tls_disabled(const char *host, const char *domain, struct SwiftShadowsocksLoaderWrapper bridge_provider, struct SwiftAccessMethodSettingsWrapper settings_provider, - struct SwiftAddressCacheWrapper address_cache); + struct SwiftAddressCacheWrapper address_cache, + void (*access_method_change_callback)(const void*, + const uint8_t*), + const void *access_method_change_context); /** * # Safety @@ -180,6 +183,14 @@ struct SwiftApiContext mullvad_api_init_new_tls_disabled(const char *host, * `address` must be a pointer to a null terminated string representing a socket address through which * the Mullvad API can be reached directly. * + * address_method_change_callback is a function with the C calling convention which will be called + * whenever the access method changes with a user-specified opaque pointer and a pointer to the bytes + * of the access method's UUID. Note that this callback must remain valid for the lifetime of the + * program. + * + * access_method_change_context is the pointer passed verbatim to the callback. It is not dereferenced + * by the Rust code, but remains opaque. + * * If a context cannot be constructed this function will panic since the call site would not be able * to proceed in a meaningful way anyway. * @@ -190,7 +201,10 @@ struct SwiftApiContext mullvad_api_init_new(const char *host, const char *domain, struct SwiftShadowsocksLoaderWrapper bridge_provider, struct SwiftAccessMethodSettingsWrapper settings_provider, - struct SwiftAddressCacheWrapper address_cache); + struct SwiftAddressCacheWrapper address_cache, + void (*access_method_change_callback)(const void*, + const uint8_t*), + const void *access_method_change_context); /** * # Safety @@ -212,7 +226,10 @@ struct SwiftApiContext mullvad_api_init_inner(const char *host, bool disable_tls, struct SwiftShadowsocksLoaderWrapper bridge_provider, struct SwiftAccessMethodSettingsWrapper settings_provider, - struct SwiftAddressCacheWrapper address_cache); + struct SwiftAddressCacheWrapper address_cache, + void (*access_method_change_callback)(const void*, + const uint8_t*), + const void *access_method_change_context); /** * Converts parameters into a `Box<AccessMethodSetting>` raw representation that diff --git a/ios/MullvadSettings/AccessMethodRepository.swift b/ios/MullvadSettings/AccessMethodRepository.swift index 72cfc0668e..07e1ad2a61 100644 --- a/ios/MullvadSettings/AccessMethodRepository.swift +++ b/ios/MullvadSettings/AccessMethodRepository.swift @@ -46,26 +46,40 @@ public class AccessMethodRepository: AccessMethodRepositoryProtocol, @unchecked accessMethodsSubject.eraseToAnyPublisher() } - private let lastReachableAccessMethodSubject: CurrentValueSubject<PersistentAccessMethod, Never> - public var lastReachableAccessMethodPublisher: AnyPublisher<PersistentAccessMethod, Never> { - lastReachableAccessMethodSubject.eraseToAnyPublisher() + private let requestAccessMethodSubject: PassthroughSubject<PersistentAccessMethod, Never> + public var requestAccessMethodPublisher: AnyPublisher<PersistentAccessMethod, Never> { + requestAccessMethodSubject.eraseToAnyPublisher() + } + + private let currentAccessMethodSubject: CurrentValueSubject<PersistentAccessMethod, Never> + public var currentAccessMethodPublisher: AnyPublisher<PersistentAccessMethod, Never> { + currentAccessMethodSubject.eraseToAnyPublisher() } public var directAccess: PersistentAccessMethod { direct } + private var cancellables: Set<Combine.AnyCancellable> = [] + public init() { accessMethodsSubject = CurrentValueSubject([]) - lastReachableAccessMethodSubject = CurrentValueSubject(direct) + requestAccessMethodSubject = PassthroughSubject() + currentAccessMethodSubject = CurrentValueSubject(direct) addDefaultsMethods() accessMethodsSubject.send(fetchAll()) - lastReachableAccessMethodSubject.send(fetchLastReachable()) + requestAccessMethodSubject.send(fetchLastReachable()) + + currentAccessMethodPublisher + .removeDuplicates() + .sink { [weak self] currentAccessMethod in + self?.saveCurrentAccessMethod(currentAccessMethod) + }.store(in: &cancellables) } - public func save(_ method: PersistentAccessMethod) { + public func save(_ method: PersistentAccessMethod, notifyingAPI: Bool = false) { var methodStore = readApiAccessMethodStore() var method = method @@ -79,19 +93,24 @@ public class AccessMethodRepository: AccessMethodRepositoryProtocol, @unchecked do { try writeApiAccessMethodStore(methodStore) - accessMethodsSubject.send(methodStore.accessMethods) + if notifyingAPI { + accessMethodsSubject.send(methodStore.accessMethods) + } } catch { logger.error("Could not save access method: \(method) \nError: \(error)") } } - public func saveLastReachable(_ method: PersistentAccessMethod) { + public func requestAccessMethod(_ method: PersistentAccessMethod) { + requestAccessMethodSubject.send(method) + } + + private func saveCurrentAccessMethod(_ method: PersistentAccessMethod) { var methodStore = readApiAccessMethodStore() methodStore.lastReachableAccessMethod = method do { try writeApiAccessMethodStore(methodStore) - lastReachableAccessMethodSubject.send(method) } catch { logger.error("Could not save last reachable access method: \(method) \nError: \(error)") } @@ -175,3 +194,17 @@ public class AccessMethodRepository: AccessMethodRepositoryProtocol, @unchecked SettingsParser(decoder: JSONDecoder(), encoder: JSONEncoder()) } } + +extension AccessMethodRepository: MullvadAccessMethodChangeListening { + public func accessMethodChangedTo(_ uuid: UUID) { + guard let method = accessMethodsSubject.value.first(where: { $0.id == uuid }) else { + logger.warning("Change reported to method with unknown ID: \(uuid)") + return + } + + Task { + print("Mullvad API changed access method to \(method.name)") + currentAccessMethodSubject.send(method) + } + } +} diff --git a/ios/MullvadSettings/AccessMethodRepositoryProtocol.swift b/ios/MullvadSettings/AccessMethodRepositoryProtocol.swift index 35f97442f5..d44f009911 100644 --- a/ios/MullvadSettings/AccessMethodRepositoryProtocol.swift +++ b/ios/MullvadSettings/AccessMethodRepositoryProtocol.swift @@ -21,7 +21,7 @@ public protocol AccessMethodRepositoryDataSource: Sendable { func fetchAll() -> [PersistentAccessMethod] /// Save last reachable access method to the persistent store. - func saveLastReachable(_ method: PersistentAccessMethod) + func requestAccessMethod(_ method: PersistentAccessMethod) /// Fetch last reachable access method from the persistent store. func fetchLastReachable() -> PersistentAccessMethod @@ -29,11 +29,11 @@ public protocol AccessMethodRepositoryDataSource: Sendable { public protocol AccessMethodRepositoryProtocol: AccessMethodRepositoryDataSource { /// Publisher that propagates a snapshot of last reachable access method upon modifications. - var lastReachableAccessMethodPublisher: AnyPublisher<PersistentAccessMethod, Never> { get } + var currentAccessMethodPublisher: AnyPublisher<PersistentAccessMethod, Never> { get } /// Add new access method. /// - Parameter method: persistent access method model. - func save(_ method: PersistentAccessMethod) + func save(_ method: PersistentAccessMethod, notifyingAPI: Bool) /// Delete access method by id. /// - Parameter id: an access method id. diff --git a/ios/MullvadSettings/MullvadAccessMethodChangeListening.swift b/ios/MullvadSettings/MullvadAccessMethodChangeListening.swift new file mode 100644 index 0000000000..679d5e66fb --- /dev/null +++ b/ios/MullvadSettings/MullvadAccessMethodChangeListening.swift @@ -0,0 +1,12 @@ +// +// MullvadAccessMethodChangeListening.swift +// MullvadVPN +// +// Created by Andrew Bulhak on 2025-07-03. +// Copyright © 2025 Mullvad VPN AB. All rights reserved. +// + +// A protocol that listens for notifications of when the current access method has changed. It receives only the UUID of the new method. +public protocol MullvadAccessMethodChangeListening: AnyObject { + func accessMethodChangedTo(_ uuid: UUID) +} diff --git a/ios/MullvadVPN.xcodeproj/project.pbxproj b/ios/MullvadVPN.xcodeproj/project.pbxproj index dcc97b34e1..9cd08dd97d 100644 --- a/ios/MullvadVPN.xcodeproj/project.pbxproj +++ b/ios/MullvadVPN.xcodeproj/project.pbxproj @@ -53,6 +53,7 @@ 4424CDD32CDBD4A6009D8C9F /* SingleChoiceList.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4424CDD22CDBD4A6009D8C9F /* SingleChoiceList.swift */; }; 447F3D8A2CDE1853006E3462 /* ShadowsocksObfuscationSettingsViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 447F3D882CDE1852006E3462 /* ShadowsocksObfuscationSettingsViewModel.swift */; }; 447F3D8B2CDE1853006E3462 /* ShadowsocksObfuscationSettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 447F3D892CDE1853006E3462 /* ShadowsocksObfuscationSettingsView.swift */; }; + 4483EC372E26A53D007E5473 /* MullvadAccessMethodChangeListening.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4483EC352E2693D5007E5473 /* MullvadAccessMethodChangeListening.swift */; }; 449275422C3570CA000526DE /* ICMP.swift in Sources */ = {isa = PBXBuildFile; fileRef = 449275412C3570CA000526DE /* ICMP.swift */; }; 4495ECD12D0B170700A7358B /* UDPOverTCPObfuscationSettingsPage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4495ECD02D0B16F700A7358B /* UDPOverTCPObfuscationSettingsPage.swift */; }; 4495ECD52D131A4800A7358B /* ShadowsocksObfuscationSettingsPage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4495ECD42D131A3E00A7358B /* ShadowsocksObfuscationSettingsPage.swift */; }; @@ -1648,6 +1649,7 @@ 4424CDD22CDBD4A6009D8C9F /* SingleChoiceList.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SingleChoiceList.swift; sourceTree = "<group>"; }; 447F3D882CDE1852006E3462 /* ShadowsocksObfuscationSettingsViewModel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShadowsocksObfuscationSettingsViewModel.swift; sourceTree = "<group>"; }; 447F3D892CDE1853006E3462 /* ShadowsocksObfuscationSettingsView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShadowsocksObfuscationSettingsView.swift; sourceTree = "<group>"; }; + 4483EC352E2693D5007E5473 /* MullvadAccessMethodChangeListening.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MullvadAccessMethodChangeListening.swift; sourceTree = "<group>"; }; 449275412C3570CA000526DE /* ICMP.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ICMP.swift; sourceTree = "<group>"; }; 449275432C3C3029000526DE /* TunnelPinger.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TunnelPinger.swift; sourceTree = "<group>"; }; 4495ECD02D0B16F700A7358B /* UDPOverTCPObfuscationSettingsPage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UDPOverTCPObfuscationSettingsPage.swift; sourceTree = "<group>"; }; @@ -3785,6 +3787,7 @@ 06410DFD292CE18F00AFC18C /* KeychainSettingsStore.swift */, 068CE5732927B7A400A068BB /* Migration.swift */, A9D96B192A8247C100A5C673 /* MigrationManager.swift */, + 4483EC352E2693D5007E5473 /* MullvadAccessMethodChangeListening.swift */, 58B2FDD52AA71D2A003EB5C6 /* MullvadSettings.h */, F0E61CA92BF2911D000C4A95 /* MultihopSettings.swift */, 44DD7D2C2B74E44A0005F67F /* QuantumResistanceSettings.swift */, @@ -6207,6 +6210,7 @@ A93181A12B727ED700E341D2 /* TunnelSettingsV4.swift in Sources */, 58FE25BF2AA72311003D1918 /* MigrationManager.swift in Sources */, 58B2FDEF2AA720C4003EB5C6 /* ApplicationTarget.swift in Sources */, + 4483EC372E26A53D007E5473 /* MullvadAccessMethodChangeListening.swift in Sources */, A988DF272ADE86ED00D807EF /* WireGuardObfuscationSettings.swift in Sources */, 58B2FDDE2AA71D5C003EB5C6 /* Migration.swift in Sources */, F05769BB2C6661EE00D9778B /* TunnelSettingsStrategy.swift in Sources */, diff --git a/ios/MullvadVPN/AppDelegate.swift b/ios/MullvadVPN/AppDelegate.swift index 2590e2f1e7..9758a60fc0 100644 --- a/ios/MullvadVPN/AppDelegate.swift +++ b/ios/MullvadVPN/AppDelegate.swift @@ -116,8 +116,9 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD accessMethodReceiver = MullvadAccessMethodReceiver( apiContext: apiContext, accessMethodsDataSource: accessMethodRepository.accessMethodsPublisher, - lastReachableDataSource: accessMethodRepository.lastReachableAccessMethodPublisher + requestDataSource: accessMethodRepository.requestAccessMethodPublisher ) + apiContext.accessMethodChangeListener = accessMethodRepository setUpProxies(containerURL: containerURL) let backgroundTaskProvider = BackgroundTaskProvider( diff --git a/ios/MullvadVPN/Coordinators/Settings/APIAccess/Edit/EditAccessMethodInteractor.swift b/ios/MullvadVPN/Coordinators/Settings/APIAccess/Edit/EditAccessMethodInteractor.swift index ac7a9125ff..453e2cb992 100644 --- a/ios/MullvadVPN/Coordinators/Settings/APIAccess/Edit/EditAccessMethodInteractor.swift +++ b/ios/MullvadVPN/Coordinators/Settings/APIAccess/Edit/EditAccessMethodInteractor.swift @@ -39,7 +39,7 @@ struct EditAccessMethodInteractor: EditAccessMethodInteractorProtocol { func saveAccessMethod() { guard let persistentMethod = try? subject.value.intoPersistentAccessMethod() else { return } - repository.save(persistentMethod) + repository.save(persistentMethod, notifyingAPI: true) checkIfSwitchCanBeToggled() } @@ -47,7 +47,7 @@ struct EditAccessMethodInteractor: EditAccessMethodInteractorProtocol { repository.delete(id: subject.value.id) // Enable direct access if all methods are disabled if repository.fetchAll().count(where: { $0.isEnabled }) == 0 { - repository.save(repository.directAccess) + repository.save(repository.directAccess, notifyingAPI: true) } } diff --git a/ios/MullvadVPN/Coordinators/Settings/APIAccess/List/ListAccessMethodInteractor.swift b/ios/MullvadVPN/Coordinators/Settings/APIAccess/List/ListAccessMethodInteractor.swift index 702922b9c0..cfd5ef9eed 100644 --- a/ios/MullvadVPN/Coordinators/Settings/APIAccess/List/ListAccessMethodInteractor.swift +++ b/ios/MullvadVPN/Coordinators/Settings/APIAccess/List/ListAccessMethodInteractor.swift @@ -28,7 +28,7 @@ struct ListAccessMethodInteractor: ListAccessMethodInteractorProtocol { } var itemInUsePublisher: AnyPublisher<ListAccessMethodItem?, Never> { - repository.lastReachableAccessMethodPublisher + repository.currentAccessMethodPublisher .receive(on: RunLoop.main) .map { $0.toListItem() } .eraseToAnyPublisher() diff --git a/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift b/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift index e949a36709..cf405b9701 100644 --- a/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift +++ b/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift @@ -273,7 +273,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable { accessMethodReceiver = MullvadAccessMethodReceiver( apiContext: apiContext, accessMethodsDataSource: accessMethodRepository.accessMethodsPublisher, - lastReachableDataSource: accessMethodRepository.lastReachableAccessMethodPublisher + requestDataSource: accessMethodRepository.requestAccessMethodPublisher ) encryptedDNSTransport = EncryptedDNSTransport(urlSession: urlSession) diff --git a/mullvad-api/src/access_mode.rs b/mullvad-api/src/access_mode.rs index 666488f59b..51c37fff8a 100644 --- a/mullvad-api/src/access_mode.rs +++ b/mullvad-api/src/access_mode.rs @@ -234,7 +234,6 @@ pub struct AccessModeSelector<B: AccessMethodResolver> { cmd_rx: mpsc::UnboundedReceiver<Message>, method_resolver: B, access_method_settings: Settings, - #[cfg(not(target_os = "ios"))] access_method_event_sender: mpsc::UnboundedSender<(AccessMethodEvent, oneshot::Sender<()>)>, connection_mode_provider_sender: mpsc::UnboundedSender<ApiConnectionMode>, current: ResolvedConnectionMode, @@ -248,10 +247,7 @@ impl<B: AccessMethodResolver + 'static> AccessModeSelector<B> { #[cfg_attr(not(feature = "api-override"), allow(unused_mut))] mut access_method_settings: Settings, #[cfg(feature = "api-override")] api_endpoint: ApiEndpoint, - #[cfg(not(target_os = "ios"))] access_method_event_sender: mpsc::UnboundedSender<( - AccessMethodEvent, - oneshot::Sender<()>, - )>, + access_method_event_sender: mpsc::UnboundedSender<(AccessMethodEvent, oneshot::Sender<()>)>, ) -> Result<(AccessModeSelectorHandle, AccessModeConnectionModeProvider)> { let (cmd_tx, cmd_rx) = mpsc::unbounded(); @@ -277,7 +273,6 @@ impl<B: AccessMethodResolver + 'static> AccessModeSelector<B> { cmd_rx, method_resolver, access_method_settings, - #[cfg(not(target_os = "ios"))] access_method_event_sender, connection_mode_provider_sender: change_tx, current: initial_connection_mode, @@ -385,13 +380,7 @@ impl<B: AccessMethodResolver + 'static> AccessModeSelector<B> { async fn set_current(&mut self, access_method: AccessMethodSetting) { let resolved = Self::resolve_with_default(&access_method, &mut self.method_resolver).await; - #[cfg(not(target_os = "ios"))] - self.notify_daemon(&resolved); - - // Notify REST client - let _ = self - .connection_mode_provider_sender - .unbounded_send(resolved.connection_mode.clone()); + self.notify_connection_mode(resolved.clone()); self.current = resolved; @@ -401,8 +390,7 @@ impl<B: AccessMethodResolver + 'static> AccessModeSelector<B> { ); } - #[cfg(not(target_os = "ios"))] - fn notify_daemon(&mut self, resolved: &ResolvedConnectionMode) { + fn notify_connection_mode(&mut self, resolved: ResolvedConnectionMode) { // Note: If the daemon is busy waiting for a call to this function // to complete while we wait for the daemon to fully handle this // `NewAccessMethodEvent`, then we find ourselves in a deadlock. @@ -410,21 +398,21 @@ impl<B: AccessMethodResolver + 'static> AccessModeSelector<B> { // `MullvadRestHandle`, which will call and await `next` on a Stream // created from this `AccessModeSelector` instance. As such, the // completion channel is discarded in this instance. - let setting = resolved.setting.clone(); - #[cfg(not(target_os = "android"))] - let endpoint = resolved.endpoint.clone(); + let access_method_event = AccessMethodEvent::New { + setting: resolved.setting, + connection_mode: resolved.connection_mode.clone(), + #[cfg(not(target_os = "android"))] + endpoint: resolved.endpoint, + }; let sender = self.access_method_event_sender.clone(); - let connection_mode = resolved.connection_mode.clone(); tokio::spawn(async move { - let _ = AccessMethodEvent::New { - setting, - connection_mode, - #[cfg(not(target_os = "android"))] - endpoint, - } - .send(sender) - .await; + let _ = access_method_event.send(sender).await; }); + + // Notify REST client + let _ = self + .connection_mode_provider_sender + .unbounded_send(resolved.connection_mode); } /// Find the next access method to use. diff --git a/mullvad-ios/src/api_client/mod.rs b/mullvad-ios/src/api_client/mod.rs index dfa25d0b44..e47385717a 100644 --- a/mullvad-ios/src/api_client/mod.rs +++ b/mullvad-ios/src/api_client/mod.rs @@ -1,12 +1,16 @@ -use std::{ffi::c_char, future::Future, sync::Arc}; +use std::{ffi::c_char, ffi::c_void, future::Future, sync::Arc}; use crate::get_string; use access_method_resolver::SwiftAccessMethodResolver; use access_method_settings::SwiftAccessMethodSettingsWrapper; use address_cache_provider::SwiftAddressCacheWrapper; +use futures::{ + StreamExt, + channel::{mpsc, oneshot}, +}; use mullvad_api::{ ApiEndpoint, Runtime, - access_mode::{AccessModeSelector, AccessModeSelectorHandle}, + access_mode::{AccessMethodEvent, AccessModeSelector, AccessModeSelectorHandle}, rest::{self, MullvadRestHandle}, }; use mullvad_encrypted_dns_proxy::state::EncryptedDnsProxyState; @@ -85,6 +89,13 @@ impl ApiContext { } } +/// An opaque pointer that exists only to be passed from the caller to a callback through the ABI +struct ForeignPtr { + ptr: *const c_void, +} +/// allow this to be passed across thread boundaries +unsafe impl Send for ForeignPtr {} + /// Called by Swift to set the available access methods #[unsafe(no_mangle)] pub unsafe extern "C" fn mullvad_api_update_access_methods( @@ -138,6 +149,8 @@ pub extern "C" fn mullvad_api_init_new_tls_disabled( bridge_provider: SwiftShadowsocksLoaderWrapper, settings_provider: SwiftAccessMethodSettingsWrapper, address_cache: SwiftAddressCacheWrapper, + access_method_change_callback: Option<unsafe extern "C" fn(*const c_void, *const u8)>, + access_method_change_context: *const c_void, ) -> SwiftApiContext { mullvad_api_init_inner( host, @@ -147,6 +160,8 @@ pub extern "C" fn mullvad_api_init_new_tls_disabled( bridge_provider, settings_provider, address_cache, + access_method_change_callback, + access_method_change_context, ) } @@ -158,6 +173,14 @@ pub extern "C" fn mullvad_api_init_new_tls_disabled( /// `address` must be a pointer to a null terminated string representing a socket address through which /// the Mullvad API can be reached directly. /// +/// address_method_change_callback is a function with the C calling convention which will be called +/// whenever the access method changes with a user-specified opaque pointer and a pointer to the bytes +/// of the access method's UUID. Note that this callback must remain valid for the lifetime of the +/// program. +/// +/// access_method_change_context is the pointer passed verbatim to the callback. It is not dereferenced +/// by the Rust code, but remains opaque. +/// /// If a context cannot be constructed this function will panic since the call site would not be able /// to proceed in a meaningful way anyway. /// @@ -170,6 +193,8 @@ pub extern "C" fn mullvad_api_init_new( bridge_provider: SwiftShadowsocksLoaderWrapper, settings_provider: SwiftAccessMethodSettingsWrapper, address_cache: SwiftAddressCacheWrapper, + access_method_change_callback: Option<unsafe extern "C" fn(*const c_void, *const u8)>, + access_method_change_context: *const c_void, ) -> SwiftApiContext { #[cfg(feature = "api-override")] return mullvad_api_init_inner( @@ -180,6 +205,8 @@ pub extern "C" fn mullvad_api_init_new( bridge_provider, settings_provider, address_cache, + access_method_change_callback, + access_method_change_context, ); #[cfg(not(feature = "api-override"))] mullvad_api_init_inner( @@ -189,6 +216,8 @@ pub extern "C" fn mullvad_api_init_new( bridge_provider, settings_provider, address_cache, + access_method_change_callback, + access_method_change_context, ) } @@ -213,6 +242,8 @@ pub extern "C" fn mullvad_api_init_inner( bridge_provider: SwiftShadowsocksLoaderWrapper, settings_provider: SwiftAccessMethodSettingsWrapper, address_cache: SwiftAddressCacheWrapper, + access_method_change_callback: Option<unsafe extern "C" fn(*const c_void, *const u8)>, + access_method_change_context: *const c_void, ) -> SwiftApiContext { // Safety: See notes for `get_string` let (host, address, domain) = @@ -245,16 +276,42 @@ pub extern "C" fn mullvad_api_init_inner( address_cache, ); + let access_method_change_ctx: ForeignPtr = ForeignPtr { + ptr: access_method_change_context, + }; let api_context = tokio_handle.clone().block_on(async move { + let (tx, mut rx) = mpsc::unbounded::<(AccessMethodEvent, oneshot::Sender<()>)>(); let (access_mode_handler, access_mode_provider) = AccessModeSelector::spawn( method_resolver, access_method_settings, #[cfg(feature = "api-override")] endpoint.clone(), + tx, ) .await .expect("Could now spawn AccessModeSelector"); + // SAFETY: The callback is expected to be called from the Swift side + if let Some(callback) = access_method_change_callback { + tokio::spawn(async move { + let access_method_change_ctx = access_method_change_ctx; + while let Some((event, _sender)) = rx.next().await { + let AccessMethodEvent::New { + setting, + connection_mode: _, + endpoint: _, + } = event + else { + continue; + }; + let uuid = setting.get_id(); + let uuid_bytes = uuid.as_bytes(); + // SAFETY: The callback is expected to be safe to call + unsafe { callback(access_method_change_ctx.ptr, uuid_bytes.as_ptr()) }; + } + }); + } + // It is imperative that the REST runtime is created within an async context, otherwise // ApiAvailability panics. let api_client = mullvad_api::Runtime::new(tokio_handle, &endpoint); diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs index 4f1229d126..f823e5d83d 100644 --- a/mullvad-types/src/access_method.rs +++ b/mullvad-types/src/access_method.rs @@ -195,6 +195,10 @@ impl Id { use std::str::FromStr; uuid::Uuid::from_str(&id).ok().map(Self) } + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } } impl std::fmt::Display for Id { |
