summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-01-07 20:50:09 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-01-26 16:27:53 +0100
commit7ce9d490f8ca3983bb9854b10335852831b5b71d (patch)
treeb6f6da8d7c54c1ebc4b6631ed8775e083ffcddef
parentb94cc8429ffe4a987312da6b4b8a31ca41417679 (diff)
downloadmullvadvpn-7ce9d490f8ca3983bb9854b10335852831b5b71d.tar.xz
mullvadvpn-7ce9d490f8ca3983bb9854b10335852831b5b71d.zip
Generalize TcpStream wrapper into AbortableStream
-rw-r--r--mullvad-rpc/src/abortable_stream.rs143
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs14
-rw-r--r--mullvad-rpc/src/lib.rs2
-rw-r--r--mullvad-rpc/src/rest.rs8
-rw-r--r--mullvad-rpc/src/tcp_stream.rs146
5 files changed, 155 insertions, 158 deletions
diff --git a/mullvad-rpc/src/abortable_stream.rs b/mullvad-rpc/src/abortable_stream.rs
new file mode 100644
index 0000000000..57f6556503
--- /dev/null
+++ b/mullvad-rpc/src/abortable_stream.rs
@@ -0,0 +1,143 @@
+//! Wrapper around a stream to make it abortable. This allows in-flight requests to be cancelled
+//! immediately instead of after the socket times out.
+
+use futures::channel::oneshot;
+use hyper::client::connect::{Connected, Connection};
+use std::{
+ future::Future,
+ io,
+ pin::Pin,
+ sync::{Arc, Mutex},
+ task::{Context, Poll},
+};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+#[derive(Clone, Debug)]
+pub struct AbortableStreamHandle {
+ tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
+ notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
+}
+
+impl AbortableStreamHandle {
+ pub fn close(self) {
+ if let Some(tx) = self.tx.lock().unwrap().take() {
+ let _ = tx.send(());
+ }
+ if let Some(tx) = self.notify_tx.lock().unwrap().take() {
+ let _ = tx.send(());
+ }
+ }
+}
+
+pub struct AbortableStream<S: Unpin> {
+ stream: S,
+ /// Notified when the stream is shut down.
+ notify_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
+ shutdown_rx: oneshot::Receiver<()>,
+}
+
+impl<S> AbortableStream<S>
+where
+ S: Unpin + Send + 'static,
+{
+ pub fn new(
+ stream: S,
+ notify_shutdown_tx: Option<oneshot::Sender<()>>,
+ ) -> (Self, AbortableStreamHandle) {
+ let (tx, rx) = oneshot::channel();
+ let notify_tx = Arc::new(Mutex::new(notify_shutdown_tx));
+ let stream_handle = AbortableStreamHandle {
+ tx: Arc::new(Mutex::new(Some(tx))),
+ notify_tx: notify_tx.clone(),
+ };
+ (
+ Self {
+ stream,
+ notify_tx,
+ shutdown_rx: rx,
+ },
+ stream_handle,
+ )
+ }
+}
+
+impl<S: Unpin> AbortableStream<S> {
+ fn maybe_send_shutdown_signal(&mut self) {
+ if let Some(tx) = self.notify_tx.lock().unwrap().take() {
+ let _ = tx.send(());
+ }
+ }
+}
+
+impl<S> Drop for AbortableStream<S>
+where
+ S: Unpin,
+{
+ fn drop(&mut self) {
+ self.maybe_send_shutdown_signal();
+ }
+}
+
+impl<S> AsyncWrite for AbortableStream<S>
+where
+ S: AsyncWrite + Unpin + Send + 'static,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ self.maybe_send_shutdown_signal();
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ "stream is closed",
+ )));
+ }
+ Pin::new(&mut self.stream).poll_write(cx, buf)
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ self.maybe_send_shutdown_signal();
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ "stream is closed",
+ )));
+ }
+ Pin::new(&mut self.stream).poll_flush(cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.stream).poll_shutdown(cx)
+ }
+}
+
+impl<S> AsyncRead for AbortableStream<S>
+where
+ S: AsyncRead + Unpin + Send + 'static,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ self.maybe_send_shutdown_signal();
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ "stream is closed",
+ )));
+ }
+ Pin::new(&mut self.stream).poll_read(cx, buf)
+ }
+}
+
+impl<S> Connection for AbortableStream<S>
+where
+ S: Connection + Unpin,
+{
+ fn connected(&self) -> Connected {
+ self.stream.connected()
+ }
+}
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index 63867f54e8..7b2a7a4b9a 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -1,4 +1,4 @@
-use crate::{rest::RequestCommand, tcp_stream::TcpStream};
+use crate::{abortable_stream::AbortableStream, rest::RequestCommand};
use futures::{
channel::{mpsc, oneshot},
sink::SinkExt,
@@ -27,7 +27,7 @@ use std::{
#[cfg(target_os = "android")]
use tokio::net::TcpSocket;
-use tokio::{net::TcpStream as TokioTcpStream, runtime::Handle, time::timeout};
+use tokio::{net::TcpStream, runtime::Handle, time::timeout};
use tokio_rustls::rustls;
// New LetsEncrypt root certificate
@@ -101,8 +101,8 @@ impl HttpsConnectorWithSni {
}
#[cfg(not(target_os = "android"))]
- async fn open_socket(addr: SocketAddr) -> std::io::Result<TokioTcpStream> {
- timeout(CONNECT_TIMEOUT, TokioTcpStream::connect(addr))
+ async fn open_socket(addr: SocketAddr) -> std::io::Result<TcpStream> {
+ timeout(CONNECT_TIMEOUT, TcpStream::connect(addr))
.await
.map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
}
@@ -111,7 +111,7 @@ impl HttpsConnectorWithSni {
async fn open_socket(
addr: SocketAddr,
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> std::io::Result<TokioTcpStream> {
+ ) -> std::io::Result<TcpStream> {
let socket = match addr {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
@@ -162,7 +162,7 @@ impl fmt::Debug for HttpsConnectorWithSni {
}
impl Service<Uri> for HttpsConnectorWithSni {
- type Response = MaybeHttpsStream<TcpStream>;
+ type Response = MaybeHttpsStream<AbortableStream<TcpStream>>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@@ -214,7 +214,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel();
let (tcp_stream, socket_handle) =
- TcpStream::new(tokio_connection, Some(socket_shutdown_tx));
+ AbortableStream::new(tokio_connection, Some(socket_shutdown_tx));
if let Some(mut service_tx) = service_tx {
if service_tx
.send(RequestCommand::SocketOpened(socket_id, socket_handle))
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 2b9551ddb4..9dfd139b01 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -23,9 +23,9 @@ pub mod rest;
mod https_client_with_sni;
use crate::https_client_with_sni::HttpsConnectorWithSni;
+mod abortable_stream;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
-mod tcp_stream;
mod address_cache;
mod relay_list;
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 00ea973dd3..7b150e7807 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,6 +1,6 @@
use crate::{
- address_cache::AddressCache, availability::ApiAvailabilityHandle,
- https_client_with_sni::HttpsConnectorWithSni, tcp_stream::TcpStreamHandle,
+ abortable_stream::AbortableStreamHandle, address_cache::AddressCache,
+ availability::ApiAvailabilityHandle, https_client_with_sni::HttpsConnectorWithSni,
};
use futures::{
channel::{mpsc, oneshot},
@@ -88,7 +88,7 @@ impl Error {
pub(crate) struct RequestService {
command_tx: mpsc::Sender<RequestCommand>,
command_rx: mpsc::Receiver<RequestCommand>,
- sockets: BTreeMap<usize, TcpStreamHandle>,
+ sockets: BTreeMap<usize, AbortableStreamHandle>,
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
handle: Handle,
next_id: u64,
@@ -296,7 +296,7 @@ pub(crate) enum RequestCommand {
oneshot::Sender<std::result::Result<Response, Error>>,
),
RequestFinished(u64),
- SocketOpened(usize, TcpStreamHandle),
+ SocketOpened(usize, AbortableStreamHandle),
SocketClosed(usize),
Reset(oneshot::Sender<()>),
}
diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs
deleted file mode 100644
index 60bad079b5..0000000000
--- a/mullvad-rpc/src/tcp_stream.rs
+++ /dev/null
@@ -1,146 +0,0 @@
-//! Wrapper around [`tokio::net::TcpStream`]. This allows in-flight requests to be cancelled
-//! immediately instead of after the socket times out.
-
-use futures::channel::oneshot;
-use hyper::client::connect::{Connected, Connection};
-use std::{
- io,
- net::Shutdown,
- pin::Pin,
- sync::{Arc, Mutex, Weak},
- task::{Context, Poll},
-};
-use tokio::{
- io::{AsyncRead, AsyncWrite, ReadBuf},
- net::TcpStream as TokioTcpStream,
-};
-
-#[derive(Debug)]
-pub struct TcpStreamHandle {
- inner: Weak<Mutex<Option<StreamInner>>>,
-}
-
-impl TcpStreamHandle {
- pub fn close(self) {
- if let Some(inner_lock) = self.inner.upgrade() {
- if let Ok(Some(inner)) = inner_lock.lock().map(|mut inner| inner.take()) {
- if let Err(err) = flatten_result(
- inner
- .stream
- .into_std()
- .map(|stream| stream.shutdown(Shutdown::Both)),
- ) {
- log::error!("Failed to shut down TCP socket: {}", err);
- }
- }
- }
- }
-}
-
-pub struct TcpStream {
- inner: Arc<Mutex<Option<StreamInner>>>,
-}
-
-impl TcpStream {
- pub fn new(
- stream: TokioTcpStream,
- shutdown_tx: Option<oneshot::Sender<()>>,
- ) -> (Self, TcpStreamHandle) {
- let inner = Arc::new(Mutex::new(Some(StreamInner {
- stream,
- shutdown_tx,
- })));
- let stream_handle = TcpStreamHandle {
- inner: Arc::downgrade(&inner),
- };
- (Self { inner }, stream_handle)
- }
-
- fn do_stream<T>(
- &self,
- mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T,
- closed_value: T,
- ) -> T {
- let mut inner = self.inner.lock().expect("TCP lock poisoned");
- if let Some(inner) = &mut *inner {
- stream_fn(&mut inner.stream)
- } else {
- closed_value
- }
- }
-}
-
-impl Drop for TcpStream {
- fn drop(&mut self) {
- if let Ok(Some(mut inner)) = self.inner.lock().map(|mut inner| inner.take()) {
- if let Some(tx) = inner.shutdown_tx.take() {
- let _ = tx.send(());
- }
- }
- }
-}
-
-impl AsyncWrite for TcpStream {
- fn poll_write(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<io::Result<usize>> {
- self.do_stream(
- |stream| Pin::new(stream).poll_write(cx, buf),
- Poll::Ready(Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "socket is closed",
- ))),
- )
- }
-
- fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- self.do_stream(
- |stream| Pin::new(stream).poll_flush(cx),
- Poll::Ready(Ok(())),
- )
- }
-
- fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- self.do_stream(
- |stream| Pin::new(stream).poll_shutdown(cx),
- Poll::Ready(Ok(())),
- )
- }
-}
-
-impl AsyncRead for TcpStream {
- fn poll_read(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut ReadBuf<'_>,
- ) -> Poll<io::Result<()>> {
- self.do_stream(
- |stream| Pin::new(stream).poll_read(cx, buf),
- Poll::Ready(Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "socket is closed",
- ))),
- )
- }
-}
-
-impl Connection for TcpStream {
- fn connected(&self) -> Connected {
- Connected::new()
- }
-}
-
-#[derive(Debug)]
-struct StreamInner {
- stream: TokioTcpStream,
- shutdown_tx: Option<oneshot::Sender<()>>,
-}
-
-fn flatten_result<T, E>(result: Result<Result<T, E>, E>) -> Result<T, E> {
- match result {
- Ok(value) => value,
- Err(err) => Err(err),
- }
-}