diff options
| author | Emīls <emils@mullvad.net> | 2025-02-13 13:39:39 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-04-04 22:15:32 +0200 |
| commit | 60e908ffe92a935c7d9295c13995ec8f1faf0fac (patch) | |
| tree | f4f220b26ee3e8f95e26d0793e2f6b07872d9424 | |
| parent | d226825b952d5ae5f993bcd44042f6237c63c133 (diff) | |
| download | mullvadvpn-60e908ffe92a935c7d9295c13995ec8f1faf0fac.tar.xz mullvadvpn-60e908ffe92a935c7d9295c13995ec8f1faf0fac.zip | |
Add initial QUIC obfuscation crate
| -rw-r--r-- | Cargo.lock | 298 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-masque-proxy/Cargo.toml | 28 | ||||
| -rw-r--r-- | mullvad-masque-proxy/examples/client.rs | 80 | ||||
| -rw-r--r-- | mullvad-masque-proxy/examples/server.rs | 78 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/client/mod.rs | 436 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/fragment.rs | 153 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/lib.rs | 9 | ||||
| -rw-r--r-- | mullvad-masque-proxy/src/server/mod.rs | 311 |
9 files changed, 1392 insertions, 2 deletions
diff --git a/Cargo.lock b/Cargo.lock index d380fa522c..6a75690afc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -231,6 +231,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] +name = "aws-lc-rs" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b7ddaa2c56a367ad27a094ad8ef4faacf8a617c2575acb2ba88949df999ca" +dependencies = [ + "aws-lc-sys", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54ac4f13dad353b209b34cbec082338202cbc01c8f00336b55c750c13ac91f8f" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "paste", +] + +[[package]] name = "axum" version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -326,6 +351,29 @@ dependencies = [ ] [[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.6.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.100", + "which", +] + +[[package]] name = "bit-set" version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -483,6 +531,8 @@ version = "1.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -493,6 +543,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" [[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + +[[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -559,6 +618,17 @@ dependencies = [ ] [[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] name = "clap" version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -619,6 +689,15 @@ dependencies = [ ] [[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] name = "colorchoice" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1009,6 +1088,12 @@ dependencies = [ ] [[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] name = "ecdsa" version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1273,6 +1358,12 @@ dependencies = [ ] [[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] name = "fsevent-sys" version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1477,6 +1568,43 @@ dependencies = [ ] [[package]] +name = "h3" +version = "0.0.6" +source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac" +dependencies = [ + "bytes", + "fastrand", + "futures-util", + "http 1.1.0", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "h3-datagram" +version = "0.0.1" +source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac" +dependencies = [ + "bytes", + "h3", + "pin-project-lite", +] + +[[package]] +name = "h3-quinn" +version = "0.0.7" +source = "git+https://github.com/mullvad/h3?rev=01ae01192300d29abaf6a3233d862e40c9f92bac#01ae01192300d29abaf6a3233d862e40c9f92bac" +dependencies = [ + "bytes", + "futures", + "h3", + "h3-datagram", + "quinn", + "tokio", + "tokio-util 0.7.10", +] + +[[package]] name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2221,6 +2349,15 @@ dependencies = [ ] [[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] name = "js-sys" version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2297,6 +2434,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] name = "libc" version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2485,6 +2628,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] name = "miniz_oxide" version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2813,6 +2962,23 @@ dependencies = [ ] [[package]] +name = "mullvad-masque-proxy" +version = "0.1.0" +dependencies = [ + "bytes", + "clap", + "h3", + "h3-datagram", + "h3-quinn", + "http 1.1.0", + "quinn", + "rand 0.8.5", + "rustls 0.23.18", + "rustls-pemfile 2.1.3", + "tokio", +] + +[[package]] name = "mullvad-nsis" version = "0.0.0" dependencies = [ @@ -3155,6 +3321,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65" [[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] name = "notify" version = "6.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3183,12 +3359,31 @@ dependencies = [ ] [[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] name = "num-traits" version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3340,6 +3535,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] name = "openvpn-plugin" version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3918,10 +4119,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", + "futures-io", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 2.1.0", "rustls 0.23.18", "socket2", "thiserror 2.0.9", @@ -3939,9 +4141,10 @@ dependencies = [ "getrandom 0.2.14", "rand 0.8.5", "ring", - "rustc-hash", + "rustc-hash 2.1.0", "rustls 0.23.18", "rustls-pki-types", + "rustls-platform-verifier", "slab", "thiserror 2.0.9", "tinyvec", @@ -4231,6 +4434,12 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc-hash" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" @@ -4275,6 +4484,7 @@ version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -4285,6 +4495,19 @@ dependencies = [ ] [[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] name = "rustls-pemfile" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4313,6 +4536,33 @@ dependencies = [ ] [[package]] +name = "rustls-platform-verifier" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4c7dc240fec5517e6c4eab3310438636cfe6391dfc345ba013109909a90d136" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls 0.23.18", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki 0.102.8", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] name = "rustls-webpki" version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4328,6 +4578,7 @@ version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -4373,6 +4624,15 @@ dependencies = [ ] [[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4402,6 +4662,30 @@ dependencies = [ ] [[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "num-bigint", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] name = "semver" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5481,6 +5765,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -5791,6 +6076,15 @@ dependencies = [ ] [[package]] +name = "webpki-root-certs" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09aed61f5e8d2c18344b3faa33a4c837855fe56642757754775548fee21386c4" +dependencies = [ + "rustls-pki-types", +] + +[[package]] name = "webpki-roots" version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index 3768230f7c..eedb961b2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "mullvad-jni", "mullvad-leak-checker", "mullvad-management-interface", + "mullvad-masque-proxy", "mullvad-nsis", "mullvad-paths", "mullvad-problem-report", diff --git a/mullvad-masque-proxy/Cargo.toml b/mullvad-masque-proxy/Cargo.toml new file mode 100644 index 0000000000..ae3a1a4928 --- /dev/null +++ b/mullvad-masque-proxy/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "mullvad-masque-proxy" +version = "0.1.0" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +description = "A limited functionality UDP over HTTP3 proxy" + +[dependencies] +quinn = "0.11" +tokio = { workspace = true, features = [ "macros", "io-util" ] } +h3 = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac" } +h3-datagram = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac" } +h3-quinn = { git = "https://github.com/mullvad/h3", rev = "01ae01192300d29abaf6a3233d862e40c9f92bac", features = [ "datagram" ]} +http = "1" +rustls = { version = "0.23" } +rustls-pemfile = "2.1.3" +bytes = "1" + +[dev-dependencies] +tokio = { workspace = true, features = [ "macros", "io-util", "rt-multi-thread" ] } +clap = { workspace = true } +rand = "0.8.5" + +[lints] +workspace = true diff --git a/mullvad-masque-proxy/examples/client.rs b/mullvad-masque-proxy/examples/client.rs new file mode 100644 index 0000000000..6005e1f5f4 --- /dev/null +++ b/mullvad-masque-proxy/examples/client.rs @@ -0,0 +1,80 @@ +use clap::Parser; +use mullvad_masque_proxy::client::Error; +use tokio::net::UdpSocket; + +use std::{ + net::{Ipv4Addr, SocketAddr}, + path::PathBuf, +}; + +#[derive(Parser, Debug)] +pub struct ClientArgs { + #[arg(long, short = 't')] + target_addr: SocketAddr, + + /// Path to cert + #[arg(long, short = 'c', required = false)] + root_cert_path: Option<PathBuf>, + + /// Server address + #[arg(long, short = 's')] + server_addr: SocketAddr, + + #[arg(long, short = 'H')] + server_hostname: String, + + #[arg(long, short = 'p', default_value = "0")] + bind_port: u16, + + #[arg(long, short = 'S', default_value = "1000")] + maximum_packet_size: u16, +} + +#[tokio::main] +async fn main() { + let ClientArgs { + server_addr, + target_addr, + root_cert_path, + server_hostname, + bind_port, + maximum_packet_size, + } = ClientArgs::parse(); + + let tls_config = match root_cert_path { + Some(path) => mullvad_masque_proxy::client::client_tls_config_from_cert_path(path.as_ref()) + .expect("Failed to get TLS config"), + None => mullvad_masque_proxy::client::default_tls_config(), + }; + + let _keylog = rustls::KeyLogFile::new(); + + let unbound_local_addr: SocketAddr = (Ipv4Addr::UNSPECIFIED, bind_port).into(); + let local_socket = UdpSocket::bind(unbound_local_addr) + .await + .expect("Failed to bind address"); + let local_addr = local_socket.local_addr().unwrap(); + println!("Listening on {local_addr}"); + + let client = mullvad_masque_proxy::client::Client::connect_with_tls_config( + local_socket, + server_addr, + (Ipv4Addr::UNSPECIFIED, 0).into(), + target_addr, + &server_hostname, + tls_config, + maximum_packet_size, + ) + .await; + if let Err(err) = &client { + println!("ERROR: {:?}", err); + if let Error::Connection(err) = err { + println!("ERROR: {}", err); + } + } + client + .expect("Failed to connect client") + .run() + .await + .unwrap(); +} diff --git a/mullvad-masque-proxy/examples/server.rs b/mullvad-masque-proxy/examples/server.rs new file mode 100644 index 0000000000..f9dfee557f --- /dev/null +++ b/mullvad-masque-proxy/examples/server.rs @@ -0,0 +1,78 @@ +use clap::Parser; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + +use std::{ + fs, + net::{IpAddr, SocketAddr}, + path::{Path, PathBuf}, + sync::Arc, +}; + +#[derive(Parser, Debug)] +pub struct ServerArgs { + #[arg(long, short = 'b', default_value = "0.0.0.0:0")] + bind_addr: SocketAddr, + + /// Path to cert + #[arg(long, short = 'c')] + cert_path: PathBuf, + + /// Path to key + #[arg(long, short = 'k')] + key_path: PathBuf, + + /// Allowed IPs + #[arg(long = "alloewd-ip", short = 'a', required = false)] + allowed_ips: Vec<IpAddr>, + /// Maximums packet size + #[arg(long = "maximum-packet-size", short = 'm', default_value = "1700")] + maximum_packet_size: u16, +} + +#[tokio::main] +async fn main() { + let args = ServerArgs::parse(); + let _keylog = rustls::KeyLogFile::new(); + + let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap(); + + let server = mullvad_masque_proxy::server::Server::bind( + args.bind_addr, + args.allowed_ips.iter().cloned().collect(), + tls_config.into(), + args.maximum_packet_size, + ) + .expect("Failed to initialize server"); + println!("Listening on {}", args.bind_addr); + server.run().await.expect("Server failed.") +} + +fn load_server_config( + key_path: &Path, + cert_path: &Path, +) -> Result<rustls::ServerConfig, Box<dyn std::error::Error>> { + let key = fs::read(key_path)?; + let key = if key_path.extension().is_some_and(|x| x == "der") { + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key)) + } else { + rustls_pemfile::private_key(&mut &*key)?.expect("Expected PEM file to contain private key") + }; + let cert_chain = fs::read(cert_path)?; + let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") { + vec![CertificateDer::from(cert_chain)] + } else { + rustls_pemfile::certs(&mut &*cert_chain).collect::<Result<_, _>>()? + }; + + let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13])? + .with_no_client_auth() + .with_single_cert(cert_chain, key)?; + + tls_config.max_early_data_size = u32::MAX; + tls_config.alpn_protocols = vec![b"h3".into()]; + + Ok(tls_config) +} diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs new file mode 100644 index 0000000000..e07d0f76d3 --- /dev/null +++ b/mullvad-masque-proxy/src/client/mod.rs @@ -0,0 +1,436 @@ +use bytes::{Buf, BytesMut}; +use rustls::client::danger::ServerCertVerified; +use std::{ + fs, future, io, + net::{Ipv4Addr, SocketAddr}, + path::Path, + sync::{Arc, LazyLock}, + time::Duration, +}; +use tokio::{net::UdpSocket, time::interval}; + +use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId}; +use h3_datagram::datagram_traits::HandleDatagramsExt; +use http::{header, uri::Scheme, Response, StatusCode}; +use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, TransportConfig}; + +use crate::fragment::{self, Fragments}; + +const MAX_HEADER_SIZE: u64 = 8192; + +const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem"); + +pub struct Client { + client_socket: UdpSocket, + /// QUIC connection, used to send the actual HTTP datagrams + connection: h3::client::Connection<h3_quinn::Connection, bytes::Bytes>, + /// Send stream over a QUIC connection - this needs to be kept alive to not close the HTTP + /// QUIC stream. + _send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, + /// Request stream for the currently open request, must not be dropped, otherwise proxy + /// connection is terminated + request_stream: client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, + /// Packet fragments + fragments: Fragments, + /// Maximum packet size + maximum_packet_size: u16, +} + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug)] +pub enum Error { + Bind(io::Error), + Connect(quinn::ConnectError), + Connection(quinn::ConnectionError), + /// Connection closed while sending request to initiate proxying + ConnectionClosedPrematurely, + /// QUIC connection failed while sending request to initiate proxying + ConnectionFailed(h3::Error), + /// Request failed to illicit a response. + RequestError(h3::Error), + /// Received response was not a 200. + UnexpectedStatus(http::StatusCode), + /// Failed to receive data from client socket + ClientRead(io::Error), + /// Failed to send data to client socket + ClientWrite(io::Error), + /// Failed to receive data from server socket + ServerRead(h3::Error), + /// Failed to create a client + CreateClient(h3::Error), + /// Failed to receive good response from proxy + ProxyResponse(h3::Error), + /// Failed to construct a URI + Uri(http::Error), + /// Failed to send datagram to proxy + SendDatagram(h3::Error), + /// Failed to read certificates + ReadCerts(io::Error), + /// Failed to parse certificates + ParseCerts, + /// Failed to fragment a packet - it is too large + PacketTooLarge(fragment::PacketTooLarge), +} + +impl Client { + pub async fn connect( + client_socket: UdpSocket, + server_addr: SocketAddr, + local_addr: SocketAddr, + target_addr: SocketAddr, + server_host: &str, + maximum_packet_size: u16, + ) -> Result<Self> { + Self::connect_with_tls_config( + client_socket, + server_addr, + local_addr, + target_addr, + server_host, + default_tls_config(), + maximum_packet_size, + ) + .await + } + + pub async fn connect_with_tls_config( + client_socket: UdpSocket, + server_addr: SocketAddr, + local_addr: SocketAddr, + target_addr: SocketAddr, + server_host: &str, + tls_config: Arc<rustls::ClientConfig>, + maximum_packet_size: u16, + ) -> Result<Self> { + let quic_client_config = QuicClientConfig::try_from(tls_config) + .expect("Failed to construct a valid TLS configuration"); + + let mut client_config = ClientConfig::new(Arc::new(quic_client_config)); + let transport_config = TransportConfig::default(); + // TODO: Set datagram_receive_buffer_size if needed + // TODO: Set datagram_send_buffer_size if needed + // When would it be needed? If we need to buffer more packets or buffer less packets for + // better performance. + client_config.transport_config(Arc::new(transport_config)); + Self::connect_with_local_addr( + client_socket, + server_addr, + local_addr, + target_addr, + server_host, + client_config, + maximum_packet_size, + ) + .await + } + + async fn connect_with_local_addr( + client_socket: UdpSocket, + server_addr: SocketAddr, + local_addr: SocketAddr, + target_addr: SocketAddr, + server_host: &str, + client_config: ClientConfig, + maximum_packet_size: u16, + ) -> Result<Self> { + let endpoint = Endpoint::client(local_addr).map_err(Error::Bind)?; + + let connecting = endpoint + .connect_with(client_config, server_addr, server_host) + .map_err(Error::Connect)?; + + let connection = connecting.await.map_err(Error::Connection)?; + + let (connection, send_stream, request_stream) = + Self::setup_h3_connection(connection, target_addr, server_host, maximum_packet_size) + .await?; + + Ok(Self { + connection, + client_socket, + request_stream, + fragments: Fragments::default(), + _send_stream: send_stream, + maximum_packet_size, + }) + } + + // Returns an h3 connection that is ready to be used for sending UDP datagrams. + async fn setup_h3_connection( + connection: quinn::Connection, + target: SocketAddr, + server_host: &str, + maximum_packet_size: u16, + ) -> Result<( + client::Connection<h3_quinn::Connection, bytes::Bytes>, + client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>, + client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>, + )> { + let (mut connection, mut send_stream) = client::builder() + .max_field_section_size(MAX_HEADER_SIZE) + .enable_datagram(true) + .send_grease(true) + .build(h3_quinn::Connection::new(connection)) + .await + .map_err(Error::CreateClient)?; + + let request = new_connect_request(target, &server_host, maximum_packet_size)?; + + let request_future = async move { + let mut request_stream = send_stream.send_request(request).await?; + let response = request_stream.recv_response().await?; + Ok((response, send_stream, request_stream)) + }; + + tokio::select! { + closed = future::poll_fn(|cx| connection.poll_close(cx)) => { + match closed { + Ok(()) => Err(Error::ConnectionClosedPrematurely), + Err(err) => Err(Error::ConnectionFailed(err)), + } + }, + response = request_future => { + let (response, send_stream, request_stream) = response.map_err(Error::RequestError)?; + handle_response(response)?; + Ok((connection, send_stream, request_stream)) + }, + } + } + + pub async fn run(mut self) -> Result<()> { + let stream_id: StreamId = self.request_stream.id(); + // this is the variable ID used to signify UDP payloads in HTTP datagrams. + let mut client_read_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE * 1024); + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); + + let mut return_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); + let mut fragment_id = 1u16; + let mut interval = interval(Duration::from_secs(3)); + + loop { + tokio::select! { + client_read = self.client_socket.recv_buf_from(&mut client_read_buf) => { + let (_bytes_received, recv_addr) = client_read.map_err(Error::ClientRead)?; + return_addr = recv_addr; + + let mut send_buf = client_read_buf.split().freeze(); + if send_buf.len() < (Into::<usize>::into(self.maximum_packet_size) - 100usize) { + self.connection + .send_datagram(stream_id, send_buf) + .map_err(Error::SendDatagram)?; + } else { + // drop the added context ID, since packet will have to be fragmented. + { + let _ = VarInt::decode(&mut send_buf); + } + for fragment in fragment::fragment_packet( + self.maximum_packet_size, + &mut send_buf, + fragment_id) + .map_err(Error::PacketTooLarge) + ? { + self.connection.send_datagram(stream_id, fragment).map_err(Error::SendDatagram)?; + } + fragment_id = fragment_id.wrapping_add(1); + } + + client_read_buf.reserve(crate::PACKET_BUFFER_SIZE); + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut client_read_buf); + }, + server_response = self.connection.read_datagram() => { + match server_response { + Ok(Some(response)) => { + if response.stream_id() != stream_id { + // log::trace!("Received datagram with an unexpected stream ID"); + continue; + } + let mut payload = response.into_payload(); + let context = VarInt::decode(&mut payload); + match context { + Ok(crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID) => { + self.client_socket + .send_to(payload.as_ref(), return_addr) + .await + .map_err(Error::ClientWrite)?; + } + Ok(crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID) => { + if let Ok(Some(payload)) = self.fragments.handle_incoming_packet(payload) { + self.client_socket + .send_to(payload.chunk(), return_addr) + .await + .map_err(Error::ClientWrite)?; + } + }, + _ => (), + + } + } + Ok(None) => { + return Ok(()); + } + Err(err) => { + return Err(Error::ProxyResponse(err)); + } + } + }, + _ = interval.tick() => { + self.fragments.clear_old_fragments( + Duration::from_secs(3) + ); + }, + }; + } + } +} + +fn new_connect_request( + socket_addr: SocketAddr, + authority: &dyn AsRef<str>, + maximum_packet_size: u16, +) -> Result<http::Request<()>> { + let host = socket_addr.ip(); + let port = socket_addr.port(); + let path = format!("/.well-known/masque/udp/{host}/{port}/"); + let uri = http::uri::Builder::new() + .scheme(Scheme::HTTPS) + .authority(authority.as_ref()) + .path_and_query(&path) + .build() + .map_err(Error::Uri)?; + + let mut request = http::Request::builder() + .method(http::method::Method::CONNECT) + .uri(uri) + .header(b"Capsule-Protocol".as_slice(), b"?1".as_slice()) + .header(header::AUTHORIZATION, b"Bearer test".as_slice()) + .header(header::HOST, authority.as_ref()) + .header( + b"X-Mullvad-Uplink-Mtu".as_slice(), + format!("{maximum_packet_size}"), + ) + .body(()) + .expect("failed to construct a body"); + + request.extensions_mut().insert(Protocol::CONNECT_UDP); + Ok(request) +} + +fn handle_response(response: Response<()>) -> Result<()> { + if response.status() != StatusCode::OK { + return Err(Error::UnexpectedStatus(response.status())); + } + Ok(()) +} + +// TODO: resuse the same TLS code from `mullvad-api` maybe +pub fn default_tls_config() -> Arc<rustls::ClientConfig> { + static TLS_CONFIG: LazyLock<Arc<rustls::ClientConfig>> = + LazyLock::new(|| client_tls_config_with_certs(read_cert_store())); + + TLS_CONFIG.clone() +} + +fn client_tls_config_with_certs(certs: rustls::RootCertStore) -> Arc<rustls::ClientConfig> { + let mut config = rustls::ClientConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .expect("ring crypt-prover should support TLS 1.3") + .with_root_certificates(certs) + .with_no_client_auth(); + config.alpn_protocols = vec![b"h3".to_vec()]; + + let approver = Approver {}; + config.key_log = Arc::new(rustls::KeyLogFile::new()); + config + .dangerous() + .set_certificate_verifier(Arc::new(approver)); + Arc::new(config) +} + +fn read_cert_store() -> rustls::RootCertStore { + read_cert_store_from_reader(&mut std::io::BufReader::new(LE_ROOT_CERT)) + .expect("failed to read built-in cert store") +} + +pub fn client_tls_config_from_cert_path(path: &Path) -> Result<Arc<rustls::ClientConfig>> { + let certs = read_cert_store_from_path(path)?; + Ok(client_tls_config_with_certs(certs)) +} + +fn read_cert_store_from_path(path: &Path) -> Result<rustls::RootCertStore> { + let cert_path = fs::File::open(path).map_err(Error::ReadCerts)?; + read_cert_store_from_reader(&mut std::io::BufReader::new(cert_path)) +} + +fn read_cert_store_from_reader(reader: &mut dyn io::BufRead) -> Result<rustls::RootCertStore> { + let mut cert_store = rustls::RootCertStore::empty(); + + let certs = rustls_pemfile::certs(reader) + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(Error::ReadCerts)?; + let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(certs); + if num_failures > 0 || num_certs_added == 0 { + return Err(Error::ParseCerts); + } + + Ok(cert_store) +} + +#[test] +fn test_zero_stream_id() { + h3::quic::StreamId::try_from(0).expect("need to be able to create stream IDs with 0, no?"); +} + +#[derive(Debug)] +struct Approver {} + +impl rustls::client::danger::ServerCertVerifier for Approver { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA1, + rustls::SignatureScheme::ECDSA_SHA1_Legacy, + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP521_SHA512, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::ED448, + ] + } +} diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs new file mode 100644 index 0000000000..6224ab7f63 --- /dev/null +++ b/mullvad-masque-proxy/src/fragment.rs @@ -0,0 +1,153 @@ +use std::{ + collections::BTreeMap, + time::{Duration, Instant}, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use h3::proto::varint::VarInt; + +#[derive(Default)] +pub struct Fragments { + fragment_map: BTreeMap<u16, Vec<Fragment>>, +} + +// When a packet that arrives is too small to be decoded. +#[derive(Debug)] +pub enum DefragError { + BadContextId(Result<VarInt, h3::proto::coding::UnexpectedEnd>), + PayloadTooSmall, +} + +// When a packet is larger than u16::MAX, it can't be fragmented. +#[derive(Debug)] +pub struct PacketTooLarge(pub usize); + +impl Fragments { + // TODO: Let caller provide output buffer. + pub fn handle_incoming_packet( + &mut self, + mut payload: Bytes, + ) -> Result<Option<Bytes>, DefragError> { + match VarInt::decode(&mut payload) { + Ok(crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID) => { + return Ok(Some(payload)); + } + Ok(crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID) => {} + unexpected_context_id => { + return Err(DefragError::BadContextId(unexpected_context_id)); + } + } + + let id = payload + .try_get_u16() + .map_err(|_| DefragError::PayloadTooSmall)?; + let index = payload + .try_get_u8() + .map_err(|_| DefragError::PayloadTooSmall)?; + let fragment_count = payload + .try_get_u8() + .map_err(|_| DefragError::PayloadTooSmall)?; + let fragment = Fragment { + index, + payload, + time_received: Instant::now(), + }; + + let fragments = self.fragment_map.entry(id).or_default(); + fragments.push(fragment); + + Ok(self.try_fetch(id, fragment_count)) + } + + // TODO: Let caller provide output buffer. + fn try_fetch(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> { + // establish that there are enough fragments to reconstruct the whole packet + let payload = { + let fragments = self.fragment_map.get_mut(&id)?; + + if fragments.len() != fragment_count.into() { + return None; + } + + fragments.sort_by_key(|f| f.index); + let mut payload = + BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum()); + for fragment in fragments { + payload.extend_from_slice(&fragment.payload); + } + payload + }; + + self.fragment_map.remove(&id); + Some(payload.into()) + } + + pub fn clear_old_fragments(&mut self, max_age: Duration) { + self.fragment_map.retain(|_, fragments| { + fragments + .iter() + .any(|fragment| fragment.time_received.elapsed() <= max_age) + }); + } +} + +struct Fragment { + index: u8, + payload: Bytes, + time_received: Instant, +} + +pub fn fragment_packet( + maximum_packet_size: u16, + payload: &'_ mut Bytes, + packet_id: u16, +) -> Result<impl Iterator<Item = Bytes> + '_, PacketTooLarge> { + let num_fragments: usize = payload.chunks(maximum_packet_size.into()).count(); + let Ok(fragment_count): std::result::Result<u8, _> = num_fragments.try_into() else { + return Err(PacketTooLarge(payload.len())); + }; + + let iterator = payload.chunks(maximum_packet_size.into()).enumerate().map( + move |(fragment_index, fragment_payload)| { + let mut fragment = BytesMut::with_capacity((maximum_packet_size + 1).into()); + crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment); + fragment.put_u16(packet_id); + fragment.put_u8( + // fragment indexes start at 1 + u8::try_from(fragment_index +1) + .expect("fragment index must fit in an u8, since num_fragments fits is an u8"), + ); + fragment.put_u8(fragment_count); + fragment.extend_from_slice(fragment_payload); + fragment.freeze() + }, + ); + Ok(iterator) +} + +#[test] +fn test_fragment_reconstruction() { + use rand::{seq::SliceRandom, thread_rng}; + + let payload = (0..255).collect::<Vec<u8>>(); + let max_payload_size = 50; + let packet_id = 76; + + let mut fragments = Fragments::default(); + + let mut payload_clone = Bytes::from(payload.clone()); + let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) + .unwrap() + .collect::<Vec<_>>(); + + fragment_buf.shuffle(&mut thread_rng()); + + for fragment in fragment_buf { + if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() { + assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); + return; + } + } + + panic!("Failed to reconstruct packet"); +} diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs new file mode 100644 index 0000000000..bb973d3a80 --- /dev/null +++ b/mullvad-masque-proxy/src/lib.rs @@ -0,0 +1,9 @@ +use h3::proto::varint::VarInt; + +pub mod client; +mod fragment; +pub mod server; + +const PACKET_BUFFER_SIZE: usize = 1700; +pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0); +pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1); diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs new file mode 100644 index 0000000000..b6d263066a --- /dev/null +++ b/mullvad-masque-proxy/src/server/mod.rs @@ -0,0 +1,311 @@ +use std::{ + collections::HashSet, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use bytes::{Bytes, BytesMut}; +use h3::{ + proto::varint::VarInt, + quic::{BidiStream, StreamId}, + server::{self, Connection, RequestStream}, +}; +use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt}; +use http::{Request, StatusCode}; +use quinn::{crypto::rustls::QuicServerConfig, Endpoint, Incoming}; +use tokio::{net::UdpSocket, time::interval}; + +use crate::fragment::{self, Fragments}; + +#[derive(Debug)] +pub enum Error { + BadTlsConfig(quinn::crypto::rustls::NoInitialCipherSuite), + BindSocket(io::Error), + SendNegotiationResponse(h3::Error), +} + +pub type Result<T> = std::result::Result<T, Error>; + +const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/"; + +pub struct Server { + endpoint: Endpoint, + allowed_hosts: AllowedIps, + max_packet_size: u16, +} + +#[derive(Clone)] +struct AllowedIps { + hosts: Arc<HashSet<IpAddr>>, +} + +impl AllowedIps { + fn ip_allowed(&self, ip: IpAddr) -> bool { + self.hosts.is_empty() || self.hosts.contains(&ip) + } +} + +impl Server { + pub fn bind( + bind_addr: SocketAddr, + allowed_hosts: HashSet<IpAddr>, + tls_config: Arc<rustls::ServerConfig>, + max_packet_size: u16, + ) -> Result<Self> { + let server_config = quinn::ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(tls_config).map_err(Error::BadTlsConfig)?, + )); + + let endpoint = Endpoint::server(server_config, bind_addr).map_err(Error::BindSocket)?; + + Ok(Self { + endpoint, + allowed_hosts: AllowedIps { + hosts: Arc::new(allowed_hosts), + }, + max_packet_size, + }) + } + + pub async fn run(self) -> Result<()> { + while let Some(new_connection) = self.endpoint.accept().await { + tokio::spawn(Self::handle_incoming_connection( + new_connection, + self.allowed_hosts.clone(), + self.max_packet_size, + )); + } + Ok(()) + } + + async fn handle_incoming_connection( + connection: Incoming, + allowed_hosts: AllowedIps, + maximum_packet_size: u16, + ) { + match connection.await { + Ok(conn) => { + println!("new connection established"); + + let Ok(mut connection) = server::builder() + .enable_datagram(true) + .build(h3_quinn::Connection::new(conn)) + .await + else { + println!("Failed to construct a new H3 server connection"); + return; + }; + + match connection.accept().await { + Ok(Some((req, stream))) => { + tokio::spawn(Self::handle_proxy_request( + connection, + req, + stream, + allowed_hosts.clone(), + maximum_packet_size, + )); + } + + // indicating no more streams to be received + Ok(None) => {} + + Err(err) => { + println!("error on accept {}", err); + return; + } + } + } + Err(err) => { + println!("accepting connection failed: {:?}", err); + } + } + } + + async fn handle_proxy_request<T: BidiStream<Bytes>>( + mut connection: Connection<h3_quinn::Connection, Bytes>, + request: Request<()>, + mut stream: RequestStream<T, Bytes>, + allowed_hosts: AllowedIps, + maximum_packet_size: u16, + ) { + let Some(target_addr) = get_target_socketaddr(request.uri().path()) else { + return; + }; + if !allowed_hosts.ip_allowed(target_addr.ip()) { + return handle_disallowed_ip(stream).await; + } + + let bind_addr = SocketAddr::new(unspecified_addr(target_addr.ip()), 0); + let Ok(udp_socket) = UdpSocket::bind(bind_addr).await else { + return handle_failed_socket(stream).await; + }; + if let Err(err) = udp_socket.connect(target_addr).await { + println!("Failed to set destination for UDP socket: {err}"); + return handle_failed_socket(stream).await; + }; + + if handle_established_connection(&mut stream).await.is_err() { + return; + } + + let stream_id = stream.id(); + let mut proxy_recv_buf = BytesMut::with_capacity(crate::PACKET_BUFFER_SIZE); + + let mut fragments = Fragments::default(); + let mut fragment_id = 0u16; + + let mut interval = interval(Duration::from_secs(3)); + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); + + loop { + tokio::select! { + client_send = connection.read_datagram() => { + match client_send { + Ok(Some(received_packet)) => { + handle_client_packet(received_packet, stream_id, &mut fragments, &udp_socket, target_addr).await; + }, + Ok(None) => { + return; + } + Err(_err) => { + // client connection QUIC connection failed, should return now. + return; + }, + } + }, + recv_result = udp_socket.recv_buf_from(&mut proxy_recv_buf) => { + match recv_result { + Ok((_bytes_received, sender_addr)) => { + if sender_addr != target_addr { + continue + } + + let mut received_packet = proxy_recv_buf.split().freeze(); + + if proxy_recv_buf.len() < maximum_packet_size.into() { + if connection.send_datagram(stream_id, received_packet).is_err() { + return; + } + } else { + let _ = VarInt::decode(&mut received_packet); + let Ok(fragments) = fragment::fragment_packet(maximum_packet_size, &mut received_packet, fragment_id) else { continue; }; + fragment_id += 1; + for payload in fragments { + if connection.send_datagram(stream_id, payload).is_err() { + return; + } + } + }; + + proxy_recv_buf.reserve(crate::PACKET_BUFFER_SIZE); + crate::HTTP_MASQUE_DATAGRAM_CONTEXT_ID.encode(&mut proxy_recv_buf); + }, + Err(err) => { + println!("Failed to receive packet from proxy connection: {err}"); + let _ = stream.finish().await; + return; + } + } + }, + _ = interval.tick() => { + fragments.clear_old_fragments( + Duration::from_secs(3) + ); + }, + }; + } + } +} + +async fn handle_client_packet( + received_packet: Datagram, + stream_id: StreamId, + fragments: &mut Fragments, + proxy_socket: &UdpSocket, + target_addr: SocketAddr, +) { + if received_packet.stream_id() != stream_id { + // log::trace!("Received unexpected stream ID from server"); + return; + } + + if let Ok(Some(payload)) = fragments.handle_incoming_packet(received_packet.into_payload()) { + let _ = proxy_socket.send_to(&payload, target_addr).await; + } +} + +async fn handle_established_connection<T: BidiStream<Bytes>>( + stream: &mut RequestStream<T, Bytes>, +) -> Result<()> { + let response = http::Response::builder() + .status(StatusCode::OK) + .body(()) + .unwrap(); + stream + .send_response(response) + .await + .map_err(Error::SendNegotiationResponse)?; + Ok(()) +} + +async fn handle_disallowed_ip<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { + let response = http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(()) + .unwrap(); + let _ = stream.send_response(response).await; +} + +async fn handle_failed_socket<T: BidiStream<Bytes>>(mut stream: RequestStream<T, Bytes>) { + let response = http::Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(()) + .unwrap(); + let _ = stream.send_response(response).await; +} + +fn get_target_socketaddr(request_path: &str) -> Option<SocketAddr> { + // Establish if the URL path looks like `/.well-known/masque/udp/{ip}/{port}` + if !request_path.starts_with(MASQUE_WELL_KNOWN_PATH) { + return None; + }; + let (addr_str, port_str) = request_path + .strip_prefix(MASQUE_WELL_KNOWN_PATH)? + .trim_start_matches('/') + .split_once('/')?; + let port_str = port_str.trim_end_matches('/'); + + Some(SocketAddr::new( + addr_str.trim_start_matches('/').parse().ok()?, + port_str.parse().ok()?, + )) +} + +fn unspecified_addr(addr: IpAddr) -> IpAddr { + match addr { + IpAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(), + IpAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(), + } +} + +#[test] +fn test_get_good_slashy_ocketaddr() { + let addr: IpAddr = "192.168.1.1".parse().unwrap(); + let port: u16 = 7979; + let expected_addr = SocketAddr::new(addr, port); + let good_path = format!("{MASQUE_WELL_KNOWN_PATH}///{addr}/{port}////"); + + assert_eq!(get_target_socketaddr(&good_path).unwrap(), expected_addr) +} + +#[test] +fn test_get_bad_socketaddr() { + let addr: IpAddr = "192.168.1.1".parse().unwrap(); + let port: u16 = 7979; + let good_path = format!("{MASQUE_WELL_KNOWN_PATH}{addr}adsfasd/asdfasdf/{port}"); + + assert_eq!(get_target_socketaddr(&good_path), None) +} |
