diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-06-14 13:05:24 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 13:05:24 +0200 |
| commit | c3bdb0ebd3e99c22051656219f2db5a715da0a0d (patch) | |
| tree | a83f10937eccad7e8a001cfc46357725c15a65da | |
| parent | a38dce1ce10893f8e1a077c95ca4afe085bcaeea (diff) | |
| parent | 7f49fb807b4e49113dc078302fa1aabdbd0a931e (diff) | |
| download | mullvadvpn-c3bdb0ebd3e99c22051656219f2db5a715da0a0d.tar.xz mullvadvpn-c3bdb0ebd3e99c22051656219f2db5a715da0a0d.zip | |
Merge branch 'add-pq-safe-tunnels'
55 files changed, 1367 insertions, 150 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index a985a3d4a6..e49f5da290 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Line wrap the file at 100 chars. Th ## [Unreleased] ### Added - Add option to filter relays by ownership in the desktop apps. +- Experimental: Add support for quantum-resistant PSK exchange to the CLI. #### Linux - Automatically attempt to detect and set the correct MTU for Wireguard tunnels. diff --git a/Cargo.lock b/Cargo.lock index 7d2f5f5582..5ef4205b00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -338,6 +338,15 @@ dependencies = [ ] [[package]] +name = "classic-mceliece-rust" +version = "1.0.1" +source = "git+https://github.com/mullvad/classic-mceliece-rust?rev=5130d9e3bfbf54735177e15636a643366c250b78#5130d9e3bfbf54735177e15636a643366c250b78" +dependencies = [ + "rand 0.8.4", + "sha3", +] + +[[package]] name = "colored" version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1406,6 +1415,12 @@ dependencies = [ ] [[package]] +name = "keccak" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9b7d56ba4a8344d6be9729995e6b06f928af29998cdf79fe390cbf6b1fee838" + +[[package]] name = "kernel32-sys" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1779,8 +1794,8 @@ dependencies = [ "mullvad-types", "nix 0.23.1", "parity-tokio-ipc", - "prost", - "prost-types", + "prost 0.8.0", + "prost-types 0.8.0", "talpid-types", "tokio", "tonic", @@ -2468,7 +2483,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.8.0", +] + +[[package]] +name = "prost" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001" +dependencies = [ + "bytes", + "prost-derive 0.9.0", ] [[package]] @@ -2483,8 +2508,8 @@ dependencies = [ "log", "multimap", "petgraph", - "prost", - "prost-types", + "prost 0.8.0", + "prost-types 0.8.0", "tempfile", "which", ] @@ -2503,13 +2528,36 @@ dependencies = [ ] [[package]] +name = "prost-derive" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] name = "prost-types" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b" dependencies = [ "bytes", - "prost", + "prost 0.8.0", +] + +[[package]] +name = "prost-types" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a" +dependencies = [ + "bytes", + "prost 0.9.0", ] [[package]] @@ -2995,6 +3043,16 @@ dependencies = [ ] [[package]] +name = "sha3" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f935e31cf406e8c0e96c2815a5516181b7004ae8c5f296293221e9b1e356bd" +dependencies = [ + "digest 0.10.1", + "keccak", +] + +[[package]] name = "shadowsocks" version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3267,7 +3325,7 @@ dependencies = [ "parity-tokio-ipc", "parking_lot 0.11.2", "pfctl", - "prost", + "prost 0.8.0", "quickcheck", "quickcheck_macros", "rand 0.7.3", @@ -3282,6 +3340,7 @@ dependencies = [ "talpid-dbus", "talpid-platform-metadata", "talpid-time", + "talpid-tunnel-config-client", "talpid-types", "tempfile", "tokio", @@ -3321,7 +3380,7 @@ dependencies = [ "log", "openvpn-plugin", "parity-tokio-ipc", - "prost", + "prost 0.8.0", "talpid-types", "tokio", "tonic", @@ -3349,6 +3408,22 @@ dependencies = [ ] [[package]] +name = "talpid-tunnel-config-client" +version = "0.1.0" +dependencies = [ + "classic-mceliece-rust", + "log", + "prost 0.8.0", + "prost-types 0.9.0", + "rand 0.8.4", + "talpid-types", + "tokio", + "tonic", + "tonic-build", + "tower", +] + +[[package]] name = "talpid-types" version = "0.1.0" dependencies = [ @@ -3566,8 +3641,8 @@ dependencies = [ "hyper-timeout", "percent-encoding", "pin-project", - "prost", - "prost-derive", + "prost 0.8.0", + "prost-derive 0.8.0", "tokio", "tokio-stream", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index 75bbc7d0ae..209384de44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "talpid-dbus", "talpid-platform-metadata", "talpid-time", + "talpid-tunnel-config-client", "mullvad-management-interface", "tunnel-obfuscation", ] @@ -24,3 +25,8 @@ members = [ [profile.release] opt-level = 3 lto = true + +# Key generation may take over one minute without optimizations +# enabled. +[profile.dev.package."classic-mceliece-rust"] +opt-level = 3 diff --git a/gui/locales/messages.pot b/gui/locales/messages.pot index 9deffc1e8d..30997c9290 100644 --- a/gui/locales/messages.pot +++ b/gui/locales/messages.pot @@ -108,6 +108,9 @@ msgstr "" msgid "Connecting" msgstr "" +msgid "CREATING QUANTUM SECURE CONNECTION" +msgstr "" + msgid "CREATING SECURE CONNECTION" msgstr "" @@ -172,6 +175,9 @@ msgstr "" msgid "Open URL" msgstr "" +msgid "QUANTUM SECURE CONNECTION" +msgstr "" + msgid "Reconnect" msgstr "" diff --git a/gui/src/renderer/components/SecuredLabel.tsx b/gui/src/renderer/components/SecuredLabel.tsx index cda3c5174d..f622c8a847 100644 --- a/gui/src/renderer/components/SecuredLabel.tsx +++ b/gui/src/renderer/components/SecuredLabel.tsx @@ -5,8 +5,10 @@ import { messages } from '../../shared/gettext'; export enum SecuredDisplayStyle { secured, + securedPq, blocked, securing, + securingPq, unsecured, unsecuring, failedToSecure, @@ -14,8 +16,10 @@ export enum SecuredDisplayStyle { const securedDisplayStyleColorMap = { [SecuredDisplayStyle.securing]: colors.white, + [SecuredDisplayStyle.securingPq]: colors.white, [SecuredDisplayStyle.unsecuring]: colors.white, [SecuredDisplayStyle.secured]: colors.green, + [SecuredDisplayStyle.securedPq]: colors.green, [SecuredDisplayStyle.blocked]: colors.white, [SecuredDisplayStyle.unsecured]: colors.red, [SecuredDisplayStyle.failedToSecure]: colors.red, @@ -45,12 +49,18 @@ function getLabelText(displayStyle: SecuredDisplayStyle) { case SecuredDisplayStyle.secured: return messages.gettext('SECURE CONNECTION'); + case SecuredDisplayStyle.securedPq: + return messages.gettext('QUANTUM SECURE CONNECTION'); + case SecuredDisplayStyle.blocked: return messages.gettext('BLOCKED CONNECTION'); case SecuredDisplayStyle.securing: return messages.gettext('CREATING SECURE CONNECTION'); + case SecuredDisplayStyle.securingPq: + return messages.gettext('CREATING QUANTUM SECURE CONNECTION'); + case SecuredDisplayStyle.unsecured: return messages.gettext('UNSECURED CONNECTION'); diff --git a/gui/src/renderer/components/TunnelControl.tsx b/gui/src/renderer/components/TunnelControl.tsx index c7617bfdda..4d210ed2c2 100644 --- a/gui/src/renderer/components/TunnelControl.tsx +++ b/gui/src/renderer/components/TunnelControl.tsx @@ -77,6 +77,7 @@ const SelectedLocationChevron = styled(AppButton.Icon)({ export default class TunnelControl extends React.Component<ITunnelControlProps> { public render() { let state = this.props.tunnelState.state; + let pq = false; switch (this.props.tunnelState.state) { case 'disconnecting': @@ -92,14 +93,23 @@ export default class TunnelControl extends React.Component<ITunnelControlProps> break; } break; + case 'connecting': + if (this.props.tunnelState.details) { + pq = this.props.tunnelState.details.endpoint.quantumResistant; + } + break; + case 'connected': + pq = this.props.tunnelState.details.endpoint.quantumResistant; + break; } switch (state) { - case 'connecting': + case 'connecting': { + const displayStyle = pq ? SecuredDisplayStyle.securingPq : SecuredDisplayStyle.securing; return ( <Wrapper> <Body> - <Secured displayStyle={SecuredDisplayStyle.securing} /> + <Secured displayStyle={displayStyle} /> <Location> {this.renderCountry()} {this.renderCity()} @@ -112,11 +122,14 @@ export default class TunnelControl extends React.Component<ITunnelControlProps> </Footer> </Wrapper> ); - case 'connected': + } + + case 'connected': { + const displayStyle = pq ? SecuredDisplayStyle.securedPq : SecuredDisplayStyle.secured; return ( <Wrapper> <Body> - <Secured displayStyle={SecuredDisplayStyle.secured} /> + <Secured displayStyle={displayStyle} /> <Location> {this.renderCountry()} {this.renderCity()} @@ -129,6 +142,7 @@ export default class TunnelControl extends React.Component<ITunnelControlProps> </Footer> </Wrapper> ); + } case 'error': if ( diff --git a/gui/src/shared/daemon-rpc-types.ts b/gui/src/shared/daemon-rpc-types.ts index 45af1c605e..8bee8003b0 100644 --- a/gui/src/shared/daemon-rpc-types.ts +++ b/gui/src/shared/daemon-rpc-types.ts @@ -92,6 +92,7 @@ export interface ITunnelEndpoint { address: string; protocol: RelayProtocol; tunnelType: TunnelType; + quantumResistant: boolean; proxy?: IProxyEndpoint; obfuscationEndpoint?: IObfuscationEndpoint; entryEndpoint?: IEndpoint; diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs index c624c2a25b..ac7bd74c63 100644 --- a/mullvad-cli/src/cmds/relay.rs +++ b/mullvad-cli/src/cmds/relay.rs @@ -573,7 +573,24 @@ impl Relay { } if let Some(entry) = matches.values_of("entry location") { wireguard_constraints.entry_location = parse_entry_location_constraint(entry); - wireguard_constraints.use_multihop = wireguard_constraints.entry_location.is_some(); + let use_multihop = wireguard_constraints.entry_location.is_some(); + if use_multihop { + let use_pq = rpc + .get_settings(()) + .await? + .into_inner() + .tunnel_options + .unwrap() + .wireguard + .unwrap() + .use_pq_safe_psk; + if use_pq { + return Err(Error::CommandFailed( + "PQ PSK exchange does not work when multihop is enabled", + )); + } + } + wireguard_constraints.use_multihop = use_multihop; } self.update_constraints(types::RelaySettingsUpdate { diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs index f01452a925..7856cc849c 100644 --- a/mullvad-cli/src/cmds/tunnel.rs +++ b/mullvad-cli/src/cmds/tunnel.rs @@ -37,6 +37,7 @@ fn create_wireguard_subcommand() -> clap::App<'static> { .about("Manage options for Wireguard tunnels") .setting(clap::AppSettings::SubcommandRequiredElseHelp) .subcommand(create_wireguard_mtu_subcommand()) + .subcommand(create_wireguard_quantum_resistant_tunnel_subcommand()) .subcommand(create_wireguard_keys_subcommand()); #[cfg(windows)] { @@ -57,6 +58,14 @@ fn create_wireguard_mtu_subcommand() -> clap::App<'static> { .subcommand(clap::App::new("set").arg(clap::Arg::new("mtu").required(true))) } +fn create_wireguard_quantum_resistant_tunnel_subcommand() -> clap::App<'static> { + clap::App::new("quantum-resistant-tunnel") + .about("EXPERIMENTAL: Enables quantum-resistant PSK exchange in the tunnel") + .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::App::new("get")) + .subcommand(clap::App::new("set").arg(clap::Arg::new("policy").required(true))) +} + fn create_wireguard_keys_subcommand() -> clap::App<'static> { clap::App::new("key") .about("Manage your wireguard key") @@ -163,6 +172,14 @@ impl Tunnel { _ => unreachable!("unhandled command"), }, + Some(("quantum-resistant-tunnel", matches)) => match matches.subcommand() { + Some(("get", _)) => Self::process_wireguard_quantum_resistant_tunnel_get().await, + Some(("set", matches)) => { + Self::process_wireguard_quantum_resistant_tunnel_set(matches).await + } + _ => unreachable!("unhandled command"), + }, + #[cfg(windows)] Some(("use-wireguard-nt", matches)) => match matches.subcommand() { Some(("get", _)) => Self::process_wireguard_use_wg_nt_get().await, @@ -203,6 +220,45 @@ impl Tunnel { Ok(()) } + async fn process_wireguard_quantum_resistant_tunnel_get() -> Result<()> { + let tunnel_options = Self::get_tunnel_options().await?; + if tunnel_options.wireguard.unwrap().use_pq_safe_psk { + println!("enabled"); + } else { + println!("disabled"); + } + Ok(()) + } + + async fn process_wireguard_quantum_resistant_tunnel_set( + matches: &clap::ArgMatches, + ) -> Result<()> { + let new_state = matches.value_of("policy").unwrap() == "on"; + let mut rpc = new_rpc_client().await?; + let settings = rpc.get_settings(()).await?; + let multihop_is_enabled = settings + .into_inner() + .relay_settings + .unwrap() + .endpoint + .and_then(|endpoint| { + if let types::relay_settings::Endpoint::Normal(settings) = endpoint { + Some(settings.wireguard_constraints.unwrap().use_multihop) + } else { + None + } + }) + .unwrap_or(false); + if multihop_is_enabled { + return Err(Error::CommandFailed( + "PQ PSK exchange does not work when multihop is enabled", + )); + } + rpc.set_quantum_resistant_tunnel(new_state).await?; + println!("Updated quantum resistant tunnel setting"); + Ok(()) + } + #[cfg(windows)] async fn process_wireguard_use_wg_nt_get() -> Result<()> { let tunnel_options = Self::get_tunnel_options().await?; diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs index 74630166d2..47c05611b2 100644 --- a/mullvad-cli/src/format.rs +++ b/mullvad-cli/src/format.rs @@ -109,6 +109,15 @@ fn format_relay_connection(relay_info: &TunnelStateRelayInfo, verbose: bool) -> } else { String::new() }; + let quantum_resistant = if verbose { + if endpoint.quantum_resistant { + "\nQuantum resistant tunnel: true".to_string() + } else { + "\nQuantum resistant tunnel: false".to_string() + } + } else { + String::new() + }; let mut bridge_type = String::new(); let mut obfuscator_type = String::new(); @@ -127,7 +136,7 @@ fn format_relay_connection(relay_info: &TunnelStateRelayInfo, verbose: bool) -> } format!( - "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{bridge_type}{obfuscator_type}", + "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{bridge_type}{obfuscator_type}", first_hop = first_hop.unwrap_or_default(), bridge = bridge.unwrap_or_default(), obfuscator = obfuscator.unwrap_or_default(), diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index b75ad9121c..c013361925 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -220,6 +220,8 @@ pub enum DaemonCommand { SetBridgeState(ResponseTx<(), settings::Error>, BridgeState), /// Set if IPv6 should be enabled in the tunnel SetEnableIpv6(ResponseTx<(), settings::Error>, bool), + /// Set whether to enable PQ PSK exchange in the tunnel + SetQuantumResistantTunnel(ResponseTx<(), settings::Error>, bool), /// Set DNS options or servers to use SetDnsOptions(ResponseTx<(), settings::Error>, DnsOptions), /// Toggle macOS network check leak @@ -979,6 +981,9 @@ where } SetBridgeState(tx, bridge_state) => self.on_set_bridge_state(tx, bridge_state).await, SetEnableIpv6(tx, enable_ipv6) => self.on_set_enable_ipv6(tx, enable_ipv6).await, + SetQuantumResistantTunnel(tx, enable_pq) => { + self.on_set_quantum_resistant_tunnel(tx, enable_pq).await + } SetDnsOptions(tx, dns_servers) => self.on_set_dns_options(tx, dns_servers).await, SetWireguardMtu(tx, mtu) => self.on_set_wireguard_mtu(tx, mtu).await, SetWireguardRotationInterval(tx, interval) => { @@ -1053,7 +1058,7 @@ where } } PrivateDeviceEvent::RotatedKey(_) => { - if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { + if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) { self.schedule_reconnect(WG_RECONNECT_DELAY); } } @@ -1683,7 +1688,7 @@ where .set_tunnel_options(&self.settings.tunnel_options); self.event_listener .notify_settings(self.settings.to_settings()); - if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { + if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { log::info!("Initiating tunnel restart"); self.reconnect_tunnel(); } @@ -1834,7 +1839,7 @@ where .set_tunnel_options(&self.settings.tunnel_options); self.event_listener .notify_settings(self.settings.to_settings()); - if let Some(TunnelType::OpenVpn) = self.get_connected_tunnel_type() { + if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) { log::info!( "Initiating tunnel restart because the OpenVPN mssfix setting changed" ); @@ -1954,6 +1959,36 @@ where } } + async fn on_set_quantum_resistant_tunnel( + &mut self, + tx: ResponseTx<(), settings::Error>, + use_pq_safe_psk: bool, + ) { + let save_result = self + .settings + .set_quantum_resistant_tunnel(use_pq_safe_psk) + .await; + match save_result { + Ok(settings_changed) => { + Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response"); + if settings_changed { + self.parameters_generator + .set_tunnel_options(&self.settings.tunnel_options); + self.event_listener + .notify_settings(self.settings.to_settings()); + if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) { + log::info!("Reconnecting because the PQ safety setting changed"); + self.reconnect_tunnel(); + } + } + } + Err(e) => { + log::error!("{}", e.display_chain_with_msg("Unable to save settings")); + Self::oneshot_send(tx, Err(e), "set_quantum_resistant_tunnel response"); + } + } + } + async fn on_set_dns_options( &mut self, tx: ResponseTx<(), settings::Error>, diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 2264159fc8..a7478eb048 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -373,6 +373,17 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_settings_error) } + async fn set_quantum_resistant_tunnel(&self, request: Request<bool>) -> ServiceResult<()> { + let enable = request.into_inner(); + log::debug!("set_quantum_resistant_tunnel({})", enable); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::SetQuantumResistantTunnel(tx, enable))?; + self.wait_for_result(rx) + .await? + .map(Response::new) + .map_err(map_settings_error) + } + #[cfg(not(target_os = "android"))] async fn set_dns_options(&self, request: Request<types::DnsOptions>) -> ServiceResult<()> { let options = DnsOptions::try_from(request.into_inner()).map_err(map_protobuf_type_err)?; diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index 52800028cc..e69ca8316f 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -236,6 +236,22 @@ impl SettingsPersister { self.update(should_save).await } + pub async fn set_quantum_resistant_tunnel( + &mut self, + use_pq_safe_psk: bool, + ) -> Result<bool, Error> { + let should_save = Self::update_field( + &mut self + .settings + .tunnel_options + .wireguard + .options + .use_pq_safe_psk, + use_pq_safe_psk, + ); + self.update(should_save).await + } + pub async fn set_dns_options(&mut self, options: DnsOptions) -> Result<bool, Error> { let should_save = Self::update_field(&mut self.settings.tunnel_options.dns_options, options); diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 20d5318a58..9148a10150 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -43,6 +43,7 @@ service ManagementService { rpc SetOpenvpnMssfix(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {} rpc SetWireguardMtu(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {} rpc SetEnableIpv6(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} + rpc SetQuantumResistantTunnel(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} rpc SetDnsOptions(DnsOptions) returns (google.protobuf.Empty) {} // Account management @@ -193,9 +194,10 @@ message TunnelEndpoint { string address = 1; TransportProtocol protocol = 2; TunnelType tunnel_type = 3; - ProxyEndpoint proxy = 4; - ObfuscationEndpoint obfuscation = 5; - Endpoint entry_endpoint = 6; + bool quantum_resistant = 4; + ProxyEndpoint proxy = 5; + ObfuscationEndpoint obfuscation = 6; + Endpoint entry_endpoint = 7; } enum ObfuscationType { @@ -435,6 +437,7 @@ message TunnelOptions { uint32 mtu = 1; google.protobuf.Duration rotation_interval = 2; bool use_wireguard_nt = 3; + bool use_pq_safe_psk = 4; } message GenericOptions { bool enable_ipv6 = 1; diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 01f16d8220..cc2086d66e 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -35,6 +35,7 @@ impl From<talpid_types::net::TunnelEndpoint> for TunnelEndpoint { net::TunnelType::Wireguard => i32::from(TunnelType::Wireguard), net::TunnelType::OpenVpn => i32::from(TunnelType::Openvpn), }, + quantum_resistant: endpoint.quantum_resistant, proxy: endpoint.proxy.map(|proxy_ep| ProxyEndpoint { address: proxy_ep.endpoint.address.to_string(), protocol: i32::from(TransportProtocol::from(proxy_ep.endpoint.protocol)), @@ -680,6 +681,7 @@ impl From<&mullvad_types::settings::TunnelOptions> for TunnelOptions { use_wireguard_nt: options.wireguard.options.use_wireguard_nt, #[cfg(not(windows))] use_wireguard_nt: false, + use_pq_safe_psk: options.wireguard.options.use_pq_safe_psk, }), generic: Some(tunnel_options::GenericOptions { enable_ipv6: options.generic.enable_ipv6, @@ -1185,6 +1187,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { public_key, allowed_ips, endpoint, + psk: None, }, exit_peer: None, ipv4_gateway, @@ -1412,6 +1415,7 @@ impl TryFrom<TunnelOptions> for mullvad_types::settings::TunnelOptions { } else { None }, + use_pq_safe_psk: wireguard_options.use_pq_safe_psk, #[cfg(windows)] use_wireguard_nt: wireguard_options.use_wireguard_nt, }, diff --git a/mullvad-relay-selector/src/matcher.rs b/mullvad-relay-selector/src/matcher.rs index 089510a6c0..13e16646ab 100644 --- a/mullvad-relay-selector/src/matcher.rs +++ b/mullvad-relay-selector/src/matcher.rs @@ -189,6 +189,7 @@ impl WireguardMatcher { public_key: data.public_key, endpoint: SocketAddr::new(host, port), allowed_ips: all_of_the_internet(), + psk: None, }; Some(MullvadEndpoint::Wireguard(MullvadWireguardEndpoint { peer: peer_config, diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs index 4c05f1e552..da7f25828a 100644 --- a/mullvad-types/src/wireguard.rs +++ b/mullvad-types/src/wireguard.rs @@ -113,7 +113,8 @@ impl Default for RotationInterval { } } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(default)] #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr( target_os = "android", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index e422a537ca..b6ff02bd8c 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -26,6 +26,7 @@ regex = "1.1.0" shell-escape = "0.1" talpid-types = { path = "../talpid-types" } talpid-time = { path = "../talpid-time" } +talpid-tunnel-config-client = { path = "../talpid-tunnel-config-client" } uuid = { version = "0.8", features = ["v4"] } zeroize = "1" chrono = "0.4.19" diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs index 8928e94e3c..4d66e94b1a 100644 --- a/talpid-core/src/firewall/linux.rs +++ b/talpid-core/src/firewall/linux.rs @@ -12,9 +12,9 @@ use std::{ env, ffi::{CStr, CString}, io, - net::{IpAddr, Ipv4Addr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, }; -use talpid_types::net::{Endpoint, TransportProtocol}; +use talpid_types::net::{AllowedTunnelTraffic, Endpoint, Protocol, TransportProtocol}; /// Priority for rules that tag split tunneling packets. Equals NF_IP_PRI_MANGLE. const MANGLE_CHAIN_PRIORITY: i32 = libc::NF_IP_PRI_MANGLE; @@ -558,6 +558,7 @@ impl<'a> PolicyBatch<'a> { tunnel, allow_lan, allowed_endpoint, + allowed_tunnel_traffic, } => { self.add_allow_tunnel_endpoint_rules(peer_endpoint); self.add_allow_endpoint_rules(&allowed_endpoint.endpoint); @@ -567,7 +568,19 @@ impl<'a> PolicyBatch<'a> { self.add_drop_dns_rule(); if let Some(tunnel) = tunnel { - self.add_allow_tunnel_rules(&tunnel.interface)?; + match allowed_tunnel_traffic { + AllowedTunnelTraffic::All => { + self.add_allow_tunnel_rules(&tunnel.interface)?; + } + AllowedTunnelTraffic::None => (), + AllowedTunnelTraffic::Only(address, protocol) => { + self.add_allow_in_tunnel_endpoint_rules( + &tunnel.interface, + *address, + *protocol, + )?; + } + } if *allow_lan { self.add_block_cve_2019_14899(tunnel); } @@ -771,6 +784,35 @@ impl<'a> PolicyBatch<'a> { } } + fn add_allow_in_tunnel_endpoint_rules( + &mut self, + tunnel_interface: &str, + address: SocketAddr, + protocol: Protocol, + ) -> Result<()> { + for (chain, dir, end) in [ + (&self.out_chain, Direction::Out, End::Dst), + (&self.in_chain, Direction::In, End::Src), + ] { + let mut rule = Rule::new(chain); + + check_iface(&mut rule, dir, tunnel_interface)?; + check_ip(&mut rule, end, address.ip()); + match protocol { + Protocol::IcmpV4 | Protocol::IcmpV6 => check_l4proto(&mut rule, protocol), + Protocol::Tcp => { + check_port(&mut rule, TransportProtocol::Tcp, end, address.port()); + } + Protocol::Udp => { + check_port(&mut rule, TransportProtocol::Udp, end, address.port()); + } + } + add_verdict(&mut rule, &Verdict::Accept); + self.batch.add(&rule, nftnl::MsgType::Add); + } + Ok(()) + } + fn add_allow_tunnel_rules(&mut self, tunnel_interface: &str) -> Result<()> { self.batch.add( &allow_interface_rule(&self.out_chain, Direction::Out, tunnel_interface)?, @@ -988,7 +1030,7 @@ fn check_ip(rule: &mut Rule<'_>, end: End, ip: impl Into<IpAddr>) { fn check_port(rule: &mut Rule<'_>, protocol: TransportProtocol, end: End, port: u16) { // Must check transport layer protocol before loading transport layer payload - check_l4proto(rule, protocol); + check_l4proto(rule, protocol.into()); rule.add_expr(&match (protocol, end) { (TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport), @@ -1011,15 +1053,17 @@ fn l3proto(addr: IpAddr) -> u8 { } } -fn check_l4proto(rule: &mut Rule<'_>, protocol: TransportProtocol) { +fn check_l4proto(rule: &mut Rule<'_>, protocol: Protocol) { rule.add_expr(&nft_expr!(meta l4proto)); rule.add_expr(&nft_expr!(cmp == l4proto(protocol))); } -fn l4proto(protocol: TransportProtocol) -> u8 { +fn l4proto(protocol: Protocol) -> u8 { match protocol { - TransportProtocol::Udp => libc::IPPROTO_UDP as u8, - TransportProtocol::Tcp => libc::IPPROTO_TCP as u8, + Protocol::Udp => libc::IPPROTO_UDP as u8, + Protocol::Tcp => libc::IPPROTO_TCP as u8, + Protocol::IcmpV4 => libc::IPPROTO_ICMP as u8, + Protocol::IcmpV6 => libc::IPPROTO_ICMPV6 as u8, } } diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs index 7c2e1455b2..fa40ae7e0d 100644 --- a/talpid-core/src/firewall/macos.rs +++ b/talpid-core/src/firewall/macos.rs @@ -6,7 +6,7 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; use subslice::SubsliceExt; -use talpid_types::net; +use talpid_types::net::{self, AllowedTunnelTraffic}; pub use pfctl::Error; @@ -119,6 +119,7 @@ impl Firewall { tunnel, allow_lan, allowed_endpoint, + allowed_tunnel_traffic, } => { let mut rules = vec![self.get_allow_relay_rule(*peer_endpoint)?]; rules.push(self.get_allowed_endpoint_rule(allowed_endpoint.endpoint)?); @@ -128,7 +129,10 @@ impl Firewall { rules.append(&mut self.get_block_dns_rules()?); if let Some(tunnel) = tunnel { - rules.push(self.get_allow_tunnel_rule(&tunnel.interface)?); + rules.extend( + self.get_allow_tunnel_rule(&tunnel.interface, allowed_tunnel_traffic)? + .into_iter(), + ); } if *allow_lan { @@ -154,7 +158,13 @@ impl Firewall { // can't leak to the wrong IPs in the tunnel or on the LAN. rules.append(&mut self.get_block_dns_rules()?); - rules.push(self.get_allow_tunnel_rule(tunnel.interface.as_str())?); + rules.extend( + self.get_allow_tunnel_rule( + tunnel.interface.as_str(), + &AllowedTunnelTraffic::All, + )? + .into_iter(), + ); if *allow_lan { rules.append(&mut self.get_allow_lan_rules()?); @@ -318,14 +328,34 @@ impl Firewall { Ok(vec![block_tcp_dns_rule, block_udp_dns_rule]) } - fn get_allow_tunnel_rule(&self, tunnel_interface: &str) -> Result<pfctl::FilterRule> { - Ok(self - .create_rule_builder(FilterRuleAction::Pass) + fn get_allow_tunnel_rule( + &self, + tunnel_interface: &str, + allowed_traffic: &AllowedTunnelTraffic, + ) -> Result<Option<pfctl::FilterRule>> { + let mut rule_builder = self.create_rule_builder(FilterRuleAction::Pass); + let mut base_rule = rule_builder .quick(true) .interface(tunnel_interface) .keep_state(pfctl::StatePolicy::Keep) - .tcp_flags(Self::get_tcp_flags()) - .build()?) + .tcp_flags(Self::get_tcp_flags()); + match allowed_traffic { + AllowedTunnelTraffic::Only(addr, protocol) => { + use talpid_types::net::Protocol::*; + let pfctl_proto = match protocol { + Udp => pfctl::Proto::Udp, + Tcp => pfctl::Proto::Tcp, + IcmpV4 => pfctl::Proto::Icmp, + IcmpV6 => pfctl::Proto::IcmpV6, + }; + base_rule = base_rule.to(*addr).proto(pfctl_proto); + } + AllowedTunnelTraffic::All => {} + AllowedTunnelTraffic::None => { + return Ok(None); + } + }; + Ok(Some(base_rule.build()?)) } fn get_allow_loopback_rules(&self) -> Result<Vec<pfctl::FilterRule>> { diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index b23e16b017..a72ad65b1b 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -8,7 +8,7 @@ use std::{ fmt, net::{Ipv4Addr, Ipv6Addr}, }; -use talpid_types::net::{AllowedEndpoint, Endpoint}; +use talpid_types::net::{AllowedEndpoint, AllowedTunnelTraffic, Endpoint}; #[cfg(target_os = "macos")] #[path = "macos.rs"] @@ -109,6 +109,8 @@ pub enum FirewallPolicy { allow_lan: bool, /// Host that should be reachable while connecting. allowed_endpoint: AllowedEndpoint, + /// Networks for which to permit in-tunnel traffic. + allowed_tunnel_traffic: AllowedTunnelTraffic, /// A process that is allowed to send packets to the relay. #[cfg(windows)] relay_client: PathBuf, @@ -151,12 +153,13 @@ impl fmt::Display for FirewallPolicy { tunnel, allow_lan, allowed_endpoint, + allowed_tunnel_traffic, .. } => { if let Some(tunnel) = tunnel { write!( f, - "Connecting to {} over \"{}\" (ip: {}, v4 gw: {}, v6 gw: {:?}), {} LAN. Allowing endpoint {}", + "Connecting to {} over \"{}\" (ip: {}, v4 gw: {}, v6 gw: {:?}, allowed in-tunnel traffic: {}), {} LAN. Allowing endpoint {}", peer_endpoint, tunnel.interface, tunnel @@ -167,6 +170,7 @@ impl fmt::Display for FirewallPolicy { .join(","), tunnel.ipv4_gateway, tunnel.ipv6_gateway, + allowed_tunnel_traffic, if *allow_lan { "Allowing" } else { "Blocking" }, allowed_endpoint, ) diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index dca86d2257..4683924d17 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -6,7 +6,7 @@ use self::winfw::*; use super::{FirewallArguments, FirewallPolicy, InitialFirewallState}; use crate::winnet; use talpid_types::{ - net::{AllowedEndpoint, Endpoint}, + net::{AllowedEndpoint, AllowedTunnelTraffic, Endpoint}, tunnel::FirewallPolicyError, }; use widestring::WideCString; @@ -102,6 +102,7 @@ impl Firewall { tunnel, allow_lan, allowed_endpoint, + allowed_tunnel_traffic, relay_client, } => { let cfg = &WinFwSettings::new(allow_lan); @@ -111,6 +112,7 @@ impl Firewall { &cfg, &tunnel, &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(), + &allowed_tunnel_traffic, &relay_client, ) } @@ -148,6 +150,7 @@ impl Firewall { winfw_settings: &WinFwSettings, tunnel_metadata: &Option<TunnelMetadata>, allowed_endpoint: &WinFwAllowedEndpoint<'_>, + allowed_tunnel_traffic: &AllowedTunnelTraffic, relay_client: &Path, ) -> Result<(), Error> { log::trace!("Applying 'connecting' firewall policy"); @@ -169,6 +172,26 @@ impl Firewall { ptr::null() }; + let allowed_tun_ip; + let allowed_tunnel_endpoint = + if let AllowedTunnelTraffic::Only(addr, proto) = allowed_tunnel_traffic { + allowed_tun_ip = widestring_ip(addr.ip()); + Some(WinFwEndpoint { + ip: allowed_tun_ip.as_ptr(), + port: addr.port(), + protocol: WinFwProt::from(*proto), + }) + } else { + None + }; + let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic { + type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic), + endpoint: allowed_tunnel_endpoint + .as_ref() + .map(|ep| ep as *const _) + .unwrap_or(ptr::null()), + }; + unsafe { WinFw_ApplyPolicyConnecting( winfw_settings, @@ -176,6 +199,7 @@ impl Firewall { relay_client.as_ptr(), interface_wstr_ptr, allowed_endpoint, + &allowed_tunnel_traffic, ) .into_result() .map_err(Error::ApplyingConnectingPolicy) @@ -276,10 +300,10 @@ fn widestring_ip(ip: IpAddr) -> WideCString { #[allow(non_snake_case)] mod winfw { - use super::{widestring_ip, AllowedEndpoint, Error, WideCString}; + use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString}; use crate::logging::windows::LogSink; use libc; - use talpid_types::net::TransportProtocol; + use talpid_types::net::{Protocol, TransportProtocol}; pub struct WinFwAllowedEndpointContainer { _clients: Box<[WideCString]>, @@ -338,6 +362,30 @@ mod winfw { } #[repr(C)] + pub struct WinFwAllowedTunnelTraffic { + pub type_: WinFwAllowedTunnelTrafficType, + pub endpoint: *const WinFwEndpoint, + } + + #[repr(u8)] + #[derive(Clone, Copy)] + pub enum WinFwAllowedTunnelTrafficType { + None, + All, + Only, + } + + impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType { + fn from(traffic: &AllowedTunnelTraffic) -> Self { + match traffic { + AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None, + AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All, + AllowedTunnelTraffic::Only(..) => WinFwAllowedTunnelTrafficType::Only, + } + } + } + + #[repr(C)] pub struct WinFwEndpoint { pub ip: *const libc::wchar_t, pub port: u16, @@ -349,6 +397,8 @@ mod winfw { pub enum WinFwProt { Tcp = 0u8, Udp = 1u8, + IcmpV4 = 2u8, + IcmpV6 = 3u8, } impl From<TransportProtocol> for WinFwProt { @@ -360,6 +410,17 @@ mod winfw { } } + impl From<Protocol> for WinFwProt { + fn from(prot: Protocol) -> WinFwProt { + match prot { + Protocol::Tcp => WinFwProt::Tcp, + Protocol::Udp => WinFwProt::Udp, + Protocol::IcmpV4 => WinFwProt::IcmpV4, + Protocol::IcmpV6 => WinFwProt::IcmpV6, + } + } + } + #[repr(C)] pub struct WinFwSettings { permitDhcp: bool, @@ -440,7 +501,8 @@ mod winfw { relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, tunnelIfaceAlias: *const libc::wchar_t, - allowed_endpoint: *const WinFwAllowedEndpoint<'_>, + allowedEndpoint: *const WinFwAllowedEndpoint<'_>, + allowedTunnelTraffic: &WinFwAllowedTunnelTraffic, ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyConnected"] diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index ea7ec35133..bafc839cf6 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -8,7 +8,7 @@ use std::{ }; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; -use talpid_types::net::{wireguard as wireguard_types, TunnelParameters}; +use talpid_types::net::{wireguard as wireguard_types, AllowedTunnelTraffic, TunnelParameters}; #[cfg(target_os = "android")] pub use self::tun_provider::TunConfig; @@ -73,7 +73,7 @@ pub enum TunnelEvent { /// Sent when the tunnel fails to connect due to an authentication error. AuthFailed(Option<String>), /// Sent when the tunnel interface has been created, before routes are set up. - InterfaceUp(TunnelMetadata), + InterfaceUp(TunnelMetadata, AllowedTunnelTraffic), /// Sent when the tunnel comes up and is ready for traffic. Up(TunnelMetadata), /// Sent when the tunnel goes down. @@ -198,6 +198,18 @@ impl TunnelMonitor { let monitor = wireguard::WireguardMonitor::start( runtime, config, + if params.options.use_pq_safe_psk { + Some( + params + .connection + .exit_peer + .as_ref() + .map(|peer| peer.public_key.clone()) + .unwrap_or(params.connection.peer.public_key.clone()), + ) + } else { + None + }, log.as_deref(), resource_dir, on_event, diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index cb885fb8ec..d9ddf8dd9a 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -873,9 +873,10 @@ mod event_server { request: Request<EventDetails>, ) -> std::result::Result<Response<()>, tonic::Status> { let env = request.into_inner().env; - (self.on_event)(super::TunnelEvent::InterfaceUp(Self::get_tunnel_metadata( - &env, - )?)) + (self.on_event)(super::TunnelEvent::InterfaceUp( + Self::get_tunnel_metadata(&env)?, + talpid_types::net::AllowedTunnelTraffic::All, + )) .await; Ok(Response::new(())) } diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs index 026c6c9fae..ed285742c4 100644 --- a/talpid-core/src/tunnel/wireguard/config.rs +++ b/talpid-core/src/tunnel/wireguard/config.rs @@ -6,6 +6,7 @@ use std::{ use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions}; /// Config required to set up a single WireGuard tunnel +#[derive(Clone)] pub struct Config { /// Contains tunnel endpoint specific config pub tunnel: wireguard::TunnelConfig, @@ -147,6 +148,9 @@ impl Config { .add("public_key", peer.public_key.as_bytes().as_ref()) .add("endpoint", peer.endpoint.to_string().as_str()) .add("replace_allowed_ips", "true"); + if let Some(ref psk) = peer.psk { + wg_conf.add("preshared_key", psk.as_bytes().as_ref()); + } for addr in &peer.allowed_ips { wg_conf.add("allowed_ip", addr.to_string().as_str()); } diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index ec2af873a0..bcb28d7c17 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -391,12 +391,16 @@ impl ConnState { #[cfg(test)] mod test { + use futures::Future; + use super::*; use crate::tunnel::wireguard::{ + config::Config, stats::{self, Stats}, TunnelError, }; use std::{ + pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, Mutex, @@ -598,6 +602,13 @@ mod test { fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> { (self.on_get_stats)() } + + fn set_config( + &self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { + Box::pin(async { Ok(()) }) + } } fn mock_monitor( diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 80c8e2ae8e..3d0de8ba5c 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -8,6 +8,7 @@ use futures::{channel::mpsc, StreamExt}; use futures::{ channel::oneshot, future::{abortable, AbortHandle as FutureAbortHandle}, + Future, }; #[cfg(target_os = "linux")] use lazy_static::lazy_static; @@ -18,14 +19,20 @@ use std::env; #[cfg(windows)] use std::io; use std::{ + borrow::Cow, convert::Infallible, - net::IpAddr, + net::{IpAddr, SocketAddrV4}, path::Path, + pin::Pin, sync::{mpsc as sync_mpsc, Arc, Mutex}, + time::Duration, }; #[cfg(windows)] use talpid_types::BoxedError; -use talpid_types::{net::obfuscation::ObfuscatorConfig, ErrorExt}; +use talpid_types::{ + net::{obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Protocol}, + ErrorExt, +}; use tunnel_obfuscation::{ create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings, }; @@ -73,6 +80,10 @@ pub enum Error { #[error(display = "Connectivity monitor failed")] ConnectivityMonitorError(#[error(source)] connectivity_check::Error), + /// Failed to negotiate PQ PSK + #[error(display = "Failed to negotiate PQ PSK")] + PskNegotiationError(#[error(source)] talpid_tunnel_config_client::Error), + /// Failed to set up IP interfaces. #[cfg(windows)] #[error(display = "Failed to set up IP interfaces")] @@ -101,6 +112,10 @@ pub struct WireguardMonitor { _obfuscator: Option<ObfuscatorHandle>, } +const INITIAL_PSK_EXCHANGE_TIMEOUT: Duration = Duration::from_secs(4); +const MAX_PSK_EXCHANGE_TIMEOUT: Duration = Duration::from_secs(15); +const PSK_EXCHANGE_TIMEOUT_MULTIPLIER: u32 = 2; + /// Simple wrapper that automatically cancels the future which runs an obfuscator. struct ObfuscatorHandle { abort_handle: FutureAbortHandle, @@ -180,7 +195,7 @@ fn maybe_create_obfuscator( impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config pub fn start< - F: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + Send + Sync + Clone @@ -188,6 +203,7 @@ impl WireguardMonitor { >( runtime: tokio::runtime::Handle, mut config: Config, + psk_negotiation: Option<PublicKey>, log_path: Option<&Path>, resource_dir: &Path, on_event: F, @@ -203,10 +219,11 @@ impl WireguardMonitor { let obfuscator = maybe_create_obfuscator(&runtime, &mut config, close_msg_sender.clone())?; #[cfg(target_os = "windows")] - let (setup_done_tx, mut setup_done_rx) = mpsc::channel(0); + let (setup_done_tx, setup_done_rx) = mpsc::channel(0); + let tunnel = Self::open_tunnel( runtime.clone(), - &config, + &Self::patch_allowed_ips(&config, psk_negotiation.is_some()), log_path, resource_dir, tun_provider, @@ -237,31 +254,22 @@ impl WireguardMonitor { .map_err(Error::ConnectivityMonitorError)?; let metadata = Self::tunnel_metadata(&iface_name, &config); + let tunnel = monitor.tunnel.clone(); let tunnel_fut = async move { #[cfg(windows)] - { - setup_done_rx - .next() - .await - .ok_or_else(|| { - // Tunnel was shut down early - CloseMsg::SetupError(Error::IpInterfacesError) - })? - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to configure tunnel interface") - ); - CloseMsg::SetupError(Error::IpInterfacesError) - })?; + Self::add_device_ip_addresses(&iface_name, &config.tunnel.addresses, setup_done_rx) + .await?; - if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { - return Err(CloseMsg::SetupError(Error::SetIpAddressesError)); - } - } - - (on_event)(TunnelEvent::InterfaceUp(metadata.clone())).await; + let allowed_traffic = if psk_negotiation.is_some() { + AllowedTunnelTraffic::Only( + SocketAddrV4::new(config.ipv4_gateway, 1337).into(), + Protocol::Tcp, + ) + } else { + AllowedTunnelTraffic::All + }; + (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; // Add non-default routes before establishing the tunnel. #[cfg(target_os = "linux")] @@ -280,6 +288,15 @@ impl WireguardMonitor { .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; + if let Some(pubkey) = psk_negotiation { + Self::perform_psk_negotiation(tunnel, retry_attempt, pubkey, &mut config).await?; + (on_event)(TunnelEvent::InterfaceUp( + metadata.clone(), + AllowedTunnelTraffic::All, + )) + .await; + } + let mut connectivity_monitor = tokio::task::spawn_blocking(move || { match connectivity_monitor.establish_connectivity(retry_attempt) { Ok(true) => Ok(connectivity_monitor), @@ -339,6 +356,125 @@ impl WireguardMonitor { Ok(monitor) } + /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. + /// Used to block traffic to other destinations while connecting on Android. + fn patch_allowed_ips<'a>(config: &'a Config, gateway_only: bool) -> Cow<'a, Config> { + if gateway_only { + let mut patched_config = config.clone(); + let gateway_net_v4 = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway)); + let gateway_net_v6 = config + .ipv6_gateway + .map(|net| ipnetwork::IpNetwork::from(IpAddr::from(net))); + for peer in &mut patched_config.peers { + peer.allowed_ips = peer + .allowed_ips + .iter() + .cloned() + .filter_map(|mut allowed_ip| { + if allowed_ip.prefix() == 0 { + if allowed_ip.is_ipv4() { + allowed_ip = gateway_net_v4; + } else { + if let Some(net) = gateway_net_v6 { + allowed_ip = net; + } else { + return None; + } + } + } + Some(allowed_ip) + }) + .collect(); + } + Cow::Owned(patched_config) + } else { + Cow::Borrowed(config) + } + } + + #[cfg(windows)] + async fn add_device_ip_addresses( + iface_name: &str, + addresses: &[IpAddr], + mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>, + ) -> std::result::Result<(), CloseMsg> { + setup_done_rx + .next() + .await + .ok_or_else(|| { + // Tunnel was shut down early + CloseMsg::SetupError(Error::IpInterfacesError) + })? + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to configure tunnel interface") + ); + CloseMsg::SetupError(Error::IpInterfacesError) + })?; + if !crate::winnet::add_device_ip_addresses(iface_name, addresses) { + return Err(CloseMsg::SetupError(Error::SetIpAddressesError)); + } + Ok(()) + } + + async fn perform_psk_negotiation( + tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, + retry_attempt: u32, + current_pubkey: PublicKey, + config: &mut Config, + ) -> std::result::Result<(), CloseMsg> { + log::debug!("Performing PQ-safe PSK exchange"); + + let timeout = std::cmp::min( + MAX_PSK_EXCHANGE_TIMEOUT, + INITIAL_PSK_EXCHANGE_TIMEOUT + .saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)), + ); + + let (private_key, psk) = tokio::time::timeout( + timeout, + talpid_tunnel_config_client::push_pq_key( + IpAddr::V4(config.ipv4_gateway), + config.tunnel.private_key.public_key(), + ), + ) + .await + .map_err(|_timeout_err| { + log::warn!("Timeout while negotiating PSK"); + CloseMsg::PskNegotiationTimeout + })? + .map_err(Error::PskNegotiationError) + .map_err(CloseMsg::SetupError)?; + + config.tunnel.private_key = private_key; + + for peer in &mut config.peers { + if current_pubkey == peer.public_key { + peer.psk = Some(psk); + break; + } + } + + log::trace!( + "Ephemeral pubkey: {}", + config.tunnel.private_key.public_key() + ); + + let set_config_future = tunnel + .lock() + .unwrap() + .as_ref() + .map(|tunnel| tunnel.set_config(config.clone())); + if let Some(f) = set_config_future { + f.await + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } + + Ok(()) + } + #[allow(unused_variables)] fn open_tunnel( runtime: tokio::runtime::Handle, @@ -424,7 +560,7 @@ impl WireguardMonitor { /// Blocks the current thread until tunnel disconnects pub fn wait(mut self) -> Result<()> { let wait_result = match self.close_msg_receiver.recv() { - Ok(CloseMsg::PingErr) => Err(Error::TimeoutError), + Ok(CloseMsg::PskNegotiationTimeout) | Ok(CloseMsg::PingErr) => Err(Error::TimeoutError), Ok(CloseMsg::Stop) | Ok(CloseMsg::ObfuscatorExpired) => Ok(()), Ok(CloseMsg::SetupError(error)) => Err(error), Ok(CloseMsg::ObfuscatorFailed(error)) => Err(error), @@ -584,6 +720,7 @@ impl WireguardMonitor { enum CloseMsg { Stop, + PskNegotiationTimeout, PingErr, SetupError(Error), ObfuscatorExpired, @@ -594,6 +731,10 @@ pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>; + fn set_config( + &self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>; } /// Errors to be returned from WireGuard implementations, namely implementers of the Tunnel trait @@ -630,6 +771,10 @@ pub enum TunnelError { #[error(display = "Failed to get config of WireGuard tunnel")] GetConfigError, + /// Failed to set WireGuard tunnel config on device + #[error(display = "Failed to set config of WireGuard tunnel")] + SetConfigError, + /// Failed to duplicate tunnel file descriptor for wireguard-go #[cfg(any(target_os = "linux", target_os = "macos", target_os = "android"))] #[error(display = "Failed to duplicate tunnel file descriptor for wireguard-go")] diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index a3d34acc0e..4fbdcefcdd 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -13,8 +13,10 @@ use futures::SinkExt; use ipnetwork::IpNetwork; use std::{ ffi::{c_void, CStr}, + future::Future, os::raw::c_char, path::Path, + pin::Pin, }; #[cfg(windows)] use talpid_types::BoxedError; @@ -354,6 +356,21 @@ impl Tunnel for WgGoTunnel { fn stop(mut self: Box<Self>) -> Result<()> { self.stop_tunnel() } + + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { + let wg_config_str = config.to_userspace_format(); + let handle = self.handle.unwrap(); + Box::pin(async move { + let status = unsafe { wgSetConfig(handle, wg_config_str.as_ptr() as *const i8) }; + if status != 0 { + return Err(TunnelError::SetConfigError); + } + Ok(()) + }) + } } fn check_wg_status(wg_code: i32) -> Result<()> { @@ -422,6 +439,9 @@ extern "C" { // Returns the file descriptor of the tunnel IPv4 socket. fn wgGetConfig(handle: i32) -> *mut std::os::raw::c_char; + // Sets the config of the WireGuard interface. + fn wgSetConfig(handle: i32, settings: *const i8) -> i32; + // Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig fn wgFreePtr(ptr: *mut c_void); diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs index 8ab3234fd5..109ae75367 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs @@ -1,3 +1,7 @@ +use std::pin::Pin; + +use futures::Future; + use super::{ super::stats::{Stats, StatsMap}, wg_message::DeviceNla, @@ -109,4 +113,20 @@ impl Tunnel for NetlinkTunnel { result } + + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> { + let mut wg = self.netlink_connections.wg_handle.clone(); + let interface_index = self.interface_index; + Box::pin(async move { + wg.set_config(interface_index, &config) + .await + .map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::SetConfigError + }) + }) + } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs index fed29b93d9..92a8f5e81c 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs @@ -2,7 +2,8 @@ use super::{ super::stats::{Stats, StatsMap}, Config, Error as WgKernelError, Handle, Tunnel, TunnelError, MULLVAD_INTERFACE_NAME, }; -use std::collections::HashMap; +use futures::Future; +use std::{collections::HashMap, pin::Pin}; use talpid_dbus::{ dbus, network_manager::{ @@ -91,6 +92,24 @@ impl Tunnel for NetworkManagerTunnel { Ok(Stats::parse_device_message(&device)) }) } + + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { + let interface_name = self.interface_name.clone(); + let mut wg = self.netlink_connections.wg_handle.clone(); + Box::pin(async move { + let index = crate::linux::iface_index(&interface_name).map_err(|err| { + log::error!("Failed to fetch WireGuard device index: {}", err); + TunnelError::SetConfigError + })?; + wg.set_config(index, &config).await.map_err(|err| { + log::error!("Failed to apply WireGuard config: {}", err); + TunnelError::SetConfigError + }) + }) + } } fn convert_config_to_dbus(config: &Config) -> DeviceConfig { diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs index 007acb8df7..dd17bc0cb3 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -81,12 +81,16 @@ impl DeviceMessage { for peer in config.peers.iter() { let peer_endpoint = InetAddr::from_std(&peer.endpoint); let allowed_ips = peer.allowed_ips.iter().map(From::from).collect(); - peers.push(PeerMessage(vec![ + let mut peer_nlas = vec![ PeerNla::PublicKey(*peer.public_key.as_bytes()), PeerNla::Endpoint(peer_endpoint), PeerNla::AllowedIps(allowed_ips), PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), - ])); + ]; + if let Some(psk) = peer.psk.as_ref() { + peer_nlas.push(PeerNla::PresharedKey(psk.as_bytes().clone())); + } + peers.push(PeerMessage(peer_nlas)); } let nlas = vec![ diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 21c9b705ce..0ae469144b 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -11,11 +11,14 @@ use ipnetwork::IpNetwork; use lazy_static::lazy_static; use std::{ ffi::CStr, - fmt, io, mem, + fmt, + future::Future, + io, mem, mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr}, os::windows::io::RawHandle, path::Path, + pin::Pin, ptr, sync::{Arc, Mutex}, }; @@ -833,11 +836,20 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { buffer.extend(windows::as_uninit_byte_slice(&header)); for peer in &config.peers { + let flags = if peer.psk.is_some() { + WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT + } else { + WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT + }; let wg_peer = WgPeer { - flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, + flags, reserved: 0, public_key: peer.public_key.as_bytes().clone(), - preshared_key: [0u8; WIREGUARD_KEY_LENGTH], + preshared_key: peer + .psk + .as_ref() + .map(|psk| psk.as_bytes().clone()) + .unwrap_or([0u8; WIREGUARD_KEY_LENGTH]), persistent_keepalive: 0, endpoint: windows::inet_sockaddr_from_socketaddr(peer.endpoint).into(), tx_bytes: 0, @@ -976,6 +988,24 @@ impl Tunnel for WgNtTunnel { self.stop_tunnel(); Ok(()) } + + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { + let device = self.device.clone(); + Box::pin(async move { + let guard = device.lock().unwrap(); + let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?; + device.set_config(&config).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set wg-nt tunnel config") + ); + super::TunnelError::SetConfigError + }) + }) + } } #[cfg(test)] @@ -1006,6 +1036,7 @@ mod tests { public_key: WG_PUBLIC_KEY.clone(), allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], endpoint: "1.2.3.4:1234".parse().unwrap(), + psk: None, }], ipv4_gateway: "0.0.0.0".parse().unwrap(), ipv6_gateway: None, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 747a5bdcbb..7536b26b09 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -21,7 +21,7 @@ use std::{ time::{Duration, Instant}, }; use talpid_types::{ - net::TunnelParameters, + net::{AllowedTunnelTraffic, TunnelParameters}, tunnel::{ErrorStateCause, FirewallPolicyError}, ErrorExt, }; @@ -47,6 +47,7 @@ pub struct ConnectingState { tunnel_events: TunnelEventsReceiver, tunnel_parameters: TunnelParameters, tunnel_metadata: Option<TunnelMetadata>, + allowed_tunnel_traffic: AllowedTunnelTraffic, tunnel_close_event: TunnelCloseEvent, tunnel_close_tx: oneshot::Sender<()>, retry_attempt: u32, @@ -57,6 +58,7 @@ impl ConnectingState { shared_values: &mut SharedTunnelStateValues, params: &TunnelParameters, tunnel_metadata: &Option<TunnelMetadata>, + allowed_tunnel_traffic: AllowedTunnelTraffic, ) -> Result<(), FirewallPolicyError> { #[cfg(target_os = "linux")] shared_values.disable_connectivity_check(); @@ -68,6 +70,7 @@ impl ConnectingState { tunnel: tunnel_metadata.clone(), allow_lan: shared_values.allow_lan, allowed_endpoint: shared_values.allowed_endpoint.clone(), + allowed_tunnel_traffic, #[cfg(windows)] relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, ¶ms), }; @@ -207,6 +210,7 @@ impl ConnectingState { tunnel_events: event_rx.fuse(), tunnel_parameters: parameters, tunnel_metadata: None, + allowed_tunnel_traffic: AllowedTunnelTraffic::None, tunnel_close_event: tunnel_close_event_rx.fuse(), tunnel_close_tx, retry_attempt, @@ -294,6 +298,7 @@ impl ConnectingState { shared_values, &self.tunnel_parameters, &self.tunnel_metadata, + self.allowed_tunnel_traffic.clone(), ) { Ok(()) => { cfg_if! { @@ -333,6 +338,7 @@ impl ConnectingState { shared_values, &self.tunnel_parameters, &self.tunnel_metadata, + self.allowed_tunnel_traffic.clone(), ) { let _ = tx.send(()); return self.disconnect( @@ -399,7 +405,7 @@ impl ConnectingState { shared_values, AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)), ), - Some((TunnelEvent::InterfaceUp(metadata), _done_tx)) => { + Some((TunnelEvent::InterfaceUp(metadata, allowed_tunnel_traffic), _done_tx)) => { #[cfg(windows)] if let Err(error) = shared_values .split_tunnel @@ -416,11 +422,15 @@ impl ConnectingState { AfterDisconnect::Block(ErrorStateCause::SplitTunnelError), ); } + + self.allowed_tunnel_traffic = allowed_tunnel_traffic; self.tunnel_metadata = Some(metadata); + match Self::set_firewall_policy( shared_values, &self.tunnel_parameters, &self.tunnel_metadata, + self.allowed_tunnel_traffic.clone(), ) { Ok(()) => SameState(self.into()), Err(error) => self.disconnect( @@ -549,9 +559,12 @@ impl TunnelState for ConnectingState { return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError); } - if let Err(error) = - Self::set_firewall_policy(shared_values, &tunnel_parameters, &None) - { + if let Err(error) = Self::set_firewall_policy( + shared_values, + &tunnel_parameters, + &None, + AllowedTunnelTraffic::None, + ) { ErrorState::enter( shared_values, ErrorStateCause::SetFirewallPolicyError(error), diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 7f31082541..3d0075257b 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -404,7 +404,7 @@ pub fn interface_luid_to_ip( } } -pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool { +pub fn add_device_ip_addresses(iface: &str, addresses: &[IpAddr]) -> bool { let raw_iface = WideCString::from_str(iface) .expect("Failed to convert UTF-8 string to null terminated UCS string") .into_raw(); diff --git a/talpid-tunnel-config-client/Cargo.toml b/talpid-tunnel-config-client/Cargo.toml new file mode 100644 index 0000000000..5cb8983c41 --- /dev/null +++ b/talpid-tunnel-config-client/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "talpid-tunnel-config-client" +version = "0.1.0" +authors = ["Mullvad VPN"] +description = "Uses the relay RPC service to set up PQ-safe peers, etc." +license = "GPL-3.0" +edition = "2021" +publish = false + +[dependencies] +log = "0.4" +rand = "0.8" +talpid-types = { path = "../talpid-types" } +tonic = "0.5" +prost = "0.8" +prost-types = "0.9" +tower = "0.4" +tokio = "1" +classic-mceliece-rust = { git = "https://github.com/mullvad/classic-mceliece-rust", rev = "5130d9e3bfbf54735177e15636a643366c250b78", features = ["mceliece8192128f"] } + +[dev-dependencies] +tokio = { version = "1", features = ["rt-multi-thread"] } + +[build-dependencies] +tonic-build = { version = "0.5", default-features = false, features = ["transport", "prost"] }
\ No newline at end of file diff --git a/talpid-tunnel-config-client/build.rs b/talpid-tunnel-config-client/build.rs new file mode 100644 index 0000000000..2732fb69d0 --- /dev/null +++ b/talpid-tunnel-config-client/build.rs @@ -0,0 +1,5 @@ +fn main() { + const PROTO_FILE: &str = "proto/tunnel_config.proto"; + tonic_build::compile_protos(PROTO_FILE).unwrap(); + println!("cargo:rerun-if-changed={}", PROTO_FILE); +} diff --git a/talpid-tunnel-config-client/examples/psk-exchange.rs b/talpid-tunnel-config-client/examples/psk-exchange.rs new file mode 100644 index 0000000000..b4607be950 --- /dev/null +++ b/talpid-tunnel-config-client/examples/psk-exchange.rs @@ -0,0 +1,25 @@ +use std::{ + io, + net::{IpAddr, Ipv4Addr}, +}; + +use talpid_types::net::wireguard::PublicKey; + +#[tokio::main] +async fn main() { + println!("Make sure you're connected to a WireGuard peer and enter your public key: "); + + let mut pubkey_s = String::new(); + io::stdin() + .read_line(&mut pubkey_s) + .expect("Failed to read from stdin"); + let pubkey = PublicKey::from_base64(pubkey_s.trim()).expect("Invalid public key"); + + let (private_key, psk) = + talpid_tunnel_config_client::push_pq_key(IpAddr::V4(Ipv4Addr::new(10, 64, 0, 1)), pubkey) + .await + .unwrap(); + + println!("private key: {:?}", private_key); + println!("psk: {:?}", psk); +} diff --git a/talpid-tunnel-config-client/proto/tunnel_config.proto b/talpid-tunnel-config-client/proto/tunnel_config.proto new file mode 100644 index 0000000000..9ce4232d0d --- /dev/null +++ b/talpid-tunnel-config-client/proto/tunnel_config.proto @@ -0,0 +1,38 @@ +// +// If you need to (re)generate the gRPC code, see prerequisites +// +// https://grpc.io/docs/languages/go/quickstart/ +// +// and then run +// +// protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative tunnel_config.proto +// +// from this directory. +// + +syntax = "proto3"; + +option go_package = "github.com/mullvad/wg-manager/server/tuncfg"; + +package tunnel_config; + +service PostQuantumSecure { + // PskExchangeExperimentalV0 uses the common API defined by LibOQS. See: + // https://github.com/open-quantum-safe/liboqs + rpc PskExchangeExperimentalV0(PskRequestExperimentalV0) returns (PskResponseExperimentalV0) {} +} + +message PskRequestExperimentalV0 { + bytes wg_pubkey = 1; + bytes wg_psk_pubkey = 2; + KemPubkeyExperimentalV0 kem_pubkey = 3; +} + +message KemPubkeyExperimentalV0 { + string algorithm_name = 1; + bytes key_data = 2; +} + +message PskResponseExperimentalV0 { + bytes ciphertext = 1; +} diff --git a/talpid-tunnel-config-client/src/kem.rs b/talpid-tunnel-config-client/src/kem.rs new file mode 100644 index 0000000000..35679ccddf --- /dev/null +++ b/talpid-tunnel-config-client/src/kem.rs @@ -0,0 +1,67 @@ +use std::fmt; + +use super::Error; + +use classic_mceliece_rust::{ + crypto_kem_dec, crypto_kem_keypair, CRYPTO_BYTES, CRYPTO_PUBLICKEYBYTES, CRYPTO_SECRETKEYBYTES, +}; +use talpid_types::net::wireguard::PresharedKey; + +const STACK_SIZE: usize = 8 * 1024 * 1024; +pub use classic_mceliece_rust::CRYPTO_CIPHERTEXTBYTES; + +#[derive(Debug)] +pub struct PublicKey(Box<[u8; CRYPTO_PUBLICKEYBYTES]>); + +impl PublicKey { + pub fn into_vec(self) -> Vec<u8> { + (self.0 as Box<[u8]>).into_vec() + } +} + +pub struct SecretKey(Box<[u8; CRYPTO_SECRETKEYBYTES]>); + +impl fmt::Debug for SecretKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SecretKey").finish() + } +} + +pub async fn generate_keys() -> Result<(PublicKey, SecretKey), Error> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let gen_key = move || { + let mut rng = rand::thread_rng(); + let mut pubkey = Box::new([0u8; CRYPTO_PUBLICKEYBYTES]); + let mut secret = Box::new([0u8; CRYPTO_SECRETKEYBYTES]); + crypto_kem_keypair(&mut pubkey, &mut secret, &mut rng).map_err(|error| { + log::error!("KEM keypair generation failed: {error}"); + Error::KeyGenerationFailed + })?; + + Ok((PublicKey(pubkey), SecretKey(secret))) + }; + + std::thread::Builder::new() + .stack_size(STACK_SIZE) + .spawn(move || { + let _ = tx.send(gen_key()); + }) + .unwrap(); + + rx.await.unwrap() +} + +pub fn decapsulate( + secret: &SecretKey, + ciphertext: &[u8; CRYPTO_CIPHERTEXTBYTES], +) -> Result<PresharedKey, Error> { + let mut psk = [0u8; CRYPTO_BYTES]; + + crypto_kem_dec(&mut psk, ciphertext, &secret.0).map_err(|error| { + log::error!("KEM decapsulation failed: {error}"); + Error::DecapsulationError + })?; + + Ok(PresharedKey::from(psk)) +} diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs new file mode 100644 index 0000000000..fafb95affb --- /dev/null +++ b/talpid-tunnel-config-client/src/lib.rs @@ -0,0 +1,84 @@ +use std::{fmt, net::IpAddr}; +use talpid_types::net::wireguard::{PresharedKey, PrivateKey, PublicKey}; +use tonic::transport::Channel; + +mod kem; + +mod proto { + tonic::include_proto!("tunnel_config"); +} + +#[derive(Debug)] +pub enum Error { + GrpcConnectError(tonic::transport::Error), + GrpcError(tonic::Status), + KeyGenerationFailed, + DecapsulationError, + InvalidCiphertext, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Error::*; + match self { + GrpcConnectError(_) => "Failed to connect to config service".fmt(f), + GrpcError(status) => write!(f, "RPC failed: {}", status), + KeyGenerationFailed => "Failed to generate KEM key pair".fmt(f), + DecapsulationError => "Failed to decapsulate secret".fmt(f), + InvalidCiphertext => "The service returned an invalid ciphertext".fmt(f), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::GrpcConnectError(error) => Some(error), + _ => None, + } + } +} + +type RelayConfigService = proto::post_quantum_secure_client::PostQuantumSecureClient<Channel>; + +const CONFIG_SERVICE_PORT: u16 = 1337; +const ALGORITHM_NAME: &str = "Classic-McEliece-8192128f"; + +/// Generates a new WireGuard key pair and negotiates a PSK with the relay in a PQ-safe +/// manner. This creates a peer on the relay with the new WireGuard pubkey and PSK, +/// which can then be used to establish a PQ-safe tunnel to the relay. +// TODO: consider binding to the tunnel interface here, on non-windows platforms +pub async fn push_pq_key( + service_address: IpAddr, + wg_pubkey: PublicKey, +) -> Result<(PrivateKey, PresharedKey), Error> { + let wg_psk_privkey = PrivateKey::new_from_random(); + let (kem_pubkey, kem_secret) = kem::generate_keys().await?; + + let mut client = new_client(service_address).await?; + let response = client + .psk_exchange_experimental_v0(proto::PskRequestExperimentalV0 { + wg_pubkey: wg_pubkey.as_bytes().to_vec(), + wg_psk_pubkey: wg_psk_privkey.public_key().as_bytes().to_vec(), + kem_pubkey: Some(proto::KemPubkeyExperimentalV0 { + algorithm_name: ALGORITHM_NAME.to_string(), + key_data: kem_pubkey.into_vec(), + }), + }) + .await + .map_err(Error::GrpcError)?; + + let ciphertext: [u8; kem::CRYPTO_CIPHERTEXTBYTES] = response + .into_inner() + .ciphertext + .try_into() + .map_err(|_| Error::InvalidCiphertext)?; + + Ok((wg_psk_privkey, kem::decapsulate(&kem_secret, &ciphertext)?)) +} + +async fn new_client(addr: IpAddr) -> Result<RelayConfigService, Error> { + RelayConfigService::connect(format!("tcp://{addr}:{CONFIG_SERVICE_PORT}")) + .await + .map_err(Error::GrpcConnectError) +} diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index 36bfba4f9d..a35c227ddd 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -29,6 +29,7 @@ impl TunnelParameters { match self { TunnelParameters::OpenVpn(params) => TunnelEndpoint { tunnel_type: TunnelType::OpenVpn, + quantum_resistant: false, endpoint: params.config.endpoint, proxy: params.proxy.as_ref().map(|proxy| proxy.get_endpoint()), obfuscation: None, @@ -36,6 +37,7 @@ impl TunnelParameters { }, TunnelParameters::Wireguard(params) => TunnelEndpoint { tunnel_type: TunnelType::Wireguard, + quantum_resistant: params.options.use_pq_safe_psk, endpoint: params .connection .get_exit_endpoint() @@ -134,6 +136,8 @@ pub struct TunnelEndpoint { #[cfg_attr(target_os = "android", jnix(skip))] pub tunnel_type: TunnelType, #[cfg_attr(target_os = "android", jnix(skip))] + pub quantum_resistant: bool, + #[cfg_attr(target_os = "android", jnix(skip))] pub proxy: Option<proxy::ProxyEndpoint>, #[cfg_attr(target_os = "android", jnix(skip))] pub obfuscation: Option<ObfuscationEndpoint>, @@ -143,7 +147,11 @@ pub struct TunnelEndpoint { impl fmt::Display for TunnelEndpoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(f, "{} - {}", self.tunnel_type, self.endpoint)?; + write!(f, "{} ", self.tunnel_type)?; + if self.quantum_resistant { + write!(f, "(quantum resistant) ")?; + } + write!(f, "- {}", self.endpoint)?; match self.tunnel_type { TunnelType::OpenVpn => { if let Some(ref proxy) = self.proxy { @@ -275,6 +283,52 @@ impl fmt::Display for AllowedEndpoint { } } +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum AllowedTunnelTraffic { + None, + All, + Only(SocketAddr, Protocol), +} + +impl fmt::Display for AllowedTunnelTraffic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match *self { + AllowedTunnelTraffic::None => "None".fmt(f), + AllowedTunnelTraffic::All => "All".fmt(f), + AllowedTunnelTraffic::Only(addr, proto) => write!(f, "{}/{}", addr, proto), + } + } +} + +/// A protocol: UDP, TCP, or ICMP. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub enum Protocol { + Udp, + Tcp, + IcmpV4, + IcmpV6, +} + +impl fmt::Display for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + Protocol::Udp => "UDP".fmt(f), + Protocol::Tcp => "TCP".fmt(f), + Protocol::IcmpV4 => "ICMPv4".fmt(f), + Protocol::IcmpV6 => "ICMPv6".fmt(f), + } + } +} + +impl From<TransportProtocol> for Protocol { + fn from(proto: TransportProtocol) -> Self { + match proto { + TransportProtocol::Udp => Protocol::Udp, + TransportProtocol::Tcp => Protocol::Tcp, + } + } +} + /// IP protocol version. #[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index 199ff3bd29..ff7ebef090 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -55,6 +55,8 @@ pub struct PeerConfig { pub allowed_ips: Vec<IpNetwork>, /// IP address of the WireGuard server. pub endpoint: SocketAddr, + /// Preshared key. + pub psk: Option<PresharedKey>, } #[derive(Clone, Eq, PartialEq, Deserialize, Serialize, Debug)] @@ -66,6 +68,7 @@ pub struct TunnelConfig { /// Options in [`TunnelParameters`] that apply to any WireGuard connection. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(default)] #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr( target_os = "android", @@ -78,6 +81,8 @@ pub struct TunnelOptions { jnix(map = "|maybe_mtu| maybe_mtu.map(|mtu| mtu as i32)") )] pub mtu: Option<u16>, + /// Obtain a PSK using the relay config client. + pub use_pq_safe_psk: bool, /// Temporary switch for wireguard-nt #[cfg(windows)] #[serde(default = "default_wgnt_setting")] @@ -94,6 +99,7 @@ impl Default for TunnelOptions { fn default() -> Self { Self { mtu: None, + use_pq_safe_psk: false, #[cfg(windows)] use_wireguard_nt: default_wgnt_setting(), } @@ -253,6 +259,40 @@ impl fmt::Display for PublicKey { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct PresharedKey([u8; 32]); + +impl PresharedKey { + /// Get the PSK as bytes + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } +} + +impl From<[u8; 32]> for PresharedKey { + fn from(key: [u8; 32]) -> PresharedKey { + PresharedKey(key) + } +} + +impl Serialize for PresharedKey { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serialize_key(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for PresharedKey { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + deserialize_key(deserializer) + } +} + fn serialize_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index fe6c227e7f..b72cefe2d3 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -185,7 +185,8 @@ bool FwContext::applyPolicyConnecting const WinFwEndpoint &relay, const std::wstring &relayClient, const std::optional<std::wstring> &tunnelInterfaceAlias, - const std::optional<WinFwAllowedEndpoint> &allowedEndpoint + const std::optional<WinFwAllowedEndpoint> &allowedEndpoint, + const WinFwAllowedTunnelTraffic &allowedTunnelTraffic ) { Ruleset ruleset; @@ -201,13 +202,39 @@ bool FwContext::applyPolicyConnecting if (tunnelInterfaceAlias.has_value()) { - ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>( - *tunnelInterfaceAlias - )); - - ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>( - *tunnelInterfaceAlias - )); + switch (allowedTunnelTraffic.type) + { + case WinFwAllowedTunnelTrafficType::All: + { + ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>( + *tunnelInterfaceAlias, + std::nullopt + )); + ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>( + *tunnelInterfaceAlias, + std::nullopt + )); + break; + } + case WinFwAllowedTunnelTrafficType::Only: + { + const auto onlyEndpoint = std::make_optional(baseline::PermitVpnTunnel::Endpoint{ + wfp::IpAddress(allowedTunnelTraffic.endpoint->ip), + allowedTunnelTraffic.endpoint->port, + allowedTunnelTraffic.endpoint->protocol + }); + ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>( + *tunnelInterfaceAlias, + onlyEndpoint + )); + ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>( + *tunnelInterfaceAlias, + onlyEndpoint + )); + break; + } + // For the "None" case, do nothing. + } } const auto status = applyRuleset(ruleset); @@ -250,11 +277,13 @@ bool FwContext::applyPolicyConnected } ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>( - tunnelInterfaceAlias + tunnelInterfaceAlias, + std::nullopt )); ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>( - tunnelInterfaceAlias + tunnelInterfaceAlias, + std::nullopt )); const auto status = applyRuleset(ruleset); diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index bf67565993..5fc23f09a7 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -30,7 +30,8 @@ public: const WinFwEndpoint &relay, const std::wstring &relayClient, const std::optional<std::wstring> &tunnelInterfaceAlias, - const std::optional<WinFwAllowedEndpoint> &allowedEndpoint + const std::optional<WinFwAllowedEndpoint> &allowedEndpoint, + const WinFwAllowedTunnelTraffic &allowedTunnelTraffic ); bool applyPolicyConnected diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp index 09d8937535..c1c74ba6ba 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp +++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "permitendpoint.h" #include <winfw/mullvadguids.h> +#include <winfw/rules/shared.h> #include <libwfp/filterbuilder.h> #include <libwfp/conditionbuilder.h> #include <libwfp/conditions/conditionprotocol.h> @@ -30,19 +31,6 @@ const GUID &OutboundLayerFromIp(const wfp::IpAddress &ip) }; } -std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol) -{ - switch (protocol) - { - case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); - case WinFwProtocol::Udp: return ConditionProtocol::Udp(); - default: - { - THROW_ERROR("Missing case handler in switch clause"); - } - }; -} - } // anonymous namespace PermitEndpoint::PermitEndpoint diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp index e756d68464..d9a1af0f28 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp @@ -1,17 +1,26 @@ #include "stdafx.h" #include "permitvpntunnel.h" #include <winfw/mullvadguids.h> +#include <winfw/rules/shared.h> #include <libwfp/filterbuilder.h> #include <libwfp/conditionbuilder.h> #include <libwfp/conditions/conditioninterface.h> +#include <libwfp/conditions/conditionip.h> +#include <libwfp/conditions/conditionport.h> +#include <libwfp/conditions/conditionprotocol.h> +#include <libcommon/error.h> using namespace wfp::conditions; namespace rules::baseline { -PermitVpnTunnel::PermitVpnTunnel(const std::wstring &tunnelInterfaceAlias) +PermitVpnTunnel::PermitVpnTunnel( + const std::wstring &tunnelInterfaceAlias, + const std::optional<Endpoint> &onlyEndpoint +) : m_tunnelInterfaceAlias(tunnelInterfaceAlias) + , m_tunnelOnlyEndpoint(onlyEndpoint) { } @@ -19,6 +28,9 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller) { wfp::FilterBuilder filterBuilder; + bool includeV4 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv4; + bool includeV6 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv6; + // // #1 Permit outbound connections, IPv4. // @@ -33,11 +45,22 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller) .weight(wfp::FilterBuilder::WeightClass::Medium) .permit(); + if (includeV4) { wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4); conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); + if (m_tunnelOnlyEndpoint.has_value()) + { + conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); + if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) + { + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); + } + conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); + } + if (!objectInstaller.addFilter(filterBuilder, conditionBuilder)) { return false; @@ -53,11 +76,26 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller) .name(L"Permit outbound connections on tunnel interface (IPv6)") .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6); - wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6); + if (includeV6) + { + wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6); - conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); + conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); + + if (m_tunnelOnlyEndpoint.has_value()) + { + conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); + if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) + { + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); + } + conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); + } - return objectInstaller.addFilter(filterBuilder, conditionBuilder); + return true; } } diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h index 9c9a7b14c1..ee030a9f43 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h @@ -1,7 +1,10 @@ #pragma once #include <winfw/rules/ifirewallrule.h> +#include <winfw/winfw.h> +#include <libwfp/ipaddress.h> #include <string> +#include <optional> namespace rules::baseline { @@ -10,13 +13,23 @@ class PermitVpnTunnel : public IFirewallRule { public: - PermitVpnTunnel(const std::wstring &tunnelInterfaceAlias); + struct Endpoint { + wfp::IpAddress ip; + uint16_t port; + WinFwProtocol protocol; + }; + + PermitVpnTunnel( + const std::wstring &tunnelInterfaceAlias, + const std::optional<Endpoint> &onlyEndpoint + ); bool apply(IObjectInstaller &objectInstaller) override; private: const std::wstring m_tunnelInterfaceAlias; + const std::optional<Endpoint> m_tunnelOnlyEndpoint; }; } diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp index 00fbc8e76b..42214b6a77 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp @@ -1,17 +1,26 @@ #include "stdafx.h" #include "permitvpntunnelservice.h" #include <winfw/mullvadguids.h> +#include <winfw/rules/shared.h> #include <libwfp/filterbuilder.h> #include <libwfp/conditionbuilder.h> #include <libwfp/conditions/conditioninterface.h> +#include <libwfp/conditions/conditionip.h> +#include <libwfp/conditions/conditionport.h> +#include <libwfp/conditions/conditionprotocol.h> +#include <libcommon/error.h> using namespace wfp::conditions; namespace rules::baseline { -PermitVpnTunnelService::PermitVpnTunnelService(const std::wstring &tunnelInterfaceAlias) +PermitVpnTunnelService::PermitVpnTunnelService( + const std::wstring &tunnelInterfaceAlias, + const std::optional<PermitVpnTunnel::Endpoint> &onlyEndpoint +) : m_tunnelInterfaceAlias(tunnelInterfaceAlias) + , m_tunnelOnlyEndpoint(onlyEndpoint) { } @@ -19,6 +28,9 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller) { wfp::FilterBuilder filterBuilder; + bool includeV4 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv4; + bool includeV6 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv6; + // // #1 Permit inbound connections, IPv4. // @@ -35,26 +47,54 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller) wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4); - conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); - - if (!objectInstaller.addFilter(filterBuilder, conditionBuilder)) + if (includeV4) { - return false; + conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); + + if (m_tunnelOnlyEndpoint.has_value()) + { + conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); + if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) + { + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); + } + conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); + } + + if (!objectInstaller.addFilter(filterBuilder, conditionBuilder)) + { + return false; + } } // // #2 Permit inbound connections, IPv6. // - filterBuilder - .key(MullvadGuids::Filter_Baseline_PermitVpnTunnelService_Ipv6()) - .name(L"Permit inbound connections on tunnel interface (IPv6)") - .layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6); + if (includeV6) + { + filterBuilder + .key(MullvadGuids::Filter_Baseline_PermitVpnTunnelService_Ipv6()) + .name(L"Permit inbound connections on tunnel interface (IPv6)") + .layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6); + + conditionBuilder.reset(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6); + conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); - conditionBuilder.reset(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6); - conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias)); + if (m_tunnelOnlyEndpoint.has_value()) + { + conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip)); + if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol)) + { + conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port)); + } + conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol)); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); + } - return objectInstaller.addFilter(filterBuilder, conditionBuilder); + return true; } } diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h index 8880c06328..8011659b97 100644 --- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h +++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h @@ -1,7 +1,11 @@ #pragma once #include <winfw/rules/ifirewallrule.h> +#include <winfw/rules/baseline/permitvpntunnel.h> +#include <winfw/winfw.h> +#include <libwfp/ipaddress.h> #include <string> +#include <optional> namespace rules::baseline { @@ -10,13 +14,17 @@ class PermitVpnTunnelService : public IFirewallRule { public: - PermitVpnTunnelService(const std::wstring &tunnelInterfaceAlias); + PermitVpnTunnelService( + const std::wstring &tunnelInterfaceAlias, + const std::optional<PermitVpnTunnel::Endpoint> &onlyEndpoint + ); bool apply(IObjectInstaller &objectInstaller) override; private: const std::wstring m_tunnelInterfaceAlias; + const std::optional<PermitVpnTunnel::Endpoint> m_tunnelOnlyEndpoint; }; } diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp index a403230df9..3c913cab14 100644 --- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp +++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp @@ -2,6 +2,7 @@ #include "permitvpnrelay.h" #include <winfw/mullvadguids.h> #include <winfw/winfw.h> +#include <winfw/rules/shared.h> #include <libwfp/filterbuilder.h> #include <libwfp/conditionbuilder.h> #include <libwfp/conditions/conditionprotocol.h> @@ -31,19 +32,6 @@ const GUID &LayerFromIp(const wfp::IpAddress &ip) }; } -std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol) -{ - switch (protocol) - { - case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); - case WinFwProtocol::Udp: return ConditionProtocol::Udp(); - default: - { - THROW_ERROR("Missing case handler in switch clause"); - } - }; -} - const GUID &TranslateSublayer(PermitVpnRelay::Sublayer sublayer) { switch (sublayer) diff --git a/windows/winfw/src/winfw/rules/shared.cpp b/windows/winfw/src/winfw/rules/shared.cpp index 66cbbdfc83..1d1123e3eb 100644 --- a/windows/winfw/src/winfw/rules/shared.cpp +++ b/windows/winfw/src/winfw/rules/shared.cpp @@ -2,6 +2,8 @@ #include "shared.h" #include <libcommon/error.h> +using namespace wfp::conditions; + namespace rules { @@ -37,4 +39,36 @@ void SplitAddresses(const IpSet &in, IpSet &outIpv4, IpSet &outIpv6) } } +std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol) +{ + switch (protocol) + { + case WinFwProtocol::Tcp: return ConditionProtocol::Tcp(); + case WinFwProtocol::Udp: return ConditionProtocol::Udp(); + case WinFwProtocol::Icmp: return ConditionProtocol::Icmp(); + case WinFwProtocol::IcmpV6: return ConditionProtocol::IcmpV6(); + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + +bool ProtocolHasPort(WinFwProtocol protocol) +{ + switch (protocol) + { + case WinFwProtocol::Tcp: + case WinFwProtocol::Udp: + return true; + case WinFwProtocol::Icmp: + case WinFwProtocol::IcmpV6: + return false; + default: + { + THROW_ERROR("Missing case handler in switch clause"); + } + }; +} + } diff --git a/windows/winfw/src/winfw/rules/shared.h b/windows/winfw/src/winfw/rules/shared.h index 1b08d3ed02..4f4da187ca 100644 --- a/windows/winfw/src/winfw/rules/shared.h +++ b/windows/winfw/src/winfw/rules/shared.h @@ -1,6 +1,9 @@ #pragma once #include <vector> +#include <memory> +#include <winfw/winfw.h> +#include <libwfp/conditions/conditionprotocol.h> #include <libwfp/ipaddress.h> namespace rules @@ -10,4 +13,8 @@ using IpSet = std::vector<wfp::IpAddress>; void SplitAddresses(const IpSet &in, IpSet &outIpv4, IpSet &outIpv6); +std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol); + +bool ProtocolHasPort(WinFwProtocol protocol); + } diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index ae0f0791de..4110dcd2f8 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -233,7 +233,8 @@ WinFw_ApplyPolicyConnecting( const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, - const WinFwAllowedEndpoint *allowedEndpoint + const WinFwAllowedEndpoint *allowedEndpoint, + const WinFwAllowedTunnelTraffic *allowedTunnelTraffic ) { if (nullptr == g_fwContext) @@ -258,12 +259,18 @@ WinFw_ApplyPolicyConnecting( THROW_ERROR("Invalid argument: relayClient"); } + if (nullptr == allowedTunnelTraffic) + { + THROW_ERROR("Invalid argument: allowedTunnelTraffic"); + } + return g_fwContext->applyPolicyConnecting( *settings, *relay, relayClient, tunnelInterfaceAlias != nullptr ? std::make_optional(tunnelInterfaceAlias) : std::nullopt, - MakeOptional(allowedEndpoint) + MakeOptional(allowedEndpoint), + *allowedTunnelTraffic ) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE; } catch (common::error::WindowsException &err) diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index 4913ae7baf..6394893d91 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -32,7 +32,9 @@ WinFwSettings; enum WinFwProtocol : uint8_t { Tcp = 0, - Udp = 1 + Udp = 1, + Icmp = 2, + IcmpV6 = 3 }; typedef struct tag_WinFwEndpoint @@ -55,6 +57,20 @@ typedef struct tag_WinFwAllowedEndpoint } WinFwAllowedEndpoint; +enum WinFwAllowedTunnelTrafficType : uint8_t +{ + None, + All, + Only +}; + +typedef struct tag_WinFwAllowedTunnelTraffic +{ + WinFwAllowedTunnelTrafficType type; + WinFwEndpoint *endpoint; +} +WinFwAllowedTunnelTraffic; + /////////////////////////////////////////////////////////////////////////////// // Functions /////////////////////////////////////////////////////////////////////////////// @@ -139,7 +155,7 @@ enum WINFW_POLICY_STATUS // Apply restrictions in the firewall that block all traffic, except: // - What is specified by settings // - Communication with the relay server -// - Non-DNS traffic inside the VPN tunnel +// - Specified in-tunnel traffic, except DNS. // extern "C" WINFW_LINKAGE @@ -150,7 +166,8 @@ WinFw_ApplyPolicyConnecting( const WinFwEndpoint *relay, const wchar_t *relayClient, const wchar_t *tunnelInterfaceAlias, - const WinFwAllowedEndpoint *allowedEndpoint + const WinFwAllowedEndpoint *allowedEndpoint, + const WinFwAllowedTunnelTraffic *allowedTunnelTraffic ); // diff --git a/wireguard/libwg/libwg.go b/wireguard/libwg/libwg.go index 82c6e8205f..e26ea7b7da 100644 --- a/wireguard/libwg/libwg.go +++ b/wireguard/libwg/libwg.go @@ -13,6 +13,7 @@ import ( "bufio" "bytes" "runtime" + "strings" "unsafe" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer" @@ -59,6 +60,27 @@ func wgGetConfig(tunnelHandle int32) *C.char { return C.CString(settings.String()) } +//export wgSetConfig +func wgSetConfig(tunnelHandle int32, cSettings *C.char) int32 { + tunnel, err := tunnels.Get(tunnelHandle) + if err != nil { + return ERROR_GENERAL_FAILURE + } + if cSettings == nil { + tunnel.Logger.Errorf("cSettings is null\n") + return ERROR_GENERAL_FAILURE + } + settings := C.GoString(cSettings) + + setError := tunnel.Device.IpcSetOperation(bufio.NewReader(strings.NewReader(settings))) + if setError != nil { + tunnel.Logger.Errorf("Failed to set device configuration\n") + tunnel.Logger.Errorf("%s\n", setError) + return ERROR_GENERAL_FAILURE + } + return 0 +} + //export wgFreePtr func wgFreePtr(ptr unsafe.Pointer) { C.free(ptr) |
