summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2025-05-12 15:18:17 +0000
committerSebastian Holmin <sebastian.holmin@mullvad.net>2025-05-28 13:25:34 +0200
commit10dc35368c408561dffdf94f8ba1b7797b4fe475 (patch)
tree0c25e0f186202cdb5045cd2b6b9a71807b012eea
parent9191354880f615a03d164f593ec81a36ab90201b (diff)
downloadmullvadvpn-10dc35368c408561dffdf94f8ba1b7797b4fe475.tar.xz
mullvadvpn-10dc35368c408561dffdf94f8ba1b7797b4fe475.zip
Add download timeout and retry logic (#8149)
* Add timeout to download * Retry failed downloads on network errors Previously, the download would either fail immediately or hang indefinitely if when the user e.g. changed their tunnel state. * Fix progress when resuming download * Import thiserror on all platforms * Add to installer downloader changelog
-rw-r--r--Cargo.lock1
-rw-r--r--installer-downloader/CHANGELOG.md2
-rw-r--r--mullvad-update/Cargo.toml2
-rw-r--r--mullvad-update/src/client/app.rs2
-rw-r--r--mullvad-update/src/client/fetch.rs202
-rw-r--r--test/Cargo.lock1
6 files changed, 179 insertions, 31 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 08708f4541..2878ae49ec 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3238,6 +3238,7 @@ dependencies = [
"hex",
"insta",
"json-canon",
+ "log",
"mockito",
"mullvad-version",
"rand 0.8.5",
diff --git a/installer-downloader/CHANGELOG.md b/installer-downloader/CHANGELOG.md
index aceeb14b41..9629809539 100644
--- a/installer-downloader/CHANGELOG.md
+++ b/installer-downloader/CHANGELOG.md
@@ -21,10 +21,10 @@ Line wrap the file at 100 chars. Th
## [Unreleased]
### Fix
+- Fix downloads hanging indefinitely on switching networks
#### macOS
- Fix rendering issues on old (unsupported) macOS versions.
-
## [1.0.0] - 2025-05-13
### Fixed
#### Windows
diff --git a/mullvad-update/Cargo.toml b/mullvad-update/Cargo.toml
index 020391b635..d2f2bbae37 100644
--- a/mullvad-update/Cargo.toml
+++ b/mullvad-update/Cargo.toml
@@ -24,6 +24,7 @@ hex = { version = "0.4" }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
zeroize = { version = "1.8", features = ["zeroize_derive"] }
+log = { workspace = true }
reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true }
sha2 = { workspace = true, optional = true }
@@ -36,7 +37,6 @@ mullvad-version = { path = "../mullvad-version", features = ["serde"] }
clap = { workspace = true, optional = true }
rand = { version = "0.8.5", optional = true }
-[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
thiserror = { workspace = true, optional = true }
[dev-dependencies]
diff --git a/mullvad-update/src/client/app.rs b/mullvad-update/src/client/app.rs
index ad031861d3..473f8daf2e 100644
--- a/mullvad-update/src/client/app.rs
+++ b/mullvad-update/src/client/app.rs
@@ -19,7 +19,7 @@ use crate::{
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("Failed to download app")]
- FetchApp(#[source] anyhow::Error),
+ FetchApp(#[from] anyhow::Error),
#[error("Failed to verify app")]
Verification(#[source] anyhow::Error),
#[error("Failed to launch app")]
diff --git a/mullvad-update/src/client/fetch.rs b/mullvad-update/src/client/fetch.rs
index 706e3897f3..4a5a8c8d56 100644
--- a/mullvad-update/src/client/fetch.rs
+++ b/mullvad-update/src/client/fetch.rs
@@ -1,9 +1,11 @@
//! A downloader that supports HTTP range requests and resuming downloads
use std::{
+ error::Error,
path::Path,
pin::Pin,
task::{ready, Poll},
+ time::Duration,
};
use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE};
@@ -12,7 +14,110 @@ use tokio::{
io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter},
};
-use anyhow::Context;
+use thiserror::Error;
+
+/// Start value of the read timeout. This is doubled on each retry.
+const READ_TIMEOUT: Duration = Duration::from_secs(1);
+const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
+// Maximum number of retry attempts for timeouts
+const MAX_RETRY_ATTEMPTS: u32 = 4;
+
+/// Custom error type for download operations
+#[derive(Error, Debug)]
+pub enum DownloadError {
+ /// Failed to initialize client
+ #[error("Failed to initialize HTTP client")]
+ ClientInitialization(#[source] reqwest::Error),
+
+ /// Failed to get content length
+ #[error("Failed to request download")]
+ HeadRequest(#[source] reqwest::Error),
+
+ /// Server returned error status
+ #[error("Download failed: {0}")]
+ HttpStatus(reqwest::StatusCode),
+
+ /// Invalid content length header
+ #[error("Invalid content length header: {0}")]
+ InvalidContentLength(&'static str),
+
+ /// Failed to make range request
+ #[error("Failed to retrieve range")]
+ RangeRequest(#[source] reqwest::Error),
+
+ /// Failed to read chunk
+ #[error("Failed to read chunk")]
+ ChunkRead(#[source] reqwest::Error),
+
+ /// Failed to write chunk
+ #[error("Failed to write chunk")]
+ ChunkWrite(#[source] io::Error),
+
+ /// Failed to get stream position
+ #[error("Failed to get existing file size")]
+ StreamPosition(#[source] io::Error),
+
+ /// Failed to flush writer
+ #[error("Failed to flush writer")]
+ Flush(#[source] io::Error),
+
+ /// Size validation error
+ #[error("Size validation failed: {0}")]
+ SizeValidation(String),
+
+ /// File operation error
+ #[error("File operation failed: {0}")]
+ FileOperation(#[source] io::Error),
+
+ /// Other error
+ #[error("{0}")]
+ Other(&'static str),
+}
+
+impl DownloadError {
+ /// Checks if the error is caused by a timeout or network issue that can be retried
+ pub fn should_retry(&self) -> bool {
+ match self {
+ DownloadError::HeadRequest(e)
+ | DownloadError::RangeRequest(e)
+ | DownloadError::ChunkRead(e)
+ | DownloadError::ClientInitialization(e) => is_network_error(e),
+ DownloadError::HttpStatus(status) => {
+ // Retry server errors and timeout status
+ status.is_server_error() || *status == reqwest::StatusCode::REQUEST_TIMEOUT
+ }
+ // Don't retry other types of errors
+ _ => false,
+ }
+ }
+}
+
+/// Checks if the error is a network-related error that can be retried
+fn is_network_error(error: &reqwest::Error) -> bool {
+ // Retry on timeout errors
+ // Retry on connection errors (which often happen when switching networks)
+ // Retry on request errors (like "connection reset")
+ if error.is_timeout() || error.is_connect() || error.is_request() {
+ return true;
+ }
+
+ let mut error = error as &dyn Error;
+ loop {
+ if let Some(io_err) = error.downcast_ref::<std::io::Error>() {
+ // Check if the error is a timeout or connection error
+ if io_err.kind() == io::ErrorKind::TimedOut
+ || io_err.kind() == io::ErrorKind::ConnectionReset
+ {
+ return true;
+ }
+ }
+ if let Some(source) = error.source() {
+ error = source;
+ } else {
+ break false;
+ }
+ }
+}
/// Receiver of the current progress so far
pub trait ProgressUpdater: Send + 'static {
@@ -67,9 +172,26 @@ pub async fn get_to_file(
progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint,
) -> anyhow::Result<()> {
- let file = create_or_append(file).await?;
- let file = BufWriter::new(file);
- get_to_writer(file, url, progress_updater, size_hint).await
+ let file = create_or_append(file)
+ .await
+ .map_err(DownloadError::FileOperation)?;
+ let mut file = BufWriter::new(file);
+ let mut attempts = 0;
+ let mut read_timeout = READ_TIMEOUT;
+ while let Err(err) =
+ get_to_writer(&mut file, url, progress_updater, size_hint, read_timeout).await
+ {
+ if !err.should_retry() {
+ anyhow::bail!(err);
+ }
+ attempts += 1;
+ read_timeout *= 2;
+ if attempts >= MAX_RETRY_ATTEMPTS {
+ anyhow::bail!("Max retry attempts reached: {err}");
+ }
+ log::warn!("Download failed: {err}. Retrying in with timeout: {read_timeout:?}");
+ }
+ Ok(())
}
/// Download `url` to `writer`.
@@ -82,41 +204,59 @@ pub async fn get_to_writer(
url: &str,
progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint,
-) -> anyhow::Result<()> {
- let client = reqwest::Client::new();
+ read_timeout: Duration,
+) -> Result<(), DownloadError> {
+ // Create a new client for each download attempt to prevent stale connections
+ let client = reqwest::Client::builder()
+ .read_timeout(read_timeout)
+ .connect_timeout(CONNECT_TIMEOUT)
+ .build()
+ .map_err(DownloadError::ClientInitialization)?;
progress_updater.set_url(url);
- progress_updater.set_progress(0.);
// Fetch content length first
- let response = client.head(url).send().await.context("HEAD failed")?;
+ let response = client
+ .head(url)
+ .send()
+ .await
+ .map_err(DownloadError::HeadRequest)?;
+
if !response.status().is_success() {
- return response
- .error_for_status()
- .map(|_| ())
- .context("Download failed");
+ return Err(DownloadError::HttpStatus(response.status()));
}
let total_size = response
.headers()
.get(CONTENT_LENGTH)
- .context("Missing file size")?;
- let total_size: usize = total_size.to_str()?.parse().context("invalid size")?;
- size_hint.check_size(total_size)?;
+ .ok_or_else(|| DownloadError::InvalidContentLength("Missing file size"))?;
+
+ let total_size: usize = total_size
+ .to_str()
+ .map_err(|_| DownloadError::InvalidContentLength("Invalid content length header"))?
+ .parse()
+ .map_err(|_| DownloadError::InvalidContentLength("Invalid size format"))?;
+
+ match size_hint.check_size(total_size) {
+ Ok(_) => {}
+ Err(e) => return Err(DownloadError::SizeValidation(e.to_string())),
+ }
let already_fetched_bytes = writer
.stream_position()
.await
- .context("failed to get existing file size")?
+ .map_err(DownloadError::StreamPosition)?
.try_into()
- .context("invalid size")?;
+ .map_err(|_| DownloadError::Other("Invalid file position"))?;
+ progress_updater.set_progress(already_fetched_bytes as f32 / total_size as f32);
if total_size == already_fetched_bytes {
- progress_updater.set_progress(1.);
return Ok(());
}
if already_fetched_bytes > total_size {
- anyhow::bail!("Found existing file that was larger");
+ return Err(DownloadError::SizeValidation(
+ "Found existing file that was larger".to_string(),
+ ));
}
// Fetch content, one range at a time
@@ -133,32 +273,32 @@ pub async fn get_to_writer(
.header(RANGE, range)
.send()
.await
- .context("Failed to retrieve range")?;
+ .map_err(DownloadError::RangeRequest)?;
+
let status = response.status();
if !status.is_success() {
- return response
- .error_for_status()
- .map(|_| ())
- .context("Download failed");
+ return Err(DownloadError::HttpStatus(status));
}
let mut bytes_read = 0;
- while let Some(chunk) = response.chunk().await.context("Failed to read chunk")? {
+ while let Some(chunk) = response.chunk().await.map_err(DownloadError::ChunkRead)? {
bytes_read += chunk.len();
if bytes_read > total_size - already_fetched_bytes {
// Protect against servers responding with more data than expected
- anyhow::bail!("Server returned more than requested bytes");
+ return Err(DownloadError::SizeValidation(
+ "Server returned more than requested bytes".to_string(),
+ ));
}
writer
.write_all(&chunk)
.await
- .context("Failed to write chunk")?;
+ .map_err(DownloadError::ChunkWrite)?;
}
}
- writer.shutdown().await.context("Failed to flush")?;
+ writer.shutdown().await.map_err(DownloadError::Flush)?;
Ok(())
}
@@ -261,6 +401,7 @@ impl<PU: ProgressUpdater, Writer: AsyncWrite + Unpin> AsyncWrite
mod test {
use std::io::Cursor;
+ use anyhow::Context;
use async_tempfile::TempDir;
use rand::RngCore;
use tokio::{fs, io::AsyncWriteExt};
@@ -344,6 +485,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
+ READ_TIMEOUT,
)
.await
.context("Complete download failed")?;
@@ -378,6 +520,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
+ READ_TIMEOUT,
)
.await
.expect_err("Expected interrupted download");
@@ -408,6 +551,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
+ READ_TIMEOUT,
)
.await
.context("Partial download failed")?;
@@ -468,6 +612,7 @@ mod test {
&file_url,
&mut FakeProgressUpdater::default(),
SizeHint::Exact(1),
+ READ_TIMEOUT,
)
.await
.expect_err("Reject unexpected content length");
@@ -492,6 +637,7 @@ mod test {
&file_url,
&mut FakeProgressUpdater::default(),
SizeHint::Exact(file_data.len()),
+ READ_TIMEOUT,
)
.await
.expect_err("Reject unexpected chunk sizes");
diff --git a/test/Cargo.lock b/test/Cargo.lock
index 2d77e49eba..1ccf04e2e7 100644
--- a/test/Cargo.lock
+++ b/test/Cargo.lock
@@ -2217,6 +2217,7 @@ dependencies = [
"ed25519-dalek",
"hex",
"json-canon",
+ "log",
"mullvad-version",
"reqwest",
"serde",