diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-10-06 23:04:23 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-10-09 19:35:07 +0200 |
| commit | 2087800f46766df575ebd9acbbb912aeca8f7ac6 (patch) | |
| tree | 7515f8bb8bf097bbdcd18abc9fa9059066da4164 | |
| parent | 14ae580f8c06f783a9a9d4c0d0560f469716fdb1 (diff) | |
| download | mullvadvpn-2087800f46766df575ebd9acbbb912aeca8f7ac6.tar.xz mullvadvpn-2087800f46766df575ebd9acbbb912aeca8f7ac6.zip | |
Make BurstGuard configurable
| -rw-r--r-- | talpid-routing/src/debounce.rs | 46 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/mod.rs | 21 | ||||
| -rw-r--r-- | talpid-routing/src/windows/default_route_monitor.rs | 14 |
3 files changed, 57 insertions, 24 deletions
diff --git a/talpid-routing/src/debounce.rs b/talpid-routing/src/debounce.rs index ba1e52250c..2463b2c429 100644 --- a/talpid-routing/src/debounce.rs +++ b/talpid-routing/src/debounce.rs @@ -14,34 +14,39 @@ use std::{ /// `buffer_period`. At which point the wrapped function will be called. pub struct BurstGuard { sender: Sender<BurstGuardEvent>, + /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent + /// before it calls the callback. + buffer_period: Duration, + /// This is the longest period that the `BurstGuard` will wait from the first trigger till + /// it calls the callback. + longest_buffer_period: Duration, } enum BurstGuardEvent { - Trigger, + Trigger(Duration), Shutdown(Sender<()>), } impl BurstGuard { - pub fn new<F: Fn() + Send + 'static>(callback: F) -> Self { - /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent - /// before it calls the callback. - const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); - /// This is the longest period that the `BurstGuard` will wait from the first trigger till - /// it calls the callback. - const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); - + pub fn new<F: Fn() + Send + 'static>( + buffer_period: Duration, + longest_buffer_period: Duration, + callback: F, + ) -> Self { let (sender, listener) = channel(); std::thread::spawn(move || { // The `stop` implementation assumes that this thread will not call `callback` again // if the listener has been dropped. while let Ok(message) = listener.recv() { match message { - BurstGuardEvent::Trigger => { + BurstGuardEvent::Trigger(mut period) => { let start = Instant::now(); loop { - match listener.recv_timeout(BURST_BUFFER_PERIOD) { - Ok(BurstGuardEvent::Trigger) => { - if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD { + match listener.recv_timeout(period) { + Ok(BurstGuardEvent::Trigger(new_period)) => { + period = new_period; + let max_period = std::cmp::max(longest_buffer_period, period); + if start.elapsed() >= max_period { callback(); break; } @@ -67,7 +72,11 @@ impl BurstGuard { } } }); - Self { sender } + Self { + sender, + buffer_period, + longest_buffer_period, + } } /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further @@ -90,6 +99,13 @@ impl BurstGuard { /// Asynchronously trigger burst pub fn trigger(&self) { - self.sender.send(BurstGuardEvent::Trigger).unwrap(); + self.trigger_with_period(self.buffer_period) + } + + /// Asynchronously trigger burst + pub fn trigger_with_period(&self, buffer_period: Duration) { + self.sender + .send(BurstGuardEvent::Trigger(buffer_period)) + .unwrap(); } } diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs index 8cca3594f9..24923ea9af 100644 --- a/talpid-routing/src/unix/macos/mod.rs +++ b/talpid-routing/src/unix/macos/mod.rs @@ -25,6 +25,9 @@ mod watch; pub type Result<T> = std::result::Result<T, Error>; +const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); +const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); + /// Errors that can happen in the macOS routing integration. #[derive(err_derive::Error, Debug)] #[error(no_from)] @@ -93,12 +96,18 @@ impl RouteManagerImpl { manage_tx: Weak<mpsc::UnboundedSender<RouteManagerCommand>>, ) -> Result<Self> { let routing_table = RoutingTable::new().map_err(Error::RoutingTable)?; - let update_trigger = BurstGuard::new(move || { - let Some(manage_tx) = manage_tx.upgrade() else { - return; - }; - let _ = manage_tx.unbounded_send(RouteManagerCommand::RefreshRoutes); - }); + + let update_trigger = BurstGuard::new( + BURST_BUFFER_PERIOD, + BURST_LONGEST_BUFFER_PERIOD, + move || { + let Some(manage_tx) = manage_tx.upgrade() else { + return; + }; + let _ = manage_tx.unbounded_send(RouteManagerCommand::RefreshRoutes); + }, + ); + Ok(Self { routing_table, non_tunnel_routes: HashSet::new(), diff --git a/talpid-routing/src/windows/default_route_monitor.rs b/talpid-routing/src/windows/default_route_monitor.rs index d42dbc91de..0f7d64e3a8 100644 --- a/talpid-routing/src/windows/default_route_monitor.rs +++ b/talpid-routing/src/windows/default_route_monitor.rs @@ -7,6 +7,7 @@ use crate::debounce::BurstGuard; use std::{ ffi::c_void, sync::{Arc, Mutex}, + time::Duration, }; use talpid_types::win32_err; use windows_sys::Win32::{ @@ -173,10 +174,17 @@ impl DefaultRouteMonitor { family, ))); + const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); + const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); + let moved_context = context.clone(); - let burst_guard = Mutex::new(BurstGuard::new(move || { - moved_context.lock().unwrap().evaluate_routes(); - })); + let burst_guard = Mutex::new(BurstGuard::new( + BURST_BUFFER_PERIOD, + BURST_LONGEST_BUFFER_PERIOD, + move || { + moved_context.lock().unwrap().evaluate_routes(); + }, + )); // SAFETY: We need to send the ContextAndBurstGuard to the windows notification functions as // a raw pointer. This imposes the requirement it is not mutated or dropped until |
