diff options
Diffstat (limited to 'talpid-core/src')
| -rw-r--r-- | talpid-core/src/dns/linux/resolvconf.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/static_resolv_conf.rs | 8 | ||||
| -rw-r--r-- | talpid-core/src/mpsc.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/ping_monitor/icmp.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/routing/linux.rs | 57 | ||||
| -rw-r--r-- | talpid-core/src/routing/unix.rs | 3 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 46 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 143 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/tun_provider/unix.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/logging.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 28 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/stats.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs | 56 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs | 16 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 16 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 103 |
16 files changed, 280 insertions, 255 deletions
diff --git a/talpid-core/src/dns/linux/resolvconf.rs b/talpid-core/src/dns/linux/resolvconf.rs index ea2f3f3704..97db14b622 100644 --- a/talpid-core/src/dns/linux/resolvconf.rs +++ b/talpid-core/src/dns/linux/resolvconf.rs @@ -22,16 +22,16 @@ pub enum Error { RunResolvconf(#[error(source)] io::Error), #[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)] - AddRecordError { stderr: String }, + AddRecord { stderr: String }, #[error(display = "Using 'resolvconf' to delete a record failed")] - DeleteRecordError, + DeleteRecord, #[error(display = "Detected dnsmasq is runing and misconfigured")] - DnsmasqMisconfigurationError, + DnsmasqMisconfiguration, #[error(display = "Current /etc/resolv.conf is not generated by resolvconf")] - ResolvconfNotInUseError, + ResolvconfNotInUse, } pub struct Resolvconf { @@ -50,15 +50,15 @@ impl Resolvconf { // Check if resolvconf is managing DNS by /etc/resolv.conf if !is_dnsmasq_running - && !(Self::check_if_resolvconf_is_symlinked_correctly() - || Self::check_if_resolvconf_was_generated()) + && !Self::check_if_resolvconf_is_symlinked_correctly() + && !Self::check_if_resolvconf_was_generated() { - return Err(Error::ResolvconfNotInUseError); + return Err(Error::ResolvconfNotInUse); } // Check if resolvconf can manage DNS via dnsmasq if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() { - return Err(Error::DnsmasqMisconfigurationError); + return Err(Error::DnsmasqMisconfiguration); } Ok(Resolvconf { @@ -94,7 +94,7 @@ impl Resolvconf { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr).to_string(); - return Err(Error::AddRecordError { stderr }); + return Err(Error::AddRecord { stderr }); } self.record_names.insert(record_name); @@ -118,7 +118,7 @@ impl Resolvconf { record_name, String::from_utf8_lossy(&output.stderr) ); - result = Err(Error::DeleteRecordError); + result = Err(Error::DeleteRecord); } } diff --git a/talpid-core/src/dns/linux/static_resolv_conf.rs b/talpid-core/src/dns/linux/static_resolv_conf.rs index 196fb31003..691d7b468b 100644 --- a/talpid-core/src/dns/linux/static_resolv_conf.rs +++ b/talpid-core/src/dns/linux/static_resolv_conf.rs @@ -28,7 +28,7 @@ pub enum Error { ReadResolvConf(&'static str, #[error(source)] io::Error), #[error(display = "resolv.conf at {} could not be parsed", _0)] - ParseError(&'static str, #[error(source)] resolv_conf::ParseError), + Parse(&'static str, #[error(source)] resolv_conf::ParseError), #[error(display = "Failed to remove stale resolv.conf backup at {}", _0)] RemoveBackup(&'static str, #[error(source)] io::Error), @@ -179,7 +179,7 @@ fn read_config() -> Result<Config> { let contents = fs::read_to_string(RESOLV_CONF_PATH) .map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?; - let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?; + let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?; Ok(config) } @@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> { match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) { Ok(backup) => { log::info!("Restoring DNS state from backup"); - let config = Config::parse(&backup) - .map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?; + let config = + Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?; write_config(&config)?; diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs index 8c6424bc01..6492796cfc 100644 --- a/talpid-core/src/mpsc.rs +++ b/talpid-core/src/mpsc.rs @@ -1,11 +1,20 @@ +/// Error type for `Sender` trait. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The underlying channel is closed. + #[error(display = "Channel is closed")] + ChannelClosed, +} + /// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`. pub trait Sender<T> { /// Sends an item over the underlying channel, failing only if the channel is closed. - fn send(&self, item: T) -> Result<(), ()>; + fn send(&self, item: T) -> Result<(), Error>; } impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> { - fn send(&self, content: E) -> Result<(), ()> { - self.unbounded_send(content).map_err(|_| ()) + fn send(&self, content: E) -> Result<(), Error> { + self.unbounded_send(content) + .map_err(|_| Error::ChannelClosed) } } diff --git a/talpid-core/src/ping_monitor/icmp.rs b/talpid-core/src/ping_monitor/icmp.rs index 67f5b70cb5..0bcd9da72f 100644 --- a/talpid-core/src/ping_monitor/icmp.rs +++ b/talpid-core/src/ping_monitor/icmp.rs @@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner( let checksum = internet_checksum::checksum(buffer); (&mut buffer[ICMP_CHECKSUM_OFFSET..]) - .write(&checksum) + .write_all(&checksum) .unwrap(); true diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 092ad6f52a..4b039fe9eb 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>; #[error(no_from)] pub enum Error { #[error(display = "Failed to open a netlink connection")] - ConnectError(#[error(source)] io::Error), + Connect(#[error(source)] io::Error), #[error(display = "Failed to bind netlink socket")] - BindError(#[error(source)] io::Error), + Bind(#[error(source)] io::Error), #[error(display = "Netlink error")] - NetlinkError(#[error(source)] rtnetlink::Error), + Netlink(#[error(source)] rtnetlink::Error), #[error(display = "Route without a valid node")] InvalidRoute, @@ -108,16 +108,16 @@ pub enum Error { UnknownDeviceIndex(u32), #[error(display = "Failed to get a route for the given IP address")] - GetRouteError(#[error(source)] rtnetlink::Error), + GetRoute(#[error(source)] rtnetlink::Error), #[error(display = "No netlink response for route query")] - NoRouteError, + NoRoute, #[error(display = "Route node was malformed")] InvalidRouteNode, #[error(display = "No link found")] - LinkNotFoundError, + LinkNotFound, /// Unable to create routing table for tagged connections and packets. #[error(display = "Cannot find a free routing table ID")] @@ -140,14 +140,11 @@ pub struct RouteManagerImpl { impl RouteManagerImpl { pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { let (mut connection, handle, messages) = - rtnetlink::new_connection().map_err(Error::ConnectError)?; + rtnetlink::new_connection().map_err(Error::Connect)?; let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY; let addr = SocketAddr::new(0, mgroup_flags); - connection - .socket_mut() - .bind(&addr) - .map_err(Error::BindError)?; + connection.socket_mut().bind(&addr).map_err(Error::Bind)?; tokio::spawn(connection); @@ -179,11 +176,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -236,7 +233,7 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; let mut rules = vec![]; @@ -246,7 +243,7 @@ impl RouteManagerImpl { rules.push(rule); } NetlinkPayload::Error(error) => { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } _ => (), } @@ -260,12 +257,12 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { if error.to_io().kind() != io::ErrorKind::NotFound { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -296,7 +293,7 @@ impl RouteManagerImpl { ) -> Result<BTreeMap<u32, NetworkInterface>> { let mut link_map = BTreeMap::new(); let mut link_request = handle.link().get().execute(); - while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? { + while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? { if let Some((idx, device)) = Self::map_interface(link) { link_map.insert(idx, device); } @@ -543,7 +540,7 @@ impl RouteManagerImpl { async fn delete_route_if_exists(&self, route: &Route) -> Result<()> { if let Err(error) = self.delete_route(route).await { - if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error { + if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error { if msg.code == -libc::ESRCH { return Ok(()); } @@ -619,7 +616,7 @@ impl RouteManagerImpl { .del(route_message) .execute() .await - .map_err(Error::NetlinkError) + .map_err(Error::Netlink) } async fn add_route_direct(&mut self, route: Route) -> Result<()> { @@ -693,11 +690,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(err) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err))); } } Ok(()) @@ -759,7 +756,7 @@ impl RouteManagerImpl { } None => { log::error!("No route detected when assigning the mtu to the Wireguard tunnel"); - return Err(Error::NoRouteError); + return Err(Error::NoRoute); } } } @@ -767,17 +764,13 @@ impl RouteManagerImpl { "Retried {} times looking for the correct device and could not find it", RECURSION_LIMIT ); - Err(Error::NoRouteError) + Err(Error::NoRoute) } async fn get_device_mtu(&self, device: String) -> Result<u16> { let mut links = self.handle.link().get().execute(); let target_device = LinkNla::IfName(device); - while let Some(msg) = links - .try_next() - .await - .map_err(|_| Error::LinkNotFoundError)? - { + while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? { let found = msg.nlas.iter().any(|e| *e == target_device); if found { if let Some(LinkNla::Mtu(mtu)) = @@ -788,7 +781,7 @@ impl RouteManagerImpl { } } } - Err(Error::LinkNotFoundError) + Err(Error::LinkNotFound) } async fn get_destination_route( @@ -813,11 +806,11 @@ impl RouteManagerImpl { let mut stream = execute_route_get_request(self.handle.clone(), message.clone()); match stream.try_next().await { Ok(Some(route_msg)) => self.parse_route_message(route_msg), - Ok(None) => Err(Error::NoRouteError), + Ok(None) => Err(Error::NoRoute), Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { Ok(None) } - Err(err) => Err(Error::GetRouteError(err)), + Err(err) => Err(Error::GetRoute(err)), } } } diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index edfbdd2b85..326fb1fad1 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -19,16 +19,19 @@ use futures::stream::Stream; #[cfg(target_os = "linux")] use std::net::IpAddr; +#[allow(clippy::module_inception)] #[cfg(target_os = "macos")] #[path = "macos.rs"] mod imp; #[cfg(target_os = "macos")] pub(crate) use imp::listen_for_default_route_changes; +#[allow(clippy::module_inception)] #[cfg(target_os = "linux")] #[path = "linux.rs"] mod imp; +#[allow(clippy::module_inception)] #[cfg(target_os = "android")] #[path = "android.rs"] mod imp; diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 5da3c092a4..f6ada1c2cf 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,6 +1,6 @@ use self::tun_provider::TunProvider; use crate::{logging, routing::RouteManagerHandle}; -use futures::channel::oneshot; +use futures::{channel::oneshot, future::BoxFuture}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, @@ -98,6 +98,20 @@ pub struct TunnelMonitor { monitor: InternalTunnelMonitor, } +/// Arguments for creating a tunnel. +pub struct TunnelArgs<'a, L> +where + // L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, +{ + /// Resource directory. + pub resource_dir: &'a Path, + /// Callback function called when an event happens. + pub on_event: L, + /// Receiver oneshot channel for closing the tunnel. + pub tunnel_close_rx: oneshot::Receiver<()>, +} + // TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` @@ -107,12 +121,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, tunnel_parameters: &mut TunnelParameters, log_dir: &Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -129,9 +141,9 @@ impl TunnelMonitor { TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel( config, log_file, - resource_dir, - on_event, - tunnel_close_rx, + init_args.resource_dir, + init_args.on_event, + init_args.tunnel_close_rx, #[cfg(target_os = "linux")] route_manager, )), @@ -142,12 +154,10 @@ impl TunnelMonitor { runtime, config, log_file, - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, ), } } @@ -178,12 +188,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, params: &mut wireguard_types::TunnelParameters, log: Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -211,12 +219,10 @@ impl TunnelMonitor { None }, log.as_deref(), - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index 910d5bb49e..9fdfb3e80b 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> { let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let openvpn_init_args = OpenVpnTunnelInitArgs { + event_server_abort_tx: event_server_abort_tx.clone(), + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file, + proxy_auth_file, + proxy_monitor, + tunnel_close_rx, + }; Self::new_internal( cmd, - event_server_abort_tx.clone(), - event_server_abort_rx, + openvpn_init_args, event_server::OpenvpnEventProxyImpl { on_event, user_pass_file_path: user_pass_file_path.clone(), @@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { #[cfg(target_os = "linux")] ipv6_enabled, }, - plugin_path, - log_path, - user_pass_file, - proxy_auth_file, - proxy_monitor, - tunnel_close_rx, #[cfg(windows)] Box::new(wintun), ) @@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute Ok(routes) } +struct OpenVpnTunnelInitArgs { + event_server_abort_tx: triggered::Trigger, + event_server_abort_rx: triggered::Listener, + plugin_path: PathBuf, + log_path: Option<PathBuf>, + user_pass_file: mktemp::TempFile, + proxy_auth_file: Option<mktemp::TempFile>, + proxy_monitor: Option<Box<dyn ProxyMonitor>>, + tunnel_close_rx: oneshot::Receiver<()>, +} + impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { async fn new_internal<L>( mut cmd: C, - event_server_abort_tx: triggered::Trigger, - event_server_abort_rx: triggered::Listener, + init_args: OpenVpnTunnelInitArgs, on_event: L, - plugin_path: PathBuf, - log_path: Option<PathBuf>, - user_pass_file: mktemp::TempFile, - proxy_auth_file: Option<mktemp::TempFile>, - proxy_monitor: Option<Box<dyn ProxyMonitor>>, - tunnel_close_rx: oneshot::Receiver<()>, #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where L: event_server::OpenvpnEventProxy + Send + Sync + 'static, { + let event_server_abort_tx = init_args.event_server_abort_tx; + let event_server_abort_rx = init_args.event_server_abort_rx; + let plugin_path = init_args.plugin_path; + let log_path = init_args.log_path; + let user_pass_file = init_args.user_pass_file; + let proxy_auth_file = init_args.proxy_auth_file; + let proxy_monitor = init_args.proxy_monitor; + let tunnel_close_rx = init_args.tunnel_close_rx; + let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx) .await .map_err(Error::EventDispatcherError)?; @@ -1220,23 +1236,37 @@ mod tests { .map_err(Error::RuntimeError) } + fn create_init_args_plugin_log( + plugin_path: PathBuf, + log_path: Option<PathBuf>, + ) -> OpenVpnTunnelInitArgs { + let (_close_tx, close_rx) = oneshot::channel(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + OpenVpnTunnelInitArgs { + event_server_abort_tx, + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file: TempFile::new(), + proxy_auth_file: None, + proxy_monitor: None, + tunnel_close_rx: close_rx, + } + } + + fn create_init_args() -> OpenVpnTunnelInitArgs { + create_init_args_plugin_log("".into(), None) + } + #[test] fn sets_plugin() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "./my_test_plugin".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1249,20 +1279,13 @@ mod tests { #[test] fn sets_log() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = + create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file"))); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - Some(PathBuf::from("./my_test_log_file")), - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1276,21 +1299,13 @@ mod tests { fn exit_successfully() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(0)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1302,21 +1317,13 @@ mod tests { fn exit_error() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1328,21 +1335,13 @@ mod tests { fn wait_closed() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1354,21 +1353,13 @@ mod tests { #[test] fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let result = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) diff --git a/talpid-core/src/tunnel/tun_provider/unix.rs b/talpid-core/src/tunnel/tun_provider/unix.rs index d8d3b7ce01..5c48a3c663 100644 --- a/talpid-core/src/tunnel/tun_provider/unix.rs +++ b/talpid-core/src/tunnel/tun_provider/unix.rs @@ -22,6 +22,12 @@ pub enum Error { /// Factory of tunnel devices on Unix systems. pub struct UnixTunProvider; +impl Default for UnixTunProvider { + fn default() -> Self { + Self::new() + } +} + impl UnixTunProvider { pub fn new() -> Self { UnixTunProvider diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs index 1a7a52ed8b..35ec10fc2f 100644 --- a/talpid-core/src/tunnel/wireguard/logging.rs +++ b/talpid-core/src/tunnel/wireguard/logging.rs @@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback( let level = match level { WG_GO_LOG_VERBOSE => LogLevel::Verbose, - WG_GO_LOG_ERROR | _ => LogLevel::Error, + _ => LogLevel::Error, }; log_inner(logfile, level, "wireguard-go", &managed_msg); } @@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback( pub type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose // const WG_GO_LOG_SILENT: WgLogLevel = 0; -const WG_GO_LOG_ERROR: WgLogLevel = 1; +// const WG_GO_LOG_ERROR: WgLogLevel = 1; const WG_GO_LOG_VERBOSE: WgLogLevel = 2; diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 7f17726c33..e49286cb30 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -1,15 +1,11 @@ use self::config::Config; #[cfg(not(windows))] use super::tun_provider; -use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; +use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use crate::routing::{self, RequiredRoute, RouteManagerHandle}; +use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future}; #[cfg(windows)] use futures::{channel::mpsc, StreamExt}; -use futures::{ - channel::oneshot, - future::{abortable, AbortHandle as FutureAbortHandle}, - Future, -}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -54,6 +50,7 @@ mod wireguard_nt; use self::wireguard_go::WgGoTunnel; type Result<T> = std::result::Result<T, Error>; +type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; /// Errors that can happen in the Wireguard tunnel monitor. #[derive(err_derive::Error, Debug)] @@ -104,12 +101,7 @@ pub struct WireguardMonitor { /// Tunnel implementation tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, /// Callback to signal tunnel events - event_callback: Box< - dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - >, + event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, pinger_stop_sender: sync_mpsc::Sender<()>, _obfuscator: Option<ObfuscatorHandle>, @@ -208,13 +200,13 @@ impl WireguardMonitor { mut config: Config, psk_negotiation: Option<PublicKey>, log_path: Option<&Path>, - resource_dir: &Path, - on_event: F, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { + let on_event = init_args.on_event; + let endpoint_addrs: Vec<IpAddr> = config.peers.iter().map(|peer| peer.endpoint.ip()).collect(); let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel(); @@ -228,7 +220,7 @@ impl WireguardMonitor { runtime.clone(), &Self::patch_allowed_ips(&config, psk_negotiation.is_some()), log_path, - resource_dir, + init_args.resource_dir, tun_provider, #[cfg(target_os = "windows")] setup_done_tx, @@ -351,7 +343,7 @@ impl WireguardMonitor { }); tokio::spawn(async move { - if tunnel_close_rx.await.is_ok() { + if init_args.tunnel_close_rx.await.is_ok() { monitor_handle.abort(); let _ = close_msg_sender.send(CloseMsg::Stop); } diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs index bda8af2e1f..cec033f611 100644 --- a/talpid-core/src/tunnel/wireguard/stats.rs +++ b/talpid-core/src/tunnel/wireguard/stats.rs @@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla}; #[derive(err_derive::Error, Debug, PartialEq)] pub enum Error { #[error(display = "Failed to parse peer pubkey from string \"_0\"")] - PubKeyParseError(String, #[error(source)] hex::FromHexError), + PubKeyParse(String, #[error(source)] hex::FromHexError), #[error(display = "Failed to parse integer from string \"_0\"")] - IntParseError(String, #[error(source)] std::num::ParseIntError), + IntParse(String, #[error(source)] std::num::ParseIntError), #[error(display = "Device no longer exists")] NoTunnelDevice, @@ -47,7 +47,7 @@ impl Stats { "public_key" => { let mut buffer = [0u8; 32]; hex::decode_to_slice(value, &mut buffer) - .map_err(|err| Error::PubKeyParseError(value.to_string(), err))?; + .map_err(|err| Error::PubKeyParse(value.to_string(), err))?; peer = Some(buffer); tx_bytes = None; rx_bytes = None; @@ -57,7 +57,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } "tx_bytes" => { @@ -65,7 +65,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } @@ -145,7 +145,7 @@ mod test { assert_eq!( Stats::parse_config_str(invalid_input), - Err(Error::IntParseError(invalid_str, int_err)) + Err(Error::IntParse(invalid_str, int_err)) ); } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs index 0f3866500e..5b7b6a1e12 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs @@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel; #[error(no_from)] pub enum Error { #[error(display = "Failed to decode netlink message")] - DecodeError(#[error(source)] DecodeError), + Decode(#[error(source)] DecodeError), #[error(display = "Failed to execute netlink control request")] - NetlinkControlMessageError(#[error(source)] nl_message::Error), + NetlinkControlMessage(#[error(source)] nl_message::Error), #[error(display = "Failed to open netlink socket")] - NetlinkSocketError(#[error(source)] std::io::Error), + NetlinkSocket(#[error(source)] std::io::Error), #[error(display = "Failed to send netlink control request")] - NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), + NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")] WireguardNetlinkInterfaceUnavailable, @@ -60,25 +60,25 @@ pub enum Error { NoDevice, #[error(display = "Failed to get config: _0")] - WgGetConfError(netlink_packet_core::error::ErrorMessage), + WgGetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Failed to apply config: _0")] - WgSetConfError(netlink_packet_core::error::ErrorMessage), + WgSetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Interface name too long")] - InterfaceNameError, + InterfaceName, #[error(display = "Send request error")] - SendRequestError(#[error(source)] NetlinkError<DeviceMessage>), + SendRequest(#[error(source)] NetlinkError<DeviceMessage>), #[error(display = "Create device error")] - NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error), + NetlinkCreateDevice(#[error(source)] rtnetlink::Error), #[error(display = "Add IP to device error")] - NetlinkSetIpError(rtnetlink::Error), + NetlinkSetIp(rtnetlink::Error), #[error(display = "Failed to delete device")] - DeleteDeviceError(#[error(source)] rtnetlink::Error), + DeleteDevice(#[error(source)] rtnetlink::Error), #[error(display = "NetworkManager error")] NetworkManager(#[error(source)] nm_tunnel::Error), @@ -98,7 +98,7 @@ impl Handle { pub async fn connect() -> Result<Self, Error> { let message_type = Self::get_wireguard_message_type().await?; let (conn, wireguard_connection, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let wg_handle = WireguardConnection { message_type, connection: wireguard_connection, @@ -106,7 +106,7 @@ impl Handle { let (abortable_connection, wg_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); let (conn, route_handle, _messages) = - rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?; + rtnetlink::new_connection().map_err(Error::NetlinkSocket)?; let (abortable_connection, route_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); @@ -120,21 +120,21 @@ impl Handle { async fn get_wireguard_message_type() -> Result<u16, Error> { let (conn, mut handle, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let (conn, abort_handle) = abortable(conn); tokio::spawn(conn); let result = async move { let mut message: NetlinkMessage<NetlinkControlMessage> = NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap()) - .map_err(Error::NetlinkControlMessageError)? + .map_err(Error::NetlinkControlMessage)? .into(); message.header.flags = NLM_F_REQUEST | NLM_F_ACK; let mut req = handle .request(message, SocketAddr::new(0, 0)) - .map_err(Error::NetlinkRequestError)?; + .map_err(Error::NetlinkRequest)?; let response = req.next().await; if let Some(response) = response { if let NetlinkPayload::InnerMessage(msg) = response.payload { @@ -177,14 +177,14 @@ impl Handle { let mut response = self .route_handle .request(add_request) - .map_err(Error::NetlinkCreateDeviceError)?; + .map_err(Error::NetlinkCreateDevice)?; while let Some(response_message) = response.next().await { if let NetlinkPayload::Error(err) = response_message.payload { // if the device exists, verify that it's a wireguard device if -err.code != libc::EEXIST { - return Err(Error::NetlinkCreateDeviceError( - rtnetlink::Error::NetlinkError(err), - )); + return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError( + err, + ))); } } } @@ -208,9 +208,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::NetlinkSetIpError)?; + .map_err(Error::NetlinkSetIp)?; while let Some(response_message) = response.next().await { - consume_netlink_error(response_message, Error::NetlinkSetIpError)?; + consume_netlink_error(response_message, Error::NetlinkSetIp)?; } Ok(()) @@ -226,9 +226,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::DeleteDeviceError)?; + .map_err(Error::DeleteDevice)?; while let Some(message) = response.next().await { - consume_netlink_error(message, Error::DeleteDeviceError)?; + consume_netlink_error(message, Error::DeleteDevice)?; } Ok(()) @@ -269,7 +269,7 @@ impl WireguardConnection { let mut response = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; match response.next().await { Some(received_message) => match received_message.payload { NetlinkPayload::InnerMessage(inner) => Ok(inner), @@ -277,7 +277,7 @@ impl WireguardConnection { if err.code == -libc::ENODEV { Err(Error::NoDevice) } else { - Err(Error::WgGetConfError(err)) + Err(Error::WgGetConf(err)) } } anything_else => { @@ -297,11 +297,11 @@ impl WireguardConnection { let mut request = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; while let Some(response) = request.next().await { if let NetlinkPayload::Error(err) = response.payload { - return Err(Error::WgSetConfError(err)); + return Err(Error::WgSetConf(err)); } } Ok(()) diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs index be2231f771..f2de334762 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -110,9 +110,9 @@ impl DeviceMessage { } pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> { - let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?; + let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?; if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ { - return Err(Error::InterfaceNameError); + return Err(Error::InterfaceName); } Ok(Self { @@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage { let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..]; let mut nlas = vec![]; for buf in NlasIterator::new(new_payload) { - nlas.push( - DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?, - ); + nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?); } Ok(DeviceMessage { @@ -391,13 +389,13 @@ impl Nla for PeerNla { InetAddr::V4(sockaddr_in) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in) }) + .write_all(unsafe { struct_as_slice(sockaddr_in) }) .expect("Buffer too small for sockaddr_in"); } InetAddr::V6(sockaddr_in6) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in6) }) + .write_all(unsafe { struct_as_slice(sockaddr_in6) }) .expect("Buffer too small for sockaddr_in6"); } }, @@ -408,7 +406,7 @@ impl Nla for PeerNla { let timespec: &libc::timespec = last_handshake.as_ref(); // SAFETY: `timespec` has no padding bytes buffer - .write(unsafe { struct_as_slice(timespec) }) + .write_all(unsafe { struct_as_slice(timespec) }) .expect("Buffer too small for timespec"); } RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes), @@ -535,7 +533,7 @@ impl Nla for AllowedIpNla { } IpAddr(ip_addr) => { buffer - .write(&ip_addr_to_bytes(ip_addr)) + .write_all(&ip_addr_to_bytes(ip_addr)) .expect("Buffer too small for AllowedIpNla::IpAddr"); } CidrMask(cidr_mask) => buffer[0] = *cidr_mask, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7536b26b09..e787729c04 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -6,7 +6,9 @@ use super::{ use crate::{ firewall::FirewallPolicy, routing::RouteManager, - tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor}, + tunnel::{ + self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor, + }, }; use cfg_if::cfg_if; use futures::{ @@ -142,16 +144,20 @@ impl ConnectingState { } }; + let init_args = TunnelArgs { + resource_dir: &resource_dir, + on_event: on_tunnel_event, + tunnel_close_rx, + }; + let block_reason = match TunnelMonitor::start( runtime, &mut tunnel_parameters, &log_dir, - &resource_dir, - on_tunnel_event, tun_provider, - route_manager_handle, retry_attempt, - tunnel_close_rx, + route_manager_handle, + init_args, ) { Ok(monitor) => { let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt); diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 3552eeab61..4c3eda0ecb 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -132,23 +132,25 @@ pub async fn spawn( let (shutdown_tx, shutdown_rx) = oneshot::channel(); let weak_command_tx = Arc::downgrade(&command_tx); - let state_machine = TunnelStateMachine::new( - initial_settings, - weak_command_tx, - offline_state_listener, + + let init_args = TunnelStateMachineInitArgs { + settings: initial_settings, + command_tx: weak_command_tx, + offline_state_tx: offline_state_listener, tunnel_parameters_generator, tun_provider, log_dir, resource_dir, - command_rx, + commands_rx: command_rx, #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "macos")] exclusion_gid, #[cfg(target_os = "android")] android_context, - ) - .await?; + }; + + let state_machine = TunnelStateMachine::new(init_args).await?; #[cfg(windows)] let split_tunnel = state_machine.shared_values.split_tunnel.handle(); @@ -219,20 +221,35 @@ struct TunnelStateMachine { shared_values: SharedTunnelStateValues, } +/// Tunnel state machine initialization arguments arguments +struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> { + settings: InitialTunnelState, + command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, + offline_state_tx: mpsc::UnboundedSender<bool>, + tunnel_parameters_generator: G, + tun_provider: TunProvider, + log_dir: Option<PathBuf>, + resource_dir: PathBuf, + commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, + #[cfg(target_os = "windows")] + volume_update_rx: mpsc::UnboundedReceiver<()>, + #[cfg(target_os = "macos")] + exclusion_gid: u32, + #[cfg(target_os = "android")] + android_context: AndroidContext, +} + impl TunnelStateMachine { async fn new( - settings: InitialTunnelState, - command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, - offline_state_tx: mpsc::UnboundedSender<bool>, - tunnel_parameters_generator: impl TunnelParametersGenerator, - tun_provider: TunProvider, - log_dir: Option<PathBuf>, - resource_dir: PathBuf, - commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, - #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, - #[cfg(target_os = "macos")] exclusion_gid: u32, - #[cfg(target_os = "android")] android_context: AndroidContext, + args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>, ) -> Result<Self, Error> { + #[cfg(target_os = "windows")] + let volume_update_rx = args.volume_update_rx; + #[cfg(target_os = "macos")] + let exclusion_gid = args.exclusion_gid; + #[cfg(target_os = "android")] + let android_context = args.android_context; + let runtime = tokio::runtime::Handle::current(); #[cfg(target_os = "macos")] @@ -242,20 +259,24 @@ impl TunnelStateMachine { let power_mgmt_rx = crate::windows::window::PowerManagementListener::new(); #[cfg(windows)] - let split_tunnel = - split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx) - .map_err(Error::InitSplitTunneling)?; + let split_tunnel = split_tunnel::SplitTunnel::new( + runtime.clone(), + args.command_tx.clone(), + volume_update_rx, + ) + .map_err(Error::InitSplitTunneling)?; - let args = FirewallArguments { - initial_state: if settings.block_when_disconnected || !settings.reset_firewall { - InitialFirewallState::Blocked(settings.allowed_endpoint.clone()) + let fw_args = FirewallArguments { + initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall + { + InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone()) } else { InitialFirewallState::None }, - allow_lan: settings.allow_lan, + allow_lan: args.settings.allow_lan, }; - let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?; + let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?; let route_manager = RouteManager::new(HashSet::new()) .await .map_err(Error::InitRouteManagerError)?; @@ -267,20 +288,20 @@ impl TunnelStateMachine { .handle() .map_err(Error::InitRouteManagerError)?, #[cfg(target_os = "macos")] - command_tx.clone(), + args.command_tx.clone(), ) .map_err(Error::InitDnsMonitorError)?; let (offline_tx, mut offline_rx) = mpsc::unbounded(); - let initial_offline_state_tx = offline_state_tx.clone(); + let initial_offline_state_tx = args.offline_state_tx.clone(); tokio::spawn(async move { while let Some(offline) = offline_rx.next().await { - if let Some(tx) = command_tx.upgrade() { + if let Some(tx) = args.command_tx.upgrade() { let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline)); } else { break; } - let _ = offline_state_tx.unbounded_send(offline); + let _ = args.offline_state_tx.unbounded_send(offline); } }); let mut offline_monitor = offline::spawn_monitor( @@ -301,7 +322,7 @@ impl TunnelStateMachine { #[cfg(windows)] split_tunnel - .set_paths_sync(&settings.exclude_paths) + .set_paths_sync(&args.settings.exclude_paths) .map_err(Error::InitSplitTunneling)?; let mut shared_values = SharedTunnelStateValues { @@ -312,15 +333,15 @@ impl TunnelStateMachine { dns_monitor, route_manager, _offline_monitor: offline_monitor, - allow_lan: settings.allow_lan, - block_when_disconnected: settings.block_when_disconnected, + allow_lan: args.settings.allow_lan, + block_when_disconnected: args.settings.block_when_disconnected, is_offline, - dns_servers: settings.dns_servers, - allowed_endpoint: settings.allowed_endpoint, - tunnel_parameters_generator: Box::new(tunnel_parameters_generator), - tun_provider: Arc::new(Mutex::new(tun_provider)), - log_dir, - resource_dir, + dns_servers: args.settings.dns_servers, + allowed_endpoint: args.settings.allowed_endpoint, + tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator), + tun_provider: Arc::new(Mutex::new(args.tun_provider)), + log_dir: args.log_dir, + resource_dir: args.resource_dir, #[cfg(target_os = "linux")] connectivity_check_was_enabled: None, #[cfg(target_os = "macos")] @@ -331,11 +352,11 @@ impl TunnelStateMachine { tokio::task::spawn_blocking(move || { let (initial_state, _) = - DisconnectedState::enter(&mut shared_values, settings.reset_firewall); + DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall); Ok(TunnelStateMachine { current_state: Some(initial_state), - commands: commands_rx.fuse(), + commands: args.commands_rx.fuse(), shared_values, }) }) |
