diff options
| -rw-r--r-- | talpid-core/src/firewall/linux.rs | 40 | ||||
| -rw-r--r-- | talpid-core/src/firewall/macos.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows.rs | 23 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 16 | ||||
| -rw-r--r-- | talpid-tunnel-config-client/src/lib.rs | 4 | ||||
| -rw-r--r-- | talpid-types/src/net/mod.rs | 33 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp | 10 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp | 10 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/rules/shared.cpp | 19 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/rules/shared.h | 2 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/winfw.h | 2 |
11 files changed, 39 insertions, 132 deletions
diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 24d52eb718..cf4fb267dc 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -12,9 +12,9 @@ use std::{ env, ffi::{CStr, CString}, io, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr}, }; -use talpid_types::net::{AllowedTunnelTraffic, Endpoint, Protocol, TransportProtocol}; +use talpid_types::net::{AllowedTunnelTraffic, Endpoint, TransportProtocol}; /// Priority for rules that tag split tunneling packets. Equals NF_IP_PRI_MANGLE. const MANGLE_CHAIN_PRIORITY: i32 = libc::NF_IP_PRI_MANGLE; @@ -573,12 +573,8 @@ impl<'a> PolicyBatch<'a> { self.add_allow_tunnel_rules(&tunnel.interface)?; } AllowedTunnelTraffic::None => (), - AllowedTunnelTraffic::Only(address, protocol) => { - self.add_allow_in_tunnel_endpoint_rules( - &tunnel.interface, - *address, - *protocol, - )?; + AllowedTunnelTraffic::Only(endpoint) => { + self.add_allow_in_tunnel_endpoint_rules(&tunnel.interface, endpoint)?; } } if *allow_lan { @@ -787,26 +783,16 @@ impl<'a> PolicyBatch<'a> { fn add_allow_in_tunnel_endpoint_rules( &mut self, tunnel_interface: &str, - address: SocketAddr, - protocol: Protocol, + endpoint: &Endpoint, ) -> Result<()> { for (chain, dir, end) in [ (&self.out_chain, Direction::Out, End::Dst), (&self.in_chain, Direction::In, End::Src), ] { let mut rule = Rule::new(chain); - check_iface(&mut rule, dir, tunnel_interface)?; - check_ip(&mut rule, end, address.ip()); - match protocol { - Protocol::IcmpV4 | Protocol::IcmpV6 => check_l4proto(&mut rule, protocol), - Protocol::Tcp => { - check_port(&mut rule, TransportProtocol::Tcp, end, address.port()); - } - Protocol::Udp => { - check_port(&mut rule, TransportProtocol::Udp, end, address.port()); - } - } + check_ip(&mut rule, end, endpoint.address.ip()); + check_port(&mut rule, endpoint.protocol, end, endpoint.address.port()); add_verdict(&mut rule, &Verdict::Accept); self.batch.add(&rule, nftnl::MsgType::Add); } @@ -1030,7 +1016,7 @@ fn check_ip(rule: &mut Rule<'_>, end: End, ip: impl Into<IpAddr>) { fn check_port(rule: &mut Rule<'_>, protocol: TransportProtocol, end: End, port: u16) { // Must check transport layer protocol before loading transport layer payload - check_l4proto(rule, protocol.into()); + check_l4proto(rule, protocol); rule.add_expr(&match (protocol, end) { (TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport), @@ -1053,17 +1039,15 @@ fn l3proto(addr: IpAddr) -> u8 { } } -fn check_l4proto(rule: &mut Rule<'_>, protocol: Protocol) { +fn check_l4proto(rule: &mut Rule<'_>, protocol: TransportProtocol) { rule.add_expr(&nft_expr!(meta l4proto)); rule.add_expr(&nft_expr!(cmp == l4proto(protocol))); } -fn l4proto(protocol: Protocol) -> u8 { +fn l4proto(protocol: TransportProtocol) -> u8 { match protocol { - Protocol::Udp => libc::IPPROTO_UDP as u8, - Protocol::Tcp => libc::IPPROTO_TCP as u8, - Protocol::IcmpV4 => libc::IPPROTO_ICMP as u8, - Protocol::IcmpV6 => libc::IPPROTO_ICMPV6 as u8, + TransportProtocol::Udp => libc::IPPROTO_UDP as u8, + TransportProtocol::Tcp => libc::IPPROTO_TCP as u8, } } diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs index fa40ae7e0d..43b8e81bc1 100644 --- a/talpid-core/src/firewall/macos.rs +++ b/talpid-core/src/firewall/macos.rs @@ -340,15 +340,9 @@ impl Firewall { .keep_state(pfctl::StatePolicy::Keep) .tcp_flags(Self::get_tcp_flags()); match allowed_traffic { - AllowedTunnelTraffic::Only(addr, protocol) => { - use talpid_types::net::Protocol::*; - let pfctl_proto = match protocol { - Udp => pfctl::Proto::Udp, - Tcp => pfctl::Proto::Tcp, - IcmpV4 => pfctl::Proto::Icmp, - IcmpV6 => pfctl::Proto::IcmpV6, - }; - base_rule = base_rule.to(*addr).proto(pfctl_proto); + AllowedTunnelTraffic::Only(endpoint) => { + let pfctl_proto = as_pfctl_proto(endpoint.protocol); + base_rule = base_rule.to(endpoint.address).proto(pfctl_proto); } AllowedTunnelTraffic::All => {} AllowedTunnelTraffic::None => { diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 4683924d17..471e464c0a 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -174,12 +174,12 @@ impl Firewall { let allowed_tun_ip; let allowed_tunnel_endpoint = - if let AllowedTunnelTraffic::Only(addr, proto) = allowed_tunnel_traffic { - allowed_tun_ip = widestring_ip(addr.ip()); + if let AllowedTunnelTraffic::Only(endpoint) = allowed_tunnel_traffic { + allowed_tun_ip = widestring_ip(endpoint.address.ip()); Some(WinFwEndpoint { ip: allowed_tun_ip.as_ptr(), - port: addr.port(), - protocol: WinFwProt::from(*proto), + port: endpoint.address.port(), + protocol: WinFwProt::from(endpoint.protocol), }) } else { None @@ -303,7 +303,7 @@ mod winfw { use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString}; use crate::logging::windows::LogSink; use libc; - use talpid_types::net::{Protocol, TransportProtocol}; + use talpid_types::net::TransportProtocol; pub struct WinFwAllowedEndpointContainer { _clients: Box<[WideCString]>, @@ -397,8 +397,6 @@ mod winfw { pub enum WinFwProt { Tcp = 0u8, Udp = 1u8, - IcmpV4 = 2u8, - IcmpV6 = 3u8, } impl From<TransportProtocol> for WinFwProt { @@ -410,17 +408,6 @@ mod winfw { } } - impl From<Protocol> for WinFwProt { - fn from(prot: Protocol) -> WinFwProt { - match prot { - Protocol::Tcp => WinFwProt::Tcp, - Protocol::Udp => WinFwProt::Udp, - Protocol::IcmpV4 => WinFwProt::IcmpV4, - Protocol::IcmpV6 => WinFwProt::IcmpV6, - } - } - } - #[repr(C)] pub struct WinFwSettings { permitDhcp: bool, diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index b5ee282b7e..7f17726c33 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -21,7 +21,7 @@ use std::io; use std::{ borrow::Cow, convert::Infallible, - net::{IpAddr, SocketAddrV4}, + net::IpAddr, path::Path, pin::Pin, sync::{mpsc as sync_mpsc, Arc, Mutex}, @@ -30,7 +30,10 @@ use std::{ #[cfg(windows)] use talpid_types::BoxedError; use talpid_types::{ - net::{obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Protocol}, + net::{ + obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Endpoint, + TransportProtocol, + }, ErrorExt, }; use tunnel_obfuscation::{ @@ -262,10 +265,11 @@ impl WireguardMonitor { .await?; let allowed_traffic = if psk_negotiation.is_some() { - AllowedTunnelTraffic::Only( - SocketAddrV4::new(config.ipv4_gateway, 1337).into(), - Protocol::Tcp, - ) + AllowedTunnelTraffic::Only(Endpoint::new( + config.ipv4_gateway, + talpid_tunnel_config_client::CONFIG_SERVICE_PORT, + TransportProtocol::Tcp, + )) } else { AllowedTunnelTraffic::All }; diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs index fafb95affb..9ab7f6b9d0 100644 --- a/talpid-tunnel-config-client/src/lib.rs +++ b/talpid-tunnel-config-client/src/lib.rs @@ -41,7 +41,9 @@ impl std::error::Error for Error { type RelayConfigService = proto::post_quantum_secure_client::PostQuantumSecureClient<Channel>; -const CONFIG_SERVICE_PORT: u16 = 1337; +/// Port used by the tunnel config service. +pub const CONFIG_SERVICE_PORT: u16 = 1337; + const ALGORITHM_NAME: &str = "Classic-McEliece-8192128f"; /// Generates a new WireGuard key pair and negotiates a PSK with the relay in a PQ-safe diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 16a30bc57d..6c62b02928 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -287,7 +287,7 @@ impl fmt::Display for AllowedEndpoint { pub enum AllowedTunnelTraffic { None, All, - Only(SocketAddr, Protocol), + Only(Endpoint), } impl fmt::Display for AllowedTunnelTraffic { @@ -295,36 +295,7 @@ impl fmt::Display for AllowedTunnelTraffic { match *self { AllowedTunnelTraffic::None => "None".fmt(f), AllowedTunnelTraffic::All => "All".fmt(f), - AllowedTunnelTraffic::Only(addr, proto) => write!(f, "{}/{}", addr, proto), - } - } -} - -/// A protocol: UDP, TCP, or ICMP. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub enum Protocol { - Udp, - Tcp, - IcmpV4, - IcmpV6, -} - -impl fmt::Display for Protocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - Protocol::Udp => "UDP".fmt(f), - Protocol::Tcp => "TCP".fmt(f), - Protocol::IcmpV4 => "ICMPv4".fmt(f), - Protocol::IcmpV6 => "ICMPv6".fmt(f), - } - } -} - -impl From<TransportProtocol> for Protocol { - fn from(proto: TransportProtocol) -> Self { - match proto { - TransportProtocol::Udp => Protocol::Udp, - TransportProtocol::Tcp => Protocol::Tcp, + AllowedTunnelTraffic::Only(endpoint) => endpoint.fmt(f), } } } diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp index d9a1af0f28..9c45d63c92 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp @@ -54,10 +54,7 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller) if (m_tunnelOnlyEndpoint.has_value()) { conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); - if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) - { - conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); - } + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); } @@ -85,10 +82,7 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller) if (m_tunnelOnlyEndpoint.has_value()) { conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); - if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) - { - conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); - } + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); } diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp index 42214b6a77..a4ff6a65e5 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp @@ -54,10 +54,7 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller) if (m_tunnelOnlyEndpoint.has_value()) { conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); - if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) - { - conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); - } + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); } @@ -84,10 +81,7 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller) if (m_tunnelOnlyEndpoint.has_value()) { conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); - if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) - { - conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); - } + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); } diff --git a/windows/winfw/src/winfw/rules/shared.cpp b/windows/winfw/src/winfw/rules/shared.cpp index 1d1123e3eb..0ed80bbd70 100644 --- a/windows/winfw/src/winfw/rules/shared.cpp +++ b/windows/winfw/src/winfw/rules/shared.cpp @@ -45,25 +45,6 @@ std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinF { case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); case WinFwProtocol::Udp: return ConditionProtocol::Udp(); - case WinFwProtocol::Icmp: return ConditionProtocol::Icmp(); - case WinFwProtocol::IcmpV6: return ConditionProtocol::IcmpV6(); - default: - { - THROW_ERROR("Missing case handler in switch clause"); - } - }; -} - -bool ProtocolHasPort(WinFwProtocol protocol) -{ - switch (protocol) - { - case WinFwProtocol::Tcp: - case WinFwProtocol::Udp: - return true; - case WinFwProtocol::Icmp: - case WinFwProtocol::IcmpV6: - return false; default: { THROW_ERROR("Missing case handler in switch clause"); diff --git a/windows/winfw/src/winfw/rules/shared.h b/windows/winfw/src/winfw/rules/shared.h index 4f4da187ca..1fd55cb548 100644 --- a/windows/winfw/src/winfw/rules/shared.h +++ b/windows/winfw/src/winfw/rules/shared.h @@ -15,6 +15,4 @@ void SplitAddresses(const IpSet &in, IpSet &outIpv4, IpSet &outIpv6); std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol); -bool ProtocolHasPort(WinFwProtocol protocol); - } diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index 6394893d91..7a7a1ca9e2 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -33,8 +33,6 @@ enum WinFwProtocol : uint8_t { Tcp = 0, Udp = 1, - Icmp = 2, - IcmpV6 = 3 }; typedef struct tag_WinFwEndpoint |
