diff options
| author | Kalle Lindström <karl.lindstrom@mullvad.net> | 2024-11-15 13:37:00 +0100 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-11-22 17:42:38 +0100 |
| commit | 84e023366e3a6b2ff5e44a2b376e4bb1f5574f22 (patch) | |
| tree | 28d918c03fe5b3aed254310485f00d8970759ebf | |
| parent | 29a6778daeb1699e725094ae061d9ee81ad2450b (diff) | |
| download | mullvadvpn-84e023366e3a6b2ff5e44a2b376e4bb1f5574f22.tar.xz mullvadvpn-84e023366e3a6b2ff5e44a2b376e4bb1f5574f22.zip | |
Check that that tunnel can serve traffic after starting a new tunnel
- Split up "ConnectivityCheck" into more descriptive types and collect
them in a new `connectivity` module.
- Fix allow Wireguard-Go tunnel setup to be cancelled
- Use retry param in connectivity check
| -rw-r--r-- | talpid-wireguard/src/connectivity/check.rs (renamed from talpid-wireguard/src/connectivity_check.rs) | 597 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/constants.rs | 22 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/error.rs | 14 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/mock.rs | 133 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/mod.rs | 13 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/monitor.rs | 174 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/pinger/android.rs (renamed from talpid-wireguard/src/ping_monitor/android.rs) | 0 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/pinger/icmp.rs (renamed from talpid-wireguard/src/ping_monitor/icmp.rs) | 0 | ||||
| -rw-r--r-- | talpid-wireguard/src/connectivity/pinger/mod.rs (renamed from talpid-wireguard/src/ping_monitor/mod.rs) | 0 | ||||
| -rw-r--r-- | talpid-wireguard/src/ephemeral.rs | 19 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 127 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 68 |
12 files changed, 678 insertions, 489 deletions
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity/check.rs index 4b6e0f9810..527931563b 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity/check.rs @@ -1,55 +1,17 @@ -use crate::{ - ping_monitor::{new_pinger, Pinger}, - stats::StatsMap, - TunnelType, -}; -use std::{ - cmp, - net::Ipv4Addr, - sync::{mpsc, Weak}, - time::{Duration, Instant}, -}; -use tokio::sync::Mutex; +use std::cmp; +use std::net::Ipv4Addr; +use std::sync::mpsc; +use std::time::{Duration, Instant}; -#[cfg(target_os = "android")] -use super::Tunnel; -use super::TunnelError; - -/// Sleep time used when initially establishing connectivity -const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); -/// Sleep time used when checking if an established connection is still working. -const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); - -/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is -/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic -/// is received. -const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5); -/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will -/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received. -const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120); -/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this -/// timeout is reached, it is assumed that the connection is lost. -const PING_TIMEOUT: Duration = Duration::from_secs(15); -/// Timeout for receiving traffic when establishing a connection. -const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4); -/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt. -const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; -/// Maximum timeout for establishing a connection. -const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT; -/// Number of seconds to wait between sending ICMP packets -const SECONDS_PER_PING: Duration = Duration::from_secs(3); - -/// Connectivity monitor errors -#[derive(thiserror::Error, Debug)] -pub enum Error { - /// Failed to read tunnel's configuration - #[error("Failed to read tunnel's configuration")] - ConfigReadError(TunnelError), +use super::constants::*; +use super::error::Error; +use super::pinger; - /// Failed to send ping - #[error("Ping monitor failed")] - PingError(#[from] crate::ping_monitor::Error), -} +use crate::stats::StatsMap; +#[cfg(target_os = "android")] +use crate::Tunnel; +use crate::{TunnelError, TunnelType}; +use pinger::Pinger; /// Verifies if a connection to a tunnel is working. /// The connectivity monitor is biased to receiving traffic - it is expected that all outgoing @@ -73,61 +35,126 @@ pub enum Error { /// /// Once a connection established, a connection is only considered broken once the connectivity /// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. -pub struct ConnectivityMonitor { - /// Tunnel implementation - tunnel_handle: Weak<Mutex<Option<TunnelType>>>, +pub struct Check<Strategy = Timeout> { conn_state: ConnState, - initial_ping_timestamp: Option<Instant>, - num_pings_sent: u32, - pinger: Box<dyn Pinger>, + ping_state: PingState, + strategy: Strategy, + retry_attempt: u32, +} + +// Define the type state of [Check] +pub(crate) trait Strategy { + fn should_shut_down(&mut self, timeout: Duration) -> bool; +} + +/// An uncancellable [Check] that will run [Check::establish_connectivity] until +/// completion or until it times out. +pub struct Timeout; + +impl Strategy for Timeout { + /// The Timeout strategy cannot receive shut down signals so this function always returns false. + fn should_shut_down(&mut self, _timeout: Duration) -> bool { + false + } +} + +/// A cancellable [Check] may be cancelled before it will time out by sending +/// a signal on the channel returned by [Check::with_cancellation]. Otherwise, +/// it behaves as [Timeout]. +pub struct Cancellable { close_receiver: mpsc::Receiver<()>, } -impl ConnectivityMonitor { - pub(super) fn new( +impl Strategy for Cancellable { + /// Returns true if monitor should be shut down + fn should_shut_down(&mut self, timeout: Duration) -> bool { + match self.close_receiver.recv_timeout(timeout) { + Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, + Err(mpsc::RecvTimeoutError::Timeout) => false, + } + } +} + +impl Check<Timeout> { + pub fn new( addr: Ipv4Addr, #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, - tunnel_handle: Weak<Mutex<Option<TunnelType>>>, - close_receiver: mpsc::Receiver<()>, - ) -> Result<Self, Error> { - let pinger = new_pinger( - addr, - #[cfg(any(target_os = "macos", target_os = "linux"))] - interface, - ) - .map_err(Error::PingError)?; + retry_attempt: u32, + ) -> Result<Check<Timeout>, Error> { + Ok(Check { + conn_state: ConnState::new(Instant::now(), Default::default()), + ping_state: PingState::new( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + )?, + strategy: Timeout, + retry_attempt, + }) + } - let now = Instant::now(); + /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping + /// the returned channel. + pub fn with_cancellation(self) -> (Check<Cancellable>, mpsc::Sender<()>) { + let (cancellation_tx, cancellation_rx) = mpsc::channel(); + let check = Check { + conn_state: self.conn_state, + ping_state: self.ping_state, + strategy: Cancellable { + close_receiver: cancellation_rx, + }, + retry_attempt: self.retry_attempt, + }; + (check, cancellation_tx) + } - Ok(Self { - tunnel_handle, - conn_state: ConnState::new(now, Default::default()), - initial_ping_timestamp: None, - num_pings_sent: 0, - pinger, - close_receiver, - }) + #[cfg(test)] + /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy, + /// see [Check::with_cancellation]. + pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self { + Check { + conn_state, + ping_state, + strategy: Timeout, + retry_attempt: 0, + } } +} +impl<S: Strategy> Check<S> { // checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is // successful at the start of a connection. - pub(super) fn establish_connectivity(&mut self, retry_attempt: u32) -> Result<bool, Error> { + pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result<bool, Error> { // Send initial ping to prod WireGuard into connecting. - self.pinger.send_icmp().map_err(Error::PingError)?; + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; self.establish_connectivity_inner( - retry_attempt, + self.retry_attempt, ESTABLISH_TIMEOUT, ESTABLISH_TIMEOUT_MULTIPLIER, MAX_ESTABLISH_TIMEOUT, + tunnel_handle, ) } + pub(crate) fn reset(&mut self, current_iteration: Instant) { + self.ping_state.reset(); + self.conn_state.reset_after_suspension(current_iteration); + } + + pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool { + self.strategy.should_shut_down(timeout) + } + fn establish_connectivity_inner( &mut self, retry_attempt: u32, timeout_initial: Duration, timeout_multiplier: u32, max_timeout: Duration, + tunnel_handle: &TunnelType, ) -> Result<bool, Error> { if self.conn_state.connected() { return Ok(true); @@ -140,7 +167,7 @@ impl ConnectivityMonitor { let start = Instant::now(); while start.elapsed() < check_timeout { - if self.check_connectivity_interval(Instant::now(), check_timeout)? { + if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? { return Ok(true); } if self.should_shut_down(DELAY_ON_INITIAL_SETUP) { @@ -150,46 +177,13 @@ impl ConnectivityMonitor { Ok(false) } - pub(super) fn run(&mut self) -> Result<(), Error> { - self.wait_loop(REGULAR_LOOP_SLEEP) - } - - /// Returns true if monitor should be shut down - fn should_shut_down(&mut self, timeout: Duration) -> bool { - match self.close_receiver.recv_timeout(timeout) { - Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, - Err(mpsc::RecvTimeoutError::Timeout) => false, - } - } - - fn wait_loop(&mut self, iter_delay: Duration) -> Result<(), Error> { - let mut last_iteration = Instant::now(); - while !self.should_shut_down(iter_delay) { - let mut current_iteration = Instant::now(); - let time_slept = current_iteration - last_iteration; - if time_slept < (iter_delay * 2) { - if !self.check_connectivity(Instant::now())? { - return Ok(()); - } - - let end = Instant::now(); - if end - current_iteration > Duration::from_secs(1) { - current_iteration = end; - } - } else { - // Loop was suspended for too long, so it's safer to assume that the host still has - // connectivity. - self.reset_pinger(); - self.conn_state.reset_after_suspension(current_iteration); - } - last_iteration = current_iteration; - } - Ok(()) - } - /// Returns true if connection is established - fn check_connectivity(&mut self, now: Instant) -> Result<bool, Error> { - self.check_connectivity_interval(now, PING_TIMEOUT) + pub(crate) fn check_connectivity( + &mut self, + now: Instant, + tunnel_handle: &TunnelType, + ) -> Result<bool, Error> { + self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle) } /// Returns true if connection is established @@ -197,19 +191,18 @@ impl ConnectivityMonitor { &mut self, now: Instant, timeout: Duration, + tunnel_handle: &TunnelType, ) -> Result<bool, Error> { - match self.get_stats() { + match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? { None => Ok(false), Some(new_stats) => { - let new_stats = new_stats?; - if self.conn_state.update(now, new_stats) { - self.reset_pinger(); + self.ping_state.reset(); return Ok(true); } self.maybe_send_ping(now)?; - Ok(!self.ping_timed_out(timeout) && self.conn_state.connected()) + Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected()) } } } @@ -218,19 +211,14 @@ impl ConnectivityMonitor { /// calls will also return None. /// /// NOTE: will panic if called from within a tokio runtime. - fn get_stats(&self) -> Option<Result<StatsMap, Error>> { - self.tunnel_handle - .upgrade()? - .blocking_lock() - .as_ref() - .and_then(|tunnel| match tunnel.get_tunnel_stats() { - Ok(stats) if stats.is_empty() => { - log::error!("Tunnel unexpectedly shut down"); - None - } - Ok(stats) => Some(Ok(stats)), - Err(error) => Some(Err(Error::ConfigReadError(error))), - }) + fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> { + let stats = tunnel_handle.get_tunnel_stats()?; + if stats.is_empty() { + log::error!("Tunnel unexpectedly shut down"); + Ok(None) + } else { + Ok(Some(stats)) + } } fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> { @@ -239,20 +227,55 @@ impl ConnectivityMonitor { // 3 seconds. if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out()) && self + .ping_state .initial_ping_timestamp .map(|initial_ping_timestamp| { - initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING + initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent + < SECONDS_PER_PING }) .unwrap_or(true) { - self.pinger.send_icmp().map_err(Error::PingError)?; - if self.initial_ping_timestamp.is_none() { - self.initial_ping_timestamp = Some(now); + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; + if self.ping_state.initial_ping_timestamp.is_none() { + self.ping_state.initial_ping_timestamp = Some(now); } - self.num_pings_sent += 1; + self.ping_state.num_pings_sent += 1; } Ok(()) } +} + +pub(super) struct PingState { + initial_ping_timestamp: Option<Instant>, + num_pings_sent: u32, + pinger: Box<dyn Pinger>, +} + +impl PingState { + pub(super) fn new( + addr: Ipv4Addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, + ) -> Result<Self, Error> { + let pinger = pinger::new_pinger( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + ) + .map_err(Error::PingError)?; + + Ok(Self::new_with(pinger)) + } + + pub(super) fn new_with(pinger: Box<dyn Pinger>) -> Self { + Self { + initial_ping_timestamp: None, + num_pings_sent: 0, + pinger, + } + } fn ping_timed_out(&self, timeout: Duration) -> bool { self.initial_ping_timestamp @@ -261,14 +284,14 @@ impl ConnectivityMonitor { } /// Reset timeouts - assume that the last time bytes were received is now. - fn reset_pinger(&mut self) { + fn reset(&mut self) { self.initial_ping_timestamp = None; self.num_pings_sent = 0; self.pinger.reset(); } } -enum ConnState { +pub(super) enum ConnState { Connecting { start: Instant, stats: StatsMap, @@ -401,21 +424,8 @@ impl ConnState { #[cfg(test)] mod test { - use futures::Future; - use super::*; - use crate::{ - config::Config, - stats::{self, Stats}, - Tunnel, - }; - use std::{ - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - }; + use crate::connectivity::mock::*; /// Test if a newly created ConnState won't have timed out or consider itself connected #[test] @@ -521,300 +531,76 @@ mod test { assert!(!conn_state.traffic_timed_out()); } - #[derive(Default)] - struct MockPinger { - on_send_ping: Option<Box<dyn FnMut() + Send>>, - } - - impl Pinger for MockPinger { - fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> { - if let Some(callback) = self.on_send_ping.as_mut() { - (callback)(); - } - Ok(()) - } - } - - struct MockTunnel { - on_get_stats: Box<dyn Fn() -> Result<stats::StatsMap, TunnelError> + Send>, - } - - impl MockTunnel { - const PEER: [u8; 32] = [0u8; 32]; - - fn new<F: Fn() -> Result<stats::StatsMap, TunnelError> + Send + 'static>(f: F) -> Self { - Self { - on_get_stats: Box::new(f), - } - } - - fn always_incrementing() -> Self { - let mut map = stats::StatsMap::new(); - map.insert( - Self::PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let peers = std::sync::Mutex::new(map); - Self { - on_get_stats: Box::new(move || { - let mut peers = peers.lock().unwrap(); - for traffic in peers.values_mut() { - traffic.tx_bytes += 1; - traffic.rx_bytes += 1; - } - Ok(peers.clone()) - }), - } - } - - fn never_incrementing() -> Self { - Self { - on_get_stats: Box::new(|| { - let mut map = stats::StatsMap::new(); - map.insert( - Self::PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - Ok(map) - }), - } - } - - #[allow(clippy::type_complexity)] - fn into_locked( - self, - ) -> ( - Arc<Mutex<Option<Box<dyn Tunnel>>>>, - Weak<Mutex<Option<Box<dyn Tunnel>>>>, - ) { - let dyn_tunnel: Box<dyn Tunnel> = Box::new(self); - let arc = Arc::new(Mutex::new(Some(dyn_tunnel))); - let weak_ref = Arc::downgrade(&arc); - (arc, weak_ref) - } - } - - impl Tunnel for MockTunnel { - fn get_interface_name(&self) -> String { - "mock-tunnel".to_string() - } - - fn stop(self: Box<Self>) -> Result<(), TunnelError> { - Ok(()) - } - - fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> { - (self.on_get_stats)() - } - - fn set_config( - &mut self, - _config: Config, - ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { - Box::pin(async { Ok(()) }) - } - - #[cfg(daita)] - fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { - Ok(()) - } - } - - fn mock_monitor( - now: Instant, - pinger: Box<dyn Pinger>, - tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, - close_receiver: mpsc::Receiver<()>, - ) -> ConnectivityMonitor { - ConnectivityMonitor { - conn_state: ConnState::new(now, Default::default()), - initial_ping_timestamp: None, - num_pings_sent: 0, - pinger, - close_receiver, - tunnel_handle, - } - } - - fn connected_state(timestamp: Instant) -> ConnState { - const PEER: [u8; 32] = [0u8; 32]; - let mut stats = stats::StatsMap::new(); - stats.insert( - PEER, - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - ConnState::Connected { - rx_timestamp: timestamp, - tx_timestamp: timestamp, - stats, - } - } - #[test] /// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is /// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`. fn test_ping_times_out() { - let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now .checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10)) .unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut checker = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established - monitor.conn_state = connected_state(start); + checker.conn_state = connected_state(start); // A ping was sent to verify connectivity - monitor.maybe_send_ping(start).unwrap(); - assert!(!monitor.check_connectivity(now).unwrap()) + checker.maybe_send_ping(start).unwrap(); + assert!(!checker.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. fn test_no_connection_on_start() { - let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut monitor = mock_checker(start, Box::new(pinger)); - assert!(!monitor.check_connectivity(now).unwrap()) + assert!(!monitor.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. fn test_connection_works() { - let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); - let (_tx, rx) = mpsc::channel(); + let tunnel = MockTunnel::always_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + let mut monitor = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established monitor.conn_state = connected_state(start); - assert!(monitor.check_connectivity(now).unwrap()) - } - - #[test] - /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, - /// and it shuts down properly. - fn test_wait_loop() { - let (result_tx, result_rx) = mpsc::channel(); - let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); - let pinger = MockPinger::default(); - let (stop_tx, stop_rx) = mpsc::channel(); - std::thread::spawn(move || { - let now = Instant::now(); - let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); - - let start_result = monitor.establish_connectivity(0); - result_tx.send(start_result).unwrap(); - - let result = monitor.run().map(|_| true); - result_tx.send(result).unwrap(); - }); - - std::thread::sleep(Duration::from_secs(1)); - assert!(result_rx.try_recv().unwrap().unwrap()); - stop_tx.send(()).unwrap(); - std::thread::sleep(Duration::from_secs(1)); - assert!(result_rx.try_recv().unwrap().is_ok()); - } - - #[test] - /// Verify that the connectivity monitor detects the tunnel timing out after no longer than - /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. - fn test_wait_loop_timeout() { - let should_stop = Arc::new(AtomicBool::new(false)); - let should_stop_inner = should_stop.clone(); - - let mut map = stats::StatsMap::new(); - map.insert( - [0u8; 32], - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let tunnel_stats = std::sync::Mutex::new(map); - - let pinger = MockPinger::default(); - let (_tunnel_anchor, tunnel) = MockTunnel::new(move || { - let mut tunnel_stats = tunnel_stats.lock().unwrap(); - if !should_stop_inner.load(Ordering::SeqCst) { - for traffic in tunnel_stats.values_mut() { - traffic.rx_bytes += 1; - } - } - for traffic in tunnel_stats.values_mut() { - traffic.tx_bytes += 1; - } - Ok(tunnel_stats.clone()) - }) - .into_locked(); - - let (result_tx, result_rx) = mpsc::channel(); - - let (_stop_tx, stop_rx) = mpsc::channel(); - std::thread::spawn(move || { - let now = Instant::now(); - let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); - let start_result = monitor.establish_connectivity(0); - result_tx.send(start_result).unwrap(); - let end_result = monitor.run().map(|_| true); - result_tx.send(end_result).expect("Failed to send result"); - }); - assert!(result_rx - .recv_timeout(Duration::from_secs(1)) - .unwrap() - .unwrap()); - should_stop.store(true, Ordering::SeqCst); - assert!(result_rx - .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) - .unwrap() - .is_ok()); + assert!(monitor.check_connectivity(now, &tunnel).unwrap()) } #[test] /// Verify that the timeout for setting up a tunnel works as expected. fn test_establish_timeout() { - let mut tunnel_stats = stats::StatsMap::new(); - tunnel_stats.insert( - [0u8; 32], - stats::Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - let pinger = MockPinger::default(); - let (_tunnel_anchor, tunnel) = - MockTunnel::new(move || Ok(tunnel_stats.clone())).into_locked(); + let tunnel = { + let mut tunnel_stats = StatsMap::new(); + tunnel_stats.insert( + [0u8; 32], + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed() + }; let (result_tx, result_rx) = mpsc::channel(); - let (_stop_tx, stop_rx) = mpsc::channel(); std::thread::spawn(move || { let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); + let mut monitor = mock_checker(start, Box::new(pinger)); const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500); @@ -827,6 +613,7 @@ mod test { ESTABLISH_TIMEOUT, ESTABLISH_TIMEOUT_MULTIPLIER, MAX_ESTABLISH_TIMEOUT, + &tunnel, )) .unwrap(); } diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs new file mode 100644 index 0000000000..a8d6752ddd --- /dev/null +++ b/talpid-wireguard/src/connectivity/constants.rs @@ -0,0 +1,22 @@ +use std::time::Duration; + +/// Sleep time used when initially establishing connectivity +pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); +/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is +/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic +/// is received. +pub(crate) const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5); +/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will +/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received. +pub(crate) const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120); +/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this +/// timeout is reached, it is assumed that the connection is lost. +pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(15); +/// Timeout for receiving traffic when establishing a connection. +pub(crate) const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4); +/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt. +pub(crate) const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; +/// Maximum timeout for establishing a connection. +pub(crate) const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT; +/// Number of seconds to wait between sending ICMP packets +pub(crate) const SECONDS_PER_PING: Duration = Duration::from_secs(3); diff --git a/talpid-wireguard/src/connectivity/error.rs b/talpid-wireguard/src/connectivity/error.rs new file mode 100644 index 0000000000..9e8c98a751 --- /dev/null +++ b/talpid-wireguard/src/connectivity/error.rs @@ -0,0 +1,14 @@ +use super::pinger; +use crate::TunnelError; + +/// Connectivity monitor errors +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Failed to read tunnel's configuration + #[error("Failed to read tunnel's configuration")] + ConfigReadError(TunnelError), + + /// Failed to send ping + #[error("Ping failed")] + PingError(#[from] pinger::Error), +} diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs new file mode 100644 index 0000000000..892f3966ea --- /dev/null +++ b/talpid-wireguard/src/connectivity/mock.rs @@ -0,0 +1,133 @@ +use std::future::Future; +use std::pin::Pin; +use std::time::Instant; + +use super::check::{ConnState, PingState, Timeout}; +use super::pinger; +use super::Check; + +use crate::{Config, Tunnel, TunnelError}; +use pinger::Pinger; + +// Convenient re-exports +pub use crate::stats::{Stats, StatsMap}; + +#[derive(Default)] +pub(crate) struct MockPinger { + on_send_ping: Option<Box<dyn FnMut() + Send>>, +} + +pub(crate) struct MockTunnel { + on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send>, +} + +pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> Check<Timeout> { + let conn_state = ConnState::new(now, Default::default()); + let ping_state = PingState::new_with(pinger); + Check::mock(conn_state, ping_state) +} + +pub fn connected_state(timestamp: Instant) -> ConnState { + const PEER: [u8; 32] = [0u8; 32]; + let mut stats = StatsMap::new(); + stats.insert( + PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + ConnState::Connected { + rx_timestamp: timestamp, + tx_timestamp: timestamp, + stats, + } +} + +impl MockTunnel { + const PEER: [u8; 32] = [0u8; 32]; + + pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + 'static>(f: F) -> Self { + Self { + on_get_stats: Box::new(f), + } + } + + /// Convert self to the more general [TunnelType]. + pub fn boxed(self) -> Box<dyn Tunnel> { + Box::new(self) + } + + pub fn always_incrementing() -> Self { + let mut map = StatsMap::new(); + map.insert( + Self::PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + let peers = std::sync::Mutex::new(map); + Self { + on_get_stats: Box::new(move || { + let mut peers = peers.lock().unwrap(); + for traffic in peers.values_mut() { + traffic.tx_bytes += 1; + traffic.rx_bytes += 1; + } + Ok(peers.clone()) + }), + } + } + + pub fn never_incrementing() -> Self { + Self { + on_get_stats: Box::new(|| { + let mut map = StatsMap::new(); + map.insert( + Self::PEER, + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + Ok(map) + }), + } + } +} + +impl Tunnel for MockTunnel { + fn get_interface_name(&self) -> String { + "mock-tunnel".to_string() + } + + fn stop(self: Box<Self>) -> Result<(), TunnelError> { + Ok(()) + } + + fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> { + (self.on_get_stats)() + } + + fn set_config( + &mut self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { + Box::pin(async { Ok(()) }) + } + + #[cfg(daita)] + fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + Ok(()) + } +} + +impl Pinger for MockPinger { + fn send_icmp(&mut self) -> Result<(), pinger::Error> { + if let Some(callback) = self.on_send_ping.as_mut() { + (callback)(); + } + Ok(()) + } +} diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs new file mode 100644 index 0000000000..512d8715f1 --- /dev/null +++ b/talpid-wireguard/src/connectivity/mod.rs @@ -0,0 +1,13 @@ +mod check; +mod constants; +mod error; +#[cfg(test)] +mod mock; +mod monitor; +mod pinger; + +#[cfg(target_os = "android")] +pub use check::Cancellable; +pub use check::Check; +pub use error::Error; +pub use monitor::Monitor; diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs new file mode 100644 index 0000000000..583b8d9589 --- /dev/null +++ b/talpid-wireguard/src/connectivity/monitor.rs @@ -0,0 +1,174 @@ +use std::{ + sync::Weak, + time::{Duration, Instant}, +}; + +use tokio::sync::Mutex; + +use crate::TunnelType; + +use super::check::{Cancellable, Check}; +use super::error::Error; + +/// Sleep time used when checking if an established connection is still working. +const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); + +pub struct Monitor { + connectivity_check: Check<Cancellable>, +} + +impl Monitor { + pub fn init(connectivity_check: Check<Cancellable>) -> Self { + Self { connectivity_check } + } + + pub fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> { + self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle) + } + + fn wait_loop( + mut self, + iter_delay: Duration, + tunnel_handle: Weak<Mutex<Option<TunnelType>>>, + ) -> Result<(), Error> { + let mut last_iteration = Instant::now(); + while !self.connectivity_check.should_shut_down(iter_delay) { + let mut current_iteration = Instant::now(); + let time_slept = current_iteration - last_iteration; + if time_slept < (iter_delay * 2) { + let Some(tunnel) = tunnel_handle.upgrade() else { + return Ok(()); + }; + let lock = tunnel.blocking_lock(); + let Some(tunnel) = lock.as_ref() else { + return Ok(()); + }; + + if !self + .connectivity_check + .check_connectivity(Instant::now(), tunnel)? + { + return Ok(()); + } + drop(lock); + + let end = Instant::now(); + if end - current_iteration > Duration::from_secs(1) { + current_iteration = end; + } + } else { + // Loop was suspended for too long, so it's safer to assume that the host still has + // connectivity. + self.connectivity_check.reset(current_iteration); + } + last_iteration = current_iteration; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + // TODO: Port to async + tokio to reduce cost of testing? + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc; + use std::sync::Arc; + use std::time::Duration; + use std::time::Instant; + + use tokio::sync::Mutex; + + use crate::connectivity::constants::*; + use crate::connectivity::mock::*; + + #[test] + /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, + /// and it shuts down properly. + fn test_wait_loop() { + use std::sync::mpsc; + let (result_tx, result_rx) = mpsc::channel(); + let tunnel = MockTunnel::always_incrementing().boxed(); + let pinger = MockPinger::default(); + let (mut checker, stop_tx) = { + let now = Instant::now(); + let start = now.checked_sub(Duration::from_secs(1)).unwrap(); + mock_checker(start, Box::new(pinger)).with_cancellation() + }; + std::thread::spawn(move || { + let start_result = checker.establish_connectivity(&tunnel); + result_tx.send(start_result).unwrap(); + // Pointer dance + let tunnel = Arc::new(Mutex::new(Some(tunnel))); + let _tunnel = Arc::downgrade(&tunnel); + let result = Monitor::init(checker).run(_tunnel).map(|_| true); + result_tx.send(result).unwrap(); + }); + + std::thread::sleep(Duration::from_secs(1)); + assert!(result_rx.try_recv().unwrap().unwrap()); + stop_tx.send(()).unwrap(); + std::thread::sleep(Duration::from_secs(1)); + assert!(result_rx.try_recv().unwrap().is_ok()); + } + + #[test] + /// Verify that the connectivity monitor detects the tunnel timing out after no longer than + /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. + fn test_wait_loop_timeout() { + let should_stop = Arc::new(AtomicBool::new(false)); + let should_stop_inner = should_stop.clone(); + + let mut map = StatsMap::new(); + map.insert( + [0u8; 32], + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + let tunnel_stats = std::sync::Mutex::new(map); + + let pinger = MockPinger::default(); + let tunnel = MockTunnel::new(move || { + let mut tunnel_stats = tunnel_stats.lock().unwrap(); + if !should_stop_inner.load(Ordering::SeqCst) { + for traffic in tunnel_stats.values_mut() { + traffic.rx_bytes += 1; + } + } + for traffic in tunnel_stats.values_mut() { + traffic.tx_bytes += 1; + } + Ok(tunnel_stats.clone()) + }) + .boxed(); + + let (result_tx, result_rx) = mpsc::channel(); + + std::thread::spawn(move || { + let (mut checker, _cancellation_token) = { + let now = Instant::now(); + let start = now.checked_sub(Duration::from_secs(1)).unwrap(); + mock_checker(start, Box::new(pinger)).with_cancellation() + }; + let start_result = checker.establish_connectivity(&tunnel); + result_tx.send(start_result).unwrap(); + // Pointer dance + let _tunnel = Arc::new(Mutex::new(Some(tunnel))); + let tunnel = Arc::downgrade(&_tunnel); + let end_result = Monitor::init(checker).run(tunnel).map(|_| true); + result_tx.send(end_result).expect("Failed to send result"); + }); + assert!(result_rx + .recv_timeout(Duration::from_secs(1)) + .unwrap() + .unwrap()); + should_stop.store(true, Ordering::SeqCst); + assert!(result_rx + .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) + .unwrap() + .is_ok()); + } +} diff --git a/talpid-wireguard/src/ping_monitor/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs index 00ad4d8fd3..00ad4d8fd3 100644 --- a/talpid-wireguard/src/ping_monitor/android.rs +++ b/talpid-wireguard/src/connectivity/pinger/android.rs diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs index 0e5d739425..0e5d739425 100644 --- a/talpid-wireguard/src/ping_monitor/icmp.rs +++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs diff --git a/talpid-wireguard/src/ping_monitor/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs index ef2394f1b7..ef2394f1b7 100644 --- a/talpid-wireguard/src/ping_monitor/mod.rs +++ b/talpid-wireguard/src/connectivity/pinger/mod.rs diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index 127836863a..a9283fcb2e 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -130,6 +130,7 @@ async fn config_ephemeral_peers_inner( &tun_provider, ) .await?; + let entry_psk = request_ephemeral_peer( retry_attempt, &entry_config, @@ -199,17 +200,17 @@ async fn reconfigure_tunnel( .await .map_err(CloseMsg::ObfuscatorFailed)?; } + { + let mut shared_tunnel = tunnel.lock().await; + let tunnel = shared_tunnel.take().expect("tunnel was None"); - let mut lock = tunnel.lock().await; - - let tunnel = lock.take().expect("tunnel was None"); - - let new_tunnel = tunnel - .better_set_config(&config) - .map_err(Error::TunnelError) - .map_err(CloseMsg::SetupError)?; + let updated_tunnel = tunnel + .set_config(&config) + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; - *lock = Some(new_tunnel); + *shared_tunnel = Some(updated_tunnel); + } Ok(config) } diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 0f05314e72..7c93b39f18 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -35,11 +35,10 @@ use tokio::sync::Mutex as AsyncMutex; /// WireGuard config data-types pub mod config; -mod connectivity_check; +mod connectivity; mod ephemeral; mod logging; mod obfuscation; -mod ping_monitor; mod stats; #[cfg(wireguard_go)] mod wireguard_go; @@ -88,7 +87,7 @@ pub enum Error { /// Failed to set up connectivity monitor #[error("Connectivity monitor failed")] - ConnectivityMonitorError(#[source] connectivity_check::Error), + ConnectivityMonitorError(#[source] connectivity::Error), /// Failed while negotiating ephemeral peer #[error("Failed while negotiating ephemeral peer")] @@ -216,8 +215,17 @@ impl WireguardMonitor { let obfuscator = Arc::new(AsyncMutex::new(obfuscator)); + let gateway = config.ipv4_gateway; + let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new( + gateway, + #[cfg(any(target_os = "macos", target_os = "linux"))] + iface_name.clone(), + args.retry_attempt, + ) + .map_err(Error::ConnectivityMonitorError)? + .with_cancellation(); + let event_callback = Box::new(on_event.clone()); - let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), @@ -227,16 +235,6 @@ impl WireguardMonitor { obfuscator, }; - let gateway = config.ipv4_gateway; - let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( - gateway, - #[cfg(any(target_os = "macos", target_os = "linux"))] - iface_name.clone(), - Arc::downgrade(&monitor.tunnel), - pinger_rx, - ) - .map_err(Error::ConnectivityMonitorError)?; - let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); @@ -321,8 +319,12 @@ impl WireguardMonitor { }); } - let mut connectivity_monitor = tokio::task::spawn_blocking(move || { - match connectivity_monitor.establish_connectivity(args.retry_attempt) { + let cloned_tunnel = Arc::clone(&tunnel); + + let connectivity_check = tokio::task::spawn_blocking(move || { + let lock = cloned_tunnel.blocking_lock(); + let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly"); + match connectivity_monitor.establish_connectivity(tunnel) { Ok(true) => Ok(connectivity_monitor), Ok(false) => { log::warn!("Timeout while checking tunnel connection"); @@ -350,8 +352,11 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); (on_event)(TunnelEvent::Up(metadata)).await; + let monitored_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { - if let Err(error) = connectivity_monitor.run() { + if let Err(error) = + connectivity::Monitor::init(connectivity_check).run(monitored_tunnel) + { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -423,6 +428,12 @@ impl WireguardMonitor { } let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; + + let (connectivity_check, pinger_tx) = + connectivity::Check::new(config.ipv4_gateway, args.retry_attempt) + .map_err(Error::ConnectivityMonitorError)? + .with_cancellation(); + let tunnel = Self::open_wireguard_go_tunnel( &config, log_path, @@ -432,11 +443,11 @@ impl WireguardMonitor { // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. should_negotiate_ephemeral_peer, + connectivity_check, )?; + let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); - - let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::clone(&tunnel), @@ -446,60 +457,18 @@ impl WireguardMonitor { obfuscator: Arc::new(AsyncMutex::new(obfuscator)), }; - let gateway = config.ipv4_gateway; - let connectivity_monitor = connectivity_check::ConnectivityMonitor::new( - gateway, - Arc::downgrade(&tunnel), - pinger_rx, - ) - .map_err(Error::ConnectivityMonitorError)?; - let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); let tunnel_fut = async move { let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender; let obfuscator = moved_obfuscator; - let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor)); let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; - let handle_ping = |ping_result: std::result::Result< - bool, - connectivity_check::Error, - >| match ping_result { - Ok(true) => Ok(()), - Ok(false) => { - log::warn!("Timeout while checking tunnel connection"); - Err(CloseMsg::PingErr) - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check tunnel connection") - ); - Err(CloseMsg::PingErr) - } - }; - - // Prepare a closure which pings inside the tunnel when executed. - let ping = || { - let connectivity_monitor_arc = connectivity_monitor.clone(); - let retry_attempt = args.retry_attempt; - move || { - let ping_result = connectivity_monitor_arc - .lock() - .unwrap() - .establish_connectivity(retry_attempt); - handle_ping(ping_result) - } - }; - if should_negotiate_ephemeral_peer { - // Ping before negotiating the ephemeral peer to make sure that the tunnel works. - tokio::task::spawn_blocking(ping()).await.unwrap()?; let ephemeral_obfs_sender = close_obfs_sender.clone(); ephemeral::config_ephemeral_peers( @@ -513,21 +482,31 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event.clone())(TunnelEvent::InterfaceUp( + args.on_event.clone()(TunnelEvent::InterfaceUp( metadata, Self::allowed_traffic_after_tunnel_config(), )) .await; } - // Make sure the tunnel works (after potentially having negotiated an ephemeral peer). - tokio::task::spawn_blocking(ping()).await.unwrap()?; - let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event.clone())(TunnelEvent::Up(metadata)).await; + args.on_event.clone()(TunnelEvent::Up(metadata)).await; + + // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it + let connectivity_check = { + let mut tunnel_lock = tunnel.lock().await; + let Some(tunnel) = tunnel_lock.as_mut() else { + log::debug!("Tunnel is no longer running"); + return Err::<Infallible, CloseMsg>(CloseMsg::PingErr); + }; + tunnel + .take_checker() + .expect("connectivity checker unexpectedly dropped") + }; tokio::task::spawn_blocking(move || { - if let Err(error) = connectivity_monitor.lock().unwrap().run() { + let tunnel = Arc::downgrade(&tunnel); + if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -748,6 +727,9 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(target_os = "android")] gateway_only: bool, + #[cfg(target_os = "android")] connectivity_check: connectivity::Check< + connectivity::Cancellable, + >, ) -> Result<WgGoTunnel> { let routes = config .get_tunnel_destinations() @@ -786,6 +768,7 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, + connectivity_check, ) .map_err(Error::TunnelError)? } else { @@ -797,6 +780,7 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, + connectivity_check, ) .map_err(Error::TunnelError)? }; @@ -1103,6 +1087,15 @@ pub enum TunnelError { #[cfg(daita)] #[error("Failed to start DAITA - tunnel implemenation does not support DAITA")] DaitaNotSupported, + + /// [connectivity] error. + #[error(transparent)] + Connectivity(#[from] Box<connectivity::Error>), + + /// Tunnel seemingly does not serve any traffic + #[cfg(target_os = "android")] + #[error("Tunnel seemingly does not serve any traffic")] + TunnelUp, } #[cfg(target_os = "linux")] diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 6734475c4b..d283758f3b 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -6,6 +6,8 @@ use super::{ }; #[cfg(target_os = "linux")] use crate::config::MULLVAD_INTERFACE_NAME; +#[cfg(target_os = "android")] +use crate::connectivity; use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; #[cfg(daita)] @@ -75,7 +77,7 @@ impl WgGoTunnel { &self.0 } - fn to_state_mut(&mut self) -> &mut WgGoTunnelState { + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { &mut self.0 } } @@ -96,14 +98,17 @@ impl WgGoTunnel { } } - fn to_state_mut(&mut self) -> &mut WgGoTunnelState { + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { match self { WgGoTunnel::Multihop(state) => state, WgGoTunnel::Singlehop(state) => state, } } - pub fn better_set_config(self, config: &Config) -> Result<Self> { + pub fn set_config(mut self, config: &Config) -> Result<Self> { + let connectivity_checker = self + .take_checker() + .expect("connectivity checker unexpectedly dropped"); let state = self.as_state(); let log_path = state._logging_context.path.clone(); let tun_provider = Arc::clone(&state.tun_provider); @@ -120,6 +125,7 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, + connectivity_checker, ) } WgGoTunnel::Singlehop(state) if config.is_multihop() => { @@ -131,6 +137,7 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, + connectivity_checker, ) } WgGoTunnel::Singlehop(mut state) => { @@ -163,6 +170,13 @@ pub(crate) struct WgGoTunnelState { resource_dir: PathBuf, #[cfg(daita)] config: Config, + // HACK: Check is not Clone, so we have to pass this around .. + // This is conceptually the connection between this Tunnel and the currently running + // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of + // a new Tunnel during the "ensure_connectivity" phase. This field should be removed + // as soon as we implement a better way to cancel Check asynchronously. + #[cfg(target_os = "android")] + connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>, } impl WgGoTunnelState { @@ -288,6 +302,7 @@ impl WgGoTunnel { tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check<connectivity::Cancellable>, ) -> Result<Self> { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -310,7 +325,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - Ok(WgGoTunnel::Singlehop(WgGoTunnelState { + let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -320,7 +335,14 @@ impl WgGoTunnel { resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), - })) + connectivity_checker: None, + }); + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + + Ok(tunnel) } pub fn start_multihop_tunnel( @@ -330,6 +352,7 @@ impl WgGoTunnel { tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check<connectivity::Cancellable>, ) -> Result<Self> { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -368,7 +391,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - Ok(WgGoTunnel::Multihop(WgGoTunnelState { + let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -378,7 +401,14 @@ impl WgGoTunnel { resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), - })) + connectivity_checker: None, + }); + + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + + Ok(tunnel) } fn bypass_tunnel_sockets( @@ -393,6 +423,28 @@ impl WgGoTunnel { Ok(()) } + + pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> { + self.as_state_mut().connectivity_checker.take() + } + + /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve + /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. + fn ensure_tunnel_is_running( + &self, + checker: &mut connectivity::Check<connectivity::Cancellable>, + ) -> Result<()> { + let connectivity_err = |e| TunnelError::Connectivity(Box::new(e)); + let connection_established = checker + .establish_connectivity(self) + .map_err(connectivity_err)?; + + // Timed out + if !connection_established { + return Err(TunnelError::TunnelUp); + } + Ok(()) + } } impl Tunnel for WgGoTunnel { @@ -418,7 +470,7 @@ impl Tunnel for WgGoTunnel { &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> { - Box::pin(async move { self.to_state_mut().set_config(config) }) + Box::pin(async move { self.as_state_mut().set_config(config) }) } #[cfg(daita)] |
