diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-11-22 17:54:43 +0100 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-11-22 17:54:43 +0100 |
| commit | b1737f5543ed8896c45f28a4e37c022991f22adf (patch) | |
| tree | 6ab95a5148c2c1aff752e3ebcc2801085f064b34 | |
| parent | b2d3287552a6530901e7e954daa5bb446307f672 (diff) | |
| parent | 3516c3c5f987a47b922670aa6d6f34c8c864af8a (diff) | |
| download | mullvadvpn-b1737f5543ed8896c45f28a4e37c022991f22adf.tar.xz mullvadvpn-b1737f5543ed8896c45f28a4e37c022991f22adf.zip | |
Merge branch 'implement-wgturnonmultihop-for-android-droid-1365'
22 files changed, 1378 insertions, 636 deletions
diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs index 550d1955a0..86f0b300e5 100644 --- a/mullvad-relay-selector/src/relay_selector/mod.rs +++ b/mullvad-relay-selector/src/relay_selector/mod.rs @@ -722,12 +722,6 @@ impl RelaySelector { custom_lists: &CustomListsSettings, parsed_relays: &ParsedRelays, ) -> Result<WireguardConfig, Error> { - // TODO: Remove when Android gets support for multihop. - if cfg!(target_os = "android") { - let relay = Self::get_wireguard_singlehop_config(query, custom_lists, parsed_relays) - .ok_or(Error::NoRelay)?; - return Ok(WireguardConfig::from(relay)); - } let inner = if query.singlehop() { match Self::get_wireguard_singlehop_config(query, custom_lists, parsed_relays) { Some(exit) => WireguardConfig::from(exit), diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 1ec8ba46c5..e53b3fa54a 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -435,15 +435,26 @@ impl AllowedClients { } } +/// What [`Endpoint`]s to allow the client to send traffic to and receive from. +/// +/// In some cases we want to restrict what IP addresses the client may communicate with even +/// inside of the tunnel, for example while negotiating a PQ-safe PSK with an ephemeral peer. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum AllowedTunnelTraffic { + /// Block all traffic inside the tunnel. None, + /// Allow all traffic inside the tunnel. This is the normal mode of operation. All, + /// Only allow communication with this specific endpoint. This will usually be a relay during a + /// short amount of time. One(Endpoint), + /// Only allow communication with these two specific endpoints. The intended use case for this + /// is while negotiating for example a PSK with both the entry & exit relays in a multihop setup. Two(Endpoint, Endpoint), } impl AllowedTunnelTraffic { + /// Do we currently allow traffic to all endpoints? pub fn all(&self) -> bool { matches!(self, AllowedTunnelTraffic::All) } diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index 5326427d13..0469273545 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -3,6 +3,7 @@ use std::{ ffi::CString, net::{Ipv4Addr, Ipv6Addr}, }; +use talpid_types::net::wireguard::{PeerConfig, PrivateKey}; use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions}; /// Name to use for the tunnel device @@ -121,38 +122,12 @@ impl Config { /// Returns a CString with the appropriate config for WireGuard-go // TODO: Consider outputting both overriding and additive configs pub fn to_userspace_format(&self) -> CString { - // the order of insertion matters, public key entry denotes a new peer entry - let mut wg_conf = WgConfigBuffer::new(); - wg_conf - .add::<&[u8]>("private_key", self.tunnel.private_key.to_bytes().as_ref()) - .add("listen_port", "0"); - - #[cfg(target_os = "linux")] - if let Some(fwmark) = &self.fwmark { - wg_conf.add("fwmark", fwmark.to_string().as_str()); - } - - wg_conf.add("replace_peers", "true"); - - for peer in self.peers() { - wg_conf - .add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref()) - .add("endpoint", peer.endpoint.to_string().as_str()) - .add("replace_allowed_ips", "true"); - if let Some(ref psk) = peer.psk { - wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref()); - } - for addr in &peer.allowed_ips { - wg_conf.add("allowed_ip", addr.to_string().as_str()); - } - #[cfg(daita)] - if peer.constant_packet_size { - wg_conf.add("constant_packet_size", "true"); - } - } - - let bytes = wg_conf.into_config(); - CString::new(bytes).expect("null bytes inside config") + userspace_format( + &self.tunnel.private_key, + self.peers(), + #[cfg(target_os = "linux")] + self.fwmark, + ) } /// Return whether the config connects to an exit peer from another remote peer. @@ -185,6 +160,13 @@ impl Config { .into_iter() .chain(std::iter::once(&mut self.entry_peer)) } + + /// Return routes for all allowed IPs. + pub fn get_tunnel_destinations(&self) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { + self.peers() + .flat_map(|peer| peer.allowed_ips.iter()) + .cloned() + } } enum ConfValue<'a> { @@ -235,3 +217,48 @@ impl WgConfigBuffer { self.buf } } + +/// Returns a CString with the appropriate config for WireGuard-go +#[allow(single_use_lifetimes)] +pub fn userspace_format<'a>( + private_key: &PrivateKey, + peers: impl Iterator<Item = &'a PeerConfig>, + #[cfg(target_os = "linux")] fwmark: Option<u32>, +) -> CString { + // the order of insertion matters, public key entry denotes a new peer entry + let mut wg_conf = WgConfigBuffer::new(); + wg_conf + .add::<&[u8]>("private_key", private_key.to_bytes().as_ref()) + .add("listen_port", "0"); + + #[cfg(target_os = "linux")] + if let Some(fwmark) = fwmark { + wg_conf.add("fwmark", fwmark.to_string().as_str()); + } + + wg_conf.add("replace_peers", "true"); + + for peer in peers { + write_peer_to_config(&mut wg_conf, peer) + } + + let bytes = wg_conf.into_config(); + CString::new(bytes).expect("null bytes inside config") +} + +fn write_peer_to_config(wg_conf: &mut WgConfigBuffer, peer: &PeerConfig) { + wg_conf + .add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref()) + .add("endpoint", peer.endpoint.to_string().as_str()) + .add("replace_allowed_ips", "true"); + if let Some(ref psk) = peer.psk { + wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref()); + } + for addr in &peer.allowed_ips { + wg_conf.add("allowed_ip", addr.to_string().as_str()); + } + #[cfg(daita)] + if peer.constant_packet_size { + wg_conf.add("constant_packet_size", "true"); + } +} diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity/check.rs index 608002d1a6..527931563b 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity/check.rs @@ -1,52 +1,17 @@ -use crate::{ - ping_monitor::{new_pinger, Pinger}, - stats::StatsMap, -}; -use std::{ - cmp, - net::Ipv4Addr, - sync::{mpsc, Weak}, - time::{Duration, Instant}, -}; -use tokio::sync::Mutex; +use std::cmp; +use std::net::Ipv4Addr; +use std::sync::mpsc; +use std::time::{Duration, Instant}; -use super::{Tunnel, TunnelError}; +use super::constants::*; +use super::error::Error; +use super::pinger; -/// Sleep time used when initially establishing connectivity -const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); -/// Sleep time used when checking if an established connection is still working. -const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); - -/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is -/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic -/// is received. -const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5); -/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will -/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received. -const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120); -/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this -/// timeout is reached, it is assumed that the connection is lost. -const PING_TIMEOUT: Duration = Duration::from_secs(15); -/// Timeout for receiving traffic when establishing a connection. -const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4); -/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt. -const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; -/// Maximum timeout for establishing a connection. -const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT; -/// Number of seconds to wait between sending ICMP packets -const SECONDS_PER_PING: Duration = Duration::from_secs(3); - -/// Connectivity monitor errors -#[derive(thiserror::Error, Debug)] -pub enum Error { - /// Failed to read tunnel's configuration - #[error("Failed to read tunnel's configuration")] - ConfigReadError(TunnelError), - - /// Failed to send ping - #[error("Ping monitor failed")] - PingError(#[from] crate::ping_monitor::Error), -} +use crate::stats::StatsMap; +#[cfg(target_os = "android")] +use crate::Tunnel; +use crate::{TunnelError, TunnelType}; +use pinger::Pinger; /// Verifies if a connection to a tunnel is working. /// The connectivity monitor is biased to receiving traffic - it is expected that all outgoing @@ -70,60 +35,126 @@ pub enum Error { /// /// Once a connection established, a connection is only considered broken once the connectivity /// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. -pub struct ConnectivityMonitor { - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, +pub struct Check<Strategy = Timeout> { conn_state: ConnState, - initial_ping_timestamp: Option<Instant>, - num_pings_sent: u32, - pinger: Box<dyn Pinger>, + ping_state: PingState, + strategy: Strategy, + retry_attempt: u32, +} + +// Define the type state of [Check] +pub(crate) trait Strategy { + fn should_shut_down(&mut self, timeout: Duration) -> bool; +} + +/// An uncancellable [Check] that will run [Check::establish_connectivity] until +/// completion or until it times out. +pub struct Timeout; + +impl Strategy for Timeout { + /// The Timeout strategy cannot receive shut down signals so this function always returns false. + fn should_shut_down(&mut self, _timeout: Duration) -> bool { + false + } +} + +/// A cancellable [Check] may be cancelled before it will time out by sending +/// a signal on the channel returned by [Check::with_cancellation]. Otherwise, +/// it behaves as [Timeout]. +pub struct Cancellable { close_receiver: mpsc::Receiver<()>, } -impl ConnectivityMonitor { - pub(super) fn new( +impl Strategy for Cancellable { + /// Returns true if monitor should be shut down + fn should_shut_down(&mut self, timeout: Duration) -> bool { + match self.close_receiver.recv_timeout(timeout) { + Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, + Err(mpsc::RecvTimeoutError::Timeout) => false, + } + } +} + +impl Check<Timeout> { + pub fn new( addr: Ipv4Addr, #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, - close_receiver: mpsc::Receiver<()>, - ) -> Result<Self, Error> { - let pinger = new_pinger( - addr, - #[cfg(any(target_os = "macos", target_os = "linux"))] - interface, - ) - .map_err(Error::PingError)?; + retry_attempt: u32, + ) -> Result<Check<Timeout>, Error> { + Ok(Check { + conn_state: ConnState::new(Instant::now(), Default::default()), + ping_state: PingState::new( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + )?, + strategy: Timeout, + retry_attempt, + }) + } - let now = Instant::now(); + /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping + /// the returned channel. + pub fn with_cancellation(self) -> (Check<Cancellable>, mpsc::Sender<()>) { + let (cancellation_tx, cancellation_rx) = mpsc::channel(); + let check = Check { + conn_state: self.conn_state, + ping_state: self.ping_state, + strategy: Cancellable { + close_receiver: cancellation_rx, + }, + retry_attempt: self.retry_attempt, + }; + (check, cancellation_tx) + } - Ok(Self { - tunnel_handle, - conn_state: ConnState::new(now, Default::default()), - initial_ping_timestamp: None, - num_pings_sent: 0, - pinger, - close_receiver, - }) + #[cfg(test)] + /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy, + /// see [Check::with_cancellation]. + pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self { + Check { + conn_state, + ping_state, + strategy: Timeout, + retry_attempt: 0, + } } +} +impl<S: Strategy> Check<S> { // checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is // successful at the start of a connection. - pub(super) fn establish_connectivity(&mut self, retry_attempt: u32) -> Result<bool, Error> { + pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result<bool, Error> { // Send initial ping to prod WireGuard into connecting. - self.pinger.send_icmp().map_err(Error::PingError)?; + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; self.establish_connectivity_inner( - retry_attempt, + self.retry_attempt, ESTABLISH_TIMEOUT, ESTABLISH_TIMEOUT_MULTIPLIER, MAX_ESTABLISH_TIMEOUT, + tunnel_handle, ) } + pub(crate) fn reset(&mut self, current_iteration: Instant) { + self.ping_state.reset(); + self.conn_state.reset_after_suspension(current_iteration); + } + + pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool { + self.strategy.should_shut_down(timeout) + } + fn establish_connectivity_inner( &mut self, retry_attempt: u32, timeout_initial: Duration, timeout_multiplier: u32, max_timeout: Duration, + tunnel_handle: &TunnelType, ) -> Result<bool, Error> { if self.conn_state.connected() { return Ok(true); @@ -136,7 +167,7 @@ impl ConnectivityMonitor { let start = Instant::now(); while start.elapsed() < check_timeout { - if self.check_connectivity_interval(Instant::now(), check_timeout)? { + if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? { return Ok(true); } if self.should_shut_down(DELAY_ON_INITIAL_SETUP) { @@ -146,46 +177,13 @@ impl ConnectivityMonitor { Ok(false) } - pub(super) fn run(&mut self) -> Result<(), Error> { - self.wait_loop(REGULAR_LOOP_SLEEP) - } - - /// Returns true if monitor should be shut down - fn should_shut_down(&mut self, timeout: Duration) -> bool { - match self.close_receiver.recv_timeout(timeout) { - Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, - Err(mpsc::RecvTimeoutError::Timeout) => false, - } - } - - fn wait_loop(&mut self, iter_delay: Duration) -> Result<(), Error> { - let mut last_iteration = Instant::now(); - while !self.should_shut_down(iter_delay) { - let mut current_iteration = Instant::now(); - let time_slept = current_iteration - last_iteration; - if time_slept < (iter_delay * 2) { - if !self.check_connectivity(Instant::now())? { - return Ok(()); - } - - let end = Instant::now(); - if end - current_iteration > Duration::from_secs(1) { - current_iteration = end; - } - } else { - // Loop was suspended for too long, so it's safer to assume that the host still has - // connectivity. - self.reset_pinger(); - self.conn_state.reset_after_suspension(current_iteration); - } - last_iteration = current_iteration; - } - Ok(()) - } - /// Returns true if connection is established - fn check_connectivity(&mut self, now: Instant) -> Result<bool, Error> { - self.check_connectivity_interval(now, PING_TIMEOUT) + pub(crate) fn check_connectivity( + &mut self, + now: Instant, + tunnel_handle: &TunnelType, + ) -> Result<bool, Error> { + self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle) } /// Returns true if connection is established @@ -193,19 +191,18 @@ impl ConnectivityMonitor { &mut self, now: Instant, timeout: Duration, + tunnel_handle: &TunnelType, ) -> Result<bool, Error> { - match self.get_stats() { + match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? { None => Ok(false), Some(new_stats) => { - let new_stats = new_stats?; - if self.conn_state.update(now, new_stats) { - self.reset_pinger(); + self.ping_state.reset(); return Ok(true); } self.maybe_send_ping(now)?; - Ok(!self.ping_timed_out(timeout) && self.conn_state.connected()) + Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected()) } } } @@ -214,19 +211,14 @@ impl ConnectivityMonitor { /// calls will also return None. /// /// NOTE: will panic if called from within a tokio runtime. - fn get_stats(&self) -> Option<Result<StatsMap, Error>> { - self.tunnel_handle - .upgrade()? - .blocking_lock() - .as_ref() - .and_then(|tunnel| match tunnel.get_tunnel_stats() { - Ok(stats) if stats.is_empty() => { - log::error!("Tunnel unexpectedly shut down"); - None - } - Ok(stats) => Some(Ok(stats)), - Err(error) => Some(Err(Error::ConfigReadError(error))), - }) + fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> { + let stats = tunnel_handle.get_tunnel_stats()?; + if stats.is_empty() { + log::error!("Tunnel unexpectedly shut down"); + Ok(None) + } else { + Ok(Some(stats)) + } } fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> { @@ -235,20 +227,55 @@ impl ConnectivityMonitor { // 3 seconds. if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out()) && self + .ping_state .initial_ping_timestamp .map(|initial_ping_timestamp| { - initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING + initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent + < SECONDS_PER_PING }) .unwrap_or(true) { - self.pinger.send_icmp().map_err(Error::PingError)?; - if self.initial_ping_timestamp.is_none() { - self.initial_ping_timestamp = Some(now); + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; + if self.ping_state.initial_ping_timestamp.is_none() { + self.ping_state.initial_ping_timestamp = Some(now); } - self.num_pings_sent += 1; + self.ping_state.num_pings_sent += 1; } Ok(()) } +} + +pub(super) struct PingState { + initial_ping_timestamp: Option<Instant>, + num_pings_sent: u32, + pinger: Box<dyn Pinger>, +} + +impl PingState { + pub(super) fn new( + addr: Ipv4Addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, + ) -> Result<Self, Error> { + let pinger = pinger::new_pinger( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + ) + .map_err(Error::PingError)?; + + Ok(Self::new_with(pinger)) + } + + pub(super) fn new_with(pinger: Box<dyn Pinger>) -> Self { + Self { + initial_ping_timestamp: None, + num_pings_sent: 0, + pinger, + } + } fn ping_timed_out(&self, timeout: Duration) -> bool { self.initial_ping_timestamp @@ -257,14 +284,14 @@ impl ConnectivityMonitor { } /// Reset timeouts - assume that the last time bytes were received is now. - fn reset_pinger(&mut self) { + fn reset(&mut self) { self.initial_ping_timestamp = None; self.num_pings_sent = 0; self.pinger.reset(); } } -enum ConnState { +pub(super) enum ConnState { Connecting { start: Instant, stats: StatsMap, @@ -397,21 +424,8 @@ impl ConnState { #[cfg(test)] mod test { - use futures::Future; - use super::*; - use crate::{ - config::Config, - stats::{self, Stats}, - Tunnel, - }; - use std::{ - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - }; + use crate::connectivity::mock::*; /// Test if a newly created ConnState won't have timed out or consider itself connected #[test] @@ -517,300 +531,76 @@ mod test { assert!(!conn_state.traffic_timed_out()); } - #[derive(Default)] - struct MockPinger { - on_send_ping: Option<Box<dyn FnMut() + Send>>, - } - - impl Pinger for MockPinger { - fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> { - if let Some(callback) = self.on_send_ping.as_mut() { - (callback)(); - } - Ok(()) - } - } - - struct MockTunnel { - on_get_stats: Box<dyn Fn() -> Result<stats::StatsMap, TunnelError> + Send>, - } - - impl MockTunnel { - const PEER: [u8; 32] = [0u8; 32]; - - fn new<F: Fn() -> Result<stats::StatsMap, TunnelError> + Send + 'static>(f: F) -> Self { - Self { - on_get_stats: Box::new(f), - } - } - - fn always_incrementing() -> Self { - let mut map = stats::StatsMap::new(); - map.insert( - Self::PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let peers = std::sync::Mutex::new(map); - Self { - on_get_stats: Box::new(move || { - let mut peers = peers.lock().unwrap(); - for traffic in peers.values_mut() { - traffic.tx_bytes += 1; - traffic.rx_bytes += 1; - } - Ok(peers.clone()) - }), - } - } - - fn never_incrementing() -> Self { - Self { - on_get_stats: Box::new(|| { - let mut map = stats::StatsMap::new(); - map.insert( - Self::PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - Ok(map) - }), - } - } - - #[allow(clippy::type_complexity)] - fn into_locked( - self, - ) -> ( - Arc<Mutex<Option<Box<dyn Tunnel>>>>, - Weak<Mutex<Option<Box<dyn Tunnel>>>>, - ) { - let dyn_tunnel: Box<dyn Tunnel> = Box::new(self); - let arc = Arc::new(Mutex::new(Some(dyn_tunnel))); - let weak_ref = Arc::downgrade(&arc); - (arc, weak_ref) - } - } - - impl Tunnel for MockTunnel { - fn get_interface_name(&self) -> String { - "mock-tunnel".to_string() - } - - fn stop(self: Box<Self>) -> Result<(), TunnelError> { - Ok(()) - } - - fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> { - (self.on_get_stats)() - } - - fn set_config( - &mut self, - _config: Config, - ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { - Box::pin(async { Ok(()) }) - } - - #[cfg(daita)] - fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { - Ok(()) - } - } - - fn mock_monitor( - now: Instant, - pinger: Box<dyn Pinger>, - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, - close_receiver: mpsc::Receiver<()>, - ) -> ConnectivityMonitor { - ConnectivityMonitor { - conn_state: ConnState::new(now, Default::default()), - initial_ping_timestamp: None, - num_pings_sent: 0, - pinger, - close_receiver, - tunnel_handle, - } - } - - fn connected_state(timestamp: Instant) -> ConnState { - const PEER: [u8; 32] = [0u8; 32]; - let mut stats = stats::StatsMap::new(); - stats.insert( - PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - ConnState::Connected { - rx_timestamp: timestamp, - tx_timestamp: timestamp, - stats, - } - } - #[test] /// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is /// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`. fn test_ping_times_out() { - let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now .checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10)) .unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut checker = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established - monitor.conn_state = connected_state(start); + checker.conn_state = connected_state(start); // A ping was sent to verify connectivity - monitor.maybe_send_ping(start).unwrap(); - assert!(!monitor.check_connectivity(now).unwrap()) + checker.maybe_send_ping(start).unwrap(); + assert!(!checker.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. fn test_no_connection_on_start() { - let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut monitor = mock_checker(start, Box::new(pinger)); - assert!(!monitor.check_connectivity(now).unwrap()) + assert!(!monitor.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. fn test_connection_works() { - let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::always_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut monitor = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established monitor.conn_state = connected_state(start); - assert!(monitor.check_connectivity(now).unwrap()) - } - - #[test] - /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, - /// and it shuts down properly. - fn test_wait_loop() { - let (result_tx, result_rx) = mpsc::channel(); - let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); - let pinger = MockPinger::default(); - let (stop_tx, stop_rx) = mpsc::channel(); - std::thread::spawn(move || { - let now = Instant::now(); - let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); - - let start_result = monitor.establish_connectivity(0); - result_tx.send(start_result).unwrap(); - - let result = monitor.run().map(|_| true); - result_tx.send(result).unwrap(); - }); - - std::thread::sleep(Duration::from_secs(1)); - assert!(result_rx.try_recv().unwrap().unwrap()); - stop_tx.send(()).unwrap(); - std::thread::sleep(Duration::from_secs(1)); - assert!(result_rx.try_recv().unwrap().is_ok()); - } - - #[test] - /// Verify that the connectivity monitor detects the tunnel timing out after no longer than - /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. - fn test_wait_loop_timeout() { - let should_stop = Arc::new(AtomicBool::new(false)); - let should_stop_inner = should_stop.clone(); - - let mut map = stats::StatsMap::new(); - map.insert( - [0u8; 32], - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let tunnel_stats = std::sync::Mutex::new(map); - - let pinger = MockPinger::default(); - let (_tunnel_anchor, tunnel) = MockTunnel::new(move || { - let mut tunnel_stats = tunnel_stats.lock().unwrap(); - if !should_stop_inner.load(Ordering::SeqCst) { - for traffic in tunnel_stats.values_mut() { - traffic.rx_bytes += 1; - } - } - for traffic in tunnel_stats.values_mut() { - traffic.tx_bytes += 1; - } - Ok(tunnel_stats.clone()) - }) - .into_locked(); - - let (result_tx, result_rx) = mpsc::channel(); - - let (_stop_tx, stop_rx) = mpsc::channel(); - std::thread::spawn(move || { - let now = Instant::now(); - let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); - let start_result = monitor.establish_connectivity(0); - result_tx.send(start_result).unwrap(); - let end_result = monitor.run().map(|_| true); - result_tx.send(end_result).expect("Failed to send result"); - }); - assert!(result_rx - .recv_timeout(Duration::from_secs(1)) - .unwrap() - .unwrap()); - should_stop.store(true, Ordering::SeqCst); - assert!(result_rx - .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) - .unwrap() - .is_ok()); + assert!(monitor.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that the timeout for setting up a tunnel works as expected. fn test_establish_timeout() { - let mut tunnel_stats = stats::StatsMap::new(); - tunnel_stats.insert( - [0u8; 32], - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let pinger = MockPinger::default(); - let (_tunnel_anchor, tunnel) = - MockTunnel::new(move || Ok(tunnel_stats.clone())).into_locked(); + let tunnel = { + let mut tunnel_stats = StatsMap::new(); + tunnel_stats.insert( + [0u8; 32], + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed() + }; let (result_tx, result_rx) = mpsc::channel(); - let (_stop_tx, stop_rx) = mpsc::channel(); std::thread::spawn(move || { let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); + let mut monitor = mock_checker(start, Box::new(pinger)); const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500); @@ -823,6 +613,7 @@ mod test { ESTABLISH_TIMEOUT, ESTABLISH_TIMEOUT_MULTIPLIER, MAX_ESTABLISH_TIMEOUT, + &tunnel, )) .unwrap(); } diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs new file mode 100644 index 0000000000..a8d6752ddd --- /dev/null +++ b/talpid-wireguard/src/connectivity/constants.rs @@ -0,0 +1,22 @@ +use std::time::Duration; + +/// Sleep time used when initially establishing connectivity +pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); +/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is +/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic +/// is received. +pub(crate) const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5); +/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will +/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received. +pub(crate) const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120); +/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this +/// timeout is reached, it is assumed that the connection is lost. +pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(15); +/// Timeout for receiving traffic when establishing a connection. +pub(crate) const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4); +/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt. +pub(crate) const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; +/// Maximum timeout for establishing a connection. +pub(crate) const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT; +/// Number of seconds to wait between sending ICMP packets +pub(crate) const SECONDS_PER_PING: Duration = Duration::from_secs(3); diff --git a/talpid-wireguard/src/connectivity/error.rs b/talpid-wireguard/src/connectivity/error.rs new file mode 100644 index 0000000000..9e8c98a751 --- /dev/null +++ b/talpid-wireguard/src/connectivity/error.rs @@ -0,0 +1,14 @@ +use super::pinger; +use crate::TunnelError; + +/// Connectivity monitor errors +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Failed to read tunnel's configuration + #[error("Failed to read tunnel's configuration")] + ConfigReadError(TunnelError), + + /// Failed to send ping + #[error("Ping failed")] + PingError(#[from] pinger::Error), +} diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs new file mode 100644 index 0000000000..892f3966ea --- /dev/null +++ b/talpid-wireguard/src/connectivity/mock.rs @@ -0,0 +1,133 @@ +use std::future::Future; +use std::pin::Pin; +use std::time::Instant; + +use super::check::{ConnState, PingState, Timeout}; +use super::pinger; +use super::Check; + +use crate::{Config, Tunnel, TunnelError}; +use pinger::Pinger; + +// Convenient re-exports +pub use crate::stats::{Stats, StatsMap}; + +#[derive(Default)] +pub(crate) struct MockPinger { + on_send_ping: Option<Box<dyn FnMut() + Send>>, +} + +pub(crate) struct MockTunnel { + on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send>, +} + +pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> Check<Timeout> { + let conn_state = ConnState::new(now, Default::default()); + let ping_state = PingState::new_with(pinger); + Check::mock(conn_state, ping_state) +} + +pub fn connected_state(timestamp: Instant) -> ConnState { + const PEER: [u8; 32] = [0u8; 32]; + let mut stats = StatsMap::new(); + stats.insert( + PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + ConnState::Connected { + rx_timestamp: timestamp, + tx_timestamp: timestamp, + stats, + } +} + +impl MockTunnel { + const PEER: [u8; 32] = [0u8; 32]; + + pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + 'static>(f: F) -> Self { + Self { + on_get_stats: Box::new(f), + } + } + + /// Convert self to the more general [TunnelType]. + pub fn boxed(self) -> Box<dyn Tunnel> { + Box::new(self) + } + + pub fn always_incrementing() -> Self { + let mut map = StatsMap::new(); + map.insert( + Self::PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + let peers = std::sync::Mutex::new(map); + Self { + on_get_stats: Box::new(move || { + let mut peers = peers.lock().unwrap(); + for traffic in peers.values_mut() { + traffic.tx_bytes += 1; + traffic.rx_bytes += 1; + } + Ok(peers.clone()) + }), + } + } + + pub fn never_incrementing() -> Self { + Self { + on_get_stats: Box::new(|| { + let mut map = StatsMap::new(); + map.insert( + Self::PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + Ok(map) + }), + } + } +} + +impl Tunnel for MockTunnel { + fn get_interface_name(&self) -> String { + "mock-tunnel".to_string() + } + + fn stop(self: Box<Self>) -> Result<(), TunnelError> { + Ok(()) + } + + fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> { + (self.on_get_stats)() + } + + fn set_config( + &mut self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { + Box::pin(async { Ok(()) }) + } + + #[cfg(daita)] + fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + Ok(()) + } +} + +impl Pinger for MockPinger { + fn send_icmp(&mut self) -> Result<(), pinger::Error> { + if let Some(callback) = self.on_send_ping.as_mut() { + (callback)(); + } + Ok(()) + } +} diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs new file mode 100644 index 0000000000..512d8715f1 --- /dev/null +++ b/talpid-wireguard/src/connectivity/mod.rs @@ -0,0 +1,13 @@ +mod check; +mod constants; +mod error; +#[cfg(test)] +mod mock; +mod monitor; +mod pinger; + +#[cfg(target_os = "android")] +pub use check::Cancellable; +pub use check::Check; +pub use error::Error; +pub use monitor::Monitor; diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs new file mode 100644 index 0000000000..583b8d9589 --- /dev/null +++ b/talpid-wireguard/src/connectivity/monitor.rs @@ -0,0 +1,174 @@ +use std::{ + sync::Weak, + time::{Duration, Instant}, +}; + +use tokio::sync::Mutex; + +use crate::TunnelType; + +use super::check::{Cancellable, Check}; +use super::error::Error; + +/// Sleep time used when checking if an established connection is still working. +const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); + +pub struct Monitor { + connectivity_check: Check<Cancellable>, +} + +impl Monitor { + pub fn init(connectivity_check: Check<Cancellable>) -> Self { + Self { connectivity_check } + } + + pub fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> { + self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle) + } + + fn wait_loop( + mut self, + iter_delay: Duration, + tunnel_handle: Weak<Mutex<Option<TunnelType>>>, + ) -> Result<(), Error> { + let mut last_iteration = Instant::now(); + while !self.connectivity_check.should_shut_down(iter_delay) { + let mut current_iteration = Instant::now(); + let time_slept = current_iteration - last_iteration; + if time_slept < (iter_delay * 2) { + let Some(tunnel) = tunnel_handle.upgrade() else { + return Ok(()); + }; + let lock = tunnel.blocking_lock(); + let Some(tunnel) = lock.as_ref() else { + return Ok(()); + }; + + if !self + .connectivity_check + .check_connectivity(Instant::now(), tunnel)? + { + return Ok(()); + } + drop(lock); + + let end = Instant::now(); + if end - current_iteration > Duration::from_secs(1) { + current_iteration = end; + } + } else { + // Loop was suspended for too long, so it's safer to assume that the host still has + // connectivity. + self.connectivity_check.reset(current_iteration); + } + last_iteration = current_iteration; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + // TODO: Port to async + tokio to reduce cost of testing? + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc; + use std::sync::Arc; + use std::time::Duration; + use std::time::Instant; + + use tokio::sync::Mutex; + + use crate::connectivity::constants::*; + use crate::connectivity::mock::*; + + #[test] + /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, + /// and it shuts down properly. + fn test_wait_loop() { + use std::sync::mpsc; + let (result_tx, result_rx) = mpsc::channel(); + let tunnel = MockTunnel::always_incrementing().boxed(); + let pinger = MockPinger::default(); + let (mut checker, stop_tx) = { + let now = Instant::now(); + let start = now.checked_sub(Duration::from_secs(1)).unwrap(); + mock_checker(start, Box::new(pinger)).with_cancellation() + }; + std::thread::spawn(move || { + let start_result = checker.establish_connectivity(&tunnel); + result_tx.send(start_result).unwrap(); + // Pointer dance + let tunnel = Arc::new(Mutex::new(Some(tunnel))); + let _tunnel = Arc::downgrade(&tunnel); + let result = Monitor::init(checker).run(_tunnel).map(|_| true); + result_tx.send(result).unwrap(); + }); + + std::thread::sleep(Duration::from_secs(1)); + assert!(result_rx.try_recv().unwrap().unwrap()); + stop_tx.send(()).unwrap(); + std::thread::sleep(Duration::from_secs(1)); + assert!(result_rx.try_recv().unwrap().is_ok()); + } + + #[test] + /// Verify that the connectivity monitor detects the tunnel timing out after no longer than + /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. + fn test_wait_loop_timeout() { + let should_stop = Arc::new(AtomicBool::new(false)); + let should_stop_inner = should_stop.clone(); + + let mut map = StatsMap::new(); + map.insert( + [0u8; 32], + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + let tunnel_stats = std::sync::Mutex::new(map); + + let pinger = MockPinger::default(); + let tunnel = MockTunnel::new(move || { + let mut tunnel_stats = tunnel_stats.lock().unwrap(); + if !should_stop_inner.load(Ordering::SeqCst) { + for traffic in tunnel_stats.values_mut() { + traffic.rx_bytes += 1; + } + } + for traffic in tunnel_stats.values_mut() { + traffic.tx_bytes += 1; + } + Ok(tunnel_stats.clone()) + }) + .boxed(); + + let (result_tx, result_rx) = mpsc::channel(); + + std::thread::spawn(move || { + let (mut checker, _cancellation_token) = { + let now = Instant::now(); + let start = now.checked_sub(Duration::from_secs(1)).unwrap(); + mock_checker(start, Box::new(pinger)).with_cancellation() + }; + let start_result = checker.establish_connectivity(&tunnel); + result_tx.send(start_result).unwrap(); + // Pointer dance + let _tunnel = Arc::new(Mutex::new(Some(tunnel))); + let tunnel = Arc::downgrade(&_tunnel); + let end_result = Monitor::init(checker).run(tunnel).map(|_| true); + result_tx.send(end_result).expect("Failed to send result"); + }); + assert!(result_rx + .recv_timeout(Duration::from_secs(1)) + .unwrap() + .unwrap()); + should_stop.store(true, Ordering::SeqCst); + assert!(result_rx + .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) + .unwrap() + .is_ok()); + } +} diff --git a/talpid-wireguard/src/ping_monitor/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs index 00ad4d8fd3..00ad4d8fd3 100644 --- a/talpid-wireguard/src/ping_monitor/android.rs +++ b/talpid-wireguard/src/connectivity/pinger/android.rs diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs index 0e5d739425..0e5d739425 100644 --- a/talpid-wireguard/src/ping_monitor/icmp.rs +++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs diff --git a/talpid-wireguard/src/ping_monitor/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs index ef2394f1b7..ef2394f1b7 100644 --- a/talpid-wireguard/src/ping_monitor/mod.rs +++ b/talpid-wireguard/src/connectivity/pinger/mod.rs diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index 5440a142f6..a9283fcb2e 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -1,7 +1,10 @@ //! This module takes care of obtaining ephemeral peers, updating the WireGuard configuration and //! restarting obfuscation and WG tunnels when necessary. -use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, Tunnel}; +#[cfg(target_os = "android")] // On Android, the Tunnel trait is not imported by default. +use super::Tunnel; +use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, TunnelType}; + #[cfg(target_os = "android")] use std::sync::Mutex; use std::{ @@ -22,7 +25,7 @@ const PSK_EXCHANGE_TIMEOUT_MULTIPLIER: u32 = 2; #[cfg(windows)] pub async fn config_ephemeral_peers( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, @@ -66,13 +69,13 @@ fn try_set_ipv4_mtu(alias: &str, mtu: u16) { #[cfg(not(windows))] pub async fn config_ephemeral_peers( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, -) -> std::result::Result<(), CloseMsg> { +) -> Result<(), CloseMsg> { config_ephemeral_peers_inner( tunnel, config, @@ -86,13 +89,13 @@ pub async fn config_ephemeral_peers( } async fn config_ephemeral_peers_inner( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, config: &mut Config, retry_attempt: u32, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, -) -> std::result::Result<(), CloseMsg> { +) -> Result<(), CloseMsg> { let ephemeral_private_key = PrivateKey::new_from_random(); let close_obfs_sender = close_obfs_sender.clone(); @@ -111,6 +114,7 @@ async fn config_ephemeral_peers_inner( if config.is_multihop() { // Set up tunnel to lead to entry let mut entry_tun_config = config.clone(); + entry_tun_config.exit_peer = None; entry_tun_config .entry_peer .allowed_ips @@ -126,6 +130,7 @@ async fn config_ephemeral_peers_inner( &tun_provider, ) .await?; + let entry_psk = request_ephemeral_peer( retry_attempt, &entry_config, @@ -173,15 +178,16 @@ async fn config_ephemeral_peers_inner( Ok(()) } +#[cfg(target_os = "android")] /// Reconfigures the tunnel to use the provided config while potentially modifying the config /// and restarting the obfuscation provider. Returns the new config used by the new tunnel. async fn reconfigure_tunnel( - tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, mut config: Config, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, - #[cfg(target_os = "android")] tun_provider: &Arc<Mutex<TunProvider>>, -) -> std::result::Result<Config, CloseMsg> { + tun_provider: &Arc<Mutex<TunProvider>>, +) -> Result<Config, CloseMsg> { let mut obfs_guard = obfuscator.lock().await; if let Some(obfuscator_handle) = obfs_guard.take() { obfuscator_handle.abort(); @@ -194,17 +200,49 @@ async fn reconfigure_tunnel( .await .map_err(CloseMsg::ObfuscatorFailed)?; } + { + let mut shared_tunnel = tunnel.lock().await; + let tunnel = shared_tunnel.take().expect("tunnel was None"); - let mut tunnel = tunnel.lock().await; - - let set_config_future = tunnel - .as_mut() - .map(|tunnel| tunnel.set_config(config.clone())); - - if let Some(f) = set_config_future { - f.await + let updated_tunnel = tunnel + .set_config(&config) .map_err(Error::TunnelError) .map_err(CloseMsg::SetupError)?; + + *shared_tunnel = Some(updated_tunnel); + } + Ok(config) +} + +#[cfg(not(target_os = "android"))] +/// Reconfigures the tunnel to use the provided config while potentially modifying the config +/// and restarting the obfuscation provider. Returns the new config used by the new tunnel. +async fn reconfigure_tunnel( + tunnel: &Arc<AsyncMutex<Option<TunnelType>>>, + mut config: Config, + obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, + close_obfs_sender: sync_mpsc::Sender<CloseMsg>, +) -> Result<Config, CloseMsg> { + let mut obfs_guard = obfuscator.lock().await; + if let Some(obfuscator_handle) = obfs_guard.take() { + obfuscator_handle.abort(); + *obfs_guard = super::obfuscation::apply_obfuscation_config(&mut config, close_obfs_sender) + .await + .map_err(CloseMsg::ObfuscatorFailed)?; + } + + { + let mut tunnel = tunnel.lock().await; + + let set_config_future = tunnel + .as_mut() + .map(|tunnel| tunnel.set_config(config.clone())); + + if let Some(f) = set_config_future { + f.await + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } } Ok(config) diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index d1e09ff570..7c93b39f18 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -35,11 +35,10 @@ use tokio::sync::Mutex as AsyncMutex; /// WireGuard config data-types pub mod config; -mod connectivity_check; +mod connectivity; mod ephemeral; mod logging; mod obfuscation; -mod ping_monitor; mod stats; #[cfg(wireguard_go)] mod wireguard_go; @@ -54,6 +53,12 @@ mod mtu_detection; #[cfg(wireguard_go)] use self::wireguard_go::WgGoTunnel; +// On android we only have Wireguard Go tunnel +#[cfg(not(target_os = "android"))] +type TunnelType = Box<dyn Tunnel>; +#[cfg(target_os = "android")] +type TunnelType = WgGoTunnel; + type Result<T> = std::result::Result<T, Error>; type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; @@ -82,7 +87,7 @@ pub enum Error { /// Failed to set up connectivity monitor #[error("Connectivity monitor failed")] - ConnectivityMonitorError(#[source] connectivity_check::Error), + ConnectivityMonitorError(#[source] connectivity::Error), /// Failed while negotiating ephemeral peer #[error("Failed while negotiating ephemeral peer")] @@ -134,7 +139,7 @@ impl Error { pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation - tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, + tunnel: Arc<AsyncMutex<Option<TunnelType>>>, /// Callback to signal tunnel events event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, @@ -210,8 +215,17 @@ impl WireguardMonitor { let obfuscator = Arc::new(AsyncMutex::new(obfuscator)); + let gateway = config.ipv4_gateway; + let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new( + gateway, + #[cfg(any(target_os = "macos", target_os = "linux"))] + iface_name.clone(), + args.retry_attempt, + ) + .map_err(Error::ConnectivityMonitorError)? + .with_cancellation(); + let event_callback = Box::new(on_event.clone()); - let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), @@ -221,16 +235,6 @@ impl WireguardMonitor { obfuscator, }; - let gateway = config.ipv4_gateway; - let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( - gateway, - #[cfg(any(target_os = "macos", target_os = "linux"))] - iface_name.clone(), - Arc::downgrade(&monitor.tunnel), - pinger_rx, - ) - .map_err(Error::ConnectivityMonitorError)?; - let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); @@ -315,8 +319,12 @@ impl WireguardMonitor { }); } - let mut connectivity_monitor = tokio::task::spawn_blocking(move || { - match connectivity_monitor.establish_connectivity(args.retry_attempt) { + let cloned_tunnel = Arc::clone(&tunnel); + + let connectivity_check = tokio::task::spawn_blocking(move || { + let lock = cloned_tunnel.blocking_lock(); + let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly"); + match connectivity_monitor.establish_connectivity(tunnel) { Ok(true) => Ok(connectivity_monitor), Ok(false) => { log::warn!("Timeout while checking tunnel connection"); @@ -344,8 +352,11 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); (on_event)(TunnelEvent::Up(metadata)).await; + let monitored_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { - if let Err(error) = connectivity_monitor.run() { + if let Err(error) = + connectivity::Monitor::init(connectivity_check).run(monitored_tunnel) + { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -396,8 +407,8 @@ impl WireguardMonitor { args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { let desired_mtu = get_desired_mtu(params); - let mut config = crate::config::Config::from_parameters(params, desired_mtu) - .map_err(Error::WireguardConfigError)?; + let mut config = + Config::from_parameters(params, desired_mtu).map_err(Error::WireguardConfigError)?; let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel(); // Start obfuscation server and patch the WireGuard config to point the endpoint to it. @@ -417,8 +428,13 @@ impl WireguardMonitor { } let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; - let tunnel = Self::open_tunnel( - args.runtime.clone(), + + let (connectivity_check, pinger_tx) = + connectivity::Check::new(config.ipv4_gateway, args.retry_attempt) + .map_err(Error::ConnectivityMonitorError)? + .with_cancellation(); + + let tunnel = Self::open_wireguard_go_tunnel( &config, log_path, args.resource_dir, @@ -427,77 +443,34 @@ impl WireguardMonitor { // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. should_negotiate_ephemeral_peer, + connectivity_check, )?; let iface_name = tunnel.get_interface_name(); - - let (pinger_tx, pinger_rx) = sync_mpsc::channel(); + let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); let monitor = WireguardMonitor { runtime: args.runtime.clone(), - tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), + tunnel: Arc::clone(&tunnel), event_callback: Box::new(args.on_event.clone()), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator: Arc::new(AsyncMutex::new(obfuscator)), }; - let gateway = config.ipv4_gateway; - let connectivity_monitor = connectivity_check::ConnectivityMonitor::new( - gateway, - Arc::downgrade(&monitor.tunnel), - pinger_rx, - ) - .map_err(Error::ConnectivityMonitorError)?; - - let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); let tunnel_fut = async move { - let tunnel = moved_tunnel; let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender; let obfuscator = moved_obfuscator; - let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor)); let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - (args.on_event.clone())(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; - let handle_ping = |ping_result: std::result::Result< - bool, - connectivity_check::Error, - >| match ping_result { - Ok(true) => Ok(()), - Ok(false) => { - log::warn!("Timeout while checking tunnel connection"); - Err(CloseMsg::PingErr) - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check tunnel connection") - ); - Err(CloseMsg::PingErr) - } - }; - - // Prepare a closure which pings inside the tunnel when executed. - let ping = || { - let connectivity_monitor_arc = connectivity_monitor.clone(); - let retry_attempt = args.retry_attempt; - move || { - let ping_result = connectivity_monitor_arc - .lock() - .unwrap() - .establish_connectivity(retry_attempt); - handle_ping(ping_result) - } - }; - if should_negotiate_ephemeral_peer { - // Ping before negotiating the ephemeral peer to make sure that the tunnel works. - tokio::task::spawn_blocking(ping()).await.unwrap()?; let ephemeral_obfs_sender = close_obfs_sender.clone(); + ephemeral::config_ephemeral_peers( &tunnel, &mut config, @@ -509,21 +482,31 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event.clone())(TunnelEvent::InterfaceUp( + args.on_event.clone()(TunnelEvent::InterfaceUp( metadata, Self::allowed_traffic_after_tunnel_config(), )) .await; } - // Make sure the tunnel works (after potentially having negotiated an ephemeral peer). - tokio::task::spawn_blocking(ping()).await.unwrap()?; - let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event.clone())(TunnelEvent::Up(metadata)).await; + args.on_event.clone()(TunnelEvent::Up(metadata)).await; + + // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it + let connectivity_check = { + let mut tunnel_lock = tunnel.lock().await; + let Some(tunnel) = tunnel_lock.as_mut() else { + log::debug!("Tunnel is no longer running"); + return Err::<Infallible, CloseMsg>(CloseMsg::PingErr); + }; + tunnel + .take_checker() + .expect("connectivity checker unexpectedly dropped") + }; tokio::task::spawn_blocking(move || { - if let Err(error) = connectivity_monitor.lock().unwrap().run() { + let tunnel = Arc::downgrade(&tunnel); + if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -585,6 +568,7 @@ impl WireguardMonitor { /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. /// Used to block traffic to other destinations while connecting on Android. + /// #[cfg(target_os = "android")] fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> { if gateway_only { @@ -654,16 +638,16 @@ impl WireguardMonitor { } #[allow(unused_variables)] + #[cfg(not(target_os = "android"))] fn open_tunnel( runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, - #[cfg(target_os = "android")] gateway_only: bool, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, - ) -> Result<Box<dyn Tunnel>> { + ) -> Result<TunnelType> { log::debug!("Tunnel MTU: {}", config.mtu); #[cfg(target_os = "linux")] @@ -743,12 +727,15 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(target_os = "android")] gateway_only: bool, + #[cfg(target_os = "android")] connectivity_check: connectivity::Check< + connectivity::Cancellable, + >, ) -> Result<WgGoTunnel> { - let routes = Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes); - - #[cfg(target_os = "android")] - let config = Self::patch_allowed_ips(config, gateway_only); + let routes = config + .get_tunnel_destinations() + .flat_map(Self::replace_default_prefixes); + #[cfg(not(target_os = "android"))] let tunnel = WgGoTunnel::start_tunnel( #[allow(clippy::needless_borrow)] &config, @@ -760,6 +747,44 @@ impl WireguardMonitor { ) .map_err(Error::TunnelError)?; + // Android uses multihop implemented in Mullvad's wireguard-go fork. When negotiating + // with an ephemeral peer, this multihop strategy require us to restart the tunnel + // every time we want to reconfigure it. As such, we will actually start a multihop + // tunnel at a later stage, after we have negotiated with the first ephemeral peer. + // At this point, when the tunnel *is first started*, we establish a regular, singlehop + // tunnel to where the ephemeral peer resides. + // + // Refer to `docs/architecture.md` for details on how to use multihop + PQ. + #[cfg(target_os = "android")] + let config = Self::patch_allowed_ips(config, gateway_only); + + #[cfg(target_os = "android")] + let tunnel = if let Some(exit_peer) = &config.exit_peer { + WgGoTunnel::start_multihop_tunnel( + &config, + exit_peer, + log_path, + tun_provider, + routes, + #[cfg(daita)] + resource_dir, + connectivity_check, + ) + .map_err(Error::TunnelError)? + } else { + WgGoTunnel::start_tunnel( + #[allow(clippy::needless_borrow)] + &config, + log_path, + tun_provider, + routes, + #[cfg(daita)] + resource_dir, + connectivity_check, + ) + .map_err(Error::TunnelError)? + }; + Ok(tunnel) } @@ -865,7 +890,8 @@ impl WireguardMonitor { gateway_routes.map(|route| Self::apply_route_mtu_for_multihop(route, config)); let routes = gateway_routes.chain( - Self::get_tunnel_destinations(config) + config + .get_tunnel_destinations() .filter(|allowed_ip| allowed_ip.prefix() != 0) .map(move |allowed_ip| { if allowed_ip.is_ipv4() { @@ -886,7 +912,8 @@ impl WireguardMonitor { config: &'a Config, ) -> impl Iterator<Item = RequiredRoute> + 'a { let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); - let iter = Self::get_tunnel_destinations(config) + let iter = config + .get_tunnel_destinations() .filter(|allowed_ip| allowed_ip.prefix() == 0) .flat_map(Self::replace_default_prefixes) .map(move |allowed_ip| { @@ -928,14 +955,6 @@ impl WireguardMonitor { } } - /// Return routes for all allowed IPs. - fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { - config - .peers() - .flat_map(|peer| peer.allowed_ips.iter()) - .cloned() - } - /// Replace default (0-prefix) routes with more specific routes. fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> { #[cfg(windows)] @@ -973,6 +992,7 @@ enum CloseMsg { ObfuscatorFailed(Error), } +#[allow(unused)] pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; @@ -1067,6 +1087,15 @@ pub enum TunnelError { #[cfg(daita)] #[error("Failed to start DAITA - tunnel implemenation does not support DAITA")] DaitaNotSupported, + + /// [connectivity] error. + #[error(transparent)] + Connectivity(#[from] Box<connectivity::Error>), + + /// Tunnel seemingly does not serve any traffic + #[cfg(target_os = "android")] + #[error("Tunnel seemingly does not serve any traffic")] + TunnelUp, } #[cfg(target_os = "linux")] diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 25ebb45a38..d283758f3b 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,3 +1,14 @@ +#[cfg(target_os = "android")] +use super::config; +use super::{ + stats::{Stats, StatsMap}, + Config, Tunnel, TunnelError, +}; +#[cfg(target_os = "linux")] +use crate::config::MULLVAD_INTERFACE_NAME; +#[cfg(target_os = "android")] +use crate::connectivity; +use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; #[cfg(daita)] use once_cell::sync::OnceCell; @@ -13,16 +24,10 @@ use std::{ #[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; use talpid_tunnel::tun_provider::{Tun, TunProvider}; +#[cfg(target_os = "android")] +use talpid_types::net::wireguard::PeerConfig; use talpid_types::BoxedError; -use super::{ - stats::{Stats, StatsMap}, - Config, Tunnel, TunnelError, -}; -#[cfg(target_os = "linux")] -use crate::config::MULLVAD_INTERFACE_NAME; -use crate::logging::{clean_up_logging, initialize_logging}; - const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; /// Maximum number of events that can be stored in the underlying buffer @@ -35,21 +40,129 @@ const DAITA_ACTIONS_CAPACITY: u32 = 1000; type Result<T> = std::result::Result<T, TunnelError>; -struct LoggingContext(u64); +struct LoggingContext { + ordinal: u64, + #[allow(dead_code)] + path: Option<PathBuf>, +} + +impl LoggingContext { + fn new(ordinal: u64, path: Option<PathBuf>) -> Self { + LoggingContext { ordinal, path } + } +} impl Drop for LoggingContext { fn drop(&mut self) { - clean_up_logging(self.0); + clean_up_logging(self.ordinal); + } +} + +#[cfg(not(target_os = "android"))] +pub struct WgGoTunnel(WgGoTunnelState); + +#[cfg(target_os = "android")] +pub enum WgGoTunnel { + Multihop(WgGoTunnelState), + Singlehop(WgGoTunnelState), +} + +#[cfg(not(target_os = "android"))] +impl WgGoTunnel { + fn into_state(self) -> WgGoTunnelState { + self.0 + } + + fn as_state(&self) -> &WgGoTunnelState { + &self.0 + } + + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { + &mut self.0 + } +} + +#[cfg(target_os = "android")] +impl WgGoTunnel { + fn into_state(self) -> WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + fn as_state(&self) -> &WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { + match self { + WgGoTunnel::Multihop(state) => state, + WgGoTunnel::Singlehop(state) => state, + } + } + + pub fn set_config(mut self, config: &Config) -> Result<Self> { + let connectivity_checker = self + .take_checker() + .expect("connectivity checker unexpectedly dropped"); + let state = self.as_state(); + let log_path = state._logging_context.path.clone(); + let tun_provider = Arc::clone(&state.tun_provider); + let routes = config.get_tunnel_destinations(); + #[cfg(daita)] + let resource_dir = state.resource_dir.clone(); + + match self { + WgGoTunnel::Multihop(state) if !config.is_multihop() => { + state.stop()?; + Self::start_tunnel( + config, + log_path.as_deref(), + tun_provider, + routes, + &resource_dir, + connectivity_checker, + ) + } + WgGoTunnel::Singlehop(state) if config.is_multihop() => { + state.stop()?; + Self::start_multihop_tunnel( + config, + &config.exit_peer.clone().unwrap().clone(), + log_path.as_deref(), + tun_provider, + routes, + &resource_dir, + connectivity_checker, + ) + } + WgGoTunnel::Singlehop(mut state) => { + state.set_config(config.clone())?; + Ok(WgGoTunnel::Singlehop(state)) + } + WgGoTunnel::Multihop(mut state) => { + state.set_config(config.clone())?; + Ok(WgGoTunnel::Multihop(state)) + } + } + } + + pub fn stop(self) -> Result<()> { + self.into_state().stop() } } -pub struct WgGoTunnel { +pub(crate) struct WgGoTunnelState { interface_name: String, tunnel_handle: wireguard_go_rs::Tunnel, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped _tunnel_device: Tun, - // context that maps to fs::File instance, used with logging callback + // context that maps to fs::File instance and stores the file path, used with logging callback _logging_context: LoggingContext, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, @@ -57,9 +170,53 @@ pub struct WgGoTunnel { resource_dir: PathBuf, #[cfg(daita)] config: Config, + // HACK: Check is not Clone, so we have to pass this around .. + // This is conceptually the connection between this Tunnel and the currently running + // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of + // a new Tunnel during the "ensure_connectivity" phase. This field should be removed + // as soon as we implement a better way to cancel Check asynchronously. + #[cfg(target_os = "android")] + connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>, +} + +impl WgGoTunnelState { + fn stop(self) -> Result<()> { + self.tunnel_handle + .turn_off() + .map_err(|e| TunnelError::StopWireguardError(Box::new(e))) + } + + fn set_config(&mut self, config: Config) -> Result<()> { + let wg_config_str = config.to_userspace_format(); + + self.tunnel_handle + .set_config(&wg_config_str) + .map_err(|_| TunnelError::SetConfigError)?; + + #[cfg(target_os = "android")] + let tun_provider = self.tun_provider.clone(); + + // When reapplying the config, the endpoint socket may be discarded + // and needs to be excluded again + #[cfg(target_os = "android")] + { + let socket_v4 = self.tunnel_handle.get_socket_v4(); + let socket_v6 = self.tunnel_handle.get_socket_v6(); + let mut provider = tun_provider.lock().unwrap(); + provider + .bypass(socket_v4) + .map_err(super::TunnelError::BypassError)?; + provider + .bypass(socket_v6) + .map_err(super::TunnelError::BypassError)?; + } + + Ok(()) + } } impl WgGoTunnel { + #[cfg(not(target_os = "android"))] pub fn start_tunnel( config: &Config, log_path: Option<&Path>, @@ -67,60 +224,35 @@ impl WgGoTunnel { routes: impl Iterator<Item = IpNetwork>, #[cfg(daita)] resource_dir: &Path, ) -> Result<Self> { - #[cfg(target_os = "android")] - let tun_provider_clone = tun_provider.clone(); - - #[cfg_attr(not(target_os = "android"), allow(unused_mut))] - let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?; + let (tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?; let interface_name: String = tunnel_device.interface_name().to_string(); let wg_config_str = config.to_userspace_format(); let logging_context = initialize_logging(log_path) - .map(LoggingContext) + .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned))) .map_err(TunnelError::LoggingError)?; - #[cfg(not(target_os = "android"))] let mtu = config.mtu as isize; + let handle = wireguard_go_rs::Tunnel::turn_on( - #[cfg(not(target_os = "android"))] mtu, &wg_config_str, tunnel_fd, Some(logging::wg_go_logging_callback), - logging_context.0, + logging_context.ordinal, ) .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; - #[cfg(target_os = "android")] - Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) - .map_err(TunnelError::BypassError)?; - - Ok(WgGoTunnel { + Ok(WgGoTunnel(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, _logging_context: logging_context, - #[cfg(target_os = "android")] - tun_provider: tun_provider_clone, #[cfg(daita)] resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), - }) - } - - #[cfg(target_os = "android")] - fn bypass_tunnel_sockets( - handle: &wireguard_go_rs::Tunnel, - tunnel_device: &mut Tun, - ) -> std::result::Result<(), TunProviderError> { - let socket_v4 = handle.get_socket_v4(); - let socket_v6 = handle.get_socket_v6(); - - tunnel_device.bypass(socket_v4)?; - tunnel_device.bypass(socket_v6)?; - - Ok(()) + })) } fn get_tunnel( @@ -162,13 +294,171 @@ impl WgGoTunnel { } } +#[cfg(target_os = "android")] +impl WgGoTunnel { + pub fn start_tunnel( + config: &Config, + log_path: Option<&Path>, + tun_provider: Arc<Mutex<TunProvider>>, + routes: impl Iterator<Item = IpNetwork>, + #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check<connectivity::Cancellable>, + ) -> Result<Self> { + let (mut tunnel_device, tunnel_fd) = + Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + + let interface_name: String = tunnel_device.interface_name().to_string(); + let logging_context = initialize_logging(log_path) + .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned))) + .map_err(TunnelError::LoggingError)?; + + let wg_config_str = config.to_userspace_format(); + + let handle = wireguard_go_rs::Tunnel::turn_on( + &wg_config_str, + tunnel_fd, + Some(logging::wg_go_logging_callback), + logging_context.ordinal, + ) + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) + .map_err(TunnelError::BypassError)?; + + let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { + interface_name, + tunnel_handle: handle, + _tunnel_device: tunnel_device, + _logging_context: logging_context, + tun_provider, + #[cfg(daita)] + resource_dir: resource_dir.to_owned(), + #[cfg(daita)] + config: config.clone(), + connectivity_checker: None, + }); + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + + Ok(tunnel) + } + + pub fn start_multihop_tunnel( + config: &Config, + exit_peer: &PeerConfig, + log_path: Option<&Path>, + tun_provider: Arc<Mutex<TunProvider>>, + routes: impl Iterator<Item = IpNetwork>, + #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check<connectivity::Cancellable>, + ) -> Result<Self> { + let (mut tunnel_device, tunnel_fd) = + Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + + let interface_name: String = tunnel_device.interface_name().to_string(); + let logging_context = initialize_logging(log_path) + .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned))) + .map_err(TunnelError::LoggingError)?; + + let entry_config_str = config::userspace_format( + &config.tunnel.private_key, + std::iter::once(&config.entry_peer), + ); + + let exit_config_str = + config::userspace_format(&config.tunnel.private_key, std::iter::once(exit_peer)); + + let private_ip = config + .tunnel + .addresses + .iter() + .find(|addr| addr.is_ipv4()) + .map(|addr| CString::new(addr.to_string()).unwrap()) + .ok_or(TunnelError::SetConfigError)?; + + let handle = wireguard_go_rs::Tunnel::turn_on_multihop( + &exit_config_str, + &entry_config_str, + &private_ip, + tunnel_fd, + Some(logging::wg_go_logging_callback), + logging_context.ordinal, + ) + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; + + Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) + .map_err(TunnelError::BypassError)?; + + let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState { + interface_name, + tunnel_handle: handle, + _tunnel_device: tunnel_device, + _logging_context: logging_context, + tun_provider, + #[cfg(daita)] + resource_dir: resource_dir.to_owned(), + #[cfg(daita)] + config: config.clone(), + connectivity_checker: None, + }); + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + + Ok(tunnel) + } + + fn bypass_tunnel_sockets( + handle: &wireguard_go_rs::Tunnel, + tunnel_device: &mut Tun, + ) -> std::result::Result<(), TunProviderError> { + let socket_v4 = handle.get_socket_v4(); + let socket_v6 = handle.get_socket_v6(); + + tunnel_device.bypass(socket_v4)?; + tunnel_device.bypass(socket_v6)?; + + Ok(()) + } + + pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> { + self.as_state_mut().connectivity_checker.take() + } + + /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve + /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. + fn ensure_tunnel_is_running( + &self, + checker: &mut connectivity::Check<connectivity::Cancellable>, + ) -> Result<()> { + let connectivity_err = |e| TunnelError::Connectivity(Box::new(e)); + let connection_established = checker + .establish_connectivity(self) + .map_err(connectivity_err)?; + + // Timed out + if !connection_established { + return Err(TunnelError::TunnelUp); + } + Ok(()) + } +} + impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> String { - self.interface_name.clone() + self.as_state().interface_name.clone() + } + + fn stop(self: Box<Self>) -> Result<()> { + self.into_state().stop() } fn get_tunnel_stats(&self) -> Result<StatsMap> { - self.tunnel_handle + self.as_state() + .tunnel_handle .get_config(|cstr| { Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) }) @@ -176,54 +466,25 @@ impl Tunnel for WgGoTunnel { .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) } - fn stop(self: Box<Self>) -> Result<()> { - self.tunnel_handle - .turn_off() - .map_err(|e| TunnelError::StopWireguardError(Box::new(e))) - } - fn set_config( &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> { - Box::pin(async move { - let wg_config_str = config.to_userspace_format(); - - self.tunnel_handle - .set_config(&wg_config_str) - .map_err(|_| TunnelError::SetConfigError)?; - - #[cfg(target_os = "android")] - let tun_provider = self.tun_provider.clone(); - - // When reapplying the config, the endpoint socket may be discarded - // and needs to be excluded again - #[cfg(target_os = "android")] - { - let socket_v4 = self.tunnel_handle.get_socket_v4(); - let socket_v6 = self.tunnel_handle.get_socket_v6(); - let mut provider = tun_provider.lock().unwrap(); - provider - .bypass(socket_v4) - .map_err(super::TunnelError::BypassError)?; - provider - .bypass(socket_v6) - .map_err(super::TunnelError::BypassError)?; - } - - Ok(()) - }) + Box::pin(async move { self.as_state_mut().set_config(config) }) } #[cfg(daita)] fn start_daita(&mut self) -> Result<()> { static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new(); - let machines = - MAYBENOT_MACHINES.get_or_try_init(|| load_maybenot_machines(&self.resource_dir))?; + let machines = MAYBENOT_MACHINES + .get_or_try_init(|| load_maybenot_machines(&self.as_state().resource_dir))?; log::info!("Initializing DAITA for wireguard device"); - let peer_public_key = &self.config.entry_peer.public_key; - self.tunnel_handle + let config = &self.as_state().config; + let peer_public_key = &config.entry_peer.public_key; + + self.as_state() + .tunnel_handle .activate_daita( peer_public_key.as_bytes(), machines, diff --git a/wireguard-go-rs/libwg/go.mod b/wireguard-go-rs/libwg/go.mod index 76627dcb7f..8f463a64ad 100644 --- a/wireguard-go-rs/libwg/go.mod +++ b/wireguard-go-rs/libwg/go.mod @@ -4,13 +4,16 @@ go 1.21 require ( golang.org/x/sys v0.19.0 - golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 + golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a67 ) require ( + github.com/google/btree v1.0.1 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/net v0.24.0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) replace golang.zx2c4.com/wireguard => ./wireguard-go diff --git a/wireguard-go-rs/libwg/go.sum b/wireguard-go-rs/libwg/go.sum index b41c5842d1..d04296cf67 100644 --- a/wireguard-go-rs/libwg/go.sum +++ b/wireguard-go-rs/libwg/go.sum @@ -1,8 +1,14 @@ +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/wireguard-go-rs/libwg/libwg.go b/wireguard-go-rs/libwg/libwg.go index 6cfbd0ba55..5dcc9141b2 100644 --- a/wireguard-go-rs/libwg/libwg.go +++ b/wireguard-go-rs/libwg/libwg.go @@ -65,6 +65,9 @@ func wgTurnOff(tunnelHandle int32) { return } tunnel.Device.Close() + if tunnel.EntryDevice != nil { + tunnel.EntryDevice.Close() + } } // Calling twice convinces the GC to release NOW. runtime.GC() diff --git a/wireguard-go-rs/libwg/libwg_android.go b/wireguard-go-rs/libwg/libwg_android.go index d623b7711d..caca9b04d0 100644 --- a/wireguard-go-rs/libwg/libwg_android.go +++ b/wireguard-go-rs/libwg/libwg_android.go @@ -11,6 +11,8 @@ import "C" import ( "bufio" + "errors" + "net/netip" "strings" "unsafe" @@ -19,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/multihoptun" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/logging" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer" @@ -29,6 +32,12 @@ import ( type LogSink = unsafe.Pointer type LogContext = C.uint64_t +type tunnelHandle struct { + exit *device.Device + entry *device.Device + logger *device.Logger +} + //export wgTurnOn func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t { logger := logging.NewLogger(logSink, logging.LogContext(logContext)) @@ -77,13 +86,166 @@ func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) return C.int32_t(handle) } +//export wgTurnOnMultihop +func wgTurnOnMultihop(cExitSettings *C.char, cEntrySettings *C.char, privateIp *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t { + logger := logging.NewLogger(logSink, logging.LogContext(logContext)) + if cExitSettings == nil { + logger.Errorf("cExitSettings is null\n") + return ERROR_INVALID_ARGUMENT + } + exitSettings := goStringFixed(cExitSettings) + + if cEntrySettings == nil { + logger.Errorf("cEntrySettings is null\n") + return ERROR_INVALID_ARGUMENT + } + entrySettings := goStringFixed(cEntrySettings) + + exitEndpoint := parseEndpointFromConfig(exitSettings) + + if exitEndpoint == nil { + logger.Errorf("exitEndpoint is null\n") + return ERROR_INVALID_ARGUMENT + } + + // Set up a two tunnel devices: One 'fake' device for the exit relay and one 'real' device for the entry relay + + tunDevice, _, err := tun.CreateUnmonitoredTUNFromFD(fd) + if err != nil { + logger.Errorf("%s\n", err) + unix.Close(fd) + if err.Error() == "bad file descriptor" { + return ERROR_INTERMITTENT_FAILURE + } + return ERROR_GENERAL_FAILURE + } + + ip, err := netip.ParseAddr(goStringFixed(privateIp)) + if err != nil { + logger.Errorf("%s\n", err) + tunDevice.Close() + return ERROR_INVALID_ARGUMENT + } + + mtu, err := tunDevice.MTU() + if err != nil { + logger.Errorf("%s\n", err) + tunDevice.Close() + return ERROR_GENERAL_FAILURE + } + + singleTunMtu := mtu - 80 //Internet mtu - Wireguard header size - ipv4 UDP header + singletun := multihoptun.NewMultihopTun(ip, exitEndpoint.Addr(), exitEndpoint.Port(), singleTunMtu) + + entryDevice := device.NewDevice(&singletun, conn.NewStdNetBind(), logger) + exitDevice := device.NewDevice(tunDevice, singletun.Binder(), logger) + + setErr := entryDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(entrySettings))) + if setErr != nil { + logger.Errorf("%s\n", setErr) + exitDevice.Close() + entryDevice.Close() + return ERROR_INTERMITTENT_FAILURE + } + + entryDevice.DisableSomeRoamingForBrokenMobileSemantics() + + setErr = exitDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(exitSettings))) + if setErr != nil { + logger.Errorf("%s\n", setErr) + exitDevice.Close() + entryDevice.Close() + return ERROR_INTERMITTENT_FAILURE + } + + exitDevice.DisableSomeRoamingForBrokenMobileSemantics() + + exitDevice.Up() + entryDevice.Up() + + // Create the stuff that needs + + context := tunnelcontainer.Context{ + Device: exitDevice, + EntryDevice: entryDevice, + Logger: logger, + } + + handle, err := tunnels.Insert(context) + if err != nil { + logger.Errorf("%s\n", err) + entryDevice.Close() + exitDevice.Close() + return ERROR_GENERAL_FAILURE + } + + return C.int32_t(handle) + +} + +func addTunnelFromDevice(exitDev *device.Device, entryDev *device.Device, exitSettings string, entrySettings string, logger *device.Logger) (*tunnelHandle, error) { + err := bringUpDevice(exitDev, exitSettings, logger) + if err != nil { + return nil, errors.New("Could not bring up exit device") // errBadWgConfig + } + + if entryDev != nil { + err = bringUpDevice(entryDev, entrySettings, logger) + if err != nil { + exitDev.Close() + return nil, errors.New("Could not bring up entry device") + } + } + + return &tunnelHandle{exitDev, entryDev, logger}, nil +} + +func bringUpDevice(dev *device.Device, settings string, logger *device.Logger) error { + err := dev.IpcSet(settings) + if err != nil { + logger.Errorf("Unable to set IPC settings: %v", err) + dev.Close() + return err + } + + dev.Up() + logger.Verbosef("Device started") + return nil +} + +// Parse a wireguard config and return the first endpoint address it finds and +// parses successfully.gi b +func parseEndpointFromConfig(config string) *netip.AddrPort { + scanner := bufio.NewScanner(strings.NewReader(config)) + for scanner.Scan() { + line := scanner.Text() + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + + if key == "endpoint" { + endpoint, err := netip.ParseAddrPort(value) + if err == nil { + return &endpoint + } + } + + } + return nil +} + //export wgGetSocketV4 func wgGetSocketV4(tunnelHandle int32) C.int32_t { tunnel, err := tunnels.Get(tunnelHandle) if err != nil { return ERROR_UNKNOWN_TUNNEL } - peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd) + device := tunnel.EntryDevice + if device == nil { + device = tunnel.Device + } + peek := device.Bind().(conn.PeekLookAtSocketFd) fd, err := peek.PeekLookAtSocketFd4() if err != nil { return ERROR_GENERAL_FAILURE @@ -97,7 +259,11 @@ func wgGetSocketV6(tunnelHandle int32) C.int32_t { if err != nil { return ERROR_UNKNOWN_TUNNEL } - peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd) + device := tunnel.EntryDevice + if device == nil { + device = tunnel.Device + } + peek := device.Bind().(conn.PeekLookAtSocketFd) fd, err := peek.PeekLookAtSocketFd6() if err != nil { return ERROR_GENERAL_FAILURE diff --git a/wireguard-go-rs/libwg/libwg_daita.go b/wireguard-go-rs/libwg/libwg_daita.go index 3b1fedda4c..fbfceec8f0 100644 --- a/wireguard-go-rs/libwg/libwg_daita.go +++ b/wireguard-go-rs/libwg/libwg_daita.go @@ -32,7 +32,15 @@ func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C. var publicKey device.NoisePublicKey copy(publicKey[:], C.GoBytes(unsafe.Pointer(peerPubkey), device.NoisePublicKeySize)) - peer := tunnel.Device.LookupPeer(publicKey) + + var peer *device.Peer + if tunnel.EntryDevice != nil { + // TODO: Document me + peer = tunnel.EntryDevice.LookupPeer(publicKey) + } else { + // TODO: Document me + peer = tunnel.Device.LookupPeer(publicKey) + } if peer == nil { return ERROR_UNKNOWN_PEER diff --git a/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go index 91291dcf4b..79eacc2a17 100644 --- a/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go +++ b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go @@ -16,6 +16,7 @@ import ( type Context struct { Device *device.Device + EntryDevice *device.Device Uapi net.Listener Logger *device.Logger } diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs index a77b48c0bd..851fd47b9f 100644 --- a/wireguard-go-rs/src/lib.rs +++ b/wireguard-go-rs/src/lib.rs @@ -8,10 +8,10 @@ #![cfg(unix)] -use core::slice; -use std::{ +use core::{ ffi::{c_char, CStr}, mem::{ManuallyDrop, MaybeUninit}, + slice, }; use util::OnDrop; use zeroize::Zeroize; @@ -105,6 +105,37 @@ impl Tunnel { result_from_code(code) } + /// Special function for android multihop since that behavior is different from desktop + /// and android non-multihop. + /// + /// The `logging_callback` let's you provide a Rust function that receives any logging output + /// from wireguard-go. `logging_context` is a value that will be passed to each invocation of + /// `logging_callback`. + #[cfg(target_os = "android")] + pub fn turn_on_multihop( + exit_settings: &CStr, + entry_settings: &CStr, + private_ip: &CStr, + device: Fd, + logging_callback: Option<LoggingCallback>, + logging_context: LoggingContext, + ) -> Result<Self, Error> { + // SAFETY: pointer is valid for the the lifetime of this function + let code = unsafe { + ffi::wgTurnOnMultihop( + exit_settings.as_ptr(), + entry_settings.as_ptr(), + private_ip.as_ptr(), + device, + logging_callback, + logging_context, + ) + }; + + result_from_code(code)?; + Ok(Tunnel { handle: code }) + } + /// Get the config of the WireGuard interface and make it available in the provided function. /// /// This takes a function to make sure the cstr get's zeroed and freed afterwards. @@ -180,12 +211,14 @@ impl Tunnel { /// Get the file descriptor of the tunnel IPv4 socket. #[cfg(target_os = "android")] pub fn get_socket_v4(&self) -> Fd { + // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. unsafe { ffi::wgGetSocketV4(self.handle) } } /// Get the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] pub fn get_socket_v6(&self) -> Fd { + // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. unsafe { ffi::wgGetSocketV6(self.handle) } } } @@ -257,6 +290,21 @@ mod ffi { logging_context: LoggingContext, ) -> i32; + /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors + /// for the tunnel device and logging. + /// + /// Positive return values are tunnel handles for this specific wireguard tunnel instance. + /// Negative return values signify errors. + #[cfg(target_os = "android")] + pub fn wgTurnOnMultihop( + exit_settings: *const c_char, + entry_settings: *const c_char, + private_ip: *const c_char, + fd: Fd, + logging_callback: Option<LoggingCallback>, + logging_context: LoggingContext, + ) -> i32; + /// Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. /// /// Negative return values signify errors. |
