summaryrefslogtreecommitdiffhomepage
path: root/talpid-core/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-09-11 19:20:25 +0200
committerDavid Lönnhager <david.l@mullvad.net>2020-11-09 14:54:58 +0100
commitc3921b60fd92e114368e1df07b977ec7178a5e65 (patch)
tree286ed72b083e4993110686a5219aaff1cfc5be8b /talpid-core/src
parentfc26f6223269889cf1d23579ac4ab07a2d603d4e (diff)
downloadmullvadvpn-c3921b60fd92e114368e1df07b977ec7178a5e65.tar.xz
mullvadvpn-c3921b60fd92e114368e1df07b977ec7178a5e65.zip
Add winnet function for obtaining the best default route
Diffstat (limited to 'talpid-core/src')
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs4
-rw-r--r--talpid-core/src/winnet.rs80
2 files changed, 77 insertions, 7 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index b4d437f046..e1ded8b3a0 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -165,7 +165,7 @@ impl WgGoTunnel {
pub unsafe extern "system" fn default_route_changed_callback(
event_type: winnet::WinNetDefaultRouteChangeEventType,
address_family: winnet::WinNetAddrFamily,
- interface_luid: u64,
+ default_route: winnet::WinNetDefaultRoute,
_ctx: *mut libc::c_void,
) {
use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex};
@@ -173,7 +173,7 @@ impl WgGoTunnel {
winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
let mut iface_idx = 0u32;
let iface_luid = NET_LUID {
- Value: interface_luid,
+ Value: default_route.interface_luid,
};
let status =
ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _);
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index ab9dff5d06..a648aad982 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -5,7 +5,7 @@ use ipnetwork::IpNetwork;
use libc::{c_void, wchar_t};
use std::{
ffi::{OsStr, OsString},
- net::IpAddr,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
ptr,
};
use widestring::WideCString;
@@ -29,6 +29,10 @@ pub enum Error {
#[error(display = "Failed to obtain GUID for the network interface")]
GetInterfaceGuid,
+ /// Failed to get the current default route.
+ #[error(display = "Failed to obtain default route")]
+ GetDefaultRoute,
+
/// Failed to read IPv6 status on the TAP network interface.
#[error(display = "Failed to read IPv6 status on the TAP network interface")]
GetIpv6Status,
@@ -142,6 +146,12 @@ pub enum WinNetAddrFamily {
IPV6 = 1,
}
+impl Default for WinNetAddrFamily {
+ fn default() -> Self {
+ WinNetAddrFamily::IPV4
+ }
+}
+
impl WinNetAddrFamily {
pub fn to_windows_proto_enum(&self) -> u16 {
match self {
@@ -152,9 +162,30 @@ impl WinNetAddrFamily {
}
#[repr(C)]
+#[derive(Default)]
pub struct WinNetIp {
- addr_family: WinNetAddrFamily,
- ip_bytes: [u8; 16],
+ pub addr_family: WinNetAddrFamily,
+ pub ip_bytes: [u8; 16],
+}
+
+#[repr(C)]
+#[derive(Default)]
+pub struct WinNetDefaultRoute {
+ pub interface_luid: u64,
+ pub gateway: WinNetIp,
+}
+
+impl From<WinNetIp> for IpAddr {
+ fn from(addr: WinNetIp) -> IpAddr {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => {
+ let mut bytes: [u8; 4] = Default::default();
+ bytes.clone_from_slice(&addr.ip_bytes[..4]);
+ IpAddr::V4(Ipv4Addr::from(bytes))
+ }
+ WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)),
+ }
+ }
}
impl From<IpAddr> for WinNetIp {
@@ -316,8 +347,8 @@ pub enum WinNetDefaultRouteChangeEventType {
pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
event_type: WinNetDefaultRouteChangeEventType,
- addr_family: WinNetAddrFamily,
- interface_luid: u64,
+ family: WinNetAddrFamily,
+ default_route: WinNetDefaultRoute,
ctx: *mut c_void,
);
@@ -359,6 +390,26 @@ pub fn deactivate_routing_manager() {
unsafe { WinNet_DeactivateRouteManager() }
}
+// TODO: Remove attribute once this is in use.
+#[allow(dead_code)]
+pub fn get_best_default_route(
+ family: WinNetAddrFamily,
+) -> Result<Option<WinNetDefaultRoute>, Error> {
+ let mut default_route = WinNetDefaultRoute::default();
+ match unsafe {
+ WinNet_GetBestDefaultRoute(
+ family,
+ &mut default_route as *mut _,
+ Some(log_sink),
+ logging_context(),
+ )
+ } {
+ GetBestDefaultRouteStatus::Success => Ok(Some(default_route)),
+ GetBestDefaultRouteStatus::NotFound => Ok(None),
+ GetBestDefaultRouteStatus::Failure => Err(Error::GetDefaultRoute),
+ }
+}
+
pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
let raw_iface = WideCString::from_str(iface)
.expect("Failed to convert UTF-8 string to null terminated UCS string")
@@ -379,6 +430,15 @@ mod api {
pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void);
+ #[allow(dead_code)]
+ #[repr(u32)]
+ pub enum FailableOptionalStatus {
+ Success = 0,
+ NotFound = 1,
+ Failure = 2,
+ }
+ pub type GetBestDefaultRouteStatus = FailableOptionalStatus;
+
extern "system" {
#[link_name = "WinNet_ActivateRouteManager"]
pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *const u8) -> bool;
@@ -415,6 +475,16 @@ mod api {
sink_context: *const u8,
) -> bool;
+ // TODO: Remove "allow(dead_code)" this is in use.
+ #[allow(dead_code)]
+ #[link_name = "WinNet_GetBestDefaultRoute"]
+ pub fn WinNet_GetBestDefaultRoute(
+ family: super::WinNetAddrFamily,
+ default_route: *mut super::WinNetDefaultRoute,
+ sink: Option<LogSink>,
+ sink_context: *const u8,
+ ) -> GetBestDefaultRouteStatus;
+
#[link_name = "WinNet_GetTapInterfaceAlias"]
pub fn WinNet_GetTapInterfaceAlias(
tunnel_interface_alias: *mut *mut wchar_t,