diff options
31 files changed, 404 insertions, 246 deletions
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt index b2af9c989d..3ab59000ae 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt @@ -5,7 +5,7 @@ import android.os.Messenger import java.net.InetAddress import kotlinx.parcelize.Parcelize import net.mullvad.mullvadvpn.model.DnsOptions -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.model.ObfuscationSettings import net.mullvad.mullvadvpn.model.QuantumResistantState import net.mullvad.mullvadvpn.model.WireguardConstraints @@ -73,7 +73,8 @@ sealed class Request : Message.RequestMessage() { @Parcelize data class SetEnableSplitTunneling(val enable: Boolean) : Request() - @Parcelize data class SetRelayLocation(val relayLocation: LocationConstraint?) : Request() + @Parcelize + data class SetRelayLocation(val relayLocation: GeographicLocationConstraint?) : Request() @Parcelize data class SetWireGuardMtu(val mtu: Int?) : Request() diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/GeographicLocationConstraint.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/GeographicLocationConstraint.kt new file mode 100644 index 0000000000..04f92a72ac --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/GeographicLocationConstraint.kt @@ -0,0 +1,28 @@ +package net.mullvad.mullvadvpn.model + +import android.os.Parcelable +import kotlinx.parcelize.Parcelize + +sealed class GeographicLocationConstraint : Parcelable { + abstract val location: GeoIpLocation + + @Parcelize + data class Country(val countryCode: String) : GeographicLocationConstraint() { + override val location: GeoIpLocation + get() = GeoIpLocation(null, null, countryCode, null, null) + } + + @Parcelize + data class City(val countryCode: String, val cityCode: String) : + GeographicLocationConstraint() { + override val location: GeoIpLocation + get() = GeoIpLocation(null, null, countryCode, cityCode, null) + } + + @Parcelize + data class Hostname(val countryCode: String, val cityCode: String, val hostname: String) : + GeographicLocationConstraint() { + override val location: GeoIpLocation + get() = GeoIpLocation(null, null, countryCode, cityCode, hostname) + } +} diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/LocationConstraint.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/LocationConstraint.kt index 2820a449b8..de7dd4e99b 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/LocationConstraint.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/LocationConstraint.kt @@ -4,24 +4,7 @@ import android.os.Parcelable import kotlinx.parcelize.Parcelize sealed class LocationConstraint : Parcelable { - abstract val location: GeoIpLocation - - @Parcelize - data class Country(val countryCode: String) : LocationConstraint() { - override val location: GeoIpLocation - get() = GeoIpLocation(null, null, countryCode, null, null) - } - - @Parcelize - data class City(val countryCode: String, val cityCode: String) : LocationConstraint() { - override val location: GeoIpLocation - get() = GeoIpLocation(null, null, countryCode, cityCode, null) - } - @Parcelize - data class Hostname(val countryCode: String, val cityCode: String, val hostname: String) : - LocationConstraint() { - override val location: GeoIpLocation - get() = GeoIpLocation(null, null, countryCode, cityCode, hostname) - } + data class Location(val location: GeographicLocationConstraint) : LocationConstraint() + @Parcelize data class CustomList(val listId: String) : LocationConstraint() } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Relay.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Relay.kt index 7afb2249d2..6f7b6760b0 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Relay.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Relay.kt @@ -1,12 +1,13 @@ package net.mullvad.mullvadvpn.relaylist -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint data class Relay(val city: RelayCity, override val name: String, override val active: Boolean) : RelayItem { override val code = name override val type = RelayItemType.Relay - override val location = LocationConstraint.Hostname(city.country.code, city.code, name) + override val location = + GeographicLocationConstraint.Hostname(city.country.code, city.code, name) override val hasChildren = false override val visibleChildCount = 0 diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCity.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCity.kt index 9500c43795..c6244101f6 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCity.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCity.kt @@ -1,6 +1,6 @@ package net.mullvad.mullvadvpn.relaylist -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint class RelayCity( val country: RelayCountry, @@ -10,7 +10,7 @@ class RelayCity( val relays: List<Relay> ) : RelayItem { override val type = RelayItemType.City - override val location = LocationConstraint.City(country.code, code) + override val location = GeographicLocationConstraint.City(country.code, code) override val active get() = relays.any { relay -> relay.active } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCountry.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCountry.kt index 447cc25ff2..d8424cacad 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCountry.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayCountry.kt @@ -1,6 +1,6 @@ package net.mullvad.mullvadvpn.relaylist -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint class RelayCountry( override val name: String, @@ -9,7 +9,7 @@ class RelayCountry( val cities: List<RelayCity> ) : RelayItem { override val type = RelayItemType.Country - override val location = LocationConstraint.Country(code) + override val location = GeographicLocationConstraint.Country(code) override val active get() = cities.any { city -> city.active } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayItem.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayItem.kt index e5f28acee6..fde283fcdf 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayItem.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayItem.kt @@ -1,12 +1,12 @@ package net.mullvad.mullvadvpn.relaylist -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint interface RelayItem { val type: RelayItemType val name: String val code: String - val location: LocationConstraint + val location: GeographicLocationConstraint val active: Boolean val hasChildren: Boolean val visibleChildCount: Int diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt index b5aaed028a..60cbdd46cf 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt @@ -1,7 +1,7 @@ package net.mullvad.mullvadvpn.relaylist import net.mullvad.mullvadvpn.model.Constraint -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint class RelayList { val countries: List<RelayCountry> @@ -41,7 +41,7 @@ class RelayList { } fun findItemForLocation( - constraint: Constraint<LocationConstraint>, + constraint: Constraint<GeographicLocationConstraint>, expand: Boolean = false ): RelayItem? { when (constraint) { @@ -50,10 +50,10 @@ class RelayList { val location = constraint.value when (location) { - is LocationConstraint.Country -> { + is GeographicLocationConstraint.Country -> { return countries.find { country -> country.code == location.countryCode } } - is LocationConstraint.City -> { + is GeographicLocationConstraint.City -> { val country = countries.find { country -> country.code == location.countryCode } @@ -63,7 +63,7 @@ class RelayList { return country?.cities?.find { city -> city.code == location.cityCode } } - is LocationConstraint.Hostname -> { + is GeographicLocationConstraint.Hostname -> { val country = countries.find { country -> country.code == location.countryCode } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/LocationInfoCache.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/LocationInfoCache.kt index 7cc43925d7..7b0d419b45 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/LocationInfoCache.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/LocationInfoCache.kt @@ -19,6 +19,7 @@ import net.mullvad.mullvadvpn.model.GeoIpLocation import net.mullvad.mullvadvpn.model.RelaySettings import net.mullvad.mullvadvpn.model.TunnelState import net.mullvad.mullvadvpn.util.ExponentialBackoff +import net.mullvad.mullvadvpn.util.toGeographicLocationConstraint import net.mullvad.talpid.tunnel.ActionAfterDisconnect class LocationInfoCache(private val endpoint: ServiceEndpoint) { @@ -131,6 +132,6 @@ class LocationInfoCache(private val endpoint: ServiceEndpoint) { val settings = relaySettings as? RelaySettings.Normal val constraint = settings?.relayConstraints?.location as? Constraint.Only - selectedRelayLocation = constraint?.value?.location + selectedRelayLocation = constraint?.value?.toGeographicLocationConstraint()?.location } } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt index 4fa531eeb4..7f4d274d6f 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt @@ -10,6 +10,7 @@ import kotlinx.coroutines.channels.trySendBlocking import net.mullvad.mullvadvpn.ipc.Event import net.mullvad.mullvadvpn.ipc.Request import net.mullvad.mullvadvpn.model.Constraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.model.LocationConstraint import net.mullvad.mullvadvpn.model.RelayConstraintsUpdate import net.mullvad.mullvadvpn.model.RelayList @@ -29,7 +30,7 @@ class RelayListListener(endpoint: ServiceEndpoint) { private val daemon = endpoint.intermittentDaemon private var selectedRelayLocation by - observable<LocationConstraint?>(null) { _, _, _ -> + observable<GeographicLocationConstraint?>(null) { _, _, _ -> commandChannel.trySendBlocking(Command.SetRelayLocation) } private var selectedWireguardConstraints by @@ -93,7 +94,10 @@ class RelayListListener(endpoint: ServiceEndpoint) { private suspend fun updateRelayConstraints() { val location: Constraint<LocationConstraint> = - selectedRelayLocation?.let { location -> Constraint.Only(location) } ?: Constraint.Any() + selectedRelayLocation?.let { location -> + Constraint.Only(LocationConstraint.Location(location)) + } + ?: Constraint.Any() val wireguardConstraints: WireguardConstraints? = selectedWireguardConstraints val update = diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt index 46cf492d01..5efca2648c 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt @@ -5,6 +5,7 @@ import net.mullvad.mullvadvpn.ipc.Event import net.mullvad.mullvadvpn.ipc.EventDispatcher import net.mullvad.mullvadvpn.ipc.Request import net.mullvad.mullvadvpn.model.Constraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.model.LocationConstraint import net.mullvad.mullvadvpn.model.PortRange import net.mullvad.mullvadvpn.model.RelayConstraints @@ -12,6 +13,7 @@ import net.mullvad.mullvadvpn.model.RelaySettings import net.mullvad.mullvadvpn.model.WireguardConstraints import net.mullvad.mullvadvpn.relaylist.RelayItem import net.mullvad.mullvadvpn.relaylist.RelayList +import net.mullvad.mullvadvpn.util.toGeographicLocationConstraint class RelayListListener( private val connection: Messenger, @@ -25,12 +27,12 @@ class RelayListListener( var selectedRelayItem: RelayItem? = null private set - var selectedRelayLocation: LocationConstraint? + var selectedRelayLocation: GeographicLocationConstraint? get() { val settings = relaySettings as? RelaySettings.Normal val location = settings?.relayConstraints?.location as? Constraint.Only - return location?.value + return location?.value?.toGeographicLocationConstraint() } set(value) { connection.send(Request.SetRelayLocation(value).message) @@ -119,7 +121,10 @@ class RelayListListener( is RelaySettings.Normal -> { val location = relaySettings.relayConstraints.location - return relayList?.findItemForLocation(location, true) + return relayList?.findItemForLocation( + location.toGeographicLocationConstraint(), + true + ) } else -> { /* NOOP */ diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/util/LocationConstraintExtensions.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/util/LocationConstraintExtensions.kt new file mode 100644 index 0000000000..2637028111 --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/util/LocationConstraintExtensions.kt @@ -0,0 +1,22 @@ +package net.mullvad.mullvadvpn.util + +import net.mullvad.mullvadvpn.model.Constraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint +import net.mullvad.mullvadvpn.model.LocationConstraint + +fun LocationConstraint.toGeographicLocationConstraint(): GeographicLocationConstraint? = + when (this) { + is LocationConstraint.Location -> this.location + is LocationConstraint.CustomList -> null + } + +fun Constraint<LocationConstraint>.toGeographicLocationConstraint(): + Constraint<GeographicLocationConstraint> = + when (this) { + is Constraint.Only -> + when (this.value) { + is LocationConstraint.Location -> Constraint.Only(this.value.location) + is LocationConstraint.CustomList -> Constraint.Any() + } + is Constraint.Any -> Constraint.Any() + } diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt index fbb5008eb7..ba7ea30c41 100644 --- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt @@ -16,7 +16,7 @@ import kotlinx.coroutines.test.runTest import net.mullvad.mullvadvpn.TestCoroutineRule import net.mullvad.mullvadvpn.assertLists import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState -import net.mullvad.mullvadvpn.model.LocationConstraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.relaylist.RelayCountry import net.mullvad.mullvadvpn.relaylist.RelayItem import net.mullvad.mullvadvpn.relaylist.RelayList @@ -120,7 +120,7 @@ class SelectLocationViewModelTest { fun testSelectRelayAndClose() = runTest { // Arrange val mockRelayItem: RelayItem = mockk() - val mockLocation: LocationConstraint.Country = mockk(relaxed = true) + val mockLocation: GeographicLocationConstraint.Country = mockk(relaxed = true) val connectionProxyMock: ConnectionProxy = mockk(relaxUnitFun = true) every { mockRelayItem.location } returns mockLocation every { mockServiceConnectionManager.relayListListener() } returns mockRelayListListener diff --git a/gui/src/main/daemon-rpc.ts b/gui/src/main/daemon-rpc.ts index e2a909b756..f4d47c7916 100644 --- a/gui/src/main/daemon-rpc.ts +++ b/gui/src/main/daemon-rpc.ts @@ -1232,21 +1232,19 @@ function convertFromLocation(location: grpcTypes.LocationConstraint.AsObject): R // FIXME: This is a hack that assumes that the LocationConstraint is not a custom list. // If it is we just set the country to "any" even if that isn't correct. if (location.location == undefined) { - return { country: "any" }; + return { country: 'any' }; } - else { - const loc = location.location; + const loc = location.location; - if (loc.hostname) { - return { hostname: [loc.country, loc.city, loc.hostname] }; - } - - if (loc.city) { - return { city: [loc.country, loc.city] }; - } + if (loc.hostname) { + return { hostname: [loc.country, loc.city, loc.hostname] }; + } - return { country: loc.country }; + if (loc.city) { + return { city: [loc.country, loc.city] }; } + + return { country: loc.country }; } function convertFromTunnelOptions(tunnelOptions: grpcTypes.TunnelOptions.AsObject): ITunnelOptions { diff --git a/mullvad-cli/src/cmds/bridge.rs b/mullvad-cli/src/cmds/bridge.rs index a170661838..e43fce1b51 100644 --- a/mullvad-cli/src/cmds/bridge.rs +++ b/mullvad-cli/src/cmds/bridge.rs @@ -3,16 +3,15 @@ use clap::Subcommand; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::{ relay_constraints::{ - BridgeConstraints, BridgeSettings, BridgeState, Constraint, - LocationConstraint, Ownership, Provider, Providers, + BridgeConstraints, BridgeSettings, BridgeState, Constraint, LocationConstraint, Ownership, + Provider, Providers, }, relay_list::RelayEndpointData, }; use std::net::{IpAddr, SocketAddr}; use talpid_types::net::openvpn::{self, SHADOWSOCKS_CIPHERS}; -use super::relay::find_relay_by_hostname; -use super::relay_constraints::LocationArgs; +use super::{relay::find_relay_by_hostname, relay_constraints::LocationArgs}; #[derive(Subcommand, Debug)] pub enum Bridge { @@ -163,17 +162,19 @@ impl Bridge { } SetCommands::Location(location) => { let countries = rpc.get_relay_locations().await?.countries; - let location = if let Some(relay) = find_relay_by_hostname(&countries, &location.country) { - Constraint::Only(relay) - } else { - Constraint::from(location) - }; + let location = + if let Some(relay) = find_relay_by_hostname(&countries, &location.country) { + Constraint::Only(relay) + } else { + Constraint::from(location) + }; let location = location.map(|location| LocationConstraint::Location { location }); Self::update_bridge_settings(&mut rpc, Some(location), None, None).await } SetCommands::CustomList { custom_list_name } => { let list = rpc.get_custom_list(custom_list_name).await?; - let location = Constraint::Only(LocationConstraint::CustomList { list_id: list.id }); + let location = + Constraint::Only(LocationConstraint::CustomList { list_id: list.id }); Self::update_bridge_settings(&mut rpc, Some(location), None, None).await } SetCommands::Ownership { ownership } => { diff --git a/mullvad-cli/src/cmds/custom_lists.rs b/mullvad-cli/src/cmds/custom_lists.rs index 12dd0b0d64..9c0770a442 100644 --- a/mullvad-cli/src/cmds/custom_lists.rs +++ b/mullvad-cli/src/cmds/custom_lists.rs @@ -4,9 +4,7 @@ use clap::Subcommand; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::{ custom_list::CustomListLocationUpdate, - relay_constraints::{ - Constraint, GeographicLocationConstraint - }, + relay_constraints::{Constraint, GeographicLocationConstraint}, }; #[derive(Subcommand, Debug)] diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs index 6787c588d8..1bb395c4e5 100644 --- a/mullvad-cli/src/cmds/relay.rs +++ b/mullvad-cli/src/cmds/relay.rs @@ -314,7 +314,9 @@ impl Relay { match subcmd { SetCommands::Custom(subcmd) => Self::set_custom(subcmd).await, SetCommands::Location(location) => Self::set_location(location).await, - SetCommands::CustomList { custom_list_name } => Self::set_custom_list(custom_list_name).await, + SetCommands::CustomList { custom_list_name } => { + Self::set_custom_list(custom_list_name).await + } SetCommands::Provider { providers } => Self::set_providers(providers).await, SetCommands::Ownership { ownership } => Self::set_ownership(ownership).await, SetCommands::Tunnel(subcmd) => Self::set_tunnel(subcmd).await, @@ -448,9 +450,10 @@ impl Relay { // The country field is assumed to be hostname due to CLI argument parsing find_relay_by_hostname(&countries, &location_constraint_args.country) { - Constraint::Only(relay) + Constraint::Only(LocationConstraint::Location { location: relay }) } else { - let location_constraint: Constraint<LocationConstraint> = Constraint::from(location_constraint_args); + let location_constraint: Constraint<GeographicLocationConstraint> = + Constraint::from(location_constraint_args); match &location_constraint { Constraint::Any => (), Constraint::Only(constraint) => { @@ -465,11 +468,11 @@ impl Relay { } } } - location_constraint + location_constraint.map(|location| LocationConstraint::Location { location }) }; Self::update_constraints(RelaySettingsUpdate::Normal(RelayConstraintsUpdate { - location: Some(constraint.map(|location| LocationConstraint::Location { location })), + location: Some(constraint), ..Default::default() })) .await @@ -571,17 +574,18 @@ impl Relay { Some(EntryLocation::EntryLocation(entry)) => { let countries = Self::get_filtered_relays().await?; // The country field is assumed to be hostname due to CLI argument parsing - wireguard_constraints.entry_location = - if let Some(relay) = find_relay_by_hostname(&countries, &entry.country) { - Constraint::Only(relay) - } else { - Constraint::from(entry) - }; - }, + wireguard_constraints.entry_location = + if let Some(relay) = find_relay_by_hostname(&countries, &entry.country) { + Constraint::Only(LocationConstraint::Location { location: relay }) + } else { + Constraint::from(entry) + }; + } Some(EntryLocation::CustomList { custom_list_name }) => { let list = rpc.get_custom_list(custom_list_name).await?; - wireguard_constraints.entry_location = Constraint::Only(LocationConstraint::CustomList { list_id: list.id }); - }, + wireguard_constraints.entry_location = + Constraint::Only(LocationConstraint::CustomList { list_id: list.id }); + } None => (), } @@ -644,7 +648,7 @@ fn parse_transport_port( pub fn find_relay_by_hostname( countries: &[RelayListCountry], hostname: &str, -) -> Option<LocationConstraint> { +) -> Option<GeographicLocationConstraint> { countries .iter() .flat_map(|country| country.cities.clone()) @@ -657,7 +661,7 @@ pub fn find_relay_by_hostname( city_code, .. }| { - LocationConstraint::Normal { location: GeographicLocationConstraint::Hostname(country_code, city_code, relay.hostname) } + GeographicLocationConstraint::Hostname(country_code, city_code, relay.hostname) }, ) }) diff --git a/mullvad-cli/src/cmds/tunnel_state.rs b/mullvad-cli/src/cmds/tunnel_state.rs index 6e115c5ba1..76393091c5 100644 --- a/mullvad-cli/src/cmds/tunnel_state.rs +++ b/mullvad-cli/src/cmds/tunnel_state.rs @@ -2,8 +2,7 @@ use crate::format; use anyhow::{anyhow, Result}; use futures::{Stream, StreamExt}; use mullvad_management_interface::{client::DaemonEvent, MullvadProxyClient}; -use mullvad_types::device::DeviceState; -use mullvad_types::states::TunnelState; +use mullvad_types::{device::DeviceState, states::TunnelState}; pub async fn connect(wait: bool) -> Result<()> { let mut rpc = MullvadProxyClient::new().await?; diff --git a/mullvad-daemon/src/custom_lists.rs b/mullvad-daemon/src/custom_lists.rs index f7bdc1b3ba..ce75e19bc5 100644 --- a/mullvad-daemon/src/custom_lists.rs +++ b/mullvad-daemon/src/custom_lists.rs @@ -43,7 +43,15 @@ where let settings_changed = self .settings .update(|settings| { - settings.custom_lists.custom_lists.remove(&id); + let index = settings + .custom_lists + .custom_lists + .iter() + .position(|custom_list| custom_list.id == id) + .unwrap(); + // NOTE: Not using swap remove because it would make user output slightly + // more confusing and the cost is so small. + settings.custom_lists.custom_lists.remove(index); }) .await .map_err(Error::Settings); @@ -83,11 +91,7 @@ where .settings .update(|settings| { let custom_list = CustomList::new(name); - assert!(settings - .custom_lists - .custom_lists - .insert(custom_list.id.clone(), custom_list) - .is_none()); + settings.custom_lists.custom_lists.push(custom_list); }) .await .map_err(Error::Settings); @@ -127,7 +131,8 @@ where let locations = &mut settings .custom_lists .custom_lists - .get_mut(&id) + .iter_mut() + .find(|custom_list| custom_list.id == id) .unwrap() .locations; @@ -183,7 +188,8 @@ where let locations = &mut settings .custom_lists .custom_lists - .get_mut(&id) + .iter_mut() + .find(|custom_list| custom_list.id == id) .unwrap() .locations; if let Some(index) = locations @@ -247,7 +253,8 @@ where settings .custom_lists .custom_lists - .get_mut(&id) + .iter_mut() + .find(|custom_list| custom_list.id == id) .unwrap() .name = new_name; }) @@ -280,7 +287,15 @@ where need_to_reconnect |= list_id == custom_list_id; } - if let TunnelState::Connecting { endpoint, location: _ } | TunnelState::Connected { endpoint, location: _ } = &self.tunnel_state { + if let TunnelState::Connecting { + endpoint, + location: _, + } + | TunnelState::Connected { + endpoint, + location: _, + } = &self.tunnel_state + { match endpoint.tunnel_type { TunnelType::Wireguard => { if relay_settings.wireguard_constraints.use_multihop { diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index a0e308abe6..46ee8b25df 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -8,6 +8,7 @@ pub mod account_history; mod api; #[cfg(not(target_os = "android"))] mod cleanup; +mod custom_lists; pub mod device; mod dns; pub mod exception_logging; @@ -27,7 +28,6 @@ mod target_state; mod tunnel; pub mod version; mod version_check; -mod custom_lists; use crate::target_state::PersistentTargetState; use device::{AccountEvent, PrivateAccountAndDevice, PrivateDeviceEvent}; @@ -1033,12 +1033,16 @@ where GetSettings(tx) => self.on_get_settings(tx), RotateWireguardKey(tx) => self.on_rotate_wireguard_key(tx).await, GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await, - ListCustomLists(tx) => self.on_list_custom_lists(tx).await, - GetCustomList(tx, name) => self.on_get_custom_list(tx, name).await, + ListCustomLists(tx) => self.on_list_custom_lists(tx), + GetCustomList(tx, name) => self.on_get_custom_list(tx, name), CreateCustomList(tx, name) => self.on_create_custom_list(tx, name).await, DeleteCustomList(tx, name) => self.on_delete_custom_list(tx, name).await, - UpdateCustomListLocation(tx, update) => self.on_update_custom_list_location(tx, update).await, - RenameCustomList(tx, name, new_name) => self.on_rename_custom_list(tx, name, new_name).await, + UpdateCustomListLocation(tx, update) => { + self.on_update_custom_list_location(tx, update).await + } + RenameCustomList(tx, name, new_name) => { + self.on_rename_custom_list(tx, name, new_name).await + } GetVersionInfo(tx) => self.on_get_version_info(tx).await, IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx), GetCurrentVersion(tx) => self.on_get_current_version(tx), @@ -2236,23 +2240,34 @@ where Self::oneshot_send(tx, result, "get_wireguard_key response"); } - async fn on_list_custom_lists(&mut self, tx: ResponseTx<Vec<CustomList>, Error>) { - let result = self.settings.custom_lists.custom_lists.values().cloned().collect(); + fn on_list_custom_lists(&mut self, tx: ResponseTx<Vec<CustomList>, Error>) { + let result = self.settings.custom_lists.custom_lists.clone(); Self::oneshot_send(tx, Ok(result), "list_custom_lists response"); } - async fn on_get_custom_list(&mut self, tx: ResponseTx<CustomList, Error>, name: String) { - let result = self.settings.custom_lists.get_custom_list_with_name(&name).cloned().ok_or(Error::CustomListError(custom_lists::Error::ListNotFound)); + fn on_get_custom_list(&mut self, tx: ResponseTx<CustomList, Error>, name: String) { + let result = self + .settings + .custom_lists + .get_custom_list_with_name(&name) + .cloned() + .ok_or(Error::CustomListError(custom_lists::Error::ListNotFound)); Self::oneshot_send(tx, result, "create_custom_list response"); } async fn on_create_custom_list(&mut self, tx: ResponseTx<(), Error>, name: String) { - let result = self.create_custom_list(name).await.map_err(Error::CustomListError); + let result = self + .create_custom_list(name) + .await + .map_err(Error::CustomListError); Self::oneshot_send(tx, result, "create_custom_list response"); } async fn on_delete_custom_list(&mut self, tx: ResponseTx<(), Error>, name: String) { - let result = self.delete_custom_list(name).await.map_err(Error::CustomListError); + let result = self + .delete_custom_list(name) + .await + .map_err(Error::CustomListError); Self::oneshot_send(tx, result, "delete_custom_list response"); } @@ -2261,7 +2276,10 @@ where tx: ResponseTx<(), Error>, update: CustomListLocationUpdate, ) { - let result = self.update_custom_list_location(update).await.map_err(Error::CustomListError); + let result = self + .update_custom_list_location(update) + .await + .map_err(Error::CustomListError); Self::oneshot_send(tx, result, "update_custom_list_location response"); } @@ -2271,7 +2289,10 @@ where name: String, new_name: String, ) { - let result = self.rename_custom_list(name, new_name).await.map_err(Error::CustomListError); + let result = self + .rename_custom_list(name, new_name) + .await + .map_err(Error::CustomListError); Self::oneshot_send(tx, result, "rename_custom_list response"); } @@ -2442,5 +2463,3 @@ fn new_selector_config( custom_lists: settings.custom_lists.clone(), } } - - diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 84e90b1726..aa46c91584 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -1,4 +1,7 @@ -use crate::{account_history, device, settings, DaemonCommand, DaemonCommandSender, EventListener, custom_lists}; +use crate::{ + account_history, custom_lists, device, settings, DaemonCommand, DaemonCommandSender, + EventListener, +}; use futures::{ channel::{mpsc, oneshot}, StreamExt, @@ -1106,24 +1109,30 @@ fn map_account_history_error(error: account_history::Error) -> Status { /// Converts an instance of [`mullvad_daemon::account_history::Error`] into a tonic status. fn map_custom_list_error(error: custom_lists::Error) -> Status { match error { - custom_lists::Error::ListExists => { - Status::with_details(Code::AlreadyExists, error.to_string(), mullvad_management_interface::CUSTOM_LIST_LIST_EXISTS_DETAILS.into()) - } - custom_lists::Error::ListNotFound => { - Status::with_details(Code::NotFound, error.to_string(), mullvad_management_interface::CUSTOM_LIST_LIST_NOT_FOUND_DETAILS.into()) - } + custom_lists::Error::ListExists => Status::with_details( + Code::AlreadyExists, + error.to_string(), + mullvad_management_interface::CUSTOM_LIST_LIST_EXISTS_DETAILS.into(), + ), + custom_lists::Error::ListNotFound => Status::with_details( + Code::NotFound, + error.to_string(), + mullvad_management_interface::CUSTOM_LIST_LIST_NOT_FOUND_DETAILS.into(), + ), custom_lists::Error::CannotAddOrRemoveAny => { Status::new(Code::InvalidArgument, error.to_string()) } - custom_lists::Error::LocationExists => { - Status::with_details(Code::AlreadyExists, error.to_string(), mullvad_management_interface::CUSTOM_LIST_LOCATION_EXISTS_DETAILS.into()) - } - custom_lists::Error::LocationNotFoundInlist => { - Status::with_details(Code::NotFound, error.to_string(), mullvad_management_interface::CUSTOM_LIST_LOCATION_NOT_FOUND_DETAILS.into()) - } - custom_lists::Error::Settings(error) => { - map_settings_error(error) - } + custom_lists::Error::LocationExists => Status::with_details( + Code::AlreadyExists, + error.to_string(), + mullvad_management_interface::CUSTOM_LIST_LOCATION_EXISTS_DETAILS.into(), + ), + custom_lists::Error::LocationNotFoundInlist => Status::with_details( + Code::NotFound, + error.to_string(), + mullvad_management_interface::CUSTOM_LIST_LOCATION_NOT_FOUND_DETAILS.into(), + ), + custom_lists::Error::Settings(error) => map_settings_error(error), } } diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index 5e90067166..ee3f355598 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -377,7 +377,11 @@ mod test { "normal": { "location": { "only": { - "country": "gb" + "location": { + "location": { + "country": "gb" + } + } } }, "tunnel_protocol": { @@ -414,7 +418,10 @@ mod test { } }, "settings_version": 5, - "show_beta_releases": false + "show_beta_releases": false, + "custom_lists": { + "custom_lists": [] + } }"#; let _ = SettingsPersister::load_from_bytes(settings).unwrap(); diff --git a/mullvad-jni/src/classes.rs b/mullvad-jni/src/classes.rs index e7d34f7966..88b5f6d938 100644 --- a/mullvad-jni/src/classes.rs +++ b/mullvad-jni/src/classes.rs @@ -20,14 +20,18 @@ pub const CLASSES: &[&str] = &[ "net/mullvad/mullvadvpn/model/DeviceState$LoggedOut", "net/mullvad/mullvadvpn/model/DeviceState$Revoked", "net/mullvad/mullvadvpn/model/RemoveDeviceEvent", + "net/mullvad/mullvadvpn/model/GeographicLocationConstraint", + "net/mullvad/mullvadvpn/model/GeographicLocationConstraint$City", + "net/mullvad/mullvadvpn/model/GeographicLocationConstraint$Country", + "net/mullvad/mullvadvpn/model/GeographicLocationConstraint$Hostname", "net/mullvad/mullvadvpn/model/GeoIpLocation", "net/mullvad/mullvadvpn/model/GetAccountDataResult$Ok", "net/mullvad/mullvadvpn/model/GetAccountDataResult$InvalidAccount", "net/mullvad/mullvadvpn/model/GetAccountDataResult$RpcError", "net/mullvad/mullvadvpn/model/GetAccountDataResult$OtherError", - "net/mullvad/mullvadvpn/model/LocationConstraint$City", - "net/mullvad/mullvadvpn/model/LocationConstraint$Country", - "net/mullvad/mullvadvpn/model/LocationConstraint$Hostname", + "net/mullvad/mullvadvpn/model/LocationConstraint", + "net/mullvad/mullvadvpn/model/LocationConstraint$Location", + "net/mullvad/mullvadvpn/model/LocationConstraint$CustomList", "net/mullvad/mullvadvpn/model/ObfuscationSettings", "net/mullvad/mullvadvpn/model/PublicKey", "net/mullvad/mullvadvpn/model/QuantumResistantState", diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 709079bee4..f82f55078a 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -294,10 +294,6 @@ message LocationConstraint { } } -message CustomLocationConstraints { - repeated RelayLocation locations = 1; -} - message RelayLocation { string country = 1; string city = 2; @@ -333,7 +329,7 @@ message CustomListRename { message CustomListLocationUpdate { enum State { ADD = 0; - REMOVE= 1; + REMOVE = 1; } State state = 1; string name = 2; @@ -346,13 +342,9 @@ message CustomList { repeated RelayLocation locations = 3; } -message CustomLists { - repeated CustomList custom_lists = 1; -} +message CustomLists { repeated CustomList custom_lists = 1; } -message CustomListSettings { - map<string, CustomList> custom_lists = 1; -} +message CustomListSettings { repeated CustomList custom_lists = 1; } message Settings { RelaySettings relay_settings = 1; diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 7674479690..b8c0b32085 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -455,12 +455,18 @@ impl MullvadProxyClient { } pub async fn create_custom_list(&mut self, name: String) -> Result<()> { - self.0.create_custom_list(name).await.map_err(map_custom_list_error)?; + self.0 + .create_custom_list(name) + .await + .map_err(map_custom_list_error)?; Ok(()) } pub async fn delete_custom_list(&mut self, name: String) -> Result<()> { - self.0.delete_custom_list(name).await.map_err(map_custom_list_error)?; + self.0 + .delete_custom_list(name) + .await + .map_err(map_custom_list_error)?; Ok(()) } @@ -616,7 +622,7 @@ fn map_custom_list_error(status: Status) -> Error { } else { Error::Rpc(status) } - }, + } Code::AlreadyExists => { let details = status.details(); if details == crate::CUSTOM_LIST_LOCATION_EXISTS_DETAILS { @@ -626,7 +632,7 @@ fn map_custom_list_error(status: Status) -> Error { } else { Error::Rpc(status) } - }, + } Code::InvalidArgument => Error::CustomListCannotAddOrRemoveAny, _other => Error::Rpc(status), } diff --git a/mullvad-management-interface/src/types/conversions/custom_list.rs b/mullvad-management-interface/src/types/conversions/custom_list.rs index acfa8831ee..61e4ce9478 100644 --- a/mullvad-management-interface/src/types/conversions/custom_list.rs +++ b/mullvad-management-interface/src/types/conversions/custom_list.rs @@ -23,9 +23,7 @@ impl From<&mullvad_types::custom_list::CustomListsSettings> for proto::CustomLis custom_lists: settings .custom_lists .iter() - .map(|(id, custom_list)| { - (id.0.to_string(), proto::CustomList::from(custom_list.clone())) - }) + .map(|custom_list| proto::CustomList::from(custom_list.clone())) .collect(), } } @@ -39,13 +37,8 @@ impl TryFrom<proto::CustomListSettings> for mullvad_types::custom_list::CustomLi custom_lists: settings .custom_lists .into_iter() - .map(|(id, custom_list)| { - Ok(( - Id::try_from(id.as_str()).map_err(|_| FromProtobufTypeError::InvalidArgument("Id could not be parsed to a uuid"))?, - mullvad_types::custom_list::CustomList::try_from(custom_list)?, - )) - }) - .collect::<Result<std::collections::HashMap<_, _>, _>>()?, + .map(mullvad_types::custom_list::CustomList::try_from) + .collect::<Result<Vec<_>, _>>()?, }) } } @@ -117,7 +110,7 @@ impl From<mullvad_types::custom_list::CustomList> for proto::CustomList { .map(proto::RelayLocation::from) .collect(); Self { - id: custom_list.id.0.to_string(), + id: custom_list.id.to_string(), name: custom_list.name, locations, } @@ -137,7 +130,9 @@ impl TryFrom<proto::CustomList> for mullvad_types::custom_list::CustomList { FromProtobufTypeError::InvalidArgument("Could not convert custom list from proto") })?; Ok(Self { - id: Id::try_from(custom_list.id.as_str()).map_err(|_| FromProtobufTypeError::InvalidArgument("Id could not be parsed to a uuid"))?, + id: Id::try_from(custom_list.id.as_str()).map_err(|_| { + FromProtobufTypeError::InvalidArgument("Id could not be parsed to a uuid") + })?, name: custom_list.name, locations, }) @@ -153,7 +148,7 @@ impl TryFrom<proto::RelayLocation> for GeographicLocationConstraint { relay_location.city.as_ref(), relay_location.hostname.as_ref(), ) { - ("", _, _) => Err(FromProtobufTypeError::InvalidArgument( + ("", ..) => Err(FromProtobufTypeError::InvalidArgument( "Relay location formatted incorrectly", )), (_country, "", "") => Ok(GeographicLocationConstraint::Country( diff --git a/mullvad-management-interface/src/types/conversions/relay_constraints.rs b/mullvad-management-interface/src/types/conversions/relay_constraints.rs index 8c1f89d368..221c2140b2 100644 --- a/mullvad-management-interface/src/types/conversions/relay_constraints.rs +++ b/mullvad-management-interface/src/types/conversions/relay_constraints.rs @@ -40,8 +40,12 @@ impl TryFrom<&proto::WireguardConstraints> entry_location: constraints .entry_location .clone() - .map(|loc| Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from(loc).ok()) - .flatten() + .and_then(|loc| { + Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from( + loc, + ) + .ok() + }) .unwrap_or(Constraint::Any), }) } @@ -98,8 +102,7 @@ impl TryFrom<proto::RelaySettings> for mullvad_types::relay_constraints::RelaySe proto::relay_settings::Endpoint::Normal(settings) => { let location = settings .location - .map(|loc| Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from(loc).ok()) - .flatten() + .and_then(|loc| Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from(loc).ok()) .unwrap_or(Constraint::Any); let providers = try_providers_constraint_from_proto(&settings.providers)?; let ownership = try_ownership_constraint_from_i32(settings.ownership)?; @@ -251,10 +254,12 @@ impl TryFrom<proto::RelaySettingsUpdate> for mullvad_types::relay_constraints::R // If `location` isn't provided, no changes are made. // If `location` is provided, but is an empty vector, // then the constraint is set to `Constraint::Any`. - let location = settings - .location - .map(|loc| Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from(loc).ok()) - .flatten(); + let location = settings.location.and_then(|loc| { + Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from( + loc, + ) + .ok() + }); let providers = if let Some(ref provider_update) = settings.providers { Some(try_providers_constraint_from_proto( &provider_update.providers, @@ -489,9 +494,7 @@ impl From<mullvad_types::relay_constraints::LocationConstraint> for proto::Locat )), }, LocationConstraint::CustomList { list_id } => Self { - r#type: Some(proto::location_constraint::Type::CustomList( - list_id.0.to_string(), - )), + r#type: Some(proto::location_constraint::Type::CustomList(list_id)), }, } } @@ -589,9 +592,9 @@ impl TryFrom<proto::BridgeSettings> for mullvad_types::relay_constraints::Bridge proto::bridge_settings::Type::Normal(constraints) => { let location = match constraints.location { None => Constraint::Any, - Some(location) => { - Constraint::<mullvad_types::relay_constraints::LocationConstraint>::try_from(location)? - } + Some(location) => Constraint::< + mullvad_types::relay_constraints::LocationConstraint, + >::try_from(location)?, }; let providers = try_providers_constraint_from_proto(&constraints.providers)?; let ownership = try_ownership_constraint_from_i32(constraints.ownership)?; diff --git a/mullvad-relay-selector/src/lib.rs b/mullvad-relay-selector/src/lib.rs index 7798c1dc88..9fd6e3c6da 100644 --- a/mullvad-relay-selector/src/lib.rs +++ b/mullvad-relay-selector/src/lib.rs @@ -489,6 +489,8 @@ impl RelaySelector { ), }; + // Nightly clippy seems wrong about this being a redundant clone + #[allow(clippy::redundant_clone)] let mut preferred_matcher: RelayMatcher<WireguardMatcher> = relay_matcher.clone(); preferred_matcher.endpoint_matcher.port = preferred_matcher .endpoint_matcher @@ -1296,8 +1298,10 @@ impl NormalSelectedRelay { mod test { use super::*; use mullvad_types::{ + custom_list::CustomListsSettings, relay_constraints::{ - BridgeConstraints, RelayConstraints, RelayConstraintsUpdate, RelaySettingsUpdate, + BridgeConstraints, GeographicLocationConstraint, RelayConstraints, + RelayConstraintsUpdate, RelaySettingsUpdate, WireguardConstraints, }, relay_list::{ OpenVpnEndpoint, OpenVpnEndpointData, Relay, RelayListCity, RelayListCountry, @@ -1453,7 +1457,9 @@ mod test { ))), config: Arc::new(Mutex::new(SelectorConfig { relay_settings: RelaySettings::Normal(RelayConstraints { - location: Constraint::Only(Location::Country("se".to_owned())), + location: Constraint::Only(LocationConstraint::from( + GeographicLocationConstraint::Country("se".to_owned()), + )), ..Default::default() }), bridge_settings: BridgeSettings::Normal(BridgeConstraints::default()), @@ -1463,6 +1469,7 @@ mod test { }, bridge_state: BridgeState::Auto, default_tunnel_type: default_tunnel_type(), + custom_lists: CustomListsSettings::default(), })), } } @@ -1476,13 +1483,13 @@ mod test { let relay_selector = new_relay_selector(); // Prefer WG if the location only supports it - let location = Location::Hostname( + let location = GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se9-wireguard".to_string(), ); let relay_constraints = RelayConstraints { - location: Constraint::Only(location), + location: Constraint::Only(LocationConstraint::from(location)), tunnel_protocol: Constraint::Any, ..RelayConstraints::default() }; @@ -1492,6 +1499,7 @@ mod test { BridgeState::Off, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1505,18 +1513,19 @@ mod test { BridgeState::Off, attempt, TunnelType::Wireguard, + &CustomListsSettings::default() ) .is_ok()); } // Prefer OpenVPN if the location only supports it - let location = Location::Hostname( + let location = GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se-got-001".to_string(), ); let relay_constraints = RelayConstraints { - location: Constraint::Only(location), + location: Constraint::Only(LocationConstraint::from(location)), tunnel_protocol: Constraint::Any, ..RelayConstraints::default() }; @@ -1526,6 +1535,7 @@ mod test { BridgeState::Off, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1539,6 +1549,7 @@ mod test { BridgeState::Off, attempt, TunnelType::Wireguard, + &CustomListsSettings::default() ) .is_ok()); } @@ -1553,6 +1564,7 @@ mod test { BridgeState::Off, attempt, TunnelType::OpenVpn, + &CustomListsSettings::default() ); assert_eq!( preferred.tunnel_protocol, @@ -1563,6 +1575,7 @@ mod test { BridgeState::Off, attempt, TunnelType::OpenVpn, + &CustomListsSettings::default() ) { Ok(result) if matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)) => (), _ => panic!("OpenVPN endpoint was not selected"), @@ -1575,25 +1588,26 @@ mod test { fn test_wg_entry_hostname_collision() { let relay_selector = new_relay_selector(); - let location1 = Location::Hostname( + let location1 = GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se9-wireguard".to_string(), ); - let location2 = Location::Hostname( + let location2 = GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se10-wireguard".to_string(), ); let mut relay_constraints = RelayConstraints { - location: Constraint::Only(location1.clone()), + location: Constraint::Only(LocationConstraint::from(location1.clone())), tunnel_protocol: Constraint::Only(TunnelType::Wireguard), ..RelayConstraints::default() }; relay_constraints.wireguard_constraints.use_multihop = true; - relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location1); + relay_constraints.wireguard_constraints.entry_location = + Constraint::Only(LocationConstraint::from(location1)); // The same host cannot be used for entry and exit assert!(relay_selector @@ -1602,10 +1616,12 @@ mod test { BridgeState::Off, 0, TunnelType::Wireguard, + &CustomListsSettings::default() ) .is_err()); - relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location2); + relay_constraints.wireguard_constraints.entry_location = + Constraint::Only(LocationConstraint::from(location2)); // If the entry and exit differ, this should succeed assert!(relay_selector @@ -1614,6 +1630,7 @@ mod test { BridgeState::Off, 0, TunnelType::Wireguard, + &CustomListsSettings::default() ) .is_ok()); } @@ -1624,12 +1641,15 @@ mod test { let specific_hostname = "se10-wireguard"; - let location_general = Location::City("se".to_string(), "got".to_string()); - let location_specific = Location::Hostname( + let location_general = LocationConstraint::from(GeographicLocationConstraint::City( + "se".to_string(), + "got".to_string(), + )); + let location_specific = LocationConstraint::from(GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), specific_hostname.to_string(), - ); + )); let mut relay_constraints = RelayConstraints { location: Constraint::Only(location_general.clone()), @@ -1643,7 +1663,13 @@ mod test { // The exit must not equal the entry let exit_relay = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, TunnelType::OpenVpn) + .get_tunnel_endpoint( + &relay_constraints, + BridgeState::Off, + 0, + TunnelType::OpenVpn, + &CustomListsSettings::default(), + ) .map_err(|error| error.to_string())? .exit_relay; @@ -1663,6 +1689,7 @@ mod test { BridgeState::Off, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ) .map_err(|error| error.to_string())?; @@ -1779,6 +1806,7 @@ mod test { BridgeState::Auto, retry_attempt, default_tunnel_type(), + &CustomListsSettings::default(), ); println!("relay: {relay:?}, constraints: {relay_constraints:?}"); @@ -1807,11 +1835,11 @@ mod test { fn test_bridge_constraints() -> Result<(), String> { let relay_selector = new_relay_selector(); - let location = Location::Hostname( + let location = LocationConstraint::from(GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se-got-001".to_string(), - ); + )); let mut relay_constraints = RelayConstraints { location: Constraint::Only(location), tunnel_protocol: Constraint::Any, @@ -1827,6 +1855,7 @@ mod test { BridgeState::On, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1842,11 +1871,11 @@ mod test { ); // Ignore bridge state where WireGuard is used - let location = Location::Hostname( + let location = LocationConstraint::from(GeographicLocationConstraint::Hostname( "se".to_string(), "got".to_string(), "se10-wireguard".to_string(), - ); + )); let relay_constraints = RelayConstraints { location: Constraint::Only(location), tunnel_protocol: Constraint::Any, @@ -1857,6 +1886,7 @@ mod test { BridgeState::On, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1880,6 +1910,7 @@ mod test { BridgeState::On, 0, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1891,6 +1922,7 @@ mod test { BridgeState::On, 2, TunnelType::Wireguard, + &CustomListsSettings::default(), ); assert_eq!( preferred.tunnel_protocol, @@ -1922,7 +1954,7 @@ mod test { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, default_tunnel_type()) + let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, default_tunnel_type(), &CustomListsSettings::default()) .expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection"); // Windows will ignore WireGuard until WireGuard is supported well enough // TODO: Remove this caveat once Windows defaults to using WireGuard @@ -1974,7 +2006,7 @@ mod test { fn test_selecting_wireguard_location_will_consider_multihop() { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type()) + let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type(), &CustomListsSettings::default()) .expect("Failed to get relay when tunnel constraints are set to default WireGuard multihop constraints"); assert!(result.entry_relay.is_some()); @@ -1985,7 +2017,7 @@ mod test { fn test_selecting_wg_endpoint_with_udp2tcp_obfuscation() { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_SINGLEHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type()) + let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_SINGLEHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type(), &CustomListsSettings::default()) .expect("Failed to get relay when tunnel constraints are set to default WireGuard constraints"); assert!(result.entry_relay.is_none()); @@ -2014,7 +2046,7 @@ mod test { fn test_selecting_wg_endpoint_with_auto_obfuscation() { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_SINGLEHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type()) + let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_SINGLEHOP_CONSTRAINTS, BridgeState::Off, 0, default_tunnel_type(), &CustomListsSettings::default()) .expect("Failed to get relay when tunnel constraints are set to default WireGuard constraints"); assert!(result.entry_relay.is_none()); @@ -2059,6 +2091,7 @@ mod test { BridgeState::Off, attempt, TunnelType::Wireguard, + &CustomListsSettings::default(), ) .expect("Failed to select a WireGuard relay"); assert!(result.entry_relay.is_none()); @@ -2095,7 +2128,13 @@ mod test { for i in 0..10 { constraints.ownership = Constraint::Only(Ownership::MullvadOwned); let relay = relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Auto, i, TunnelType::Wireguard) + .get_tunnel_endpoint( + &constraints, + BridgeState::Auto, + i, + TunnelType::Wireguard, + &CustomListsSettings::default(), + ) .unwrap(); assert!(matches!( relay, @@ -2107,7 +2146,13 @@ mod test { constraints.ownership = Constraint::Only(Ownership::Rented); let relay = relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Auto, i, TunnelType::Wireguard) + .get_tunnel_endpoint( + &constraints, + BridgeState::Auto, + i, + TunnelType::Wireguard, + &CustomListsSettings::default(), + ) .unwrap(); assert!(matches!( relay, @@ -2134,7 +2179,9 @@ mod test { config.relay_settings = config.relay_settings.merge(RelaySettingsUpdate::Normal( RelayConstraintsUpdate { tunnel_protocol: Some(tunnel_protocol), - location: Some(Constraint::Only(Location::Country("se".to_string()))), + location: Some(Constraint::Only(LocationConstraint::from( + GeographicLocationConstraint::Country("se".to_string()), + ))), ..Default::default() }, )); @@ -2178,7 +2225,13 @@ mod test { Providers::new(EXPECTED_PROVIDERS.into_iter().map(|p| p.to_owned())).unwrap(), ); let relay = relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Auto, i, TunnelType::Wireguard) + .get_tunnel_endpoint( + &constraints, + BridgeState::Auto, + i, + TunnelType::Wireguard, + &CustomListsSettings::default(), + ) .unwrap(); assert!( EXPECTED_PROVIDERS.contains(&relay.exit_relay.provider.as_str()), diff --git a/mullvad-types/src/custom_list.rs b/mullvad-types/src/custom_list.rs index f9fa2e84f3..ee4f914f75 100644 --- a/mullvad-types/src/custom_list.rs +++ b/mullvad-types/src/custom_list.rs @@ -1,34 +1,21 @@ use crate::relay_constraints::{Constraint, GeographicLocationConstraint}; +#[cfg(target_os = "android")] +use jnix::{FromJava, IntoJava}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::str::FromStr; -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct Id(pub uuid::Uuid); - -impl TryFrom<&str> for Id { - type Error = (); - fn try_from(string: &str) -> Result<Self, Self::Error> { - let uuid = uuid::Uuid::from_str(string).map_err(|_| ())?; - Ok(Id(uuid)) - } -} - -impl std::fmt::Display for Id { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fmt.write_str(&self.0.to_string()) - } -} +pub type Id = String; #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +#[cfg_attr(target_os = "android", derive(FromJava, IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] pub struct CustomListsSettings { - pub custom_lists: HashMap<Id, CustomList>, + pub custom_lists: Vec<CustomList>, } impl CustomListsSettings { pub fn get_custom_list_with_name(&self, name: &String) -> Option<&CustomList> { self.custom_lists - .values() + .iter() .find(|custom_list| &custom_list.name == name) } } @@ -46,6 +33,8 @@ pub enum CustomListLocationUpdate { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[cfg_attr(target_os = "android", derive(FromJava, IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] pub struct CustomList { pub id: Id, pub name: String, @@ -55,7 +44,7 @@ pub struct CustomList { impl CustomList { pub fn new(name: String) -> Self { CustomList { - id: Id(uuid::Uuid::new_v4()), + id: uuid::Uuid::new_v4().to_string(), name, locations: Vec::new(), } diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index d518330cea..18ab4eef03 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -10,7 +10,7 @@ use crate::{ #[cfg(target_os = "android")] use jnix::{jni::objects::JObject, FromJava, IntoJava, JnixEnv}; use serde::{Deserialize, Serialize}; -use std::{collections::HashSet, fmt, str::FromStr, fmt::Write}; +use std::{collections::HashSet, fmt, fmt::Write, str::FromStr}; use talpid_types::net::{openvpn::ProxySettings, IpVersion, TransportProtocol, TunnelType}; pub trait Match<T> { @@ -205,7 +205,11 @@ pub enum RelaySettings { } impl RelaySettings { - pub fn format(&self, s: &mut String, custom_lists: &CustomListsSettings) -> Result<(), fmt::Error> { + pub fn format( + &self, + s: &mut String, + custom_lists: &CustomListsSettings, + ) -> Result<(), fmt::Error> { match self { RelaySettings::CustomTunnelEndpoint(endpoint) => { write!(s, "custom endpoint {endpoint}") @@ -232,6 +236,9 @@ impl RelaySettings { } #[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(target_os = "android", derive(IntoJava, FromJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] pub enum LocationConstraint { Location { location: GeographicLocationConstraint, @@ -263,7 +270,8 @@ impl ResolvedLocationConstraint { } Constraint::Only(LocationConstraint::CustomList { list_id }) => custom_lists .custom_lists - .get(&list_id) + .iter() + .find(|custom_list| custom_list.id == list_id) .map(|custom_list| { Constraint::Only(Self::Locations { locations: custom_list.locations.clone(), @@ -338,20 +346,22 @@ impl LocationConstraint { fn format(&self, f: &mut String, custom_lists: &CustomListsSettings) -> Result<(), fmt::Error> { match self { Self::Location { location } => writeln!(f, "location - {location}"), - Self::CustomList { list_id } => { - match custom_lists.custom_lists.get(list_id) { - Some(list) => { - writeln!(f, "custom list - {}", list.name)?; - for location in &list.locations { - writeln!(f, "\t{}", location)?; - } - Ok(()) - }, - None => { - writeln!(f, "custom list - list not found") + Self::CustomList { list_id } => match custom_lists + .custom_lists + .iter() + .find(|custom_list| &custom_list.id == list_id) + { + Some(list) => { + writeln!(f, "custom list - {}", list.name)?; + for location in &list.locations { + writeln!(f, "\t{}", location)?; } + Ok(()) } - } + None => { + writeln!(f, "custom list - list not found") + } + }, } } } @@ -407,16 +417,20 @@ impl RelayConstraints { } impl RelayConstraints { - pub fn format(&self, f: &mut String, custom_lists: &CustomListsSettings) -> Result<(), fmt::Error> { + pub fn format( + &self, + f: &mut String, + custom_lists: &CustomListsSettings, + ) -> Result<(), fmt::Error> { match self.tunnel_protocol { Constraint::Any => { writeln!( f, - "Tunnel protocol: Any\nOpenVPN: {}\nWireguard: ", - &self.openvpn_constraints, + "Tunnel protocol: Any\nOpenVPN constraints: {}\nWireguard constraints: ", + &self.openvpn_constraints, )?; self.wireguard_constraints.format(f, custom_lists)?; - }, + } Constraint::Only(ref tunnel_protocol) => { writeln!(f, "Tunnel protocol: {}", tunnel_protocol)?; match tunnel_protocol { @@ -435,7 +449,7 @@ impl RelayConstraints { Constraint::Only(ref location_constraint) => { write!(f, "Location: ")?; location_constraint.format(f, custom_lists)?; - }, + } } match self.providers { Constraint::Any => writeln!(f, "Provider: Any")?, @@ -767,7 +781,7 @@ impl WireguardConstraints { Constraint::Only(location) => { write!(f, "Wireguard entry ")?; location.format(f, custom_lists) - }, + } } } else { Ok(()) @@ -886,10 +900,16 @@ pub struct BridgeConstraints { } impl BridgeConstraints { - pub fn format(&self, f: &mut String, custom_lists: &CustomListsSettings) -> Result<(), fmt::Error> { + pub fn format( + &self, + f: &mut String, + custom_lists: &CustomListsSettings, + ) -> Result<(), fmt::Error> { match self.location { Constraint::Any => write!(f, "any location")?, - Constraint::Only(ref location_constraint) => location_constraint.format(f, custom_lists)?, + Constraint::Only(ref location_constraint) => { + location_constraint.format(f, custom_lists)? + } } write!(f, " using ")?; match self.providers { diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs index e54ea86e7d..5af6ecfe08 100644 --- a/mullvad-types/src/settings/mod.rs +++ b/mullvad-types/src/settings/mod.rs @@ -73,6 +73,7 @@ pub struct Settings { #[cfg_attr(target_os = "android", jnix(skip))] pub bridge_state: BridgeState, /// All of the custom relay lists + #[cfg_attr(target_os = "android", jnix(skip))] pub custom_lists: CustomListsSettings, /// If the daemon should allow communication with private (LAN) networks. pub allow_lan: bool, @@ -165,7 +166,9 @@ impl Settings { } let mut old_settings_string = String::new(); - let _ = self.relay_settings.format(&mut old_settings_string, &self.custom_lists); + let _ = self + .relay_settings + .format(&mut old_settings_string, &self.custom_lists); let mut new_settings_string = String::new(); let _ = new_settings.format(&mut new_settings_string, &self.custom_lists); @@ -213,5 +216,3 @@ impl Default for TunnelOptions { } } } - - |
