summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/tunnel/wireguard/connectivity_check.rs240
1 files changed, 238 insertions, 2 deletions
diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
index 73cdff92e4..111edd772a 100644
--- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs
+++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
@@ -342,8 +342,15 @@ impl ConnState {
#[cfg(test)]
mod test {
- use super::{ConnState, Stats, BYTES_RX_TIMEOUT, TRAFFIC_TIMEOUT};
- use std::time::{Duration, Instant};
+ use super::*;
+ use crate::tunnel::wireguard::{stats, TunnelError};
+ use std::{
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc, Mutex,
+ },
+ time::{Duration, Instant},
+ };
/// Test if a newly created ConnState won't have timed out or consider itself connected
#[test]
@@ -440,4 +447,233 @@ mod test {
assert!(conn_state.rx_timed_out());
assert!(!conn_state.traffic_timed_out());
}
+
+ #[derive(Default)]
+ struct MockPinger {
+ on_send_ping: Option<Box<dyn FnMut() + Send>>,
+ }
+
+ impl Pinger for MockPinger {
+ fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> {
+ if let Some(callback) = self.on_send_ping.as_mut() {
+ (callback)();
+ }
+ Ok(())
+ }
+ }
+
+ struct MockTunnel {
+ on_get_stats: Box<dyn Fn() -> Result<stats::Stats, TunnelError> + Send>,
+ }
+
+ impl MockTunnel {
+ fn new<F: Fn() -> Result<stats::Stats, TunnelError> + Send + 'static>(f: F) -> Self {
+ Self {
+ on_get_stats: Box::new(f),
+ }
+ }
+
+ fn always_incrementing() -> Self {
+ let traffic = Mutex::new(stats::Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ });
+ Self {
+ on_get_stats: Box::new(move || {
+ let mut traffic = traffic.lock().unwrap();
+ traffic.tx_bytes += 1;
+ traffic.rx_bytes += 1;
+
+ Ok(*traffic)
+ }),
+ }
+ }
+
+ fn never_incrementing() -> Self {
+ Self {
+ on_get_stats: Box::new(|| {
+ Ok(stats::Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ })
+ }),
+ }
+ }
+
+ fn into_locked(
+ self,
+ ) -> (
+ Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ Weak<Mutex<Option<Box<dyn Tunnel>>>>,
+ ) {
+ let dyn_tunnel: Box<dyn Tunnel> = Box::new(self);
+ let arc = Arc::new(Mutex::new(Some(dyn_tunnel)));
+ let weak_ref = Arc::downgrade(&arc);
+ (arc, weak_ref)
+ }
+ }
+
+ impl Tunnel for MockTunnel {
+ fn get_interface_name(&self) -> &str {
+ "mock-tunnel"
+ }
+
+ fn stop(self: Box<Self>) -> Result<(), TunnelError> {
+ Ok(())
+ }
+
+ fn get_tunnel_stats(&self) -> Result<stats::Stats, TunnelError> {
+ (self.on_get_stats)()
+ }
+ }
+
+ fn mock_monitor(
+ now: Instant,
+ pinger: Box<dyn Pinger>,
+ tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>,
+ close_receiver: mpsc::Receiver<()>,
+ ) -> ConnectivityMonitor {
+ ConnectivityMonitor {
+ conn_state: ConnState::new(now, Default::default()),
+ initial_ping_timestamp: None,
+ num_pings_sent: 0,
+ pinger,
+ close_receiver,
+ tunnel_handle,
+ }
+ }
+
+ fn connected_state(timestamp: Instant) -> ConnState {
+ ConnState::Connected {
+ rx_timestamp: timestamp,
+ tx_timestamp: timestamp,
+ stats: stats::Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ }
+ }
+
+
+ #[test]
+ /// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is
+ /// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`.
+ fn test_ping_times_out() {
+ let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
+ let (_tx, rx) = mpsc::channel();
+ let pinger = MockPinger::default();
+ let now = Instant::now();
+ let start = now - (BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10));
+ let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+
+ // Mock the state - connectivity has been established
+ monitor.conn_state = connected_state(start);
+ // A ping was sent to verify connectivity
+ monitor.maybe_send_ping(start).unwrap();
+ assert!(!monitor.check_connectivity(now).unwrap())
+ }
+
+ #[test]
+ /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
+ /// flowing constantly.
+ fn test_no_connection_on_start() {
+ let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
+ let (_tx, rx) = mpsc::channel();
+ let pinger = MockPinger::default();
+ let now = Instant::now();
+ let start = now - Duration::from_secs(1);
+ let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+
+ assert!(!monitor.check_connectivity(now).unwrap())
+ }
+
+ #[test]
+ /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
+ /// flowing constantly.
+ fn test_connection_works() {
+ let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
+ let (_tx, rx) = mpsc::channel();
+ let pinger = MockPinger::default();
+ let now = Instant::now();
+ let start = now - Duration::from_secs(1);
+ let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+
+ // Mock the state - connectivity has been established
+ monitor.conn_state = connected_state(start);
+
+ assert!(monitor.check_connectivity(now).unwrap())
+ }
+
+ #[test]
+ /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
+ /// and it shuts down properly.
+ fn test_wait_loop() {
+ let (result_tx, result_rx) = mpsc::channel();
+ let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
+ let pinger = MockPinger::default();
+ let (stop_tx, stop_rx) = mpsc::channel();
+ std::thread::spawn(move || {
+ let now = Instant::now();
+ let start = now - Duration::from_secs(1);
+ let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
+
+ let start_result = monitor.establish_connectivity();
+ result_tx.send(start_result).unwrap();
+
+ let result = monitor.run().map(|_| true);
+ result_tx.send(result).unwrap();
+ });
+
+ std::thread::sleep(Duration::from_secs(1));
+ assert_eq!(true, result_rx.try_recv().unwrap().unwrap());
+ stop_tx.send(()).unwrap();
+ std::thread::sleep(Duration::from_secs(1));
+ assert!(result_rx.try_recv().unwrap().is_ok());
+ }
+
+ #[test]
+ /// Verify that the connectivity monitor detects the tunnel timing out after no longer than
+ /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
+ fn test_wait_loop_timeout() {
+ let should_stop = Arc::new(AtomicBool::new(false));
+ let should_stop_inner = should_stop.clone();
+
+ let tunnel_stats = Mutex::new(stats::Stats {
+ rx_bytes: 0,
+ tx_bytes: 0,
+ });
+
+ let pinger = MockPinger::default();
+ let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {
+ let mut tunnel_stats = tunnel_stats.lock().unwrap();
+ if !should_stop_inner.load(Ordering::SeqCst) {
+ tunnel_stats.rx_bytes += 1;
+ }
+ tunnel_stats.tx_bytes += 1;
+ Ok(tunnel_stats.clone())
+ })
+ .into_locked();
+
+ let (result_tx, result_rx) = mpsc::channel();
+
+ let (_stop_tx, stop_rx) = mpsc::channel();
+ std::thread::spawn(move || {
+ let now = Instant::now();
+ let start = now - Duration::from_secs(1);
+ let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
+ let start_result = monitor.establish_connectivity();
+ result_tx.send(start_result).unwrap();
+ let end_result = monitor.run().map(|_| true);
+ result_tx.send(end_result).expect("Failed to send result");
+ });
+ assert!(result_rx
+ .recv_timeout(Duration::from_secs(1))
+ .unwrap()
+ .unwrap());
+ should_stop.store(true, Ordering::SeqCst);
+ assert!(result_rx
+ .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
+ .unwrap()
+ .is_ok());
+ }
}