summaryrefslogtreecommitdiffhomepage
path: root/test/test-manager
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-03-20 17:00:37 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-03-20 17:00:37 +0100
commita3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3 (patch)
treed7079f9f6775583a0e4da4a061b672372e7af2b3 /test/test-manager
parent7e62d03a4366fb8eaabb13dec354bb0237cd0d08 (diff)
parent6594c6de52763ab313d99edecf9231596a003e1f (diff)
downloadmullvadvpn-a3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3.tar.xz
mullvadvpn-a3a5c67abc4ee719e09d0b4befdc60b7fb2e7ff3.zip
Merge remote-tracking branch 'origin/write-tests-for-split-tunneling-des-276'
Diffstat (limited to 'test/test-manager')
-rw-r--r--test/test-manager/src/config.rs10
-rw-r--r--test/test-manager/src/logging.rs2
-rw-r--r--test/test-manager/src/run_tests.rs6
-rw-r--r--test/test-manager/src/tests/mod.rs5
-rw-r--r--test/test-manager/src/tests/split_tunnel.rs357
-rw-r--r--test/test-manager/src/tests/test_metadata.rs6
-rw-r--r--test/test-manager/src/vm/provision.rs5
-rw-r--r--test/test-manager/src/vm/qemu.rs34
-rw-r--r--test/test-manager/test_macro/Cargo.toml1
-rw-r--r--test/test-manager/test_macro/src/lib.rs211
10 files changed, 523 insertions, 114 deletions
diff --git a/test/test-manager/src/config.rs b/test/test-manager/src/config.rs
index 1605661d53..6921c0b33f 100644
--- a/test/test-manager/src/config.rs
+++ b/test/test-manager/src/config.rs
@@ -139,6 +139,16 @@ pub struct VmConfig {
#[serde(default)]
#[arg(long)]
pub tpm: bool,
+
+ /// Override the path to `OVMF_VARS.secboot.fd`. Requires `tpm`.
+ #[serde(default)]
+ #[arg(long, requires("tpm"))]
+ pub ovmf_vars_path: Option<String>,
+
+ /// Override the path to `OVMF_CODE.secboot.fd`. Requires `tpm`.
+ #[serde(default)]
+ #[arg(long, requires("tpm"))]
+ pub ovmf_code_path: Option<String>,
}
impl VmConfig {
diff --git a/test/test-manager/src/logging.rs b/test/test-manager/src/logging.rs
index cd0bd4af28..e85920b1cd 100644
--- a/test/test-manager/src/logging.rs
+++ b/test/test-manager/src/logging.rs
@@ -1,4 +1,4 @@
-use crate::tests::Error;
+use anyhow::Error;
use colored::Colorize;
use std::sync::{Arc, Mutex};
use test_rpc::logging::{LogOutput, Output};
diff --git a/test/test-manager/src/run_tests.rs b/test/test-manager/src/run_tests.rs
index 6af1536562..6b3da37138 100644
--- a/test/test-manager/src/run_tests.rs
+++ b/test/test-manager/src/run_tests.rs
@@ -2,9 +2,7 @@ use crate::summary::{self, maybe_log_test_result};
use crate::tests::{config::TEST_CONFIG, TestContext};
use crate::{
logging::{panic_as_string, TestOutput},
- mullvad_daemon, tests,
- tests::Error,
- vm,
+ mullvad_daemon, tests, vm,
};
use anyhow::{Context, Result};
use futures::FutureExt;
@@ -187,7 +185,7 @@ pub async fn run_test<F, R, MullvadClient>(
) -> TestOutput
where
F: Fn(super::tests::TestContext, ServiceClient, MullvadClient) -> R,
- R: Future<Output = Result<(), Error>>,
+ R: Future<Output = anyhow::Result<()>>,
{
let _flushed = runner_rpc.try_poll_output().await;
diff --git a/test/test-manager/src/tests/mod.rs b/test/test-manager/src/tests/mod.rs
index 0cf1357696..48d75b9e3f 100644
--- a/test/test-manager/src/tests/mod.rs
+++ b/test/test-manager/src/tests/mod.rs
@@ -6,6 +6,7 @@ mod helpers;
mod install;
mod settings;
mod software;
+mod split_tunnel;
mod test_metadata;
mod tunnel;
mod tunnel_state;
@@ -32,7 +33,7 @@ pub type TestWrapperFunction = fn(
TestContext,
ServiceClient,
Box<dyn std::any::Any + Send>,
-) -> BoxFuture<'static, Result<(), Error>>;
+) -> BoxFuture<'static, anyhow::Result<()>>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
@@ -40,7 +41,7 @@ pub enum Error {
Rpc(#[from] test_rpc::Error),
#[error("geoip lookup failed")]
- GeoipLookup(test_rpc::Error),
+ GeoipLookup(#[source] test_rpc::Error),
#[error("Found running daemon unexpectedly")]
DaemonRunning,
diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs
new file mode 100644
index 0000000000..336ee5b5ab
--- /dev/null
+++ b/test/test-manager/src/tests/split_tunnel.rs
@@ -0,0 +1,357 @@
+use anyhow::{anyhow, bail, ensure, Context};
+use mullvad_management_interface::MullvadProxyClient;
+use pcap::Direction;
+use pnet_packet::ip::IpNextHeaderProtocols;
+use std::{
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ str,
+ time::Duration,
+};
+use test_macro::test_function;
+use test_rpc::{meta::Os, ServiceClient, SpawnOpts};
+use tokio::time::{sleep, timeout};
+
+use crate::network_monitor::{start_packet_monitor, MonitorOptions};
+
+use super::{config::TEST_CONFIG, helpers, TestContext};
+
+const CHECKER_FILENAME_WINDOWS: &str = "connection-checker.exe";
+const CHECKER_FILENAME_UNIX: &str = "connection-checker";
+const LEAK_DESTINATION: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1337);
+
+/// Test that split tunneling works by asserting the following:
+/// - Splitting a process shouldn't do anything if tunnel is not connected.
+/// - A split process should never push traffic through the tunnel.
+/// - Splitting/unsplitting should work regardless if process is running.
+#[test_function(target_os = "linux", target_os = "windows")]
+pub async fn test_split_tunnel(
+ _ctx: TestContext,
+ rpc: ServiceClient,
+ mut mullvad_client: MullvadProxyClient,
+) -> anyhow::Result<()> {
+ let mut checker = ConnChecker::new(rpc.clone(), mullvad_client.clone());
+
+ // Test that program is behaving when we are disconnected
+ (checker.spawn().await?.assert_insecure().await)
+ .with_context(|| "Test disconnected and unsplit")?;
+ checker.split().await?;
+ (checker.spawn().await?.assert_insecure().await)
+ .with_context(|| "Test disconnected and split")?;
+ checker.unsplit().await?;
+
+ // Test that program is behaving being split/unsplit while running and we are disconnected
+ let mut handle = checker.spawn().await?;
+ handle.split().await?;
+ (handle.assert_insecure().await)
+ .with_context(|| "Test disconnected and being split while running")?;
+ handle.unsplit().await?;
+ (handle.assert_insecure().await)
+ .with_context(|| "Test disconnected and being unsplit while running")?;
+ drop(handle);
+
+ helpers::connect_and_wait(&mut mullvad_client).await?;
+
+ // Test running an unsplit program
+ checker
+ .spawn()
+ .await?
+ .assert_secure()
+ .await
+ .with_context(|| "Test connected and unsplit")?;
+
+ // Test running a split program
+ checker.split().await?;
+ checker
+ .spawn()
+ .await?
+ .assert_insecure()
+ .await
+ .with_context(|| "Test connected and split")?;
+
+ checker.unsplit().await?;
+
+ // Test splitting and unsplitting a program while it's running
+ let mut handle = checker.spawn().await?;
+ (handle.assert_secure().await).with_context(|| "Test connected and unsplit (again)")?;
+ handle.split().await?;
+ (handle.assert_insecure().await)
+ .with_context(|| "Test connected and being split while running")?;
+ handle.unsplit().await?;
+ (handle.assert_secure().await)
+ .with_context(|| "Test connected and being unsplit while running")?;
+
+ Ok(())
+}
+
+/// This helper spawns a seperate process which checks if we are connected to Mullvad, and tries to
+/// leak traffic outside the tunnel by sending TCP, UDP, and ICMP packets to [LEAK_DESTINATION].
+struct ConnChecker {
+ rpc: ServiceClient,
+ mullvad_client: MullvadProxyClient,
+
+ /// Path to the process binary.
+ executable_path: String,
+
+ /// Whether the process should be split when spawned. Needed on Linux.
+ split: bool,
+}
+
+struct ConnCheckerHandle<'a> {
+ checker: &'a mut ConnChecker,
+
+ /// ID of the spawned process.
+ pid: u32,
+}
+
+struct ConnectionStatus {
+ /// True if <https://am.i.mullvad.net/> reported we are connected.
+ am_i_mullvad: bool,
+
+ /// True if we sniffed TCP packets going outside the tunnel.
+ leaked_tcp: bool,
+
+ /// True if we sniffed UDP packets going outside the tunnel.
+ leaked_udp: bool,
+
+ /// True if we sniffed ICMP packets going outside the tunnel.
+ leaked_icmp: bool,
+}
+
+impl ConnChecker {
+ pub fn new(rpc: ServiceClient, mullvad_client: MullvadProxyClient) -> Self {
+ let artifacts_dir = &TEST_CONFIG.artifacts_dir;
+ let executable_path = match TEST_CONFIG.os {
+ Os::Linux | Os::Macos => format!("{artifacts_dir}/{CHECKER_FILENAME_UNIX}"),
+ Os::Windows => format!("{artifacts_dir}\\{CHECKER_FILENAME_WINDOWS}"),
+ };
+
+ Self {
+ rpc,
+ mullvad_client,
+ split: false,
+ executable_path,
+ }
+ }
+
+ /// Spawn the connecton checker process and return a handle to it.
+ ///
+ /// Dropping the handle will stop the process.
+ /// **NOTE**: The handle must be dropped from a tokio runtime context.
+ pub async fn spawn(&mut self) -> anyhow::Result<ConnCheckerHandle<'_>> {
+ log::debug!("spawning connection checker");
+
+ let opts = SpawnOpts {
+ attach_stdin: true,
+ attach_stdout: true,
+ args: [
+ "--interactive",
+ "--timeout",
+ "10000",
+ // try to leak traffic to LEAK_DESTINATION
+ "--leak",
+ &LEAK_DESTINATION.to_string(),
+ "--leak-timeout",
+ "500",
+ "--leak-tcp",
+ "--leak-udp",
+ "--leak-icmp",
+ ]
+ .map(String::from)
+ .to_vec(),
+ ..SpawnOpts::new(&self.executable_path)
+ };
+
+ let pid = self.rpc.spawn(opts).await?;
+
+ if self.split && TEST_CONFIG.os == Os::Linux {
+ self.mullvad_client
+ .add_split_tunnel_process(pid as i32)
+ .await?;
+ }
+
+ Ok(ConnCheckerHandle { pid, checker: self })
+ }
+
+ /// Enable split tunneling for the connection checker.
+ pub async fn split(&mut self) -> anyhow::Result<()> {
+ log::debug!("enable split tunnel");
+ self.split = true;
+
+ match TEST_CONFIG.os {
+ Os::Linux => { /* linux programs can't be split until they are spawned */ }
+ Os::Windows => {
+ self.mullvad_client
+ .add_split_tunnel_app(&self.executable_path)
+ .await?;
+ self.mullvad_client.set_split_tunnel_state(true).await?;
+ }
+ Os::Macos => unimplemented!("MacOS"),
+ }
+
+ Ok(())
+ }
+
+ /// Disable split tunneling for the connection checker.
+ pub async fn unsplit(&mut self) -> anyhow::Result<()> {
+ log::debug!("disable split tunnel");
+ self.split = false;
+
+ match TEST_CONFIG.os {
+ Os::Linux => {}
+ Os::Windows => {
+ self.mullvad_client.set_split_tunnel_state(false).await?;
+ self.mullvad_client
+ .remove_split_tunnel_app(&self.executable_path)
+ .await?;
+ }
+ Os::Macos => unimplemented!("MacOS"),
+ }
+
+ Ok(())
+ }
+}
+
+impl ConnCheckerHandle<'_> {
+ pub async fn split(&mut self) -> anyhow::Result<()> {
+ if TEST_CONFIG.os == Os::Linux {
+ self.checker
+ .mullvad_client
+ .add_split_tunnel_process(self.pid as i32)
+ .await?;
+ }
+
+ self.checker.split().await
+ }
+
+ pub async fn unsplit(&mut self) -> anyhow::Result<()> {
+ if TEST_CONFIG.os == Os::Linux {
+ self.checker
+ .mullvad_client
+ .remove_split_tunnel_process(self.pid as i32)
+ .await?;
+ }
+
+ self.checker.unsplit().await
+ }
+
+ /// Assert that traffic is flowing through the Mullvad tunnel and that no packets are leaked.
+ pub async fn assert_secure(&mut self) -> anyhow::Result<()> {
+ log::info!("checking that connection is secure");
+ let status = self.check_connection().await?;
+ ensure!(status.am_i_mullvad);
+ ensure!(!status.leaked_tcp);
+ ensure!(!status.leaked_udp);
+ ensure!(!status.leaked_icmp);
+
+ Ok(())
+ }
+
+ /// Assert that traffic is NOT flowing through the Mullvad tunnel and that packets ARE leaked.
+ pub async fn assert_insecure(&mut self) -> anyhow::Result<()> {
+ log::info!("checking that connection is not secure");
+ let status = self.check_connection().await?;
+ ensure!(!status.am_i_mullvad);
+ ensure!(status.leaked_tcp);
+ ensure!(status.leaked_udp);
+ ensure!(status.leaked_icmp);
+
+ Ok(())
+ }
+
+ async fn check_connection(&mut self) -> anyhow::Result<ConnectionStatus> {
+ // Monitor all pakets going to LEAK_DESTINATION during the check.
+ let monitor = start_packet_monitor(
+ |packet| packet.destination.ip() == LEAK_DESTINATION.ip(),
+ MonitorOptions {
+ direction: Some(Direction::In),
+ ..MonitorOptions::default()
+ },
+ )
+ .await;
+
+ // Write a newline to the connection checker to prompt it to perform the check.
+ self.checker
+ .rpc
+ .write_child_stdin(self.pid, "Say the line, Bart!\r\n".into())
+ .await?;
+
+ // The checker responds when the check is complete.
+ let line = self.read_stdout_line().await?;
+
+ let monitor_result = monitor
+ .into_result()
+ .await
+ .map_err(|_e| anyhow!("Packet monitor unexpectedly stopped"))?;
+
+ Ok(ConnectionStatus {
+ am_i_mullvad: parse_am_i_mullvad(line)?,
+
+ leaked_tcp: (monitor_result.packets.iter())
+ .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Tcp),
+
+ leaked_udp: (monitor_result.packets.iter())
+ .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Udp),
+
+ leaked_icmp: (monitor_result.packets.iter())
+ .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Icmp),
+ })
+ }
+
+ /// Try to a single line of output from the spawned process
+ async fn read_stdout_line(&mut self) -> anyhow::Result<String> {
+ // Add a timeout to avoid waiting forever.
+ timeout(Duration::from_secs(8), async {
+ let mut line = String::new();
+
+ // tarpc doesn't support streams, so we poll the checker process in a loop instead
+ loop {
+ let Some(output) = self.checker.rpc.read_child_stdout(self.pid).await? else {
+ bail!("got EOF from connection checker process");
+ };
+
+ if output.is_empty() {
+ sleep(Duration::from_millis(500)).await;
+ continue;
+ }
+
+ line.push_str(&output);
+
+ if line.contains('\n') {
+ log::info!("output from child process: {output:?}");
+ return Ok(line);
+ }
+ }
+ })
+ .await
+ .with_context(|| "Timeout reading stdout from connection checker")?
+ }
+}
+
+impl Drop for ConnCheckerHandle<'_> {
+ fn drop(&mut self) {
+ let rpc = self.checker.rpc.clone();
+ let pid = self.pid;
+
+ let Ok(runtime_handle) = tokio::runtime::Handle::try_current() else {
+ log::error!("ConnCheckerHandle dropped outside of a tokio runtime.");
+ return;
+ };
+
+ runtime_handle.spawn(async move {
+ // Make sure child process is stopped when this handle is dropped.
+ // Closing stdin does the trick.
+ let _ = rpc.close_child_stdin(pid).await;
+ });
+ }
+}
+
+/// Parse output from connection-checker. Returns true if connected to Mullvad.
+fn parse_am_i_mullvad(result: String) -> anyhow::Result<bool> {
+ Ok(if result.contains("You are connected") {
+ true
+ } else if result.contains("You are not connected") {
+ false
+ } else {
+ bail!("Unexpected output from connection-checker: {result:?}")
+ })
+}
diff --git a/test/test-manager/src/tests/test_metadata.rs b/test/test-manager/src/tests/test_metadata.rs
index 3e28a4380b..d4ffa9bfd0 100644
--- a/test/test-manager/src/tests/test_metadata.rs
+++ b/test/test-manager/src/tests/test_metadata.rs
@@ -5,7 +5,7 @@ use test_rpc::mullvad_daemon::MullvadClientVersion;
pub struct TestMetadata {
pub name: &'static str,
pub command: &'static str,
- pub target_os: Option<Os>,
+ pub targets: &'static [Os],
pub mullvad_client_version: MullvadClientVersion,
pub func: TestWrapperFunction,
pub priority: Option<i32>,
@@ -16,9 +16,7 @@ pub struct TestMetadata {
impl TestMetadata {
pub fn should_run_on_os(&self, os: Os) -> bool {
- self.target_os
- .map(|target_os| target_os == os)
- .unwrap_or(true)
+ self.targets.is_empty() || self.targets.contains(&os)
}
}
diff --git a/test/test-manager/src/vm/provision.rs b/test/test-manager/src/vm/provision.rs
index 5f01e8f192..8667b6c133 100644
--- a/test/test-manager/src/vm/provision.rs
+++ b/test/test-manager/src/vm/provision.rs
@@ -106,6 +106,11 @@ fn blocking_ssh(
ssh_send_file_path(&session, &source, temp_dir)
.context("Failed to send test runner to remote")?;
+ // Transfer connection-checker
+ let source = local_runner_dir.join("connection-checker");
+ ssh_send_file_path(&session, &source, temp_dir)
+ .context("Failed to send connection-checker to remote")?;
+
// Transfer app packages
ssh_send_file_path(&session, &local_app_manifest.current_app_path, temp_dir)
.context("Failed to send current app package to remote")?;
diff --git a/test/test-manager/src/vm/qemu.rs b/test/test-manager/src/vm/qemu.rs
index 5688f47101..62613d5e1d 100644
--- a/test/test-manager/src/vm/qemu.rs
+++ b/test/test-manager/src/vm/qemu.rs
@@ -134,7 +134,7 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result<QemuInstance>
// Configure OVMF. Currently, this is enabled implicitly if using a TPM
let ovmf_handle = if vm_config.tpm {
- let handle = OvmfHandle::new().await?;
+ let handle = OvmfHandle::new(vm_config).await?;
handle.append_qemu_args(&mut qemu_cmd);
Some(handle)
} else {
@@ -202,32 +202,50 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result<QemuInstance>
/// Used to set up UEFI and append options to the QEMU command
struct OvmfHandle {
temp_vars: TempFile,
+ ovmf_code_path: String,
}
impl OvmfHandle {
- pub async fn new() -> Result<Self> {
- const OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd";
+ pub async fn new(config: &VmConfig) -> Result<Self> {
+ const DEFAULT_OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd";
+ const DEFAULT_OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd";
+
+ let ovmf_code_path = config
+ .ovmf_code_path
+ .as_deref()
+ .unwrap_or(DEFAULT_OVMF_CODE_PATH)
+ .to_owned();
+
+ let ovmf_vars_path = config
+ .ovmf_vars_path
+ .as_deref()
+ .unwrap_or(DEFAULT_OVMF_VARS_PATH);
// Create a local copy of OVMF_VARS
let temp_vars_path = random_tempfile_name();
- fs::copy(OVMF_VARS_PATH, &temp_vars_path)
+ fs::copy(ovmf_vars_path, &temp_vars_path)
.await
.map_err(Error::CopyOvmfVars)?;
let temp_vars = TempFile::from_existing(temp_vars_path, async_tempfile::Ownership::Owned)
.await
.map_err(|_| Error::WrapOvmfVars)?;
- Ok(OvmfHandle { temp_vars })
+
+ Ok(OvmfHandle {
+ temp_vars,
+ ovmf_code_path,
+ })
}
pub fn append_qemu_args(&self, qemu_cmd: &mut Command) {
- const OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd";
-
qemu_cmd.args([
"-global",
"driver=cfi.pflash01,property=secure,value=on",
"-drive",
- &format!("if=pflash,format=raw,unit=0,file={OVMF_CODE_PATH},readonly=on"),
+ &format!(
+ "if=pflash,format=raw,unit=0,file={},readonly=on",
+ self.ovmf_code_path
+ ),
"-drive",
&format!(
"if=pflash,format=raw,unit=1,file={}",
diff --git a/test/test-manager/test_macro/Cargo.toml b/test/test-manager/test_macro/Cargo.toml
index a064b6d200..19a405d08f 100644
--- a/test/test-manager/test_macro/Cargo.toml
+++ b/test/test-manager/test_macro/Cargo.toml
@@ -14,3 +14,4 @@ proc-macro = true
syn = "1.0"
quote = "1.0"
proc-macro2 = "1.0"
+test-rpc = { path = "../../test-rpc" }
diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs
index d95c3f8832..7cb8407230 100644
--- a/test/test-manager/test_macro/src/lib.rs
+++ b/test/test-manager/test_macro/src/lib.rs
@@ -1,6 +1,7 @@
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
-use syn::{AttributeArgs, Lit, Meta, NestedMeta};
+use syn::{AttributeArgs, Lit, Meta, NestedMeta, Result};
+use test_rpc::meta::Os;
/// Register an `async` function to be run by `test-manager`.
///
@@ -52,7 +53,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta};
/// pub async fn test_function(
/// rpc: ServiceClient,
/// mut mullvad_client: mullvad_management_interface::MullvadProxyClient,
-/// ) -> Result<(), Error> {
+/// ) -> anyhow::Result<()> {
/// Ok(())
/// }
/// ```
@@ -67,7 +68,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta};
/// pub async fn test_function(
/// rpc: ServiceClient,
/// mut mullvad_client: mullvad_management_interface::MullvadProxyClient,
-/// ) -> Result<(), Error> {
+/// ) -> anyhow::Result<()> {
/// Ok(())
/// }
/// ```
@@ -76,7 +77,10 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream
let function: syn::ItemFn = syn::parse(code).unwrap();
let attributes = syn::parse_macro_input!(attributes as AttributeArgs);
- let test_function = parse_marked_test_function(&attributes, &function);
+ let test_function = match parse_marked_test_function(&attributes, &function) {
+ Ok(tf) => tf,
+ Err(e) => return e.into_compile_error().into(),
+ };
let register_test = create_test(test_function);
@@ -88,73 +92,91 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream
.into()
}
-fn parse_marked_test_function(attributes: &AttributeArgs, function: &syn::ItemFn) -> TestFunction {
- let macro_parameters = get_test_macro_parameters(attributes);
+/// Shorthand for `return syn::Error::new(...)`.
+macro_rules! bail {
+ ($span:expr, $($tt:tt)*) => {{
+ return ::core::result::Result::Err(::syn::Error::new(
+ ::syn::spanned::Spanned::span(&$span),
+ ::core::format_args!($($tt)*),
+ ))
+ }};
+}
- let function_parameters = get_test_function_parameters(&function.sig.inputs);
+fn parse_marked_test_function(
+ attributes: &AttributeArgs,
+ function: &syn::ItemFn,
+) -> Result<TestFunction> {
+ let macro_parameters = get_test_macro_parameters(attributes)?;
+ let function_parameters = get_test_function_parameters(&function.sig.inputs)?;
- TestFunction {
+ Ok(TestFunction {
name: function.sig.ident.clone(),
function_parameters,
macro_parameters,
- }
+ })
}
-fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> MacroParameters {
+fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result<MacroParameters> {
let mut priority = None;
let mut cleanup = true;
let mut always_run = false;
let mut must_succeed = false;
- let mut target_os = None;
+ let mut targets = vec![];
for attribute in attributes {
- if let NestedMeta::Meta(Meta::NameValue(nv)) = attribute {
- if nv.path.is_ident("priority") {
- match &nv.lit {
- Lit::Int(lit_int) => {
- priority = Some(lit_int.base10_parse().unwrap());
- }
- _ => panic!("'priority' should have an integer value"),
- }
- } else if nv.path.is_ident("always_run") {
- match &nv.lit {
- Lit::Bool(lit_bool) => {
- always_run = lit_bool.value();
- }
- _ => panic!("'always_run' should have a bool value"),
- }
- } else if nv.path.is_ident("must_succeed") {
- match &nv.lit {
- Lit::Bool(lit_bool) => {
- must_succeed = lit_bool.value();
- }
- _ => panic!("'must_succeed' should have a bool value"),
- }
- } else if nv.path.is_ident("cleanup") {
- match &nv.lit {
- Lit::Bool(lit_bool) => {
- cleanup = lit_bool.value();
- }
- _ => panic!("'cleanup' should have a bool value"),
- }
- } else if nv.path.is_ident("target_os") {
- match &nv.lit {
- Lit::Str(lit_str) => {
- target_os = Some(lit_str.value());
- }
- _ => panic!("'target_os' should have a string value"),
- }
+ // we only use name-value attributes
+ let NestedMeta::Meta(Meta::NameValue(nv)) = attribute else {
+ bail!(attribute, "unknown attribute");
+ };
+ let lit = &nv.lit;
+
+ if nv.path.is_ident("priority") {
+ match lit {
+ Lit::Int(lit_int) => priority = Some(lit_int.base10_parse().unwrap()),
+ _ => bail!(nv, "'priority' should have an integer value"),
+ }
+ } else if nv.path.is_ident("always_run") {
+ match lit {
+ Lit::Bool(lit_bool) => always_run = lit_bool.value(),
+ _ => bail!(nv, "'always_run' should have a bool value"),
}
+ } else if nv.path.is_ident("must_succeed") {
+ match lit {
+ Lit::Bool(lit_bool) => must_succeed = lit_bool.value(),
+ _ => bail!(nv, "'must_succeed' should have a bool value"),
+ }
+ } else if nv.path.is_ident("cleanup") {
+ match lit {
+ Lit::Bool(lit_bool) => cleanup = lit_bool.value(),
+ _ => bail!(nv, "'cleanup' should have a bool value"),
+ }
+ } else if nv.path.is_ident("target_os") {
+ let Lit::Str(lit_str) = lit else {
+ bail!(nv, "'target_os' should have a string value");
+ };
+
+ let target = match lit_str.value().parse() {
+ Ok(os) => os,
+ Err(e) => bail!(lit_str, "{e}"),
+ };
+
+ if targets.contains(&target) {
+ bail!(nv, "Duplicate target");
+ }
+
+ targets.push(target);
+ } else {
+ bail!(nv, "unknown attribute");
}
}
- MacroParameters {
+ Ok(MacroParameters {
priority,
cleanup,
always_run,
must_succeed,
- target_os,
- }
+ targets,
+ })
}
fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream {
@@ -162,17 +184,14 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream {
Some(priority) => quote! { Some(#priority) },
None => quote! { None },
};
- let target_os = match test_function.macro_parameters.target_os.as_deref() {
- Some("linux") => quote! { Some(::test_rpc::meta::Os::Linux) },
- Some("macos") => quote! { Some(::test_rpc::meta::Os::Macos) },
- Some("windows") => quote! { Some(::test_rpc::meta::Os::Windows) },
- Some(target_os) => {
- return quote! {
- compile_error!("invalid target_os: {:?}", #target_os);
- };
- }
- None => quote! { None },
- };
+ let targets: proc_macro2::TokenStream = (test_function.macro_parameters.targets.iter())
+ .map(|&os| match os {
+ Os::Linux => quote! { ::test_rpc::meta::Os::Linux, },
+ Os::Macos => quote! { ::test_rpc::meta::Os::Macos, },
+ Os::Windows => quote! { ::test_rpc::meta::Os::Windows, },
+ })
+ .collect();
+
let should_cleanup = test_function.macro_parameters.cleanup;
let always_run = test_function.macro_parameters.always_run;
let must_succeed = test_function.macro_parameters.must_succeed;
@@ -193,7 +212,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream {
use std::any::Any;
let mullvad_client = mullvad_client.downcast::<#mullvad_client_type>().expect("invalid mullvad client");
Box::pin(async move {
- #func_name(test_context, rpc, *mullvad_client).await
+ #func_name(test_context, rpc, *mullvad_client).await.map_err(Into::into)
})
}
}
@@ -202,9 +221,9 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream {
quote! {
|test_context: crate::tests::TestContext,
rpc: test_rpc::ServiceClient,
- mullvad_client: Box<dyn std::any::Any + Send>| {
+ _mullvad_client: Box<dyn std::any::Any + Send>| {
Box::pin(async move {
- #func_name(test_context, rpc).await
+ #func_name(test_context, rpc).await.map_err(Into::into)
})
}
}
@@ -215,7 +234,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream {
inventory::submit!(crate::tests::test_metadata::TestMetadata {
name: stringify!(#func_name),
command: stringify!(#func_name),
- target_os: #target_os,
+ targets: &[#targets],
mullvad_client_version: #function_mullvad_version,
func: #wrapper_closure,
priority: #test_function_priority,
@@ -237,7 +256,7 @@ struct MacroParameters {
cleanup: bool,
always_run: bool,
must_succeed: bool,
- target_os: Option<String>,
+ targets: Vec<Os>,
}
enum MullvadClient {
@@ -269,36 +288,38 @@ struct FunctionParameters {
}
fn get_test_function_parameters(
- inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
-) -> FunctionParameters {
- if inputs.len() > 2 {
- match inputs[2].clone() {
- syn::FnArg::Typed(pat_type) => {
- let mullvad_client = match &*pat_type.ty {
- syn::Type::Path(syn::TypePath { path, .. }) => {
- match path.segments[0].ident.to_string().as_str() {
- "mullvad_management_interface" | "MullvadProxyClient" => {
- let mullvad_client_version =
- quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New };
- MullvadClient::New {
- mullvad_client_type: pat_type.ty,
- mullvad_client_version,
- }
- }
- _ => panic!("cannot infer mullvad client type"),
- }
- }
- _ => panic!("unexpected 'mullvad_client' type"),
- };
- FunctionParameters { mullvad_client }
- }
- syn::FnArg::Receiver(_) => panic!("unexpected 'mullvad_client' arg"),
- }
- } else {
- FunctionParameters {
+ args: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
+) -> Result<FunctionParameters> {
+ if args.len() <= 2 {
+ return Ok(FunctionParameters {
mullvad_client: MullvadClient::None {
- mullvad_client_version: quote! { test_rpc::mullvad_daemon::MullvadClientVersion::None },
+ mullvad_client_version: quote! {
+ test_rpc::mullvad_daemon::MullvadClientVersion::None
+ },
},
- }
+ });
}
+
+ let arg = args[2].clone();
+ let syn::FnArg::Typed(pat_type) = arg else {
+ bail!(arg, "unexpected 'mullvad_client' arg");
+ };
+
+ let syn::Type::Path(syn::TypePath { path, .. }) = &*pat_type.ty else {
+ bail!(pat_type, "unexpected 'mullvad_client' type");
+ };
+
+ let mullvad_client = match path.segments[0].ident.to_string().as_str() {
+ "mullvad_management_interface" | "MullvadProxyClient" => {
+ let mullvad_client_version =
+ quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New };
+ MullvadClient::New {
+ mullvad_client_type: pat_type.ty,
+ mullvad_client_version,
+ }
+ }
+ _ => bail!(pat_type, "cannot infer mullvad client type"),
+ };
+
+ Ok(FunctionParameters { mullvad_client })
}