diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-01-07 20:50:09 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-01-26 16:27:53 +0100 |
| commit | 7ce9d490f8ca3983bb9854b10335852831b5b71d (patch) | |
| tree | b6f6da8d7c54c1ebc4b6631ed8775e083ffcddef | |
| parent | b94cc8429ffe4a987312da6b4b8a31ca41417679 (diff) | |
| download | mullvadvpn-7ce9d490f8ca3983bb9854b10335852831b5b71d.tar.xz mullvadvpn-7ce9d490f8ca3983bb9854b10335852831b5b71d.zip | |
Generalize TcpStream wrapper into AbortableStream
| -rw-r--r-- | mullvad-rpc/src/abortable_stream.rs | 143 | ||||
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 14 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 2 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 8 | ||||
| -rw-r--r-- | mullvad-rpc/src/tcp_stream.rs | 146 |
5 files changed, 155 insertions, 158 deletions
diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs new file mode 100644 index 0000000000..57f6556503 --- /dev/null +++ b/mullvad-rpc/src/abortable_stream.rs @@ -0,0 +1,143 @@ +//! Wrapper around a stream to make it abortable. This allows in-flight requests to be cancelled +//! immediately instead of after the socket times out. + +use futures::channel::oneshot; +use hyper::client::connect::{Connected, Connection}; +use std::{ + future::Future, + io, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Clone, Debug)] +pub struct AbortableStreamHandle { + tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, + notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, +} + +impl AbortableStreamHandle { + pub fn close(self) { + if let Some(tx) = self.tx.lock().unwrap().take() { + let _ = tx.send(()); + } + if let Some(tx) = self.notify_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } +} + +pub struct AbortableStream<S: Unpin> { + stream: S, + /// Notified when the stream is shut down. + notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, + shutdown_rx: oneshot::Receiver<()>, +} + +impl<S> AbortableStream<S> +where + S: Unpin + Send + 'static, +{ + pub fn new( + stream: S, + notify_shutdown_tx: Option<oneshot::Sender<()>>, + ) -> (Self, AbortableStreamHandle) { + let (tx, rx) = oneshot::channel(); + let notify_tx = Arc::new(Mutex::new(notify_shutdown_tx)); + let stream_handle = AbortableStreamHandle { + tx: Arc::new(Mutex::new(Some(tx))), + notify_tx: notify_tx.clone(), + }; + ( + Self { + stream, + notify_tx, + shutdown_rx: rx, + }, + stream_handle, + ) + } +} + +impl<S: Unpin> AbortableStream<S> { + fn maybe_send_shutdown_signal(&mut self) { + if let Some(tx) = self.notify_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } +} + +impl<S> Drop for AbortableStream<S> +where + S: Unpin, +{ + fn drop(&mut self) { + self.maybe_send_shutdown_signal(); + } +} + +impl<S> AsyncWrite for AbortableStream<S> +where + S: AsyncWrite + Unpin + Send + 'static, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + self.maybe_send_shutdown_signal(); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "stream is closed", + ))); + } + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + self.maybe_send_shutdown_signal(); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "stream is closed", + ))); + } + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl<S> AsyncRead for AbortableStream<S> +where + S: AsyncRead + Unpin + Send + 'static, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + self.maybe_send_shutdown_signal(); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "stream is closed", + ))); + } + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl<S> Connection for AbortableStream<S> +where + S: Connection + Unpin, +{ + fn connected(&self) -> Connected { + self.stream.connected() + } +} diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index 63867f54e8..7b2a7a4b9a 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -1,4 +1,4 @@ -use crate::{rest::RequestCommand, tcp_stream::TcpStream}; +use crate::{abortable_stream::AbortableStream, rest::RequestCommand}; use futures::{ channel::{mpsc, oneshot}, sink::SinkExt, @@ -27,7 +27,7 @@ use std::{ #[cfg(target_os = "android")] use tokio::net::TcpSocket; -use tokio::{net::TcpStream as TokioTcpStream, runtime::Handle, time::timeout}; +use tokio::{net::TcpStream, runtime::Handle, time::timeout}; use tokio_rustls::rustls; // New LetsEncrypt root certificate @@ -101,8 +101,8 @@ impl HttpsConnectorWithSni { } #[cfg(not(target_os = "android"))] - async fn open_socket(addr: SocketAddr) -> std::io::Result<TokioTcpStream> { - timeout(CONNECT_TIMEOUT, TokioTcpStream::connect(addr)) + async fn open_socket(addr: SocketAddr) -> std::io::Result<TcpStream> { + timeout(CONNECT_TIMEOUT, TcpStream::connect(addr)) .await .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? } @@ -111,7 +111,7 @@ impl HttpsConnectorWithSni { async fn open_socket( addr: SocketAddr, socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - ) -> std::io::Result<TokioTcpStream> { + ) -> std::io::Result<TcpStream> { let socket = match addr { SocketAddr::V4(_) => TcpSocket::new_v4()?, SocketAddr::V6(_) => TcpSocket::new_v6()?, @@ -162,7 +162,7 @@ impl fmt::Debug for HttpsConnectorWithSni { } impl Service<Uri> for HttpsConnectorWithSni { - type Response = MaybeHttpsStream<TcpStream>; + type Response = MaybeHttpsStream<AbortableStream<TcpStream>>; type Error = io::Error; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; @@ -214,7 +214,7 @@ impl Service<Uri> for HttpsConnectorWithSni { let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel(); let (tcp_stream, socket_handle) = - TcpStream::new(tokio_connection, Some(socket_shutdown_tx)); + AbortableStream::new(tokio_connection, Some(socket_shutdown_tx)); if let Some(mut service_tx) = service_tx { if service_tx .send(RequestCommand::SocketOpened(socket_id, socket_handle)) diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 2b9551ddb4..9dfd139b01 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -23,9 +23,9 @@ pub mod rest; mod https_client_with_sni; use crate::https_client_with_sni::HttpsConnectorWithSni; +mod abortable_stream; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; -mod tcp_stream; mod address_cache; mod relay_list; diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 00ea973dd3..7b150e7807 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,6 +1,6 @@ use crate::{ - address_cache::AddressCache, availability::ApiAvailabilityHandle, - https_client_with_sni::HttpsConnectorWithSni, tcp_stream::TcpStreamHandle, + abortable_stream::AbortableStreamHandle, address_cache::AddressCache, + availability::ApiAvailabilityHandle, https_client_with_sni::HttpsConnectorWithSni, }; use futures::{ channel::{mpsc, oneshot}, @@ -88,7 +88,7 @@ impl Error { pub(crate) struct RequestService { command_tx: mpsc::Sender<RequestCommand>, command_rx: mpsc::Receiver<RequestCommand>, - sockets: BTreeMap<usize, TcpStreamHandle>, + sockets: BTreeMap<usize, AbortableStreamHandle>, client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, handle: Handle, next_id: u64, @@ -296,7 +296,7 @@ pub(crate) enum RequestCommand { oneshot::Sender<std::result::Result<Response, Error>>, ), RequestFinished(u64), - SocketOpened(usize, TcpStreamHandle), + SocketOpened(usize, AbortableStreamHandle), SocketClosed(usize), Reset(oneshot::Sender<()>), } diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs deleted file mode 100644 index 60bad079b5..0000000000 --- a/mullvad-rpc/src/tcp_stream.rs +++ /dev/null @@ -1,146 +0,0 @@ -//! Wrapper around [`tokio::net::TcpStream`]. This allows in-flight requests to be cancelled -//! immediately instead of after the socket times out. - -use futures::channel::oneshot; -use hyper::client::connect::{Connected, Connection}; -use std::{ - io, - net::Shutdown, - pin::Pin, - sync::{Arc, Mutex, Weak}, - task::{Context, Poll}, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::TcpStream as TokioTcpStream, -}; - -#[derive(Debug)] -pub struct TcpStreamHandle { - inner: Weak<Mutex<Option<StreamInner>>>, -} - -impl TcpStreamHandle { - pub fn close(self) { - if let Some(inner_lock) = self.inner.upgrade() { - if let Ok(Some(inner)) = inner_lock.lock().map(|mut inner| inner.take()) { - if let Err(err) = flatten_result( - inner - .stream - .into_std() - .map(|stream| stream.shutdown(Shutdown::Both)), - ) { - log::error!("Failed to shut down TCP socket: {}", err); - } - } - } - } -} - -pub struct TcpStream { - inner: Arc<Mutex<Option<StreamInner>>>, -} - -impl TcpStream { - pub fn new( - stream: TokioTcpStream, - shutdown_tx: Option<oneshot::Sender<()>>, - ) -> (Self, TcpStreamHandle) { - let inner = Arc::new(Mutex::new(Some(StreamInner { - stream, - shutdown_tx, - }))); - let stream_handle = TcpStreamHandle { - inner: Arc::downgrade(&inner), - }; - (Self { inner }, stream_handle) - } - - fn do_stream<T>( - &self, - mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T, - closed_value: T, - ) -> T { - let mut inner = self.inner.lock().expect("TCP lock poisoned"); - if let Some(inner) = &mut *inner { - stream_fn(&mut inner.stream) - } else { - closed_value - } - } -} - -impl Drop for TcpStream { - fn drop(&mut self) { - if let Ok(Some(mut inner)) = self.inner.lock().map(|mut inner| inner.take()) { - if let Some(tx) = inner.shutdown_tx.take() { - let _ = tx.send(()); - } - } - } -} - -impl AsyncWrite for TcpStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<io::Result<usize>> { - self.do_stream( - |stream| Pin::new(stream).poll_write(cx, buf), - Poll::Ready(Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "socket is closed", - ))), - ) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - self.do_stream( - |stream| Pin::new(stream).poll_flush(cx), - Poll::Ready(Ok(())), - ) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - self.do_stream( - |stream| Pin::new(stream).poll_shutdown(cx), - Poll::Ready(Ok(())), - ) - } -} - -impl AsyncRead for TcpStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll<io::Result<()>> { - self.do_stream( - |stream| Pin::new(stream).poll_read(cx, buf), - Poll::Ready(Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "socket is closed", - ))), - ) - } -} - -impl Connection for TcpStream { - fn connected(&self) -> Connected { - Connected::new() - } -} - -#[derive(Debug)] -struct StreamInner { - stream: TokioTcpStream, - shutdown_tx: Option<oneshot::Sender<()>>, -} - -fn flatten_result<T, E>(result: Result<Result<T, E>, E>) -> Result<T, E> { - match result { - Ok(value) => value, - Err(err) => Err(err), - } -} |
