summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/firewall/linux/mod.rs50
1 files changed, 30 insertions, 20 deletions
diff --git a/talpid-core/src/firewall/linux/mod.rs b/talpid-core/src/firewall/linux/mod.rs
index fdfedb6cb5..502f12e033 100644
--- a/talpid-core/src/firewall/linux/mod.rs
+++ b/talpid-core/src/firewall/linux/mod.rs
@@ -6,7 +6,7 @@ use ipnetwork::IpNetwork;
use libc;
use nftnl::{
self,
- expr::{self, InterfaceName, Verdict},
+ expr::{self, Verdict},
Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table,
};
use talpid_types::net;
@@ -14,6 +14,7 @@ use tunnel;
use std::env;
use std::ffi::CString;
+use std::io;
use std::net::{IpAddr, Ipv4Addr};
use std::path::Path;
@@ -32,6 +33,11 @@ error_chain! {
NetlinkRecvError { description("Error while reading from netlink socket") }
/// Error while processing an incoming netlink message
ProcessNetlinkError { description("Error while processing an incoming netlink message") }
+ /// The name is not a valid Linux network interface name
+ InvalidInterfaceName(name: String) {
+ description("Invalid network interface name")
+ display("Invalid network interface name: {}", name)
+ }
}
links {
DnsSettings(self::dns::Error, self::dns::ErrorKind) #[doc = "DNS error"];
@@ -187,13 +193,13 @@ impl<'a> PolicyBatch<'a> {
}
fn add_loopback_rules(&mut self) -> Result<()> {
- let loopback_device = InterfaceName::Exact(CString::new("lo").unwrap());
+ const LOOPBACK_IFACE_NAME: &str = "lo";
self.batch.add(
- &allow_interface_rule(&self.out_chain, Direction::Out, &loopback_device)?,
+ &allow_interface_rule(&self.out_chain, Direction::Out, LOOPBACK_IFACE_NAME)?,
nftnl::MsgType::Add,
)?;
self.batch.add(
- &allow_interface_rule(&self.in_chain, Direction::In, &loopback_device)?,
+ &allow_interface_rule(&self.in_chain, Direction::In, LOOPBACK_IFACE_NAME)?,
nftnl::MsgType::Add,
)?;
Ok(())
@@ -267,11 +273,9 @@ impl<'a> PolicyBatch<'a> {
protocol: net::TransportProtocol,
) -> Result<()> {
let mut rule = Rule::new(&self.out_chain)?;
- rule.add_expr(nft_expr!(meta oifname))?;
- rule.add_expr(nft_expr!(cmp == tunnel_iface_name(tunnel)))?;
+ check_iface(&mut rule, Direction::Out, &tunnel.interface[..])?;
check_port(&mut rule, protocol, End::Dst, 53)?;
-
check_l3proto(&mut rule, IpAddr::V4(tunnel.gateway))?;
rule.add_expr(nft_expr!(payload ipv4 daddr))?;
@@ -284,13 +288,12 @@ impl<'a> PolicyBatch<'a> {
}
fn add_allow_tunnel_rules(&mut self, tunnel: &tunnel::TunnelMetadata) -> Result<()> {
- let tunnel_interface = tunnel_iface_name(tunnel);
self.batch.add(
- &allow_interface_rule(&self.out_chain, Direction::Out, &tunnel_interface)?,
+ &allow_interface_rule(&self.out_chain, Direction::Out, &tunnel.interface[..])?,
nftnl::MsgType::Add,
)?;
self.batch.add(
- &allow_interface_rule(&self.in_chain, Direction::In, &tunnel_interface)?,
+ &allow_interface_rule(&self.in_chain, Direction::In, &tunnel.interface[..])?,
nftnl::MsgType::Add,
)?;
Ok(())
@@ -352,7 +355,7 @@ fn allow_dhcp_rule<'a>(chain: &'a Chain, direction: Direction) -> Result<Rule<'a
fn allow_interface_rule<'a>(
chain: &'a Chain,
direction: Direction,
- iface: &InterfaceName,
+ iface: &str,
) -> Result<Rule<'a>> {
let mut rule = Rule::new(&chain)?;
check_iface(&mut rule, direction, iface)?;
@@ -361,16 +364,27 @@ fn allow_interface_rule<'a>(
Ok(rule)
}
-
-fn check_iface(rule: &mut Rule, direction: Direction, iface: &InterfaceName) -> Result<()> {
+fn check_iface(rule: &mut Rule, direction: Direction, iface: &str) -> Result<()> {
+ let iface_index = iface_index(iface)?;
rule.add_expr(match direction {
- Direction::In => nft_expr!(meta iifname),
- Direction::Out => nft_expr!(meta oifname),
+ Direction::In => nft_expr!(meta iif),
+ Direction::Out => nft_expr!(meta oif),
})?;
- rule.add_expr(nft_expr!(cmp == iface))?;
+ rule.add_expr(nft_expr!(cmp == iface_index))?;
Ok(())
}
+fn iface_index(name: &str) -> Result<libc::c_uint> {
+ let c_name = CString::new(name).chain_err(|| ErrorKind::InvalidInterfaceName(name.to_owned()))?;
+ let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
+ if index == 0 {
+ let error = io::Error::last_os_error();
+ Err(error).chain_err(|| ErrorKind::InvalidInterfaceName(name.to_owned()))
+ } else {
+ Ok(index)
+ }
+}
+
fn check_net(rule: &mut Rule, end: End, net: IpNetwork) -> Result<()> {
// Must check network layer protocol before loading network layer payload
check_l3proto(rule, net.ip())?;
@@ -430,10 +444,6 @@ fn check_port(
Ok(())
}
-fn tunnel_iface_name(tunnel: &tunnel::TunnelMetadata) -> InterfaceName {
- InterfaceName::Exact(CString::new(&tunnel.interface[..]).unwrap())
-}
-
fn check_l3proto(rule: &mut Rule, ip: IpAddr) -> Result<()> {
rule.add_expr(nft_expr!(meta nfproto))?;
rule.add_expr(nft_expr!(cmp == l3proto(ip)))?;