summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-01-20 10:17:33 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-01-26 16:27:53 +0100
commit865b571f07517c73bbf30ba097e787a43cc62289 (patch)
treebcbc820fd582a39d912ca6dbd72776ac4fbfd303
parent7ce9d490f8ca3983bb9854b10335852831b5b71d (diff)
downloadmullvadvpn-865b571f07517c73bbf30ba097e787a43cc62289.tar.xz
mullvadvpn-865b571f07517c73bbf30ba097e787a43cc62289.zip
Add TlsStream type
-rw-r--r--Cargo.lock8
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs59
-rw-r--r--mullvad-rpc/src/lib.rs1
-rw-r--r--mullvad-rpc/src/tls_stream.rs124
4 files changed, 133 insertions, 59 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 483cb54781..c6673a446f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -748,9 +748,9 @@ checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7"
[[package]]
name = "h2"
-version = "0.3.6"
+version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c06815895acec637cd6ed6e9662c935b866d20a106f8361892893a7d9234964"
+checksum = "0c9de88456263e249e241fcd211d3954e2c9b0ef7ccfc235a444eb367cae3689"
dependencies = [
"bytes",
"fnv",
@@ -858,9 +858,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
-version = "0.14.13"
+version = "0.14.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "15d1cfb9e4f68655fa04c01f59edb405b6074a0f7118ea881e5026e4a1cd8593"
+checksum = "b7ec3e62bdc98a2f0393a5048e4c30ef659440ea6e0e572965103e72bd836f55"
dependencies = [
"bytes",
"futures-channel",
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index 7b2a7a4b9a..7cc97ddd39 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -1,4 +1,4 @@
-use crate::{abortable_stream::AbortableStream, rest::RequestCommand};
+use crate::{abortable_stream::AbortableStream, rest::RequestCommand, tls_stream::TlsStream};
use futures::{
channel::{mpsc, oneshot},
sink::SinkExt,
@@ -9,18 +9,15 @@ use hyper::{
service::Service,
Uri,
};
-use hyper_rustls::MaybeHttpsStream;
-use rustls::ServerName;
#[cfg(target_os = "android")]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
fmt,
future::Future,
- io::{self, BufReader},
+ io,
net::{IpAddr, SocketAddr},
pin::Pin,
str::{self, FromStr},
- sync::Arc,
task::{Context, Poll},
time::Duration,
};
@@ -28,10 +25,6 @@ use std::{
use tokio::net::TcpSocket;
use tokio::{net::TcpStream, runtime::Handle, time::timeout};
-use tokio_rustls::rustls;
-
-// New LetsEncrypt root certificate
-const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem");
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
@@ -44,7 +37,6 @@ pub struct HttpsConnectorWithSni {
service_tx: Option<mpsc::Sender<RequestCommand>>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- tls: Arc<rustls::ClientConfig>,
}
#[cfg(target_os = "android")]
@@ -56,15 +48,6 @@ impl HttpsConnectorWithSni {
sni_hostname: Option<String>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Self {
- let mut config = rustls::ClientConfig::builder()
- .with_safe_default_cipher_suites()
- .with_safe_default_kx_groups()
- .with_protocol_versions(&[&rustls::version::TLS13])
- .unwrap()
- .with_root_certificates(Self::read_cert_store())
- .with_no_client_auth();
- config.enable_sni = true;
-
HttpsConnectorWithSni {
next_socket_id: 0,
handle,
@@ -72,23 +55,9 @@ impl HttpsConnectorWithSni {
#[cfg(target_os = "android")]
socket_bypass_tx,
service_tx: None,
- tls: Arc::new(config),
}
}
- fn read_cert_store() -> rustls::RootCertStore {
- let mut cert_store = rustls::RootCertStore::empty();
-
- let certs = rustls_pemfile::certs(&mut BufReader::new(LE_ROOT_CERT))
- .expect("Failed to parse pem file");
- let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs);
- if num_failures > 0 || num_certs_added != 1 {
- panic!("Failed to add root cert");
- }
-
- cert_store
- }
-
/// Set a channel to register sockets with the request service.
pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) {
self.service_tx = Some(service_tx);
@@ -162,7 +131,7 @@ impl fmt::Debug for HttpsConnectorWithSni {
}
impl Service<Uri> for HttpsConnectorWithSni {
- type Response = MaybeHttpsStream<AbortableStream<TcpStream>>;
+ type Response = TlsStream<AbortableStream<TcpStream>>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@@ -172,7 +141,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
fn call(&mut self, uri: Uri) -> Self::Future {
- let tls_connector: tokio_rustls::TlsConnector = self.tls.clone().into();
let sni_hostname = self
.sni_hostname
.clone()
@@ -196,12 +164,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
let hostname = sni_hostname?;
- let host = ServerName::try_from(hostname.as_str()).map_err(|_| {
- io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("invalid hostname \"{}\"", hostname),
- )
- })?;
let addr = Self::resolve_address(&uri).await?;
let tokio_connection = Self::open_socket(
@@ -234,22 +196,9 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
});
}
-
- let tls_connection = tls_connector.connect(host, tcp_stream).await?;
-
- Ok(MaybeHttpsStream::Https(tls_connection))
+ Ok(TlsStream::connect_https(tcp_stream, &hostname).await?)
};
Box::pin(fut)
}
}
-
-#[cfg(test)]
-mod test {
- use super::HttpsConnectorWithSni;
-
- #[test]
- fn test_cert_loading() {
- let _certs = HttpsConnectorWithSni::read_cert_store();
- }
-}
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 9dfd139b01..fce57d3991 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -24,6 +24,7 @@ pub mod rest;
mod https_client_with_sni;
use crate::https_client_with_sni::HttpsConnectorWithSni;
mod abortable_stream;
+mod tls_stream;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
diff --git a/mullvad-rpc/src/tls_stream.rs b/mullvad-rpc/src/tls_stream.rs
new file mode 100644
index 0000000000..232bb39b92
--- /dev/null
+++ b/mullvad-rpc/src/tls_stream.rs
@@ -0,0 +1,124 @@
+//! Provides a TLS 1.3 stream with SNI and LE root cert only.
+use std::{
+ io::{self, ErrorKind},
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll},
+};
+
+use hyper::client::connect::{Connected, Connection};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio_rustls::{
+ rustls::{self, ClientConfig, ServerName},
+ TlsConnector,
+};
+
+const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem");
+
+pub struct TlsStream<S: AsyncRead + AsyncWrite + Unpin> {
+ stream: Pin<Box<tokio_rustls::client::TlsStream<S>>>,
+}
+
+impl<S> TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ pub async fn connect_https(stream: S, domain: &str) -> io::Result<TlsStream<S>> {
+ lazy_static::lazy_static! {
+ static ref TLS_CONFIG: Arc<ClientConfig> = {
+ let config = ClientConfig::builder()
+ .with_safe_default_cipher_suites()
+ .with_safe_default_kx_groups()
+ .with_protocol_versions(&[&rustls::version::TLS13])
+ .unwrap()
+ .with_root_certificates(read_cert_store())
+ .with_no_client_auth();
+ Arc::new(config)
+ };
+ }
+
+ let connector = TlsConnector::from(TLS_CONFIG.clone());
+
+ let host = match ServerName::try_from(domain) {
+ Ok(n) => n,
+ Err(_) => {
+ return Err(io::Error::new(
+ ErrorKind::InvalidInput,
+ format!("invalid hostname \"{}\"", domain),
+ ));
+ }
+ };
+
+ let tls_stream = connector.connect(host, stream).await?;
+
+ Ok(TlsStream {
+ stream: Box::pin(tls_stream),
+ })
+ }
+}
+
+fn read_cert_store() -> rustls::RootCertStore {
+ let mut cert_store = rustls::RootCertStore::empty();
+
+ let certs = rustls_pemfile::certs(&mut std::io::BufReader::new(LE_ROOT_CERT))
+ .expect("Failed to parse pem file");
+ let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs);
+ if num_failures > 0 || num_certs_added != 1 {
+ panic!("Failed to add root cert");
+ }
+
+ cert_store
+}
+
+impl<S> AsyncRead for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.stream.as_mut().poll_read(cx, buf)
+ }
+}
+
+impl<S> AsyncWrite for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.stream.as_mut().poll_write(cx, buf)
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ self.stream.as_mut().poll_flush(cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ self.stream.as_mut().poll_shutdown(cx)
+ }
+}
+
+impl<S> Connection for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn connected(&self) -> Connected {
+ Connected::new()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_cert_loading() {
+ let _certs = read_cert_store();
+ }
+}