summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-02-09 16:01:51 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-02-09 16:01:51 +0100
commit433547101d2c4e26c5224381a72330bce1ce51fb (patch)
tree8e8db37b1669a0203909ad1eac7b8289263c3aed
parent4011170772e72a2b8fd535b26ee51303b43f1f3e (diff)
parent51693fe1e5925db85e6b7a374ab2be963f000c70 (diff)
downloadmullvadvpn-433547101d2c4e26c5224381a72330bce1ce51fb.tar.xz
mullvadvpn-433547101d2c4e26c5224381a72330bce1ce51fb.zip
Merge branch 'mtu-detection-windows'
-rw-r--r--talpid-core/src/tunnel/mod.rs4
-rw-r--r--talpid-windows/src/net.rs22
-rw-r--r--talpid-wireguard/Cargo.toml2
-rw-r--r--talpid-wireguard/src/lib.rs39
4 files changed, 53 insertions, 14 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 17ad2915d8..652d24e4b9 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -176,14 +176,14 @@ impl TunnelMonitor {
.map(|mtu| Self::clamp_mtu(params, mtu))
.unwrap_or(default_mtu);
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", windows))]
let detect_mtu = params.options.mtu.is_none();
let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?;
let monitor = talpid_wireguard::WireguardMonitor::start(
config,
params.options.quantum_resistant,
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", windows))]
detect_mtu,
log.as_deref(),
args,
diff --git a/talpid-windows/src/net.rs b/talpid-windows/src/net.rs
index e9899743d6..1545b85e4a 100644
--- a/talpid-windows/src/net.rs
+++ b/talpid-windows/src/net.rs
@@ -332,6 +332,28 @@ pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Resul
win32_err!(unsafe { CreateUnicastIpAddressEntry(&row) }).map_err(Error::CreateUnicastEntry)
}
+/// Sets MTU on the specified network interface identified by `luid`.
+pub fn set_mtu(luid: NET_LUID_LH, mtu: u32, use_ipv6: bool) -> io::Result<()> {
+ let ip_families: &[AddressFamily] = if use_ipv6 {
+ &[AddressFamily::Ipv4, AddressFamily::Ipv6]
+ } else {
+ &[AddressFamily::Ipv4]
+ };
+ for family in ip_families {
+ let mut row = match get_ip_interface_entry(*family, &luid) {
+ Ok(row) => row,
+ Err(error) if error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => continue,
+ Err(error) => return Err(error),
+ };
+
+ row.NlMtu = mtu;
+
+ set_ip_interface_entry(&mut row)?;
+ }
+
+ Ok(())
+}
+
/// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are
/// returned.
pub fn get_unicast_table(
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index d28ccca2ae..01f4b5fd1a 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -37,6 +37,7 @@ duct = "0.13"
byteorder = "1"
internet-checksum = "0.2"
socket2 = { version = "0.5.3", features = ["all"] }
+tokio-stream = { version = "0.1", features = ["io-util"] }
[target.'cfg(unix)'.dependencies]
nix = "0.23"
@@ -48,7 +49,6 @@ netlink-packet-route = "0.13"
netlink-packet-utils = "0.5.1"
netlink-proto = "0.10"
talpid-dbus = { path = "../talpid-dbus" }
-tokio-stream = { version = "0.1", features = ["io-util"] }
[target.'cfg(windows)'.dependencies]
bitflags = "1.2"
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 09a0fc929a..71331c2edf 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -3,9 +3,9 @@
#![deny(missing_docs)]
use self::config::Config;
-use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
#[cfg(windows)]
-use futures::{channel::mpsc, StreamExt};
+use futures::channel::mpsc;
+use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
#[cfg(target_os = "linux")]
use once_cell::sync::Lazy;
#[cfg(target_os = "android")]
@@ -26,6 +26,8 @@ use talpid_routing as routing;
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
+#[cfg(not(target_os = "android"))]
+use talpid_tunnel::IPV4_HEADER_SIZE;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use ipnetwork::IpNetwork;
@@ -42,9 +44,6 @@ use tunnel_obfuscation::{
create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings,
};
-#[cfg(any(target_os = "linux", target_os = "macos"))]
-use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
-
/// WireGuard config data-types
pub mod config;
mod connectivity_check;
@@ -270,7 +269,7 @@ impl WireguardMonitor {
>(
mut config: Config,
psk_negotiation: bool,
- #[cfg(target_os = "linux")] detect_mtu: bool,
+ #[cfg(any(target_os = "linux", windows))] detect_mtu: bool,
log_path: Option<&Path>,
args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
@@ -389,7 +388,8 @@ impl WireguardMonitor {
)
.await?;
}
- #[cfg(target_os = "linux")]
+
+ #[cfg(any(target_os = "linux", windows))]
if detect_mtu {
let iface_name_clone = iface_name.clone();
tokio::task::spawn(async move {
@@ -411,7 +411,20 @@ impl WireguardMonitor {
if verified_mtu != config.mtu {
log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
- if let Err(e) = unix::set_mtu(&iface_name_clone, verified_mtu) {
+ #[cfg(target_os = "linux")]
+ let res = unix::set_mtu(&iface_name_clone, verified_mtu);
+ #[cfg(windows)]
+ let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
+ |luid| {
+ talpid_windows::net::set_mtu(
+ luid,
+ verified_mtu as u32,
+ config.ipv6_gateway.is_some(),
+ )
+ },
+ );
+
+ if let Err(e) = res {
log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
};
} else {
@@ -664,6 +677,8 @@ impl WireguardMonitor {
addresses: &[IpAddr],
mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>,
) -> std::result::Result<(), CloseMsg> {
+ use futures::StreamExt;
+
setup_done_rx
.next()
.await
@@ -936,6 +951,8 @@ impl WireguardMonitor {
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
+ use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
+
if !config.is_multihop() {
route
} else {
@@ -991,7 +1008,7 @@ impl WireguardMonitor {
///
/// The detection works by sending evenly spread out range of pings between 576 and the given
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", windows))]
async fn auto_mtu_detection(
gateway: std::net::Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
@@ -1068,7 +1085,7 @@ async fn auto_mtu_detection(
/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
/// points.
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", windows))]
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
assert!(mtu_min < mtu_max);
assert!(step_size < mtu_max);
@@ -1084,7 +1101,7 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
ret
}
-#[cfg(all(test, target_os = "linux"))]
+#[cfg(all(test, any(target_os = "linux", windows)))]
mod tests {
use crate::mtu_spacing;
use proptest::prelude::*;