summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/firewall/linux/mod.rs44
1 files changed, 19 insertions, 25 deletions
diff --git a/talpid-core/src/firewall/linux/mod.rs b/talpid-core/src/firewall/linux/mod.rs
index abe2e1d920..25e38f99cc 100644
--- a/talpid-core/src/firewall/linux/mod.rs
+++ b/talpid-core/src/firewall/linux/mod.rs
@@ -9,7 +9,7 @@ use nftnl::{
expr::{self, Verdict},
Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table,
};
-use talpid_types::net;
+use talpid_types::net::{Endpoint, TransportProtocol};
use tunnel;
use std::env;
@@ -232,8 +232,8 @@ impl<'a> PolicyBatch<'a> {
self.add_allow_endpoint_rules(relay_endpoint)?;
if let Some(tunnel) = tunnel {
- self.add_dns_rule(tunnel, net::TransportProtocol::Udp)?;
- self.add_dns_rule(tunnel, net::TransportProtocol::Tcp)?;
+ self.add_dns_rule(tunnel, TransportProtocol::Udp)?;
+ self.add_dns_rule(tunnel, TransportProtocol::Tcp)?;
self.add_allow_tunnel_rules(tunnel)?;
}
if allow_lan {
@@ -242,7 +242,7 @@ impl<'a> PolicyBatch<'a> {
Ok(())
}
- fn add_allow_endpoint_rules(&mut self, endpoint: &net::Endpoint) -> Result<()> {
+ fn add_allow_endpoint_rules(&mut self, endpoint: &Endpoint) -> Result<()> {
let mut in_rule = Rule::new(&self.in_chain)?;
check_endpoint(&mut in_rule, End::Src, endpoint)?;
@@ -267,7 +267,7 @@ impl<'a> PolicyBatch<'a> {
fn add_dns_rule(
&mut self,
tunnel: &tunnel::TunnelMetadata,
- protocol: net::TransportProtocol,
+ protocol: TransportProtocol,
) -> Result<()> {
let mut rule = Rule::new(&self.out_chain)?;
@@ -325,7 +325,6 @@ impl<'a> PolicyBatch<'a> {
}
fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a>> {
- const PROTOCOL: net::TransportProtocol = net::TransportProtocol::Udp;
const SERVER_PORT: u16 = 67;
const CLIENT_PORT: u16 = 68;
let broadcast_addr = IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255));
@@ -334,12 +333,12 @@ fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a
match direction {
Direction::In => {
- check_port(&mut rule, PROTOCOL, End::Src, SERVER_PORT)?;
- check_port(&mut rule, PROTOCOL, End::Dst, CLIENT_PORT)?;
+ check_port(&mut rule, TransportProtocol::Udp, End::Src, SERVER_PORT)?;
+ check_port(&mut rule, TransportProtocol::Udp, End::Dst, CLIENT_PORT)?;
}
Direction::Out => {
- check_port(&mut rule, PROTOCOL, End::Src, CLIENT_PORT)?;
- check_port(&mut rule, PROTOCOL, End::Dst, SERVER_PORT)?;
+ check_port(&mut rule, TransportProtocol::Udp, End::Src, CLIENT_PORT)?;
+ check_port(&mut rule, TransportProtocol::Udp, End::Dst, SERVER_PORT)?;
check_ip(&mut rule, End::Dst, broadcast_addr)?;
}
}
@@ -398,7 +397,7 @@ fn check_net(rule: &mut Rule, end: End, net: IpNetwork) -> Result<()> {
Ok(())
}
-fn check_endpoint(rule: &mut Rule, end: End, endpoint: &net::Endpoint) -> Result<()> {
+fn check_endpoint(rule: &mut Rule, end: End, endpoint: &Endpoint) -> Result<()> {
check_ip(rule, end, endpoint.address.ip())?;
check_port(rule, endpoint.protocol, end, endpoint.address.port())?;
Ok(())
@@ -422,20 +421,15 @@ fn check_ip(rule: &mut Rule, end: End, ip: IpAddr) -> Result<()> {
Ok(())
}
-fn check_port(
- rule: &mut Rule,
- protocol: net::TransportProtocol,
- end: End,
- port: u16,
-) -> Result<()> {
+fn check_port(rule: &mut Rule, protocol: TransportProtocol, end: End, port: u16) -> Result<()> {
// Must check transport layer protocol before loading transport layer payload
check_l4proto(rule, protocol)?;
rule.add_expr(match (protocol, end) {
- (net::TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport),
- (net::TransportProtocol::Udp, End::Dst) => nft_expr!(payload udp dport),
- (net::TransportProtocol::Tcp, End::Src) => nft_expr!(payload tcp sport),
- (net::TransportProtocol::Tcp, End::Dst) => nft_expr!(payload tcp dport),
+ (TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport),
+ (TransportProtocol::Udp, End::Dst) => nft_expr!(payload udp dport),
+ (TransportProtocol::Tcp, End::Src) => nft_expr!(payload tcp sport),
+ (TransportProtocol::Tcp, End::Dst) => nft_expr!(payload tcp dport),
})?;
rule.add_expr(nft_expr!(cmp == port.to_be()))?;
Ok(())
@@ -454,16 +448,16 @@ fn l3proto(addr: IpAddr) -> u8 {
}
}
-fn check_l4proto(rule: &mut Rule, protocol: net::TransportProtocol) -> Result<()> {
+fn check_l4proto(rule: &mut Rule, protocol: TransportProtocol) -> Result<()> {
rule.add_expr(nft_expr!(meta l4proto))?;
rule.add_expr(nft_expr!(cmp == l4proto(protocol)))?;
Ok(())
}
-fn l4proto(protocol: net::TransportProtocol) -> u8 {
+fn l4proto(protocol: TransportProtocol) -> u8 {
match protocol {
- net::TransportProtocol::Udp => libc::IPPROTO_UDP as u8,
- net::TransportProtocol::Tcp => libc::IPPROTO_TCP as u8,
+ TransportProtocol::Udp => libc::IPPROTO_UDP as u8,
+ TransportProtocol::Tcp => libc::IPPROTO_TCP as u8,
}
}