diff options
| author | Emīls <emils@mullvad.net> | 2021-01-20 11:00:15 +0000 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2021-01-25 15:29:39 +0000 |
| commit | d7f6da0c20b77f23a896173f88a960926bc3d38d (patch) | |
| tree | fba23b03cfcf2f8c123d343176b5f9356021458d | |
| parent | 631b690b32a183b03f2fc7ec2fb2ecf0d7e40aae (diff) | |
| download | mullvadvpn-d7f6da0c20b77f23a896173f88a960926bc3d38d.tar.xz mullvadvpn-d7f6da0c20b77f23a896173f88a960926bc3d38d.zip | |
Add tcp_stream.rs
| -rw-r--r-- | mullvad-rpc/Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-rpc/src/tcp_stream.rs | 124 |
3 files changed, 126 insertions, 0 deletions
diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml index cf27d2e287..322640f7ac 100644 --- a/mullvad-rpc/Cargo.toml +++ b/mullvad-rpc/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" publish = false [dependencies] +bytes = "0.5" chrono = { version = "0.4", features = ["serde"] } err-derive = "0.2.1" futures = "0.3" diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index e211a89c2e..ebb1041d54 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -20,6 +20,7 @@ pub mod rest; mod https_client_with_sni; use crate::https_client_with_sni::HttpsConnectorWithSni; +mod tcp_stream; mod address_cache; mod relay_list; diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs new file mode 100644 index 0000000000..7bf89a47b2 --- /dev/null +++ b/mullvad-rpc/src/tcp_stream.rs @@ -0,0 +1,124 @@ +use bytes::buf::Buf; +use futures::channel::oneshot; +use hyper::client::connect::{Connected, Connection}; +use std::{ + io, + net::Shutdown, + pin::Pin, + sync::{Arc, Mutex, Weak}, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream as TokioTcpStream, +}; + +#[derive(Debug)] +pub struct TcpStreamHandle { + inner: Weak<Mutex<StreamInner>>, +} + +impl TcpStreamHandle { + pub fn close(self) { + if let Some(inner_lock) = self.inner.upgrade() { + if let Ok(mut inner) = inner_lock.lock() { + if let Err(err) = inner.stream.shutdown(Shutdown::Both) { + log::error!("Failed to shut down TCP socket: {}", err); + } + let _ = inner.shutdown_tx.take(); + } + } + } +} + + +pub struct TcpStream { + inner: Arc<Mutex<StreamInner>>, +} + +impl TcpStream { + pub fn new( + stream: TokioTcpStream, + id: usize, + shutdown_tx: Option<oneshot::Sender<()>>, + ) -> (Self, TcpStreamHandle) { + let inner = Arc::new(Mutex::new(StreamInner { + id, + stream, + shutdown_tx, + })); + ( + Self { + inner: inner.clone(), + }, + TcpStreamHandle { + inner: Arc::downgrade(&inner), + }, + ) + } + + fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T { + let mut inner = self.inner.lock().expect("TCP lock poisoned"); + stream_fn(&mut inner.stream) + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + if let Ok(mut inner) = self.inner.lock() { + if let Some(tx) = inner.shutdown_tx.take() { + let _ = tx.send(()); + } + } + } +} + + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf)) + } + + fn poll_write_buf<B: Buf>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_write_buf(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx)) + } +} + +impl AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf)) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[derive(Debug)] +struct StreamInner { + id: usize, + stream: TokioTcpStream, + shutdown_tx: Option<oneshot::Sender<()>>, +} |
