diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-04-10 16:53:29 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-04-10 16:53:29 +0200 |
| commit | e4591042b587d73794515da0be9627aedb27ac94 (patch) | |
| tree | 50e21c575da46099c477092baa706dec2e57e47e | |
| parent | 510c224d5fe3c769b834b64ec81b85fc21bfb792 (diff) | |
| parent | afdb6280d30c4bb05c51695075f00530684a37dc (diff) | |
| download | mullvadvpn-e4591042b587d73794515da0be9627aedb27ac94.tar.xz mullvadvpn-e4591042b587d73794515da0be9627aedb27ac94.zip | |
Merge branch 'fix-masque-sizes'
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-client.rs | 6 | ||||
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-server.rs | 4 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 96 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/fragment.rs | 68 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/lib.rs | 31 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 93 | ||||
| -rw-r--r-- | mullvad-masque-proxy/tests/proxy.rs | 162 |
7 files changed, 350 insertions, 110 deletions
diff --git a/mullvad-masque-proxy/examples/masque-client.rs b/mullvad-masque-proxy/examples/masque-client.rs index 780e95d044..3a6b591640 100644 --- a/mullvad-masque-proxy/examples/masque-client.rs +++ b/mullvad-masque-proxy/examples/masque-client.rs @@ -31,7 +31,7 @@ pub struct ClientArgs { /// Maximum packet size #[arg(long, short = 'S', default_value = "1280")] - maximum_packet_size: u16, + mtu: u16, } #[tokio::main] @@ -42,7 +42,7 @@ async fn main() { root_cert_path, server_hostname, bind_port, - maximum_packet_size, + mtu, } = ClientArgs::parse(); let tls_config = match root_cert_path { @@ -67,7 +67,7 @@ async fn main() { target_addr, &server_hostname, tls_config, - maximum_packet_size, + mtu, ) .await; if let Err(err) = &client { diff --git a/mullvad-masque-proxy/examples/masque-server.rs b/mullvad-masque-proxy/examples/masque-server.rs index 9c07423d9c..fe216070ef 100644 --- a/mullvad-masque-proxy/examples/masque-server.rs +++ b/mullvad-masque-proxy/examples/masque-server.rs @@ -28,7 +28,7 @@ pub struct ServerArgs { /// Maximum packet size #[arg(long, short = 'm', default_value = "1700")] - maximum_packet_size: u16, + mtu: u16, } #[tokio::main] @@ -42,7 +42,7 @@ async fn main() { args.bind_addr, args.allowed_ips.iter().cloned().collect(), tls_config.into(), - args.maximum_packet_size, + args.mtu, ) .expect("Failed to initialize server"); println!("Listening on {}", args.bind_addr); diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 4d59f0dd03..df65339c03 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -21,8 +21,10 @@ use quinn::{ }; use crate::{ + compute_udp_payload_size, fragment::{self, Fragments}, stats::Stats, + MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, }; const MAX_HEADER_SIZE: u64 = 8192; @@ -34,6 +36,9 @@ const MAX_INFLIGHT_PACKETS: usize = 100; pub struct Client { client_socket: Arc<UdpSocket>, + /// QUIC endpoint + quinn_conn: quinn::Connection, + /// QUIC connection, used to send the actual HTTP datagrams connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, @@ -45,8 +50,8 @@ pub struct Client { /// connection is terminated request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, - /// Maximum packet size - maximum_packet_size: u16, + /// Maximum UDP payload size (packet size including QUIC overhead) + max_udp_payload_size: u16, stats: Arc<Stats>, } @@ -61,6 +66,8 @@ pub enum Error { Connect(#[from] quinn::ConnectError), #[error("Failed to connect to QUIC endpoint")] Connection(#[from] quinn::ConnectionError), + #[error("Invalid MTU: must be at least {min_mtu}")] + InvalidMtu { min_mtu: u16 }, #[error("Invalid max_udp_payload_size")] InvalidMaxUdpPayload(#[source] quinn::ConfigError), #[error("Connection closed while sending request to initiate proxying")] @@ -100,7 +107,7 @@ impl Client { local_addr: SocketAddr, target_addr: SocketAddr, server_host: &str, - maximum_packet_size: u16, + mtu: u16, ) -> Result<Self> { Self::connect_with_tls_config( client_socket, @@ -109,7 +116,7 @@ impl Client { target_addr, server_host, default_tls_config(), - maximum_packet_size, + mtu, ) .await } @@ -121,7 +128,7 @@ impl Client { target_addr: SocketAddr, server_host: &str, tls_config: Arc<rustls::ClientConfig>, - maximum_packet_size: u16, + mtu: u16, ) -> Result<Self> { let quic_client_config = QuicClientConfig::try_from(tls_config) .expect("Failed to construct a valid TLS configuration"); @@ -140,7 +147,7 @@ impl Client { target_addr, server_host, client_config, - maximum_packet_size, + mtu, ) .await } @@ -152,34 +159,56 @@ impl Client { target_addr: SocketAddr, server_host: &str, client_config: ClientConfig, - maximum_packet_size: u16, + mtu: u16, ) -> Result<Self> { - let endpoint = Self::setup_quic_endpoint(local_addr, maximum_packet_size)?; + Self::validate_mtu(mtu, target_addr)?; + + let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr); + + let endpoint = Self::setup_quic_endpoint(local_addr, max_udp_payload_size)?; let connecting = endpoint.connect_with(client_config, server_addr, server_host)?; let connection = connecting.await?; - let (connection, send_stream, request_stream) = - Self::setup_h3_connection(connection, target_addr, server_host, maximum_packet_size) - .await?; + let (h3_connection, send_stream, request_stream) = Self::setup_h3_connection( + connection.clone(), + target_addr, + server_host, + max_udp_payload_size, + ) + .await?; Ok(Self { - connection, + quinn_conn: connection, + connection: h3_connection, client_socket: Arc::new(client_socket), request_stream, _send_stream: send_stream, - maximum_packet_size, + max_udp_payload_size, stats: Arc::default(), }) } - fn setup_quic_endpoint(local_addr: SocketAddr, maximum_packet_size: u16) -> Result<Endpoint> { + const fn validate_mtu(mtu: u16, target_addr: SocketAddr) -> Result<()> { + let min_mtu = if target_addr.is_ipv4() { + MIN_IPV4_MTU + } else { + MIN_IPV6_MTU + }; + if mtu >= min_mtu { + Ok(()) + } else { + Err(Error::InvalidMtu { min_mtu }) + } + } + + fn setup_quic_endpoint(local_addr: SocketAddr, max_udp_payload_size: u16) -> Result<Endpoint> { let local_socket = std::net::UdpSocket::bind(local_addr).map_err(Error::Bind)?; let mut endpoint_config = EndpointConfig::default(); endpoint_config - .max_udp_payload_size(maximum_packet_size) + .max_udp_payload_size(max_udp_payload_size) .map_err(Error::InvalidMaxUdpPayload)?; Endpoint::new(endpoint_config, None, local_socket, Arc::new(TokioRuntime)) @@ -191,7 +220,7 @@ impl Client { connection: quinn::Connection, target: SocketAddr, server_host: &str, - maximum_packet_size: u16, + mtu: u16, ) -> Result<( client::Connection<h3_quinn::Connection, bytes::Bytes>, client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, @@ -205,7 +234,7 @@ impl Client { .await .map_err(Error::CreateClient)?; - let request = new_connect_request(target, &server_host, maximum_packet_size)?; + let request = new_connect_request(target, &server_host, mtu)?; let request_future = async move { let mut request_stream = send_stream.send_request(request).await?; @@ -251,7 +280,8 @@ impl Client { let mut server_socket_task = tokio::task::spawn(server_socket_task( stream_id, - self.maximum_packet_size, + self.max_udp_payload_size, + self.quinn_conn, self.connection, server_tx, client_rx, @@ -274,13 +304,15 @@ impl Client { async fn server_socket_task( stream_id: StreamId, - maximum_packet_size: u16, + max_udp_payload_size: u16, + quinn_conn: quinn::Connection, mut connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, server_tx: mpsc::Sender<Datagram>, mut client_rx: mpsc::Receiver<Bytes>, stats: Arc<Stats>, ) -> Result<()> { let mut fragment_id = 1u16; + let stream_id_size = VarInt::from(stream_id).size() as u16; loop { let packet = select! { @@ -302,7 +334,14 @@ async fn server_socket_task( let Some(mut packet) = packet else { break }; - if packet.len() < (Into::<usize>::into(maximum_packet_size) - 100usize) { + // Maximum QUIC payload (including fragmentation headers) + let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() { + max_datagram_size as u16 - stream_id_size + } else { + max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size + }; + + if packet.len() <= usize::from(maximum_packet_size) { stats.tx(packet.len(), false); connection .send_datagram(stream_id, packet) @@ -314,6 +353,8 @@ async fn server_socket_task( for fragment in fragment::fragment_packet(maximum_packet_size, &mut packet, fragment_id) .map_err(Error::PacketTooLarge)? { + debug_assert!(fragment.len() <= maximum_packet_size as usize); + stats.tx(fragment.len(), true); connection .send_datagram(stream_id, fragment) @@ -331,7 +372,7 @@ async fn client_socket_rx_task( client_tx: mpsc::Sender<Bytes>, return_addr_tx: broadcast::Sender<SocketAddr>, ) -> Result<()> { - let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); + let mut client_read_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); loop { @@ -411,7 +452,7 @@ async fn client_socket_tx_task( fn new_connect_request( socket_addr: SocketAddr, authority: &dyn AsRef<str>, - maximum_packet_size: u16, + mtu: u16, ) -> Result<http::Request<()>> { let host = socket_addr.ip(); let port = socket_addr.port(); @@ -432,7 +473,7 @@ fn new_connect_request( // TODO: Not needed since we set the max_udp_payload_size transport param .header( b"X-Mullvad-Uplink-Mtu".as_slice(), - format!("{maximum_packet_size}"), + format!("{mtu}"), ) .body(()) .expect("failed to construct a body"); @@ -503,9 +544,12 @@ fn read_cert_store_from_reader(reader: &mut dyn io::BufRead) -> Result<rustls::R Ok(cert_store) } -#[test] -fn test_zero_stream_id() { - h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?"); +#[cfg(test)] +mod test { + #[test] + fn test_zero_stream_id() { + h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?"); + } } #[derive(Debug)] diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs index f60c3b927d..c699ed2d78 100644 --- a/mullvad-masque-proxy/src/fragment.rs +++ b/mullvad-masque-proxy/src/fragment.rs @@ -6,6 +6,8 @@ use std::{ use bytes::{Buf, BufMut, Bytes, BytesMut}; use h3::proto::varint::VarInt; +use crate::FRAGMENT_HEADER_SIZE_FRAGMENTED; + #[derive(Default)] pub struct Fragments { fragment_map: BTreeMap<u16, Vec<Fragment>>, @@ -102,19 +104,29 @@ struct Fragment { time_received: Instant, } +/// Fragment packet using the given maximum fragment size (including headers). +/// +/// `payload` must not contain any fragmentation headers. +/// `maximum_packet_size` is the maximum fragment size including headers. pub fn fragment_packet( maximum_packet_size: u16, payload: &'_ mut Bytes, packet_id: u16, ) -> Result<impl Iterator<Item = Bytes> + '_, PacketTooLarge> { - let num_fragments: usize = payload.chunks(maximum_packet_size.into()).count(); + let fragment_payload_size = maximum_packet_size - FRAGMENT_HEADER_SIZE_FRAGMENTED; + + let num_fragments: usize = payload.chunks(fragment_payload_size.into()).count(); let Ok(fragment_count): std::result::Result<u8, _> = num_fragments.try_into() else { return Err(PacketTooLarge(payload.len())); }; - let iterator = payload.chunks(maximum_packet_size.into()).enumerate().map( - move |(fragment_index, fragment_payload)| { - let mut fragment = BytesMut::with_capacity((maximum_packet_size + 1).into()); + let iterator = payload + .chunks(fragment_payload_size.into()) + .enumerate() + .map(move |(fragment_index, fragment_payload)| { + let mut fragment = BytesMut::with_capacity(usize::from( + fragment_payload_size + FRAGMENT_HEADER_SIZE_FRAGMENTED, + )); crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment); fragment.put_u16(packet_id); fragment.put_u8( @@ -123,36 +135,44 @@ pub fn fragment_packet( .expect("fragment index must fit in an u8, since num_fragments fits is an u8"), ); fragment.put_u8(fragment_count); + + debug_assert!(fragment.len() == usize::from(FRAGMENT_HEADER_SIZE_FRAGMENTED)); + fragment.extend_from_slice(fragment_payload); fragment.freeze() - }, - ); + }); Ok(iterator) } -#[test] -fn test_fragment_reconstruction() { - use rand::{seq::SliceRandom, thread_rng}; +#[cfg(test)] +mod test { + use super::*; - let payload = (0..255).collect::<Vec<u8>>(); - let max_payload_size = 50; - let packet_id = 76; + #[test] + fn test_fragment_reconstruction() { + use rand::{seq::SliceRandom, thread_rng}; - let mut fragments = Fragments::default(); + let payload = (0..255).collect::<Vec<u8>>(); + let max_payload_size = 50; + let packet_id = 76; - let mut payload_clone = Bytes::from(payload.clone()); - let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) - .unwrap() - .collect::<Vec<_>>(); + let mut fragments = Fragments::default(); - fragment_buf.shuffle(&mut thread_rng()); + let mut payload_clone = Bytes::from(payload.clone()); + let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) + .unwrap() + .collect::<Vec<_>>(); - for fragment in fragment_buf { - if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() { - assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); - return; + fragment_buf.shuffle(&mut thread_rng()); + + for fragment in fragment_buf { + if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() + { + assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); + return; + } } - } - panic!("Failed to reconstruct packet"); + panic!("Failed to reconstruct packet"); + } } diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs index d4a4e47812..5e9c2902b0 100644 --- a/mullvad-masque-proxy/src/lib.rs +++ b/mullvad-masque-proxy/src/lib.rs @@ -1,10 +1,39 @@ use h3::proto::varint::VarInt; +use std::net::SocketAddr; pub mod client; mod fragment; pub mod server; mod stats; -const PACKET_BUFFER_SIZE: usize = 1700; +const PACKET_BUFFER_SIZE: usize = 64 * 1024; pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0); pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1); + +/// Fragment headers size for fragmented packets +const FRAGMENT_HEADER_SIZE_FRAGMENTED: u16 = 5; + +/// UDP header overhead +const UDP_HEADER_SIZE: u16 = 8; + +/// QUIC header size. This is conservative, real overhead varies +const QUIC_HEADER_SIZE: u16 = 41; + +/// This is the size of the payload that stores QUIC packets +/// MTU - IP header - UDP header +const fn compute_udp_payload_size(mtu: u16, target_addr: SocketAddr) -> u16 { + let ip_overhead = if target_addr.is_ipv4() { 20 } else { 40 }; + mtu - ip_overhead - UDP_HEADER_SIZE +} + +/// Minimum allowed MTU (IPv6) is the overhead of all headers, plus 1 byte for actual data. +/// QUIC defines that clients must support UDP payloads of at least 1200 bytes. +/// <https://datatracker.ietf.org/doc/html/rfc9000#section-8.1> +// 20 = IPv4 header (without optional fields) +pub const MIN_IPV4_MTU: u16 = 20 + UDP_HEADER_SIZE + 1200; + +/// Minimum allowed MTU (IPv6) is the overhead of all headers, plus 1 byte for actual data. +/// QUIC defines that clients must support UDP payloads of at least 1200 bytes. +/// <https://datatracker.ietf.org/doc/html/rfc9000#section-8.1> +// 40 = IPv6 header (without optional fields) +pub const MIN_IPV6_MTU: u16 = 40 + UDP_HEADER_SIZE + 1200; diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs index 710edaa537..a57282d172 100644 --- a/mullvad-masque-proxy/src/server/mod.rs +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -17,7 +17,11 @@ use http::{Request, StatusCode}; use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming}; use tokio::{net::UdpSocket, time::interval}; -use crate::fragment::{self, Fragments}; +use crate::{ + compute_udp_payload_size, + fragment::{self, Fragments}, + MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE, +}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -27,6 +31,8 @@ pub enum Error { BindSocket(#[source] io::Error), #[error("Failed to send negotiation response")] SendNegotiationResponse(#[source] h3::Error), + #[error("Invalid MTU: must be at least {min_mtu}")] + InvalidMtu { min_mtu: u16 }, } pub type Result<T> = std::result::Result<T, Error>; @@ -36,7 +42,7 @@ const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/"; pub struct Server { endpoint: Endpoint, allowed_hosts: AllowedIps, - maximum_packet_size: u16, + mtu: u16, } #[derive(Clone)] @@ -55,8 +61,10 @@ impl Server { bind_addr: SocketAddr, allowed_hosts: HashSet<IpAddr>, tls_config: Arc<rustls::ServerConfig>, - maximum_packet_size: u16, + mtu: u16, ) -> Result<Self> { + Self::validate_mtu(mtu, bind_addr)?; + let server_config = quinn::ServerConfig::with_crypto(Arc::new( QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?, )); @@ -68,10 +76,23 @@ impl Server { allowed_hosts: AllowedIps { hosts: Arc::new(allowed_hosts), }, - maximum_packet_size, + mtu, }) } + const fn validate_mtu(mtu: u16, bind_addr: SocketAddr) -> Result<()> { + let min_mtu = if bind_addr.is_ipv4() { + MIN_IPV4_MTU + } else { + MIN_IPV6_MTU + }; + if mtu >= min_mtu { + Ok(()) + } else { + Err(Error::InvalidMtu { min_mtu }) + } + } + pub fn local_addr(&self) -> io::Result<SocketAddr> { self.endpoint.local_addr() } @@ -81,21 +102,19 @@ impl Server { tokio::spawn(Self::handle_incoming_connection( new_connection, self.allowed_hosts.clone(), - self.maximum_packet_size, + self.mtu, )); } Ok(()) } - async fn handle_incoming_connection( - connection: Incoming, - allowed_hosts: AllowedIps, - maximum_packet_size: u16, - ) { + async fn handle_incoming_connection(connection: Incoming, allowed_hosts: AllowedIps, mtu: u16) { match connection.await { Ok(conn) => { println!("new connection established"); + let quinn_conn = conn.clone(); + let Ok(mut connection) = server::builder() .enable_datagram(true) .build(h3_quinn::Connection::new(conn)) @@ -109,10 +128,11 @@ impl Server { Ok(Some((req, stream))) => { tokio::spawn(Self::handle_proxy_request( connection, + quinn_conn, req, stream, allowed_hosts.clone(), - maximum_packet_size, + mtu, )); } @@ -132,10 +152,11 @@ impl Server { async fn handle_proxy_request<T: BidiStream<Bytes>>( mut connection: Connection<h3_quinn::Connection, Bytes>, + quinn_conn: quinn::Connection, request: Request<()>, mut stream: RequestStream<T, Bytes>, allowed_hosts: AllowedIps, - maximum_packet_size: u16, + mtu: u16, ) { let Some(target_addr) = get_target_socketaddr(request.uri().path()) else { return; @@ -157,8 +178,11 @@ impl Server { return; } + let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr); + let stream_id = stream.id(); - let mut proxy_recv_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE); + let stream_id_size = VarInt::from(stream_id).size() as u16; + let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE); let mut fragments = Fragments::default(); let mut fragment_id = 0u16; @@ -191,12 +215,20 @@ impl Server { let mut received_packet = proxy_recv_buf.split().freeze(); - if received_packet.len() < maximum_packet_size.into() { + // Maximum QUIC payload (including fragmentation headers) + let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() { + max_datagram_size as u16 - stream_id_size + } else { + max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size + }; + + if received_packet.len() <= usize::from(maximum_packet_size) { if connection.send_datagram(stream_id, received_packet).is_err() { return; } } else { let _ = VarInt::decode(&mut received_packet); + let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; }; fragment_id += 1; for payload in fragments { @@ -296,21 +328,26 @@ fn unspecified_addr(addr: IpAddr) -> IpAddr { } } -#[test] -fn test_get_good_slashy_ocketaddr() { - let addr: IpAddr = "192.168.1.1".parse().unwrap(); - let port: u16 = 7979; - let expected_addr = SocketAddr::new(addr, port); - let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////"); +#[cfg(test)] +mod test { + use super::*; - assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr) -} + #[test] + fn test_get_good_slashy_ocketaddr() { + let addr: IpAddr = "192.168.1.1".parse().unwrap(); + let port: u16 = 7979; + let expected_addr = SocketAddr::new(addr, port); + let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////"); + + assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr) + } -#[test] -fn test_get_bad_socketaddr() { - let addr: IpAddr = "192.168.1.1".parse().unwrap(); - let port: u16 = 7979; - let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}"); + #[test] + fn test_get_bad_socketaddr() { + let addr: IpAddr = "192.168.1.1".parse().unwrap(); + let port: u16 = 7979; + let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}"); - assert_eq!(get_target_socketaddr(&good_path), None) + assert_eq!(get_target_socketaddr(&good_path), None) + } } diff --git a/mullvad-masque-proxy/tests/proxy.rs b/mullvad-masque-proxy/tests/proxy.rs index 6825f7d75d..61a15e5136 100644 --- a/mullvad-masque-proxy/tests/proxy.rs +++ b/mullvad-masque-proxy/tests/proxy.rs @@ -1,25 +1,145 @@ +use std::iter; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; +use anyhow::anyhow; use anyhow::Context; use bytes::BytesMut; +use mullvad_masque_proxy::MIN_IPV4_MTU; +use rand::RngCore; use tokio::fs; use mullvad_masque_proxy::client; use mullvad_masque_proxy::server; use tokio::net::UdpSocket; +use tokio::time::timeout; /// Set up a MASQUE proxy and test that it can be used to communicate with some UDP destination #[tokio::test] async fn test_server_and_client_forwarding() -> anyhow::Result<()> { - const MAXIMUM_PACKET_SIZE: u16 = 1700; + timeout(Duration::from_secs(1), async { + const MTU: u16 = 1700; + let (client, server) = setup_masque(MTU).await?; + + // Proxy client -> destination + let mut rx_buf = BytesMut::with_capacity(128); + client.send(b"abc").await?; + let (_, proxy_addr) = server + .recv_buf_from(&mut rx_buf) + .await + .context("Expected to receive message")?; + assert_eq!(&*rx_buf, b"abc", "Expected to receive message from client"); + + // Destination -> proxy client + let mut rx_buf = BytesMut::with_capacity(128); + server.send_to(b"def", proxy_addr).await?; + client + .recv_buf(&mut rx_buf) + .await + .context("Expected to receive message")?; + assert_eq!(&*rx_buf, b"def", "Expected to receive message from server"); + + Ok(()) + }) + .await? +} + +/// End to end test with fragmentation. +/// Note: This doesn't actually check whether fragmentation occurs, only that packets actually +/// reach their destinations when fragmentation *should* be present. +#[tokio::test] +async fn test_server_and_client_fragmentation() -> anyhow::Result<()> { + #[allow(unused_mut)] + let mut valid_send_packet_sizes = vec![0u16, 1, 10, 100, 1280, 5000]; + + // Maximum packet size sans UDP and QUIC headers, sans 1 byte context ID. + // + // NOTE: On macOS, the maximum UDP packet size is equal to the value set by + // `sysctl net.inet.udp.maxdgram` + #[cfg(not(target_os = "macos"))] + valid_send_packet_sizes.push(u16::MAX - 8 - 41 - 1); + + let valid_mtus = [MIN_IPV4_MTU, 1280, 1500, 1700, 5000, 20000, u16::MAX]; + + let params = valid_mtus + .into_iter() + .flat_map(|mtu| iter::repeat(mtu).zip(&valid_send_packet_sizes)); + + async fn run_test(mtu: u16, send_packet_size: usize) -> anyhow::Result<()> { + let (client, server) = setup_masque(mtu).await?; + + // Proxy client -> destination + // Send a random packet, large enough to be fragmented + let mut fragment_me = vec![0u8; send_packet_size]; + rand::thread_rng().fill_bytes(&mut fragment_me); + + client.send(&fragment_me).await?; + + let mut rx_buf = BytesMut::with_capacity(send_packet_size + 100); + let (_, proxy_addr) = server + .recv_buf_from(&mut rx_buf) + .await + .context("Expected to receive message")?; + let read = rx_buf.split(); + assert_eq!( + &*read, &fragment_me, + "Expected to receive reassembled message from client" + ); + + // Destination -> proxy client + // Send a random packet, large enough to be fragmented + let mut fragment_me = vec![0u8; send_packet_size]; + rand::thread_rng().fill_bytes(&mut fragment_me); + + server.send_to(&fragment_me, proxy_addr).await?; + + let mut rx_buf = BytesMut::with_capacity(send_packet_size + 100); + let blen = client + .recv_buf(&mut rx_buf) + .await + .context("Expected to receive message")?; + + let read = rx_buf.split(); + eprintln!( + "from server: {}, {}, {}", + fragment_me.len(), + read.len(), + blen + ); + assert_eq!( + &*read, &fragment_me, + "Expected to receive reassembled message from server" + ); + + Ok(()) + } + + for (mtu, &send_packet_size) in params { + timeout( + Duration::from_secs(1), + run_test(mtu, send_packet_size.into()), + ) + .await? + .context(anyhow!("mtu={mtu}, send_packet_size={send_packet_size}"))?; + } + + Ok(()) +} + +/// Set up a client and server connected by a MASQUE proxy. +/// This returns a UDP socket that is connected to the local MASQUE client, +/// and a UDP socket that represents the other endpoint. +/// Note that the server socket (second returned value) is not connected, +/// so `recv_from` must be used. +async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> { const HOST: &str = "test.test"; let any_localhost_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); // Set up destination UDP server - let target_udp_server = UdpSocket::bind(any_localhost_addr).await?; - let target_udp_addr = target_udp_server + let destination_udp_server = UdpSocket::bind(any_localhost_addr).await?; + let target_udp_addr = destination_udp_server .local_addr() .context("Retrieve dest UDP server addr")?; @@ -29,13 +149,17 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> { any_localhost_addr, Default::default(), Arc::new(server_tls_config), - MAXIMUM_PACKET_SIZE, + mtu, ) .context("Failed to start MASQUE server")?; let masque_server_addr = server.local_addr()?; - tokio::spawn(server.run()); + tokio::spawn(async move { + if let Err(err) = server.run().await { + eprintln!("server.run() failed: {err}"); + } + }); // Set up MASQUE client let local_socket = UdpSocket::bind(any_localhost_addr) @@ -51,12 +175,16 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> { target_udp_addr, HOST, client::default_tls_config(), - MAXIMUM_PACKET_SIZE, + mtu, ) .await .context("Failed to start MASQUE client")?; - tokio::spawn(client.run()); + tokio::spawn(async move { + if let Err(err) = client.run().await { + eprintln!("client.run() failed: {err}"); + } + }); // Connect to local UDP socket let proxy_client = UdpSocket::bind(any_localhost_addr).await?; @@ -65,25 +193,7 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> { .await .context("Failed to connect to local UDP server")?; - // Proxy client -> destination - let mut rx_buf = BytesMut::with_capacity(128); - proxy_client.send(b"abc").await?; - let (_, proxy_addr) = target_udp_server - .recv_buf_from(&mut rx_buf) - .await - .context("Expected to receive message")?; - assert_eq!(&*rx_buf, b"abc", "Expected to receive message from client"); - - // Destination -> proxy client - let mut rx_buf = BytesMut::with_capacity(128); - target_udp_server.send_to(b"def", proxy_addr).await?; - proxy_client - .recv_buf(&mut rx_buf) - .await - .context("Expected to receive message")?; - assert_eq!(&*rx_buf, b"def", "Expected to receive message from server"); - - Ok(()) + Ok((proxy_client, destination_udp_server)) } async fn load_server_test_cert() -> anyhow::Result<rustls::ServerConfig> { |
