summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2025-04-10 16:53:29 +0200
committerDavid Lönnhager <david.l@mullvad.net>2025-04-10 16:53:29 +0200
commite4591042b587d73794515da0be9627aedb27ac94 (patch)
tree50e21c575da46099c477092baa706dec2e57e47e
parent510c224d5fe3c769b834b64ec81b85fc21bfb792 (diff)
parentafdb6280d30c4bb05c51695075f00530684a37dc (diff)
downloadmullvadvpn-e4591042b587d73794515da0be9627aedb27ac94.tar.xz
mullvadvpn-e4591042b587d73794515da0be9627aedb27ac94.zip
Merge branch 'fix-masque-sizes'
-rw-r--r--mullvad-masque-proxy/examples/masque-client.rs6
-rw-r--r--mullvad-masque-proxy/examples/masque-server.rs4
-rw-r--r--mullvad-masque-proxy/src/client/mod.rs96
-rw-r--r--mullvad-masque-proxy/src/fragment.rs68
-rw-r--r--mullvad-masque-proxy/src/lib.rs31
-rw-r--r--mullvad-masque-proxy/src/server/mod.rs93
-rw-r--r--mullvad-masque-proxy/tests/proxy.rs162
7 files changed, 350 insertions, 110 deletions
diff --git a/mullvad-masque-proxy/examples/masque-client.rs b/mullvad-masque-proxy/examples/masque-client.rs
index 780e95d044..3a6b591640 100644
--- a/mullvad-masque-proxy/examples/masque-client.rs
+++ b/mullvad-masque-proxy/examples/masque-client.rs
@@ -31,7 +31,7 @@ pub struct ClientArgs {
/// Maximum packet size
#[arg(long, short = 'S', default_value = "1280")]
- maximum_packet_size: u16,
+ mtu: u16,
}
#[tokio::main]
@@ -42,7 +42,7 @@ async fn main() {
root_cert_path,
server_hostname,
bind_port,
- maximum_packet_size,
+ mtu,
} = ClientArgs::parse();
let tls_config = match root_cert_path {
@@ -67,7 +67,7 @@ async fn main() {
target_addr,
&server_hostname,
tls_config,
- maximum_packet_size,
+ mtu,
)
.await;
if let Err(err) = &client {
diff --git a/mullvad-masque-proxy/examples/masque-server.rs b/mullvad-masque-proxy/examples/masque-server.rs
index 9c07423d9c..fe216070ef 100644
--- a/mullvad-masque-proxy/examples/masque-server.rs
+++ b/mullvad-masque-proxy/examples/masque-server.rs
@@ -28,7 +28,7 @@ pub struct ServerArgs {
/// Maximum packet size
#[arg(long, short = 'm', default_value = "1700")]
- maximum_packet_size: u16,
+ mtu: u16,
}
#[tokio::main]
@@ -42,7 +42,7 @@ async fn main() {
args.bind_addr,
args.allowed_ips.iter().cloned().collect(),
tls_config.into(),
- args.maximum_packet_size,
+ args.mtu,
)
.expect("Failed to initialize server");
println!("Listening on {}", args.bind_addr);
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs
index 4d59f0dd03..df65339c03 100644
--- a/mullvad-masque-proxy/src/client/mod.rs
+++ b/mullvad-masque-proxy/src/client/mod.rs
@@ -21,8 +21,10 @@ use quinn::{
};
use crate::{
+ compute_udp_payload_size,
fragment::{self, Fragments},
stats::Stats,
+ MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
};
const MAX_HEADER_SIZE: u64 = 8192;
@@ -34,6 +36,9 @@ const MAX_INFLIGHT_PACKETS: usize = 100;
pub struct Client {
client_socket: Arc<UdpSocket>,
+ /// QUIC endpoint
+ quinn_conn: quinn::Connection,
+
/// QUIC connection, used to send the actual HTTP datagrams
connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>,
@@ -45,8 +50,8 @@ pub struct Client {
/// connection is terminated
request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
- /// Maximum packet size
- maximum_packet_size: u16,
+ /// Maximum UDP payload size (packet size including QUIC overhead)
+ max_udp_payload_size: u16,
stats: Arc<Stats>,
}
@@ -61,6 +66,8 @@ pub enum Error {
Connect(#[from] quinn::ConnectError),
#[error("Failed to connect to QUIC endpoint")]
Connection(#[from] quinn::ConnectionError),
+ #[error("Invalid MTU: must be at least {min_mtu}")]
+ InvalidMtu { min_mtu: u16 },
#[error("Invalid max_udp_payload_size")]
InvalidMaxUdpPayload(#[source] quinn::ConfigError),
#[error("Connection closed while sending request to initiate proxying")]
@@ -100,7 +107,7 @@ impl Client {
local_addr: SocketAddr,
target_addr: SocketAddr,
server_host: &str,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<Self> {
Self::connect_with_tls_config(
client_socket,
@@ -109,7 +116,7 @@ impl Client {
target_addr,
server_host,
default_tls_config(),
- maximum_packet_size,
+ mtu,
)
.await
}
@@ -121,7 +128,7 @@ impl Client {
target_addr: SocketAddr,
server_host: &str,
tls_config: Arc<rustls::ClientConfig>,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<Self> {
let quic_client_config = QuicClientConfig::try_from(tls_config)
.expect("Failed to construct a valid TLS configuration");
@@ -140,7 +147,7 @@ impl Client {
target_addr,
server_host,
client_config,
- maximum_packet_size,
+ mtu,
)
.await
}
@@ -152,34 +159,56 @@ impl Client {
target_addr: SocketAddr,
server_host: &str,
client_config: ClientConfig,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<Self> {
- let endpoint = Self::setup_quic_endpoint(local_addr, maximum_packet_size)?;
+ Self::validate_mtu(mtu, target_addr)?;
+
+ let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr);
+
+ let endpoint = Self::setup_quic_endpoint(local_addr, max_udp_payload_size)?;
let connecting = endpoint.connect_with(client_config, server_addr, server_host)?;
let connection = connecting.await?;
- let (connection, send_stream, request_stream) =
- Self::setup_h3_connection(connection, target_addr, server_host, maximum_packet_size)
- .await?;
+ let (h3_connection, send_stream, request_stream) = Self::setup_h3_connection(
+ connection.clone(),
+ target_addr,
+ server_host,
+ max_udp_payload_size,
+ )
+ .await?;
Ok(Self {
- connection,
+ quinn_conn: connection,
+ connection: h3_connection,
client_socket: Arc::new(client_socket),
request_stream,
_send_stream: send_stream,
- maximum_packet_size,
+ max_udp_payload_size,
stats: Arc::default(),
})
}
- fn setup_quic_endpoint(local_addr: SocketAddr, maximum_packet_size: u16) -> Result<Endpoint> {
+ const fn validate_mtu(mtu: u16, target_addr: SocketAddr) -> Result<()> {
+ let min_mtu = if target_addr.is_ipv4() {
+ MIN_IPV4_MTU
+ } else {
+ MIN_IPV6_MTU
+ };
+ if mtu >= min_mtu {
+ Ok(())
+ } else {
+ Err(Error::InvalidMtu { min_mtu })
+ }
+ }
+
+ fn setup_quic_endpoint(local_addr: SocketAddr, max_udp_payload_size: u16) -> Result<Endpoint> {
let local_socket = std::net::UdpSocket::bind(local_addr).map_err(Error::Bind)?;
let mut endpoint_config = EndpointConfig::default();
endpoint_config
- .max_udp_payload_size(maximum_packet_size)
+ .max_udp_payload_size(max_udp_payload_size)
.map_err(Error::InvalidMaxUdpPayload)?;
Endpoint::new(endpoint_config, None, local_socket, Arc::new(TokioRuntime))
@@ -191,7 +220,7 @@ impl Client {
connection: quinn::Connection,
target: SocketAddr,
server_host: &str,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<(
client::Connection<h3_quinn::Connection, bytes::Bytes>,
client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
@@ -205,7 +234,7 @@ impl Client {
.await
.map_err(Error::CreateClient)?;
- let request = new_connect_request(target, &server_host, maximum_packet_size)?;
+ let request = new_connect_request(target, &server_host, mtu)?;
let request_future = async move {
let mut request_stream = send_stream.send_request(request).await?;
@@ -251,7 +280,8 @@ impl Client {
let mut server_socket_task = tokio::task::spawn(server_socket_task(
stream_id,
- self.maximum_packet_size,
+ self.max_udp_payload_size,
+ self.quinn_conn,
self.connection,
server_tx,
client_rx,
@@ -274,13 +304,15 @@ impl Client {
async fn server_socket_task(
stream_id: StreamId,
- maximum_packet_size: u16,
+ max_udp_payload_size: u16,
+ quinn_conn: quinn::Connection,
mut connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>,
server_tx: mpsc::Sender<Datagram>,
mut client_rx: mpsc::Receiver<Bytes>,
stats: Arc<Stats>,
) -> Result<()> {
let mut fragment_id = 1u16;
+ let stream_id_size = VarInt::from(stream_id).size() as u16;
loop {
let packet = select! {
@@ -302,7 +334,14 @@ async fn server_socket_task(
let Some(mut packet) = packet else { break };
- if packet.len() < (Into::<usize>::into(maximum_packet_size) - 100usize) {
+ // Maximum QUIC payload (including fragmentation headers)
+ let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() {
+ max_datagram_size as u16 - stream_id_size
+ } else {
+ max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size
+ };
+
+ if packet.len() <= usize::from(maximum_packet_size) {
stats.tx(packet.len(), false);
connection
.send_datagram(stream_id, packet)
@@ -314,6 +353,8 @@ async fn server_socket_task(
for fragment in fragment::fragment_packet(maximum_packet_size, &mut packet, fragment_id)
.map_err(Error::PacketTooLarge)?
{
+ debug_assert!(fragment.len() <= maximum_packet_size as usize);
+
stats.tx(fragment.len(), true);
connection
.send_datagram(stream_id, fragment)
@@ -331,7 +372,7 @@ async fn client_socket_rx_task(
client_tx: mpsc::Sender<Bytes>,
return_addr_tx: broadcast::Sender<SocketAddr>,
) -> Result<()> {
- let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024);
+ let mut client_read_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE);
let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
loop {
@@ -411,7 +452,7 @@ async fn client_socket_tx_task(
fn new_connect_request(
socket_addr: SocketAddr,
authority: &dyn AsRef<str>,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<http::Request<()>> {
let host = socket_addr.ip();
let port = socket_addr.port();
@@ -432,7 +473,7 @@ fn new_connect_request(
// TODO: Not needed since we set the max_udp_payload_size transport param
.header(
b"X-Mullvad-Uplink-Mtu".as_slice(),
- format!("{maximum_packet_size}"),
+ format!("{mtu}"),
)
.body(())
.expect("failed to construct a body");
@@ -503,9 +544,12 @@ fn read_cert_store_from_reader(reader: &mut dyn io::BufRead) -> Result<rustls::R
Ok(cert_store)
}
-#[test]
-fn test_zero_stream_id() {
- h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?");
+#[cfg(test)]
+mod test {
+ #[test]
+ fn test_zero_stream_id() {
+ h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?");
+ }
}
#[derive(Debug)]
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs
index f60c3b927d..c699ed2d78 100644
--- a/mullvad-masque-proxy/src/fragment.rs
+++ b/mullvad-masque-proxy/src/fragment.rs
@@ -6,6 +6,8 @@ use std::{
use bytes::{Buf, BufMut, Bytes, BytesMut};
use h3::proto::varint::VarInt;
+use crate::FRAGMENT_HEADER_SIZE_FRAGMENTED;
+
#[derive(Default)]
pub struct Fragments {
fragment_map: BTreeMap<u16, Vec<Fragment>>,
@@ -102,19 +104,29 @@ struct Fragment {
time_received: Instant,
}
+/// Fragment packet using the given maximum fragment size (including headers).
+///
+/// `payload` must not contain any fragmentation headers.
+/// `maximum_packet_size` is the maximum fragment size including headers.
pub fn fragment_packet(
maximum_packet_size: u16,
payload: &'_ mut Bytes,
packet_id: u16,
) -> Result<impl Iterator<Item = Bytes> + '_, PacketTooLarge> {
- let num_fragments: usize = payload.chunks(maximum_packet_size.into()).count();
+ let fragment_payload_size = maximum_packet_size - FRAGMENT_HEADER_SIZE_FRAGMENTED;
+
+ let num_fragments: usize = payload.chunks(fragment_payload_size.into()).count();
let Ok(fragment_count): std::result::Result<u8, _> = num_fragments.try_into() else {
return Err(PacketTooLarge(payload.len()));
};
- let iterator = payload.chunks(maximum_packet_size.into()).enumerate().map(
- move |(fragment_index, fragment_payload)| {
- let mut fragment = BytesMut::with_capacity((maximum_packet_size + 1).into());
+ let iterator = payload
+ .chunks(fragment_payload_size.into())
+ .enumerate()
+ .map(move |(fragment_index, fragment_payload)| {
+ let mut fragment = BytesMut::with_capacity(usize::from(
+ fragment_payload_size + FRAGMENT_HEADER_SIZE_FRAGMENTED,
+ ));
crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment);
fragment.put_u16(packet_id);
fragment.put_u8(
@@ -123,36 +135,44 @@ pub fn fragment_packet(
.expect("fragment index must fit in an u8, since num_fragments fits is an u8"),
);
fragment.put_u8(fragment_count);
+
+ debug_assert!(fragment.len() == usize::from(FRAGMENT_HEADER_SIZE_FRAGMENTED));
+
fragment.extend_from_slice(fragment_payload);
fragment.freeze()
- },
- );
+ });
Ok(iterator)
}
-#[test]
-fn test_fragment_reconstruction() {
- use rand::{seq::SliceRandom, thread_rng};
+#[cfg(test)]
+mod test {
+ use super::*;
- let payload = (0..255).collect::<Vec<u8>>();
- let max_payload_size = 50;
- let packet_id = 76;
+ #[test]
+ fn test_fragment_reconstruction() {
+ use rand::{seq::SliceRandom, thread_rng};
- let mut fragments = Fragments::default();
+ let payload = (0..255).collect::<Vec<u8>>();
+ let max_payload_size = 50;
+ let packet_id = 76;
- let mut payload_clone = Bytes::from(payload.clone());
- let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
- .unwrap()
- .collect::<Vec<_>>();
+ let mut fragments = Fragments::default();
- fragment_buf.shuffle(&mut thread_rng());
+ let mut payload_clone = Bytes::from(payload.clone());
+ let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
+ .unwrap()
+ .collect::<Vec<_>>();
- for fragment in fragment_buf {
- if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() {
- assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
- return;
+ fragment_buf.shuffle(&mut thread_rng());
+
+ for fragment in fragment_buf {
+ if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap()
+ {
+ assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
+ return;
+ }
}
- }
- panic!("Failed to reconstruct packet");
+ panic!("Failed to reconstruct packet");
+ }
}
diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs
index d4a4e47812..5e9c2902b0 100644
--- a/mullvad-masque-proxy/src/lib.rs
+++ b/mullvad-masque-proxy/src/lib.rs
@@ -1,10 +1,39 @@
use h3::proto::varint::VarInt;
+use std::net::SocketAddr;
pub mod client;
mod fragment;
pub mod server;
mod stats;
-const PACKET_BUFFER_SIZE: usize = 1700;
+const PACKET_BUFFER_SIZE: usize = 64 * 1024;
pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0);
pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1);
+
+/// Fragment headers size for fragmented packets
+const FRAGMENT_HEADER_SIZE_FRAGMENTED: u16 = 5;
+
+/// UDP header overhead
+const UDP_HEADER_SIZE: u16 = 8;
+
+/// QUIC header size. This is conservative, real overhead varies
+const QUIC_HEADER_SIZE: u16 = 41;
+
+/// This is the size of the payload that stores QUIC packets
+/// MTU - IP header - UDP header
+const fn compute_udp_payload_size(mtu: u16, target_addr: SocketAddr) -> u16 {
+ let ip_overhead = if target_addr.is_ipv4() { 20 } else { 40 };
+ mtu - ip_overhead - UDP_HEADER_SIZE
+}
+
+/// Minimum allowed MTU (IPv6) is the overhead of all headers, plus 1 byte for actual data.
+/// QUIC defines that clients must support UDP payloads of at least 1200 bytes.
+/// <https://datatracker.ietf.org/doc/html/rfc9000#section-8.1>
+// 20 = IPv4 header (without optional fields)
+pub const MIN_IPV4_MTU: u16 = 20 + UDP_HEADER_SIZE + 1200;
+
+/// Minimum allowed MTU (IPv6) is the overhead of all headers, plus 1 byte for actual data.
+/// QUIC defines that clients must support UDP payloads of at least 1200 bytes.
+/// <https://datatracker.ietf.org/doc/html/rfc9000#section-8.1>
+// 40 = IPv6 header (without optional fields)
+pub const MIN_IPV6_MTU: u16 = 40 + UDP_HEADER_SIZE + 1200;
diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs
index 710edaa537..a57282d172 100644
--- a/mullvad-masque-proxy/src/server/mod.rs
+++ b/mullvad-masque-proxy/src/server/mod.rs
@@ -17,7 +17,11 @@ use http::{Request, StatusCode};
use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming};
use tokio::{net::UdpSocket, time::interval};
-use crate::fragment::{self, Fragments};
+use crate::{
+ compute_udp_payload_size,
+ fragment::{self, Fragments},
+ MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
+};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@@ -27,6 +31,8 @@ pub enum Error {
BindSocket(#[source] io::Error),
#[error("Failed to send negotiation response")]
SendNegotiationResponse(#[source] h3::Error),
+ #[error("Invalid MTU: must be at least {min_mtu}")]
+ InvalidMtu { min_mtu: u16 },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -36,7 +42,7 @@ const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/";
pub struct Server {
endpoint: Endpoint,
allowed_hosts: AllowedIps,
- maximum_packet_size: u16,
+ mtu: u16,
}
#[derive(Clone)]
@@ -55,8 +61,10 @@ impl Server {
bind_addr: SocketAddr,
allowed_hosts: HashSet<IpAddr>,
tls_config: Arc<rustls::ServerConfig>,
- maximum_packet_size: u16,
+ mtu: u16,
) -> Result<Self> {
+ Self::validate_mtu(mtu, bind_addr)?;
+
let server_config = quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?,
));
@@ -68,10 +76,23 @@ impl Server {
allowed_hosts: AllowedIps {
hosts: Arc::new(allowed_hosts),
},
- maximum_packet_size,
+ mtu,
})
}
+ const fn validate_mtu(mtu: u16, bind_addr: SocketAddr) -> Result<()> {
+ let min_mtu = if bind_addr.is_ipv4() {
+ MIN_IPV4_MTU
+ } else {
+ MIN_IPV6_MTU
+ };
+ if mtu >= min_mtu {
+ Ok(())
+ } else {
+ Err(Error::InvalidMtu { min_mtu })
+ }
+ }
+
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.endpoint.local_addr()
}
@@ -81,21 +102,19 @@ impl Server {
tokio::spawn(Self::handle_incoming_connection(
new_connection,
self.allowed_hosts.clone(),
- self.maximum_packet_size,
+ self.mtu,
));
}
Ok(())
}
- async fn handle_incoming_connection(
- connection: Incoming,
- allowed_hosts: AllowedIps,
- maximum_packet_size: u16,
- ) {
+ async fn handle_incoming_connection(connection: Incoming, allowed_hosts: AllowedIps, mtu: u16) {
match connection.await {
Ok(conn) => {
println!("new connection established");
+ let quinn_conn = conn.clone();
+
let Ok(mut connection) = server::builder()
.enable_datagram(true)
.build(h3_quinn::Connection::new(conn))
@@ -109,10 +128,11 @@ impl Server {
Ok(Some((req, stream))) => {
tokio::spawn(Self::handle_proxy_request(
connection,
+ quinn_conn,
req,
stream,
allowed_hosts.clone(),
- maximum_packet_size,
+ mtu,
));
}
@@ -132,10 +152,11 @@ impl Server {
async fn handle_proxy_request<T: BidiStream<Bytes>>(
mut connection: Connection<h3_quinn::Connection, Bytes>,
+ quinn_conn: quinn::Connection,
request: Request<()>,
mut stream: RequestStream<T, Bytes>,
allowed_hosts: AllowedIps,
- maximum_packet_size: u16,
+ mtu: u16,
) {
let Some(target_addr) = get_target_socketaddr(request.uri().path()) else {
return;
@@ -157,8 +178,11 @@ impl Server {
return;
}
+ let max_udp_payload_size = compute_udp_payload_size(mtu, target_addr);
+
let stream_id = stream.id();
- let mut proxy_recv_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE);
+ let stream_id_size = VarInt::from(stream_id).size() as u16;
+ let mut proxy_recv_buf = BytesMut::with_capacity(100 * crate::PACKET_BUFFER_SIZE);
let mut fragments = Fragments::default();
let mut fragment_id = 0u16;
@@ -191,12 +215,20 @@ impl Server {
let mut received_packet = proxy_recv_buf.split().freeze();
- if received_packet.len() < maximum_packet_size.into() {
+ // Maximum QUIC payload (including fragmentation headers)
+ let maximum_packet_size = if let Some(max_datagram_size) = quinn_conn.max_datagram_size() {
+ max_datagram_size as u16 - stream_id_size
+ } else {
+ max_udp_payload_size - QUIC_HEADER_SIZE - stream_id_size
+ };
+
+ if received_packet.len() <= usize::from(maximum_packet_size) {
if connection.send_datagram(stream_id, received_packet).is_err() {
return;
}
} else {
let _ = VarInt::decode(&mut received_packet);
+
let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; };
fragment_id += 1;
for payload in fragments {
@@ -296,21 +328,26 @@ fn unspecified_addr(addr: IpAddr) -> IpAddr {
}
}
-#[test]
-fn test_get_good_slashy_ocketaddr() {
- let addr: IpAddr = "192.168.1.1".parse().unwrap();
- let port: u16 = 7979;
- let expected_addr = SocketAddr::new(addr, port);
- let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////");
+#[cfg(test)]
+mod test {
+ use super::*;
- assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr)
-}
+ #[test]
+ fn test_get_good_slashy_ocketaddr() {
+ let addr: IpAddr = "192.168.1.1".parse().unwrap();
+ let port: u16 = 7979;
+ let expected_addr = SocketAddr::new(addr, port);
+ let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////");
+
+ assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr)
+ }
-#[test]
-fn test_get_bad_socketaddr() {
- let addr: IpAddr = "192.168.1.1".parse().unwrap();
- let port: u16 = 7979;
- let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}");
+ #[test]
+ fn test_get_bad_socketaddr() {
+ let addr: IpAddr = "192.168.1.1".parse().unwrap();
+ let port: u16 = 7979;
+ let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}");
- assert_eq!(get_target_socketaddr(&good_path), None)
+ assert_eq!(get_target_socketaddr(&good_path), None)
+ }
}
diff --git a/mullvad-masque-proxy/tests/proxy.rs b/mullvad-masque-proxy/tests/proxy.rs
index 6825f7d75d..61a15e5136 100644
--- a/mullvad-masque-proxy/tests/proxy.rs
+++ b/mullvad-masque-proxy/tests/proxy.rs
@@ -1,25 +1,145 @@
+use std::iter;
use std::net::SocketAddr;
use std::sync::Arc;
+use std::time::Duration;
+use anyhow::anyhow;
use anyhow::Context;
use bytes::BytesMut;
+use mullvad_masque_proxy::MIN_IPV4_MTU;
+use rand::RngCore;
use tokio::fs;
use mullvad_masque_proxy::client;
use mullvad_masque_proxy::server;
use tokio::net::UdpSocket;
+use tokio::time::timeout;
/// Set up a MASQUE proxy and test that it can be used to communicate with some UDP destination
#[tokio::test]
async fn test_server_and_client_forwarding() -> anyhow::Result<()> {
- const MAXIMUM_PACKET_SIZE: u16 = 1700;
+ timeout(Duration::from_secs(1), async {
+ const MTU: u16 = 1700;
+ let (client, server) = setup_masque(MTU).await?;
+
+ // Proxy client -> destination
+ let mut rx_buf = BytesMut::with_capacity(128);
+ client.send(b"abc").await?;
+ let (_, proxy_addr) = server
+ .recv_buf_from(&mut rx_buf)
+ .await
+ .context("Expected to receive message")?;
+ assert_eq!(&*rx_buf, b"abc", "Expected to receive message from client");
+
+ // Destination -> proxy client
+ let mut rx_buf = BytesMut::with_capacity(128);
+ server.send_to(b"def", proxy_addr).await?;
+ client
+ .recv_buf(&mut rx_buf)
+ .await
+ .context("Expected to receive message")?;
+ assert_eq!(&*rx_buf, b"def", "Expected to receive message from server");
+
+ Ok(())
+ })
+ .await?
+}
+
+/// End to end test with fragmentation.
+/// Note: This doesn't actually check whether fragmentation occurs, only that packets actually
+/// reach their destinations when fragmentation *should* be present.
+#[tokio::test]
+async fn test_server_and_client_fragmentation() -> anyhow::Result<()> {
+ #[allow(unused_mut)]
+ let mut valid_send_packet_sizes = vec![0u16, 1, 10, 100, 1280, 5000];
+
+ // Maximum packet size sans UDP and QUIC headers, sans 1 byte context ID.
+ //
+ // NOTE: On macOS, the maximum UDP packet size is equal to the value set by
+ // `sysctl net.inet.udp.maxdgram`
+ #[cfg(not(target_os = "macos"))]
+ valid_send_packet_sizes.push(u16::MAX - 8 - 41 - 1);
+
+ let valid_mtus = [MIN_IPV4_MTU, 1280, 1500, 1700, 5000, 20000, u16::MAX];
+
+ let params = valid_mtus
+ .into_iter()
+ .flat_map(|mtu| iter::repeat(mtu).zip(&valid_send_packet_sizes));
+
+ async fn run_test(mtu: u16, send_packet_size: usize) -> anyhow::Result<()> {
+ let (client, server) = setup_masque(mtu).await?;
+
+ // Proxy client -> destination
+ // Send a random packet, large enough to be fragmented
+ let mut fragment_me = vec![0u8; send_packet_size];
+ rand::thread_rng().fill_bytes(&mut fragment_me);
+
+ client.send(&fragment_me).await?;
+
+ let mut rx_buf = BytesMut::with_capacity(send_packet_size + 100);
+ let (_, proxy_addr) = server
+ .recv_buf_from(&mut rx_buf)
+ .await
+ .context("Expected to receive message")?;
+ let read = rx_buf.split();
+ assert_eq!(
+ &*read, &fragment_me,
+ "Expected to receive reassembled message from client"
+ );
+
+ // Destination -> proxy client
+ // Send a random packet, large enough to be fragmented
+ let mut fragment_me = vec![0u8; send_packet_size];
+ rand::thread_rng().fill_bytes(&mut fragment_me);
+
+ server.send_to(&fragment_me, proxy_addr).await?;
+
+ let mut rx_buf = BytesMut::with_capacity(send_packet_size + 100);
+ let blen = client
+ .recv_buf(&mut rx_buf)
+ .await
+ .context("Expected to receive message")?;
+
+ let read = rx_buf.split();
+ eprintln!(
+ "from server: {}, {}, {}",
+ fragment_me.len(),
+ read.len(),
+ blen
+ );
+ assert_eq!(
+ &*read, &fragment_me,
+ "Expected to receive reassembled message from server"
+ );
+
+ Ok(())
+ }
+
+ for (mtu, &send_packet_size) in params {
+ timeout(
+ Duration::from_secs(1),
+ run_test(mtu, send_packet_size.into()),
+ )
+ .await?
+ .context(anyhow!("mtu={mtu}, send_packet_size={send_packet_size}"))?;
+ }
+
+ Ok(())
+}
+
+/// Set up a client and server connected by a MASQUE proxy.
+/// This returns a UDP socket that is connected to the local MASQUE client,
+/// and a UDP socket that represents the other endpoint.
+/// Note that the server socket (second returned value) is not connected,
+/// so `recv_from` must be used.
+async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> {
const HOST: &str = "test.test";
let any_localhost_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
// Set up destination UDP server
- let target_udp_server = UdpSocket::bind(any_localhost_addr).await?;
- let target_udp_addr = target_udp_server
+ let destination_udp_server = UdpSocket::bind(any_localhost_addr).await?;
+ let target_udp_addr = destination_udp_server
.local_addr()
.context("Retrieve dest UDP server addr")?;
@@ -29,13 +149,17 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> {
any_localhost_addr,
Default::default(),
Arc::new(server_tls_config),
- MAXIMUM_PACKET_SIZE,
+ mtu,
)
.context("Failed to start MASQUE server")?;
let masque_server_addr = server.local_addr()?;
- tokio::spawn(server.run());
+ tokio::spawn(async move {
+ if let Err(err) = server.run().await {
+ eprintln!("server.run() failed: {err}");
+ }
+ });
// Set up MASQUE client
let local_socket = UdpSocket::bind(any_localhost_addr)
@@ -51,12 +175,16 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> {
target_udp_addr,
HOST,
client::default_tls_config(),
- MAXIMUM_PACKET_SIZE,
+ mtu,
)
.await
.context("Failed to start MASQUE client")?;
- tokio::spawn(client.run());
+ tokio::spawn(async move {
+ if let Err(err) = client.run().await {
+ eprintln!("client.run() failed: {err}");
+ }
+ });
// Connect to local UDP socket
let proxy_client = UdpSocket::bind(any_localhost_addr).await?;
@@ -65,25 +193,7 @@ async fn test_server_and_client_forwarding() -> anyhow::Result<()> {
.await
.context("Failed to connect to local UDP server")?;
- // Proxy client -> destination
- let mut rx_buf = BytesMut::with_capacity(128);
- proxy_client.send(b"abc").await?;
- let (_, proxy_addr) = target_udp_server
- .recv_buf_from(&mut rx_buf)
- .await
- .context("Expected to receive message")?;
- assert_eq!(&*rx_buf, b"abc", "Expected to receive message from client");
-
- // Destination -> proxy client
- let mut rx_buf = BytesMut::with_capacity(128);
- target_udp_server.send_to(b"def", proxy_addr).await?;
- proxy_client
- .recv_buf(&mut rx_buf)
- .await
- .context("Expected to receive message")?;
- assert_eq!(&*rx_buf, b"def", "Expected to receive message from server");
-
- Ok(())
+ Ok((proxy_client, destination_udp_server))
}
async fn load_server_test_cert() -> anyhow::Result<rustls::ServerConfig> {