diff options
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 6 | ||||
| -rw-r--r-- | talpid-wireguard/src/stats.rs | 145 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go.rs | 111 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/mod.rs | 1 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/stats.rs | 32 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_nt.rs | 9 |
6 files changed, 149 insertions, 155 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 2a98ad3fb0..b4a0b7e1c8 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -29,15 +29,13 @@ use talpid_tunnel::tun_provider; use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use ipnetwork::IpNetwork; -#[cfg(windows)] -use talpid_types::BoxedError; use talpid_types::{ net::{ obfuscation::ObfuscatorConfig, wireguard::{PresharedKey, PrivateKey, PublicKey}, AllowedTunnelTraffic, Endpoint, TransportProtocol, }, - ErrorExt, + BoxedError, ErrorExt, }; use tokio::sync::Mutex as AsyncMutex; use tunnel_obfuscation::{ @@ -1005,7 +1003,7 @@ pub enum TunnelError { /// Error whilst trying to parse the WireGuard config to read the stats #[error(display = "Reading tunnel stats failed")] - StatsError(#[error(source)] stats::Error), + StatsError(#[error(source)] BoxedError), /// Error whilst trying to retrieve config of a WireGuard tunnel #[error(display = "Failed to get config of WireGuard tunnel")] diff --git a/talpid-wireguard/src/stats.rs b/talpid-wireguard/src/stats.rs index 79db5937a0..cdbe8318cf 100644 --- a/talpid-wireguard/src/stats.rs +++ b/talpid-wireguard/src/stats.rs @@ -1,21 +1,3 @@ -#[cfg(target_os = "linux")] -use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla}; - -#[derive(err_derive::Error, Debug, PartialEq)] -pub enum Error { - #[error(display = "Failed to parse peer pubkey from string \"_0\"")] - PubKeyParse(String, #[error(source)] hex::FromHexError), - - #[error(display = "Failed to parse integer from string \"_0\"")] - IntParse(String, #[error(source)] std::num::ParseIntError), - - #[error(display = "Device no longer exists")] - NoTunnelDevice, - - #[error(display = "Failed to obtain tunnel config")] - NoTunnelConfig, -} - /// Contains bytes sent and received through a tunnel #[derive(Default, Debug, PartialEq, Eq, Clone, Copy)] pub struct Stats { @@ -25,130 +7,3 @@ pub struct Stats { /// A map from peer pubkeys to peer stats. pub type StatsMap = std::collections::HashMap<[u8; 32], Stats>; - -impl Stats { - #[cfg(unix)] - pub fn parse_config_str(config: &str) -> Result<StatsMap, Error> { - let mut map = StatsMap::new(); - - let mut peer = None; - let mut tx_bytes = None; - let mut rx_bytes = None; - - // parts iterates over keys and values - let parts = config.split('\n').filter_map(|line| { - let mut pair = line.split('='); - let key = pair.next()?; - let value = pair.next()?; - Some((key, value)) - }); - - for (key, value) in parts { - match key { - "public_key" => { - let mut buffer = [0u8; 32]; - hex::decode_to_slice(value, &mut buffer) - .map_err(|err| Error::PubKeyParse(value.to_string(), err))?; - peer = Some(buffer); - tx_bytes = None; - rx_bytes = None; - } - "rx_bytes" => { - rx_bytes = Some( - value - .trim() - .parse() - .map_err(|err| Error::IntParse(value.to_string(), err))?, - ); - } - "tx_bytes" => { - tx_bytes = Some( - value - .trim() - .parse() - .map_err(|err| Error::IntParse(value.to_string(), err))?, - ); - } - - _ => continue, - } - - if let (Some(peer_val), Some(tx_bytes_val), Some(rx_bytes_val)) = - (peer, tx_bytes, rx_bytes) - { - map.insert( - peer_val, - Self { - tx_bytes: tx_bytes_val, - rx_bytes: rx_bytes_val, - }, - ); - peer = None; - tx_bytes = None; - rx_bytes = None; - } - } - Ok(map) - } - - #[cfg(target_os = "linux")] - pub fn parse_device_message(message: &DeviceMessage) -> StatsMap { - let mut map = StatsMap::new(); - - for nla in &message.nlas { - if let DeviceNla::Peers(peers) = nla { - for msg in peers { - let mut tx_bytes = 0; - let mut rx_bytes = 0; - let mut pub_key = None; - - for nla in &msg.0 { - match nla { - PeerNla::TxBytes(bytes) => tx_bytes = *bytes, - PeerNla::RxBytes(bytes) => rx_bytes = *bytes, - PeerNla::PublicKey(key) => pub_key = Some(*key), - _ => continue, - } - } - if let Some(key) = pub_key { - map.insert(key, Stats { tx_bytes, rx_bytes }); - } - } - } - } - - map - } -} - -#[cfg(test)] -mod test { - use super::{Error, Stats}; - - #[cfg(unix)] - #[test] - fn test_parsing() { - let valid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=2740\nrx_bytes=2396\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; - let pubkey = [0u8; 32]; - - let stats = Stats::parse_config_str(valid_input).expect("Failed to parse valid input"); - assert_eq!(stats.len(), 1); - let actual_keys: Vec<[u8; 32]> = stats.keys().cloned().collect(); - assert_eq!(actual_keys, [pubkey]); - assert_eq!(stats[&pubkey].rx_bytes, 2396); - assert_eq!(stats[&pubkey].tx_bytes, 2740); - } - - #[cfg(unix)] - #[test] - fn test_parsing_invalid_input() { - let invalid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=27error40\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; - let invalid_str = "27error40".to_string(); - let int_err = invalid_str.parse::<u64>().unwrap_err(); - - assert_eq!( - Stats::parse_config_str(invalid_input), - Err(Error::IntParse(invalid_str, int_err)) - ); - } -} diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go.rs index 24ad613659..476507ea16 100644 --- a/talpid-wireguard/src/wireguard_go.rs +++ b/talpid-wireguard/src/wireguard_go.rs @@ -11,6 +11,7 @@ use std::{ pin::Pin, }; use talpid_tunnel::tun_provider::TunProvider; +use talpid_types::BoxedError; use zeroize::Zeroize; #[cfg(target_os = "android")] @@ -201,7 +202,7 @@ impl Tunnel for WgGoTunnel { let result = Stats::parse_config_str(config_str.to_str().expect("Go strings are always UTF-8")) - .map_err(TunnelError::StatsError); + .map_err(|error| TunnelError::StatsError(BoxedError::new(error))); unsafe { // Zeroing out config string to not leave private key in memory. let slice = std::slice::from_raw_parts_mut( @@ -318,3 +319,111 @@ extern "C" { #[cfg(target_os = "android")] fn wgGetSocketV6(handle: i32) -> Fd; } + +mod stats { + use super::{Stats, StatsMap}; + + #[derive(err_derive::Error, Debug, PartialEq)] + pub enum Error { + #[error(display = "Failed to parse peer pubkey from string \"_0\"")] + PubKeyParse(String, #[error(source)] hex::FromHexError), + + #[error(display = "Failed to parse integer from string \"_0\"")] + IntParse(String, #[error(source)] std::num::ParseIntError), + } + + impl Stats { + pub fn parse_config_str(config: &str) -> std::result::Result<StatsMap, Error> { + let mut map = StatsMap::new(); + + let mut peer = None; + let mut tx_bytes = None; + let mut rx_bytes = None; + + // parts iterates over keys and values + let parts = config.split('\n').filter_map(|line| { + let mut pair = line.split('='); + let key = pair.next()?; + let value = pair.next()?; + Some((key, value)) + }); + + for (key, value) in parts { + match key { + "public_key" => { + let mut buffer = [0u8; 32]; + hex::decode_to_slice(value, &mut buffer) + .map_err(|err| Error::PubKeyParse(value.to_string(), err))?; + peer = Some(buffer); + tx_bytes = None; + rx_bytes = None; + } + "rx_bytes" => { + rx_bytes = Some( + value + .trim() + .parse() + .map_err(|err| Error::IntParse(value.to_string(), err))?, + ); + } + "tx_bytes" => { + tx_bytes = Some( + value + .trim() + .parse() + .map_err(|err| Error::IntParse(value.to_string(), err))?, + ); + } + + _ => continue, + } + + if let (Some(peer_val), Some(tx_bytes_val), Some(rx_bytes_val)) = + (peer, tx_bytes, rx_bytes) + { + map.insert( + peer_val, + Self { + tx_bytes: tx_bytes_val, + rx_bytes: rx_bytes_val, + }, + ); + peer = None; + tx_bytes = None; + rx_bytes = None; + } + } + Ok(map) + } + } + + #[cfg(test)] + mod test { + use super::super::stats::{Error, Stats}; + + #[test] + fn test_parsing() { + let valid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=2740\nrx_bytes=2396\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; + let pubkey = [0u8; 32]; + + let stats = Stats::parse_config_str(valid_input).expect("Failed to parse valid input"); + assert_eq!(stats.len(), 1); + let actual_keys: Vec<[u8; 32]> = stats.keys().cloned().collect(); + assert_eq!(actual_keys, [pubkey]); + assert_eq!(stats[&pubkey].rx_bytes, 2396); + assert_eq!(stats[&pubkey].tx_bytes, 2740); + } + + #[test] + fn test_parsing_invalid_input() { + let invalid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=27error40\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; + let invalid_str = "27error40".to_string(); + let int_err = invalid_str.parse::<u64>().unwrap_err(); + + assert_eq!( + Stats::parse_config_str(invalid_input), + Err(Error::IntParse(invalid_str, int_err)) + ); + } + } +} diff --git a/talpid-wireguard/src/wireguard_kernel/mod.rs b/talpid-wireguard/src/wireguard_kernel/mod.rs index 7a64ee3f65..15eaad4c24 100644 --- a/talpid-wireguard/src/wireguard_kernel/mod.rs +++ b/talpid-wireguard/src/wireguard_kernel/mod.rs @@ -18,6 +18,7 @@ use std::{ffi::CString, net::IpAddr}; use tokio_stream::StreamExt; mod parsers; +mod stats; pub mod wg_message; use wg_message::{DeviceMessage, DeviceNla}; diff --git a/talpid-wireguard/src/wireguard_kernel/stats.rs b/talpid-wireguard/src/wireguard_kernel/stats.rs new file mode 100644 index 0000000000..8604e8243f --- /dev/null +++ b/talpid-wireguard/src/wireguard_kernel/stats.rs @@ -0,0 +1,32 @@ +use super::wg_message::{DeviceMessage, DeviceNla, PeerNla}; +use crate::stats::{Stats, StatsMap}; + +impl Stats { + pub fn parse_device_message(message: &DeviceMessage) -> StatsMap { + let mut map = StatsMap::new(); + + for nla in &message.nlas { + if let DeviceNla::Peers(peers) = nla { + for msg in peers { + let mut tx_bytes = 0; + let mut rx_bytes = 0; + let mut pub_key = None; + + for nla in &msg.0 { + match nla { + PeerNla::TxBytes(bytes) => tx_bytes = *bytes, + PeerNla::RxBytes(bytes) => rx_bytes = *bytes, + PeerNla::PublicKey(key) => pub_key = Some(*key), + _ => continue, + } + } + if let Some(key) = pub_key { + map.insert(key, Stats { tx_bytes, rx_bytes }); + } + } + } + } + + map + } +} diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt.rs index 66e830bbc2..1b4405eba2 100644 --- a/talpid-wireguard/src/wireguard_nt.rs +++ b/talpid-wireguard/src/wireguard_nt.rs @@ -946,9 +946,9 @@ impl Tunnel for WgNtTunnel { let (_interface, peers) = device.get_config().map_err(|error| { log::error!( "{}", - error.display_chain_with_msg("Failed to obtain wg-nt tunnel config") + error.display_chain_with_msg("Failed to obtain tunnel config") ); - super::TunnelError::StatsError(super::stats::Error::NoTunnelConfig) + super::TunnelError::GetConfigError })?; for (peer, _allowed_ips) in &peers { map.insert( @@ -961,9 +961,8 @@ impl Tunnel for WgNtTunnel { } Ok(map) } else { - Err(super::TunnelError::StatsError( - super::stats::Error::NoTunnelDevice, - )) + log::error!("Failed to obtain tunnel stats as device no longer exists"); + Err(super::TunnelError::GetConfigError) } } |
