summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorLinus Färnstrand <linus@mullvad.net>2023-08-03 09:18:16 +0200
committerLinus Färnstrand <linus@mullvad.net>2023-08-08 08:39:59 +0200
commite350149d6dd85834cfa7b3dee4a959854cf5302b (patch)
tree94fd04c9272a84cd5cb39b10005e2525807dd47f
parent7f948ba86b041c57ee4a9a8f11ec21eabff66c31 (diff)
downloadmullvadvpn-e350149d6dd85834cfa7b3dee4a959854cf5302b.tar.xz
mullvadvpn-e350149d6dd85834cfa7b3dee4a959854cf5302b.zip
Adapt talpid-windows-net to windows-sys 0.48
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-types/src/error.rs70
-rw-r--r--talpid-types/src/lib.rs55
-rw-r--r--talpid-windows-net/Cargo.toml2
-rw-r--r--talpid-windows-net/src/net.rs76
5 files changed, 95 insertions, 109 deletions
diff --git a/Cargo.lock b/Cargo.lock
index f4e3d1a322..eb7309d0db 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3774,6 +3774,7 @@ dependencies = [
"futures",
"libc",
"socket2 0.4.9",
+ "talpid-types",
"winapi",
"windows-sys 0.48.0",
]
diff --git a/talpid-types/src/error.rs b/talpid-types/src/error.rs
new file mode 100644
index 0000000000..77f0664375
--- /dev/null
+++ b/talpid-types/src/error.rs
@@ -0,0 +1,70 @@
+use std::{error::Error, fmt, fmt::Write};
+
+/// Used to generate string representations of error chains.
+pub trait ErrorExt {
+ /// Creates a string representation of the entire error chain.
+ fn display_chain(&self) -> String;
+
+ /// Like [Self::display_chain] but with an extra message at the start of the chain
+ fn display_chain_with_msg(&self, msg: &str) -> String;
+}
+
+impl<E: Error> ErrorExt for E {
+ fn display_chain(&self) -> String {
+ let mut s = format!("Error: {self}");
+ let mut source = self.source();
+ while let Some(error) = source {
+ write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
+ source = error.source();
+ }
+ s
+ }
+
+ fn display_chain_with_msg(&self, msg: &str) -> String {
+ let mut s = format!("Error: {msg}\nCaused by: {self}");
+ let mut source = self.source();
+ while let Some(error) = source {
+ write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
+ source = error.source();
+ }
+ s
+ }
+}
+
+#[derive(Debug)]
+pub struct BoxedError(Box<dyn Error + 'static + Send>);
+
+impl fmt::Display for BoxedError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl Error for BoxedError {
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ self.0.source()
+ }
+}
+
+impl BoxedError {
+ pub fn new(error: impl Error + 'static + Send) -> Self {
+ BoxedError(Box::new(error))
+ }
+}
+
+/// Helper macro allowing simpler handling of Windows FFI returning `WIN32_ERROR`
+/// status codes. Converts a `WIN32_ERROR` into an `io::Result<()>`.
+///
+/// The caller of this macro must have `windows_sys` as a dependency.
+#[cfg(windows)]
+#[macro_export]
+macro_rules! win32_err {
+ ($expr:expr) => {{
+ let status = $expr;
+ if status == ::windows_sys::Win32::Foundation::NO_ERROR {
+ Ok(())
+ } else {
+ Err(::std::io::Error::from_raw_os_error(status as i32))
+ }
+ }};
+}
diff --git a/talpid-types/src/lib.rs b/talpid-types/src/lib.rs
index 92c0a2ce76..ee2c6d172a 100644
--- a/talpid-types/src/lib.rs
+++ b/talpid-types/src/lib.rs
@@ -1,7 +1,5 @@
#![deny(rust_2018_idioms)]
-use std::{error::Error, fmt, fmt::Write};
-
#[cfg(target_os = "android")]
pub mod android;
pub mod net;
@@ -13,54 +11,5 @@ pub mod cgroup;
#[cfg(target_os = "windows")]
pub mod split_tunnel;
-/// Used to generate string representations of error chains.
-pub trait ErrorExt {
- /// Creates a string representation of the entire error chain.
- fn display_chain(&self) -> String;
-
- /// Like [Self::display_chain] but with an extra message at the start of the chain
- fn display_chain_with_msg(&self, msg: &str) -> String;
-}
-
-impl<E: Error> ErrorExt for E {
- fn display_chain(&self) -> String {
- let mut s = format!("Error: {self}");
- let mut source = self.source();
- while let Some(error) = source {
- write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
- source = error.source();
- }
- s
- }
-
- fn display_chain_with_msg(&self, msg: &str) -> String {
- let mut s = format!("Error: {msg}\nCaused by: {self}");
- let mut source = self.source();
- while let Some(error) = source {
- write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
- source = error.source();
- }
- s
- }
-}
-
-#[derive(Debug)]
-pub struct BoxedError(Box<dyn Error + 'static + Send>);
-
-impl fmt::Display for BoxedError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.0.fmt(f)
- }
-}
-
-impl Error for BoxedError {
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- self.0.source()
- }
-}
-
-impl BoxedError {
- pub fn new(error: impl Error + 'static + Send) -> Self {
- BoxedError(Box::new(error))
- }
-}
+mod error;
+pub use error::*;
diff --git a/talpid-windows-net/Cargo.toml b/talpid-windows-net/Cargo.toml
index 40caf827f4..351bc539ea 100644
--- a/talpid-windows-net/Cargo.toml
+++ b/talpid-windows-net/Cargo.toml
@@ -15,6 +15,8 @@ socket2 = { version = "0.4.2", features = ["all"] }
futures = "0.3.15"
winapi = { version = "0.3.6", features = ["ws2def"] }
+talpid-types = { path = "../talpid-types" }
+
[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
features = [
diff --git a/talpid-windows-net/src/net.rs b/talpid-windows-net/src/net.rs
index cd7ff14358..3c0d4d6880 100644
--- a/talpid-windows-net/src/net.rs
+++ b/talpid-windows-net/src/net.rs
@@ -9,11 +9,12 @@ use std::{
sync::Mutex,
time::{Duration, Instant},
};
+use talpid_types::win32_err;
use winapi::shared::ws2def::SOCKADDR_STORAGE as sockaddr_storage;
use windows_sys::{
core::GUID,
Win32::{
- Foundation::{ERROR_NOT_FOUND, HANDLE, NO_ERROR},
+ Foundation::{ERROR_NOT_FOUND, HANDLE},
NetworkManagement::{
IpHelper::{
CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias,
@@ -174,7 +175,7 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send
handle: 0,
});
- let status = unsafe {
+ win32_err!(unsafe {
NotifyIpInterfaceChange(
af_family_from_family(family),
Some(inner_callback),
@@ -182,13 +183,8 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send
0,
(&mut context.handle) as *mut _,
)
- };
-
- if status == NO_ERROR as i32 {
- Ok(context)
- } else {
- Err(io::Error::from_raw_os_error(status))
- }
+ })?;
+ Ok(context)
}
/// Returns information about a network IP interface.
@@ -200,22 +196,13 @@ pub fn get_ip_interface_entry(
row.Family = family as u16;
row.InterfaceLuid = *luid;
- let result = unsafe { GetIpInterfaceEntry(&mut row) };
- if result == NO_ERROR as i32 {
- Ok(row)
- } else {
- Err(io::Error::from_raw_os_error(result))
- }
+ win32_err!(unsafe { GetIpInterfaceEntry(&mut row) })?;
+ Ok(row)
}
/// Set the properties of an IP interface.
pub fn set_ip_interface_entry(row: &mut MIB_IPINTERFACE_ROW) -> io::Result<()> {
- let result = unsafe { SetIpInterfaceEntry(row as *mut _) };
- if result == NO_ERROR as i32 {
- Ok(())
- } else {
- Err(io::Error::from_raw_os_error(result))
- }
+ win32_err!(unsafe { SetIpInterfaceEntry(row as *mut _) })
}
fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID_LH) -> io::Result<bool> {
@@ -293,12 +280,8 @@ pub async fn wait_for_addresses(luid: NET_LUID_LH) -> Result<()> {
let mut ready = true;
for row in &mut unicast_rows {
- let status = unsafe { GetUnicastIpAddressEntry(row) };
- if status != NO_ERROR as i32 {
- return Err(Error::ObtainUnicastAddress(io::Error::from_raw_os_error(
- status,
- )));
- }
+ win32_err!(unsafe { GetUnicastIpAddressEntry(row) })
+ .map_err(Error::ObtainUnicastAddress)?;
if row.DadState == IpDadStateTentative {
ready = false;
break;
@@ -347,13 +330,7 @@ pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Resul
row.DadState = IpDadStatePreferred;
row.OnLinkPrefixLength = 255;
- let status = unsafe { CreateUnicastIpAddressEntry(&row) };
- if status != NO_ERROR as i32 {
- return Err(Error::CreateUnicastEntry(io::Error::from_raw_os_error(
- status,
- )));
- }
- Ok(())
+ win32_err!(unsafe { CreateUnicastIpAddressEntry(&row) }).map_err(Error::CreateUnicastEntry)
}
/// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are
@@ -364,11 +341,9 @@ pub fn get_unicast_table(
let mut unicast_rows = vec![];
let mut unicast_table: *mut MIB_UNICASTIPADDRESS_TABLE = std::ptr::null_mut();
- let status =
- unsafe { GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table) };
- if status != NO_ERROR as i32 {
- return Err(io::Error::from_raw_os_error(status));
- }
+ win32_err!(unsafe {
+ GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table)
+ })?;
let first_row = unsafe { &(*unicast_table).Table[0] } as *const MIB_UNICASTIPADDRESS_ROW;
for i in 0..unsafe { *unicast_table }.NumEntries {
unicast_rows.push(unsafe { *(first_row.offset(i as isize)) });
@@ -381,20 +356,14 @@ pub fn get_unicast_table(
/// Returns the index of a network interface given its LUID.
pub fn index_from_luid(luid: &NET_LUID_LH) -> io::Result<u32> {
let mut index = 0u32;
- let status = unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) };
- if status != NO_ERROR as i32 {
- return Err(io::Error::from_raw_os_error(status));
- }
+ win32_err!(unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) })?;
Ok(index)
}
/// Returns the GUID of a network interface given its LUID.
pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> {
let mut guid = MaybeUninit::zeroed();
- let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) };
- if status != NO_ERROR as i32 {
- return Err(io::Error::from_raw_os_error(status));
- }
+ win32_err!(unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) })?;
Ok(unsafe { guid.assume_init() })
}
@@ -406,21 +375,16 @@ pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID_LH> {
.chain(std::iter::once(0u16))
.collect();
let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() };
- let status = unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) };
- if status != NO_ERROR as i32 {
- return Err(io::Error::from_raw_os_error(status));
- }
+ win32_err!(unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) })?;
Ok(luid)
}
/// Returns the alias of an interface given its LUID.
pub fn alias_from_luid(luid: &NET_LUID_LH) -> io::Result<OsString> {
let mut buffer = [0u16; IF_MAX_STRING_SIZE as usize + 1];
- let status =
- unsafe { ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len()) };
- if status != NO_ERROR as i32 {
- return Err(io::Error::from_raw_os_error(status));
- }
+ win32_err!(unsafe {
+ ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len())
+ })?;
let nul = buffer.iter().position(|&c| c == 0u16).unwrap();
Ok(OsString::from_wide(&buffer[0..nul]))
}