summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-28 17:59:50 +0200
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-30 12:58:01 +0200
commit961f47a58379d8ffcc41da35a2581fb4c04dc258 (patch)
tree1d66a70ef66dc9f389bba9105742d95f0cc4ca45
parent4165dacc7f2c2b6f85fc6ec1abc665b0d85a0b30 (diff)
downloadmullvadvpn-961f47a58379d8ffcc41da35a2581fb4c04dc258.tar.xz
mullvadvpn-961f47a58379d8ffcc41da35a2581fb4c04dc258.zip
Handle HTTP redirects in masque client
-rw-r--r--mullvad-masque-proxy/src/client/mod.rs99
-rw-r--r--mullvad-masque-proxy/src/lib.rs2
-rw-r--r--mullvad-masque-proxy/src/server/mod.rs4
3 files changed, 81 insertions, 24 deletions
diff --git a/mullvad-masque-proxy/src/client/mod.rs b/mullvad-masque-proxy/src/client/mod.rs
index ea34ea2fab..26a069bc81 100644
--- a/mullvad-masque-proxy/src/client/mod.rs
+++ b/mullvad-masque-proxy/src/client/mod.rs
@@ -1,9 +1,12 @@
+use anyhow::{anyhow, Context};
use bytes::{Buf, Bytes, BytesMut};
use rustls::client::danger::ServerCertVerified;
use std::{
- fs, future, io,
+ fs::{self},
+ future, io,
net::{Ipv4Addr, SocketAddr},
path::Path,
+ str::FromStr as _,
sync::{Arc, LazyLock},
time::Duration,
};
@@ -16,7 +19,7 @@ use typed_builder::TypedBuilder;
use h3::{client, ext::Protocol, proto::varint::VarInt, quic::StreamId};
use h3_datagram::{datagram::Datagram, datagram_traits::HandleDatagramsExt};
-use http::{header, uri::Scheme, Response, StatusCode};
+use http::{header, uri::Scheme, StatusCode, Uri};
use quinn::{
crypto::rustls::QuicClientConfig, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime,
TransportConfig,
@@ -26,11 +29,13 @@ use crate::{
compute_udp_payload_size,
fragment::{self, Fragments},
stats::Stats,
- MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
+ MASQUE_WELL_KNOWN_PATH, MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
};
const MAX_HEADER_SIZE: u64 = 8192;
+const MAX_REDIRECT_COUNT: usize = 1;
+
const LE_ROOT_CERT: &[u8] = include_bytes!("../../../mullvad-api/le_root_cert.pem");
pub struct Client {
@@ -103,6 +108,8 @@ pub enum Error {
PacketTooLarge(#[from] fragment::PacketTooLarge),
#[error("The provided idle timeout was invalid")]
InvalidIdleTimeout(quinn::VarIntBoundsExceeded),
+ #[error("The server returned an invalid HTTP redirect")]
+ InvalidHttpRedirect(#[source] anyhow::Error),
}
#[derive(TypedBuilder)]
@@ -244,7 +251,7 @@ impl Client {
.map_err(Error::Bind)
}
- // Returns an h3 connection that is ready to be used for sending UDP datagrams.
+ /// Returns an h3 connection that is ready to be used for sending UDP datagrams.
async fn setup_h3_connection(
connection: quinn::Connection,
target: SocketAddr,
@@ -255,7 +262,7 @@ impl Client {
client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
)> {
- let (mut connection, mut send_stream) = client::builder()
+ let (connection, send_stream) = client::builder()
.max_field_section_size(MAX_HEADER_SIZE)
.enable_datagram(true)
.send_grease(true)
@@ -263,6 +270,24 @@ impl Client {
.await
.map_err(Error::CreateClient)?;
+ Self::send_connect_request(connection, send_stream, server_host, target, mtu, 0).await
+ }
+
+ /// Send an HTTP CONNECT request and set up the h3 connection for sending datagrams.
+ ///
+ /// This function will follow HTTP redirects up to [MAX_REDIRECT_COUNT].
+ async fn send_connect_request(
+ mut connection: client::Connection<h3_quinn::Connection, bytes::Bytes>,
+ mut send_stream: client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
+ server_host: &str,
+ target: SocketAddr,
+ mtu: u16,
+ redirect_count: usize,
+ ) -> Result<(
+ client::Connection<h3_quinn::Connection, bytes::Bytes>,
+ client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
+ client::RequestStream<h3_quinn::BidiStream<bytes::Bytes>, bytes::Bytes>,
+ )> {
let request = new_connect_request(target, &server_host, mtu)?;
let request_future = async move {
@@ -271,18 +296,57 @@ impl Client {
Ok((response, send_stream, request_stream))
};
- tokio::select! {
+ let response = tokio::select! {
+ response = request_future => response,
+ // TODO: this arm completes first when the connection is gracefully terminated from the
+ // peer, but ideally we want to be able to handle the response above in that case.
closed = future::poll_fn(|cx| connection.poll_close(cx)) => {
- match closed {
+ return match closed {
Ok(()) => Err(Error::ConnectionClosedPrematurely),
Err(err) => Err(Error::ConnectionFailed(err)),
- }
- },
- response = request_future => {
- let (response, send_stream, request_stream) = response.map_err(Error::RequestError)?;
- handle_response(response)?;
- Ok((connection, send_stream, request_stream))
+ };
},
+ };
+
+ let (response, send_stream, request_stream) = response.map_err(Error::RequestError)?;
+
+ match response.status() {
+ StatusCode::OK => Ok((connection, send_stream, request_stream)),
+
+ // If we are trying to connect with the wrong `host` in the HTTP URI, then the masque
+ // server will redirect us to the URI with the correct `host`.
+ status @ StatusCode::PERMANENT_REDIRECT => {
+ if redirect_count >= MAX_REDIRECT_COUNT {
+ log::error!("Too many redirects (redirect loop?)");
+ return Err(anyhow!("Too many redirects")).map_err(Error::InvalidHttpRedirect);
+ }
+
+ let server_host = response
+ .headers()
+ .get("Location")
+ .and_then(|header| header.to_str().ok())
+ .and_then(|location| Uri::from_str(location).ok())
+ .inspect(|location| log::info!("Redirected to {location:?} (HTTP {status})"))
+ .and_then(|location| location.host().map(String::from))
+ .context("Failed to decode `Location` HTTP header")
+ .map_err(Error::InvalidHttpRedirect)?;
+
+ // Repeat the request, but using the new host
+ //
+ // We are re-using the same h3 connection for this HTTP request, meaning that we
+ // will never redirect to a *different* server. We are only re-issuing the same
+ // HTTP request, using the same connection, but with a different URI.
+ Box::pin(Self::send_connect_request(
+ connection,
+ send_stream,
+ &server_host,
+ target,
+ mtu,
+ redirect_count + 1,
+ ))
+ .await
+ }
+ status => Err(Error::UnexpectedStatus(status)),
}
}
@@ -485,7 +549,7 @@ fn new_connect_request(
) -> Result<http::Request<()>> {
let host = socket_addr.ip();
let port = socket_addr.port();
- let path = format!("/.well-known/masque/udp/{host}/{port}/");
+ let path = format!("{MASQUE_WELL_KNOWN_PATH}{host}/{port}/");
let uri = http::uri::Builder::new()
.scheme(Scheme::HTTPS)
.authority(authority.as_ref())
@@ -511,13 +575,6 @@ fn new_connect_request(
Ok(request)
}
-fn handle_response(response: Response<()>) -> Result<()> {
- if response.status() != StatusCode::OK {
- return Err(Error::UnexpectedStatus(response.status()));
- }
- Ok(())
-}
-
// TODO: resuse the same TLS code from `mullvad-api` maybe
pub fn default_tls_config() -> Arc<rustls::ClientConfig> {
static TLS_CONFIG: LazyLock<Arc<rustls::ClientConfig>> =
diff --git a/mullvad-masque-proxy/src/lib.rs b/mullvad-masque-proxy/src/lib.rs
index d8467790ec..e6533b5368 100644
--- a/mullvad-masque-proxy/src/lib.rs
+++ b/mullvad-masque-proxy/src/lib.rs
@@ -6,6 +6,8 @@ mod fragment;
pub mod server;
mod stats;
+pub const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/";
+
pub const HTTP_MASQUE_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(0);
pub const HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID: VarInt = VarInt::from_u32(1);
diff --git a/mullvad-masque-proxy/src/server/mod.rs b/mullvad-masque-proxy/src/server/mod.rs
index 9a8a088892..60e8db010b 100644
--- a/mullvad-masque-proxy/src/server/mod.rs
+++ b/mullvad-masque-proxy/src/server/mod.rs
@@ -21,7 +21,7 @@ use tokio::{net::UdpSocket, select, sync::mpsc, task};
use crate::{
compute_udp_payload_size,
fragment::{self, Fragments},
- MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
+ MASQUE_WELL_KNOWN_PATH, MAX_INFLIGHT_PACKETS, MIN_IPV4_MTU, MIN_IPV6_MTU, QUIC_HEADER_SIZE,
};
#[derive(Debug, thiserror::Error)]
@@ -38,8 +38,6 @@ pub enum Error {
pub type Result<T> = std::result::Result<T, Error>;
-const MASQUE_WELL_KNOWN_PATH: &str = "/.well-known/masque/udp/";
-
pub struct Server {
endpoint: Endpoint,
params: Arc<ServerParams>,