summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-wireguard/src/connectivity/check.rs (renamed from talpid-wireguard/src/connectivity_check.rs)597
-rw-r--r--talpid-wireguard/src/connectivity/constants.rs22
-rw-r--r--talpid-wireguard/src/connectivity/error.rs14
-rw-r--r--talpid-wireguard/src/connectivity/mock.rs133
-rw-r--r--talpid-wireguard/src/connectivity/mod.rs13
-rw-r--r--talpid-wireguard/src/connectivity/monitor.rs174
-rw-r--r--talpid-wireguard/src/connectivity/pinger/android.rs (renamed from talpid-wireguard/src/ping_monitor/android.rs)0
-rw-r--r--talpid-wireguard/src/connectivity/pinger/icmp.rs (renamed from talpid-wireguard/src/ping_monitor/icmp.rs)0
-rw-r--r--talpid-wireguard/src/connectivity/pinger/mod.rs (renamed from talpid-wireguard/src/ping_monitor/mod.rs)0
-rw-r--r--talpid-wireguard/src/ephemeral.rs19
-rw-r--r--talpid-wireguard/src/lib.rs127
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs68
12 files changed, 678 insertions, 489 deletions
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity/check.rs
index 4b6e0f9810..527931563b 100644
--- a/talpid-wireguard/src/connectivity_check.rs
+++ b/talpid-wireguard/src/connectivity/check.rs
@@ -1,55 +1,17 @@
-use crate::{
- ping_monitor::{new_pinger, Pinger},
- stats::StatsMap,
- TunnelType,
-};
-use std::{
- cmp,
- net::Ipv4Addr,
- sync::{mpsc, Weak},
- time::{Duration, Instant},
-};
-use tokio::sync::Mutex;
+use std::cmp;
+use std::net::Ipv4Addr;
+use std::sync::mpsc;
+use std::time::{Duration, Instant};
-#[cfg(target_os = "android")]
-use super::Tunnel;
-use super::TunnelError;
-
-/// Sleep time used when initially establishing connectivity
-const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50);
-/// Sleep time used when checking if an established connection is still working.
-const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1);
-
-/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is
-/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic
-/// is received.
-const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5);
-/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will
-/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received.
-const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120);
-/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this
-/// timeout is reached, it is assumed that the connection is lost.
-const PING_TIMEOUT: Duration = Duration::from_secs(15);
-/// Timeout for receiving traffic when establishing a connection.
-const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4);
-/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt.
-const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
-/// Maximum timeout for establishing a connection.
-const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT;
-/// Number of seconds to wait between sending ICMP packets
-const SECONDS_PER_PING: Duration = Duration::from_secs(3);
-
-/// Connectivity monitor errors
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- /// Failed to read tunnel's configuration
- #[error("Failed to read tunnel's configuration")]
- ConfigReadError(TunnelError),
+use super::constants::*;
+use super::error::Error;
+use super::pinger;
- /// Failed to send ping
- #[error("Ping monitor failed")]
- PingError(#[from] crate::ping_monitor::Error),
-}
+use crate::stats::StatsMap;
+#[cfg(target_os = "android")]
+use crate::Tunnel;
+use crate::{TunnelError, TunnelType};
+use pinger::Pinger;
/// Verifies if a connection to a tunnel is working.
/// The connectivity monitor is biased to receiving traffic - it is expected that all outgoing
@@ -73,61 +35,126 @@ pub enum Error {
///
/// Once a connection established, a connection is only considered broken once the connectivity
/// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`.
-pub struct ConnectivityMonitor {
- /// Tunnel implementation
- tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
+pub struct Check<Strategy = Timeout> {
conn_state: ConnState,
- initial_ping_timestamp: Option<Instant>,
- num_pings_sent: u32,
- pinger: Box<dyn Pinger>,
+ ping_state: PingState,
+ strategy: Strategy,
+ retry_attempt: u32,
+}
+
+// Define the type state of [Check]
+pub(crate) trait Strategy {
+ fn should_shut_down(&mut self, timeout: Duration) -> bool;
+}
+
+/// An uncancellable [Check] that will run [Check::establish_connectivity] until
+/// completion or until it times out.
+pub struct Timeout;
+
+impl Strategy for Timeout {
+ /// The Timeout strategy cannot receive shut down signals so this function always returns false.
+ fn should_shut_down(&mut self, _timeout: Duration) -> bool {
+ false
+ }
+}
+
+/// A cancellable [Check] may be cancelled before it will time out by sending
+/// a signal on the channel returned by [Check::with_cancellation]. Otherwise,
+/// it behaves as [Timeout].
+pub struct Cancellable {
close_receiver: mpsc::Receiver<()>,
}
-impl ConnectivityMonitor {
- pub(super) fn new(
+impl Strategy for Cancellable {
+ /// Returns true if monitor should be shut down
+ fn should_shut_down(&mut self, timeout: Duration) -> bool {
+ match self.close_receiver.recv_timeout(timeout) {
+ Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
+ Err(mpsc::RecvTimeoutError::Timeout) => false,
+ }
+ }
+}
+
+impl Check<Timeout> {
+ pub fn new(
addr: Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
- tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
- close_receiver: mpsc::Receiver<()>,
- ) -> Result<Self, Error> {
- let pinger = new_pinger(
- addr,
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- interface,
- )
- .map_err(Error::PingError)?;
+ retry_attempt: u32,
+ ) -> Result<Check<Timeout>, Error> {
+ Ok(Check {
+ conn_state: ConnState::new(Instant::now(), Default::default()),
+ ping_state: PingState::new(
+ addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ interface,
+ )?,
+ strategy: Timeout,
+ retry_attempt,
+ })
+ }
- let now = Instant::now();
+ /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping
+ /// the returned channel.
+ pub fn with_cancellation(self) -> (Check<Cancellable>, mpsc::Sender<()>) {
+ let (cancellation_tx, cancellation_rx) = mpsc::channel();
+ let check = Check {
+ conn_state: self.conn_state,
+ ping_state: self.ping_state,
+ strategy: Cancellable {
+ close_receiver: cancellation_rx,
+ },
+ retry_attempt: self.retry_attempt,
+ };
+ (check, cancellation_tx)
+ }
- Ok(Self {
- tunnel_handle,
- conn_state: ConnState::new(now, Default::default()),
- initial_ping_timestamp: None,
- num_pings_sent: 0,
- pinger,
- close_receiver,
- })
+ #[cfg(test)]
+ /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy,
+ /// see [Check::with_cancellation].
+ pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self {
+ Check {
+ conn_state,
+ ping_state,
+ strategy: Timeout,
+ retry_attempt: 0,
+ }
}
+}
+impl<S: Strategy> Check<S> {
// checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is
// successful at the start of a connection.
- pub(super) fn establish_connectivity(&mut self, retry_attempt: u32) -> Result<bool, Error> {
+ pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result<bool, Error> {
// Send initial ping to prod WireGuard into connecting.
- self.pinger.send_icmp().map_err(Error::PingError)?;
+ self.ping_state
+ .pinger
+ .send_icmp()
+ .map_err(Error::PingError)?;
self.establish_connectivity_inner(
- retry_attempt,
+ self.retry_attempt,
ESTABLISH_TIMEOUT,
ESTABLISH_TIMEOUT_MULTIPLIER,
MAX_ESTABLISH_TIMEOUT,
+ tunnel_handle,
)
}
+ pub(crate) fn reset(&mut self, current_iteration: Instant) {
+ self.ping_state.reset();
+ self.conn_state.reset_after_suspension(current_iteration);
+ }
+
+ pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool {
+ self.strategy.should_shut_down(timeout)
+ }
+
fn establish_connectivity_inner(
&mut self,
retry_attempt: u32,
timeout_initial: Duration,
timeout_multiplier: u32,
max_timeout: Duration,
+ tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
if self.conn_state.connected() {
return Ok(true);
@@ -140,7 +167,7 @@ impl ConnectivityMonitor {
let start = Instant::now();
while start.elapsed() < check_timeout {
- if self.check_connectivity_interval(Instant::now(), check_timeout)? {
+ if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? {
return Ok(true);
}
if self.should_shut_down(DELAY_ON_INITIAL_SETUP) {
@@ -150,46 +177,13 @@ impl ConnectivityMonitor {
Ok(false)
}
- pub(super) fn run(&mut self) -> Result<(), Error> {
- self.wait_loop(REGULAR_LOOP_SLEEP)
- }
-
- /// Returns true if monitor should be shut down
- fn should_shut_down(&mut self, timeout: Duration) -> bool {
- match self.close_receiver.recv_timeout(timeout) {
- Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
- Err(mpsc::RecvTimeoutError::Timeout) => false,
- }
- }
-
- fn wait_loop(&mut self, iter_delay: Duration) -> Result<(), Error> {
- let mut last_iteration = Instant::now();
- while !self.should_shut_down(iter_delay) {
- let mut current_iteration = Instant::now();
- let time_slept = current_iteration - last_iteration;
- if time_slept < (iter_delay * 2) {
- if !self.check_connectivity(Instant::now())? {
- return Ok(());
- }
-
- let end = Instant::now();
- if end - current_iteration > Duration::from_secs(1) {
- current_iteration = end;
- }
- } else {
- // Loop was suspended for too long, so it's safer to assume that the host still has
- // connectivity.
- self.reset_pinger();
- self.conn_state.reset_after_suspension(current_iteration);
- }
- last_iteration = current_iteration;
- }
- Ok(())
- }
-
/// Returns true if connection is established
- fn check_connectivity(&mut self, now: Instant) -> Result<bool, Error> {
- self.check_connectivity_interval(now, PING_TIMEOUT)
+ pub(crate) fn check_connectivity(
+ &mut self,
+ now: Instant,
+ tunnel_handle: &TunnelType,
+ ) -> Result<bool, Error> {
+ self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle)
}
/// Returns true if connection is established
@@ -197,19 +191,18 @@ impl ConnectivityMonitor {
&mut self,
now: Instant,
timeout: Duration,
+ tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
- match self.get_stats() {
+ match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? {
None => Ok(false),
Some(new_stats) => {
- let new_stats = new_stats?;
-
if self.conn_state.update(now, new_stats) {
- self.reset_pinger();
+ self.ping_state.reset();
return Ok(true);
}
self.maybe_send_ping(now)?;
- Ok(!self.ping_timed_out(timeout) && self.conn_state.connected())
+ Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected())
}
}
}
@@ -218,19 +211,14 @@ impl ConnectivityMonitor {
/// calls will also return None.
///
/// NOTE: will panic if called from within a tokio runtime.
- fn get_stats(&self) -> Option<Result<StatsMap, Error>> {
- self.tunnel_handle
- .upgrade()?
- .blocking_lock()
- .as_ref()
- .and_then(|tunnel| match tunnel.get_tunnel_stats() {
- Ok(stats) if stats.is_empty() => {
- log::error!("Tunnel unexpectedly shut down");
- None
- }
- Ok(stats) => Some(Ok(stats)),
- Err(error) => Some(Err(Error::ConfigReadError(error))),
- })
+ fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> {
+ let stats = tunnel_handle.get_tunnel_stats()?;
+ if stats.is_empty() {
+ log::error!("Tunnel unexpectedly shut down");
+ Ok(None)
+ } else {
+ Ok(Some(stats))
+ }
}
fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> {
@@ -239,20 +227,55 @@ impl ConnectivityMonitor {
// 3 seconds.
if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out())
&& self
+ .ping_state
.initial_ping_timestamp
.map(|initial_ping_timestamp| {
- initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING
+ initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent
+ < SECONDS_PER_PING
})
.unwrap_or(true)
{
- self.pinger.send_icmp().map_err(Error::PingError)?;
- if self.initial_ping_timestamp.is_none() {
- self.initial_ping_timestamp = Some(now);
+ self.ping_state
+ .pinger
+ .send_icmp()
+ .map_err(Error::PingError)?;
+ if self.ping_state.initial_ping_timestamp.is_none() {
+ self.ping_state.initial_ping_timestamp = Some(now);
}
- self.num_pings_sent += 1;
+ self.ping_state.num_pings_sent += 1;
}
Ok(())
}
+}
+
+pub(super) struct PingState {
+ initial_ping_timestamp: Option<Instant>,
+ num_pings_sent: u32,
+ pinger: Box<dyn Pinger>,
+}
+
+impl PingState {
+ pub(super) fn new(
+ addr: Ipv4Addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
+ ) -> Result<Self, Error> {
+ let pinger = pinger::new_pinger(
+ addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ interface,
+ )
+ .map_err(Error::PingError)?;
+
+ Ok(Self::new_with(pinger))
+ }
+
+ pub(super) fn new_with(pinger: Box<dyn Pinger>) -> Self {
+ Self {
+ initial_ping_timestamp: None,
+ num_pings_sent: 0,
+ pinger,
+ }
+ }
fn ping_timed_out(&self, timeout: Duration) -> bool {
self.initial_ping_timestamp
@@ -261,14 +284,14 @@ impl ConnectivityMonitor {
}
/// Reset timeouts - assume that the last time bytes were received is now.
- fn reset_pinger(&mut self) {
+ fn reset(&mut self) {
self.initial_ping_timestamp = None;
self.num_pings_sent = 0;
self.pinger.reset();
}
}
-enum ConnState {
+pub(super) enum ConnState {
Connecting {
start: Instant,
stats: StatsMap,
@@ -401,21 +424,8 @@ impl ConnState {
#[cfg(test)]
mod test {
- use futures::Future;
-
use super::*;
- use crate::{
- config::Config,
- stats::{self, Stats},
- Tunnel,
- };
- use std::{
- pin::Pin,
- sync::{
- atomic::{AtomicBool, Ordering},
- Arc,
- },
- };
+ use crate::connectivity::mock::*;
/// Test if a newly created ConnState won't have timed out or consider itself connected
#[test]
@@ -521,300 +531,76 @@ mod test {
assert!(!conn_state.traffic_timed_out());
}
- #[derive(Default)]
- struct MockPinger {
- on_send_ping: Option<Box<dyn FnMut() + Send>>,
- }
-
- impl Pinger for MockPinger {
- fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> {
- if let Some(callback) = self.on_send_ping.as_mut() {
- (callback)();
- }
- Ok(())
- }
- }
-
- struct MockTunnel {
- on_get_stats: Box<dyn Fn() -> Result<stats::StatsMap, TunnelError> + Send>,
- }
-
- impl MockTunnel {
- const PEER: [u8; 32] = [0u8; 32];
-
- fn new<F: Fn() -> Result<stats::StatsMap, TunnelError> + Send + 'static>(f: F) -> Self {
- Self {
- on_get_stats: Box::new(f),
- }
- }
-
- fn always_incrementing() -> Self {
- let mut map = stats::StatsMap::new();
- map.insert(
- Self::PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- let peers = std::sync::Mutex::new(map);
- Self {
- on_get_stats: Box::new(move || {
- let mut peers = peers.lock().unwrap();
- for traffic in peers.values_mut() {
- traffic.tx_bytes += 1;
- traffic.rx_bytes += 1;
- }
- Ok(peers.clone())
- }),
- }
- }
-
- fn never_incrementing() -> Self {
- Self {
- on_get_stats: Box::new(|| {
- let mut map = stats::StatsMap::new();
- map.insert(
- Self::PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- Ok(map)
- }),
- }
- }
-
- #[allow(clippy::type_complexity)]
- fn into_locked(
- self,
- ) -> (
- Arc<Mutex<Option<Box<dyn Tunnel>>>>,
- Weak<Mutex<Option<Box<dyn Tunnel>>>>,
- ) {
- let dyn_tunnel: Box<dyn Tunnel> = Box::new(self);
- let arc = Arc::new(Mutex::new(Some(dyn_tunnel)));
- let weak_ref = Arc::downgrade(&arc);
- (arc, weak_ref)
- }
- }
-
- impl Tunnel for MockTunnel {
- fn get_interface_name(&self) -> String {
- "mock-tunnel".to_string()
- }
-
- fn stop(self: Box<Self>) -> Result<(), TunnelError> {
- Ok(())
- }
-
- fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> {
- (self.on_get_stats)()
- }
-
- fn set_config(
- &mut self,
- _config: Config,
- ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
- Box::pin(async { Ok(()) })
- }
-
- #[cfg(daita)]
- fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
- Ok(())
- }
- }
-
- fn mock_monitor(
- now: Instant,
- pinger: Box<dyn Pinger>,
- tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>,
- close_receiver: mpsc::Receiver<()>,
- ) -> ConnectivityMonitor {
- ConnectivityMonitor {
- conn_state: ConnState::new(now, Default::default()),
- initial_ping_timestamp: None,
- num_pings_sent: 0,
- pinger,
- close_receiver,
- tunnel_handle,
- }
- }
-
- fn connected_state(timestamp: Instant) -> ConnState {
- const PEER: [u8; 32] = [0u8; 32];
- let mut stats = stats::StatsMap::new();
- stats.insert(
- PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- ConnState::Connected {
- rx_timestamp: timestamp,
- tx_timestamp: timestamp,
- stats,
- }
- }
-
#[test]
/// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is
/// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`.
fn test_ping_times_out() {
- let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now
.checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10))
.unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut checker = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
- monitor.conn_state = connected_state(start);
+ checker.conn_state = connected_state(start);
// A ping was sent to verify connectivity
- monitor.maybe_send_ping(start).unwrap();
- assert!(!monitor.check_connectivity(now).unwrap())
+ checker.maybe_send_ping(start).unwrap();
+ assert!(!checker.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
fn test_no_connection_on_start() {
- let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
- assert!(!monitor.check_connectivity(now).unwrap())
+ assert!(!monitor.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
fn test_connection_works() {
- let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::always_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
monitor.conn_state = connected_state(start);
- assert!(monitor.check_connectivity(now).unwrap())
- }
-
- #[test]
- /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
- /// and it shuts down properly.
- fn test_wait_loop() {
- let (result_tx, result_rx) = mpsc::channel();
- let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
- let pinger = MockPinger::default();
- let (stop_tx, stop_rx) = mpsc::channel();
- std::thread::spawn(move || {
- let now = Instant::now();
- let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
-
- let start_result = monitor.establish_connectivity(0);
- result_tx.send(start_result).unwrap();
-
- let result = monitor.run().map(|_| true);
- result_tx.send(result).unwrap();
- });
-
- std::thread::sleep(Duration::from_secs(1));
- assert!(result_rx.try_recv().unwrap().unwrap());
- stop_tx.send(()).unwrap();
- std::thread::sleep(Duration::from_secs(1));
- assert!(result_rx.try_recv().unwrap().is_ok());
- }
-
- #[test]
- /// Verify that the connectivity monitor detects the tunnel timing out after no longer than
- /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
- fn test_wait_loop_timeout() {
- let should_stop = Arc::new(AtomicBool::new(false));
- let should_stop_inner = should_stop.clone();
-
- let mut map = stats::StatsMap::new();
- map.insert(
- [0u8; 32],
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- let tunnel_stats = std::sync::Mutex::new(map);
-
- let pinger = MockPinger::default();
- let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {
- let mut tunnel_stats = tunnel_stats.lock().unwrap();
- if !should_stop_inner.load(Ordering::SeqCst) {
- for traffic in tunnel_stats.values_mut() {
- traffic.rx_bytes += 1;
- }
- }
- for traffic in tunnel_stats.values_mut() {
- traffic.tx_bytes += 1;
- }
- Ok(tunnel_stats.clone())
- })
- .into_locked();
-
- let (result_tx, result_rx) = mpsc::channel();
-
- let (_stop_tx, stop_rx) = mpsc::channel();
- std::thread::spawn(move || {
- let now = Instant::now();
- let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
- let start_result = monitor.establish_connectivity(0);
- result_tx.send(start_result).unwrap();
- let end_result = monitor.run().map(|_| true);
- result_tx.send(end_result).expect("Failed to send result");
- });
- assert!(result_rx
- .recv_timeout(Duration::from_secs(1))
- .unwrap()
- .unwrap());
- should_stop.store(true, Ordering::SeqCst);
- assert!(result_rx
- .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
- .unwrap()
- .is_ok());
+ assert!(monitor.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that the timeout for setting up a tunnel works as expected.
fn test_establish_timeout() {
- let mut tunnel_stats = stats::StatsMap::new();
- tunnel_stats.insert(
- [0u8; 32],
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
-
let pinger = MockPinger::default();
- let (_tunnel_anchor, tunnel) =
- MockTunnel::new(move || Ok(tunnel_stats.clone())).into_locked();
+ let tunnel = {
+ let mut tunnel_stats = StatsMap::new();
+ tunnel_stats.insert(
+ [0u8; 32],
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed()
+ };
let (result_tx, result_rx) = mpsc::channel();
- let (_stop_tx, stop_rx) = mpsc::channel();
std::thread::spawn(move || {
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500);
@@ -827,6 +613,7 @@ mod test {
ESTABLISH_TIMEOUT,
ESTABLISH_TIMEOUT_MULTIPLIER,
MAX_ESTABLISH_TIMEOUT,
+ &tunnel,
))
.unwrap();
}
diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs
new file mode 100644
index 0000000000..a8d6752ddd
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/constants.rs
@@ -0,0 +1,22 @@
+use std::time::Duration;
+
+/// Sleep time used when initially establishing connectivity
+pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50);
+/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is
+/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic
+/// is received.
+pub(crate) const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5);
+/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will
+/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received.
+pub(crate) const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120);
+/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this
+/// timeout is reached, it is assumed that the connection is lost.
+pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(15);
+/// Timeout for receiving traffic when establishing a connection.
+pub(crate) const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4);
+/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt.
+pub(crate) const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
+/// Maximum timeout for establishing a connection.
+pub(crate) const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT;
+/// Number of seconds to wait between sending ICMP packets
+pub(crate) const SECONDS_PER_PING: Duration = Duration::from_secs(3);
diff --git a/talpid-wireguard/src/connectivity/error.rs b/talpid-wireguard/src/connectivity/error.rs
new file mode 100644
index 0000000000..9e8c98a751
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/error.rs
@@ -0,0 +1,14 @@
+use super::pinger;
+use crate::TunnelError;
+
+/// Connectivity monitor errors
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+ /// Failed to read tunnel's configuration
+ #[error("Failed to read tunnel's configuration")]
+ ConfigReadError(TunnelError),
+
+ /// Failed to send ping
+ #[error("Ping failed")]
+ PingError(#[from] pinger::Error),
+}
diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs
new file mode 100644
index 0000000000..892f3966ea
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/mock.rs
@@ -0,0 +1,133 @@
+use std::future::Future;
+use std::pin::Pin;
+use std::time::Instant;
+
+use super::check::{ConnState, PingState, Timeout};
+use super::pinger;
+use super::Check;
+
+use crate::{Config, Tunnel, TunnelError};
+use pinger::Pinger;
+
+// Convenient re-exports
+pub use crate::stats::{Stats, StatsMap};
+
+#[derive(Default)]
+pub(crate) struct MockPinger {
+ on_send_ping: Option<Box<dyn FnMut() + Send>>,
+}
+
+pub(crate) struct MockTunnel {
+ on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send>,
+}
+
+pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> Check<Timeout> {
+ let conn_state = ConnState::new(now, Default::default());
+ let ping_state = PingState::new_with(pinger);
+ Check::mock(conn_state, ping_state)
+}
+
+pub fn connected_state(timestamp: Instant) -> ConnState {
+ const PEER: [u8; 32] = [0u8; 32];
+ let mut stats = StatsMap::new();
+ stats.insert(
+ PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ ConnState::Connected {
+ rx_timestamp: timestamp,
+ tx_timestamp: timestamp,
+ stats,
+ }
+}
+
+impl MockTunnel {
+ const PEER: [u8; 32] = [0u8; 32];
+
+ pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + 'static>(f: F) -> Self {
+ Self {
+ on_get_stats: Box::new(f),
+ }
+ }
+
+ /// Convert self to the more general [TunnelType].
+ pub fn boxed(self) -> Box<dyn Tunnel> {
+ Box::new(self)
+ }
+
+ pub fn always_incrementing() -> Self {
+ let mut map = StatsMap::new();
+ map.insert(
+ Self::PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ let peers = std::sync::Mutex::new(map);
+ Self {
+ on_get_stats: Box::new(move || {
+ let mut peers = peers.lock().unwrap();
+ for traffic in peers.values_mut() {
+ traffic.tx_bytes += 1;
+ traffic.rx_bytes += 1;
+ }
+ Ok(peers.clone())
+ }),
+ }
+ }
+
+ pub fn never_incrementing() -> Self {
+ Self {
+ on_get_stats: Box::new(|| {
+ let mut map = StatsMap::new();
+ map.insert(
+ Self::PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ Ok(map)
+ }),
+ }
+ }
+}
+
+impl Tunnel for MockTunnel {
+ fn get_interface_name(&self) -> String {
+ "mock-tunnel".to_string()
+ }
+
+ fn stop(self: Box<Self>) -> Result<(), TunnelError> {
+ Ok(())
+ }
+
+ fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> {
+ (self.on_get_stats)()
+ }
+
+ fn set_config(
+ &mut self,
+ _config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
+ Box::pin(async { Ok(()) })
+ }
+
+ #[cfg(daita)]
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
+ Ok(())
+ }
+}
+
+impl Pinger for MockPinger {
+ fn send_icmp(&mut self) -> Result<(), pinger::Error> {
+ if let Some(callback) = self.on_send_ping.as_mut() {
+ (callback)();
+ }
+ Ok(())
+ }
+}
diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs
new file mode 100644
index 0000000000..512d8715f1
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/mod.rs
@@ -0,0 +1,13 @@
+mod check;
+mod constants;
+mod error;
+#[cfg(test)]
+mod mock;
+mod monitor;
+mod pinger;
+
+#[cfg(target_os = "android")]
+pub use check::Cancellable;
+pub use check::Check;
+pub use error::Error;
+pub use monitor::Monitor;
diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs
new file mode 100644
index 0000000000..583b8d9589
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/monitor.rs
@@ -0,0 +1,174 @@
+use std::{
+ sync::Weak,
+ time::{Duration, Instant},
+};
+
+use tokio::sync::Mutex;
+
+use crate::TunnelType;
+
+use super::check::{Cancellable, Check};
+use super::error::Error;
+
+/// Sleep time used when checking if an established connection is still working.
+const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1);
+
+pub struct Monitor {
+ connectivity_check: Check<Cancellable>,
+}
+
+impl Monitor {
+ pub fn init(connectivity_check: Check<Cancellable>) -> Self {
+ Self { connectivity_check }
+ }
+
+ pub fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> {
+ self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle)
+ }
+
+ fn wait_loop(
+ mut self,
+ iter_delay: Duration,
+ tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
+ ) -> Result<(), Error> {
+ let mut last_iteration = Instant::now();
+ while !self.connectivity_check.should_shut_down(iter_delay) {
+ let mut current_iteration = Instant::now();
+ let time_slept = current_iteration - last_iteration;
+ if time_slept < (iter_delay * 2) {
+ let Some(tunnel) = tunnel_handle.upgrade() else {
+ return Ok(());
+ };
+ let lock = tunnel.blocking_lock();
+ let Some(tunnel) = lock.as_ref() else {
+ return Ok(());
+ };
+
+ if !self
+ .connectivity_check
+ .check_connectivity(Instant::now(), tunnel)?
+ {
+ return Ok(());
+ }
+ drop(lock);
+
+ let end = Instant::now();
+ if end - current_iteration > Duration::from_secs(1) {
+ current_iteration = end;
+ }
+ } else {
+ // Loop was suspended for too long, so it's safer to assume that the host still has
+ // connectivity.
+ self.connectivity_check.reset(current_iteration);
+ }
+ last_iteration = current_iteration;
+ }
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ // TODO: Port to async + tokio to reduce cost of testing?
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::mpsc;
+ use std::sync::Arc;
+ use std::time::Duration;
+ use std::time::Instant;
+
+ use tokio::sync::Mutex;
+
+ use crate::connectivity::constants::*;
+ use crate::connectivity::mock::*;
+
+ #[test]
+ /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
+ /// and it shuts down properly.
+ fn test_wait_loop() {
+ use std::sync::mpsc;
+ let (result_tx, result_rx) = mpsc::channel();
+ let tunnel = MockTunnel::always_incrementing().boxed();
+ let pinger = MockPinger::default();
+ let (mut checker, stop_tx) = {
+ let now = Instant::now();
+ let start = now.checked_sub(Duration::from_secs(1)).unwrap();
+ mock_checker(start, Box::new(pinger)).with_cancellation()
+ };
+ std::thread::spawn(move || {
+ let start_result = checker.establish_connectivity(&tunnel);
+ result_tx.send(start_result).unwrap();
+ // Pointer dance
+ let tunnel = Arc::new(Mutex::new(Some(tunnel)));
+ let _tunnel = Arc::downgrade(&tunnel);
+ let result = Monitor::init(checker).run(_tunnel).map(|_| true);
+ result_tx.send(result).unwrap();
+ });
+
+ std::thread::sleep(Duration::from_secs(1));
+ assert!(result_rx.try_recv().unwrap().unwrap());
+ stop_tx.send(()).unwrap();
+ std::thread::sleep(Duration::from_secs(1));
+ assert!(result_rx.try_recv().unwrap().is_ok());
+ }
+
+ #[test]
+ /// Verify that the connectivity monitor detects the tunnel timing out after no longer than
+ /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
+ fn test_wait_loop_timeout() {
+ let should_stop = Arc::new(AtomicBool::new(false));
+ let should_stop_inner = should_stop.clone();
+
+ let mut map = StatsMap::new();
+ map.insert(
+ [0u8; 32],
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ let tunnel_stats = std::sync::Mutex::new(map);
+
+ let pinger = MockPinger::default();
+ let tunnel = MockTunnel::new(move || {
+ let mut tunnel_stats = tunnel_stats.lock().unwrap();
+ if !should_stop_inner.load(Ordering::SeqCst) {
+ for traffic in tunnel_stats.values_mut() {
+ traffic.rx_bytes += 1;
+ }
+ }
+ for traffic in tunnel_stats.values_mut() {
+ traffic.tx_bytes += 1;
+ }
+ Ok(tunnel_stats.clone())
+ })
+ .boxed();
+
+ let (result_tx, result_rx) = mpsc::channel();
+
+ std::thread::spawn(move || {
+ let (mut checker, _cancellation_token) = {
+ let now = Instant::now();
+ let start = now.checked_sub(Duration::from_secs(1)).unwrap();
+ mock_checker(start, Box::new(pinger)).with_cancellation()
+ };
+ let start_result = checker.establish_connectivity(&tunnel);
+ result_tx.send(start_result).unwrap();
+ // Pointer dance
+ let _tunnel = Arc::new(Mutex::new(Some(tunnel)));
+ let tunnel = Arc::downgrade(&_tunnel);
+ let end_result = Monitor::init(checker).run(tunnel).map(|_| true);
+ result_tx.send(end_result).expect("Failed to send result");
+ });
+ assert!(result_rx
+ .recv_timeout(Duration::from_secs(1))
+ .unwrap()
+ .unwrap());
+ should_stop.store(true, Ordering::SeqCst);
+ assert!(result_rx
+ .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
+ .unwrap()
+ .is_ok());
+ }
+}
diff --git a/talpid-wireguard/src/ping_monitor/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs
index 00ad4d8fd3..00ad4d8fd3 100644
--- a/talpid-wireguard/src/ping_monitor/android.rs
+++ b/talpid-wireguard/src/connectivity/pinger/android.rs
diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs
index 0e5d739425..0e5d739425 100644
--- a/talpid-wireguard/src/ping_monitor/icmp.rs
+++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs
diff --git a/talpid-wireguard/src/ping_monitor/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs
index ef2394f1b7..ef2394f1b7 100644
--- a/talpid-wireguard/src/ping_monitor/mod.rs
+++ b/talpid-wireguard/src/connectivity/pinger/mod.rs
diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs
index 127836863a..a9283fcb2e 100644
--- a/talpid-wireguard/src/ephemeral.rs
+++ b/talpid-wireguard/src/ephemeral.rs
@@ -130,6 +130,7 @@ async fn config_ephemeral_peers_inner(
&tun_provider,
)
.await?;
+
let entry_psk = request_ephemeral_peer(
retry_attempt,
&entry_config,
@@ -199,17 +200,17 @@ async fn reconfigure_tunnel(
.await
.map_err(CloseMsg::ObfuscatorFailed)?;
}
+ {
+ let mut shared_tunnel = tunnel.lock().await;
+ let tunnel = shared_tunnel.take().expect("tunnel was None");
- let mut lock = tunnel.lock().await;
-
- let tunnel = lock.take().expect("tunnel was None");
-
- let new_tunnel = tunnel
- .better_set_config(&config)
- .map_err(Error::TunnelError)
- .map_err(CloseMsg::SetupError)?;
+ let updated_tunnel = tunnel
+ .set_config(&config)
+ .map_err(Error::TunnelError)
+ .map_err(CloseMsg::SetupError)?;
- *lock = Some(new_tunnel);
+ *shared_tunnel = Some(updated_tunnel);
+ }
Ok(config)
}
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 0f05314e72..7c93b39f18 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -35,11 +35,10 @@ use tokio::sync::Mutex as AsyncMutex;
/// WireGuard config data-types
pub mod config;
-mod connectivity_check;
+mod connectivity;
mod ephemeral;
mod logging;
mod obfuscation;
-mod ping_monitor;
mod stats;
#[cfg(wireguard_go)]
mod wireguard_go;
@@ -88,7 +87,7 @@ pub enum Error {
/// Failed to set up connectivity monitor
#[error("Connectivity monitor failed")]
- ConnectivityMonitorError(#[source] connectivity_check::Error),
+ ConnectivityMonitorError(#[source] connectivity::Error),
/// Failed while negotiating ephemeral peer
#[error("Failed while negotiating ephemeral peer")]
@@ -216,8 +215,17 @@ impl WireguardMonitor {
let obfuscator = Arc::new(AsyncMutex::new(obfuscator));
+ let gateway = config.ipv4_gateway;
+ let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new(
+ gateway,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ iface_name.clone(),
+ args.retry_attempt,
+ )
+ .map_err(Error::ConnectivityMonitorError)?
+ .with_cancellation();
+
let event_callback = Box::new(on_event.clone());
- let (pinger_tx, pinger_rx) = sync_mpsc::channel();
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
@@ -227,16 +235,6 @@ impl WireguardMonitor {
obfuscator,
};
- let gateway = config.ipv4_gateway;
- let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
- gateway,
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- iface_name.clone(),
- Arc::downgrade(&monitor.tunnel),
- pinger_rx,
- )
- .map_err(Error::ConnectivityMonitorError)?;
-
let moved_tunnel = monitor.tunnel.clone();
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
@@ -321,8 +319,12 @@ impl WireguardMonitor {
});
}
- let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
- match connectivity_monitor.establish_connectivity(args.retry_attempt) {
+ let cloned_tunnel = Arc::clone(&tunnel);
+
+ let connectivity_check = tokio::task::spawn_blocking(move || {
+ let lock = cloned_tunnel.blocking_lock();
+ let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly");
+ match connectivity_monitor.establish_connectivity(tunnel) {
Ok(true) => Ok(connectivity_monitor),
Ok(false) => {
log::warn!("Timeout while checking tunnel connection");
@@ -350,8 +352,11 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
(on_event)(TunnelEvent::Up(metadata)).await;
+ let monitored_tunnel = Arc::downgrade(&tunnel);
tokio::task::spawn_blocking(move || {
- if let Err(error) = connectivity_monitor.run() {
+ if let Err(error) =
+ connectivity::Monitor::init(connectivity_check).run(monitored_tunnel)
+ {
log::error!(
"{}",
error.display_chain_with_msg("Connectivity monitor failed")
@@ -423,6 +428,12 @@ impl WireguardMonitor {
}
let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita;
+
+ let (connectivity_check, pinger_tx) =
+ connectivity::Check::new(config.ipv4_gateway, args.retry_attempt)
+ .map_err(Error::ConnectivityMonitorError)?
+ .with_cancellation();
+
let tunnel = Self::open_wireguard_go_tunnel(
&config,
log_path,
@@ -432,11 +443,11 @@ impl WireguardMonitor {
// that we only allows traffic to/from the gateway. This is only needed on Android
// since we lack a firewall there.
should_negotiate_ephemeral_peer,
+ connectivity_check,
)?;
+
let iface_name = tunnel.get_interface_name();
let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
-
- let (pinger_tx, pinger_rx) = sync_mpsc::channel();
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::clone(&tunnel),
@@ -446,60 +457,18 @@ impl WireguardMonitor {
obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
};
- let gateway = config.ipv4_gateway;
- let connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
- gateway,
- Arc::downgrade(&tunnel),
- pinger_rx,
- )
- .map_err(Error::ConnectivityMonitorError)?;
-
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
let tunnel_fut = async move {
let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender;
let obfuscator = moved_obfuscator;
- let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor));
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
.await;
- let handle_ping = |ping_result: std::result::Result<
- bool,
- connectivity_check::Error,
- >| match ping_result {
- Ok(true) => Ok(()),
- Ok(false) => {
- log::warn!("Timeout while checking tunnel connection");
- Err(CloseMsg::PingErr)
- }
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to check tunnel connection")
- );
- Err(CloseMsg::PingErr)
- }
- };
-
- // Prepare a closure which pings inside the tunnel when executed.
- let ping = || {
- let connectivity_monitor_arc = connectivity_monitor.clone();
- let retry_attempt = args.retry_attempt;
- move || {
- let ping_result = connectivity_monitor_arc
- .lock()
- .unwrap()
- .establish_connectivity(retry_attempt);
- handle_ping(ping_result)
- }
- };
-
if should_negotiate_ephemeral_peer {
- // Ping before negotiating the ephemeral peer to make sure that the tunnel works.
- tokio::task::spawn_blocking(ping()).await.unwrap()?;
let ephemeral_obfs_sender = close_obfs_sender.clone();
ephemeral::config_ephemeral_peers(
@@ -513,21 +482,31 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event.clone())(TunnelEvent::InterfaceUp(
+ args.on_event.clone()(TunnelEvent::InterfaceUp(
metadata,
Self::allowed_traffic_after_tunnel_config(),
))
.await;
}
- // Make sure the tunnel works (after potentially having negotiated an ephemeral peer).
- tokio::task::spawn_blocking(ping()).await.unwrap()?;
-
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event.clone())(TunnelEvent::Up(metadata)).await;
+ args.on_event.clone()(TunnelEvent::Up(metadata)).await;
+
+ // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
+ let connectivity_check = {
+ let mut tunnel_lock = tunnel.lock().await;
+ let Some(tunnel) = tunnel_lock.as_mut() else {
+ log::debug!("Tunnel is no longer running");
+ return Err::<Infallible, CloseMsg>(CloseMsg::PingErr);
+ };
+ tunnel
+ .take_checker()
+ .expect("connectivity checker unexpectedly dropped")
+ };
tokio::task::spawn_blocking(move || {
- if let Err(error) = connectivity_monitor.lock().unwrap().run() {
+ let tunnel = Arc::downgrade(&tunnel);
+ if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) {
log::error!(
"{}",
error.display_chain_with_msg("Connectivity monitor failed")
@@ -748,6 +727,9 @@ impl WireguardMonitor {
#[cfg(daita)] resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] gateway_only: bool,
+ #[cfg(target_os = "android")] connectivity_check: connectivity::Check<
+ connectivity::Cancellable,
+ >,
) -> Result<WgGoTunnel> {
let routes = config
.get_tunnel_destinations()
@@ -786,6 +768,7 @@ impl WireguardMonitor {
routes,
#[cfg(daita)]
resource_dir,
+ connectivity_check,
)
.map_err(Error::TunnelError)?
} else {
@@ -797,6 +780,7 @@ impl WireguardMonitor {
routes,
#[cfg(daita)]
resource_dir,
+ connectivity_check,
)
.map_err(Error::TunnelError)?
};
@@ -1103,6 +1087,15 @@ pub enum TunnelError {
#[cfg(daita)]
#[error("Failed to start DAITA - tunnel implemenation does not support DAITA")]
DaitaNotSupported,
+
+ /// [connectivity] error.
+ #[error(transparent)]
+ Connectivity(#[from] Box<connectivity::Error>),
+
+ /// Tunnel seemingly does not serve any traffic
+ #[cfg(target_os = "android")]
+ #[error("Tunnel seemingly does not serve any traffic")]
+ TunnelUp,
}
#[cfg(target_os = "linux")]
diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index 6734475c4b..d283758f3b 100644
--- a/talpid-wireguard/src/wireguard_go/mod.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -6,6 +6,8 @@ use super::{
};
#[cfg(target_os = "linux")]
use crate::config::MULLVAD_INTERFACE_NAME;
+#[cfg(target_os = "android")]
+use crate::connectivity;
use crate::logging::{clean_up_logging, initialize_logging};
use ipnetwork::IpNetwork;
#[cfg(daita)]
@@ -75,7 +77,7 @@ impl WgGoTunnel {
&self.0
}
- fn to_state_mut(&mut self) -> &mut WgGoTunnelState {
+ fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
&mut self.0
}
}
@@ -96,14 +98,17 @@ impl WgGoTunnel {
}
}
- fn to_state_mut(&mut self) -> &mut WgGoTunnelState {
+ fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
match self {
WgGoTunnel::Multihop(state) => state,
WgGoTunnel::Singlehop(state) => state,
}
}
- pub fn better_set_config(self, config: &Config) -> Result<Self> {
+ pub fn set_config(mut self, config: &Config) -> Result<Self> {
+ let connectivity_checker = self
+ .take_checker()
+ .expect("connectivity checker unexpectedly dropped");
let state = self.as_state();
let log_path = state._logging_context.path.clone();
let tun_provider = Arc::clone(&state.tun_provider);
@@ -120,6 +125,7 @@ impl WgGoTunnel {
tun_provider,
routes,
&resource_dir,
+ connectivity_checker,
)
}
WgGoTunnel::Singlehop(state) if config.is_multihop() => {
@@ -131,6 +137,7 @@ impl WgGoTunnel {
tun_provider,
routes,
&resource_dir,
+ connectivity_checker,
)
}
WgGoTunnel::Singlehop(mut state) => {
@@ -163,6 +170,13 @@ pub(crate) struct WgGoTunnelState {
resource_dir: PathBuf,
#[cfg(daita)]
config: Config,
+ // HACK: Check is not Clone, so we have to pass this around ..
+ // This is conceptually the connection between this Tunnel and the currently running
+ // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of
+ // a new Tunnel during the "ensure_connectivity" phase. This field should be removed
+ // as soon as we implement a better way to cancel Check asynchronously.
+ #[cfg(target_os = "android")]
+ connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>,
}
impl WgGoTunnelState {
@@ -288,6 +302,7 @@ impl WgGoTunnel {
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
+ mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
@@ -310,7 +325,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;
- Ok(WgGoTunnel::Singlehop(WgGoTunnelState {
+ let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
@@ -320,7 +335,14 @@ impl WgGoTunnel {
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
config: config.clone(),
- }))
+ connectivity_checker: None,
+ });
+
+ // HACK: Check if the tunnel is working by sending a ping in the tunnel.
+ tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
+ tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+
+ Ok(tunnel)
}
pub fn start_multihop_tunnel(
@@ -330,6 +352,7 @@ impl WgGoTunnel {
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
+ mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
@@ -368,7 +391,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;
- Ok(WgGoTunnel::Multihop(WgGoTunnelState {
+ let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
@@ -378,7 +401,14 @@ impl WgGoTunnel {
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
config: config.clone(),
- }))
+ connectivity_checker: None,
+ });
+
+ // HACK: Check if the tunnel is working by sending a ping in the tunnel.
+ tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
+ tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+
+ Ok(tunnel)
}
fn bypass_tunnel_sockets(
@@ -393,6 +423,28 @@ impl WgGoTunnel {
Ok(())
}
+
+ pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> {
+ self.as_state_mut().connectivity_checker.take()
+ }
+
+ /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve
+ /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out.
+ fn ensure_tunnel_is_running(
+ &self,
+ checker: &mut connectivity::Check<connectivity::Cancellable>,
+ ) -> Result<()> {
+ let connectivity_err = |e| TunnelError::Connectivity(Box::new(e));
+ let connection_established = checker
+ .establish_connectivity(self)
+ .map_err(connectivity_err)?;
+
+ // Timed out
+ if !connection_established {
+ return Err(TunnelError::TunnelUp);
+ }
+ Ok(())
+ }
}
impl Tunnel for WgGoTunnel {
@@ -418,7 +470,7 @@ impl Tunnel for WgGoTunnel {
&mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
- Box::pin(async move { self.to_state_mut().set_config(config) })
+ Box::pin(async move { self.as_state_mut().set_config(config) })
}
#[cfg(daita)]