diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-12-07 20:31:40 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-01-04 16:50:18 +0100 |
| commit | d9baa6bf9d98858d9f5bae95740b9d5ecb192c0f (patch) | |
| tree | 1ce85578b7bb8ca1dddad78120231a1c28113146 | |
| parent | 07d363b919ee0c9e33f444475361194a29f37216 (diff) | |
| download | mullvadvpn-d9baa6bf9d98858d9f5bae95740b9d5ecb192c0f.tar.xz mullvadvpn-d9baa6bf9d98858d9f5bae95740b9d5ecb192c0f.zip | |
Unblock API endpoint while connecting or blocked
28 files changed, 635 insertions, 97 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 9a55798f1e..00cbec6d36 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -62,7 +62,7 @@ use talpid_core::{ #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; use talpid_types::{ - net::{openvpn, TransportProtocol, TunnelParameters, TunnelType}, + net::{openvpn, Endpoint, TransportProtocol, TunnelParameters, TunnelType}, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, }; @@ -261,7 +261,7 @@ pub(crate) enum InternalDaemonEvent { /// The background job fetching new `AppVersionInfo`s got a new info object. NewAppVersionInfo(AppVersionInfo), /// A new API endpoint is being used - NewApiAddress(SocketAddr), + NewApiAddress(SocketAddr, oneshot::Sender<()>), } impl From<TunnelStateTransition> for InternalDaemonEvent { @@ -495,6 +495,7 @@ where let (internal_event_tx, internal_event_rx) = command_channel.destructure(); let address_change_tx = std::sync::Mutex::new(internal_event_tx.clone()); + let address_change_runtime = tokio::runtime::Handle::current(); let mut rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( tokio::runtime::Handle::current(), @@ -502,13 +503,18 @@ where &user_cache_dir, true, move |address| { + let (result_tx, result_rx) = oneshot::channel(); + let tx = address_change_tx.lock().unwrap(); if tx - .send(InternalDaemonEvent::NewApiAddress(address)) + .send(InternalDaemonEvent::NewApiAddress(address, result_tx)) .is_err() { log::error!("Failed to send API address daemon event"); + return Err(()); } + + address_change_runtime.block_on(result_rx).map_err(|_| ()) }, ) .await @@ -590,10 +596,16 @@ where TargetState::Unsecured }; + let initial_api_endpoint = Endpoint::from_socket_address( + rpc_runtime.address_cache.peek_address(), + TransportProtocol::Tcp, + ); + let tunnel_command_tx = tunnel_state_machine::spawn( settings.allow_lan, settings.block_when_disconnected, Self::get_custom_resolvers(&settings.tunnel_options.dns_options), + initial_api_endpoint, tunnel_parameters_generator, log_dir, resource_dir, @@ -763,9 +775,11 @@ where NewAppVersionInfo(app_version_info) => { self.handle_new_app_version_info(app_version_info) } - NewApiAddress(address) => { - // TODO - log::info!("ADDRESS! {:?}", address); + NewApiAddress(address, tx) => { + self.send_tunnel_command(TunnelCommand::AllowEndpoint( + Endpoint::from_socket_address(address, TransportProtocol::Tcp), + tx, + )); } } } diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 869661eb4e..f4ee9d73ae 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -281,7 +281,7 @@ pub fn send_problem_report( None, user_cache_dir, false, - |_| {}, + |_| Ok(()), )) .map_err(Error::CreateRpcClientError)?; let rpc_client = mullvad_rpc::ProblemReportProxy::new(rpc_manager.mullvad_rest_handle()); diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs index 5da1f09359..757d6645ed 100644 --- a/mullvad-rpc/src/address_cache.rs +++ b/mullvad-rpc/src/address_cache.rs @@ -3,6 +3,7 @@ use rand::seq::SliceRandom; use std::{ io, net::SocketAddr, + ops::{Deref, DerefMut}, path::Path, sync::{Arc, Mutex}, }; @@ -26,9 +27,13 @@ pub enum Error { #[error(display = "The address cache is empty")] EmptyAddressCache, + + #[error(display = "The address change listener returned an error")] + ChangeListenerError, } -pub type CurrentAddressChangeListener = dyn Fn(SocketAddr) + Send + Sync + 'static; +pub type CurrentAddressChangeListener = + dyn Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static; #[derive(Clone)] pub struct AddressCache { @@ -71,6 +76,10 @@ impl AddressCache { ) } + pub fn set_change_listener(&mut self, change_listener: Arc<Box<CurrentAddressChangeListener>>) { + self.change_listener = change_listener; + } + /// Returns the currently selected address. pub fn get_address(&self) -> SocketAddr { let mut inner = self.inner.lock().unwrap(); @@ -101,18 +110,24 @@ impl AddressCache { } pub async fn select_new_address(&self) { - let (new_address, new_choice, old_choice) = { + { let mut inner = self.inner.lock().unwrap(); - let old_choice = inner.choice; - inner.choice = inner.choice.wrapping_add(1); - (Self::get_address_inner(&inner), inner.choice, old_choice) - }; + let mut transaction = AddressCacheTransaction::new(&mut inner); - if new_choice == old_choice { - return; - } + transaction.choice = transaction.current.choice.wrapping_add(1); + if transaction.choice == transaction.current.choice { + return; + } + transaction.tried_current = false; - (*self.change_listener)(new_address); + tokio::task::block_in_place(move || { + if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() { + log::error!("Failed to select a new API endpoint"); + return; + } + transaction.commit(); + }); + } if let Err(error) = self.save_to_disk().await { log::error!("{}", error.display_chain()); @@ -124,12 +139,25 @@ impl AddressCache { pub async fn randomize(&self) -> Result<(), Error> { { let mut inner = self.inner.lock().unwrap(); - inner.shuffle(); - inner.choice = 0; - inner.tried_current = false; - let new_address = Self::get_address_inner(&inner); - (*self.change_listener)(new_address); + let mut transaction = AddressCacheTransaction::new(&mut inner); + transaction.shuffle(); + transaction.choice = 0; + + let current_address = Self::get_address_inner(&transaction.current); + let new_address = Self::get_address_inner(&transaction); + + tokio::task::block_in_place(move || { + if new_address != current_address { + transaction.tried_current = false; + if (*self.change_listener)(new_address).is_err() { + return Err(Error::ChangeListenerError); + } + } + + transaction.commit(); + Ok(()) + })?; } self.save_to_disk().await.map_err(Error::WriteAddressCache) } @@ -137,28 +165,42 @@ impl AddressCache { pub async fn set_addresses(&self, mut addresses: Vec<SocketAddr>) -> io::Result<()> { let should_update = { let mut inner = self.inner.lock().unwrap(); + let mut transaction = AddressCacheTransaction::new(&mut inner); + addresses.sort(); - let mut current_sorted = inner.addresses.clone(); + + let mut current_sorted = transaction.addresses.clone(); current_sorted.sort(); + if addresses != current_sorted { - let current_address = Self::get_address_inner(&inner); + let current_address = Self::get_address_inner(&transaction); - inner.addresses = addresses.clone(); - inner.shuffle(); + transaction.addresses = addresses.clone(); + transaction.shuffle(); // Prefer a likely-working address - let choice = inner + let choice = transaction .addresses .iter() .position(|&addr| addr == current_address); if let Some(choice) = choice { - inner.choice = choice; + transaction.choice = choice; + transaction.commit(); } else { - inner.choice = 0; - inner.tried_current = false; + transaction.choice = 0; + transaction.tried_current = false; - let new_address = Self::get_address_inner(&inner); - (*self.change_listener)(new_address); + tokio::task::block_in_place(move || { + if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() { + log::error!("Failed to select a new API endpoint"); + return Err(io::Error::new( + io::ErrorKind::Other, + "callback returned an error", + )); + } + transaction.commit(); + Ok(()) + })?; } true @@ -217,6 +259,7 @@ impl crate::rest::AddressProvider for AddressCache { } +#[derive(Clone, PartialEq, Eq)] struct AddressCacheInner { addresses: Vec<SocketAddr>, choice: usize, @@ -247,6 +290,38 @@ impl AddressCacheInner { } } +struct AddressCacheTransaction<'a> { + current: &'a mut AddressCacheInner, + working_cache: AddressCacheInner, +} + +impl<'a> AddressCacheTransaction<'a> { + fn new(cache: &'a mut AddressCacheInner) -> Self { + Self { + working_cache: cache.clone(), + current: cache, + } + } + + fn commit(self) { + *self.current = self.working_cache; + } +} + +impl<'a> Deref for AddressCacheTransaction<'a> { + type Target = AddressCacheInner; + + fn deref(&self) -> &Self::Target { + &self.working_cache + } +} + +impl<'a> DerefMut for AddressCacheTransaction<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.working_cache + } +} + async fn read_address_file(path: &Path) -> Result<Vec<SocketAddr>, Error> { let file = fs::File::open(path) .await diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index ca24288696..4beefa7d3b 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -23,8 +23,7 @@ use crate::https_client_with_sni::HttpsConnectorWithSni; mod address_cache; mod relay_list; -use address_cache::AddressCache; -pub use address_cache::CurrentAddressChangeListener; +pub use address_cache::{AddressCache, CurrentAddressChangeListener}; pub use hyper::StatusCode; pub use relay_list::RelayListProxy; @@ -44,7 +43,7 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443); pub struct MullvadRpcRuntime { https_connector: HttpsConnectorWithSni, handle: tokio::runtime::Handle, - address_cache: AddressCache, + pub address_cache: AddressCache, } #[derive(err_derive::Error, Debug)] @@ -65,7 +64,7 @@ impl MullvadRpcRuntime { address_cache: AddressCache::new( vec![API_ADDRESS.into()], None, - Arc::new(Box::new(|_| {})), + Arc::new(Box::new(|_| Ok(()))), )?, }) } @@ -78,7 +77,7 @@ impl MullvadRpcRuntime { resource_dir: Option<&Path>, cache_dir: &Path, write_changes: bool, - address_change_listener: impl Fn(SocketAddr) + Send + Sync + 'static, + address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static, ) -> Result<Self, Error> { let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); let write_file = if write_changes { @@ -113,13 +112,12 @@ impl MullvadRpcRuntime { match resource_dir { Some(resource_dir) => { let read_file = resource_dir.join(API_IP_CACHE_FILENAME); - let cache = AddressCache::from_file( - &read_file, - write_file, - address_change_listener, - ) - .await?; + let empty_listener = + Arc::<Box<CurrentAddressChangeListener>>::new(Box::new(|_| Ok(()))); + let mut cache = + AddressCache::from_file(&read_file, write_file, empty_listener).await?; cache.randomize().await?; + cache.set_change_listener(address_change_listener); cache } None => return Err(Error::AddressCacheError(error)), diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 7e1021f4ea..c5ed2d38f0 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -144,6 +144,7 @@ async fn reset_firewall() -> Result<(), Error> { let mut firewall = Firewall::new(FirewallArguments { initialize_blocked: false, allow_lan: true, + allowed_endpoint: None, }) .map_err(Error::FirewallError)?; @@ -158,7 +159,7 @@ async fn clear_history() -> Result<(), Error> { None, &user_cache_path, false, - |_| {}, + |_| Ok(()), ) .await .map_err(Error::RpcInitializationError)?; diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 3c252313ce..04bd00777d 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -531,10 +531,12 @@ impl<'a> PolicyBatch<'a> { peer_endpoint, pingable_hosts, allow_lan, + allowed_endpoint, use_fwmark, } => { self.add_allow_icmp_pingable_hosts(&pingable_hosts); - self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_endpoint_rules(allowed_endpoint); // Important to block DNS after allow relay rule (so the relay can operate // over port 53) but before allow LAN (so DNS does not leak to the LAN) @@ -548,7 +550,7 @@ impl<'a> PolicyBatch<'a> { dns_servers, use_fwmark, } => { - self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark); self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Udp)?; self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Tcp)?; // Important to block DNS *before* we allow the tunnel and allow LAN. So DNS @@ -560,7 +562,12 @@ impl<'a> PolicyBatch<'a> { } *allow_lan } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { + self.add_allow_endpoint_rules(allowed_endpoint); + // Important to drop DNS before allowing LAN (to stop DNS leaking to the LAN) self.add_drop_dns_rule(); *allow_lan @@ -582,7 +589,7 @@ impl<'a> PolicyBatch<'a> { Ok(()) } - fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) { + fn add_allow_tunnel_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) { let mut in_rule = Rule::new(&self.in_chain); check_endpoint(&mut in_rule, End::Src, endpoint); @@ -608,6 +615,20 @@ impl<'a> PolicyBatch<'a> { self.batch.add(&out_rule, nftnl::MsgType::Add); } + fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) { + let mut in_rule = Rule::new(&self.in_chain); + check_endpoint(&mut in_rule, End::Src, endpoint); + add_verdict(&mut in_rule, &Verdict::Accept); + + self.batch.add(&in_rule, nftnl::MsgType::Add); + + let mut out_rule = Rule::new(&self.out_chain); + check_endpoint(&mut out_rule, End::Dst, endpoint); + add_verdict(&mut out_rule, &Verdict::Accept); + + self.batch.add(&out_rule, nftnl::MsgType::Add); + } + fn add_allow_icmp_pingable_hosts(&mut self, pingable_hosts: &[IpAddr]) { for host in pingable_hosts { let icmp_proto = match &host { diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs index dfdc1e31fc..2e23c99dd2 100644 --- a/talpid-core/src/firewall/macos.rs +++ b/talpid-core/src/firewall/macos.rs @@ -98,9 +98,11 @@ impl Firewall { FirewallPolicy::Connecting { peer_endpoint, allow_lan, + allowed_endpoint, pingable_hosts, } => { let mut rules = vec![self.get_allow_relay_rule(peer_endpoint)?]; + rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?); rules.extend(self.get_allow_pingable_hosts(&pingable_hosts)?); if allow_lan { // Important to block DNS after allow relay rule (so the relay can operate @@ -136,8 +138,12 @@ impl Firewall { Ok(rules) } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { let mut rules = Vec::new(); + rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?); if allow_lan { // Important to block DNS before allow LAN (so DNS does not leak to the LAN) rules.append(&mut self.get_block_dns_rules()?); @@ -247,6 +253,22 @@ impl Firewall { .build()?) } + fn get_allowed_endpoint_rule( + &self, + allowed_endpoint: net::Endpoint, + ) -> Result<pfctl::FilterRule> { + let pfctl_proto = as_pfctl_proto(allowed_endpoint.protocol); + + Ok(self + .create_rule_builder(FilterRuleAction::Pass) + .direction(pfctl::Direction::Out) + .to(allowed_endpoint.address) + .proto(pfctl_proto) + .keep_state(pfctl::StatePolicy::Keep) + .quick(true) + .build()?) + } + fn get_block_dns_rules(&self) -> Result<Vec<pfctl::FilterRule>> { let block_tcp_dns_rule = self .create_rule_builder(FilterRuleAction::Drop(DropAction::Return)) diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index b467f37d98..83a112ce88 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -107,6 +107,8 @@ pub enum FirewallPolicy { pingable_hosts: Vec<IpAddr>, /// Flag setting if communication with LAN networks should be possible. allow_lan: bool, + /// Host that should be reachable by the tunnel client while connecting. + allowed_endpoint: Endpoint, /// A process that is allowed to send packets to the relay. #[cfg(windows)] relay_client: PathBuf, @@ -140,6 +142,8 @@ pub enum FirewallPolicy { Blocked { /// Flag setting if communication with LAN networks should be possible. allow_lan: bool, + /// Host that should be reachable while in the blocked state. + allowed_endpoint: Endpoint, }, } @@ -182,10 +186,14 @@ impl fmt::Display for FirewallPolicy { tunnel.ipv6_gateway, if *allow_lan { "Allowing" } else { "Blocking" } ), - FirewallPolicy::Blocked { allow_lan } => write!( + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => write!( f, - "Blocked, {} LAN", - if *allow_lan { "Allowing" } else { "Blocking" } + "Blocked. {} LAN. Allowing endpoint {}", + if *allow_lan { "Allowing" } else { "Blocking" }, + allowed_endpoint, ), } } @@ -203,6 +211,8 @@ pub struct FirewallArguments { pub initialize_blocked: bool, /// This argument is required for the blocked state to configure the firewall correctly. pub allow_lan: bool, + /// This argument is required for the blocked state to configure the firewall correctly. + pub allowed_endpoint: Option<Endpoint>, } impl Firewall { diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index d1fa08a3e6..8375eb55d6 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -57,10 +57,23 @@ impl FirewallT for Firewall { if args.initialize_blocked { let cfg = &WinFwSettings::new(args.allow_lan); + + let winfw_allowed_endpoint = if let Some(allowed_endpoint) = args.allowed_endpoint { + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }) + } else { + None + }; + unsafe { WinFw_InitializeBlocked( WINFW_TIMEOUT_SECONDS, &cfg, + winfw_allowed_endpoint.as_ptr(), Some(log_sink), logging_context, ) @@ -83,6 +96,7 @@ impl FirewallT for Firewall { peer_endpoint, pingable_hosts, allow_lan, + allowed_endpoint, relay_client, } => { let cfg = &WinFwSettings::new(allow_lan); @@ -91,6 +105,7 @@ impl FirewallT for Firewall { &peer_endpoint, &cfg, "Mullvad".to_string(), + &allowed_endpoint, &pingable_hosts, &relay_client, ) @@ -105,9 +120,12 @@ impl FirewallT for Firewall { let cfg = &WinFwSettings::new(allow_lan); self.set_connected_state(&peer_endpoint, &cfg, &tunnel, &dns_servers, &relay_client) } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { let cfg = &WinFwSettings::new(allow_lan); - self.set_blocked_state(&cfg) + self.set_blocked_state(&cfg, &allowed_endpoint) } } } @@ -138,12 +156,13 @@ impl Firewall { endpoint: &Endpoint, winfw_settings: &WinFwSettings, _tunnel_iface_alias: String, + allowed_endpoint: &Endpoint, pingable_hosts: &Vec<IpAddr>, relay_client: &Path, ) -> Result<(), Error> { trace!("Applying 'connecting' firewall policy"); let ip_str = Self::widestring_ip(endpoint.address.ip()); - let winfw_relay = WinFwRelay { + let winfw_relay = WinFwEndpoint { ip: ip_str.as_ptr(), port: endpoint.address.port(), protocol: WinFwProt::from(endpoint.protocol), @@ -171,12 +190,20 @@ impl Firewall { None }; + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + let winfw_allowed_endpoint = Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }); + unsafe { WinFw_ApplyPolicyConnecting( winfw_settings, &winfw_relay, relay_client.as_ptr(), pingable_hosts.as_ptr(), + winfw_allowed_endpoint.as_ptr(), ) .into_result() .map_err(Error::ApplyingConnectingPolicy) @@ -207,7 +234,7 @@ impl Firewall { WideCString::new(tunnel_metadata.interface.encode_utf16().collect::<Vec<_>>()).unwrap(); // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay - let winfw_relay = WinFwRelay { + let winfw_relay = WinFwEndpoint { ip: ip_str.as_ptr(), port: endpoint.address.port(), protocol: WinFwProt::from(endpoint.protocol), @@ -258,10 +285,22 @@ impl Firewall { } } - fn set_blocked_state(&mut self, winfw_settings: &WinFwSettings) -> Result<(), Error> { + fn set_blocked_state( + &mut self, + winfw_settings: &WinFwSettings, + allowed_endpoint: &Endpoint, + ) -> Result<(), Error> { trace!("Applying 'blocked' firewall policy"); + + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + let winfw_allowed_endpoint = Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }); + unsafe { - WinFw_ApplyPolicyBlocked(winfw_settings) + WinFw_ApplyPolicyBlocked(winfw_settings, winfw_allowed_endpoint.as_ptr()) .into_result() .map_err(Error::ApplyingBlockedPolicy) } @@ -289,7 +328,7 @@ mod winfw { use talpid_types::net::TransportProtocol; #[repr(C)] - pub struct WinFwRelay { + pub struct WinFwEndpoint { pub ip: *const libc::wchar_t, pub port: u16, pub protocol: WinFwProt, @@ -385,6 +424,7 @@ mod winfw { pub fn WinFw_InitializeBlocked( timeout: libc::c_uint, settings: &WinFwSettings, + allowed_endpoint: *const WinFwEndpoint, sink: Option<LogSink>, sink_context: *const u8, ) -> InitializationResult; @@ -395,15 +435,16 @@ mod winfw { #[link_name = "WinFw_ApplyPolicyConnecting"] pub fn WinFw_ApplyPolicyConnecting( settings: &WinFwSettings, - relay: &WinFwRelay, + relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, pingable_hosts: *const WinFwPingableHosts, + allowed_endpoint: *const WinFwEndpoint, ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyConnected"] pub fn WinFw_ApplyPolicyConnected( settings: &WinFwSettings, - relay: &WinFwRelay, + relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, tunnelIfaceAlias: *const libc::wchar_t, v4Gateway: *const libc::wchar_t, @@ -413,7 +454,10 @@ mod winfw { ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyBlocked"] - pub fn WinFw_ApplyPolicyBlocked(settings: &WinFwSettings) -> WinFwPolicyStatus; + pub fn WinFw_ApplyPolicyBlocked( + settings: &WinFwSettings, + allowed_endpoint: *const WinFwEndpoint, + ) -> WinFwPolicyStatus; #[link_name = "WinFw_Reset"] pub fn WinFw_Reset() -> WinFwPolicyStatus; diff --git a/talpid-core/src/tunnel/tun_provider/android/mod.rs b/talpid-core/src/tunnel/tun_provider/android/mod.rs index b9385f13a7..fa48f115b9 100644 --- a/talpid-core/src/tunnel/tun_provider/android/mod.rs +++ b/talpid-core/src/tunnel/tun_provider/android/mod.rs @@ -66,6 +66,7 @@ pub struct AndroidTunProvider { object: GlobalRef, last_tun_config: TunConfig, allow_lan: bool, + allowed_endpoint: IpAddr, custom_dns_servers: Option<Vec<IpAddr>>, } @@ -74,6 +75,7 @@ impl AndroidTunProvider { pub fn new( context: AndroidContext, allow_lan: bool, + allowed_endpoint: IpAddr, custom_dns_servers: Option<Vec<IpAddr>>, ) -> Self { let env = JnixEnv::from( @@ -90,6 +92,7 @@ impl AndroidTunProvider { object: context.vpn_service, last_tun_config: TunConfig::default(), allow_lan, + allowed_endpoint, custom_dns_servers, } } @@ -103,6 +106,10 @@ impl AndroidTunProvider { Ok(()) } + pub fn set_allowed_endpoint(&mut self, endpoint: IpAddr) { + self.allowed_endpoint = endpoint; + } + pub fn set_custom_dns_servers(&mut self, servers: Option<Vec<IpAddr>>) -> Result<(), Error> { if self.custom_dns_servers != servers { self.custom_dns_servers = servers; @@ -129,6 +136,19 @@ impl AndroidTunProvider { }) } + /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and (potentially) + /// LAN routes via the tunnel device. + /// + /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be + /// closed. + pub fn create_blocking_tun(&mut self) -> Result<(), Error> { + let mut config = TunConfig::default(); + self.prepare_tun_config(&mut config); + self.prepare_tun_config_for_allowed_endpoint(&mut config); + let _ = self.get_tun(config)?; + Ok(()) + } + /// Open a tunnel device using the previous or the default configuration. /// /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be @@ -231,6 +251,24 @@ impl AndroidTunProvider { } } + fn prepare_tun_config_for_allowed_endpoint(&self, config: &mut TunConfig) { + let endpoint_net = IpNetwork::from(self.allowed_endpoint); + let routes = config + .routes + .iter() + .flat_map(|&route| { + if route.is_ipv4() && endpoint_net.is_ipv4() { + route.sub(endpoint_net).collect() + } else if route.is_ipv6() && endpoint_net.is_ipv6() { + route.sub(endpoint_net).collect() + } else { + vec![route] + } + }) + .collect(); + config.routes = routes; + } + fn prepare_tun_config(&self, config: &mut TunConfig) { self.prepare_tun_config_for_allow_lan(config); self.prepare_tun_config_for_custom_dns(config); diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 7292da0c67..0c305de9a7 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -192,6 +192,13 @@ impl ConnectedState { } } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { match shared_values.set_custom_dns(servers) { Ok(true) => { diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 44dcd9f153..0b03ceeca1 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -63,6 +63,7 @@ impl ConnectingState { peer_endpoint, pingable_hosts: gateway_list_from_params(params), allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), #[cfg(windows)] relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, ¶ms), #[cfg(target_os = "linux")] @@ -235,6 +236,22 @@ impl ConnectingState { } } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + if let Err(error) = + Self::set_firewall_policy(shared_values, &self.tunnel_parameters) + { + return self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), + ); + } + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { match shared_values.set_custom_dns(servers) { #[cfg(target_os = "android")] diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index dcc4660e9f..922eb69c88 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -17,6 +17,7 @@ impl DisconnectedState { let result = if shared_values.block_when_disconnected { let policy = FirewallPolicy::Blocked { allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), }; shared_values.firewall.apply_policy(policy).map_err(|e| { e.display_chain_with_msg( @@ -77,6 +78,15 @@ impl TunnelState for DisconnectedState { } SameState(self.into()) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + Self::set_firewall_policy(shared_values, true); + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { // Same situation as allow LAN above. shared_values diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 0928834d1c..48a83a6dc3 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -32,6 +32,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Nothing } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Nothing + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Nothing @@ -53,6 +60,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Block(reason) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Block(reason) + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Block(reason) @@ -79,6 +93,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Reconnect(retry_attempt) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Reconnect(retry_attempt) + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Reconnect(retry_attempt) diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index a87dccd5b4..51159d274f 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -21,6 +21,7 @@ impl ErrorState { ) -> Result<(), FirewallPolicyError> { let policy = FirewallPolicy::Blocked { allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), }; #[cfg(target_os = "linux")] @@ -47,7 +48,7 @@ impl ErrorState { /// Returns true if a new tunnel device was successfully created. #[cfg(target_os = "android")] fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool { - match shared_values.tun_provider.create_tun_if_closed() { + match shared_values.tun_provider.create_blocking_tun() { Ok(()) => true, Err(error) => { log::error!( @@ -105,6 +106,23 @@ impl TunnelState for ErrorState { SameState(self.into()) } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + let _ = Self::set_firewall_policy(shared_values); + + #[cfg(target_os = "android")] + if !Self::create_blocking_tun(shared_values) { + return NewState(Self::enter( + shared_values, + ErrorStateCause::SetFirewallPolicyError(FirewallPolicyError::Generic), + )); + } + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { if let Err(error_state_cause) = shared_values.set_custom_dns(servers) { NewState(Self::enter(shared_values, error_state_cause)) diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index fbec1bf2b1..b657ec5e36 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -33,7 +33,7 @@ use std::{ #[cfg(target_os = "android")] use talpid_types::{android::AndroidContext, ErrorExt}; use talpid_types::{ - net::TunnelParameters, + net::{Endpoint, TunnelParameters}, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, }; @@ -75,6 +75,7 @@ pub async fn spawn( allow_lan: bool, block_when_disconnected: bool, custom_dns: Option<Vec<IpAddr>>, + allowed_endpoint: Endpoint, tunnel_parameters_generator: impl TunnelParametersGenerator, log_dir: Option<PathBuf>, resource_dir: PathBuf, @@ -101,6 +102,8 @@ pub async fn spawn( #[cfg(target_os = "android")] allow_lan, #[cfg(target_os = "android")] + allowed_endpoint.address.ip(), + #[cfg(target_os = "android")] custom_dns.clone(), ); @@ -114,6 +117,7 @@ pub async fn spawn( block_when_disconnected, is_offline, custom_dns, + allowed_endpoint, tunnel_parameters_generator, tun_provider, log_dir, @@ -152,6 +156,9 @@ pub async fn spawn( pub enum TunnelCommand { /// Enable or disable LAN access in the firewall. AllowLan(bool), + /// Endpoint that should never be blocked. + /// If an error occurs, the sender is dropped. + AllowEndpoint(Endpoint, oneshot::Sender<()>), /// Set custom DNS servers to use. CustomDns(Option<Vec<IpAddr>>), /// Enable or disable the block_when_disconnected feature. @@ -193,6 +200,7 @@ impl TunnelStateMachine { block_when_disconnected: bool, is_offline: bool, custom_dns: Option<Vec<IpAddr>>, + allowed_endpoint: Endpoint, tunnel_parameters_generator: impl TunnelParametersGenerator, tun_provider: TunProvider, log_dir: Option<PathBuf>, @@ -204,6 +212,7 @@ impl TunnelStateMachine { let args = FirewallArguments { initialize_blocked: block_when_disconnected || !reset_firewall, allow_lan, + allowed_endpoint: Some(allowed_endpoint), }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; @@ -218,6 +227,7 @@ impl TunnelStateMachine { block_when_disconnected, is_offline, custom_dns, + allowed_endpoint, tunnel_parameters_generator: Box::new(tunnel_parameters_generator), tun_provider, log_dir, @@ -291,6 +301,8 @@ struct SharedTunnelStateValues { is_offline: bool, /// Custom DNS servers to use. custom_dns: Option<Vec<IpAddr>>, + /// Endpoint that should not be blocked by the firewall. + allowed_endpoint: Endpoint, /// The generator of new `TunnelParameter`s tunnel_parameters_generator: Box<dyn TunnelParametersGenerator>, /// The provider of tunnel devices. @@ -328,6 +340,20 @@ impl SharedTunnelStateValues { Ok(()) } + pub fn set_allowed_endpoint(&mut self, endpoint: Endpoint) -> bool { + if self.allowed_endpoint != endpoint { + self.allowed_endpoint = endpoint; + + #[cfg(target_os = "android")] + self.tun_provider + .set_allowed_endpoint(endpoint.address.ip()); + + true + } else { + false + } + } + pub fn set_custom_dns( &mut self, custom_dns: Option<Vec<IpAddr>>, diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 15bf33a5bb..29a871ee07 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -144,6 +144,10 @@ impl Endpoint { protocol, } } + + pub fn from_socket_address(address: SocketAddr, protocol: TransportProtocol) -> Self { + Endpoint { address, protocol } + } } impl fmt::Display for Endpoint { diff --git a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp index a722501832..7f5b9d980c 100644 --- a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp +++ b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp @@ -26,9 +26,9 @@ WinFwProtocol TranslateProtocol(const std::wstring &protocol) return (0 == _wcsicmp(protocol.c_str(), L"tcp") ? WinFwProtocol::Tcp : WinFwProtocol::Udp); } -WinFwRelay CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol) +WinFwEndpoint CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol) { - WinFwRelay r; + WinFwEndpoint r; r.ip = ip; r.port = common::string::LexicalCast<uint16_t>(port); diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index b8959cfc01..99a3d3dc26 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -15,6 +15,7 @@ #include "rules/baseline/permitvpntunnelservice.h" #include "rules/baseline/permitping.h" #include "rules/baseline/permitdns.h" +#include "rules/baseline/permitendpoint.h" #include "rules/dns/blockall.h" #include "rules/dns/permittunnel.h" #include "rules/dns/permitnontunnel.h" @@ -43,6 +44,19 @@ multi::PermitVpnRelay::Protocol TranslateProtocol(WinFwProtocol protocol) }; } +baseline::PermitEndpoint::Protocol TranslateEndpointProtocol(WinFwProtocol protocol) +{ + switch (protocol) + { + case Tcp: return baseline::PermitEndpoint::Protocol::Tcp; + case Udp: return baseline::PermitEndpoint::Protocol::Udp; + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + // // Since the PermitLan rule doesn't specifically address DNS, it will allow DNS requests targetting // a local resolver to leave the machine. From the local resolver the request will either be @@ -91,7 +105,7 @@ void AppendSettingsRules void AppendRelayRules ( FwContext::Ruleset &ruleset, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient ) { @@ -111,6 +125,22 @@ void AppendRelayRules )); } +// +// Refer comment on `AppendSettingsRules`. +// +void AppendAllowedEndpointRules +( + FwContext::Ruleset &ruleset, + const WinFwEndpoint &endpoint +) +{ + ruleset.emplace_back(std::make_unique<baseline::PermitEndpoint>( + wfp::IpAddress(endpoint.ip), + endpoint.port, + TranslateEndpointProtocol(endpoint.protocol) + )); +} + void AppendNetBlockedRules(FwContext::Ruleset &ruleset) { ruleset.emplace_back(std::make_unique<baseline::BlockAll>()); @@ -145,7 +175,8 @@ FwContext::FwContext FwContext::FwContext ( uint32_t timeout, - const WinFwSettings &settings + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint ) : m_baseline(0) , m_activePolicy(Policy::None) @@ -159,7 +190,7 @@ FwContext::FwContext uint32_t checkpoint = 0; - if (false == applyBlockedBaseConfiguration(settings, checkpoint)) + if (false == applyBlockedBaseConfiguration(settings, allowedEndpoint, checkpoint)) { THROW_ERROR("Failed to apply base configuration in BFE"); } @@ -171,9 +202,10 @@ FwContext::FwContext bool FwContext::applyPolicyConnecting ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, - const std::optional<PingableHosts> &pingableHosts + const std::optional<PingableHosts> &pingableHosts, + const std::optional<WinFwEndpoint> &allowedEndpoint ) { Ruleset ruleset; @@ -182,6 +214,11 @@ bool FwContext::applyPolicyConnecting AppendSettingsRules(ruleset, settings); AppendRelayRules(ruleset, relay, relayClient); + if (allowedEndpoint.has_value()) + { + AppendAllowedEndpointRules(ruleset, allowedEndpoint.value()); + } + // // Permit pinging the gateway inside the tunnel. // @@ -208,7 +245,7 @@ bool FwContext::applyPolicyConnecting bool FwContext::applyPolicyConnected ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, const std::wstring &tunnelInterfaceAlias, const std::vector<wfp::IpAddress> &tunnelDnsServers, @@ -252,9 +289,9 @@ bool FwContext::applyPolicyConnected return status; } -bool FwContext::applyPolicyBlocked(const WinFwSettings &settings) +bool FwContext::applyPolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint) { - const auto status = applyRuleset(composePolicyBlocked(settings)); + const auto status = applyRuleset(composePolicyBlocked(settings, allowedEndpoint)); if (status) { @@ -284,13 +321,18 @@ FwContext::Policy FwContext::activePolicy() const return m_activePolicy; } -FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings) +FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint) { Ruleset ruleset; AppendNetBlockedRules(ruleset); AppendSettingsRules(ruleset, settings); + if (allowedEndpoint.has_value()) + { + AppendAllowedEndpointRules(ruleset, allowedEndpoint.value()); + } + return ruleset; } @@ -302,7 +344,7 @@ bool FwContext::applyBaseConfiguration() }); } -bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint) +bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint) { return m_sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine) { @@ -318,7 +360,7 @@ bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uin // checkpoint = controller.peekCheckpoint(); - return applyRulesetDirectly(composePolicyBlocked(settings), controller); + return applyRulesetDirectly(composePolicyBlocked(settings, allowedEndpoint), controller); }); } diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index 100672073a..bbbb1de485 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -20,7 +20,8 @@ public: FwContext ( uint32_t timeout, - const WinFwSettings &settings + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint ); struct PingableHosts @@ -32,22 +33,26 @@ public: bool applyPolicyConnecting ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, - const std::optional<PingableHosts> &pingableHosts + const std::optional<PingableHosts> &pingableHosts, + const std::optional<WinFwEndpoint> &allowedEndpoint ); bool applyPolicyConnected ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, const std::wstring &tunnelInterfaceAlias, const std::vector<wfp::IpAddress> &tunnelDnsServers, const std::vector<wfp::IpAddress> &nonTunnelDnsServers ); - bool applyPolicyBlocked(const WinFwSettings &settings); + bool applyPolicyBlocked( + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint + ); bool reset(); @@ -68,10 +73,10 @@ private: FwContext(const FwContext &) = delete; FwContext &operator=(const FwContext &) = delete; - Ruleset composePolicyBlocked(const WinFwSettings &settings); + Ruleset composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint); bool applyBaseConfiguration(); - bool applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint); + bool applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint); bool applyCommonBaseConfiguration(SessionController &controller, wfp::FilterEngine &engine); bool applyRuleset(const Ruleset &ruleset); diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp index 0a22be1740..417b157f82 100644 --- a/windows/winfw/src/winfw/mullvadguids.cpp +++ b/windows/winfw/src/winfw/mullvadguids.cpp @@ -129,6 +129,7 @@ MullvadGuids::DetailedIdentityRegistry MullvadGuids::DetailedRegistry(IdentityQu registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Inbound_Request_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Outbound_Response_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnRelay())); + registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitEndpoint())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnelService_Ipv4())); @@ -644,6 +645,20 @@ const GUID &MullvadGuids::Filter_Baseline_PermitVpnRelay() } //static +const GUID &MullvadGuids::Filter_Baseline_PermitEndpoint() +{ + static const GUID g = + { + 0x99dc8dac, + 0x8520, + 0x41be, + { 0xbf, 0xab, 0x0c, 0x9, 0xbf, 0x12, 0xeb, 0 } + }; + + return g; +} + +//static const GUID &MullvadGuids::Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4() { static const GUID g = diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h index 11e396fc2b..7f00863811 100644 --- a/windows/winfw/src/winfw/mullvadguids.h +++ b/windows/winfw/src/winfw/mullvadguids.h @@ -69,6 +69,8 @@ public: static const GUID &Filter_Baseline_PermitVpnRelay(); + static const GUID &Filter_Baseline_PermitEndpoint(); + static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4(); static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6(); diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp new file mode 100644 index 0000000000..217631579d --- /dev/null +++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp @@ -0,0 +1,87 @@ +#include "stdafx.h" +#include "permitendpoint.h" +#include <winfw/mullvadguids.h> +#include <libwfp/filterbuilder.h> +#include <libwfp/conditionbuilder.h> +#include <libwfp/conditions/conditionprotocol.h> +#include <libwfp/conditions/conditionip.h> +#include <libwfp/conditions/conditionport.h> +#include <libwfp/conditions/conditionapplication.h> +#include <libcommon/error.h> + +using namespace wfp::conditions; + +namespace rules::baseline +{ + +namespace +{ + +const GUID &OutboundLayerFromIp(const wfp::IpAddress &ip) +{ + switch (ip.type()) + { + case wfp::IpAddress::Type::Ipv4: return FWPM_LAYER_ALE_AUTH_CONNECT_V4; + case wfp::IpAddress::Type::Ipv6: return FWPM_LAYER_ALE_AUTH_CONNECT_V6; + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + +std::unique_ptr<ConditionProtocol> CreateProtocolCondition(PermitEndpoint::Protocol protocol) +{ + switch (protocol) + { + case PermitEndpoint::Protocol::Tcp: return ConditionProtocol::Tcp(); + case PermitEndpoint::Protocol::Udp: return ConditionProtocol::Udp(); + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + +} // anonymous namespace + +PermitEndpoint::PermitEndpoint +( + const wfp::IpAddress &address, + uint16_t port, + Protocol protocol +) + : m_address(address) + , m_port(port) + , m_protocol(protocol) +{ +} + +bool PermitEndpoint::apply(IObjectInstaller &objectInstaller) +{ + wfp::FilterBuilder filterBuilder; + + // + // Permit outbound connections to endpoint. + // + + filterBuilder + .key(MullvadGuids::Filter_Baseline_PermitEndpoint()) + .name(L"Permit outbound connections to a given endpoint") + .description(L"This filter is part of a rule that permits traffic to a specific endpoint") + .provider(MullvadGuids::Provider()) + .layer(OutboundLayerFromIp(m_address)) + .sublayer(MullvadGuids::SublayerBaseline()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(OutboundLayerFromIp(m_address)); + + conditionBuilder.add_condition(ConditionIp::Remote(m_address)); + conditionBuilder.add_condition(ConditionPort::Remote(m_port)); + conditionBuilder.add_condition(CreateProtocolCondition(m_protocol)); + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +} diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.h b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h new file mode 100644 index 0000000000..cfa57b003d --- /dev/null +++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h @@ -0,0 +1,36 @@ +#pragma once + +#include <winfw/rules/ifirewallrule.h> +#include <libwfp/ipaddress.h> +#include <string> + +namespace rules::baseline +{ + +class PermitEndpoint : public IFirewallRule +{ +public: + + enum class Protocol + { + Tcp, + Udp + }; + + PermitEndpoint + ( + const wfp::IpAddress &address, + uint16_t port, + Protocol protocol + ); + + bool apply(IObjectInstaller &objectInstaller) override; + +private: + + const wfp::IpAddress m_address; + const uint16_t m_port; + const Protocol m_protocol; +}; + +} diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index f2d5a66b2a..ee7842877e 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -65,6 +65,16 @@ HandlePolicyException(const common::error::WindowsException &err) return WINFW_POLICY_STATUS_GENERAL_FAILURE; } +template<typename T> +std::optional<T> MakeOptional(T* object) +{ + if (nullptr == object) + { + return std::nullopt; + } + return std::make_optional(*object); +} + // // Networks for which DNS requests can be made on all network adapters. // @@ -136,6 +146,7 @@ WINFW_API WinFw_InitializeBlocked( uint32_t timeout, const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint, MullvadLogSink logSink, void *logSinkContext ) @@ -162,7 +173,7 @@ WinFw_InitializeBlocked( g_logSink = logSink; g_logSinkContext = logSinkContext; - g_fwContext = new FwContext(timeout_ms, *settings); + g_fwContext = new FwContext(timeout_ms, *settings, MakeOptional(allowedEndpoint)); } catch (std::exception &err) { @@ -247,9 +258,10 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, - const PingableHosts *pingableHosts + const PingableHosts *pingableHosts, + const WinFwEndpoint *allowedEndpoint ) { if (nullptr == g_fwContext) @@ -278,7 +290,8 @@ WinFw_ApplyPolicyConnecting( *settings, *relay, relayClient, - ConvertPingableHosts(pingableHosts) + ConvertPingableHosts(pingableHosts), + MakeOptional(allowedEndpoint) ) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE; } catch (common::error::WindowsException &err) @@ -305,7 +318,7 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, @@ -447,7 +460,8 @@ WINFW_LINKAGE WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyBlocked( - const WinFwSettings *settings + const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint ) { if (nullptr == g_fwContext) @@ -462,7 +476,7 @@ WinFw_ApplyPolicyBlocked( THROW_ERROR("Invalid argument: settings"); } - return g_fwContext->applyPolicyBlocked(*settings) + return g_fwContext->applyPolicyBlocked(*settings, MakeOptional(allowedEndpoint)) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE; } diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index f0a487cb12..23163786e9 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -37,13 +37,13 @@ enum WinFwProtocol : uint8_t Udp = 1 }; -typedef struct tag_WinFwRelay +typedef struct tag_WinFwEndpoint { const wchar_t *ip; uint16_t port; WinFwProtocol protocol; } -WinFwRelay; +WinFwEndpoint; #pragma pack(pop) @@ -88,6 +88,7 @@ WINFW_API WinFw_InitializeBlocked( uint32_t timeout, const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint, MullvadLogSink logSink, void *logSinkContext ); @@ -155,9 +156,10 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, - const PingableHosts *pingableHosts + const PingableHosts *pingableHosts, + const WinFwEndpoint *allowedEndpoint ); // @@ -183,7 +185,7 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, @@ -203,7 +205,8 @@ WINFW_LINKAGE WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyBlocked( - const WinFwSettings *settings + const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint ); // diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj index 8f9c37f919..3f9502a10d 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj +++ b/windows/winfw/src/winfw/winfw.vcxproj @@ -27,6 +27,7 @@ <ClCompile Include="rules\baseline\permitdhcp.cpp" /> <ClCompile Include="rules\baseline\permitdhcpserver.cpp" /> <ClCompile Include="rules\baseline\permitdns.cpp" /> + <ClCompile Include="rules\baseline\permitendpoint.cpp" /> <ClCompile Include="rules\baseline\permitlan.cpp" /> <ClCompile Include="rules\baseline\permitlanservice.cpp" /> <ClCompile Include="rules\baseline\permitloopback.cpp" /> @@ -61,6 +62,7 @@ <ClInclude Include="rules\baseline\permitdhcp.h" /> <ClInclude Include="rules\baseline\permitdhcpserver.h" /> <ClInclude Include="rules\baseline\permitdns.h" /> + <ClInclude Include="rules\baseline\permitendpoint.h" /> <ClInclude Include="rules\baseline\permitlan.h" /> <ClInclude Include="rules\baseline\permitlanservice.h" /> <ClInclude Include="rules\baseline\permitloopback.h" /> diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters index 312045876e..7a2aa85487 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj.filters +++ b/windows/winfw/src/winfw/winfw.vcxproj.filters @@ -61,6 +61,9 @@ <ClCompile Include="rules\persistent\blockall.cpp"> <Filter>rules\persistent</Filter> </ClCompile> + <ClCompile Include="rules\baseline\permitendpoint.cpp"> + <Filter>rules\baseline</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -132,6 +135,9 @@ <ClInclude Include="rules\persistent\blockall.h"> <Filter>rules\persistent</Filter> </ClInclude> + <ClInclude Include="rules\baseline\permitendpoint.h"> + <Filter>rules\baseline</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <Filter Include="rules"> |
