summaryrefslogtreecommitdiffhomepage
path: root/talpid-core/src
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2019-06-04 13:08:52 +0200
committerOdd Stranne <odd@mullvad.net>2019-11-25 13:49:41 +0100
commitd3aa9cfe61c804767375599138bbf756bb07cb84 (patch)
tree21374abd907969eb664063c5ad469cab45f13da1 /talpid-core/src
parent61f840bc4827863e1c7d5ca281161508ded401c6 (diff)
downloadmullvadvpn-d3aa9cfe61c804767375599138bbf756bb07cb84.tar.xz
mullvadvpn-d3aa9cfe61c804767375599138bbf756bb07cb84.zip
Add Windows support for WireGuard in 'talpid-core'
Diffstat (limited to 'talpid-core/src')
-rw-r--r--talpid-core/src/dns/windows/mod.rs2
-rw-r--r--talpid-core/src/firewall/windows.rs48
-rw-r--r--talpid-core/src/lib.rs1
-rw-r--r--talpid-core/src/ping_monitor/mod.rs2
-rw-r--r--talpid-core/src/ping_monitor/win.rs45
-rw-r--r--talpid-core/src/routing/mod.rs38
-rw-r--r--talpid-core/src/routing/windows.rs72
-rw-r--r--talpid-core/src/tunnel/mod.rs16
-rw-r--r--talpid-core/src/tunnel/tun_provider/mod.rs4
-rw-r--r--talpid-core/src/tunnel/tun_provider/windows.rs13
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs20
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs131
-rw-r--r--talpid-core/src/winnet.rs303
13 files changed, 649 insertions, 46 deletions
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index d04e5b8b00..beaaba2ee9 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -159,7 +159,7 @@ type ErrorSink = extern "system" fn(
);
#[allow(non_snake_case)]
-extern "system" {
+extern "stdcall" {
#[link_name = "WinDns_Initialize"]
pub fn WinDns_Initialize(
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index 67b37713d2..bee16fee3a 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -85,12 +85,17 @@ impl FirewallT for Firewall {
match policy {
FirewallPolicy::Connecting {
peer_endpoint,
- // TODO: Allow ICMP traffic to a list of hosts for wireguard
- pingable_hosts: _,
+ pingable_hosts,
allow_lan,
} => {
let cfg = &WinFwSettings::new(allow_lan);
- self.set_connecting_state(&peer_endpoint, &cfg)
+ // TODO: Determine interface alias at runtime
+ self.set_connecting_state(
+ &peer_endpoint,
+ &cfg,
+ "wg-mullvad".to_string(),
+ &pingable_hosts,
+ )
}
FirewallPolicy::Connected {
peer_endpoint,
@@ -128,6 +133,8 @@ impl Firewall {
&mut self,
endpoint: &Endpoint,
winfw_settings: &WinFwSettings,
+ _tunnel_iface_alias: String,
+ pingable_hosts: &Vec<IpAddr>,
) -> Result<(), Error> {
trace!("Applying 'connecting' firewall policy");
let ip_str = Self::widestring_ip(endpoint.address.ip());
@@ -139,7 +146,31 @@ impl Firewall {
protocol: WinFwProt::from(endpoint.protocol),
};
- unsafe { WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay).into_result() }
+ if pingable_hosts.is_empty() {
+ unsafe {
+ return WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, ptr::null())
+ .into_result();
+ }
+ }
+
+ let pingable_addresses = pingable_hosts
+ .iter()
+ .map(|ip| Self::widestring_ip(*ip))
+ .collect::<Vec<_>>();
+ let pingable_address_ptrs = pingable_addresses
+ .iter()
+ .map(|ip| ip.as_ptr())
+ .collect::<Vec<_>>();
+
+ let pingable_hosts = WinFwPingableHosts {
+ interfaceAlias: ptr::null(),
+ addresses: pingable_address_ptrs.as_ptr(),
+ num_addresses: pingable_addresses.len(),
+ };
+
+ unsafe {
+ WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, &pingable_hosts).into_result()
+ }
}
fn widestring_ip(ip: IpAddr) -> WideCString {
@@ -250,6 +281,14 @@ mod winfw {
}
}
+ #[repr(C)]
+ pub struct WinFwPingableHosts {
+ // a null pointer implies that all interfaces will be able to ping the supplied addresses
+ pub interfaceAlias: *const libc::wchar_t,
+ pub addresses: *const *const libc::wchar_t,
+ pub num_addresses: usize,
+ }
+
ffi_error!(InitializationResult, Error::Initialization);
ffi_error!(DeinitializationResult, Error::Deinitialization);
ffi_error!(ApplyConnectingResult, Error::ApplyingConnectingPolicy);
@@ -280,6 +319,7 @@ mod winfw {
pub fn WinFw_ApplyPolicyConnecting(
settings: &WinFwSettings,
relay: &WinFwRelay,
+ pingable_hosts: *const WinFwPingableHosts,
) -> ApplyConnectingResult;
#[link_name = "WinFw_ApplyPolicyConnected"]
diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs
index 5fdbd3030d..cd88277d9d 100644
--- a/talpid-core/src/lib.rs
+++ b/talpid-core/src/lib.rs
@@ -23,7 +23,6 @@ mod winnet;
#[cfg(any(target_os = "linux", target_os = "macos"))]
/// Working with IP interface devices
pub mod network_interface;
-#[cfg(not(windows))]
/// Abstraction over operating system routing table.
pub mod routing;
diff --git a/talpid-core/src/ping_monitor/mod.rs b/talpid-core/src/ping_monitor/mod.rs
index 8cdc766ad4..20993a8fb5 100644
--- a/talpid-core/src/ping_monitor/mod.rs
+++ b/talpid-core/src/ping_monitor/mod.rs
@@ -7,4 +7,4 @@ mod imp;
#[path = "win.rs"]
mod imp;
-pub use imp::{monitor_ping, ping, Error};
+pub use imp::{monitor_ping, ping, Error, Pinger};
diff --git a/talpid-core/src/ping_monitor/win.rs b/talpid-core/src/ping_monitor/win.rs
index f4540dddbd..6c2018c253 100644
--- a/talpid-core/src/ping_monitor/win.rs
+++ b/talpid-core/src/ping_monitor/win.rs
@@ -1,5 +1,3 @@
-#![allow(dead_code)]
-// TODO: Remove lint exemption once ping monitor is used on Windows
use pnet_packet::{
icmp::{
self,
@@ -18,6 +16,8 @@ use std::{
time::{Duration, Instant},
};
+const SEND_RETRY_ATTEMPTS: u32 = 10;
+
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
@@ -40,10 +40,10 @@ pub enum Error {
pub fn monitor_ping(
ip: Ipv4Addr,
timeout_secs: u16,
- _interface: &str,
+ interface: &str,
close_receiver: mpsc::Receiver<()>,
) -> Result<()> {
- let mut pinger = Pinger::new(ip)?;
+ let mut pinger = Pinger::new(ip, interface)?;
while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() {
let start = Instant::now();
pinger.send_ping(Duration::from_secs(timeout_secs.into()))?;
@@ -57,8 +57,8 @@ pub fn monitor_ping(
Ok(())
}
-pub fn ping(ip: Ipv4Addr, timeout_secs: u16, _interface: &str) -> Result<()> {
- Pinger::new(ip)?.send_ping(Duration::from_secs(timeout_secs.into()))
+pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<()> {
+ Pinger::new(ip, interface)?.send_ping(Duration::from_secs(timeout_secs.into()))
}
type Result<T> = std::result::Result<T, Error>;
@@ -71,10 +71,12 @@ pub struct Pinger {
}
impl Pinger {
- pub fn new(addr: Ipv4Addr) -> Result<Self> {
+ pub fn new(addr: Ipv4Addr, _interface_name: &str) -> Result<Self> {
let sock = Socket::new(Domain::ipv4(), Type::raw(), Some(Protocol::icmpv4()))
.map_err(Error::OpenError)?;
sock.set_nonblocking(true).map_err(Error::OpenError)?;
+
+
Ok(Self {
sock,
id: rand::random(),
@@ -87,12 +89,35 @@ impl Pinger {
pub fn send_ping(&mut self, timeout: Duration) -> Result<()> {
let dest = SocketAddr::new(IpAddr::from(self.addr), 0);
let request = self.next_ping_request();
- self.sock
- .send_to(request.packet(), &dest.into())
- .map_err(Error::WriteError)?;
+ self.send_ping_request(&request, dest.into())?;
self.wait_for_response(Instant::now() + timeout, &request)
}
+ fn send_ping_request(
+ &mut self,
+ request: &EchoRequestPacket<'static>,
+ destination: SocketAddr,
+ ) -> Result<()> {
+ let mut tries = 0;
+ let mut result = Ok(());
+ while tries < SEND_RETRY_ATTEMPTS {
+ match self.sock.send_to(request.packet(), &destination.into()) {
+ Ok(_) => {
+ return Ok(());
+ }
+ Err(err) => {
+ if Some(10065) != err.raw_os_error() {
+ return Err(Error::WriteError(err));
+ }
+ result = Err(Error::WriteError(err));
+ }
+ }
+ thread::sleep(Duration::from_secs(1));
+ tries += 1;
+ }
+ result
+ }
+
/// returns the next ping packet
fn next_ping_request(&mut self) -> EchoRequestPacket<'static> {
const ICMP_HEADER_LENGTH: usize = 8;
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs
index 6f5b73e014..4636b0b27d 100644
--- a/talpid-core/src/routing/mod.rs
+++ b/talpid-core/src/routing/mod.rs
@@ -1,4 +1,5 @@
#![cfg_attr(target_os = "android", allow(dead_code))]
+#![cfg_attr(target_os = "windows", allow(dead_code))]
// TODO: remove the allow(dead_code) for android once it's up to scratch.
use futures::{sync::oneshot, Future};
use ipnetwork::IpNetwork;
@@ -16,6 +17,12 @@ mod imp;
#[path = "android.rs"]
mod imp;
+#[cfg(target_os = "windows")]
+#[path = "windows.rs"]
+mod imp;
+#[cfg(target_os = "windows")]
+use crate::winnet;
+
pub use imp::Error as PlatformError;
/// Errors that can be encountered whilst initializing RouteManager
@@ -37,6 +44,8 @@ pub enum Error {
/// the route will be adjusted dynamically when the default route changes.
pub struct RouteManager {
tx: Option<oneshot::Sender<oneshot::Sender<()>>>,
+ #[cfg(target_os = "windows")]
+ callback_handles: Vec<winnet::WinNetCallbackHandle>,
}
impl RouteManager {
@@ -61,12 +70,34 @@ impl RouteManager {
},
);
match start_rx.wait() {
- Ok(Ok(())) => Ok(Self { tx: Some(tx) }),
+ Ok(Ok(())) => Ok(Self {
+ tx: Some(tx),
+ #[cfg(target_os = "windows")]
+ callback_handles: vec![],
+ }),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::RoutingManagerThreadPanic),
}
}
+ /// Sets a callback that is called whenever the default route changes.
+ #[cfg(target_os = "windows")]
+ pub fn set_default_route_callback<T: 'static>(
+ &mut self,
+ callback: Option<winnet::DefaultRouteChangedCallback>,
+ context: T,
+ ) {
+ match winnet::set_default_route_change_callback(callback, context) {
+ Err(_e) => {
+ // not sure if this should panic
+ log::error!("Failed to add callback!");
+ }
+ Ok(handle) => {
+ self.callback_handles.push(handle);
+ }
+ }
+ }
+
/// Stops RouteManager and removes all of the applied routes.
pub fn stop(&mut self) {
if let Some(tx) = self.tx.take() {
@@ -85,6 +116,11 @@ impl RouteManager {
impl Drop for RouteManager {
fn drop(&mut self) {
+ // Ensuring callbacks are removed before the route manager is stopped
+ #[cfg(target_os = "windows")]
+ {
+ self.callback_handles.clear();
+ }
self.stop();
}
}
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
new file mode 100644
index 0000000000..684d1a3184
--- /dev/null
+++ b/talpid-core/src/routing/windows.rs
@@ -0,0 +1,72 @@
+use super::NetNode;
+use crate::winnet;
+use futures::{sync::oneshot, Async, Future};
+use ipnetwork::IpNetwork;
+use std::collections::HashMap;
+
+/// Windows routing errors.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// Failure to apply a route
+ #[error(display = "Failed to start route manager")]
+ FailedToStartManager,
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+pub struct RouteManagerImpl {
+ shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+}
+
+impl RouteManagerImpl {
+ pub fn new(
+ required_routes: HashMap<IpNetwork, NetNode>,
+ shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+ ) -> Result<Self> {
+ let routes: Vec<_> = required_routes
+ .iter()
+ .map(|(destination, node)| {
+ let destination = winnet::WinNetIpNetwork::from(*destination);
+ match node {
+ NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
+ NetNode::RealNode(node) => {
+ winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
+ }
+ }
+ })
+ .collect();
+
+ if !winnet::activate_routing_manager(&routes) {
+ return Err(Error::FailedToStartManager);
+ }
+
+
+ Ok(Self { shutdown_rx })
+ }
+}
+
+impl Drop for RouteManagerImpl {
+ fn drop(&mut self) {
+ if !winnet::deactivate_routing_manager() {
+ log::error!("Failed to deactivate routing manager");
+ }
+ }
+}
+
+
+impl Future for RouteManagerImpl {
+ type Item = ();
+ type Error = Error;
+ fn poll(&mut self) -> Result<Async<()>> {
+ match self.shutdown_rx.poll() {
+ Ok(Async::Ready(result_tx)) => {
+ if let Err(_e) = result_tx.send(()) {
+ log::error!("Receiver already down");
+ }
+ Ok(Async::Ready(()))
+ }
+ Ok(Async::NotReady) => Ok(Async::NotReady),
+ Err(_) => Ok(Async::Ready(())),
+ }
+ }
+}
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 6eaa77f8b8..c3f22a2560 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -9,15 +9,12 @@ use std::{
};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
-#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
-use talpid_types::net::wireguard as wireguard_types;
-use talpid_types::net::{GenericTunnelOptions, TunnelParameters};
+use talpid_types::net::{wireguard as wireguard_types, GenericTunnelOptions, TunnelParameters};
/// A module for all OpenVPN related tunnel management.
#[cfg(not(target_os = "android"))]
pub mod openvpn;
-#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
pub mod wireguard;
/// A module for low level platform specific tunnel device management.
@@ -45,7 +42,6 @@ pub enum Error {
RotateLogError(#[error(source)] crate::logging::RotateLogError),
/// Failure to build Wireguard configuration.
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
#[error(display = "Failed to configure Wireguard with the given parameters")]
WireguardConfigError(#[error(source)] self::wireguard::config::Error),
@@ -55,7 +51,6 @@ pub enum Error {
OpenVpnTunnelMonitoringError(#[error(source)] openvpn::Error),
/// There was an error listening for events from the Wireguard tunnel
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
#[error(display = "Failed while listening for events from the Wireguard tunnel")]
WireguardTunnelMonitoringError(#[error(source)] wireguard::Error),
}
@@ -161,16 +156,12 @@ impl TunnelMonitor {
#[cfg(target_os = "android")]
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
TunnelParameters::Wireguard(config) => {
Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider)
}
- #[cfg(windows)]
- TunnelParameters::Wireguard(_) => Err(Error::UnsupportedPlatform),
}
}
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
fn start_wireguard_tunnel<L>(
params: &wireguard_types::TunnelParameters,
log: Option<PathBuf>,
@@ -254,7 +245,6 @@ pub enum CloseHandle {
#[cfg(not(target_os = "android"))]
/// OpenVpn close handle
OpenVpn(openvpn::OpenVpnCloseHandle),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
/// Wireguard close handle
Wireguard(wireguard::CloseHandle),
}
@@ -265,7 +255,6 @@ impl CloseHandle {
match self {
#[cfg(not(target_os = "android"))]
CloseHandle::OpenVpn(handle) => handle.close(),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
CloseHandle::Wireguard(mut handle) => {
handle.close();
Ok(())
@@ -277,7 +266,6 @@ impl CloseHandle {
enum InternalTunnelMonitor {
#[cfg(not(target_os = "android"))]
OpenVpn(openvpn::OpenVpnMonitor),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
Wireguard(wireguard::WireguardMonitor),
}
@@ -286,7 +274,6 @@ impl InternalTunnelMonitor {
match self {
#[cfg(not(target_os = "android"))]
InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()),
}
}
@@ -295,7 +282,6 @@ impl InternalTunnelMonitor {
match self {
#[cfg(not(target_os = "android"))]
InternalTunnelMonitor::OpenVpn(tun) => tun.wait()?,
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
InternalTunnelMonitor::Wireguard(tun) => tun.wait()?,
}
diff --git a/talpid-core/src/tunnel/tun_provider/mod.rs b/talpid-core/src/tunnel/tun_provider/mod.rs
index c6701ceac9..f0bf8f69b6 100644
--- a/talpid-core/src/tunnel/tun_provider/mod.rs
+++ b/talpid-core/src/tunnel/tun_provider/mod.rs
@@ -29,6 +29,10 @@ cfg_if! {
}
}
+/// Windows tunnel
+#[cfg(target_os = "windows")]
+pub mod windows;
+
/// Generic tunnel device.
///
/// Must be associated with a platform specific file descriptor representing the device.
diff --git a/talpid-core/src/tunnel/tun_provider/windows.rs b/talpid-core/src/tunnel/tun_provider/windows.rs
new file mode 100644
index 0000000000..9a114bf4b7
--- /dev/null
+++ b/talpid-core/src/tunnel/tun_provider/windows.rs
@@ -0,0 +1,13 @@
+use super::Tun;
+
+/// Windows tunnel implementation
+pub struct WinTun {
+ /// Name of tunnel interface
+ pub interface_name: String,
+}
+
+impl Tun for WinTun {
+ fn interface_name(&self) -> &str {
+ &self.interface_name
+ }
+}
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index c9b6988f7b..1f7f11052f 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -43,6 +43,11 @@ pub enum Error {
#[error(display = "Failed to stop wireguard tunnel - {}", status)]
StopWireguardError { status: i32 },
+ /// Failed to set ip addresses on tunnel interface.
+ #[cfg(target_os = "windows")]
+ #[error(display = "Failed to set IP addresses on WireGuard interface")]
+ SetIpAddressesError,
+
/// Failed to set up routing.
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] crate::routing::Error),
@@ -97,8 +102,13 @@ impl WireguardMonitor {
Self::get_tunnel_routes(config),
)?);
let iface_name = tunnel.get_interface_name();
- let route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config))
+ let mut route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config))
.map_err(Error::SetupRoutingError)?;
+
+ #[cfg(target_os = "windows")]
+ route_handle
+ .set_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
+
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
let (pinger_tx, pinger_rx) = mpsc::channel();
@@ -121,12 +131,10 @@ impl WireguardMonitor {
Ok(()) => {
(on_event)(TunnelEvent::Up(metadata));
- match ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx)
+ if let Err(error) =
+ ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx)
{
- Ok(()) => return,
- Err(error) => {
- log::trace!("{}", error.display_chain_with_msg("Ping monitor failed"));
- }
+ log::trace!("{}", error.display_chain_with_msg("Ping monitor failed"));
}
}
Err(error) => {
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 4e3b9c45fd..75442ce56c 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -1,12 +1,35 @@
use super::{Config, Error, Result, Tunnel};
-use crate::tunnel::tun_provider::{Tun, TunConfig, TunProvider};
+use crate::tunnel::tun_provider::{Tun, TunProvider};
use ipnetwork::IpNetwork;
-use std::{ffi::CString, net::IpAddr, os::unix::io::RawFd, path::Path, ptr};
+use std::{ffi::CString, fs, path::Path};
+
+#[cfg(not(target_os = "windows"))]
+use std::ptr;
+
+#[cfg(not(target_os = "windows"))]
+use crate::tunnel::tun_provider::TunConfig;
+
+#[cfg(not(target_os = "windows"))]
+use std::net::IpAddr;
+
+#[cfg(not(target_os = "windows"))]
+use std::os::unix::io::{AsRawFd, RawFd};
+
+#[cfg(target_os = "windows")]
+use std::os::windows::io::AsRawHandle;
+
+#[cfg(target_os = "windows")]
+use crate::tunnel::tun_provider::windows::WinTun;
+
#[cfg(target_os = "android")]
use talpid_types::BoxedError;
+#[cfg(not(target_os = "windows"))]
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
+#[cfg(target_os = "windows")]
+use crate::winnet::{self, add_device_ip_addresses};
+
pub struct WgGoTunnel {
interface_name: String,
handle: Option<i32>,
@@ -16,6 +39,7 @@ pub struct WgGoTunnel {
}
impl WgGoTunnel {
+ #[cfg(not(target_os = "windows"))]
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
@@ -66,6 +90,84 @@ impl WgGoTunnel {
})
}
+ #[cfg(target_os = "windows")]
+ pub fn start_tunnel(
+ config: &Config,
+ log_path: Option<&Path>,
+ _tun_provider: &dyn TunProvider,
+ _routes: impl Iterator<Item = IpNetwork>,
+ ) -> Result<Self> {
+ let log_file = prepare_log_file(log_path)?;
+ let wg_config_str = config.to_userspace_format();
+ let iface_name: String = "wg-mullvad".to_string();
+ let cstr_iface_name =
+ CString::new(iface_name.as_bytes()).map_err(Error::InterfaceNameError)?;
+
+ let handle = unsafe {
+ wgTurnOn(
+ cstr_iface_name.as_ptr(),
+ config.mtu as i64,
+ wg_config_str.as_ptr(),
+ log_file.as_raw_handle(),
+ WG_GO_LOG_DEBUG,
+ )
+ };
+
+ if handle < 0 {
+ return Err(Error::FatalStartWireguardError);
+ }
+
+ if !add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
+ // Todo: what kind of clean-up is required?
+ return Err(Error::SetIpAddressesError);
+ }
+
+ Ok(WgGoTunnel {
+ interface_name: iface_name.clone(),
+ handle: Some(handle),
+ _tunnel_device: Box::new(WinTun {
+ interface_name: iface_name.clone(),
+ }),
+ //_log_file: log_file,
+ })
+ }
+
+ // Callback to be used to rebind the tunnel sockets when the default route changes
+ #[cfg(target_os = "windows")]
+ pub unsafe extern "system" fn default_route_changed_callback(
+ event_type: winnet::WinNetDefaultRouteChangeEventType,
+ address_family: winnet::WinNetIpFamily,
+ interface_luid: u64,
+ _ctx: *mut libc::c_void,
+ ) {
+ use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex};
+ let iface_idx: u32 = match event_type {
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ let mut iface_idx = 0u32;
+ let iface_luid = NET_LUID {
+ Value: interface_luid,
+ };
+ let status =
+ ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _);
+ if status != 0 {
+ log::error!(
+ "Failed to convert interface LUID to interface index - {} - {}",
+ status,
+ std::io::Error::last_os_error()
+ );
+ return;
+ }
+ iface_idx
+ }
+ // if there is no new default route, specify 0 as the interface index
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => 0,
+ };
+
+ wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx);
+ }
+
+
+ #[cfg(not(target_os = "windows"))]
fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig {
let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)];
dns_servers.extend(config.ipv6_gateway.map(IpAddr::V6));
@@ -102,6 +204,7 @@ impl WgGoTunnel {
Ok(())
}
+ #[cfg(not(target_os = "windows"))]
fn get_tunnel(
tun_provider: &mut dyn TunProvider,
config: &Config,
@@ -138,6 +241,13 @@ impl Drop for WgGoTunnel {
}
}
+#[cfg(target_os = "windows")]
+static NULL_DEVICE: &str = "NUL";
+
+fn prepare_log_file(log_path: Option<&Path>) -> Result<fs::File> {
+ fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError)
+}
+
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> &str {
&self.interface_name
@@ -158,8 +268,6 @@ type WgLogLevel = i32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
const WG_GO_LOG_DEBUG: WgLogLevel = 3;
-#[cfg_attr(target_os = "android", link(name = "wg", kind = "dylib"))]
-#[cfg_attr(not(target_os = "android"), link(name = "wg", kind = "static"))]
extern "C" {
// Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors
// for the tunnel device and logging.
@@ -167,6 +275,7 @@ extern "C" {
// Positive return values are tunnel handles for this specific wireguard tunnel instance.
// Negative return values signify errors. All error codes are opaque.
#[cfg_attr(target_os = "android", link_name = "wgTurnOnWithFdAndroid")]
+ #[cfg(not(target_os = "windows"))]
fn wgTurnOnWithFd(
iface_name: *const i8,
mtu: isize,
@@ -176,6 +285,16 @@ extern "C" {
logLevel: WgLogLevel,
) -> i32;
+ // Windows
+ #[cfg(target_os = "windows")]
+ fn wgTurnOn(
+ iface_name: *const i8,
+ mtu: i64,
+ settings: *const i8,
+ log_fd: Fd,
+ logLevel: WgLogLevel,
+ ) -> i32;
+
// Pass a handle that was created by wgTurnOnWithFd to stop a wireguard tunnel.
fn wgTurnOff(handle: i32) -> i32;
@@ -186,4 +305,8 @@ extern "C" {
// Returns the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
fn wgGetSocketV6(handle: i32) -> Fd;
+
+ // Rebind tunnel socket when network interfaces change
+ #[cfg(target_os = "windows")]
+ fn wgRebindTunnelSocket(family: u16, interfaceIndex: u32);
}
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index c6bd6c6115..6b42f6d810 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -2,8 +2,14 @@ use self::api::*;
pub use self::api::{
LogSink, WinNet_ActivateConnectivityMonitor, WinNet_DeactivateConnectivityMonitor,
};
+use crate::routing::Node;
+use ipnetwork::IpNetwork;
use libc::{c_char, c_void, wchar_t};
-use std::{ffi::OsString, ptr};
+use std::{
+ ffi::{CStr, OsString},
+ net::IpAddr,
+ ptr,
+};
use widestring::WideCString;
/// Errors that this module may produce.
@@ -41,7 +47,6 @@ pub enum LogSeverity {
/// Logging callback used with `winnet.dll`.
pub extern "system" fn log_sink(severity: LogSeverity, msg: *const c_char, _ctx: *mut c_void) {
- use std::ffi::CStr;
if msg.is_null() {
log::error!("Log message from FFI boundary is NULL");
} else {
@@ -119,9 +124,264 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> {
Ok(alias.to_os_string())
}
+#[repr(C)]
+struct WinNetIpType(u32);
+
+const WINNET_IPV4: u32 = 0;
+const WINNET_IPV6: u32 = 1;
+
+impl WinNetIpType {
+ pub fn v4() -> Self {
+ WinNetIpType(WINNET_IPV4)
+ }
+
+ pub fn v6() -> Self {
+ WinNetIpType(WINNET_IPV6)
+ }
+}
+
+
+#[repr(C)]
+pub struct WinNetIpNetwork {
+ ip_type: WinNetIpType,
+ ip_bytes: [u8; 16],
+ prefix: u8,
+}
+
+impl From<IpNetwork> for WinNetIpNetwork {
+ fn from(network: IpNetwork) -> WinNetIpNetwork {
+ let WinNetIp { ip_type, ip_bytes } = WinNetIp::from(network.ip());
+ let prefix = network.prefix();
+ WinNetIpNetwork {
+ ip_type,
+ ip_bytes,
+ prefix,
+ }
+ }
+}
+
+#[repr(C)]
+pub struct WinNetIp {
+ ip_type: WinNetIpType,
+ ip_bytes: [u8; 16],
+}
+
+impl From<IpAddr> for WinNetIp {
+ fn from(addr: IpAddr) -> WinNetIp {
+ let mut bytes = [0u8; 16];
+ match addr {
+ IpAddr::V4(v4_addr) => {
+ bytes[..4].copy_from_slice(&v4_addr.octets());
+ WinNetIp {
+ ip_type: WinNetIpType::v4(),
+ ip_bytes: bytes,
+ }
+ }
+ IpAddr::V6(v6_addr) => {
+ bytes.copy_from_slice(&v6_addr.octets());
+
+ WinNetIp {
+ ip_type: WinNetIpType::v6(),
+ ip_bytes: bytes,
+ }
+ }
+ }
+ }
+}
+
+#[repr(C)]
+pub struct WinNetNode {
+ gateway: *mut WinNetIp,
+ device_name: *mut u16,
+}
+
+impl WinNetNode {
+ fn new(name: &str, ip: WinNetIp) -> Self {
+ let device_name = WideCString::from_str(name)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ let gateway = Box::into_raw(Box::new(ip));
+ Self {
+ gateway,
+ device_name,
+ }
+ }
+
+ fn from_gateway(ip: WinNetIp) -> Self {
+ let gateway = Box::into_raw(Box::new(ip));
+ Self {
+ gateway,
+ device_name: ptr::null_mut(),
+ }
+ }
+
+
+ fn from_device(name: &str) -> Self {
+ let device_name = WideCString::from_str(name)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ Self {
+ gateway: ptr::null_mut(),
+ device_name,
+ }
+ }
+}
+
+impl From<&Node> for WinNetNode {
+ fn from(node: &Node) -> Self {
+ match (node.get_address(), node.get_device()) {
+ (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()),
+ (None, Some(device)) => WinNetNode::from_device(device),
+ (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()),
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl Drop for WinNetNode {
+ fn drop(&mut self) {
+ if !self.gateway.is_null() {
+ unsafe {
+ let _ = Box::from_raw(self.gateway);
+ }
+ }
+ if !self.device_name.is_null() {
+ unsafe {
+ let _ = WideCString::from_ptr_str(self.device_name);
+ }
+ }
+ }
+}
+
+
+#[repr(C)]
+pub struct WinNetRoute {
+ gateway: WinNetIpNetwork,
+ node: *mut WinNetNode,
+}
+
+impl WinNetRoute {
+ pub fn through_default_node(gateway: WinNetIpNetwork) -> Self {
+ Self {
+ gateway,
+ node: ptr::null_mut(),
+ }
+ }
+
+ pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self {
+ let node = Box::into_raw(Box::new(node));
+ WinNetRoute { gateway, node }
+ }
+}
+
+impl Drop for WinNetRoute {
+ fn drop(&mut self) {
+ if !self.node.is_null() {
+ unsafe {
+ let _ = Box::from_raw(self.node);
+ }
+ self.node = ptr::null_mut();
+ }
+ }
+}
+
+pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool {
+ unsafe { WinNet_ActivateRouteManager(Some(log_sink), ptr::null_mut()) };
+ routing_manager_add_routes(routes)
+}
+
+pub struct WinNetCallbackHandle {
+ handle: *mut libc::c_void,
+ // allows us to keep the context pointer allive.
+ _context: Box<dyn std::any::Any>,
+}
+
+unsafe impl Send for WinNetCallbackHandle {}
+
+impl Drop for WinNetCallbackHandle {
+ fn drop(&mut self) {
+ unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) };
+ }
+}
+
+#[allow(dead_code)]
+#[repr(u16)]
+pub enum WinNetDefaultRouteChangeEventType {
+ DefaultRouteChanged = 0,
+ DefaultRouteRemoved = 1,
+}
+
+#[allow(dead_code)]
+#[repr(u16)]
+pub enum WinNetIpFamily {
+ V4 = 0,
+ V6 = 1,
+}
+
+impl WinNetIpFamily {
+ pub fn to_windows_proto_enum(&self) -> u16 {
+ match self {
+ Self::V4 => 2,
+ Self::V6 => 23,
+ }
+ }
+}
+
+pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
+ event_type: WinNetDefaultRouteChangeEventType,
+ ip_family: WinNetIpFamily,
+ interface_luid: u64,
+ ctx: *mut c_void,
+);
+
+#[derive(err_derive::Error, Debug)]
+#[error(display = "Failed to set callback for default route")]
+pub struct DefaultRouteCallbackError;
+
+pub fn set_default_route_change_callback<T: 'static>(
+ callback: Option<DefaultRouteChangedCallback>,
+ context: T,
+) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> {
+ let mut handle_ptr = ptr::null_mut();
+ let mut context = Box::new(context);
+ let ctx_ptr = &mut *context as *mut T as *mut libc::c_void;
+ unsafe {
+ if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _)
+ {
+ return Err(DefaultRouteCallbackError);
+ }
+
+
+ Ok(WinNetCallbackHandle {
+ handle: handle_ptr,
+ _context: context,
+ })
+ }
+}
+
+pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool {
+ let ptr = routes.as_ptr();
+ let length: u32 = routes.len() as u32;
+ unsafe { WinNet_AddRoutes(ptr, length) }
+}
+
+pub fn deactivate_routing_manager() -> bool {
+ unsafe { WinNet_DeactivateRouteManager() }
+}
+
+pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
+ let raw_iface = WideCString::from_str(iface)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ let converted_addresses: Vec<_> = addresses.iter().map(|addr| WinNetIp::from(*addr)).collect();
+ let ptr = converted_addresses.as_ptr();
+ let length: u32 = converted_addresses.len() as u32;
+ unsafe { WinNet_AddDeviceIpAddresses(raw_iface, ptr, length, Some(log_sink), ptr::null_mut()) }
+}
+
#[allow(non_snake_case)]
mod api {
- use super::LogSeverity;
+ use super::{DefaultRouteChangedCallback, LogSeverity};
use libc::{c_char, c_void, wchar_t};
/// logging callback type for use with `winnet.dll`.
@@ -131,6 +391,24 @@ mod api {
pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void);
extern "system" {
+ #[link_name = "WinNet_ActivateRouteManager"]
+ pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *mut c_void);
+
+ #[link_name = "WinNet_AddRoutes"]
+ pub fn WinNet_AddRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool;
+
+ // #[link_name = "WinNet_AddRoute"]
+ // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> bool;
+
+ // #[link_name = "WinNet_DeleteRoutes"]
+ // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool;
+
+ // #[link_name = "WinNet_DeleteRoute"]
+ // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool;
+
+ #[link_name = "WinNet_DeactivateRouteManager"]
+ pub fn WinNet_DeactivateRouteManager() -> bool;
+
#[link_name = "WinNet_EnsureTopMetric"]
pub fn WinNet_EnsureTopMetric(
tunnel_interface_alias: *const wchar_t,
@@ -162,7 +440,26 @@ mod api {
sink_context: *mut c_void,
) -> bool;
+ #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"]
+ pub fn WinNet_RegisterDefaultRouteChangedCallback(
+ callback: Option<DefaultRouteChangedCallback>,
+ callbackContext: *mut libc::c_void,
+ registrationHandle: *mut *mut libc::c_void,
+ ) -> bool;
+
+ #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"]
+ pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void);
+
#[link_name = "WinNet_DeactivateConnectivityMonitor"]
pub fn WinNet_DeactivateConnectivityMonitor() -> bool;
+
+ #[link_name = "WinNet_AddDeviceIpAddresses"]
+ pub fn WinNet_AddDeviceIpAddresses(
+ interface_alias: *const wchar_t,
+ addresses: *const super::WinNetIp,
+ num_addresses: u32,
+ sink: Option<LogSink>,
+ sink_context: *mut c_void,
+ ) -> bool;
}
}