diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2025-05-12 15:18:17 +0000 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2025-05-28 13:25:34 +0200 |
| commit | 10dc35368c408561dffdf94f8ba1b7797b4fe475 (patch) | |
| tree | 0c25e0f186202cdb5045cd2b6b9a71807b012eea | |
| parent | 9191354880f615a03d164f593ec81a36ab90201b (diff) | |
| download | mullvadvpn-10dc35368c408561dffdf94f8ba1b7797b4fe475.tar.xz mullvadvpn-10dc35368c408561dffdf94f8ba1b7797b4fe475.zip | |
Add download timeout and retry logic (#8149)
* Add timeout to download
* Retry failed downloads on network errors
Previously, the download would either fail immediately or hang
indefinitely if when the user e.g. changed their tunnel state.
* Fix progress when resuming download
* Import thiserror on all platforms
* Add to installer downloader changelog
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | installer-downloader/CHANGELOG.md | 2 | ||||
| -rw-r--r-- | mullvad-update/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-update/src/client/app.rs | 2 | ||||
| -rw-r--r-- | mullvad-update/src/client/fetch.rs | 202 | ||||
| -rw-r--r-- | test/Cargo.lock | 1 |
6 files changed, 179 insertions, 31 deletions
diff --git a/Cargo.lock b/Cargo.lock index 08708f4541..2878ae49ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3238,6 +3238,7 @@ dependencies = [ "hex", "insta", "json-canon", + "log", "mockito", "mullvad-version", "rand 0.8.5", diff --git a/installer-downloader/CHANGELOG.md b/installer-downloader/CHANGELOG.md index aceeb14b41..9629809539 100644 --- a/installer-downloader/CHANGELOG.md +++ b/installer-downloader/CHANGELOG.md @@ -21,10 +21,10 @@ Line wrap the file at 100 chars. Th ## [Unreleased] ### Fix +- Fix downloads hanging indefinitely on switching networks #### macOS - Fix rendering issues on old (unsupported) macOS versions. - ## [1.0.0] - 2025-05-13 ### Fixed #### Windows diff --git a/mullvad-update/Cargo.toml b/mullvad-update/Cargo.toml index 020391b635..d2f2bbae37 100644 --- a/mullvad-update/Cargo.toml +++ b/mullvad-update/Cargo.toml @@ -24,6 +24,7 @@ hex = { version = "0.4" } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } zeroize = { version = "1.8", features = ["zeroize_derive"] } +log = { workspace = true } reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true } sha2 = { workspace = true, optional = true } @@ -36,7 +37,6 @@ mullvad-version = { path = "../mullvad-version", features = ["serde"] } clap = { workspace = true, optional = true } rand = { version = "0.8.5", optional = true } -[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies] thiserror = { workspace = true, optional = true } [dev-dependencies] diff --git a/mullvad-update/src/client/app.rs b/mullvad-update/src/client/app.rs index ad031861d3..473f8daf2e 100644 --- a/mullvad-update/src/client/app.rs +++ b/mullvad-update/src/client/app.rs @@ -19,7 +19,7 @@ use crate::{ #[derive(Debug, thiserror::Error)] pub enum DownloadError { #[error("Failed to download app")] - FetchApp(#[source] anyhow::Error), + FetchApp(#[from] anyhow::Error), #[error("Failed to verify app")] Verification(#[source] anyhow::Error), #[error("Failed to launch app")] diff --git a/mullvad-update/src/client/fetch.rs b/mullvad-update/src/client/fetch.rs index 706e3897f3..4a5a8c8d56 100644 --- a/mullvad-update/src/client/fetch.rs +++ b/mullvad-update/src/client/fetch.rs @@ -1,9 +1,11 @@ //! A downloader that supports HTTP range requests and resuming downloads use std::{ + error::Error, path::Path, pin::Pin, task::{ready, Poll}, + time::Duration, }; use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE}; @@ -12,7 +14,110 @@ use tokio::{ io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}, }; -use anyhow::Context; +use thiserror::Error; + +/// Start value of the read timeout. This is doubled on each retry. +const READ_TIMEOUT: Duration = Duration::from_secs(1); +const CONNECT_TIMEOUT: Duration = Duration::from_secs(30); +// Maximum number of retry attempts for timeouts +const MAX_RETRY_ATTEMPTS: u32 = 4; + +/// Custom error type for download operations +#[derive(Error, Debug)] +pub enum DownloadError { + /// Failed to initialize client + #[error("Failed to initialize HTTP client")] + ClientInitialization(#[source] reqwest::Error), + + /// Failed to get content length + #[error("Failed to request download")] + HeadRequest(#[source] reqwest::Error), + + /// Server returned error status + #[error("Download failed: {0}")] + HttpStatus(reqwest::StatusCode), + + /// Invalid content length header + #[error("Invalid content length header: {0}")] + InvalidContentLength(&'static str), + + /// Failed to make range request + #[error("Failed to retrieve range")] + RangeRequest(#[source] reqwest::Error), + + /// Failed to read chunk + #[error("Failed to read chunk")] + ChunkRead(#[source] reqwest::Error), + + /// Failed to write chunk + #[error("Failed to write chunk")] + ChunkWrite(#[source] io::Error), + + /// Failed to get stream position + #[error("Failed to get existing file size")] + StreamPosition(#[source] io::Error), + + /// Failed to flush writer + #[error("Failed to flush writer")] + Flush(#[source] io::Error), + + /// Size validation error + #[error("Size validation failed: {0}")] + SizeValidation(String), + + /// File operation error + #[error("File operation failed: {0}")] + FileOperation(#[source] io::Error), + + /// Other error + #[error("{0}")] + Other(&'static str), +} + +impl DownloadError { + /// Checks if the error is caused by a timeout or network issue that can be retried + pub fn should_retry(&self) -> bool { + match self { + DownloadError::HeadRequest(e) + | DownloadError::RangeRequest(e) + | DownloadError::ChunkRead(e) + | DownloadError::ClientInitialization(e) => is_network_error(e), + DownloadError::HttpStatus(status) => { + // Retry server errors and timeout status + status.is_server_error() || *status == reqwest::StatusCode::REQUEST_TIMEOUT + } + // Don't retry other types of errors + _ => false, + } + } +} + +/// Checks if the error is a network-related error that can be retried +fn is_network_error(error: &reqwest::Error) -> bool { + // Retry on timeout errors + // Retry on connection errors (which often happen when switching networks) + // Retry on request errors (like "connection reset") + if error.is_timeout() || error.is_connect() || error.is_request() { + return true; + } + + let mut error = error as &dyn Error; + loop { + if let Some(io_err) = error.downcast_ref::<std::io::Error>() { + // Check if the error is a timeout or connection error + if io_err.kind() == io::ErrorKind::TimedOut + || io_err.kind() == io::ErrorKind::ConnectionReset + { + return true; + } + } + if let Some(source) = error.source() { + error = source; + } else { + break false; + } + } +} /// Receiver of the current progress so far pub trait ProgressUpdater: Send + 'static { @@ -67,9 +172,26 @@ pub async fn get_to_file( progress_updater: &mut impl ProgressUpdater, size_hint: SizeHint, ) -> anyhow::Result<()> { - let file = create_or_append(file).await?; - let file = BufWriter::new(file); - get_to_writer(file, url, progress_updater, size_hint).await + let file = create_or_append(file) + .await + .map_err(DownloadError::FileOperation)?; + let mut file = BufWriter::new(file); + let mut attempts = 0; + let mut read_timeout = READ_TIMEOUT; + while let Err(err) = + get_to_writer(&mut file, url, progress_updater, size_hint, read_timeout).await + { + if !err.should_retry() { + anyhow::bail!(err); + } + attempts += 1; + read_timeout *= 2; + if attempts >= MAX_RETRY_ATTEMPTS { + anyhow::bail!("Max retry attempts reached: {err}"); + } + log::warn!("Download failed: {err}. Retrying in with timeout: {read_timeout:?}"); + } + Ok(()) } /// Download `url` to `writer`. @@ -82,41 +204,59 @@ pub async fn get_to_writer( url: &str, progress_updater: &mut impl ProgressUpdater, size_hint: SizeHint, -) -> anyhow::Result<()> { - let client = reqwest::Client::new(); + read_timeout: Duration, +) -> Result<(), DownloadError> { + // Create a new client for each download attempt to prevent stale connections + let client = reqwest::Client::builder() + .read_timeout(read_timeout) + .connect_timeout(CONNECT_TIMEOUT) + .build() + .map_err(DownloadError::ClientInitialization)?; progress_updater.set_url(url); - progress_updater.set_progress(0.); // Fetch content length first - let response = client.head(url).send().await.context("HEAD failed")?; + let response = client + .head(url) + .send() + .await + .map_err(DownloadError::HeadRequest)?; + if !response.status().is_success() { - return response - .error_for_status() - .map(|_| ()) - .context("Download failed"); + return Err(DownloadError::HttpStatus(response.status())); } let total_size = response .headers() .get(CONTENT_LENGTH) - .context("Missing file size")?; - let total_size: usize = total_size.to_str()?.parse().context("invalid size")?; - size_hint.check_size(total_size)?; + .ok_or_else(|| DownloadError::InvalidContentLength("Missing file size"))?; + + let total_size: usize = total_size + .to_str() + .map_err(|_| DownloadError::InvalidContentLength("Invalid content length header"))? + .parse() + .map_err(|_| DownloadError::InvalidContentLength("Invalid size format"))?; + + match size_hint.check_size(total_size) { + Ok(_) => {} + Err(e) => return Err(DownloadError::SizeValidation(e.to_string())), + } let already_fetched_bytes = writer .stream_position() .await - .context("failed to get existing file size")? + .map_err(DownloadError::StreamPosition)? .try_into() - .context("invalid size")?; + .map_err(|_| DownloadError::Other("Invalid file position"))?; + progress_updater.set_progress(already_fetched_bytes as f32 / total_size as f32); if total_size == already_fetched_bytes { - progress_updater.set_progress(1.); return Ok(()); } if already_fetched_bytes > total_size { - anyhow::bail!("Found existing file that was larger"); + return Err(DownloadError::SizeValidation( + "Found existing file that was larger".to_string(), + )); } // Fetch content, one range at a time @@ -133,32 +273,32 @@ pub async fn get_to_writer( .header(RANGE, range) .send() .await - .context("Failed to retrieve range")?; + .map_err(DownloadError::RangeRequest)?; + let status = response.status(); if !status.is_success() { - return response - .error_for_status() - .map(|_| ()) - .context("Download failed"); + return Err(DownloadError::HttpStatus(status)); } let mut bytes_read = 0; - while let Some(chunk) = response.chunk().await.context("Failed to read chunk")? { + while let Some(chunk) = response.chunk().await.map_err(DownloadError::ChunkRead)? { bytes_read += chunk.len(); if bytes_read > total_size - already_fetched_bytes { // Protect against servers responding with more data than expected - anyhow::bail!("Server returned more than requested bytes"); + return Err(DownloadError::SizeValidation( + "Server returned more than requested bytes".to_string(), + )); } writer .write_all(&chunk) .await - .context("Failed to write chunk")?; + .map_err(DownloadError::ChunkWrite)?; } } - writer.shutdown().await.context("Failed to flush")?; + writer.shutdown().await.map_err(DownloadError::Flush)?; Ok(()) } @@ -261,6 +401,7 @@ impl<PU: ProgressUpdater, Writer: AsyncWrite + Unpin> AsyncWrite mod test { use std::io::Cursor; + use anyhow::Context; use async_tempfile::TempDir; use rand::RngCore; use tokio::{fs, io::AsyncWriteExt}; @@ -344,6 +485,7 @@ mod test { &file_url, &mut progress_updater, SizeHint::Exact(file_data.len()), + READ_TIMEOUT, ) .await .context("Complete download failed")?; @@ -378,6 +520,7 @@ mod test { &file_url, &mut progress_updater, SizeHint::Exact(file_data.len()), + READ_TIMEOUT, ) .await .expect_err("Expected interrupted download"); @@ -408,6 +551,7 @@ mod test { &file_url, &mut progress_updater, SizeHint::Exact(file_data.len()), + READ_TIMEOUT, ) .await .context("Partial download failed")?; @@ -468,6 +612,7 @@ mod test { &file_url, &mut FakeProgressUpdater::default(), SizeHint::Exact(1), + READ_TIMEOUT, ) .await .expect_err("Reject unexpected content length"); @@ -492,6 +637,7 @@ mod test { &file_url, &mut FakeProgressUpdater::default(), SizeHint::Exact(file_data.len()), + READ_TIMEOUT, ) .await .expect_err("Reject unexpected chunk sizes"); diff --git a/test/Cargo.lock b/test/Cargo.lock index 2d77e49eba..1ccf04e2e7 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -2217,6 +2217,7 @@ dependencies = [ "ed25519-dalek", "hex", "json-canon", + "log", "mullvad-version", "reqwest", "serde", |
