diff options
| -rw-r--r-- | talpid-core/src/firewall/linux/mod.rs | 50 |
1 files changed, 30 insertions, 20 deletions
diff --git a/talpid-core/src/firewall/linux/mod.rs b/talpid-core/src/firewall/linux/mod.rs index fdfedb6cb5..502f12e033 100644 --- a/talpid-core/src/firewall/linux/mod.rs +++ b/talpid-core/src/firewall/linux/mod.rs @@ -6,7 +6,7 @@ use ipnetwork::IpNetwork; use libc; use nftnl::{ self, - expr::{self, InterfaceName, Verdict}, + expr::{self, Verdict}, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table, }; use talpid_types::net; @@ -14,6 +14,7 @@ use tunnel; use std::env; use std::ffi::CString; +use std::io; use std::net::{IpAddr, Ipv4Addr}; use std::path::Path; @@ -32,6 +33,11 @@ error_chain! { NetlinkRecvError { description("Error while reading from netlink socket") } /// Error while processing an incoming netlink message ProcessNetlinkError { description("Error while processing an incoming netlink message") } + /// The name is not a valid Linux network interface name + InvalidInterfaceName(name: String) { + description("Invalid network interface name") + display("Invalid network interface name: {}", name) + } } links { DnsSettings(self::dns::Error, self::dns::ErrorKind) #[doc = "DNS error"]; @@ -187,13 +193,13 @@ impl<'a> PolicyBatch<'a> { } fn add_loopback_rules(&mut self) -> Result<()> { - let loopback_device = InterfaceName::Exact(CString::new("lo").unwrap()); + const LOOPBACK_IFACE_NAME: &str = "lo"; self.batch.add( - &allow_interface_rule(&self.out_chain, Direction::Out, &loopback_device)?, + &allow_interface_rule(&self.out_chain, Direction::Out, LOOPBACK_IFACE_NAME)?, nftnl::MsgType::Add, )?; self.batch.add( - &allow_interface_rule(&self.in_chain, Direction::In, &loopback_device)?, + &allow_interface_rule(&self.in_chain, Direction::In, LOOPBACK_IFACE_NAME)?, nftnl::MsgType::Add, )?; Ok(()) @@ -267,11 +273,9 @@ impl<'a> PolicyBatch<'a> { protocol: net::TransportProtocol, ) -> Result<()> { let mut rule = Rule::new(&self.out_chain)?; - rule.add_expr(nft_expr!(meta oifname))?; - rule.add_expr(nft_expr!(cmp == tunnel_iface_name(tunnel)))?; + check_iface(&mut rule, Direction::Out, &tunnel.interface[..])?; check_port(&mut rule, protocol, End::Dst, 53)?; - check_l3proto(&mut rule, IpAddr::V4(tunnel.gateway))?; rule.add_expr(nft_expr!(payload ipv4 daddr))?; @@ -284,13 +288,12 @@ impl<'a> PolicyBatch<'a> { } fn add_allow_tunnel_rules(&mut self, tunnel: &tunnel::TunnelMetadata) -> Result<()> { - let tunnel_interface = tunnel_iface_name(tunnel); self.batch.add( - &allow_interface_rule(&self.out_chain, Direction::Out, &tunnel_interface)?, + &allow_interface_rule(&self.out_chain, Direction::Out, &tunnel.interface[..])?, nftnl::MsgType::Add, )?; self.batch.add( - &allow_interface_rule(&self.in_chain, Direction::In, &tunnel_interface)?, + &allow_interface_rule(&self.in_chain, Direction::In, &tunnel.interface[..])?, nftnl::MsgType::Add, )?; Ok(()) @@ -352,7 +355,7 @@ fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a fn allow_interface_rule<'a>( chain: &'a Chain, direction: Direction, - iface: &InterfaceName, + iface: &str, ) -> Result<Rule<'a>> { let mut rule = Rule::new(&chain)?; check_iface(&mut rule, direction, iface)?; @@ -361,16 +364,27 @@ fn allow_interface_rule<'a>( Ok(rule) } - -fn check_iface(rule: &mut Rule, direction: Direction, iface: &InterfaceName) -> Result<()> { +fn check_iface(rule: &mut Rule, direction: Direction, iface: &str) -> Result<()> { + let iface_index = iface_index(iface)?; rule.add_expr(match direction { - Direction::In => nft_expr!(meta iifname), - Direction::Out => nft_expr!(meta oifname), + Direction::In => nft_expr!(meta iif), + Direction::Out => nft_expr!(meta oif), })?; - rule.add_expr(nft_expr!(cmp == iface))?; + rule.add_expr(nft_expr!(cmp == iface_index))?; Ok(()) } +fn iface_index(name: &str) -> Result<libc::c_uint> { + let c_name = CString::new(name).chain_err(|| ErrorKind::InvalidInterfaceName(name.to_owned()))?; + let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; + if index == 0 { + let error = io::Error::last_os_error(); + Err(error).chain_err(|| ErrorKind::InvalidInterfaceName(name.to_owned())) + } else { + Ok(index) + } +} + fn check_net(rule: &mut Rule, end: End, net: IpNetwork) -> Result<()> { // Must check network layer protocol before loading network layer payload check_l3proto(rule, net.ip())?; @@ -430,10 +444,6 @@ fn check_port( Ok(()) } -fn tunnel_iface_name(tunnel: &tunnel::TunnelMetadata) -> InterfaceName { - InterfaceName::Exact(CString::new(&tunnel.interface[..]).unwrap()) -} - fn check_l3proto(rule: &mut Rule, ip: IpAddr) -> Result<()> { rule.add_expr(nft_expr!(meta nfproto))?; rule.add_expr(nft_expr!(cmp == l3proto(ip)))?; |
