summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-02-07 12:30:15 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-01 15:30:25 +0100
commitbdb19a3723932d8da927bac1456d91afb2de8c4a (patch)
tree51a3b556aa43f5678899cfb5569d0d065a883d75
parent0e3275fb181f2b17ebde7f8a2b68713f491c73d4 (diff)
downloadmullvadvpn-bdb19a3723932d8da927bac1456d91afb2de8c4a.tar.xz
mullvadvpn-bdb19a3723932d8da927bac1456d91afb2de8c4a.zip
Drop in-flight REST requests implicitly
-rw-r--r--mullvad-jni/src/lib.rs2
-rw-r--r--mullvad-rpc/src/abortable_stream.rs10
-rw-r--r--mullvad-rpc/src/rest.rs79
3 files changed, 40 insertions, 51 deletions
diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs
index 72c75e7360..646988e11b 100644
--- a/mullvad-jni/src/lib.rs
+++ b/mullvad-jni/src/lib.rs
@@ -940,7 +940,7 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_updateR
fn log_request_error(request: &str, error: &daemon_interface::Error) {
match error {
- daemon_interface::Error::RpcError(RestError::Aborted(_)) => {
+ daemon_interface::Error::RpcError(RestError::Aborted) => {
log::debug!("Request to {} cancelled", request);
}
error => {
diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs
index 160e329dfb..af217c5768 100644
--- a/mullvad-rpc/src/abortable_stream.rs
+++ b/mullvad-rpc/src/abortable_stream.rs
@@ -12,6 +12,10 @@ use std::{
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+#[derive(err_derive::Error, Debug)]
+#[error(display = "Stream is closed")]
+pub struct Aborted(());
+
#[derive(Clone, Debug)]
pub struct AbortableStreamHandle {
tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
@@ -71,7 +75,7 @@ where
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
- "stream is closed",
+ Aborted(()),
)));
}
Pin::new(&mut self.stream).poll_write(cx, buf)
@@ -81,7 +85,7 @@ where
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
- "stream is closed",
+ Aborted(()),
)));
}
Pin::new(&mut self.stream).poll_flush(cx)
@@ -104,7 +108,7 @@ where
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
- "stream is closed",
+ Aborted(()),
)));
}
Pin::new(&mut self.stream).poll_read(cx, buf)
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 146d92f395..8bb5efc72b 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -8,7 +8,6 @@ use crate::{
};
use futures::{
channel::{mpsc, oneshot},
- future::{abortable, AbortHandle, Aborted},
sink::SinkExt,
stream::StreamExt,
Stream, TryFutureExt,
@@ -19,9 +18,7 @@ use hyper::{
Method, Uri,
};
use std::{
- collections::BTreeMap,
future::Future,
- mem,
net::SocketAddr,
str::FromStr,
time::{Duration, Instant},
@@ -45,7 +42,7 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(err_derive::Error, Debug)]
pub enum Error {
#[error(display = "Request cancelled")]
- Aborted(Aborted),
+ Aborted,
#[error(display = "Hyper error")]
HyperError(#[error(source)] hyper::Error),
@@ -84,6 +81,26 @@ impl Error {
_ => false,
}
}
+
+ /// Returns a new instance for which `abortable_stream::Aborted` is mapped to `Self::Aborted`.
+ fn map_aborted(self) -> Self {
+ if let Error::HyperError(error) = &self {
+ use std::error::Error;
+ let mut source = error.source();
+ while let Some(error) = source {
+ let io_error: Option<&std::io::Error> = error.downcast_ref();
+ if let Some(io_error) = io_error {
+ let abort_error: Option<&crate::abortable_stream::Aborted> =
+ io_error.get_ref().and_then(|inner| inner.downcast_ref());
+ if abort_error.is_some() {
+ return Self::Aborted;
+ }
+ }
+ source = error.source();
+ }
+ }
+ self
+ }
}
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
@@ -97,8 +114,6 @@ pub(crate) struct RequestService<
command_rx: mpsc::Receiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
- next_id: u64,
- in_flight_requests: BTreeMap<u64, AbortHandle>,
proxy_config_provider: T,
new_address_callback: F,
address_cache: AddressCache,
@@ -140,8 +155,6 @@ impl<
command_rx,
connector_handle,
client,
- in_flight_requests: BTreeMap::new(),
- next_id: 0,
proxy_config_provider,
new_address_callback,
address_cache,
@@ -161,7 +174,6 @@ impl<
async fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
- let id = self.id();
let mut tx = self.command_tx.clone();
let timeout = request.timeout();
@@ -171,18 +183,17 @@ impl<
let suspend_fut = api_availability.wait_for_unsuspend();
let request_fut = self.client.request(hyper_request).map_err(Error::from);
- let (request_future, abort_handle) = abortable(async move {
+ let request_future = async move {
let _ = suspend_fut.await;
request_fut.await
- });
+ };
let future = async move {
- let response =
- tokio::time::timeout(timeout, request_future.map_err(Error::Aborted))
- .await
- .map_err(Error::TimeoutError);
+ let response = tokio::time::timeout(timeout, request_future)
+ .await
+ .map_err(Error::TimeoutError);
- let response = flatten_result(flatten_result(response));
+ let response = flatten_result(response).map_err(|error| error.map_aborted());
if let Err(err) = &response {
if err.is_network_error() && !api_availability.get_state().is_offline() {
@@ -196,18 +207,11 @@ impl<
"Failed to send response to caller, caller channel is shut down"
);
}
- let _ = tx.send(RequestCommand::RequestFinished(id)).await;
};
-
- self.in_flight_requests.insert(id, abort_handle);
tokio::spawn(future);
}
- RequestCommand::RequestFinished(id) => {
- self.in_flight_requests.remove(&id);
- }
- RequestCommand::Reset(tx) => {
- self.reset();
- let _ = tx.send(());
+ RequestCommand::Reset => {
+ self.connector_handle.reset();
}
RequestCommand::NextApiConfig => {
if let Some(new_config) = self.proxy_config_provider.next().await {
@@ -224,26 +228,11 @@ impl<
}
}
- fn reset(&mut self) {
- let old_requests = mem::take(&mut self.in_flight_requests);
- for (_, abort_handle) in old_requests {
- abort_handle.abort();
- }
-
- self.connector_handle.reset();
- }
-
- fn id(&mut self) -> u64 {
- let id = self.next_id;
- self.next_id = id.wrapping_add(1);
- id
- }
-
async fn into_future(mut self) {
while let Some(command) = self.command_rx.next().await {
self.process_command(command).await;
}
- self.reset();
+ self.connector_handle.reset();
}
}
@@ -257,10 +246,7 @@ impl RequestServiceHandle {
/// Resets the corresponding RequestService, dropping all in-flight requests.
pub async fn reset(&self) {
let mut tx = self.tx.clone();
- let (done_tx, done_rx) = oneshot::channel();
-
- let _ = tx.send(RequestCommand::Reset(done_tx)).await;
- let _ = done_rx.await;
+ let _ = tx.send(RequestCommand::Reset).await;
}
/// Submits a `RestRequest` for exectuion to the request service.
@@ -281,8 +267,7 @@ pub(crate) enum RequestCommand {
RestRequest,
oneshot::Sender<std::result::Result<Response, Error>>,
),
- RequestFinished(u64),
- Reset(oneshot::Sender<()>),
+ Reset,
NextApiConfig,
}