diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-02-13 11:39:01 +0100 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-02-27 12:30:51 +0100 |
| commit | fb42a23a39a9c0a1a65fd8a347a985783b3c4bf2 (patch) | |
| tree | b72483c955bfe0af6c693180fd72b8b466fb938c | |
| parent | 9223ccf44e25369221bf753e12c92a5668dac5fe (diff) | |
| download | mullvadvpn-fb42a23a39a9c0a1a65fd8a347a985783b3c4bf2.tar.xz mullvadvpn-fb42a23a39a9c0a1a65fd8a347a985783b3c4bf2.zip | |
Add unit test for MTU detection
Split MTU detection into an inner pure function `max_ping_sized`
and an outer function `detect_mtu` and add unit and prop-testing to
the non io-dependent parts.
| -rw-r--r-- | talpid-wireguard/src/mtu_detection.rs | 195 |
1 files changed, 156 insertions, 39 deletions
diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs index 55523da08b..6166c2b4b9 100644 --- a/talpid-wireguard/src/mtu_detection.rs +++ b/talpid-wireguard/src/mtu_detection.rs @@ -1,8 +1,9 @@ use std::{io, net::IpAddr, time::Duration}; -use futures::{future, stream::FuturesUnordered, TryStreamExt}; +use futures::{future, stream::FuturesUnordered, Stream, TryStreamExt}; use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError}; use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU}; +use tokio::pin; use tokio_stream::StreamExt; #[derive(thiserror::Error, Debug)] @@ -11,23 +12,31 @@ pub enum Error { #[error("Failed to set MTU on the active tunnel")] SetMtu(#[source] io::Error), - /// Failed to set MTU - #[error("Failed to detect MTU because every ping was dropped.")] + /// Failed to detect MTU because every ping was dropped + #[error("Failed to detect MTU because all pings timed out.")] MtuDetectionAllDropped, - /// Failed to set MTU + /// Failed to detect MTU because of unexpected ping error #[error("Failed to detect MTU because of unexpected ping error.")] - MtuDetectionPing(#[source] surge_ping::SurgeError), + MtuDetectionUnexpected(#[source] surge_ping::SurgeError), - /// Failed to set MTU + /// Failed to detect MTU because of an IO error when setting up the ping socket #[error("Failed to detect MTU because of an IO error when setting up the ping socket.")] MtuDetectionSetupSocket(#[source] io::Error), - /// Failed to set MTU + /// Failed to set buffer size #[cfg(target_os = "macos")] #[error("Failed to set buffer size")] MtuSetBufferSize(#[source] nix::Error), } + +/// Max time to wait for any ping, when this expires, we give up and throw an error. +const PING_TIMEOUT: Duration = Duration::from_secs(10); +/// Max time to wait after the first ping arrives. Every ping after this timeout is +/// considered dropped, so we return the largest collected packet size. +const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2); +const MTU_STEP_SIZE: u16 = 20; + /// Verify that the current MTU doesn't cause dropped packets, otherwise lower it to the /// largest value which doesn't. /// @@ -89,14 +98,7 @@ async fn detect_mtu( #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String, current_mtu: u16, ) -> Result<u16, Error> { - /// Max time to wait for any ping, when this expires, we give up and throw an error. - const PING_TIMEOUT: Duration = Duration::from_secs(10); - /// Max time to wait after the first ping arrives. Every ping after this timeout is - /// considered dropped, so we return the largest collected packet size. - const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2); - - let step_size = 20; - let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size); + let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, MTU_STEP_SIZE); let config_builder = Config::builder().kind(surge_ping::ICMP::V4); #[cfg(any(target_os = "macos", target_os = "linux"))] @@ -107,7 +109,7 @@ async fn detect_mtu( // data we are sending here. The consequence will be dropped packets causing the MTU // detection to set a low value. Here we manually increase this value, which fixes // the problem. - // TODO: Make sure this fix is not needed for any other target OS + // NOTE: If pings drop on other unix platforms too, then enable this fix for them #[cfg(target_os = "macos")] { use nix::sys::socket::{setsockopt, sockopt}; @@ -117,37 +119,57 @@ async fn detect_mtu( setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?; } + // Shared buffer to reduce allocations let payload_buf = vec![0; current_mtu as usize]; - let mut ping_stream = linspace - .iter() + // Send a ping for each MTU in the linspace + let ping_stream = linspace + .into_iter() .enumerate() - .map(|(i, &mtu)| { + .map(|(sequence, mtu)| { let client = client.clone(); let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize; let payload = &payload_buf[0..payload_size]; + // Return a future that sends a ping of size MTU, receives the result, and returns the + // validated MTU async move { log::trace!("Sending ICMP ping of total size {mtu}"); - client + let (packet, _duration) = client .pinger(IpAddr::V4(gateway), PingIdentifier(0)) .await .timeout(PING_TIMEOUT) - .ping(PingSequence(i as u16), payload) - .await + .ping(PingSequence(sequence as u16), payload) + .await?; + + // Validate the received ping response + { + let surge_ping::IcmpPacket::V4(packet) = packet else { + unreachable!("ICMP ping response was not of IPv4 type"); + }; + let size = u16::try_from(packet.get_size()) + .expect("ICMP packet size should fit in u16") + + IPV4_HEADER_SIZE; + log::trace!("Got ICMP ping response of total size {size}"); + debug_assert_eq!( + size, mtu, + "Ping response should be of identical size to request" + ); + } + Ok(mtu) } }) - .collect::<FuturesUnordered<_>>() - .map_ok(|(packet, _rtt)| { - let surge_ping::IcmpPacket::V4(packet) = packet else { - unreachable!("ICMP ping response was not of IPv4 type"); - }; - let size = u16::try_from(packet.get_size()).expect("ICMP packet size should fit in 16") - + IPV4_HEADER_SIZE; - log::trace!("Got ICMP ping response of total size {size}"); - debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]); - size - }); + .collect::<FuturesUnordered<_>>(); + max_ping_size(ping_stream, PING_OFFSET_TIMEOUT).await +} + +/// Consumes a stream of pings, and returns the largest packet size within a given timeout from the +/// first ping response. Short circuits on errors. +async fn max_ping_size( + ping_stream: impl Stream<Item = Result<u16, SurgeError>>, + ping_offset_timeout: Duration, +) -> Result<u16, Error> { + pin!(ping_stream); let first_ping_size = ping_stream .next() .await @@ -157,15 +179,15 @@ async fn detect_mtu( // If the first ping we get back timed out, then all of them did SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped, // Unexpected error type - e => Error::MtuDetectionPing(e), + e => Error::MtuDetectionUnexpected(e), })?; ping_stream - .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout - .map_while(|res| res.ok()) // Stop waiting for pings after this timeout + .timeout(ping_offset_timeout) // Start the timeout after the first ping has arrived + .map_while(|res| res.ok()) // Stop waiting for more pings after this timeout .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping .await - .map_err(Error::MtuDetectionPing) + .map_err(Error::MtuDetectionUnexpected) } /// Creates a linear spacing of MTU values with the given step size. Always includes the given @@ -187,20 +209,115 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> { #[cfg(test)] mod tests { - use super::mtu_spacing; + use super::*; + use futures::{stream, StreamExt}; use proptest::prelude::*; proptest! { #[test] - fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) { + fn mtu_spacing_properties(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) { let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size); + // The MTU linspace should contain the end points exactly once prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1); prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1); + // It should be allocated with no wasted capacity prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len()); + // The spacing should be no greater than step size let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]); prop_assert!(diffs.all(|diff| diff <= step_size)); } } + + /// The largest ping size should be chosen if all of them return, regardless of return order. + #[tokio::test] + async fn all_pings_ok() { + let pings = stream::iter((0..=100).rev().map(Ok)); + let max = max_ping_size(pings, Duration::from_millis(10)) + .await + .unwrap(); + assert_eq!(max, 100,); + } + + /// If one ping times out, all the following are considered timed out too. The largest response + /// before that point is chosen. + #[tokio::test] + async fn ping_timeout() { + let pings = stream::iter((0..=99).map(Ok)).chain(stream::once(async { + futures::future::pending::<()>().await; + Ok(100) + })); + let max = max_ping_size(pings, Duration::from_millis(10)) + .await + .unwrap(); + assert_eq!(max, 99); + } + + /// The [`PING_OFFSET_TIMEOUT`] is counted from the return of the first ping, not from the + /// function call. + #[tokio::test] + async fn delay_first_ping() { + let pings = stream::once(async { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(0) + }) + .chain(stream::iter((0..=100).map(Ok))); + let max = max_ping_size(pings, Duration::from_millis(5)) + .await + .unwrap(); + assert_eq!(max, 100); + } + + /// If an unknown error type occurs, the MTU detection is aborted and that error is propagated, + /// even if some ping response came back ok. + #[tokio::test] + async fn unknown_error() { + let pings = stream::iter([Ok(0), Err(SurgeError::NetworkError), Ok(10)]); + let e = max_ping_size(pings, Duration::from_millis(10)) + .await + .unwrap_err(); + assert!(matches!( + e, + Error::MtuDetectionUnexpected(SurgeError::NetworkError) + )); + } + + /// An error of type [`SurgeError::Timeout`] signals that the total [`PING_TIMEOUT`] has been + /// reached. If this happens to the first ping we consider alls pings timed out. + #[tokio::test] + async fn all_dropped() { + let pings = stream::iter([ + Err(SurgeError::Timeout { + seq: PingSequence(0), + }), + Ok(0), + Ok(10), + ]); + let e = max_ping_size(pings, Duration::from_millis(10)) + .await + .unwrap_err(); + assert!(matches!(e, Error::MtuDetectionAllDropped)); + } + + /// In the rare case that [`PING_TIMEOUT`] triggers before [`PING_OFFSET_TIMEOUT`], even though + /// some of the ping responses have come back, we still consider it abnormal and choose to + /// return an error instead of trusting result. + #[tokio::test] + async fn max_timeout_error() { + let pings = stream::iter([ + Ok(0), + Err(SurgeError::Timeout { + seq: PingSequence(0), + }), + Ok(10), + ]); + let e = max_ping_size(pings, Duration::from_millis(10)) + .await + .unwrap_err(); + assert!(matches!( + e, + Error::MtuDetectionUnexpected(SurgeError::Timeout { seq: _ }) + )); + } } |
