diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-09-04 09:53:02 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-09-04 09:53:02 +0200 |
| commit | f3bf9e00303eee55e2293bd39057e166938879d1 (patch) | |
| tree | 522317866e9565ef4d3972cd84c841bb5aeb2ec9 | |
| parent | a7c27135483ab512b155ed8ce11d39f337585826 (diff) | |
| parent | 359ffc366907d93f2e3fb3f2640f9309348d1835 (diff) | |
| download | mullvadvpn-f3bf9e00303eee55e2293bd39057e166938879d1.tar.xz mullvadvpn-f3bf9e00303eee55e2293bd39057e166938879d1.zip | |
Merge branch 'masque-win-gso'
| -rw-r--r-- | CHANGELOG.md | 4 | ||||
| -rw-r--r-- | Cargo.lock | 2 | ||||
| -rw-r--r-- | mullvad-masque-proxy/Cargo.toml | 5 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 289 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/lib.rs | 4 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 9 |
6 files changed, 293 insertions, 20 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 955b6ed7c4..bad6478441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ Line wrap the file at 100 chars. Th #### Windows - Add additional logging for tunnel devices and split tunneling to problem reports. +### Changed +#### Windows +- Implement UDP GSO for QUIC on client socket. This improves download speeds slightly. + ### Security #### Windows - Block traffic to exit node from non-Mullvad processes. This fixes a leak where traffic could be diff --git a/Cargo.lock b/Cargo.lock index 0578bec7af..a5e7e02c67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3284,6 +3284,7 @@ dependencies = [ "h3-datagram", "h3-quinn", "http 1.1.0", + "libc", "log", "quinn", "rand 0.8.5", @@ -3293,6 +3294,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "typed-builder 0.21.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml index 31bc379dd3..73a6452bd4 100644 --- a/mullvad-masque-proxy/Cargo.toml +++ b/mullvad-masque-proxy/Cargo.toml @@ -20,9 +20,14 @@ rustls = { version = "0.23", default-features = false } rustls-pemfile = "2.1.3" bytes = "1" anyhow = { workspace = true } +libc = "0.2" log = { workspace = true } typed-builder = "0.21.0" +[target.'cfg(windows)'.dependencies.windows-sys] +workspace = true +features = ["Win32", "Win32_Foundation", "Win32_Networking", "Win32_Networking_WinSock"] + [dev-dependencies] env_logger = { workspace = true } tokio = { workspace = true, features = ["fs", "macros", "io-util", "rt-multi-thread"] } diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 801249cfb2..62139a1ab8 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -1,5 +1,5 @@ use anyhow::{Context, anyhow}; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use rustls::client::danger::ServerCertVerified; use std::{ fs::{self}, @@ -355,11 +355,16 @@ impl Client { return_addr_tx, )); - let mut client_socket_tx_task = tokio::task::spawn(client_socket_tx_task( + let (send_tx, send_rx) = mpsc::channel::<(SocketAddr, Bytes)>(MAX_INFLIGHT_PACKETS); + + let mut client_socket_tx_task = + tokio::task::spawn(client_socket_tx_task(self.client_socket.clone(), send_rx)); + + let mut fragment_reassembly_task = tokio::task::spawn(fragment_reassembly_task( stream_id, server_rx, return_addr_rx, - self.client_socket.clone(), + send_tx, Arc::clone(&self.stats), )); @@ -375,11 +380,13 @@ impl Client { let result = select! { result = &mut client_socket_tx_task => result, + result = &mut fragment_reassembly_task => result, result = &mut client_socket_rx_task => result, result = &mut server_socket_task => result, }; client_socket_tx_task.abort(); + fragment_reassembly_task.abort(); client_socket_rx_task.abort(); server_socket_task.abort(); @@ -457,11 +464,16 @@ async fn client_socket_rx_task( client_tx: mpsc::Sender<Bytes>, return_addr_tx: broadcast::Sender<SocketAddr>, ) -> Result<()> { - let mut client_read_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); + const TOTAL_BUFFER_CAPACITY: usize = 100 * crate::MAX_UDP_SIZE; + + let mut client_read_buf = BytesMut::with_capacity(TOTAL_BUFFER_CAPACITY); let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); loop { - client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); + if !client_read_buf.try_reclaim(crate::MAX_UDP_SIZE) { + // Allocate space for new packets + client_read_buf.reserve(TOTAL_BUFFER_CAPACITY); + } // this is the variable ID used to signify UDP payloads in HTTP datagrams. crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); @@ -488,10 +500,149 @@ async fn client_socket_rx_task( } async fn client_socket_tx_task( + client_socket: Arc<UdpSocket>, + send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + #[cfg(target_os = "windows")] + if *windows::MAX_GSO_SEGMENTS > 1 { + log::debug!("UDP GSO enabled"); + return client_socket_gso_tx_task(client_socket, send_rx).await; + } + + log::debug!("UDP GSO disabled"); + + client_socket_non_gso_tx_task(client_socket, send_rx).await +} + +#[cfg(target_os = "windows")] +async fn client_socket_gso_tx_task( + client_socket: Arc<UdpSocket>, + mut send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + use bytes::Buf; + use std::{collections::VecDeque, mem}; + use tokio::io::Interest; + use windows::*; + use windows_sys::Win32::Networking::WinSock; + + const MAX_SEGMENT_SIZE: usize = 1500; + + let client_socket_ref = socket2::SockRef::from(&client_socket); + + let mut buffer = Vec::with_capacity(*MAX_GSO_SEGMENTS * MAX_SEGMENT_SIZE); + let mut cmsg_buf = Cmsg::new( + mem::size_of::<u32>(), + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + ); + let mut queued_packets = VecDeque::new(); + + loop { + // Fill up queue + if queued_packets.is_empty() { + let Some((dest, packet)) = send_rx.recv().await else { + break; + }; + queued_packets.push_back((dest, packet)); + } + while let Ok((dest, packet)) = send_rx.try_recv() { + queued_packets.push_back((dest, packet)); + } + + let (dest, packet) = queued_packets.pop_front().expect("never empty"); + + // If the queue is empty now, send a single packet using send_to + if queued_packets.is_empty() { + client_socket + .send_to(packet.chunk(), dest) + .await + .map_err(Error::ClientWrite)?; + continue; + } + + let segment_size = packet.len(); + buffer.clear(); + buffer.extend_from_slice(packet.chunk()); + + loop { + let Some((next_dest, next_packet)) = queued_packets.pop_front() else { + break; + }; + + // If the destination differs, stop coalescing packets + if next_dest != dest { + // Flush the buffer now and queue this packet + // This should occur rarely, as we're expecting a single UDP client + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // On overflow, also stop coalescing + if buffer.len() + next_packet.len() > buffer.capacity() { + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // If this packet is larger, we are done + if next_packet.len() > segment_size { + queued_packets.push_front((next_dest, next_packet)); + break; + } + + // Otherwise, append the next packet to the bunch + buffer.extend_from_slice(next_packet.chunk()); + + // The last packet may be smaller than previous segments: + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options + if next_packet.len() < segment_size { + break; + } + } + + client_socket + .async_io(Interest::WRITABLE, || { + use std::io::IoSlice; + + // Call sendmsg with one CMSG containing the segment size. + // This will send all packets in `buffer`. + + // SAFETY: We have allocated capacity for a u32. The data may contain that. + unsafe { *(cmsg_buf.data_mut_ptr() as *mut u32) = segment_size as u32 }; + + let io_slices = [IoSlice::new(&buffer); 1]; + let daddr = socket2::SockAddr::from(dest); + let msg_hdr = socket2::MsgHdr::new() + .with_addr(&daddr) + .with_buffers(&io_slices) + .with_control(cmsg_buf.as_slice()); + + client_socket_ref.sendmsg(&msg_hdr, 0) + }) + .await + .map_err(Error::ClientWrite)?; + } + + Ok(()) +} + +async fn client_socket_non_gso_tx_task( + client_socket: Arc<UdpSocket>, + mut send_rx: mpsc::Receiver<(SocketAddr, Bytes)>, +) -> Result<()> { + while let Some((dest, buf)) = send_rx.recv().await { + client_socket + .send_to(&buf, dest) + .await + .map_err(Error::ClientWrite)?; + } + Ok(()) +} + +async fn fragment_reassembly_task( stream_id: StreamId, mut server_rx: mpsc::Receiver<Datagram>, mut return_addr_rx: broadcast::Receiver<SocketAddr>, - client_socket: Arc<UdpSocket>, + send_tx: mpsc::Sender<(SocketAddr, Bytes)>, stats: Arc<Stats>, ) -> Result<()> { let mut fragments = Fragments::default(); @@ -522,22 +673,22 @@ async fn client_socket_tx_task( let payload = response.into_payload(); let original_payload_len = payload.len(); - let send = async |payload: &[u8]| -> Result<()> { - client_socket - .send_to(payload, return_addr) - .await - .map_err(Error::ClientWrite)?; - Ok(()) - }; - match fragments.handle_incoming_packet(payload) { Ok(DefragReceived::Nonfragmented(payload)) => { stats.rx(payload.len(), false); - send(payload.chunk()).await?; + if send_tx.send((return_addr, payload)).await.is_err() { + break; + } } Ok(DefragReceived::Reassembled(reassembled_payload)) => { stats.rx(original_payload_len, true); - send(reassembled_payload.chunk()).await?; + if send_tx + .send((return_addr, reassembled_payload)) + .await + .is_err() + { + break; + } } Ok(DefragReceived::Fragment) => stats.rx(original_payload_len, true), Err(e) => { @@ -702,3 +853,109 @@ impl rustls::client::danger::ServerCertVerifier for Approver { ] } } + +#[cfg(target_os = "windows")] +mod windows { + use socket2::{Domain, Socket, Type}; + use std::{ffi::c_uchar, mem, sync::LazyLock}; + use std::{ffi::c_uint, os::windows::io::AsRawSocket}; + use windows_sys::Win32::Networking::WinSock::{self, CMSGHDR}; + + /// Struct representing a CMSG + pub struct Cmsg { + buffer: Vec<u8>, + } + + impl Cmsg { + /// Create a new with space for `space` bytes and a CMSG header + pub fn new(space: usize, cmsg_level: i32, cmsg_type: i32) -> Self { + let mut self_ = Self { + buffer: vec![0u8; cmsg_space(space)], + }; + + *self_.header_mut() = CMSGHDR { + cmsg_len: cmsg_len(space), + cmsg_level, + cmsg_type, + }; + + self_ + } + + fn header_mut(&mut self) -> &mut CMSGHDR { + let hdr = self.buffer.as_mut_ptr() as *mut CMSGHDR; + debug_assert!(hdr.is_aligned()); + // SAFETY: `hdr` is aligned and points to an initialized `CMSGHDR` + unsafe { &mut *hdr } + } + + pub fn as_slice(&self) -> &[u8] { + &self.buffer[..] + } + + pub fn data_mut_ptr(&mut self) -> *mut u8 { + let header = self.header_mut(); + // SAFETY: The buffer is initialized using `cmsg_space`, so this points to actual data + // (but len may be 0) + unsafe { cmsg_data(header) } + } + } + + /// The total size of an ancillary data object given the amount of data + /// Source: ws2def.h: CMSG_SPACE macro + pub fn cmsg_space(length: usize) -> usize { + cmsgdata_align(mem::size_of::<CMSGHDR>() + cmsghdr_align(length)) + } + + /// Value to store in the `cmsg_len` of the CMSG header given an amount of data. + /// Source: ws2def.h: CMSG_LEN macro + pub fn cmsg_len(length: usize) -> usize { + cmsgdata_align(mem::size_of::<CMSGHDR>()) + length + } + + /// Pointer to the first byte of data in `cmsg`. + /// Source: ws2def.h: CMSG_DATA macro + pub unsafe fn cmsg_data(cmsg: *mut CMSGHDR) -> *mut c_uchar { + (cmsg as usize + cmsgdata_align(mem::size_of::<CMSGHDR>())) as *mut c_uchar + } + + // Taken from ws2def.h: CMSGHDR_ALIGN macro + pub fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::<WinSock::CMSGHDR>() - 1) + & !(mem::align_of::<WinSock::CMSGHDR>() - 1) + } + + // Source: ws2def.h: CMSGDATA_ALIGN macro + pub fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1) + } + + pub static MAX_GSO_SEGMENTS: LazyLock<usize> = LazyLock::new(|| { + // Detect whether UDP GSO is supported + + let Ok(socket) = Socket::new(Domain::IPV4, Type::DGRAM, None) else { + return 1; + }; + + let mut gso_size: c_uint = 1500; + + // SAFETY: We're correctly passing an *mut c_uint specifying the size, a valid socket, and + // its correct size. + let result = unsafe { + libc::setsockopt( + socket.as_raw_socket() as libc::SOCKET, + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + &mut gso_size as *mut _ as *mut _, + i32::try_from(std::mem::size_of_val(&gso_size)).unwrap(), + ) + }; + + // If non-zero (error), set max segment count to 1. Otherwise, set it to 512. + // 512 is the "empirically found" value also used by quinn + match result { + 0 => 512, + _ => 1, + } + }); +} diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs index 37f8353acd..2c7b64f900 100644 --- a/mullvad-masque-proxy/src/lib.rs +++ b/mullvad-masque-proxy/src/lib.rs @@ -11,9 +11,9 @@ pub const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/"; pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0); pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1); -/// Minimum size of buffer used to hold UDP packets. +/// Maximum possible buffer size UDP packets, plus context ID. // 1 byte for size of HTTP_MASQUE_DATAGRAM_CONTEXT_ID -const PACKET_BUFFER_SIZE: usize = (u16::MAX - UDP_HEADER_SIZE + 1) as usize; +const MAX_UDP_SIZE: usize = (u16::MAX - UDP_HEADER_SIZE + 1) as usize; /// Maximum number of inflight packets, in both directions. const MAX_INFLIGHT_PACKETS: usize = 100; diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs index 3788ca7c80..2232fbdae2 100644 --- a/mullvad-masque-proxy/src/server/mod.rs +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -337,13 +337,18 @@ async fn proxy_rx_task( udp_socket: impl AsRef<UdpSocket>, send_tx: mpsc::Sender<Bytes>, ) { + const TOTAL_BUFFER_CAPACITY: usize = 100 * crate::MAX_UDP_SIZE; + let stream_id_size = VarInt::from(stream_id).size() as u16; let udp_socket = udp_socket.as_ref(); - let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); + let mut proxy_recv_buf = BytesMut::with_capacity(TOTAL_BUFFER_CAPACITY); let mut fragment_id = 0u16; loop { - proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE); + if !proxy_recv_buf.try_reclaim(crate::MAX_UDP_SIZE) { + // Allocate space for new packets + proxy_recv_buf.reserve(TOTAL_BUFFER_CAPACITY); + } crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); let (_n, sender_addr) = match udp_socket.recv_buf_from(&mut proxy_recv_buf).await { |
