summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-04-04 16:50:35 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-04-04 16:50:35 +0200
commit7170e05df5cbfd4dd3f372267bd9c155ed49f5f6 (patch)
tree44d8037ea9f552227dfce5a5e2358bc580a8cec5
parente45f61d0f538f20d594106e0ec9ccecb40678a3b (diff)
parentc0b0304be43f994dc661d18696085dbd415afb24 (diff)
downloadmullvadvpn-7170e05df5cbfd4dd3f372267bd9c155ed49f5f6.tar.xz
mullvadvpn-7170e05df5cbfd4dd3f372267bd9c155ed49f5f6.zip
Merge branch 'intersection-macro'
-rw-r--r--Cargo.lock11
-rw-r--r--Cargo.toml1
-rw-r--r--mullvad-relay-selector/Cargo.toml1
-rw-r--r--mullvad-relay-selector/src/relay_selector/detailer.rs7
-rw-r--r--mullvad-relay-selector/src/relay_selector/helpers.rs35
-rw-r--r--mullvad-relay-selector/src/relay_selector/matcher.rs8
-rw-r--r--mullvad-relay-selector/src/relay_selector/mod.rs111
-rw-r--r--mullvad-relay-selector/src/relay_selector/query.rs296
-rw-r--r--mullvad-relay-selector/tests/relay_selector.rs43
-rw-r--r--mullvad-types/Cargo.toml2
-rw-r--r--mullvad-types/intersection-derive/Cargo.toml16
-rw-r--r--mullvad-types/intersection-derive/src/lib.rs68
-rw-r--r--mullvad-types/src/constraints/constraint.rs24
-rw-r--r--mullvad-types/src/constraints/mod.rs133
-rw-r--r--mullvad-types/src/lib.rs3
-rw-r--r--mullvad-types/src/relay_constraints.rs70
-rw-r--r--test/Cargo.lock11
17 files changed, 402 insertions, 438 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 25b7582fb6..cc54dcb168 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1554,6 +1554,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc6d6206008e25125b1f97fbe5d309eb7b85141cf9199d52dbd3729a1584dd16"
[[package]]
+name = "intersection-derive"
+version = "0.0.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.51",
+]
+
+[[package]]
name = "ioctl-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2109,6 +2118,7 @@ name = "mullvad-relay-selector"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork",
"itertools 0.12.1",
"log",
@@ -2147,6 +2157,7 @@ version = "0.0.0"
dependencies = [
"chrono",
"clap",
+ "intersection-derive",
"ipnetwork",
"jnix",
"log",
diff --git a/Cargo.toml b/Cargo.toml
index d760c77f6b..b0c28f016a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -24,6 +24,7 @@ members = [
"mullvad-relay-selector",
"mullvad-setup",
"mullvad-types",
+ "mullvad-types/intersection-derive",
"mullvad-version",
"talpid-core",
"talpid-dbus",
diff --git a/mullvad-relay-selector/Cargo.toml b/mullvad-relay-selector/Cargo.toml
index ec8943df0b..3224fd7983 100644
--- a/mullvad-relay-selector/Cargo.toml
+++ b/mullvad-relay-selector/Cargo.toml
@@ -22,6 +22,7 @@ serde_json = "1.0"
talpid-types = { path = "../talpid-types" }
mullvad-types = { path = "../mullvad-types" }
+intersection-derive = { path = "../mullvad-types/intersection-derive"}
[dev-dependencies]
proptest = { workspace = true }
diff --git a/mullvad-relay-selector/src/relay_selector/detailer.rs b/mullvad-relay-selector/src/relay_selector/detailer.rs
index 374d92da09..8fa36f5434 100644
--- a/mullvad-relay-selector/src/relay_selector/detailer.rs
+++ b/mullvad-relay-selector/src/relay_selector/detailer.rs
@@ -48,7 +48,8 @@ pub enum Error {
/// Constructs a [`MullvadWireguardEndpoint`] with details for how to connect to a Wireguard relay.
///
/// # Returns
-/// - A configured endpoint for Wireguard relay, encapsulating either a single-hop or multi-hop connection.
+/// - A configured endpoint for Wireguard relay, encapsulating either a single-hop or multi-hop
+/// connection.
/// - Returns [`Option::None`] if the desired port is not in a valid port range (see
/// [`WireguardRelayQuery::port`]) or relay addresses cannot be resolved.
pub fn wireguard_endpoint(
@@ -198,8 +199,8 @@ const fn get_public_key(relay: &Relay) -> Result<&PublicKey, Error> {
/// - `port_ranges`: A slice of tuples, each representing a range of valid port numbers.
///
/// # Returns
-/// - `Option<u16>`: A randomly selected port number within the given ranges, or `None` if
-/// the input is empty or the total number of available ports is zero.
+/// - `Option<u16>`: A randomly selected port number within the given ranges, or `None` if the input
+/// is empty or the total number of available ports is zero.
fn select_random_port(port_ranges: &[(u16, u16)]) -> Result<u16, Error> {
use rand::Rng;
let get_port_amount = |range: &(u16, u16)| -> u64 { (1 + range.1 - range.0) as u64 };
diff --git a/mullvad-relay-selector/src/relay_selector/helpers.rs b/mullvad-relay-selector/src/relay_selector/helpers.rs
index 5ad5bf7ab4..263034d836 100644
--- a/mullvad-relay-selector/src/relay_selector/helpers.rs
+++ b/mullvad-relay-selector/src/relay_selector/helpers.rs
@@ -8,24 +8,11 @@ use mullvad_types::{
relay_constraints::Udp2TcpObfuscationSettings,
relay_list::{BridgeEndpointData, Relay, RelayEndpointData},
};
-use rand::{
- seq::{IteratorRandom, SliceRandom},
- thread_rng, Rng,
-};
+use rand::{seq::SliceRandom, thread_rng, Rng};
use talpid_types::net::{obfuscation::ObfuscatorConfig, proxy::CustomProxy};
use crate::SelectedObfuscator;
-/// Pick a random element out of `from`, excluding the element `exclude` from the selection.
-pub fn random<'a, A: PartialEq>(
- from: impl IntoIterator<Item = &'a A>,
- exclude: &A,
-) -> Option<&'a A> {
- from.into_iter()
- .filter(|&a| a != exclude)
- .choose(&mut thread_rng())
-}
-
/// Picks a relay using [pick_random_relay_fn], using the `weight` member of each relay
/// as the weight function.
pub fn pick_random_relay(relays: &[Relay]) -> Option<&Relay> {
@@ -49,16 +36,16 @@ pub fn pick_random_relay_weighted<RelayType>(
// Pick a random number in the range 1..=total_weight. This choses the relay with a
// non-zero weight.
//
- // rng(1..=total_weight)
- // |
- // v
- // _____________________________i___________________________________________________
- // 0|_____________|__________________________|___________|_____|___________|__________| total_weight
- // ^ ^ ^ ^ ^
- // | | | | |
- // ------------------------------------------ ------------
- // | | |
- // weight(relay 0) weight(relay 1) .. .. .. weight(relay n)
+ // rng(1..=total_weight)
+ // |
+ // v
+ // ________________________i_______________________________________________
+ // 0|_____________|____________________|___________|_____|________|__________| total_weight
+ // ^ ^ ^ ^ ^
+ // | | | | |
+ // ------------------------------------ ------------
+ // | | |
+ // weight(relay 0) weight(relay 1) .. .. .. weight(relay n)
let mut i: u64 = rng.gen_range(1..=total_weight);
Some(
relays
diff --git a/mullvad-relay-selector/src/relay_selector/matcher.rs b/mullvad-relay-selector/src/relay_selector/matcher.rs
index f06e04224f..2987d8965a 100644
--- a/mullvad-relay-selector/src/relay_selector/matcher.rs
+++ b/mullvad-relay-selector/src/relay_selector/matcher.rs
@@ -35,10 +35,10 @@ pub fn filter_matching_relay_list<'a, R: Iterator<Item = &'a Relay> + Clone>(
.filter(|relay| filter_on_providers(&query.providers, relay));
// The last filtering to be done is on the `include_in_country` attribute found on each
- // relay. When the location constraint is based on country, a relay which has `include_in_country`
- // set to true should always be prioritized over relays which has this flag set to false.
- // We should only consider relays with `include_in_country` set to false if there are no
- // other candidates left.
+ // relay. When the location constraint is based on country, a relay which has
+ // `include_in_country` set to true should always be prioritized over relays which has this
+ // flag set to false. We should only consider relays with `include_in_country` set to false
+ // if there are no other candidates left.
match &locations {
Constraint::Any => shortlist.cloned().collect(),
Constraint::Only(locations) => {
diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs
index 424ca6385c..cbb3473f7f 100644
--- a/mullvad-relay-selector/src/relay_selector/mod.rs
+++ b/mullvad-relay-selector/src/relay_selector/mod.rs
@@ -9,6 +9,7 @@ pub mod query;
use chrono::{DateTime, Local};
use itertools::Itertools;
use once_cell::sync::Lazy;
+use rand::{seq::IteratorRandom, thread_rng};
use std::{
path::Path,
sync::{Arc, Mutex},
@@ -27,7 +28,7 @@ use mullvad_types::{
},
relay_list::{Relay, RelayEndpointData, RelayList},
settings::Settings,
- CustomTunnelEndpoint,
+ CustomTunnelEndpoint, Intersection,
};
use talpid_types::{
net::{
@@ -42,12 +43,12 @@ use self::{
detailer::{openvpn_endpoint, wireguard_endpoint},
matcher::{filter_matching_bridges, filter_matching_relay_list},
parsed_relays::ParsedRelays,
- query::{BridgeQuery, Intersection, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery},
+ query::{BridgeQuery, OpenVpnRelayQuery, RelayQuery, WireguardRelayQuery},
};
-/// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should prioritize on
-/// successive connection attempts. Note that these will *never* override user preferences.
-/// See [the documentation on `RelayQuery`][RelayQuery] for further details.
+/// [`RETRY_ORDER`] defines an ordered set of relay parameters which the relay selector should
+/// prioritize on successive connection attempts. Note that these will *never* override user
+/// preferences. See [the documentation on `RelayQuery`][RelayQuery] for further details.
///
/// This list should be kept in sync with the expected behavior defined in `docs/relay-selector.md`
pub static RETRY_ORDER: Lazy<Vec<RelayQuery>> = Lazy::new(|| {
@@ -144,14 +145,16 @@ impl Default for RuntimeParameters {
/// This enum exists to separate the two types of [`SelectorConfig`] that exists.
///
-/// The first one is a "regular" config, where [`SelectorConfig::relay_settings`] is [`RelaySettings::Normal`].
-/// This is the most common variant, and there exists a mapping from this variant to [`RelayQueryBuilder`].
-/// Being able to implement `From<NormalSelectorConfig> for RelayQueryBuilder` was the main
-/// motivator for introducing these seemingly useless derivates of [`SelectorConfig`].
+/// The first one is a "regular" config, where [`SelectorConfig::relay_settings`] is
+/// [`RelaySettings::Normal`]. This is the most common variant, and there exists a mapping from this
+/// variant to [`RelayQueryBuilder`]. Being able to implement `From<NormalSelectorConfig> for
+/// RelayQueryBuilder` was the main motivator for introducing these seemingly useless derivates of
+/// [`SelectorConfig`].
///
-/// The second one is a custom config, where [`SelectorConfig::relay_settings`] is [`RelaySettings::Custom`].
-/// For this variant, the endpoint where the client should connect to is already specified inside of the variant,
-/// so in practice the relay selector becomes superfluous. Also, there exists no mapping to [`RelayQueryBuilder`].
+/// The second one is a custom config, where [`SelectorConfig::relay_settings`] is
+/// [`RelaySettings::Custom`]. For this variant, the endpoint where the client should connect to is
+/// already specified inside of the variant, so in practice the relay selector becomes superfluous.
+/// Also, there exists no mapping to [`RelayQueryBuilder`].
#[derive(Debug, Clone)]
enum SpecializedSelectorConfig<'a> {
// This variant implements `From<NormalSelectorConfig> for RelayQuery`
@@ -500,9 +503,9 @@ impl RelaySelector {
self.get_relay_with_custom_params(retry_attempt, &RETRY_ORDER, runtime_params)
}
- /// Peek at which [`TunnelType`] that would be returned for a certain connection attempt for a given
- /// [`SelectorConfig`]. Returns [`Option::None`] if the given config would return a custom
- /// tunnel endpoint.
+ /// Peek at which [`TunnelType`] that would be returned for a certain connection attempt for a
+ /// given [`SelectorConfig`]. Returns [`Option::None`] if the given config would return a
+ /// custom tunnel endpoint.
///
/// # Note
/// This function is only really useful for testing-purposes. It is exposed to ease testing of
@@ -556,15 +559,16 @@ impl RelaySelector {
}
}
- /// This function defines the merge between a set of pre-defined queries and `user_preferences` for the given
- /// `retry_attempt`.
+ /// This function defines the merge between a set of pre-defined queries and `user_preferences`
+ /// for the given `retry_attempt`.
///
- /// This algorithm will loop back to the start of `retry_order` if `retry_attempt < retry_order.len()`.
- /// If `user_preferences` is not compatible with any of the pre-defined queries in `retry_order`, `user_preferences`
- /// is returned.
+ /// This algorithm will loop back to the start of `retry_order` if `retry_attempt <
+ /// retry_order.len()`. If `user_preferences` is not compatible with any of the pre-defined
+ /// queries in `retry_order`, `user_preferences` is returned.
///
/// Runtime parameters may affect which of the default queries that are considered. For example,
- /// queries which rely on IPv6 will not be considered if working IPv6 is not available at runtime.
+ /// queries which rely on IPv6 will not be considered if working IPv6 is not available at
+ /// runtime.
fn pick_and_merge_query(
retry_attempt: usize,
retry_order: &[RelayQuery],
@@ -583,15 +587,19 @@ impl RelaySelector {
.unwrap_or(user_preferences)
}
- /// "Execute" the given query, yielding a final set of relays and/or bridges which the VPN traffic shall be routed through.
+ /// "Execute" the given query, yielding a final set of relays and/or bridges which the VPN
+ /// traffic shall be routed through.
///
/// # Parameters
- /// - `query`: Constraints that filter the available relays, such as geographic location or tunnel protocol.
- /// - `config`: Configuration settings that influence relay selection, including bridge state and custom lists.
+ /// - `query`: Constraints that filter the available relays, such as geographic location or
+ /// tunnel protocol.
+ /// - `config`: Configuration settings that influence relay selection, including bridge state
+ /// and custom lists.
/// - `parsed_relays`: The complete set of parsed relays available for selection.
///
/// # Returns
- /// * A randomly selected relay that meets the specified constraints (and a random bridge/entry relay if applicable).
+ /// * A randomly selected relay that meets the specified constraints (and a random bridge/entry
+ /// relay if applicable).
/// See [`GetRelay`] for more details.
/// * An `Err` if no suitable relay is found
/// * An `Err` if no suitable bridge is found
@@ -629,10 +637,10 @@ impl RelaySelector {
parsed_relays: &ParsedRelays,
config: &NormalSelectorConfig<'_>,
) -> Result<GetRelay, Error> {
- // FIXME: A bit of defensive programming - calling `get_wiregurad_relay` with a query that doesn't
- // specify Wireguard as the desired tunnel type is not valid and will lead to unwanted
- // behavior. This should be seen as a workaround, and it would be nicer to lift this
- // invariant to be checked by the type system instead.
+ // FIXME: A bit of defensive programming - calling `get_wiregurad_relay` with a query that
+ // doesn't specify Wireguard as the desired tunnel type is not valid and will lead
+ // to unwanted behavior. This should be seen as a workaround, and it would be nicer
+ // to lift this invariant to be checked by the type system instead.
query.tunnel_protocol = Constraint::Only(TunnelType::Wireguard);
Self::get_wireguard_relay(query, config, parsed_relays)
}
@@ -712,7 +720,8 @@ impl RelaySelector {
let mut entry_relay_query = query.clone();
entry_relay_query.location = query.wireguard_constraints.entry_location.clone();
// After we have our two queries (one for the exit relay & one for the entry relay),
- // we can query for all exit & entry candidates! All candidates are needed for the next step.
+ // we can query for all exit & entry candidates! All candidates are needed for the next
+ // step.
let exit_candidates =
filter_matching_relay_list(query, parsed_relays.relays(), config.custom_lists);
let entry_candidates = filter_matching_relay_list(
@@ -721,25 +730,25 @@ impl RelaySelector {
config.custom_lists,
);
- // This algorithm gracefully handles a particular edge case that arise when a constraint on
- // the exit relay is more specific than on the entry relay which forces the relay selector
- // to choose one specific relay. The relay selector could end up selecting that specific
- // relay as the entry relay, thus leaving no remaining exit relay candidates or vice versa.
+ fn pick_random_excluding<'a>(list: &'a [Relay], exclude: &'a Relay) -> Option<&'a Relay> {
+ list.iter()
+ .filter(|&a| a != exclude)
+ .choose(&mut thread_rng())
+ }
+ // We avoid picking the same relay for entry and exit by choosing one and excluding it when
+ // choosing the other.
let (exit, entry) = match (exit_candidates.as_slice(), entry_candidates.as_slice()) {
- ([exit], [entry]) if exit == entry => None,
+ // In the case where there is only one entry to choose from, we have to pick it before
+ // the exit
(exits, [entry]) if exits.contains(entry) => {
- let exit = helpers::random(exits, entry).ok_or(Error::NoRelay)?;
- Some((exit, entry))
- }
- ([exit], entrys) if entrys.contains(exit) => {
- let entry = helpers::random(entrys, exit).ok_or(Error::NoRelay)?;
- Some((exit, entry))
+ pick_random_excluding(exits, entry).map(|exit| (exit, entry))
}
- (exits, entrys) => {
- let exit = helpers::pick_random_relay(exits).ok_or(Error::NoRelay)?;
- let entry = helpers::random(entrys, exit).ok_or(Error::NoRelay)?;
- Some((exit, entry))
+ // Vice versa for the case of only one exit
+ ([exit], entries) if entries.contains(exit) => {
+ pick_random_excluding(entries, exit).map(|entry| (exit, entry))
}
+ (exits, entries) => helpers::pick_random_relay(exits)
+ .and_then(|exit| pick_random_excluding(entries, exit).map(|entry| (exit, entry))),
}
.ok_or(Error::NoRelay)?;
@@ -847,17 +856,20 @@ impl RelaySelector {
})
}
- /// Selects a suitable bridge based on the specified settings, relay information, and transport protocol.
+ /// Selects a suitable bridge based on the specified settings, relay information, and transport
+ /// protocol.
///
/// # Parameters
/// - `query`: The filter criteria for selecting a bridge.
/// - `relay`: Information about the current relay, including its location.
/// - `protocol`: The transport protocol (TCP or UDP) in use.
/// - `parsed_relays`: A structured representation of all available relays.
- /// - `custom_lists`: User-defined or application-specific settings that may influence bridge selection.
+ /// - `custom_lists`: User-defined or application-specific settings that may influence bridge
+ /// selection.
///
/// # Returns
- /// * On success, returns an `Option` containing the selected bridge, if one is found. Returns `None` if no suitable bridge meets the criteria or bridges should not be used.
+ /// * On success, returns an `Option` containing the selected bridge, if one is found. Returns
+ /// `None` if no suitable bridge meets the criteria or bridges should not be used.
/// * `Error::NoBridge` if attempting to use OpenVPN bridges over UDP, as this is unsupported.
/// * `Error::NoRelay` if `relay` does not have a location set.
#[cfg(not(target_os = "android"))]
@@ -1010,7 +1022,8 @@ impl RelaySelector {
}
/// # Returns
- /// A randomly selected relay that meets the specified constraints, or `None` if no suitable relay is found.
+ /// A randomly selected relay that meets the specified constraints, or `None` if no suitable
+ /// relay is found.
#[cfg(not(target_os = "android"))]
fn choose_openvpn_relay(
query: &RelayQuery,
diff --git a/mullvad-relay-selector/src/relay_selector/query.rs b/mullvad-relay-selector/src/relay_selector/query.rs
index d5b17b2303..f112e83968 100644
--- a/mullvad-relay-selector/src/relay_selector/query.rs
+++ b/mullvad-relay-selector/src/relay_selector/query.rs
@@ -1,18 +1,20 @@
//! This module provides a flexible way to specify 'queries' for relays.
//!
//! A query is a set of constraints that the [`crate::RelaySelector`] will use when filtering out
-//! potential relays that the daemon should connect to. It supports filtering relays by geographic location,
-//! provider, ownership, and tunnel protocol, along with protocol-specific settings for WireGuard and OpenVPN.
+//! potential relays that the daemon should connect to. It supports filtering relays by geographic
+//! location, provider, ownership, and tunnel protocol, along with protocol-specific settings for
+//! WireGuard and OpenVPN.
//!
//! The main components of this module include:
//!
-//! - [`RelayQuery`]: The core struct for specifying a query to select relay servers. It
-//! aggregates constraints on location, providers, ownership, tunnel protocol, and
-//! protocol-specific constraints for WireGuard and OpenVPN.
+//! - [`RelayQuery`]: The core struct for specifying a query to select relay servers. It aggregates
+//! constraints on location, providers, ownership, tunnel protocol, and protocol-specific
+//! constraints for WireGuard and OpenVPN.
//! - [`WireguardRelayQuery`] and [`OpenVpnRelayQuery`]: Structs that define protocol-specific
//! constraints for selecting WireGuard and OpenVPN relays, respectively.
-//! - [`Intersection`]: A trait implemented by the different query types that support intersection logic,
-//! which allows for combining two queries into a single query that represents the common constraints of both.
+//! - [`Intersection`]: A trait implemented by the different query types that support intersection
+//! logic, which allows for combining two queries into a single query that represents the common
+//! constraints of both.
//! - [Builder patterns][builder]: The module also provides builder patterns for creating instances
//! of `RelayQuery`, `WireguardRelayQuery`, and `OpenVpnRelayQuery` with a fluent API.
//!
@@ -33,6 +35,7 @@ use mullvad_types::{
RelayConstraints, SelectedObfuscation, TransportPort, Udp2TcpObfuscationSettings,
WireguardConstraints,
},
+ Intersection,
};
use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType};
@@ -68,7 +71,7 @@ use talpid_types::net::{proxy::CustomProxy, IpVersion, TunnelType};
/// This example demonstrates creating a `RelayQuery` which can then be passed
/// to the [`crate::RelaySelector`] to find a relay that matches the criteria.
/// See [`builder`] for more info on how to construct queries.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct RelayQuery {
pub location: Constraint<LocationConstraint>,
pub providers: Constraint<Providers>,
@@ -86,7 +89,7 @@ impl RelayQuery {
/// Note that the following identity applies for any `other_query`:
/// ```rust
/// # use mullvad_relay_selector::query::RelayQuery;
- /// # use crate::mullvad_relay_selector::query::Intersection;
+ /// # use mullvad_types::Intersection;
///
/// # let other_query = RelayQuery::new();
/// assert_eq!(RelayQuery::new().intersection(other_query.clone()), Some(other_query));
@@ -105,105 +108,6 @@ impl RelayQuery {
}
}
-impl Intersection for RelayQuery {
- /// Return a new [`RelayQuery`] which matches the intersected queries.
- ///
- /// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned:
- /// ```rust
- /// # use mullvad_relay_selector::query::builder::RelayQueryBuilder;
- /// # use crate::mullvad_relay_selector::query::Intersection;
- /// let query_a = RelayQueryBuilder::new().wireguard().build();
- /// let query_b = RelayQueryBuilder::new().openvpn().build();
- /// assert_eq!(query_a.intersection(query_b), None);
- /// ```
- ///
- /// * Otherwise, a new [`RelayQuery`] is returned where each constraint is
- /// as specific as possible. See [`Constraint`] for further details.
- /// ```rust
- /// # use crate::mullvad_relay_selector::*;
- /// # use crate::mullvad_relay_selector::query::*;
- /// # use crate::mullvad_relay_selector::query::builder::*;
- /// # use mullvad_types::relay_list::*;
- /// # use talpid_types::net::wireguard::PublicKey;
- ///
- /// // The relay list used by `relay_selector` in this example
- /// let relay_list = RelayList {
- /// # etag: None,
- /// # openvpn: OpenVpnEndpointData { ports: vec![] },
- /// # bridge: BridgeEndpointData {
- /// # shadowsocks: vec![],
- /// # },
- /// # wireguard: WireguardEndpointData {
- /// # port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)],
- /// # ipv4_gateway: "10.64.0.1".parse().unwrap(),
- /// # ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
- /// # udp2tcp_ports: vec![],
- /// # },
- /// countries: vec![RelayListCountry {
- /// name: "Sweden".to_string(),
- /// # code: "Sweden".to_string(),
- /// cities: vec![RelayListCity {
- /// name: "Gothenburg".to_string(),
- /// # code: "Gothenburg".to_string(),
- /// # latitude: 57.70887,
- /// # longitude: 11.97456,
- /// relays: vec![Relay {
- /// hostname: "se9-wireguard".to_string(),
- /// ipv4_addr_in: "185.213.154.68".parse().unwrap(),
- /// # ipv6_addr_in: Some("2a03:1b20:5:f011::a09f".parse().unwrap()),
- /// # include_in_country: false,
- /// # active: true,
- /// # owned: true,
- /// # provider: "31173".to_string(),
- /// # weight: 1,
- /// # endpoint_data: RelayEndpointData::Wireguard(WireguardRelayEndpointData {
- /// # public_key: PublicKey::from_base64(
- /// # "BLNHNoGO88LjV/wDBa7CUUwUzPq/fO2UwcGLy56hKy4=",
- /// # )
- /// # .unwrap(),
- /// # }),
- /// # location: None,
- /// }],
- /// }],
- /// }],
- /// };
- ///
- /// # let relay_selector = RelaySelector::from_list(SelectorConfig::default(), relay_list.clone());
- /// # let city = |country, city| GeographicLocationConstraint::city(country, city);
- ///
- /// let query_a = RelayQueryBuilder::new().wireguard().build();
- /// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build();
- ///
- /// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap());
- /// assert!(result.is_ok());
- /// ```
- ///
- /// This way, if the mullvad app wants to check if the user's relay settings
- /// are compatible with any other [`RelayQuery`], for examples those defined by
- /// [`RETRY_ORDER`] , taking the intersection between them will never result in
- /// a situation where the app can override the user's preferences.
- ///
- /// [`RETRY_ORDER`]: crate::RETRY_ORDER
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(RelayQuery {
- location: self.location.intersection(other.location)?,
- providers: self.providers.intersection(other.providers)?,
- ownership: self.ownership.intersection(other.ownership)?,
- tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?,
- wireguard_constraints: self
- .wireguard_constraints
- .intersection(other.wireguard_constraints)?,
- openvpn_constraints: self
- .openvpn_constraints
- .intersection(other.openvpn_constraints)?,
- })
- }
-}
-
impl From<RelayQuery> for RelayConstraints {
/// The mapping from [`RelayQuery`] to [`RelayConstraints`].
fn from(value: RelayQuery) -> Self {
@@ -218,7 +122,8 @@ impl From<RelayQuery> for RelayConstraints {
}
}
-/// A query for a relay with Wireguard-specific properties, such as `multihop` and [wireguard obfuscation][`SelectedObfuscation`].
+/// A query for a relay with Wireguard-specific properties, such as `multihop` and [wireguard
+/// obfuscation][`SelectedObfuscation`].
///
/// This struct may look a lot like [`WireguardConstraints`], and that is the point!
/// This struct is meant to be that type in the "universe of relay queries". The difference
@@ -226,7 +131,7 @@ impl From<RelayQuery> for RelayConstraints {
/// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner.
/// Notice that [obfuscation][`SelectedObfuscation`] is not a [`Constraint`], but it is trivial
/// to define [`Intersection`] on it, so it is fine.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct WireguardRelayQuery {
pub port: Constraint<u16>,
pub ip_version: Constraint<IpVersion>,
@@ -254,37 +159,6 @@ impl WireguardRelayQuery {
}
}
}
-impl Intersection for WireguardRelayQuery {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(WireguardRelayQuery {
- port: self.port.intersection(other.port)?,
- ip_version: self.ip_version.intersection(other.ip_version)?,
- use_multihop: self.use_multihop.intersection(other.use_multihop)?,
- entry_location: self.entry_location.intersection(other.entry_location)?,
- obfuscation: self.obfuscation.intersection(other.obfuscation)?,
- udp2tcp_port: self.udp2tcp_port.intersection(other.udp2tcp_port)?,
- })
- }
-}
-
-impl Intersection for SelectedObfuscation {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- match (self, other) {
- (left, SelectedObfuscation::Auto) => Some(left),
- (SelectedObfuscation::Auto, right) => Some(right),
- (left, right) if left == right => Some(left),
- _ => None,
- }
- }
-}
impl From<WireguardRelayQuery> for WireguardConstraints {
/// The mapping from [`WireguardRelayQuery`] to [`WireguardConstraints`].
@@ -304,7 +178,7 @@ impl From<WireguardRelayQuery> for WireguardConstraints {
/// This struct is meant to be that type in the "universe of relay queries". The difference
/// between them may seem subtle, but in a [`OpenVpnRelayQuery`] every field is represented
/// as a [`Constraint`], which allow us to implement [`Intersection`] in a straight forward manner.
-#[derive(Debug, Clone, Eq, PartialEq)]
+#[derive(Debug, Clone, Eq, PartialEq, Intersection)]
pub struct OpenVpnRelayQuery {
pub port: Constraint<TransportPort>,
pub bridge_settings: Constraint<BridgeQuery>,
@@ -319,29 +193,8 @@ impl OpenVpnRelayQuery {
}
}
-impl Intersection for OpenVpnRelayQuery {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- let bridge_settings = {
- match (self.bridge_settings, other.bridge_settings) {
- // Recursive case
- (Constraint::Only(left), Constraint::Only(right)) => {
- Constraint::Only(left.intersection(right)?)
- }
- (left, right) => left.intersection(right)?,
- }
- };
- Some(OpenVpnRelayQuery {
- port: self.port.intersection(other.port)?,
- bridge_settings,
- })
- }
-}
-
-/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay queries".
+/// This is the reflection of [`BridgeState`] + [`BridgeSettings`] in the "universe of relay
+/// queries".
///
/// [`BridgeState`]: mullvad_types::relay_constraints::BridgeState
/// [`BridgeSettings`]: mullvad_types::relay_constraints::BridgeSettings
@@ -361,7 +214,7 @@ pub enum BridgeQuery {
}
impl BridgeQuery {
- ///If `bridge_constraints` is `Any`, bridges should not be used due to
+ /// If `bridge_constraints` is `Any`, bridges should not be used due to
/// latency concerns.
///
/// If `bridge_constraints` is `Only(settings)`, then `settings` will be
@@ -399,20 +252,6 @@ impl Intersection for BridgeQuery {
}
}
-impl Intersection for BridgeConstraints {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: PartialEq,
- Self: Sized,
- {
- Some(BridgeConstraints {
- location: self.location.intersection(other.location)?,
- providers: self.providers.intersection(other.providers)?,
- ownership: self.ownership.intersection(other.ownership)?,
- })
- }
-}
-
impl From<OpenVpnRelayQuery> for OpenVpnConstraints {
/// The mapping from [`OpenVpnRelayQuery`] to [`OpenVpnConstraints`].
fn from(value: OpenVpnRelayQuery) -> Self {
@@ -420,103 +259,10 @@ impl From<OpenVpnRelayQuery> for OpenVpnConstraints {
}
}
-/// Any type that wish to implement `Intersection` should make sure that the
-/// following properties are upheld:
-///
-/// - idempotency (if there is an identity element)
-/// - commutativity
-/// - associativity
-pub trait Intersection {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized;
-}
-
-impl<T: Intersection> Intersection for Constraint<T> {
- /// Define the intersection between two arbitrary [`Constraint`]s.
- ///
- /// This operation may be compared to the set operation with the same name.
- /// In contrast to the general set intersection, this function represents a
- /// very specific case where [`Constraint::Any`] is equivalent to the set
- /// universe and [`Constraint::Only`] represents a singleton set. Notable is
- /// that the representation of any empty set is [`Option::None`].
- fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> {
- use Constraint::*;
- match (self, other) {
- (Any, Any) => Some(Any),
- (Only(t), Any) | (Any, Only(t)) => Some(Only(t)),
- // Recurse on `left` and `right` to see if there exist an intersection
- (Only(left), Only(right)) => Some(Only(left.intersection(right)?)),
- }
- }
-}
-
-// Implement `Intersection` for different types
-
-impl Intersection for Providers {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- Providers::new(self.providers().intersection(other.providers())).ok()
- }
-}
-
-impl Intersection for Udp2TcpObfuscationSettings {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- Some(Udp2TcpObfuscationSettings {
- port: self.port.intersection(other.port)?,
- })
- }
-}
-
-impl Intersection for TransportPort {
- fn intersection(self, other: Self) -> Option<Self>
- where
- Self: Sized,
- {
- let protocol = if self.protocol == other.protocol {
- Some(self.protocol)
- } else {
- None
- }?;
- let port = self.port.intersection(other.port)?;
- Some(TransportPort { protocol, port })
- }
-}
-
-/// Auto-implement `Intersection` for trivial cases where the logic should just check if
-/// `self` is equal to `other`.
-macro_rules! impl_intersection_partialeq {
- ($ty:ty) => {
- impl Intersection for $ty {
- fn intersection(self, other: Self) -> Option<Self> {
- if self == other {
- Some(self)
- } else {
- None
- }
- }
- }
- };
-}
-impl_intersection_partialeq!(u16);
-impl_intersection_partialeq!(bool);
-// FIXME: [`LocationConstraint`] deserves a hand-rolled implementation of [`Intersection`], but
-// it would probably be best to implement it for [`ResolvedLocationConstraint`] instead to properly
-// handle custom lists.
-impl_intersection_partialeq!(LocationConstraint);
-impl_intersection_partialeq!(Ownership);
-impl_intersection_partialeq!(talpid_types::net::TransportProtocol);
-impl_intersection_partialeq!(talpid_types::net::TunnelType);
-impl_intersection_partialeq!(talpid_types::net::IpVersion);
-
#[allow(unused)]
pub mod builder {
- //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate pattern.
+ //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate
+ //! pattern.
use mullvad_types::{
constraints::Constraint,
relay_constraints::{
diff --git a/mullvad-relay-selector/tests/relay_selector.rs b/mullvad-relay-selector/tests/relay_selector.rs
index 99a5491fad..ed3546c62e 100644
--- a/mullvad-relay-selector/tests/relay_selector.rs
+++ b/mullvad-relay-selector/tests/relay_selector.rs
@@ -201,10 +201,11 @@ fn default_relay_selector() -> RelaySelector {
}
/// This is not an actual test. Rather, it serves as a reminder that if [`RETRY_ORDER`] is modified,
-/// the programmer should be made aware to update all external documents which rely on the retry order
-/// to be correct.
+/// the programmer should be made aware to update all external documents which rely on the retry
+/// order to be correct.
///
-/// When all necessary changes have been made, feel free to update this test to mirror the new [`RETRY_ORDER`].
+/// When all necessary changes have been made, feel free to update this test to mirror the new
+/// [`RETRY_ORDER`].
#[test]
fn assert_retry_order() {
use talpid_types::net::{IpVersion, TransportProtocol};
@@ -341,8 +342,8 @@ fn prefer_wireguard_when_auto() {
}
}
-/// If a Wireguard relay is only specified by it's hostname (and not tunnel type), the relay selector should
-/// still return a relay of the correct tunnel type (Wireguard).
+/// If a Wireguard relay is only specified by it's hostname (and not tunnel type), the relay
+/// selector should still return a relay of the correct tunnel type (Wireguard).
#[test]
fn test_prefer_wireguard_if_location_supports_it() {
let relay_selector = default_relay_selector();
@@ -361,8 +362,8 @@ fn test_prefer_wireguard_if_location_supports_it() {
}
}
-/// If an OpenVPN relay is only specified by it's hostname (and not tunnel type), the relay selector should
-/// still return a relay of the correct tunnel type (OpenVPN).
+/// If an OpenVPN relay is only specified by it's hostname (and not tunnel type), the relay selector
+/// should still return a relay of the correct tunnel type (OpenVPN).
#[test]
fn test_prefer_openvpn_if_location_supports_it() {
let relay_selector = default_relay_selector();
@@ -381,9 +382,10 @@ fn test_prefer_openvpn_if_location_supports_it() {
}
}
-/// Assert that the relay selector does *not* return a multihop configuration where the exit and entry relay are
-/// the same, even if the constraints would allow for it. Also verify that the relay selector is smart enough to
-/// pick either the entry or exit relay first depending on which one ends up yielding a valid configuration.
+/// Assert that the relay selector does *not* return a multihop configuration where the exit and
+/// entry relay are the same, even if the constraints would allow for it. Also verify that the relay
+/// selector is smart enough to pick either the entry or exit relay first depending on which one
+/// ends up yielding a valid configuration.
#[test]
fn test_wireguard_entry() {
// Define a relay list containing exactly two Wireguard relays in Gothenburg.
@@ -555,7 +557,8 @@ fn test_wireguard_entry_hostname_collision() {
/// Test that the relay selector:
/// * returns an OpenVPN relay given a constraint of a valid transport protocol + port combo
-/// * does *not* return an OpenVPN relay given a constraint of an *invalid* transport protocol + port combo
+/// * does *not* return an OpenVPN relay given a constraint of an *invalid* transport protocol +
+/// port combo
#[test]
fn test_openvpn_constraints() {
let relay_selector = default_relay_selector();
@@ -644,7 +647,8 @@ fn test_openvpn_constraints() {
}
}
-/// Construct a query for multihop configuration and assert that the relay selector picks an accompanying entry relay.
+/// Construct a query for multihop configuration and assert that the relay selector picks an
+/// accompanying entry relay.
#[test]
fn test_selecting_wireguard_location_will_consider_multihop() {
let relay_selector = default_relay_selector();
@@ -678,8 +682,9 @@ fn test_selecting_any_relay_will_consider_multihop() {
}
}
-/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is selected and multihop is explicitly
-/// turned off. Assert that the relay selector always return an obfuscator configuration.
+/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is selected and
+/// multihop is explicitly turned off. Assert that the relay selector always return an obfuscator
+/// configuration.
#[test]
fn test_selecting_wireguard_endpoint_with_udp2tcp_obfuscation() {
let relay_selector = default_relay_selector();
@@ -704,12 +709,14 @@ fn test_selecting_wireguard_endpoint_with_udp2tcp_obfuscation() {
}
}
-/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is set to "Auto" and multihop is
-/// explicitly turned off. Assert that the relay selector does *not* return an obfuscator config.
+/// Construct a query for a Wireguard configuration where UDP2TCP obfuscation is set to "Auto" and
+/// multihop is explicitly turned off. Assert that the relay selector does *not* return an
+/// obfuscator config.
///
/// # Note
-/// This is a highly specific test which details how the relay selector should behave at the time of writing this test.
-/// The cost (in latency primarily) of using obfuscation is deemed to be too high to enable it as an auto-configuration.
+/// This is a highly specific test which details how the relay selector should behave at the time of
+/// writing this test. The cost (in latency primarily) of using obfuscation is deemed to be too high
+/// to enable it as an auto-configuration.
#[test]
fn test_selecting_wireguard_endpoint_with_auto_obfuscation() {
let relay_selector = default_relay_selector();
diff --git a/mullvad-types/Cargo.toml b/mullvad-types/Cargo.toml
index 8b614dd182..c4b4f76486 100644
--- a/mullvad-types/Cargo.toml
+++ b/mullvad-types/Cargo.toml
@@ -21,6 +21,8 @@ serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.4.1", features = ["v4", "serde" ] }
talpid-types = { path = "../talpid-types" }
+intersection-derive = { path = "intersection-derive" }
+
clap = { workspace = true , optional = true }
diff --git a/mullvad-types/intersection-derive/Cargo.toml b/mullvad-types/intersection-derive/Cargo.toml
new file mode 100644
index 0000000000..f167c28aed
--- /dev/null
+++ b/mullvad-types/intersection-derive/Cargo.toml
@@ -0,0 +1,16 @@
+[package]
+name = "intersection-derive"
+description = "Derive macro for the `Intersection` trait"
+authors.workspace = true
+repository.workspace = true
+license.workspace = true
+edition.workspace = true
+rust-version.workspace = true
+
+[lib]
+proc-macro = true
+
+[dependencies]
+proc-macro2 = "1"
+syn = "2"
+quote = "1" \ No newline at end of file
diff --git a/mullvad-types/intersection-derive/src/lib.rs b/mullvad-types/intersection-derive/src/lib.rs
new file mode 100644
index 0000000000..65e5af15bd
--- /dev/null
+++ b/mullvad-types/intersection-derive/src/lib.rs
@@ -0,0 +1,68 @@
+//! This `proc-macro` crate exports the [`Intersection`] derive macro, see the trait documentation
+//! for more information.
+extern crate proc_macro;
+
+use proc_macro::TokenStream;
+use syn::{parse_macro_input, DeriveInput};
+
+/// Derive macro for the [`Intersection`] trait on structs.
+#[proc_macro_derive(Intersection)]
+pub fn intersection_derive(item: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(item as DeriveInput);
+
+ inner::derive(input).into()
+}
+
+mod inner {
+ use proc_macro2::TokenStream;
+ use quote::{quote, TokenStreamExt};
+ use syn::{spanned::Spanned, DeriveInput, Error};
+
+ pub(crate) fn derive(input: DeriveInput) -> TokenStream {
+ if let syn::Data::Struct(data) = &input.data {
+ derive_for_struct(&input, data).unwrap_or_else(Error::into_compile_error)
+ } else {
+ syn::Error::new(
+ input.span(),
+ "Deriving `Intersection` is only supported for structs",
+ )
+ .into_compile_error()
+ }
+ }
+
+ pub(crate) fn derive_for_struct(
+ input: &DeriveInput,
+ data: &syn::DataStruct,
+ ) -> syn::Result<TokenStream> {
+ let my_type = &input.ident;
+ let mut field_conversions = quote! {};
+ for field in &data.fields {
+ let Some(name) = &field.ident else {
+ return Err(syn::Error::new(
+ field.span(),
+ "Tuple structs are not currently supported",
+ ));
+ };
+
+ // TODO(Sebastian): Here, and in the `quote` below, we are referring to `Intersection`
+ // with its relative name, which will fail if the user renames the trait
+ // when importing, e.g. `use mullvad_types::Intersection as SomethingElse`.
+ // This is a know limitation of procural macros (declarative macros can use the `$crate`
+ // syntax). If the issue arises then it can be solve using the
+ // <https://crates.io/crates/proc-macro-crate> crate. Add it if necessary.
+ field_conversions.append_all(quote! {
+ #name: Intersection::intersection(self.#name, other.#name)?,
+ })
+ }
+
+ Ok(quote! {
+ impl Intersection for #my_type {
+ fn intersection(self, other: Self) -> ::core::option::Option<Self> {
+ ::core::option::Option::Some(Self {
+ #field_conversions
+ })
+ }
+ }
+ })
+ }
+}
diff --git a/mullvad-types/src/constraints/constraint.rs b/mullvad-types/src/constraints/constraint.rs
index c1669e33ae..26e1df3c6c 100644
--- a/mullvad-types/src/constraints/constraint.rs
+++ b/mullvad-types/src/constraints/constraint.rs
@@ -3,8 +3,9 @@
#[cfg(target_os = "android")]
use jnix::{FromJava, IntoJava};
use serde::{Deserialize, Serialize};
-use std::fmt;
-use std::str::FromStr;
+use std::{fmt, str::FromStr};
+
+use crate::Intersection;
/// Limits the set of [`crate::relay_list::Relay`]s that a `RelaySelector` may select.
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
@@ -17,6 +18,25 @@ pub enum Constraint<T> {
Only(T),
}
+impl<T: Intersection> Intersection for Constraint<T> {
+ /// Define the intersection between two arbitrary [`Constraint`]s.
+ ///
+ /// This operation may be compared to the set operation with the same name.
+ /// In contrast to the general set intersection, this function represents a
+ /// very specific case where [`Constraint::Any`] is equivalent to the set
+ /// universe and [`Constraint::Only`] represents a singleton set. Notable is
+ /// that the representation of any empty set is [`Option::None`].
+ fn intersection(self, other: Constraint<T>) -> Option<Constraint<T>> {
+ use Constraint::*;
+ match (self, other) {
+ (Any, Any) => Some(Any),
+ (Only(t), Any) | (Any, Only(t)) => Some(Only(t)),
+ // Pick any of `left` or `right` if they are the same.
+ (Only(left), Only(right)) => left.intersection(right).map(Only),
+ }
+ }
+}
+
impl<T: fmt::Display> fmt::Display for Constraint<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
diff --git a/mullvad-types/src/constraints/mod.rs b/mullvad-types/src/constraints/mod.rs
index 4caa31e901..756066c447 100644
--- a/mullvad-types/src/constraints/mod.rs
+++ b/mullvad-types/src/constraints/mod.rs
@@ -5,10 +5,7 @@ mod constraint;
// Re-export bits & pieces from `constraints.rs` as needed
pub use constraint::Constraint;
-/// A limited variant of Sets.
-pub trait Set<T> {
- fn is_subset(&self, other: &T) -> bool;
-}
+use crate::relay_constraints;
pub trait Match<T> {
fn matches(&self, other: &T) -> bool;
@@ -22,14 +19,124 @@ impl<T: Match<U>, U> Match<U> for Constraint<T> {
}
}
-impl<T: Set<U>, U> Set<Constraint<U>> for Constraint<T> {
- fn is_subset(&self, other: &Constraint<U>) -> bool {
- match self {
- Constraint::Any => other.is_any(),
- Constraint::Only(ref constraint) => match other {
- Constraint::Only(ref other_constraint) => constraint.is_subset(other_constraint),
- _ => true,
- },
+/// The intersection of two sets of criteria on [`Relay`](crate::relay_list::Relay)s is another
+/// criteria which matches the given relay iff both of the original criteria matched. It is
+/// primarily used by the relay selector to check whether a given connection method is compatible
+/// with the users settings.
+///
+/// # Examples
+///
+/// The [`Intersection`] implementation of [`RelayQuery`] upholds the following properties:
+///
+/// * If two [`RelayQuery`]s differ such that no relay matches both, [`Option::None`] is returned:
+/// ```rust, ignore
+/// # use mullvad_relay_selector::query::builder::RelayQueryBuilder;
+/// let query_a = RelayQueryBuilder::new().wireguard().build();
+/// let query_b = RelayQueryBuilder::new().openvpn().build();
+/// assert_eq!(query_a.intersection(query_b), None);
+/// ```
+///
+/// * Otherwise, a new [`RelayQuery`] is returned where each constraint is
+/// as specific as possible. See [`Constraint`] for further details.
+/// ```rust, ignore
+/// let query_a = RelayQueryBuilder::new().wireguard().build();
+/// let query_b = RelayQueryBuilder::new().location(city("Sweden", "Gothenburg")).build();
+///
+/// let result = relay_selector.get_relay_by_query(query_a.intersection(query_b).unwrap());
+/// assert!(result.is_ok());
+/// ```
+///
+/// This way, if the mullvad app wants to check if the user's relay settings
+/// are compatible with any other [`RelayQuery`], for examples those defined by
+/// [`RETRY_ORDER`] , taking the intersection between them will never result in
+/// a situation where the app can override the user's preferences.
+///
+/// [`RETRY_ORDER`]: crate::RETRY_ORDER
+///
+/// The macro recursively applies the intersection on each field of the struct and returns the
+/// resulting type or `None` if any of the intersections failed to overlap.
+///
+/// The macro requires the types of each field to also implement [`Intersection`], which may be done
+/// using this derive macro, the
+///
+/// # Implementing [`Intersection`]
+///
+/// For structs where each field already implements `Intersection`, the easiest way to implement the
+/// trait is using the derive macro. Using the derive macro on [`RelayQuery`]
+/// ```rust, ignore
+/// #[derive(Intersection)]
+/// struct RelayQuery {
+/// pub location: Constraint<LocationConstraint>,
+/// pub providers: Constraint<Providers>,
+/// pub ownership: Constraint<Ownership>,
+/// pub tunnel_protocol: Constraint<TunnelType>,
+/// pub wireguard_constraints: WireguardRelayQuery,
+/// pub openvpn_constraints: OpenVpnRelayQuery,
+/// }
+/// ```
+///
+/// produces an implementation like this:
+///
+/// ```rust, ignore
+/// impl Intersection for RelayQuery {
+/// fn intersection(self, other: Self) -> Option<Self>
+/// where
+/// Self: PartialEq,
+/// Self: Sized,
+/// {
+/// Some(RelayQuery {
+/// location: self.location.intersection(other.location)?,
+/// providers: self.providers.intersection(other.providers)?,
+/// ownership: self.ownership.intersection(other.ownership)?,
+/// tunnel_protocol: self.tunnel_protocol.intersection(other.tunnel_protocol)?,
+/// wireguard_constraints: self
+/// .wireguard_constraints
+/// .intersection(other.wireguard_constraints)?,
+/// openvpn_constraints: self
+/// .openvpn_constraints
+/// .intersection(other.openvpn_constraints)?,
+/// })
+/// }
+/// }
+/// ```
+///
+/// For types that cannot "overlap", e.g. they only intersect if they are equal, the declarative
+/// macro [`impl_intersection_partialeq`] can be used.
+///
+/// For less trivial cases, the trait needs to be implemented manually. When doing so, make sure
+/// that the following properties are upheld:
+///
+/// - idempotency (if there is an identity element)
+/// - commutativity
+/// - associativity
+pub trait Intersection: Sized {
+ fn intersection(self, other: Self) -> Option<Self>;
+}
+
+#[macro_export]
+macro_rules! impl_intersection_partialeq {
+ ($ty:ty) => {
+ impl $crate::Intersection for $ty {
+ fn intersection(self, other: Self) -> Option<Self> {
+ if self == other {
+ Some(self)
+ } else {
+ None
+ }
+ }
}
- }
+ };
}
+
+impl_intersection_partialeq!(u16);
+impl_intersection_partialeq!(bool);
+
+// NOTE: this implementation does not do what you may expect of an intersection
+impl_intersection_partialeq!(relay_constraints::Providers);
+// NOTE: should take actual intersection
+impl_intersection_partialeq!(relay_constraints::LocationConstraint);
+impl_intersection_partialeq!(relay_constraints::Ownership);
+// NOTE: it contains an inner constraint
+impl_intersection_partialeq!(talpid_types::net::TransportProtocol);
+impl_intersection_partialeq!(talpid_types::net::TunnelType);
+impl_intersection_partialeq!(talpid_types::net::IpVersion);
diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs
index 231f80300b..d1a50fa0ea 100644
--- a/mullvad-types/src/lib.rs
+++ b/mullvad-types/src/lib.rs
@@ -21,3 +21,6 @@ pub use crate::custom_tunnel::*;
pub const TUNNEL_TABLE_ID: u32 = 0x6d6f6c65;
#[cfg(target_os = "linux")]
pub const TUNNEL_FWMARK: u32 = 0x6d6f6c65;
+
+pub use constraints::Intersection;
+pub use intersection_derive::Intersection;
diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs
index 05d07477be..9b6b640cf5 100644
--- a/mullvad-types/src/relay_constraints.rs
+++ b/mullvad-types/src/relay_constraints.rs
@@ -2,11 +2,11 @@
//! updated as well.
use crate::{
- constraints::{Constraint, Match, Set},
+ constraints::{Constraint, Match},
custom_list::{CustomListsSettings, Id},
location::{CityCode, CountryCode, Hostname},
relay_list::Relay,
- CustomTunnelEndpoint,
+ CustomTunnelEndpoint, Intersection,
};
#[cfg(target_os = "android")]
use jnix::{jni::objects::JObject, FromJava, IntoJava, JnixEnv};
@@ -294,51 +294,6 @@ impl Match<Relay> for GeographicLocationConstraint {
}
}
-impl Set<GeographicLocationConstraint> for GeographicLocationConstraint {
- /// Returns whether `self` is equal to or a subset of `other`.
- fn is_subset(&self, other: &Self) -> bool {
- match self {
- GeographicLocationConstraint::Country(_) => self == other,
- GeographicLocationConstraint::City(ref country, ref _city) => match other {
- GeographicLocationConstraint::Country(ref other_country) => {
- country == other_country
- }
- GeographicLocationConstraint::City(..) => self == other,
- _ => false,
- },
- GeographicLocationConstraint::Hostname(ref country, ref city, ref _hostname) => {
- match other {
- GeographicLocationConstraint::Country(ref other_country) => {
- country == other_country
- }
- GeographicLocationConstraint::City(ref other_country, ref other_city) => {
- country == other_country && city == other_city
- }
- GeographicLocationConstraint::Hostname(..) => self == other,
- }
- }
- }
- }
-}
-
-impl Set<Constraint<Vec<GeographicLocationConstraint>>>
- for Constraint<Vec<GeographicLocationConstraint>>
-{
- fn is_subset(&self, other: &Self) -> bool {
- match self {
- Constraint::Any => other.is_any(),
- Constraint::Only(locations) => match other {
- Constraint::Any => true,
- Constraint::Only(other_locations) => locations.iter().all(|location| {
- other_locations
- .iter()
- .any(|other_location| location.is_subset(other_location))
- }),
- },
- }
- }
-}
-
/// Limits the set of servers to choose based on ownership.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
#[cfg_attr(target_os = "android", derive(IntoJava, FromJava))]
@@ -463,7 +418,7 @@ impl fmt::Display for GeographicLocationConstraint {
}
}
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
pub struct TransportPort {
pub protocol: TransportProtocol,
pub port: Constraint<u16>,
@@ -646,6 +601,21 @@ pub enum SelectedObfuscation {
Udp2Tcp,
}
+impl Intersection for SelectedObfuscation {
+ fn intersection(self, other: Self) -> Option<Self>
+ where
+ Self: PartialEq,
+ Self: Sized,
+ {
+ match (self, other) {
+ (left, SelectedObfuscation::Auto) => Some(left),
+ (SelectedObfuscation::Auto, right) => Some(right),
+ (left, right) if left == right => Some(left),
+ _ => None,
+ }
+ }
+}
+
impl fmt::Display for SelectedObfuscation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -656,7 +626,7 @@ impl fmt::Display for SelectedObfuscation {
}
}
-#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Default, Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
#[serde(rename_all = "snake_case")]
@@ -716,7 +686,7 @@ pub struct ObfuscationSettings {
}
/// Limits the set of bridge servers to use in `mullvad-daemon`.
-#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize, Intersection)]
#[serde(default)]
#[serde(rename_all = "snake_case")]
pub struct BridgeConstraints {
diff --git a/test/Cargo.lock b/test/Cargo.lock
index 5f207d9a34..eb38b5f37a 100644
--- a/test/Cargo.lock
+++ b/test/Cargo.lock
@@ -1380,6 +1380,15 @@ dependencies = [
]
[[package]]
+name = "intersection-derive"
+version = "0.0.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.51",
+]
+
+[[package]]
name = "inventory"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1814,6 +1823,7 @@ name = "mullvad-relay-selector"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork 0.16.0",
"itertools 0.12.1",
"log",
@@ -1830,6 +1840,7 @@ name = "mullvad-types"
version = "0.0.0"
dependencies = [
"chrono",
+ "intersection-derive",
"ipnetwork 0.16.0",
"jnix",
"log",