summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-11-17 13:21:23 +0100
committerDavid Lönnhager <david.l@mullvad.net>2021-12-07 14:39:50 +0100
commitf9f31a08c8b7061dbee96c850892a545ff02d43c (patch)
tree9ee4db786f14accbfa6d353534ca5cff744bc683
parentc560ccf5a9bcd97ef04cbada264d399e33cfac67 (diff)
downloadmullvadvpn-f9f31a08c8b7061dbee96c850892a545ff02d43c.tar.xz
mullvadvpn-f9f31a08c8b7061dbee96c850892a545ff02d43c.zip
Update wg-nt tunnel to use new API
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_nt.rs177
1 files changed, 32 insertions, 145 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
index 66b9c086fc..3752fe8dba 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
@@ -40,7 +40,7 @@ use winapi::{
lazy_static! {
static ref WG_NT_DLL: Mutex<Option<Arc<WgNtDll>>> = Mutex::new(None);
- static ref ADAPTER_POOL: U16CString = U16CString::from_str("Mullvad").unwrap();
+ static ref ADAPTER_TYPE: U16CString = U16CString::from_str("Mullvad").unwrap();
static ref ADAPTER_ALIAS: U16CString = U16CString::from_str("Mullvad").unwrap();
}
@@ -51,24 +51,14 @@ const ADAPTER_GUID: GUID = GUID {
Data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9],
};
-/// Longest possible adapter name (in characters), including null terminator
-const MAX_ADAPTER_NAME: usize = 128;
-
-type WireGuardOpenAdapterFn =
- unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle;
type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn(
- pool: *const u16,
name: *const u16,
+ tunnel_type: *const u16,
requested_guid: *const GUID,
- reboot_required: *mut BOOL,
) -> RawHandle;
-type WireGuardFreeAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle);
-type WireGuardDeleteAdapterFn =
- unsafe extern "stdcall" fn(adapter: RawHandle, reboot_required: *mut BOOL) -> BOOL;
+type WireGuardCloseAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle);
type WireGuardGetAdapterLuidFn =
unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID);
-type WireGuardGetAdapterNameFn =
- unsafe extern "stdcall" fn(adapter: RawHandle, name: *mut u16) -> BOOL;
type WireGuardSetConfigurationFn = unsafe extern "stdcall" fn(
adapter: RawHandle,
config: *const MaybeUninit<u8>,
@@ -116,8 +106,6 @@ enum WireGuardAdapterLogState {
type WireGuardSetAdapterLoggingFn =
unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> BOOL;
-type RebootRequired = bool;
-
pub type Result<T> = std::result::Result<T, Error>;
#[derive(err_derive::Error, Debug)]
@@ -127,18 +115,10 @@ pub enum Error {
#[error(display = "Failed to load wireguard.dll")]
DllError(#[error(source)] io::Error),
- /// Failed to remove tunnel interface
- #[error(display = "Failed to remove residual tunnel device")]
- DeleteExistingTunnelError(#[error(source)] io::Error),
-
/// Failed to create tunnel interface
#[error(display = "Failed to create WireGuard device")]
CreateTunnelDeviceError(#[error(source)] io::Error),
- /// Failed to delete tunnel interface
- #[error(display = "Failed to delete WireGuard device")]
- DeleteTunnelDeviceError(#[error(source)] io::Error),
-
/// Failed to obtain tunnel interface alias
#[error(display = "Failed to obtain interface name")]
ObtainAliasError(#[error(source)] io::Error),
@@ -432,40 +412,20 @@ impl WgNtTunnel {
resource_dir: &Path,
) -> Result<Self> {
let dll = load_wg_nt_dll(resource_dir)?;
-
let logger_handle = LoggerHandle::new(dll.clone(), log_path)?;
-
- {
- if let Ok(device) = WgNtAdapter::open(dll.clone(), &*ADAPTER_POOL, &*ADAPTER_ALIAS) {
- device.delete().map_err(Error::DeleteExistingTunnelError)?;
- }
- }
-
- let (device, reboot_required) = WgNtAdapter::create(
+ let device = WgNtAdapter::create(
dll.clone(),
- &*ADAPTER_POOL,
&*ADAPTER_ALIAS,
+ &*ADAPTER_TYPE,
Some(ADAPTER_GUID.clone()),
)
.map_err(Error::CreateTunnelDeviceError)?;
- if reboot_required {
- log::warn!("You may need to reboot to finish installing WireGuardNT");
- }
-
let interface_luid = device.luid();
- let interface_name = match device.name() {
- Ok(name) => name.to_string_lossy(),
- Err(error) => {
- if let Err(error) = device.delete() {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to delete tunnel device")
- );
- }
- return Err(Error::ObtainAliasError(error));
- }
- };
+ let interface_name = device
+ .name()
+ .map_err(Error::ObtainAliasError)?
+ .to_string_lossy();
let tunnel = WgNtTunnel {
device: Some(device),
@@ -477,13 +437,8 @@ impl WgNtTunnel {
Ok(tunnel)
}
- fn stop_tunnel(&mut self) -> Result<()> {
- if let Some(device) = self.device.take() {
- if let Err(error) = device.delete() {
- return Err(Error::DeleteTunnelDeviceError(error));
- }
- }
- Ok(())
+ fn stop_tunnel(&mut self) {
+ let _ = self.device.take();
}
fn configure(&self, config: &Config) -> Result<()> {
@@ -510,12 +465,7 @@ impl WgNtTunnel {
impl Drop for WgNtTunnel {
fn drop(&mut self) {
- if let Err(error) = self.stop_tunnel() {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to stop WireGuardNT tunnel")
- );
- }
+ self.stop_tunnel();
}
}
@@ -580,27 +530,21 @@ unsafe impl Send for WgNtAdapter {}
unsafe impl Sync for WgNtAdapter {}
impl WgNtAdapter {
- fn open(dll_handle: Arc<WgNtDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> {
- let handle = dll_handle.open_adapter(pool, name)?;
- Ok(Self { dll_handle, handle })
- }
-
fn create(
dll_handle: Arc<WgNtDll>,
- pool: &U16CStr,
name: &U16CStr,
+ tunnel_type: &U16CStr,
requested_guid: Option<GUID>,
- ) -> io::Result<(Self, RebootRequired)> {
- let (handle, restart_required) = dll_handle.create_adapter(pool, name, requested_guid)?;
- Ok((Self { dll_handle, handle }, restart_required))
- }
-
- fn delete(self) -> io::Result<RebootRequired> {
- unsafe { self.dll_handle.delete_adapter(self.handle) }
+ ) -> io::Result<Self> {
+ let handle = dll_handle.create_adapter(name, tunnel_type, requested_guid)?;
+ Ok(Self { dll_handle, handle })
}
fn name(&self) -> io::Result<U16CString> {
- unsafe { self.dll_handle.get_adapter_name(self.handle) }
+ windows::alias_from_luid(&self.luid()).and_then(|alias| {
+ U16CString::from_os_str(alias)
+ .map_err(|_| io::Error::new(io::ErrorKind::Other, "unexpected null char"))
+ })
}
fn luid(&self) -> NET_LUID {
@@ -638,18 +582,15 @@ impl WgNtAdapter {
impl Drop for WgNtAdapter {
fn drop(&mut self) {
- unsafe { self.dll_handle.free_adapter(self.handle) };
+ unsafe { self.dll_handle.close_adapter(self.handle) };
}
}
struct WgNtDll {
handle: HINSTANCE,
- func_open: WireGuardOpenAdapterFn,
func_create: WireGuardCreateAdapterFn,
- func_delete: WireGuardDeleteAdapterFn,
- func_free: WireGuardFreeAdapterFn,
+ func_close: WireGuardCloseAdapterFn,
func_get_adapter_luid: WireGuardGetAdapterLuidFn,
- func_get_adapter_name: WireGuardGetAdapterNameFn,
func_set_configuration: WireGuardSetConfigurationFn,
func_get_configuration: WireGuardGetConfigurationFn,
func_set_adapter_state: WireGuardSetStateFn,
@@ -683,28 +624,16 @@ impl WgNtDll {
) -> io::Result<Self> {
Ok(WgNtDll {
handle,
- func_open: unsafe {
- *((&get_proc_fn(
- handle,
- CStr::from_bytes_with_nul(b"WireGuardOpenAdapter\0").unwrap(),
- )?) as *const _ as *const _)
- },
func_create: unsafe {
*((&get_proc_fn(
handle,
CStr::from_bytes_with_nul(b"WireGuardCreateAdapter\0").unwrap(),
)?) as *const _ as *const _)
},
- func_delete: unsafe {
- *((&get_proc_fn(
- handle,
- CStr::from_bytes_with_nul(b"WireGuardDeleteAdapter\0").unwrap(),
- )?) as *const _ as *const _)
- },
- func_free: unsafe {
+ func_close: unsafe {
*((&get_proc_fn(
handle,
- CStr::from_bytes_with_nul(b"WireGuardFreeAdapter\0").unwrap(),
+ CStr::from_bytes_with_nul(b"WireGuardCloseAdapter\0").unwrap(),
)?) as *const _ as *const _)
},
func_get_adapter_luid: unsafe {
@@ -713,12 +642,6 @@ impl WgNtDll {
CStr::from_bytes_with_nul(b"WireGuardGetAdapterLUID\0").unwrap(),
)?) as *const _ as *const _)
},
- func_get_adapter_name: unsafe {
- *((&get_proc_fn(
- handle,
- CStr::from_bytes_with_nul(b"WireGuardGetAdapterName\0").unwrap(),
- )?) as *const _ as *const _)
- },
func_set_configuration: unsafe {
*((&get_proc_fn(
handle,
@@ -760,54 +683,25 @@ impl WgNtDll {
Ok(handle)
}
- pub fn open_adapter(&self, pool: &U16CStr, name: &U16CStr) -> io::Result<RawHandle> {
- let handle = unsafe { (self.func_open)(pool.as_ptr(), name.as_ptr()) };
- if handle == ptr::null_mut() {
- return Err(io::Error::last_os_error());
- }
- Ok(handle)
- }
-
pub fn create_adapter(
&self,
- pool: &U16CStr,
name: &U16CStr,
+ tunnel_type: &U16CStr,
requested_guid: Option<GUID>,
- ) -> io::Result<(RawHandle, RebootRequired)> {
+ ) -> io::Result<RawHandle> {
let guid_ptr = match requested_guid.as_ref() {
Some(guid) => guid as *const _,
None => ptr::null_mut(),
};
- let mut reboot_required = 0;
- let handle = unsafe {
- (self.func_create)(pool.as_ptr(), name.as_ptr(), guid_ptr, &mut reboot_required)
- };
+ let handle = unsafe { (self.func_create)(name.as_ptr(), tunnel_type.as_ptr(), guid_ptr) };
if handle == ptr::null_mut() {
return Err(io::Error::last_os_error());
}
- Ok((handle, reboot_required != 0))
- }
-
- pub unsafe fn delete_adapter(&self, adapter: RawHandle) -> io::Result<RebootRequired> {
- let mut reboot_required = 0;
- let result = (self.func_delete)(adapter, &mut reboot_required);
- if result == 0 {
- return Err(io::Error::last_os_error());
- }
- Ok(reboot_required != 0)
+ Ok(handle)
}
- pub unsafe fn free_adapter(&self, adapter: RawHandle) {
- (self.func_free)(adapter);
- }
-
- pub unsafe fn get_adapter_name(&self, adapter: RawHandle) -> io::Result<U16CString> {
- let mut alias_buffer = vec![0u16; MAX_ADAPTER_NAME];
- let result = (self.func_get_adapter_name)(adapter, alias_buffer.as_mut_ptr());
- if result == 0 {
- return Err(io::Error::last_os_error());
- }
- Ok(U16CString::from_vec_truncate(alias_buffer))
+ pub unsafe fn close_adapter(&self, adapter: RawHandle) {
+ (self.func_close)(adapter);
}
pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID {
@@ -1051,15 +945,8 @@ impl Tunnel for WgNtTunnel {
}
fn stop(mut self: Box<Self>) -> std::result::Result<(), super::TunnelError> {
- if let Err(error) = self.stop_tunnel() {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to stop WireGuardNT tunnel")
- );
- Err(super::TunnelError::StopWireguardError { status: 0 })
- } else {
- Ok(())
- }
+ self.stop_tunnel();
+ Ok(())
}
}