diff options
| -rw-r--r-- | Cargo.lock | 11 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-relay-selector/Cargo.toml | 1 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/mod.rs | 4 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/relay_selector/query.rs | 277 | ||||
| -rw-r--r-- | mullvad-types/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-types/intersection-derive/Cargo.toml | 16 | ||||
| -rw-r--r-- | mullvad-types/intersection-derive/src/lib.rs | 68 | ||||
| -rw-r--r-- | mullvad-types/src/constraints/constraint.rs | 24 | ||||
| -rw-r--r-- | mullvad-types/src/constraints/mod.rs | 128 | ||||
| -rw-r--r-- | mullvad-types/src/lib.rs | 3 | ||||
| -rw-r--r-- | mullvad-types/src/relay_constraints.rs | 23 | ||||
| -rw-r--r-- | test/Cargo.lock | 11 |
13 files changed, 285 insertions, 284 deletions
diff --git a/Cargo.lock b/Cargo.lock index 25b7582fb6..cc54dcb168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1554,6 +1554,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc6d6206008e25125b1f97fbe5d309eb7b85141cf9199d52dbd3729a1584dd16" [[package]] +name = "intersection-derive" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + +[[package]] name = "ioctl-sys" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2109,6 +2118,7 @@ name = "mullvad-relay-selector" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork", "itertools 0.12.1", "log", @@ -2147,6 +2157,7 @@ version = "0.0.0" dependencies = [ "chrono", "clap", + "intersection-derive", "ipnetwork", "jnix", "log", diff --git a/Cargo.toml b/Cargo.toml index d760c77f6b..b0c28f016a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "mullvad-relay-selector", "mullvad-setup", "mullvad-types", + "mullvad-types/intersection-derive", "mullvad-version", "talpid-core", "talpid-dbus", diff --git a/mullvad-relay-selector/Cargo.toml b/mullvad-relay-selector/Cargo.toml index ec8943df0b..3224fd7983 100644 --- a/mullvad-relay-selector/Cargo.toml +++ b/mullvad-relay-selector/Cargo.toml @@ -22,6 +22,7 @@ serde_json = "1.0" talpid-types = { path = "../talpid-types" } mullvad-types = { path = "../mullvad-types" } +intersection-derive = { path = "../mullvad-types/intersection-derive"} [dev-dependencies] proptest = { workspace = true } diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs index e0949beae9..cbb3473f7f 100644 --- a/mullvad-relay-selector/src/relay_selector/mod.rs +++ b/mullvad-relay-selector/src/relay_selector/mod.rs @@ -28,7 +28,7 @@ use mullvad_types::{ }, relay_list::{Relay, RelayEndpointData, RelayList}, settings::Settings, - CustomTunnelEndpoint, + CustomTunnelEndpoint, Intersection, }; use talpid_types::{ net::{ @@ -43,7 +43,7 @@ use self::{ detailer::{openvpn_endpoint, wireguard_endpoint}, matcher::{filter_matching_bridges, filter_matching_relay_list}, parsed_relays::ParsedRelays, - query::{BridgeQuery, Intersection, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery}, + query::{BridgeQuery, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery}, }; /// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should diff --git a/mullvad-relay-selector/src/relay_selector/query.rs b/mullvad-relay-selector/src/relay_selector/query.rs index 0536379dbc..f112e83968 100644 --- a/mullvad-relay-selector/src/relay_selector/query.rs +++ b/mullvad-relay-selector/src/relay_selector/query.rs @@ -35,6 +35,7 @@ use mullvad_types::{ RelayConstraints, SelectedObfuscation, TransportPort, Udp2TcpObfuscationSettings, WireguardConstraints, }, + Intersection, }; use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType}; @@ -70,7 +71,7 @@ use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType}; /// This example demonstrates creating a `RelayQuery` which can then be passed /// to the [`crate::RelaySelector`] to find a relay that matches the criteria. /// See [`builder`] for more info on how to construct queries. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct RelayQuery { pub location: Constraint<LocationConstraint>, pub providers: Constraint<Providers>, @@ -88,7 +89,7 @@ impl RelayQuery { /// Note that the following identity applies for any `other_query`: /// ```rust /// # use mullvad_relay_selector::query::RelayQuery; - /// # use crate::mullvad_relay_selector::query::Intersection; + /// # use mullvad_types::Intersection; /// /// # let other_query = RelayQuery::new(); /// assert_eq!(RelayQuery::new().intersection(other_query.clone()), Some(other_query)); @@ -107,105 +108,6 @@ impl RelayQuery { } } -impl Intersection for RelayQuery { - /// Return a new [`RelayQuery`] which matches the intersected queries. - /// - /// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned: - /// ```rust - /// # use mullvad_relay_selector::query::builder::RelayQueryBuilder; - /// # use crate::mullvad_relay_selector::query::Intersection; - /// let query_a = RelayQueryBuilder::new().wireguard().build(); - /// let query_b = RelayQueryBuilder::new().openvpn().build(); - /// assert_eq!(query_a.intersection(query_b), None); - /// ``` - /// - /// * Otherwise, a new [`RelayQuery`] is returned where each constraint is - /// as specific as possible. See [`Constraint`] for further details. - /// ```rust - /// # use crate::mullvad_relay_selector::*; - /// # use crate::mullvad_relay_selector::query::*; - /// # use crate::mullvad_relay_selector::query::builder::*; - /// # use mullvad_types::relay_list::*; - /// # use talpid_types::net::wireguard::PublicKey; - /// - /// // The relay list used by `relay_selector` in this example - /// let relay_list = RelayList { - /// # etag: None, - /// # openvpn: OpenVpnEndpointData { ports: vec![] }, - /// # bridge: BridgeEndpointData { - /// # shadowsocks: vec![], - /// # }, - /// # wireguard: WireguardEndpointData { - /// # port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)], - /// # ipv4_gateway: "10.64.0.1".parse().unwrap(), - /// # ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(), - /// # udp2tcp_ports: vec![], - /// # }, - /// countries: vec![RelayListCountry { - /// name: "Sweden".to_string(), - /// # code: "Sweden".to_string(), - /// cities: vec![RelayListCity { - /// name: "Gothenburg".to_string(), - /// # code: "Gothenburg".to_string(), - /// # latitude: 57.70887, - /// # longitude: 11.97456, - /// relays: vec![Relay { - /// hostname: "se9-wireguard".to_string(), - /// ipv4_addr_in: "185.213.154.68".parse().unwrap(), - /// # ipv6_addr_in: Some("2a03:1b20:5:f011::a09f".parse().unwrap()), - /// # include_in_country: false, - /// # active: true, - /// # owned: true, - /// # provider: "31173".to_string(), - /// # weight: 1, - /// # endpoint_data: RelayEndpointData::Wireguard(WireguardRelayEndpointData { - /// # public_key: PublicKey::from_base64( - /// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=", - /// # ) - /// # .unwrap(), - /// # }), - /// # location: None, - /// }], - /// }], - /// }], - /// }; - /// - /// # let relay_selector = RelaySelector::from_list(SelectorConfig::default(), relay_list.clone()); - /// # let city = |country, city| GeographicLocationConstraint::city(country, city); - /// - /// let query_a = RelayQueryBuilder::new().wireguard().build(); - /// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build(); - /// - /// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap()); - /// assert!(result.is_ok()); - /// ``` - /// - /// This way, if the mullvad app wants to check if the user's relay settings - /// are compatible with any other [`RelayQuery`], for examples those defined by - /// [`RETRY_ORDER`] , taking the intersection between them will never result in - /// a situation where the app can override the user's preferences. - /// - /// [`RETRY_ORDER`]: crate::RETRY_ORDER - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(RelayQuery { - location: self.location.intersection(other.location)?, - providers: self.providers.intersection(other.providers)?, - ownership: self.ownership.intersection(other.ownership)?, - tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?, - wireguard_constraints: self - .wireguard_constraints - .intersection(other.wireguard_constraints)?, - openvpn_constraints: self - .openvpn_constraints - .intersection(other.openvpn_constraints)?, - }) - } -} - impl From<RelayQuery> for RelayConstraints { /// The mapping from [`RelayQuery`] to [`RelayConstraints`]. fn from(value: RelayQuery) -> Self { @@ -229,7 +131,7 @@ impl From<RelayQuery> for RelayConstraints { /// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner. /// Notice that [obfuscation][`SelectedObfuscation`] is not a [`Constraint`], but it is trivial /// to define [`Intersection`] on it, so it is fine. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct WireguardRelayQuery { pub port: Constraint<u16>, pub ip_version: Constraint<IpVersion>, @@ -257,37 +159,6 @@ impl WireguardRelayQuery { } } } -impl Intersection for WireguardRelayQuery { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(WireguardRelayQuery { - port: self.port.intersection(other.port)?, - ip_version: self.ip_version.intersection(other.ip_version)?, - use_multihop: self.use_multihop.intersection(other.use_multihop)?, - entry_location: self.entry_location.intersection(other.entry_location)?, - obfuscation: self.obfuscation.intersection(other.obfuscation)?, - udp2tcp_port: self.udp2tcp_port.intersection(other.udp2tcp_port)?, - }) - } -} - -impl Intersection for SelectedObfuscation { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - match (self, other) { - (left, SelectedObfuscation::Auto) => Some(left), - (SelectedObfuscation::Auto, right) => Some(right), - (left, right) if left == right => Some(left), - _ => None, - } - } -} impl From<WireguardRelayQuery> for WireguardConstraints { /// The mapping from [`WireguardRelayQuery`] to [`WireguardConstraints`]. @@ -307,7 +178,7 @@ impl From<WireguardRelayQuery> for WireguardConstraints { /// This struct is meant to be that type in the "universe of relay queries". The difference /// between them may seem subtle, but in a [`OpenVpnRelayQuery`] every field is represented /// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Intersection)] pub struct OpenVpnRelayQuery { pub port: Constraint<TransportPort>, pub bridge_settings: Constraint<BridgeQuery>, @@ -322,29 +193,8 @@ impl OpenVpnRelayQuery { } } -impl Intersection for OpenVpnRelayQuery { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - let bridge_settings = { - match (self.bridge_settings, other.bridge_settings) { - // Recursive case - (Constraint::Only(left), Constraint::Only(right)) => { - Constraint::Only(left.intersection(right)?) - } - (left, right) => left.intersection(right)?, - } - }; - Some(OpenVpnRelayQuery { - port: self.port.intersection(other.port)?, - bridge_settings, - }) - } -} - -/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay queries". +/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay +/// queries". /// /// [`BridgeState`]: mullvad_types::relay_constraints::BridgeState /// [`BridgeSettings`]: mullvad_types::relay_constraints::BridgeSettings @@ -364,7 +214,7 @@ pub enum BridgeQuery { } impl BridgeQuery { - ///If `bridge_constraints` is `Any`, bridges should not be used due to + /// If `bridge_constraints` is `Any`, bridges should not be used due to /// latency concerns. /// /// If `bridge_constraints` is `Only(settings)`, then `settings` will be @@ -402,20 +252,6 @@ impl Intersection for BridgeQuery { } } -impl Intersection for BridgeConstraints { - fn intersection(self, other: Self) -> Option<Self> - where - Self: PartialEq, - Self: Sized, - { - Some(BridgeConstraints { - location: self.location.intersection(other.location)?, - providers: self.providers.intersection(other.providers)?, - ownership: self.ownership.intersection(other.ownership)?, - }) - } -} - impl From<OpenVpnRelayQuery> for OpenVpnConstraints { /// The mapping from [`OpenVpnRelayQuery`] to [`OpenVpnConstraints`]. fn from(value: OpenVpnRelayQuery) -> Self { @@ -423,103 +259,10 @@ impl From<OpenVpnRelayQuery> for OpenVpnConstraints { } } -/// Any type that wish to implement `Intersection` should make sure that the -/// following properties are upheld: -/// -/// - idempotency (if there is an identity element) -/// - commutativity -/// - associativity -pub trait Intersection { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized; -} - -impl<T: Intersection> Intersection for Constraint<T> { - /// Define the intersection between two arbitrary [`Constraint`]s. - /// - /// This operation may be compared to the set operation with the same name. - /// In contrast to the general set intersection, this function represents a - /// very specific case where [`Constraint::Any`] is equivalent to the set - /// universe and [`Constraint::Only`] represents a singleton set. Notable is - /// that the representation of any empty set is [`Option::None`]. - fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> { - use Constraint::*; - match (self, other) { - (Any, Any) => Some(Any), - (Only(t), Any) | (Any, Only(t)) => Some(Only(t)), - // Recurse on `left` and `right` to see if there exist an intersection - (Only(left), Only(right)) => Some(Only(left.intersection(right)?)), - } - } -} - -// Implement `Intersection` for different types - -impl Intersection for Providers { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - Providers::new(self.providers().intersection(other.providers())).ok() - } -} - -impl Intersection for Udp2TcpObfuscationSettings { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - Some(Udp2TcpObfuscationSettings { - port: self.port.intersection(other.port)?, - }) - } -} - -impl Intersection for TransportPort { - fn intersection(self, other: Self) -> Option<Self> - where - Self: Sized, - { - let protocol = if self.protocol == other.protocol { - Some(self.protocol) - } else { - None - }?; - let port = self.port.intersection(other.port)?; - Some(TransportPort { protocol, port }) - } -} - -/// Auto-implement `Intersection` for trivial cases where the logic should just check if -/// `self` is equal to `other`. -macro_rules! impl_intersection_partialeq { - ($ty:ty) => { - impl Intersection for $ty { - fn intersection(self, other: Self) -> Option<Self> { - if self == other { - Some(self) - } else { - None - } - } - } - }; -} -impl_intersection_partialeq!(u16); -impl_intersection_partialeq!(bool); -// FIXME: [`LocationConstraint`] deserves a hand-rolled implementation of [`Intersection`], but -// it would probably be best to implement it for [`ResolvedLocationConstraint`] instead to properly -// handle custom lists. -impl_intersection_partialeq!(LocationConstraint); -impl_intersection_partialeq!(Ownership); -impl_intersection_partialeq!(talpid_types::net::TransportProtocol); -impl_intersection_partialeq!(talpid_types::net::TunnelType); -impl_intersection_partialeq!(talpid_types::net::IpVersion); - #[allow(unused)] pub mod builder { - //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate pattern. + //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate + //! pattern. use mullvad_types::{ constraints::Constraint, relay_constraints::{ diff --git a/mullvad-types/Cargo.toml b/mullvad-types/Cargo.toml index 8b614dd182..c4b4f76486 100644 --- a/mullvad-types/Cargo.toml +++ b/mullvad-types/Cargo.toml @@ -21,6 +21,8 @@ serde = { version = "1.0", features = ["derive"] } uuid = { version = "1.4.1", features = ["v4", "serde" ] } talpid-types = { path = "../talpid-types" } +intersection-derive = { path = "intersection-derive" } + clap = { workspace = true , optional = true } diff --git a/mullvad-types/intersection-derive/Cargo.toml b/mullvad-types/intersection-derive/Cargo.toml new file mode 100644 index 0000000000..f167c28aed --- /dev/null +++ b/mullvad-types/intersection-derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "intersection-derive" +description = "Derive macro for the `Intersection` trait" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +syn = "2" +quote = "1"
\ No newline at end of file diff --git a/mullvad-types/intersection-derive/src/lib.rs b/mullvad-types/intersection-derive/src/lib.rs new file mode 100644 index 0000000000..65e5af15bd --- /dev/null +++ b/mullvad-types/intersection-derive/src/lib.rs @@ -0,0 +1,68 @@ +//! This `proc-macro` crate exports the [`Intersection`] derive macro, see the trait documentation +//! for more information. +extern crate proc_macro; + +use proc_macro::TokenStream; +use syn::{parse_macro_input, DeriveInput}; + +/// Derive macro for the [`Intersection`] trait on structs. +#[proc_macro_derive(Intersection)] +pub fn intersection_derive(item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as DeriveInput); + + inner::derive(input).into() +} + +mod inner { + use proc_macro2::TokenStream; + use quote::{quote, TokenStreamExt}; + use syn::{spanned::Spanned, DeriveInput, Error}; + + pub(crate) fn derive(input: DeriveInput) -> TokenStream { + if let syn::Data::Struct(data) = &input.data { + derive_for_struct(&input, data).unwrap_or_else(Error::into_compile_error) + } else { + syn::Error::new( + input.span(), + "Deriving `Intersection` is only supported for structs", + ) + .into_compile_error() + } + } + + pub(crate) fn derive_for_struct( + input: &DeriveInput, + data: &syn::DataStruct, + ) -> syn::Result<TokenStream> { + let my_type = &input.ident; + let mut field_conversions = quote! {}; + for field in &data.fields { + let Some(name) = &field.ident else { + return Err(syn::Error::new( + field.span(), + "Tuple structs are not currently supported", + )); + }; + + // TODO(Sebastian): Here, and in the `quote` below, we are referring to `Intersection` + // with its relative name, which will fail if the user renames the trait + // when importing, e.g. `use mullvad_types::Intersection as SomethingElse`. + // This is a know limitation of procural macros (declarative macros can use the `$crate` + // syntax). If the issue arises then it can be solve using the + // <https://crates.io/crates/proc-macro-crate> crate. Add it if necessary. + field_conversions.append_all(quote! { + #name: Intersection::intersection(self.#name, other.#name)?, + }) + } + + Ok(quote! { + impl Intersection for #my_type { + fn intersection(self, other: Self) -> ::core::option::Option<Self> { + ::core::option::Option::Some(Self { + #field_conversions + }) + } + } + }) + } +} diff --git a/mullvad-types/src/constraints/constraint.rs b/mullvad-types/src/constraints/constraint.rs index c1669e33ae..26e1df3c6c 100644 --- a/mullvad-types/src/constraints/constraint.rs +++ b/mullvad-types/src/constraints/constraint.rs @@ -3,8 +3,9 @@ #[cfg(target_os = "android")] use jnix::{FromJava, IntoJava}; use serde::{Deserialize, Serialize}; -use std::fmt; -use std::str::FromStr; +use std::{fmt, str::FromStr}; + +use crate::Intersection; /// Limits the set of [`crate::relay_list::Relay`]s that a `RelaySelector` may select. #[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] @@ -17,6 +18,25 @@ pub enum Constraint<T> { Only(T), } +impl<T: Intersection> Intersection for Constraint<T> { + /// Define the intersection between two arbitrary [`Constraint`]s. + /// + /// This operation may be compared to the set operation with the same name. + /// In contrast to the general set intersection, this function represents a + /// very specific case where [`Constraint::Any`] is equivalent to the set + /// universe and [`Constraint::Only`] represents a singleton set. Notable is + /// that the representation of any empty set is [`Option::None`]. + fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> { + use Constraint::*; + match (self, other) { + (Any, Any) => Some(Any), + (Only(t), Any) | (Any, Only(t)) => Some(Only(t)), + // Pick any of `left` or `right` if they are the same. + (Only(left), Only(right)) => left.intersection(right).map(Only), + } + } +} + impl<T: fmt::Display> fmt::Display for Constraint<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { diff --git a/mullvad-types/src/constraints/mod.rs b/mullvad-types/src/constraints/mod.rs index 4caa31e901..8ee6e93195 100644 --- a/mullvad-types/src/constraints/mod.rs +++ b/mullvad-types/src/constraints/mod.rs @@ -22,14 +22,124 @@ impl<T: Match<U>, U> Match<U> for Constraint<T> { } } -impl<T: Set<U>, U> Set<Constraint<U>> for Constraint<T> { - fn is_subset(&self, other: &Constraint<U>) -> bool { - match self { - Constraint::Any => other.is_any(), - Constraint::Only(ref constraint) => match other { - Constraint::Only(ref other_constraint) => constraint.is_subset(other_constraint), - _ => true, - }, +/// The intersection of two sets of criteria on [`Relay`](crate::relay_list::Relay)s is another +/// criteria which matches the given relay iff both of the original criteria matched. It is +/// primarily used by the relay selector to check whether a given connection method is compatible +/// with the users settings. +/// +/// # Examples +/// +/// The [`Intersection`] implementation of [`RelayQuery`] upholds the following properties: +/// +/// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned: +/// ```rust, ignore +/// # use mullvad_relay_selector::query::builder::RelayQueryBuilder; +/// let query_a = RelayQueryBuilder::new().wireguard().build(); +/// let query_b = RelayQueryBuilder::new().openvpn().build(); +/// assert_eq!(query_a.intersection(query_b), None); +/// ``` +/// +/// * Otherwise, a new [`RelayQuery`] is returned where each constraint is +/// as specific as possible. See [`Constraint`] for further details. +/// ```rust, ignore +/// let query_a = RelayQueryBuilder::new().wireguard().build(); +/// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build(); +/// +/// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap()); +/// assert!(result.is_ok()); +/// ``` +/// +/// This way, if the mullvad app wants to check if the user's relay settings +/// are compatible with any other [`RelayQuery`], for examples those defined by +/// [`RETRY_ORDER`] , taking the intersection between them will never result in +/// a situation where the app can override the user's preferences. +/// +/// [`RETRY_ORDER`]: crate::RETRY_ORDER +/// +/// The macro recursively applies the intersection on each field of the struct and returns the +/// resulting type or `None` if any of the intersections failed to overlap. +/// +/// The macro requires the types of each field to also implement [`Intersection`], which may be done +/// using this derive macro, the +/// +/// # Implementing [`Intersection`] +/// +/// For structs where each field already implements `Intersection`, the easiest way to implement the +/// trait is using the derive macro. Using the derive macro on [`RelayQuery`] +/// ```rust, ignore +/// #[derive(Intersection)] +/// struct RelayQuery { +/// pub location: Constraint<LocationConstraint>, +/// pub providers: Constraint<Providers>, +/// pub ownership: Constraint<Ownership>, +/// pub tunnel_protocol: Constraint<TunnelType>, +/// pub wireguard_constraints: WireguardRelayQuery, +/// pub openvpn_constraints: OpenVpnRelayQuery, +/// } +/// ``` +/// +/// produces an implementation like this: +/// +/// ```rust, ignore +/// impl Intersection for RelayQuery { +/// fn intersection(self, other: Self) -> Option<Self> +/// where +/// Self: PartialEq, +/// Self: Sized, +/// { +/// Some(RelayQuery { +/// location: self.location.intersection(other.location)?, +/// providers: self.providers.intersection(other.providers)?, +/// ownership: self.ownership.intersection(other.ownership)?, +/// tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?, +/// wireguard_constraints: self +/// .wireguard_constraints +/// .intersection(other.wireguard_constraints)?, +/// openvpn_constraints: self +/// .openvpn_constraints +/// .intersection(other.openvpn_constraints)?, +/// }) +/// } +/// } +/// ``` +/// +/// For types that cannot "overlap", e.g. they only intersect if they are equal, the declarative +/// macro [`impl_intersection_partialeq`] can be used. +/// +/// For less trivial cases, the trait needs to be implemented manually. When doing so, make sure +/// that the following properties are upheld: +/// +/// - idempotency (if there is an identity element) +/// - commutativity +/// - associativity +pub trait Intersection: Sized { + fn intersection(self, other: Self) -> Option<Self>; +} + +#[macro_export] +macro_rules! impl_intersection_partialeq { + ($ty:ty) => { + impl $crate::Intersection for $ty { + fn intersection(self, other: Self) -> Option<Self> { + if self == other { + Some(self) + } else { + None + } + } } - } + }; } + +impl_intersection_partialeq!(u16); +impl_intersection_partialeq!(bool); + +// NOTE: this implementation does not do what you may expect of an intersection +impl_intersection_partialeq!(relay_constraints::Providers); +// NOTE: should take actual intersection +impl_intersection_partialeq!(relay_constraints::LocationConstraint); +impl_intersection_partialeq!(relay_constraints::Ownership); +// NOTE: it contains an inner constraint +impl_intersection_partialeq!(talpid_types::net::TransportProtocol); +impl_intersection_partialeq!(talpid_types::net::TunnelType); +impl_intersection_partialeq!(talpid_types::net::IpVersion); diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs index 231f80300b..d1a50fa0ea 100644 --- a/mullvad-types/src/lib.rs +++ b/mullvad-types/src/lib.rs @@ -21,3 +21,6 @@ pub use crate::custom_tunnel::*; pub const TUNNEL_TABLE_ID: u32 = 0x6d6f6c65; #[cfg(target_os = "linux")] pub const TUNNEL_FWMARK: u32 = 0x6d6f6c65; + +pub use constraints::Intersection; +pub use intersection_derive::Intersection; diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 05d07477be..dc6d0608e5 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -6,7 +6,7 @@ use crate::{ custom_list::{CustomListsSettings, Id}, location::{CityCode, CountryCode, Hostname}, relay_list::Relay, - CustomTunnelEndpoint, + CustomTunnelEndpoint, Intersection, }; #[cfg(target_os = "android")] use jnix::{jni::objects::JObject, FromJava, IntoJava, JnixEnv}; @@ -463,7 +463,7 @@ impl fmt::Display for GeographicLocationConstraint { } } -#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] pub struct TransportPort { pub protocol: TransportProtocol, pub port: Constraint<u16>, @@ -646,6 +646,21 @@ pub enum SelectedObfuscation { Udp2Tcp, } +impl Intersection for SelectedObfuscation { + fn intersection(self, other: Self) -> Option<Self> + where + Self: PartialEq, + Self: Sized, + { + match (self, other) { + (left, SelectedObfuscation::Auto) => Some(left), + (SelectedObfuscation::Auto, right) => Some(right), + (left, right) if left == right => Some(left), + _ => None, + } + } +} + impl fmt::Display for SelectedObfuscation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -656,7 +671,7 @@ impl fmt::Display for SelectedObfuscation { } } -#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] #[serde(rename_all = "snake_case")] @@ -716,7 +731,7 @@ pub struct ObfuscationSettings { } /// Limits the set of bridge servers to use in `mullvad-daemon`. -#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)] #[serde(default)] #[serde(rename_all = "snake_case")] pub struct BridgeConstraints { diff --git a/test/Cargo.lock b/test/Cargo.lock index 5f207d9a34..eb38b5f37a 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -1380,6 +1380,15 @@ dependencies = [ ] [[package]] +name = "intersection-derive" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + +[[package]] name = "inventory" version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1814,6 +1823,7 @@ name = "mullvad-relay-selector" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork 0.16.0", "itertools 0.12.1", "log", @@ -1830,6 +1840,7 @@ name = "mullvad-types" version = "0.0.0" dependencies = [ "chrono", + "intersection-derive", "ipnetwork 0.16.0", "jnix", "log", |
