diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-10-01 10:06:39 +0200 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-10-16 09:56:49 +0200 |
| commit | ae1c1cefb5c048fcec021397a85ad886197913f7 (patch) | |
| tree | 60075efd2ef0fddcdf2f4700519ce8639eb589de | |
| parent | 088b1f68db4268e923d011ea79c05f6c46ca72f2 (diff) | |
| download | mullvadvpn-ae1c1cefb5c048fcec021397a85ad886197913f7.tar.xz mullvadvpn-ae1c1cefb5c048fcec021397a85ad886197913f7.zip | |
Replace occurrences of old `Body` type in `rest` mod
Use `Empty<Bytes>` for outgoing, `Incoming` for responses
and generic paras for our type wrapping `Request`.
| -rw-r--r-- | mullvad-api/src/lib.rs | 3 | ||||
| -rw-r--r-- | mullvad-api/src/relay_list.rs | 4 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 223 | ||||
| -rw-r--r-- | mullvad-daemon/src/geoip.rs | 2 |
4 files changed, 154 insertions, 78 deletions
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index fc731ee6ea..6b3ac3c951 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -1,7 +1,6 @@ #![allow(rustdoc::private_intra_doc_links)] #[cfg(target_os = "android")] use futures::channel::mpsc; -use hyper::Method; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; use mullvad_types::{ @@ -710,7 +709,7 @@ impl AppVersionProxy { let service = self.handle.service.clone(); let path = format!("{APP_URL_PREFIX}/releases/{platform}/{app_version}"); - let request = self.handle.factory.request(&path, Method::GET); + let request = self.handle.factory.get(&path); async move { let request = request? diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 5f2b2d6d81..f1375b5f6f 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -2,7 +2,7 @@ use crate::rest; -use hyper::{header, Method, StatusCode}; +use hyper::{header, StatusCode}; use mullvad_types::{location, relay_list}; use talpid_types::net::wireguard; @@ -34,7 +34,7 @@ impl RelayListProxy { etag: Option<String>, ) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> { let service = self.handle.service.clone(); - let request = self.handle.factory.request("app/v1/relays", Method::GET); + let request = self.handle.factory.get("app/v1/relays"); async move { let mut request = request? diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index c8df2aea8e..0a2bed93b7 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -11,10 +11,13 @@ use futures::{ channel::{mpsc, oneshot}, stream::StreamExt, }; +use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; use hyper::{ - client::{connect::Connect, Client}, + body::{Body, Bytes, Incoming}, + // client::{connect::Connect, Client}, header::{self, HeaderValue}, - Method, Uri, + Method, + Uri, }; use mullvad_types::account::AccountNumber; use std::{ @@ -62,8 +65,8 @@ pub enum Error { ApiError(StatusCode, String), /// The string given was not a valid URI. - #[error("Not a valid URI")] - InvalidUri, + #[error("Not a valid URI {0}")] + InvalidUri(#[from] Arc<http::uri::InvalidUri>), #[error("Set account number on factory with no access token store")] NoAccessTokenStore, @@ -119,7 +122,11 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> { command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>, command_rx: mpsc::UnboundedReceiver<RequestCommand>, connector_handle: HttpsConnectorWithSniHandle, - client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, + // client: hyper_util::client::legacy::Client< + // HttpsConnectorWithSni, + // BoxBody<dyn hyper::body::Buf, Error>, + // >, + client: HttpsConnectorWithSni, connection_mode_provider: T, connection_mode_generation: usize, api_availability: ApiAvailability, @@ -144,7 +151,8 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { connector_handle.set_connection_mode(connection_mode_provider.initial()); let (command_tx, command_rx) = mpsc::unbounded(); - let client = Client::builder().build(connector); + // let client = + // hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector); let command_tx = Arc::new(command_tx); @@ -152,7 +160,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { command_tx: Arc::downgrade(&command_tx), command_rx, connector_handle, - client, + client: connector, connection_mode_provider, connection_mode_generation: 0, api_availability, @@ -203,13 +211,15 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { fn handle_new_request( &mut self, - request: Request, - completion_tx: oneshot::Sender<Result<Response>>, + request: Request<BoxBody<Bytes, Error>>, + completion_tx: oneshot::Sender<Result<Response<Incoming>>>, ) { let tx = self.command_tx.upgrade(); let api_availability = self.api_availability.clone(); - let request_future = request.into_future(self.client.clone(), api_availability.clone()); + let request_future = request + .map(|r| http::Request::map(r, BodyExt::boxed)) + .into_future(self.client.clone(), api_availability.clone()); let connection_mode_generation = self.connection_mode_generation; @@ -246,8 +256,14 @@ impl RequestServiceHandle { } /// Submits a `RestRequest` for execution to the request service. - pub async fn request(&self, request: Request) -> Result<Response> { + pub async fn request<B>(&self, request: Request<B>) -> Result<Response<Incoming>> + where + B: Body + Send + Sync + 'static, + Error: From<B::Error>, + Bytes: From<B::Data>, + { let (completion_tx, completion_rx) = oneshot::channel(); + let request = request.map(|r| r.map(box_body)); self.tx .unbounded_send(RequestCommand::NewRequest(request, completion_tx)) .map_err(|_| Error::RestServiceDown)?; @@ -258,8 +274,8 @@ impl RequestServiceHandle { #[derive(Debug)] pub(crate) enum RequestCommand { NewRequest( - Request, - oneshot::Sender<std::result::Result<Response, Error>>, + Request<BoxBody<Bytes, Error>>, + oneshot::Sender<std::result::Result<Response<Incoming>, Error>>, ), Reset, NextApiConfig(usize), @@ -267,18 +283,18 @@ pub(crate) enum RequestCommand { /// A REST request that is sent to the RequestService to be executed. #[derive(Debug)] -pub struct Request { - request: hyper::Request<hyper::Body>, +pub struct Request<B> { + request: hyper::Request<B>, timeout: Duration, access_token_store: Option<AccessTokenStore>, account: Option<AccountNumber>, expected_status: &'static [hyper::StatusCode], } -impl Request { +// TODO: merge with `RequestFactory::get` /// Constructs a GET request with the given URI. Returns an error if the URI is not valid. - pub fn get(uri: &str) -> Result<Self> { - let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; +pub fn get(uri: &str) -> Result<Request<Empty<Bytes>>> { + let uri = hyper::Uri::from_str(uri)?; let mut builder = http::request::Builder::new() .method(Method::GET) @@ -287,18 +303,16 @@ impl Request { if let Some(host) = uri.host() { builder = builder.header( header::HOST, - HeaderValue::from_str(host).map_err(|_| Error::InvalidHeaderError)?, + HeaderValue::from_str(host).map_err(|_e| Error::InvalidHeaderError)?, ); }; - let request = builder.uri(uri).body(hyper::Body::empty())?; - Ok(Self::new(request, None)) + let request = builder.uri(uri).body(Empty::<Bytes>::new())?; + Ok(Request::new(request, None)) } - fn new( - request: hyper::Request<hyper::Body>, - access_token_store: Option<AccessTokenStore>, - ) -> Self { +impl<B: Body> Request<B> { + fn new(request: hyper::Request<B>, access_token_store: Option<AccessTokenStore>) -> Self { Self { request, timeout: DEFAULT_TIMEOUT, @@ -336,11 +350,64 @@ impl Request { Ok(self) } + /// Returns the URI of the request + pub fn uri(&self) -> &Uri { + self.request.uri() + } +} +impl<B> Request<B> { + /// Map the underlying [`hyper::Request`] type + fn map<F, B2>(self, f: F) -> Request<B2> + where + F: FnOnce(hyper::Request<B>) -> hyper::Request<B2>, + { + Request { + request: f(self.request), + timeout: self.timeout, + access_token_store: self.access_token_store, + account: self.account, + expected_status: self.expected_status, + } + } +} + +fn box_body<B>(body: B) -> BoxBody<Bytes, Error> +where + B: Body + Send + Sync + 'static, + Error: From<B::Error>, + Bytes: From<B::Data>, +{ + try_downcast(body).unwrap_or_else(|body| { + body.map_frame(|frame| frame.map_data(Bytes::from)) + .map_err(Error::from) + .boxed() + }) +} + +pub(crate) fn try_downcast<T, K>(k: K) -> core::result::Result<T, K> +where + T: 'static, + K: Send + 'static, +{ + let mut k = Some(k); + if let Some(k) = <dyn std::any::Any>::downcast_mut::<Option<T>>(&mut k) { + Ok(k.take().unwrap()) + } else { + Err(k.unwrap()) + } +} + +impl<B> Request<B> +where + B: Body + Send + 'static + Unpin, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ async fn into_future<C: Connect + Clone + Send + Sync + 'static>( self, - hyper_client: hyper::Client<C>, + hyper_client: hyper_util::client::legacy::Client<C, B>, api_availability: ApiAvailability, - ) -> Result<Response> { + ) -> Result<Response<Incoming>> { let timeout = self.timeout; let inner_fut = self.into_future_without_timeout(hyper_client, api_availability); tokio::time::timeout(timeout, inner_fut) @@ -348,11 +415,14 @@ impl Request { .map_err(|_| Error::TimeoutError)? } - async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>( + async fn into_future_without_timeout<C>( mut self, - hyper_client: hyper::Client<C>, + hyper_client: hyper_util::client::legacy::Client<C, B>, api_availability: ApiAvailability, - ) -> Result<Response> { + ) -> Result<Response<Incoming>> + where + C: Connect + Clone + Send + Sync + 'static, + { let _ = api_availability.wait_for_unsuspend().await; // Obtain access token first @@ -399,21 +469,19 @@ impl Request { Ok(Response::new(response)) } - - /// Returns the URI of the request - pub fn uri(&self) -> &Uri { - self.request.uri() - } } /// Successful result of a REST request #[derive(Debug)] -pub struct Response { - response: hyper::Response<hyper::Body>, +pub struct Response<B> { + response: hyper::Response<B>, } -impl Response { - fn new(response: hyper::Response<hyper::Body>) -> Self { +impl<B: Body> Response<B> +where + Error: From<<B as Body>::Error>, +{ + fn new(response: hyper::Response<B>) -> Self { Self { response } } @@ -426,8 +494,7 @@ impl Response { } pub async fn deserialize<T: serde::de::DeserializeOwned>(self) -> Result<T> { - let body_length = get_body_length(&self.response); - deserialize_body_inner(self.response, body_length).await + deserialize_body_inner(self.response).await } } @@ -462,38 +529,46 @@ impl RequestFactory { } } - pub fn request(&self, path: &str, method: Method) -> Result<Request> { + pub fn request<B: Body + Default>(&self, path: &str, method: Method) -> Result<Request<B>> { Ok( Request::new(self.hyper_request(path, method)?, self.token_store.clone()) .timeout(self.default_timeout), ) } - pub fn get(&self, path: &str) -> Result<Request> { + pub fn get(&self, path: &str) -> Result<Request<Empty<Bytes>>> { self.request(path, Method::GET) } - pub fn post(&self, path: &str) -> Result<Request> { + pub fn post(&self, path: &str) -> Result<Request<Empty<Bytes>>> { self.request(path, Method::POST) } - pub fn put(&self, path: &str) -> Result<Request> { + pub fn put(&self, path: &str) -> Result<Request<Empty<Bytes>>> { self.request(path, Method::PUT) } - pub fn delete(&self, path: &str) -> Result<Request> { + pub fn delete(&self, path: &str) -> Result<Request<Empty<Bytes>>> { self.request(path, Method::DELETE) } - pub fn head(&self, path: &str) -> Result<Request> { + pub fn head(&self, path: &str) -> Result<Request<Empty<Bytes>>> { self.request(path, Method::HEAD) } - pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> { + pub fn post_json<S: serde::Serialize>( + &self, + path: &str, + body: &S, + ) -> Result<Request<Full<Bytes>>> { self.json_request(Method::POST, path, body) } - pub fn put_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> { + pub fn put_json<S: serde::Serialize>( + &self, + path: &str, + body: &S, + ) -> Result<Request<Full<Bytes>>> { self.json_request(Method::PUT, path, body) } @@ -501,18 +576,17 @@ impl RequestFactory { self.default_timeout = timeout; self } - fn json_request<S: serde::Serialize>( &self, method: Method, path: &str, body: &S, - ) -> Result<Request> { + ) -> Result<Request<Full<Bytes>>> { let mut request = self.hyper_request(path, method)?; - let json_body = serde_json::to_string(&body)?; - let body_length = json_body.as_bytes().len(); - *request.body_mut() = json_body.into_bytes().into(); + let json_body = serde_json::to_vec(&body)?; + let body_length = json_body.len(); + *request.body_mut() = Full::new(Bytes::from(json_body)); let headers = request.headers_mut(); headers.insert(header::CONTENT_LENGTH, HeaderValue::from(body_length)); @@ -524,7 +598,7 @@ impl RequestFactory { Ok(Request::new(request, self.token_store.clone()).timeout(self.default_timeout)) } - fn hyper_request(&self, path: &str, method: Method) -> Result<hyper::Request<hyper::Body>> { + fn hyper_request<B: Default>(&self, path: &str, method: Method) -> Result<http::Request<B>> { let uri = self.get_uri(path)?; let request = http::request::Builder::new() .method(method) @@ -536,17 +610,17 @@ impl RequestFactory { HeaderValue::from_str(&self.hostname).map_err(|_| Error::InvalidHeaderError)?, ); - let result = request.body(hyper::Body::empty())?; + let result = request.body(B::default())?; Ok(result) } fn get_uri(&self, path: &str) -> Result<Uri> { let uri = format!("https://{}/{}", self.hostname, path); - hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri) + Ok(hyper::Uri::from_str(&uri)?) } } -fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize { +fn get_body_length<B>(response: &hyper::Response<B>) -> usize { response .headers() .get(header::CONTENT_LENGTH) @@ -555,20 +629,22 @@ fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize { .unwrap_or(0) } -async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Result<T> { +async fn handle_error_response<T, B: Body>(response: hyper::Response<B>) -> Result<T> +where + Error: From<B::Error>, +{ let status = response.status(); let error_message = match status { hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed", status => match get_body_length(&response) { 0 => status.canonical_reason().unwrap_or("Unexpected error"), - body_length => { + _length => { return match response.headers().get("content-type") { Some(content_type) if content_type == "application/problem+json" => { // TODO: We should make sure we unify the new error format and the old // error format so that they both produce the same Errors for the same // problems after being processed. - let err: NewErrorResponse = - deserialize_body_inner(response, body_length).await?; + let err: NewErrorResponse = deserialize_body_inner(response).await?; // The new error type replaces the `code` field with the `type` field. // This is what is used to programmatically check the error. Err(Error::ApiError( @@ -578,8 +654,7 @@ async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Res )) } _ => { - let err: OldErrorResponse = - deserialize_body_inner(response, body_length).await?; + let err: OldErrorResponse = deserialize_body_inner(response).await?; Err(Error::ApiError(status, err.code)) } }; @@ -589,16 +664,17 @@ async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Res Err(Error::ApiError(status, error_message.to_owned())) } -async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( - mut response: hyper::Response<hyper::Body>, - body_length: usize, -) -> Result<T> { - let mut body: Vec<u8> = Vec::with_capacity(body_length); - while let Some(chunk) = response.body_mut().next().await { - body.extend(&chunk?); - } +async fn deserialize_body_inner<T, B>(response: hyper::Response<B>) -> Result<T> +where + T: serde::de::DeserializeOwned, + B: Body, + Error: From<B::Error>, +{ + use http_body_util::BodyExt; - serde_json::from_slice(&body).map_err(Error::from) + let collected = BodyExt::collect(response).await?; + let res = serde_json::from_slice(&collected.to_bytes())?; + Ok(res) } #[derive(Clone)] @@ -639,3 +715,4 @@ macro_rules! impl_into_arc_err { impl_into_arc_err!(hyper::Error); impl_into_arc_err!(serde_json::Error); impl_into_arc_err!(http::Error); +impl_into_arc_err!(http::uri::InvalidUri); diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs index da2fb3e8db..815b83b13f 100644 --- a/mullvad-daemon/src/geoip.rs +++ b/mullvad-daemon/src/geoip.rs @@ -154,7 +154,7 @@ async fn send_location_request_internal( service: RequestServiceHandle, ) -> Result<AmIMullvad, Error> { let future_service = service.clone(); - let request = mullvad_api::rest::Request::get(uri)?; + let request = mullvad_api::rest::get(uri)?; future_service.request(request).await?.deserialize().await } |
