summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-03-11 14:24:48 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-03-14 13:01:06 +0100
commit69950ff6070cacc3ba5ad3a445ccd4c7c05ab180 (patch)
treecca20e246852e341aa2b356202b4e1d45e74fe93
parent4c7327dd7e6bdd59b1086f96fc042cca34c32581 (diff)
downloadmullvadvpn-69950ff6070cacc3ba5ad3a445ccd4c7c05ab180.tar.xz
mullvadvpn-69950ff6070cacc3ba5ad3a445ccd4c7c05ab180.zip
Simplify route manager handle
-rw-r--r--talpid-core/src/offline/macos.rs6
-rw-r--r--talpid-core/src/offline/windows.rs13
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs31
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs16
-rw-r--r--talpid-openvpn/src/lib.rs6
-rw-r--r--talpid-routing/src/lib.rs6
-rw-r--r--talpid-routing/src/unix/mod.rs262
-rw-r--r--talpid-routing/src/windows/mod.rs132
-rw-r--r--talpid-wireguard/src/lib.rs2
9 files changed, 170 insertions, 304 deletions
diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs
index 2813dc4211..daafecb052 100644
--- a/talpid-core/src/offline/macos.rs
+++ b/talpid-core/src/offline/macos.rs
@@ -65,14 +65,14 @@ impl ConnectivityInner {
pub async fn spawn_monitor(
notify_tx: UnboundedSender<Connectivity>,
- route_manager_handle: RouteManagerHandle,
+ route_manager: RouteManagerHandle,
) -> Result<MonitorHandle, Error> {
let notify_tx = Arc::new(notify_tx);
// note: begin observing before initializing the state
- let route_listener = route_manager_handle.default_route_listener().await?;
+ let route_listener = route_manager.default_route_listener().await?;
- let (ipv4, ipv6) = match route_manager_handle.get_default_routes().await {
+ let (ipv4, ipv6) = match route_manager.get_default_routes().await {
Ok((v4_route, v6_route)) => (v4_route.is_some(), v6_route.is_some()),
Err(error) => {
log::warn!("Failed to initialize offline monitor: {error}");
diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs
index f47fe8dd4d..5e09763cd0 100644
--- a/talpid-core/src/offline/windows.rs
+++ b/talpid-core/src/offline/windows.rs
@@ -29,7 +29,7 @@ unsafe impl Send for BroadcastListener {}
impl BroadcastListener {
pub async fn start(
notify_tx: UnboundedSender<Connectivity>,
- route_manager_handle: RouteManagerHandle,
+ route_manager: RouteManagerHandle,
mut power_mgmt_rx: PowerManagementListener,
) -> Result<Self, Error> {
let notify_tx = Arc::new(notify_tx);
@@ -67,8 +67,7 @@ impl BroadcastListener {
});
let callback_handle =
- Self::setup_network_connectivity_listener(system_state.clone(), route_manager_handle)
- .await?;
+ Self::setup_network_connectivity_listener(system_state.clone(), route_manager).await?;
Ok(BroadcastListener {
system_state,
@@ -107,9 +106,9 @@ impl BroadcastListener {
/// until after `WinNet_DeactivateConnectivityMonitor` has been called.
async fn setup_network_connectivity_listener(
system_state: Arc<Mutex<SystemState>>,
- route_manager_handle: RouteManagerHandle,
+ route_manager: RouteManagerHandle,
) -> Result<CallbackHandle, Error> {
- let change_handle = route_manager_handle
+ let change_handle = route_manager
.add_default_route_change_callback(Box::new(move |event, addr_family| {
Self::connectivity_callback(event, addr_family, &system_state)
}))
@@ -202,10 +201,10 @@ pub type MonitorHandle = BroadcastListener;
pub async fn spawn_monitor(
sender: UnboundedSender<Connectivity>,
- route_manager_handle: RouteManagerHandle,
+ route_manager: RouteManagerHandle,
) -> Result<MonitorHandle, Error> {
let power_mgmt_rx = crate::window::PowerManagementListener::new();
- BroadcastListener::start(sender, route_manager_handle, power_mgmt_rx).await
+ BroadcastListener::start(sender, route_manager, power_mgmt_rx).await
}
fn apply_system_state_change(state: Arc<Mutex<SystemState>>, change: StateChange) {
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 927de207bf..71a88e64fb 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -18,7 +18,7 @@ use std::{
thread,
time::{Duration, Instant},
};
-use talpid_routing::RouteManager;
+use talpid_routing::RouteManagerHandle;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use talpid_types::{
net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters},
@@ -60,9 +60,9 @@ impl ConnectingState {
if shared_values.connectivity.is_offline() {
// FIXME: Temporary: Nudge route manager to update the default interface
#[cfg(target_os = "macos")]
- if let Ok(handle) = shared_values.route_manager.handle() {
+ {
log::debug!("Poking route manager to update default routes");
- let _ = handle.refresh_routes();
+ let _ = shared_values.route_manager.refresh_routes();
}
return ErrorState::enter(shared_values, ErrorStateCause::IsOffline);
}
@@ -189,7 +189,7 @@ impl ConnectingState {
log_dir: &Option<PathBuf>,
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: &RouteManager,
+ route_manager: &RouteManagerHandle,
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
@@ -202,7 +202,7 @@ impl ConnectingState {
})
};
- let route_manager_handle = route_manager.handle();
+ let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
let resource_dir = resource_dir.to_path_buf();
@@ -214,25 +214,6 @@ impl ConnectingState {
tokio::task::spawn_blocking(move || {
let start = Instant::now();
- let route_manager_handle = match route_manager_handle {
- Ok(handle) => handle,
- Err(error) => {
- if tunnel_close_event_tx
- .send(Some(ErrorStateCause::StartTunnelError))
- .is_err()
- {
- log::warn!(
- "Tunnel state machine stopped before receiving tunnel closed event"
- );
- }
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to obtain route monitor handle")
- );
- return;
- }
- };
-
let args = TunnelArgs {
runtime,
resource_dir: &resource_dir,
@@ -240,7 +221,7 @@ impl ConnectingState {
tunnel_close_rx,
tun_provider,
retry_attempt,
- route_manager: route_manager_handle,
+ route_manager,
};
let block_reason = match TunnelMonitor::start(&mut tunnel_parameters, &log_dir, args) {
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 75f6cc1ced..bee32bb31d 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -21,7 +21,7 @@ use crate::{
};
#[cfg(windows)]
use std::ffi::OsString;
-use talpid_routing::RouteManager;
+use talpid_routing::RouteManagerHandle;
use talpid_tunnel::{tun_provider::TunProvider, TunnelEvent};
use futures::{
@@ -269,7 +269,7 @@ impl TunnelStateMachine {
#[cfg(target_os = "macos")]
let filtering_resolver = crate::resolver::start_resolver().await?;
- let route_manager = RouteManager::new(
+ let route_manager = RouteManagerHandle::spawn(
#[cfg(target_os = "linux")]
args.linux_ids.fwmark,
#[cfg(target_os = "linux")]
@@ -284,9 +284,7 @@ impl TunnelStateMachine {
args.resource_dir.clone(),
args.command_tx.clone(),
volume_update_rx,
- route_manager
- .handle()
- .map_err(Error::InitRouteManagerError)?,
+ route_manager.clone(),
)
.map_err(Error::InitSplitTunneling)?;
@@ -308,9 +306,7 @@ impl TunnelStateMachine {
#[cfg(target_os = "linux")]
runtime.clone(),
#[cfg(target_os = "linux")]
- route_manager
- .handle()
- .map_err(Error::InitRouteManagerError)?,
+ route_manager.clone(),
#[cfg(target_os = "macos")]
args.command_tx.clone(),
)
@@ -331,7 +327,7 @@ impl TunnelStateMachine {
let offline_monitor = offline::spawn_monitor(
offline_tx,
#[cfg(not(target_os = "android"))]
- route_manager.handle()?,
+ route_manager.clone(),
#[cfg(target_os = "linux")]
Some(args.linux_ids.fwmark),
#[cfg(target_os = "android")]
@@ -436,7 +432,7 @@ struct SharedTunnelStateValues {
runtime: tokio::runtime::Handle,
firewall: Firewall,
dns_monitor: DnsMonitor,
- route_manager: RouteManager,
+ route_manager: RouteManagerHandle,
_offline_monitor: offline::MonitorHandle,
/// Should LAN access be allowed outside the tunnel.
allow_lan: bool,
diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs
index 7426228812..da30c1940f 100644
--- a/talpid-openvpn/src/lib.rs
+++ b/talpid-openvpn/src/lib.rs
@@ -312,7 +312,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
proxy_auth_file_path: proxy_auth_file_path.clone(),
abort_server_tx: event_server_abort_tx,
proxy: params.proxy.clone(),
- route_manager_handle: route_manager,
+ route_manager,
#[cfg(target_os = "linux")]
ipv6_enabled,
},
@@ -817,7 +817,7 @@ mod event_server {
pub proxy_auth_file_path: Option<super::PathBuf>,
pub abort_server_tx: triggered::Trigger,
pub proxy: Option<CustomProxy>,
- pub route_manager_handle: talpid_routing::RouteManagerHandle,
+ pub route_manager: talpid_routing::RouteManagerHandle,
#[cfg(target_os = "linux")]
pub ipv6_enabled: bool,
}
@@ -864,7 +864,7 @@ mod event_server {
let route = talpid_routing::RequiredRoute::new(network, node);
routes.insert(route);
}
- let route_handle = self.route_manager_handle.clone();
+ let route_handle = self.route_manager.clone();
#[cfg(target_os = "linux")]
{
diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs
index d525b12435..f15489fcb6 100644
--- a/talpid-routing/src/lib.rs
+++ b/talpid-routing/src/lib.rs
@@ -25,7 +25,7 @@ use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN;
#[cfg(target_os = "macos")]
pub use imp::{imp::RouteError, DefaultRouteEvent, PlatformError};
-pub use imp::{Error, RouteManager, RouteManagerHandle};
+pub use imp::{Error, RouteManagerHandle};
/// A network route with a specific network node, destination and an optional metric.
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
@@ -81,7 +81,7 @@ impl fmt::Display for Route {
}
}
-/// A network route that should be applied by the RouteManager.
+/// A network route that should be applied by the route manager.
/// It can either be routed through a specific network node or it can be routed through the current
/// default route.
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
@@ -130,7 +130,7 @@ impl RequiredRoute {
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
pub enum NetNode {
/// A real node will be used to set a regular route that will remain unchanged for the lifetime
- /// of the RouteManager
+ /// of the route manager
RealNode(Node),
/// A default node is a symbolic node that will resolve to the network node used in the current
/// most preferable default route
diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs
index 768000c010..7fe7e9bf31 100644
--- a/talpid-routing/src/unix/mod.rs
+++ b/talpid-routing/src/unix/mod.rs
@@ -32,7 +32,7 @@ mod imp;
pub use imp::Error as PlatformError;
-/// Errors that can be encountered whilst initializing RouteManager
+/// Errors that can be encountered whilst initializing route manager
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Route manager thread may have panicked
@@ -65,25 +65,123 @@ impl Error {
}
}
-/// Handle to a route manager.
-#[derive(Clone)]
+/// Represents a firewall mark.
+#[cfg(target_os = "linux")]
+type Fwmark = u32;
+
+/// Commands for the underlying route manager object.
+#[derive(Debug)]
+pub(crate) enum RouteManagerCommand {
+ AddRoutes(
+ HashSet<RequiredRoute>,
+ oneshot::Sender<Result<(), PlatformError>>,
+ ),
+ ClearRoutes,
+ Shutdown(oneshot::Sender<()>),
+ #[cfg(target_os = "macos")]
+ RefreshRoutes,
+ #[cfg(target_os = "macos")]
+ NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>),
+ #[cfg(target_os = "macos")]
+ GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>),
+ #[cfg(target_os = "linux")]
+ CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>),
+ #[cfg(target_os = "linux")]
+ ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>),
+ #[cfg(target_os = "linux")]
+ NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>),
+ #[cfg(target_os = "linux")]
+ GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>),
+ /// Attempt to fetch a route for the given destination with an optional firewall mark.
+ #[cfg(target_os = "linux")]
+ GetDestinationRoute(
+ IpAddr,
+ Option<Fwmark>,
+ oneshot::Sender<Result<Option<Route>, PlatformError>>,
+ ),
+}
+
+/// Event that is sent when a preferred non-tunnel default route is
+/// added or removed.
+#[cfg(target_os = "macos")]
+#[derive(Debug, Clone, Copy)]
+pub enum DefaultRouteEvent {
+ /// Added or updated a non-tunnel default IPv4 route
+ AddedOrChangedV4,
+ /// Added or updated a non-tunnel default IPv6 route
+ AddedOrChangedV6,
+ /// Non-tunnel default IPv4 route was removed
+ RemovedV4,
+ /// Non-tunnel default IPv6 route was removed
+ RemovedV6,
+}
+
+#[cfg(target_os = "linux")]
+#[derive(Debug, Clone)]
+pub enum CallbackMessage {
+ NewRoute(Route),
+ DelRoute(Route),
+}
+
+/// Route manager applies a set of routes to the route table.
+/// If a destination has to be routed through the default node,
+/// the route will be adjusted dynamically when the default route changes.
+#[derive(Debug, Clone)]
pub struct RouteManagerHandle {
tx: Arc<UnboundedSender<RouteManagerCommand>>,
}
impl RouteManagerHandle {
- /// Applies the given routes while the route manager is running.
+ /// Construct a route manager.
+ pub async fn spawn(
+ #[cfg(target_os = "linux")] fwmark: u32,
+ #[cfg(target_os = "linux")] table_id: u32,
+ ) -> Result<Self, Error> {
+ let (manage_tx, manage_rx) = mpsc::unbounded();
+ let manage_tx = Arc::new(manage_tx);
+ let manager = imp::RouteManagerImpl::new(
+ #[cfg(target_os = "linux")]
+ fwmark,
+ #[cfg(target_os = "linux")]
+ table_id,
+ #[cfg(target_os = "macos")]
+ Arc::downgrade(&manage_tx),
+ )
+ .await?;
+ tokio::spawn(manager.run(manage_rx));
+
+ Ok(Self { tx: manage_tx })
+ }
+
+ /// Stop route manager and revert all changes to routing
+ pub async fn stop(&self) {
+ let (wait_tx, wait_rx) = oneshot::channel();
+ let _ = self
+ .tx
+ .unbounded_send(RouteManagerCommand::Shutdown(wait_tx));
+ let _ = wait_rx.await;
+ }
+
+ /// Applies the given routes until they are cleared
pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
- let (response_tx, response_rx) = oneshot::channel();
+ let (result_tx, result_rx) = oneshot::channel();
self.tx
- .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
.map_err(|_| Error::RouteManagerDown)?;
- response_rx
+
+ result_rx
.await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}
+ /// Removes all routes previously applied in [`RouteManager::add_routes`].
+ pub fn clear_routes(&self) -> Result<(), Error> {
+ self.tx
+ .unbounded_send(RouteManagerCommand::ClearRoutes)
+ .map_err(|_| Error::RouteManagerDown)
+ }
+
/// Listen for non-tunnel default route changes.
#[cfg(target_os = "macos")]
pub async fn default_route_listener(
@@ -187,153 +285,3 @@ impl RouteManagerHandle {
.map_err(Error::PlatformError)
}
}
-
-/// Represents a firewall mark.
-#[cfg(target_os = "linux")]
-type Fwmark = u32;
-
-/// Commands for the underlying route manager object.
-#[derive(Debug)]
-pub(crate) enum RouteManagerCommand {
- AddRoutes(
- HashSet<RequiredRoute>,
- oneshot::Sender<Result<(), PlatformError>>,
- ),
- ClearRoutes,
- Shutdown(oneshot::Sender<()>),
- #[cfg(target_os = "macos")]
- RefreshRoutes,
- #[cfg(target_os = "macos")]
- NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>),
- #[cfg(target_os = "macos")]
- GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>),
- #[cfg(target_os = "linux")]
- CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>),
- #[cfg(target_os = "linux")]
- ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>),
- #[cfg(target_os = "linux")]
- NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>),
- #[cfg(target_os = "linux")]
- GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>),
- /// Attempt to fetch a route for the given destination with an optional firewall mark.
- #[cfg(target_os = "linux")]
- GetDestinationRoute(
- IpAddr,
- Option<Fwmark>,
- oneshot::Sender<Result<Option<Route>, PlatformError>>,
- ),
-}
-
-/// Event that is sent when a preferred non-tunnel default route is
-/// added or removed.
-#[cfg(target_os = "macos")]
-#[derive(Debug, Clone, Copy)]
-pub enum DefaultRouteEvent {
- /// Added or updated a non-tunnel default IPv4 route
- AddedOrChangedV4,
- /// Added or updated a non-tunnel default IPv6 route
- AddedOrChangedV6,
- /// Non-tunnel default IPv4 route was removed
- RemovedV4,
- /// Non-tunnel default IPv6 route was removed
- RemovedV6,
-}
-
-#[cfg(target_os = "linux")]
-#[derive(Debug, Clone)]
-pub enum CallbackMessage {
- NewRoute(Route),
- DelRoute(Route),
-}
-
-/// RouteManager applies a set of routes to the route table.
-/// If a destination has to be routed through the default node,
-/// the route will be adjusted dynamically when the default route changes.
-pub struct RouteManager {
- manage_tx: Option<Arc<UnboundedSender<RouteManagerCommand>>>,
-}
-
-impl RouteManager {
- /// Construct a RouteManager.
- pub async fn new(
- #[cfg(target_os = "linux")] fwmark: u32,
- #[cfg(target_os = "linux")] table_id: u32,
- ) -> Result<Self, Error> {
- let (manage_tx, manage_rx) = mpsc::unbounded();
- let manage_tx = Arc::new(manage_tx);
- let manager = imp::RouteManagerImpl::new(
- #[cfg(target_os = "linux")]
- fwmark,
- #[cfg(target_os = "linux")]
- table_id,
- #[cfg(target_os = "macos")]
- Arc::downgrade(&manage_tx),
- )
- .await?;
- tokio::spawn(manager.run(manage_rx));
-
- Ok(Self {
- manage_tx: Some(manage_tx),
- })
- }
-
- /// Stops RouteManager and removes all of the applied routes.
- pub async fn stop(&mut self) {
- if let Some(tx) = self.manage_tx.take() {
- let (wait_tx, wait_rx) = oneshot::channel();
- let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(wait_tx));
- let _ = wait_rx.await;
- }
- }
-
- /// Applies the given routes until [`RouteManager::stop`] is called.
- pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
- let tx = self.get_command_tx()?;
- let (result_tx, result_rx) = oneshot::channel();
- tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
- .map_err(|_| Error::RouteManagerDown)?;
-
- result_rx
- .await
- .map_err(|_| Error::ManagerChannelDown)?
- .map_err(Error::PlatformError)
- }
-
- /// Removes all routes previously applied in [`RouteManager::add_routes`].
- pub fn clear_routes(&self) -> Result<(), Error> {
- let tx = self.get_command_tx()?;
- tx.unbounded_send(RouteManagerCommand::ClearRoutes)
- .map_err(|_| Error::RouteManagerDown)
- }
-
- /// Ensure that packets are routed using the correct tables.
- #[cfg(target_os = "linux")]
- pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> {
- self.handle()?.create_routing_rules(enable_ipv6).await
- }
-
- /// Remove any routing rules created by [Self::create_routing_rules].
- #[cfg(target_os = "linux")]
- pub async fn clear_routing_rules(&self) -> Result<(), Error> {
- self.handle()?.clear_routing_rules().await
- }
-
- /// Retrieve a sender directly to the command channel.
- pub fn handle(&self) -> Result<RouteManagerHandle, Error> {
- let tx = self.get_command_tx()?;
- Ok(RouteManagerHandle { tx: tx.clone() })
- }
-
- fn get_command_tx(&self) -> Result<&Arc<UnboundedSender<RouteManagerCommand>>, Error> {
- self.manage_tx.as_ref().ok_or(Error::RouteManagerDown)
- }
-}
-
-impl Drop for RouteManager {
- fn drop(&mut self) {
- if let Some(tx) = self.manage_tx.take() {
- let (done_tx, _) = oneshot::channel();
- let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx));
- }
- }
-}
diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs
index c03beea8cf..c158938ebd 100644
--- a/talpid-routing/src/windows/mod.rs
+++ b/talpid-routing/src/windows/mod.rs
@@ -21,9 +21,6 @@ mod route_manager;
/// Windows routing errors.
#[derive(thiserror::Error, Debug)]
pub enum Error {
- /// The sender was dropped unexpectedly -- possible panic
- #[error("The channel sender was dropped")]
- ManagerChannelDown,
/// Failure to initialize route manager
#[error("Failed to start route manager")]
FailedToStartManager,
@@ -101,17 +98,31 @@ impl Error {
pub type Result<T> = std::result::Result<T, Error>;
/// Manages routes by calling into WinNet
-pub struct RouteManager {
- manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
-}
-
-/// Handle to a route manager.
-#[derive(Clone)]
+#[derive(Debug, Clone)]
pub struct RouteManagerHandle {
tx: UnboundedSender<RouteManagerCommand>,
}
+pub enum RouteManagerCommand {
+ AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
+ GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>),
+ ClearRoutes,
+ RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>),
+ Shutdown(oneshot::Sender<()>),
+}
+
impl RouteManagerHandle {
+ /// Create a new route manager
+ #[allow(clippy::unused_async)]
+ pub async fn spawn() -> Result<Self> {
+ let internal = RouteManagerInternal::new().map_err(|_| Error::FailedToStartManager)?;
+ let (tx, rx) = mpsc::unbounded();
+ let handle = Self { tx };
+ tokio::spawn(RouteManagerHandle::run(rx, internal));
+
+ Ok(handle)
+ }
+
/// Add a callback which will be called if the default route changes.
pub async fn add_default_route_change_callback(
&self,
@@ -124,7 +135,7 @@ impl RouteManagerHandle {
response_tx,
))
.map_err(|_| Error::RouteManagerDown)?;
- response_rx.await.map_err(|_| Error::ManagerChannelDown)
+ response_rx.await.map_err(|_| Error::RouteManagerDown)
}
/// Applies the given routes while the route manager is running.
@@ -133,65 +144,35 @@ impl RouteManagerHandle {
self.tx
.unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
.map_err(|_| Error::RouteManagerDown)?;
- response_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ response_rx.await.map_err(|_| Error::RouteManagerDown)?
}
- /// Applies the given routes while the route manager is running.
+ /// Retrieve MTU for the given destination/route.
pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
.map_err(|_| Error::RouteManagerDown)?;
- response_rx.await.map_err(|_| Error::ManagerChannelDown)?
- }
-}
-
-pub enum RouteManagerCommand {
- AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
- GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>),
- ClearRoutes,
- RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>),
- Shutdown(oneshot::Sender<Result<()>>),
-}
-
-impl RouteManager {
- /// Create a new route manager
- #[allow(clippy::unused_async)]
- pub async fn new() -> Result<Self> {
- let internal = match RouteManagerInternal::new() {
- Ok(internal) => internal,
- Err(_) => return Err(Error::FailedToStartManager),
- };
- let (manage_tx, manage_rx) = mpsc::unbounded();
- let manager = Self {
- manage_tx: Some(manage_tx),
- };
- tokio::spawn(RouteManager::listen(manage_rx, internal));
-
- Ok(manager)
+ response_rx.await.map_err(|_| Error::RouteManagerDown)?
}
- /// Add a callback which will be called if the default route changes.
- pub async fn add_default_route_change_callback(
- &self,
- callback: Callback,
- ) -> Result<CallbackHandle> {
- let tx = self.get_command_tx()?;
+ /// Stop the routing manager actor and revert all changes to routing
+ pub async fn stop(&self) {
let (result_tx, result_rx) = oneshot::channel();
- tx.unbounded_send(RouteManagerCommand::RegisterDefaultRouteChangeCallback(
- callback, result_tx,
- ))
- .map_err(|_| Error::RouteManagerDown)?;
- result_rx.await.map_err(|_| Error::ManagerChannelDown)
+ _ = self
+ .tx
+ .unbounded_send(RouteManagerCommand::Shutdown(result_tx));
+ _ = result_rx.await;
}
- /// Retrieve a sender directly to the command channel.
- pub fn handle(&self) -> Result<RouteManagerHandle> {
- let tx = self.get_command_tx()?;
- Ok(RouteManagerHandle { tx: tx.clone() })
+ /// Removes all routes previously applied in [`RouteManager::add_routes`].
+ pub fn clear_routes(&self) -> Result<()> {
+ self.tx
+ .unbounded_send(RouteManagerCommand::ClearRoutes)
+ .map_err(|_| Error::RouteManagerDown)
}
- async fn listen(
+ async fn run(
mut manage_rx: UnboundedReceiver<RouteManagerCommand>,
mut internal: RouteManagerInternal,
) {
@@ -235,42 +216,12 @@ impl RouteManager {
}
RouteManagerCommand::Shutdown(tx) => {
drop(internal);
- let _ = tx.send(Ok(()));
+ let _ = tx.send(());
break;
}
}
}
}
-
- /// Stops the routing manager and invalidates the route manager - no new default route callbacks
- /// can be added
- pub async fn stop(&mut self) {
- if let Some(tx) = self.manage_tx.take() {
- let (result_tx, result_rx) = oneshot::channel();
- let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(result_tx));
- _ = result_rx.await;
- }
- }
-
- /// Applies the given routes until [`RouteManager::stop`] is called.
- pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
- let tx = self.get_command_tx()?;
- let (result_tx, result_rx) = oneshot::channel();
- tx.unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
- .map_err(|_| Error::RouteManagerDown)?;
- result_rx.await.map_err(|_| Error::ManagerChannelDown)?
- }
-
- /// Removes all routes previously applied in [`RouteManager::add_routes`].
- pub fn clear_routes(&self) -> Result<()> {
- let tx = self.get_command_tx()?;
- tx.unbounded_send(RouteManagerCommand::ClearRoutes)
- .map_err(|_| Error::RouteManagerDown)
- }
-
- fn get_command_tx(&self) -> Result<&UnboundedSender<RouteManagerCommand>> {
- self.manage_tx.as_ref().ok_or(Error::RouteManagerDown)
- }
}
fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> {
@@ -292,12 +243,3 @@ fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> {
}
}
}
-
-impl Drop for RouteManager {
- fn drop(&mut self) {
- if let Some(tx) = self.manage_tx.take() {
- let (done_tx, _) = oneshot::channel();
- let _ = tx.unbounded_send(RouteManagerCommand::Shutdown(done_tx));
- }
- }
-}
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 4f02bfcb09..72d6a31566 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -718,7 +718,7 @@ impl WireguardMonitor {
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] psk_negotiation: bool,
- #[cfg(windows)] route_manager_handle: crate::routing::RouteManagerHandle,
+ #[cfg(windows)] route_manager: crate::routing::RouteManagerHandle,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Box<dyn Tunnel>> {
log::debug!("Tunnel MTU: {}", config.mtu);