summaryrefslogtreecommitdiffhomepage
path: root/mullvad-masque-proxy/examples/masque-server.rs
blob: b321fcf6592db353db0c4b89206fe8743fb005e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use clap::Parser;
use mullvad_masque_proxy::server::{AllowedIps, ServerParams};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};

use std::{
    fs,
    net::{IpAddr, SocketAddr},
    path::{Path, PathBuf},
    sync::Arc,
};

#[derive(Parser, Debug)]
pub struct ServerArgs {
    /// Bind address
    #[arg(long, short = 'b', default_value = "0.0.0.0:0")]
    bind_addr: SocketAddr,

    /// Path to cert
    #[arg(long, short = 'c')]
    cert_path: PathBuf,

    /// Path to key
    #[arg(long, short = 'k')]
    key_path: PathBuf,

    /// Allowed IPs
    #[arg(long = "allowed-ip", short = 'a', required = false)]
    allowed_ips: Vec<IpAddr>,

    /// Server hostname.
    ///
    /// If set, the client must provide the correct hostname when connecting. If they don't, the
    /// server will provide an HTTP 308 redirect to the correct URI.
    #[arg(long)]
    hostname: Option<String>,

    /// Maximum packet size
    #[arg(long, short = 'm', default_value = "1700")]
    mtu: u16,

    /// Authorization header value to set
    #[arg(long, default_value = "Bearer test")]
    auth: Option<String>,
}

#[tokio::main]
async fn main() {
    env_logger::builder()
        .filter_level(log::LevelFilter::Info)
        .parse_default_env()
        .init();

    let args = ServerArgs::parse();
    let _keylog = rustls::KeyLogFile::new();

    let tls_config = load_server_config(&args.key_path, &args.cert_path).unwrap();

    let params = ServerParams::builder()
        .allowed_hosts(AllowedIps::from(args.allowed_ips))
        .hostname(args.hostname)
        .mtu(args.mtu)
        .auth_header(args.auth)
        .build();

    let server =
        mullvad_masque_proxy::server::Server::bind(args.bind_addr, tls_config.into(), params)
            .expect("Failed to initialize server");
    log::info!("Listening on {}", args.bind_addr);
    server.run().await.expect("Server failed.")
}

fn load_server_config(
    key_path: &Path,
    cert_path: &Path,
) -> Result<rustls::ServerConfig, Box<dyn std::error::Error>> {
    let key = fs::read(key_path)?;
    let key = if key_path.extension().is_some_and(|x| x == "der") {
        PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key))
    } else {
        rustls_pemfile::private_key(&mut &*key)?.expect("Expected PEM file to contain private key")
    };
    let cert_chain = fs::read(cert_path)?;
    let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
        vec![CertificateDer::from(cert_chain)]
    } else {
        rustls_pemfile::certs(&mut &*cert_chain).collect::<Result<_, _>>()?
    };

    let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
        rustls::crypto::ring::default_provider(),
    ))
    .with_protocol_versions(&[&rustls::version::TLS13])?
    .with_no_client_auth()
    .with_single_cert(cert_chain, key)?;

    tls_config.max_early_data_size = u32::MAX;
    tls_config.alpn_protocols = vec![b"h3".into()];

    Ok(tls_config)
}