summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-11-09 11:02:39 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-11-09 11:02:39 +0100
commit2af40170885057f6fa6a40ffb071d2b998265a54 (patch)
treed9eecb898dbc276764e77013ef69d1c6fb3a31c6
parentc51cbe279921485e0ca19085d9e322c07508ced0 (diff)
parentb71abff37164d23ad8ebed5a8fee6e87d6b71608 (diff)
downloadmullvadvpn-2af40170885057f6fa6a40ffb071d2b998265a54.tar.xz
mullvadvpn-2af40170885057f6fa6a40ffb071d2b998265a54.zip
Merge branch 'simplify-windows-net'
-rw-r--r--talpid-routing/src/windows/get_best_default_route.rs65
-rw-r--r--talpid-routing/src/windows/route_manager.rs43
-rw-r--r--talpid-wireguard/src/wireguard_go.rs5
3 files changed, 49 insertions, 64 deletions
diff --git a/talpid-routing/src/windows/get_best_default_route.rs b/talpid-routing/src/windows/get_best_default_route.rs
index 940fd93644..37baf8c44b 100644
--- a/talpid-routing/src/windows/get_best_default_route.rs
+++ b/talpid-routing/src/windows/get_best_default_route.rs
@@ -1,5 +1,5 @@
use super::{Error, Result};
-use std::{convert::TryInto, io, net::SocketAddr};
+use std::{io, net::SocketAddr, slice};
use talpid_windows_net::{
get_ip_interface_entry, try_socketaddr_from_inet_sockaddr, AddressFamily,
};
@@ -22,7 +22,7 @@ const TUNNEL_INTERFACE_DESCS: [&WideCStr; 3] = [
widecstr!("Tunnel"),
];
-fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> {
+fn get_ip_forward_table(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> {
let family = family.to_af_family();
let mut table_ptr = std::ptr::null_mut();
@@ -36,32 +36,23 @@ fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>>
}
// SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error
- let num_entries = unsafe { *table_ptr }.NumEntries;
- let mut vec = Vec::with_capacity(num_entries.try_into().unwrap_or_default());
+ let num_entries = usize::try_from(unsafe { *table_ptr }.NumEntries).unwrap();
+ assert!(
+ num_entries
+ .checked_mul(std::mem::size_of::<MIB_IPFORWARD_ROW2>())
+ .unwrap()
+ <= usize::try_from(isize::MAX).unwrap()
+ );
+ // SAFETY: num_entries * size_of(MIB_IPFORWARD_ROW2) is at most isize::MAX
+ let rows = unsafe { slice::from_raw_parts((*table_ptr).Table.as_ptr(), num_entries) }.to_vec();
- for i in 0..num_entries {
- assert!(
- usize::try_from(i).unwrap() * std::mem::size_of::<MIB_IPFORWARD_ROW2>()
- < usize::try_from(isize::MAX).unwrap()
- );
-
- // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error nor have we
- // or will we modify the table
- let ptr: *const MIB_IPFORWARD_ROW2 = unsafe { (*table_ptr).Table.as_ptr() };
-
- // SAFETY: The assert guarantees that the amount of bytes we are jumping is not larger than
- // isize::MAX. Win32 guarantees that the resulting pointer is aligned, non-null,
- // init.
- let row: &MIB_IPFORWARD_ROW2 =
- unsafe { ptr.offset(i.try_into().unwrap()).as_ref() }.unwrap();
- vec.push(row.clone());
- }
// SAFETY: FreeMibTable does not have clear safety rules but it deallocates the
// MIB_IPFORWARD_TABLE2 This pointer is ONLY deallocated here so it is guaranteed to not
// have been already deallocated. We have cloned all MIB_IPFORWARD_ROW2s and the rows do not
// contain pointers to the table so they will not be dangling after this free.
unsafe { FreeMibTable(table_ptr as *const _) }
- Ok(vec)
+
+ Ok(rows)
}
/// General type for passing interface and gateway
@@ -81,7 +72,7 @@ impl PartialEq for InterfaceAndGateway {
/// Get the best default route for the given address family or None if none exists.
pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceAndGateway>> {
- let table = get_ipforward_rows(family)?;
+ let table = get_ip_forward_table(family)?;
// Remove all candidates without a gateway and which are not on a physical interface.
// Then get the annotated routes which are active.
@@ -92,29 +83,29 @@ pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceA
&& route_has_gateway(row)
&& is_route_on_physical_interface(row).unwrap_or(false)
})
- .filter_map(|row| annotate_route(row))
+ .filter_map(annotate_route)
.collect();
- if annotated.is_empty() {
- return Ok(None);
- }
-
// We previously filtered out all inactive routes so we only need to sort by acending
// effective_metric
annotated.sort_by(|lhs, rhs| lhs.effective_metric.cmp(&rhs.effective_metric));
- Ok(Some(InterfaceAndGateway {
- iface: annotated[0].route.InterfaceLuid,
- gateway: try_socketaddr_from_inet_sockaddr(annotated[0].route.NextHop)
- .map_err(|_| Error::InvalidSiFamily)?,
- }))
+ annotated
+ .get(0)
+ .map(|annotated| {
+ Ok(InterfaceAndGateway {
+ iface: annotated.route.InterfaceLuid,
+ gateway: try_socketaddr_from_inet_sockaddr(annotated.route.NextHop)
+ .map_err(|_| Error::InvalidSiFamily)?,
+ })
+ })
+ .transpose()
}
pub fn route_has_gateway(route: &MIB_IPFORWARD_ROW2) -> bool {
- match try_socketaddr_from_inet_sockaddr(route.NextHop) {
- Ok(sock) => !sock.ip().is_unspecified(),
- Err(_) => false,
- }
+ try_socketaddr_from_inet_sockaddr(route.NextHop)
+ .map(|addr| !addr.ip().is_unspecified())
+ .unwrap_or(false)
}
// TODO(Jon): It would be more correct to filter for devices that match the known LUID of the tunnel
diff --git a/talpid-routing/src/windows/route_manager.rs b/talpid-routing/src/windows/route_manager.rs
index ab4c9ef45b..691f39f453 100644
--- a/talpid-routing/src/windows/route_manager.rs
+++ b/talpid-routing/src/windows/route_manager.rs
@@ -528,26 +528,23 @@ impl RouteManagerInternal {
// changed. So removing and adding again is the only option.
//
- match Self::delete_from_routing_table(&affected_route.registered_route) {
- Ok(()) => (),
- Err(e) => {
- log::error!(
- "Failed to delete route when refreshing existing routes: {}",
- e
- );
- continue;
- }
+ if let Err(error) = Self::delete_from_routing_table(&affected_route.registered_route) {
+ log::error!(
+ "Failed to delete route when refreshing existing routes: {}",
+ error
+ );
+ continue;
}
affected_route.registered_route.luid = route.iface;
affected_route.registered_route.next_hop = route.gateway;
- match Self::restore_into_routing_table(&affected_route.registered_route) {
- Ok(()) => (),
- Err(e) => {
- log::error!("Failed to add route when refreshing existing routes: {}", e);
- continue;
- }
+ if let Err(error) = Self::restore_into_routing_table(&affected_route.registered_route) {
+ log::error!(
+ "Failed to add route when refreshing existing routes: {}",
+ error
+ );
+ continue;
}
}
}
@@ -601,11 +598,6 @@ fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> {
})
.collect();
- if matches.is_empty() {
- log::error!("Unable to find network adapter with specified gateway");
- return Err(Error::DeviceGatewayNotFound);
- }
-
// Sort matching interfaces ascending by metric.
//
@@ -621,8 +613,13 @@ fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> {
// Select the interface with the best (lowest) metric.
//
-
- Ok(matches[0].Luid)
+ matches
+ .get(0)
+ .map(|interface| interface.Luid)
+ .ok_or_else(|| {
+ log::error!("Unable to find network adapter with specified gateway");
+ Error::DeviceGatewayNotFound
+ })
}
/// SAFETY: adapter.FirstGatewayAddress must be dereferencable and must live as long as adapter
@@ -741,7 +738,7 @@ impl Adapters {
GetAdaptersAddresses(
family,
flags,
- std::ptr::null_mut() as *mut _,
+ std::ptr::null_mut(),
buffer_pointer as *mut IP_ADAPTER_ADDRESSES_LH,
&mut buffer_size,
)
diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go.rs
index 4232d99e70..aa962ace7d 100644
--- a/talpid-wireguard/src/wireguard_go.rs
+++ b/talpid-wireguard/src/wireguard_go.rs
@@ -217,11 +217,8 @@ impl WgGoTunnel {
let iface_idx: u32 = match event_type {
Updated(default_route) => {
let mut iface_idx = 0u32;
- // TODO: Make sure unwrap is fine
let iface_luid = default_route.iface;
- let status = unsafe {
- ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _)
- };
+ let status = unsafe { ConvertInterfaceLuidToIndex(&iface_luid, &mut iface_idx) };
if status != 0 {
log::error!(
"Failed to convert interface LUID to interface index: {}: {}",