diff options
| -rw-r--r-- | mullvad_daemon/src/main.rs | 14 | ||||
| -rw-r--r-- | talpid_core/src/net.rs | 90 | ||||
| -rw-r--r-- | talpid_core/src/process/openvpn.rs | 80 | ||||
| -rw-r--r-- | talpid_core/src/tunnel/mod.rs | 8 |
4 files changed, 69 insertions, 123 deletions
diff --git a/mullvad_daemon/src/main.rs b/mullvad_daemon/src/main.rs index 4848ef956e..62c09608a5 100644 --- a/mullvad_daemon/src/main.rs +++ b/mullvad_daemon/src/main.rs @@ -32,7 +32,7 @@ use std::sync::{Arc, Mutex, mpsc}; use std::thread; use talpid_core::mpsc::IntoSender; -use talpid_core::net::RemoteAddr; +use talpid_core::net::{Endpoint, TransportProtocol}; use talpid_core::tunnel::{self, TunnelEvent, TunnelMonitor}; error_chain!{ @@ -55,10 +55,10 @@ error_chain!{ lazy_static! { // Temporary store of hardcoded remotes. - static ref REMOTES: [RemoteAddr; 3] = [ - RemoteAddr::new("se5.mullvad.net", 1300), - RemoteAddr::new("se6.mullvad.net", 1300), - RemoteAddr::new("se7.mullvad.net", 1300), + static ref REMOTES: [Endpoint; 3] = [ + Endpoint::new("se5.mullvad.net", 1300, TransportProtocol::Udp), + Endpoint::new("se6.mullvad.net", 1300, TransportProtocol::Udp), + Endpoint::new("se7.mullvad.net", 1300, TransportProtocol::Udp), ]; } @@ -117,7 +117,7 @@ struct Daemon { // Just for testing. A cyclic iterator iterating over the hardcoded remotes, // picking a new one for each retry. - remote_iter: std::iter::Cycle<std::iter::Cloned<std::slice::Iter<'static, RemoteAddr>>>, + remote_iter: std::iter::Cycle<std::iter::Cloned<std::slice::Iter<'static, Endpoint>>>, } impl Daemon { @@ -314,7 +314,7 @@ impl Daemon { Ok(()) } - fn spawn_tunnel_monitor(&self, remote: RemoteAddr) -> Result<TunnelMonitor> { + fn spawn_tunnel_monitor(&self, remote: Endpoint) -> Result<TunnelMonitor> { // Must wrap the channel in a Mutex because TunnelMonitor forces the closure to be Sync let event_tx = Arc::new(Mutex::new(self.tx.clone())); let on_tunnel_event = move |event| { diff --git a/talpid_core/src/net.rs b/talpid_core/src/net.rs index f5f40688a7..310adee6ca 100644 --- a/talpid_core/src/net.rs +++ b/talpid_core/src/net.rs @@ -1,11 +1,6 @@ use std::fmt; -use std::io; -use std::iter; use std::net::SocketAddr; -use std::option; -use std::slice; use std::str::FromStr; -use std::vec; error_chain! { @@ -19,6 +14,34 @@ error_chain! { } +/// Represents a network layer IP address together with the transport layer protocol and port. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Endpoint { + /// The address part of this endpoint, contains the IP and port. + pub address: RemoteAddr, + /// The protocol part of this endpoint. + pub protocol: TransportProtocol, +} + +impl Endpoint { + /// Constructs a new `Endpoint` from the given parameters. + pub fn new(address: &str, port: u16, protocol: TransportProtocol) -> Self { + Endpoint { + address: RemoteAddr::new(address, port), + protocol: protocol, + } + } +} + +/// Representation of a transport protocol, either UDP or TCP. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum TransportProtocol { + /// Represents the UDP transport protocol. + Udp, + /// Represents the TCP transport protocol. + Tcp, +} + /// Representation of a TCP or UDP endpoint. The IP level address is represented by either an IP /// directly or a hostname/domain. The IP level address together with a port becomes a socket /// address. @@ -101,63 +124,6 @@ impl fmt::Display for RemoteAddr { } } -/// A trait for objects which can be converted to one or more `RemoteAddr` values. -pub trait ToRemoteAddrs { - /// Returned iterator over remote addresses which this type may correspond - /// to. - type Iter: Iterator<Item = RemoteAddr>; - - /// Converts this object to an iterator of parsed `RemoteAddr`s. - /// - /// # Errors - /// - /// Any errors encountered during parsing will be returned as an `Err`. - fn to_remote_addrs(&self) -> io::Result<Self::Iter>; -} - -impl ToRemoteAddrs for RemoteAddr { - type Iter = option::IntoIter<RemoteAddr>; - - fn to_remote_addrs(&self) -> io::Result<Self::Iter> { - Ok(Some(self.clone()).into_iter()) - } -} - -impl<'a> ToRemoteAddrs for &'a [RemoteAddr] { - type Iter = iter::Cloned<slice::Iter<'a, RemoteAddr>>; - - fn to_remote_addrs(&self) -> io::Result<Self::Iter> { - Ok(self.iter().cloned()) - } -} - -impl<'a> ToRemoteAddrs for &'a str { - type Iter = option::IntoIter<RemoteAddr>; - - fn to_remote_addrs(&self) -> io::Result<Self::Iter> { - let parsed_addr = str_to_remote_addr(self)?; - Ok(Some(parsed_addr).into_iter()) - } -} - -impl<'a> ToRemoteAddrs for &'a [&'a str] { - type Iter = vec::IntoIter<RemoteAddr>; - - fn to_remote_addrs(&self) -> io::Result<Self::Iter> { - let mut addrs = Vec::with_capacity(self.len()); - for addr in self.iter() { - addrs.push(str_to_remote_addr(addr)?); - } - Ok(addrs.into_iter()) - } -} - -fn str_to_remote_addr(s: &str) -> io::Result<RemoteAddr> { - RemoteAddr::from_str(s) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.description())) -} - - #[cfg(test)] mod remote_addr_tests { diff --git a/talpid_core/src/process/openvpn.rs b/talpid_core/src/process/openvpn.rs index 652256da0f..ba6688b034 100644 --- a/talpid_core/src/process/openvpn.rs +++ b/talpid_core/src/process/openvpn.rs @@ -2,11 +2,10 @@ extern crate openvpn_ffi; use duct; -use net::{RemoteAddr, ToRemoteAddrs}; +use net; use std::ffi::{OsStr, OsString}; use std::fmt; -use std::io; use std::path::{Path, PathBuf}; static BASE_ARGUMENTS: &[&[&str]] = &[ @@ -35,7 +34,7 @@ static ALLOWED_TLS_CIPHERS: &[&str] = &[ pub struct OpenVpnCommand { openvpn_bin: OsString, config: Option<PathBuf>, - remotes: Vec<RemoteAddr>, + remote: Option<net::Endpoint>, plugin: Option<(PathBuf, Vec<String>)>, } @@ -46,7 +45,7 @@ impl OpenVpnCommand { OpenVpnCommand { openvpn_bin: OsString::from(openvpn_bin.as_ref()), config: None, - remotes: vec![], + remote: None, plugin: None, } } @@ -57,11 +56,10 @@ impl OpenVpnCommand { self } - /// Sets the addresses that OpenVPN will connect to. See OpenVPN documentation for how multiple - /// remotes are handled. - pub fn remotes<A: ToRemoteAddrs>(&mut self, remotes: A) -> io::Result<&mut Self> { - self.remotes = remotes.to_remote_addrs()?.collect(); - Ok(self) + /// Sets the address and protocol that OpenVPN will connect to. + pub fn remote(&mut self, remote: net::Endpoint) -> &mut Self { + self.remote = Some(remote); + self } /// Sets a plugin and its arguments that OpenVPN will be started with. @@ -84,11 +82,9 @@ impl OpenVpnCommand { args.push(OsString::from("--config")); args.push(OsString::from(config.as_os_str())); } - for remote in &self.remotes { - args.push(OsString::from("--remote")); - args.push(OsString::from(remote.address())); - args.push(OsString::from(remote.port().to_string())); - } + + args.extend(self.remote_arguments().iter().map(OsString::from)); + if let Some((ref path, ref plugin_args)) = self.plugin { args.push(OsString::from("--plugin")); args.push(OsString::from(path)); @@ -116,6 +112,23 @@ impl OpenVpnCommand { args.push(ALLOWED_TLS_CIPHERS.join(":")); args } + + fn remote_arguments(&self) -> Vec<String> { + let mut args: Vec<String> = vec![]; + if let Some(ref endpoint) = self.remote { + args.push("--proto".to_owned()); + args.push( + match endpoint.protocol { + net::TransportProtocol::Udp => "udp".to_owned(), + net::TransportProtocol::Tcp => "tcp-client".to_owned(), + }, + ); + args.push("--remote".to_owned()); + args.push(endpoint.address.address()); + args.push(endpoint.address.port().to_string()); + } + args + } } impl fmt::Display for OpenVpnCommand { @@ -147,52 +160,21 @@ fn write_argument(fmt: &mut fmt::Formatter, arg: &str) -> fmt::Result { #[cfg(test)] mod tests { use super::OpenVpnCommand; - use net::RemoteAddr; + use net::{Endpoint, TransportProtocol}; use std::ffi::OsString; #[test] fn passes_one_remote() { - let remote = RemoteAddr::new("example.com", 3333); + let remote = Endpoint::new("example.com", 3333, TransportProtocol::Udp); - let testee_args = OpenVpnCommand::new("").remotes(remote).unwrap().get_arguments(); + let testee_args = OpenVpnCommand::new("").remote(remote).get_arguments(); + assert!(testee_args.contains(&OsString::from("udp"))); assert!(testee_args.contains(&OsString::from("example.com"))); assert!(testee_args.contains(&OsString::from("3333"))); } #[test] - fn passes_two_remotes() { - let remotes = vec![ - RemoteAddr::new("127.0.0.1", 998), - RemoteAddr::new("fe80::1", 1337), - ]; - - let testee_args = OpenVpnCommand::new("").remotes(&remotes[..]).unwrap().get_arguments(); - - assert!(testee_args.contains(&OsString::from("127.0.0.1"))); - assert!(testee_args.contains(&OsString::from("998"))); - assert!(testee_args.contains(&OsString::from("fe80::1"))); - assert!(testee_args.contains(&OsString::from("1337"))); - } - - #[test] - fn accepts_str() { - assert!(OpenVpnCommand::new("").remotes("10.0.0.1:1377").is_ok()); - } - - #[test] - fn accepts_slice_of_str() { - let remotes = ["10.0.0.1:1337", "127.0.0.1:99"]; - - let testee_args = OpenVpnCommand::new("").remotes(&remotes[..]).unwrap().get_arguments(); - - assert!(testee_args.contains(&OsString::from("10.0.0.1"))); - assert!(testee_args.contains(&OsString::from("1337"))); - assert!(testee_args.contains(&OsString::from("127.0.0.1"))); - assert!(testee_args.contains(&OsString::from("99"))); - } - - #[test] fn passes_plugin_path() { let path = "./a/path"; let testee_args = OpenVpnCommand::new("").plugin(path, vec![]).get_arguments(); diff --git a/talpid_core/src/tunnel/mod.rs b/talpid_core/src/tunnel/mod.rs index 6060e13208..45b518ad9e 100644 --- a/talpid_core/src/tunnel/mod.rs +++ b/talpid_core/src/tunnel/mod.rs @@ -56,7 +56,7 @@ pub struct TunnelMonitor { impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` /// on tunnel state changes. - pub fn new<L>(remote: net::RemoteAddr, on_event: L) -> Result<Self> + pub fn new<L>(remote: net::Endpoint, on_event: L) -> Result<Self> where L: Fn(TunnelEvent) + Send + Sync + 'static { let on_openvpn_event = move |event, _env| match TunnelEvent::from_openvpn_event(&event) { @@ -69,11 +69,9 @@ impl TunnelMonitor { Ok(TunnelMonitor { monitor }) } - fn create_openvpn_cmd(remote: net::RemoteAddr) -> OpenVpnCommand { + fn create_openvpn_cmd(remote: net::Endpoint) -> OpenVpnCommand { let mut cmd = OpenVpnCommand::new("openvpn"); - cmd.config(get_config_path()) - .remotes(remote) - .unwrap(); + cmd.config(get_config_path()).remote(remote); cmd } |
