summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-12-18 17:32:14 +0100
committerDavid Lönnhager <david.l@mullvad.net>2025-01-23 10:02:34 +0100
commite20a0b2be52155354d1ba43881c53bf19f50d5a6 (patch)
tree6fc60d537025ac30bfb58ac28c77231bc75a78b8
parentd0e32be44b8ea13b22eca74d4e37399b2cec7803 (diff)
downloadmullvadvpn-e20a0b2be52155354d1ba43881c53bf19f50d5a6.tar.xz
mullvadvpn-e20a0b2be52155354d1ba43881c53bf19f50d5a6.zip
Refactor connectivity check to be async
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-wireguard/Cargo.toml1
-rw-r--r--talpid-wireguard/src/connectivity/check.rs370
-rw-r--r--talpid-wireguard/src/connectivity/constants.rs2
-rw-r--r--talpid-wireguard/src/connectivity/mock.rs18
-rw-r--r--talpid-wireguard/src/connectivity/mod.rs4
-rw-r--r--talpid-wireguard/src/connectivity/monitor.rs171
-rw-r--r--talpid-wireguard/src/connectivity/pinger/android.rs67
-rw-r--r--talpid-wireguard/src/connectivity/pinger/icmp.rs57
-rw-r--r--talpid-wireguard/src/connectivity/pinger/mod.rs5
-rw-r--r--talpid-wireguard/src/ephemeral.rs1
-rw-r--r--talpid-wireguard/src/lib.rs198
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs95
-rw-r--r--talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs17
-rw-r--r--talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs23
-rw-r--r--talpid-wireguard/src/wireguard_nt/mod.rs18
16 files changed, 549 insertions, 499 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 13e2e0d7e6..2bf2110f63 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4678,6 +4678,7 @@ dependencies = [
name = "talpid-wireguard"
version = "0.0.0"
dependencies = [
+ "async-trait",
"bitflags 1.3.2",
"byteorder",
"chrono",
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index e02bf874d2..3a19f5a70a 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -11,6 +11,7 @@ rust-version.workspace = true
workspace = true
[dependencies]
+async-trait = "0.1"
thiserror = { workspace = true }
futures = { workspace = true }
hex = "0.4"
diff --git a/talpid-wireguard/src/connectivity/check.rs b/talpid-wireguard/src/connectivity/check.rs
index 702ce97f2d..a5ac9cbeef 100644
--- a/talpid-wireguard/src/connectivity/check.rs
+++ b/talpid-wireguard/src/connectivity/check.rs
@@ -1,7 +1,9 @@
-use std::cmp;
use std::net::Ipv4Addr;
-use std::sync::mpsc;
-use std::time::{Duration, Instant};
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::sync::broadcast;
+use tokio::time::Instant;
use super::constants::*;
use super::error::Error;
@@ -35,52 +37,70 @@ use pinger::Pinger;
///
/// Once a connection established, a connection is only considered broken once the connectivity
/// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`.
-pub struct Check<Strategy = Timeout> {
+pub struct Check {
conn_state: ConnState,
ping_state: PingState,
- strategy: Strategy,
+ cancel_receiver: CancelReceiver,
retry_attempt: u32,
}
-// Define the type state of [Check]
-pub(crate) trait Strategy {
- fn should_shut_down(&mut self, timeout: Duration) -> bool;
+/// A handle that can be used to shut down the connectivity monitor.
+/// The monitor will also be shut down if all tokens are dropped.
+#[derive(Debug, Clone)]
+pub struct CancelToken {
+ closed: Arc<AtomicBool>,
+ tx: broadcast::Sender<()>,
}
-/// An uncancellable [Check] that will run [Check::establish_connectivity] until
-/// completion or until it times out.
-pub struct Timeout;
+/// A handle that can be passed to a [Check]. The corresponding [CancelToken] causes the [Check] to
+/// be stopped. Any [CancelToken] will cancel all receivers
+#[derive(Debug)]
+pub struct CancelReceiver {
+ closed: Arc<AtomicBool>,
+ rx: broadcast::Receiver<()>,
+}
-impl Strategy for Timeout {
- /// The Timeout strategy cannot receive shut down signals so this function always returns false.
- fn should_shut_down(&mut self, _timeout: Duration) -> bool {
- false
+impl CancelReceiver {
+ fn closed(&self) -> bool {
+ self.closed.load(Ordering::SeqCst)
}
}
-/// A cancellable [Check] may be cancelled before it will time out by sending
-/// a signal on the channel returned by [Check::with_cancellation]. Otherwise,
-/// it behaves as [Timeout].
-pub struct Cancellable {
- close_receiver: mpsc::Receiver<()>,
+impl Clone for CancelReceiver {
+ fn clone(&self) -> Self {
+ Self {
+ closed: self.closed.clone(),
+ rx: self.rx.resubscribe(),
+ }
+ }
}
-impl Strategy for Cancellable {
- /// Returns true if monitor should be shut down
- fn should_shut_down(&mut self, timeout: Duration) -> bool {
- match self.close_receiver.recv_timeout(timeout) {
- Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
- Err(mpsc::RecvTimeoutError::Timeout) => false,
- }
+impl CancelToken {
+ pub fn new() -> (Self, CancelReceiver) {
+ let (tx, rx) = broadcast::channel(1);
+ let closed = Arc::new(AtomicBool::new(false));
+ (
+ CancelToken {
+ closed: closed.clone(),
+ tx,
+ },
+ CancelReceiver { closed, rx },
+ )
+ }
+
+ pub fn close(&self) {
+ self.closed.store(true, Ordering::SeqCst);
+ let _ = self.tx.send(());
}
}
-impl Check<Timeout> {
+impl Check {
pub fn new(
addr: Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
retry_attempt: u32,
- ) -> Result<Check<Timeout>, Error> {
+ cancel_receiver: CancelReceiver,
+ ) -> Result<Check, Error> {
Ok(Check {
conn_state: ConnState::new(Instant::now(), Default::default()),
ping_state: PingState::new(
@@ -88,47 +108,37 @@ impl Check<Timeout> {
#[cfg(any(target_os = "macos", target_os = "linux"))]
interface,
)?,
- strategy: Timeout,
retry_attempt,
+ cancel_receiver,
})
}
- /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping
- /// the returned channel.
- pub fn with_cancellation(self) -> (Check<Cancellable>, mpsc::Sender<()>) {
- let (cancellation_tx, cancellation_rx) = mpsc::channel();
- let check = Check {
- conn_state: self.conn_state,
- ping_state: self.ping_state,
- strategy: Cancellable {
- close_receiver: cancellation_rx,
- },
- retry_attempt: self.retry_attempt,
- };
- (check, cancellation_tx)
- }
-
#[cfg(test)]
- /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy,
- /// see [Check::with_cancellation].
- pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self {
- Check {
- conn_state,
- ping_state,
- strategy: Timeout,
- retry_attempt: 0,
- }
+ /// Create a new [Check] with a custom initial state.
+ pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> (Self, CancelToken) {
+ let (cancel_token, cancel_receiver) = CancelToken::new();
+ (
+ Check {
+ conn_state,
+ ping_state,
+ retry_attempt: 0,
+ cancel_receiver,
+ },
+ cancel_token,
+ )
}
-}
-impl<S: Strategy> Check<S> {
// checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is
// successful at the start of a connection.
- pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result<bool, Error> {
+ pub async fn establish_connectivity(
+ &mut self,
+ tunnel_handle: &TunnelType,
+ ) -> Result<bool, Error> {
// Send initial ping to prod WireGuard into connecting.
self.ping_state
.pinger
.send_icmp()
+ .await
.map_err(Error::PingError)?;
self.establish_connectivity_inner(
self.retry_attempt,
@@ -137,18 +147,15 @@ impl<S: Strategy> Check<S> {
MAX_ESTABLISH_TIMEOUT,
tunnel_handle,
)
+ .await
}
- pub(crate) fn reset(&mut self, current_iteration: Instant) {
- self.ping_state.reset();
+ pub(crate) async fn reset(&mut self, current_iteration: Instant) {
+ self.ping_state.reset().await;
self.conn_state.reset_after_suspension(current_iteration);
}
- pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool {
- self.strategy.should_shut_down(timeout)
- }
-
- fn establish_connectivity_inner(
+ async fn establish_connectivity_inner(
&mut self,
retry_attempt: u32,
timeout_initial: Duration,
@@ -160,59 +167,96 @@ impl<S: Strategy> Check<S> {
return Ok(true);
}
- let check_timeout = cmp::min(
- max_timeout,
- timeout_initial.saturating_mul(timeout_multiplier.saturating_pow(retry_attempt)),
- );
+ let check_timeout = max_timeout
+ .min(timeout_initial.saturating_mul(timeout_multiplier.saturating_pow(retry_attempt)));
- let start = Instant::now();
- while start.elapsed() < check_timeout {
- if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? {
- return Ok(true);
+ // Begin polling tunnel traffic stats periodically
+ let poll_check = async {
+ loop {
+ if Self::check_connectivity_interval(
+ &mut self.conn_state,
+ &mut self.ping_state,
+ Instant::now(),
+ check_timeout,
+ tunnel_handle,
+ )
+ .await?
+ {
+ return Ok(true);
+ }
+ tokio::time::sleep(Duration::from_millis(20)).await;
}
- if self.should_shut_down(DELAY_ON_INITIAL_SETUP) {
- return Ok(false);
+ };
+
+ let timeout = tokio::time::sleep(check_timeout);
+
+ tokio::select! {
+ // Tunnel status polling returned a result
+ result = poll_check => {
+ result
+ }
+
+ // Cancel token signal
+ _ = self.cancel_receiver.rx.recv() => {
+ Ok(false)
+ }
+
+ // Give up if the timeout is hit
+ _ = timeout => {
+ Ok(false)
}
}
- Ok(false)
+ }
+
+ pub(crate) fn should_shut_down(&self) -> bool {
+ self.cancel_receiver.closed()
}
/// Returns true if connection is established
- pub(crate) fn check_connectivity(
+ pub(crate) async fn check_connectivity(
&mut self,
now: Instant,
tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
- self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle)
+ Self::check_connectivity_interval(
+ &mut self.conn_state,
+ &mut self.ping_state,
+ now,
+ PING_TIMEOUT,
+ tunnel_handle,
+ )
+ .await
}
/// Returns true if connection is established
- fn check_connectivity_interval(
- &mut self,
+ async fn check_connectivity_interval(
+ conn_state: &mut ConnState,
+ ping_state: &mut PingState,
now: Instant,
timeout: Duration,
tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
- match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? {
+ match Self::get_stats(tunnel_handle)
+ .await
+ .map_err(Error::ConfigReadError)?
+ {
None => Ok(false),
Some(new_stats) => {
- if self.conn_state.update(now, new_stats) {
- self.ping_state.reset();
+ if conn_state.update(now, new_stats) {
+ ping_state.reset().await;
return Ok(true);
}
- self.maybe_send_ping(now)?;
- Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected())
+ Self::maybe_send_ping(conn_state, ping_state, now).await?;
+ Ok(!ping_state.ping_timed_out(timeout) && conn_state.connected())
}
}
}
/// If None is returned, then the underlying tunnel has already been closed and all subsequent
/// calls will also return None.
- ///
- /// NOTE: will panic if called from within a tokio runtime.
- fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> {
- let stats = tunnel_handle.get_tunnel_stats()?;
+ async fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> {
+ let stats = tunnel_handle.get_tunnel_stats().await?;
if stats.is_empty() {
log::error!("Tunnel unexpectedly shut down");
Ok(None)
@@ -221,28 +265,31 @@ impl<S: Strategy> Check<S> {
}
}
- fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> {
+ async fn maybe_send_ping(
+ conn_state: &mut ConnState,
+ ping_state: &mut PingState,
+ now: Instant,
+ ) -> Result<(), Error> {
// Only send out a ping if we haven't received a byte in a while or no traffic has flowed
// in the last 2 minutes, but if a ping already has been sent out, only send one out every
// 3 seconds.
- if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out())
- && self
- .ping_state
+ if (conn_state.rx_timed_out() || conn_state.traffic_timed_out())
+ && ping_state
.initial_ping_timestamp
.map(|initial_ping_timestamp| {
- initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent
- < SECONDS_PER_PING
+ initial_ping_timestamp.elapsed() / ping_state.num_pings_sent < SECONDS_PER_PING
})
.unwrap_or(true)
{
- self.ping_state
+ ping_state
.pinger
.send_icmp()
+ .await
.map_err(Error::PingError)?;
- if self.ping_state.initial_ping_timestamp.is_none() {
- self.ping_state.initial_ping_timestamp = Some(now);
+ if ping_state.initial_ping_timestamp.is_none() {
+ ping_state.initial_ping_timestamp = Some(now);
}
- self.ping_state.num_pings_sent += 1;
+ ping_state.num_pings_sent += 1;
}
Ok(())
}
@@ -284,10 +331,10 @@ impl PingState {
}
/// Reset timeouts - assume that the last time bytes were received is now.
- fn reset(&mut self) {
+ async fn reset(&mut self) {
self.initial_ping_timestamp = None;
self.num_pings_sent = 0;
- self.pinger.reset();
+ self.pinger.reset().await;
}
}
@@ -420,6 +467,8 @@ impl ConnState {
#[cfg(test)]
mod test {
+ use tokio::sync::mpsc;
+
use super::*;
use crate::connectivity::mock::*;
@@ -527,100 +576,115 @@ mod test {
assert!(!conn_state.traffic_timed_out());
}
- #[test]
+ #[tokio::test]
/// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is
/// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`.
- fn test_ping_times_out() {
+ async fn test_ping_times_out() {
let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now
.checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10))
.unwrap();
- let mut checker = mock_checker(start, Box::new(pinger));
+ let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
checker.conn_state = connected_state(start);
// A ping was sent to verify connectivity
- checker.maybe_send_ping(start).unwrap();
- assert!(!checker.check_connectivity(now, &tunnel).unwrap())
+ Check::maybe_send_ping(&mut checker.conn_state, &mut checker.ping_state, start)
+ .await
+ .unwrap();
+ assert!(!checker.check_connectivity(now, &tunnel).await.unwrap())
}
- #[test]
+ #[tokio::test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
- fn test_no_connection_on_start() {
+ async fn test_no_connection_on_start() {
let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_checker(start, Box::new(pinger));
+ let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger));
- assert!(!monitor.check_connectivity(now, &tunnel).unwrap())
+ assert!(!checker.check_connectivity(now, &tunnel).await.unwrap())
}
- #[test]
+ #[tokio::test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
- fn test_connection_works() {
+ async fn test_connection_works() {
let tunnel = MockTunnel::always_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_checker(start, Box::new(pinger));
+ let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
- monitor.conn_state = connected_state(start);
+ checker.conn_state = connected_state(start);
- assert!(monitor.check_connectivity(now, &tunnel).unwrap())
+ assert!(checker.check_connectivity(now, &tunnel).await.unwrap())
}
- #[test]
+ #[tokio::test(start_paused = true)]
/// Verify that the timeout for setting up a tunnel works as expected.
- fn test_establish_timeout() {
- let pinger = MockPinger::default();
- let tunnel = {
- let mut tunnel_stats = StatsMap::new();
- tunnel_stats.insert(
- [0u8; 32],
- Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed()
- };
+ async fn test_establish_timeout() {
+ const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
+ const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500);
+ const MAX_ESTABLISH_TIMEOUT: Duration = Duration::from_secs(2);
- let (result_tx, result_rx) = mpsc::channel();
+ let (result_tx, mut result_rx) = mpsc::channel(1);
- std::thread::spawn(move || {
+ tokio::spawn(async move {
+ let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_checker(start, Box::new(pinger));
+ let (mut monitor, _cancel_token) = mock_checker(start, Box::new(pinger));
- const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
- const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500);
- const MAX_ESTABLISH_TIMEOUT: Duration = Duration::from_secs(2);
+ let tunnel = {
+ let mut tunnel_stats = StatsMap::new();
+ tunnel_stats.insert(
+ [0u8; 32],
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed()
+ };
- for attempt in 0..4 {
- result_tx
- .send(monitor.establish_connectivity_inner(
- attempt,
- ESTABLISH_TIMEOUT,
- ESTABLISH_TIMEOUT_MULTIPLIER,
- MAX_ESTABLISH_TIMEOUT,
- &tunnel,
- ))
- .unwrap();
- }
+ result_tx
+ .send(
+ monitor
+ .establish_connectivity_inner(
+ 0,
+ ESTABLISH_TIMEOUT,
+ ESTABLISH_TIMEOUT_MULTIPLIER,
+ MAX_ESTABLISH_TIMEOUT,
+ &tunnel,
+ )
+ .await,
+ )
+ .await
+ .unwrap();
});
- let err = DELAY_ON_INITIAL_SETUP + Duration::from_millis(350);
- let assert_rx = |recv_timeout: Duration| {
- assert!(!result_rx.recv_timeout(recv_timeout + err).unwrap().unwrap());
- };
- assert_rx(Duration::from_millis(500));
- assert_rx(Duration::from_secs(1));
- assert_rx(Duration::from_secs(2));
- assert_rx(Duration::from_secs(2));
+
+ tokio::time::timeout(
+ ESTABLISH_TIMEOUT - Duration::from_millis(100),
+ result_rx.recv(),
+ )
+ .await
+ .expect_err("expected timeout");
+
+ // Should assume no connectivity after timeout
+ let connected = tokio::time::timeout(
+ ESTABLISH_TIMEOUT + Duration::from_millis(100),
+ result_rx.recv(),
+ )
+ .await
+ .expect("expected no timeout")
+ .unwrap()
+ .unwrap();
+ assert!(!connected);
}
}
diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs
index a8d6752ddd..28c8acf1a5 100644
--- a/talpid-wireguard/src/connectivity/constants.rs
+++ b/talpid-wireguard/src/connectivity/constants.rs
@@ -1,7 +1,5 @@
use std::time::Duration;
-/// Sleep time used when initially establishing connectivity
-pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50);
/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is
/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic
/// is received.
diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs
index eea3004bfc..5b7c98b183 100644
--- a/talpid-wireguard/src/connectivity/mock.rs
+++ b/talpid-wireguard/src/connectivity/mock.rs
@@ -1,8 +1,8 @@
use std::future::Future;
use std::pin::Pin;
-use std::time::Instant;
+use tokio::time::Instant;
-use super::check::{ConnState, PingState, Timeout};
+use super::check::{CancelToken, ConnState, PingState};
use super::pinger;
use super::Check;
@@ -14,14 +14,14 @@ pub use crate::stats::{Stats, StatsMap};
#[derive(Default)]
pub(crate) struct MockPinger {
- on_send_ping: Option<Box<dyn FnMut() + Send>>,
+ on_send_ping: Option<Box<dyn FnMut() + Send + Sync>>,
}
pub(crate) struct MockTunnel {
- on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send>,
+ on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send + Sync>,
}
-pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> Check<Timeout> {
+pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> (Check, CancelToken) {
let conn_state = ConnState::new(now, Default::default());
let ping_state = PingState::new_with(pinger);
Check::mock(conn_state, ping_state)
@@ -47,7 +47,7 @@ pub fn connected_state(timestamp: Instant) -> ConnState {
impl MockTunnel {
const PEER: [u8; 32] = [0u8; 32];
- pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + 'static>(f: F) -> Self {
+ pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + Sync + 'static>(f: F) -> Self {
Self {
on_get_stats: Box::new(f),
}
@@ -97,6 +97,7 @@ impl MockTunnel {
}
}
+#[async_trait::async_trait]
impl Tunnel for MockTunnel {
fn get_interface_name(&self) -> String {
"mock-tunnel".to_string()
@@ -106,7 +107,7 @@ impl Tunnel for MockTunnel {
Ok(())
}
- fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> {
+ async fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> {
(self.on_get_stats)()
}
@@ -126,8 +127,9 @@ impl Tunnel for MockTunnel {
}
}
+#[async_trait::async_trait]
impl Pinger for MockPinger {
- fn send_icmp(&mut self) -> Result<(), pinger::Error> {
+ async fn send_icmp(&mut self) -> Result<(), pinger::Error> {
if let Some(callback) = self.on_send_ping.as_mut() {
(callback)();
}
diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs
index 512d8715f1..2da555ad45 100644
--- a/talpid-wireguard/src/connectivity/mod.rs
+++ b/talpid-wireguard/src/connectivity/mod.rs
@@ -7,7 +7,7 @@ mod monitor;
mod pinger;
#[cfg(target_os = "android")]
-pub use check::Cancellable;
-pub use check::Check;
+pub use check::CancelReceiver;
+pub use check::{CancelToken, Check};
pub use error::Error;
pub use monitor::Monitor;
diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs
index 583b8d9589..1272b43f4d 100644
--- a/talpid-wireguard/src/connectivity/monitor.rs
+++ b/talpid-wireguard/src/connectivity/monitor.rs
@@ -1,69 +1,73 @@
-use std::{
- sync::Weak,
- time::{Duration, Instant},
-};
+use std::{sync::Weak, time::Duration};
use tokio::sync::Mutex;
+use tokio::time::{Instant, MissedTickBehavior};
use crate::TunnelType;
-use super::check::{Cancellable, Check};
+use super::check::Check;
use super::error::Error;
/// Sleep time used when checking if an established connection is still working.
const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1);
+/// Reset the checker if the last check occurred this long ago
+const SUSPEND_TIMEOUT: Duration = Duration::from_secs(6);
+
pub struct Monitor {
- connectivity_check: Check<Cancellable>,
+ connectivity_check: Check,
}
impl Monitor {
- pub fn init(connectivity_check: Check<Cancellable>) -> Self {
+ pub fn init(connectivity_check: Check) -> Self {
Self { connectivity_check }
}
- pub fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> {
- self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle)
- }
-
- fn wait_loop(
+ pub async fn run(
mut self,
- iter_delay: Duration,
tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
) -> Result<(), Error> {
- let mut last_iteration = Instant::now();
- while !self.connectivity_check.should_shut_down(iter_delay) {
- let mut current_iteration = Instant::now();
- let time_slept = current_iteration - last_iteration;
- if time_slept < (iter_delay * 2) {
- let Some(tunnel) = tunnel_handle.upgrade() else {
- return Ok(());
- };
- let lock = tunnel.blocking_lock();
- let Some(tunnel) = lock.as_ref() else {
- return Ok(());
- };
+ let mut last_check = Instant::now();
- if !self
- .connectivity_check
- .check_connectivity(Instant::now(), tunnel)?
- {
- return Ok(());
- }
- drop(lock);
+ let mut interval = tokio::time::interval(REGULAR_LOOP_SLEEP);
+ interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
- let end = Instant::now();
- if end - current_iteration > Duration::from_secs(1) {
- current_iteration = end;
- }
- } else {
- // Loop was suspended for too long, so it's safer to assume that the host still has
- // connectivity.
- self.connectivity_check.reset(current_iteration);
+ loop {
+ if self.connectivity_check.should_shut_down() {
+ return Ok(());
+ }
+
+ let now = Instant::now();
+ let time_slept = now - last_check;
+ last_check = now;
+
+ if time_slept >= SUSPEND_TIMEOUT {
+ self.connectivity_check.reset(now).await;
+ } else if !self.tunnel_exists_and_is_connected(&tunnel_handle).await? {
+ return Ok(());
}
- last_iteration = current_iteration;
+
+ interval.tick().await;
}
- Ok(())
+ }
+
+ async fn tunnel_exists_and_is_connected(
+ &mut self,
+ tunnel_handle: &Weak<Mutex<Option<TunnelType>>>,
+ ) -> Result<bool, Error> {
+ let Some(tunnel) = tunnel_handle.upgrade() else {
+ // Tunnel closed
+ return Ok(false);
+ };
+ let lock = tunnel.lock().await;
+ let Some(tunnel) = lock.as_ref() else {
+ // Tunnel closed
+ return Ok(false);
+ };
+
+ self.connectivity_check
+ .check_connectivity(Instant::now(), tunnel)
+ .await
}
}
@@ -71,54 +75,52 @@ impl Monitor {
mod test {
use super::*;
- // TODO: Port to async + tokio to reduce cost of testing?
use std::sync::atomic::{AtomicBool, Ordering};
- use std::sync::mpsc;
use std::sync::Arc;
use std::time::Duration;
- use std::time::Instant;
+ use tokio::sync::mpsc;
use tokio::sync::Mutex;
use crate::connectivity::constants::*;
use crate::connectivity::mock::*;
- #[test]
+ #[tokio::test(start_paused = true)]
/// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
/// and it shuts down properly.
- fn test_wait_loop() {
- use std::sync::mpsc;
- let (result_tx, result_rx) = mpsc::channel();
+ async fn test_wait_loop() {
+ let (result_tx, mut result_rx) = mpsc::channel(1);
let tunnel = MockTunnel::always_incrementing().boxed();
let pinger = MockPinger::default();
let (mut checker, stop_tx) = {
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- mock_checker(start, Box::new(pinger)).with_cancellation()
+ mock_checker(start, Box::new(pinger))
};
- std::thread::spawn(move || {
- let start_result = checker.establish_connectivity(&tunnel);
- result_tx.send(start_result).unwrap();
+
+ tokio::spawn(async move {
+ let start_result = checker.establish_connectivity(&tunnel).await;
+ result_tx.send(start_result).await.unwrap();
// Pointer dance
let tunnel = Arc::new(Mutex::new(Some(tunnel)));
let _tunnel = Arc::downgrade(&tunnel);
- let result = Monitor::init(checker).run(_tunnel).map(|_| true);
- result_tx.send(result).unwrap();
+ let result = Monitor::init(checker).run(_tunnel).await.map(|_| true);
+ result_tx.send(result).await.unwrap();
});
- std::thread::sleep(Duration::from_secs(1));
+ tokio::time::sleep(Duration::from_secs(1)).await;
assert!(result_rx.try_recv().unwrap().unwrap());
- stop_tx.send(()).unwrap();
- std::thread::sleep(Duration::from_secs(1));
+ stop_tx.close();
+ tokio::time::sleep(Duration::from_secs(2)).await;
assert!(result_rx.try_recv().unwrap().is_ok());
}
- #[test]
+ #[tokio::test(start_paused = true)]
/// Verify that the connectivity monitor detects the tunnel timing out after no longer than
/// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
- fn test_wait_loop_timeout() {
- let should_stop = Arc::new(AtomicBool::new(false));
- let should_stop_inner = should_stop.clone();
+ async fn test_wait_loop_timeout() {
+ let stop_bytes_rx = Arc::new(AtomicBool::new(false));
+ let stop_bytes_rx_inner = stop_bytes_rx.clone();
let mut map = StatsMap::new();
map.insert(
@@ -133,7 +135,7 @@ mod test {
let pinger = MockPinger::default();
let tunnel = MockTunnel::new(move || {
let mut tunnel_stats = tunnel_stats.lock().unwrap();
- if !should_stop_inner.load(Ordering::SeqCst) {
+ if !stop_bytes_rx_inner.load(Ordering::SeqCst) {
for traffic in tunnel_stats.values_mut() {
traffic.rx_bytes += 1;
}
@@ -145,30 +147,41 @@ mod test {
})
.boxed();
- let (result_tx, result_rx) = mpsc::channel();
+ let (result_tx, mut result_rx) = mpsc::channel(1);
- std::thread::spawn(move || {
+ tokio::spawn(async move {
let (mut checker, _cancellation_token) = {
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- mock_checker(start, Box::new(pinger)).with_cancellation()
+ mock_checker(start, Box::new(pinger))
};
- let start_result = checker.establish_connectivity(&tunnel);
- result_tx.send(start_result).unwrap();
+ let start_result = checker.establish_connectivity(&tunnel).await;
+ result_tx.send(start_result).await.unwrap();
// Pointer dance
let _tunnel = Arc::new(Mutex::new(Some(tunnel)));
let tunnel = Arc::downgrade(&_tunnel);
- let end_result = Monitor::init(checker).run(tunnel).map(|_| true);
- result_tx.send(end_result).expect("Failed to send result");
+ let end_result = Monitor::init(checker).run(tunnel).await.map(|_| true);
+ result_tx
+ .send(end_result)
+ .await
+ .expect("Failed to send result");
});
- assert!(result_rx
- .recv_timeout(Duration::from_secs(1))
- .unwrap()
- .unwrap());
- should_stop.store(true, Ordering::SeqCst);
- assert!(result_rx
- .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
- .unwrap()
- .is_ok());
+
+ assert!(
+ tokio::time::timeout(Duration::from_secs(1), result_rx.recv())
+ .await
+ .unwrap()
+ .unwrap()
+ .unwrap()
+ );
+ stop_bytes_rx.store(true, Ordering::SeqCst);
+ assert!(tokio::time::timeout(
+ BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2),
+ result_rx.recv()
+ )
+ .await
+ .unwrap()
+ .unwrap()
+ .is_ok());
}
}
diff --git a/talpid-wireguard/src/connectivity/pinger/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs
index 00ad4d8fd3..34e28f8891 100644
--- a/talpid-wireguard/src/connectivity/pinger/android.rs
+++ b/talpid-wireguard/src/connectivity/pinger/android.rs
@@ -1,4 +1,9 @@
-use std::{io, net::Ipv4Addr};
+use std::net::Ipv4Addr;
+use std::process::Stdio;
+use std::time::Duration;
+
+use tokio::io;
+use tokio::process::{Child, Command};
/// Pinger errors
#[derive(thiserror::Error, Debug)]
@@ -15,7 +20,7 @@ pub enum Error {
/// A pinger that sends ICMP requests without waiting for responses
pub struct Pinger {
addr: Ipv4Addr,
- processes: Vec<duct::Handle>,
+ processes: Vec<Child>,
}
impl Pinger {
@@ -28,60 +33,40 @@ impl Pinger {
}
fn try_deplete_process_list(&mut self) {
- self.processes.retain(|child| {
- match child.try_wait() {
- // child has terminated, doesn't have to be retained
- Ok(Some(_)) => false,
- _ => true,
- }
+ self.processes.retain_mut(|child| {
+ // retain non-terminated children
+ matches!(child.try_wait(), Err(_) | Ok(None))
});
}
}
+#[async_trait::async_trait]
impl super::Pinger for Pinger {
// Send an ICMP packet without waiting for a reply
- fn send_icmp(&mut self) -> Result<(), Error> {
+ async fn send_icmp(&mut self) -> Result<(), Error> {
self.try_deplete_process_list();
- let cmd = ping_cmd(self.addr, 1);
- let handle = cmd.start().map_err(Error::PingError)?;
- self.processes.push(handle);
+ let child = ping_cmd(self.addr, Duration::from_secs(1)).map_err(Error::PingError)?;
+ self.processes.push(child);
Ok(())
}
- fn reset(&mut self) {
- let processes = std::mem::take(&mut self.processes);
- for proc in processes {
- if proc
- .try_wait()
- .map(|maybe_stopped| maybe_stopped.is_none())
- .unwrap_or(false)
- {
- if let Err(err) = proc.kill() {
- log::error!("Failed to kill ping process: {}", err);
- }
- }
- }
+ async fn reset(&mut self) {
+ self.processes.clear();
}
}
-impl Drop for Pinger {
- fn drop(&mut self) {
- for child in self.processes.iter_mut() {
- if let Err(e) = child.kill() {
- log::error!("Failed to kill ping process: {}", e);
- }
- }
- }
-}
+fn ping_cmd(ip: Ipv4Addr, timeout: Duration) -> io::Result<Child> {
+ let mut cmd = Command::new("ping");
-fn ping_cmd(ip: Ipv4Addr, timeout_secs: u16) -> duct::Expression {
- let timeout_secs = timeout_secs.to_string();
+ let timeout_secs = timeout.as_secs().to_string();
let ip = ip.to_string();
- let args = ["-n", "-i", "1", "-w", &timeout_secs, &ip];
+ cmd.args(["-n", "-i", "1", "-w", &timeout_secs, &ip]);
+
+ cmd.stdin(Stdio::null())
+ .stdout(Stdio::null())
+ .stderr(Stdio::null())
+ .kill_on_drop(true);
- duct::cmd("ping", args)
- .stdin_null()
- .stdout_null()
- .unchecked()
+ cmd.spawn()
}
diff --git a/talpid-wireguard/src/connectivity/pinger/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs
index 0e5d739425..b17ee7ddc1 100644
--- a/talpid-wireguard/src/connectivity/pinger/icmp.rs
+++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs
@@ -1,11 +1,11 @@
use byteorder::{NetworkEndian, WriteBytesExt};
use rand::Rng;
use socket2::{Domain, Protocol, Socket, Type};
+use tokio::net::UdpSocket;
use std::{
io::{self, Write},
net::{Ipv4Addr, SocketAddr},
- thread,
time::Duration,
};
@@ -30,6 +30,10 @@ pub enum Error {
#[error("Failed to write to socket")]
Write(#[source] io::Error),
+ /// Failed to convert to tokio socket
+ #[error("Failed to convert to tokio socket")]
+ ConvertSocket(#[source] io::Error),
+
/// Failed to get device index
#[cfg(target_os = "macos")]
#[error("Failed to obtain device index")]
@@ -43,16 +47,12 @@ pub enum Error {
/// ICMP buffer too small
#[error("ICMP message buffer too small")]
BufferTooSmall,
-
- /// Interface name contains null bytes
- #[error("Interface name contains a null byte")]
- InterfaceNameContainsNull,
}
type Result<T> = std::result::Result<T, Error>;
pub struct Pinger {
- sock: Socket,
+ sock: UdpSocket,
addr: SocketAddr,
id: u16,
seq: u16,
@@ -76,6 +76,9 @@ impl Pinger {
#[cfg(target_os = "macos")]
Self::set_device_index(&sock, &interface_name)?;
+ let sock =
+ UdpSocket::from_std(std::net::UdpSocket::from(sock)).map_err(Error::ConvertSocket)?;
+
Ok(Self {
sock,
addr,
@@ -96,25 +99,19 @@ impl Pinger {
Ok(())
}
- fn send_ping_request(&mut self, message: &[u8], destination: SocketAddr) -> Result<()> {
+ async fn send_ping_request(&mut self, message: &[u8], destination: SocketAddr) -> Result<()> {
let mut tries = 0;
- let mut result = Ok(());
- while tries < SEND_RETRY_ATTEMPTS {
- match self.sock.send_to(message, &destination.into()) {
- Ok(_) => {
- return Ok(());
- }
- Err(err) => {
- if Some(10065) != err.raw_os_error() {
- return Err(Error::Write(err));
- }
- result = Err(Error::Write(err));
- }
+ loop {
+ let Err(error) = self.sock.send_to(message, destination).await else {
+ return Ok(());
+ };
+ if tries >= SEND_RETRY_ATTEMPTS || !should_retry_send(&error) {
+ return Err(Error::Write(error));
}
- thread::sleep(Duration::from_secs(1));
+
+ tokio::time::sleep(Duration::from_secs(1)).await;
tries += 1;
}
- result
}
fn construct_icmpv4_packet(&mut self, buffer: &mut [u8]) -> Result<()> {
@@ -125,11 +122,25 @@ impl Pinger {
}
}
+#[cfg(windows)]
+fn should_retry_send(err: &io::Error) -> bool {
+ // Winsock error for when there is no route
+ // NOTE: It's unclear if we need to check this on Windows anymore, or why specifically on Windows
+ const WSAEHOSTUNREACH: i32 = 10065;
+ err.raw_os_error() == Some(WSAEHOSTUNREACH)
+}
+
+#[cfg(unix)]
+fn should_retry_send(_err: &io::Error) -> bool {
+ false
+}
+
+#[async_trait::async_trait]
impl super::Pinger for Pinger {
- fn send_icmp(&mut self) -> Result<()> {
+ async fn send_icmp(&mut self) -> Result<()> {
let mut message = [0u8; 50];
self.construct_icmpv4_packet(&mut message)?;
- self.send_ping_request(&message, self.addr)
+ self.send_ping_request(&message, self.addr).await
}
}
diff --git a/talpid-wireguard/src/connectivity/pinger/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs
index ef2394f1b7..10875afb8a 100644
--- a/talpid-wireguard/src/connectivity/pinger/mod.rs
+++ b/talpid-wireguard/src/connectivity/pinger/mod.rs
@@ -9,11 +9,12 @@ mod imp;
pub use imp::Error;
/// Trait for sending ICMP requests to get some traffic from a remote server
+#[async_trait::async_trait]
pub trait Pinger: Send {
/// Sends an ICMP packet
- fn send_icmp(&mut self) -> Result<(), Error>;
+ async fn send_icmp(&mut self) -> Result<(), Error>;
/// Clears all resources used by the pinger.
- fn reset(&mut self) {}
+ async fn reset(&mut self) {}
}
/// Create a new pinger
diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs
index 31f3957253..1df0820014 100644
--- a/talpid-wireguard/src/ephemeral.rs
+++ b/talpid-wireguard/src/ephemeral.rs
@@ -226,6 +226,7 @@ async fn reconfigure_tunnel(
let updated_tunnel = tunnel
.set_config(&config)
+ .await
.map_err(Error::TunnelError)
.map_err(CloseMsg::SetupError)?;
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index aed06be788..1b377e2f4f 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -145,7 +145,7 @@ pub struct WireguardMonitor {
/// Callback to signal tunnel events
event_hook: EventHook,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
- pinger_stop_sender: sync_mpsc::Sender<()>,
+ pinger_stop_sender: connectivity::CancelToken,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
}
@@ -211,21 +211,22 @@ impl WireguardMonitor {
let obfuscator = Arc::new(AsyncMutex::new(obfuscator));
let gateway = config.ipv4_gateway;
- let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new(
+ let (cancel_token, cancel_receiver) = connectivity::CancelToken::new();
+ let mut connectivity_monitor = connectivity::Check::new(
gateway,
#[cfg(any(target_os = "macos", target_os = "linux"))]
iface_name.clone(),
args.retry_attempt,
+ cancel_receiver,
)
- .map_err(Error::ConnectivityMonitorError)?
- .with_cancellation();
+ .map_err(Error::ConnectivityMonitorError)?;
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
event_hook: args.event_hook.clone(),
close_msg_receiver: close_obfs_listener,
- pinger_stop_sender: pinger_tx,
+ pinger_stop_sender: cancel_token,
obfuscator,
};
@@ -281,13 +282,7 @@ impl WireguardMonitor {
// timing out on Windows for 2024.9-beta1. These verbose data usage logs are
// a temporary measure to help us understand the issue. They can be removed
// if the issue is resolved.
- if let Err(err) =
- tokio::task::spawn_blocking(move || log_tunnel_data_usage(&config, &tunnel))
- .await
- {
- log::error!("Failed to log tunnel data during setup phase");
- log::error!("{err}");
- }
+ log_tunnel_data_usage(&config, &tunnel).await;
return Err(e);
}
@@ -331,28 +326,26 @@ impl WireguardMonitor {
});
}
- let cloned_tunnel = Arc::clone(&tunnel);
-
- let connectivity_check = tokio::task::spawn_blocking(move || {
- let lock = cloned_tunnel.blocking_lock();
- let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly");
- match connectivity_monitor.establish_connectivity(tunnel) {
- Ok(true) => Ok(connectivity_monitor),
- Ok(false) => {
- log::warn!("Timeout while checking tunnel connection");
- Err(CloseMsg::PingErr)
- }
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to check tunnel connection")
- );
- Err(CloseMsg::PingErr)
- }
+ let lock = tunnel.lock().await;
+ let borrowed_tun = lock.as_ref().expect("The tunnel was dropped unexpectedly");
+ match connectivity_monitor
+ .establish_connectivity(borrowed_tun)
+ .await
+ {
+ Ok(true) => Ok(()),
+ Ok(false) => {
+ log::warn!("Timeout while checking tunnel connection");
+ Err(CloseMsg::PingErr)
}
- })
- .await
- .unwrap()?;
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check tunnel connection")
+ );
+ Err(CloseMsg::PingErr)
+ }
+ }?;
+ drop(lock);
// Add any default route(s) that may exist.
args.route_manager
@@ -364,19 +357,15 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
event_hook.on_event(TunnelEvent::Up(metadata)).await;
- let monitored_tunnel = Arc::downgrade(&tunnel);
- tokio::task::spawn_blocking(move || {
- if let Err(error) =
- connectivity::Monitor::init(connectivity_check).run(monitored_tunnel)
- {
- log::error!(
- "{}",
- error.display_chain_with_msg("Connectivity monitor failed")
- );
- }
- })
- .await
- .unwrap();
+ if let Err(error) = connectivity::Monitor::init(connectivity_monitor)
+ .run(Arc::downgrade(&tunnel))
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Connectivity monitor failed")
+ );
+ }
Err::<Infallible, CloseMsg>(CloseMsg::PingErr)
};
@@ -435,12 +424,15 @@ impl WireguardMonitor {
let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita;
- let (connectivity_check, pinger_tx) =
- connectivity::Check::new(config.ipv4_gateway, args.retry_attempt)
- .map_err(Error::ConnectivityMonitorError)?
- .with_cancellation();
+ let (cancel_token, cancel_receiver) = connectivity::CancelToken::new();
+ let connectivity_check = connectivity::Check::new(
+ config.ipv4_gateway,
+ args.retry_attempt,
+ cancel_receiver.clone(),
+ )
+ .map_err(Error::ConnectivityMonitorError)?;
- let tunnel = Self::open_wireguard_go_tunnel(
+ let tunnel = args.runtime.block_on(Self::open_wireguard_go_tunnel(
&config,
log_path,
args.tun_provider.clone(),
@@ -448,8 +440,8 @@ impl WireguardMonitor {
// that we only allows traffic to/from the gateway. This is only needed on Android
// since we lack a firewall there.
should_negotiate_ephemeral_peer,
- connectivity_check,
- )?;
+ cancel_receiver,
+ ))?;
let iface_name = tunnel.get_interface_name();
let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
@@ -459,7 +451,7 @@ impl WireguardMonitor {
tunnel: Arc::clone(&tunnel),
event_hook: event_hook.clone(),
close_msg_receiver: close_obfs_listener,
- pinger_stop_sender: pinger_tx,
+ pinger_stop_sender: cancel_token,
obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
};
@@ -492,13 +484,7 @@ impl WireguardMonitor {
// timing out on Windows for 2024.9-beta1. These verbose data usage logs are
// a temporary measure to help us understand the issue. They can be removed
// if the issue is resolved.
- if let Err(err) =
- tokio::task::spawn_blocking(move || log_tunnel_data_usage(&config, &tunnel))
- .await
- {
- log::error!("Failed to log tunnel data during setup phase");
- log::error!("{err}");
- }
+ log_tunnel_data_usage(&config, &tunnel).await;
return Err(e);
}
@@ -514,29 +500,15 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
event_hook.on_event(TunnelEvent::Up(metadata)).await;
- // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
- let connectivity_check = {
- let mut tunnel_lock = tunnel.lock().await;
- let Some(tunnel) = tunnel_lock.as_mut() else {
- log::debug!("Tunnel is no longer running");
- return Err::<Infallible, CloseMsg>(CloseMsg::PingErr);
- };
- tunnel
- .take_checker()
- .expect("connectivity checker unexpectedly dropped")
- };
-
- tokio::task::spawn_blocking(move || {
- let tunnel = Arc::downgrade(&tunnel);
- if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) {
- log::error!(
- "{}",
- error.display_chain_with_msg("Connectivity monitor failed")
- );
- }
- })
- .await
- .unwrap();
+ if let Err(error) = connectivity::Monitor::init(connectivity_check)
+ .run(Arc::downgrade(&tunnel))
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Connectivity monitor failed")
+ );
+ }
Err::<Infallible, CloseMsg>(CloseMsg::PingErr)
};
@@ -675,13 +647,18 @@ impl WireguardMonitor {
if !*FORCE_USERSPACE_WIREGUARD {
// If DAITA is enabled, wireguard-go has to be used.
if config.daita {
- let tunnel =
- Self::open_wireguard_go_tunnel(config, log_path, tun_provider).map(Box::new)?;
+ let tunnel = runtime
+ .block_on(Self::open_wireguard_go_tunnel(
+ config,
+ log_path,
+ tun_provider,
+ ))
+ .map(Box::new)?;
return Ok(tunnel);
}
if will_nm_manage_dns() {
- match wireguard_kernel::NetworkManagerTunnel::new(runtime, config) {
+ match wireguard_kernel::NetworkManagerTunnel::new(runtime.clone(), config) {
Ok(tunnel) => {
log::debug!("Using NetworkManager to use kernel WireGuard implementation");
return Ok(Box::new(tunnel));
@@ -696,7 +673,7 @@ impl WireguardMonitor {
}
};
} else {
- match wireguard_kernel::NetlinkTunnel::new(runtime, config) {
+ match wireguard_kernel::NetlinkTunnel::new(runtime.clone(), config) {
Ok(tunnel) => {
log::debug!("Using kernel WireGuard implementation");
return Ok(Box::new(tunnel));
@@ -725,28 +702,28 @@ impl WireguardMonitor {
#[cfg(target_os = "linux")]
log::debug!("Using userspace WireGuard implementation");
- let tunnel = Self::open_wireguard_go_tunnel(
- config,
- log_path,
- tun_provider,
- #[cfg(target_os = "android")]
- gateway_only,
- )
- .map(Box::new)?;
+ let tunnel = runtime
+ .block_on(Self::open_wireguard_go_tunnel(
+ config,
+ log_path,
+ tun_provider,
+ #[cfg(target_os = "android")]
+ gateway_only,
+ ))
+ .map(Box::new)?;
Ok(tunnel)
}
}
/// Configure and start a Wireguard-go tunnel.
#[cfg(wireguard_go)]
- fn open_wireguard_go_tunnel(
+ #[allow(clippy::unused_async)]
+ async fn open_wireguard_go_tunnel(
config: &Config,
log_path: Option<&Path>,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] gateway_only: bool,
- #[cfg(target_os = "android")] connectivity_check: connectivity::Check<
- connectivity::Cancellable,
- >,
+ #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver,
) -> Result<WgGoTunnel> {
let routes = config
.get_tunnel_destinations()
@@ -781,8 +758,9 @@ impl WireguardMonitor {
log_path,
tun_provider,
routes,
- connectivity_check,
+ cancel_receiver,
)
+ .await
.map_err(Error::TunnelError)?
} else {
WgGoTunnel::start_tunnel(
@@ -791,8 +769,9 @@ impl WireguardMonitor {
log_path,
tun_provider,
routes,
- connectivity_check,
+ cancel_receiver,
)
+ .await
.map_err(Error::TunnelError)?
};
@@ -811,7 +790,7 @@ impl WireguardMonitor {
Err(_) => Ok(()),
};
- let _ = self.pinger_stop_sender.send(());
+ self.pinger_stop_sender.close();
self.runtime
.block_on(self.event_hook.on_event(TunnelEvent::Down));
@@ -997,10 +976,10 @@ impl WireguardMonitor {
///
/// This will log the amount of outgoing and incoming data to and from the exit (and entry) relay
/// so far.
-fn log_tunnel_data_usage(config: &Config, tunnel: &Arc<AsyncMutex<Option<TunnelType>>>) {
- let tunnel = tunnel.blocking_lock();
+async fn log_tunnel_data_usage(config: &Config, tunnel: &Arc<AsyncMutex<Option<TunnelType>>>) {
+ let tunnel = tunnel.lock().await;
let Some(tunnel) = &*tunnel else { return };
- let Ok(tunnel_stats) = tunnel.get_tunnel_stats() else {
+ let Ok(tunnel_stats) = tunnel.get_tunnel_stats().await else {
return;
};
if let Some(stats) = config
@@ -1028,12 +1007,11 @@ enum CloseMsg {
}
#[allow(unused)]
-pub(crate) trait Tunnel: Send {
+#[async_trait::async_trait]
+pub(crate) trait Tunnel: Send + Sync {
fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
- /// # Note
- /// This function should *not* be called from within an async context.
- fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
+ async fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
fn set_config<'a>(
&'a mut self,
_config: Config,
diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index ede5ed97f5..418462e7d1 100644
--- a/talpid-wireguard/src/wireguard_go/mod.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -1,7 +1,5 @@
#[cfg(target_os = "android")]
use super::config;
-#[cfg(target_os = "android")]
-use super::Error;
use super::{
stats::{Stats, StatsMap},
Config, Tunnel, TunnelError,
@@ -106,32 +104,26 @@ impl WgGoTunnel {
}
}
- pub fn set_config(self, config: &Config) -> Result<Self> {
+ pub async fn set_config(self, config: &Config) -> Result<Self> {
let state = self.as_state();
let log_path = state._logging_context.path.clone();
+ let cancel_receiver = state.cancel_receiver.clone();
let tun_provider = Arc::clone(&state.tun_provider);
let routes = config.get_tunnel_destinations();
match self {
- WgGoTunnel::Multihop(mut state) if !config.is_multihop() => {
- let connectivity_checker = state
- .connectivity_checker
- .take()
- .expect("connectivity checker unexpectedly dropped");
+ WgGoTunnel::Multihop(state) if !config.is_multihop() => {
state.stop()?;
Self::start_tunnel(
config,
log_path.as_deref(),
tun_provider,
routes,
- connectivity_checker,
+ cancel_receiver,
)
+ .await
}
- WgGoTunnel::Singlehop(mut state) if config.is_multihop() => {
- let connectivity_checker = state
- .connectivity_checker
- .take()
- .expect("connectivity checker unexpectedly dropped");
+ WgGoTunnel::Singlehop(state) if config.is_multihop() => {
state.stop()?;
Self::start_multihop_tunnel(
config,
@@ -139,8 +131,9 @@ impl WgGoTunnel {
log_path.as_deref(),
tun_provider,
routes,
- connectivity_checker,
+ cancel_receiver,
)
+ .await
}
WgGoTunnel::Singlehop(mut state) => {
state.set_config(config.clone())?;
@@ -170,13 +163,9 @@ pub(crate) struct WgGoTunnelState {
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(daita)]
config: Config,
- // HACK: Check is not Clone, so we have to pass this around ..
- // This is conceptually the connection between this Tunnel and the currently running
- // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of
- // a new Tunnel during the "ensure_connectivity" phase. This field should be removed
- // as soon as we implement a better way to cancel Check asynchronously.
+ /// This is used to cancel the connectivity checks that occur when toggling multihop
#[cfg(target_os = "android")]
- connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>,
+ cancel_receiver: connectivity::CancelReceiver,
}
impl WgGoTunnelState {
@@ -295,12 +284,12 @@ impl WgGoTunnel {
#[cfg(target_os = "android")]
impl WgGoTunnel {
- pub fn start_tunnel(
+ pub async fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
- mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
+ cancel_receiver: connectivity::CancelReceiver,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
@@ -325,7 +314,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;
- let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
+ let tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
@@ -333,23 +322,22 @@ impl WgGoTunnel {
tun_provider,
#[cfg(daita)]
config: config.clone(),
- connectivity_checker: None,
+ cancel_receiver,
});
// HACK: Check if the tunnel is working by sending a ping in the tunnel.
- tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
- tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+ tunnel.ensure_tunnel_is_running().await?;
Ok(tunnel)
}
- pub fn start_multihop_tunnel(
+ pub async fn start_multihop_tunnel(
config: &Config,
exit_peer: &PeerConfig,
log_path: Option<&Path>,
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
- mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
+ cancel_receiver: connectivity::CancelReceiver,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
@@ -390,7 +378,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;
- let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
+ let tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
@@ -398,12 +386,11 @@ impl WgGoTunnel {
tun_provider,
#[cfg(daita)]
config: config.clone(),
- connectivity_checker: None,
+ cancel_receiver: cancel_receiver.clone(),
});
// HACK: Check if the tunnel is working by sending a ping in the tunnel.
- tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
- tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+ tunnel.ensure_tunnel_is_running().await?;
Ok(tunnel)
}
@@ -421,30 +408,34 @@ impl WgGoTunnel {
Ok(())
}
- pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> {
- self.as_state_mut().connectivity_checker.take()
- }
-
/// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve
/// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out.
- fn ensure_tunnel_is_running(
- &self,
- checker: &mut connectivity::Check<connectivity::Cancellable>,
- ) -> Result<()> {
- let connection_established = checker
+ async fn ensure_tunnel_is_running(&self) -> Result<()> {
+ let state = self.as_state();
+ let addr = state.config.ipv4_gateway;
+ let cancel_receiver = state.cancel_receiver.clone();
+ let mut check = connectivity::Check::new(addr, 0, cancel_receiver)
+ .map_err(|err| TunnelError::RecoverableStartWireguardError(Box::new(err)))?;
+
+ // TODO: retry attempt?
+
+ let connection_established = check
.establish_connectivity(self)
+ .await
.map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?;
// Timed out
if !connection_established {
return Err(TunnelError::RecoverableStartWireguardError(Box::new(
- Error::TimeoutError,
+ super::Error::TimeoutError,
)));
}
+
Ok(())
}
}
+#[async_trait::async_trait]
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> String {
self.as_state().interface_name.clone()
@@ -454,14 +445,16 @@ impl Tunnel for WgGoTunnel {
self.into_state().stop()
}
- fn get_tunnel_stats(&self) -> Result<StatsMap> {
- self.as_state()
- .tunnel_handle
- .get_config(|cstr| {
- Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8"))
- })
- .ok_or(TunnelError::GetConfigError)?
- .map_err(|error| TunnelError::StatsError(BoxedError::new(error)))
+ async fn get_tunnel_stats(&self) -> Result<StatsMap> {
+ tokio::task::block_in_place(|| {
+ self.as_state()
+ .tunnel_handle
+ .get_config(|cstr| {
+ Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8"))
+ })
+ .ok_or(TunnelError::GetConfigError)?
+ .map_err(|error| TunnelError::StatsError(BoxedError::new(error)))
+ })
}
fn set_config(
diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
index 8b84b3769d..86285d80a2 100644
--- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
+++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
@@ -65,6 +65,7 @@ impl NetlinkTunnel {
}
}
+#[async_trait::async_trait]
impl Tunnel for NetlinkTunnel {
fn get_interface_name(&self) -> String {
let mut wg = self.netlink_connections.wg_handle.clone();
@@ -103,16 +104,14 @@ impl Tunnel for NetlinkTunnel {
})
}
- fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, TunnelError> {
- let mut wg = self.netlink_connections.wg_handle.clone();
+ async fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, TunnelError> {
let interface_index = self.interface_index;
- self.tokio_handle.block_on(async move {
- let device = wg.get_by_index(interface_index).await.map_err(|err| {
- log::error!("Failed to fetch WireGuard device config: {}", err);
- TunnelError::GetConfigError
- })?;
- Ok(Stats::parse_device_message(&device))
- })
+ let mut wg = self.netlink_connections.wg_handle.clone();
+ let device = wg.get_by_index(interface_index).await.map_err(|err| {
+ log::error!("Failed to fetch WireGuard device config: {}", err);
+ TunnelError::GetConfigError
+ })?;
+ Ok(Stats::parse_device_message(&device))
}
fn set_config(
diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
index 070e3d1ee9..ba3bca14be 100644
--- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
+++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
@@ -28,7 +28,6 @@ pub struct NetworkManagerTunnel {
network_manager: NetworkManager,
tunnel: Option<WireguardTunnel>,
netlink_connections: Handle,
- tokio_handle: tokio::runtime::Handle,
interface_name: String,
}
@@ -58,12 +57,12 @@ impl NetworkManagerTunnel {
network_manager,
tunnel: Some(tunnel),
netlink_connections,
- tokio_handle,
interface_name,
})
}
}
+#[async_trait::async_trait]
impl Tunnel for NetworkManagerTunnel {
fn get_interface_name(&self) -> String {
self.interface_name.clone()
@@ -82,18 +81,16 @@ impl Tunnel for NetworkManagerTunnel {
}
}
- fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, TunnelError> {
+ async fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, TunnelError> {
let mut wg = self.netlink_connections.wg_handle.clone();
- self.tokio_handle.block_on(async move {
- let device = wg
- .get_by_name(self.interface_name.clone())
- .await
- .map_err(|err| {
- log::error!("Failed to fetch WireGuard device config: {}", err);
- TunnelError::GetConfigError
- })?;
- Ok(Stats::parse_device_message(&device))
- })
+ let device = wg
+ .get_by_name(self.interface_name.clone())
+ .await
+ .map_err(|err| {
+ log::error!("Failed to fetch WireGuard device config: {}", err);
+ TunnelError::GetConfigError
+ })?;
+ Ok(Stats::parse_device_message(&device))
}
fn set_config(
diff --git a/talpid-wireguard/src/wireguard_nt/mod.rs b/talpid-wireguard/src/wireguard_nt/mod.rs
index fefb7879e9..9243425cde 100644
--- a/talpid-wireguard/src/wireguard_nt/mod.rs
+++ b/talpid-wireguard/src/wireguard_nt/mod.rs
@@ -1037,13 +1037,20 @@ unsafe fn deserialize_config(
Ok((interface, peers))
}
+#[async_trait::async_trait]
impl Tunnel for WgNtTunnel {
fn get_interface_name(&self) -> String {
self.interface_name.clone()
}
- fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> {
- if let Some(ref device) = self.device {
+ async fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> {
+ let Some(ref device) = self.device else {
+ log::error!("Failed to obtain tunnel stats as device no longer exists");
+ return Err(super::TunnelError::GetConfigError);
+ };
+
+ let device = device.clone();
+ tokio::task::spawn_blocking(move || {
let mut map = StatsMap::new();
let (_interface, peers) = device.get_config().map_err(|error| {
log::error!(
@@ -1062,10 +1069,9 @@ impl Tunnel for WgNtTunnel {
);
}
Ok(map)
- } else {
- log::error!("Failed to obtain tunnel stats as device no longer exists");
- Err(super::TunnelError::GetConfigError)
- }
+ })
+ .await
+ .unwrap()
}
fn stop(mut self: Box<Self>) -> std::result::Result<(), super::TunnelError> {