summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-12-02 14:05:56 +0100
committerSebastian Holmin <sebastian.holmin@mullvad.net>2024-12-02 14:05:56 +0100
commit3ed024c53913678e1a557bd66b5bef5be774367e (patch)
tree7de740e7e03b2ceac92c41b3e38dc18e6e81df8d
parentdb019b768532baecea3f32a0bb5679c8b551074b (diff)
parentca17aee44561ebd23321c82f1f0d05476d96d49d (diff)
downloadmullvadvpn-3ed024c53913678e1a557bd66b5bef5be774367e.tar.xz
mullvadvpn-3ed024c53913678e1a557bd66b5bef5be774367e.zip
Merge branch 'remove-event-callback'
-rw-r--r--talpid-core/src/tunnel/mod.rs51
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs15
-rw-r--r--talpid-openvpn/src/lib.rs83
-rw-r--r--talpid-tunnel/src/lib.rs33
-rw-r--r--talpid-wireguard/src/lib.rs77
5 files changed, 115 insertions, 144 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index db2d83346a..63bd01c57a 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -9,8 +9,13 @@ use talpid_tunnel::tun_provider;
pub use talpid_tunnel::{TunnelArgs, TunnelEvent, TunnelMetadata};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
-use talpid_types::net::{wireguard as wireguard_types, TunnelParameters};
-use talpid_types::tunnel::ErrorStateCause;
+use talpid_types::{
+ net::{wireguard as wireguard_types, TunnelParameters},
+ tunnel::ErrorStateCause,
+};
+
+#[cfg(not(target_os = "android"))]
+use talpid_tunnel::EventHook;
const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
@@ -122,18 +127,11 @@ impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
/// on tunnel state changes.
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
- pub fn start<L>(
+ pub fn start(
tunnel_parameters: &TunnelParameters,
log_dir: &Option<path::PathBuf>,
- args: TunnelArgs<'_, L>,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Clone
- + Sync
- + 'static,
- {
+ args: TunnelArgs<'_>,
+ ) -> Result<Self> {
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;
@@ -143,7 +141,7 @@ impl TunnelMonitor {
config,
log_file,
args.resource_dir,
- args.on_event,
+ args.event_hook,
args.tunnel_close_rx,
args.route_manager,
)),
@@ -176,21 +174,14 @@ impl TunnelMonitor {
}
}
- fn start_wireguard_tunnel<L>(
+ fn start_wireguard_tunnel(
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
params: &wireguard_types::TunnelParameters,
#[cfg(any(target_os = "linux", target_os = "windows"))]
params: &wireguard_types::TunnelParameters,
log: Option<path::PathBuf>,
- args: TunnelArgs<'_, L>,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- {
+ args: TunnelArgs<'_>,
+ ) -> Result<Self> {
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
@@ -198,22 +189,16 @@ impl TunnelMonitor {
}
#[cfg(not(target_os = "android"))]
- async fn start_openvpn_tunnel<L>(
+ async fn start_openvpn_tunnel(
config: &openvpn_types::TunnelParameters,
log: Option<path::PathBuf>,
resource_dir: &path::Path,
- on_event: L,
+ event_hook: EventHook,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- {
+ ) -> Result<Self> {
let monitor = talpid_openvpn::OpenVpnMonitor::start(
- on_event,
+ event_hook,
config,
log,
resource_dir,
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 7387d51c03..53ef61475e 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -19,7 +19,9 @@ use std::{
time::{Duration, Instant},
};
use talpid_routing::RouteManagerHandle;
-use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
+use talpid_tunnel::{
+ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata,
+};
use talpid_types::{
net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters},
tunnel::{ErrorStateCause, FirewallPolicyError},
@@ -214,14 +216,7 @@ impl ConnectingState {
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
- let on_tunnel_event =
- move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
- let (tx, rx) = oneshot::channel();
- let _ = event_tx.unbounded_send((event, tx));
- Box::pin(async move {
- let _ = rx.await;
- })
- };
+ let event_hook = EventHook::new(event_tx);
let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
@@ -238,7 +233,7 @@ impl ConnectingState {
let args = TunnelArgs {
runtime,
resource_dir: &resource_dir,
- on_event: on_tunnel_event,
+ event_hook,
tunnel_close_rx,
tun_provider,
retry_attempt,
diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs
index 421c2076ef..16d0d0e4fc 100644
--- a/talpid-openvpn/src/lib.rs
+++ b/talpid-openvpn/src/lib.rs
@@ -19,7 +19,7 @@ use std::{
};
#[cfg(target_os = "linux")]
use talpid_routing::RequiredRoute;
-use talpid_tunnel::TunnelEvent;
+use talpid_tunnel::EventHook;
use talpid_types::{
net::{openvpn, proxy::CustomProxy},
ErrorExt,
@@ -245,19 +245,13 @@ impl WintunContextImpl {
impl OpenVpnMonitor<OpenVpnCommand> {
/// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given
/// path.
- pub async fn start<L>(
- on_event: L,
+ pub async fn start(
+ event_hook: EventHook,
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
resource_dir: &Path,
route_manager: talpid_routing::RouteManagerHandle,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- {
+ ) -> Result<Self> {
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
.map_err(Error::CredentialsWriteError)?;
@@ -308,7 +302,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
cmd,
openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
- on_event,
+ event_hook,
user_pass_file_path: user_pass_file_path.clone(),
proxy_auth_file_path: proxy_auth_file_path.clone(),
abort_server_tx: event_server_abort_tx,
@@ -777,7 +771,7 @@ mod event_server {
pin::Pin,
task::{Context, Poll},
};
- use talpid_tunnel::TunnelMetadata;
+ use talpid_tunnel::{EventHook, TunnelMetadata};
#[cfg(any(target_os = "macos", target_os = "windows"))]
use talpid_types::net::proxy::CustomProxy;
use talpid_types::ErrorExt;
@@ -808,15 +802,8 @@ mod event_server {
}
/// Implements a gRPC service used to process events sent to by OpenVPN.
- pub struct OpenvpnEventProxyImpl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- > {
- pub on_event: L,
+ pub struct OpenvpnEventProxyImpl {
+ pub event_hook: EventHook,
pub user_pass_file_path: super::PathBuf,
pub proxy_auth_file_path: Option<super::PathBuf>,
pub abort_server_tx: triggered::Trigger,
@@ -827,26 +814,19 @@ mod event_server {
pub ipv6_enabled: bool,
}
- impl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- )
- -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- > OpenvpnEventProxyImpl<L>
- {
+ impl OpenvpnEventProxyImpl {
async fn up_inner(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
- (self.on_event)(talpid_tunnel::TunnelEvent::InterfaceUp(
- Self::get_tunnel_metadata(&env)?,
- talpid_types::net::AllowedTunnelTraffic::All,
- ))
- .await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::InterfaceUp(
+ Self::get_tunnel_metadata(&env)?,
+ talpid_types::net::AllowedTunnelTraffic::All,
+ ))
+ .await;
Ok(Response::new(()))
}
@@ -916,7 +896,10 @@ mod event_server {
return Err(tonic::Status::failed_precondition("Failed to add routes"));
}
- (self.on_event)(talpid_tunnel::TunnelEvent::Up(metadata)).await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::Up(metadata))
+ .await;
Ok(Response::new(()))
}
@@ -970,25 +953,18 @@ mod event_server {
}
#[tonic::async_trait]
- impl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- )
- -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- > OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
- {
+ impl OpenvpnEventProxy for OpenvpnEventProxyImpl {
async fn auth_failed(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
- (self.on_event)(talpid_tunnel::TunnelEvent::AuthFailed(
- env.get("auth_failed_reason").cloned(),
- ))
- .await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::AuthFailed(
+ env.get("auth_failed_reason").cloned(),
+ ))
+ .await;
Ok(Response::new(()))
}
@@ -1014,7 +990,10 @@ mod event_server {
&self,
_request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
- (self.on_event)(talpid_tunnel::TunnelEvent::Down).await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::Down)
+ .await;
Ok(Response::new(()))
}
}
diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs
index 53a746d102..ddffb0d302 100644
--- a/talpid-tunnel/src/lib.rs
+++ b/talpid-tunnel/src/lib.rs
@@ -9,7 +9,13 @@ use std::{
pub mod network_interface;
pub mod tun_provider;
-use futures::{channel::oneshot, future::BoxFuture};
+use futures::{
+ channel::{
+ mpsc::UnboundedSender,
+ oneshot::{self, Sender},
+ },
+ SinkExt,
+};
use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
@@ -28,16 +34,13 @@ pub const MIN_IPV4_MTU: u16 = 576;
pub const MIN_IPV6_MTU: u16 = 1280;
/// Arguments for creating a tunnel.
-pub struct TunnelArgs<'a, L>
-where
- L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
-{
+pub struct TunnelArgs<'a> {
/// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
/// Resource directory path.
pub resource_dir: &'a Path,
/// Callback function called when an event happens.
- pub on_event: L,
+ pub event_hook: EventHook,
/// Receiver oneshot channel for closing the tunnel.
pub tunnel_close_rx: oneshot::Receiver<()>,
/// Mutex to tunnel provider.
@@ -48,6 +51,24 @@ where
pub route_manager: RouteManagerHandle,
}
+#[derive(Clone)]
+pub struct EventHook {
+ event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>,
+}
+
+impl EventHook {
+ pub fn new(event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>) -> Self {
+ Self { event_tx }
+ }
+
+ pub async fn on_event(&mut self, event: TunnelEvent) {
+ let (tx, rx) = oneshot::channel::<()>();
+ if let Ok(()) = self.event_tx.send((event, tx)).await {
+ let _ = rx.await;
+ }
+ }
+}
+
/// Information about a VPN tunnel.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct TunnelMetadata {
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 06726f9c69..c3cf9a554f 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -5,7 +5,7 @@
use self::config::Config;
#[cfg(windows)]
use futures::channel::mpsc;
-use futures::future::{BoxFuture, Future};
+use futures::future::Future;
use obfuscation::ObfuscatorHandle;
#[cfg(target_os = "android")]
use std::borrow::Cow;
@@ -24,11 +24,12 @@ use std::{env, sync::LazyLock};
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
-use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
+use talpid_tunnel::{
+ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata,
+};
-use talpid_types::net::wireguard::TunnelParameters;
use talpid_types::{
- net::{AllowedTunnelTraffic, Endpoint, TransportProtocol},
+ net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol},
BoxedError, ErrorExt,
};
use tokio::sync::Mutex as AsyncMutex;
@@ -60,7 +61,6 @@ type TunnelType = Box<dyn Tunnel>;
type TunnelType = WgGoTunnel;
type Result<T> = std::result::Result<T, Error>;
-type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
/// Errors that can happen in the Wireguard tunnel monitor.
#[derive(thiserror::Error, Debug)]
@@ -141,7 +141,7 @@ pub struct WireguardMonitor {
/// Tunnel implementation
tunnel: Arc<AsyncMutex<Option<TunnelType>>>,
/// Callback to signal tunnel events
- event_callback: EventCallback,
+ event_hook: EventHook,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
@@ -158,19 +158,11 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock<bool> = LazyLock::new(|| {
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
#[cfg(not(target_os = "android"))]
- pub fn start<
- F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- >(
+ pub fn start(
params: &TunnelParameters,
log_path: Option<&Path>,
- args: TunnelArgs<'_, F>,
+ args: TunnelArgs<'_>,
) -> Result<WireguardMonitor> {
- let on_event = args.on_event.clone();
-
#[cfg(any(target_os = "windows", target_os = "linux"))]
let desired_mtu = args
.runtime
@@ -225,16 +217,16 @@ impl WireguardMonitor {
.map_err(Error::ConnectivityMonitorError)?
.with_cancellation();
- let event_callback = Box::new(on_event.clone());
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
- event_callback,
+ event_hook: args.event_hook.clone(),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator,
};
+ let mut event_hook = args.event_hook.clone();
let moved_tunnel = monitor.tunnel.clone();
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
@@ -249,7 +241,9 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ .await;
// Add non-default routes before establishing the tunnel.
#[cfg(target_os = "linux")]
@@ -281,11 +275,12 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (on_event)(TunnelEvent::InterfaceUp(
- metadata,
- Self::allowed_traffic_after_tunnel_config(),
- ))
- .await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(
+ metadata,
+ Self::allowed_traffic_after_tunnel_config(),
+ ))
+ .await;
}
if detect_mtu {
@@ -350,7 +345,7 @@ impl WireguardMonitor {
.map_err(CloseMsg::SetupError)?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (on_event)(TunnelEvent::Up(metadata)).await;
+ event_hook.on_event(TunnelEvent::Up(metadata)).await;
let monitored_tunnel = Arc::downgrade(&tunnel);
tokio::task::spawn_blocking(move || {
@@ -395,16 +390,10 @@ impl WireguardMonitor {
/// being ready to serve traffic.
/// - No routes are configured on android.
#[cfg(target_os = "android")]
- pub fn start<
- F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- >(
+ pub fn start(
params: &TunnelParameters,
log_path: Option<&Path>,
- args: TunnelArgs<'_, F>,
+ args: TunnelArgs<'_>,
) -> Result<WireguardMonitor> {
let desired_mtu = get_desired_mtu(params);
let mut config =
@@ -448,10 +437,11 @@ impl WireguardMonitor {
let iface_name = tunnel.get_interface_name();
let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
+ let mut event_hook = args.event_hook;
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::clone(&tunnel),
- event_callback: Box::new(args.on_event.clone()),
+ event_hook: event_hook.clone(),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
@@ -465,7 +455,8 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
.await;
if should_negotiate_ephemeral_peer {
@@ -482,15 +473,16 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- args.on_event.clone()(TunnelEvent::InterfaceUp(
- metadata,
- Self::allowed_traffic_after_tunnel_config(),
- ))
- .await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(
+ metadata,
+ Self::allowed_traffic_after_tunnel_config(),
+ ))
+ .await;
}
let metadata = Self::tunnel_metadata(&iface_name, &config);
- args.on_event.clone()(TunnelEvent::Up(metadata)).await;
+ event_hook.on_event(TunnelEvent::Up(metadata)).await;
// HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
let connectivity_check = {
@@ -568,7 +560,6 @@ impl WireguardMonitor {
/// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true.
/// Used to block traffic to other destinations while connecting on Android.
- ///
#[cfg(target_os = "android")]
fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> {
if gateway_only {
@@ -803,7 +794,7 @@ impl WireguardMonitor {
let _ = self.pinger_stop_sender.send(());
self.runtime
- .block_on((self.event_callback)(TunnelEvent::Down));
+ .block_on(self.event_hook.on_event(TunnelEvent::Down));
self.stop_tunnel();