diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-03-11 13:35:29 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-03-11 13:35:29 +0100 |
| commit | 579945fc6602d4127ee750eaab97f7582ed06b8c (patch) | |
| tree | 0b0cca7d377db0749682079dbc6b9b82f4e53569 | |
| parent | 1fa40d1ea735c3639552c71e3e980cc99707cf27 (diff) | |
| parent | 9e0f9a1e3fc2559f0820e1c9d2ed4450ab2af02b (diff) | |
| download | mullvadvpn-579945fc6602d4127ee750eaab97f7582ed06b8c.tar.xz mullvadvpn-579945fc6602d4127ee750eaab97f7582ed06b8c.zip | |
Merge branch 'routing-improvements' into main
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 4 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 82 | ||||
| -rw-r--r-- | talpid-routing/src/windows/mod.rs | 77 |
3 files changed, 64 insertions, 99 deletions
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 26a03328a4..75f6cc1ced 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -409,7 +409,9 @@ impl TunnelStateMachine { } } - log::debug!("Exiting tunnel state machine loop"); + log::debug!("Tunnel state machine exited"); + + runtime.block_on(self.shared_values.route_manager.stop()); } } diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 45b7b92b60..768000c010 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -7,7 +7,7 @@ use futures::channel::{ mpsc::{self, UnboundedSender}, oneshot, }; -use std::{collections::HashSet, io, sync::Arc}; +use std::{collections::HashSet, sync::Arc}; #[cfg(any(target_os = "linux", target_os = "macos"))] use futures::stream::Stream; @@ -41,12 +41,6 @@ pub enum Error { /// Platform specific error occurred #[error("Internal route manager error")] PlatformError(#[from] imp::Error), - /// Failed to spawn route manager future - #[error("Failed to spawn route manager on the provided executor")] - FailedToSpawnManager, - /// Failed to spawn route manager runtime - #[error("Failed to spawn route manager runtime")] - FailedToSpawnRuntime(#[from] io::Error), /// Attempt to use route manager that has been dropped #[error("Cannot send message to route manager since it is down")] RouteManagerDown, @@ -257,7 +251,6 @@ pub enum CallbackMessage { /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { manage_tx: Option<Arc<UnboundedSender<RouteManagerCommand>>>, - runtime: tokio::runtime::Handle, } impl RouteManager { @@ -280,7 +273,6 @@ impl RouteManager { tokio::spawn(manager.run(manage_rx)); Ok(Self { - runtime: tokio::runtime::Handle::current(), manage_tx: Some(manage_tx), }) } @@ -289,77 +281,59 @@ impl RouteManager { pub async fn stop(&mut self) { if let Some(tx) = self.manage_tx.take() { let (wait_tx, wait_rx) = oneshot::channel(); - - if tx - .unbounded_send(RouteManagerCommand::Shutdown(wait_tx)) - .is_err() - { - log::error!("RouteManager already down!"); - return; - } - - if wait_rx.await.is_err() { - log::error!("{}", Error::ManagerChannelDown); - } + let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(wait_tx)); + let _ = wait_rx.await; } } /// Applies the given routes until [`RouteManager::stop`] is called. - pub async fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { - if let Some(tx) = &self.manage_tx { - let (result_tx, result_rx) = oneshot::channel(); - if tx - .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) - .is_err() - { - return Err(Error::RouteManagerDown); - } + pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { + let tx = self.get_command_tx()?; + let (result_tx, result_rx) = oneshot::channel(); + tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .map_err(|_| Error::RouteManagerDown)?; - result_rx - .await - .map_err(|_| Error::ManagerChannelDown)? - .map_err(Error::PlatformError) - } else { - Err(Error::RouteManagerDown) - } + result_rx + .await + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) } /// Removes all routes previously applied in [`RouteManager::add_routes`]. - pub fn clear_routes(&mut self) -> Result<(), Error> { - if let Some(tx) = &self.manage_tx { - if tx.unbounded_send(RouteManagerCommand::ClearRoutes).is_err() { - return Err(Error::RouteManagerDown); - } - Ok(()) - } else { - Err(Error::RouteManagerDown) - } + pub fn clear_routes(&self) -> Result<(), Error> { + let tx = self.get_command_tx()?; + tx.unbounded_send(RouteManagerCommand::ClearRoutes) + .map_err(|_| Error::RouteManagerDown) } /// Ensure that packets are routed using the correct tables. #[cfg(target_os = "linux")] - pub async fn create_routing_rules(&mut self, enable_ipv6: bool) -> Result<(), Error> { + pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { self.handle()?.create_routing_rules(enable_ipv6).await } /// Remove any routing rules created by [Self::create_routing_rules]. #[cfg(target_os = "linux")] - pub async fn clear_routing_rules(&mut self) -> Result<(), Error> { + pub async fn clear_routing_rules(&self) -> Result<(), Error> { self.handle()?.clear_routing_rules().await } /// Retrieve a sender directly to the command channel. pub fn handle(&self) -> Result<RouteManagerHandle, Error> { - if let Some(tx) = &self.manage_tx { - Ok(RouteManagerHandle { tx: tx.clone() }) - } else { - Err(Error::RouteManagerDown) - } + let tx = self.get_command_tx()?; + Ok(RouteManagerHandle { tx: tx.clone() }) + } + + fn get_command_tx(&self) -> Result<&Arc<UnboundedSender<RouteManagerCommand>>, Error> { + self.manage_tx.as_ref().ok_or(Error::RouteManagerDown) } } impl Drop for RouteManager { fn drop(&mut self) { - self.runtime.clone().block_on(self.stop()); + if let Some(tx) = self.manage_tx.take() { + let (done_tx, _) = oneshot::channel(); + let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx)); + } } } diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs index 055b4c2b68..c03beea8cf 100644 --- a/talpid-routing/src/windows/mod.rs +++ b/talpid-routing/src/windows/mod.rs @@ -151,7 +151,7 @@ pub enum RouteManagerCommand { GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>), ClearRoutes, RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>), - Shutdown, + Shutdown(oneshot::Sender<Result<()>>), } impl RouteManager { @@ -176,29 +176,19 @@ impl RouteManager { &self, callback: Callback, ) -> Result<CallbackHandle> { - if let Some(tx) = &self.manage_tx { - let (result_tx, result_rx) = oneshot::channel(); - if tx - .unbounded_send(RouteManagerCommand::RegisterDefaultRouteChangeCallback( - callback, result_tx, - )) - .is_err() - { - return Err(Error::RouteManagerDown); - } - Ok(result_rx.await.map_err(|_| Error::ManagerChannelDown)?) - } else { - Err(Error::RouteManagerDown) - } + let tx = self.get_command_tx()?; + let (result_tx, result_rx) = oneshot::channel(); + tx.unbounded_send(RouteManagerCommand::RegisterDefaultRouteChangeCallback( + callback, result_tx, + )) + .map_err(|_| Error::RouteManagerDown)?; + result_rx.await.map_err(|_| Error::ManagerChannelDown) } /// Retrieve a sender directly to the command channel. pub fn handle(&self) -> Result<RouteManagerHandle> { - if let Some(tx) = &self.manage_tx { - Ok(RouteManagerHandle { tx: tx.clone() }) - } else { - Err(Error::RouteManagerDown) - } + let tx = self.get_command_tx()?; + Ok(RouteManagerHandle { tx: tx.clone() }) } async fn listen( @@ -243,7 +233,9 @@ impl RouteManager { RouteManagerCommand::RegisterDefaultRouteChangeCallback(callback, tx) => { let _ = tx.send(internal.register_default_route_changed_callback(callback)); } - RouteManagerCommand::Shutdown => { + RouteManagerCommand::Shutdown(tx) => { + drop(internal); + let _ = tx.send(Ok(())); break; } } @@ -252,38 +244,32 @@ impl RouteManager { /// Stops the routing manager and invalidates the route manager - no new default route callbacks /// can be added - pub fn stop(&mut self) { + pub async fn stop(&mut self) { if let Some(tx) = self.manage_tx.take() { - if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() { - log::error!("RouteManager channel already down or thread panicked"); - } + let (result_tx, result_rx) = oneshot::channel(); + let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(result_tx)); + _ = result_rx.await; } } /// Applies the given routes until [`RouteManager::stop`] is called. pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - if let Some(tx) = &self.manage_tx { - let (result_tx, result_rx) = oneshot::channel(); - if tx - .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) - .is_err() - { - return Err(Error::RouteManagerDown); - } - result_rx.await.map_err(|_| Error::ManagerChannelDown)? - } else { - Err(Error::RouteManagerDown) - } + let tx = self.get_command_tx()?; + let (result_tx, result_rx) = oneshot::channel(); + tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .map_err(|_| Error::RouteManagerDown)?; + result_rx.await.map_err(|_| Error::ManagerChannelDown)? } /// Removes all routes previously applied in [`RouteManager::add_routes`]. pub fn clear_routes(&self) -> Result<()> { - if let Some(tx) = &self.manage_tx { - tx.unbounded_send(RouteManagerCommand::ClearRoutes) - .map_err(|_| Error::RouteManagerDown) - } else { - Err(Error::RouteManagerDown) - } + let tx = self.get_command_tx()?; + tx.unbounded_send(RouteManagerCommand::ClearRoutes) + .map_err(|_| Error::RouteManagerDown) + } + + fn get_command_tx(&self) -> Result<&UnboundedSender<RouteManagerCommand>> { + self.manage_tx.as_ref().ok_or(Error::RouteManagerDown) } } @@ -309,6 +295,9 @@ fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> { impl Drop for RouteManager { fn drop(&mut self) { - self.stop(); + if let Some(tx) = self.manage_tx.take() { + let (done_tx, _) = oneshot::channel(); + let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx)); + } } } |
