diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-10-19 21:35:07 +0200 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-10-21 09:21:47 +0200 |
| commit | f94cd1509d3b1f4a14f33c4425def750cb9879b8 (patch) | |
| tree | 74700e285d1f14b373f4c5e7a2fcbea90a0071c0 | |
| parent | c32e72ab7ce0940f36060b79891471ad32dfd703 (diff) | |
| download | mullvadvpn-f94cd1509d3b1f4a14f33c4425def750cb9879b8.tar.xz mullvadvpn-f94cd1509d3b1f4a14f33c4425def750cb9879b8.zip | |
Parameterize `Forwarder` over any Async{Read,Write} stream
| -rw-r--r-- | mullvad-encrypted-dns-proxy/src/forwarder.rs | 95 |
1 files changed, 48 insertions, 47 deletions
diff --git a/mullvad-encrypted-dns-proxy/src/forwarder.rs b/mullvad-encrypted-dns-proxy/src/forwarder.rs index e8f366e167..f607d10978 100644 --- a/mullvad-encrypted-dns-proxy/src/forwarder.rs +++ b/mullvad-encrypted-dns-proxy/src/forwarder.rs @@ -1,9 +1,6 @@ //! Forward TCP traffic over various proxy configurations. -use std::{ - io, - task::{ready, Poll}, -}; +use std::io; use tokio::{ io::{AsyncRead, AsyncWrite}, @@ -16,17 +13,16 @@ use crate::config::Obfuscator; /// /// Obtain [`ProxyConfig`](crate::config::ProxyConfig)s with /// [resolve_configs](crate::config_resolver::resolve_configs). -pub struct Forwarder { +pub struct Forwarder<S> { read_obfuscator: Option<Box<dyn Obfuscator>>, write_obfuscator: Option<Box<dyn Obfuscator>>, - server_connection: TcpStream, + stream: S, } -impl Forwarder { +impl Forwarder<TcpStream> { /// Create a forwarder that will connect to a given proxy endpoint. pub async fn connect(proxy_config: &crate::config::ProxyConfig) -> io::Result<Self> { let server_connection = TcpStream::connect(proxy_config.addr).await?; - let (read_obfuscator, write_obfuscator) = if let Some(obfuscation_config) = &proxy_config.obfuscation { ( @@ -40,14 +36,14 @@ impl Forwarder { Ok(Self { read_obfuscator, write_obfuscator, - server_connection, + stream: server_connection, }) } /// Forwards traffic from the client stream to the remote proxy, obfuscating and deobfuscating /// it in the process. pub async fn forward(self, client_stream: TcpStream) { - let (server_read, server_write) = self.server_connection.into_split(); + let (server_read, server_write) = self.stream.into_split(); let (client_read, client_write) = client_stream.into_split(); let _ = tokio::join!( forward(self.read_obfuscator, client_read, server_write), @@ -56,13 +52,38 @@ impl Forwarder { } } -impl tokio::io::AsyncRead for Forwarder { +async fn forward( + mut obfuscator: Option<Box<dyn Obfuscator>>, + mut source: impl AsyncRead + Unpin, + mut sink: impl AsyncWrite + Unpin, +) -> io::Result<()> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let mut buf = vec![0u8; 1024 * 64]; + while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await { + if n_bytes_read == 0 { + break; + } + let bytes_received = &mut buf[..n_bytes_read]; + + if let Some(obfuscator) = &mut obfuscator { + obfuscator.obfuscate(bytes_received); + } + sink.write_all(bytes_received).await?; + } + Ok(()) +} + +impl<S> tokio::io::AsyncRead for Forwarder<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll<std::io::Result<()>> { - let socket = std::pin::pin!(&mut self.server_connection); + use std::task::{ready, Poll}; + let socket = std::pin::pin!(&mut self.stream); match ready!(socket.poll_read(cx, buf)) { // in this case, we can read and deobfuscate. Ok(()) => { @@ -76,59 +97,39 @@ impl tokio::io::AsyncRead for Forwarder { } } -impl tokio::io::AsyncWrite for Forwarder { +impl<S> tokio::io::AsyncWrite for Forwarder<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> Poll<Result<usize, io::Error>> { - let socket = std::pin::pin!(&mut self.server_connection); - if let Err(err) = ready!(socket.poll_write_ready(cx)) { - return Poll::Ready(Err(err)); - }; - + ) -> std::task::Poll<Result<usize, io::Error>> { let mut owned_buf = buf.to_vec(); if let Some(write_obfuscator) = &mut self.write_obfuscator { write_obfuscator.obfuscate(&mut owned_buf); } - let socket = std::pin::pin!(&mut self.server_connection); - socket.poll_write(cx, &owned_buf) + let stream = std::pin::pin!(&mut self.stream); + // If the object is not ready for writing, the method returns Poll::Pending + // and arranges for the current task (via cx.waker()) to receive a notification + // when the object becomes writable or is closed. + stream.poll_write(cx, &owned_buf) } fn poll_flush( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> Poll<Result<(), io::Error>> { - std::pin::pin!(&mut self.server_connection).poll_flush(cx) + ) -> std::task::Poll<Result<(), io::Error>> { + std::pin::pin!(&mut self.stream).poll_flush(cx) } fn poll_shutdown( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> Poll<Result<(), io::Error>> { - std::pin::pin!(&mut self.server_connection).poll_shutdown(cx) - } -} - -async fn forward( - mut obfuscator: Option<Box<dyn Obfuscator>>, - mut source: impl AsyncRead + Unpin, - mut sink: impl AsyncWrite + Unpin, -) -> io::Result<()> { - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - let mut buf = vec![0u8; 1024 * 64]; - while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await { - if n_bytes_read == 0 { - break; - } - let bytes_received = &mut buf[..n_bytes_read]; - - if let Some(obfuscator) = &mut obfuscator { - obfuscator.obfuscate(bytes_received); - } - sink.write_all(bytes_received).await?; + ) -> std::task::Poll<Result<(), io::Error>> { + std::pin::pin!(&mut self.stream).poll_shutdown(cx) } - Ok(()) } #[cfg(test)] @@ -169,7 +170,7 @@ mod tests { let mut forwarder = Forwarder { read_obfuscator: Some(obfuscation_config.create_obfuscator()), write_obfuscator: Some(obfuscation_config.create_obfuscator()), - server_connection: client_conn, + stream: client_conn, }; let mut buf = vec![0u8; 1024]; while let Ok(bytes_read) = forwarder.read(&mut buf).await { |
