summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2021-01-20 11:00:15 +0000
committerEmīls <emils@mullvad.net>2021-01-25 15:29:39 +0000
commitd7f6da0c20b77f23a896173f88a960926bc3d38d (patch)
treefba23b03cfcf2f8c123d343176b5f9356021458d
parent631b690b32a183b03f2fc7ec2fb2ecf0d7e40aae (diff)
downloadmullvadvpn-d7f6da0c20b77f23a896173f88a960926bc3d38d.tar.xz
mullvadvpn-d7f6da0c20b77f23a896173f88a960926bc3d38d.zip
Add tcp_stream.rs
-rw-r--r--mullvad-rpc/Cargo.toml1
-rw-r--r--mullvad-rpc/src/lib.rs1
-rw-r--r--mullvad-rpc/src/tcp_stream.rs124
3 files changed, 126 insertions, 0 deletions
diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml
index cf27d2e287..322640f7ac 100644
--- a/mullvad-rpc/Cargo.toml
+++ b/mullvad-rpc/Cargo.toml
@@ -8,6 +8,7 @@ edition = "2018"
publish = false
[dependencies]
+bytes = "0.5"
chrono = { version = "0.4", features = ["serde"] }
err-derive = "0.2.1"
futures = "0.3"
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index e211a89c2e..ebb1041d54 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -20,6 +20,7 @@ pub mod rest;
mod https_client_with_sni;
use crate::https_client_with_sni::HttpsConnectorWithSni;
+mod tcp_stream;
mod address_cache;
mod relay_list;
diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs
new file mode 100644
index 0000000000..7bf89a47b2
--- /dev/null
+++ b/mullvad-rpc/src/tcp_stream.rs
@@ -0,0 +1,124 @@
+use bytes::buf::Buf;
+use futures::channel::oneshot;
+use hyper::client::connect::{Connected, Connection};
+use std::{
+ io,
+ net::Shutdown,
+ pin::Pin,
+ sync::{Arc, Mutex, Weak},
+ task::{Context, Poll},
+};
+use tokio::{
+ io::{AsyncRead, AsyncWrite},
+ net::TcpStream as TokioTcpStream,
+};
+
+#[derive(Debug)]
+pub struct TcpStreamHandle {
+ inner: Weak<Mutex<StreamInner>>,
+}
+
+impl TcpStreamHandle {
+ pub fn close(self) {
+ if let Some(inner_lock) = self.inner.upgrade() {
+ if let Ok(mut inner) = inner_lock.lock() {
+ if let Err(err) = inner.stream.shutdown(Shutdown::Both) {
+ log::error!("Failed to shut down TCP socket: {}", err);
+ }
+ let _ = inner.shutdown_tx.take();
+ }
+ }
+ }
+}
+
+
+pub struct TcpStream {
+ inner: Arc<Mutex<StreamInner>>,
+}
+
+impl TcpStream {
+ pub fn new(
+ stream: TokioTcpStream,
+ id: usize,
+ shutdown_tx: Option<oneshot::Sender<()>>,
+ ) -> (Self, TcpStreamHandle) {
+ let inner = Arc::new(Mutex::new(StreamInner {
+ id,
+ stream,
+ shutdown_tx,
+ }));
+ (
+ Self {
+ inner: inner.clone(),
+ },
+ TcpStreamHandle {
+ inner: Arc::downgrade(&inner),
+ },
+ )
+ }
+
+ fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T {
+ let mut inner = self.inner.lock().expect("TCP lock poisoned");
+ stream_fn(&mut inner.stream)
+ }
+}
+
+impl Drop for TcpStream {
+ fn drop(&mut self) {
+ if let Ok(mut inner) = self.inner.lock() {
+ if let Some(tx) = inner.shutdown_tx.take() {
+ let _ = tx.send(());
+ }
+ }
+ }
+}
+
+
+impl AsyncWrite for TcpStream {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf))
+ }
+
+ fn poll_write_buf<B: Buf>(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_write_buf(cx, buf))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_flush(cx))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx))
+ }
+}
+
+impl AsyncRead for TcpStream {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf))
+ }
+}
+
+impl Connection for TcpStream {
+ fn connected(&self) -> Connected {
+ Connected::new()
+ }
+}
+
+#[derive(Debug)]
+struct StreamInner {
+ id: usize,
+ stream: TokioTcpStream,
+ shutdown_tx: Option<oneshot::Sender<()>>,
+}