diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-07 15:38:58 +0200 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-09 18:02:55 +0200 |
| commit | 959d55b72868b492a9992f237360581e93bbf874 (patch) | |
| tree | 592776e114277f7e13e38e1e712b6519f641e7c1 | |
| parent | f0547062d88f468d304bd5bec5937f1e0786cee3 (diff) | |
| download | mullvadvpn-959d55b72868b492a9992f237360581e93bbf874.tar.xz mullvadvpn-959d55b72868b492a9992f237360581e93bbf874.zip | |
Split Client::run into multiple tasks
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 263 |
1 files changed, 180 insertions, 83 deletions
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 424a26b8d9..1f7cfa18c7 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -1,16 +1,19 @@ -use bytes::{Buf, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use rustls::client::danger::ServerCertVerified; use std::{ fs, future, io, net::{Ipv4Addr, SocketAddr}, path::Path, sync::{Arc, LazyLock}, - time::{Duration, Instant}, }; -use tokio::{net::UdpSocket, time::interval}; +use tokio::{ + net::UdpSocket, + select, + sync::{broadcast, mpsc}, +}; use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId}; -use h3_datagram::datagram_traits::HandleDatagramsExt; +use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; use http::{header, uri::Scheme, Response, StatusCode}; use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, TransportConfig}; @@ -23,18 +26,22 @@ const MAX_HEADER_SIZE: u64 = 8192; const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem"); +const MAX_INFLIGHT_PACKETS: usize = 100; + pub struct Client { - client_socket: UdpSocket, + client_socket: Arc<UdpSocket>, + /// QUIC connection, used to send the actual HTTP datagrams connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, + /// Send stream over a QUIC connection - this needs to be kept alive to not close the HTTP /// QUIC stream. _send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, + /// Request stream for the currently open request, must not be dropped, otherwise proxy /// connection is terminated request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, - /// Packet fragments - fragments: Fragments, + /// Maximum packet size maximum_packet_size: u16, @@ -155,9 +162,8 @@ impl Client { Ok(Self { connection, - client_socket, + client_socket: Arc::new(client_socket), request_stream, - fragments: Fragments::default(), _send_stream: send_stream, maximum_packet_size, stats: Arc::default(), @@ -206,94 +212,185 @@ impl Client { } } - pub async fn run(mut self) -> Result<()> { + pub async fn run(self) -> Result<()> { let stream_id: StreamId = self.request_stream.id(); - // this is the variable ID used to signify UDP payloads in HTTP datagrams. - let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); - let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); - let mut fragment_id = 1u16; - let mut interval = interval(Duration::from_secs(3)); + let (client_tx, client_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); + let (server_tx, server_rx) = mpsc::channel(MAX_INFLIGHT_PACKETS); + let (return_addr_tx, return_addr_rx) = broadcast::channel(1); - let mut prev_stats = Instant::now(); + let mut client_socket_rx_task = tokio::task::spawn(client_socket_rx_task( + self.client_socket.clone(), + client_tx, + return_addr_tx, + )); - loop { - tokio::select! { - client_read = self.client_socket.recv_buf_from(&mut client_read_buf) => { - let (bytes_received, recv_addr) = client_read.map_err(Error::ClientRead)?; - return_addr = recv_addr; + let mut client_socket_tx_task = tokio::task::spawn(client_socket_tx_task( + stream_id, + server_rx, + return_addr_rx, + self.client_socket.clone(), + self.stats.clone(), + )); - /*if prev_stats.elapsed() >= Duration::from_secs(3) { - prev_stats = Instant::now(); - println!("stats: {:?}", self.stats); - }*/ + let mut server_socket_task = tokio::task::spawn(server_socket_task( + stream_id, + self.maximum_packet_size, + self.connection, + server_tx, + client_rx, + Arc::clone(&self.stats), + )); - let mut send_buf = client_read_buf.split().freeze(); - if send_buf.len() < (Into::<usize>::into(self.maximum_packet_size) - 100usize) { - self.stats.tx(bytes_received, false); - self.connection - .send_datagram(stream_id, send_buf) - .map_err(Error::SendDatagram)?; - } else { - // drop the added context ID, since packet will have to be fragmented. - { - let _ = VarInt::decode(&mut send_buf); - } - for fragment in fragment::fragment_packet( - self.maximum_packet_size, - &mut send_buf, - fragment_id) - ? { - self.stats.tx(fragment.len(), true); - self.connection.send_datagram(stream_id, fragment).map_err(Error::SendDatagram)?; - } - fragment_id = fragment_id.wrapping_add(1); - } + let result = select! { + result = &mut client_socket_tx_task => result, + result = &mut client_socket_rx_task => result, + result = &mut server_socket_task => result, + }; - client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); - crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); - }, - server_response = self.connection.read_datagram() => { - match server_response { - Ok(Some(response)) => { - if response.stream_id() != stream_id { - // log::trace!("Received datagram with an unexpected stream ID"); - continue; - } - let payload = response.into_payload(); + client_socket_tx_task.abort(); + client_socket_rx_task.abort(); + server_socket_task.abort(); - /*if prev_stats.elapsed() >= Duration::from_secs(3) { - prev_stats = Instant::now(); - println!("stats: {:?}", self.stats); - }*/ + result.expect("proxy routine panicked") + } +} - let fragment_len = payload.len(); - if let Ok(Some(payload)) = self.fragments.handle_incoming_packet(payload) { - self.stats.rx(payload.len(), fragment_len != payload.len()); +async fn server_socket_task( + stream_id: StreamId, + maximum_packet_size: u16, + mut connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, + server_tx: mpsc::Sender<Datagram>, + mut client_rx: mpsc::Receiver<Bytes>, + stats: Arc<Stats>, +) -> Result<()> { + let mut fragment_id = 1u16; - self.client_socket - .send_to(payload.chunk(), return_addr) - .await - .map_err(Error::ClientWrite)?; - } - } - Ok(None) => { - return Ok(()); - } - Err(err) => { - return Err(Error::ProxyResponse(err)); + loop { + let packet = select! { + datagram = connection.read_datagram() => { + match datagram { + Ok(Some(response)) => { + if server_tx.send(response).await.is_err() { + break; } } - }, - _ = interval.tick() => { - self.fragments.clear_old_fragments( - Duration::from_secs(3) - ); - }, - }; + Ok(None) => break, + Err(err) => return Err(Error::ProxyResponse(err)), + } + + continue; + } + packet = client_rx.recv() => packet, + }; + + let Some(mut packet) = packet else { break }; + + if packet.len() < (Into::<usize>::into(maximum_packet_size) - 100usize) { + stats.tx(packet.len(), false); + connection + .send_datagram(stream_id, packet) + .map_err(Error::SendDatagram)?; + } else { + // drop the added context ID, since packet will have to be fragmented. + let _ = VarInt::decode(&mut packet); + + for fragment in fragment::fragment_packet(maximum_packet_size, &mut packet, fragment_id) + .map_err(Error::PacketTooLarge)? + { + stats.tx(fragment.len(), true); + connection + .send_datagram(stream_id, fragment) + .map_err(Error::SendDatagram)?; + } + fragment_id = fragment_id.wrapping_add(1); + } + } + + Result::Ok(()) +} + +async fn client_socket_rx_task( + client_socket: Arc<UdpSocket>, + client_tx: mpsc::Sender<Bytes>, + return_addr_tx: broadcast::Sender<SocketAddr>, +) -> Result<()> { + let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); + let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); + + loop { + client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); + + // this is the variable ID used to signify UDP payloads in HTTP datagrams. + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); + + let (_bytes_received, recv_addr) = client_socket + .recv_buf_from(&mut client_read_buf) + .await + .map_err(Error::ClientRead)?; + + if recv_addr != return_addr { + return_addr = recv_addr; + if return_addr_tx.send(return_addr).is_err() { + break; + } + } + let packet = client_read_buf.split().freeze(); + + if client_tx.send(packet).await.is_err() { + break; + }; + } + + Ok(()) +} + +async fn client_socket_tx_task( + stream_id: StreamId, + mut server_rx: mpsc::Receiver<Datagram>, + mut return_addr_rx: broadcast::Receiver<SocketAddr>, + client_socket: Arc<UdpSocket>, + stats: Arc<Stats>, +) -> Result<()> { + let mut fragments = Fragments::default(); + + let mut return_addr = loop { + match return_addr_rx.recv().await { + Ok(addr) => break addr, + Err(broadcast::error::RecvError::Lagged(..)) => continue, + Err(broadcast::error::RecvError::Closed) => return Ok(()), + } + }; + + loop { + let Some(response) = server_rx.recv().await else { + break; + }; + + match return_addr_rx.try_recv() { + Ok(new_addr) => return_addr = new_addr, + Err(broadcast::error::TryRecvError::Empty) => {} + Err(..) => break, + } + + if response.stream_id() != stream_id { + // log::trace!("Received datagram with an unexpected stream ID"); + continue; + } + let payload = response.into_payload(); + + let fragment_len = payload.len(); + if let Ok(Some(payload)) = fragments.handle_incoming_packet(payload) { + stats.rx(payload.len(), fragment_len != payload.len()); + + client_socket + .send_to(payload.chunk(), return_addr) + .await + .map_err(Error::ClientWrite)?; } } + + Result::Ok(()) } fn new_connect_request( |
