summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-12-20 17:13:26 +0100
committerSebastian Holmin <sebastian.holmin@mullvad.net>2024-12-20 17:36:38 +0100
commit4fc8028b301f59e0da4154aaa7d7c650a59cffeb (patch)
treeee1e7a5477c65af1d078df7b6bac6218c4d20592
parentbb087ff19f28ac839c9d892cab798d2666843047 (diff)
downloadmullvadvpn-4fc8028b301f59e0da4154aaa7d7c650a59cffeb.tar.xz
mullvadvpn-4fc8028b301f59e0da4154aaa7d7c650a59cffeb.zip
Move SocketSniffer to separate module
-rw-r--r--talpid-tunnel-config-client/src/lib.rs110
1 files changed, 58 insertions, 52 deletions
diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs
index f09cd32074..3a1b25bb46 100644
--- a/talpid-tunnel-config-client/src/lib.rs
+++ b/talpid-tunnel-config-client/src/lib.rs
@@ -5,7 +5,6 @@ use std::net::SocketAddr;
#[cfg(not(target_os = "ios"))]
use std::net::{IpAddr, Ipv4Addr};
use talpid_types::net::wireguard::{PresharedKey, PublicKey};
-use tokio::io::{AsyncRead, AsyncWrite};
use tonic::transport::Channel;
#[cfg(not(target_os = "ios"))]
use tonic::transport::Endpoint;
@@ -279,6 +278,8 @@ fn xor_assign(dst: &mut [u8; 32], src: &[u8; 32]) {
/// value has been speficically lowered, to avoid MTU issues. See the `socket` module.
#[cfg(not(target_os = "ios"))]
async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService, Error> {
+ use hyper_util::rt::tokio::TokioIo;
+
let endpoint = Endpoint::from_static("tcp://0.0.0.0:0");
let addr = SocketAddr::new(IpAddr::V4(ip), CONFIG_SERVICE_PORT);
@@ -286,13 +287,13 @@ async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService,
.connect_with_connector(service_fn(move |_| async move {
let sock = socket::TcpSocket::new()?;
let stream = sock.connect(addr).await?;
- let sniffer = SocketSniffer {
+ let sniffer = socket_sniffer::SocketSniffer {
s: stream,
rx_bytes: 0,
tx_bytes: 0,
start_time: std::time::Instant::now(),
};
- Ok::<_, std::io::Error>(hyper_util::rt::tokio::TokioIo::new(sniffer))
+ Ok::<_, std::io::Error>(TokioIo::new(sniffer))
}))
.await
.map_err(Error::GrpcConnectError)?;
@@ -300,63 +301,68 @@ async fn connect_relay_config_client(ip: Ipv4Addr) -> Result<RelayConfigService,
Ok(RelayConfigService::new(connection))
}
-struct SocketSniffer<S> {
- s: S,
- rx_bytes: usize,
- tx_bytes: usize,
- start_time: std::time::Instant,
-}
-
-impl<S> Drop for SocketSniffer<S> {
- fn drop(&mut self) {
- let duration = self.start_time.elapsed();
- log::debug!(
- "Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s",
- self.rx_bytes,
- self.tx_bytes,
- duration.as_secs()
- );
+mod socket_sniffer {
+ pub struct SocketSniffer<S> {
+ pub s: S,
+ pub rx_bytes: usize,
+ pub tx_bytes: usize,
+ pub start_time: std::time::Instant,
}
-}
+ use std::{
+ io,
+ pin::Pin,
+ task::{Context, Poll},
+ };
+
+ use tokio::io::AsyncWrite;
-impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncRead for SocketSniffer<S> {
- fn poll_read(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &mut tokio::io::ReadBuf<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_read(cx, buf));
- if bytes.is_ok() {
- self.rx_bytes += buf.filled().len();
+ use tokio::io::{AsyncRead, ReadBuf};
+
+ impl<S> Drop for SocketSniffer<S> {
+ fn drop(&mut self) {
+ let duration = self.start_time.elapsed();
+ log::debug!(
+ "Tunnel config client connection ended. RX: {} bytes, TX: {} bytes, duration: {} s",
+ self.rx_bytes,
+ self.tx_bytes,
+ duration.as_secs()
+ );
}
- std::task::Poll::Ready(bytes)
}
-}
-impl<S: AsyncRead + AsyncWrite + std::marker::Unpin> AsyncWrite for SocketSniffer<S> {
- fn poll_write(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &[u8],
- ) -> std::task::Poll<std::io::Result<usize>> {
- let bytes = std::task::ready!(std::pin::Pin::new(&mut self.s).poll_write(cx, buf));
- if bytes.is_ok() {
- self.tx_bytes += buf.len();
+ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SocketSniffer<S> {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let bytes = std::task::ready!(Pin::new(&mut self.s).poll_read(cx, buf));
+ if bytes.is_ok() {
+ self.rx_bytes += buf.filled().len();
+ }
+ Poll::Ready(bytes)
}
- std::task::Poll::Ready(bytes)
}
- fn poll_flush(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- std::pin::Pin::new(&mut self.s).poll_flush(cx)
- }
+ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SocketSniffer<S> {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ let bytes = std::task::ready!(Pin::new(&mut self.s).poll_write(cx, buf));
+ if bytes.is_ok() {
+ self.tx_bytes += buf.len();
+ }
+ Poll::Ready(bytes)
+ }
- fn poll_shutdown(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- std::pin::Pin::new(&mut self.s).poll_shutdown(cx)
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.s).poll_flush(cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.s).poll_shutdown(cx)
+ }
}
}