diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-01-20 20:08:22 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-01-26 16:27:53 +0100 |
| commit | 43a89edb9feaed2d5595c92424f29de3cc10999e (patch) | |
| tree | 1a076417b2628846b1030c53d5cb20c988445db3 | |
| parent | 865b571f07517c73bbf30ba097e787a43cc62289 (diff) | |
| download | mullvadvpn-43a89edb9feaed2d5595c92424f29de3cc10999e.tar.xz mullvadvpn-43a89edb9feaed2d5595c92424f29de3cc10999e.zip | |
Refactor API socket cancellation
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 111 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 14 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 48 |
3 files changed, 99 insertions, 74 deletions
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index 7cc97ddd39..e7a785601d 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -1,7 +1,12 @@ -use crate::{abortable_stream::AbortableStream, rest::RequestCommand, tls_stream::TlsStream}; +use crate::{ + abortable_stream::{AbortableStream, AbortableStreamHandle}, + tls_stream::TlsStream, +}; +#[cfg(target_os = "android")] +use futures::sink::SinkExt; use futures::{ channel::{mpsc, oneshot}, - sink::SinkExt, + StreamExt, }; use http::uri::Scheme; use hyper::{ @@ -12,12 +17,14 @@ use hyper::{ #[cfg(target_os = "android")] use std::os::unix::io::{AsRawFd, RawFd}; use std::{ + collections::BTreeMap, fmt, future::Future, io, net::{IpAddr, SocketAddr}, pin::Pin, str::{self, FromStr}, + sync::{Arc, Mutex}, task::{Context, Poll}, time::Duration, }; @@ -28,17 +35,33 @@ use tokio::{net::TcpStream, runtime::Handle, time::timeout}; const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +#[derive(Clone)] +pub struct HttpsConnectorWithSniHandle { + tx: mpsc::UnboundedSender<()>, +} + +impl HttpsConnectorWithSniHandle { + /// Stop all streams produced by this connector + pub fn reset(&self) { + let _ = self.tx.unbounded_send(()); + } +} + /// A Connector for the `https` scheme. #[derive(Clone)] pub struct HttpsConnectorWithSni { - next_socket_id: usize, + inner: Arc<Mutex<HttpsConnectorWithSniInner>>, handle: Handle, sni_hostname: Option<String>, - service_tx: Option<mpsc::Sender<RequestCommand>>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } +struct HttpsConnectorWithSniInner { + next_socket_id: usize, + stream_handles: BTreeMap<usize, AbortableStreamHandle>, +} + #[cfg(target_os = "android")] pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>); @@ -47,25 +70,43 @@ impl HttpsConnectorWithSni { handle: Handle, sni_hostname: Option<String>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - ) -> Self { - HttpsConnectorWithSni { + ) -> (Self, HttpsConnectorWithSniHandle) { + let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded(); + let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner { + stream_handles: BTreeMap::new(), next_socket_id: 0, - handle, - sni_hostname, - #[cfg(target_os = "android")] - socket_bypass_tx, - service_tx: None, - } - } + })); + + let inner_copy = inner.clone(); + handle.spawn(async move { + // Handle requests by `HttpsConnectorWithSniHandle`s + while let Some(()) = rx.next().await { + let handles = { + let mut inner = inner_copy.lock().unwrap(); + std::mem::take(&mut inner.stream_handles) + }; + for (_, handle) in handles { + handle.close(); + } + } + }); - /// Set a channel to register sockets with the request service. - pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) { - self.service_tx = Some(service_tx); + ( + HttpsConnectorWithSni { + inner, + handle, + sni_hostname, + #[cfg(target_os = "android")] + socket_bypass_tx, + }, + HttpsConnectorWithSniHandle { tx }, + ) } fn next_id(&mut self) -> usize { - let next_id = self.next_socket_id; - self.next_socket_id = self.next_socket_id.wrapping_add(1); + let mut inner = self.inner.lock().unwrap(); + let next_id = inner.next_socket_id; + inner.next_socket_id = inner.next_socket_id.wrapping_add(1); next_id } @@ -148,8 +189,7 @@ impl Service<Uri> for HttpsConnectorWithSni { .ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") }); - let service_tx = self.service_tx.clone(); - + let inner = self.inner.clone(); let socket_id = self.next_id(); let handle = self.handle.clone(); #[cfg(target_os = "android")] @@ -177,25 +217,22 @@ impl Service<Uri> for HttpsConnectorWithSni { let (tcp_stream, socket_handle) = AbortableStream::new(tokio_connection, Some(socket_shutdown_tx)); - if let Some(mut service_tx) = service_tx { - if service_tx - .send(RequestCommand::SocketOpened(socket_id, socket_handle)) - .await - .is_err() + + { + let mut inner = inner.lock().unwrap(); + inner + .stream_handles + .insert(socket_id, socket_handle.clone()); + } + + handle.spawn(async move { + let _ = socket_shutdown_rx.await; { - log::error!("Failed to submit new socket to request service"); + let mut inner = inner.lock().unwrap(); + inner.stream_handles.remove(&socket_id); } - handle.spawn(async move { - let _ = socket_shutdown_rx.await; - if service_tx - .send(RequestCommand::SocketClosed(socket_id)) - .await - .is_err() - { - log::error!("Failed to send socket closure command to request service"); - } - }); - } + }); + Ok(TlsStream::connect_https(tcp_stream, &hostname).await?) }; diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index fce57d3991..39aaf20c2c 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -21,9 +21,8 @@ pub mod availability; use availability::{ApiAvailability, ApiAvailabilityHandle}; pub mod rest; -mod https_client_with_sni; -use crate::https_client_with_sni::HttpsConnectorWithSni; mod abortable_stream; +mod https_client_with_sni; mod tls_stream; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; @@ -228,18 +227,13 @@ impl MullvadRpcRuntime { /// Creates a new request service and returns a handle to it. fn new_request_service(&mut self, sni_hostname: Option<String>) -> rest::RequestServiceHandle { - let https_connector = HttpsConnectorWithSni::new( - self.handle.clone(), - sni_hostname, - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ); - let service = rest::RequestService::new( - https_connector, self.handle.clone(), + sni_hostname, self.api_availability.handle(), self.address_cache.clone(), + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), ); let handle = service.handle(); self.handle.spawn(service.into_future()); diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 7b150e7807..41b637f84d 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,6 +1,9 @@ +#[cfg(target_os = "android")] +pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ - abortable_stream::AbortableStreamHandle, address_cache::AddressCache, - availability::ApiAvailabilityHandle, https_client_with_sni::HttpsConnectorWithSni, + address_cache::AddressCache, + availability::ApiAvailabilityHandle, + https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, }; use futures::{ channel::{mpsc, oneshot}, @@ -88,7 +91,7 @@ impl Error { pub(crate) struct RequestService { command_tx: mpsc::Sender<RequestCommand>, command_rx: mpsc::Receiver<RequestCommand>, - sockets: BTreeMap<usize, AbortableStreamHandle>, + connector_handle: HttpsConnectorWithSniHandle, client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, handle: Handle, next_id: u64, @@ -100,24 +103,30 @@ pub(crate) struct RequestService { impl RequestService { /// Constructs a new request service. pub fn new( - mut connector: HttpsConnectorWithSni, handle: Handle, + sni_hostname: Option<String>, api_availability: ApiAvailabilityHandle, address_cache: AddressCache, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> RequestService { - let (command_tx, command_rx) = mpsc::channel(1); + let (connector, connector_handle) = HttpsConnectorWithSni::new( + handle.clone(), + sni_hostname, + #[cfg(target_os = "android")] + socket_bypass_tx.clone(), + ); - connector.set_service_tx(command_tx.clone()); + let (command_tx, command_rx) = mpsc::channel(1); let client = Client::builder().build(connector); Self { command_tx, command_rx, - sockets: BTreeMap::new(), + connector_handle, client, + handle, in_flight_requests: BTreeMap::new(), next_id: 0, - handle, api_availability, address_cache, } @@ -194,20 +203,12 @@ impl RequestService { let _ = tx.send(RequestCommand::RequestFinished(id)).await; }; - self.handle.spawn(future); self.in_flight_requests.insert(id, abort_handle); - } - - RequestCommand::SocketOpened(id, socket) => { - self.sockets.insert(id, socket); - } - RequestCommand::SocketClosed(id) => { - self.sockets.remove(&id); + self.handle.spawn(future); } RequestCommand::RequestFinished(id) => { self.in_flight_requests.remove(&id); } - RequestCommand::Reset(tx) => { self.reset(); let _ = tx.send(()); @@ -216,17 +217,12 @@ impl RequestService { } fn reset(&mut self) { - let old_requests = mem::replace(&mut self.in_flight_requests, BTreeMap::new()); - for (_, abort_handle) in old_requests.into_iter() { + let old_requests = mem::take(&mut self.in_flight_requests); + for (_, abort_handle) in old_requests { abort_handle.abort(); } - let old_sockets = mem::replace(&mut self.sockets, BTreeMap::new()); - for (_, socket) in old_sockets.into_iter() { - socket.close(); - } - - self.next_id = 0; + self.connector_handle.reset(); } fn id(&mut self) -> u64 { @@ -296,8 +292,6 @@ pub(crate) enum RequestCommand { oneshot::Sender<std::result::Result<Response, Error>>, ), RequestFinished(u64), - SocketOpened(usize, AbortableStreamHandle), - SocketClosed(usize), Reset(oneshot::Sender<()>), } |
