summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-wireguard/src/lib.rs193
-rw-r--r--talpid-wireguard/src/mtu_detection.rs193
2 files changed, 209 insertions, 177 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index ee801559f7..78843570be 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -26,8 +26,6 @@ use talpid_routing as routing;
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
-#[cfg(not(target_os = "android"))]
-use talpid_tunnel::IPV4_HEADER_SIZE;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use ipnetwork::IpNetwork;
@@ -59,6 +57,9 @@ pub(crate) mod wireguard_kernel;
#[cfg(windows)]
mod wireguard_nt;
+#[cfg(not(target_os = "android"))]
+mod mtu_detection;
+
#[cfg(wireguard_go)]
use self::wireguard_go::WgGoTunnel;
@@ -73,19 +74,6 @@ pub enum Error {
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] talpid_routing::Error),
- /// Failed to set MTU
- #[error(display = "Failed to detect MTU because every ping was dropped.")]
- MtuDetectionAllDropped,
-
- /// Failed to set MTU
- #[error(display = "Failed to detect MTU because of unexpected ping error.")]
- MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),
-
- /// Failed to set MTU
- #[cfg(target_os = "macos")]
- #[error(display = "Failed to set buffer size")]
- MtuSetBufferSize(#[error(source)] nix::Error),
-
/// Tunnel timed out
#[error(display = "Tunnel timed out")]
TimeoutError,
@@ -396,45 +384,25 @@ impl WireguardMonitor {
#[cfg(not(target_os = "android"))]
if detect_mtu {
- let iface_name_clone = iface_name.clone();
+ let config = config.clone();
+ let iface_name = iface_name.clone();
tokio::task::spawn(async move {
- log::debug!("Starting MTU detection");
- let verified_mtu = match auto_mtu_detection(
+ if let Err(e) = mtu_detection::automatic_mtu_correction(
gateway,
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- iface_name_clone.clone(),
+ iface_name,
config.mtu,
+ #[cfg(windows)]
+ config.ipv6_gateway.is_some(),
)
.await
{
- Ok(mtu) => mtu,
- Err(e) => {
- log::error!("{}", e.display_chain_with_msg("Failed to detect MTU"));
- return;
- }
- };
-
- if verified_mtu != config.mtu {
- log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
- #[cfg(any(target_os = "linux", target_os = "macos"))]
- let res = unix::set_mtu(&iface_name_clone, verified_mtu);
- #[cfg(windows)]
- let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
- |luid| {
- talpid_windows::net::set_mtu(
- luid,
- verified_mtu as u32,
- config.ipv6_gateway.is_some(),
- )
- },
+ log::error!(
+ "{}",
+ e.display_chain_with_msg(
+ "Failed to automatically adjust MTU based on dropped packets"
+ )
);
-
- if let Err(e) = res {
- log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
- };
- } else {
- log::debug!("MTU {verified_mtu} verified to not drop packets");
- }
+ };
});
}
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
@@ -956,7 +924,7 @@ impl WireguardMonitor {
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
- use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
+ use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
if !config.is_multihop() {
route
@@ -1009,135 +977,6 @@ impl WireguardMonitor {
}
}
-/// Detects the maximum MTU that does not cause dropped packets.
-///
-/// The detection works by sending evenly spread out range of pings between 576 and the given
-/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
-#[cfg(not(target_os = "android"))]
-async fn auto_mtu_detection(
- gateway: std::net::Ipv4Addr,
- #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
- current_mtu: u16,
-) -> Result<u16> {
- use futures::{future, stream::FuturesUnordered, TryStreamExt};
- use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
- use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
- use tokio_stream::StreamExt;
-
- /// Max time to wait for any ping, when this expires, we give up and throw an error.
- const PING_TIMEOUT: Duration = Duration::from_secs(10);
- /// Max time to wait after the first ping arrives. Every ping after this timeout is considered
- /// dropped, so we return the largest collected packet size.
- const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
-
- let step_size = 20;
- let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
-
- let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- let config_builder = config_builder.interface(&iface_name);
- let client = Client::new(&config_builder.build()).unwrap();
- // For macos, the default socket receive buffer size seems to be too small to handle the data we
- // are sending here. The consequence will be dropped packets causing the MTU detection to set a
- // low value. Here we manually increase this value, which fixes the problem.
- // TODO: Make sure this fix is not needed for any other target OS
- #[cfg(target_os = "macos")]
- {
- use nix::sys::socket::{setsockopt, sockopt};
- let fd = client.get_socket().get_native_sock();
- let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum();
- setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
- setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
- }
-
- let payload_buf = vec![0; current_mtu as usize];
-
- let mut ping_stream = linspace
- .iter()
- .enumerate()
- .map(|(i, &mtu)| {
- let client = client.clone();
- let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
- let payload = &payload_buf[0..payload_size];
- async move {
- log::trace!("Sending ICMP ping of total size {mtu}");
- client
- .pinger(IpAddr::V4(gateway), PingIdentifier(0))
- .await
- .timeout(PING_TIMEOUT)
- .ping(PingSequence(i as u16), payload)
- .await
- }
- })
- .collect::<FuturesUnordered<_>>()
- .map_ok(|(packet, _rtt)| {
- let surge_ping::IcmpPacket::V4(packet) = packet else {
- unreachable!("ICMP ping response was not of IPv4 type");
- };
- let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
- log::trace!("Got ICMP ping response of total size {size}");
- debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
- size
- });
-
- let first_ping_size = ping_stream
- .next()
- .await
- .expect("At least one pings should be sent")
- // Short-circuit and return on error
- .map_err(|e| match e {
- // If the first ping we get back timed out, then all of them did
- SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
- // Unexpected error type
- e => Error::MtuDetectionPingError(e),
- })?;
-
- ping_stream
- .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
- .map_while(|res| res.ok()) // Stop waiting for pings after this timeout
- .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
- .await
- .map_err(Error::MtuDetectionPingError)
-}
-
-/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
-/// points.
-#[cfg(not(target_os = "android"))]
-fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
- assert!(mtu_min < mtu_max);
- assert!(step_size < mtu_max);
- assert_ne!(step_size, 0);
-
- let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
- let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
-
- let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
- ret.push(mtu_min);
- ret.extend(in_between);
- ret.push(mtu_max);
- ret
-}
-
-#[cfg(all(test, not(target_os = "android")))]
-mod tests {
- use crate::mtu_spacing;
- use proptest::prelude::*;
-
- proptest! {
- #[test]
- fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
- let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);
-
- prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
- prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
- prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
- let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
- prop_assert!(diffs.all(|diff| diff <= step_size));
-
- }
- }
-}
-
#[derive(Debug)]
enum CloseMsg {
Stop,
diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs
new file mode 100644
index 0000000000..139dfbc9b5
--- /dev/null
+++ b/talpid-wireguard/src/mtu_detection.rs
@@ -0,0 +1,193 @@
+use std::{io, net::IpAddr, time::Duration};
+
+use futures::{future, stream::FuturesUnordered, TryStreamExt};
+use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
+use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU};
+use tokio_stream::StreamExt;
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to set MTU on the active tunnel
+ #[error(display = "Failed to set MTU on the active tunnel")]
+ SetMtu(#[error(source)] io::Error),
+
+ /// Failed to set MTU
+ #[error(display = "Failed to detect MTU because every ping was dropped.")]
+ MtuDetectionAllDropped,
+
+ /// Failed to set MTU
+ #[error(display = "Failed to detect MTU because of unexpected ping error.")]
+ MtuDetectionPing(#[error(source)] surge_ping::SurgeError),
+
+ /// Failed to set MTU
+ #[error(
+ display = "Failed to detect MTU because of an IO error when setting up the ping socket."
+ )]
+ MtuDetectionSetupSocket(#[error(source)] io::Error),
+
+ /// Failed to set MTU
+ #[cfg(target_os = "macos")]
+ #[error(display = "Failed to set buffer size")]
+ MtuSetBufferSize(#[error(source)] nix::Error),
+}
+/// Verify that the current MTU doesn't cause dropped packets, otherwise lower it to the
+/// largest value which doesn't.
+///
+/// Note: This does not take fragmentation into account, so it should only be used as an extra
+/// safety measure after the normal MTU calculation using header sizes and safety margins.
+pub async fn automatic_mtu_correction(
+ gateway: std::net::Ipv4Addr,
+ iface_name: String,
+ current_tunnel_mtu: u16,
+ #[cfg(windows)] ipv6: bool,
+) -> Result<(), Error> {
+ log::debug!("Starting MTU detection");
+ let verified_mtu = detect_mtu(
+ gateway,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ iface_name.clone(),
+ current_tunnel_mtu,
+ )
+ .await?;
+
+ if verified_mtu != current_tunnel_mtu {
+ log::warn!("Lowering MTU from {} to {verified_mtu}", current_tunnel_mtu);
+
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ crate::unix::set_mtu(&iface_name, verified_mtu).map_err(Error::SetMtu)?;
+ #[cfg(windows)]
+ talpid_windows::net::luid_from_alias(iface_name)
+ .and_then(|luid| talpid_windows::net::set_mtu(luid, verified_mtu as u32, ipv6))
+ .map_err(Error::SetMtu)?;
+ } else {
+ log::debug!("MTU {verified_mtu} verified to not drop packets");
+ };
+ Ok(())
+}
+
+/// Detects the maximum MTU that does not cause dropped packets.
+///
+/// The detection works by sending evenly spread out range of pings between 576 and the given
+/// current tunnel MTU, and returning the maximum packet size that was returned within a
+/// timeout.
+async fn detect_mtu(
+ gateway: std::net::Ipv4Addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
+ current_mtu: u16,
+) -> Result<u16, Error> {
+ /// Max time to wait for any ping, when this expires, we give up and throw an error.
+ const PING_TIMEOUT: Duration = Duration::from_secs(10);
+ /// Max time to wait after the first ping arrives. Every ping after this timeout is
+ /// considered dropped, so we return the largest collected packet size.
+ const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
+
+ let step_size = 20;
+ let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
+
+ let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ let config_builder = config_builder.interface(&iface_name);
+ let client = Client::new(&config_builder.build()).map_err(Error::MtuDetectionSetupSocket)?;
+
+ // For macos, the default socket receive buffer size seems to be too small to handle the
+ // data we are sending here. The consequence will be dropped packets causing the MTU
+ // detection to set a low value. Here we manually increase this value, which fixes
+ // the problem.
+ // TODO: Make sure this fix is not needed for any other target OS
+ #[cfg(target_os = "macos")]
+ {
+ use nix::sys::socket::{setsockopt, sockopt};
+ let fd = client.get_socket().get_native_sock();
+ let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum();
+ setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
+ setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
+ }
+
+ let payload_buf = vec![0; current_mtu as usize];
+
+ let mut ping_stream = linspace
+ .iter()
+ .enumerate()
+ .map(|(i, &mtu)| {
+ let client = client.clone();
+ let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
+ let payload = &payload_buf[0..payload_size];
+ async move {
+ log::trace!("Sending ICMP ping of total size {mtu}");
+ client
+ .pinger(IpAddr::V4(gateway), PingIdentifier(0))
+ .await
+ .timeout(PING_TIMEOUT)
+ .ping(PingSequence(i as u16), payload)
+ .await
+ }
+ })
+ .collect::<FuturesUnordered<_>>()
+ .map_ok(|(packet, _rtt)| {
+ let surge_ping::IcmpPacket::V4(packet) = packet else {
+ unreachable!("ICMP ping response was not of IPv4 type");
+ };
+ let size = u16::try_from(packet.get_size()).expect("ICMP packet size should fit in 16")
+ + IPV4_HEADER_SIZE;
+ log::trace!("Got ICMP ping response of total size {size}");
+ debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
+ size
+ });
+
+ let first_ping_size = ping_stream
+ .next()
+ .await
+ .expect("At least one pings should be sent")
+ // Short-circuit and return on error
+ .map_err(|e| match e {
+ // If the first ping we get back timed out, then all of them did
+ SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
+ // Unexpected error type
+ e => Error::MtuDetectionPing(e),
+ })?;
+
+ ping_stream
+ .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
+ .map_while(|res| res.ok()) // Stop waiting for pings after this timeout
+ .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
+ .await
+ .map_err(Error::MtuDetectionPing)
+}
+
+/// Creates a linear spacing of MTU values with the given step size. Always includes the given
+/// end points.
+fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
+ assert!(mtu_min < mtu_max);
+ assert!(step_size < mtu_max);
+ assert_ne!(step_size, 0);
+
+ let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
+ let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
+
+ let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
+ ret.push(mtu_min);
+ ret.extend(in_between);
+ ret.push(mtu_max);
+ ret
+}
+
+#[cfg(test)]
+mod tests {
+ use super::mtu_spacing;
+ use proptest::prelude::*;
+
+ proptest! {
+ #[test]
+ fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
+ let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);
+
+ prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
+ prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
+ prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
+ let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
+ prop_assert!(diffs.all(|diff| diff <= step_size));
+
+ }
+ }
+}