summaryrefslogtreecommitdiffhomepage
path: root/talpid-core
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-11-09 14:55:55 +0100
committerDavid Lönnhager <david.l@mullvad.net>2020-11-09 14:55:55 +0100
commit55169f326ce43da93649f3bc8d1eb78ec8593bed (patch)
tree7d53715b91e3590b3681d1d31b9febcf5b547d44 /talpid-core
parentfc26f6223269889cf1d23579ac4ab07a2d603d4e (diff)
parente429e5caa1cf783b05fc4db3880b41dea02d509c (diff)
downloadmullvadvpn-55169f326ce43da93649f3bc8d1eb78ec8593bed.tar.xz
mullvadvpn-55169f326ce43da93649f3bc8d1eb78ec8593bed.zip
Merge branch 'update-winnet'
Diffstat (limited to 'talpid-core')
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs4
-rw-r--r--talpid-core/src/winnet.rs116
2 files changed, 113 insertions, 7 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index b4d437f046..e1ded8b3a0 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -165,7 +165,7 @@ impl WgGoTunnel {
pub unsafe extern "system" fn default_route_changed_callback(
event_type: winnet::WinNetDefaultRouteChangeEventType,
address_family: winnet::WinNetAddrFamily,
- interface_luid: u64,
+ default_route: winnet::WinNetDefaultRoute,
_ctx: *mut libc::c_void,
) {
use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex};
@@ -173,7 +173,7 @@ impl WgGoTunnel {
winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
let mut iface_idx = 0u32;
let iface_luid = NET_LUID {
- Value: interface_luid,
+ Value: default_route.interface_luid,
};
let status =
ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _);
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index ab9dff5d06..9fed90ac21 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -5,7 +5,7 @@ use ipnetwork::IpNetwork;
use libc::{c_void, wchar_t};
use std::{
ffi::{OsStr, OsString},
- net::IpAddr,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
ptr,
};
use widestring::WideCString;
@@ -29,6 +29,14 @@ pub enum Error {
#[error(display = "Failed to obtain GUID for the network interface")]
GetInterfaceGuid,
+ /// Failed to get the current default route.
+ #[error(display = "Failed to obtain default route")]
+ GetDefaultRoute,
+
+ /// Failed to obtain an IP address given a LUID.
+ #[error(display = "Failed to obtain IP address for the given interface")]
+ GetIpAddressFromLuid,
+
/// Failed to read IPv6 status on the TAP network interface.
#[error(display = "Failed to read IPv6 status on the TAP network interface")]
GetIpv6Status,
@@ -142,6 +150,12 @@ pub enum WinNetAddrFamily {
IPV6 = 1,
}
+impl Default for WinNetAddrFamily {
+ fn default() -> Self {
+ WinNetAddrFamily::IPV4
+ }
+}
+
impl WinNetAddrFamily {
pub fn to_windows_proto_enum(&self) -> u16 {
match self {
@@ -152,9 +166,30 @@ impl WinNetAddrFamily {
}
#[repr(C)]
+#[derive(Default)]
pub struct WinNetIp {
- addr_family: WinNetAddrFamily,
- ip_bytes: [u8; 16],
+ pub addr_family: WinNetAddrFamily,
+ pub ip_bytes: [u8; 16],
+}
+
+#[repr(C)]
+#[derive(Default)]
+pub struct WinNetDefaultRoute {
+ pub interface_luid: u64,
+ pub gateway: WinNetIp,
+}
+
+impl From<WinNetIp> for IpAddr {
+ fn from(addr: WinNetIp) -> IpAddr {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => {
+ let mut bytes: [u8; 4] = Default::default();
+ bytes.clone_from_slice(&addr.ip_bytes[..4]);
+ IpAddr::V4(Ipv4Addr::from(bytes))
+ }
+ WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)),
+ }
+ }
}
impl From<IpAddr> for WinNetIp {
@@ -316,8 +351,8 @@ pub enum WinNetDefaultRouteChangeEventType {
pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
event_type: WinNetDefaultRouteChangeEventType,
- addr_family: WinNetAddrFamily,
- interface_luid: u64,
+ family: WinNetAddrFamily,
+ default_route: WinNetDefaultRoute,
ctx: *mut c_void,
);
@@ -359,6 +394,48 @@ pub fn deactivate_routing_manager() {
unsafe { WinNet_DeactivateRouteManager() }
}
+// TODO: Remove attribute once this is in use.
+#[allow(dead_code)]
+pub fn get_best_default_route(
+ family: WinNetAddrFamily,
+) -> Result<Option<WinNetDefaultRoute>, Error> {
+ let mut default_route = WinNetDefaultRoute::default();
+ match unsafe {
+ WinNet_GetBestDefaultRoute(
+ family,
+ &mut default_route as *mut _,
+ Some(log_sink),
+ logging_context(),
+ )
+ } {
+ WinNetStatus::Success => Ok(Some(default_route)),
+ WinNetStatus::NotFound => Ok(None),
+ WinNetStatus::Failure => Err(Error::GetDefaultRoute),
+ }
+}
+
+// TODO: Remove attribute once this is in use.
+#[allow(dead_code)]
+pub fn interface_luid_to_ip(
+ family: WinNetAddrFamily,
+ luid: u64,
+) -> Result<Option<WinNetIp>, Error> {
+ let mut ip = WinNetIp::default();
+ match unsafe {
+ WinNet_InterfaceLuidToIpAddress(
+ family,
+ luid,
+ &mut ip as *mut _,
+ Some(log_sink),
+ logging_context(),
+ )
+ } {
+ WinNetStatus::Success => Ok(Some(ip)),
+ WinNetStatus::NotFound => Ok(None),
+ WinNetStatus::Failure => Err(Error::GetIpAddressFromLuid),
+ }
+}
+
pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
let raw_iface = WideCString::from_str(iface)
.expect("Failed to convert UTF-8 string to null terminated UCS string")
@@ -379,6 +456,14 @@ mod api {
pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void);
+ #[allow(dead_code)]
+ #[repr(u32)]
+ pub enum WinNetStatus {
+ Success = 0,
+ NotFound = 1,
+ Failure = 2,
+ }
+
extern "system" {
#[link_name = "WinNet_ActivateRouteManager"]
pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *const u8) -> bool;
@@ -415,6 +500,27 @@ mod api {
sink_context: *const u8,
) -> bool;
+ // TODO: Remove "allow(dead_code)" this is in use.
+ #[allow(dead_code)]
+ #[link_name = "WinNet_GetBestDefaultRoute"]
+ pub fn WinNet_GetBestDefaultRoute(
+ family: super::WinNetAddrFamily,
+ default_route: *mut super::WinNetDefaultRoute,
+ sink: Option<LogSink>,
+ sink_context: *const u8,
+ ) -> WinNetStatus;
+
+ // TODO: Remove "allow(dead_code)" this is in use.
+ #[allow(dead_code)]
+ #[link_name = "WinNet_InterfaceLuidToIpAddress"]
+ pub fn WinNet_InterfaceLuidToIpAddress(
+ family: super::WinNetAddrFamily,
+ luid: u64,
+ ip: *mut super::WinNetIp,
+ sink: Option<LogSink>,
+ sink_context: *const u8,
+ ) -> WinNetStatus;
+
#[link_name = "WinNet_GetTapInterfaceAlias"]
pub fn WinNet_GetTapInterfaceAlias(
tunnel_interface_alias: *mut *mut wchar_t,