summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-01-20 20:08:22 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-01-26 16:27:53 +0100
commit43a89edb9feaed2d5595c92424f29de3cc10999e (patch)
tree1a076417b2628846b1030c53d5cb20c988445db3
parent865b571f07517c73bbf30ba097e787a43cc62289 (diff)
downloadmullvadvpn-43a89edb9feaed2d5595c92424f29de3cc10999e.tar.xz
mullvadvpn-43a89edb9feaed2d5595c92424f29de3cc10999e.zip
Refactor API socket cancellation
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs111
-rw-r--r--mullvad-rpc/src/lib.rs14
-rw-r--r--mullvad-rpc/src/rest.rs48
3 files changed, 99 insertions, 74 deletions
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index 7cc97ddd39..e7a785601d 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -1,7 +1,12 @@
-use crate::{abortable_stream::AbortableStream, rest::RequestCommand, tls_stream::TlsStream};
+use crate::{
+ abortable_stream::{AbortableStream, AbortableStreamHandle},
+ tls_stream::TlsStream,
+};
+#[cfg(target_os = "android")]
+use futures::sink::SinkExt;
use futures::{
channel::{mpsc, oneshot},
- sink::SinkExt,
+ StreamExt,
};
use http::uri::Scheme;
use hyper::{
@@ -12,12 +17,14 @@ use hyper::{
#[cfg(target_os = "android")]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
+ collections::BTreeMap,
fmt,
future::Future,
io,
net::{IpAddr, SocketAddr},
pin::Pin,
str::{self, FromStr},
+ sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
@@ -28,17 +35,33 @@ use tokio::{net::TcpStream, runtime::Handle, time::timeout};
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
+#[derive(Clone)]
+pub struct HttpsConnectorWithSniHandle {
+ tx: mpsc::UnboundedSender<()>,
+}
+
+impl HttpsConnectorWithSniHandle {
+ /// Stop all streams produced by this connector
+ pub fn reset(&self) {
+ let _ = self.tx.unbounded_send(());
+ }
+}
+
/// A Connector for the `https` scheme.
#[derive(Clone)]
pub struct HttpsConnectorWithSni {
- next_socket_id: usize,
+ inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
handle: Handle,
sni_hostname: Option<String>,
- service_tx: Option<mpsc::Sender<RequestCommand>>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
+struct HttpsConnectorWithSniInner {
+ next_socket_id: usize,
+ stream_handles: BTreeMap<usize, AbortableStreamHandle>,
+}
+
#[cfg(target_os = "android")]
pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
@@ -47,25 +70,43 @@ impl HttpsConnectorWithSni {
handle: Handle,
sni_hostname: Option<String>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> Self {
- HttpsConnectorWithSni {
+ ) -> (Self, HttpsConnectorWithSniHandle) {
+ let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded();
+ let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner {
+ stream_handles: BTreeMap::new(),
next_socket_id: 0,
- handle,
- sni_hostname,
- #[cfg(target_os = "android")]
- socket_bypass_tx,
- service_tx: None,
- }
- }
+ }));
+
+ let inner_copy = inner.clone();
+ handle.spawn(async move {
+ // Handle requests by `HttpsConnectorWithSniHandle`s
+ while let Some(()) = rx.next().await {
+ let handles = {
+ let mut inner = inner_copy.lock().unwrap();
+ std::mem::take(&mut inner.stream_handles)
+ };
+ for (_, handle) in handles {
+ handle.close();
+ }
+ }
+ });
- /// Set a channel to register sockets with the request service.
- pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) {
- self.service_tx = Some(service_tx);
+ (
+ HttpsConnectorWithSni {
+ inner,
+ handle,
+ sni_hostname,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ },
+ HttpsConnectorWithSniHandle { tx },
+ )
}
fn next_id(&mut self) -> usize {
- let next_id = self.next_socket_id;
- self.next_socket_id = self.next_socket_id.wrapping_add(1);
+ let mut inner = self.inner.lock().unwrap();
+ let next_id = inner.next_socket_id;
+ inner.next_socket_id = inner.next_socket_id.wrapping_add(1);
next_id
}
@@ -148,8 +189,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
- let service_tx = self.service_tx.clone();
-
+ let inner = self.inner.clone();
let socket_id = self.next_id();
let handle = self.handle.clone();
#[cfg(target_os = "android")]
@@ -177,25 +217,22 @@ impl Service<Uri> for HttpsConnectorWithSni {
let (tcp_stream, socket_handle) =
AbortableStream::new(tokio_connection, Some(socket_shutdown_tx));
- if let Some(mut service_tx) = service_tx {
- if service_tx
- .send(RequestCommand::SocketOpened(socket_id, socket_handle))
- .await
- .is_err()
+
+ {
+ let mut inner = inner.lock().unwrap();
+ inner
+ .stream_handles
+ .insert(socket_id, socket_handle.clone());
+ }
+
+ handle.spawn(async move {
+ let _ = socket_shutdown_rx.await;
{
- log::error!("Failed to submit new socket to request service");
+ let mut inner = inner.lock().unwrap();
+ inner.stream_handles.remove(&socket_id);
}
- handle.spawn(async move {
- let _ = socket_shutdown_rx.await;
- if service_tx
- .send(RequestCommand::SocketClosed(socket_id))
- .await
- .is_err()
- {
- log::error!("Failed to send socket closure command to request service");
- }
- });
- }
+ });
+
Ok(TlsStream::connect_https(tcp_stream, &hostname).await?)
};
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index fce57d3991..39aaf20c2c 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -21,9 +21,8 @@ pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
pub mod rest;
-mod https_client_with_sni;
-use crate::https_client_with_sni::HttpsConnectorWithSni;
mod abortable_stream;
+mod https_client_with_sni;
mod tls_stream;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
@@ -228,18 +227,13 @@ impl MullvadRpcRuntime {
/// Creates a new request service and returns a handle to it.
fn new_request_service(&mut self, sni_hostname: Option<String>) -> rest::RequestServiceHandle {
- let https_connector = HttpsConnectorWithSni::new(
- self.handle.clone(),
- sni_hostname,
- #[cfg(target_os = "android")]
- self.socket_bypass_tx.clone(),
- );
-
let service = rest::RequestService::new(
- https_connector,
self.handle.clone(),
+ sni_hostname,
self.api_availability.handle(),
self.address_cache.clone(),
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
);
let handle = service.handle();
self.handle.spawn(service.into_future());
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 7b150e7807..41b637f84d 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,6 +1,9 @@
+#[cfg(target_os = "android")]
+pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
- abortable_stream::AbortableStreamHandle, address_cache::AddressCache,
- availability::ApiAvailabilityHandle, https_client_with_sni::HttpsConnectorWithSni,
+ address_cache::AddressCache,
+ availability::ApiAvailabilityHandle,
+ https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
};
use futures::{
channel::{mpsc, oneshot},
@@ -88,7 +91,7 @@ impl Error {
pub(crate) struct RequestService {
command_tx: mpsc::Sender<RequestCommand>,
command_rx: mpsc::Receiver<RequestCommand>,
- sockets: BTreeMap<usize, AbortableStreamHandle>,
+ connector_handle: HttpsConnectorWithSniHandle,
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
handle: Handle,
next_id: u64,
@@ -100,24 +103,30 @@ pub(crate) struct RequestService {
impl RequestService {
/// Constructs a new request service.
pub fn new(
- mut connector: HttpsConnectorWithSni,
handle: Handle,
+ sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestService {
- let (command_tx, command_rx) = mpsc::channel(1);
+ let (connector, connector_handle) = HttpsConnectorWithSni::new(
+ handle.clone(),
+ sni_hostname,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx.clone(),
+ );
- connector.set_service_tx(command_tx.clone());
+ let (command_tx, command_rx) = mpsc::channel(1);
let client = Client::builder().build(connector);
Self {
command_tx,
command_rx,
- sockets: BTreeMap::new(),
+ connector_handle,
client,
+ handle,
in_flight_requests: BTreeMap::new(),
next_id: 0,
- handle,
api_availability,
address_cache,
}
@@ -194,20 +203,12 @@ impl RequestService {
let _ = tx.send(RequestCommand::RequestFinished(id)).await;
};
- self.handle.spawn(future);
self.in_flight_requests.insert(id, abort_handle);
- }
-
- RequestCommand::SocketOpened(id, socket) => {
- self.sockets.insert(id, socket);
- }
- RequestCommand::SocketClosed(id) => {
- self.sockets.remove(&id);
+ self.handle.spawn(future);
}
RequestCommand::RequestFinished(id) => {
self.in_flight_requests.remove(&id);
}
-
RequestCommand::Reset(tx) => {
self.reset();
let _ = tx.send(());
@@ -216,17 +217,12 @@ impl RequestService {
}
fn reset(&mut self) {
- let old_requests = mem::replace(&mut self.in_flight_requests, BTreeMap::new());
- for (_, abort_handle) in old_requests.into_iter() {
+ let old_requests = mem::take(&mut self.in_flight_requests);
+ for (_, abort_handle) in old_requests {
abort_handle.abort();
}
- let old_sockets = mem::replace(&mut self.sockets, BTreeMap::new());
- for (_, socket) in old_sockets.into_iter() {
- socket.close();
- }
-
- self.next_id = 0;
+ self.connector_handle.reset();
}
fn id(&mut self) -> u64 {
@@ -296,8 +292,6 @@ pub(crate) enum RequestCommand {
oneshot::Sender<std::result::Result<Response, Error>>,
),
RequestFinished(u64),
- SocketOpened(usize, AbortableStreamHandle),
- SocketClosed(usize),
Reset(oneshot::Sender<()>),
}