diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-02-27 11:04:16 +0100 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-02-27 12:30:51 +0100 |
| commit | 82b4467b0592919bff14fb527f7722d5c99d0dfd (patch) | |
| tree | 51fd099545414e0f442a4b385fcf3ba66ba3234e | |
| parent | fb42a23a39a9c0a1a65fd8a347a985783b3c4bf2 (diff) | |
| download | mullvadvpn-82b4467b0592919bff14fb527f7722d5c99d0dfd.tar.xz mullvadvpn-82b4467b0592919bff14fb527f7722d5c99d0dfd.zip | |
Make `max_ping_size` only take `FuturesUnordered` instead of being generic
| -rw-r--r-- | talpid-wireguard/src/mtu_detection.rs | 95 |
1 files changed, 64 insertions, 31 deletions
diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs index 6166c2b4b9..b4b114cdc5 100644 --- a/talpid-wireguard/src/mtu_detection.rs +++ b/talpid-wireguard/src/mtu_detection.rs @@ -1,9 +1,8 @@ use std::{io, net::IpAddr, time::Duration}; -use futures::{future, stream::FuturesUnordered, Stream, TryStreamExt}; +use futures::{future, stream::FuturesUnordered, Future, TryStreamExt}; use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError}; use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU}; -use tokio::pin; use tokio_stream::StreamExt; #[derive(thiserror::Error, Debug)] @@ -166,10 +165,9 @@ async fn detect_mtu( /// Consumes a stream of pings, and returns the largest packet size within a given timeout from the /// first ping response. Short circuits on errors. async fn max_ping_size( - ping_stream: impl Stream<Item = Result<u16, SurgeError>>, + mut ping_stream: FuturesUnordered<impl Future<Output = Result<u16, SurgeError>>>, ping_offset_timeout: Duration, ) -> Result<u16, Error> { - pin!(ping_stream); let first_ping_size = ping_stream .next() .await @@ -209,8 +207,12 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> { #[cfg(test)] mod tests { + use std::{ + marker::{Send, Unpin}, + pin::Pin, + }; + use super::*; - use futures::{stream, StreamExt}; use proptest::prelude::*; proptest! { @@ -230,39 +232,65 @@ mod tests { } } + fn ready_ping<T: Send + 'static>(x: T) -> Pin<Box<dyn Future<Output = T>>> { + Box::pin(future::ready(x)) + } + + fn ok_ping<T: Send + 'static, E: Send + 'static>( + x: T, + ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> { + ready_ping(Ok(x)) + } + + fn err_ping<T: Send + 'static, E: Send + 'static>( + e: E, + ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> { + ready_ping(Err(e)) + } + + fn delayed_ping<T: Send + 'static + Unpin>( + x: T, + duration: Duration, + ) -> Pin<Box<dyn Future<Output = T>>> { + Box::pin(async move { + tokio::time::sleep(duration).await; + x + }) + } + /// The largest ping size should be chosen if all of them return, regardless of return order. #[tokio::test] async fn all_pings_ok() { - let pings = stream::iter((0..=100).rev().map(Ok)); + let pings = (0..=100).rev().map(ok_ping).collect(); let max = max_ping_size(pings, Duration::from_millis(10)) .await .unwrap(); - assert_eq!(max, 100,); + assert_eq!(max, 100); } /// If one ping times out, all the following are considered timed out too. The largest response /// before that point is chosen. #[tokio::test] async fn ping_timeout() { - let pings = stream::iter((0..=99).map(Ok)).chain(stream::once(async { - futures::future::pending::<()>().await; - Ok(100) - })); - let max = max_ping_size(pings, Duration::from_millis(10)) + let mut pings = FuturesUnordered::new(); + let early_pings = (0..=50).map(ok_ping); + pings.extend(early_pings); + let late_pings = (51..=100).map(|p| delayed_ping(Ok(p), Duration::from_millis(10))); + pings.extend(late_pings); + + let max = max_ping_size(pings, Duration::from_millis(5)) .await .unwrap(); - assert_eq!(max, 99); + assert_eq!(max, 50); } /// The [`PING_OFFSET_TIMEOUT`] is counted from the return of the first ping, not from the /// function call. #[tokio::test] async fn delay_first_ping() { - let pings = stream::once(async { - tokio::time::sleep(Duration::from_millis(10)).await; - Ok(0) - }) - .chain(stream::iter((0..=100).map(Ok))); + let pings = (0..=100) + .map(|p| delayed_ping(Ok(p), Duration::from_millis(10))) + .collect(); let max = max_ping_size(pings, Duration::from_millis(5)) .await .unwrap(); @@ -273,7 +301,11 @@ mod tests { /// even if some ping response came back ok. #[tokio::test] async fn unknown_error() { - let pings = stream::iter([Ok(0), Err(SurgeError::NetworkError), Ok(10)]); + let pings = FuturesUnordered::new(); + pings.push(ok_ping(0)); + pings.push(err_ping(SurgeError::NetworkError)); + pings.push(ok_ping(10)); + let e = max_ping_size(pings, Duration::from_millis(10)) .await .unwrap_err(); @@ -287,13 +319,12 @@ mod tests { /// reached. If this happens to the first ping we consider alls pings timed out. #[tokio::test] async fn all_dropped() { - let pings = stream::iter([ - Err(SurgeError::Timeout { - seq: PingSequence(0), - }), - Ok(0), - Ok(10), - ]); + let pings = FuturesUnordered::new(); + pings.push(err_ping(SurgeError::Timeout { + seq: PingSequence(0), + })); + pings.push(delayed_ping(Ok(10), Duration::from_millis(10))); + let e = max_ping_size(pings, Duration::from_millis(10)) .await .unwrap_err(); @@ -305,14 +336,16 @@ mod tests { /// return an error instead of trusting result. #[tokio::test] async fn max_timeout_error() { - let pings = stream::iter([ - Ok(0), + let pings = FuturesUnordered::new(); + pings.push(delayed_ping(Ok(0), Duration::from_millis(9))); + pings.push(delayed_ping( Err(SurgeError::Timeout { seq: PingSequence(0), }), - Ok(10), - ]); - let e = max_ping_size(pings, Duration::from_millis(10)) + Duration::from_millis(10), + )); + + let e = max_ping_size(pings, Duration::from_millis(5)) .await .unwrap_err(); assert!(matches!( |
