diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2025-04-28 16:53:04 +0200 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2025-05-28 13:25:32 +0200 |
| commit | 9ca3e7493077c9d3ddff98c657d18a2501701aec (patch) | |
| tree | 3d2e5a098edb92312cd2b22fa868ba4cb7c8d442 | |
| parent | bc56796a0a7726392e2c69317b1a519db2f6d988 (diff) | |
| download | mullvadvpn-9ca3e7493077c9d3ddff98c657d18a2501701aec.tar.xz mullvadvpn-9ca3e7493077c9d3ddff98c657d18a2501701aec.zip | |
Add tests for in app upgrade tests in daemon (#8015)
Also add check for metadata version
| -rw-r--r-- | installer-downloader/src/controller.rs | 2 | ||||
| -rw-r--r-- | mullvad-api/src/version.rs | 12 | ||||
| -rw-r--r-- | mullvad-daemon/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 31 | ||||
| -rw-r--r-- | mullvad-daemon/src/version/check.rs | 182 | ||||
| -rw-r--r-- | mullvad-daemon/src/version/downloader.rs | 99 | ||||
| -rw-r--r-- | mullvad-daemon/src/version/router.rs | 809 | ||||
| -rw-r--r-- | mullvad-update/mullvad-release/src/main.rs | 3 | ||||
| -rw-r--r-- | mullvad-update/mullvad-release/src/platform.rs | 9 | ||||
| -rw-r--r-- | mullvad-update/src/client/app.rs | 31 | ||||
| -rw-r--r-- | mullvad-update/src/version.rs | 3 |
11 files changed, 888 insertions, 295 deletions
diff --git a/installer-downloader/src/controller.rs b/installer-downloader/src/controller.rs index a279498e27..a5e2d6aabe 100644 --- a/installer-downloader/src/controller.rs +++ b/installer-downloader/src/controller.rs @@ -146,7 +146,7 @@ where // For the downloader, the rollout version is always preferred rollout: mullvad_update::version::IGNORE, // The downloader allows any version - lowest_metadata_version: 0, + lowest_metadata_version: mullvad_update::version::MIN_VERIFY_METADATA_VERSION, }; let err = match version_provider.get_version_info(version_params).await { diff --git a/mullvad-api/src/version.rs b/mullvad-api/src/version.rs index 78f5ded1f9..490fc5b5fd 100644 --- a/mullvad-api/src/version.rs +++ b/mullvad-api/src/version.rs @@ -57,7 +57,7 @@ impl AppVersionProxy { architecture: mullvad_update::format::Architecture, rollout: f32, lowest_metadata_version: usize, - ) -> impl Future<Output = Result<VersionInfo, rest::Error>> + use<> { + ) -> impl Future<Output = Result<(VersionInfo, usize), rest::Error>> + use<> { let service = self.handle.service.clone(); let path = format!("app/releases/{platform}.json"); let request = self.handle.factory.get(&path); @@ -79,9 +79,13 @@ impl AppVersionProxy { lowest_metadata_version, }; - VersionInfo::try_from_response(¶ms, response.signed) - .map_err(Arc::new) - .map_err(rest::Error::FetchVersions) + let metadata_version = response.signed.metadata_version; + Ok(( + VersionInfo::try_from_response(¶ms, response.signed) + .map_err(Arc::new) + .map_err(rest::Error::FetchVersions)?, + metadata_version, + )) } } } diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml index c7a25031fa..07a2a50a4d 100644 --- a/mullvad-daemon/Cargo.toml +++ b/mullvad-daemon/Cargo.toml @@ -58,8 +58,8 @@ talpid-time = { path = "../talpid-time", features = ["test"] } tokio = { workspace = true, features = ["test-util"] } [target.'cfg(target_os="android")'.dependencies] -android_logger = "0.8" async-trait = "0.1" +android_logger = "0.8" hickory-resolver = { workspace = true } [target.'cfg(unix)'.dependencies] diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 60832c708e..7a9a2d36ae 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -900,7 +900,7 @@ impl Daemon { on_relay_list_update, ); - let version_handle = version::router::VersionRouter::spawn( + let version_handle = version::router::spawn_version_router( api_handle.clone(), api_handle.availability.clone(), config.cache_dir.clone(), @@ -3228,14 +3228,35 @@ impl Daemon { Self::oneshot_send(tx, (), "on_toggle_relay response"); } + #[cfg_attr(not(in_app_upgrade), allow(clippy::unused_async))] async fn on_app_upgrade(&self, tx: ResponseTx<(), version::Error>) { - let result = self.version_handle.update_application().await; - Self::oneshot_send(tx, result, "on_app_upgrade response"); + #[cfg(update)] + { + let result = self.version_handle.update_application().await; + Self::oneshot_send(tx, result, "on_app_upgrade response"); + } + #[cfg(not(update))] + { + log::warn!("Ignoring app upgrade command as in-app upgrades are disabled on this OS"); + Self::oneshot_send(tx, Ok(()), "on_app_upgrade response") + }; } + #[cfg_attr(not(in_app_upgrade), allow(clippy::unused_async))] async fn on_app_upgrade_abort(&self, tx: ResponseTx<(), version::Error>) { - let result = self.version_handle.cancel_update().await; - Self::oneshot_send(tx, result, "on_app_upgrade_abort response"); + #[cfg(update)] + { + let result = self.version_handle.cancel_update().await; + Self::oneshot_send(tx, result, "on_app_upgrade_abort response"); + } + #[cfg(not(update))] + { + log::warn!( + "Ignoring cancel app upgrade command as in-app upgrades are disabled on this OS" + ); + + Self::oneshot_send(tx, Ok(()), "on_app_upgrade_abort response") + }; } /// Set the target state of the client. If it changed trigger the operations needed to diff --git a/mullvad-daemon/src/version/check.rs b/mullvad-daemon/src/version/check.rs index e6ebff9920..e90cf7e431 100644 --- a/mullvad-daemon/src/version/check.rs +++ b/mullvad-daemon/src/version/check.rs @@ -55,6 +55,8 @@ pub(super) struct VersionCache { pub current_version_supported: bool, /// The latest available versions pub latest_version: mullvad_update::version::VersionInfo, + #[cfg(update)] + pub metadata_version: usize, } pub(crate) struct VersionUpdater(()); @@ -67,40 +69,17 @@ struct VersionUpdaterInner { get_version_info_responders: Vec<oneshot::Sender<VersionCache>>, } -type VersionUpdateCommand = oneshot::Sender<VersionCache>; - -#[derive(Clone)] -pub(crate) struct VersionUpdaterHandle { - tx: mpsc::UnboundedSender<VersionUpdateCommand>, -} - -impl VersionUpdaterHandle { - /// Get the latest cached [AppVersionInfo]. - /// - /// If the cache is stale or missing, this will immediately query the API for the latest - /// version. This may take a few seconds. - pub(super) async fn get_version_info(&self) -> Result<VersionCache, Error> { - let (done_tx, done_rx) = oneshot::channel(); - if self.tx.unbounded_send(done_tx).is_err() { - Err(Error::VersionUpdaterDown) - } else { - done_rx.await.map_err(|_| Error::UpdateAborted) - } - } -} - impl VersionUpdater { pub(super) async fn spawn( mut api_handle: MullvadRestHandle, availability_handle: ApiAvailability, cache_dir: PathBuf, update_sender: mpsc::UnboundedSender<VersionCache>, - ) -> VersionUpdaterHandle { + refresh_rx: mpsc::UnboundedReceiver<()>, + ) { // load the last known AppVersionInfo from cache let last_app_version_info = load_cache(&cache_dir).await; - let (tx, rx) = mpsc::unbounded(); - api_handle.factory = api_handle.factory.default_timeout(DOWNLOAD_TIMEOUT); let version_proxy = AppVersionProxy::new(api_handle); let cache_path = cache_dir.join(VERSION_INFO_FILENAME); @@ -112,7 +91,7 @@ impl VersionUpdater { get_version_info_responders: vec![], } .run( - rx, + refresh_rx, UpdateContext { cache_path, update_sender, @@ -124,8 +103,6 @@ impl VersionUpdater { }, ), ); - - VersionUpdaterHandle { tx } } } @@ -135,13 +112,38 @@ impl VersionUpdaterInner { self.last_app_version_info.as_ref().map(|(info, _)| info) } + #[cfg(update)] + pub fn get_min_metadata_version(&self) -> usize { + self.last_app_version_info + .as_ref() + // Reject version responses with a lower metadata version + // than the newest version we know about. This is + // important to prevent downgrade attacks. + .map(|(info, _)| info.metadata_version) + .unwrap_or(mullvad_update::version::MIN_VERIFY_METADATA_VERSION) + } + + #[cfg(not(update))] + pub fn get_min_metadata_version(&self) -> usize { + mullvad_update::version::MIN_VERIFY_METADATA_VERSION + } + /// Update [Self::last_app_version_info] and write it to disk cache, and notify the `update` /// callback. + #[allow(unused_mut)] async fn update_version_info( &mut self, update: &impl Fn(VersionCache) -> BoxFuture<'static, Result<(), Error>>, - new_version_info: VersionCache, + mut new_version_info: VersionCache, ) { + #[cfg(update)] + if let Some((current_cache, _)) = self.last_app_version_info.as_ref() { + if current_cache.metadata_version == new_version_info.metadata_version { + log::trace!("Ignoring version info with same metadata version"); + new_version_info = current_cache.clone(); + } + } + if let Err(err) = update(new_version_info.clone()).await { log::error!("Failed to save version cache to disk: {}", err); } @@ -191,55 +193,66 @@ impl VersionUpdaterInner { async fn run( self, - mut rx: mpsc::UnboundedReceiver<VersionUpdateCommand>, + mut refresh_rx: mpsc::UnboundedReceiver<()>, update: UpdateContext, api: ApiContext, ) { // If this is a dev build, there's no need to pester the API for version checks. if *IS_DEV_BUILD { log::warn!("Not checking for updates because this is a development build"); - while let Some(done_tx) = rx.next().await { + while let Some(()) = refresh_rx.next().await { log::info!("Version check is disabled in dev builds"); - let _ = done_tx.send(dev_version_cache()); } return; } let update = |info| Box::pin(update.update(info)) as BoxFuture<'static, _>; - let do_version_check = || do_version_check(api.clone()); - let do_version_check_in_background = || do_version_check_in_background(api.clone()); + let do_version_check = + |min_metadata_version| do_version_check(api.clone(), min_metadata_version); + let do_version_check_in_background = |min_metadata_version| { + do_version_check_in_background(api.clone(), min_metadata_version) + }; - self.run_inner(rx, update, do_version_check, do_version_check_in_background) - .await + self.run_inner( + refresh_rx, + update, + do_version_check, + do_version_check_in_background, + ) + .await } async fn run_inner( mut self, - mut rx: mpsc::UnboundedReceiver<VersionUpdateCommand>, + mut refresh_rx: mpsc::UnboundedReceiver<()>, update: impl Fn(VersionCache) -> BoxFuture<'static, Result<(), Error>>, - do_version_check: impl Fn() -> BoxFuture<'static, Result<VersionCache, Error>>, - do_version_check_in_background: impl Fn() -> BoxFuture<'static, Result<VersionCache, Error>>, + do_version_check: impl Fn(usize) -> BoxFuture<'static, Result<VersionCache, Error>>, + do_version_check_in_background: impl Fn( + usize, + ) + -> BoxFuture<'static, Result<VersionCache, Error>>, ) { let mut version_is_stale = self.wait_until_version_is_stale(); let mut version_check = futures::future::Fuse::terminated(); loop { futures::select! { - command = rx.next() => match command { + command = refresh_rx.next() => match command { - Some(done_tx) => { + Some(()) => { match (self.version_is_stale(), self.last_app_version_info()) { - (false, Some(version_info)) => { + (false, Some(version_cache)) => { // if the version_info isn't stale, return it immediately. - let _ = done_tx.send(version_info.clone()); + if let Err(err) = update(version_cache.clone()).await { + log::error!("Failed to save version cache to disk: {}", err); + } } _ => { // otherwise, start a foreground query to get the latest version_info. if !self.is_running_version_check() { - version_check = do_version_check().fuse(); + version_check = do_version_check(self.get_min_metadata_version()).fuse(); } - self.get_version_info_responders.retain(|r| !r.is_canceled()); - self.get_version_info_responders.push(done_tx); + } } } @@ -254,23 +267,17 @@ impl VersionUpdaterInner { if self.is_running_version_check() { continue; } - version_check = do_version_check_in_background().fuse(); + version_check = do_version_check_in_background(self.get_min_metadata_version()).fuse(); }, response = version_check => { match response { Ok(version_info) => { - // Respond to all pending GetVersionInfo commands - for done_tx in self.get_version_info_responders.drain(..) { - let _ = done_tx.send(version_info.clone()); - } - self.update_version_info(&update, version_info).await; } Err(err) => { log::error!("Failed to fetch version info: {err:#}"); - self.get_version_info_responders.clear(); } } @@ -314,10 +321,13 @@ struct ApiContext { } /// Immediately query the API for the latest [AppVersionInfo]. -fn do_version_check(api: ApiContext) -> BoxFuture<'static, Result<VersionCache, Error>> { +fn do_version_check( + api: ApiContext, + min_metadata_version: usize, +) -> BoxFuture<'static, Result<VersionCache, Error>> { let api_handle = api.api_handle.clone(); - let download_future_factory = move || version_check_inner(&api); + let download_future_factory = move || version_check_inner(&api, min_metadata_version); // retry immediately on network errors (unless we're offline) let should_retry_immediate = move |result: &Result<_, Error>| { @@ -340,10 +350,11 @@ fn do_version_check(api: ApiContext) -> BoxFuture<'static, Result<VersionCache, /// On any error, this function retries repeatedly every [UPDATE_INTERVAL_ERROR] until success. fn do_version_check_in_background( api: ApiContext, + min_metadata_version: usize, ) -> BoxFuture<'static, Result<VersionCache, Error>> { let download_future_factory = move || { let when_available = api.api_handle.wait_background(); - let version_cache = version_check_inner(&api); + let version_cache = version_check_inner(&api, min_metadata_version); async move { when_available.await.map_err(Error::ApiCheck)?; version_cache.await @@ -358,34 +369,49 @@ fn do_version_check_in_background( } /// Combine the old version and new version endpoint -#[cfg(any(target_os = "windows", target_os = "macos"))] -fn version_check_inner(api: &ApiContext) -> impl Future<Output = Result<VersionCache, Error>> { +#[cfg(update)] +fn version_check_inner( + api: &ApiContext, + min_metadata_version: usize, +) -> impl Future<Output = Result<VersionCache, Error>> { let v1_endpoint = api.version_proxy.version_check( mullvad_version::VERSION.to_owned(), PLATFORM, api.platform_version.clone(), ); + + let architecture = match talpid_platform_metadata::get_native_arch() + .expect("IO error while getting native architecture") + .expect("Failed to get native architecture") + { + talpid_platform_metadata::Architecture::X86 => mullvad_update::format::Architecture::X86, + talpid_platform_metadata::Architecture::Arm64 => { + mullvad_update::format::Architecture::Arm64 + } + }; let v2_endpoint = api.version_proxy.version_check_2( PLATFORM, - // TODO: get current architecture (from talpid_platform_metadata) - mullvad_update::format::Architecture::X86, - // TODO: set reasonable rollout, - 0., - // TODO: set last known metadata version + 1 - 0, + architecture, + mullvad_update::version::IGNORE, + min_metadata_version, ); async move { let (v1_response, v2_response) = tokio::try_join!(v1_endpoint, v2_endpoint).map_err(Error::Download)?; Ok(VersionCache { current_version_supported: v1_response.supported, - latest_version: v2_response, + latest_version: v2_response.0, + metadata_version: v2_response.1, }) } } -#[cfg(any(target_os = "linux", target_os = "android"))] -fn version_check_inner(api: &ApiContext) -> impl Future<Output = Result<VersionCache, Error>> { +#[cfg(not(update))] +fn version_check_inner( + api: &ApiContext, + // NOTE: This is unused when `update` is disabled + _min_metadata_version: usize, +) -> impl Future<Output = Result<VersionCache, Error>> { let v1_endpoint = api.version_proxy.version_check( mullvad_version::VERSION.to_owned(), PLATFORM, @@ -507,6 +533,8 @@ fn dev_version_cache() -> VersionCache { }, beta: None, }, + #[cfg(update)] + metadata_version: 0, } } @@ -687,17 +715,17 @@ mod test { updated.store(false, Ordering::SeqCst); - // The next request should do nothing + // The next request should trigger an update, even if the version has not changed send_version_request(&mut tx).await.unwrap(); talpid_time::sleep(Duration::from_secs(1)).await; - assert!(!updated.load(Ordering::SeqCst), "expected cached version"); + assert!(updated.load(Ordering::SeqCst), "expected cached version"); } async fn send_version_request( - tx: &mut mpsc::UnboundedSender<VersionUpdateCommand>, + tx: &mut mpsc::UnboundedSender<()>, ) -> Result<(), futures::channel::mpsc::SendError> { - let (done_tx, _done_rx) = oneshot::channel(); - tx.send(done_tx).await + tx.send(()).await?; + Ok(()) } fn fake_updater( @@ -709,11 +737,15 @@ mod test { } } - fn fake_version_check() -> BoxFuture<'static, Result<VersionCache, Error>> { + fn fake_version_check( + _min_metadata_version: usize, + ) -> BoxFuture<'static, Result<VersionCache, Error>> { Box::pin(async { Ok(fake_version_response()) }) } - fn fake_version_check_err() -> BoxFuture<'static, Result<VersionCache, Error>> { + fn fake_version_check_err( + _min_metadata_version: usize, + ) -> BoxFuture<'static, Result<VersionCache, Error>> { Box::pin(retry_future( || async { Err(Error::Download(mullvad_api::rest::Error::TimeoutError)) }, |_| true, @@ -735,6 +767,8 @@ mod test { }, beta: None, }, + #[cfg(update)] + metadata_version: 0, } } } diff --git a/mullvad-daemon/src/version/downloader.rs b/mullvad-daemon/src/version/downloader.rs index f7cd2cc5cd..b2b7146f0d 100644 --- a/mullvad-daemon/src/version/downloader.rs +++ b/mullvad-daemon/src/version/downloader.rs @@ -1,12 +1,11 @@ #![cfg(update)] use mullvad_types::version::{AppUpgradeDownloadProgress, AppUpgradeError, AppUpgradeEvent}; -use mullvad_update::app::{ - AppDownloader, AppDownloaderParameters, DownloadError, HttpAppDownloader, -}; +use mullvad_update::app::{bin_path, AppDownloader, AppDownloaderParameters, DownloadError}; use rand::seq::SliceRandom; use std::path::PathBuf; use std::time::{Duration, Instant}; +use talpid_types::ErrorExt; use tokio::fs; use tokio::sync::broadcast; @@ -28,7 +27,7 @@ pub enum Error { NoUrlFound, } -type Result<T> = std::result::Result<T, Error>; +pub type Result<T> = std::result::Result<T, Error>; #[derive(Debug)] pub struct DownloaderHandle { @@ -48,41 +47,58 @@ impl Drop for DownloaderHandle { } } -impl DownloaderHandle { - /// Wait for the downloader to finish - pub async fn wait(&mut self) -> Result<PathBuf> { - let path = (&mut self.task).await?; +impl std::future::Future for DownloaderHandle { + type Output = Result<PathBuf>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Self::Output> { + let task = std::pin::Pin::new(&mut self.task); + let ready = futures::ready!(task.poll(cx))?; self.dropped_tx = None; // Prevent sending the aborted event after successful download - path + std::task::Poll::Ready(ready) } } -pub fn spawn_downloader( +pub fn spawn_downloader<D>( version: mullvad_update::version::Version, event_tx: broadcast::Sender<AppUpgradeEvent>, -) -> DownloaderHandle { +) -> DownloaderHandle +where + D: AppDownloader + Send + 'static, + D: From<AppDownloaderParameters<ProgressUpdater>>, +{ DownloaderHandle { - task: tokio::spawn(start(version, event_tx.clone())), + task: tokio::spawn(start::<D>(version, event_tx.clone())), dropped_tx: Some(event_tx), } } /// Begin or resume download of `version` -async fn start( +async fn start<D>( version: mullvad_update::version::Version, event_tx: broadcast::Sender<AppUpgradeEvent>, -) -> Result<PathBuf> { +) -> Result<PathBuf> +where + D: AppDownloader + Send + 'static, + D: From<AppDownloaderParameters<ProgressUpdater>>, +{ let url = select_cdn_url(&version.urls) .ok_or(Error::NoUrlFound)? .to_owned(); log::info!("Downloading app version '{}' from {url}", version.version); - let download_dir = mullvad_paths::cache_dir()?.join("mullvad-update"); - log::trace!("Download directory: {download_dir:?}"); - fs::create_dir_all(&download_dir) - .await - .map_err(Error::CreateDownloadDir)?; + let download_dir = if cfg!(test) { + PathBuf::new() + } else { + create_download_dir().await.inspect_err(|err| { + log::error!("Failed to get download directory: {}", err.display_chain()); + let _ = event_tx.send(AppUpgradeEvent::Error(AppUpgradeError::GeneralError)); + })? + }; + let bin_path = bin_path(&version.version, &download_dir); let params = AppDownloaderParameters { app_version: version.version, @@ -92,27 +108,48 @@ async fn start( app_sha256: version.sha256, cache_dir: download_dir, }; - let mut downloader = HttpAppDownloader::from(params); + let mut downloader = D::from(params); - if let Err(download_err) = downloader.download_executable().await { - log::error!("Failed to download app: {download_err}"); + let _ = event_tx.send(AppUpgradeEvent::DownloadStarting); + if let Err(err) = downloader.download_executable().await { let _ = event_tx.send(AppUpgradeEvent::Error(AppUpgradeError::DownloadFailed)); - return Err(download_err.into()); + log::error!("{}", err.display_chain()); + log::info!("Cleaning up download at '{bin_path:?}'",); + #[cfg(not(test))] + tokio::fs::remove_file(&bin_path) + .await + .expect("Removing download file"); + return Err(err.into()); }; - let _ = event_tx.send(AppUpgradeEvent::VerifyingInstaller); - - if let Err(verify_err) = downloader.verify().await { - log::error!("Failed to verify downloaded app: {verify_err}"); + if let Err(err) = downloader.verify().await { let _ = event_tx.send(AppUpgradeEvent::Error(AppUpgradeError::VerificationFailed)); - return Err(verify_err.into()); + log::error!("{}", err.display_chain()); + log::info!("Cleaning up download at '{:?}'", bin_path); + #[cfg(not(test))] + tokio::fs::remove_file(&bin_path) + .await + .expect("Removing download file"); + return Err(err.into()); }; - let _ = event_tx.send(AppUpgradeEvent::VerifiedInstaller); - Ok(downloader.bin_path()) + + // Note that we cannot call `downloader.install()` here, as it must be done by the user process. + // Instead, the GUI is responsible for launching the installer. + + Ok(bin_path) +} + +async fn create_download_dir() -> Result<PathBuf> { + let download_dir = mullvad_paths::cache_dir()?.join("mullvad-update"); + log::trace!("Download directory: {download_dir:?}"); + fs::create_dir_all(&download_dir) + .await + .map_err(Error::CreateDownloadDir)?; + Ok(download_dir) } -struct ProgressUpdater { +pub struct ProgressUpdater { server: String, event_tx: broadcast::Sender<AppUpgradeEvent>, complete_frac: f32, diff --git a/mullvad-daemon/src/version/router.rs b/mullvad-daemon/src/version/router.rs index 6e9eec9f45..7640122c54 100644 --- a/mullvad-daemon/src/version/router.rs +++ b/mullvad-daemon/src/version/router.rs @@ -1,21 +1,22 @@ -use std::future::Future; +use std::ops::ControlFlow; use std::path::PathBuf; -use std::pin::Pin; use futures::channel::{mpsc, oneshot}; -use futures::future::{Fuse, FusedFuture}; use futures::stream::StreamExt; -use futures::FutureExt; use mullvad_api::{availability::ApiAvailability, rest::MullvadRestHandle}; use mullvad_types::version::{AppVersionInfo, SuggestedUpgrade}; +#[cfg(update)] +use mullvad_update::app::{AppDownloader, AppDownloaderParameters, HttpAppDownloader}; use mullvad_update::version::VersionInfo; use talpid_core::mpsc::Sender; use crate::management_interface::AppUpgradeBroadcast; use crate::DaemonEventSender; +#[cfg(update)] +use super::downloader::ProgressUpdater; use super::{ - check::{self, VersionCache, VersionUpdater}, + check::{VersionCache, VersionUpdater}, Error, }; @@ -23,7 +24,7 @@ use super::{ use super::downloader; use std::mem; -type Result<T> = std::result::Result<T, Error>; +pub type Result<T> = std::result::Result<T, Error>; #[derive(Clone)] pub struct VersionRouterHandle { @@ -66,29 +67,45 @@ impl VersionRouterHandle { } } +// These wrapper traits and type aliases exist to help feature gate the module +#[cfg(update)] +trait Downloader: + AppDownloader + Send + 'static + From<AppDownloaderParameters<ProgressUpdater>> +{ +} +#[cfg(not(update))] +trait Downloader {} + +#[cfg(update)] +type DefaultDownloader = HttpAppDownloader<ProgressUpdater>; +#[cfg(not(update))] +type DefaultDownloader = (); + +impl Downloader for DefaultDownloader {} + /// Router of version updates and update requests. /// /// New available app version events are forwarded from the [`VersionUpdater`]. /// If an update is in progress, these events are paused until the update is completed or canceled. /// This is done to prevent frontends from confusing which version is currently being installed, /// in case new version info is received while the update is in progress. -pub struct VersionRouter { - rx: mpsc::UnboundedReceiver<Message>, +struct VersionRouter<S = DaemonEventSender<AppVersionInfo>, D = DefaultDownloader> { + daemon_rx: mpsc::UnboundedReceiver<Message>, state: State, beta_program: bool, - version_event_sender: DaemonEventSender<AppVersionInfo>, - /// Version updater - version_check: check::VersionUpdaterHandle, + version_event_sender: S, + /// Channel used to trigger a version check. The result will always be sent to the + /// `new_version_rx` channel. + refresh_version_check_tx: mpsc::UnboundedSender<()>, /// Channel used to receive updates from `version_check` new_version_rx: mpsc::UnboundedReceiver<VersionCache>, - /// Future that resolves when `get_latest_version` resolves - version_request: Fuse<Pin<Box<dyn Future<Output = Result<VersionCache>> + Send>>>, /// Channels that receive responses to `get_latest_version` version_request_channels: Vec<oneshot::Sender<Result<AppVersionInfo>>>, - /// Broadcast channel for app upgrade events #[cfg(update)] app_upgrade_broadcast: AppUpgradeBroadcast, + /// Type used to spawn the downloader task, replaced when testing + _phantom: std::marker::PhantomData<D>, } enum Message { @@ -133,6 +150,11 @@ enum State { }, } +struct AppVersionInfoEvent { + app_version_info: AppVersionInfo, + is_new: bool, +} + impl std::fmt::Display for State { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -163,88 +185,83 @@ impl State { } } } +} - fn get_verified_installer_path(&self) -> Option<&PathBuf> { - match self { +#[cfg_attr(not(update), allow(unused_variables))] +pub(crate) fn spawn_version_router( + api_handle: MullvadRestHandle, + availability_handle: ApiAvailability, + cache_dir: PathBuf, + version_event_sender: DaemonEventSender<AppVersionInfo>, + beta_program: bool, + app_upgrade_broadcast: AppUpgradeBroadcast, +) -> VersionRouterHandle { + let (tx, rx) = mpsc::unbounded(); + + tokio::spawn(async move { + let (new_version_tx, new_version_rx) = mpsc::unbounded(); + let (refresh_version_check_tx, refresh_version_check_rx) = mpsc::unbounded(); + VersionUpdater::spawn( + api_handle, + availability_handle, + cache_dir, + new_version_tx, + refresh_version_check_rx, + ) + .await; + + VersionRouter { + daemon_rx: rx, + state: State::NoVersion, + beta_program, + version_event_sender, + new_version_rx, + version_request_channels: vec![], #[cfg(update)] - State::Downloaded { - verified_installer_path, - .. - } => Some(verified_installer_path), - _ => None, + app_upgrade_broadcast, + refresh_version_check_tx, + _phantom: std::marker::PhantomData::<DefaultDownloader>, } - } + .run() + .await; + }); + VersionRouterHandle { tx } } -impl VersionRouter { - #[cfg_attr(not(update), allow(unused_variables))] - pub(crate) fn spawn( - api_handle: MullvadRestHandle, - availability_handle: ApiAvailability, - cache_dir: PathBuf, - version_event_sender: DaemonEventSender<AppVersionInfo>, - beta_program: bool, - app_upgrade_broadcast: AppUpgradeBroadcast, - ) -> VersionRouterHandle { - let (tx, rx) = mpsc::unbounded(); - - tokio::spawn(async move { - let (new_version_tx, new_version_rx) = mpsc::unbounded(); - let version_check = - VersionUpdater::spawn(api_handle, availability_handle, cache_dir, new_version_tx) - .await; - - Self { - rx, - state: State::NoVersion, - beta_program, - version_check, - version_event_sender, - new_version_rx, - version_request: Fuse::terminated(), - version_request_channels: vec![], - #[cfg(update)] - app_upgrade_broadcast, - } - .run() - .await; - }); - VersionRouterHandle { tx } +impl<S, D> VersionRouter<S, D> +where + S: Sender<AppVersionInfo> + Send + 'static, + D: Downloader, +{ + async fn run(mut self) { + log::debug!("Version router started"); + // Loop until the router is closed + while self.run_step().await.is_continue() {} + log::debug!("Version router closed"); } - async fn run(mut self) { - loop { - tokio::select! { - // Respond to version info requests - update_result = &mut self.version_request => { - match update_result { - Ok(new_version) => { - self.on_new_version(new_version.clone()); - } - Err(error) => { - log::error!("Failed to retrieve version: {error}"); - for tx in self.version_request_channels.drain(..) { - // TODO: More appropriate error? But Error isn't Clone - let _ = tx.send(Err(Error::UpdateAborted)); - } - } - } - } - // Received version event from `check` - Some(new_version) = self.new_version_rx.next() => { - self.on_new_version(new_version); + /// Run a single step of the router, handling messages from the daemon and version events + async fn run_step(&mut self) -> ControlFlow<()> { + tokio::select! { + // Received version event from `check` + Some(new_version) = self.new_version_rx.next() => { + let AppVersionInfoEvent { app_version_info, is_new } = self.on_new_version(new_version); + self.notify_version_requesters(app_version_info.clone()); + if is_new { + // Notify the daemon about new version + let _ = self.version_event_sender.send(app_version_info); } - res = wait_for_update(&mut self.state) => { - // If the download was successful, we send the new version - if let Some(app_update_info) = res { - let _ = self.version_event_sender.send(app_update_info); - } - }, - Some(message) = self.rx.next() => self.handle_message(message), - else => break, } + res = wait_for_update(&mut self.state) => { + // If the download was successful, we send the new version + if let Some(app_update_info) = res { + let _ = self.version_event_sender.send(app_update_info); + } + }, + Some(message) = self.daemon_rx.next() => self.handle_message(message), + else => return ControlFlow::Break(()), } - log::info!("Version router closed"); + ControlFlow::Continue(()) } /// Handle [Message] sent by user @@ -276,27 +293,32 @@ impl VersionRouter { /// /// If the router is in the process of upgrading, it will not propagate versions, but only /// remember it for when it transitions back into the "idle" (version check) state. - fn on_new_version(&mut self, version_cache: VersionCache) { + fn on_new_version(&mut self, version_cache: VersionCache) -> AppVersionInfoEvent { + #[cfg(update)] + let verified_installer_path = self.get_verified_installer_path(); match &mut self.state { State::NoVersion => { // Receive first version let app_version_info = to_app_version_info(&version_cache, self.beta_program, None); - let _ = self.version_event_sender.send(app_version_info.clone()); self.state = State::HasVersion { version_cache }; + AppVersionInfoEvent { + app_version_info, + is_new: true, + } } // Already have version info, just update it State::HasVersion { version_cache: prev_cache, } => { - if let Some(version_info) = updated_app_version_info_on_new_version_cache( - prev_cache, - &version_cache, - self.beta_program, - ) { - // New version available - let _ = self.version_event_sender.send(version_info.clone()); - } + let prev_app_version = to_app_version_info(prev_cache, self.beta_program, None); + let new_app_version = to_app_version_info(&version_cache, self.beta_program, None); + self.state = State::HasVersion { version_cache }; + + AppVersionInfoEvent { + is_new: new_app_version != prev_app_version, + app_version_info: new_app_version, + } } #[cfg(update)] State::Downloaded { @@ -307,59 +329,68 @@ impl VersionRouter { version_cache: ref mut prev_cache, .. } => { - // If version changed, cancel download - if let Some(version_info) = updated_app_version_info_on_new_version_cache( + let prev_app_version = to_app_version_info( prev_cache, - &version_cache, self.beta_program, - ) { - log::warn!("Received new version while upgrading: {version_info:?}, aborting"); + verified_installer_path.clone(), + ); + let new_app_version = + to_app_version_info(&version_cache, self.beta_program, verified_installer_path); - let _ = self.version_event_sender.send(version_info.clone()); + let is_new = new_app_version != prev_app_version; + // If version changed, cancel download by switching state + if is_new { + log::warn!("Received new version while upgrading: {new_app_version:?}"); self.state = State::HasVersion { version_cache }; } else { *prev_cache = version_cache; + }; + AppVersionInfoEvent { + app_version_info: new_app_version, + is_new, } } } - - // Notify version requesters - if let Some(cache) = self.state.get_version_cache() { - self.notify_version_requesters(to_app_version_info( - cache, - self.beta_program, - self.state.get_verified_installer_path().cloned(), - )); - } } fn notify_version_requesters(&mut self, new_app_version_info: AppVersionInfo) { - // Cancel update notifications - self.version_request = Fuse::terminated(); // Notify all requesters for tx in self.version_request_channels.drain(..) { let _ = tx.send(Ok(new_app_version_info.clone())); } } + #[cfg(update)] + fn get_verified_installer_path(&self) -> Option<PathBuf> { + match &self.state { + State::Downloaded { + verified_installer_path, + .. + } => Some(verified_installer_path.clone()), + _ => None, + } + } + fn set_beta_program(&mut self, new_state: bool) { if new_state == self.beta_program { return; } let previous_state = self.beta_program; self.beta_program = new_state; - let Some(new_app_version_info) = self.state.get_version_cache().and_then(|version_cache| { - updated_app_version_info_on_new_beta(version_cache, previous_state, new_state) - }) else { + let Some(version_cache) = self.state.get_version_cache() else { + return; + }; + let prev_app_version = to_app_version_info(version_cache, previous_state, None); + let new_app_version = to_app_version_info(version_cache, new_state, None); + if new_app_version == prev_app_version { return; }; // Always cancel download if the suggested upgrade changes - let version_cache = match mem::replace(&mut self.state, State::NoVersion) { #[cfg(update)] State::Downloaded { version_cache, .. } | State::Downloading { version_cache, .. } => { - log::warn!("Switching beta after while updating resulted in new suggested upgrade: {:?}, aborting", new_app_version_info.suggested_upgrade); + log::warn!("Switching beta after updating resulted in new suggested upgrade: {:?}, aborting", new_app_version.suggested_upgrade); version_cache } State::HasVersion { version_cache } => version_cache, @@ -369,9 +400,9 @@ impl VersionRouter { }; self.state = State::HasVersion { version_cache }; - let _ = self.version_event_sender.send(new_app_version_info.clone()); + let _ = self.version_event_sender.send(new_app_version.clone()); - self.notify_version_requesters(new_app_version_info); + self.notify_version_requesters(new_app_version); } fn get_latest_version( @@ -390,8 +421,6 @@ impl VersionRouter { .send(Err(err)) .unwrap_or_else(|e| log::warn!("Failed to send version request result: {e:?}")), } - // Append to response channels - self.version_request_channels.push(result_tx); } #[cfg(update)] @@ -399,8 +428,7 @@ impl VersionRouter { use crate::version::downloader::spawn_downloader; match mem::replace(&mut self.state, State::NoVersion) { - // If we're already downloading or have a version, do nothing - State::Downloaded { version_cache, .. } | State::HasVersion { version_cache } => { + State::HasVersion { version_cache } => { let Some(upgrading_to_version) = recommended_version_upgrade(&version_cache.latest_version, self.beta_program) else { @@ -414,7 +442,7 @@ impl VersionRouter { upgrading_to_version.version ); - let downloader_handle = spawn_downloader( + let downloader_handle = spawn_downloader::<D>( upgrading_to_version.clone(), self.app_upgrade_broadcast.clone(), ); @@ -425,7 +453,6 @@ impl VersionRouter { downloader_handle, }; } - // Already downloading/downloaded or there is no version: do nothing state => { log::debug!("Ignoring update request while in state {:?}", state); self.state = state; @@ -452,38 +479,6 @@ impl VersionRouter { } } -fn updated_app_version_info_on_new_version_cache( - version_cache: &VersionCache, - new_version_cache: &VersionCache, - beta_program: bool, -) -> Option<AppVersionInfo> { - let prev_app_version = to_app_version_info(version_cache, beta_program, None); - let new_app_version = to_app_version_info(new_version_cache, beta_program, None); - - // Update version info - if new_app_version != prev_app_version { - Some(new_app_version) - } else { - None - } -} - -fn updated_app_version_info_on_new_beta( - version_cache: &VersionCache, - previous_beta_state: bool, - new_beta_state: bool, -) -> Option<AppVersionInfo> { - let prev_app_version = to_app_version_info(version_cache, previous_beta_state, None); - let new_app_version = to_app_version_info(version_cache, new_beta_state, None); - - // Update version info - if new_app_version != prev_app_version { - Some(new_app_version) - } else { - None - } -} - /// Wait for the update to finish. In case no update is in progress (or the platform does not /// support in-app upgrades), then the future will never resolve as to not escape the select statement. #[allow(clippy::unused_async, unused_variables)] @@ -495,7 +490,7 @@ async fn wait_for_update(state: &mut State) -> Option<AppVersionInfo> { ref mut downloader_handle, upgrading_to_version, .. - } => match downloader_handle.wait().await { + } => match downloader_handle.await { Ok(verified_installer_path) => { let app_update_info = AppVersionInfo { current_version_supported: version_cache.current_version_supported, @@ -514,7 +509,7 @@ async fn wait_for_update(state: &mut State) -> Option<AppVersionInfo> { Some(app_update_info) } Err(err) => { - log::trace!("Downloader task ended: {err}"); + log::warn!("Downloader task ended: {err}"); *state = State::HasVersion { version_cache: version_cache.clone(), }; @@ -574,3 +569,497 @@ fn recommended_version_upgrade( None } } + +#[cfg(all(test, update))] +mod test { + use super::downloader::ProgressUpdater; + use futures::channel::mpsc::unbounded; + use mullvad_types::version::{AppUpgradeDownloadProgress, AppUpgradeEvent}; + use mullvad_update::{app::DownloadError, fetch::ProgressUpdater as _}; + use tokio::sync::broadcast::error::TryRecvError; + + use super::*; + + /// To be able to test events occurring during the download process, we need to + /// call `tokio::time::sleep` in the downloader. This will not affect the runtime + /// of the tests, as set `start_paused = true`. + const DOWNLOAD_DURATION: std::time::Duration = std::time::Duration::from_millis(1000); + + /// Mock downloader that simulates a successful download + struct SuccessfulAppDownloader(AppDownloaderParameters<ProgressUpdater>); + + impl AppDownloader for SuccessfulAppDownloader { + async fn download_executable(&mut self) -> std::result::Result<(), DownloadError> { + tokio::time::sleep(DOWNLOAD_DURATION).await; + self.0.app_progress.set_progress(1.0); + Ok(()) + } + + async fn verify(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + + async fn install(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + } + + impl From<AppDownloaderParameters<ProgressUpdater>> for SuccessfulAppDownloader { + fn from(parameters: AppDownloaderParameters<ProgressUpdater>) -> Self { + Self(parameters) + } + } + + impl Downloader for SuccessfulAppDownloader {} + + /// Mock downloader that simulates a failed download + struct FailingAppDownloader; + + impl AppDownloader for FailingAppDownloader { + async fn download_executable(&mut self) -> std::result::Result<(), DownloadError> { + Err(DownloadError::FetchApp(anyhow::anyhow!("Download failed"))) + } + + async fn verify(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + + async fn install(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + } + + impl From<AppDownloaderParameters<ProgressUpdater>> for FailingAppDownloader { + fn from(_parameters: AppDownloaderParameters<ProgressUpdater>) -> Self { + Self + } + } + + impl Downloader for FailingAppDownloader {} + + /// Mock downloader that simulates a failed verification, but a successful download + struct FailingAppVerifier; + + impl AppDownloader for FailingAppVerifier { + async fn download_executable(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + + async fn verify(&mut self) -> std::result::Result<(), DownloadError> { + Err(DownloadError::Verification(anyhow::anyhow!( + "Verification failed" + ))) + } + + async fn install(&mut self) -> std::result::Result<(), DownloadError> { + Ok(()) + } + } + + impl From<AppDownloaderParameters<ProgressUpdater>> for FailingAppVerifier { + fn from(_parameters: AppDownloaderParameters<ProgressUpdater>) -> Self { + Self + } + } + + impl Downloader for FailingAppVerifier {} + + /// Channels used to communicate with the version router and receive version events. + /// This is used in the tests to simulate the daemon and `VersionUpdater`. + struct VersionRouterChannels { + daemon_tx: futures::channel::mpsc::UnboundedSender<Message>, + new_version_tx: futures::channel::mpsc::UnboundedSender<VersionCache>, + refresh_version_check_rx: futures::channel::mpsc::UnboundedReceiver<()>, + version_event_receiver: futures::channel::mpsc::UnboundedReceiver<AppVersionInfo>, + } + + fn make_version_router<D>() -> ( + VersionRouter<futures::channel::mpsc::UnboundedSender<AppVersionInfo>, D>, + VersionRouterChannels, + ) { + let (version_event_sender, version_event_receiver) = unbounded(); + let (daemon_tx, daemon_rx) = unbounded(); + let (app_upgrade_broadcast, _) = tokio::sync::broadcast::channel(10); + let (refresh_version_check_tx, refresh_version_check_rx) = unbounded(); + let (new_version_tx, new_version_rx) = unbounded(); + ( + VersionRouter { + daemon_rx, + state: State::NoVersion, + beta_program: false, + version_event_sender, + new_version_rx, + version_request_channels: vec![], + app_upgrade_broadcast, + refresh_version_check_tx, + _phantom: std::marker::PhantomData::<D>, + }, + VersionRouterChannels { + daemon_tx, + new_version_tx, + refresh_version_check_rx, + version_event_receiver, + }, + ) + } + + /// Create a version cache with a stable version that is newer than the current version + fn get_new_stable_version_cache() -> VersionCache { + let mut version: mullvad_version::Version = mullvad_version::VERSION.parse().unwrap(); + version.incremental += 1; + VersionCache { + current_version_supported: true, + latest_version: VersionInfo { + beta: None, + stable: mullvad_update::version::Version { + version, + urls: vec!["https://example.com".to_string()], + size: 123456, + changelog: "Changelog".to_string(), + sha256: [0; 32], + }, + }, + metadata_version: 0, + } + } + + /// Create a version cache with a beta version that is newer than the current version + fn get_new_beta_version_cache() -> VersionCache { + let stable = mullvad_update::version::Version { + version: mullvad_version::VERSION.parse().unwrap(), + urls: vec!["https://example.com".to_string()], + size: 123456, + changelog: "Changelog".to_string(), + sha256: [0; 32], + }; + let mut beta = stable.clone(); + beta.version.pre_stable = Some(mullvad_version::PreStableType::Beta(1)); + beta.version.incremental += 1; + VersionCache { + current_version_supported: true, + latest_version: VersionInfo { + beta: Some(beta), + stable, + }, + metadata_version: 0, + } + } + + #[tokio::test(start_paused = true)] + async fn test_upgrade_with_no_version() { + let (mut version_router, _channels) = make_version_router::<SuccessfulAppDownloader>(); + let upgrade_events = version_router.app_upgrade_broadcast.subscribe(); + version_router.update_application(); + assert!( + matches!(version_router.state, State::NoVersion), + "State should stay as NoVersion after calling update_application" + ); + assert!( + upgrade_events.is_empty(), + "No upgrade events should be sent" + ); + } + + #[tokio::test(start_paused = true)] + async fn test_new_beta() { + let (mut version_router, mut channels) = make_version_router::<SuccessfulAppDownloader>(); + let version_cache = get_new_beta_version_cache(); + + // Test that new beta version is ignored if beta program is off + version_router.set_beta_program(false); // This is default value, but set it for clarity + assert!( + matches!(version_router.state, State::NoVersion), + "State should not transition" + ); + version_router.on_new_version(version_cache); + assert!(matches!(version_router.state, State::HasVersion { .. })); + assert!( + channels.version_event_receiver.try_next().is_err(), + "No version event should be sent on beta program change" + ); + version_router.update_application(); + assert!( + matches!(version_router.state, State::HasVersion { .. }), + "State should not transition to Downloading as the beta version is ignored" + ); + + // Test that switching to beta program sends version event for the previously received beta + // version and allows upgrades. + version_router.set_beta_program(true); + assert!( + channels.version_event_receiver.try_next().is_ok(), + "Version event should be sent on beta program change" + ); + version_router.update_application(); + assert!( + matches!(version_router.state, State::Downloading { .. }), + "State should transition to Downloading as the beta version is accepted" + ); + } + + /// Test that when the daemon calls `get_latest_version`, it will trigger a version check + /// and send the result back to the daemon, both on the response channel and in the + /// version event stream. + #[tokio::test(start_paused = true)] + async fn test_get_latest_version() { + let (mut version_router, mut channels) = make_version_router::<SuccessfulAppDownloader>(); + let version_cache_test = get_new_stable_version_cache(); + + // Make a request to the router to get the latest version + // Note that we could as well call `version_router.get_latest_version()`, + // but this way we test the actual message passing between the router and + // the daemon. + let (tx, mut get_latest_version_rx) = oneshot::channel(); + channels + .daemon_tx + .unbounded_send(Message::GetLatestVersion(tx)) + .unwrap(); + version_router.run_step().await; + + // Here, we play the role of `VersionUpdater`. + // It should receive a version check request and send a version in response + assert!( + matches!(channels.refresh_version_check_rx.try_next(), Ok(Some(()))), + "Version check should be triggered" + ); + channels + .new_version_tx + .unbounded_send(version_cache_test.clone()) + .unwrap(); + + // On the next step, the router should receive the version info + // and send it to as a response to the oneshot from `GetLatestVersion` + // and to the daemon in the `version_event_receiver` channel. + version_router.run_step().await; + let version_info = get_latest_version_rx + .try_recv() + .expect("Sender should not be dropped") + .expect("Version info should have been sent") + .expect("Version request should be successful"); + match &version_router.state { + State::HasVersion { version_cache } => assert_eq!(version_cache, &version_cache_test), + other => panic!("State should be HasVersion, was {other:?}"), + } + assert_eq!( + version_info, + channels + .version_event_receiver + .try_next() + .expect("Version event sender should not be closed") + .expect("Version event should be sent"), + "Version event sent to the daemon should be the same as the one sent to the requester" + ); + } + + #[tokio::test(start_paused = true)] + async fn test_upgrade() { + let (mut version_router, mut channels) = make_version_router::<SuccessfulAppDownloader>(); + let version_cache_test = get_new_stable_version_cache(); + + version_router.on_new_version(version_cache_test.clone()); + match &version_router.state { + State::HasVersion { version_cache } => assert_eq!(version_cache, &version_cache_test), + other => panic!("State should be HasVersion, was {other:?}"), + } + + // Start upgrading + let mut app_upgrade_listener = version_router.app_upgrade_broadcast.subscribe(); + version_router.update_application(); + // Check that the state is now downloading + match &version_router.state { + State::Downloading { + version_cache, + upgrading_to_version, + .. + } => { + assert_eq!(version_cache, &version_cache_test); + assert_eq!( + upgrading_to_version.version, + version_cache_test.latest_version.stable.version + ); + } + other => panic!("State should be Downloading, was {other:?}"), + } + + version_router.update_application(); + assert!( + matches!(version_router.state, State::Downloading { .. }), + "Triggering an update while in the downloading shout be ignored" + ); + + // Drive the download to completion, and get the verified installer path + version_router.run_step().await; + let verified_installer_path = match &version_router.state { + State::Downloaded { + version_cache, + verified_installer_path, + .. + } => { + assert_eq!(version_cache, &version_cache_test); + verified_installer_path + } + other => panic!("State should be Downloaded, was {other:?}"), + }; + + // Check that the app upgrade events were sent + let events = [ + Ok(AppUpgradeEvent::DownloadStarting), + Ok(AppUpgradeEvent::DownloadProgress( + AppUpgradeDownloadProgress { + progress: 100, + server: "example.com".to_string(), + time_left: None, + }, + )), + Ok(AppUpgradeEvent::VerifyingInstaller), + Ok(AppUpgradeEvent::VerifiedInstaller), + Err(TryRecvError::Empty), // No more events should be sent + ]; + for event in events { + assert_eq!(app_upgrade_listener.try_recv(), event); + } + + // Check that the version event was sent with the verified installer path + let version_info = channels + .version_event_receiver + .try_next() + .expect("Version event sender should not be closed") + .expect("Version event should be sent"); + assert_eq!( + version_info + .suggested_upgrade + .as_ref() + .unwrap() + .verified_installer_path, + Some(verified_installer_path.clone()) + ); + + version_router.update_application(); + assert!( + matches!(version_router.state, State::Downloaded { .. }), + "Triggering an update while in the downloaded shout be ignored" + ); + + version_router.cancel_upgrade(); + assert!( + matches!(version_router.state, State::HasVersion { .. }), + "State should be HasVersion after cancelling the upgrade" + ); + + assert_eq!( + app_upgrade_listener.try_recv(), + Err(TryRecvError::Empty), + "The `AppUpgradeEvent::Aborted` should not be sent when cancelling a finished download" + ); + } + + /// Test that the update is aborted if a new version is received while downloading + #[tokio::test(start_paused = true)] + async fn test_abort_on_new_version() { + let (mut version_router, _channels) = make_version_router::<SuccessfulAppDownloader>(); + let upgrade_version = get_new_stable_version_cache(); + let mut upgrade_version_newer = upgrade_version.clone(); + upgrade_version_newer + .latest_version + .stable + .version + .incremental += 1; + + version_router.on_new_version(upgrade_version.clone()); + + // Start upgrading + let mut app_upgrade_listener = version_router.app_upgrade_broadcast.subscribe(); + version_router.update_application(); + // Check that the state is now downloading + assert!(matches!(version_router.state, State::Downloading { .. }),); + + tokio::time::sleep(DOWNLOAD_DURATION / 2).await; + version_router.on_new_version(upgrade_version); + + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::DownloadStarting + ); + assert_eq!(app_upgrade_listener.try_recv(), Err(TryRecvError::Empty)); + + version_router.on_new_version(upgrade_version_newer); + + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::Aborted + ); + assert_eq!(app_upgrade_listener.try_recv(), Err(TryRecvError::Empty)); + } + + #[tokio::test] + async fn test_failed_download() { + let (mut version_router, _channels) = make_version_router::<FailingAppDownloader>(); + let version_cache_test = get_new_stable_version_cache(); + + version_router.on_new_version(version_cache_test.clone()); + + // Start upgrading + let mut app_upgrade_listener = version_router.app_upgrade_broadcast.subscribe(); + version_router.update_application(); + // Check that the state is now downloading + assert!(matches!(version_router.state, State::Downloading { .. }),); + + // Drive the download to completion + version_router.run_step().await; + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::DownloadStarting + ); + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::Error(mullvad_types::version::AppUpgradeError::DownloadFailed) + ); + assert_eq!(app_upgrade_listener.try_recv(), Err(TryRecvError::Empty)); + version_router.update_application(); + + // Verify that we can restart the download again + version_router.run_step().await; + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::DownloadStarting + ); + } + + #[tokio::test] + async fn test_failed_verification() { + let (mut version_router, _channels) = make_version_router::<FailingAppVerifier>(); + let version_cache_test = get_new_stable_version_cache(); + + version_router.on_new_version(version_cache_test.clone()); + + // Start upgrading + let mut app_upgrade_listener = version_router.app_upgrade_broadcast.subscribe(); + version_router.update_application(); + // Check that the state is now downloading + assert!(matches!(version_router.state, State::Downloading { .. }),); + + // Drive the download to completion + version_router.run_step().await; + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::DownloadStarting + ); + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::VerifyingInstaller + ); + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::Error(mullvad_types::version::AppUpgradeError::VerificationFailed) + ); + assert_eq!(app_upgrade_listener.try_recv(), Err(TryRecvError::Empty)); + version_router.update_application(); + + // Verify that we can restart the download again + version_router.run_step().await; + assert_eq!( + app_upgrade_listener.try_recv().unwrap(), + AppUpgradeEvent::DownloadStarting + ); + } +} diff --git a/mullvad-update/mullvad-release/src/main.rs b/mullvad-update/mullvad-release/src/main.rs index 7ba8469cd1..900caf481b 100644 --- a/mullvad-update/mullvad-release/src/main.rs +++ b/mullvad-update/mullvad-release/src/main.rs @@ -25,9 +25,6 @@ const DEFAULT_EXPIRY_MONTHS: usize = 6; /// Rollout to use when not specified const DEFAULT_ROLLOUT: f32 = 1.; -/// Lowest version to accept using 'verify' -const MIN_VERIFY_METADATA_VERSION: usize = 0; - /// A tool that generates signed Mullvad version metadata. /// /// Unsigned work is stored in `work/`, and signed work is stored in `signed/` diff --git a/mullvad-update/mullvad-release/src/platform.rs b/mullvad-update/mullvad-release/src/platform.rs index 06ad8a7185..5ab5b8ba39 100644 --- a/mullvad-update/mullvad-release/src/platform.rs +++ b/mullvad-update/mullvad-release/src/platform.rs @@ -105,7 +105,7 @@ impl Platform { let response = HttpVersionInfoProvider::get_versions_for_platform( platform, - crate::MIN_VERIFY_METADATA_VERSION, + mullvad_update::version::MIN_VERIFY_METADATA_VERSION, ) .await .context("Failed to retrieve versions")?; @@ -204,8 +204,11 @@ impl Platform { println!("Verifying signature of {}...", signed_path.display()); let bytes = fs::read(signed_path).await.context("Failed to read file")?; - format::SignedResponse::deserialize_and_verify(&bytes, crate::MIN_VERIFY_METADATA_VERSION) - .context("Failed to verify metadata for {platform}: {error}")?; + format::SignedResponse::deserialize_and_verify( + &bytes, + mullvad_update::version::MIN_VERIFY_METADATA_VERSION, + ) + .context("Failed to verify metadata for {platform}: {error}")?; Ok(()) } diff --git a/mullvad-update/src/client/app.rs b/mullvad-update/src/client/app.rs index c43326fb88..ad031861d3 100644 --- a/mullvad-update/src/client/app.rs +++ b/mullvad-update/src/client/app.rs @@ -2,7 +2,12 @@ //! This module implements the flow of downloading and verifying the app. -use std::{ffi::OsString, future::Future, path::PathBuf, time::Duration}; +use std::{ + ffi::OsString, + future::Future, + path::{Path, PathBuf}, + time::Duration, +}; use tokio::{process::Command, time::timeout}; @@ -81,7 +86,7 @@ impl<AppProgress: ProgressUpdater> From<AppDownloaderParameters<AppProgress>> impl<AppProgress: ProgressUpdater> AppDownloader for HttpAppDownloader<AppProgress> { async fn download_executable(&mut self) -> Result<(), DownloadError> { - let bin_path = self.bin_path(); + let bin_path = bin_path(&self.params.app_version, &self.params.cache_dir); fetch::get_to_file( bin_path, &self.params.app_url, @@ -93,7 +98,7 @@ impl<AppProgress: ProgressUpdater> AppDownloader for HttpAppDownloader<AppProgre } async fn verify(&mut self) -> Result<(), DownloadError> { - let bin_path = self.bin_path(); + let bin_path = bin_path(&self.params.app_version, &self.params.cache_dir); let hash = self.hash_sha256(); match Sha256Verifier::verify(&bin_path, *hash) @@ -133,21 +138,21 @@ impl<AppProgress: ProgressUpdater> AppDownloader for HttpAppDownloader<AppProgre } } -impl<AppProgress> HttpAppDownloader<AppProgress> { - pub fn bin_path(&self) -> PathBuf { - #[cfg(windows)] - let bin_filename = format!("mullvad-{}.exe", self.params.app_version); +pub fn bin_path(app_version: &mullvad_version::Version, cache_dir: &Path) -> PathBuf { + #[cfg(windows)] + let bin_filename = format!("mullvad-{}.exe", app_version); - #[cfg(target_os = "macos")] - let bin_filename = format!("mullvad-{}.pkg", self.params.app_version); + #[cfg(target_os = "macos")] + let bin_filename = format!("mullvad-{}.pkg", app_version); - self.params.cache_dir.join(bin_filename) - } + cache_dir.join(bin_filename) +} +impl<AppProgress> HttpAppDownloader<AppProgress> { fn launch_path(&self) -> PathBuf { #[cfg(target_os = "windows")] { - self.bin_path() + bin_path(&self.params.app_version, &self.params.cache_dir) } #[cfg(target_os = "macos")] @@ -166,7 +171,7 @@ impl<AppProgress> HttpAppDownloader<AppProgress> { #[cfg(target_os = "macos")] { - vec![self.bin_path().into()] + vec![bin_path(&self.params.app_version, &self.params.cache_dir).into()] } } diff --git a/mullvad-update/src/version.rs b/mullvad-update/src/version.rs index 5fc0acae42..a6bad44b22 100644 --- a/mullvad-update/src/version.rs +++ b/mullvad-update/src/version.rs @@ -11,6 +11,9 @@ use mullvad_version::PreStableType; use crate::format; +/// Lowest version to accept using 'verify' +pub const MIN_VERIFY_METADATA_VERSION: usize = 0; + /// Query type for [VersionInfo] #[derive(Debug)] pub struct VersionParameters { |
