diff options
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/connectivity_check.rs | 240 |
1 files changed, 238 insertions, 2 deletions
diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index 73cdff92e4..111edd772a 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -342,8 +342,15 @@ impl ConnState { #[cfg(test)] mod test { - use super::{ConnState, Stats, BYTES_RX_TIMEOUT, TRAFFIC_TIMEOUT}; - use std::time::{Duration, Instant}; + use super::*; + use crate::tunnel::wireguard::{stats, TunnelError}; + use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, + time::{Duration, Instant}, + }; /// Test if a newly created ConnState won't have timed out or consider itself connected #[test] @@ -440,4 +447,233 @@ mod test { assert!(conn_state.rx_timed_out()); assert!(!conn_state.traffic_timed_out()); } + + #[derive(Default)] + struct MockPinger { + on_send_ping: Option<Box<dyn FnMut() + Send>>, + } + + impl Pinger for MockPinger { + fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> { + if let Some(callback) = self.on_send_ping.as_mut() { + (callback)(); + } + Ok(()) + } + } + + struct MockTunnel { + on_get_stats: Box<dyn Fn() -> Result<stats::Stats, TunnelError> + Send>, + } + + impl MockTunnel { + fn new<F: Fn() -> Result<stats::Stats, TunnelError> + Send + 'static>(f: F) -> Self { + Self { + on_get_stats: Box::new(f), + } + } + + fn always_incrementing() -> Self { + let traffic = Mutex::new(stats::Stats { + tx_bytes: 0, + rx_bytes: 0, + }); + Self { + on_get_stats: Box::new(move || { + let mut traffic = traffic.lock().unwrap(); + traffic.tx_bytes += 1; + traffic.rx_bytes += 1; + + Ok(*traffic) + }), + } + } + + fn never_incrementing() -> Self { + Self { + on_get_stats: Box::new(|| { + Ok(stats::Stats { + tx_bytes: 0, + rx_bytes: 0, + }) + }), + } + } + + fn into_locked( + self, + ) -> ( + Arc<Mutex<Option<Box<dyn Tunnel>>>>, + Weak<Mutex<Option<Box<dyn Tunnel>>>>, + ) { + let dyn_tunnel: Box<dyn Tunnel> = Box::new(self); + let arc = Arc::new(Mutex::new(Some(dyn_tunnel))); + let weak_ref = Arc::downgrade(&arc); + (arc, weak_ref) + } + } + + impl Tunnel for MockTunnel { + fn get_interface_name(&self) -> &str { + "mock-tunnel" + } + + fn stop(self: Box<Self>) -> Result<(), TunnelError> { + Ok(()) + } + + fn get_tunnel_stats(&self) -> Result<stats::Stats, TunnelError> { + (self.on_get_stats)() + } + } + + fn mock_monitor( + now: Instant, + pinger: Box<dyn Pinger>, + tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, + close_receiver: mpsc::Receiver<()>, + ) -> ConnectivityMonitor { + ConnectivityMonitor { + conn_state: ConnState::new(now, Default::default()), + initial_ping_timestamp: None, + num_pings_sent: 0, + pinger, + close_receiver, + tunnel_handle, + } + } + + fn connected_state(timestamp: Instant) -> ConnState { + ConnState::Connected { + rx_timestamp: timestamp, + tx_timestamp: timestamp, + stats: stats::Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + } + } + + + #[test] + /// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is + /// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`. + fn test_ping_times_out() { + let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); + let (_tx, rx) = mpsc::channel(); + let pinger = MockPinger::default(); + let now = Instant::now(); + let start = now - (BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10)); + let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + + // Mock the state - connectivity has been established + monitor.conn_state = connected_state(start); + // A ping was sent to verify connectivity + monitor.maybe_send_ping(start).unwrap(); + assert!(!monitor.check_connectivity(now).unwrap()) + } + + #[test] + /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is + /// flowing constantly. + fn test_no_connection_on_start() { + let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked(); + let (_tx, rx) = mpsc::channel(); + let pinger = MockPinger::default(); + let now = Instant::now(); + let start = now - Duration::from_secs(1); + let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + + assert!(!monitor.check_connectivity(now).unwrap()) + } + + #[test] + /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is + /// flowing constantly. + fn test_connection_works() { + let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); + let (_tx, rx) = mpsc::channel(); + let pinger = MockPinger::default(); + let now = Instant::now(); + let start = now - Duration::from_secs(1); + let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx); + + // Mock the state - connectivity has been established + monitor.conn_state = connected_state(start); + + assert!(monitor.check_connectivity(now).unwrap()) + } + + #[test] + /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, + /// and it shuts down properly. + fn test_wait_loop() { + let (result_tx, result_rx) = mpsc::channel(); + let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked(); + let pinger = MockPinger::default(); + let (stop_tx, stop_rx) = mpsc::channel(); + std::thread::spawn(move || { + let now = Instant::now(); + let start = now - Duration::from_secs(1); + let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); + + let start_result = monitor.establish_connectivity(); + result_tx.send(start_result).unwrap(); + + let result = monitor.run().map(|_| true); + result_tx.send(result).unwrap(); + }); + + std::thread::sleep(Duration::from_secs(1)); + assert_eq!(true, result_rx.try_recv().unwrap().unwrap()); + stop_tx.send(()).unwrap(); + std::thread::sleep(Duration::from_secs(1)); + assert!(result_rx.try_recv().unwrap().is_ok()); + } + + #[test] + /// Verify that the connectivity monitor detects the tunnel timing out after no longer than + /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. + fn test_wait_loop_timeout() { + let should_stop = Arc::new(AtomicBool::new(false)); + let should_stop_inner = should_stop.clone(); + + let tunnel_stats = Mutex::new(stats::Stats { + rx_bytes: 0, + tx_bytes: 0, + }); + + let pinger = MockPinger::default(); + let (_tunnel_anchor, tunnel) = MockTunnel::new(move || { + let mut tunnel_stats = tunnel_stats.lock().unwrap(); + if !should_stop_inner.load(Ordering::SeqCst) { + tunnel_stats.rx_bytes += 1; + } + tunnel_stats.tx_bytes += 1; + Ok(tunnel_stats.clone()) + }) + .into_locked(); + + let (result_tx, result_rx) = mpsc::channel(); + + let (_stop_tx, stop_rx) = mpsc::channel(); + std::thread::spawn(move || { + let now = Instant::now(); + let start = now - Duration::from_secs(1); + let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx); + let start_result = monitor.establish_connectivity(); + result_tx.send(start_result).unwrap(); + let end_result = monitor.run().map(|_| true); + result_tx.send(end_result).expect("Failed to send result"); + }); + assert!(result_rx + .recv_timeout(Duration::from_secs(1)) + .unwrap() + .unwrap()); + should_stop.store(true, Ordering::SeqCst); + assert!(result_rx + .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) + .unwrap() + .is_ok()); + } } |
