diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-05-05 13:02:51 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-05-05 13:02:51 +0200 |
| commit | dc860e549ad9b88416ec38b8bd2b7f8936a65856 (patch) | |
| tree | 3b1a603b0d3a33a004fb17852cceffce9734b710 | |
| parent | f55b16817f443098eb3ee0f38076b77fd224e292 (diff) | |
| parent | 591ea62b522d1ae0f5a9d2ad926f6ed0cdc15e1d (diff) | |
| download | mullvadvpn-dc860e549ad9b88416ec38b8bd2b7f8936a65856.tar.xz mullvadvpn-dc860e549ad9b88416ec38b8bd2b7f8936a65856.zip | |
Merge branch 'masque-configurable-auth'
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-client.rs | 8 | ||||
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-server.rs | 23 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 33 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 61 | ||||
| -rw-r--r-- | mullvad-masque-proxy/tests/proxy.rs | 20 | ||||
| -rw-r--r-- | tunnel-obfuscation/src/quic.rs | 6 |
6 files changed, 109 insertions, 42 deletions
diff --git a/mullvad-masque-proxy/examples/masque-client.rs b/mullvad-masque-proxy/examples/masque-client.rs index f873b14b2f..776afa8abb 100644 --- a/mullvad-masque-proxy/examples/masque-client.rs +++ b/mullvad-masque-proxy/examples/masque-client.rs @@ -44,6 +44,10 @@ pub struct ClientArgs { /// Inactivity happens when no data is sent over the proxy. #[arg(long, short = 'i', value_parser = duration_from_seconds)] idle_timeout: Option<Duration>, + + /// Authorization header value to set + #[arg(long, default_value = "Bearer test")] + auth: Option<String>, } /// Parse a duration from a decimal number of seconds @@ -69,6 +73,7 @@ async fn main() { #[cfg(target_os = "linux")] fwmark, idle_timeout, + auth, } = ClientArgs::parse(); let tls_config = match root_cert_path { @@ -94,7 +99,8 @@ async fn main() { .target_addr(target_addr) .mtu(mtu) .tls_config(tls_config) - .idle_timeout(idle_timeout); + .idle_timeout(idle_timeout) + .auth_header(auth); #[cfg(target_os = "linux")] let config = config.fwmark(fwmark); diff --git a/mullvad-masque-proxy/examples/masque-server.rs b/mullvad-masque-proxy/examples/masque-server.rs index e884e91c93..b321fcf659 100644 --- a/mullvad-masque-proxy/examples/masque-server.rs +++ b/mullvad-masque-proxy/examples/masque-server.rs @@ -1,4 +1,5 @@ use clap::Parser; +use mullvad_masque_proxy::server::{AllowedIps, ServerParams}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use std::{ @@ -36,6 +37,10 @@ pub struct ServerArgs { /// Maximum packet size #[arg(long, short = 'm', default_value = "1700")] mtu: u16, + + /// Authorization header value to set + #[arg(long, default_value = "Bearer test")] + auth: Option<String>, } #[tokio::main] @@ -50,14 +55,16 @@ async fn main() { let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap(); - let server = mullvad_masque_proxy::server::Server::bind( - args.bind_addr, - args.allowed_ips.iter().cloned().collect(), - args.hostname, - tls_config.into(), - args.mtu, - ) - .expect("Failed to initialize server"); + let params = ServerParams::builder() + .allowed_hosts(AllowedIps::from(args.allowed_ips)) + .hostname(args.hostname) + .mtu(args.mtu) + .auth_header(args.auth) + .build(); + + let server = + mullvad_masque_proxy::server::Server::bind(args.bind_addr, tls_config.into(), params) + .expect("Failed to initialize server"); log::info!("Listening on {}", args.bind_addr); server.run().await.expect("Server failed.") } diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs index 33dcb0e397..7bdea3a9e6 100644 --- a/mullvad-masque-proxy/src/client/mod.rs +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -145,6 +145,10 @@ pub struct ClientConfig { /// Optional timeout when no data is sent in the proxy. #[builder(default)] pub idle_timeout: Option<Duration>, + + /// Set the authorization header to use in the CONNECT-UDP request. + #[builder(default)] + pub auth_header: Option<String>, } impl Client { @@ -189,6 +193,7 @@ impl Client { config.target_addr, &config.server_host, max_udp_payload_size, + config.auth_header, ) .await?; @@ -257,6 +262,7 @@ impl Client { target: SocketAddr, server_host: &str, mtu: u16, + auth_header: Option<String>, ) -> Result<( client::Connection<h3_quinn::Connection, bytes::Bytes>, client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, @@ -270,7 +276,16 @@ impl Client { .await .map_err(Error::CreateClient)?; - Self::send_connect_request(connection, send_stream, server_host, target, mtu, 0).await + Self::send_connect_request( + connection, + send_stream, + server_host, + target, + mtu, + 0, + auth_header, + ) + .await } /// Send an HTTP CONNECT request and set up the h3 connection for sending datagrams. @@ -283,12 +298,13 @@ impl Client { target: SocketAddr, mtu: u16, redirect_count: usize, + auth_header: Option<String>, ) -> Result<( client::Connection<h3_quinn::Connection, bytes::Bytes>, client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, )> { - let request = new_connect_request(target, &server_host, mtu)?; + let request = new_connect_request(target, &server_host, mtu, auth_header.as_deref())?; let request_future = async move { let mut request_stream = send_stream.send_request(request).await?; @@ -343,6 +359,7 @@ impl Client { target, mtu, redirect_count + 1, + auth_header, )) .await } @@ -547,6 +564,7 @@ fn new_connect_request( socket_addr: SocketAddr, authority: &dyn AsRef<str>, mtu: u16, + authorization: Option<&str>, ) -> Result<http::Request<()>> { let host = socket_addr.ip(); let port = socket_addr.port(); @@ -558,12 +576,17 @@ fn new_connect_request( .build() .map_err(Error::Uri)?; - let mut request = http::Request::builder() + let mut builder = http::Request::builder() .method(http::method::Method::CONNECT) .uri(uri) .header(b"Capsule-Protocol".as_slice(), b"?1".as_slice()) - .header(header::AUTHORIZATION, b"Bearer test".as_slice()) - .header(header::HOST, authority.as_ref()) + .header(header::HOST, authority.as_ref()); + + if let Some(auth) = authorization { + builder = builder.header(header::AUTHORIZATION, auth); + } + + let mut request = builder // TODO: Not needed since we set the max_udp_payload_size transport param .header( b"X-Mullvad-Uplink-Mtu".as_slice(), diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs index 4f160eaea6..af1cb3e1d4 100644 --- a/mullvad-masque-proxy/src/server/mod.rs +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -14,9 +14,10 @@ use h3::{ server::{self, Connection, RequestStream}, }; use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; -use http::{StatusCode, Uri}; +use http::{header, StatusCode, Uri}; use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming}; use tokio::{net::UdpSocket, select, sync::mpsc, task}; +use typed_builder::TypedBuilder; use crate::{ compute_udp_payload_size, @@ -43,22 +44,37 @@ pub struct Server { params: Arc<ServerParams>, } -struct ServerParams { +#[derive(TypedBuilder)] +pub struct ServerParams { /// Allowed target IPs for the proxy connection - allowed_hosts: AllowedIps, + pub allowed_hosts: AllowedIps, - /// Server hostname (optional) - hostname: Option<String>, + /// Server hostname expected from clients + #[builder(default)] + pub hostname: Option<String>, /// Maximum transfer unit - mtu: u16, + #[builder(default = 1500)] + pub mtu: u16, + + /// Authorization header expected from clients + #[builder(default)] + pub auth_header: Option<String>, } -#[derive(Clone)] -struct AllowedIps { +#[derive(Default, Clone)] +pub struct AllowedIps { hosts: Arc<HashSet<IpAddr>>, } +impl<T: IntoIterator<Item = IpAddr>> From<T> for AllowedIps { + fn from(value: T) -> Self { + AllowedIps { + hosts: Arc::new(value.into_iter().collect()), + } + } +} + impl AllowedIps { fn ip_allowed(&self, ip: IpAddr) -> bool { self.hosts.is_empty() || self.hosts.contains(&ip) @@ -68,12 +84,10 @@ impl AllowedIps { impl Server { pub fn bind( bind_addr: SocketAddr, - allowed_hosts: HashSet<IpAddr>, - hostname: Option<String>, tls_config: Arc<rustls::ServerConfig>, - mtu: u16, + params: ServerParams, ) -> Result<Self> { - Self::validate_mtu(mtu, bind_addr)?; + Self::validate_mtu(params.mtu, bind_addr)?; let server_config = quinn::ServerConfig::with_crypto(Arc::new( QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?, @@ -83,13 +97,7 @@ impl Server { Ok(Self { endpoint, - params: Arc::new(ServerParams { - allowed_hosts: AllowedIps { - hosts: Arc::new(allowed_hosts), - }, - hostname, - mtu, - }), + params: Arc::new(params), }) } @@ -171,6 +179,13 @@ impl Server { } }; + if let Some(required_auth) = &server_params.auth_header { + match http_request.headers().get(header::AUTHORIZATION) { + Some(actual_auth) if actual_auth == required_auth => (), + _ => return handle_invalid_auth(stream).await, + } + } + if let Some(hostname) = &server_params.hostname { if &proxy_uri.hostname != hostname { let valid_uri = ProxyUri { @@ -389,6 +404,14 @@ async fn handle_established_connection<T: BidiStream<Bytes>>( Ok(()) } +async fn handle_invalid_auth<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { + let response = http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(()) + .unwrap(); + let _ = stream.send_response(response).await; +} + async fn handle_disallowed_ip<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { let response = http::Response::builder() .status(StatusCode::BAD_REQUEST) diff --git a/mullvad-masque-proxy/tests/proxy.rs b/mullvad-masque-proxy/tests/proxy.rs index 29a90ca6fb..25a95bbbcf 100644 --- a/mullvad-masque-proxy/tests/proxy.rs +++ b/mullvad-masque-proxy/tests/proxy.rs @@ -6,6 +6,8 @@ use std::time::Duration; use anyhow::anyhow; use anyhow::Context; use bytes::BytesMut; +use mullvad_masque_proxy::server::AllowedIps; +use mullvad_masque_proxy::server::ServerParams; use mullvad_masque_proxy::MIN_IPV4_MTU; use rand::RngCore; use tokio::fs; @@ -145,14 +147,15 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> { // Set up MASQUE server let server_tls_config = load_server_test_cert().await?; - let server = server::Server::bind( - any_localhost_addr, - Default::default(), - None, - Arc::new(server_tls_config), - mtu, - ) - .context("Failed to start MASQUE server")?; + + let params = ServerParams::builder() + .allowed_hosts(AllowedIps::default()) + .mtu(mtu) + .auth_header(Some("Bearer test".to_owned())) + .build(); + + let server = server::Server::bind(any_localhost_addr, Arc::new(server_tls_config), params) + .context("Failed to start MASQUE server")?; let masque_server_addr = server.local_addr()?; @@ -176,6 +179,7 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> { .target_addr(target_udp_addr) .mtu(mtu) .idle_timeout(Some(Duration::from_secs(10))) + .auth_header(Some("Bearer test".to_owned())) .build(); let client = client::Client::connect(client_config) diff --git a/tunnel-obfuscation/src/quic.rs b/tunnel-obfuscation/src/quic.rs index 11a315554d..c72a3b84fd 100644 --- a/tunnel-obfuscation/src/quic.rs +++ b/tunnel-obfuscation/src/quic.rs @@ -12,6 +12,9 @@ use crate::Obfuscator; type Result<T> = std::result::Result<T, Error>; +/// Authentication header to set for the CONNECT request +const AUTH_HEADER: &str = "Bearer test"; + #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Failed to bind UDP socket")] @@ -48,7 +51,8 @@ impl Quic { .local_addr((Ipv4Addr::UNSPECIFIED, 0).into()) .server_addr(settings.quic_endpoint) .server_host(settings.hostname.clone()) - .target_addr(settings.wireguard_endpoint); + .target_addr(settings.wireguard_endpoint) + .auth_header(Some(AUTH_HEADER.to_owned())); let task = tokio::spawn(async move { let client = Client::connect(config_builder.build()) |
