diff options
| -rw-r--r-- | talpid-core/src/firewall/linux/mod.rs | 44 |
1 files changed, 19 insertions, 25 deletions
diff --git a/talpid-core/src/firewall/linux/mod.rs b/talpid-core/src/firewall/linux/mod.rs index abe2e1d920..25e38f99cc 100644 --- a/talpid-core/src/firewall/linux/mod.rs +++ b/talpid-core/src/firewall/linux/mod.rs @@ -9,7 +9,7 @@ use nftnl::{ expr::{self, Verdict}, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table, }; -use talpid_types::net; +use talpid_types::net::{Endpoint, TransportProtocol}; use tunnel; use std::env; @@ -232,8 +232,8 @@ impl<'a> PolicyBatch<'a> { self.add_allow_endpoint_rules(relay_endpoint)?; if let Some(tunnel) = tunnel { - self.add_dns_rule(tunnel, net::TransportProtocol::Udp)?; - self.add_dns_rule(tunnel, net::TransportProtocol::Tcp)?; + self.add_dns_rule(tunnel, TransportProtocol::Udp)?; + self.add_dns_rule(tunnel, TransportProtocol::Tcp)?; self.add_allow_tunnel_rules(tunnel)?; } if allow_lan { @@ -242,7 +242,7 @@ impl<'a> PolicyBatch<'a> { Ok(()) } - fn add_allow_endpoint_rules(&mut self, endpoint: &net::Endpoint) -> Result<()> { + fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) -> Result<()> { let mut in_rule = Rule::new(&self.in_chain)?; check_endpoint(&mut in_rule, End::Src, endpoint)?; @@ -267,7 +267,7 @@ impl<'a> PolicyBatch<'a> { fn add_dns_rule( &mut self, tunnel: &tunnel::TunnelMetadata, - protocol: net::TransportProtocol, + protocol: TransportProtocol, ) -> Result<()> { let mut rule = Rule::new(&self.out_chain)?; @@ -325,7 +325,6 @@ impl<'a> PolicyBatch<'a> { } fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a>> { - const PROTOCOL: net::TransportProtocol = net::TransportProtocol::Udp; const SERVER_PORT: u16 = 67; const CLIENT_PORT: u16 = 68; let broadcast_addr = IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)); @@ -334,12 +333,12 @@ fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a match direction { Direction::In => { - check_port(&mut rule, PROTOCOL, End::Src, SERVER_PORT)?; - check_port(&mut rule, PROTOCOL, End::Dst, CLIENT_PORT)?; + check_port(&mut rule, TransportProtocol::Udp, End::Src, SERVER_PORT)?; + check_port(&mut rule, TransportProtocol::Udp, End::Dst, CLIENT_PORT)?; } Direction::Out => { - check_port(&mut rule, PROTOCOL, End::Src, CLIENT_PORT)?; - check_port(&mut rule, PROTOCOL, End::Dst, SERVER_PORT)?; + check_port(&mut rule, TransportProtocol::Udp, End::Src, CLIENT_PORT)?; + check_port(&mut rule, TransportProtocol::Udp, End::Dst, SERVER_PORT)?; check_ip(&mut rule, End::Dst, broadcast_addr)?; } } @@ -398,7 +397,7 @@ fn check_net(rule: &mut Rule, end: End, net: IpNetwork) -> Result<()> { Ok(()) } -fn check_endpoint(rule: &mut Rule, end: End, endpoint: &net::Endpoint) -> Result<()> { +fn check_endpoint(rule: &mut Rule, end: End, endpoint: &Endpoint) -> Result<()> { check_ip(rule, end, endpoint.address.ip())?; check_port(rule, endpoint.protocol, end, endpoint.address.port())?; Ok(()) @@ -422,20 +421,15 @@ fn check_ip(rule: &mut Rule, end: End, ip: IpAddr) -> Result<()> { Ok(()) } -fn check_port( - rule: &mut Rule, - protocol: net::TransportProtocol, - end: End, - port: u16, -) -> Result<()> { +fn check_port(rule: &mut Rule, protocol: TransportProtocol, end: End, port: u16) -> Result<()> { // Must check transport layer protocol before loading transport layer payload check_l4proto(rule, protocol)?; rule.add_expr(match (protocol, end) { - (net::TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport), - (net::TransportProtocol::Udp, End::Dst) => nft_expr!(payload udp dport), - (net::TransportProtocol::Tcp, End::Src) => nft_expr!(payload tcp sport), - (net::TransportProtocol::Tcp, End::Dst) => nft_expr!(payload tcp dport), + (TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport), + (TransportProtocol::Udp, End::Dst) => nft_expr!(payload udp dport), + (TransportProtocol::Tcp, End::Src) => nft_expr!(payload tcp sport), + (TransportProtocol::Tcp, End::Dst) => nft_expr!(payload tcp dport), })?; rule.add_expr(nft_expr!(cmp == port.to_be()))?; Ok(()) @@ -454,16 +448,16 @@ fn l3proto(addr: IpAddr) -> u8 { } } -fn check_l4proto(rule: &mut Rule, protocol: net::TransportProtocol) -> Result<()> { +fn check_l4proto(rule: &mut Rule, protocol: TransportProtocol) -> Result<()> { rule.add_expr(nft_expr!(meta l4proto))?; rule.add_expr(nft_expr!(cmp == l4proto(protocol)))?; Ok(()) } -fn l4proto(protocol: net::TransportProtocol) -> u8 { +fn l4proto(protocol: TransportProtocol) -> u8 { match protocol { - net::TransportProtocol::Udp => libc::IPPROTO_UDP as u8, - net::TransportProtocol::Tcp => libc::IPPROTO_TCP as u8, + TransportProtocol::Udp => libc::IPPROTO_UDP as u8, + TransportProtocol::Tcp => libc::IPPROTO_TCP as u8, } } |
