diff options
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 193 | ||||
| -rw-r--r-- | talpid-wireguard/src/mtu_detection.rs | 193 |
2 files changed, 209 insertions, 177 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index ee801559f7..78843570be 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -26,8 +26,6 @@ use talpid_routing as routing; use talpid_routing::{self, RequiredRoute}; #[cfg(not(windows))] use talpid_tunnel::tun_provider; -#[cfg(not(target_os = "android"))] -use talpid_tunnel::IPV4_HEADER_SIZE; use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use ipnetwork::IpNetwork; @@ -59,6 +57,9 @@ pub(crate) mod wireguard_kernel; #[cfg(windows)] mod wireguard_nt; +#[cfg(not(target_os = "android"))] +mod mtu_detection; + #[cfg(wireguard_go)] use self::wireguard_go::WgGoTunnel; @@ -73,19 +74,6 @@ pub enum Error { #[error(display = "Failed to setup routing")] SetupRoutingError(#[error(source)] talpid_routing::Error), - /// Failed to set MTU - #[error(display = "Failed to detect MTU because every ping was dropped.")] - MtuDetectionAllDropped, - - /// Failed to set MTU - #[error(display = "Failed to detect MTU because of unexpected ping error.")] - MtuDetectionPingError(#[error(source)] surge_ping::SurgeError), - - /// Failed to set MTU - #[cfg(target_os = "macos")] - #[error(display = "Failed to set buffer size")] - MtuSetBufferSize(#[error(source)] nix::Error), - /// Tunnel timed out #[error(display = "Tunnel timed out")] TimeoutError, @@ -396,45 +384,25 @@ impl WireguardMonitor { #[cfg(not(target_os = "android"))] if detect_mtu { - let iface_name_clone = iface_name.clone(); + let config = config.clone(); + let iface_name = iface_name.clone(); tokio::task::spawn(async move { - log::debug!("Starting MTU detection"); - let verified_mtu = match auto_mtu_detection( + if let Err(e) = mtu_detection::automatic_mtu_correction( gateway, - #[cfg(any(target_os = "macos", target_os = "linux"))] - iface_name_clone.clone(), + iface_name, config.mtu, + #[cfg(windows)] + config.ipv6_gateway.is_some(), ) .await { - Ok(mtu) => mtu, - Err(e) => { - log::error!("{}", e.display_chain_with_msg("Failed to detect MTU")); - return; - } - }; - - if verified_mtu != config.mtu { - log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu); - #[cfg(any(target_os = "linux", target_os = "macos"))] - let res = unix::set_mtu(&iface_name_clone, verified_mtu); - #[cfg(windows)] - let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then( - |luid| { - talpid_windows::net::set_mtu( - luid, - verified_mtu as u32, - config.ipv6_gateway.is_some(), - ) - }, + log::error!( + "{}", + e.display_chain_with_msg( + "Failed to automatically adjust MTU based on dropped packets" + ) ); - - if let Err(e) = res { - log::error!("{}", e.display_chain_with_msg("Failed to set MTU")) - }; - } else { - log::debug!("MTU {verified_mtu} verified to not drop packets"); - } + }; }); } let mut connectivity_monitor = tokio::task::spawn_blocking(move || { @@ -956,7 +924,7 @@ impl WireguardMonitor { #[cfg(any(target_os = "linux", target_os = "macos"))] fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute { - use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE}; + use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE}; if !config.is_multihop() { route @@ -1009,135 +977,6 @@ impl WireguardMonitor { } } -/// Detects the maximum MTU that does not cause dropped packets. -/// -/// The detection works by sending evenly spread out range of pings between 576 and the given -/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout. -#[cfg(not(target_os = "android"))] -async fn auto_mtu_detection( - gateway: std::net::Ipv4Addr, - #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String, - current_mtu: u16, -) -> Result<u16> { - use futures::{future, stream::FuturesUnordered, TryStreamExt}; - use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError}; - use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU}; - use tokio_stream::StreamExt; - - /// Max time to wait for any ping, when this expires, we give up and throw an error. - const PING_TIMEOUT: Duration = Duration::from_secs(10); - /// Max time to wait after the first ping arrives. Every ping after this timeout is considered - /// dropped, so we return the largest collected packet size. - const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2); - - let step_size = 20; - let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size); - - let config_builder = Config::builder().kind(surge_ping::ICMP::V4); - #[cfg(any(target_os = "macos", target_os = "linux"))] - let config_builder = config_builder.interface(&iface_name); - let client = Client::new(&config_builder.build()).unwrap(); - // For macos, the default socket receive buffer size seems to be too small to handle the data we - // are sending here. The consequence will be dropped packets causing the MTU detection to set a - // low value. Here we manually increase this value, which fixes the problem. - // TODO: Make sure this fix is not needed for any other target OS - #[cfg(target_os = "macos")] - { - use nix::sys::socket::{setsockopt, sockopt}; - let fd = client.get_socket().get_native_sock(); - let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum(); - setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?; - setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?; - } - - let payload_buf = vec![0; current_mtu as usize]; - - let mut ping_stream = linspace - .iter() - .enumerate() - .map(|(i, &mtu)| { - let client = client.clone(); - let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize; - let payload = &payload_buf[0..payload_size]; - async move { - log::trace!("Sending ICMP ping of total size {mtu}"); - client - .pinger(IpAddr::V4(gateway), PingIdentifier(0)) - .await - .timeout(PING_TIMEOUT) - .ping(PingSequence(i as u16), payload) - .await - } - }) - .collect::<FuturesUnordered<_>>() - .map_ok(|(packet, _rtt)| { - let surge_ping::IcmpPacket::V4(packet) = packet else { - unreachable!("ICMP ping response was not of IPv4 type"); - }; - let size = packet.get_size() as u16 + IPV4_HEADER_SIZE; - log::trace!("Got ICMP ping response of total size {size}"); - debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]); - size - }); - - let first_ping_size = ping_stream - .next() - .await - .expect("At least one pings should be sent") - // Short-circuit and return on error - .map_err(|e| match e { - // If the first ping we get back timed out, then all of them did - SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped, - // Unexpected error type - e => Error::MtuDetectionPingError(e), - })?; - - ping_stream - .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout - .map_while(|res| res.ok()) // Stop waiting for pings after this timeout - .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping - .await - .map_err(Error::MtuDetectionPingError) -} - -/// Creates a linear spacing of MTU values with the given step size. Always includes the given end -/// points. -#[cfg(not(target_os = "android"))] -fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> { - assert!(mtu_min < mtu_max); - assert!(step_size < mtu_max); - assert_ne!(step_size, 0); - - let second_mtu = (mtu_min + 1).next_multiple_of(step_size); - let in_between = (second_mtu..mtu_max).step_by(step_size as usize); - - let mut ret = Vec::with_capacity(in_between.clone().count() + 2); - ret.push(mtu_min); - ret.extend(in_between); - ret.push(mtu_max); - ret -} - -#[cfg(all(test, not(target_os = "android")))] -mod tests { - use crate::mtu_spacing; - use proptest::prelude::*; - - proptest! { - #[test] - fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) { - let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size); - - prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1); - prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1); - prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len()); - let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]); - prop_assert!(diffs.all(|diff| diff <= step_size)); - - } - } -} - #[derive(Debug)] enum CloseMsg { Stop, diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs new file mode 100644 index 0000000000..139dfbc9b5 --- /dev/null +++ b/talpid-wireguard/src/mtu_detection.rs @@ -0,0 +1,193 @@ +use std::{io, net::IpAddr, time::Duration}; + +use futures::{future, stream::FuturesUnordered, TryStreamExt}; +use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError}; +use talpid_tunnel::{ICMP_HEADER_SIZE, IPV4_HEADER_SIZE, MIN_IPV4_MTU}; +use tokio_stream::StreamExt; + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to set MTU on the active tunnel + #[error(display = "Failed to set MTU on the active tunnel")] + SetMtu(#[error(source)] io::Error), + + /// Failed to set MTU + #[error(display = "Failed to detect MTU because every ping was dropped.")] + MtuDetectionAllDropped, + + /// Failed to set MTU + #[error(display = "Failed to detect MTU because of unexpected ping error.")] + MtuDetectionPing(#[error(source)] surge_ping::SurgeError), + + /// Failed to set MTU + #[error( + display = "Failed to detect MTU because of an IO error when setting up the ping socket." + )] + MtuDetectionSetupSocket(#[error(source)] io::Error), + + /// Failed to set MTU + #[cfg(target_os = "macos")] + #[error(display = "Failed to set buffer size")] + MtuSetBufferSize(#[error(source)] nix::Error), +} +/// Verify that the current MTU doesn't cause dropped packets, otherwise lower it to the +/// largest value which doesn't. +/// +/// Note: This does not take fragmentation into account, so it should only be used as an extra +/// safety measure after the normal MTU calculation using header sizes and safety margins. +pub async fn automatic_mtu_correction( + gateway: std::net::Ipv4Addr, + iface_name: String, + current_tunnel_mtu: u16, + #[cfg(windows)] ipv6: bool, +) -> Result<(), Error> { + log::debug!("Starting MTU detection"); + let verified_mtu = detect_mtu( + gateway, + #[cfg(any(target_os = "macos", target_os = "linux"))] + iface_name.clone(), + current_tunnel_mtu, + ) + .await?; + + if verified_mtu != current_tunnel_mtu { + log::warn!("Lowering MTU from {} to {verified_mtu}", current_tunnel_mtu); + + #[cfg(any(target_os = "linux", target_os = "macos"))] + crate::unix::set_mtu(&iface_name, verified_mtu).map_err(Error::SetMtu)?; + #[cfg(windows)] + talpid_windows::net::luid_from_alias(iface_name) + .and_then(|luid| talpid_windows::net::set_mtu(luid, verified_mtu as u32, ipv6)) + .map_err(Error::SetMtu)?; + } else { + log::debug!("MTU {verified_mtu} verified to not drop packets"); + }; + Ok(()) +} + +/// Detects the maximum MTU that does not cause dropped packets. +/// +/// The detection works by sending evenly spread out range of pings between 576 and the given +/// current tunnel MTU, and returning the maximum packet size that was returned within a +/// timeout. +async fn detect_mtu( + gateway: std::net::Ipv4Addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String, + current_mtu: u16, +) -> Result<u16, Error> { + /// Max time to wait for any ping, when this expires, we give up and throw an error. + const PING_TIMEOUT: Duration = Duration::from_secs(10); + /// Max time to wait after the first ping arrives. Every ping after this timeout is + /// considered dropped, so we return the largest collected packet size. + const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2); + + let step_size = 20; + let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size); + + let config_builder = Config::builder().kind(surge_ping::ICMP::V4); + #[cfg(any(target_os = "macos", target_os = "linux"))] + let config_builder = config_builder.interface(&iface_name); + let client = Client::new(&config_builder.build()).map_err(Error::MtuDetectionSetupSocket)?; + + // For macos, the default socket receive buffer size seems to be too small to handle the + // data we are sending here. The consequence will be dropped packets causing the MTU + // detection to set a low value. Here we manually increase this value, which fixes + // the problem. + // TODO: Make sure this fix is not needed for any other target OS + #[cfg(target_os = "macos")] + { + use nix::sys::socket::{setsockopt, sockopt}; + let fd = client.get_socket().get_native_sock(); + let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum(); + setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?; + setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?; + } + + let payload_buf = vec![0; current_mtu as usize]; + + let mut ping_stream = linspace + .iter() + .enumerate() + .map(|(i, &mtu)| { + let client = client.clone(); + let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize; + let payload = &payload_buf[0..payload_size]; + async move { + log::trace!("Sending ICMP ping of total size {mtu}"); + client + .pinger(IpAddr::V4(gateway), PingIdentifier(0)) + .await + .timeout(PING_TIMEOUT) + .ping(PingSequence(i as u16), payload) + .await + } + }) + .collect::<FuturesUnordered<_>>() + .map_ok(|(packet, _rtt)| { + let surge_ping::IcmpPacket::V4(packet) = packet else { + unreachable!("ICMP ping response was not of IPv4 type"); + }; + let size = u16::try_from(packet.get_size()).expect("ICMP packet size should fit in 16") + + IPV4_HEADER_SIZE; + log::trace!("Got ICMP ping response of total size {size}"); + debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]); + size + }); + + let first_ping_size = ping_stream + .next() + .await + .expect("At least one pings should be sent") + // Short-circuit and return on error + .map_err(|e| match e { + // If the first ping we get back timed out, then all of them did + SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped, + // Unexpected error type + e => Error::MtuDetectionPing(e), + })?; + + ping_stream + .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout + .map_while(|res| res.ok()) // Stop waiting for pings after this timeout + .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping + .await + .map_err(Error::MtuDetectionPing) +} + +/// Creates a linear spacing of MTU values with the given step size. Always includes the given +/// end points. +fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> { + assert!(mtu_min < mtu_max); + assert!(step_size < mtu_max); + assert_ne!(step_size, 0); + + let second_mtu = (mtu_min + 1).next_multiple_of(step_size); + let in_between = (second_mtu..mtu_max).step_by(step_size as usize); + + let mut ret = Vec::with_capacity(in_between.clone().count() + 2); + ret.push(mtu_min); + ret.extend(in_between); + ret.push(mtu_max); + ret +} + +#[cfg(test)] +mod tests { + use super::mtu_spacing; + use proptest::prelude::*; + + proptest! { + #[test] + fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) { + let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size); + + prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1); + prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1); + prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len()); + let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]); + prop_assert!(diffs.all(|diff| diff <= step_size)); + + } + } +} |
