diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-10-05 15:17:40 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-10-05 15:17:40 +0200 |
| commit | bb1b0d0b5d33893dc65c09e80ca957daf669983a (patch) | |
| tree | b5695f0b0b72917a250f9a197fb2b93de911f011 | |
| parent | 017dec4c3f522aecf786004ce09c981ef17e1464 (diff) | |
| parent | 4399583c323784d5352a659dc32fcee39c87b9b3 (diff) | |
| download | mullvadvpn-bb1b0d0b5d33893dc65c09e80ca957daf669983a.tar.xz mullvadvpn-bb1b0d0b5d33893dc65c09e80ca957daf669983a.zip | |
Merge branch 'macos-add-routing-debounce' into main
| -rw-r--r-- | talpid-core/src/offline/macos.rs | 8 | ||||
| -rw-r--r-- | talpid-routing/src/debounce.rs | 95 | ||||
| -rw-r--r-- | talpid-routing/src/lib.rs | 3 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/mod.rs | 59 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 11 | ||||
| -rw-r--r-- | talpid-routing/src/windows/default_route_monitor.rs | 96 |
6 files changed, 150 insertions, 122 deletions
diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs index 7f13638ec9..9be5aae8a8 100644 --- a/talpid-core/src/offline/macos.rs +++ b/talpid-core/src/offline/macos.rs @@ -12,9 +12,6 @@ use std::sync::{ }; use talpid_routing::{DefaultRouteEvent, RouteManagerHandle}; -/// How long to wait before announcing changes to the offline state -//const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(2); - #[derive(err_derive::Error, Debug)] pub enum Error { #[error(display = "Failed to initialize route monitor")] @@ -120,11 +117,6 @@ pub async fn spawn_monitor( None => return, }; - // Debounce event updates - // FIXME: Debounce is disabled because the DNS config can get messed up - // when switching between networks otherwise. - //tokio::time::sleep(DEBOUNCE_INTERVAL).await; - if prev_notified.swap(new_connectivity, Ordering::AcqRel) == new_connectivity { // We don't care about network changes here return; diff --git a/talpid-routing/src/debounce.rs b/talpid-routing/src/debounce.rs new file mode 100644 index 0000000000..ba1e52250c --- /dev/null +++ b/talpid-routing/src/debounce.rs @@ -0,0 +1,95 @@ +#![allow(dead_code)] + +use std::{ + sync::mpsc::{channel, RecvTimeoutError, Sender}, + time::{Duration, Instant}, +}; + +/// BurstGuard is a wrapper for a function that protects that function from being called too many +/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point +/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it +/// will call the wrapped function. If another call to `trigger` is made during this wait then it +/// will wait another `buffer_period`, this happens over and over until either +/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in +/// `buffer_period`. At which point the wrapped function will be called. +pub struct BurstGuard { + sender: Sender<BurstGuardEvent>, +} + +enum BurstGuardEvent { + Trigger, + Shutdown(Sender<()>), +} + +impl BurstGuard { + pub fn new<F: Fn() + Send + 'static>(callback: F) -> Self { + /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent + /// before it calls the callback. + const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); + /// This is the longest period that the `BurstGuard` will wait from the first trigger till + /// it calls the callback. + const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); + + let (sender, listener) = channel(); + std::thread::spawn(move || { + // The `stop` implementation assumes that this thread will not call `callback` again + // if the listener has been dropped. + while let Ok(message) = listener.recv() { + match message { + BurstGuardEvent::Trigger => { + let start = Instant::now(); + loop { + match listener.recv_timeout(BURST_BUFFER_PERIOD) { + Ok(BurstGuardEvent::Trigger) => { + if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD { + callback(); + break; + } + } + Ok(BurstGuardEvent::Shutdown(tx)) => { + let _ = tx.send(()); + return; + } + Err(RecvTimeoutError::Timeout) => { + callback(); + break; + } + Err(RecvTimeoutError::Disconnected) => { + break; + } + } + } + } + BurstGuardEvent::Shutdown(tx) => { + let _ = tx.send(()); + return; + } + } + } + }); + Self { sender } + } + + /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further + /// calls to `callback`. + pub fn stop(self) { + let (sender, listener) = channel(); + // If we could not send then it means the thread has already shut down and we can return + if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() { + // We do not care what the result is, if it is OK it means the thread shut down, if + // it is Err it also means it shut down. + let _ = listener.recv(); + } + } + + /// Stop without waiting for in-flight events to complete. + pub fn stop_nonblocking(self) { + let (sender, _listener) = channel(); + let _ = self.sender.send(BurstGuardEvent::Shutdown(sender)); + } + + /// Asynchronously trigger burst + pub fn trigger(&self) { + self.sender.send(BurstGuardEvent::Trigger).unwrap(); + } +} diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index 8985b4e394..dd5fd3a761 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -6,6 +6,9 @@ use ipnetwork::IpNetwork; use std::{fmt, net::IpAddr}; +#[cfg(any(target_os = "windows", target_os = "macos"))] +mod debounce; + #[cfg(target_os = "windows")] #[path = "windows/mod.rs"] mod imp; diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs index c2cd486f71..7aceeba78f 100644 --- a/talpid-routing/src/unix/macos/mod.rs +++ b/talpid-routing/src/unix/macos/mod.rs @@ -1,4 +1,4 @@ -use crate::{NetNode, Node, RequiredRoute, Route}; +use crate::{debounce::BurstGuard, NetNode, Node, RequiredRoute, Route}; use futures::{ channel::mpsc, @@ -7,11 +7,11 @@ use futures::{ }; use ipnetwork::IpNetwork; use nix::sys::socket::{AddressFamily, SockaddrLike, SockaddrStorage}; -use std::pin::Pin; use std::{ collections::{BTreeMap, HashSet}, time::Duration, }; +use std::{pin::Pin, sync::Weak}; use talpid_types::ErrorExt; use watch::RoutingTable; @@ -85,6 +85,7 @@ pub struct RouteManagerImpl { applied_routes: BTreeMap<RouteDestination, RouteMessage>, v4_default_route: Option<data::RouteMessage>, v6_default_route: Option<data::RouteMessage>, + update_trigger: BurstGuard, default_route_listeners: Vec<mpsc::UnboundedSender<DefaultRouteEvent>>, check_default_routes_restored: Pin<Box<dyn FusedStream<Item = ()> + Send>>, } @@ -92,8 +93,16 @@ pub struct RouteManagerImpl { impl RouteManagerImpl { /// Create new route manager #[allow(clippy::unused_async)] - pub async fn new() -> Result<Self> { + pub(crate) async fn new( + manage_tx: Weak<mpsc::UnboundedSender<RouteManagerCommand>>, + ) -> Result<Self> { let routing_table = RoutingTable::new().map_err(Error::RoutingTable)?; + let update_trigger = BurstGuard::new(move || { + let Some(manage_tx) = manage_tx.upgrade() else { + return; + }; + let _ = manage_tx.unbounded_send(RouteManagerCommand::RefreshRoutes); + }); Ok(Self { routing_table, non_tunnel_routes: HashSet::new(), @@ -102,6 +111,7 @@ impl RouteManagerImpl { applied_routes: BTreeMap::new(), v4_default_route: None, v6_default_route: None, + update_trigger, default_route_listeners: vec![], check_default_routes_restored: Box::pin(futures::stream::pending()), }) @@ -129,10 +139,12 @@ impl RouteManagerImpl { ); }); + let mut completion_tx = None; + loop { futures::select_biased! { route_message = self.routing_table.next_message().fuse() => { - self.handle_route_message(route_message).await; + self.handle_route_message(route_message); } _ = self.check_default_routes_restored.next() => { @@ -148,11 +160,8 @@ impl RouteManagerImpl { command = manage_rx.next() => { match command { Some(RouteManagerCommand::Shutdown(tx)) => { - if let Err(err) = self.cleanup_routes().await { - log::error!("Failed to clean up routes: {err}"); - } - let _ = tx.send(()); - return; + completion_tx = Some(tx); + break; }, Some(RouteManagerCommand::NewDefaultRouteListener(tx)) => { @@ -214,6 +223,11 @@ impl RouteManagerImpl { log::error!("Failed to clean up rotues: {err}"); } }, + Some(RouteManagerCommand::RefreshRoutes) => { + if let Err(error) = self.refresh_routes().await { + log::error!("Failed to refresh routes: {error}") + } + }, None => { break; } @@ -225,6 +239,12 @@ impl RouteManagerImpl { if let Err(err) = self.cleanup_routes().await { log::error!("Failed to clean up routing table when shutting down: {err}"); } + + self.update_trigger.stop_nonblocking(); + + if let Some(tx) = completion_tx { + let _ = tx.send(()); + } } async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { @@ -287,7 +307,7 @@ impl RouteManagerImpl { Ok(()) } - async fn handle_route_message( + fn handle_route_message( &mut self, message: std::result::Result<RouteSocketMessage, watch::Error>, ) { @@ -303,18 +323,19 @@ impl RouteManagerImpl { log::error!("Failed to process deleted route: {err}"); } } - - if let Err(error) = self.handle_route_socket_message().await { - log::error!("Failed to process route change: {error}"); + if route.errno() == 0 && route.is_default().unwrap_or(true) { + self.update_trigger.trigger(); } } - Ok(RouteSocketMessage::AddRoute(_)) - | Ok(RouteSocketMessage::ChangeRoute(_)) - | Ok(RouteSocketMessage::AddAddress(_) | RouteSocketMessage::DeleteAddress(_)) => { - if let Err(error) = self.handle_route_socket_message().await { - log::error!("Failed to process route/address change: {error}"); + Ok(RouteSocketMessage::AddRoute(route)) + | Ok(RouteSocketMessage::ChangeRoute(route)) => { + if route.errno() == 0 && route.is_default().unwrap_or(true) { + self.update_trigger.trigger(); } } + Ok(RouteSocketMessage::AddAddress(_) | RouteSocketMessage::DeleteAddress(_)) => { + self.update_trigger.trigger(); + } // ignore all other message types Ok(_) => {} Err(err) => { @@ -329,7 +350,7 @@ impl RouteManagerImpl { /// * At the same time, update the route used by non-tunnel interfaces to reach the relay/VPN /// server. The gateway of the relay route is set to the first interface in the network /// service order that has a working ifscoped default route. - async fn handle_route_socket_message(&mut self) -> Result<()> { + async fn refresh_routes(&mut self) -> Result<()> { self.update_best_default_route(interface::Family::V4) .await?; self.update_best_default_route(interface::Family::V6) diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 757d3775fc..02dac8ac0f 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -7,7 +7,7 @@ use futures::channel::{ mpsc::{self, UnboundedSender}, oneshot, }; -use std::{collections::HashSet, io}; +use std::{collections::HashSet, io, sync::Arc}; #[cfg(any(target_os = "linux", target_os = "macos"))] use futures::stream::Stream; @@ -55,7 +55,7 @@ pub enum Error { /// Handle to a route manager. #[derive(Clone)] pub struct RouteManagerHandle { - tx: UnboundedSender<RouteManagerCommand>, + tx: Arc<UnboundedSender<RouteManagerCommand>>, } impl RouteManagerHandle { @@ -181,6 +181,8 @@ pub(crate) enum RouteManagerCommand { ClearRoutes, Shutdown(oneshot::Sender<()>), #[cfg(target_os = "macos")] + RefreshRoutes, + #[cfg(target_os = "macos")] NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>), #[cfg(target_os = "macos")] GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>), @@ -227,7 +229,7 @@ pub enum CallbackMessage { /// If a destination has to be routed through the default node, /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { - manage_tx: Option<UnboundedSender<RouteManagerCommand>>, + manage_tx: Option<Arc<UnboundedSender<RouteManagerCommand>>>, runtime: tokio::runtime::Handle, } @@ -238,11 +240,14 @@ impl RouteManager { #[cfg(target_os = "linux")] table_id: u32, ) -> Result<Self, Error> { let (manage_tx, manage_rx) = mpsc::unbounded(); + let manage_tx = Arc::new(manage_tx); let manager = imp::RouteManagerImpl::new( #[cfg(target_os = "linux")] fwmark, #[cfg(target_os = "linux")] table_id, + #[cfg(target_os = "macos")] + Arc::downgrade(&manage_tx), ) .await?; tokio::spawn(manager.run(manage_rx)); diff --git a/talpid-routing/src/windows/default_route_monitor.rs b/talpid-routing/src/windows/default_route_monitor.rs index dbc0615e43..d42dbc91de 100644 --- a/talpid-routing/src/windows/default_route_monitor.rs +++ b/talpid-routing/src/windows/default_route_monitor.rs @@ -2,14 +2,11 @@ use super::{ get_best_default_route, get_best_default_route::route_has_gateway, Error, InterfaceAndGateway, Result, }; +use crate::debounce::BurstGuard; use std::{ ffi::c_void, - sync::{ - mpsc::{channel, RecvTimeoutError, Sender}, - Arc, Mutex, - }, - time::{Duration, Instant}, + sync::{Arc, Mutex}, }; use talpid_types::win32_err; use windows_sys::Win32::{ @@ -125,10 +122,8 @@ impl Drop for DefaultRouteMonitor { let context = unsafe { Box::from_raw(self.context as *mut ContextAndBurstGuard) }; // Stop the burst guard - context.burst_guard.lock().unwrap().stop(); - - // Drop the context now that we are guaranteed nothing might try to access the context - drop(context); + let context = context.burst_guard.into_inner().unwrap(); + context.stop(); } } @@ -350,86 +345,3 @@ unsafe extern "system" fn ip_address_change_callback( context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex); context_and_burst.burst_guard.lock().unwrap().trigger(); } - -/// BurstGuard is a wrapper for a function that protects that function from being called too many -/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point -/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it -/// will call the wrapped function. If another call to `trigger` is made during this wait then it -/// will wait another `buffer_period`, this happens over and over until either -/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in -/// `buffer_period`. At which point the wrapped function will be called. -struct BurstGuard { - sender: Sender<BurstGuardEvent>, -} - -enum BurstGuardEvent { - Trigger, - Shutdown(Sender<()>), -} - -impl BurstGuard { - fn new<F: Fn() + Send + 'static>(callback: F) -> Self { - /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent - /// before it calls the callback. - const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); - /// This is the longest period that the `BurstGuard` will wait from the first trigger till - /// it calls the callback. - const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); - - let (sender, listener) = channel(); - std::thread::spawn(move || { - // The `stop` implementation assumes that this thread will not call `callback` again - // if the listener has been dropped. - while let Ok(message) = listener.recv() { - match message { - BurstGuardEvent::Trigger => { - let start = Instant::now(); - loop { - match listener.recv_timeout(BURST_BUFFER_PERIOD) { - Ok(BurstGuardEvent::Trigger) => { - if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD { - callback(); - break; - } - } - Ok(BurstGuardEvent::Shutdown(tx)) => { - let _ = tx.send(()); - return; - } - Err(RecvTimeoutError::Timeout) => { - callback(); - break; - } - Err(RecvTimeoutError::Disconnected) => { - break; - } - } - } - } - BurstGuardEvent::Shutdown(tx) => { - let _ = tx.send(()); - return; - } - } - } - }); - Self { sender } - } - - /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further - /// calls to `callback`. - fn stop(&self) { - let (sender, listener) = channel(); - // If we could not send then it means the thread has already shut down and we can return - if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() { - // We do not care what the result is, if it is OK it means the thread shut down, if - // it is Err it also means it shut down. - let _ = listener.recv(); - } - } - - /// Non-blocking - fn trigger(&self) { - self.sender.send(BurstGuardEvent::Trigger).unwrap(); - } -} |
