diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-07-12 10:36:40 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-07-13 15:57:40 +0200 |
| commit | ef2a4d8616274996cd6b5620e890ce8a3f507986 (patch) | |
| tree | 2f037f1450fa54ae98b2eacbea92c2eb4375e367 | |
| parent | c77f4785d0d70561a82bf49c90f14aed9497a6e6 (diff) | |
| download | mullvadvpn-ef2a4d8616274996cd6b5620e890ce8a3f507986.tar.xz mullvadvpn-ef2a4d8616274996cd6b5620e890ce8a3f507986.zip | |
Fix TcpStream shutdown
| -rw-r--r-- | mullvad-rpc/src/tcp_stream.rs | 71 |
1 files changed, 49 insertions, 22 deletions
diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs index ba414054d2..2919d73b0d 100644 --- a/mullvad-rpc/src/tcp_stream.rs +++ b/mullvad-rpc/src/tcp_stream.rs @@ -14,17 +14,21 @@ use tokio::{ #[derive(Debug)] pub struct TcpStreamHandle { - inner: Weak<Mutex<StreamInner>>, + inner: Weak<Mutex<Option<StreamInner>>>, } impl TcpStreamHandle { pub fn close(self) { if let Some(inner_lock) = self.inner.upgrade() { - if let Ok(mut inner) = inner_lock.lock() { - if let Err(err) = inner.stream.shutdown(Shutdown::Both) { + if let Ok(Some(inner)) = inner_lock.lock().map(|mut inner| inner.take()) { + if let Err(err) = flatten_result( + inner + .stream + .into_std() + .map(|stream| stream.shutdown(Shutdown::Both)), + ) { log::error!("Failed to shut down TCP socket: {}", err); } - let _ = inner.shutdown_tx.take(); } } } @@ -32,7 +36,7 @@ impl TcpStreamHandle { pub struct TcpStream { - inner: Arc<Mutex<StreamInner>>, + inner: Arc<Mutex<Option<StreamInner>>>, } impl TcpStream { @@ -41,30 +45,34 @@ impl TcpStream { id: usize, shutdown_tx: Option<oneshot::Sender<()>>, ) -> (Self, TcpStreamHandle) { - let inner = Arc::new(Mutex::new(StreamInner { + let inner = Arc::new(Mutex::new(Some(StreamInner { id, stream, shutdown_tx, - })); - ( - Self { - inner: inner.clone(), - }, - TcpStreamHandle { - inner: Arc::downgrade(&inner), - }, - ) + }))); + let stream_handle = TcpStreamHandle { + inner: Arc::downgrade(&inner), + }; + (Self { inner }, stream_handle) } - fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T { + fn do_stream<T>( + &self, + mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T, + closed_value: T, + ) -> T { let mut inner = self.inner.lock().expect("TCP lock poisoned"); - stream_fn(&mut inner.stream) + if let Some(inner) = &mut *inner { + stream_fn(&mut inner.stream) + } else { + closed_value + } } } impl Drop for TcpStream { fn drop(&mut self) { - if let Ok(mut inner) = self.inner.lock() { + if let Ok(Some(mut inner)) = self.inner.lock().map(|mut inner| inner.take()) { if let Some(tx) = inner.shutdown_tx.take() { let _ = tx.send(()); } @@ -79,15 +87,24 @@ impl AsyncWrite for TcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf)) + self.do_stream( + |stream| Pin::new(stream).poll_write(cx, buf), + Poll::Ready(Ok(0)), + ) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - self.do_stream(|stream| Pin::new(stream).poll_flush(cx)) + self.do_stream( + |stream| Pin::new(stream).poll_flush(cx), + Poll::Ready(Ok(())), + ) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx)) + self.do_stream( + |stream| Pin::new(stream).poll_shutdown(cx), + Poll::Ready(Ok(())), + ) } } @@ -97,7 +114,10 @@ impl AsyncRead for TcpStream { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { - self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf)) + self.do_stream( + |stream| Pin::new(stream).poll_read(cx, buf), + Poll::Ready(Ok(())), + ) } } @@ -113,3 +133,10 @@ struct StreamInner { stream: TokioTcpStream, shutdown_tx: Option<oneshot::Sender<()>>, } + +fn flatten_result<T, E>(result: Result<Result<T, E>, E>) -> Result<T, E> { + match result { + Ok(value) => value, + Err(err) => Err(err), + } +} |
