summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-03-14 13:40:36 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-14 13:40:36 +0100
commit6459ae7beefcc5f13eb54254dfe402dd807c62fe (patch)
treebc03c4027aad5c47f00dfa4c1fb3584dff4d1add
parent78dc4644a82d7b3fb904ef3cbac8a1f705f0a213 (diff)
parent3e1271777fd7556a76abc582bd3c44356ecbd15a (diff)
downloadmullvadvpn-6459ae7beefcc5f13eb54254dfe402dd807c62fe.tar.xz
mullvadvpn-6459ae7beefcc5f13eb54254dfe402dd807c62fe.zip
Merge branch 'device-api'
-rw-r--r--android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt14
-rw-r--r--dist-assets/linux/before-remove.sh2
-rwxr-xr-xdist-assets/uninstall_macos.sh2
-rw-r--r--dist-assets/windows/installer.nsh18
-rw-r--r--mullvad-cli/src/cmds/account.rs273
-rw-r--r--mullvad-cli/src/cmds/status.rs12
-rw-r--r--mullvad-cli/src/cmds/tunnel.rs13
-rw-r--r--mullvad-cli/src/format.rs21
-rw-r--r--mullvad-cli/src/main.rs3
-rw-r--r--mullvad-daemon/src/account.rs173
-rw-r--r--mullvad-daemon/src/device.rs1136
-rw-r--r--mullvad-daemon/src/lib.rs816
-rw-r--r--mullvad-daemon/src/management_interface.rs123
-rw-r--r--mullvad-daemon/src/migrations/mod.rs10
-rw-r--r--mullvad-daemon/src/migrations/v5.rs134
-rw-r--r--mullvad-daemon/src/relays/mod.rs114
-rw-r--r--mullvad-daemon/src/settings.rs17
-rw-r--r--mullvad-daemon/src/wireguard.rs499
-rw-r--r--mullvad-jni/src/daemon_interface.rs62
-rw-r--r--mullvad-jni/src/jni_event_listener.rs80
-rw-r--r--mullvad-jni/src/lib.rs225
-rw-r--r--mullvad-management-interface/proto/management_interface.proto79
-rw-r--r--mullvad-management-interface/src/types.rs114
-rw-r--r--mullvad-rpc/src/access.rs110
-rw-r--r--mullvad-rpc/src/availability.rs16
-rw-r--r--mullvad-rpc/src/device.rs196
-rw-r--r--mullvad-rpc/src/lib.rs222
-rw-r--r--mullvad-rpc/src/relay_list.rs4
-rw-r--r--mullvad-rpc/src/rest.rs103
-rw-r--r--mullvad-setup/src/main.rs84
-rw-r--r--mullvad-types/src/account.rs21
-rw-r--r--mullvad-types/src/device.rs142
-rw-r--r--mullvad-types/src/lib.rs1
-rw-r--r--mullvad-types/src/settings/mod.rs44
-rw-r--r--mullvad-types/src/states.rs8
-rw-r--r--mullvad-types/src/wireguard.rs21
-rw-r--r--talpid-core/src/mpsc.rs6
37 files changed, 3097 insertions, 1821 deletions
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt
index 4565844daa..8470f314d7 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt
@@ -45,7 +45,8 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
}
fun generateWireguardKey(): KeygenEvent? {
- return generateWireguardKey(daemonInterfaceAddress)
+ // TODO: remove
+ return null
}
fun getAccountData(accountToken: String): GetAccountDataResult {
@@ -85,6 +86,7 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
}
fun getWireguardKey(): PublicKey? {
+ // TODO: no longer needed
return getWireguardKey(daemonInterfaceAddress)
}
@@ -97,7 +99,7 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
}
fun setAccount(accountToken: String?) {
- setAccount(daemonInterfaceAddress, accountToken)
+ // TODO: replace with login+logout
}
fun setAllowLan(allowLan: Boolean) {
@@ -154,7 +156,6 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
private external fun connect(daemonInterfaceAddress: Long)
private external fun createNewAccount(daemonInterfaceAddress: Long): String?
private external fun disconnect(daemonInterfaceAddress: Long)
- private external fun generateWireguardKey(daemonInterfaceAddress: Long): KeygenEvent?
private external fun getAccountData(
daemonInterfaceAddress: Long,
accountToken: String
@@ -170,7 +171,8 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
private external fun getWireguardKey(daemonInterfaceAddress: Long): PublicKey?
private external fun reconnect(daemonInterfaceAddress: Long)
private external fun clearAccountHistory(daemonInterfaceAddress: Long)
- private external fun setAccount(daemonInterfaceAddress: Long, accountToken: String?)
+ private external fun loginAccount(daemonInterfaceAddress: Long, accountToken: String?)
+ private external fun logoutAccount(daemonInterfaceAddress: Long)
private external fun setAllowLan(daemonInterfaceAddress: Long, allowLan: Boolean)
private external fun setAutoConnect(daemonInterfaceAddress: Long, alwaysOn: Boolean)
private external fun setDnsOptions(daemonInterfaceAddress: Long, dnsOptions: DnsOptions)
@@ -190,10 +192,6 @@ class MullvadDaemon(val vpnService: MullvadVpnService) {
onAppVersionInfoChange?.invoke(appVersionInfo)
}
- private fun notifyKeygenEvent(event: KeygenEvent) {
- onKeygenEvent?.invoke(event)
- }
-
private fun notifyRelayListEvent(relayList: RelayList) {
onRelayListChange?.invoke(relayList)
}
diff --git a/dist-assets/linux/before-remove.sh b/dist-assets/linux/before-remove.sh
index 602d3d09fc..6d4ec5262f 100644
--- a/dist-assets/linux/before-remove.sh
+++ b/dist-assets/linux/before-remove.sh
@@ -26,4 +26,4 @@ fi
pkill -x "mullvad-gui" || true
/opt/Mullvad\ VPN/resources/mullvad-setup reset-firewall || echo "Failed to reset firewall"
-/opt/Mullvad\ VPN/resources/mullvad-setup remove-wireguard-key || echo "Failed to remove leftover WireGuard key"
+/opt/Mullvad\ VPN/resources/mullvad-setup remove-device || echo "Failed to remove device from account"
diff --git a/dist-assets/uninstall_macos.sh b/dist-assets/uninstall_macos.sh
index 7833ba528f..b7f27302f2 100755
--- a/dist-assets/uninstall_macos.sh
+++ b/dist-assets/uninstall_macos.sh
@@ -23,7 +23,7 @@ sudo dscl . -delete /groups/mullvad-exclusion || echo "Failed to remove 'mullvad
echo "Resetting firewall"
sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup reset-firewall
-sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup remove-wireguard-key
+sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup remove-device
echo "Removing zsh shell completion symlink ..."
sudo rm -f /usr/local/share/zsh/site-functions/_mullvad
diff --git a/dist-assets/windows/installer.nsh b/dist-assets/windows/installer.nsh
index 422d6d3169..b04fadfb2d 100644
--- a/dist-assets/windows/installer.nsh
+++ b/dist-assets/windows/installer.nsh
@@ -712,25 +712,25 @@
!define FirewallWarningCheck '!insertmacro "FirewallWarningCheck"'
#
-# RemoveWireGuardKey
+# RemoveCurrentDevice
#
-# Remove the WireGuard key from the account, if there is one
+# Remove the device from the account, if there is one
#
-!macro RemoveWireGuardKey
+!macro RemoveCurrentDevice
- log::Log "RemoveWireGuardKey()"
+ log::Log "RemoveCurrentDevice()"
Push $0
Push $1
- nsExec::ExecToStack '"$TEMP\mullvad-setup.exe" remove-wireguard-key'
+ nsExec::ExecToStack '"$TEMP\mullvad-setup.exe" remove-device'
Pop $0
Pop $1
${If} $0 != ${MVSETUP_OK}
- log::LogWithDetails "RemoveWireGuardKey() failed" $1
+ log::LogWithDetails "RemoveCurrentDevice() failed" $1
${Else}
- log::Log "RemoveWireGuardKey() completed successfully"
+ log::Log "RemoveCurrentDevice() completed successfully"
${EndIf}
Pop $1
@@ -738,7 +738,7 @@
!macroend
-!define RemoveWireGuardKey '!insertmacro "RemoveWireGuardKey"'
+!define RemoveCurrentDevice '!insertmacro "RemoveCurrentDevice"'
#
@@ -1170,7 +1170,7 @@
${If} $FullUninstall == 1
${ClearFirewallRules}
- ${RemoveWireGuardKey}
+ ${RemoveCurrentDevice}
${ExtractWireGuard}
${RemoveWintun}
diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs
index 0bbbc28024..b4ef7c7f14 100644
--- a/mullvad-cli/src/cmds/account.rs
+++ b/mullvad-cli/src/cmds/account.rs
@@ -1,9 +1,20 @@
use crate::{new_rpc_client, Command, Error, Result};
use itertools::Itertools;
-use mullvad_management_interface::{types::Timestamp, Code};
-use mullvad_types::account::AccountToken;
+use mullvad_management_interface::{
+ types::{self, Timestamp},
+ Code, ManagementServiceClient, Status,
+};
+use mullvad_types::{account::AccountToken, device::Device};
use std::io::{self, Write};
+const NOT_LOGGED_IN_ERROR: &str = "Not logged in to any account";
+const DEVICE_NOT_FOUND_ERROR: &str = "There is no such device";
+const INVALID_ACCOUNT_ERROR: &str = "The account does not exist";
+const TOO_MANY_DEVICES_ERROR: &str =
+ "There are too many devices on this account. Revoke one to log in";
+const ALREADY_LOGGED_IN_ERROR: &str =
+ "You are already logged in. Please log out before creating a new account";
+
pub struct Account;
#[mullvad_management_interface::async_trait]
@@ -16,23 +27,55 @@ impl Command for Account {
clap::App::new(self.name())
.about("Control and display information about your Mullvad account")
.setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::App::new("create").about("Create and log in to a new account"))
.subcommand(
- clap::App::new("set").about("Change account").arg(
- clap::Arg::new("token")
+ clap::App::new("login").about("Log in to an account").arg(
+ clap::Arg::new("account")
.help("The Mullvad account token to configure the client with")
.required(false),
),
)
+ .subcommand(clap::App::new("logout").about("Log out of the current account"))
.subcommand(
clap::App::new("get")
- .about("Display information about the currently configured account"),
+ .about("Display information about the current account")
+ .arg(
+ clap::Arg::new("verbose")
+ .long("verbose")
+ .short('v')
+ .help("Enables verbose output"),
+ ),
)
.subcommand(
- clap::App::new("unset").about("Removes the account number from the settings"),
+ clap::App::new("list-devices")
+ .about("List devices associated with an account")
+ .arg(
+ clap::Arg::new("account")
+ .help("Mullvad account number")
+ .long("account")
+ .takes_value(true),
+ )
+ .arg(
+ clap::Arg::new("verbose")
+ .long("verbose")
+ .short('v')
+ .help("Enables verbose output"),
+ ),
)
.subcommand(
- clap::App::new("create")
- .about("Creates a new account and sets it as the active one"),
+ clap::App::new("revoke-device")
+ .about("Revoke a device associated with an account")
+ .arg(
+ clap::Arg::new("account")
+ .help("Mullvad account number")
+ .long("account")
+ .takes_value(true),
+ )
+ .arg(
+ clap::Arg::new("device")
+ .help("Name or ID of the device to revoke")
+ .required(true),
+ ),
)
.subcommand(
clap::App::new("redeem").about("Redeems a voucher").arg(
@@ -44,29 +87,19 @@ impl Command for Account {
}
async fn run(&self, matches: &clap::ArgMatches) -> Result<()> {
- if let Some(set_matches) = matches.subcommand_matches("set") {
- let mut token = match set_matches.value_of("token") {
- Some(token) => token.to_string(),
- None => {
- let mut token = String::new();
- io::stdout()
- .write_all(b"Enter account token: ")
- .expect("Failed to write to STDOUT");
- let _ = io::stdout().flush();
- io::stdin()
- .read_line(&mut token)
- .expect("Failed to read from STDIN");
- token
- }
- };
- token = token.split_whitespace().join("").to_string();
- self.set(Some(token)).await
- } else if let Some(_matches) = matches.subcommand_matches("get") {
- self.get().await
- } else if let Some(_matches) = matches.subcommand_matches("unset") {
- self.set(None).await
- } else if let Some(_matches) = matches.subcommand_matches("create") {
+ if let Some(_matches) = matches.subcommand_matches("create") {
self.create().await
+ } else if let Some(set_matches) = matches.subcommand_matches("login") {
+ self.login(parse_token_else_stdin(set_matches)).await
+ } else if let Some(_matches) = matches.subcommand_matches("logout") {
+ self.logout().await
+ } else if let Some(set_matches) = matches.subcommand_matches("get") {
+ let verbose = set_matches.is_present("verbose");
+ self.get(verbose).await
+ } else if let Some(set_matches) = matches.subcommand_matches("list-devices") {
+ self.list_devices(set_matches).await
+ } else if let Some(set_matches) = matches.subcommand_matches("revoke-device") {
+ self.revoke_device(set_matches).await
} else if let Some(matches) = matches.subcommand_matches("redeem") {
let voucher = matches.value_of_t_or_exit("voucher");
self.redeem_voucher(voucher).await
@@ -77,24 +110,52 @@ impl Command for Account {
}
impl Account {
- async fn set(&self, token: Option<AccountToken>) -> Result<()> {
+ async fn create(&self) -> Result<()> {
let mut rpc = new_rpc_client().await?;
- rpc.set_account(token.clone().unwrap_or_default()).await?;
- if let Some(token) = token {
- println!("Mullvad account \"{}\" set", token);
- } else {
- println!("Mullvad account removed");
- }
+ rpc.create_new_account(()).await.map_err(map_device_error)?;
+ println!("New account created!");
+ self.get(false).await
+ }
+
+ async fn login(&self, token: AccountToken) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ rpc.login_account(token.clone())
+ .await
+ .map_err(map_device_error)?;
+ println!("Mullvad account \"{}\" set", token);
Ok(())
}
- async fn get(&self) -> Result<()> {
+ async fn logout(&self) -> Result<()> {
let mut rpc = new_rpc_client().await?;
- let settings = rpc.get_settings(()).await?.into_inner();
- if settings.account_token != "" {
- println!("Mullvad account: {}", settings.account_token);
+ rpc.logout_account(()).await?;
+ println!("Removed device from Mullvad account");
+ Ok(())
+ }
+
+ async fn get(&self, verbose: bool) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ let device = rpc
+ .get_device(())
+ .await
+ .map_err(|error| match error.code() {
+ Code::NotFound => Error::Other(NOT_LOGGED_IN_ERROR),
+ _other => map_device_error(error),
+ })?
+ .into_inner();
+ if !device.account_token.is_empty() {
+ println!("Mullvad account: {}", device.account_token);
+ let inner_device = Device::try_from(device.device.unwrap()).unwrap();
+ println!("Device name : {}", inner_device.pretty_name());
+ if verbose {
+ println!("Device id : {}", inner_device.id);
+ println!("Device pubkey : {}", inner_device.pubkey);
+ for port in inner_device.ports {
+ println!("Device port : {}", port);
+ }
+ }
let expiry = rpc
- .get_account_data(settings.account_token)
+ .get_account_data(device.account_token)
.await
.map_err(|error| Error::RpcFailedExt("Failed to fetch account data", error))?
.into_inner();
@@ -108,11 +169,88 @@ impl Account {
Ok(())
}
- async fn create(&self) -> Result<()> {
+ async fn list_devices(&self, matches: &clap::ArgMatches) -> Result<()> {
let mut rpc = new_rpc_client().await?;
- rpc.create_new_account(()).await?;
- println!("New account created!");
- self.get().await
+ let token = self.parse_account_else_current(&mut rpc, matches).await?;
+ let device_list = rpc
+ .list_devices(token)
+ .await
+ .map_err(map_device_error)?
+ .into_inner();
+
+ let verbose = matches.is_present("verbose");
+
+ println!("Devices on the account:");
+ for device in device_list.devices {
+ let device = Device::try_from(device.clone()).unwrap();
+ if verbose {
+ println!();
+ println!("Name : {}", device.pretty_name());
+ println!("Id : {}", device.id);
+ println!("Public key: {}", device.pubkey);
+ for port in device.ports {
+ println!("Port : {}", port);
+ }
+ } else {
+ println!("{}", device.pretty_name());
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn revoke_device(&self, matches: &clap::ArgMatches) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+
+ let token = self.parse_account_else_current(&mut rpc, matches).await?;
+ let device_to_revoke = parse_device_name(matches);
+
+ let device_list = rpc
+ .list_devices(token.clone())
+ .await
+ .map_err(map_device_error)?
+ .into_inner();
+ let device_id = device_list
+ .devices
+ .into_iter()
+ .find(|dev| {
+ dev.name.eq_ignore_ascii_case(&device_to_revoke)
+ || dev.id.eq_ignore_ascii_case(&device_to_revoke)
+ })
+ .map(|dev| dev.id)
+ .ok_or_else(|| Error::Other(DEVICE_NOT_FOUND_ERROR))?;
+
+ rpc.remove_device(types::DeviceRemoval {
+ account_token: token,
+ device_id,
+ })
+ .await
+ .map_err(map_device_error)?;
+ println!("Removed device");
+ Ok(())
+ }
+
+ async fn parse_account_else_current(
+ &self,
+ rpc: &mut ManagementServiceClient,
+ matches: &clap::ArgMatches,
+ ) -> Result<String> {
+ match matches.value_of("account").map(str::to_string) {
+ Some(token) => Ok(token),
+ None => {
+ let device = rpc
+ .get_device(())
+ .await
+ .map_err(|error| match error.code() {
+ mullvad_management_interface::Code::NotFound => {
+ Error::Other("Log in or specify an account")
+ }
+ _ => Error::RpcFailedExt("Failed to obtain device", error),
+ })?
+ .into_inner();
+ Ok(device.account_token)
+ }
+ }
}
async fn redeem_voucher(&self, mut voucher: String) -> Result<()> {
@@ -163,3 +301,46 @@ impl Account {
utc.with_timezone(&chrono::Local).to_string()
}
}
+
+fn map_device_error(error: Status) -> Error {
+ match error.code() {
+ Code::ResourceExhausted => Error::Other(TOO_MANY_DEVICES_ERROR),
+ Code::Unauthenticated => Error::Other(INVALID_ACCOUNT_ERROR),
+ Code::AlreadyExists => Error::Other(ALREADY_LOGGED_IN_ERROR),
+ Code::NotFound => Error::Other(DEVICE_NOT_FOUND_ERROR),
+ _other => Error::RpcFailed(error),
+ }
+}
+
+fn parse_token_else_stdin(matches: &clap::ArgMatches) -> String {
+ parse_from_match_else_stdin("Enter account number: ", "account", matches)
+ .split_whitespace()
+ .join("")
+}
+
+fn parse_device_name(matches: &clap::ArgMatches) -> String {
+ parse_from_match_else_stdin("Enter device name: ", "device", matches)
+ .trim()
+ .to_string()
+}
+
+fn parse_from_match_else_stdin(
+ prompt_str: &'static str,
+ key: &'static str,
+ matches: &clap::ArgMatches,
+) -> String {
+ match matches.value_of(key) {
+ Some(device) => device.to_string(),
+ None => {
+ let mut val = String::new();
+ io::stdout()
+ .write_all(prompt_str.as_bytes())
+ .expect("Failed to write to STDOUT");
+ let _ = io::stdout().flush();
+ io::stdin()
+ .read_line(&mut val)
+ .expect("Failed to read from STDIN");
+ val
+ }
+ }
+}
diff --git a/mullvad-cli/src/cmds/status.rs b/mullvad-cli/src/cmds/status.rs
index 8c4a929c30..69052dcaf1 100644
--- a/mullvad-cli/src/cmds/status.rs
+++ b/mullvad-cli/src/cmds/status.rs
@@ -1,4 +1,4 @@
-use crate::{format, format::print_keygen_event, new_rpc_client, Command, Error, Result};
+use crate::{format, new_rpc_client, Command, Error, Result};
use mullvad_management_interface::{
types::daemon_event::Event as EventType, ManagementServiceClient,
};
@@ -74,10 +74,14 @@ impl Command for Status {
println!("New app version info: {:#?}", app_version_info);
}
}
- EventType::KeyEvent(key_event) => {
+ EventType::Device(device) => {
if verbose {
- print!("Key event: ");
- print_keygen_event(&key_event);
+ println!("Device event: {:#?}", device);
+ }
+ }
+ EventType::RemoveDevice(device) => {
+ if verbose {
+ println!("Remove device event: {:#?}", device);
}
}
}
diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs
index f3b218648e..f01452a925 100644
--- a/mullvad-cli/src/cmds/tunnel.rs
+++ b/mullvad-cli/src/cmds/tunnel.rs
@@ -1,4 +1,4 @@
-use crate::{format::print_keygen_event, new_rpc_client, Command, Error, Result};
+use crate::{new_rpc_client, Command, Error, Result};
use mullvad_management_interface::types::{self, Timestamp, TunnelOptions};
use mullvad_types::wireguard::DEFAULT_ROTATION_INTERVAL;
use std::{convert::TryFrom, time::Duration};
@@ -246,20 +246,13 @@ impl Tunnel {
println!("No key is set");
return Ok(());
}
-
- let is_valid = rpc
- .verify_wireguard_key(())
- .await
- .map_err(|error| Error::RpcFailedExt("Failed to verify key", error))?
- .into_inner();
- println!("Key is valid for use with current account: {}", is_valid);
Ok(())
}
async fn process_wireguard_key_generate() -> Result<()> {
let mut rpc = new_rpc_client().await?;
- let keygen_event = rpc.generate_wireguard_key(()).await?;
- print_keygen_event(&keygen_event.into_inner());
+ rpc.rotate_wireguard_key(()).await?;
+ println!("Rotated WireGuard key");
Ok(())
}
diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs
index b056ffff53..eb91ffcca8 100644
--- a/mullvad-cli/src/format.rs
+++ b/mullvad-cli/src/format.rs
@@ -5,30 +5,11 @@ use mullvad_management_interface::types::{
},
tunnel_state,
tunnel_state::State::*,
- ErrorState, KeygenEvent, ProxyType, TransportProtocol, TunnelEndpoint, TunnelState, TunnelType,
+ ErrorState, ProxyType, TransportProtocol, TunnelEndpoint, TunnelState, TunnelType,
};
use mullvad_types::auth_failed::AuthFailed;
use std::fmt::Write;
-pub fn print_keygen_event(key_event: &KeygenEvent) {
- use mullvad_management_interface::types::keygen_event::KeygenEvent as EventType;
-
- match EventType::from_i32(key_event.event).unwrap() {
- EventType::NewKey => {
- println!(
- "New WireGuard key: {}",
- base64::encode(&key_event.new_key.as_ref().unwrap().key)
- );
- }
- EventType::TooManyKeys => {
- println!("Account has too many keys already");
- }
- EventType::GenerationFailure => {
- println!("Failed to generate new WireGuard key");
- }
- }
-}
-
pub fn print_state(state: &TunnelState) {
print!("Tunnel status: ");
match state.state.as_ref().unwrap() {
diff --git a/mullvad-cli/src/main.rs b/mullvad-cli/src/main.rs
index 55a195cdb8..df7ef0a04c 100644
--- a/mullvad-cli/src/main.rs
+++ b/mullvad-cli/src/main.rs
@@ -49,6 +49,9 @@ pub enum Error {
//#[cfg(all(unix, not(target_os = "android"))
#[error(display = "Failed to generate shell completions")]
CompletionsError(#[error(source, no_from)] io::Error),
+
+ #[error(display = "{}", _0)]
+ Other(&'static str),
}
#[tokio::main]
diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs
deleted file mode 100644
index f5655c9d1f..0000000000
--- a/mullvad-daemon/src/account.rs
+++ /dev/null
@@ -1,173 +0,0 @@
-use chrono::{DateTime, Utc};
-use futures::future::{abortable, AbortHandle};
-use mullvad_rpc::{
- availability::ApiAvailabilityHandle,
- rest::{self, Error as RestError, MullvadRestHandle},
- AccountsProxy,
-};
-use mullvad_types::account::{AccountToken, VoucherSubmission};
-use std::{future::Future, time::Duration};
-use talpid_core::future_retry::{
- constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered,
-};
-
-const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
-const RETRY_ACTION_MAX_RETRIES: usize = 2;
-
-const RETRY_EXPIRY_CHECK_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
-const RETRY_EXPIRY_CHECK_INTERVAL_FACTOR: u32 = 5;
-const RETRY_EXPIRY_CHECK_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-
-pub struct Account(());
-
-#[derive(Clone)]
-pub struct AccountHandle {
- api_availability: ApiAvailabilityHandle,
- initial_check_abort_handle: AbortHandle,
- proxy: AccountsProxy,
-}
-
-impl AccountHandle {
- pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.create_account(),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub fn get_www_auth_token(
- &self,
- account: AccountToken,
- ) -> impl Future<Output = Result<String, rest::Error>> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.get_www_auth_token(account.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.get_expiry(token.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if handle_expiry_result_inner(&result, &self.api_availability) {
- self.initial_check_abort_handle.abort();
- }
- result
- }
-
- pub async fn submit_voucher(
- &mut self,
- account_token: AccountToken,
- voucher: String,
- ) -> Result<VoucherSubmission, rest::Error> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if result.is_ok() {
- self.initial_check_abort_handle.abort();
- self.api_availability.resume_background();
- }
- result
- }
-
- fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
- match result {
- Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
- _ => false,
- }
- }
-}
-
-impl Account {
- pub fn new(
- runtime: tokio::runtime::Handle,
- rpc_handle: MullvadRestHandle,
- token: Option<String>,
- api_availability: ApiAvailabilityHandle,
- ) -> AccountHandle {
- let accounts_proxy = AccountsProxy::new(rpc_handle);
- api_availability.pause_background();
-
- let api_availability_copy = api_availability.clone();
- let accounts_proxy_copy = accounts_proxy.clone();
-
- let (future, initial_check_abort_handle) = abortable(async move {
- let token = if let Some(token) = token {
- token
- } else {
- api_availability.pause_background();
- return;
- };
-
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(
- RETRY_EXPIRY_CHECK_INTERVAL_INITIAL,
- RETRY_EXPIRY_CHECK_INTERVAL_FACTOR,
- )
- .max_delay(RETRY_EXPIRY_CHECK_INTERVAL_MAX),
- );
- let future_generator = move || {
- let wait_online = api_availability.wait_online();
- let expiry_fut = accounts_proxy.get_expiry(token.clone());
- let api_availability_copy = api_availability.clone();
- async move {
- let _ = wait_online.await;
- handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy)
- }
- };
- let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
- retry_future(future_generator, should_retry, retry_strategy).await;
- });
- runtime.spawn(future);
-
- AccountHandle {
- api_availability: api_availability_copy,
- initial_check_abort_handle,
- proxy: accounts_proxy_copy,
- }
- }
-}
-
-fn handle_expiry_result_inner(
- result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>,
- api_availability: &ApiAvailabilityHandle,
-) -> bool {
- match result {
- Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
- api_availability.resume_background();
- true
- }
- Ok(_expiry) => {
- api_availability.pause_background();
- true
- }
- Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => {
- if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH {
- api_availability.pause_background();
- return true;
- }
- false
- }
- Err(_) => false,
- }
-}
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs
new file mode 100644
index 0000000000..868dc003d8
--- /dev/null
+++ b/mullvad-daemon/src/device.rs
@@ -0,0 +1,1136 @@
+use chrono::{DateTime, Utc};
+use futures::{
+ channel::{mpsc, oneshot},
+ future::{abortable, AbortHandle},
+ stream::StreamExt,
+};
+use mullvad_rpc::{
+ availability::ApiAvailabilityHandle,
+ rest::{self, Error as RestError, MullvadRestHandle},
+ AccountsProxy, DevicesProxy,
+};
+use mullvad_types::{
+ account::{AccountToken, VoucherSubmission},
+ device::{Device, DeviceData, DeviceEvent, DeviceId},
+ wireguard::{RotationInterval, WireguardData},
+};
+use std::{
+ future::Future,
+ path::Path,
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+ },
+ time::{Duration, SystemTime},
+};
+use talpid_core::{
+ future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered},
+ mpsc::Sender,
+};
+use talpid_types::{
+ net::{wireguard::PrivateKey, TunnelType},
+ tunnel::TunnelStateTransition,
+ ErrorExt,
+};
+use tokio::{
+ fs,
+ io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
+};
+
+/// How often to check whether the key has expired.
+/// A short interval is used in case the computer is ever suspended.
+const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(5 * 60);
+
+/// File that used to store account and device data.
+const DEVICE_CACHE_FILENAME: &str = "device.json";
+
+const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
+const RETRY_ACTION_MAX_RETRIES: usize = 2;
+
+const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
+const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5;
+const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
+
+/// How long to keep the known status for [AccountManagerHandle::validate_device].
+const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10);
+
+/// How long to wait on logout (device removal) before letting it continue as a background task.
+const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2);
+
+/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts
+/// to set up a WireGuard tunnel.
+const WG_DEVICE_CHECK_THRESHOLD: usize = 3;
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "The account already has a maximum number of devices")]
+ MaxDevicesReached,
+ #[error(display = "No device is set")]
+ NoDevice,
+ #[error(display = "Device not found")]
+ InvalidDevice,
+ #[error(display = "Invalid account")]
+ InvalidAccount,
+ #[error(display = "Failed to read or write device cache")]
+ DeviceIoError(#[error(source)] io::Error),
+ #[error(display = "Failed parse device cache")]
+ ParseDeviceCache(#[error(source)] serde_json::Error),
+ #[error(display = "Unexpected HTTP request error")]
+ OtherRestError(#[error(source)] rest::Error),
+ #[error(display = "The device update task is not running")]
+ DeviceUpdaterCancelled(#[error(source)] oneshot::Canceled),
+ #[error(display = "The account manager is down")]
+ AccountManagerDown,
+}
+
+#[derive(Clone)]
+pub(crate) enum InnerDeviceEvent {
+ /// The device was removed due to user (or daemon) action.
+ Logout,
+ /// Logged in to a new device.
+ Login(DeviceData),
+ /// The device was updated remotely, but not its key.
+ Updated(DeviceData),
+ /// The key was rotated.
+ RotatedKey(DeviceData),
+ /// Device was removed because it was not found remotely.
+ Revoked,
+}
+
+impl From<InnerDeviceEvent> for DeviceEvent {
+ fn from(event: InnerDeviceEvent) -> DeviceEvent {
+ match event {
+ InnerDeviceEvent::Logout => DeviceEvent::revoke(false),
+ InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true),
+ InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Revoked => DeviceEvent::revoke(true),
+ }
+ }
+}
+
+impl InnerDeviceEvent {
+ fn data(&self) -> Option<&DeviceData> {
+ match self {
+ InnerDeviceEvent::Login(data) => Some(&data),
+ InnerDeviceEvent::Updated(data) => Some(&data),
+ InnerDeviceEvent::RotatedKey(data) => Some(&data),
+ InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
+ }
+ }
+
+ fn into_data(self) -> Option<DeviceData> {
+ match self {
+ InnerDeviceEvent::Login(data) => Some(data),
+ InnerDeviceEvent::Updated(data) => Some(data),
+ InnerDeviceEvent::RotatedKey(data) => Some(data),
+ InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
+ }
+ }
+}
+
+impl Error {
+ pub fn is_network_error(&self) -> bool {
+ if let Error::OtherRestError(error) = self {
+ error.is_network_error()
+ } else {
+ false
+ }
+ }
+}
+
+pub enum ValidationResult {
+ /// The device and key were valid.
+ Valid,
+ /// The device was valid but the key was replaced
+ RotatedKey,
+ /// The device was valid but one or more fields, such as ports, were replaced
+ Updated,
+ /// The device was not found remotely and was removed from the cache.
+ Removed,
+}
+
+type ResponseTx<T> = oneshot::Sender<Result<T, Error>>;
+
+enum AccountManagerCommand {
+ Login(AccountToken, ResponseTx<()>),
+ Logout(ResponseTx<()>),
+ SetData(DeviceData, ResponseTx<()>),
+ GetData(ResponseTx<Option<DeviceData>>),
+ RotateKey(ResponseTx<()>),
+ SetRotationInterval(RotationInterval, ResponseTx<()>),
+ GetRotationInterval(ResponseTx<RotationInterval>),
+ ValidateDevice(ResponseTx<ValidationResult>),
+ ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>),
+ Shutdown(oneshot::Sender<()>),
+}
+
+#[derive(Clone)]
+pub(crate) struct AccountManagerHandle {
+ cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>,
+ pub account_service: AccountService,
+ pub device_service: DeviceService,
+}
+
+impl AccountManagerHandle {
+ pub async fn login(&self, token: AccountToken) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Login(token, tx))
+ .await
+ }
+
+ pub async fn logout(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Logout(tx))
+ .await
+ }
+
+ pub async fn set(&self, data: DeviceData) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetData(data, tx))
+ .await
+ }
+
+ pub async fn data(&self) -> Result<Option<DeviceData>, Error> {
+ self.send_command(|tx| AccountManagerCommand::GetData(tx))
+ .await
+ }
+
+ pub async fn rotate_key(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::RotateKey(tx))
+ .await
+ }
+
+ pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx))
+ .await
+ }
+
+ pub async fn rotation_interval(&self) -> Result<RotationInterval, Error> {
+ self.send_command(|tx| AccountManagerCommand::GetRotationInterval(tx))
+ .await
+ }
+
+ pub async fn validate_device(&self) -> Result<ValidationResult, Error> {
+ self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx))
+ .await
+ }
+
+ pub async fn receive_events(
+ &self,
+ events_tx: impl Sender<InnerDeviceEvent> + Send + 'static,
+ ) -> Result<(), Error> {
+ self.send_command(|tx| {
+ AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx)
+ })
+ .await
+ }
+
+ pub async fn shutdown(self) {
+ let (tx, rx) = oneshot::channel();
+ let _ = self
+ .cmd_tx
+ .unbounded_send(AccountManagerCommand::Shutdown(tx));
+ let _ = rx.await;
+ }
+
+ async fn send_command<T>(
+ &self,
+ make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand,
+ ) -> Result<T, Error> {
+ let (tx, rx) = oneshot::channel();
+ self.cmd_tx
+ .unbounded_send(make_cmd(tx))
+ .map_err(|_| Error::AccountManagerDown)?;
+ rx.await.map_err(|_| Error::AccountManagerDown)?
+ }
+}
+
+pub(crate) struct AccountManager {
+ cacher: DeviceCacher,
+ device_service: DeviceService,
+ data: Option<DeviceData>,
+ rotation_interval: RotationInterval,
+ listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>,
+ last_validation: Option<SystemTime>,
+}
+
+impl AccountManager {
+ pub async fn spawn(
+ rest_handle: rest::MullvadRestHandle,
+ api_availability: ApiAvailabilityHandle,
+ settings_dir: &Path,
+ initial_rotation_interval: RotationInterval,
+ ) -> Result<AccountManagerHandle, Error> {
+ let (cacher, data) = DeviceCacher::new(settings_dir).await?;
+ let token = data.as_ref().map(|state| state.token.clone());
+ let account_service =
+ spawn_account_service(rest_handle.clone(), token, api_availability.clone());
+
+ let (cmd_tx, cmd_rx) = mpsc::unbounded();
+
+ let device_service = DeviceService::new(rest_handle, api_availability);
+ let manager = AccountManager {
+ cacher,
+ device_service: device_service.clone(),
+ data,
+ rotation_interval: initial_rotation_interval,
+ listeners: vec![],
+ last_validation: None,
+ };
+
+ tokio::spawn(manager.run(cmd_rx));
+ let handle = AccountManagerHandle {
+ cmd_tx,
+ account_service,
+ device_service,
+ };
+ KeyUpdater::spawn(handle.clone()).await?;
+ Ok(handle)
+ }
+
+ async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) {
+ let mut shutdown_tx = None;
+ while let Some(cmd) = cmd_rx.next().await {
+ match cmd {
+ AccountManagerCommand::Shutdown(tx) => {
+ shutdown_tx = Some(tx);
+ break;
+ }
+ other => self.service_command(other).await,
+ }
+ }
+ self.shutdown().await;
+ if let Some(tx) = shutdown_tx {
+ let _ = tx.send(());
+ }
+ log::debug!("Account manager has stopped");
+ }
+
+ async fn service_command(&mut self, cmd: AccountManagerCommand) {
+ match cmd {
+ AccountManagerCommand::Login(token, tx) => {
+ let _ = tx.send(self.login(token).await);
+ }
+ AccountManagerCommand::Logout(tx) => {
+ let _ = tx.send(self.logout().await);
+ }
+ AccountManagerCommand::SetData(data, tx) => {
+ let _ = tx.send(self.set(InnerDeviceEvent::Login(data)).await);
+ }
+ AccountManagerCommand::GetData(tx) => {
+ let _ = tx.send(Ok(self.data.clone()));
+ }
+ AccountManagerCommand::RotateKey(tx) => {
+ let _ = tx.send(self.rotate_key().await);
+ }
+ AccountManagerCommand::SetRotationInterval(interval, tx) => {
+ self.rotation_interval = interval;
+ let _ = tx.send(Ok(()));
+ }
+ AccountManagerCommand::GetRotationInterval(tx) => {
+ let _ = tx.send(Ok(self.rotation_interval));
+ }
+ AccountManagerCommand::ValidateDevice(tx) => {
+ let _ = tx.send(self.validate_device().await);
+ }
+ AccountManagerCommand::ReceiveEvents(events_tx, tx) => {
+ let _ = tx.send(Ok(self.listeners.push(events_tx)));
+ }
+ AccountManagerCommand::Shutdown(_) => unreachable!("shutdown is handled earlier"),
+ }
+ }
+
+ async fn login(&mut self, token: AccountToken) -> Result<(), Error> {
+ let data = self.device_service.generate_for_account(token).await?;
+ self.set(InnerDeviceEvent::Login(data)).await?;
+ Ok(())
+ }
+
+ async fn logout(&mut self) -> Result<(), Error> {
+ if self.data.is_some() {
+ self.cacher.write(None).await?;
+ let _ = tokio::time::timeout(LOGOUT_TIMEOUT, self.logout_inner()).await;
+
+ let event = InnerDeviceEvent::Logout;
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+ }
+ Ok(())
+ }
+
+ async fn logout_inner(&mut self) -> tokio::task::JoinHandle<()> {
+ let prev_data = self.data.take();
+ let service = self.device_service.clone();
+
+ tokio::spawn(async move {
+ if let Some(data) = prev_data {
+ if let Err(error) = service
+ .remove_device_with_backoff(data.token, data.device.id)
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to remove a previous device")
+ );
+ }
+ }
+ })
+ }
+
+ async fn set(&mut self, event: InnerDeviceEvent) -> Result<(), Error> {
+ let data = event.data();
+ if data == self.data.as_ref() {
+ return Ok(());
+ }
+
+ self.cacher.write(data).await?;
+ self.last_validation = None;
+
+ if self
+ .data
+ .as_ref()
+ .map(|current| data.as_ref().map(|d| &d.device.id) != Some(&current.device.id))
+ .unwrap_or(false)
+ {
+ // Remove the existing device if its ID differs. Otherwise, only update
+ // the data.
+ self.logout_inner().await;
+ }
+
+ self.data = data.cloned();
+
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+
+ Ok(())
+ }
+
+ async fn rotate_key(&mut self) -> Result<(), Error> {
+ // TODO: Update all data opportunistically?
+ let data = self.data.as_ref().ok_or(Error::NoDevice)?;
+
+ let wg_data = self
+ .device_service
+ .rotate_key(data.token.clone(), data.device.id.clone())
+ .await?;
+
+ // Copy the data to keep a predictable state if an error occurs.
+ let mut new_data = data.clone();
+ new_data.device.pubkey = wg_data.private_key.public_key();
+ new_data.wg_data = wg_data;
+ self.set(InnerDeviceEvent::RotatedKey(new_data)).await
+ }
+
+ /// Check if the device is valid for the account, and yank it if it no longer exists.
+ /// This also updates any associated data and returns whether it changed.
+ async fn validate_device(&mut self) -> Result<ValidationResult, Error> {
+ log::debug!("Checking whether the device is still valid");
+
+ if let Some(result) = self.cached_validation() {
+ log::debug!("The current device is still valid");
+ return Ok(result);
+ }
+
+ let data = self.data.as_ref().ok_or(Error::NoDevice)?;
+
+ match self
+ .device_service
+ .get(data.token.clone(), data.device.id.clone())
+ .await
+ {
+ Ok(device) => {
+ if device.pubkey == data.device.pubkey {
+ if device == data.device {
+ log::debug!("The current device is still valid");
+ Ok(ValidationResult::Valid)
+ } else {
+ log::debug!("Updating data for the current device");
+ // Copy the data to keep a predictable state if an error occurs.
+ let new_data = DeviceData {
+ device,
+ ..data.clone()
+ };
+ self.set(InnerDeviceEvent::Updated(new_data)).await?;
+ Ok(ValidationResult::Updated)
+ }
+ } else {
+ log::debug!("Rotating invalid WireGuard key");
+ self.rotate_key().await?;
+ Ok(ValidationResult::RotatedKey)
+ }
+ }
+ Err(Error::InvalidAccount) | Err(Error::InvalidDevice) => {
+ log::debug!("The current device is no longer valid for this account");
+
+ self.cacher.write(None).await?;
+ self.data = None;
+
+ let event = InnerDeviceEvent::Revoked;
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+
+ Ok(ValidationResult::Removed)
+ }
+ Err(error) => Err(error),
+ }
+ }
+
+ fn cached_validation(&mut self) -> Option<ValidationResult> {
+ if self.data.is_none() {
+ return None;
+ }
+
+ let now = SystemTime::now();
+
+ let elapsed = self
+ .last_validation
+ .and_then(|last_check| now.duration_since(last_check).ok())
+ .unwrap_or(VALIDITY_CACHE_TIMEOUT);
+
+ if elapsed >= VALIDITY_CACHE_TIMEOUT {
+ self.last_validation = Some(now);
+ return None;
+ }
+
+ Some(ValidationResult::Valid)
+ }
+
+ async fn shutdown(self) {
+ self.cacher.finalize().await;
+ }
+}
+
+struct KeyUpdater {
+ handle: AccountManagerHandle,
+ rx: mpsc::UnboundedReceiver<InnerDeviceEvent>,
+ data: Option<DeviceData>,
+}
+
+impl KeyUpdater {
+ async fn spawn(handle: AccountManagerHandle) -> Result<(), Error> {
+ let (tx, rx) = mpsc::unbounded();
+ handle.receive_events(tx).await?;
+ let data = handle.data().await?;
+ let mut key_rotator = KeyUpdater { handle, rx, data };
+
+ tokio::spawn(async move {
+ loop {
+ tokio::time::sleep(KEY_CHECK_INTERVAL).await;
+
+ if let Err(error) = key_rotator.check_key_validity().await {
+ if let Error::AccountManagerDown = error {
+ break;
+ }
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Stopping key rotation task due to an error")
+ );
+ break;
+ }
+ }
+ log::debug!("Stopping key updater");
+ });
+
+ Ok(())
+ }
+
+ async fn check_key_validity(&mut self) -> Result<(), Error> {
+ let rotation_interval = self.handle.rotation_interval().await?;
+ let data = self.wait_for_data().await?;
+
+ if (chrono::Utc::now()
+ .signed_duration_since(data.wg_data.created)
+ .num_seconds() as u64)
+ < rotation_interval.as_duration().as_secs()
+ {
+ return Ok(());
+ }
+
+ let mut data = data.clone();
+
+ let rotation_fut = self
+ .handle
+ .device_service
+ .rotate_key_with_backoff(data.token.clone(), data.device.id.clone());
+
+ match futures::future::select(Box::pin(rotation_fut), self.rx.next()).await {
+ futures::future::Either::Left((Ok(wg_data), _)) => {
+ log::debug!("Rotating WireGuard key");
+ data.device.pubkey = wg_data.private_key.public_key();
+ data.wg_data = wg_data;
+ self.handle.set(data).await?;
+ }
+ futures::future::Either::Left((Err(error), _)) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Stopping key rotation due to an error")
+ );
+
+ // Forget the current device. Key rotation will restart when
+ // it is updated in any way.
+ self.data = None;
+ }
+ futures::future::Either::Right((event, _)) => {
+ // Abort key rotation if the device changed
+ if let Some(event) = event {
+ self.data = event.into_data();
+ } else {
+ return Err(Error::AccountManagerDown);
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn wait_for_data(&mut self) -> Result<&DeviceData, Error> {
+ while let Ok(item) = self.rx.try_next() {
+ match item {
+ Some(event) => {
+ self.data = event.into_data();
+ }
+ None => return Err(Error::AccountManagerDown),
+ }
+ }
+
+ match self.data {
+ Some(ref data) => Ok(data),
+ None => loop {
+ let event = self.rx.next().await;
+ match event {
+ Some(event) => {
+ if let Some(data) = event.into_data() {
+ self.data = Some(data);
+ break Ok(self.data.as_ref().unwrap());
+ }
+ }
+ None => break Err(Error::AccountManagerDown),
+ }
+ },
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct DeviceService {
+ api_availability: ApiAvailabilityHandle,
+ proxy: DevicesProxy,
+}
+
+impl DeviceService {
+ pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self {
+ Self {
+ proxy: DevicesProxy::new(handle),
+ api_availability,
+ }
+ }
+
+ /// Generate a new device for a given token
+ pub async fn generate_for_account(&self, token: AccountToken) -> Result<DeviceData, Error> {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ let (device, addresses) = retry_future_n(
+ move || proxy.create(token_copy.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+
+ pub async fn generate_for_account_with_backoff(
+ &self,
+ token: AccountToken,
+ ) -> Result<DeviceData, Error> {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ let (device, addresses) = retry_future(
+ move || api_handle.when_online(proxy.create(token_copy.clone(), pubkey.clone())),
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+
+ pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.remove(token.clone(), device.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+ Ok(())
+ }
+
+ pub async fn remove_device_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ ), // Not setting a maximum interval
+ );
+
+ retry_future(
+ // NOTE: Not honoring "paused" state, because the account may have no time on it.
+ move || api_handle.when_online(proxy.remove(token.clone(), device.clone())),
+ should_retry_backoff,
+ retry_strategy,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(())
+ }
+
+ pub async fn rotate_key(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+ let addresses = retry_future_n(
+ move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn rotate_key_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+
+ let addresses = retry_future(
+ move || {
+ api_handle.when_bg_resumes(proxy.replace_wg_key(
+ token.clone(),
+ device.clone(),
+ pubkey.clone(),
+ ))
+ },
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.list(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+
+ pub async fn list_devices_with_backoff(
+ &self,
+ token: AccountToken,
+ ) -> Result<Vec<Device>, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+
+ retry_future(
+ move || api_handle.when_online(proxy.list(token.clone())),
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+
+ pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result<Device, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.get(token.clone(), device.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+}
+
+pub struct DeviceCacher {
+ file: io::BufWriter<fs::File>,
+ path: std::path::PathBuf,
+}
+
+impl DeviceCacher {
+ pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> {
+ let mut options = std::fs::OpenOptions::new();
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::OpenOptionsExt;
+ options.mode(0o600);
+ }
+ #[cfg(windows)]
+ {
+ use std::os::windows::fs::OpenOptionsExt;
+ // exclusive access
+ options.share_mode(0);
+ }
+
+ let path = settings_dir.join(DEVICE_CACHE_FILENAME);
+ let cache_exists = path.is_file();
+
+ let mut file = fs::OpenOptions::from(options)
+ .write(true)
+ .read(true)
+ .create(true)
+ .open(&path)
+ .await?;
+
+ let device: Option<DeviceData> = if cache_exists {
+ let mut reader = io::BufReader::new(&mut file);
+ let mut buffer = String::new();
+ reader.read_to_string(&mut buffer).await?;
+ if !buffer.is_empty() {
+ serde_json::from_str(&buffer)?
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ Ok((
+ DeviceCacher {
+ file: io::BufWriter::new(file),
+ path,
+ },
+ device,
+ ))
+ }
+
+ pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> {
+ let data = serde_json::to_vec_pretty(&device).unwrap();
+
+ self.file.get_mut().set_len(0).await?;
+ self.file.seek(io::SeekFrom::Start(0)).await?;
+ self.file.write_all(&data).await?;
+ self.file.flush().await?;
+ self.file.get_mut().sync_data().await?;
+
+ Ok(())
+ }
+
+ pub async fn remove(self) -> Result<(), Error> {
+ let path = {
+ let DeviceCacher { path, file } = self;
+ let std_file = file.into_inner().into_std().await;
+ let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
+ path
+ };
+ tokio::fs::remove_file(path).await?;
+ Ok(())
+ }
+
+ async fn finalize(self) {
+ let std_file = self.file.into_inner().into_std().await;
+ let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
+ }
+}
+
+#[derive(Clone)]
+pub struct AccountService {
+ api_availability: ApiAvailabilityHandle,
+ initial_check_abort_handle: AbortHandle,
+ proxy: AccountsProxy,
+}
+
+impl AccountService {
+ pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.create_account(),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub fn get_www_auth_token(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<String, rest::Error>> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.get_www_auth_token(account.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.get_expiry(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if handle_expiry_result_inner(&result, &self.api_availability) {
+ self.initial_check_abort_handle.abort();
+ }
+ result
+ }
+
+ pub async fn submit_voucher(
+ &mut self,
+ account_token: AccountToken,
+ voucher: String,
+ ) -> Result<VoucherSubmission, rest::Error> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if result.is_ok() {
+ self.initial_check_abort_handle.abort();
+ self.api_availability.resume_background();
+ }
+ result
+ }
+}
+
+pub fn spawn_account_service(
+ rpc_handle: MullvadRestHandle,
+ token: Option<String>,
+ api_availability: ApiAvailabilityHandle,
+) -> AccountService {
+ let accounts_proxy = AccountsProxy::new(rpc_handle);
+ api_availability.pause_background();
+
+ let api_availability_copy = api_availability.clone();
+ let accounts_proxy_copy = accounts_proxy.clone();
+
+ let (future, initial_check_abort_handle) = abortable(async move {
+ let token = if let Some(token) = token {
+ token
+ } else {
+ api_availability.pause_background();
+ return;
+ };
+
+ let future_generator = move || {
+ let expiry_fut = api_availability.when_online(accounts_proxy.get_expiry(token.clone()));
+ let api_availability_copy = api_availability.clone();
+ async move { handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) }
+ };
+ let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
+ retry_future(future_generator, should_retry, retry_strategy()).await;
+ });
+ tokio::spawn(future);
+
+ AccountService {
+ api_availability: api_availability_copy,
+ initial_check_abort_handle,
+ proxy: accounts_proxy_copy,
+ }
+}
+
+fn handle_expiry_result_inner(
+ result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>,
+ api_availability: &ApiAvailabilityHandle,
+) -> bool {
+ match result {
+ Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
+ api_availability.resume_background();
+ true
+ }
+ Ok(_expiry) => {
+ api_availability.pause_background();
+ true
+ }
+ Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => {
+ if code == mullvad_rpc::INVALID_ACCOUNT {
+ api_availability.pause_background();
+ return true;
+ }
+ false
+ }
+ Err(_) => false,
+ }
+}
+
+fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
+ match result {
+ Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
+ _ => false,
+ }
+}
+
+fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
+ match result {
+ Ok(_) => false,
+ Err(error) => {
+ if let RestError::ApiError(status, code) = error {
+ *status != rest::StatusCode::NOT_FOUND
+ && code != mullvad_rpc::INVALID_ACCOUNT
+ && code != mullvad_rpc::MAX_DEVICES_REACHED
+ && code != mullvad_rpc::PUBKEY_IN_USE
+ } else {
+ true
+ }
+ }
+ }
+}
+
+fn map_rest_error(error: rest::Error) -> Error {
+ match error {
+ RestError::ApiError(status, ref code) => {
+ if status == rest::StatusCode::NOT_FOUND {
+ return Error::InvalidDevice;
+ }
+ match code.as_str() {
+ mullvad_rpc::INVALID_ACCOUNT => Error::InvalidAccount,
+ mullvad_rpc::MAX_DEVICES_REACHED => Error::MaxDevicesReached,
+ _ => Error::OtherRestError(error),
+ }
+ }
+ error => Error::OtherRestError(error),
+ }
+}
+
+fn retry_strategy() -> Jittered<ExponentialBackoff> {
+ Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ )
+ .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
+ )
+}
+
+/// Checks if the current device is valid if a WireGuard tunnel cannot be set up
+/// after multiple attempts.
+pub(crate) struct TunnelStateChangeHandler {
+ manager: AccountManagerHandle,
+ check_validity: Arc<AtomicBool>,
+ wg_retry_attempt: usize,
+}
+
+impl TunnelStateChangeHandler {
+ pub fn new(manager: AccountManagerHandle) -> Self {
+ Self {
+ manager,
+ check_validity: Arc::new(AtomicBool::new(true)),
+ wg_retry_attempt: 0,
+ }
+ }
+
+ pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) {
+ match new_state {
+ TunnelStateTransition::Connecting(endpoint) => {
+ if endpoint.tunnel_type != TunnelType::Wireguard {
+ return;
+ }
+ self.wg_retry_attempt += 1;
+ if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 {
+ let handle = self.manager.clone();
+ let check_validity = self.check_validity.clone();
+ tokio::spawn(async move {
+ if !check_validity.swap(false, Ordering::SeqCst) {
+ return;
+ }
+ if let Err(error) = handle.validate_device().await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check device validity")
+ );
+ if error.is_network_error() {
+ check_validity.store(true, Ordering::SeqCst);
+ }
+ }
+ });
+ }
+ }
+ TunnelStateTransition::Connected(_) | TunnelStateTransition::Disconnected => {
+ self.check_validity.store(true, Ordering::SeqCst);
+ self.wg_retry_attempt = 0;
+ }
+ _ => (),
+ }
+ }
+}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 0d9ec96c87..3f1d363694 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -4,9 +4,9 @@
#[macro_use]
extern crate serde;
-mod account;
pub mod account_history;
mod api;
+pub mod device;
pub mod exception_logging;
#[cfg(target_os = "macos")]
pub mod exclusion_gid;
@@ -25,6 +25,7 @@ pub mod version;
mod version_check;
use crate::target_state::PersistentTargetState;
+use device::InnerDeviceEvent;
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Future},
@@ -36,6 +37,7 @@ use mullvad_rpc::{
};
use mullvad_types::{
account::{AccountData, AccountToken, VoucherSubmission},
+ device::{Device, DeviceConfig, DeviceData, DeviceEvent, DeviceId, RemoveDeviceEvent},
endpoint::MullvadEndpoint,
location::{Coordinates, GeoIpLocation},
relay_constraints::{
@@ -46,7 +48,7 @@ use mullvad_types::{
settings::{DnsOptions, DnsState, Settings},
states::{TargetState, TunnelState},
version::{AppVersion, AppVersionInfo},
- wireguard::{KeygenEvent, RotationInterval},
+ wireguard::{PublicKey, RotationInterval},
};
use settings::SettingsPersister;
#[cfg(target_os = "android")]
@@ -75,7 +77,7 @@ use talpid_types::android::AndroidContext;
use talpid_types::{
net::{
openvpn::{self, ProxySettings},
- TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType,
+ wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType,
},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
@@ -84,12 +86,6 @@ use talpid_types::{
use tokio::fs;
use tokio::io;
-#[path = "wireguard.rs"]
-mod wireguard;
-
-/// Timeout for first WireGuard key pushing
-const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5);
-
/// Delay between generating a new WireGuard key and reconnecting
const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60);
@@ -124,17 +120,35 @@ pub enum Error {
#[error(display = "Unable to load account history")]
LoadAccountHistory(#[error(source)] account_history::Error),
+ #[error(display = "Failed to start account manager")]
+ LoadAccountManager(#[error(source)] device::Error),
+
+ #[error(display = "Failed to log in to account")]
+ LoginError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to log out of account")]
+ LogoutError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to rotate WireGuard key")]
+ KeyRotationError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to list devices")]
+ ListDevicesError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to remove device")]
+ RemoveDeviceError(#[error(source)] device::Error),
+
#[cfg(target_os = "linux")]
#[error(display = "Unable to initialize split tunneling")]
InitSplitTunneling(#[error(source)] split_tunnel::Error),
- #[error(display = "The account has too many wireguard keys")]
- TooManyKeys,
-
#[cfg(windows)]
#[error(display = "Split tunneling error")]
SplitTunnelError(#[error(source)] split_tunnel::Error),
+ #[error(display = "An account is already set")]
+ AlreadyLoggedIn,
+
#[error(display = "No wireguard private key available")]
NoKeyAvailable,
@@ -226,8 +240,16 @@ pub enum DaemonCommand {
/// Trigger an asynchronous relay list update. This returns before the relay list is actually
/// updated.
UpdateRelayLocations,
- /// Set which account token to use for subsequent connection attempts.
- SetAccount(ResponseTx<(), settings::Error>, Option<AccountToken>),
+ /// Log in with a given account and create a new device.
+ LoginAccount(ResponseTx<(), Error>, AccountToken),
+ /// Log out of the current account and remove the device, if they exist.
+ LogoutAccount(ResponseTx<(), Error>),
+ /// Return the current device configuration, if there is one.
+ GetDevice(ResponseTx<Option<DeviceConfig>, Error>),
+ /// Return all the devices for a given account token.
+ ListDevices(ResponseTx<Vec<Device>, Error>, AccountToken),
+ /// Remove device from a given account.
+ RemoveDevice(ResponseTx<(), Error>, AccountToken, DeviceId),
/// Place constraints on the type of tunnel and relay
UpdateRelaySettings(ResponseTx<(), settings::Error>, RelaySettingsUpdate),
/// Set the allow LAN setting.
@@ -256,11 +278,9 @@ pub enum DaemonCommand {
/// Get the daemon settings
GetSettings(oneshot::Sender<Settings>),
/// Generate new wireguard key
- GenerateWireguardKey(ResponseTx<wireguard::KeygenEvent, Error>),
+ RotateWireguardKey(ResponseTx<(), Error>),
/// Return a public key of the currently set wireguard private key, if there is one
- GetWireguardKey(ResponseTx<Option<wireguard::PublicKey>, Error>),
- /// Verify if the currently set wireguard key is valid.
- VerifyWireguardKey(ResponseTx<bool, Error>),
+ GetWireguardKey(ResponseTx<Option<PublicKey>, Error>),
/// Get information about the currently running and latest app versions
GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>),
/// Get current version of the app
@@ -320,19 +340,14 @@ pub(crate) enum InternalDaemonEvent {
Command(DaemonCommand),
/// Daemon shutdown triggered by a signal, ctrl-c or similar.
TriggerShutdown,
- /// Wireguard key generation event
- WgKeyEvent(
- (
- AccountToken,
- Result<mullvad_types::wireguard::WireguardData, wireguard::Error>,
- ),
- ),
- /// New Account created
- NewAccountEvent(AccountToken, oneshot::Sender<Result<String, Error>>),
/// The background job fetching new `AppVersionInfo`s got a new info object.
NewAppVersionInfo(AppVersionInfo),
/// Request from REST client to use a different API endpoint.
GenerateApiConnectionMode(api::ApiConnectionModeRequest),
+ /// Sent when a device is updated in any way (key rotation, login, logout, etc.).
+ DeviceEvent(InnerDeviceEvent),
+ /// Handles updates from versions without devices.
+ DeviceMigrationEvent(DeviceData),
/// The split tunnel paths or state were updated.
#[cfg(target_os = "windows")]
ExcludedPathsEvent(ExcludedPathsUpdate, oneshot::Sender<Result<(), Error>>),
@@ -368,6 +383,12 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent {
}
}
+impl From<InnerDeviceEvent> for InternalDaemonEvent {
+ fn from(event: InnerDeviceEvent) -> Self {
+ InternalDaemonEvent::DeviceEvent(event)
+ }
+}
+
#[derive(Clone, Debug, Eq, PartialEq)]
enum DaemonExecutionState {
Running,
@@ -529,8 +550,11 @@ pub trait EventListener {
/// Or some flag about the currently running version is changed.
fn notify_app_version(&self, app_version_info: AppVersionInfo);
- /// Notify clients of a key generation event.
- fn notify_key_event(&self, key_event: KeygenEvent);
+ /// Notify that device changed (login, logout, or key rotation).
+ fn notify_device_event(&self, event: DeviceEvent);
+
+ /// Notify that a device was revoked using `RemoveDevice`.
+ fn notify_remove_device_event(&self, event: RemoveDeviceEvent);
}
pub struct Daemon<L: EventListener> {
@@ -546,10 +570,10 @@ pub struct Daemon<L: EventListener> {
event_listener: L,
settings: SettingsPersister,
account_history: account_history::AccountHistory,
- account: account::AccountHandle,
+ device_checker: device::TunnelStateChangeHandler,
+ account_manager: device::AccountManagerHandle,
rpc_runtime: mullvad_rpc::MullvadRpcRuntime,
rpc_handle: mullvad_rpc::rest::MullvadRestHandle,
- wireguard_key_manager: wireguard::KeyManager,
version_updater_handle: version_check::VersionUpdaterHandle,
relay_selector: relays::RelaySelector,
last_generated_relay: Option<Relay>,
@@ -584,11 +608,38 @@ where
mullvad_rpc::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await;
- let runtime = tokio::runtime::Handle::current();
-
let (internal_event_tx, internal_event_rx) = command_channel.destructure();
- if let Err(error) = migrations::migrate_all(&cache_dir, &settings_dir).await {
+ let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
+ &cache_dir,
+ true,
+ #[cfg(target_os = "android")]
+ Self::create_bypass_tx(&internal_event_tx),
+ )
+ .await
+ .map_err(Error::InitRpcFactory)?;
+
+ let api_availability = rpc_runtime.availability_handle();
+ api_availability.suspend();
+
+ let endpoint_updater = api::ApiEndpointUpdaterHandle::new();
+
+ let proxy_provider = api::create_api_config_provider(
+ internal_event_tx.to_specialized_sender(),
+ ApiConnectionMode::Direct,
+ );
+ let rpc_handle = rpc_runtime
+ .mullvad_rest_handle(proxy_provider, endpoint_updater.callback())
+ .await;
+
+ if let Err(error) = migrations::migrate_all(
+ &cache_dir,
+ &settings_dir,
+ rpc_handle.clone(),
+ internal_event_tx.clone(),
+ )
+ .await
+ {
log::error!(
"{}",
error.display_chain_with_msg("Failed to migrate settings or cache")
@@ -596,19 +647,45 @@ where
}
let settings = SettingsPersister::load(&settings_dir).await;
- let target_state = if settings.get_account_token().is_none() {
- PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await
- } else if settings.auto_connect {
+ let tunnel_parameters_generator = MullvadTunnelParametersGenerator {
+ tx: internal_event_tx.clone(),
+ };
+
+ let account_manager = device::AccountManager::spawn(
+ rpc_handle.clone(),
+ api_availability.clone(),
+ &settings_dir,
+ settings
+ .tunnel_options
+ .wireguard
+ .rotation_interval
+ .unwrap_or_default(),
+ )
+ .await
+ .map_err(Error::LoadAccountManager)?;
+ account_manager
+ .receive_events(internal_event_tx.to_specialized_sender())
+ .await
+ .map_err(Error::LoadAccountManager)?;
+ let data = account_manager
+ .data()
+ .await
+ .map_err(Error::LoadAccountManager)?;
+
+ let account_history = account_history::AccountHistory::new(
+ &settings_dir,
+ data.as_ref().map(|device| device.token.clone()),
+ )
+ .await
+ .map_err(Error::LoadAccountHistory)?;
+
+ let target_state = if settings.auto_connect {
log::info!("Automatically connecting since auto-connect is turned on");
PersistentTargetState::force(&cache_dir, TargetState::Secured).await
} else {
PersistentTargetState::new(&cache_dir).await
};
- let tunnel_parameters_generator = MullvadTunnelParametersGenerator {
- tx: internal_event_tx.clone(),
- };
-
#[cfg(windows)]
let exclude_paths = if settings.split_tunnel.enable_exclusions {
settings
@@ -621,18 +698,6 @@ where
vec![]
};
- let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
- &cache_dir,
- true,
- #[cfg(target_os = "android")]
- Self::create_bypass_tx(&internal_event_tx),
- )
- .await
- .map_err(Error::InitRpcFactory)?;
-
- let api_availability = rpc_runtime.availability_handle();
- api_availability.suspend();
-
let initial_api_endpoint =
api::get_allowed_endpoint(rpc_runtime.address_cache.get_address().await);
@@ -664,17 +729,8 @@ where
.await
.map_err(Error::TunnelError)?;
- let endpoint_updater = api::ApiEndpointUpdaterHandle::new();
endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx));
- let proxy_provider = api::create_api_config_provider(
- internal_event_tx.to_specialized_sender(),
- ApiConnectionMode::Direct,
- );
- let rpc_handle = rpc_runtime
- .mullvad_rest_handle(proxy_provider, endpoint_updater.callback())
- .await;
-
Self::forward_offline_state(api_availability.clone(), offline_state_rx).await;
let relay_list_listener = event_listener.clone();
@@ -700,28 +756,11 @@ where
settings.show_beta_releases,
);
tokio::spawn(version_updater.run());
- let account_history =
- account_history::AccountHistory::new(&settings_dir, settings.get_account_token())
- .await
- .map_err(Error::LoadAccountHistory)?;
-
- let wireguard_key_manager = wireguard::KeyManager::new(
- internal_event_tx.clone(),
- api_availability.clone(),
- rpc_handle.clone(),
- );
-
- let account = account::Account::new(
- runtime,
- rpc_handle.clone(),
- settings.get_account_token(),
- api_availability.clone(),
- );
// Attempt to download a fresh relay list
relay_selector.update().await;
- let mut daemon = Daemon {
+ let daemon = Daemon {
tunnel_command_tx,
tunnel_state: TunnelState::Disconnected,
target_state,
@@ -734,10 +773,10 @@ where
event_listener,
settings,
account_history,
- account,
+ device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()),
+ account_manager,
rpc_runtime,
rpc_handle,
- wireguard_key_manager,
version_updater_handle,
relay_selector,
last_generated_relay: None,
@@ -751,8 +790,6 @@ where
volume_update_tx,
};
- daemon.ensure_wireguard_keys_for_current_account().await;
-
api_availability.unsuspend();
Ok(daemon)
@@ -856,10 +893,12 @@ where
rpc_runtime,
tunnel_state_machine_handle,
target_state,
+ account_manager,
..
} = self;
shutdown_tasks.push(Box::pin(target_state.finalize()));
+ shutdown_tasks.push(Box::pin(account_manager.shutdown()));
(
event_listener,
@@ -881,16 +920,14 @@ where
}
Command(command) => self.handle_command(command).await,
TriggerShutdown => self.trigger_shutdown_event(),
- WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await,
- NewAccountEvent(account_token, tx) => {
- self.handle_new_account_event(account_token, tx).await
- }
NewAppVersionInfo(app_version_info) => {
self.handle_new_app_version_info(app_version_info)
}
GenerateApiConnectionMode(request) => {
self.handle_generate_api_connection_mode(request).await
}
+ DeviceEvent(event) => self.handle_device_event(event).await,
+ DeviceMigrationEvent(event) => self.handle_device_migration_event(event).await,
#[cfg(windows)]
ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await,
}
@@ -902,6 +939,9 @@ where
) {
self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition)
.await;
+ self.device_checker
+ .handle_state_transition(&tunnel_state_transition);
+
let tunnel_state = match tunnel_state_transition {
TunnelStateTransition::Disconnected => TunnelState::Disconnected,
TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting {
@@ -918,7 +958,12 @@ where
TunnelStateTransition::Error(error_state) => TunnelState::Error(error_state),
};
- self.unschedule_reconnect();
+ if !tunnel_state.is_connected() {
+ // Cancel reconnects except when entering the connected state.
+ // Exempt the latter because a reconnect scheduled while connecting should not be
+ // aborted.
+ self.unschedule_reconnect();
+ }
log::debug!("New tunnel state: {:?}", tunnel_state);
match tunnel_state {
@@ -937,7 +982,7 @@ where
}
if let ErrorStateCause::AuthFailed(_) = error_state.cause() {
- self.schedule_reconnect(Duration::from_secs(60)).await
+ self.schedule_reconnect(Duration::from_secs(60))
}
}
_ => {}
@@ -967,7 +1012,7 @@ where
>,
retry_attempt: u32,
) {
- if let Some(account_token) = self.settings.get_account_token() {
+ if let Ok(Some(device)) = self.account_manager.data().await {
let result = match self.settings.get_relay_settings() {
RelaySettings::CustomTunnelEndpoint(custom_relay) => {
self.last_generated_relay = None;
@@ -987,7 +1032,6 @@ where
&constraints,
self.settings.get_bridge_state(),
retry_attempt,
- self.settings.get_wireguard().is_some(),
)
.ok();
if let Some(relays::RelaySelectorResult {
@@ -1000,7 +1044,7 @@ where
.create_tunnel_parameters(
&exit_relay,
endpoint,
- account_token,
+ device.token,
retry_attempt,
)
.await;
@@ -1111,7 +1155,13 @@ where
.into())
}
MullvadEndpoint::Wireguard(endpoint) => {
- let wg_data = self.settings.get_wireguard().ok_or(Error::NoKeyAvailable)?;
+ let wg_data = self
+ .account_manager
+ .data()
+ .await
+ .map_err(|_| Error::NoKeyAvailable)?
+ .map(|device| device.wg_data)
+ .ok_or(Error::NoKeyAvailable)?;
let tunnel = wireguard::TunnelConfig {
private_key: wg_data.private_key,
addresses: vec![
@@ -1135,7 +1185,7 @@ where
}
}
- async fn schedule_reconnect(&mut self, delay: Duration) {
+ fn schedule_reconnect(&mut self, delay: Duration) {
self.unschedule_reconnect();
let tunnel_command_tx = self.tx.to_specialized_sender();
@@ -1175,7 +1225,13 @@ where
SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await,
GetRelayLocations(tx) => self.on_get_relay_locations(tx),
UpdateRelayLocations => self.on_update_relay_locations().await,
- SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await,
+ LoginAccount(tx, account_token) => self.on_login_account(tx, account_token).await,
+ LogoutAccount(tx) => self.on_logout_account(tx).await,
+ GetDevice(tx) => self.on_get_device(tx).await,
+ ListDevices(tx, account_token) => self.on_list_devices(tx, account_token).await,
+ RemoveDevice(tx, account_token, device_id) => {
+ self.on_remove_device(tx, account_token, device_id).await
+ }
GetAccountHistory(tx) => self.on_get_account_history(tx),
ClearAccountHistory(tx) => self.on_clear_account_history(tx).await,
UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update).await,
@@ -1198,9 +1254,8 @@ where
self.on_set_wireguard_rotation_interval(tx, interval).await
}
GetSettings(tx) => self.on_get_settings(tx),
- GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await,
+ RotateWireguardKey(tx) => self.on_rotate_wireguard_key(tx).await,
GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await,
- VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await,
GetVersionInfo(tx) => self.on_get_version_info(tx).await,
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
@@ -1232,106 +1287,6 @@ where
}
}
- async fn handle_wireguard_key_event(
- &mut self,
- event: (
- AccountToken,
- Result<mullvad_types::wireguard::WireguardData, wireguard::Error>,
- ),
- ) {
- let (account, result) = event;
- // If the account has been reset whilst a key was being generated, the event should be
- // dropped even if a new key was generated.
- if self
- .settings
- .get_account_token()
- .map(|current_account| current_account != account)
- .unwrap_or(true)
- {
- log::info!("Dropping wireguard key event since account has been changed");
- return;
- }
-
- match result {
- Ok(data) => {
- let public_key = data.get_public_key();
- let is_first_key = self.settings.get_wireguard().is_none();
- match self.settings.set_wireguard(Some(data)).await {
- Ok(_) => {
- if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY).await;
- }
- self.event_listener
- .notify_key_event(KeygenEvent::NewKey(public_key));
- if is_first_key {
- self.ensure_key_rotation().await;
- }
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg(
- "Failed to add new wireguard key to account data"
- )
- );
- self.event_listener
- .notify_key_event(KeygenEvent::GenerationFailure)
- }
- }
- }
- Err(wireguard::Error::TooManyKeys) => {
- self.event_listener
- .notify_key_event(KeygenEvent::TooManyKeys);
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg("Failed to generate wireguard key")
- );
- self.event_listener
- .notify_key_event(KeygenEvent::GenerationFailure);
- }
- }
- }
-
- async fn ensure_key_rotation(&mut self) {
- let token = match self.settings.get_account_token() {
- Some(token) => token,
- None => return,
- };
- let public_key = match self.settings.get_wireguard() {
- Some(data) => data.get_public_key(),
- None => return,
- };
- self.wireguard_key_manager
- .set_rotation_interval(
- public_key,
- token,
- self.settings.tunnel_options.wireguard.rotation_interval,
- )
- .await;
- }
-
- async fn handle_new_account_event(
- &mut self,
- new_token: AccountToken,
- tx: ResponseTx<String, Error>,
- ) {
- match self.set_account(Some(new_token.clone())).await {
- Ok(_) => {
- self.set_target_state(TargetState::Unsecured).await;
- let _ = tx.send(Ok(new_token));
- }
- Err(err) => {
- log::error!(
- "{}",
- err.display_chain_with_msg("Failed to save new account")
- );
- let _ = tx.send(Err(Error::SettingsError(err)));
- }
- };
- }
-
fn handle_new_app_version_info(&mut self, app_version_info: AppVersionInfo) {
self.app_version_info = Some(app_version_info.clone());
self.event_listener.notify_app_version(app_version_info);
@@ -1409,6 +1364,49 @@ where
let _ = request.response_tx.send(config);
}
+ async fn handle_device_event(&mut self, event: InnerDeviceEvent) {
+ match &event {
+ InnerDeviceEvent::Login(device) => {
+ if let Err(error) = self.account_history.set(device.token.clone()).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to update account history")
+ );
+ }
+ if *self.target_state == TargetState::Secured {
+ log::debug!("Initiating tunnel restart because the account token changed");
+ self.reconnect_tunnel();
+ }
+ }
+ InnerDeviceEvent::Logout => {
+ log::info!("Disconnecting because account token was cleared");
+ self.set_target_state(TargetState::Unsecured).await;
+ }
+ InnerDeviceEvent::RotatedKey(_) => {
+ if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
+ self.schedule_reconnect(WG_RECONNECT_DELAY);
+ }
+ }
+ _ => (),
+ }
+ self.event_listener
+ .notify_device_event(DeviceEvent::from(event));
+ }
+
+ async fn handle_device_migration_event(&mut self, data: DeviceData) {
+ if let Ok(Some(_)) = self.account_manager.data().await {
+ // Discard stale device
+ return;
+ }
+ if let Err(error) = self.account_manager.set(data).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to move over account from old settings")
+ );
+ }
+ self.reconnect_tunnel();
+ }
+
#[cfg(windows)]
async fn handle_new_excluded_paths(
&mut self,
@@ -1540,17 +1538,30 @@ where
}
async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) {
- let daemon_tx = self.tx.clone();
- let future = self.account.create_account();
+ let account_manager = self.account_manager.clone();
tokio::spawn(async move {
- match future.await {
- Ok(account_token) => {
- let _ = daemon_tx.send(InternalDaemonEvent::NewAccountEvent(account_token, tx));
+ let result = async {
+ if let Ok(Some(_)) = account_manager.data().await {
+ return Err(Error::AlreadyLoggedIn);
}
- Err(err) => {
- let _ = tx.send(Err(Error::RestError(err)));
- }
- }
+ let token = account_manager
+ .account_service
+ .create_account()
+ .await
+ .map_err(Error::RestError)?;
+ account_manager
+ .login(token.clone())
+ .await
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Creating new account failed")
+ );
+ Error::LoginError(error)
+ })?;
+ Ok(token)
+ };
+ Self::oneshot_send(tx, result.await, "create new account");
});
}
@@ -1559,7 +1570,7 @@ where
tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>,
account_token: AccountToken,
) {
- let account = self.account.clone();
+ let account = self.account_manager.account_service.clone();
tokio::spawn(async move {
let result = account.check_expiry(account_token).await;
Self::oneshot_send(
@@ -1571,16 +1582,18 @@ where
}
async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) {
- if let Some(account_token) = self.settings.get_account_token() {
- let future = self.account.get_www_auth_token(account_token);
- let rpc_call = async {
+ if let Ok(Some(device)) = self.account_manager.data().await {
+ let future = self
+ .account_manager
+ .account_service
+ .get_www_auth_token(device.token);
+ tokio::spawn(async {
Self::oneshot_send(
tx,
future.await.map_err(Error::RestError),
"get_www_auth_token response",
);
- };
- tokio::spawn(rpc_call);
+ });
} else {
Self::oneshot_send(
tx,
@@ -1595,13 +1608,13 @@ where
tx: ResponseTx<VoucherSubmission, Error>,
voucher: String,
) {
- if let Some(account_token) = self.settings.get_account_token() {
- let mut account = self.account.clone();
+ if let Ok(Some(device)) = self.account_manager.data().await {
+ let mut account = self.account_manager.account_service.clone();
tokio::spawn(async move {
Self::oneshot_send(
tx,
account
- .submit_voucher(account_token, voucher)
+ .submit_voucher(device.token, voucher)
.await
.map_err(Error::RestError),
"submit_voucher response",
@@ -1620,90 +1633,120 @@ where
self.relay_selector.update().await;
}
- async fn on_set_account(
- &mut self,
- tx: ResponseTx<(), settings::Error>,
- account_token: Option<String>,
- ) {
- match self.set_account(account_token.clone()).await {
- Ok(account_changed) => {
- if account_changed {
- match account_token {
- Some(_) => {
- log::info!(
- "Initiating tunnel restart because the account token changed"
- );
- self.reconnect_tunnel();
- }
- None => {
- log::info!("Disconnecting because account token was cleared");
- self.set_target_state(TargetState::Unsecured).await;
- }
- };
+ async fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) {
+ let account_manager = self.account_manager.clone();
+ tokio::spawn(async move {
+ let result = async {
+ account_manager.login(account_token).await.map_err(|error| {
+ log::error!("{}", error.display_chain_with_msg("Login failed"));
+ Error::LoginError(error)
+ })
+ };
+ Self::oneshot_send(tx, result.await, "login_account response");
+ });
+ }
+
+ async fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) {
+ let account_manager = self.account_manager.clone();
+ tokio::spawn(async move {
+ let result = async {
+ account_manager.logout().await.map_err(|error| {
+ log::error!("{}", error.display_chain_with_msg("Logout failed"));
+ Error::LogoutError(error)
+ })
+ };
+ Self::oneshot_send(tx, result.await, "logout_account response");
+ });
+ }
+
+ async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceConfig>, Error>) {
+ let account_manager = self.account_manager.clone();
+ tokio::spawn(async move {
+ // Make sure the device is updated
+ match account_manager.validate_device().await {
+ Ok(_) | Err(device::Error::NoDevice) => (),
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to update device data")
+ );
}
- Self::oneshot_send(tx, Ok(()), "set_account response");
}
- Err(error) => {
- log::error!("{}", error.display_chain_with_msg("Failed to set account"));
- Self::oneshot_send(tx, Err(error), "set_account response");
- }
- }
+
+ Self::oneshot_send(
+ tx,
+ Ok(account_manager
+ .data()
+ .await
+ .unwrap_or(None)
+ .map(DeviceConfig::from)),
+ "get_device response",
+ );
+ });
}
- async fn set_account(
- &mut self,
- account_token: Option<String>,
- ) -> Result<bool, settings::Error> {
- let previous_token = self.settings.get_account_token();
- let account_changed = self
- .settings
- .set_account_token(account_token.clone())
- .await?;
- if account_changed {
- self.event_listener
- .notify_settings(self.settings.to_settings());
+ async fn on_list_devices(&self, tx: ResponseTx<Vec<Device>, Error>, token: AccountToken) {
+ let service = self.account_manager.device_service.clone();
+ tokio::spawn(async move {
+ Self::oneshot_send(
+ tx,
+ service
+ .list_devices(token)
+ .await
+ .map_err(Error::ListDevicesError),
+ "list_devices response",
+ );
+ });
+ }
- let history_token = match account_token {
- Some(token) => token,
- None => previous_token.clone().unwrap_or("".to_string()),
- };
- if let Err(error) = self.account_history.set(history_token).await {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to update account history")
- );
- }
+ async fn on_remove_device(
+ &mut self,
+ tx: ResponseTx<(), Error>,
+ token: AccountToken,
+ device_id: DeviceId,
+ ) {
+ let device_service = self.account_manager.device_service.clone();
+ let event_listener = self.event_listener.clone();
- if let Some(previous_token) = previous_token {
- if let Some(previous_key) = self
- .settings
- .get_wireguard()
- .map(|data| data.private_key.public_key())
- {
- let remove_key = self
- .wireguard_key_manager
- .remove_key_with_backoff(previous_token, previous_key);
- tokio::spawn(async move {
- if let Err(error) = remove_key.await {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to remove WireGuard key for previous account"
- )
- );
- }
- });
+ tokio::spawn(async move {
+ let mut devices = match device_service
+ .list_devices(token.clone())
+ .await
+ .map_err(Error::ListDevicesError)
+ {
+ Ok(devices) => devices,
+ Err(error) => {
+ Self::oneshot_send(tx, Err(error), "remove_device response");
+ return;
}
- }
- if let Err(error) = self.settings.set_wireguard(None).await {
- log::error!(
- "{}",
- error.display_chain_with_msg("Error resetting WireGuard key")
- );
- }
- self.ensure_wireguard_keys_for_current_account().await;
- }
- Ok(account_changed)
+ };
+ if let Err(error) = device_service
+ .remove_device(token.clone(), device_id.clone())
+ .await
+ .map_err(Error::RemoveDeviceError)
+ {
+ Self::oneshot_send(tx, Err(error), "remove_device response");
+ return;
+ };
+ let removed_device =
+ if let Some(index) = devices.iter().position(|device| device.id == device_id) {
+ devices.swap_remove(index)
+ } else {
+ log::error!("List did not contain the revoked device");
+ Device {
+ id: device_id,
+ name: "unknown device".to_string(),
+ pubkey: talpid_types::net::wireguard::PublicKey::from([0u8; 32]),
+ ports: vec![],
+ }
+ };
+ event_listener.notify_remove_device_event(RemoveDeviceEvent {
+ account_token: token,
+ removed_device,
+ new_devices: devices,
+ });
+ Self::oneshot_send(tx, Ok(()), "remove_device response");
+ });
}
fn on_get_account_history(&mut self, tx: oneshot::Sender<Option<AccountToken>>) {
@@ -1723,37 +1766,6 @@ where
Self::oneshot_send(tx, result, "clear_account_history response");
}
- // Remove the key associated with the current account, if there is one.
- // This does not modify settings or account history.
- #[cfg(not(target_os = "android"))]
- fn remove_current_key_rpc(&self) -> impl std::future::Future<Output = Result<(), Error>> {
- let remove_key = if let Some(token) = self.settings.get_account_token() {
- if let Some(wg_data) = self.settings.get_wireguard() {
- Some(
- self.wireguard_key_manager
- .remove_key(token, wg_data.private_key.public_key()),
- )
- } else {
- None
- }
- } else {
- None
- };
-
- async move {
- if let Some(task) = remove_key {
- match task.await {
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- // This result should never occur
- Err(wireguard::Error::TooManyKeys) => Err(Error::TooManyKeys),
- _ => Ok(()),
- }
- } else {
- Ok(())
- }
- }
- }
-
async fn on_get_version_info(&mut self, tx: oneshot::Sender<Option<AppVersionInfo>>) {
if self.app_version_info.is_none() {
log::debug!("No version cache found. Fetching new info");
@@ -1795,17 +1807,13 @@ where
async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) {
let mut last_error = Ok(());
- let remove_key = self.remove_current_key_rpc();
- tokio::spawn(async move {
- if let Err(error) = remove_key.await {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to remove WireGuard key for previous account"
- )
- );
- }
- });
+ if let Err(error) = self.account_manager.logout().await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to clear device cache")
+ );
+ last_error = Err(Error::LogoutError(error));
+ }
if let Err(error) = self.account_history.clear().await {
log::error!(
@@ -2315,7 +2323,16 @@ where
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response");
if settings_changed {
- self.ensure_key_rotation().await;
+ if let Err(error) = self
+ .account_manager
+ .set_rotation_interval(interval.unwrap_or_default())
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to update rotation interval")
+ );
+ }
self.event_listener
.notify_settings(self.settings.to_settings());
}
@@ -2327,128 +2344,27 @@ where
}
}
- async fn ensure_wireguard_keys_for_current_account(&mut self) {
- if let Some(account) = self.settings.get_account_token() {
- if self.settings.get_wireguard().is_none() {
- log::info!("Generating new WireGuard key for account");
- self.wireguard_key_manager
- .spawn_key_generation_task(account, Some(FIRST_KEY_PUSH_TIMEOUT))
- .await;
- } else {
- log::info!("Account already has WireGuard key");
- self.ensure_key_rotation().await;
- }
- }
- }
-
- async fn on_generate_wireguard_key(&mut self, tx: ResponseTx<KeygenEvent, Error>) {
- match self.on_generate_wireguard_key_inner().await {
- Ok(key_event) => {
- Self::oneshot_send(tx, Ok(key_event), "generate_wireguard_key");
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg("Failed to generate new wireguard key")
- );
- Self::oneshot_send(tx, Err(e), "generate_wireguard_key");
- }
- }
- }
-
- async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, Error> {
- let account_token = self
- .settings
- .get_account_token()
- .ok_or(Error::NoAccountToken)?;
- let wireguard_data = self.settings.get_wireguard();
-
- let gen_result = match &wireguard_data {
- Some(wireguard_data) => {
- self.wireguard_key_manager
- .replace_key(account_token.clone(), wireguard_data.get_public_key())
- .await
- }
- None => {
- self.wireguard_key_manager
- .generate_key_sync(account_token.clone())
- .await
- }
- };
-
- match gen_result {
- Ok(new_data) => {
- let public_key = new_data.get_public_key();
- self.settings
- .set_wireguard(Some(new_data))
- .await
- .map_err(Error::SettingsError)?;
- if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY).await;
- }
- let keygen_event = KeygenEvent::NewKey(public_key.clone());
- self.event_listener.notify_key_event(keygen_event.clone());
-
- // update automatic rotation
- self.wireguard_key_manager
- .set_rotation_interval(
- public_key,
- account_token,
- self.settings.tunnel_options.wireguard.rotation_interval,
- )
- .await;
-
- Ok(keygen_event)
- }
- Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys),
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
- }
+ async fn on_rotate_wireguard_key(&self, tx: ResponseTx<(), Error>) {
+ let manager = self.account_manager.clone();
+ tokio::spawn(async move {
+ let result = manager
+ .rotate_key()
+ .await
+ .map(|_| ())
+ .map_err(Error::KeyRotationError);
+ Self::oneshot_send(tx, result, "rotate_wireguard_key response");
+ });
}
- async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<wireguard::PublicKey>, Error>) {
- let result = if self.settings.get_account_token().is_some() {
- Ok(self
- .settings
- .get_wireguard()
- .map(|data| data.get_public_key()))
+ async fn on_get_wireguard_key(&self, tx: ResponseTx<Option<PublicKey>, Error>) {
+ let result = if let Ok(Some(device)) = self.account_manager.data().await {
+ Ok(Some(device.wg_data.get_public_key()))
} else {
Err(Error::NoAccountToken)
};
Self::oneshot_send(tx, result, "get_wireguard_key response");
}
- async fn on_verify_wireguard_key(&mut self, tx: ResponseTx<bool, Error>) {
- let account = match self.settings.get_account_token() {
- Some(account) => account,
- None => {
- Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response");
- return;
- }
- };
- let public_key = match self.settings.get_wireguard() {
- Some(wg_data) => wg_data.private_key.public_key(),
- None => {
- Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response");
- return;
- }
- };
-
- let verification_rpc = self
- .wireguard_key_manager
- .verify_wireguard_key(account, public_key);
-
- tokio::spawn(async move {
- let result = match verification_rpc.await {
- Ok(is_valid) => Ok(is_valid),
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
- Err(wireguard::Error::TooManyKeys) => return,
- };
- Self::oneshot_send(tx, result, "verify_wireguard_key response");
- });
- }
-
fn on_get_settings(&self, tx: oneshot::Sender<Settings>) {
Self::oneshot_send(tx, self.settings.to_settings(), "get_settings response");
}
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index ba828ed903..b6413b357b 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -1,4 +1,4 @@
-use crate::{account_history, settings, DaemonCommand, DaemonCommandSender, EventListener};
+use crate::{account_history, device, settings, DaemonCommand, DaemonCommandSender, EventListener};
use futures::{
channel::{mpsc, oneshot},
StreamExt,
@@ -370,6 +370,7 @@ impl ManagementService for ManagementServiceImpl {
//
async fn create_new_account(&self, _: Request<()>) -> ServiceResult<String> {
+ log::debug!("create_new_account");
let (tx, rx) = oneshot::channel();
self.send_command_to_daemon(DaemonCommand::CreateNewAccount(tx))?;
self.wait_for_result(rx)
@@ -378,20 +379,25 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_daemon_error)
}
- async fn set_account(&self, request: Request<AccountToken>) -> ServiceResult<()> {
- log::debug!("set_account");
+ async fn login_account(&self, request: Request<AccountToken>) -> ServiceResult<()> {
+ log::debug!("login_account");
let account_token = request.into_inner();
- let account_token = if account_token == "" {
- None
- } else {
- Some(account_token)
- };
let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::SetAccount(tx, account_token))?;
+ self.send_command_to_daemon(DaemonCommand::LoginAccount(tx, account_token))?;
self.wait_for_result(rx)
.await?
.map(Response::new)
- .map_err(map_settings_error)
+ .map_err(map_daemon_error)
+ }
+
+ async fn logout_account(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("logout_account");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::LogoutAccount(tx))?;
+ self.wait_for_result(rx)
+ .await?
+ .map(Response::new)
+ .map_err(map_daemon_error)
}
async fn get_account_data(
@@ -479,6 +485,44 @@ impl ManagementService for ManagementServiceImpl {
})
}
+ // Device management
+ async fn get_device(&self, _: Request<()>) -> ServiceResult<types::DeviceConfig> {
+ log::debug!("get_device");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::GetDevice(tx))?;
+ let device = self
+ .wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)?
+ .ok_or(Status::new(Code::NotFound, "no device is set"))?;
+ Ok(Response::new(types::DeviceConfig::from(device)))
+ }
+
+ async fn list_devices(
+ &self,
+ request: Request<AccountToken>,
+ ) -> ServiceResult<types::DeviceList> {
+ log::debug!("list_devices");
+ let (tx, rx) = oneshot::channel();
+ let token = request.into_inner();
+ self.send_command_to_daemon(DaemonCommand::ListDevices(tx, token))?;
+ let device = self.wait_for_result(rx).await?.map_err(map_daemon_error)?;
+ Ok(Response::new(types::DeviceList::from(device)))
+ }
+
+ async fn remove_device(&self, request: Request<types::DeviceRemoval>) -> ServiceResult<()> {
+ log::debug!("remove_device");
+ let (tx, rx) = oneshot::channel();
+ let removal = request.into_inner();
+ self.send_command_to_daemon(DaemonCommand::RemoveDevice(
+ tx,
+ removal.account_token,
+ removal.device_id,
+ ))?;
+ self.wait_for_result(rx).await?.map_err(map_daemon_error)?;
+ Ok(Response::new(()))
+ }
+
// WireGuard key management
//
@@ -515,15 +559,13 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_settings_error)
}
- async fn generate_wireguard_key(&self, _: Request<()>) -> ServiceResult<types::KeygenEvent> {
- // TODO: return error for TooManyKeys, GenerationFailure
- // on success, simply return the new key or nil
- log::debug!("generate_wireguard_key");
+ async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("rotate_wireguard_key");
let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::GenerateWireguardKey(tx))?;
+ self.send_command_to_daemon(DaemonCommand::RotateWireguardKey(tx))?;
self.wait_for_result(rx)
.await?
- .map(|event| Response::new(types::KeygenEvent::from(event)))
+ .map(Response::new)
.map_err(map_daemon_error)
}
@@ -538,16 +580,6 @@ impl ManagementService for ManagementServiceImpl {
}
}
- async fn verify_wireguard_key(&self, _: Request<()>) -> ServiceResult<bool> {
- log::debug!("verify_wireguard_key");
- let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::VerifyWireguardKey(tx))?;
- self.wait_for_result(rx)
- .await?
- .map(Response::new)
- .map_err(map_daemon_error)
- }
-
// Split tunneling
//
@@ -832,14 +864,23 @@ impl EventListener for ManagementInterfaceEventBroadcaster {
})
}
- fn notify_key_event(&self, key_event: mullvad_types::wireguard::KeygenEvent) {
- log::debug!("Broadcasting new wireguard key event");
+ fn notify_device_event(&self, device: mullvad_types::device::DeviceEvent) {
+ log::debug!("Broadcasting device event");
self.notify(types::DaemonEvent {
- event: Some(daemon_event::Event::KeyEvent(types::KeygenEvent::from(
- key_event,
+ event: Some(daemon_event::Event::Device(types::DeviceEvent::from(
+ device,
))),
})
}
+
+ fn notify_remove_device_event(&self, remove_event: mullvad_types::device::RemoveDeviceEvent) {
+ log::debug!("Broadcasting remove device event");
+ self.notify(types::DaemonEvent {
+ event: Some(daemon_event::Event::RemoveDevice(
+ types::RemoveDeviceEvent::from(remove_event),
+ )),
+ })
+ }
}
impl ManagementInterfaceEventBroadcaster {
@@ -857,6 +898,12 @@ fn map_daemon_error(error: crate::Error) -> Status {
match error {
DaemonError::RestError(error) => map_rest_error(error),
DaemonError::SettingsError(error) => map_settings_error(error),
+ DaemonError::AlreadyLoggedIn => Status::already_exists(error.to_string()),
+ DaemonError::LoginError(error) => map_device_error(error),
+ DaemonError::LogoutError(error) => map_device_error(error),
+ DaemonError::KeyRotationError(error) => map_device_error(error),
+ DaemonError::ListDevicesError(error) => map_device_error(error),
+ DaemonError::RemoveDeviceError(error) => map_device_error(error),
#[cfg(windows)]
DaemonError::SplitTunnelError(error) => map_split_tunnel_error(error),
DaemonError::AccountHistory(error) => map_account_history_error(error),
@@ -929,6 +976,22 @@ fn map_settings_error(error: settings::Error) -> Status {
}
}
+/// Converts an instance of [`mullvad_daemon::device::Error`] into a tonic status.
+fn map_device_error(error: device::Error) -> Status {
+ match error {
+ device::Error::MaxDevicesReached => Status::new(Code::ResourceExhausted, error.to_string()),
+ device::Error::InvalidAccount => Status::new(Code::Unauthenticated, error.to_string()),
+ device::Error::InvalidDevice | device::Error::NoDevice => {
+ Status::new(Code::NotFound, error.to_string())
+ }
+ device::Error::DeviceIoError(ref _error) => {
+ Status::new(Code::Unavailable, error.to_string())
+ }
+ device::Error::OtherRestError(error) => map_rest_error(error),
+ _ => Status::new(Code::Unknown, error.to_string()),
+ }
+}
+
/// Converts an instance of [`mullvad_daemon::account_history::Error`] into a tonic status.
fn map_account_history_error(error: account_history::Error) -> Status {
match error {
diff --git a/mullvad-daemon/src/migrations/mod.rs b/mullvad-daemon/src/migrations/mod.rs
index 98ad71c23c..8347b3cd76 100644
--- a/mullvad-daemon/src/migrations/mod.rs
+++ b/mullvad-daemon/src/migrations/mod.rs
@@ -87,7 +87,12 @@ pub enum Error {
pub type Result<T> = std::result::Result<T, Error>;
-pub async fn migrate_all(cache_dir: &Path, settings_dir: &Path) -> Result<()> {
+pub(crate) async fn migrate_all(
+ cache_dir: &Path,
+ settings_dir: &Path,
+ rest_handle: mullvad_rpc::rest::MullvadRestHandle,
+ daemon_tx: crate::DaemonEventSender,
+) -> Result<()> {
#[cfg(windows)]
windows::migrate_after_windows_update(settings_dir)
.await
@@ -114,11 +119,12 @@ pub async fn migrate_all(cache_dir: &Path, settings_dir: &Path) -> Result<()> {
v2::migrate(&mut settings)?;
v3::migrate(&mut settings)?;
v4::migrate(&mut settings)?;
- v5::migrate(&mut settings)?;
account_history::migrate_location(cache_dir, settings_dir).await;
account_history::migrate_formats(settings_dir, &mut settings).await?;
+ v5::migrate(&mut settings, rest_handle, daemon_tx).await?;
+
if settings == old_settings {
// Nothing changed
return Ok(());
diff --git a/mullvad-daemon/src/migrations/v5.rs b/mullvad-daemon/src/migrations/v5.rs
index 0fcaca4e08..0695ee8c7e 100644
--- a/mullvad-daemon/src/migrations/v5.rs
+++ b/mullvad-daemon/src/migrations/v5.rs
@@ -1,5 +1,10 @@
use super::{Error, Result};
-use mullvad_types::settings::SettingsVersion;
+use crate::{device::DeviceService, DaemonEventSender, InternalDaemonEvent};
+use mullvad_types::{
+ account::AccountToken, device::DeviceData, settings::SettingsVersion, wireguard::WireguardData,
+};
+use talpid_core::mpsc::Sender;
+use talpid_types::ErrorExt;
// ======================================================
// Section for vendoring types and values that
@@ -21,16 +26,48 @@ use mullvad_types::settings::SettingsVersion;
/// * `use_mulithop` was not present in the settings
/// * A multihop entry location had been previously specified.
///
-/// This change is backwards compatible since older daemons will just ignore `use_multihop` if
-/// present.
-///
/// It is also no longer valid to have `entry_location` set to null. So remove the field if it
/// is null in order to make it default back to the default location.
-pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
- if !version_matches(settings) {
- return Ok(());
+///
+/// This also removes the account token and WireGuard key from the settings, looks up the
+/// corresponding device, and eventually stores them in `device.json` instead. This is done by
+/// sending the `DeviceMigrationEvent` event to the daemon. Because this is fallible, it can
+/// result in the account token and private key being lost. This should not be not critical since
+/// the account token is also stored in the account history.
+pub(crate) async fn migrate(
+ settings: &mut serde_json::Value,
+ rest_handle: mullvad_rpc::rest::MullvadRestHandle,
+ daemon_tx: DaemonEventSender,
+) -> Result<()> {
+ let migration_data = migrate_inner(settings).await?;
+
+ if let Some(migration_data) = migration_data {
+ let api_handle = rest_handle.availability.clone();
+ let service = DeviceService::new(rest_handle, api_handle);
+ match (migration_data.token, migration_data.wg_data) {
+ (token, Some(wg_data)) => {
+ log::info!("Creating a new device cache from previous settings");
+ tokio::spawn(cache_from_wireguard_key(daemon_tx, service, token, wg_data));
+ }
+ (token, None) => {
+ log::info!("Generating a new device for the account");
+ tokio::spawn(cache_from_account(daemon_tx, service, token));
+ }
+ }
}
+ Ok(())
+}
+
+struct MigrationData {
+ token: AccountToken,
+ wg_data: Option<WireguardData>,
+}
+
+async fn migrate_inner(settings: &mut serde_json::Value) -> Result<Option<MigrationData>> {
+ if !version_matches(settings) {
+ return Ok(None);
+ }
let wireguard_constraints = || -> Option<&serde_json::Value> {
settings
.get("relay_settings")?
@@ -54,11 +91,35 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
}
}
+ if let Some(token) = settings.get("account_token").filter(|t| !t.is_null()) {
+ let token: AccountToken =
+ serde_json::from_value(token.clone()).map_err(Error::ParseError)?;
+ let mig_data = if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) {
+ let wg_data: WireguardData =
+ serde_json::from_value(wg_data.clone()).map_err(Error::ParseError)?;
+ Ok(Some(MigrationData {
+ token,
+ wg_data: Some(wg_data),
+ }))
+ } else {
+ Ok(Some(MigrationData {
+ token,
+ wg_data: None,
+ }))
+ };
+
+ let settings_map = settings.as_object_mut().ok_or(Error::NoMatchingVersion)?;
+ settings_map.remove("account_token");
+ settings_map.remove("wireguard");
+
+ return mig_data;
+ }
+
// Note: Not incrementing the version number yet, since this migration is still open
// for future modification.
// settings["settings_version"] = serde_json::json!(SettingsVersion::V6);
- Ok(())
+ Ok(None)
}
fn version_matches(settings: &mut serde_json::Value) -> bool {
@@ -68,9 +129,56 @@ fn version_matches(settings: &mut serde_json::Value) -> bool {
.unwrap_or(false)
}
+async fn cache_from_wireguard_key(
+ daemon_tx: DaemonEventSender,
+ service: DeviceService,
+ token: AccountToken,
+ wg_data: WireguardData,
+) {
+ let devices = match service.list_devices_with_backoff(token.clone()).await {
+ Ok(devices) => devices,
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to enumerate devices for account")
+ );
+ return;
+ }
+ };
+
+ for device in devices.into_iter() {
+ if device.pubkey == wg_data.private_key.public_key() {
+ let _ = daemon_tx.send(InternalDaemonEvent::DeviceMigrationEvent(DeviceData {
+ token,
+ device,
+ wg_data,
+ }));
+ return;
+ }
+ }
+ log::info!("The existing WireGuard key is not valid; generating a new device");
+ cache_from_account(daemon_tx, service, token).await;
+}
+
+async fn cache_from_account(
+ daemon_tx: DaemonEventSender,
+ service: DeviceService,
+ token: AccountToken,
+) {
+ match service.generate_for_account_with_backoff(token).await {
+ Ok(device_data) => {
+ let _ = daemon_tx.send(InternalDaemonEvent::DeviceMigrationEvent(device_data));
+ }
+ Err(error) => log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to generate new device for account")
+ ),
+ }
+}
+
#[cfg(test)]
mod test {
- use super::{migrate, version_matches};
+ use super::{migrate_inner, version_matches};
use serde_json;
pub const V5_SETTINGS_V1: &str = r#"
@@ -144,7 +252,6 @@ mod test {
pub const V5_SETTINGS_V2: &str = r#"
{
- "account_token": "1234",
"relay_settings": {
"normal": {
"location": {
@@ -212,13 +319,12 @@ mod test {
}
"#;
- #[test]
- fn test_v5_v1_migration() {
+ #[tokio::test]
+ async fn test_v5_v1_migration() {
let mut old_settings = serde_json::from_str(V5_SETTINGS_V1).unwrap();
assert!(version_matches(&mut old_settings));
-
- migrate(&mut old_settings).unwrap();
+ migrate_inner(&mut old_settings).await.unwrap();
let new_settings: serde_json::Value = serde_json::from_str(V5_SETTINGS_V2).unwrap();
assert_eq!(&old_settings, &new_settings);
diff --git a/mullvad-daemon/src/relays/mod.rs b/mullvad-daemon/src/relays/mod.rs
index 332ca5fea3..c4a136369c 100644
--- a/mullvad-daemon/src/relays/mod.rs
+++ b/mullvad-daemon/src/relays/mod.rs
@@ -276,7 +276,6 @@ impl RelaySelector {
relay_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> Result<RelaySelectorResult, Error> {
match relay_constraints.tunnel_protocol {
Constraint::Only(TunnelType::OpenVpn) => self.get_openvpn_endpoint(
@@ -293,12 +292,9 @@ impl RelaySelector {
&relay_constraints.wireguard_constraints,
retry_attempt,
),
- Constraint::Any => self.get_any_tunnel_endpoint(
- relay_constraints,
- bridge_state,
- retry_attempt,
- wg_key_exists,
- ),
+ Constraint::Any => {
+ self.get_any_tunnel_endpoint(relay_constraints, bridge_state, retry_attempt)
+ }
}
}
@@ -479,14 +475,9 @@ impl RelaySelector {
relay_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> Result<RelaySelectorResult, Error> {
- let preferred_constraints = self.preferred_constraints(
- &relay_constraints,
- bridge_state,
- retry_attempt,
- wg_key_exists,
- );
+ let preferred_constraints =
+ self.preferred_constraints(&relay_constraints, bridge_state, retry_attempt);
let original_matcher: RelayMatcher<_> = relay_constraints.clone().into();
let preferred_tunnel_protocol = preferred_constraints.tunnel_protocol;
@@ -543,14 +534,12 @@ impl RelaySelector {
original_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> RelayConstraints {
let (preferred_port, preferred_protocol, preferred_tunnel) = self
.preferred_tunnel_constraints(
retry_attempt,
&original_constraints.location,
&original_constraints.providers,
- wg_key_exists,
);
let mut relay_constraints = original_constraints.clone();
@@ -731,7 +720,6 @@ impl RelaySelector {
retry_attempt: u32,
location_constraint: &Constraint<LocationConstraint>,
providers_constraint: &Constraint<Providers>,
- wg_key_exists: bool,
) -> (Constraint<u16>, TransportProtocol, TunnelType) {
#[cfg(target_os = "windows")]
{
@@ -757,7 +745,7 @@ impl RelaySelector {
});
// If location does not support WireGuard, defer to preferred OpenVPN tunnel
// constraints
- if !location_supports_wireguard || !wg_key_exists {
+ if !location_supports_wireguard {
let (preferred_port, preferred_protocol) =
Self::preferred_openvpn_constraints(retry_attempt);
return (preferred_port, preferred_protocol, TunnelType::OpenVpn);
@@ -1159,7 +1147,7 @@ mod test {
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
@@ -1167,7 +1155,7 @@ mod test {
for attempt in 0..10 {
assert!(relay_selector
- .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.is_ok());
}
@@ -1184,7 +1172,7 @@ mod test {
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1192,7 +1180,7 @@ mod test {
for attempt in 0..10 {
assert!(relay_selector
- .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.is_ok());
}
@@ -1205,7 +1193,6 @@ mod test {
&relay_constraints,
BridgeState::Off,
attempt,
- true,
);
assert_eq!(
preferred.tunnel_protocol,
@@ -1215,7 +1202,6 @@ mod test {
&relay_constraints,
BridgeState::Off,
attempt,
- true,
) {
Ok(result) if matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)) => (),
_ => panic!("OpenVPN endpoint was not selected"),
@@ -1250,14 +1236,14 @@ mod test {
// The same host cannot be used for entry and exit
assert!(relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.is_err());
relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location2);
// If the entry and exit differ, this should succeed
assert!(relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.is_ok());
}
@@ -1286,7 +1272,7 @@ mod test {
// The exit must not equal the entry
let exit_relay = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.map_err(|error| error.to_string())?
.exit_relay;
@@ -1301,7 +1287,7 @@ mod test {
endpoint,
..
} = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.map_err(|error| error.to_string())?;
assert_eq!(exit_relay.hostname, specific_hostname);
@@ -1336,7 +1322,7 @@ mod test {
});
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1362,7 +1348,7 @@ mod test {
..RelayConstraints::default()
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
@@ -1381,14 +1367,14 @@ mod test {
#[cfg(all(unix, not(target_os = "android")))]
{
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
);
}
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1405,54 +1391,6 @@ mod test {
}
#[test]
- fn test_wg_relay_with_no_key() {
- let mut relay_constraints = RelayConstraints {
- tunnel_protocol: Constraint::Only(TunnelType::Wireguard),
- ..RelayConstraints::default()
- };
-
- let relay_selector = new_relay_selector();
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect("Failed to get WireGuard relay when WireGuard relay was specified as the only tunnel protocol");
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
-
- relay_constraints.tunnel_protocol = Constraint::Any;
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect("Failed to get OpenVPN relay with tunnel protocol constraint set to Any and without a WireGuard key");
-
- assert!(matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)));
-
- let wireguard_specific_location = LocationConstraint::Hostname(
- "se".to_string(),
- "got".to_string(),
- "se9-wireguard".to_string(),
- );
- relay_constraints.location = Constraint::Only(wireguard_specific_location);
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect(
- "Failed to get a valid WireGuard relay when tunnel constraints are set to any
- tunnel protocol and with a wireguard specific location without a wireguard key",
- );
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
- .expect(
- "Failed to get a valid WireGuard relay when tunnel constraints are set to any
- tunnel protocol and with a wireguard specific location with a wireguard key",
- );
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
- }
-
- #[test]
fn test_selecting_any_relay_will_consider_multihop() {
let relay_constraints = RelayConstraints {
wireguard_constraints: WireguardConstraints {
@@ -1467,7 +1405,7 @@ mod test {
let relay_selector = new_relay_selector();
- let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection");
// Windows will ignore WireGuard until WireGuard is supported well enough
// TODO: Remove this caveat once Windows defaults to using WireGuard
@@ -1502,7 +1440,7 @@ mod test {
fn test_selecting_wireguard_location_will_consider_multihop() {
let relay_selector = new_relay_selector();
- let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, true)
+ let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0)
.expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection");
@@ -1526,7 +1464,7 @@ mod test {
let relay_selector = new_relay_selector();
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(result.entry_relay.is_some());
@@ -1555,7 +1493,7 @@ mod test {
let relay_selector = new_relay_selector();
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get WireGuard TCP relay");
let endpoint = result.endpoint.unwrap_wireguard();
assert!(matches!(endpoint.peer.protocol, TransportProtocol::Tcp));
@@ -1570,7 +1508,7 @@ mod test {
const INVALID_UDP_PORTS: [u16; 2] = [80, 443];
for attempt in 0..1000 {
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(!INVALID_UDP_PORTS.contains(&result.endpoint.to_endpoint().address.port()));
assert_eq!(
@@ -1587,7 +1525,7 @@ mod test {
const VALID_TCP_PORTS: [u16; 3] = [80, 443, 5001];
for attempt in 0..1000 {
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(VALID_TCP_PORTS.contains(&result.endpoint.to_endpoint().address.port()));
assert_eq!(
@@ -1609,7 +1547,7 @@ mod test {
..RelayConstraints::default()
};
relay_selector
- .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&constraints, BridgeState::Off, 0)
.expect_err("Successfully selected a relay that should be filtered");
constraints.location = Constraint::Only(LocationConstraint::Hostname(
@@ -1619,7 +1557,7 @@ mod test {
));
relay_selector
- .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&constraints, BridgeState::Off, 0)
.expect_err("Successfully selected a relay that should be filtered");
}
}
diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs
index ec610f63d4..bf3fe710c8 100644
--- a/mullvad-daemon/src/settings.rs
+++ b/mullvad-daemon/src/settings.rs
@@ -3,7 +3,7 @@ use futures::TryFutureExt;
use mullvad_types::{
relay_constraints::{BridgeSettings, BridgeState, RelaySettingsUpdate},
settings::{DnsOptions, Settings},
- wireguard::{RotationInterval, WireguardData},
+ wireguard::RotationInterval,
};
#[cfg(target_os = "windows")]
use std::collections::HashSet;
@@ -191,21 +191,6 @@ impl SettingsPersister {
settings
}
- /// Changes account number to the one given. Also saves the new settings to disk.
- /// The boolean in the Result indicates if the account token changed or not
- pub async fn set_account_token(
- &mut self,
- account_token: Option<String>,
- ) -> Result<bool, Error> {
- let should_save = self.settings.set_account_token(account_token);
- self.update(should_save).await
- }
-
- pub async fn set_wireguard(&mut self, wireguard: Option<WireguardData>) -> Result<bool, Error> {
- let should_save = self.settings.set_wireguard(wireguard);
- self.update(should_save).await
- }
-
pub async fn update_relay_settings(
&mut self,
update: RelaySettingsUpdate,
diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs
deleted file mode 100644
index eb198b858b..0000000000
--- a/mullvad-daemon/src/wireguard.rs
+++ /dev/null
@@ -1,499 +0,0 @@
-use crate::{DaemonEventSender, InternalDaemonEvent};
-use chrono::offset::Utc;
-use mullvad_rpc::{
- availability::ApiAvailabilityHandle,
- rest::{Error as RestError, MullvadRestHandle},
-};
-use mullvad_types::account::AccountToken;
-pub use mullvad_types::wireguard::*;
-use std::{future::Future, pin::Pin, time::Duration};
-
-use futures::future::{abortable, AbortHandle};
-#[cfg(not(target_os = "android"))]
-use talpid_core::future_retry::constant_interval;
-use talpid_core::{
- future_retry::{retry_future, retry_future_n, ExponentialBackoff, Jittered},
- mpsc::Sender,
-};
-
-pub use talpid_types::net::wireguard::{
- ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters,
-};
-use talpid_types::ErrorExt;
-
-/// How long to wait before starting key rotation
-const ROTATION_START_DELAY: Duration = Duration::from_secs(60 * 3);
-
-/// How often to check whether the key has expired.
-/// A short interval is used in case the computer is ever suspended.
-const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60);
-
-const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
-const RETRY_INTERVAL_FACTOR: u32 = 5;
-const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-
-#[cfg(not(target_os = "android"))]
-const SHORT_RETRY_INTERVAL: Duration = Duration::ZERO;
-
-const MAX_KEY_REMOVAL_RETRIES: usize = 2;
-
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- #[error(display = "Unexpected HTTP request error")]
- RestError(#[error(source)] mullvad_rpc::rest::Error),
- #[error(display = "API availability check was interrupted")]
- ApiCheckError(#[error(source)] mullvad_rpc::availability::Error),
- #[error(display = "Account already has maximum number of keys")]
- TooManyKeys,
-}
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-pub struct KeyManager {
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- current_job: Option<AbortHandle>,
-
- abort_scheduler_tx: Option<AbortHandle>,
- auto_rotation_interval: RotationInterval,
-}
-
-impl KeyManager {
- pub(crate) fn new(
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- ) -> Self {
- Self {
- daemon_tx,
- availability_handle,
- http_handle,
- current_job: None,
- abort_scheduler_tx: None,
- auto_rotation_interval: RotationInterval::default(),
- }
- }
-
- /// Reset key rotation, cancelling the current one and starting a new one for the specified
- /// account
- pub async fn reset_rotation(&mut self, current_key: PublicKey, account_token: AccountToken) {
- self.run_automatic_rotation(account_token, current_key)
- .await
- }
-
- /// Update automatic key rotation interval
- /// Passing `None` for the interval will cause the default value to be used.
- pub async fn set_rotation_interval(
- &mut self,
- current_key: PublicKey,
- account_token: AccountToken,
- auto_rotation_interval: Option<RotationInterval>,
- ) {
- self.auto_rotation_interval = auto_rotation_interval.unwrap_or_default();
- self.reset_rotation(current_key, account_token).await;
- }
-
- /// Stop current key generation
- pub fn reset(&mut self) {
- if let Some(job) = self.current_job.take() {
- job.abort()
- }
- }
-
- /// Generate a new private key
- pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> {
- self.reset();
- let private_key = PrivateKey::new_from_random();
-
- self.push_future_generator(account, private_key, None)()
- .await
- .map_err(Self::map_rpc_error)
- }
-
- /// Replace a key for an account synchronously
- pub async fn replace_key(
- &mut self,
- account: AccountToken,
- old_key: PublicKey,
- ) -> Result<WireguardData> {
- self.reset();
-
- let new_key = PrivateKey::new_from_random();
- Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await
- }
-
- /// Verifies whether a key is valid or not.
- pub fn verify_wireguard_key(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<bool>> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- async move {
- match rpc.get_wireguard_key(account, &key).await {
- Ok(_) => Ok(true),
- Err(mullvad_rpc::rest::Error::ApiError(status, _code))
- if status == mullvad_rpc::StatusCode::NOT_FOUND =>
- {
- Ok(false)
- }
- Err(err) => Err(Self::map_rpc_error(err)),
- }
- }
- }
-
- /// Removes a key from an account
- #[cfg(not(target_os = "android"))]
- pub fn remove_key(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<()>> {
- self.remove_key_inner(account, key, constant_interval(SHORT_RETRY_INTERVAL), false)
- }
-
- /// Removes a key from an account
- pub fn remove_key_with_backoff(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<()>> {
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
- self.remove_key_inner(account, key, retry_strategy, true)
- }
-
- fn remove_key_inner<D: Iterator<Item = Duration> + 'static>(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- retry_strategy: D,
- offline_check: bool,
- ) -> impl Future<Output = Result<()>> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- let api_handle = self.availability_handle.clone();
- let api_handle_2 = api_handle.clone();
- let future = retry_future_n(
- move || {
- let remove_key = rpc.remove_wireguard_key(account.clone(), key.clone());
- let wait_future = api_handle.wait_online();
- async move {
- if offline_check {
- let _ = wait_future.await;
- }
- remove_key.await
- }
- },
- move |result| match result {
- Ok(_) => false,
- Err(error) => Self::should_retry_removal(error, &api_handle_2),
- },
- retry_strategy,
- MAX_KEY_REMOVAL_RETRIES,
- );
- async move { future.await.map_err(Self::map_rpc_error) }
- }
-
- fn should_retry_removal(error: &RestError, api_handle: &ApiAvailabilityHandle) -> bool {
- error.is_network_error() && !api_handle.get_state().is_offline()
- }
-
- fn should_retry(error: &RestError) -> bool {
- if let RestError::ApiError(_status, code) = &error {
- code != mullvad_rpc::INVALID_ACCOUNT && code != mullvad_rpc::KEY_LIMIT_REACHED
- } else {
- true
- }
- }
-
- /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel.
- pub async fn spawn_key_generation_task(
- &mut self,
- account: AccountToken,
- timeout: Option<Duration>,
- ) {
- self.reset();
- let private_key = PrivateKey::new_from_random();
-
- let error_tx = self.daemon_tx.clone();
- let error_account = account.clone();
-
- let mut inner_future_generator =
- self.push_future_generator(account.clone(), private_key, timeout);
-
- let availability_handle = self.availability_handle.clone();
-
- let future_generator = move || {
- let wait_available = availability_handle.wait_background();
- let fut = inner_future_generator();
- let error_tx = error_tx.clone();
- let error_account = error_account.clone();
- async move {
- let error_account_copy = error_account.clone();
- wait_available.await.map_err(|error| {
- let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent((
- error_account_copy,
- Err(Error::ApiCheckError(error)),
- )));
- false
- })?;
- let response = fut.await;
- match response {
- Ok(addresses) => Ok(addresses),
- Err(err) => {
- let should_retry = Self::should_retry(&err);
- let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent((
- error_account,
- Err(Self::map_rpc_error(err)),
- )));
- Err(should_retry)
- }
- }
- }
- };
-
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
-
- let should_retry = move |result: &std::result::Result<_, bool>| -> bool {
- match result {
- Ok(_) => false,
- Err(should_retry) => *should_retry,
- }
- };
-
- let upload_future = retry_future(future_generator, should_retry, retry_strategy);
-
- let (cancellable_upload, abort_handle) = abortable(Box::pin(upload_future));
- let daemon_tx = self.daemon_tx.clone();
- let future = async move {
- match cancellable_upload.await {
- Ok(Ok(wireguard_data)) => {
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account,
- Ok(wireguard_data),
- )));
- }
- Ok(Err(_)) => {}
- Err(_) => {
- log::error!("Key generation cancelled");
- }
- }
- };
-
- tokio::spawn(Box::pin(future));
- self.current_job = Some(abort_handle);
- }
-
- fn push_future_generator(
- &self,
- account: AccountToken,
- private_key: PrivateKey,
- timeout: Option<Duration>,
- ) -> Box<
- dyn FnMut() -> Pin<
- Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>,
- > + Send,
- > {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- let public_key = private_key.public_key();
-
- let push_future =
- move || -> std::pin::Pin<Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send >> {
- let key = private_key.clone();
- let address_future = rpc
- .push_wg_key(account.clone(), public_key.clone(), timeout);
- Box::pin(async move {
- let addresses = address_future.await?;
- Ok(WireguardData {
- private_key: key,
- addresses,
- created: Utc::now(),
- })
- })
- };
- Box::new(push_future)
- }
-
- async fn replace_key_rpc(
- http_handle: MullvadRestHandle,
- account: AccountToken,
- old_key: PublicKey,
- new_key: PrivateKey,
- ) -> Result<WireguardData> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(http_handle);
- let new_public_key = new_key.public_key();
- let addresses = rpc
- .replace_wg_key(account, old_key.key, new_public_key)
- .await
- .map_err(Self::map_rpc_error)?;
- Ok(WireguardData {
- private_key: new_key,
- addresses,
- created: Utc::now(),
- })
- }
-
- fn map_rpc_error(err: mullvad_rpc::rest::Error) -> Error {
- match &err {
- // TODO: Consider handling the invalid account case too.
- mullvad_rpc::rest::Error::ApiError(status, message)
- if *status == mullvad_rpc::StatusCode::BAD_REQUEST
- && message == mullvad_rpc::KEY_LIMIT_REACHED =>
- {
- Error::TooManyKeys
- }
- _ => Error::RestError(err),
- }
- }
-
- async fn wait_for_key_expiry(key: &PublicKey, rotation_interval_secs: u64) {
- let mut interval = tokio::time::interval(KEY_CHECK_INTERVAL);
- interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
- loop {
- interval.tick().await;
- if (Utc::now().signed_duration_since(key.created)).num_seconds() as u64
- >= rotation_interval_secs
- {
- return;
- }
- }
- }
-
- async fn create_automatic_rotation(
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- mut public_key: PublicKey,
- rotation_interval_secs: u64,
- account_token: AccountToken,
- ) {
- tokio::time::sleep(ROTATION_START_DELAY).await;
-
- let rotate_key_for_account =
- move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> {
- let wait_available = availability_handle.wait_background();
- let rotate = Self::rotate_key(
- daemon_tx.clone(),
- http_handle.clone(),
- account_token.clone(),
- old_key.clone(),
- );
- Box::pin(async move {
- wait_available.await?;
- rotate.await
- })
- };
-
- loop {
- Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await;
-
- let rotate_key_for_account_copy = rotate_key_for_account.clone();
- match Self::rotate_key_with_retries(public_key.clone(), rotate_key_for_account_copy)
- .await
- {
- Ok(new_key) => public_key = new_key,
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Stopping automatic key rotation due to an error"
- )
- );
- return;
- }
- }
- }
- }
-
- fn rotate_key(
- daemon_tx: DaemonEventSender,
- http_handle: MullvadRestHandle,
- account_token: AccountToken,
- old_key: PublicKey,
- ) -> impl Future<Output = Result<PublicKey>> {
- let new_key = PrivateKey::new_from_random();
- let rpc_result =
- Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key);
-
- async move {
- match rpc_result.await {
- Ok(data) => {
- // Update account data
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account_token,
- Ok(data.clone()),
- )));
- Ok(data.get_public_key())
- }
- Err(Error::TooManyKeys) => {
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account_token,
- Err(Error::TooManyKeys),
- )));
- Err(Error::TooManyKeys)
- }
- Err(unknown) => Err(unknown),
- }
- }
- }
-
- async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey>
- where
- F: FnMut(&PublicKey) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>>
- + Clone
- + 'static,
- {
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
- let should_retry = move |result: &Result<PublicKey>| -> bool {
- match result {
- Ok(_) => false,
- Err(error) => match error {
- Error::RestError(error) => Self::should_retry(error),
- _ => false,
- },
- }
- };
-
- retry_future(
- move || rotate_key.clone()(&old_key),
- should_retry,
- retry_strategy,
- )
- .await
- }
-
- async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) {
- self.stop_automatic_rotation();
-
- log::debug!("Starting automatic key rotation job");
- // Schedule cancellable series of repeating rotation tasks
- let fut = Self::create_automatic_rotation(
- self.daemon_tx.clone(),
- self.availability_handle.clone(),
- self.http_handle.clone(),
- public_key,
- self.auto_rotation_interval.as_duration().as_secs(),
- account_token,
- );
- let (request, abort_handle) = abortable(Box::pin(fut));
-
- tokio::spawn(request);
- self.abort_scheduler_tx = Some(abort_handle);
- }
-
- fn stop_automatic_rotation(&mut self) {
- if let Some(abort_handle) = self.abort_scheduler_tx.take() {
- log::info!("Stopping automatic key rotation");
- abort_handle.abort();
- }
- }
-}
diff --git a/mullvad-jni/src/daemon_interface.rs b/mullvad-jni/src/daemon_interface.rs
index 6f47fe21e8..3550b20d32 100644
--- a/mullvad-jni/src/daemon_interface.rs
+++ b/mullvad-jni/src/daemon_interface.rs
@@ -1,14 +1,15 @@
use futures::{channel::oneshot, executor::block_on};
-use mullvad_daemon::{DaemonCommand, DaemonCommandSender};
+use mullvad_daemon::{device, DaemonCommand, DaemonCommandSender};
use mullvad_types::{
account::{AccountData, AccountToken, VoucherSubmission},
+ device::{Device, DeviceConfig},
location::GeoIpLocation,
relay_constraints::RelaySettingsUpdate,
relay_list::RelayList,
settings::{DnsOptions, Settings},
states::{TargetState, TunnelState},
version::AppVersionInfo,
- wireguard::{self, KeygenEvent},
+ wireguard,
};
#[derive(Debug, err_derive::Error)]
@@ -37,6 +38,12 @@ impl From<mullvad_daemon::Error> for Error {
fn from(error: mullvad_daemon::Error) -> Error {
match error {
mullvad_daemon::Error::RestError(error) => Error::RpcError(error),
+ mullvad_daemon::Error::LoginError(device::Error::OtherRestError(error)) => {
+ Error::RpcError(error)
+ }
+ mullvad_daemon::Error::ListDevicesError(device::Error::OtherRestError(error)) => {
+ Error::RpcError(error)
+ }
error => Error::OtherError(error),
}
}
@@ -79,16 +86,6 @@ impl DaemonInterface {
block_on(rx).map(|_| ()).map_err(|_| Error::NoResponse)
}
- pub fn generate_wireguard_key(&self) -> Result<KeygenEvent> {
- let (tx, rx) = oneshot::channel();
-
- self.send_command(DaemonCommand::GenerateWireguardKey(tx))?;
-
- block_on(rx)
- .map_err(|_| Error::NoResponse)?
- .map_err(Error::from)
- }
-
pub fn get_account_data(&self, account_token: String) -> Result<AccountData> {
let (tx, rx) = oneshot::channel();
@@ -195,23 +192,54 @@ impl DaemonInterface {
.map_err(Error::from)
}
- pub fn verify_wireguard_key(&self) -> Result<bool> {
+ pub fn login_account(&self, account_token: String) -> Result<()> {
+ let (tx, rx) = oneshot::channel();
+
+ self.send_command(DaemonCommand::LoginAccount(tx, account_token))?;
+
+ block_on(rx)
+ .map_err(|_| Error::NoResponse)?
+ .map_err(Error::from)
+ }
+
+ pub fn logout_account(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();
- self.send_command(DaemonCommand::VerifyWireguardKey(tx))?;
+ self.send_command(DaemonCommand::LogoutAccount(tx))?;
+
block_on(rx)
.map_err(|_| Error::NoResponse)?
.map_err(Error::from)
}
- pub fn set_account(&self, account_token: Option<String>) -> Result<()> {
+ pub fn get_device(&self) -> Result<Option<DeviceConfig>> {
let (tx, rx) = oneshot::channel();
- self.send_command(DaemonCommand::SetAccount(tx, account_token))?;
+ self.send_command(DaemonCommand::GetDevice(tx))?;
block_on(rx)
.map_err(|_| Error::NoResponse)?
- .map_err(|_| Error::SettingsError)
+ .map_err(Error::from)
+ }
+
+ pub fn list_devices(&self, account_token: String) -> Result<Vec<Device>> {
+ let (tx, rx) = oneshot::channel();
+
+ self.send_command(DaemonCommand::ListDevices(tx, account_token))?;
+
+ block_on(rx)
+ .map_err(|_| Error::NoResponse)?
+ .map_err(Error::from)
+ }
+
+ pub fn remove_device(&self, account_token: String, device_id: String) -> Result<()> {
+ let (tx, rx) = oneshot::channel();
+
+ self.send_command(DaemonCommand::RemoveDevice(tx, account_token, device_id))?;
+
+ block_on(rx)
+ .map_err(|_| Error::NoResponse)?
+ .map_err(Error::from)
}
pub fn set_allow_lan(&self, allow_lan: bool) -> Result<()> {
diff --git a/mullvad-jni/src/jni_event_listener.rs b/mullvad-jni/src/jni_event_listener.rs
index 9fd5c3d2ea..553f3f48f3 100644
--- a/mullvad-jni/src/jni_event_listener.rs
+++ b/mullvad-jni/src/jni_event_listener.rs
@@ -7,8 +7,11 @@ use jnix::{
};
use mullvad_daemon::EventListener;
use mullvad_types::{
- relay_list::RelayList, settings::Settings, states::TunnelState, version::AppVersionInfo,
- wireguard::KeygenEvent,
+ device::{DeviceEvent, RemoveDeviceEvent},
+ relay_list::RelayList,
+ settings::Settings,
+ states::TunnelState,
+ version::AppVersionInfo,
};
use std::{sync::mpsc, thread};
use talpid_types::ErrorExt;
@@ -27,11 +30,12 @@ pub enum Error {
}
enum Event {
- KeygenEvent(KeygenEvent),
RelayList(RelayList),
Settings(Settings),
Tunnel(TunnelState),
AppVersionInfo(AppVersionInfo),
+ DeviceEvent(DeviceEvent),
+ RemoveDeviceEvent(RemoveDeviceEvent),
}
#[derive(Clone, Debug)]
@@ -44,10 +48,6 @@ impl JniEventListener {
}
impl EventListener for JniEventListener {
- fn notify_key_event(&self, key_event: KeygenEvent) {
- let _ = self.0.send(Event::KeygenEvent(key_event));
- }
-
fn notify_new_state(&self, state: TunnelState) {
let _ = self.0.send(Event::Tunnel(state));
}
@@ -63,16 +63,25 @@ impl EventListener for JniEventListener {
fn notify_app_version(&self, app_version_info: AppVersionInfo) {
let _ = self.0.send(Event::AppVersionInfo(app_version_info));
}
+
+ fn notify_device_event(&self, event: DeviceEvent) {
+ let _ = self.0.send(Event::DeviceEvent(event));
+ }
+
+ fn notify_remove_device_event(&self, event: RemoveDeviceEvent) {
+ let _ = self.0.send(Event::RemoveDeviceEvent(event));
+ }
}
struct JniEventHandler<'env> {
env: JnixEnv<'env>,
mullvad_ipc_client: JObject<'env>,
notify_app_version_info_event: JMethodID<'env>,
- notify_keygen_event: JMethodID<'env>,
notify_relay_list_event: JMethodID<'env>,
notify_settings_event: JMethodID<'env>,
notify_tunnel_event: JMethodID<'env>,
+ notify_device_event: JMethodID<'env>,
+ notify_remove_device_event: JMethodID<'env>,
events: mpsc::Receiver<Event>,
}
@@ -123,12 +132,6 @@ impl<'env> JniEventHandler<'env> {
"notifyAppVersionInfoEvent",
"(Lnet/mullvad/mullvadvpn/model/AppVersionInfo;)V",
)?;
- let notify_keygen_event = Self::get_method_id(
- &env,
- &class,
- "notifyKeygenEvent",
- "(Lnet/mullvad/mullvadvpn/model/KeygenEvent;)V",
- )?;
let notify_relay_list_event = Self::get_method_id(
&env,
&class,
@@ -147,15 +150,28 @@ impl<'env> JniEventHandler<'env> {
"notifyTunnelStateEvent",
"(Lnet/mullvad/mullvadvpn/model/TunnelState;)V",
)?;
+ let notify_device_event = Self::get_method_id(
+ &env,
+ &class,
+ "notifyDeviceEvent",
+ "(Lnet/mullvad/mullvadvpn/model/DeviceEvent;)V",
+ )?;
+ let notify_remove_device_event = Self::get_method_id(
+ &env,
+ &class,
+ "notifyRemoveDeviceEvent",
+ "(Lnet/mullvad/mullvadvpn/model/RemoveDeviceEvent;)V",
+ )?;
Ok(JniEventHandler {
env,
mullvad_ipc_client,
notify_app_version_info_event,
- notify_keygen_event,
notify_relay_list_event,
notify_settings_event,
notify_tunnel_event,
+ notify_device_event,
+ notify_remove_device_event,
events,
})
}
@@ -173,31 +189,53 @@ impl<'env> JniEventHandler<'env> {
fn run(&mut self) {
while let Ok(event) = self.events.recv() {
match event {
- Event::KeygenEvent(keygen_event) => self.handle_keygen_event(keygen_event),
Event::RelayList(relay_list) => self.handle_relay_list_event(relay_list),
Event::Settings(settings) => self.handle_settings(settings),
Event::Tunnel(tunnel_event) => self.handle_tunnel_event(tunnel_event),
Event::AppVersionInfo(app_version_info) => {
self.handle_app_version_info_event(app_version_info)
}
+ Event::DeviceEvent(device_event) => self.handle_device_event(device_event),
+ Event::RemoveDeviceEvent(device_event) => {
+ self.handle_remove_device_event(device_event)
+ }
}
}
}
- fn handle_keygen_event(&self, event: KeygenEvent) {
- let java_keygen_event = event.into_java(&self.env);
+ fn handle_device_event(&self, device_event: DeviceEvent) {
+ let java_event = device_event.into_java(&self.env);
+
+ let result = self.env.call_method_unchecked(
+ self.mullvad_ipc_client,
+ self.notify_device_event,
+ JavaType::Primitive(Primitive::Void),
+ &[JValue::Object(java_event.as_obj())],
+ );
+
+ if let Err(error) = result {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to call MullvadDaemon.notifyDeviceEvent")
+ );
+ }
+ }
+
+ fn handle_remove_device_event(&self, remove_event: RemoveDeviceEvent) {
+ let java_event = remove_event.into_java(&self.env);
let result = self.env.call_method_unchecked(
self.mullvad_ipc_client,
- self.notify_keygen_event,
+ self.notify_remove_device_event,
JavaType::Primitive(Primitive::Void),
- &[JValue::Object(java_keygen_event.as_obj())],
+ &[JValue::Object(java_event.as_obj())],
);
if let Err(error) = result {
log::error!(
"{}",
- error.display_chain_with_msg("Failed to call MullvadDaemon.notifyKeygenEvent")
+ error
+ .display_chain_with_msg("Failed to call MullvadDaemon.notifyRemoveDeviceEvent")
);
}
}
diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs
index 646988e11b..c98c132e60 100644
--- a/mullvad-jni/src/lib.rs
+++ b/mullvad-jni/src/lib.rs
@@ -19,7 +19,8 @@ use jnix::{
FromJava, IntoJava, JnixEnv,
};
use mullvad_daemon::{
- exception_logging, logging, runtime::new_runtime_builder, version, Daemon, DaemonCommandChannel,
+ device, exception_logging, logging, runtime::new_runtime_builder, version, Daemon,
+ DaemonCommandChannel,
};
use mullvad_rpc::{rest::Error as RestError, StatusCode};
use mullvad_types::{
@@ -92,6 +93,65 @@ impl From<Result<AccountData, daemon_interface::Error>> for GetAccountDataResult
#[derive(IntoJava)]
#[jnix(package = "net.mullvad.mullvadvpn.model")]
+pub enum LoginResult {
+ Ok,
+ InvalidAccount,
+ MaxDevicesReached,
+ RpcError,
+ OtherError,
+}
+
+impl From<Result<(), daemon_interface::Error>> for LoginResult {
+ fn from(result: Result<(), daemon_interface::Error>) -> Self {
+ match result {
+ Ok(()) => LoginResult::Ok,
+ Err(error) => match error {
+ daemon_interface::Error::OtherError(mullvad_daemon::Error::LoginError(error)) => {
+ match error {
+ device::Error::InvalidAccount => LoginResult::InvalidAccount,
+ device::Error::MaxDevicesReached => LoginResult::MaxDevicesReached,
+ device::Error::OtherRestError(_) => LoginResult::RpcError,
+ _ => LoginResult::OtherError,
+ }
+ }
+ daemon_interface::Error::RpcError(_) => LoginResult::RpcError,
+ _ => LoginResult::OtherError,
+ },
+ }
+ }
+}
+
+#[derive(IntoJava)]
+#[jnix(package = "net.mullvad.mullvadvpn.model")]
+pub enum RemoveDeviceResult {
+ Ok,
+ NotFound,
+ RpcError,
+ OtherError,
+}
+
+impl From<Result<(), daemon_interface::Error>> for RemoveDeviceResult {
+ fn from(result: Result<(), daemon_interface::Error>) -> Self {
+ match result {
+ Ok(()) => RemoveDeviceResult::Ok,
+ Err(error) => match error {
+ daemon_interface::Error::OtherError(mullvad_daemon::Error::LoginError(error)) => {
+ match error {
+ device::Error::InvalidAccount => RemoveDeviceResult::RpcError,
+ device::Error::InvalidDevice => RemoveDeviceResult::NotFound,
+ device::Error::OtherRestError(_) => RemoveDeviceResult::RpcError,
+ _ => RemoveDeviceResult::OtherError,
+ }
+ }
+ daemon_interface::Error::RpcError(_) => RemoveDeviceResult::RpcError,
+ _ => RemoveDeviceResult::OtherError,
+ },
+ }
+ }
+}
+
+#[derive(IntoJava)]
+#[jnix(package = "net.mullvad.mullvadvpn.model")]
pub enum VoucherSubmissionResult {
Ok(VoucherSubmission),
Error(VoucherSubmissionError),
@@ -439,66 +499,6 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_disconn
#[no_mangle]
#[allow(non_snake_case)]
-pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_generateWireguardKey<
- 'env,
->(
- env: JNIEnv<'env>,
- _: JObject<'_>,
- daemon_interface_address: jlong,
-) -> JObject<'env> {
- let env = JnixEnv::from(env);
-
- if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
- match daemon_interface.generate_wireguard_key() {
- Ok(keygen_event) => keygen_event.into_java(&env).forget(),
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to request to generate wireguard key")
- );
- JObject::null()
- }
- }
- } else {
- JObject::null()
- }
-}
-
-#[no_mangle]
-#[allow(non_snake_case)]
-pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_verifyWireguardKey<
- 'env,
->(
- env: JNIEnv<'env>,
- _: JObject<'_>,
- daemon_interface_address: jlong,
-) -> JObject<'env> {
- let env = JnixEnv::from(env);
-
- if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
- match daemon_interface.verify_wireguard_key() {
- Ok(key_is_valid) => env
- .new_object(
- &env.get_class("java/lang/Boolean"),
- "(Z)V",
- &[JValue::Bool(key_is_valid as jboolean)],
- )
- .expect("Failed to create Boolean Java object"),
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to verify wireguard key")
- );
- JObject::null()
- }
- }
- } else {
- JObject::null()
- }
-}
-
-#[no_mangle]
-#[allow(non_snake_case)]
pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_getAccountHistory<'env>(
env: JNIEnv<'env>,
_: JObject<'_>,
@@ -768,25 +768,118 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_clearAc
#[no_mangle]
#[allow(non_snake_case)]
-pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_setAccount(
- env: JNIEnv<'_>,
+pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_loginAccount<'env>(
+ env: JNIEnv<'env>,
_: JObject<'_>,
daemon_interface_address: jlong,
accountToken: JString<'_>,
+) -> JObject<'env> {
+ let env = JnixEnv::from(env);
+
+ if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
+ let account = String::from_java(&env, accountToken);
+ let result = daemon_interface.login_account(account);
+
+ if let Err(ref error) = &result {
+ log_request_error("login account", error);
+ }
+
+ LoginResult::from(result).into_java(&env).forget()
+ } else {
+ LoginResult::OtherError.into_java(&env).forget()
+ }
+}
+
+#[no_mangle]
+#[allow(non_snake_case)]
+pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_logoutAccount(
+ _: JNIEnv<'_>,
+ _: JObject<'_>,
+ daemon_interface_address: jlong,
) {
+ if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
+ if let Err(error) = daemon_interface.logout_account() {
+ log::error!("{}", error.display_chain_with_msg("Failed to log out"));
+ }
+ }
+}
+
+#[no_mangle]
+#[allow(non_snake_case)]
+pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_getDevice<'env>(
+ env: JNIEnv<'env>,
+ _: JObject<'_>,
+ daemon_interface_address: jlong,
+) -> JObject<'env> {
let env = JnixEnv::from(env);
if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
- let account = Option::from_java(&env, accountToken);
+ match daemon_interface.get_device() {
+ Ok(key) => key.into_java(&env).forget(),
+ Err(error) => {
+ log::error!("{}", error.display_chain_with_msg("Failed to get device"));
+ JObject::null()
+ }
+ }
+ } else {
+ JObject::null()
+ }
+}
- if let Err(error) = daemon_interface.set_account(account) {
- log::error!("{}", error.display_chain_with_msg("Failed to set account"));
+#[no_mangle]
+#[allow(non_snake_case)]
+pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_listDevices<'env>(
+ env: JNIEnv<'env>,
+ _: JObject<'_>,
+ daemon_interface_address: jlong,
+ account_token: JString<'_>,
+) -> JObject<'env> {
+ let env = JnixEnv::from(env);
+
+ if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
+ let token = String::from_java(&env, account_token);
+ match daemon_interface.list_devices(token) {
+ Ok(key) => key.into_java(&env).forget(),
+ Err(error) => {
+ log::error!("{}", error.display_chain_with_msg("Failed to list devices"));
+ JObject::null()
+ }
}
+ } else {
+ JObject::null()
}
}
#[no_mangle]
#[allow(non_snake_case)]
+pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_removeDevice<'env>(
+ env: JNIEnv<'env>,
+ _: JObject<'_>,
+ daemon_interface_address: jlong,
+ account_token: JString<'_>,
+ device_id: JString<'_>,
+) -> JObject<'env> {
+ let env = JnixEnv::from(env);
+
+ let result = if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) {
+ let token = String::from_java(&env, account_token);
+ let device_id = String::from_java(&env, device_id);
+ let raw_result = daemon_interface.remove_device(token, device_id);
+
+ if let Err(ref error) = &raw_result {
+ log_request_error("remove device", error);
+ }
+
+ RemoveDeviceResult::from(raw_result)
+ } else {
+ RemoveDeviceResult::OtherError
+ };
+
+ result.into_java(&env).forget()
+}
+
+#[no_mangle]
+#[allow(non_snake_case)]
pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_setAllowLan(
env: JNIEnv<'_>,
_: JObject<'_>,
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index e690557aae..21eb6ab512 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -44,19 +44,24 @@ service ManagementService {
// Account management
rpc CreateNewAccount(google.protobuf.Empty) returns (google.protobuf.StringValue) {}
- rpc SetAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc LoginAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc LogoutAccount(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetAccountData(google.protobuf.StringValue) returns (AccountData) {}
rpc GetAccountHistory(google.protobuf.Empty) returns (AccountHistory) {}
rpc ClearAccountHistory(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetWwwAuthToken(google.protobuf.Empty) returns (google.protobuf.StringValue) {}
rpc SubmitVoucher(google.protobuf.StringValue) returns (VoucherSubmission) {}
+ // Device management
+ rpc GetDevice(google.protobuf.Empty) returns (DeviceConfig) {}
+ rpc ListDevices(google.protobuf.StringValue) returns (DeviceList) {}
+ rpc RemoveDevice(DeviceRemoval) returns (google.protobuf.Empty) {}
+
// WireGuard key management
rpc SetWireguardRotationInterval(google.protobuf.Duration) returns (google.protobuf.Empty) {}
rpc ResetWireguardRotationInterval(google.protobuf.Empty) returns (google.protobuf.Empty) {}
- rpc GenerateWireguardKey(google.protobuf.Empty) returns (KeygenEvent) {}
+ rpc RotateWireguardKey(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {}
- rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {}
// Split tunneling (Linux)
rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {}
@@ -265,16 +270,15 @@ message BridgeState {
}
message Settings {
- string account_token = 1;
- RelaySettings relay_settings = 2;
- BridgeSettings bridge_settings = 3;
- BridgeState bridge_state = 4;
- bool allow_lan = 5;
- bool block_when_disconnected = 6;
- bool auto_connect = 7;
- TunnelOptions tunnel_options = 8;
- bool show_beta_releases = 9;
- SplitTunnelSettings split_tunnel = 10;
+ RelaySettings relay_settings = 1;
+ BridgeSettings bridge_settings = 2;
+ BridgeState bridge_state = 3;
+ bool allow_lan = 4;
+ bool block_when_disconnected = 5;
+ bool auto_connect = 6;
+ TunnelOptions tunnel_options = 7;
+ bool show_beta_releases = 8;
+ SplitTunnelSettings split_tunnel = 9;
}
message SplitTunnelSettings {
@@ -423,16 +427,6 @@ message PublicKey {
google.protobuf.Timestamp created = 2;
}
-message KeygenEvent {
- enum KeygenEvent {
- NEW_KEY = 0;
- TOO_MANY_KEYS = 1;
- GENERATION_FAILURE = 2;
- }
- KeygenEvent event = 1;
- PublicKey new_key = 2;
-}
-
message AppVersionInfo {
bool supported = 1;
string latest_stable = 2;
@@ -521,10 +515,47 @@ message DaemonEvent {
Settings settings = 2;
RelayList relay_list = 3;
AppVersionInfo version_info = 4;
- KeygenEvent key_event = 5;
+ DeviceEvent device = 5;
+ RemoveDeviceEvent remove_device = 6;
}
}
message RelayList {
repeated RelayListCountry countries = 1;
}
+
+message DeviceConfig {
+ string account_token = 1;
+ Device device = 2;
+}
+
+message Device {
+ string id = 1;
+ string name = 2;
+ bytes pubkey = 3;
+ repeated DevicePort ports = 4;
+}
+
+message DevicePort {
+ string id = 1;
+}
+
+message DeviceList {
+ repeated Device devices = 1;
+}
+
+message DeviceRemoval {
+ string account_token = 1;
+ string device_id = 2;
+}
+
+message DeviceEvent {
+ DeviceConfig device = 1;
+ bool remote = 2;
+}
+
+message RemoveDeviceEvent {
+ string account_token = 1;
+ Device removed_device = 2;
+ repeated Device new_device_list = 3;
+}
diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs
index 5398927569..c76ada98d1 100644
--- a/mullvad-management-interface/src/types.rs
+++ b/mullvad-management-interface/src/types.rs
@@ -2,7 +2,7 @@ pub use prost_types::{Duration, Timestamp};
use mullvad_types::relay_constraints::Constraint;
use std::convert::TryFrom;
-use talpid_types::ErrorExt;
+use talpid_types::{net::wireguard, ErrorExt};
tonic::include_proto!("mullvad_daemon.management_interface");
@@ -197,22 +197,58 @@ impl From<mullvad_types::states::TunnelState> for TunnelState {
}
}
-impl From<mullvad_types::wireguard::KeygenEvent> for KeygenEvent {
- fn from(event: mullvad_types::wireguard::KeygenEvent) -> Self {
- use keygen_event::KeygenEvent as Event;
- use mullvad_types::wireguard::KeygenEvent as MullvadEvent;
+impl From<mullvad_types::device::Device> for Device {
+ fn from(device: mullvad_types::device::Device) -> Self {
+ Device {
+ id: device.id,
+ name: device.name,
+ pubkey: device.pubkey.as_bytes().to_vec(),
+ ports: device.ports.into_iter().map(DevicePort::from).collect(),
+ }
+ }
+}
- KeygenEvent {
- event: match event {
- MullvadEvent::NewKey(_) => i32::from(Event::NewKey),
- MullvadEvent::TooManyKeys => i32::from(Event::TooManyKeys),
- MullvadEvent::GenerationFailure => i32::from(Event::GenerationFailure),
- },
- new_key: if let MullvadEvent::NewKey(key) = event {
- Some(PublicKey::from(key))
- } else {
- None
- },
+impl From<mullvad_types::device::DevicePort> for DevicePort {
+ fn from(port: mullvad_types::device::DevicePort) -> Self {
+ DevicePort { id: port.id }
+ }
+}
+
+impl From<mullvad_types::device::DeviceEvent> for DeviceEvent {
+ fn from(event: mullvad_types::device::DeviceEvent) -> Self {
+ DeviceEvent {
+ device: event.device.map(|config| DeviceConfig {
+ account_token: config.token,
+ device: Some(Device::from(config.device)),
+ }),
+ remote: event.remote,
+ }
+ }
+}
+
+impl From<mullvad_types::device::RemoveDeviceEvent> for RemoveDeviceEvent {
+ fn from(event: mullvad_types::device::RemoveDeviceEvent) -> Self {
+ RemoveDeviceEvent {
+ account_token: event.account_token,
+ removed_device: Some(Device::from(event.removed_device)),
+ new_device_list: event.new_devices.into_iter().map(Device::from).collect(),
+ }
+ }
+}
+
+impl From<mullvad_types::device::DeviceConfig> for DeviceConfig {
+ fn from(device: mullvad_types::device::DeviceConfig) -> Self {
+ DeviceConfig {
+ account_token: device.token,
+ device: Some(Device::from(device.device)),
+ }
+ }
+}
+
+impl From<Vec<mullvad_types::device::Device>> for DeviceList {
+ fn from(devices: Vec<mullvad_types::device::Device>) -> Self {
+ DeviceList {
+ devices: devices.into_iter().map(Device::from).collect(),
}
}
}
@@ -387,7 +423,6 @@ impl From<&mullvad_types::settings::Settings> for Settings {
let split_tunnel = None;
Self {
- account_token: settings.get_account_token().unwrap_or_default(),
relay_settings: Some(RelaySettings::from(settings.get_relay_settings())),
bridge_settings: Some(BridgeSettings::from(settings.bridge_settings.clone())),
bridge_state: Some(BridgeState::from(settings.get_bridge_state())),
@@ -689,6 +724,29 @@ pub enum FromProtobufTypeError {
InvalidArgument(&'static str),
}
+impl TryFrom<Device> for mullvad_types::device::Device {
+ type Error = FromProtobufTypeError;
+
+ fn try_from(device: Device) -> Result<Self, Self::Error> {
+ Ok(mullvad_types::device::Device {
+ id: device.id,
+ name: device.name,
+ pubkey: bytes_to_pubkey(&device.pubkey)?,
+ ports: device
+ .ports
+ .into_iter()
+ .map(mullvad_types::device::DevicePort::from)
+ .collect(),
+ })
+ }
+}
+
+impl From<DevicePort> for mullvad_types::device::DevicePort {
+ fn from(port: DevicePort) -> Self {
+ mullvad_types::device::DevicePort { id: port.id }
+ }
+}
+
impl TryFrom<&WireguardConstraints> for mullvad_types::relay_constraints::WireguardConstraints {
type Error = FromProtobufTypeError;
@@ -929,7 +987,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig {
type Error = FromProtobufTypeError;
fn try_from(config: ConnectionConfig) -> Result<mullvad_types::ConnectionConfig, Self::Error> {
- use talpid_types::net::{self, openvpn, wireguard};
+ use talpid_types::net::{self, openvpn};
let config = config.config.ok_or(FromProtobufTypeError::InvalidArgument(
"missing connection config",
@@ -974,14 +1032,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig {
"missing peer config",
))?;
- // Copy the public key to an array
- if peer.public_key.len() != 32 {
- return Err(FromProtobufTypeError::InvalidArgument("invalid public key"));
- }
-
- let mut public_key = [0; 32];
- let buffer = &peer.public_key[..public_key.len()];
- public_key.copy_from_slice(buffer);
+ let public_key = bytes_to_pubkey(&peer.public_key)?;
let ipv4_gateway = match config.ipv4_gateway.parse() {
Ok(address) => address,
@@ -1037,7 +1088,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig {
addresses: tunnel_addresses,
},
peer: wireguard::PeerConfig {
- public_key: wireguard::PublicKey::from(public_key),
+ public_key,
allowed_ips,
endpoint,
protocol: try_transport_protocol_from_i32(peer.protocol)?,
@@ -1052,6 +1103,15 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig {
}
}
+fn bytes_to_pubkey(bytes: &[u8]) -> Result<wireguard::PublicKey, FromProtobufTypeError> {
+ if bytes.len() != 32 {
+ return Err(FromProtobufTypeError::InvalidArgument("invalid public key"));
+ }
+ let mut public_key = [0; 32];
+ public_key.copy_from_slice(&bytes[..32]);
+ Ok(wireguard::PublicKey::from(public_key))
+}
+
impl From<RelayLocation> for Constraint<mullvad_types::relay_constraints::LocationConstraint> {
fn from(location: RelayLocation) -> Self {
use mullvad_types::relay_constraints::LocationConstraint;
diff --git a/mullvad-rpc/src/access.rs b/mullvad-rpc/src/access.rs
new file mode 100644
index 0000000000..d95a5319c2
--- /dev/null
+++ b/mullvad-rpc/src/access.rs
@@ -0,0 +1,110 @@
+use crate::{
+ rest,
+ rest::{RequestFactory, RequestServiceHandle},
+};
+use hyper::StatusCode;
+use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
+use std::{
+ collections::HashMap,
+ sync::{Arc, Mutex},
+};
+use talpid_types::ErrorExt;
+
+pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1";
+
+#[derive(Clone)]
+pub struct AccessTokenProxy {
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
+}
+
+impl AccessTokenProxy {
+ pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
+ Self {
+ service,
+ factory,
+ access_from_account: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ /// Obtain access token for an account, requesting a new one from the API if necessary.
+ pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
+ let existing_token = {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .get(account.as_str())
+ .cloned()
+ };
+ if let Some(access_token) = existing_token {
+ if access_token.is_expired() {
+ log::debug!("Replacing expired access token");
+ return self.request_new_token(account.clone()).await;
+ }
+ log::trace!("Using stored access token");
+ return Ok(access_token.access_token.clone());
+ }
+ self.request_new_token(account.clone()).await
+ }
+
+ /// Remove an access token if the API response calls for it.
+ pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
+ if let Err(rest::Error::ApiError(_status, code)) = response {
+ if code == crate::INVALID_ACCESS_TOKEN {
+ log::debug!("Dropping invalid access token");
+ self.remove_token(account);
+ }
+ }
+ }
+
+ /// Removes a stored access token.
+ fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .remove(account)
+ .map(|v| v.access_token)
+ }
+
+ async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
+ log::debug!("Fetching access token for an account");
+ let access_token = self
+ .fetch_access_token(account.clone())
+ .await
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain access token")
+ );
+ error
+ })?;
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .insert(account, access_token.clone());
+ Ok(access_token.access_token)
+ }
+
+ async fn fetch_access_token(
+ &self,
+ account_token: AccountToken,
+ ) -> Result<AccessTokenData, rest::Error> {
+ #[derive(serde::Serialize)]
+ struct AccessTokenRequest {
+ account_number: String,
+ }
+ let request = AccessTokenRequest {
+ account_number: account_token,
+ };
+
+ let service = self.service.clone();
+
+ let rest_request = self
+ .factory
+ .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?;
+ let response = service.request(rest_request).await?;
+ let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(response).await
+ }
+}
diff --git a/mullvad-rpc/src/availability.rs b/mullvad-rpc/src/availability.rs
index da8d624e80..2cf40cf53b 100644
--- a/mullvad-rpc/src/availability.rs
+++ b/mullvad-rpc/src/availability.rs
@@ -122,10 +122,26 @@ impl ApiAvailabilityHandle {
self.wait_for_state(|state| !state.is_suspended())
}
+ pub fn when_bg_resumes<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> {
+ let wait_task = self.wait_for_state(|state| !state.is_background_paused());
+ async move {
+ let _ = wait_task.await;
+ task.await
+ }
+ }
+
pub fn wait_background(&self) -> impl Future<Output = Result<(), Error>> {
self.wait_for_state(|state| !state.is_background_paused())
}
+ pub fn when_online<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> {
+ let wait_task = self.wait_for_state(|state| !state.is_offline());
+ async move {
+ let _ = wait_task.await;
+ task.await
+ }
+ }
+
pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> {
self.wait_for_state(|state| !state.is_offline())
}
diff --git a/mullvad-rpc/src/device.rs b/mullvad-rpc/src/device.rs
new file mode 100644
index 0000000000..de572aa20d
--- /dev/null
+++ b/mullvad-rpc/src/device.rs
@@ -0,0 +1,196 @@
+use http::{Method, StatusCode};
+use mullvad_types::{
+ account::AccountToken,
+ device::{Device, DeviceId, DeviceName, DevicePort},
+};
+use std::future::Future;
+use talpid_types::net::wireguard;
+
+use crate::rest;
+
+use super::ACCOUNTS_URL_PREFIX;
+
+#[derive(Clone)]
+pub struct DevicesProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+#[derive(serde::Deserialize)]
+struct DeviceResponse {
+ id: DeviceId,
+ name: DeviceName,
+ pubkey: wireguard::PublicKey,
+ ipv4_address: ipnetwork::Ipv4Network,
+ ipv6_address: ipnetwork::Ipv6Network,
+ ports: Vec<DevicePort>,
+}
+
+impl DevicesProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn create(
+ &self,
+ account: AccountToken,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct DeviceSubmission {
+ pubkey: wireguard::PublicKey,
+ }
+
+ let submission = DeviceSubmission { pubkey };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account)),
+ &[StatusCode::CREATED],
+ )
+ .await;
+
+ let response: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ id,
+ name,
+ pubkey,
+ ipv4_address,
+ ipv6_address,
+ ports,
+ ..
+ } = response;
+
+ Ok((
+ Device {
+ id,
+ name,
+ pubkey,
+ ports,
+ },
+ mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ },
+ ))
+ }
+ }
+
+ pub fn get(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<Device, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn list(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn remove(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<(), rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::DELETE,
+ Some((access_proxy, account)),
+ &[StatusCode::NO_CONTENT],
+ )
+ .await;
+
+ response?;
+ Ok(())
+ }
+ }
+
+ pub fn replace_wg_key(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct RotateDevicePubkey {
+ pubkey: wireguard::PublicKey,
+ }
+ let req_body = RotateDevicePubkey { pubkey };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id),
+ Method::PUT,
+ &req_body,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ ipv4_address,
+ ipv6_address,
+ ..
+ } = updated_device;
+ Ok(mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ })
+ }
+ }
+}
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 614aa3bdb6..f93d27262a 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -16,7 +16,7 @@ use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
path::Path,
};
-use talpid_types::{net::wireguard, ErrorExt};
+use talpid_types::ErrorExt;
pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
@@ -29,9 +29,12 @@ mod tls_stream;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
+mod access;
mod address_cache;
+pub mod device;
mod relay_list;
pub use address_cache::AddressCache;
+pub use device::DevicesProxy;
pub use hyper::StatusCode;
pub use relay_list::RelayListProxy;
@@ -44,11 +47,17 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER";
/// Error code returned by the Mullvad API if the account token is invalid.
pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT";
-/// Error code returned by the Mullvad API if the account token is missing or invalid.
-pub const INVALID_AUTH: &str = "INVALID_AUTH";
+/// Error code returned by the Mullvad API if the access token is invalid.
+pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN";
+
+pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED";
+pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE";
pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt";
+const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1";
+const APP_URL_PREFIX: &str = "app/v1";
+
lazy_static::lazy_static! {
static ref API: ApiEndpoint = ApiEndpoint::get();
}
@@ -257,7 +266,7 @@ impl MullvadRpcRuntime {
self.socket_bypass_tx.clone(),
)
.await;
- let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned()));
+ let factory = rest::RequestFactory::new(API.host.clone(), None);
rest::MullvadRestHandle::new(
service,
@@ -295,8 +304,8 @@ pub struct AccountsProxy {
#[derive(serde::Deserialize)]
struct AccountResponse {
- token: AccountToken,
- expires: DateTime<Utc>,
+ number: AccountToken,
+ expiry: DateTime<Utc>,
}
impl AccountsProxy {
@@ -309,18 +318,21 @@ impl AccountsProxy {
account: AccountToken,
) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> {
let service = self.handle.service.clone();
-
- let response = rest::send_request(
- &self.handle.factory,
- service,
- "/v1/me",
- Method::GET,
- Some(account),
- &[StatusCode::OK],
- );
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
async move {
- let account: AccountResponse = rest::deserialize_body(response.await?).await?;
- Ok(account.expires)
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let account: AccountResponse = rest::deserialize_body(response?).await?;
+ Ok(account.expiry)
}
}
@@ -329,7 +341,7 @@ impl AccountsProxy {
let response = rest::send_request(
&self.handle.factory,
service,
- "/v1/accounts",
+ &format!("{}/accounts", ACCOUNTS_URL_PREFIX),
Method::POST,
None,
&[StatusCode::CREATED],
@@ -337,7 +349,7 @@ impl AccountsProxy {
async move {
let account: AccountResponse = rest::deserialize_body(response.await?).await?;
- Ok(account.token)
+ Ok(account.number)
}
}
@@ -352,18 +364,23 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
let submission = VoucherSubmission { voucher_code };
- let response = rest::post_request_with_json(
- &self.handle.factory,
- service,
- "/v1/submit-voucher",
- &submission,
- Some(account_token),
- &[StatusCode::OK],
- );
-
- async move { rest::deserialize_body(response.await?).await }
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/submit-voucher", APP_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account_token)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
}
pub fn get_www_auth_token(
@@ -376,17 +393,20 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
- let response = rest::send_request(
- &self.handle.factory,
- service,
- "/v1/www-auth-token",
- Method::POST,
- Some(account),
- &[StatusCode::OK],
- );
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
async move {
- let response: AuthTokenResponse = rest::deserialize_body(response.await?).await?;
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/www-auth-token", APP_URL_PREFIX),
+ Method::POST,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
Ok(response.auth_token)
}
}
@@ -425,10 +445,11 @@ impl ProblemReportProxy {
let service = self.handle.service.clone();
- let request = rest::post_request_with_json(
+ let request = rest::send_json_request(
&self.handle.factory,
service,
- "/v1/problem-report",
+ &format!("{}/problem-report", APP_URL_PREFIX),
+ Method::POST,
&report,
None,
&[StatusCode::NO_CONTENT],
@@ -467,7 +488,7 @@ impl AppVersionProxy {
) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> {
let service = self.handle.service.clone();
- let path = format!("/v1/releases/{}/{}", platform, app_version);
+ let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version);
let request = self.handle.factory.request(&path, Method::GET);
async move {
@@ -481,123 +502,6 @@ impl AppVersionProxy {
}
}
-/// Error code for when an account has too many keys. Returned when trying to push a new key.
-pub const KEY_LIMIT_REACHED: &str = "KEY_LIMIT_REACHED";
-#[derive(Clone)]
-pub struct WireguardKeyProxy {
- handle: rest::MullvadRestHandle,
-}
-
-impl WireguardKeyProxy {
- pub fn new(handle: rest::MullvadRestHandle) -> Self {
- Self { handle }
- }
-
- pub fn push_wg_key(
- &mut self,
- account_token: AccountToken,
- public_key: wireguard::PublicKey,
- timeout: Option<std::time::Duration>,
- ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + 'static
- {
- #[derive(serde::Serialize)]
- struct PublishRequest {
- pubkey: wireguard::PublicKey,
- }
-
- let service = self.handle.service.clone();
- let body = PublishRequest { pubkey: public_key };
-
- let request = self.handle.factory.post_json(&"/v1/wireguard-keys", &body);
- async move {
- let mut request = request?;
- if let Some(timeout) = timeout {
- request.set_timeout(timeout);
- }
- request.set_auth(Some(account_token))?;
- let response = service.request(request).await?;
- rest::deserialize_body(
- rest::parse_rest_response(response, &[StatusCode::CREATED]).await?,
- )
- .await
- }
- }
-
- pub async fn replace_wg_key(
- &mut self,
- account_token: AccountToken,
- old: wireguard::PublicKey,
- new: wireguard::PublicKey,
- ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> {
- #[derive(serde::Serialize)]
- struct ReplacementRequest {
- old: wireguard::PublicKey,
- new: wireguard::PublicKey,
- }
-
- let service = self.handle.service.clone();
- let body = ReplacementRequest { old, new };
-
- let response = rest::post_request_with_json(
- &self.handle.factory,
- service,
- &"/v1/replace-wireguard-key",
- &body,
- Some(account_token),
- [StatusCode::CREATED, StatusCode::OK].as_slice(),
- )
- .await?;
-
- rest::deserialize_body(response).await
- }
-
- pub async fn get_wireguard_key(
- &mut self,
- account_token: AccountToken,
- key: &wireguard::PublicKey,
- ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> {
- let service = self.handle.service.clone();
-
- let response = rest::send_request(
- &self.handle.factory,
- service,
- &format!(
- "/v1/wireguard-keys/{}",
- urlencoding::encode(&key.to_base64())
- ),
- Method::GET,
- Some(account_token),
- &[StatusCode::OK],
- )
- .await?;
-
- rest::deserialize_body(response).await
- }
-
- pub fn remove_wireguard_key(
- &mut self,
- account_token: AccountToken,
- key: wireguard::PublicKey,
- ) -> impl Future<Output = Result<(), rest::Error>> {
- let service = self.handle.service.clone();
- let future = rest::send_request(
- &self.handle.factory,
- service,
- &format!(
- "/v1/wireguard-keys/{}",
- urlencoding::encode(&key.to_base64())
- ),
- Method::DELETE,
- Some(account_token),
- &[StatusCode::NO_CONTENT],
- );
- async move {
- let _ = future.await?;
- Ok(())
- }
- }
-}
-
#[derive(Clone)]
pub struct ApiProxy {
handle: rest::MullvadRestHandle,
@@ -614,7 +518,7 @@ impl ApiProxy {
let response = rest::send_request(
&self.handle.factory,
service,
- "/v1/api-addrs",
+ &format!("{}/api-addrs", APP_URL_PREFIX),
Method::GET,
None,
&[StatusCode::OK],
diff --git a/mullvad-rpc/src/relay_list.rs b/mullvad-rpc/src/relay_list.rs
index f1ed2217fd..5a8a01836f 100644
--- a/mullvad-rpc/src/relay_list.rs
+++ b/mullvad-rpc/src/relay_list.rs
@@ -13,7 +13,7 @@ use std::{
time::Duration,
};
-/// Fetches relay list from <https://api.mullvad.net/v1/relays>
+/// Fetches relay list from https://api.mullvad.net/app/v1/relays
#[derive(Clone)]
pub struct RelayListProxy {
handle: rest::MullvadRestHandle,
@@ -33,7 +33,7 @@ impl RelayListProxy {
etag: Option<String>,
) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> {
let service = self.handle.service.clone();
- let request = self.handle.factory.request("/v1/relays", Method::GET);
+ let request = self.handle.factory.request("app/v1/relays", Method::GET);
let future = async move {
let mut request = request?;
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 17362cce05..6f36a2a096 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,6 +1,7 @@
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
+ access::AccessTokenProxy,
address_cache::AddressCache,
availability::ApiAvailabilityHandle,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
@@ -17,6 +18,7 @@ use hyper::{
header::{self, HeaderValue},
Method, Uri,
};
+use mullvad_types::account::AccountToken;
use std::{
future::Future,
str::FromStr,
@@ -29,6 +31,8 @@ pub use hyper::StatusCode;
pub type Request = hyper::Request<hyper::Body>;
pub type Response = hyper::Response<hyper::Body>;
+const USER_AGENT: &str = "mullvad-app";
+
const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60);
const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60);
const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
@@ -285,6 +289,7 @@ impl RestRequest {
let mut builder = http::request::Builder::new()
.method(Method::GET)
+ .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"));
if let Some(host) = uri.host() {
builder = builder.header(header::HOST, HeaderValue::from_str(&host)?);
@@ -302,11 +307,11 @@ impl RestRequest {
})
}
- /// Set the auth header with the following format: `Token $auth`.
+ /// Set the auth header with the following format: `Bearer $auth`.
pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> {
let header = match auth {
Some(auth) => Some(
- HeaderValue::from_str(&format!("Token {}", auth))
+ HeaderValue::from_str(&format!("Bearer {}", auth))
.map_err(Error::InvalidHeaderError)?,
),
None => None,
@@ -399,7 +404,16 @@ impl RequestFactory {
}
pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
- let mut request = self.hyper_request(path, Method::POST)?;
+ self.json_request(Method::POST, path, body)
+ }
+
+ fn json_request<S: serde::Serialize>(
+ &self,
+ method: Method,
+ path: &str,
+ body: &S,
+ ) -> Result<RestRequest> {
+ let mut request = self.hyper_request(path, method)?;
let json_body = serde_json::to_string(&body)?;
let body_length = json_body.as_bytes().len() as u64;
@@ -429,6 +443,7 @@ impl RequestFactory {
let request = http::request::Builder::new()
.method(method)
.uri(uri)
+ .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
.header(header::HOST, self.hostname.clone());
@@ -468,44 +483,64 @@ pub fn send_request(
service: RequestServiceHandle,
uri: &str,
method: Method,
- auth: Option<String>,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
let request = factory.request(uri, method);
async move {
let mut request = request?;
- request.set_auth(auth)?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
}
}
-pub fn post_request_with_json<B: serde::Serialize>(
+pub fn send_json_request<B: serde::Serialize>(
factory: &RequestFactory,
service: RequestServiceHandle,
uri: &str,
+ method: Method,
body: &B,
- auth: Option<String>,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
- let request = factory.post_json(uri, body);
+ let request = factory.json_request(method, uri, body);
async move {
let mut request = request?;
- request.set_auth(auth)?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
}
}
-pub async fn deserialize_body<T: serde::de::DeserializeOwned>(mut response: Response) -> Result<T> {
- let body_length: usize = response
- .headers()
- .get(header::CONTENT_LENGTH)
- .and_then(|header_value| header_value.to_str().ok())
- .and_then(|length| length.parse::<usize>().ok())
- .unwrap_or(0);
+pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> {
+ let body_length = get_body_length(&response);
+ deserialize_body_inner(response, body_length).await
+}
+async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
+ mut response: Response,
+ body_length: usize,
+) -> Result<T> {
let mut body: Vec<u8> = Vec::with_capacity(body_length);
while let Some(chunk) = response.body_mut().next().await {
body.extend(&chunk?);
@@ -514,6 +549,15 @@ pub async fn deserialize_body<T: serde::de::DeserializeOwned>(mut response: Resp
serde_json::from_slice(&body).map_err(Error::DeserializeError)
}
+fn get_body_length(response: &Response) -> usize {
+ response
+ .headers()
+ .get(header::CONTENT_LENGTH)
+ .and_then(|header_value| header_value.to_str().ok())
+ .and_then(|length| length.parse::<usize>().ok())
+ .unwrap_or(0)
+}
+
pub async fn parse_rest_response(
response: Response,
expected_statuses: &'static [hyper::StatusCode],
@@ -537,23 +581,27 @@ pub async fn parse_rest_response(
}
pub async fn handle_error_response<T>(response: Response) -> Result<T> {
- let error_message = match response.status() {
+ let status = response.status();
+ let error_message = match status {
hyper::StatusCode::NOT_FOUND => "Not found",
hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed",
- status => {
- let err: ErrorResponse = deserialize_body(response).await?;
-
- return Err(Error::ApiError(status, err.code));
- }
+ status => match get_body_length(&response) {
+ 0 => status.canonical_reason().unwrap_or("Unexpected error"),
+ body_length => {
+ let err: ErrorResponse = deserialize_body_inner(response, body_length).await?;
+ return Err(Error::ApiError(status, err.code));
+ }
+ },
};
- Err(Error::ApiError(response.status(), error_message.to_owned()))
+ Err(Error::ApiError(status, error_message.to_owned()))
}
#[derive(Clone)]
pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
- availability: ApiAvailabilityHandle,
+ pub availability: ApiAvailabilityHandle,
+ pub token_store: AccessTokenProxy,
}
impl MullvadRestHandle {
@@ -563,10 +611,13 @@ impl MullvadRestHandle {
address_cache: AddressCache,
availability: ApiAvailabilityHandle,
) -> Self {
+ let token_store = AccessTokenProxy::new(service.clone(), factory.clone());
+
let handle = Self {
service,
factory,
availability,
+ token_store,
};
if !super::API.disable_address_cache {
handle.spawn_api_address_fetcher(address_cache);
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index e65b1278f8..e9289b115a 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -17,7 +17,7 @@ lazy_static::lazy_static! {
}
const KEY_RETRY_INTERVAL: Duration = Duration::ZERO;
-const KEY_RETRY_MAX_RETRIES: usize = 2;
+const KEY_RETRY_MAX_RETRIES: usize = 4;
#[repr(i32)]
enum ExitStatus {
@@ -63,8 +63,8 @@ pub enum Error {
#[error(display = "Failed to initialize mullvad RPC runtime")]
RpcInitializationError(#[error(source)] mullvad_rpc::Error),
- #[error(display = "Failed to remove WireGuard key for account")]
- RemoveKeyError(#[error(source)] mullvad_rpc::rest::Error),
+ #[error(display = "Failed to remove device from account")]
+ RemoveDeviceError(#[error(source)] mullvad_rpc::rest::Error),
#[error(display = "Failed to obtain settings directory path")]
SettingsPathError(#[error(source)] SettingsPathErrorType),
@@ -72,8 +72,11 @@ pub enum Error {
#[error(display = "Failed to obtain cache directory path")]
CachePathError(#[error(source)] mullvad_paths::Error),
- #[error(display = "Failed to update the settings")]
- SettingsError(#[error(source)] mullvad_daemon::settings::Error),
+ #[error(display = "Failed to read the device cache")]
+ ReadDeviceCacheError(#[error(source)] mullvad_daemon::device::Error),
+
+ #[error(display = "Failed to write the device cache")]
+ WriteDeviceCacheError(#[error(source)] mullvad_daemon::device::Error),
#[error(display = "Cannot parse the version string")]
ParseVersionStringError,
@@ -87,7 +90,7 @@ async fn main() {
App::new("prepare-restart")
.about("Move a running daemon into a blocking state and save its target state"),
App::new("reset-firewall").about("Remove any firewall rules introduced by the daemon"),
- App::new("remove-wireguard-key").about("Removes the WireGuard key from the active account"),
+ App::new("remove-device").about("Remove the current device from the active account"),
App::new("is-older-version")
.about("Checks whether the given version is older than the current version")
.arg(
@@ -110,7 +113,7 @@ async fn main() {
let result = match matches.subcommand() {
Some(("prepare-restart", _)) => prepare_restart().await,
Some(("reset-firewall", _)) => reset_firewall().await,
- Some(("remove-wireguard-key", _)) => remove_wireguard_key().await,
+ Some(("remove-device", _)) => remove_device().await,
Some(("is-older-version", sub_matches)) => {
let old_version = sub_matches.value_of("OLDVERSION").unwrap();
match is_older_version(old_version).await {
@@ -159,43 +162,42 @@ async fn reset_firewall() -> Result<(), Error> {
.map_err(Error::FirewallError)
}
-async fn remove_wireguard_key() -> Result<(), Error> {
+async fn remove_device() -> Result<(), Error> {
let (cache_path, settings_path) = get_paths()?;
- let mut settings = mullvad_daemon::settings::SettingsPersister::load(&settings_path).await;
+ let (cacher, data) = mullvad_daemon::device::DeviceCacher::new(&settings_path)
+ .await
+ .map_err(Error::ReadDeviceCacheError)?;
+ if let Some(device) = data {
+ let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false)
+ .await
+ .map_err(Error::RpcInitializationError)?;
- if let Some(token) = settings.get_account_token() {
- if let Some(wg_data) = settings.get_wireguard() {
- let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false)
- .await
- .map_err(Error::RpcInitializationError)?;
- let mut key_proxy = mullvad_rpc::WireguardKeyProxy::new(
- rpc_runtime
- .mullvad_rest_handle(
- ApiConnectionMode::try_from_cache(&cache_path)
- .await
- .into_repeat(),
- |_| async { true },
- )
- .await,
- );
- retry_future_n(
- move || {
- key_proxy.remove_wireguard_key(token.clone(), wg_data.private_key.public_key())
- },
- move |result| match result {
- Err(error) => error.is_network_error(),
- _ => false,
- },
- constant_interval(KEY_RETRY_INTERVAL),
- KEY_RETRY_MAX_RETRIES,
- )
+ let proxy = mullvad_rpc::DevicesProxy::new(
+ rpc_runtime
+ .mullvad_rest_handle(
+ ApiConnectionMode::try_from_cache(&cache_path)
+ .await
+ .into_repeat(),
+ |_| async { true },
+ )
+ .await,
+ );
+ retry_future_n(
+ move || proxy.remove(device.token.clone(), device.device.id.clone()),
+ move |result| match result {
+ Err(error) => error.is_network_error(),
+ _ => false,
+ },
+ constant_interval(KEY_RETRY_INTERVAL),
+ KEY_RETRY_MAX_RETRIES,
+ )
+ .await
+ .map_err(Error::RemoveDeviceError)?;
+
+ cacher
+ .remove()
.await
- .map_err(Error::RemoveKeyError)?;
- settings
- .set_wireguard(None)
- .await
- .map_err(Error::SettingsError)?;
- }
+ .map_err(Error::WriteDeviceCacheError)?;
}
Ok(())
diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs
index b5479640e6..16f6a963f2 100644
--- a/mullvad-types/src/account.rs
+++ b/mullvad-types/src/account.rs
@@ -3,9 +3,12 @@ use chrono::{offset::Utc, DateTime};
use jnix::IntoJava;
use serde::{Deserialize, Serialize};
-/// Identifier used to authenticate or identify a Mullvad account.
+/// Identifier used to identify a Mullvad account.
pub type AccountToken = String;
+/// Identifier used to authenticate a Mullvad account.
+pub type AccessToken = String;
+
/// Account expiration info returned by the API via `/v1/me`.
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
@@ -18,7 +21,7 @@ pub struct AccountData {
impl AccountData {
/// Return true if the account has no time left.
pub fn is_expired(&self) -> bool {
- self.expiry >= Utc::now()
+ Utc::now() >= self.expiry
}
}
@@ -35,3 +38,17 @@ pub struct VoucherSubmission {
#[cfg_attr(target_os = "android", jnix(map = "|expiry| expiry.to_string()"))]
pub new_expiry: DateTime<Utc>,
}
+
+/// Token used for authentication in the API.
+#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
+pub struct AccessTokenData {
+ pub access_token: AccessToken,
+ pub expiry: DateTime<Utc>,
+}
+
+impl AccessTokenData {
+ /// Return true if the token is no longer valid.
+ pub fn is_expired(&self) -> bool {
+ Utc::now() >= self.expiry
+ }
+}
diff --git a/mullvad-types/src/device.rs b/mullvad-types/src/device.rs
new file mode 100644
index 0000000000..4b0123e6a9
--- /dev/null
+++ b/mullvad-types/src/device.rs
@@ -0,0 +1,142 @@
+use crate::{account::AccountToken, wireguard};
+#[cfg(target_os = "android")]
+use jnix::IntoJava;
+use serde::{Deserialize, Serialize};
+use std::fmt;
+use talpid_types::net::wireguard::PublicKey;
+
+/// UUID for a device.
+pub type DeviceId = String;
+
+/// Human-readable device identifier.
+pub type DeviceName = String;
+
+/// Contains data for a device returned by the API.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+#[cfg_attr(target_os = "android", derive(IntoJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
+pub struct Device {
+ pub id: DeviceId,
+ pub name: DeviceName,
+ #[cfg_attr(target_os = "android", jnix(map = "|key| *key.as_bytes()"))]
+ pub pubkey: PublicKey,
+ pub ports: Vec<DevicePort>,
+}
+
+impl Eq for Device {}
+
+impl Device {
+ /// Return name with each word capitalized: "Happy Seagull" instead of "happy seagull"
+ pub fn pretty_name(&self) -> String {
+ self.name
+ .split_whitespace()
+ .map(|word| {
+ let mut chars = word.chars();
+ match chars.next() {
+ None => String::new(),
+ Some(c) => c.to_uppercase().chain(chars).collect(),
+ }
+ })
+ .collect::<Vec<String>>()
+ .join(" ")
+ }
+
+ pub fn eq_id(&self, other: &Device) -> bool {
+ self.id == other.id
+ }
+}
+
+/// Ports associated with a device.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+#[cfg_attr(target_os = "android", derive(IntoJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
+pub struct DevicePort {
+ /// Port identifier.
+ pub id: String,
+}
+
+impl fmt::Display for DevicePort {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.id)
+ }
+}
+
+/// A complete device configuration.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+pub struct DeviceData {
+ pub token: AccountToken,
+ pub device: Device,
+ pub wg_data: wireguard::WireguardData,
+}
+
+impl From<DeviceData> for Device {
+ fn from(data: DeviceData) -> Device {
+ data.device
+ }
+}
+
+/// [`DeviceData`] excluding the private key.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+#[cfg_attr(target_os = "android", derive(IntoJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
+pub struct DeviceConfig {
+ pub token: AccountToken,
+ pub device: Device,
+}
+
+impl From<DeviceData> for DeviceConfig {
+ fn from(data: DeviceData) -> DeviceConfig {
+ DeviceConfig {
+ token: data.token,
+ device: data.device,
+ }
+ }
+}
+
+/// Emitted when logging in or out of an account, or when the device changes.
+#[derive(Clone, Debug)]
+#[cfg_attr(target_os = "android", derive(IntoJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
+pub struct DeviceEvent {
+ /// Device that was affected.
+ pub device: Option<DeviceConfig>,
+ /// Indicates whether the change was initiated remotely or by the daemon.
+ pub remote: bool,
+}
+
+impl DeviceEvent {
+ pub fn new(data: Option<DeviceData>, remote: bool) -> DeviceEvent {
+ DeviceEvent {
+ device: data.map(DeviceConfig::from),
+ remote,
+ }
+ }
+
+ pub fn from_device(data: DeviceData, remote: bool) -> DeviceEvent {
+ DeviceEvent {
+ device: Some(DeviceConfig {
+ token: data.token,
+ device: data.device,
+ }),
+ remote,
+ }
+ }
+
+ pub fn revoke(remote: bool) -> Self {
+ Self {
+ device: None,
+ remote,
+ }
+ }
+}
+
+/// Emitted when a device is removed using the `RemoveDevice` RPC.
+/// This is not sent by a normal logout or when it is revoked remotely.
+#[derive(Clone, Debug)]
+#[cfg_attr(target_os = "android", derive(IntoJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
+pub struct RemoveDeviceEvent {
+ pub account_token: AccountToken,
+ pub removed_device: Device,
+ pub new_devices: Vec<Device>,
+}
diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs
index e93ab2f606..6d636aceb5 100644
--- a/mullvad-types/src/lib.rs
+++ b/mullvad-types/src/lib.rs
@@ -2,6 +2,7 @@
pub mod account;
pub mod auth_failed;
+pub mod device;
pub mod endpoint;
pub mod location;
pub mod relay_constraints;
diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs
index 26a24202a5..63ccb480a2 100644
--- a/mullvad-types/src/settings/mod.rs
+++ b/mullvad-types/src/settings/mod.rs
@@ -61,9 +61,6 @@ impl Serialize for SettingsVersion {
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
pub struct Settings {
- account_token: Option<String>,
- #[cfg_attr(target_os = "android", jnix(skip))]
- wireguard: Option<wireguard::WireguardData>,
relay_settings: RelaySettings,
#[cfg_attr(target_os = "android", jnix(skip))]
pub bridge_settings: BridgeSettings,
@@ -102,8 +99,6 @@ pub struct SplitTunnelSettings {
impl Default for Settings {
fn default() -> Self {
Settings {
- account_token: None,
- wireguard: None,
relay_settings: RelaySettings::Normal(RelayConstraints {
location: Constraint::Only(LocationConstraint::Country("se".to_owned())),
..Default::default()
@@ -123,45 +118,6 @@ impl Default for Settings {
}
impl Settings {
- pub fn get_account_token(&self) -> Option<String> {
- self.account_token.clone()
- }
-
- /// Changes account number to the one given. Also saves the new settings to disk.
- /// The boolean in the Result indicates if the account token changed or not
- pub fn set_account_token(&mut self, mut account_token: Option<String>) -> bool {
- if account_token.as_ref().map(String::len) == Some(0) {
- log::debug!("Setting empty account token is treated as unsetting it");
- account_token = None;
- }
- if account_token != self.account_token {
- if account_token.is_none() {
- log::info!("Unsetting account token");
- } else if self.account_token.is_none() {
- log::info!("Setting account token");
- } else {
- log::info!("Changing account token")
- }
- self.account_token = account_token;
- true
- } else {
- false
- }
- }
-
- pub fn get_wireguard(&self) -> Option<wireguard::WireguardData> {
- self.wireguard.clone()
- }
-
- pub fn set_wireguard(&mut self, wireguard: Option<wireguard::WireguardData>) -> bool {
- if wireguard != self.wireguard {
- self.wireguard = wireguard;
- true
- } else {
- false
- }
- }
-
pub fn get_relay_settings(&self) -> RelaySettings {
self.relay_settings.clone()
}
diff --git a/mullvad-types/src/states.rs b/mullvad-types/src/states.rs
index 9d3b188db4..86d9e816fe 100644
--- a/mullvad-types/src/states.rs
+++ b/mullvad-types/src/states.rs
@@ -55,4 +55,12 @@ impl TunnelState {
_ => false,
}
}
+
+ /// Returns true if the tunnel state is in the connected state.
+ pub fn is_connected(&self) -> bool {
+ match self {
+ TunnelState::Connected { .. } => true,
+ _ => false,
+ }
+ }
}
diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs
index 2991eb1a1d..4c05f1e552 100644
--- a/mullvad-types/src/wireguard.rs
+++ b/mullvad-types/src/wireguard.rs
@@ -145,24 +145,3 @@ pub struct AssociatedAddresses {
pub ipv4_address: ipnetwork::Ipv4Network,
pub ipv6_address: ipnetwork::Ipv6Network,
}
-
-/// Event that is emitted when the daemon has finished generating a key.
-#[derive(Clone, Debug, Deserialize, Serialize)]
-#[serde(rename_all = "snake_case")]
-#[cfg_attr(target_os = "android", derive(IntoJava))]
-#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
-pub enum KeygenEvent {
- NewKey(PublicKey),
- TooManyKeys,
- GenerationFailure,
-}
-
-impl fmt::Display for KeygenEvent {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
- match self {
- KeygenEvent::NewKey(new_key) => write!(f, "New wireguard key {}", new_key.key),
- KeygenEvent::TooManyKeys => write!(f, "Account has too many keys already"),
- KeygenEvent::GenerationFailure => write!(f, "Failed to generate new wireguard key"),
- }
- }
-}
diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs
index 050b90c81c..8c6424bc01 100644
--- a/talpid-core/src/mpsc.rs
+++ b/talpid-core/src/mpsc.rs
@@ -3,3 +3,9 @@ pub trait Sender<T> {
/// Sends an item over the underlying channel, failing only if the channel is closed.
fn send(&self, item: T) -> Result<(), ()>;
}
+
+impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> {
+ fn send(&self, content: E) -> Result<(), ()> {
+ self.unbounded_send(content).map_err(|_| ())
+ }
+}