diff options
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 51 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 15 | ||||
| -rw-r--r-- | talpid-openvpn/src/lib.rs | 83 | ||||
| -rw-r--r-- | talpid-tunnel/src/lib.rs | 33 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 77 |
5 files changed, 115 insertions, 144 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index db2d83346a..63bd01c57a 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -9,8 +9,13 @@ use talpid_tunnel::tun_provider; pub use talpid_tunnel::{TunnelArgs, TunnelEvent, TunnelMetadata}; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; -use talpid_types::net::{wireguard as wireguard_types, TunnelParameters}; -use talpid_types::tunnel::ErrorStateCause; +use talpid_types::{ + net::{wireguard as wireguard_types, TunnelParameters}, + tunnel::ErrorStateCause, +}; + +#[cfg(not(target_os = "android"))] +use talpid_tunnel::EventHook; const OPENVPN_LOG_FILENAME: &str = "openvpn.log"; const WIREGUARD_LOG_FILENAME: &str = "wireguard.log"; @@ -122,18 +127,11 @@ impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` /// on tunnel state changes. #[cfg_attr(any(target_os = "android", windows), allow(unused_variables))] - pub fn start<L>( + pub fn start( tunnel_parameters: &TunnelParameters, log_dir: &Option<path::PathBuf>, - args: TunnelArgs<'_, L>, - ) -> Result<Self> - where - L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Clone - + Sync - + 'static, - { + args: TunnelArgs<'_>, + ) -> Result<Self> { Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?; let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?; @@ -143,7 +141,7 @@ impl TunnelMonitor { config, log_file, args.resource_dir, - args.on_event, + args.event_hook, args.tunnel_close_rx, args.route_manager, )), @@ -176,21 +174,14 @@ impl TunnelMonitor { } } - fn start_wireguard_tunnel<L>( + fn start_wireguard_tunnel( #[cfg(not(any(target_os = "linux", target_os = "windows")))] params: &wireguard_types::TunnelParameters, #[cfg(any(target_os = "linux", target_os = "windows"))] params: &wireguard_types::TunnelParameters, log: Option<path::PathBuf>, - args: TunnelArgs<'_, L>, - ) -> Result<Self> - where - L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + Clone - + 'static, - { + args: TunnelArgs<'_>, + ) -> Result<Self> { let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), @@ -198,22 +189,16 @@ impl TunnelMonitor { } #[cfg(not(target_os = "android"))] - async fn start_openvpn_tunnel<L>( + async fn start_openvpn_tunnel( config: &openvpn_types::TunnelParameters, log: Option<path::PathBuf>, resource_dir: &path::Path, - on_event: L, + event_hook: EventHook, tunnel_close_rx: oneshot::Receiver<()>, route_manager: RouteManagerHandle, - ) -> Result<Self> - where - L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - { + ) -> Result<Self> { let monitor = talpid_openvpn::OpenVpnMonitor::start( - on_event, + event_hook, config, log, resource_dir, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7387d51c03..53ef61475e 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -19,7 +19,9 @@ use std::{ time::{Duration, Instant}, }; use talpid_routing::RouteManagerHandle; -use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; +use talpid_tunnel::{ + tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, +}; use talpid_types::{ net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters}, tunnel::{ErrorStateCause, FirewallPolicyError}, @@ -214,14 +216,7 @@ impl ConnectingState { retry_attempt: u32, ) -> Self { let (event_tx, event_rx) = mpsc::unbounded(); - let on_tunnel_event = - move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> { - let (tx, rx) = oneshot::channel(); - let _ = event_tx.unbounded_send((event, tx)); - Box::pin(async move { - let _ = rx.await; - }) - }; + let event_hook = EventHook::new(event_tx); let route_manager = route_manager.clone(); let log_dir = log_dir.clone(); @@ -238,7 +233,7 @@ impl ConnectingState { let args = TunnelArgs { runtime, resource_dir: &resource_dir, - on_event: on_tunnel_event, + event_hook, tunnel_close_rx, tun_provider, retry_attempt, diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 421c2076ef..16d0d0e4fc 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -19,7 +19,7 @@ use std::{ }; #[cfg(target_os = "linux")] use talpid_routing::RequiredRoute; -use talpid_tunnel::TunnelEvent; +use talpid_tunnel::EventHook; use talpid_types::{ net::{openvpn, proxy::CustomProxy}, ErrorExt, @@ -245,19 +245,13 @@ impl WintunContextImpl { impl OpenVpnMonitor<OpenVpnCommand> { /// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given /// path. - pub async fn start<L>( - on_event: L, + pub async fn start( + event_hook: EventHook, params: &openvpn::TunnelParameters, log_path: Option<PathBuf>, resource_dir: &Path, route_manager: talpid_routing::RouteManagerHandle, - ) -> Result<Self> - where - L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - { + ) -> Result<Self> { let user_pass_file = Self::create_credentials_file(¶ms.config.username, ¶ms.config.password) .map_err(Error::CredentialsWriteError)?; @@ -308,7 +302,7 @@ impl OpenVpnMonitor<OpenVpnCommand> { cmd, openvpn_init_args, event_server::OpenvpnEventProxyImpl { - on_event, + event_hook, user_pass_file_path: user_pass_file_path.clone(), proxy_auth_file_path: proxy_auth_file_path.clone(), abort_server_tx: event_server_abort_tx, @@ -777,7 +771,7 @@ mod event_server { pin::Pin, task::{Context, Poll}, }; - use talpid_tunnel::TunnelMetadata; + use talpid_tunnel::{EventHook, TunnelMetadata}; #[cfg(any(target_os = "macos", target_os = "windows"))] use talpid_types::net::proxy::CustomProxy; use talpid_types::ErrorExt; @@ -808,15 +802,8 @@ mod event_server { } /// Implements a gRPC service used to process events sent to by OpenVPN. - pub struct OpenvpnEventProxyImpl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - > { - pub on_event: L, + pub struct OpenvpnEventProxyImpl { + pub event_hook: EventHook, pub user_pass_file_path: super::PathBuf, pub proxy_auth_file_path: Option<super::PathBuf>, pub abort_server_tx: triggered::Trigger, @@ -827,26 +814,19 @@ mod event_server { pub ipv6_enabled: bool, } - impl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) - -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - > OpenvpnEventProxyImpl<L> - { + impl OpenvpnEventProxyImpl { async fn up_inner( &self, request: Request<EventDetails>, ) -> std::result::Result<Response<()>, tonic::Status> { let env = request.into_inner().env; - (self.on_event)(talpid_tunnel::TunnelEvent::InterfaceUp( - Self::get_tunnel_metadata(&env)?, - talpid_types::net::AllowedTunnelTraffic::All, - )) - .await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::InterfaceUp( + Self::get_tunnel_metadata(&env)?, + talpid_types::net::AllowedTunnelTraffic::All, + )) + .await; Ok(Response::new(())) } @@ -916,7 +896,10 @@ mod event_server { return Err(tonic::Status::failed_precondition("Failed to add routes")); } - (self.on_event)(talpid_tunnel::TunnelEvent::Up(metadata)).await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::Up(metadata)) + .await; Ok(Response::new(())) } @@ -970,25 +953,18 @@ mod event_server { } #[tonic::async_trait] - impl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) - -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - > OpenvpnEventProxy for OpenvpnEventProxyImpl<L> - { + impl OpenvpnEventProxy for OpenvpnEventProxyImpl { async fn auth_failed( &self, request: Request<EventDetails>, ) -> std::result::Result<Response<()>, tonic::Status> { let env = request.into_inner().env; - (self.on_event)(talpid_tunnel::TunnelEvent::AuthFailed( - env.get("auth_failed_reason").cloned(), - )) - .await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::AuthFailed( + env.get("auth_failed_reason").cloned(), + )) + .await; Ok(Response::new(())) } @@ -1014,7 +990,10 @@ mod event_server { &self, _request: Request<EventDetails>, ) -> std::result::Result<Response<()>, tonic::Status> { - (self.on_event)(talpid_tunnel::TunnelEvent::Down).await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::Down) + .await; Ok(Response::new(())) } } diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs index 53a746d102..ddffb0d302 100644 --- a/talpid-tunnel/src/lib.rs +++ b/talpid-tunnel/src/lib.rs @@ -9,7 +9,13 @@ use std::{ pub mod network_interface; pub mod tun_provider; -use futures::{channel::oneshot, future::BoxFuture}; +use futures::{ + channel::{ + mpsc::UnboundedSender, + oneshot::{self, Sender}, + }, + SinkExt, +}; use talpid_routing::RouteManagerHandle; use talpid_types::net::AllowedTunnelTraffic; use tun_provider::TunProvider; @@ -28,16 +34,13 @@ pub const MIN_IPV4_MTU: u16 = 576; pub const MIN_IPV6_MTU: u16 = 1280; /// Arguments for creating a tunnel. -pub struct TunnelArgs<'a, L> -where - L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, -{ +pub struct TunnelArgs<'a> { /// Tokio runtime handle. pub runtime: tokio::runtime::Handle, /// Resource directory path. pub resource_dir: &'a Path, /// Callback function called when an event happens. - pub on_event: L, + pub event_hook: EventHook, /// Receiver oneshot channel for closing the tunnel. pub tunnel_close_rx: oneshot::Receiver<()>, /// Mutex to tunnel provider. @@ -48,6 +51,24 @@ where pub route_manager: RouteManagerHandle, } +#[derive(Clone)] +pub struct EventHook { + event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>, +} + +impl EventHook { + pub fn new(event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>) -> Self { + Self { event_tx } + } + + pub async fn on_event(&mut self, event: TunnelEvent) { + let (tx, rx) = oneshot::channel::<()>(); + if let Ok(()) = self.event_tx.send((event, tx)).await { + let _ = rx.await; + } + } +} + /// Information about a VPN tunnel. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct TunnelMetadata { diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 06726f9c69..c3cf9a554f 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -5,7 +5,7 @@ use self::config::Config; #[cfg(windows)] use futures::channel::mpsc; -use futures::future::{BoxFuture, Future}; +use futures::future::Future; use obfuscation::ObfuscatorHandle; #[cfg(target_os = "android")] use std::borrow::Cow; @@ -24,11 +24,12 @@ use std::{env, sync::LazyLock}; use talpid_routing::{self, RequiredRoute}; #[cfg(not(windows))] use talpid_tunnel::tun_provider; -use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; +use talpid_tunnel::{ + tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, +}; -use talpid_types::net::wireguard::TunnelParameters; use talpid_types::{ - net::{AllowedTunnelTraffic, Endpoint, TransportProtocol}, + net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol}, BoxedError, ErrorExt, }; use tokio::sync::Mutex as AsyncMutex; @@ -60,7 +61,6 @@ type TunnelType = Box<dyn Tunnel>; type TunnelType = WgGoTunnel; type Result<T> = std::result::Result<T, Error>; -type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; /// Errors that can happen in the Wireguard tunnel monitor. #[derive(thiserror::Error, Debug)] @@ -141,7 +141,7 @@ pub struct WireguardMonitor { /// Tunnel implementation tunnel: Arc<AsyncMutex<Option<TunnelType>>>, /// Callback to signal tunnel events - event_callback: EventCallback, + event_hook: EventHook, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, pinger_stop_sender: sync_mpsc::Sender<()>, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, @@ -158,19 +158,11 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock<bool> = LazyLock::new(|| { impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config #[cfg(not(target_os = "android"))] - pub fn start< - F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + Clone - + 'static, - >( + pub fn start( params: &TunnelParameters, log_path: Option<&Path>, - args: TunnelArgs<'_, F>, + args: TunnelArgs<'_>, ) -> Result<WireguardMonitor> { - let on_event = args.on_event.clone(); - #[cfg(any(target_os = "windows", target_os = "linux"))] let desired_mtu = args .runtime @@ -225,16 +217,16 @@ impl WireguardMonitor { .map_err(Error::ConnectivityMonitorError)? .with_cancellation(); - let event_callback = Box::new(on_event.clone()); let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), - event_callback, + event_hook: args.event_hook.clone(), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator, }; + let mut event_hook = args.event_hook.clone(); let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); @@ -249,7 +241,9 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; + event_hook + .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + .await; // Add non-default routes before establishing the tunnel. #[cfg(target_os = "linux")] @@ -281,11 +275,12 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (on_event)(TunnelEvent::InterfaceUp( - metadata, - Self::allowed_traffic_after_tunnel_config(), - )) - .await; + event_hook + .on_event(TunnelEvent::InterfaceUp( + metadata, + Self::allowed_traffic_after_tunnel_config(), + )) + .await; } if detect_mtu { @@ -350,7 +345,7 @@ impl WireguardMonitor { .map_err(CloseMsg::SetupError)?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (on_event)(TunnelEvent::Up(metadata)).await; + event_hook.on_event(TunnelEvent::Up(metadata)).await; let monitored_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { @@ -395,16 +390,10 @@ impl WireguardMonitor { /// being ready to serve traffic. /// - No routes are configured on android. #[cfg(target_os = "android")] - pub fn start< - F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + Clone - + 'static, - >( + pub fn start( params: &TunnelParameters, log_path: Option<&Path>, - args: TunnelArgs<'_, F>, + args: TunnelArgs<'_>, ) -> Result<WireguardMonitor> { let desired_mtu = get_desired_mtu(params); let mut config = @@ -448,10 +437,11 @@ impl WireguardMonitor { let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); + let mut event_hook = args.event_hook; let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::clone(&tunnel), - event_callback: Box::new(args.on_event.clone()), + event_hook: event_hook.clone(), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator: Arc::new(AsyncMutex::new(obfuscator)), @@ -465,7 +455,8 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + event_hook + .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; if should_negotiate_ephemeral_peer { @@ -482,15 +473,16 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - args.on_event.clone()(TunnelEvent::InterfaceUp( - metadata, - Self::allowed_traffic_after_tunnel_config(), - )) - .await; + event_hook + .on_event(TunnelEvent::InterfaceUp( + metadata, + Self::allowed_traffic_after_tunnel_config(), + )) + .await; } let metadata = Self::tunnel_metadata(&iface_name, &config); - args.on_event.clone()(TunnelEvent::Up(metadata)).await; + event_hook.on_event(TunnelEvent::Up(metadata)).await; // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it let connectivity_check = { @@ -568,7 +560,6 @@ impl WireguardMonitor { /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. /// Used to block traffic to other destinations while connecting on Android. - /// #[cfg(target_os = "android")] fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> { if gateway_only { @@ -803,7 +794,7 @@ impl WireguardMonitor { let _ = self.pinger_stop_sender.send(()); self.runtime - .block_on((self.event_callback)(TunnelEvent::Down)); + .block_on(self.event_hook.on_event(TunnelEvent::Down)); self.stop_tunnel(); |
