summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-10 18:43:12 +0200
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-10 18:43:12 +0200
commitaffa09e9901fca9dc4a8120c4d339e41a65a1de3 (patch)
tree5ddf69060af30b6a7a78ddf3ce47eceec958d917
parentc870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2 (diff)
parent4846498154e58fa06284f49746ed2f0c014f705e (diff)
downloadmullvadvpn-affa09e9901fca9dc4a8120c4d339e41a65a1de3.tar.xz
mullvadvpn-affa09e9901fca9dc4a8120c4d339e41a65a1de3.zip
Merge branch 'add-masque-server-multithreading'
-rw-r--r--mullvad-masque-proxy/Cargo.toml2
-rw-r--r--mullvad-masque-proxy/src/client/mod.rs4
-rw-r--r--mullvad-masque-proxy/src/fragment.rs20
-rw-r--r--mullvad-masque-proxy/src/lib.rs8
-rw-r--r--mullvad-masque-proxy/src/server/mod.rs226
5 files changed, 157 insertions, 103 deletions
diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml
index 27b1a9ca1f..232a83ca6d 100644
--- a/mullvad-masque-proxy/Cargo.toml
+++ b/mullvad-masque-proxy/Cargo.toml
@@ -19,9 +19,9 @@ http = "1"
rustls = { version = "0.23", default-features = false }
rustls-pemfile = "2.1.3"
bytes = "1"
+anyhow = { workspace = true }
[dev-dependencies]
-anyhow = { workspace = true }
tokio = { workspace = true, features = ["fs", "macros", "io-util", "rt-multi-thread"] }
clap = { workspace = true }
rand = "0.8.5"
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs
index df65339c03..a0041c17d2 100644
--- a/mullvad-masque-proxy/src/client/mod.rs
+++ b/mullvad-masque-proxy/src/client/mod.rs
@@ -24,15 +24,13 @@ use crate::{
compute_udp_payload_size,
fragment::{self, Fragments},
stats::Stats,
- MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
+ MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
};
const MAX_HEADER_SIZE: u64 = 8192;
const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem");
-const MAX_INFLIGHT_PACKETS: usize = 100;
-
pub struct Client {
client_socket: Arc<UdpSocket>,
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs
index 3d9fc94273..94aafad574 100644
--- a/mullvad-masque-proxy/src/fragment.rs
+++ b/mullvad-masque-proxy/src/fragment.rs
@@ -1,7 +1,4 @@
-use std::{
- collections::{BTreeMap, VecDeque},
- time::{Duration, Instant},
-};
+use std::collections::{BTreeMap, VecDeque};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use h3::proto::varint::VarInt;
@@ -79,11 +76,7 @@ impl Fragments {
let fragment_count = payload
.try_get_u8()
.map_err(|_| DefragError::PayloadTooSmall)?;
- let fragment = Fragment {
- index,
- payload,
- time_received: Instant::now(),
- };
+ let fragment = Fragment { index, payload };
// ensure that the fifo has capacity before pushing the new fragment id
if self.fragment_index_fifo.len() >= FRAGMENT_BUFFER_CAP {
@@ -140,20 +133,11 @@ impl Fragments {
Some(payload)
}
-
- pub fn clear_old_fragments(&mut self, max_age: Duration) {
- self.fragment_map.retain(|_, fragments| {
- fragments
- .iter()
- .any(|fragment| fragment.time_received.elapsed() <= max_age)
- });
- }
}
struct Fragment {
index: u8,
payload: Bytes,
- time_received: Instant,
}
/// Fragment packet using the given maximum fragment size (including headers).
diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs
index 5e9c2902b0..d8467790ec 100644
--- a/mullvad-masque-proxy/src/lib.rs
+++ b/mullvad-masque-proxy/src/lib.rs
@@ -6,10 +6,16 @@ mod fragment;
pub mod server;
mod stats;
-const PACKET_BUFFER_SIZE: usize = 64 * 1024;
pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0);
pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1);
+/// Minimum size of buffer used to hold UDP packets.
+// 1 byte for size of HTTP_MASQUE_DATAGRAM_CONTEXT_ID
+const PACKET_BUFFER_SIZE: usize = (u16::MAX - UDP_HEADER_SIZE + 1) as usize;
+
+/// Maximum number of inflight packets, in both directions.
+const MAX_INFLIGHT_PACKETS: usize = 100;
+
/// Fragment headers size for fragmented packets
const FRAGMENT_HEADER_SIZE_FRAGMENTED: u16 = 5;
diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs
index a57282d172..4644ff5215 100644
--- a/mullvad-masque-proxy/src/server/mod.rs
+++ b/mullvad-masque-proxy/src/server/mod.rs
@@ -3,9 +3,9 @@ use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
- time::Duration,
};
+use anyhow::{ensure, Context};
use bytes::{Bytes, BytesMut};
use h3::{
proto::varint::VarInt,
@@ -15,12 +15,12 @@ use h3::{
use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt};
use http::{Request, StatusCode};
use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming};
-use tokio::{net::UdpSocket, time::interval};
+use tokio::{net::UdpSocket, select, sync::mpsc, task};
use crate::{
compute_udp_payload_size,
fragment::{self, Fragments},
- MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
+ MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
};
#[derive(Debug, thiserror::Error)]
@@ -151,7 +151,7 @@ impl Server {
}
async fn handle_proxy_request<T: BidiStream<Bytes>>(
- mut connection: Connection<h3_quinn::Connection, Bytes>,
+ connection: Connection<h3_quinn::Connection, Bytes>,
quinn_conn: quinn::Connection,
request: Request<()>,
mut stream: RequestStream<T, Bytes>,
@@ -178,99 +178,165 @@ impl Server {
return;
}
- let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr);
-
let stream_id = stream.id();
- let stream_id_size = VarInt::from(stream_id).size() as u16;
- let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE);
+ let udp_socket = Arc::new(udp_socket);
+ let (client_tx, client_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS);
+ let (send_tx, send_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS);
- let mut fragments = Fragments::default();
- let mut fragment_id = 0u16;
+ let mut connection_task =
+ task::spawn(connection_task(stream_id, connection, send_rx, client_tx));
+ let mut proxy_rx_task = task::spawn(proxy_rx_task(
+ stream_id,
+ quinn_conn,
+ target_addr,
+ mtu,
+ Arc::clone(&udp_socket),
+ send_tx,
+ ));
+ let mut proxy_tx_task = task::spawn(proxy_tx_task(udp_socket, client_rx));
- let mut interval = interval(Duration::from_secs(3));
- crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf);
+ select! {
+ _ = &mut connection_task => {}
+ _ = &mut proxy_rx_task => {}
+ _ = &mut proxy_tx_task => {}
+ }
+
+ connection_task.abort();
+ proxy_rx_task.abort();
+ proxy_tx_task.abort();
+
+ // TODO: stream.finish()?
+ }
+}
+
+/// Forward packets from `send_rx` to `connection`, and from `connection` to `client_tx`.
+async fn connection_task(
+ stream_id: StreamId,
+ mut connection: Connection<h3_quinn::Connection, Bytes>,
+ mut send_rx: mpsc::Receiver<Bytes>,
+ client_tx: mpsc::Sender<Datagram>,
+) -> anyhow::Result<()> {
+ loop {
+ tokio::select! {
+ outgoing_packet = send_rx.recv() => {
+ let Some(outgoing_packet) = outgoing_packet else {
+ break; // sender is gone
+ };
- loop {
- tokio::select! {
- client_send = connection.read_datagram() => {
- match client_send {
- Ok(Some(received_packet)) => {
- handle_client_packet(received_packet, stream_id, &mut fragments, &udp_socket).await;
- },
- Ok(None) => {
- return;
- }
- Err(_err) => {
- // client connection QUIC connection failed, should return now.
- return;
- },
+ // TODO: is this blocking?
+ connection.send_datagram(stream_id, outgoing_packet)
+ .context("Error sending QUIC datagram to client")?;
+ }
+ incoming_packet = connection.read_datagram() => match incoming_packet {
+ Ok(Some(received_packet)) => {
+ ensure!(
+ received_packet.stream_id() == stream_id,
+ "Received unexpected stream ID from client",
+ );
+
+ if client_tx.send(received_packet).await.is_err() {
+ break; // receiver is gone
}
- },
- recv_result = udp_socket.recv_buf_from(&mut proxy_recv_buf) => {
- match recv_result {
- Ok((_bytes_received, sender_addr)) => {
- if sender_addr != target_addr {
- continue
- }
+ }
+ Ok(None) => break, // EOF
+ Err(err) => {
+ return Err(err).context("Error reading QUIC datagram from client");
+ }
+ },
+ }
+ }
- let mut received_packet = proxy_recv_buf.split().freeze();
+ Ok(())
+}
- // Maximum QUIC payload (including fragmentation headers)
- let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() {
- max_datagram_size as u16 - stream_id_size
- } else {
- max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size
- };
+/// Reassemble and forward packet fragments from `client_rx` to `udp_socket`.
+async fn proxy_tx_task(udp_socket: impl AsRef<UdpSocket>, mut client_rx: mpsc::Receiver<Datagram>) {
+ let udp_socket = udp_socket.as_ref();
+ let mut fragments = Fragments::default();
+ loop {
+ let Some(quic_datagram) = client_rx.recv().await else {
+ break;
+ };
- if received_packet.len() <= usize::from(maximum_packet_size) {
- if connection.send_datagram(stream_id, received_packet).is_err() {
- return;
- }
- } else {
- let _ = VarInt::decode(&mut received_packet);
+ let quic_payload = quic_datagram.into_payload();
- let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; };
- fragment_id += 1;
- for payload in fragments {
- if connection.send_datagram(stream_id, payload).is_err() {
- return;
- }
- }
- };
+ let packet = match fragments.handle_incoming_packet(quic_payload) {
+ Ok(Some(packet)) => packet,
+ Ok(None) => continue,
+ Err(_defrag_err) => {
+ // TODO: log::trace!()
+ continue;
+ }
+ };
- proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE);
- crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf);
- },
- Err(err) => {
- println!("Failed to receive packet from proxy connection: {err}");
- let _ = stream.finish().await;
- return;
- }
- }
- },
- _ = interval.tick() => {
- fragments.clear_old_fragments(
- Duration::from_secs(3)
- );
- },
- };
+ if let Err(_err) = udp_socket.send(&packet).await {
+ // TODO: log::trace!()
}
}
}
-async fn handle_client_packet(
- received_packet: Datagram,
+/// Forward packets from `udp_socket` to `send_tx`, and fragment them if they exceed
+/// `maximum_packet_size`.
+async fn proxy_rx_task(
stream_id: StreamId,
- fragments: &mut Fragments,
- proxy_socket: &UdpSocket,
+ quinn_conn: quinn::Connection,
+ target_addr: SocketAddr,
+ mtu: u16,
+ udp_socket: impl AsRef<UdpSocket>,
+ send_tx: mpsc::Sender<Bytes>,
) {
- if received_packet.stream_id() != stream_id {
- // log::trace!("Received unexpected stream ID from server");
- return;
- }
+ let stream_id_size = VarInt::from(stream_id).size() as u16;
+ let udp_socket = udp_socket.as_ref();
+ let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE);
+ let mut fragment_id = 0u16;
+
+ loop {
+ proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE);
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf);
+
+ let (_n, sender_addr) = match udp_socket.recv_buf_from(&mut proxy_recv_buf).await {
+ Ok(recv) => recv,
+ Err(err) => {
+ println!("Failed to receive packet from proxy socket: {err}");
+ continue;
+ }
+ };
+
+ if sender_addr != target_addr {
+ continue;
+ }
- if let Ok(Some(payload)) = fragments.handle_incoming_packet(received_packet.into_payload()) {
- let _ = proxy_socket.send(&payload).await;
+ let mut received_packet = proxy_recv_buf.split().freeze();
+
+ let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr);
+
+ // Maximum QUIC payload (including fragmentation headers)
+ let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() {
+ max_datagram_size as u16 - stream_id_size
+ } else {
+ max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size
+ };
+
+ if received_packet.len() < usize::from(maximum_packet_size) {
+ if send_tx.send(received_packet).await.is_err() {
+ break;
+ };
+ } else {
+ // TODO: consider fragmenting packets on a different task
+
+ let _ = VarInt::decode(&mut received_packet);
+ let Ok(fragments) =
+ fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id)
+ else {
+ continue;
+ };
+ fragment_id += 1;
+ for payload in fragments {
+ if send_tx.send(payload).await.is_err() {
+ break;
+ }
+ }
+ };
}
}