summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-02-08 15:54:48 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-02-08 15:54:48 +0100
commit0afb7cc9502ecbd073f1909aab39d9570db5903d (patch)
treebb014bec50c05cfe24b4c88f78ef2b28ed263d26
parentbc51bf6295f0c7fca273d720c624fbfee305e910 (diff)
parenteeca931f565dd7699e8d7f1209c65b1520834373 (diff)
downloadmullvadvpn-0afb7cc9502ecbd073f1909aab39d9570db5903d.tar.xz
mullvadvpn-0afb7cc9502ecbd073f1909aab39d9570db5903d.zip
Merge branch 'mtu-detection-linux'
-rw-r--r--Cargo.lock183
-rw-r--r--talpid-core/Cargo.toml3
-rw-r--r--talpid-core/src/future_retry.rs17
-rw-r--r--talpid-core/src/tunnel/mod.rs76
-rw-r--r--talpid-tunnel/src/lib.rs15
-rw-r--r--talpid-wireguard/Cargo.toml4
-rw-r--r--talpid-wireguard/src/config.rs18
-rw-r--r--talpid-wireguard/src/connectivity_check.rs2
-rw-r--r--talpid-wireguard/src/lib.rs167
-rw-r--r--talpid-wireguard/src/ping_monitor/icmp.rs2
-rw-r--r--talpid-wireguard/src/unix.rs41
-rw-r--r--talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs2
12 files changed, 441 insertions, 89 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 1481155dc6..5f37db970b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -290,6 +290,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
+name = "bit-set"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+dependencies = [
+ "bit-vec",
+]
+
+[[package]]
+name = "bit-vec"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+
+[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1165,6 +1180,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
[[package]]
+name = "glob"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
+
+[[package]]
name = "h2"
version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1630,6 +1651,12 @@ dependencies = [
]
[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
+
+[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2189,6 +2216,12 @@ dependencies = [
]
[[package]]
+name = "no-std-net"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65"
+
+[[package]]
name = "notify"
version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2214,6 +2247,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2"
dependencies = [
"autocfg",
+ "libm",
]
[[package]]
@@ -2535,6 +2569,48 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4503fa043bf02cee09a9582e9554b4c6403b2ef55e4612e96561d294419429f8"
[[package]]
+name = "pnet_base"
+version = "0.33.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "872e46346144ebf35219ccaa64b1dffacd9c6f188cd7d012bd6977a2a838f42e"
+dependencies = [
+ "no-std-net",
+]
+
+[[package]]
+name = "pnet_macros"
+version = "0.33.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2a780e80005c2e463ec25a6e9f928630049a10b43945fea83207207d4a7606f4"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "regex",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "pnet_macros_support"
+version = "0.33.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6d932134f32efd7834eb8b16d42418dac87086347d1bc7d142370ef078582bc"
+dependencies = [
+ "pnet_base",
+]
+
+[[package]]
+name = "pnet_packet"
+version = "0.33.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8bde678bbd85cb1c2d99dc9fc596e57f03aa725f84f3168b0eaf33eeccb41706"
+dependencies = [
+ "glob",
+ "pnet_base",
+ "pnet_macros",
+ "pnet_macros_support",
+]
+
+[[package]]
name = "poly1305"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2617,6 +2693,26 @@ dependencies = [
]
[[package]]
+name = "proptest"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf"
+dependencies = [
+ "bit-set",
+ "bit-vec",
+ "bitflags 2.4.0",
+ "lazy_static",
+ "num-traits",
+ "rand 0.8.5",
+ "rand_chacha 0.3.1",
+ "rand_xorshift",
+ "regex-syntax 0.8.2",
+ "rusty-fork",
+ "tempfile",
+ "unarray",
+]
+
+[[package]]
name = "prost"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2687,26 +2783,6 @@ dependencies = [
]
[[package]]
-name = "quickcheck"
-version = "1.0.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
-dependencies = [
- "rand 0.8.5",
-]
-
-[[package]]
-name = "quickcheck_macros"
-version = "1.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b22a693222d716a9587786f37ac3f6b4faedb5b80c23914e7303ff5a1d8016e9"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn 1.0.109",
-]
-
-[[package]]
name = "quote"
version = "1.0.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2787,6 +2863,15 @@ dependencies = [
]
[[package]]
+name = "rand_xorshift"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f"
+dependencies = [
+ "rand_core 0.6.4",
+]
+
+[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2824,7 +2909,7 @@ dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
- "regex-syntax",
+ "regex-syntax 0.7.5",
]
[[package]]
@@ -2835,7 +2920,7 @@ checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795"
dependencies = [
"aho-corasick",
"memchr",
- "regex-syntax",
+ "regex-syntax 0.7.5",
]
[[package]]
@@ -2845,6 +2930,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]]
+name = "regex-syntax"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
+
+[[package]]
name = "resolv-conf"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2975,6 +3066,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
[[package]]
+name = "rusty-fork"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f"
+dependencies = [
+ "fnv",
+ "quick-error",
+ "tempfile",
+ "wait-timeout",
+]
+
+[[package]]
name = "ryu"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3363,6 +3466,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
+name = "surge-ping"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af341b2be485d647b5dc4cfb2da99efac35b5c95748a08fb7233480fedc5ead3"
+dependencies = [
+ "hex",
+ "parking_lot",
+ "pnet_packet",
+ "rand 0.8.5",
+ "socket2 0.5.3",
+ "thiserror",
+ "tokio",
+ "tracing",
+]
+
+[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3445,8 +3564,7 @@ dependencies = [
"once_cell",
"parking_lot",
"pfctl",
- "quickcheck",
- "quickcheck_macros",
+ "proptest",
"rand 0.8.5",
"resolv-conf",
"subslice",
@@ -3656,9 +3774,11 @@ dependencies = [
"nix 0.23.2",
"once_cell",
"parking_lot",
+ "proptest",
"rand 0.8.5",
"rtnetlink",
"socket2 0.5.3",
+ "surge-ping",
"talpid-dbus",
"talpid-routing",
"talpid-tunnel",
@@ -4152,6 +4272,12 @@ dependencies = [
]
[[package]]
+name = "unarray"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94"
+
+[[package]]
name = "unicode-bidi"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4229,6 +4355,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
+name = "wait-timeout"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6"
+dependencies = [
+ "libc",
+]
+
+[[package]]
name = "walkdir"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 4d2a54fc3f..1ec70876ff 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -94,6 +94,5 @@ features = [
tonic-build = { workspace = true, default-features = false, features = ["transport", "prost"] }
[dev-dependencies]
-quickcheck = { version = "1.0", default-features = false }
-quickcheck_macros = "1.0"
+proptest = "1.4"
tokio = { workspace = true, features = [ "test-util" ] }
diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs
index 197042e353..ee23de312f 100644
--- a/talpid-core/src/future_retry.rs
+++ b/talpid-core/src/future_retry.rs
@@ -153,6 +153,7 @@ fn apply_jitter(duration: Duration, jitter: f64) -> Duration {
#[cfg(test)]
mod test {
use super::*;
+ use proptest::prelude::*;
#[test]
fn test_constant_interval() {
@@ -220,13 +221,15 @@ mod test {
assert_eq!(apply_jitter(second, 1.0), second);
}
- #[quickcheck_macros::quickcheck]
- fn test_jitter(millis: u64, jitter: u64) {
- let max_num = 2u64.checked_pow(f64::MANTISSA_DIGITS).unwrap();
- let jitter = (jitter % max_num) as f64 / (max_num as f64);
- let unjittered_duration = Duration::from_millis(millis);
- let jittered_duration = apply_jitter(unjittered_duration, jitter);
- assert!(jittered_duration <= unjittered_duration);
+ proptest! {
+ #[test]
+ fn test_jitter(millis: u64, jitter: u64) {
+ let max_num = 2u64.checked_pow(f64::MANTISSA_DIGITS).unwrap();
+ let jitter = (jitter % max_num) as f64 / (max_num as f64);
+ let unjittered_duration = Duration::from_millis(millis);
+ let jittered_duration = apply_jitter(unjittered_duration, jitter);
+ prop_assert!(jittered_duration <= unjittered_duration);
+ }
}
// NOTE: The test is disabled because the clock does not advance.
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 3e6df61f76..17ad2915d8 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -17,6 +17,14 @@ use talpid_wireguard;
const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
+/// Set the MTU to the lowest possible whilst still allowing for IPv6 to help with wireless
+/// carriers that do a lot of encapsulation.
+const DEFAULT_MTU: u16 = if cfg!(target_os = "android") {
+ 1280
+} else {
+ 1380
+};
+
/// Results from operations in the tunnel module.
pub type Result<T> = std::result::Result<T, Error>;
@@ -154,13 +162,29 @@ impl TunnelMonitor {
+ Clone
+ 'static,
{
+ let default_mtu = DEFAULT_MTU;
+
#[cfg(any(target_os = "linux", target_os = "windows"))]
- args.runtime
- .block_on(Self::assign_mtu(&args.route_manager, params));
- let config = talpid_wireguard::config::Config::from_parameters(params)?;
+ // Detects the MTU of the device and sets the default tunnel MTU to that minus headers and
+ // the safety margin
+ let default_mtu = args
+ .runtime
+ .block_on(
+ args.route_manager
+ .get_mtu_for_route(params.connection.peer.endpoint.ip()),
+ )
+ .map(|mtu| Self::clamp_mtu(params, mtu))
+ .unwrap_or(default_mtu);
+
+ #[cfg(target_os = "linux")]
+ let detect_mtu = params.options.mtu.is_none();
+
+ let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?;
let monitor = talpid_wireguard::WireguardMonitor::start(
config,
params.options.quantum_resistant,
+ #[cfg(target_os = "linux")]
+ detect_mtu,
log.as_deref(),
args,
)?;
@@ -169,58 +193,36 @@ impl TunnelMonitor {
})
}
- /// Set the MTU in the tunnel parameters based on the inputted device MTU and some
- /// calculations. `peer_mtu` is the detected device MTU.
+ /// Calculates and appropriate tunnel MTU based on the given peer MTU minus header sizes
#[cfg(any(target_os = "linux", target_os = "windows"))]
- fn set_mtu(params: &mut wireguard_types::TunnelParameters, peer_mtu: u16) {
+ fn clamp_mtu(params: &wireguard_types::TunnelParameters, peer_mtu: u16) -> u16 {
+ use talpid_tunnel::{
+ IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, MIN_IPV4_MTU, MIN_IPV6_MTU, WIREGUARD_HEADER_SIZE,
+ };
// Some users experience fragmentation issues even when we take the interface MTU and
// subtract the header sizes. This is likely due to some program that they use which does
// not change the interface MTU but adds its own header onto the outgoing packets. For this
// reason we subtract some extra bytes from our MTU in order to give other programs some
// safety margin.
const MTU_SAFETY_MARGIN: u16 = 60;
- const IPV4_HEADER_SIZE: u16 = 20;
- const IPV6_HEADER_SIZE: u16 = 40;
- const WIREGUARD_HEADER_SIZE: u16 = 40;
+
let total_header_size = WIREGUARD_HEADER_SIZE
+ match params.connection.peer.endpoint.is_ipv6() {
false => IPV4_HEADER_SIZE,
true => IPV6_HEADER_SIZE,
};
+
// The largest peer MTU that we allow
- const MAX_PEER_MTU: u16 = 1500 - MTU_SAFETY_MARGIN;
- // The minimum allowed MTU size for our tunnel in IPv6 is 1280 and 576 for IPv4
- const MIN_IPV4_MTU: u16 = 576;
- const MIN_IPV6_MTU: u16 = 1280;
+ let max_peer_mtu: u16 = 1500 - MTU_SAFETY_MARGIN - total_header_size;
+
let min_mtu = match params.generic_options.enable_ipv6 {
false => MIN_IPV4_MTU,
true => MIN_IPV6_MTU,
};
- let tunnel_mtu = peer_mtu
- .saturating_sub(total_header_size)
- .clamp(min_mtu, MAX_PEER_MTU - total_header_size);
- params.options.mtu = Some(tunnel_mtu);
- }
- /// Detects the MTU of the device, calculates what the virtual device MTU should be and sets
- /// that in the tunnel parameters.
- #[cfg(any(target_os = "linux", target_os = "windows"))]
- async fn assign_mtu(
- route_manager: &RouteManagerHandle,
- params: &mut wireguard_types::TunnelParameters,
- ) {
- // Only calculate the mtu automatically if the user has not set any
- if params.options.mtu.is_none() {
- match route_manager
- .get_mtu_for_route(params.connection.peer.endpoint.ip())
- .await
- {
- Ok(mtu) => Self::set_mtu(params, mtu),
- Err(e) => {
- log::error!("Could not get the MTU for route {}", e);
- }
- }
- }
+ peer_mtu
+ .saturating_sub(total_header_size)
+ .clamp(min_mtu, max_peer_mtu)
}
#[cfg(not(target_os = "android"))]
diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs
index 8a916c668d..8ce3dd2d0d 100644
--- a/talpid-tunnel/src/lib.rs
+++ b/talpid-tunnel/src/lib.rs
@@ -14,12 +14,25 @@ use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
+/// Size of IPv4 header in bytes
+pub const IPV4_HEADER_SIZE: u16 = 20;
+/// Size of IPv6 header in bytes
+pub const IPV6_HEADER_SIZE: u16 = 40;
+/// Size of wireguard header in bytes
+pub const WIREGUARD_HEADER_SIZE: u16 = 40;
+/// Size of ICMP header in bytes
+pub const ICMP_HEADER_SIZE: u16 = 8;
+/// Smallest allowed MTU for IPv4 in bytes
+pub const MIN_IPV4_MTU: u16 = 576;
+/// Smallest allowed MTU for IPv6 in bytes
+pub const MIN_IPV6_MTU: u16 = 1280;
+
/// Arguments for creating a tunnel.
pub struct TunnelArgs<'a, L>
where
L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
{
- /// Toktio runtime handle.
+ /// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
/// Resource directory path.
pub resource_dir: &'a Path,
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index 5c5a14a3fa..d28ccca2ae 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -28,6 +28,7 @@ chrono = { workspace = true, features = ["clock"] }
tokio = { workspace = true, features = ["process", "rt-multi-thread", "fs"] }
tunnel-obfuscation = { path = "../tunnel-obfuscation" }
rand = "0.8.5"
+surge-ping = "0.8.0"
[target.'cfg(target_os="android")'.dependencies]
duct = "0.13"
@@ -78,3 +79,6 @@ features = [
"Win32_UI_Shell",
"Win32_UI_WindowsAndMessaging",
]
+
+[dev-dependencies]
+proptest = "1.4"
diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs
index 0e462102b2..b30e9053fc 100644
--- a/talpid-wireguard/src/config.rs
+++ b/talpid-wireguard/src/config.rs
@@ -30,14 +30,6 @@ pub struct Config {
pub obfuscator_config: Option<ObfuscatorConfig>,
}
-/// Set the MTU to the lowest possible whilst still allowing for IPv6 to help with wireless
-/// carriers that do a lot of encapsulation.
-const DEFAULT_MTU: u16 = if cfg!(target_os = "android") {
- 1280
-} else {
- 1380
-};
-
/// Configuration errors
#[derive(err_derive::Error, Debug)]
pub enum Error {
@@ -52,12 +44,16 @@ pub enum Error {
impl Config {
/// Constructs a Config from parameters
- pub fn from_parameters(params: &wireguard::TunnelParameters) -> Result<Config, Error> {
+ pub fn from_parameters(
+ params: &wireguard::TunnelParameters,
+ default_mtu: u16,
+ ) -> Result<Config, Error> {
Self::new(
&params.connection,
&params.options,
&params.generic_options,
&params.obfuscation,
+ default_mtu,
)
}
@@ -67,9 +63,11 @@ impl Config {
wg_options: &wireguard::TunnelOptions,
generic_options: &GenericTunnelOptions,
obfuscator_config: &Option<ObfuscatorConfig>,
+ default_mtu: u16,
) -> Result<Config, Error> {
let mut tunnel = connection.tunnel.clone();
- let mtu = wg_options.mtu.unwrap_or(DEFAULT_MTU);
+
+ let mtu = wg_options.mtu.unwrap_or(default_mtu);
if tunnel.addresses.is_empty() {
return Err(Error::InvalidTunnelIpError);
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs
index 8f44b9b1c6..a7b09778d9 100644
--- a/talpid-wireguard/src/connectivity_check.rs
+++ b/talpid-wireguard/src/connectivity_check.rs
@@ -62,7 +62,7 @@ pub enum Error {
///
/// The connectivity monitor will start sending pings and start the countdown to `PING_TIMEOUT` in
/// the following cases:
-/// - In case that we have observed a bump in the outgoing traffic but no coressponding incoming
+/// - In case that we have observed a bump in the outgoing traffic but no corresponding incoming
/// traffic for longer than `BYTES_RX_TIMEOUT`, then the monitor will start pinging.
/// - In case that no increase in outgoing or incoming traffic has been observed for longer than
/// `TRAFFIC_TIMEOUT`, then the monitor will start pinging as well.
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index e8a63d0b1b..09a0fc929a 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -42,12 +42,17 @@ use tunnel_obfuscation::{
create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings,
};
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
+
/// WireGuard config data-types
pub mod config;
mod connectivity_check;
mod logging;
mod ping_monitor;
mod stats;
+#[cfg(target_os = "linux")]
+mod unix;
#[cfg(wireguard_go)]
mod wireguard_go;
#[cfg(target_os = "linux")]
@@ -69,6 +74,14 @@ pub enum Error {
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] talpid_routing::Error),
+ /// Failed to set MTU
+ #[error(display = "Failed to detect MTU because every ping was dropped.")]
+ MtuDetectionAllDropped,
+
+ /// Failed to set MTU
+ #[error(display = "Failed to detect MTU because of unexpected ping error.")]
+ MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),
+
/// Tunnel timed out
#[error(display = "Tunnel timed out")]
TimeoutError,
@@ -257,6 +270,7 @@ impl WireguardMonitor {
>(
mut config: Config,
psk_negotiation: bool,
+ #[cfg(target_os = "linux")] detect_mtu: bool,
log_path: Option<&Path>,
args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
@@ -375,7 +389,36 @@ impl WireguardMonitor {
)
.await?;
}
+ #[cfg(target_os = "linux")]
+ if detect_mtu {
+ let iface_name_clone = iface_name.clone();
+ tokio::task::spawn(async move {
+ log::debug!("Starting MTU detection");
+ let verified_mtu = match auto_mtu_detection(
+ gateway,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ iface_name_clone.clone(),
+ config.mtu,
+ )
+ .await
+ {
+ Ok(mtu) => mtu,
+ Err(e) => {
+ log::error!("{}", e.display_chain_with_msg("Failed to detect MTU"));
+ return;
+ }
+ };
+ if verified_mtu != config.mtu {
+ log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
+ if let Err(e) = unix::set_mtu(&iface_name_clone, verified_mtu) {
+ log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
+ };
+ } else {
+ log::debug!("MTU {verified_mtu} verified to not drop packets");
+ }
+ });
+ }
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
match connectivity_monitor.establish_connectivity(args.retry_attempt) {
Ok(true) => Ok(connectivity_monitor),
@@ -898,16 +941,11 @@ impl WireguardMonitor {
} else {
// Set route MTU by subtracting the WireGuard overhead from the tunnel MTU. Plus
// some margin to make room for padding bytes.
- // TODO: Move consts to shared location
- const IPV4_HEADER_SIZE: u16 = 20;
- const IPV6_HEADER_SIZE: u16 = 40;
- const WIREGUARD_HEADER_SIZE: u16 = 40;
- const PADDING_BYTES_MARGIN: u16 = 15;
-
let ip_overhead = match route.prefix.is_ipv4() {
true => IPV4_HEADER_SIZE,
false => IPV6_HEADER_SIZE,
};
+ const PADDING_BYTES_MARGIN: u16 = 15;
let mtu = config.mtu - ip_overhead - WIREGUARD_HEADER_SIZE - PADDING_BYTES_MARGIN;
route.mtu(mtu)
@@ -949,6 +987,123 @@ impl WireguardMonitor {
}
}
+/// Detects the maximum MTU that does not cause dropped packets.
+///
+/// The detection works by sending evenly spread out range of pings between 576 and the given
+/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
+#[cfg(target_os = "linux")]
+async fn auto_mtu_detection(
+ gateway: std::net::Ipv4Addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
+ current_mtu: u16,
+) -> Result<u16> {
+ use futures::{future, stream::FuturesUnordered, TryStreamExt};
+ use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
+ use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
+ use tokio_stream::StreamExt;
+
+ /// Max time to wait for any ping, when this expires, we give up and throw an error.
+ const PING_TIMEOUT: Duration = Duration::from_secs(10);
+ /// Max time to wait after the first ping arrives. Every ping after this timeout is considered
+ /// dropped, so we return the largest collected packet size.
+ const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
+
+ let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ let config_builder = config_builder.interface(&iface_name);
+ let client = Client::new(&config_builder.build()).unwrap();
+
+ let step_size = 20;
+ let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
+
+ let payload_buf = vec![0; current_mtu as usize];
+
+ let mut ping_stream = linspace
+ .iter()
+ .enumerate()
+ .map(|(i, &mtu)| {
+ let client = client.clone();
+ let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
+ let payload = &payload_buf[0..payload_size];
+ async move {
+ log::trace!("Sending ICMP ping of total size {mtu}");
+ client
+ .pinger(IpAddr::V4(gateway), PingIdentifier(0))
+ .await
+ .timeout(PING_TIMEOUT)
+ .ping(PingSequence(i as u16), payload)
+ .await
+ }
+ })
+ .collect::<FuturesUnordered<_>>()
+ .map_ok(|(packet, _rtt)| {
+ let surge_ping::IcmpPacket::V4(packet) = packet else {
+ unreachable!("ICMP ping response was not of IPv4 type");
+ };
+ let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
+ log::trace!("Got ICMP ping response of total size {size}");
+ debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
+ size
+ });
+
+ let first_ping_size = ping_stream
+ .next()
+ .await
+ .expect("At least one pings should be sent")
+ // Short-circuit and return on error
+ .map_err(|e| match e {
+ // If the first ping we get back timed out, then all of them did
+ SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
+ // Unexpected error type
+ e => Error::MtuDetectionPingError(e),
+ })?;
+
+ ping_stream
+ .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
+ .map_while(|res| res.ok()) // Stop waiting for pings after this timeout
+ .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
+ .await
+ .map_err(Error::MtuDetectionPingError)
+}
+
+/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
+/// points.
+#[cfg(target_os = "linux")]
+fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
+ assert!(mtu_min < mtu_max);
+ assert!(step_size < mtu_max);
+ assert_ne!(step_size, 0);
+
+ let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
+ let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
+
+ let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
+ ret.push(mtu_min);
+ ret.extend(in_between);
+ ret.push(mtu_max);
+ ret
+}
+
+#[cfg(all(test, target_os = "linux"))]
+mod tests {
+ use crate::mtu_spacing;
+ use proptest::prelude::*;
+
+ proptest! {
+ #[test]
+ fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
+ let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);
+
+ prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
+ prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
+ prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
+ let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
+ prop_assert!(diffs.all(|diff| diff <= step_size));
+
+ }
+ }
+}
+
#[derive(Debug)]
enum CloseMsg {
Stop,
diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/ping_monitor/icmp.rs
index a0afc8a98c..ad31349799 100644
--- a/talpid-wireguard/src/ping_monitor/icmp.rs
+++ b/talpid-wireguard/src/ping_monitor/icmp.rs
@@ -1,6 +1,7 @@
use byteorder::{NetworkEndian, WriteBytesExt};
use rand::Rng;
use socket2::{Domain, Protocol, Socket, Type};
+
use std::{
io::{self, Write},
net::{Ipv4Addr, SocketAddr},
@@ -59,6 +60,7 @@ pub struct Pinger {
}
impl Pinger {
+ /// Creates a new `Pinger`.
pub fn new(
addr: Ipv4Addr,
#[cfg(not(target_os = "windows"))] interface_name: String,
diff --git a/talpid-wireguard/src/unix.rs b/talpid-wireguard/src/unix.rs
new file mode 100644
index 0000000000..bef057a042
--- /dev/null
+++ b/talpid-wireguard/src/unix.rs
@@ -0,0 +1,41 @@
+use std::{io, os::fd::AsRawFd};
+
+use socket2::Domain;
+use talpid_types::ErrorExt;
+
+pub fn set_mtu(interface_name: &str, mtu: u16) -> Result<(), io::Error> {
+ debug_assert_ne!(
+ interface_name, "eth0",
+ "Should be name of mullvad tunnel interface, e.g. 'wg0-mullvad'"
+ );
+
+ let sock = socket2::Socket::new(
+ Domain::IPV4,
+ socket2::Type::STREAM,
+ Some(socket2::Protocol::TCP),
+ )?;
+
+ let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
+ if interface_name.len() >= ifr.ifr_name.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Interface name too long",
+ ));
+ }
+
+ unsafe {
+ std::ptr::copy_nonoverlapping(
+ interface_name.as_ptr() as *const i8,
+ &mut ifr.ifr_name as *mut _,
+ interface_name.len(),
+ )
+ };
+ ifr.ifr_ifru.ifru_mtu = mtu as i32;
+
+ if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU, &ifr) } < 0 {
+ let e = std::io::Error::last_os_error();
+ log::error!("{}", e.display_chain_with_msg("SIOCSIFMTU failed"));
+ return Err(e);
+ }
+ Ok(())
+}
diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
index b1159bb6de..579bcde65a 100644
--- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
+++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
@@ -122,7 +122,7 @@ impl Tunnel for NetlinkTunnel {
wg.set_config(interface_index, &config)
.await
.map_err(|err| {
- log::error!("Failed to fetch WireGuard device config: {}", err);
+ log::error!("Failed to set WireGuard device config: {}", err);
TunnelError::SetConfigError
})
})