summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-api/src/access.rs231
-rw-r--r--mullvad-api/src/bin/relay_list.rs2
-rw-r--r--mullvad-api/src/device.rs30
-rw-r--r--mullvad-api/src/lib.rs36
-rw-r--r--mullvad-api/src/rest.rs99
-rw-r--r--mullvad-daemon/src/device/service.rs12
-rw-r--r--mullvad-daemon/src/geoip.rs3
-rw-r--r--mullvad-daemon/src/management_interface.rs2
8 files changed, 262 insertions, 153 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs
index a3bec3f725..67c83ac4da 100644
--- a/mullvad-api/src/access.rs
+++ b/mullvad-api/src/access.rs
@@ -2,109 +2,180 @@ use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
};
+use futures::{
+ channel::{mpsc, oneshot},
+ StreamExt,
+};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
-use std::{
- collections::HashMap,
- sync::{Arc, Mutex},
-};
-use talpid_types::ErrorExt;
+use std::collections::HashMap;
+use tokio::select;
pub const AUTH_URL_PREFIX: &str = "auth/v1";
#[derive(Clone)]
-pub struct AccessTokenProxy {
- service: RequestServiceHandle,
- factory: RequestFactory,
- access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
+pub struct AccessTokenStore {
+ tx: mpsc::UnboundedSender<StoreAction>,
+}
+
+enum StoreAction {
+ /// Request an access token for `AccountToken`, or return a saved one if it's not expired.
+ GetAccessToken(
+ AccountToken,
+ oneshot::Sender<Result<AccessToken, rest::Error>>,
+ ),
+ /// Forget cached access token for `AccountToken`, and drop any in-flight requests
+ InvalidateToken(AccountToken),
+}
+
+#[derive(Default)]
+struct AccountState {
+ current_access_token: Option<AccessTokenData>,
+ inflight_request: Option<tokio::task::JoinHandle<()>>,
+ response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}
-impl AccessTokenProxy {
+impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
- Self {
- service,
- factory,
- access_from_account: Arc::new(Mutex::new(HashMap::new())),
+ let (tx, rx) = mpsc::unbounded();
+ tokio::spawn(Self::service_requests(rx, service, factory));
+ Self { tx }
+ }
+
+ async fn service_requests(
+ mut rx: mpsc::UnboundedReceiver<StoreAction>,
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ ) {
+ let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new();
+
+ let (completed_tx, mut completed_rx) = mpsc::unbounded();
+
+ loop {
+ select! {
+ action = rx.next() => {
+ let Some(action) = action else {
+ // We're done
+ break;
+ };
+
+ match action {
+ StoreAction::GetAccessToken(account, response_tx) => {
+ let account_state = account_states
+ .entry(account.clone())
+ .or_default();
+
+ // If there is an unexpired access token, just return it.
+ // Otherwise, generate a new token
+ if let Some(ref access_token) = account_state.current_access_token {
+ if !access_token.is_expired() {
+ log::trace!("Using stored access token");
+ let _ = response_tx.send(Ok(access_token.access_token.clone()));
+ continue;
+ }
+
+ log::debug!("Replacing expired access token");
+ account_state.current_access_token = None;
+ }
+
+ // Begin requesting an access token if it's not already underway.
+ // If there's already an inflight request, just save `response_tx`
+ account_state
+ .inflight_request
+ .get_or_insert_with(|| {
+ let completed_tx = completed_tx.clone();
+ let account = account.clone();
+ let service = service.clone();
+ let factory = factory.clone();
+
+ log::debug!("Fetching access token for an account");
+
+ tokio::spawn(async move {
+ let result = fetch_access_token(service, factory, account.clone()).await;
+ let _ = completed_tx.unbounded_send((account, result));
+ })
+ });
+
+ // Save the channel to respond to later
+ account_state.response_channels.push(response_tx);
+ }
+ StoreAction::InvalidateToken(account) => {
+ let account_state = account_states
+ .entry(account)
+ .or_default();
+
+ // Drop in-flight requests for the account
+ // & forget any existing access token
+
+ log::debug!("Invalidating access token for an account");
+
+ if let Some(task) = account_state.inflight_request.take() {
+ task.abort();
+ let _ = task.await;
+ }
+
+ account_state.response_channels.clear();
+ account_state.current_access_token = None;
+ }
+ }
+ }
+
+ Some((account, result)) = completed_rx.next() => {
+ let account_state = account_states
+ .entry(account)
+ .or_default();
+
+ account_state.inflight_request = None;
+
+ // Send response to all channels
+ for tx in account_state.response_channels.drain(..) {
+ let _ = tx.send(result.clone().map(|data| data.access_token));
+ }
+
+ if let Ok(access_token) = result {
+ account_state.current_access_token = Some(access_token);
+ }
+ }
+ }
}
}
/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
- let existing_token = {
- self.access_from_account
- .lock()
- .unwrap()
- .get(account.as_str())
- .cloned()
- };
- if let Some(access_token) = existing_token {
- if access_token.is_expired() {
- log::debug!("Replacing expired access token");
- return self.request_new_token(account.clone()).await;
- }
- log::trace!("Using stored access token");
- return Ok(access_token.access_token.clone());
- }
- self.request_new_token(account.clone()).await
+ let (tx, rx) = oneshot::channel();
+ let _ = self
+ .tx
+ .unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx));
+ rx.await.map_err(|_| rest::Error::Aborted)?
}
/// Remove an access token if the API response calls for it.
- pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
+ pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) {
if let Err(rest::Error::ApiError(_status, code)) = response {
if code == crate::INVALID_ACCESS_TOKEN {
- log::debug!("Dropping invalid access token");
- self.remove_token(account);
+ let _ = self
+ .tx
+ .unbounded_send(StoreAction::InvalidateToken(account.to_owned()));
}
}
}
+}
- /// Removes a stored access token.
- fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
- self.access_from_account
- .lock()
- .unwrap()
- .remove(account)
- .map(|v| v.access_token)
- }
-
- async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
- log::debug!("Fetching access token for an account");
- let access_token = self
- .fetch_access_token(account.clone())
- .await
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to obtain access token")
- );
- error
- })?;
- self.access_from_account
- .lock()
- .unwrap()
- .insert(account, access_token.clone());
- Ok(access_token.access_token)
+async fn fetch_access_token(
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ account_token: AccountToken,
+) -> Result<AccessTokenData, rest::Error> {
+ #[derive(serde::Serialize)]
+ struct AccessTokenRequest {
+ account_number: String,
}
+ let request = AccessTokenRequest {
+ account_number: account_token,
+ };
- async fn fetch_access_token(
- &self,
- account_token: AccountToken,
- ) -> Result<AccessTokenData, rest::Error> {
- #[derive(serde::Serialize)]
- struct AccessTokenRequest {
- account_number: String,
- }
- let request = AccessTokenRequest {
- account_number: account_token,
- };
-
- let service = self.service.clone();
-
- let rest_request = self
- .factory
- .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
- let response = service.request(rest_request).await?;
- let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
- rest::deserialize_body(response).await
- }
+ let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
+ let response = service.request(rest_request).await?;
+ let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(response).await
}
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index 2139e51f54..ffb65c28b2 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -21,7 +21,7 @@ async fn main() {
let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
- Err(RestError::TimeoutError(_)) => {
+ Err(RestError::TimeoutError) => {
eprintln!("Request timed out");
process::exit(2);
}
diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs
index 585e863ccf..3d8913e366 100644
--- a/mullvad-api/src/device.rs
+++ b/mullvad-api/src/device.rs
@@ -54,17 +54,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
+
let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::POST,
&submission,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::CREATED],
)
.await;
+ access_proxy.check_response(&account, &response);
+
let response: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
id,
@@ -102,16 +106,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::GET,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
- rest::deserialize_body(response?).await
+ access_proxy.check_response(&account, &response);
+ let device = rest::deserialize_body(response?).await?;
+ Ok(device)
}
}
@@ -123,16 +130,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::GET,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
- rest::deserialize_body(response?).await
+ access_proxy.check_response(&account, &response);
+ let devices = rest::deserialize_body(response?).await?;
+ Ok(devices)
}
}
@@ -145,15 +155,17 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::DELETE,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::NO_CONTENT],
)
.await;
+ access_proxy.check_response(&account, &response);
response?;
Ok(())
@@ -178,17 +190,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
+
let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
Method::PUT,
&req_body,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
+ access_proxy.check_response(&account, &response);
+
let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
ipv4_address,
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 63f5c2ad5b..6beb1c8e86 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -394,15 +394,17 @@ impl AccountsProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"),
Method::GET,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
+ access_proxy.check_response(&account, &response);
let account: AccountExpiryResponse = rest::deserialize_body(response?).await?;
Ok(account.expiry)
@@ -447,16 +449,21 @@ impl AccountsProxy {
let submission = VoucherSubmission { voucher_code };
async move {
+ let access_token = access_proxy.get_token(&account_token).await?;
+
let response = rest::send_json_request(
&factory,
service,
&format!("{APP_URL_PREFIX}/submit-voucher"),
Method::POST,
&submission,
- Some((access_proxy, account_token)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
+
+ access_proxy.check_response(&account_token, &response);
+
rest::deserialize_body(response?).await
}
}
@@ -464,7 +471,7 @@ impl AccountsProxy {
#[cfg(target_os = "android")]
pub fn init_play_purchase(
&mut self,
- account_token: AccountToken,
+ account: AccountToken,
) -> impl Future<Output = Result<PlayPurchasePaymentToken, rest::Error>> {
#[derive(serde::Deserialize)]
struct PlayPurchaseInitResponse {
@@ -476,17 +483,21 @@ impl AccountsProxy {
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
+
let response = rest::send_json_request(
&factory,
service,
&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"),
Method::POST,
&(),
- Some((access_proxy, account_token)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
+ access_proxy.check_response(&account, &response);
+
let PlayPurchaseInitResponse { obfuscated_id } =
rest::deserialize_body(response?).await?;
@@ -497,7 +508,7 @@ impl AccountsProxy {
#[cfg(target_os = "android")]
pub fn verify_play_purchase(
&mut self,
- account_token: AccountToken,
+ account: AccountToken,
play_purchase: PlayPurchase,
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
@@ -505,16 +516,21 @@ impl AccountsProxy {
let access_proxy = self.handle.token_store.clone();
async move {
- rest::send_json_request(
+ let access_token = access_proxy.get_token(&account).await?;
+
+ let response = rest::send_json_request(
&factory,
service,
&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"),
Method::POST,
&play_purchase,
- Some((access_proxy, account_token)),
+ Some(access_token),
&[StatusCode::ACCEPTED],
)
- .await?;
+ .await;
+
+ access_proxy.check_response(&account, &response);
+ response?;
Ok(())
}
}
@@ -533,15 +549,17 @@ impl AccountsProxy {
let access_proxy = self.handle.token_store.clone();
async move {
+ let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{APP_URL_PREFIX}/www-auth-token"),
Method::POST,
- Some((access_proxy, account)),
+ Some(access_token),
&[StatusCode::OK],
)
.await;
+ access_proxy.check_response(&account, &response);
let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
Ok(response.auth_token)
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index c3687a1eee..3690f1450c 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -1,7 +1,7 @@
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
- access::AccessTokenProxy,
+ access::AccessTokenStore,
address_cache::AddressCache,
availability::ApiAvailabilityHandle,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
@@ -44,25 +44,25 @@ pub type Result<T> = std::result::Result<T, Error>;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
/// Describes all the ways a REST request can fail
-#[derive(err_derive::Error, Debug)]
+#[derive(err_derive::Error, Debug, Clone)]
pub enum Error {
#[error(display = "Request cancelled")]
Aborted,
#[error(display = "Hyper error")]
- HyperError(#[error(source)] hyper::Error),
+ HyperError(#[error(source)] Arc<hyper::Error>),
#[error(display = "Invalid header value")]
- InvalidHeaderError(#[error(source)] http::header::InvalidHeaderValue),
+ InvalidHeaderError,
#[error(display = "HTTP error")]
- HttpError(#[error(source)] http::Error),
+ HttpError(#[error(source)] Arc<http::Error>),
#[error(display = "Request timed out")]
- TimeoutError(#[error(source)] tokio::time::error::Elapsed),
+ TimeoutError,
#[error(display = "Failed to deserialize data")]
- DeserializeError(#[error(source)] serde_json::Error),
+ DeserializeError(#[error(source)] Arc<serde_json::Error>),
#[error(display = "Failed to send request to rest client")]
SendError,
@@ -76,7 +76,7 @@ pub enum Error {
/// The string given was not a valid URI.
#[error(display = "Not a valid URI")]
- UriError(#[error(source)] http::uri::InvalidUri),
+ InvalidUri,
/// A new API config was requested, but the request could not be completed.
#[error(display = "Failed to rotate API config")]
@@ -85,7 +85,7 @@ pub enum Error {
impl Error {
pub fn is_network_error(&self) -> bool {
- matches!(self, Error::HyperError(_) | Error::TimeoutError(_))
+ matches!(self, Error::HyperError(_) | Error::TimeoutError)
}
pub fn is_aborted(&self) -> bool {
@@ -203,7 +203,7 @@ impl<
let future = async move {
let response = tokio::time::timeout(timeout, request_future)
.await
- .map_err(Error::TimeoutError);
+ .map_err(|_| Error::TimeoutError);
let response = flatten_result(response).map_err(|error| error.map_aborted());
@@ -314,20 +314,20 @@ pub struct RestRequest {
impl RestRequest {
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
pub fn get(uri: &str) -> Result<Self> {
- let uri = hyper::Uri::from_str(uri).map_err(Error::UriError)?;
+ let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?;
let mut builder = http::request::Builder::new()
.method(Method::GET)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"));
if let Some(host) = uri.host() {
- builder = builder.header(header::HOST, HeaderValue::from_str(host)?);
+ builder = builder.header(
+ header::HOST,
+ HeaderValue::from_str(host).map_err(|_| Error::InvalidHeaderError)?,
+ );
};
- let request = builder
- .uri(uri)
- .body(hyper::Body::empty())
- .map_err(Error::HttpError)?;
+ let request = builder.uri(uri).body(hyper::Body::empty())?;
Ok(RestRequest {
timeout: DEFAULT_TIMEOUT,
@@ -341,7 +341,7 @@ impl RestRequest {
let header = match auth {
Some(auth) => Some(
HeaderValue::from_str(&format!("Bearer {auth}"))
- .map_err(Error::InvalidHeaderError)?,
+ .map_err(|_| Error::InvalidHeaderError)?,
),
None => None,
};
@@ -361,7 +361,8 @@ impl RestRequest {
}
pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> {
- let header_value = http::HeaderValue::from_str(value).map_err(Error::InvalidHeaderError)?;
+ let header_value =
+ http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?;
self.request.headers_mut().insert(key, header_value);
Ok(())
}
@@ -458,7 +459,8 @@ impl RequestFactory {
let headers = request.headers_mut();
headers.insert(
header::CONTENT_LENGTH,
- HeaderValue::from_str(&body_length.to_string()).map_err(Error::InvalidHeaderError)?,
+ HeaderValue::from_str(&body_length.to_string())
+ .map_err(|_| Error::InvalidHeaderError)?,
);
headers.insert(
header::CONTENT_TYPE,
@@ -483,13 +485,14 @@ impl RequestFactory {
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
.header(header::HOST, self.hostname.clone());
- request.body(hyper::Body::empty()).map_err(Error::HttpError)
+ let result = request.body(hyper::Body::empty())?;
+ Ok(result)
}
fn get_uri(&self, path: &str) -> Result<Uri> {
let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
- hyper::Uri::from_str(&uri).map_err(Error::UriError)
+ hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri)
}
fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest {
@@ -503,25 +506,16 @@ pub fn send_request(
service: RequestServiceHandle,
uri: &str,
method: Method,
- auth: Option<(AccessTokenProxy, AccountToken)>,
+ access_token: Option<AccountToken>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
let request = factory.request(uri, method);
async move {
let mut request = request?;
- if let Some((store, account)) = &auth {
- let access_token = store.get_token(account).await?;
- request.set_auth(Some(access_token))?;
- }
+ request.set_auth(access_token)?;
let response = service.request(request).await?;
- let result = parse_rest_response(response, expected_statuses).await;
-
- if let Some((store, account)) = &auth {
- store.check_response(account, &result);
- }
-
- result
+ parse_rest_response(response, expected_statuses).await
}
}
@@ -531,24 +525,15 @@ pub fn send_json_request<B: serde::Serialize>(
uri: &str,
method: Method,
body: &B,
- auth: Option<(AccessTokenProxy, AccountToken)>,
+ access_token: Option<AccountToken>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
let request = factory.json_request(method, uri, body);
async move {
let mut request = request?;
- if let Some((store, account)) = &auth {
- let access_token = store.get_token(account).await?;
- request.set_auth(Some(access_token))?;
- }
+ request.set_auth(access_token)?;
let response = service.request(request).await?;
- let result = parse_rest_response(response, expected_statuses).await;
-
- if let Some((store, account)) = &auth {
- store.check_response(account, &result);
- }
-
- result
+ parse_rest_response(response, expected_statuses).await
}
}
@@ -566,7 +551,7 @@ async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
body.extend(&chunk?);
}
- serde_json::from_slice(&body).map_err(Error::DeserializeError)
+ serde_json::from_slice(&body).map_err(Error::from)
}
fn get_body_length(response: &Response) -> usize {
@@ -639,7 +624,7 @@ pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
pub availability: ApiAvailabilityHandle,
- pub token_store: AccessTokenProxy,
+ pub token_store: AccessTokenStore,
}
impl MullvadRestHandle {
@@ -649,7 +634,7 @@ impl MullvadRestHandle {
address_cache: AddressCache,
availability: ApiAvailabilityHandle,
) -> Self {
- let token_store = AccessTokenProxy::new(service.clone(), factory.clone());
+ let token_store = AccessTokenStore::new(service.clone(), factory.clone());
let handle = Self {
service,
@@ -728,3 +713,21 @@ fn flatten_result<T, E>(
Err(err) => Err(err),
}
}
+
+impl From<hyper::Error> for Error {
+ fn from(value: hyper::Error) -> Self {
+ Error::HyperError(Arc::new(value))
+ }
+}
+
+impl From<serde_json::Error> for Error {
+ fn from(value: serde_json::Error) -> Self {
+ Error::DeserializeError(Arc::new(value))
+ }
+}
+
+impl From<http::Error> for Error {
+ fn from(value: http::Error) -> Self {
+ Error::HttpError(Arc::new(value))
+ }
+}
diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs
index fdda61297f..a56cf10b48 100644
--- a/mullvad-daemon/src/device/service.rs
+++ b/mullvad-daemon/src/device/service.rs
@@ -14,7 +14,7 @@ use talpid_types::net::wireguard::PrivateKey;
use super::{Error, PrivateAccountAndDevice, PrivateDevice};
use mullvad_api::{
availability::ApiAvailabilityHandle,
- rest::{self, Error as RestError, MullvadRestHandle},
+ rest::{self, MullvadRestHandle},
AccountsProxy, DevicesProxy,
};
use talpid_core::future_retry::{retry_future, ConstantInterval, ExponentialBackoff, Jittered};
@@ -402,7 +402,7 @@ pub fn spawn_account_service(
}
fn handle_expiry_result_inner(
- result: &Result<chrono::DateTime<chrono::Utc>, mullvad_api::rest::Error>,
+ result: &Result<chrono::DateTime<chrono::Utc>, rest::Error>,
api_availability: &ApiAvailabilityHandle,
) -> bool {
match result {
@@ -425,18 +425,18 @@ fn handle_expiry_result_inner(
}
}
-fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
+fn should_retry<T>(result: &Result<T, rest::Error>, api_handle: &ApiAvailabilityHandle) -> bool {
match result {
Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
_ => false,
}
}
-fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
+fn should_retry_backoff<T>(result: &Result<T, rest::Error>) -> bool {
match result {
Ok(_) => false,
Err(error) => {
- if let RestError::ApiError(status, code) = error {
+ if let rest::Error::ApiError(status, code) = error {
*status != rest::StatusCode::NOT_FOUND
&& code != mullvad_api::DEVICE_NOT_FOUND
&& code != mullvad_api::INVALID_ACCOUNT
@@ -451,7 +451,7 @@ fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
fn map_rest_error(error: rest::Error) -> Error {
match error {
- RestError::ApiError(_status, ref code) => match code.as_str() {
+ rest::Error::ApiError(_status, ref code) => match code.as_str() {
// TODO: Implement invalid payment
mullvad_api::DEVICE_NOT_FOUND => Error::InvalidDevice,
mullvad_api::INVALID_ACCOUNT => Error::InvalidAccount,
diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs
index 527e06cf61..2619939215 100644
--- a/mullvad-daemon/src/geoip.rs
+++ b/mullvad-daemon/src/geoip.rs
@@ -89,10 +89,11 @@ async fn send_location_request_internal(
}
fn log_network_error(err: Error, version: &'static str) {
+ use std::sync::Arc;
let err_message = &format!("Unable to fetch {version} GeoIP location");
match err {
Error::HyperError(hyper_err) if hyper_err.is_connect() => {
- if let Some(cause) = hyper_err.into_cause() {
+ if let Some(cause) = Arc::into_inner(hyper_err).and_then(|x| x.into_cause()) {
if let Some(err) = cause.downcast_ref::<std::io::Error>() {
// Don't log ENETUNREACH errors, they are not informative.
if err.raw_os_error() == Some(libc::ENETUNREACH) {
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 993f0f9ece..61e4b025ba 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -1088,7 +1088,7 @@ fn map_rest_error(error: &RestError) -> Status {
{
Status::new(Code::Unauthenticated, message)
}
- RestError::TimeoutError(_elapsed) => Status::deadline_exceeded("API request timed out"),
+ RestError::TimeoutError => Status::deadline_exceeded("API request timed out"),
RestError::HyperError(_) => Status::unavailable("Cannot reach the API"),
error => Status::unknown(format!("REST error: {error}")),
}