diff options
| author | David Lönnhager <david.l@mullvad.net> | 2025-01-13 09:25:36 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2025-01-24 17:35:03 +0100 |
| commit | 81fe90fe201497d953fefc26357631ac4fd69d54 (patch) | |
| tree | c94719da088d24311a92b18acc8f4464dd1892c1 | |
| parent | 3e10836afaa167571ea3be3a9922cc74f5acfdb1 (diff) | |
| download | mullvadvpn-81fe90fe201497d953fefc26357631ac4fd69d54.tar.xz mullvadvpn-81fe90fe201497d953fefc26357631ac4fd69d54.zip | |
Handle network changes for wireguard-go (rebind endpoint socket)
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 8 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 56 |
2 files changed, 60 insertions, 4 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 6eecefbbb6..9cf6082a48 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -697,13 +697,13 @@ impl WireguardMonitor { let use_userspace_wg = config.daita; if use_userspace_wg { log::debug!("Using userspace WireGuard implementation"); - let tunnel = Self::open_wireguard_go_tunnel( + let tunnel = runtime.block_on(Self::open_wireguard_go_tunnel( runtime, config, log_path, setup_done_tx, route_manager, - ) + )) .map(Box::new)?; return Ok(tunnel); } @@ -736,10 +736,12 @@ impl WireguardMonitor { #[cfg(wireguard_go)] #[allow(clippy::unused_async)] async fn open_wireguard_go_tunnel( + #[cfg(windows)] runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, #[cfg(unix)] tun_provider: Arc<Mutex<TunProvider>>, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, + #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(target_os = "android")] gateway_only: bool, #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, ) -> Result<WgGoTunnel> { @@ -753,7 +755,7 @@ impl WireguardMonitor { .map_err(Error::TunnelError)?; #[cfg(target_os = "windows")] - let tunnel = WgGoTunnel::start_tunnel(config, log_path, setup_done_tx) + let tunnel = WgGoTunnel::start_tunnel(runtime, config, log_path, route_manager, setup_done_tx) .map_err(Error::TunnelError)?; // Android uses multihop implemented in Mullvad's wireguard-go fork. When negotiating diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 86646c1d0c..6699b7a6ad 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -178,6 +178,10 @@ pub(crate) struct WgGoTunnelState { /// This is used to cancel the connectivity checks that occur when toggling multihop #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, + /// Default route change callback. This is used to rebind the endpoint socket when the default + /// route (network) is changed. + #[cfg(target_os = "windows")] + _socket_update_cb: Option<talpid_routing::CallbackHandle>, } impl WgGoTunnelState { @@ -257,8 +261,10 @@ impl WgGoTunnel { #[cfg(target_os = "windows")] pub fn start_tunnel( + runtime: tokio::runtime::Handle, config: &Config, log_path: Option<&Path>, + route_manager: talpid_routing::RouteManagerHandle, mut setup_done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Self> { use futures::SinkExt; @@ -269,7 +275,16 @@ impl WgGoTunnel { .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned))) .map_err(TunnelError::LoggingError)?; - // TODO: default route clalback + let socket_update_cb = runtime + .block_on( + route_manager.add_default_route_change_callback(Box::new( + Self::default_route_changed_callback, + )), + ) + .ok(); + if socket_update_cb.is_none() { + log::warn!("Failed to register default route callback"); + } let handle = wireguard_go_rs::Tunnel::turn_on( c"Mullvad", @@ -312,11 +327,50 @@ impl WgGoTunnel { interface_name: interface_name.to_owned(), tunnel_handle: handle, _logging_context: logging_context, + _socket_update_cb: socket_update_cb, #[cfg(daita)] config: config.clone(), })) } + // Callback to be used to rebind the tunnel sockets when the default route changes + #[cfg(target_os = "windows")] + fn default_route_changed_callback( + event_type: talpid_routing::EventType<'_>, + address_family: talpid_windows::net::AddressFamily, + ) { + use talpid_routing::EventType::*; + + let iface_idx: u32 = match event_type { + Updated(default_route) => { + let iface_luid = default_route.iface; + match talpid_windows::net::index_from_luid(&iface_luid) { + Ok(idx) => idx, + Err(err) => { + log::error!( + "Failed to convert interface LUID to interface index: {}", + err, + ); + return; + }, + } + } + // if there is no new default route, specify 0 as the interface index + Removed => 0, + // ignore interface updates that don't affect the interface to use + UpdatedDetails(_) => return, + }; + + match address_family { + talpid_windows::net::AddressFamily::Ipv4 => { + wireguard_go_rs::rebind_tunnel_socket_v4(iface_idx); + } + talpid_windows::net::AddressFamily::Ipv6 => { + wireguard_go_rs::rebind_tunnel_socket_v6(iface_idx); + } + } + } + #[cfg(unix)] fn get_tunnel( tun_provider: Arc<Mutex<TunProvider>>, |
