summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-07-12 10:36:40 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-07-13 15:57:40 +0200
commitef2a4d8616274996cd6b5620e890ce8a3f507986 (patch)
tree2f037f1450fa54ae98b2eacbea92c2eb4375e367
parentc77f4785d0d70561a82bf49c90f14aed9497a6e6 (diff)
downloadmullvadvpn-ef2a4d8616274996cd6b5620e890ce8a3f507986.tar.xz
mullvadvpn-ef2a4d8616274996cd6b5620e890ce8a3f507986.zip
Fix TcpStream shutdown
-rw-r--r--mullvad-rpc/src/tcp_stream.rs71
1 files changed, 49 insertions, 22 deletions
diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs
index ba414054d2..2919d73b0d 100644
--- a/mullvad-rpc/src/tcp_stream.rs
+++ b/mullvad-rpc/src/tcp_stream.rs
@@ -14,17 +14,21 @@ use tokio::{
#[derive(Debug)]
pub struct TcpStreamHandle {
- inner: Weak<Mutex<StreamInner>>,
+ inner: Weak<Mutex<Option<StreamInner>>>,
}
impl TcpStreamHandle {
pub fn close(self) {
if let Some(inner_lock) = self.inner.upgrade() {
- if let Ok(mut inner) = inner_lock.lock() {
- if let Err(err) = inner.stream.shutdown(Shutdown::Both) {
+ if let Ok(Some(inner)) = inner_lock.lock().map(|mut inner| inner.take()) {
+ if let Err(err) = flatten_result(
+ inner
+ .stream
+ .into_std()
+ .map(|stream| stream.shutdown(Shutdown::Both)),
+ ) {
log::error!("Failed to shut down TCP socket: {}", err);
}
- let _ = inner.shutdown_tx.take();
}
}
}
@@ -32,7 +36,7 @@ impl TcpStreamHandle {
pub struct TcpStream {
- inner: Arc<Mutex<StreamInner>>,
+ inner: Arc<Mutex<Option<StreamInner>>>,
}
impl TcpStream {
@@ -41,30 +45,34 @@ impl TcpStream {
id: usize,
shutdown_tx: Option<oneshot::Sender<()>>,
) -> (Self, TcpStreamHandle) {
- let inner = Arc::new(Mutex::new(StreamInner {
+ let inner = Arc::new(Mutex::new(Some(StreamInner {
id,
stream,
shutdown_tx,
- }));
- (
- Self {
- inner: inner.clone(),
- },
- TcpStreamHandle {
- inner: Arc::downgrade(&inner),
- },
- )
+ })));
+ let stream_handle = TcpStreamHandle {
+ inner: Arc::downgrade(&inner),
+ };
+ (Self { inner }, stream_handle)
}
- fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T {
+ fn do_stream<T>(
+ &self,
+ mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T,
+ closed_value: T,
+ ) -> T {
let mut inner = self.inner.lock().expect("TCP lock poisoned");
- stream_fn(&mut inner.stream)
+ if let Some(inner) = &mut *inner {
+ stream_fn(&mut inner.stream)
+ } else {
+ closed_value
+ }
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
- if let Ok(mut inner) = self.inner.lock() {
+ if let Ok(Some(mut inner)) = self.inner.lock().map(|mut inner| inner.take()) {
if let Some(tx) = inner.shutdown_tx.take() {
let _ = tx.send(());
}
@@ -79,15 +87,24 @@ impl AsyncWrite for TcpStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
- self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf))
+ self.do_stream(
+ |stream| Pin::new(stream).poll_write(cx, buf),
+ Poll::Ready(Ok(0)),
+ )
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- self.do_stream(|stream| Pin::new(stream).poll_flush(cx))
+ self.do_stream(
+ |stream| Pin::new(stream).poll_flush(cx),
+ Poll::Ready(Ok(())),
+ )
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx))
+ self.do_stream(
+ |stream| Pin::new(stream).poll_shutdown(cx),
+ Poll::Ready(Ok(())),
+ )
}
}
@@ -97,7 +114,10 @@ impl AsyncRead for TcpStream {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
- self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf))
+ self.do_stream(
+ |stream| Pin::new(stream).poll_read(cx, buf),
+ Poll::Ready(Ok(())),
+ )
}
}
@@ -113,3 +133,10 @@ struct StreamInner {
stream: TokioTcpStream,
shutdown_tx: Option<oneshot::Sender<()>>,
}
+
+fn flatten_result<T, E>(result: Result<Result<T, E>, E>) -> Result<T, E> {
+ match result {
+ Ok(value) => value,
+ Err(err) => Err(err),
+ }
+}