diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-01-26 14:12:46 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-01-26 16:28:15 +0100 |
| commit | 6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a (patch) | |
| tree | 1e389c4c3b2b7a062bc75bc3be33798d1f560b24 | |
| parent | b29b5ab4a6d3bd33d5d5242ca0593cec55f77045 (diff) | |
| download | mullvadvpn-6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a.tar.xz mullvadvpn-6d7ad13e51f884006a94a2b7fc87655c2fd2cd0a.zip | |
Remove socket map from HTTP connector and abort handle
| -rw-r--r-- | mullvad-rpc/src/abortable_stream.rs | 70 | ||||
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 45 |
2 files changed, 32 insertions, 83 deletions
diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs index 8754dc8cf1..160e329dfb 100644 --- a/mullvad-rpc/src/abortable_stream.rs +++ b/mullvad-rpc/src/abortable_stream.rs @@ -15,7 +15,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[derive(Clone, Debug)] pub struct AbortableStreamHandle { tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, - notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, } impl AbortableStreamHandle { @@ -23,16 +22,21 @@ impl AbortableStreamHandle { if let Some(tx) = self.tx.lock().unwrap().take() { let _ = tx.send(()); } - if let Some(tx) = self.notify_tx.lock().unwrap().take() { - let _ = tx.send(()); - } + } + + /// Returns whether the stream has already stopped on its own. + pub fn is_closed(&self) -> bool { + self.tx + .lock() + .unwrap() + .as_ref() + .map(|tx| tx.is_canceled()) + .unwrap_or(true) } } pub struct AbortableStream<S: Unpin> { stream: S, - /// Notified when the stream is shut down. - notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, shutdown_rx: oneshot::Receiver<()>, } @@ -40,20 +44,14 @@ impl<S> AbortableStream<S> where S: Unpin + Send + 'static, { - pub fn new( - stream: S, - notify_shutdown_tx: Option<oneshot::Sender<()>>, - ) -> (Self, AbortableStreamHandle) { + pub fn new(stream: S) -> (Self, AbortableStreamHandle) { let (tx, rx) = oneshot::channel(); - let notify_tx = Arc::new(Mutex::new(notify_shutdown_tx)); let stream_handle = AbortableStreamHandle { tx: Arc::new(Mutex::new(Some(tx))), - notify_tx: notify_tx.clone(), }; ( Self { stream, - notify_tx, shutdown_rx: rx, }, stream_handle, @@ -61,23 +59,6 @@ where } } -impl<S: Unpin> AbortableStream<S> { - fn maybe_send_shutdown_signal(&mut self) { - if let Some(tx) = self.notify_tx.lock().unwrap().take() { - let _ = tx.send(()); - } - } -} - -impl<S> Drop for AbortableStream<S> -where - S: Unpin, -{ - fn drop(&mut self) { - self.maybe_send_shutdown_signal(); - } -} - impl<S> AsyncWrite for AbortableStream<S> where S: AsyncWrite + Unpin + Send + 'static, @@ -88,7 +69,6 @@ where buf: &[u8], ) -> Poll<io::Result<usize>> { if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { - self.maybe_send_shutdown_signal(); return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, "stream is closed", @@ -99,7 +79,6 @@ where fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { - self.maybe_send_shutdown_signal(); return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, "stream is closed", @@ -123,7 +102,6 @@ where buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { - self.maybe_send_shutdown_signal(); return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, "stream is closed", @@ -156,7 +134,7 @@ mod test { let (client, _server) = tokio::io::duplex(64); runtime.block_on(async move { - let (mut stream, abort_handle) = AbortableStream::new(client, None); + let (mut stream, abort_handle) = AbortableStream::new(client); let stream_task = tokio::spawn(async move { let mut buf = vec![]; @@ -173,45 +151,45 @@ mod test { }); } - /// Test whether the shutdown signal is sent when the stream is explicitly closed. + /// Test the `AbortableStreamHandle::is_closed` method when explicitly closed. #[test] fn test_shutdown_signal() { let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); let (client, _server) = tokio::io::duplex(64); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); runtime.block_on(async move { - let (_stream, abort_handle) = AbortableStream::new(client, Some(shutdown_tx)); + let (_stream, abort_handle) = AbortableStream::new(client); + let abort_handle_2 = abort_handle.clone(); + assert!(!abort_handle_2.is_closed()); abort_handle.close(); - assert!(tokio::time::timeout(Duration::from_secs(1), shutdown_rx) - .await - .unwrap() - .is_ok()); + assert!(abort_handle_2.is_closed()); }); } - /// Test whether the shutdown signal is sent when the stream stops on its own. + /// Test the `AbortableStreamHandle::is_closed` method when the stream stops on its own. #[test] fn test_shutdown_signal_normal() { let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); let (client, server) = tokio::io::duplex(64); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); runtime.block_on(async move { - let (mut stream, _abort_handle) = AbortableStream::new(client, Some(shutdown_tx)); + let (mut stream, abort_handle) = AbortableStream::new(client); - tokio::spawn(async move { + assert!(!abort_handle.is_closed()); + + let stream_task = tokio::spawn(async move { drop(server); let mut buf = vec![]; stream.read_to_end(&mut buf).await }); - assert!(tokio::time::timeout(Duration::from_secs(1), shutdown_rx) + assert!(tokio::time::timeout(Duration::from_secs(1), stream_task) .await .unwrap() .is_ok()); + assert!(abort_handle.is_closed()); }); } } diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index e7a785601d..4d8108f89b 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -2,12 +2,9 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, tls_stream::TlsStream, }; +use futures::{channel::mpsc, StreamExt}; #[cfg(target_os = "android")] -use futures::sink::SinkExt; -use futures::{ - channel::{mpsc, oneshot}, - StreamExt, -}; +use futures::{channel::oneshot, sink::SinkExt}; use http::uri::Scheme; use hyper::{ client::connect::dns::{GaiResolver, Name}, @@ -17,7 +14,6 @@ use hyper::{ #[cfg(target_os = "android")] use std::os::unix::io::{AsRawFd, RawFd}; use std::{ - collections::BTreeMap, fmt, future::Future, io, @@ -51,15 +47,13 @@ impl HttpsConnectorWithSniHandle { #[derive(Clone)] pub struct HttpsConnectorWithSni { inner: Arc<Mutex<HttpsConnectorWithSniInner>>, - handle: Handle, sni_hostname: Option<String>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } struct HttpsConnectorWithSniInner { - next_socket_id: usize, - stream_handles: BTreeMap<usize, AbortableStreamHandle>, + stream_handles: Vec<AbortableStreamHandle>, } #[cfg(target_os = "android")] @@ -73,8 +67,7 @@ impl HttpsConnectorWithSni { ) -> (Self, HttpsConnectorWithSniHandle) { let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded(); let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner { - stream_handles: BTreeMap::new(), - next_socket_id: 0, + stream_handles: vec![], })); let inner_copy = inner.clone(); @@ -85,7 +78,7 @@ impl HttpsConnectorWithSni { let mut inner = inner_copy.lock().unwrap(); std::mem::take(&mut inner.stream_handles) }; - for (_, handle) in handles { + for handle in handles { handle.close(); } } @@ -94,7 +87,6 @@ impl HttpsConnectorWithSni { ( HttpsConnectorWithSni { inner, - handle, sni_hostname, #[cfg(target_os = "android")] socket_bypass_tx, @@ -103,13 +95,6 @@ impl HttpsConnectorWithSni { ) } - fn next_id(&mut self) -> usize { - let mut inner = self.inner.lock().unwrap(); - let next_id = inner.next_socket_id; - inner.next_socket_id = inner.next_socket_id.wrapping_add(1); - next_id - } - #[cfg(not(target_os = "android"))] async fn open_socket(addr: SocketAddr) -> std::io::Result<TcpStream> { timeout(CONNECT_TIMEOUT, TcpStream::connect(addr)) @@ -190,8 +175,6 @@ impl Service<Uri> for HttpsConnectorWithSni { io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") }); let inner = self.inner.clone(); - let socket_id = self.next_id(); - let handle = self.handle.clone(); #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); @@ -213,26 +196,14 @@ impl Service<Uri> for HttpsConnectorWithSni { ) .await?; - let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel(); - - let (tcp_stream, socket_handle) = - AbortableStream::new(tokio_connection, Some(socket_shutdown_tx)); + let (tcp_stream, socket_handle) = AbortableStream::new(tokio_connection); { let mut inner = inner.lock().unwrap(); - inner - .stream_handles - .insert(socket_id, socket_handle.clone()); + inner.stream_handles.retain(|handle| !handle.is_closed()); + inner.stream_handles.push(socket_handle); } - handle.spawn(async move { - let _ = socket_shutdown_rx.await; - { - let mut inner = inner.lock().unwrap(); - inner.stream_handles.remove(&socket_id); - } - }); - Ok(TlsStream::connect_https(tcp_stream, &hostname).await?) }; |
