summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-03-18 10:28:45 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-18 10:28:45 +0100
commiteb937c0ae598b5fbcb0e539f8bc1b300720c17a8 (patch)
treefdb52a1540ff5749e5415e1ddb81a9d8fee4cc39
parentaf7a5fc08e0e887fa4b75db55d1300294e4f7103 (diff)
parent8fdea52ee962c2e58985f4ba8d388848f220255e (diff)
downloadmullvadvpn-eb937c0ae598b5fbcb0e539f8bc1b300720c17a8.tar.xz
mullvadvpn-eb937c0ae598b5fbcb0e539f8bc1b300720c17a8.zip
Merge branch 'winnet-forward-interface-detail-change'
-rw-r--r--talpid-core/src/offline/windows.rs12
-rw-r--r--talpid-core/src/routing/windows.rs13
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs95
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs8
-rw-r--r--talpid-core/src/winnet.rs3
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.cpp80
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.h9
-rw-r--r--windows/winnet/src/winnet/winnet.cpp14
-rw-r--r--windows/winnet/src/winnet/winnet.h6
9 files changed, 167 insertions, 73 deletions
diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs
index 15b5d47169..3c603cf7e9 100644
--- a/talpid-core/src/offline/windows.rs
+++ b/talpid-core/src/offline/windows.rs
@@ -122,11 +122,15 @@ impl BroadcastListener {
_default_route: winnet::WinNetDefaultRoute,
ctx: *mut c_void,
) {
+ use winnet::WinNetDefaultRouteChangeEventType::*;
+
+ if event_type == DefaultRouteUpdatedDetails {
+ // ignore changes that don't affect the route
+ return;
+ }
+
let state_lock: &mut Arc<Mutex<SystemState>> = &mut *(ctx as *mut _);
- let connectivity = match event_type {
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => true,
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => false,
- };
+ let connectivity = event_type != DefaultRouteRemoved;
let change = match family {
winnet::WinNetAddrFamily::IPV4 => StateChange::NetworkV4Connectivity(connectivity),
winnet::WinNetAddrFamily::IPV6 => StateChange::NetworkV6Connectivity(connectivity),
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index ec17d4feae..f7325b3963 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -119,19 +119,6 @@ impl RouteManager {
}
}
- /// Sets a callback that is called whenever the default route changes.
- pub fn add_default_route_callback<T: 'static>(
- &mut self,
- callback: Option<winnet::DefaultRouteChangedCallback>,
- context: T,
- ) -> Result<winnet::WinNetCallbackHandle> {
- if self.manage_tx.is_none() {
- return Err(Error::RouteManagerDown);
- }
- winnet::add_default_route_change_callback(callback, context)
- .map_err(|_| Error::FailedToAddDefaultRouteCallback)
- }
-
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
/// can be added
pub fn stop(&mut self) {
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index efdd75ecf9..01a3d324b5 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -133,18 +133,21 @@ impl Drop for QuitEvent {
enum Request {
SetPaths(Vec<OsString>),
- RegisterIps(
- Option<Ipv4Addr>,
- Option<Ipv6Addr>,
- Option<Ipv4Addr>,
- Option<Ipv6Addr>,
- ),
+ RegisterIps(InterfaceAddresses),
}
type RequestResponseTx = sync_mpsc::Sender<Result<(), Error>>;
type RequestTx = sync_mpsc::Sender<(Request, RequestResponseTx)>;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
+#[derive(Default, PartialEq, Clone)]
+struct InterfaceAddresses {
+ tunnel_ipv4: Option<Ipv4Addr>,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Option<Ipv4Addr>,
+ internet_ipv6: Option<Ipv6Addr>,
+}
+
struct EventThreadContext {
handle: Arc<driver::DeviceHandle>,
event_overlapped: OVERLAPPED,
@@ -362,6 +365,8 @@ impl SplitTunnel {
}
};
+ let mut previous_addresses = InterfaceAddresses::default();
+
while let Ok((request, response_tx)) = rx.recv() {
let response = match request {
Request::SetPaths(paths) => {
@@ -387,19 +392,27 @@ impl SplitTunnel {
result
}
- Request::RegisterIps(
- mut tunnel_ipv4,
- mut tunnel_ipv6,
- internet_ipv4,
- internet_ipv6,
- ) => {
- if internet_ipv4.is_none() && internet_ipv6.is_none() {
- tunnel_ipv4 = None;
- tunnel_ipv6 = None;
+ Request::RegisterIps(mut ips) => {
+ if ips.internet_ipv4.is_none() && ips.internet_ipv6.is_none() {
+ ips.tunnel_ipv4 = None;
+ ips.tunnel_ipv6 = None;
+ }
+ if previous_addresses == ips {
+ Ok(())
+ } else {
+ let result = handle
+ .register_ips(
+ ips.tunnel_ipv4,
+ ips.tunnel_ipv6,
+ ips.internet_ipv4,
+ ips.internet_ipv6,
+ )
+ .map_err(Error::RegisterIps);
+ if result.is_ok() {
+ previous_addresses = ips;
+ }
+ result
}
- handle
- .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6)
- .map_err(Error::RegisterIps)
}
};
if response_tx.send(response).is_err() {
@@ -548,7 +561,7 @@ impl SplitTunnel {
/// Instructs the driver to stop redirecting tunnel traffic and INADDR_ANY.
pub fn clear_tunnel_addresses(&mut self) -> Result<(), Error> {
self._route_change_callback = None;
- self.send_request(Request::RegisterIps(None, None, None, None))
+ self.send_request(Request::RegisterIps(InterfaceAddresses::default()))
}
}
@@ -574,10 +587,7 @@ impl Drop for SplitTunnel {
struct SplitTunnelDefaultRouteChangeHandlerContext {
request_tx: RequestTx,
pub daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
- pub tunnel_ipv4: Option<Ipv4Addr>,
- pub tunnel_ipv6: Option<Ipv6Addr>,
- pub internet_ipv4: Option<Ipv4Addr>,
- pub internet_ipv6: Option<Ipv6Addr>,
+ pub addresses: InterfaceAddresses,
}
impl SplitTunnelDefaultRouteChangeHandlerContext {
@@ -590,22 +600,19 @@ impl SplitTunnelDefaultRouteChangeHandlerContext {
SplitTunnelDefaultRouteChangeHandlerContext {
request_tx,
daemon_tx,
- tunnel_ipv4,
- tunnel_ipv6,
- internet_ipv4: None,
- internet_ipv6: None,
+ addresses: InterfaceAddresses {
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4: None,
+ internet_ipv6: None,
+ },
}
}
pub fn register_ips(&self) -> Result<(), Error> {
SplitTunnel::send_request_inner(
&self.request_tx,
- Request::RegisterIps(
- self.tunnel_ipv4,
- self.tunnel_ipv6,
- self.internet_ipv4,
- self.internet_ipv6,
- ),
+ Request::RegisterIps(self.addresses.clone()),
)
}
@@ -624,10 +631,10 @@ impl SplitTunnelDefaultRouteChangeHandlerContext {
.map_err(Error::LuidToIp)?
.flatten();
- self.internet_ipv4 = internet_ipv4
+ self.addresses.internet_ipv4 = internet_ipv4
.map(|addr| Ipv4Addr::try_from(addr).map_err(|_| Error::IpParseError))
.transpose()?;
- self.internet_ipv6 = internet_ipv6
+ self.addresses.internet_ipv6 = internet_ipv6
.map(|addr| Ipv6Addr::try_from(addr).map_err(|_| Error::IpParseError))
.transpose()?;
Ok(())
@@ -640,6 +647,8 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
default_route: winnet::WinNetDefaultRoute,
ctx: *mut libc::c_void,
) {
+ use winnet::WinNetDefaultRouteChangeEventType::*;
+
// Update the "internet interface" IP when best default route changes
let ctx_mutex = &mut *(ctx as *mut Arc<Mutex<SplitTunnelDefaultRouteChangeHandlerContext>>);
let mut ctx = ctx_mutex.lock().expect("ST route handler mutex poisoned");
@@ -652,20 +661,20 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
};
let result = match event_type {
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ DefaultRouteChanged | DefaultRouteUpdatedDetails => {
match interface_luid_to_ip(address_family, default_route.interface_luid) {
Ok(Some(ip)) => match IpAddr::from(ip) {
- IpAddr::V4(addr) => ctx.internet_ipv4 = Some(addr),
- IpAddr::V6(addr) => ctx.internet_ipv6 = Some(addr),
+ IpAddr::V4(addr) => ctx.addresses.internet_ipv4 = Some(addr),
+ IpAddr::V6(addr) => ctx.addresses.internet_ipv6 = Some(addr),
},
Ok(None) => {
log::warn!("Failed to obtain default route interface address");
match address_family {
WinNetAddrFamily::IPV4 => {
- ctx.internet_ipv4 = None;
+ ctx.addresses.internet_ipv4 = None;
}
WinNetAddrFamily::IPV6 => {
- ctx.internet_ipv6 = None;
+ ctx.addresses.internet_ipv6 = None;
}
}
}
@@ -684,13 +693,13 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
ctx.register_ips()
}
// no default route
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => {
+ DefaultRouteRemoved => {
match address_family {
WinNetAddrFamily::IPV4 => {
- ctx.internet_ipv4 = None;
+ ctx.addresses.internet_ipv4 = None;
}
WinNetAddrFamily::IPV6 => {
- ctx.internet_ipv6 = None;
+ ctx.addresses.internet_ipv6 = None;
}
}
ctx.register_ips()
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 28666b6506..a3d34acc0e 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -197,8 +197,10 @@ impl WgGoTunnel {
_ctx: *mut libc::c_void,
) {
use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex};
+ use winnet::WinNetDefaultRouteChangeEventType::*;
+
let iface_idx: u32 = match event_type {
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ DefaultRouteChanged => {
let mut iface_idx = 0u32;
let iface_luid = NET_LUID {
Value: default_route.interface_luid,
@@ -216,7 +218,9 @@ impl WgGoTunnel {
iface_idx
}
// if there is no new default route, specify 0 as the interface index
- winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => 0,
+ DefaultRouteRemoved => 0,
+ // ignore interface updates that don't affect the interface to use
+ DefaultRouteUpdatedDetails => return,
};
wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx);
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 7c9489ecfb..7f31082541 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -311,7 +311,8 @@ impl Drop for WinNetCallbackHandle {
#[repr(u16)]
pub enum WinNetDefaultRouteChangeEventType {
DefaultRouteChanged = 0,
- DefaultRouteRemoved = 1,
+ DefaultRouteUpdatedDetails = 1,
+ DefaultRouteRemoved = 2,
}
pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
index 523c2d7ba0..8e1e2599ad 100644
--- a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
@@ -23,6 +23,7 @@ DefaultRouteMonitor::DefaultRouteMonitor
: m_family(family)
, m_callback(callback)
, m_logSink(logSink)
+ , m_refreshCurrentRoute(false)
, m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>(
std::bind(&DefaultRouteMonitor::evaluateRoutes, this),
POINT_TWO_SECOND_BURST,
@@ -31,14 +32,14 @@ DefaultRouteMonitor::DefaultRouteMonitor
{
std::scoped_lock<std::mutex> lock(m_evaluationLock);
- auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle);
+ auto status = NotifyRouteChange2(family, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle);
if (NO_ERROR != status)
{
THROW_WINDOWS_ERROR(status, "Register for route table change notifications");
}
- status = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this,
+ status = NotifyIpInterfaceChange(family, InterfaceChangeCallback, this,
FALSE, &m_interfaceNotificationHandle);
if (NO_ERROR != status)
@@ -47,6 +48,16 @@ DefaultRouteMonitor::DefaultRouteMonitor
THROW_WINDOWS_ERROR(status, "Register for network interface change notifications");
}
+ status = NotifyUnicastIpAddressChange(family, AddressChangeCallback, this,
+ FALSE, &m_addressNotificationHandle);
+
+ if (NO_ERROR != status)
+ {
+ CancelMibChangeNotify2(m_routeNotificationHandle);
+ CancelMibChangeNotify2(m_interfaceNotificationHandle);
+ THROW_WINDOWS_ERROR(status, "Register for unicast address change notifications");
+ }
+
try
{
m_bestRoute = GetBestDefaultRoute(m_family);
@@ -62,6 +73,7 @@ DefaultRouteMonitor::~DefaultRouteMonitor()
// Cancel notifications to stop triggering the BurstGuard.
//
+ CancelMibChangeNotify2(m_addressNotificationHandle);
CancelMibChangeNotify2(m_interfaceNotificationHandle);
CancelMibChangeNotify2(m_routeNotificationHandle);
@@ -91,18 +103,62 @@ void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback
return;
}
- reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+ const auto monitor = reinterpret_cast<DefaultRouteMonitor*>(context);
+ monitor->updateRefreshFlag(row->InterfaceLuid, row->InterfaceIndex);
+ monitor->m_evaluateRoutesGuard->trigger();
}
//static
void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback
(
void *context,
- MIB_IPINTERFACE_ROW *,
+ MIB_IPINTERFACE_ROW *row,
+ MIB_NOTIFICATION_TYPE
+)
+{
+ const auto monitor = reinterpret_cast<DefaultRouteMonitor*>(context);
+ monitor->updateRefreshFlag(row->InterfaceLuid, row->InterfaceIndex);
+ monitor->m_evaluateRoutesGuard->trigger();
+}
+
+//static
+void NETIOAPI_API_ DefaultRouteMonitor::AddressChangeCallback
+(
+ void *context,
+ MIB_UNICASTIPADDRESS_ROW *row,
MIB_NOTIFICATION_TYPE
)
{
- reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+ const auto monitor = reinterpret_cast<DefaultRouteMonitor*>(context);
+ monitor->updateRefreshFlag(row->InterfaceLuid, row->InterfaceIndex);
+ monitor->m_evaluateRoutesGuard->trigger();
+}
+
+void DefaultRouteMonitor::updateRefreshFlag(const NET_LUID &luid, const NET_IFINDEX &index)
+{
+ std::scoped_lock<std::mutex> lock(m_evaluationLock);
+
+ if (!m_bestRoute.has_value())
+ {
+ return;
+ }
+
+ if (luid.Value == m_bestRoute->iface.Value)
+ {
+ m_refreshCurrentRoute = true;
+ return;
+ }
+
+ if (luid.Value != 0)
+ {
+ return;
+ }
+
+ NET_IFINDEX defaultInterfaceIndex = 0;
+ const auto routeLuid = &m_bestRoute->iface;
+ ConvertInterfaceLuidToIndex(routeLuid, &defaultInterfaceIndex);
+ m_refreshCurrentRoute = index == defaultInterfaceIndex ||
+ (defaultInterfaceIndex == NET_IFINDEX_UNSPECIFIED);
}
void DefaultRouteMonitor::evaluateRoutes()
@@ -128,6 +184,9 @@ void DefaultRouteMonitor::evaluateRoutesInner()
{
std::optional<InterfaceAndGateway> currentBestRoute;
+ bool refreshCurrent = m_refreshCurrentRoute;
+ m_refreshCurrentRoute = false;
+
try
{
currentBestRoute = GetBestDefaultRoute(m_family);
@@ -172,6 +231,17 @@ void DefaultRouteMonitor::evaluateRoutesInner()
{
m_bestRoute = currentBestRoute;
m_callback(EventType::Updated, m_bestRoute);
+
+ return;
+ }
+
+ //
+ // Interface details may have changed.
+ //
+
+ if (refreshCurrent)
+ {
+ m_callback(EventType::UpdatedDetails, m_bestRoute);
}
}
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
index 5575685a82..ce2a3ce3f6 100644
--- a/windows/winnet/src/winnet/routing/defaultroutemonitor.h
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
@@ -22,6 +22,10 @@ public:
// The best default route changed.
Updated,
+ // Interface details changed; the associated interface and
+ // gateway did not.
+ UpdatedDetails,
+
// No default routes exist.
Removed,
};
@@ -53,14 +57,19 @@ private:
std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard;
std::optional<InterfaceAndGateway> m_bestRoute;
+ bool m_refreshCurrentRoute;
HANDLE m_routeNotificationHandle;
HANDLE m_interfaceNotificationHandle;
+ HANDLE m_addressNotificationHandle;
std::mutex m_evaluationLock;
static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType);
static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType);
+ static void NETIOAPI_API_ AddressChangeCallback(void *context, MIB_UNICASTIPADDRESS_ROW *row, MIB_NOTIFICATION_TYPE notificationType);
+
+ void updateRefreshFlag(const NET_LUID &luid, const NET_IFINDEX &index);
void evaluateRoutes();
void evaluateRoutesInner();
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index c1b864612e..2e35830c96 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -395,6 +395,7 @@ WinNet_RegisterDefaultRouteChangedCallback(
static const std::pair<from_t, to_t> eventTypeMap[] =
{
{ from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::UpdatedDetails, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED_DETAILS },
{ from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
};
@@ -418,11 +419,16 @@ WinNet_RegisterDefaultRouteChangedCallback(
// Determine which LUID and gateway to forward.
//
- if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ switch (eventType)
{
- const auto ips = winnet::ConvertNativeAddresses(&route.value().gateway, 1);
- defaultRoute.gateway = ips[0];
- defaultRoute.interfaceLuid = route.value().iface.Value;
+ case RouteManager::DefaultRouteChangedEventType::Updated:
+ case RouteManager::DefaultRouteChangedEventType::UpdatedDetails:
+ {
+ const auto ips = winnet::ConvertNativeAddresses(&route.value().gateway, 1);
+ defaultRoute.gateway = ips[0];
+ defaultRoute.interfaceLuid = route.value().iface.Value;
+ break;
+ }
}
//
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 3f2f80a5e5..f6f8d56074 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -168,8 +168,12 @@ enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
// Best default route changed.
WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0,
+ // The route (gateway or interface) did not change, but
+ // interface details may have changed.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED_DETAILS = 1,
+
// No default routes exist.
- WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1,
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 2,
};
typedef void (WINNET_API *WinNetDefaultRouteChangedCallback)