summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-wireguard/Cargo.toml1
-rw-r--r--talpid-wireguard/src/mtu_detection.rs248
2 files changed, 210 insertions, 39 deletions
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index de7a6d4b19..5770fdb51f 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -82,3 +82,4 @@ features = [
[dev-dependencies]
proptest = "1.4"
+tokio = { workspace = true, features = ["time", "test-util"] } \ No newline at end of file
diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs
index 55523da08b..5132705719 100644
--- a/talpid-wireguard/src/mtu_detection.rs
+++ b/talpid-wireguard/src/mtu_detection.rs
@@ -1,6 +1,6 @@
use std::{io, net::IpAddr, time::Duration};
-use futures::{future, stream::FuturesUnordered, TryStreamExt};
+use futures::{future, stream::FuturesUnordered, Future, TryStreamExt};
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU};
use tokio_stream::StreamExt;
@@ -11,23 +11,31 @@ pub enum Error {
#[error("Failed to set MTU on the active tunnel")]
SetMtu(#[source] io::Error),
- /// Failed to set MTU
- #[error("Failed to detect MTU because every ping was dropped.")]
+ /// Failed to detect MTU because every ping was dropped
+ #[error("Failed to detect MTU because all pings timed out.")]
MtuDetectionAllDropped,
- /// Failed to set MTU
+ /// Failed to detect MTU because of unexpected ping error
#[error("Failed to detect MTU because of unexpected ping error.")]
- MtuDetectionPing(#[source] surge_ping::SurgeError),
+ MtuDetectionUnexpected(#[source] surge_ping::SurgeError),
- /// Failed to set MTU
+ /// Failed to detect MTU because of an IO error when setting up the ping socket
#[error("Failed to detect MTU because of an IO error when setting up the ping socket.")]
MtuDetectionSetupSocket(#[source] io::Error),
- /// Failed to set MTU
+ /// Failed to set buffer size
#[cfg(target_os = "macos")]
#[error("Failed to set buffer size")]
MtuSetBufferSize(#[source] nix::Error),
}
+
+/// Max time to wait for any ping, when this expires, we give up and throw an error.
+const PING_TIMEOUT: Duration = Duration::from_secs(10);
+/// Max time to wait after the first ping arrives. Every ping after this timeout is
+/// considered dropped, so we return the largest collected packet size.
+const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
+const MTU_STEP_SIZE: u16 = 20;
+
/// Verify that the current MTU doesn't cause dropped packets, otherwise lower it to the
/// largest value which doesn't.
///
@@ -89,14 +97,7 @@ async fn detect_mtu(
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
current_mtu: u16,
) -> Result<u16, Error> {
- /// Max time to wait for any ping, when this expires, we give up and throw an error.
- const PING_TIMEOUT: Duration = Duration::from_secs(10);
- /// Max time to wait after the first ping arrives. Every ping after this timeout is
- /// considered dropped, so we return the largest collected packet size.
- const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
-
- let step_size = 20;
- let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
+ let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, MTU_STEP_SIZE);
let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
#[cfg(any(target_os = "macos", target_os = "linux"))]
@@ -107,7 +108,7 @@ async fn detect_mtu(
// data we are sending here. The consequence will be dropped packets causing the MTU
// detection to set a low value. Here we manually increase this value, which fixes
// the problem.
- // TODO: Make sure this fix is not needed for any other target OS
+ // NOTE: If pings drop on other unix platforms too, then enable this fix for them
#[cfg(target_os = "macos")]
{
use nix::sys::socket::{setsockopt, sockopt};
@@ -117,37 +118,55 @@ async fn detect_mtu(
setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
}
+ // Shared buffer to reduce allocations
let payload_buf = vec![0; current_mtu as usize];
- let mut ping_stream = linspace
- .iter()
+ // Send a ping for each MTU in the linspace
+ let ping_stream = linspace
+ .into_iter()
.enumerate()
- .map(|(i, &mtu)| {
+ .map(|(sequence, mtu)| {
let client = client.clone();
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
let payload = &payload_buf[0..payload_size];
+ // Return a future that sends a ping of size MTU, receives the result, and returns the
+ // validated MTU
async move {
log::trace!("Sending ICMP ping of total size {mtu}");
- client
+ let (packet, _duration) = client
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
.await
.timeout(PING_TIMEOUT)
- .ping(PingSequence(i as u16), payload)
- .await
+ .ping(PingSequence(sequence as u16), payload)
+ .await?;
+
+ // Validate the received ping response
+ {
+ let surge_ping::IcmpPacket::V4(packet) = packet else {
+ unreachable!("ICMP ping response was not of IPv4 type");
+ };
+ let size = u16::try_from(packet.get_size())
+ .expect("ICMP packet size should fit in u16")
+ + IPV4_HEADER_SIZE;
+ log::trace!("Got ICMP ping response of total size {size}");
+ debug_assert_eq!(
+ size, mtu,
+ "Ping response should be of identical size to request"
+ );
+ }
+ Ok(mtu)
}
})
- .collect::<FuturesUnordered<_>>()
- .map_ok(|(packet, _rtt)| {
- let surge_ping::IcmpPacket::V4(packet) = packet else {
- unreachable!("ICMP ping response was not of IPv4 type");
- };
- let size = u16::try_from(packet.get_size()).expect("ICMP packet size should fit in 16")
- + IPV4_HEADER_SIZE;
- log::trace!("Got ICMP ping response of total size {size}");
- debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
- size
- });
+ .collect::<FuturesUnordered<_>>();
+ max_ping_size(ping_stream).await
+}
+
+/// Consumes a stream of pings, and returns the largest packet size within [`PING_OFFSET_TIMEOUT`]
+/// from the first ping response. Short circuits on errors.
+async fn max_ping_size(
+ mut ping_stream: FuturesUnordered<impl Future<Output = Result<u16, SurgeError>>>,
+) -> Result<u16, Error> {
let first_ping_size = ping_stream
.next()
.await
@@ -157,15 +176,15 @@ async fn detect_mtu(
// If the first ping we get back timed out, then all of them did
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
// Unexpected error type
- e => Error::MtuDetectionPing(e),
+ e => Error::MtuDetectionUnexpected(e),
})?;
ping_stream
- .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
- .map_while(|res| res.ok()) // Stop waiting for pings after this timeout
+ .timeout(PING_OFFSET_TIMEOUT) // Start the timeout after the first ping has arrived
+ .map_while(|res| res.ok()) // Stop waiting for more pings after this timeout
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
.await
- .map_err(Error::MtuDetectionPing)
+ .map_err(Error::MtuDetectionUnexpected)
}
/// Creates a linear spacing of MTU values with the given step size. Always includes the given
@@ -187,20 +206,171 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
#[cfg(test)]
mod tests {
- use super::mtu_spacing;
+ use super::*;
use proptest::prelude::*;
proptest! {
#[test]
- fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
+ fn mtu_spacing_properties(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);
+ // The MTU linspace should contain the end points exactly once
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
+ // It should be allocated with no wasted capacity
prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
+ // The spacing should be no greater than step size
let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
prop_assert!(diffs.all(|diff| diff <= step_size));
}
}
+
+ /// Tests for the timeout behavior described by [`PING_OFFSET_TIMEOUT`] and [`PING_TIMEOUT`].
+ ///
+ /// Note that time is mocked using [`tokio::time::pause`]. When all current tasks are sleeping,
+ /// the clock will auto advance until the next one wakes up,
+ /// see <https://docs.rs/tokio/latest/tokio/time/fn.pause.html#auto-advance> for details.
+ mod timeout {
+ use super::*;
+ use rand::{distributions::Uniform, thread_rng};
+ use std::pin::Pin;
+ use tokio::test;
+
+ // Convenience functions for creating dynamic ping futures, required by `FuturesUnordered`
+ // to manipulate the outcome and delay of mocked pings individually
+
+ /// Ping response that is available immediately
+ fn ready_ping<T: Send + 'static>(x: T) -> Pin<Box<dyn Future<Output = T>>> {
+ Box::pin(future::ready(x))
+ }
+
+ /// Ping response that is available immediately and wraps result in Ok()
+ fn ok_ping<T: Send + 'static, E: Send + 'static>(
+ t: T,
+ ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> {
+ ready_ping(Ok(t))
+ }
+
+ /// Ping response that is available immediately and wraps result in Err()
+ fn err_ping<T: Send + 'static, E: Send + 'static>(
+ e: E,
+ ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> {
+ ready_ping(Err(e))
+ }
+
+ /// Ping response that is delayed
+ fn delayed_ping<R: Send + 'static + Unpin>(
+ ret: R,
+ duration: Duration,
+ ) -> Pin<Box<dyn Future<Output = R>>> {
+ Box::pin(async move {
+ tokio::time::sleep(duration).await;
+ ret
+ })
+ }
+
+ /// The largest ping size should be chosen if all of them return, regardless of return
+ /// order.
+ #[test(start_paused = true)]
+ async fn all_pings_ok() {
+ let mut rng = thread_rng();
+ // Random delay for each ping, but within PING_OFFSET_TIMEOUT of the first
+ let uniform = Uniform::new(Duration::ZERO, PING_OFFSET_TIMEOUT);
+ let pings = (0..=100)
+ .map(|p| delayed_ping(Ok(p), rng.sample(uniform)))
+ .collect();
+ let max = max_ping_size(pings).await.unwrap();
+ assert_eq!(max, 100);
+ }
+
+ /// If pings arrive later than [`PING_OFFSET_TIMEOUT`] after the first ping, they should be
+ /// filtered out. The largest response before that point is chosen.
+ #[test(start_paused = true)]
+ async fn ping_timeout() {
+ let mut pings = FuturesUnordered::new();
+ let ok_pings = (0..=50).map(ok_ping);
+ pings.extend(ok_pings);
+ let dropped_pings = (51..=100)
+ .map(|p| delayed_ping(Ok(p), PING_OFFSET_TIMEOUT + Duration::from_secs(1)));
+ pings.extend(dropped_pings);
+
+ let max = max_ping_size(pings).await.unwrap();
+ assert_eq!(max, 50);
+ }
+
+ /// The [`PING_OFFSET_TIMEOUT`] is counted from the return of the first ping, not from the
+ /// function call. Test that if all pings arrive after PING_OFFSET_TIMEOUT, but close to
+ /// each other in time, the largest return value is chosen as normal.
+ #[test(start_paused = true)]
+ async fn delay_first_ping() {
+ let mut rng = thread_rng();
+ // Random delay for each ping, but within PING_OFFSET_TIMEOUT of the first and no sooner
+ // than 5s
+ let uniform = Uniform::new(
+ Duration::from_secs(5),
+ Duration::from_secs(5) + PING_OFFSET_TIMEOUT,
+ );
+ let pings = (0..=100)
+ .map(|p| delayed_ping(Ok(p), rng.sample(uniform)))
+ .collect();
+ let max = max_ping_size(pings).await.unwrap();
+ assert_eq!(max, 100);
+ }
+
+ /// If an unknown error type occurs, the MTU detection is aborted and that error is
+ /// propagated, even if some ping response came back ok.
+ #[test(start_paused = true)]
+ async fn unknown_error() {
+ let pings = FuturesUnordered::new();
+ pings.push(ok_ping(0));
+ pings.push(ok_ping(100));
+ pings.push(err_ping(SurgeError::NetworkError));
+
+ let e = max_ping_size(pings).await.unwrap_err();
+ assert!(matches!(
+ e,
+ Error::MtuDetectionUnexpected(SurgeError::NetworkError)
+ ));
+ }
+
+ /// An error of type [`SurgeError::Timeout`] signals that the total [`PING_TIMEOUT`] has
+ /// been reached. If this happens to the first ping we consider alls pings timed
+ /// out.
+ #[test(start_paused = true)]
+ async fn all_dropped() {
+ let pings = FuturesUnordered::new();
+ pings.push(delayed_ping(
+ Err(SurgeError::Timeout {
+ seq: PingSequence(0),
+ }),
+ PING_TIMEOUT,
+ ));
+ pings.push(delayed_ping(Ok(100), PING_TIMEOUT + Duration::from_secs(1)));
+
+ let e = max_ping_size(pings).await.unwrap_err();
+ assert!(matches!(e, Error::MtuDetectionAllDropped));
+ }
+
+ /// In the rare case that [`PING_TIMEOUT`] triggers before [`PING_OFFSET_TIMEOUT`], even
+ /// though some of the ping responses have come back, we still consider it abnormal
+ /// and choose to return an error instead of trusting result.
+ #[test(start_paused = true)]
+ async fn max_timeout_error() {
+ let pings = FuturesUnordered::new();
+ pings.push(delayed_ping(Ok(0), PING_TIMEOUT - Duration::from_secs(1)));
+ pings.push(delayed_ping(
+ Err(SurgeError::Timeout {
+ seq: PingSequence(0),
+ }),
+ PING_TIMEOUT,
+ ));
+
+ let e = max_ping_size(pings).await.unwrap_err();
+ assert!(matches!(
+ e,
+ Error::MtuDetectionUnexpected(SurgeError::Timeout { seq: _ })
+ ));
+ }
+ }
}