diff options
| author | Emīls <emils@mullvad.net> | 2021-01-25 15:29:56 +0000 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2021-01-25 15:29:56 +0000 |
| commit | 2b055cc39c4eee4b8af8bb193143ddd4055b39ac (patch) | |
| tree | df2c586a211247a022cfb4618a4abbc8423027a0 | |
| parent | 631b690b32a183b03f2fc7ec2fb2ecf0d7e40aae (diff) | |
| parent | 493aeeb995c0027afd92e2e5d26de64d0e0d88d9 (diff) | |
| download | mullvadvpn-2b055cc39c4eee4b8af8bb193143ddd4055b39ac.tar.xz mullvadvpn-2b055cc39c4eee4b8af8bb193143ddd4055b39ac.zip | |
Merge branch 'android-socket-passthrough'
| -rw-r--r-- | CHANGELOG.md | 1 | ||||
| -rw-r--r-- | Cargo.lock | 2 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 91 | ||||
| -rw-r--r-- | mullvad-problem-report/src/lib.rs | 2 | ||||
| -rw-r--r-- | mullvad-rpc/Cargo.toml | 4 | ||||
| -rw-r--r-- | mullvad-rpc/src/https_client_with_sni.rs | 183 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 25 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 61 | ||||
| -rw-r--r-- | mullvad-rpc/src/tcp_stream.rs | 124 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/tun_provider/android/mod.rs | 38 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnecting_state.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/error_state.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 13 |
16 files changed, 500 insertions, 80 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 78dacd3c01..87b49e3166 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th #### Android - Allow to configure the tunnel to use custom DNS servers. - Show only applications that has INTERNET permission on split tunnel screen. +- Allow reaching the API server when connecting, disconnecting or in a blocked state. #### Linux - Improved compatiblitiy with newer versions of systemd-resolved. diff --git a/Cargo.lock b/Cargo.lock index a0cf1de768..5cdb346ab5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1354,6 +1354,7 @@ dependencies = [ name = "mullvad-rpc" version = "0.1.0" dependencies = [ + "bytes", "chrono", "err-derive", "filetime", @@ -1368,6 +1369,7 @@ dependencies = [ "regex", "serde", "serde_json", + "socket2", "talpid-types", "tempfile", "tokio", diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index abf5d9962a..42f067a4f3 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -41,6 +41,8 @@ use mullvad_types::{ wireguard::KeygenEvent, }; use settings::SettingsPersister; +#[cfg(target_os = "android")] +use std::os::unix::io::RawFd; #[cfg(not(target_os = "android"))] use std::path::Path; use std::{ @@ -62,7 +64,7 @@ use talpid_core::{ #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; use talpid_types::{ - net::{openvpn, Endpoint, TransportProtocol, TunnelParameters, TunnelType}, + net::{openvpn, Endpoint, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType}, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, }; @@ -231,6 +233,8 @@ pub enum DaemonCommand { /// Saves the target tunnel state and enters a blocking state. The state is restored /// upon restart. PrepareRestart, + #[cfg(target_os = "android")] + BypassSocket(RawFd, oneshot::Sender<()>), } /// All events that can happen in the daemon. Sent from various threads and exposed interfaces. @@ -513,6 +517,8 @@ where result_rx.await.map_err(|_| ()) }) }, + #[cfg(target_os = "android")] + Self::create_bypass_tx(&internal_event_tx), ) .await .map_err(Error::InitRpcFactory)?; @@ -522,6 +528,7 @@ where let on_relay_list_update = move |relay_list: &RelayList| { relay_list_listener.notify_relay_list(relay_list.clone()); }; + let mut relay_selector = relays::RelaySelector::new( rpc_handle.clone(), on_relay_list_update, @@ -794,17 +801,8 @@ where &mut self, tunnel_state_transition: TunnelStateTransition, ) { - match &tunnel_state_transition { - TunnelStateTransition::Disconnected - | TunnelStateTransition::Connected(_) - | TunnelStateTransition::Error(_) => { - // Reset the RPCs so that they fail immediately after the underlying socket gets - // invalidated due to the tunnel either coming up or breaking. - self.rpc_handle.service().reset().await; - } - _ => (), - }; - + self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition) + .await; let tunnel_state = match tunnel_state_transition { TunnelStateTransition::Disconnected => TunnelState::Disconnected, TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting { @@ -851,6 +849,19 @@ where self.event_listener.notify_new_state(tunnel_state); } + async fn reset_rpc_sockets_on_tunnel_state_transition( + &mut self, + tunnel_state_transition: &TunnelStateTransition, + ) { + match (&self.tunnel_state, &tunnel_state_transition) { + // only reset the API sockets if when connected or leaving the connected state + (&TunnelState::Connected { .. }, _) | (_, &TunnelStateTransition::Connected(_)) => { + self.rpc_handle.service().reset().await; + } + _ => (), + }; + } + async fn handle_generate_tunnel_parameters( &mut self, tunnel_parameters_tx: &sync_mpsc::Sender< @@ -1112,6 +1123,8 @@ where ClearSplitTunnelProcesses(tx) => self.on_clear_split_tunnel_processes(tx), Shutdown => self.trigger_shutdown_event(), PrepareRestart => self.on_prepare_restart(), + #[cfg(target_os = "android")] + BypassSocket(fd, tx) => self.on_bypass_socket(fd, tx), } } @@ -1288,10 +1301,9 @@ where } fn get_geo_location(&mut self) -> impl Future<Output = Result<GeoIpLocation, ()>> { - let https_handle = self.rpc_runtime.rest_handle(); - + let rpc_service = self.rpc_runtime.rest_handle(); async { - geoip::send_location_request(https_handle) + geoip::send_location_request(rpc_service) .await .map_err(|e| { warn!("Unable to fetch GeoIP location: {}", e.display_chain()); @@ -1865,7 +1877,7 @@ where .map_err(|e| { format!("Failed to add new wireguard key to account data: {}", e) })?; - if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { + if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { self.reconnect_tunnel(); } let keygen_event = KeygenEvent::NewKey(public_key); @@ -1974,6 +1986,34 @@ where self.clean_up_target_cache = false; } + #[cfg(target_os = "android")] + fn on_bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) { + match self.tunnel_state { + // When connected, the API connection shouldn't be bypassed. + TunnelState::Connected { .. } => (), + _ => { + self.send_tunnel_command(TunnelCommand::BypassSocket(fd, tx)); + } + } + } + + #[cfg(target_os = "android")] + fn create_bypass_tx( + event_sender: &DaemonEventSender, + ) -> Option<mpsc::Sender<mullvad_rpc::SocketBypassRequest>> { + let (bypass_tx, mut bypass_rx) = mpsc::channel(1); + let daemon_tx = event_sender.to_specialized_sender(); + tokio::runtime::Handle::current().spawn(async move { + while let Some((raw_fd, done_tx)) = bypass_rx.next().await { + if let Err(_) = daemon_tx.send(DaemonCommand::BypassSocket(raw_fd, done_tx)) { + log::error!("Can't send socket bypass request to daemon"); + break; + } + } + }); + Some(bypass_tx) + } + /// Set the target state of the client. If it changed trigger the operations needed to /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. @@ -2026,10 +2066,7 @@ where } fn get_connected_tunnel_type(&self) -> Option<TunnelType> { - use talpid_types::net::TunnelEndpoint; - use TunnelState::Connected; - - if let Connected { + if let TunnelState::Connected { endpoint: TunnelEndpoint { tunnel_type, .. }, .. } = self.tunnel_state @@ -2040,6 +2077,20 @@ where } } + fn get_target_tunnel_type(&self) -> Option<TunnelType> { + match self.tunnel_state { + TunnelState::Connected { + endpoint: TunnelEndpoint { tunnel_type, .. }, + .. + } + | TunnelState::Connecting { + endpoint: TunnelEndpoint { tunnel_type, .. }, + .. + } => Some(tunnel_type), + _ => None, + } + } + fn send_tunnel_command(&mut self, command: TunnelCommand) { self.tunnel_command_tx .unbounded_send(command) diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 9bd1fba185..6b820f58f3 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -282,6 +282,8 @@ pub fn send_problem_report( cache_dir, false, |_| Ok(()), + #[cfg(target_os = "android")] + None, )) .map_err(Error::CreateRpcClientError)?; let rpc_client = mullvad_rpc::ProblemReportProxy::new(rpc_manager.mullvad_rest_handle()); diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml index cf27d2e287..680a880608 100644 --- a/mullvad-rpc/Cargo.toml +++ b/mullvad-rpc/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" publish = false [dependencies] +bytes = "0.5" chrono = { version = "0.4", features = ["serde"] } err-derive = "0.2.1" futures = "0.3" @@ -32,6 +33,9 @@ talpid-types = { path = "../talpid-types" } filetime = "0.2" tempfile = "3.0" +[target.'cfg(target_os="android")'.dependencies] +socket2 = "0.3" + [[bin]] name = "relay_list" diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs index 48b90e76e7..b80b2db95b 100644 --- a/mullvad-rpc/src/https_client_with_sni.rs +++ b/mullvad-rpc/src/https_client_with_sni.rs @@ -1,15 +1,30 @@ +use crate::{rest::RequestCommand, tcp_stream::TcpStream}; +use futures::{ + channel::{mpsc, oneshot}, + sink::SinkExt, +}; use http::uri::Scheme; -use hyper::{client::HttpConnector, service::Service, Uri}; +use hyper::{ + client::connect::dns::{GaiResolver, Name}, + service::Service, + Uri, +}; use hyper_rustls::MaybeHttpsStream; +#[cfg(target_os = "android")] +use std::os::unix::io::{AsRawFd, RawFd}; use std::{ fmt, future::Future, io::{self, BufReader}, + net::{IpAddr, SocketAddr}, pin::Pin, - str, + str::{self, FromStr}, sync::Arc, task::{Context, Poll}, + time::Duration, }; + +use tokio::{net::TcpStream as TokioTcpStream, runtime::Handle, time::timeout}; use tokio_rustls::rustls; use webpki::DNSNameRef; @@ -18,14 +33,23 @@ const OLD_ROOT_CERT: &[u8] = include_bytes!("../old_le_root_cert.pem"); // New LetsEncrypt root certificate const NEW_ROOT_CERT: &[u8] = include_bytes!("../new_le_root_cert.pem"); +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); + /// A Connector for the `https` scheme. #[derive(Clone)] pub struct HttpsConnectorWithSni { + next_socket_id: usize, + handle: Handle, sni_hostname: Option<String>, - http: HttpConnector, + service_tx: Option<mpsc::Sender<RequestCommand>>, + #[cfg(target_os = "android")] + socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, tls: Arc<rustls::ClientConfig>, } +#[cfg(target_os = "android")] +pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>); + impl HttpsConnectorWithSni { /// Construct a new HttpsConnectorWithSni. /// @@ -33,15 +57,24 @@ impl HttpsConnectorWithSni { /// /// This uses hyper's default `HttpConnector`, and default `TlsConnector`. /// If you wish to use something besides the defaults, use `From::from`. - pub fn new() -> Self { - let mut http = HttpConnector::new(); - http.enforce_http(false); - + pub fn new( + handle: Handle, + sni_hostname: Option<String>, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> Self { let mut config = rustls::ClientConfig::new(); config.enable_sni = true; config.root_store = Self::read_cert_store(); - HttpsConnectorWithSni::from((http, config)) + HttpsConnectorWithSni { + next_socket_id: 0, + handle, + sni_hostname, + #[cfg(target_os = "android")] + socket_bypass_tx, + service_tx: None, + tls: Arc::new(config), + } } fn read_cert_store() -> rustls::RootCertStore { @@ -65,22 +98,75 @@ impl HttpsConnectorWithSni { } - /// Configure a hostname to use with SNI. - /// - /// Configures the TLS connection handshake to request a certificate for a given domain, - /// instead of the domain obtained from the URI. Use `None` to use the domain from the URI. - pub fn set_sni_hostname(&mut self, hostname: Option<String>) { - self.sni_hostname = hostname; + /// Set a channel to register sockets with the request service. + pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) { + self.service_tx = Some(service_tx); } -} -impl From<(HttpConnector, rustls::ClientConfig)> for HttpsConnectorWithSni { - fn from(args: (HttpConnector, rustls::ClientConfig)) -> HttpsConnectorWithSni { - HttpsConnectorWithSni { - sni_hostname: None, - http: args.0, - tls: Arc::new(args.1), + fn next_id(&mut self) -> usize { + let next_id = self.next_socket_id; + self.next_socket_id = self.next_socket_id.wrapping_add(1); + next_id + } + + #[cfg(not(target_os = "android"))] + async fn open_socket(addr: SocketAddr) -> std::io::Result<TokioTcpStream> { + timeout(CONNECT_TIMEOUT, TokioTcpStream::connect(addr)) + .await + .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? + } + + #[cfg(target_os = "android")] + async fn open_socket( + addr: SocketAddr, + socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> std::io::Result<TokioTcpStream> { + use socket2::{Domain, Protocol, Socket, Type}; + let domain = match addr { + SocketAddr::V4(_) => Domain::ipv4(), + SocketAddr::V6(_) => Domain::ipv6(), + }; + let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?.into_tcp_stream(); + + if let Some(mut tx) = socket_bypass_tx { + let (done_tx, done_rx) = oneshot::channel(); + let _ = tx.send((socket.as_raw_fd(), done_tx)).await; + if let Err(_) = done_rx.await { + log::error!("Failed to bypass socket, connection might fail"); + } } + + timeout(CONNECT_TIMEOUT, TokioTcpStream::connect_std(socket, &addr)) + .await + .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? + } + + async fn resolve_address(hostname: &str) -> io::Result<SocketAddr> { + match Self::parse_addr(&hostname) { + Some(addr) => Ok(addr), + None => { + let mut addrs = GaiResolver::new() + .call( + Name::from_str(&hostname) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?, + ) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let addr = addrs + .next() + .ok_or(io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; + Ok(SocketAddr::new(addr, 443)) + } + } + } + + + fn parse_addr(hostname: &str) -> Option<SocketAddr> { + if let Ok(addr) = hostname.parse::<SocketAddr>() { + return Some(addr); + } + let ip = hostname.parse::<IpAddr>().ok()?; + Some(SocketAddr::new(ip, 443)) } } @@ -90,8 +176,9 @@ impl fmt::Debug for HttpsConnectorWithSni { } } + impl Service<Uri> for HttpsConnectorWithSni { - type Response = MaybeHttpsStream<tokio::net::TcpStream>; + type Response = MaybeHttpsStream<TcpStream>; type Error = io::Error; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; @@ -102,7 +189,6 @@ impl Service<Uri> for HttpsConnectorWithSni { fn call(&mut self, uri: Uri) -> Self::Future { let tls_connector: tokio_rustls::TlsConnector = self.tls.clone().into(); - let mut http = self.http.clone(); let sni_hostname = self .sni_hostname .clone() @@ -110,7 +196,12 @@ impl Service<Uri> for HttpsConnectorWithSni { .ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") }); + let service_tx = self.service_tx.clone(); + let socket_id = self.next_id(); + let handle = self.handle.clone(); + #[cfg(target_os = "android")] + let socket_bypass_tx = self.socket_bypass_tx.clone(); let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { @@ -119,14 +210,49 @@ impl Service<Uri> for HttpsConnectorWithSni { "invalid url, not https", )); } + let host_addr = uri.host().ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid url, missing host", + ))?; let hostname = sni_hostname?; let host = DNSNameRef::try_from_ascii_str(&hostname) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid hostname"))?; - let connection = http - .call(uri) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - let tls_connection = tls_connector.connect(host, connection).await?; + let addr = Self::resolve_address(host_addr).await?; + + let tokio_connection = Self::open_socket( + addr, + #[cfg(target_os = "android")] + socket_bypass_tx, + ) + .await?; + + let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel(); + + + let (tcp_stream, socket_handle) = + TcpStream::new(tokio_connection, socket_id, Some(socket_shutdown_tx)); + if let Some(mut service_tx) = service_tx { + if service_tx + .send(RequestCommand::SocketOpened(socket_id, socket_handle)) + .await + .is_err() + { + log::error!("Failed to submit new socket to request service"); + } + handle.spawn(async move { + let _ = socket_shutdown_rx.await; + if service_tx + .send(RequestCommand::SocketClosed(socket_id)) + .await + .is_err() + { + log::error!("Failed to send socket closure command to request service"); + } + }); + } + + + let tls_connection = tls_connector.connect(host, tcp_stream).await?; Ok(MaybeHttpsStream::Https(tls_connection)) }; @@ -136,6 +262,7 @@ impl Service<Uri> for HttpsConnectorWithSni { } } + #[cfg(test)] mod test { use super::HttpsConnectorWithSni; diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index e211a89c2e..7fd3a2480f 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -1,6 +1,8 @@ #![deny(rust_2018_idioms)] use chrono::{offset::Utc, DateTime}; +#[cfg(target_os = "android")] +use futures::channel::mpsc; use hyper::Method; use mullvad_types::{ account::{AccountToken, VoucherSubmission}, @@ -20,6 +22,9 @@ pub mod rest; mod https_client_with_sni; use crate::https_client_with_sni::HttpsConnectorWithSni; +#[cfg(target_os = "android")] +pub use crate::https_client_with_sni::SocketBypassRequest; +mod tcp_stream; mod address_cache; mod relay_list; @@ -41,9 +46,10 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443); /// A type that helps with the creation of RPC connections. pub struct MullvadRpcRuntime { - https_connector: HttpsConnectorWithSni, handle: tokio::runtime::Handle, pub address_cache: AddressCache, + #[cfg(target_os = "android")] + socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } #[derive(err_derive::Error, Debug)] @@ -59,13 +65,14 @@ impl MullvadRpcRuntime { /// Create a new `MullvadRpcRuntime`. pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { Ok(MullvadRpcRuntime { - https_connector: HttpsConnectorWithSni::new(), handle, address_cache: AddressCache::new( vec![API_ADDRESS.into()], None, Arc::new(Box::new(|_| Ok(()))), )?, + #[cfg(target_os = "android")] + socket_bypass_tx: None, }) } @@ -78,6 +85,7 @@ impl MullvadRpcRuntime { cache_dir: &Path, write_changes: bool, address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); let write_file = if write_changes { @@ -125,19 +133,22 @@ impl MullvadRpcRuntime { } }; - let https_connector = HttpsConnectorWithSni::new(); - Ok(MullvadRpcRuntime { - https_connector, handle, address_cache, + #[cfg(target_os = "android")] + socket_bypass_tx, }) } /// Creates a new request service and returns a handle to it. fn new_request_service(&mut self, sni_hostname: Option<String>) -> rest::RequestServiceHandle { - let mut https_connector = self.https_connector.clone(); - https_connector.set_sni_hostname(sni_hostname); + let https_connector = HttpsConnectorWithSni::new( + self.handle.clone(), + sni_hostname, + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), + ); let service = rest::RequestService::new( https_connector, diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 82a6271c74..07746a3150 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,4 +1,7 @@ -use crate::address_cache::AddressCache; +use crate::{ + address_cache::AddressCache, https_client_with_sni::HttpsConnectorWithSni, + tcp_stream::TcpStreamHandle, +}; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Aborted}, @@ -7,7 +10,7 @@ use futures::{ TryFutureExt, }; use hyper::{ - client::{connect::Connect, Client}, + client::Client, header::{self, HeaderValue}, Method, Uri, }; @@ -73,30 +76,37 @@ pub enum Error { /// A service that executes HTTP requests, allowing for on-demand termination of all in-flight /// requests -pub(crate) struct RequestService<C> { +pub(crate) struct RequestService { command_tx: mpsc::Sender<RequestCommand>, command_rx: mpsc::Receiver<RequestCommand>, - client: hyper::Client<C, hyper::Body>, - connector: C, + sockets: BTreeMap<usize, TcpStreamHandle>, + client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, handle: Handle, next_id: u64, in_flight_requests: BTreeMap<u64, AbortHandle>, address_cache: AddressCache, } -impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { +impl RequestService { /// Constructs a new request service. - pub fn new(connector: C, handle: Handle, address_cache: AddressCache) -> RequestService<C> { - let client = Self::new_client(connector.clone()); - + pub fn new( + mut connector: HttpsConnectorWithSni, + handle: Handle, + address_cache: AddressCache, + ) -> RequestService { let (command_tx, command_rx) = mpsc::channel(1); + + connector.set_service_tx(command_tx.clone()); + let client = Client::builder().build(connector); + + Self { command_tx, command_rx, + sockets: BTreeMap::new(), client, in_flight_requests: BTreeMap::new(), next_id: 0, - connector, handle, address_cache, } @@ -110,10 +120,6 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { } } - fn new_client(connector: C) -> Client<C, hyper::Body> { - Client::builder().pool_max_idle_per_host(0).build(connector) - } - fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { @@ -173,12 +179,19 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { self.in_flight_requests.insert(id, abort_handle); } + RequestCommand::SocketOpened(id, socket) => { + self.sockets.insert(id, socket); + } + RequestCommand::SocketClosed(id) => { + self.sockets.remove(&id); + } RequestCommand::RequestFinished(id) => { self.in_flight_requests.remove(&id); } - RequestCommand::Reset => { + RequestCommand::Reset(tx) => { self.reset(); + let _ = tx.send(()); } } } @@ -188,7 +201,12 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { for (_, abort_handle) in old_requests.into_iter() { abort_handle.abort(); } - let _ = mem::replace(&mut self.client, Self::new_client(self.connector.clone())); + + let old_sockets = mem::replace(&mut self.sockets, BTreeMap::new()); + for (_, socket) in old_sockets.into_iter() { + socket.close(); + } + self.next_id = 0; } @@ -202,6 +220,7 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { while let Some(command) = self.command_rx.next().await { self.process_command(command); } + self.reset(); } } @@ -229,8 +248,10 @@ impl RequestServiceHandle { /// Resets the corresponding RequestService, dropping all in-flight requests. pub async fn reset(&self) { let mut tx = self.tx.clone(); + let (done_tx, done_rx) = oneshot::channel(); - let _ = tx.send(RequestCommand::Reset).await; + let _ = tx.send(RequestCommand::Reset(done_tx)).await; + let _ = done_rx.await; } /// Submits a `RestRequest` for exectuion to the request service. @@ -252,13 +273,15 @@ impl RequestServiceHandle { } #[derive(Debug)] -enum RequestCommand { +pub(crate) enum RequestCommand { NewRequest( RestRequest, oneshot::Sender<std::result::Result<Response, Error>>, ), RequestFinished(u64), - Reset, + SocketOpened(usize, TcpStreamHandle), + SocketClosed(usize), + Reset(oneshot::Sender<()>), } diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs new file mode 100644 index 0000000000..7bf89a47b2 --- /dev/null +++ b/mullvad-rpc/src/tcp_stream.rs @@ -0,0 +1,124 @@ +use bytes::buf::Buf; +use futures::channel::oneshot; +use hyper::client::connect::{Connected, Connection}; +use std::{ + io, + net::Shutdown, + pin::Pin, + sync::{Arc, Mutex, Weak}, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream as TokioTcpStream, +}; + +#[derive(Debug)] +pub struct TcpStreamHandle { + inner: Weak<Mutex<StreamInner>>, +} + +impl TcpStreamHandle { + pub fn close(self) { + if let Some(inner_lock) = self.inner.upgrade() { + if let Ok(mut inner) = inner_lock.lock() { + if let Err(err) = inner.stream.shutdown(Shutdown::Both) { + log::error!("Failed to shut down TCP socket: {}", err); + } + let _ = inner.shutdown_tx.take(); + } + } + } +} + + +pub struct TcpStream { + inner: Arc<Mutex<StreamInner>>, +} + +impl TcpStream { + pub fn new( + stream: TokioTcpStream, + id: usize, + shutdown_tx: Option<oneshot::Sender<()>>, + ) -> (Self, TcpStreamHandle) { + let inner = Arc::new(Mutex::new(StreamInner { + id, + stream, + shutdown_tx, + })); + ( + Self { + inner: inner.clone(), + }, + TcpStreamHandle { + inner: Arc::downgrade(&inner), + }, + ) + } + + fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T { + let mut inner = self.inner.lock().expect("TCP lock poisoned"); + stream_fn(&mut inner.stream) + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + if let Ok(mut inner) = self.inner.lock() { + if let Some(tx) = inner.shutdown_tx.take() { + let _ = tx.send(()); + } + } + } +} + + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf)) + } + + fn poll_write_buf<B: Buf>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_write_buf(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx)) + } +} + +impl AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<Result<usize, io::Error>> { + self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf)) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[derive(Debug)] +struct StreamInner { + id: usize, + stream: TokioTcpStream, + shutdown_tx: Option<oneshot::Sender<()>>, +} diff --git a/talpid-core/src/tunnel/tun_provider/android/mod.rs b/talpid-core/src/tunnel/tun_provider/android/mod.rs index fa48f115b9..575f05f4c2 100644 --- a/talpid-core/src/tunnel/tun_provider/android/mod.rs +++ b/talpid-core/src/tunnel/tun_provider/android/mod.rs @@ -136,8 +136,8 @@ impl AndroidTunProvider { }) } - /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and (potentially) - /// LAN routes via the tunnel device. + /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and + /// (potentially) LAN routes via the tunnel device. /// /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be /// closed. @@ -321,6 +321,33 @@ impl AndroidTunProvider { } } + /// Allow a socket to bypass the tunnel. + pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { + let env = JnixEnv::from( + self.jvm + .attach_current_thread_as_daemon() + .map_err(|cause| Error::AttachJvmToThread(cause))?, + ); + let create_tun_method = env + .get_method_id(&self.class, "bypass", "(I)Z") + .map_err(|cause| Error::FindMethod("bypass", cause))?; + + let result = env + .call_method_unchecked( + self.object.as_obj(), + create_tun_method, + JavaType::Primitive(Primitive::Boolean), + &[JValue::Int(socket)], + ) + .map_err(|cause| Error::CallMethod("bypass", cause))?; + + match result { + JValue::Bool(0) => Err(Error::Bypass), + JValue::Bool(_) => Ok(()), + value => Err(Error::InvalidMethodResult("bypass", format!("{:?}", value))), + } + } + fn call_method( &self, name: &'static str, @@ -385,11 +412,10 @@ impl VpnServiceTun { ) .map_err(|cause| Error::CallMethod("bypass", cause))?; - match result { - JValue::Bool(0) => Err(Error::Bypass), - JValue::Bool(_) => Ok(()), - value => Err(Error::InvalidMethodResult("bypass", format!("{:?}", value))), + if !bool::from_java(&env, result) { + return Err(Error::Bypass); } + Ok(()) } } diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 0c305de9a7..2f1d53fd9a 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -258,6 +258,11 @@ impl ConnectedState { Some(TunnelCommand::Block(reason)) => { self.disconnect(shared_values, AfterDisconnect::Block(reason)) } + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + SameState(self.into()) + } } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 0b03ceeca1..7ec0aaf55a 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -284,6 +284,11 @@ impl ConnectingState { Some(TunnelCommand::Block(reason)) => { self.disconnect(shared_values, AfterDisconnect::Block(reason)) } + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + SameState(self.into()) + } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 922eb69c88..7a02a9e550 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -110,6 +110,11 @@ impl TunnelState for DisconnectedState { Some(TunnelCommand::Block(reason)) => { NewState(ErrorState::enter(shared_values, reason)) } + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + SameState(self.into()) + } Some(_) => SameState(self.into()), None => Finished, } diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 48a83a6dc3..12e0bebe5e 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -54,6 +54,11 @@ impl DisconnectingState { Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing, Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + AfterDisconnect::Nothing + } }, AfterDisconnect::Block(reason) => match command { Some(TunnelCommand::AllowLan(allow_lan)) => { @@ -86,6 +91,11 @@ impl DisconnectingState { Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), Some(TunnelCommand::Disconnect) => AfterDisconnect::Nothing, Some(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason), + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + AfterDisconnect::Block(reason) + } None => AfterDisconnect::Block(reason), }, AfterDisconnect::Reconnect(retry_attempt) => match command { @@ -119,6 +129,11 @@ impl DisconnectingState { Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt), Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing, Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + AfterDisconnect::Reconnect(retry_attempt) + } }, }; diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 51159d274f..6674cca13e 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -151,6 +151,12 @@ impl TunnelState for ErrorState { Some(TunnelCommand::Block(reason)) => { NewState(ErrorState::enter(shared_values, reason)) } + + #[cfg(target_os = "android")] + Some(TunnelCommand::BypassSocket(fd, done_tx)) => { + shared_values.bypass_socket(fd, done_tx); + SameState(self.into()) + } } } } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 5a86722c01..90012a8296 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -23,6 +23,8 @@ use futures::{ channel::{mpsc, oneshot}, stream, StreamExt, }; +#[cfg(target_os = "android")] +use std::os::unix::io::RawFd; use std::{ collections::HashSet, io, @@ -171,6 +173,9 @@ pub enum TunnelCommand { Disconnect, /// Disconnect any open tunnel and block all network access Block(ErrorStateCause), + /// Bypass a socket, allowing traffic to flow through outside the tunnel. + #[cfg(target_os = "android")] + BypassSocket(RawFd, oneshot::Sender<()>), } type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>; @@ -405,6 +410,14 @@ impl SharedTunnelStateValues { log::trace!("Connectivity check wasn't disabled by the daemon"); } } + + #[cfg(target_os = "android")] + pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) { + if let Err(err) = self.tun_provider.bypass(fd) { + log::error!("Failed to bypass socket {}", err); + } + let _ = tx.send(()); + } } /// Asynchronous result of an attempt to progress a state. |
