summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/tunnel/mod.rs75
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs24
-rw-r--r--talpid-openvpn/src/lib.rs64
-rw-r--r--talpid-tunnel/src/lib.rs35
-rw-r--r--talpid-wireguard/src/lib.rs74
5 files changed, 135 insertions, 137 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 5a12b67eb3..63bd01c57a 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -14,6 +14,9 @@ use talpid_types::{
tunnel::ErrorStateCause,
};
+#[cfg(not(target_os = "android"))]
+use talpid_tunnel::EventHook;
+
const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
@@ -115,23 +118,19 @@ impl Error {
}
/// Abstraction for monitoring a generic VPN tunnel.
-pub struct TunnelMonitor<F> {
- monitor: InternalTunnelMonitor<F>,
+pub struct TunnelMonitor {
+ monitor: InternalTunnelMonitor,
}
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
-impl<L, F> TunnelMonitor<L>
-where
- L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()> + Send + 'static,
-{
+impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
/// on tunnel state changes.
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
pub fn start(
tunnel_parameters: &TunnelParameters,
log_dir: &Option<path::PathBuf>,
- args: TunnelArgs<'_, L, F>,
+ args: TunnelArgs<'_>,
) -> Result<Self> {
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;
@@ -142,7 +141,7 @@ where
config,
log_file,
args.resource_dir,
- args.on_event,
+ args.event_hook,
args.tunnel_close_rx,
args.route_manager,
)),
@@ -155,13 +154,33 @@ where
}
}
+ /// Returns a path to an executable that communicates with relay servers.
+ /// Returns `None` if the executable is unknown.
+ #[cfg(windows)]
+ pub fn get_relay_client(
+ resource_dir: &path::Path,
+ params: &TunnelParameters,
+ ) -> Option<path::PathBuf> {
+ use talpid_types::net::proxy::CustomProxy;
+
+ let resource_dir = resource_dir.to_path_buf();
+ match params {
+ TunnelParameters::OpenVpn(params) => match &params.proxy {
+ Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
+ Some(CustomProxy::Socks5Local(_)) => None,
+ Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
+ },
+ _ => Some(std::env::current_exe().unwrap()),
+ }
+ }
+
fn start_wireguard_tunnel(
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
params: &wireguard_types::TunnelParameters,
#[cfg(any(target_os = "linux", target_os = "windows"))]
params: &wireguard_types::TunnelParameters,
log: Option<path::PathBuf>,
- args: TunnelArgs<'_, L, F>,
+ args: TunnelArgs<'_>,
) -> Result<Self> {
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
@@ -174,12 +193,12 @@ where
config: &openvpn_types::TunnelParameters,
log: Option<path::PathBuf>,
resource_dir: &path::Path,
- on_event: L,
+ event_hook: EventHook,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
) -> Result<Self> {
let monitor = talpid_openvpn::OpenVpnMonitor::start(
- on_event,
+ event_hook,
config,
log,
resource_dir,
@@ -255,39 +274,13 @@ where
}
}
-impl TunnelMonitor<()> {
- /// Returns a path to an executable that communicates with relay servers.
- /// Returns `None` if the executable is unknown.
- #[cfg(windows)]
- pub fn get_relay_client(
- resource_dir: &path::Path,
- params: &TunnelParameters,
- ) -> Option<path::PathBuf> {
- use talpid_types::net::proxy::CustomProxy;
-
- let resource_dir = resource_dir.to_path_buf();
- match params {
- TunnelParameters::OpenVpn(params) => match &params.proxy {
- Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
- Some(CustomProxy::Socks5Local(_)) => None,
- Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
- },
- _ => Some(std::env::current_exe().unwrap()),
- }
- }
-}
-
-enum InternalTunnelMonitor<F> {
+enum InternalTunnelMonitor {
#[cfg(not(target_os = "android"))]
OpenVpn(talpid_openvpn::OpenVpnMonitor),
- Wireguard(talpid_wireguard::WireguardMonitor<F>),
+ Wireguard(talpid_wireguard::WireguardMonitor),
}
-impl<L, F> InternalTunnelMonitor<L>
-where
- L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()> + Send + 'static,
-{
+impl InternalTunnelMonitor {
fn wait(self) -> Result<()> {
#[cfg(not(target_os = "android"))]
let handle = tokio::runtime::Handle::current();
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 1f1ea1be4f..53ef61475e 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -19,7 +19,9 @@ use std::{
time::{Duration, Instant},
};
use talpid_routing::RouteManagerHandle;
-use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
+use talpid_tunnel::{
+ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata,
+};
use talpid_types::{
net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters},
tunnel::{ErrorStateCause, FirewallPolicyError},
@@ -214,13 +216,7 @@ impl ConnectingState {
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
- let on_tunnel_event = move |event| {
- let (tx, rx) = oneshot::channel();
- let _ = event_tx.unbounded_send((event, tx));
- async move {
- let _ = rx.await;
- }
- };
+ let event_hook = EventHook::new(event_tx);
let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
@@ -237,7 +233,7 @@ impl ConnectingState {
let args = TunnelArgs {
runtime,
resource_dir: &resource_dir,
- on_event: on_tunnel_event,
+ event_hook,
tunnel_close_rx,
tun_provider,
retry_attempt,
@@ -289,14 +285,10 @@ impl ConnectingState {
}
}
- fn wait_for_tunnel_monitor<L, F>(
- tunnel_monitor: TunnelMonitor<L>,
+ fn wait_for_tunnel_monitor(
+ tunnel_monitor: TunnelMonitor,
retry_attempt: u32,
- ) -> Option<ErrorStateCause>
- where
- L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()> + Send + 'static,
- {
+ ) -> Option<ErrorStateCause> {
match tunnel_monitor.wait() {
Ok(_) => None,
Err(error) => match error {
diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs
index 0939b2e343..16d0d0e4fc 100644
--- a/talpid-openvpn/src/lib.rs
+++ b/talpid-openvpn/src/lib.rs
@@ -19,7 +19,7 @@ use std::{
};
#[cfg(target_os = "linux")]
use talpid_routing::RequiredRoute;
-use talpid_tunnel::TunnelEvent;
+use talpid_tunnel::EventHook;
use talpid_types::{
net::{openvpn, proxy::CustomProxy},
ErrorExt,
@@ -245,17 +245,13 @@ impl WintunContextImpl {
impl OpenVpnMonitor<OpenVpnCommand> {
/// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given
/// path.
- pub async fn start<L, F>(
- on_event: L,
+ pub async fn start(
+ event_hook: EventHook,
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
resource_dir: &Path,
route_manager: talpid_routing::RouteManagerHandle,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()> + Send + 'static,
- {
+ ) -> Result<Self> {
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
.map_err(Error::CredentialsWriteError)?;
@@ -306,7 +302,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
cmd,
openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
- on_event,
+ event_hook,
user_pass_file_path: user_pass_file_path.clone(),
proxy_auth_file_path: proxy_auth_file_path.clone(),
abort_server_tx: event_server_abort_tx,
@@ -775,7 +771,7 @@ mod event_server {
pin::Pin,
task::{Context, Poll},
};
- use talpid_tunnel::TunnelMetadata;
+ use talpid_tunnel::{EventHook, TunnelMetadata};
#[cfg(any(target_os = "macos", target_os = "windows"))]
use talpid_types::net::proxy::CustomProxy;
use talpid_types::ErrorExt;
@@ -806,8 +802,8 @@ mod event_server {
}
/// Implements a gRPC service used to process events sent to by OpenVPN.
- pub struct OpenvpnEventProxyImpl<L> {
- pub on_event: L,
+ pub struct OpenvpnEventProxyImpl {
+ pub event_hook: EventHook,
pub user_pass_file_path: super::PathBuf,
pub proxy_auth_file_path: Option<super::PathBuf>,
pub abort_server_tx: triggered::Trigger,
@@ -818,21 +814,19 @@ mod event_server {
pub ipv6_enabled: bool,
}
- impl<
- L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()>,
- > OpenvpnEventProxyImpl<L>
- {
+ impl OpenvpnEventProxyImpl {
async fn up_inner(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
- (self.on_event)(talpid_tunnel::TunnelEvent::InterfaceUp(
- Self::get_tunnel_metadata(&env)?,
- talpid_types::net::AllowedTunnelTraffic::All,
- ))
- .await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::InterfaceUp(
+ Self::get_tunnel_metadata(&env)?,
+ talpid_types::net::AllowedTunnelTraffic::All,
+ ))
+ .await;
Ok(Response::new(()))
}
@@ -902,7 +896,10 @@ mod event_server {
return Err(tonic::Status::failed_precondition("Failed to add routes"));
}
- (self.on_event)(talpid_tunnel::TunnelEvent::Up(metadata)).await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::Up(metadata))
+ .await;
Ok(Response::new(()))
}
@@ -956,20 +953,18 @@ mod event_server {
}
#[tonic::async_trait]
- impl<
- L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: std::future::Future<Output = ()> + 'static + Send,
- > OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
- {
+ impl OpenvpnEventProxy for OpenvpnEventProxyImpl {
async fn auth_failed(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
- (self.on_event)(talpid_tunnel::TunnelEvent::AuthFailed(
- env.get("auth_failed_reason").cloned(),
- ))
- .await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::AuthFailed(
+ env.get("auth_failed_reason").cloned(),
+ ))
+ .await;
Ok(Response::new(()))
}
@@ -995,7 +990,10 @@ mod event_server {
&self,
_request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
- (self.on_event)(talpid_tunnel::TunnelEvent::Down).await;
+ self.event_hook
+ .clone()
+ .on_event(talpid_tunnel::TunnelEvent::Down)
+ .await;
Ok(Response::new(()))
}
}
diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs
index 9fe6e0074e..ddffb0d302 100644
--- a/talpid-tunnel/src/lib.rs
+++ b/talpid-tunnel/src/lib.rs
@@ -1,5 +1,4 @@
use std::{
- future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::Path,
sync::{Arc, Mutex},
@@ -10,7 +9,13 @@ use std::{
pub mod network_interface;
pub mod tun_provider;
-use futures::channel::oneshot;
+use futures::{
+ channel::{
+ mpsc::UnboundedSender,
+ oneshot::{self, Sender},
+ },
+ SinkExt,
+};
use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
@@ -29,17 +34,13 @@ pub const MIN_IPV4_MTU: u16 = 576;
pub const MIN_IPV6_MTU: u16 = 1280;
/// Arguments for creating a tunnel.
-pub struct TunnelArgs<'a, L, F>
-where
- L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
- F: Future<Output = ()>,
-{
+pub struct TunnelArgs<'a> {
/// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
/// Resource directory path.
pub resource_dir: &'a Path,
/// Callback function called when an event happens.
- pub on_event: L,
+ pub event_hook: EventHook,
/// Receiver oneshot channel for closing the tunnel.
pub tunnel_close_rx: oneshot::Receiver<()>,
/// Mutex to tunnel provider.
@@ -50,6 +51,24 @@ where
pub route_manager: RouteManagerHandle,
}
+#[derive(Clone)]
+pub struct EventHook {
+ event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>,
+}
+
+impl EventHook {
+ pub fn new(event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>) -> Self {
+ Self { event_tx }
+ }
+
+ pub async fn on_event(&mut self, event: TunnelEvent) {
+ let (tx, rx) = oneshot::channel::<()>();
+ if let Ok(()) = self.event_tx.send((event, tx)).await {
+ let _ = rx.await;
+ }
+ }
+}
+
/// Information about a VPN tunnel.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct TunnelMetadata {
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index fa52717e50..c3cf9a554f 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -24,7 +24,9 @@ use std::{env, sync::LazyLock};
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
-use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
+use talpid_tunnel::{
+ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata,
+};
use talpid_types::{
net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol},
@@ -134,12 +136,12 @@ impl Error {
}
/// Spawns and monitors a wireguard tunnel
-pub struct WireguardMonitor<F> {
+pub struct WireguardMonitor {
runtime: tokio::runtime::Handle,
/// Tunnel implementation
tunnel: Arc<AsyncMutex<Option<TunnelType>>>,
/// Callback to signal tunnel events
- event_callback: F,
+ event_hook: EventHook,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
@@ -153,18 +155,14 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock<bool> = LazyLock::new(|| {
.unwrap_or(false)
});
-impl<F, Fut> WireguardMonitor<F>
-where
- F: (Fn(TunnelEvent) -> Fut) + Send + Sync + Clone + 'static,
- Fut: Future<Output = ()> + Send,
-{
+impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
#[cfg(not(target_os = "android"))]
pub fn start(
params: &TunnelParameters,
log_path: Option<&Path>,
- args: TunnelArgs<'_, F, Fut>,
- ) -> Result<WireguardMonitor<F>> {
+ args: TunnelArgs<'_>,
+ ) -> Result<WireguardMonitor> {
#[cfg(any(target_os = "windows", target_os = "linux"))]
let desired_mtu = args
.runtime
@@ -222,12 +220,13 @@ where
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
- event_callback: args.on_event.clone(),
+ event_hook: args.event_hook.clone(),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator,
};
+ let mut event_hook = args.event_hook.clone();
let moved_tunnel = monitor.tunnel.clone();
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
@@ -242,7 +241,9 @@ where
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- (args.on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ .await;
// Add non-default routes before establishing the tunnel.
#[cfg(target_os = "linux")]
@@ -274,11 +275,12 @@ where
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event)(TunnelEvent::InterfaceUp(
- metadata,
- Self::allowed_traffic_after_tunnel_config(),
- ))
- .await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(
+ metadata,
+ Self::allowed_traffic_after_tunnel_config(),
+ ))
+ .await;
}
if detect_mtu {
@@ -343,7 +345,7 @@ where
.map_err(CloseMsg::SetupError)?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event)(TunnelEvent::Up(metadata)).await;
+ event_hook.on_event(TunnelEvent::Up(metadata)).await;
let monitored_tunnel = Arc::downgrade(&tunnel);
tokio::task::spawn_blocking(move || {
@@ -388,16 +390,10 @@ where
/// being ready to serve traffic.
/// - No routes are configured on android.
#[cfg(target_os = "android")]
- pub fn start<
- F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- >(
+ pub fn start(
params: &TunnelParameters,
log_path: Option<&Path>,
- args: TunnelArgs<'_, F>,
+ args: TunnelArgs<'_>,
) -> Result<WireguardMonitor> {
let desired_mtu = get_desired_mtu(params);
let mut config =
@@ -441,10 +437,11 @@ where
let iface_name = tunnel.get_interface_name();
let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
+ let mut event_hook = args.event_hook;
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::clone(&tunnel),
- event_callback: Box::new(args.on_event.clone()),
+ event_hook: event_hook.clone(),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
@@ -458,7 +455,8 @@ where
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
.await;
if should_negotiate_ephemeral_peer {
@@ -475,15 +473,16 @@ where
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- args.on_event.clone()(TunnelEvent::InterfaceUp(
- metadata,
- Self::allowed_traffic_after_tunnel_config(),
- ))
- .await;
+ event_hook
+ .on_event(TunnelEvent::InterfaceUp(
+ metadata,
+ Self::allowed_traffic_after_tunnel_config(),
+ ))
+ .await;
}
let metadata = Self::tunnel_metadata(&iface_name, &config);
- args.on_event.clone()(TunnelEvent::Up(metadata)).await;
+ event_hook.on_event(TunnelEvent::Up(metadata)).await;
// HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
let connectivity_check = {
@@ -795,7 +794,7 @@ where
let _ = self.pinger_stop_sender.send(());
self.runtime
- .block_on((self.event_callback)(TunnelEvent::Down));
+ .block_on(self.event_hook.on_event(TunnelEvent::Down));
self.stop_tunnel();
@@ -902,10 +901,7 @@ where
fn get_post_tunnel_routes<'a>(
iface_name: &str,
config: &'a Config,
- ) -> impl Iterator<Item = RequiredRoute> + 'a
- where
- Fut: 'a,
- {
+ ) -> impl Iterator<Item = RequiredRoute> + 'a {
let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config);
let iter = config
.get_tunnel_destinations()