summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-18 14:15:09 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-09-18 14:15:09 +0200
commit63c1abe65257b3357bbca696ad832968e6ecab44 (patch)
treec1663932f2fea79df9a4d4783edc3460abd53333
parenta8c1ca3474cb7741debf5ade3fdbd21de8dcc660 (diff)
parent0264abaf2801709bb9e78e533b8873a0ee3ae6dd (diff)
downloadmullvadvpn-63c1abe65257b3357bbca696ad832968e6ecab44.tar.xz
mullvadvpn-63c1abe65257b3357bbca696ad832968e6ecab44.zip
Merge branch 'timeout-negotiating-ephemeral-peer-des-1238'
-rw-r--r--talpid-wireguard/src/lib.rs214
1 files changed, 190 insertions, 24 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index a477bea80b..0c918e1fc7 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -20,7 +20,7 @@ use std::{
};
#[cfg(target_os = "linux")]
use std::{env, sync::LazyLock};
-use talpid_routing as routing;
+#[cfg(not(target_os = "android"))]
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
@@ -264,6 +264,7 @@ async fn maybe_create_obfuscator(
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
+ #[cfg(not(target_os = "android"))]
pub fn start<
F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
@@ -272,7 +273,7 @@ impl WireguardMonitor {
+ 'static,
>(
mut config: Config,
- #[cfg(not(target_os = "android"))] detect_mtu: bool,
+ detect_mtu: bool,
log_path: Option<&Path>,
args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
@@ -294,8 +295,6 @@ impl WireguardMonitor {
log_path,
args.resource_dir,
args.tun_provider.clone(),
- #[cfg(target_os = "android")]
- config.quantum_resistant,
#[cfg(target_os = "windows")]
args.route_manager.clone(),
#[cfg(target_os = "windows")]
@@ -303,15 +302,6 @@ impl WireguardMonitor {
)?;
let iface_name = tunnel.get_interface_name();
- #[cfg(target_os = "android")]
- if let Some(remote_socket_fd) = obfuscator.as_ref().map(|obfs| obfs.remote_socket_fd()) {
- // Exclude remote obfuscation socket or bridge
- log::debug!("Excluding remote socket fd from the tunnel");
- if let Err(error) = args.tun_provider.lock().unwrap().bypass(remote_socket_fd) {
- log::error!("Failed to exclude remote socket fd: {error}");
- }
- }
-
let obfuscator = Arc::new(AsyncMutex::new(obfuscator));
let event_callback = Box::new(on_event.clone());
@@ -376,8 +366,6 @@ impl WireguardMonitor {
args.retry_attempt,
obfuscator.clone(),
ephemeral_obfs_sender,
- #[cfg(target_os = "android")]
- args.tun_provider,
)
.await?;
@@ -389,7 +377,6 @@ impl WireguardMonitor {
.await;
}
- #[cfg(not(target_os = "android"))]
if detect_mtu {
let config = config.clone();
let iface_name = iface_name.clone();
@@ -420,6 +407,7 @@ impl WireguardMonitor {
};
});
}
+
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
match connectivity_monitor.establish_connectivity(args.retry_attempt) {
Ok(true) => Ok(connectivity_monitor),
@@ -480,6 +468,177 @@ impl WireguardMonitor {
Ok(monitor)
}
+ /// Starts a WireGuard tunnel with the given config
+ ///
+ /// This differs from [`start`] on other platforms in multiple ways. Here is a list of some
+ /// notable differences:
+ /// - A ping is sent between the Wireguard-GO tunnel is started and an ephemeral peer is
+ /// negotiated. There seems to be a race condition between starting the tunnel and the tunnel
+ /// being ready to serve traffic.
+ /// - No routes are configured on android.
+ #[cfg(target_os = "android")]
+ pub fn start<
+ F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ + Send
+ + Sync
+ + Clone
+ + 'static,
+ >(
+ mut config: Config,
+ log_path: Option<&Path>,
+ args: TunnelArgs<'_, F>,
+ ) -> Result<WireguardMonitor> {
+ let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita;
+ let tunnel = Self::open_tunnel(
+ args.runtime.clone(),
+ &config,
+ log_path,
+ args.resource_dir,
+ args.tun_provider.clone(),
+ // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs
+ // that we only allows traffic to/from the gateway. This is only needed on Android
+ // since we lack a firewall there.
+ should_negotiate_ephemeral_peer,
+ )?;
+
+ let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel();
+ let obfuscator = args.runtime.block_on(maybe_create_obfuscator(
+ &mut config,
+ close_obfs_sender.clone(),
+ ))?;
+
+ if let Some(remote_socket_fd) = obfuscator.as_ref().map(|obfs| obfs.remote_socket_fd()) {
+ // Exclude remote obfuscation socket or bridge
+ log::debug!("Excluding remote socket fd from the tunnel");
+ if let Err(error) = args.tun_provider.lock().unwrap().bypass(remote_socket_fd) {
+ log::error!("Failed to exclude remote socket fd: {error}");
+ }
+ }
+
+ let iface_name = tunnel.get_interface_name();
+
+ let (pinger_tx, pinger_rx) = sync_mpsc::channel();
+ let monitor = WireguardMonitor {
+ runtime: args.runtime.clone(),
+ tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
+ event_callback: Box::new(args.on_event.clone()),
+ close_msg_receiver: close_obfs_listener,
+ pinger_stop_sender: pinger_tx,
+ obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
+ };
+
+ let gateway = config.ipv4_gateway;
+ let connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
+ gateway,
+ Arc::downgrade(&monitor.tunnel),
+ pinger_rx,
+ )
+ .map_err(Error::ConnectivityMonitorError)?;
+
+ let moved_tunnel = monitor.tunnel.clone();
+ let moved_close_obfs_sender = close_obfs_sender.clone();
+ let moved_obfuscator = monitor.obfuscator.clone();
+ let tunnel_fut = async move {
+ let tunnel = moved_tunnel;
+ let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender;
+ let obfuscator = moved_obfuscator;
+ let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor));
+
+ let metadata = Self::tunnel_metadata(&iface_name, &config);
+ let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
+ (args.on_event.clone())(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ .await;
+
+ let handle_ping = |ping_result: std::result::Result<
+ bool,
+ connectivity_check::Error,
+ >| match ping_result {
+ Ok(true) => Ok(()),
+ Ok(false) => {
+ log::warn!("Timeout while checking tunnel connection");
+ Err(CloseMsg::PingErr)
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check tunnel connection")
+ );
+ Err(CloseMsg::PingErr)
+ }
+ };
+
+ // Prepare a closure which pings inside the tunnel when executed.
+ let ping = || {
+ let connectivity_monitor_arc = connectivity_monitor.clone();
+ let retry_attempt = args.retry_attempt;
+ move || {
+ let ping_result = connectivity_monitor_arc
+ .lock()
+ .unwrap()
+ .establish_connectivity(retry_attempt);
+ handle_ping(ping_result)
+ }
+ };
+
+ if should_negotiate_ephemeral_peer {
+ // Ping before negotiating the ephemeral peer to make sure that the tunnel works.
+ tokio::task::spawn_blocking(ping()).await.unwrap()?;
+ let ephemeral_obfs_sender = close_obfs_sender.clone();
+ Self::config_ephemeral_peers(
+ &tunnel,
+ &mut config,
+ args.retry_attempt,
+ obfuscator.clone(),
+ ephemeral_obfs_sender,
+ args.tun_provider,
+ )
+ .await?;
+
+ let metadata = Self::tunnel_metadata(&iface_name, &config);
+ (args.on_event.clone())(TunnelEvent::InterfaceUp(
+ metadata,
+ Self::allowed_traffic_after_tunnel_config(),
+ ))
+ .await;
+ }
+
+ // Make sure the tunnel works (after potentially having negotiated an ephemeral peer).
+ tokio::task::spawn_blocking(ping()).await.unwrap()?;
+
+ let metadata = Self::tunnel_metadata(&iface_name, &config);
+ (args.on_event.clone())(TunnelEvent::Up(metadata)).await;
+
+ tokio::task::spawn_blocking(move || {
+ if let Err(error) = connectivity_monitor.lock().unwrap().run() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Connectivity monitor failed")
+ );
+ }
+ })
+ .await
+ .unwrap();
+
+ Err::<Infallible, CloseMsg>(CloseMsg::PingErr)
+ };
+
+ let close_sender = close_obfs_sender.clone();
+ let monitor_handle = tokio::spawn(async move {
+ // This is safe to unwrap because the future resolves to `Result<Infallible, E>`.
+ let close_msg = tunnel_fut.await.unwrap_err();
+ let _ = close_sender.send(close_msg);
+ });
+
+ tokio::spawn(async move {
+ if args.tunnel_close_rx.await.is_ok() {
+ monitor_handle.abort();
+ let _ = close_obfs_sender.send(CloseMsg::Stop);
+ }
+ });
+
+ Ok(monitor)
+ }
+
fn allowed_traffic_during_tunnel_config(config: &Config) -> AllowedTunnelTraffic {
// During ephemeral peer negotiation, only allow traffic to the config service.
if config.quantum_resistant || config.daita {
@@ -754,7 +913,7 @@ impl WireguardMonitor {
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] gateway_only: bool,
- #[cfg(windows)] route_manager: crate::routing::RouteManagerHandle,
+ #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Box<dyn Tunnel>> {
log::debug!("Tunnel MTU: {}", config.mtu);
@@ -894,6 +1053,7 @@ impl WireguardMonitor {
/// Returns routes to the peer endpoints (through the physical interface).
#[cfg_attr(target_os = "linux", allow(unused_variables))]
+ #[cfg(not(target_os = "android"))]
fn get_endpoint_routes(endpoints: &[IpAddr]) -> impl Iterator<Item = RequiredRoute> + '_ {
#[cfg(target_os = "linux")]
{
@@ -904,37 +1064,42 @@ impl WireguardMonitor {
endpoints.iter().map(|ip| {
RequiredRoute::new(
ipnetwork::IpNetwork::from(*ip),
- routing::NetNode::DefaultNode,
+ talpid_routing::NetNode::DefaultNode,
)
})
}
#[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
- fn get_tunnel_nodes(iface_name: &str, config: &Config) -> (routing::Node, routing::Node) {
+ #[cfg(not(target_os = "android"))]
+ fn get_tunnel_nodes(
+ iface_name: &str,
+ config: &Config,
+ ) -> (talpid_routing::Node, talpid_routing::Node) {
#[cfg(windows)]
{
- let v4 = routing::Node::new(config.ipv4_gateway.into(), iface_name.to_string());
+ let v4 = talpid_routing::Node::new(config.ipv4_gateway.into(), iface_name.to_string());
let v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() {
- routing::Node::new((*ipv6_gateway).into(), iface_name.to_string())
+ talpid_routing::Node::new((*ipv6_gateway).into(), iface_name.to_string())
} else {
- routing::Node::device(iface_name.to_string())
+ talpid_routing::Node::device(iface_name.to_string())
};
(v4, v6)
}
#[cfg(not(windows))]
{
- let node = routing::Node::device(iface_name.to_string());
+ let node = talpid_routing::Node::device(iface_name.to_string());
(node.clone(), node)
}
}
/// Return routes for all allowed IPs, as well as the gateway, except 0.0.0.0/0.
+ #[cfg(not(target_os = "android"))]
fn get_pre_tunnel_routes<'a>(
iface_name: &str,
config: &'a Config,
) -> impl Iterator<Item = RequiredRoute> + 'a {
- let gateway_node = routing::Node::device(iface_name.to_string());
+ let gateway_node = talpid_routing::Node::device(iface_name.to_string());
let gateway_routes = std::iter::once(RequiredRoute::new(
ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(),
gateway_node.clone(),
@@ -965,6 +1130,7 @@ impl WireguardMonitor {
}
/// Return any 0.0.0.0/0 routes specified by the allowed IPs.
+ #[cfg(not(target_os = "android"))]
fn get_post_tunnel_routes<'a>(
iface_name: &str,
config: &'a Config,