summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-11-28 09:42:54 +0100
committerSebastian Holmin <sebastian.holmin@mullvad.net>2024-12-02 13:39:37 +0100
commit96bceb620d05b2a9d45c11fe9f42f7d90d5eb30a (patch)
tree6d17c761b8ebea7e732d9067e1a03fdb1cc82d2e
parentdb019b768532baecea3f32a0bb5679c8b551074b (diff)
downloadmullvadvpn-96bceb620d05b2a9d45c11fe9f42f7d90d5eb30a.tar.xz
mullvadvpn-96bceb620d05b2a9d45c11fe9f42f7d90d5eb30a.zip
Replace dyn fn with generic
-rw-r--r--talpid-core/src/tunnel/mod.rs104
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs25
-rw-r--r--talpid-openvpn/src/lib.rs35
-rw-r--r--talpid-tunnel/src/lib.rs8
-rw-r--r--talpid-wireguard/src/lib.rs45
5 files changed, 95 insertions, 122 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index db2d83346a..5a12b67eb3 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -9,8 +9,10 @@ use talpid_tunnel::tun_provider;
pub use talpid_tunnel::{TunnelArgs, TunnelEvent, TunnelMetadata};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
-use talpid_types::net::{wireguard as wireguard_types, TunnelParameters};
-use talpid_types::tunnel::ErrorStateCause;
+use talpid_types::{
+ net::{wireguard as wireguard_types, TunnelParameters},
+ tunnel::ErrorStateCause,
+};
const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
@@ -113,27 +115,24 @@ impl Error {
}
/// Abstraction for monitoring a generic VPN tunnel.
-pub struct TunnelMonitor {
- monitor: InternalTunnelMonitor,
+pub struct TunnelMonitor<F> {
+ monitor: InternalTunnelMonitor<F>,
}
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
-impl TunnelMonitor {
+impl<L, F> TunnelMonitor<L>
+where
+ L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()> + Send + 'static,
+{
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
/// on tunnel state changes.
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
- pub fn start<L>(
+ pub fn start(
tunnel_parameters: &TunnelParameters,
log_dir: &Option<path::PathBuf>,
- args: TunnelArgs<'_, L>,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Clone
- + Sync
- + 'static,
- {
+ args: TunnelArgs<'_, L, F>,
+ ) -> Result<Self> {
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;
@@ -156,41 +155,14 @@ impl TunnelMonitor {
}
}
- /// Returns a path to an executable that communicates with relay servers.
- /// Returns `None` if the executable is unknown.
- #[cfg(windows)]
- pub fn get_relay_client(
- resource_dir: &path::Path,
- params: &TunnelParameters,
- ) -> Option<path::PathBuf> {
- use talpid_types::net::proxy::CustomProxy;
-
- let resource_dir = resource_dir.to_path_buf();
- match params {
- TunnelParameters::OpenVpn(params) => match &params.proxy {
- Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
- Some(CustomProxy::Socks5Local(_)) => None,
- Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
- },
- _ => Some(std::env::current_exe().unwrap()),
- }
- }
-
- fn start_wireguard_tunnel<L>(
+ fn start_wireguard_tunnel(
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
params: &wireguard_types::TunnelParameters,
#[cfg(any(target_os = "linux", target_os = "windows"))]
params: &wireguard_types::TunnelParameters,
log: Option<path::PathBuf>,
- args: TunnelArgs<'_, L>,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- {
+ args: TunnelArgs<'_, L, F>,
+ ) -> Result<Self> {
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
@@ -198,20 +170,14 @@ impl TunnelMonitor {
}
#[cfg(not(target_os = "android"))]
- async fn start_openvpn_tunnel<L>(
+ async fn start_openvpn_tunnel(
config: &openvpn_types::TunnelParameters,
log: Option<path::PathBuf>,
resource_dir: &path::Path,
on_event: L,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
- ) -> Result<Self>
- where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- {
+ ) -> Result<Self> {
let monitor = talpid_openvpn::OpenVpnMonitor::start(
on_event,
config,
@@ -289,13 +255,39 @@ impl TunnelMonitor {
}
}
-enum InternalTunnelMonitor {
+impl TunnelMonitor<()> {
+ /// Returns a path to an executable that communicates with relay servers.
+ /// Returns `None` if the executable is unknown.
+ #[cfg(windows)]
+ pub fn get_relay_client(
+ resource_dir: &path::Path,
+ params: &TunnelParameters,
+ ) -> Option<path::PathBuf> {
+ use talpid_types::net::proxy::CustomProxy;
+
+ let resource_dir = resource_dir.to_path_buf();
+ match params {
+ TunnelParameters::OpenVpn(params) => match &params.proxy {
+ Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
+ Some(CustomProxy::Socks5Local(_)) => None,
+ Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
+ },
+ _ => Some(std::env::current_exe().unwrap()),
+ }
+ }
+}
+
+enum InternalTunnelMonitor<F> {
#[cfg(not(target_os = "android"))]
OpenVpn(talpid_openvpn::OpenVpnMonitor),
- Wireguard(talpid_wireguard::WireguardMonitor),
+ Wireguard(talpid_wireguard::WireguardMonitor<F>),
}
-impl InternalTunnelMonitor {
+impl<L, F> InternalTunnelMonitor<L>
+where
+ L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()> + Send + 'static,
+{
fn wait(self) -> Result<()> {
#[cfg(not(target_os = "android"))]
let handle = tokio::runtime::Handle::current();
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 7387d51c03..1f1ea1be4f 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -214,14 +214,13 @@ impl ConnectingState {
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
- let on_tunnel_event =
- move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
- let (tx, rx) = oneshot::channel();
- let _ = event_tx.unbounded_send((event, tx));
- Box::pin(async move {
- let _ = rx.await;
- })
- };
+ let on_tunnel_event = move |event| {
+ let (tx, rx) = oneshot::channel();
+ let _ = event_tx.unbounded_send((event, tx));
+ async move {
+ let _ = rx.await;
+ }
+ };
let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
@@ -290,10 +289,14 @@ impl ConnectingState {
}
}
- fn wait_for_tunnel_monitor(
- tunnel_monitor: TunnelMonitor,
+ fn wait_for_tunnel_monitor<L, F>(
+ tunnel_monitor: TunnelMonitor<L>,
retry_attempt: u32,
- ) -> Option<ErrorStateCause> {
+ ) -> Option<ErrorStateCause>
+ where
+ L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()> + Send + 'static,
+ {
match tunnel_monitor.wait() {
Ok(_) => None,
Err(error) => match error {
diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs
index 421c2076ef..0939b2e343 100644
--- a/talpid-openvpn/src/lib.rs
+++ b/talpid-openvpn/src/lib.rs
@@ -245,7 +245,7 @@ impl WintunContextImpl {
impl OpenVpnMonitor<OpenVpnCommand> {
/// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given
/// path.
- pub async fn start<L>(
+ pub async fn start<L, F>(
on_event: L,
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
@@ -253,10 +253,8 @@ impl OpenVpnMonitor<OpenVpnCommand> {
route_manager: talpid_routing::RouteManagerHandle,
) -> Result<Self>
where
- L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
+ L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()> + Send + 'static,
{
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
@@ -808,14 +806,7 @@ mod event_server {
}
/// Implements a gRPC service used to process events sent to by OpenVPN.
- pub struct OpenvpnEventProxyImpl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- > {
+ pub struct OpenvpnEventProxyImpl<L> {
pub on_event: L,
pub user_pass_file_path: super::PathBuf,
pub proxy_auth_file_path: Option<super::PathBuf>,
@@ -828,13 +819,8 @@ mod event_server {
}
impl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- )
- -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
+ L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()>,
> OpenvpnEventProxyImpl<L>
{
async fn up_inner(
@@ -971,13 +957,8 @@ mod event_server {
#[tonic::async_trait]
impl<
- L: (Fn(
- talpid_tunnel::TunnelEvent,
- )
- -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
+ L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: std::future::Future<Output = ()> + 'static + Send,
> OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
{
async fn auth_failed(
diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs
index 53a746d102..9fe6e0074e 100644
--- a/talpid-tunnel/src/lib.rs
+++ b/talpid-tunnel/src/lib.rs
@@ -1,4 +1,5 @@
use std::{
+ future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::Path,
sync::{Arc, Mutex},
@@ -9,7 +10,7 @@ use std::{
pub mod network_interface;
pub mod tun_provider;
-use futures::{channel::oneshot, future::BoxFuture};
+use futures::channel::oneshot;
use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
@@ -28,9 +29,10 @@ pub const MIN_IPV4_MTU: u16 = 576;
pub const MIN_IPV6_MTU: u16 = 1280;
/// Arguments for creating a tunnel.
-pub struct TunnelArgs<'a, L>
+pub struct TunnelArgs<'a, L, F>
where
- L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
+ L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
+ F: Future<Output = ()>,
{
/// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 06726f9c69..fa52717e50 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -5,7 +5,7 @@
use self::config::Config;
#[cfg(windows)]
use futures::channel::mpsc;
-use futures::future::{BoxFuture, Future};
+use futures::future::Future;
use obfuscation::ObfuscatorHandle;
#[cfg(target_os = "android")]
use std::borrow::Cow;
@@ -26,9 +26,8 @@ use talpid_routing::{self, RequiredRoute};
use talpid_tunnel::tun_provider;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
-use talpid_types::net::wireguard::TunnelParameters;
use talpid_types::{
- net::{AllowedTunnelTraffic, Endpoint, TransportProtocol},
+ net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol},
BoxedError, ErrorExt,
};
use tokio::sync::Mutex as AsyncMutex;
@@ -60,7 +59,6 @@ type TunnelType = Box<dyn Tunnel>;
type TunnelType = WgGoTunnel;
type Result<T> = std::result::Result<T, Error>;
-type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
/// Errors that can happen in the Wireguard tunnel monitor.
#[derive(thiserror::Error, Debug)]
@@ -136,12 +134,12 @@ impl Error {
}
/// Spawns and monitors a wireguard tunnel
-pub struct WireguardMonitor {
+pub struct WireguardMonitor<F> {
runtime: tokio::runtime::Handle,
/// Tunnel implementation
tunnel: Arc<AsyncMutex<Option<TunnelType>>>,
/// Callback to signal tunnel events
- event_callback: EventCallback,
+ event_callback: F,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
@@ -155,22 +153,18 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock<bool> = LazyLock::new(|| {
.unwrap_or(false)
});
-impl WireguardMonitor {
+impl<F, Fut> WireguardMonitor<F>
+where
+ F: (Fn(TunnelEvent) -> Fut) + Send + Sync + Clone + 'static,
+ Fut: Future<Output = ()> + Send,
+{
/// Starts a WireGuard tunnel with the given config
#[cfg(not(target_os = "android"))]
- pub fn start<
- F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + Clone
- + 'static,
- >(
+ pub fn start(
params: &TunnelParameters,
log_path: Option<&Path>,
- args: TunnelArgs<'_, F>,
- ) -> Result<WireguardMonitor> {
- let on_event = args.on_event.clone();
-
+ args: TunnelArgs<'_, F, Fut>,
+ ) -> Result<WireguardMonitor<F>> {
#[cfg(any(target_os = "windows", target_os = "linux"))]
let desired_mtu = args
.runtime
@@ -225,11 +219,10 @@ impl WireguardMonitor {
.map_err(Error::ConnectivityMonitorError)?
.with_cancellation();
- let event_callback = Box::new(on_event.clone());
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
- event_callback,
+ event_callback: args.on_event.clone(),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator,
@@ -249,7 +242,7 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
+ (args.on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
// Add non-default routes before establishing the tunnel.
#[cfg(target_os = "linux")]
@@ -281,7 +274,7 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (on_event)(TunnelEvent::InterfaceUp(
+ (args.on_event)(TunnelEvent::InterfaceUp(
metadata,
Self::allowed_traffic_after_tunnel_config(),
))
@@ -350,7 +343,7 @@ impl WireguardMonitor {
.map_err(CloseMsg::SetupError)?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (on_event)(TunnelEvent::Up(metadata)).await;
+ (args.on_event)(TunnelEvent::Up(metadata)).await;
let monitored_tunnel = Arc::downgrade(&tunnel);
tokio::task::spawn_blocking(move || {
@@ -568,7 +561,6 @@ impl WireguardMonitor {
/// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true.
/// Used to block traffic to other destinations while connecting on Android.
- ///
#[cfg(target_os = "android")]
fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> {
if gateway_only {
@@ -910,7 +902,10 @@ impl WireguardMonitor {
fn get_post_tunnel_routes<'a>(
iface_name: &str,
config: &'a Config,
- ) -> impl Iterator<Item = RequiredRoute> + 'a {
+ ) -> impl Iterator<Item = RequiredRoute> + 'a
+ where
+ Fut: 'a,
+ {
let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config);
let iter = config
.get_tunnel_destinations()