summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-02-27 11:04:16 +0100
committerSebastian Holmin <sebastian.holmin@mullvad.net>2024-02-27 12:30:51 +0100
commit82b4467b0592919bff14fb527f7722d5c99d0dfd (patch)
tree51fd099545414e0f442a4b385fcf3ba66ba3234e
parentfb42a23a39a9c0a1a65fd8a347a985783b3c4bf2 (diff)
downloadmullvadvpn-82b4467b0592919bff14fb527f7722d5c99d0dfd.tar.xz
mullvadvpn-82b4467b0592919bff14fb527f7722d5c99d0dfd.zip
Make `max_ping_size` only take `FuturesUnordered` instead of being generic
-rw-r--r--talpid-wireguard/src/mtu_detection.rs95
1 files changed, 64 insertions, 31 deletions
diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs
index 6166c2b4b9..b4b114cdc5 100644
--- a/talpid-wireguard/src/mtu_detection.rs
+++ b/talpid-wireguard/src/mtu_detection.rs
@@ -1,9 +1,8 @@
use std::{io, net::IpAddr, time::Duration};
-use futures::{future, stream::FuturesUnordered, Stream, TryStreamExt};
+use futures::{future, stream::FuturesUnordered, Future, TryStreamExt};
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU};
-use tokio::pin;
use tokio_stream::StreamExt;
#[derive(thiserror::Error, Debug)]
@@ -166,10 +165,9 @@ async fn detect_mtu(
/// Consumes a stream of pings, and returns the largest packet size within a given timeout from the
/// first ping response. Short circuits on errors.
async fn max_ping_size(
- ping_stream: impl Stream<Item = Result<u16, SurgeError>>,
+ mut ping_stream: FuturesUnordered<impl Future<Output = Result<u16, SurgeError>>>,
ping_offset_timeout: Duration,
) -> Result<u16, Error> {
- pin!(ping_stream);
let first_ping_size = ping_stream
.next()
.await
@@ -209,8 +207,12 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
#[cfg(test)]
mod tests {
+ use std::{
+ marker::{Send, Unpin},
+ pin::Pin,
+ };
+
use super::*;
- use futures::{stream, StreamExt};
use proptest::prelude::*;
proptest! {
@@ -230,39 +232,65 @@ mod tests {
}
}
+ fn ready_ping<T: Send + 'static>(x: T) -> Pin<Box<dyn Future<Output = T>>> {
+ Box::pin(future::ready(x))
+ }
+
+ fn ok_ping<T: Send + 'static, E: Send + 'static>(
+ x: T,
+ ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> {
+ ready_ping(Ok(x))
+ }
+
+ fn err_ping<T: Send + 'static, E: Send + 'static>(
+ e: E,
+ ) -> Pin<Box<dyn Future<Output = Result<T, E>>>> {
+ ready_ping(Err(e))
+ }
+
+ fn delayed_ping<T: Send + 'static + Unpin>(
+ x: T,
+ duration: Duration,
+ ) -> Pin<Box<dyn Future<Output = T>>> {
+ Box::pin(async move {
+ tokio::time::sleep(duration).await;
+ x
+ })
+ }
+
/// The largest ping size should be chosen if all of them return, regardless of return order.
#[tokio::test]
async fn all_pings_ok() {
- let pings = stream::iter((0..=100).rev().map(Ok));
+ let pings = (0..=100).rev().map(ok_ping).collect();
let max = max_ping_size(pings, Duration::from_millis(10))
.await
.unwrap();
- assert_eq!(max, 100,);
+ assert_eq!(max, 100);
}
/// If one ping times out, all the following are considered timed out too. The largest response
/// before that point is chosen.
#[tokio::test]
async fn ping_timeout() {
- let pings = stream::iter((0..=99).map(Ok)).chain(stream::once(async {
- futures::future::pending::<()>().await;
- Ok(100)
- }));
- let max = max_ping_size(pings, Duration::from_millis(10))
+ let mut pings = FuturesUnordered::new();
+ let early_pings = (0..=50).map(ok_ping);
+ pings.extend(early_pings);
+ let late_pings = (51..=100).map(|p| delayed_ping(Ok(p), Duration::from_millis(10)));
+ pings.extend(late_pings);
+
+ let max = max_ping_size(pings, Duration::from_millis(5))
.await
.unwrap();
- assert_eq!(max, 99);
+ assert_eq!(max, 50);
}
/// The [`PING_OFFSET_TIMEOUT`] is counted from the return of the first ping, not from the
/// function call.
#[tokio::test]
async fn delay_first_ping() {
- let pings = stream::once(async {
- tokio::time::sleep(Duration::from_millis(10)).await;
- Ok(0)
- })
- .chain(stream::iter((0..=100).map(Ok)));
+ let pings = (0..=100)
+ .map(|p| delayed_ping(Ok(p), Duration::from_millis(10)))
+ .collect();
let max = max_ping_size(pings, Duration::from_millis(5))
.await
.unwrap();
@@ -273,7 +301,11 @@ mod tests {
/// even if some ping response came back ok.
#[tokio::test]
async fn unknown_error() {
- let pings = stream::iter([Ok(0), Err(SurgeError::NetworkError), Ok(10)]);
+ let pings = FuturesUnordered::new();
+ pings.push(ok_ping(0));
+ pings.push(err_ping(SurgeError::NetworkError));
+ pings.push(ok_ping(10));
+
let e = max_ping_size(pings, Duration::from_millis(10))
.await
.unwrap_err();
@@ -287,13 +319,12 @@ mod tests {
/// reached. If this happens to the first ping we consider alls pings timed out.
#[tokio::test]
async fn all_dropped() {
- let pings = stream::iter([
- Err(SurgeError::Timeout {
- seq: PingSequence(0),
- }),
- Ok(0),
- Ok(10),
- ]);
+ let pings = FuturesUnordered::new();
+ pings.push(err_ping(SurgeError::Timeout {
+ seq: PingSequence(0),
+ }));
+ pings.push(delayed_ping(Ok(10), Duration::from_millis(10)));
+
let e = max_ping_size(pings, Duration::from_millis(10))
.await
.unwrap_err();
@@ -305,14 +336,16 @@ mod tests {
/// return an error instead of trusting result.
#[tokio::test]
async fn max_timeout_error() {
- let pings = stream::iter([
- Ok(0),
+ let pings = FuturesUnordered::new();
+ pings.push(delayed_ping(Ok(0), Duration::from_millis(9)));
+ pings.push(delayed_ping(
Err(SurgeError::Timeout {
seq: PingSequence(0),
}),
- Ok(10),
- ]);
- let e = max_ping_size(pings, Duration::from_millis(10))
+ Duration::from_millis(10),
+ ));
+
+ let e = max_ping_size(pings, Duration::from_millis(5))
.await
.unwrap_err();
assert!(matches!(