diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-01-21 13:23:27 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-01 15:30:22 +0100 |
| commit | bcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd (patch) | |
| tree | 95721793a135d304f73a8724a5f014960316df80 | |
| parent | 45d827f96b524f1183a2f3709700c0bea643faab (diff) | |
| download | mullvadvpn-bcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd.tar.xz mullvadvpn-bcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd.zip | |
Add Shadowsocks support to HTTPS connector
| -rw-r--r-- | mullvad-rpc/src/address_cache.rs | 139 | ||||
| -rw-r--r-- | mullvad-rpc/src/bin/relay_list.rs | 16 | ||||
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 184 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 68 | ||||
| -rw-r--r-- | mullvad-rpc/src/proxy.rs | 171 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 93 |
6 files changed, 452 insertions, 219 deletions
diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs index 099ffbc075..3b6fcba074 100644 --- a/mullvad-rpc/src/address_cache.rs +++ b/mullvad-rpc/src/address_cache.rs @@ -1,13 +1,9 @@ -use std::{ - io, - net::SocketAddr, - ops::{Deref, DerefMut}, - path::Path, - sync::{Arc, Mutex}, -}; +use super::API; +use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt}, + sync::Mutex, }; #[derive(err_derive::Error, Debug)] @@ -27,99 +23,66 @@ pub enum Error { #[error(display = "The address cache is empty")] EmptyAddressCache, - - #[error(display = "The address change listener returned an error")] - ChangeListenerError, } -pub type CurrentAddressChangeListener = - dyn Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static; - #[derive(Clone)] pub struct AddressCache { inner: Arc<Mutex<AddressCacheInner>>, write_path: Option<Arc<Path>>, - change_listener: Arc<Box<CurrentAddressChangeListener>>, } impl AddressCache { - /// Initialize cache using the given list, and write changes to `write_path`. - pub fn new(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> { + /// Initialize cache using the hardcoded address, and write changes to `write_path`. + pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> { + Self::new_inner(API.addr, write_path) + } + + /// Initialize cache using `read_path`, and write changes to `write_path`. + pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> { + log::debug!("Loading API addresses from {}", read_path.display()); + Self::new_inner(read_address_file(read_path).await?, write_path) + } + + fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> { let cache = AddressCacheInner::from_address(address); - log::trace!("API address cache: {:?}", cache.address); - log::debug!("Using API address: {:?}", Self::get_address_inner(&cache)); + log::debug!("Using API address: {}", cache.address); let address_cache = Self { inner: Arc::new(Mutex::new(cache)), write_path: write_path.map(|cache| Arc::from(cache)), - change_listener: Arc::new(Box::new(|_| Ok(()))), }; Ok(address_cache) } - /// Initialize cache using `read_path`, and write changes to `write_path`. - pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> { - log::debug!("Loading API addresses from {}", read_path.display()); - Self::new(read_address_file(read_path).await?, write_path) - } - - pub fn set_change_listener(&mut self, change_listener: Arc<Box<CurrentAddressChangeListener>>) { - self.change_listener = change_listener; + /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. + pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> { + if hostname.eq_ignore_ascii_case(&API.host) { + Some(self.get_address().await) + } else { + None + } } /// Returns the currently selected address. - pub fn get_address(&self) -> SocketAddr { - let inner = self.inner.lock().unwrap(); - Self::get_address_inner(&inner) - } - - fn get_address_inner(inner: &AddressCacheInner) -> SocketAddr { - inner.address + pub async fn get_address(&self) -> SocketAddr { + self.inner.lock().await.address } pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> { - let should_update = { - let mut inner = self.inner.lock().unwrap(); - let mut transaction = AddressCacheTransaction::new(&mut inner); - - let current_address = transaction.address.clone(); - - if address != current_address { - transaction.address = address.clone(); - tokio::task::block_in_place(move || { - if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() { - log::error!("Failed to select new API endpoint"); - return Err(io::Error::new( - io::ErrorKind::Other, - "callback returned an error", - )); - } - transaction.commit(); - Ok(()) - })?; - true - } else { - false - } - }; - if should_update { - log::trace!("API address cache: {}", address); - self.save_to_disk().await?; + let mut inner = self.inner.lock().await; + if address != inner.address { + self.save_to_disk(&address).await?; + inner.address = address; } Ok(()) } - async fn save_to_disk(&self) -> io::Result<()> { + async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> { let write_path = match self.write_path.as_ref() { Some(write_path) => write_path, None => return Ok(()), }; - let address = { - let inner = self.inner.lock().unwrap(); - inner.address.clone() - }; - let temp_path = write_path.with_file_name("api-cache.temp"); let mut file = fs::File::create(&temp_path).await?; @@ -132,16 +95,6 @@ impl AddressCache { } } -impl crate::rest::AddressProvider for AddressCache { - fn get_address(&self) -> String { - self.get_address().to_string() - } - - fn clone_box(&self) -> Box<dyn crate::rest::AddressProvider> { - Box::new(self.clone()) - } -} - #[derive(Clone, PartialEq, Eq)] struct AddressCacheInner { address: SocketAddr, @@ -153,38 +106,6 @@ impl AddressCacheInner { } } -struct AddressCacheTransaction<'a> { - current: &'a mut AddressCacheInner, - working_cache: AddressCacheInner, -} - -impl<'a> AddressCacheTransaction<'a> { - fn new(cache: &'a mut AddressCacheInner) -> Self { - Self { - working_cache: cache.clone(), - current: cache, - } - } - - fn commit(self) { - *self.current = self.working_cache; - } -} - -impl<'a> Deref for AddressCacheTransaction<'a> { - type Target = AddressCacheInner; - - fn deref(&self) -> &Self::Target { - &self.working_cache - } -} - -impl<'a> DerefMut for AddressCacheTransaction<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.working_cache - } -} - async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> { let mut file = fs::File::open(path) .await diff --git a/mullvad-rpc/src/bin/relay_list.rs b/mullvad-rpc/src/bin/relay_list.rs index db7fc29854..66118d3ade 100644 --- a/mullvad-rpc/src/bin/relay_list.rs +++ b/mullvad-rpc/src/bin/relay_list.rs @@ -2,18 +2,24 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_rpc::{rest::Error as RestError, MullvadRpcRuntime, RelayListProxy}; +use mullvad_rpc::{ + proxy::ApiConnectionMode, rest::Error as RestError, MullvadRpcRuntime, RelayListProxy, +}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let mut runtime = + let runtime = MullvadRpcRuntime::new(tokio::runtime::Handle::current()).expect("Failed to load runtime"); - let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle()) - .relay_list(None) - .await; + let relay_list_request = RelayListProxy::new( + runtime + .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat(), |_| async { true }) + .await, + ) + .relay_list(None) + .await; let relay_list = match relay_list_request { Ok(relay_list) => relay_list, diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index 4d8108f89b..8328ed93d9 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -1,8 +1,10 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, + proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, tls_stream::TlsStream, + AddressCache, }; -use futures::{channel::mpsc, StreamExt}; +use futures::{channel::mpsc, future, StreamExt}; #[cfg(target_os = "android")] use futures::{channel::oneshot, sink::SinkExt}; use http::uri::Scheme; @@ -11,6 +13,13 @@ use hyper::{ service::Service, Uri, }; +use shadowsocks::{ + config::ServerType, + context::{Context as SsContext, SharedContext}, + crypto::v1::CipherKind, + relay::tcprelay::ProxyClientStream, + ServerConfig, +}; #[cfg(target_os = "android")] use std::os::unix::io::{AsRawFd, RawFd}; use std::{ @@ -24,6 +33,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; +use talpid_types::ErrorExt; #[cfg(target_os = "android")] use tokio::net::TcpSocket; @@ -33,13 +43,70 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); #[derive(Clone)] pub struct HttpsConnectorWithSniHandle { - tx: mpsc::UnboundedSender<()>, + tx: mpsc::UnboundedSender<HttpsConnectorRequest>, } impl HttpsConnectorWithSniHandle { /// Stop all streams produced by this connector pub fn reset(&self) { - let _ = self.tx.unbounded_send(()); + let _ = self.tx.unbounded_send(HttpsConnectorRequest::Reset); + } + + /// Change the proxy settings for the connector + pub fn set_connection_mode(&self, proxy: ApiConnectionMode) { + let _ = self + .tx + .unbounded_send(HttpsConnectorRequest::SetConnectionMode(proxy)); + } +} + +enum HttpsConnectorRequest { + Reset, + SetConnectionMode(ApiConnectionMode), +} + +#[derive(Clone)] +enum InnerConnectionMode { + /// Connect directly to the target. + Direct, + /// Connect to the destination via a proxy. + Proxied(ParsedShadowsocksConfig), +} + +#[derive(Clone)] +struct ParsedShadowsocksConfig { + peer: SocketAddr, + password: String, + cipher: CipherKind, +} + +impl From<ParsedShadowsocksConfig> for ServerConfig { + fn from(config: ParsedShadowsocksConfig) -> Self { + ServerConfig::new(config.peer, config.password, config.cipher) + } +} + +#[derive(err_derive::Error, Debug)] +enum ProxyConfigError { + #[error(display = "Unrecognized cipher selected: {}", _0)] + InvalidCipher(String), +} + +impl TryFrom<ApiConnectionMode> for InnerConnectionMode { + type Error = ProxyConfigError; + + fn try_from(config: ApiConnectionMode) -> Result<Self, Self::Error> { + Ok(match config { + ApiConnectionMode::Direct => InnerConnectionMode::Direct, + ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(config)) => { + InnerConnectionMode::Proxied(ParsedShadowsocksConfig { + peer: config.peer, + password: config.password, + cipher: CipherKind::from_str(&config.cipher) + .map_err(|_| ProxyConfigError::InvalidCipher(config.cipher))?, + }) + } + }) } } @@ -48,12 +115,16 @@ impl HttpsConnectorWithSniHandle { pub struct HttpsConnectorWithSni { inner: Arc<Mutex<HttpsConnectorWithSniInner>>, sni_hostname: Option<String>, + address_cache: AddressCache, + abort_notify: Arc<tokio::sync::Notify>, + proxy_context: SharedContext, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } struct HttpsConnectorWithSniInner { stream_handles: Vec<AbortableStreamHandle>, + proxy_config: InnerConnectionMode, } #[cfg(target_os = "android")] @@ -63,24 +134,46 @@ impl HttpsConnectorWithSni { pub fn new( handle: Handle, sni_hostname: Option<String>, + address_cache: AddressCache, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> (Self, HttpsConnectorWithSniHandle) { - let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded(); + let (tx, mut rx) = mpsc::unbounded(); + let abort_notify = Arc::new(tokio::sync::Notify::new()); let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner { stream_handles: vec![], + proxy_config: InnerConnectionMode::Direct, })); let inner_copy = inner.clone(); + let notify = abort_notify.clone(); handle.spawn(async move { // Handle requests by `HttpsConnectorWithSniHandle`s - while let Some(()) = rx.next().await { + while let Some(request) = rx.next().await { let handles = { let mut inner = inner_copy.lock().unwrap(); + + if let HttpsConnectorRequest::SetConnectionMode(config) = request { + match InnerConnectionMode::try_from(config) { + Ok(config) => { + inner.proxy_config = config; + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to parse new API proxy config" + ) + ); + } + } + } + std::mem::take(&mut inner.stream_handles) }; for handle in handles { handle.close(); } + notify.notify_waiters(); } }); @@ -88,6 +181,9 @@ impl HttpsConnectorWithSni { HttpsConnectorWithSni { inner, sni_hostname, + address_cache, + abort_notify, + proxy_context: SsContext::new_shared(ServerType::Local), #[cfg(target_os = "android")] socket_bypass_tx, }, @@ -125,17 +221,24 @@ impl HttpsConnectorWithSni { .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? } - async fn resolve_address(uri: &Uri) -> io::Result<SocketAddr> { + async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> { let hostname = uri.host().ok_or(io::Error::new( io::ErrorKind::InvalidInput, "invalid url, missing host", ))?; let port = uri.port_u16().unwrap_or(443); - if let Some(addr) = hostname.parse::<IpAddr>().ok() { return Ok(SocketAddr::new(addr, port)); } + // Preferentially, use cached address. + // + if let Some(addr) = address_cache.resolve_hostname(hostname).await { + return Ok(SocketAddr::new(addr.ip(), port)); + } + + // Use getaddrinfo as a fallback + // let mut addrs = GaiResolver::new() .call( Name::from_str(&hostname) @@ -157,7 +260,7 @@ impl fmt::Debug for HttpsConnectorWithSni { } impl Service<Uri> for HttpsConnectorWithSni { - type Response = TlsStream<AbortableStream<TcpStream>>; + type Response = AbortableStream<ApiConnection>; type Error = io::Error; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; @@ -175,8 +278,11 @@ impl Service<Uri> for HttpsConnectorWithSni { io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") }); let inner = self.inner.clone(); + let abort_notify = self.abort_notify.clone(); + let proxy_context = self.proxy_context.clone(); #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); + let address_cache = self.address_cache.clone(); let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { @@ -187,16 +293,62 @@ impl Service<Uri> for HttpsConnectorWithSni { } let hostname = sni_hostname?; - let addr = Self::resolve_address(&uri).await?; + let addr = Self::resolve_address(address_cache, uri).await?; - let tokio_connection = Self::open_socket( - addr, + // Loop until we have established a connection. This starts over if a new endpoint + // is selected while connecting. + let stream = loop { + let config = { inner.lock().unwrap().proxy_config.clone() }; + let hostname_copy = hostname.clone(); + let addr_copy = addr.clone(); + let context = proxy_context.clone(); #[cfg(target_os = "android")] - socket_bypass_tx, - ) - .await?; + let socket_bypass_tx_copy = socket_bypass_tx.clone(); + + let stream_fut: Pin< + Box<dyn Future<Output = Result<ApiConnection, io::Error>> + Send>, + > = Box::pin(async move { + match config { + InnerConnectionMode::Direct => { + let socket = Self::open_socket( + addr_copy, + #[cfg(target_os = "android")] + socket_bypass_tx_copy, + ) + .await?; + let tls_stream = + TlsStream::connect_https(socket, &hostname_copy).await?; + Ok(ApiConnection::Direct(tls_stream)) + } + InnerConnectionMode::Proxied(proxy_config) => { + let socket = Self::open_socket( + proxy_config.peer, + #[cfg(target_os = "android")] + socket_bypass_tx_copy, + ) + .await?; + let proxy = ProxyClientStream::from_stream( + context, + socket, + &ServerConfig::from(proxy_config), + addr, + ); + let tls_stream = + TlsStream::connect_https(proxy, &hostname_copy).await?; + Ok(ApiConnection::Proxied(tls_stream)) + } + } + }); + + // Wait for connection. Abort and retry if we switched to a different server. + if let future::Either::Left((stream, _)) = + future::select(stream_fut, Box::pin(abort_notify.notified())).await + { + break stream?; + } + }; - let (tcp_stream, socket_handle) = AbortableStream::new(tokio_connection); + let (stream, socket_handle) = AbortableStream::new(stream); { let mut inner = inner.lock().unwrap(); @@ -204,7 +356,7 @@ impl Service<Uri> for HttpsConnectorWithSni { inner.stream_handles.push(socket_handle); } - Ok(TlsStream::connect_https(tcp_stream, &hostname).await?) + Ok(stream) }; Box::pin(fut) diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 27cd3f87ae..e625dca376 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -3,17 +3,18 @@ use chrono::{offset::Utc, DateTime}; #[cfg(target_os = "android")] use futures::channel::mpsc; +use futures::Stream; use hyper::Method; use mullvad_types::{ account::{AccountToken, VoucherSubmission}, version::AppVersion, }; +use proxy::ApiConnectionMode; use std::{ collections::BTreeMap, future::Future, net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, - sync::Arc, }; use talpid_types::{net::wireguard, ErrorExt}; @@ -23,14 +24,14 @@ pub mod rest; mod abortable_stream; mod https_client_with_sni; +pub mod proxy; mod tls_stream; -mod proxy; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; mod address_cache; mod relay_list; -pub use address_cache::{AddressCache, CurrentAddressChangeListener}; +pub use address_cache::AddressCache; pub use hyper::StatusCode; pub use relay_list::RelayListProxy; @@ -151,7 +152,7 @@ impl MullvadRpcRuntime { ) -> Result<Self, Error> { Ok(MullvadRpcRuntime { handle, - address_cache: AddressCache::new(API.addr, None)?, + address_cache: AddressCache::new(None)?, api_availability: ApiAvailability::new(availability::State::default()), #[cfg(target_os = "android")] socket_bypass_tx, @@ -192,7 +193,7 @@ impl MullvadRpcRuntime { ) ); } - AddressCache::new(API.addr, write_file)? + AddressCache::new(write_file)? } }; @@ -205,44 +206,52 @@ impl MullvadRpcRuntime { }) } - pub fn set_address_change_listener( - &mut self, - address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static, - ) { - self.address_cache - .set_change_listener(Arc::new(Box::new(address_change_listener))); - } - /// Creates a new request service and returns a handle to it. - fn new_request_service( - &mut self, + async fn new_request_service< + T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static, + AcceptedNewEndpoint: Future<Output = bool> + Send + 'static, + >( + &self, sni_hostname: Option<String>, + proxy_provider: T, + new_address_callback: impl (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> rest::RequestServiceHandle { let service = rest::RequestService::new( self.handle.clone(), sni_hostname, self.api_availability.handle(), + self.address_cache.clone(), + proxy_provider, + new_address_callback, #[cfg(target_os = "android")] socket_bypass_tx, - ); + ) + .await; let handle = service.handle(); self.handle.spawn(service.into_future()); handle } /// Returns a request factory initialized to create requests for the master API - pub fn mullvad_rest_handle(&mut self) -> rest::MullvadRestHandle { - let service = self.new_request_service( - Some(API.host.clone()), - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ); - let factory = rest::RequestFactory::new( - API.host.clone(), - Box::new(self.address_cache.clone()), - Some("app".to_owned()), - ); + pub async fn mullvad_rest_handle< + T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static, + AcceptedNewEndpoint: Future<Output = bool> + Send + 'static, + >( + &self, + proxy_provider: T, + new_address_callback: impl (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static, + ) -> rest::MullvadRestHandle { + let service = self + .new_request_service( + Some(API.host.clone()), + proxy_provider, + new_address_callback, + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), + ) + .await; + let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned())); rest::MullvadRestHandle::new( service, @@ -253,12 +262,15 @@ impl MullvadRpcRuntime { } /// Returns a new request service handle - pub fn rest_handle(&mut self) -> rest::RequestServiceHandle { + pub async fn rest_handle(&mut self) -> rest::RequestServiceHandle { self.new_request_service( None, + ApiConnectionMode::Direct.into_repeat(), + |_| async { true }, #[cfg(target_os = "android")] None, ) + .await } pub fn handle(&mut self) -> &mut tokio::runtime::Handle { diff --git a/mullvad-rpc/src/proxy.rs b/mullvad-rpc/src/proxy.rs index dd21df164e..009a1960dc 100644 --- a/mullvad-rpc/src/proxy.rs +++ b/mullvad-rpc/src/proxy.rs @@ -1,67 +1,204 @@ use crate::tls_stream::TlsStream; +use futures::Stream; use hyper::client::connect::{Connected, Connection}; +use rand::{distributions::Alphanumeric, Rng}; +use serde::{Deserialize, Serialize}; use shadowsocks::relay::tcprelay::ProxyClientStream; use std::{ - io, + fmt, io, + net::SocketAddr, + path::Path, pin::Pin, task::{self, Poll}, }; +use talpid_types::{net::openvpn::ShadowsocksProxySettings, ErrorExt}; use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, + fs, + io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, net::TcpStream, }; +const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json"; + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +pub enum ApiConnectionMode { + /// Connect directly to the target. + Direct, + /// Connect to the destination via a proxy. + Proxied(ProxyConfig), +} + +impl fmt::Display for ApiConnectionMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + ApiConnectionMode::Direct => write!(f, "unproxied"), + ApiConnectionMode::Proxied(settings) => settings.fmt(f), + } + } +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +pub enum ProxyConfig { + Shadowsocks(ShadowsocksProxySettings), +} + +impl fmt::Display for ProxyConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + // TODO: Do not hardcode TCP + ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer), + } + } +} + +impl ApiConnectionMode { + /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`. + /// This returns `ApiConnectionMode::Direct` if reading from disk fails for any reason. + pub async fn try_from_cache(cache_dir: &Path) -> Self { + Self::from_cache(cache_dir).await.unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to read API endpoint cache") + ); + ApiConnectionMode::Direct + }) + } + + /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`. + /// If the file does not exist, this returns `Ok(ApiConnectionMode::Direct)`. + async fn from_cache(cache_dir: &Path) -> io::Result<Self> { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + match fs::read_to_string(path).await { + Ok(s) => serde_json::from_str(&s).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg(&format!( + "Failed to deserialize \"{}\"", + CURRENT_CONFIG_FILENAME + )) + ); + io::Error::new(io::ErrorKind::Other, "deserialization failed") + }), + Err(error) => { + if error.kind() == io::ErrorKind::NotFound { + Ok(ApiConnectionMode::Direct) + } else { + Err(error) + } + } + } + } + + /// Stores this config to `CURRENT_CONFIG_FILENAME`. + /// The content is saved to a temporary file first, which ensures that + /// consumers of the file never end up with partial content. + pub async fn save(&self, cache_dir: &Path) -> io::Result<()> { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + let mut temp_ext = String::from("temp"); + temp_ext.push_str( + &rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::<String>(), + ); + let temp_path = path.with_extension(temp_ext); + + { + let mut file = fs::File::create(&temp_path).await?; + let json = serde_json::to_string_pretty(self) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "serialization failed"))?; + file.write_all(json.as_bytes()).await?; + file.write_all(b"\n").await?; + file.sync_data().await?; + } + + fs::rename(&temp_path, path).await + } + + /// Attempts to remove `CURRENT_CONFIG_FILENAME`, if it exists. + pub async fn try_delete_cache(cache_dir: &Path) { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + if let Err(err) = fs::remove_file(path).await { + if err.kind() != std::io::ErrorKind::NotFound { + log::error!( + "{}", + err.display_chain_with_msg("Failed to remove old API config") + ); + } + } + } + + /// Returns the remote address, or `None` for `ApiConnectionMode::Direct`. + pub fn get_endpoint(&self) -> Option<SocketAddr> { + match self { + ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(ss)) => Some(ss.peer), + ApiConnectionMode::Direct => None, + } + } + + pub fn is_proxy(&self) -> bool { + *self != ApiConnectionMode::Direct + } + + /// Convenience function that returns a stream that repeats + /// this config forever. + pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> { + futures::stream::repeat(self) + } +} + /// Stream that is either a regular TLS stream or TLS via shadowsocks -pub enum MaybeProxyStream { - Tls(TlsStream<TcpStream>), +pub enum ApiConnection { + Direct(TlsStream<TcpStream>), Proxied(TlsStream<ProxyClientStream<TcpStream>>), } -impl AsyncRead for MaybeProxyStream { +impl AsyncRead for ApiConnection { fn poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { match Pin::get_mut(self) { - MaybeProxyStream::Tls(s) => Pin::new(s).poll_read(cx, buf), - MaybeProxyStream::Proxied(s) => Pin::new(s).poll_read(cx, buf), + ApiConnection::Direct(s) => Pin::new(s).poll_read(cx, buf), + ApiConnection::Proxied(s) => Pin::new(s).poll_read(cx, buf), } } } -impl AsyncWrite for MaybeProxyStream { +impl AsyncWrite for ApiConnection { fn poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { match Pin::get_mut(self) { - MaybeProxyStream::Tls(s) => Pin::new(s).poll_write(cx, buf), - MaybeProxyStream::Proxied(s) => Pin::new(s).poll_write(cx, buf), + ApiConnection::Direct(s) => Pin::new(s).poll_write(cx, buf), + ApiConnection::Proxied(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { match Pin::get_mut(self) { - MaybeProxyStream::Tls(s) => Pin::new(s).poll_flush(cx), - MaybeProxyStream::Proxied(s) => Pin::new(s).poll_flush(cx), + ApiConnection::Direct(s) => Pin::new(s).poll_flush(cx), + ApiConnection::Proxied(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { match Pin::get_mut(self) { - MaybeProxyStream::Tls(s) => Pin::new(s).poll_shutdown(cx), - MaybeProxyStream::Proxied(s) => Pin::new(s).poll_shutdown(cx), + ApiConnection::Direct(s) => Pin::new(s).poll_shutdown(cx), + ApiConnection::Proxied(s) => Pin::new(s).poll_shutdown(cx), } } } -impl Connection for MaybeProxyStream { +impl Connection for ApiConnection { fn connected(&self) -> Connected { match self { - MaybeProxyStream::Tls(s) => s.connected(), - MaybeProxyStream::Proxied(s) => s.connected(), + ApiConnection::Direct(s) => s.connected(), + ApiConnection::Proxied(s) => s.connected(), } } } diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 014719b26b..3b7d0a0bc7 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -4,13 +4,14 @@ use crate::{ address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, + proxy::ApiConnectionMode, }; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Aborted}, sink::SinkExt, stream::StreamExt, - TryFutureExt, + Stream, TryFutureExt, }; use hyper::{ client::Client, @@ -21,7 +22,7 @@ use std::{ collections::BTreeMap, future::Future, mem, - net::{IpAddr, SocketAddr}, + net::SocketAddr, str::FromStr, time::{Duration, Instant}, }; @@ -88,7 +89,11 @@ impl Error { /// A service that executes HTTP requests, allowing for on-demand termination of all in-flight /// requests -pub(crate) struct RequestService { +pub(crate) struct RequestService< + T: Stream<Item = ApiConnectionMode>, + F: Fn(SocketAddr) -> AcceptedNewEndpoint, + AcceptedNewEndpoint: Future<Output = bool>, +> { command_tx: mpsc::Sender<RequestCommand>, command_rx: mpsc::Receiver<RequestCommand>, connector_handle: HttpsConnectorWithSniHandle, @@ -96,24 +101,41 @@ pub(crate) struct RequestService { handle: Handle, next_id: u64, in_flight_requests: BTreeMap<u64, AbortHandle>, + proxy_config_provider: T, + new_address_callback: F, + address_cache: AddressCache, api_availability: ApiAvailabilityHandle, } -impl RequestService { +impl< + T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static, + F: (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static, + AcceptedNewEndpoint: Future<Output = bool> + Send + 'static, + > RequestService<T, F, AcceptedNewEndpoint> +{ /// Constructs a new request service. - pub fn new( + pub async fn new( handle: Handle, sni_hostname: Option<String>, api_availability: ApiAvailabilityHandle, + address_cache: AddressCache, + mut proxy_config_provider: T, + new_address_callback: F, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - ) -> RequestService { + ) -> RequestService<T> { let (connector, connector_handle) = HttpsConnectorWithSni::new( handle.clone(), sni_hostname, + address_cache.clone(), #[cfg(target_os = "android")] socket_bypass_tx.clone(), ); + proxy_config_provider + .next() + .await + .map(|config| connector_handle.set_connection_mode(config)); + let (command_tx, command_rx) = mpsc::channel(1); let client = Client::builder().build(connector); @@ -125,6 +147,9 @@ impl RequestService { handle, in_flight_requests: BTreeMap::new(), next_id: 0, + proxy_config_provider, + new_address_callback, + address_cache, api_availability, } } @@ -137,7 +162,7 @@ impl RequestService { } } - fn process_command(&mut self, command: RequestCommand) { + async fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { let id = self.id(); @@ -166,10 +191,7 @@ impl RequestService { if let Err(err) = &response { if err.is_network_error() && !api_availability.get_state().is_offline() { log::error!("{}", err.display_chain_with_msg("HTTP request failed")); - - // TODO: ask provider for new proxy config - // TODO: notify connector handle of this new config - // TODO: pass proxy config to tunnel state machine + let _ = tx.send(RequestCommand::NextApiConfig).await; } } @@ -191,6 +213,18 @@ impl RequestService { self.reset(); let _ = tx.send(()); } + RequestCommand::NextApiConfig => { + if let Some(new_config) = self.proxy_config_provider.next().await { + let endpoint = match new_config.get_endpoint() { + Some(endpoint) => endpoint, + None => self.address_cache.get_address().await, + }; + // Switch to new connection mode unless rejected by address change callback + if (self.new_address_callback)(endpoint).await { + self.connector_handle.set_connection_mode(new_config); + } + } + } } } @@ -211,7 +245,7 @@ impl RequestService { pub async fn into_future(mut self) { while let Some(command) = self.command_rx.next().await { - self.process_command(command); + self.process_command(command).await; } self.reset(); } @@ -259,6 +293,7 @@ pub(crate) enum RequestCommand { ), RequestFinished(u64), Reset(oneshot::Sender<()>), + NextApiConfig, } /// A REST request that is sent to the RequestService to be executed. @@ -358,20 +393,14 @@ pub struct ErrorResponse { #[derive(Clone)] pub struct RequestFactory { hostname: String, - address_provider: Box<dyn AddressProvider>, path_prefix: Option<String>, pub timeout: Duration, } impl RequestFactory { - pub fn new( - hostname: String, - address_provider: Box<dyn AddressProvider>, - path_prefix: Option<String>, - ) -> Self { + pub fn new(hostname: String, path_prefix: Option<String>) -> Self { Self { hostname, - address_provider, path_prefix, timeout: DEFAULT_TIMEOUT, } @@ -432,9 +461,8 @@ impl RequestFactory { } fn get_uri(&self, path: &str) -> Result<Uri> { - let host = self.address_provider.get_address(); let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or(""); - let uri = format!("https://{}/{}{}", host, prefix, path); + let uri = format!("https://{}/{}{}", self.hostname, prefix, path); hyper::Uri::from_str(&uri).map_err(Error::UriError) } @@ -444,29 +472,6 @@ impl RequestFactory { } } -pub trait AddressProvider: Send + Sync { - /// Must return a string that represents either a host or a host with port - fn get_address(&self) -> String; - fn clone_box(&self) -> Box<dyn AddressProvider>; -} - -impl Clone for Box<dyn AddressProvider> { - fn clone(&self) -> Self { - self.clone_box() - } -} - -impl AddressProvider for IpAddr { - /// Must return a string that represents either a host or a host with port - fn get_address(&self) -> String { - self.to_string() - } - - fn clone_box(&self) -> Box<dyn AddressProvider> { - Box::new(*self) - } -} - pub fn get_request<T: serde::de::DeserializeOwned>( factory: &RequestFactory, service: RequestServiceHandle, |
