summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-10-19 21:35:07 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-10-21 09:21:47 +0200
commitf94cd1509d3b1f4a14f33c4425def750cb9879b8 (patch)
tree74700e285d1f14b373f4c5e7a2fcbea90a0071c0
parentc32e72ab7ce0940f36060b79891471ad32dfd703 (diff)
downloadmullvadvpn-f94cd1509d3b1f4a14f33c4425def750cb9879b8.tar.xz
mullvadvpn-f94cd1509d3b1f4a14f33c4425def750cb9879b8.zip
Parameterize `Forwarder` over any Async{Read,Write} stream
-rw-r--r--mullvad-encrypted-dns-proxy/src/forwarder.rs95
1 files changed, 48 insertions, 47 deletions
diff --git a/mullvad-encrypted-dns-proxy/src/forwarder.rs b/mullvad-encrypted-dns-proxy/src/forwarder.rs
index e8f366e167..f607d10978 100644
--- a/mullvad-encrypted-dns-proxy/src/forwarder.rs
+++ b/mullvad-encrypted-dns-proxy/src/forwarder.rs
@@ -1,9 +1,6 @@
//! Forward TCP traffic over various proxy configurations.
-use std::{
- io,
- task::{ready, Poll},
-};
+use std::io;
use tokio::{
io::{AsyncRead, AsyncWrite},
@@ -16,17 +13,16 @@ use crate::config::Obfuscator;
///
/// Obtain [`ProxyConfig`](crate::config::ProxyConfig)s with
/// [resolve_configs](crate::config_resolver::resolve_configs).
-pub struct Forwarder {
+pub struct Forwarder<S> {
read_obfuscator: Option<Box<dyn Obfuscator>>,
write_obfuscator: Option<Box<dyn Obfuscator>>,
- server_connection: TcpStream,
+ stream: S,
}
-impl Forwarder {
+impl Forwarder<TcpStream> {
/// Create a forwarder that will connect to a given proxy endpoint.
pub async fn connect(proxy_config: &crate::config::ProxyConfig) -> io::Result<Self> {
let server_connection = TcpStream::connect(proxy_config.addr).await?;
-
let (read_obfuscator, write_obfuscator) =
if let Some(obfuscation_config) = &proxy_config.obfuscation {
(
@@ -40,14 +36,14 @@ impl Forwarder {
Ok(Self {
read_obfuscator,
write_obfuscator,
- server_connection,
+ stream: server_connection,
})
}
/// Forwards traffic from the client stream to the remote proxy, obfuscating and deobfuscating
/// it in the process.
pub async fn forward(self, client_stream: TcpStream) {
- let (server_read, server_write) = self.server_connection.into_split();
+ let (server_read, server_write) = self.stream.into_split();
let (client_read, client_write) = client_stream.into_split();
let _ = tokio::join!(
forward(self.read_obfuscator, client_read, server_write),
@@ -56,13 +52,38 @@ impl Forwarder {
}
}
-impl tokio::io::AsyncRead for Forwarder {
+async fn forward(
+ mut obfuscator: Option<Box<dyn Obfuscator>>,
+ mut source: impl AsyncRead + Unpin,
+ mut sink: impl AsyncWrite + Unpin,
+) -> io::Result<()> {
+ use tokio::io::{AsyncReadExt, AsyncWriteExt};
+ let mut buf = vec![0u8; 1024 * 64];
+ while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await {
+ if n_bytes_read == 0 {
+ break;
+ }
+ let bytes_received = &mut buf[..n_bytes_read];
+
+ if let Some(obfuscator) = &mut obfuscator {
+ obfuscator.obfuscate(bytes_received);
+ }
+ sink.write_all(bytes_received).await?;
+ }
+ Ok(())
+}
+
+impl<S> tokio::io::AsyncRead for Forwarder<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
- let socket = std::pin::pin!(&mut self.server_connection);
+ use std::task::{ready, Poll};
+ let socket = std::pin::pin!(&mut self.stream);
match ready!(socket.poll_read(cx, buf)) {
// in this case, we can read and deobfuscate.
Ok(()) => {
@@ -76,59 +97,39 @@ impl tokio::io::AsyncRead for Forwarder {
}
}
-impl tokio::io::AsyncWrite for Forwarder {
+impl<S> tokio::io::AsyncWrite for Forwarder<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
- ) -> Poll<Result<usize, io::Error>> {
- let socket = std::pin::pin!(&mut self.server_connection);
- if let Err(err) = ready!(socket.poll_write_ready(cx)) {
- return Poll::Ready(Err(err));
- };
-
+ ) -> std::task::Poll<Result<usize, io::Error>> {
let mut owned_buf = buf.to_vec();
if let Some(write_obfuscator) = &mut self.write_obfuscator {
write_obfuscator.obfuscate(&mut owned_buf);
}
- let socket = std::pin::pin!(&mut self.server_connection);
- socket.poll_write(cx, &owned_buf)
+ let stream = std::pin::pin!(&mut self.stream);
+ // If the object is not ready for writing, the method returns Poll::Pending
+ // and arranges for the current task (via cx.waker()) to receive a notification
+ // when the object becomes writable or is closed.
+ stream.poll_write(cx, &owned_buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
- ) -> Poll<Result<(), io::Error>> {
- std::pin::pin!(&mut self.server_connection).poll_flush(cx)
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ std::pin::pin!(&mut self.stream).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
- ) -> Poll<Result<(), io::Error>> {
- std::pin::pin!(&mut self.server_connection).poll_shutdown(cx)
- }
-}
-
-async fn forward(
- mut obfuscator: Option<Box<dyn Obfuscator>>,
- mut source: impl AsyncRead + Unpin,
- mut sink: impl AsyncWrite + Unpin,
-) -> io::Result<()> {
- use tokio::io::{AsyncReadExt, AsyncWriteExt};
- let mut buf = vec![0u8; 1024 * 64];
- while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await {
- if n_bytes_read == 0 {
- break;
- }
- let bytes_received = &mut buf[..n_bytes_read];
-
- if let Some(obfuscator) = &mut obfuscator {
- obfuscator.obfuscate(bytes_received);
- }
- sink.write_all(bytes_received).await?;
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ std::pin::pin!(&mut self.stream).poll_shutdown(cx)
}
- Ok(())
}
#[cfg(test)]
@@ -169,7 +170,7 @@ mod tests {
let mut forwarder = Forwarder {
read_obfuscator: Some(obfuscation_config.create_obfuscator()),
write_obfuscator: Some(obfuscation_config.create_obfuscator()),
- server_connection: client_conn,
+ stream: client_conn,
};
let mut buf = vec![0u8; 1024];
while let Ok(bytes_read) = forwarder.read(&mut buf).await {