diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-11-17 13:21:23 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-12-07 14:39:50 +0100 |
| commit | f9f31a08c8b7061dbee96c850892a545ff02d43c (patch) | |
| tree | 9ee4db786f14accbfa6d353534ca5cff744bc683 | |
| parent | c560ccf5a9bcd97ef04cbada264d399e33cfac67 (diff) | |
| download | mullvadvpn-f9f31a08c8b7061dbee96c850892a545ff02d43c.tar.xz mullvadvpn-f9f31a08c8b7061dbee96c850892a545ff02d43c.zip | |
Update wg-nt tunnel to use new API
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 177 |
1 files changed, 32 insertions, 145 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 66b9c086fc..3752fe8dba 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -40,7 +40,7 @@ use winapi::{ lazy_static! { static ref WG_NT_DLL: Mutex<Option<Arc<WgNtDll>>> = Mutex::new(None); - static ref ADAPTER_POOL: U16CString = U16CString::from_str("Mullvad").unwrap(); + static ref ADAPTER_TYPE: U16CString = U16CString::from_str("Mullvad").unwrap(); static ref ADAPTER_ALIAS: U16CString = U16CString::from_str("Mullvad").unwrap(); } @@ -51,24 +51,14 @@ const ADAPTER_GUID: GUID = GUID { Data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9], }; -/// Longest possible adapter name (in characters), including null terminator -const MAX_ADAPTER_NAME: usize = 128; - -type WireGuardOpenAdapterFn = - unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle; type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn( - pool: *const u16, name: *const u16, + tunnel_type: *const u16, requested_guid: *const GUID, - reboot_required: *mut BOOL, ) -> RawHandle; -type WireGuardFreeAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle); -type WireGuardDeleteAdapterFn = - unsafe extern "stdcall" fn(adapter: RawHandle, reboot_required: *mut BOOL) -> BOOL; +type WireGuardCloseAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle); type WireGuardGetAdapterLuidFn = unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID); -type WireGuardGetAdapterNameFn = - unsafe extern "stdcall" fn(adapter: RawHandle, name: *mut u16) -> BOOL; type WireGuardSetConfigurationFn = unsafe extern "stdcall" fn( adapter: RawHandle, config: *const MaybeUninit<u8>, @@ -116,8 +106,6 @@ enum WireGuardAdapterLogState { type WireGuardSetAdapterLoggingFn = unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> BOOL; -type RebootRequired = bool; - pub type Result<T> = std::result::Result<T, Error>; #[derive(err_derive::Error, Debug)] @@ -127,18 +115,10 @@ pub enum Error { #[error(display = "Failed to load wireguard.dll")] DllError(#[error(source)] io::Error), - /// Failed to remove tunnel interface - #[error(display = "Failed to remove residual tunnel device")] - DeleteExistingTunnelError(#[error(source)] io::Error), - /// Failed to create tunnel interface #[error(display = "Failed to create WireGuard device")] CreateTunnelDeviceError(#[error(source)] io::Error), - /// Failed to delete tunnel interface - #[error(display = "Failed to delete WireGuard device")] - DeleteTunnelDeviceError(#[error(source)] io::Error), - /// Failed to obtain tunnel interface alias #[error(display = "Failed to obtain interface name")] ObtainAliasError(#[error(source)] io::Error), @@ -432,40 +412,20 @@ impl WgNtTunnel { resource_dir: &Path, ) -> Result<Self> { let dll = load_wg_nt_dll(resource_dir)?; - let logger_handle = LoggerHandle::new(dll.clone(), log_path)?; - - { - if let Ok(device) = WgNtAdapter::open(dll.clone(), &*ADAPTER_POOL, &*ADAPTER_ALIAS) { - device.delete().map_err(Error::DeleteExistingTunnelError)?; - } - } - - let (device, reboot_required) = WgNtAdapter::create( + let device = WgNtAdapter::create( dll.clone(), - &*ADAPTER_POOL, &*ADAPTER_ALIAS, + &*ADAPTER_TYPE, Some(ADAPTER_GUID.clone()), ) .map_err(Error::CreateTunnelDeviceError)?; - if reboot_required { - log::warn!("You may need to reboot to finish installing WireGuardNT"); - } - let interface_luid = device.luid(); - let interface_name = match device.name() { - Ok(name) => name.to_string_lossy(), - Err(error) => { - if let Err(error) = device.delete() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to delete tunnel device") - ); - } - return Err(Error::ObtainAliasError(error)); - } - }; + let interface_name = device + .name() + .map_err(Error::ObtainAliasError)? + .to_string_lossy(); let tunnel = WgNtTunnel { device: Some(device), @@ -477,13 +437,8 @@ impl WgNtTunnel { Ok(tunnel) } - fn stop_tunnel(&mut self) -> Result<()> { - if let Some(device) = self.device.take() { - if let Err(error) = device.delete() { - return Err(Error::DeleteTunnelDeviceError(error)); - } - } - Ok(()) + fn stop_tunnel(&mut self) { + let _ = self.device.take(); } fn configure(&self, config: &Config) -> Result<()> { @@ -510,12 +465,7 @@ impl WgNtTunnel { impl Drop for WgNtTunnel { fn drop(&mut self) { - if let Err(error) = self.stop_tunnel() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to stop WireGuardNT tunnel") - ); - } + self.stop_tunnel(); } } @@ -580,27 +530,21 @@ unsafe impl Send for WgNtAdapter {} unsafe impl Sync for WgNtAdapter {} impl WgNtAdapter { - fn open(dll_handle: Arc<WgNtDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> { - let handle = dll_handle.open_adapter(pool, name)?; - Ok(Self { dll_handle, handle }) - } - fn create( dll_handle: Arc<WgNtDll>, - pool: &U16CStr, name: &U16CStr, + tunnel_type: &U16CStr, requested_guid: Option<GUID>, - ) -> io::Result<(Self, RebootRequired)> { - let (handle, restart_required) = dll_handle.create_adapter(pool, name, requested_guid)?; - Ok((Self { dll_handle, handle }, restart_required)) - } - - fn delete(self) -> io::Result<RebootRequired> { - unsafe { self.dll_handle.delete_adapter(self.handle) } + ) -> io::Result<Self> { + let handle = dll_handle.create_adapter(name, tunnel_type, requested_guid)?; + Ok(Self { dll_handle, handle }) } fn name(&self) -> io::Result<U16CString> { - unsafe { self.dll_handle.get_adapter_name(self.handle) } + windows::alias_from_luid(&self.luid()).and_then(|alias| { + U16CString::from_os_str(alias) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "unexpected null char")) + }) } fn luid(&self) -> NET_LUID { @@ -638,18 +582,15 @@ impl WgNtAdapter { impl Drop for WgNtAdapter { fn drop(&mut self) { - unsafe { self.dll_handle.free_adapter(self.handle) }; + unsafe { self.dll_handle.close_adapter(self.handle) }; } } struct WgNtDll { handle: HINSTANCE, - func_open: WireGuardOpenAdapterFn, func_create: WireGuardCreateAdapterFn, - func_delete: WireGuardDeleteAdapterFn, - func_free: WireGuardFreeAdapterFn, + func_close: WireGuardCloseAdapterFn, func_get_adapter_luid: WireGuardGetAdapterLuidFn, - func_get_adapter_name: WireGuardGetAdapterNameFn, func_set_configuration: WireGuardSetConfigurationFn, func_get_configuration: WireGuardGetConfigurationFn, func_set_adapter_state: WireGuardSetStateFn, @@ -683,28 +624,16 @@ impl WgNtDll { ) -> io::Result<Self> { Ok(WgNtDll { handle, - func_open: unsafe { - *((&get_proc_fn( - handle, - CStr::from_bytes_with_nul(b"WireGuardOpenAdapter\0").unwrap(), - )?) as *const _ as *const _) - }, func_create: unsafe { *((&get_proc_fn( handle, CStr::from_bytes_with_nul(b"WireGuardCreateAdapter\0").unwrap(), )?) as *const _ as *const _) }, - func_delete: unsafe { - *((&get_proc_fn( - handle, - CStr::from_bytes_with_nul(b"WireGuardDeleteAdapter\0").unwrap(), - )?) as *const _ as *const _) - }, - func_free: unsafe { + func_close: unsafe { *((&get_proc_fn( handle, - CStr::from_bytes_with_nul(b"WireGuardFreeAdapter\0").unwrap(), + CStr::from_bytes_with_nul(b"WireGuardCloseAdapter\0").unwrap(), )?) as *const _ as *const _) }, func_get_adapter_luid: unsafe { @@ -713,12 +642,6 @@ impl WgNtDll { CStr::from_bytes_with_nul(b"WireGuardGetAdapterLUID\0").unwrap(), )?) as *const _ as *const _) }, - func_get_adapter_name: unsafe { - *((&get_proc_fn( - handle, - CStr::from_bytes_with_nul(b"WireGuardGetAdapterName\0").unwrap(), - )?) as *const _ as *const _) - }, func_set_configuration: unsafe { *((&get_proc_fn( handle, @@ -760,54 +683,25 @@ impl WgNtDll { Ok(handle) } - pub fn open_adapter(&self, pool: &U16CStr, name: &U16CStr) -> io::Result<RawHandle> { - let handle = unsafe { (self.func_open)(pool.as_ptr(), name.as_ptr()) }; - if handle == ptr::null_mut() { - return Err(io::Error::last_os_error()); - } - Ok(handle) - } - pub fn create_adapter( &self, - pool: &U16CStr, name: &U16CStr, + tunnel_type: &U16CStr, requested_guid: Option<GUID>, - ) -> io::Result<(RawHandle, RebootRequired)> { + ) -> io::Result<RawHandle> { let guid_ptr = match requested_guid.as_ref() { Some(guid) => guid as *const _, None => ptr::null_mut(), }; - let mut reboot_required = 0; - let handle = unsafe { - (self.func_create)(pool.as_ptr(), name.as_ptr(), guid_ptr, &mut reboot_required) - }; + let handle = unsafe { (self.func_create)(name.as_ptr(), tunnel_type.as_ptr(), guid_ptr) }; if handle == ptr::null_mut() { return Err(io::Error::last_os_error()); } - Ok((handle, reboot_required != 0)) - } - - pub unsafe fn delete_adapter(&self, adapter: RawHandle) -> io::Result<RebootRequired> { - let mut reboot_required = 0; - let result = (self.func_delete)(adapter, &mut reboot_required); - if result == 0 { - return Err(io::Error::last_os_error()); - } - Ok(reboot_required != 0) + Ok(handle) } - pub unsafe fn free_adapter(&self, adapter: RawHandle) { - (self.func_free)(adapter); - } - - pub unsafe fn get_adapter_name(&self, adapter: RawHandle) -> io::Result<U16CString> { - let mut alias_buffer = vec![0u16; MAX_ADAPTER_NAME]; - let result = (self.func_get_adapter_name)(adapter, alias_buffer.as_mut_ptr()); - if result == 0 { - return Err(io::Error::last_os_error()); - } - Ok(U16CString::from_vec_truncate(alias_buffer)) + pub unsafe fn close_adapter(&self, adapter: RawHandle) { + (self.func_close)(adapter); } pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID { @@ -1051,15 +945,8 @@ impl Tunnel for WgNtTunnel { } fn stop(mut self: Box<Self>) -> std::result::Result<(), super::TunnelError> { - if let Err(error) = self.stop_tunnel() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to stop WireGuardNT tunnel") - ); - Err(super::TunnelError::StopWireguardError { status: 0 }) - } else { - Ok(()) - } + self.stop_tunnel(); + Ok(()) } } |
