summaryrefslogtreecommitdiffhomepage
path: root/talpid-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'talpid-core/src')
-rw-r--r--talpid-core/src/dns/linux/resolvconf.rs20
-rw-r--r--talpid-core/src/dns/linux/static_resolv_conf.rs8
-rw-r--r--talpid-core/src/mpsc.rs15
-rw-r--r--talpid-core/src/ping_monitor/icmp.rs2
-rw-r--r--talpid-core/src/routing/linux.rs57
-rw-r--r--talpid-core/src/routing/unix.rs3
-rw-r--r--talpid-core/src/tunnel/mod.rs46
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs143
-rw-r--r--talpid-core/src/tunnel/tun_provider/unix.rs6
-rw-r--r--talpid-core/src/tunnel/wireguard/logging.rs4
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs28
-rw-r--r--talpid-core/src/tunnel/wireguard/stats.rs12
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs56
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs16
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs16
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs103
16 files changed, 280 insertions, 255 deletions
diff --git a/talpid-core/src/dns/linux/resolvconf.rs b/talpid-core/src/dns/linux/resolvconf.rs
index ea2f3f3704..97db14b622 100644
--- a/talpid-core/src/dns/linux/resolvconf.rs
+++ b/talpid-core/src/dns/linux/resolvconf.rs
@@ -22,16 +22,16 @@ pub enum Error {
RunResolvconf(#[error(source)] io::Error),
#[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)]
- AddRecordError { stderr: String },
+ AddRecord { stderr: String },
#[error(display = "Using 'resolvconf' to delete a record failed")]
- DeleteRecordError,
+ DeleteRecord,
#[error(display = "Detected dnsmasq is runing and misconfigured")]
- DnsmasqMisconfigurationError,
+ DnsmasqMisconfiguration,
#[error(display = "Current /etc/resolv.conf is not generated by resolvconf")]
- ResolvconfNotInUseError,
+ ResolvconfNotInUse,
}
pub struct Resolvconf {
@@ -50,15 +50,15 @@ impl Resolvconf {
// Check if resolvconf is managing DNS by /etc/resolv.conf
if !is_dnsmasq_running
- && !(Self::check_if_resolvconf_is_symlinked_correctly()
- || Self::check_if_resolvconf_was_generated())
+ && !Self::check_if_resolvconf_is_symlinked_correctly()
+ && !Self::check_if_resolvconf_was_generated()
{
- return Err(Error::ResolvconfNotInUseError);
+ return Err(Error::ResolvconfNotInUse);
}
// Check if resolvconf can manage DNS via dnsmasq
if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() {
- return Err(Error::DnsmasqMisconfigurationError);
+ return Err(Error::DnsmasqMisconfiguration);
}
Ok(Resolvconf {
@@ -94,7 +94,7 @@ impl Resolvconf {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
- return Err(Error::AddRecordError { stderr });
+ return Err(Error::AddRecord { stderr });
}
self.record_names.insert(record_name);
@@ -118,7 +118,7 @@ impl Resolvconf {
record_name,
String::from_utf8_lossy(&output.stderr)
);
- result = Err(Error::DeleteRecordError);
+ result = Err(Error::DeleteRecord);
}
}
diff --git a/talpid-core/src/dns/linux/static_resolv_conf.rs b/talpid-core/src/dns/linux/static_resolv_conf.rs
index 196fb31003..691d7b468b 100644
--- a/talpid-core/src/dns/linux/static_resolv_conf.rs
+++ b/talpid-core/src/dns/linux/static_resolv_conf.rs
@@ -28,7 +28,7 @@ pub enum Error {
ReadResolvConf(&'static str, #[error(source)] io::Error),
#[error(display = "resolv.conf at {} could not be parsed", _0)]
- ParseError(&'static str, #[error(source)] resolv_conf::ParseError),
+ Parse(&'static str, #[error(source)] resolv_conf::ParseError),
#[error(display = "Failed to remove stale resolv.conf backup at {}", _0)]
RemoveBackup(&'static str, #[error(source)] io::Error),
@@ -179,7 +179,7 @@ fn read_config() -> Result<Config> {
let contents = fs::read_to_string(RESOLV_CONF_PATH)
.map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?;
- let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?;
+ let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?;
Ok(config)
}
@@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> {
match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) {
Ok(backup) => {
log::info!("Restoring DNS state from backup");
- let config = Config::parse(&backup)
- .map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?;
+ let config =
+ Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?;
write_config(&config)?;
diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs
index 8c6424bc01..6492796cfc 100644
--- a/talpid-core/src/mpsc.rs
+++ b/talpid-core/src/mpsc.rs
@@ -1,11 +1,20 @@
+/// Error type for `Sender` trait.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// The underlying channel is closed.
+ #[error(display = "Channel is closed")]
+ ChannelClosed,
+}
+
/// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`.
pub trait Sender<T> {
/// Sends an item over the underlying channel, failing only if the channel is closed.
- fn send(&self, item: T) -> Result<(), ()>;
+ fn send(&self, item: T) -> Result<(), Error>;
}
impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> {
- fn send(&self, content: E) -> Result<(), ()> {
- self.unbounded_send(content).map_err(|_| ())
+ fn send(&self, content: E) -> Result<(), Error> {
+ self.unbounded_send(content)
+ .map_err(|_| Error::ChannelClosed)
}
}
diff --git a/talpid-core/src/ping_monitor/icmp.rs b/talpid-core/src/ping_monitor/icmp.rs
index 67f5b70cb5..0bcd9da72f 100644
--- a/talpid-core/src/ping_monitor/icmp.rs
+++ b/talpid-core/src/ping_monitor/icmp.rs
@@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner(
let checksum = internet_checksum::checksum(buffer);
(&mut buffer[ICMP_CHECKSUM_OFFSET..])
- .write(&checksum)
+ .write_all(&checksum)
.unwrap();
true
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index 092ad6f52a..4b039fe9eb 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to open a netlink connection")]
- ConnectError(#[error(source)] io::Error),
+ Connect(#[error(source)] io::Error),
#[error(display = "Failed to bind netlink socket")]
- BindError(#[error(source)] io::Error),
+ Bind(#[error(source)] io::Error),
#[error(display = "Netlink error")]
- NetlinkError(#[error(source)] rtnetlink::Error),
+ Netlink(#[error(source)] rtnetlink::Error),
#[error(display = "Route without a valid node")]
InvalidRoute,
@@ -108,16 +108,16 @@ pub enum Error {
UnknownDeviceIndex(u32),
#[error(display = "Failed to get a route for the given IP address")]
- GetRouteError(#[error(source)] rtnetlink::Error),
+ GetRoute(#[error(source)] rtnetlink::Error),
#[error(display = "No netlink response for route query")]
- NoRouteError,
+ NoRoute,
#[error(display = "Route node was malformed")]
InvalidRouteNode,
#[error(display = "No link found")]
- LinkNotFoundError,
+ LinkNotFound,
/// Unable to create routing table for tagged connections and packets.
#[error(display = "Cannot find a free routing table ID")]
@@ -140,14 +140,11 @@ pub struct RouteManagerImpl {
impl RouteManagerImpl {
pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
let (mut connection, handle, messages) =
- rtnetlink::new_connection().map_err(Error::ConnectError)?;
+ rtnetlink::new_connection().map_err(Error::Connect)?;
let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY;
let addr = SocketAddr::new(0, mgroup_flags);
- connection
- .socket_mut()
- .bind(&addr)
- .map_err(Error::BindError)?;
+ connection.socket_mut().bind(&addr).map_err(Error::Bind)?;
tokio::spawn(connection);
@@ -179,11 +176,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@@ -236,7 +233,7 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
let mut rules = vec![];
@@ -246,7 +243,7 @@ impl RouteManagerImpl {
rules.push(rule);
}
NetlinkPayload::Error(error) => {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
_ => (),
}
@@ -260,12 +257,12 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
if error.to_io().kind() != io::ErrorKind::NotFound {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@@ -296,7 +293,7 @@ impl RouteManagerImpl {
) -> Result<BTreeMap<u32, NetworkInterface>> {
let mut link_map = BTreeMap::new();
let mut link_request = handle.link().get().execute();
- while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? {
+ while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? {
if let Some((idx, device)) = Self::map_interface(link) {
link_map.insert(idx, device);
}
@@ -543,7 +540,7 @@ impl RouteManagerImpl {
async fn delete_route_if_exists(&self, route: &Route) -> Result<()> {
if let Err(error) = self.delete_route(route).await {
- if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error {
+ if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error {
if msg.code == -libc::ESRCH {
return Ok(());
}
@@ -619,7 +616,7 @@ impl RouteManagerImpl {
.del(route_message)
.execute()
.await
- .map_err(Error::NetlinkError)
+ .map_err(Error::Netlink)
}
async fn add_route_direct(&mut self, route: Route) -> Result<()> {
@@ -693,11 +690,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
- let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
+ let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(err) = message.payload {
- return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err)));
+ return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err)));
}
}
Ok(())
@@ -759,7 +756,7 @@ impl RouteManagerImpl {
}
None => {
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
- return Err(Error::NoRouteError);
+ return Err(Error::NoRoute);
}
}
}
@@ -767,17 +764,13 @@ impl RouteManagerImpl {
"Retried {} times looking for the correct device and could not find it",
RECURSION_LIMIT
);
- Err(Error::NoRouteError)
+ Err(Error::NoRoute)
}
async fn get_device_mtu(&self, device: String) -> Result<u16> {
let mut links = self.handle.link().get().execute();
let target_device = LinkNla::IfName(device);
- while let Some(msg) = links
- .try_next()
- .await
- .map_err(|_| Error::LinkNotFoundError)?
- {
+ while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? {
let found = msg.nlas.iter().any(|e| *e == target_device);
if found {
if let Some(LinkNla::Mtu(mtu)) =
@@ -788,7 +781,7 @@ impl RouteManagerImpl {
}
}
}
- Err(Error::LinkNotFoundError)
+ Err(Error::LinkNotFound)
}
async fn get_destination_route(
@@ -813,11 +806,11 @@ impl RouteManagerImpl {
let mut stream = execute_route_get_request(self.handle.clone(), message.clone());
match stream.try_next().await {
Ok(Some(route_msg)) => self.parse_route_message(route_msg),
- Ok(None) => Err(Error::NoRouteError),
+ Ok(None) => Err(Error::NoRoute),
Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
Ok(None)
}
- Err(err) => Err(Error::GetRouteError(err)),
+ Err(err) => Err(Error::GetRoute(err)),
}
}
}
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index edfbdd2b85..326fb1fad1 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -19,16 +19,19 @@ use futures::stream::Stream;
#[cfg(target_os = "linux")]
use std::net::IpAddr;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
mod imp;
#[cfg(target_os = "macos")]
pub(crate) use imp::listen_for_default_route_changes;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "linux")]
#[path = "linux.rs"]
mod imp;
+#[allow(clippy::module_inception)]
#[cfg(target_os = "android")]
#[path = "android.rs"]
mod imp;
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 5da3c092a4..f6ada1c2cf 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -1,6 +1,6 @@
use self::tun_provider::TunProvider;
use crate::{logging, routing::RouteManagerHandle};
-use futures::channel::oneshot;
+use futures::{channel::oneshot, future::BoxFuture};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
@@ -98,6 +98,20 @@ pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
}
+/// Arguments for creating a tunnel.
+pub struct TunnelArgs<'a, L>
+where
+ // L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
+{
+ /// Resource directory.
+ pub resource_dir: &'a Path,
+ /// Callback function called when an event happens.
+ pub on_event: L,
+ /// Receiver oneshot channel for closing the tunnel.
+ pub tunnel_close_rx: oneshot::Receiver<()>,
+}
+
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
@@ -107,12 +121,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
tunnel_parameters: &mut TunnelParameters,
log_dir: &Option<PathBuf>,
- resource_dir: &Path,
- on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -129,9 +141,9 @@ impl TunnelMonitor {
TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel(
config,
log_file,
- resource_dir,
- on_event,
- tunnel_close_rx,
+ init_args.resource_dir,
+ init_args.on_event,
+ init_args.tunnel_close_rx,
#[cfg(target_os = "linux")]
route_manager,
)),
@@ -142,12 +154,10 @@ impl TunnelMonitor {
runtime,
config,
log_file,
- resource_dir,
- on_event,
tun_provider,
- route_manager,
retry_attempt,
- tunnel_close_rx,
+ route_manager,
+ init_args,
),
}
}
@@ -178,12 +188,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
params: &mut wireguard_types::TunnelParameters,
log: Option<PathBuf>,
- resource_dir: &Path,
- on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -211,12 +219,10 @@ impl TunnelMonitor {
None
},
log.as_deref(),
- resource_dir,
- on_event,
tun_provider,
- route_manager,
retry_attempt,
- tunnel_close_rx,
+ route_manager,
+ init_args,
)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index 910d5bb49e..9fdfb3e80b 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> {
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let openvpn_init_args = OpenVpnTunnelInitArgs {
+ event_server_abort_tx: event_server_abort_tx.clone(),
+ event_server_abort_rx,
+ plugin_path,
+ log_path,
+ user_pass_file,
+ proxy_auth_file,
+ proxy_monitor,
+ tunnel_close_rx,
+ };
Self::new_internal(
cmd,
- event_server_abort_tx.clone(),
- event_server_abort_rx,
+ openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
on_event,
user_pass_file_path: user_pass_file_path.clone(),
@@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(target_os = "linux")]
ipv6_enabled,
},
- plugin_path,
- log_path,
- user_pass_file,
- proxy_auth_file,
- proxy_monitor,
- tunnel_close_rx,
#[cfg(windows)]
Box::new(wintun),
)
@@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
Ok(routes)
}
+struct OpenVpnTunnelInitArgs {
+ event_server_abort_tx: triggered::Trigger,
+ event_server_abort_rx: triggered::Listener,
+ plugin_path: PathBuf,
+ log_path: Option<PathBuf>,
+ user_pass_file: mktemp::TempFile,
+ proxy_auth_file: Option<mktemp::TempFile>,
+ proxy_monitor: Option<Box<dyn ProxyMonitor>>,
+ tunnel_close_rx: oneshot::Receiver<()>,
+}
+
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
async fn new_internal<L>(
mut cmd: C,
- event_server_abort_tx: triggered::Trigger,
- event_server_abort_rx: triggered::Listener,
+ init_args: OpenVpnTunnelInitArgs,
on_event: L,
- plugin_path: PathBuf,
- log_path: Option<PathBuf>,
- user_pass_file: mktemp::TempFile,
- proxy_auth_file: Option<mktemp::TempFile>,
- proxy_monitor: Option<Box<dyn ProxyMonitor>>,
- tunnel_close_rx: oneshot::Receiver<()>,
#[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
L: event_server::OpenvpnEventProxy + Send + Sync + 'static,
{
+ let event_server_abort_tx = init_args.event_server_abort_tx;
+ let event_server_abort_rx = init_args.event_server_abort_rx;
+ let plugin_path = init_args.plugin_path;
+ let log_path = init_args.log_path;
+ let user_pass_file = init_args.user_pass_file;
+ let proxy_auth_file = init_args.proxy_auth_file;
+ let proxy_monitor = init_args.proxy_monitor;
+ let tunnel_close_rx = init_args.tunnel_close_rx;
+
let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx)
.await
.map_err(Error::EventDispatcherError)?;
@@ -1220,23 +1236,37 @@ mod tests {
.map_err(Error::RuntimeError)
}
+ fn create_init_args_plugin_log(
+ plugin_path: PathBuf,
+ log_path: Option<PathBuf>,
+ ) -> OpenVpnTunnelInitArgs {
+ let (_close_tx, close_rx) = oneshot::channel();
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ OpenVpnTunnelInitArgs {
+ event_server_abort_tx,
+ event_server_abort_rx,
+ plugin_path,
+ log_path,
+ user_pass_file: TempFile::new(),
+ proxy_auth_file: None,
+ proxy_monitor: None,
+ tunnel_close_rx: close_rx,
+ }
+ }
+
+ fn create_init_args() -> OpenVpnTunnelInitArgs {
+ create_init_args_plugin_log("".into(), None)
+ }
+
#[test]
fn sets_plugin() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None);
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "./my_test_plugin".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@@ -1249,20 +1279,13 @@ mod tests {
#[test]
fn sets_log() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args =
+ create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file")));
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- Some(PathBuf::from("./my_test_log_file")),
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@@ -1276,21 +1299,13 @@ mod tests {
fn exit_successfully() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(0));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1302,21 +1317,13 @@ mod tests {
fn exit_error() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1328,21 +1335,13 @@ mod tests {
fn wait_closed() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@@ -1354,21 +1353,13 @@ mod tests {
#[test]
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
- let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
+ let openvpn_init_args = create_init_args();
let result = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
- event_server_abort_tx,
- event_server_abort_rx,
+ openvpn_init_args,
TestOpenvpnEventProxy {},
- "".into(),
- None,
- TempFile::new(),
- None,
- None,
- close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
diff --git a/talpid-core/src/tunnel/tun_provider/unix.rs b/talpid-core/src/tunnel/tun_provider/unix.rs
index d8d3b7ce01..5c48a3c663 100644
--- a/talpid-core/src/tunnel/tun_provider/unix.rs
+++ b/talpid-core/src/tunnel/tun_provider/unix.rs
@@ -22,6 +22,12 @@ pub enum Error {
/// Factory of tunnel devices on Unix systems.
pub struct UnixTunProvider;
+impl Default for UnixTunProvider {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
impl UnixTunProvider {
pub fn new() -> Self {
UnixTunProvider
diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs
index 1a7a52ed8b..35ec10fc2f 100644
--- a/talpid-core/src/tunnel/wireguard/logging.rs
+++ b/talpid-core/src/tunnel/wireguard/logging.rs
@@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback(
let level = match level {
WG_GO_LOG_VERBOSE => LogLevel::Verbose,
- WG_GO_LOG_ERROR | _ => LogLevel::Error,
+ _ => LogLevel::Error,
};
log_inner(logfile, level, "wireguard-go", &managed_msg);
}
@@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback(
pub type WgLogLevel = u32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
// const WG_GO_LOG_SILENT: WgLogLevel = 0;
-const WG_GO_LOG_ERROR: WgLogLevel = 1;
+// const WG_GO_LOG_ERROR: WgLogLevel = 1;
const WG_GO_LOG_VERBOSE: WgLogLevel = 2;
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 7f17726c33..e49286cb30 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -1,15 +1,11 @@
use self::config::Config;
#[cfg(not(windows))]
use super::tun_provider;
-use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata};
+use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use crate::routing::{self, RequiredRoute, RouteManagerHandle};
+use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
#[cfg(windows)]
use futures::{channel::mpsc, StreamExt};
-use futures::{
- channel::oneshot,
- future::{abortable, AbortHandle as FutureAbortHandle},
- Future,
-};
#[cfg(target_os = "linux")]
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
@@ -54,6 +50,7 @@ mod wireguard_nt;
use self::wireguard_go::WgGoTunnel;
type Result<T> = std::result::Result<T, Error>;
+type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
/// Errors that can happen in the Wireguard tunnel monitor.
#[derive(err_derive::Error, Debug)]
@@ -104,12 +101,7 @@ pub struct WireguardMonitor {
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
- event_callback: Box<
- dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
- + Send
- + Sync
- + 'static,
- >,
+ event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
_obfuscator: Option<ObfuscatorHandle>,
@@ -208,13 +200,13 @@ impl WireguardMonitor {
mut config: Config,
psk_negotiation: Option<PublicKey>,
log_path: Option<&Path>,
- resource_dir: &Path,
- on_event: F,
tun_provider: Arc<Mutex<TunProvider>>,
- route_manager: RouteManagerHandle,
retry_attempt: u32,
- tunnel_close_rx: oneshot::Receiver<()>,
+ route_manager: RouteManagerHandle,
+ init_args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
+ let on_event = init_args.on_event;
+
let endpoint_addrs: Vec<IpAddr> =
config.peers.iter().map(|peer| peer.endpoint.ip()).collect();
let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel();
@@ -228,7 +220,7 @@ impl WireguardMonitor {
runtime.clone(),
&Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
log_path,
- resource_dir,
+ init_args.resource_dir,
tun_provider,
#[cfg(target_os = "windows")]
setup_done_tx,
@@ -351,7 +343,7 @@ impl WireguardMonitor {
});
tokio::spawn(async move {
- if tunnel_close_rx.await.is_ok() {
+ if init_args.tunnel_close_rx.await.is_ok() {
monitor_handle.abort();
let _ = close_msg_sender.send(CloseMsg::Stop);
}
diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs
index bda8af2e1f..cec033f611 100644
--- a/talpid-core/src/tunnel/wireguard/stats.rs
+++ b/talpid-core/src/tunnel/wireguard/stats.rs
@@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla};
#[derive(err_derive::Error, Debug, PartialEq)]
pub enum Error {
#[error(display = "Failed to parse peer pubkey from string \"_0\"")]
- PubKeyParseError(String, #[error(source)] hex::FromHexError),
+ PubKeyParse(String, #[error(source)] hex::FromHexError),
#[error(display = "Failed to parse integer from string \"_0\"")]
- IntParseError(String, #[error(source)] std::num::ParseIntError),
+ IntParse(String, #[error(source)] std::num::ParseIntError),
#[error(display = "Device no longer exists")]
NoTunnelDevice,
@@ -47,7 +47,7 @@ impl Stats {
"public_key" => {
let mut buffer = [0u8; 32];
hex::decode_to_slice(value, &mut buffer)
- .map_err(|err| Error::PubKeyParseError(value.to_string(), err))?;
+ .map_err(|err| Error::PubKeyParse(value.to_string(), err))?;
peer = Some(buffer);
tx_bytes = None;
rx_bytes = None;
@@ -57,7 +57,7 @@ impl Stats {
value
.trim()
.parse()
- .map_err(|err| Error::IntParseError(value.to_string(), err))?,
+ .map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
"tx_bytes" => {
@@ -65,7 +65,7 @@ impl Stats {
value
.trim()
.parse()
- .map_err(|err| Error::IntParseError(value.to_string(), err))?,
+ .map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
@@ -145,7 +145,7 @@ mod test {
assert_eq!(
Stats::parse_config_str(invalid_input),
- Err(Error::IntParseError(invalid_str, int_err))
+ Err(Error::IntParse(invalid_str, int_err))
);
}
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
index 0f3866500e..5b7b6a1e12 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs
@@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to decode netlink message")]
- DecodeError(#[error(source)] DecodeError),
+ Decode(#[error(source)] DecodeError),
#[error(display = "Failed to execute netlink control request")]
- NetlinkControlMessageError(#[error(source)] nl_message::Error),
+ NetlinkControlMessage(#[error(source)] nl_message::Error),
#[error(display = "Failed to open netlink socket")]
- NetlinkSocketError(#[error(source)] std::io::Error),
+ NetlinkSocket(#[error(source)] std::io::Error),
#[error(display = "Failed to send netlink control request")]
- NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
+ NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
#[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")]
WireguardNetlinkInterfaceUnavailable,
@@ -60,25 +60,25 @@ pub enum Error {
NoDevice,
#[error(display = "Failed to get config: _0")]
- WgGetConfError(netlink_packet_core::error::ErrorMessage),
+ WgGetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Failed to apply config: _0")]
- WgSetConfError(netlink_packet_core::error::ErrorMessage),
+ WgSetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Interface name too long")]
- InterfaceNameError,
+ InterfaceName,
#[error(display = "Send request error")]
- SendRequestError(#[error(source)] NetlinkError<DeviceMessage>),
+ SendRequest(#[error(source)] NetlinkError<DeviceMessage>),
#[error(display = "Create device error")]
- NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error),
+ NetlinkCreateDevice(#[error(source)] rtnetlink::Error),
#[error(display = "Add IP to device error")]
- NetlinkSetIpError(rtnetlink::Error),
+ NetlinkSetIp(rtnetlink::Error),
#[error(display = "Failed to delete device")]
- DeleteDeviceError(#[error(source)] rtnetlink::Error),
+ DeleteDevice(#[error(source)] rtnetlink::Error),
#[error(display = "NetworkManager error")]
NetworkManager(#[error(source)] nm_tunnel::Error),
@@ -98,7 +98,7 @@ impl Handle {
pub async fn connect() -> Result<Self, Error> {
let message_type = Self::get_wireguard_message_type().await?;
let (conn, wireguard_connection, _messages) =
- netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
+ netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let wg_handle = WireguardConnection {
message_type,
connection: wireguard_connection,
@@ -106,7 +106,7 @@ impl Handle {
let (abortable_connection, wg_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
let (conn, route_handle, _messages) =
- rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?;
+ rtnetlink::new_connection().map_err(Error::NetlinkSocket)?;
let (abortable_connection, route_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
@@ -120,21 +120,21 @@ impl Handle {
async fn get_wireguard_message_type() -> Result<u16, Error> {
let (conn, mut handle, _messages) =
- netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
+ netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let (conn, abort_handle) = abortable(conn);
tokio::spawn(conn);
let result = async move {
let mut message: NetlinkMessage<NetlinkControlMessage> =
NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap())
- .map_err(Error::NetlinkControlMessageError)?
+ .map_err(Error::NetlinkControlMessage)?
.into();
message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
let mut req = handle
.request(message, SocketAddr::new(0, 0))
- .map_err(Error::NetlinkRequestError)?;
+ .map_err(Error::NetlinkRequest)?;
let response = req.next().await;
if let Some(response) = response {
if let NetlinkPayload::InnerMessage(msg) = response.payload {
@@ -177,14 +177,14 @@ impl Handle {
let mut response = self
.route_handle
.request(add_request)
- .map_err(Error::NetlinkCreateDeviceError)?;
+ .map_err(Error::NetlinkCreateDevice)?;
while let Some(response_message) = response.next().await {
if let NetlinkPayload::Error(err) = response_message.payload {
// if the device exists, verify that it's a wireguard device
if -err.code != libc::EEXIST {
- return Err(Error::NetlinkCreateDeviceError(
- rtnetlink::Error::NetlinkError(err),
- ));
+ return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError(
+ err,
+ )));
}
}
}
@@ -208,9 +208,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
- .map_err(Error::NetlinkSetIpError)?;
+ .map_err(Error::NetlinkSetIp)?;
while let Some(response_message) = response.next().await {
- consume_netlink_error(response_message, Error::NetlinkSetIpError)?;
+ consume_netlink_error(response_message, Error::NetlinkSetIp)?;
}
Ok(())
@@ -226,9 +226,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
- .map_err(Error::DeleteDeviceError)?;
+ .map_err(Error::DeleteDevice)?;
while let Some(message) = response.next().await {
- consume_netlink_error(message, Error::DeleteDeviceError)?;
+ consume_netlink_error(message, Error::DeleteDevice)?;
}
Ok(())
@@ -269,7 +269,7 @@ impl WireguardConnection {
let mut response = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
- .map_err(Error::SendRequestError)?;
+ .map_err(Error::SendRequest)?;
match response.next().await {
Some(received_message) => match received_message.payload {
NetlinkPayload::InnerMessage(inner) => Ok(inner),
@@ -277,7 +277,7 @@ impl WireguardConnection {
if err.code == -libc::ENODEV {
Err(Error::NoDevice)
} else {
- Err(Error::WgGetConfError(err))
+ Err(Error::WgGetConf(err))
}
}
anything_else => {
@@ -297,11 +297,11 @@ impl WireguardConnection {
let mut request = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
- .map_err(Error::SendRequestError)?;
+ .map_err(Error::SendRequest)?;
while let Some(response) = request.next().await {
if let NetlinkPayload::Error(err) = response.payload {
- return Err(Error::WgSetConfError(err));
+ return Err(Error::WgSetConf(err));
}
}
Ok(())
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
index be2231f771..f2de334762 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
@@ -110,9 +110,9 @@ impl DeviceMessage {
}
pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> {
- let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?;
+ let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?;
if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ {
- return Err(Error::InterfaceNameError);
+ return Err(Error::InterfaceName);
}
Ok(Self {
@@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage {
let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..];
let mut nlas = vec![];
for buf in NlasIterator::new(new_payload) {
- nlas.push(
- DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?,
- );
+ nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?);
}
Ok(DeviceMessage {
@@ -391,13 +389,13 @@ impl Nla for PeerNla {
InetAddr::V4(sockaddr_in) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(sockaddr_in) })
+ .write_all(unsafe { struct_as_slice(sockaddr_in) })
.expect("Buffer too small for sockaddr_in");
}
InetAddr::V6(sockaddr_in6) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(sockaddr_in6) })
+ .write_all(unsafe { struct_as_slice(sockaddr_in6) })
.expect("Buffer too small for sockaddr_in6");
}
},
@@ -408,7 +406,7 @@ impl Nla for PeerNla {
let timespec: &libc::timespec = last_handshake.as_ref();
// SAFETY: `timespec` has no padding bytes
buffer
- .write(unsafe { struct_as_slice(timespec) })
+ .write_all(unsafe { struct_as_slice(timespec) })
.expect("Buffer too small for timespec");
}
RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes),
@@ -535,7 +533,7 @@ impl Nla for AllowedIpNla {
}
IpAddr(ip_addr) => {
buffer
- .write(&ip_addr_to_bytes(ip_addr))
+ .write_all(&ip_addr_to_bytes(ip_addr))
.expect("Buffer too small for AllowedIpNla::IpAddr");
}
CidrMask(cidr_mask) => buffer[0] = *cidr_mask,
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 7536b26b09..e787729c04 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -6,7 +6,9 @@ use super::{
use crate::{
firewall::FirewallPolicy,
routing::RouteManager,
- tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor},
+ tunnel::{
+ self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor,
+ },
};
use cfg_if::cfg_if;
use futures::{
@@ -142,16 +144,20 @@ impl ConnectingState {
}
};
+ let init_args = TunnelArgs {
+ resource_dir: &resource_dir,
+ on_event: on_tunnel_event,
+ tunnel_close_rx,
+ };
+
let block_reason = match TunnelMonitor::start(
runtime,
&mut tunnel_parameters,
&log_dir,
- &resource_dir,
- on_tunnel_event,
tun_provider,
- route_manager_handle,
retry_attempt,
- tunnel_close_rx,
+ route_manager_handle,
+ init_args,
) {
Ok(monitor) => {
let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 3552eeab61..4c3eda0ecb 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -132,23 +132,25 @@ pub async fn spawn(
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let weak_command_tx = Arc::downgrade(&command_tx);
- let state_machine = TunnelStateMachine::new(
- initial_settings,
- weak_command_tx,
- offline_state_listener,
+
+ let init_args = TunnelStateMachineInitArgs {
+ settings: initial_settings,
+ command_tx: weak_command_tx,
+ offline_state_tx: offline_state_listener,
tunnel_parameters_generator,
tun_provider,
log_dir,
resource_dir,
- command_rx,
+ commands_rx: command_rx,
#[cfg(target_os = "windows")]
volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
android_context,
- )
- .await?;
+ };
+
+ let state_machine = TunnelStateMachine::new(init_args).await?;
#[cfg(windows)]
let split_tunnel = state_machine.shared_values.split_tunnel.handle();
@@ -219,20 +221,35 @@ struct TunnelStateMachine {
shared_values: SharedTunnelStateValues,
}
+/// Tunnel state machine initialization arguments arguments
+struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> {
+ settings: InitialTunnelState,
+ command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ offline_state_tx: mpsc::UnboundedSender<bool>,
+ tunnel_parameters_generator: G,
+ tun_provider: TunProvider,
+ log_dir: Option<PathBuf>,
+ resource_dir: PathBuf,
+ commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
+ #[cfg(target_os = "windows")]
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ #[cfg(target_os = "macos")]
+ exclusion_gid: u32,
+ #[cfg(target_os = "android")]
+ android_context: AndroidContext,
+}
+
impl TunnelStateMachine {
async fn new(
- settings: InitialTunnelState,
- command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
- offline_state_tx: mpsc::UnboundedSender<bool>,
- tunnel_parameters_generator: impl TunnelParametersGenerator,
- tun_provider: TunProvider,
- log_dir: Option<PathBuf>,
- resource_dir: PathBuf,
- commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
- #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
- #[cfg(target_os = "macos")] exclusion_gid: u32,
- #[cfg(target_os = "android")] android_context: AndroidContext,
+ args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>,
) -> Result<Self, Error> {
+ #[cfg(target_os = "windows")]
+ let volume_update_rx = args.volume_update_rx;
+ #[cfg(target_os = "macos")]
+ let exclusion_gid = args.exclusion_gid;
+ #[cfg(target_os = "android")]
+ let android_context = args.android_context;
+
let runtime = tokio::runtime::Handle::current();
#[cfg(target_os = "macos")]
@@ -242,20 +259,24 @@ impl TunnelStateMachine {
let power_mgmt_rx = crate::windows::window::PowerManagementListener::new();
#[cfg(windows)]
- let split_tunnel =
- split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
- .map_err(Error::InitSplitTunneling)?;
+ let split_tunnel = split_tunnel::SplitTunnel::new(
+ runtime.clone(),
+ args.command_tx.clone(),
+ volume_update_rx,
+ )
+ .map_err(Error::InitSplitTunneling)?;
- let args = FirewallArguments {
- initial_state: if settings.block_when_disconnected || !settings.reset_firewall {
- InitialFirewallState::Blocked(settings.allowed_endpoint.clone())
+ let fw_args = FirewallArguments {
+ initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall
+ {
+ InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone())
} else {
InitialFirewallState::None
},
- allow_lan: settings.allow_lan,
+ allow_lan: args.settings.allow_lan,
};
- let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?;
+ let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?;
let route_manager = RouteManager::new(HashSet::new())
.await
.map_err(Error::InitRouteManagerError)?;
@@ -267,20 +288,20 @@ impl TunnelStateMachine {
.handle()
.map_err(Error::InitRouteManagerError)?,
#[cfg(target_os = "macos")]
- command_tx.clone(),
+ args.command_tx.clone(),
)
.map_err(Error::InitDnsMonitorError)?;
let (offline_tx, mut offline_rx) = mpsc::unbounded();
- let initial_offline_state_tx = offline_state_tx.clone();
+ let initial_offline_state_tx = args.offline_state_tx.clone();
tokio::spawn(async move {
while let Some(offline) = offline_rx.next().await {
- if let Some(tx) = command_tx.upgrade() {
+ if let Some(tx) = args.command_tx.upgrade() {
let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline));
} else {
break;
}
- let _ = offline_state_tx.unbounded_send(offline);
+ let _ = args.offline_state_tx.unbounded_send(offline);
}
});
let mut offline_monitor = offline::spawn_monitor(
@@ -301,7 +322,7 @@ impl TunnelStateMachine {
#[cfg(windows)]
split_tunnel
- .set_paths_sync(&settings.exclude_paths)
+ .set_paths_sync(&args.settings.exclude_paths)
.map_err(Error::InitSplitTunneling)?;
let mut shared_values = SharedTunnelStateValues {
@@ -312,15 +333,15 @@ impl TunnelStateMachine {
dns_monitor,
route_manager,
_offline_monitor: offline_monitor,
- allow_lan: settings.allow_lan,
- block_when_disconnected: settings.block_when_disconnected,
+ allow_lan: args.settings.allow_lan,
+ block_when_disconnected: args.settings.block_when_disconnected,
is_offline,
- dns_servers: settings.dns_servers,
- allowed_endpoint: settings.allowed_endpoint,
- tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
- tun_provider: Arc::new(Mutex::new(tun_provider)),
- log_dir,
- resource_dir,
+ dns_servers: args.settings.dns_servers,
+ allowed_endpoint: args.settings.allowed_endpoint,
+ tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator),
+ tun_provider: Arc::new(Mutex::new(args.tun_provider)),
+ log_dir: args.log_dir,
+ resource_dir: args.resource_dir,
#[cfg(target_os = "linux")]
connectivity_check_was_enabled: None,
#[cfg(target_os = "macos")]
@@ -331,11 +352,11 @@ impl TunnelStateMachine {
tokio::task::spawn_blocking(move || {
let (initial_state, _) =
- DisconnectedState::enter(&mut shared_values, settings.reset_firewall);
+ DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall);
Ok(TunnelStateMachine {
current_state: Some(initial_state),
- commands: commands_rx.fuse(),
+ commands: args.commands_rx.fuse(),
shared_values,
})
})