summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-09 18:12:55 +0200
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-09 18:12:55 +0200
commita64a885ae56d39ce4213077283409891d5aa2e94 (patch)
tree989cc27922c2fc997e98a7a28f694e9c3ab067ba
parentb925b777c7cc3398abf707d8f9680d384a0d2d68 (diff)
parent476737ada224c69aca0554439ec971c4f7b9be9b (diff)
downloadmullvadvpn-a64a885ae56d39ce4213077283409891d5aa2e94.tar.xz
mullvadvpn-a64a885ae56d39ce4213077283409891d5aa2e94.zip
Merge branch 'masque-client-multithreading'
-rw-r--r--mullvad-masque-proxy/examples/masque-client.rs2
-rw-r--r--mullvad-masque-proxy/src/client/mod.rs262
-rw-r--r--mullvad-masque-proxy/src/lib.rs1
-rw-r--r--mullvad-masque-proxy/src/stats.rs46
4 files changed, 239 insertions, 72 deletions
diff --git a/mullvad-masque-proxy/examples/masque-client.rs b/mullvad-masque-proxy/examples/masque-client.rs
index 6005e1f5f4..893c0e7ffc 100644
--- a/mullvad-masque-proxy/examples/masque-client.rs
+++ b/mullvad-masque-proxy/examples/masque-client.rs
@@ -26,7 +26,7 @@ pub struct ClientArgs {
#[arg(long, short = 'p', default_value = "0")]
bind_port: u16,
- #[arg(long, short = 'S', default_value = "1000")]
+ #[arg(long, short = 'S', default_value = "1280")]
maximum_packet_size: u16,
}
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs
index 182ec580ba..1a93f29ea3 100644
--- a/mullvad-masque-proxy/src/client/mod.rs
+++ b/mullvad-masque-proxy/src/client/mod.rs
@@ -1,39 +1,51 @@
-use bytes::{Buf, BytesMut};
+use bytes::{Buf, Bytes, BytesMut};
use rustls::client::danger::ServerCertVerified;
use std::{
fs, future, io,
net::{Ipv4Addr, SocketAddr},
path::Path,
sync::{Arc, LazyLock},
- time::Duration,
};
-use tokio::{net::UdpSocket, time::interval};
+use tokio::{
+ net::UdpSocket,
+ select,
+ sync::{broadcast, mpsc},
+};
use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId};
-use h3_datagram::datagram_traits::HandleDatagramsExt;
+use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt};
use http::{header, uri::Scheme, Response, StatusCode};
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, TransportConfig};
-use crate::fragment::{self, Fragments};
+use crate::{
+ fragment::{self, Fragments},
+ stats::Stats,
+};
const MAX_HEADER_SIZE: u64 = 8192;
const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem");
+const MAX_INFLIGHT_PACKETS: usize = 100;
+
pub struct Client {
- client_socket: UdpSocket,
+ client_socket: Arc<UdpSocket>,
+
/// QUIC connection, used to send the actual HTTP datagrams
connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>,
+
/// Send stream over a QUIC connection - this needs to be kept alive to not close the HTTP
/// QUIC stream.
_send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
+
/// Request stream for the currently open request, must not be dropped, otherwise proxy
/// connection is terminated
request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
- /// Packet fragments
- fragments: Fragments,
+
/// Maximum packet size
maximum_packet_size: u16,
+
+ stats: Arc<Stats>,
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -150,11 +162,11 @@ impl Client {
Ok(Self {
connection,
- client_socket,
+ client_socket: Arc::new(client_socket),
request_stream,
- fragments: Fragments::default(),
_send_stream: send_stream,
maximum_packet_size,
+ stats: Arc::default(),
})
}
@@ -200,76 +212,184 @@ impl Client {
}
}
- pub async fn run(mut self) -> Result<()> {
+ pub async fn run(self) -> Result<()> {
let stream_id: StreamId = self.request_stream.id();
- // this is the variable ID used to signify UDP payloads in HTTP datagrams.
- let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024);
- crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf);
- let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
- let mut fragment_id = 1u16;
- let mut interval = interval(Duration::from_secs(3));
+ let (client_tx, client_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS);
+ let (server_tx, server_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS);
+ let (return_addr_tx, return_addr_rx) = broadcast::channel(1);
- loop {
- tokio::select! {
- client_read = self.client_socket.recv_buf_from(&mut client_read_buf) => {
- let (_bytes_received, recv_addr) = client_read.map_err(Error::ClientRead)?;
- return_addr = recv_addr;
+ let mut client_socket_rx_task = tokio::task::spawn(client_socket_rx_task(
+ self.client_socket.clone(),
+ client_tx,
+ return_addr_tx,
+ ));
- let mut send_buf = client_read_buf.split().freeze();
- if send_buf.len() < (Into::<usize>::into(self.maximum_packet_size) - 100usize) {
- self.connection
- .send_datagram(stream_id, send_buf)
- .map_err(Error::SendDatagram)?;
- } else {
- // drop the added context ID, since packet will have to be fragmented.
- {
- let _ = VarInt::decode(&mut send_buf);
- }
- for fragment in fragment::fragment_packet(
- self.maximum_packet_size,
- &mut send_buf,
- fragment_id)
- ? {
- self.connection.send_datagram(stream_id, fragment).map_err(Error::SendDatagram)?;
- }
- fragment_id = fragment_id.wrapping_add(1);
- }
+ let mut client_socket_tx_task = tokio::task::spawn(client_socket_tx_task(
+ stream_id,
+ server_rx,
+ return_addr_rx,
+ self.client_socket.clone(),
+ Arc::clone(&self.stats),
+ ));
- client_read_buf.reserve(crate::PACKET_BUFFER_SIZE);
- crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf);
- },
- server_response = self.connection.read_datagram() => {
- match server_response {
- Ok(Some(response)) => {
- if response.stream_id() != stream_id {
- // log::trace!("Received datagram with an unexpected stream ID");
- continue;
- }
- let payload = response.into_payload();
- if let Ok(Some(payload)) = self.fragments.handle_incoming_packet(payload) {
- self.client_socket
- .send_to(payload.chunk(), return_addr)
- .await
- .map_err(Error::ClientWrite)?;
- }
- }
- Ok(None) => {
- return Ok(());
- }
- Err(err) => {
- return Err(Error::ProxyResponse(err));
+ let mut server_socket_task = tokio::task::spawn(server_socket_task(
+ stream_id,
+ self.maximum_packet_size,
+ self.connection,
+ server_tx,
+ client_rx,
+ Arc::clone(&self.stats),
+ ));
+
+ let result = select! {
+ result = &mut client_socket_tx_task => result,
+ result = &mut client_socket_rx_task => result,
+ result = &mut server_socket_task => result,
+ };
+
+ client_socket_tx_task.abort();
+ client_socket_rx_task.abort();
+ server_socket_task.abort();
+
+ result.expect("proxy routine panicked")
+ }
+}
+
+async fn server_socket_task(
+ stream_id: StreamId,
+ maximum_packet_size: u16,
+ mut connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>,
+ server_tx: mpsc::Sender<Datagram>,
+ mut client_rx: mpsc::Receiver<Bytes>,
+ stats: Arc<Stats>,
+) -> Result<()> {
+ let mut fragment_id = 1u16;
+
+ loop {
+ let packet = select! {
+ datagram = connection.read_datagram() => {
+ match datagram {
+ Ok(Some(response)) => {
+ if server_tx.send(response).await.is_err() {
+ break;
}
}
- },
- _ = interval.tick() => {
- self.fragments.clear_old_fragments(
- Duration::from_secs(3)
- );
- },
- };
+ Ok(None) => break,
+ Err(err) => return Err(Error::ProxyResponse(err)),
+ }
+
+ continue;
+ }
+ packet = client_rx.recv() => packet,
+ };
+
+ let Some(mut packet) = packet else { break };
+
+ if packet.len() < (Into::<usize>::into(maximum_packet_size) - 100usize) {
+ stats.tx(packet.len(), false);
+ connection
+ .send_datagram(stream_id, packet)
+ .map_err(Error::SendDatagram)?;
+ } else {
+ // drop the added context ID, since packet will have to be fragmented.
+ let _ = VarInt::decode(&mut packet);
+
+ for fragment in fragment::fragment_packet(maximum_packet_size, &mut packet, fragment_id)
+ .map_err(Error::PacketTooLarge)?
+ {
+ stats.tx(fragment.len(), true);
+ connection
+ .send_datagram(stream_id, fragment)
+ .map_err(Error::SendDatagram)?;
+ }
+ fragment_id = fragment_id.wrapping_add(1);
}
}
+
+ Result::Ok(())
+}
+
+async fn client_socket_rx_task(
+ client_socket: Arc<UdpSocket>,
+ client_tx: mpsc::Sender<Bytes>,
+ return_addr_tx: broadcast::Sender<SocketAddr>,
+) -> Result<()> {
+ let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024);
+ let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
+
+ loop {
+ client_read_buf.reserve(crate::PACKET_BUFFER_SIZE);
+
+ // this is the variable ID used to signify UDP payloads in HTTP datagrams.
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf);
+
+ let (_bytes_received, recv_addr) = client_socket
+ .recv_buf_from(&mut client_read_buf)
+ .await
+ .map_err(Error::ClientRead)?;
+
+ if recv_addr != return_addr {
+ return_addr = recv_addr;
+ if return_addr_tx.send(return_addr).is_err() {
+ break;
+ }
+ }
+ let packet = client_read_buf.split().freeze();
+
+ if client_tx.send(packet).await.is_err() {
+ break;
+ };
+ }
+
+ Ok(())
+}
+
+async fn client_socket_tx_task(
+ stream_id: StreamId,
+ mut server_rx: mpsc::Receiver<Datagram>,
+ mut return_addr_rx: broadcast::Receiver<SocketAddr>,
+ client_socket: Arc<UdpSocket>,
+ stats: Arc<Stats>,
+) -> Result<()> {
+ let mut fragments = Fragments::default();
+
+ let mut return_addr = loop {
+ match return_addr_rx.recv().await {
+ Ok(addr) => break addr,
+ Err(broadcast::error::RecvError::Lagged(..)) => continue,
+ Err(broadcast::error::RecvError::Closed) => return Ok(()),
+ }
+ };
+
+ loop {
+ let Some(response) = server_rx.recv().await else {
+ break;
+ };
+
+ match return_addr_rx.try_recv() {
+ Ok(new_addr) => return_addr = new_addr,
+ Err(broadcast::error::TryRecvError::Empty) => {}
+ Err(..) => break,
+ }
+
+ if response.stream_id() != stream_id {
+ // log::trace!("Received datagram with an unexpected stream ID");
+ continue;
+ }
+ let payload = response.into_payload();
+
+ if let Ok(Some(payload)) = fragments.handle_incoming_packet(payload) {
+ stats.rx(payload.len(), false /* TODO */);
+
+ client_socket
+ .send_to(payload.chunk(), return_addr)
+ .await
+ .map_err(Error::ClientWrite)?;
+ }
+ }
+
+ Result::Ok(())
}
fn new_connect_request(
diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs
index bb973d3a80..d4a4e47812 100644
--- a/mullvad-masque-proxy/src/lib.rs
+++ b/mullvad-masque-proxy/src/lib.rs
@@ -3,6 +3,7 @@ use h3::proto::varint::VarInt;
pub mod client;
mod fragment;
pub mod server;
+mod stats;
const PACKET_BUFFER_SIZE: usize = 1700;
pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0);
diff --git a/mullvad-masque-proxy/src/stats.rs b/mullvad-masque-proxy/src/stats.rs
new file mode 100644
index 0000000000..412ddcc9bd
--- /dev/null
+++ b/mullvad-masque-proxy/src/stats.rs
@@ -0,0 +1,46 @@
+use std::sync::atomic::{AtomicUsize, Ordering};
+
+#[derive(Debug, Default)]
+pub struct Stats {
+ rx_packets: AtomicUsize,
+ tx_packets: AtomicUsize,
+
+ rx_bytes: AtomicUsize,
+ tx_bytes: AtomicUsize,
+
+ fragmented_tx_bytes: AtomicUsize,
+ fragmented_rx_bytes: AtomicUsize,
+
+ fragmented_tx_packets: AtomicUsize,
+ fragmented_rx_packets: AtomicUsize,
+}
+
+const ORD: Ordering = Ordering::Relaxed;
+
+impl Drop for Stats {
+ fn drop(&mut self) {
+ println!("stats: {:?}", self);
+ }
+}
+
+impl Stats {
+ pub fn tx(&self, packet_len: usize, is_fragment: bool) {
+ self.tx_packets.fetch_add(1, ORD);
+ self.tx_bytes.fetch_add(packet_len, ORD);
+
+ if is_fragment {
+ self.fragmented_tx_packets.fetch_add(1, ORD);
+ self.fragmented_tx_bytes.fetch_add(packet_len, ORD);
+ }
+ }
+
+ pub fn rx(&self, packet_len: usize, is_fragment: bool) {
+ self.rx_packets.fetch_add(1, ORD);
+ self.rx_bytes.fetch_add(packet_len, ORD);
+
+ if is_fragment {
+ self.fragmented_rx_packets.fetch_add(1, ORD);
+ self.fragmented_rx_bytes.fetch_add(packet_len, ORD);
+ }
+ }
+}