diff options
| author | Odd Stranne <odd@mullvad.net> | 2019-11-25 14:23:36 +0100 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2019-11-25 14:23:36 +0100 |
| commit | 67a86af237d3305f84bb7044aa1bdf5e1122e81f (patch) | |
| tree | f910fc098a52d5466de50224d8cd66c203aac4fa | |
| parent | dc6d5d8e87738f919b8f924b3381a1954097138b (diff) | |
| parent | e4f46afbe027cb8ddbb45a66ce014af9acbc54b6 (diff) | |
| download | mullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.tar.xz mullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.zip | |
Merge branch 'win-wireguard'
47 files changed, 3372 insertions, 115 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 7884f211cf..879b58a49b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Line wrap the file at 100 chars. Th ## [Unreleased] ### Added #### Windows +- Full WireGuard support, GUI and CLI. - Install Wintun driver that provides the WireGuard TUN adapter. - Remove Mullvad TAP adapter on uninstall. Also remove the TAP driver if there are no other TAP adapters in the system. diff --git a/Cargo.lock b/Cargo.lock index 5c6db6f0b8..b23c1d31c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2423,6 +2423,7 @@ version = "0.1.0" dependencies = [ "atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)", "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "chrono 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)", "dbus 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", "duct 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)", "err-derive 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/dist-assets/binaries b/dist-assets/binaries -Subproject 66091e4249f8afcbf3daf4ffb01bb05bf8d64d0 +Subproject fe1f86af8f3b99eed99b1299b5d4ca15f20ebab diff --git a/gui/src/renderer/components/AdvancedSettings.tsx b/gui/src/renderer/components/AdvancedSettings.tsx index 334da211a8..dd9568bed5 100644 --- a/gui/src/renderer/components/AdvancedSettings.tsx +++ b/gui/src/renderer/components/AdvancedSettings.tsx @@ -43,7 +43,6 @@ interface IProps { wireguard: { port?: number }; mssfix?: number; bridgeState: BridgeState; - enableWireguardKeysPage: boolean; setBridgeState: (value: BridgeState) => void; setEnableIpv6: (value: boolean) => void; setBlockWhenDisconnected: (value: boolean) => void; @@ -226,18 +225,14 @@ export default class AdvancedSettings extends Component<IProps, IState> { )} </Cell.Footer> - {process.platform !== 'win32' ? ( - <View style={styles.advanced_settings__content}> - <Selector - title={messages.pgettext('advanced-settings-view', 'Tunnel protocol')} - values={this.tunnelProtocolItems} - value={this.props.tunnelProtocol} - onSelect={this.onSelectTunnelProtocol} - /> - </View> - ) : ( - undefined - )} + <View style={styles.advanced_settings__content}> + <Selector + title={messages.pgettext('advanced-settings-view', 'Tunnel protocol')} + values={this.tunnelProtocolItems} + value={this.props.tunnelProtocol} + onSelect={this.onSelectTunnelProtocol} + /> + </View> {this.props.tunnelProtocol !== 'wireguard' ? ( <View style={styles.advanced_settings__content}> @@ -277,7 +272,7 @@ export default class AdvancedSettings extends Component<IProps, IState> { undefined )} - {this.props.tunnelProtocol === 'wireguard' && process.platform !== 'win32' ? ( + {this.props.tunnelProtocol === 'wireguard' ? ( <View style={styles.advanced_settings__content}> <Selector // TRANSLATORS: The title for the shadowsocks bridge selector section. @@ -336,7 +331,14 @@ export default class AdvancedSettings extends Component<IProps, IState> { )} </Cell.FooterText> </Cell.Footer> - {this.wireguardKeysButton()} + <View style={styles.advanced_settings__wgkeys_cell}> + <Cell.CellButton onPress={this.props.onViewWireguardKeys}> + <Cell.Label> + {messages.pgettext('advanced-settings-view', 'WireGuard key')} + </Cell.Label> + <Cell.Icon height={12} width={7} source="icon-chevron" /> + </Cell.CellButton> + </View> </NavigationScrollbars> </View> </NavigationContainer> @@ -346,21 +348,6 @@ export default class AdvancedSettings extends Component<IProps, IState> { ); } - private wireguardKeysButton() { - if (this.props.enableWireguardKeysPage) { - return ( - <View style={styles.advanced_settings__wgkeys_cell}> - <Cell.CellButton onPress={this.props.onViewWireguardKeys}> - <Cell.Label>{messages.pgettext('advanced-settings-view', 'WireGuard key')}</Cell.Label> - <Cell.Icon height={12} width={7} source="icon-chevron" /> - </Cell.CellButton> - </View> - ); - } else { - return null; - } - } - private onSelectTunnelProtocol = (protocol?: TunnelProtocol) => { this.props.setTunnelProtocol(protocol); }; diff --git a/gui/src/renderer/containers/AdvancedSettingsPage.tsx b/gui/src/renderer/containers/AdvancedSettingsPage.tsx index e31365b86d..fad5704a97 100644 --- a/gui/src/renderer/containers/AdvancedSettingsPage.tsx +++ b/gui/src/renderer/containers/AdvancedSettingsPage.tsx @@ -12,14 +12,12 @@ import { IReduxState, ReduxDispatch } from '../redux/store'; const mapStateToProps = (state: IReduxState) => { const protocolAndPort = mapRelaySettingsToProtocolAndPort(state.settings.relaySettings); - const enableWireguardKeysPage = process.platform === 'linux' || process.platform === 'darwin'; return { enableIpv6: state.settings.enableIpv6, blockWhenDisconnected: state.settings.blockWhenDisconnected, mssfix: state.settings.openVpn.mssfix, bridgeState: state.settings.bridgeState, - enableWireguardKeysPage, ...protocolAndPort, }; }; diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js index be262cf80e..7ae4dc3296 100644 --- a/gui/tasks/distribution.js +++ b/gui/tasks/distribution.js @@ -93,6 +93,7 @@ const config = { { from: root('windows/winutil/bin/x64-Release/winutil.dll'), to: '.' }, { from: distAssets('binaries/x86_64-pc-windows-msvc/openvpn.exe'), to: '.' }, { from: distAssets('binaries/x86_64-pc-windows-msvc/sslocal.exe'), to: '.' }, + { from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard/libwg.dll'), to: '.' }, ], }, diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 6081f0fe25..0a2c65cb3b 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -1319,12 +1319,7 @@ where } } - #[cfg_attr(target_os = "windows", allow(unreachable_code))] fn ensure_wireguard_keys_for_current_account(&mut self) { - #[cfg(target_os = "windows")] - { - return; - } if let Some(account) = self.settings.get_account_token() { if self .account_history diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index a91c71d843..1df24dd7c4 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -12,9 +12,11 @@ cfg-if = "0.1" duct = "0.13" err-derive = "0.2.1" futures = "0.1" +hex = "0.4" ipnetwork = "0.15" jsonrpc-core = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" } jsonrpc-macros = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" } +lazy_static = "1.0" libc = "0.2.20" log = "0.4" openvpn-plugin = { git = "https://github.com/mullvad/openvpn-plugin-rs", branch = "auth-failed-event", features = ["serde"] } @@ -25,15 +27,13 @@ shell-escape = "0.1" talpid-ipc = { path = "../talpid-ipc" } talpid-types = { path = "../talpid-types" } tokio-core = "0.1" +tokio-executor = "0.1" uuid = { version = "0.7", features = ["v4"] } [target.'cfg(unix)'.dependencies] -hex = "0.4" -lazy_static = "1.0" nix = "0.15" tokio-process = "0.2" -tokio-executor = "0.1" tokio-io = "0.1" @@ -63,9 +63,10 @@ tun = "0.4.3" [target.'cfg(windows)'.dependencies] +chrono = "0.4" widestring = "0.4" winreg = "0.6" -winapi = { version = "0.3.6", features = ["handleapi", "libloaderapi", "synchapi", "winbase", "winuser"] } +winapi = { version = "0.3.6", features = ["handleapi", "ifdef", "libloaderapi", "netioapi", "synchapi", "winbase", "winuser"] } socket2 = "0.3" rand = "0.7" pnet_packet = "0.22" diff --git a/talpid-core/build.rs b/talpid-core/build.rs index 12be7fcb6c..5a63a1f064 100644 --- a/talpid-core/build.rs +++ b/talpid-core/build.rs @@ -53,6 +53,9 @@ fn main() { declare_library(WINFW_DIR_VAR, WINFW_BUILD_DIR, "winfw"); declare_library(WINDNS_DIR_VAR, WINDNS_BUILD_DIR, "windns"); declare_library(WINNET_DIR_VAR, WINNET_BUILD_DIR, "winnet"); + let lib_dir = manifest_dir().join("../dist-assets/binaries/x86_64-pc-windows-msvc/wireguard"); + println!("cargo:rustc-link-search={}", &lib_dir.display()); + println!("cargo:rustc-link-lib=dylib=libwg"); } #[cfg(not(windows))] diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index d04e5b8b00..beaaba2ee9 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -159,7 +159,7 @@ type ErrorSink = extern "system" fn( ); #[allow(non_snake_case)] -extern "system" { +extern "stdcall" { #[link_name = "WinDns_Initialize"] pub fn WinDns_Initialize( diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 67b37713d2..bee16fee3a 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -85,12 +85,17 @@ impl FirewallT for Firewall { match policy { FirewallPolicy::Connecting { peer_endpoint, - // TODO: Allow ICMP traffic to a list of hosts for wireguard - pingable_hosts: _, + pingable_hosts, allow_lan, } => { let cfg = &WinFwSettings::new(allow_lan); - self.set_connecting_state(&peer_endpoint, &cfg) + // TODO: Determine interface alias at runtime + self.set_connecting_state( + &peer_endpoint, + &cfg, + "wg-mullvad".to_string(), + &pingable_hosts, + ) } FirewallPolicy::Connected { peer_endpoint, @@ -128,6 +133,8 @@ impl Firewall { &mut self, endpoint: &Endpoint, winfw_settings: &WinFwSettings, + _tunnel_iface_alias: String, + pingable_hosts: &Vec<IpAddr>, ) -> Result<(), Error> { trace!("Applying 'connecting' firewall policy"); let ip_str = Self::widestring_ip(endpoint.address.ip()); @@ -139,7 +146,31 @@ impl Firewall { protocol: WinFwProt::from(endpoint.protocol), }; - unsafe { WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay).into_result() } + if pingable_hosts.is_empty() { + unsafe { + return WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, ptr::null()) + .into_result(); + } + } + + let pingable_addresses = pingable_hosts + .iter() + .map(|ip| Self::widestring_ip(*ip)) + .collect::<Vec<_>>(); + let pingable_address_ptrs = pingable_addresses + .iter() + .map(|ip| ip.as_ptr()) + .collect::<Vec<_>>(); + + let pingable_hosts = WinFwPingableHosts { + interfaceAlias: ptr::null(), + addresses: pingable_address_ptrs.as_ptr(), + num_addresses: pingable_addresses.len(), + }; + + unsafe { + WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, &pingable_hosts).into_result() + } } fn widestring_ip(ip: IpAddr) -> WideCString { @@ -250,6 +281,14 @@ mod winfw { } } + #[repr(C)] + pub struct WinFwPingableHosts { + // a null pointer implies that all interfaces will be able to ping the supplied addresses + pub interfaceAlias: *const libc::wchar_t, + pub addresses: *const *const libc::wchar_t, + pub num_addresses: usize, + } + ffi_error!(InitializationResult, Error::Initialization); ffi_error!(DeinitializationResult, Error::Deinitialization); ffi_error!(ApplyConnectingResult, Error::ApplyingConnectingPolicy); @@ -280,6 +319,7 @@ mod winfw { pub fn WinFw_ApplyPolicyConnecting( settings: &WinFwSettings, relay: &WinFwRelay, + pingable_hosts: *const WinFwPingableHosts, ) -> ApplyConnectingResult; #[link_name = "WinFw_ApplyPolicyConnected"] diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs index 5fdbd3030d..cd88277d9d 100644 --- a/talpid-core/src/lib.rs +++ b/talpid-core/src/lib.rs @@ -23,7 +23,6 @@ mod winnet; #[cfg(any(target_os = "linux", target_os = "macos"))] /// Working with IP interface devices pub mod network_interface; -#[cfg(not(windows))] /// Abstraction over operating system routing table. pub mod routing; diff --git a/talpid-core/src/ping_monitor/win.rs b/talpid-core/src/ping_monitor/win.rs index f4540dddbd..40fa523584 100644 --- a/talpid-core/src/ping_monitor/win.rs +++ b/talpid-core/src/ping_monitor/win.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] -// TODO: Remove lint exemption once ping monitor is used on Windows use pnet_packet::{ icmp::{ self, @@ -18,6 +16,8 @@ use std::{ time::{Duration, Instant}, }; +const SEND_RETRY_ATTEMPTS: u32 = 10; + #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { @@ -40,10 +40,10 @@ pub enum Error { pub fn monitor_ping( ip: Ipv4Addr, timeout_secs: u16, - _interface: &str, + interface: &str, close_receiver: mpsc::Receiver<()>, ) -> Result<()> { - let mut pinger = Pinger::new(ip)?; + let mut pinger = Pinger::new(ip, interface)?; while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() { let start = Instant::now(); pinger.send_ping(Duration::from_secs(timeout_secs.into()))?; @@ -57,8 +57,8 @@ pub fn monitor_ping( Ok(()) } -pub fn ping(ip: Ipv4Addr, timeout_secs: u16, _interface: &str) -> Result<()> { - Pinger::new(ip)?.send_ping(Duration::from_secs(timeout_secs.into())) +pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<()> { + Pinger::new(ip, interface)?.send_ping(Duration::from_secs(timeout_secs.into())) } type Result<T> = std::result::Result<T, Error>; @@ -70,11 +70,15 @@ pub struct Pinger { seq: u16, } +const NUM_PINGS_TO_SEND: usize = 3; + impl Pinger { - pub fn new(addr: Ipv4Addr) -> Result<Self> { + pub fn new(addr: Ipv4Addr, _interface_name: &str) -> Result<Self> { let sock = Socket::new(Domain::ipv4(), Type::raw(), Some(Protocol::icmpv4())) .map_err(Error::OpenError)?; sock.set_nonblocking(true).map_err(Error::OpenError)?; + + Ok(Self { sock, id: rand::random(), @@ -86,19 +90,49 @@ impl Pinger { /// Sends an ICMP echo request pub fn send_ping(&mut self, timeout: Duration) -> Result<()> { let dest = SocketAddr::new(IpAddr::from(self.addr), 0); - let request = self.next_ping_request(); - self.sock - .send_to(request.packet(), &dest.into()) - .map_err(Error::WriteError)?; - self.wait_for_response(Instant::now() + timeout, &request) + let requests = (0..NUM_PINGS_TO_SEND) + .map(|_| { + let request = self.next_ping_request(); + self.send_ping_request(&request, dest)?; + Ok(request) + }) + .collect::<Result<Vec<_>>>()?; + self.wait_for_response(Instant::now() + timeout, &requests) + } + + fn send_ping_request( + &mut self, + request: &EchoRequestPacket<'static>, + destination: SocketAddr, + ) -> Result<()> { + let mut tries = 0; + let mut result = Ok(()); + while tries < SEND_RETRY_ATTEMPTS { + match self.sock.send_to(request.packet(), &destination.into()) { + Ok(_) => { + return Ok(()); + } + Err(err) => { + if Some(10065) != err.raw_os_error() { + return Err(Error::WriteError(err)); + } + result = Err(Error::WriteError(err)); + } + } + thread::sleep(Duration::from_secs(1)); + tries += 1; + } + result } /// returns the next ping packet fn next_ping_request(&mut self) -> EchoRequestPacket<'static> { + use rand::Rng; const ICMP_HEADER_LENGTH: usize = 8; - const ICMP_PAYLOAD_LENGTH: usize = 24; + const ICMP_PAYLOAD_LENGTH: usize = 150; const ICMP_PACKET_LENGTH: usize = ICMP_HEADER_LENGTH + ICMP_PAYLOAD_LENGTH; - let payload: [u8; ICMP_PAYLOAD_LENGTH] = rand::random(); + let mut payload = [0u8; ICMP_PAYLOAD_LENGTH]; + rand::thread_rng().fill(&mut payload[..]); let mut packet = MutableEchoRequestPacket::owned(vec![0u8; ICMP_PACKET_LENGTH]) .expect("Failed to construct an empty packet"); packet.set_icmp_type(IcmpType::new(8)); @@ -117,24 +151,39 @@ impl Pinger { } - fn wait_for_response(&mut self, deadline: Instant, req: &EchoRequestPacket<'_>) -> Result<()> { + fn wait_for_response( + &mut self, + deadline: Instant, + requests: &[EchoRequestPacket<'_>], + ) -> Result<()> { let mut recv_buffer = [0u8; 4096]; - while Instant::now() < deadline { + let mut bytes_received = 0; + let mut success = false; + let mut requests = requests.iter().map(|req| (false, req)).collect::<Vec<_>>(); + 'outer: while Instant::now() < deadline { match self.sock.recv(&mut recv_buffer) { Ok(recv_len) => { + bytes_received += recv_len; if recv_len > 20 { // have to slice off first 20 bytes for the IP header. if let Some(reply) = Self::parse_response(&recv_buffer[20..recv_len]) { - if reply.get_identifier() == req.get_identifier() - && reply.get_sequence_number() == req.get_sequence_number() - && req.payload() == reply.payload() - { - return Ok(()); + for (used, req) in requests.iter_mut() { + if *used { + continue; + } + if Self::request_and_response_match(req, &reply) { + *used = true; + success = true; + continue 'outer; + } } } } } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if success { + return Ok(()); + } std::thread::sleep(Duration::from_millis(100)); continue; } @@ -143,9 +192,44 @@ impl Pinger { } } } + log::debug!( + "Timing out whilst waiting for ICMP response after receiving {} bytes", + bytes_received + ); Err(Error::TimeoutError) } + fn request_and_response_match(req: &EchoRequestPacket<'_>, resp: &EchoReplyPacket<'_>) -> bool { + if req.get_identifier() != resp.get_identifier() { + log::debug!( + "Expected idnetifier {} - got {}", + req.get_identifier(), + resp.get_identifier() + ); + return false; + } + + if req.get_sequence_number() != resp.get_sequence_number() { + log::debug!( + "Expected sequence number {} - got {}", + req.get_sequence_number(), + resp.get_sequence_number() + ); + return false; + } + + if req.payload() != resp.payload() { + log::debug!( + "Expected payload {:?} - got {:?}", + req.payload(), + resp.payload() + ); + return false; + } + + return true; + } + fn parse_response<'a>(buffer: &'a [u8]) -> Option<EchoReplyPacket<'a>> { let icmp_checksum = icmp::checksum(&IcmpPacket::new(buffer)?); let reply = EchoReplyPacket::new(buffer)?; diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs index 6f5b73e014..4636b0b27d 100644 --- a/talpid-core/src/routing/mod.rs +++ b/talpid-core/src/routing/mod.rs @@ -1,4 +1,5 @@ #![cfg_attr(target_os = "android", allow(dead_code))] +#![cfg_attr(target_os = "windows", allow(dead_code))] // TODO: remove the allow(dead_code) for android once it's up to scratch. use futures::{sync::oneshot, Future}; use ipnetwork::IpNetwork; @@ -16,6 +17,12 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(target_os = "windows")] +#[path = "windows.rs"] +mod imp; +#[cfg(target_os = "windows")] +use crate::winnet; + pub use imp::Error as PlatformError; /// Errors that can be encountered whilst initializing RouteManager @@ -37,6 +44,8 @@ pub enum Error { /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { tx: Option<oneshot::Sender<oneshot::Sender<()>>>, + #[cfg(target_os = "windows")] + callback_handles: Vec<winnet::WinNetCallbackHandle>, } impl RouteManager { @@ -61,12 +70,34 @@ impl RouteManager { }, ); match start_rx.wait() { - Ok(Ok(())) => Ok(Self { tx: Some(tx) }), + Ok(Ok(())) => Ok(Self { + tx: Some(tx), + #[cfg(target_os = "windows")] + callback_handles: vec![], + }), Ok(Err(e)) => Err(e), Err(_) => Err(Error::RoutingManagerThreadPanic), } } + /// Sets a callback that is called whenever the default route changes. + #[cfg(target_os = "windows")] + pub fn set_default_route_callback<T: 'static>( + &mut self, + callback: Option<winnet::DefaultRouteChangedCallback>, + context: T, + ) { + match winnet::set_default_route_change_callback(callback, context) { + Err(_e) => { + // not sure if this should panic + log::error!("Failed to add callback!"); + } + Ok(handle) => { + self.callback_handles.push(handle); + } + } + } + /// Stops RouteManager and removes all of the applied routes. pub fn stop(&mut self) { if let Some(tx) = self.tx.take() { @@ -85,6 +116,11 @@ impl RouteManager { impl Drop for RouteManager { fn drop(&mut self) { + // Ensuring callbacks are removed before the route manager is stopped + #[cfg(target_os = "windows")] + { + self.callback_handles.clear(); + } self.stop(); } } diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs new file mode 100644 index 0000000000..684d1a3184 --- /dev/null +++ b/talpid-core/src/routing/windows.rs @@ -0,0 +1,72 @@ +use super::NetNode; +use crate::winnet; +use futures::{sync::oneshot, Async, Future}; +use ipnetwork::IpNetwork; +use std::collections::HashMap; + +/// Windows routing errors. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// Failure to apply a route + #[error(display = "Failed to start route manager")] + FailedToStartManager, +} + +pub type Result<T> = std::result::Result<T, Error>; + +pub struct RouteManagerImpl { + shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, +} + +impl RouteManagerImpl { + pub fn new( + required_routes: HashMap<IpNetwork, NetNode>, + shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, + ) -> Result<Self> { + let routes: Vec<_> = required_routes + .iter() + .map(|(destination, node)| { + let destination = winnet::WinNetIpNetwork::from(*destination); + match node { + NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), + NetNode::RealNode(node) => { + winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) + } + } + }) + .collect(); + + if !winnet::activate_routing_manager(&routes) { + return Err(Error::FailedToStartManager); + } + + + Ok(Self { shutdown_rx }) + } +} + +impl Drop for RouteManagerImpl { + fn drop(&mut self) { + if !winnet::deactivate_routing_manager() { + log::error!("Failed to deactivate routing manager"); + } + } +} + + +impl Future for RouteManagerImpl { + type Item = (); + type Error = Error; + fn poll(&mut self) -> Result<Async<()>> { + match self.shutdown_rx.poll() { + Ok(Async::Ready(result_tx)) => { + if let Err(_e) = result_tx.send(()) { + log::error!("Receiver already down"); + } + Ok(Async::Ready(())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => Ok(Async::Ready(())), + } + } +} diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 6eaa77f8b8..ddf19fcb66 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -9,15 +9,12 @@ use std::{ }; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; -#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] -use talpid_types::net::wireguard as wireguard_types; -use talpid_types::net::{GenericTunnelOptions, TunnelParameters}; +use talpid_types::net::{wireguard as wireguard_types, GenericTunnelOptions, TunnelParameters}; /// A module for all OpenVPN related tunnel management. #[cfg(not(target_os = "android"))] pub mod openvpn; -#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] pub mod wireguard; /// A module for low level platform specific tunnel device management. @@ -45,7 +42,6 @@ pub enum Error { RotateLogError(#[error(source)] crate::logging::RotateLogError), /// Failure to build Wireguard configuration. - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] #[error(display = "Failed to configure Wireguard with the given parameters")] WireguardConfigError(#[error(source)] self::wireguard::config::Error), @@ -55,7 +51,6 @@ pub enum Error { OpenVpnTunnelMonitoringError(#[error(source)] openvpn::Error), /// There was an error listening for events from the Wireguard tunnel - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] #[error(display = "Failed while listening for events from the Wireguard tunnel")] WireguardTunnelMonitoringError(#[error(source)] wireguard::Error), } @@ -161,16 +156,12 @@ impl TunnelMonitor { #[cfg(target_os = "android")] TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] TunnelParameters::Wireguard(config) => { Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider) } - #[cfg(windows)] - TunnelParameters::Wireguard(_) => Err(Error::UnsupportedPlatform), } } - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] fn start_wireguard_tunnel<L>( params: &wireguard_types::TunnelParameters, log: Option<PathBuf>, @@ -216,6 +207,7 @@ impl TunnelMonitor { } } + #[cfg(not(target_os = "windows"))] fn prepare_tunnel_log_file( parameters: &TunnelParameters, log_dir: &Option<PathBuf>, @@ -234,6 +226,23 @@ impl TunnelMonitor { } } + #[cfg(target_os = "windows")] + fn prepare_tunnel_log_file( + parameters: &TunnelParameters, + log_dir: &Option<PathBuf>, + ) -> Result<Option<PathBuf>> { + if let Some(ref log_dir) = log_dir { + let filename = match parameters { + TunnelParameters::OpenVpn(_) => OPENVPN_LOG_FILENAME, + TunnelParameters::Wireguard(_) => WIREGUARD_LOG_FILENAME, + }; + let tunnel_log = log_dir.join(filename); + logging::rotate_log(&tunnel_log)?; + Ok(Some(tunnel_log)) + } else { + Ok(None) + } + } /// Creates a handle to this monitor, allowing the tunnel to be closed while some other /// thread @@ -254,7 +263,6 @@ pub enum CloseHandle { #[cfg(not(target_os = "android"))] /// OpenVpn close handle OpenVpn(openvpn::OpenVpnCloseHandle), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] /// Wireguard close handle Wireguard(wireguard::CloseHandle), } @@ -265,7 +273,6 @@ impl CloseHandle { match self { #[cfg(not(target_os = "android"))] CloseHandle::OpenVpn(handle) => handle.close(), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] CloseHandle::Wireguard(mut handle) => { handle.close(); Ok(()) @@ -277,7 +284,6 @@ impl CloseHandle { enum InternalTunnelMonitor { #[cfg(not(target_os = "android"))] OpenVpn(openvpn::OpenVpnMonitor), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] Wireguard(wireguard::WireguardMonitor), } @@ -286,7 +292,6 @@ impl InternalTunnelMonitor { match self { #[cfg(not(target_os = "android"))] InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()), - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()), } } @@ -295,7 +300,6 @@ impl InternalTunnelMonitor { match self { #[cfg(not(target_os = "android"))] InternalTunnelMonitor::OpenVpn(tun) => tun.wait()?, - #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))] InternalTunnelMonitor::Wireguard(tun) => tun.wait()?, } diff --git a/talpid-core/src/tunnel/tun_provider/mod.rs b/talpid-core/src/tunnel/tun_provider/mod.rs index c6701ceac9..f0bf8f69b6 100644 --- a/talpid-core/src/tunnel/tun_provider/mod.rs +++ b/talpid-core/src/tunnel/tun_provider/mod.rs @@ -29,6 +29,10 @@ cfg_if! { } } +/// Windows tunnel +#[cfg(target_os = "windows")] +pub mod windows; + /// Generic tunnel device. /// /// Must be associated with a platform specific file descriptor representing the device. diff --git a/talpid-core/src/tunnel/tun_provider/windows.rs b/talpid-core/src/tunnel/tun_provider/windows.rs new file mode 100644 index 0000000000..9a114bf4b7 --- /dev/null +++ b/talpid-core/src/tunnel/tun_provider/windows.rs @@ -0,0 +1,13 @@ +use super::Tun; + +/// Windows tunnel implementation +pub struct WinTun { + /// Name of tunnel interface + pub interface_name: String, +} + +impl Tun for WinTun { + fn interface_name(&self) -> &str { + &self.interface_name + } +} diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index c9b6988f7b..1f7f11052f 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -43,6 +43,11 @@ pub enum Error { #[error(display = "Failed to stop wireguard tunnel - {}", status)] StopWireguardError { status: i32 }, + /// Failed to set ip addresses on tunnel interface. + #[cfg(target_os = "windows")] + #[error(display = "Failed to set IP addresses on WireGuard interface")] + SetIpAddressesError, + /// Failed to set up routing. #[error(display = "Failed to setup routing")] SetupRoutingError(#[error(source)] crate::routing::Error), @@ -97,8 +102,13 @@ impl WireguardMonitor { Self::get_tunnel_routes(config), )?); let iface_name = tunnel.get_interface_name(); - let route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config)) + let mut route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config)) .map_err(Error::SetupRoutingError)?; + + #[cfg(target_os = "windows")] + route_handle + .set_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()); + let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = mpsc::channel(); let (pinger_tx, pinger_rx) = mpsc::channel(); @@ -121,12 +131,10 @@ impl WireguardMonitor { Ok(()) => { (on_event)(TunnelEvent::Up(metadata)); - match ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx) + if let Err(error) = + ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx) { - Ok(()) => return, - Err(error) => { - log::trace!("{}", error.display_chain_with_msg("Ping monitor failed")); - } + log::trace!("{}", error.display_chain_with_msg("Ping monitor failed")); } } Err(error) => { diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 4e3b9c45fd..50df570d01 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -1,21 +1,56 @@ use super::{Config, Error, Result, Tunnel}; -use crate::tunnel::tun_provider::{Tun, TunConfig, TunProvider}; +use crate::tunnel::tun_provider::{Tun, TunProvider}; use ipnetwork::IpNetwork; -use std::{ffi::CString, net::IpAddr, os::unix::io::RawFd, path::Path, ptr}; +use std::{ffi::CString, path::Path}; + +#[cfg(not(target_os = "windows"))] +use { + crate::tunnel::tun_provider::TunConfig, + std::{net::IpAddr, os::unix::io::RawFd, ptr}, +}; + + +#[cfg(target_os = "windows")] +use crate::{ + tunnel::tun_provider::windows::WinTun, + winnet::{self, add_device_ip_addresses}, +}; + #[cfg(target_os = "android")] use talpid_types::BoxedError; +#[cfg(not(target_os = "windows"))] const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; +#[cfg(target_os = "windows")] +use { + chrono, + parking_lot::Mutex, + std::{collections::HashMap, fs, io::Write}, +}; + + pub struct WgGoTunnel { interface_name: String, handle: Option<i32>, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped _tunnel_device: Box<dyn Tun>, + // ordinal that maps to fs::File instance, used with logging callback + #[cfg(target_os = "windows")] + log_context_ordinal: u32, +} + +#[cfg(target_os = "windows")] +lazy_static::lazy_static! { + static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new()); } +#[cfg(target_os = "windows")] +static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0; + impl WgGoTunnel { + #[cfg(not(target_os = "windows"))] pub fn start_tunnel( config: &Config, log_path: Option<&Path>, @@ -66,6 +101,129 @@ impl WgGoTunnel { }) } + #[cfg(target_os = "windows")] + pub fn start_tunnel( + config: &Config, + log_path: Option<&Path>, + _tun_provider: &dyn TunProvider, + _routes: impl Iterator<Item = IpNetwork>, + ) -> Result<Self> { + let log_file = prepare_log_file(log_path)?; + + let log_context_ordinal = unsafe { + let mut map = LOG_MUTEX.lock(); + let ordinal = LOG_CONTEXT_NEXT_ORDINAL; + LOG_CONTEXT_NEXT_ORDINAL += 1; + map.insert(ordinal, log_file); + ordinal + }; + + let wg_config_str = config.to_userspace_format(); + let iface_name: String = "wg-mullvad".to_string(); + let cstr_iface_name = + CString::new(iface_name.as_bytes()).map_err(Error::InterfaceNameError)?; + + let handle = unsafe { + wgTurnOn( + cstr_iface_name.as_ptr(), + config.mtu as i64, + wg_config_str.as_ptr(), + Some(Self::logging_callback), + log_context_ordinal as *mut libc::c_void, + ) + }; + + if handle < 0 { + clean_up_log_file(log_context_ordinal); + return Err(Error::FatalStartWireguardError); + } + + if !add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { + // Todo: what kind of clean-up is required? + clean_up_log_file(log_context_ordinal); + return Err(Error::SetIpAddressesError); + } + + Ok(WgGoTunnel { + interface_name: iface_name.clone(), + handle: Some(handle), + _tunnel_device: Box::new(WinTun { + interface_name: iface_name.clone(), + }), + log_context_ordinal, + }) + } + + // Callback to be used to rebind the tunnel sockets when the default route changes + #[cfg(target_os = "windows")] + pub unsafe extern "system" fn default_route_changed_callback( + event_type: winnet::WinNetDefaultRouteChangeEventType, + address_family: winnet::WinNetIpFamily, + interface_luid: u64, + _ctx: *mut libc::c_void, + ) { + use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex}; + let iface_idx: u32 = match event_type { + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => { + let mut iface_idx = 0u32; + let iface_luid = NET_LUID { + Value: interface_luid, + }; + let status = + ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _); + if status != 0 { + log::error!( + "Failed to convert interface LUID to interface index - {} - {}", + status, + std::io::Error::last_os_error() + ); + return; + } + iface_idx + } + // if there is no new default route, specify 0 as the interface index + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => 0, + }; + + wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx); + } + + // Callback that receives messages from WireGuard + #[cfg(target_os = "windows")] + pub unsafe extern "system" fn logging_callback( + level: WgLogLevel, + msg: *const libc::c_char, + context: *mut libc::c_void, + ) { + let map = LOG_MUTEX.lock(); + if let Some(mut logfile) = map.get(&(context as u32)) { + let managed_msg = if !msg.is_null() { + std::ffi::CStr::from_ptr(msg) + .to_string_lossy() + .to_string() + .replace("\n", "\r\n") + } else { + "Logging message from WireGuard is NULL".to_string() + }; + + let level_str = match level { + WG_GO_LOG_DEBUG => "DEBUG", + WG_GO_LOG_INFO => "INFO", + WG_GO_LOG_ERROR | _ => "ERROR", + }; + + let _ = write!( + logfile, + "{}[{}][{}] {}", + chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"), + "wireguard-go", + level_str, + managed_msg + ); + } + } + + #[cfg(not(target_os = "windows"))] fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig { let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)]; dns_servers.extend(config.ipv6_gateway.map(IpAddr::V6)); @@ -102,6 +260,7 @@ impl WgGoTunnel { Ok(()) } + #[cfg(not(target_os = "windows"))] fn get_tunnel( tun_provider: &mut dyn TunProvider, config: &Config, @@ -130,14 +289,30 @@ impl WgGoTunnel { } } +#[cfg(target_os = "windows")] +fn clean_up_log_file(ordinal: u32) { + let mut map = LOG_MUTEX.lock(); + map.remove(&ordinal); +} + impl Drop for WgGoTunnel { fn drop(&mut self) { if let Err(e) = self.stop_tunnel() { log::error!("Failed to stop tunnel - {}", e); } + #[cfg(target_os = "windows")] + clean_up_log_file(self.log_context_ordinal); } } +#[cfg(target_os = "windows")] +static NULL_DEVICE: &str = "NUL"; + +#[cfg(target_os = "windows")] +fn prepare_log_file(log_path: Option<&Path>) -> Result<fs::File> { + fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError) +} + impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> &str { &self.interface_name @@ -154,12 +329,22 @@ pub type Fd = std::os::unix::io::RawFd; #[cfg(windows)] pub type Fd = std::os::windows::io::RawHandle; -type WgLogLevel = i32; +type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose +// const WG_GO_LOG_SILENT: WgLogLevel = 0; +#[cfg(target_os = "windows")] +const WG_GO_LOG_ERROR: WgLogLevel = 1; +#[cfg(target_os = "windows")] +const WG_GO_LOG_INFO: WgLogLevel = 2; const WG_GO_LOG_DEBUG: WgLogLevel = 3; -#[cfg_attr(target_os = "android", link(name = "wg", kind = "dylib"))] -#[cfg_attr(not(target_os = "android"), link(name = "wg", kind = "static"))] +#[cfg(target_os = "windows")] +pub type LoggingCallback = unsafe extern "system" fn( + level: WgLogLevel, + msg: *const libc::c_char, + context: *mut libc::c_void, +); + extern "C" { // Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors // for the tunnel device and logging. @@ -167,6 +352,7 @@ extern "C" { // Positive return values are tunnel handles for this specific wireguard tunnel instance. // Negative return values signify errors. All error codes are opaque. #[cfg_attr(target_os = "android", link_name = "wgTurnOnWithFdAndroid")] + #[cfg(not(target_os = "windows"))] fn wgTurnOnWithFd( iface_name: *const i8, mtu: isize, @@ -176,6 +362,16 @@ extern "C" { logLevel: WgLogLevel, ) -> i32; + // Windows + #[cfg(target_os = "windows")] + fn wgTurnOn( + iface_name: *const i8, + mtu: i64, + settings: *const i8, + logging_callback: Option<LoggingCallback>, + logging_context: *mut libc::c_void, + ) -> i32; + // Pass a handle that was created by wgTurnOnWithFd to stop a wireguard tunnel. fn wgTurnOff(handle: i32) -> i32; @@ -186,4 +382,8 @@ extern "C" { // Returns the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] fn wgGetSocketV6(handle: i32) -> Fd; + + // Rebind tunnel socket when network interfaces change + #[cfg(target_os = "windows")] + fn wgRebindTunnelSocket(family: u16, interfaceIndex: u32); } diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index c6bd6c6115..6b42f6d810 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -2,8 +2,14 @@ use self::api::*; pub use self::api::{ LogSink, WinNet_ActivateConnectivityMonitor, WinNet_DeactivateConnectivityMonitor, }; +use crate::routing::Node; +use ipnetwork::IpNetwork; use libc::{c_char, c_void, wchar_t}; -use std::{ffi::OsString, ptr}; +use std::{ + ffi::{CStr, OsString}, + net::IpAddr, + ptr, +}; use widestring::WideCString; /// Errors that this module may produce. @@ -41,7 +47,6 @@ pub enum LogSeverity { /// Logging callback used with `winnet.dll`. pub extern "system" fn log_sink(severity: LogSeverity, msg: *const c_char, _ctx: *mut c_void) { - use std::ffi::CStr; if msg.is_null() { log::error!("Log message from FFI boundary is NULL"); } else { @@ -119,9 +124,264 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> { Ok(alias.to_os_string()) } +#[repr(C)] +struct WinNetIpType(u32); + +const WINNET_IPV4: u32 = 0; +const WINNET_IPV6: u32 = 1; + +impl WinNetIpType { + pub fn v4() -> Self { + WinNetIpType(WINNET_IPV4) + } + + pub fn v6() -> Self { + WinNetIpType(WINNET_IPV6) + } +} + + +#[repr(C)] +pub struct WinNetIpNetwork { + ip_type: WinNetIpType, + ip_bytes: [u8; 16], + prefix: u8, +} + +impl From<IpNetwork> for WinNetIpNetwork { + fn from(network: IpNetwork) -> WinNetIpNetwork { + let WinNetIp { ip_type, ip_bytes } = WinNetIp::from(network.ip()); + let prefix = network.prefix(); + WinNetIpNetwork { + ip_type, + ip_bytes, + prefix, + } + } +} + +#[repr(C)] +pub struct WinNetIp { + ip_type: WinNetIpType, + ip_bytes: [u8; 16], +} + +impl From<IpAddr> for WinNetIp { + fn from(addr: IpAddr) -> WinNetIp { + let mut bytes = [0u8; 16]; + match addr { + IpAddr::V4(v4_addr) => { + bytes[..4].copy_from_slice(&v4_addr.octets()); + WinNetIp { + ip_type: WinNetIpType::v4(), + ip_bytes: bytes, + } + } + IpAddr::V6(v6_addr) => { + bytes.copy_from_slice(&v6_addr.octets()); + + WinNetIp { + ip_type: WinNetIpType::v6(), + ip_bytes: bytes, + } + } + } + } +} + +#[repr(C)] +pub struct WinNetNode { + gateway: *mut WinNetIp, + device_name: *mut u16, +} + +impl WinNetNode { + fn new(name: &str, ip: WinNetIp) -> Self { + let device_name = WideCString::from_str(name) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + let gateway = Box::into_raw(Box::new(ip)); + Self { + gateway, + device_name, + } + } + + fn from_gateway(ip: WinNetIp) -> Self { + let gateway = Box::into_raw(Box::new(ip)); + Self { + gateway, + device_name: ptr::null_mut(), + } + } + + + fn from_device(name: &str) -> Self { + let device_name = WideCString::from_str(name) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + Self { + gateway: ptr::null_mut(), + device_name, + } + } +} + +impl From<&Node> for WinNetNode { + fn from(node: &Node) -> Self { + match (node.get_address(), node.get_device()) { + (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()), + (None, Some(device)) => WinNetNode::from_device(device), + (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()), + _ => unreachable!(), + } + } +} + +impl Drop for WinNetNode { + fn drop(&mut self) { + if !self.gateway.is_null() { + unsafe { + let _ = Box::from_raw(self.gateway); + } + } + if !self.device_name.is_null() { + unsafe { + let _ = WideCString::from_ptr_str(self.device_name); + } + } + } +} + + +#[repr(C)] +pub struct WinNetRoute { + gateway: WinNetIpNetwork, + node: *mut WinNetNode, +} + +impl WinNetRoute { + pub fn through_default_node(gateway: WinNetIpNetwork) -> Self { + Self { + gateway, + node: ptr::null_mut(), + } + } + + pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self { + let node = Box::into_raw(Box::new(node)); + WinNetRoute { gateway, node } + } +} + +impl Drop for WinNetRoute { + fn drop(&mut self) { + if !self.node.is_null() { + unsafe { + let _ = Box::from_raw(self.node); + } + self.node = ptr::null_mut(); + } + } +} + +pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool { + unsafe { WinNet_ActivateRouteManager(Some(log_sink), ptr::null_mut()) }; + routing_manager_add_routes(routes) +} + +pub struct WinNetCallbackHandle { + handle: *mut libc::c_void, + // allows us to keep the context pointer allive. + _context: Box<dyn std::any::Any>, +} + +unsafe impl Send for WinNetCallbackHandle {} + +impl Drop for WinNetCallbackHandle { + fn drop(&mut self) { + unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) }; + } +} + +#[allow(dead_code)] +#[repr(u16)] +pub enum WinNetDefaultRouteChangeEventType { + DefaultRouteChanged = 0, + DefaultRouteRemoved = 1, +} + +#[allow(dead_code)] +#[repr(u16)] +pub enum WinNetIpFamily { + V4 = 0, + V6 = 1, +} + +impl WinNetIpFamily { + pub fn to_windows_proto_enum(&self) -> u16 { + match self { + Self::V4 => 2, + Self::V6 => 23, + } + } +} + +pub type DefaultRouteChangedCallback = unsafe extern "system" fn( + event_type: WinNetDefaultRouteChangeEventType, + ip_family: WinNetIpFamily, + interface_luid: u64, + ctx: *mut c_void, +); + +#[derive(err_derive::Error, Debug)] +#[error(display = "Failed to set callback for default route")] +pub struct DefaultRouteCallbackError; + +pub fn set_default_route_change_callback<T: 'static>( + callback: Option<DefaultRouteChangedCallback>, + context: T, +) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> { + let mut handle_ptr = ptr::null_mut(); + let mut context = Box::new(context); + let ctx_ptr = &mut *context as *mut T as *mut libc::c_void; + unsafe { + if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _) + { + return Err(DefaultRouteCallbackError); + } + + + Ok(WinNetCallbackHandle { + handle: handle_ptr, + _context: context, + }) + } +} + +pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool { + let ptr = routes.as_ptr(); + let length: u32 = routes.len() as u32; + unsafe { WinNet_AddRoutes(ptr, length) } +} + +pub fn deactivate_routing_manager() -> bool { + unsafe { WinNet_DeactivateRouteManager() } +} + +pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool { + let raw_iface = WideCString::from_str(iface) + .expect("Failed to convert UTF-8 string to null terminated UCS string") + .into_raw(); + let converted_addresses: Vec<_> = addresses.iter().map(|addr| WinNetIp::from(*addr)).collect(); + let ptr = converted_addresses.as_ptr(); + let length: u32 = converted_addresses.len() as u32; + unsafe { WinNet_AddDeviceIpAddresses(raw_iface, ptr, length, Some(log_sink), ptr::null_mut()) } +} + #[allow(non_snake_case)] mod api { - use super::LogSeverity; + use super::{DefaultRouteChangedCallback, LogSeverity}; use libc::{c_char, c_void, wchar_t}; /// logging callback type for use with `winnet.dll`. @@ -131,6 +391,24 @@ mod api { pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void); extern "system" { + #[link_name = "WinNet_ActivateRouteManager"] + pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *mut c_void); + + #[link_name = "WinNet_AddRoutes"] + pub fn WinNet_AddRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool; + + // #[link_name = "WinNet_AddRoute"] + // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> bool; + + // #[link_name = "WinNet_DeleteRoutes"] + // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool; + + // #[link_name = "WinNet_DeleteRoute"] + // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool; + + #[link_name = "WinNet_DeactivateRouteManager"] + pub fn WinNet_DeactivateRouteManager() -> bool; + #[link_name = "WinNet_EnsureTopMetric"] pub fn WinNet_EnsureTopMetric( tunnel_interface_alias: *const wchar_t, @@ -162,7 +440,26 @@ mod api { sink_context: *mut c_void, ) -> bool; + #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"] + pub fn WinNet_RegisterDefaultRouteChangedCallback( + callback: Option<DefaultRouteChangedCallback>, + callbackContext: *mut libc::c_void, + registrationHandle: *mut *mut libc::c_void, + ) -> bool; + + #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"] + pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void); + #[link_name = "WinNet_DeactivateConnectivityMonitor"] pub fn WinNet_DeactivateConnectivityMonitor() -> bool; + + #[link_name = "WinNet_AddDeviceIpAddresses"] + pub fn WinNet_AddDeviceIpAddresses( + interface_alias: *const wchar_t, + addresses: *const super::WinNetIp, + num_addresses: u32, + sink: Option<LogSink>, + sink_context: *mut c_void, + ) -> bool; } } diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index e5325f0b9c..49f8793572 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -13,10 +13,10 @@ #include "rules/permitvpnrelay.h" #include "rules/permitvpntunnel.h" #include "rules/permitvpntunnelservice.h" +#include "rules/permitping.h" #include "rules/restrictdns.h" #include "libwfp/transaction.h" #include "libwfp/filterengine.h" -#include "libwfp/ipaddress.h" #include <functional> #include <stdexcept> #include <utility> @@ -99,7 +99,12 @@ FwContext::FwContext(uint32_t timeout, const WinFwSettings &settings) m_baseline = checkpoint; } -bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay) +bool FwContext::applyPolicyConnecting +( + const WinFwSettings &settings, + const WinFwRelay &relay, + const std::optional<PingableHosts> &pingableHosts +) { Ruleset ruleset; @@ -112,6 +117,22 @@ bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFw TranslateProtocol(relay.protocol) )); + // + // Permit pinging the gateway inside the tunnel. + // + if (pingableHosts.has_value()) + { + const auto &ph = pingableHosts.value(); + + for (const auto &host : ph.hosts) + { + ruleset.emplace_back(std::make_unique<rules::PermitPing>( + ph.tunnelInterfaceAlias, + host + )); + } + } + return applyRuleset(ruleset); } diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index 89ef40e1d3..9d5b34c51b 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -3,9 +3,11 @@ #include "winfw.h" #include "sessioncontroller.h" #include "rules/ifirewallrule.h" +#include "libwfp/ipaddress.h" #include <cstdint> #include <memory> #include <vector> +#include <optional> class FwContext { @@ -16,7 +18,19 @@ public: // This ctor applies the "blocked" policy. FwContext(uint32_t timeout, const WinFwSettings &settings); - bool applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay); + struct PingableHosts + { + std::optional<std::wstring> tunnelInterfaceAlias; + std::vector<wfp::IpAddress> hosts; + }; + + bool applyPolicyConnecting + ( + const WinFwSettings &settings, + const WinFwRelay &relay, + const std::optional<PingableHosts> &pingableHosts + ); + bool applyPolicyConnected ( const WinFwSettings &settings, diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp index 010d41e44a..e73fac26ed 100644 --- a/windows/winfw/src/winfw/mullvadguids.cpp +++ b/windows/winfw/src/winfw/mullvadguids.cpp @@ -59,6 +59,8 @@ DetailedWfpObjectRegistry MullvadGuids::BuildDetailedRegistry() registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Outbound_Router_Solicitation())); registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Router_Advertisement())); registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Redirect())); + registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv4())); + registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv6())); return registry; } @@ -567,3 +569,31 @@ const GUID &MullvadGuids::FilterPermitNdp_Inbound_Redirect() return g; } + +//static +const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv4() +{ + static const GUID g = + { + 0x2ecf7ff7, + 0xc951, + 0x4056, + { 0xb0, 0xf7, 0x40, 0xa4, 0x5c, 0x7e, 0xb4, 0xc2 } + }; + + return g; +} + +//static +const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv6() +{ + static const GUID g = + { + 0x3deb8cab, + 0x1edb, + 0x4aa1, + { 0xb2, 0x73, 0xec, 0x61, 0x4f, 0x50, 0xdc, 0x13 } + }; + + return g; +} diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h index d4fb470d90..3c3ca9702b 100644 --- a/windows/winfw/src/winfw/mullvadguids.h +++ b/windows/winfw/src/winfw/mullvadguids.h @@ -67,4 +67,7 @@ public: static const GUID &FilterPermitNdp_Outbound_Router_Solicitation(); static const GUID &FilterPermitNdp_Inbound_Router_Advertisement(); static const GUID &FilterPermitNdp_Inbound_Redirect(); + + static const GUID &FilterPermitPing_Outbound_Icmpv4(); + static const GUID &FilterPermitPing_Outbound_Icmpv6(); }; diff --git a/windows/winfw/src/winfw/rules/permitping.cpp b/windows/winfw/src/winfw/rules/permitping.cpp new file mode 100644 index 0000000000..f6aed36bf2 --- /dev/null +++ b/windows/winfw/src/winfw/rules/permitping.cpp @@ -0,0 +1,98 @@ +#include "stdafx.h" +#include "permitping.h" +#include "winfw/mullvadguids.h" +#include "libwfp/filterbuilder.h" +#include "libwfp/conditionbuilder.h" +#include "libwfp/conditions/conditionip.h" +#include "libwfp/conditions/conditioninterface.h" +#include "libwfp/conditions/conditionprotocol.h" + + +using namespace wfp::conditions; + +namespace rules +{ + +PermitPing::PermitPing +( + const std::optional<std::wstring> &interfaceAlias, + const wfp::IpAddress &host +) + : m_interfaceAlias(interfaceAlias) + , m_host(host) +{ +} + +bool PermitPing::apply(IObjectInstaller &objectInstaller) +{ + if (wfp::IpAddress::Type::Ipv4 == m_host.type()) + { + return applyIcmpv4(objectInstaller); + } + + return applyIcmpv6(objectInstaller); +} + +bool PermitPing::applyIcmpv4(IObjectInstaller &objectInstaller) const +{ + wfp::FilterBuilder filterBuilder; + + // + // #1 Permit outbound ICMPv4 to %host% on %interface% + // + + filterBuilder + .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv4()) + .name(L"Permit outbound ICMP to specific host (ICMPv4)") + .description(L"This filter is part of a rule that permits ping") + .provider(MullvadGuids::Provider()) + .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4) + .sublayer(MullvadGuids::SublayerWhitelist()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4); + + conditionBuilder.add_condition(ConditionIp::Remote(m_host)); + conditionBuilder.add_condition(ConditionProtocol::Icmp()); + + if (m_interfaceAlias.has_value()) + { + conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value())); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +bool PermitPing::applyIcmpv6(IObjectInstaller &objectInstaller) const +{ + wfp::FilterBuilder filterBuilder; + + // + // #1 Permit outbound ICMPv6 to %host% on %interface% + // + + filterBuilder + .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv6()) + .name(L"Permit outbound ICMP to specific host (ICMPv6)") + .description(L"This filter is part of a rule that permits ping") + .provider(MullvadGuids::Provider()) + .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6) + .sublayer(MullvadGuids::SublayerWhitelist()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6); + + conditionBuilder.add_condition(ConditionIp::Remote(m_host)); + conditionBuilder.add_condition(ConditionProtocol::IcmpV6()); + + if (m_interfaceAlias.has_value()) + { + conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value())); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +} diff --git a/windows/winfw/src/winfw/rules/permitping.h b/windows/winfw/src/winfw/rules/permitping.h new file mode 100644 index 0000000000..c8238ceaa8 --- /dev/null +++ b/windows/winfw/src/winfw/rules/permitping.h @@ -0,0 +1,28 @@ +#pragma once + +#include "ifirewallrule.h" +#include <libwfp/ipaddress.h> +#include <string> +#include <optional> + +namespace rules +{ + +class PermitPing : public IFirewallRule +{ +public: + + PermitPing(const std::optional<std::wstring> &interfaceAlias, const wfp::IpAddress &host); + + bool apply(IObjectInstaller &objectInstaller) override; + +private: + + const std::optional<std::wstring> m_interfaceAlias; + const wfp::IpAddress m_host; + + bool applyIcmpv4(IObjectInstaller &objectInstaller) const; + bool applyIcmpv6(IObjectInstaller &objectInstaller) const; +}; + +} diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index 7b9ea2dc6b..3065408f3d 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -4,6 +4,7 @@ #include "objectpurger.h" #include <windows.h> #include <stdexcept> +#include <optional> namespace { @@ -15,6 +16,34 @@ void * g_errorContext = nullptr; FwContext *g_fwContext = nullptr; +std::optional<FwContext::PingableHosts> ConvertPingableHosts(const PingableHosts *pingableHosts) +{ + if (nullptr == pingableHosts) + { + return {}; + } + + if (nullptr == pingableHosts->hosts + || 0 == pingableHosts->numHosts) + { + throw std::runtime_error("Invalid PingableHosts structure"); + } + + FwContext::PingableHosts converted; + + if (nullptr != pingableHosts->tunnelInterfaceAlias) + { + converted.tunnelInterfaceAlias = pingableHosts->tunnelInterfaceAlias; + } + + for (size_t i = 0; i < pingableHosts->numHosts; ++i) + { + converted.hosts.emplace_back(wfp::IpAddress(pingableHosts->hosts[i])); + } + + return converted; +} + } // anonymous namespace WINFW_LINKAGE @@ -130,7 +159,8 @@ bool WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings &settings, - const WinFwRelay &relay + const WinFwRelay &relay, + const PingableHosts *pingableHosts ) { if (nullptr == g_fwContext) @@ -140,7 +170,7 @@ WinFw_ApplyPolicyConnecting( try { - return g_fwContext->applyPolicyConnecting(settings, relay); + return g_fwContext->applyPolicyConnecting(settings, relay, ConvertPingableHosts(pingableHosts)); } catch (std::exception &err) { diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index 95e66a608f..6d43b0db4c 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -105,11 +105,29 @@ WINFW_API WinFw_Deinitialize(); // +// PingableHosts: +// +// Specifies a set of IP addresses that should be reachable by ICMP when the connecting +// policy is effective. +// +// The interface alias is optional and can be used to restrict the traffic such +// that it is only allowed on that specific interface. +// +typedef struct tag_PingableHosts +{ + const wchar_t *tunnelInterfaceAlias; + const wchar_t **hosts; + size_t numHosts; +} +PingableHosts; + +// // ApplyPolicyConnecting: // // Apply restrictions in the firewall that block all traffic, except: // - What is specified by settings // - Communication with the relay server +// - ICMP (for ping) to/from tunnel gateway // extern "C" WINFW_LINKAGE @@ -117,7 +135,8 @@ bool WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings &settings, - const WinFwRelay &relay + const WinFwRelay &relay, + const PingableHosts *pingableHosts ); // diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj index 4777503f72..cbabe2f4f7 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj +++ b/windows/winfw/src/winfw/winfw.vcxproj @@ -30,6 +30,7 @@ <ClCompile Include="rules\permitlanservice.cpp" /> <ClCompile Include="rules\permitloopback.cpp" /> <ClCompile Include="rules\permitndp.cpp" /> + <ClCompile Include="rules\permitping.cpp" /> <ClCompile Include="rules\permitvpntunnelservice.cpp" /> <ClCompile Include="rules\permitvpnrelay.cpp" /> <ClCompile Include="rules\permitvpntunnel.cpp" /> @@ -53,6 +54,7 @@ <ClInclude Include="objectpurger.h" /> <ClInclude Include="rules\permitdhcpserver.h" /> <ClInclude Include="rules\permitndp.h" /> + <ClInclude Include="rules\permitping.h" /> <ClInclude Include="wfpobjecttype.h" /> <ClInclude Include="rules\blockall.h" /> <ClInclude Include="rules\ifirewallrule.h" /> diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters index 0319b0214a..a758a1c9ec 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj.filters +++ b/windows/winfw/src/winfw/winfw.vcxproj.filters @@ -43,6 +43,9 @@ <ClCompile Include="rules\permitndp.cpp"> <Filter>rules</Filter> </ClCompile> + <ClCompile Include="rules\permitping.cpp"> + <Filter>rules</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -93,6 +96,9 @@ <ClInclude Include="rules\permitndp.h"> <Filter>rules</Filter> </ClInclude> + <ClInclude Include="rules\permitping.h"> + <Filter>rules</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <Filter Include="rules"> diff --git a/windows/winnet/src/extras/loader/loader.vcxproj.filters b/windows/winnet/src/extras/loader/loader.vcxproj.filters index cd0f4643c7..408a9591b1 100644 --- a/windows/winnet/src/extras/loader/loader.vcxproj.filters +++ b/windows/winnet/src/extras/loader/loader.vcxproj.filters @@ -3,9 +3,13 @@ <ItemGroup> <ClCompile Include="loader.cpp" /> <ClCompile Include="stdafx.cpp" /> + <ClCompile Include="..\..\winnet\routemanager.cpp" /> + <ClCompile Include="..\..\winnet\adapters.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> + <ClInclude Include="..\..\winnet\routemanager.h" /> + <ClInclude Include="..\..\winnet\adapters.h" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/windows/winnet/src/winnet/interfaceutils.cpp b/windows/winnet/src/winnet/interfaceutils.cpp index babe03eba6..202d9d0724 100644 --- a/windows/winnet/src/winnet/interfaceutils.cpp +++ b/windows/winnet/src/winnet/interfaceutils.cpp @@ -2,13 +2,8 @@ #include "interfaceutils.h" #include "libcommon/error.h" #include "libcommon/string.h" -#include <vector> #include <cstdint> #include <algorithm> -#include <winsock2.h> -#include <iphlpapi.h> -#include <windows.h> - //static std::set<InterfaceUtils::NetworkAdapter> InterfaceUtils::GetAllAdapters() @@ -112,3 +107,18 @@ std::wstring InterfaceUtils::GetTapInterfaceAlias() throw std::runtime_error("Unable to find TAP adapter"); } + +//static +void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses) +{ + for (const auto &address : addresses) + { + MIB_UNICASTIPADDRESS_ROW row; + InitializeUnicastIpAddressEntry(&row); + + row.InterfaceLuid = device; + row.Address = address; + + THROW_UNLESS(NO_ERROR, CreateUnicastIpAddressEntry(&row), "Assign IP address on network interface"); + } +} diff --git a/windows/winnet/src/winnet/interfaceutils.h b/windows/winnet/src/winnet/interfaceutils.h index f5c31963c2..8ab1249a50 100644 --- a/windows/winnet/src/winnet/interfaceutils.h +++ b/windows/winnet/src/winnet/interfaceutils.h @@ -2,6 +2,17 @@ #include <string> #include <set> +#include <vector> + +// Secret include order to get most common networking structs/apis +// And avoiding compilation errors +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +#include <netioapi.h> +// end class InterfaceUtils { @@ -35,4 +46,6 @@ public: // Determines alias of primary TAP adapter. // static std::wstring GetTapInterfaceAlias(); + + static void AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses); }; diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp new file mode 100644 index 0000000000..55d7560904 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp @@ -0,0 +1,177 @@ +#include "stdafx.h" +#include <libcommon/error.h> +#include "defaultroutemonitor.h" +#include "helpers.h" + +namespace winnet::routing +{ + +namespace +{ + +const uint32_t POINT_TWO_SECOND_BURST = 200; +const uint32_t TWO_SECOND_INTERFERENCE = 2000; + +} // anonymous namespace + +DefaultRouteMonitor::DefaultRouteMonitor +( + ADDRESS_FAMILY family, + Callback callback, + std::shared_ptr<common::logging::ILogSink> logSink +) + : m_family(family) + , m_callback(callback) + , m_logSink(logSink) + , m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>( + std::bind(&DefaultRouteMonitor::evaluateRoutes, this), + POINT_TWO_SECOND_BURST, + TWO_SECOND_INTERFERENCE + )) +{ + try + { + m_bestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications"); + + try + { + const auto s2 = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this, + FALSE, &m_interfaceNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for network interface change notifications"); + } + catch (...) + { + CancelMibChangeNotify2(m_routeNotificationHandle); + throw; + } +} + +DefaultRouteMonitor::~DefaultRouteMonitor() +{ + // + // Cancel notifications to stop triggering the BurstGuard. + // + + CancelMibChangeNotify2(m_interfaceNotificationHandle); + CancelMibChangeNotify2(m_routeNotificationHandle); + + // + // Controlled destruction of BurstGuard to prevent it from calling here + // after other member variables have been destructed. + // + + m_evaluateRoutesGuard.reset(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback +( + void *context, + MIB_IPFORWARD_ROW2 *row, + MIB_NOTIFICATION_TYPE +) +{ + // + // We're only interested in changes that add/remove/update a default route. + // + + if (0 != row->DestinationPrefix.PrefixLength + || false == RouteHasGateway(*row)) + { + return; + } + + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback +( + void *context, + MIB_IPINTERFACE_ROW *, + MIB_NOTIFICATION_TYPE +) +{ + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +void DefaultRouteMonitor::evaluateRoutes() +{ + std::scoped_lock<std::mutex> lock(m_evaluationLock); + + try + { + evaluateRoutesInner(); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure while evaluating route table: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure while evaluating route table"); + } +} + +void DefaultRouteMonitor::evaluateRoutesInner() +{ + std::optional<InterfaceAndGateway> currentBestRoute; + + try + { + currentBestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + // + // If there was no default route previously. + // + + if (false == m_bestRoute.has_value()) + { + if (currentBestRoute.has_value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } + + return; + } + + // + // There used to be a default route. + // If there is not currently a default route. + // + + if (false == currentBestRoute.has_value()) + { + m_bestRoute.reset(); + m_callback(EventType::Removed, std::nullopt); + + return; + } + + // + // The current best route may have changed. + // + + if (m_bestRoute.value() != currentBestRoute.value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } +} + +} diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h new file mode 100644 index 0000000000..5575685a82 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h @@ -0,0 +1,69 @@ +#pragma once + +#include <ifdef.h> +#include <ws2def.h> +#include <functional> +#include <optional> +#include <memory> +#include <mutex> +#include <libcommon/logging/ilogsink.h> +#include <libcommon/burstguard.h> +#include "types.h" + +namespace winnet::routing +{ + +class DefaultRouteMonitor +{ +public: + + enum class EventType + { + // The best default route changed. + Updated, + + // No default routes exist. + Removed, + }; + + using Callback = std::function<void + ( + EventType eventType, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + DefaultRouteMonitor(ADDRESS_FAMILY family, Callback callback, std::shared_ptr<common::logging::ILogSink> logSink); + ~DefaultRouteMonitor(); + + DefaultRouteMonitor(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor(DefaultRouteMonitor &&) = delete; + DefaultRouteMonitor &operator=(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor &operator=(DefaultRouteMonitor &&) = delete; + +private: + + ADDRESS_FAMILY m_family; + Callback m_callback; + std::shared_ptr<common::logging::ILogSink> m_logSink; + + // This can't be a plain member variable. + // We need to be able to delete it explicitly in order to have a controlled tear down. + std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard; + + std::optional<InterfaceAndGateway> m_bestRoute; + + HANDLE m_routeNotificationHandle; + HANDLE m_interfaceNotificationHandle; + + std::mutex m_evaluationLock; + + static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); + static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType); + + void evaluateRoutes(); + void evaluateRoutesInner(); +}; + +} diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp new file mode 100644 index 0000000000..cabf19bce6 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.cpp @@ -0,0 +1,275 @@ +#include "stdafx.h" +#include "helpers.h" +#include <stdexcept> +#include <ws2def.h> +#include <in6addr.h> +#include <numeric> +//#include <netioapi.h> +#include <libcommon/error.h> +#include <libcommon/memory.h> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs) +{ + if (lhs.PrefixLength != rhs.PrefixLength) + { + return false; + } + + return EqualAddress(lhs.Prefix, rhs.Prefix); +} + +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs) +{ + if (lhs.si_family != rhs.si_family) + { + return false; + } + + switch (lhs.si_family) + { + case AF_INET: + { + return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR)); + } + default: + { + throw std::runtime_error("Invalid address family for network address"); + } + } +} + +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs) +{ + if (lhs->si_family != rhs->lpSockaddr->sa_family) + { + return false; + } + + switch (lhs->si_family) + { + case AF_INET: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr); + return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr; + } + case AF_INET6: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr); + return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16); + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface) +{ + memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW)); + + iface->Family = addressFamily; + iface->InterfaceLuid = adapter; + + return NO_ERROR == GetIpInterfaceEntry(iface); +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes) +{ + std::vector<AnnotatedRoute> annotated; + annotated.reserve(routes.size()); + + for (auto route : routes) + { + MIB_IPINTERFACE_ROW iface; + + if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface)) + { + continue; + } + + annotated.emplace_back + ( + AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric } + ); + } + + return annotated; +} + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route) +{ + switch (route.NextHop.si_family) + { + case AF_INET: + { + return 0 != route.NextHop.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + return 0 != std::accumulate(begin, end, 0); + } + default: + { + return false; + } + }; +} + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family) +{ + PMIB_IPFORWARD_TABLE2 table; + + auto status = GetIpForwardTable2(family, &table); + + THROW_UNLESS(NO_ERROR, status, "Acquire route table"); + + common::memory::ScopeDestructor sd; + + sd += [table] + { + FreeMibTable(table); + }; + + std::vector<const MIB_IPFORWARD_ROW2 *> candidates; + candidates.reserve(table->NumEntries); + + // + // Enumerate routes looking for: route 0/0 && gateway specified. + // + + for (ULONG i = 0; i < table->NumEntries; ++i) + { + const MIB_IPFORWARD_ROW2 &candidate = table->Table[i]; + + if (0 == candidate.DestinationPrefix.PrefixLength + && RouteHasGateway(candidate)) + { + candidates.emplace_back(&candidate); + } + } + + auto annotated = AnnotateRoutes(candidates); + + if (annotated.empty()) + { + throw std::runtime_error("Unable to determine details of default route"); + } + + // + // Sort on (active, effectiveMetric) ascending by metric. + // + + std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs) + { + if (lhs.active == rhs.active) + { + return lhs.effectiveMetric < rhs.effectiveMetric; + } + + return lhs.active && false == rhs.active; + }); + + // + // Ensure the top rated route is active. + // + + if (false == annotated[0].active) + { + throw std::runtime_error("Unable to identify active default route"); + } + + return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop }; +} + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family) +{ + switch (family) + { + case AF_INET: + { + return 0 != adapter->Ipv4Enabled; + } + case AF_INET6: + { + return 0 != adapter->Ipv6Enabled; + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +) +{ + std::vector<const SOCKET_ADDRESS *> matches; + + for (auto gateway = head; nullptr != gateway; gateway = gateway->Next) + { + if (family == gateway->Address.lpSockaddr->sa_family) + { + matches.emplace_back(&gateway->Address); + } + } + + return matches; +} + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle) +{ + for (const auto candidate : hay) + { + if (EqualAddress(needle, candidate)) + { + return true; + } + } + + return false; +} + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa) +//{ +// NodeAddress out = { 0 }; +// +// switch (sa->lpSockaddr->sa_family) +// { +// case AF_INET: +// { +// out.si_family = AF_INET; +// out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr); +// +// break; +// } +// case AF_INET6: +// { +// out.si_family = AF_INET6; +// out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr); +// +// break; +// } +// default: +// { +// throw std::runtime_error("Missing case handler in switch clause"); +// } +// }; +// +// return out; +//} + +} diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h new file mode 100644 index 0000000000..3ef5e85b75 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.h @@ -0,0 +1,46 @@ +#pragma once + +#include "types.h" +#include <vector> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs); +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs); +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs); + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface); + +struct AnnotatedRoute +{ + const MIB_IPFORWARD_ROW2 *route; + bool active; + uint32_t effectiveMetric; +}; + +template<typename T> +bool bool_cast(const T &value) +{ + return 0 != value; +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes); + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route); + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family); + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family); + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +); + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle); + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa); + +} diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp new file mode 100644 index 0000000000..668e64bb68 --- /dev/null +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -0,0 +1,692 @@ +#include "stdafx.h" +#include "routemanager.h" +#include "helpers.h" +#include <libcommon/error.h> +#include <libcommon/memory.h> +#include <libcommon/string.h> +#include <libcommon/network/adapters.h> +#include <vector> +#include <algorithm> +#include <numeric> +#include <sstream> +#include <stdexcept> + +using AutoLockType = std::scoped_lock<std::mutex>; +using AutoRecursiveLockType = std::scoped_lock<std::recursive_mutex>; +using namespace std::placeholders; + +namespace winnet::routing +{ + +namespace +{ + +using Adapters = common::network::Adapters; + +NET_LUID InterfaceLuidFromGateway(const NodeAddress &gateway) +{ + const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER + | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS; + + Adapters adapters(gateway.si_family, adapterFlags); + + // + // Process adapters to find matching ones. + // + + std::vector<const IP_ADAPTER_ADDRESSES *> matches; + + for (auto adapter = adapters.next(); nullptr != adapter; adapter = adapters.next()) + { + if (false == AdapterInterfaceEnabled(adapter, gateway.si_family)) + { + continue; + } + + auto gateways = IsolateGatewayAddresses(adapter->FirstGatewayAddress, gateway.si_family); + + if (AddressPresent(gateways, &gateway)) + { + matches.emplace_back(adapter); + } + } + + if (matches.empty()) + { + throw std::runtime_error("Unable to find network adapter with specified gateway"); + } + + // + // Sort matching interfaces ascending by metric. + // + + const bool targetV4 = (AF_INET == gateway.si_family); + + std::sort(matches.begin(), matches.end(), [&targetV4](const IP_ADAPTER_ADDRESSES *lhs, const IP_ADAPTER_ADDRESSES *rhs) + { + if (targetV4) + { + return lhs->Ipv4Metric < rhs->Ipv4Metric; + } + + return lhs->Ipv6Metric < rhs->Ipv6Metric; + }); + + // + // Select the interface with the best (lowest) metric. + // + + return matches[0]->Luid; +} + +bool ParseStringEncodedLuid(const std::wstring &encodedLuid, NET_LUID &luid) +{ + // + // The `#` is a valid character in adapter names so we use `?` instead. + // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes. + // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`. + // + + static const size_t StringEncodedLuidLength = 17; + + if (encodedLuid.size() != StringEncodedLuidLength + || L'?' != encodedLuid[0]) + { + return false; + } + + try + { + std::wstringstream ss; + + ss << std::hex << &encodedLuid[1]; + ss >> luid.Value; + } + catch (...) + { + const auto ansi = common::string::ToAnsi(encodedLuid); + const auto err = std::string("Failed to parse string encoded LUID: ").append(ansi); + + std::throw_with_nested(std::runtime_error(err)); + } + + return true; +} + +InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node> &optionalNode) +{ + // + // There are four cases: + // + // Unspecified node (use interface and gateway of default route). + // Node is specified by name. + // Node is specified by name and gateway. + // Node is specified by gateway. + // + + if (false == optionalNode.has_value()) + { + return GetBestDefaultRoute(family); + } + + const auto &node = optionalNode.value(); + + if (node.deviceName().has_value()) + { + const auto &deviceName = node.deviceName().value(); + NET_LUID luid; + + if (false == ParseStringEncodedLuid(deviceName, luid) + && 0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid)) + { + const auto ansiName = common::string::ToAnsi(deviceName); + const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName); + + throw std::runtime_error(err); + } + + auto onLinkProvider = [&family]() + { + NodeAddress onLink = { 0 }; + onLink.si_family = family; + + return onLink; + }; + + return InterfaceAndGateway{ luid, node.gateway().value_or(onLinkProvider()) }; + } + + // + // The node is specified only by gateway. + // + + return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() }; +} + +// TODO: Move to libcommon +uint32_t ByteSwap(uint32_t val) +{ + return + ( + ((val & 0xFF) << 24) | + ((val & 0xFF00) << 8) | + ((val & 0xFF0000) >> 8) | + ((val & 0xFF000000) >> 24) + ); +} + +std::wstring FormatNetwork(const Network &network) +{ + switch (network.Prefix.si_family) + { + case AF_INET: + { + return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength); + } + case AF_INET6: + { + return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength); + } + default: + { + return L"Failed to format network details"; + } + } +} + +} // anonymous namespace + +RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink) + : m_logSink(logSink) + , m_routeMonitorV4(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET), _1, _2), + logSink + )) + , m_routeMonitorV6(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET6), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET6), _1, _2), + logSink + )) +{ +} + +RouteManager::~RouteManager() +{ + // + // Stop callbacks that are triggered by events in Windows from coming in. + // + + m_routeMonitorV4.reset(); + m_routeMonitorV6.reset(); + + // + // Delete all routes owned by us. + // + + for (const auto &record : m_routes) + { + try + { + deleteFromRoutingTable(record.registeredRoute); + } + catch (const std::exception &ex) + { + std::wstringstream ss; + + ss << L"Failed to delete route as part of cleaning up, Route: " + << FormatRegisteredRoute(record.registeredRoute); + + m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); + m_logSink->error(ex.what()); + } + } +} + +void RouteManager::addRoutes(const std::vector<Route> &routes) +{ + AutoLockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + + const RouteRecord newRecord { route, addIntoRoutingTable(route) }; + + eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord }); + m_routes.emplace_back(std::move(newRecord)); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch insertion of routes")); + } + } +} + +void RouteManager::addRoute(const Route &route) +{ + AutoLockType lock(m_routesLock); + + std::optional<RouteRecord> deletedRecord; + + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + try + { + deleteFromRoutingTable(record->registeredRoute); + } + catch (...) + { + std::throw_with_nested(std::runtime_error("Failed to evict old route when adding new route")); + } + + deletedRecord = *record; + m_routes.erase(record); + } + + try + { + m_routes.emplace_back + ( + RouteRecord{ route, addIntoRoutingTable(route) } + ); + } + catch (...) + { + // + // Restore deleted record. + // + + if (deletedRecord.has_value()) + { + auto &r = deletedRecord.value(); + + try + { + restoreIntoRoutingTable(r.registeredRoute); + m_routes.emplace_back(r); + } + catch (const std::exception &ex) + { + const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } + + // + // Just rethrow because the error is from addIntoRoutingTable(). + // + + throw; + } +} + +void RouteManager::deleteRoutes(const std::vector<Route> &routes) +{ + AutoLockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + continue; + } + + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch removal of routes")); + } + } +} + +void RouteManager::deleteRoute(const Route &route) +{ + AutoLockType lock(m_routesLock); + + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + return; + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); +} + +RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback) +{ + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); + + m_defaultRouteCallbacks.emplace_back(callback); + + // Return raw address of record in list. + return &m_defaultRouteCallbacks.back(); +} + +void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle) +{ + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); + + for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it) + { + // Match on raw address of record. + if (&*it == handle) + { + m_defaultRouteCallbacks.erase(it); + return; + } + } +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network) +{ + return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate) + { + return EqualAddress(network, candidate.route.network()); + }); +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route) +{ + return findRouteRecord(route.network()); +} + +RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route) +{ + const auto node = ResolveNode(route.network().Prefix.si_family, route.node()); + + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = node.iface; + spec.DestinationPrefix = route.network(); + spec.NextHop = node.gateway; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + // + // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful. + // Because it may not take route metric into consideration. + // + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); + + return RegisteredRoute { route.network(), node.iface, node.gateway }; +} + +void RouteManager::restoreIntoRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = route.luid; + spec.DestinationPrefix = route.network; + spec.NextHop = route.nextHop; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); +} + +void RouteManager::deleteFromRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 r = { 0}; + + r.InterfaceLuid = route.luid; + r.DestinationPrefix = route.network; + r.NextHop = route.nextHop; + + auto status = DeleteIpForwardEntry2(&r); + + if (ERROR_NOT_FOUND == status) + { + status = NO_ERROR; + + const auto err = std::wstring(L"Attempting to delete route which was not present in routing table, " \ + "ignoring and proceeding. Route: ").append(FormatRegisteredRoute(route)); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + } + + THROW_UNLESS(NO_ERROR, status, "Delete route in routing table"); +} + +void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) +{ + // + // Rewind state by processing events in the reverse order. + // + + for (auto it = eventLog.rbegin(); it != eventLog.rend(); ++it) + { + try + { + switch (it->type) + { + case EventType::ADD_ROUTE: + { + auto record = findRouteRecord(it->record.route); + + if (m_routes.end() == record) + { + throw std::runtime_error("Internal state inconsistency in route manager"); + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); + + break; + } + case EventType::DELETE_ROUTE: + { + restoreIntoRoutingTable(it->record.registeredRoute); + m_routes.emplace_back(it->record); + + break; + } + default: + { + throw std::logic_error("Missing case handler in switch clause"); + } + } + } + catch (const std::exception &ex) + { + const auto err = std::string("Attempting to rollback state: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } +} + +// static +std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route) +{ + // + // TODO: Fix broken IP formatting + // Update FormatIpv4 function with an additional argument to specify network/host byte order. + // + + std::wstringstream ss; + + if (AF_INET == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + if (0 != route.nextHop.Ipv4.sin_addr.s_addr) + { + gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr)); + } + + ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else if (AF_INET6 == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + if (0 != std::accumulate(begin, end, 0)) + { + gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte); + } + + ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else + { + ss << L"Failed to format route details"; + } + + return ss.str(); +} + +void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route) +{ + // + // Forward event to all registered listeners. + // + + m_defaultRouteCallbacksLock.lock(); + + for (const auto &callback : m_defaultRouteCallbacks) + { + try + { + callback(eventType, family, route); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure in default-route-changed callback: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure in default-route-changed callback"); + } + } + + m_defaultRouteCallbacksLock.unlock(); + + // + // Examine event to determine if best default route has changed. + // + + if (DefaultRouteMonitor::EventType::Updated != eventType) + { + return; + } + + // + // Examine our routes to see if any of them are policy bound to the best default route. + // + + AutoLockType routesLock(m_routesLock); + + using RecordIterator = std::list<RouteRecord>::iterator; + + std::list<RecordIterator> affectedRoutes; + + for (RecordIterator it = m_routes.begin(); it != m_routes.end(); ++it) + { + if (false == it->route.node().has_value() + && family == it->route.network().Prefix.si_family) + { + affectedRoutes.emplace_back(it); + } + } + + if (affectedRoutes.empty()) + { + return; + } + + // + // Update all affected routes. + // + + m_logSink->info("Best default route has changed. Refreshing dependent routes"); + + for (auto &it : affectedRoutes) + { + try + { + deleteFromRoutingTable(it->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to delete route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + + it->registeredRoute.luid = route.value().iface; + it->registeredRoute.nextHop = route.value().gateway; + + try + { + restoreIntoRoutingTable(it->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to add route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + } +} + +} diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h new file mode 100644 index 0000000000..981c8e6834 --- /dev/null +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -0,0 +1,112 @@ +#pragma once + +#include <string> +#include <memory> +#include <vector> +#include <list> +#include <optional> +#include <mutex> +#include <functional> +#include <windows.h> +#include <ws2def.h> +#include <ifdef.h> +#include <libcommon/string.h> +#include <libcommon/logging/ilogsink.h> +#include "defaultroutemonitor.h" + +namespace winnet::routing +{ + +class RouteManager +{ +public: + + RouteManager(std::shared_ptr<common::logging::ILogSink> logSink); + ~RouteManager(); + + RouteManager(const RouteManager &) = delete; + RouteManager(RouteManager &&) = default; + RouteManager &operator=(const RouteManager &) = delete; + RouteManager &operator=(RouteManager &&) = delete; + + void addRoutes(const std::vector<Route> &routes); + void addRoute(const Route &route); + + void deleteRoutes(const std::vector<Route> &routes); + void deleteRoute(const Route &route); + + using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType; + + using DefaultRouteChangedCallback = std::function<void + ( + DefaultRouteChangedEventType eventType, + ADDRESS_FAMILY family, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + using CallbackHandle = void*; + + CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback); + void unregisterDefaultRouteChangedCallback(CallbackHandle handle); + +private: + + std::shared_ptr<common::logging::ILogSink> m_logSink; + + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV4; + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV6; + + // These are the exact details derived from the route specification (`Route`). + // They are used when registering and deleting a route in the system. + struct RegisteredRoute + { + Network network; + NET_LUID luid; + NodeAddress nextHop; + }; + + struct RouteRecord + { + Route route; + RegisteredRoute registeredRoute; + }; + + std::list<RouteRecord> m_routes; + std::mutex m_routesLock; + + std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks; + std::recursive_mutex m_defaultRouteCallbacksLock; + + // Find record based on destination and mask. + std::list<RouteRecord>::iterator findRouteRecord(const Network &network); + + // Note: Same as above! + std::list<RouteRecord>::iterator findRouteRecord(const Route &route); + + RegisteredRoute addIntoRoutingTable(const Route &route); + void restoreIntoRoutingTable(const RegisteredRoute &route); + void deleteFromRoutingTable(const RegisteredRoute &route); + + enum class EventType + { + ADD_ROUTE, + DELETE_ROUTE, + }; + + struct EventEntry + { + EventType type; + RouteRecord record; + }; + + void undoEvents(const std::vector<EventEntry> &eventLog); + + static std::wstring FormatRegisteredRoute(const RegisteredRoute &route); + + void defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route); +}; + +} diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp new file mode 100644 index 0000000000..ac71c8108f --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.cpp @@ -0,0 +1,84 @@ +#include "stdafx.h" +#include "types.h" +#include "helpers.h" +#include <libcommon/string.h> + +namespace winnet::routing +{ + +Node::Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway) + : m_deviceName(deviceName) + , m_gateway(gateway) +{ + if (false == m_deviceName.has_value() && false == m_gateway.has_value()) + { + throw std::runtime_error("Invalid node definition"); + } + + if (m_deviceName.has_value()) + { + const auto trimmed = common::string::Trim<>(m_deviceName.value()); + + if (trimmed.empty()) + { + throw std::runtime_error("Invalid device name in node definition"); + } + + m_deviceName = std::move(trimmed); + } +} + +bool Node::operator==(const Node &rhs) const +{ + if (m_deviceName.has_value()) + { + if (false == rhs.m_deviceName.has_value() + || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str())) + { + return false; + } + } + + if (m_gateway.has_value()) + { + if (false == rhs.m_gateway.has_value() + || false == EqualAddress(m_gateway.value(), rhs.gateway().value())) + { + return false; + } + } + + return true; +} + +Route::Route(const Network &network, const std::optional<Node> &node) + : m_network(network) + , m_node(node) +{ +} + +bool Route::operator==(const Route &rhs) const +{ + if (m_node.has_value()) + { + return rhs.node().has_value() + && EqualAddress(m_network, rhs.network()) + && m_node.value() == rhs.node().value(); + } + + return false == rhs.node().has_value() + && EqualAddress(m_network, rhs.network()); +} + +bool InterfaceAndGateway::operator==(const InterfaceAndGateway &rhs) +{ + return iface.Value == rhs.iface.Value + && EqualAddress(gateway, rhs.gateway); +} + +bool InterfaceAndGateway::operator!=(const InterfaceAndGateway &rhs) +{ + return !(*this == rhs); +} + +} diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h new file mode 100644 index 0000000000..1e132feb00 --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.h @@ -0,0 +1,77 @@ +#pragma once + +#include <string> +#include <optional> +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +//#include <netioapi.h> +//#include <functional> + + +namespace winnet::routing +{ + +using Network = IP_ADDRESS_PREFIX; +using NodeAddress = SOCKADDR_INET; + +class Node +{ +public: + + Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway); + + const std::optional<std::wstring> &deviceName() const + { + return m_deviceName; + } + + const std::optional<NodeAddress> &gateway() const + { + return m_gateway; + } + + bool operator==(const Node &rhs) const; + +private: + + std::optional<std::wstring> m_deviceName; + std::optional<NodeAddress> m_gateway; +}; + +class Route +{ +public: + + Route(const Network &network, const std::optional<Node> &node); + + const Network &network() const + { + return m_network; + } + + const std::optional<Node> &node() const + { + return m_node; + } + + bool operator==(const Route &rhs) const; + +private: + + Network m_network; + std::optional<Node> m_node; +}; + +struct InterfaceAndGateway +{ + NET_LUID iface; + NodeAddress gateway; + + bool operator==(const InterfaceAndGateway &rhs); + bool operator!=(const InterfaceAndGateway &rhs); +}; + +} diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index 4b006964a6..48d12b5ea3 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -3,17 +3,135 @@ #include "NetworkInterfaces.h"
#include "interfaceutils.h"
#include "offlinemonitor.h"
+#include "routing/routemanager.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
+#include <libcommon/network.h>
#include <cstdint>
#include <stdexcept>
#include <memory>
+#include <optional>
+#include <mutex>
+
+using namespace winnet::routing;
+using AutoLockType = std::scoped_lock<std::mutex>;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+std::mutex g_RouteManagerLock;
+RouteManager *g_RouteManager = nullptr;
+std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
+
+Network ConvertNetwork(const WINNET_IPNETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out{ 0 };
+
+ out.PrefixLength = in.prefix;
+
+ switch (in.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ out.Prefix.si_family = AF_INET;
+ out.Prefix.Ipv4.sin_family = AF_INET;
+ out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ out.Prefix.si_family = AF_INET6;
+ out.Prefix.Ipv6.sin6_family = AF_INET6;
+ memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return {};
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ throw std::runtime_error("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ NodeAddress gw { 0 };
+
+ switch (in->gateway->type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ gw.si_family = AF_INET;
+ gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ gw.si_family = AF_INET6;
+ memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid gateway type specifier in 'WINNET_NODE' definition");
+ }
+ }
+
+ gateway = gw;
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::exception &err)
{
if (nullptr == logSink)
@@ -26,6 +144,49 @@ void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::excep common::error::UnwindException(err, logger);
}
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
+{
+ //
+ // This duplicates the same logic we have above.
+ // TODO: Fix when time permits.
+ //
+
+ std::vector<SOCKADDR_INET> out;
+ out.reserve(numAddresses);
+
+ for (uint32_t i = 0; i < numAddresses; ++i)
+ {
+ const WINNET_IP &from = addresses[i];
+ SOCKADDR_INET to{ 0 };
+
+ switch (from.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ to.si_family = AF_INET;
+ to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(from.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ to.si_family = AF_INET6;
+ memcpy(&to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid address family in 'WINNET_IP' definition");
+ }
+ }
+
+ out.push_back(to);
+ }
+
+ return out;
+}
+
} //anonymous namespace
extern "C"
@@ -66,12 +227,12 @@ WinNet_GetTapInterfaceIpv6Status( {
try
{
- MIB_IPINTERFACE_ROW interface = { 0 };
+ MIB_IPINTERFACE_ROW iface = { 0 };
- interface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
- interface.Family = AF_INET6;
+ iface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
+ iface.Family = AF_INET6;
- const auto status = GetIpInterfaceEntry(&interface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -201,3 +362,360 @@ WinNet_DeactivateConnectivityMonitor( {
}
}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ if (nullptr != g_RouteManager)
+ {
+ throw std::runtime_error("Cannot activate route manager twice");
+ }
+
+ g_RouteManagerLogSink = std::make_shared<shared::LogSinkAdapter>(logSink, logSinkContext);
+ g_RouteManager = new RouteManager(g_RouteManagerLogSink);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+//
+// TODO: Move to libcommon.
+//
+struct ValueMapper
+{
+ template<typename T, typename U, std::size_t S>
+ static U map(T t, const std::pair<T, U> (&dictionary)[S])
+ {
+ for (const auto &entry : dictionary)
+ {
+ if (t == entry.first)
+ {
+ return entry.second;
+ }
+ }
+
+ throw std::runtime_error("Could not map between values");
+ }
+};
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ auto forwarder = [callback, context](RouteManager::DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family, const std::optional<InterfaceAndGateway> &route)
+ {
+ //
+ // Translate the event type.
+ //
+
+ using from_t = RouteManager::DefaultRouteChangedEventType;
+ using to_t = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE;
+
+ static const std::pair<from_t, to_t> eventTypeMap[] =
+ {
+ { from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
+ };
+
+ const auto translatedEventType = ValueMapper::map<>(eventType, eventTypeMap);
+
+ //
+ // Translate the family type.
+ //
+
+ static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
+ {
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ };
+
+ const auto translatedFamily = ValueMapper::map<>(family, familyMap);
+
+ //
+ // Determine which LUID to forward.
+ //
+
+ uint64_t translatedLuid = 0;
+
+ if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ {
+ translatedLuid = route.value().iface.Value;
+ }
+
+ //
+ // Forward to client.
+ //
+
+ callback(translatedEventType, translatedFamily, translatedLuid, context);
+ };
+
+ *registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return;
+ }
+
+ try
+ {
+ g_RouteManager->unregisterDefaultRouteChangedCallback(registrationHandle);
+ }
+ catch (const std::exception &err)
+ {
+ g_RouteManagerLogSink->error("Failed to unregister default-route-changed callback");
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ delete g_RouteManager;
+ g_RouteManager = nullptr;
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ NET_LUID luid;
+
+ if (0 != ConvertInterfaceAliasToLuid(deviceAlias, &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceAlias);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ InterfaceUtils::AddDeviceIpAddresses(luid, ConvertAddresses(addresses, numAddresses));
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def index 04c3f22ee3..b23ae6c854 100644 --- a/windows/winnet/src/winnet/winnet.def +++ b/windows/winnet/src/winnet/winnet.def @@ -6,3 +6,6 @@ EXPORTS WinNet_ReleaseString WinNet_ActivateConnectivityMonitor WinNet_DeactivateConnectivityMonitor + WinNet_ActivateRouteManager + WinNet_DeactivateRouteManager + WinNet_AddDeviceIpAddresses diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 9b1af52e36..c7a161c3d8 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -1,6 +1,7 @@ #pragma once #include "../../shared/logsink.h" +#include <stdint.h> #include <stdbool.h> #ifndef WINNET_STATIC @@ -89,3 +90,147 @@ void WINNET_API WinNet_DeactivateConnectivityMonitor( ); + +enum WINNET_IP_TYPE +{ + WINNET_IP_TYPE_IPV4 = 0, + WINNET_IP_TYPE_IPV6 = 1, +}; + +typedef struct tag_WINNET_IPNETWORK +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. + uint8_t prefix; +} +WINNET_IPNETWORK; + +typedef struct tag_WINNET_IP +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. +} +WINNET_IP; + +typedef struct tag_WINNET_NODE +{ + const WINNET_IP *gateway; + const wchar_t *deviceName; +} +WINNET_NODE; + +typedef struct tag_WINNET_ROUTE +{ + WINNET_IPNETWORK network; + const WINNET_NODE *node; +} +WINNET_ROUTE; + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_ActivateRouteManager( + MullvadLogSink logSink, + void *logSinkContext +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoute( + const WINNET_ROUTE *route +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoute( + const WINNET_ROUTE *route +); + +enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE +{ + // Best default route changed. + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0, + + // No default routes exist. + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1, +}; + +enum WINNET_IP_FAMILY +{ + WINNET_IP_FAMILY_V4 = 0, + WINNET_IP_FAMILY_V6 = 1, +}; + +typedef void (WINNET_API *WinNetDefaultRouteChangedCallback) +( + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType, + + // Signals which IP family the event relates to. + WINNET_IP_FAMILY family, + + // For update events, signals the interface associated with the new best default route. + uint64_t interfaceLuid, + + void *context +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_RegisterDefaultRouteChangedCallback( + WinNetDefaultRouteChangedCallback callback, + void *context, + void **registrationHandle +); + +extern "C" +WINNET_LINKAGE +void +WINNET_API +WinNet_UnregisterDefaultRouteChangedCallback( + void *registrationHandle +); + +extern "C" +WINNET_LINKAGE +void +WINNET_API +WinNet_DeactivateRouteManager( +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddDeviceIpAddresses( + const wchar_t *deviceAlias, + const WINNET_IP *addresses, + uint32_t numAddresses, + MullvadLogSink logSink, + void *logSinkContext +); + diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj index 192320daaf..5e71a1f733 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj +++ b/windows/winnet/src/winnet/winnet.vcxproj @@ -33,6 +33,10 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> <ClCompile Include="NetworkInterfaces.cpp" /> + <ClCompile Include="routing\defaultroutemonitor.cpp" /> + <ClCompile Include="routing\helpers.cpp" /> + <ClCompile Include="routing\routemanager.cpp" /> + <ClCompile Include="routing\types.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="winnet.cpp" /> </ItemGroup> @@ -42,6 +46,10 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="offlinemonitor.h" /> <ClInclude Include="NetworkInterfaces.h" /> + <ClInclude Include="routing\defaultroutemonitor.h" /> + <ClInclude Include="routing\helpers.h" /> + <ClInclude Include="routing\routemanager.h" /> + <ClInclude Include="routing\types.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="winnet.h" /> @@ -208,7 +216,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> @@ -278,7 +286,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreaded</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters index 9a901d3203..dfe6d29ec7 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj.filters +++ b/windows/winnet/src/winnet/winnet.vcxproj.filters @@ -9,6 +9,18 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="networkadaptermonitor.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> + <ClCompile Include="routing\types.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\helpers.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\defaultroutemonitor.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\routemanager.cpp"> + <Filter>routing</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -19,6 +31,18 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="networkadaptermonitor.h" /> <ClInclude Include="offlinemonitor.h" /> + <ClInclude Include="routing\types.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\helpers.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\defaultroutemonitor.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\routemanager.h"> + <Filter>routing</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <None Include="winnet.def" /> @@ -26,4 +50,9 @@ <ItemGroup> <ResourceCompile Include="winnet.rc" /> </ItemGroup> + <ItemGroup> + <Filter Include="routing"> + <UniqueIdentifier>{8df22cc6-597f-4342-bc57-7647393084be}</UniqueIdentifier> + </Filter> + </ItemGroup> </Project>
\ No newline at end of file |
