diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-10-21 14:54:09 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-10-30 23:15:25 +0100 |
| commit | 8dd49b1e9643ba0aecd414f2d544c5190a38e530 (patch) | |
| tree | 5031812333a0a0e60a56221c377e0327fa5f54c3 | |
| parent | 87f73399c400ee8b99e62b895785c8b2dd67e082 (diff) | |
| download | mullvadvpn-8dd49b1e9643ba0aecd414f2d544c5190a38e530.tar.xz mullvadvpn-8dd49b1e9643ba0aecd414f2d544c5190a38e530.zip | |
Newtype REST Response
| -rw-r--r-- | mullvad-api/src/access.rs | 3 | ||||
| -rw-r--r-- | mullvad-api/src/device.rs | 18 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 19 | ||||
| -rw-r--r-- | mullvad-api/src/relay_list.rs | 7 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 66 | ||||
| -rw-r--r-- | mullvad-daemon/src/geoip.rs | 3 |
6 files changed, 59 insertions, 57 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index 3edf580c8a..2c244d1513 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -179,6 +179,5 @@ async fn fetch_access_token( let rest_request = factory .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)? .expected_status(&[StatusCode::OK]); - let response = service.request(rest_request).await?; - rest::deserialize_body(response).await + service.request(rest_request).await?.deserialize().await } diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs index 1f7175e6a2..37410e99b3 100644 --- a/mullvad-api/src/device.rs +++ b/mullvad-api/src/device.rs @@ -57,9 +57,7 @@ impl DevicesProxy { .post_json(&format!("{ACCOUNTS_URL_PREFIX}/devices"), &submission)? .account(account)? .expected_status(&[StatusCode::CREATED]); - let response = service.request(request).await; - - let response: DeviceResponse = rest::deserialize_body(response?).await?; + let response = service.request(request).await?; let DeviceResponse { id, name, @@ -69,7 +67,7 @@ impl DevicesProxy { hijack_dns, created, .. - } = response; + } = response.deserialize().await?; Ok(( Device { @@ -99,9 +97,7 @@ impl DevicesProxy { .get(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))? .expected_status(&[StatusCode::OK]) .account(account)?; - let response = service.request(request).await?; - let device = rest::deserialize_body(response).await?; - Ok(device) + service.request(request).await?.deserialize().await } } @@ -116,9 +112,7 @@ impl DevicesProxy { .get(&format!("{ACCOUNTS_URL_PREFIX}/device"))? .expected_status(&[StatusCode::OK]) .account(account)?; - let response = service.request(request).await?; - let devices = rest::deserialize_body(response).await?; - Ok(devices) + service.request(request).await?.deserialize().await } } @@ -164,13 +158,11 @@ impl DevicesProxy { .expected_status(&[StatusCode::OK]) .account(account)?; let response = service.request(request).await?; - - let updated_device: DeviceResponse = rest::deserialize_body(response).await?; let DeviceResponse { ipv4_address, ipv6_address, .. - } = updated_device; + } = response.deserialize().await?; Ok(mullvad_types::wireguard::AssociatedAddresses { ipv4_address, ipv6_address, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index b97fc37dab..26c6d3f71d 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -398,9 +398,8 @@ impl AccountsProxy { .get(&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"))? .expected_status(&[StatusCode::OK]) .account(account)?; - let response = service.request(request).await; - - let account: AccountExpiryResponse = rest::deserialize_body(response?).await?; + let response = service.request(request).await?; + let account: AccountExpiryResponse = response.deserialize().await?; Ok(account.expiry) } } @@ -419,7 +418,7 @@ impl AccountsProxy { .post(&format!("{ACCOUNTS_URL_PREFIX}/accounts"))? .expected_status(&[StatusCode::CREATED]); let response = service.request(request).await?; - let account: AccountCreationResponse = rest::deserialize_body(response).await?; + let account: AccountCreationResponse = response.deserialize().await?; Ok(account.number) } } @@ -443,8 +442,7 @@ impl AccountsProxy { .post_json(&format!("{APP_URL_PREFIX}/submit-voucher"), &submission)? .account(account)? .expected_status(&[StatusCode::OK]); - let response = service.request(request).await?; - rest::deserialize_body(response).await + service.request(request).await?.deserialize().await } } @@ -468,8 +466,7 @@ impl AccountsProxy { .expected_status(&[StatusCode::OK]); let response = service.request(request).await?; - let PlayPurchaseInitResponse { obfuscated_id } = - rest::deserialize_body(response).await?; + let PlayPurchaseInitResponse { obfuscated_id } = response.deserialize().await?; Ok(obfuscated_id) } @@ -515,7 +512,7 @@ impl AccountsProxy { .account(account)? .expected_status(&[StatusCode::OK]); let response = service.request(request).await?; - let response: AuthTokenResponse = rest::deserialize_body(response).await?; + let response: AuthTokenResponse = response.deserialize().await?; Ok(response.auth_token) } } @@ -599,7 +596,7 @@ impl AppVersionProxy { .expected_status(&[StatusCode::OK]) .header("M-Platform-Version", &platform_version)?; let response = service.request(request).await?; - rest::deserialize_body(response).await + response.deserialize().await } } } @@ -621,6 +618,6 @@ impl ApiProxy { .get(&format!("{APP_URL_PREFIX}/api-addrs"))? .expected_status(&[StatusCode::OK]); let response = self.handle.service.request(request).await?; - rest::deserialize_body(response).await + response.deserialize().await } } diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 6f14e6f05b..deaf29ef10 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -60,11 +60,8 @@ impl RelayListProxy { } }); - Ok(Some( - rest::deserialize_body::<ServerRelayList>(response) - .await? - .into_relay_list(etag), - )) + let relay_list: ServerRelayList = response.deserialize().await?; + Ok(Some(relay_list.into_relay_list(etag))) } } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index ff704533c0..de6635b0b1 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -225,7 +225,7 @@ impl< fn handle_new_request( &mut self, request: Request, - completion_tx: oneshot::Sender<Result<hyper::Response<hyper::Body>>>, + completion_tx: oneshot::Sender<Result<Response>>, ) { let tx = self.command_tx.upgrade(); @@ -273,7 +273,7 @@ impl RequestServiceHandle { } /// Submits a `RestRequest` for execution to the request service. - pub async fn request(&self, request: Request) -> Result<hyper::Response<hyper::Body>> { + pub async fn request(&self, request: Request) -> Result<Response> { let (completion_tx, completion_rx) = oneshot::channel(); self.tx .unbounded_send(RequestCommand::NewRequest(request, completion_tx)) @@ -295,7 +295,7 @@ impl RequestServiceHandle { pub(crate) enum RequestCommand { NewRequest( Request, - oneshot::Sender<std::result::Result<hyper::Response<hyper::Body>, Error>>, + oneshot::Sender<std::result::Result<Response, Error>>, ), Reset, NextApiConfig(oneshot::Sender<std::result::Result<(), Error>>), @@ -375,7 +375,7 @@ impl Request { async fn into_future<C: Connect + Clone + Send + Sync + 'static>( mut self, hyper_client: hyper::Client<C>, - ) -> Result<hyper::Response<hyper::Body>> { + ) -> Result<Response> { // Obtain access token first if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) { let access_token = store.get_token(account).await?; @@ -419,7 +419,7 @@ impl Request { } } - Ok(response) + Ok(Response::new(response)) } /// Returns the URI of the request @@ -428,6 +428,31 @@ impl Request { } } +/// Successful result of a REST request +#[derive(Debug)] +pub struct Response { + response: hyper::Response<hyper::Body>, +} + +impl Response { + fn new(response: hyper::Response<hyper::Body>) -> Self { + Self { response } + } + + pub fn status(&self) -> StatusCode { + self.response.status() + } + + pub fn headers(&self) -> &hyper::HeaderMap<HeaderValue> { + self.response.headers() + } + + pub async fn deserialize<T: serde::de::DeserializeOwned>(self) -> Result<T> { + let body_length = get_body_length(&self.response); + deserialize_body_inner(self.response, body_length).await + } +} + #[derive(serde::Deserialize)] struct OldErrorResponse { pub code: String, @@ -539,25 +564,6 @@ impl RequestFactory { } } -pub async fn deserialize_body<T: serde::de::DeserializeOwned>( - response: hyper::Response<hyper::Body>, -) -> Result<T> { - let body_length = get_body_length(&response); - deserialize_body_inner(response, body_length).await -} - -async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( - mut response: hyper::Response<hyper::Body>, - body_length: usize, -) -> Result<T> { - let mut body: Vec<u8> = Vec::with_capacity(body_length); - while let Some(chunk) = response.body_mut().next().await { - body.extend(&chunk?); - } - - serde_json::from_slice(&body).map_err(Error::from) -} - fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize { response .headers() @@ -601,6 +607,18 @@ async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Res Err(Error::ApiError(status, error_message.to_owned())) } +async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( + mut response: hyper::Response<hyper::Body>, + body_length: usize, +) -> Result<T> { + let mut body: Vec<u8> = Vec::with_capacity(body_length); + while let Some(chunk) = response.body_mut().next().await { + body.extend(&chunk?); + } + + serde_json::from_slice(&body).map_err(Error::from) +} + #[derive(Clone)] pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs index d6672464e1..43ffaf1054 100644 --- a/mullvad-daemon/src/geoip.rs +++ b/mullvad-daemon/src/geoip.rs @@ -84,8 +84,7 @@ async fn send_location_request_internal( ) -> Result<AmIMullvad, Error> { let future_service = service.clone(); let request = mullvad_api::rest::Request::get(uri)?; - let response = future_service.request(request).await?; - mullvad_api::rest::deserialize_body(response).await + future_service.request(request).await?.deserialize().await } fn log_network_error(err: Error, version: &'static str) { |
