diff options
| -rw-r--r-- | talpid-core/Cargo.toml | 6 | ||||
| -rw-r--r-- | talpid-core/build.rs | 3 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows.rs | 48 | ||||
| -rw-r--r-- | talpid-core/src/lib.rs | 1 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/win.rs | 45 | ||||
| -rw-r--r-- | talpid-core/src/routing/mod.rs | 38 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 72 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 16 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/tun_provider/mod.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/tun_provider/windows.rs | 13 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 131 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 303 |
15 files changed, 655 insertions, 49 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index a91c71d843..360cd01938 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -12,6 +12,7 @@ cfg-if = "0.1" duct = "0.13" err-derive = "0.2.1" futures = "0.1" +hex = "0.4" ipnetwork = "0.15" jsonrpc-core = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" } jsonrpc-macros = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" } @@ -25,15 +26,14 @@ shell-escape = "0.1" talpid-ipc = { path = "../talpid-ipc" } talpid-types = { path = "../talpid-types" } tokio-core = "0.1" +tokio-executor = "0.1" uuid = { version = "0.7", features = ["v4"] } [target.'cfg(unix)'.dependencies] -hex = "0.4" lazy_static = "1.0" nix = "0.15" tokio-process = "0.2" -tokio-executor = "0.1" tokio-io = "0.1" @@ -65,7 +65,7 @@ tun = "0.4.3" [target.'cfg(windows)'.dependencies] widestring = "0.4" winreg = "0.6" -winapi = { version = "0.3.6", features = ["handleapi", "libloaderapi", "synchapi", "winbase", "winuser"] } +winapi = { version = "0.3.6", features = ["handleapi", "ifdef", "libloaderapi", "netioapi", "synchapi", "winbase", "winuser"] } socket2 = "0.3" rand = "0.7" pnet_packet = "0.22" diff --git a/talpid-core/build.rs b/talpid-core/build.rs index 12be7fcb6c..5a63a1f064 100644 --- a/talpid-core/build.rs +++ b/talpid-core/build.rs @@ -53,6 +53,9 @@ fn main() { declare_library(WINFW_DIR_VAR, WINFW_BUILD_DIR, "winfw"); declare_library(WINDNS_DIR_VAR, WINDNS_BUILD_DIR, "windns"); declare_library(WINNET_DIR_VAR, WINNET_BUILD_DIR, "winnet"); + let lib_dir = manifest_dir().join("../dist-assets/binaries/x86_64-pc-windows-msvc/wireguard"); + println!("cargo:rustc-link-search={}", &lib_dir.display()); + println!("cargo:rustc-link-lib=dylib=libwg"); } #[cfg(not(windows))] diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index d04e5b8b00..beaaba2ee9 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -159,7 +159,7 @@ type ErrorSink = extern "system" fn( ); #[allow(non_snake_case)] -extern "system" { +extern "stdcall" { #[link_name = "WinDns_Initialize"] pub fn WinDns_Initialize( diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 67b37713d2..bee16fee3a 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -85,12 +85,17 @@ impl FirewallT for Firewall { match policy { FirewallPolicy::Connecting { peer_endpoint, - // TODO: Allow ICMP traffic to a list of hosts for wireguard - pingable_hosts: _, + pingable_hosts, allow_lan, } => { let cfg = &WinFwSettings::new(allow_lan); - self.set_connecting_state(&peer_endpoint, &cfg) + // TODO: Determine interface alias at runtime + self.set_connecting_state( + &peer_endpoint, + &cfg, + "wg-mullvad".to_string(), + &pingable_hosts, + ) } FirewallPolicy::Connected { peer_endpoint, @@ -128,6 +133,8 @@ impl Firewall { &mut self, endpoint: &Endpoint, winfw_settings: &WinFwSettings, + _tunnel_iface_alias: String, + pingable_hosts: &Vec<IpAddr>, ) -> Result<(), Error> { trace!("Applying 'connecting' firewall policy"); let ip_str = Self::widestring_ip(endpoint.address.ip()); @@ -139,7 +146,31 @@ impl Firewall { protocol: WinFwProt::from(endpoint.protocol), }; - unsafe { WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay).into_result() } + if pingable_hosts.is_empty() { + unsafe { + return WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, ptr::null()) + .into_result(); + } + } + + let pingable_addresses = pingable_hosts + .iter() + .map(|ip| Self::widestring_ip(*ip)) + .collect::<Vec<_>>(); + let pingable_address_ptrs = pingable_addresses + .iter() + .map(|ip| ip.as_ptr()) + .collect::<Vec<_>>(); + + let pingable_hosts = WinFwPingableHosts { + interfaceAlias: ptr::null(), + addresses: pingable_address_ptrs.as_ptr(), + num_addresses: pingable_addresses.len(), + }; + + unsafe { + WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, &pingable_hosts).into_result() + } } fn widestring_ip(ip: IpAddr) -> WideCString { @@ -250,6 +281,14 @@ mod winfw { } } + #[repr(C)] + pub struct WinFwPingableHosts { + // a null pointer implies that all interfaces will be able to ping the supplied addresses + pub interfaceAlias: *const libc::wchar_t, + pub addresses: *const *const libc::wchar_t, + pub num_addresses: usize, + } + ffi_error!(InitializationResult, Error::Initialization); ffi_error!(DeinitializationResult, Error::Deinitialization); ffi_error!(ApplyConnectingResult, Error::ApplyingConnectingPolicy); @@ -280,6 +319,7 @@ mod winfw { pub fn WinFw_ApplyPolicyConnecting( settings: &WinFwSettings, relay: &WinFwRelay, + pingable_hosts: *const WinFwPingableHosts, ) -> ApplyConnectingResult; #[link_name = "WinFw_ApplyPolicyConnected"] diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs index 5fdbd3030d..cd88277d9d 100644 --- a/talpid-core/src/lib.rs +++ b/talpid-core/src/lib.rs @@ -23,7 +23,6 @@ mod winnet; #[cfg(any(target_os = "linux", target_os = "macos"))] /// Working with IP interface devices pub mod network_interface; -#[cfg(not(windows))] /// Abstraction over operating system routing table. pub mod routing; diff --git a/talpid-core/src/ping_monitor/mod.rs b/talpid-core/src/ping_monitor/mod.rs index 8cdc766ad4..20993a8fb5 100644 --- a/talpid-core/src/ping_monitor/mod.rs +++ b/talpid-core/src/ping_monitor/mod.rs @@ -7,4 +7,4 @@ mod imp; #[path = "win.rs"] mod imp; -pub use imp::{monitor_ping, ping, Error}; +pub use imp::{monitor_ping, ping, Error, Pinger}; diff --git a/talpid-core/src/ping_monitor/win.rs b/talpid-core/src/ping_monitor/win.rs index f4540dddbd..6c2018c253 100644 --- a/talpid-core/src/ping_monitor/win.rs +++ b/talpid-core/src/ping_monitor/win.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] -// TODO: Remove lint exemption once ping monitor is used on Windows use pnet_packet::{ icmp::{ self, @@ -18,6 +16,8 @@ use std::{ time::{Duration, Instant}, }; +const SEND_RETRY_ATTEMPTS: u32 = 10; + #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { @@ -40,10 +40,10 @@ pub enum Error { pub fn monitor_ping( ip: Ipv4Addr, timeout_secs: u16, - _interface: &str, + interface: &str, close_receiver: mpsc::Receiver<()>, ) -> Result<()> { - let mut pinger = Pinger::new(ip)?; + let mut pinger = Pinger::new(ip, interface)?; while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() { let start = Instant::now(); pinger.send_ping(Duration::from_secs(timeout_secs.into()))?; @@ -57,8 +57,8 @@ pub fn monitor_ping( Ok(()) } -pub fn ping(ip: Ipv4Addr, timeout_secs: u16, _interface: &str) -> Result<()> { - Pinger::new(ip)?.send_ping(Duration::from_secs(timeout_secs.into())) +pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<()> { + Pinger::new(ip, interface)?.send_ping(Duration::from_secs(timeout_secs.into())) } type Result<T> = std::result::Result<T, Error>; @@ -71,10 +71,12 @@ pub struct Pinger { } impl Pinger { - pub fn new(addr: Ipv4Addr) -> Result<Self> { + pub fn new(addr: Ipv4Addr, _interface_name: &str) -> Result<Self> { let sock = Socket::new(Domain::ipv4(), Type::raw(), Some(Protocol::icmpv4())) .map_err(Error::OpenError)?; sock.set_nonblocking(true).map_err(Error::OpenError)?; + + Ok(Self { sock, id: rand::random(), @@ -87,12 +89,35 @@ impl Pinger { pub fn send_ping(&mut self, timeout: Duration) -> Result<()> { let dest = SocketAddr::new(IpAddr::from(self.addr), 0); let request = self.next_ping_request(); - self.sock - .send_to(request.packet(), &dest.into()) - .map_err(Error::WriteError)?; + self.send_ping_request(&request, dest.into())?; self.wait_for_response(Instant::now() + timeout, &request) } + fn send_ping_request( + &mut self, + request: &EchoRequestPacket<'static>, + destination: SocketAddr, + ) -> Result<()> { + let mut tries = 0; + let mut result = Ok(()); + while tries < SEND_RETRY_ATTEMPTS { + match self.sock.send_to(request.packet(), &destination.into()) { + Ok(_) => { + return Ok(()); + } + Err(err) => { + if Some(10065) != err.raw_os_error() { + return Err(Error::WriteError(err)); + } + result = Err(Error::WriteError(err)); + } + } + thread::sleep(Duration::from_secs(1)); + tries += 1; + } + result + } + /// returns the next ping packet fn next_ping_request(&mut self) -> EchoRequestPacket<'static> { const ICMP_HEADER_LENGTH: usize = 8; diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs index 6f5b73e014..4636b0b27d 100644 --- a/talpid-core/src/routing/mod.rs +++ b/talpid-core/src/routing/mod.rs @@ -1,4 +1,5 @@ #![cfg_attr(target_os = "android", allow(dead_code))] +#![cfg_attr(target_os = "windows", allow(dead_code))] // TODO: remove the allow(dead_code) for android once it's up to scratch. use futures::{sync::oneshot, Future}; use ipnetwork::IpNetwork; @@ -16,6 +17,12 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(target_os = "windows")] +#[path = "windows.rs"] +mod imp; +#[cfg(target_os = "windows")] +use crate::winnet; + pub use imp::Error as PlatformError; /// Errors that can be encountered whilst initializing RouteManager @@ -37,6 +44,8 @@ pub enum Error { /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { tx: Option<oneshot::Sender<oneshot::Sender<()>>>, + #[cfg(target_os = "windows")] + callback_handles: Vec<winnet::WinNetCallbackHandle>, } impl RouteManager { @@ -61,12 +70,34 @@ impl RouteManager { }, ); match start_rx.wait() { - Ok(Ok(())) => Ok(Self { tx: Some(tx) }), + Ok(Ok(())) => Ok(Self { + tx: Some(tx), + #[cfg(target_os = "windows")] + callback_handles: vec![], + }), Ok(Err(e)) => Err(e), Err(_) => Err(Error::RoutingManagerThreadPanic), } } + /// Sets a callback that is called whenever the default route changes. + #[cfg(target_os = "windows")] + pub fn set_default_route_callback<T: 'static>( + &mut self, + callback: Option<winnet::DefaultRouteChangedCallback>, + context: T, + ) { + match winnet::set_default_route_change_callback(callback, context) { + Err(_e) => { + // not sure if this should panic + log::error!("Failed to add callback!"); + } + Ok(handle) => { + self.callback_handles.push(handle); + } + } + } + /// Stops RouteManager and removes all of the applied routes. pub fn stop(&mut self) { if let Some(tx) = self.tx.take() { @@ -85,6 +116,11 @@ impl RouteManager { impl Drop for RouteManager { fn drop(&mut self) { + // Ensuring callbacks are removed before the route manager is stopped + #[cfg(target_os = "windows")] + { + self.callback_handles.clear(); + } self.stop(); } } diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs new file mode 100644 index 0000000000..684d1a3184 --- /dev/null +++ b/talpid-core/src/routing/windows.rs @@ -0,0 +1,72 @@ +use super::NetNode; +use crate::winnet; +use futures::{sync::oneshot, Async, Future}; +use ipnetwork::IpNetwork; +use std::collections::HashMap; + +/// Windows routing errors. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// Failure to apply a route + #[error(display = "Failed to start route manager")] + FailedToStartManager, +} + +pub type Result<T> = std::result::Result<T, Error>; + +pub struct RouteManagerImpl { + shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, +} + +impl RouteManagerImpl { + pub fn new( + required_routes: HashMap<IpNetwork, NetNode>, + shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, + ) -> Result<Self> { + let routes: Vec<_> = required_routes + .iter() + .map(|(destination, node)| { + let destination = winnet::WinNetIpNetwork::from(*destination); + match node { + NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), + NetNode::RealNode(node) => { + winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) + } + } + }) + .collect(); + + if !winnet::activate_routing_manager(&routes) { + return Err(Error::FailedToStartManager); + } + + + Ok(Self { shutdown_rx }) + } +} + +impl Drop for RouteManagerImpl { + fn drop(&mut self) { + if !winnet::deactivate_routing_manager() { + log::error!("Failed to deactivate routing manager"); + } + } +} + + +impl Future for RouteManagerImpl { + type Item = (); + type Error = Error; + fn poll(&mut self) -> Result<Async<()>> { + match self.shutdown_rx.poll() { + Ok(Async::Ready(result_tx)) => { + if let Err(_e) = result_tx.send(()) { + log::error!("Receiver already down"); + } + Ok(Async::Ready(())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => Ok(Async::Ready(())), + } + } +} diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 6eaa77f8b8..c3f22a2560 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -9,15 +9,12 @@ use std::{ }; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; -#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] -use talpid_types::net::wireguard as wireguard_types; -use talpid_types::net::{GenericTunnelOptions, TunnelParameters}; +use talpid_types::net::{wireguard as wireguard_types, GenericTunnelOptions, TunnelParameters}; /// A module for all OpenVPN related tunnel management. #[cfg(not(target_os = "android"))] pub mod openvpn; -#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] pub mod wireguard; /// A module for low level platform specific tunnel device management. @@ -45,7 +42,6 @@ pub enum Error { RotateLogError(#[error(source)] crate::logging::RotateLogError), /// Failure to build Wireguard configuration. - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] #[error(display = "Failed to configure Wireguard with the given parameters")] WireguardConfigError(#[error(source)] self::wireguard::config::Error), @@ -55,7 +51,6 @@ pub enum Error { OpenVpnTunnelMonitoringError(#[error(source)] openvpn::Error), /// There was an error listening for events from the Wireguard tunnel - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] #[error(display = "Failed while listening for events from the Wireguard tunnel")] WireguardTunnelMonitoringError(#[error(source)] wireguard::Error), } @@ -161,16 +156,12 @@ impl TunnelMonitor { #[cfg(target_os = "android")] TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] TunnelParameters::Wireguard(config) => { Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider) } - #[cfg(windows)] - TunnelParameters::Wireguard(_) => Err(Error::UnsupportedPlatform), } } - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] fn start_wireguard_tunnel<L>( params: &wireguard_types::TunnelParameters, log: Option<PathBuf>, @@ -254,7 +245,6 @@ pub enum CloseHandle { #[cfg(not(target_os = "android"))] /// OpenVpn close handle OpenVpn(openvpn::OpenVpnCloseHandle), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] /// Wireguard close handle Wireguard(wireguard::CloseHandle), } @@ -265,7 +255,6 @@ impl CloseHandle { match self { #[cfg(not(target_os = "android"))] CloseHandle::OpenVpn(handle) => handle.close(), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] CloseHandle::Wireguard(mut handle) => { handle.close(); Ok(()) @@ -277,7 +266,6 @@ impl CloseHandle { enum InternalTunnelMonitor { #[cfg(not(target_os = "android"))] OpenVpn(openvpn::OpenVpnMonitor), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] Wireguard(wireguard::WireguardMonitor), } @@ -286,7 +274,6 @@ impl InternalTunnelMonitor { match self { #[cfg(not(target_os = "android"))] InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()), } } @@ -295,7 +282,6 @@ impl InternalTunnelMonitor { match self { #[cfg(not(target_os = "android"))] InternalTunnelMonitor::OpenVpn(tun) => tun.wait()?, - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] InternalTunnelMonitor::Wireguard(tun) => tun.wait()?, } diff --git a/talpid-core/src/tunnel/tun_provider/mod.rs b/talpid-core/src/tunnel/tun_provider/mod.rs index c6701ceac9..f0bf8f69b6 100644 --- a/talpid-core/src/tunnel/tun_provider/mod.rs +++ b/talpid-core/src/tunnel/tun_provider/mod.rs @@ -29,6 +29,10 @@ cfg_if! { } } +/// Windows tunnel +#[cfg(target_os = "windows")] +pub mod windows; + /// Generic tunnel device. /// /// Must be associated with a platform specific file descriptor representing the device. diff --git a/talpid-core/src/tunnel/tun_provider/windows.rs b/talpid-core/src/tunnel/tun_provider/windows.rs new file mode 100644 index 0000000000..9a114bf4b7 --- /dev/null +++ b/talpid-core/src/tunnel/tun_provider/windows.rs @@ -0,0 +1,13 @@ +use super::Tun; + +/// Windows tunnel implementation +pub struct WinTun { + /// Name of tunnel interface + pub interface_name: String, +} + +impl Tun for WinTun { + fn interface_name(&self) -> &str { + &self.interface_name + } +} diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index c9b6988f7b..1f7f11052f 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -43,6 +43,11 @@ pub enum Error { #[error(display = "Failed to stop wireguard tunnel - {}", status)] StopWireguardError { status: i32 }, + /// Failed to set ip addresses on tunnel interface. + #[cfg(target_os = "windows")] + #[error(display = "Failed to set IP addresses on WireGuard interface")] + SetIpAddressesError, + /// Failed to set up routing. #[error(display = "Failed to setup routing")] SetupRoutingError(#[error(source)] crate::routing::Error), @@ -97,8 +102,13 @@ impl WireguardMonitor { Self::get_tunnel_routes(config), )?); let iface_name = tunnel.get_interface_name(); - let route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config)) + let mut route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config)) .map_err(Error::SetupRoutingError)?; + + #[cfg(target_os = "windows")] + route_handle + .set_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()); + let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = mpsc::channel(); let (pinger_tx, pinger_rx) = mpsc::channel(); @@ -121,12 +131,10 @@ impl WireguardMonitor { Ok(()) => { (on_event)(TunnelEvent::Up(metadata)); - match ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx) + if let Err(error) = + ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx) { - Ok(()) => return, - Err(error) => { - log::trace!("{}", error.display_chain_with_msg("Ping monitor failed")); - } + log::trace!("{}", error.display_chain_with_msg("Ping monitor failed")); } } Err(error) => { diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 4e3b9c45fd..75442ce56c 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -1,12 +1,35 @@ use super::{Config, Error, Result, Tunnel}; -use crate::tunnel::tun_provider::{Tun, TunConfig, TunProvider}; +use crate::tunnel::tun_provider::{Tun, TunProvider}; use ipnetwork::IpNetwork; -use std::{ffi::CString, net::IpAddr, os::unix::io::RawFd, path::Path, ptr}; +use std::{ffi::CString, fs, path::Path}; + +#[cfg(not(target_os = "windows"))] +use std::ptr; + +#[cfg(not(target_os = "windows"))] +use crate::tunnel::tun_provider::TunConfig; + +#[cfg(not(target_os = "windows"))] +use std::net::IpAddr; + +#[cfg(not(target_os = "windows"))] +use std::os::unix::io::{AsRawFd, RawFd}; + +#[cfg(target_os = "windows")] +use std::os::windows::io::AsRawHandle; + +#[cfg(target_os = "windows")] +use crate::tunnel::tun_provider::windows::WinTun; + #[cfg(target_os = "android")] use talpid_types::BoxedError; +#[cfg(not(target_os = "windows"))] const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; +#[cfg(target_os = "windows")] +use crate::winnet::{self, add_device_ip_addresses}; + pub struct WgGoTunnel { interface_name: String, handle: Option<i32>, @@ -16,6 +39,7 @@ pub struct WgGoTunnel { } impl WgGoTunnel { + #[cfg(not(target_os = "windows"))] pub fn start_tunnel( config: &Config, log_path: Option<&Path>, @@ -66,6 +90,84 @@ impl WgGoTunnel { }) } + #[cfg(target_os = "windows")] + pub fn start_tunnel( + config: &Config, + log_path: Option<&Path>, + _tun_provider: &dyn TunProvider, + _routes: impl Iterator<Item = IpNetwork>, + ) -> Result<Self> { + let log_file = prepare_log_file(log_path)?; + let wg_config_str = config.to_userspace_format(); + let iface_name: String = "wg-mullvad".to_string(); + let cstr_iface_name = + CString::new(iface_name.as_bytes()).map_err(Error::InterfaceNameError)?; + + let handle = unsafe { + wgTurnOn( + cstr_iface_name.as_ptr(), + config.mtu as i64, + wg_config_str.as_ptr(), + log_file.as_raw_handle(), + WG_GO_LOG_DEBUG, + ) + }; + + if handle < 0 { + return Err(Error::FatalStartWireguardError); + } + + if !add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { + // Todo: what kind of clean-up is required? + return Err(Error::SetIpAddressesError); + } + + Ok(WgGoTunnel { + interface_name: iface_name.clone(), + handle: Some(handle), + _tunnel_device: Box::new(WinTun { + interface_name: iface_name.clone(), + }), + //_log_file: log_file, + }) + } + + // Callback to be used to rebind the tunnel sockets when the default route changes + #[cfg(target_os = "windows")] + pub unsafe extern "system" fn default_route_changed_callback( + event_type: winnet::WinNetDefaultRouteChangeEventType, + address_family: winnet::WinNetIpFamily, + interface_luid: u64, + _ctx: *mut libc::c_void, + ) { + use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex}; + let iface_idx: u32 = match event_type { + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => { + let mut iface_idx = 0u32; + let iface_luid = NET_LUID { + Value: interface_luid, + }; + let status = + ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _); + if status != 0 { + log::error!( + "Failed to convert interface LUID to interface index - {} - {}", + status, + std::io::Error::last_os_error() + ); + return; + } + iface_idx + } + // if there is no new default route, specify 0 as the interface index + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => 0, + }; + + wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx); + } + + + #[cfg(not(target_os = "windows"))] fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig { let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)]; dns_servers.extend(config.ipv6_gateway.map(IpAddr::V6)); @@ -102,6 +204,7 @@ impl WgGoTunnel { Ok(()) } + #[cfg(not(target_os = "windows"))] fn get_tunnel( tun_provider: &mut dyn TunProvider, config: &Config, @@ -138,6 +241,13 @@ impl Drop for WgGoTunnel { } } +#[cfg(target_os = "windows")] +static NULL_DEVICE: &str = "NUL"; + +fn prepare_log_file(log_path: Option<&Path>) -> Result<fs::File> { + fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError) +} + impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> &str { &self.interface_name @@ -158,8 +268,6 @@ type WgLogLevel = i32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose const WG_GO_LOG_DEBUG: WgLogLevel = 3; -#[cfg_attr(target_os = "android", link(name = "wg", kind = "dylib"))] -#[cfg_attr(not(target_os = "android"), link(name = "wg", kind = "static"))] extern "C" { // Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors // for the tunnel device and logging. @@ -167,6 +275,7 @@ extern "C" { // Positive return values are tunnel handles for this specific wireguard tunnel instance. // Negative return values signify errors. All error codes are opaque. #[cfg_attr(target_os = "android", link_name = "wgTurnOnWithFdAndroid")] + #[cfg(not(target_os = "windows"))] fn wgTurnOnWithFd( iface_name: *const i8, mtu: isize, @@ -176,6 +285,16 @@ extern "C" { logLevel: WgLogLevel, ) -> i32; + // Windows + #[cfg(target_os = "windows")] + fn wgTurnOn( + iface_name: *const i8, + mtu: i64, + settings: *const i8, + log_fd: Fd, + logLevel: WgLogLevel, + ) -> i32; + // Pass a handle that was created by wgTurnOnWithFd to stop a wireguard tunnel. fn wgTurnOff(handle: i32) -> i32; @@ -186,4 +305,8 @@ extern "C" { // Returns the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] fn wgGetSocketV6(handle: i32) -> Fd; + + // Rebind tunnel socket when network interfaces change + #[cfg(target_os = "windows")] + fn wgRebindTunnelSocket(family: u16, interfaceIndex: u32); } diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index c6bd6c6115..6b42f6d810 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -2,8 +2,14 @@ use self::api::*; pub use self::api::{ LogSink, WinNet_ActivateConnectivityMonitor, WinNet_DeactivateConnectivityMonitor, }; +use crate::routing::Node; +use ipnetwork::IpNetwork; use libc::{c_char, c_void, wchar_t}; -use std::{ffi::OsString, ptr}; +use std::{ + ffi::{CStr, OsString}, + net::IpAddr, + ptr, +}; use widestring::WideCString; /// Errors that this module may produce. @@ -41,7 +47,6 @@ pub enum LogSeverity { /// Logging callback used with `winnet.dll`. pub extern "system" fn log_sink(severity: LogSeverity, msg: *const c_char, _ctx: *mut c_void) { - use std::ffi::CStr; if msg.is_null() { log::error!("Log message from FFI boundary is NULL"); } else { @@ -119,9 +124,264 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> { Ok(alias.to_os_string()) } +#[repr(C)] +struct WinNetIpType(u32); + +const WINNET_IPV4: u32 = 0; +const WINNET_IPV6: u32 = 1; + +impl WinNetIpType { + pub fn v4() -> Self { + WinNetIpType(WINNET_IPV4) + } + + pub fn v6() -> Self { + WinNetIpType(WINNET_IPV6) + } +} + + +#[repr(C)] +pub struct WinNetIpNetwork { + ip_type: WinNetIpType, + ip_bytes: [u8; 16], + prefix: u8, +} + +impl From<IpNetwork> for WinNetIpNetwork { + fn from(network: IpNetwork) -> WinNetIpNetwork { + let WinNetIp { ip_type, ip_bytes } = WinNetIp::from(network.ip()); + let prefix = network.prefix(); + WinNetIpNetwork { + ip_type, + ip_bytes, + prefix, + } + } +} + +#[repr(C)] +pub struct WinNetIp { + ip_type: WinNetIpType, + ip_bytes: [u8; 16], +} + +impl From<IpAddr> for WinNetIp { + fn from(addr: IpAddr) -> WinNetIp { + let mut bytes = [0u8; 16]; + match addr { + IpAddr::V4(v4_addr) => { + bytes[..4].copy_from_slice(&v4_addr.octets()); + WinNetIp { + ip_type: WinNetIpType::v4(), + ip_bytes: bytes, + } + } + IpAddr::V6(v6_addr) => { + bytes.copy_from_slice(&v6_addr.octets()); + + WinNetIp { + ip_type: WinNetIpType::v6(), + ip_bytes: bytes, + } + } + } + } +} + +#[repr(C)] +pub struct WinNetNode { + gateway: *mut WinNetIp, + device_name: *mut u16, +} + +impl WinNetNode { + fn new(name: &str, ip: WinNetIp) -> Self { + let device_name = WideCString::from_str(name) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + let gateway = Box::into_raw(Box::new(ip)); + Self { + gateway, + device_name, + } + } + + fn from_gateway(ip: WinNetIp) -> Self { + let gateway = Box::into_raw(Box::new(ip)); + Self { + gateway, + device_name: ptr::null_mut(), + } + } + + + fn from_device(name: &str) -> Self { + let device_name = WideCString::from_str(name) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + Self { + gateway: ptr::null_mut(), + device_name, + } + } +} + +impl From<&Node> for WinNetNode { + fn from(node: &Node) -> Self { + match (node.get_address(), node.get_device()) { + (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()), + (None, Some(device)) => WinNetNode::from_device(device), + (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()), + _ => unreachable!(), + } + } +} + +impl Drop for WinNetNode { + fn drop(&mut self) { + if !self.gateway.is_null() { + unsafe { + let _ = Box::from_raw(self.gateway); + } + } + if !self.device_name.is_null() { + unsafe { + let _ = WideCString::from_ptr_str(self.device_name); + } + } + } +} + + +#[repr(C)] +pub struct WinNetRoute { + gateway: WinNetIpNetwork, + node: *mut WinNetNode, +} + +impl WinNetRoute { + pub fn through_default_node(gateway: WinNetIpNetwork) -> Self { + Self { + gateway, + node: ptr::null_mut(), + } + } + + pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self { + let node = Box::into_raw(Box::new(node)); + WinNetRoute { gateway, node } + } +} + +impl Drop for WinNetRoute { + fn drop(&mut self) { + if !self.node.is_null() { + unsafe { + let _ = Box::from_raw(self.node); + } + self.node = ptr::null_mut(); + } + } +} + +pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool { + unsafe { WinNet_ActivateRouteManager(Some(log_sink), ptr::null_mut()) }; + routing_manager_add_routes(routes) +} + +pub struct WinNetCallbackHandle { + handle: *mut libc::c_void, + // allows us to keep the context pointer allive. + _context: Box<dyn std::any::Any>, +} + +unsafe impl Send for WinNetCallbackHandle {} + +impl Drop for WinNetCallbackHandle { + fn drop(&mut self) { + unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) }; + } +} + +#[allow(dead_code)] +#[repr(u16)] +pub enum WinNetDefaultRouteChangeEventType { + DefaultRouteChanged = 0, + DefaultRouteRemoved = 1, +} + +#[allow(dead_code)] +#[repr(u16)] +pub enum WinNetIpFamily { + V4 = 0, + V6 = 1, +} + +impl WinNetIpFamily { + pub fn to_windows_proto_enum(&self) -> u16 { + match self { + Self::V4 => 2, + Self::V6 => 23, + } + } +} + +pub type DefaultRouteChangedCallback = unsafe extern "system" fn( + event_type: WinNetDefaultRouteChangeEventType, + ip_family: WinNetIpFamily, + interface_luid: u64, + ctx: *mut c_void, +); + +#[derive(err_derive::Error, Debug)] +#[error(display = "Failed to set callback for default route")] +pub struct DefaultRouteCallbackError; + +pub fn set_default_route_change_callback<T: 'static>( + callback: Option<DefaultRouteChangedCallback>, + context: T, +) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> { + let mut handle_ptr = ptr::null_mut(); + let mut context = Box::new(context); + let ctx_ptr = &mut *context as *mut T as *mut libc::c_void; + unsafe { + if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _) + { + return Err(DefaultRouteCallbackError); + } + + + Ok(WinNetCallbackHandle { + handle: handle_ptr, + _context: context, + }) + } +} + +pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool { + let ptr = routes.as_ptr(); + let length: u32 = routes.len() as u32; + unsafe { WinNet_AddRoutes(ptr, length) } +} + +pub fn deactivate_routing_manager() -> bool { + unsafe { WinNet_DeactivateRouteManager() } +} + +pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool { + let raw_iface = WideCString::from_str(iface) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + let converted_addresses: Vec<_> = addresses.iter().map(|addr| WinNetIp::from(*addr)).collect(); + let ptr = converted_addresses.as_ptr(); + let length: u32 = converted_addresses.len() as u32; + unsafe { WinNet_AddDeviceIpAddresses(raw_iface, ptr, length, Some(log_sink), ptr::null_mut()) } +} + #[allow(non_snake_case)] mod api { - use super::LogSeverity; + use super::{DefaultRouteChangedCallback, LogSeverity}; use libc::{c_char, c_void, wchar_t}; /// logging callback type for use with `winnet.dll`. @@ -131,6 +391,24 @@ mod api { pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void); extern "system" { + #[link_name = "WinNet_ActivateRouteManager"] + pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *mut c_void); + + #[link_name = "WinNet_AddRoutes"] + pub fn WinNet_AddRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool; + + // #[link_name = "WinNet_AddRoute"] + // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> bool; + + // #[link_name = "WinNet_DeleteRoutes"] + // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool; + + // #[link_name = "WinNet_DeleteRoute"] + // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool; + + #[link_name = "WinNet_DeactivateRouteManager"] + pub fn WinNet_DeactivateRouteManager() -> bool; + #[link_name = "WinNet_EnsureTopMetric"] pub fn WinNet_EnsureTopMetric( tunnel_interface_alias: *const wchar_t, @@ -162,7 +440,26 @@ mod api { sink_context: *mut c_void, ) -> bool; + #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"] + pub fn WinNet_RegisterDefaultRouteChangedCallback( + callback: Option<DefaultRouteChangedCallback>, + callbackContext: *mut libc::c_void, + registrationHandle: *mut *mut libc::c_void, + ) -> bool; + + #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"] + pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void); + #[link_name = "WinNet_DeactivateConnectivityMonitor"] pub fn WinNet_DeactivateConnectivityMonitor() -> bool; + + #[link_name = "WinNet_AddDeviceIpAddresses"] + pub fn WinNet_AddDeviceIpAddresses( + interface_alias: *const wchar_t, + addresses: *const super::WinNetIp, + num_addresses: u32, + sink: Option<LogSink>, + sink_context: *mut c_void, + ) -> bool; } } |
