diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-10-21 14:20:28 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-10-30 23:15:24 +0100 |
| commit | 93ca54370d6011225f0fe16e4a94559e2287b9ee (patch) | |
| tree | dfaaece6127f571ebf14bacc4e2e6d534a1c5414 | |
| parent | 1c24b27d2f4c1bf3f6f497f43d77d23d938aa20c (diff) | |
| download | mullvadvpn-93ca54370d6011225f0fe16e4a94559e2287b9ee.tar.xz mullvadvpn-93ca54370d6011225f0fe16e4a94559e2287b9ee.zip | |
Handle authentication and errors in API client
| -rw-r--r-- | mullvad-api/src/access.rs | 11 | ||||
| -rw-r--r-- | mullvad-api/src/device.rs | 104 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 168 | ||||
| -rw-r--r-- | mullvad-api/src/relay_list.rs | 10 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 345 |
5 files changed, 246 insertions, 392 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index 67c83ac4da..3edf580c8a 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -1,6 +1,7 @@ use crate::{ rest, rest::{RequestFactory, RequestServiceHandle}, + API, }; use futures::{ channel::{mpsc, oneshot}, @@ -13,7 +14,7 @@ use tokio::select; pub const AUTH_URL_PREFIX: &str = "auth/v1"; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct AccessTokenStore { tx: mpsc::UnboundedSender<StoreAction>, } @@ -36,7 +37,8 @@ struct AccountState { } impl AccessTokenStore { - pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { + pub(crate) fn new(service: RequestServiceHandle) -> Self { + let factory = rest::RequestFactory::new(API.host.clone(), None); let (tx, rx) = mpsc::unbounded(); tokio::spawn(Self::service_requests(rx, service, factory)); Self { tx } @@ -174,8 +176,9 @@ async fn fetch_access_token( account_number: account_token, }; - let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; + let rest_request = factory + .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)? + .expected_status(&[StatusCode::OK]); let response = service.request(rest_request).await?; - let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; rest::deserialize_body(response).await } diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs index 3d8913e366..1f7175e6a2 100644 --- a/mullvad-api/src/device.rs +++ b/mullvad-api/src/device.rs @@ -1,5 +1,5 @@ use chrono::{DateTime, Utc}; -use http::{Method, StatusCode}; +use http::StatusCode; use mullvad_types::{ account::AccountToken, device::{Device, DeviceId, DeviceName}, @@ -51,23 +51,13 @@ impl DevicesProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices"), - Method::POST, - &submission, - Some(access_token), - &[StatusCode::CREATED], - ) - .await; - - access_proxy.check_response(&account, &response); + let request = factory + .post_json(&format!("{ACCOUNTS_URL_PREFIX}/devices"), &submission)? + .account(account)? + .expected_status(&[StatusCode::CREATED]); + let response = service.request(request).await; let response: DeviceResponse = rest::deserialize_body(response?).await?; let DeviceResponse { @@ -104,20 +94,13 @@ impl DevicesProxy { ) -> impl Future<Output = Result<Device, rest::Error>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let device = rest::deserialize_body(response?).await?; + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await?; + let device = rest::deserialize_body(response).await?; Ok(device) } } @@ -128,20 +111,13 @@ impl DevicesProxy { ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let devices = rest::deserialize_body(response?).await?; + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/device"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await?; + let devices = rest::deserialize_body(response).await?; Ok(devices) } } @@ -153,21 +129,12 @@ impl DevicesProxy { ) -> impl Future<Output = Result<(), rest::Error>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), - Method::DELETE, - Some(access_token), - &[StatusCode::NO_CONTENT], - ) - .await; - access_proxy.check_response(&account, &response); - - response?; + let request = factory + .delete(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))? + .expected_status(&[StatusCode::NO_CONTENT]) + .account(account)?; + service.request(request).await?; Ok(()) } } @@ -187,25 +154,18 @@ impl DevicesProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), - Method::PUT, - &req_body, - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account, &response); + let request = factory + .put_json( + &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), + &req_body, + )? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await?; - let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; + let updated_device: DeviceResponse = rest::deserialize_body(response).await?; let DeviceResponse { ipv4_address, ipv6_address, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 6beb1c8e86..b97fc37dab 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -340,7 +340,8 @@ impl Runtime { self.socket_bypass_tx.clone(), ) .await; - let factory = rest::RequestFactory::new(API.host.clone(), None); + let token_store = access::AccessTokenStore::new(service.clone()); + let factory = rest::RequestFactory::new(API.host.clone(), None, Some(token_store)); rest::MullvadRestHandle::new( service, @@ -392,19 +393,12 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await; let account: AccountExpiryResponse = rest::deserialize_body(response?).await?; Ok(account.expiry) @@ -418,24 +412,21 @@ impl AccountsProxy { } let service = self.handle.service.clone(); - let response = rest::send_request( - &self.handle.factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/accounts"), - Method::POST, - None, - &[StatusCode::CREATED], - ); + let factory = self.handle.factory.clone(); async move { - let account: AccountCreationResponse = rest::deserialize_body(response.await?).await?; + let request = factory + .post(&format!("{ACCOUNTS_URL_PREFIX}/accounts"))? + .expected_status(&[StatusCode::CREATED]); + let response = service.request(request).await?; + let account: AccountCreationResponse = rest::deserialize_body(response).await?; Ok(account.number) } } pub fn submit_voucher( &mut self, - account_token: AccountToken, + account: AccountToken, voucher_code: String, ) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> { #[derive(serde::Serialize)] @@ -445,26 +436,15 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); let submission = VoucherSubmission { voucher_code }; async move { - let access_token = access_proxy.get_token(&account_token).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{APP_URL_PREFIX}/submit-voucher"), - Method::POST, - &submission, - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account_token, &response); - - rest::deserialize_body(response?).await + let request = factory + .post_json(&format!("{APP_URL_PREFIX}/submit-voucher"), &submission)? + .account(account)? + .expected_status(&[StatusCode::OK]); + let response = service.request(request).await?; + rest::deserialize_body(response).await } } @@ -480,26 +460,16 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), - Method::POST, - &(), - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account, &response); + let request = factory + .post_json(&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), &())? + .account(account)? + .expected_status(&[StatusCode::OK]); + let response = service.request(request).await?; let PlayPurchaseInitResponse { obfuscated_id } = - rest::deserialize_body(response?).await?; + rest::deserialize_body(response).await?; Ok(obfuscated_id) } @@ -513,24 +483,16 @@ impl AccountsProxy { ) -> impl Future<Output = Result<(), rest::Error>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), - Method::POST, - &play_purchase, - Some(access_token), - &[StatusCode::ACCEPTED], - ) - .await; - - access_proxy.check_response(&account, &response); - response?; + let request = factory + .post_json( + &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), + &play_purchase, + )? + .account(account)? + .expected_status(&[StatusCode::ACCEPTED]); + service.request(request).await?; Ok(()) } } @@ -546,21 +508,14 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{APP_URL_PREFIX}/www-auth-token"), - Method::POST, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let response: AuthTokenResponse = rest::deserialize_body(response?).await?; + let request = factory + .post(&format!("{APP_URL_PREFIX}/www-auth-token"))? + .account(account)? + .expected_status(&[StatusCode::OK]); + let response = service.request(request).await?; + let response: AuthTokenResponse = rest::deserialize_body(response).await?; Ok(response.auth_token) } } @@ -598,19 +553,13 @@ impl ProblemReportProxy { }; let service = self.handle.service.clone(); - - let request = rest::send_json_request( - &self.handle.factory, - service, - &format!("{APP_URL_PREFIX}/problem-report"), - Method::POST, - &report, - None, - &[StatusCode::NO_CONTENT], - ); + let factory = self.handle.factory.clone(); async move { - request.await?; + let request = factory + .post_json(&format!("{APP_URL_PREFIX}/problem-report"), &report)? + .expected_status(&[StatusCode::NO_CONTENT]); + service.request(request).await?; Ok(()) } } @@ -646,12 +595,11 @@ impl AppVersionProxy { let request = self.handle.factory.request(&path, Method::GET); async move { - let mut request = request?; - request.add_header("M-Platform-Version", &platform_version)?; - + let request = request? + .expected_status(&[StatusCode::OK]) + .header("M-Platform-Version", &platform_version)?; let response = service.request(request).await?; - let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(parsed_response).await + rest::deserialize_body(response).await } } } @@ -667,18 +615,12 @@ impl ApiProxy { } pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> { - let service = self.handle.service.clone(); - - let response = rest::send_request( - &self.handle.factory, - service, - &format!("{APP_URL_PREFIX}/api-addrs"), - Method::GET, - None, - &[StatusCode::OK], - ) - .await?; - + let request = self + .handle + .factory + .get(&format!("{APP_URL_PREFIX}/api-addrs"))? + .expected_status(&[StatusCode::OK]); + let response = self.handle.service.request(request).await?; rest::deserialize_body(response).await } } diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 0eb2b22afd..6f14e6f05b 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -36,20 +36,18 @@ impl RelayListProxy { let request = self.handle.factory.request("app/v1/relays", Method::GET); async move { - let mut request = request?; - request.set_timeout(RELAY_LIST_TIMEOUT); + let mut request = request? + .timeout(RELAY_LIST_TIMEOUT) + .expected_status(&[StatusCode::NOT_MODIFIED, StatusCode::OK]); if let Some(ref tag) = etag { - request.add_header(header::IF_NONE_MATCH, tag)?; + request = request.header(header::IF_NONE_MATCH, tag)?; } let response = service.request(request).await?; if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED { return Ok(None); } - if response.status() != StatusCode::OK { - return rest::handle_error_response(response).await; - } let etag = response .headers() diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 39ed98d370..89917105ee 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -10,17 +10,16 @@ use crate::{ use futures::{ channel::{mpsc, oneshot}, stream::StreamExt, - Stream, TryFutureExt, + Stream, }; use hyper::{ - client::Client, + client::{connect::Connect, Client}, header::{self, HeaderValue}, Method, Uri, }; use mullvad_types::account::AccountToken; use std::{ error::Error as StdError, - future::Future, str::FromStr, sync::{Arc, Weak}, time::Duration, @@ -47,6 +46,9 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); /// Describes all the ways a REST request can fail #[derive(err_derive::Error, Debug, Clone)] pub enum Error { + #[error(display = "REST client service is down")] + RestServiceDown, + #[error(display = "Request cancelled")] Aborted, @@ -65,12 +67,6 @@ pub enum Error { #[error(display = "Failed to deserialize data")] DeserializeError(#[error(source)] Arc<serde_json::Error>), - #[error(display = "Failed to send request to rest client")] - SendError, - - #[error(display = "Failed to receive response from rest client")] - ReceiveError, - /// Unexpected response code #[error(display = "Unexpected response status code {} - {}", _0, _1)] ApiError(StatusCode, String), @@ -79,9 +75,8 @@ pub enum Error { #[error(display = "Not a valid URI")] InvalidUri, - /// A new API config was requested, but the request could not be completed. - #[error(display = "Failed to rotate API config")] - NextApiConfigError, + #[error(display = "Set account token on factory with no access token store")] + NoAccessTokenStore, } impl Error { @@ -201,45 +196,7 @@ impl< async fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { - let tx = self.command_tx.upgrade(); - let timeout = request.timeout(); - - let hyper_request = request.into_request(); - - let api_availability = self.api_availability.clone(); - let suspend_fut = api_availability.wait_for_unsuspend(); - let request_fut = self.client.request(hyper_request).map_err(Error::from); - - let request_future = async move { - let _ = suspend_fut.await; - request_fut.await - }; - - let future = async move { - let response = tokio::time::timeout(timeout, request_future) - .await - .map_err(|_| Error::TimeoutError); - - let response = flatten_result(response).map_err(|error| error.map_aborted()); - - if let Err(err) = &response { - if err.is_network_error() && !api_availability.get_state().is_offline() { - log::error!("{}", err.display_chain_with_msg("HTTP request failed")); - if let Some(tx) = tx { - let (completion_tx, _completion_rx) = oneshot::channel(); - let _ = - tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx)); - } - } - } - - if completion_tx.send(response).is_err() { - log::trace!( - "Failed to send response to caller, caller channel is shut down" - ); - } - }; - tokio::spawn(future); + self.handle_new_request(request, completion_tx); } RequestCommand::Reset => { self.connector_handle.reset(); @@ -268,6 +225,36 @@ impl< } } + fn handle_new_request( + &mut self, + request: RestRequest, + completion_tx: oneshot::Sender<Result<Response>>, + ) { + let tx = self.command_tx.upgrade(); + + let api_availability = self.api_availability.clone(); + let request_future = request.into_future(self.client.clone()); + + tokio::spawn(async move { + let response = request_future.await.map_err(|error| error.map_aborted()); + + // Switch API endpoint if the request failed due to a network error + if let Err(err) = &response { + if err.is_network_error() && !api_availability.get_state().is_offline() { + log::error!("{}", err.display_chain_with_msg("HTTP request failed")); + if let Some(tx) = tx { + let (completion_tx, _completion_rx) = oneshot::channel(); + let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx)); + } + } + } + + if completion_tx.send(response).is_err() { + log::trace!("Failed to send response to caller, caller channel is shut down"); + } + }); + } + async fn into_future(mut self) { while let Some(command) = self.command_rx.next().await { self.process_command(command).await; @@ -293,8 +280,8 @@ impl RequestServiceHandle { let (completion_tx, completion_rx) = oneshot::channel(); self.tx .unbounded_send(RequestCommand::NewRequest(request, completion_tx)) - .map_err(|_| Error::SendError)?; - completion_rx.await.map_err(|_| Error::ReceiveError)? + .map_err(|_| Error::RestServiceDown)?; + completion_rx.await.map_err(|_| Error::RestServiceDown)? } /// Forcibly update the connection mode. @@ -302,9 +289,8 @@ impl RequestServiceHandle { let (completion_tx, completion_rx) = oneshot::channel(); self.tx .unbounded_send(RequestCommand::NextApiConfig(completion_tx)) - .map_err(|_| Error::SendError)?; - - completion_rx.await.map_err(|_| Error::NextApiConfigError)? + .map_err(|_| Error::RestServiceDown)?; + completion_rx.await.map_err(|_| Error::RestServiceDown)? } } @@ -321,9 +307,11 @@ pub(crate) enum RequestCommand { /// A REST request that is sent to the RequestService to be executed. #[derive(Debug)] pub struct RestRequest { - request: Request, + request: hyper::Request<hyper::Body>, timeout: Duration, - auth: Option<HeaderValue>, + access_token_store: Option<AccessTokenStore>, + account: Option<AccountToken>, + expected_status: &'static [hyper::StatusCode], } impl RestRequest { @@ -343,54 +331,98 @@ impl RestRequest { }; let request = builder.uri(uri).body(hyper::Body::empty())?; + Ok(Self::new(request, None)) + } - Ok(RestRequest { - timeout: DEFAULT_TIMEOUT, - auth: None, + fn new( + request: hyper::Request<hyper::Body>, + access_token_store: Option<AccessTokenStore>, + ) -> Self { + Self { request, - }) + timeout: DEFAULT_TIMEOUT, + access_token_store, + account: None, + expected_status: &[], + } } - /// Set the auth header with the following format: `Bearer $auth`. - pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> { - let header = match auth { - Some(auth) => Some( - HeaderValue::from_str(&format!("Bearer {auth}")) - .map_err(|_| Error::InvalidHeaderError)?, - ), - None => None, - }; - - self.auth = header; - Ok(()) + /// Set the account token to obtain authentication for. + /// This fails if no store is set. + pub fn account(mut self, account: AccountToken) -> Result<Self> { + if self.access_token_store.is_none() { + return Err(Error::NoAccessTokenStore); + } + self.account = Some(account); + Ok(self) } /// Sets timeout for the request. - pub fn set_timeout(&mut self, timeout: Duration) { + pub fn timeout(mut self, timeout: Duration) -> Self { self.timeout = timeout; + self } - /// Retrieves timeout - pub fn timeout(&self) -> Duration { - self.timeout + pub fn expected_status(mut self, expected_status: &'static [hyper::StatusCode]) -> Self { + self.expected_status = expected_status; + self } - pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> { + pub fn header<T: header::IntoHeaderName>(mut self, key: T, value: &str) -> Result<Self> { let header_value = http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?; self.request.headers_mut().insert(key, header_value); - Ok(()) + Ok(self) } - /// Converts into a `hyper::Request<hyper::Body>` - fn into_request(self) -> Request { - let Self { - mut request, auth, .. - } = self; - if let Some(auth) = auth { - request.headers_mut().insert(header::AUTHORIZATION, auth); + async fn into_future<C: Connect + Clone + Send + Sync + 'static>( + mut self, + hyper_client: hyper::Client<C>, + ) -> Result<Response> { + // Obtain access token first + if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) { + let access_token = store.get_token(account).await?; + let auth = HeaderValue::from_str(&format!("Bearer {access_token}")) + .map_err(|_| Error::InvalidHeaderError)?; + self.request + .headers_mut() + .insert(header::AUTHORIZATION, auth); + } + + // Make request to hyper client + let request_fut = hyper_client.request(self.request); + let response = tokio::time::timeout(self.timeout, request_fut) + .await + .map_err(|_| Error::TimeoutError)? + .map_err(Error::from); + + // Notify access token store of expired tokens + if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) { + store.check_response(account, &response); + } + + // Parse unexpected responses and errors + + let response = response?; + + if !self.expected_status.contains(&response.status()) { + if !self.expected_status.is_empty() { + log::error!( + "Unexpected HTTP status code {}, expected codes [{}]", + response.status(), + self.expected_status + .iter() + .map(ToString::to_string) + .collect::<Vec<_>>() + .join(",") + ); + } + if !response.status().is_success() { + return handle_error_response(response).await; + } } - request + + Ok(response) } /// Returns the URI of the request @@ -399,16 +431,6 @@ impl RestRequest { } } -impl From<Request> for RestRequest { - fn from(request: Request) -> Self { - Self { - request, - timeout: DEFAULT_TIMEOUT, - auth: None, - } - } -} - #[derive(serde::Deserialize)] struct OldErrorResponse { pub code: String, @@ -425,40 +447,55 @@ struct NewErrorResponse { pub struct RequestFactory { hostname: String, path_prefix: Option<String>, + token_store: Option<AccessTokenStore>, pub timeout: Duration, } impl RequestFactory { - pub fn new(hostname: String, path_prefix: Option<String>) -> Self { + pub fn new( + hostname: String, + path_prefix: Option<String>, + token_store: Option<AccessTokenStore>, + ) -> Self { Self { hostname, path_prefix, + token_store, timeout: DEFAULT_TIMEOUT, } } pub fn request(&self, path: &str, method: Method) -> Result<RestRequest> { - self.hyper_request(path, method) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + Ok( + RestRequest::new(self.hyper_request(path, method)?, self.token_store.clone()) + .timeout(self.timeout), + ) } pub fn get(&self, path: &str) -> Result<RestRequest> { - self.hyper_request(path, Method::GET) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + self.request(path, Method::GET) } pub fn post(&self, path: &str) -> Result<RestRequest> { - self.hyper_request(path, Method::POST) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + self.request(path, Method::POST) + } + + pub fn put(&self, path: &str) -> Result<RestRequest> { + self.request(path, Method::PUT) + } + + pub fn delete(&self, path: &str) -> Result<RestRequest> { + self.request(path, Method::DELETE) } pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> { self.json_request(Method::POST, path, body) } + pub fn put_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> { + self.json_request(Method::PUT, path, body) + } + fn json_request<S: serde::Serialize>( &self, method: Method, @@ -482,16 +519,10 @@ impl RequestFactory { HeaderValue::from_static("application/json"), ); - Ok(self.set_request_timeout(RestRequest::from(request))) - } - - pub fn delete(&self, path: &str) -> Result<RestRequest> { - self.hyper_request(path, Method::DELETE) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + Ok(RestRequest::new(request, self.token_store.clone()).timeout(self.timeout)) } - fn hyper_request(&self, path: &str, method: Method) -> Result<Request> { + fn hyper_request(&self, path: &str, method: Method) -> Result<hyper::Request<hyper::Body>> { let uri = self.get_uri(path)?; let request = http::request::Builder::new() .method(method) @@ -509,47 +540,6 @@ impl RequestFactory { let uri = format!("https://{}/{}{}", self.hostname, prefix, path); hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri) } - - fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest { - request.timeout = self.timeout; - request - } -} - -pub fn send_request( - factory: &RequestFactory, - service: RequestServiceHandle, - uri: &str, - method: Method, - access_token: Option<AccountToken>, - expected_statuses: &'static [hyper::StatusCode], -) -> impl Future<Output = Result<Response>> { - let request = factory.request(uri, method); - - async move { - let mut request = request?; - request.set_auth(access_token)?; - let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await - } -} - -pub fn send_json_request<B: serde::Serialize>( - factory: &RequestFactory, - service: RequestServiceHandle, - uri: &str, - method: Method, - body: &B, - access_token: Option<AccountToken>, - expected_statuses: &'static [hyper::StatusCode], -) -> impl Future<Output = Result<Response>> { - let request = factory.json_request(method, uri, body); - async move { - let mut request = request?; - request.set_auth(access_token)?; - let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await - } } pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> { @@ -578,29 +568,7 @@ fn get_body_length(response: &Response) -> usize { .unwrap_or(0) } -pub async fn parse_rest_response( - response: Response, - expected_statuses: &'static [hyper::StatusCode], -) -> Result<Response> { - if !expected_statuses.contains(&response.status()) { - log::error!( - "Unexpected HTTP status code {}, expected codes [{}]", - response.status(), - expected_statuses - .iter() - .map(ToString::to_string) - .collect::<Vec<_>>() - .join(",") - ); - if !response.status().is_success() { - return handle_error_response(response).await; - } - } - - Ok(response) -} - -pub async fn handle_error_response<T>(response: Response) -> Result<T> { +async fn handle_error_response<T>(response: Response) -> Result<T> { let status = response.status(); let error_message = match status { hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed", @@ -639,7 +607,6 @@ pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, pub availability: ApiAvailabilityHandle, - pub token_store: AccessTokenStore, } impl MullvadRestHandle { @@ -649,13 +616,10 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { - let token_store = AccessTokenStore::new(service.clone(), factory.clone()); - let handle = Self { service, factory, availability, - token_store, }; #[cfg(feature = "api-override")] if API.disable_address_cache { @@ -714,19 +678,6 @@ impl MullvadRestHandle { pub fn service(&self) -> RequestServiceHandle { self.service.clone() } - - pub fn factory(&self) -> &RequestFactory { - &self.factory - } -} - -fn flatten_result<T, E>( - result: std::result::Result<std::result::Result<T, E>, E>, -) -> std::result::Result<T, E> { - match result { - Ok(value) => value, - Err(err) => Err(err), - } } macro_rules! impl_into_arc_err { |
