diff options
Diffstat (limited to 'mullvad-daemon')
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 15 | ||||
| -rw-r--r-- | mullvad-daemon/src/management_interface.rs | 128 | ||||
| -rw-r--r-- | mullvad-daemon/src/settings/mod.rs (renamed from mullvad-daemon/src/settings.rs) | 19 | ||||
| -rw-r--r-- | mullvad-daemon/src/settings/patch.rs | 481 |
4 files changed, 559 insertions, 84 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d4964a8b80..890cdfb13e 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -349,6 +349,8 @@ pub enum DaemonCommand { /// Verify that a google play payment was successful through the API. #[cfg(target_os = "android")] VerifyPlayPurchase(ResponseTx<(), Error>, PlayPurchase), + /// Patch the settings using a blob of JSON settings + ApplyJsonSettings(ResponseTx<(), settings::patch::Error>, String), } /// All events that can happen in the daemon. Sent from various threads and exposed interfaces. @@ -1171,6 +1173,7 @@ where VerifyPlayPurchase(tx, play_purchase) => { self.on_verify_play_purchase(tx, play_purchase) } + ApplyJsonSettings(tx, blob) => self.on_apply_json_settings(tx, blob).await, } } @@ -2439,6 +2442,18 @@ where }); } + async fn on_apply_json_settings( + &mut self, + tx: ResponseTx<(), settings::patch::Error>, + blob: String, + ) { + let result = settings::patch::merge_validate_patch(&mut self.settings, &blob).await; + if result.is_ok() { + self.reconnect_tunnel(); + } + Self::oneshot_send(tx, result, "apply_json_settings response"); + } + /// Set the target state of the client. If it changed trigger the operations needed to /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index e67a02117c..f042a923e5 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -1,4 +1,4 @@ -use crate::{account_history, device, settings, DaemonCommand, DaemonCommandSender, EventListener}; +use crate::{account_history, device, DaemonCommand, DaemonCommandSender, EventListener}; use futures::{ channel::{mpsc, oneshot}, StreamExt, @@ -177,10 +177,8 @@ impl ManagementService for ManagementServiceImpl { let message = DaemonCommand::SetRelaySettings(tx, constraints_update); self.send_command_to_daemon(message)?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn get_relay_locations(&self, _: Request<()>) -> ServiceResult<types::RelayList> { @@ -215,10 +213,8 @@ impl ManagementService for ManagementServiceImpl { let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetBridgeSettings(tx, settings))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_obfuscation_settings( @@ -230,10 +226,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_obfuscation_settings({:?})", settings); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetObfuscationSettings(tx, settings))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_bridge_state(&self, request: Request<types::BridgeState>) -> ServiceResult<()> { @@ -243,10 +237,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_bridge_state({:?})", bridge_state); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetBridgeState(tx, bridge_state))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } // Settings @@ -266,10 +258,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_allow_lan({})", allow_lan); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetAllowLan(tx, allow_lan))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_show_beta_releases(&self, request: Request<bool>) -> ServiceResult<()> { @@ -277,10 +267,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_show_beta_releases({})", enabled); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetShowBetaReleases(tx, enabled))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_block_when_disconnected(&self, request: Request<bool>) -> ServiceResult<()> { @@ -291,10 +279,8 @@ impl ManagementService for ManagementServiceImpl { tx, block_when_disconnected, ))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_auto_connect(&self, request: Request<bool>) -> ServiceResult<()> { @@ -302,10 +288,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_auto_connect({})", auto_connect); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetAutoConnect(tx, auto_connect))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_openvpn_mssfix(&self, request: Request<u32>) -> ServiceResult<()> { @@ -318,10 +302,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_openvpn_mssfix({:?})", mssfix); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetOpenVpnMssfix(tx, mssfix))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_wireguard_mtu(&self, request: Request<u32>) -> ServiceResult<()> { @@ -330,10 +312,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_wireguard_mtu({:?})", mtu); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetWireguardMtu(tx, mtu))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_enable_ipv6(&self, request: Request<bool>) -> ServiceResult<()> { @@ -341,10 +321,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_enable_ipv6({})", enable_ipv6); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetEnableIpv6(tx, enable_ipv6))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_quantum_resistant_tunnel( @@ -357,10 +335,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_quantum_resistant_tunnel({state:?})"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetQuantumResistantTunnel(tx, state))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } #[cfg(not(target_os = "android"))] @@ -370,10 +346,8 @@ impl ManagementService for ManagementServiceImpl { let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetDnsOptions(tx, options))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } #[cfg(target_os = "android")] @@ -390,20 +364,16 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_relay_override"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetRelayOverride(tx, relay_override))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn clear_all_relay_overrides(&self, _: Request<()>) -> ServiceResult<()> { log::debug!("clear_all_relay_overrides"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::ClearAllRelayOverrides(tx))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } // Account management @@ -571,20 +541,16 @@ impl ManagementService for ManagementServiceImpl { tx, Some(interval), ))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn reset_wireguard_rotation_interval(&self, _: Request<()>) -> ServiceResult<()> { log::debug!("reset_wireguard_rotation_interval"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetWireguardRotationInterval(tx, None))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> { @@ -929,6 +895,14 @@ impl ManagementService for ManagementServiceImpl { async fn check_volumes(&self, _: Request<()>) -> ServiceResult<()> { Ok(Response::new(())) } + + async fn apply_json_settings(&self, blob: Request<String>) -> ServiceResult<()> { + log::debug!("apply_json_settings"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::ApplyJsonSettings(tx, blob.into_inner()))?; + self.wait_for_result(rx).await??; + Ok(Response::new(())) + } } impl ManagementServiceImpl { @@ -1061,7 +1035,7 @@ fn map_daemon_error(error: crate::Error) -> Status { match error { DaemonError::RestError(error) => map_rest_error(&error), - DaemonError::SettingsError(error) => map_settings_error(error), + DaemonError::SettingsError(error) => Status::from(error), DaemonError::AlreadyLoggedIn => Status::already_exists(error.to_string()), DaemonError::LoginError(error) => map_device_error(&error), DaemonError::LogoutError(error) => map_device_error(&error), @@ -1121,20 +1095,6 @@ fn map_rest_error(error: &RestError) -> Status { } } -/// Converts an instance of [`mullvad_daemon::settings::Error`] into a tonic status. -fn map_settings_error(error: settings::Error) -> Status { - match error { - settings::Error::DeleteError(..) - | settings::Error::WriteError(..) - | settings::Error::ReadError(..) => { - Status::new(Code::FailedPrecondition, error.to_string()) - } - settings::Error::SerializeError(..) | settings::Error::ParseError(..) => { - Status::new(Code::Internal, error.to_string()) - } - } -} - /// Converts an instance of [`mullvad_daemon::device::Error`] into a tonic status. fn map_device_error(error: &device::Error) -> Status { match error { diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings/mod.rs index f5c5e31e94..99f637c023 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings/mod.rs @@ -16,6 +16,8 @@ use tokio::{ io::{self, AsyncWriteExt}, }; +pub mod patch; + const SETTINGS_FILE: &str = "settings.json"; #[derive(err_derive::Error, Debug)] @@ -38,6 +40,23 @@ pub enum Error { WriteError(String, #[error(source)] io::Error), } +/// Converts an [Error] to a management interface status +#[cfg(not(target_os = "android"))] +impl From<Error> for mullvad_management_interface::Status { + fn from(error: Error) -> mullvad_management_interface::Status { + use mullvad_management_interface::{Code, Status}; + + match error { + Error::DeleteError(..) | Error::WriteError(..) | Error::ReadError(..) => { + Status::new(Code::FailedPrecondition, error.to_string()) + } + Error::SerializeError(..) | Error::ParseError(..) => { + Status::new(Code::Internal, error.to_string()) + } + } + } +} + pub struct SettingsPersister { settings: Settings, path: PathBuf, diff --git a/mullvad-daemon/src/settings/patch.rs b/mullvad-daemon/src/settings/patch.rs new file mode 100644 index 0000000000..f006b21056 --- /dev/null +++ b/mullvad-daemon/src/settings/patch.rs @@ -0,0 +1,481 @@ +//! This module provides functionality for updating settings using a JSON string, i.e. applying a +//! patch. It is intended to be relatively safe, preventing editing of "dangerous" settings such as +//! custom DNS. +//! +//! Patching the settings is a three-step procedure: +//! 1. Validating the input. Only a subset of settings is allowed to be edited using this method. +//! Attempting to edit prohibited or invalid settings results in an error. +//! 2. Merging the changes. When the patch has been accepted, it can be applied to the existing +//! settings. How they're merged depends on the actual setting. See [MergeStrategy]. +//! 3. Deserialize the resulting JSON back to a [Settings] instance, and, if valid, replace the +//! existing settings. +//! +//! Permitted settings and merge strategies are defined in the [PERMITTED_SUBKEYS] constant. + +use super::SettingsPersister; +use mullvad_types::settings::Settings; + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Missing expected JSON object + #[error(display = "Incorrect or missing value: {}", _0)] + InvalidOrMissingValue(&'static str), + /// Unknown or prohibited key + #[error(display = "Invalid or prohibited key: {}", _0)] + UnknownOrProhibitedKey(String), + /// Failed to parse patch json + #[error(display = "Failed to parse settings patch")] + ParsePatch(#[error(source)] serde_json::Error), + /// Failed to deserialize patched settings + #[error(display = "Failed to deserialize patched settings")] + DeserializePatched(#[error(source)] serde_json::Error), + /// Failed to serialize settings + #[error(display = "Failed to serialize current settings")] + SerializeSettings(#[error(source)] serde_json::Error), + /// Recursion limit reached + #[error(display = "Maximum JSON object depth reached")] + RecursionLimit, + /// Settings error + #[error(display = "Settings error")] + Settings(#[error(source)] super::Error), +} + +/// Converts an [Error] to a management interface status +#[cfg(not(target_os = "android"))] +impl From<Error> for mullvad_management_interface::Status { + fn from(error: Error) -> mullvad_management_interface::Status { + use mullvad_management_interface::Status; + + match error { + Error::InvalidOrMissingValue(_) + | Error::UnknownOrProhibitedKey(_) + | Error::ParsePatch(_) + | Error::DeserializePatched(_) + | Error::RecursionLimit => Status::invalid_argument(error.to_string()), + Error::Settings(error) => Status::from(error), + Error::SerializeSettings(error) => Status::internal(error.to_string()), + } + } +} + +enum MergeStrategy { + /// Replace or append keys to objects, and replace everything else + Replace, + /// Call a function to combine an existing setting (which may be null) with the patch. + /// The returned value replaces the existing node. + Custom(fn(&serde_json::Value, &serde_json::Value) -> Result<serde_json::Value, Error>), +} + +// TODO: Use Default trait when `const_trait_impl`` is available. +const DEFAULT_MERGE_STRATEGY: MergeStrategy = MergeStrategy::Replace; + +struct PermittedKey { + key_type: PermittedKeyValue, + merge_strategy: MergeStrategy, +} + +impl PermittedKey { + const fn object(keys: &'static [(&'static str, PermittedKey)]) -> Self { + Self { + key_type: PermittedKeyValue::Object(keys), + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn array(key: &'static PermittedKey) -> Self { + Self { + key_type: PermittedKeyValue::Array(key), + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn any() -> Self { + Self { + key_type: PermittedKeyValue::Any, + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn merge_strategy(mut self, merge_strategy: MergeStrategy) -> Self { + self.merge_strategy = merge_strategy; + self + } +} + +enum PermittedKeyValue { + /// Select subkeys that can be modified at this level + Object(&'static [(&'static str, PermittedKey)]), + /// Array that can be modified at this level + Array(&'static PermittedKey), + /// Accept any object at this level + Any, +} + +const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "relay_overrides", + PermittedKey::array(&PermittedKey::object(&[ + ("hostname", PermittedKey::any()), + ("ipv4_addr_in", PermittedKey::any()), + ("ipv6_addr_in", PermittedKey::any()), + ])) + .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)), +)]); +/// Prohibit stack overflow via excessive recursion. It might be possible to forgo this when +/// tail-call optimization can be enforced? +const RECURSE_LIMIT: usize = 15; + +/// Update the settings with the supplied patch. Only settings specified in `PERMITTED_SUBKEYS` can +/// be updated. All other changes are rejected +pub async fn merge_validate_patch( + settings: &mut SettingsPersister, + json_patch: &str, +) -> Result<(), Error> { + let mut settings_value: serde_json::Value = + serde_json::to_value(settings.to_settings()).map_err(Error::SerializeSettings)?; + let patch_value: serde_json::Value = + serde_json::from_str(json_patch).map_err(Error::ParsePatch)?; + + validate_patch_value(PERMITTED_SUBKEYS, &patch_value, 0)?; + merge_patch_to_value(PERMITTED_SUBKEYS, &mut settings_value, &patch_value, 0)?; + + let new_settings: Settings = + serde_json::from_value(settings_value).map_err(Error::DeserializePatched)?; + + settings + .update(move |settings| *settings = new_settings) + .await + .map_err(Error::Settings)?; + + Ok(()) +} + +/// Replace overrides for existing values in the array if there's a matching hostname. For hostnames +/// that do not exist, just append the overrides. +fn merge_relay_overrides( + current_settings: &serde_json::Value, + patch: &serde_json::Value, +) -> Result<serde_json::Value, Error> { + if current_settings.is_null() { + return Ok(patch.to_owned()); + } + + let patch_array = patch.as_array().ok_or(Error::InvalidOrMissingValue( + "relay overrides must be array", + ))?; + let current_array = current_settings + .as_array() + .ok_or(Error::InvalidOrMissingValue( + "existing overrides should be an array", + ))?; + let mut new_array = current_array.clone(); + + for patch_override in patch_array.iter().cloned() { + let patch_obj = patch_override + .as_object() + .ok_or(Error::InvalidOrMissingValue("override entry"))?; + let patch_hostname = patch_obj + .get("hostname") + .and_then(|hostname| hostname.as_str()) + .ok_or(Error::InvalidOrMissingValue("hostname"))?; + + let existing_obj = new_array.iter_mut().find(|value| { + value + .as_object() + .and_then(|obj| obj.get("hostname")) + .map(|hostname| hostname.as_str() == Some(patch_hostname)) + .unwrap_or(false) + }); + + match existing_obj { + Some(existing_val) => { + // Replace or append to existing values + match (existing_val, patch_override) { + ( + serde_json::Value::Object(ref mut current), + serde_json::Value::Object(ref patch), + ) => { + for (k, v) in patch { + current.insert(k.to_owned(), v.to_owned()); + } + } + _ => { + return Err(Error::InvalidOrMissingValue( + "all override entries must be objects", + )); + } + } + } + None => new_array.push(patch_override), + } + } + + Ok(serde_json::Value::Array(new_array)) +} + +fn merge_patch_to_value( + permitted_key: &'static PermittedKey, + current_value: &mut serde_json::Value, + patch_value: &serde_json::Value, + recurse_level: usize, +) -> Result<(), Error> { + if recurse_level >= RECURSE_LIMIT { + return Err(Error::RecursionLimit); + } + + match permitted_key.merge_strategy { + MergeStrategy::Replace => { + match (&permitted_key.key_type, current_value, patch_value) { + // Append or replace keys to objects + ( + PermittedKeyValue::Object(sub_permitteds), + serde_json::Value::Object(ref mut current), + serde_json::Value::Object(ref patch), + ) => { + for (k, sub_patch) in patch { + let Some((_, sub_permitted)) = sub_permitteds + .iter() + .find(|(permitted_key, _)| k == permitted_key) + else { + return Err(Error::UnknownOrProhibitedKey(k.to_owned())); + }; + let sub_current = current.entry(k).or_insert(serde_json::Value::Null); + merge_patch_to_value( + sub_permitted, + sub_current, + sub_patch, + recurse_level + 1, + )?; + } + } + // Totally replace anything else + (_, current, patch) => { + *current = patch.clone(); + } + } + } + MergeStrategy::Custom(merge_function) => { + *current_value = merge_function(current_value, patch_value)?; + } + } + + Ok(()) +} + +fn validate_patch_value( + permitted_key: &'static PermittedKey, + json_value: &serde_json::Value, + recurse_level: usize, +) -> Result<(), Error> { + if recurse_level >= RECURSE_LIMIT { + return Err(Error::RecursionLimit); + } + + match permitted_key.key_type { + PermittedKeyValue::Object(subkeys) => { + let map = json_value.as_object().ok_or(Error::InvalidOrMissingValue( + "expected JSON object in patch", + ))?; + for (k, v) in map.into_iter() { + // NOTE: We're relying on the parser to shed duplicate keys here. + // As of this writing, `Map` is implemented using BTreeMap. + let Some((_, subkey)) = + subkeys.iter().find(|(permitted_key, _)| k == permitted_key) + else { + return Err(Error::UnknownOrProhibitedKey(k.to_owned())); + }; + validate_patch_value(subkey, v, recurse_level + 1)?; + } + Ok(()) + } + PermittedKeyValue::Array(subkey) => { + let values = json_value + .as_array() + .ok_or(Error::InvalidOrMissingValue("expected JSON array in patch"))?; + for v in values { + validate_patch_value(subkey, v, recurse_level + 1)?; + } + Ok(()) + } + PermittedKeyValue::Any => Ok(()), + } +} + +#[test] +fn test_permitted_value() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "key", + PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])), + )]); + + let patch = r#"{"key": [ {"a": "test" } ] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); +} + +#[test] +fn test_prohibited_value() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "key", + PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])), + )]); + + let patch = r#"{"keyx": [] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); + + let patch = r#"{"key": { "b": 1 } }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); +} + +#[test] +fn test_merge_append_to_object() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[ + ("test0", PermittedKey::any()), + ("test1", PermittedKey::any()), + ]); + + let current = r#"{ "test0": 1 }"#; + let patch = r#"{ "test1": [] }"#; + let expected = r#"{ "test0": 1, "test1": [] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} + +#[test] +fn test_merge_replace_in_object() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[ + ("test0", PermittedKey::any()), + ( + "test1", + PermittedKey::object(&[("a", PermittedKey::any()), ("test0", PermittedKey::any())]), + ), + ]); + + let current = r#"{ "test0": 1, "test1": { "a": 1, "test0": [] } }"#; + let patch = r#"{ "test1": { "test0": [1, 2, 3] } }"#; + let expected = r#"{ "test0": 1, "test1": { "a": 1, "test0": [1, 2, 3] } }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} + +#[test] +fn test_overflow() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::any()), + ))), + ))), + ))), + ))), + ))), + )); + + let patch = r#"[[[[[[[[[[[[[[[[[[[[[[]]]]]]]]]]]]]]]]]]]]]]"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + assert!(matches!( + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0), + Err(Error::RecursionLimit) + )); +} + +#[test] +fn test_patch_relay_override() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "relay_overrides", + PermittedKey::array(&PermittedKey::object(&[ + ("hostname", PermittedKey::any()), + ("ipv4_addr_in", PermittedKey::any()), + ("ipv6_addr_in", PermittedKey::any()), + ])) + .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)), + )]); + + // If override has no hostname, fail + // + let patch = r#"{ "relay_overrides": [ { "invalid": 0 } ] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); + + // If there are no overrides, append new override + // + let current = r#"{ "other": 1 }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let expected = r#"{ "other": 1, "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // If there are overrides, append new override to existing list + // + let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // If there are overrides, replace existing overrides but keep rest + // + let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "1.2.3.4" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // For same hostname, only update specified overrides + // + let current = + r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv6_addr_in": "::1" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7", "ipv6_addr_in": "::1" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} |
