diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-09-03 13:43:40 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-09-04 09:52:36 +0200 |
| commit | c08d081cb4d5f19e5e3f35904dfe2b8874ca4b88 (patch) | |
| tree | adcbaaca627c8dac8608f50cdedb9f73f21a8a47 | |
| parent | a7c27135483ab512b155ed8ce11d39f337585826 (diff) | |
| download | mullvadvpn-c08d081cb4d5f19e5e3f35904dfe2b8874ca4b88.tar.xz mullvadvpn-c08d081cb4d5f19e5e3f35904dfe2b8874ca4b88.zip | |
Enable UDP GSO on Windows for masque proxy client
| -rw-r--r-- | CHANGELOG.md | 4 | ||||
| -rw-r--r-- | Cargo.lock | 2 | ||||
| -rw-r--r-- | mullvad-masque-proxy/Cargo.toml | 5 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 280 |
4 files changed, 277 insertions, 14 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 955b6ed7c4..bad6478441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ Line wrap the file at 100 chars. Th #### Windows - Add additional logging for tunnel devices and split tunneling to problem reports. +### Changed +#### Windows +- Implement UDP GSO for QUIC on client socket. This improves download speeds slightly. + ### Security #### Windows - Block traffic to exit node from non-Mullvad processes. This fixes a leak where traffic could be diff --git a/Cargo.lock b/Cargo.lock index 0578bec7af..a5e7e02c67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3284,6 +3284,7 @@ dependencies = [ "h3-datagram", "h3-quinn", "http 1.1.0", + "libc", "log", "quinn", "rand 0.8.5", @@ -3293,6 +3294,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "typed-builder 0.21.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml index 31bc379dd3..73a6452bd4 100644 --- a/mullvad-masque-proxy/Cargo.toml +++ b/mullvad-masque-proxy/Cargo.toml @@ -20,9 +20,14 @@ rustls = { version = "0.23", default-features = false } rustls-pemfile = "2.1.3" bytes = "1" anyhow = { workspace = true } +libc = "0.2" log = { workspace = true } typed-builder = "0.21.0" +[target.'cfg(windows)'.dependencies.windows-sys] +workspace = true +features = ["Win32", "Win32_Foundation", "Win32_Networking", "Win32_Networking_WinSock"] + [dev-dependencies] env_logger = { workspace = true } tokio = { workspace = true, features = ["fs", "macros", "io-util", "rt-multi-thread"] } diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 801249cfb2..cb178788ec 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -1,5 +1,5 @@ use anyhow::{Context, anyhow}; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use rustls::client::danger::ServerCertVerified; use std::{ fs::{self}, @@ -355,11 +355,16 @@ impl Client { return_addr_tx, )); - let mut client_socket_tx_task = tokio::task::spawn(client_socket_tx_task( + let (send_tx, send_rx) = mpsc::channel::<(SocketAddr, Bytes)>(MAX_INFLIGHT_PACKETS); + + let mut client_socket_tx_task = + tokio::task::spawn(client_socket_tx_task(self.client_socket.clone(), send_rx)); + + let mut fragment_reassembly_task = tokio::task::spawn(fragment_reassembly_task( stream_id, server_rx, return_addr_rx, - self.client_socket.clone(), + send_tx, Arc::clone(&self.stats), )); @@ -375,11 +380,13 @@ impl Client { let result = select! { result = &mut client_socket_tx_task => result, + result = &mut fragment_reassembly_task => result, result = &mut client_socket_rx_task => result, result = &mut server_socket_task => result, }; client_socket_tx_task.abort(); + fragment_reassembly_task.abort(); client_socket_rx_task.abort(); server_socket_task.abort(); @@ -488,10 +495,149 @@ async fn client_socket_rx_task( } async fn client_socket_tx_task( + client_socket: Arc<UdpSocket>, + send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + #[cfg(target_os = "windows")] + if *windows::MAX_GSO_SEGMENTS > 1 { + log::debug!("UDP GSO enabled"); + return client_socket_gso_tx_task(client_socket, send_rx).await; + } + + log::debug!("UDP GSO disabled"); + + client_socket_non_gso_tx_task(client_socket, send_rx).await +} + +#[cfg(target_os = "windows")] +async fn client_socket_gso_tx_task( + client_socket: Arc<UdpSocket>, + mut send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + use bytes::Buf; + use std::{collections::VecDeque, mem}; + use tokio::io::Interest; + use windows::*; + use windows_sys::Win32::Networking::WinSock; + + const MAX_SEGMENT_SIZE: usize = 1500; + + let client_socket_ref = socket2::SockRef::from(&client_socket); + + let mut buffer = Vec::with_capacity(*MAX_GSO_SEGMENTS * MAX_SEGMENT_SIZE); + let mut cmsg_buf = Cmsg::new( + mem::size_of::<u32>(), + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + ); + let mut queued_packets = VecDeque::new(); + + loop { + // Fill up queue + if queued_packets.is_empty() { + let Some((dest, packet)) = send_rx.recv().await else { + break; + }; + queued_packets.push_back((dest, packet)); + } + while let Ok((dest, packet)) = send_rx.try_recv() { + queued_packets.push_back((dest, packet)); + } + + let (dest, packet) = queued_packets.pop_front().expect("never empty"); + + // If the queue is empty now, send a single packet using send_to + if queued_packets.is_empty() { + client_socket + .send_to(packet.chunk(), dest) + .await + .map_err(Error::ClientWrite)?; + continue; + } + + let segment_size = packet.len(); + buffer.clear(); + buffer.extend_from_slice(packet.chunk()); + + loop { + let Some((next_dest, next_packet)) = queued_packets.pop_front() else { + break; + }; + + // If the destination differs, stop coalescing packets + if next_dest != dest { + // Flush the buffer now and queue this packet + // This should occur rarely, as we're expecting a single UDP client + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // On overflow, also stop coalescing + if buffer.len() + next_packet.len() > buffer.capacity() { + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // If this packet is larger, we are done + if next_packet.len() > segment_size { + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // Otherwise, append the next packet to the bunch + buffer.extend_from_slice(next_packet.chunk()); + + // The last packet may be smaller than previous segments: + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options + if next_packet.len() < segment_size { + break; + } + } + + client_socket + .async_io(Interest::WRITABLE, || { + use std::io::IoSlice; + + // Call sendmsg with one CMSG containing the segment size. + // This will send all packets in `buffer`. + + // SAFETY: We have allocated capacity for a u32. The data may contain that. + unsafe { *(cmsg_buf.data_mut_ptr() as *mut u32) = segment_size as u32 }; + + let io_slices = [IoSlice::new(&buffer); 1]; + let daddr = socket2::SockAddr::from(dest); + let msg_hdr = socket2::MsgHdr::new() + .with_addr(&daddr) + .with_buffers(&io_slices) + .with_control(cmsg_buf.as_slice()); + + client_socket_ref.sendmsg(&msg_hdr, 0) + }) + .await + .map_err(Error::ClientWrite)?; + } + + Ok(()) +} + +async fn client_socket_non_gso_tx_task( + client_socket: Arc<UdpSocket>, + mut send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + while let Some((dest, buf)) = send_rx.recv().await { + client_socket + .send_to(&buf, dest) + .await + .map_err(Error::ClientWrite)?; + } + Ok(()) +} + +async fn fragment_reassembly_task( stream_id: StreamId, mut server_rx: mpsc::Receiver<Datagram>, mut return_addr_rx: broadcast::Receiver<SocketAddr>, - client_socket: Arc<UdpSocket>, + send_tx: mpsc::Sender<(SocketAddr, Bytes)>, stats: Arc<Stats>, ) -> Result<()> { let mut fragments = Fragments::default(); @@ -522,22 +668,22 @@ async fn client_socket_tx_task( let payload = response.into_payload(); let original_payload_len = payload.len(); - let send = async |payload: &[u8]| -> Result<()> { - client_socket - .send_to(payload, return_addr) - .await - .map_err(Error::ClientWrite)?; - Ok(()) - }; - match fragments.handle_incoming_packet(payload) { Ok(DefragReceived::Nonfragmented(payload)) => { stats.rx(payload.len(), false); - send(payload.chunk()).await?; + if send_tx.send((return_addr, payload)).await.is_err() { + break; + } } Ok(DefragReceived::Reassembled(reassembled_payload)) => { stats.rx(original_payload_len, true); - send(reassembled_payload.chunk()).await?; + if send_tx + .send((return_addr, reassembled_payload)) + .await + .is_err() + { + break; + } } Ok(DefragReceived::Fragment) => stats.rx(original_payload_len, true), Err(e) => { @@ -702,3 +848,109 @@ impl rustls::client::danger::ServerCertVerifier for Approver { ] } } + +#[cfg(target_os = "windows")] +mod windows { + use socket2::{Domain, Socket, Type}; + use std::{ffi::c_uchar, mem, sync::LazyLock}; + use std::{ffi::c_uint, os::windows::io::AsRawSocket}; + use windows_sys::Win32::Networking::WinSock::{self, CMSGHDR}; + + /// Struct representing a CMSG + pub struct Cmsg { + buffer: Vec<u8>, + } + + impl Cmsg { + /// Create a new with space for `space` bytes and a CMSG header + pub fn new(space: usize, cmsg_level: i32, cmsg_type: i32) -> Self { + let mut self_ = Self { + buffer: vec![0u8; cmsg_space(space)], + }; + + *self_.header_mut() = CMSGHDR { + cmsg_len: cmsg_len(space), + cmsg_level, + cmsg_type, + }; + + self_ + } + + fn header_mut(&mut self) -> &mut CMSGHDR { + let hdr = self.buffer.as_mut_ptr() as *mut CMSGHDR; + debug_assert!(hdr.is_aligned()); + // SAFETY: `hdr` is aligned and points to an initialized `CMSGHDR` + unsafe { &mut *hdr } + } + + pub fn as_slice(&self) -> &[u8] { + &self.buffer[..] + } + + pub fn data_mut_ptr(&mut self) -> *mut u8 { + let header = self.header_mut(); + // SAFETY: The buffer is initialized using `cmsg_space`, so this points to actual data + // (but len may be 0) + unsafe { cmsg_data(header) } + } + } + + /// The total size of an ancillary data object given the amount of data + /// Source: ws2def.h: CMSG_SPACE macro + pub fn cmsg_space(length: usize) -> usize { + cmsgdata_align(mem::size_of::<CMSGHDR>() + cmsghdr_align(length)) + } + + /// Value to store in the `cmsg_len` of the CMSG header given an amount of data. + /// Source: ws2def.h: CMSG_LEN macro + pub fn cmsg_len(length: usize) -> usize { + cmsgdata_align(mem::size_of::<CMSGHDR>()) + length + } + + /// Pointer to the first byte of data in `cmsg`. + /// Source: ws2def.h: CMSG_DATA macro + pub unsafe fn cmsg_data(cmsg: *mut CMSGHDR) -> *mut c_uchar { + (cmsg as usize + cmsgdata_align(mem::size_of::<CMSGHDR>())) as *mut c_uchar + } + + // Taken from ws2def.h: CMSGHDR_ALIGN macro + pub fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::<WinSock::CMSGHDR>() - 1) + & !(mem::align_of::<WinSock::CMSGHDR>() - 1) + } + + // Source: ws2def.h: CMSGDATA_ALIGN macro + pub fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1) + } + + pub static MAX_GSO_SEGMENTS: LazyLock<usize> = LazyLock::new(|| { + // Detect whether UDP GSO is supported + + let Ok(socket) = Socket::new(Domain::IPV4, Type::DGRAM, None) else { + return 1; + }; + + let mut gso_size: c_uint = 1500; + + // SAFETY: We're correctly passing an *mut c_uint specifying the size, a valid socket, and + // its correct size. + let result = unsafe { + libc::setsockopt( + socket.as_raw_socket() as libc::SOCKET, + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + &mut gso_size as *mut _ as *mut _, + i32::try_from(std::mem::size_of_val(&gso_size)).unwrap(), + ) + }; + + // If non-zero (error), set max segment count to 1. Otherwise, set it to 512. + // 512 is the "empirically found" value also used by quinn + match result { + 0 => 512, + _ => 1, + } + }); +} |
