diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-09 18:12:55 +0200 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-09 18:12:55 +0200 |
| commit | a64a885ae56d39ce4213077283409891d5aa2e94 (patch) | |
| tree | 989cc27922c2fc997e98a7a28f694e9c3ab067ba | |
| parent | b925b777c7cc3398abf707d8f9680d384a0d2d68 (diff) | |
| parent | 476737ada224c69aca0554439ec971c4f7b9be9b (diff) | |
| download | mullvadvpn-a64a885ae56d39ce4213077283409891d5aa2e94.tar.xz mullvadvpn-a64a885ae56d39ce4213077283409891d5aa2e94.zip | |
Merge branch 'masque-client-multithreading'
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-client.rs | 2 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 262 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/stats.rs | 46 |
4 files changed, 239 insertions, 72 deletions
diff --git a/mullvad-masque-proxy/examples/masque-client.rs b/mullvad-masque-proxy/examples/masque-client.rs index 6005e1f5f4..893c0e7ffc 100644 --- a/mullvad-masque-proxy/examples/masque-client.rs +++ b/mullvad-masque-proxy/examples/masque-client.rs @@ -26,7 +26,7 @@ pub struct ClientArgs { #[arg(long, short = 'p', default_value = "0")] bind_port: u16, - #[arg(long, short = 'S', default_value = "1000")] + #[arg(long, short = 'S', default_value = "1280")] maximum_packet_size: u16, } diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 182ec580ba..1a93f29ea3 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -1,39 +1,51 @@ -use bytes::{Buf, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use rustls::client::danger::ServerCertVerified; use std::{ fs, future, io, net::{Ipv4Addr, SocketAddr}, path::Path, sync::{Arc, LazyLock}, - time::Duration, }; -use tokio::{net::UdpSocket, time::interval}; +use tokio::{ + net::UdpSocket, + select, + sync::{broadcast, mpsc}, +}; use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId}; -use h3_datagram::datagram_traits::HandleDatagramsExt; +use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; use http::{header, uri::Scheme, Response, StatusCode}; use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, TransportConfig}; -use crate::fragment::{self, Fragments}; +use crate::{ + fragment::{self, Fragments}, + stats::Stats, +}; const MAX_HEADER_SIZE: u64 = 8192; const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem"); +const MAX_INFLIGHT_PACKETS: usize = 100; + pub struct Client { - client_socket: UdpSocket, + client_socket: Arc<UdpSocket>, + /// QUIC connection, used to send the actual HTTP datagrams connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, + /// Send stream over a QUIC connection - this needs to be kept alive to not close the HTTP /// QUIC stream. _send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, + /// Request stream for the currently open request, must not be dropped, otherwise proxy /// connection is terminated request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, - /// Packet fragments - fragments: Fragments, + /// Maximum packet size maximum_packet_size: u16, + + stats: Arc<Stats>, } pub type Result<T> = std::result::Result<T, Error>; @@ -150,11 +162,11 @@ impl Client { Ok(Self { connection, - client_socket, + client_socket: Arc::new(client_socket), request_stream, - fragments: Fragments::default(), _send_stream: send_stream, maximum_packet_size, + stats: Arc::default(), }) } @@ -200,76 +212,184 @@ impl Client { } } - pub async fn run(mut self) -> Result<()> { + pub async fn run(self) -> Result<()> { let stream_id: StreamId = self.request_stream.id(); - // this is the variable ID used to signify UDP payloads in HTTP datagrams. - let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); - let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); - let mut fragment_id = 1u16; - let mut interval = interval(Duration::from_secs(3)); + let (client_tx, client_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); + let (server_tx, server_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); + let (return_addr_tx, return_addr_rx) = broadcast::channel(1); - loop { - tokio::select! { - client_read = self.client_socket.recv_buf_from(&mut client_read_buf) => { - let (_bytes_received, recv_addr) = client_read.map_err(Error::ClientRead)?; - return_addr = recv_addr; + let mut client_socket_rx_task = tokio::task::spawn(client_socket_rx_task( + self.client_socket.clone(), + client_tx, + return_addr_tx, + )); - let mut send_buf = client_read_buf.split().freeze(); - if send_buf.len() < (Into::<usize>::into(self.maximum_packet_size) - 100usize) { - self.connection - .send_datagram(stream_id, send_buf) - .map_err(Error::SendDatagram)?; - } else { - // drop the added context ID, since packet will have to be fragmented. - { - let _ = VarInt::decode(&mut send_buf); - } - for fragment in fragment::fragment_packet( - self.maximum_packet_size, - &mut send_buf, - fragment_id) - ? { - self.connection.send_datagram(stream_id, fragment).map_err(Error::SendDatagram)?; - } - fragment_id = fragment_id.wrapping_add(1); - } + let mut client_socket_tx_task = tokio::task::spawn(client_socket_tx_task( + stream_id, + server_rx, + return_addr_rx, + self.client_socket.clone(), + Arc::clone(&self.stats), + )); - client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); - }, - server_response = self.connection.read_datagram() => { - match server_response { - Ok(Some(response)) => { - if response.stream_id() != stream_id { - // log::trace!("Received datagram with an unexpected stream ID"); - continue; - } - let payload = response.into_payload(); - if let Ok(Some(payload)) = self.fragments.handle_incoming_packet(payload) { - self.client_socket - .send_to(payload.chunk(), return_addr) - .await - .map_err(Error::ClientWrite)?; - } - } - Ok(None) => { - return Ok(()); - } - Err(err) => { - return Err(Error::ProxyResponse(err)); + let mut server_socket_task = tokio::task::spawn(server_socket_task( + stream_id, + self.maximum_packet_size, + self.connection, + server_tx, + client_rx, + Arc::clone(&self.stats), + )); + + let result = select! { + result = &mut client_socket_tx_task => result, + result = &mut client_socket_rx_task => result, + result = &mut server_socket_task => result, + }; + + client_socket_tx_task.abort(); + client_socket_rx_task.abort(); + server_socket_task.abort(); + + result.expect("proxy routine panicked") + } +} + +async fn server_socket_task( + stream_id: StreamId, + maximum_packet_size: u16, + mut connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, + server_tx: mpsc::Sender<Datagram>, + mut client_rx: mpsc::Receiver<Bytes>, + stats: Arc<Stats>, +) -> Result<()> { + let mut fragment_id = 1u16; + + loop { + let packet = select! { + datagram = connection.read_datagram() => { + match datagram { + Ok(Some(response)) => { + if server_tx.send(response).await.is_err() { + break; } } - }, - _ = interval.tick() => { - self.fragments.clear_old_fragments( - Duration::from_secs(3) - ); - }, - }; + Ok(None) => break, + Err(err) => return Err(Error::ProxyResponse(err)), + } + + continue; + } + packet = client_rx.recv() => packet, + }; + + let Some(mut packet) = packet else { break }; + + if packet.len() < (Into::<usize>::into(maximum_packet_size) - 100usize) { + stats.tx(packet.len(), false); + connection + .send_datagram(stream_id, packet) + .map_err(Error::SendDatagram)?; + } else { + // drop the added context ID, since packet will have to be fragmented. + let _ = VarInt::decode(&mut packet); + + for fragment in fragment::fragment_packet(maximum_packet_size, &mut packet, fragment_id) + .map_err(Error::PacketTooLarge)? + { + stats.tx(fragment.len(), true); + connection + .send_datagram(stream_id, fragment) + .map_err(Error::SendDatagram)?; + } + fragment_id = fragment_id.wrapping_add(1); } } + + Result::Ok(()) +} + +async fn client_socket_rx_task( + client_socket: Arc<UdpSocket>, + client_tx: mpsc::Sender<Bytes>, + return_addr_tx: broadcast::Sender<SocketAddr>, +) -> Result<()> { + let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); + let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); + + loop { + client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); + + // this is the variable ID used to signify UDP payloads in HTTP datagrams. + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); + + let (_bytes_received, recv_addr) = client_socket + .recv_buf_from(&mut client_read_buf) + .await + .map_err(Error::ClientRead)?; + + if recv_addr != return_addr { + return_addr = recv_addr; + if return_addr_tx.send(return_addr).is_err() { + break; + } + } + let packet = client_read_buf.split().freeze(); + + if client_tx.send(packet).await.is_err() { + break; + }; + } + + Ok(()) +} + +async fn client_socket_tx_task( + stream_id: StreamId, + mut server_rx: mpsc::Receiver<Datagram>, + mut return_addr_rx: broadcast::Receiver<SocketAddr>, + client_socket: Arc<UdpSocket>, + stats: Arc<Stats>, +) -> Result<()> { + let mut fragments = Fragments::default(); + + let mut return_addr = loop { + match return_addr_rx.recv().await { + Ok(addr) => break addr, + Err(broadcast::error::RecvError::Lagged(..)) => continue, + Err(broadcast::error::RecvError::Closed) => return Ok(()), + } + }; + + loop { + let Some(response) = server_rx.recv().await else { + break; + }; + + match return_addr_rx.try_recv() { + Ok(new_addr) => return_addr = new_addr, + Err(broadcast::error::TryRecvError::Empty) => {} + Err(..) => break, + } + + if response.stream_id() != stream_id { + // log::trace!("Received datagram with an unexpected stream ID"); + continue; + } + let payload = response.into_payload(); + + if let Ok(Some(payload)) = fragments.handle_incoming_packet(payload) { + stats.rx(payload.len(), false /* TODO */); + + client_socket + .send_to(payload.chunk(), return_addr) + .await + .map_err(Error::ClientWrite)?; + } + } + + Result::Ok(()) } fn new_connect_request( diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs index bb973d3a80..d4a4e47812 100644 --- a/mullvad-masque-proxy/src/lib.rs +++ b/mullvad-masque-proxy/src/lib.rs @@ -3,6 +3,7 @@ use h3::proto::varint::VarInt; pub mod client; mod fragment; pub mod server; +mod stats; const PACKET_BUFFER_SIZE: usize = 1700; pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0); diff --git a/mullvad-masque-proxy/src/stats.rs b/mullvad-masque-proxy/src/stats.rs new file mode 100644 index 0000000000..412ddcc9bd --- /dev/null +++ b/mullvad-masque-proxy/src/stats.rs @@ -0,0 +1,46 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[derive(Debug, Default)] +pub struct Stats { + rx_packets: AtomicUsize, + tx_packets: AtomicUsize, + + rx_bytes: AtomicUsize, + tx_bytes: AtomicUsize, + + fragmented_tx_bytes: AtomicUsize, + fragmented_rx_bytes: AtomicUsize, + + fragmented_tx_packets: AtomicUsize, + fragmented_rx_packets: AtomicUsize, +} + +const ORD: Ordering = Ordering::Relaxed; + +impl Drop for Stats { + fn drop(&mut self) { + println!("stats: {:?}", self); + } +} + +impl Stats { + pub fn tx(&self, packet_len: usize, is_fragment: bool) { + self.tx_packets.fetch_add(1, ORD); + self.tx_bytes.fetch_add(packet_len, ORD); + + if is_fragment { + self.fragmented_tx_packets.fetch_add(1, ORD); + self.fragmented_tx_bytes.fetch_add(packet_len, ORD); + } + } + + pub fn rx(&self, packet_len: usize, is_fragment: bool) { + self.rx_packets.fetch_add(1, ORD); + self.rx_bytes.fetch_add(packet_len, ORD); + + if is_fragment { + self.fragmented_rx_packets.fetch_add(1, ORD); + self.fragmented_rx_bytes.fetch_add(packet_len, ORD); + } + } +} |
