summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-10-21 14:20:28 +0200
committerDavid Lönnhager <david.l@mullvad.net>2023-10-30 23:15:24 +0100
commit93ca54370d6011225f0fe16e4a94559e2287b9ee (patch)
treedfaaece6127f571ebf14bacc4e2e6d534a1c5414
parent1c24b27d2f4c1bf3f6f497f43d77d23d938aa20c (diff)
downloadmullvadvpn-93ca54370d6011225f0fe16e4a94559e2287b9ee.tar.xz
mullvadvpn-93ca54370d6011225f0fe16e4a94559e2287b9ee.zip
Handle authentication and errors in API client
-rw-r--r--mullvad-api/src/access.rs11
-rw-r--r--mullvad-api/src/device.rs104
-rw-r--r--mullvad-api/src/lib.rs168
-rw-r--r--mullvad-api/src/relay_list.rs10
-rw-r--r--mullvad-api/src/rest.rs345
5 files changed, 246 insertions, 392 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs
index 67c83ac4da..3edf580c8a 100644
--- a/mullvad-api/src/access.rs
+++ b/mullvad-api/src/access.rs
@@ -1,6 +1,7 @@
use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
+ API,
};
use futures::{
channel::{mpsc, oneshot},
@@ -13,7 +14,7 @@ use tokio::select;
pub const AUTH_URL_PREFIX: &str = "auth/v1";
-#[derive(Clone)]
+#[derive(Debug, Clone)]
pub struct AccessTokenStore {
tx: mpsc::UnboundedSender<StoreAction>,
}
@@ -36,7 +37,8 @@ struct AccountState {
}
impl AccessTokenStore {
- pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
+ pub(crate) fn new(service: RequestServiceHandle) -> Self {
+ let factory = rest::RequestFactory::new(API.host.clone(), None);
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
@@ -174,8 +176,9 @@ async fn fetch_access_token(
account_number: account_token,
};
- let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
+ let rest_request = factory
+ .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?
+ .expected_status(&[StatusCode::OK]);
let response = service.request(rest_request).await?;
- let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
rest::deserialize_body(response).await
}
diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs
index 3d8913e366..1f7175e6a2 100644
--- a/mullvad-api/src/device.rs
+++ b/mullvad-api/src/device.rs
@@ -1,5 +1,5 @@
use chrono::{DateTime, Utc};
-use http::{Method, StatusCode};
+use http::StatusCode;
use mullvad_types::{
account::AccountToken,
device::{Device, DeviceId, DeviceName},
@@ -51,23 +51,13 @@ impl DevicesProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices"),
- Method::POST,
- &submission,
- Some(access_token),
- &[StatusCode::CREATED],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
+ let request = factory
+ .post_json(&format!("{ACCOUNTS_URL_PREFIX}/devices"), &submission)?
+ .account(account)?
+ .expected_status(&[StatusCode::CREATED]);
+ let response = service.request(request).await;
let response: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
@@ -104,20 +94,13 @@ impl DevicesProxy {
) -> impl Future<Output = Result<Device, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let device = rest::deserialize_body(response?).await?;
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await?;
+ let device = rest::deserialize_body(response).await?;
Ok(device)
}
}
@@ -128,20 +111,13 @@ impl DevicesProxy {
) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let devices = rest::deserialize_body(response?).await?;
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/device"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await?;
+ let devices = rest::deserialize_body(response).await?;
Ok(devices)
}
}
@@ -153,21 +129,12 @@ impl DevicesProxy {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
- Method::DELETE,
- Some(access_token),
- &[StatusCode::NO_CONTENT],
- )
- .await;
- access_proxy.check_response(&account, &response);
-
- response?;
+ let request = factory
+ .delete(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))?
+ .expected_status(&[StatusCode::NO_CONTENT])
+ .account(account)?;
+ service.request(request).await?;
Ok(())
}
}
@@ -187,25 +154,18 @@ impl DevicesProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
- Method::PUT,
- &req_body,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
+ let request = factory
+ .put_json(
+ &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
+ &req_body,
+ )?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await?;
- let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
+ let updated_device: DeviceResponse = rest::deserialize_body(response).await?;
let DeviceResponse {
ipv4_address,
ipv6_address,
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 6beb1c8e86..b97fc37dab 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -340,7 +340,8 @@ impl Runtime {
self.socket_bypass_tx.clone(),
)
.await;
- let factory = rest::RequestFactory::new(API.host.clone(), None);
+ let token_store = access::AccessTokenStore::new(service.clone());
+ let factory = rest::RequestFactory::new(API.host.clone(), None, Some(token_store));
rest::MullvadRestHandle::new(
service,
@@ -392,19 +393,12 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"),
- Method::GET,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
+ let request = factory
+ .get(&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"))?
+ .expected_status(&[StatusCode::OK])
+ .account(account)?;
+ let response = service.request(request).await;
let account: AccountExpiryResponse = rest::deserialize_body(response?).await?;
Ok(account.expiry)
@@ -418,24 +412,21 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
- let response = rest::send_request(
- &self.handle.factory,
- service,
- &format!("{ACCOUNTS_URL_PREFIX}/accounts"),
- Method::POST,
- None,
- &[StatusCode::CREATED],
- );
+ let factory = self.handle.factory.clone();
async move {
- let account: AccountCreationResponse = rest::deserialize_body(response.await?).await?;
+ let request = factory
+ .post(&format!("{ACCOUNTS_URL_PREFIX}/accounts"))?
+ .expected_status(&[StatusCode::CREATED]);
+ let response = service.request(request).await?;
+ let account: AccountCreationResponse = rest::deserialize_body(response).await?;
Ok(account.number)
}
}
pub fn submit_voucher(
&mut self,
- account_token: AccountToken,
+ account: AccountToken,
voucher_code: String,
) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> {
#[derive(serde::Serialize)]
@@ -445,26 +436,15 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
let submission = VoucherSubmission { voucher_code };
async move {
- let access_token = access_proxy.get_token(&account_token).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{APP_URL_PREFIX}/submit-voucher"),
- Method::POST,
- &submission,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account_token, &response);
-
- rest::deserialize_body(response?).await
+ let request = factory
+ .post_json(&format!("{APP_URL_PREFIX}/submit-voucher"), &submission)?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ let response = service.request(request).await?;
+ rest::deserialize_body(response).await
}
}
@@ -480,26 +460,16 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"),
- Method::POST,
- &(),
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
+ let request = factory
+ .post_json(&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), &())?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ let response = service.request(request).await?;
let PlayPurchaseInitResponse { obfuscated_id } =
- rest::deserialize_body(response?).await?;
+ rest::deserialize_body(response).await?;
Ok(obfuscated_id)
}
@@ -513,24 +483,16 @@ impl AccountsProxy {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
-
- let response = rest::send_json_request(
- &factory,
- service,
- &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"),
- Method::POST,
- &play_purchase,
- Some(access_token),
- &[StatusCode::ACCEPTED],
- )
- .await;
-
- access_proxy.check_response(&account, &response);
- response?;
+ let request = factory
+ .post_json(
+ &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"),
+ &play_purchase,
+ )?
+ .account(account)?
+ .expected_status(&[StatusCode::ACCEPTED]);
+ service.request(request).await?;
Ok(())
}
}
@@ -546,21 +508,14 @@ impl AccountsProxy {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
- let access_proxy = self.handle.token_store.clone();
async move {
- let access_token = access_proxy.get_token(&account).await?;
- let response = rest::send_request(
- &factory,
- service,
- &format!("{APP_URL_PREFIX}/www-auth-token"),
- Method::POST,
- Some(access_token),
- &[StatusCode::OK],
- )
- .await;
- access_proxy.check_response(&account, &response);
- let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
+ let request = factory
+ .post(&format!("{APP_URL_PREFIX}/www-auth-token"))?
+ .account(account)?
+ .expected_status(&[StatusCode::OK]);
+ let response = service.request(request).await?;
+ let response: AuthTokenResponse = rest::deserialize_body(response).await?;
Ok(response.auth_token)
}
}
@@ -598,19 +553,13 @@ impl ProblemReportProxy {
};
let service = self.handle.service.clone();
-
- let request = rest::send_json_request(
- &self.handle.factory,
- service,
- &format!("{APP_URL_PREFIX}/problem-report"),
- Method::POST,
- &report,
- None,
- &[StatusCode::NO_CONTENT],
- );
+ let factory = self.handle.factory.clone();
async move {
- request.await?;
+ let request = factory
+ .post_json(&format!("{APP_URL_PREFIX}/problem-report"), &report)?
+ .expected_status(&[StatusCode::NO_CONTENT]);
+ service.request(request).await?;
Ok(())
}
}
@@ -646,12 +595,11 @@ impl AppVersionProxy {
let request = self.handle.factory.request(&path, Method::GET);
async move {
- let mut request = request?;
- request.add_header("M-Platform-Version", &platform_version)?;
-
+ let request = request?
+ .expected_status(&[StatusCode::OK])
+ .header("M-Platform-Version", &platform_version)?;
let response = service.request(request).await?;
- let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
- rest::deserialize_body(parsed_response).await
+ rest::deserialize_body(response).await
}
}
}
@@ -667,18 +615,12 @@ impl ApiProxy {
}
pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> {
- let service = self.handle.service.clone();
-
- let response = rest::send_request(
- &self.handle.factory,
- service,
- &format!("{APP_URL_PREFIX}/api-addrs"),
- Method::GET,
- None,
- &[StatusCode::OK],
- )
- .await?;
-
+ let request = self
+ .handle
+ .factory
+ .get(&format!("{APP_URL_PREFIX}/api-addrs"))?
+ .expected_status(&[StatusCode::OK]);
+ let response = self.handle.service.request(request).await?;
rest::deserialize_body(response).await
}
}
diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs
index 0eb2b22afd..6f14e6f05b 100644
--- a/mullvad-api/src/relay_list.rs
+++ b/mullvad-api/src/relay_list.rs
@@ -36,20 +36,18 @@ impl RelayListProxy {
let request = self.handle.factory.request("app/v1/relays", Method::GET);
async move {
- let mut request = request?;
- request.set_timeout(RELAY_LIST_TIMEOUT);
+ let mut request = request?
+ .timeout(RELAY_LIST_TIMEOUT)
+ .expected_status(&[StatusCode::NOT_MODIFIED, StatusCode::OK]);
if let Some(ref tag) = etag {
- request.add_header(header::IF_NONE_MATCH, tag)?;
+ request = request.header(header::IF_NONE_MATCH, tag)?;
}
let response = service.request(request).await?;
if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED {
return Ok(None);
}
- if response.status() != StatusCode::OK {
- return rest::handle_error_response(response).await;
- }
let etag = response
.headers()
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 39ed98d370..89917105ee 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -10,17 +10,16 @@ use crate::{
use futures::{
channel::{mpsc, oneshot},
stream::StreamExt,
- Stream, TryFutureExt,
+ Stream,
};
use hyper::{
- client::Client,
+ client::{connect::Connect, Client},
header::{self, HeaderValue},
Method, Uri,
};
use mullvad_types::account::AccountToken;
use std::{
error::Error as StdError,
- future::Future,
str::FromStr,
sync::{Arc, Weak},
time::Duration,
@@ -47,6 +46,9 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
/// Describes all the ways a REST request can fail
#[derive(err_derive::Error, Debug, Clone)]
pub enum Error {
+ #[error(display = "REST client service is down")]
+ RestServiceDown,
+
#[error(display = "Request cancelled")]
Aborted,
@@ -65,12 +67,6 @@ pub enum Error {
#[error(display = "Failed to deserialize data")]
DeserializeError(#[error(source)] Arc<serde_json::Error>),
- #[error(display = "Failed to send request to rest client")]
- SendError,
-
- #[error(display = "Failed to receive response from rest client")]
- ReceiveError,
-
/// Unexpected response code
#[error(display = "Unexpected response status code {} - {}", _0, _1)]
ApiError(StatusCode, String),
@@ -79,9 +75,8 @@ pub enum Error {
#[error(display = "Not a valid URI")]
InvalidUri,
- /// A new API config was requested, but the request could not be completed.
- #[error(display = "Failed to rotate API config")]
- NextApiConfigError,
+ #[error(display = "Set account token on factory with no access token store")]
+ NoAccessTokenStore,
}
impl Error {
@@ -201,45 +196,7 @@ impl<
async fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
- let tx = self.command_tx.upgrade();
- let timeout = request.timeout();
-
- let hyper_request = request.into_request();
-
- let api_availability = self.api_availability.clone();
- let suspend_fut = api_availability.wait_for_unsuspend();
- let request_fut = self.client.request(hyper_request).map_err(Error::from);
-
- let request_future = async move {
- let _ = suspend_fut.await;
- request_fut.await
- };
-
- let future = async move {
- let response = tokio::time::timeout(timeout, request_future)
- .await
- .map_err(|_| Error::TimeoutError);
-
- let response = flatten_result(response).map_err(|error| error.map_aborted());
-
- if let Err(err) = &response {
- if err.is_network_error() && !api_availability.get_state().is_offline() {
- log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
- if let Some(tx) = tx {
- let (completion_tx, _completion_rx) = oneshot::channel();
- let _ =
- tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
- }
- }
- }
-
- if completion_tx.send(response).is_err() {
- log::trace!(
- "Failed to send response to caller, caller channel is shut down"
- );
- }
- };
- tokio::spawn(future);
+ self.handle_new_request(request, completion_tx);
}
RequestCommand::Reset => {
self.connector_handle.reset();
@@ -268,6 +225,36 @@ impl<
}
}
+ fn handle_new_request(
+ &mut self,
+ request: RestRequest,
+ completion_tx: oneshot::Sender<Result<Response>>,
+ ) {
+ let tx = self.command_tx.upgrade();
+
+ let api_availability = self.api_availability.clone();
+ let request_future = request.into_future(self.client.clone());
+
+ tokio::spawn(async move {
+ let response = request_future.await.map_err(|error| error.map_aborted());
+
+ // Switch API endpoint if the request failed due to a network error
+ if let Err(err) = &response {
+ if err.is_network_error() && !api_availability.get_state().is_offline() {
+ log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
+ if let Some(tx) = tx {
+ let (completion_tx, _completion_rx) = oneshot::channel();
+ let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
+ }
+ }
+ }
+
+ if completion_tx.send(response).is_err() {
+ log::trace!("Failed to send response to caller, caller channel is shut down");
+ }
+ });
+ }
+
async fn into_future(mut self) {
while let Some(command) = self.command_rx.next().await {
self.process_command(command).await;
@@ -293,8 +280,8 @@ impl RequestServiceHandle {
let (completion_tx, completion_rx) = oneshot::channel();
self.tx
.unbounded_send(RequestCommand::NewRequest(request, completion_tx))
- .map_err(|_| Error::SendError)?;
- completion_rx.await.map_err(|_| Error::ReceiveError)?
+ .map_err(|_| Error::RestServiceDown)?;
+ completion_rx.await.map_err(|_| Error::RestServiceDown)?
}
/// Forcibly update the connection mode.
@@ -302,9 +289,8 @@ impl RequestServiceHandle {
let (completion_tx, completion_rx) = oneshot::channel();
self.tx
.unbounded_send(RequestCommand::NextApiConfig(completion_tx))
- .map_err(|_| Error::SendError)?;
-
- completion_rx.await.map_err(|_| Error::NextApiConfigError)?
+ .map_err(|_| Error::RestServiceDown)?;
+ completion_rx.await.map_err(|_| Error::RestServiceDown)?
}
}
@@ -321,9 +307,11 @@ pub(crate) enum RequestCommand {
/// A REST request that is sent to the RequestService to be executed.
#[derive(Debug)]
pub struct RestRequest {
- request: Request,
+ request: hyper::Request<hyper::Body>,
timeout: Duration,
- auth: Option<HeaderValue>,
+ access_token_store: Option<AccessTokenStore>,
+ account: Option<AccountToken>,
+ expected_status: &'static [hyper::StatusCode],
}
impl RestRequest {
@@ -343,54 +331,98 @@ impl RestRequest {
};
let request = builder.uri(uri).body(hyper::Body::empty())?;
+ Ok(Self::new(request, None))
+ }
- Ok(RestRequest {
- timeout: DEFAULT_TIMEOUT,
- auth: None,
+ fn new(
+ request: hyper::Request<hyper::Body>,
+ access_token_store: Option<AccessTokenStore>,
+ ) -> Self {
+ Self {
request,
- })
+ timeout: DEFAULT_TIMEOUT,
+ access_token_store,
+ account: None,
+ expected_status: &[],
+ }
}
- /// Set the auth header with the following format: `Bearer $auth`.
- pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> {
- let header = match auth {
- Some(auth) => Some(
- HeaderValue::from_str(&format!("Bearer {auth}"))
- .map_err(|_| Error::InvalidHeaderError)?,
- ),
- None => None,
- };
-
- self.auth = header;
- Ok(())
+ /// Set the account token to obtain authentication for.
+ /// This fails if no store is set.
+ pub fn account(mut self, account: AccountToken) -> Result<Self> {
+ if self.access_token_store.is_none() {
+ return Err(Error::NoAccessTokenStore);
+ }
+ self.account = Some(account);
+ Ok(self)
}
/// Sets timeout for the request.
- pub fn set_timeout(&mut self, timeout: Duration) {
+ pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
+ self
}
- /// Retrieves timeout
- pub fn timeout(&self) -> Duration {
- self.timeout
+ pub fn expected_status(mut self, expected_status: &'static [hyper::StatusCode]) -> Self {
+ self.expected_status = expected_status;
+ self
}
- pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> {
+ pub fn header<T: header::IntoHeaderName>(mut self, key: T, value: &str) -> Result<Self> {
let header_value =
http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?;
self.request.headers_mut().insert(key, header_value);
- Ok(())
+ Ok(self)
}
- /// Converts into a `hyper::Request<hyper::Body>`
- fn into_request(self) -> Request {
- let Self {
- mut request, auth, ..
- } = self;
- if let Some(auth) = auth {
- request.headers_mut().insert(header::AUTHORIZATION, auth);
+ async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
+ mut self,
+ hyper_client: hyper::Client<C>,
+ ) -> Result<Response> {
+ // Obtain access token first
+ if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) {
+ let access_token = store.get_token(account).await?;
+ let auth = HeaderValue::from_str(&format!("Bearer {access_token}"))
+ .map_err(|_| Error::InvalidHeaderError)?;
+ self.request
+ .headers_mut()
+ .insert(header::AUTHORIZATION, auth);
+ }
+
+ // Make request to hyper client
+ let request_fut = hyper_client.request(self.request);
+ let response = tokio::time::timeout(self.timeout, request_fut)
+ .await
+ .map_err(|_| Error::TimeoutError)?
+ .map_err(Error::from);
+
+ // Notify access token store of expired tokens
+ if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) {
+ store.check_response(account, &response);
+ }
+
+ // Parse unexpected responses and errors
+
+ let response = response?;
+
+ if !self.expected_status.contains(&response.status()) {
+ if !self.expected_status.is_empty() {
+ log::error!(
+ "Unexpected HTTP status code {}, expected codes [{}]",
+ response.status(),
+ self.expected_status
+ .iter()
+ .map(ToString::to_string)
+ .collect::<Vec<_>>()
+ .join(",")
+ );
+ }
+ if !response.status().is_success() {
+ return handle_error_response(response).await;
+ }
}
- request
+
+ Ok(response)
}
/// Returns the URI of the request
@@ -399,16 +431,6 @@ impl RestRequest {
}
}
-impl From<Request> for RestRequest {
- fn from(request: Request) -> Self {
- Self {
- request,
- timeout: DEFAULT_TIMEOUT,
- auth: None,
- }
- }
-}
-
#[derive(serde::Deserialize)]
struct OldErrorResponse {
pub code: String,
@@ -425,40 +447,55 @@ struct NewErrorResponse {
pub struct RequestFactory {
hostname: String,
path_prefix: Option<String>,
+ token_store: Option<AccessTokenStore>,
pub timeout: Duration,
}
impl RequestFactory {
- pub fn new(hostname: String, path_prefix: Option<String>) -> Self {
+ pub fn new(
+ hostname: String,
+ path_prefix: Option<String>,
+ token_store: Option<AccessTokenStore>,
+ ) -> Self {
Self {
hostname,
path_prefix,
+ token_store,
timeout: DEFAULT_TIMEOUT,
}
}
pub fn request(&self, path: &str, method: Method) -> Result<RestRequest> {
- self.hyper_request(path, method)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ Ok(
+ RestRequest::new(self.hyper_request(path, method)?, self.token_store.clone())
+ .timeout(self.timeout),
+ )
}
pub fn get(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::GET)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ self.request(path, Method::GET)
}
pub fn post(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::POST)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ self.request(path, Method::POST)
+ }
+
+ pub fn put(&self, path: &str) -> Result<RestRequest> {
+ self.request(path, Method::PUT)
+ }
+
+ pub fn delete(&self, path: &str) -> Result<RestRequest> {
+ self.request(path, Method::DELETE)
}
pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
self.json_request(Method::POST, path, body)
}
+ pub fn put_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
+ self.json_request(Method::PUT, path, body)
+ }
+
fn json_request<S: serde::Serialize>(
&self,
method: Method,
@@ -482,16 +519,10 @@ impl RequestFactory {
HeaderValue::from_static("application/json"),
);
- Ok(self.set_request_timeout(RestRequest::from(request)))
- }
-
- pub fn delete(&self, path: &str) -> Result<RestRequest> {
- self.hyper_request(path, Method::DELETE)
- .map(RestRequest::from)
- .map(|req| self.set_request_timeout(req))
+ Ok(RestRequest::new(request, self.token_store.clone()).timeout(self.timeout))
}
- fn hyper_request(&self, path: &str, method: Method) -> Result<Request> {
+ fn hyper_request(&self, path: &str, method: Method) -> Result<hyper::Request<hyper::Body>> {
let uri = self.get_uri(path)?;
let request = http::request::Builder::new()
.method(method)
@@ -509,47 +540,6 @@ impl RequestFactory {
let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri)
}
-
- fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest {
- request.timeout = self.timeout;
- request
- }
-}
-
-pub fn send_request(
- factory: &RequestFactory,
- service: RequestServiceHandle,
- uri: &str,
- method: Method,
- access_token: Option<AccountToken>,
- expected_statuses: &'static [hyper::StatusCode],
-) -> impl Future<Output = Result<Response>> {
- let request = factory.request(uri, method);
-
- async move {
- let mut request = request?;
- request.set_auth(access_token)?;
- let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
- }
-}
-
-pub fn send_json_request<B: serde::Serialize>(
- factory: &RequestFactory,
- service: RequestServiceHandle,
- uri: &str,
- method: Method,
- body: &B,
- access_token: Option<AccountToken>,
- expected_statuses: &'static [hyper::StatusCode],
-) -> impl Future<Output = Result<Response>> {
- let request = factory.json_request(method, uri, body);
- async move {
- let mut request = request?;
- request.set_auth(access_token)?;
- let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
- }
}
pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> {
@@ -578,29 +568,7 @@ fn get_body_length(response: &Response) -> usize {
.unwrap_or(0)
}
-pub async fn parse_rest_response(
- response: Response,
- expected_statuses: &'static [hyper::StatusCode],
-) -> Result<Response> {
- if !expected_statuses.contains(&response.status()) {
- log::error!(
- "Unexpected HTTP status code {}, expected codes [{}]",
- response.status(),
- expected_statuses
- .iter()
- .map(ToString::to_string)
- .collect::<Vec<_>>()
- .join(",")
- );
- if !response.status().is_success() {
- return handle_error_response(response).await;
- }
- }
-
- Ok(response)
-}
-
-pub async fn handle_error_response<T>(response: Response) -> Result<T> {
+async fn handle_error_response<T>(response: Response) -> Result<T> {
let status = response.status();
let error_message = match status {
hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed",
@@ -639,7 +607,6 @@ pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
pub availability: ApiAvailabilityHandle,
- pub token_store: AccessTokenStore,
}
impl MullvadRestHandle {
@@ -649,13 +616,10 @@ impl MullvadRestHandle {
address_cache: AddressCache,
availability: ApiAvailabilityHandle,
) -> Self {
- let token_store = AccessTokenStore::new(service.clone(), factory.clone());
-
let handle = Self {
service,
factory,
availability,
- token_store,
};
#[cfg(feature = "api-override")]
if API.disable_address_cache {
@@ -714,19 +678,6 @@ impl MullvadRestHandle {
pub fn service(&self) -> RequestServiceHandle {
self.service.clone()
}
-
- pub fn factory(&self) -> &RequestFactory {
- &self.factory
- }
-}
-
-fn flatten_result<T, E>(
- result: std::result::Result<std::result::Result<T, E>, E>,
-) -> std::result::Result<T, E> {
- match result {
- Ok(value) => value,
- Err(err) => Err(err),
- }
}
macro_rules! impl_into_arc_err {