summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2025-04-30 18:38:24 +0200
committerDavid Lönnhager <david.l@mullvad.net>2025-05-05 13:02:29 +0200
commitee380d68e2e12ef52b49cdb7439e3d2bb1637bcf (patch)
tree8fc8db3c8fe81857f5e1883d7b2d941b330bd56d
parentc026f50ce79454596e247b4027a885df5df0d212 (diff)
downloadmullvadvpn-ee380d68e2e12ef52b49cdb7439e3d2bb1637bcf.tar.xz
mullvadvpn-ee380d68e2e12ef52b49cdb7439e3d2bb1637bcf.zip
Add masque server params builder and authorization header option
-rw-r--r--mullvad-masque-proxy/examples/masque-server.rs23
-rw-r--r--mullvad-masque-proxy/src/server/mod.rs61
-rw-r--r--mullvad-masque-proxy/tests/proxy.rs20
3 files changed, 69 insertions, 35 deletions
diff --git a/mullvad-masque-proxy/examples/masque-server.rs b/mullvad-masque-proxy/examples/masque-server.rs
index e884e91c93..b321fcf659 100644
--- a/mullvad-masque-proxy/examples/masque-server.rs
+++ b/mullvad-masque-proxy/examples/masque-server.rs
@@ -1,4 +1,5 @@
use clap::Parser;
+use mullvad_masque_proxy::server::{AllowedIps, ServerParams};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::{
@@ -36,6 +37,10 @@ pub struct ServerArgs {
/// Maximum packet size
#[arg(long, short = 'm', default_value = "1700")]
mtu: u16,
+
+ /// Authorization header value to set
+ #[arg(long, default_value = "Bearer test")]
+ auth: Option<String>,
}
#[tokio::main]
@@ -50,14 +55,16 @@ async fn main() {
let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap();
- let server = mullvad_masque_proxy::server::Server::bind(
- args.bind_addr,
- args.allowed_ips.iter().cloned().collect(),
- args.hostname,
- tls_config.into(),
- args.mtu,
- )
- .expect("Failed to initialize server");
+ let params = ServerParams::builder()
+ .allowed_hosts(AllowedIps::from(args.allowed_ips))
+ .hostname(args.hostname)
+ .mtu(args.mtu)
+ .auth_header(args.auth)
+ .build();
+
+ let server =
+ mullvad_masque_proxy::server::Server::bind(args.bind_addr, tls_config.into(), params)
+ .expect("Failed to initialize server");
log::info!("Listening on {}", args.bind_addr);
server.run().await.expect("Server failed.")
}
diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs
index 4f160eaea6..af1cb3e1d4 100644
--- a/mullvad-masque-proxy/src/server/mod.rs
+++ b/mullvad-masque-proxy/src/server/mod.rs
@@ -14,9 +14,10 @@ use h3::{
server::{self, Connection, RequestStream},
};
use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt};
-use http::{StatusCode, Uri};
+use http::{header, StatusCode, Uri};
use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming};
use tokio::{net::UdpSocket, select, sync::mpsc, task};
+use typed_builder::TypedBuilder;
use crate::{
compute_udp_payload_size,
@@ -43,22 +44,37 @@ pub struct Server {
params: Arc<ServerParams>,
}
-struct ServerParams {
+#[derive(TypedBuilder)]
+pub struct ServerParams {
/// Allowed target IPs for the proxy connection
- allowed_hosts: AllowedIps,
+ pub allowed_hosts: AllowedIps,
- /// Server hostname (optional)
- hostname: Option<String>,
+ /// Server hostname expected from clients
+ #[builder(default)]
+ pub hostname: Option<String>,
/// Maximum transfer unit
- mtu: u16,
+ #[builder(default = 1500)]
+ pub mtu: u16,
+
+ /// Authorization header expected from clients
+ #[builder(default)]
+ pub auth_header: Option<String>,
}
-#[derive(Clone)]
-struct AllowedIps {
+#[derive(Default, Clone)]
+pub struct AllowedIps {
hosts: Arc<HashSet<IpAddr>>,
}
+impl<T: IntoIterator<Item = IpAddr>> From<T> for AllowedIps {
+ fn from(value: T) -> Self {
+ AllowedIps {
+ hosts: Arc::new(value.into_iter().collect()),
+ }
+ }
+}
+
impl AllowedIps {
fn ip_allowed(&self, ip: IpAddr) -> bool {
self.hosts.is_empty() || self.hosts.contains(&ip)
@@ -68,12 +84,10 @@ impl AllowedIps {
impl Server {
pub fn bind(
bind_addr: SocketAddr,
- allowed_hosts: HashSet<IpAddr>,
- hostname: Option<String>,
tls_config: Arc<rustls::ServerConfig>,
- mtu: u16,
+ params: ServerParams,
) -> Result<Self> {
- Self::validate_mtu(mtu, bind_addr)?;
+ Self::validate_mtu(params.mtu, bind_addr)?;
let server_config = quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?,
@@ -83,13 +97,7 @@ impl Server {
Ok(Self {
endpoint,
- params: Arc::new(ServerParams {
- allowed_hosts: AllowedIps {
- hosts: Arc::new(allowed_hosts),
- },
- hostname,
- mtu,
- }),
+ params: Arc::new(params),
})
}
@@ -171,6 +179,13 @@ impl Server {
}
};
+ if let Some(required_auth) = &server_params.auth_header {
+ match http_request.headers().get(header::AUTHORIZATION) {
+ Some(actual_auth) if actual_auth == required_auth => (),
+ _ => return handle_invalid_auth(stream).await,
+ }
+ }
+
if let Some(hostname) = &server_params.hostname {
if &proxy_uri.hostname != hostname {
let valid_uri = ProxyUri {
@@ -389,6 +404,14 @@ async fn handle_established_connection<T: BidiStream<Bytes>>(
Ok(())
}
+async fn handle_invalid_auth<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) {
+ let response = http::Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body(())
+ .unwrap();
+ let _ = stream.send_response(response).await;
+}
+
async fn handle_disallowed_ip<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) {
let response = http::Response::builder()
.status(StatusCode::BAD_REQUEST)
diff --git a/mullvad-masque-proxy/tests/proxy.rs b/mullvad-masque-proxy/tests/proxy.rs
index 29a90ca6fb..25a95bbbcf 100644
--- a/mullvad-masque-proxy/tests/proxy.rs
+++ b/mullvad-masque-proxy/tests/proxy.rs
@@ -6,6 +6,8 @@ use std::time::Duration;
use anyhow::anyhow;
use anyhow::Context;
use bytes::BytesMut;
+use mullvad_masque_proxy::server::AllowedIps;
+use mullvad_masque_proxy::server::ServerParams;
use mullvad_masque_proxy::MIN_IPV4_MTU;
use rand::RngCore;
use tokio::fs;
@@ -145,14 +147,15 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> {
// Set up MASQUE server
let server_tls_config = load_server_test_cert().await?;
- let server = server::Server::bind(
- any_localhost_addr,
- Default::default(),
- None,
- Arc::new(server_tls_config),
- mtu,
- )
- .context("Failed to start MASQUE server")?;
+
+ let params = ServerParams::builder()
+ .allowed_hosts(AllowedIps::default())
+ .mtu(mtu)
+ .auth_header(Some("Bearer test".to_owned()))
+ .build();
+
+ let server = server::Server::bind(any_localhost_addr, Arc::new(server_tls_config), params)
+ .context("Failed to start MASQUE server")?;
let masque_server_addr = server.local_addr()?;
@@ -176,6 +179,7 @@ async fn setup_masque(mtu: u16) -> anyhow::Result<(UdpSocket, UdpSocket)> {
.target_addr(target_udp_addr)
.mtu(mtu)
.idle_timeout(Some(Duration::from_secs(10)))
+ .auth_header(Some("Bearer test".to_owned()))
.build();
let client = client::Client::connect(client_config)