diff options
| author | Linus Färnstrand <linus@mullvad.net> | 2023-08-03 09:18:16 +0200 |
|---|---|---|
| committer | Linus Färnstrand <linus@mullvad.net> | 2023-08-08 08:39:59 +0200 |
| commit | e350149d6dd85834cfa7b3dee4a959854cf5302b (patch) | |
| tree | 94fd04c9272a84cd5cb39b10005e2525807dd47f | |
| parent | 7f948ba86b041c57ee4a9a8f11ec21eabff66c31 (diff) | |
| download | mullvadvpn-e350149d6dd85834cfa7b3dee4a959854cf5302b.tar.xz mullvadvpn-e350149d6dd85834cfa7b3dee4a959854cf5302b.zip | |
Adapt talpid-windows-net to windows-sys 0.48
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-types/src/error.rs | 70 | ||||
| -rw-r--r-- | talpid-types/src/lib.rs | 55 | ||||
| -rw-r--r-- | talpid-windows-net/Cargo.toml | 2 | ||||
| -rw-r--r-- | talpid-windows-net/src/net.rs | 76 |
5 files changed, 95 insertions, 109 deletions
diff --git a/Cargo.lock b/Cargo.lock index f4e3d1a322..eb7309d0db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3774,6 +3774,7 @@ dependencies = [ "futures", "libc", "socket2 0.4.9", + "talpid-types", "winapi", "windows-sys 0.48.0", ] diff --git a/talpid-types/src/error.rs b/talpid-types/src/error.rs new file mode 100644 index 0000000000..77f0664375 --- /dev/null +++ b/talpid-types/src/error.rs @@ -0,0 +1,70 @@ +use std::{error::Error, fmt, fmt::Write}; + +/// Used to generate string representations of error chains. +pub trait ErrorExt { + /// Creates a string representation of the entire error chain. + fn display_chain(&self) -> String; + + /// Like [Self::display_chain] but with an extra message at the start of the chain + fn display_chain_with_msg(&self, msg: &str) -> String; +} + +impl<E: Error> ErrorExt for E { + fn display_chain(&self) -> String { + let mut s = format!("Error: {self}"); + let mut source = self.source(); + while let Some(error) = source { + write!(&mut s, "\nCaused by: {error}").expect("formatting failed"); + source = error.source(); + } + s + } + + fn display_chain_with_msg(&self, msg: &str) -> String { + let mut s = format!("Error: {msg}\nCaused by: {self}"); + let mut source = self.source(); + while let Some(error) = source { + write!(&mut s, "\nCaused by: {error}").expect("formatting failed"); + source = error.source(); + } + s + } +} + +#[derive(Debug)] +pub struct BoxedError(Box<dyn Error + 'static + Send>); + +impl fmt::Display for BoxedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Error for BoxedError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.0.source() + } +} + +impl BoxedError { + pub fn new(error: impl Error + 'static + Send) -> Self { + BoxedError(Box::new(error)) + } +} + +/// Helper macro allowing simpler handling of Windows FFI returning `WIN32_ERROR` +/// status codes. Converts a `WIN32_ERROR` into an `io::Result<()>`. +/// +/// The caller of this macro must have `windows_sys` as a dependency. +#[cfg(windows)] +#[macro_export] +macro_rules! win32_err { + ($expr:expr) => {{ + let status = $expr; + if status == ::windows_sys::Win32::Foundation::NO_ERROR { + Ok(()) + } else { + Err(::std::io::Error::from_raw_os_error(status as i32)) + } + }}; +} diff --git a/talpid-types/src/lib.rs b/talpid-types/src/lib.rs index 92c0a2ce76..ee2c6d172a 100644 --- a/talpid-types/src/lib.rs +++ b/talpid-types/src/lib.rs @@ -1,7 +1,5 @@ #![deny(rust_2018_idioms)] -use std::{error::Error, fmt, fmt::Write}; - #[cfg(target_os = "android")] pub mod android; pub mod net; @@ -13,54 +11,5 @@ pub mod cgroup; #[cfg(target_os = "windows")] pub mod split_tunnel; -/// Used to generate string representations of error chains. -pub trait ErrorExt { - /// Creates a string representation of the entire error chain. - fn display_chain(&self) -> String; - - /// Like [Self::display_chain] but with an extra message at the start of the chain - fn display_chain_with_msg(&self, msg: &str) -> String; -} - -impl<E: Error> ErrorExt for E { - fn display_chain(&self) -> String { - let mut s = format!("Error: {self}"); - let mut source = self.source(); - while let Some(error) = source { - write!(&mut s, "\nCaused by: {error}").expect("formatting failed"); - source = error.source(); - } - s - } - - fn display_chain_with_msg(&self, msg: &str) -> String { - let mut s = format!("Error: {msg}\nCaused by: {self}"); - let mut source = self.source(); - while let Some(error) = source { - write!(&mut s, "\nCaused by: {error}").expect("formatting failed"); - source = error.source(); - } - s - } -} - -#[derive(Debug)] -pub struct BoxedError(Box<dyn Error + 'static + Send>); - -impl fmt::Display for BoxedError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl Error for BoxedError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() - } -} - -impl BoxedError { - pub fn new(error: impl Error + 'static + Send) -> Self { - BoxedError(Box::new(error)) - } -} +mod error; +pub use error::*; diff --git a/talpid-windows-net/Cargo.toml b/talpid-windows-net/Cargo.toml index 40caf827f4..351bc539ea 100644 --- a/talpid-windows-net/Cargo.toml +++ b/talpid-windows-net/Cargo.toml @@ -15,6 +15,8 @@ socket2 = { version = "0.4.2", features = ["all"] } futures = "0.3.15" winapi = { version = "0.3.6", features = ["ws2def"] } +talpid-types = { path = "../talpid-types" } + [target.'cfg(windows)'.dependencies.windows-sys] workspace = true features = [ diff --git a/talpid-windows-net/src/net.rs b/talpid-windows-net/src/net.rs index cd7ff14358..3c0d4d6880 100644 --- a/talpid-windows-net/src/net.rs +++ b/talpid-windows-net/src/net.rs @@ -9,11 +9,12 @@ use std::{ sync::Mutex, time::{Duration, Instant}, }; +use talpid_types::win32_err; use winapi::shared::ws2def::SOCKADDR_STORAGE as sockaddr_storage; use windows_sys::{ core::GUID, Win32::{ - Foundation::{ERROR_NOT_FOUND, HANDLE, NO_ERROR}, + Foundation::{ERROR_NOT_FOUND, HANDLE}, NetworkManagement::{ IpHelper::{ CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, @@ -174,7 +175,7 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send handle: 0, }); - let status = unsafe { + win32_err!(unsafe { NotifyIpInterfaceChange( af_family_from_family(family), Some(inner_callback), @@ -182,13 +183,8 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send 0, (&mut context.handle) as *mut _, ) - }; - - if status == NO_ERROR as i32 { - Ok(context) - } else { - Err(io::Error::from_raw_os_error(status)) - } + })?; + Ok(context) } /// Returns information about a network IP interface. @@ -200,22 +196,13 @@ pub fn get_ip_interface_entry( row.Family = family as u16; row.InterfaceLuid = *luid; - let result = unsafe { GetIpInterfaceEntry(&mut row) }; - if result == NO_ERROR as i32 { - Ok(row) - } else { - Err(io::Error::from_raw_os_error(result)) - } + win32_err!(unsafe { GetIpInterfaceEntry(&mut row) })?; + Ok(row) } /// Set the properties of an IP interface. pub fn set_ip_interface_entry(row: &mut MIB_IPINTERFACE_ROW) -> io::Result<()> { - let result = unsafe { SetIpInterfaceEntry(row as *mut _) }; - if result == NO_ERROR as i32 { - Ok(()) - } else { - Err(io::Error::from_raw_os_error(result)) - } + win32_err!(unsafe { SetIpInterfaceEntry(row as *mut _) }) } fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID_LH) -> io::Result<bool> { @@ -293,12 +280,8 @@ pub async fn wait_for_addresses(luid: NET_LUID_LH) -> Result<()> { let mut ready = true; for row in &mut unicast_rows { - let status = unsafe { GetUnicastIpAddressEntry(row) }; - if status != NO_ERROR as i32 { - return Err(Error::ObtainUnicastAddress(io::Error::from_raw_os_error( - status, - ))); - } + win32_err!(unsafe { GetUnicastIpAddressEntry(row) }) + .map_err(Error::ObtainUnicastAddress)?; if row.DadState == IpDadStateTentative { ready = false; break; @@ -347,13 +330,7 @@ pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Resul row.DadState = IpDadStatePreferred; row.OnLinkPrefixLength = 255; - let status = unsafe { CreateUnicastIpAddressEntry(&row) }; - if status != NO_ERROR as i32 { - return Err(Error::CreateUnicastEntry(io::Error::from_raw_os_error( - status, - ))); - } - Ok(()) + win32_err!(unsafe { CreateUnicastIpAddressEntry(&row) }).map_err(Error::CreateUnicastEntry) } /// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are @@ -364,11 +341,9 @@ pub fn get_unicast_table( let mut unicast_rows = vec![]; let mut unicast_table: *mut MIB_UNICASTIPADDRESS_TABLE = std::ptr::null_mut(); - let status = - unsafe { GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table) }; - if status != NO_ERROR as i32 { - return Err(io::Error::from_raw_os_error(status)); - } + win32_err!(unsafe { + GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table) + })?; let first_row = unsafe { &(*unicast_table).Table[0] } as *const MIB_UNICASTIPADDRESS_ROW; for i in 0..unsafe { *unicast_table }.NumEntries { unicast_rows.push(unsafe { *(first_row.offset(i as isize)) }); @@ -381,20 +356,14 @@ pub fn get_unicast_table( /// Returns the index of a network interface given its LUID. pub fn index_from_luid(luid: &NET_LUID_LH) -> io::Result<u32> { let mut index = 0u32; - let status = unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) }; - if status != NO_ERROR as i32 { - return Err(io::Error::from_raw_os_error(status)); - } + win32_err!(unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) })?; Ok(index) } /// Returns the GUID of a network interface given its LUID. pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> { let mut guid = MaybeUninit::zeroed(); - let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) }; - if status != NO_ERROR as i32 { - return Err(io::Error::from_raw_os_error(status)); - } + win32_err!(unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) })?; Ok(unsafe { guid.assume_init() }) } @@ -406,21 +375,16 @@ pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID_LH> { .chain(std::iter::once(0u16)) .collect(); let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() }; - let status = unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) }; - if status != NO_ERROR as i32 { - return Err(io::Error::from_raw_os_error(status)); - } + win32_err!(unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) })?; Ok(luid) } /// Returns the alias of an interface given its LUID. pub fn alias_from_luid(luid: &NET_LUID_LH) -> io::Result<OsString> { let mut buffer = [0u16; IF_MAX_STRING_SIZE as usize + 1]; - let status = - unsafe { ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len()) }; - if status != NO_ERROR as i32 { - return Err(io::Error::from_raw_os_error(status)); - } + win32_err!(unsafe { + ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len()) + })?; let nul = buffer.iter().position(|&c| c == 0u16).unwrap(); Ok(OsString::from_wide(&buffer[0..nul])) } |
