summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/routing/windows.rs5
-rw-r--r--talpid-core/src/tunnel/mod.rs3
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs244
-rw-r--r--talpid-core/src/tunnel/openvpn/windows.rs132
-rw-r--r--talpid-core/src/tunnel/windows.rs133
-rw-r--r--talpid-core/src/tunnel/wireguard/connectivity_check.rs5
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs100
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs18
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs9
-rw-r--r--wireguard/libwg/libwg_windows.go48
12 files changed, 435 insertions, 264 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 781f984822..7fb4f90439 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2524,6 +2524,7 @@ dependencies = [
name = "talpid-core"
version = "0.1.0"
dependencies = [
+ "async-trait",
"atty",
"byteorder",
"cfg-if 1.0.0",
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 6f0c60c08d..f84e4a5bd4 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -8,6 +8,7 @@ edition = "2018"
publish = false
[dependencies]
+async-trait = "0.1"
atty = "0.2"
cfg-if = "1.0"
duct = "0.13"
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index 25c80899e5..ca5fbb8ea2 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -98,6 +98,11 @@ impl RouteManager {
}
}
+ /// Retrieve handle for the tokio runtime.
+ pub fn runtime_handle(&self) -> tokio::runtime::Handle {
+ self.runtime.clone()
+ }
+
async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) {
while let Some(command) = manage_rx.next().await {
match command {
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 063feacc3d..69cd842138 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -24,6 +24,9 @@ pub mod wireguard;
/// A module for low level platform specific tunnel device management.
pub(crate) mod tun_provider;
+#[cfg(target_os = "windows")]
+mod windows;
+
const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index 8d1aa03607..44efa96668 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -22,7 +22,7 @@ use std::{
process::ExitStatus,
sync::{
atomic::{AtomicBool, Ordering},
- mpsc, Arc,
+ mpsc, Arc, Mutex,
},
thread,
time::Duration,
@@ -33,12 +33,9 @@ use std::{collections::HashSet, net::IpAddr};
use std::{
ffi::{OsStr, OsString},
os::windows::ffi::OsStrExt,
- sync::Mutex,
time::Instant,
};
-use talpid_types::net::openvpn;
-#[cfg(any(windows, target_os = "linux"))]
-use talpid_types::ErrorExt;
+use talpid_types::{net::openvpn, ErrorExt};
use tokio::task;
#[cfg(target_os = "linux")]
use which;
@@ -180,6 +177,10 @@ pub enum Error {
#[error(display = "OpenVPN process died unexpectedly")]
ChildProcessDied,
+ /// Failed before OpenVPN started
+ #[error(display = "Failed to start OpenVPN")]
+ StartProcessError,
+
/// The IP routing program was not found.
#[cfg(target_os = "linux")]
#[error(display = "The IP routing program `ip` was not found")]
@@ -260,9 +261,15 @@ const OPENVPN_BIN_FILENAME: &str = "openvpn.exe";
/// Struct for monitoring an OpenVPN process.
#[derive(Debug)]
pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
- child: Arc<C::ProcessHandle>,
+ spawn_task: Option<
+ tokio::task::JoinHandle<
+ std::result::Result<io::Result<C::ProcessHandle>, futures::future::Aborted>,
+ >,
+ >,
+ abort_spawn: futures::future::AbortHandle,
+
+ child: Arc<Mutex<Option<Arc<C::ProcessHandle>>>>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
- log_path: Option<PathBuf>,
closed: Arc<AtomicBool>,
/// Keep the `TempFile` for the user-pass file in the struct, so it's removed on drop.
_user_pass_file: mktemp::TempFile,
@@ -274,9 +281,52 @@ pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
server_join_handle: Option<task::JoinHandle<std::result::Result<(), event_server::Error>>>,
#[cfg(windows)]
- wintun_adapter: Option<windows::TemporaryWintunAdapter>,
- #[cfg(windows)]
- _wintun_logger: Option<windows::WintunLoggerHandle>,
+ wintun: Arc<Box<dyn WintunContext>>,
+}
+
+#[cfg(windows)]
+#[async_trait::async_trait]
+trait WintunContext: Send + Sync {
+ fn luid(&self) -> NET_LUID;
+ fn ipv6(&self) -> bool;
+ async fn wait_for_interfaces(&self) -> io::Result<()>;
+}
+
+#[cfg(windows)]
+impl std::fmt::Debug for dyn WintunContext {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "WintunContext {{ luid: {}, ipv6: {} }}",
+ self.luid().Value,
+ self.ipv6()
+ )
+ }
+}
+
+#[cfg(windows)]
+#[derive(Debug)]
+struct WintunContextImpl {
+ adapter: windows::TemporaryWintunAdapter,
+ wait_v6_interface: bool,
+ _logger: windows::WintunLoggerHandle,
+}
+
+#[cfg(windows)]
+#[async_trait::async_trait]
+impl WintunContext for WintunContextImpl {
+ fn luid(&self) -> NET_LUID {
+ self.adapter.adapter().luid()
+ }
+
+ fn ipv6(&self) -> bool {
+ self.wait_v6_interface
+ }
+
+ async fn wait_for_interfaces(&self) -> io::Result<()> {
+ let luid = self.adapter.adapter().luid();
+ super::windows::wait_for_interfaces(luid, true, self.wait_v6_interface).await
+ }
}
@@ -399,14 +449,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
log::warn!("You may need to restart Windows to complete the install of Wintun");
}
- log::debug!("Wait for IP interfaces");
- windows::wait_for_interfaces(
- &adapter.adapter().luid(),
- true,
- params.generic_options.enable_ipv6,
- )
- .map_err(Error::IpInterfacesError)?;
-
let assigned_guid = adapter.adapter().guid();
let assigned_guid = assigned_guid.as_ref().unwrap_or_else(|error| {
log::error!(
@@ -475,15 +517,17 @@ impl OpenVpnMonitor<OpenVpnCommand> {
Self::new_internal(
cmd,
on_openvpn_event,
- &plugin_path,
+ plugin_path,
log_path,
user_pass_file,
proxy_auth_file,
proxy_monitor,
#[cfg(windows)]
- Some(wintun_adapter),
- #[cfg(windows)]
- Some(wintun_logger),
+ Box::new(WintunContextImpl {
+ adapter: wintun_adapter,
+ wait_v6_interface: params.generic_options.enable_ipv6,
+ _logger: wintun_logger,
+ }),
)
}
}
@@ -527,17 +571,16 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
Ok(routes)
}
-impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
+impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
fn new_internal<L>(
mut cmd: C,
on_event: L,
- plugin_path: impl AsRef<Path>,
+ plugin_path: PathBuf,
log_path: Option<PathBuf>,
user_pass_file: mktemp::TempFile,
proxy_auth_file: Option<mktemp::TempFile>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
- #[cfg(windows)] wintun_adapter: Option<windows::TemporaryWintunAdapter>,
- #[cfg(windows)] wintun_logger: Option<windows::WintunLoggerHandle>,
+ #[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static,
@@ -574,16 +617,23 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
.unwrap_err());
}
- let child = cmd
- .plugin(plugin_path, vec![ipc_path])
- .log(log_path.as_ref().map(|p| p.as_path()))
- .start()
- .map_err(|e| Error::ChildProcessError("Failed to start", e))?;
+ #[cfg(windows)]
+ let wintun = Arc::new(wintun);
+
+ cmd.plugin(plugin_path, vec![ipc_path])
+ .log(log_path.as_ref().map(|p| p.as_path()));
+ let (spawn_task, abort_spawn) = futures::future::abortable(Self::prepare_process(
+ cmd,
+ #[cfg(windows)]
+ wintun.clone(),
+ ));
+ let spawn_task = runtime.spawn(spawn_task);
Ok(OpenVpnMonitor {
- child: Arc::new(child),
+ spawn_task: Some(spawn_task),
+ abort_spawn,
+ child: Arc::new(Mutex::new(None)),
proxy_monitor,
- log_path,
closed: Arc::new(AtomicBool::new(false)),
_user_pass_file: user_pass_file,
_proxy_auth_file: proxy_auth_file,
@@ -593,17 +643,28 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
server_join_handle: Some(server_join_handle),
#[cfg(windows)]
- wintun_adapter,
- #[cfg(windows)]
- _wintun_logger: wintun_logger,
+ wintun,
})
}
+ async fn prepare_process(
+ cmd: C,
+ #[cfg(windows)] wintun: Arc<Box<dyn WintunContext>>,
+ ) -> io::Result<C::ProcessHandle> {
+ #[cfg(windows)]
+ {
+ log::debug!("Wait for IP interfaces");
+ wintun.wait_for_interfaces().await?;
+ }
+ cmd.start()
+ }
+
/// Creates a handle to this monitor, allowing the tunnel to be closed while some other
/// thread is blocked in `wait`.
pub fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> {
OpenVpnCloseHandle {
child: self.child.clone(),
+ abort_spawn: self.abort_spawn.clone(),
closed: self.closed.clone(),
}
}
@@ -656,9 +717,19 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
}
/// Supplement `inner_wait_tunnel()` with logging and error handling.
- fn wait_tunnel(&mut self) -> Result<()> {
+ fn wait_tunnel(self) -> Result<()> {
let result = self.inner_wait_tunnel();
match result {
+ WaitResult::Preparation(result) => match result {
+ Err(error) => {
+ log::debug!(
+ "{}",
+ error.display_chain_with_msg("Failed to start OpenVPN")
+ );
+ Err(Error::StartProcessError)
+ }
+ _ => Ok(()),
+ },
WaitResult::Child(Ok(exit_status), closed) => {
if exit_status.success() || closed {
log::debug!(
@@ -684,8 +755,28 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
/// Waits for both the child process and the event dispatcher in parallel. After both have
/// returned this returns the earliest result.
- fn inner_wait_tunnel(&mut self) -> WaitResult {
- let child_wait_handle = self.child.clone();
+ fn inner_wait_tunnel(mut self) -> WaitResult {
+ let child = match self
+ .runtime
+ .block_on(self.spawn_task.take().unwrap())
+ .expect("spawn task panicked")
+ {
+ Ok(Ok(child)) => Arc::new(child),
+ Ok(Err(error)) => {
+ self.closed.swap(true, Ordering::SeqCst);
+ return WaitResult::Preparation(Err(error));
+ }
+ Err(_) => return WaitResult::Preparation(Ok(())),
+ };
+
+ if self.closed.load(Ordering::SeqCst) {
+ return WaitResult::Preparation(Ok(()));
+ }
+
+ {
+ self.child.lock().unwrap().replace(child.clone());
+ }
+
let closed_handle = self.closed.clone();
let child_close_handle = self.close_handle();
@@ -695,7 +786,7 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
let event_server_abort_tx = self.event_server_abort_tx.clone();
thread::spawn(move || {
- let result = child_wait_handle.wait();
+ let result = child.wait();
let closed = closed_handle.load(Ordering::SeqCst);
child_tx.send(WaitResult::Child(result, closed)).unwrap();
event_server_abort_tx.trigger();
@@ -835,7 +926,8 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> {
/// A handle to an `OpenVpnMonitor` for closing it.
#[derive(Debug, Clone)]
pub struct OpenVpnCloseHandle<H: ProcessHandle = OpenVpnProcHandle> {
- child: Arc<H>,
+ child: Arc<Mutex<Option<Arc<H>>>>,
+ abort_spawn: futures::future::AbortHandle,
closed: Arc<AtomicBool>,
}
@@ -843,7 +935,12 @@ impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
/// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return.
pub fn close(self) -> io::Result<()> {
if !self.closed.swap(true, Ordering::SeqCst) {
- self.child.kill()
+ self.abort_spawn.abort();
+ if let Some(child) = self.child.lock().unwrap().as_ref() {
+ child.kill()
+ } else {
+ Ok(())
+ }
} else {
Ok(())
}
@@ -853,6 +950,7 @@ impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
/// Internal enum to differentiate between if the child process or the event dispatcher died first.
#[derive(Debug)]
enum WaitResult {
+ Preparation(io::Result<()>),
Child(io::Result<ExitStatus>, bool),
EventDispatcher,
}
@@ -1152,6 +1250,24 @@ mod tests {
sync::Arc,
};
+ #[cfg(windows)]
+ #[derive(Debug)]
+ struct TestWintunContext {}
+
+ #[cfg(windows)]
+ #[async_trait::async_trait]
+ impl WintunContext for TestWintunContext {
+ fn luid(&self) -> NET_LUID {
+ NET_LUID { Value: 0u64 }
+ }
+ fn ipv6(&self) -> bool {
+ false
+ }
+ async fn wait_for_interfaces(&self) -> io::Result<()> {
+ Ok(())
+ }
+ }
+
#[derive(Debug, Default, Clone)]
struct TestOpenVpnBuilder {
pub plugin: Arc<Mutex<Option<PathBuf>>>,
@@ -1205,15 +1321,13 @@ mod tests {
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
|_, _| {},
- "./my_test_plugin",
+ "./my_test_plugin".into(),
None,
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
);
assert_eq!(
Some(PathBuf::from("./my_test_plugin")),
@@ -1227,15 +1341,13 @@ mod tests {
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
|_, _| {},
- "",
+ "".into(),
Some(PathBuf::from("./my_test_log_file")),
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
);
assert_eq!(
Some(PathBuf::from("./my_test_log_file")),
@@ -1250,15 +1362,13 @@ mod tests {
let testee = OpenVpnMonitor::new_internal(
builder,
|_, _| {},
- "",
+ "".into(),
None,
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
)
.unwrap();
assert!(testee.wait().is_ok());
@@ -1271,15 +1381,13 @@ mod tests {
let testee = OpenVpnMonitor::new_internal(
builder,
|_, _| {},
- "",
+ "".into(),
None,
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
)
.unwrap();
assert!(testee.wait().is_err());
@@ -1292,15 +1400,13 @@ mod tests {
let testee = OpenVpnMonitor::new_internal(
builder,
|_, _| {},
- "",
+ "".into(),
None,
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
)
.unwrap();
testee.close_handle().close().unwrap();
@@ -1310,22 +1416,20 @@ mod tests {
#[test]
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
- let error = OpenVpnMonitor::new_internal(
+ let result = OpenVpnMonitor::new_internal(
builder,
|_, _| {},
- "",
+ "".into(),
None,
TempFile::new(),
None,
None,
#[cfg(windows)]
- None,
- #[cfg(windows)]
- None,
+ Box::new(TestWintunContext {}),
)
- .unwrap_err();
- match error {
- Error::ChildProcessError(..) => (),
+ .unwrap();
+ match result.wait() {
+ Err(Error::StartProcessError) => (),
_ => panic!("Wrong error"),
}
}
diff --git a/talpid-core/src/tunnel/openvpn/windows.rs b/talpid-core/src/tunnel/openvpn/windows.rs
index a88c6c756f..9b907e101a 100644
--- a/talpid-core/src/tunnel/openvpn/windows.rs
+++ b/talpid-core/src/tunnel/openvpn/windows.rs
@@ -4,8 +4,7 @@ use std::{
os::windows::{ffi::OsStrExt, io::RawHandle},
path::Path,
ptr,
- sync::{Arc, Mutex},
- time::Duration,
+ sync::Arc,
};
use talpid_types::ErrorExt;
use widestring::{U16CStr, U16CString};
@@ -14,13 +13,8 @@ use winapi::{
guiddef::GUID,
ifdef::NET_LUID,
minwindef::{BOOL, FARPROC, HINSTANCE, HMODULE},
- netioapi::{
- CancelMibChangeNotify2, ConvertInterfaceLuidToGuid, GetIpInterfaceEntry,
- MibAddInstance, NotifyIpInterfaceChange, MIB_IPINTERFACE_ROW,
- },
- ntdef::FALSE,
+ netioapi::ConvertInterfaceLuidToGuid,
winerror::NO_ERROR,
- ws2def::{AF_INET, AF_INET6, AF_UNSPEC},
},
um::{
libloaderapi::{
@@ -35,8 +29,6 @@ use winreg::{enums::HKEY_LOCAL_MACHINE, RegKey};
/// Longest possible adapter name (in characters), including null terminator
const MAX_ADAPTER_NAME: usize = 128;
-const INTERFACE_WAIT_TIMEOUT: Duration = Duration::from_secs(5);
-
type WintunOpenAdapterFn =
unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle;
@@ -142,6 +134,7 @@ impl fmt::Debug for WintunAdapter {
}
unsafe impl Send for WintunAdapter {}
+unsafe impl Sync for WintunAdapter {}
impl WintunAdapter {
pub fn open(dll_handle: Arc<WintunDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> {
@@ -406,6 +399,7 @@ pub fn string_from_guid(guid: &GUID) -> String {
}
}
+/// Returns the registry key for a network device identified by its GUID.
pub fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result<RegKey> {
let net_devs = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_with_flags(
r"SYSTEM\CurrentControlSet\Control\Class\{4d36e972-e325-11ce-bfc1-08002be10318}",
@@ -433,124 +427,6 @@ pub fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Re
Err(io::Error::new(io::ErrorKind::NotFound, "device not found"))
}
-pub struct IpNotifierHandle<'a> {
- mutex: Mutex<()>,
- callback: Option<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>>,
- handle: RawHandle,
-}
-
-impl<'a> Drop for IpNotifierHandle<'a> {
- fn drop(&mut self) {
- // Inner callback may be called while destructing
- unsafe { CancelMibChangeNotify2(self.handle as *mut _) };
-
- let _ = self
- .mutex
- .lock()
- .expect("NotifyIpInterfaceChange mutex poisoned");
- let _ = self.callback.take();
- }
-}
-
-unsafe extern "system" fn inner_callback(
- context: *mut winapi::ctypes::c_void,
- row: *mut MIB_IPINTERFACE_ROW,
- notify_type: u32,
-) {
- let context = &mut *(context as *mut IpNotifierHandle<'_>);
- let _ = context
- .mutex
- .lock()
- .expect("NotifyIpInterfaceChange mutex poisoned");
-
- if let Some(ref mut callback) = context.callback {
- callback(&*row, notify_type);
- }
-}
-
-pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>(
- callback: T,
- family: u16,
-) -> io::Result<Box<IpNotifierHandle<'a>>> {
- let mut context = Box::new(IpNotifierHandle {
- mutex: Mutex::default(),
- callback: Some(Box::new(callback)),
- handle: std::ptr::null_mut(),
- });
-
- let status = unsafe {
- NotifyIpInterfaceChange(
- family,
- Some(inner_callback),
- &mut *context as *mut _ as *mut _,
- FALSE,
- (&mut context.handle) as *mut _,
- )
- };
-
- if status != NO_ERROR {
- return Err(io::Error::last_os_error());
- }
-
- Ok(context)
-}
-
-pub fn get_ip_interface_entry(family: u16, luid: &NET_LUID) -> io::Result<MIB_IPINTERFACE_ROW> {
- let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() };
- row.Family = family;
- row.InterfaceLuid = *luid;
-
- let result = unsafe { GetIpInterfaceEntry(&mut row as *mut _) };
- if result != NO_ERROR {
- return Err(io::Error::last_os_error());
- }
-
- Ok(row)
-}
-
-pub fn wait_for_interfaces(luid: &NET_LUID, ipv4: bool, ipv6: bool) -> io::Result<()> {
- let (tx, rx) = std::sync::mpsc::channel();
-
- let mut found_ipv4 = if ipv4 { false } else { true };
- let mut found_ipv6 = if ipv6 { false } else { true };
-
- let _handle = notify_ip_interface_change(
- move |row, notification_type| {
- if found_ipv4 && found_ipv6 {
- return;
- }
- if notification_type != MibAddInstance {
- return;
- }
- if row.InterfaceLuid.Value != luid.Value {
- return;
- }
- match row.Family as i32 {
- AF_INET => found_ipv4 = true,
- AF_INET6 => found_ipv6 = true,
- _ => (),
- }
- if found_ipv4 && found_ipv6 {
- let _ = tx.send(());
- }
- },
- AF_UNSPEC as u16,
- )?;
-
- // Make sure they don't already exist
- if (!ipv4 || get_ip_interface_entry(AF_INET as u16, luid).is_ok())
- && (!ipv6 || get_ip_interface_entry(AF_INET6 as u16, luid).is_ok())
- {
- return Ok(());
- }
-
- let _ = rx
- .recv_timeout(INTERFACE_WAIT_TIMEOUT)
- .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out waiting on interfaces"))?;
-
- Ok(())
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/talpid-core/src/tunnel/windows.rs b/talpid-core/src/tunnel/windows.rs
new file mode 100644
index 0000000000..2d4fcf85e5
--- /dev/null
+++ b/talpid-core/src/tunnel/windows.rs
@@ -0,0 +1,133 @@
+use std::{io, mem, os::windows::io::RawHandle, sync::Mutex};
+use winapi::shared::{
+ ifdef::NET_LUID,
+ netioapi::{
+ CancelMibChangeNotify2, GetIpInterfaceEntry, MibAddInstance, NotifyIpInterfaceChange,
+ MIB_IPINTERFACE_ROW,
+ },
+ ntdef::FALSE,
+ winerror::{ERROR_NOT_FOUND, NO_ERROR},
+ ws2def::{AF_INET, AF_INET6, AF_UNSPEC},
+};
+
+/// Context for [`notify_ip_interface_change`]. When it is dropped,
+/// the callback is unregistered.
+pub struct IpNotifierHandle<'a> {
+ callback: Mutex<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>>,
+ handle: RawHandle,
+}
+
+unsafe impl Send for IpNotifierHandle<'_> {}
+
+impl<'a> Drop for IpNotifierHandle<'a> {
+ fn drop(&mut self) {
+ unsafe { CancelMibChangeNotify2(self.handle as *mut _) };
+ }
+}
+
+unsafe extern "system" fn inner_callback(
+ context: *mut winapi::ctypes::c_void,
+ row: *mut MIB_IPINTERFACE_ROW,
+ notify_type: u32,
+) {
+ let context = &mut *(context as *mut IpNotifierHandle<'_>);
+ context
+ .callback
+ .lock()
+ .expect("NotifyIpInterfaceChange mutex poisoned")(&*row, notify_type);
+}
+
+/// Registers a callback function that is invoked when an interface is added, removed,
+/// or changed.
+pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>(
+ callback: T,
+ family: u16,
+) -> io::Result<Box<IpNotifierHandle<'a>>> {
+ let mut context = Box::new(IpNotifierHandle {
+ callback: Mutex::new(Box::new(callback)),
+ handle: std::ptr::null_mut(),
+ });
+
+ let status = unsafe {
+ NotifyIpInterfaceChange(
+ family,
+ Some(inner_callback),
+ &mut *context as *mut _ as *mut _,
+ FALSE,
+ (&mut context.handle) as *mut _,
+ )
+ };
+
+ if status == NO_ERROR {
+ Ok(context)
+ } else {
+ Err(io::Error::from_raw_os_error(status as i32))
+ }
+}
+
+/// Returns information about a network IP interface.
+pub fn get_ip_interface_entry(family: u16, luid: &NET_LUID) -> io::Result<MIB_IPINTERFACE_ROW> {
+ let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() };
+ row.Family = family;
+ row.InterfaceLuid = *luid;
+
+ let result = unsafe { GetIpInterfaceEntry(&mut row as *mut _) };
+ if result == NO_ERROR {
+ Ok(row)
+ } else {
+ Err(io::Error::from_raw_os_error(result as i32))
+ }
+}
+
+fn ip_interface_entry_exists(family: u16, luid: &NET_LUID) -> io::Result<bool> {
+ match get_ip_interface_entry(family, luid) {
+ Ok(_) => Ok(true),
+ Err(error) if error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => Ok(false),
+ Err(error) => Err(error),
+ }
+}
+
+/// Waits until the specified IP interfaces have attached to a given network interface.
+pub async fn wait_for_interfaces(luid: NET_LUID, ipv4: bool, ipv6: bool) -> io::Result<()> {
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ let mut found_ipv4 = if ipv4 { false } else { true };
+ let mut found_ipv6 = if ipv6 { false } else { true };
+
+ let mut tx = Some(tx);
+
+ let _handle = notify_ip_interface_change(
+ move |row, notification_type| {
+ if found_ipv4 && found_ipv6 {
+ return;
+ }
+ if notification_type != MibAddInstance {
+ return;
+ }
+ if row.InterfaceLuid.Value != luid.Value {
+ return;
+ }
+ match row.Family as i32 {
+ AF_INET => found_ipv4 = true,
+ AF_INET6 => found_ipv6 = true,
+ _ => (),
+ }
+ if found_ipv4 && found_ipv6 {
+ if let Some(tx) = tx.take() {
+ let _ = tx.send(());
+ }
+ }
+ },
+ AF_UNSPEC as u16,
+ )?;
+
+ // Make sure they don't already exist
+ if (!ipv4 || ip_interface_entry_exists(AF_INET as u16, &luid)?)
+ && (!ipv6 || ip_interface_entry_exists(AF_INET6 as u16, &luid)?)
+ {
+ return Ok(());
+ }
+
+ let _ = rx.await;
+ Ok(())
+}
diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
index 9b62fc47e8..60318c071b 100644
--- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs
+++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
@@ -524,6 +524,11 @@ mod test {
"mock-tunnel".to_string()
}
+ #[cfg(windows)]
+ fn get_interface_luid(&self) -> u64 {
+ 0
+ }
+
fn stop(self: Box<Self>) -> Result<(), TunnelError> {
Ok(())
}
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 199fe3e6ca..0703633b4e 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -8,6 +8,8 @@ use futures::future::abortable;
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
use std::env;
+#[cfg(windows)]
+use std::io;
use std::{
collections::HashSet,
net::SocketAddr,
@@ -54,9 +56,19 @@ pub enum Error {
#[error(display = "Failed obtain local address for the UDP socket in Udp2Tcp")]
GetLocalUdpAddress(#[error(source)] std::io::Error),
- /// Failed to setup connectivity monitor
+ /// Failed to set up connectivity monitor
#[error(display = "Connectivity monitor failed")]
ConnectivityMonitorError(#[error(source)] connectivity_check::Error),
+
+ /// Failed to set up IP interfaces.
+ #[cfg(windows)]
+ #[error(display = "Failed while waiting on IP interfaces")]
+ IpInterfacesError(#[error(source)] io::Error),
+
+ /// Failed to set IP addresses on WireGuard interface
+ #[cfg(target_os = "windows")]
+ #[error(display = "Failed to set IP addresses on WireGuard interface")]
+ SetIpAddressesError,
}
@@ -68,6 +80,8 @@ pub struct WireguardMonitor {
event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>,
close_msg_sender: mpsc::Sender<CloseMsg>,
close_msg_receiver: mpsc::Receiver<CloseMsg>,
+ #[cfg(target_os = "windows")]
+ stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>,
pinger_stop_sender: mpsc::Sender<()>,
_tcp_proxies: Vec<TcpProxy>,
}
@@ -158,18 +172,11 @@ impl WireguardMonitor {
let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?;
let iface_name = tunnel.get_interface_name().to_string();
+ #[cfg(windows)]
+ let iface_luid = tunnel.get_interface_luid();
(on_event)(TunnelEvent::InterfaceUp(iface_name.clone()));
- #[cfg(target_os = "linux")]
- route_manager
- .create_routing_rules(config.enable_ipv6)
- .map_err(Error::SetupRoutingError)?;
-
- route_manager
- .add_routes(Self::get_routes(&iface_name, &config))
- .map_err(Error::SetupRoutingError)?;
-
#[cfg(target_os = "windows")]
route_manager
.add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
@@ -177,11 +184,15 @@ impl WireguardMonitor {
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
let (pinger_tx, pinger_rx) = mpsc::channel();
+ #[cfg(target_os = "windows")]
+ let (stop_setup_tx, stop_setup_rx) = futures::channel::oneshot::channel();
let monitor = WireguardMonitor {
tunnel: Arc::new(Mutex::new(Some(tunnel))),
event_callback,
close_msg_sender,
close_msg_receiver,
+ #[cfg(target_os = "windows")]
+ stop_setup_tx: Some(stop_setup_tx),
pinger_stop_sender: pinger_tx,
_tcp_proxies: tcp_proxies,
};
@@ -191,13 +202,67 @@ impl WireguardMonitor {
let close_sender = monitor.close_msg_sender.clone();
let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
gateway,
- iface_name.to_string(),
+ iface_name.clone(),
Arc::downgrade(&monitor.tunnel),
pinger_rx,
)
.map_err(Error::ConnectivityMonitorError)?;
+ let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
+ #[cfg(windows)]
+ let runtime = route_manager.runtime_handle();
+
std::thread::spawn(move || {
+ #[cfg(windows)]
+ {
+ let iface_close_sender = close_sender.clone();
+ let enable_ipv6 = config.ipv6_gateway.is_some();
+
+ let result = runtime.block_on(async move {
+ use futures::future::FutureExt;
+ use winapi::shared::ifdef::NET_LUID;
+ let luid = NET_LUID { Value: iface_luid };
+ let setup_future = super::windows::wait_for_interfaces(luid, true, enable_ipv6);
+
+ futures::select! {
+ result = setup_future.fuse() => {
+ result.map_err(|error|
+ iface_close_sender.send(CloseMsg::SetupError(
+ Error::IpInterfacesError(error)
+ ))
+ .unwrap_or(())
+ )
+ }
+ _ = stop_setup_rx.fuse() => Err(()),
+ }
+ });
+
+ if result.is_err() {
+ return;
+ }
+ }
+
+ let setup_iface_routes = move || -> Result<()> {
+ #[cfg(target_os = "windows")]
+ if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
+ return Err(Error::SetIpAddressesError);
+ }
+
+ #[cfg(target_os = "linux")]
+ route_handle
+ .create_routing_rules(config.enable_ipv6)
+ .map_err(Error::SetupRoutingError)?;
+
+ route_handle
+ .add_routes(Self::get_routes(&iface_name, &config))
+ .map_err(Error::SetupRoutingError)
+ };
+
+ if let Err(error) = setup_iface_routes() {
+ let _ = close_sender.send(CloseMsg::SetupError(error));
+ return;
+ }
+
match connectivity_monitor.establish_connectivity() {
Ok(true) => {
(on_event)(TunnelEvent::Up(metadata));
@@ -291,9 +356,14 @@ impl WireguardMonitor {
let wait_result = match self.close_msg_receiver.recv() {
Ok(CloseMsg::PingErr) => Err(Error::TimeoutError),
Ok(CloseMsg::Stop) => Ok(()),
+ Ok(CloseMsg::SetupError(error)) => Err(error),
Err(_) => Ok(()),
};
+ #[cfg(windows)]
+ if let Some(stop_tx) = self.stop_setup_tx.take() {
+ let _ = stop_tx.send(());
+ }
let _ = self.pinger_stop_sender.send(());
self.stop_tunnel();
@@ -439,6 +509,7 @@ impl WireguardMonitor {
enum CloseMsg {
Stop,
PingErr,
+ SetupError(Error),
}
/// Close handle for a WireGuard tunnel.
@@ -458,6 +529,8 @@ impl CloseHandle {
pub(crate) trait Tunnel: Send {
fn get_interface_name(&self) -> String;
+ #[cfg(target_os = "windows")]
+ fn get_interface_luid(&self) -> u64;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
fn get_tunnel_stats(&self) -> std::result::Result<stats::Stats, TunnelError>;
#[cfg(target_os = "linux")]
@@ -522,11 +595,6 @@ pub enum TunnelError {
#[error(display = "Failed to convert adapter alias")]
InvalidAlias,
- /// Failed to set ip addresses on tunnel interface.
- #[cfg(target_os = "windows")]
- #[error(display = "Failed to set IP addresses on WireGuard interface")]
- SetIpAddressesError,
-
/// Failure to set up logging
#[error(display = "Failed to set up logging")]
LoggingError(#[error(source)] logging::Error),
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 2751fbdfbf..2145bb96d6 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -29,7 +29,7 @@ use {
type Result<T> = std::result::Result<T, TunnelError>;
#[cfg(target_os = "windows")]
-use crate::winnet::{self, add_device_ip_addresses};
+use crate::winnet;
#[cfg(not(target_os = "windows"))]
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
@@ -44,6 +44,8 @@ impl Drop for LoggingContext {
pub struct WgGoTunnel {
interface_name: String,
+ #[cfg(windows)]
+ interface_luid: u64,
handle: Option<i32>,
// holding on to the tunnel device and the log file ensures that the associated file handles
// live long enough and get closed when the tunnel is stopped
@@ -134,6 +136,7 @@ impl WgGoTunnel {
.any(|config| config.allowed_ips.iter().any(|ip| ip.is_ipv6()));
let mut alias_ptr = std::ptr::null_mut();
+ let mut interface_luid = 0u64;
let handle = unsafe {
wgTurnOn(
@@ -142,6 +145,7 @@ impl WgGoTunnel {
wait_on_ipv6 as u8,
wg_config_str.as_ptr(),
&mut alias_ptr,
+ &mut interface_luid,
Some(logging_callback),
logging_context.0 as *mut libc::c_void,
)
@@ -163,13 +167,9 @@ impl WgGoTunnel {
log::debug!("Adapter alias: {}", actual_iface_name);
- if !add_device_ip_addresses(&actual_iface_name, &config.tunnel.addresses) {
- // Todo: what kind of clean-up is required?
- return Err(TunnelError::SetIpAddressesError);
- }
-
Ok(WgGoTunnel {
interface_name: actual_iface_name,
+ interface_luid,
handle: Some(handle),
_logging_context: logging_context,
})
@@ -302,6 +302,11 @@ impl Tunnel for WgGoTunnel {
self.interface_name.clone()
}
+ #[cfg(target_os = "windows")]
+ fn get_interface_luid(&self) -> u64 {
+ self.interface_luid
+ }
+
fn get_tunnel_stats(&self) -> Result<Stats> {
let config_str = unsafe {
let ptr = wgGetConfig(self.handle.unwrap());
@@ -376,6 +381,7 @@ extern "C" {
wait_on_ipv6: u8,
settings: *const i8,
iface_name_out: *const *mut std::os::raw::c_char,
+ iface_luid_out: *mut u64,
logging_callback: Option<LoggingCallback>,
logging_context: *mut libc::c_void,
) -> i32;
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 553cf9377d..fe87d00f0f 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -167,6 +167,15 @@ impl ConnectingState {
log::debug!("WireGuard tunnel timed out");
None
}
+ error @ tunnel::Error::WireguardTunnelMonitoringError(..)
+ if !should_retry(&error) =>
+ {
+ error!(
+ "{}",
+ error.display_chain_with_msg("Tunnel has stopped unexpectedly")
+ );
+ Some(ErrorStateCause::StartTunnelError)
+ }
error => {
warn!(
"{}",
diff --git a/wireguard/libwg/libwg_windows.go b/wireguard/libwg/libwg_windows.go
index 42f27148b3..2cf04ed140 100644
--- a/wireguard/libwg/libwg_windows.go
+++ b/wireguard/libwg/libwg_windows.go
@@ -21,9 +21,7 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/wintun"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
- "github.com/mullvad/mullvadvpn-app/wireguard/libwg/interfacewatcher"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/logging"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer"
)
@@ -43,30 +41,8 @@ func init() {
}
}
-func createInterfaceWatcherEvents(waitOnIpv6 bool, tunLuid uint64) []interfacewatcher.Event {
- if waitOnIpv6 {
- return []interfacewatcher.Event{
- {
- Luid: winipcfg.LUID(tunLuid),
- Family: windows.AF_INET,
- },
- interfacewatcher.Event {
- Luid: winipcfg.LUID(tunLuid),
- Family: windows.AF_INET6,
- },
- }
- } else {
- return []interfacewatcher.Event{
- {
- Luid: winipcfg.LUID(tunLuid),
- Family: windows.AF_INET,
- },
- }
- }
-}
-
//export wgTurnOn
-func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, cIfaceNameOut **C.char, logSink LogSink, logContext LogContext) int32 {
+func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, cIfaceNameOut **C.char, cLuidOut *uint64, logSink LogSink, logContext LogContext) int32 {
logger := logging.NewLogger(logSink, logContext)
if cIfaceNameOut != nil {
*cIfaceNameOut = nil
@@ -88,13 +64,6 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c
// {AFE43773-E1F8-4EBB-8536-576AB86AFE9A}
networkId := windows.GUID{0xafe43773, 0xe1f8, 0x4ebb, [8]byte{0x85, 0x36, 0x57, 0x6a, 0xb8, 0x6a, 0xfe, 0x9a}}
- watcher, err := interfacewatcher.NewWatcher()
- if err != nil {
- logger.Errorf("%s\n", err)
- return ERROR_GENERAL_FAILURE
- }
- defer watcher.Destroy()
-
if tun.WintunPool != MullvadPool {
tun.WintunPool = MullvadPool
}
@@ -132,18 +101,6 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c
device.Up()
- interfaces := createInterfaceWatcherEvents(waitOnIpv6, nativeTun.LUID())
-
- logger.Verbosef("Waiting for interfaces to attach\n")
-
- if !watcher.Join(interfaces, 5) {
- logger.Errorf("Failed to wait for IP interfaces to become available\n")
- device.Close()
- return ERROR_GENERAL_FAILURE
- }
-
- logger.Verbosef("Interfaces OK\n")
-
context := tunnelcontainer.Context{
Device: device,
Logger: logger,
@@ -159,6 +116,9 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c
if cIfaceNameOut != nil {
*cIfaceNameOut = C.CString(actualInterfaceName)
}
+ if cLuidOut != nil {
+ *cLuidOut = nativeTun.LUID()
+ }
return handle
}