diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-12-20 17:13:26 +0100 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-12-20 17:36:38 +0100 |
| commit | 4fc8028b301f59e0da4154aaa7d7c650a59cffeb (patch) | |
| tree | ee1e7a5477c65af1d078df7b6bac6218c4d20592 | |
| parent | bb087ff19f28ac839c9d892cab798d2666843047 (diff) | |
| download | mullvadvpn-4fc8028b301f59e0da4154aaa7d7c650a59cffeb.tar.xz mullvadvpn-4fc8028b301f59e0da4154aaa7d7c650a59cffeb.zip | |
Move SocketSniffer to separate module
| -rw-r--r-- | talpid-tunnel-config-client/src/lib.rs | 110 |
1 files changed, 58 insertions, 52 deletions
diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs index f09cd32074..3a1b25bb46 100644 --- a/talpid-tunnel-config-client/src/lib.rs +++ b/talpid-tunnel-config-client/src/lib.rs @@ -5,7 +5,6 @@ use std::net::SocketAddr; #[cfg(not(target_os = "ios"))] use std::net::{IpAddr, Ipv4Addr}; use talpid_types::net::wireguard::{PresharedKey, PublicKey}; -use tokio::io::{AsyncRead, AsyncWrite}; use tonic::transport::Channel; #[cfg(not(target_os = "ios"))] use tonic::transport::Endpoint; @@ -279,6 +278,8 @@ fn xor_assign(dst: &mut [u8; 32], src: &[u8; 32]) { /// value has been speficically lowered, to avoid MTU issues. See the `socket` module. #[cfg(not(target_os = "ios"))] async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService, Error> { + use hyper_util::rt::tokio::TokioIo; + let endpoint = Endpoint::from_static("tcp://0.0.0.0:0"); let addr = SocketAddr::new(IpAddr::V4(ip), CONFIG_SERVICE_PORT); @@ -286,13 +287,13 @@ async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService, .connect_with_connector(service_fn(move |_| async move { let sock = socket::TcpSocket::new()?; let stream = sock.connect(addr).await?; - let sniffer = SocketSniffer { + let sniffer = socket_sniffer::SocketSniffer { s: stream, rx_bytes: 0, tx_bytes: 0, start_time: std::time::Instant::now(), }; - Ok::<_, std::io::Error>(hyper_util::rt::tokio::TokioIo::new(sniffer)) + Ok::<_, std::io::Error>(TokioIo::new(sniffer)) })) .await .map_err(Error::GrpcConnectError)?; @@ -300,63 +301,68 @@ async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService, Ok(RelayConfigService::new(connection)) } -struct SocketSniffer<S> { - s: S, - rx_bytes: usize, - tx_bytes: usize, - start_time: std::time::Instant, -} - -impl<S> Drop for SocketSniffer<S> { - fn drop(&mut self) { - let duration = self.start_time.elapsed(); - log::debug!( - "Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s", - self.rx_bytes, - self.tx_bytes, - duration.as_secs() - ); +mod socket_sniffer { + pub struct SocketSniffer<S> { + pub s: S, + pub rx_bytes: usize, + pub tx_bytes: usize, + pub start_time: std::time::Instant, } -} + use std::{ + io, + pin::Pin, + task::{Context, Poll}, + }; + + use tokio::io::AsyncWrite; -impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncRead for SocketSniffer<S> { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll<std::io::Result<()>> { - let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_read(cx, buf)); - if bytes.is_ok() { - self.rx_bytes += buf.filled().len(); + use tokio::io::{AsyncRead, ReadBuf}; + + impl<S> Drop for SocketSniffer<S> { + fn drop(&mut self) { + let duration = self.start_time.elapsed(); + log::debug!( + "Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s", + self.rx_bytes, + self.tx_bytes, + duration.as_secs() + ); } - std::task::Poll::Ready(bytes) } -} -impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncWrite for SocketSniffer<S> { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll<std::io::Result<usize>> { - let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_write(cx, buf)); - if bytes.is_ok() { - self.tx_bytes += buf.len(); + impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SocketSniffer<S> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let bytes = std::task::ready!(Pin::new(&mut self.s).poll_read(cx, buf)); + if bytes.is_ok() { + self.rx_bytes += buf.filled().len(); + } + Poll::Ready(bytes) } - std::task::Poll::Ready(bytes) } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<std::io::Result<()>> { - std::pin::Pin::new(&mut self.s).poll_flush(cx) - } + impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SocketSniffer<S> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let bytes = std::task::ready!(Pin::new(&mut self.s).poll_write(cx, buf)); + if bytes.is_ok() { + self.tx_bytes += buf.len(); + } + Poll::Ready(bytes) + } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<std::io::Result<()>> { - std::pin::Pin::new(&mut self.s).poll_shutdown(cx) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.s).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.s).poll_shutdown(cx) + } } } |
