diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-02-07 12:30:15 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-01 15:30:25 +0100 |
| commit | bdb19a3723932d8da927bac1456d91afb2de8c4a (patch) | |
| tree | 51a3b556aa43f5678899cfb5569d0d065a883d75 | |
| parent | 0e3275fb181f2b17ebde7f8a2b68713f491c73d4 (diff) | |
| download | mullvadvpn-bdb19a3723932d8da927bac1456d91afb2de8c4a.tar.xz mullvadvpn-bdb19a3723932d8da927bac1456d91afb2de8c4a.zip | |
Drop in-flight REST requests implicitly
| -rw-r--r-- | mullvad-jni/src/lib.rs | 2 | ||||
| -rw-r--r-- | mullvad-rpc/src/abortable_stream.rs | 10 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 79 |
3 files changed, 40 insertions, 51 deletions
diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 72c75e7360..646988e11b 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -940,7 +940,7 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_updateR fn log_request_error(request: &str, error: &daemon_interface::Error) { match error { - daemon_interface::Error::RpcError(RestError::Aborted(_)) => { + daemon_interface::Error::RpcError(RestError::Aborted) => { log::debug!("Request to {} cancelled", request); } error => { diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs index 160e329dfb..af217c5768 100644 --- a/mullvad-rpc/src/abortable_stream.rs +++ b/mullvad-rpc/src/abortable_stream.rs @@ -12,6 +12,10 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[derive(err_derive::Error, Debug)] +#[error(display = "Stream is closed")] +pub struct Aborted(()); + #[derive(Clone, Debug)] pub struct AbortableStreamHandle { tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, @@ -71,7 +75,7 @@ where if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, - "stream is closed", + Aborted(()), ))); } Pin::new(&mut self.stream).poll_write(cx, buf) @@ -81,7 +85,7 @@ where if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, - "stream is closed", + Aborted(()), ))); } Pin::new(&mut self.stream).poll_flush(cx) @@ -104,7 +108,7 @@ where if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, - "stream is closed", + Aborted(()), ))); } Pin::new(&mut self.stream).poll_read(cx, buf) diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 146d92f395..8bb5efc72b 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -8,7 +8,6 @@ use crate::{ }; use futures::{ channel::{mpsc, oneshot}, - future::{abortable, AbortHandle, Aborted}, sink::SinkExt, stream::StreamExt, Stream, TryFutureExt, @@ -19,9 +18,7 @@ use hyper::{ Method, Uri, }; use std::{ - collections::BTreeMap, future::Future, - mem, net::SocketAddr, str::FromStr, time::{Duration, Instant}, @@ -45,7 +42,7 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); #[derive(err_derive::Error, Debug)] pub enum Error { #[error(display = "Request cancelled")] - Aborted(Aborted), + Aborted, #[error(display = "Hyper error")] HyperError(#[error(source)] hyper::Error), @@ -84,6 +81,26 @@ impl Error { _ => false, } } + + /// Returns a new instance for which `abortable_stream::Aborted` is mapped to `Self::Aborted`. + fn map_aborted(self) -> Self { + if let Error::HyperError(error) = &self { + use std::error::Error; + let mut source = error.source(); + while let Some(error) = source { + let io_error: Option<&std::io::Error> = error.downcast_ref(); + if let Some(io_error) = io_error { + let abort_error: Option<&crate::abortable_stream::Aborted> = + io_error.get_ref().and_then(|inner| inner.downcast_ref()); + if abort_error.is_some() { + return Self::Aborted; + } + } + source = error.source(); + } + } + self + } } /// A service that executes HTTP requests, allowing for on-demand termination of all in-flight @@ -97,8 +114,6 @@ pub(crate) struct RequestService< command_rx: mpsc::Receiver<RequestCommand>, connector_handle: HttpsConnectorWithSniHandle, client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, - next_id: u64, - in_flight_requests: BTreeMap<u64, AbortHandle>, proxy_config_provider: T, new_address_callback: F, address_cache: AddressCache, @@ -140,8 +155,6 @@ impl< command_rx, connector_handle, client, - in_flight_requests: BTreeMap::new(), - next_id: 0, proxy_config_provider, new_address_callback, address_cache, @@ -161,7 +174,6 @@ impl< async fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { - let id = self.id(); let mut tx = self.command_tx.clone(); let timeout = request.timeout(); @@ -171,18 +183,17 @@ impl< let suspend_fut = api_availability.wait_for_unsuspend(); let request_fut = self.client.request(hyper_request).map_err(Error::from); - let (request_future, abort_handle) = abortable(async move { + let request_future = async move { let _ = suspend_fut.await; request_fut.await - }); + }; let future = async move { - let response = - tokio::time::timeout(timeout, request_future.map_err(Error::Aborted)) - .await - .map_err(Error::TimeoutError); + let response = tokio::time::timeout(timeout, request_future) + .await + .map_err(Error::TimeoutError); - let response = flatten_result(flatten_result(response)); + let response = flatten_result(response).map_err(|error| error.map_aborted()); if let Err(err) = &response { if err.is_network_error() && !api_availability.get_state().is_offline() { @@ -196,18 +207,11 @@ impl< "Failed to send response to caller, caller channel is shut down" ); } - let _ = tx.send(RequestCommand::RequestFinished(id)).await; }; - - self.in_flight_requests.insert(id, abort_handle); tokio::spawn(future); } - RequestCommand::RequestFinished(id) => { - self.in_flight_requests.remove(&id); - } - RequestCommand::Reset(tx) => { - self.reset(); - let _ = tx.send(()); + RequestCommand::Reset => { + self.connector_handle.reset(); } RequestCommand::NextApiConfig => { if let Some(new_config) = self.proxy_config_provider.next().await { @@ -224,26 +228,11 @@ impl< } } - fn reset(&mut self) { - let old_requests = mem::take(&mut self.in_flight_requests); - for (_, abort_handle) in old_requests { - abort_handle.abort(); - } - - self.connector_handle.reset(); - } - - fn id(&mut self) -> u64 { - let id = self.next_id; - self.next_id = id.wrapping_add(1); - id - } - async fn into_future(mut self) { while let Some(command) = self.command_rx.next().await { self.process_command(command).await; } - self.reset(); + self.connector_handle.reset(); } } @@ -257,10 +246,7 @@ impl RequestServiceHandle { /// Resets the corresponding RequestService, dropping all in-flight requests. pub async fn reset(&self) { let mut tx = self.tx.clone(); - let (done_tx, done_rx) = oneshot::channel(); - - let _ = tx.send(RequestCommand::Reset(done_tx)).await; - let _ = done_rx.await; + let _ = tx.send(RequestCommand::Reset).await; } /// Submits a `RestRequest` for exectuion to the request service. @@ -281,8 +267,7 @@ pub(crate) enum RequestCommand { RestRequest, oneshot::Sender<std::result::Result<Response, Error>>, ), - RequestFinished(u64), - Reset(oneshot::Sender<()>), + Reset, NextApiConfig, } |
