diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-10 18:43:12 +0200 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-10 18:43:12 +0200 |
| commit | affa09e9901fca9dc4a8120c4d339e41a65a1de3 (patch) | |
| tree | 5ddf69060af30b6a7a78ddf3ce47eceec958d917 | |
| parent | c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2 (diff) | |
| parent | 4846498154e58fa06284f49746ed2f0c014f705e (diff) | |
| download | mullvadvpn-affa09e9901fca9dc4a8120c4d339e41a65a1de3.tar.xz mullvadvpn-affa09e9901fca9dc4a8120c4d339e41a65a1de3.zip | |
Merge branch 'add-masque-server-multithreading'
| -rw-r--r-- | mullvad-masque-proxy/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 4 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/fragment.rs | 20 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/lib.rs | 8 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 226 |
5 files changed, 157 insertions, 103 deletions
diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml index 27b1a9ca1f..232a83ca6d 100644 --- a/mullvad-masque-proxy/Cargo.toml +++ b/mullvad-masque-proxy/Cargo.toml @@ -19,9 +19,9 @@ http = "1" rustls = { version = "0.23", default-features = false } rustls-pemfile = "2.1.3" bytes = "1" +anyhow = { workspace = true } [dev-dependencies] -anyhow = { workspace = true } tokio = { workspace = true, features = ["fs", "macros", "io-util", "rt-multi-thread"] } clap = { workspace = true } rand = "0.8.5" diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index df65339c03..a0041c17d2 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -24,15 +24,13 @@ use crate::{ compute_udp_payload_size, fragment::{self, Fragments}, stats::Stats, - MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, + MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, }; const MAX_HEADER_SIZE: u64 = 8192; const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem"); -const MAX_INFLIGHT_PACKETS: usize = 100; - pub struct Client { client_socket: Arc<UdpSocket>, diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs index 3d9fc94273..94aafad574 100644 --- a/mullvad-masque-proxy/src/fragment.rs +++ b/mullvad-masque-proxy/src/fragment.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{BTreeMap, VecDeque}, - time::{Duration, Instant}, -}; +use std::collections::{BTreeMap, VecDeque}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use h3::proto::varint::VarInt; @@ -79,11 +76,7 @@ impl Fragments { let fragment_count = payload .try_get_u8() .map_err(|_| DefragError::PayloadTooSmall)?; - let fragment = Fragment { - index, - payload, - time_received: Instant::now(), - }; + let fragment = Fragment { index, payload }; // ensure that the fifo has capacity before pushing the new fragment id if self.fragment_index_fifo.len() >= FRAGMENT_BUFFER_CAP { @@ -140,20 +133,11 @@ impl Fragments { Some(payload) } - - pub fn clear_old_fragments(&mut self, max_age: Duration) { - self.fragment_map.retain(|_, fragments| { - fragments - .iter() - .any(|fragment| fragment.time_received.elapsed() <= max_age) - }); - } } struct Fragment { index: u8, payload: Bytes, - time_received: Instant, } /// Fragment packet using the given maximum fragment size (including headers). diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs index 5e9c2902b0..d8467790ec 100644 --- a/mullvad-masque-proxy/src/lib.rs +++ b/mullvad-masque-proxy/src/lib.rs @@ -6,10 +6,16 @@ mod fragment; pub mod server; mod stats; -const PACKET_BUFFER_SIZE: usize = 64 * 1024; pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0); pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1); +/// Minimum size of buffer used to hold UDP packets. +// 1 byte for size of HTTP_MASQUE_DATAGRAM_CONTEXT_ID +const PACKET_BUFFER_SIZE: usize = (u16::MAX - UDP_HEADER_SIZE + 1) as usize; + +/// Maximum number of inflight packets, in both directions. +const MAX_INFLIGHT_PACKETS: usize = 100; + /// Fragment headers size for fragmented packets const FRAGMENT_HEADER_SIZE_FRAGMENTED: u16 = 5; diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs index a57282d172..4644ff5215 100644 --- a/mullvad-masque-proxy/src/server/mod.rs +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -3,9 +3,9 @@ use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, - time::Duration, }; +use anyhow::{ensure, Context}; use bytes::{Bytes, BytesMut}; use h3::{ proto::varint::VarInt, @@ -15,12 +15,12 @@ use h3::{ use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; use http::{Request, StatusCode}; use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming}; -use tokio::{net::UdpSocket, time::interval}; +use tokio::{net::UdpSocket, select, sync::mpsc, task}; use crate::{ compute_udp_payload_size, fragment::{self, Fragments}, - MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, + MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, }; #[derive(Debug, thiserror::Error)] @@ -151,7 +151,7 @@ impl Server { } async fn handle_proxy_request<T: BidiStream<Bytes>>( - mut connection: Connection<h3_quinn::Connection, Bytes>, + connection: Connection<h3_quinn::Connection, Bytes>, quinn_conn: quinn::Connection, request: Request<()>, mut stream: RequestStream<T, Bytes>, @@ -178,99 +178,165 @@ impl Server { return; } - let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr); - let stream_id = stream.id(); - let stream_id_size = VarInt::from(stream_id).size() as u16; - let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); + let udp_socket = Arc::new(udp_socket); + let (client_tx, client_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); + let (send_tx, send_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); - let mut fragments = Fragments::default(); - let mut fragment_id = 0u16; + let mut connection_task = + task::spawn(connection_task(stream_id, connection, send_rx, client_tx)); + let mut proxy_rx_task = task::spawn(proxy_rx_task( + stream_id, + quinn_conn, + target_addr, + mtu, + Arc::clone(&udp_socket), + send_tx, + )); + let mut proxy_tx_task = task::spawn(proxy_tx_task(udp_socket, client_rx)); - let mut interval = interval(Duration::from_secs(3)); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); + select! { + _ = &mut connection_task => {} + _ = &mut proxy_rx_task => {} + _ = &mut proxy_tx_task => {} + } + + connection_task.abort(); + proxy_rx_task.abort(); + proxy_tx_task.abort(); + + // TODO: stream.finish()? + } +} + +/// Forward packets from `send_rx` to `connection`, and from `connection` to `client_tx`. +async fn connection_task( + stream_id: StreamId, + mut connection: Connection<h3_quinn::Connection, Bytes>, + mut send_rx: mpsc::Receiver<Bytes>, + client_tx: mpsc::Sender<Datagram>, +) -> anyhow::Result<()> { + loop { + tokio::select! { + outgoing_packet = send_rx.recv() => { + let Some(outgoing_packet) = outgoing_packet else { + break; // sender is gone + }; - loop { - tokio::select! { - client_send = connection.read_datagram() => { - match client_send { - Ok(Some(received_packet)) => { - handle_client_packet(received_packet, stream_id, &mut fragments, &udp_socket).await; - }, - Ok(None) => { - return; - } - Err(_err) => { - // client connection QUIC connection failed, should return now. - return; - }, + // TODO: is this blocking? + connection.send_datagram(stream_id, outgoing_packet) + .context("Error sending QUIC datagram to client")?; + } + incoming_packet = connection.read_datagram() => match incoming_packet { + Ok(Some(received_packet)) => { + ensure!( + received_packet.stream_id() == stream_id, + "Received unexpected stream ID from client", + ); + + if client_tx.send(received_packet).await.is_err() { + break; // receiver is gone } - }, - recv_result = udp_socket.recv_buf_from(&mut proxy_recv_buf) => { - match recv_result { - Ok((_bytes_received, sender_addr)) => { - if sender_addr != target_addr { - continue - } + } + Ok(None) => break, // EOF + Err(err) => { + return Err(err).context("Error reading QUIC datagram from client"); + } + }, + } + } - let mut received_packet = proxy_recv_buf.split().freeze(); + Ok(()) +} - // Maximum QUIC payload (including fragmentation headers) - let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() { - max_datagram_size as u16 - stream_id_size - } else { - max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size - }; +/// Reassemble and forward packet fragments from `client_rx` to `udp_socket`. +async fn proxy_tx_task(udp_socket: impl AsRef<UdpSocket>, mut client_rx: mpsc::Receiver<Datagram>) { + let udp_socket = udp_socket.as_ref(); + let mut fragments = Fragments::default(); + loop { + let Some(quic_datagram) = client_rx.recv().await else { + break; + }; - if received_packet.len() <= usize::from(maximum_packet_size) { - if connection.send_datagram(stream_id, received_packet).is_err() { - return; - } - } else { - let _ = VarInt::decode(&mut received_packet); + let quic_payload = quic_datagram.into_payload(); - let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; }; - fragment_id += 1; - for payload in fragments { - if connection.send_datagram(stream_id, payload).is_err() { - return; - } - } - }; + let packet = match fragments.handle_incoming_packet(quic_payload) { + Ok(Some(packet)) => packet, + Ok(None) => continue, + Err(_defrag_err) => { + // TODO: log::trace!() + continue; + } + }; - proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); - }, - Err(err) => { - println!("Failed to receive packet from proxy connection: {err}"); - let _ = stream.finish().await; - return; - } - } - }, - _ = interval.tick() => { - fragments.clear_old_fragments( - Duration::from_secs(3) - ); - }, - }; + if let Err(_err) = udp_socket.send(&packet).await { + // TODO: log::trace!() } } } -async fn handle_client_packet( - received_packet: Datagram, +/// Forward packets from `udp_socket` to `send_tx`, and fragment them if they exceed +/// `maximum_packet_size`. +async fn proxy_rx_task( stream_id: StreamId, - fragments: &mut Fragments, - proxy_socket: &UdpSocket, + quinn_conn: quinn::Connection, + target_addr: SocketAddr, + mtu: u16, + udp_socket: impl AsRef<UdpSocket>, + send_tx: mpsc::Sender<Bytes>, ) { - if received_packet.stream_id() != stream_id { - // log::trace!("Received unexpected stream ID from server"); - return; - } + let stream_id_size = VarInt::from(stream_id).size() as u16; + let udp_socket = udp_socket.as_ref(); + let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); + let mut fragment_id = 0u16; + + loop { + proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE); + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); + + let (_n, sender_addr) = match udp_socket.recv_buf_from(&mut proxy_recv_buf).await { + Ok(recv) => recv, + Err(err) => { + println!("Failed to receive packet from proxy socket: {err}"); + continue; + } + }; + + if sender_addr != target_addr { + continue; + } - if let Ok(Some(payload)) = fragments.handle_incoming_packet(received_packet.into_payload()) { - let _ = proxy_socket.send(&payload).await; + let mut received_packet = proxy_recv_buf.split().freeze(); + + let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr); + + // Maximum QUIC payload (including fragmentation headers) + let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() { + max_datagram_size as u16 - stream_id_size + } else { + max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size + }; + + if received_packet.len() < usize::from(maximum_packet_size) { + if send_tx.send(received_packet).await.is_err() { + break; + }; + } else { + // TODO: consider fragmenting packets on a different task + + let _ = VarInt::decode(&mut received_packet); + let Ok(fragments) = + fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) + else { + continue; + }; + fragment_id += 1; + for payload in fragments { + if send_tx.send(payload).await.is_err() { + break; + } + } + }; } } |
