diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-04-30 18:38:24 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-05-05 13:02:29 +0200 |
| commit | ee380d68e2e12ef52b49cdb7439e3d2bb1637bcf (patch) | |
| tree | 8fc8db3c8fe81857f5e1883d7b2d941b330bd56d | |
| parent | c026f50ce79454596e247b4027a885df5df0d212 (diff) | |
| download | mullvadvpn-ee380d68e2e12ef52b49cdb7439e3d2bb1637bcf.tar.xz mullvadvpn-ee380d68e2e12ef52b49cdb7439e3d2bb1637bcf.zip | |
Add masque server params builder and authorization header option
| -rw-r--r-- | mullvad-masque-proxy/examples/masque-server.rs | 23 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 61 | ||||
| -rw-r--r-- | mullvad-masque-proxy/tests/proxy.rs | 20 |
3 files changed, 69 insertions, 35 deletions
diff --git a/mullvad-masque-proxy/examples/masque-server.rs b/mullvad-masque-proxy/examples/masque-server.rs index e884e91c93..b321fcf659 100644 --- a/mullvad-masque-proxy/examples/masque-server.rs +++ b/mullvad-masque-proxy/examples/masque-server.rs @@ -1,4 +1,5 @@ use clap::Parser; +use mullvad_masque_proxy::server::{AllowedIps, ServerParams}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use std::{ @@ -36,6 +37,10 @@ pub struct ServerArgs { /// Maximum packet size #[arg(long, short = 'm', default_value = "1700")] mtu: u16, + + /// Authorization header value to set + #[arg(long, default_value = "Bearer test")] + auth: Option<String>, } #[tokio::main] @@ -50,14 +55,16 @@ async fn main() { let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap(); - let server = mullvad_masque_proxy::server::Server::bind( - args.bind_addr, - args.allowed_ips.iter().cloned().collect(), - args.hostname, - tls_config.into(), - args.mtu, - ) - .expect("Failed to initialize server"); + let params = ServerParams::builder() + .allowed_hosts(AllowedIps::from(args.allowed_ips)) + .hostname(args.hostname) + .mtu(args.mtu) + .auth_header(args.auth) + .build(); + + let server = + mullvad_masque_proxy::server::Server::bind(args.bind_addr, tls_config.into(), params) + .expect("Failed to initialize server"); log::info!("Listening on {}", args.bind_addr); server.run().await.expect("Server failed.") } diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs index 4f160eaea6..af1cb3e1d4 100644 --- a/mullvad-masque-proxy/src/server/mod.rs +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -14,9 +14,10 @@ use h3::{ server::{self, Connection, RequestStream}, }; use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; -use http::{StatusCode, Uri}; +use http::{header, StatusCode, Uri}; use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming}; use tokio::{net::UdpSocket, select, sync::mpsc, task}; +use typed_builder::TypedBuilder; use crate::{ compute_udp_payload_size, @@ -43,22 +44,37 @@ pub struct Server { params: Arc<ServerParams>, } -struct ServerParams { +#[derive(TypedBuilder)] +pub struct ServerParams { /// Allowed target IPs for the proxy connection - allowed_hosts: AllowedIps, + pub allowed_hosts: AllowedIps, - /// Server hostname (optional) - hostname: Option<String>, + /// Server hostname expected from clients + #[builder(default)] + pub hostname: Option<String>, /// Maximum transfer unit - mtu: u16, + #[builder(default = 1500)] + pub mtu: u16, + + /// Authorization header expected from clients + #[builder(default)] + pub auth_header: Option<String>, } -#[derive(Clone)] -struct AllowedIps { +#[derive(Default, Clone)] +pub struct AllowedIps { hosts: Arc<HashSet<IpAddr>>, } +impl<T: IntoIterator<Item = IpAddr>> From<T> for AllowedIps { + fn from(value: T) -> Self { + AllowedIps { + hosts: Arc::new(value.into_iter().collect()), + } + } +} + impl AllowedIps { fn ip_allowed(&self, ip: IpAddr) -> bool { self.hosts.is_empty() || self.hosts.contains(&ip) @@ -68,12 +84,10 @@ impl AllowedIps { impl Server { pub fn bind( bind_addr: SocketAddr, - allowed_hosts: HashSet<IpAddr>, - hostname: Option<String>, tls_config: Arc<rustls::ServerConfig>, - mtu: u16, + params: ServerParams, ) -> Result<Self> { - Self::validate_mtu(mtu, bind_addr)?; + Self::validate_mtu(params.mtu, bind_addr)?; let server_config = quinn::ServerConfig::with_crypto(Arc::new( QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?, @@ -83,13 +97,7 @@ impl Server { Ok(Self { endpoint, - params: Arc::new(ServerParams { - allowed_hosts: AllowedIps { - hosts: Arc::new(allowed_hosts), - }, - hostname, - mtu, - }), + params: Arc::new(params), }) } @@ -171,6 +179,13 @@ impl Server { } }; + if let Some(required_auth) = &server_params.auth_header { + match http_request.headers().get(header::AUTHORIZATION) { + Some(actual_auth) if actual_auth == required_auth => (), + _ => return handle_invalid_auth(stream).await, + } + } + if let Some(hostname) = &server_params.hostname { if &proxy_uri.hostname != hostname { let valid_uri = ProxyUri { @@ -389,6 +404,14 @@ async fn handle_established_connection<T: BidiStream<Bytes>>( Ok(()) } +async fn handle_invalid_auth<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { + let response = http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(()) + .unwrap(); + let _ = stream.send_response(response).await; +} + async fn handle_disallowed_ip<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { let response = http::Response::builder() .status(StatusCode::BAD_REQUEST) diff --git a/mullvad-masque-proxy/tests/proxy.rs b/mullvad-masque-proxy/tests/proxy.rs index 29a90ca6fb..25a95bbbcf 100644 --- a/mullvad-masque-proxy/tests/proxy.rs +++ b/mullvad-masque-proxy/tests/proxy.rs @@ -6,6 +6,8 @@ use std::time::Duration; use anyhow::anyhow; use anyhow::Context; use bytes::BytesMut; +use mullvad_masque_proxy::server::AllowedIps; +use mullvad_masque_proxy::server::ServerParams; use mullvad_masque_proxy::MIN_IPV4_MTU; use rand::RngCore; use tokio::fs; @@ -145,14 +147,15 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> { // Set up MASQUE server let server_tls_config = load_server_test_cert().await?; - let server = server::Server::bind( - any_localhost_addr, - Default::default(), - None, - Arc::new(server_tls_config), - mtu, - ) - .context("Failed to start MASQUE server")?; + + let params = ServerParams::builder() + .allowed_hosts(AllowedIps::default()) + .mtu(mtu) + .auth_header(Some("Bearer test".to_owned())) + .build(); + + let server = server::Server::bind(any_localhost_addr, Arc::new(server_tls_config), params) + .context("Failed to start MASQUE server")?; let masque_server_addr = server.local_addr()?; @@ -176,6 +179,7 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> { .target_addr(target_udp_addr) .mtu(mtu) .idle_timeout(Some(Duration::from_secs(10))) + .auth_header(Some("Bearer test".to_owned())) .build(); let client = client::Client::connect(client_config) |
