summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls Piņķis <emils@mullvad.net>2018-06-19 15:48:41 +0100
committerEmīls Piņķis <emils@mullvad.net>2018-06-21 16:06:19 +0100
commit7f592d6477e89372394853a56346a7feaf252208 (patch)
tree5f891cf8929fbd3a0196cc163110db354bfe51aa
parent98cceca9a1f1ce0fc1efcc0d790640e8aa6b4fbd (diff)
downloadmullvadvpn-7f592d6477e89372394853a56346a7feaf252208.tar.xz
mullvadvpn-7f592d6477e89372394853a56346a7feaf252208.zip
Change Firewall trait to take a cache dir as a parameter in it's constructor
-rw-r--r--Cargo.lock1
-rw-r--r--mullvad-daemon/src/main.rs4
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/firewall/linux/mod.rs3
-rw-r--r--talpid-core/src/firewall/macos/mod.rs3
-rw-r--r--talpid-core/src/firewall/mod.rs5
-rw-r--r--talpid-core/src/firewall/windows/dns.rs94
-rw-r--r--talpid-core/src/firewall/windows/ffi.rs51
-rw-r--r--talpid-core/src/firewall/windows/mod.rs79
-rw-r--r--talpid-core/src/firewall/windows/system_state.rs (renamed from talpid-core/src/firewall/system_state.rs)74
-rw-r--r--talpid-core/src/lib.rs2
11 files changed, 128 insertions, 189 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 1e8e6c17f3..10cb559070 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1394,7 +1394,6 @@ dependencies = [
"jsonrpc-macros 8.0.0 (git+https://github.com/paritytech/jsonrpc?tag=v8.0.1)",
"libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "mullvad-paths 0.1.0",
"notify 4.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
"openvpn-plugin 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"os_pipe 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs
index ff936ef8ea..892c438188 100644
--- a/mullvad-daemon/src/main.rs
+++ b/mullvad-daemon/src/main.rs
@@ -239,7 +239,7 @@ impl Daemon {
let (tx, rx) = mpsc::channel();
let management_interface_broadcaster =
- Self::start_management_interface(tx.clone(), cache_dir)?;
+ Self::start_management_interface(tx.clone(), cache_dir.clone())?;
let state = TunnelState::NotRunning;
let target_state = TargetState::Unsecured;
Ok(Daemon {
@@ -260,7 +260,7 @@ impl Daemon {
http_handle,
tokio_remote,
relay_selector,
- firewall: FirewallProxy::new().chain_err(|| ErrorKind::FirewallError)?,
+ firewall: FirewallProxy::new(&cache_dir).chain_err(|| ErrorKind::FirewallError)?,
current_relay: None,
tunnel_endpoint: None,
tunnel_metadata: None,
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 1184980232..e87ea35eb5 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -33,7 +33,6 @@ tokio-core = "0.1"
[target.'cfg(windows)'.dependencies]
libc = "0.2.20"
widestring = "0.3"
-mullvad-paths = { path = "../mullvad-paths" }
[dev-dependencies]
tempfile = "3.0"
diff --git a/talpid-core/src/firewall/linux/mod.rs b/talpid-core/src/firewall/linux/mod.rs
index 6b8e213f54..7761afd7e9 100644
--- a/talpid-core/src/firewall/linux/mod.rs
+++ b/talpid-core/src/firewall/linux/mod.rs
@@ -1,4 +1,5 @@
use error_chain::ChainedError;
+use std::path::Path;
use super::{Firewall, SecurityPolicy};
@@ -20,7 +21,7 @@ pub struct Netfilter {
impl Firewall for Netfilter {
type Error = Error;
- fn new() -> Result<Self> {
+ fn new<P: AsRef<Path>>(_cache_dir: P) -> Result<Self> {
Ok(Netfilter {
dns_settings: DnsSettings::new()?,
})
diff --git a/talpid-core/src/firewall/macos/mod.rs b/talpid-core/src/firewall/macos/mod.rs
index 5cacf9a2dc..394fd916d3 100644
--- a/talpid-core/src/firewall/macos/mod.rs
+++ b/talpid-core/src/firewall/macos/mod.rs
@@ -5,6 +5,7 @@ use self::pfctl::ipnetwork::{IpNetwork, Ipv4Network};
use super::{Firewall, SecurityPolicy};
use std::net::Ipv4Addr;
+use std::path::Path;
use talpid_types::net;
@@ -32,7 +33,7 @@ pub struct PacketFilter {
impl Firewall for PacketFilter {
type Error = Error;
- fn new() -> Result<Self> {
+ fn new<P: AsRef<Path>>(_cache_dir: P) -> Result<Self> {
Ok(PacketFilter {
pf: pfctl::PfCtl::new()?,
pf_was_enabled: None,
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index 6b2a073d40..50b26fbdb9 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -1,6 +1,5 @@
use talpid_types::net::Endpoint;
-
-mod system_state;
+use std::path::Path;
/// A enum that describes firewall rules strategy
@@ -31,7 +30,7 @@ pub trait Firewall {
type Error: ::std::error::Error;
/// Create new instance of Firewall
- fn new() -> ::std::result::Result<Self, Self::Error>
+ fn new<P: AsRef<Path>>(cache_dir: P) -> ::std::result::Result<Self, Self::Error>
where
Self: Sized;
diff --git a/talpid-core/src/firewall/windows/dns.rs b/talpid-core/src/firewall/windows/dns.rs
index fa6f00362e..62b323c34f 100644
--- a/talpid-core/src/firewall/windows/dns.rs
+++ b/talpid-core/src/firewall/windows/dns.rs
@@ -1,72 +1,66 @@
-extern crate libc;
-extern crate mullvad_paths;
extern crate widestring;
-use super::super::system_state::SystemStateWriter;
+use super::system_state::SystemStateWriter;
use super::ffi;
-use self::mullvad_paths::cache_dir;
use self::widestring::WideCString;
use std::net::IpAddr;
-use std::os::raw::c_void;
+use libc;
use std::ptr;
use std::slice;
+use std::path::Path;
const DNS_STATE_FILENAME: &'static str = "dns_state_backup";
error_chain!{
errors{
- #[doc = "Failure to initialize WinDNS"]
+ /// Failure to initialize WinDns
Initialization{
- description("Failed to initialize WinDNS")
+ description("Failed to initialize WinDns")
}
- #[doc = "Failure to deinitialize WinDNS"]
+ /// Failure to deinitialize WinDns
Deinitialization{
- description("Failed to deinitialize WinDNS")
+ description("Failed to deinitialize WinDns")
}
- #[doc = "Failure to set new DNS servers"]
+ /// Failure to set new DNS servers
Setting{
description("Failed to set new DNS servers")
}
- #[doc = "Failure to reset DNS settings"]
+ /// Failure to reset DNS settings
Resetting{
description("Failed to reset DNS")
}
- #[doc = "Failure to reset DNS settings from backup"]
+ /// Failure to reset DNS settings from backup
Recovery{
description("Failed to recover to backed up system state")
}
}
- links {
- NoCacheDir(mullvad_paths::Error, mullvad_paths::ErrorKind) #[doc = "Failure to create a cache directory"];
- }
-
foreign_links {
Io(::std::io::Error) #[doc = "IO error, most probably occurs when reading system state backup"];
}
}
-pub struct WinDNS {
+pub struct WinDns {
backup_writer: SystemStateWriter,
}
-impl WinDNS {
- pub fn new() -> Result<Self> {
+impl WinDns {
+ pub fn new<P: AsRef<Path>>(cache_dir: P) -> Result<Self> {
unsafe { WinDns_Initialize(Some(ffi::error_sink), ptr::null_mut()).into_result()? };
- let backup_writer = SystemStateWriter::new(cache_dir()?.join(DNS_STATE_FILENAME));
- let mut dns = WinDNS { backup_writer };
+ let backup_writer = SystemStateWriter::new(cache_dir.as_ref().join(DNS_STATE_FILENAME).into_boxed_path());
+ let mut dns = WinDns { backup_writer };
dns.restore_system_backup()?;
Ok(dns)
}
- pub fn set_dns(&mut self, servers: Vec<IpAddr>) -> Result<()> {
- debug!("Setting DNS servers - {:?}", servers);
+ pub fn set_dns(&mut self, servers: &[IpAddr]) -> Result<()> {
+ info!("Setting DNS servers - {}", servers.iter().map(|ip| ip.to_string()).collect::<Vec<String>>().join(", "));
let widestring_ips = servers
.iter()
.map(|ip| ip.to_string().encode_utf16().collect::<Vec<_>>())
@@ -83,7 +77,7 @@ impl WinDNS {
ip_ptrs.as_mut_ptr(),
widestring_ips.len() as u32,
Some(write_system_state_backup_cb),
- &self.backup_writer as *const _ as *const c_void,
+ &self.backup_writer as *const _ as *const libc::c_void,
).into_result()
}
}
@@ -92,7 +86,7 @@ impl WinDNS {
trace!("Resetting DNS");
unsafe { WinDns_Reset().into_result()? };
- if let Err(e) = self.backup_writer.remove_state_file() {
+ if let Err(e) = self.backup_writer.remove_backup() {
warn!("Failed to remove DNS state backup file: {}", e);
}
Ok(())
@@ -103,13 +97,14 @@ impl WinDNS {
}
fn restore_system_backup(&mut self) -> Result<()> {
- if let Some(previous_state) = self.backup_writer.consume_state_backup()? {
+ if let Some(previous_state) = self.backup_writer.read_backup()? {
trace!("Restoring system backed up DNS state");
- if let Err(e) = self.restore_dns_settings(&previous_state) {
- self.backup_writer.write_backup(&previous_state)?;
- return Err(e.into());
- }
- trace!("Successfully restored DNS state");
+
+ self.restore_dns_settings(&previous_state)?;
+ info!("Successfully restored DNS state");
+ if let Err(e) = self.backup_writer.remove_backup() {
+ error!("Failed to remove DNS config backup after restoring it: {}", e);
+ }
return Ok(());
}
trace!("No dns state to restore");
@@ -117,30 +112,32 @@ impl WinDNS {
}
}
-impl Drop for WinDNS {
+impl Drop for WinDns {
fn drop(&mut self) {
if unsafe { WinDns_Deinitialize().into_result().is_ok() } {
- trace!("Successfully deinitialized WinDNS");
+ trace!("Successfully deinitialized WinDns");
} else {
- error!("Failed to deinitialize WinDNS");
+ error!("Failed to deinitialize WinDns");
}
}
}
-ffi_error!(init, ErrorKind::Initialization.into());
-ffi_error!(deinit, ErrorKind::Deinitialization.into());
-ffi_error!(setting, ErrorKind::Setting.into());
-ffi_error!(resetting, ErrorKind::Resetting.into());
-ffi_error!(recovering, ErrorKind::Recovery.into());
+ffi_error!(InitializationResult, ErrorKind::Initialization.into());
+ffi_error!(DeinitializationResult, ErrorKind::Deinitialization.into());
+ffi_error!(SettingResult, ErrorKind::Setting.into());
+ffi_error!(ResettingResult, ErrorKind::Resetting.into());
+ffi_error!(RecoveringResult, ErrorKind::Recovery.into());
/// A callback for writing system state data
pub extern "system" fn write_system_state_backup_cb(
blob: *const u8,
length: u32,
- state_writer: *mut SystemStateWriter,
+ state_writer_ptr: *mut libc::c_void,
) -> i32 {
+
+ let state_writer = state_writer_ptr as *mut SystemStateWriter;
if state_writer.is_null() {
error!("State writer pointer is null, can't save system state backup");
return -1;
@@ -168,24 +165,23 @@ pub extern "system" fn write_system_state_backup_cb(
}
-#[allow(improper_ctypes)]
type DNSConfigSink =
- extern "system" fn(data: *const u8, length: u32, state_writer: *mut SystemStateWriter) -> i32;
+ extern "system" fn(data: *const u8, length: u32, state_writer: *mut libc::c_void) -> i32;
-#[allow(non_snake_case, improper_ctypes)]
+#[allow(non_snake_case)]
extern "system" {
#[link_name(WinDns_Initialize)]
pub fn WinDns_Initialize(
sink: Option<ffi::ErrorSink>,
sink_context: *mut libc::c_void,
- ) -> init::FFIResult;
+ ) -> InitializationResult;
// WinDns_Deinitialize:
//
// Call this function once before unloading WINDNS or exiting the process.
#[link_name(WinDns_Deinitialize)]
- pub fn WinDns_Deinitialize() -> deinit::FFIResult;
+ pub fn WinDns_Deinitialize() -> DeinitializationResult;
// Configure which DNS servers should be used and start enforcing these settings.
#[link_name(WinDns_Set)]
@@ -193,16 +189,16 @@ extern "system" {
ips: *mut *const u16,
n_ips: u32,
callback: Option<DNSConfigSink>,
- backup_writer: *const c_void,
- ) -> setting::FFIResult;
+ backup_writer: *const libc::c_void,
+ ) -> SettingResult;
// Revert server settings to what they were before calling WinDns_Set.
//
// (Also taking into account external changes to DNS settings that have ocurred
// during the period of enforcing specific settings.)
#[link_name(WinDns_Reset)]
- pub fn WinDns_Reset() -> resetting::FFIResult;
+ pub fn WinDns_Reset() -> ResettingResult;
#[link_name(WinDns_Recover)]
- pub fn WinDns_Recover(data: *const u8, length: u32) -> recovering::FFIResult;
+ pub fn WinDns_Recover(data: *const u8, length: u32) -> RecoveringResult;
}
diff --git a/talpid-core/src/firewall/windows/ffi.rs b/talpid-core/src/firewall/windows/ffi.rs
index 2cd6c34836..d91ae1dd0f 100644
--- a/talpid-core/src/firewall/windows/ffi.rs
+++ b/talpid-core/src/firewall/windows/ffi.rs
@@ -1,13 +1,11 @@
-extern crate libc;
-use std::os::raw::c_char;
-use std::ptr;
+use libc::{c_char, c_void};
-pub type ErrorSink = extern "system" fn(msg: *const c_char, ctx: *mut libc::c_void);
+pub type ErrorSink = extern "system" fn(msg: *const c_char, ctx: *mut c_void);
-pub extern "system" fn error_sink(msg: *const c_char, _ctx: *mut libc::c_void) {
+pub extern "system" fn error_sink(msg: *const c_char, _ctx: *mut c_void) {
use std::ffi::CStr;
- if msg == ptr::null() {
+ if msg.is_null() {
error!("Log message from FFI boundary is NULL");
} else {
error!("{}", unsafe { CStr::from_ptr(msg).to_string_lossy() });
@@ -17,29 +15,26 @@ pub extern "system" fn error_sink(msg: *const c_char, _ctx: *mut libc::c_void) {
#[macro_export]
macro_rules! ffi_error {
($result:ident, $error:expr) => {
- pub mod $result {
- use super::*;
- #[repr(C)]
- #[derive(Debug)]
- pub struct FFIResult {
- success: bool,
- }
+ #[repr(C)]
+ #[derive(Debug)]
+ pub struct $result {
+ success: bool,
+ }
- impl FFIResult {
- pub fn into_result(self) -> Result<()> {
- match self.success {
- true => Ok(()),
- false => Err($error),
- }
- }
- }
+ impl $result {
+ pub fn into_result(self) -> Result<()> {
+ match self.success {
+ true => Ok(()),
+ false => Err($error),
+ }
+ }
+ }
- impl Into<Result<()>> for FFIResult {
- fn into(self) -> Result<()> {
- self.into_result()
- }
- }
- }
- };
+ impl Into<Result<()>> for $result {
+ fn into(self) -> Result<()> {
+ self.into_result()
+ }
+ }
+}
}
diff --git a/talpid-core/src/firewall/windows/mod.rs b/talpid-core/src/firewall/windows/mod.rs
index 1e36388574..60c3d068d7 100644
--- a/talpid-core/src/firewall/windows/mod.rs
+++ b/talpid-core/src/firewall/windows/mod.rs
@@ -3,6 +3,7 @@ extern crate widestring;
use super::{Firewall, SecurityPolicy};
use std::net::IpAddr;
use std::ptr;
+use std::path::Path;
use self::winfw::*;
use talpid_types::net::Endpoint;
@@ -13,39 +14,40 @@ use self::widestring::WideCString;
#[macro_use]
mod ffi;
mod dns;
+mod system_state;
-use self::dns::WinDNS;
+use self::dns::WinDns;
error_chain!{
errors{
- #[doc = "Failure to initialize windows firewall module"]
+ /// Failure to initialize windows firewall module
Initialization{
description("Failed to initialise windows firewall module")
}
- #[doc = "Failure to deinitialize windows firewall module"]
+ /// Failure to deinitialize windows firewall module
Deinitialization{
description("Failed to deinitialize windows firewall module")
}
- #[doc = "Failure to apply a firewall _connected_ policy"]
+ /// Failure to apply a firewall _connected_ policy
ApplyingConnectedPolicy{
description("Failed to apply firewall policy for when the daemon is connecting to a tunnel")
}
- #[doc = "Failure to apply a firewall _connecting_ policy"]
+ /// Failure to apply a firewall _connecting_ policy
ApplyingConnectingPolicy{
description("Failed to apply firewall policy for when the daemon is connected to a tunnel")
}
- #[doc = "Failure to reset firewall policies"]
+ /// Failure to reset firewall policies
ResettingPolicy{
description("Failed to reset firewall policies")
}
}
links {
- WinDNS(dns::Error, dns::ErrorKind) #[doc = "WinDNS failure"];
+ WinDns(dns::Error, dns::ErrorKind) #[doc = "WinDNS failure"];
}
}
@@ -53,13 +55,14 @@ const WINFW_TIMEOUT_SECONDS: u32 = 2;
/// The Windows implementation for the `Firewall` trait.
pub struct WindowsFirewall {
- dns: WinDNS,
+ dns: WinDns,
}
impl Firewall for WindowsFirewall {
type Error = Error;
- fn new() -> Result<Self> {
+ fn new<P: AsRef<Path>>(cache_dir: P) -> Result<Self> {
+ let windns = WinDns::new(cache_dir)?;
unsafe {
WinFw_Initialize(
WINFW_TIMEOUT_SECONDS,
@@ -68,18 +71,6 @@ impl Firewall for WindowsFirewall {
).into_result()?
};
trace!("Successfully initialized windows firewall module");
- let windns = match WinDNS::new() {
- Ok(w) => w,
- Err(e) => {
- unsafe { WinFw_Deinitialize() }
- .into_result()
- .unwrap_or_else(|_| {
- error!("Failed to denitialize windows firewall module after failing to initialize WinDNS")
- });
- return Err(Error::from(e));
- }
- };
-
Ok(WindowsFirewall { dns: windns })
}
@@ -125,19 +116,19 @@ impl WindowsFirewall {
fn set_connecting_state(
&mut self,
endpoint: &Endpoint,
- wfp_settings: &WinFwSettings,
+ winfw_settings: &WinFwSettings,
) -> Result<()> {
trace!("Applying 'connecting' firewall policy");
let ip_str = Self::widestring_ip(&endpoint.address.ip());
- // ip_str has to outlive wfp_relay
- let wfp_relay = WinFwRelay {
+ // ip_str has to outlive winfw_relay
+ let winfw_relay = WinFwRelay {
ip: ip_str.as_wide_c_str().as_ptr(),
port: endpoint.address.port(),
protocol: WinFwProt::from(endpoint.protocol),
};
- unsafe { WinFw_ApplyPolicyConnecting(wfp_settings, &wfp_relay).into_result() }
+ unsafe { WinFw_ApplyPolicyConnecting(winfw_settings, &winfw_relay).into_result() }
}
fn widestring_ip(ip: &IpAddr) -> WideCString {
@@ -148,7 +139,7 @@ impl WindowsFirewall {
fn set_connected_state(
&mut self,
endpoint: &Endpoint,
- wfp_settings: &WinFwSettings,
+ winfw_settings: &WinFwSettings,
tunnel_metadata: &::tunnel::TunnelMetadata,
) -> Result<()> {
trace!("Applying 'connected' firewall policy");
@@ -158,18 +149,18 @@ impl WindowsFirewall {
let tunnel_alias =
WideCString::new(tunnel_metadata.interface.encode_utf16().collect::<Vec<_>>()).unwrap();
- // ip_str, gateway_str and tunnel_alias have to outlive wfp_relay
- let wfp_relay = WinFwRelay {
+ // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay
+ let winfw_relay = WinFwRelay {
ip: ip_str.as_wide_c_str().as_ptr(),
port: endpoint.address.port(),
protocol: WinFwProt::from(endpoint.protocol),
};
- self.dns.set_dns(vec![tunnel_metadata.gateway.into()])?;
+ self.dns.set_dns(&vec![tunnel_metadata.gateway.into()])?;
unsafe {
WinFw_ApplyPolicyConnected(
- wfp_settings,
- &wfp_relay,
+ winfw_settings,
+ &winfw_relay,
tunnel_alias.as_wide_c_str().as_ptr(),
gateway_str.as_wide_c_str().as_ptr(),
).into_result()
@@ -180,10 +171,8 @@ impl WindowsFirewall {
#[allow(non_snake_case)]
mod winfw {
-
- use super::ffi;
- use super::libc;
- use super::{ErrorKind, Result};
+ use libc;
+ use super::{ffi, ErrorKind, Result};
use talpid_types::net::TransportProtocol;
#[repr(C)]
@@ -224,11 +213,11 @@ mod winfw {
}
}
- ffi_error!(init, ErrorKind::Initialization.into());
- ffi_error!(deinit, ErrorKind::Deinitialization.into());
- ffi_error!(apply_connected, ErrorKind::ApplyingConnectedPolicy.into());
- ffi_error!(apply_connecting, ErrorKind::ApplyingConnectingPolicy.into());
- ffi_error!(reset_policy, ErrorKind::ResettingPolicy.into());
+ ffi_error!(InitializationResult, ErrorKind::Initialization.into());
+ ffi_error!(DeinitializationResult, ErrorKind::Deinitialization.into());
+ ffi_error!(ApplyConnectedResult, ErrorKind::ApplyingConnectedPolicy.into());
+ ffi_error!(ApplyConnectingResult, ErrorKind::ApplyingConnectingPolicy.into());
+ ffi_error!(ResettingPolicyResult, ErrorKind::ResettingPolicy.into());
extern "system" {
#[link_name(WinFw_Initialize)]
@@ -236,16 +225,16 @@ mod winfw {
timeout: libc::c_uint,
sink: Option<ffi::ErrorSink>,
sink_context: *mut libc::c_void,
- ) -> init::FFIResult;
+ ) -> InitializationResult;
#[link_name(WinFw_Deinitialize)]
- pub fn WinFw_Deinitialize() -> deinit::FFIResult;
+ pub fn WinFw_Deinitialize() -> DeinitializationResult;
#[link_name(WinFw_ApplyPolicyConnecting)]
pub fn WinFw_ApplyPolicyConnecting(
settings: &WinFwSettings,
relay: &WinFwRelay,
- ) -> apply_connecting::FFIResult;
+ ) -> ApplyConnectingResult;
#[link_name(WinFw_ApplyPolicyConnected)]
pub fn WinFw_ApplyPolicyConnected(
@@ -253,9 +242,9 @@ mod winfw {
relay: &WinFwRelay,
tunnelIfaceAlias: *const libc::wchar_t,
primaryDns: *const libc::wchar_t,
- ) -> apply_connected::FFIResult;
+ ) -> ApplyConnectingResult;
#[link_name(WinFw_Reset)]
- pub fn WinFw_Reset() -> reset_policy::FFIResult;
+ pub fn WinFw_Reset() -> ResettingPolicyResult;
}
}
diff --git a/talpid-core/src/firewall/system_state.rs b/talpid-core/src/firewall/windows/system_state.rs
index e99fc5033c..682e36d1f8 100644
--- a/talpid-core/src/firewall/system_state.rs
+++ b/talpid-core/src/firewall/windows/system_state.rs
@@ -7,7 +7,6 @@ use std::path::Path;
/// This struct is responsible for saving a binary blob to disk. The binary blob is intended to
/// store system state that should be resotred when the security policy is reset.
-#[repr(C)]
pub struct SystemStateWriter {
/// Full path to the system state backup file
pub backup_path: Box<Path>,
@@ -28,24 +27,19 @@ impl SystemStateWriter {
fs::write(&self.backup_path, &data)
}
- /// Tries to read a previously saved backup and deletes it after reading it if it exists.
- pub fn consume_state_backup(&self) -> io::Result<Option<Vec<u8>>> {
- match fs::read(&self.backup_path) {
- Ok(blob) => {
- if let Err(e) = self.remove_state_file() {
- error!("Failed to remove system state backup: {}", e)
- };
- Ok(Some(blob))
- }
- Err(e) => match e.kind() {
- io::ErrorKind::NotFound => Ok(None),
- _ => Err(e),
- },
- }
+ pub fn read_backup(&self) -> io::Result<Option<Vec<u8>>> {
+ match fs::read(&self.backup_path).map(|blob| Some(blob)) {
+ Ok(b) => Ok(b),
+ Err(e) => match e.kind() {
+ io::ErrorKind::NotFound => Ok(None),
+ _ => Err(e),
+ }
+ }
}
+
/// Removes a previously created state backup if it exists.
- pub fn remove_state_file(&self) -> io::Result<()> {
+ pub fn remove_backup(&self) -> io::Result<()> {
match fs::remove_file(&self.backup_path) {
Err(e) => {
if e.kind() != io::ErrorKind::NotFound {
@@ -77,15 +71,10 @@ mod tests {
.expect("failed to write system state");
let backup = writer
- .consume_state_backup()
+ .read_backup()
.expect("error when reading system state backup")
.expect("expected to read system state backup");
assert_eq!(backup, mock_system_state);
-
- let empty_read = writer
- .consume_state_backup()
- .expect("error when reading system state backup");
- assert_eq!(empty_read, None);
}
#[test]
@@ -95,38 +84,19 @@ mod tests {
let writer = SystemStateWriter::new(&temp_file);
let backup = writer
- .consume_state_backup()
+ .read_backup()
.expect("error when reading system state backup");
assert_eq!(backup, None);
}
- #[cfg(unix)]
- #[test]
- fn cant_read_without_access() {
- let temp_dir = PathBuf::from("/dev/null/bogus");
-
- let writer = SystemStateWriter::new(&temp_dir);
- let mock_system_state: Vec<_> = b"8.8.8.8\n8.8.4.4\n".to_vec();
-
- let failure = writer
- .write_backup(&mock_system_state)
- .expect_err("successfully wrote backup file to a directory in /dev/null");
- assert_eq!(failure.kind(), io::ErrorKind::Other);
-
- let recovery_failure = writer
- .consume_state_backup()
- .expect_err("successfully read backup file in /dev/null");
- assert_eq!(recovery_failure.kind(), io::ErrorKind::Other);
- }
-
#[test]
fn can_remove_when_no_backup_exists() {
let temp_dir = tempfile::tempdir().expect("failed to crate temp dir");
let temp_file = temp_dir.path().join("test_file");
let writer = SystemStateWriter::new(&temp_file);
- writer.remove_state_file().expect(
- "Encountered IO error when running remove_state_file when no state file exists",
+ writer.remove_backup().expect(
+ "Encountered IO error when running remove_backup when no state file exists",
);
}
@@ -141,24 +111,12 @@ mod tests {
.write_backup(&mock_system_state)
.expect("Failed to write backup");
writer
- .remove_state_file()
+ .remove_backup()
.expect("Failed to remove state file");
let empty_backup = writer
- .consume_state_backup()
+ .read_backup()
.expect("Encountered IO error when no backup file exists");
assert_eq!(empty_backup, None);
}
-
- #[cfg(unix)]
- #[test]
- fn cant_remove_backup_with_io_error() {
- let temp_dir = PathBuf::from("/dev/null/bogus");
-
- let writer = SystemStateWriter::new(&temp_dir);
- let removal_failure = writer
- .remove_state_file()
- .expect_err("successfully removed state file in /dev/null");
- assert_eq!(removal_failure.kind(), io::ErrorKind::Other);
- }
}
diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs
index a005cd6b8d..0cda5cdf8f 100644
--- a/talpid-core/src/lib.rs
+++ b/talpid-core/src/lib.rs
@@ -27,6 +27,8 @@ extern crate openvpn_plugin;
extern crate talpid_ipc;
extern crate talpid_types;
+#[cfg(windows)] extern crate libc;
+
/// Working with processes.
pub mod process;