diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-01-20 10:17:33 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-01-26 16:27:53 +0100 |
| commit | 865b571f07517c73bbf30ba097e787a43cc62289 (patch) | |
| tree | bcbc820fd582a39d912ca6dbd72776ac4fbfd303 | |
| parent | 7ce9d490f8ca3983bb9854b10335852831b5b71d (diff) | |
| download | mullvadvpn-865b571f07517c73bbf30ba097e787a43cc62289.tar.xz mullvadvpn-865b571f07517c73bbf30ba097e787a43cc62289.zip | |
Add TlsStream type
| -rw-r--r-- | Cargo.lock | 8 | ||||
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 59 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-rpc/src/tls_stream.rs | 124 |
4 files changed, 133 insertions, 59 deletions
diff --git a/Cargo.lock b/Cargo.lock index 483cb54781..c6673a446f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -748,9 +748,9 @@ checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" [[package]] name = "h2" -version = "0.3.6" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c06815895acec637cd6ed6e9662c935b866d20a106f8361892893a7d9234964" +checksum = "0c9de88456263e249e241fcd211d3954e2c9b0ef7ccfc235a444eb367cae3689" dependencies = [ "bytes", "fnv", @@ -858,9 +858,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.13" +version = "0.14.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15d1cfb9e4f68655fa04c01f59edb405b6074a0f7118ea881e5026e4a1cd8593" +checksum = "b7ec3e62bdc98a2f0393a5048e4c30ef659440ea6e0e572965103e72bd836f55" dependencies = [ "bytes", "futures-channel", diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index 7b2a7a4b9a..7cc97ddd39 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -1,4 +1,4 @@ -use crate::{abortable_stream::AbortableStream, rest::RequestCommand}; +use crate::{abortable_stream::AbortableStream, rest::RequestCommand, tls_stream::TlsStream}; use futures::{ channel::{mpsc, oneshot}, sink::SinkExt, @@ -9,18 +9,15 @@ use hyper::{ service::Service, Uri, }; -use hyper_rustls::MaybeHttpsStream; -use rustls::ServerName; #[cfg(target_os = "android")] use std::os::unix::io::{AsRawFd, RawFd}; use std::{ fmt, future::Future, - io::{self, BufReader}, + io, net::{IpAddr, SocketAddr}, pin::Pin, str::{self, FromStr}, - sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -28,10 +25,6 @@ use std::{ use tokio::net::TcpSocket; use tokio::{net::TcpStream, runtime::Handle, time::timeout}; -use tokio_rustls::rustls; - -// New LetsEncrypt root certificate -const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem"); const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); @@ -44,7 +37,6 @@ pub struct HttpsConnectorWithSni { service_tx: Option<mpsc::Sender<RequestCommand>>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - tls: Arc<rustls::ClientConfig>, } #[cfg(target_os = "android")] @@ -56,15 +48,6 @@ impl HttpsConnectorWithSni { sni_hostname: Option<String>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Self { - let mut config = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_root_certificates(Self::read_cert_store()) - .with_no_client_auth(); - config.enable_sni = true; - HttpsConnectorWithSni { next_socket_id: 0, handle, @@ -72,23 +55,9 @@ impl HttpsConnectorWithSni { #[cfg(target_os = "android")] socket_bypass_tx, service_tx: None, - tls: Arc::new(config), } } - fn read_cert_store() -> rustls::RootCertStore { - let mut cert_store = rustls::RootCertStore::empty(); - - let certs = rustls_pemfile::certs(&mut BufReader::new(LE_ROOT_CERT)) - .expect("Failed to parse pem file"); - let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs); - if num_failures > 0 || num_certs_added != 1 { - panic!("Failed to add root cert"); - } - - cert_store - } - /// Set a channel to register sockets with the request service. pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) { self.service_tx = Some(service_tx); @@ -162,7 +131,7 @@ impl fmt::Debug for HttpsConnectorWithSni { } impl Service<Uri> for HttpsConnectorWithSni { - type Response = MaybeHttpsStream<AbortableStream<TcpStream>>; + type Response = TlsStream<AbortableStream<TcpStream>>; type Error = io::Error; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; @@ -172,7 +141,6 @@ impl Service<Uri> for HttpsConnectorWithSni { } fn call(&mut self, uri: Uri) -> Self::Future { - let tls_connector: tokio_rustls::TlsConnector = self.tls.clone().into(); let sni_hostname = self .sni_hostname .clone() @@ -196,12 +164,6 @@ impl Service<Uri> for HttpsConnectorWithSni { } let hostname = sni_hostname?; - let host = ServerName::try_from(hostname.as_str()).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("invalid hostname \"{}\"", hostname), - ) - })?; let addr = Self::resolve_address(&uri).await?; let tokio_connection = Self::open_socket( @@ -234,22 +196,9 @@ impl Service<Uri> for HttpsConnectorWithSni { } }); } - - let tls_connection = tls_connector.connect(host, tcp_stream).await?; - - Ok(MaybeHttpsStream::Https(tls_connection)) + Ok(TlsStream::connect_https(tcp_stream, &hostname).await?) }; Box::pin(fut) } } - -#[cfg(test)] -mod test { - use super::HttpsConnectorWithSni; - - #[test] - fn test_cert_loading() { - let _certs = HttpsConnectorWithSni::read_cert_store(); - } -} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 9dfd139b01..fce57d3991 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -24,6 +24,7 @@ pub mod rest; mod https_client_with_sni; use crate::https_client_with_sni::HttpsConnectorWithSni; mod abortable_stream; +mod tls_stream; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; diff --git a/mullvad-rpc/src/tls_stream.rs b/mullvad-rpc/src/tls_stream.rs new file mode 100644 index 0000000000..232bb39b92 --- /dev/null +++ b/mullvad-rpc/src/tls_stream.rs @@ -0,0 +1,124 @@ +//! Provides a TLS 1.3 stream with SNI and LE root cert only. +use std::{ + io::{self, ErrorKind}, + pin::Pin, + sync::Arc, + task::{self, Poll}, +}; + +use hyper::client::connect::{Connected, Connection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{ + rustls::{self, ClientConfig, ServerName}, + TlsConnector, +}; + +const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem"); + +pub struct TlsStream<S: AsyncRead + AsyncWrite + Unpin> { + stream: Pin<Box<tokio_rustls::client::TlsStream<S>>>, +} + +impl<S> TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn connect_https(stream: S, domain: &str) -> io::Result<TlsStream<S>> { + lazy_static::lazy_static! { + static ref TLS_CONFIG: Arc<ClientConfig> = { + let config = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(read_cert_store()) + .with_no_client_auth(); + Arc::new(config) + }; + } + + let connector = TlsConnector::from(TLS_CONFIG.clone()); + + let host = match ServerName::try_from(domain) { + Ok(n) => n, + Err(_) => { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("invalid hostname \"{}\"", domain), + )); + } + }; + + let tls_stream = connector.connect(host, stream).await?; + + Ok(TlsStream { + stream: Box::pin(tls_stream), + }) + } +} + +fn read_cert_store() -> rustls::RootCertStore { + let mut cert_store = rustls::RootCertStore::empty(); + + let certs = rustls_pemfile::certs(&mut std::io::BufReader::new(LE_ROOT_CERT)) + .expect("Failed to parse pem file"); + let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs); + if num_failures > 0 || num_certs_added != 1 { + panic!("Failed to add root cert"); + } + + cert_store +} + +impl<S> AsyncRead for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.stream.as_mut().poll_read(cx, buf) + } +} + +impl<S> AsyncWrite for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.stream.as_mut().poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + self.stream.as_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + self.stream.as_mut().poll_shutdown(cx) + } +} + +impl<S> Connection for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cert_loading() { + let _certs = read_cert_store(); + } +} |
