summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-01-26 14:12:46 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-01-26 16:28:15 +0100
commit6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a (patch)
tree1e389c4c3b2b7a062bc75bc3be33798d1f560b24
parentb29b5ab4a6d3bd33d5d5242ca0593cec55f77045 (diff)
downloadmullvadvpn-6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a.tar.xz
mullvadvpn-6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a.zip
Remove socket map from HTTP connector and abort handle
-rw-r--r--mullvad-rpc/src/abortable_stream.rs70
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs45
2 files changed, 32 insertions, 83 deletions
diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs
index 8754dc8cf1..160e329dfb 100644
--- a/mullvad-rpc/src/abortable_stream.rs
+++ b/mullvad-rpc/src/abortable_stream.rs
@@ -15,7 +15,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Clone, Debug)]
pub struct AbortableStreamHandle {
tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
- notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
}
impl AbortableStreamHandle {
@@ -23,16 +22,21 @@ impl AbortableStreamHandle {
if let Some(tx) = self.tx.lock().unwrap().take() {
let _ = tx.send(());
}
- if let Some(tx) = self.notify_tx.lock().unwrap().take() {
- let _ = tx.send(());
- }
+ }
+
+ /// Returns whether the stream has already stopped on its own.
+ pub fn is_closed(&self) -> bool {
+ self.tx
+ .lock()
+ .unwrap()
+ .as_ref()
+ .map(|tx| tx.is_canceled())
+ .unwrap_or(true)
}
}
pub struct AbortableStream<S: Unpin> {
stream: S,
- /// Notified when the stream is shut down.
- notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
shutdown_rx: oneshot::Receiver<()>,
}
@@ -40,20 +44,14 @@ impl<S> AbortableStream<S>
where
S: Unpin + Send + 'static,
{
- pub fn new(
- stream: S,
- notify_shutdown_tx: Option<oneshot::Sender<()>>,
- ) -> (Self, AbortableStreamHandle) {
+ pub fn new(stream: S) -> (Self, AbortableStreamHandle) {
let (tx, rx) = oneshot::channel();
- let notify_tx = Arc::new(Mutex::new(notify_shutdown_tx));
let stream_handle = AbortableStreamHandle {
tx: Arc::new(Mutex::new(Some(tx))),
- notify_tx: notify_tx.clone(),
};
(
Self {
stream,
- notify_tx,
shutdown_rx: rx,
},
stream_handle,
@@ -61,23 +59,6 @@ where
}
}
-impl<S: Unpin> AbortableStream<S> {
- fn maybe_send_shutdown_signal(&mut self) {
- if let Some(tx) = self.notify_tx.lock().unwrap().take() {
- let _ = tx.send(());
- }
- }
-}
-
-impl<S> Drop for AbortableStream<S>
-where
- S: Unpin,
-{
- fn drop(&mut self) {
- self.maybe_send_shutdown_signal();
- }
-}
-
impl<S> AsyncWrite for AbortableStream<S>
where
S: AsyncWrite + Unpin + Send + 'static,
@@ -88,7 +69,6 @@ where
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
- self.maybe_send_shutdown_signal();
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"stream is closed",
@@ -99,7 +79,6 @@ where
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
- self.maybe_send_shutdown_signal();
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"stream is closed",
@@ -123,7 +102,6 @@ where
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
- self.maybe_send_shutdown_signal();
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"stream is closed",
@@ -156,7 +134,7 @@ mod test {
let (client, _server) = tokio::io::duplex(64);
runtime.block_on(async move {
- let (mut stream, abort_handle) = AbortableStream::new(client, None);
+ let (mut stream, abort_handle) = AbortableStream::new(client);
let stream_task = tokio::spawn(async move {
let mut buf = vec![];
@@ -173,45 +151,45 @@ mod test {
});
}
- /// Test whether the shutdown signal is sent when the stream is explicitly closed.
+ /// Test the `AbortableStreamHandle::is_closed` method when explicitly closed.
#[test]
fn test_shutdown_signal() {
let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
let (client, _server) = tokio::io::duplex(64);
- let (shutdown_tx, shutdown_rx) = oneshot::channel();
runtime.block_on(async move {
- let (_stream, abort_handle) = AbortableStream::new(client, Some(shutdown_tx));
+ let (_stream, abort_handle) = AbortableStream::new(client);
+ let abort_handle_2 = abort_handle.clone();
+ assert!(!abort_handle_2.is_closed());
abort_handle.close();
- assert!(tokio::time::timeout(Duration::from_secs(1), shutdown_rx)
- .await
- .unwrap()
- .is_ok());
+ assert!(abort_handle_2.is_closed());
});
}
- /// Test whether the shutdown signal is sent when the stream stops on its own.
+ /// Test the `AbortableStreamHandle::is_closed` method when the stream stops on its own.
#[test]
fn test_shutdown_signal_normal() {
let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
let (client, server) = tokio::io::duplex(64);
- let (shutdown_tx, shutdown_rx) = oneshot::channel();
runtime.block_on(async move {
- let (mut stream, _abort_handle) = AbortableStream::new(client, Some(shutdown_tx));
+ let (mut stream, abort_handle) = AbortableStream::new(client);
- tokio::spawn(async move {
+ assert!(!abort_handle.is_closed());
+
+ let stream_task = tokio::spawn(async move {
drop(server);
let mut buf = vec![];
stream.read_to_end(&mut buf).await
});
- assert!(tokio::time::timeout(Duration::from_secs(1), shutdown_rx)
+ assert!(tokio::time::timeout(Duration::from_secs(1), stream_task)
.await
.unwrap()
.is_ok());
+ assert!(abort_handle.is_closed());
});
}
}
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index e7a785601d..4d8108f89b 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -2,12 +2,9 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
tls_stream::TlsStream,
};
+use futures::{channel::mpsc, StreamExt};
#[cfg(target_os = "android")]
-use futures::sink::SinkExt;
-use futures::{
- channel::{mpsc, oneshot},
- StreamExt,
-};
+use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::{
client::connect::dns::{GaiResolver, Name},
@@ -17,7 +14,6 @@ use hyper::{
#[cfg(target_os = "android")]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
- collections::BTreeMap,
fmt,
future::Future,
io,
@@ -51,15 +47,13 @@ impl HttpsConnectorWithSniHandle {
#[derive(Clone)]
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
- handle: Handle,
sni_hostname: Option<String>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
struct HttpsConnectorWithSniInner {
- next_socket_id: usize,
- stream_handles: BTreeMap<usize, AbortableStreamHandle>,
+ stream_handles: Vec<AbortableStreamHandle>,
}
#[cfg(target_os = "android")]
@@ -73,8 +67,7 @@ impl HttpsConnectorWithSni {
) -> (Self, HttpsConnectorWithSniHandle) {
let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded();
let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner {
- stream_handles: BTreeMap::new(),
- next_socket_id: 0,
+ stream_handles: vec![],
}));
let inner_copy = inner.clone();
@@ -85,7 +78,7 @@ impl HttpsConnectorWithSni {
let mut inner = inner_copy.lock().unwrap();
std::mem::take(&mut inner.stream_handles)
};
- for (_, handle) in handles {
+ for handle in handles {
handle.close();
}
}
@@ -94,7 +87,6 @@ impl HttpsConnectorWithSni {
(
HttpsConnectorWithSni {
inner,
- handle,
sni_hostname,
#[cfg(target_os = "android")]
socket_bypass_tx,
@@ -103,13 +95,6 @@ impl HttpsConnectorWithSni {
)
}
- fn next_id(&mut self) -> usize {
- let mut inner = self.inner.lock().unwrap();
- let next_id = inner.next_socket_id;
- inner.next_socket_id = inner.next_socket_id.wrapping_add(1);
- next_id
- }
-
#[cfg(not(target_os = "android"))]
async fn open_socket(addr: SocketAddr) -> std::io::Result<TcpStream> {
timeout(CONNECT_TIMEOUT, TcpStream::connect(addr))
@@ -190,8 +175,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
let inner = self.inner.clone();
- let socket_id = self.next_id();
- let handle = self.handle.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
@@ -213,26 +196,14 @@ impl Service<Uri> for HttpsConnectorWithSni {
)
.await?;
- let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel();
-
- let (tcp_stream, socket_handle) =
- AbortableStream::new(tokio_connection, Some(socket_shutdown_tx));
+ let (tcp_stream, socket_handle) = AbortableStream::new(tokio_connection);
{
let mut inner = inner.lock().unwrap();
- inner
- .stream_handles
- .insert(socket_id, socket_handle.clone());
+ inner.stream_handles.retain(|handle| !handle.is_closed());
+ inner.stream_handles.push(socket_handle);
}
- handle.spawn(async move {
- let _ = socket_shutdown_rx.await;
- {
- let mut inner = inner.lock().unwrap();
- inner.stream_handles.remove(&socket_id);
- }
- });
-
Ok(TlsStream::connect_https(tcp_stream, &hostname).await?)
};