summaryrefslogtreecommitdiffhomepage
path: root/mullvad-daemon/src/settings
diff options
context:
space:
mode:
Diffstat (limited to 'mullvad-daemon/src/settings')
-rw-r--r--mullvad-daemon/src/settings/mod.rs461
-rw-r--r--mullvad-daemon/src/settings/patch.rs481
2 files changed, 942 insertions, 0 deletions
diff --git a/mullvad-daemon/src/settings/mod.rs b/mullvad-daemon/src/settings/mod.rs
new file mode 100644
index 0000000000..99f637c023
--- /dev/null
+++ b/mullvad-daemon/src/settings/mod.rs
@@ -0,0 +1,461 @@
+#[cfg(not(target_os = "android"))]
+use futures::TryFutureExt;
+use mullvad_types::{
+ relay_constraints::{RelayConstraints, RelaySettings, WireguardConstraints},
+ settings::{DnsState, Settings},
+};
+use std::{
+ fmt::{self, Display},
+ ops::Deref,
+ path::{Path, PathBuf},
+};
+use talpid_core::firewall::is_local_address;
+use talpid_types::ErrorExt;
+use tokio::{
+ fs,
+ io::{self, AsyncWriteExt},
+};
+
+pub mod patch;
+
+const SETTINGS_FILE: &str = "settings.json";
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ #[error(display = "Unable to read settings file {}", _0)]
+ ReadError(String, #[error(source)] io::Error),
+
+ #[error(display = "Unable to parse settings file")]
+ ParseError(#[error(source)] serde_json::Error),
+
+ #[error(display = "Unable to remove settings file {}", _0)]
+ #[cfg(not(target_os = "android"))]
+ DeleteError(String, #[error(source)] io::Error),
+
+ #[error(display = "Unable to serialize settings to JSON")]
+ SerializeError(#[error(source)] serde_json::Error),
+
+ #[error(display = "Unable to write settings to {}", _0)]
+ WriteError(String, #[error(source)] io::Error),
+}
+
+/// Converts an [Error] to a management interface status
+#[cfg(not(target_os = "android"))]
+impl From<Error> for mullvad_management_interface::Status {
+ fn from(error: Error) -> mullvad_management_interface::Status {
+ use mullvad_management_interface::{Code, Status};
+
+ match error {
+ Error::DeleteError(..) | Error::WriteError(..) | Error::ReadError(..) => {
+ Status::new(Code::FailedPrecondition, error.to_string())
+ }
+ Error::SerializeError(..) | Error::ParseError(..) => {
+ Status::new(Code::Internal, error.to_string())
+ }
+ }
+ }
+}
+
+pub struct SettingsPersister {
+ settings: Settings,
+ path: PathBuf,
+ #[allow(clippy::type_complexity)]
+ on_change_listeners: Vec<Box<dyn Fn(&Settings)>>,
+}
+
+pub type MadeChanges = bool;
+
+impl SettingsPersister {
+ /// Loads user settings from file. If it fails, it returns the defaults.
+ pub async fn load(settings_dir: &Path) -> Self {
+ let path = settings_dir.join(SETTINGS_FILE);
+ let (mut settings, mut should_save) = match Self::load_from_file(&path).await {
+ Ok(value) => value,
+ Err(error) => {
+ log::warn!(
+ "{}",
+ error.display_chain_with_msg("Failed to load settings. Using defaults.")
+ );
+ let mut settings = Self::default_settings();
+
+ // Protect the user by blocking the internet by default. Previous settings may
+ // not have caused the daemon to enter the non-blocking disconnected state.
+ settings.block_when_disconnected = true;
+
+ (settings, true)
+ }
+ };
+
+ // Force IPv6 to be enabled on Android
+ if cfg!(target_os = "android") {
+ should_save |= !settings.tunnel_options.generic.enable_ipv6;
+ settings.tunnel_options.generic.enable_ipv6 = true;
+ }
+ if crate::version::is_beta_version() {
+ should_save |= !settings.show_beta_releases;
+ settings.show_beta_releases = true;
+ }
+
+ let mut persister = SettingsPersister {
+ settings,
+ path,
+ on_change_listeners: vec![],
+ };
+
+ if should_save {
+ if let Err(error) = persister.save().await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to save updated settings")
+ );
+ }
+ }
+
+ persister
+ }
+
+ pub fn register_change_listener(&mut self, change_listener: impl Fn(&Settings) + 'static) {
+ self.on_change_listeners.push(Box::new(change_listener));
+ }
+
+ fn notify_listeners(&self) {
+ for listener in &self.on_change_listeners {
+ listener(&self.settings);
+ }
+ }
+
+ async fn load_from_file(path: &Path) -> Result<(Settings, bool), Error> {
+ log::info!("Loading settings from {}", path.display());
+
+ let settings_bytes = match fs::read(path).await {
+ Ok(bytes) => bytes,
+ Err(error) => {
+ if error.kind() == io::ErrorKind::NotFound {
+ log::info!("No settings were found. Using defaults.");
+ return Ok((Self::default_settings(), true));
+ } else {
+ return Err(Error::ReadError(path.display().to_string(), error));
+ }
+ }
+ };
+ Ok((Self::load_from_bytes(&settings_bytes)?, false))
+ }
+
+ fn load_from_bytes(bytes: &[u8]) -> Result<Settings, Error> {
+ serde_json::from_slice(bytes).map_err(Error::ParseError)
+ }
+
+ async fn save(&mut self) -> Result<(), Error> {
+ Self::save_inner(&self.path, &self.settings).await
+ }
+
+ /// Serializes the settings and saves them to the given file.
+ async fn save_inner(path: &Path, settings: &Settings) -> Result<(), Error> {
+ log::debug!("Writing settings to {}", path.display());
+
+ let buffer = serde_json::to_string_pretty(settings).map_err(Error::SerializeError)?;
+ let mut file = mullvad_fs::AtomicFile::new(path)
+ .await
+ .map_err(|e| Error::WriteError(path.display().to_string(), e))?;
+ file.write_all(&buffer.into_bytes())
+ .await
+ .map_err(|e| Error::WriteError(path.display().to_string(), e))?;
+ file.finalize()
+ .await
+ .map_err(|e| Error::WriteError(path.display().to_string(), e))?;
+
+ Ok(())
+ }
+
+ /// Resets default settings
+ #[cfg(not(target_os = "android"))]
+ pub async fn reset(&mut self) -> Result<(), Error> {
+ self.settings = Self::default_settings();
+ let path = self.path.clone();
+ self.save()
+ .or_else(|e| async move {
+ log::error!(
+ "{}",
+ e.display_chain_with_msg("Unable to save default settings")
+ );
+ log::info!("Will attempt to remove settings file");
+ fs::remove_file(&path)
+ .map_err(|e| Error::DeleteError(path.display().to_string(), e))
+ .await
+ })
+ .await?;
+
+ self.notify_listeners();
+
+ Ok(())
+ }
+
+ pub fn to_settings(&self) -> Settings {
+ self.settings.clone()
+ }
+
+ /// Modifies `Settings::default()` somewhat, e.g. depending on whether a beta version
+ /// is being run or not.
+ fn default_settings() -> Settings {
+ let mut settings = Settings::default();
+
+ if crate::version::is_beta_version() {
+ settings.show_beta_releases = true;
+ }
+ settings
+ }
+
+ /// Edit the settings in a closure, and write the changes, if any, to disk.
+ ///
+ /// On success, the function returns a boolean indicating whether any settings were changed.
+ /// If the settings could not be written to disk, all changes are rolled back, and an error is
+ /// returned.
+ pub async fn update(
+ &mut self,
+ update_fn: impl FnOnce(&mut Settings),
+ ) -> Result<MadeChanges, Error> {
+ let mut new_settings = self.settings.clone();
+
+ update_fn(&mut new_settings);
+
+ if self.settings == new_settings {
+ return Ok(false);
+ }
+
+ Self::save_inner(&self.path, &new_settings).await?;
+ self.settings = new_settings;
+
+ self.notify_listeners();
+
+ Ok(true)
+ }
+
+ /// Return a compact summary of important settings
+ pub fn summary(&self) -> SettingsSummary<'_> {
+ SettingsSummary {
+ settings: &self.settings,
+ }
+ }
+}
+
+impl Deref for SettingsPersister {
+ type Target = Settings;
+
+ fn deref(&self) -> &Self::Target {
+ &self.settings
+ }
+}
+
+/// A compact summary of important settings
+pub struct SettingsSummary<'a> {
+ settings: &'a Settings,
+}
+
+impl<'a> Display for SettingsSummary<'a> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let bool_to_label = |state| {
+ if state {
+ "on"
+ } else {
+ "off"
+ }
+ };
+
+ let relay_settings = self.settings.get_relay_settings();
+
+ write!(f, "openvpn mssfix: ")?;
+ Self::fmt_option(f, self.settings.tunnel_options.openvpn.mssfix)?;
+ write!(f, ", wg mtu: ")?;
+ Self::fmt_option(f, self.settings.tunnel_options.wireguard.mtu)?;
+
+ if let RelaySettings::Normal(RelayConstraints {
+ wireguard_constraints: WireguardConstraints { ip_version, .. },
+ ..
+ }) = relay_settings
+ {
+ write!(f, ", wg ip version: {ip_version}")?;
+ }
+
+ let multihop = matches!(
+ relay_settings,
+ RelaySettings::Normal(RelayConstraints {
+ wireguard_constraints: WireguardConstraints {
+ use_multihop: true,
+ ..
+ },
+ ..
+ })
+ );
+
+ write!(
+ f,
+ ", multihop: {}, ipv6 (tun): {}, lan: {}, pq: {}, obfs: {}",
+ bool_to_label(multihop),
+ bool_to_label(self.settings.tunnel_options.generic.enable_ipv6),
+ bool_to_label(self.settings.allow_lan),
+ self.settings.tunnel_options.wireguard.quantum_resistant,
+ self.settings.obfuscation_settings.selected_obfuscation,
+ )?;
+
+ // Print DNS options
+
+ write!(f, ", dns: ")?;
+
+ match self.settings.tunnel_options.dns_options.state {
+ DnsState::Default => {
+ let mut content = vec![];
+ let default_options = &self.settings.tunnel_options.dns_options.default_options;
+
+ if default_options.block_ads {
+ content.push("ads");
+ }
+ if default_options.block_trackers {
+ content.push("trackers");
+ }
+ if default_options.block_malware {
+ content.push("malware");
+ }
+ if default_options.block_adult_content {
+ content.push("adult");
+ }
+ if default_options.block_gambling {
+ content.push("gambling");
+ }
+ if default_options.block_social_media {
+ content.push("social media");
+ }
+ if content.is_empty() {
+ content.push("default");
+ }
+ write!(f, "{}", content.join(" "))?;
+ }
+ DnsState::Custom => {
+ // NOTE: Technically inaccurate, as the gateway IP is a local IP but isn't treated
+ // as one.
+ let contains_local = self
+ .settings
+ .tunnel_options
+ .dns_options
+ .custom_options
+ .addresses
+ .iter()
+ .any(is_local_address);
+ let contains_public = self
+ .settings
+ .tunnel_options
+ .dns_options
+ .custom_options
+ .addresses
+ .iter()
+ .any(|addr| !is_local_address(addr));
+
+ match (contains_public, contains_local) {
+ (true, true) => f.write_str("custom, public, local")?,
+ (true, false) => f.write_str("custom, public")?,
+ (false, false) => f.write_str("custom, no addrs")?,
+ (false, true) => f.write_str("custom, local")?,
+ }
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<'a> SettingsSummary<'a> {
+ fn fmt_option<T: Display>(f: &mut fmt::Formatter<'_>, val: Option<T>) -> fmt::Result {
+ if let Some(inner) = &val {
+ inner.fmt(f)
+ } else {
+ f.write_str("unset")
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::SettingsPersister;
+ use mullvad_types::settings::SettingsVersion;
+ use serde_json;
+
+ #[test]
+ #[should_panic]
+ fn test_deserialization_failure_version_too_small() {
+ let _version: SettingsVersion = serde_json::from_str("1").expect("Version too small");
+ }
+
+ #[test]
+ #[should_panic]
+ fn test_deserialization_failure_version_too_big() {
+ let _version: SettingsVersion = serde_json::from_str("1000").expect("Version too big");
+ }
+
+ #[test]
+ fn test_deserialization_success() {
+ let _version: SettingsVersion =
+ serde_json::from_str("2").expect("Failed to deserialize valid version");
+ }
+
+ #[test]
+ fn test_serialization_success() {
+ let version = SettingsVersion::V2;
+ let s = serde_json::to_string(&version).expect("Failed to serialize");
+ assert_eq!(s, "2");
+ }
+
+ #[test]
+ fn test_deserialization() {
+ let settings = br#"{
+ "account_token": "0000000000000000",
+ "relay_settings": {
+ "normal": {
+ "location": {
+ "only": {
+ "location": {
+ "country": "gb"
+ }
+ }
+ },
+ "tunnel_protocol": {
+ "only": "wireguard"
+ },
+ "wireguard_constraints": {
+ "port": "any"
+ },
+ "openvpn_constraints": {
+ "port": "any",
+ "protocol": "any"
+ }
+ }
+ },
+ "bridge_settings": {
+ "normal": {
+ "location": "any"
+ }
+ },
+ "bridge_state": "auto",
+ "allow_lan": true,
+ "block_when_disconnected": false,
+ "auto_connect": true,
+ "tunnel_options": {
+ "openvpn": {
+ "mssfix": null
+ },
+ "wireguard": {
+ "mtu": null,
+ "rotation_interval": null
+ },
+ "generic": {
+ "enable_ipv6": true
+ }
+ },
+ "settings_version": 5,
+ "show_beta_releases": false,
+ "custom_lists": {
+ "custom_lists": []
+ }
+ }"#;
+
+ let _ = SettingsPersister::load_from_bytes(settings).unwrap();
+ }
+}
diff --git a/mullvad-daemon/src/settings/patch.rs b/mullvad-daemon/src/settings/patch.rs
new file mode 100644
index 0000000000..f006b21056
--- /dev/null
+++ b/mullvad-daemon/src/settings/patch.rs
@@ -0,0 +1,481 @@
+//! This module provides functionality for updating settings using a JSON string, i.e. applying a
+//! patch. It is intended to be relatively safe, preventing editing of "dangerous" settings such as
+//! custom DNS.
+//!
+//! Patching the settings is a three-step procedure:
+//! 1. Validating the input. Only a subset of settings is allowed to be edited using this method.
+//! Attempting to edit prohibited or invalid settings results in an error.
+//! 2. Merging the changes. When the patch has been accepted, it can be applied to the existing
+//! settings. How they're merged depends on the actual setting. See [MergeStrategy].
+//! 3. Deserialize the resulting JSON back to a [Settings] instance, and, if valid, replace the
+//! existing settings.
+//!
+//! Permitted settings and merge strategies are defined in the [PERMITTED_SUBKEYS] constant.
+
+use super::SettingsPersister;
+use mullvad_types::settings::Settings;
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Missing expected JSON object
+ #[error(display = "Incorrect or missing value: {}", _0)]
+ InvalidOrMissingValue(&'static str),
+ /// Unknown or prohibited key
+ #[error(display = "Invalid or prohibited key: {}", _0)]
+ UnknownOrProhibitedKey(String),
+ /// Failed to parse patch json
+ #[error(display = "Failed to parse settings patch")]
+ ParsePatch(#[error(source)] serde_json::Error),
+ /// Failed to deserialize patched settings
+ #[error(display = "Failed to deserialize patched settings")]
+ DeserializePatched(#[error(source)] serde_json::Error),
+ /// Failed to serialize settings
+ #[error(display = "Failed to serialize current settings")]
+ SerializeSettings(#[error(source)] serde_json::Error),
+ /// Recursion limit reached
+ #[error(display = "Maximum JSON object depth reached")]
+ RecursionLimit,
+ /// Settings error
+ #[error(display = "Settings error")]
+ Settings(#[error(source)] super::Error),
+}
+
+/// Converts an [Error] to a management interface status
+#[cfg(not(target_os = "android"))]
+impl From<Error> for mullvad_management_interface::Status {
+ fn from(error: Error) -> mullvad_management_interface::Status {
+ use mullvad_management_interface::Status;
+
+ match error {
+ Error::InvalidOrMissingValue(_)
+ | Error::UnknownOrProhibitedKey(_)
+ | Error::ParsePatch(_)
+ | Error::DeserializePatched(_)
+ | Error::RecursionLimit => Status::invalid_argument(error.to_string()),
+ Error::Settings(error) => Status::from(error),
+ Error::SerializeSettings(error) => Status::internal(error.to_string()),
+ }
+ }
+}
+
+enum MergeStrategy {
+ /// Replace or append keys to objects, and replace everything else
+ Replace,
+ /// Call a function to combine an existing setting (which may be null) with the patch.
+ /// The returned value replaces the existing node.
+ Custom(fn(&serde_json::Value, &serde_json::Value) -> Result<serde_json::Value, Error>),
+}
+
+// TODO: Use Default trait when `const_trait_impl`` is available.
+const DEFAULT_MERGE_STRATEGY: MergeStrategy = MergeStrategy::Replace;
+
+struct PermittedKey {
+ key_type: PermittedKeyValue,
+ merge_strategy: MergeStrategy,
+}
+
+impl PermittedKey {
+ const fn object(keys: &'static [(&'static str, PermittedKey)]) -> Self {
+ Self {
+ key_type: PermittedKeyValue::Object(keys),
+ merge_strategy: DEFAULT_MERGE_STRATEGY,
+ }
+ }
+
+ const fn array(key: &'static PermittedKey) -> Self {
+ Self {
+ key_type: PermittedKeyValue::Array(key),
+ merge_strategy: DEFAULT_MERGE_STRATEGY,
+ }
+ }
+
+ const fn any() -> Self {
+ Self {
+ key_type: PermittedKeyValue::Any,
+ merge_strategy: DEFAULT_MERGE_STRATEGY,
+ }
+ }
+
+ const fn merge_strategy(mut self, merge_strategy: MergeStrategy) -> Self {
+ self.merge_strategy = merge_strategy;
+ self
+ }
+}
+
+enum PermittedKeyValue {
+ /// Select subkeys that can be modified at this level
+ Object(&'static [(&'static str, PermittedKey)]),
+ /// Array that can be modified at this level
+ Array(&'static PermittedKey),
+ /// Accept any object at this level
+ Any,
+}
+
+const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[(
+ "relay_overrides",
+ PermittedKey::array(&PermittedKey::object(&[
+ ("hostname", PermittedKey::any()),
+ ("ipv4_addr_in", PermittedKey::any()),
+ ("ipv6_addr_in", PermittedKey::any()),
+ ]))
+ .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)),
+)]);
+/// Prohibit stack overflow via excessive recursion. It might be possible to forgo this when
+/// tail-call optimization can be enforced?
+const RECURSE_LIMIT: usize = 15;
+
+/// Update the settings with the supplied patch. Only settings specified in `PERMITTED_SUBKEYS` can
+/// be updated. All other changes are rejected
+pub async fn merge_validate_patch(
+ settings: &mut SettingsPersister,
+ json_patch: &str,
+) -> Result<(), Error> {
+ let mut settings_value: serde_json::Value =
+ serde_json::to_value(settings.to_settings()).map_err(Error::SerializeSettings)?;
+ let patch_value: serde_json::Value =
+ serde_json::from_str(json_patch).map_err(Error::ParsePatch)?;
+
+ validate_patch_value(PERMITTED_SUBKEYS, &patch_value, 0)?;
+ merge_patch_to_value(PERMITTED_SUBKEYS, &mut settings_value, &patch_value, 0)?;
+
+ let new_settings: Settings =
+ serde_json::from_value(settings_value).map_err(Error::DeserializePatched)?;
+
+ settings
+ .update(move |settings| *settings = new_settings)
+ .await
+ .map_err(Error::Settings)?;
+
+ Ok(())
+}
+
+/// Replace overrides for existing values in the array if there's a matching hostname. For hostnames
+/// that do not exist, just append the overrides.
+fn merge_relay_overrides(
+ current_settings: &serde_json::Value,
+ patch: &serde_json::Value,
+) -> Result<serde_json::Value, Error> {
+ if current_settings.is_null() {
+ return Ok(patch.to_owned());
+ }
+
+ let patch_array = patch.as_array().ok_or(Error::InvalidOrMissingValue(
+ "relay overrides must be array",
+ ))?;
+ let current_array = current_settings
+ .as_array()
+ .ok_or(Error::InvalidOrMissingValue(
+ "existing overrides should be an array",
+ ))?;
+ let mut new_array = current_array.clone();
+
+ for patch_override in patch_array.iter().cloned() {
+ let patch_obj = patch_override
+ .as_object()
+ .ok_or(Error::InvalidOrMissingValue("override entry"))?;
+ let patch_hostname = patch_obj
+ .get("hostname")
+ .and_then(|hostname| hostname.as_str())
+ .ok_or(Error::InvalidOrMissingValue("hostname"))?;
+
+ let existing_obj = new_array.iter_mut().find(|value| {
+ value
+ .as_object()
+ .and_then(|obj| obj.get("hostname"))
+ .map(|hostname| hostname.as_str() == Some(patch_hostname))
+ .unwrap_or(false)
+ });
+
+ match existing_obj {
+ Some(existing_val) => {
+ // Replace or append to existing values
+ match (existing_val, patch_override) {
+ (
+ serde_json::Value::Object(ref mut current),
+ serde_json::Value::Object(ref patch),
+ ) => {
+ for (k, v) in patch {
+ current.insert(k.to_owned(), v.to_owned());
+ }
+ }
+ _ => {
+ return Err(Error::InvalidOrMissingValue(
+ "all override entries must be objects",
+ ));
+ }
+ }
+ }
+ None => new_array.push(patch_override),
+ }
+ }
+
+ Ok(serde_json::Value::Array(new_array))
+}
+
+fn merge_patch_to_value(
+ permitted_key: &'static PermittedKey,
+ current_value: &mut serde_json::Value,
+ patch_value: &serde_json::Value,
+ recurse_level: usize,
+) -> Result<(), Error> {
+ if recurse_level >= RECURSE_LIMIT {
+ return Err(Error::RecursionLimit);
+ }
+
+ match permitted_key.merge_strategy {
+ MergeStrategy::Replace => {
+ match (&permitted_key.key_type, current_value, patch_value) {
+ // Append or replace keys to objects
+ (
+ PermittedKeyValue::Object(sub_permitteds),
+ serde_json::Value::Object(ref mut current),
+ serde_json::Value::Object(ref patch),
+ ) => {
+ for (k, sub_patch) in patch {
+ let Some((_, sub_permitted)) = sub_permitteds
+ .iter()
+ .find(|(permitted_key, _)| k == permitted_key)
+ else {
+ return Err(Error::UnknownOrProhibitedKey(k.to_owned()));
+ };
+ let sub_current = current.entry(k).or_insert(serde_json::Value::Null);
+ merge_patch_to_value(
+ sub_permitted,
+ sub_current,
+ sub_patch,
+ recurse_level + 1,
+ )?;
+ }
+ }
+ // Totally replace anything else
+ (_, current, patch) => {
+ *current = patch.clone();
+ }
+ }
+ }
+ MergeStrategy::Custom(merge_function) => {
+ *current_value = merge_function(current_value, patch_value)?;
+ }
+ }
+
+ Ok(())
+}
+
+fn validate_patch_value(
+ permitted_key: &'static PermittedKey,
+ json_value: &serde_json::Value,
+ recurse_level: usize,
+) -> Result<(), Error> {
+ if recurse_level >= RECURSE_LIMIT {
+ return Err(Error::RecursionLimit);
+ }
+
+ match permitted_key.key_type {
+ PermittedKeyValue::Object(subkeys) => {
+ let map = json_value.as_object().ok_or(Error::InvalidOrMissingValue(
+ "expected JSON object in patch",
+ ))?;
+ for (k, v) in map.into_iter() {
+ // NOTE: We're relying on the parser to shed duplicate keys here.
+ // As of this writing, `Map` is implemented using BTreeMap.
+ let Some((_, subkey)) =
+ subkeys.iter().find(|(permitted_key, _)| k == permitted_key)
+ else {
+ return Err(Error::UnknownOrProhibitedKey(k.to_owned()));
+ };
+ validate_patch_value(subkey, v, recurse_level + 1)?;
+ }
+ Ok(())
+ }
+ PermittedKeyValue::Array(subkey) => {
+ let values = json_value
+ .as_array()
+ .ok_or(Error::InvalidOrMissingValue("expected JSON array in patch"))?;
+ for v in values {
+ validate_patch_value(subkey, v, recurse_level + 1)?;
+ }
+ Ok(())
+ }
+ PermittedKeyValue::Any => Ok(()),
+ }
+}
+
+#[test]
+fn test_permitted_value() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[(
+ "key",
+ PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])),
+ )]);
+
+ let patch = r#"{"key": [ {"a": "test" } ] }"#;
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap();
+}
+
+#[test]
+fn test_prohibited_value() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[(
+ "key",
+ PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])),
+ )]);
+
+ let patch = r#"{"keyx": [] }"#;
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err();
+
+ let patch = r#"{"key": { "b": 1 } }"#;
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err();
+}
+
+#[test]
+fn test_merge_append_to_object() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[
+ ("test0", PermittedKey::any()),
+ ("test1", PermittedKey::any()),
+ ]);
+
+ let current = r#"{ "test0": 1 }"#;
+ let patch = r#"{ "test1": [] }"#;
+ let expected = r#"{ "test0": 1, "test1": [] }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+}
+
+#[test]
+fn test_merge_replace_in_object() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[
+ ("test0", PermittedKey::any()),
+ (
+ "test1",
+ PermittedKey::object(&[("a", PermittedKey::any()), ("test0", PermittedKey::any())]),
+ ),
+ ]);
+
+ let current = r#"{ "test0": 1, "test1": { "a": 1, "test0": [] } }"#;
+ let patch = r#"{ "test1": { "test0": [1, 2, 3] } }"#;
+ let expected = r#"{ "test0": 1, "test1": { "a": 1, "test0": [1, 2, 3] } }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+}
+
+#[test]
+fn test_overflow() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::array(&PermittedKey::array(
+ &PermittedKey::array(&PermittedKey::any()),
+ ))),
+ ))),
+ ))),
+ ))),
+ ))),
+ ));
+
+ let patch = r#"[[[[[[[[[[[[[[[[[[[[[[]]]]]]]]]]]]]]]]]]]]]]"#;
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+
+ assert!(matches!(
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0),
+ Err(Error::RecursionLimit)
+ ));
+}
+
+#[test]
+fn test_patch_relay_override() {
+ const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[(
+ "relay_overrides",
+ PermittedKey::array(&PermittedKey::object(&[
+ ("hostname", PermittedKey::any()),
+ ("ipv4_addr_in", PermittedKey::any()),
+ ("ipv6_addr_in", PermittedKey::any()),
+ ]))
+ .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)),
+ )]);
+
+ // If override has no hostname, fail
+ //
+ let patch = r#"{ "relay_overrides": [ { "invalid": 0 } ] }"#;
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err();
+
+ // If there are no overrides, append new override
+ //
+ let current = r#"{ "other": 1 }"#;
+ let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#;
+ let expected = r#"{ "other": 1, "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap();
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+
+ // If there are overrides, append new override to existing list
+ //
+ let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#;
+ let patch = r#"{ "relay_overrides": [ { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#;
+ let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap();
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+
+ // If there are overrides, replace existing overrides but keep rest
+ //
+ let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "1.2.3.4" } ] }"#;
+ let patch = r#"{ "relay_overrides": [ { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#;
+ let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap();
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+
+ // For same hostname, only update specified overrides
+ //
+ let current =
+ r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#;
+ let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv6_addr_in": "::1" } ] }"#;
+ let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7", "ipv6_addr_in": "::1" } ] }"#;
+
+ let mut current: serde_json::Value = serde_json::from_str(current).unwrap();
+ let patch: serde_json::Value = serde_json::from_str(patch).unwrap();
+ let expected: serde_json::Value = serde_json::from_str(expected).unwrap();
+
+ validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap();
+ merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap();
+
+ assert_eq!(current, expected);
+}