summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-daemon/src/geoip.rs13
-rw-r--r--mullvad-daemon/src/wireguard.rs4
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs157
3 files changed, 168 insertions, 6 deletions
diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs
index 953780b6f2..1047d3bb15 100644
--- a/mullvad-daemon/src/geoip.rs
+++ b/mullvad-daemon/src/geoip.rs
@@ -1,6 +1,7 @@
use futures::join;
use mullvad_rpc::{self, rest::RequestServiceHandle};
use mullvad_types::location::{AmIMullvad, GeoIpLocation};
+use talpid_types::ErrorExt;
const URI_V4: &str = "https://ipv4.am.i.mullvad.net/json";
const URI_V6: &str = "https://ipv6.am.i.mullvad.net/json";
@@ -11,7 +12,7 @@ pub async fn send_location_request(
let v4_sender = request_sender.clone();
let v4_future = async move {
let location = send_location_request_internal(URI_V4, v4_sender).await?;
- Ok(GeoIpLocation::from(location))
+ Ok::<GeoIpLocation, mullvad_rpc::rest::Error>(GeoIpLocation::from(location))
};
let v6_sender = request_sender.clone();
let v6_future = async move {
@@ -28,11 +29,17 @@ pub async fn send_location_request(
Ok(v4)
}
(Ok(v4), Err(e)) => {
- log::debug!("Unable to fetch IPv6 GeoIP location: {}", e);
+ log::debug!(
+ "{}",
+ e.display_chain_with_msg("Unable to fetch IPv6 GeoIP location")
+ );
Ok(v4)
}
(Err(e), Ok(v6)) => {
- log::debug!("Unable to fetch IPv4 GeoIP location: {}", e);
+ log::debug!(
+ "{}",
+ e.display_chain_with_msg("Unable to fetch IPv4 GeoIP location")
+ );
Ok(v6)
}
(Err(e_v4), Err(_)) => Err(e_v4),
diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs
index 28595fe1da..4ebf6e27d1 100644
--- a/mullvad-daemon/src/wireguard.rs
+++ b/mullvad-daemon/src/wireguard.rs
@@ -234,8 +234,8 @@ impl KeyManager {
timeout: Option<Duration>,
) -> Box<
dyn FnMut() -> Pin<
- Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>,
- > + Send,
+ Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>,
+ > + Send,
> {
let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
let public_key = private_key.public_key();
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index c413c8235a..7f299d8e73 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -29,6 +29,8 @@ use std::{
};
#[cfg(target_os = "linux")]
use std::{collections::HashSet, net::IpAddr};
+#[cfg(windows)]
+use std::{ffi::OsStr, os::windows::ffi::OsStrExt, time::Instant};
use talpid_types::net::openvpn;
#[cfg(target_os = "linux")]
use talpid_types::ErrorExt;
@@ -38,7 +40,17 @@ use which;
#[cfg(windows)]
use widestring::U16CString;
#[cfg(windows)]
-use winapi::shared::{guiddef::GUID, winerror::ERROR_FILE_NOT_FOUND};
+use winapi::shared::{
+ guiddef::GUID,
+ ifdef::NET_LUID,
+ netioapi::{
+ ConvertInterfaceAliasToLuid, FreeMibTable, GetUnicastIpAddressEntry,
+ GetUnicastIpAddressTable, MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE,
+ },
+ nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE},
+ winerror::{ERROR_FILE_NOT_FOUND, NO_ERROR},
+ ws2def::AF_UNSPEC,
+};
#[cfg(windows)]
mod windows;
@@ -58,6 +70,11 @@ const ADAPTER_GUID: GUID = GUID {
Data4: [0x85, 0x36, 0x57, 0x6A, 0xB8, 0x6A, 0xFE, 0x9A],
};
+#[cfg(windows)]
+const DEVICE_READY_TIMEOUT: Duration = Duration::from_secs(5);
+#[cfg(windows)]
+const DEVICE_CHECK_INTERVAL: Duration = Duration::from_millis(100);
+
/// Results from fallible operations on the OpenVPN tunnel.
pub type Result<T> = std::result::Result<T, Error>;
@@ -102,6 +119,31 @@ pub enum Error {
#[error(display = "Failed to delete existing Wintun adapter")]
WintunDeleteExistingError(#[error(source)] io::Error),
+ /// Error returned from `ConvertInterfaceAliasToLuid`
+ #[cfg(windows)]
+ #[error(display = "Cannot find LUID for virtual adapter")]
+ NoDeviceLuid(#[error(source)] io::Error),
+
+ /// Error returned from `GetUnicastIpAddressTable`/`GetUnicastIpAddressEntry`
+ #[cfg(windows)]
+ #[error(display = "Cannot find LUID for virtual adapter")]
+ ObtainUnicastAddress(#[error(source)] io::Error),
+
+ /// `GetUnicastIpAddressTable` contained no addresses for the tunnel interface
+ #[cfg(windows)]
+ #[error(display = "Found no addresses for virtual adapter")]
+ NoUnicastAddress,
+
+ /// Unexpected DAD state returned for a unicast address
+ #[cfg(windows)]
+ #[error(display = "Unexpected DAD state")]
+ DadStateError(#[error(source)] DadStateError),
+
+ /// DAD check failed.
+ #[cfg(windows)]
+ #[error(display = "Timed out waiting on tunnel device")]
+ DeviceReadyTimeout,
+
/// OpenVPN process died unexpectedly
#[error(display = "OpenVPN process died unexpectedly")]
ChildProcessDied,
@@ -262,6 +304,11 @@ impl OpenVpnMonitor<OpenVpnCommand> {
if let Some(ref file_path) = &proxy_auth_file_path {
let _ = fs::remove_file(file_path);
}
+
+ #[cfg(windows)]
+ tokio::task::block_in_place(|| {
+ wait_for_ready_device(env.get("dev").expect("missing tunnel alias")).unwrap();
+ });
}
match TunnelEvent::from_openvpn_event(event, &env) {
Some(tunnel_event) => on_event(tunnel_event),
@@ -888,6 +935,114 @@ mod event_server {
}
}
+#[cfg(windows)]
+fn wait_for_ready_device(alias: &str) -> Result<()> {
+ // Obtain luid for alias
+ let alias_wide: Vec<u16> = OsStr::new(alias)
+ .encode_wide()
+ .chain(std::iter::once(0u16))
+ .collect();
+
+ let mut luid: NET_LUID = unsafe { std::mem::zeroed() };
+ let status = unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) };
+ if status != NO_ERROR {
+ return Err(Error::NoDeviceLuid(io::Error::last_os_error()));
+ }
+
+ // Obtain unicast IP addresses
+ let mut unicast_rows = vec![];
+
+ unsafe {
+ let mut unicast_table: *mut MIB_UNICASTIPADDRESS_TABLE = std::ptr::null_mut();
+
+ let status = GetUnicastIpAddressTable(AF_UNSPEC as u16, &mut unicast_table);
+ if status != NO_ERROR {
+ return Err(Error::ObtainUnicastAddress(io::Error::last_os_error()));
+ }
+
+ if (*unicast_table).NumEntries == 0 {
+ FreeMibTable(unicast_table as *mut _);
+ return Err(Error::NoUnicastAddress);
+ }
+
+ let first_row = &(*unicast_table).Table[0] as *const MIB_UNICASTIPADDRESS_ROW;
+
+ for i in 0..(*unicast_table).NumEntries {
+ let row = first_row.offset(i as isize);
+ if (*row).InterfaceLuid.Value != luid.Value {
+ continue;
+ }
+ unicast_rows.push(*row);
+ }
+
+ FreeMibTable(unicast_table as *mut _);
+ }
+
+ // Poll DAD status using GetUnicastIpAddressEntry
+ // https://docs.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-createunicastipaddressentry
+
+ let deadline = Instant::now() + DEVICE_READY_TIMEOUT;
+ while Instant::now() < deadline {
+ let mut ready = true;
+
+ for row in &mut unicast_rows {
+ let status = unsafe { GetUnicastIpAddressEntry(row as *mut _) };
+ if status != NO_ERROR {
+ return Err(Error::ObtainUnicastAddress(io::Error::last_os_error()));
+ }
+ if row.DadState == IpDadStateTentative {
+ ready = false;
+ break;
+ }
+ if row.DadState != IpDadStatePreferred {
+ return Err(Error::DadStateError(DadStateError::from(row.DadState)));
+ }
+ }
+
+ if ready {
+ return Ok(());
+ }
+ std::thread::sleep(DEVICE_CHECK_INTERVAL);
+ }
+
+ Err(Error::DeviceReadyTimeout)
+}
+
+/// Handles cases where there DAD state is neither tentative nor preferred.
+#[cfg(windows)]
+#[derive(err_derive::Error, Debug)]
+pub enum DadStateError {
+ /// Invalid DAD state.
+ #[error(display = "Invalid DAD state")]
+ Invalid,
+
+ /// Duplicate unicast address.
+ #[error(display = "A duplicate IP address was detected")]
+ Duplicate,
+
+ /// Deprecated unicast address.
+ #[error(display = "The IP address has been deprecated")]
+ Deprecated,
+
+ /// Unknown DAD state constant.
+ #[error(display = "Unknown DAD state: {}", _0)]
+ Unknown(u32),
+}
+
+#[cfg(windows)]
+#[allow(non_upper_case_globals)]
+impl From<NL_DAD_STATE> for DadStateError {
+ fn from(state: NL_DAD_STATE) -> DadStateError {
+ use winapi::shared::nldef::*;
+ match state {
+ IpDadStateInvalid => DadStateError::Invalid,
+ IpDadStateDuplicate => DadStateError::Duplicate,
+ IpDadStateDeprecated => DadStateError::Deprecated,
+ other => DadStateError::Unknown(other),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {