diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-01-04 16:51:07 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-01-04 16:51:07 +0100 |
| commit | 9cfad3027adea59fa26e4c610d5a0ad41e5743ed (patch) | |
| tree | f3d08577213319b78483adaa33646649fa154ce0 | |
| parent | 1680065e267b84f1443317ae942cd0c609e86b14 (diff) | |
| parent | a2379f2b5eb7a79cd0a05f63692207674e9fabd5 (diff) | |
| download | mullvadvpn-9cfad3027adea59fa26e4c610d5a0ad41e5743ed.tar.xz mullvadvpn-9cfad3027adea59fa26e4c610d5a0ad41e5743ed.zip | |
Merge branch 'unblock-api-ip'
44 files changed, 882 insertions, 217 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 523aaf6859..2858f85f99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,11 @@ Line wrap the file at 100 chars. Th OpenVPN by correctly applying the fix for [CVE-2019-14899](https://seclists.org/oss-sec/2019/q4/122). +### Changed +- Allow the API to be accessed while in a blocking state. +- Prefer the last used API endpoint when the service starts back up, as well as in other tools such + as the problem report tool. + ## [2020.8-beta2] - 2020-12-11 This release is for desktop only. diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt index ae2d1b2df2..52795f0964 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt @@ -22,7 +22,7 @@ class MullvadProblemReport { } val logDirectory = CompletableDeferred<File>() - val resourcesDirectory = CompletableDeferred<File>() + val cacheDirectory = CompletableDeferred<File>() private val commandChannel = spawnActor() @@ -112,7 +112,7 @@ class MullvadProblemReport { userEmail, userMessage, problemReportPath.await().absolutePath, - resourcesDirectory.await().absolutePath + cacheDirectory.await().absolutePath ) if (result) { @@ -132,6 +132,6 @@ class MullvadProblemReport { userEmail: String, userMessage: String, reportPath: String, - resourcesDirectory: String + cacheDirectory: String ): Boolean } diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/ui/MainActivity.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/ui/MainActivity.kt index 44996b55b8..613927fe49 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/ui/MainActivity.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/ui/MainActivity.kt @@ -88,7 +88,7 @@ class MainActivity : FragmentActivity() { problemReport.apply { logDirectory.complete(filesDir) - resourcesDirectory.complete(filesDir) + cacheDirectory.complete(cacheDir) } setContentView(R.layout.main) diff --git a/dist-assets/pkg-scripts/preinstall b/dist-assets/pkg-scripts/preinstall index 5b56185a58..6c0ee99ff3 100755 --- a/dist-assets/pkg-scripts/preinstall +++ b/dist-assets/pkg-scripts/preinstall @@ -80,7 +80,10 @@ fi # Remove the existing relay and API address cache lists. # There is a risk that they're incompatible with the format this version wants rm "$NEW_CACHE_DIR/relays.json" || true +# Old API IP cache rm "$NEW_CACHE_DIR/api-ip-address.txt" || true +# New API IP cache +rm "/Library/Caches/mullvad-vpn/api-ip-address.txt" || true # Notify the running daemon that we are going to kill it and replace it with a newer version. # This will make the daemon save it's state to a file and then lock the firewall to prevent diff --git a/dist-assets/uninstall_macos.sh b/dist-assets/uninstall_macos.sh index 622457f994..4c02245a30 100755 --- a/dist-assets/uninstall_macos.sh +++ b/dist-assets/uninstall_macos.sh @@ -40,7 +40,7 @@ sudo pkgutil --forget net.mullvad.vpn || true read -p "Do you want to delete the log and cache files the app has created? (y/n) " if [[ "$REPLY" =~ [Yy]$ ]]; then - sudo rm -rf /var/log/mullvad-vpn /var/root/Library/Caches/mullvad-vpn + sudo rm -rf /var/log/mullvad-vpn /var/root/Library/Caches/mullvad-vpn /Library/Caches/mullvad-vpn for user in /Users/*; do user_log_dir="$user/Library/Logs/Mullvad VPN" if [[ -d "$user_log_dir" ]]; then diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 290fdaa9ee..00cbec6d36 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -48,7 +48,7 @@ use std::{ io, marker::PhantomData, mem, - net::IpAddr, + net::{IpAddr, SocketAddr}, path::PathBuf, sync::{mpsc as sync_mpsc, Arc, Weak}, time::Duration, @@ -62,7 +62,7 @@ use talpid_core::{ #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; use talpid_types::{ - net::{openvpn, TransportProtocol, TunnelParameters, TunnelType}, + net::{openvpn, Endpoint, TransportProtocol, TunnelParameters, TunnelType}, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, }; @@ -260,6 +260,8 @@ pub(crate) enum InternalDaemonEvent { ), /// The background job fetching new `AppVersionInfo`s got a new info object. NewAppVersionInfo(AppVersionInfo), + /// A new API endpoint is being used + NewApiAddress(SocketAddr, oneshot::Sender<()>), } impl From<TunnelStateTransition> for InternalDaemonEvent { @@ -483,6 +485,7 @@ where resource_dir: PathBuf, settings_dir: PathBuf, cache_dir: PathBuf, + user_cache_dir: PathBuf, event_listener: L, command_channel: DaemonCommandChannel, #[cfg(target_os = "android")] android_context: AndroidContext, @@ -490,10 +493,29 @@ where let (tunnel_state_machine_shutdown_tx, tunnel_state_machine_shutdown_signal) = oneshot::channel(); + let (internal_event_tx, internal_event_rx) = command_channel.destructure(); + let address_change_tx = std::sync::Mutex::new(internal_event_tx.clone()); + let address_change_runtime = tokio::runtime::Handle::current(); + let mut rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( tokio::runtime::Handle::current(), - &resource_dir, - Some(&cache_dir), + Some(&resource_dir), + &user_cache_dir, + true, + move |address| { + let (result_tx, result_rx) = oneshot::channel(); + + let tx = address_change_tx.lock().unwrap(); + if tx + .send(InternalDaemonEvent::NewApiAddress(address, result_tx)) + .is_err() + { + log::error!("Failed to send API address daemon event"); + return Err(()); + } + + address_change_runtime.block_on(result_rx).map_err(|_| ()) + }, ) .await .map_err(Error::InitRpcFactory)?; @@ -510,8 +532,6 @@ where &cache_dir, ); - let (internal_event_tx, internal_event_rx) = command_channel.destructure(); - let mut settings = SettingsPersister::load(&settings_dir); @@ -576,10 +596,16 @@ where TargetState::Unsecured }; + let initial_api_endpoint = Endpoint::from_socket_address( + rpc_runtime.address_cache.peek_address(), + TransportProtocol::Tcp, + ); + let tunnel_command_tx = tunnel_state_machine::spawn( settings.allow_lan, settings.block_when_disconnected, Self::get_custom_resolvers(&settings.tunnel_options.dns_options), + initial_api_endpoint, tunnel_parameters_generator, log_dir, resource_dir, @@ -749,6 +775,12 @@ where NewAppVersionInfo(app_version_info) => { self.handle_new_app_version_info(app_version_info) } + NewApiAddress(address, tx) => { + self.send_tunnel_command(TunnelCommand::AllowEndpoint( + Endpoint::from_socket_address(address, TransportProtocol::Tcp), + tx, + )); + } } } diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 4e055183b9..bd91729af1 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -123,6 +123,8 @@ async fn create_daemon( .map_err(|e| e.display_chain_with_msg("Unable to get settings dir"))?; let cache_dir = mullvad_paths::cache_dir() .map_err(|e| e.display_chain_with_msg("Unable to get cache dir"))?; + let user_cache_dir = mullvad_paths::user_cache_dir() + .map_err(|e| e.display_chain_with_msg("Unable to get user cache dir"))?; let command_channel = DaemonCommandChannel::new(); let event_listener = spawn_management_interface(command_channel.sender()).await?; @@ -132,6 +134,7 @@ async fn create_daemon( resource_dir, settings_dir, cache_dir, + user_cache_dir, event_listener, command_channel, ) diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 35fe78dd01..0b95495746 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -236,6 +236,7 @@ fn spawn_daemon( Some(resource_dir.clone()), resource_dir.clone(), resource_dir, + cache_dir.clone(), cache_dir, listener, command_channel, diff --git a/mullvad-jni/src/problem_report.rs b/mullvad-jni/src/problem_report.rs index fb83c06bdc..f1a7e884db 100644 --- a/mullvad-jni/src/problem_report.rs +++ b/mullvad-jni/src/problem_report.rs @@ -43,21 +43,21 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_dataproxy_MullvadProblemRepor userEmail: JString<'_>, userMessage: JString<'_>, outputPath: JString<'_>, - resourcesDirectory: JString<'_>, + cacheDirectory: JString<'_>, ) -> jboolean { let env = JnixEnv::from(env); let user_email = String::from_java(&env, userEmail); let user_message = String::from_java(&env, userMessage); let output_path_string = String::from_java(&env, outputPath); let output_path = Path::new(&output_path_string); - let resources_directory_string = String::from_java(&env, resourcesDirectory); - let resources_directory = Path::new(&resources_directory_string); + let cache_directory_string = String::from_java(&env, cacheDirectory); + let cache_directory = Path::new(&cache_directory_string); let send_result = mullvad_problem_report::send_problem_report( &user_email, &user_message, output_path, - resources_directory, + cache_directory, ); match send_result { diff --git a/mullvad-paths/src/cache.rs b/mullvad-paths/src/cache.rs index 35b99738f0..490fe9f702 100644 --- a/mullvad-paths/src/cache.rs +++ b/mullvad-paths/src/cache.rs @@ -33,3 +33,26 @@ pub fn get_default_cache_dir() -> Result<PathBuf> { Ok(std::path::Path::new(crate::APP_PATH).join("cache")) } } + +/// Creates and returns a cache directory that is readable by all users. +pub fn user_cache_dir() -> Result<PathBuf> { + #[cfg(not(target_os = "macos"))] + let permissions = None; + #[cfg(target_os = "macos")] + let permissions = Some(std::os::unix::fs::PermissionsExt::from_mode(0o755)); + crate::create_and_return(get_user_cache_dir, permissions) +} + +pub fn get_user_cache_dir() -> Result<PathBuf> { + #[cfg(windows)] + { + let dir = crate::get_allusersprofile_dir(); + dir.map(|dir| dir.join(crate::PRODUCT_NAME)) + } + #[cfg(target_os = "macos")] + { + Ok(std::path::Path::new("/Library/Caches").join(crate::PRODUCT_NAME)) + } + #[cfg(not(any(target_os = "macos", windows)))] + get_cache_dir() +} diff --git a/mullvad-paths/src/lib.rs b/mullvad-paths/src/lib.rs index 7b06c9937d..03453e9adc 100644 --- a/mullvad-paths/src/lib.rs +++ b/mullvad-paths/src/lib.rs @@ -56,7 +56,7 @@ fn create_and_return( } mod cache; -pub use crate::cache::{cache_dir, get_default_cache_dir}; +pub use crate::cache::{cache_dir, get_default_cache_dir, get_user_cache_dir, user_cache_dir}; mod logs; pub use crate::logs::{get_default_log_dir, get_log_dir, log_dir}; diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 276a1557b1..f4ee9d73ae 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -69,6 +69,9 @@ pub enum Error { #[error(display = "Unable to spawn Tokio runtime")] CreateRuntime(#[error(source)] io::Error), + + #[error(display = "Unable to find cache directory")] + ObtainCacheDirectory(#[error(source)] mullvad_paths::Error), } /// These are errors that can happen during problem report collection. @@ -253,7 +256,7 @@ pub fn send_problem_report( user_email: &str, user_message: &str, report_path: &Path, - resource_dir: &Path, + user_cache_dir: &Path, ) -> Result<(), Error> { let report_content = normalize_newlines( read_file_lossy(report_path, REPORT_MAX_SIZE).map_err(|source| { @@ -275,8 +278,10 @@ pub fn send_problem_report( let mut rpc_manager = runtime .block_on(mullvad_rpc::MullvadRpcRuntime::with_cache( runtime.handle().clone(), - resource_dir, None, + user_cache_dir, + false, + |_| Ok(()), )) .map_err(Error::CreateRpcClientError)?; let rpc_client = mullvad_rpc::ProblemReportProxy::new(rpc_manager.mullvad_rest_handle()); diff --git a/mullvad-problem-report/src/main.rs b/mullvad-problem-report/src/main.rs index 7090cecd6d..4f0dfb0265 100644 --- a/mullvad-problem-report/src/main.rs +++ b/mullvad-problem-report/src/main.rs @@ -113,8 +113,8 @@ fn run() -> Result<(), Error> { let report_path = Path::new(send_matches.value_of_os("report").unwrap()); let user_email = send_matches.value_of("email").unwrap_or(""); let user_message = send_matches.value_of("message").unwrap_or(""); - let resource_dir = mullvad_paths::get_resource_dir(); - send_problem_report(user_email, user_message, report_path, &resource_dir) + let user_cache_dir = mullvad_paths::get_user_cache_dir()?; + send_problem_report(user_email, user_message, report_path, &user_cache_dir) } else { unreachable!("No sub command given"); } diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs index b2181c763b..757d6645ed 100644 --- a/mullvad-rpc/src/address_cache.rs +++ b/mullvad-rpc/src/address_cache.rs @@ -3,9 +3,11 @@ use rand::seq::SliceRandom; use std::{ io, net::SocketAddr, + ops::{Deref, DerefMut}, path::Path, sync::{Arc, Mutex}, }; +use talpid_types::ErrorExt; use tokio::{ fs, io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, @@ -20,41 +22,75 @@ pub enum Error { #[error(display = "Failed to read the address cache file")] ReadAddressCache(#[error(source)] io::Error), + #[error(display = "Failed to update the address cache file")] + WriteAddressCache(#[error(source)] io::Error), + #[error(display = "The address cache is empty")] EmptyAddressCache, + + #[error(display = "The address change listener returned an error")] + ChangeListenerError, } +pub type CurrentAddressChangeListener = + dyn Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static; + #[derive(Clone)] pub struct AddressCache { inner: Arc<Mutex<AddressCacheInner>>, - cache_path: Option<Arc<Path>>, + write_path: Option<Arc<Path>>, + change_listener: Arc<Box<CurrentAddressChangeListener>>, } impl AddressCache { - /// Initialize cache using the given list, and write changes to `cache_path`. - pub fn new(addresses: Vec<SocketAddr>, cache_path: Option<Box<Path>>) -> Result<Self, Error> { - log::trace!("API address cache: {:?}", addresses); - - let cache = AddressCacheInner::from_addresses(addresses)?; + /// Initialize cache using the given list, and write changes to `write_path`. + pub fn new( + addresses: Vec<SocketAddr>, + write_path: Option<Box<Path>>, + change_listener: Arc<Box<CurrentAddressChangeListener>>, + ) -> Result<Self, Error> { + let mut cache = AddressCacheInner::from_addresses(addresses)?; + cache.shuffle_tail(); + log::trace!("API address cache: {:?}", cache.addresses); log::debug!("Using API address: {:?}", Self::get_address_inner(&cache)); let address_cache = Self { inner: Arc::new(Mutex::new(cache)), - cache_path: cache_path.map(|cache| Arc::from(cache)), + write_path: write_path.map(|cache| Arc::from(cache)), + change_listener, }; Ok(address_cache) } - /// Initialize cache using `read_path`, and write changes to `cache_path`. - pub async fn from_file(read_path: &Path, cache_path: Option<Box<Path>>) -> Result<Self, Error> { + /// Initialize cache using `read_path`, and write changes to `write_path`. + pub async fn from_file( + read_path: &Path, + write_path: Option<Box<Path>>, + change_listener: Arc<Box<CurrentAddressChangeListener>>, + ) -> Result<Self, Error> { log::debug!("Loading API addresses from {:?}", read_path); - Self::new(read_address_file(read_path).await?, cache_path) + Self::new( + read_address_file(read_path).await?, + write_path, + change_listener, + ) } + pub fn set_change_listener(&mut self, change_listener: Arc<Box<CurrentAddressChangeListener>>) { + self.change_listener = change_listener; + } + + /// Returns the currently selected address. pub fn get_address(&self) -> SocketAddr { let mut inner = self.inner.lock().unwrap(); - inner.last_try = Some(inner.choice); + inner.tried_current = true; + Self::get_address_inner(&inner) + } + /// Returns the current address without registering it as "tried" + /// in [`has_tried_current_address`]. + pub fn peek_address(&self) -> SocketAddr { + let inner = self.inner.lock().unwrap(); Self::get_address_inner(&inner) } @@ -68,38 +104,105 @@ impl AddressCache { .unwrap_or(&API_ADDRESS.into()) } - pub fn register_failure(&self, failed_addr: SocketAddr, err: &dyn std::error::Error) { - let mut inner = self.inner.lock().unwrap(); + pub fn has_tried_current_address(&self) -> bool { + let inner = self.inner.lock().unwrap(); + inner.tried_current + } - let current_address = Self::get_address_inner(&inner); - // Only choose the next server if the current one has been tried before and it failed - if failed_addr == current_address - && inner - .last_try - .map(|last_try| last_try == inner.choice) - .unwrap_or(false) + pub async fn select_new_address(&self) { { - inner.choice = inner.choice.wrapping_add(1); - let new_address = Self::get_address_inner(&inner); - log::error!( - "HTTP request failed: {}, using address {}. Trying next API address: {}", - err, - failed_addr, - new_address - ); + let mut inner = self.inner.lock().unwrap(); + let mut transaction = AddressCacheTransaction::new(&mut inner); + + transaction.choice = transaction.current.choice.wrapping_add(1); + if transaction.choice == transaction.current.choice { + return; + } + transaction.tried_current = false; + + tokio::task::block_in_place(move || { + if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() { + log::error!("Failed to select a new API endpoint"); + return; + } + transaction.commit(); + }); } + + if let Err(error) = self.save_to_disk().await { + log::error!("{}", error.display_chain()); + } + } + + /// Forgets the currently selected address and randomizes + /// the entire list. + pub async fn randomize(&self) -> Result<(), Error> { + { + let mut inner = self.inner.lock().unwrap(); + + let mut transaction = AddressCacheTransaction::new(&mut inner); + transaction.shuffle(); + transaction.choice = 0; + + let current_address = Self::get_address_inner(&transaction.current); + let new_address = Self::get_address_inner(&transaction); + + tokio::task::block_in_place(move || { + if new_address != current_address { + transaction.tried_current = false; + if (*self.change_listener)(new_address).is_err() { + return Err(Error::ChangeListenerError); + } + } + + transaction.commit(); + Ok(()) + })?; + } + self.save_to_disk().await.map_err(Error::WriteAddressCache) } pub async fn set_addresses(&self, mut addresses: Vec<SocketAddr>) -> io::Result<()> { let should_update = { let mut inner = self.inner.lock().unwrap(); + let mut transaction = AddressCacheTransaction::new(&mut inner); + addresses.sort(); - let mut current_sorted = inner.addresses.clone(); + + let mut current_sorted = transaction.addresses.clone(); current_sorted.sort(); + if addresses != current_sorted { - inner.addresses = addresses.clone(); - inner.shuffle(); - inner.choice = 0; + let current_address = Self::get_address_inner(&transaction); + + transaction.addresses = addresses.clone(); + transaction.shuffle(); + + // Prefer a likely-working address + let choice = transaction + .addresses + .iter() + .position(|&addr| addr == current_address); + if let Some(choice) = choice { + transaction.choice = choice; + transaction.commit(); + } else { + transaction.choice = 0; + transaction.tried_current = false; + + tokio::task::block_in_place(move || { + if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() { + log::error!("Failed to select a new API endpoint"); + return Err(io::Error::new( + io::ErrorKind::Other, + "callback returned an error", + )); + } + transaction.commit(); + Ok(()) + })?; + } + true } else { false @@ -107,26 +210,41 @@ impl AddressCache { }; if should_update { log::trace!("API address cache: {:?}", addresses); - self.save_to_disk(addresses).await?; + self.save_to_disk().await?; } Ok(()) } - async fn save_to_disk(&self, addresses: Vec<SocketAddr>) -> io::Result<()> { - if let Some(cache_path) = self.cache_path.as_ref() { - let mut file = fs::File::create(cache_path).await?; - let mut contents = addresses - .iter() - .map(ToString::to_string) - .collect::<Vec<String>>() - .join("\n"); - contents += "\n"; + async fn save_to_disk(&self) -> io::Result<()> { + let write_path = match self.write_path.as_ref() { + Some(write_path) => write_path, + None => return Ok(()), + }; - file.write_all(contents.as_bytes()).await?; - file.sync_data().await?; + let (mut addresses, choice) = { + let inner = self.inner.lock().unwrap(); + (inner.addresses.clone(), inner.choice) + }; + + // Place the current choice on top + if !addresses.is_empty() { + let addresses_len = addresses.len(); + addresses.swap(0, choice % addresses_len); } - Ok(()) + let temp_path = write_path.with_file_name("api-cache.temp"); + + let mut file = fs::File::create(&temp_path).await?; + let mut contents = addresses + .iter() + .map(ToString::to_string) + .collect::<Vec<String>>() + .join("\n"); + contents += "\n"; + file.write_all(contents.as_bytes()).await?; + file.sync_data().await?; + + fs::rename(&temp_path, write_path).await } } @@ -141,10 +259,11 @@ impl crate::rest::AddressProvider for AddressCache { } +#[derive(Clone, PartialEq, Eq)] struct AddressCacheInner { addresses: Vec<SocketAddr>, choice: usize, - last_try: Option<usize>, + tried_current: bool, } impl AddressCacheInner { @@ -152,19 +271,55 @@ impl AddressCacheInner { if addresses.is_empty() { return Err(Error::EmptyAddressCache); } - let mut cache = Self { + Ok(Self { addresses, choice: 0, - last_try: None, - }; - cache.shuffle(); - Ok(cache) + tried_current: false, + }) } fn shuffle(&mut self) { let mut rng = rand::thread_rng(); (&mut self.addresses[..]).shuffle(&mut rng); } + + /// Shuffle all but the first element + fn shuffle_tail(&mut self) { + let mut rng = rand::thread_rng(); + (&mut self.addresses[1..]).shuffle(&mut rng); + } +} + +struct AddressCacheTransaction<'a> { + current: &'a mut AddressCacheInner, + working_cache: AddressCacheInner, +} + +impl<'a> AddressCacheTransaction<'a> { + fn new(cache: &'a mut AddressCacheInner) -> Self { + Self { + working_cache: cache.clone(), + current: cache, + } + } + + fn commit(self) { + *self.current = self.working_cache; + } +} + +impl<'a> Deref for AddressCacheTransaction<'a> { + type Target = AddressCacheInner; + + fn deref(&self) -> &Self::Target { + &self.working_cache + } +} + +impl<'a> DerefMut for AddressCacheTransaction<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.working_cache + } } async fn read_address_file(path: &Path) -> Result<Vec<SocketAddr>, Error> { @@ -178,7 +333,6 @@ async fn read_address_file(path: &Path) -> Result<Vec<SocketAddr>, Error> { .await .map_err(|error| Error::ReadAddressCache(error))? { - // for line in lines.next_line() { match line.trim().parse() { Ok(address) => addresses.push(address), Err(err) => { diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index d92be6e8be..4beefa7d3b 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -11,6 +11,7 @@ use std::{ future::Future, net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, + sync::Arc, }; use talpid_types::{net::wireguard, ErrorExt}; @@ -22,7 +23,7 @@ use crate::https_client_with_sni::HttpsConnectorWithSni; mod address_cache; mod relay_list; -use address_cache::AddressCache; +pub use address_cache::{AddressCache, CurrentAddressChangeListener}; pub use hyper::StatusCode; pub use relay_list::RelayListProxy; @@ -42,7 +43,7 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443); pub struct MullvadRpcRuntime { https_connector: HttpsConnectorWithSni, handle: tokio::runtime::Handle, - address_cache: AddressCache, + pub address_cache: AddressCache, } #[derive(err_derive::Error, Debug)] @@ -60,7 +61,11 @@ impl MullvadRpcRuntime { Ok(MullvadRpcRuntime { https_connector: HttpsConnectorWithSni::new(), handle, - address_cache: AddressCache::new(vec![API_ADDRESS.into()], None)?, + address_cache: AddressCache::new( + vec![API_ADDRESS.into()], + None, + Arc::new(Box::new(|_| Ok(()))), + )?, }) } @@ -69,31 +74,55 @@ impl MullvadRpcRuntime { /// if it fails. pub async fn with_cache( handle: tokio::runtime::Handle, - resource_dir: &Path, - cache_dir: Option<&Path>, + resource_dir: Option<&Path>, + cache_dir: &Path, + write_changes: bool, + address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static, ) -> Result<Self, Error> { - let resource_file = resource_dir.join(API_IP_CACHE_FILENAME); + let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); + let write_file = if write_changes { + Some(cache_file.clone().into_boxed_path()) + } else { + None + }; + + let address_change_listener = + Arc::<Box<CurrentAddressChangeListener>>::new(Box::new(address_change_listener)); - let address_cache = if let Some(cache_dir) = cache_dir { - let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); - let cache_file_boxed = cache_file.clone().into_boxed_path(); + let address_cache = match AddressCache::from_file( + &cache_file, + write_file.clone(), + address_change_listener.clone(), + ) + .await + { + Ok(cache) => cache, + Err(error) => { + let cache_exists = cache_file.exists(); + if cache_exists { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to load cached API addresses. Falling back on bundled list" + ) + ); + } - match AddressCache::from_file(&cache_file, Some(cache_file_boxed.clone())).await { - Ok(cache) => cache, - Err(error) => { - if cache_file.exists() { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to load cached API addresses. Falling back on bundled list" - ) - ); + // Initialize the cache directory cache using the resource directory + match resource_dir { + Some(resource_dir) => { + let read_file = resource_dir.join(API_IP_CACHE_FILENAME); + let empty_listener = + Arc::<Box<CurrentAddressChangeListener>>::new(Box::new(|_| Ok(()))); + let mut cache = + AddressCache::from_file(&read_file, write_file, empty_listener).await?; + cache.randomize().await?; + cache.set_change_listener(address_change_listener); + cache } - AddressCache::from_file(&resource_file, Some(cache_file_boxed)).await? + None => return Err(Error::AddressCacheError(error)), } } - } else { - AddressCache::from_file(&resource_file, None).await? }; let https_connector = HttpsConnectorWithSni::new(); diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 2bf8511d24..12bc10263e 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -139,7 +139,20 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { if let Err(err) = &response { match err { Error::HyperError(_) | Error::TimeoutError(_) => { - address_cache.register_failure(host_addr, err); + let current_address = address_cache.peek_address(); + if current_address == host_addr + && address_cache.has_tried_current_address() + { + address_cache.select_new_address().await; + let new_address = address_cache.peek_address(); + + log::error!( + "HTTP request failed: {}, using address {}. Trying next API address: {}", + err, + current_address, + new_address, + ); + } } _ => (), } diff --git a/mullvad-setup/src/daemon_paths.rs b/mullvad-setup/src/daemon_paths.rs index 6f00952703..fce81ede8c 100644 --- a/mullvad-setup/src/daemon_paths.rs +++ b/mullvad-setup/src/daemon_paths.rs @@ -15,9 +15,7 @@ use winapi::{ um::{ combaseapi::CoTaskMemFree, handleapi::CloseHandle, - knownfolders::{ - FOLDERID_LocalAppData, FOLDERID_ProgramFiles, FOLDERID_RoamingAppData, FOLDERID_System, - }, + knownfolders::{FOLDERID_LocalAppData, FOLDERID_System}, processthreadsapi::{GetCurrentThread, OpenProcess, OpenProcessToken, OpenThreadToken}, psapi::K32EnumProcesses, securitybaseapi::{AdjustTokenPrivileges, ImpersonateSelf, RevertToSelf}, @@ -37,16 +35,6 @@ pub fn get_mullvad_daemon_settings_path() -> io::Result<PathBuf> { .map(|settings| settings.join(mullvad_paths::PRODUCT_NAME)) } -pub fn get_mullvad_resource_path() -> io::Result<PathBuf> { - get_known_folder_path(FOLDERID_ProgramFiles, KF_FLAG_DEFAULT, ptr::null_mut()) - .map(|settings| settings.join(mullvad_paths::PRODUCT_NAME).join("resources")) -} - -pub fn get_mullvad_daemon_cache_path() -> io::Result<PathBuf> { - get_system_service_known_folder(FOLDERID_RoamingAppData) - .map(|settings| settings.join(mullvad_paths::PRODUCT_NAME)) -} - /// Get local AppData path for the system service user. Requires elevated privileges to work. /// Useful for deducing the config path for the daemon on Windows when running as a user that diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 2e109089a4..c5ed2d38f0 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -24,11 +24,11 @@ enum ExitStatus { #[cfg(windows)] mod daemon_paths; -#[cfg(not(windows))] -type PathError = mullvad_paths::Error; - #[cfg(windows)] -type PathError = std::io::Error; +type SettingsPathErrorType = std::io::Error; + +#[cfg(not(windows))] +type SettingsPathErrorType = mullvad_paths::Error; #[derive(err_derive::Error, Debug)] #[error(no_from)] @@ -49,14 +49,10 @@ pub enum Error { RpcInitializationError(#[error(source)] mullvad_rpc::Error), #[error(display = "Failed to obtain settings directory path")] - SettingsPathError(#[error(source)] PathError), - - #[cfg(windows)] - #[error(display = "Failed to obtain resource directory path")] - ResourcePathError(#[error(source)] PathError), + SettingsPathError(#[error(source)] SettingsPathErrorType), #[error(display = "Failed to obtain cache directory path")] - CachePathError(#[error(source)] PathError), + CachePathError(#[error(source)] mullvad_paths::Error), #[error(display = "Failed to initialize account history")] InitializeAccountHistoryError(#[error(source)] account_history::Error), @@ -148,6 +144,7 @@ async fn reset_firewall() -> Result<(), Error> { let mut firewall = Firewall::new(FirewallArguments { initialize_blocked: false, allow_lan: true, + allowed_endpoint: None, }) .map_err(Error::FirewallError)?; @@ -155,18 +152,20 @@ async fn reset_firewall() -> Result<(), Error> { } async fn clear_history() -> Result<(), Error> { - let (cache_path, resource_path, settings_path) = get_paths()?; + let (user_cache_path, settings_path) = get_paths()?; let mut rpc_runtime = MullvadRpcRuntime::with_cache( tokio::runtime::Handle::current(), - &resource_path, - Some(&cache_path), + None, + &user_cache_path, + false, + |_| Ok(()), ) .await .map_err(Error::RpcInitializationError)?; let mut account_history = account_history::AccountHistory::new( - &cache_path, + &user_cache_path, &settings_path, rpc_runtime.mullvad_rest_handle(), ) @@ -180,21 +179,16 @@ async fn clear_history() -> Result<(), Error> { } #[cfg(not(windows))] -fn get_paths() -> Result<(PathBuf, PathBuf, PathBuf), Error> { - let cache_path = mullvad_paths::cache_dir().map_err(Error::CachePathError)?; - let resource_path = mullvad_paths::get_resource_dir(); +fn get_paths() -> Result<(PathBuf, PathBuf), Error> { + let user_cache_path = mullvad_paths::user_cache_dir().map_err(Error::CachePathError)?; let settings_path = mullvad_paths::settings_dir().map_err(Error::SettingsPathError)?; - Ok((cache_path, resource_path, settings_path)) + Ok((user_cache_path, settings_path)) } #[cfg(windows)] -fn get_paths() -> Result<(PathBuf, PathBuf, PathBuf), Error> { +fn get_paths() -> Result<(PathBuf, PathBuf), Error> { + let user_cache_path = mullvad_paths::user_cache_dir().map_err(Error::CachePathError)?; let settings_path = - daemon_paths::get_mullvad_daemon_settings_path().map_err(Error::CachePathError)?; - let resource_path = - daemon_paths::get_mullvad_resource_path().map_err(Error::ResourcePathError)?; - let cache_path = - daemon_paths::get_mullvad_daemon_cache_path().map_err(Error::SettingsPathError)?; - - Ok((cache_path, resource_path, settings_path)) + daemon_paths::get_mullvad_daemon_settings_path().map_err(Error::SettingsPathError)?; + Ok((user_cache_path, settings_path)) } diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 3c252313ce..04bd00777d 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -531,10 +531,12 @@ impl<'a> PolicyBatch<'a> { peer_endpoint, pingable_hosts, allow_lan, + allowed_endpoint, use_fwmark, } => { self.add_allow_icmp_pingable_hosts(&pingable_hosts); - self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_endpoint_rules(allowed_endpoint); // Important to block DNS after allow relay rule (so the relay can operate // over port 53) but before allow LAN (so DNS does not leak to the LAN) @@ -548,7 +550,7 @@ impl<'a> PolicyBatch<'a> { dns_servers, use_fwmark, } => { - self.add_allow_endpoint_rules(peer_endpoint, *use_fwmark); + self.add_allow_tunnel_endpoint_rules(peer_endpoint, *use_fwmark); self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Udp)?; self.add_allow_dns_rules(tunnel, &dns_servers, TransportProtocol::Tcp)?; // Important to block DNS *before* we allow the tunnel and allow LAN. So DNS @@ -560,7 +562,12 @@ impl<'a> PolicyBatch<'a> { } *allow_lan } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { + self.add_allow_endpoint_rules(allowed_endpoint); + // Important to drop DNS before allowing LAN (to stop DNS leaking to the LAN) self.add_drop_dns_rule(); *allow_lan @@ -582,7 +589,7 @@ impl<'a> PolicyBatch<'a> { Ok(()) } - fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) { + fn add_allow_tunnel_endpoint_rules(&mut self, endpoint: &Endpoint, use_fwmark: bool) { let mut in_rule = Rule::new(&self.in_chain); check_endpoint(&mut in_rule, End::Src, endpoint); @@ -608,6 +615,20 @@ impl<'a> PolicyBatch<'a> { self.batch.add(&out_rule, nftnl::MsgType::Add); } + fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) { + let mut in_rule = Rule::new(&self.in_chain); + check_endpoint(&mut in_rule, End::Src, endpoint); + add_verdict(&mut in_rule, &Verdict::Accept); + + self.batch.add(&in_rule, nftnl::MsgType::Add); + + let mut out_rule = Rule::new(&self.out_chain); + check_endpoint(&mut out_rule, End::Dst, endpoint); + add_verdict(&mut out_rule, &Verdict::Accept); + + self.batch.add(&out_rule, nftnl::MsgType::Add); + } + fn add_allow_icmp_pingable_hosts(&mut self, pingable_hosts: &[IpAddr]) { for host in pingable_hosts { let icmp_proto = match &host { diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs index dfdc1e31fc..2e23c99dd2 100644 --- a/talpid-core/src/firewall/macos.rs +++ b/talpid-core/src/firewall/macos.rs @@ -98,9 +98,11 @@ impl Firewall { FirewallPolicy::Connecting { peer_endpoint, allow_lan, + allowed_endpoint, pingable_hosts, } => { let mut rules = vec![self.get_allow_relay_rule(peer_endpoint)?]; + rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?); rules.extend(self.get_allow_pingable_hosts(&pingable_hosts)?); if allow_lan { // Important to block DNS after allow relay rule (so the relay can operate @@ -136,8 +138,12 @@ impl Firewall { Ok(rules) } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { let mut rules = Vec::new(); + rules.push(self.get_allowed_endpoint_rule(allowed_endpoint)?); if allow_lan { // Important to block DNS before allow LAN (so DNS does not leak to the LAN) rules.append(&mut self.get_block_dns_rules()?); @@ -247,6 +253,22 @@ impl Firewall { .build()?) } + fn get_allowed_endpoint_rule( + &self, + allowed_endpoint: net::Endpoint, + ) -> Result<pfctl::FilterRule> { + let pfctl_proto = as_pfctl_proto(allowed_endpoint.protocol); + + Ok(self + .create_rule_builder(FilterRuleAction::Pass) + .direction(pfctl::Direction::Out) + .to(allowed_endpoint.address) + .proto(pfctl_proto) + .keep_state(pfctl::StatePolicy::Keep) + .quick(true) + .build()?) + } + fn get_block_dns_rules(&self) -> Result<Vec<pfctl::FilterRule>> { let block_tcp_dns_rule = self .create_rule_builder(FilterRuleAction::Drop(DropAction::Return)) diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index b467f37d98..83a112ce88 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -107,6 +107,8 @@ pub enum FirewallPolicy { pingable_hosts: Vec<IpAddr>, /// Flag setting if communication with LAN networks should be possible. allow_lan: bool, + /// Host that should be reachable by the tunnel client while connecting. + allowed_endpoint: Endpoint, /// A process that is allowed to send packets to the relay. #[cfg(windows)] relay_client: PathBuf, @@ -140,6 +142,8 @@ pub enum FirewallPolicy { Blocked { /// Flag setting if communication with LAN networks should be possible. allow_lan: bool, + /// Host that should be reachable while in the blocked state. + allowed_endpoint: Endpoint, }, } @@ -182,10 +186,14 @@ impl fmt::Display for FirewallPolicy { tunnel.ipv6_gateway, if *allow_lan { "Allowing" } else { "Blocking" } ), - FirewallPolicy::Blocked { allow_lan } => write!( + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => write!( f, - "Blocked, {} LAN", - if *allow_lan { "Allowing" } else { "Blocking" } + "Blocked. {} LAN. Allowing endpoint {}", + if *allow_lan { "Allowing" } else { "Blocking" }, + allowed_endpoint, ), } } @@ -203,6 +211,8 @@ pub struct FirewallArguments { pub initialize_blocked: bool, /// This argument is required for the blocked state to configure the firewall correctly. pub allow_lan: bool, + /// This argument is required for the blocked state to configure the firewall correctly. + pub allowed_endpoint: Option<Endpoint>, } impl Firewall { diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index d1fa08a3e6..8375eb55d6 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -57,10 +57,23 @@ impl FirewallT for Firewall { if args.initialize_blocked { let cfg = &WinFwSettings::new(args.allow_lan); + + let winfw_allowed_endpoint = if let Some(allowed_endpoint) = args.allowed_endpoint { + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }) + } else { + None + }; + unsafe { WinFw_InitializeBlocked( WINFW_TIMEOUT_SECONDS, &cfg, + winfw_allowed_endpoint.as_ptr(), Some(log_sink), logging_context, ) @@ -83,6 +96,7 @@ impl FirewallT for Firewall { peer_endpoint, pingable_hosts, allow_lan, + allowed_endpoint, relay_client, } => { let cfg = &WinFwSettings::new(allow_lan); @@ -91,6 +105,7 @@ impl FirewallT for Firewall { &peer_endpoint, &cfg, "Mullvad".to_string(), + &allowed_endpoint, &pingable_hosts, &relay_client, ) @@ -105,9 +120,12 @@ impl FirewallT for Firewall { let cfg = &WinFwSettings::new(allow_lan); self.set_connected_state(&peer_endpoint, &cfg, &tunnel, &dns_servers, &relay_client) } - FirewallPolicy::Blocked { allow_lan } => { + FirewallPolicy::Blocked { + allow_lan, + allowed_endpoint, + } => { let cfg = &WinFwSettings::new(allow_lan); - self.set_blocked_state(&cfg) + self.set_blocked_state(&cfg, &allowed_endpoint) } } } @@ -138,12 +156,13 @@ impl Firewall { endpoint: &Endpoint, winfw_settings: &WinFwSettings, _tunnel_iface_alias: String, + allowed_endpoint: &Endpoint, pingable_hosts: &Vec<IpAddr>, relay_client: &Path, ) -> Result<(), Error> { trace!("Applying 'connecting' firewall policy"); let ip_str = Self::widestring_ip(endpoint.address.ip()); - let winfw_relay = WinFwRelay { + let winfw_relay = WinFwEndpoint { ip: ip_str.as_ptr(), port: endpoint.address.port(), protocol: WinFwProt::from(endpoint.protocol), @@ -171,12 +190,20 @@ impl Firewall { None }; + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + let winfw_allowed_endpoint = Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }); + unsafe { WinFw_ApplyPolicyConnecting( winfw_settings, &winfw_relay, relay_client.as_ptr(), pingable_hosts.as_ptr(), + winfw_allowed_endpoint.as_ptr(), ) .into_result() .map_err(Error::ApplyingConnectingPolicy) @@ -207,7 +234,7 @@ impl Firewall { WideCString::new(tunnel_metadata.interface.encode_utf16().collect::<Vec<_>>()).unwrap(); // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay - let winfw_relay = WinFwRelay { + let winfw_relay = WinFwEndpoint { ip: ip_str.as_ptr(), port: endpoint.address.port(), protocol: WinFwProt::from(endpoint.protocol), @@ -258,10 +285,22 @@ impl Firewall { } } - fn set_blocked_state(&mut self, winfw_settings: &WinFwSettings) -> Result<(), Error> { + fn set_blocked_state( + &mut self, + winfw_settings: &WinFwSettings, + allowed_endpoint: &Endpoint, + ) -> Result<(), Error> { trace!("Applying 'blocked' firewall policy"); + + let allowed_endpoint_ip = Self::widestring_ip(allowed_endpoint.address.ip()); + let winfw_allowed_endpoint = Some(WinFwEndpoint { + ip: allowed_endpoint_ip.as_ptr(), + port: allowed_endpoint.address.port(), + protocol: WinFwProt::from(allowed_endpoint.protocol), + }); + unsafe { - WinFw_ApplyPolicyBlocked(winfw_settings) + WinFw_ApplyPolicyBlocked(winfw_settings, winfw_allowed_endpoint.as_ptr()) .into_result() .map_err(Error::ApplyingBlockedPolicy) } @@ -289,7 +328,7 @@ mod winfw { use talpid_types::net::TransportProtocol; #[repr(C)] - pub struct WinFwRelay { + pub struct WinFwEndpoint { pub ip: *const libc::wchar_t, pub port: u16, pub protocol: WinFwProt, @@ -385,6 +424,7 @@ mod winfw { pub fn WinFw_InitializeBlocked( timeout: libc::c_uint, settings: &WinFwSettings, + allowed_endpoint: *const WinFwEndpoint, sink: Option<LogSink>, sink_context: *const u8, ) -> InitializationResult; @@ -395,15 +435,16 @@ mod winfw { #[link_name = "WinFw_ApplyPolicyConnecting"] pub fn WinFw_ApplyPolicyConnecting( settings: &WinFwSettings, - relay: &WinFwRelay, + relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, pingable_hosts: *const WinFwPingableHosts, + allowed_endpoint: *const WinFwEndpoint, ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyConnected"] pub fn WinFw_ApplyPolicyConnected( settings: &WinFwSettings, - relay: &WinFwRelay, + relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, tunnelIfaceAlias: *const libc::wchar_t, v4Gateway: *const libc::wchar_t, @@ -413,7 +454,10 @@ mod winfw { ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyBlocked"] - pub fn WinFw_ApplyPolicyBlocked(settings: &WinFwSettings) -> WinFwPolicyStatus; + pub fn WinFw_ApplyPolicyBlocked( + settings: &WinFwSettings, + allowed_endpoint: *const WinFwEndpoint, + ) -> WinFwPolicyStatus; #[link_name = "WinFw_Reset"] pub fn WinFw_Reset() -> WinFwPolicyStatus; diff --git a/talpid-core/src/tunnel/tun_provider/android/mod.rs b/talpid-core/src/tunnel/tun_provider/android/mod.rs index b9385f13a7..fa48f115b9 100644 --- a/talpid-core/src/tunnel/tun_provider/android/mod.rs +++ b/talpid-core/src/tunnel/tun_provider/android/mod.rs @@ -66,6 +66,7 @@ pub struct AndroidTunProvider { object: GlobalRef, last_tun_config: TunConfig, allow_lan: bool, + allowed_endpoint: IpAddr, custom_dns_servers: Option<Vec<IpAddr>>, } @@ -74,6 +75,7 @@ impl AndroidTunProvider { pub fn new( context: AndroidContext, allow_lan: bool, + allowed_endpoint: IpAddr, custom_dns_servers: Option<Vec<IpAddr>>, ) -> Self { let env = JnixEnv::from( @@ -90,6 +92,7 @@ impl AndroidTunProvider { object: context.vpn_service, last_tun_config: TunConfig::default(), allow_lan, + allowed_endpoint, custom_dns_servers, } } @@ -103,6 +106,10 @@ impl AndroidTunProvider { Ok(()) } + pub fn set_allowed_endpoint(&mut self, endpoint: IpAddr) { + self.allowed_endpoint = endpoint; + } + pub fn set_custom_dns_servers(&mut self, servers: Option<Vec<IpAddr>>) -> Result<(), Error> { if self.custom_dns_servers != servers { self.custom_dns_servers = servers; @@ -129,6 +136,19 @@ impl AndroidTunProvider { }) } + /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and (potentially) + /// LAN routes via the tunnel device. + /// + /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be + /// closed. + pub fn create_blocking_tun(&mut self) -> Result<(), Error> { + let mut config = TunConfig::default(); + self.prepare_tun_config(&mut config); + self.prepare_tun_config_for_allowed_endpoint(&mut config); + let _ = self.get_tun(config)?; + Ok(()) + } + /// Open a tunnel device using the previous or the default configuration. /// /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be @@ -231,6 +251,24 @@ impl AndroidTunProvider { } } + fn prepare_tun_config_for_allowed_endpoint(&self, config: &mut TunConfig) { + let endpoint_net = IpNetwork::from(self.allowed_endpoint); + let routes = config + .routes + .iter() + .flat_map(|&route| { + if route.is_ipv4() && endpoint_net.is_ipv4() { + route.sub(endpoint_net).collect() + } else if route.is_ipv6() && endpoint_net.is_ipv6() { + route.sub(endpoint_net).collect() + } else { + vec![route] + } + }) + .collect(); + config.routes = routes; + } + fn prepare_tun_config(&self, config: &mut TunConfig) { self.prepare_tun_config_for_allow_lan(config); self.prepare_tun_config_for_custom_dns(config); diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 7292da0c67..0c305de9a7 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -192,6 +192,13 @@ impl ConnectedState { } } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { match shared_values.set_custom_dns(servers) { Ok(true) => { diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 44dcd9f153..0b03ceeca1 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -63,6 +63,7 @@ impl ConnectingState { peer_endpoint, pingable_hosts: gateway_list_from_params(params), allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), #[cfg(windows)] relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, ¶ms), #[cfg(target_os = "linux")] @@ -235,6 +236,22 @@ impl ConnectingState { } } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + if let Err(error) = + Self::set_firewall_policy(shared_values, &self.tunnel_parameters) + { + return self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), + ); + } + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { match shared_values.set_custom_dns(servers) { #[cfg(target_os = "android")] diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index dcc4660e9f..922eb69c88 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -17,6 +17,7 @@ impl DisconnectedState { let result = if shared_values.block_when_disconnected { let policy = FirewallPolicy::Blocked { allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), }; shared_values.firewall.apply_policy(policy).map_err(|e| { e.display_chain_with_msg( @@ -77,6 +78,15 @@ impl TunnelState for DisconnectedState { } SameState(self.into()) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + Self::set_firewall_policy(shared_values, true); + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { // Same situation as allow LAN above. shared_values diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 0928834d1c..48a83a6dc3 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -32,6 +32,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Nothing } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Nothing + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Nothing @@ -53,6 +60,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Block(reason) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Block(reason) + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Block(reason) @@ -79,6 +93,13 @@ impl DisconnectingState { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Reconnect(retry_attempt) } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + let _ = shared_values.set_allowed_endpoint(endpoint); + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + AfterDisconnect::Reconnect(retry_attempt) + } Some(TunnelCommand::CustomDns(servers)) => { let _ = shared_values.set_custom_dns(servers); AfterDisconnect::Reconnect(retry_attempt) diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index a87dccd5b4..51159d274f 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -21,6 +21,7 @@ impl ErrorState { ) -> Result<(), FirewallPolicyError> { let policy = FirewallPolicy::Blocked { allow_lan: shared_values.allow_lan, + allowed_endpoint: shared_values.allowed_endpoint.clone(), }; #[cfg(target_os = "linux")] @@ -47,7 +48,7 @@ impl ErrorState { /// Returns true if a new tunnel device was successfully created. #[cfg(target_os = "android")] fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool { - match shared_values.tun_provider.create_tun_if_closed() { + match shared_values.tun_provider.create_blocking_tun() { Ok(()) => true, Err(error) => { log::error!( @@ -105,6 +106,23 @@ impl TunnelState for ErrorState { SameState(self.into()) } } + Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { + if shared_values.set_allowed_endpoint(endpoint) { + let _ = Self::set_firewall_policy(shared_values); + + #[cfg(target_os = "android")] + if !Self::create_blocking_tun(shared_values) { + return NewState(Self::enter( + shared_values, + ErrorStateCause::SetFirewallPolicyError(FirewallPolicyError::Generic), + )); + } + } + if let Err(_) = tx.send(()) { + log::error!("The AllowEndpoint receiver was dropped"); + } + SameState(self.into()) + } Some(TunnelCommand::CustomDns(servers)) => { if let Err(error_state_cause) = shared_values.set_custom_dns(servers) { NewState(Self::enter(shared_values, error_state_cause)) diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index fbec1bf2b1..b657ec5e36 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -33,7 +33,7 @@ use std::{ #[cfg(target_os = "android")] use talpid_types::{android::AndroidContext, ErrorExt}; use talpid_types::{ - net::TunnelParameters, + net::{Endpoint, TunnelParameters}, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, }; @@ -75,6 +75,7 @@ pub async fn spawn( allow_lan: bool, block_when_disconnected: bool, custom_dns: Option<Vec<IpAddr>>, + allowed_endpoint: Endpoint, tunnel_parameters_generator: impl TunnelParametersGenerator, log_dir: Option<PathBuf>, resource_dir: PathBuf, @@ -101,6 +102,8 @@ pub async fn spawn( #[cfg(target_os = "android")] allow_lan, #[cfg(target_os = "android")] + allowed_endpoint.address.ip(), + #[cfg(target_os = "android")] custom_dns.clone(), ); @@ -114,6 +117,7 @@ pub async fn spawn( block_when_disconnected, is_offline, custom_dns, + allowed_endpoint, tunnel_parameters_generator, tun_provider, log_dir, @@ -152,6 +156,9 @@ pub async fn spawn( pub enum TunnelCommand { /// Enable or disable LAN access in the firewall. AllowLan(bool), + /// Endpoint that should never be blocked. + /// If an error occurs, the sender is dropped. + AllowEndpoint(Endpoint, oneshot::Sender<()>), /// Set custom DNS servers to use. CustomDns(Option<Vec<IpAddr>>), /// Enable or disable the block_when_disconnected feature. @@ -193,6 +200,7 @@ impl TunnelStateMachine { block_when_disconnected: bool, is_offline: bool, custom_dns: Option<Vec<IpAddr>>, + allowed_endpoint: Endpoint, tunnel_parameters_generator: impl TunnelParametersGenerator, tun_provider: TunProvider, log_dir: Option<PathBuf>, @@ -204,6 +212,7 @@ impl TunnelStateMachine { let args = FirewallArguments { initialize_blocked: block_when_disconnected || !reset_firewall, allow_lan, + allowed_endpoint: Some(allowed_endpoint), }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; @@ -218,6 +227,7 @@ impl TunnelStateMachine { block_when_disconnected, is_offline, custom_dns, + allowed_endpoint, tunnel_parameters_generator: Box::new(tunnel_parameters_generator), tun_provider, log_dir, @@ -291,6 +301,8 @@ struct SharedTunnelStateValues { is_offline: bool, /// Custom DNS servers to use. custom_dns: Option<Vec<IpAddr>>, + /// Endpoint that should not be blocked by the firewall. + allowed_endpoint: Endpoint, /// The generator of new `TunnelParameter`s tunnel_parameters_generator: Box<dyn TunnelParametersGenerator>, /// The provider of tunnel devices. @@ -328,6 +340,20 @@ impl SharedTunnelStateValues { Ok(()) } + pub fn set_allowed_endpoint(&mut self, endpoint: Endpoint) -> bool { + if self.allowed_endpoint != endpoint { + self.allowed_endpoint = endpoint; + + #[cfg(target_os = "android")] + self.tun_provider + .set_allowed_endpoint(endpoint.address.ip()); + + true + } else { + false + } + } + pub fn set_custom_dns( &mut self, custom_dns: Option<Vec<IpAddr>>, diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 15bf33a5bb..29a871ee07 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -144,6 +144,10 @@ impl Endpoint { protocol, } } + + pub fn from_socket_address(address: SocketAddr, protocol: TransportProtocol) -> Self { + Endpoint { address, protocol } + } } impl fmt::Display for Endpoint { diff --git a/windows/nsis-plugins/src/cleanup/cleaningops.cpp b/windows/nsis-plugins/src/cleanup/cleaningops.cpp index 475c78d8b7..1a711bf3d8 100644 --- a/windows/nsis-plugins/src/cleanup/cleaningops.cpp +++ b/windows/nsis-plugins/src/cleanup/cleaningops.cpp @@ -333,14 +333,14 @@ void RemoveRelayCacheServiceUser() void RemoveApiAddressCacheServiceUser() { - const auto localAppData = GetSystemUserLocalAppData(); - const auto mullvadAppData = std::filesystem::path(localAppData).append(L"Mullvad VPN"); + const auto programData = common::fs::GetKnownFolderPath(FOLDERID_ProgramData, KF_FLAG_DEFAULT, nullptr); + const auto mullvadProgramData = std::filesystem::path(programData).append(L"Mullvad VPN"); common::fs::ScopedNativeFileSystem nativeFileSystem; - common::security::AddAdminToObjectDacl(mullvadAppData, SE_FILE_OBJECT); + common::security::AddAdminToObjectDacl(mullvadProgramData, SE_FILE_OBJECT); - const auto cacheFile = std::filesystem::path(mullvadAppData).append(L"api-ip-address.txt"); + const auto cacheFile = std::filesystem::path(mullvadProgramData).append(L"api-ip-address.txt"); std::filesystem::remove(cacheFile); } diff --git a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp index a722501832..7f5b9d980c 100644 --- a/windows/winfw/src/extras/cli/commands/winfw/policy.cpp +++ b/windows/winfw/src/extras/cli/commands/winfw/policy.cpp @@ -26,9 +26,9 @@ WinFwProtocol TranslateProtocol(const std::wstring &protocol) return (0 == _wcsicmp(protocol.c_str(), L"tcp") ? WinFwProtocol::Tcp : WinFwProtocol::Udp); } -WinFwRelay CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol) +WinFwEndpoint CreateRelay(const wchar_t *ip, const std::wstring &port, const std::wstring &protocol) { - WinFwRelay r; + WinFwEndpoint r; r.ip = ip; r.port = common::string::LexicalCast<uint16_t>(port); diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index b8959cfc01..793f8c917d 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -15,6 +15,7 @@ #include "rules/baseline/permitvpntunnelservice.h" #include "rules/baseline/permitping.h" #include "rules/baseline/permitdns.h" +#include "rules/baseline/permitendpoint.h" #include "rules/dns/blockall.h" #include "rules/dns/permittunnel.h" #include "rules/dns/permitnontunnel.h" @@ -30,19 +31,6 @@ using namespace rules; namespace { -multi::PermitVpnRelay::Protocol TranslateProtocol(WinFwProtocol protocol) -{ - switch (protocol) - { - case Tcp: return multi::PermitVpnRelay::Protocol::Tcp; - case Udp: return multi::PermitVpnRelay::Protocol::Udp; - default: - { - THROW_ERROR("Missing case handler in switch clause"); - } - }; -} - // // Since the PermitLan rule doesn't specifically address DNS, it will allow DNS requests targetting // a local resolver to leave the machine. From the local resolver the request will either be @@ -91,7 +79,7 @@ void AppendSettingsRules void AppendRelayRules ( FwContext::Ruleset &ruleset, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient ) { @@ -105,12 +93,28 @@ void AppendRelayRules ruleset.emplace_back(std::make_unique<multi::PermitVpnRelay>( wfp::IpAddress(relay.ip), relay.port, - TranslateProtocol(relay.protocol), + relay.protocol, relayClient, sublayer )); } +// +// Refer comment on `AppendSettingsRules`. +// +void AppendAllowedEndpointRules +( + FwContext::Ruleset &ruleset, + const WinFwEndpoint &endpoint +) +{ + ruleset.emplace_back(std::make_unique<baseline::PermitEndpoint>( + wfp::IpAddress(endpoint.ip), + endpoint.port, + endpoint.protocol + )); +} + void AppendNetBlockedRules(FwContext::Ruleset &ruleset) { ruleset.emplace_back(std::make_unique<baseline::BlockAll>()); @@ -145,7 +149,8 @@ FwContext::FwContext FwContext::FwContext ( uint32_t timeout, - const WinFwSettings &settings + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint ) : m_baseline(0) , m_activePolicy(Policy::None) @@ -159,7 +164,7 @@ FwContext::FwContext uint32_t checkpoint = 0; - if (false == applyBlockedBaseConfiguration(settings, checkpoint)) + if (false == applyBlockedBaseConfiguration(settings, allowedEndpoint, checkpoint)) { THROW_ERROR("Failed to apply base configuration in BFE"); } @@ -171,9 +176,10 @@ FwContext::FwContext bool FwContext::applyPolicyConnecting ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, - const std::optional<PingableHosts> &pingableHosts + const std::optional<PingableHosts> &pingableHosts, + const std::optional<WinFwEndpoint> &allowedEndpoint ) { Ruleset ruleset; @@ -182,6 +188,11 @@ bool FwContext::applyPolicyConnecting AppendSettingsRules(ruleset, settings); AppendRelayRules(ruleset, relay, relayClient); + if (allowedEndpoint.has_value()) + { + AppendAllowedEndpointRules(ruleset, allowedEndpoint.value()); + } + // // Permit pinging the gateway inside the tunnel. // @@ -208,7 +219,7 @@ bool FwContext::applyPolicyConnecting bool FwContext::applyPolicyConnected ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, const std::wstring &tunnelInterfaceAlias, const std::vector<wfp::IpAddress> &tunnelDnsServers, @@ -252,9 +263,9 @@ bool FwContext::applyPolicyConnected return status; } -bool FwContext::applyPolicyBlocked(const WinFwSettings &settings) +bool FwContext::applyPolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint) { - const auto status = applyRuleset(composePolicyBlocked(settings)); + const auto status = applyRuleset(composePolicyBlocked(settings, allowedEndpoint)); if (status) { @@ -284,13 +295,18 @@ FwContext::Policy FwContext::activePolicy() const return m_activePolicy; } -FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings) +FwContext::Ruleset FwContext::composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint) { Ruleset ruleset; AppendNetBlockedRules(ruleset); AppendSettingsRules(ruleset, settings); + if (allowedEndpoint.has_value()) + { + AppendAllowedEndpointRules(ruleset, allowedEndpoint.value()); + } + return ruleset; } @@ -302,7 +318,7 @@ bool FwContext::applyBaseConfiguration() }); } -bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint) +bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint) { return m_sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine) { @@ -318,7 +334,7 @@ bool FwContext::applyBlockedBaseConfiguration(const WinFwSettings &settings, uin // checkpoint = controller.peekCheckpoint(); - return applyRulesetDirectly(composePolicyBlocked(settings), controller); + return applyRulesetDirectly(composePolicyBlocked(settings, allowedEndpoint), controller); }); } diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index 100672073a..bbbb1de485 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -20,7 +20,8 @@ public: FwContext ( uint32_t timeout, - const WinFwSettings &settings + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint ); struct PingableHosts @@ -32,22 +33,26 @@ public: bool applyPolicyConnecting ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, - const std::optional<PingableHosts> &pingableHosts + const std::optional<PingableHosts> &pingableHosts, + const std::optional<WinFwEndpoint> &allowedEndpoint ); bool applyPolicyConnected ( const WinFwSettings &settings, - const WinFwRelay &relay, + const WinFwEndpoint &relay, const std::wstring &relayClient, const std::wstring &tunnelInterfaceAlias, const std::vector<wfp::IpAddress> &tunnelDnsServers, const std::vector<wfp::IpAddress> &nonTunnelDnsServers ); - bool applyPolicyBlocked(const WinFwSettings &settings); + bool applyPolicyBlocked( + const WinFwSettings &settings, + const std::optional<WinFwEndpoint> &allowedEndpoint + ); bool reset(); @@ -68,10 +73,10 @@ private: FwContext(const FwContext &) = delete; FwContext &operator=(const FwContext &) = delete; - Ruleset composePolicyBlocked(const WinFwSettings &settings); + Ruleset composePolicyBlocked(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint); bool applyBaseConfiguration(); - bool applyBlockedBaseConfiguration(const WinFwSettings &settings, uint32_t &checkpoint); + bool applyBlockedBaseConfiguration(const WinFwSettings &settings, const std::optional<WinFwEndpoint> &allowedEndpoint, uint32_t &checkpoint); bool applyCommonBaseConfiguration(SessionController &controller, wfp::FilterEngine &engine); bool applyRuleset(const Ruleset &ruleset); diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp index 0a22be1740..417b157f82 100644 --- a/windows/winfw/src/winfw/mullvadguids.cpp +++ b/windows/winfw/src/winfw/mullvadguids.cpp @@ -129,6 +129,7 @@ MullvadGuids::DetailedIdentityRegistry MullvadGuids::DetailedRegistry(IdentityQu registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Inbound_Request_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitDhcpServer_Outbound_Response_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnRelay())); + registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitEndpoint())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6())); registry.insert(std::make_pair(WfpObjectType::Filter, Filter_Baseline_PermitVpnTunnelService_Ipv4())); @@ -644,6 +645,20 @@ const GUID &MullvadGuids::Filter_Baseline_PermitVpnRelay() } //static +const GUID &MullvadGuids::Filter_Baseline_PermitEndpoint() +{ + static const GUID g = + { + 0x99dc8dac, + 0x8520, + 0x41be, + { 0xbf, 0xab, 0x0c, 0x9, 0xbf, 0x12, 0xeb, 0 } + }; + + return g; +} + +//static const GUID &MullvadGuids::Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4() { static const GUID g = diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h index 11e396fc2b..7f00863811 100644 --- a/windows/winfw/src/winfw/mullvadguids.h +++ b/windows/winfw/src/winfw/mullvadguids.h @@ -69,6 +69,8 @@ public: static const GUID &Filter_Baseline_PermitVpnRelay(); + static const GUID &Filter_Baseline_PermitEndpoint(); + static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv4(); static const GUID &Filter_Baseline_PermitVpnTunnel_Outbound_Ipv6(); diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp new file mode 100644 index 0000000000..5b79d64ceb --- /dev/null +++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp @@ -0,0 +1,87 @@ +#include "stdafx.h" +#include "permitendpoint.h" +#include <winfw/mullvadguids.h> +#include <libwfp/filterbuilder.h> +#include <libwfp/conditionbuilder.h> +#include <libwfp/conditions/conditionprotocol.h> +#include <libwfp/conditions/conditionip.h> +#include <libwfp/conditions/conditionport.h> +#include <libwfp/conditions/conditionapplication.h> +#include <libcommon/error.h> + +using namespace wfp::conditions; + +namespace rules::baseline +{ + +namespace +{ + +const GUID &OutboundLayerFromIp(const wfp::IpAddress &ip) +{ + switch (ip.type()) + { + case wfp::IpAddress::Type::Ipv4: return FWPM_LAYER_ALE_AUTH_CONNECT_V4; + case wfp::IpAddress::Type::Ipv6: return FWPM_LAYER_ALE_AUTH_CONNECT_V6; + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + +std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol) +{ + switch (protocol) + { + case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); + case WinFwProtocol::Udp: return ConditionProtocol::Udp(); + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + +} // anonymous namespace + +PermitEndpoint::PermitEndpoint +( + const wfp::IpAddress &address, + uint16_t port, + WinFwProtocol protocol +) + : m_address(address) + , m_port(port) + , m_protocol(protocol) +{ +} + +bool PermitEndpoint::apply(IObjectInstaller &objectInstaller) +{ + wfp::FilterBuilder filterBuilder; + + // + // Permit outbound connections to endpoint. + // + + filterBuilder + .key(MullvadGuids::Filter_Baseline_PermitEndpoint()) + .name(L"Permit outbound connections to a given endpoint") + .description(L"This filter is part of a rule that permits traffic to a specific endpoint") + .provider(MullvadGuids::Provider()) + .layer(OutboundLayerFromIp(m_address)) + .sublayer(MullvadGuids::SublayerBaseline()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(OutboundLayerFromIp(m_address)); + + conditionBuilder.add_condition(ConditionIp::Remote(m_address)); + conditionBuilder.add_condition(ConditionPort::Remote(m_port)); + conditionBuilder.add_condition(CreateProtocolCondition(m_protocol)); + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +} diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.h b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h new file mode 100644 index 0000000000..93564dbd1e --- /dev/null +++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.h @@ -0,0 +1,31 @@ +#pragma once + +#include <winfw/rules/ifirewallrule.h> +#include <winfw/winfw.h> +#include <libwfp/ipaddress.h> +#include <string> + +namespace rules::baseline +{ + +class PermitEndpoint : public IFirewallRule +{ +public: + + PermitEndpoint + ( + const wfp::IpAddress &address, + uint16_t port, + WinFwProtocol protocol + ); + + bool apply(IObjectInstaller &objectInstaller) override; + +private: + + const wfp::IpAddress m_address; + const uint16_t m_port; + const WinFwProtocol m_protocol; +}; + +} diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp index 35e56ba167..ee5ffcb0c4 100644 --- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp +++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "permitvpnrelay.h" #include <winfw/mullvadguids.h> +#include <winfw/winfw.h> #include <libwfp/filterbuilder.h> #include <libwfp/conditionbuilder.h> #include <libwfp/conditions/conditionprotocol.h> @@ -30,12 +31,12 @@ const GUID &LayerFromIp(const wfp::IpAddress &ip) }; } -std::unique_ptr<ConditionProtocol> CreateProtocolCondition(PermitVpnRelay::Protocol protocol) +std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol) { switch (protocol) { - case PermitVpnRelay::Protocol::Tcp: return ConditionProtocol::Tcp(); - case PermitVpnRelay::Protocol::Udp: return ConditionProtocol::Udp(); + case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); + case WinFwProtocol::Udp: return ConditionProtocol::Udp(); default: { THROW_ERROR("Missing case handler in switch clause"); @@ -62,7 +63,7 @@ PermitVpnRelay::PermitVpnRelay ( const wfp::IpAddress &relay, uint16_t relayPort, - Protocol protocol, + WinFwProtocol protocol, const std::wstring &relayClient, Sublayer sublayer ) diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h index 22b7956588..d63f27a862 100644 --- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h +++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h @@ -1,6 +1,7 @@ #pragma once #include <winfw/rules/ifirewallrule.h> +#include <winfw/winfw.h> #include <libwfp/ipaddress.h> #include <string> @@ -11,12 +12,6 @@ class PermitVpnRelay : public IFirewallRule { public: - enum class Protocol - { - Tcp, - Udp - }; - enum class Sublayer { Baseline, @@ -27,7 +22,7 @@ public: ( const wfp::IpAddress &relay, uint16_t relayPort, - Protocol protocol, + WinFwProtocol protocol, const std::wstring &relayClient, Sublayer sublayer ); @@ -38,7 +33,7 @@ private: const wfp::IpAddress m_relay; const uint16_t m_relayPort; - const Protocol m_protocol; + const WinFwProtocol m_protocol; const std::wstring m_relayClient; const Sublayer m_sublayer; }; diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index f2d5a66b2a..ee7842877e 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -65,6 +65,16 @@ HandlePolicyException(const common::error::WindowsException &err) return WINFW_POLICY_STATUS_GENERAL_FAILURE; } +template<typename T> +std::optional<T> MakeOptional(T* object) +{ + if (nullptr == object) + { + return std::nullopt; + } + return std::make_optional(*object); +} + // // Networks for which DNS requests can be made on all network adapters. // @@ -136,6 +146,7 @@ WINFW_API WinFw_InitializeBlocked( uint32_t timeout, const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint, MullvadLogSink logSink, void *logSinkContext ) @@ -162,7 +173,7 @@ WinFw_InitializeBlocked( g_logSink = logSink; g_logSinkContext = logSinkContext; - g_fwContext = new FwContext(timeout_ms, *settings); + g_fwContext = new FwContext(timeout_ms, *settings, MakeOptional(allowedEndpoint)); } catch (std::exception &err) { @@ -247,9 +258,10 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, - const PingableHosts *pingableHosts + const PingableHosts *pingableHosts, + const WinFwEndpoint *allowedEndpoint ) { if (nullptr == g_fwContext) @@ -278,7 +290,8 @@ WinFw_ApplyPolicyConnecting( *settings, *relay, relayClient, - ConvertPingableHosts(pingableHosts) + ConvertPingableHosts(pingableHosts), + MakeOptional(allowedEndpoint) ) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE; } catch (common::error::WindowsException &err) @@ -305,7 +318,7 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, @@ -447,7 +460,8 @@ WINFW_LINKAGE WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyBlocked( - const WinFwSettings *settings + const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint ) { if (nullptr == g_fwContext) @@ -462,7 +476,7 @@ WinFw_ApplyPolicyBlocked( THROW_ERROR("Invalid argument: settings"); } - return g_fwContext->applyPolicyBlocked(*settings) + return g_fwContext->applyPolicyBlocked(*settings, MakeOptional(allowedEndpoint)) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE; } diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index f0a487cb12..23163786e9 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -37,13 +37,13 @@ enum WinFwProtocol : uint8_t Udp = 1 }; -typedef struct tag_WinFwRelay +typedef struct tag_WinFwEndpoint { const wchar_t *ip; uint16_t port; WinFwProtocol protocol; } -WinFwRelay; +WinFwEndpoint; #pragma pack(pop) @@ -88,6 +88,7 @@ WINFW_API WinFw_InitializeBlocked( uint32_t timeout, const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint, MullvadLogSink logSink, void *logSinkContext ); @@ -155,9 +156,10 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, - const PingableHosts *pingableHosts + const PingableHosts *pingableHosts, + const WinFwEndpoint *allowedEndpoint ); // @@ -183,7 +185,7 @@ WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, - const WinFwRelay *relay, + const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, @@ -203,7 +205,8 @@ WINFW_LINKAGE WINFW_POLICY_STATUS WINFW_API WinFw_ApplyPolicyBlocked( - const WinFwSettings *settings + const WinFwSettings *settings, + const WinFwEndpoint *allowedEndpoint ); // diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj index 8f9c37f919..3f9502a10d 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj +++ b/windows/winfw/src/winfw/winfw.vcxproj @@ -27,6 +27,7 @@ <ClCompile Include="rules\baseline\permitdhcp.cpp" /> <ClCompile Include="rules\baseline\permitdhcpserver.cpp" /> <ClCompile Include="rules\baseline\permitdns.cpp" /> + <ClCompile Include="rules\baseline\permitendpoint.cpp" /> <ClCompile Include="rules\baseline\permitlan.cpp" /> <ClCompile Include="rules\baseline\permitlanservice.cpp" /> <ClCompile Include="rules\baseline\permitloopback.cpp" /> @@ -61,6 +62,7 @@ <ClInclude Include="rules\baseline\permitdhcp.h" /> <ClInclude Include="rules\baseline\permitdhcpserver.h" /> <ClInclude Include="rules\baseline\permitdns.h" /> + <ClInclude Include="rules\baseline\permitendpoint.h" /> <ClInclude Include="rules\baseline\permitlan.h" /> <ClInclude Include="rules\baseline\permitlanservice.h" /> <ClInclude Include="rules\baseline\permitloopback.h" /> diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters index 312045876e..7a2aa85487 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj.filters +++ b/windows/winfw/src/winfw/winfw.vcxproj.filters @@ -61,6 +61,9 @@ <ClCompile Include="rules\persistent\blockall.cpp"> <Filter>rules\persistent</Filter> </ClCompile> + <ClCompile Include="rules\baseline\permitendpoint.cpp"> + <Filter>rules\baseline</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -132,6 +135,9 @@ <ClInclude Include="rules\persistent\blockall.h"> <Filter>rules\persistent</Filter> </ClInclude> + <ClInclude Include="rules\baseline\permitendpoint.h"> + <Filter>rules\baseline</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <Filter Include="rules"> |
