summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock11
-rw-r--r--Cargo.toml1
-rw-r--r--mullvad-relay-selector/Cargo.toml1
-rw-r--r--mullvad-relay-selector/src/relay_selector/mod.rs4
-rw-r--r--mullvad-relay-selector/src/relay_selector/query.rs277
-rw-r--r--mullvad-types/Cargo.toml2
-rw-r--r--mullvad-types/intersection-derive/Cargo.toml16
-rw-r--r--mullvad-types/intersection-derive/src/lib.rs68
-rw-r--r--mullvad-types/src/constraints/constraint.rs24
-rw-r--r--mullvad-types/src/constraints/mod.rs128
-rw-r--r--mullvad-types/src/lib.rs3
-rw-r--r--mullvad-types/src/relay_constraints.rs23
-rw-r--r--test/Cargo.lock11
13 files changed, 285 insertions, 284 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 25b7582fb6..cc54dcb168 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1554,6 +1554,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc6d6206008e25125b1f97fbe5d309eb7b85141cf9199d52dbd3729a1584dd16"
[[package]]
+name = "intersection-derive"
+version = "0.0.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.51",
+]
+
+[[package]]
name = "ioctl-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2109,6 +2118,7 @@ name = "mullvad-relay-selector"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork",
"itertools 0.12.1",
"log",
@@ -2147,6 +2157,7 @@ version = "0.0.0"
dependencies = [
"chrono",
"clap",
+ "intersection-derive",
"ipnetwork",
"jnix",
"log",
diff --git a/Cargo.toml b/Cargo.toml
index d760c77f6b..b0c28f016a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -24,6 +24,7 @@ members = [
"mullvad-relay-selector",
"mullvad-setup",
"mullvad-types",
+ "mullvad-types/intersection-derive",
"mullvad-version",
"talpid-core",
"talpid-dbus",
diff --git a/mullvad-relay-selector/Cargo.toml b/mullvad-relay-selector/Cargo.toml
index ec8943df0b..3224fd7983 100644
--- a/mullvad-relay-selector/Cargo.toml
+++ b/mullvad-relay-selector/Cargo.toml
@@ -22,6 +22,7 @@ serde_json = "1.0"
talpid-types = { path = "../talpid-types" }
mullvad-types = { path = "../mullvad-types" }
+intersection-derive = { path = "../mullvad-types/intersection-derive"}
[dev-dependencies]
proptest = { workspace = true }
diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs
index e0949beae9..cbb3473f7f 100644
--- a/mullvad-relay-selector/src/relay_selector/mod.rs
+++ b/mullvad-relay-selector/src/relay_selector/mod.rs
@@ -28,7 +28,7 @@ use mullvad_types::{
},
relay_list::{Relay, RelayEndpointData, RelayList},
settings::Settings,
- CustomTunnelEndpoint,
+ CustomTunnelEndpoint, Intersection,
};
use talpid_types::{
net::{
@@ -43,7 +43,7 @@ use self::{
detailer::{openvpn_endpoint, wireguard_endpoint},
matcher::{filter_matching_bridges, filter_matching_relay_list},
parsed_relays::ParsedRelays,
- query::{BridgeQuery, Intersection, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery},
+ query::{BridgeQuery, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery},
};
/// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should
diff --git a/mullvad-relay-selector/src/relay_selector/query.rs b/mullvad-relay-selector/src/relay_selector/query.rs
index 0536379dbc..f112e83968 100644
--- a/mullvad-relay-selector/src/relay_selector/query.rs
+++ b/mullvad-relay-selector/src/relay_selector/query.rs
@@ -35,6 +35,7 @@ use mullvad_types::{
RelayConstraints, SelectedObfuscation, TransportPort, Udp2TcpObfuscationSettings,
WireguardConstraints,
},
+ Intersection,
};
use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType};
@@ -70,7 +71,7 @@ use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType};
/// This example demonstrates creating a `RelayQuery` which can then be passed
/// to the [`crate::RelaySelector`] to find a relay that matches the criteria.
/// See [`builder`] for more info on how to construct queries.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct RelayQuery {
pub location: Constraint<LocationConstraint>,
pub providers: Constraint<Providers>,
@@ -88,7 +89,7 @@ impl RelayQuery {
/// Note that the following identity applies for any `other_query`:
/// ```rust
/// # use mullvad_relay_selector::query::RelayQuery;
- /// # use crate::mullvad_relay_selector::query::Intersection;
+ /// # use mullvad_types::Intersection;
///
/// # let other_query = RelayQuery::new();
/// assert_eq!(RelayQuery::new().intersection(other_query.clone()), Some(other_query));
@@ -107,105 +108,6 @@ impl RelayQuery {
}
}
-impl Intersection for RelayQuery {
- /// Return a new [`RelayQuery`] which matches the intersected queries.
- ///
- /// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned:
- /// ```rust
- /// # use mullvad_relay_selector::query::builder::RelayQueryBuilder;
- /// # use crate::mullvad_relay_selector::query::Intersection;
- /// let query_a = RelayQueryBuilder::new().wireguard().build();
- /// let query_b = RelayQueryBuilder::new().openvpn().build();
- /// assert_eq!(query_a.intersection(query_b), None);
- /// ```
- ///
- /// * Otherwise, a new [`RelayQuery`] is returned where each constraint is
- /// as specific as possible. See [`Constraint`] for further details.
- /// ```rust
- /// # use crate::mullvad_relay_selector::*;
- /// # use crate::mullvad_relay_selector::query::*;
- /// # use crate::mullvad_relay_selector::query::builder::*;
- /// # use mullvad_types::relay_list::*;
- /// # use talpid_types::net::wireguard::PublicKey;
- ///
- /// // The relay list used by `relay_selector` in this example
- /// let relay_list = RelayList {
- /// # etag: None,
- /// # openvpn: OpenVpnEndpointData { ports: vec![] },
- /// # bridge: BridgeEndpointData {
- /// # shadowsocks: vec![],
- /// # },
- /// # wireguard: WireguardEndpointData {
- /// # port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)],
- /// # ipv4_gateway: "10.64.0.1".parse().unwrap(),
- /// # ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
- /// # udp2tcp_ports: vec![],
- /// # },
- /// countries: vec![RelayListCountry {
- /// name: "Sweden".to_string(),
- /// # code: "Sweden".to_string(),
- /// cities: vec![RelayListCity {
- /// name: "Gothenburg".to_string(),
- /// # code: "Gothenburg".to_string(),
- /// # latitude: 57.70887,
- /// # longitude: 11.97456,
- /// relays: vec![Relay {
- /// hostname: "se9-wireguard".to_string(),
- /// ipv4_addr_in: "185.213.154.68".parse().unwrap(),
- /// # ipv6_addr_in: Some("2a03:1b20:5:f011::a09f".parse().unwrap()),
- /// # include_in_country: false,
- /// # active: true,
- /// # owned: true,
- /// # provider: "31173".to_string(),
- /// # weight: 1,
- /// # endpoint_data: RelayEndpointData::Wireguard(WireguardRelayEndpointData {
- /// # public_key: PublicKey::from_base64(
- /// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
- /// # )
- /// # .unwrap(),
- /// # }),
- /// # location: None,
- /// }],
- /// }],
- /// }],
- /// };
- ///
- /// # let relay_selector = RelaySelector::from_list(SelectorConfig::default(), relay_list.clone());
- /// # let city = |country, city| GeographicLocationConstraint::city(country, city);
- ///
- /// let query_a = RelayQueryBuilder::new().wireguard().build();
- /// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build();
- ///
- /// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap());
- /// assert!(result.is_ok());
- /// ```
- ///
- /// This way, if the mullvad app wants to check if the user's relay settings
- /// are compatible with any other [`RelayQuery`], for examples those defined by
- /// [`RETRY_ORDER`] , taking the intersection between them will never result in
- /// a situation where the app can override the user's preferences.
- ///
- /// [`RETRY_ORDER`]: crate::RETRY_ORDER
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(RelayQuery {
- location: self.location.intersection(other.location)?,
- providers: self.providers.intersection(other.providers)?,
- ownership: self.ownership.intersection(other.ownership)?,
- tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?,
- wireguard_constraints: self
- .wireguard_constraints
- .intersection(other.wireguard_constraints)?,
- openvpn_constraints: self
- .openvpn_constraints
- .intersection(other.openvpn_constraints)?,
- })
- }
-}
-
impl From<RelayQuery> for RelayConstraints {
/// The mapping from [`RelayQuery`] to [`RelayConstraints`].
fn from(value: RelayQuery) -> Self {
@@ -229,7 +131,7 @@ impl From<RelayQuery> for RelayConstraints {
/// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner.
/// Notice that [obfuscation][`SelectedObfuscation`] is not a [`Constraint`], but it is trivial
/// to define [`Intersection`] on it, so it is fine.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct WireguardRelayQuery {
pub port: Constraint<u16>,
pub ip_version: Constraint<IpVersion>,
@@ -257,37 +159,6 @@ impl WireguardRelayQuery {
}
}
}
-impl Intersection for WireguardRelayQuery {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(WireguardRelayQuery {
- port: self.port.intersection(other.port)?,
- ip_version: self.ip_version.intersection(other.ip_version)?,
- use_multihop: self.use_multihop.intersection(other.use_multihop)?,
- entry_location: self.entry_location.intersection(other.entry_location)?,
- obfuscation: self.obfuscation.intersection(other.obfuscation)?,
- udp2tcp_port: self.udp2tcp_port.intersection(other.udp2tcp_port)?,
- })
- }
-}
-
-impl Intersection for SelectedObfuscation {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- match (self, other) {
- (left, SelectedObfuscation::Auto) => Some(left),
- (SelectedObfuscation::Auto, right) => Some(right),
- (left, right) if left == right => Some(left),
- _ => None,
- }
- }
-}
impl From<WireguardRelayQuery> for WireguardConstraints {
/// The mapping from [`WireguardRelayQuery`] to [`WireguardConstraints`].
@@ -307,7 +178,7 @@ impl From<WireguardRelayQuery> for WireguardConstraints {
/// This struct is meant to be that type in the "universe of relay queries". The difference
/// between them may seem subtle, but in a [`OpenVpnRelayQuery`] every field is represented
/// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct OpenVpnRelayQuery {
pub port: Constraint<TransportPort>,
pub bridge_settings: Constraint<BridgeQuery>,
@@ -322,29 +193,8 @@ impl OpenVpnRelayQuery {
}
}
-impl Intersection for OpenVpnRelayQuery {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- let bridge_settings = {
- match (self.bridge_settings, other.bridge_settings) {
- // Recursive case
- (Constraint::Only(left), Constraint::Only(right)) => {
- Constraint::Only(left.intersection(right)?)
- }
- (left, right) => left.intersection(right)?,
- }
- };
- Some(OpenVpnRelayQuery {
- port: self.port.intersection(other.port)?,
- bridge_settings,
- })
- }
-}
-
-/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay queries".
+/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay
+/// queries".
///
/// [`BridgeState`]: mullvad_types::relay_constraints::BridgeState
/// [`BridgeSettings`]: mullvad_types::relay_constraints::BridgeSettings
@@ -364,7 +214,7 @@ pub enum BridgeQuery {
}
impl BridgeQuery {
- ///If `bridge_constraints` is `Any`, bridges should not be used due to
+ /// If `bridge_constraints` is `Any`, bridges should not be used due to
/// latency concerns.
///
/// If `bridge_constraints` is `Only(settings)`, then `settings` will be
@@ -402,20 +252,6 @@ impl Intersection for BridgeQuery {
}
}
-impl Intersection for BridgeConstraints {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(BridgeConstraints {
- location: self.location.intersection(other.location)?,
- providers: self.providers.intersection(other.providers)?,
- ownership: self.ownership.intersection(other.ownership)?,
- })
- }
-}
-
impl From<OpenVpnRelayQuery> for OpenVpnConstraints {
/// The mapping from [`OpenVpnRelayQuery`] to [`OpenVpnConstraints`].
fn from(value: OpenVpnRelayQuery) -> Self {
@@ -423,103 +259,10 @@ impl From<OpenVpnRelayQuery> for OpenVpnConstraints {
}
}
-/// Any type that wish to implement `Intersection` should make sure that the
-/// following properties are upheld:
-///
-/// - idempotency (if there is an identity element)
-/// - commutativity
-/// - associativity
-pub trait Intersection {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized;
-}
-
-impl<T: Intersection> Intersection for Constraint<T> {
- /// Define the intersection between two arbitrary [`Constraint`]s.
- ///
- /// This operation may be compared to the set operation with the same name.
- /// In contrast to the general set intersection, this function represents a
- /// very specific case where [`Constraint::Any`] is equivalent to the set
- /// universe and [`Constraint::Only`] represents a singleton set. Notable is
- /// that the representation of any empty set is [`Option::None`].
- fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> {
- use Constraint::*;
- match (self, other) {
- (Any, Any) => Some(Any),
- (Only(t), Any) | (Any, Only(t)) => Some(Only(t)),
- // Recurse on `left` and `right` to see if there exist an intersection
- (Only(left), Only(right)) => Some(Only(left.intersection(right)?)),
- }
- }
-}
-
-// Implement `Intersection` for different types
-
-impl Intersection for Providers {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- Providers::new(self.providers().intersection(other.providers())).ok()
- }
-}
-
-impl Intersection for Udp2TcpObfuscationSettings {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- Some(Udp2TcpObfuscationSettings {
- port: self.port.intersection(other.port)?,
- })
- }
-}
-
-impl Intersection for TransportPort {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- let protocol = if self.protocol == other.protocol {
- Some(self.protocol)
- } else {
- None
- }?;
- let port = self.port.intersection(other.port)?;
- Some(TransportPort { protocol, port })
- }
-}
-
-/// Auto-implement `Intersection` for trivial cases where the logic should just check if
-/// `self` is equal to `other`.
-macro_rules! impl_intersection_partialeq {
- ($ty:ty) => {
- impl Intersection for $ty {
- fn intersection(self, other: Self) -> Option<Self> {
- if self == other {
- Some(self)
- } else {
- None
- }
- }
- }
- };
-}
-impl_intersection_partialeq!(u16);
-impl_intersection_partialeq!(bool);
-// FIXME: [`LocationConstraint`] deserves a hand-rolled implementation of [`Intersection`], but
-// it would probably be best to implement it for [`ResolvedLocationConstraint`] instead to properly
-// handle custom lists.
-impl_intersection_partialeq!(LocationConstraint);
-impl_intersection_partialeq!(Ownership);
-impl_intersection_partialeq!(talpid_types::net::TransportProtocol);
-impl_intersection_partialeq!(talpid_types::net::TunnelType);
-impl_intersection_partialeq!(talpid_types::net::IpVersion);
-
#[allow(unused)]
pub mod builder {
- //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate pattern.
+ //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate
+ //! pattern.
use mullvad_types::{
constraints::Constraint,
relay_constraints::{
diff --git a/mullvad-types/Cargo.toml b/mullvad-types/Cargo.toml
index 8b614dd182..c4b4f76486 100644
--- a/mullvad-types/Cargo.toml
+++ b/mullvad-types/Cargo.toml
@@ -21,6 +21,8 @@ serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.4.1", features = ["v4", "serde" ] }
talpid-types = { path = "../talpid-types" }
+intersection-derive = { path = "intersection-derive" }
+
clap = { workspace = true , optional = true }
diff --git a/mullvad-types/intersection-derive/Cargo.toml b/mullvad-types/intersection-derive/Cargo.toml
new file mode 100644
index 0000000000..f167c28aed
--- /dev/null
+++ b/mullvad-types/intersection-derive/Cargo.toml
@@ -0,0 +1,16 @@
+[package]
+name = "intersection-derive"
+description = "Derive macro for the `Intersection` trait"
+authors.workspace = true
+repository.workspace = true
+license.workspace = true
+edition.workspace = true
+rust-version.workspace = true
+
+[lib]
+proc-macro = true
+
+[dependencies]
+proc-macro2 = "1"
+syn = "2"
+quote = "1" \ No newline at end of file
diff --git a/mullvad-types/intersection-derive/src/lib.rs b/mullvad-types/intersection-derive/src/lib.rs
new file mode 100644
index 0000000000..65e5af15bd
--- /dev/null
+++ b/mullvad-types/intersection-derive/src/lib.rs
@@ -0,0 +1,68 @@
+//! This `proc-macro` crate exports the [`Intersection`] derive macro, see the trait documentation
+//! for more information.
+extern crate proc_macro;
+
+use proc_macro::TokenStream;
+use syn::{parse_macro_input, DeriveInput};
+
+/// Derive macro for the [`Intersection`] trait on structs.
+#[proc_macro_derive(Intersection)]
+pub fn intersection_derive(item: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(item as DeriveInput);
+
+ inner::derive(input).into()
+}
+
+mod inner {
+ use proc_macro2::TokenStream;
+ use quote::{quote, TokenStreamExt};
+ use syn::{spanned::Spanned, DeriveInput, Error};
+
+ pub(crate) fn derive(input: DeriveInput) -> TokenStream {
+ if let syn::Data::Struct(data) = &input.data {
+ derive_for_struct(&input, data).unwrap_or_else(Error::into_compile_error)
+ } else {
+ syn::Error::new(
+ input.span(),
+ "Deriving `Intersection` is only supported for structs",
+ )
+ .into_compile_error()
+ }
+ }
+
+ pub(crate) fn derive_for_struct(
+ input: &DeriveInput,
+ data: &syn::DataStruct,
+ ) -> syn::Result<TokenStream> {
+ let my_type = &input.ident;
+ let mut field_conversions = quote! {};
+ for field in &data.fields {
+ let Some(name) = &field.ident else {
+ return Err(syn::Error::new(
+ field.span(),
+ "Tuple structs are not currently supported",
+ ));
+ };
+
+ // TODO(Sebastian): Here, and in the `quote` below, we are referring to `Intersection`
+ // with its relative name, which will fail if the user renames the trait
+ // when importing, e.g. `use mullvad_types::Intersection as SomethingElse`.
+ // This is a know limitation of procural macros (declarative macros can use the `$crate`
+ // syntax). If the issue arises then it can be solve using the
+ // <https://crates.io/crates/proc-macro-crate> crate. Add it if necessary.
+ field_conversions.append_all(quote! {
+ #name: Intersection::intersection(self.#name, other.#name)?,
+ })
+ }
+
+ Ok(quote! {
+ impl Intersection for #my_type {
+ fn intersection(self, other: Self) -> ::core::option::Option<Self> {
+ ::core::option::Option::Some(Self {
+ #field_conversions
+ })
+ }
+ }
+ })
+ }
+}
diff --git a/mullvad-types/src/constraints/constraint.rs b/mullvad-types/src/constraints/constraint.rs
index c1669e33ae..26e1df3c6c 100644
--- a/mullvad-types/src/constraints/constraint.rs
+++ b/mullvad-types/src/constraints/constraint.rs
@@ -3,8 +3,9 @@
#[cfg(target_os = "android")]
use jnix::{FromJava, IntoJava};
use serde::{Deserialize, Serialize};
-use std::fmt;
-use std::str::FromStr;
+use std::{fmt, str::FromStr};
+
+use crate::Intersection;
/// Limits the set of [`crate::relay_list::Relay`]s that a `RelaySelector` may select.
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
@@ -17,6 +18,25 @@ pub enum Constraint<T> {
Only(T),
}
+impl<T: Intersection> Intersection for Constraint<T> {
+ /// Define the intersection between two arbitrary [`Constraint`]s.
+ ///
+ /// This operation may be compared to the set operation with the same name.
+ /// In contrast to the general set intersection, this function represents a
+ /// very specific case where [`Constraint::Any`] is equivalent to the set
+ /// universe and [`Constraint::Only`] represents a singleton set. Notable is
+ /// that the representation of any empty set is [`Option::None`].
+ fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> {
+ use Constraint::*;
+ match (self, other) {
+ (Any, Any) => Some(Any),
+ (Only(t), Any) | (Any, Only(t)) => Some(Only(t)),
+ // Pick any of `left` or `right` if they are the same.
+ (Only(left), Only(right)) => left.intersection(right).map(Only),
+ }
+ }
+}
+
impl<T: fmt::Display> fmt::Display for Constraint<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
diff --git a/mullvad-types/src/constraints/mod.rs b/mullvad-types/src/constraints/mod.rs
index 4caa31e901..8ee6e93195 100644
--- a/mullvad-types/src/constraints/mod.rs
+++ b/mullvad-types/src/constraints/mod.rs
@@ -22,14 +22,124 @@ impl<T: Match<U>, U> Match<U> for Constraint<T> {
}
}
-impl<T: Set<U>, U> Set<Constraint<U>> for Constraint<T> {
- fn is_subset(&self, other: &Constraint<U>) -> bool {
- match self {
- Constraint::Any => other.is_any(),
- Constraint::Only(ref constraint) => match other {
- Constraint::Only(ref other_constraint) => constraint.is_subset(other_constraint),
- _ => true,
- },
+/// The intersection of two sets of criteria on [`Relay`](crate::relay_list::Relay)s is another
+/// criteria which matches the given relay iff both of the original criteria matched. It is
+/// primarily used by the relay selector to check whether a given connection method is compatible
+/// with the users settings.
+///
+/// # Examples
+///
+/// The [`Intersection`] implementation of [`RelayQuery`] upholds the following properties:
+///
+/// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned:
+/// ```rust, ignore
+/// # use mullvad_relay_selector::query::builder::RelayQueryBuilder;
+/// let query_a = RelayQueryBuilder::new().wireguard().build();
+/// let query_b = RelayQueryBuilder::new().openvpn().build();
+/// assert_eq!(query_a.intersection(query_b), None);
+/// ```
+///
+/// * Otherwise, a new [`RelayQuery`] is returned where each constraint is
+/// as specific as possible. See [`Constraint`] for further details.
+/// ```rust, ignore
+/// let query_a = RelayQueryBuilder::new().wireguard().build();
+/// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build();
+///
+/// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap());
+/// assert!(result.is_ok());
+/// ```
+///
+/// This way, if the mullvad app wants to check if the user's relay settings
+/// are compatible with any other [`RelayQuery`], for examples those defined by
+/// [`RETRY_ORDER`] , taking the intersection between them will never result in
+/// a situation where the app can override the user's preferences.
+///
+/// [`RETRY_ORDER`]: crate::RETRY_ORDER
+///
+/// The macro recursively applies the intersection on each field of the struct and returns the
+/// resulting type or `None` if any of the intersections failed to overlap.
+///
+/// The macro requires the types of each field to also implement [`Intersection`], which may be done
+/// using this derive macro, the
+///
+/// # Implementing [`Intersection`]
+///
+/// For structs where each field already implements `Intersection`, the easiest way to implement the
+/// trait is using the derive macro. Using the derive macro on [`RelayQuery`]
+/// ```rust, ignore
+/// #[derive(Intersection)]
+/// struct RelayQuery {
+/// pub location: Constraint<LocationConstraint>,
+/// pub providers: Constraint<Providers>,
+/// pub ownership: Constraint<Ownership>,
+/// pub tunnel_protocol: Constraint<TunnelType>,
+/// pub wireguard_constraints: WireguardRelayQuery,
+/// pub openvpn_constraints: OpenVpnRelayQuery,
+/// }
+/// ```
+///
+/// produces an implementation like this:
+///
+/// ```rust, ignore
+/// impl Intersection for RelayQuery {
+/// fn intersection(self, other: Self) -> Option<Self>
+/// where
+/// Self: PartialEq,
+/// Self: Sized,
+/// {
+/// Some(RelayQuery {
+/// location: self.location.intersection(other.location)?,
+/// providers: self.providers.intersection(other.providers)?,
+/// ownership: self.ownership.intersection(other.ownership)?,
+/// tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?,
+/// wireguard_constraints: self
+/// .wireguard_constraints
+/// .intersection(other.wireguard_constraints)?,
+/// openvpn_constraints: self
+/// .openvpn_constraints
+/// .intersection(other.openvpn_constraints)?,
+/// })
+/// }
+/// }
+/// ```
+///
+/// For types that cannot "overlap", e.g. they only intersect if they are equal, the declarative
+/// macro [`impl_intersection_partialeq`] can be used.
+///
+/// For less trivial cases, the trait needs to be implemented manually. When doing so, make sure
+/// that the following properties are upheld:
+///
+/// - idempotency (if there is an identity element)
+/// - commutativity
+/// - associativity
+pub trait Intersection: Sized {
+ fn intersection(self, other: Self) -> Option<Self>;
+}
+
+#[macro_export]
+macro_rules! impl_intersection_partialeq {
+ ($ty:ty) => {
+ impl $crate::Intersection for $ty {
+ fn intersection(self, other: Self) -> Option<Self> {
+ if self == other {
+ Some(self)
+ } else {
+ None
+ }
+ }
}
- }
+ };
}
+
+impl_intersection_partialeq!(u16);
+impl_intersection_partialeq!(bool);
+
+// NOTE: this implementation does not do what you may expect of an intersection
+impl_intersection_partialeq!(relay_constraints::Providers);
+// NOTE: should take actual intersection
+impl_intersection_partialeq!(relay_constraints::LocationConstraint);
+impl_intersection_partialeq!(relay_constraints::Ownership);
+// NOTE: it contains an inner constraint
+impl_intersection_partialeq!(talpid_types::net::TransportProtocol);
+impl_intersection_partialeq!(talpid_types::net::TunnelType);
+impl_intersection_partialeq!(talpid_types::net::IpVersion);
diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs
index 231f80300b..d1a50fa0ea 100644
--- a/mullvad-types/src/lib.rs
+++ b/mullvad-types/src/lib.rs
@@ -21,3 +21,6 @@ pub use crate::custom_tunnel::*;
pub const TUNNEL_TABLE_ID: u32 = 0x6d6f6c65;
#[cfg(target_os = "linux")]
pub const TUNNEL_FWMARK: u32 = 0x6d6f6c65;
+
+pub use constraints::Intersection;
+pub use intersection_derive::Intersection;
diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs
index 05d07477be..dc6d0608e5 100644
--- a/mullvad-types/src/relay_constraints.rs
+++ b/mullvad-types/src/relay_constraints.rs
@@ -6,7 +6,7 @@ use crate::{
custom_list::{CustomListsSettings, Id},
location::{CityCode, CountryCode, Hostname},
relay_list::Relay,
- CustomTunnelEndpoint,
+ CustomTunnelEndpoint, Intersection,
};
#[cfg(target_os = "android")]
use jnix::{jni::objects::JObject, FromJava, IntoJava, JnixEnv};
@@ -463,7 +463,7 @@ impl fmt::Display for GeographicLocationConstraint {
}
}
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
pub struct TransportPort {
pub protocol: TransportProtocol,
pub port: Constraint<u16>,
@@ -646,6 +646,21 @@ pub enum SelectedObfuscation {
Udp2Tcp,
}
+impl Intersection for SelectedObfuscation {
+ fn intersection(self, other: Self) -> Option<Self>
+ where
+ Self: PartialEq,
+ Self: Sized,
+ {
+ match (self, other) {
+ (left, SelectedObfuscation::Auto) => Some(left),
+ (SelectedObfuscation::Auto, right) => Some(right),
+ (left, right) if left == right => Some(left),
+ _ => None,
+ }
+ }
+}
+
impl fmt::Display for SelectedObfuscation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -656,7 +671,7 @@ impl fmt::Display for SelectedObfuscation {
}
}
-#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
#[serde(rename_all = "snake_case")]
@@ -716,7 +731,7 @@ pub struct ObfuscationSettings {
}
/// Limits the set of bridge servers to use in `mullvad-daemon`.
-#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
#[serde(default)]
#[serde(rename_all = "snake_case")]
pub struct BridgeConstraints {
diff --git a/test/Cargo.lock b/test/Cargo.lock
index 5f207d9a34..eb38b5f37a 100644
--- a/test/Cargo.lock
+++ b/test/Cargo.lock
@@ -1380,6 +1380,15 @@ dependencies = [
]
[[package]]
+name = "intersection-derive"
+version = "0.0.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.51",
+]
+
+[[package]]
name = "inventory"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1814,6 +1823,7 @@ name = "mullvad-relay-selector"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork 0.16.0",
"itertools 0.12.1",
"log",
@@ -1830,6 +1840,7 @@ name = "mullvad-types"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork 0.16.0",
"jnix",
"log",