summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2019-11-25 14:23:36 +0100
committerOdd Stranne <odd@mullvad.net>2019-11-25 14:23:36 +0100
commit67a86af237d3305f84bb7044aa1bdf5e1122e81f (patch)
treef910fc098a52d5466de50224d8cd66c203aac4fa
parentdc6d5d8e87738f919b8f924b3381a1954097138b (diff)
parente4f46afbe027cb8ddbb45a66ce014af9acbc54b6 (diff)
downloadmullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.tar.xz
mullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.zip
Merge branch 'win-wireguard'
-rw-r--r--CHANGELOG.md1
-rw-r--r--Cargo.lock1
m---------dist-assets/binaries0
-rw-r--r--gui/src/renderer/components/AdvancedSettings.tsx47
-rw-r--r--gui/src/renderer/containers/AdvancedSettingsPage.tsx2
-rw-r--r--gui/tasks/distribution.js1
-rw-r--r--mullvad-daemon/src/lib.rs5
-rw-r--r--talpid-core/Cargo.toml9
-rw-r--r--talpid-core/build.rs3
-rw-r--r--talpid-core/src/dns/windows/mod.rs2
-rw-r--r--talpid-core/src/firewall/windows.rs48
-rw-r--r--talpid-core/src/lib.rs1
-rw-r--r--talpid-core/src/ping_monitor/win.rs126
-rw-r--r--talpid-core/src/routing/mod.rs38
-rw-r--r--talpid-core/src/routing/windows.rs72
-rw-r--r--talpid-core/src/tunnel/mod.rs34
-rw-r--r--talpid-core/src/tunnel/tun_provider/mod.rs4
-rw-r--r--talpid-core/src/tunnel/tun_provider/windows.rs13
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs20
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs210
-rw-r--r--talpid-core/src/winnet.rs303
-rw-r--r--windows/winfw/src/winfw/fwcontext.cpp25
-rw-r--r--windows/winfw/src/winfw/fwcontext.h16
-rw-r--r--windows/winfw/src/winfw/mullvadguids.cpp30
-rw-r--r--windows/winfw/src/winfw/mullvadguids.h3
-rw-r--r--windows/winfw/src/winfw/rules/permitping.cpp98
-rw-r--r--windows/winfw/src/winfw/rules/permitping.h28
-rw-r--r--windows/winfw/src/winfw/winfw.cpp34
-rw-r--r--windows/winfw/src/winfw/winfw.h21
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj2
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj.filters6
-rw-r--r--windows/winnet/src/extras/loader/loader.vcxproj.filters4
-rw-r--r--windows/winnet/src/winnet/interfaceutils.cpp20
-rw-r--r--windows/winnet/src/winnet/interfaceutils.h13
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.cpp177
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.h69
-rw-r--r--windows/winnet/src/winnet/routing/helpers.cpp275
-rw-r--r--windows/winnet/src/winnet/routing/helpers.h46
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.cpp692
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.h112
-rw-r--r--windows/winnet/src/winnet/routing/types.cpp84
-rw-r--r--windows/winnet/src/winnet/routing/types.h77
-rw-r--r--windows/winnet/src/winnet/winnet.cpp526
-rw-r--r--windows/winnet/src/winnet/winnet.def3
-rw-r--r--windows/winnet/src/winnet/winnet.h145
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj12
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj.filters29
47 files changed, 3372 insertions, 115 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7884f211cf..879b58a49b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,6 +25,7 @@ Line wrap the file at 100 chars. Th
## [Unreleased]
### Added
#### Windows
+- Full WireGuard support, GUI and CLI.
- Install Wintun driver that provides the WireGuard TUN adapter.
- Remove Mullvad TAP adapter on uninstall. Also remove the TAP driver if there are no other TAP
adapters in the system.
diff --git a/Cargo.lock b/Cargo.lock
index 5c6db6f0b8..b23c1d31c8 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2423,6 +2423,7 @@ version = "0.1.0"
dependencies = [
"atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
+ "chrono 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)",
"dbus 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
"duct 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)",
"err-derive 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/dist-assets/binaries b/dist-assets/binaries
-Subproject 66091e4249f8afcbf3daf4ffb01bb05bf8d64d0
+Subproject fe1f86af8f3b99eed99b1299b5d4ca15f20ebab
diff --git a/gui/src/renderer/components/AdvancedSettings.tsx b/gui/src/renderer/components/AdvancedSettings.tsx
index 334da211a8..dd9568bed5 100644
--- a/gui/src/renderer/components/AdvancedSettings.tsx
+++ b/gui/src/renderer/components/AdvancedSettings.tsx
@@ -43,7 +43,6 @@ interface IProps {
wireguard: { port?: number };
mssfix?: number;
bridgeState: BridgeState;
- enableWireguardKeysPage: boolean;
setBridgeState: (value: BridgeState) => void;
setEnableIpv6: (value: boolean) => void;
setBlockWhenDisconnected: (value: boolean) => void;
@@ -226,18 +225,14 @@ export default class AdvancedSettings extends Component<IProps, IState> {
)}
</Cell.Footer>
- {process.platform !== 'win32' ? (
- <View style={styles.advanced_settings__content}>
- <Selector
- title={messages.pgettext('advanced-settings-view', 'Tunnel protocol')}
- values={this.tunnelProtocolItems}
- value={this.props.tunnelProtocol}
- onSelect={this.onSelectTunnelProtocol}
- />
- </View>
- ) : (
- undefined
- )}
+ <View style={styles.advanced_settings__content}>
+ <Selector
+ title={messages.pgettext('advanced-settings-view', 'Tunnel protocol')}
+ values={this.tunnelProtocolItems}
+ value={this.props.tunnelProtocol}
+ onSelect={this.onSelectTunnelProtocol}
+ />
+ </View>
{this.props.tunnelProtocol !== 'wireguard' ? (
<View style={styles.advanced_settings__content}>
@@ -277,7 +272,7 @@ export default class AdvancedSettings extends Component<IProps, IState> {
undefined
)}
- {this.props.tunnelProtocol === 'wireguard' && process.platform !== 'win32' ? (
+ {this.props.tunnelProtocol === 'wireguard' ? (
<View style={styles.advanced_settings__content}>
<Selector
// TRANSLATORS: The title for the shadowsocks bridge selector section.
@@ -336,7 +331,14 @@ export default class AdvancedSettings extends Component<IProps, IState> {
)}
</Cell.FooterText>
</Cell.Footer>
- {this.wireguardKeysButton()}
+ <View style={styles.advanced_settings__wgkeys_cell}>
+ <Cell.CellButton onPress={this.props.onViewWireguardKeys}>
+ <Cell.Label>
+ {messages.pgettext('advanced-settings-view', 'WireGuard key')}
+ </Cell.Label>
+ <Cell.Icon height={12} width={7} source="icon-chevron" />
+ </Cell.CellButton>
+ </View>
</NavigationScrollbars>
</View>
</NavigationContainer>
@@ -346,21 +348,6 @@ export default class AdvancedSettings extends Component<IProps, IState> {
);
}
- private wireguardKeysButton() {
- if (this.props.enableWireguardKeysPage) {
- return (
- <View style={styles.advanced_settings__wgkeys_cell}>
- <Cell.CellButton onPress={this.props.onViewWireguardKeys}>
- <Cell.Label>{messages.pgettext('advanced-settings-view', 'WireGuard key')}</Cell.Label>
- <Cell.Icon height={12} width={7} source="icon-chevron" />
- </Cell.CellButton>
- </View>
- );
- } else {
- return null;
- }
- }
-
private onSelectTunnelProtocol = (protocol?: TunnelProtocol) => {
this.props.setTunnelProtocol(protocol);
};
diff --git a/gui/src/renderer/containers/AdvancedSettingsPage.tsx b/gui/src/renderer/containers/AdvancedSettingsPage.tsx
index e31365b86d..fad5704a97 100644
--- a/gui/src/renderer/containers/AdvancedSettingsPage.tsx
+++ b/gui/src/renderer/containers/AdvancedSettingsPage.tsx
@@ -12,14 +12,12 @@ import { IReduxState, ReduxDispatch } from '../redux/store';
const mapStateToProps = (state: IReduxState) => {
const protocolAndPort = mapRelaySettingsToProtocolAndPort(state.settings.relaySettings);
- const enableWireguardKeysPage = process.platform === 'linux' || process.platform === 'darwin';
return {
enableIpv6: state.settings.enableIpv6,
blockWhenDisconnected: state.settings.blockWhenDisconnected,
mssfix: state.settings.openVpn.mssfix,
bridgeState: state.settings.bridgeState,
- enableWireguardKeysPage,
...protocolAndPort,
};
};
diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js
index be262cf80e..7ae4dc3296 100644
--- a/gui/tasks/distribution.js
+++ b/gui/tasks/distribution.js
@@ -93,6 +93,7 @@ const config = {
{ from: root('windows/winutil/bin/x64-Release/winutil.dll'), to: '.' },
{ from: distAssets('binaries/x86_64-pc-windows-msvc/openvpn.exe'), to: '.' },
{ from: distAssets('binaries/x86_64-pc-windows-msvc/sslocal.exe'), to: '.' },
+ { from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard/libwg.dll'), to: '.' },
],
},
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 6081f0fe25..0a2c65cb3b 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -1319,12 +1319,7 @@ where
}
}
- #[cfg_attr(target_os = "windows", allow(unreachable_code))]
fn ensure_wireguard_keys_for_current_account(&mut self) {
- #[cfg(target_os = "windows")]
- {
- return;
- }
if let Some(account) = self.settings.get_account_token() {
if self
.account_history
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index a91c71d843..1df24dd7c4 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -12,9 +12,11 @@ cfg-if = "0.1"
duct = "0.13"
err-derive = "0.2.1"
futures = "0.1"
+hex = "0.4"
ipnetwork = "0.15"
jsonrpc-core = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" }
jsonrpc-macros = { git = "https://github.com/mullvad/jsonrpc", branch = "mullvad-fork" }
+lazy_static = "1.0"
libc = "0.2.20"
log = "0.4"
openvpn-plugin = { git = "https://github.com/mullvad/openvpn-plugin-rs", branch = "auth-failed-event", features = ["serde"] }
@@ -25,15 +27,13 @@ shell-escape = "0.1"
talpid-ipc = { path = "../talpid-ipc" }
talpid-types = { path = "../talpid-types" }
tokio-core = "0.1"
+tokio-executor = "0.1"
uuid = { version = "0.7", features = ["v4"] }
[target.'cfg(unix)'.dependencies]
-hex = "0.4"
-lazy_static = "1.0"
nix = "0.15"
tokio-process = "0.2"
-tokio-executor = "0.1"
tokio-io = "0.1"
@@ -63,9 +63,10 @@ tun = "0.4.3"
[target.'cfg(windows)'.dependencies]
+chrono = "0.4"
widestring = "0.4"
winreg = "0.6"
-winapi = { version = "0.3.6", features = ["handleapi", "libloaderapi", "synchapi", "winbase", "winuser"] }
+winapi = { version = "0.3.6", features = ["handleapi", "ifdef", "libloaderapi", "netioapi", "synchapi", "winbase", "winuser"] }
socket2 = "0.3"
rand = "0.7"
pnet_packet = "0.22"
diff --git a/talpid-core/build.rs b/talpid-core/build.rs
index 12be7fcb6c..5a63a1f064 100644
--- a/talpid-core/build.rs
+++ b/talpid-core/build.rs
@@ -53,6 +53,9 @@ fn main() {
declare_library(WINFW_DIR_VAR, WINFW_BUILD_DIR, "winfw");
declare_library(WINDNS_DIR_VAR, WINDNS_BUILD_DIR, "windns");
declare_library(WINNET_DIR_VAR, WINNET_BUILD_DIR, "winnet");
+ let lib_dir = manifest_dir().join("../dist-assets/binaries/x86_64-pc-windows-msvc/wireguard");
+ println!("cargo:rustc-link-search={}", &lib_dir.display());
+ println!("cargo:rustc-link-lib=dylib=libwg");
}
#[cfg(not(windows))]
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index d04e5b8b00..beaaba2ee9 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -159,7 +159,7 @@ type ErrorSink = extern "system" fn(
);
#[allow(non_snake_case)]
-extern "system" {
+extern "stdcall" {
#[link_name = "WinDns_Initialize"]
pub fn WinDns_Initialize(
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index 67b37713d2..bee16fee3a 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -85,12 +85,17 @@ impl FirewallT for Firewall {
match policy {
FirewallPolicy::Connecting {
peer_endpoint,
- // TODO: Allow ICMP traffic to a list of hosts for wireguard
- pingable_hosts: _,
+ pingable_hosts,
allow_lan,
} => {
let cfg = &WinFwSettings::new(allow_lan);
- self.set_connecting_state(&peer_endpoint, &cfg)
+ // TODO: Determine interface alias at runtime
+ self.set_connecting_state(
+ &peer_endpoint,
+ &cfg,
+ "wg-mullvad".to_string(),
+ &pingable_hosts,
+ )
}
FirewallPolicy::Connected {
peer_endpoint,
@@ -128,6 +133,8 @@ impl Firewall {
&mut self,
endpoint: &Endpoint,
winfw_settings: &WinFwSettings,
+ _tunnel_iface_alias: String,
+ pingable_hosts: &Vec<IpAddr>,
) -> Result<(), Error> {
trace!("Applying 'connecting' firewall policy");
let ip_str = Self::widestring_ip(endpoint.address.ip());
@@ -139,7 +146,31 @@ impl Firewall {
protocol: WinFwProt::from(endpoint.protocol),
};
- unsafe { WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay).into_result() }
+ if pingable_hosts.is_empty() {
+ unsafe {
+ return WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, ptr::null())
+ .into_result();
+ }
+ }
+
+ let pingable_addresses = pingable_hosts
+ .iter()
+ .map(|ip| Self::widestring_ip(*ip))
+ .collect::<Vec<_>>();
+ let pingable_address_ptrs = pingable_addresses
+ .iter()
+ .map(|ip| ip.as_ptr())
+ .collect::<Vec<_>>();
+
+ let pingable_hosts = WinFwPingableHosts {
+ interfaceAlias: ptr::null(),
+ addresses: pingable_address_ptrs.as_ptr(),
+ num_addresses: pingable_addresses.len(),
+ };
+
+ unsafe {
+ WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay, &pingable_hosts).into_result()
+ }
}
fn widestring_ip(ip: IpAddr) -> WideCString {
@@ -250,6 +281,14 @@ mod winfw {
}
}
+ #[repr(C)]
+ pub struct WinFwPingableHosts {
+ // a null pointer implies that all interfaces will be able to ping the supplied addresses
+ pub interfaceAlias: *const libc::wchar_t,
+ pub addresses: *const *const libc::wchar_t,
+ pub num_addresses: usize,
+ }
+
ffi_error!(InitializationResult, Error::Initialization);
ffi_error!(DeinitializationResult, Error::Deinitialization);
ffi_error!(ApplyConnectingResult, Error::ApplyingConnectingPolicy);
@@ -280,6 +319,7 @@ mod winfw {
pub fn WinFw_ApplyPolicyConnecting(
settings: &WinFwSettings,
relay: &WinFwRelay,
+ pingable_hosts: *const WinFwPingableHosts,
) -> ApplyConnectingResult;
#[link_name = "WinFw_ApplyPolicyConnected"]
diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs
index 5fdbd3030d..cd88277d9d 100644
--- a/talpid-core/src/lib.rs
+++ b/talpid-core/src/lib.rs
@@ -23,7 +23,6 @@ mod winnet;
#[cfg(any(target_os = "linux", target_os = "macos"))]
/// Working with IP interface devices
pub mod network_interface;
-#[cfg(not(windows))]
/// Abstraction over operating system routing table.
pub mod routing;
diff --git a/talpid-core/src/ping_monitor/win.rs b/talpid-core/src/ping_monitor/win.rs
index f4540dddbd..40fa523584 100644
--- a/talpid-core/src/ping_monitor/win.rs
+++ b/talpid-core/src/ping_monitor/win.rs
@@ -1,5 +1,3 @@
-#![allow(dead_code)]
-// TODO: Remove lint exemption once ping monitor is used on Windows
use pnet_packet::{
icmp::{
self,
@@ -18,6 +16,8 @@ use std::{
time::{Duration, Instant},
};
+const SEND_RETRY_ATTEMPTS: u32 = 10;
+
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
@@ -40,10 +40,10 @@ pub enum Error {
pub fn monitor_ping(
ip: Ipv4Addr,
timeout_secs: u16,
- _interface: &str,
+ interface: &str,
close_receiver: mpsc::Receiver<()>,
) -> Result<()> {
- let mut pinger = Pinger::new(ip)?;
+ let mut pinger = Pinger::new(ip, interface)?;
while let Err(mpsc::TryRecvError::Empty) = close_receiver.try_recv() {
let start = Instant::now();
pinger.send_ping(Duration::from_secs(timeout_secs.into()))?;
@@ -57,8 +57,8 @@ pub fn monitor_ping(
Ok(())
}
-pub fn ping(ip: Ipv4Addr, timeout_secs: u16, _interface: &str) -> Result<()> {
- Pinger::new(ip)?.send_ping(Duration::from_secs(timeout_secs.into()))
+pub fn ping(ip: Ipv4Addr, timeout_secs: u16, interface: &str) -> Result<()> {
+ Pinger::new(ip, interface)?.send_ping(Duration::from_secs(timeout_secs.into()))
}
type Result<T> = std::result::Result<T, Error>;
@@ -70,11 +70,15 @@ pub struct Pinger {
seq: u16,
}
+const NUM_PINGS_TO_SEND: usize = 3;
+
impl Pinger {
- pub fn new(addr: Ipv4Addr) -> Result<Self> {
+ pub fn new(addr: Ipv4Addr, _interface_name: &str) -> Result<Self> {
let sock = Socket::new(Domain::ipv4(), Type::raw(), Some(Protocol::icmpv4()))
.map_err(Error::OpenError)?;
sock.set_nonblocking(true).map_err(Error::OpenError)?;
+
+
Ok(Self {
sock,
id: rand::random(),
@@ -86,19 +90,49 @@ impl Pinger {
/// Sends an ICMP echo request
pub fn send_ping(&mut self, timeout: Duration) -> Result<()> {
let dest = SocketAddr::new(IpAddr::from(self.addr), 0);
- let request = self.next_ping_request();
- self.sock
- .send_to(request.packet(), &dest.into())
- .map_err(Error::WriteError)?;
- self.wait_for_response(Instant::now() + timeout, &request)
+ let requests = (0..NUM_PINGS_TO_SEND)
+ .map(|_| {
+ let request = self.next_ping_request();
+ self.send_ping_request(&request, dest)?;
+ Ok(request)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ self.wait_for_response(Instant::now() + timeout, &requests)
+ }
+
+ fn send_ping_request(
+ &mut self,
+ request: &EchoRequestPacket<'static>,
+ destination: SocketAddr,
+ ) -> Result<()> {
+ let mut tries = 0;
+ let mut result = Ok(());
+ while tries < SEND_RETRY_ATTEMPTS {
+ match self.sock.send_to(request.packet(), &destination.into()) {
+ Ok(_) => {
+ return Ok(());
+ }
+ Err(err) => {
+ if Some(10065) != err.raw_os_error() {
+ return Err(Error::WriteError(err));
+ }
+ result = Err(Error::WriteError(err));
+ }
+ }
+ thread::sleep(Duration::from_secs(1));
+ tries += 1;
+ }
+ result
}
/// returns the next ping packet
fn next_ping_request(&mut self) -> EchoRequestPacket<'static> {
+ use rand::Rng;
const ICMP_HEADER_LENGTH: usize = 8;
- const ICMP_PAYLOAD_LENGTH: usize = 24;
+ const ICMP_PAYLOAD_LENGTH: usize = 150;
const ICMP_PACKET_LENGTH: usize = ICMP_HEADER_LENGTH + ICMP_PAYLOAD_LENGTH;
- let payload: [u8; ICMP_PAYLOAD_LENGTH] = rand::random();
+ let mut payload = [0u8; ICMP_PAYLOAD_LENGTH];
+ rand::thread_rng().fill(&mut payload[..]);
let mut packet = MutableEchoRequestPacket::owned(vec![0u8; ICMP_PACKET_LENGTH])
.expect("Failed to construct an empty packet");
packet.set_icmp_type(IcmpType::new(8));
@@ -117,24 +151,39 @@ impl Pinger {
}
- fn wait_for_response(&mut self, deadline: Instant, req: &EchoRequestPacket<'_>) -> Result<()> {
+ fn wait_for_response(
+ &mut self,
+ deadline: Instant,
+ requests: &[EchoRequestPacket<'_>],
+ ) -> Result<()> {
let mut recv_buffer = [0u8; 4096];
- while Instant::now() < deadline {
+ let mut bytes_received = 0;
+ let mut success = false;
+ let mut requests = requests.iter().map(|req| (false, req)).collect::<Vec<_>>();
+ 'outer: while Instant::now() < deadline {
match self.sock.recv(&mut recv_buffer) {
Ok(recv_len) => {
+ bytes_received += recv_len;
if recv_len > 20 {
// have to slice off first 20 bytes for the IP header.
if let Some(reply) = Self::parse_response(&recv_buffer[20..recv_len]) {
- if reply.get_identifier() == req.get_identifier()
- && reply.get_sequence_number() == req.get_sequence_number()
- && req.payload() == reply.payload()
- {
- return Ok(());
+ for (used, req) in requests.iter_mut() {
+ if *used {
+ continue;
+ }
+ if Self::request_and_response_match(req, &reply) {
+ *used = true;
+ success = true;
+ continue 'outer;
+ }
}
}
}
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ if success {
+ return Ok(());
+ }
std::thread::sleep(Duration::from_millis(100));
continue;
}
@@ -143,9 +192,44 @@ impl Pinger {
}
}
}
+ log::debug!(
+ "Timing out whilst waiting for ICMP response after receiving {} bytes",
+ bytes_received
+ );
Err(Error::TimeoutError)
}
+ fn request_and_response_match(req: &EchoRequestPacket<'_>, resp: &EchoReplyPacket<'_>) -> bool {
+ if req.get_identifier() != resp.get_identifier() {
+ log::debug!(
+ "Expected idnetifier {} - got {}",
+ req.get_identifier(),
+ resp.get_identifier()
+ );
+ return false;
+ }
+
+ if req.get_sequence_number() != resp.get_sequence_number() {
+ log::debug!(
+ "Expected sequence number {} - got {}",
+ req.get_sequence_number(),
+ resp.get_sequence_number()
+ );
+ return false;
+ }
+
+ if req.payload() != resp.payload() {
+ log::debug!(
+ "Expected payload {:?} - got {:?}",
+ req.payload(),
+ resp.payload()
+ );
+ return false;
+ }
+
+ return true;
+ }
+
fn parse_response<'a>(buffer: &'a [u8]) -> Option<EchoReplyPacket<'a>> {
let icmp_checksum = icmp::checksum(&IcmpPacket::new(buffer)?);
let reply = EchoReplyPacket::new(buffer)?;
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs
index 6f5b73e014..4636b0b27d 100644
--- a/talpid-core/src/routing/mod.rs
+++ b/talpid-core/src/routing/mod.rs
@@ -1,4 +1,5 @@
#![cfg_attr(target_os = "android", allow(dead_code))]
+#![cfg_attr(target_os = "windows", allow(dead_code))]
// TODO: remove the allow(dead_code) for android once it's up to scratch.
use futures::{sync::oneshot, Future};
use ipnetwork::IpNetwork;
@@ -16,6 +17,12 @@ mod imp;
#[path = "android.rs"]
mod imp;
+#[cfg(target_os = "windows")]
+#[path = "windows.rs"]
+mod imp;
+#[cfg(target_os = "windows")]
+use crate::winnet;
+
pub use imp::Error as PlatformError;
/// Errors that can be encountered whilst initializing RouteManager
@@ -37,6 +44,8 @@ pub enum Error {
/// the route will be adjusted dynamically when the default route changes.
pub struct RouteManager {
tx: Option<oneshot::Sender<oneshot::Sender<()>>>,
+ #[cfg(target_os = "windows")]
+ callback_handles: Vec<winnet::WinNetCallbackHandle>,
}
impl RouteManager {
@@ -61,12 +70,34 @@ impl RouteManager {
},
);
match start_rx.wait() {
- Ok(Ok(())) => Ok(Self { tx: Some(tx) }),
+ Ok(Ok(())) => Ok(Self {
+ tx: Some(tx),
+ #[cfg(target_os = "windows")]
+ callback_handles: vec![],
+ }),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::RoutingManagerThreadPanic),
}
}
+ /// Sets a callback that is called whenever the default route changes.
+ #[cfg(target_os = "windows")]
+ pub fn set_default_route_callback<T: 'static>(
+ &mut self,
+ callback: Option<winnet::DefaultRouteChangedCallback>,
+ context: T,
+ ) {
+ match winnet::set_default_route_change_callback(callback, context) {
+ Err(_e) => {
+ // not sure if this should panic
+ log::error!("Failed to add callback!");
+ }
+ Ok(handle) => {
+ self.callback_handles.push(handle);
+ }
+ }
+ }
+
/// Stops RouteManager and removes all of the applied routes.
pub fn stop(&mut self) {
if let Some(tx) = self.tx.take() {
@@ -85,6 +116,11 @@ impl RouteManager {
impl Drop for RouteManager {
fn drop(&mut self) {
+ // Ensuring callbacks are removed before the route manager is stopped
+ #[cfg(target_os = "windows")]
+ {
+ self.callback_handles.clear();
+ }
self.stop();
}
}
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
new file mode 100644
index 0000000000..684d1a3184
--- /dev/null
+++ b/talpid-core/src/routing/windows.rs
@@ -0,0 +1,72 @@
+use super::NetNode;
+use crate::winnet;
+use futures::{sync::oneshot, Async, Future};
+use ipnetwork::IpNetwork;
+use std::collections::HashMap;
+
+/// Windows routing errors.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// Failure to apply a route
+ #[error(display = "Failed to start route manager")]
+ FailedToStartManager,
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+pub struct RouteManagerImpl {
+ shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+}
+
+impl RouteManagerImpl {
+ pub fn new(
+ required_routes: HashMap<IpNetwork, NetNode>,
+ shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+ ) -> Result<Self> {
+ let routes: Vec<_> = required_routes
+ .iter()
+ .map(|(destination, node)| {
+ let destination = winnet::WinNetIpNetwork::from(*destination);
+ match node {
+ NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
+ NetNode::RealNode(node) => {
+ winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
+ }
+ }
+ })
+ .collect();
+
+ if !winnet::activate_routing_manager(&routes) {
+ return Err(Error::FailedToStartManager);
+ }
+
+
+ Ok(Self { shutdown_rx })
+ }
+}
+
+impl Drop for RouteManagerImpl {
+ fn drop(&mut self) {
+ if !winnet::deactivate_routing_manager() {
+ log::error!("Failed to deactivate routing manager");
+ }
+ }
+}
+
+
+impl Future for RouteManagerImpl {
+ type Item = ();
+ type Error = Error;
+ fn poll(&mut self) -> Result<Async<()>> {
+ match self.shutdown_rx.poll() {
+ Ok(Async::Ready(result_tx)) => {
+ if let Err(_e) = result_tx.send(()) {
+ log::error!("Receiver already down");
+ }
+ Ok(Async::Ready(()))
+ }
+ Ok(Async::NotReady) => Ok(Async::NotReady),
+ Err(_) => Ok(Async::Ready(())),
+ }
+ }
+}
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 6eaa77f8b8..ddf19fcb66 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -9,15 +9,12 @@ use std::{
};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
-#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
-use talpid_types::net::wireguard as wireguard_types;
-use talpid_types::net::{GenericTunnelOptions, TunnelParameters};
+use talpid_types::net::{wireguard as wireguard_types, GenericTunnelOptions, TunnelParameters};
/// A module for all OpenVPN related tunnel management.
#[cfg(not(target_os = "android"))]
pub mod openvpn;
-#[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
pub mod wireguard;
/// A module for low level platform specific tunnel device management.
@@ -45,7 +42,6 @@ pub enum Error {
RotateLogError(#[error(source)] crate::logging::RotateLogError),
/// Failure to build Wireguard configuration.
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
#[error(display = "Failed to configure Wireguard with the given parameters")]
WireguardConfigError(#[error(source)] self::wireguard::config::Error),
@@ -55,7 +51,6 @@ pub enum Error {
OpenVpnTunnelMonitoringError(#[error(source)] openvpn::Error),
/// There was an error listening for events from the Wireguard tunnel
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
#[error(display = "Failed while listening for events from the Wireguard tunnel")]
WireguardTunnelMonitoringError(#[error(source)] wireguard::Error),
}
@@ -161,16 +156,12 @@ impl TunnelMonitor {
#[cfg(target_os = "android")]
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
TunnelParameters::Wireguard(config) => {
Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider)
}
- #[cfg(windows)]
- TunnelParameters::Wireguard(_) => Err(Error::UnsupportedPlatform),
}
}
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
fn start_wireguard_tunnel<L>(
params: &wireguard_types::TunnelParameters,
log: Option<PathBuf>,
@@ -216,6 +207,7 @@ impl TunnelMonitor {
}
}
+ #[cfg(not(target_os = "windows"))]
fn prepare_tunnel_log_file(
parameters: &TunnelParameters,
log_dir: &Option<PathBuf>,
@@ -234,6 +226,23 @@ impl TunnelMonitor {
}
}
+ #[cfg(target_os = "windows")]
+ fn prepare_tunnel_log_file(
+ parameters: &TunnelParameters,
+ log_dir: &Option<PathBuf>,
+ ) -> Result<Option<PathBuf>> {
+ if let Some(ref log_dir) = log_dir {
+ let filename = match parameters {
+ TunnelParameters::OpenVpn(_) => OPENVPN_LOG_FILENAME,
+ TunnelParameters::Wireguard(_) => WIREGUARD_LOG_FILENAME,
+ };
+ let tunnel_log = log_dir.join(filename);
+ logging::rotate_log(&tunnel_log)?;
+ Ok(Some(tunnel_log))
+ } else {
+ Ok(None)
+ }
+ }
/// Creates a handle to this monitor, allowing the tunnel to be closed while some other
/// thread
@@ -254,7 +263,6 @@ pub enum CloseHandle {
#[cfg(not(target_os = "android"))]
/// OpenVpn close handle
OpenVpn(openvpn::OpenVpnCloseHandle),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
/// Wireguard close handle
Wireguard(wireguard::CloseHandle),
}
@@ -265,7 +273,6 @@ impl CloseHandle {
match self {
#[cfg(not(target_os = "android"))]
CloseHandle::OpenVpn(handle) => handle.close(),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
CloseHandle::Wireguard(mut handle) => {
handle.close();
Ok(())
@@ -277,7 +284,6 @@ impl CloseHandle {
enum InternalTunnelMonitor {
#[cfg(not(target_os = "android"))]
OpenVpn(openvpn::OpenVpnMonitor),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
Wireguard(wireguard::WireguardMonitor),
}
@@ -286,7 +292,6 @@ impl InternalTunnelMonitor {
match self {
#[cfg(not(target_os = "android"))]
InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()),
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()),
}
}
@@ -295,7 +300,6 @@ impl InternalTunnelMonitor {
match self {
#[cfg(not(target_os = "android"))]
InternalTunnelMonitor::OpenVpn(tun) => tun.wait()?,
- #[cfg(any(target_os = "android", target_os = "linux", target_os = "macos"))]
InternalTunnelMonitor::Wireguard(tun) => tun.wait()?,
}
diff --git a/talpid-core/src/tunnel/tun_provider/mod.rs b/talpid-core/src/tunnel/tun_provider/mod.rs
index c6701ceac9..f0bf8f69b6 100644
--- a/talpid-core/src/tunnel/tun_provider/mod.rs
+++ b/talpid-core/src/tunnel/tun_provider/mod.rs
@@ -29,6 +29,10 @@ cfg_if! {
}
}
+/// Windows tunnel
+#[cfg(target_os = "windows")]
+pub mod windows;
+
/// Generic tunnel device.
///
/// Must be associated with a platform specific file descriptor representing the device.
diff --git a/talpid-core/src/tunnel/tun_provider/windows.rs b/talpid-core/src/tunnel/tun_provider/windows.rs
new file mode 100644
index 0000000000..9a114bf4b7
--- /dev/null
+++ b/talpid-core/src/tunnel/tun_provider/windows.rs
@@ -0,0 +1,13 @@
+use super::Tun;
+
+/// Windows tunnel implementation
+pub struct WinTun {
+ /// Name of tunnel interface
+ pub interface_name: String,
+}
+
+impl Tun for WinTun {
+ fn interface_name(&self) -> &str {
+ &self.interface_name
+ }
+}
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index c9b6988f7b..1f7f11052f 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -43,6 +43,11 @@ pub enum Error {
#[error(display = "Failed to stop wireguard tunnel - {}", status)]
StopWireguardError { status: i32 },
+ /// Failed to set ip addresses on tunnel interface.
+ #[cfg(target_os = "windows")]
+ #[error(display = "Failed to set IP addresses on WireGuard interface")]
+ SetIpAddressesError,
+
/// Failed to set up routing.
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] crate::routing::Error),
@@ -97,8 +102,13 @@ impl WireguardMonitor {
Self::get_tunnel_routes(config),
)?);
let iface_name = tunnel.get_interface_name();
- let route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config))
+ let mut route_handle = routing::RouteManager::new(Self::get_routes(iface_name, &config))
.map_err(Error::SetupRoutingError)?;
+
+ #[cfg(target_os = "windows")]
+ route_handle
+ .set_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
+
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
let (pinger_tx, pinger_rx) = mpsc::channel();
@@ -121,12 +131,10 @@ impl WireguardMonitor {
Ok(()) => {
(on_event)(TunnelEvent::Up(metadata));
- match ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx)
+ if let Err(error) =
+ ping_monitor::monitor_ping(gateway, PING_TIMEOUT, &iface_name, pinger_rx)
{
- Ok(()) => return,
- Err(error) => {
- log::trace!("{}", error.display_chain_with_msg("Ping monitor failed"));
- }
+ log::trace!("{}", error.display_chain_with_msg("Ping monitor failed"));
}
}
Err(error) => {
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 4e3b9c45fd..50df570d01 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -1,21 +1,56 @@
use super::{Config, Error, Result, Tunnel};
-use crate::tunnel::tun_provider::{Tun, TunConfig, TunProvider};
+use crate::tunnel::tun_provider::{Tun, TunProvider};
use ipnetwork::IpNetwork;
-use std::{ffi::CString, net::IpAddr, os::unix::io::RawFd, path::Path, ptr};
+use std::{ffi::CString, path::Path};
+
+#[cfg(not(target_os = "windows"))]
+use {
+ crate::tunnel::tun_provider::TunConfig,
+ std::{net::IpAddr, os::unix::io::RawFd, ptr},
+};
+
+
+#[cfg(target_os = "windows")]
+use crate::{
+ tunnel::tun_provider::windows::WinTun,
+ winnet::{self, add_device_ip_addresses},
+};
+
#[cfg(target_os = "android")]
use talpid_types::BoxedError;
+#[cfg(not(target_os = "windows"))]
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
+#[cfg(target_os = "windows")]
+use {
+ chrono,
+ parking_lot::Mutex,
+ std::{collections::HashMap, fs, io::Write},
+};
+
+
pub struct WgGoTunnel {
interface_name: String,
handle: Option<i32>,
// holding on to the tunnel device and the log file ensures that the associated file handles
// live long enough and get closed when the tunnel is stopped
_tunnel_device: Box<dyn Tun>,
+ // ordinal that maps to fs::File instance, used with logging callback
+ #[cfg(target_os = "windows")]
+ log_context_ordinal: u32,
+}
+
+#[cfg(target_os = "windows")]
+lazy_static::lazy_static! {
+ static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new());
}
+#[cfg(target_os = "windows")]
+static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0;
+
impl WgGoTunnel {
+ #[cfg(not(target_os = "windows"))]
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
@@ -66,6 +101,129 @@ impl WgGoTunnel {
})
}
+ #[cfg(target_os = "windows")]
+ pub fn start_tunnel(
+ config: &Config,
+ log_path: Option<&Path>,
+ _tun_provider: &dyn TunProvider,
+ _routes: impl Iterator<Item = IpNetwork>,
+ ) -> Result<Self> {
+ let log_file = prepare_log_file(log_path)?;
+
+ let log_context_ordinal = unsafe {
+ let mut map = LOG_MUTEX.lock();
+ let ordinal = LOG_CONTEXT_NEXT_ORDINAL;
+ LOG_CONTEXT_NEXT_ORDINAL += 1;
+ map.insert(ordinal, log_file);
+ ordinal
+ };
+
+ let wg_config_str = config.to_userspace_format();
+ let iface_name: String = "wg-mullvad".to_string();
+ let cstr_iface_name =
+ CString::new(iface_name.as_bytes()).map_err(Error::InterfaceNameError)?;
+
+ let handle = unsafe {
+ wgTurnOn(
+ cstr_iface_name.as_ptr(),
+ config.mtu as i64,
+ wg_config_str.as_ptr(),
+ Some(Self::logging_callback),
+ log_context_ordinal as *mut libc::c_void,
+ )
+ };
+
+ if handle < 0 {
+ clean_up_log_file(log_context_ordinal);
+ return Err(Error::FatalStartWireguardError);
+ }
+
+ if !add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
+ // Todo: what kind of clean-up is required?
+ clean_up_log_file(log_context_ordinal);
+ return Err(Error::SetIpAddressesError);
+ }
+
+ Ok(WgGoTunnel {
+ interface_name: iface_name.clone(),
+ handle: Some(handle),
+ _tunnel_device: Box::new(WinTun {
+ interface_name: iface_name.clone(),
+ }),
+ log_context_ordinal,
+ })
+ }
+
+ // Callback to be used to rebind the tunnel sockets when the default route changes
+ #[cfg(target_os = "windows")]
+ pub unsafe extern "system" fn default_route_changed_callback(
+ event_type: winnet::WinNetDefaultRouteChangeEventType,
+ address_family: winnet::WinNetIpFamily,
+ interface_luid: u64,
+ _ctx: *mut libc::c_void,
+ ) {
+ use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex};
+ let iface_idx: u32 = match event_type {
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ let mut iface_idx = 0u32;
+ let iface_luid = NET_LUID {
+ Value: interface_luid,
+ };
+ let status =
+ ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _);
+ if status != 0 {
+ log::error!(
+ "Failed to convert interface LUID to interface index - {} - {}",
+ status,
+ std::io::Error::last_os_error()
+ );
+ return;
+ }
+ iface_idx
+ }
+ // if there is no new default route, specify 0 as the interface index
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => 0,
+ };
+
+ wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx);
+ }
+
+ // Callback that receives messages from WireGuard
+ #[cfg(target_os = "windows")]
+ pub unsafe extern "system" fn logging_callback(
+ level: WgLogLevel,
+ msg: *const libc::c_char,
+ context: *mut libc::c_void,
+ ) {
+ let map = LOG_MUTEX.lock();
+ if let Some(mut logfile) = map.get(&(context as u32)) {
+ let managed_msg = if !msg.is_null() {
+ std::ffi::CStr::from_ptr(msg)
+ .to_string_lossy()
+ .to_string()
+ .replace("\n", "\r\n")
+ } else {
+ "Logging message from WireGuard is NULL".to_string()
+ };
+
+ let level_str = match level {
+ WG_GO_LOG_DEBUG => "DEBUG",
+ WG_GO_LOG_INFO => "INFO",
+ WG_GO_LOG_ERROR | _ => "ERROR",
+ };
+
+ let _ = write!(
+ logfile,
+ "{}[{}][{}] {}",
+ chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"),
+ "wireguard-go",
+ level_str,
+ managed_msg
+ );
+ }
+ }
+
+ #[cfg(not(target_os = "windows"))]
fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig {
let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)];
dns_servers.extend(config.ipv6_gateway.map(IpAddr::V6));
@@ -102,6 +260,7 @@ impl WgGoTunnel {
Ok(())
}
+ #[cfg(not(target_os = "windows"))]
fn get_tunnel(
tun_provider: &mut dyn TunProvider,
config: &Config,
@@ -130,14 +289,30 @@ impl WgGoTunnel {
}
}
+#[cfg(target_os = "windows")]
+fn clean_up_log_file(ordinal: u32) {
+ let mut map = LOG_MUTEX.lock();
+ map.remove(&ordinal);
+}
+
impl Drop for WgGoTunnel {
fn drop(&mut self) {
if let Err(e) = self.stop_tunnel() {
log::error!("Failed to stop tunnel - {}", e);
}
+ #[cfg(target_os = "windows")]
+ clean_up_log_file(self.log_context_ordinal);
}
}
+#[cfg(target_os = "windows")]
+static NULL_DEVICE: &str = "NUL";
+
+#[cfg(target_os = "windows")]
+fn prepare_log_file(log_path: Option<&Path>) -> Result<fs::File> {
+ fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError)
+}
+
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> &str {
&self.interface_name
@@ -154,12 +329,22 @@ pub type Fd = std::os::unix::io::RawFd;
#[cfg(windows)]
pub type Fd = std::os::windows::io::RawHandle;
-type WgLogLevel = i32;
+type WgLogLevel = u32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
+// const WG_GO_LOG_SILENT: WgLogLevel = 0;
+#[cfg(target_os = "windows")]
+const WG_GO_LOG_ERROR: WgLogLevel = 1;
+#[cfg(target_os = "windows")]
+const WG_GO_LOG_INFO: WgLogLevel = 2;
const WG_GO_LOG_DEBUG: WgLogLevel = 3;
-#[cfg_attr(target_os = "android", link(name = "wg", kind = "dylib"))]
-#[cfg_attr(not(target_os = "android"), link(name = "wg", kind = "static"))]
+#[cfg(target_os = "windows")]
+pub type LoggingCallback = unsafe extern "system" fn(
+ level: WgLogLevel,
+ msg: *const libc::c_char,
+ context: *mut libc::c_void,
+);
+
extern "C" {
// Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors
// for the tunnel device and logging.
@@ -167,6 +352,7 @@ extern "C" {
// Positive return values are tunnel handles for this specific wireguard tunnel instance.
// Negative return values signify errors. All error codes are opaque.
#[cfg_attr(target_os = "android", link_name = "wgTurnOnWithFdAndroid")]
+ #[cfg(not(target_os = "windows"))]
fn wgTurnOnWithFd(
iface_name: *const i8,
mtu: isize,
@@ -176,6 +362,16 @@ extern "C" {
logLevel: WgLogLevel,
) -> i32;
+ // Windows
+ #[cfg(target_os = "windows")]
+ fn wgTurnOn(
+ iface_name: *const i8,
+ mtu: i64,
+ settings: *const i8,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: *mut libc::c_void,
+ ) -> i32;
+
// Pass a handle that was created by wgTurnOnWithFd to stop a wireguard tunnel.
fn wgTurnOff(handle: i32) -> i32;
@@ -186,4 +382,8 @@ extern "C" {
// Returns the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
fn wgGetSocketV6(handle: i32) -> Fd;
+
+ // Rebind tunnel socket when network interfaces change
+ #[cfg(target_os = "windows")]
+ fn wgRebindTunnelSocket(family: u16, interfaceIndex: u32);
}
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index c6bd6c6115..6b42f6d810 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -2,8 +2,14 @@ use self::api::*;
pub use self::api::{
LogSink, WinNet_ActivateConnectivityMonitor, WinNet_DeactivateConnectivityMonitor,
};
+use crate::routing::Node;
+use ipnetwork::IpNetwork;
use libc::{c_char, c_void, wchar_t};
-use std::{ffi::OsString, ptr};
+use std::{
+ ffi::{CStr, OsString},
+ net::IpAddr,
+ ptr,
+};
use widestring::WideCString;
/// Errors that this module may produce.
@@ -41,7 +47,6 @@ pub enum LogSeverity {
/// Logging callback used with `winnet.dll`.
pub extern "system" fn log_sink(severity: LogSeverity, msg: *const c_char, _ctx: *mut c_void) {
- use std::ffi::CStr;
if msg.is_null() {
log::error!("Log message from FFI boundary is NULL");
} else {
@@ -119,9 +124,264 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> {
Ok(alias.to_os_string())
}
+#[repr(C)]
+struct WinNetIpType(u32);
+
+const WINNET_IPV4: u32 = 0;
+const WINNET_IPV6: u32 = 1;
+
+impl WinNetIpType {
+ pub fn v4() -> Self {
+ WinNetIpType(WINNET_IPV4)
+ }
+
+ pub fn v6() -> Self {
+ WinNetIpType(WINNET_IPV6)
+ }
+}
+
+
+#[repr(C)]
+pub struct WinNetIpNetwork {
+ ip_type: WinNetIpType,
+ ip_bytes: [u8; 16],
+ prefix: u8,
+}
+
+impl From<IpNetwork> for WinNetIpNetwork {
+ fn from(network: IpNetwork) -> WinNetIpNetwork {
+ let WinNetIp { ip_type, ip_bytes } = WinNetIp::from(network.ip());
+ let prefix = network.prefix();
+ WinNetIpNetwork {
+ ip_type,
+ ip_bytes,
+ prefix,
+ }
+ }
+}
+
+#[repr(C)]
+pub struct WinNetIp {
+ ip_type: WinNetIpType,
+ ip_bytes: [u8; 16],
+}
+
+impl From<IpAddr> for WinNetIp {
+ fn from(addr: IpAddr) -> WinNetIp {
+ let mut bytes = [0u8; 16];
+ match addr {
+ IpAddr::V4(v4_addr) => {
+ bytes[..4].copy_from_slice(&v4_addr.octets());
+ WinNetIp {
+ ip_type: WinNetIpType::v4(),
+ ip_bytes: bytes,
+ }
+ }
+ IpAddr::V6(v6_addr) => {
+ bytes.copy_from_slice(&v6_addr.octets());
+
+ WinNetIp {
+ ip_type: WinNetIpType::v6(),
+ ip_bytes: bytes,
+ }
+ }
+ }
+ }
+}
+
+#[repr(C)]
+pub struct WinNetNode {
+ gateway: *mut WinNetIp,
+ device_name: *mut u16,
+}
+
+impl WinNetNode {
+ fn new(name: &str, ip: WinNetIp) -> Self {
+ let device_name = WideCString::from_str(name)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ let gateway = Box::into_raw(Box::new(ip));
+ Self {
+ gateway,
+ device_name,
+ }
+ }
+
+ fn from_gateway(ip: WinNetIp) -> Self {
+ let gateway = Box::into_raw(Box::new(ip));
+ Self {
+ gateway,
+ device_name: ptr::null_mut(),
+ }
+ }
+
+
+ fn from_device(name: &str) -> Self {
+ let device_name = WideCString::from_str(name)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ Self {
+ gateway: ptr::null_mut(),
+ device_name,
+ }
+ }
+}
+
+impl From<&Node> for WinNetNode {
+ fn from(node: &Node) -> Self {
+ match (node.get_address(), node.get_device()) {
+ (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()),
+ (None, Some(device)) => WinNetNode::from_device(device),
+ (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()),
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl Drop for WinNetNode {
+ fn drop(&mut self) {
+ if !self.gateway.is_null() {
+ unsafe {
+ let _ = Box::from_raw(self.gateway);
+ }
+ }
+ if !self.device_name.is_null() {
+ unsafe {
+ let _ = WideCString::from_ptr_str(self.device_name);
+ }
+ }
+ }
+}
+
+
+#[repr(C)]
+pub struct WinNetRoute {
+ gateway: WinNetIpNetwork,
+ node: *mut WinNetNode,
+}
+
+impl WinNetRoute {
+ pub fn through_default_node(gateway: WinNetIpNetwork) -> Self {
+ Self {
+ gateway,
+ node: ptr::null_mut(),
+ }
+ }
+
+ pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self {
+ let node = Box::into_raw(Box::new(node));
+ WinNetRoute { gateway, node }
+ }
+}
+
+impl Drop for WinNetRoute {
+ fn drop(&mut self) {
+ if !self.node.is_null() {
+ unsafe {
+ let _ = Box::from_raw(self.node);
+ }
+ self.node = ptr::null_mut();
+ }
+ }
+}
+
+pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool {
+ unsafe { WinNet_ActivateRouteManager(Some(log_sink), ptr::null_mut()) };
+ routing_manager_add_routes(routes)
+}
+
+pub struct WinNetCallbackHandle {
+ handle: *mut libc::c_void,
+ // allows us to keep the context pointer allive.
+ _context: Box<dyn std::any::Any>,
+}
+
+unsafe impl Send for WinNetCallbackHandle {}
+
+impl Drop for WinNetCallbackHandle {
+ fn drop(&mut self) {
+ unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) };
+ }
+}
+
+#[allow(dead_code)]
+#[repr(u16)]
+pub enum WinNetDefaultRouteChangeEventType {
+ DefaultRouteChanged = 0,
+ DefaultRouteRemoved = 1,
+}
+
+#[allow(dead_code)]
+#[repr(u16)]
+pub enum WinNetIpFamily {
+ V4 = 0,
+ V6 = 1,
+}
+
+impl WinNetIpFamily {
+ pub fn to_windows_proto_enum(&self) -> u16 {
+ match self {
+ Self::V4 => 2,
+ Self::V6 => 23,
+ }
+ }
+}
+
+pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
+ event_type: WinNetDefaultRouteChangeEventType,
+ ip_family: WinNetIpFamily,
+ interface_luid: u64,
+ ctx: *mut c_void,
+);
+
+#[derive(err_derive::Error, Debug)]
+#[error(display = "Failed to set callback for default route")]
+pub struct DefaultRouteCallbackError;
+
+pub fn set_default_route_change_callback<T: 'static>(
+ callback: Option<DefaultRouteChangedCallback>,
+ context: T,
+) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> {
+ let mut handle_ptr = ptr::null_mut();
+ let mut context = Box::new(context);
+ let ctx_ptr = &mut *context as *mut T as *mut libc::c_void;
+ unsafe {
+ if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _)
+ {
+ return Err(DefaultRouteCallbackError);
+ }
+
+
+ Ok(WinNetCallbackHandle {
+ handle: handle_ptr,
+ _context: context,
+ })
+ }
+}
+
+pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool {
+ let ptr = routes.as_ptr();
+ let length: u32 = routes.len() as u32;
+ unsafe { WinNet_AddRoutes(ptr, length) }
+}
+
+pub fn deactivate_routing_manager() -> bool {
+ unsafe { WinNet_DeactivateRouteManager() }
+}
+
+pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
+ let raw_iface = WideCString::from_str(iface)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string")
+ .into_raw();
+ let converted_addresses: Vec<_> = addresses.iter().map(|addr| WinNetIp::from(*addr)).collect();
+ let ptr = converted_addresses.as_ptr();
+ let length: u32 = converted_addresses.len() as u32;
+ unsafe { WinNet_AddDeviceIpAddresses(raw_iface, ptr, length, Some(log_sink), ptr::null_mut()) }
+}
+
#[allow(non_snake_case)]
mod api {
- use super::LogSeverity;
+ use super::{DefaultRouteChangedCallback, LogSeverity};
use libc::{c_char, c_void, wchar_t};
/// logging callback type for use with `winnet.dll`.
@@ -131,6 +391,24 @@ mod api {
pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void);
extern "system" {
+ #[link_name = "WinNet_ActivateRouteManager"]
+ pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *mut c_void);
+
+ #[link_name = "WinNet_AddRoutes"]
+ pub fn WinNet_AddRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool;
+
+ // #[link_name = "WinNet_AddRoute"]
+ // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> bool;
+
+ // #[link_name = "WinNet_DeleteRoutes"]
+ // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool;
+
+ // #[link_name = "WinNet_DeleteRoute"]
+ // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool;
+
+ #[link_name = "WinNet_DeactivateRouteManager"]
+ pub fn WinNet_DeactivateRouteManager() -> bool;
+
#[link_name = "WinNet_EnsureTopMetric"]
pub fn WinNet_EnsureTopMetric(
tunnel_interface_alias: *const wchar_t,
@@ -162,7 +440,26 @@ mod api {
sink_context: *mut c_void,
) -> bool;
+ #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"]
+ pub fn WinNet_RegisterDefaultRouteChangedCallback(
+ callback: Option<DefaultRouteChangedCallback>,
+ callbackContext: *mut libc::c_void,
+ registrationHandle: *mut *mut libc::c_void,
+ ) -> bool;
+
+ #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"]
+ pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void);
+
#[link_name = "WinNet_DeactivateConnectivityMonitor"]
pub fn WinNet_DeactivateConnectivityMonitor() -> bool;
+
+ #[link_name = "WinNet_AddDeviceIpAddresses"]
+ pub fn WinNet_AddDeviceIpAddresses(
+ interface_alias: *const wchar_t,
+ addresses: *const super::WinNetIp,
+ num_addresses: u32,
+ sink: Option<LogSink>,
+ sink_context: *mut c_void,
+ ) -> bool;
}
}
diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp
index e5325f0b9c..49f8793572 100644
--- a/windows/winfw/src/winfw/fwcontext.cpp
+++ b/windows/winfw/src/winfw/fwcontext.cpp
@@ -13,10 +13,10 @@
#include "rules/permitvpnrelay.h"
#include "rules/permitvpntunnel.h"
#include "rules/permitvpntunnelservice.h"
+#include "rules/permitping.h"
#include "rules/restrictdns.h"
#include "libwfp/transaction.h"
#include "libwfp/filterengine.h"
-#include "libwfp/ipaddress.h"
#include <functional>
#include <stdexcept>
#include <utility>
@@ -99,7 +99,12 @@ FwContext::FwContext(uint32_t timeout, const WinFwSettings &settings)
m_baseline = checkpoint;
}
-bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay)
+bool FwContext::applyPolicyConnecting
+(
+ const WinFwSettings &settings,
+ const WinFwRelay &relay,
+ const std::optional<PingableHosts> &pingableHosts
+)
{
Ruleset ruleset;
@@ -112,6 +117,22 @@ bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFw
TranslateProtocol(relay.protocol)
));
+ //
+ // Permit pinging the gateway inside the tunnel.
+ //
+ if (pingableHosts.has_value())
+ {
+ const auto &ph = pingableHosts.value();
+
+ for (const auto &host : ph.hosts)
+ {
+ ruleset.emplace_back(std::make_unique<rules::PermitPing>(
+ ph.tunnelInterfaceAlias,
+ host
+ ));
+ }
+ }
+
return applyRuleset(ruleset);
}
diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h
index 89ef40e1d3..9d5b34c51b 100644
--- a/windows/winfw/src/winfw/fwcontext.h
+++ b/windows/winfw/src/winfw/fwcontext.h
@@ -3,9 +3,11 @@
#include "winfw.h"
#include "sessioncontroller.h"
#include "rules/ifirewallrule.h"
+#include "libwfp/ipaddress.h"
#include <cstdint>
#include <memory>
#include <vector>
+#include <optional>
class FwContext
{
@@ -16,7 +18,19 @@ public:
// This ctor applies the "blocked" policy.
FwContext(uint32_t timeout, const WinFwSettings &settings);
- bool applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay);
+ struct PingableHosts
+ {
+ std::optional<std::wstring> tunnelInterfaceAlias;
+ std::vector<wfp::IpAddress> hosts;
+ };
+
+ bool applyPolicyConnecting
+ (
+ const WinFwSettings &settings,
+ const WinFwRelay &relay,
+ const std::optional<PingableHosts> &pingableHosts
+ );
+
bool applyPolicyConnected
(
const WinFwSettings &settings,
diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp
index 010d41e44a..e73fac26ed 100644
--- a/windows/winfw/src/winfw/mullvadguids.cpp
+++ b/windows/winfw/src/winfw/mullvadguids.cpp
@@ -59,6 +59,8 @@ DetailedWfpObjectRegistry MullvadGuids::BuildDetailedRegistry()
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Outbound_Router_Solicitation()));
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Router_Advertisement()));
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Redirect()));
+ registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv4()));
+ registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv6()));
return registry;
}
@@ -567,3 +569,31 @@ const GUID &MullvadGuids::FilterPermitNdp_Inbound_Redirect()
return g;
}
+
+//static
+const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv4()
+{
+ static const GUID g =
+ {
+ 0x2ecf7ff7,
+ 0xc951,
+ 0x4056,
+ { 0xb0, 0xf7, 0x40, 0xa4, 0x5c, 0x7e, 0xb4, 0xc2 }
+ };
+
+ return g;
+}
+
+//static
+const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv6()
+{
+ static const GUID g =
+ {
+ 0x3deb8cab,
+ 0x1edb,
+ 0x4aa1,
+ { 0xb2, 0x73, 0xec, 0x61, 0x4f, 0x50, 0xdc, 0x13 }
+ };
+
+ return g;
+}
diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h
index d4fb470d90..3c3ca9702b 100644
--- a/windows/winfw/src/winfw/mullvadguids.h
+++ b/windows/winfw/src/winfw/mullvadguids.h
@@ -67,4 +67,7 @@ public:
static const GUID &FilterPermitNdp_Outbound_Router_Solicitation();
static const GUID &FilterPermitNdp_Inbound_Router_Advertisement();
static const GUID &FilterPermitNdp_Inbound_Redirect();
+
+ static const GUID &FilterPermitPing_Outbound_Icmpv4();
+ static const GUID &FilterPermitPing_Outbound_Icmpv6();
};
diff --git a/windows/winfw/src/winfw/rules/permitping.cpp b/windows/winfw/src/winfw/rules/permitping.cpp
new file mode 100644
index 0000000000..f6aed36bf2
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/permitping.cpp
@@ -0,0 +1,98 @@
+#include "stdafx.h"
+#include "permitping.h"
+#include "winfw/mullvadguids.h"
+#include "libwfp/filterbuilder.h"
+#include "libwfp/conditionbuilder.h"
+#include "libwfp/conditions/conditionip.h"
+#include "libwfp/conditions/conditioninterface.h"
+#include "libwfp/conditions/conditionprotocol.h"
+
+
+using namespace wfp::conditions;
+
+namespace rules
+{
+
+PermitPing::PermitPing
+(
+ const std::optional<std::wstring> &interfaceAlias,
+ const wfp::IpAddress &host
+)
+ : m_interfaceAlias(interfaceAlias)
+ , m_host(host)
+{
+}
+
+bool PermitPing::apply(IObjectInstaller &objectInstaller)
+{
+ if (wfp::IpAddress::Type::Ipv4 == m_host.type())
+ {
+ return applyIcmpv4(objectInstaller);
+ }
+
+ return applyIcmpv6(objectInstaller);
+}
+
+bool PermitPing::applyIcmpv4(IObjectInstaller &objectInstaller) const
+{
+ wfp::FilterBuilder filterBuilder;
+
+ //
+ // #1 Permit outbound ICMPv4 to %host% on %interface%
+ //
+
+ filterBuilder
+ .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv4())
+ .name(L"Permit outbound ICMP to specific host (ICMPv4)")
+ .description(L"This filter is part of a rule that permits ping")
+ .provider(MullvadGuids::Provider())
+ .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
+ .sublayer(MullvadGuids::SublayerWhitelist())
+ .weight(wfp::FilterBuilder::WeightClass::Max)
+ .permit();
+
+ wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
+
+ conditionBuilder.add_condition(ConditionIp::Remote(m_host));
+ conditionBuilder.add_condition(ConditionProtocol::Icmp());
+
+ if (m_interfaceAlias.has_value())
+ {
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value()));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+}
+
+bool PermitPing::applyIcmpv6(IObjectInstaller &objectInstaller) const
+{
+ wfp::FilterBuilder filterBuilder;
+
+ //
+ // #1 Permit outbound ICMPv6 to %host% on %interface%
+ //
+
+ filterBuilder
+ .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv6())
+ .name(L"Permit outbound ICMP to specific host (ICMPv6)")
+ .description(L"This filter is part of a rule that permits ping")
+ .provider(MullvadGuids::Provider())
+ .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
+ .sublayer(MullvadGuids::SublayerWhitelist())
+ .weight(wfp::FilterBuilder::WeightClass::Max)
+ .permit();
+
+ wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
+
+ conditionBuilder.add_condition(ConditionIp::Remote(m_host));
+ conditionBuilder.add_condition(ConditionProtocol::IcmpV6());
+
+ if (m_interfaceAlias.has_value())
+ {
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value()));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+}
+
+}
diff --git a/windows/winfw/src/winfw/rules/permitping.h b/windows/winfw/src/winfw/rules/permitping.h
new file mode 100644
index 0000000000..c8238ceaa8
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/permitping.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "ifirewallrule.h"
+#include <libwfp/ipaddress.h>
+#include <string>
+#include <optional>
+
+namespace rules
+{
+
+class PermitPing : public IFirewallRule
+{
+public:
+
+ PermitPing(const std::optional<std::wstring> &interfaceAlias, const wfp::IpAddress &host);
+
+ bool apply(IObjectInstaller &objectInstaller) override;
+
+private:
+
+ const std::optional<std::wstring> m_interfaceAlias;
+ const wfp::IpAddress m_host;
+
+ bool applyIcmpv4(IObjectInstaller &objectInstaller) const;
+ bool applyIcmpv6(IObjectInstaller &objectInstaller) const;
+};
+
+}
diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp
index 7b9ea2dc6b..3065408f3d 100644
--- a/windows/winfw/src/winfw/winfw.cpp
+++ b/windows/winfw/src/winfw/winfw.cpp
@@ -4,6 +4,7 @@
#include "objectpurger.h"
#include <windows.h>
#include <stdexcept>
+#include <optional>
namespace
{
@@ -15,6 +16,34 @@ void * g_errorContext = nullptr;
FwContext *g_fwContext = nullptr;
+std::optional<FwContext::PingableHosts> ConvertPingableHosts(const PingableHosts *pingableHosts)
+{
+ if (nullptr == pingableHosts)
+ {
+ return {};
+ }
+
+ if (nullptr == pingableHosts->hosts
+ || 0 == pingableHosts->numHosts)
+ {
+ throw std::runtime_error("Invalid PingableHosts structure");
+ }
+
+ FwContext::PingableHosts converted;
+
+ if (nullptr != pingableHosts->tunnelInterfaceAlias)
+ {
+ converted.tunnelInterfaceAlias = pingableHosts->tunnelInterfaceAlias;
+ }
+
+ for (size_t i = 0; i < pingableHosts->numHosts; ++i)
+ {
+ converted.hosts.emplace_back(wfp::IpAddress(pingableHosts->hosts[i]));
+ }
+
+ return converted;
+}
+
} // anonymous namespace
WINFW_LINKAGE
@@ -130,7 +159,8 @@ bool
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings &settings,
- const WinFwRelay &relay
+ const WinFwRelay &relay,
+ const PingableHosts *pingableHosts
)
{
if (nullptr == g_fwContext)
@@ -140,7 +170,7 @@ WinFw_ApplyPolicyConnecting(
try
{
- return g_fwContext->applyPolicyConnecting(settings, relay);
+ return g_fwContext->applyPolicyConnecting(settings, relay, ConvertPingableHosts(pingableHosts));
}
catch (std::exception &err)
{
diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h
index 95e66a608f..6d43b0db4c 100644
--- a/windows/winfw/src/winfw/winfw.h
+++ b/windows/winfw/src/winfw/winfw.h
@@ -105,11 +105,29 @@ WINFW_API
WinFw_Deinitialize();
//
+// PingableHosts:
+//
+// Specifies a set of IP addresses that should be reachable by ICMP when the connecting
+// policy is effective.
+//
+// The interface alias is optional and can be used to restrict the traffic such
+// that it is only allowed on that specific interface.
+//
+typedef struct tag_PingableHosts
+{
+ const wchar_t *tunnelInterfaceAlias;
+ const wchar_t **hosts;
+ size_t numHosts;
+}
+PingableHosts;
+
+//
// ApplyPolicyConnecting:
//
// Apply restrictions in the firewall that block all traffic, except:
// - What is specified by settings
// - Communication with the relay server
+// - ICMP (for ping) to/from tunnel gateway
//
extern "C"
WINFW_LINKAGE
@@ -117,7 +135,8 @@ bool
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings &settings,
- const WinFwRelay &relay
+ const WinFwRelay &relay,
+ const PingableHosts *pingableHosts
);
//
diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj
index 4777503f72..cbabe2f4f7 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj
+++ b/windows/winfw/src/winfw/winfw.vcxproj
@@ -30,6 +30,7 @@
<ClCompile Include="rules\permitlanservice.cpp" />
<ClCompile Include="rules\permitloopback.cpp" />
<ClCompile Include="rules\permitndp.cpp" />
+ <ClCompile Include="rules\permitping.cpp" />
<ClCompile Include="rules\permitvpntunnelservice.cpp" />
<ClCompile Include="rules\permitvpnrelay.cpp" />
<ClCompile Include="rules\permitvpntunnel.cpp" />
@@ -53,6 +54,7 @@
<ClInclude Include="objectpurger.h" />
<ClInclude Include="rules\permitdhcpserver.h" />
<ClInclude Include="rules\permitndp.h" />
+ <ClInclude Include="rules\permitping.h" />
<ClInclude Include="wfpobjecttype.h" />
<ClInclude Include="rules\blockall.h" />
<ClInclude Include="rules\ifirewallrule.h" />
diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters
index 0319b0214a..a758a1c9ec 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj.filters
+++ b/windows/winfw/src/winfw/winfw.vcxproj.filters
@@ -43,6 +43,9 @@
<ClCompile Include="rules\permitndp.cpp">
<Filter>rules</Filter>
</ClCompile>
+ <ClCompile Include="rules\permitping.cpp">
+ <Filter>rules</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -93,6 +96,9 @@
<ClInclude Include="rules\permitndp.h">
<Filter>rules</Filter>
</ClInclude>
+ <ClInclude Include="rules\permitping.h">
+ <Filter>rules</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="rules">
diff --git a/windows/winnet/src/extras/loader/loader.vcxproj.filters b/windows/winnet/src/extras/loader/loader.vcxproj.filters
index cd0f4643c7..408a9591b1 100644
--- a/windows/winnet/src/extras/loader/loader.vcxproj.filters
+++ b/windows/winnet/src/extras/loader/loader.vcxproj.filters
@@ -3,9 +3,13 @@
<ItemGroup>
<ClCompile Include="loader.cpp" />
<ClCompile Include="stdafx.cpp" />
+ <ClCompile Include="..\..\winnet\routemanager.cpp" />
+ <ClCompile Include="..\..\winnet\adapters.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
+ <ClInclude Include="..\..\winnet\routemanager.h" />
+ <ClInclude Include="..\..\winnet\adapters.h" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/windows/winnet/src/winnet/interfaceutils.cpp b/windows/winnet/src/winnet/interfaceutils.cpp
index babe03eba6..202d9d0724 100644
--- a/windows/winnet/src/winnet/interfaceutils.cpp
+++ b/windows/winnet/src/winnet/interfaceutils.cpp
@@ -2,13 +2,8 @@
#include "interfaceutils.h"
#include "libcommon/error.h"
#include "libcommon/string.h"
-#include <vector>
#include <cstdint>
#include <algorithm>
-#include <winsock2.h>
-#include <iphlpapi.h>
-#include <windows.h>
-
//static
std::set<InterfaceUtils::NetworkAdapter> InterfaceUtils::GetAllAdapters()
@@ -112,3 +107,18 @@ std::wstring InterfaceUtils::GetTapInterfaceAlias()
throw std::runtime_error("Unable to find TAP adapter");
}
+
+//static
+void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses)
+{
+ for (const auto &address : addresses)
+ {
+ MIB_UNICASTIPADDRESS_ROW row;
+ InitializeUnicastIpAddressEntry(&row);
+
+ row.InterfaceLuid = device;
+ row.Address = address;
+
+ THROW_UNLESS(NO_ERROR, CreateUnicastIpAddressEntry(&row), "Assign IP address on network interface");
+ }
+}
diff --git a/windows/winnet/src/winnet/interfaceutils.h b/windows/winnet/src/winnet/interfaceutils.h
index f5c31963c2..8ab1249a50 100644
--- a/windows/winnet/src/winnet/interfaceutils.h
+++ b/windows/winnet/src/winnet/interfaceutils.h
@@ -2,6 +2,17 @@
#include <string>
#include <set>
+#include <vector>
+
+// Secret include order to get most common networking structs/apis
+// And avoiding compilation errors
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2def.h>
+#include <ws2ipdef.h>
+#include <iphlpapi.h>
+#include <netioapi.h>
+// end
class InterfaceUtils
{
@@ -35,4 +46,6 @@ public:
// Determines alias of primary TAP adapter.
//
static std::wstring GetTapInterfaceAlias();
+
+ static void AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses);
};
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
new file mode 100644
index 0000000000..55d7560904
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
@@ -0,0 +1,177 @@
+#include "stdafx.h"
+#include <libcommon/error.h>
+#include "defaultroutemonitor.h"
+#include "helpers.h"
+
+namespace winnet::routing
+{
+
+namespace
+{
+
+const uint32_t POINT_TWO_SECOND_BURST = 200;
+const uint32_t TWO_SECOND_INTERFERENCE = 2000;
+
+} // anonymous namespace
+
+DefaultRouteMonitor::DefaultRouteMonitor
+(
+ ADDRESS_FAMILY family,
+ Callback callback,
+ std::shared_ptr<common::logging::ILogSink> logSink
+)
+ : m_family(family)
+ , m_callback(callback)
+ , m_logSink(logSink)
+ , m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>(
+ std::bind(&DefaultRouteMonitor::evaluateRoutes, this),
+ POINT_TWO_SECOND_BURST,
+ TWO_SECOND_INTERFERENCE
+ ))
+{
+ try
+ {
+ m_bestRoute = GetBestDefaultRoute(m_family);
+ }
+ catch (...)
+ {
+ }
+
+ const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle);
+
+ THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications");
+
+ try
+ {
+ const auto s2 = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this,
+ FALSE, &m_interfaceNotificationHandle);
+
+ THROW_UNLESS(NO_ERROR, status, "Register for network interface change notifications");
+ }
+ catch (...)
+ {
+ CancelMibChangeNotify2(m_routeNotificationHandle);
+ throw;
+ }
+}
+
+DefaultRouteMonitor::~DefaultRouteMonitor()
+{
+ //
+ // Cancel notifications to stop triggering the BurstGuard.
+ //
+
+ CancelMibChangeNotify2(m_interfaceNotificationHandle);
+ CancelMibChangeNotify2(m_routeNotificationHandle);
+
+ //
+ // Controlled destruction of BurstGuard to prevent it from calling here
+ // after other member variables have been destructed.
+ //
+
+ m_evaluateRoutesGuard.reset();
+}
+
+//static
+void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback
+(
+ void *context,
+ MIB_IPFORWARD_ROW2 *row,
+ MIB_NOTIFICATION_TYPE
+)
+{
+ //
+ // We're only interested in changes that add/remove/update a default route.
+ //
+
+ if (0 != row->DestinationPrefix.PrefixLength
+ || false == RouteHasGateway(*row))
+ {
+ return;
+ }
+
+ reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+}
+
+//static
+void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback
+(
+ void *context,
+ MIB_IPINTERFACE_ROW *,
+ MIB_NOTIFICATION_TYPE
+)
+{
+ reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+}
+
+void DefaultRouteMonitor::evaluateRoutes()
+{
+ std::scoped_lock<std::mutex> lock(m_evaluationLock);
+
+ try
+ {
+ evaluateRoutesInner();
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failure while evaluating route table: ").append(ex.what());
+ m_logSink->error(msg.c_str());
+ }
+ catch (...)
+ {
+ m_logSink->error("Unspecified failure while evaluating route table");
+ }
+}
+
+void DefaultRouteMonitor::evaluateRoutesInner()
+{
+ std::optional<InterfaceAndGateway> currentBestRoute;
+
+ try
+ {
+ currentBestRoute = GetBestDefaultRoute(m_family);
+ }
+ catch (...)
+ {
+ }
+
+ //
+ // If there was no default route previously.
+ //
+
+ if (false == m_bestRoute.has_value())
+ {
+ if (currentBestRoute.has_value())
+ {
+ m_bestRoute = currentBestRoute;
+ m_callback(EventType::Updated, m_bestRoute);
+ }
+
+ return;
+ }
+
+ //
+ // There used to be a default route.
+ // If there is not currently a default route.
+ //
+
+ if (false == currentBestRoute.has_value())
+ {
+ m_bestRoute.reset();
+ m_callback(EventType::Removed, std::nullopt);
+
+ return;
+ }
+
+ //
+ // The current best route may have changed.
+ //
+
+ if (m_bestRoute.value() != currentBestRoute.value())
+ {
+ m_bestRoute = currentBestRoute;
+ m_callback(EventType::Updated, m_bestRoute);
+ }
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
new file mode 100644
index 0000000000..5575685a82
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
@@ -0,0 +1,69 @@
+#pragma once
+
+#include <ifdef.h>
+#include <ws2def.h>
+#include <functional>
+#include <optional>
+#include <memory>
+#include <mutex>
+#include <libcommon/logging/ilogsink.h>
+#include <libcommon/burstguard.h>
+#include "types.h"
+
+namespace winnet::routing
+{
+
+class DefaultRouteMonitor
+{
+public:
+
+ enum class EventType
+ {
+ // The best default route changed.
+ Updated,
+
+ // No default routes exist.
+ Removed,
+ };
+
+ using Callback = std::function<void
+ (
+ EventType eventType,
+
+ // For update events, data associated with the new best default route.
+ const std::optional<InterfaceAndGateway> &route
+ )>;
+
+ DefaultRouteMonitor(ADDRESS_FAMILY family, Callback callback, std::shared_ptr<common::logging::ILogSink> logSink);
+ ~DefaultRouteMonitor();
+
+ DefaultRouteMonitor(const DefaultRouteMonitor &) = delete;
+ DefaultRouteMonitor(DefaultRouteMonitor &&) = delete;
+ DefaultRouteMonitor &operator=(const DefaultRouteMonitor &) = delete;
+ DefaultRouteMonitor &operator=(DefaultRouteMonitor &&) = delete;
+
+private:
+
+ ADDRESS_FAMILY m_family;
+ Callback m_callback;
+ std::shared_ptr<common::logging::ILogSink> m_logSink;
+
+ // This can't be a plain member variable.
+ // We need to be able to delete it explicitly in order to have a controlled tear down.
+ std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard;
+
+ std::optional<InterfaceAndGateway> m_bestRoute;
+
+ HANDLE m_routeNotificationHandle;
+ HANDLE m_interfaceNotificationHandle;
+
+ std::mutex m_evaluationLock;
+
+ static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType);
+ static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType);
+
+ void evaluateRoutes();
+ void evaluateRoutesInner();
+};
+
+}
diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp
new file mode 100644
index 0000000000..cabf19bce6
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/helpers.cpp
@@ -0,0 +1,275 @@
+#include "stdafx.h"
+#include "helpers.h"
+#include <stdexcept>
+#include <ws2def.h>
+#include <in6addr.h>
+#include <numeric>
+//#include <netioapi.h>
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+
+namespace winnet::routing
+{
+
+bool EqualAddress(const Network &lhs, const Network &rhs)
+{
+ if (lhs.PrefixLength != rhs.PrefixLength)
+ {
+ return false;
+ }
+
+ return EqualAddress(lhs.Prefix, rhs.Prefix);
+}
+
+bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs)
+{
+ if (lhs.si_family != rhs.si_family)
+ {
+ return false;
+ }
+
+ switch (lhs.si_family)
+ {
+ case AF_INET:
+ {
+ return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR));
+ }
+ default:
+ {
+ throw std::runtime_error("Invalid address family for network address");
+ }
+ }
+}
+
+bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs)
+{
+ if (lhs->si_family != rhs->lpSockaddr->sa_family)
+ {
+ return false;
+ }
+
+ switch (lhs->si_family)
+ {
+ case AF_INET:
+ {
+ auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr);
+ return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr);
+ return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16);
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+}
+
+bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface)
+{
+ memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW));
+
+ iface->Family = addressFamily;
+ iface->InterfaceLuid = adapter;
+
+ return NO_ERROR == GetIpInterfaceEntry(iface);
+}
+
+std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes)
+{
+ std::vector<AnnotatedRoute> annotated;
+ annotated.reserve(routes.size());
+
+ for (auto route : routes)
+ {
+ MIB_IPINTERFACE_ROW iface;
+
+ if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface))
+ {
+ continue;
+ }
+
+ annotated.emplace_back
+ (
+ AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric }
+ );
+ }
+
+ return annotated;
+}
+
+bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route)
+{
+ switch (route.NextHop.si_family)
+ {
+ case AF_INET:
+ {
+ return 0 != route.NextHop.Ipv4.sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0];
+ const uint8_t *end = begin + 16;
+
+ return 0 != std::accumulate(begin, end, 0);
+ }
+ default:
+ {
+ return false;
+ }
+ };
+}
+
+InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family)
+{
+ PMIB_IPFORWARD_TABLE2 table;
+
+ auto status = GetIpForwardTable2(family, &table);
+
+ THROW_UNLESS(NO_ERROR, status, "Acquire route table");
+
+ common::memory::ScopeDestructor sd;
+
+ sd += [table]
+ {
+ FreeMibTable(table);
+ };
+
+ std::vector<const MIB_IPFORWARD_ROW2 *> candidates;
+ candidates.reserve(table->NumEntries);
+
+ //
+ // Enumerate routes looking for: route 0/0 && gateway specified.
+ //
+
+ for (ULONG i = 0; i < table->NumEntries; ++i)
+ {
+ const MIB_IPFORWARD_ROW2 &candidate = table->Table[i];
+
+ if (0 == candidate.DestinationPrefix.PrefixLength
+ && RouteHasGateway(candidate))
+ {
+ candidates.emplace_back(&candidate);
+ }
+ }
+
+ auto annotated = AnnotateRoutes(candidates);
+
+ if (annotated.empty())
+ {
+ throw std::runtime_error("Unable to determine details of default route");
+ }
+
+ //
+ // Sort on (active, effectiveMetric) ascending by metric.
+ //
+
+ std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs)
+ {
+ if (lhs.active == rhs.active)
+ {
+ return lhs.effectiveMetric < rhs.effectiveMetric;
+ }
+
+ return lhs.active && false == rhs.active;
+ });
+
+ //
+ // Ensure the top rated route is active.
+ //
+
+ if (false == annotated[0].active)
+ {
+ throw std::runtime_error("Unable to identify active default route");
+ }
+
+ return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop };
+}
+
+bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family)
+{
+ switch (family)
+ {
+ case AF_INET:
+ {
+ return 0 != adapter->Ipv4Enabled;
+ }
+ case AF_INET6:
+ {
+ return 0 != adapter->Ipv6Enabled;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+}
+
+std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses
+(
+ PIP_ADAPTER_GATEWAY_ADDRESS_LH head,
+ ADDRESS_FAMILY family
+)
+{
+ std::vector<const SOCKET_ADDRESS *> matches;
+
+ for (auto gateway = head; nullptr != gateway; gateway = gateway->Next)
+ {
+ if (family == gateway->Address.lpSockaddr->sa_family)
+ {
+ matches.emplace_back(&gateway->Address);
+ }
+ }
+
+ return matches;
+}
+
+bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle)
+{
+ for (const auto candidate : hay)
+ {
+ if (EqualAddress(needle, candidate))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa)
+//{
+// NodeAddress out = { 0 };
+//
+// switch (sa->lpSockaddr->sa_family)
+// {
+// case AF_INET:
+// {
+// out.si_family = AF_INET;
+// out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr);
+//
+// break;
+// }
+// case AF_INET6:
+// {
+// out.si_family = AF_INET6;
+// out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr);
+//
+// break;
+// }
+// default:
+// {
+// throw std::runtime_error("Missing case handler in switch clause");
+// }
+// };
+//
+// return out;
+//}
+
+}
diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h
new file mode 100644
index 0000000000..3ef5e85b75
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/helpers.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "types.h"
+#include <vector>
+
+namespace winnet::routing
+{
+
+bool EqualAddress(const Network &lhs, const Network &rhs);
+bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs);
+bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs);
+
+bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface);
+
+struct AnnotatedRoute
+{
+ const MIB_IPFORWARD_ROW2 *route;
+ bool active;
+ uint32_t effectiveMetric;
+};
+
+template<typename T>
+bool bool_cast(const T &value)
+{
+ return 0 != value;
+}
+
+std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes);
+
+bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route);
+
+InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family);
+
+bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family);
+
+std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses
+(
+ PIP_ADAPTER_GATEWAY_ADDRESS_LH head,
+ ADDRESS_FAMILY family
+);
+
+bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle);
+
+//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa);
+
+}
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp
new file mode 100644
index 0000000000..668e64bb68
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/routemanager.cpp
@@ -0,0 +1,692 @@
+#include "stdafx.h"
+#include "routemanager.h"
+#include "helpers.h"
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+#include <libcommon/string.h>
+#include <libcommon/network/adapters.h>
+#include <vector>
+#include <algorithm>
+#include <numeric>
+#include <sstream>
+#include <stdexcept>
+
+using AutoLockType = std::scoped_lock<std::mutex>;
+using AutoRecursiveLockType = std::scoped_lock<std::recursive_mutex>;
+using namespace std::placeholders;
+
+namespace winnet::routing
+{
+
+namespace
+{
+
+using Adapters = common::network::Adapters;
+
+NET_LUID InterfaceLuidFromGateway(const NodeAddress &gateway)
+{
+ const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER
+ | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS;
+
+ Adapters adapters(gateway.si_family, adapterFlags);
+
+ //
+ // Process adapters to find matching ones.
+ //
+
+ std::vector<const IP_ADAPTER_ADDRESSES *> matches;
+
+ for (auto adapter = adapters.next(); nullptr != adapter; adapter = adapters.next())
+ {
+ if (false == AdapterInterfaceEnabled(adapter, gateway.si_family))
+ {
+ continue;
+ }
+
+ auto gateways = IsolateGatewayAddresses(adapter->FirstGatewayAddress, gateway.si_family);
+
+ if (AddressPresent(gateways, &gateway))
+ {
+ matches.emplace_back(adapter);
+ }
+ }
+
+ if (matches.empty())
+ {
+ throw std::runtime_error("Unable to find network adapter with specified gateway");
+ }
+
+ //
+ // Sort matching interfaces ascending by metric.
+ //
+
+ const bool targetV4 = (AF_INET == gateway.si_family);
+
+ std::sort(matches.begin(), matches.end(), [&targetV4](const IP_ADAPTER_ADDRESSES *lhs, const IP_ADAPTER_ADDRESSES *rhs)
+ {
+ if (targetV4)
+ {
+ return lhs->Ipv4Metric < rhs->Ipv4Metric;
+ }
+
+ return lhs->Ipv6Metric < rhs->Ipv6Metric;
+ });
+
+ //
+ // Select the interface with the best (lowest) metric.
+ //
+
+ return matches[0]->Luid;
+}
+
+bool ParseStringEncodedLuid(const std::wstring &encodedLuid, NET_LUID &luid)
+{
+ //
+ // The `#` is a valid character in adapter names so we use `?` instead.
+ // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes.
+ // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`.
+ //
+
+ static const size_t StringEncodedLuidLength = 17;
+
+ if (encodedLuid.size() != StringEncodedLuidLength
+ || L'?' != encodedLuid[0])
+ {
+ return false;
+ }
+
+ try
+ {
+ std::wstringstream ss;
+
+ ss << std::hex << &encodedLuid[1];
+ ss >> luid.Value;
+ }
+ catch (...)
+ {
+ const auto ansi = common::string::ToAnsi(encodedLuid);
+ const auto err = std::string("Failed to parse string encoded LUID: ").append(ansi);
+
+ std::throw_with_nested(std::runtime_error(err));
+ }
+
+ return true;
+}
+
+InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node> &optionalNode)
+{
+ //
+ // There are four cases:
+ //
+ // Unspecified node (use interface and gateway of default route).
+ // Node is specified by name.
+ // Node is specified by name and gateway.
+ // Node is specified by gateway.
+ //
+
+ if (false == optionalNode.has_value())
+ {
+ return GetBestDefaultRoute(family);
+ }
+
+ const auto &node = optionalNode.value();
+
+ if (node.deviceName().has_value())
+ {
+ const auto &deviceName = node.deviceName().value();
+ NET_LUID luid;
+
+ if (false == ParseStringEncodedLuid(deviceName, luid)
+ && 0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceName);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ auto onLinkProvider = [&family]()
+ {
+ NodeAddress onLink = { 0 };
+ onLink.si_family = family;
+
+ return onLink;
+ };
+
+ return InterfaceAndGateway{ luid, node.gateway().value_or(onLinkProvider()) };
+ }
+
+ //
+ // The node is specified only by gateway.
+ //
+
+ return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() };
+}
+
+// TODO: Move to libcommon
+uint32_t ByteSwap(uint32_t val)
+{
+ return
+ (
+ ((val & 0xFF) << 24) |
+ ((val & 0xFF00) << 8) |
+ ((val & 0xFF0000) >> 8) |
+ ((val & 0xFF000000) >> 24)
+ );
+}
+
+std::wstring FormatNetwork(const Network &network)
+{
+ switch (network.Prefix.si_family)
+ {
+ case AF_INET:
+ {
+ return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength);
+ }
+ case AF_INET6:
+ {
+ return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength);
+ }
+ default:
+ {
+ return L"Failed to format network details";
+ }
+ }
+}
+
+} // anonymous namespace
+
+RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink)
+ : m_logSink(logSink)
+ , m_routeMonitorV4(std::make_unique<DefaultRouteMonitor>(
+ static_cast<ADDRESS_FAMILY>(AF_INET),
+ std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET), _1, _2),
+ logSink
+ ))
+ , m_routeMonitorV6(std::make_unique<DefaultRouteMonitor>(
+ static_cast<ADDRESS_FAMILY>(AF_INET6),
+ std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET6), _1, _2),
+ logSink
+ ))
+{
+}
+
+RouteManager::~RouteManager()
+{
+ //
+ // Stop callbacks that are triggered by events in Windows from coming in.
+ //
+
+ m_routeMonitorV4.reset();
+ m_routeMonitorV6.reset();
+
+ //
+ // Delete all routes owned by us.
+ //
+
+ for (const auto &record : m_routes)
+ {
+ try
+ {
+ deleteFromRoutingTable(record.registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ std::wstringstream ss;
+
+ ss << L"Failed to delete route as part of cleaning up, Route: "
+ << FormatRegisteredRoute(record.registeredRoute);
+
+ m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
+ m_logSink->error(ex.what());
+ }
+ }
+}
+
+void RouteManager::addRoutes(const std::vector<Route> &routes)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::vector<EventEntry> eventLog;
+
+ for (const auto &route : routes)
+ {
+ try
+ {
+ auto record = findRouteRecord(route);
+
+ if (record != m_routes.end())
+ {
+ deleteFromRoutingTable(record->registeredRoute);
+ eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
+ m_routes.erase(record);
+ }
+
+ const RouteRecord newRecord { route, addIntoRoutingTable(route) };
+
+ eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord });
+ m_routes.emplace_back(std::move(newRecord));
+ }
+ catch (...)
+ {
+ undoEvents(eventLog);
+
+ std::throw_with_nested(std::runtime_error("Failed during batch insertion of routes"));
+ }
+ }
+}
+
+void RouteManager::addRoute(const Route &route)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::optional<RouteRecord> deletedRecord;
+
+ auto record = findRouteRecord(route);
+
+ if (record != m_routes.end())
+ {
+ try
+ {
+ deleteFromRoutingTable(record->registeredRoute);
+ }
+ catch (...)
+ {
+ std::throw_with_nested(std::runtime_error("Failed to evict old route when adding new route"));
+ }
+
+ deletedRecord = *record;
+ m_routes.erase(record);
+ }
+
+ try
+ {
+ m_routes.emplace_back
+ (
+ RouteRecord{ route, addIntoRoutingTable(route) }
+ );
+ }
+ catch (...)
+ {
+ //
+ // Restore deleted record.
+ //
+
+ if (deletedRecord.has_value())
+ {
+ auto &r = deletedRecord.value();
+
+ try
+ {
+ restoreIntoRoutingTable(r.registeredRoute);
+ m_routes.emplace_back(r);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what());
+ m_logSink->error(err.c_str());
+ }
+ }
+
+ //
+ // Just rethrow because the error is from addIntoRoutingTable().
+ //
+
+ throw;
+ }
+}
+
+void RouteManager::deleteRoutes(const std::vector<Route> &routes)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::vector<EventEntry> eventLog;
+
+ for (const auto &route : routes)
+ {
+ try
+ {
+ auto record = findRouteRecord(route);
+
+ if (m_routes.end() == record)
+ {
+ const auto err = std::wstring(L"Request to delete previously unregistered route: ")
+ .append(FormatNetwork(route.network()));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+
+ continue;
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
+ m_routes.erase(record);
+ }
+ catch (...)
+ {
+ undoEvents(eventLog);
+
+ std::throw_with_nested(std::runtime_error("Failed during batch removal of routes"));
+ }
+ }
+}
+
+void RouteManager::deleteRoute(const Route &route)
+{
+ AutoLockType lock(m_routesLock);
+
+ auto record = findRouteRecord(route);
+
+ if (m_routes.end() == record)
+ {
+ const auto err = std::wstring(L"Request to delete previously unregistered route: ")
+ .append(FormatNetwork(route.network()));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+
+ return;
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ m_routes.erase(record);
+}
+
+RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
+{
+ AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
+
+ m_defaultRouteCallbacks.emplace_back(callback);
+
+ // Return raw address of record in list.
+ return &m_defaultRouteCallbacks.back();
+}
+
+void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle)
+{
+ AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
+
+ for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it)
+ {
+ // Match on raw address of record.
+ if (&*it == handle)
+ {
+ m_defaultRouteCallbacks.erase(it);
+ return;
+ }
+ }
+}
+
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network)
+{
+ return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate)
+ {
+ return EqualAddress(network, candidate.route.network());
+ });
+}
+
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route)
+{
+ return findRouteRecord(route.network());
+}
+
+RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route)
+{
+ const auto node = ResolveNode(route.network().Prefix.si_family, route.node());
+
+ MIB_IPFORWARD_ROW2 spec;
+
+ InitializeIpForwardEntry(&spec);
+
+ spec.InterfaceLuid = node.iface;
+ spec.DestinationPrefix = route.network();
+ spec.NextHop = node.gateway;
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ //
+ // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful.
+ // Because it may not take route metric into consideration.
+ //
+
+ THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table");
+
+ return RegisteredRoute { route.network(), node.iface, node.gateway };
+}
+
+void RouteManager::restoreIntoRoutingTable(const RegisteredRoute &route)
+{
+ MIB_IPFORWARD_ROW2 spec;
+
+ InitializeIpForwardEntry(&spec);
+
+ spec.InterfaceLuid = route.luid;
+ spec.DestinationPrefix = route.network;
+ spec.NextHop = route.nextHop;
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table");
+}
+
+void RouteManager::deleteFromRoutingTable(const RegisteredRoute &route)
+{
+ MIB_IPFORWARD_ROW2 r = { 0};
+
+ r.InterfaceLuid = route.luid;
+ r.DestinationPrefix = route.network;
+ r.NextHop = route.nextHop;
+
+ auto status = DeleteIpForwardEntry2(&r);
+
+ if (ERROR_NOT_FOUND == status)
+ {
+ status = NO_ERROR;
+
+ const auto err = std::wstring(L"Attempting to delete route which was not present in routing table, " \
+ "ignoring and proceeding. Route: ").append(FormatRegisteredRoute(route));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+ }
+
+ THROW_UNLESS(NO_ERROR, status, "Delete route in routing table");
+}
+
+void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog)
+{
+ //
+ // Rewind state by processing events in the reverse order.
+ //
+
+ for (auto it = eventLog.rbegin(); it != eventLog.rend(); ++it)
+ {
+ try
+ {
+ switch (it->type)
+ {
+ case EventType::ADD_ROUTE:
+ {
+ auto record = findRouteRecord(it->record.route);
+
+ if (m_routes.end() == record)
+ {
+ throw std::runtime_error("Internal state inconsistency in route manager");
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ m_routes.erase(record);
+
+ break;
+ }
+ case EventType::DELETE_ROUTE:
+ {
+ restoreIntoRoutingTable(it->record.registeredRoute);
+ m_routes.emplace_back(it->record);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Missing case handler in switch clause");
+ }
+ }
+ }
+ catch (const std::exception &ex)
+ {
+ const auto err = std::string("Attempting to rollback state: ").append(ex.what());
+ m_logSink->error(err.c_str());
+ }
+ }
+}
+
+// static
+std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
+{
+ //
+ // TODO: Fix broken IP formatting
+ // Update FormatIpv4 function with an additional argument to specify network/host byte order.
+ //
+
+ std::wstringstream ss;
+
+ if (AF_INET == route.network.Prefix.si_family)
+ {
+ std::wstring gateway(L"\"On-link\"");
+
+ if (0 != route.nextHop.Ipv4.sin_addr.s_addr)
+ {
+ gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr));
+ }
+
+ ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength)
+ << L" with gateway " << gateway
+ << L" on interface with LUID 0x" << std::hex << route.luid.Value;
+ }
+ else if (AF_INET6 == route.network.Prefix.si_family)
+ {
+ std::wstring gateway(L"\"On-link\"");
+
+ const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0];
+ const uint8_t *end = begin + 16;
+
+ if (0 != std::accumulate(begin, end, 0))
+ {
+ gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte);
+ }
+
+ ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength)
+ << L" with gateway " << gateway
+ << L" on interface with LUID 0x" << std::hex << route.luid.Value;
+ }
+ else
+ {
+ ss << L"Failed to format route details";
+ }
+
+ return ss.str();
+}
+
+void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType,
+ const std::optional<InterfaceAndGateway> &route)
+{
+ //
+ // Forward event to all registered listeners.
+ //
+
+ m_defaultRouteCallbacksLock.lock();
+
+ for (const auto &callback : m_defaultRouteCallbacks)
+ {
+ try
+ {
+ callback(eventType, family, route);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failure in default-route-changed callback: ").append(ex.what());
+ m_logSink->error(msg.c_str());
+ }
+ catch (...)
+ {
+ m_logSink->error("Unspecified failure in default-route-changed callback");
+ }
+ }
+
+ m_defaultRouteCallbacksLock.unlock();
+
+ //
+ // Examine event to determine if best default route has changed.
+ //
+
+ if (DefaultRouteMonitor::EventType::Updated != eventType)
+ {
+ return;
+ }
+
+ //
+ // Examine our routes to see if any of them are policy bound to the best default route.
+ //
+
+ AutoLockType routesLock(m_routesLock);
+
+ using RecordIterator = std::list<RouteRecord>::iterator;
+
+ std::list<RecordIterator> affectedRoutes;
+
+ for (RecordIterator it = m_routes.begin(); it != m_routes.end(); ++it)
+ {
+ if (false == it->route.node().has_value()
+ && family == it->route.network().Prefix.si_family)
+ {
+ affectedRoutes.emplace_back(it);
+ }
+ }
+
+ if (affectedRoutes.empty())
+ {
+ return;
+ }
+
+ //
+ // Update all affected routes.
+ //
+
+ m_logSink->info("Best default route has changed. Refreshing dependent routes");
+
+ for (auto &it : affectedRoutes)
+ {
+ try
+ {
+ deleteFromRoutingTable(it->registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failed to delete route when refreshing " \
+ "existing routes: ").append(ex.what());
+
+ m_logSink->error(msg.c_str());
+
+ continue;
+ }
+
+ it->registeredRoute.luid = route.value().iface;
+ it->registeredRoute.nextHop = route.value().gateway;
+
+ try
+ {
+ restoreIntoRoutingTable(it->registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failed to add route when refreshing " \
+ "existing routes: ").append(ex.what());
+
+ m_logSink->error(msg.c_str());
+
+ continue;
+ }
+ }
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h
new file mode 100644
index 0000000000..981c8e6834
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/routemanager.h
@@ -0,0 +1,112 @@
+#pragma once
+
+#include <string>
+#include <memory>
+#include <vector>
+#include <list>
+#include <optional>
+#include <mutex>
+#include <functional>
+#include <windows.h>
+#include <ws2def.h>
+#include <ifdef.h>
+#include <libcommon/string.h>
+#include <libcommon/logging/ilogsink.h>
+#include "defaultroutemonitor.h"
+
+namespace winnet::routing
+{
+
+class RouteManager
+{
+public:
+
+ RouteManager(std::shared_ptr<common::logging::ILogSink> logSink);
+ ~RouteManager();
+
+ RouteManager(const RouteManager &) = delete;
+ RouteManager(RouteManager &&) = default;
+ RouteManager &operator=(const RouteManager &) = delete;
+ RouteManager &operator=(RouteManager &&) = delete;
+
+ void addRoutes(const std::vector<Route> &routes);
+ void addRoute(const Route &route);
+
+ void deleteRoutes(const std::vector<Route> &routes);
+ void deleteRoute(const Route &route);
+
+ using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType;
+
+ using DefaultRouteChangedCallback = std::function<void
+ (
+ DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family,
+
+ // For update events, data associated with the new best default route.
+ const std::optional<InterfaceAndGateway> &route
+ )>;
+
+ using CallbackHandle = void*;
+
+ CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback);
+ void unregisterDefaultRouteChangedCallback(CallbackHandle handle);
+
+private:
+
+ std::shared_ptr<common::logging::ILogSink> m_logSink;
+
+ std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV4;
+ std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV6;
+
+ // These are the exact details derived from the route specification (`Route`).
+ // They are used when registering and deleting a route in the system.
+ struct RegisteredRoute
+ {
+ Network network;
+ NET_LUID luid;
+ NodeAddress nextHop;
+ };
+
+ struct RouteRecord
+ {
+ Route route;
+ RegisteredRoute registeredRoute;
+ };
+
+ std::list<RouteRecord> m_routes;
+ std::mutex m_routesLock;
+
+ std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks;
+ std::recursive_mutex m_defaultRouteCallbacksLock;
+
+ // Find record based on destination and mask.
+ std::list<RouteRecord>::iterator findRouteRecord(const Network &network);
+
+ // Note: Same as above!
+ std::list<RouteRecord>::iterator findRouteRecord(const Route &route);
+
+ RegisteredRoute addIntoRoutingTable(const Route &route);
+ void restoreIntoRoutingTable(const RegisteredRoute &route);
+ void deleteFromRoutingTable(const RegisteredRoute &route);
+
+ enum class EventType
+ {
+ ADD_ROUTE,
+ DELETE_ROUTE,
+ };
+
+ struct EventEntry
+ {
+ EventType type;
+ RouteRecord record;
+ };
+
+ void undoEvents(const std::vector<EventEntry> &eventLog);
+
+ static std::wstring FormatRegisteredRoute(const RegisteredRoute &route);
+
+ void defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType,
+ const std::optional<InterfaceAndGateway> &route);
+};
+
+}
diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp
new file mode 100644
index 0000000000..ac71c8108f
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/types.cpp
@@ -0,0 +1,84 @@
+#include "stdafx.h"
+#include "types.h"
+#include "helpers.h"
+#include <libcommon/string.h>
+
+namespace winnet::routing
+{
+
+Node::Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway)
+ : m_deviceName(deviceName)
+ , m_gateway(gateway)
+{
+ if (false == m_deviceName.has_value() && false == m_gateway.has_value())
+ {
+ throw std::runtime_error("Invalid node definition");
+ }
+
+ if (m_deviceName.has_value())
+ {
+ const auto trimmed = common::string::Trim<>(m_deviceName.value());
+
+ if (trimmed.empty())
+ {
+ throw std::runtime_error("Invalid device name in node definition");
+ }
+
+ m_deviceName = std::move(trimmed);
+ }
+}
+
+bool Node::operator==(const Node &rhs) const
+{
+ if (m_deviceName.has_value())
+ {
+ if (false == rhs.m_deviceName.has_value()
+ || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str()))
+ {
+ return false;
+ }
+ }
+
+ if (m_gateway.has_value())
+ {
+ if (false == rhs.m_gateway.has_value()
+ || false == EqualAddress(m_gateway.value(), rhs.gateway().value()))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+Route::Route(const Network &network, const std::optional<Node> &node)
+ : m_network(network)
+ , m_node(node)
+{
+}
+
+bool Route::operator==(const Route &rhs) const
+{
+ if (m_node.has_value())
+ {
+ return rhs.node().has_value()
+ && EqualAddress(m_network, rhs.network())
+ && m_node.value() == rhs.node().value();
+ }
+
+ return false == rhs.node().has_value()
+ && EqualAddress(m_network, rhs.network());
+}
+
+bool InterfaceAndGateway::operator==(const InterfaceAndGateway &rhs)
+{
+ return iface.Value == rhs.iface.Value
+ && EqualAddress(gateway, rhs.gateway);
+}
+
+bool InterfaceAndGateway::operator!=(const InterfaceAndGateway &rhs)
+{
+ return !(*this == rhs);
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h
new file mode 100644
index 0000000000..1e132feb00
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/types.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#include <string>
+#include <optional>
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2def.h>
+#include <ws2ipdef.h>
+#include <iphlpapi.h>
+//#include <netioapi.h>
+//#include <functional>
+
+
+namespace winnet::routing
+{
+
+using Network = IP_ADDRESS_PREFIX;
+using NodeAddress = SOCKADDR_INET;
+
+class Node
+{
+public:
+
+ Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway);
+
+ const std::optional<std::wstring> &deviceName() const
+ {
+ return m_deviceName;
+ }
+
+ const std::optional<NodeAddress> &gateway() const
+ {
+ return m_gateway;
+ }
+
+ bool operator==(const Node &rhs) const;
+
+private:
+
+ std::optional<std::wstring> m_deviceName;
+ std::optional<NodeAddress> m_gateway;
+};
+
+class Route
+{
+public:
+
+ Route(const Network &network, const std::optional<Node> &node);
+
+ const Network &network() const
+ {
+ return m_network;
+ }
+
+ const std::optional<Node> &node() const
+ {
+ return m_node;
+ }
+
+ bool operator==(const Route &rhs) const;
+
+private:
+
+ Network m_network;
+ std::optional<Node> m_node;
+};
+
+struct InterfaceAndGateway
+{
+ NET_LUID iface;
+ NodeAddress gateway;
+
+ bool operator==(const InterfaceAndGateway &rhs);
+ bool operator!=(const InterfaceAndGateway &rhs);
+};
+
+}
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 4b006964a6..48d12b5ea3 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -3,17 +3,135 @@
#include "NetworkInterfaces.h"
#include "interfaceutils.h"
#include "offlinemonitor.h"
+#include "routing/routemanager.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
+#include <libcommon/network.h>
#include <cstdint>
#include <stdexcept>
#include <memory>
+#include <optional>
+#include <mutex>
+
+using namespace winnet::routing;
+using AutoLockType = std::scoped_lock<std::mutex>;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+std::mutex g_RouteManagerLock;
+RouteManager *g_RouteManager = nullptr;
+std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
+
+Network ConvertNetwork(const WINNET_IPNETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out{ 0 };
+
+ out.PrefixLength = in.prefix;
+
+ switch (in.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ out.Prefix.si_family = AF_INET;
+ out.Prefix.Ipv4.sin_family = AF_INET;
+ out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ out.Prefix.si_family = AF_INET6;
+ out.Prefix.Ipv6.sin6_family = AF_INET6;
+ memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return {};
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ throw std::runtime_error("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ NodeAddress gw { 0 };
+
+ switch (in->gateway->type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ gw.si_family = AF_INET;
+ gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ gw.si_family = AF_INET6;
+ memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid gateway type specifier in 'WINNET_NODE' definition");
+ }
+ }
+
+ gateway = gw;
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::exception &err)
{
if (nullptr == logSink)
@@ -26,6 +144,49 @@ void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::excep
common::error::UnwindException(err, logger);
}
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
+{
+ //
+ // This duplicates the same logic we have above.
+ // TODO: Fix when time permits.
+ //
+
+ std::vector<SOCKADDR_INET> out;
+ out.reserve(numAddresses);
+
+ for (uint32_t i = 0; i < numAddresses; ++i)
+ {
+ const WINNET_IP &from = addresses[i];
+ SOCKADDR_INET to{ 0 };
+
+ switch (from.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ to.si_family = AF_INET;
+ to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(from.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ to.si_family = AF_INET6;
+ memcpy(&to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid address family in 'WINNET_IP' definition");
+ }
+ }
+
+ out.push_back(to);
+ }
+
+ return out;
+}
+
} //anonymous namespace
extern "C"
@@ -66,12 +227,12 @@ WinNet_GetTapInterfaceIpv6Status(
{
try
{
- MIB_IPINTERFACE_ROW interface = { 0 };
+ MIB_IPINTERFACE_ROW iface = { 0 };
- interface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
- interface.Family = AF_INET6;
+ iface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
+ iface.Family = AF_INET6;
- const auto status = GetIpInterfaceEntry(&interface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -201,3 +362,360 @@ WinNet_DeactivateConnectivityMonitor(
{
}
}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ if (nullptr != g_RouteManager)
+ {
+ throw std::runtime_error("Cannot activate route manager twice");
+ }
+
+ g_RouteManagerLogSink = std::make_shared<shared::LogSinkAdapter>(logSink, logSinkContext);
+ g_RouteManager = new RouteManager(g_RouteManagerLogSink);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+//
+// TODO: Move to libcommon.
+//
+struct ValueMapper
+{
+ template<typename T, typename U, std::size_t S>
+ static U map(T t, const std::pair<T, U> (&dictionary)[S])
+ {
+ for (const auto &entry : dictionary)
+ {
+ if (t == entry.first)
+ {
+ return entry.second;
+ }
+ }
+
+ throw std::runtime_error("Could not map between values");
+ }
+};
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ auto forwarder = [callback, context](RouteManager::DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family, const std::optional<InterfaceAndGateway> &route)
+ {
+ //
+ // Translate the event type.
+ //
+
+ using from_t = RouteManager::DefaultRouteChangedEventType;
+ using to_t = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE;
+
+ static const std::pair<from_t, to_t> eventTypeMap[] =
+ {
+ { from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
+ };
+
+ const auto translatedEventType = ValueMapper::map<>(eventType, eventTypeMap);
+
+ //
+ // Translate the family type.
+ //
+
+ static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
+ {
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ };
+
+ const auto translatedFamily = ValueMapper::map<>(family, familyMap);
+
+ //
+ // Determine which LUID to forward.
+ //
+
+ uint64_t translatedLuid = 0;
+
+ if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ {
+ translatedLuid = route.value().iface.Value;
+ }
+
+ //
+ // Forward to client.
+ //
+
+ callback(translatedEventType, translatedFamily, translatedLuid, context);
+ };
+
+ *registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return;
+ }
+
+ try
+ {
+ g_RouteManager->unregisterDefaultRouteChangedCallback(registrationHandle);
+ }
+ catch (const std::exception &err)
+ {
+ g_RouteManagerLogSink->error("Failed to unregister default-route-changed callback");
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ delete g_RouteManager;
+ g_RouteManager = nullptr;
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ NET_LUID luid;
+
+ if (0 != ConvertInterfaceAliasToLuid(deviceAlias, &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceAlias);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ InterfaceUtils::AddDeviceIpAddresses(luid, ConvertAddresses(addresses, numAddresses));
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def
index 04c3f22ee3..b23ae6c854 100644
--- a/windows/winnet/src/winnet/winnet.def
+++ b/windows/winnet/src/winnet/winnet.def
@@ -6,3 +6,6 @@ EXPORTS
WinNet_ReleaseString
WinNet_ActivateConnectivityMonitor
WinNet_DeactivateConnectivityMonitor
+ WinNet_ActivateRouteManager
+ WinNet_DeactivateRouteManager
+ WinNet_AddDeviceIpAddresses
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 9b1af52e36..c7a161c3d8 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -1,6 +1,7 @@
#pragma once
#include "../../shared/logsink.h"
+#include <stdint.h>
#include <stdbool.h>
#ifndef WINNET_STATIC
@@ -89,3 +90,147 @@ void
WINNET_API
WinNet_DeactivateConnectivityMonitor(
);
+
+enum WINNET_IP_TYPE
+{
+ WINNET_IP_TYPE_IPV4 = 0,
+ WINNET_IP_TYPE_IPV6 = 1,
+};
+
+typedef struct tag_WINNET_IPNETWORK
+{
+ WINNET_IP_TYPE type;
+ uint8_t bytes[16]; // Network byte order.
+ uint8_t prefix;
+}
+WINNET_IPNETWORK;
+
+typedef struct tag_WINNET_IP
+{
+ WINNET_IP_TYPE type;
+ uint8_t bytes[16]; // Network byte order.
+}
+WINNET_IP;
+
+typedef struct tag_WINNET_NODE
+{
+ const WINNET_IP *gateway;
+ const wchar_t *deviceName;
+}
+WINNET_NODE;
+
+typedef struct tag_WINNET_ROUTE
+{
+ WINNET_IPNETWORK network;
+ const WINNET_NODE *node;
+}
+WINNET_ROUTE;
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+);
+
+enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
+{
+ // Best default route changed.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0,
+
+ // No default routes exist.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1,
+};
+
+enum WINNET_IP_FAMILY
+{
+ WINNET_IP_FAMILY_V4 = 0,
+ WINNET_IP_FAMILY_V6 = 1,
+};
+
+typedef void (WINNET_API *WinNetDefaultRouteChangedCallback)
+(
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType,
+
+ // Signals which IP family the event relates to.
+ WINNET_IP_FAMILY family,
+
+ // For update events, signals the interface associated with the new best default route.
+ uint64_t interfaceLuid,
+
+ void *context
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+);
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+);
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+);
+
diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj
index 192320daaf..5e71a1f733 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj
+++ b/windows/winnet/src/winnet/winnet.vcxproj
@@ -33,6 +33,10 @@
<ClCompile Include="interfaceutils.cpp" />
<ClCompile Include="offlinemonitor.cpp" />
<ClCompile Include="NetworkInterfaces.cpp" />
+ <ClCompile Include="routing\defaultroutemonitor.cpp" />
+ <ClCompile Include="routing\helpers.cpp" />
+ <ClCompile Include="routing\routemanager.cpp" />
+ <ClCompile Include="routing\types.cpp" />
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="winnet.cpp" />
</ItemGroup>
@@ -42,6 +46,10 @@
<ClInclude Include="interfaceutils.h" />
<ClInclude Include="offlinemonitor.h" />
<ClInclude Include="NetworkInterfaces.h" />
+ <ClInclude Include="routing\defaultroutemonitor.h" />
+ <ClInclude Include="routing\helpers.h" />
+ <ClInclude Include="routing\routemanager.h" />
+ <ClInclude Include="routing\types.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="winnet.h" />
@@ -208,7 +216,7 @@
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
<LanguageStandard>stdcpplatest</LanguageStandard>
- <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@@ -278,7 +286,7 @@
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<LanguageStandard>stdcpplatest</LanguageStandard>
- <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters
index 9a901d3203..dfe6d29ec7 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj.filters
+++ b/windows/winnet/src/winnet/winnet.vcxproj.filters
@@ -9,6 +9,18 @@
<ClCompile Include="interfaceutils.cpp" />
<ClCompile Include="networkadaptermonitor.cpp" />
<ClCompile Include="offlinemonitor.cpp" />
+ <ClCompile Include="routing\types.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\helpers.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\defaultroutemonitor.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\routemanager.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -19,6 +31,18 @@
<ClInclude Include="interfaceutils.h" />
<ClInclude Include="networkadaptermonitor.h" />
<ClInclude Include="offlinemonitor.h" />
+ <ClInclude Include="routing\types.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\helpers.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\defaultroutemonitor.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\routemanager.h">
+ <Filter>routing</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="winnet.def" />
@@ -26,4 +50,9 @@
<ItemGroup>
<ResourceCompile Include="winnet.rc" />
</ItemGroup>
+ <ItemGroup>
+ <Filter Include="routing">
+ <UniqueIdentifier>{8df22cc6-597f-4342-bc57-7647393084be}</UniqueIdentifier>
+ </Filter>
+ </ItemGroup>
</Project> \ No newline at end of file