diff options
| author | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-04-04 16:50:35 +0200 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-04-04 16:50:35 +0200 |
| commit | 7170e05df5cbfd4dd3f372267bd9c155ed49f5f6 (patch) | |
| tree | 44d8037ea9f552227dfce5a5e2358bc580a8cec5 | |
| parent | e45f61d0f538f20d594106e0ec9ccecb40678a3b (diff) | |
| parent | c0b0304be43f994dc661d18696085dbd415afb24 (diff) | |
| download | mullvadvpn-7170e05df5cbfd4dd3f372267bd9c155ed49f5f6.tar.xz mullvadvpn-7170e05df5cbfd4dd3f372267bd9c155ed49f5f6.zip | |
Merge branch 'intersection-macro'
| -rw-r--r-- | Cargo.lock | 11 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-relay-selector/Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/detailer.rs | 7 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/helpers.rs | 35 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/matcher.rs | 8 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/mod.rs | 111 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/query.rs | 296 | ||||
| -rw-r--r-- | mullvad-relay-selector/tests/relay_selector.rs | 43 | ||||
| -rw-r--r-- | mullvad-types/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-types/intersection-derive/Cargo.toml | 16 | ||||
| -rw-r--r-- | mullvad-types/intersection-derive/src/lib.rs | 68 | ||||
| -rw-r--r-- | mullvad-types/src/constraints/constraint.rs | 24 | ||||
| -rw-r--r-- | mullvad-types/src/constraints/mod.rs | 133 | ||||
| -rw-r--r-- | mullvad-types/src/lib.rs | 3 | ||||
| -rw-r--r-- | mullvad-types/src/relay_constraints.rs | 70 | ||||
| -rw-r--r-- | test/Cargo.lock | 11 |
17 files changed, 402 insertions, 438 deletions
diff --git a/Cargo.lock b/Cargo.lock index 25b7582fb6..cc54dcb168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1554,6 +1554,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc6d6206008e25125b1f97fbe5d309eb7b85141cf9199d52dbd3729a1584dd16" [[package]] +name = "intersection-derive" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + +[[package]] name = "ioctl-sys" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2109,6 +2118,7 @@ name = "mullvad-relay-selector" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork", "itertools 0.12.1", "log", @@ -2147,6 +2157,7 @@ version = "0.0.0" dependencies = [ "chrono", "clap", + "intersection-derive", "ipnetwork", "jnix", "log", diff --git a/Cargo.toml b/Cargo.toml index d760c77f6b..b0c28f016a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "mullvad-relay-selector", "mullvad-setup", "mullvad-types", + "mullvad-types/intersection-derive", "mullvad-version", "talpid-core", "talpid-dbus", diff --git a/mullvad-relay-selector/Cargo.toml b/mullvad-relay-selector/Cargo.toml index ec8943df0b..3224fd7983 100644 --- a/mullvad-relay-selector/Cargo.toml +++ b/mullvad-relay-selector/Cargo.toml @@ -22,6 +22,7 @@ serde_json = "1.0" talpid-types = { path = "../talpid-types" } mullvad-types = { path = "../mullvad-types" } +intersection-derive = { path = "../mullvad-types/intersection-derive"} [dev-dependencies] proptest = { workspace = true } diff --git a/mullvad-relay-selector/src/relay_selector/detailer.rs b/mullvad-relay-selector/src/relay_selector/detailer.rs index 374d92da09..8fa36f5434 100644 --- a/mullvad-relay-selector/src/relay_selector/detailer.rs +++ b/mullvad-relay-selector/src/relay_selector/detailer.rs @@ -48,7 +48,8 @@ pub enum Error { /// Constructs a [`MullvadWireguardEndpoint`] with details for how to connect to a Wireguard relay. /// /// # Returns -/// - A configured endpoint for Wireguard relay, encapsulating either a single-hop or multi-hop connection. +/// - A configured endpoint for Wireguard relay, encapsulating either a single-hop or multi-hop +/// connection. /// - Returns [`Option::None`] if the desired port is not in a valid port range (see /// [`WireguardRelayQuery::port`]) or relay addresses cannot be resolved. pub fn wireguard_endpoint( @@ -198,8 +199,8 @@ const fn get_public_key(relay: &Relay) -> Result<&PublicKey, Error> { /// - `port_ranges`: A slice of tuples, each representing a range of valid port numbers. /// /// # Returns -/// - `Option<u16>`: A randomly selected port number within the given ranges, or `None` if -/// the input is empty or the total number of available ports is zero. +/// - `Option<u16>`: A randomly selected port number within the given ranges, or `None` if the input +/// is empty or the total number of available ports is zero. fn select_random_port(port_ranges: &[(u16, u16)]) -> Result<u16, Error> { use rand::Rng; let get_port_amount = |range: &(u16, u16)| -> u64 { (1 + range.1 - range.0) as u64 }; diff --git a/mullvad-relay-selector/src/relay_selector/helpers.rs b/mullvad-relay-selector/src/relay_selector/helpers.rs index 5ad5bf7ab4..263034d836 100644 --- a/mullvad-relay-selector/src/relay_selector/helpers.rs +++ b/mullvad-relay-selector/src/relay_selector/helpers.rs @@ -8,24 +8,11 @@ use mullvad_types::{ relay_constraints::Udp2TcpObfuscationSettings, relay_list::{BridgeEndpointData, Relay, RelayEndpointData}, }; -use rand::{ - seq::{IteratorRandom, SliceRandom}, - thread_rng, Rng, -}; +use rand::{seq::SliceRandom, thread_rng, Rng}; use talpid_types::net::{obfuscation::ObfuscatorConfig, proxy::CustomProxy}; use crate::SelectedObfuscator; -/// Pick a random element out of `from`, excluding the element `exclude` from the selection. -pub fn random<'a, A: PartialEq>( - from: impl IntoIterator<Item = &'a A>, - exclude: &A, -) -> Option<&'a A> { - from.into_iter() - .filter(|&a| a != exclude) - .choose(&mut thread_rng()) -} - /// Picks a relay using [pick_random_relay_fn], using the `weight` member of each relay /// as the weight function. pub fn pick_random_relay(relays: &[Relay]) -> Option<&Relay> { @@ -49,16 +36,16 @@ pub fn pick_random_relay_weighted<RelayType>( // Pick a random number in the range 1..=total_weight. This choses the relay with a // non-zero weight. // - // rng(1..=total_weight) - // | - // v - // _____________________________i___________________________________________________ - // 0|_____________|__________________________|___________|_____|___________|__________| total_weight - // ^ ^ ^ ^ ^ - // | | | | | - // ------------------------------------------ ------------ - // | | | - // weight(relay 0) weight(relay 1) .. .. .. weight(relay n) + // rng(1..=total_weight) + // | + // v + // ________________________i_______________________________________________ + // 0|_____________|____________________|___________|_____|________|__________| total_weight + // ^ ^ ^ ^ ^ + // | | | | | + // ------------------------------------ ------------ + // | | | + // weight(relay 0) weight(relay 1) .. .. .. weight(relay n) let mut i: u64 = rng.gen_range(1..=total_weight); Some( relays diff --git a/mullvad-relay-selector/src/relay_selector/matcher.rs b/mullvad-relay-selector/src/relay_selector/matcher.rs index f06e04224f..2987d8965a 100644 --- a/mullvad-relay-selector/src/relay_selector/matcher.rs +++ b/mullvad-relay-selector/src/relay_selector/matcher.rs @@ -35,10 +35,10 @@ pub fn filter_matching_relay_list<'a, R: Iterator<Item = &'a Relay> + Clone>( .filter(|relay| filter_on_providers(&query.providers, relay)); // The last filtering to be done is on the `include_in_country` attribute found on each - // relay. When the location constraint is based on country, a relay which has `include_in_country` - // set to true should always be prioritized over relays which has this flag set to false. - // We should only consider relays with `include_in_country` set to false if there are no - // other candidates left. + // relay. When the location constraint is based on country, a relay which has + // `include_in_country` set to true should always be prioritized over relays which has this + // flag set to false. We should only consider relays with `include_in_country` set to false + // if there are no other candidates left. match &locations { Constraint::Any => shortlist.cloned().collect(), Constraint::Only(locations) => { diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs index 424ca6385c..cbb3473f7f 100644 --- a/mullvad-relay-selector/src/relay_selector/mod.rs +++ b/mullvad-relay-selector/src/relay_selector/mod.rs @@ -9,6 +9,7 @@ pub mod query; use chrono::{DateTime, Local}; use itertools::Itertools; use once_cell::sync::Lazy; +use rand::{seq::IteratorRandom, thread_rng}; use std::{ path::Path, sync::{Arc, Mutex}, @@ -27,7 +28,7 @@ use mullvad_types::{ }, relay_list::{Relay, RelayEndpointData, RelayList}, settings::Settings, - CustomTunnelEndpoint, + CustomTunnelEndpoint, Intersection, }; use talpid_types::{ net::{ @@ -42,12 +43,12 @@ use self::{ detailer::{openvpn_endpoint, wireguard_endpoint}, matcher::{filter_matching_bridges, filter_matching_relay_list}, parsed_relays::ParsedRelays, - query::{BridgeQuery, Intersection, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery}, + query::{BridgeQuery, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery}, }; -/// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should prioritize on -/// successive connection attempts. Note that these will *never* override user preferences. -/// See [the documentation on `RelayQuery`][RelayQuery] for further details. +/// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should +/// prioritize on successive connection attempts. Note that these will *never* override user +/// preferences. See [the documentation on `RelayQuery`][RelayQuery] for further details. /// /// This list should be kept in sync with the expected behavior defined in `docs/relay-selector.md` pub static RETRY_ORDER: Lazy<Vec<RelayQuery>> = Lazy::new(|| { @@ -144,14 +145,16 @@ impl Default for RuntimeParameters { /// This enum exists to separate the two types of [`SelectorConfig`] that exists. /// -/// The first one is a "regular" config, where [`SelectorConfig::relay_settings`] is [`RelaySettings::Normal`]. -/// This is the most common variant, and there exists a mapping from this variant to [`RelayQueryBuilder`]. -/// Being able to implement `From<NormalSelectorConfig> for RelayQueryBuilder` was the main -/// motivator for introducing these seemingly useless derivates of [`SelectorConfig`]. +/// The first one is a "regular" config, where [`SelectorConfig::relay_settings`] is +/// [`RelaySettings::Normal`]. This is the most common variant, and there exists a mapping from this +/// variant to [`RelayQueryBuilder`]. Being able to implement `From<NormalSelectorConfig> for +/// RelayQueryBuilder` was the main motivator for introducing these seemingly useless derivates of +/// [`SelectorConfig`]. /// -/// The second one is a custom config, where [`SelectorConfig::relay_settings`] is [`RelaySettings::Custom`]. -/// For this variant, the endpoint where the client should connect to is already specified inside of the variant, -/// so in practice the relay selector becomes superfluous. Also, there exists no mapping to [`RelayQueryBuilder`]. +/// The second one is a custom config, where [`SelectorConfig::relay_settings`] is +/// [`RelaySettings::Custom`]. For this variant, the endpoint where the client should connect to is +/// already specified inside of the variant, so in practice the relay selector becomes superfluous. +/// Also, there exists no mapping to [`RelayQueryBuilder`]. #[derive(Debug, Clone)] enum SpecializedSelectorConfig<'a> { // This variant implements `From<NormalSelectorConfig> for RelayQuery` @@ -500,9 +503,9 @@ impl RelaySelector { self.get_relay_with_custom_params(retry_attempt, &RETRY_ORDER, runtime_params) } - /// Peek at which [`TunnelType`] that would be returned for a certain connection attempt for a given - /// [`SelectorConfig`]. Returns [`Option::None`] if the given config would return a custom - /// tunnel endpoint. + /// Peek at which [`TunnelType`] that would be returned for a certain connection attempt for a + /// given [`SelectorConfig`]. Returns [`Option::None`] if the given config would return a + /// custom tunnel endpoint. /// /// # Note /// This function is only really useful for testing-purposes. It is exposed to ease testing of @@ -556,15 +559,16 @@ impl RelaySelector { } } - /// This function defines the merge between a set of pre-defined queries and `user_preferences` for the given - /// `retry_attempt`. + /// This function defines the merge between a set of pre-defined queries and `user_preferences` + /// for the given `retry_attempt`. /// - /// This algorithm will loop back to the start of `retry_order` if `retry_attempt < retry_order.len()`. - /// If `user_preferences` is not compatible with any of the pre-defined queries in `retry_order`, `user_preferences` - /// is returned. + /// This algorithm will loop back to the start of `retry_order` if `retry_attempt < + /// retry_order.len()`. If `user_preferences` is not compatible with any of the pre-defined + /// queries in `retry_order`, `user_preferences` is returned. /// /// Runtime parameters may affect which of the default queries that are considered. For example, - /// queries which rely on IPv6 will not be considered if working IPv6 is not available at runtime. + /// queries which rely on IPv6 will not be considered if working IPv6 is not available at + /// runtime. fn pick_and_merge_query( retry_attempt: usize, retry_order: &[RelayQuery], @@ -583,15 +587,19 @@ impl RelaySelector { .unwrap_or(user_preferences) } - /// "Execute" the given query, yielding a final set of relays and/or bridges which the VPN traffic shall be routed through. + /// "Execute" the given query, yielding a final set of relays and/or bridges which the VPN + /// traffic shall be routed through. /// /// # Parameters - /// - `query`: Constraints that filter the available relays, such as geographic location or tunnel protocol. - /// - `config`: Configuration settings that influence relay selection, including bridge state and custom lists. + /// - `query`: Constraints that filter the available relays, such as geographic location or + /// tunnel protocol. + /// - `config`: Configuration settings that influence relay selection, including bridge state + /// and custom lists. /// - `parsed_relays`: The complete set of parsed relays available for selection. /// /// # Returns - /// * A randomly selected relay that meets the specified constraints (and a random bridge/entry relay if applicable). + /// * A randomly selected relay that meets the specified constraints (and a random bridge/entry + /// relay if applicable). /// See [`GetRelay`] for more details. /// * An `Err` if no suitable relay is found /// * An `Err` if no suitable bridge is found @@ -629,10 +637,10 @@ impl RelaySelector { parsed_relays: &ParsedRelays, config: &NormalSelectorConfig<'_>, ) -> Result<GetRelay, Error> { - // FIXME: A bit of defensive programming - calling `get_wiregurad_relay` with a query that doesn't - // specify Wireguard as the desired tunnel type is not valid and will lead to unwanted - // behavior. This should be seen as a workaround, and it would be nicer to lift this - // invariant to be checked by the type system instead. + // FIXME: A bit of defensive programming - calling `get_wiregurad_relay` with a query that + // doesn't specify Wireguard as the desired tunnel type is not valid and will lead + // to unwanted behavior. This should be seen as a workaround, and it would be nicer + // to lift this invariant to be checked by the type system instead. query.tunnel_protocol = Constraint::Only(TunnelType::Wireguard); Self::get_wireguard_relay(query, config, parsed_relays) } @@ -712,7 +720,8 @@ impl RelaySelector { let mut entry_relay_query = query.clone(); entry_relay_query.location = query.wireguard_constraints.entry_location.clone(); // After we have our two queries (one for the exit relay & one for the entry relay), - // we can query for all exit & entry candidates! All candidates are needed for the next step. + // we can query for all exit & entry candidates! All candidates are needed for the next + // step. let exit_candidates = filter_matching_relay_list(query, parsed_relays.relays(), config.custom_lists); let entry_candidates = filter_matching_relay_list( @@ -721,25 +730,25 @@ impl RelaySelector { config.custom_lists, ); - // This algorithm gracefully handles a particular edge case that arise when a constraint on - // the exit relay is more specific than on the entry relay which forces the relay selector - // to choose one specific relay. The relay selector could end up selecting that specific - // relay as the entry relay, thus leaving no remaining exit relay candidates or vice versa. + fn pick_random_excluding<'a>(list: &'a [Relay], exclude: &'a Relay) -> Option<&'a Relay> { + list.iter() + .filter(|&a| a != exclude) + .choose(&mut thread_rng()) + } + // We avoid picking the same relay for entry and exit by choosing one and excluding it when + // choosing the other. let (exit, entry) = match (exit_candidates.as_slice(), entry_candidates.as_slice()) { - ([exit], [entry]) if exit == entry => None, + // In the case where there is only one entry to choose from, we have to pick it before + // the exit (exits, [entry]) if exits.contains(entry) => { - let exit = helpers::random(exits, entry).ok_or(Error::NoRelay)?; - Some((exit, entry)) - } - ([exit], entrys) if entrys.contains(exit) => { - let entry = helpers::random(entrys, exit).ok_or(Error::NoRelay)?; - Some((exit, entry)) + pick_random_excluding(exits, entry).map(|exit| (exit, entry)) } - (exits, entrys) => { - let exit = helpers::pick_random_relay(exits).ok_or(Error::NoRelay)?; - let entry = helpers::random(entrys, exit).ok_or(Error::NoRelay)?; - Some((exit, entry)) + // Vice versa for the case of only one exit + ([exit], entries) if entries.contains(exit) => { + pick_random_excluding(entries, exit).map(|entry| (exit, entry)) } + (exits, entries) => helpers::pick_random_relay(exits) + .and_then(|exit| pick_random_excluding(entries, exit).map(|entry| (exit, entry))), } .ok_or(Error::NoRelay)?; @@ -847,17 +856,20 @@ impl RelaySelector { }) } - /// Selects a suitable bridge based on the specified settings, relay information, and transport protocol. + /// Selects a suitable bridge based on the specified settings, relay information, and transport + /// protocol. /// /// # Parameters /// - `query`: The filter criteria for selecting a bridge. /// - `relay`: Information about the current relay, including its location. /// - `protocol`: The transport protocol (TCP or UDP) in use. /// - `parsed_relays`: A structured representation of all available relays. - /// - `custom_lists`: User-defined or application-specific settings that may influence bridge selection. + /// - `custom_lists`: User-defined or application-specific settings that may influence bridge + /// selection. /// /// # Returns - /// * On success, returns an `Option` containing the selected bridge, if one is found. Returns `None` if no suitable bridge meets the criteria or bridges should not be used. + /// * On success, returns an `Option` containing the selected bridge, if one is found. Returns + /// `None` if no suitable bridge meets the criteria or bridges should not be used. /// * `Error::NoBridge` if attempting to use OpenVPN bridges over UDP, as this is unsupported. /// * `Error::NoRelay` if `relay` does not have a location set. #[cfg(not(target_os = "android"))] @@ -1010,7 +1022,8 @@ impl RelaySelector { } /// # Returns - /// A randomly selected relay that meets the specified constraints, or `None` if no suitable relay is found. + /// A randomly selected relay that meets the specified constraints, or `None` if no suitable + /// relay is found. #[cfg(not(target_os = "android"))] fn choose_openvpn_relay( query: &RelayQuery, diff --git a/mullvad-relay-selector/src/relay_selector/query.rs b/mullvad-relay-selector/src/relay_selector/query.rs index d5b17b2303..f112e83968 100644 --- a/mullvad-relay-selector/src/relay_selector/query.rs +++ b/mullvad-relay-selector/src/relay_selector/query.rs @@ -1,18 +1,20 @@ //! This module provides a flexible way to specify 'queries' for relays. //! //! A query is a set of constraints that the [`crate::RelaySelector`] will use when filtering out -//! potential relays that the daemon should connect to. It supports filtering relays by geographic location, -//! provider, ownership, and tunnel protocol, along with protocol-specific settings for WireGuard and OpenVPN. +//! potential relays that the daemon should connect to. It supports filtering relays by geographic +//! location, provider, ownership, and tunnel protocol, along with protocol-specific settings for +//! WireGuard and OpenVPN. //! //! The main components of this module include: //! -//! - [`RelayQuery`]: The core struct for specifying a query to select relay servers. It -//! aggregates constraints on location, providers, ownership, tunnel protocol, and -//! protocol-specific constraints for WireGuard and OpenVPN. +//! - [`RelayQuery`]: The core struct for specifying a query to select relay servers. It aggregates +//! constraints on location, providers, ownership, tunnel protocol, and protocol-specific +//! constraints for WireGuard and OpenVPN. //! - [`WireguardRelayQuery`] and [`OpenVpnRelayQuery`]: Structs that define protocol-specific //! constraints for selecting WireGuard and OpenVPN relays, respectively. -//! - [`Intersection`]: A trait implemented by the different query types that support intersection logic, -//! which allows for combining two queries into a single query that represents the common constraints of both. +//! - [`Intersection`]: A trait implemented by the different query types that support intersection +//! logic, which allows for combining two queries into a single query that represents the common +//! constraints of both. //! - [Builder patterns][builder]: The module also provides builder patterns for creating instances //! of `RelayQuery`, `WireguardRelayQuery`, and `OpenVpnRelayQuery` with a fluent API. //! @@ -33,6 +35,7 @@ use mullvad_types::{ RelayConstraints, SelectedObfuscation, TransportPort, Udp2TcpObfuscationSettings, WireguardConstraints, }, + Intersection, }; use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType}; @@ -68,7 +71,7 @@ use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType}; /// This example demonstrates creating a `RelayQuery` which can then be passed /// to the [`crate::RelaySelector`] to find a relay that matches the criteria. /// See [`builder`] for more info on how to construct queries. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct RelayQuery { pub location: Constraint<LocationConstraint>, pub providers: Constraint<Providers>, @@ -86,7 +89,7 @@ impl RelayQuery { /// Note that the following identity applies for any `other_query`: /// ```rust /// # use mullvad_relay_selector::query::RelayQuery; - /// # use crate::mullvad_relay_selector::query::Intersection; + /// # use mullvad_types::Intersection; /// /// # let other_query = RelayQuery::new(); /// assert_eq!(RelayQuery::new().intersection(other_query.clone()), Some(other_query)); @@ -105,105 +108,6 @@ impl RelayQuery { } } -impl Intersection for RelayQuery { - /// Return a new [`RelayQuery`] which matches the intersected queries. - /// - /// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned: - /// ```rust - /// # use mullvad_relay_selector::query::builder::RelayQueryBuilder; - /// # use crate::mullvad_relay_selector::query::Intersection; - /// let query_a = RelayQueryBuilder::new().wireguard().build(); - /// let query_b = RelayQueryBuilder::new().openvpn().build(); - /// assert_eq!(query_a.intersection(query_b), None); - /// ``` - /// - /// * Otherwise, a new [`RelayQuery`] is returned where each constraint is - /// as specific as possible. See [`Constraint`] for further details. - /// ```rust - /// # use crate::mullvad_relay_selector::*; - /// # use crate::mullvad_relay_selector::query::*; - /// # use crate::mullvad_relay_selector::query::builder::*; - /// # use mullvad_types::relay_list::*; - /// # use talpid_types::net::wireguard::PublicKey; - /// - /// // The relay list used by `relay_selector` in this example - /// let relay_list = RelayList { - /// # etag: None, - /// # openvpn: OpenVpnEndpointData { ports: vec![] }, - /// # bridge: BridgeEndpointData { - /// # shadowsocks: vec![], - /// # }, - /// # wireguard: WireguardEndpointData { - /// # port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)], - /// # ipv4_gateway: "10.64.0.1".parse().unwrap(), - /// # ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(), - /// # udp2tcp_ports: vec![], - /// # }, - /// countries: vec![RelayListCountry { - /// name: "Sweden".to_string(), - /// # code: "Sweden".to_string(), - /// cities: vec![RelayListCity { - /// name: "Gothenburg".to_string(), - /// # code: "Gothenburg".to_string(), - /// # latitude: 57.70887, - /// # longitude: 11.97456, - /// relays: vec![Relay { - /// hostname: "se9-wireguard".to_string(), - /// ipv4_addr_in: "185.213.154.68".parse().unwrap(), - /// # ipv6_addr_in: Some("2a03:1b20:5:f011::a09f".parse().unwrap()), - /// # include_in_country: false, - /// # active: true, - /// # owned: true, - /// # provider: "31173".to_string(), - /// # weight: 1, - /// # endpoint_data: RelayEndpointData::Wireguard(WireguardRelayEndpointData { - /// # public_key: PublicKey::from_base64( - /// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", - /// # ) - /// # .unwrap(), - /// # }), - /// # location: None, - /// }], - /// }], - /// }], - /// }; - /// - /// # let relay_selector = RelaySelector::from_list(SelectorConfig::default(), relay_list.clone()); - /// # let city = |country, city| GeographicLocationConstraint::city(country, city); - /// - /// let query_a = RelayQueryBuilder::new().wireguard().build(); - /// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build(); - /// - /// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap()); - /// assert!(result.is_ok()); - /// ``` - /// - /// This way, if the mullvad app wants to check if the user's relay settings - /// are compatible with any other [`RelayQuery`], for examples those defined by - /// [`RETRY_ORDER`] , taking the intersection between them will never result in - /// a situation where the app can override the user's preferences. - /// - /// [`RETRY_ORDER`]: crate::RETRY_ORDER - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(RelayQuery { - location: self.location.intersection(other.location)?, - providers: self.providers.intersection(other.providers)?, - ownership: self.ownership.intersection(other.ownership)?, - tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?, - wireguard_constraints: self - .wireguard_constraints - .intersection(other.wireguard_constraints)?, - openvpn_constraints: self - .openvpn_constraints - .intersection(other.openvpn_constraints)?, - }) - } -} - impl From<RelayQuery> for RelayConstraints { /// The mapping from [`RelayQuery`] to [`RelayConstraints`]. fn from(value: RelayQuery) -> Self { @@ -218,7 +122,8 @@ impl From<RelayQuery> for RelayConstraints { } } -/// A query for a relay with Wireguard-specific properties, such as `multihop` and [wireguard obfuscation][`SelectedObfuscation`]. +/// A query for a relay with Wireguard-specific properties, such as `multihop` and [wireguard +/// obfuscation][`SelectedObfuscation`]. /// /// This struct may look a lot like [`WireguardConstraints`], and that is the point! /// This struct is meant to be that type in the "universe of relay queries". The difference @@ -226,7 +131,7 @@ impl From<RelayQuery> for RelayConstraints { /// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner. /// Notice that [obfuscation][`SelectedObfuscation`] is not a [`Constraint`], but it is trivial /// to define [`Intersection`] on it, so it is fine. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct WireguardRelayQuery { pub port: Constraint<u16>, pub ip_version: Constraint<IpVersion>, @@ -254,37 +159,6 @@ impl WireguardRelayQuery { } } } -impl Intersection for WireguardRelayQuery { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(WireguardRelayQuery { - port: self.port.intersection(other.port)?, - ip_version: self.ip_version.intersection(other.ip_version)?, - use_multihop: self.use_multihop.intersection(other.use_multihop)?, - entry_location: self.entry_location.intersection(other.entry_location)?, - obfuscation: self.obfuscation.intersection(other.obfuscation)?, - udp2tcp_port: self.udp2tcp_port.intersection(other.udp2tcp_port)?, - }) - } -} - -impl Intersection for SelectedObfuscation { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - match (self, other) { - (left, SelectedObfuscation::Auto) => Some(left), - (SelectedObfuscation::Auto, right) => Some(right), - (left, right) if left == right => Some(left), - _ => None, - } - } -} impl From<WireguardRelayQuery> for WireguardConstraints { /// The mapping from [`WireguardRelayQuery`] to [`WireguardConstraints`]. @@ -304,7 +178,7 @@ impl From<WireguardRelayQuery> for WireguardConstraints { /// This struct is meant to be that type in the "universe of relay queries". The difference /// between them may seem subtle, but in a [`OpenVpnRelayQuery`] every field is represented /// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct OpenVpnRelayQuery { pub port: Constraint<TransportPort>, pub bridge_settings: Constraint<BridgeQuery>, @@ -319,29 +193,8 @@ impl OpenVpnRelayQuery { } } -impl Intersection for OpenVpnRelayQuery { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - let bridge_settings = { - match (self.bridge_settings, other.bridge_settings) { - // Recursive case - (Constraint::Only(left), Constraint::Only(right)) => { - Constraint::Only(left.intersection(right)?) - } - (left, right) => left.intersection(right)?, - } - }; - Some(OpenVpnRelayQuery { - port: self.port.intersection(other.port)?, - bridge_settings, - }) - } -} - -/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay queries". +/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay +/// queries". /// /// [`BridgeState`]: mullvad_types::relay_constraints::BridgeState /// [`BridgeSettings`]: mullvad_types::relay_constraints::BridgeSettings @@ -361,7 +214,7 @@ pub enum BridgeQuery { } impl BridgeQuery { - ///If `bridge_constraints` is `Any`, bridges should not be used due to + /// If `bridge_constraints` is `Any`, bridges should not be used due to /// latency concerns. /// /// If `bridge_constraints` is `Only(settings)`, then `settings` will be @@ -399,20 +252,6 @@ impl Intersection for BridgeQuery { } } -impl Intersection for BridgeConstraints { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(BridgeConstraints { - location: self.location.intersection(other.location)?, - providers: self.providers.intersection(other.providers)?, - ownership: self.ownership.intersection(other.ownership)?, - }) - } -} - impl From<OpenVpnRelayQuery> for OpenVpnConstraints { /// The mapping from [`OpenVpnRelayQuery`] to [`OpenVpnConstraints`]. fn from(value: OpenVpnRelayQuery) -> Self { @@ -420,103 +259,10 @@ impl From<OpenVpnRelayQuery> for OpenVpnConstraints { } } -/// Any type that wish to implement `Intersection` should make sure that the -/// following properties are upheld: -/// -/// - idempotency (if there is an identity element) -/// - commutativity -/// - associativity -pub trait Intersection { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized; -} - -impl<T: Intersection> Intersection for Constraint<T> { - /// Define the intersection between two arbitrary [`Constraint`]s. - /// - /// This operation may be compared to the set operation with the same name. - /// In contrast to the general set intersection, this function represents a - /// very specific case where [`Constraint::Any`] is equivalent to the set - /// universe and [`Constraint::Only`] represents a singleton set. Notable is - /// that the representation of any empty set is [`Option::None`]. - fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> { - use Constraint::*; - match (self, other) { - (Any, Any) => Some(Any), - (Only(t), Any) | (Any, Only(t)) => Some(Only(t)), - // Recurse on `left` and `right` to see if there exist an intersection - (Only(left), Only(right)) => Some(Only(left.intersection(right)?)), - } - } -} - -// Implement `Intersection` for different types - -impl Intersection for Providers { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - Providers::new(self.providers().intersection(other.providers())).ok() - } -} - -impl Intersection for Udp2TcpObfuscationSettings { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - Some(Udp2TcpObfuscationSettings { - port: self.port.intersection(other.port)?, - }) - } -} - -impl Intersection for TransportPort { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - let protocol = if self.protocol == other.protocol { - Some(self.protocol) - } else { - None - }?; - let port = self.port.intersection(other.port)?; - Some(TransportPort { protocol, port }) - } -} - -/// Auto-implement `Intersection` for trivial cases where the logic should just check if -/// `self` is equal to `other`. -macro_rules! impl_intersection_partialeq { - ($ty:ty) => { - impl Intersection for $ty { - fn intersection(self, other: Self) -> Option<Self> { - if self == other { - Some(self) - } else { - None - } - } - } - }; -} -impl_intersection_partialeq!(u16); -impl_intersection_partialeq!(bool); -// FIXME: [`LocationConstraint`] deserves a hand-rolled implementation of [`Intersection`], but -// it would probably be best to implement it for [`ResolvedLocationConstraint`] instead to properly -// handle custom lists. -impl_intersection_partialeq!(LocationConstraint); -impl_intersection_partialeq!(Ownership); -impl_intersection_partialeq!(talpid_types::net::TransportProtocol); -impl_intersection_partialeq!(talpid_types::net::TunnelType); -impl_intersection_partialeq!(talpid_types::net::IpVersion); - #[allow(unused)] pub mod builder { - //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate pattern. + //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate + //! pattern. use mullvad_types::{ constraints::Constraint, relay_constraints::{ diff --git a/mullvad-relay-selector/tests/relay_selector.rs b/mullvad-relay-selector/tests/relay_selector.rs index 99a5491fad..ed3546c62e 100644 --- a/mullvad-relay-selector/tests/relay_selector.rs +++ b/mullvad-relay-selector/tests/relay_selector.rs @@ -201,10 +201,11 @@ fn default_relay_selector() -> RelaySelector { } /// This is not an actual test. Rather, it serves as a reminder that if [`RETRY_ORDER`] is modified, -/// the programmer should be made aware to update all external documents which rely on the retry order -/// to be correct. +/// the programmer should be made aware to update all external documents which rely on the retry +/// order to be correct. /// -/// When all necessary changes have been made, feel free to update this test to mirror the new [`RETRY_ORDER`]. +/// When all necessary changes have been made, feel free to update this test to mirror the new +/// [`RETRY_ORDER`]. #[test] fn assert_retry_order() { use talpid_types::net::{IpVersion, TransportProtocol}; @@ -341,8 +342,8 @@ fn prefer_wireguard_when_auto() { } } -/// If a Wireguard relay is only specified by it's hostname (and not tunnel type), the relay selector should -/// still return a relay of the correct tunnel type (Wireguard). +/// If a Wireguard relay is only specified by it's hostname (and not tunnel type), the relay +/// selector should still return a relay of the correct tunnel type (Wireguard). #[test] fn test_prefer_wireguard_if_location_supports_it() { let relay_selector = default_relay_selector(); @@ -361,8 +362,8 @@ fn test_prefer_wireguard_if_location_supports_it() { } } -/// If an OpenVPN relay is only specified by it's hostname (and not tunnel type), the relay selector should -/// still return a relay of the correct tunnel type (OpenVPN). +/// If an OpenVPN relay is only specified by it's hostname (and not tunnel type), the relay selector +/// should still return a relay of the correct tunnel type (OpenVPN). #[test] fn test_prefer_openvpn_if_location_supports_it() { let relay_selector = default_relay_selector(); @@ -381,9 +382,10 @@ fn test_prefer_openvpn_if_location_supports_it() { } } -/// Assert that the relay selector does *not* return a multihop configuration where the exit and entry relay are -/// the same, even if the constraints would allow for it. Also verify that the relay selector is smart enough to -/// pick either the entry or exit relay first depending on which one ends up yielding a valid configuration. +/// Assert that the relay selector does *not* return a multihop configuration where the exit and +/// entry relay are the same, even if the constraints would allow for it. Also verify that the relay +/// selector is smart enough to pick either the entry or exit relay first depending on which one +/// ends up yielding a valid configuration. #[test] fn test_wireguard_entry() { // Define a relay list containing exactly two Wireguard relays in Gothenburg. @@ -555,7 +557,8 @@ fn test_wireguard_entry_hostname_collision() { /// Test that the relay selector: /// * returns an OpenVPN relay given a constraint of a valid transport protocol + port combo -/// * does *not* return an OpenVPN relay given a constraint of an *invalid* transport protocol + port combo +/// * does *not* return an OpenVPN relay given a constraint of an *invalid* transport protocol + +/// port combo #[test] fn test_openvpn_constraints() { let relay_selector = default_relay_selector(); @@ -644,7 +647,8 @@ fn test_openvpn_constraints() { } } -/// Construct a query for multihop configuration and assert that the relay selector picks an accompanying entry relay. +/// Construct a query for multihop configuration and assert that the relay selector picks an +/// accompanying entry relay. #[test] fn test_selecting_wireguard_location_will_consider_multihop() { let relay_selector = default_relay_selector(); @@ -678,8 +682,9 @@ fn test_selecting_any_relay_will_consider_multihop() { } } -/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is selected and multihop is explicitly -/// turned off. Assert that the relay selector always return an obfuscator configuration. +/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is selected and +/// multihop is explicitly turned off. Assert that the relay selector always return an obfuscator +/// configuration. #[test] fn test_selecting_wireguard_endpoint_with_udp2tcp_obfuscation() { let relay_selector = default_relay_selector(); @@ -704,12 +709,14 @@ fn test_selecting_wireguard_endpoint_with_udp2tcp_obfuscation() { } } -/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is set to "Auto" and multihop is -/// explicitly turned off. Assert that the relay selector does *not* return an obfuscator config. +/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is set to "Auto" and +/// multihop is explicitly turned off. Assert that the relay selector does *not* return an +/// obfuscator config. /// /// # Note -/// This is a highly specific test which details how the relay selector should behave at the time of writing this test. -/// The cost (in latency primarily) of using obfuscation is deemed to be too high to enable it as an auto-configuration. +/// This is a highly specific test which details how the relay selector should behave at the time of +/// writing this test. The cost (in latency primarily) of using obfuscation is deemed to be too high +/// to enable it as an auto-configuration. #[test] fn test_selecting_wireguard_endpoint_with_auto_obfuscation() { let relay_selector = default_relay_selector(); diff --git a/mullvad-types/Cargo.toml b/mullvad-types/Cargo.toml index 8b614dd182..c4b4f76486 100644 --- a/mullvad-types/Cargo.toml +++ b/mullvad-types/Cargo.toml @@ -21,6 +21,8 @@ serde = { version = "1.0", features = ["derive"] } uuid = { version = "1.4.1", features = ["v4", "serde" ] } talpid-types = { path = "../talpid-types" } +intersection-derive = { path = "intersection-derive" } + clap = { workspace = true , optional = true } diff --git a/mullvad-types/intersection-derive/Cargo.toml b/mullvad-types/intersection-derive/Cargo.toml new file mode 100644 index 0000000000..f167c28aed --- /dev/null +++ b/mullvad-types/intersection-derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "intersection-derive" +description = "Derive macro for the `Intersection` trait" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +syn = "2" +quote = "1"
\ No newline at end of file diff --git a/mullvad-types/intersection-derive/src/lib.rs b/mullvad-types/intersection-derive/src/lib.rs new file mode 100644 index 0000000000..65e5af15bd --- /dev/null +++ b/mullvad-types/intersection-derive/src/lib.rs @@ -0,0 +1,68 @@ +//! This `proc-macro` crate exports the [`Intersection`] derive macro, see the trait documentation +//! for more information. +extern crate proc_macro; + +use proc_macro::TokenStream; +use syn::{parse_macro_input, DeriveInput}; + +/// Derive macro for the [`Intersection`] trait on structs. +#[proc_macro_derive(Intersection)] +pub fn intersection_derive(item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as DeriveInput); + + inner::derive(input).into() +} + +mod inner { + use proc_macro2::TokenStream; + use quote::{quote, TokenStreamExt}; + use syn::{spanned::Spanned, DeriveInput, Error}; + + pub(crate) fn derive(input: DeriveInput) -> TokenStream { + if let syn::Data::Struct(data) = &input.data { + derive_for_struct(&input, data).unwrap_or_else(Error::into_compile_error) + } else { + syn::Error::new( + input.span(), + "Deriving `Intersection` is only supported for structs", + ) + .into_compile_error() + } + } + + pub(crate) fn derive_for_struct( + input: &DeriveInput, + data: &syn::DataStruct, + ) -> syn::Result<TokenStream> { + let my_type = &input.ident; + let mut field_conversions = quote! {}; + for field in &data.fields { + let Some(name) = &field.ident else { + return Err(syn::Error::new( + field.span(), + "Tuple structs are not currently supported", + )); + }; + + // TODO(Sebastian): Here, and in the `quote` below, we are referring to `Intersection` + // with its relative name, which will fail if the user renames the trait + // when importing, e.g. `use mullvad_types::Intersection as SomethingElse`. + // This is a know limitation of procural macros (declarative macros can use the `$crate` + // syntax). If the issue arises then it can be solve using the + // <https://crates.io/crates/proc-macro-crate> crate. Add it if necessary. + field_conversions.append_all(quote! { + #name: Intersection::intersection(self.#name, other.#name)?, + }) + } + + Ok(quote! { + impl Intersection for #my_type { + fn intersection(self, other: Self) -> ::core::option::Option<Self> { + ::core::option::Option::Some(Self { + #field_conversions + }) + } + } + }) + } +} diff --git a/mullvad-types/src/constraints/constraint.rs b/mullvad-types/src/constraints/constraint.rs index c1669e33ae..26e1df3c6c 100644 --- a/mullvad-types/src/constraints/constraint.rs +++ b/mullvad-types/src/constraints/constraint.rs @@ -3,8 +3,9 @@ #[cfg(target_os = "android")] use jnix::{FromJava, IntoJava}; use serde::{Deserialize, Serialize}; -use std::fmt; -use std::str::FromStr; +use std::{fmt, str::FromStr}; + +use crate::Intersection; /// Limits the set of [`crate::relay_list::Relay`]s that a `RelaySelector` may select. #[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] @@ -17,6 +18,25 @@ pub enum Constraint<T> { Only(T), } +impl<T: Intersection> Intersection for Constraint<T> { + /// Define the intersection between two arbitrary [`Constraint`]s. + /// + /// This operation may be compared to the set operation with the same name. + /// In contrast to the general set intersection, this function represents a + /// very specific case where [`Constraint::Any`] is equivalent to the set + /// universe and [`Constraint::Only`] represents a singleton set. Notable is + /// that the representation of any empty set is [`Option::None`]. + fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> { + use Constraint::*; + match (self, other) { + (Any, Any) => Some(Any), + (Only(t), Any) | (Any, Only(t)) => Some(Only(t)), + // Pick any of `left` or `right` if they are the same. + (Only(left), Only(right)) => left.intersection(right).map(Only), + } + } +} + impl<T: fmt::Display> fmt::Display for Constraint<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { diff --git a/mullvad-types/src/constraints/mod.rs b/mullvad-types/src/constraints/mod.rs index 4caa31e901..756066c447 100644 --- a/mullvad-types/src/constraints/mod.rs +++ b/mullvad-types/src/constraints/mod.rs @@ -5,10 +5,7 @@ mod constraint; // Re-export bits & pieces from `constraints.rs` as needed pub use constraint::Constraint; -/// A limited variant of Sets. -pub trait Set<T> { - fn is_subset(&self, other: &T) -> bool; -} +use crate::relay_constraints; pub trait Match<T> { fn matches(&self, other: &T) -> bool; @@ -22,14 +19,124 @@ impl<T: Match<U>, U> Match<U> for Constraint<T> { } } -impl<T: Set<U>, U> Set<Constraint<U>> for Constraint<T> { - fn is_subset(&self, other: &Constraint<U>) -> bool { - match self { - Constraint::Any => other.is_any(), - Constraint::Only(ref constraint) => match other { - Constraint::Only(ref other_constraint) => constraint.is_subset(other_constraint), - _ => true, - }, +/// The intersection of two sets of criteria on [`Relay`](crate::relay_list::Relay)s is another +/// criteria which matches the given relay iff both of the original criteria matched. It is +/// primarily used by the relay selector to check whether a given connection method is compatible +/// with the users settings. +/// +/// # Examples +/// +/// The [`Intersection`] implementation of [`RelayQuery`] upholds the following properties: +/// +/// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned: +/// ```rust, ignore +/// # use mullvad_relay_selector::query::builder::RelayQueryBuilder; +/// let query_a = RelayQueryBuilder::new().wireguard().build(); +/// let query_b = RelayQueryBuilder::new().openvpn().build(); +/// assert_eq!(query_a.intersection(query_b), None); +/// ``` +/// +/// * Otherwise, a new [`RelayQuery`] is returned where each constraint is +/// as specific as possible. See [`Constraint`] for further details. +/// ```rust, ignore +/// let query_a = RelayQueryBuilder::new().wireguard().build(); +/// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build(); +/// +/// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap()); +/// assert!(result.is_ok()); +/// ``` +/// +/// This way, if the mullvad app wants to check if the user's relay settings +/// are compatible with any other [`RelayQuery`], for examples those defined by +/// [`RETRY_ORDER`] , taking the intersection between them will never result in +/// a situation where the app can override the user's preferences. +/// +/// [`RETRY_ORDER`]: crate::RETRY_ORDER +/// +/// The macro recursively applies the intersection on each field of the struct and returns the +/// resulting type or `None` if any of the intersections failed to overlap. +/// +/// The macro requires the types of each field to also implement [`Intersection`], which may be done +/// using this derive macro, the +/// +/// # Implementing [`Intersection`] +/// +/// For structs where each field already implements `Intersection`, the easiest way to implement the +/// trait is using the derive macro. Using the derive macro on [`RelayQuery`] +/// ```rust, ignore +/// #[derive(Intersection)] +/// struct RelayQuery { +/// pub location: Constraint<LocationConstraint>, +/// pub providers: Constraint<Providers>, +/// pub ownership: Constraint<Ownership>, +/// pub tunnel_protocol: Constraint<TunnelType>, +/// pub wireguard_constraints: WireguardRelayQuery, +/// pub openvpn_constraints: OpenVpnRelayQuery, +/// } +/// ``` +/// +/// produces an implementation like this: +/// +/// ```rust, ignore +/// impl Intersection for RelayQuery { +/// fn intersection(self, other: Self) -> Option<Self> +/// where +/// Self: PartialEq, +/// Self: Sized, +/// { +/// Some(RelayQuery { +/// location: self.location.intersection(other.location)?, +/// providers: self.providers.intersection(other.providers)?, +/// ownership: self.ownership.intersection(other.ownership)?, +/// tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?, +/// wireguard_constraints: self +/// .wireguard_constraints +/// .intersection(other.wireguard_constraints)?, +/// openvpn_constraints: self +/// .openvpn_constraints +/// .intersection(other.openvpn_constraints)?, +/// }) +/// } +/// } +/// ``` +/// +/// For types that cannot "overlap", e.g. they only intersect if they are equal, the declarative +/// macro [`impl_intersection_partialeq`] can be used. +/// +/// For less trivial cases, the trait needs to be implemented manually. When doing so, make sure +/// that the following properties are upheld: +/// +/// - idempotency (if there is an identity element) +/// - commutativity +/// - associativity +pub trait Intersection: Sized { + fn intersection(self, other: Self) -> Option<Self>; +} + +#[macro_export] +macro_rules! impl_intersection_partialeq { + ($ty:ty) => { + impl $crate::Intersection for $ty { + fn intersection(self, other: Self) -> Option<Self> { + if self == other { + Some(self) + } else { + None + } + } } - } + }; } + +impl_intersection_partialeq!(u16); +impl_intersection_partialeq!(bool); + +// NOTE: this implementation does not do what you may expect of an intersection +impl_intersection_partialeq!(relay_constraints::Providers); +// NOTE: should take actual intersection +impl_intersection_partialeq!(relay_constraints::LocationConstraint); +impl_intersection_partialeq!(relay_constraints::Ownership); +// NOTE: it contains an inner constraint +impl_intersection_partialeq!(talpid_types::net::TransportProtocol); +impl_intersection_partialeq!(talpid_types::net::TunnelType); +impl_intersection_partialeq!(talpid_types::net::IpVersion); diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs index 231f80300b..d1a50fa0ea 100644 --- a/mullvad-types/src/lib.rs +++ b/mullvad-types/src/lib.rs @@ -21,3 +21,6 @@ pub use crate::custom_tunnel::*; pub const TUNNEL_TABLE_ID: u32 = 0x6d6f6c65; #[cfg(target_os = "linux")] pub const TUNNEL_FWMARK: u32 = 0x6d6f6c65; + +pub use constraints::Intersection; +pub use intersection_derive::Intersection; diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 05d07477be..9b6b640cf5 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -2,11 +2,11 @@ //! updated as well. use crate::{ - constraints::{Constraint, Match, Set}, + constraints::{Constraint, Match}, custom_list::{CustomListsSettings, Id}, location::{CityCode, CountryCode, Hostname}, relay_list::Relay, - CustomTunnelEndpoint, + CustomTunnelEndpoint, Intersection, }; #[cfg(target_os = "android")] use jnix::{jni::objects::JObject, FromJava, IntoJava, JnixEnv}; @@ -294,51 +294,6 @@ impl Match<Relay> for GeographicLocationConstraint { } } -impl Set<GeographicLocationConstraint> for GeographicLocationConstraint { - /// Returns whether `self` is equal to or a subset of `other`. - fn is_subset(&self, other: &Self) -> bool { - match self { - GeographicLocationConstraint::Country(_) => self == other, - GeographicLocationConstraint::City(ref country, ref _city) => match other { - GeographicLocationConstraint::Country(ref other_country) => { - country == other_country - } - GeographicLocationConstraint::City(..) => self == other, - _ => false, - }, - GeographicLocationConstraint::Hostname(ref country, ref city, ref _hostname) => { - match other { - GeographicLocationConstraint::Country(ref other_country) => { - country == other_country - } - GeographicLocationConstraint::City(ref other_country, ref other_city) => { - country == other_country && city == other_city - } - GeographicLocationConstraint::Hostname(..) => self == other, - } - } - } - } -} - -impl Set<Constraint<Vec<GeographicLocationConstraint>>> - for Constraint<Vec<GeographicLocationConstraint>> -{ - fn is_subset(&self, other: &Self) -> bool { - match self { - Constraint::Any => other.is_any(), - Constraint::Only(locations) => match other { - Constraint::Any => true, - Constraint::Only(other_locations) => locations.iter().all(|location| { - other_locations - .iter() - .any(|other_location| location.is_subset(other_location)) - }), - }, - } - } -} - /// Limits the set of servers to choose based on ownership. #[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)] #[cfg_attr(target_os = "android", derive(IntoJava, FromJava))] @@ -463,7 +418,7 @@ impl fmt::Display for GeographicLocationConstraint { } } -#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] pub struct TransportPort { pub protocol: TransportProtocol, pub port: Constraint<u16>, @@ -646,6 +601,21 @@ pub enum SelectedObfuscation { Udp2Tcp, } +impl Intersection for SelectedObfuscation { + fn intersection(self, other: Self) -> Option<Self> + where + Self: PartialEq, + Self: Sized, + { + match (self, other) { + (left, SelectedObfuscation::Auto) => Some(left), + (SelectedObfuscation::Auto, right) => Some(right), + (left, right) if left == right => Some(left), + _ => None, + } + } +} + impl fmt::Display for SelectedObfuscation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -656,7 +626,7 @@ impl fmt::Display for SelectedObfuscation { } } -#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] #[serde(rename_all = "snake_case")] @@ -716,7 +686,7 @@ pub struct ObfuscationSettings { } /// Limits the set of bridge servers to use in `mullvad-daemon`. -#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] #[serde(default)] #[serde(rename_all = "snake_case")] pub struct BridgeConstraints { diff --git a/test/Cargo.lock b/test/Cargo.lock index 5f207d9a34..eb38b5f37a 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -1380,6 +1380,15 @@ dependencies = [ ] [[package]] +name = "intersection-derive" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + +[[package]] name = "inventory" version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1814,6 +1823,7 @@ name = "mullvad-relay-selector" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork 0.16.0", "itertools 0.12.1", "log", @@ -1830,6 +1840,7 @@ name = "mullvad-types" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork 0.16.0", "jnix", "log", |
