diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-03-20 17:00:37 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-03-20 17:00:37 +0100 |
| commit | a3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3 (patch) | |
| tree | d7079f9f6775583a0e4da4a061b672372e7af2b3 /test/test-manager | |
| parent | 7e62d03a4366fb8eaabb13dec354bb0237cd0d08 (diff) | |
| parent | 6594c6de52763ab313d99edecf9231596a003e1f (diff) | |
| download | mullvadvpn-a3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3.tar.xz mullvadvpn-a3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3.zip | |
Merge remote-tracking branch 'origin/write-tests-for-split-tunneling-des-276'
Diffstat (limited to 'test/test-manager')
| -rw-r--r-- | test/test-manager/src/config.rs | 10 | ||||
| -rw-r--r-- | test/test-manager/src/logging.rs | 2 | ||||
| -rw-r--r-- | test/test-manager/src/run_tests.rs | 6 | ||||
| -rw-r--r-- | test/test-manager/src/tests/mod.rs | 5 | ||||
| -rw-r--r-- | test/test-manager/src/tests/split_tunnel.rs | 357 | ||||
| -rw-r--r-- | test/test-manager/src/tests/test_metadata.rs | 6 | ||||
| -rw-r--r-- | test/test-manager/src/vm/provision.rs | 5 | ||||
| -rw-r--r-- | test/test-manager/src/vm/qemu.rs | 34 | ||||
| -rw-r--r-- | test/test-manager/test_macro/Cargo.toml | 1 | ||||
| -rw-r--r-- | test/test-manager/test_macro/src/lib.rs | 211 |
10 files changed, 523 insertions, 114 deletions
diff --git a/test/test-manager/src/config.rs b/test/test-manager/src/config.rs index 1605661d53..6921c0b33f 100644 --- a/test/test-manager/src/config.rs +++ b/test/test-manager/src/config.rs @@ -139,6 +139,16 @@ pub struct VmConfig { #[serde(default)] #[arg(long)] pub tpm: bool, + + /// Override the path to `OVMF_VARS.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_vars_path: Option<String>, + + /// Override the path to `OVMF_CODE.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_code_path: Option<String>, } impl VmConfig { diff --git a/test/test-manager/src/logging.rs b/test/test-manager/src/logging.rs index cd0bd4af28..e85920b1cd 100644 --- a/test/test-manager/src/logging.rs +++ b/test/test-manager/src/logging.rs @@ -1,4 +1,4 @@ -use crate::tests::Error; +use anyhow::Error; use colored::Colorize; use std::sync::{Arc, Mutex}; use test_rpc::logging::{LogOutput, Output}; diff --git a/test/test-manager/src/run_tests.rs b/test/test-manager/src/run_tests.rs index 6af1536562..6b3da37138 100644 --- a/test/test-manager/src/run_tests.rs +++ b/test/test-manager/src/run_tests.rs @@ -2,9 +2,7 @@ use crate::summary::{self, maybe_log_test_result}; use crate::tests::{config::TEST_CONFIG, TestContext}; use crate::{ logging::{panic_as_string, TestOutput}, - mullvad_daemon, tests, - tests::Error, - vm, + mullvad_daemon, tests, vm, }; use anyhow::{Context, Result}; use futures::FutureExt; @@ -187,7 +185,7 @@ pub async fn run_test<F, R, MullvadClient>( ) -> TestOutput where F: Fn(super::tests::TestContext, ServiceClient, MullvadClient) -> R, - R: Future<Output = Result<(), Error>>, + R: Future<Output = anyhow::Result<()>>, { let _flushed = runner_rpc.try_poll_output().await; diff --git a/test/test-manager/src/tests/mod.rs b/test/test-manager/src/tests/mod.rs index 0cf1357696..48d75b9e3f 100644 --- a/test/test-manager/src/tests/mod.rs +++ b/test/test-manager/src/tests/mod.rs @@ -6,6 +6,7 @@ mod helpers; mod install; mod settings; mod software; +mod split_tunnel; mod test_metadata; mod tunnel; mod tunnel_state; @@ -32,7 +33,7 @@ pub type TestWrapperFunction = fn( TestContext, ServiceClient, Box<dyn std::any::Any + Send>, -) -> BoxFuture<'static, Result<(), Error>>; +) -> BoxFuture<'static, anyhow::Result<()>>; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -40,7 +41,7 @@ pub enum Error { Rpc(#[from] test_rpc::Error), #[error("geoip lookup failed")] - GeoipLookup(test_rpc::Error), + GeoipLookup(#[source] test_rpc::Error), #[error("Found running daemon unexpectedly")] DaemonRunning, diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs new file mode 100644 index 0000000000..336ee5b5ab --- /dev/null +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -0,0 +1,357 @@ +use anyhow::{anyhow, bail, ensure, Context}; +use mullvad_management_interface::MullvadProxyClient; +use pcap::Direction; +use pnet_packet::ip::IpNextHeaderProtocols; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str, + time::Duration, +}; +use test_macro::test_function; +use test_rpc::{meta::Os, ServiceClient, SpawnOpts}; +use tokio::time::{sleep, timeout}; + +use crate::network_monitor::{start_packet_monitor, MonitorOptions}; + +use super::{config::TEST_CONFIG, helpers, TestContext}; + +const CHECKER_FILENAME_WINDOWS: &str = "connection-checker.exe"; +const CHECKER_FILENAME_UNIX: &str = "connection-checker"; +const LEAK_DESTINATION: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1337); + +/// Test that split tunneling works by asserting the following: +/// - Splitting a process shouldn't do anything if tunnel is not connected. +/// - A split process should never push traffic through the tunnel. +/// - Splitting/unsplitting should work regardless if process is running. +#[test_function(target_os = "linux", target_os = "windows")] +pub async fn test_split_tunnel( + _ctx: TestContext, + rpc: ServiceClient, + mut mullvad_client: MullvadProxyClient, +) -> anyhow::Result<()> { + let mut checker = ConnChecker::new(rpc.clone(), mullvad_client.clone()); + + // Test that program is behaving when we are disconnected + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and unsplit")?; + checker.split().await?; + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and split")?; + checker.unsplit().await?; + + // Test that program is behaving being split/unsplit while running and we are disconnected + let mut handle = checker.spawn().await?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being split while running")?; + handle.unsplit().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being unsplit while running")?; + drop(handle); + + helpers::connect_and_wait(&mut mullvad_client).await?; + + // Test running an unsplit program + checker + .spawn() + .await? + .assert_secure() + .await + .with_context(|| "Test connected and unsplit")?; + + // Test running a split program + checker.split().await?; + checker + .spawn() + .await? + .assert_insecure() + .await + .with_context(|| "Test connected and split")?; + + checker.unsplit().await?; + + // Test splitting and unsplitting a program while it's running + let mut handle = checker.spawn().await?; + (handle.assert_secure().await).with_context(|| "Test connected and unsplit (again)")?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test connected and being split while running")?; + handle.unsplit().await?; + (handle.assert_secure().await) + .with_context(|| "Test connected and being unsplit while running")?; + + Ok(()) +} + +/// This helper spawns a seperate process which checks if we are connected to Mullvad, and tries to +/// leak traffic outside the tunnel by sending TCP, UDP, and ICMP packets to [LEAK_DESTINATION]. +struct ConnChecker { + rpc: ServiceClient, + mullvad_client: MullvadProxyClient, + + /// Path to the process binary. + executable_path: String, + + /// Whether the process should be split when spawned. Needed on Linux. + split: bool, +} + +struct ConnCheckerHandle<'a> { + checker: &'a mut ConnChecker, + + /// ID of the spawned process. + pid: u32, +} + +struct ConnectionStatus { + /// True if <https://am.i.mullvad.net/> reported we are connected. + am_i_mullvad: bool, + + /// True if we sniffed TCP packets going outside the tunnel. + leaked_tcp: bool, + + /// True if we sniffed UDP packets going outside the tunnel. + leaked_udp: bool, + + /// True if we sniffed ICMP packets going outside the tunnel. + leaked_icmp: bool, +} + +impl ConnChecker { + pub fn new(rpc: ServiceClient, mullvad_client: MullvadProxyClient) -> Self { + let artifacts_dir = &TEST_CONFIG.artifacts_dir; + let executable_path = match TEST_CONFIG.os { + Os::Linux | Os::Macos => format!("{artifacts_dir}/{CHECKER_FILENAME_UNIX}"), + Os::Windows => format!("{artifacts_dir}\\{CHECKER_FILENAME_WINDOWS}"), + }; + + Self { + rpc, + mullvad_client, + split: false, + executable_path, + } + } + + /// Spawn the connecton checker process and return a handle to it. + /// + /// Dropping the handle will stop the process. + /// **NOTE**: The handle must be dropped from a tokio runtime context. + pub async fn spawn(&mut self) -> anyhow::Result<ConnCheckerHandle<'_>> { + log::debug!("spawning connection checker"); + + let opts = SpawnOpts { + attach_stdin: true, + attach_stdout: true, + args: [ + "--interactive", + "--timeout", + "10000", + // try to leak traffic to LEAK_DESTINATION + "--leak", + &LEAK_DESTINATION.to_string(), + "--leak-timeout", + "500", + "--leak-tcp", + "--leak-udp", + "--leak-icmp", + ] + .map(String::from) + .to_vec(), + ..SpawnOpts::new(&self.executable_path) + }; + + let pid = self.rpc.spawn(opts).await?; + + if self.split && TEST_CONFIG.os == Os::Linux { + self.mullvad_client + .add_split_tunnel_process(pid as i32) + .await?; + } + + Ok(ConnCheckerHandle { pid, checker: self }) + } + + /// Enable split tunneling for the connection checker. + pub async fn split(&mut self) -> anyhow::Result<()> { + log::debug!("enable split tunnel"); + self.split = true; + + match TEST_CONFIG.os { + Os::Linux => { /* linux programs can't be split until they are spawned */ } + Os::Windows => { + self.mullvad_client + .add_split_tunnel_app(&self.executable_path) + .await?; + self.mullvad_client.set_split_tunnel_state(true).await?; + } + Os::Macos => unimplemented!("MacOS"), + } + + Ok(()) + } + + /// Disable split tunneling for the connection checker. + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + log::debug!("disable split tunnel"); + self.split = false; + + match TEST_CONFIG.os { + Os::Linux => {} + Os::Windows => { + self.mullvad_client.set_split_tunnel_state(false).await?; + self.mullvad_client + .remove_split_tunnel_app(&self.executable_path) + .await?; + } + Os::Macos => unimplemented!("MacOS"), + } + + Ok(()) + } +} + +impl ConnCheckerHandle<'_> { + pub async fn split(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .add_split_tunnel_process(self.pid as i32) + .await?; + } + + self.checker.split().await + } + + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .remove_split_tunnel_process(self.pid as i32) + .await?; + } + + self.checker.unsplit().await + } + + /// Assert that traffic is flowing through the Mullvad tunnel and that no packets are leaked. + pub async fn assert_secure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is secure"); + let status = self.check_connection().await?; + ensure!(status.am_i_mullvad); + ensure!(!status.leaked_tcp); + ensure!(!status.leaked_udp); + ensure!(!status.leaked_icmp); + + Ok(()) + } + + /// Assert that traffic is NOT flowing through the Mullvad tunnel and that packets ARE leaked. + pub async fn assert_insecure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is not secure"); + let status = self.check_connection().await?; + ensure!(!status.am_i_mullvad); + ensure!(status.leaked_tcp); + ensure!(status.leaked_udp); + ensure!(status.leaked_icmp); + + Ok(()) + } + + async fn check_connection(&mut self) -> anyhow::Result<ConnectionStatus> { + // Monitor all pakets going to LEAK_DESTINATION during the check. + let monitor = start_packet_monitor( + |packet| packet.destination.ip() == LEAK_DESTINATION.ip(), + MonitorOptions { + direction: Some(Direction::In), + ..MonitorOptions::default() + }, + ) + .await; + + // Write a newline to the connection checker to prompt it to perform the check. + self.checker + .rpc + .write_child_stdin(self.pid, "Say the line, Bart!\r\n".into()) + .await?; + + // The checker responds when the check is complete. + let line = self.read_stdout_line().await?; + + let monitor_result = monitor + .into_result() + .await + .map_err(|_e| anyhow!("Packet monitor unexpectedly stopped"))?; + + Ok(ConnectionStatus { + am_i_mullvad: parse_am_i_mullvad(line)?, + + leaked_tcp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Tcp), + + leaked_udp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Udp), + + leaked_icmp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Icmp), + }) + } + + /// Try to a single line of output from the spawned process + async fn read_stdout_line(&mut self) -> anyhow::Result<String> { + // Add a timeout to avoid waiting forever. + timeout(Duration::from_secs(8), async { + let mut line = String::new(); + + // tarpc doesn't support streams, so we poll the checker process in a loop instead + loop { + let Some(output) = self.checker.rpc.read_child_stdout(self.pid).await? else { + bail!("got EOF from connection checker process"); + }; + + if output.is_empty() { + sleep(Duration::from_millis(500)).await; + continue; + } + + line.push_str(&output); + + if line.contains('\n') { + log::info!("output from child process: {output:?}"); + return Ok(line); + } + } + }) + .await + .with_context(|| "Timeout reading stdout from connection checker")? + } +} + +impl Drop for ConnCheckerHandle<'_> { + fn drop(&mut self) { + let rpc = self.checker.rpc.clone(); + let pid = self.pid; + + let Ok(runtime_handle) = tokio::runtime::Handle::try_current() else { + log::error!("ConnCheckerHandle dropped outside of a tokio runtime."); + return; + }; + + runtime_handle.spawn(async move { + // Make sure child process is stopped when this handle is dropped. + // Closing stdin does the trick. + let _ = rpc.close_child_stdin(pid).await; + }); + } +} + +/// Parse output from connection-checker. Returns true if connected to Mullvad. +fn parse_am_i_mullvad(result: String) -> anyhow::Result<bool> { + Ok(if result.contains("You are connected") { + true + } else if result.contains("You are not connected") { + false + } else { + bail!("Unexpected output from connection-checker: {result:?}") + }) +} diff --git a/test/test-manager/src/tests/test_metadata.rs b/test/test-manager/src/tests/test_metadata.rs index 3e28a4380b..d4ffa9bfd0 100644 --- a/test/test-manager/src/tests/test_metadata.rs +++ b/test/test-manager/src/tests/test_metadata.rs @@ -5,7 +5,7 @@ use test_rpc::mullvad_daemon::MullvadClientVersion; pub struct TestMetadata { pub name: &'static str, pub command: &'static str, - pub target_os: Option<Os>, + pub targets: &'static [Os], pub mullvad_client_version: MullvadClientVersion, pub func: TestWrapperFunction, pub priority: Option<i32>, @@ -16,9 +16,7 @@ pub struct TestMetadata { impl TestMetadata { pub fn should_run_on_os(&self, os: Os) -> bool { - self.target_os - .map(|target_os| target_os == os) - .unwrap_or(true) + self.targets.is_empty() || self.targets.contains(&os) } } diff --git a/test/test-manager/src/vm/provision.rs b/test/test-manager/src/vm/provision.rs index 5f01e8f192..8667b6c133 100644 --- a/test/test-manager/src/vm/provision.rs +++ b/test/test-manager/src/vm/provision.rs @@ -106,6 +106,11 @@ fn blocking_ssh( ssh_send_file_path(&session, &source, temp_dir) .context("Failed to send test runner to remote")?; + // Transfer connection-checker + let source = local_runner_dir.join("connection-checker"); + ssh_send_file_path(&session, &source, temp_dir) + .context("Failed to send connection-checker to remote")?; + // Transfer app packages ssh_send_file_path(&session, &local_app_manifest.current_app_path, temp_dir) .context("Failed to send current app package to remote")?; diff --git a/test/test-manager/src/vm/qemu.rs b/test/test-manager/src/vm/qemu.rs index 5688f47101..62613d5e1d 100644 --- a/test/test-manager/src/vm/qemu.rs +++ b/test/test-manager/src/vm/qemu.rs @@ -134,7 +134,7 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result<QemuInstance> // Configure OVMF. Currently, this is enabled implicitly if using a TPM let ovmf_handle = if vm_config.tpm { - let handle = OvmfHandle::new().await?; + let handle = OvmfHandle::new(vm_config).await?; handle.append_qemu_args(&mut qemu_cmd); Some(handle) } else { @@ -202,32 +202,50 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result<QemuInstance> /// Used to set up UEFI and append options to the QEMU command struct OvmfHandle { temp_vars: TempFile, + ovmf_code_path: String, } impl OvmfHandle { - pub async fn new() -> Result<Self> { - const OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + pub async fn new(config: &VmConfig) -> Result<Self> { + const DEFAULT_OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + const DEFAULT_OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; + + let ovmf_code_path = config + .ovmf_code_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_CODE_PATH) + .to_owned(); + + let ovmf_vars_path = config + .ovmf_vars_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_VARS_PATH); // Create a local copy of OVMF_VARS let temp_vars_path = random_tempfile_name(); - fs::copy(OVMF_VARS_PATH, &temp_vars_path) + fs::copy(ovmf_vars_path, &temp_vars_path) .await .map_err(Error::CopyOvmfVars)?; let temp_vars = TempFile::from_existing(temp_vars_path, async_tempfile::Ownership::Owned) .await .map_err(|_| Error::WrapOvmfVars)?; - Ok(OvmfHandle { temp_vars }) + + Ok(OvmfHandle { + temp_vars, + ovmf_code_path, + }) } pub fn append_qemu_args(&self, qemu_cmd: &mut Command) { - const OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; - qemu_cmd.args([ "-global", "driver=cfi.pflash01,property=secure,value=on", "-drive", - &format!("if=pflash,format=raw,unit=0,file={OVMF_CODE_PATH},readonly=on"), + &format!( + "if=pflash,format=raw,unit=0,file={},readonly=on", + self.ovmf_code_path + ), "-drive", &format!( "if=pflash,format=raw,unit=1,file={}", diff --git a/test/test-manager/test_macro/Cargo.toml b/test/test-manager/test_macro/Cargo.toml index a064b6d200..19a405d08f 100644 --- a/test/test-manager/test_macro/Cargo.toml +++ b/test/test-manager/test_macro/Cargo.toml @@ -14,3 +14,4 @@ proc-macro = true syn = "1.0" quote = "1.0" proc-macro2 = "1.0" +test-rpc = { path = "../../test-rpc" } diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs index d95c3f8832..7cb8407230 100644 --- a/test/test-manager/test_macro/src/lib.rs +++ b/test/test-manager/test_macro/src/lib.rs @@ -1,6 +1,7 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{AttributeArgs, Lit, Meta, NestedMeta}; +use syn::{AttributeArgs, Lit, Meta, NestedMeta, Result}; +use test_rpc::meta::Os; /// Register an `async` function to be run by `test-manager`. /// @@ -52,7 +53,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -67,7 +68,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -76,7 +77,10 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream let function: syn::ItemFn = syn::parse(code).unwrap(); let attributes = syn::parse_macro_input!(attributes as AttributeArgs); - let test_function = parse_marked_test_function(&attributes, &function); + let test_function = match parse_marked_test_function(&attributes, &function) { + Ok(tf) => tf, + Err(e) => return e.into_compile_error().into(), + }; let register_test = create_test(test_function); @@ -88,73 +92,91 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream .into() } -fn parse_marked_test_function(attributes: &AttributeArgs, function: &syn::ItemFn) -> TestFunction { - let macro_parameters = get_test_macro_parameters(attributes); +/// Shorthand for `return syn::Error::new(...)`. +macro_rules! bail { + ($span:expr, $($tt:tt)*) => {{ + return ::core::result::Result::Err(::syn::Error::new( + ::syn::spanned::Spanned::span(&$span), + ::core::format_args!($($tt)*), + )) + }}; +} - let function_parameters = get_test_function_parameters(&function.sig.inputs); +fn parse_marked_test_function( + attributes: &AttributeArgs, + function: &syn::ItemFn, +) -> Result<TestFunction> { + let macro_parameters = get_test_macro_parameters(attributes)?; + let function_parameters = get_test_function_parameters(&function.sig.inputs)?; - TestFunction { + Ok(TestFunction { name: function.sig.ident.clone(), function_parameters, macro_parameters, - } + }) } -fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> MacroParameters { +fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result<MacroParameters> { let mut priority = None; let mut cleanup = true; let mut always_run = false; let mut must_succeed = false; - let mut target_os = None; + let mut targets = vec![]; for attribute in attributes { - if let NestedMeta::Meta(Meta::NameValue(nv)) = attribute { - if nv.path.is_ident("priority") { - match &nv.lit { - Lit::Int(lit_int) => { - priority = Some(lit_int.base10_parse().unwrap()); - } - _ => panic!("'priority' should have an integer value"), - } - } else if nv.path.is_ident("always_run") { - match &nv.lit { - Lit::Bool(lit_bool) => { - always_run = lit_bool.value(); - } - _ => panic!("'always_run' should have a bool value"), - } - } else if nv.path.is_ident("must_succeed") { - match &nv.lit { - Lit::Bool(lit_bool) => { - must_succeed = lit_bool.value(); - } - _ => panic!("'must_succeed' should have a bool value"), - } - } else if nv.path.is_ident("cleanup") { - match &nv.lit { - Lit::Bool(lit_bool) => { - cleanup = lit_bool.value(); - } - _ => panic!("'cleanup' should have a bool value"), - } - } else if nv.path.is_ident("target_os") { - match &nv.lit { - Lit::Str(lit_str) => { - target_os = Some(lit_str.value()); - } - _ => panic!("'target_os' should have a string value"), - } + // we only use name-value attributes + let NestedMeta::Meta(Meta::NameValue(nv)) = attribute else { + bail!(attribute, "unknown attribute"); + }; + let lit = &nv.lit; + + if nv.path.is_ident("priority") { + match lit { + Lit::Int(lit_int) => priority = Some(lit_int.base10_parse().unwrap()), + _ => bail!(nv, "'priority' should have an integer value"), + } + } else if nv.path.is_ident("always_run") { + match lit { + Lit::Bool(lit_bool) => always_run = lit_bool.value(), + _ => bail!(nv, "'always_run' should have a bool value"), } + } else if nv.path.is_ident("must_succeed") { + match lit { + Lit::Bool(lit_bool) => must_succeed = lit_bool.value(), + _ => bail!(nv, "'must_succeed' should have a bool value"), + } + } else if nv.path.is_ident("cleanup") { + match lit { + Lit::Bool(lit_bool) => cleanup = lit_bool.value(), + _ => bail!(nv, "'cleanup' should have a bool value"), + } + } else if nv.path.is_ident("target_os") { + let Lit::Str(lit_str) = lit else { + bail!(nv, "'target_os' should have a string value"); + }; + + let target = match lit_str.value().parse() { + Ok(os) => os, + Err(e) => bail!(lit_str, "{e}"), + }; + + if targets.contains(&target) { + bail!(nv, "Duplicate target"); + } + + targets.push(target); + } else { + bail!(nv, "unknown attribute"); } } - MacroParameters { + Ok(MacroParameters { priority, cleanup, always_run, must_succeed, - target_os, - } + targets, + }) } fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { @@ -162,17 +184,14 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { Some(priority) => quote! { Some(#priority) }, None => quote! { None }, }; - let target_os = match test_function.macro_parameters.target_os.as_deref() { - Some("linux") => quote! { Some(::test_rpc::meta::Os::Linux) }, - Some("macos") => quote! { Some(::test_rpc::meta::Os::Macos) }, - Some("windows") => quote! { Some(::test_rpc::meta::Os::Windows) }, - Some(target_os) => { - return quote! { - compile_error!("invalid target_os: {:?}", #target_os); - }; - } - None => quote! { None }, - }; + let targets: proc_macro2::TokenStream = (test_function.macro_parameters.targets.iter()) + .map(|&os| match os { + Os::Linux => quote! { ::test_rpc::meta::Os::Linux, }, + Os::Macos => quote! { ::test_rpc::meta::Os::Macos, }, + Os::Windows => quote! { ::test_rpc::meta::Os::Windows, }, + }) + .collect(); + let should_cleanup = test_function.macro_parameters.cleanup; let always_run = test_function.macro_parameters.always_run; let must_succeed = test_function.macro_parameters.must_succeed; @@ -193,7 +212,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { use std::any::Any; let mullvad_client = mullvad_client.downcast::<#mullvad_client_type>().expect("invalid mullvad client"); Box::pin(async move { - #func_name(test_context, rpc, *mullvad_client).await + #func_name(test_context, rpc, *mullvad_client).await.map_err(Into::into) }) } } @@ -202,9 +221,9 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { quote! { |test_context: crate::tests::TestContext, rpc: test_rpc::ServiceClient, - mullvad_client: Box<dyn std::any::Any + Send>| { + _mullvad_client: Box<dyn std::any::Any + Send>| { Box::pin(async move { - #func_name(test_context, rpc).await + #func_name(test_context, rpc).await.map_err(Into::into) }) } } @@ -215,7 +234,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { inventory::submit!(crate::tests::test_metadata::TestMetadata { name: stringify!(#func_name), command: stringify!(#func_name), - target_os: #target_os, + targets: &[#targets], mullvad_client_version: #function_mullvad_version, func: #wrapper_closure, priority: #test_function_priority, @@ -237,7 +256,7 @@ struct MacroParameters { cleanup: bool, always_run: bool, must_succeed: bool, - target_os: Option<String>, + targets: Vec<Os>, } enum MullvadClient { @@ -269,36 +288,38 @@ struct FunctionParameters { } fn get_test_function_parameters( - inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>, -) -> FunctionParameters { - if inputs.len() > 2 { - match inputs[2].clone() { - syn::FnArg::Typed(pat_type) => { - let mullvad_client = match &*pat_type.ty { - syn::Type::Path(syn::TypePath { path, .. }) => { - match path.segments[0].ident.to_string().as_str() { - "mullvad_management_interface" | "MullvadProxyClient" => { - let mullvad_client_version = - quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; - MullvadClient::New { - mullvad_client_type: pat_type.ty, - mullvad_client_version, - } - } - _ => panic!("cannot infer mullvad client type"), - } - } - _ => panic!("unexpected 'mullvad_client' type"), - }; - FunctionParameters { mullvad_client } - } - syn::FnArg::Receiver(_) => panic!("unexpected 'mullvad_client' arg"), - } - } else { - FunctionParameters { + args: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>, +) -> Result<FunctionParameters> { + if args.len() <= 2 { + return Ok(FunctionParameters { mullvad_client: MullvadClient::None { - mullvad_client_version: quote! { test_rpc::mullvad_daemon::MullvadClientVersion::None }, + mullvad_client_version: quote! { + test_rpc::mullvad_daemon::MullvadClientVersion::None + }, }, - } + }); } + + let arg = args[2].clone(); + let syn::FnArg::Typed(pat_type) = arg else { + bail!(arg, "unexpected 'mullvad_client' arg"); + }; + + let syn::Type::Path(syn::TypePath { path, .. }) = &*pat_type.ty else { + bail!(pat_type, "unexpected 'mullvad_client' type"); + }; + + let mullvad_client = match path.segments[0].ident.to_string().as_str() { + "mullvad_management_interface" | "MullvadProxyClient" => { + let mullvad_client_version = + quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; + MullvadClient::New { + mullvad_client_type: pat_type.ty, + mullvad_client_version, + } + } + _ => bail!(pat_type, "cannot infer mullvad client type"), + }; + + Ok(FunctionParameters { mullvad_client }) } |
