summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-12-07 20:31:40 +0100
committerDavid Lönnhager <david.l@mullvad.net>2021-01-04 16:50:18 +0100
commitd9baa6bf9d98858d9f5bae95740b9d5ecb192c0f (patch)
tree1ce85578b7bb8ca1dddad78120231a1c28113146
parent07d363b919ee0c9e33f444475361194a29f37216 (diff)
downloadmullvadvpn-d9baa6bf9d98858d9f5bae95740b9d5ecb192c0f.tar.xz
mullvadvpn-d9baa6bf9d98858d9f5bae95740b9d5ecb192c0f.zip
Unblock API endpoint while connecting or blocked
-rw-r--r--mullvad-daemon/src/lib.rs26
-rw-r--r--mullvad-problem-report/src/lib.rs2
-rw-r--r--mullvad-rpc/src/address_cache.rs125
-rw-r--r--mullvad-rpc/src/lib.rs20
-rw-r--r--mullvad-setup/src/main.rs3
-rw-r--r--talpid-core/src/firewall/linux.rs29
-rw-r--r--talpid-core/src/firewall/macos.rs24
-rw-r--r--talpid-core/src/firewall/mod.rs16
-rw-r--r--talpid-core/src/firewall/windows.rs64
-rw-r--r--talpid-core/src/tunnel/tun_provider/android/mod.rs38
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs7
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs17
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs10
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs21
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs20
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs28
-rw-r--r--talpid-types/src/net/mod.rs4
-rw-r--r--windows/winfw/src/extras/cli/commands/winfw/policy.cpp4
-rw-r--r--windows/winfw/src/winfw/fwcontext.cpp64
-rw-r--r--windows/winfw/src/winfw/fwcontext.h19
-rw-r--r--windows/winfw/src/winfw/mullvadguids.cpp15
-rw-r--r--windows/winfw/src/winfw/mullvadguids.h2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp87
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitendpoint.h36
-rw-r--r--windows/winfw/src/winfw/winfw.cpp28
-rw-r--r--windows/winfw/src/winfw/winfw.h15
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj2
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj.filters6
28 files changed, 635 insertions, 97 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 9a55798f1e..00cbec6d36 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -62,7 +62,7 @@ use talpid_core::{
#[cfg(target_os = "android")]
use talpid_types::android::AndroidContext;
use talpid_types::{
- net::{openvpn, TransportProtocol, TunnelParameters, TunnelType},
+ net::{openvpn, Endpoint, TransportProtocol, TunnelParameters, TunnelType},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
};
@@ -261,7 +261,7 @@ pub(crate) enum InternalDaemonEvent {
/// The background job fetching new `AppVersionInfo`s got a new info object.
NewAppVersionInfo(AppVersionInfo),
/// A new API endpoint is being used
- NewApiAddress(SocketAddr),
+ NewApiAddress(SocketAddr, oneshot::Sender<()>),
}
impl From<TunnelStateTransition> for InternalDaemonEvent {
@@ -495,6 +495,7 @@ where
let (internal_event_tx, internal_event_rx) = command_channel.destructure();
let address_change_tx = std::sync::Mutex::new(internal_event_tx.clone());
+ let address_change_runtime = tokio::runtime::Handle::current();
let mut rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
tokio::runtime::Handle::current(),
@@ -502,13 +503,18 @@ where
&user_cache_dir,
true,
move |address| {
+ let (result_tx, result_rx) = oneshot::channel();
+
let tx = address_change_tx.lock().unwrap();
if tx
- .send(InternalDaemonEvent::NewApiAddress(address))
+ .send(InternalDaemonEvent::NewApiAddress(address, result_tx))
.is_err()
{
log::error!("Failed to send API address daemon event");
+ return Err(());
}
+
+ address_change_runtime.block_on(result_rx).map_err(|_| ())
},
)
.await
@@ -590,10 +596,16 @@ where
TargetState::Unsecured
};
+ let initial_api_endpoint = Endpoint::from_socket_address(
+ rpc_runtime.address_cache.peek_address(),
+ TransportProtocol::Tcp,
+ );
+
let tunnel_command_tx = tunnel_state_machine::spawn(
settings.allow_lan,
settings.block_when_disconnected,
Self::get_custom_resolvers(&settings.tunnel_options.dns_options),
+ initial_api_endpoint,
tunnel_parameters_generator,
log_dir,
resource_dir,
@@ -763,9 +775,11 @@ where
NewAppVersionInfo(app_version_info) => {
self.handle_new_app_version_info(app_version_info)
}
- NewApiAddress(address) => {
- // TODO
- log::info!("ADDRESS! {:?}", address);
+ NewApiAddress(address, tx) => {
+ self.send_tunnel_command(TunnelCommand::AllowEndpoint(
+ Endpoint::from_socket_address(address, TransportProtocol::Tcp),
+ tx,
+ ));
}
}
}
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 869661eb4e..f4ee9d73ae 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -281,7 +281,7 @@ pub fn send_problem_report(
None,
user_cache_dir,
false,
- |_| {},
+ |_| Ok(()),
))
.map_err(Error::CreateRpcClientError)?;
let rpc_client = mullvad_rpc::ProblemReportProxy::new(rpc_manager.mullvad_rest_handle());
diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs
index 5da1f09359..757d6645ed 100644
--- a/mullvad-rpc/src/address_cache.rs
+++ b/mullvad-rpc/src/address_cache.rs
@@ -3,6 +3,7 @@ use rand::seq::SliceRandom;
use std::{
io,
net::SocketAddr,
+ ops::{Deref, DerefMut},
path::Path,
sync::{Arc, Mutex},
};
@@ -26,9 +27,13 @@ pub enum Error {
#[error(display = "The address cache is empty")]
EmptyAddressCache,
+
+ #[error(display = "The address change listener returned an error")]
+ ChangeListenerError,
}
-pub type CurrentAddressChangeListener = dyn Fn(SocketAddr) + Send + Sync + 'static;
+pub type CurrentAddressChangeListener =
+ dyn Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static;
#[derive(Clone)]
pub struct AddressCache {
@@ -71,6 +76,10 @@ impl AddressCache {
)
}
+ pub fn set_change_listener(&mut self, change_listener: Arc<Box<CurrentAddressChangeListener>>) {
+ self.change_listener = change_listener;
+ }
+
/// Returns the currently selected address.
pub fn get_address(&self) -> SocketAddr {
let mut inner = self.inner.lock().unwrap();
@@ -101,18 +110,24 @@ impl AddressCache {
}
pub async fn select_new_address(&self) {
- let (new_address, new_choice, old_choice) = {
+ {
let mut inner = self.inner.lock().unwrap();
- let old_choice = inner.choice;
- inner.choice = inner.choice.wrapping_add(1);
- (Self::get_address_inner(&inner), inner.choice, old_choice)
- };
+ let mut transaction = AddressCacheTransaction::new(&mut inner);
- if new_choice == old_choice {
- return;
- }
+ transaction.choice = transaction.current.choice.wrapping_add(1);
+ if transaction.choice == transaction.current.choice {
+ return;
+ }
+ transaction.tried_current = false;
- (*self.change_listener)(new_address);
+ tokio::task::block_in_place(move || {
+ if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() {
+ log::error!("Failed to select a new API endpoint");
+ return;
+ }
+ transaction.commit();
+ });
+ }
if let Err(error) = self.save_to_disk().await {
log::error!("{}", error.display_chain());
@@ -124,12 +139,25 @@ impl AddressCache {
pub async fn randomize(&self) -> Result<(), Error> {
{
let mut inner = self.inner.lock().unwrap();
- inner.shuffle();
- inner.choice = 0;
- inner.tried_current = false;
- let new_address = Self::get_address_inner(&inner);
- (*self.change_listener)(new_address);
+ let mut transaction = AddressCacheTransaction::new(&mut inner);
+ transaction.shuffle();
+ transaction.choice = 0;
+
+ let current_address = Self::get_address_inner(&transaction.current);
+ let new_address = Self::get_address_inner(&transaction);
+
+ tokio::task::block_in_place(move || {
+ if new_address != current_address {
+ transaction.tried_current = false;
+ if (*self.change_listener)(new_address).is_err() {
+ return Err(Error::ChangeListenerError);
+ }
+ }
+
+ transaction.commit();
+ Ok(())
+ })?;
}
self.save_to_disk().await.map_err(Error::WriteAddressCache)
}
@@ -137,28 +165,42 @@ impl AddressCache {
pub async fn set_addresses(&self, mut addresses: Vec<SocketAddr>) -> io::Result<()> {
let should_update = {
let mut inner = self.inner.lock().unwrap();
+ let mut transaction = AddressCacheTransaction::new(&mut inner);
+
addresses.sort();
- let mut current_sorted = inner.addresses.clone();
+
+ let mut current_sorted = transaction.addresses.clone();
current_sorted.sort();
+
if addresses != current_sorted {
- let current_address = Self::get_address_inner(&inner);
+ let current_address = Self::get_address_inner(&transaction);
- inner.addresses = addresses.clone();
- inner.shuffle();
+ transaction.addresses = addresses.clone();
+ transaction.shuffle();
// Prefer a likely-working address
- let choice = inner
+ let choice = transaction
.addresses
.iter()
.position(|&addr| addr == current_address);
if let Some(choice) = choice {
- inner.choice = choice;
+ transaction.choice = choice;
+ transaction.commit();
} else {
- inner.choice = 0;
- inner.tried_current = false;
+ transaction.choice = 0;
+ transaction.tried_current = false;
- let new_address = Self::get_address_inner(&inner);
- (*self.change_listener)(new_address);
+ tokio::task::block_in_place(move || {
+ if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() {
+ log::error!("Failed to select a new API endpoint");
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "callback returned an error",
+ ));
+ }
+ transaction.commit();
+ Ok(())
+ })?;
}
true
@@ -217,6 +259,7 @@ impl crate::rest::AddressProvider for AddressCache {
}
+#[derive(Clone, PartialEq, Eq)]
struct AddressCacheInner {
addresses: Vec<SocketAddr>,
choice: usize,
@@ -247,6 +290,38 @@ impl AddressCacheInner {
}
}
+struct AddressCacheTransaction<'a> {
+ current: &'a mut AddressCacheInner,
+ working_cache: AddressCacheInner,
+}
+
+impl<'a> AddressCacheTransaction<'a> {
+ fn new(cache: &'a mut AddressCacheInner) -> Self {
+ Self {
+ working_cache: cache.clone(),
+ current: cache,
+ }
+ }
+
+ fn commit(self) {
+ *self.current = self.working_cache;
+ }
+}
+
+impl<'a> Deref for AddressCacheTransaction<'a> {
+ type Target = AddressCacheInner;
+
+ fn deref(&self) -> &Self::Target {
+ &self.working_cache
+ }
+}
+
+impl<'a> DerefMut for AddressCacheTransaction<'a> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.working_cache
+ }
+}
+
async fn read_address_file(path: &Path) -> Result<Vec<SocketAddr>, Error> {
let file = fs::File::open(path)
.await
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index ca24288696..4beefa7d3b 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -23,8 +23,7 @@ use crate::https_client_with_sni::HttpsConnectorWithSni;
mod address_cache;
mod relay_list;
-use address_cache::AddressCache;
-pub use address_cache::CurrentAddressChangeListener;
+pub use address_cache::{AddressCache, CurrentAddressChangeListener};
pub use hyper::StatusCode;
pub use relay_list::RelayListProxy;
@@ -44,7 +43,7 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443);
pub struct MullvadRpcRuntime {
https_connector: HttpsConnectorWithSni,
handle: tokio::runtime::Handle,
- address_cache: AddressCache,
+ pub address_cache: AddressCache,
}
#[derive(err_derive::Error, Debug)]
@@ -65,7 +64,7 @@ impl MullvadRpcRuntime {
address_cache: AddressCache::new(
vec![API_ADDRESS.into()],
None,
- Arc::new(Box::new(|_| {})),
+ Arc::new(Box::new(|_| Ok(()))),
)?,
})
}
@@ -78,7 +77,7 @@ impl MullvadRpcRuntime {
resource_dir: Option<&Path>,
cache_dir: &Path,
write_changes: bool,
- address_change_listener: impl Fn(SocketAddr) + Send + Sync + 'static,
+ address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static,
) -> Result<Self, Error> {
let cache_file = cache_dir.join(API_IP_CACHE_FILENAME);
let write_file = if write_changes {
@@ -113,13 +112,12 @@ impl MullvadRpcRuntime {
match resource_dir {
Some(resource_dir) => {
let read_file = resource_dir.join(API_IP_CACHE_FILENAME);
- let cache = AddressCache::from_file(
- &read_file,
- write_file,
- address_change_listener,
- )
- .await?;
+ let empty_listener =
+ Arc::<Box<CurrentAddressChangeListener>>::new(Box::new(|_| Ok(())));
+ let mut cache =
+ AddressCache::from_file(&read_file, write_file, empty_listener).await?;
cache.randomize().await?;
+ cache.set_change_listener(address_change_listener);
cache
}
None => return Err(Error::AddressCacheError(error)),
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index 7e1021f4ea..c5ed2d38f0 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -144,6 +144,7 @@ async fn reset_firewall() -> Result<(), Error> {
let mut firewall = Firewall::new(FirewallArguments {
initialize_blocked: false,
allow_lan: true,
+ allowed_endpoint: None,
})
.map_err(Error::FirewallError)?;
@@ -158,7 +159,7 @@ async fn clear_history() -> Result<(), Error> {
None,
&user_cache_path,
false,
- |_| {},
+ |_| Ok(()),
)
.await
.map_err(Error::RpcInitializationError)?;
diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs
index 3c252313ce..04bd00777d 100644
--- a/talpid-core/src/firewall/linux.rs
+++ b/talpid-core/src/firewall/linux.rs
@@ -531,10 +531,12 @@ impl<'a> PolicyBatch<'a> {
peer_endpoint,
pingable_hosts,
allow_lan,
+ allowed_endpoint,
use_fwmark,
} => {
self.add_allow_icmp_pingable_hosts(&pingable_hosts);
- self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark);
+ self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark);
+ self.add_allow_endpoint_rules(allowed_endpoint);
// Important to block DNS after allow relay rule (so the relay can operate
// over port 53) but before allow LAN (so DNS does not leak to the LAN)
@@ -548,7 +550,7 @@ impl<'a> PolicyBatch<'a> {
dns_servers,
use_fwmark,
} => {
- self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark);
+ self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark);
self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Udp)?;
self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Tcp)?;
// Important to block DNS *before* we allow the tunnel and allow LAN. So DNS
@@ -560,7 +562,12 @@ impl<'a> PolicyBatch<'a> {
}
*allow_lan
}
- FirewallPolicy::Blocked { allow_lan } => {
+ FirewallPolicy::Blocked {
+ allow_lan,
+ allowed_endpoint,
+ } => {
+ self.add_allow_endpoint_rules(allowed_endpoint);
+
// Important to drop DNS before allowing LAN (to stop DNS leaking to the LAN)
self.add_drop_dns_rule();
*allow_lan
@@ -582,7 +589,7 @@ impl<'a> PolicyBatch<'a> {
Ok(())
}
- fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) {
+ fn add_allow_tunnel_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) {
let mut in_rule = Rule::new(&self.in_chain);
check_endpoint(&mut in_rule, End::Src, endpoint);
@@ -608,6 +615,20 @@ impl<'a> PolicyBatch<'a> {
self.batch.add(&out_rule, nftnl::MsgType::Add);
}
+ fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) {
+ let mut in_rule = Rule::new(&self.in_chain);
+ check_endpoint(&mut in_rule, End::Src, endpoint);
+ add_verdict(&mut in_rule, &Verdict::Accept);
+
+ self.batch.add(&in_rule, nftnl::MsgType::Add);
+
+ let mut out_rule = Rule::new(&self.out_chain);
+ check_endpoint(&mut out_rule, End::Dst, endpoint);
+ add_verdict(&mut out_rule, &Verdict::Accept);
+
+ self.batch.add(&out_rule, nftnl::MsgType::Add);
+ }
+
fn add_allow_icmp_pingable_hosts(&mut self, pingable_hosts: &[IpAddr]) {
for host in pingable_hosts {
let icmp_proto = match &host {
diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs
index dfdc1e31fc..2e23c99dd2 100644
--- a/talpid-core/src/firewall/macos.rs
+++ b/talpid-core/src/firewall/macos.rs
@@ -98,9 +98,11 @@ impl Firewall {
FirewallPolicy::Connecting {
peer_endpoint,
allow_lan,
+ allowed_endpoint,
pingable_hosts,
} => {
let mut rules = vec![self.get_allow_relay_rule(peer_endpoint)?];
+ rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?);
rules.extend(self.get_allow_pingable_hosts(&pingable_hosts)?);
if allow_lan {
// Important to block DNS after allow relay rule (so the relay can operate
@@ -136,8 +138,12 @@ impl Firewall {
Ok(rules)
}
- FirewallPolicy::Blocked { allow_lan } => {
+ FirewallPolicy::Blocked {
+ allow_lan,
+ allowed_endpoint,
+ } => {
let mut rules = Vec::new();
+ rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?);
if allow_lan {
// Important to block DNS before allow LAN (so DNS does not leak to the LAN)
rules.append(&mut self.get_block_dns_rules()?);
@@ -247,6 +253,22 @@ impl Firewall {
.build()?)
}
+ fn get_allowed_endpoint_rule(
+ &self,
+ allowed_endpoint: net::Endpoint,
+ ) -> Result<pfctl::FilterRule> {
+ let pfctl_proto = as_pfctl_proto(allowed_endpoint.protocol);
+
+ Ok(self
+ .create_rule_builder(FilterRuleAction::Pass)
+ .direction(pfctl::Direction::Out)
+ .to(allowed_endpoint.address)
+ .proto(pfctl_proto)
+ .keep_state(pfctl::StatePolicy::Keep)
+ .quick(true)
+ .build()?)
+ }
+
fn get_block_dns_rules(&self) -> Result<Vec<pfctl::FilterRule>> {
let block_tcp_dns_rule = self
.create_rule_builder(FilterRuleAction::Drop(DropAction::Return))
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index b467f37d98..83a112ce88 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -107,6 +107,8 @@ pub enum FirewallPolicy {
pingable_hosts: Vec<IpAddr>,
/// Flag setting if communication with LAN networks should be possible.
allow_lan: bool,
+ /// Host that should be reachable by the tunnel client while connecting.
+ allowed_endpoint: Endpoint,
/// A process that is allowed to send packets to the relay.
#[cfg(windows)]
relay_client: PathBuf,
@@ -140,6 +142,8 @@ pub enum FirewallPolicy {
Blocked {
/// Flag setting if communication with LAN networks should be possible.
allow_lan: bool,
+ /// Host that should be reachable while in the blocked state.
+ allowed_endpoint: Endpoint,
},
}
@@ -182,10 +186,14 @@ impl fmt::Display for FirewallPolicy {
tunnel.ipv6_gateway,
if *allow_lan { "Allowing" } else { "Blocking" }
),
- FirewallPolicy::Blocked { allow_lan } => write!(
+ FirewallPolicy::Blocked {
+ allow_lan,
+ allowed_endpoint,
+ } => write!(
f,
- "Blocked, {} LAN",
- if *allow_lan { "Allowing" } else { "Blocking" }
+ "Blocked. {} LAN. Allowing endpoint {}",
+ if *allow_lan { "Allowing" } else { "Blocking" },
+ allowed_endpoint,
),
}
}
@@ -203,6 +211,8 @@ pub struct FirewallArguments {
pub initialize_blocked: bool,
/// This argument is required for the blocked state to configure the firewall correctly.
pub allow_lan: bool,
+ /// This argument is required for the blocked state to configure the firewall correctly.
+ pub allowed_endpoint: Option<Endpoint>,
}
impl Firewall {
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index d1fa08a3e6..8375eb55d6 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -57,10 +57,23 @@ impl FirewallT for Firewall {
if args.initialize_blocked {
let cfg = &WinFwSettings::new(args.allow_lan);
+
+ let winfw_allowed_endpoint = if let Some(allowed_endpoint) = args.allowed_endpoint {
+ let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip());
+ Some(WinFwEndpoint {
+ ip: allowed_endpoint_ip.as_ptr(),
+ port: allowed_endpoint.address.port(),
+ protocol: WinFwProt::from(allowed_endpoint.protocol),
+ })
+ } else {
+ None
+ };
+
unsafe {
WinFw_InitializeBlocked(
WINFW_TIMEOUT_SECONDS,
&cfg,
+ winfw_allowed_endpoint.as_ptr(),
Some(log_sink),
logging_context,
)
@@ -83,6 +96,7 @@ impl FirewallT for Firewall {
peer_endpoint,
pingable_hosts,
allow_lan,
+ allowed_endpoint,
relay_client,
} => {
let cfg = &WinFwSettings::new(allow_lan);
@@ -91,6 +105,7 @@ impl FirewallT for Firewall {
&peer_endpoint,
&cfg,
"Mullvad".to_string(),
+ &allowed_endpoint,
&pingable_hosts,
&relay_client,
)
@@ -105,9 +120,12 @@ impl FirewallT for Firewall {
let cfg = &WinFwSettings::new(allow_lan);
self.set_connected_state(&peer_endpoint, &cfg, &tunnel, &dns_servers, &relay_client)
}
- FirewallPolicy::Blocked { allow_lan } => {
+ FirewallPolicy::Blocked {
+ allow_lan,
+ allowed_endpoint,
+ } => {
let cfg = &WinFwSettings::new(allow_lan);
- self.set_blocked_state(&cfg)
+ self.set_blocked_state(&cfg, &allowed_endpoint)
}
}
}
@@ -138,12 +156,13 @@ impl Firewall {
endpoint: &Endpoint,
winfw_settings: &WinFwSettings,
_tunnel_iface_alias: String,
+ allowed_endpoint: &Endpoint,
pingable_hosts: &Vec<IpAddr>,
relay_client: &Path,
) -> Result<(), Error> {
trace!("Applying 'connecting' firewall policy");
let ip_str = Self::widestring_ip(endpoint.address.ip());
- let winfw_relay = WinFwRelay {
+ let winfw_relay = WinFwEndpoint {
ip: ip_str.as_ptr(),
port: endpoint.address.port(),
protocol: WinFwProt::from(endpoint.protocol),
@@ -171,12 +190,20 @@ impl Firewall {
None
};
+ let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip());
+ let winfw_allowed_endpoint = Some(WinFwEndpoint {
+ ip: allowed_endpoint_ip.as_ptr(),
+ port: allowed_endpoint.address.port(),
+ protocol: WinFwProt::from(allowed_endpoint.protocol),
+ });
+
unsafe {
WinFw_ApplyPolicyConnecting(
winfw_settings,
&winfw_relay,
relay_client.as_ptr(),
pingable_hosts.as_ptr(),
+ winfw_allowed_endpoint.as_ptr(),
)
.into_result()
.map_err(Error::ApplyingConnectingPolicy)
@@ -207,7 +234,7 @@ impl Firewall {
WideCString::new(tunnel_metadata.interface.encode_utf16().collect::<Vec<_>>()).unwrap();
// ip_str, gateway_str and tunnel_alias have to outlive winfw_relay
- let winfw_relay = WinFwRelay {
+ let winfw_relay = WinFwEndpoint {
ip: ip_str.as_ptr(),
port: endpoint.address.port(),
protocol: WinFwProt::from(endpoint.protocol),
@@ -258,10 +285,22 @@ impl Firewall {
}
}
- fn set_blocked_state(&mut self, winfw_settings: &WinFwSettings) -> Result<(), Error> {
+ fn set_blocked_state(
+ &mut self,
+ winfw_settings: &WinFwSettings,
+ allowed_endpoint: &Endpoint,
+ ) -> Result<(), Error> {
trace!("Applying 'blocked' firewall policy");
+
+ let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip());
+ let winfw_allowed_endpoint = Some(WinFwEndpoint {
+ ip: allowed_endpoint_ip.as_ptr(),
+ port: allowed_endpoint.address.port(),
+ protocol: WinFwProt::from(allowed_endpoint.protocol),
+ });
+
unsafe {
- WinFw_ApplyPolicyBlocked(winfw_settings)
+ WinFw_ApplyPolicyBlocked(winfw_settings, winfw_allowed_endpoint.as_ptr())
.into_result()
.map_err(Error::ApplyingBlockedPolicy)
}
@@ -289,7 +328,7 @@ mod winfw {
use talpid_types::net::TransportProtocol;
#[repr(C)]
- pub struct WinFwRelay {
+ pub struct WinFwEndpoint {
pub ip: *const libc::wchar_t,
pub port: u16,
pub protocol: WinFwProt,
@@ -385,6 +424,7 @@ mod winfw {
pub fn WinFw_InitializeBlocked(
timeout: libc::c_uint,
settings: &WinFwSettings,
+ allowed_endpoint: *const WinFwEndpoint,
sink: Option<LogSink>,
sink_context: *const u8,
) -> InitializationResult;
@@ -395,15 +435,16 @@ mod winfw {
#[link_name = "WinFw_ApplyPolicyConnecting"]
pub fn WinFw_ApplyPolicyConnecting(
settings: &WinFwSettings,
- relay: &WinFwRelay,
+ relay: &WinFwEndpoint,
relayClient: *const libc::wchar_t,
pingable_hosts: *const WinFwPingableHosts,
+ allowed_endpoint: *const WinFwEndpoint,
) -> WinFwPolicyStatus;
#[link_name = "WinFw_ApplyPolicyConnected"]
pub fn WinFw_ApplyPolicyConnected(
settings: &WinFwSettings,
- relay: &WinFwRelay,
+ relay: &WinFwEndpoint,
relayClient: *const libc::wchar_t,
tunnelIfaceAlias: *const libc::wchar_t,
v4Gateway: *const libc::wchar_t,
@@ -413,7 +454,10 @@ mod winfw {
) -> WinFwPolicyStatus;
#[link_name = "WinFw_ApplyPolicyBlocked"]
- pub fn WinFw_ApplyPolicyBlocked(settings: &WinFwSettings) -> WinFwPolicyStatus;
+ pub fn WinFw_ApplyPolicyBlocked(
+ settings: &WinFwSettings,
+ allowed_endpoint: *const WinFwEndpoint,
+ ) -> WinFwPolicyStatus;
#[link_name = "WinFw_Reset"]
pub fn WinFw_Reset() -> WinFwPolicyStatus;
diff --git a/talpid-core/src/tunnel/tun_provider/android/mod.rs b/talpid-core/src/tunnel/tun_provider/android/mod.rs
index b9385f13a7..fa48f115b9 100644
--- a/talpid-core/src/tunnel/tun_provider/android/mod.rs
+++ b/talpid-core/src/tunnel/tun_provider/android/mod.rs
@@ -66,6 +66,7 @@ pub struct AndroidTunProvider {
object: GlobalRef,
last_tun_config: TunConfig,
allow_lan: bool,
+ allowed_endpoint: IpAddr,
custom_dns_servers: Option<Vec<IpAddr>>,
}
@@ -74,6 +75,7 @@ impl AndroidTunProvider {
pub fn new(
context: AndroidContext,
allow_lan: bool,
+ allowed_endpoint: IpAddr,
custom_dns_servers: Option<Vec<IpAddr>>,
) -> Self {
let env = JnixEnv::from(
@@ -90,6 +92,7 @@ impl AndroidTunProvider {
object: context.vpn_service,
last_tun_config: TunConfig::default(),
allow_lan,
+ allowed_endpoint,
custom_dns_servers,
}
}
@@ -103,6 +106,10 @@ impl AndroidTunProvider {
Ok(())
}
+ pub fn set_allowed_endpoint(&mut self, endpoint: IpAddr) {
+ self.allowed_endpoint = endpoint;
+ }
+
pub fn set_custom_dns_servers(&mut self, servers: Option<Vec<IpAddr>>) -> Result<(), Error> {
if self.custom_dns_servers != servers {
self.custom_dns_servers = servers;
@@ -129,6 +136,19 @@ impl AndroidTunProvider {
})
}
+ /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and (potentially)
+ /// LAN routes via the tunnel device.
+ ///
+ /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be
+ /// closed.
+ pub fn create_blocking_tun(&mut self) -> Result<(), Error> {
+ let mut config = TunConfig::default();
+ self.prepare_tun_config(&mut config);
+ self.prepare_tun_config_for_allowed_endpoint(&mut config);
+ let _ = self.get_tun(config)?;
+ Ok(())
+ }
+
/// Open a tunnel device using the previous or the default configuration.
///
/// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be
@@ -231,6 +251,24 @@ impl AndroidTunProvider {
}
}
+ fn prepare_tun_config_for_allowed_endpoint(&self, config: &mut TunConfig) {
+ let endpoint_net = IpNetwork::from(self.allowed_endpoint);
+ let routes = config
+ .routes
+ .iter()
+ .flat_map(|&route| {
+ if route.is_ipv4() && endpoint_net.is_ipv4() {
+ route.sub(endpoint_net).collect()
+ } else if route.is_ipv6() && endpoint_net.is_ipv6() {
+ route.sub(endpoint_net).collect()
+ } else {
+ vec![route]
+ }
+ })
+ .collect();
+ config.routes = routes;
+ }
+
fn prepare_tun_config(&self, config: &mut TunConfig) {
self.prepare_tun_config_for_allow_lan(config);
self.prepare_tun_config_for_custom_dns(config);
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 7292da0c67..0c305de9a7 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -192,6 +192,13 @@ impl ConnectedState {
}
}
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ let _ = shared_values.set_allowed_endpoint(endpoint);
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ SameState(self.into())
+ }
Some(TunnelCommand::CustomDns(servers)) => {
match shared_values.set_custom_dns(servers) {
Ok(true) => {
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 44dcd9f153..0b03ceeca1 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -63,6 +63,7 @@ impl ConnectingState {
peer_endpoint,
pingable_hosts: gateway_list_from_params(params),
allow_lan: shared_values.allow_lan,
+ allowed_endpoint: shared_values.allowed_endpoint.clone(),
#[cfg(windows)]
relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, &params),
#[cfg(target_os = "linux")]
@@ -235,6 +236,22 @@ impl ConnectingState {
}
}
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ if shared_values.set_allowed_endpoint(endpoint) {
+ if let Err(error) =
+ Self::set_firewall_policy(shared_values, &self.tunnel_parameters)
+ {
+ return self.disconnect(
+ shared_values,
+ AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
+ );
+ }
+ }
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ SameState(self.into())
+ }
Some(TunnelCommand::CustomDns(servers)) => {
match shared_values.set_custom_dns(servers) {
#[cfg(target_os = "android")]
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index dcc4660e9f..922eb69c88 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -17,6 +17,7 @@ impl DisconnectedState {
let result = if shared_values.block_when_disconnected {
let policy = FirewallPolicy::Blocked {
allow_lan: shared_values.allow_lan,
+ allowed_endpoint: shared_values.allowed_endpoint.clone(),
};
shared_values.firewall.apply_policy(policy).map_err(|e| {
e.display_chain_with_msg(
@@ -77,6 +78,15 @@ impl TunnelState for DisconnectedState {
}
SameState(self.into())
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ if shared_values.set_allowed_endpoint(endpoint) {
+ Self::set_firewall_policy(shared_values, true);
+ }
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ SameState(self.into())
+ }
Some(TunnelCommand::CustomDns(servers)) => {
// Same situation as allow LAN above.
shared_values
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 0928834d1c..48a83a6dc3 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -32,6 +32,13 @@ impl DisconnectingState {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Nothing
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ let _ = shared_values.set_allowed_endpoint(endpoint);
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ AfterDisconnect::Nothing
+ }
Some(TunnelCommand::CustomDns(servers)) => {
let _ = shared_values.set_custom_dns(servers);
AfterDisconnect::Nothing
@@ -53,6 +60,13 @@ impl DisconnectingState {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Block(reason)
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ let _ = shared_values.set_allowed_endpoint(endpoint);
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ AfterDisconnect::Block(reason)
+ }
Some(TunnelCommand::CustomDns(servers)) => {
let _ = shared_values.set_custom_dns(servers);
AfterDisconnect::Block(reason)
@@ -79,6 +93,13 @@ impl DisconnectingState {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Reconnect(retry_attempt)
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ let _ = shared_values.set_allowed_endpoint(endpoint);
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ AfterDisconnect::Reconnect(retry_attempt)
+ }
Some(TunnelCommand::CustomDns(servers)) => {
let _ = shared_values.set_custom_dns(servers);
AfterDisconnect::Reconnect(retry_attempt)
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index a87dccd5b4..51159d274f 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -21,6 +21,7 @@ impl ErrorState {
) -> Result<(), FirewallPolicyError> {
let policy = FirewallPolicy::Blocked {
allow_lan: shared_values.allow_lan,
+ allowed_endpoint: shared_values.allowed_endpoint.clone(),
};
#[cfg(target_os = "linux")]
@@ -47,7 +48,7 @@ impl ErrorState {
/// Returns true if a new tunnel device was successfully created.
#[cfg(target_os = "android")]
fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool {
- match shared_values.tun_provider.create_tun_if_closed() {
+ match shared_values.tun_provider.create_blocking_tun() {
Ok(()) => true,
Err(error) => {
log::error!(
@@ -105,6 +106,23 @@ impl TunnelState for ErrorState {
SameState(self.into())
}
}
+ Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
+ if shared_values.set_allowed_endpoint(endpoint) {
+ let _ = Self::set_firewall_policy(shared_values);
+
+ #[cfg(target_os = "android")]
+ if !Self::create_blocking_tun(shared_values) {
+ return NewState(Self::enter(
+ shared_values,
+ ErrorStateCause::SetFirewallPolicyError(FirewallPolicyError::Generic),
+ ));
+ }
+ }
+ if let Err(_) = tx.send(()) {
+ log::error!("The AllowEndpoint receiver was dropped");
+ }
+ SameState(self.into())
+ }
Some(TunnelCommand::CustomDns(servers)) => {
if let Err(error_state_cause) = shared_values.set_custom_dns(servers) {
NewState(Self::enter(shared_values, error_state_cause))
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index fbec1bf2b1..b657ec5e36 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -33,7 +33,7 @@ use std::{
#[cfg(target_os = "android")]
use talpid_types::{android::AndroidContext, ErrorExt};
use talpid_types::{
- net::TunnelParameters,
+ net::{Endpoint, TunnelParameters},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
};
@@ -75,6 +75,7 @@ pub async fn spawn(
allow_lan: bool,
block_when_disconnected: bool,
custom_dns: Option<Vec<IpAddr>>,
+ allowed_endpoint: Endpoint,
tunnel_parameters_generator: impl TunnelParametersGenerator,
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
@@ -101,6 +102,8 @@ pub async fn spawn(
#[cfg(target_os = "android")]
allow_lan,
#[cfg(target_os = "android")]
+ allowed_endpoint.address.ip(),
+ #[cfg(target_os = "android")]
custom_dns.clone(),
);
@@ -114,6 +117,7 @@ pub async fn spawn(
block_when_disconnected,
is_offline,
custom_dns,
+ allowed_endpoint,
tunnel_parameters_generator,
tun_provider,
log_dir,
@@ -152,6 +156,9 @@ pub async fn spawn(
pub enum TunnelCommand {
/// Enable or disable LAN access in the firewall.
AllowLan(bool),
+ /// Endpoint that should never be blocked.
+ /// If an error occurs, the sender is dropped.
+ AllowEndpoint(Endpoint, oneshot::Sender<()>),
/// Set custom DNS servers to use.
CustomDns(Option<Vec<IpAddr>>),
/// Enable or disable the block_when_disconnected feature.
@@ -193,6 +200,7 @@ impl TunnelStateMachine {
block_when_disconnected: bool,
is_offline: bool,
custom_dns: Option<Vec<IpAddr>>,
+ allowed_endpoint: Endpoint,
tunnel_parameters_generator: impl TunnelParametersGenerator,
tun_provider: TunProvider,
log_dir: Option<PathBuf>,
@@ -204,6 +212,7 @@ impl TunnelStateMachine {
let args = FirewallArguments {
initialize_blocked: block_when_disconnected || !reset_firewall,
allow_lan,
+ allowed_endpoint: Some(allowed_endpoint),
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
@@ -218,6 +227,7 @@ impl TunnelStateMachine {
block_when_disconnected,
is_offline,
custom_dns,
+ allowed_endpoint,
tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
tun_provider,
log_dir,
@@ -291,6 +301,8 @@ struct SharedTunnelStateValues {
is_offline: bool,
/// Custom DNS servers to use.
custom_dns: Option<Vec<IpAddr>>,
+ /// Endpoint that should not be blocked by the firewall.
+ allowed_endpoint: Endpoint,
/// The generator of new `TunnelParameter`s
tunnel_parameters_generator: Box<dyn TunnelParametersGenerator>,
/// The provider of tunnel devices.
@@ -328,6 +340,20 @@ impl SharedTunnelStateValues {
Ok(())
}
+ pub fn set_allowed_endpoint(&mut self, endpoint: Endpoint) -> bool {
+ if self.allowed_endpoint != endpoint {
+ self.allowed_endpoint = endpoint;
+
+ #[cfg(target_os = "android")]
+ self.tun_provider
+ .set_allowed_endpoint(endpoint.address.ip());
+
+ true
+ } else {
+ false
+ }
+ }
+
pub fn set_custom_dns(
&mut self,
custom_dns: Option<Vec<IpAddr>>,
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 15bf33a5bb..29a871ee07 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -144,6 +144,10 @@ impl Endpoint {
protocol,
}
}
+
+ pub fn from_socket_address(address: SocketAddr, protocol: TransportProtocol) -> Self {
+ Endpoint { address, protocol }
+ }
}
impl fmt::Display for Endpoint {
diff --git a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp
index a722501832..7f5b9d980c 100644
--- a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp
+++ b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp
@@ -26,9 +26,9 @@ WinFwProtocol TranslateProtocol(const std::wstring &protocol)
return (0 == _wcsicmp(protocol.c_str(), L"tcp") ? WinFwProtocol::Tcp : WinFwProtocol::Udp);
}
-WinFwRelay CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol)
+WinFwEndpoint CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol)
{
- WinFwRelay r;
+ WinFwEndpoint r;
r.ip = ip;
r.port = common::string::LexicalCast<uint16_t>(port);
diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp
index b8959cfc01..99a3d3dc26 100644
--- a/windows/winfw/src/winfw/fwcontext.cpp
+++ b/windows/winfw/src/winfw/fwcontext.cpp
@@ -15,6 +15,7 @@
#include "rules/baseline/permitvpntunnelservice.h"
#include "rules/baseline/permitping.h"
#include "rules/baseline/permitdns.h"
+#include "rules/baseline/permitendpoint.h"
#include "rules/dns/blockall.h"
#include "rules/dns/permittunnel.h"
#include "rules/dns/permitnontunnel.h"
@@ -43,6 +44,19 @@ multi::PermitVpnRelay::Protocol TranslateProtocol(WinFwProtocol protocol)
};
}
+baseline::PermitEndpoint::Protocol TranslateEndpointProtocol(WinFwProtocol protocol)
+{
+ switch (protocol)
+ {
+ case Tcp: return baseline::PermitEndpoint::Protocol::Tcp;
+ case Udp: return baseline::PermitEndpoint::Protocol::Udp;
+ default:
+ {
+ THROW_ERROR("Missing case handler in switch clause");
+ }
+ };
+}
+
//
// Since the PermitLan rule doesn't specifically address DNS, it will allow DNS requests targetting
// a local resolver to leave the machine. From the local resolver the request will either be
@@ -91,7 +105,7 @@ void AppendSettingsRules
void AppendRelayRules
(
FwContext::Ruleset &ruleset,
- const WinFwRelay &relay,
+ const WinFwEndpoint &relay,
const std::wstring &relayClient
)
{
@@ -111,6 +125,22 @@ void AppendRelayRules
));
}
+//
+// Refer comment on `AppendSettingsRules`.
+//
+void AppendAllowedEndpointRules
+(
+ FwContext::Ruleset &ruleset,
+ const WinFwEndpoint &endpoint
+)
+{
+ ruleset.emplace_back(std::make_unique<baseline::PermitEndpoint>(
+ wfp::IpAddress(endpoint.ip),
+ endpoint.port,
+ TranslateEndpointProtocol(endpoint.protocol)
+ ));
+}
+
void AppendNetBlockedRules(FwContext::Ruleset &ruleset)
{
ruleset.emplace_back(std::make_unique<baseline::BlockAll>());
@@ -145,7 +175,8 @@ FwContext::FwContext
FwContext::FwContext
(
uint32_t timeout,
- const WinFwSettings &settings
+ const WinFwSettings &settings,
+ const std::optional<WinFwEndpoint> &allowedEndpoint
)
: m_baseline(0)
, m_activePolicy(Policy::None)
@@ -159,7 +190,7 @@ FwContext::FwContext
uint32_t checkpoint = 0;
- if (false == applyBlockedBaseConfiguration(settings, checkpoint))
+ if (false == applyBlockedBaseConfiguration(settings, allowedEndpoint, checkpoint))
{
THROW_ERROR("Failed to apply base configuration in BFE");
}
@@ -171,9 +202,10 @@ FwContext::FwContext
bool FwContext::applyPolicyConnecting
(
const WinFwSettings &settings,
- const WinFwRelay &relay,
+ const WinFwEndpoint &relay,
const std::wstring &relayClient,
- const std::optional<PingableHosts> &pingableHosts
+ const std::optional<PingableHosts> &pingableHosts,
+ const std::optional<WinFwEndpoint> &allowedEndpoint
)
{
Ruleset ruleset;
@@ -182,6 +214,11 @@ bool FwContext::applyPolicyConnecting
AppendSettingsRules(ruleset, settings);
AppendRelayRules(ruleset, relay, relayClient);
+ if (allowedEndpoint.has_value())
+ {
+ AppendAllowedEndpointRules(ruleset, allowedEndpoint.value());
+ }
+
//
// Permit pinging the gateway inside the tunnel.
//
@@ -208,7 +245,7 @@ bool FwContext::applyPolicyConnecting
bool FwContext::applyPolicyConnected
(
const WinFwSettings &settings,
- const WinFwRelay &relay,
+ const WinFwEndpoint &relay,
const std::wstring &relayClient,
const std::wstring &tunnelInterfaceAlias,
const std::vector<wfp::IpAddress> &tunnelDnsServers,
@@ -252,9 +289,9 @@ bool FwContext::applyPolicyConnected
return status;
}
-bool FwContext::applyPolicyBlocked(const WinFwSettings &settings)
+bool FwContext::applyPolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint)
{
- const auto status = applyRuleset(composePolicyBlocked(settings));
+ const auto status = applyRuleset(composePolicyBlocked(settings, allowedEndpoint));
if (status)
{
@@ -284,13 +321,18 @@ FwContext::Policy FwContext::activePolicy() const
return m_activePolicy;
}
-FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings)
+FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint)
{
Ruleset ruleset;
AppendNetBlockedRules(ruleset);
AppendSettingsRules(ruleset, settings);
+ if (allowedEndpoint.has_value())
+ {
+ AppendAllowedEndpointRules(ruleset, allowedEndpoint.value());
+ }
+
return ruleset;
}
@@ -302,7 +344,7 @@ bool FwContext::applyBaseConfiguration()
});
}
-bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint)
+bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint)
{
return m_sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine)
{
@@ -318,7 +360,7 @@ bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uin
//
checkpoint = controller.peekCheckpoint();
- return applyRulesetDirectly(composePolicyBlocked(settings), controller);
+ return applyRulesetDirectly(composePolicyBlocked(settings, allowedEndpoint), controller);
});
}
diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h
index 100672073a..bbbb1de485 100644
--- a/windows/winfw/src/winfw/fwcontext.h
+++ b/windows/winfw/src/winfw/fwcontext.h
@@ -20,7 +20,8 @@ public:
FwContext
(
uint32_t timeout,
- const WinFwSettings &settings
+ const WinFwSettings &settings,
+ const std::optional<WinFwEndpoint> &allowedEndpoint
);
struct PingableHosts
@@ -32,22 +33,26 @@ public:
bool applyPolicyConnecting
(
const WinFwSettings &settings,
- const WinFwRelay &relay,
+ const WinFwEndpoint &relay,
const std::wstring &relayClient,
- const std::optional<PingableHosts> &pingableHosts
+ const std::optional<PingableHosts> &pingableHosts,
+ const std::optional<WinFwEndpoint> &allowedEndpoint
);
bool applyPolicyConnected
(
const WinFwSettings &settings,
- const WinFwRelay &relay,
+ const WinFwEndpoint &relay,
const std::wstring &relayClient,
const std::wstring &tunnelInterfaceAlias,
const std::vector<wfp::IpAddress> &tunnelDnsServers,
const std::vector<wfp::IpAddress> &nonTunnelDnsServers
);
- bool applyPolicyBlocked(const WinFwSettings &settings);
+ bool applyPolicyBlocked(
+ const WinFwSettings &settings,
+ const std::optional<WinFwEndpoint> &allowedEndpoint
+ );
bool reset();
@@ -68,10 +73,10 @@ private:
FwContext(const FwContext &) = delete;
FwContext &operator=(const FwContext &) = delete;
- Ruleset composePolicyBlocked(const WinFwSettings &settings);
+ Ruleset composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint);
bool applyBaseConfiguration();
- bool applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint);
+ bool applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint);
bool applyCommonBaseConfiguration(SessionController &controller, wfp::FilterEngine &engine);
bool applyRuleset(const Ruleset &ruleset);
diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp
index 0a22be1740..417b157f82 100644
--- a/windows/winfw/src/winfw/mullvadguids.cpp
+++ b/windows/winfw/src/winfw/mullvadguids.cpp
@@ -129,6 +129,7 @@ MullvadGuids::DetailedIdentityRegistry MullvadGuids::DetailedRegistry(IdentityQu
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Inbound_Request_Ipv4()));
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Outbound_Response_Ipv4()));
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnRelay()));
+ registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitEndpoint()));
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4()));
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6()));
registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnelService_Ipv4()));
@@ -644,6 +645,20 @@ const GUID &MullvadGuids::Filter_Baseline_PermitVpnRelay()
}
//static
+const GUID &MullvadGuids::Filter_Baseline_PermitEndpoint()
+{
+ static const GUID g =
+ {
+ 0x99dc8dac,
+ 0x8520,
+ 0x41be,
+ { 0xbf, 0xab, 0x0c, 0x9, 0xbf, 0x12, 0xeb, 0 }
+ };
+
+ return g;
+}
+
+//static
const GUID &MullvadGuids::Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4()
{
static const GUID g =
diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h
index 11e396fc2b..7f00863811 100644
--- a/windows/winfw/src/winfw/mullvadguids.h
+++ b/windows/winfw/src/winfw/mullvadguids.h
@@ -69,6 +69,8 @@ public:
static const GUID &Filter_Baseline_PermitVpnRelay();
+ static const GUID &Filter_Baseline_PermitEndpoint();
+
static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4();
static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6();
diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp
new file mode 100644
index 0000000000..217631579d
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp
@@ -0,0 +1,87 @@
+#include "stdafx.h"
+#include "permitendpoint.h"
+#include <winfw/mullvadguids.h>
+#include <libwfp/filterbuilder.h>
+#include <libwfp/conditionbuilder.h>
+#include <libwfp/conditions/conditionprotocol.h>
+#include <libwfp/conditions/conditionip.h>
+#include <libwfp/conditions/conditionport.h>
+#include <libwfp/conditions/conditionapplication.h>
+#include <libcommon/error.h>
+
+using namespace wfp::conditions;
+
+namespace rules::baseline
+{
+
+namespace
+{
+
+const GUID &OutboundLayerFromIp(const wfp::IpAddress &ip)
+{
+ switch (ip.type())
+ {
+ case wfp::IpAddress::Type::Ipv4: return FWPM_LAYER_ALE_AUTH_CONNECT_V4;
+ case wfp::IpAddress::Type::Ipv6: return FWPM_LAYER_ALE_AUTH_CONNECT_V6;
+ default:
+ {
+ THROW_ERROR("Missing case handler in switch clause");
+ }
+ };
+}
+
+std::unique_ptr<ConditionProtocol> CreateProtocolCondition(PermitEndpoint::Protocol protocol)
+{
+ switch (protocol)
+ {
+ case PermitEndpoint::Protocol::Tcp: return ConditionProtocol::Tcp();
+ case PermitEndpoint::Protocol::Udp: return ConditionProtocol::Udp();
+ default:
+ {
+ THROW_ERROR("Missing case handler in switch clause");
+ }
+ };
+}
+
+} // anonymous namespace
+
+PermitEndpoint::PermitEndpoint
+(
+ const wfp::IpAddress &address,
+ uint16_t port,
+ Protocol protocol
+)
+ : m_address(address)
+ , m_port(port)
+ , m_protocol(protocol)
+{
+}
+
+bool PermitEndpoint::apply(IObjectInstaller &objectInstaller)
+{
+ wfp::FilterBuilder filterBuilder;
+
+ //
+ // Permit outbound connections to endpoint.
+ //
+
+ filterBuilder
+ .key(MullvadGuids::Filter_Baseline_PermitEndpoint())
+ .name(L"Permit outbound connections to a given endpoint")
+ .description(L"This filter is part of a rule that permits traffic to a specific endpoint")
+ .provider(MullvadGuids::Provider())
+ .layer(OutboundLayerFromIp(m_address))
+ .sublayer(MullvadGuids::SublayerBaseline())
+ .weight(wfp::FilterBuilder::WeightClass::Max)
+ .permit();
+
+ wfp::ConditionBuilder conditionBuilder(OutboundLayerFromIp(m_address));
+
+ conditionBuilder.add_condition(ConditionIp::Remote(m_address));
+ conditionBuilder.add_condition(ConditionPort::Remote(m_port));
+ conditionBuilder.add_condition(CreateProtocolCondition(m_protocol));
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+}
+
+}
diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.h b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h
new file mode 100644
index 0000000000..cfa57b003d
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <winfw/rules/ifirewallrule.h>
+#include <libwfp/ipaddress.h>
+#include <string>
+
+namespace rules::baseline
+{
+
+class PermitEndpoint : public IFirewallRule
+{
+public:
+
+ enum class Protocol
+ {
+ Tcp,
+ Udp
+ };
+
+ PermitEndpoint
+ (
+ const wfp::IpAddress &address,
+ uint16_t port,
+ Protocol protocol
+ );
+
+ bool apply(IObjectInstaller &objectInstaller) override;
+
+private:
+
+ const wfp::IpAddress m_address;
+ const uint16_t m_port;
+ const Protocol m_protocol;
+};
+
+}
diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp
index f2d5a66b2a..ee7842877e 100644
--- a/windows/winfw/src/winfw/winfw.cpp
+++ b/windows/winfw/src/winfw/winfw.cpp
@@ -65,6 +65,16 @@ HandlePolicyException(const common::error::WindowsException &err)
return WINFW_POLICY_STATUS_GENERAL_FAILURE;
}
+template<typename T>
+std::optional<T> MakeOptional(T* object)
+{
+ if (nullptr == object)
+ {
+ return std::nullopt;
+ }
+ return std::make_optional(*object);
+}
+
//
// Networks for which DNS requests can be made on all network adapters.
//
@@ -136,6 +146,7 @@ WINFW_API
WinFw_InitializeBlocked(
uint32_t timeout,
const WinFwSettings *settings,
+ const WinFwEndpoint *allowedEndpoint,
MullvadLogSink logSink,
void *logSinkContext
)
@@ -162,7 +173,7 @@ WinFw_InitializeBlocked(
g_logSink = logSink;
g_logSinkContext = logSinkContext;
- g_fwContext = new FwContext(timeout_ms, *settings);
+ g_fwContext = new FwContext(timeout_ms, *settings, MakeOptional(allowedEndpoint));
}
catch (std::exception &err)
{
@@ -247,9 +258,10 @@ WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings *settings,
- const WinFwRelay *relay,
+ const WinFwEndpoint *relay,
const wchar_t *relayClient,
- const PingableHosts *pingableHosts
+ const PingableHosts *pingableHosts,
+ const WinFwEndpoint *allowedEndpoint
)
{
if (nullptr == g_fwContext)
@@ -278,7 +290,8 @@ WinFw_ApplyPolicyConnecting(
*settings,
*relay,
relayClient,
- ConvertPingableHosts(pingableHosts)
+ ConvertPingableHosts(pingableHosts),
+ MakeOptional(allowedEndpoint)
) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE;
}
catch (common::error::WindowsException &err)
@@ -305,7 +318,7 @@ WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyConnected(
const WinFwSettings *settings,
- const WinFwRelay *relay,
+ const WinFwEndpoint *relay,
const wchar_t *relayClient,
const wchar_t *tunnelInterfaceAlias,
const wchar_t *v4Gateway,
@@ -447,7 +460,8 @@ WINFW_LINKAGE
WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyBlocked(
- const WinFwSettings *settings
+ const WinFwSettings *settings,
+ const WinFwEndpoint *allowedEndpoint
)
{
if (nullptr == g_fwContext)
@@ -462,7 +476,7 @@ WinFw_ApplyPolicyBlocked(
THROW_ERROR("Invalid argument: settings");
}
- return g_fwContext->applyPolicyBlocked(*settings)
+ return g_fwContext->applyPolicyBlocked(*settings, MakeOptional(allowedEndpoint))
? WINFW_POLICY_STATUS_SUCCESS
: WINFW_POLICY_STATUS_GENERAL_FAILURE;
}
diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h
index f0a487cb12..23163786e9 100644
--- a/windows/winfw/src/winfw/winfw.h
+++ b/windows/winfw/src/winfw/winfw.h
@@ -37,13 +37,13 @@ enum WinFwProtocol : uint8_t
Udp = 1
};
-typedef struct tag_WinFwRelay
+typedef struct tag_WinFwEndpoint
{
const wchar_t *ip;
uint16_t port;
WinFwProtocol protocol;
}
-WinFwRelay;
+WinFwEndpoint;
#pragma pack(pop)
@@ -88,6 +88,7 @@ WINFW_API
WinFw_InitializeBlocked(
uint32_t timeout,
const WinFwSettings *settings,
+ const WinFwEndpoint *allowedEndpoint,
MullvadLogSink logSink,
void *logSinkContext
);
@@ -155,9 +156,10 @@ WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings *settings,
- const WinFwRelay *relay,
+ const WinFwEndpoint *relay,
const wchar_t *relayClient,
- const PingableHosts *pingableHosts
+ const PingableHosts *pingableHosts,
+ const WinFwEndpoint *allowedEndpoint
);
//
@@ -183,7 +185,7 @@ WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyConnected(
const WinFwSettings *settings,
- const WinFwRelay *relay,
+ const WinFwEndpoint *relay,
const wchar_t *relayClient,
const wchar_t *tunnelInterfaceAlias,
const wchar_t *v4Gateway,
@@ -203,7 +205,8 @@ WINFW_LINKAGE
WINFW_POLICY_STATUS
WINFW_API
WinFw_ApplyPolicyBlocked(
- const WinFwSettings *settings
+ const WinFwSettings *settings,
+ const WinFwEndpoint *allowedEndpoint
);
//
diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj
index 8f9c37f919..3f9502a10d 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj
+++ b/windows/winfw/src/winfw/winfw.vcxproj
@@ -27,6 +27,7 @@
<ClCompile Include="rules\baseline\permitdhcp.cpp" />
<ClCompile Include="rules\baseline\permitdhcpserver.cpp" />
<ClCompile Include="rules\baseline\permitdns.cpp" />
+ <ClCompile Include="rules\baseline\permitendpoint.cpp" />
<ClCompile Include="rules\baseline\permitlan.cpp" />
<ClCompile Include="rules\baseline\permitlanservice.cpp" />
<ClCompile Include="rules\baseline\permitloopback.cpp" />
@@ -61,6 +62,7 @@
<ClInclude Include="rules\baseline\permitdhcp.h" />
<ClInclude Include="rules\baseline\permitdhcpserver.h" />
<ClInclude Include="rules\baseline\permitdns.h" />
+ <ClInclude Include="rules\baseline\permitendpoint.h" />
<ClInclude Include="rules\baseline\permitlan.h" />
<ClInclude Include="rules\baseline\permitlanservice.h" />
<ClInclude Include="rules\baseline\permitloopback.h" />
diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters
index 312045876e..7a2aa85487 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj.filters
+++ b/windows/winfw/src/winfw/winfw.vcxproj.filters
@@ -61,6 +61,9 @@
<ClCompile Include="rules\persistent\blockall.cpp">
<Filter>rules\persistent</Filter>
</ClCompile>
+ <ClCompile Include="rules\baseline\permitendpoint.cpp">
+ <Filter>rules\baseline</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -132,6 +135,9 @@
<ClInclude Include="rules\persistent\blockall.h">
<Filter>rules\persistent</Filter>
</ClInclude>
+ <ClInclude Include="rules\baseline\permitendpoint.h">
+ <Filter>rules\baseline</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="rules">