summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2025-02-13 13:39:39 +0100
committerDavid Lönnhager <david.l@mullvad.net>2025-04-04 22:15:32 +0200
commit60e908ffe92a935c7d9295c13995ec8f1faf0fac (patch)
treef4f220b26ee3e8f95e26d0793e2f6b07872d9424
parentd226825b952d5ae5f993bcd44042f6237c63c133 (diff)
downloadmullvadvpn-60e908ffe92a935c7d9295c13995ec8f1faf0fac.tar.xz
mullvadvpn-60e908ffe92a935c7d9295c13995ec8f1faf0fac.zip
Add initial QUIC obfuscation crate
-rw-r--r--Cargo.lock298
-rw-r--r--Cargo.toml1
-rw-r--r--mullvad-masque-proxy/Cargo.toml28
-rw-r--r--mullvad-masque-proxy/examples/client.rs80
-rw-r--r--mullvad-masque-proxy/examples/server.rs78
-rw-r--r--mullvad-masque-proxy/src/client/mod.rs436
-rw-r--r--mullvad-masque-proxy/src/fragment.rs153
-rw-r--r--mullvad-masque-proxy/src/lib.rs9
-rw-r--r--mullvad-masque-proxy/src/server/mod.rs311
9 files changed, 1392 insertions, 2 deletions
diff --git a/Cargo.lock b/Cargo.lock
index d380fa522c..6a75690afc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -231,6 +231,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]]
+name = "aws-lc-rs"
+version = "1.12.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4c2b7ddaa2c56a367ad27a094ad8ef4faacf8a617c2575acb2ba88949df999ca"
+dependencies = [
+ "aws-lc-sys",
+ "paste",
+ "zeroize",
+]
+
+[[package]]
+name = "aws-lc-sys"
+version = "0.25.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54ac4f13dad353b209b34cbec082338202cbc01c8f00336b55c750c13ac91f8f"
+dependencies = [
+ "bindgen",
+ "cc",
+ "cmake",
+ "dunce",
+ "fs_extra",
+ "paste",
+]
+
+[[package]]
name = "axum"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -326,6 +351,29 @@ dependencies = [
]
[[package]]
+name = "bindgen"
+version = "0.69.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
+dependencies = [
+ "bitflags 2.6.0",
+ "cexpr",
+ "clang-sys",
+ "itertools 0.12.1",
+ "lazy_static",
+ "lazycell",
+ "log",
+ "prettyplease",
+ "proc-macro2",
+ "quote",
+ "regex",
+ "rustc-hash 1.1.0",
+ "shlex",
+ "syn 2.0.100",
+ "which",
+]
+
+[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -483,6 +531,8 @@ version = "1.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c"
dependencies = [
+ "jobserver",
+ "libc",
"shlex",
]
@@ -493,6 +543,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
+name = "cexpr"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
+dependencies = [
+ "nom",
+]
+
+[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -559,6 +618,17 @@ dependencies = [
]
[[package]]
+name = "clang-sys"
+version = "1.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
+dependencies = [
+ "glob",
+ "libc",
+ "libloading",
+]
+
+[[package]]
name = "clap"
version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -619,6 +689,15 @@ dependencies = [
]
[[package]]
+name = "cmake"
+version = "0.1.54"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
+dependencies = [
+ "cc",
+]
+
+[[package]]
name = "colorchoice"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1009,6 +1088,12 @@ dependencies = [
]
[[package]]
+name = "dunce"
+version = "1.0.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
+
+[[package]]
name = "ecdsa"
version = "0.16.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1273,6 +1358,12 @@ dependencies = [
]
[[package]]
+name = "fs_extra"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
+
+[[package]]
name = "fsevent-sys"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1477,6 +1568,43 @@ dependencies = [
]
[[package]]
+name = "h3"
+version = "0.0.6"
+source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac"
+dependencies = [
+ "bytes",
+ "fastrand",
+ "futures-util",
+ "http 1.1.0",
+ "pin-project-lite",
+ "tokio",
+]
+
+[[package]]
+name = "h3-datagram"
+version = "0.0.1"
+source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac"
+dependencies = [
+ "bytes",
+ "h3",
+ "pin-project-lite",
+]
+
+[[package]]
+name = "h3-quinn"
+version = "0.0.7"
+source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac"
+dependencies = [
+ "bytes",
+ "futures",
+ "h3",
+ "h3-datagram",
+ "quinn",
+ "tokio",
+ "tokio-util 0.7.10",
+]
+
+[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2221,6 +2349,15 @@ dependencies = [
]
[[package]]
+name = "jobserver"
+version = "0.1.32"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
+dependencies = [
+ "libc",
+]
+
+[[package]]
name = "js-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2297,6 +2434,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
+name = "lazycell"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
+
+[[package]]
name = "libc"
version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2485,6 +2628,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
+name = "minimal-lexical"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
+
+[[package]]
name = "miniz_oxide"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2813,6 +2962,23 @@ dependencies = [
]
[[package]]
+name = "mullvad-masque-proxy"
+version = "0.1.0"
+dependencies = [
+ "bytes",
+ "clap",
+ "h3",
+ "h3-datagram",
+ "h3-quinn",
+ "http 1.1.0",
+ "quinn",
+ "rand 0.8.5",
+ "rustls 0.23.18",
+ "rustls-pemfile 2.1.3",
+ "tokio",
+]
+
+[[package]]
name = "mullvad-nsis"
version = "0.0.0"
dependencies = [
@@ -3155,6 +3321,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65"
[[package]]
+name = "nom"
+version = "7.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
+dependencies = [
+ "memchr",
+ "minimal-lexical",
+]
+
+[[package]]
name = "notify"
version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3183,12 +3359,31 @@ dependencies = [
]
[[package]]
+name = "num-bigint"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
+dependencies = [
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
+name = "num-integer"
+version = "0.1.46"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
name = "num-traits"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3340,6 +3535,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
+name = "openssl-probe"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
+
+[[package]]
name = "openvpn-plugin"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3918,10 +4119,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
+ "futures-io",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
- "rustc-hash",
+ "rustc-hash 2.1.0",
"rustls 0.23.18",
"socket2",
"thiserror 2.0.9",
@@ -3939,9 +4141,10 @@ dependencies = [
"getrandom 0.2.14",
"rand 0.8.5",
"ring",
- "rustc-hash",
+ "rustc-hash 2.1.0",
"rustls 0.23.18",
"rustls-pki-types",
+ "rustls-platform-verifier",
"slab",
"thiserror 2.0.9",
"tinyvec",
@@ -4231,6 +4434,12 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc-hash"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
+
+[[package]]
+name = "rustc-hash"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497"
@@ -4275,6 +4484,7 @@ version = "0.23.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f"
dependencies = [
+ "aws-lc-rs",
"log",
"once_cell",
"ring",
@@ -4285,6 +4495,19 @@ dependencies = [
]
[[package]]
+name = "rustls-native-certs"
+version = "0.7.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
+dependencies = [
+ "openssl-probe",
+ "rustls-pemfile 2.1.3",
+ "rustls-pki-types",
+ "schannel",
+ "security-framework",
+]
+
+[[package]]
name = "rustls-pemfile"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4313,6 +4536,33 @@ dependencies = [
]
[[package]]
+name = "rustls-platform-verifier"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4c7dc240fec5517e6c4eab3310438636cfe6391dfc345ba013109909a90d136"
+dependencies = [
+ "core-foundation",
+ "core-foundation-sys",
+ "jni",
+ "log",
+ "once_cell",
+ "rustls 0.23.18",
+ "rustls-native-certs",
+ "rustls-platform-verifier-android",
+ "rustls-webpki 0.102.8",
+ "security-framework",
+ "security-framework-sys",
+ "webpki-root-certs",
+ "windows-sys 0.52.0",
+]
+
+[[package]]
+name = "rustls-platform-verifier-android"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
+
+[[package]]
name = "rustls-webpki"
version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4328,6 +4578,7 @@ version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
+ "aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
@@ -4373,6 +4624,15 @@ dependencies = [
]
[[package]]
+name = "schannel"
+version = "0.1.27"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
+dependencies = [
+ "windows-sys 0.59.0",
+]
+
+[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4402,6 +4662,30 @@ dependencies = [
]
[[package]]
+name = "security-framework"
+version = "2.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
+dependencies = [
+ "bitflags 2.6.0",
+ "core-foundation",
+ "core-foundation-sys",
+ "libc",
+ "num-bigint",
+ "security-framework-sys",
+]
+
+[[package]]
+name = "security-framework-sys"
+version = "2.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
+dependencies = [
+ "core-foundation-sys",
+ "libc",
+]
+
+[[package]]
name = "semver"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -5481,6 +5765,7 @@ version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
+ "log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@@ -5791,6 +6076,15 @@ dependencies = [
]
[[package]]
+name = "webpki-root-certs"
+version = "0.26.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09aed61f5e8d2c18344b3faa33a4c837855fe56642757754775548fee21386c4"
+dependencies = [
+ "rustls-pki-types",
+]
+
+[[package]]
name = "webpki-roots"
version = "0.25.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 3768230f7c..eedb961b2c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,6 +23,7 @@ members = [
"mullvad-jni",
"mullvad-leak-checker",
"mullvad-management-interface",
+ "mullvad-masque-proxy",
"mullvad-nsis",
"mullvad-paths",
"mullvad-problem-report",
diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml
new file mode 100644
index 0000000000..ae3a1a4928
--- /dev/null
+++ b/mullvad-masque-proxy/Cargo.toml
@@ -0,0 +1,28 @@
+[package]
+name = "mullvad-masque-proxy"
+version = "0.1.0"
+authors.workspace = true
+repository.workspace = true
+license.workspace = true
+edition.workspace = true
+rust-version.workspace = true
+description = "A limited functionality UDP over HTTP3 proxy"
+
+[dependencies]
+quinn = "0.11"
+tokio = { workspace = true, features = [ "macros", "io-util" ] }
+h3 = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac" }
+h3-datagram = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac" }
+h3-quinn = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac", features = [ "datagram" ]}
+http = "1"
+rustls = { version = "0.23" }
+rustls-pemfile = "2.1.3"
+bytes = "1"
+
+[dev-dependencies]
+tokio = { workspace = true, features = [ "macros", "io-util", "rt-multi-thread" ] }
+clap = { workspace = true }
+rand = "0.8.5"
+
+[lints]
+workspace = true
diff --git a/mullvad-masque-proxy/examples/client.rs b/mullvad-masque-proxy/examples/client.rs
new file mode 100644
index 0000000000..6005e1f5f4
--- /dev/null
+++ b/mullvad-masque-proxy/examples/client.rs
@@ -0,0 +1,80 @@
+use clap::Parser;
+use mullvad_masque_proxy::client::Error;
+use tokio::net::UdpSocket;
+
+use std::{
+ net::{Ipv4Addr, SocketAddr},
+ path::PathBuf,
+};
+
+#[derive(Parser, Debug)]
+pub struct ClientArgs {
+ #[arg(long, short = 't')]
+ target_addr: SocketAddr,
+
+ /// Path to cert
+ #[arg(long, short = 'c', required = false)]
+ root_cert_path: Option<PathBuf>,
+
+ /// Server address
+ #[arg(long, short = 's')]
+ server_addr: SocketAddr,
+
+ #[arg(long, short = 'H')]
+ server_hostname: String,
+
+ #[arg(long, short = 'p', default_value = "0")]
+ bind_port: u16,
+
+ #[arg(long, short = 'S', default_value = "1000")]
+ maximum_packet_size: u16,
+}
+
+#[tokio::main]
+async fn main() {
+ let ClientArgs {
+ server_addr,
+ target_addr,
+ root_cert_path,
+ server_hostname,
+ bind_port,
+ maximum_packet_size,
+ } = ClientArgs::parse();
+
+ let tls_config = match root_cert_path {
+ Some(path) => mullvad_masque_proxy::client::client_tls_config_from_cert_path(path.as_ref())
+ .expect("Failed to get TLS config"),
+ None => mullvad_masque_proxy::client::default_tls_config(),
+ };
+
+ let _keylog = rustls::KeyLogFile::new();
+
+ let unbound_local_addr: SocketAddr = (Ipv4Addr::UNSPECIFIED, bind_port).into();
+ let local_socket = UdpSocket::bind(unbound_local_addr)
+ .await
+ .expect("Failed to bind address");
+ let local_addr = local_socket.local_addr().unwrap();
+ println!("Listening on {local_addr}");
+
+ let client = mullvad_masque_proxy::client::Client::connect_with_tls_config(
+ local_socket,
+ server_addr,
+ (Ipv4Addr::UNSPECIFIED, 0).into(),
+ target_addr,
+ &server_hostname,
+ tls_config,
+ maximum_packet_size,
+ )
+ .await;
+ if let Err(err) = &client {
+ println!("ERROR: {:?}", err);
+ if let Error::Connection(err) = err {
+ println!("ERROR: {}", err);
+ }
+ }
+ client
+ .expect("Failed to connect client")
+ .run()
+ .await
+ .unwrap();
+}
diff --git a/mullvad-masque-proxy/examples/server.rs b/mullvad-masque-proxy/examples/server.rs
new file mode 100644
index 0000000000..f9dfee557f
--- /dev/null
+++ b/mullvad-masque-proxy/examples/server.rs
@@ -0,0 +1,78 @@
+use clap::Parser;
+use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
+
+use std::{
+ fs,
+ net::{IpAddr, SocketAddr},
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+
+#[derive(Parser, Debug)]
+pub struct ServerArgs {
+ #[arg(long, short = 'b', default_value = "0.0.0.0:0")]
+ bind_addr: SocketAddr,
+
+ /// Path to cert
+ #[arg(long, short = 'c')]
+ cert_path: PathBuf,
+
+ /// Path to key
+ #[arg(long, short = 'k')]
+ key_path: PathBuf,
+
+ /// Allowed IPs
+ #[arg(long = "alloewd-ip", short = 'a', required = false)]
+ allowed_ips: Vec<IpAddr>,
+ /// Maximums packet size
+ #[arg(long = "maximum-packet-size", short = 'm', default_value = "1700")]
+ maximum_packet_size: u16,
+}
+
+#[tokio::main]
+async fn main() {
+ let args = ServerArgs::parse();
+ let _keylog = rustls::KeyLogFile::new();
+
+ let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap();
+
+ let server = mullvad_masque_proxy::server::Server::bind(
+ args.bind_addr,
+ args.allowed_ips.iter().cloned().collect(),
+ tls_config.into(),
+ args.maximum_packet_size,
+ )
+ .expect("Failed to initialize server");
+ println!("Listening on {}", args.bind_addr);
+ server.run().await.expect("Server failed.")
+}
+
+fn load_server_config(
+ key_path: &Path,
+ cert_path: &Path,
+) -> Result<rustls::ServerConfig, Box<dyn std::error::Error>> {
+ let key = fs::read(key_path)?;
+ let key = if key_path.extension().is_some_and(|x| x == "der") {
+ PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key))
+ } else {
+ rustls_pemfile::private_key(&mut &*key)?.expect("Expected PEM file to contain private key")
+ };
+ let cert_chain = fs::read(cert_path)?;
+ let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
+ vec![CertificateDer::from(cert_chain)]
+ } else {
+ rustls_pemfile::certs(&mut &*cert_chain).collect::<Result<_, _>>()?
+ };
+
+ let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
+ rustls::crypto::ring::default_provider(),
+ ))
+ .with_protocol_versions(&[&rustls::version::TLS13])?
+ .with_no_client_auth()
+ .with_single_cert(cert_chain, key)?;
+
+ tls_config.max_early_data_size = u32::MAX;
+ tls_config.alpn_protocols = vec![b"h3".into()];
+
+ Ok(tls_config)
+}
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs
new file mode 100644
index 0000000000..e07d0f76d3
--- /dev/null
+++ b/mullvad-masque-proxy/src/client/mod.rs
@@ -0,0 +1,436 @@
+use bytes::{Buf, BytesMut};
+use rustls::client::danger::ServerCertVerified;
+use std::{
+ fs, future, io,
+ net::{Ipv4Addr, SocketAddr},
+ path::Path,
+ sync::{Arc, LazyLock},
+ time::Duration,
+};
+use tokio::{net::UdpSocket, time::interval};
+
+use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId};
+use h3_datagram::datagram_traits::HandleDatagramsExt;
+use http::{header, uri::Scheme, Response, StatusCode};
+use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, TransportConfig};
+
+use crate::fragment::{self, Fragments};
+
+const MAX_HEADER_SIZE: u64 = 8192;
+
+const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem");
+
+pub struct Client {
+ client_socket: UdpSocket,
+ /// QUIC connection, used to send the actual HTTP datagrams
+ connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>,
+ /// Send stream over a QUIC connection - this needs to be kept alive to not close the HTTP
+ /// QUIC stream.
+ _send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
+ /// Request stream for the currently open request, must not be dropped, otherwise proxy
+ /// connection is terminated
+ request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
+ /// Packet fragments
+ fragments: Fragments,
+ /// Maximum packet size
+ maximum_packet_size: u16,
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+#[derive(Debug)]
+pub enum Error {
+ Bind(io::Error),
+ Connect(quinn::ConnectError),
+ Connection(quinn::ConnectionError),
+ /// Connection closed while sending request to initiate proxying
+ ConnectionClosedPrematurely,
+ /// QUIC connection failed while sending request to initiate proxying
+ ConnectionFailed(h3::Error),
+ /// Request failed to illicit a response.
+ RequestError(h3::Error),
+ /// Received response was not a 200.
+ UnexpectedStatus(http::StatusCode),
+ /// Failed to receive data from client socket
+ ClientRead(io::Error),
+ /// Failed to send data to client socket
+ ClientWrite(io::Error),
+ /// Failed to receive data from server socket
+ ServerRead(h3::Error),
+ /// Failed to create a client
+ CreateClient(h3::Error),
+ /// Failed to receive good response from proxy
+ ProxyResponse(h3::Error),
+ /// Failed to construct a URI
+ Uri(http::Error),
+ /// Failed to send datagram to proxy
+ SendDatagram(h3::Error),
+ /// Failed to read certificates
+ ReadCerts(io::Error),
+ /// Failed to parse certificates
+ ParseCerts,
+ /// Failed to fragment a packet - it is too large
+ PacketTooLarge(fragment::PacketTooLarge),
+}
+
+impl Client {
+ pub async fn connect(
+ client_socket: UdpSocket,
+ server_addr: SocketAddr,
+ local_addr: SocketAddr,
+ target_addr: SocketAddr,
+ server_host: &str,
+ maximum_packet_size: u16,
+ ) -> Result<Self> {
+ Self::connect_with_tls_config(
+ client_socket,
+ server_addr,
+ local_addr,
+ target_addr,
+ server_host,
+ default_tls_config(),
+ maximum_packet_size,
+ )
+ .await
+ }
+
+ pub async fn connect_with_tls_config(
+ client_socket: UdpSocket,
+ server_addr: SocketAddr,
+ local_addr: SocketAddr,
+ target_addr: SocketAddr,
+ server_host: &str,
+ tls_config: Arc<rustls::ClientConfig>,
+ maximum_packet_size: u16,
+ ) -> Result<Self> {
+ let quic_client_config = QuicClientConfig::try_from(tls_config)
+ .expect("Failed to construct a valid TLS configuration");
+
+ let mut client_config = ClientConfig::new(Arc::new(quic_client_config));
+ let transport_config = TransportConfig::default();
+ // TODO: Set datagram_receive_buffer_size if needed
+ // TODO: Set datagram_send_buffer_size if needed
+ // When would it be needed? If we need to buffer more packets or buffer less packets for
+ // better performance.
+ client_config.transport_config(Arc::new(transport_config));
+ Self::connect_with_local_addr(
+ client_socket,
+ server_addr,
+ local_addr,
+ target_addr,
+ server_host,
+ client_config,
+ maximum_packet_size,
+ )
+ .await
+ }
+
+ async fn connect_with_local_addr(
+ client_socket: UdpSocket,
+ server_addr: SocketAddr,
+ local_addr: SocketAddr,
+ target_addr: SocketAddr,
+ server_host: &str,
+ client_config: ClientConfig,
+ maximum_packet_size: u16,
+ ) -> Result<Self> {
+ let endpoint = Endpoint::client(local_addr).map_err(Error::Bind)?;
+
+ let connecting = endpoint
+ .connect_with(client_config, server_addr, server_host)
+ .map_err(Error::Connect)?;
+
+ let connection = connecting.await.map_err(Error::Connection)?;
+
+ let (connection, send_stream, request_stream) =
+ Self::setup_h3_connection(connection, target_addr, server_host, maximum_packet_size)
+ .await?;
+
+ Ok(Self {
+ connection,
+ client_socket,
+ request_stream,
+ fragments: Fragments::default(),
+ _send_stream: send_stream,
+ maximum_packet_size,
+ })
+ }
+
+ // Returns an h3 connection that is ready to be used for sending UDP datagrams.
+ async fn setup_h3_connection(
+ connection: quinn::Connection,
+ target: SocketAddr,
+ server_host: &str,
+ maximum_packet_size: u16,
+ ) -> Result<(
+ client::Connection<h3_quinn::Connection, bytes::Bytes>,
+ client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
+ client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
+ )> {
+ let (mut connection, mut send_stream) = client::builder()
+ .max_field_section_size(MAX_HEADER_SIZE)
+ .enable_datagram(true)
+ .send_grease(true)
+ .build(h3_quinn::Connection::new(connection))
+ .await
+ .map_err(Error::CreateClient)?;
+
+ let request = new_connect_request(target, &server_host, maximum_packet_size)?;
+
+ let request_future = async move {
+ let mut request_stream = send_stream.send_request(request).await?;
+ let response = request_stream.recv_response().await?;
+ Ok((response, send_stream, request_stream))
+ };
+
+ tokio::select! {
+ closed = future::poll_fn(|cx| connection.poll_close(cx)) => {
+ match closed {
+ Ok(()) => Err(Error::ConnectionClosedPrematurely),
+ Err(err) => Err(Error::ConnectionFailed(err)),
+ }
+ },
+ response = request_future => {
+ let (response, send_stream, request_stream) = response.map_err(Error::RequestError)?;
+ handle_response(response)?;
+ Ok((connection, send_stream, request_stream))
+ },
+ }
+ }
+
+ pub async fn run(mut self) -> Result<()> {
+ let stream_id: StreamId = self.request_stream.id();
+ // this is the variable ID used to signify UDP payloads in HTTP datagrams.
+ let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024);
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf);
+
+ let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
+ let mut fragment_id = 1u16;
+ let mut interval = interval(Duration::from_secs(3));
+
+ loop {
+ tokio::select! {
+ client_read = self.client_socket.recv_buf_from(&mut client_read_buf) => {
+ let (_bytes_received, recv_addr) = client_read.map_err(Error::ClientRead)?;
+ return_addr = recv_addr;
+
+ let mut send_buf = client_read_buf.split().freeze();
+ if send_buf.len() < (Into::<usize>::into(self.maximum_packet_size) - 100usize) {
+ self.connection
+ .send_datagram(stream_id, send_buf)
+ .map_err(Error::SendDatagram)?;
+ } else {
+ // drop the added context ID, since packet will have to be fragmented.
+ {
+ let _ = VarInt::decode(&mut send_buf);
+ }
+ for fragment in fragment::fragment_packet(
+ self.maximum_packet_size,
+ &mut send_buf,
+ fragment_id)
+ .map_err(Error::PacketTooLarge)
+ ? {
+ self.connection.send_datagram(stream_id, fragment).map_err(Error::SendDatagram)?;
+ }
+ fragment_id = fragment_id.wrapping_add(1);
+ }
+
+ client_read_buf.reserve(crate::PACKET_BUFFER_SIZE);
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf);
+ },
+ server_response = self.connection.read_datagram() => {
+ match server_response {
+ Ok(Some(response)) => {
+ if response.stream_id() != stream_id {
+ // log::trace!("Received datagram with an unexpected stream ID");
+ continue;
+ }
+ let mut payload = response.into_payload();
+ let context = VarInt::decode(&mut payload);
+ match context {
+ Ok(crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID) => {
+ self.client_socket
+ .send_to(payload.as_ref(), return_addr)
+ .await
+ .map_err(Error::ClientWrite)?;
+ }
+ Ok(crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID) => {
+ if let Ok(Some(payload)) = self.fragments.handle_incoming_packet(payload) {
+ self.client_socket
+ .send_to(payload.chunk(), return_addr)
+ .await
+ .map_err(Error::ClientWrite)?;
+ }
+ },
+ _ => (),
+
+ }
+ }
+ Ok(None) => {
+ return Ok(());
+ }
+ Err(err) => {
+ return Err(Error::ProxyResponse(err));
+ }
+ }
+ },
+ _ = interval.tick() => {
+ self.fragments.clear_old_fragments(
+ Duration::from_secs(3)
+ );
+ },
+ };
+ }
+ }
+}
+
+fn new_connect_request(
+ socket_addr: SocketAddr,
+ authority: &dyn AsRef<str>,
+ maximum_packet_size: u16,
+) -> Result<http::Request<()>> {
+ let host = socket_addr.ip();
+ let port = socket_addr.port();
+ let path = format!("/.well-known/masque/udp/{host}/{port}/");
+ let uri = http::uri::Builder::new()
+ .scheme(Scheme::HTTPS)
+ .authority(authority.as_ref())
+ .path_and_query(&path)
+ .build()
+ .map_err(Error::Uri)?;
+
+ let mut request = http::Request::builder()
+ .method(http::method::Method::CONNECT)
+ .uri(uri)
+ .header(b"Capsule-Protocol".as_slice(), b"?1".as_slice())
+ .header(header::AUTHORIZATION, b"Bearer test".as_slice())
+ .header(header::HOST, authority.as_ref())
+ .header(
+ b"X-Mullvad-Uplink-Mtu".as_slice(),
+ format!("{maximum_packet_size}"),
+ )
+ .body(())
+ .expect("failed to construct a body");
+
+ request.extensions_mut().insert(Protocol::CONNECT_UDP);
+ Ok(request)
+}
+
+fn handle_response(response: Response<()>) -> Result<()> {
+ if response.status() != StatusCode::OK {
+ return Err(Error::UnexpectedStatus(response.status()));
+ }
+ Ok(())
+}
+
+// TODO: resuse the same TLS code from `mullvad-api` maybe
+pub fn default_tls_config() -> Arc<rustls::ClientConfig> {
+ static TLS_CONFIG: LazyLock<Arc<rustls::ClientConfig>> =
+ LazyLock::new(|| client_tls_config_with_certs(read_cert_store()));
+
+ TLS_CONFIG.clone()
+}
+
+fn client_tls_config_with_certs(certs: rustls::RootCertStore) -> Arc<rustls::ClientConfig> {
+ let mut config = rustls::ClientConfig::builder_with_provider(Arc::new(
+ rustls::crypto::ring::default_provider(),
+ ))
+ .with_protocol_versions(&[&rustls::version::TLS13])
+ .expect("ring crypt-prover should support TLS 1.3")
+ .with_root_certificates(certs)
+ .with_no_client_auth();
+ config.alpn_protocols = vec![b"h3".to_vec()];
+
+ let approver = Approver {};
+ config.key_log = Arc::new(rustls::KeyLogFile::new());
+ config
+ .dangerous()
+ .set_certificate_verifier(Arc::new(approver));
+ Arc::new(config)
+}
+
+fn read_cert_store() -> rustls::RootCertStore {
+ read_cert_store_from_reader(&mut std::io::BufReader::new(LE_ROOT_CERT))
+ .expect("failed to read built-in cert store")
+}
+
+pub fn client_tls_config_from_cert_path(path: &Path) -> Result<Arc<rustls::ClientConfig>> {
+ let certs = read_cert_store_from_path(path)?;
+ Ok(client_tls_config_with_certs(certs))
+}
+
+fn read_cert_store_from_path(path: &Path) -> Result<rustls::RootCertStore> {
+ let cert_path = fs::File::open(path).map_err(Error::ReadCerts)?;
+ read_cert_store_from_reader(&mut std::io::BufReader::new(cert_path))
+}
+
+fn read_cert_store_from_reader(reader: &mut dyn io::BufRead) -> Result<rustls::RootCertStore> {
+ let mut cert_store = rustls::RootCertStore::empty();
+
+ let certs = rustls_pemfile::certs(reader)
+ .collect::<std::result::Result<Vec<_>, _>>()
+ .map_err(Error::ReadCerts)?;
+ let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(certs);
+ if num_failures > 0 || num_certs_added == 0 {
+ return Err(Error::ParseCerts);
+ }
+
+ Ok(cert_store)
+}
+
+#[test]
+fn test_zero_stream_id() {
+ h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?");
+}
+
+#[derive(Debug)]
+struct Approver {}
+
+impl rustls::client::danger::ServerCertVerifier for Approver {
+ fn verify_server_cert(
+ &self,
+ _end_entity: &rustls::pki_types::CertificateDer<'_>,
+ _intermediates: &[rustls::pki_types::CertificateDer<'_>],
+ _server_name: &rustls::pki_types::ServerName<'_>,
+ _ocsp_response: &[u8],
+ _now: rustls::pki_types::UnixTime,
+ ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
+ Ok(ServerCertVerified::assertion())
+ }
+
+ fn verify_tls12_signature(
+ &self,
+ _message: &[u8],
+ _cert: &rustls::pki_types::CertificateDer<'_>,
+ _dss: &rustls::DigitallySignedStruct,
+ ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
+ Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
+ }
+
+ fn verify_tls13_signature(
+ &self,
+ _message: &[u8],
+ _cert: &rustls::pki_types::CertificateDer<'_>,
+ _dss: &rustls::DigitallySignedStruct,
+ ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
+ Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
+ }
+
+ fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
+ vec![
+ rustls::SignatureScheme::RSA_PKCS1_SHA1,
+ rustls::SignatureScheme::ECDSA_SHA1_Legacy,
+ rustls::SignatureScheme::RSA_PKCS1_SHA256,
+ rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
+ rustls::SignatureScheme::RSA_PKCS1_SHA384,
+ rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
+ rustls::SignatureScheme::RSA_PKCS1_SHA512,
+ rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
+ rustls::SignatureScheme::RSA_PSS_SHA256,
+ rustls::SignatureScheme::RSA_PSS_SHA384,
+ rustls::SignatureScheme::RSA_PSS_SHA512,
+ rustls::SignatureScheme::ED25519,
+ rustls::SignatureScheme::ED448,
+ ]
+ }
+}
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs
new file mode 100644
index 0000000000..6224ab7f63
--- /dev/null
+++ b/mullvad-masque-proxy/src/fragment.rs
@@ -0,0 +1,153 @@
+use std::{
+ collections::BTreeMap,
+ time::{Duration, Instant},
+};
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use h3::proto::varint::VarInt;
+
+#[derive(Default)]
+pub struct Fragments {
+ fragment_map: BTreeMap<u16, Vec<Fragment>>,
+}
+
+// When a packet that arrives is too small to be decoded.
+#[derive(Debug)]
+pub enum DefragError {
+ BadContextId(Result<VarInt, h3::proto::coding::UnexpectedEnd>),
+ PayloadTooSmall,
+}
+
+// When a packet is larger than u16::MAX, it can't be fragmented.
+#[derive(Debug)]
+pub struct PacketTooLarge(pub usize);
+
+impl Fragments {
+ // TODO: Let caller provide output buffer.
+ pub fn handle_incoming_packet(
+ &mut self,
+ mut payload: Bytes,
+ ) -> Result<Option<Bytes>, DefragError> {
+ match VarInt::decode(&mut payload) {
+ Ok(crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID) => {
+ return Ok(Some(payload));
+ }
+ Ok(crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID) => {}
+ unexpected_context_id => {
+ return Err(DefragError::BadContextId(unexpected_context_id));
+ }
+ }
+
+ let id = payload
+ .try_get_u16()
+ .map_err(|_| DefragError::PayloadTooSmall)?;
+ let index = payload
+ .try_get_u8()
+ .map_err(|_| DefragError::PayloadTooSmall)?;
+ let fragment_count = payload
+ .try_get_u8()
+ .map_err(|_| DefragError::PayloadTooSmall)?;
+ let fragment = Fragment {
+ index,
+ payload,
+ time_received: Instant::now(),
+ };
+
+ let fragments = self.fragment_map.entry(id).or_default();
+ fragments.push(fragment);
+
+ Ok(self.try_fetch(id, fragment_count))
+ }
+
+ // TODO: Let caller provide output buffer.
+ fn try_fetch(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> {
+ // establish that there are enough fragments to reconstruct the whole packet
+ let payload = {
+ let fragments = self.fragment_map.get_mut(&id)?;
+
+ if fragments.len() != fragment_count.into() {
+ return None;
+ }
+
+ fragments.sort_by_key(|f| f.index);
+ let mut payload =
+ BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum());
+ for fragment in fragments {
+ payload.extend_from_slice(&fragment.payload);
+ }
+ payload
+ };
+
+ self.fragment_map.remove(&id);
+ Some(payload.into())
+ }
+
+ pub fn clear_old_fragments(&mut self, max_age: Duration) {
+ self.fragment_map.retain(|_, fragments| {
+ fragments
+ .iter()
+ .any(|fragment| fragment.time_received.elapsed() <= max_age)
+ });
+ }
+}
+
+struct Fragment {
+ index: u8,
+ payload: Bytes,
+ time_received: Instant,
+}
+
+pub fn fragment_packet(
+ maximum_packet_size: u16,
+ payload: &'_ mut Bytes,
+ packet_id: u16,
+) -> Result<impl Iterator<Item = Bytes> + '_, PacketTooLarge> {
+ let num_fragments: usize = payload.chunks(maximum_packet_size.into()).count();
+ let Ok(fragment_count): std::result::Result<u8, _> = num_fragments.try_into() else {
+ return Err(PacketTooLarge(payload.len()));
+ };
+
+ let iterator = payload.chunks(maximum_packet_size.into()).enumerate().map(
+ move |(fragment_index, fragment_payload)| {
+ let mut fragment = BytesMut::with_capacity((maximum_packet_size + 1).into());
+ crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment);
+ fragment.put_u16(packet_id);
+ fragment.put_u8(
+ // fragment indexes start at 1
+ u8::try_from(fragment_index +1)
+ .expect("fragment index must fit in an u8, since num_fragments fits is an u8"),
+ );
+ fragment.put_u8(fragment_count);
+ fragment.extend_from_slice(fragment_payload);
+ fragment.freeze()
+ },
+ );
+ Ok(iterator)
+}
+
+#[test]
+fn test_fragment_reconstruction() {
+ use rand::{seq::SliceRandom, thread_rng};
+
+ let payload = (0..255).collect::<Vec<u8>>();
+ let max_payload_size = 50;
+ let packet_id = 76;
+
+ let mut fragments = Fragments::default();
+
+ let mut payload_clone = Bytes::from(payload.clone());
+ let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
+ .unwrap()
+ .collect::<Vec<_>>();
+
+ fragment_buf.shuffle(&mut thread_rng());
+
+ for fragment in fragment_buf {
+ if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() {
+ assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
+ return;
+ }
+ }
+
+ panic!("Failed to reconstruct packet");
+}
diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs
new file mode 100644
index 0000000000..bb973d3a80
--- /dev/null
+++ b/mullvad-masque-proxy/src/lib.rs
@@ -0,0 +1,9 @@
+use h3::proto::varint::VarInt;
+
+pub mod client;
+mod fragment;
+pub mod server;
+
+const PACKET_BUFFER_SIZE: usize = 1700;
+pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0);
+pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1);
diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs
new file mode 100644
index 0000000000..b6d263066a
--- /dev/null
+++ b/mullvad-masque-proxy/src/server/mod.rs
@@ -0,0 +1,311 @@
+use std::{
+ collections::HashSet,
+ io,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
+ sync::Arc,
+ time::Duration,
+};
+
+use bytes::{Bytes, BytesMut};
+use h3::{
+ proto::varint::VarInt,
+ quic::{BidiStream, StreamId},
+ server::{self, Connection, RequestStream},
+};
+use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt};
+use http::{Request, StatusCode};
+use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming};
+use tokio::{net::UdpSocket, time::interval};
+
+use crate::fragment::{self, Fragments};
+
+#[derive(Debug)]
+pub enum Error {
+ BadTlsConfig(quinn::crypto::rustls::NoInitialCipherSuite),
+ BindSocket(io::Error),
+ SendNegotiationResponse(h3::Error),
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/";
+
+pub struct Server {
+ endpoint: Endpoint,
+ allowed_hosts: AllowedIps,
+ max_packet_size: u16,
+}
+
+#[derive(Clone)]
+struct AllowedIps {
+ hosts: Arc<HashSet<IpAddr>>,
+}
+
+impl AllowedIps {
+ fn ip_allowed(&self, ip: IpAddr) -> bool {
+ self.hosts.is_empty() || self.hosts.contains(&ip)
+ }
+}
+
+impl Server {
+ pub fn bind(
+ bind_addr: SocketAddr,
+ allowed_hosts: HashSet<IpAddr>,
+ tls_config: Arc<rustls::ServerConfig>,
+ max_packet_size: u16,
+ ) -> Result<Self> {
+ let server_config = quinn::ServerConfig::with_crypto(Arc::new(
+ QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?,
+ ));
+
+ let endpoint = Endpoint::server(server_config, bind_addr).map_err(Error::BindSocket)?;
+
+ Ok(Self {
+ endpoint,
+ allowed_hosts: AllowedIps {
+ hosts: Arc::new(allowed_hosts),
+ },
+ max_packet_size,
+ })
+ }
+
+ pub async fn run(self) -> Result<()> {
+ while let Some(new_connection) = self.endpoint.accept().await {
+ tokio::spawn(Self::handle_incoming_connection(
+ new_connection,
+ self.allowed_hosts.clone(),
+ self.max_packet_size,
+ ));
+ }
+ Ok(())
+ }
+
+ async fn handle_incoming_connection(
+ connection: Incoming,
+ allowed_hosts: AllowedIps,
+ maximum_packet_size: u16,
+ ) {
+ match connection.await {
+ Ok(conn) => {
+ println!("new connection established");
+
+ let Ok(mut connection) = server::builder()
+ .enable_datagram(true)
+ .build(h3_quinn::Connection::new(conn))
+ .await
+ else {
+ println!("Failed to construct a new H3 server connection");
+ return;
+ };
+
+ match connection.accept().await {
+ Ok(Some((req, stream))) => {
+ tokio::spawn(Self::handle_proxy_request(
+ connection,
+ req,
+ stream,
+ allowed_hosts.clone(),
+ maximum_packet_size,
+ ));
+ }
+
+ // indicating no more streams to be received
+ Ok(None) => {}
+
+ Err(err) => {
+ println!("error on accept {}", err);
+ return;
+ }
+ }
+ }
+ Err(err) => {
+ println!("accepting connection failed: {:?}", err);
+ }
+ }
+ }
+
+ async fn handle_proxy_request<T: BidiStream<Bytes>>(
+ mut connection: Connection<h3_quinn::Connection, Bytes>,
+ request: Request<()>,
+ mut stream: RequestStream<T, Bytes>,
+ allowed_hosts: AllowedIps,
+ maximum_packet_size: u16,
+ ) {
+ let Some(target_addr) = get_target_socketaddr(request.uri().path()) else {
+ return;
+ };
+ if !allowed_hosts.ip_allowed(target_addr.ip()) {
+ return handle_disallowed_ip(stream).await;
+ }
+
+ let bind_addr = SocketAddr::new(unspecified_addr(target_addr.ip()), 0);
+ let Ok(udp_socket) = UdpSocket::bind(bind_addr).await else {
+ return handle_failed_socket(stream).await;
+ };
+ if let Err(err) = udp_socket.connect(target_addr).await {
+ println!("Failed to set destination for UDP socket: {err}");
+ return handle_failed_socket(stream).await;
+ };
+
+ if handle_established_connection(&mut stream).await.is_err() {
+ return;
+ }
+
+ let stream_id = stream.id();
+ let mut proxy_recv_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE);
+
+ let mut fragments = Fragments::default();
+ let mut fragment_id = 0u16;
+
+ let mut interval = interval(Duration::from_secs(3));
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf);
+
+ loop {
+ tokio::select! {
+ client_send = connection.read_datagram() => {
+ match client_send {
+ Ok(Some(received_packet)) => {
+ handle_client_packet(received_packet, stream_id, &mut fragments, &udp_socket, target_addr).await;
+ },
+ Ok(None) => {
+ return;
+ }
+ Err(_err) => {
+ // client connection QUIC connection failed, should return now.
+ return;
+ },
+ }
+ },
+ recv_result = udp_socket.recv_buf_from(&mut proxy_recv_buf) => {
+ match recv_result {
+ Ok((_bytes_received, sender_addr)) => {
+ if sender_addr != target_addr {
+ continue
+ }
+
+ let mut received_packet = proxy_recv_buf.split().freeze();
+
+ if proxy_recv_buf.len() < maximum_packet_size.into() {
+ if connection.send_datagram(stream_id, received_packet).is_err() {
+ return;
+ }
+ } else {
+ let _ = VarInt::decode(&mut received_packet);
+ let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; };
+ fragment_id += 1;
+ for payload in fragments {
+ if connection.send_datagram(stream_id, payload).is_err() {
+ return;
+ }
+ }
+ };
+
+ proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE);
+ crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf);
+ },
+ Err(err) => {
+ println!("Failed to receive packet from proxy connection: {err}");
+ let _ = stream.finish().await;
+ return;
+ }
+ }
+ },
+ _ = interval.tick() => {
+ fragments.clear_old_fragments(
+ Duration::from_secs(3)
+ );
+ },
+ };
+ }
+ }
+}
+
+async fn handle_client_packet(
+ received_packet: Datagram,
+ stream_id: StreamId,
+ fragments: &mut Fragments,
+ proxy_socket: &UdpSocket,
+ target_addr: SocketAddr,
+) {
+ if received_packet.stream_id() != stream_id {
+ // log::trace!("Received unexpected stream ID from server");
+ return;
+ }
+
+ if let Ok(Some(payload)) = fragments.handle_incoming_packet(received_packet.into_payload()) {
+ let _ = proxy_socket.send_to(&payload, target_addr).await;
+ }
+}
+
+async fn handle_established_connection<T: BidiStream<Bytes>>(
+ stream: &mut RequestStream<T, Bytes>,
+) -> Result<()> {
+ let response = http::Response::builder()
+ .status(StatusCode::OK)
+ .body(())
+ .unwrap();
+ stream
+ .send_response(response)
+ .await
+ .map_err(Error::SendNegotiationResponse)?;
+ Ok(())
+}
+
+async fn handle_disallowed_ip<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) {
+ let response = http::Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body(())
+ .unwrap();
+ let _ = stream.send_response(response).await;
+}
+
+async fn handle_failed_socket<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) {
+ let response = http::Response::builder()
+ .status(StatusCode::BAD_GATEWAY)
+ .body(())
+ .unwrap();
+ let _ = stream.send_response(response).await;
+}
+
+fn get_target_socketaddr(request_path: &str) -> Option<SocketAddr> {
+ // Establish if the URL path looks like `/.well-known/masque/udp/{ip}/{port}`
+ if !request_path.starts_with(MASQUE_WELL_KNOWN_PATH) {
+ return None;
+ };
+ let (addr_str, port_str) = request_path
+ .strip_prefix(MASQUE_WELL_KNOWN_PATH)?
+ .trim_start_matches('/')
+ .split_once('/')?;
+ let port_str = port_str.trim_end_matches('/');
+
+ Some(SocketAddr::new(
+ addr_str.trim_start_matches('/').parse().ok()?,
+ port_str.parse().ok()?,
+ ))
+}
+
+fn unspecified_addr(addr: IpAddr) -> IpAddr {
+ match addr {
+ IpAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(),
+ IpAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(),
+ }
+}
+
+#[test]
+fn test_get_good_slashy_ocketaddr() {
+ let addr: IpAddr = "192.168.1.1".parse().unwrap();
+ let port: u16 = 7979;
+ let expected_addr = SocketAddr::new(addr, port);
+ let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////");
+
+ assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr)
+}
+
+#[test]
+fn test_get_bad_socketaddr() {
+ let addr: IpAddr = "192.168.1.1".parse().unwrap();
+ let port: u16 = 7979;
+ let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}");
+
+ assert_eq!(get_target_socketaddr(&good_path), None)
+}