diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-02-07 17:50:24 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-02-07 17:50:24 +0100 |
| commit | 0d4ee241b523a7d024cb2aebfcbbf3924a8f3bb5 (patch) | |
| tree | c095e7b898d4d736322d267965948b30fde3bf50 /test/test-runner/src | |
| parent | 20d9c98f5ec44166b461730fec9ca292b622265f (diff) | |
| parent | eed7234599253f3d742be8bb4b6b1ecbf1299dc3 (diff) | |
| download | mullvadvpn-0d4ee241b523a7d024cb2aebfcbbf3924a8f3bb5.tar.xz mullvadvpn-0d4ee241b523a7d024cb2aebfcbbf3924a8f3bb5.zip | |
Merge branch 'testing-add-socks-server'
Diffstat (limited to 'test/test-runner/src')
| -rw-r--r-- | test/test-runner/src/forward.rs | 127 | ||||
| -rw-r--r-- | test/test-runner/src/main.rs | 19 |
2 files changed, 146 insertions, 0 deletions
diff --git a/test/test-runner/src/forward.rs b/test/test-runner/src/forward.rs new file mode 100644 index 0000000000..ec9e8a98f1 --- /dev/null +++ b/test/test-runner/src/forward.rs @@ -0,0 +1,127 @@ +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use test_rpc::net::SockHandleId; +use tokio::net::TcpListener; +use tokio::net::TcpStream; + +static SERVERS: Lazy<Mutex<HashMap<SockHandleId, Handle>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +/// Spawn a TCP forwarder that sends TCP via `via_addr` +pub async fn start_server( + bind_addr: SocketAddr, + via_addr: SocketAddr, +) -> Result<(SockHandleId, SocketAddr), test_rpc::Error> { + let next_nonce = { + static NONCE: AtomicUsize = AtomicUsize::new(0); + || NONCE.fetch_add(1, Ordering::Relaxed) + }; + let id = SockHandleId(next_nonce()); + + let handle = tcp_forward(bind_addr, via_addr).await.map_err(|error| { + log::error!("Failed to start TCP forwarder listener: {error}"); + test_rpc::Error::TcpForward + })?; + + let bind_addr = handle.local_addr(); + + let mut servers = SERVERS.lock().unwrap(); + servers.insert(id, handle); + + Ok((id, bind_addr)) +} + +/// Stop TCP forwarder given some ID returned by `start_server` +pub fn stop_server(id: SockHandleId) -> Result<(), test_rpc::Error> { + let handle = { + let mut servers = SERVERS.lock().unwrap(); + servers.remove(&id) + }; + + if let Some(handle) = handle { + handle.close(); + } + Ok(()) +} + +struct Handle { + handle: tokio::task::JoinHandle<()>, + bind_addr: SocketAddr, + clients: Arc<Mutex<Vec<tokio::task::JoinHandle<()>>>>, +} + +impl Handle { + pub fn close(&self) { + self.handle.abort(); + + let mut clients = self.clients.lock().unwrap(); + for client in clients.drain(..) { + client.abort(); + } + } + + pub fn local_addr(&self) -> SocketAddr { + self.bind_addr + } +} + +impl Drop for Handle { + fn drop(&mut self) { + self.close(); + } +} + +/// Forward TCP traffic via `proxy_addr` +async fn tcp_forward( + bind_addr: SocketAddr, + proxy_addr: SocketAddr, +) -> Result<Handle, test_rpc::Error> { + let listener = TcpListener::bind(&bind_addr).await.map_err(|error| { + log::error!("Failed to bind TCP forward socket: {error}"); + test_rpc::Error::TcpForward + })?; + let bind_addr = listener.local_addr().map_err(|error| { + log::error!("Failed to get TCP socket addr: {error}"); + test_rpc::Error::TcpForward + })?; + + let clients = Arc::new(Mutex::new(vec![])); + + let clients_copy = clients.clone(); + + let handle = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((mut client, _addr)) => { + let client_handle = tokio::spawn(async move { + let mut proxy = match TcpStream::connect(proxy_addr).await { + Ok(proxy) => proxy, + Err(error) => { + log::error!("failed to connect to TCP proxy: {error}"); + return; + } + }; + + if let Err(error) = + tokio::io::copy_bidirectional(&mut client, &mut proxy).await + { + log::error!("copy_directional failed: {error}"); + } + }); + clients_copy.lock().unwrap().push(client_handle); + } + Err(error) => { + log::error!("failed to accept TCP client: {error}"); + } + } + } + }); + Ok(Handle { + handle, + bind_addr, + clients, + }) +} diff --git a/test/test-runner/src/main.rs b/test/test-runner/src/main.rs index 1c2c301b27..74f7761cc2 100644 --- a/test/test-runner/src/main.rs +++ b/test/test-runner/src/main.rs @@ -10,6 +10,7 @@ use tarpc::context; use tarpc::server::Channel; use test_rpc::{ mullvad_daemon::{ServiceStatus, SOCKET_PATH}, + net::SockHandleId, package::Package, transport::GrpcForwarder, AppTrace, Service, @@ -22,6 +23,7 @@ use tokio::{ use tokio_util::codec::{Decoder, LengthDelimitedCodec}; mod app; +mod forward; mod logging; mod net; mod package; @@ -167,6 +169,23 @@ impl Service for TestServer { .collect()) } + async fn start_tcp_forward( + self, + _: context::Context, + bind_addr: SocketAddr, + via_addr: SocketAddr, + ) -> Result<(SockHandleId, SocketAddr), test_rpc::Error> { + forward::start_server(bind_addr, via_addr).await + } + + async fn stop_tcp_forward( + self, + _: context::Context, + id: SockHandleId, + ) -> Result<(), test_rpc::Error> { + forward::stop_server(id) + } + async fn get_interface_ip( self, _: context::Context, |
