summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-11-26 10:43:32 +0100
committerDavid Lönnhager <david.l@mullvad.net>2021-11-26 10:43:32 +0100
commit1ed0cc515f5da09dcf04e79af2c50757484a2fe8 (patch)
treec3c2eb8cd31abf59b00ddbf8ad02fefc79f403f7
parent9bfca257349dad9cde2aac9b1516e71f33f25a23 (diff)
parent8d1144c76a81d405f2c032b9b5333ed293142f3b (diff)
downloadmullvadvpn-1ed0cc515f5da09dcf04e79af2c50757484a2fe8.tar.xz
mullvadvpn-1ed0cc515f5da09dcf04e79af2c50757484a2fe8.zip
Merge branch 'remove-tokio-handles'
-rw-r--r--mullvad-daemon/src/lib.rs9
-rw-r--r--mullvad-problem-report/src/lib.rs1
-rw-r--r--mullvad-rpc/src/lib.rs2
-rw-r--r--mullvad-setup/src/main.rs11
-rw-r--r--talpid-core/src/dns/macos.rs4
-rw-r--r--talpid-core/src/routing/unix.rs15
-rw-r--r--talpid-core/src/routing/windows.rs14
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs42
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs18
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs57
10 files changed, 77 insertions, 96 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 9d64a67bea..80a8061fc8 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -629,7 +629,6 @@ where
};
let mut rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
- runtime.clone(),
Some(&resource_dir),
&cache_dir,
true,
@@ -649,7 +648,6 @@ where
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
let tunnel_command_tx = tunnel_state_machine::spawn(
- runtime.clone(),
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
@@ -691,7 +689,7 @@ where
let rpc_handle = rpc_runtime.mullvad_rest_handle();
- Self::forward_offline_state(&runtime, api_availability.clone(), offline_state_rx).await;
+ Self::forward_offline_state(api_availability.clone(), offline_state_rx).await;
let relay_list_listener = event_listener.clone();
let on_relay_list_update = move |relay_list: &RelayList| {
@@ -2427,7 +2425,7 @@ where
) -> Option<mpsc::Sender<mullvad_rpc::SocketBypassRequest>> {
let (bypass_tx, mut bypass_rx) = mpsc::channel(1);
let daemon_tx = event_sender.to_specialized_sender();
- tokio::runtime::Handle::current().spawn(async move {
+ tokio::spawn(async move {
while let Some((raw_fd, done_tx)) = bypass_rx.next().await {
if let Err(_) = daemon_tx.send(DaemonCommand::BypassSocket(raw_fd, done_tx)) {
log::error!("Can't send socket bypass request to daemon");
@@ -2439,7 +2437,6 @@ where
}
async fn forward_offline_state(
- runtime: &tokio::runtime::Handle,
api_availability: ApiAvailabilityHandle,
mut offline_state_rx: mpsc::UnboundedReceiver<bool>,
) {
@@ -2448,7 +2445,7 @@ where
.await
.expect("missing initial offline state");
api_availability.set_offline(initial_state);
- runtime.spawn(async move {
+ tokio::spawn(async move {
while let Some(is_offline) = offline_state_rx.next().await {
api_availability.set_offline(is_offline);
}
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 9a5bca1a61..56bcb0cdfb 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -279,7 +279,6 @@ pub fn send_problem_report(
let mut rpc_manager = runtime
.block_on(mullvad_rpc::MullvadRpcRuntime::with_cache(
- runtime.handle().clone(),
None,
cache_dir,
false,
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 61654cab1c..956db21e9d 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -128,12 +128,12 @@ impl MullvadRpcRuntime {
/// Try to use the cache directory first, and fall back on the resource directory
/// if it fails.
pub async fn with_cache(
- handle: tokio::runtime::Handle,
resource_dir: Option<&Path>,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
+ let handle = tokio::runtime::Handle::current();
#[cfg(feature = "api-override")]
if *DISABLE_ADDRESS_ROTATION {
return Self::new_inner(
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index eb09513a1f..8b9df4b154 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -173,14 +173,9 @@ async fn remove_wireguard_key() -> Result<(), Error> {
if let Some(token) = settings.get_account_token() {
if let Some(wg_data) = settings.get_wireguard() {
- let mut rpc_runtime = MullvadRpcRuntime::with_cache(
- tokio::runtime::Handle::current(),
- None,
- &cache_path,
- false,
- )
- .await
- .map_err(Error::RpcInitializationError)?;
+ let mut rpc_runtime = MullvadRpcRuntime::with_cache(None, &cache_path, false)
+ .await
+ .map_err(Error::RpcInitializationError)?;
let mut key_proxy =
mullvad_rpc::WireguardKeyProxy::new(rpc_runtime.mullvad_rest_handle());
retry_future_n(
diff --git a/talpid-core/src/dns/macos.rs b/talpid-core/src/dns/macos.rs
index b58e47eada..efb4e8c816 100644
--- a/talpid-core/src/dns/macos.rs
+++ b/talpid-core/src/dns/macos.rs
@@ -132,6 +132,10 @@ pub struct DnsMonitor {
state: Arc<Mutex<Option<State>>>,
}
+/// SAFETY: The `SCDynamicStore` can be sent to other threads since it doesn't share mutable state
+/// with anything else.
+unsafe impl Send for DnsMonitor {}
+
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index 989ec7ad24..aa1cf66f6b 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -173,16 +173,13 @@ impl RouteManager {
/// Constructs a RouteManager and applies the required routes.
/// Takes a set of network destinations and network nodes as an argument, and applies said
/// routes.
- pub async fn new(
- runtime: tokio::runtime::Handle,
- required_routes: HashSet<RequiredRoute>,
- ) -> Result<Self, Error> {
+ pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> {
let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = imp::RouteManagerImpl::new(required_routes).await?;
- runtime.spawn(manager.run(manage_rx));
+ tokio::spawn(manager.run(manage_rx));
Ok(Self {
- runtime,
+ runtime: tokio::runtime::Handle::current(),
manage_tx: Some(manage_tx),
})
}
@@ -259,12 +256,6 @@ impl RouteManager {
Err(Error::RouteManagerDown)
}
}
-
- /// Exposes runtime handle
- #[cfg(target_os = "linux")]
- pub fn runtime_handle(&self) -> tokio::runtime::Handle {
- self.runtime.clone()
- }
}
impl Drop for RouteManager {
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index be812ff99a..ec17d4feae 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -36,7 +36,6 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Manages routes by calling into WinNet
pub struct RouteManager {
- runtime: tokio::runtime::Handle,
manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
}
@@ -66,19 +65,15 @@ pub enum RouteManagerCommand {
impl RouteManager {
/// Creates a new route manager that will apply the provided routes and ensure they exist until
/// it's stopped.
- pub async fn new(
- runtime: tokio::runtime::Handle,
- required_routes: HashSet<RequiredRoute>,
- ) -> Result<Self> {
+ pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
if !winnet::activate_routing_manager() {
return Err(Error::FailedToStartManager);
}
let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = Self {
- runtime: runtime.clone(),
manage_tx: Some(manage_tx),
};
- runtime.spawn(RouteManager::listen(manage_rx));
+ tokio::spawn(RouteManager::listen(manage_rx));
manager.add_routes(required_routes).await?;
Ok(manager)
@@ -93,11 +88,6 @@ impl RouteManager {
}
}
- /// Retrieve handle for the tokio runtime.
- pub fn runtime_handle(&self) -> tokio::runtime::Handle {
- self.runtime.clone()
- }
-
async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) {
while let Some(command) = manage_rx.next().await {
match command {
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index d9aa5cb286..d541d7c3f2 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -100,12 +100,36 @@ pub struct SplitTunnel {
runtime: tokio::runtime::Handle,
request_tx: RequestTx,
event_thread: Option<std::thread::JoinHandle<()>>,
- quit_event: RawHandle,
+ quit_event: Arc<QuitEvent>,
_route_change_callback: Option<WinNetCallbackHandle>,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
async_path_update_in_progress: Arc<AtomicBool>,
}
+struct QuitEvent(RawHandle);
+
+unsafe impl Send for QuitEvent {}
+unsafe impl Sync for QuitEvent {}
+
+impl QuitEvent {
+ fn new() -> Self {
+ Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) })
+ }
+
+ fn set_event(&self) -> io::Result<()> {
+ if unsafe { SetEvent(self.0) } == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+}
+
+impl Drop for QuitEvent {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.0) };
+ }
+}
+
enum Request {
SetPaths(Vec<OsString>),
RegisterIps(
@@ -123,7 +147,7 @@ const REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
struct EventThreadContext {
handle: Arc<driver::DeviceHandle>,
event_overlapped: OVERLAPPED,
- quit_event: RawHandle,
+ quit_event: Arc<QuitEvent>,
}
unsafe impl Send for EventThreadContext {}
@@ -142,12 +166,12 @@ impl SplitTunnel {
return Err(Error::EventThreadError(io::Error::last_os_error()));
}
- let quit_event = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+ let quit_event = Arc::new(QuitEvent::new());
let event_context = EventThreadContext {
handle: handle.clone(),
event_overlapped,
- quit_event,
+ quit_event: quit_event.clone(),
};
let event_thread = std::thread::spawn(move || {
@@ -161,11 +185,11 @@ impl SplitTunnel {
let event_objects = [
event_context.event_overlapped.hEvent,
- event_context.quit_event,
+ event_context.quit_event.0,
];
loop {
- if unsafe { WaitForSingleObject(event_context.quit_event, 0) == WAIT_OBJECT_0 } {
+ if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } {
// Quit event was signaled
break;
}
@@ -213,7 +237,7 @@ impl SplitTunnel {
continue;
};
- if event_context.quit_event == event_objects[signaled_index as usize] {
+ if event_context.quit_event.0 == event_objects[signaled_index as usize] {
// Quit event was signaled
break;
}
@@ -288,7 +312,6 @@ impl SplitTunnel {
log::debug!("Stopping split tunnel event thread");
unsafe { CloseHandle(event_context.event_overlapped.hEvent) };
- unsafe { CloseHandle(event_context.quit_event) };
});
Ok(SplitTunnel {
@@ -520,8 +543,7 @@ impl SplitTunnel {
impl Drop for SplitTunnel {
fn drop(&mut self) {
if let Some(_event_thread) = self.event_thread.take() {
- if unsafe { SetEvent(self.quit_event) } == 0 {
- let error = io::Error::last_os_error();
+ if let Err(error) = self.quit_event.set_event() {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close ST event thread")
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index a3fa426ca2..e2889166a7 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -184,8 +184,14 @@ impl WireguardMonitor {
}
}
- let tunnel =
- Self::open_tunnel(&config, log_path, resource_dir, tun_provider, route_manager)?;
+ let tunnel = Self::open_tunnel(
+ runtime.clone(),
+ &config,
+ log_path,
+ resource_dir,
+ tun_provider,
+ route_manager,
+ )?;
let iface_name = tunnel.get_interface_name().to_string();
#[cfg(windows)]
let iface_luid = tunnel.get_interface_luid();
@@ -310,6 +316,7 @@ impl WireguardMonitor {
#[allow(unused_variables)]
fn open_tunnel(
+ runtime: tokio::runtime::Handle,
config: &Config,
log_path: Option<&Path>,
resource_dir: &Path,
@@ -319,10 +326,7 @@ impl WireguardMonitor {
#[cfg(target_os = "linux")]
if !*FORCE_USERSPACE_WIREGUARD {
if crate::dns::will_use_nm() {
- match wireguard_kernel::NetworkManagerTunnel::new(
- route_manager.runtime_handle(),
- config,
- ) {
+ match wireguard_kernel::NetworkManagerTunnel::new(runtime, config) {
Ok(tunnel) => {
log::debug!("Using NetworkManager to use kernel WireGuard implementation");
return Ok(Box::new(tunnel));
@@ -337,7 +341,7 @@ impl WireguardMonitor {
}
};
} else {
- match wireguard_kernel::NetlinkTunnel::new(route_manager.runtime_handle(), config) {
+ match wireguard_kernel::NetlinkTunnel::new(runtime, config) {
Ok(tunnel) => {
log::debug!("Using kernel WireGuard implementation");
return Ok(Box::new(tunnel));
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index fbc3e05622..ef80293399 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -30,13 +30,7 @@ use futures::{
};
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
-use std::{
- collections::HashSet,
- io,
- net::IpAddr,
- path::PathBuf,
- sync::{mpsc as sync_mpsc, Arc},
-};
+use std::{collections::HashSet, io, net::IpAddr, path::PathBuf, sync::Arc};
#[cfg(target_os = "android")]
use talpid_types::{android::AndroidContext, ErrorExt};
use talpid_types::{
@@ -97,7 +91,6 @@ pub struct InitialTunnelState {
/// Spawn the tunnel state machine thread, returning a channel for sending tunnel commands.
pub async fn spawn(
- runtime: tokio::runtime::Handle,
initial_settings: InitialTunnelState,
tunnel_parameters_generator: impl TunnelParametersGenerator,
log_dir: Option<PathBuf>,
@@ -121,43 +114,28 @@ pub async fn spawn(
initial_settings.dns_servers.clone(),
);
- let (startup_result_tx, startup_result_rx) = sync_mpsc::channel();
let weak_command_tx = Arc::downgrade(&command_tx);
- std::thread::spawn(move || {
- let state_machine = runtime.block_on(TunnelStateMachine::new(
- runtime.clone(),
- initial_settings,
- weak_command_tx,
- offline_state_listener,
- tunnel_parameters_generator,
- tun_provider,
- log_dir,
- resource_dir,
- command_rx,
- #[cfg(target_os = "android")]
- android_context,
- ));
- let state_machine = match state_machine {
- Ok(state_machine) => {
- startup_result_tx.send(Ok(())).unwrap();
- state_machine
- }
- Err(error) => {
- startup_result_tx.send(Err(error)).unwrap();
- return;
- }
- };
+ let state_machine = TunnelStateMachine::new(
+ initial_settings,
+ weak_command_tx,
+ offline_state_listener,
+ tunnel_parameters_generator,
+ tun_provider,
+ log_dir,
+ resource_dir,
+ command_rx,
+ #[cfg(target_os = "android")]
+ android_context,
+ )
+ .await?;
+ tokio::task::spawn_blocking(move || {
state_machine.run(state_change_listener);
-
if shutdown_tx.send(()).is_err() {
log::error!("Can't send shutdown completion to daemon");
}
});
- startup_result_rx
- .recv()
- .expect("Failed to start tunnel state machine thread")?;
Ok(command_tx)
}
@@ -213,7 +191,6 @@ struct TunnelStateMachine {
impl TunnelStateMachine {
async fn new(
- runtime: tokio::runtime::Handle,
settings: InitialTunnelState,
command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
offline_state_tx: mpsc::UnboundedSender<bool>,
@@ -224,6 +201,8 @@ impl TunnelStateMachine {
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Self, Error> {
+ let runtime = tokio::runtime::Handle::current();
+
#[cfg(windows)]
let split_tunnel = split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone())
.map_err(Error::InitSplitTunneling)?;
@@ -235,7 +214,7 @@ impl TunnelStateMachine {
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
- let route_manager = RouteManager::new(runtime.clone(), HashSet::new())
+ let route_manager = RouteManager::new(HashSet::new())
.await
.map_err(Error::InitRouteManagerError)?;
let dns_monitor = DnsMonitor::new(