diff options
Diffstat (limited to 'talpid-core/src')
| -rw-r--r-- | talpid-core/src/firewall/linux.rs | 40 | ||||
| -rw-r--r-- | talpid-core/src/firewall/macos.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows.rs | 23 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 16 |
4 files changed, 30 insertions, 61 deletions
diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 24d52eb718..cf4fb267dc 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -12,9 +12,9 @@ use std::{ env, ffi::{CStr, CString}, io, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr}, }; -use talpid_types::net::{AllowedTunnelTraffic, Endpoint, Protocol, TransportProtocol}; +use talpid_types::net::{AllowedTunnelTraffic, Endpoint, TransportProtocol}; /// Priority for rules that tag split tunneling packets. Equals NF_IP_PRI_MANGLE. const MANGLE_CHAIN_PRIORITY: i32 = libc::NF_IP_PRI_MANGLE; @@ -573,12 +573,8 @@ impl<'a> PolicyBatch<'a> { self.add_allow_tunnel_rules(&tunnel.interface)?; } AllowedTunnelTraffic::None => (), - AllowedTunnelTraffic::Only(address, protocol) => { - self.add_allow_in_tunnel_endpoint_rules( - &tunnel.interface, - *address, - *protocol, - )?; + AllowedTunnelTraffic::Only(endpoint) => { + self.add_allow_in_tunnel_endpoint_rules(&tunnel.interface, endpoint)?; } } if *allow_lan { @@ -787,26 +783,16 @@ impl<'a> PolicyBatch<'a> { fn add_allow_in_tunnel_endpoint_rules( &mut self, tunnel_interface: &str, - address: SocketAddr, - protocol: Protocol, + endpoint: &Endpoint, ) -> Result<()> { for (chain, dir, end) in [ (&self.out_chain, Direction::Out, End::Dst), (&self.in_chain, Direction::In, End::Src), ] { let mut rule = Rule::new(chain); - check_iface(&mut rule, dir, tunnel_interface)?; - check_ip(&mut rule, end, address.ip()); - match protocol { - Protocol::IcmpV4 | Protocol::IcmpV6 => check_l4proto(&mut rule, protocol), - Protocol::Tcp => { - check_port(&mut rule, TransportProtocol::Tcp, end, address.port()); - } - Protocol::Udp => { - check_port(&mut rule, TransportProtocol::Udp, end, address.port()); - } - } + check_ip(&mut rule, end, endpoint.address.ip()); + check_port(&mut rule, endpoint.protocol, end, endpoint.address.port()); add_verdict(&mut rule, &Verdict::Accept); self.batch.add(&rule, nftnl::MsgType::Add); } @@ -1030,7 +1016,7 @@ fn check_ip(rule: &mut Rule<'_>, end: End, ip: impl Into<IpAddr>) { fn check_port(rule: &mut Rule<'_>, protocol: TransportProtocol, end: End, port: u16) { // Must check transport layer protocol before loading transport layer payload - check_l4proto(rule, protocol.into()); + check_l4proto(rule, protocol); rule.add_expr(&match (protocol, end) { (TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport), @@ -1053,17 +1039,15 @@ fn l3proto(addr: IpAddr) -> u8 { } } -fn check_l4proto(rule: &mut Rule<'_>, protocol: Protocol) { +fn check_l4proto(rule: &mut Rule<'_>, protocol: TransportProtocol) { rule.add_expr(&nft_expr!(meta l4proto)); rule.add_expr(&nft_expr!(cmp == l4proto(protocol))); } -fn l4proto(protocol: Protocol) -> u8 { +fn l4proto(protocol: TransportProtocol) -> u8 { match protocol { - Protocol::Udp => libc::IPPROTO_UDP as u8, - Protocol::Tcp => libc::IPPROTO_TCP as u8, - Protocol::IcmpV4 => libc::IPPROTO_ICMP as u8, - Protocol::IcmpV6 => libc::IPPROTO_ICMPV6 as u8, + TransportProtocol::Udp => libc::IPPROTO_UDP as u8, + TransportProtocol::Tcp => libc::IPPROTO_TCP as u8, } } diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs index fa40ae7e0d..43b8e81bc1 100644 --- a/talpid-core/src/firewall/macos.rs +++ b/talpid-core/src/firewall/macos.rs @@ -340,15 +340,9 @@ impl Firewall { .keep_state(pfctl::StatePolicy::Keep) .tcp_flags(Self::get_tcp_flags()); match allowed_traffic { - AllowedTunnelTraffic::Only(addr, protocol) => { - use talpid_types::net::Protocol::*; - let pfctl_proto = match protocol { - Udp => pfctl::Proto::Udp, - Tcp => pfctl::Proto::Tcp, - IcmpV4 => pfctl::Proto::Icmp, - IcmpV6 => pfctl::Proto::IcmpV6, - }; - base_rule = base_rule.to(*addr).proto(pfctl_proto); + AllowedTunnelTraffic::Only(endpoint) => { + let pfctl_proto = as_pfctl_proto(endpoint.protocol); + base_rule = base_rule.to(endpoint.address).proto(pfctl_proto); } AllowedTunnelTraffic::All => {} AllowedTunnelTraffic::None => { diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 4683924d17..471e464c0a 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -174,12 +174,12 @@ impl Firewall { let allowed_tun_ip; let allowed_tunnel_endpoint = - if let AllowedTunnelTraffic::Only(addr, proto) = allowed_tunnel_traffic { - allowed_tun_ip = widestring_ip(addr.ip()); + if let AllowedTunnelTraffic::Only(endpoint) = allowed_tunnel_traffic { + allowed_tun_ip = widestring_ip(endpoint.address.ip()); Some(WinFwEndpoint { ip: allowed_tun_ip.as_ptr(), - port: addr.port(), - protocol: WinFwProt::from(*proto), + port: endpoint.address.port(), + protocol: WinFwProt::from(endpoint.protocol), }) } else { None @@ -303,7 +303,7 @@ mod winfw { use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString}; use crate::logging::windows::LogSink; use libc; - use talpid_types::net::{Protocol, TransportProtocol}; + use talpid_types::net::TransportProtocol; pub struct WinFwAllowedEndpointContainer { _clients: Box<[WideCString]>, @@ -397,8 +397,6 @@ mod winfw { pub enum WinFwProt { Tcp = 0u8, Udp = 1u8, - IcmpV4 = 2u8, - IcmpV6 = 3u8, } impl From<TransportProtocol> for WinFwProt { @@ -410,17 +408,6 @@ mod winfw { } } - impl From<Protocol> for WinFwProt { - fn from(prot: Protocol) -> WinFwProt { - match prot { - Protocol::Tcp => WinFwProt::Tcp, - Protocol::Udp => WinFwProt::Udp, - Protocol::IcmpV4 => WinFwProt::IcmpV4, - Protocol::IcmpV6 => WinFwProt::IcmpV6, - } - } - } - #[repr(C)] pub struct WinFwSettings { permitDhcp: bool, diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index b5ee282b7e..7f17726c33 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -21,7 +21,7 @@ use std::io; use std::{ borrow::Cow, convert::Infallible, - net::{IpAddr, SocketAddrV4}, + net::IpAddr, path::Path, pin::Pin, sync::{mpsc as sync_mpsc, Arc, Mutex}, @@ -30,7 +30,10 @@ use std::{ #[cfg(windows)] use talpid_types::BoxedError; use talpid_types::{ - net::{obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Protocol}, + net::{ + obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Endpoint, + TransportProtocol, + }, ErrorExt, }; use tunnel_obfuscation::{ @@ -262,10 +265,11 @@ impl WireguardMonitor { .await?; let allowed_traffic = if psk_negotiation.is_some() { - AllowedTunnelTraffic::Only( - SocketAddrV4::new(config.ipv4_gateway, 1337).into(), - Protocol::Tcp, - ) + AllowedTunnelTraffic::Only(Endpoint::new( + config.ipv4_gateway, + talpid_tunnel_config_client::CONFIG_SERVICE_PORT, + TransportProtocol::Tcp, + )) } else { AllowedTunnelTraffic::All }; |
