diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-11-26 10:43:32 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-11-26 10:43:32 +0100 |
| commit | 1ed0cc515f5da09dcf04e79af2c50757484a2fe8 (patch) | |
| tree | c3c2eb8cd31abf59b00ddbf8ad02fefc79f403f7 | |
| parent | 9bfca257349dad9cde2aac9b1516e71f33f25a23 (diff) | |
| parent | 8d1144c76a81d405f2c032b9b5333ed293142f3b (diff) | |
| download | mullvadvpn-1ed0cc515f5da09dcf04e79af2c50757484a2fe8.tar.xz mullvadvpn-1ed0cc515f5da09dcf04e79af2c50757484a2fe8.zip | |
Merge branch 'remove-tokio-handles'
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 9 | ||||
| -rw-r--r-- | mullvad-problem-report/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 2 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 11 | ||||
| -rw-r--r-- | talpid-core/src/dns/macos.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/routing/unix.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 14 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 42 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 18 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 57 |
10 files changed, 77 insertions, 96 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 9d64a67bea..80a8061fc8 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -629,7 +629,6 @@ where }; let mut rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( - runtime.clone(), Some(&resource_dir), &cache_dir, true, @@ -649,7 +648,6 @@ where let (offline_state_tx, offline_state_rx) = mpsc::unbounded(); let tunnel_command_tx = tunnel_state_machine::spawn( - runtime.clone(), tunnel_state_machine::InitialTunnelState { allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, @@ -691,7 +689,7 @@ where let rpc_handle = rpc_runtime.mullvad_rest_handle(); - Self::forward_offline_state(&runtime, api_availability.clone(), offline_state_rx).await; + Self::forward_offline_state(api_availability.clone(), offline_state_rx).await; let relay_list_listener = event_listener.clone(); let on_relay_list_update = move |relay_list: &RelayList| { @@ -2427,7 +2425,7 @@ where ) -> Option<mpsc::Sender<mullvad_rpc::SocketBypassRequest>> { let (bypass_tx, mut bypass_rx) = mpsc::channel(1); let daemon_tx = event_sender.to_specialized_sender(); - tokio::runtime::Handle::current().spawn(async move { + tokio::spawn(async move { while let Some((raw_fd, done_tx)) = bypass_rx.next().await { if let Err(_) = daemon_tx.send(DaemonCommand::BypassSocket(raw_fd, done_tx)) { log::error!("Can't send socket bypass request to daemon"); @@ -2439,7 +2437,6 @@ where } async fn forward_offline_state( - runtime: &tokio::runtime::Handle, api_availability: ApiAvailabilityHandle, mut offline_state_rx: mpsc::UnboundedReceiver<bool>, ) { @@ -2448,7 +2445,7 @@ where .await .expect("missing initial offline state"); api_availability.set_offline(initial_state); - runtime.spawn(async move { + tokio::spawn(async move { while let Some(is_offline) = offline_state_rx.next().await { api_availability.set_offline(is_offline); } diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 9a5bca1a61..56bcb0cdfb 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -279,7 +279,6 @@ pub fn send_problem_report( let mut rpc_manager = runtime .block_on(mullvad_rpc::MullvadRpcRuntime::with_cache( - runtime.handle().clone(), None, cache_dir, false, diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 61654cab1c..956db21e9d 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -128,12 +128,12 @@ impl MullvadRpcRuntime { /// Try to use the cache directory first, and fall back on the resource directory /// if it fails. pub async fn with_cache( - handle: tokio::runtime::Handle, resource_dir: Option<&Path>, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { + let handle = tokio::runtime::Handle::current(); #[cfg(feature = "api-override")] if *DISABLE_ADDRESS_ROTATION { return Self::new_inner( diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index eb09513a1f..8b9df4b154 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -173,14 +173,9 @@ async fn remove_wireguard_key() -> Result<(), Error> { if let Some(token) = settings.get_account_token() { if let Some(wg_data) = settings.get_wireguard() { - let mut rpc_runtime = MullvadRpcRuntime::with_cache( - tokio::runtime::Handle::current(), - None, - &cache_path, - false, - ) - .await - .map_err(Error::RpcInitializationError)?; + let mut rpc_runtime = MullvadRpcRuntime::with_cache(None, &cache_path, false) + .await + .map_err(Error::RpcInitializationError)?; let mut key_proxy = mullvad_rpc::WireguardKeyProxy::new(rpc_runtime.mullvad_rest_handle()); retry_future_n( diff --git a/talpid-core/src/dns/macos.rs b/talpid-core/src/dns/macos.rs index b58e47eada..efb4e8c816 100644 --- a/talpid-core/src/dns/macos.rs +++ b/talpid-core/src/dns/macos.rs @@ -132,6 +132,10 @@ pub struct DnsMonitor { state: Arc<Mutex<Option<State>>>, } +/// SAFETY: The `SCDynamicStore` can be sent to other threads since it doesn't share mutable state +/// with anything else. +unsafe impl Send for DnsMonitor {} + impl super::DnsMonitorT for DnsMonitor { type Error = Error; diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index 989ec7ad24..aa1cf66f6b 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -173,16 +173,13 @@ impl RouteManager { /// Constructs a RouteManager and applies the required routes. /// Takes a set of network destinations and network nodes as an argument, and applies said /// routes. - pub async fn new( - runtime: tokio::runtime::Handle, - required_routes: HashSet<RequiredRoute>, - ) -> Result<Self, Error> { + pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> { let (manage_tx, manage_rx) = mpsc::unbounded(); let manager = imp::RouteManagerImpl::new(required_routes).await?; - runtime.spawn(manager.run(manage_rx)); + tokio::spawn(manager.run(manage_rx)); Ok(Self { - runtime, + runtime: tokio::runtime::Handle::current(), manage_tx: Some(manage_tx), }) } @@ -259,12 +256,6 @@ impl RouteManager { Err(Error::RouteManagerDown) } } - - /// Exposes runtime handle - #[cfg(target_os = "linux")] - pub fn runtime_handle(&self) -> tokio::runtime::Handle { - self.runtime.clone() - } } impl Drop for RouteManager { diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index be812ff99a..ec17d4feae 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -36,7 +36,6 @@ pub type Result<T> = std::result::Result<T, Error>; /// Manages routes by calling into WinNet pub struct RouteManager { - runtime: tokio::runtime::Handle, manage_tx: Option<UnboundedSender<RouteManagerCommand>>, } @@ -66,19 +65,15 @@ pub enum RouteManagerCommand { impl RouteManager { /// Creates a new route manager that will apply the provided routes and ensure they exist until /// it's stopped. - pub async fn new( - runtime: tokio::runtime::Handle, - required_routes: HashSet<RequiredRoute>, - ) -> Result<Self> { + pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { if !winnet::activate_routing_manager() { return Err(Error::FailedToStartManager); } let (manage_tx, manage_rx) = mpsc::unbounded(); let manager = Self { - runtime: runtime.clone(), manage_tx: Some(manage_tx), }; - runtime.spawn(RouteManager::listen(manage_rx)); + tokio::spawn(RouteManager::listen(manage_rx)); manager.add_routes(required_routes).await?; Ok(manager) @@ -93,11 +88,6 @@ impl RouteManager { } } - /// Retrieve handle for the tokio runtime. - pub fn runtime_handle(&self) -> tokio::runtime::Handle { - self.runtime.clone() - } - async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) { while let Some(command) = manage_rx.next().await { match command { diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index d9aa5cb286..d541d7c3f2 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -100,12 +100,36 @@ pub struct SplitTunnel { runtime: tokio::runtime::Handle, request_tx: RequestTx, event_thread: Option<std::thread::JoinHandle<()>>, - quit_event: RawHandle, + quit_event: Arc<QuitEvent>, _route_change_callback: Option<WinNetCallbackHandle>, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, async_path_update_in_progress: Arc<AtomicBool>, } +struct QuitEvent(RawHandle); + +unsafe impl Send for QuitEvent {} +unsafe impl Sync for QuitEvent {} + +impl QuitEvent { + fn new() -> Self { + Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }) + } + + fn set_event(&self) -> io::Result<()> { + if unsafe { SetEvent(self.0) } == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } +} + +impl Drop for QuitEvent { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + enum Request { SetPaths(Vec<OsString>), RegisterIps( @@ -123,7 +147,7 @@ const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); struct EventThreadContext { handle: Arc<driver::DeviceHandle>, event_overlapped: OVERLAPPED, - quit_event: RawHandle, + quit_event: Arc<QuitEvent>, } unsafe impl Send for EventThreadContext {} @@ -142,12 +166,12 @@ impl SplitTunnel { return Err(Error::EventThreadError(io::Error::last_os_error())); } - let quit_event = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; + let quit_event = Arc::new(QuitEvent::new()); let event_context = EventThreadContext { handle: handle.clone(), event_overlapped, - quit_event, + quit_event: quit_event.clone(), }; let event_thread = std::thread::spawn(move || { @@ -161,11 +185,11 @@ impl SplitTunnel { let event_objects = [ event_context.event_overlapped.hEvent, - event_context.quit_event, + event_context.quit_event.0, ]; loop { - if unsafe { WaitForSingleObject(event_context.quit_event, 0) == WAIT_OBJECT_0 } { + if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } { // Quit event was signaled break; } @@ -213,7 +237,7 @@ impl SplitTunnel { continue; }; - if event_context.quit_event == event_objects[signaled_index as usize] { + if event_context.quit_event.0 == event_objects[signaled_index as usize] { // Quit event was signaled break; } @@ -288,7 +312,6 @@ impl SplitTunnel { log::debug!("Stopping split tunnel event thread"); unsafe { CloseHandle(event_context.event_overlapped.hEvent) }; - unsafe { CloseHandle(event_context.quit_event) }; }); Ok(SplitTunnel { @@ -520,8 +543,7 @@ impl SplitTunnel { impl Drop for SplitTunnel { fn drop(&mut self) { if let Some(_event_thread) = self.event_thread.take() { - if unsafe { SetEvent(self.quit_event) } == 0 { - let error = io::Error::last_os_error(); + if let Err(error) = self.quit_event.set_event() { log::error!( "{}", error.display_chain_with_msg("Failed to close ST event thread") diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index a3fa426ca2..e2889166a7 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -184,8 +184,14 @@ impl WireguardMonitor { } } - let tunnel = - Self::open_tunnel(&config, log_path, resource_dir, tun_provider, route_manager)?; + let tunnel = Self::open_tunnel( + runtime.clone(), + &config, + log_path, + resource_dir, + tun_provider, + route_manager, + )?; let iface_name = tunnel.get_interface_name().to_string(); #[cfg(windows)] let iface_luid = tunnel.get_interface_luid(); @@ -310,6 +316,7 @@ impl WireguardMonitor { #[allow(unused_variables)] fn open_tunnel( + runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, resource_dir: &Path, @@ -319,10 +326,7 @@ impl WireguardMonitor { #[cfg(target_os = "linux")] if !*FORCE_USERSPACE_WIREGUARD { if crate::dns::will_use_nm() { - match wireguard_kernel::NetworkManagerTunnel::new( - route_manager.runtime_handle(), - config, - ) { + match wireguard_kernel::NetworkManagerTunnel::new(runtime, config) { Ok(tunnel) => { log::debug!("Using NetworkManager to use kernel WireGuard implementation"); return Ok(Box::new(tunnel)); @@ -337,7 +341,7 @@ impl WireguardMonitor { } }; } else { - match wireguard_kernel::NetlinkTunnel::new(route_manager.runtime_handle(), config) { + match wireguard_kernel::NetlinkTunnel::new(runtime, config) { Ok(tunnel) => { log::debug!("Using kernel WireGuard implementation"); return Ok(Box::new(tunnel)); diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index fbc3e05622..ef80293399 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -30,13 +30,7 @@ use futures::{ }; #[cfg(target_os = "android")] use std::os::unix::io::RawFd; -use std::{ - collections::HashSet, - io, - net::IpAddr, - path::PathBuf, - sync::{mpsc as sync_mpsc, Arc}, -}; +use std::{collections::HashSet, io, net::IpAddr, path::PathBuf, sync::Arc}; #[cfg(target_os = "android")] use talpid_types::{android::AndroidContext, ErrorExt}; use talpid_types::{ @@ -97,7 +91,6 @@ pub struct InitialTunnelState { /// Spawn the tunnel state machine thread, returning a channel for sending tunnel commands. pub async fn spawn( - runtime: tokio::runtime::Handle, initial_settings: InitialTunnelState, tunnel_parameters_generator: impl TunnelParametersGenerator, log_dir: Option<PathBuf>, @@ -121,43 +114,28 @@ pub async fn spawn( initial_settings.dns_servers.clone(), ); - let (startup_result_tx, startup_result_rx) = sync_mpsc::channel(); let weak_command_tx = Arc::downgrade(&command_tx); - std::thread::spawn(move || { - let state_machine = runtime.block_on(TunnelStateMachine::new( - runtime.clone(), - initial_settings, - weak_command_tx, - offline_state_listener, - tunnel_parameters_generator, - tun_provider, - log_dir, - resource_dir, - command_rx, - #[cfg(target_os = "android")] - android_context, - )); - let state_machine = match state_machine { - Ok(state_machine) => { - startup_result_tx.send(Ok(())).unwrap(); - state_machine - } - Err(error) => { - startup_result_tx.send(Err(error)).unwrap(); - return; - } - }; + let state_machine = TunnelStateMachine::new( + initial_settings, + weak_command_tx, + offline_state_listener, + tunnel_parameters_generator, + tun_provider, + log_dir, + resource_dir, + command_rx, + #[cfg(target_os = "android")] + android_context, + ) + .await?; + tokio::task::spawn_blocking(move || { state_machine.run(state_change_listener); - if shutdown_tx.send(()).is_err() { log::error!("Can't send shutdown completion to daemon"); } }); - startup_result_rx - .recv() - .expect("Failed to start tunnel state machine thread")?; Ok(command_tx) } @@ -213,7 +191,6 @@ struct TunnelStateMachine { impl TunnelStateMachine { async fn new( - runtime: tokio::runtime::Handle, settings: InitialTunnelState, command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, offline_state_tx: mpsc::UnboundedSender<bool>, @@ -224,6 +201,8 @@ impl TunnelStateMachine { commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Self, Error> { + let runtime = tokio::runtime::Handle::current(); + #[cfg(windows)] let split_tunnel = split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone()) .map_err(Error::InitSplitTunneling)?; @@ -235,7 +214,7 @@ impl TunnelStateMachine { }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; - let route_manager = RouteManager::new(runtime.clone(), HashSet::new()) + let route_manager = RouteManager::new(HashSet::new()) .await .map_err(Error::InitRouteManagerError)?; let dns_monitor = DnsMonitor::new( |
