summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2025-01-13 09:25:36 +0100
committerDavid Lönnhager <david.l@mullvad.net>2025-01-24 17:35:03 +0100
commit81fe90fe201497d953fefc26357631ac4fd69d54 (patch)
treec94719da088d24311a92b18acc8f4464dd1892c1
parent3e10836afaa167571ea3be3a9922cc74f5acfdb1 (diff)
downloadmullvadvpn-81fe90fe201497d953fefc26357631ac4fd69d54.tar.xz
mullvadvpn-81fe90fe201497d953fefc26357631ac4fd69d54.zip
Handle network changes for wireguard-go (rebind endpoint socket)
-rw-r--r--talpid-wireguard/src/lib.rs8
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs56
2 files changed, 60 insertions, 4 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 6eecefbbb6..9cf6082a48 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -697,13 +697,13 @@ impl WireguardMonitor {
let use_userspace_wg = config.daita;
if use_userspace_wg {
log::debug!("Using userspace WireGuard implementation");
- let tunnel = Self::open_wireguard_go_tunnel(
+ let tunnel = runtime.block_on(Self::open_wireguard_go_tunnel(
runtime,
config,
log_path,
setup_done_tx,
route_manager,
- )
+ ))
.map(Box::new)?;
return Ok(tunnel);
}
@@ -736,10 +736,12 @@ impl WireguardMonitor {
#[cfg(wireguard_go)]
#[allow(clippy::unused_async)]
async fn open_wireguard_go_tunnel(
+ #[cfg(windows)] runtime: tokio::runtime::Handle,
config: &Config,
log_path: Option<&Path>,
#[cfg(unix)] tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
+ #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle,
#[cfg(target_os = "android")] gateway_only: bool,
#[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver,
) -> Result<WgGoTunnel> {
@@ -753,7 +755,7 @@ impl WireguardMonitor {
.map_err(Error::TunnelError)?;
#[cfg(target_os = "windows")]
- let tunnel = WgGoTunnel::start_tunnel(config, log_path, setup_done_tx)
+ let tunnel = WgGoTunnel::start_tunnel(runtime, config, log_path, route_manager, setup_done_tx)
.map_err(Error::TunnelError)?;
// Android uses multihop implemented in Mullvad's wireguard-go fork. When negotiating
diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index 86646c1d0c..6699b7a6ad 100644
--- a/talpid-wireguard/src/wireguard_go/mod.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -178,6 +178,10 @@ pub(crate) struct WgGoTunnelState {
/// This is used to cancel the connectivity checks that occur when toggling multihop
#[cfg(target_os = "android")]
cancel_receiver: connectivity::CancelReceiver,
+ /// Default route change callback. This is used to rebind the endpoint socket when the default
+ /// route (network) is changed.
+ #[cfg(target_os = "windows")]
+ _socket_update_cb: Option<talpid_routing::CallbackHandle>,
}
impl WgGoTunnelState {
@@ -257,8 +261,10 @@ impl WgGoTunnel {
#[cfg(target_os = "windows")]
pub fn start_tunnel(
+ runtime: tokio::runtime::Handle,
config: &Config,
log_path: Option<&Path>,
+ route_manager: talpid_routing::RouteManagerHandle,
mut setup_done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Self> {
use futures::SinkExt;
@@ -269,7 +275,16 @@ impl WgGoTunnel {
.map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
.map_err(TunnelError::LoggingError)?;
- // TODO: default route clalback
+ let socket_update_cb = runtime
+ .block_on(
+ route_manager.add_default_route_change_callback(Box::new(
+ Self::default_route_changed_callback,
+ )),
+ )
+ .ok();
+ if socket_update_cb.is_none() {
+ log::warn!("Failed to register default route callback");
+ }
let handle = wireguard_go_rs::Tunnel::turn_on(
c"Mullvad",
@@ -312,11 +327,50 @@ impl WgGoTunnel {
interface_name: interface_name.to_owned(),
tunnel_handle: handle,
_logging_context: logging_context,
+ _socket_update_cb: socket_update_cb,
#[cfg(daita)]
config: config.clone(),
}))
}
+ // Callback to be used to rebind the tunnel sockets when the default route changes
+ #[cfg(target_os = "windows")]
+ fn default_route_changed_callback(
+ event_type: talpid_routing::EventType<'_>,
+ address_family: talpid_windows::net::AddressFamily,
+ ) {
+ use talpid_routing::EventType::*;
+
+ let iface_idx: u32 = match event_type {
+ Updated(default_route) => {
+ let iface_luid = default_route.iface;
+ match talpid_windows::net::index_from_luid(&iface_luid) {
+ Ok(idx) => idx,
+ Err(err) => {
+ log::error!(
+ "Failed to convert interface LUID to interface index: {}",
+ err,
+ );
+ return;
+ },
+ }
+ }
+ // if there is no new default route, specify 0 as the interface index
+ Removed => 0,
+ // ignore interface updates that don't affect the interface to use
+ UpdatedDetails(_) => return,
+ };
+
+ match address_family {
+ talpid_windows::net::AddressFamily::Ipv4 => {
+ wireguard_go_rs::rebind_tunnel_socket_v4(iface_idx);
+ }
+ talpid_windows::net::AddressFamily::Ipv6 => {
+ wireguard_go_rs::rebind_tunnel_socket_v6(iface_idx);
+ }
+ }
+ }
+
#[cfg(unix)]
fn get_tunnel(
tun_provider: Arc<Mutex<TunProvider>>,