summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-06-14 13:05:24 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-06-14 13:05:24 +0200
commitc3bdb0ebd3e99c22051656219f2db5a715da0a0d (patch)
treea83f10937eccad7e8a001cfc46357725c15a65da
parenta38dce1ce10893f8e1a077c95ca4afe085bcaeea (diff)
parent7f49fb807b4e49113dc078302fa1aabdbd0a931e (diff)
downloadmullvadvpn-c3bdb0ebd3e99c22051656219f2db5a715da0a0d.tar.xz
mullvadvpn-c3bdb0ebd3e99c22051656219f2db5a715da0a0d.zip
Merge branch 'add-pq-safe-tunnels'
-rw-r--r--CHANGELOG.md1
-rw-r--r--Cargo.lock95
-rw-r--r--Cargo.toml6
-rw-r--r--gui/locales/messages.pot6
-rw-r--r--gui/src/renderer/components/SecuredLabel.tsx10
-rw-r--r--gui/src/renderer/components/TunnelControl.tsx22
-rw-r--r--gui/src/shared/daemon-rpc-types.ts1
-rw-r--r--mullvad-cli/src/cmds/relay.rs19
-rw-r--r--mullvad-cli/src/cmds/tunnel.rs56
-rw-r--r--mullvad-cli/src/format.rs11
-rw-r--r--mullvad-daemon/src/lib.rs41
-rw-r--r--mullvad-daemon/src/management_interface.rs11
-rw-r--r--mullvad-daemon/src/settings.rs16
-rw-r--r--mullvad-management-interface/proto/management_interface.proto9
-rw-r--r--mullvad-management-interface/src/types.rs4
-rw-r--r--mullvad-relay-selector/src/matcher.rs1
-rw-r--r--mullvad-types/src/wireguard.rs3
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/firewall/linux.rs60
-rw-r--r--talpid-core/src/firewall/macos.rs46
-rw-r--r--talpid-core/src/firewall/mod.rs8
-rw-r--r--talpid-core/src/firewall/windows.rs70
-rw-r--r--talpid-core/src/tunnel/mod.rs16
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs7
-rw-r--r--talpid-core/src/tunnel/wireguard/config.rs4
-rw-r--r--talpid-core/src/tunnel/wireguard/connectivity_check.rs11
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs199
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs20
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs20
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs21
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs8
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_nt.rs37
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs23
-rw-r--r--talpid-core/src/winnet.rs2
-rw-r--r--talpid-tunnel-config-client/Cargo.toml25
-rw-r--r--talpid-tunnel-config-client/build.rs5
-rw-r--r--talpid-tunnel-config-client/examples/psk-exchange.rs25
-rw-r--r--talpid-tunnel-config-client/proto/tunnel_config.proto38
-rw-r--r--talpid-tunnel-config-client/src/kem.rs67
-rw-r--r--talpid-tunnel-config-client/src/lib.rs84
-rw-r--r--talpid-types/src/net/mod.rs56
-rw-r--r--talpid-types/src/net/wireguard.rs40
-rw-r--r--windows/winfw/src/winfw/fwcontext.cpp49
-rw-r--r--windows/winfw/src/winfw/fwcontext.h3
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp14
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp46
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h15
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp64
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h10
-rw-r--r--windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp14
-rw-r--r--windows/winfw/src/winfw/rules/shared.cpp34
-rw-r--r--windows/winfw/src/winfw/rules/shared.h7
-rw-r--r--windows/winfw/src/winfw/winfw.cpp11
-rw-r--r--windows/winfw/src/winfw/winfw.h23
-rw-r--r--wireguard/libwg/libwg.go22
55 files changed, 1367 insertions, 150 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a985a3d4a6..e49f5da290 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,6 +25,7 @@ Line wrap the file at 100 chars. Th
## [Unreleased]
### Added
- Add option to filter relays by ownership in the desktop apps.
+- Experimental: Add support for quantum-resistant PSK exchange to the CLI.
#### Linux
- Automatically attempt to detect and set the correct MTU for Wireguard tunnels.
diff --git a/Cargo.lock b/Cargo.lock
index 7d2f5f5582..5ef4205b00 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -338,6 +338,15 @@ dependencies = [
]
[[package]]
+name = "classic-mceliece-rust"
+version = "1.0.1"
+source = "git+https://github.com/mullvad/classic-mceliece-rust?rev=5130d9e3bfbf54735177e15636a643366c250b78#5130d9e3bfbf54735177e15636a643366c250b78"
+dependencies = [
+ "rand 0.8.4",
+ "sha3",
+]
+
+[[package]]
name = "colored"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1406,6 +1415,12 @@ dependencies = [
]
[[package]]
+name = "keccak"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9b7d56ba4a8344d6be9729995e6b06f928af29998cdf79fe390cbf6b1fee838"
+
+[[package]]
name = "kernel32-sys"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1779,8 +1794,8 @@ dependencies = [
"mullvad-types",
"nix 0.23.1",
"parity-tokio-ipc",
- "prost",
- "prost-types",
+ "prost 0.8.0",
+ "prost-types 0.8.0",
"talpid-types",
"tokio",
"tonic",
@@ -2468,7 +2483,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020"
dependencies = [
"bytes",
- "prost-derive",
+ "prost-derive 0.8.0",
+]
+
+[[package]]
+name = "prost"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001"
+dependencies = [
+ "bytes",
+ "prost-derive 0.9.0",
]
[[package]]
@@ -2483,8 +2508,8 @@ dependencies = [
"log",
"multimap",
"petgraph",
- "prost",
- "prost-types",
+ "prost 0.8.0",
+ "prost-types 0.8.0",
"tempfile",
"which",
]
@@ -2503,13 +2528,36 @@ dependencies = [
]
[[package]]
+name = "prost-derive"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe"
+dependencies = [
+ "anyhow",
+ "itertools",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "prost-types"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b"
dependencies = [
"bytes",
- "prost",
+ "prost 0.8.0",
+]
+
+[[package]]
+name = "prost-types"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a"
+dependencies = [
+ "bytes",
+ "prost 0.9.0",
]
[[package]]
@@ -2995,6 +3043,16 @@ dependencies = [
]
[[package]]
+name = "sha3"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "31f935e31cf406e8c0e96c2815a5516181b7004ae8c5f296293221e9b1e356bd"
+dependencies = [
+ "digest 0.10.1",
+ "keccak",
+]
+
+[[package]]
name = "shadowsocks"
version = "1.14.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3267,7 +3325,7 @@ dependencies = [
"parity-tokio-ipc",
"parking_lot 0.11.2",
"pfctl",
- "prost",
+ "prost 0.8.0",
"quickcheck",
"quickcheck_macros",
"rand 0.7.3",
@@ -3282,6 +3340,7 @@ dependencies = [
"talpid-dbus",
"talpid-platform-metadata",
"talpid-time",
+ "talpid-tunnel-config-client",
"talpid-types",
"tempfile",
"tokio",
@@ -3321,7 +3380,7 @@ dependencies = [
"log",
"openvpn-plugin",
"parity-tokio-ipc",
- "prost",
+ "prost 0.8.0",
"talpid-types",
"tokio",
"tonic",
@@ -3349,6 +3408,22 @@ dependencies = [
]
[[package]]
+name = "talpid-tunnel-config-client"
+version = "0.1.0"
+dependencies = [
+ "classic-mceliece-rust",
+ "log",
+ "prost 0.8.0",
+ "prost-types 0.9.0",
+ "rand 0.8.4",
+ "talpid-types",
+ "tokio",
+ "tonic",
+ "tonic-build",
+ "tower",
+]
+
+[[package]]
name = "talpid-types"
version = "0.1.0"
dependencies = [
@@ -3566,8 +3641,8 @@ dependencies = [
"hyper-timeout",
"percent-encoding",
"pin-project",
- "prost",
- "prost-derive",
+ "prost 0.8.0",
+ "prost-derive 0.8.0",
"tokio",
"tokio-stream",
"tokio-util",
diff --git a/Cargo.toml b/Cargo.toml
index 75bbc7d0ae..209384de44 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -17,6 +17,7 @@ members = [
"talpid-dbus",
"talpid-platform-metadata",
"talpid-time",
+ "talpid-tunnel-config-client",
"mullvad-management-interface",
"tunnel-obfuscation",
]
@@ -24,3 +25,8 @@ members = [
[profile.release]
opt-level = 3
lto = true
+
+# Key generation may take over one minute without optimizations
+# enabled.
+[profile.dev.package."classic-mceliece-rust"]
+opt-level = 3
diff --git a/gui/locales/messages.pot b/gui/locales/messages.pot
index 9deffc1e8d..30997c9290 100644
--- a/gui/locales/messages.pot
+++ b/gui/locales/messages.pot
@@ -108,6 +108,9 @@ msgstr ""
msgid "Connecting"
msgstr ""
+msgid "CREATING QUANTUM SECURE CONNECTION"
+msgstr ""
+
msgid "CREATING SECURE CONNECTION"
msgstr ""
@@ -172,6 +175,9 @@ msgstr ""
msgid "Open URL"
msgstr ""
+msgid "QUANTUM SECURE CONNECTION"
+msgstr ""
+
msgid "Reconnect"
msgstr ""
diff --git a/gui/src/renderer/components/SecuredLabel.tsx b/gui/src/renderer/components/SecuredLabel.tsx
index cda3c5174d..f622c8a847 100644
--- a/gui/src/renderer/components/SecuredLabel.tsx
+++ b/gui/src/renderer/components/SecuredLabel.tsx
@@ -5,8 +5,10 @@ import { messages } from '../../shared/gettext';
export enum SecuredDisplayStyle {
secured,
+ securedPq,
blocked,
securing,
+ securingPq,
unsecured,
unsecuring,
failedToSecure,
@@ -14,8 +16,10 @@ export enum SecuredDisplayStyle {
const securedDisplayStyleColorMap = {
[SecuredDisplayStyle.securing]: colors.white,
+ [SecuredDisplayStyle.securingPq]: colors.white,
[SecuredDisplayStyle.unsecuring]: colors.white,
[SecuredDisplayStyle.secured]: colors.green,
+ [SecuredDisplayStyle.securedPq]: colors.green,
[SecuredDisplayStyle.blocked]: colors.white,
[SecuredDisplayStyle.unsecured]: colors.red,
[SecuredDisplayStyle.failedToSecure]: colors.red,
@@ -45,12 +49,18 @@ function getLabelText(displayStyle: SecuredDisplayStyle) {
case SecuredDisplayStyle.secured:
return messages.gettext('SECURE CONNECTION');
+ case SecuredDisplayStyle.securedPq:
+ return messages.gettext('QUANTUM SECURE CONNECTION');
+
case SecuredDisplayStyle.blocked:
return messages.gettext('BLOCKED CONNECTION');
case SecuredDisplayStyle.securing:
return messages.gettext('CREATING SECURE CONNECTION');
+ case SecuredDisplayStyle.securingPq:
+ return messages.gettext('CREATING QUANTUM SECURE CONNECTION');
+
case SecuredDisplayStyle.unsecured:
return messages.gettext('UNSECURED CONNECTION');
diff --git a/gui/src/renderer/components/TunnelControl.tsx b/gui/src/renderer/components/TunnelControl.tsx
index c7617bfdda..4d210ed2c2 100644
--- a/gui/src/renderer/components/TunnelControl.tsx
+++ b/gui/src/renderer/components/TunnelControl.tsx
@@ -77,6 +77,7 @@ const SelectedLocationChevron = styled(AppButton.Icon)({
export default class TunnelControl extends React.Component<ITunnelControlProps> {
public render() {
let state = this.props.tunnelState.state;
+ let pq = false;
switch (this.props.tunnelState.state) {
case 'disconnecting':
@@ -92,14 +93,23 @@ export default class TunnelControl extends React.Component<ITunnelControlProps>
break;
}
break;
+ case 'connecting':
+ if (this.props.tunnelState.details) {
+ pq = this.props.tunnelState.details.endpoint.quantumResistant;
+ }
+ break;
+ case 'connected':
+ pq = this.props.tunnelState.details.endpoint.quantumResistant;
+ break;
}
switch (state) {
- case 'connecting':
+ case 'connecting': {
+ const displayStyle = pq ? SecuredDisplayStyle.securingPq : SecuredDisplayStyle.securing;
return (
<Wrapper>
<Body>
- <Secured displayStyle={SecuredDisplayStyle.securing} />
+ <Secured displayStyle={displayStyle} />
<Location>
{this.renderCountry()}
{this.renderCity()}
@@ -112,11 +122,14 @@ export default class TunnelControl extends React.Component<ITunnelControlProps>
</Footer>
</Wrapper>
);
- case 'connected':
+ }
+
+ case 'connected': {
+ const displayStyle = pq ? SecuredDisplayStyle.securedPq : SecuredDisplayStyle.secured;
return (
<Wrapper>
<Body>
- <Secured displayStyle={SecuredDisplayStyle.secured} />
+ <Secured displayStyle={displayStyle} />
<Location>
{this.renderCountry()}
{this.renderCity()}
@@ -129,6 +142,7 @@ export default class TunnelControl extends React.Component<ITunnelControlProps>
</Footer>
</Wrapper>
);
+ }
case 'error':
if (
diff --git a/gui/src/shared/daemon-rpc-types.ts b/gui/src/shared/daemon-rpc-types.ts
index 45af1c605e..8bee8003b0 100644
--- a/gui/src/shared/daemon-rpc-types.ts
+++ b/gui/src/shared/daemon-rpc-types.ts
@@ -92,6 +92,7 @@ export interface ITunnelEndpoint {
address: string;
protocol: RelayProtocol;
tunnelType: TunnelType;
+ quantumResistant: boolean;
proxy?: IProxyEndpoint;
obfuscationEndpoint?: IObfuscationEndpoint;
entryEndpoint?: IEndpoint;
diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs
index c624c2a25b..ac7bd74c63 100644
--- a/mullvad-cli/src/cmds/relay.rs
+++ b/mullvad-cli/src/cmds/relay.rs
@@ -573,7 +573,24 @@ impl Relay {
}
if let Some(entry) = matches.values_of("entry location") {
wireguard_constraints.entry_location = parse_entry_location_constraint(entry);
- wireguard_constraints.use_multihop = wireguard_constraints.entry_location.is_some();
+ let use_multihop = wireguard_constraints.entry_location.is_some();
+ if use_multihop {
+ let use_pq = rpc
+ .get_settings(())
+ .await?
+ .into_inner()
+ .tunnel_options
+ .unwrap()
+ .wireguard
+ .unwrap()
+ .use_pq_safe_psk;
+ if use_pq {
+ return Err(Error::CommandFailed(
+ "PQ PSK exchange does not work when multihop is enabled",
+ ));
+ }
+ }
+ wireguard_constraints.use_multihop = use_multihop;
}
self.update_constraints(types::RelaySettingsUpdate {
diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs
index f01452a925..7856cc849c 100644
--- a/mullvad-cli/src/cmds/tunnel.rs
+++ b/mullvad-cli/src/cmds/tunnel.rs
@@ -37,6 +37,7 @@ fn create_wireguard_subcommand() -> clap::App<'static> {
.about("Manage options for Wireguard tunnels")
.setting(clap::AppSettings::SubcommandRequiredElseHelp)
.subcommand(create_wireguard_mtu_subcommand())
+ .subcommand(create_wireguard_quantum_resistant_tunnel_subcommand())
.subcommand(create_wireguard_keys_subcommand());
#[cfg(windows)]
{
@@ -57,6 +58,14 @@ fn create_wireguard_mtu_subcommand() -> clap::App<'static> {
.subcommand(clap::App::new("set").arg(clap::Arg::new("mtu").required(true)))
}
+fn create_wireguard_quantum_resistant_tunnel_subcommand() -> clap::App<'static> {
+ clap::App::new("quantum-resistant-tunnel")
+ .about("EXPERIMENTAL: Enables quantum-resistant PSK exchange in the tunnel")
+ .setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::App::new("get"))
+ .subcommand(clap::App::new("set").arg(clap::Arg::new("policy").required(true)))
+}
+
fn create_wireguard_keys_subcommand() -> clap::App<'static> {
clap::App::new("key")
.about("Manage your wireguard key")
@@ -163,6 +172,14 @@ impl Tunnel {
_ => unreachable!("unhandled command"),
},
+ Some(("quantum-resistant-tunnel", matches)) => match matches.subcommand() {
+ Some(("get", _)) => Self::process_wireguard_quantum_resistant_tunnel_get().await,
+ Some(("set", matches)) => {
+ Self::process_wireguard_quantum_resistant_tunnel_set(matches).await
+ }
+ _ => unreachable!("unhandled command"),
+ },
+
#[cfg(windows)]
Some(("use-wireguard-nt", matches)) => match matches.subcommand() {
Some(("get", _)) => Self::process_wireguard_use_wg_nt_get().await,
@@ -203,6 +220,45 @@ impl Tunnel {
Ok(())
}
+ async fn process_wireguard_quantum_resistant_tunnel_get() -> Result<()> {
+ let tunnel_options = Self::get_tunnel_options().await?;
+ if tunnel_options.wireguard.unwrap().use_pq_safe_psk {
+ println!("enabled");
+ } else {
+ println!("disabled");
+ }
+ Ok(())
+ }
+
+ async fn process_wireguard_quantum_resistant_tunnel_set(
+ matches: &clap::ArgMatches,
+ ) -> Result<()> {
+ let new_state = matches.value_of("policy").unwrap() == "on";
+ let mut rpc = new_rpc_client().await?;
+ let settings = rpc.get_settings(()).await?;
+ let multihop_is_enabled = settings
+ .into_inner()
+ .relay_settings
+ .unwrap()
+ .endpoint
+ .and_then(|endpoint| {
+ if let types::relay_settings::Endpoint::Normal(settings) = endpoint {
+ Some(settings.wireguard_constraints.unwrap().use_multihop)
+ } else {
+ None
+ }
+ })
+ .unwrap_or(false);
+ if multihop_is_enabled {
+ return Err(Error::CommandFailed(
+ "PQ PSK exchange does not work when multihop is enabled",
+ ));
+ }
+ rpc.set_quantum_resistant_tunnel(new_state).await?;
+ println!("Updated quantum resistant tunnel setting");
+ Ok(())
+ }
+
#[cfg(windows)]
async fn process_wireguard_use_wg_nt_get() -> Result<()> {
let tunnel_options = Self::get_tunnel_options().await?;
diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs
index 74630166d2..47c05611b2 100644
--- a/mullvad-cli/src/format.rs
+++ b/mullvad-cli/src/format.rs
@@ -109,6 +109,15 @@ fn format_relay_connection(relay_info: &TunnelStateRelayInfo, verbose: bool) ->
} else {
String::new()
};
+ let quantum_resistant = if verbose {
+ if endpoint.quantum_resistant {
+ "\nQuantum resistant tunnel: true".to_string()
+ } else {
+ "\nQuantum resistant tunnel: false".to_string()
+ }
+ } else {
+ String::new()
+ };
let mut bridge_type = String::new();
let mut obfuscator_type = String::new();
@@ -127,7 +136,7 @@ fn format_relay_connection(relay_info: &TunnelStateRelayInfo, verbose: bool) ->
}
format!(
- "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{bridge_type}{obfuscator_type}",
+ "{exit_endpoint}{first_hop}{bridge}{obfuscator}{tunnel_type}{quantum_resistant}{bridge_type}{obfuscator_type}",
first_hop = first_hop.unwrap_or_default(),
bridge = bridge.unwrap_or_default(),
obfuscator = obfuscator.unwrap_or_default(),
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index b75ad9121c..c013361925 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -220,6 +220,8 @@ pub enum DaemonCommand {
SetBridgeState(ResponseTx<(), settings::Error>, BridgeState),
/// Set if IPv6 should be enabled in the tunnel
SetEnableIpv6(ResponseTx<(), settings::Error>, bool),
+ /// Set whether to enable PQ PSK exchange in the tunnel
+ SetQuantumResistantTunnel(ResponseTx<(), settings::Error>, bool),
/// Set DNS options or servers to use
SetDnsOptions(ResponseTx<(), settings::Error>, DnsOptions),
/// Toggle macOS network check leak
@@ -979,6 +981,9 @@ where
}
SetBridgeState(tx, bridge_state) => self.on_set_bridge_state(tx, bridge_state).await,
SetEnableIpv6(tx, enable_ipv6) => self.on_set_enable_ipv6(tx, enable_ipv6).await,
+ SetQuantumResistantTunnel(tx, enable_pq) => {
+ self.on_set_quantum_resistant_tunnel(tx, enable_pq).await
+ }
SetDnsOptions(tx, dns_servers) => self.on_set_dns_options(tx, dns_servers).await,
SetWireguardMtu(tx, mtu) => self.on_set_wireguard_mtu(tx, mtu).await,
SetWireguardRotationInterval(tx, interval) => {
@@ -1053,7 +1058,7 @@ where
}
}
PrivateDeviceEvent::RotatedKey(_) => {
- if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
+ if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) {
self.schedule_reconnect(WG_RECONNECT_DELAY);
}
}
@@ -1683,7 +1688,7 @@ where
.set_tunnel_options(&self.settings.tunnel_options);
self.event_listener
.notify_settings(self.settings.to_settings());
- if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
+ if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
log::info!("Initiating tunnel restart");
self.reconnect_tunnel();
}
@@ -1834,7 +1839,7 @@ where
.set_tunnel_options(&self.settings.tunnel_options);
self.event_listener
.notify_settings(self.settings.to_settings());
- if let Some(TunnelType::OpenVpn) = self.get_connected_tunnel_type() {
+ if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) {
log::info!(
"Initiating tunnel restart because the OpenVPN mssfix setting changed"
);
@@ -1954,6 +1959,36 @@ where
}
}
+ async fn on_set_quantum_resistant_tunnel(
+ &mut self,
+ tx: ResponseTx<(), settings::Error>,
+ use_pq_safe_psk: bool,
+ ) {
+ let save_result = self
+ .settings
+ .set_quantum_resistant_tunnel(use_pq_safe_psk)
+ .await;
+ match save_result {
+ Ok(settings_changed) => {
+ Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response");
+ if settings_changed {
+ self.parameters_generator
+ .set_tunnel_options(&self.settings.tunnel_options);
+ self.event_listener
+ .notify_settings(self.settings.to_settings());
+ if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) {
+ log::info!("Reconnecting because the PQ safety setting changed");
+ self.reconnect_tunnel();
+ }
+ }
+ }
+ Err(e) => {
+ log::error!("{}", e.display_chain_with_msg("Unable to save settings"));
+ Self::oneshot_send(tx, Err(e), "set_quantum_resistant_tunnel response");
+ }
+ }
+ }
+
async fn on_set_dns_options(
&mut self,
tx: ResponseTx<(), settings::Error>,
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 2264159fc8..a7478eb048 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -373,6 +373,17 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_settings_error)
}
+ async fn set_quantum_resistant_tunnel(&self, request: Request<bool>) -> ServiceResult<()> {
+ let enable = request.into_inner();
+ log::debug!("set_quantum_resistant_tunnel({})", enable);
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::SetQuantumResistantTunnel(tx, enable))?;
+ self.wait_for_result(rx)
+ .await?
+ .map(Response::new)
+ .map_err(map_settings_error)
+ }
+
#[cfg(not(target_os = "android"))]
async fn set_dns_options(&self, request: Request<types::DnsOptions>) -> ServiceResult<()> {
let options = DnsOptions::try_from(request.into_inner()).map_err(map_protobuf_type_err)?;
diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs
index 52800028cc..e69ca8316f 100644
--- a/mullvad-daemon/src/settings.rs
+++ b/mullvad-daemon/src/settings.rs
@@ -236,6 +236,22 @@ impl SettingsPersister {
self.update(should_save).await
}
+ pub async fn set_quantum_resistant_tunnel(
+ &mut self,
+ use_pq_safe_psk: bool,
+ ) -> Result<bool, Error> {
+ let should_save = Self::update_field(
+ &mut self
+ .settings
+ .tunnel_options
+ .wireguard
+ .options
+ .use_pq_safe_psk,
+ use_pq_safe_psk,
+ );
+ self.update(should_save).await
+ }
+
pub async fn set_dns_options(&mut self, options: DnsOptions) -> Result<bool, Error> {
let should_save =
Self::update_field(&mut self.settings.tunnel_options.dns_options, options);
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 20d5318a58..9148a10150 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -43,6 +43,7 @@ service ManagementService {
rpc SetOpenvpnMssfix(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {}
rpc SetWireguardMtu(google.protobuf.UInt32Value) returns (google.protobuf.Empty) {}
rpc SetEnableIpv6(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
+ rpc SetQuantumResistantTunnel(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
rpc SetDnsOptions(DnsOptions) returns (google.protobuf.Empty) {}
// Account management
@@ -193,9 +194,10 @@ message TunnelEndpoint {
string address = 1;
TransportProtocol protocol = 2;
TunnelType tunnel_type = 3;
- ProxyEndpoint proxy = 4;
- ObfuscationEndpoint obfuscation = 5;
- Endpoint entry_endpoint = 6;
+ bool quantum_resistant = 4;
+ ProxyEndpoint proxy = 5;
+ ObfuscationEndpoint obfuscation = 6;
+ Endpoint entry_endpoint = 7;
}
enum ObfuscationType {
@@ -435,6 +437,7 @@ message TunnelOptions {
uint32 mtu = 1;
google.protobuf.Duration rotation_interval = 2;
bool use_wireguard_nt = 3;
+ bool use_pq_safe_psk = 4;
}
message GenericOptions {
bool enable_ipv6 = 1;
diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs
index 01f16d8220..cc2086d66e 100644
--- a/mullvad-management-interface/src/types.rs
+++ b/mullvad-management-interface/src/types.rs
@@ -35,6 +35,7 @@ impl From<talpid_types::net::TunnelEndpoint> for TunnelEndpoint {
net::TunnelType::Wireguard => i32::from(TunnelType::Wireguard),
net::TunnelType::OpenVpn => i32::from(TunnelType::Openvpn),
},
+ quantum_resistant: endpoint.quantum_resistant,
proxy: endpoint.proxy.map(|proxy_ep| ProxyEndpoint {
address: proxy_ep.endpoint.address.to_string(),
protocol: i32::from(TransportProtocol::from(proxy_ep.endpoint.protocol)),
@@ -680,6 +681,7 @@ impl From<&mullvad_types::settings::TunnelOptions> for TunnelOptions {
use_wireguard_nt: options.wireguard.options.use_wireguard_nt,
#[cfg(not(windows))]
use_wireguard_nt: false,
+ use_pq_safe_psk: options.wireguard.options.use_pq_safe_psk,
}),
generic: Some(tunnel_options::GenericOptions {
enable_ipv6: options.generic.enable_ipv6,
@@ -1185,6 +1187,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig {
public_key,
allowed_ips,
endpoint,
+ psk: None,
},
exit_peer: None,
ipv4_gateway,
@@ -1412,6 +1415,7 @@ impl TryFrom<TunnelOptions> for mullvad_types::settings::TunnelOptions {
} else {
None
},
+ use_pq_safe_psk: wireguard_options.use_pq_safe_psk,
#[cfg(windows)]
use_wireguard_nt: wireguard_options.use_wireguard_nt,
},
diff --git a/mullvad-relay-selector/src/matcher.rs b/mullvad-relay-selector/src/matcher.rs
index 089510a6c0..13e16646ab 100644
--- a/mullvad-relay-selector/src/matcher.rs
+++ b/mullvad-relay-selector/src/matcher.rs
@@ -189,6 +189,7 @@ impl WireguardMatcher {
public_key: data.public_key,
endpoint: SocketAddr::new(host, port),
allowed_ips: all_of_the_internet(),
+ psk: None,
};
Some(MullvadEndpoint::Wireguard(MullvadWireguardEndpoint {
peer: peer_config,
diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs
index 4c05f1e552..da7f25828a 100644
--- a/mullvad-types/src/wireguard.rs
+++ b/mullvad-types/src/wireguard.rs
@@ -113,7 +113,8 @@ impl Default for RotationInterval {
}
}
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(default)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(
target_os = "android",
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index e422a537ca..b6ff02bd8c 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -26,6 +26,7 @@ regex = "1.1.0"
shell-escape = "0.1"
talpid-types = { path = "../talpid-types" }
talpid-time = { path = "../talpid-time" }
+talpid-tunnel-config-client = { path = "../talpid-tunnel-config-client" }
uuid = { version = "0.8", features = ["v4"] }
zeroize = "1"
chrono = "0.4.19"
diff --git a/talpid-core/src/firewall/linux.rs b/talpid-core/src/firewall/linux.rs
index 8928e94e3c..4d66e94b1a 100644
--- a/talpid-core/src/firewall/linux.rs
+++ b/talpid-core/src/firewall/linux.rs
@@ -12,9 +12,9 @@ use std::{
env,
ffi::{CStr, CString},
io,
- net::{IpAddr, Ipv4Addr},
+ net::{IpAddr, Ipv4Addr, SocketAddr},
};
-use talpid_types::net::{Endpoint, TransportProtocol};
+use talpid_types::net::{AllowedTunnelTraffic, Endpoint, Protocol, TransportProtocol};
/// Priority for rules that tag split tunneling packets. Equals NF_IP_PRI_MANGLE.
const MANGLE_CHAIN_PRIORITY: i32 = libc::NF_IP_PRI_MANGLE;
@@ -558,6 +558,7 @@ impl<'a> PolicyBatch<'a> {
tunnel,
allow_lan,
allowed_endpoint,
+ allowed_tunnel_traffic,
} => {
self.add_allow_tunnel_endpoint_rules(peer_endpoint);
self.add_allow_endpoint_rules(&allowed_endpoint.endpoint);
@@ -567,7 +568,19 @@ impl<'a> PolicyBatch<'a> {
self.add_drop_dns_rule();
if let Some(tunnel) = tunnel {
- self.add_allow_tunnel_rules(&tunnel.interface)?;
+ match allowed_tunnel_traffic {
+ AllowedTunnelTraffic::All => {
+ self.add_allow_tunnel_rules(&tunnel.interface)?;
+ }
+ AllowedTunnelTraffic::None => (),
+ AllowedTunnelTraffic::Only(address, protocol) => {
+ self.add_allow_in_tunnel_endpoint_rules(
+ &tunnel.interface,
+ *address,
+ *protocol,
+ )?;
+ }
+ }
if *allow_lan {
self.add_block_cve_2019_14899(tunnel);
}
@@ -771,6 +784,35 @@ impl<'a> PolicyBatch<'a> {
}
}
+ fn add_allow_in_tunnel_endpoint_rules(
+ &mut self,
+ tunnel_interface: &str,
+ address: SocketAddr,
+ protocol: Protocol,
+ ) -> Result<()> {
+ for (chain, dir, end) in [
+ (&self.out_chain, Direction::Out, End::Dst),
+ (&self.in_chain, Direction::In, End::Src),
+ ] {
+ let mut rule = Rule::new(chain);
+
+ check_iface(&mut rule, dir, tunnel_interface)?;
+ check_ip(&mut rule, end, address.ip());
+ match protocol {
+ Protocol::IcmpV4 | Protocol::IcmpV6 => check_l4proto(&mut rule, protocol),
+ Protocol::Tcp => {
+ check_port(&mut rule, TransportProtocol::Tcp, end, address.port());
+ }
+ Protocol::Udp => {
+ check_port(&mut rule, TransportProtocol::Udp, end, address.port());
+ }
+ }
+ add_verdict(&mut rule, &Verdict::Accept);
+ self.batch.add(&rule, nftnl::MsgType::Add);
+ }
+ Ok(())
+ }
+
fn add_allow_tunnel_rules(&mut self, tunnel_interface: &str) -> Result<()> {
self.batch.add(
&allow_interface_rule(&self.out_chain, Direction::Out, tunnel_interface)?,
@@ -988,7 +1030,7 @@ fn check_ip(rule: &mut Rule<'_>, end: End, ip: impl Into<IpAddr>) {
fn check_port(rule: &mut Rule<'_>, protocol: TransportProtocol, end: End, port: u16) {
// Must check transport layer protocol before loading transport layer payload
- check_l4proto(rule, protocol);
+ check_l4proto(rule, protocol.into());
rule.add_expr(&match (protocol, end) {
(TransportProtocol::Udp, End::Src) => nft_expr!(payload udp sport),
@@ -1011,15 +1053,17 @@ fn l3proto(addr: IpAddr) -> u8 {
}
}
-fn check_l4proto(rule: &mut Rule<'_>, protocol: TransportProtocol) {
+fn check_l4proto(rule: &mut Rule<'_>, protocol: Protocol) {
rule.add_expr(&nft_expr!(meta l4proto));
rule.add_expr(&nft_expr!(cmp == l4proto(protocol)));
}
-fn l4proto(protocol: TransportProtocol) -> u8 {
+fn l4proto(protocol: Protocol) -> u8 {
match protocol {
- TransportProtocol::Udp => libc::IPPROTO_UDP as u8,
- TransportProtocol::Tcp => libc::IPPROTO_TCP as u8,
+ Protocol::Udp => libc::IPPROTO_UDP as u8,
+ Protocol::Tcp => libc::IPPROTO_TCP as u8,
+ Protocol::IcmpV4 => libc::IPPROTO_ICMP as u8,
+ Protocol::IcmpV6 => libc::IPPROTO_ICMPV6 as u8,
}
}
diff --git a/talpid-core/src/firewall/macos.rs b/talpid-core/src/firewall/macos.rs
index 7c2e1455b2..fa40ae7e0d 100644
--- a/talpid-core/src/firewall/macos.rs
+++ b/talpid-core/src/firewall/macos.rs
@@ -6,7 +6,7 @@ use std::{
net::{IpAddr, Ipv4Addr},
};
use subslice::SubsliceExt;
-use talpid_types::net;
+use talpid_types::net::{self, AllowedTunnelTraffic};
pub use pfctl::Error;
@@ -119,6 +119,7 @@ impl Firewall {
tunnel,
allow_lan,
allowed_endpoint,
+ allowed_tunnel_traffic,
} => {
let mut rules = vec![self.get_allow_relay_rule(*peer_endpoint)?];
rules.push(self.get_allowed_endpoint_rule(allowed_endpoint.endpoint)?);
@@ -128,7 +129,10 @@ impl Firewall {
rules.append(&mut self.get_block_dns_rules()?);
if let Some(tunnel) = tunnel {
- rules.push(self.get_allow_tunnel_rule(&tunnel.interface)?);
+ rules.extend(
+ self.get_allow_tunnel_rule(&tunnel.interface, allowed_tunnel_traffic)?
+ .into_iter(),
+ );
}
if *allow_lan {
@@ -154,7 +158,13 @@ impl Firewall {
// can't leak to the wrong IPs in the tunnel or on the LAN.
rules.append(&mut self.get_block_dns_rules()?);
- rules.push(self.get_allow_tunnel_rule(tunnel.interface.as_str())?);
+ rules.extend(
+ self.get_allow_tunnel_rule(
+ tunnel.interface.as_str(),
+ &AllowedTunnelTraffic::All,
+ )?
+ .into_iter(),
+ );
if *allow_lan {
rules.append(&mut self.get_allow_lan_rules()?);
@@ -318,14 +328,34 @@ impl Firewall {
Ok(vec![block_tcp_dns_rule, block_udp_dns_rule])
}
- fn get_allow_tunnel_rule(&self, tunnel_interface: &str) -> Result<pfctl::FilterRule> {
- Ok(self
- .create_rule_builder(FilterRuleAction::Pass)
+ fn get_allow_tunnel_rule(
+ &self,
+ tunnel_interface: &str,
+ allowed_traffic: &AllowedTunnelTraffic,
+ ) -> Result<Option<pfctl::FilterRule>> {
+ let mut rule_builder = self.create_rule_builder(FilterRuleAction::Pass);
+ let mut base_rule = rule_builder
.quick(true)
.interface(tunnel_interface)
.keep_state(pfctl::StatePolicy::Keep)
- .tcp_flags(Self::get_tcp_flags())
- .build()?)
+ .tcp_flags(Self::get_tcp_flags());
+ match allowed_traffic {
+ AllowedTunnelTraffic::Only(addr, protocol) => {
+ use talpid_types::net::Protocol::*;
+ let pfctl_proto = match protocol {
+ Udp => pfctl::Proto::Udp,
+ Tcp => pfctl::Proto::Tcp,
+ IcmpV4 => pfctl::Proto::Icmp,
+ IcmpV6 => pfctl::Proto::IcmpV6,
+ };
+ base_rule = base_rule.to(*addr).proto(pfctl_proto);
+ }
+ AllowedTunnelTraffic::All => {}
+ AllowedTunnelTraffic::None => {
+ return Ok(None);
+ }
+ };
+ Ok(Some(base_rule.build()?))
}
fn get_allow_loopback_rules(&self) -> Result<Vec<pfctl::FilterRule>> {
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index b23e16b017..a72ad65b1b 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -8,7 +8,7 @@ use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr},
};
-use talpid_types::net::{AllowedEndpoint, Endpoint};
+use talpid_types::net::{AllowedEndpoint, AllowedTunnelTraffic, Endpoint};
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
@@ -109,6 +109,8 @@ pub enum FirewallPolicy {
allow_lan: bool,
/// Host that should be reachable while connecting.
allowed_endpoint: AllowedEndpoint,
+ /// Networks for which to permit in-tunnel traffic.
+ allowed_tunnel_traffic: AllowedTunnelTraffic,
/// A process that is allowed to send packets to the relay.
#[cfg(windows)]
relay_client: PathBuf,
@@ -151,12 +153,13 @@ impl fmt::Display for FirewallPolicy {
tunnel,
allow_lan,
allowed_endpoint,
+ allowed_tunnel_traffic,
..
} => {
if let Some(tunnel) = tunnel {
write!(
f,
- "Connecting to {} over \"{}\" (ip: {}, v4 gw: {}, v6 gw: {:?}), {} LAN. Allowing endpoint {}",
+ "Connecting to {} over \"{}\" (ip: {}, v4 gw: {}, v6 gw: {:?}, allowed in-tunnel traffic: {}), {} LAN. Allowing endpoint {}",
peer_endpoint,
tunnel.interface,
tunnel
@@ -167,6 +170,7 @@ impl fmt::Display for FirewallPolicy {
.join(","),
tunnel.ipv4_gateway,
tunnel.ipv6_gateway,
+ allowed_tunnel_traffic,
if *allow_lan { "Allowing" } else { "Blocking" },
allowed_endpoint,
)
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index dca86d2257..4683924d17 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -6,7 +6,7 @@ use self::winfw::*;
use super::{FirewallArguments, FirewallPolicy, InitialFirewallState};
use crate::winnet;
use talpid_types::{
- net::{AllowedEndpoint, Endpoint},
+ net::{AllowedEndpoint, AllowedTunnelTraffic, Endpoint},
tunnel::FirewallPolicyError,
};
use widestring::WideCString;
@@ -102,6 +102,7 @@ impl Firewall {
tunnel,
allow_lan,
allowed_endpoint,
+ allowed_tunnel_traffic,
relay_client,
} => {
let cfg = &WinFwSettings::new(allow_lan);
@@ -111,6 +112,7 @@ impl Firewall {
&cfg,
&tunnel,
&WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(),
+ &allowed_tunnel_traffic,
&relay_client,
)
}
@@ -148,6 +150,7 @@ impl Firewall {
winfw_settings: &WinFwSettings,
tunnel_metadata: &Option<TunnelMetadata>,
allowed_endpoint: &WinFwAllowedEndpoint<'_>,
+ allowed_tunnel_traffic: &AllowedTunnelTraffic,
relay_client: &Path,
) -> Result<(), Error> {
log::trace!("Applying 'connecting' firewall policy");
@@ -169,6 +172,26 @@ impl Firewall {
ptr::null()
};
+ let allowed_tun_ip;
+ let allowed_tunnel_endpoint =
+ if let AllowedTunnelTraffic::Only(addr, proto) = allowed_tunnel_traffic {
+ allowed_tun_ip = widestring_ip(addr.ip());
+ Some(WinFwEndpoint {
+ ip: allowed_tun_ip.as_ptr(),
+ port: addr.port(),
+ protocol: WinFwProt::from(*proto),
+ })
+ } else {
+ None
+ };
+ let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic {
+ type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic),
+ endpoint: allowed_tunnel_endpoint
+ .as_ref()
+ .map(|ep| ep as *const _)
+ .unwrap_or(ptr::null()),
+ };
+
unsafe {
WinFw_ApplyPolicyConnecting(
winfw_settings,
@@ -176,6 +199,7 @@ impl Firewall {
relay_client.as_ptr(),
interface_wstr_ptr,
allowed_endpoint,
+ &allowed_tunnel_traffic,
)
.into_result()
.map_err(Error::ApplyingConnectingPolicy)
@@ -276,10 +300,10 @@ fn widestring_ip(ip: IpAddr) -> WideCString {
#[allow(non_snake_case)]
mod winfw {
- use super::{widestring_ip, AllowedEndpoint, Error, WideCString};
+ use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString};
use crate::logging::windows::LogSink;
use libc;
- use talpid_types::net::TransportProtocol;
+ use talpid_types::net::{Protocol, TransportProtocol};
pub struct WinFwAllowedEndpointContainer {
_clients: Box<[WideCString]>,
@@ -338,6 +362,30 @@ mod winfw {
}
#[repr(C)]
+ pub struct WinFwAllowedTunnelTraffic {
+ pub type_: WinFwAllowedTunnelTrafficType,
+ pub endpoint: *const WinFwEndpoint,
+ }
+
+ #[repr(u8)]
+ #[derive(Clone, Copy)]
+ pub enum WinFwAllowedTunnelTrafficType {
+ None,
+ All,
+ Only,
+ }
+
+ impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType {
+ fn from(traffic: &AllowedTunnelTraffic) -> Self {
+ match traffic {
+ AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None,
+ AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All,
+ AllowedTunnelTraffic::Only(..) => WinFwAllowedTunnelTrafficType::Only,
+ }
+ }
+ }
+
+ #[repr(C)]
pub struct WinFwEndpoint {
pub ip: *const libc::wchar_t,
pub port: u16,
@@ -349,6 +397,8 @@ mod winfw {
pub enum WinFwProt {
Tcp = 0u8,
Udp = 1u8,
+ IcmpV4 = 2u8,
+ IcmpV6 = 3u8,
}
impl From<TransportProtocol> for WinFwProt {
@@ -360,6 +410,17 @@ mod winfw {
}
}
+ impl From<Protocol> for WinFwProt {
+ fn from(prot: Protocol) -> WinFwProt {
+ match prot {
+ Protocol::Tcp => WinFwProt::Tcp,
+ Protocol::Udp => WinFwProt::Udp,
+ Protocol::IcmpV4 => WinFwProt::IcmpV4,
+ Protocol::IcmpV6 => WinFwProt::IcmpV6,
+ }
+ }
+ }
+
#[repr(C)]
pub struct WinFwSettings {
permitDhcp: bool,
@@ -440,7 +501,8 @@ mod winfw {
relay: &WinFwEndpoint,
relayClient: *const libc::wchar_t,
tunnelIfaceAlias: *const libc::wchar_t,
- allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
+ allowedEndpoint: *const WinFwAllowedEndpoint<'_>,
+ allowedTunnelTraffic: &WinFwAllowedTunnelTraffic,
) -> WinFwPolicyStatus;
#[link_name = "WinFw_ApplyPolicyConnected"]
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index ea7ec35133..bafc839cf6 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -8,7 +8,7 @@ use std::{
};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
-use talpid_types::net::{wireguard as wireguard_types, TunnelParameters};
+use talpid_types::net::{wireguard as wireguard_types, AllowedTunnelTraffic, TunnelParameters};
#[cfg(target_os = "android")]
pub use self::tun_provider::TunConfig;
@@ -73,7 +73,7 @@ pub enum TunnelEvent {
/// Sent when the tunnel fails to connect due to an authentication error.
AuthFailed(Option<String>),
/// Sent when the tunnel interface has been created, before routes are set up.
- InterfaceUp(TunnelMetadata),
+ InterfaceUp(TunnelMetadata, AllowedTunnelTraffic),
/// Sent when the tunnel comes up and is ready for traffic.
Up(TunnelMetadata),
/// Sent when the tunnel goes down.
@@ -198,6 +198,18 @@ impl TunnelMonitor {
let monitor = wireguard::WireguardMonitor::start(
runtime,
config,
+ if params.options.use_pq_safe_psk {
+ Some(
+ params
+ .connection
+ .exit_peer
+ .as_ref()
+ .map(|peer| peer.public_key.clone())
+ .unwrap_or(params.connection.peer.public_key.clone()),
+ )
+ } else {
+ None
+ },
log.as_deref(),
resource_dir,
on_event,
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index cb885fb8ec..d9ddf8dd9a 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -873,9 +873,10 @@ mod event_server {
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
- (self.on_event)(super::TunnelEvent::InterfaceUp(Self::get_tunnel_metadata(
- &env,
- )?))
+ (self.on_event)(super::TunnelEvent::InterfaceUp(
+ Self::get_tunnel_metadata(&env)?,
+ talpid_types::net::AllowedTunnelTraffic::All,
+ ))
.await;
Ok(Response::new(()))
}
diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs
index 026c6c9fae..ed285742c4 100644
--- a/talpid-core/src/tunnel/wireguard/config.rs
+++ b/talpid-core/src/tunnel/wireguard/config.rs
@@ -6,6 +6,7 @@ use std::{
use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions};
/// Config required to set up a single WireGuard tunnel
+#[derive(Clone)]
pub struct Config {
/// Contains tunnel endpoint specific config
pub tunnel: wireguard::TunnelConfig,
@@ -147,6 +148,9 @@ impl Config {
.add("public_key", peer.public_key.as_bytes().as_ref())
.add("endpoint", peer.endpoint.to_string().as_str())
.add("replace_allowed_ips", "true");
+ if let Some(ref psk) = peer.psk {
+ wg_conf.add("preshared_key", psk.as_bytes().as_ref());
+ }
for addr in &peer.allowed_ips {
wg_conf.add("allowed_ip", addr.to_string().as_str());
}
diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
index ec2af873a0..bcb28d7c17 100644
--- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs
+++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs
@@ -391,12 +391,16 @@ impl ConnState {
#[cfg(test)]
mod test {
+ use futures::Future;
+
use super::*;
use crate::tunnel::wireguard::{
+ config::Config,
stats::{self, Stats},
TunnelError,
};
use std::{
+ pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
@@ -598,6 +602,13 @@ mod test {
fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> {
(self.on_get_stats)()
}
+
+ fn set_config(
+ &self,
+ _config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
+ Box::pin(async { Ok(()) })
+ }
}
fn mock_monitor(
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 80c8e2ae8e..3d0de8ba5c 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -8,6 +8,7 @@ use futures::{channel::mpsc, StreamExt};
use futures::{
channel::oneshot,
future::{abortable, AbortHandle as FutureAbortHandle},
+ Future,
};
#[cfg(target_os = "linux")]
use lazy_static::lazy_static;
@@ -18,14 +19,20 @@ use std::env;
#[cfg(windows)]
use std::io;
use std::{
+ borrow::Cow,
convert::Infallible,
- net::IpAddr,
+ net::{IpAddr, SocketAddrV4},
path::Path,
+ pin::Pin,
sync::{mpsc as sync_mpsc, Arc, Mutex},
+ time::Duration,
};
#[cfg(windows)]
use talpid_types::BoxedError;
-use talpid_types::{net::obfuscation::ObfuscatorConfig, ErrorExt};
+use talpid_types::{
+ net::{obfuscation::ObfuscatorConfig, wireguard::PublicKey, AllowedTunnelTraffic, Protocol},
+ ErrorExt,
+};
use tunnel_obfuscation::{
create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings,
};
@@ -73,6 +80,10 @@ pub enum Error {
#[error(display = "Connectivity monitor failed")]
ConnectivityMonitorError(#[error(source)] connectivity_check::Error),
+ /// Failed to negotiate PQ PSK
+ #[error(display = "Failed to negotiate PQ PSK")]
+ PskNegotiationError(#[error(source)] talpid_tunnel_config_client::Error),
+
/// Failed to set up IP interfaces.
#[cfg(windows)]
#[error(display = "Failed to set up IP interfaces")]
@@ -101,6 +112,10 @@ pub struct WireguardMonitor {
_obfuscator: Option<ObfuscatorHandle>,
}
+const INITIAL_PSK_EXCHANGE_TIMEOUT: Duration = Duration::from_secs(4);
+const MAX_PSK_EXCHANGE_TIMEOUT: Duration = Duration::from_secs(15);
+const PSK_EXCHANGE_TIMEOUT_MULTIPLIER: u32 = 2;
+
/// Simple wrapper that automatically cancels the future which runs an obfuscator.
struct ObfuscatorHandle {
abort_handle: FutureAbortHandle,
@@ -180,7 +195,7 @@ fn maybe_create_obfuscator(
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
pub fn start<
- F: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ Clone
@@ -188,6 +203,7 @@ impl WireguardMonitor {
>(
runtime: tokio::runtime::Handle,
mut config: Config,
+ psk_negotiation: Option<PublicKey>,
log_path: Option<&Path>,
resource_dir: &Path,
on_event: F,
@@ -203,10 +219,11 @@ impl WireguardMonitor {
let obfuscator = maybe_create_obfuscator(&runtime, &mut config, close_msg_sender.clone())?;
#[cfg(target_os = "windows")]
- let (setup_done_tx, mut setup_done_rx) = mpsc::channel(0);
+ let (setup_done_tx, setup_done_rx) = mpsc::channel(0);
+
let tunnel = Self::open_tunnel(
runtime.clone(),
- &config,
+ &Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
log_path,
resource_dir,
tun_provider,
@@ -237,31 +254,22 @@ impl WireguardMonitor {
.map_err(Error::ConnectivityMonitorError)?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
+ let tunnel = monitor.tunnel.clone();
let tunnel_fut = async move {
#[cfg(windows)]
- {
- setup_done_rx
- .next()
- .await
- .ok_or_else(|| {
- // Tunnel was shut down early
- CloseMsg::SetupError(Error::IpInterfacesError)
- })?
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to configure tunnel interface")
- );
- CloseMsg::SetupError(Error::IpInterfacesError)
- })?;
+ Self::add_device_ip_addresses(&iface_name, &config.tunnel.addresses, setup_done_rx)
+ .await?;
- if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
- return Err(CloseMsg::SetupError(Error::SetIpAddressesError));
- }
- }
-
- (on_event)(TunnelEvent::InterfaceUp(metadata.clone())).await;
+ let allowed_traffic = if psk_negotiation.is_some() {
+ AllowedTunnelTraffic::Only(
+ SocketAddrV4::new(config.ipv4_gateway, 1337).into(),
+ Protocol::Tcp,
+ )
+ } else {
+ AllowedTunnelTraffic::All
+ };
+ (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
// Add non-default routes before establishing the tunnel.
#[cfg(target_os = "linux")]
@@ -280,6 +288,15 @@ impl WireguardMonitor {
.map_err(Error::SetupRoutingError)
.map_err(CloseMsg::SetupError)?;
+ if let Some(pubkey) = psk_negotiation {
+ Self::perform_psk_negotiation(tunnel, retry_attempt, pubkey, &mut config).await?;
+ (on_event)(TunnelEvent::InterfaceUp(
+ metadata.clone(),
+ AllowedTunnelTraffic::All,
+ ))
+ .await;
+ }
+
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
match connectivity_monitor.establish_connectivity(retry_attempt) {
Ok(true) => Ok(connectivity_monitor),
@@ -339,6 +356,125 @@ impl WireguardMonitor {
Ok(monitor)
}
+ /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true.
+ /// Used to block traffic to other destinations while connecting on Android.
+ fn patch_allowed_ips<'a>(config: &'a Config, gateway_only: bool) -> Cow<'a, Config> {
+ if gateway_only {
+ let mut patched_config = config.clone();
+ let gateway_net_v4 = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway));
+ let gateway_net_v6 = config
+ .ipv6_gateway
+ .map(|net| ipnetwork::IpNetwork::from(IpAddr::from(net)));
+ for peer in &mut patched_config.peers {
+ peer.allowed_ips = peer
+ .allowed_ips
+ .iter()
+ .cloned()
+ .filter_map(|mut allowed_ip| {
+ if allowed_ip.prefix() == 0 {
+ if allowed_ip.is_ipv4() {
+ allowed_ip = gateway_net_v4;
+ } else {
+ if let Some(net) = gateway_net_v6 {
+ allowed_ip = net;
+ } else {
+ return None;
+ }
+ }
+ }
+ Some(allowed_ip)
+ })
+ .collect();
+ }
+ Cow::Owned(patched_config)
+ } else {
+ Cow::Borrowed(config)
+ }
+ }
+
+ #[cfg(windows)]
+ async fn add_device_ip_addresses(
+ iface_name: &str,
+ addresses: &[IpAddr],
+ mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>,
+ ) -> std::result::Result<(), CloseMsg> {
+ setup_done_rx
+ .next()
+ .await
+ .ok_or_else(|| {
+ // Tunnel was shut down early
+ CloseMsg::SetupError(Error::IpInterfacesError)
+ })?
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to configure tunnel interface")
+ );
+ CloseMsg::SetupError(Error::IpInterfacesError)
+ })?;
+ if !crate::winnet::add_device_ip_addresses(iface_name, addresses) {
+ return Err(CloseMsg::SetupError(Error::SetIpAddressesError));
+ }
+ Ok(())
+ }
+
+ async fn perform_psk_negotiation(
+ tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ retry_attempt: u32,
+ current_pubkey: PublicKey,
+ config: &mut Config,
+ ) -> std::result::Result<(), CloseMsg> {
+ log::debug!("Performing PQ-safe PSK exchange");
+
+ let timeout = std::cmp::min(
+ MAX_PSK_EXCHANGE_TIMEOUT,
+ INITIAL_PSK_EXCHANGE_TIMEOUT
+ .saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)),
+ );
+
+ let (private_key, psk) = tokio::time::timeout(
+ timeout,
+ talpid_tunnel_config_client::push_pq_key(
+ IpAddr::V4(config.ipv4_gateway),
+ config.tunnel.private_key.public_key(),
+ ),
+ )
+ .await
+ .map_err(|_timeout_err| {
+ log::warn!("Timeout while negotiating PSK");
+ CloseMsg::PskNegotiationTimeout
+ })?
+ .map_err(Error::PskNegotiationError)
+ .map_err(CloseMsg::SetupError)?;
+
+ config.tunnel.private_key = private_key;
+
+ for peer in &mut config.peers {
+ if current_pubkey == peer.public_key {
+ peer.psk = Some(psk);
+ break;
+ }
+ }
+
+ log::trace!(
+ "Ephemeral pubkey: {}",
+ config.tunnel.private_key.public_key()
+ );
+
+ let set_config_future = tunnel
+ .lock()
+ .unwrap()
+ .as_ref()
+ .map(|tunnel| tunnel.set_config(config.clone()));
+ if let Some(f) = set_config_future {
+ f.await
+ .map_err(Error::TunnelError)
+ .map_err(CloseMsg::SetupError)?;
+ }
+
+ Ok(())
+ }
+
#[allow(unused_variables)]
fn open_tunnel(
runtime: tokio::runtime::Handle,
@@ -424,7 +560,7 @@ impl WireguardMonitor {
/// Blocks the current thread until tunnel disconnects
pub fn wait(mut self) -> Result<()> {
let wait_result = match self.close_msg_receiver.recv() {
- Ok(CloseMsg::PingErr) => Err(Error::TimeoutError),
+ Ok(CloseMsg::PskNegotiationTimeout) | Ok(CloseMsg::PingErr) => Err(Error::TimeoutError),
Ok(CloseMsg::Stop) | Ok(CloseMsg::ObfuscatorExpired) => Ok(()),
Ok(CloseMsg::SetupError(error)) => Err(error),
Ok(CloseMsg::ObfuscatorFailed(error)) => Err(error),
@@ -584,6 +720,7 @@ impl WireguardMonitor {
enum CloseMsg {
Stop,
+ PskNegotiationTimeout,
PingErr,
SetupError(Error),
ObfuscatorExpired,
@@ -594,6 +731,10 @@ pub(crate) trait Tunnel: Send {
fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
+ fn set_config(
+ &self,
+ _config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>;
}
/// Errors to be returned from WireGuard implementations, namely implementers of the Tunnel trait
@@ -630,6 +771,10 @@ pub enum TunnelError {
#[error(display = "Failed to get config of WireGuard tunnel")]
GetConfigError,
+ /// Failed to set WireGuard tunnel config on device
+ #[error(display = "Failed to set config of WireGuard tunnel")]
+ SetConfigError,
+
/// Failed to duplicate tunnel file descriptor for wireguard-go
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "android"))]
#[error(display = "Failed to duplicate tunnel file descriptor for wireguard-go")]
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index a3d34acc0e..4fbdcefcdd 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -13,8 +13,10 @@ use futures::SinkExt;
use ipnetwork::IpNetwork;
use std::{
ffi::{c_void, CStr},
+ future::Future,
os::raw::c_char,
path::Path,
+ pin::Pin,
};
#[cfg(windows)]
use talpid_types::BoxedError;
@@ -354,6 +356,21 @@ impl Tunnel for WgGoTunnel {
fn stop(mut self: Box<Self>) -> Result<()> {
self.stop_tunnel()
}
+
+ fn set_config(
+ &self,
+ config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
+ let wg_config_str = config.to_userspace_format();
+ let handle = self.handle.unwrap();
+ Box::pin(async move {
+ let status = unsafe { wgSetConfig(handle, wg_config_str.as_ptr() as *const i8) };
+ if status != 0 {
+ return Err(TunnelError::SetConfigError);
+ }
+ Ok(())
+ })
+ }
}
fn check_wg_status(wg_code: i32) -> Result<()> {
@@ -422,6 +439,9 @@ extern "C" {
// Returns the file descriptor of the tunnel IPv4 socket.
fn wgGetConfig(handle: i32) -> *mut std::os::raw::c_char;
+ // Sets the config of the WireGuard interface.
+ fn wgSetConfig(handle: i32, settings: *const i8) -> i32;
+
// Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig
fn wgFreePtr(ptr: *mut c_void);
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs
index 8ab3234fd5..109ae75367 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs
@@ -1,3 +1,7 @@
+use std::pin::Pin;
+
+use futures::Future;
+
use super::{
super::stats::{Stats, StatsMap},
wg_message::DeviceNla,
@@ -109,4 +113,20 @@ impl Tunnel for NetlinkTunnel {
result
}
+
+ fn set_config(
+ &self,
+ config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> {
+ let mut wg = self.netlink_connections.wg_handle.clone();
+ let interface_index = self.interface_index;
+ Box::pin(async move {
+ wg.set_config(interface_index, &config)
+ .await
+ .map_err(|err| {
+ log::error!("Failed to fetch WireGuard device config: {}", err);
+ TunnelError::SetConfigError
+ })
+ })
+ }
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs
index fed29b93d9..92a8f5e81c 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nm_tunnel.rs
@@ -2,7 +2,8 @@ use super::{
super::stats::{Stats, StatsMap},
Config, Error as WgKernelError, Handle, Tunnel, TunnelError, MULLVAD_INTERFACE_NAME,
};
-use std::collections::HashMap;
+use futures::Future;
+use std::{collections::HashMap, pin::Pin};
use talpid_dbus::{
dbus,
network_manager::{
@@ -91,6 +92,24 @@ impl Tunnel for NetworkManagerTunnel {
Ok(Stats::parse_device_message(&device))
})
}
+
+ fn set_config(
+ &self,
+ config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
+ let interface_name = self.interface_name.clone();
+ let mut wg = self.netlink_connections.wg_handle.clone();
+ Box::pin(async move {
+ let index = crate::linux::iface_index(&interface_name).map_err(|err| {
+ log::error!("Failed to fetch WireGuard device index: {}", err);
+ TunnelError::SetConfigError
+ })?;
+ wg.set_config(index, &config).await.map_err(|err| {
+ log::error!("Failed to apply WireGuard config: {}", err);
+ TunnelError::SetConfigError
+ })
+ })
+ }
}
fn convert_config_to_dbus(config: &Config) -> DeviceConfig {
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
index 007acb8df7..dd17bc0cb3 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs
@@ -81,12 +81,16 @@ impl DeviceMessage {
for peer in config.peers.iter() {
let peer_endpoint = InetAddr::from_std(&peer.endpoint);
let allowed_ips = peer.allowed_ips.iter().map(From::from).collect();
- peers.push(PeerMessage(vec![
+ let mut peer_nlas = vec![
PeerNla::PublicKey(*peer.public_key.as_bytes()),
PeerNla::Endpoint(peer_endpoint),
PeerNla::AllowedIps(allowed_ips),
PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
- ]));
+ ];
+ if let Some(psk) = peer.psk.as_ref() {
+ peer_nlas.push(PeerNla::PresharedKey(psk.as_bytes().clone()));
+ }
+ peers.push(PeerMessage(peer_nlas));
}
let nlas = vec![
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
index 21c9b705ce..0ae469144b 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
@@ -11,11 +11,14 @@ use ipnetwork::IpNetwork;
use lazy_static::lazy_static;
use std::{
ffi::CStr,
- fmt, io, mem,
+ fmt,
+ future::Future,
+ io, mem,
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
os::windows::io::RawHandle,
path::Path,
+ pin::Pin,
ptr,
sync::{Arc, Mutex},
};
@@ -833,11 +836,20 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> {
buffer.extend(windows::as_uninit_byte_slice(&header));
for peer in &config.peers {
+ let flags = if peer.psk.is_some() {
+ WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT
+ } else {
+ WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT
+ };
let wg_peer = WgPeer {
- flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT,
+ flags,
reserved: 0,
public_key: peer.public_key.as_bytes().clone(),
- preshared_key: [0u8; WIREGUARD_KEY_LENGTH],
+ preshared_key: peer
+ .psk
+ .as_ref()
+ .map(|psk| psk.as_bytes().clone())
+ .unwrap_or([0u8; WIREGUARD_KEY_LENGTH]),
persistent_keepalive: 0,
endpoint: windows::inet_sockaddr_from_socketaddr(peer.endpoint).into(),
tx_bytes: 0,
@@ -976,6 +988,24 @@ impl Tunnel for WgNtTunnel {
self.stop_tunnel();
Ok(())
}
+
+ fn set_config(
+ &self,
+ config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
+ let device = self.device.clone();
+ Box::pin(async move {
+ let guard = device.lock().unwrap();
+ let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?;
+ device.set_config(&config).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set wg-nt tunnel config")
+ );
+ super::TunnelError::SetConfigError
+ })
+ })
+ }
}
#[cfg(test)]
@@ -1006,6 +1036,7 @@ mod tests {
public_key: WG_PUBLIC_KEY.clone(),
allowed_ips: vec!["1.3.3.0/24".parse().unwrap()],
endpoint: "1.2.3.4:1234".parse().unwrap(),
+ psk: None,
}],
ipv4_gateway: "0.0.0.0".parse().unwrap(),
ipv6_gateway: None,
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 747a5bdcbb..7536b26b09 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -21,7 +21,7 @@ use std::{
time::{Duration, Instant},
};
use talpid_types::{
- net::TunnelParameters,
+ net::{AllowedTunnelTraffic, TunnelParameters},
tunnel::{ErrorStateCause, FirewallPolicyError},
ErrorExt,
};
@@ -47,6 +47,7 @@ pub struct ConnectingState {
tunnel_events: TunnelEventsReceiver,
tunnel_parameters: TunnelParameters,
tunnel_metadata: Option<TunnelMetadata>,
+ allowed_tunnel_traffic: AllowedTunnelTraffic,
tunnel_close_event: TunnelCloseEvent,
tunnel_close_tx: oneshot::Sender<()>,
retry_attempt: u32,
@@ -57,6 +58,7 @@ impl ConnectingState {
shared_values: &mut SharedTunnelStateValues,
params: &TunnelParameters,
tunnel_metadata: &Option<TunnelMetadata>,
+ allowed_tunnel_traffic: AllowedTunnelTraffic,
) -> Result<(), FirewallPolicyError> {
#[cfg(target_os = "linux")]
shared_values.disable_connectivity_check();
@@ -68,6 +70,7 @@ impl ConnectingState {
tunnel: tunnel_metadata.clone(),
allow_lan: shared_values.allow_lan,
allowed_endpoint: shared_values.allowed_endpoint.clone(),
+ allowed_tunnel_traffic,
#[cfg(windows)]
relay_client: TunnelMonitor::get_relay_client(&shared_values.resource_dir, &params),
};
@@ -207,6 +210,7 @@ impl ConnectingState {
tunnel_events: event_rx.fuse(),
tunnel_parameters: parameters,
tunnel_metadata: None,
+ allowed_tunnel_traffic: AllowedTunnelTraffic::None,
tunnel_close_event: tunnel_close_event_rx.fuse(),
tunnel_close_tx,
retry_attempt,
@@ -294,6 +298,7 @@ impl ConnectingState {
shared_values,
&self.tunnel_parameters,
&self.tunnel_metadata,
+ self.allowed_tunnel_traffic.clone(),
) {
Ok(()) => {
cfg_if! {
@@ -333,6 +338,7 @@ impl ConnectingState {
shared_values,
&self.tunnel_parameters,
&self.tunnel_metadata,
+ self.allowed_tunnel_traffic.clone(),
) {
let _ = tx.send(());
return self.disconnect(
@@ -399,7 +405,7 @@ impl ConnectingState {
shared_values,
AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)),
),
- Some((TunnelEvent::InterfaceUp(metadata), _done_tx)) => {
+ Some((TunnelEvent::InterfaceUp(metadata, allowed_tunnel_traffic), _done_tx)) => {
#[cfg(windows)]
if let Err(error) = shared_values
.split_tunnel
@@ -416,11 +422,15 @@ impl ConnectingState {
AfterDisconnect::Block(ErrorStateCause::SplitTunnelError),
);
}
+
+ self.allowed_tunnel_traffic = allowed_tunnel_traffic;
self.tunnel_metadata = Some(metadata);
+
match Self::set_firewall_policy(
shared_values,
&self.tunnel_parameters,
&self.tunnel_metadata,
+ self.allowed_tunnel_traffic.clone(),
) {
Ok(()) => SameState(self.into()),
Err(error) => self.disconnect(
@@ -549,9 +559,12 @@ impl TunnelState for ConnectingState {
return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError);
}
- if let Err(error) =
- Self::set_firewall_policy(shared_values, &tunnel_parameters, &None)
- {
+ if let Err(error) = Self::set_firewall_policy(
+ shared_values,
+ &tunnel_parameters,
+ &None,
+ AllowedTunnelTraffic::None,
+ ) {
ErrorState::enter(
shared_values,
ErrorStateCause::SetFirewallPolicyError(error),
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 7f31082541..3d0075257b 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -404,7 +404,7 @@ pub fn interface_luid_to_ip(
}
}
-pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
+pub fn add_device_ip_addresses(iface: &str, addresses: &[IpAddr]) -> bool {
let raw_iface = WideCString::from_str(iface)
.expect("Failed to convert UTF-8 string to null terminated UCS string")
.into_raw();
diff --git a/talpid-tunnel-config-client/Cargo.toml b/talpid-tunnel-config-client/Cargo.toml
new file mode 100644
index 0000000000..5cb8983c41
--- /dev/null
+++ b/talpid-tunnel-config-client/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "talpid-tunnel-config-client"
+version = "0.1.0"
+authors = ["Mullvad VPN"]
+description = "Uses the relay RPC service to set up PQ-safe peers, etc."
+license = "GPL-3.0"
+edition = "2021"
+publish = false
+
+[dependencies]
+log = "0.4"
+rand = "0.8"
+talpid-types = { path = "../talpid-types" }
+tonic = "0.5"
+prost = "0.8"
+prost-types = "0.9"
+tower = "0.4"
+tokio = "1"
+classic-mceliece-rust = { git = "https://github.com/mullvad/classic-mceliece-rust", rev = "5130d9e3bfbf54735177e15636a643366c250b78", features = ["mceliece8192128f"] }
+
+[dev-dependencies]
+tokio = { version = "1", features = ["rt-multi-thread"] }
+
+[build-dependencies]
+tonic-build = { version = "0.5", default-features = false, features = ["transport", "prost"] } \ No newline at end of file
diff --git a/talpid-tunnel-config-client/build.rs b/talpid-tunnel-config-client/build.rs
new file mode 100644
index 0000000000..2732fb69d0
--- /dev/null
+++ b/talpid-tunnel-config-client/build.rs
@@ -0,0 +1,5 @@
+fn main() {
+ const PROTO_FILE: &str = "proto/tunnel_config.proto";
+ tonic_build::compile_protos(PROTO_FILE).unwrap();
+ println!("cargo:rerun-if-changed={}", PROTO_FILE);
+}
diff --git a/talpid-tunnel-config-client/examples/psk-exchange.rs b/talpid-tunnel-config-client/examples/psk-exchange.rs
new file mode 100644
index 0000000000..b4607be950
--- /dev/null
+++ b/talpid-tunnel-config-client/examples/psk-exchange.rs
@@ -0,0 +1,25 @@
+use std::{
+ io,
+ net::{IpAddr, Ipv4Addr},
+};
+
+use talpid_types::net::wireguard::PublicKey;
+
+#[tokio::main]
+async fn main() {
+ println!("Make sure you're connected to a WireGuard peer and enter your public key: ");
+
+ let mut pubkey_s = String::new();
+ io::stdin()
+ .read_line(&mut pubkey_s)
+ .expect("Failed to read from stdin");
+ let pubkey = PublicKey::from_base64(pubkey_s.trim()).expect("Invalid public key");
+
+ let (private_key, psk) =
+ talpid_tunnel_config_client::push_pq_key(IpAddr::V4(Ipv4Addr::new(10, 64, 0, 1)), pubkey)
+ .await
+ .unwrap();
+
+ println!("private key: {:?}", private_key);
+ println!("psk: {:?}", psk);
+}
diff --git a/talpid-tunnel-config-client/proto/tunnel_config.proto b/talpid-tunnel-config-client/proto/tunnel_config.proto
new file mode 100644
index 0000000000..9ce4232d0d
--- /dev/null
+++ b/talpid-tunnel-config-client/proto/tunnel_config.proto
@@ -0,0 +1,38 @@
+//
+// If you need to (re)generate the gRPC code, see prerequisites
+//
+// https://grpc.io/docs/languages/go/quickstart/
+//
+// and then run
+//
+// protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative tunnel_config.proto
+//
+// from this directory.
+//
+
+syntax = "proto3";
+
+option go_package = "github.com/mullvad/wg-manager/server/tuncfg";
+
+package tunnel_config;
+
+service PostQuantumSecure {
+ // PskExchangeExperimentalV0 uses the common API defined by LibOQS. See:
+ // https://github.com/open-quantum-safe/liboqs
+ rpc PskExchangeExperimentalV0(PskRequestExperimentalV0) returns (PskResponseExperimentalV0) {}
+}
+
+message PskRequestExperimentalV0 {
+ bytes wg_pubkey = 1;
+ bytes wg_psk_pubkey = 2;
+ KemPubkeyExperimentalV0 kem_pubkey = 3;
+}
+
+message KemPubkeyExperimentalV0 {
+ string algorithm_name = 1;
+ bytes key_data = 2;
+}
+
+message PskResponseExperimentalV0 {
+ bytes ciphertext = 1;
+}
diff --git a/talpid-tunnel-config-client/src/kem.rs b/talpid-tunnel-config-client/src/kem.rs
new file mode 100644
index 0000000000..35679ccddf
--- /dev/null
+++ b/talpid-tunnel-config-client/src/kem.rs
@@ -0,0 +1,67 @@
+use std::fmt;
+
+use super::Error;
+
+use classic_mceliece_rust::{
+ crypto_kem_dec, crypto_kem_keypair, CRYPTO_BYTES, CRYPTO_PUBLICKEYBYTES, CRYPTO_SECRETKEYBYTES,
+};
+use talpid_types::net::wireguard::PresharedKey;
+
+const STACK_SIZE: usize = 8 * 1024 * 1024;
+pub use classic_mceliece_rust::CRYPTO_CIPHERTEXTBYTES;
+
+#[derive(Debug)]
+pub struct PublicKey(Box<[u8; CRYPTO_PUBLICKEYBYTES]>);
+
+impl PublicKey {
+ pub fn into_vec(self) -> Vec<u8> {
+ (self.0 as Box<[u8]>).into_vec()
+ }
+}
+
+pub struct SecretKey(Box<[u8; CRYPTO_SECRETKEYBYTES]>);
+
+impl fmt::Debug for SecretKey {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SecretKey").finish()
+ }
+}
+
+pub async fn generate_keys() -> Result<(PublicKey, SecretKey), Error> {
+ let (tx, rx) = tokio::sync::oneshot::channel();
+
+ let gen_key = move || {
+ let mut rng = rand::thread_rng();
+ let mut pubkey = Box::new([0u8; CRYPTO_PUBLICKEYBYTES]);
+ let mut secret = Box::new([0u8; CRYPTO_SECRETKEYBYTES]);
+ crypto_kem_keypair(&mut pubkey, &mut secret, &mut rng).map_err(|error| {
+ log::error!("KEM keypair generation failed: {error}");
+ Error::KeyGenerationFailed
+ })?;
+
+ Ok((PublicKey(pubkey), SecretKey(secret)))
+ };
+
+ std::thread::Builder::new()
+ .stack_size(STACK_SIZE)
+ .spawn(move || {
+ let _ = tx.send(gen_key());
+ })
+ .unwrap();
+
+ rx.await.unwrap()
+}
+
+pub fn decapsulate(
+ secret: &SecretKey,
+ ciphertext: &[u8; CRYPTO_CIPHERTEXTBYTES],
+) -> Result<PresharedKey, Error> {
+ let mut psk = [0u8; CRYPTO_BYTES];
+
+ crypto_kem_dec(&mut psk, ciphertext, &secret.0).map_err(|error| {
+ log::error!("KEM decapsulation failed: {error}");
+ Error::DecapsulationError
+ })?;
+
+ Ok(PresharedKey::from(psk))
+}
diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs
new file mode 100644
index 0000000000..fafb95affb
--- /dev/null
+++ b/talpid-tunnel-config-client/src/lib.rs
@@ -0,0 +1,84 @@
+use std::{fmt, net::IpAddr};
+use talpid_types::net::wireguard::{PresharedKey, PrivateKey, PublicKey};
+use tonic::transport::Channel;
+
+mod kem;
+
+mod proto {
+ tonic::include_proto!("tunnel_config");
+}
+
+#[derive(Debug)]
+pub enum Error {
+ GrpcConnectError(tonic::transport::Error),
+ GrpcError(tonic::Status),
+ KeyGenerationFailed,
+ DecapsulationError,
+ InvalidCiphertext,
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ use Error::*;
+ match self {
+ GrpcConnectError(_) => "Failed to connect to config service".fmt(f),
+ GrpcError(status) => write!(f, "RPC failed: {}", status),
+ KeyGenerationFailed => "Failed to generate KEM key pair".fmt(f),
+ DecapsulationError => "Failed to decapsulate secret".fmt(f),
+ InvalidCiphertext => "The service returned an invalid ciphertext".fmt(f),
+ }
+ }
+}
+
+impl std::error::Error for Error {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ Self::GrpcConnectError(error) => Some(error),
+ _ => None,
+ }
+ }
+}
+
+type RelayConfigService = proto::post_quantum_secure_client::PostQuantumSecureClient<Channel>;
+
+const CONFIG_SERVICE_PORT: u16 = 1337;
+const ALGORITHM_NAME: &str = "Classic-McEliece-8192128f";
+
+/// Generates a new WireGuard key pair and negotiates a PSK with the relay in a PQ-safe
+/// manner. This creates a peer on the relay with the new WireGuard pubkey and PSK,
+/// which can then be used to establish a PQ-safe tunnel to the relay.
+// TODO: consider binding to the tunnel interface here, on non-windows platforms
+pub async fn push_pq_key(
+ service_address: IpAddr,
+ wg_pubkey: PublicKey,
+) -> Result<(PrivateKey, PresharedKey), Error> {
+ let wg_psk_privkey = PrivateKey::new_from_random();
+ let (kem_pubkey, kem_secret) = kem::generate_keys().await?;
+
+ let mut client = new_client(service_address).await?;
+ let response = client
+ .psk_exchange_experimental_v0(proto::PskRequestExperimentalV0 {
+ wg_pubkey: wg_pubkey.as_bytes().to_vec(),
+ wg_psk_pubkey: wg_psk_privkey.public_key().as_bytes().to_vec(),
+ kem_pubkey: Some(proto::KemPubkeyExperimentalV0 {
+ algorithm_name: ALGORITHM_NAME.to_string(),
+ key_data: kem_pubkey.into_vec(),
+ }),
+ })
+ .await
+ .map_err(Error::GrpcError)?;
+
+ let ciphertext: [u8; kem::CRYPTO_CIPHERTEXTBYTES] = response
+ .into_inner()
+ .ciphertext
+ .try_into()
+ .map_err(|_| Error::InvalidCiphertext)?;
+
+ Ok((wg_psk_privkey, kem::decapsulate(&kem_secret, &ciphertext)?))
+}
+
+async fn new_client(addr: IpAddr) -> Result<RelayConfigService, Error> {
+ RelayConfigService::connect(format!("tcp://{addr}:{CONFIG_SERVICE_PORT}"))
+ .await
+ .map_err(Error::GrpcConnectError)
+}
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 36bfba4f9d..a35c227ddd 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -29,6 +29,7 @@ impl TunnelParameters {
match self {
TunnelParameters::OpenVpn(params) => TunnelEndpoint {
tunnel_type: TunnelType::OpenVpn,
+ quantum_resistant: false,
endpoint: params.config.endpoint,
proxy: params.proxy.as_ref().map(|proxy| proxy.get_endpoint()),
obfuscation: None,
@@ -36,6 +37,7 @@ impl TunnelParameters {
},
TunnelParameters::Wireguard(params) => TunnelEndpoint {
tunnel_type: TunnelType::Wireguard,
+ quantum_resistant: params.options.use_pq_safe_psk,
endpoint: params
.connection
.get_exit_endpoint()
@@ -134,6 +136,8 @@ pub struct TunnelEndpoint {
#[cfg_attr(target_os = "android", jnix(skip))]
pub tunnel_type: TunnelType,
#[cfg_attr(target_os = "android", jnix(skip))]
+ pub quantum_resistant: bool,
+ #[cfg_attr(target_os = "android", jnix(skip))]
pub proxy: Option<proxy::ProxyEndpoint>,
#[cfg_attr(target_os = "android", jnix(skip))]
pub obfuscation: Option<ObfuscationEndpoint>,
@@ -143,7 +147,11 @@ pub struct TunnelEndpoint {
impl fmt::Display for TunnelEndpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
- write!(f, "{} - {}", self.tunnel_type, self.endpoint)?;
+ write!(f, "{} ", self.tunnel_type)?;
+ if self.quantum_resistant {
+ write!(f, "(quantum resistant) ")?;
+ }
+ write!(f, "- {}", self.endpoint)?;
match self.tunnel_type {
TunnelType::OpenVpn => {
if let Some(ref proxy) = self.proxy {
@@ -275,6 +283,52 @@ impl fmt::Display for AllowedEndpoint {
}
}
+#[derive(Debug, Clone, Eq, PartialEq, Hash)]
+pub enum AllowedTunnelTraffic {
+ None,
+ All,
+ Only(SocketAddr, Protocol),
+}
+
+impl fmt::Display for AllowedTunnelTraffic {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match *self {
+ AllowedTunnelTraffic::None => "None".fmt(f),
+ AllowedTunnelTraffic::All => "All".fmt(f),
+ AllowedTunnelTraffic::Only(addr, proto) => write!(f, "{}/{}", addr, proto),
+ }
+ }
+}
+
+/// A protocol: UDP, TCP, or ICMP.
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
+pub enum Protocol {
+ Udp,
+ Tcp,
+ IcmpV4,
+ IcmpV6,
+}
+
+impl fmt::Display for Protocol {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ Protocol::Udp => "UDP".fmt(f),
+ Protocol::Tcp => "TCP".fmt(f),
+ Protocol::IcmpV4 => "ICMPv4".fmt(f),
+ Protocol::IcmpV6 => "ICMPv6".fmt(f),
+ }
+ }
+}
+
+impl From<TransportProtocol> for Protocol {
+ fn from(proto: TransportProtocol) -> Self {
+ match proto {
+ TransportProtocol::Udp => Protocol::Udp,
+ TransportProtocol::Tcp => Protocol::Tcp,
+ }
+ }
+}
+
/// IP protocol version.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs
index 199ff3bd29..ff7ebef090 100644
--- a/talpid-types/src/net/wireguard.rs
+++ b/talpid-types/src/net/wireguard.rs
@@ -55,6 +55,8 @@ pub struct PeerConfig {
pub allowed_ips: Vec<IpNetwork>,
/// IP address of the WireGuard server.
pub endpoint: SocketAddr,
+ /// Preshared key.
+ pub psk: Option<PresharedKey>,
}
#[derive(Clone, Eq, PartialEq, Deserialize, Serialize, Debug)]
@@ -66,6 +68,7 @@ pub struct TunnelConfig {
/// Options in [`TunnelParameters`] that apply to any WireGuard connection.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(default)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(
target_os = "android",
@@ -78,6 +81,8 @@ pub struct TunnelOptions {
jnix(map = "|maybe_mtu| maybe_mtu.map(|mtu| mtu as i32)")
)]
pub mtu: Option<u16>,
+ /// Obtain a PSK using the relay config client.
+ pub use_pq_safe_psk: bool,
/// Temporary switch for wireguard-nt
#[cfg(windows)]
#[serde(default = "default_wgnt_setting")]
@@ -94,6 +99,7 @@ impl Default for TunnelOptions {
fn default() -> Self {
Self {
mtu: None,
+ use_pq_safe_psk: false,
#[cfg(windows)]
use_wireguard_nt: default_wgnt_setting(),
}
@@ -253,6 +259,40 @@ impl fmt::Display for PublicKey {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct PresharedKey([u8; 32]);
+
+impl PresharedKey {
+ /// Get the PSK as bytes
+ pub fn as_bytes(&self) -> &[u8; 32] {
+ &self.0
+ }
+}
+
+impl From<[u8; 32]> for PresharedKey {
+ fn from(key: [u8; 32]) -> PresharedKey {
+ PresharedKey(key)
+ }
+}
+
+impl Serialize for PresharedKey {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ serialize_key(&self.0, serializer)
+ }
+}
+
+impl<'de> Deserialize<'de> for PresharedKey {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserialize_key(deserializer)
+ }
+}
+
fn serialize_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp
index fe6c227e7f..b72cefe2d3 100644
--- a/windows/winfw/src/winfw/fwcontext.cpp
+++ b/windows/winfw/src/winfw/fwcontext.cpp
@@ -185,7 +185,8 @@ bool FwContext::applyPolicyConnecting
const WinFwEndpoint &relay,
const std::wstring &relayClient,
const std::optional<std::wstring> &tunnelInterfaceAlias,
- const std::optional<WinFwAllowedEndpoint> &allowedEndpoint
+ const std::optional<WinFwAllowedEndpoint> &allowedEndpoint,
+ const WinFwAllowedTunnelTraffic &allowedTunnelTraffic
)
{
Ruleset ruleset;
@@ -201,13 +202,39 @@ bool FwContext::applyPolicyConnecting
if (tunnelInterfaceAlias.has_value())
{
- ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>(
- *tunnelInterfaceAlias
- ));
-
- ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>(
- *tunnelInterfaceAlias
- ));
+ switch (allowedTunnelTraffic.type)
+ {
+ case WinFwAllowedTunnelTrafficType::All:
+ {
+ ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>(
+ *tunnelInterfaceAlias,
+ std::nullopt
+ ));
+ ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>(
+ *tunnelInterfaceAlias,
+ std::nullopt
+ ));
+ break;
+ }
+ case WinFwAllowedTunnelTrafficType::Only:
+ {
+ const auto onlyEndpoint = std::make_optional(baseline::PermitVpnTunnel::Endpoint{
+ wfp::IpAddress(allowedTunnelTraffic.endpoint->ip),
+ allowedTunnelTraffic.endpoint->port,
+ allowedTunnelTraffic.endpoint->protocol
+ });
+ ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>(
+ *tunnelInterfaceAlias,
+ onlyEndpoint
+ ));
+ ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>(
+ *tunnelInterfaceAlias,
+ onlyEndpoint
+ ));
+ break;
+ }
+ // For the "None" case, do nothing.
+ }
}
const auto status = applyRuleset(ruleset);
@@ -250,11 +277,13 @@ bool FwContext::applyPolicyConnected
}
ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnel>(
- tunnelInterfaceAlias
+ tunnelInterfaceAlias,
+ std::nullopt
));
ruleset.emplace_back(std::make_unique<baseline::PermitVpnTunnelService>(
- tunnelInterfaceAlias
+ tunnelInterfaceAlias,
+ std::nullopt
));
const auto status = applyRuleset(ruleset);
diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h
index bf67565993..5fc23f09a7 100644
--- a/windows/winfw/src/winfw/fwcontext.h
+++ b/windows/winfw/src/winfw/fwcontext.h
@@ -30,7 +30,8 @@ public:
const WinFwEndpoint &relay,
const std::wstring &relayClient,
const std::optional<std::wstring> &tunnelInterfaceAlias,
- const std::optional<WinFwAllowedEndpoint> &allowedEndpoint
+ const std::optional<WinFwAllowedEndpoint> &allowedEndpoint,
+ const WinFwAllowedTunnelTraffic &allowedTunnelTraffic
);
bool applyPolicyConnected
diff --git a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp
index 09d8937535..c1c74ba6ba 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitendpoint.cpp
@@ -1,6 +1,7 @@
#include "stdafx.h"
#include "permitendpoint.h"
#include <winfw/mullvadguids.h>
+#include <winfw/rules/shared.h>
#include <libwfp/filterbuilder.h>
#include <libwfp/conditionbuilder.h>
#include <libwfp/conditions/conditionprotocol.h>
@@ -30,19 +31,6 @@ const GUID &OutboundLayerFromIp(const wfp::IpAddress &ip)
};
}
-std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol)
-{
- switch (protocol)
- {
- case WinFwProtocol::Tcp: return ConditionProtocol::Tcp();
- case WinFwProtocol::Udp: return ConditionProtocol::Udp();
- default:
- {
- THROW_ERROR("Missing case handler in switch clause");
- }
- };
-}
-
} // anonymous namespace
PermitEndpoint::PermitEndpoint
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
index e756d68464..d9a1af0f28 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
@@ -1,17 +1,26 @@
#include "stdafx.h"
#include "permitvpntunnel.h"
#include <winfw/mullvadguids.h>
+#include <winfw/rules/shared.h>
#include <libwfp/filterbuilder.h>
#include <libwfp/conditionbuilder.h>
#include <libwfp/conditions/conditioninterface.h>
+#include <libwfp/conditions/conditionip.h>
+#include <libwfp/conditions/conditionport.h>
+#include <libwfp/conditions/conditionprotocol.h>
+#include <libcommon/error.h>
using namespace wfp::conditions;
namespace rules::baseline
{
-PermitVpnTunnel::PermitVpnTunnel(const std::wstring &tunnelInterfaceAlias)
+PermitVpnTunnel::PermitVpnTunnel(
+ const std::wstring &tunnelInterfaceAlias,
+ const std::optional<Endpoint> &onlyEndpoint
+)
: m_tunnelInterfaceAlias(tunnelInterfaceAlias)
+ , m_tunnelOnlyEndpoint(onlyEndpoint)
{
}
@@ -19,6 +28,9 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller)
{
wfp::FilterBuilder filterBuilder;
+ bool includeV4 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv4;
+ bool includeV6 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv6;
+
//
// #1 Permit outbound connections, IPv4.
//
@@ -33,11 +45,22 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller)
.weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
+ if (includeV4)
{
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
+ if (m_tunnelOnlyEndpoint.has_value())
+ {
+ conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip));
+ if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol))
+ {
+ conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port));
+ }
+ conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol));
+ }
+
if (!objectInstaller.addFilter(filterBuilder, conditionBuilder))
{
return false;
@@ -53,11 +76,26 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller)
.name(L"Permit outbound connections on tunnel interface (IPv6)")
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
- wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
+ if (includeV6)
+ {
+ wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
- conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
+
+ if (m_tunnelOnlyEndpoint.has_value())
+ {
+ conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip));
+ if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol))
+ {
+ conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port));
+ }
+ conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+ }
- return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+ return true;
}
}
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h
index 9c9a7b14c1..ee030a9f43 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.h
@@ -1,7 +1,10 @@
#pragma once
#include <winfw/rules/ifirewallrule.h>
+#include <winfw/winfw.h>
+#include <libwfp/ipaddress.h>
#include <string>
+#include <optional>
namespace rules::baseline
{
@@ -10,13 +13,23 @@ class PermitVpnTunnel : public IFirewallRule
{
public:
- PermitVpnTunnel(const std::wstring &tunnelInterfaceAlias);
+ struct Endpoint {
+ wfp::IpAddress ip;
+ uint16_t port;
+ WinFwProtocol protocol;
+ };
+
+ PermitVpnTunnel(
+ const std::wstring &tunnelInterfaceAlias,
+ const std::optional<Endpoint> &onlyEndpoint
+ );
bool apply(IObjectInstaller &objectInstaller) override;
private:
const std::wstring m_tunnelInterfaceAlias;
+ const std::optional<Endpoint> m_tunnelOnlyEndpoint;
};
}
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
index 00fbc8e76b..42214b6a77 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
@@ -1,17 +1,26 @@
#include "stdafx.h"
#include "permitvpntunnelservice.h"
#include <winfw/mullvadguids.h>
+#include <winfw/rules/shared.h>
#include <libwfp/filterbuilder.h>
#include <libwfp/conditionbuilder.h>
#include <libwfp/conditions/conditioninterface.h>
+#include <libwfp/conditions/conditionip.h>
+#include <libwfp/conditions/conditionport.h>
+#include <libwfp/conditions/conditionprotocol.h>
+#include <libcommon/error.h>
using namespace wfp::conditions;
namespace rules::baseline
{
-PermitVpnTunnelService::PermitVpnTunnelService(const std::wstring &tunnelInterfaceAlias)
+PermitVpnTunnelService::PermitVpnTunnelService(
+ const std::wstring &tunnelInterfaceAlias,
+ const std::optional<PermitVpnTunnel::Endpoint> &onlyEndpoint
+)
: m_tunnelInterfaceAlias(tunnelInterfaceAlias)
+ , m_tunnelOnlyEndpoint(onlyEndpoint)
{
}
@@ -19,6 +28,9 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller)
{
wfp::FilterBuilder filterBuilder;
+ bool includeV4 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv4;
+ bool includeV6 = !m_tunnelOnlyEndpoint.has_value() || m_tunnelOnlyEndpoint->ip.type() == wfp::IpAddress::Ipv6;
+
//
// #1 Permit inbound connections, IPv4.
//
@@ -35,26 +47,54 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller)
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4);
- conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
-
- if (!objectInstaller.addFilter(filterBuilder, conditionBuilder))
+ if (includeV4)
{
- return false;
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
+
+ if (m_tunnelOnlyEndpoint.has_value())
+ {
+ conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip));
+ if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol))
+ {
+ conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port));
+ }
+ conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol));
+ }
+
+ if (!objectInstaller.addFilter(filterBuilder, conditionBuilder))
+ {
+ return false;
+ }
}
//
// #2 Permit inbound connections, IPv6.
//
- filterBuilder
- .key(MullvadGuids::Filter_Baseline_PermitVpnTunnelService_Ipv6())
- .name(L"Permit inbound connections on tunnel interface (IPv6)")
- .layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6);
+ if (includeV6)
+ {
+ filterBuilder
+ .key(MullvadGuids::Filter_Baseline_PermitVpnTunnelService_Ipv6())
+ .name(L"Permit inbound connections on tunnel interface (IPv6)")
+ .layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6);
+
+ conditionBuilder.reset(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6);
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
- conditionBuilder.reset(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6);
- conditionBuilder.add_condition(ConditionInterface::Alias(m_tunnelInterfaceAlias));
+ if (m_tunnelOnlyEndpoint.has_value())
+ {
+ conditionBuilder.add_condition(ConditionIp::Remote(m_tunnelOnlyEndpoint->ip));
+ if (ProtocolHasPort(m_tunnelOnlyEndpoint->protocol))
+ {
+ conditionBuilder.add_condition(ConditionPort::Remote(m_tunnelOnlyEndpoint->port));
+ }
+ conditionBuilder.add_condition(CreateProtocolCondition(m_tunnelOnlyEndpoint->protocol));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+ }
- return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+ return true;
}
}
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h
index 8880c06328..8011659b97 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.h
@@ -1,7 +1,11 @@
#pragma once
#include <winfw/rules/ifirewallrule.h>
+#include <winfw/rules/baseline/permitvpntunnel.h>
+#include <winfw/winfw.h>
+#include <libwfp/ipaddress.h>
#include <string>
+#include <optional>
namespace rules::baseline
{
@@ -10,13 +14,17 @@ class PermitVpnTunnelService : public IFirewallRule
{
public:
- PermitVpnTunnelService(const std::wstring &tunnelInterfaceAlias);
+ PermitVpnTunnelService(
+ const std::wstring &tunnelInterfaceAlias,
+ const std::optional<PermitVpnTunnel::Endpoint> &onlyEndpoint
+ );
bool apply(IObjectInstaller &objectInstaller) override;
private:
const std::wstring m_tunnelInterfaceAlias;
+ const std::optional<PermitVpnTunnel::Endpoint> m_tunnelOnlyEndpoint;
};
}
diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
index a403230df9..3c913cab14 100644
--- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
+++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
@@ -2,6 +2,7 @@
#include "permitvpnrelay.h"
#include <winfw/mullvadguids.h>
#include <winfw/winfw.h>
+#include <winfw/rules/shared.h>
#include <libwfp/filterbuilder.h>
#include <libwfp/conditionbuilder.h>
#include <libwfp/conditions/conditionprotocol.h>
@@ -31,19 +32,6 @@ const GUID &LayerFromIp(const wfp::IpAddress &ip)
};
}
-std::unique_ptr<ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol)
-{
- switch (protocol)
- {
- case WinFwProtocol::Tcp: return ConditionProtocol::Tcp();
- case WinFwProtocol::Udp: return ConditionProtocol::Udp();
- default:
- {
- THROW_ERROR("Missing case handler in switch clause");
- }
- };
-}
-
const GUID &TranslateSublayer(PermitVpnRelay::Sublayer sublayer)
{
switch (sublayer)
diff --git a/windows/winfw/src/winfw/rules/shared.cpp b/windows/winfw/src/winfw/rules/shared.cpp
index 66cbbdfc83..1d1123e3eb 100644
--- a/windows/winfw/src/winfw/rules/shared.cpp
+++ b/windows/winfw/src/winfw/rules/shared.cpp
@@ -2,6 +2,8 @@
#include "shared.h"
#include <libcommon/error.h>
+using namespace wfp::conditions;
+
namespace rules
{
@@ -37,4 +39,36 @@ void SplitAddresses(const IpSet &in, IpSet &outIpv4, IpSet &outIpv6)
}
}
+std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol)
+{
+ switch (protocol)
+ {
+ case WinFwProtocol::Tcp: return ConditionProtocol::Tcp();
+ case WinFwProtocol::Udp: return ConditionProtocol::Udp();
+ case WinFwProtocol::Icmp: return ConditionProtocol::Icmp();
+ case WinFwProtocol::IcmpV6: return ConditionProtocol::IcmpV6();
+ default:
+ {
+ THROW_ERROR("Missing case handler in switch clause");
+ }
+ };
+}
+
+bool ProtocolHasPort(WinFwProtocol protocol)
+{
+ switch (protocol)
+ {
+ case WinFwProtocol::Tcp:
+ case WinFwProtocol::Udp:
+ return true;
+ case WinFwProtocol::Icmp:
+ case WinFwProtocol::IcmpV6:
+ return false;
+ default:
+ {
+ THROW_ERROR("Missing case handler in switch clause");
+ }
+ };
+}
+
}
diff --git a/windows/winfw/src/winfw/rules/shared.h b/windows/winfw/src/winfw/rules/shared.h
index 1b08d3ed02..4f4da187ca 100644
--- a/windows/winfw/src/winfw/rules/shared.h
+++ b/windows/winfw/src/winfw/rules/shared.h
@@ -1,6 +1,9 @@
#pragma once
#include <vector>
+#include <memory>
+#include <winfw/winfw.h>
+#include <libwfp/conditions/conditionprotocol.h>
#include <libwfp/ipaddress.h>
namespace rules
@@ -10,4 +13,8 @@ using IpSet = std::vector<wfp::IpAddress>;
void SplitAddresses(const IpSet &in, IpSet &outIpv4, IpSet &outIpv6);
+std::unique_ptr<wfp::conditions::ConditionProtocol> CreateProtocolCondition(WinFwProtocol protocol);
+
+bool ProtocolHasPort(WinFwProtocol protocol);
+
}
diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp
index ae0f0791de..4110dcd2f8 100644
--- a/windows/winfw/src/winfw/winfw.cpp
+++ b/windows/winfw/src/winfw/winfw.cpp
@@ -233,7 +233,8 @@ WinFw_ApplyPolicyConnecting(
const WinFwEndpoint *relay,
const wchar_t *relayClient,
const wchar_t *tunnelInterfaceAlias,
- const WinFwAllowedEndpoint *allowedEndpoint
+ const WinFwAllowedEndpoint *allowedEndpoint,
+ const WinFwAllowedTunnelTraffic *allowedTunnelTraffic
)
{
if (nullptr == g_fwContext)
@@ -258,12 +259,18 @@ WinFw_ApplyPolicyConnecting(
THROW_ERROR("Invalid argument: relayClient");
}
+ if (nullptr == allowedTunnelTraffic)
+ {
+ THROW_ERROR("Invalid argument: allowedTunnelTraffic");
+ }
+
return g_fwContext->applyPolicyConnecting(
*settings,
*relay,
relayClient,
tunnelInterfaceAlias != nullptr ? std::make_optional(tunnelInterfaceAlias) : std::nullopt,
- MakeOptional(allowedEndpoint)
+ MakeOptional(allowedEndpoint),
+ *allowedTunnelTraffic
) ? WINFW_POLICY_STATUS_SUCCESS : WINFW_POLICY_STATUS_GENERAL_FAILURE;
}
catch (common::error::WindowsException &err)
diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h
index 4913ae7baf..6394893d91 100644
--- a/windows/winfw/src/winfw/winfw.h
+++ b/windows/winfw/src/winfw/winfw.h
@@ -32,7 +32,9 @@ WinFwSettings;
enum WinFwProtocol : uint8_t
{
Tcp = 0,
- Udp = 1
+ Udp = 1,
+ Icmp = 2,
+ IcmpV6 = 3
};
typedef struct tag_WinFwEndpoint
@@ -55,6 +57,20 @@ typedef struct tag_WinFwAllowedEndpoint
}
WinFwAllowedEndpoint;
+enum WinFwAllowedTunnelTrafficType : uint8_t
+{
+ None,
+ All,
+ Only
+};
+
+typedef struct tag_WinFwAllowedTunnelTraffic
+{
+ WinFwAllowedTunnelTrafficType type;
+ WinFwEndpoint *endpoint;
+}
+WinFwAllowedTunnelTraffic;
+
///////////////////////////////////////////////////////////////////////////////
// Functions
///////////////////////////////////////////////////////////////////////////////
@@ -139,7 +155,7 @@ enum WINFW_POLICY_STATUS
// Apply restrictions in the firewall that block all traffic, except:
// - What is specified by settings
// - Communication with the relay server
-// - Non-DNS traffic inside the VPN tunnel
+// - Specified in-tunnel traffic, except DNS.
//
extern "C"
WINFW_LINKAGE
@@ -150,7 +166,8 @@ WinFw_ApplyPolicyConnecting(
const WinFwEndpoint *relay,
const wchar_t *relayClient,
const wchar_t *tunnelInterfaceAlias,
- const WinFwAllowedEndpoint *allowedEndpoint
+ const WinFwAllowedEndpoint *allowedEndpoint,
+ const WinFwAllowedTunnelTraffic *allowedTunnelTraffic
);
//
diff --git a/wireguard/libwg/libwg.go b/wireguard/libwg/libwg.go
index 82c6e8205f..e26ea7b7da 100644
--- a/wireguard/libwg/libwg.go
+++ b/wireguard/libwg/libwg.go
@@ -13,6 +13,7 @@ import (
"bufio"
"bytes"
"runtime"
+ "strings"
"unsafe"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer"
@@ -59,6 +60,27 @@ func wgGetConfig(tunnelHandle int32) *C.char {
return C.CString(settings.String())
}
+//export wgSetConfig
+func wgSetConfig(tunnelHandle int32, cSettings *C.char) int32 {
+ tunnel, err := tunnels.Get(tunnelHandle)
+ if err != nil {
+ return ERROR_GENERAL_FAILURE
+ }
+ if cSettings == nil {
+ tunnel.Logger.Errorf("cSettings is null\n")
+ return ERROR_GENERAL_FAILURE
+ }
+ settings := C.GoString(cSettings)
+
+ setError := tunnel.Device.IpcSetOperation(bufio.NewReader(strings.NewReader(settings)))
+ if setError != nil {
+ tunnel.Logger.Errorf("Failed to set device configuration\n")
+ tunnel.Logger.Errorf("%s\n", setError)
+ return ERROR_GENERAL_FAILURE
+ }
+ return 0
+}
+
//export wgFreePtr
func wgFreePtr(ptr unsafe.Pointer) {
C.free(ptr)