summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-08-14 14:18:37 +0200
committerDavid Lönnhager <david.l@mullvad.net>2024-08-16 09:13:30 +0200
commit97db97d200d547a9a80e3001b69e2edf85b968bd (patch)
tree33c27fc9f7a592ec2e4fb2bf598fb9dcb6d2d75f
parent57285fa162fced4bd62ff6bb7b71e1e0ccd42309 (diff)
downloadmullvadvpn-97db97d200d547a9a80e3001b69e2edf85b968bd.tar.xz
mullvadvpn-97db97d200d547a9a80e3001b69e2edf85b968bd.zip
Simplify random port selection
-rw-r--r--mullvad-relay-selector/src/relay_selector/detailer.rs11
-rw-r--r--mullvad-relay-selector/src/relay_selector/helpers.rs62
-rw-r--r--mullvad-types/src/constraints/constraint.rs7
3 files changed, 41 insertions, 39 deletions
diff --git a/mullvad-relay-selector/src/relay_selector/detailer.rs b/mullvad-relay-selector/src/relay_selector/detailer.rs
index 536ea50f26..7b5f24f94c 100644
--- a/mullvad-relay-selector/src/relay_selector/detailer.rs
+++ b/mullvad-relay-selector/src/relay_selector/detailer.rs
@@ -180,15 +180,8 @@ fn get_port_for_wireguard_relay(
query: &WireguardRelayQuery,
data: &WireguardEndpointData,
) -> Result<u16, Error> {
- if let Constraint::Only(port) = query.port {
- if super::helpers::port_in_range(port, &data.port_ranges) {
- Ok(port)
- } else {
- Err(Error::PortSelectionError)
- }
- } else {
- super::helpers::select_random_port(&data.port_ranges).map_err(|_| Error::PortSelectionError)
- }
+ super::helpers::desired_or_random_port_from_range(&data.port_ranges, query.port)
+ .map_err(|_err| Error::PortSelectionError)
}
/// Read the [`PublicKey`] of a relay. This will only succeed if [relay][`Relay`] is a
diff --git a/mullvad-relay-selector/src/relay_selector/helpers.rs b/mullvad-relay-selector/src/relay_selector/helpers.rs
index 6cf13db45c..c49a7df04d 100644
--- a/mullvad-relay-selector/src/relay_selector/helpers.rs
+++ b/mullvad-relay-selector/src/relay_selector/helpers.rs
@@ -156,41 +156,48 @@ fn get_shadowsocks_obfuscator_inner<R: RangeBounds<u16> + Iterator<Item = u16> +
.unwrap_or(wg_in_addr);
let selected_port = if extra_in_addrs.is_empty() {
- desired_port_from_range(wg_in_addr_port_ranges, desired_port)
+ desired_or_random_port_from_range(wg_in_addr_port_ranges, desired_port)
} else {
- desired_port_from_range(SHADOWSOCKS_EXTRA_PORT_RANGES, desired_port)
+ desired_or_random_port_from_range(SHADOWSOCKS_EXTRA_PORT_RANGES, desired_port)
}?;
Ok(SocketAddr::from((in_ip, selected_port)))
}
-fn desired_port_from_range<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
+/// Return `desired_port` if it is specified and included in `port_ranges`.
+/// If `desired_port` isn't specified, a random port from the ranges is returned.
+/// If `desired_port` is specified but not in range, an error is returned.
+pub fn desired_or_random_port_from_range<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
port_ranges: &[R],
desired_port: Constraint<u16>,
) -> Result<u16, Error> {
- match desired_port {
- // Selected a specific, in-range port
- Constraint::Only(port) if port_in_range(port, port_ranges) => Ok(port),
- // Selected a specific, out-of-range port
- Constraint::Only(_port) => Err(Error::NoMatchingPort),
- // Selected no specific port
- Constraint::Any => select_random_port(port_ranges),
- }
+ desired_port
+ .map(|port| port_if_in_range(port_ranges, port))
+ .unwrap_or_else(|| select_random_port(port_ranges))
+}
+
+/// Return `Ok(port)`, if and only if `port` is in `port_ranges`. Otherwise, return an error.
+fn port_if_in_range<R: RangeBounds<u16>>(port_ranges: &[R], port: u16) -> Result<u16, Error> {
+ port_ranges
+ .iter()
+ .find_map(|range| {
+ if range.contains(&port) {
+ Some(port)
+ } else {
+ None
+ }
+ })
+ .ok_or(Error::NoMatchingPort)
}
/// Selects a random port number from a list of provided port ranges.
///
-/// This function iterates over a list of port ranges, each represented as a tuple (u16, u16)
-/// where the first element is the start of the range and the second is the end (inclusive),
-/// and selects a random port from the set of all ranges.
-///
/// # Parameters
-/// - `port`: Constraint to apply to the port selection
-/// - `port_ranges`: A slice of tuples, each representing a range of valid port numbers.
+/// - `port_ranges`: A slice of port numbers.
///
/// # Returns
-/// - A randomly selected port number within the given ranges.
-/// - An error if `port_ranges` is empty.
+/// - On success, a randomly selected port number within the given ranges. Otherwise,
+/// an error is returned.
pub fn select_random_port<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
port_ranges: &[R],
) -> Result<u16, Error> {
@@ -202,13 +209,11 @@ pub fn select_random_port<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
.ok_or(Error::NoMatchingPort)
}
-pub fn port_in_range<R: RangeBounds<u16>>(port: u16, port_ranges: &[R]) -> bool {
- port_ranges.iter().any(|range| range.contains(&port))
-}
-
#[cfg(test)]
mod tests {
- use super::{get_shadowsocks_obfuscator_inner, port_in_range, SHADOWSOCKS_EXTRA_PORT_RANGES};
+ use super::{
+ get_shadowsocks_obfuscator_inner, port_if_in_range, SHADOWSOCKS_EXTRA_PORT_RANGES,
+ };
use mullvad_types::constraints::Constraint;
use std::{net::IpAddr, ops::RangeInclusive};
@@ -226,7 +231,7 @@ mod tests {
assert_eq!(selected_addr.ip(), wg_in_ip);
assert!(
- port_in_range(selected_addr.port(), PORT_RANGES),
+ port_if_in_range(PORT_RANGES, selected_addr.port()).is_ok(),
"expected port in port range"
);
@@ -240,7 +245,7 @@ mod tests {
assert_eq!(selected_addr.ip(), wg_in_ip);
assert!(
- port_in_range(selected_addr.port(), PORT_RANGES),
+ port_if_in_range(PORT_RANGES, selected_addr.port()).is_ok(),
"expected port in port range"
);
@@ -278,10 +283,7 @@ mod tests {
extra_in_addrs.contains(&selected_addr.ip()),
"expected extra IP to be selected"
);
- assert!(port_in_range(
- selected_addr.port(),
- SHADOWSOCKS_EXTRA_PORT_RANGES
- ));
+ assert!(port_if_in_range(SHADOWSOCKS_EXTRA_PORT_RANGES, selected_addr.port(),).is_ok());
let selected_addr = get_shadowsocks_obfuscator_inner(
wg_in_ip,
diff --git a/mullvad-types/src/constraints/constraint.rs b/mullvad-types/src/constraints/constraint.rs
index 2554a97d25..35b83e1320 100644
--- a/mullvad-types/src/constraints/constraint.rs
+++ b/mullvad-types/src/constraints/constraint.rs
@@ -56,6 +56,13 @@ impl<T> Constraint<T> {
}
}
+ pub fn unwrap_or_else<F: FnOnce() -> T>(self, or_else: F) -> T {
+ match self {
+ Constraint::Only(value) => value,
+ Constraint::Any => or_else(),
+ }
+ }
+
pub fn or(self, other: Constraint<T>) -> Constraint<T> {
match self {
Constraint::Any => other,