diff options
| author | Emīls <emils@mullvad.net> | 2020-01-28 12:02:16 +0000 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2020-01-28 12:02:16 +0000 |
| commit | c48e088f5cd7bf101db33e117d8ea7acf86c9580 (patch) | |
| tree | bcc0f9edf8916b8ff4ba252a1134464cf89cf810 | |
| parent | ed5a61312692268e774abe39599d6099e47f0cae (diff) | |
| parent | a42a28cf6ab60e6122ca9e0ca42d36222e5302eb (diff) | |
| download | mullvadvpn-c48e088f5cd7bf101db33e117d8ea7acf86c9580.tar.xz mullvadvpn-c48e088f5cd7bf101db33e117d8ea7acf86c9580.zip | |
Merge branch 'improve-connectivity-check'
| -rw-r--r-- | CHANGELOG.md | 1 | ||||
| -rw-r--r-- | Cargo.lock | 7 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/unix.rs | 96 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/win.rs | 135 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/connectivity_check.rs | 201 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 100 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/stats.rs | 95 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 43 |
10 files changed, 461 insertions, 220 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 970c924031..1d48bf630a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ Line wrap the file at 100 chars. Th a bit less frequently. - Increase WireGuard ping timeout from 7 to 15 seconds. - Updated `wireguard-go` to `v0.0.20200121` +- Use traffic data from WireGuard to infer connectivity to improve stability of the connection. #### Linux - DNS management with static `/etc/resolv.conf` will now work even when no diff --git a/Cargo.lock b/Cargo.lock index 599fcf0c47..8e6b03aa4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2469,6 +2469,7 @@ dependencies = [ "widestring 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", "winreg 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", + "zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -3139,6 +3140,11 @@ dependencies = [ "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "zeroize" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [metadata] "checksum aho-corasick 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)" = "81ce3d38065e618af2d7b77e10c5ad9a069859b4be3c2250f674af3840d9c8a5" "checksum aho-corasick 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)" = "58fb5e95d83b38284460a5fda7d6470aa0b8844d283a0b614b8535e880800d2d" @@ -3460,3 +3466,4 @@ dependencies = [ "checksum winres 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "ff4fb510bbfe5b8992ff15f77a2e6fe6cf062878f0eda00c0f44963a807ca5dc" "checksum ws2_32-sys 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d59cefebd0c892fa2dd6de581e937301d8552cb44489cdff035c6187cb63fa5e" "checksum x25519-dalek 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7ee1585dc1484373cbc1cee7aafda26634665cf449436fd6e24bfd1fad230538" +"checksum zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3cbac2ed2ba24cc90f5e06485ac8c7c1e5449fe8911aef4d8877218af021a5b8" diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 4230080dbd..7182c676c2 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -30,6 +30,7 @@ talpid-types = { path = "../talpid-types" } tokio-core = "0.1" tokio-executor = "0.1" uuid = { version = "0.7", features = ["v4"] } +zeroize = "1" [target.'cfg(unix)'.dependencies] diff --git a/talpid-core/src/ping_monitor/mod.rs b/talpid-core/src/ping_monitor/mod.rs index 8cdc766ad4..bc148a5892 100644 --- a/talpid-core/src/ping_monitor/mod.rs +++ b/talpid-core/src/ping_monitor/mod.rs @@ -7,4 +7,4 @@ mod imp; #[path = "win.rs"] mod imp; -pub use imp::{monitor_ping, ping, Error}; +pub use imp::{Error, Pinger}; diff --git a/talpid-core/src/ping_monitor/unix.rs b/talpid-core/src/ping_monitor/unix.rs index 8a040d476d..c1afa60fae 100644 --- a/talpid-core/src/ping_monitor/unix.rs +++ b/talpid-core/src/ping_monitor/unix.rs @@ -1,12 +1,4 @@ -#[allow(dead_code)] -// TODO: remove the lint exemption above when ping monitor is used -use std::{ - io, - net::Ipv4Addr, - sync::mpsc, - thread, - time::{Duration, Instant}, -}; +use std::{io, net::Ipv4Addr}; #[derive(err_derive::Error, Debug)] pub enum Error { @@ -17,51 +9,55 @@ pub enum Error { TimeoutError, } -pub fn monitor_ping( - ip: Ipv4Addr, - timeout_secs: u16, - interface: &str, - close_receiver: mpsc::Receiver<()>, -) -> Result<(), Error> { - while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() { - let start = Instant::now(); - internal_ping(ip, timeout_secs, &interface, false)?; - if let Some(remaining) = - Duration::from_secs(timeout_secs.into()).checked_sub(start.elapsed()) - { - thread::sleep(remaining); - } +/// A pinger that sends ICMP requests without waiting for responses +pub struct Pinger { + addr: Ipv4Addr, + interface_name: String, + processes: Vec<duct::Handle>, +} + +impl Pinger { + pub fn new(addr: Ipv4Addr, interface_name: String) -> Result<Self, Error> { + Ok(Self { + processes: vec![], + addr, + interface_name, + }) } - Ok(()) -} + // Send an ICMP packet without waiting for a reply + pub fn send_icmp(&mut self) -> Result<(), Error> { + self.try_deplete_process_list(); -pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<(), Error> { - internal_ping(ip, timeout_secs, interface, true) + let cmd = ping_cmd(self.addr, 1, &self.interface_name); + let handle = cmd.start().map_err(Error::PingError)?; + self.processes.push(handle); + Ok(()) + } + + fn try_deplete_process_list(&mut self) { + self.processes.retain(|child| { + match child.try_wait() { + // child has terminated, doesn't have to be retained + Ok(Some(_)) => false, + _ => true, + } + }); + } } -fn internal_ping( - ip: Ipv4Addr, - timeout_secs: u16, - interface: &str, - exit_on_first_reply: bool, -) -> Result<(), Error> { - let output = ping_cmd(ip, timeout_secs, interface, exit_on_first_reply) - .run() - .map_err(Error::PingError)?; - if output.status.success() { - Ok(()) - } else { - Err(Error::TimeoutError) +impl Drop for Pinger { + fn drop(&mut self) { + for child in self.processes.iter_mut() { + if let Err(e) = child.kill() { + log::error!("Failed to kill ping process - {}", e); + } + } } } -fn ping_cmd( - ip: Ipv4Addr, - timeout_secs: u16, - interface: &str, - exit_on_first_reply: bool, -) -> duct::Expression { + +fn ping_cmd(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> duct::Expression { let mut args = vec!["-n", "-i", "1"]; let timeout_flag = if cfg!(target_os = "linux") || cfg!(target_os = "android") { @@ -85,14 +81,6 @@ fn ping_cmd( args.extend_from_slice(&[interface_flag, interface]); } - if exit_on_first_reply { - if cfg!(target_os = "macos") { - args.push("-o"); - } else { - args.extend_from_slice(&["-c", "1"]) - } - } - let ip = ip.to_string(); args.push(&ip); diff --git a/talpid-core/src/ping_monitor/win.rs b/talpid-core/src/ping_monitor/win.rs index 40fa523584..c2fd16dd9d 100644 --- a/talpid-core/src/ping_monitor/win.rs +++ b/talpid-core/src/ping_monitor/win.rs @@ -1,7 +1,6 @@ use pnet_packet::{ icmp::{ self, - echo_reply::EchoReplyPacket, echo_request::{EchoRequestPacket, MutableEchoRequestPacket}, IcmpCode, IcmpPacket, IcmpType, }, @@ -11,9 +10,8 @@ use socket2::{Domain, Protocol, Socket, Type}; use std::{ io, net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::mpsc, thread, - time::{Duration, Instant}, + time::Duration, }; const SEND_RETRY_ATTEMPTS: u32 = 10; @@ -37,30 +35,6 @@ pub enum Error { TimeoutError, } -pub fn monitor_ping( - ip: Ipv4Addr, - timeout_secs: u16, - interface: &str, - close_receiver: mpsc::Receiver<()>, -) -> Result<()> { - let mut pinger = Pinger::new(ip, interface)?; - while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() { - let start = Instant::now(); - pinger.send_ping(Duration::from_secs(timeout_secs.into()))?; - if let Some(remaining) = - Duration::from_secs(timeout_secs.into()).checked_sub(start.elapsed()) - { - thread::sleep(remaining); - } - } - - Ok(()) -} - -pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<()> { - Pinger::new(ip, interface)?.send_ping(Duration::from_secs(timeout_secs.into())) -} - type Result<T> = std::result::Result<T, Error>; pub struct Pinger { @@ -70,10 +44,9 @@ pub struct Pinger { seq: u16, } -const NUM_PINGS_TO_SEND: usize = 3; impl Pinger { - pub fn new(addr: Ipv4Addr, _interface_name: &str) -> Result<Self> { + pub fn new(addr: Ipv4Addr, _interface_name: String) -> Result<Self> { let sock = Socket::new(Domain::ipv4(), Type::raw(), Some(Protocol::icmpv4())) .map_err(Error::OpenError)?; sock.set_nonblocking(true).map_err(Error::OpenError)?; @@ -87,19 +60,13 @@ impl Pinger { }) } - /// Sends an ICMP echo request - pub fn send_ping(&mut self, timeout: Duration) -> Result<()> { + pub fn send_icmp(&mut self) -> Result<()> { let dest = SocketAddr::new(IpAddr::from(self.addr), 0); - let requests = (0..NUM_PINGS_TO_SEND) - .map(|_| { - let request = self.next_ping_request(); - self.send_ping_request(&request, dest)?; - Ok(request) - }) - .collect::<Result<Vec<_>>>()?; - self.wait_for_response(Instant::now() + timeout, &requests) + let request = self.next_ping_request(); + self.send_ping_request(&request, dest) } + fn send_ping_request( &mut self, request: &EchoRequestPacket<'static>, @@ -149,94 +116,4 @@ impl Pinger { self.seq += 1; seq } - - - fn wait_for_response( - &mut self, - deadline: Instant, - requests: &[EchoRequestPacket<'_>], - ) -> Result<()> { - let mut recv_buffer = [0u8; 4096]; - let mut bytes_received = 0; - let mut success = false; - let mut requests = requests.iter().map(|req| (false, req)).collect::<Vec<_>>(); - 'outer: while Instant::now() < deadline { - match self.sock.recv(&mut recv_buffer) { - Ok(recv_len) => { - bytes_received += recv_len; - if recv_len > 20 { - // have to slice off first 20 bytes for the IP header. - if let Some(reply) = Self::parse_response(&recv_buffer[20..recv_len]) { - for (used, req) in requests.iter_mut() { - if *used { - continue; - } - if Self::request_and_response_match(req, &reply) { - *used = true; - success = true; - continue 'outer; - } - } - } - } - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if success { - return Ok(()); - } - std::thread::sleep(Duration::from_millis(100)); - continue; - } - Err(e) => { - return Err(Error::ReadError(e)); - } - } - } - log::debug!( - "Timing out whilst waiting for ICMP response after receiving {} bytes", - bytes_received - ); - Err(Error::TimeoutError) - } - - fn request_and_response_match(req: &EchoRequestPacket<'_>, resp: &EchoReplyPacket<'_>) -> bool { - if req.get_identifier() != resp.get_identifier() { - log::debug!( - "Expected idnetifier {} - got {}", - req.get_identifier(), - resp.get_identifier() - ); - return false; - } - - if req.get_sequence_number() != resp.get_sequence_number() { - log::debug!( - "Expected sequence number {} - got {}", - req.get_sequence_number(), - resp.get_sequence_number() - ); - return false; - } - - if req.payload() != resp.payload() { - log::debug!( - "Expected payload {:?} - got {:?}", - req.payload(), - resp.payload() - ); - return false; - } - - return true; - } - - fn parse_response<'a>(buffer: &'a [u8]) -> Option<EchoReplyPacket<'a>> { - let icmp_checksum = icmp::checksum(&IcmpPacket::new(buffer)?); - let reply = EchoReplyPacket::new(buffer)?; - if reply.get_checksum() == icmp_checksum { - Some(reply) - } else { - None - } - } } diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs new file mode 100644 index 0000000000..7b23198d2f --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -0,0 +1,201 @@ +use crate::{ping_monitor::Pinger, tunnel::wireguard::stats::Stats}; +use std::{ + net::Ipv4Addr, + sync::{mpsc, Mutex, Weak}, + time::{Duration, Instant}, +}; + +use super::{Error, Tunnel}; + +/// Sleep time used when initially establishing connectivity +const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); +/// Sleep time used when checking if an established connection is still working. +const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); + + +/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is +/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic +/// is received. +const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5); +/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will +/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received. +const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120); +/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this +/// timeout is reached, it is assumed that the connection is lost. +const PING_TIMEOUT: Duration = Duration::from_secs(15); +/// Number of seconds to wait between sending ICMP packets +const SECONDS_PER_PING: Duration = Duration::from_secs(3); + + +/// Verifies if a connection to a tunnel is working. +/// The connectivity monitor is biased to receiving traffic - it is expected that all outgoing +/// traffic will be answered with a response. +/// +/// The connectivity monitor tries to opportunistically use information about how much data has +/// been sent through the tunnel to infer connectivity. This is done by reading the traffic data +/// from the tunnel and recording the time of the reading - the connectivity monitor only stores +/// the timestamp of when was the last time an increase in either incoming or outgoing traffic was +/// observed. The connectivity monitor tries to read the data at a set interval, and the connection +/// is considered to be working if the incoming traffic timestamp has been incremented in a given +/// timeout. A connection is considered to be established the first time an increase in incoming +/// traffic is observed. +/// +/// The connectivity monitor will start sending pings and start the countdown to `PING_TIMEOUT` in +/// the following cases: +/// - In case that we have observed a bump in the outgoing traffic but no coressponding incoming +/// traffic for longer than `BYTES_RX_TIMEOUT`, then the monitor will start pinging. +/// - In case that no increase in outgoing or incoming traffic has been observed for longer than +/// `TRAFFIC_TIMEOUT`, then the monitor will start pinging as well. +/// +/// Once a connection established, a connection is only considered broken once the connectivity +/// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. +pub struct ConnectivityMonitor { + tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, + last_stats: Stats, + tx_timestamp: Instant, + rx_timestamp: Instant, + initial_ping_timestamp: Option<Instant>, + num_pings_sent: u32, + pinger: Pinger, + close_receiver: mpsc::Receiver<()>, +} + +impl ConnectivityMonitor { + pub fn new( + addr: Ipv4Addr, + interface: String, + tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>, + close_receiver: mpsc::Receiver<()>, + ) -> Result<Self, Error> { + let pinger = Pinger::new(addr, interface).map_err(Error::PingError)?; + + let now = Instant::now(); + + Ok(Self { + tunnel_handle, + last_stats: Default::default(), + tx_timestamp: now, + rx_timestamp: now, + initial_ping_timestamp: None, + num_pings_sent: 0, + pinger, + close_receiver, + }) + } + + // checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is + // successfull at the start of a connection. + pub fn establish_connectivity(&mut self) -> Result<bool, Error> { + if self.last_stats.rx_bytes > 0 { + return Ok(true); + } + + let start = Instant::now(); + while start.elapsed() < PING_TIMEOUT { + if self.check_connectivity()? { + return Ok(true); + } + if self.should_shut_down(DELAY_ON_INITIAL_SETUP) { + return Ok(false); + } + } + Ok(false) + } + + pub fn run(&mut self) -> Result<(), Error> { + self.wait_loop(REGULAR_LOOP_SLEEP) + } + + /// Returns true if monitor should be shut down + fn should_shut_down(&mut self, timeout: Duration) -> bool { + match self.close_receiver.recv_timeout(timeout) { + Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, + Err(mpsc::RecvTimeoutError::Timeout) => false, + } + } + + fn wait_loop(&mut self, iter_delay: Duration) -> Result<(), Error> { + while self.check_connectivity()? && !self.should_shut_down(iter_delay) {} + Ok(()) + } + + /// Returns true if connection is established + fn check_connectivity(&mut self) -> Result<bool, Error> { + let now = Instant::now(); + match self.get_stats() { + None => Ok(false), + Some(new_stats) => { + let new_stats = new_stats?; + let last_stats = self.last_stats; + self.last_stats = new_stats; + + if new_stats.tx_bytes > last_stats.tx_bytes { + self.tx_timestamp = now; + } + + if new_stats.rx_bytes > last_stats.rx_bytes { + self.rx_timestamp = now; + // resetting ping + self.initial_ping_timestamp = None; + self.num_pings_sent = 0; + return Ok(true); + } + + self.maybe_send_ping()?; + Ok(!self.ping_timed_out() && self.last_stats.rx_bytes > 0) + } + } + } + + /// If None is returned, then the underlying tunnel has already been closed and all subsequent + /// calls will also return None. + fn get_stats(&self) -> Option<Result<Stats, Error>> { + self.tunnel_handle + .upgrade()? + .lock() + .ok()? + .as_ref() + .map(|tunnel| tunnel.get_config()) + } + + fn maybe_send_ping(&mut self) -> Result<(), Error> { + // Only send out a ping if we haven't received a byte in a while or no traffic has flowed + // in the last 2 minutes, but if a ping already has been sent out, only send one out every + // 3 seconds. + if (self.rx_timed_out() || self.traffic_timed_out()) + && self + .initial_ping_timestamp + .map(|initial_ping_timestamp| { + initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING + }) + .unwrap_or(true) + { + self.pinger.send_icmp().map_err(Error::PingError)?; + if self.initial_ping_timestamp.is_none() { + self.initial_ping_timestamp = Some(Instant::now()); + } + self.num_pings_sent += 1; + } + Ok(()) + } + + // check if last time data was received is too long ago + fn rx_timed_out(&self) -> bool { + // if last sent bytes were sent after last received bytes + self.tx_timestamp > self.rx_timestamp + // and the response hasn't been seen for BYTES_RX_TIMEOUT + && self.rx_timestamp.elapsed() >= BYTES_RX_TIMEOUT + } + + // check if no bytes have been sent or received in a while + fn traffic_timed_out(&self) -> bool { + self.rx_timestamp.elapsed() >= TRAFFIC_TIMEOUT + || self.tx_timestamp.elapsed() >= TRAFFIC_TIMEOUT + } + + fn ping_timed_out(&self) -> bool { + self.initial_ping_timestamp + .map(|initial_ping_timestamp| initial_ping_timestamp.elapsed() > PING_TIMEOUT) + .unwrap_or(false) + } +} diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 9b84a83cf0..13bff72d27 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -5,17 +5,20 @@ use self::config::Config; use super::tun_provider; use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; use crate::{ping_monitor, routing}; -use std::{collections::HashMap, io, path::Path, sync::mpsc}; -use talpid_types::ErrorExt; +use std::{ + collections::HashMap, + io, + path::Path, + sync::{mpsc, Arc, Mutex}, +}; pub mod config; +mod connectivity_check; +mod stats; pub mod wireguard_go; pub use self::wireguard_go::WgGoTunnel; -// amount of seconds to run `ping` until it returns. -const PING_TIMEOUT: u16 = 15; - pub type Result<T> = std::result::Result<T, Error>; /// Errors that can happen in the Wireguard tunnel monitor. @@ -46,6 +49,10 @@ pub enum Error { #[error(display = "Failed to stop wireguard tunnel - {}", status)] StopWireguardError { status: i32 }, + /// Failed to get tunnel config + #[error(display = "Failed to obtain tunnel config")] + GetConfigError, + /// Failed to set ip addresses on tunnel interface. #[cfg(target_os = "windows")] #[error(display = "Failed to set IP addresses on WireGuard interface")] @@ -73,15 +80,28 @@ pub enum Error { #[error(display = "Failed to duplicate tunnel file descriptor for wireguard-go")] FdDuplicationError(#[error(source)] nix::Error), + /// Error whilst trying to read stats + #[error(display = "Reading tunnel stats failed")] + StatsError(#[error(source)] stats::Error), + + /// Tunnel handle is invalid + #[error(display = "Tunnel handle is invalid")] + InvalidTunnelHandle, + /// Pinging timed out. #[error(display = "Ping timed out")] - PingTimeoutError, + PingError(#[error(source)] ping_monitor::Error), + + /// Tunnel timed out + #[error(display = "Tunnel timed out")] + TimeoutError, } + /// Spawns and monitors a wireguard tunnel pub struct WireguardMonitor { /// Tunnel implementation - tunnel: Box<dyn Tunnel>, + tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, /// Route manager route_handle: routing::RouteManager, /// Callback to signal tunnel events @@ -104,20 +124,21 @@ impl WireguardMonitor { tun_provider, Self::get_tunnel_routes(config), )?); - let iface_name = tunnel.get_interface_name(); + let iface_name = tunnel.get_interface_name().to_string(); #[cfg_attr(not(windows), allow(unused_mut))] - let mut route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config)) + let mut route_handle = routing::RouteManager::new(Self::get_routes(&iface_name, &config)) .map_err(Error::SetupRoutingError)?; #[cfg(target_os = "windows")] route_handle .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()); + let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = mpsc::channel(); let (pinger_tx, pinger_rx) = mpsc::channel(); let monitor = WireguardMonitor { - tunnel, + tunnel: Arc::new(Mutex::new(Some(tunnel))), route_handle, event_callback, close_msg_sender, @@ -125,29 +146,28 @@ impl WireguardMonitor { pinger_stop_sender: pinger_tx, }; - let metadata = monitor.tunnel_metadata(&config); - let iface_name = monitor.tunnel.get_interface_name().to_string(); - let gateway = config.ipv4_gateway.into(); + let metadata = Self::tunnel_metadata(&iface_name, &config); + let gateway = config.ipv4_gateway; let close_sender = monitor.close_msg_sender.clone(); + let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( + gateway, + iface_name, + Arc::downgrade(&monitor.tunnel), + pinger_rx, + )?; std::thread::spawn(move || { - match ping_monitor::ping(gateway, PING_TIMEOUT, &iface_name) { - Ok(()) => { - (on_event)(TunnelEvent::Up(metadata)); - - if let Err(error) = - ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx) - { - log::trace!("{}", error.display_chain_with_msg("Ping monitor failed")); - } - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("First ping to gateway failed") - ); + match connectivity_monitor.establish_connectivity() { + Ok(true) => (on_event)(TunnelEvent::Up(metadata)), + Ok(false) => return, + Err(err) => { + log::error!("ConnectivityMonitor failed: {}", err); + return; } } + if let Err(err) = connectivity_monitor.run() { + log::error!("Connectivity monitor failed - {}", err); + } let _ = close_sender.send(CloseMsg::PingErr); }); @@ -163,7 +183,7 @@ impl WireguardMonitor { pub fn wait(mut self) -> Result<()> { let wait_result = match self.close_msg_receiver.recv() { - Ok(CloseMsg::PingErr) => Err(Error::PingTimeoutError), + Ok(CloseMsg::PingErr) => Err(Error::TimeoutError), Ok(CloseMsg::Stop) => Ok(()), Err(_) => Ok(()), }; @@ -175,13 +195,25 @@ impl WireguardMonitor { // routes that were set. self.route_handle.stop(); - if let Err(e) = self.tunnel.stop() { - log::error!("Failed to stop tunnel - {}", e); - } + self.stop_tunnel(); + (self.event_callback)(TunnelEvent::Down); wait_result } + fn stop_tunnel(&mut self) { + match self.tunnel.lock().expect("Tunnel lock poisoned").take() { + Some(tunnel) => { + if let Err(e) = tunnel.stop() { + log::error!("Failed to stop tunnel - {}", e); + } + } + None => { + log::debug!("Tunnel already stopped"); + } + } + } + fn get_tunnel_routes(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { config .peers @@ -218,8 +250,7 @@ impl WireguardMonitor { routes } - fn tunnel_metadata(&self, config: &Config) -> TunnelMetadata { - let interface_name = self.tunnel.get_interface_name(); + fn tunnel_metadata(interface_name: &str, config: &Config) -> TunnelMetadata { TunnelMetadata { interface: interface_name.to_string(), ips: config.tunnel.addresses.clone(), @@ -250,4 +281,5 @@ impl CloseHandle { pub trait Tunnel: Send { fn get_interface_name(&self) -> &str; fn stop(self: Box<Self>) -> Result<()>; + fn get_config(&self) -> Result<stats::Stats>; } diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs new file mode 100644 index 0000000000..545a49d688 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/stats.rs @@ -0,0 +1,95 @@ +#[derive(err_derive::Error, Debug, PartialEq)] +#[error(no_from)] +pub enum Error { + #[error(display = "Failed to parse integer from string \"_0\"")] + IntParseError(String, #[error(source)] std::num::ParseIntError), + + #[error(display = "Config key not found")] + KeyNotFoundError, +} + +/// Contains bytes sent and received through a tunnel +#[derive(Default, Debug, PartialEq, Clone, Copy)] +pub struct Stats { + pub tx_bytes: u64, + pub rx_bytes: u64, +} + +impl Stats { + pub fn parse_config_str(config: &str) -> Result<Self, Error> { + let mut tx_bytes = None; + let mut rx_bytes = None; + + // parts iterates over keys and values + let parts = config.split('\n').filter_map(|line| { + let mut pair = line.split('='); + let key = pair.next()?; + let value = pair.next()?; + Some((key, value)) + }); + + for (key, value) in parts { + match key { + "rx_bytes" => { + rx_bytes = Some( + value + .trim() + .parse() + .map_err(|err| Error::IntParseError(value.to_string(), err))?, + ); + } + "tx_bytes" => { + tx_bytes = Some( + value + .trim() + .parse() + .map_err(|err| Error::IntParseError(value.to_string(), err))?, + ); + } + + _ => continue, + } + } + + match (tx_bytes, rx_bytes) { + (Some(tx_bytes), Some(rx_bytes)) => Ok(Self { tx_bytes, rx_bytes }), + _ => Err(Error::KeyNotFoundError), + } + } +} + + +#[cfg(test)] +mod test { + use super::{Error, Stats}; + + #[test] + fn test_parsing() { + let valid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=2740\nrx_bytes=2396\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; + + let stats = Stats::parse_config_str(valid_input).expect("Failed to parse valid input"); + assert_eq!(stats.rx_bytes, 2396); + assert_eq!(stats.tx_bytes, 2740); + } + + #[test] + fn test_parsing_invalid_input() { + let invalid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=27error40\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; + let invalid_str = "27error40".to_string(); + let int_err = invalid_str.parse::<u64>().unwrap_err(); + + assert_eq!( + Stats::parse_config_str(invalid_input), + Err(Error::IntParseError(invalid_str, int_err)) + ); + } + + #[test] + fn test_parsing_missing_keys() { + let invalid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=2740\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; + assert_eq!( + Stats::parse_config_str(invalid_input), + Err(Error::KeyNotFoundError) + ); + } +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index cff7f2f946..91af5aa7a2 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -1,7 +1,12 @@ -use super::{Config, Error, Result, Tunnel}; +use super::{stats::Stats, Config, Error, Result, Tunnel}; use crate::tunnel::tun_provider::TunProvider; use ipnetwork::IpNetwork; -use std::{ffi::CString, path::Path}; +use std::{ + ffi::{c_void, CStr, CString}, + os::raw::c_char, + path::Path, +}; +use zeroize::Zeroize; #[cfg(target_os = "android")] use crate::tunnel::tun_provider; @@ -315,6 +320,34 @@ impl Tunnel for WgGoTunnel { &self.interface_name } + fn get_config(&self) -> Result<Stats> { + let config_str = unsafe { + let ptr = wgGetConfig(self.handle.unwrap()); + if ptr.is_null() { + log::error!("Failed to get config !"); + return Err(Error::GetConfigError); + } + + CStr::from_ptr(ptr) + }; + + let result = + Stats::parse_config_str(config_str.to_str().expect("Go strings are always UTF-8")) + .map_err(Error::StatsError); + unsafe { + // Zeroing out config string to not leave private key in memory. + let slice = std::slice::from_raw_parts_mut( + config_str.as_ptr() as *mut c_char, + config_str.to_bytes().len(), + ); + slice.zeroize(); + + wgFreePtr(config_str.as_ptr() as *mut c_void); + } + + result + } + fn stop(mut self: Box<Self>) -> Result<()> { self.stop_tunnel() } @@ -373,6 +406,12 @@ extern "C" { fn wgTurnOff(handle: i32) -> i32; // Returns the file descriptor of the tunnel IPv4 socket. + fn wgGetConfig(handle: i32) -> *mut std::os::raw::c_char; + + // Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig + fn wgFreePtr(ptr: *mut c_void); + + // Returns the file descriptor of the tunnel IPv4 socket. #[cfg(target_os = "android")] fn wgGetSocketV4(handle: i32) -> Fd; |
