summaryrefslogtreecommitdiffhomepage
path: root/talpid-core
diff options
context:
space:
mode:
authorJonathan <jonathan@mullvad.net>2022-06-13 10:49:46 +0200
committerJonathan <jonathan@mullvad.net>2022-06-21 14:31:40 +0200
commitd3da8745c8ff9e66d6698d8a239b8139dbe8abfe (patch)
tree528a15026535b01bc3324892be783797aa64bbe4 /talpid-core
parentb6b80b9ffe6521a78ea6b2cdfd0e6965e67479fd (diff)
downloadmullvadvpn-d3da8745c8ff9e66d6698d8a239b8139dbe8abfe.tar.xz
mullvadvpn-d3da8745c8ff9e66d6698d8a239b8139dbe8abfe.zip
Fix the large majority of clippy warnings
This commit fixes most of the remaining clippy warnings in the codebase. These warnings were the more semantically difficult ones to fix. There are some warnings that remain from the rebase that will be fixed in the upcoming PR.
Diffstat (limited to 'talpid-core')
-rw-r--r--talpid-core/src/dns/linux/resolvconf.rs20
-rw-r--r--talpid-core/src/dns/linux/static_resolv_conf.rs8
-rw-r--r--talpid-core/src/mpsc.rs15
-rw-r--r--talpid-core/src/ping_monitor/icmp.rs2
-rw-r--r--talpid-core/src/routing/linux.rs57
-rw-r--r--talpid-core/src/routing/unix.rs3
-rw-r--r--talpid-core/src/tunnel/mod.rs46
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs143
-rw-r--r--talpid-core/src/tunnel/tun_provider/unix.rs6
-rw-r--r--talpid-core/src/tunnel/wireguard/logging.rs4
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs28
-rw-r--r--talpid-core/src/tunnel/wireguard/stats.rs12
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs56
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs16
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs16
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs103
16 files changed, 280 insertions, 255 deletions
diff --git a/talpid-core/src/dns/linux/resolvconf.rs b/talpid-core/src/dns/linux/resolvconf.rs
index ea2f3f3704..97db14b622 100644
--- a/talpid-core/src/dns/linux/resolvconf.rs
+++ b/talpid-core/src/dns/linux/resolvconf.rs
@@ -22,16 +22,16 @@ pub enum Error {
RunResolvconf(#[error(source)] io::Error),
#[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)]
- AddRecordError { stderr: String },
+ AddRecord { stderr: String },
#[error(display = "Using 'resolvconf' to delete a record failed")]
- DeleteRecordError,
+ DeleteRecord,
#[error(display = "Detected dnsmasq is runing and misconfigured")]
- DnsmasqMisconfigurationError,
+ DnsmasqMisconfiguration,
#[error(display = "Current /etc/resolv.conf is not generated by resolvconf")]
- ResolvconfNotInUseError,
+ ResolvconfNotInUse,
}
pub struct Resolvconf {
@@ -50,15 +50,15 @@ impl Resolvconf {
// Check if resolvconf is managing DNS by /etc/resolv.conf
if !is_dnsmasq_running
- && !(Self::check_if_resolvconf_is_symlinked_correctly()
- || Self::check_if_resolvconf_was_generated())
+ && !Self::check_if_resolvconf_is_symlinked_correctly()
+ && !Self::check_if_resolvconf_was_generated()
{
- return Err(Error::ResolvconfNotInUseError);
+ return Err(Error::ResolvconfNotInUse);
}
// Check if resolvconf can manage DNS via dnsmasq
if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() {
- return Err(Error::DnsmasqMisconfigurationError);
+ return Err(Error::DnsmasqMisconfiguration);
}
Ok(Resolvconf {
@@ -94,7 +94,7 @@ impl Resolvconf {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
- return Err(Error::AddRecordError { stderr });
+ return Err(Error::AddRecord { stderr });
}
self.record_names.insert(record_name);
@@ -118,7 +118,7 @@ impl Resolvconf {
record_name,
String::from_utf8_lossy(&output.stderr)
);
- result = Err(Error::DeleteRecordError);
+ result = Err(Error::DeleteRecord);
}
}
diff --git a/talpid-core/src/dns/linux/static_resolv_conf.rs b/talpid-core/src/dns/linux/static_resolv_conf.rs
index 196fb31003..691d7b468b 100644
--- a/talpid-core/src/dns/linux/static_resolv_conf.rs
+++ b/talpid-core/src/dns/linux/static_resolv_conf.rs
@@ -28,7 +28,7 @@ pub enum Error {
ReadResolvConf(&'static str, #[error(source)] io::Error),
#[error(display = "resolv.conf at {} could not be parsed", _0)]
- ParseError(&'static str, #[error(source)] resolv_conf::ParseError),
+ Parse(&'static str, #[error(source)] resolv_conf::ParseError),
#[error(display = "Failed to remove stale resolv.conf backup at {}", _0)]
RemoveBackup(&'static str, #[error(source)] io::Error),
@@ -179,7 +179,7 @@ fn read_config() -> Result<Config> {
let contents = fs::read_to_string(RESOLV_CONF_PATH)
.map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?;
- let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?;
+ let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?;
Ok(config)
}
@@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> {
match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) {
Ok(backup) => {
log::info!("Restoring DNS state from backup");
- let config = Config::parse(&backup)
- .map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?;
+ let config =
+ Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?;
write_config(&config)?;
diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs
index 8c6424bc01..6492796cfc 100644
--- a/talpid-core/src/mpsc.rs
+++ b/talpid-core/src/mpsc.rs
@@ -1,11 +1,20 @@
+/// Error type for `Sender` trait.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// The underlying channel is closed.
+ #[error(display = "Channel is closed")]
+ ChannelClosed,
+}
+
/// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`.
pub trait Sender<T> {
/// Sends an item over the underlying channel, failing only if the channel is closed.
- fn send(&self, item: T) -> Result<(), ()>;
+ fn send(&self, item: T) -> Result<(), Error>;
}
impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> {
- fn send(&self, content: E) -> Result<(), ()> {
- self.unbounded_send(content).map_err(|_| ())
+ fn send(&self, content: E) -> Result<(), Error> {
+ self.unbounded_send(content)
+ .map_err(|_| Error::ChannelClosed)
}
}
diff --git a/talpid-core/src/ping_monitor/icmp.rs b/talpid-core/src/ping_monitor/icmp.rs
index 67f5b70cb5..0bcd9da72f 100644
--- a/talpid-core/src/ping_monitor/icmp.rs
+++ b/talpid-core/src/ping_monitor/icmp.rs
@@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner(
let checksum = internet_checksum::checksum(buffer);
(&mut buffer[ICMP_CHECKSUM_OFFSET..])
- .write(&checksum)
+ .write_all(&checksum)
.unwrap();
true
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index 092ad6f52a..4b039fe9eb 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to open a netlink connection")]
- ConnectError(#[error(source)] io::Error),
+ Connect(#[error(source)] io::Error),
#[error(display = "Failed to bind netlink socket")]
- BindError(#[error(source)] io::Error),
+ Bind(#[error(source)] io::Error),
#[error(display = "Netlink error")]
- NetlinkError(#[error(source)] rtnetlink::Error),
+ Netlink(#[error(source)] rtnetlink::Error),
#[error(display = "Route without a valid node")]
InvalidRoute,
@@ -108,16 +108,16 @@ pub enum Error {
UnknownDeviceIndex(u32),
#[error(display = "Failed to get a route for the given IP address")]
- GetRouteError(#[error(source)] rtnetlink::Error),
+ GetRoute(#[error(source)] rtnetlink::Error),
#[error(display = "No netlink response for route query")]
- NoRouteError,
+ NoRoute,
#[error(display = "Route node was malformed")]
InvalidRouteNode,
#[error(display = "No link found")]
- LinkNotFoundError,
+ LinkNotFound,
/// Unable to create routing table for tagged connections and packets.
#[error(display = "Cannot find a free routing table ID")]
@@ -140,14 +140,11 @@ pub struct RouteManagerImpl {
impl RouteManagerImpl {
pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
let (mut connection, handle, messages) =
- rtnetlink::new_connection().map_err(Error::ConnectError)?;
+ rtnetlink::new_connection().map_err(Error::Connect)?;
let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY;
let addr = SocketAddr::new(0, mgroup_flags);
- connection
- .socket_mut()
- .bind(&addr)
- .map_err(Error::BindError)?;
+ connection.socket_mut().bind(&addr).map_err(Error::Bind)?;
tokio::spawn(connection);
@@ -179,11 +176,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@@ -236,7 +233,7 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
let mut rules = vec![];
@@ -246,7 +243,7 @@ impl RouteManagerImpl {
rules.push(rule);
}
NetlinkPayload::Error(error) => {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
_ => (),
}
@@ -260,12 +257,12 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
if error.to_io().kind() != io::ErrorKind::NotFound {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@@ -296,7 +293,7 @@ impl RouteManagerImpl {
) -> Result<BTreeMap<u32, NetworkInterface>> {
let mut link_map = BTreeMap::new();
let mut link_request = handle.link().get().execute();
- while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? {
+ while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? {
if let Some((idx, device)) = Self::map_interface(link) {
link_map.insert(idx, device);
}
@@ -543,7 +540,7 @@ impl RouteManagerImpl {
async fn delete_route_if_exists(&self, route: &Route) -> Result<()> {
if let Err(error) = self.delete_route(route).await {
- if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error {
+ if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error {
if msg.code == -libc::ESRCH {
return Ok(());
}
@@ -619,7 +616,7 @@ impl RouteManagerImpl {
.del(route_message)
.execute()
.await
- .map_err(Error::NetlinkError)
+ .map_err(Error::Netlink)
}
async fn add_route_direct(&mut self, route: Route) -> Result<()> {
@@ -693,11 +690,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(err) = message.payload {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err)));
}
}
Ok(())
@@ -759,7 +756,7 @@ impl RouteManagerImpl {
}
None => {
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
- return Err(Error::NoRouteError);
+ return Err(Error::NoRoute);
}
}
}
@@ -767,17 +764,13 @@ impl RouteManagerImpl {
"Retried {} times looking for the correct device and could not find it",
RECURSION_LIMIT
);
- Err(Error::NoRouteError)
+ Err(Error::NoRoute)
}
async fn get_device_mtu(&self, device: String) -> Result<u16> {
let mut links = self.handle.link().get().execute();
let target_device = LinkNla::IfName(device);
- while let Some(msg) = links
- .try_next()
- .await
- .map_err(|_| Error::LinkNotFoundError)?
- {
+ while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? {
let found = msg.nlas.iter().any(|e| *e == target_device);
if found {
if let Some(LinkNla::Mtu(mtu)) =
@@ -788,7 +781,7 @@ impl RouteManagerImpl {
}
}
}
- Err(Error::LinkNotFoundError)
+ Err(Error::LinkNotFound)
}
async fn get_destination_route(
@@ -813,11 +806,11 @@ impl RouteManagerImpl {
let mut stream = execute_route_get_request(self.handle.clone(), message.clone());
match stream.try_next().await {
Ok(Some(route_msg)) => self.parse_route_message(route_msg),
- Ok(None) => Err(Error::NoRouteError),
+ Ok(None) => Err(Error::NoRoute),
Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
Ok(None)
}
- Err(err) => Err(Error::GetRouteError(err)),
+ Err(err) => Err(Error::GetRoute(err)),
}
}
}
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index edfbdd2b85..326fb1fad1 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -19,16 +19,19 @@ use futures::stream::Stream;
#[cfg(target_os = "linux")]
use std::net::IpAddr;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
mod imp;
#[cfg(target_os = "macos")]
pub(crate) use imp::listen_for_default_route_changes;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "linux")]
#[path = "linux.rs"]
mod imp;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "android")]
#[path = "android.rs"]
mod imp;
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 5da3c092a4..f6ada1c2cf 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -1,6 +1,6 @@
use self::tun_provider::TunProvider;
use crate::{logging, routing::RouteManagerHandle};
-use futures::channel::oneshot;
+use futures::{channel::oneshot, future::BoxFuture};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
@@ -98,6 +98,20 @@ pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
}
+/// Arguments for creating a tunnel.
+pub struct TunnelArgs<'a, L>
+where
+ // L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
+{
+ /// Resource directory.
+ pub resource_dir: &'a Path,
+ /// Callback function called when an event happens.
+ pub on_event: L,
+ /// Receiver oneshot channel for closing the tunnel.
+ pub tunnel_close_rx: oneshot::Receiver<()>,
+}
+
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
@@ -107,12 +121,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
tunnel_parameters: &mut TunnelParameters,
log_dir: &Option<PathBuf>,
- resource_dir: &Path,
- on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -129,9 +141,9 @@ impl TunnelMonitor {
TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel(
config,
log_file,
- resource_dir,
- on_event,
- tunnel_close_rx,
+ init_args.resource_dir,
+ init_args.on_event,
+ init_args.tunnel_close_rx,
#[cfg(target_os = "linux")]
route_manager,
)),
@@ -142,12 +154,10 @@ impl TunnelMonitor {
runtime,
config,
log_file,
- resource_dir,
- on_event,
tun_provider,
- route_manager,
retry_attempt,
- tunnel_close_rx,
+ route_manager,
+ init_args,
),
}
}
@@ -178,12 +188,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
params: &mut wireguard_types::TunnelParameters,
log: Option<PathBuf>,
- resource_dir: &Path,
- on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -211,12 +219,10 @@ impl TunnelMonitor {
None
},
log.as_deref(),
- resource_dir,
- on_event,
tun_provider,
- route_manager,
retry_attempt,
- tunnel_close_rx,
+ route_manager,
+ init_args,
)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index 910d5bb49e..9fdfb3e80b 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> {
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let openvpn_init_args = OpenVpnTunnelInitArgs {
+ event_server_abort_tx: event_server_abort_tx.clone(),
+ event_server_abort_rx,
+ plugin_path,
+ log_path,
+ user_pass_file,
+ proxy_auth_file,
+ proxy_monitor,
+ tunnel_close_rx,
+ };
Self::new_internal(
cmd,
- event_server_abort_tx.clone(),
- event_server_abort_rx,
+ openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
on_event,
user_pass_file_path: user_pass_file_path.clone(),
@@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(target_os = "linux")]
ipv6_enabled,
},
- plugin_path,
- log_path,
- user_pass_file,
- proxy_auth_file,
- proxy_monitor,
- tunnel_close_rx,
#[cfg(windows)]
Box::new(wintun),
)
@@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
Ok(routes)
}
+struct OpenVpnTunnelInitArgs {
+ event_server_abort_tx: triggered::Trigger,
+ event_server_abort_rx: triggered::Listener,
+ plugin_path: PathBuf,
+ log_path: Option<PathBuf>,
+ user_pass_file: mktemp::TempFile,
+ proxy_auth_file: Option<mktemp::TempFile>,
+ proxy_monitor: Option<Box<dyn ProxyMonitor>>,
+ tunnel_close_rx: oneshot::Receiver<()>,
+}
+
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
async fn new_internal<L>(
mut cmd: C,
- event_server_abort_tx: triggered::Trigger,
- event_server_abort_rx: triggered::Listener,
+ init_args: OpenVpnTunnelInitArgs,
on_event: L,
- plugin_path: PathBuf,
- log_path: Option<PathBuf>,
- user_pass_file: mktemp::TempFile,
- proxy_auth_file: Option<mktemp::TempFile>,
- proxy_monitor: Option<Box<dyn ProxyMonitor>>,
- tunnel_close_rx: oneshot::Receiver<()>,
#[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
L: event_server::OpenvpnEventProxy + Send + Sync + 'static,
{
+ let event_server_abort_tx = init_args.event_server_abort_tx;
+ let event_server_abort_rx = init_args.event_server_abort_rx;
+ let plugin_path = init_args.plugin_path;
+ let log_path = init_args.log_path;
+ let user_pass_file = init_args.user_pass_file;
+ let proxy_auth_file = init_args.proxy_auth_file;
+ let proxy_monitor = init_args.proxy_monitor;
+ let tunnel_close_rx = init_args.tunnel_close_rx;
+
let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx)
.await
.map_err(Error::EventDispatcherError)?;
@@ -1220,23 +1236,37 @@ mod tests {
.map_err(Error::RuntimeError)
}
+ fn create_init_args_plugin_log(
+ plugin_path: PathBuf,
+ log_path: Option<PathBuf>,
+ ) -> OpenVpnTunnelInitArgs {
+ let (_close_tx, close_rx) = oneshot::channel();
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ OpenVpnTunnelInitArgs {
+ event_server_abort_tx,
+ event_server_abort_rx,
+ plugin_path,
+ log_path,
+ user_pass_file: TempFile::new(),
+ proxy_auth_file: None,
+ proxy_monitor: None,
+ tunnel_close_rx: close_rx,
+ }
+ }
+
+ fn create_init_args() -> OpenVpnTunnelInitArgs {
+ create_init_args_plugin_log("".into(), None)
+ }
+
#[test]
fn sets_plugin() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None);
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "./my_test_plugin".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@@ -1249,20 +1279,13 @@ mod tests {
#[test]
fn sets_log() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args =
+ create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file")));
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- Some(PathBuf::from("./my_test_log_file")),
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@@ -1276,21 +1299,13 @@ mod tests {
fn exit_successfully() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(0));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1302,21 +1317,13 @@ mod tests {
fn exit_error() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1328,21 +1335,13 @@ mod tests {
fn wait_closed() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1354,21 +1353,13 @@ mod tests {
#[test]
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let result = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
diff --git a/talpid-core/src/tunnel/tun_provider/unix.rs b/talpid-core/src/tunnel/tun_provider/unix.rs
index d8d3b7ce01..5c48a3c663 100644
--- a/talpid-core/src/tunnel/tun_provider/unix.rs
+++ b/talpid-core/src/tunnel/tun_provider/unix.rs
@@ -22,6 +22,12 @@ pub enum Error {
/// Factory of tunnel devices on Unix systems.
pub struct UnixTunProvider;
+impl Default for UnixTunProvider {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
impl UnixTunProvider {
pub fn new() -> Self {
UnixTunProvider
diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs
index 1a7a52ed8b..35ec10fc2f 100644
--- a/talpid-core/src/tunnel/wireguard/logging.rs
+++ b/talpid-core/src/tunnel/wireguard/logging.rs
@@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback(
let level = match level {
WG_GO_LOG_VERBOSE => LogLevel::Verbose,
- WG_GO_LOG_ERROR | _ => LogLevel::Error,
+ _ => LogLevel::Error,
};
log_inner(logfile, level, "wireguard-go", &managed_msg);
}
@@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback(
pub type WgLogLevel = u32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
// const WG_GO_LOG_SILENT: WgLogLevel = 0;
-const WG_GO_LOG_ERROR: WgLogLevel = 1;
+// const WG_GO_LOG_ERROR: WgLogLevel = 1;
const WG_GO_LOG_VERBOSE: WgLogLevel = 2;
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 7f17726c33..e49286cb30 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -1,15 +1,11 @@
use self::config::Config;
#[cfg(not(windows))]
use super::tun_provider;
-use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata};
+use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use crate::routing::{self, RequiredRoute, RouteManagerHandle};
+use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
#[cfg(windows)]
use futures::{channel::mpsc, StreamExt};
-use futures::{
- channel::oneshot,
- future::{abortable, AbortHandle as FutureAbortHandle},
- Future,
-};
#[cfg(target_os = "linux")]
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
@@ -54,6 +50,7 @@ mod wireguard_nt;
use self::wireguard_go::WgGoTunnel;
type Result<T> = std::result::Result<T, Error>;
+type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
/// Errors that can happen in the Wireguard tunnel monitor.
#[derive(err_derive::Error, Debug)]
@@ -104,12 +101,7 @@ pub struct WireguardMonitor {
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
- event_callback: Box<
- dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- >,
+ event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
_obfuscator: Option<ObfuscatorHandle>,
@@ -208,13 +200,13 @@ impl WireguardMonitor {
mut config: Config,
psk_negotiation: Option<PublicKey>,
log_path: Option<&Path>,
- resource_dir: &Path,
- on_event: F,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
+ let on_event = init_args.on_event;
+
let endpoint_addrs: Vec<IpAddr> =
config.peers.iter().map(|peer| peer.endpoint.ip()).collect();
let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel();
@@ -228,7 +220,7 @@ impl WireguardMonitor {
runtime.clone(),
&Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
log_path,
- resource_dir,
+ init_args.resource_dir,
tun_provider,
#[cfg(target_os = "windows")]
setup_done_tx,
@@ -351,7 +343,7 @@ impl WireguardMonitor {
});
tokio::spawn(async move {
- if tunnel_close_rx.await.is_ok() {
+ if init_args.tunnel_close_rx.await.is_ok() {
monitor_handle.abort();
let _ = close_msg_sender.send(CloseMsg::Stop);
}
diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs
index bda8af2e1f..cec033f611 100644
--- a/talpid-core/src/tunnel/wireguard/stats.rs
+++ b/talpid-core/src/tunnel/wireguard/stats.rs
@@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla};
#[derive(err_derive::Error, Debug, PartialEq)]
pub enum Error {
#[error(display = "Failed to parse peer pubkey from string \"_0\"")]
- PubKeyParseError(String, #[error(source)] hex::FromHexError),
+ PubKeyParse(String, #[error(source)] hex::FromHexError),
#[error(display = "Failed to parse integer from string \"_0\"")]
- IntParseError(String, #[error(source)] std::num::ParseIntError),
+ IntParse(String, #[error(source)] std::num::ParseIntError),
#[error(display = "Device no longer exists")]
NoTunnelDevice,
@@ -47,7 +47,7 @@ impl Stats {
"public_key" => {
let mut buffer = [0u8; 32];
hex::decode_to_slice(value, &mut buffer)
- .map_err(|err| Error::PubKeyParseError(value.to_string(), err))?;
+ .map_err(|err| Error::PubKeyParse(value.to_string(), err))?;
peer = Some(buffer);
tx_bytes = None;
rx_bytes = None;
@@ -57,7 +57,7 @@ impl Stats {
value
.trim()
.parse()
- .map_err(|err| Error::IntParseError(value.to_string(), err))?,
+ .map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
"tx_bytes" => {
@@ -65,7 +65,7 @@ impl Stats {
value
.trim()
.parse()
- .map_err(|err| Error::IntParseError(value.to_string(), err))?,
+ .map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
@@ -145,7 +145,7 @@ mod test {
assert_eq!(
Stats::parse_config_str(invalid_input),
- Err(Error::IntParseError(invalid_str, int_err))
+ Err(Error::IntParse(invalid_str, int_err))
);
}
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
index 0f3866500e..5b7b6a1e12 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
@@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to decode netlink message")]
- DecodeError(#[error(source)] DecodeError),
+ Decode(#[error(source)] DecodeError),
#[error(display = "Failed to execute netlink control request")]
- NetlinkControlMessageError(#[error(source)] nl_message::Error),
+ NetlinkControlMessage(#[error(source)] nl_message::Error),
#[error(display = "Failed to open netlink socket")]
- NetlinkSocketError(#[error(source)] std::io::Error),
+ NetlinkSocket(#[error(source)] std::io::Error),
#[error(display = "Failed to send netlink control request")]
- NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
+ NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
#[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")]
WireguardNetlinkInterfaceUnavailable,
@@ -60,25 +60,25 @@ pub enum Error {
NoDevice,
#[error(display = "Failed to get config: _0")]
- WgGetConfError(netlink_packet_core::error::ErrorMessage),
+ WgGetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Failed to apply config: _0")]
- WgSetConfError(netlink_packet_core::error::ErrorMessage),
+ WgSetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Interface name too long")]
- InterfaceNameError,
+ InterfaceName,
#[error(display = "Send request error")]
- SendRequestError(#[error(source)] NetlinkError<DeviceMessage>),
+ SendRequest(#[error(source)] NetlinkError<DeviceMessage>),
#[error(display = "Create device error")]
- NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error),
+ NetlinkCreateDevice(#[error(source)] rtnetlink::Error),
#[error(display = "Add IP to device error")]
- NetlinkSetIpError(rtnetlink::Error),
+ NetlinkSetIp(rtnetlink::Error),
#[error(display = "Failed to delete device")]
- DeleteDeviceError(#[error(source)] rtnetlink::Error),
+ DeleteDevice(#[error(source)] rtnetlink::Error),
#[error(display = "NetworkManager error")]
NetworkManager(#[error(source)] nm_tunnel::Error),
@@ -98,7 +98,7 @@ impl Handle {
pub async fn connect() -> Result<Self, Error> {
let message_type = Self::get_wireguard_message_type().await?;
let (conn, wireguard_connection, _messages) =
- netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
+ netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let wg_handle = WireguardConnection {
message_type,
connection: wireguard_connection,
@@ -106,7 +106,7 @@ impl Handle {
let (abortable_connection, wg_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
let (conn, route_handle, _messages) =
- rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?;
+ rtnetlink::new_connection().map_err(Error::NetlinkSocket)?;
let (abortable_connection, route_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
@@ -120,21 +120,21 @@ impl Handle {
async fn get_wireguard_message_type() -> Result<u16, Error> {
let (conn, mut handle, _messages) =
- netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
+ netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let (conn, abort_handle) = abortable(conn);
tokio::spawn(conn);
let result = async move {
let mut message: NetlinkMessage<NetlinkControlMessage> =
NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap())
- .map_err(Error::NetlinkControlMessageError)?
+ .map_err(Error::NetlinkControlMessage)?
.into();
message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
let mut req = handle
.request(message, SocketAddr::new(0, 0))
- .map_err(Error::NetlinkRequestError)?;
+ .map_err(Error::NetlinkRequest)?;
let response = req.next().await;
if let Some(response) = response {
if let NetlinkPayload::InnerMessage(msg) = response.payload {
@@ -177,14 +177,14 @@ impl Handle {
let mut response = self
.route_handle
.request(add_request)
- .map_err(Error::NetlinkCreateDeviceError)?;
+ .map_err(Error::NetlinkCreateDevice)?;
while let Some(response_message) = response.next().await {
if let NetlinkPayload::Error(err) = response_message.payload {
// if the device exists, verify that it's a wireguard device
if -err.code != libc::EEXIST {
- return Err(Error::NetlinkCreateDeviceError(
- rtnetlink::Error::NetlinkError(err),
- ));
+ return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError(
+ err,
+ )));
}
}
}
@@ -208,9 +208,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
- .map_err(Error::NetlinkSetIpError)?;
+ .map_err(Error::NetlinkSetIp)?;
while let Some(response_message) = response.next().await {
- consume_netlink_error(response_message, Error::NetlinkSetIpError)?;
+ consume_netlink_error(response_message, Error::NetlinkSetIp)?;
}
Ok(())
@@ -226,9 +226,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
- .map_err(Error::DeleteDeviceError)?;
+ .map_err(Error::DeleteDevice)?;
while let Some(message) = response.next().await {
- consume_netlink_error(message, Error::DeleteDeviceError)?;
+ consume_netlink_error(message, Error::DeleteDevice)?;
}
Ok(())
@@ -269,7 +269,7 @@ impl WireguardConnection {
let mut response = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
- .map_err(Error::SendRequestError)?;
+ .map_err(Error::SendRequest)?;
match response.next().await {
Some(received_message) => match received_message.payload {
NetlinkPayload::InnerMessage(inner) => Ok(inner),
@@ -277,7 +277,7 @@ impl WireguardConnection {
if err.code == -libc::ENODEV {
Err(Error::NoDevice)
} else {
- Err(Error::WgGetConfError(err))
+ Err(Error::WgGetConf(err))
}
}
anything_else => {
@@ -297,11 +297,11 @@ impl WireguardConnection {
let mut request = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
- .map_err(Error::SendRequestError)?;
+ .map_err(Error::SendRequest)?;
while let Some(response) = request.next().await {
if let NetlinkPayload::Error(err) = response.payload {
- return Err(Error::WgSetConfError(err));
+ return Err(Error::WgSetConf(err));
}
}
Ok(())
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
index be2231f771..f2de334762 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
@@ -110,9 +110,9 @@ impl DeviceMessage {
}
pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> {
- let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?;
+ let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?;
if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ {
- return Err(Error::InterfaceNameError);
+ return Err(Error::InterfaceName);
}
Ok(Self {
@@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage {
let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..];
let mut nlas = vec![];
for buf in NlasIterator::new(new_payload) {
- nlas.push(
- DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?,
- );
+ nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?);
}
Ok(DeviceMessage {
@@ -391,13 +389,13 @@ impl Nla for PeerNla {
InetAddr::V4(sockaddr_in) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(sockaddr_in) })
+ .write_all(unsafe { struct_as_slice(sockaddr_in) })
.expect("Buffer too small for sockaddr_in");
}
InetAddr::V6(sockaddr_in6) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(sockaddr_in6) })
+ .write_all(unsafe { struct_as_slice(sockaddr_in6) })
.expect("Buffer too small for sockaddr_in6");
}
},
@@ -408,7 +406,7 @@ impl Nla for PeerNla {
let timespec: &libc::timespec = last_handshake.as_ref();
// SAFETY: `timespec` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(timespec) })
+ .write_all(unsafe { struct_as_slice(timespec) })
.expect("Buffer too small for timespec");
}
RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes),
@@ -535,7 +533,7 @@ impl Nla for AllowedIpNla {
}
IpAddr(ip_addr) => {
buffer
- .write(&ip_addr_to_bytes(ip_addr))
+ .write_all(&ip_addr_to_bytes(ip_addr))
.expect("Buffer too small for AllowedIpNla::IpAddr");
}
CidrMask(cidr_mask) => buffer[0] = *cidr_mask,
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 7536b26b09..e787729c04 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -6,7 +6,9 @@ use super::{
use crate::{
firewall::FirewallPolicy,
routing::RouteManager,
- tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor},
+ tunnel::{
+ self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor,
+ },
};
use cfg_if::cfg_if;
use futures::{
@@ -142,16 +144,20 @@ impl ConnectingState {
}
};
+ let init_args = TunnelArgs {
+ resource_dir: &resource_dir,
+ on_event: on_tunnel_event,
+ tunnel_close_rx,
+ };
+
let block_reason = match TunnelMonitor::start(
runtime,
&mut tunnel_parameters,
&log_dir,
- &resource_dir,
- on_tunnel_event,
tun_provider,
- route_manager_handle,
retry_attempt,
- tunnel_close_rx,
+ route_manager_handle,
+ init_args,
) {
Ok(monitor) => {
let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 3552eeab61..4c3eda0ecb 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -132,23 +132,25 @@ pub async fn spawn(
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let weak_command_tx = Arc::downgrade(&command_tx);
- let state_machine = TunnelStateMachine::new(
- initial_settings,
- weak_command_tx,
- offline_state_listener,
+
+ let init_args = TunnelStateMachineInitArgs {
+ settings: initial_settings,
+ command_tx: weak_command_tx,
+ offline_state_tx: offline_state_listener,
tunnel_parameters_generator,
tun_provider,
log_dir,
resource_dir,
- command_rx,
+ commands_rx: command_rx,
#[cfg(target_os = "windows")]
volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
android_context,
- )
- .await?;
+ };
+
+ let state_machine = TunnelStateMachine::new(init_args).await?;
#[cfg(windows)]
let split_tunnel = state_machine.shared_values.split_tunnel.handle();
@@ -219,20 +221,35 @@ struct TunnelStateMachine {
shared_values: SharedTunnelStateValues,
}
+/// Tunnel state machine initialization arguments arguments
+struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> {
+ settings: InitialTunnelState,
+ command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ offline_state_tx: mpsc::UnboundedSender<bool>,
+ tunnel_parameters_generator: G,
+ tun_provider: TunProvider,
+ log_dir: Option<PathBuf>,
+ resource_dir: PathBuf,
+ commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
+ #[cfg(target_os = "windows")]
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ #[cfg(target_os = "macos")]
+ exclusion_gid: u32,
+ #[cfg(target_os = "android")]
+ android_context: AndroidContext,
+}
+
impl TunnelStateMachine {
async fn new(
- settings: InitialTunnelState,
- command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
- offline_state_tx: mpsc::UnboundedSender<bool>,
- tunnel_parameters_generator: impl TunnelParametersGenerator,
- tun_provider: TunProvider,
- log_dir: Option<PathBuf>,
- resource_dir: PathBuf,
- commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
- #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
- #[cfg(target_os = "macos")] exclusion_gid: u32,
- #[cfg(target_os = "android")] android_context: AndroidContext,
+ args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>,
) -> Result<Self, Error> {
+ #[cfg(target_os = "windows")]
+ let volume_update_rx = args.volume_update_rx;
+ #[cfg(target_os = "macos")]
+ let exclusion_gid = args.exclusion_gid;
+ #[cfg(target_os = "android")]
+ let android_context = args.android_context;
+
let runtime = tokio::runtime::Handle::current();
#[cfg(target_os = "macos")]
@@ -242,20 +259,24 @@ impl TunnelStateMachine {
let power_mgmt_rx = crate::windows::window::PowerManagementListener::new();
#[cfg(windows)]
- let split_tunnel =
- split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
- .map_err(Error::InitSplitTunneling)?;
+ let split_tunnel = split_tunnel::SplitTunnel::new(
+ runtime.clone(),
+ args.command_tx.clone(),
+ volume_update_rx,
+ )
+ .map_err(Error::InitSplitTunneling)?;
- let args = FirewallArguments {
- initial_state: if settings.block_when_disconnected || !settings.reset_firewall {
- InitialFirewallState::Blocked(settings.allowed_endpoint.clone())
+ let fw_args = FirewallArguments {
+ initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall
+ {
+ InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone())
} else {
InitialFirewallState::None
},
- allow_lan: settings.allow_lan,
+ allow_lan: args.settings.allow_lan,
};
- let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?;
+ let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?;
let route_manager = RouteManager::new(HashSet::new())
.await
.map_err(Error::InitRouteManagerError)?;
@@ -267,20 +288,20 @@ impl TunnelStateMachine {
.handle()
.map_err(Error::InitRouteManagerError)?,
#[cfg(target_os = "macos")]
- command_tx.clone(),
+ args.command_tx.clone(),
)
.map_err(Error::InitDnsMonitorError)?;
let (offline_tx, mut offline_rx) = mpsc::unbounded();
- let initial_offline_state_tx = offline_state_tx.clone();
+ let initial_offline_state_tx = args.offline_state_tx.clone();
tokio::spawn(async move {
while let Some(offline) = offline_rx.next().await {
- if let Some(tx) = command_tx.upgrade() {
+ if let Some(tx) = args.command_tx.upgrade() {
let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline));
} else {
break;
}
- let _ = offline_state_tx.unbounded_send(offline);
+ let _ = args.offline_state_tx.unbounded_send(offline);
}
});
let mut offline_monitor = offline::spawn_monitor(
@@ -301,7 +322,7 @@ impl TunnelStateMachine {
#[cfg(windows)]
split_tunnel
- .set_paths_sync(&settings.exclude_paths)
+ .set_paths_sync(&args.settings.exclude_paths)
.map_err(Error::InitSplitTunneling)?;
let mut shared_values = SharedTunnelStateValues {
@@ -312,15 +333,15 @@ impl TunnelStateMachine {
dns_monitor,
route_manager,
_offline_monitor: offline_monitor,
- allow_lan: settings.allow_lan,
- block_when_disconnected: settings.block_when_disconnected,
+ allow_lan: args.settings.allow_lan,
+ block_when_disconnected: args.settings.block_when_disconnected,
is_offline,
- dns_servers: settings.dns_servers,
- allowed_endpoint: settings.allowed_endpoint,
- tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
- tun_provider: Arc::new(Mutex::new(tun_provider)),
- log_dir,
- resource_dir,
+ dns_servers: args.settings.dns_servers,
+ allowed_endpoint: args.settings.allowed_endpoint,
+ tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator),
+ tun_provider: Arc::new(Mutex::new(args.tun_provider)),
+ log_dir: args.log_dir,
+ resource_dir: args.resource_dir,
#[cfg(target_os = "linux")]
connectivity_check_was_enabled: None,
#[cfg(target_os = "macos")]
@@ -331,11 +352,11 @@ impl TunnelStateMachine {
tokio::task::spawn_blocking(move || {
let (initial_state, _) =
- DisconnectedState::enter(&mut shared_values, settings.reset_firewall);
+ DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall);
Ok(TunnelStateMachine {
current_state: Some(initial_state),
- commands: commands_rx.fuse(),
+ commands: args.commands_rx.fuse(),
shared_values,
})
})