summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-10-30 23:16:14 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-10-30 23:16:14 +0100
commitdf106cda61bf605f29eecc2b06bf353120a9f7c8 (patch)
tree15423286e42d010eba35594b62c916064e0923d1
parent1c24b27d2f4c1bf3f6f497f43d77d23d938aa20c (diff)
parentecf3a202d601ce2d4763e0037d01a706498b38f3 (diff)
downloadmullvadvpn-df106cda61bf605f29eecc2b06bf353120a9f7c8.tar.xz
mullvadvpn-df106cda61bf605f29eecc2b06bf353120a9f7c8.zip
Merge branch 'simplify-rest-client'
-rw-r--r--mullvad-api/src/access.rs14
-rw-r--r--mullvad-api/src/device.rs110
-rw-r--r--mullvad-api/src/lib.rs173
-rw-r--r--mullvad-api/src/relay_list.rs17
-rw-r--r--mullvad-api/src/rest.rs442
-rw-r--r--mullvad-daemon/src/geoip.rs5
-rw-r--r--mullvad-daemon/src/version_check.rs2
7 files changed, 314 insertions, 449 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs
index 67c83ac4da..276cc1f561 100644
--- a/mullvad-api/src/access.rs
+++ b/mullvad-api/src/access.rs
@@ -1,6 +1,7 @@
use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
+ API,
};
use futures::{
channel::{mpsc, oneshot},
@@ -13,7 +14,7 @@ use tokio::select;
pub const AUTH_URL_PREFIX: &str = "auth/v1";
-#[derive(Clone)]
+#[derive(Debug, Clone)]
pub struct AccessTokenStore {
tx: mpsc::UnboundedSender<StoreAction>,
}
@@ -36,7 +37,8 @@ struct AccountState {
}
impl AccessTokenStore {
- pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
+ pub(crate) fn new(service: RequestServiceHandle) -> Self {
+ let factory = rest::RequestFactory::new(&API.host, None);
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
@@ -174,8 +176,8 @@ async fn fetch_access_token(
account_number: account_token,
};
- let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
- let response = service.request(rest_request).await?;
- let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
- rest::deserialize_body(response).await
+ let rest_request = factory
+ .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?
+ .expected_status(&[StatusCode::OK]);
+ service.request(rest_request).await?.deserialize().await
}
diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs
index 3d8913e366..37410e99b3 100644
--- a/mullvad-api/src/device.rs
+++ b/mullvad-api/src/device.rs
@@ -1,5 +1,5 @@
use chrono::{DateTime, Utc};
-use http::{Method, StatusCode};
+use http::StatusCode;
use mullvad_types::{
account::AccountToken,
device::{Device, DeviceId, DeviceName},
@@ -51,25 +51,13 @@ impl DevicesProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices"),
- Method::POST,
- &submission,
- Some(access_token),
- &[StatusCode::CREATED],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
-
- let response: DeviceResponse = rest::deserialize_body(response?).await?;
+ let request = factory
+ .post_json(&format!("{ACCOUNTS_URL_PREFIX}/devices"), &submission)?
+ .account(account)?
+ .expected_status(&[StatusCode::CREATED]);
+ let response = service.request(request).await?;
let DeviceResponse {
id,
name,
@@ -79,7 +67,7 @@ impl DevicesProxy {
hijack_dns,
created,
..
- } = response;
+ } = response.deserialize().await?;
Ok((
Device {
@@ -104,21 +92,12 @@ impl DevicesProxy {
) -> impl Future<Output = Result<Device, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let device = rest::deserialize_body(response?).await?;
- Ok(device)
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ service.request(request).await?.deserialize().await
}
}
@@ -128,21 +107,12 @@ impl DevicesProxy {
) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let devices = rest::deserialize_body(response?).await?;
- Ok(devices)
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/device"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ service.request(request).await?.deserialize().await
}
}
@@ -153,21 +123,12 @@ impl DevicesProxy {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
- Method::DELETE,
- Some(access_token),
- &[StatusCode::NO_CONTENT],
- )
- .await;
- access_proxy.check_response(&account, &response);
-
- response?;
+ let request = factory
+ .delete(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))?
+ .expected_status(&[StatusCode::NO_CONTENT])
+ .account(account)?;
+ service.request(request).await?;
Ok(())
}
}
@@ -187,30 +148,21 @@ impl DevicesProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
- Method::PUT,
- &req_body,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
-
- let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
+ let request = factory
+ .put_json(
+ &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
+ &req_body,
+ )?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await?;
let DeviceResponse {
ipv4_address,
ipv6_address,
..
- } = updated_device;
+ } = response.deserialize().await?;
Ok(mullvad_types::wireguard::AssociatedAddresses {
ipv4_address,
ipv6_address,
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 6beb1c8e86..91e2bc524a 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -340,7 +340,8 @@ impl Runtime {
self.socket_bypass_tx.clone(),
)
.await;
- let factory = rest::RequestFactory::new(API.host.clone(), None);
+ let token_store = access::AccessTokenStore::new(service.clone());
+ let factory = rest::RequestFactory::new(&API.host, Some(token_store));
rest::MullvadRestHandle::new(
service,
@@ -392,21 +393,13 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
-
- let account: AccountExpiryResponse = rest::deserialize_body(response?).await?;
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await?;
+ let account: AccountExpiryResponse = response.deserialize().await?;
Ok(account.expiry)
}
}
@@ -418,24 +411,21 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
- let response = rest::send_request(
- &self.handle.factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/accounts"),
- Method::POST,
- None,
- &[StatusCode::CREATED],
- );
+ let factory = self.handle.factory.clone();
async move {
- let account: AccountCreationResponse = rest::deserialize_body(response.await?).await?;
+ let request = factory
+ .post(&format!("{ACCOUNTS_URL_PREFIX}/accounts"))?
+ .expected_status(&[StatusCode::CREATED]);
+ let response = service.request(request).await?;
+ let account: AccountCreationResponse = response.deserialize().await?;
Ok(account.number)
}
}
pub fn submit_voucher(
&mut self,
- account_token: AccountToken,
+ account: AccountToken,
voucher_code: String,
) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> {
#[derive(serde::Serialize)]
@@ -445,26 +435,14 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
let submission = VoucherSubmission { voucher_code };
async move {
- let access_token = access_proxy.get_token(&account_token).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{APP_URL_PREFIX}/submit-voucher"),
- Method::POST,
- &submission,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account_token, &response);
-
- rest::deserialize_body(response?).await
+ let request = factory
+ .post_json(&format!("{APP_URL_PREFIX}/submit-voucher"), &submission)?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ service.request(request).await?.deserialize().await
}
}
@@ -480,26 +458,15 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"),
- Method::POST,
- &(),
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
+ let request = factory
+ .post_json(&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), &())?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ let response = service.request(request).await?;
- let PlayPurchaseInitResponse { obfuscated_id } =
- rest::deserialize_body(response?).await?;
+ let PlayPurchaseInitResponse { obfuscated_id } = response.deserialize().await?;
Ok(obfuscated_id)
}
@@ -513,24 +480,16 @@ impl AccountsProxy {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"),
- Method::POST,
- &play_purchase,
- Some(access_token),
- &[StatusCode::ACCEPTED],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
- response?;
+ let request = factory
+ .post_json(
+ &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"),
+ &play_purchase,
+ )?
+ .account(account)?
+ .expected_status(&[StatusCode::ACCEPTED]);
+ service.request(request).await?;
Ok(())
}
}
@@ -546,21 +505,14 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{APP_URL_PREFIX}/www-auth-token"),
- Method::POST,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
+ let request = factory
+ .post(&format!("{APP_URL_PREFIX}/www-auth-token"))?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ let response = service.request(request).await?;
+ let response: AuthTokenResponse = response.deserialize().await?;
Ok(response.auth_token)
}
}
@@ -598,19 +550,13 @@ impl ProblemReportProxy {
};
let service = self.handle.service.clone();
-
- let request = rest::send_json_request(
- &self.handle.factory,
- service,
- &format!("{APP_URL_PREFIX}/problem-report"),
- Method::POST,
- &report,
- None,
- &[StatusCode::NO_CONTENT],
- );
+ let factory = self.handle.factory.clone();
async move {
- request.await?;
+ let request = factory
+ .post_json(&format!("{APP_URL_PREFIX}/problem-report"), &report)?
+ .expected_status(&[StatusCode::NO_CONTENT]);
+ service.request(request).await?;
Ok(())
}
}
@@ -646,12 +592,11 @@ impl AppVersionProxy {
let request = self.handle.factory.request(&path, Method::GET);
async move {
- let mut request = request?;
- request.add_header("M-Platform-Version", &platform_version)?;
-
+ let request = request?
+ .expected_status(&[StatusCode::OK])
+ .header("M-Platform-Version", &platform_version)?;
let response = service.request(request).await?;
- let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
- rest::deserialize_body(parsed_response).await
+ response.deserialize().await
}
}
}
@@ -667,18 +612,12 @@ impl ApiProxy {
}
pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> {
- let service = self.handle.service.clone();
-
- let response = rest::send_request(
- &self.handle.factory,
- service,
- &format!("{APP_URL_PREFIX}/api-addrs"),
- Method::GET,
- None,
- &[StatusCode::OK],
- )
- .await?;
-
- rest::deserialize_body(response).await
+ let request = self
+ .handle
+ .factory
+ .get(&format!("{APP_URL_PREFIX}/api-addrs"))?
+ .expected_status(&[StatusCode::OK]);
+ let response = self.handle.service.request(request).await?;
+ response.deserialize().await
}
}
diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs
index 0eb2b22afd..deaf29ef10 100644
--- a/mullvad-api/src/relay_list.rs
+++ b/mullvad-api/src/relay_list.rs
@@ -36,20 +36,18 @@ impl RelayListProxy {
let request = self.handle.factory.request("app/v1/relays", Method::GET);
async move {
- let mut request = request?;
- request.set_timeout(RELAY_LIST_TIMEOUT);
+ let mut request = request?
+ .timeout(RELAY_LIST_TIMEOUT)
+ .expected_status(&[StatusCode::NOT_MODIFIED, StatusCode::OK]);
if let Some(ref tag) = etag {
- request.add_header(header::IF_NONE_MATCH, tag)?;
+ request = request.header(header::IF_NONE_MATCH, tag)?;
}
let response = service.request(request).await?;
if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED {
return Ok(None);
}
- if response.status() != StatusCode::OK {
- return rest::handle_error_response(response).await;
- }
let etag = response
.headers()
@@ -62,11 +60,8 @@ impl RelayListProxy {
}
});
- Ok(Some(
- rest::deserialize_body::<ServerRelayList>(response)
- .await?
- .into_relay_list(etag),
- ))
+ let relay_list: ServerRelayList = response.deserialize().await?;
+ Ok(Some(relay_list.into_relay_list(etag)))
}
}
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 39ed98d370..63c909507c 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -10,17 +10,16 @@ use crate::{
use futures::{
channel::{mpsc, oneshot},
stream::StreamExt,
- Stream, TryFutureExt,
+ Stream,
};
use hyper::{
- client::Client,
+ client::{connect::Connect, Client},
header::{self, HeaderValue},
Method, Uri,
};
use mullvad_types::account::AccountToken;
use std::{
error::Error as StdError,
- future::Future,
str::FromStr,
sync::{Arc, Weak},
time::Duration,
@@ -32,9 +31,6 @@ use crate::API;
pub use hyper::StatusCode;
-pub type Request = hyper::Request<hyper::Body>;
-pub type Response = hyper::Response<hyper::Body>;
-
const USER_AGENT: &str = "mullvad-app";
const API_IP_CHECK_INITIAL: Duration = Duration::from_secs(15 * 60);
@@ -47,6 +43,9 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
/// Describes all the ways a REST request can fail
#[derive(err_derive::Error, Debug, Clone)]
pub enum Error {
+ #[error(display = "REST client service is down")]
+ RestServiceDown,
+
#[error(display = "Request cancelled")]
Aborted,
@@ -65,12 +64,6 @@ pub enum Error {
#[error(display = "Failed to deserialize data")]
DeserializeError(#[error(source)] Arc<serde_json::Error>),
- #[error(display = "Failed to send request to rest client")]
- SendError,
-
- #[error(display = "Failed to receive response from rest client")]
- ReceiveError,
-
/// Unexpected response code
#[error(display = "Unexpected response status code {} - {}", _0, _1)]
ApiError(StatusCode, String),
@@ -79,9 +72,8 @@ pub enum Error {
#[error(display = "Not a valid URI")]
InvalidUri,
- /// A new API config was requested, but the request could not be completed.
- #[error(display = "Failed to rotate API config")]
- NextApiConfigError,
+ #[error(display = "Set account token on factory with no access token store")]
+ NoAccessTokenStore,
}
impl Error {
@@ -201,45 +193,7 @@ impl<
async fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
- let tx = self.command_tx.upgrade();
- let timeout = request.timeout();
-
- let hyper_request = request.into_request();
-
- let api_availability = self.api_availability.clone();
- let suspend_fut = api_availability.wait_for_unsuspend();
- let request_fut = self.client.request(hyper_request).map_err(Error::from);
-
- let request_future = async move {
- let _ = suspend_fut.await;
- request_fut.await
- };
-
- let future = async move {
- let response = tokio::time::timeout(timeout, request_future)
- .await
- .map_err(|_| Error::TimeoutError);
-
- let response = flatten_result(response).map_err(|error| error.map_aborted());
-
- if let Err(err) = &response {
- if err.is_network_error() && !api_availability.get_state().is_offline() {
- log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
- if let Some(tx) = tx {
- let (completion_tx, _completion_rx) = oneshot::channel();
- let _ =
- tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
- }
- }
- }
-
- if completion_tx.send(response).is_err() {
- log::trace!(
- "Failed to send response to caller, caller channel is shut down"
- );
- }
- };
- tokio::spawn(future);
+ self.handle_new_request(request, completion_tx);
}
RequestCommand::Reset => {
self.connector_handle.reset();
@@ -268,6 +222,34 @@ impl<
}
}
+ fn handle_new_request(
+ &mut self,
+ request: Request,
+ completion_tx: oneshot::Sender<Result<Response>>,
+ ) {
+ let tx = self.command_tx.upgrade();
+
+ let api_availability = self.api_availability.clone();
+ let request_future = request.into_future(self.client.clone(), api_availability.clone());
+
+ tokio::spawn(async move {
+ let response = request_future.await.map_err(|error| error.map_aborted());
+
+ // Switch API endpoint if the request failed due to a network error
+ if let Err(err) = &response {
+ if err.is_network_error() && !api_availability.get_state().is_offline() {
+ log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
+ if let Some(tx) = tx {
+ let (completion_tx, _completion_rx) = oneshot::channel();
+ let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
+ }
+ }
+ }
+
+ let _ = completion_tx.send(response);
+ });
+ }
+
async fn into_future(mut self) {
while let Some(command) = self.command_rx.next().await {
self.process_command(command).await;
@@ -289,12 +271,12 @@ impl RequestServiceHandle {
}
/// Submits a `RestRequest` for execution to the request service.
- pub async fn request(&self, request: RestRequest) -> Result<Response> {
+ pub async fn request(&self, request: Request) -> Result<Response> {
let (completion_tx, completion_rx) = oneshot::channel();
self.tx
.unbounded_send(RequestCommand::NewRequest(request, completion_tx))
- .map_err(|_| Error::SendError)?;
- completion_rx.await.map_err(|_| Error::ReceiveError)?
+ .map_err(|_| Error::RestServiceDown)?;
+ completion_rx.await.map_err(|_| Error::RestServiceDown)?
}
/// Forcibly update the connection mode.
@@ -302,16 +284,15 @@ impl RequestServiceHandle {
let (completion_tx, completion_rx) = oneshot::channel();
self.tx
.unbounded_send(RequestCommand::NextApiConfig(completion_tx))
- .map_err(|_| Error::SendError)?;
-
- completion_rx.await.map_err(|_| Error::NextApiConfigError)?
+ .map_err(|_| Error::RestServiceDown)?;
+ completion_rx.await.map_err(|_| Error::RestServiceDown)?
}
}
#[derive(Debug)]
pub(crate) enum RequestCommand {
NewRequest(
- RestRequest,
+ Request,
oneshot::Sender<std::result::Result<Response, Error>>,
),
Reset,
@@ -320,13 +301,15 @@ pub(crate) enum RequestCommand {
/// A REST request that is sent to the RequestService to be executed.
#[derive(Debug)]
-pub struct RestRequest {
- request: Request,
+pub struct Request {
+ request: hyper::Request<hyper::Body>,
timeout: Duration,
- auth: Option<HeaderValue>,
+ access_token_store: Option<AccessTokenStore>,
+ account: Option<AccountToken>,
+ expected_status: &'static [hyper::StatusCode],
}
-impl RestRequest {
+impl Request {
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
pub fn get(uri: &str) -> Result<Self> {
let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?;
@@ -343,54 +326,112 @@ impl RestRequest {
};
let request = builder.uri(uri).body(hyper::Body::empty())?;
+ Ok(Self::new(request, None))
+ }
- Ok(RestRequest {
- timeout: DEFAULT_TIMEOUT,
- auth: None,
+ fn new(
+ request: hyper::Request<hyper::Body>,
+ access_token_store: Option<AccessTokenStore>,
+ ) -> Self {
+ Self {
request,
- })
+ timeout: DEFAULT_TIMEOUT,
+ access_token_store,
+ account: None,
+ expected_status: &[],
+ }
}
- /// Set the auth header with the following format: `Bearer $auth`.
- pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> {
- let header = match auth {
- Some(auth) => Some(
- HeaderValue::from_str(&format!("Bearer {auth}"))
- .map_err(|_| Error::InvalidHeaderError)?,
- ),
- None => None,
- };
-
- self.auth = header;
- Ok(())
+ /// Set the account token to obtain authentication for.
+ /// This fails if no store is set.
+ pub fn account(mut self, account: AccountToken) -> Result<Self> {
+ if self.access_token_store.is_none() {
+ return Err(Error::NoAccessTokenStore);
+ }
+ self.account = Some(account);
+ Ok(self)
}
/// Sets timeout for the request.
- pub fn set_timeout(&mut self, timeout: Duration) {
+ pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
+ self
}
- /// Retrieves timeout
- pub fn timeout(&self) -> Duration {
- self.timeout
+ pub fn expected_status(mut self, expected_status: &'static [hyper::StatusCode]) -> Self {
+ self.expected_status = expected_status;
+ self
}
- pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> {
+ pub fn header<T: header::IntoHeaderName>(mut self, key: T, value: &str) -> Result<Self> {
let header_value =
http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?;
self.request.headers_mut().insert(key, header_value);
- Ok(())
+ Ok(self)
+ }
+
+ async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
+ self,
+ hyper_client: hyper::Client<C>,
+ api_availability: ApiAvailabilityHandle,
+ ) -> Result<Response> {
+ let timeout = self.timeout;
+ let inner_fut = self.into_future_without_timeout(hyper_client, api_availability);
+ tokio::time::timeout(timeout, inner_fut)
+ .await
+ .map_err(|_| Error::TimeoutError)?
}
- /// Converts into a `hyper::Request<hyper::Body>`
- fn into_request(self) -> Request {
- let Self {
- mut request, auth, ..
- } = self;
- if let Some(auth) = auth {
- request.headers_mut().insert(header::AUTHORIZATION, auth);
+ async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>(
+ mut self,
+ hyper_client: hyper::Client<C>,
+ api_availability: ApiAvailabilityHandle,
+ ) -> Result<Response> {
+ let _ = api_availability.wait_for_unsuspend().await;
+
+ // Obtain access token first
+ if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) {
+ let access_token = store.get_token(account).await?;
+ let auth = HeaderValue::from_str(&format!("Bearer {access_token}"))
+ .map_err(|_| Error::InvalidHeaderError)?;
+ self.request
+ .headers_mut()
+ .insert(header::AUTHORIZATION, auth);
+ }
+
+ // Make request to hyper client
+ let response = hyper_client
+ .request(self.request)
+ .await
+ .map_err(Error::from);
+
+ // Notify access token store of expired tokens
+ if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) {
+ store.check_response(account, &response);
+ }
+
+ // Parse unexpected responses and errors
+
+ let response = response?;
+
+ if !self.expected_status.contains(&response.status()) {
+ if !self.expected_status.is_empty() {
+ log::error!(
+ "Unexpected HTTP status code {}, expected codes [{}]",
+ response.status(),
+ self.expected_status
+ .iter()
+ .map(ToString::to_string)
+ .collect::<Vec<_>>()
+ .join(",")
+ );
+ }
+ if !response.status().is_success() {
+ return handle_error_response(response).await;
+ }
}
- request
+
+ Ok(Response::new(response))
}
/// Returns the URI of the request
@@ -399,13 +440,28 @@ impl RestRequest {
}
}
-impl From<Request> for RestRequest {
- fn from(request: Request) -> Self {
- Self {
- request,
- timeout: DEFAULT_TIMEOUT,
- auth: None,
- }
+/// Successful result of a REST request
+#[derive(Debug)]
+pub struct Response {
+ response: hyper::Response<hyper::Body>,
+}
+
+impl Response {
+ fn new(response: hyper::Response<hyper::Body>) -> Self {
+ Self { response }
+ }
+
+ pub fn status(&self) -> StatusCode {
+ self.response.status()
+ }
+
+ pub fn headers(&self) -> &hyper::HeaderMap<HeaderValue> {
+ self.response.headers()
+ }
+
+ pub async fn deserialize<T: serde::de::DeserializeOwned>(self) -> Result<T> {
+ let body_length = get_body_length(&self.response);
+ deserialize_body_inner(self.response, body_length).await
}
}
@@ -423,48 +479,62 @@ struct NewErrorResponse {
#[derive(Clone)]
pub struct RequestFactory {
- hostname: String,
- path_prefix: Option<String>,
- pub timeout: Duration,
+ hostname: &'static str,
+ token_store: Option<AccessTokenStore>,
+ default_timeout: Duration,
}
impl RequestFactory {
- pub fn new(hostname: String, path_prefix: Option<String>) -> Self {
+ pub fn new(hostname: &'static str, token_store: Option<AccessTokenStore>) -> Self {
Self {
hostname,
- path_prefix,
- timeout: DEFAULT_TIMEOUT,
+ token_store,
+ default_timeout: DEFAULT_TIMEOUT,
}
}
- pub fn request(&self, path: &str, method: Method) -> Result<RestRequest> {
- self.hyper_request(path, method)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ pub fn request(&self, path: &str, method: Method) -> Result<Request> {
+ Ok(
+ Request::new(self.hyper_request(path, method)?, self.token_store.clone())
+ .timeout(self.default_timeout),
+ )
+ }
+
+ pub fn get(&self, path: &str) -> Result<Request> {
+ self.request(path, Method::GET)
+ }
+
+ pub fn post(&self, path: &str) -> Result<Request> {
+ self.request(path, Method::POST)
}
- pub fn get(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::GET)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ pub fn put(&self, path: &str) -> Result<Request> {
+ self.request(path, Method::PUT)
}
- pub fn post(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::POST)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ pub fn delete(&self, path: &str) -> Result<Request> {
+ self.request(path, Method::DELETE)
}
- pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
+ pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> {
self.json_request(Method::POST, path, body)
}
+ pub fn put_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> {
+ self.json_request(Method::PUT, path, body)
+ }
+
+ pub fn default_timeout(mut self, timeout: Duration) -> Self {
+ self.default_timeout = timeout;
+ self
+ }
+
fn json_request<S: serde::Serialize>(
&self,
method: Method,
path: &str,
body: &S,
- ) -> Result<RestRequest> {
+ ) -> Result<Request> {
let mut request = self.hyper_request(path, method)?;
let json_body = serde_json::to_string(&body)?;
@@ -482,94 +552,29 @@ impl RequestFactory {
HeaderValue::from_static("application/json"),
);
- Ok(self.set_request_timeout(RestRequest::from(request)))
+ Ok(Request::new(request, self.token_store.clone()).timeout(self.default_timeout))
}
- pub fn delete(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::DELETE)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
- }
-
- fn hyper_request(&self, path: &str, method: Method) -> Result<Request> {
+ fn hyper_request(&self, path: &str, method: Method) -> Result<hyper::Request<hyper::Body>> {
let uri = self.get_uri(path)?;
let request = http::request::Builder::new()
.method(method)
.uri(uri)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
- .header(header::HOST, self.hostname.clone());
+ .header(header::HOST, HeaderValue::from_static(self.hostname));
let result = request.body(hyper::Body::empty())?;
Ok(result)
}
fn get_uri(&self, path: &str) -> Result<Uri> {
- let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
- let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
+ let uri = format!("https://{}/{}", self.hostname, path);
hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri)
}
-
- fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest {
- request.timeout = self.timeout;
- request
- }
-}
-
-pub fn send_request(
- factory: &RequestFactory,
- service: RequestServiceHandle,
- uri: &str,
- method: Method,
- access_token: Option<AccountToken>,
- expected_statuses: &'static [hyper::StatusCode],
-) -> impl Future<Output = Result<Response>> {
- let request = factory.request(uri, method);
-
- async move {
- let mut request = request?;
- request.set_auth(access_token)?;
- let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
- }
-}
-
-pub fn send_json_request<B: serde::Serialize>(
- factory: &RequestFactory,
- service: RequestServiceHandle,
- uri: &str,
- method: Method,
- body: &B,
- access_token: Option<AccountToken>,
- expected_statuses: &'static [hyper::StatusCode],
-) -> impl Future<Output = Result<Response>> {
- let request = factory.json_request(method, uri, body);
- async move {
- let mut request = request?;
- request.set_auth(access_token)?;
- let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
- }
-}
-
-pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> {
- let body_length = get_body_length(&response);
- deserialize_body_inner(response, body_length).await
-}
-
-async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
- mut response: Response,
- body_length: usize,
-) -> Result<T> {
- let mut body: Vec<u8> = Vec::with_capacity(body_length);
- while let Some(chunk) = response.body_mut().next().await {
- body.extend(&chunk?);
- }
-
- serde_json::from_slice(&body).map_err(Error::from)
}
-fn get_body_length(response: &Response) -> usize {
+fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize {
response
.headers()
.get(header::CONTENT_LENGTH)
@@ -578,29 +583,7 @@ fn get_body_length(response: &Response) -> usize {
.unwrap_or(0)
}
-pub async fn parse_rest_response(
- response: Response,
- expected_statuses: &'static [hyper::StatusCode],
-) -> Result<Response> {
- if !expected_statuses.contains(&response.status()) {
- log::error!(
- "Unexpected HTTP status code {}, expected codes [{}]",
- response.status(),
- expected_statuses
- .iter()
- .map(ToString::to_string)
- .collect::<Vec<_>>()
- .join(",")
- );
- if !response.status().is_success() {
- return handle_error_response(response).await;
- }
- }
-
- Ok(response)
-}
-
-pub async fn handle_error_response<T>(response: Response) -> Result<T> {
+async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Result<T> {
let status = response.status();
let error_message = match status {
hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed",
@@ -634,12 +617,23 @@ pub async fn handle_error_response<T>(response: Response) -> Result<T> {
Err(Error::ApiError(status, error_message.to_owned()))
}
+async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
+ mut response: hyper::Response<hyper::Body>,
+ body_length: usize,
+) -> Result<T> {
+ let mut body: Vec<u8> = Vec::with_capacity(body_length);
+ while let Some(chunk) = response.body_mut().next().await {
+ body.extend(&chunk?);
+ }
+
+ serde_json::from_slice(&body).map_err(Error::from)
+}
+
#[derive(Clone)]
pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
pub availability: ApiAvailabilityHandle,
- pub token_store: AccessTokenStore,
}
impl MullvadRestHandle {
@@ -649,13 +643,10 @@ impl MullvadRestHandle {
address_cache: AddressCache,
availability: ApiAvailabilityHandle,
) -> Self {
- let token_store = AccessTokenStore::new(service.clone(), factory.clone());
-
let handle = Self {
service,
factory,
availability,
- token_store,
};
#[cfg(feature = "api-override")]
if API.disable_address_cache {
@@ -714,19 +705,6 @@ impl MullvadRestHandle {
pub fn service(&self) -> RequestServiceHandle {
self.service.clone()
}
-
- pub fn factory(&self) -> &RequestFactory {
- &self.factory
- }
-}
-
-fn flatten_result<T, E>(
- result: std::result::Result<std::result::Result<T, E>, E>,
-) -> std::result::Result<T, E> {
- match result {
- Ok(value) => value,
- Err(err) => Err(err),
- }
}
macro_rules! impl_into_arc_err {
diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs
index 0939787e91..43ffaf1054 100644
--- a/mullvad-daemon/src/geoip.rs
+++ b/mullvad-daemon/src/geoip.rs
@@ -83,9 +83,8 @@ async fn send_location_request_internal(
service: RequestServiceHandle,
) -> Result<AmIMullvad, Error> {
let future_service = service.clone();
- let request = mullvad_api::rest::RestRequest::get(uri)?;
- let response = future_service.request(request).await?;
- mullvad_api::rest::deserialize_body(response).await
+ let request = mullvad_api::rest::Request::get(uri)?;
+ future_service.request(request).await?.deserialize().await
}
fn log_network_error(err: Error, version: &'static str) {
diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs
index dfe0a26b5f..d7c7eaa805 100644
--- a/mullvad-daemon/src/version_check.rs
+++ b/mullvad-daemon/src/version_check.rs
@@ -150,7 +150,7 @@ impl VersionUpdater {
last_app_version_info: Option<AppVersionInfo>,
show_beta_releases: bool,
) -> (Self, VersionUpdaterHandle) {
- api_handle.factory.timeout = DOWNLOAD_TIMEOUT;
+ api_handle.factory = api_handle.factory.default_timeout(DOWNLOAD_TIMEOUT);
let version_proxy = AppVersionProxy::new(api_handle);
let cache_path = cache_dir.join(VERSION_INFO_FILENAME);
let (tx, rx) = mpsc::channel(1);