diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-03-11 14:24:48 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-03-14 13:01:06 +0100 |
| commit | 69950ff6070cacc3ba5ad3a445ccd4c7c05ab180 (patch) | |
| tree | cca20e246852e341aa2b356202b4e1d45e74fe93 | |
| parent | 4c7327dd7e6bdd59b1086f96fc042cca34c32581 (diff) | |
| download | mullvadvpn-69950ff6070cacc3ba5ad3a445ccd4c7c05ab180.tar.xz mullvadvpn-69950ff6070cacc3ba5ad3a445ccd4c7c05ab180.zip | |
Simplify route manager handle
| -rw-r--r-- | talpid-core/src/offline/macos.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/offline/windows.rs | 13 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 31 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 16 | ||||
| -rw-r--r-- | talpid-openvpn/src/lib.rs | 6 | ||||
| -rw-r--r-- | talpid-routing/src/lib.rs | 6 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 262 | ||||
| -rw-r--r-- | talpid-routing/src/windows/mod.rs | 132 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 2 |
9 files changed, 170 insertions, 304 deletions
diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs index 2813dc4211..daafecb052 100644 --- a/talpid-core/src/offline/macos.rs +++ b/talpid-core/src/offline/macos.rs @@ -65,14 +65,14 @@ impl ConnectivityInner { pub async fn spawn_monitor( notify_tx: UnboundedSender<Connectivity>, - route_manager_handle: RouteManagerHandle, + route_manager: RouteManagerHandle, ) -> Result<MonitorHandle, Error> { let notify_tx = Arc::new(notify_tx); // note: begin observing before initializing the state - let route_listener = route_manager_handle.default_route_listener().await?; + let route_listener = route_manager.default_route_listener().await?; - let (ipv4, ipv6) = match route_manager_handle.get_default_routes().await { + let (ipv4, ipv6) = match route_manager.get_default_routes().await { Ok((v4_route, v6_route)) => (v4_route.is_some(), v6_route.is_some()), Err(error) => { log::warn!("Failed to initialize offline monitor: {error}"); diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs index f47fe8dd4d..5e09763cd0 100644 --- a/talpid-core/src/offline/windows.rs +++ b/talpid-core/src/offline/windows.rs @@ -29,7 +29,7 @@ unsafe impl Send for BroadcastListener {} impl BroadcastListener { pub async fn start( notify_tx: UnboundedSender<Connectivity>, - route_manager_handle: RouteManagerHandle, + route_manager: RouteManagerHandle, mut power_mgmt_rx: PowerManagementListener, ) -> Result<Self, Error> { let notify_tx = Arc::new(notify_tx); @@ -67,8 +67,7 @@ impl BroadcastListener { }); let callback_handle = - Self::setup_network_connectivity_listener(system_state.clone(), route_manager_handle) - .await?; + Self::setup_network_connectivity_listener(system_state.clone(), route_manager).await?; Ok(BroadcastListener { system_state, @@ -107,9 +106,9 @@ impl BroadcastListener { /// until after `WinNet_DeactivateConnectivityMonitor` has been called. async fn setup_network_connectivity_listener( system_state: Arc<Mutex<SystemState>>, - route_manager_handle: RouteManagerHandle, + route_manager: RouteManagerHandle, ) -> Result<CallbackHandle, Error> { - let change_handle = route_manager_handle + let change_handle = route_manager .add_default_route_change_callback(Box::new(move |event, addr_family| { Self::connectivity_callback(event, addr_family, &system_state) })) @@ -202,10 +201,10 @@ pub type MonitorHandle = BroadcastListener; pub async fn spawn_monitor( sender: UnboundedSender<Connectivity>, - route_manager_handle: RouteManagerHandle, + route_manager: RouteManagerHandle, ) -> Result<MonitorHandle, Error> { let power_mgmt_rx = crate::window::PowerManagementListener::new(); - BroadcastListener::start(sender, route_manager_handle, power_mgmt_rx).await + BroadcastListener::start(sender, route_manager, power_mgmt_rx).await } fn apply_system_state_change(state: Arc<Mutex<SystemState>>, change: StateChange) { diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 927de207bf..71a88e64fb 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -18,7 +18,7 @@ use std::{ thread, time::{Duration, Instant}, }; -use talpid_routing::RouteManager; +use talpid_routing::RouteManagerHandle; use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use talpid_types::{ net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters}, @@ -60,9 +60,9 @@ impl ConnectingState { if shared_values.connectivity.is_offline() { // FIXME: Temporary: Nudge route manager to update the default interface #[cfg(target_os = "macos")] - if let Ok(handle) = shared_values.route_manager.handle() { + { log::debug!("Poking route manager to update default routes"); - let _ = handle.refresh_routes(); + let _ = shared_values.route_manager.refresh_routes(); } return ErrorState::enter(shared_values, ErrorStateCause::IsOffline); } @@ -189,7 +189,7 @@ impl ConnectingState { log_dir: &Option<PathBuf>, resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: &RouteManager, + route_manager: &RouteManagerHandle, retry_attempt: u32, ) -> Self { let (event_tx, event_rx) = mpsc::unbounded(); @@ -202,7 +202,7 @@ impl ConnectingState { }) }; - let route_manager_handle = route_manager.handle(); + let route_manager = route_manager.clone(); let log_dir = log_dir.clone(); let resource_dir = resource_dir.to_path_buf(); @@ -214,25 +214,6 @@ impl ConnectingState { tokio::task::spawn_blocking(move || { let start = Instant::now(); - let route_manager_handle = match route_manager_handle { - Ok(handle) => handle, - Err(error) => { - if tunnel_close_event_tx - .send(Some(ErrorStateCause::StartTunnelError)) - .is_err() - { - log::warn!( - "Tunnel state machine stopped before receiving tunnel closed event" - ); - } - log::error!( - "{}", - error.display_chain_with_msg("Failed to obtain route monitor handle") - ); - return; - } - }; - let args = TunnelArgs { runtime, resource_dir: &resource_dir, @@ -240,7 +221,7 @@ impl ConnectingState { tunnel_close_rx, tun_provider, retry_attempt, - route_manager: route_manager_handle, + route_manager, }; let block_reason = match TunnelMonitor::start(&mut tunnel_parameters, &log_dir, args) { diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 75f6cc1ced..bee32bb31d 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -21,7 +21,7 @@ use crate::{ }; #[cfg(windows)] use std::ffi::OsString; -use talpid_routing::RouteManager; +use talpid_routing::RouteManagerHandle; use talpid_tunnel::{tun_provider::TunProvider, TunnelEvent}; use futures::{ @@ -269,7 +269,7 @@ impl TunnelStateMachine { #[cfg(target_os = "macos")] let filtering_resolver = crate::resolver::start_resolver().await?; - let route_manager = RouteManager::new( + let route_manager = RouteManagerHandle::spawn( #[cfg(target_os = "linux")] args.linux_ids.fwmark, #[cfg(target_os = "linux")] @@ -284,9 +284,7 @@ impl TunnelStateMachine { args.resource_dir.clone(), args.command_tx.clone(), volume_update_rx, - route_manager - .handle() - .map_err(Error::InitRouteManagerError)?, + route_manager.clone(), ) .map_err(Error::InitSplitTunneling)?; @@ -308,9 +306,7 @@ impl TunnelStateMachine { #[cfg(target_os = "linux")] runtime.clone(), #[cfg(target_os = "linux")] - route_manager - .handle() - .map_err(Error::InitRouteManagerError)?, + route_manager.clone(), #[cfg(target_os = "macos")] args.command_tx.clone(), ) @@ -331,7 +327,7 @@ impl TunnelStateMachine { let offline_monitor = offline::spawn_monitor( offline_tx, #[cfg(not(target_os = "android"))] - route_manager.handle()?, + route_manager.clone(), #[cfg(target_os = "linux")] Some(args.linux_ids.fwmark), #[cfg(target_os = "android")] @@ -436,7 +432,7 @@ struct SharedTunnelStateValues { runtime: tokio::runtime::Handle, firewall: Firewall, dns_monitor: DnsMonitor, - route_manager: RouteManager, + route_manager: RouteManagerHandle, _offline_monitor: offline::MonitorHandle, /// Should LAN access be allowed outside the tunnel. allow_lan: bool, diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 7426228812..da30c1940f 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -312,7 +312,7 @@ impl OpenVpnMonitor<OpenVpnCommand> { proxy_auth_file_path: proxy_auth_file_path.clone(), abort_server_tx: event_server_abort_tx, proxy: params.proxy.clone(), - route_manager_handle: route_manager, + route_manager, #[cfg(target_os = "linux")] ipv6_enabled, }, @@ -817,7 +817,7 @@ mod event_server { pub proxy_auth_file_path: Option<super::PathBuf>, pub abort_server_tx: triggered::Trigger, pub proxy: Option<CustomProxy>, - pub route_manager_handle: talpid_routing::RouteManagerHandle, + pub route_manager: talpid_routing::RouteManagerHandle, #[cfg(target_os = "linux")] pub ipv6_enabled: bool, } @@ -864,7 +864,7 @@ mod event_server { let route = talpid_routing::RequiredRoute::new(network, node); routes.insert(route); } - let route_handle = self.route_manager_handle.clone(); + let route_handle = self.route_manager.clone(); #[cfg(target_os = "linux")] { diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index d525b12435..f15489fcb6 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -25,7 +25,7 @@ use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; #[cfg(target_os = "macos")] pub use imp::{imp::RouteError, DefaultRouteEvent, PlatformError}; -pub use imp::{Error, RouteManager, RouteManagerHandle}; +pub use imp::{Error, RouteManagerHandle}; /// A network route with a specific network node, destination and an optional metric. #[derive(Debug, Hash, Eq, PartialEq, Clone)] @@ -81,7 +81,7 @@ impl fmt::Display for Route { } } -/// A network route that should be applied by the RouteManager. +/// A network route that should be applied by the route manager. /// It can either be routed through a specific network node or it can be routed through the current /// default route. #[derive(Debug, Hash, Eq, PartialEq, Clone)] @@ -130,7 +130,7 @@ impl RequiredRoute { #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub enum NetNode { /// A real node will be used to set a regular route that will remain unchanged for the lifetime - /// of the RouteManager + /// of the route manager RealNode(Node), /// A default node is a symbolic node that will resolve to the network node used in the current /// most preferable default route diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 768000c010..7fe7e9bf31 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -32,7 +32,7 @@ mod imp; pub use imp::Error as PlatformError; -/// Errors that can be encountered whilst initializing RouteManager +/// Errors that can be encountered whilst initializing route manager #[derive(thiserror::Error, Debug)] pub enum Error { /// Route manager thread may have panicked @@ -65,25 +65,123 @@ impl Error { } } -/// Handle to a route manager. -#[derive(Clone)] +/// Represents a firewall mark. +#[cfg(target_os = "linux")] +type Fwmark = u32; + +/// Commands for the underlying route manager object. +#[derive(Debug)] +pub(crate) enum RouteManagerCommand { + AddRoutes( + HashSet<RequiredRoute>, + oneshot::Sender<Result<(), PlatformError>>, + ), + ClearRoutes, + Shutdown(oneshot::Sender<()>), + #[cfg(target_os = "macos")] + RefreshRoutes, + #[cfg(target_os = "macos")] + NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>), + #[cfg(target_os = "macos")] + GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>), + #[cfg(target_os = "linux")] + CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>), + #[cfg(target_os = "linux")] + ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>), + #[cfg(target_os = "linux")] + NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>), + #[cfg(target_os = "linux")] + GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>), + /// Attempt to fetch a route for the given destination with an optional firewall mark. + #[cfg(target_os = "linux")] + GetDestinationRoute( + IpAddr, + Option<Fwmark>, + oneshot::Sender<Result<Option<Route>, PlatformError>>, + ), +} + +/// Event that is sent when a preferred non-tunnel default route is +/// added or removed. +#[cfg(target_os = "macos")] +#[derive(Debug, Clone, Copy)] +pub enum DefaultRouteEvent { + /// Added or updated a non-tunnel default IPv4 route + AddedOrChangedV4, + /// Added or updated a non-tunnel default IPv6 route + AddedOrChangedV6, + /// Non-tunnel default IPv4 route was removed + RemovedV4, + /// Non-tunnel default IPv6 route was removed + RemovedV6, +} + +#[cfg(target_os = "linux")] +#[derive(Debug, Clone)] +pub enum CallbackMessage { + NewRoute(Route), + DelRoute(Route), +} + +/// Route manager applies a set of routes to the route table. +/// If a destination has to be routed through the default node, +/// the route will be adjusted dynamically when the default route changes. +#[derive(Debug, Clone)] pub struct RouteManagerHandle { tx: Arc<UnboundedSender<RouteManagerCommand>>, } impl RouteManagerHandle { - /// Applies the given routes while the route manager is running. + /// Construct a route manager. + pub async fn spawn( + #[cfg(target_os = "linux")] fwmark: u32, + #[cfg(target_os = "linux")] table_id: u32, + ) -> Result<Self, Error> { + let (manage_tx, manage_rx) = mpsc::unbounded(); + let manage_tx = Arc::new(manage_tx); + let manager = imp::RouteManagerImpl::new( + #[cfg(target_os = "linux")] + fwmark, + #[cfg(target_os = "linux")] + table_id, + #[cfg(target_os = "macos")] + Arc::downgrade(&manage_tx), + ) + .await?; + tokio::spawn(manager.run(manage_rx)); + + Ok(Self { tx: manage_tx }) + } + + /// Stop route manager and revert all changes to routing + pub async fn stop(&self) { + let (wait_tx, wait_rx) = oneshot::channel(); + let _ = self + .tx + .unbounded_send(RouteManagerCommand::Shutdown(wait_tx)); + let _ = wait_rx.await; + } + + /// Applies the given routes until they are cleared pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { - let (response_tx, response_rx) = oneshot::channel(); + let (result_tx, result_rx) = oneshot::channel(); self.tx - .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) + .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) .map_err(|_| Error::RouteManagerDown)?; - response_rx + + result_rx .await .map_err(|_| Error::ManagerChannelDown)? .map_err(Error::PlatformError) } + /// Removes all routes previously applied in [`RouteManager::add_routes`]. + pub fn clear_routes(&self) -> Result<(), Error> { + self.tx + .unbounded_send(RouteManagerCommand::ClearRoutes) + .map_err(|_| Error::RouteManagerDown) + } + /// Listen for non-tunnel default route changes. #[cfg(target_os = "macos")] pub async fn default_route_listener( @@ -187,153 +285,3 @@ impl RouteManagerHandle { .map_err(Error::PlatformError) } } - -/// Represents a firewall mark. -#[cfg(target_os = "linux")] -type Fwmark = u32; - -/// Commands for the underlying route manager object. -#[derive(Debug)] -pub(crate) enum RouteManagerCommand { - AddRoutes( - HashSet<RequiredRoute>, - oneshot::Sender<Result<(), PlatformError>>, - ), - ClearRoutes, - Shutdown(oneshot::Sender<()>), - #[cfg(target_os = "macos")] - RefreshRoutes, - #[cfg(target_os = "macos")] - NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>), - #[cfg(target_os = "macos")] - GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>), - #[cfg(target_os = "linux")] - CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>), - #[cfg(target_os = "linux")] - ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>), - #[cfg(target_os = "linux")] - NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>), - #[cfg(target_os = "linux")] - GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>), - /// Attempt to fetch a route for the given destination with an optional firewall mark. - #[cfg(target_os = "linux")] - GetDestinationRoute( - IpAddr, - Option<Fwmark>, - oneshot::Sender<Result<Option<Route>, PlatformError>>, - ), -} - -/// Event that is sent when a preferred non-tunnel default route is -/// added or removed. -#[cfg(target_os = "macos")] -#[derive(Debug, Clone, Copy)] -pub enum DefaultRouteEvent { - /// Added or updated a non-tunnel default IPv4 route - AddedOrChangedV4, - /// Added or updated a non-tunnel default IPv6 route - AddedOrChangedV6, - /// Non-tunnel default IPv4 route was removed - RemovedV4, - /// Non-tunnel default IPv6 route was removed - RemovedV6, -} - -#[cfg(target_os = "linux")] -#[derive(Debug, Clone)] -pub enum CallbackMessage { - NewRoute(Route), - DelRoute(Route), -} - -/// RouteManager applies a set of routes to the route table. -/// If a destination has to be routed through the default node, -/// the route will be adjusted dynamically when the default route changes. -pub struct RouteManager { - manage_tx: Option<Arc<UnboundedSender<RouteManagerCommand>>>, -} - -impl RouteManager { - /// Construct a RouteManager. - pub async fn new( - #[cfg(target_os = "linux")] fwmark: u32, - #[cfg(target_os = "linux")] table_id: u32, - ) -> Result<Self, Error> { - let (manage_tx, manage_rx) = mpsc::unbounded(); - let manage_tx = Arc::new(manage_tx); - let manager = imp::RouteManagerImpl::new( - #[cfg(target_os = "linux")] - fwmark, - #[cfg(target_os = "linux")] - table_id, - #[cfg(target_os = "macos")] - Arc::downgrade(&manage_tx), - ) - .await?; - tokio::spawn(manager.run(manage_rx)); - - Ok(Self { - manage_tx: Some(manage_tx), - }) - } - - /// Stops RouteManager and removes all of the applied routes. - pub async fn stop(&mut self) { - if let Some(tx) = self.manage_tx.take() { - let (wait_tx, wait_rx) = oneshot::channel(); - let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(wait_tx)); - let _ = wait_rx.await; - } - } - - /// Applies the given routes until [`RouteManager::stop`] is called. - pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { - let tx = self.get_command_tx()?; - let (result_tx, result_rx) = oneshot::channel(); - tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) - .map_err(|_| Error::RouteManagerDown)?; - - result_rx - .await - .map_err(|_| Error::ManagerChannelDown)? - .map_err(Error::PlatformError) - } - - /// Removes all routes previously applied in [`RouteManager::add_routes`]. - pub fn clear_routes(&self) -> Result<(), Error> { - let tx = self.get_command_tx()?; - tx.unbounded_send(RouteManagerCommand::ClearRoutes) - .map_err(|_| Error::RouteManagerDown) - } - - /// Ensure that packets are routed using the correct tables. - #[cfg(target_os = "linux")] - pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { - self.handle()?.create_routing_rules(enable_ipv6).await - } - - /// Remove any routing rules created by [Self::create_routing_rules]. - #[cfg(target_os = "linux")] - pub async fn clear_routing_rules(&self) -> Result<(), Error> { - self.handle()?.clear_routing_rules().await - } - - /// Retrieve a sender directly to the command channel. - pub fn handle(&self) -> Result<RouteManagerHandle, Error> { - let tx = self.get_command_tx()?; - Ok(RouteManagerHandle { tx: tx.clone() }) - } - - fn get_command_tx(&self) -> Result<&Arc<UnboundedSender<RouteManagerCommand>>, Error> { - self.manage_tx.as_ref().ok_or(Error::RouteManagerDown) - } -} - -impl Drop for RouteManager { - fn drop(&mut self) { - if let Some(tx) = self.manage_tx.take() { - let (done_tx, _) = oneshot::channel(); - let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx)); - } - } -} diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs index c03beea8cf..c158938ebd 100644 --- a/talpid-routing/src/windows/mod.rs +++ b/talpid-routing/src/windows/mod.rs @@ -21,9 +21,6 @@ mod route_manager; /// Windows routing errors. #[derive(thiserror::Error, Debug)] pub enum Error { - /// The sender was dropped unexpectedly -- possible panic - #[error("The channel sender was dropped")] - ManagerChannelDown, /// Failure to initialize route manager #[error("Failed to start route manager")] FailedToStartManager, @@ -101,17 +98,31 @@ impl Error { pub type Result<T> = std::result::Result<T, Error>; /// Manages routes by calling into WinNet -pub struct RouteManager { - manage_tx: Option<UnboundedSender<RouteManagerCommand>>, -} - -/// Handle to a route manager. -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct RouteManagerHandle { tx: UnboundedSender<RouteManagerCommand>, } +pub enum RouteManagerCommand { + AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), + GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>), + ClearRoutes, + RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>), + Shutdown(oneshot::Sender<()>), +} + impl RouteManagerHandle { + /// Create a new route manager + #[allow(clippy::unused_async)] + pub async fn spawn() -> Result<Self> { + let internal = RouteManagerInternal::new().map_err(|_| Error::FailedToStartManager)?; + let (tx, rx) = mpsc::unbounded(); + let handle = Self { tx }; + tokio::spawn(RouteManagerHandle::run(rx, internal)); + + Ok(handle) + } + /// Add a callback which will be called if the default route changes. pub async fn add_default_route_change_callback( &self, @@ -124,7 +135,7 @@ impl RouteManagerHandle { response_tx, )) .map_err(|_| Error::RouteManagerDown)?; - response_rx.await.map_err(|_| Error::ManagerChannelDown) + response_rx.await.map_err(|_| Error::RouteManagerDown) } /// Applies the given routes while the route manager is running. @@ -133,65 +144,35 @@ impl RouteManagerHandle { self.tx .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) .map_err(|_| Error::RouteManagerDown)?; - response_rx.await.map_err(|_| Error::ManagerChannelDown)? + response_rx.await.map_err(|_| Error::RouteManagerDown)? } - /// Applies the given routes while the route manager is running. + /// Retrieve MTU for the given destination/route. pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> { let (response_tx, response_rx) = oneshot::channel(); self.tx .unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx)) .map_err(|_| Error::RouteManagerDown)?; - response_rx.await.map_err(|_| Error::ManagerChannelDown)? - } -} - -pub enum RouteManagerCommand { - AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), - GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>), - ClearRoutes, - RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>), - Shutdown(oneshot::Sender<Result<()>>), -} - -impl RouteManager { - /// Create a new route manager - #[allow(clippy::unused_async)] - pub async fn new() -> Result<Self> { - let internal = match RouteManagerInternal::new() { - Ok(internal) => internal, - Err(_) => return Err(Error::FailedToStartManager), - }; - let (manage_tx, manage_rx) = mpsc::unbounded(); - let manager = Self { - manage_tx: Some(manage_tx), - }; - tokio::spawn(RouteManager::listen(manage_rx, internal)); - - Ok(manager) + response_rx.await.map_err(|_| Error::RouteManagerDown)? } - /// Add a callback which will be called if the default route changes. - pub async fn add_default_route_change_callback( - &self, - callback: Callback, - ) -> Result<CallbackHandle> { - let tx = self.get_command_tx()?; + /// Stop the routing manager actor and revert all changes to routing + pub async fn stop(&self) { let (result_tx, result_rx) = oneshot::channel(); - tx.unbounded_send(RouteManagerCommand::RegisterDefaultRouteChangeCallback( - callback, result_tx, - )) - .map_err(|_| Error::RouteManagerDown)?; - result_rx.await.map_err(|_| Error::ManagerChannelDown) + _ = self + .tx + .unbounded_send(RouteManagerCommand::Shutdown(result_tx)); + _ = result_rx.await; } - /// Retrieve a sender directly to the command channel. - pub fn handle(&self) -> Result<RouteManagerHandle> { - let tx = self.get_command_tx()?; - Ok(RouteManagerHandle { tx: tx.clone() }) + /// Removes all routes previously applied in [`RouteManager::add_routes`]. + pub fn clear_routes(&self) -> Result<()> { + self.tx + .unbounded_send(RouteManagerCommand::ClearRoutes) + .map_err(|_| Error::RouteManagerDown) } - async fn listen( + async fn run( mut manage_rx: UnboundedReceiver<RouteManagerCommand>, mut internal: RouteManagerInternal, ) { @@ -235,42 +216,12 @@ impl RouteManager { } RouteManagerCommand::Shutdown(tx) => { drop(internal); - let _ = tx.send(Ok(())); + let _ = tx.send(()); break; } } } } - - /// Stops the routing manager and invalidates the route manager - no new default route callbacks - /// can be added - pub async fn stop(&mut self) { - if let Some(tx) = self.manage_tx.take() { - let (result_tx, result_rx) = oneshot::channel(); - let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(result_tx)); - _ = result_rx.await; - } - } - - /// Applies the given routes until [`RouteManager::stop`] is called. - pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - let tx = self.get_command_tx()?; - let (result_tx, result_rx) = oneshot::channel(); - tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) - .map_err(|_| Error::RouteManagerDown)?; - result_rx.await.map_err(|_| Error::ManagerChannelDown)? - } - - /// Removes all routes previously applied in [`RouteManager::add_routes`]. - pub fn clear_routes(&self) -> Result<()> { - let tx = self.get_command_tx()?; - tx.unbounded_send(RouteManagerCommand::ClearRoutes) - .map_err(|_| Error::RouteManagerDown) - } - - fn get_command_tx(&self) -> Result<&UnboundedSender<RouteManagerCommand>> { - self.manage_tx.as_ref().ok_or(Error::RouteManagerDown) - } } fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> { @@ -292,12 +243,3 @@ fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> { } } } - -impl Drop for RouteManager { - fn drop(&mut self) { - if let Some(tx) = self.manage_tx.take() { - let (done_tx, _) = oneshot::channel(); - let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx)); - } - } -} diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 4f02bfcb09..72d6a31566 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -718,7 +718,7 @@ impl WireguardMonitor { resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, #[cfg(target_os = "android")] psk_negotiation: bool, - #[cfg(windows)] route_manager_handle: crate::routing::RouteManagerHandle, + #[cfg(windows)] route_manager: crate::routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { log::debug!("Tunnel MTU: {}", config.mtu); |
