diff options
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 65 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 13 | ||||
| -rw-r--r-- | talpid-core/src/windows/mod.rs | 50 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 51 | ||||
| -rw-r--r-- | windows/libshared/src/libshared/network/interfaceutils.cpp | 21 | ||||
| -rw-r--r-- | windows/libshared/src/libshared/network/interfaceutils.h | 2 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/converters.cpp | 13 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/converters.h | 1 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 118 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.def | 2 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.h | 25 |
11 files changed, 102 insertions, 259 deletions
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index d177b4f198..24679aa934 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -6,10 +6,12 @@ mod windows; use crate::{ tunnel::TunnelMetadata, tunnel_state_machine::TunnelCommand, - windows::window::{PowerManagementEvent, PowerManagementListener}, - winnet::{ - self, get_best_default_route, interface_luid_to_ip, WinNetAddrFamily, WinNetCallbackHandle, + windows::{ + get_ip_address_for_interface, + window::{PowerManagementEvent, PowerManagementListener}, + AddressFamily, }, + winnet::{self, get_best_default_route, WinNetAddrFamily, WinNetCallbackHandle}, }; use futures::channel::{mpsc, oneshot}; use std::{ @@ -27,7 +29,7 @@ use std::{ time::Duration, }; use talpid_types::{tunnel::ErrorStateCause, ErrorExt}; -use winapi::shared::winerror::ERROR_OPERATION_ABORTED; +use winapi::shared::{ifdef::NET_LUID, winerror::ERROR_OPERATION_ABORTED}; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); @@ -66,7 +68,7 @@ pub enum Error { /// Failed to obtain an IP address given a network interface LUID #[error(display = "Failed to obtain IP address for interface LUID")] - LuidToIp(#[error(source)] winnet::Error), + LuidToIp(#[error(source)] crate::windows::Error), /// Failed to set up callback for monitoring default route changes #[error(display = "Failed to register default route change callback")] @@ -761,11 +763,19 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4) .map_err(Error::ObtainDefaultRoute)? .map(|route| { - interface_luid_to_ip(WinNetAddrFamily::IPV4, route.interface_luid).map(|ip| { - ip.or_else(|| { + get_ip_address_for_interface( + AddressFamily::Ipv4, + NET_LUID { + Value: route.interface_luid, + }, + ) + .map(|ip| match ip { + Some(IpAddr::V4(addr)) => Some(addr), + Some(_) => unreachable!("wrong address family (expected IPv4)"), + None => { log::warn!("No IPv4 address was found for the default route interface"); None - }) + } }) }) .transpose() @@ -774,23 +784,28 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6) .map_err(Error::ObtainDefaultRoute)? .map(|route| { - interface_luid_to_ip(WinNetAddrFamily::IPV6, route.interface_luid).map(|ip| { - ip.or_else(|| { + get_ip_address_for_interface( + AddressFamily::Ipv6, + NET_LUID { + Value: route.interface_luid, + }, + ) + .map(|ip| match ip { + Some(IpAddr::V6(addr)) => Some(addr), + Some(_) => unreachable!("wrong address family (expected IPv6)"), + None => { log::warn!("No IPv6 address was found for the default route interface"); None - }) + } }) }) .transpose() .map_err(Error::LuidToIp)? .flatten(); - self.addresses.internet_ipv4 = internet_ipv4 - .map(|addr| Ipv4Addr::try_from(addr).map_err(|_| Error::IpParseError)) - .transpose()?; - self.addresses.internet_ipv6 = internet_ipv6 - .map(|addr| Ipv6Addr::try_from(addr).map_err(|_| Error::IpParseError)) - .transpose()?; + self.addresses.internet_ipv4 = internet_ipv4; + self.addresses.internet_ipv6 = internet_ipv6; + Ok(()) } } @@ -814,9 +829,16 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( } }; + let translated_family = winnet_to_talpid_family(address_family); + let result = match event_type { DefaultRouteChanged | DefaultRouteUpdatedDetails => { - match interface_luid_to_ip(address_family, default_route.interface_luid) { + match get_ip_address_for_interface( + translated_family, + NET_LUID { + Value: default_route.interface_luid, + }, + ) { Ok(Some(ip)) => match IpAddr::from(ip) { IpAddr::V4(addr) => ctx.addresses.internet_ipv4 = Some(addr), IpAddr::V6(addr) => ctx.addresses.internet_ipv6 = Some(addr), @@ -868,3 +890,10 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); } } + +fn winnet_to_talpid_family(address_family: WinNetAddrFamily) -> AddressFamily { + match address_family { + WinNetAddrFamily::IPV4 => AddressFamily::Ipv4, + WinNetAddrFamily::IPV6 => AddressFamily::Ipv6, + } +} diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 5d5317a2eb..5e8c0ede49 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -93,7 +93,7 @@ pub enum Error { /// Failed to set IP addresses on WireGuard interface #[cfg(target_os = "windows")] #[error(display = "Failed to set IP addresses on WireGuard interface")] - SetIpAddressesError, + SetIpAddressesError(#[error(source)] crate::windows::Error), } /// Spawns and monitors a wireguard tunnel @@ -419,8 +419,15 @@ impl WireguardMonitor { ); CloseMsg::SetupError(Error::IpInterfacesError) })?; - if !crate::winnet::add_device_ip_addresses(iface_name, addresses) { - return Err(CloseMsg::SetupError(Error::SetIpAddressesError)); + + // TODO: The LUID can be obtained directly. + let luid = crate::windows::luid_from_alias(iface_name).map_err(|error| { + log::error!("Failed to obtain tunnel interface LUID: {}", error); + CloseMsg::SetupError(Error::IpInterfacesError) + })?; + for address in addresses { + crate::windows::add_ip_address_for_interface(luid, *address) + .map_err(|error| CloseMsg::SetupError(Error::SetIpAddressesError(error)))?; } Ok(()) } diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs index ae15c9cb9c..6b97ee389e 100644 --- a/talpid-core/src/windows/mod.rs +++ b/talpid-core/src/windows/mod.rs @@ -3,7 +3,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, io, mem::{self, MaybeUninit}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::windows::{ ffi::{OsStrExt, OsStringExt}, io::RawHandle, @@ -22,10 +22,11 @@ use winapi::{ inaddr::IN_ADDR, netioapi::{ CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, - ConvertInterfaceLuidToGuid, FreeMibTable, GetIpInterfaceEntry, - GetUnicastIpAddressEntry, GetUnicastIpAddressTable, MibAddInstance, - NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, - MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE, + ConvertInterfaceLuidToGuid, CreateUnicastIpAddressEntry, FreeMibTable, + GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, + InitializeUnicastIpAddressEntry, MibAddInstance, NotifyIpInterfaceChange, + SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, MIB_UNICASTIPADDRESS_ROW, + MIB_UNICASTIPADDRESS_TABLE, }, nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE}, ntddndis::NDIS_IF_MAX_STRING_SIZE, @@ -72,6 +73,11 @@ pub enum Error { #[error(display = "Found no addresses for the given adapter")] NoUnicastAddress, + /// Error returned from `CreateUnicastIpAddressEntry` + #[cfg(windows)] + #[error(display = "Failed to create unicast IP address")] + CreateUnicastEntry(#[error(source)] io::Error), + /// Unexpected DAD state returned for a unicast address #[cfg(windows)] #[error(display = "Unexpected DAD state")] @@ -342,6 +348,40 @@ pub async fn wait_for_addresses(luid: NET_LUID) -> Result<()> { rx.await.map_err(|_| Error::UnicastSenderDropped)? } +/// Returns the first unicast IP address for the given interface. +pub fn get_ip_address_for_interface( + family: AddressFamily, + luid: NET_LUID, +) -> Result<Option<IpAddr>> { + match get_unicast_table(Some(family)) + .map_err(Error::ObtainUnicastAddress)? + .into_iter() + .find(|row| row.InterfaceLuid.Value == luid.Value) + { + Some(row) => Ok(Some(try_socketaddr_from_inet_sockaddr(row.Address)?.ip())), + None => Ok(None), + } +} + +/// Adds a unicast IP address for the given interface. +pub fn add_ip_address_for_interface(luid: NET_LUID, address: IpAddr) -> Result<()> { + let mut row = unsafe { mem::zeroed() }; + unsafe { InitializeUnicastIpAddressEntry(&mut row) }; + + row.InterfaceLuid = luid; + row.Address = inet_sockaddr_from_socketaddr(SocketAddr::new(address, 0)); + row.DadState = IpDadStatePreferred; + row.OnLinkPrefixLength = 255; + + let status = unsafe { CreateUnicastIpAddressEntry(&row) }; + if status != NO_ERROR { + return Err(Error::CreateUnicastEntry(io::Error::from_raw_os_error( + status as i32, + ))); + } + Ok(()) +} + /// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are /// returned. pub fn get_unicast_table( diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 50ae88b6da..9843d873aa 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -347,43 +347,10 @@ pub fn get_best_default_route( } } -pub fn interface_luid_to_ip( - family: WinNetAddrFamily, - luid: u64, -) -> Result<Option<WinNetIp>, Error> { - let mut ip = WinNetIp::default(); - match unsafe { - WinNet_InterfaceLuidToIpAddress( - family, - luid, - &mut ip as *mut _, - Some(log_sink), - logging_context(), - ) - } { - WinNetStatus::Success => Ok(Some(ip)), - WinNetStatus::NotFound => Ok(None), - WinNetStatus::Failure => Err(Error::GetIpAddressFromLuid), - } -} - -pub fn add_device_ip_addresses(iface: &str, addresses: &[IpAddr]) -> bool { - let raw_iface = WideCString::from_str(iface) - .expect("Failed to convert UTF-8 string to null terminated UCS string") - .into_raw(); - let converted_addresses: Vec<_> = addresses.iter().map(|addr| WinNetIp::from(*addr)).collect(); - let ptr = converted_addresses.as_ptr(); - let length: u32 = converted_addresses.len() as u32; - unsafe { - WinNet_AddDeviceIpAddresses(raw_iface, ptr, length, Some(log_sink), logging_context()) - } -} - #[allow(non_snake_case)] mod api { use super::DefaultRouteChangedCallback; use crate::logging::windows::LogSink; - use libc::wchar_t; #[allow(dead_code)] #[repr(u32)] @@ -436,15 +403,6 @@ mod api { sink_context: *const u8, ) -> WinNetStatus; - #[link_name = "WinNet_InterfaceLuidToIpAddress"] - pub fn WinNet_InterfaceLuidToIpAddress( - family: super::WinNetAddrFamily, - luid: u64, - ip: *mut super::WinNetIp, - sink: Option<LogSink>, - sink_context: *const u8, - ) -> WinNetStatus; - #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"] pub fn WinNet_RegisterDefaultRouteChangedCallback( callback: Option<DefaultRouteChangedCallback>, @@ -454,14 +412,5 @@ mod api { #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"] pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void); - - #[link_name = "WinNet_AddDeviceIpAddresses"] - pub fn WinNet_AddDeviceIpAddresses( - interface_alias: *const wchar_t, - addresses: *const super::WinNetIp, - num_addresses: u32, - sink: Option<LogSink>, - sink_context: *const u8, - ) -> bool; } } diff --git a/windows/libshared/src/libshared/network/interfaceutils.cpp b/windows/libshared/src/libshared/network/interfaceutils.cpp index e381a2abca..fcfcb47a0e 100644 --- a/windows/libshared/src/libshared/network/interfaceutils.cpp +++ b/windows/libshared/src/libshared/network/interfaceutils.cpp @@ -55,25 +55,4 @@ std::set<InterfaceUtils::NetworkAdapter> InterfaceUtils::GetAllAdapters(ULONG fa return adapters; } -//static -void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses) -{ - for (const auto &address : addresses) - { - MIB_UNICASTIPADDRESS_ROW row; - InitializeUnicastIpAddressEntry(&row); - - row.InterfaceLuid = device; - row.Address = address; - row.DadState = IpDadStatePreferred; - - const auto status = CreateUnicastIpAddressEntry(&row); - - if (NO_ERROR != status) - { - THROW_WINDOWS_ERROR(status, "Assign IP address on network interface"); - } - } -} - } diff --git a/windows/libshared/src/libshared/network/interfaceutils.h b/windows/libshared/src/libshared/network/interfaceutils.h index 0fff359d08..1ab637eaf6 100644 --- a/windows/libshared/src/libshared/network/interfaceutils.h +++ b/windows/libshared/src/libshared/network/interfaceutils.h @@ -62,8 +62,6 @@ public: }; static std::set<NetworkAdapter> GetAllAdapters(ULONG family, ULONG flags); - - static void AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses); }; } diff --git a/windows/winnet/src/winnet/converters.cpp b/windows/winnet/src/winnet/converters.cpp index dc3f333c6c..c09ae6339d 100644 --- a/windows/winnet/src/winnet/converters.cpp +++ b/windows/winnet/src/winnet/converters.cpp @@ -129,19 +129,6 @@ std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes) return out; } -std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses) -{ - std::vector<SOCKADDR_INET> out; - out.reserve(numAddresses); - - for (uint32_t i = 0; i < numAddresses; ++i) - { - out.emplace_back(IpToNative(addresses[i])); - } - - return out; -} - std::vector<WINNET_IP> ConvertNativeAddresses(const SOCKADDR_INET *addresses, uint32_t numAddresses) { std::vector<WINNET_IP> out; diff --git a/windows/winnet/src/winnet/converters.h b/windows/winnet/src/winnet/converters.h index e094728902..a608958cee 100644 --- a/windows/winnet/src/winnet/converters.h +++ b/windows/winnet/src/winnet/converters.h @@ -11,7 +11,6 @@ namespace winnet routing::Network ConvertNetwork(const WINNET_IP_NETWORK &in); std::optional<routing::Node> ConvertNode(const WINNET_NODE *in); std::vector<routing::Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes); -std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses); std::vector<WINNET_IP> ConvertNativeAddresses(const SOCKADDR_INET *addresses, uint32_t numAddresses); } diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index 1c0555a10c..aae9f0f44e 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -79,75 +79,6 @@ WinNet_GetBestDefaultRoute( extern "C"
WINNET_LINKAGE
-WINNET_STATUS
-WINNET_API
-WinNet_InterfaceLuidToIpAddress(
- WINNET_ADDR_FAMILY family,
- uint64_t interfaceLuid,
- WINNET_IP *ip,
- MullvadLogSink logSink,
- void *logSinkContext
-)
-{
- try
- {
- if (nullptr == ip)
- {
- THROW_ERROR("Invalid argument: ip");
- }
-
- static const std::pair<WINNET_ADDR_FAMILY, ADDRESS_FAMILY> familyMap[] =
- {
- { WINNET_ADDR_FAMILY_IPV4, static_cast<ADDRESS_FAMILY>(AF_INET) },
- { WINNET_ADDR_FAMILY_IPV6, static_cast<ADDRESS_FAMILY>(AF_INET6) }
- };
- const auto win_family = common::ValueMapper::Map<>(family, familyMap);
-
- MIB_UNICASTIPADDRESS_TABLE *table = nullptr;
- const auto status = GetUnicastIpAddressTable(win_family, &table);
-
- if (NO_ERROR != status)
- {
- THROW_WINDOWS_ERROR(status, "GetUnicastIpAddressTable");
- }
-
- common::memory::ScopeDestructor destructor;
-
- destructor += [table]() {
- FreeMibTable(table);
- };
-
- for (ULONG i = 0; i < table->NumEntries; i++)
- {
- const auto entry = table->Table[i];
-
- if (interfaceLuid != entry.InterfaceLuid.Value)
- {
- continue;
- }
-
- // Found IP address
- const auto ips = winnet::ConvertNativeAddresses(&entry.Address, 1);
- *ip = ips[0];
-
- return WINNET_STATUS_SUCCESS;
- }
-
- return WINNET_STATUS_NOT_FOUND;
- }
- catch (const std::exception & err)
- {
- shared::logging::UnwindAndLog(logSink, logSinkContext, err);
- return WINNET_STATUS_FAILURE;
- }
- catch (...)
- {
- return WINNET_STATUS_FAILURE;
- }
-}
-
-extern "C"
-WINNET_LINKAGE
bool
WINNET_API
WinNet_ActivateRouteManager(
@@ -469,52 +400,3 @@ WinNet_DeactivateRouteManager( {
}
}
-
-extern "C"
-WINNET_LINKAGE
-bool
-WINNET_API
-WinNet_AddDeviceIpAddresses(
- const wchar_t *deviceAlias,
- const WINNET_IP *addresses,
- uint32_t numAddresses,
- MullvadLogSink logSink,
- void *logSinkContext
-)
-{
- try
- {
- if (nullptr == deviceAlias)
- {
- THROW_ERROR("Invalid argument: deviceAlias")
- }
-
- if (nullptr == addresses)
- {
- THROW_ERROR("Invalid argument: addresses")
- }
-
- NET_LUID luid;
-
- if (0 != ConvertInterfaceAliasToLuid(deviceAlias, &luid))
- {
- const auto msg = std::string("Unable to derive interface LUID from interface alias: ")
- .append(common::string::ToAnsi(deviceAlias));
-
- THROW_ERROR(msg.c_str());
- }
-
- InterfaceUtils::AddDeviceIpAddresses(luid, winnet::ConvertAddresses(addresses, numAddresses));
-
- return true;
- }
- catch (const std::exception &err)
- {
- shared::logging::UnwindAndLog(logSink, logSinkContext, err);
- return false;
- }
- catch (...)
- {
- return false;
- }
-}
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def index a5f9a63863..0bc759d8bc 100644 --- a/windows/winnet/src/winnet/winnet.def +++ b/windows/winnet/src/winnet/winnet.def @@ -2,6 +2,4 @@ LIBRARY winnet EXPORTS WinNet_ActivateRouteManager WinNet_DeactivateRouteManager - WinNet_AddDeviceIpAddresses WinNet_GetBestDefaultRoute - WinNet_InterfaceLuidToIpAddress diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 0c72c91eb8..38d1386a58 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -134,18 +134,6 @@ WinNet_GetBestDefaultRoute( void *logSinkContext ); -extern "C" -WINNET_LINKAGE -WINNET_STATUS -WINNET_API -WinNet_InterfaceLuidToIpAddress( - WINNET_ADDR_FAMILY family, - uint64_t interfaceLuid, - WINNET_IP *ip, - MullvadLogSink logSink, - void *logSinkContext -); - enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE { // Best default route changed. @@ -196,16 +184,3 @@ void WINNET_API WinNet_DeactivateRouteManager( ); - -extern "C" -WINNET_LINKAGE -bool -WINNET_API -WinNet_AddDeviceIpAddresses( - const wchar_t *deviceAlias, - const WINNET_IP *addresses, - uint32_t numAddresses, - MullvadLogSink logSink, - void *logSinkContext -); - |
