diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-03-14 13:40:36 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-14 13:40:36 +0100 |
| commit | 6459ae7beefcc5f13eb54254dfe402dd807c62fe (patch) | |
| tree | bc03c4027aad5c47f00dfa4c1fb3584dff4d1add | |
| parent | 78dc4644a82d7b3fb904ef3cbac8a1f705f0a213 (diff) | |
| parent | 3e1271777fd7556a76abc582bd3c44356ecbd15a (diff) | |
| download | mullvadvpn-6459ae7beefcc5f13eb54254dfe402dd807c62fe.tar.xz mullvadvpn-6459ae7beefcc5f13eb54254dfe402dd807c62fe.zip | |
Merge branch 'device-api'
37 files changed, 3097 insertions, 1821 deletions
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt index 4565844daa..8470f314d7 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt @@ -45,7 +45,8 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { } fun generateWireguardKey(): KeygenEvent? { - return generateWireguardKey(daemonInterfaceAddress) + // TODO: remove + return null } fun getAccountData(accountToken: String): GetAccountDataResult { @@ -85,6 +86,7 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { } fun getWireguardKey(): PublicKey? { + // TODO: no longer needed return getWireguardKey(daemonInterfaceAddress) } @@ -97,7 +99,7 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { } fun setAccount(accountToken: String?) { - setAccount(daemonInterfaceAddress, accountToken) + // TODO: replace with login+logout } fun setAllowLan(allowLan: Boolean) { @@ -154,7 +156,6 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { private external fun connect(daemonInterfaceAddress: Long) private external fun createNewAccount(daemonInterfaceAddress: Long): String? private external fun disconnect(daemonInterfaceAddress: Long) - private external fun generateWireguardKey(daemonInterfaceAddress: Long): KeygenEvent? private external fun getAccountData( daemonInterfaceAddress: Long, accountToken: String @@ -170,7 +171,8 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { private external fun getWireguardKey(daemonInterfaceAddress: Long): PublicKey? private external fun reconnect(daemonInterfaceAddress: Long) private external fun clearAccountHistory(daemonInterfaceAddress: Long) - private external fun setAccount(daemonInterfaceAddress: Long, accountToken: String?) + private external fun loginAccount(daemonInterfaceAddress: Long, accountToken: String?) + private external fun logoutAccount(daemonInterfaceAddress: Long) private external fun setAllowLan(daemonInterfaceAddress: Long, allowLan: Boolean) private external fun setAutoConnect(daemonInterfaceAddress: Long, alwaysOn: Boolean) private external fun setDnsOptions(daemonInterfaceAddress: Long, dnsOptions: DnsOptions) @@ -190,10 +192,6 @@ class MullvadDaemon(val vpnService: MullvadVpnService) { onAppVersionInfoChange?.invoke(appVersionInfo) } - private fun notifyKeygenEvent(event: KeygenEvent) { - onKeygenEvent?.invoke(event) - } - private fun notifyRelayListEvent(relayList: RelayList) { onRelayListChange?.invoke(relayList) } diff --git a/dist-assets/linux/before-remove.sh b/dist-assets/linux/before-remove.sh index 602d3d09fc..6d4ec5262f 100644 --- a/dist-assets/linux/before-remove.sh +++ b/dist-assets/linux/before-remove.sh @@ -26,4 +26,4 @@ fi pkill -x "mullvad-gui" || true /opt/Mullvad\ VPN/resources/mullvad-setup reset-firewall || echo "Failed to reset firewall" -/opt/Mullvad\ VPN/resources/mullvad-setup remove-wireguard-key || echo "Failed to remove leftover WireGuard key" +/opt/Mullvad\ VPN/resources/mullvad-setup remove-device || echo "Failed to remove device from account" diff --git a/dist-assets/uninstall_macos.sh b/dist-assets/uninstall_macos.sh index 7833ba528f..b7f27302f2 100755 --- a/dist-assets/uninstall_macos.sh +++ b/dist-assets/uninstall_macos.sh @@ -23,7 +23,7 @@ sudo dscl . -delete /groups/mullvad-exclusion || echo "Failed to remove 'mullvad echo "Resetting firewall" sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup reset-firewall -sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup remove-wireguard-key +sudo /Applications/Mullvad\ VPN.app/Contents/Resources/mullvad-setup remove-device echo "Removing zsh shell completion symlink ..." sudo rm -f /usr/local/share/zsh/site-functions/_mullvad diff --git a/dist-assets/windows/installer.nsh b/dist-assets/windows/installer.nsh index 422d6d3169..b04fadfb2d 100644 --- a/dist-assets/windows/installer.nsh +++ b/dist-assets/windows/installer.nsh @@ -712,25 +712,25 @@ !define FirewallWarningCheck '!insertmacro "FirewallWarningCheck"' # -# RemoveWireGuardKey +# RemoveCurrentDevice # -# Remove the WireGuard key from the account, if there is one +# Remove the device from the account, if there is one # -!macro RemoveWireGuardKey +!macro RemoveCurrentDevice - log::Log "RemoveWireGuardKey()" + log::Log "RemoveCurrentDevice()" Push $0 Push $1 - nsExec::ExecToStack '"$TEMP\mullvad-setup.exe" remove-wireguard-key' + nsExec::ExecToStack '"$TEMP\mullvad-setup.exe" remove-device' Pop $0 Pop $1 ${If} $0 != ${MVSETUP_OK} - log::LogWithDetails "RemoveWireGuardKey() failed" $1 + log::LogWithDetails "RemoveCurrentDevice() failed" $1 ${Else} - log::Log "RemoveWireGuardKey() completed successfully" + log::Log "RemoveCurrentDevice() completed successfully" ${EndIf} Pop $1 @@ -738,7 +738,7 @@ !macroend -!define RemoveWireGuardKey '!insertmacro "RemoveWireGuardKey"' +!define RemoveCurrentDevice '!insertmacro "RemoveCurrentDevice"' # @@ -1170,7 +1170,7 @@ ${If} $FullUninstall == 1 ${ClearFirewallRules} - ${RemoveWireGuardKey} + ${RemoveCurrentDevice} ${ExtractWireGuard} ${RemoveWintun} diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs index 0bbbc28024..b4ef7c7f14 100644 --- a/mullvad-cli/src/cmds/account.rs +++ b/mullvad-cli/src/cmds/account.rs @@ -1,9 +1,20 @@ use crate::{new_rpc_client, Command, Error, Result}; use itertools::Itertools; -use mullvad_management_interface::{types::Timestamp, Code}; -use mullvad_types::account::AccountToken; +use mullvad_management_interface::{ + types::{self, Timestamp}, + Code, ManagementServiceClient, Status, +}; +use mullvad_types::{account::AccountToken, device::Device}; use std::io::{self, Write}; +const NOT_LOGGED_IN_ERROR: &str = "Not logged in to any account"; +const DEVICE_NOT_FOUND_ERROR: &str = "There is no such device"; +const INVALID_ACCOUNT_ERROR: &str = "The account does not exist"; +const TOO_MANY_DEVICES_ERROR: &str = + "There are too many devices on this account. Revoke one to log in"; +const ALREADY_LOGGED_IN_ERROR: &str = + "You are already logged in. Please log out before creating a new account"; + pub struct Account; #[mullvad_management_interface::async_trait] @@ -16,23 +27,55 @@ impl Command for Account { clap::App::new(self.name()) .about("Control and display information about your Mullvad account") .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::App::new("create").about("Create and log in to a new account")) .subcommand( - clap::App::new("set").about("Change account").arg( - clap::Arg::new("token") + clap::App::new("login").about("Log in to an account").arg( + clap::Arg::new("account") .help("The Mullvad account token to configure the client with") .required(false), ), ) + .subcommand(clap::App::new("logout").about("Log out of the current account")) .subcommand( clap::App::new("get") - .about("Display information about the currently configured account"), + .about("Display information about the current account") + .arg( + clap::Arg::new("verbose") + .long("verbose") + .short('v') + .help("Enables verbose output"), + ), ) .subcommand( - clap::App::new("unset").about("Removes the account number from the settings"), + clap::App::new("list-devices") + .about("List devices associated with an account") + .arg( + clap::Arg::new("account") + .help("Mullvad account number") + .long("account") + .takes_value(true), + ) + .arg( + clap::Arg::new("verbose") + .long("verbose") + .short('v') + .help("Enables verbose output"), + ), ) .subcommand( - clap::App::new("create") - .about("Creates a new account and sets it as the active one"), + clap::App::new("revoke-device") + .about("Revoke a device associated with an account") + .arg( + clap::Arg::new("account") + .help("Mullvad account number") + .long("account") + .takes_value(true), + ) + .arg( + clap::Arg::new("device") + .help("Name or ID of the device to revoke") + .required(true), + ), ) .subcommand( clap::App::new("redeem").about("Redeems a voucher").arg( @@ -44,29 +87,19 @@ impl Command for Account { } async fn run(&self, matches: &clap::ArgMatches) -> Result<()> { - if let Some(set_matches) = matches.subcommand_matches("set") { - let mut token = match set_matches.value_of("token") { - Some(token) => token.to_string(), - None => { - let mut token = String::new(); - io::stdout() - .write_all(b"Enter account token: ") - .expect("Failed to write to STDOUT"); - let _ = io::stdout().flush(); - io::stdin() - .read_line(&mut token) - .expect("Failed to read from STDIN"); - token - } - }; - token = token.split_whitespace().join("").to_string(); - self.set(Some(token)).await - } else if let Some(_matches) = matches.subcommand_matches("get") { - self.get().await - } else if let Some(_matches) = matches.subcommand_matches("unset") { - self.set(None).await - } else if let Some(_matches) = matches.subcommand_matches("create") { + if let Some(_matches) = matches.subcommand_matches("create") { self.create().await + } else if let Some(set_matches) = matches.subcommand_matches("login") { + self.login(parse_token_else_stdin(set_matches)).await + } else if let Some(_matches) = matches.subcommand_matches("logout") { + self.logout().await + } else if let Some(set_matches) = matches.subcommand_matches("get") { + let verbose = set_matches.is_present("verbose"); + self.get(verbose).await + } else if let Some(set_matches) = matches.subcommand_matches("list-devices") { + self.list_devices(set_matches).await + } else if let Some(set_matches) = matches.subcommand_matches("revoke-device") { + self.revoke_device(set_matches).await } else if let Some(matches) = matches.subcommand_matches("redeem") { let voucher = matches.value_of_t_or_exit("voucher"); self.redeem_voucher(voucher).await @@ -77,24 +110,52 @@ impl Command for Account { } impl Account { - async fn set(&self, token: Option<AccountToken>) -> Result<()> { + async fn create(&self) -> Result<()> { let mut rpc = new_rpc_client().await?; - rpc.set_account(token.clone().unwrap_or_default()).await?; - if let Some(token) = token { - println!("Mullvad account \"{}\" set", token); - } else { - println!("Mullvad account removed"); - } + rpc.create_new_account(()).await.map_err(map_device_error)?; + println!("New account created!"); + self.get(false).await + } + + async fn login(&self, token: AccountToken) -> Result<()> { + let mut rpc = new_rpc_client().await?; + rpc.login_account(token.clone()) + .await + .map_err(map_device_error)?; + println!("Mullvad account \"{}\" set", token); Ok(()) } - async fn get(&self) -> Result<()> { + async fn logout(&self) -> Result<()> { let mut rpc = new_rpc_client().await?; - let settings = rpc.get_settings(()).await?.into_inner(); - if settings.account_token != "" { - println!("Mullvad account: {}", settings.account_token); + rpc.logout_account(()).await?; + println!("Removed device from Mullvad account"); + Ok(()) + } + + async fn get(&self, verbose: bool) -> Result<()> { + let mut rpc = new_rpc_client().await?; + let device = rpc + .get_device(()) + .await + .map_err(|error| match error.code() { + Code::NotFound => Error::Other(NOT_LOGGED_IN_ERROR), + _other => map_device_error(error), + })? + .into_inner(); + if !device.account_token.is_empty() { + println!("Mullvad account: {}", device.account_token); + let inner_device = Device::try_from(device.device.unwrap()).unwrap(); + println!("Device name : {}", inner_device.pretty_name()); + if verbose { + println!("Device id : {}", inner_device.id); + println!("Device pubkey : {}", inner_device.pubkey); + for port in inner_device.ports { + println!("Device port : {}", port); + } + } let expiry = rpc - .get_account_data(settings.account_token) + .get_account_data(device.account_token) .await .map_err(|error| Error::RpcFailedExt("Failed to fetch account data", error))? .into_inner(); @@ -108,11 +169,88 @@ impl Account { Ok(()) } - async fn create(&self) -> Result<()> { + async fn list_devices(&self, matches: &clap::ArgMatches) -> Result<()> { let mut rpc = new_rpc_client().await?; - rpc.create_new_account(()).await?; - println!("New account created!"); - self.get().await + let token = self.parse_account_else_current(&mut rpc, matches).await?; + let device_list = rpc + .list_devices(token) + .await + .map_err(map_device_error)? + .into_inner(); + + let verbose = matches.is_present("verbose"); + + println!("Devices on the account:"); + for device in device_list.devices { + let device = Device::try_from(device.clone()).unwrap(); + if verbose { + println!(); + println!("Name : {}", device.pretty_name()); + println!("Id : {}", device.id); + println!("Public key: {}", device.pubkey); + for port in device.ports { + println!("Port : {}", port); + } + } else { + println!("{}", device.pretty_name()); + } + } + + Ok(()) + } + + async fn revoke_device(&self, matches: &clap::ArgMatches) -> Result<()> { + let mut rpc = new_rpc_client().await?; + + let token = self.parse_account_else_current(&mut rpc, matches).await?; + let device_to_revoke = parse_device_name(matches); + + let device_list = rpc + .list_devices(token.clone()) + .await + .map_err(map_device_error)? + .into_inner(); + let device_id = device_list + .devices + .into_iter() + .find(|dev| { + dev.name.eq_ignore_ascii_case(&device_to_revoke) + || dev.id.eq_ignore_ascii_case(&device_to_revoke) + }) + .map(|dev| dev.id) + .ok_or_else(|| Error::Other(DEVICE_NOT_FOUND_ERROR))?; + + rpc.remove_device(types::DeviceRemoval { + account_token: token, + device_id, + }) + .await + .map_err(map_device_error)?; + println!("Removed device"); + Ok(()) + } + + async fn parse_account_else_current( + &self, + rpc: &mut ManagementServiceClient, + matches: &clap::ArgMatches, + ) -> Result<String> { + match matches.value_of("account").map(str::to_string) { + Some(token) => Ok(token), + None => { + let device = rpc + .get_device(()) + .await + .map_err(|error| match error.code() { + mullvad_management_interface::Code::NotFound => { + Error::Other("Log in or specify an account") + } + _ => Error::RpcFailedExt("Failed to obtain device", error), + })? + .into_inner(); + Ok(device.account_token) + } + } } async fn redeem_voucher(&self, mut voucher: String) -> Result<()> { @@ -163,3 +301,46 @@ impl Account { utc.with_timezone(&chrono::Local).to_string() } } + +fn map_device_error(error: Status) -> Error { + match error.code() { + Code::ResourceExhausted => Error::Other(TOO_MANY_DEVICES_ERROR), + Code::Unauthenticated => Error::Other(INVALID_ACCOUNT_ERROR), + Code::AlreadyExists => Error::Other(ALREADY_LOGGED_IN_ERROR), + Code::NotFound => Error::Other(DEVICE_NOT_FOUND_ERROR), + _other => Error::RpcFailed(error), + } +} + +fn parse_token_else_stdin(matches: &clap::ArgMatches) -> String { + parse_from_match_else_stdin("Enter account number: ", "account", matches) + .split_whitespace() + .join("") +} + +fn parse_device_name(matches: &clap::ArgMatches) -> String { + parse_from_match_else_stdin("Enter device name: ", "device", matches) + .trim() + .to_string() +} + +fn parse_from_match_else_stdin( + prompt_str: &'static str, + key: &'static str, + matches: &clap::ArgMatches, +) -> String { + match matches.value_of(key) { + Some(device) => device.to_string(), + None => { + let mut val = String::new(); + io::stdout() + .write_all(prompt_str.as_bytes()) + .expect("Failed to write to STDOUT"); + let _ = io::stdout().flush(); + io::stdin() + .read_line(&mut val) + .expect("Failed to read from STDIN"); + val + } + } +} diff --git a/mullvad-cli/src/cmds/status.rs b/mullvad-cli/src/cmds/status.rs index 8c4a929c30..69052dcaf1 100644 --- a/mullvad-cli/src/cmds/status.rs +++ b/mullvad-cli/src/cmds/status.rs @@ -1,4 +1,4 @@ -use crate::{format, format::print_keygen_event, new_rpc_client, Command, Error, Result}; +use crate::{format, new_rpc_client, Command, Error, Result}; use mullvad_management_interface::{ types::daemon_event::Event as EventType, ManagementServiceClient, }; @@ -74,10 +74,14 @@ impl Command for Status { println!("New app version info: {:#?}", app_version_info); } } - EventType::KeyEvent(key_event) => { + EventType::Device(device) => { if verbose { - print!("Key event: "); - print_keygen_event(&key_event); + println!("Device event: {:#?}", device); + } + } + EventType::RemoveDevice(device) => { + if verbose { + println!("Remove device event: {:#?}", device); } } } diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs index f3b218648e..f01452a925 100644 --- a/mullvad-cli/src/cmds/tunnel.rs +++ b/mullvad-cli/src/cmds/tunnel.rs @@ -1,4 +1,4 @@ -use crate::{format::print_keygen_event, new_rpc_client, Command, Error, Result}; +use crate::{new_rpc_client, Command, Error, Result}; use mullvad_management_interface::types::{self, Timestamp, TunnelOptions}; use mullvad_types::wireguard::DEFAULT_ROTATION_INTERVAL; use std::{convert::TryFrom, time::Duration}; @@ -246,20 +246,13 @@ impl Tunnel { println!("No key is set"); return Ok(()); } - - let is_valid = rpc - .verify_wireguard_key(()) - .await - .map_err(|error| Error::RpcFailedExt("Failed to verify key", error))? - .into_inner(); - println!("Key is valid for use with current account: {}", is_valid); Ok(()) } async fn process_wireguard_key_generate() -> Result<()> { let mut rpc = new_rpc_client().await?; - let keygen_event = rpc.generate_wireguard_key(()).await?; - print_keygen_event(&keygen_event.into_inner()); + rpc.rotate_wireguard_key(()).await?; + println!("Rotated WireGuard key"); Ok(()) } diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs index b056ffff53..eb91ffcca8 100644 --- a/mullvad-cli/src/format.rs +++ b/mullvad-cli/src/format.rs @@ -5,30 +5,11 @@ use mullvad_management_interface::types::{ }, tunnel_state, tunnel_state::State::*, - ErrorState, KeygenEvent, ProxyType, TransportProtocol, TunnelEndpoint, TunnelState, TunnelType, + ErrorState, ProxyType, TransportProtocol, TunnelEndpoint, TunnelState, TunnelType, }; use mullvad_types::auth_failed::AuthFailed; use std::fmt::Write; -pub fn print_keygen_event(key_event: &KeygenEvent) { - use mullvad_management_interface::types::keygen_event::KeygenEvent as EventType; - - match EventType::from_i32(key_event.event).unwrap() { - EventType::NewKey => { - println!( - "New WireGuard key: {}", - base64::encode(&key_event.new_key.as_ref().unwrap().key) - ); - } - EventType::TooManyKeys => { - println!("Account has too many keys already"); - } - EventType::GenerationFailure => { - println!("Failed to generate new WireGuard key"); - } - } -} - pub fn print_state(state: &TunnelState) { print!("Tunnel status: "); match state.state.as_ref().unwrap() { diff --git a/mullvad-cli/src/main.rs b/mullvad-cli/src/main.rs index 55a195cdb8..df7ef0a04c 100644 --- a/mullvad-cli/src/main.rs +++ b/mullvad-cli/src/main.rs @@ -49,6 +49,9 @@ pub enum Error { //#[cfg(all(unix, not(target_os = "android")) #[error(display = "Failed to generate shell completions")] CompletionsError(#[error(source, no_from)] io::Error), + + #[error(display = "{}", _0)] + Other(&'static str), } #[tokio::main] diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs deleted file mode 100644 index f5655c9d1f..0000000000 --- a/mullvad-daemon/src/account.rs +++ /dev/null @@ -1,173 +0,0 @@ -use chrono::{DateTime, Utc}; -use futures::future::{abortable, AbortHandle}; -use mullvad_rpc::{ - availability::ApiAvailabilityHandle, - rest::{self, Error as RestError, MullvadRestHandle}, - AccountsProxy, -}; -use mullvad_types::account::{AccountToken, VoucherSubmission}; -use std::{future::Future, time::Duration}; -use talpid_core::future_retry::{ - constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered, -}; - -const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; -const RETRY_ACTION_MAX_RETRIES: usize = 2; - -const RETRY_EXPIRY_CHECK_INTERVAL_INITIAL: Duration = Duration::from_secs(4); -const RETRY_EXPIRY_CHECK_INTERVAL_FACTOR: u32 = 5; -const RETRY_EXPIRY_CHECK_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); - -pub struct Account(()); - -#[derive(Clone)] -pub struct AccountHandle { - api_availability: ApiAvailabilityHandle, - initial_check_abort_handle: AbortHandle, - proxy: AccountsProxy, -} - -impl AccountHandle { - pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> { - let mut proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - retry_future_n( - move || proxy.create_account(), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - } - - pub fn get_www_auth_token( - &self, - account: AccountToken, - ) -> impl Future<Output = Result<String, rest::Error>> { - let proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - retry_future_n( - move || proxy.get_www_auth_token(account.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - } - - pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { - let proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - let result = retry_future_n( - move || proxy.get_expiry(token.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - .await; - if handle_expiry_result_inner(&result, &self.api_availability) { - self.initial_check_abort_handle.abort(); - } - result - } - - pub async fn submit_voucher( - &mut self, - account_token: AccountToken, - voucher: String, - ) -> Result<VoucherSubmission, rest::Error> { - let mut proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - let result = retry_future_n( - move || proxy.submit_voucher(account_token.clone(), voucher.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - .await; - if result.is_ok() { - self.initial_check_abort_handle.abort(); - self.api_availability.resume_background(); - } - result - } - - fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool { - match result { - Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), - _ => false, - } - } -} - -impl Account { - pub fn new( - runtime: tokio::runtime::Handle, - rpc_handle: MullvadRestHandle, - token: Option<String>, - api_availability: ApiAvailabilityHandle, - ) -> AccountHandle { - let accounts_proxy = AccountsProxy::new(rpc_handle); - api_availability.pause_background(); - - let api_availability_copy = api_availability.clone(); - let accounts_proxy_copy = accounts_proxy.clone(); - - let (future, initial_check_abort_handle) = abortable(async move { - let token = if let Some(token) = token { - token - } else { - api_availability.pause_background(); - return; - }; - - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new( - RETRY_EXPIRY_CHECK_INTERVAL_INITIAL, - RETRY_EXPIRY_CHECK_INTERVAL_FACTOR, - ) - .max_delay(RETRY_EXPIRY_CHECK_INTERVAL_MAX), - ); - let future_generator = move || { - let wait_online = api_availability.wait_online(); - let expiry_fut = accounts_proxy.get_expiry(token.clone()); - let api_availability_copy = api_availability.clone(); - async move { - let _ = wait_online.await; - handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) - } - }; - let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; - retry_future(future_generator, should_retry, retry_strategy).await; - }); - runtime.spawn(future); - - AccountHandle { - api_availability: api_availability_copy, - initial_check_abort_handle, - proxy: accounts_proxy_copy, - } - } -} - -fn handle_expiry_result_inner( - result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>, - api_availability: &ApiAvailabilityHandle, -) -> bool { - match result { - Ok(_expiry) if *_expiry >= chrono::Utc::now() => { - api_availability.resume_background(); - true - } - Ok(_expiry) => { - api_availability.pause_background(); - true - } - Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => { - if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH { - api_availability.pause_background(); - return true; - } - false - } - Err(_) => false, - } -} diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs new file mode 100644 index 0000000000..868dc003d8 --- /dev/null +++ b/mullvad-daemon/src/device.rs @@ -0,0 +1,1136 @@ +use chrono::{DateTime, Utc}; +use futures::{ + channel::{mpsc, oneshot}, + future::{abortable, AbortHandle}, + stream::StreamExt, +}; +use mullvad_rpc::{ + availability::ApiAvailabilityHandle, + rest::{self, Error as RestError, MullvadRestHandle}, + AccountsProxy, DevicesProxy, +}; +use mullvad_types::{ + account::{AccountToken, VoucherSubmission}, + device::{Device, DeviceData, DeviceEvent, DeviceId}, + wireguard::{RotationInterval, WireguardData}, +}; +use std::{ + future::Future, + path::Path, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, SystemTime}, +}; +use talpid_core::{ + future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered}, + mpsc::Sender, +}; +use talpid_types::{ + net::{wireguard::PrivateKey, TunnelType}, + tunnel::TunnelStateTransition, + ErrorExt, +}; +use tokio::{ + fs, + io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; + +/// How often to check whether the key has expired. +/// A short interval is used in case the computer is ever suspended. +const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(5 * 60); + +/// File that used to store account and device data. +const DEVICE_CACHE_FILENAME: &str = "device.json"; + +const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; +const RETRY_ACTION_MAX_RETRIES: usize = 2; + +const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4); +const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5; +const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); + +/// How long to keep the known status for [AccountManagerHandle::validate_device]. +const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10); + +/// How long to wait on logout (device removal) before letting it continue as a background task. +const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2); + +/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts +/// to set up a WireGuard tunnel. +const WG_DEVICE_CHECK_THRESHOLD: usize = 3; + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "The account already has a maximum number of devices")] + MaxDevicesReached, + #[error(display = "No device is set")] + NoDevice, + #[error(display = "Device not found")] + InvalidDevice, + #[error(display = "Invalid account")] + InvalidAccount, + #[error(display = "Failed to read or write device cache")] + DeviceIoError(#[error(source)] io::Error), + #[error(display = "Failed parse device cache")] + ParseDeviceCache(#[error(source)] serde_json::Error), + #[error(display = "Unexpected HTTP request error")] + OtherRestError(#[error(source)] rest::Error), + #[error(display = "The device update task is not running")] + DeviceUpdaterCancelled(#[error(source)] oneshot::Canceled), + #[error(display = "The account manager is down")] + AccountManagerDown, +} + +#[derive(Clone)] +pub(crate) enum InnerDeviceEvent { + /// The device was removed due to user (or daemon) action. + Logout, + /// Logged in to a new device. + Login(DeviceData), + /// The device was updated remotely, but not its key. + Updated(DeviceData), + /// The key was rotated. + RotatedKey(DeviceData), + /// Device was removed because it was not found remotely. + Revoked, +} + +impl From<InnerDeviceEvent> for DeviceEvent { + fn from(event: InnerDeviceEvent) -> DeviceEvent { + match event { + InnerDeviceEvent::Logout => DeviceEvent::revoke(false), + InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false), + InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true), + InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false), + InnerDeviceEvent::Revoked => DeviceEvent::revoke(true), + } + } +} + +impl InnerDeviceEvent { + fn data(&self) -> Option<&DeviceData> { + match self { + InnerDeviceEvent::Login(data) => Some(&data), + InnerDeviceEvent::Updated(data) => Some(&data), + InnerDeviceEvent::RotatedKey(data) => Some(&data), + InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None, + } + } + + fn into_data(self) -> Option<DeviceData> { + match self { + InnerDeviceEvent::Login(data) => Some(data), + InnerDeviceEvent::Updated(data) => Some(data), + InnerDeviceEvent::RotatedKey(data) => Some(data), + InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None, + } + } +} + +impl Error { + pub fn is_network_error(&self) -> bool { + if let Error::OtherRestError(error) = self { + error.is_network_error() + } else { + false + } + } +} + +pub enum ValidationResult { + /// The device and key were valid. + Valid, + /// The device was valid but the key was replaced + RotatedKey, + /// The device was valid but one or more fields, such as ports, were replaced + Updated, + /// The device was not found remotely and was removed from the cache. + Removed, +} + +type ResponseTx<T> = oneshot::Sender<Result<T, Error>>; + +enum AccountManagerCommand { + Login(AccountToken, ResponseTx<()>), + Logout(ResponseTx<()>), + SetData(DeviceData, ResponseTx<()>), + GetData(ResponseTx<Option<DeviceData>>), + RotateKey(ResponseTx<()>), + SetRotationInterval(RotationInterval, ResponseTx<()>), + GetRotationInterval(ResponseTx<RotationInterval>), + ValidateDevice(ResponseTx<ValidationResult>), + ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>), + Shutdown(oneshot::Sender<()>), +} + +#[derive(Clone)] +pub(crate) struct AccountManagerHandle { + cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>, + pub account_service: AccountService, + pub device_service: DeviceService, +} + +impl AccountManagerHandle { + pub async fn login(&self, token: AccountToken) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::Login(token, tx)) + .await + } + + pub async fn logout(&self) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::Logout(tx)) + .await + } + + pub async fn set(&self, data: DeviceData) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::SetData(data, tx)) + .await + } + + pub async fn data(&self) -> Result<Option<DeviceData>, Error> { + self.send_command(|tx| AccountManagerCommand::GetData(tx)) + .await + } + + pub async fn rotate_key(&self) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::RotateKey(tx)) + .await + } + + pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx)) + .await + } + + pub async fn rotation_interval(&self) -> Result<RotationInterval, Error> { + self.send_command(|tx| AccountManagerCommand::GetRotationInterval(tx)) + .await + } + + pub async fn validate_device(&self) -> Result<ValidationResult, Error> { + self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx)) + .await + } + + pub async fn receive_events( + &self, + events_tx: impl Sender<InnerDeviceEvent> + Send + 'static, + ) -> Result<(), Error> { + self.send_command(|tx| { + AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx) + }) + .await + } + + pub async fn shutdown(self) { + let (tx, rx) = oneshot::channel(); + let _ = self + .cmd_tx + .unbounded_send(AccountManagerCommand::Shutdown(tx)); + let _ = rx.await; + } + + async fn send_command<T>( + &self, + make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand, + ) -> Result<T, Error> { + let (tx, rx) = oneshot::channel(); + self.cmd_tx + .unbounded_send(make_cmd(tx)) + .map_err(|_| Error::AccountManagerDown)?; + rx.await.map_err(|_| Error::AccountManagerDown)? + } +} + +pub(crate) struct AccountManager { + cacher: DeviceCacher, + device_service: DeviceService, + data: Option<DeviceData>, + rotation_interval: RotationInterval, + listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>, + last_validation: Option<SystemTime>, +} + +impl AccountManager { + pub async fn spawn( + rest_handle: rest::MullvadRestHandle, + api_availability: ApiAvailabilityHandle, + settings_dir: &Path, + initial_rotation_interval: RotationInterval, + ) -> Result<AccountManagerHandle, Error> { + let (cacher, data) = DeviceCacher::new(settings_dir).await?; + let token = data.as_ref().map(|state| state.token.clone()); + let account_service = + spawn_account_service(rest_handle.clone(), token, api_availability.clone()); + + let (cmd_tx, cmd_rx) = mpsc::unbounded(); + + let device_service = DeviceService::new(rest_handle, api_availability); + let manager = AccountManager { + cacher, + device_service: device_service.clone(), + data, + rotation_interval: initial_rotation_interval, + listeners: vec![], + last_validation: None, + }; + + tokio::spawn(manager.run(cmd_rx)); + let handle = AccountManagerHandle { + cmd_tx, + account_service, + device_service, + }; + KeyUpdater::spawn(handle.clone()).await?; + Ok(handle) + } + + async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) { + let mut shutdown_tx = None; + while let Some(cmd) = cmd_rx.next().await { + match cmd { + AccountManagerCommand::Shutdown(tx) => { + shutdown_tx = Some(tx); + break; + } + other => self.service_command(other).await, + } + } + self.shutdown().await; + if let Some(tx) = shutdown_tx { + let _ = tx.send(()); + } + log::debug!("Account manager has stopped"); + } + + async fn service_command(&mut self, cmd: AccountManagerCommand) { + match cmd { + AccountManagerCommand::Login(token, tx) => { + let _ = tx.send(self.login(token).await); + } + AccountManagerCommand::Logout(tx) => { + let _ = tx.send(self.logout().await); + } + AccountManagerCommand::SetData(data, tx) => { + let _ = tx.send(self.set(InnerDeviceEvent::Login(data)).await); + } + AccountManagerCommand::GetData(tx) => { + let _ = tx.send(Ok(self.data.clone())); + } + AccountManagerCommand::RotateKey(tx) => { + let _ = tx.send(self.rotate_key().await); + } + AccountManagerCommand::SetRotationInterval(interval, tx) => { + self.rotation_interval = interval; + let _ = tx.send(Ok(())); + } + AccountManagerCommand::GetRotationInterval(tx) => { + let _ = tx.send(Ok(self.rotation_interval)); + } + AccountManagerCommand::ValidateDevice(tx) => { + let _ = tx.send(self.validate_device().await); + } + AccountManagerCommand::ReceiveEvents(events_tx, tx) => { + let _ = tx.send(Ok(self.listeners.push(events_tx))); + } + AccountManagerCommand::Shutdown(_) => unreachable!("shutdown is handled earlier"), + } + } + + async fn login(&mut self, token: AccountToken) -> Result<(), Error> { + let data = self.device_service.generate_for_account(token).await?; + self.set(InnerDeviceEvent::Login(data)).await?; + Ok(()) + } + + async fn logout(&mut self) -> Result<(), Error> { + if self.data.is_some() { + self.cacher.write(None).await?; + let _ = tokio::time::timeout(LOGOUT_TIMEOUT, self.logout_inner()).await; + + let event = InnerDeviceEvent::Logout; + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); + } + Ok(()) + } + + async fn logout_inner(&mut self) -> tokio::task::JoinHandle<()> { + let prev_data = self.data.take(); + let service = self.device_service.clone(); + + tokio::spawn(async move { + if let Some(data) = prev_data { + if let Err(error) = service + .remove_device_with_backoff(data.token, data.device.id) + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to remove a previous device") + ); + } + } + }) + } + + async fn set(&mut self, event: InnerDeviceEvent) -> Result<(), Error> { + let data = event.data(); + if data == self.data.as_ref() { + return Ok(()); + } + + self.cacher.write(data).await?; + self.last_validation = None; + + if self + .data + .as_ref() + .map(|current| data.as_ref().map(|d| &d.device.id) != Some(¤t.device.id)) + .unwrap_or(false) + { + // Remove the existing device if its ID differs. Otherwise, only update + // the data. + self.logout_inner().await; + } + + self.data = data.cloned(); + + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); + + Ok(()) + } + + async fn rotate_key(&mut self) -> Result<(), Error> { + // TODO: Update all data opportunistically? + let data = self.data.as_ref().ok_or(Error::NoDevice)?; + + let wg_data = self + .device_service + .rotate_key(data.token.clone(), data.device.id.clone()) + .await?; + + // Copy the data to keep a predictable state if an error occurs. + let mut new_data = data.clone(); + new_data.device.pubkey = wg_data.private_key.public_key(); + new_data.wg_data = wg_data; + self.set(InnerDeviceEvent::RotatedKey(new_data)).await + } + + /// Check if the device is valid for the account, and yank it if it no longer exists. + /// This also updates any associated data and returns whether it changed. + async fn validate_device(&mut self) -> Result<ValidationResult, Error> { + log::debug!("Checking whether the device is still valid"); + + if let Some(result) = self.cached_validation() { + log::debug!("The current device is still valid"); + return Ok(result); + } + + let data = self.data.as_ref().ok_or(Error::NoDevice)?; + + match self + .device_service + .get(data.token.clone(), data.device.id.clone()) + .await + { + Ok(device) => { + if device.pubkey == data.device.pubkey { + if device == data.device { + log::debug!("The current device is still valid"); + Ok(ValidationResult::Valid) + } else { + log::debug!("Updating data for the current device"); + // Copy the data to keep a predictable state if an error occurs. + let new_data = DeviceData { + device, + ..data.clone() + }; + self.set(InnerDeviceEvent::Updated(new_data)).await?; + Ok(ValidationResult::Updated) + } + } else { + log::debug!("Rotating invalid WireGuard key"); + self.rotate_key().await?; + Ok(ValidationResult::RotatedKey) + } + } + Err(Error::InvalidAccount) | Err(Error::InvalidDevice) => { + log::debug!("The current device is no longer valid for this account"); + + self.cacher.write(None).await?; + self.data = None; + + let event = InnerDeviceEvent::Revoked; + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); + + Ok(ValidationResult::Removed) + } + Err(error) => Err(error), + } + } + + fn cached_validation(&mut self) -> Option<ValidationResult> { + if self.data.is_none() { + return None; + } + + let now = SystemTime::now(); + + let elapsed = self + .last_validation + .and_then(|last_check| now.duration_since(last_check).ok()) + .unwrap_or(VALIDITY_CACHE_TIMEOUT); + + if elapsed >= VALIDITY_CACHE_TIMEOUT { + self.last_validation = Some(now); + return None; + } + + Some(ValidationResult::Valid) + } + + async fn shutdown(self) { + self.cacher.finalize().await; + } +} + +struct KeyUpdater { + handle: AccountManagerHandle, + rx: mpsc::UnboundedReceiver<InnerDeviceEvent>, + data: Option<DeviceData>, +} + +impl KeyUpdater { + async fn spawn(handle: AccountManagerHandle) -> Result<(), Error> { + let (tx, rx) = mpsc::unbounded(); + handle.receive_events(tx).await?; + let data = handle.data().await?; + let mut key_rotator = KeyUpdater { handle, rx, data }; + + tokio::spawn(async move { + loop { + tokio::time::sleep(KEY_CHECK_INTERVAL).await; + + if let Err(error) = key_rotator.check_key_validity().await { + if let Error::AccountManagerDown = error { + break; + } + log::error!( + "{}", + error.display_chain_with_msg("Stopping key rotation task due to an error") + ); + break; + } + } + log::debug!("Stopping key updater"); + }); + + Ok(()) + } + + async fn check_key_validity(&mut self) -> Result<(), Error> { + let rotation_interval = self.handle.rotation_interval().await?; + let data = self.wait_for_data().await?; + + if (chrono::Utc::now() + .signed_duration_since(data.wg_data.created) + .num_seconds() as u64) + < rotation_interval.as_duration().as_secs() + { + return Ok(()); + } + + let mut data = data.clone(); + + let rotation_fut = self + .handle + .device_service + .rotate_key_with_backoff(data.token.clone(), data.device.id.clone()); + + match futures::future::select(Box::pin(rotation_fut), self.rx.next()).await { + futures::future::Either::Left((Ok(wg_data), _)) => { + log::debug!("Rotating WireGuard key"); + data.device.pubkey = wg_data.private_key.public_key(); + data.wg_data = wg_data; + self.handle.set(data).await?; + } + futures::future::Either::Left((Err(error), _)) => { + log::error!( + "{}", + error.display_chain_with_msg("Stopping key rotation due to an error") + ); + + // Forget the current device. Key rotation will restart when + // it is updated in any way. + self.data = None; + } + futures::future::Either::Right((event, _)) => { + // Abort key rotation if the device changed + if let Some(event) = event { + self.data = event.into_data(); + } else { + return Err(Error::AccountManagerDown); + } + } + } + + Ok(()) + } + + async fn wait_for_data(&mut self) -> Result<&DeviceData, Error> { + while let Ok(item) = self.rx.try_next() { + match item { + Some(event) => { + self.data = event.into_data(); + } + None => return Err(Error::AccountManagerDown), + } + } + + match self.data { + Some(ref data) => Ok(data), + None => loop { + let event = self.rx.next().await; + match event { + Some(event) => { + if let Some(data) = event.into_data() { + self.data = Some(data); + break Ok(self.data.as_ref().unwrap()); + } + } + None => break Err(Error::AccountManagerDown), + } + }, + } + } +} + +#[derive(Clone)] +pub struct DeviceService { + api_availability: ApiAvailabilityHandle, + proxy: DevicesProxy, +} + +impl DeviceService { + pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self { + Self { + proxy: DevicesProxy::new(handle), + api_availability, + } + } + + /// Generate a new device for a given token + pub async fn generate_for_account(&self, token: AccountToken) -> Result<DeviceData, Error> { + let private_key = PrivateKey::new_from_random(); + let pubkey = private_key.public_key(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let token_copy = token.clone(); + let (device, addresses) = retry_future_n( + move || proxy.create(token_copy.clone(), pubkey.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(map_rest_error)?; + + Ok(DeviceData { + token, + device, + wg_data: WireguardData { + private_key, + addresses, + created: Utc::now(), + }, + }) + } + + pub async fn generate_for_account_with_backoff( + &self, + token: AccountToken, + ) -> Result<DeviceData, Error> { + let private_key = PrivateKey::new_from_random(); + let pubkey = private_key.public_key(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let token_copy = token.clone(); + let (device, addresses) = retry_future( + move || api_handle.when_online(proxy.create(token_copy.clone(), pubkey.clone())), + should_retry_backoff, + retry_strategy(), + ) + .await + .map_err(map_rest_error)?; + + Ok(DeviceData { + token, + device, + wg_data: WireguardData { + private_key, + addresses, + created: Utc::now(), + }, + }) + } + + pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.remove(token.clone(), device.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(map_rest_error)?; + Ok(()) + } + + pub async fn remove_device_with_backoff( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<(), Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ), // Not setting a maximum interval + ); + + retry_future( + // NOTE: Not honoring "paused" state, because the account may have no time on it. + move || api_handle.when_online(proxy.remove(token.clone(), device.clone())), + should_retry_backoff, + retry_strategy, + ) + .await + .map_err(map_rest_error)?; + + Ok(()) + } + + pub async fn rotate_key( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<WireguardData, Error> { + let private_key = PrivateKey::new_from_random(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let pubkey = private_key.public_key(); + let addresses = retry_future_n( + move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(map_rest_error)?; + + Ok(WireguardData { + private_key, + addresses, + created: Utc::now(), + }) + } + + pub async fn rotate_key_with_backoff( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<WireguardData, Error> { + let private_key = PrivateKey::new_from_random(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let pubkey = private_key.public_key(); + + let addresses = retry_future( + move || { + api_handle.when_bg_resumes(proxy.replace_wg_key( + token.clone(), + device.clone(), + pubkey.clone(), + )) + }, + should_retry_backoff, + retry_strategy(), + ) + .await + .map_err(map_rest_error)?; + + Ok(WireguardData { + private_key, + addresses, + created: Utc::now(), + }) + } + + pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.list(token.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(map_rest_error) + } + + pub async fn list_devices_with_backoff( + &self, + token: AccountToken, + ) -> Result<Vec<Device>, Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + + retry_future( + move || api_handle.when_online(proxy.list(token.clone())), + should_retry_backoff, + retry_strategy(), + ) + .await + .map_err(map_rest_error) + } + + pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result<Device, Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.get(token.clone(), device.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(map_rest_error) + } +} + +pub struct DeviceCacher { + file: io::BufWriter<fs::File>, + path: std::path::PathBuf, +} + +impl DeviceCacher { + pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> { + let mut options = std::fs::OpenOptions::new(); + #[cfg(unix)] + { + use std::os::unix::fs::OpenOptionsExt; + options.mode(0o600); + } + #[cfg(windows)] + { + use std::os::windows::fs::OpenOptionsExt; + // exclusive access + options.share_mode(0); + } + + let path = settings_dir.join(DEVICE_CACHE_FILENAME); + let cache_exists = path.is_file(); + + let mut file = fs::OpenOptions::from(options) + .write(true) + .read(true) + .create(true) + .open(&path) + .await?; + + let device: Option<DeviceData> = if cache_exists { + let mut reader = io::BufReader::new(&mut file); + let mut buffer = String::new(); + reader.read_to_string(&mut buffer).await?; + if !buffer.is_empty() { + serde_json::from_str(&buffer)? + } else { + None + } + } else { + None + }; + + Ok(( + DeviceCacher { + file: io::BufWriter::new(file), + path, + }, + device, + )) + } + + pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> { + let data = serde_json::to_vec_pretty(&device).unwrap(); + + self.file.get_mut().set_len(0).await?; + self.file.seek(io::SeekFrom::Start(0)).await?; + self.file.write_all(&data).await?; + self.file.flush().await?; + self.file.get_mut().sync_data().await?; + + Ok(()) + } + + pub async fn remove(self) -> Result<(), Error> { + let path = { + let DeviceCacher { path, file } = self; + let std_file = file.into_inner().into_std().await; + let _ = tokio::task::spawn_blocking(move || drop(std_file)).await; + path + }; + tokio::fs::remove_file(path).await?; + Ok(()) + } + + async fn finalize(self) { + let std_file = self.file.into_inner().into_std().await; + let _ = tokio::task::spawn_blocking(move || drop(std_file)).await; + } +} + +#[derive(Clone)] +pub struct AccountService { + api_availability: ApiAvailabilityHandle, + initial_check_abort_handle: AbortHandle, + proxy: AccountsProxy, +} + +impl AccountService { + pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> { + let mut proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.create_account(), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + } + + pub fn get_www_auth_token( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<String, rest::Error>> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.get_www_auth_token(account.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + } + + pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let result = retry_future_n( + move || proxy.get_expiry(token.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await; + if handle_expiry_result_inner(&result, &self.api_availability) { + self.initial_check_abort_handle.abort(); + } + result + } + + pub async fn submit_voucher( + &mut self, + account_token: AccountToken, + voucher: String, + ) -> Result<VoucherSubmission, rest::Error> { + let mut proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let result = retry_future_n( + move || proxy.submit_voucher(account_token.clone(), voucher.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await; + if result.is_ok() { + self.initial_check_abort_handle.abort(); + self.api_availability.resume_background(); + } + result + } +} + +pub fn spawn_account_service( + rpc_handle: MullvadRestHandle, + token: Option<String>, + api_availability: ApiAvailabilityHandle, +) -> AccountService { + let accounts_proxy = AccountsProxy::new(rpc_handle); + api_availability.pause_background(); + + let api_availability_copy = api_availability.clone(); + let accounts_proxy_copy = accounts_proxy.clone(); + + let (future, initial_check_abort_handle) = abortable(async move { + let token = if let Some(token) = token { + token + } else { + api_availability.pause_background(); + return; + }; + + let future_generator = move || { + let expiry_fut = api_availability.when_online(accounts_proxy.get_expiry(token.clone())); + let api_availability_copy = api_availability.clone(); + async move { handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) } + }; + let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; + retry_future(future_generator, should_retry, retry_strategy()).await; + }); + tokio::spawn(future); + + AccountService { + api_availability: api_availability_copy, + initial_check_abort_handle, + proxy: accounts_proxy_copy, + } +} + +fn handle_expiry_result_inner( + result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>, + api_availability: &ApiAvailabilityHandle, +) -> bool { + match result { + Ok(_expiry) if *_expiry >= chrono::Utc::now() => { + api_availability.resume_background(); + true + } + Ok(_expiry) => { + api_availability.pause_background(); + true + } + Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => { + if code == mullvad_rpc::INVALID_ACCOUNT { + api_availability.pause_background(); + return true; + } + false + } + Err(_) => false, + } +} + +fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool { + match result { + Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), + _ => false, + } +} + +fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool { + match result { + Ok(_) => false, + Err(error) => { + if let RestError::ApiError(status, code) = error { + *status != rest::StatusCode::NOT_FOUND + && code != mullvad_rpc::INVALID_ACCOUNT + && code != mullvad_rpc::MAX_DEVICES_REACHED + && code != mullvad_rpc::PUBKEY_IN_USE + } else { + true + } + } + } +} + +fn map_rest_error(error: rest::Error) -> Error { + match error { + RestError::ApiError(status, ref code) => { + if status == rest::StatusCode::NOT_FOUND { + return Error::InvalidDevice; + } + match code.as_str() { + mullvad_rpc::INVALID_ACCOUNT => Error::InvalidAccount, + mullvad_rpc::MAX_DEVICES_REACHED => Error::MaxDevicesReached, + _ => Error::OtherRestError(error), + } + } + error => Error::OtherRestError(error), + } +} + +fn retry_strategy() -> Jittered<ExponentialBackoff> { + Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ) + .max_delay(RETRY_BACKOFF_INTERVAL_MAX), + ) +} + +/// Checks if the current device is valid if a WireGuard tunnel cannot be set up +/// after multiple attempts. +pub(crate) struct TunnelStateChangeHandler { + manager: AccountManagerHandle, + check_validity: Arc<AtomicBool>, + wg_retry_attempt: usize, +} + +impl TunnelStateChangeHandler { + pub fn new(manager: AccountManagerHandle) -> Self { + Self { + manager, + check_validity: Arc::new(AtomicBool::new(true)), + wg_retry_attempt: 0, + } + } + + pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) { + match new_state { + TunnelStateTransition::Connecting(endpoint) => { + if endpoint.tunnel_type != TunnelType::Wireguard { + return; + } + self.wg_retry_attempt += 1; + if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 { + let handle = self.manager.clone(); + let check_validity = self.check_validity.clone(); + tokio::spawn(async move { + if !check_validity.swap(false, Ordering::SeqCst) { + return; + } + if let Err(error) = handle.validate_device().await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check device validity") + ); + if error.is_network_error() { + check_validity.store(true, Ordering::SeqCst); + } + } + }); + } + } + TunnelStateTransition::Connected(_) | TunnelStateTransition::Disconnected => { + self.check_validity.store(true, Ordering::SeqCst); + self.wg_retry_attempt = 0; + } + _ => (), + } + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 0d9ec96c87..3f1d363694 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -4,9 +4,9 @@ #[macro_use] extern crate serde; -mod account; pub mod account_history; mod api; +pub mod device; pub mod exception_logging; #[cfg(target_os = "macos")] pub mod exclusion_gid; @@ -25,6 +25,7 @@ pub mod version; mod version_check; use crate::target_state::PersistentTargetState; +use device::InnerDeviceEvent; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Future}, @@ -36,6 +37,7 @@ use mullvad_rpc::{ }; use mullvad_types::{ account::{AccountData, AccountToken, VoucherSubmission}, + device::{Device, DeviceConfig, DeviceData, DeviceEvent, DeviceId, RemoveDeviceEvent}, endpoint::MullvadEndpoint, location::{Coordinates, GeoIpLocation}, relay_constraints::{ @@ -46,7 +48,7 @@ use mullvad_types::{ settings::{DnsOptions, DnsState, Settings}, states::{TargetState, TunnelState}, version::{AppVersion, AppVersionInfo}, - wireguard::{KeygenEvent, RotationInterval}, + wireguard::{PublicKey, RotationInterval}, }; use settings::SettingsPersister; #[cfg(target_os = "android")] @@ -75,7 +77,7 @@ use talpid_types::android::AndroidContext; use talpid_types::{ net::{ openvpn::{self, ProxySettings}, - TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType, + wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType, }, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, @@ -84,12 +86,6 @@ use talpid_types::{ use tokio::fs; use tokio::io; -#[path = "wireguard.rs"] -mod wireguard; - -/// Timeout for first WireGuard key pushing -const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5); - /// Delay between generating a new WireGuard key and reconnecting const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60); @@ -124,17 +120,35 @@ pub enum Error { #[error(display = "Unable to load account history")] LoadAccountHistory(#[error(source)] account_history::Error), + #[error(display = "Failed to start account manager")] + LoadAccountManager(#[error(source)] device::Error), + + #[error(display = "Failed to log in to account")] + LoginError(#[error(source)] device::Error), + + #[error(display = "Failed to log out of account")] + LogoutError(#[error(source)] device::Error), + + #[error(display = "Failed to rotate WireGuard key")] + KeyRotationError(#[error(source)] device::Error), + + #[error(display = "Failed to list devices")] + ListDevicesError(#[error(source)] device::Error), + + #[error(display = "Failed to remove device")] + RemoveDeviceError(#[error(source)] device::Error), + #[cfg(target_os = "linux")] #[error(display = "Unable to initialize split tunneling")] InitSplitTunneling(#[error(source)] split_tunnel::Error), - #[error(display = "The account has too many wireguard keys")] - TooManyKeys, - #[cfg(windows)] #[error(display = "Split tunneling error")] SplitTunnelError(#[error(source)] split_tunnel::Error), + #[error(display = "An account is already set")] + AlreadyLoggedIn, + #[error(display = "No wireguard private key available")] NoKeyAvailable, @@ -226,8 +240,16 @@ pub enum DaemonCommand { /// Trigger an asynchronous relay list update. This returns before the relay list is actually /// updated. UpdateRelayLocations, - /// Set which account token to use for subsequent connection attempts. - SetAccount(ResponseTx<(), settings::Error>, Option<AccountToken>), + /// Log in with a given account and create a new device. + LoginAccount(ResponseTx<(), Error>, AccountToken), + /// Log out of the current account and remove the device, if they exist. + LogoutAccount(ResponseTx<(), Error>), + /// Return the current device configuration, if there is one. + GetDevice(ResponseTx<Option<DeviceConfig>, Error>), + /// Return all the devices for a given account token. + ListDevices(ResponseTx<Vec<Device>, Error>, AccountToken), + /// Remove device from a given account. + RemoveDevice(ResponseTx<(), Error>, AccountToken, DeviceId), /// Place constraints on the type of tunnel and relay UpdateRelaySettings(ResponseTx<(), settings::Error>, RelaySettingsUpdate), /// Set the allow LAN setting. @@ -256,11 +278,9 @@ pub enum DaemonCommand { /// Get the daemon settings GetSettings(oneshot::Sender<Settings>), /// Generate new wireguard key - GenerateWireguardKey(ResponseTx<wireguard::KeygenEvent, Error>), + RotateWireguardKey(ResponseTx<(), Error>), /// Return a public key of the currently set wireguard private key, if there is one - GetWireguardKey(ResponseTx<Option<wireguard::PublicKey>, Error>), - /// Verify if the currently set wireguard key is valid. - VerifyWireguardKey(ResponseTx<bool, Error>), + GetWireguardKey(ResponseTx<Option<PublicKey>, Error>), /// Get information about the currently running and latest app versions GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>), /// Get current version of the app @@ -320,19 +340,14 @@ pub(crate) enum InternalDaemonEvent { Command(DaemonCommand), /// Daemon shutdown triggered by a signal, ctrl-c or similar. TriggerShutdown, - /// Wireguard key generation event - WgKeyEvent( - ( - AccountToken, - Result<mullvad_types::wireguard::WireguardData, wireguard::Error>, - ), - ), - /// New Account created - NewAccountEvent(AccountToken, oneshot::Sender<Result<String, Error>>), /// The background job fetching new `AppVersionInfo`s got a new info object. NewAppVersionInfo(AppVersionInfo), /// Request from REST client to use a different API endpoint. GenerateApiConnectionMode(api::ApiConnectionModeRequest), + /// Sent when a device is updated in any way (key rotation, login, logout, etc.). + DeviceEvent(InnerDeviceEvent), + /// Handles updates from versions without devices. + DeviceMigrationEvent(DeviceData), /// The split tunnel paths or state were updated. #[cfg(target_os = "windows")] ExcludedPathsEvent(ExcludedPathsUpdate, oneshot::Sender<Result<(), Error>>), @@ -368,6 +383,12 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent { } } +impl From<InnerDeviceEvent> for InternalDaemonEvent { + fn from(event: InnerDeviceEvent) -> Self { + InternalDaemonEvent::DeviceEvent(event) + } +} + #[derive(Clone, Debug, Eq, PartialEq)] enum DaemonExecutionState { Running, @@ -529,8 +550,11 @@ pub trait EventListener { /// Or some flag about the currently running version is changed. fn notify_app_version(&self, app_version_info: AppVersionInfo); - /// Notify clients of a key generation event. - fn notify_key_event(&self, key_event: KeygenEvent); + /// Notify that device changed (login, logout, or key rotation). + fn notify_device_event(&self, event: DeviceEvent); + + /// Notify that a device was revoked using `RemoveDevice`. + fn notify_remove_device_event(&self, event: RemoveDeviceEvent); } pub struct Daemon<L: EventListener> { @@ -546,10 +570,10 @@ pub struct Daemon<L: EventListener> { event_listener: L, settings: SettingsPersister, account_history: account_history::AccountHistory, - account: account::AccountHandle, + device_checker: device::TunnelStateChangeHandler, + account_manager: device::AccountManagerHandle, rpc_runtime: mullvad_rpc::MullvadRpcRuntime, rpc_handle: mullvad_rpc::rest::MullvadRestHandle, - wireguard_key_manager: wireguard::KeyManager, version_updater_handle: version_check::VersionUpdaterHandle, relay_selector: relays::RelaySelector, last_generated_relay: Option<Relay>, @@ -584,11 +608,38 @@ where mullvad_rpc::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; - let runtime = tokio::runtime::Handle::current(); - let (internal_event_tx, internal_event_rx) = command_channel.destructure(); - if let Err(error) = migrations::migrate_all(&cache_dir, &settings_dir).await { + let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( + &cache_dir, + true, + #[cfg(target_os = "android")] + Self::create_bypass_tx(&internal_event_tx), + ) + .await + .map_err(Error::InitRpcFactory)?; + + let api_availability = rpc_runtime.availability_handle(); + api_availability.suspend(); + + let endpoint_updater = api::ApiEndpointUpdaterHandle::new(); + + let proxy_provider = api::create_api_config_provider( + internal_event_tx.to_specialized_sender(), + ApiConnectionMode::Direct, + ); + let rpc_handle = rpc_runtime + .mullvad_rest_handle(proxy_provider, endpoint_updater.callback()) + .await; + + if let Err(error) = migrations::migrate_all( + &cache_dir, + &settings_dir, + rpc_handle.clone(), + internal_event_tx.clone(), + ) + .await + { log::error!( "{}", error.display_chain_with_msg("Failed to migrate settings or cache") @@ -596,19 +647,45 @@ where } let settings = SettingsPersister::load(&settings_dir).await; - let target_state = if settings.get_account_token().is_none() { - PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await - } else if settings.auto_connect { + let tunnel_parameters_generator = MullvadTunnelParametersGenerator { + tx: internal_event_tx.clone(), + }; + + let account_manager = device::AccountManager::spawn( + rpc_handle.clone(), + api_availability.clone(), + &settings_dir, + settings + .tunnel_options + .wireguard + .rotation_interval + .unwrap_or_default(), + ) + .await + .map_err(Error::LoadAccountManager)?; + account_manager + .receive_events(internal_event_tx.to_specialized_sender()) + .await + .map_err(Error::LoadAccountManager)?; + let data = account_manager + .data() + .await + .map_err(Error::LoadAccountManager)?; + + let account_history = account_history::AccountHistory::new( + &settings_dir, + data.as_ref().map(|device| device.token.clone()), + ) + .await + .map_err(Error::LoadAccountHistory)?; + + let target_state = if settings.auto_connect { log::info!("Automatically connecting since auto-connect is turned on"); PersistentTargetState::force(&cache_dir, TargetState::Secured).await } else { PersistentTargetState::new(&cache_dir).await }; - let tunnel_parameters_generator = MullvadTunnelParametersGenerator { - tx: internal_event_tx.clone(), - }; - #[cfg(windows)] let exclude_paths = if settings.split_tunnel.enable_exclusions { settings @@ -621,18 +698,6 @@ where vec![] }; - let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( - &cache_dir, - true, - #[cfg(target_os = "android")] - Self::create_bypass_tx(&internal_event_tx), - ) - .await - .map_err(Error::InitRpcFactory)?; - - let api_availability = rpc_runtime.availability_handle(); - api_availability.suspend(); - let initial_api_endpoint = api::get_allowed_endpoint(rpc_runtime.address_cache.get_address().await); @@ -664,17 +729,8 @@ where .await .map_err(Error::TunnelError)?; - let endpoint_updater = api::ApiEndpointUpdaterHandle::new(); endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx)); - let proxy_provider = api::create_api_config_provider( - internal_event_tx.to_specialized_sender(), - ApiConnectionMode::Direct, - ); - let rpc_handle = rpc_runtime - .mullvad_rest_handle(proxy_provider, endpoint_updater.callback()) - .await; - Self::forward_offline_state(api_availability.clone(), offline_state_rx).await; let relay_list_listener = event_listener.clone(); @@ -700,28 +756,11 @@ where settings.show_beta_releases, ); tokio::spawn(version_updater.run()); - let account_history = - account_history::AccountHistory::new(&settings_dir, settings.get_account_token()) - .await - .map_err(Error::LoadAccountHistory)?; - - let wireguard_key_manager = wireguard::KeyManager::new( - internal_event_tx.clone(), - api_availability.clone(), - rpc_handle.clone(), - ); - - let account = account::Account::new( - runtime, - rpc_handle.clone(), - settings.get_account_token(), - api_availability.clone(), - ); // Attempt to download a fresh relay list relay_selector.update().await; - let mut daemon = Daemon { + let daemon = Daemon { tunnel_command_tx, tunnel_state: TunnelState::Disconnected, target_state, @@ -734,10 +773,10 @@ where event_listener, settings, account_history, - account, + device_checker: device::TunnelStateChangeHandler::new(account_manager.clone()), + account_manager, rpc_runtime, rpc_handle, - wireguard_key_manager, version_updater_handle, relay_selector, last_generated_relay: None, @@ -751,8 +790,6 @@ where volume_update_tx, }; - daemon.ensure_wireguard_keys_for_current_account().await; - api_availability.unsuspend(); Ok(daemon) @@ -856,10 +893,12 @@ where rpc_runtime, tunnel_state_machine_handle, target_state, + account_manager, .. } = self; shutdown_tasks.push(Box::pin(target_state.finalize())); + shutdown_tasks.push(Box::pin(account_manager.shutdown())); ( event_listener, @@ -881,16 +920,14 @@ where } Command(command) => self.handle_command(command).await, TriggerShutdown => self.trigger_shutdown_event(), - WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await, - NewAccountEvent(account_token, tx) => { - self.handle_new_account_event(account_token, tx).await - } NewAppVersionInfo(app_version_info) => { self.handle_new_app_version_info(app_version_info) } GenerateApiConnectionMode(request) => { self.handle_generate_api_connection_mode(request).await } + DeviceEvent(event) => self.handle_device_event(event).await, + DeviceMigrationEvent(event) => self.handle_device_migration_event(event).await, #[cfg(windows)] ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await, } @@ -902,6 +939,9 @@ where ) { self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition) .await; + self.device_checker + .handle_state_transition(&tunnel_state_transition); + let tunnel_state = match tunnel_state_transition { TunnelStateTransition::Disconnected => TunnelState::Disconnected, TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting { @@ -918,7 +958,12 @@ where TunnelStateTransition::Error(error_state) => TunnelState::Error(error_state), }; - self.unschedule_reconnect(); + if !tunnel_state.is_connected() { + // Cancel reconnects except when entering the connected state. + // Exempt the latter because a reconnect scheduled while connecting should not be + // aborted. + self.unschedule_reconnect(); + } log::debug!("New tunnel state: {:?}", tunnel_state); match tunnel_state { @@ -937,7 +982,7 @@ where } if let ErrorStateCause::AuthFailed(_) = error_state.cause() { - self.schedule_reconnect(Duration::from_secs(60)).await + self.schedule_reconnect(Duration::from_secs(60)) } } _ => {} @@ -967,7 +1012,7 @@ where >, retry_attempt: u32, ) { - if let Some(account_token) = self.settings.get_account_token() { + if let Ok(Some(device)) = self.account_manager.data().await { let result = match self.settings.get_relay_settings() { RelaySettings::CustomTunnelEndpoint(custom_relay) => { self.last_generated_relay = None; @@ -987,7 +1032,6 @@ where &constraints, self.settings.get_bridge_state(), retry_attempt, - self.settings.get_wireguard().is_some(), ) .ok(); if let Some(relays::RelaySelectorResult { @@ -1000,7 +1044,7 @@ where .create_tunnel_parameters( &exit_relay, endpoint, - account_token, + device.token, retry_attempt, ) .await; @@ -1111,7 +1155,13 @@ where .into()) } MullvadEndpoint::Wireguard(endpoint) => { - let wg_data = self.settings.get_wireguard().ok_or(Error::NoKeyAvailable)?; + let wg_data = self + .account_manager + .data() + .await + .map_err(|_| Error::NoKeyAvailable)? + .map(|device| device.wg_data) + .ok_or(Error::NoKeyAvailable)?; let tunnel = wireguard::TunnelConfig { private_key: wg_data.private_key, addresses: vec![ @@ -1135,7 +1185,7 @@ where } } - async fn schedule_reconnect(&mut self, delay: Duration) { + fn schedule_reconnect(&mut self, delay: Duration) { self.unschedule_reconnect(); let tunnel_command_tx = self.tx.to_specialized_sender(); @@ -1175,7 +1225,13 @@ where SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await, GetRelayLocations(tx) => self.on_get_relay_locations(tx), UpdateRelayLocations => self.on_update_relay_locations().await, - SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await, + LoginAccount(tx, account_token) => self.on_login_account(tx, account_token).await, + LogoutAccount(tx) => self.on_logout_account(tx).await, + GetDevice(tx) => self.on_get_device(tx).await, + ListDevices(tx, account_token) => self.on_list_devices(tx, account_token).await, + RemoveDevice(tx, account_token, device_id) => { + self.on_remove_device(tx, account_token, device_id).await + } GetAccountHistory(tx) => self.on_get_account_history(tx), ClearAccountHistory(tx) => self.on_clear_account_history(tx).await, UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update).await, @@ -1198,9 +1254,8 @@ where self.on_set_wireguard_rotation_interval(tx, interval).await } GetSettings(tx) => self.on_get_settings(tx), - GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await, + RotateWireguardKey(tx) => self.on_rotate_wireguard_key(tx).await, GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await, - VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await, GetVersionInfo(tx) => self.on_get_version_info(tx).await, GetCurrentVersion(tx) => self.on_get_current_version(tx), #[cfg(not(target_os = "android"))] @@ -1232,106 +1287,6 @@ where } } - async fn handle_wireguard_key_event( - &mut self, - event: ( - AccountToken, - Result<mullvad_types::wireguard::WireguardData, wireguard::Error>, - ), - ) { - let (account, result) = event; - // If the account has been reset whilst a key was being generated, the event should be - // dropped even if a new key was generated. - if self - .settings - .get_account_token() - .map(|current_account| current_account != account) - .unwrap_or(true) - { - log::info!("Dropping wireguard key event since account has been changed"); - return; - } - - match result { - Ok(data) => { - let public_key = data.get_public_key(); - let is_first_key = self.settings.get_wireguard().is_none(); - match self.settings.set_wireguard(Some(data)).await { - Ok(_) => { - if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY).await; - } - self.event_listener - .notify_key_event(KeygenEvent::NewKey(public_key)); - if is_first_key { - self.ensure_key_rotation().await; - } - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg( - "Failed to add new wireguard key to account data" - ) - ); - self.event_listener - .notify_key_event(KeygenEvent::GenerationFailure) - } - } - } - Err(wireguard::Error::TooManyKeys) => { - self.event_listener - .notify_key_event(KeygenEvent::TooManyKeys); - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg("Failed to generate wireguard key") - ); - self.event_listener - .notify_key_event(KeygenEvent::GenerationFailure); - } - } - } - - async fn ensure_key_rotation(&mut self) { - let token = match self.settings.get_account_token() { - Some(token) => token, - None => return, - }; - let public_key = match self.settings.get_wireguard() { - Some(data) => data.get_public_key(), - None => return, - }; - self.wireguard_key_manager - .set_rotation_interval( - public_key, - token, - self.settings.tunnel_options.wireguard.rotation_interval, - ) - .await; - } - - async fn handle_new_account_event( - &mut self, - new_token: AccountToken, - tx: ResponseTx<String, Error>, - ) { - match self.set_account(Some(new_token.clone())).await { - Ok(_) => { - self.set_target_state(TargetState::Unsecured).await; - let _ = tx.send(Ok(new_token)); - } - Err(err) => { - log::error!( - "{}", - err.display_chain_with_msg("Failed to save new account") - ); - let _ = tx.send(Err(Error::SettingsError(err))); - } - }; - } - fn handle_new_app_version_info(&mut self, app_version_info: AppVersionInfo) { self.app_version_info = Some(app_version_info.clone()); self.event_listener.notify_app_version(app_version_info); @@ -1409,6 +1364,49 @@ where let _ = request.response_tx.send(config); } + async fn handle_device_event(&mut self, event: InnerDeviceEvent) { + match &event { + InnerDeviceEvent::Login(device) => { + if let Err(error) = self.account_history.set(device.token.clone()).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update account history") + ); + } + if *self.target_state == TargetState::Secured { + log::debug!("Initiating tunnel restart because the account token changed"); + self.reconnect_tunnel(); + } + } + InnerDeviceEvent::Logout => { + log::info!("Disconnecting because account token was cleared"); + self.set_target_state(TargetState::Unsecured).await; + } + InnerDeviceEvent::RotatedKey(_) => { + if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { + self.schedule_reconnect(WG_RECONNECT_DELAY); + } + } + _ => (), + } + self.event_listener + .notify_device_event(DeviceEvent::from(event)); + } + + async fn handle_device_migration_event(&mut self, data: DeviceData) { + if let Ok(Some(_)) = self.account_manager.data().await { + // Discard stale device + return; + } + if let Err(error) = self.account_manager.set(data).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to move over account from old settings") + ); + } + self.reconnect_tunnel(); + } + #[cfg(windows)] async fn handle_new_excluded_paths( &mut self, @@ -1540,17 +1538,30 @@ where } async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) { - let daemon_tx = self.tx.clone(); - let future = self.account.create_account(); + let account_manager = self.account_manager.clone(); tokio::spawn(async move { - match future.await { - Ok(account_token) => { - let _ = daemon_tx.send(InternalDaemonEvent::NewAccountEvent(account_token, tx)); + let result = async { + if let Ok(Some(_)) = account_manager.data().await { + return Err(Error::AlreadyLoggedIn); } - Err(err) => { - let _ = tx.send(Err(Error::RestError(err))); - } - } + let token = account_manager + .account_service + .create_account() + .await + .map_err(Error::RestError)?; + account_manager + .login(token.clone()) + .await + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Creating new account failed") + ); + Error::LoginError(error) + })?; + Ok(token) + }; + Self::oneshot_send(tx, result.await, "create new account"); }); } @@ -1559,7 +1570,7 @@ where tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>, account_token: AccountToken, ) { - let account = self.account.clone(); + let account = self.account_manager.account_service.clone(); tokio::spawn(async move { let result = account.check_expiry(account_token).await; Self::oneshot_send( @@ -1571,16 +1582,18 @@ where } async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) { - if let Some(account_token) = self.settings.get_account_token() { - let future = self.account.get_www_auth_token(account_token); - let rpc_call = async { + if let Ok(Some(device)) = self.account_manager.data().await { + let future = self + .account_manager + .account_service + .get_www_auth_token(device.token); + tokio::spawn(async { Self::oneshot_send( tx, future.await.map_err(Error::RestError), "get_www_auth_token response", ); - }; - tokio::spawn(rpc_call); + }); } else { Self::oneshot_send( tx, @@ -1595,13 +1608,13 @@ where tx: ResponseTx<VoucherSubmission, Error>, voucher: String, ) { - if let Some(account_token) = self.settings.get_account_token() { - let mut account = self.account.clone(); + if let Ok(Some(device)) = self.account_manager.data().await { + let mut account = self.account_manager.account_service.clone(); tokio::spawn(async move { Self::oneshot_send( tx, account - .submit_voucher(account_token, voucher) + .submit_voucher(device.token, voucher) .await .map_err(Error::RestError), "submit_voucher response", @@ -1620,90 +1633,120 @@ where self.relay_selector.update().await; } - async fn on_set_account( - &mut self, - tx: ResponseTx<(), settings::Error>, - account_token: Option<String>, - ) { - match self.set_account(account_token.clone()).await { - Ok(account_changed) => { - if account_changed { - match account_token { - Some(_) => { - log::info!( - "Initiating tunnel restart because the account token changed" - ); - self.reconnect_tunnel(); - } - None => { - log::info!("Disconnecting because account token was cleared"); - self.set_target_state(TargetState::Unsecured).await; - } - }; + async fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) { + let account_manager = self.account_manager.clone(); + tokio::spawn(async move { + let result = async { + account_manager.login(account_token).await.map_err(|error| { + log::error!("{}", error.display_chain_with_msg("Login failed")); + Error::LoginError(error) + }) + }; + Self::oneshot_send(tx, result.await, "login_account response"); + }); + } + + async fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) { + let account_manager = self.account_manager.clone(); + tokio::spawn(async move { + let result = async { + account_manager.logout().await.map_err(|error| { + log::error!("{}", error.display_chain_with_msg("Logout failed")); + Error::LogoutError(error) + }) + }; + Self::oneshot_send(tx, result.await, "logout_account response"); + }); + } + + async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceConfig>, Error>) { + let account_manager = self.account_manager.clone(); + tokio::spawn(async move { + // Make sure the device is updated + match account_manager.validate_device().await { + Ok(_) | Err(device::Error::NoDevice) => (), + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update device data") + ); } - Self::oneshot_send(tx, Ok(()), "set_account response"); } - Err(error) => { - log::error!("{}", error.display_chain_with_msg("Failed to set account")); - Self::oneshot_send(tx, Err(error), "set_account response"); - } - } + + Self::oneshot_send( + tx, + Ok(account_manager + .data() + .await + .unwrap_or(None) + .map(DeviceConfig::from)), + "get_device response", + ); + }); } - async fn set_account( - &mut self, - account_token: Option<String>, - ) -> Result<bool, settings::Error> { - let previous_token = self.settings.get_account_token(); - let account_changed = self - .settings - .set_account_token(account_token.clone()) - .await?; - if account_changed { - self.event_listener - .notify_settings(self.settings.to_settings()); + async fn on_list_devices(&self, tx: ResponseTx<Vec<Device>, Error>, token: AccountToken) { + let service = self.account_manager.device_service.clone(); + tokio::spawn(async move { + Self::oneshot_send( + tx, + service + .list_devices(token) + .await + .map_err(Error::ListDevicesError), + "list_devices response", + ); + }); + } - let history_token = match account_token { - Some(token) => token, - None => previous_token.clone().unwrap_or("".to_string()), - }; - if let Err(error) = self.account_history.set(history_token).await { - log::error!( - "{}", - error.display_chain_with_msg("Failed to update account history") - ); - } + async fn on_remove_device( + &mut self, + tx: ResponseTx<(), Error>, + token: AccountToken, + device_id: DeviceId, + ) { + let device_service = self.account_manager.device_service.clone(); + let event_listener = self.event_listener.clone(); - if let Some(previous_token) = previous_token { - if let Some(previous_key) = self - .settings - .get_wireguard() - .map(|data| data.private_key.public_key()) - { - let remove_key = self - .wireguard_key_manager - .remove_key_with_backoff(previous_token, previous_key); - tokio::spawn(async move { - if let Err(error) = remove_key.await { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to remove WireGuard key for previous account" - ) - ); - } - }); + tokio::spawn(async move { + let mut devices = match device_service + .list_devices(token.clone()) + .await + .map_err(Error::ListDevicesError) + { + Ok(devices) => devices, + Err(error) => { + Self::oneshot_send(tx, Err(error), "remove_device response"); + return; } - } - if let Err(error) = self.settings.set_wireguard(None).await { - log::error!( - "{}", - error.display_chain_with_msg("Error resetting WireGuard key") - ); - } - self.ensure_wireguard_keys_for_current_account().await; - } - Ok(account_changed) + }; + if let Err(error) = device_service + .remove_device(token.clone(), device_id.clone()) + .await + .map_err(Error::RemoveDeviceError) + { + Self::oneshot_send(tx, Err(error), "remove_device response"); + return; + }; + let removed_device = + if let Some(index) = devices.iter().position(|device| device.id == device_id) { + devices.swap_remove(index) + } else { + log::error!("List did not contain the revoked device"); + Device { + id: device_id, + name: "unknown device".to_string(), + pubkey: talpid_types::net::wireguard::PublicKey::from([0u8; 32]), + ports: vec![], + } + }; + event_listener.notify_remove_device_event(RemoveDeviceEvent { + account_token: token, + removed_device, + new_devices: devices, + }); + Self::oneshot_send(tx, Ok(()), "remove_device response"); + }); } fn on_get_account_history(&mut self, tx: oneshot::Sender<Option<AccountToken>>) { @@ -1723,37 +1766,6 @@ where Self::oneshot_send(tx, result, "clear_account_history response"); } - // Remove the key associated with the current account, if there is one. - // This does not modify settings or account history. - #[cfg(not(target_os = "android"))] - fn remove_current_key_rpc(&self) -> impl std::future::Future<Output = Result<(), Error>> { - let remove_key = if let Some(token) = self.settings.get_account_token() { - if let Some(wg_data) = self.settings.get_wireguard() { - Some( - self.wireguard_key_manager - .remove_key(token, wg_data.private_key.public_key()), - ) - } else { - None - } - } else { - None - }; - - async move { - if let Some(task) = remove_key { - match task.await { - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - // This result should never occur - Err(wireguard::Error::TooManyKeys) => Err(Error::TooManyKeys), - _ => Ok(()), - } - } else { - Ok(()) - } - } - } - async fn on_get_version_info(&mut self, tx: oneshot::Sender<Option<AppVersionInfo>>) { if self.app_version_info.is_none() { log::debug!("No version cache found. Fetching new info"); @@ -1795,17 +1807,13 @@ where async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) { let mut last_error = Ok(()); - let remove_key = self.remove_current_key_rpc(); - tokio::spawn(async move { - if let Err(error) = remove_key.await { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to remove WireGuard key for previous account" - ) - ); - } - }); + if let Err(error) = self.account_manager.logout().await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to clear device cache") + ); + last_error = Err(Error::LogoutError(error)); + } if let Err(error) = self.account_history.clear().await { log::error!( @@ -2315,7 +2323,16 @@ where Ok(settings_changed) => { Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response"); if settings_changed { - self.ensure_key_rotation().await; + if let Err(error) = self + .account_manager + .set_rotation_interval(interval.unwrap_or_default()) + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update rotation interval") + ); + } self.event_listener .notify_settings(self.settings.to_settings()); } @@ -2327,128 +2344,27 @@ where } } - async fn ensure_wireguard_keys_for_current_account(&mut self) { - if let Some(account) = self.settings.get_account_token() { - if self.settings.get_wireguard().is_none() { - log::info!("Generating new WireGuard key for account"); - self.wireguard_key_manager - .spawn_key_generation_task(account, Some(FIRST_KEY_PUSH_TIMEOUT)) - .await; - } else { - log::info!("Account already has WireGuard key"); - self.ensure_key_rotation().await; - } - } - } - - async fn on_generate_wireguard_key(&mut self, tx: ResponseTx<KeygenEvent, Error>) { - match self.on_generate_wireguard_key_inner().await { - Ok(key_event) => { - Self::oneshot_send(tx, Ok(key_event), "generate_wireguard_key"); - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg("Failed to generate new wireguard key") - ); - Self::oneshot_send(tx, Err(e), "generate_wireguard_key"); - } - } - } - - async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, Error> { - let account_token = self - .settings - .get_account_token() - .ok_or(Error::NoAccountToken)?; - let wireguard_data = self.settings.get_wireguard(); - - let gen_result = match &wireguard_data { - Some(wireguard_data) => { - self.wireguard_key_manager - .replace_key(account_token.clone(), wireguard_data.get_public_key()) - .await - } - None => { - self.wireguard_key_manager - .generate_key_sync(account_token.clone()) - .await - } - }; - - match gen_result { - Ok(new_data) => { - let public_key = new_data.get_public_key(); - self.settings - .set_wireguard(Some(new_data)) - .await - .map_err(Error::SettingsError)?; - if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY).await; - } - let keygen_event = KeygenEvent::NewKey(public_key.clone()); - self.event_listener.notify_key_event(keygen_event.clone()); - - // update automatic rotation - self.wireguard_key_manager - .set_rotation_interval( - public_key, - account_token, - self.settings.tunnel_options.wireguard.rotation_interval, - ) - .await; - - Ok(keygen_event) - } - Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys), - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), - } + async fn on_rotate_wireguard_key(&self, tx: ResponseTx<(), Error>) { + let manager = self.account_manager.clone(); + tokio::spawn(async move { + let result = manager + .rotate_key() + .await + .map(|_| ()) + .map_err(Error::KeyRotationError); + Self::oneshot_send(tx, result, "rotate_wireguard_key response"); + }); } - async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<wireguard::PublicKey>, Error>) { - let result = if self.settings.get_account_token().is_some() { - Ok(self - .settings - .get_wireguard() - .map(|data| data.get_public_key())) + async fn on_get_wireguard_key(&self, tx: ResponseTx<Option<PublicKey>, Error>) { + let result = if let Ok(Some(device)) = self.account_manager.data().await { + Ok(Some(device.wg_data.get_public_key())) } else { Err(Error::NoAccountToken) }; Self::oneshot_send(tx, result, "get_wireguard_key response"); } - async fn on_verify_wireguard_key(&mut self, tx: ResponseTx<bool, Error>) { - let account = match self.settings.get_account_token() { - Some(account) => account, - None => { - Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response"); - return; - } - }; - let public_key = match self.settings.get_wireguard() { - Some(wg_data) => wg_data.private_key.public_key(), - None => { - Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response"); - return; - } - }; - - let verification_rpc = self - .wireguard_key_manager - .verify_wireguard_key(account, public_key); - - tokio::spawn(async move { - let result = match verification_rpc.await { - Ok(is_valid) => Ok(is_valid), - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), - Err(wireguard::Error::TooManyKeys) => return, - }; - Self::oneshot_send(tx, result, "verify_wireguard_key response"); - }); - } - fn on_get_settings(&self, tx: oneshot::Sender<Settings>) { Self::oneshot_send(tx, self.settings.to_settings(), "get_settings response"); } diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index ba828ed903..b6413b357b 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -1,4 +1,4 @@ -use crate::{account_history, settings, DaemonCommand, DaemonCommandSender, EventListener}; +use crate::{account_history, device, settings, DaemonCommand, DaemonCommandSender, EventListener}; use futures::{ channel::{mpsc, oneshot}, StreamExt, @@ -370,6 +370,7 @@ impl ManagementService for ManagementServiceImpl { // async fn create_new_account(&self, _: Request<()>) -> ServiceResult<String> { + log::debug!("create_new_account"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::CreateNewAccount(tx))?; self.wait_for_result(rx) @@ -378,20 +379,25 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_daemon_error) } - async fn set_account(&self, request: Request<AccountToken>) -> ServiceResult<()> { - log::debug!("set_account"); + async fn login_account(&self, request: Request<AccountToken>) -> ServiceResult<()> { + log::debug!("login_account"); let account_token = request.into_inner(); - let account_token = if account_token == "" { - None - } else { - Some(account_token) - }; let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::SetAccount(tx, account_token))?; + self.send_command_to_daemon(DaemonCommand::LoginAccount(tx, account_token))?; self.wait_for_result(rx) .await? .map(Response::new) - .map_err(map_settings_error) + .map_err(map_daemon_error) + } + + async fn logout_account(&self, _: Request<()>) -> ServiceResult<()> { + log::debug!("logout_account"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::LogoutAccount(tx))?; + self.wait_for_result(rx) + .await? + .map(Response::new) + .map_err(map_daemon_error) } async fn get_account_data( @@ -479,6 +485,44 @@ impl ManagementService for ManagementServiceImpl { }) } + // Device management + async fn get_device(&self, _: Request<()>) -> ServiceResult<types::DeviceConfig> { + log::debug!("get_device"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::GetDevice(tx))?; + let device = self + .wait_for_result(rx) + .await? + .map_err(map_daemon_error)? + .ok_or(Status::new(Code::NotFound, "no device is set"))?; + Ok(Response::new(types::DeviceConfig::from(device))) + } + + async fn list_devices( + &self, + request: Request<AccountToken>, + ) -> ServiceResult<types::DeviceList> { + log::debug!("list_devices"); + let (tx, rx) = oneshot::channel(); + let token = request.into_inner(); + self.send_command_to_daemon(DaemonCommand::ListDevices(tx, token))?; + let device = self.wait_for_result(rx).await?.map_err(map_daemon_error)?; + Ok(Response::new(types::DeviceList::from(device))) + } + + async fn remove_device(&self, request: Request<types::DeviceRemoval>) -> ServiceResult<()> { + log::debug!("remove_device"); + let (tx, rx) = oneshot::channel(); + let removal = request.into_inner(); + self.send_command_to_daemon(DaemonCommand::RemoveDevice( + tx, + removal.account_token, + removal.device_id, + ))?; + self.wait_for_result(rx).await?.map_err(map_daemon_error)?; + Ok(Response::new(())) + } + // WireGuard key management // @@ -515,15 +559,13 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_settings_error) } - async fn generate_wireguard_key(&self, _: Request<()>) -> ServiceResult<types::KeygenEvent> { - // TODO: return error for TooManyKeys, GenerationFailure - // on success, simply return the new key or nil - log::debug!("generate_wireguard_key"); + async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> { + log::debug!("rotate_wireguard_key"); let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::GenerateWireguardKey(tx))?; + self.send_command_to_daemon(DaemonCommand::RotateWireguardKey(tx))?; self.wait_for_result(rx) .await? - .map(|event| Response::new(types::KeygenEvent::from(event))) + .map(Response::new) .map_err(map_daemon_error) } @@ -538,16 +580,6 @@ impl ManagementService for ManagementServiceImpl { } } - async fn verify_wireguard_key(&self, _: Request<()>) -> ServiceResult<bool> { - log::debug!("verify_wireguard_key"); - let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::VerifyWireguardKey(tx))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_daemon_error) - } - // Split tunneling // @@ -832,14 +864,23 @@ impl EventListener for ManagementInterfaceEventBroadcaster { }) } - fn notify_key_event(&self, key_event: mullvad_types::wireguard::KeygenEvent) { - log::debug!("Broadcasting new wireguard key event"); + fn notify_device_event(&self, device: mullvad_types::device::DeviceEvent) { + log::debug!("Broadcasting device event"); self.notify(types::DaemonEvent { - event: Some(daemon_event::Event::KeyEvent(types::KeygenEvent::from( - key_event, + event: Some(daemon_event::Event::Device(types::DeviceEvent::from( + device, ))), }) } + + fn notify_remove_device_event(&self, remove_event: mullvad_types::device::RemoveDeviceEvent) { + log::debug!("Broadcasting remove device event"); + self.notify(types::DaemonEvent { + event: Some(daemon_event::Event::RemoveDevice( + types::RemoveDeviceEvent::from(remove_event), + )), + }) + } } impl ManagementInterfaceEventBroadcaster { @@ -857,6 +898,12 @@ fn map_daemon_error(error: crate::Error) -> Status { match error { DaemonError::RestError(error) => map_rest_error(error), DaemonError::SettingsError(error) => map_settings_error(error), + DaemonError::AlreadyLoggedIn => Status::already_exists(error.to_string()), + DaemonError::LoginError(error) => map_device_error(error), + DaemonError::LogoutError(error) => map_device_error(error), + DaemonError::KeyRotationError(error) => map_device_error(error), + DaemonError::ListDevicesError(error) => map_device_error(error), + DaemonError::RemoveDeviceError(error) => map_device_error(error), #[cfg(windows)] DaemonError::SplitTunnelError(error) => map_split_tunnel_error(error), DaemonError::AccountHistory(error) => map_account_history_error(error), @@ -929,6 +976,22 @@ fn map_settings_error(error: settings::Error) -> Status { } } +/// Converts an instance of [`mullvad_daemon::device::Error`] into a tonic status. +fn map_device_error(error: device::Error) -> Status { + match error { + device::Error::MaxDevicesReached => Status::new(Code::ResourceExhausted, error.to_string()), + device::Error::InvalidAccount => Status::new(Code::Unauthenticated, error.to_string()), + device::Error::InvalidDevice | device::Error::NoDevice => { + Status::new(Code::NotFound, error.to_string()) + } + device::Error::DeviceIoError(ref _error) => { + Status::new(Code::Unavailable, error.to_string()) + } + device::Error::OtherRestError(error) => map_rest_error(error), + _ => Status::new(Code::Unknown, error.to_string()), + } +} + /// Converts an instance of [`mullvad_daemon::account_history::Error`] into a tonic status. fn map_account_history_error(error: account_history::Error) -> Status { match error { diff --git a/mullvad-daemon/src/migrations/mod.rs b/mullvad-daemon/src/migrations/mod.rs index 98ad71c23c..8347b3cd76 100644 --- a/mullvad-daemon/src/migrations/mod.rs +++ b/mullvad-daemon/src/migrations/mod.rs @@ -87,7 +87,12 @@ pub enum Error { pub type Result<T> = std::result::Result<T, Error>; -pub async fn migrate_all(cache_dir: &Path, settings_dir: &Path) -> Result<()> { +pub(crate) async fn migrate_all( + cache_dir: &Path, + settings_dir: &Path, + rest_handle: mullvad_rpc::rest::MullvadRestHandle, + daemon_tx: crate::DaemonEventSender, +) -> Result<()> { #[cfg(windows)] windows::migrate_after_windows_update(settings_dir) .await @@ -114,11 +119,12 @@ pub async fn migrate_all(cache_dir: &Path, settings_dir: &Path) -> Result<()> { v2::migrate(&mut settings)?; v3::migrate(&mut settings)?; v4::migrate(&mut settings)?; - v5::migrate(&mut settings)?; account_history::migrate_location(cache_dir, settings_dir).await; account_history::migrate_formats(settings_dir, &mut settings).await?; + v5::migrate(&mut settings, rest_handle, daemon_tx).await?; + if settings == old_settings { // Nothing changed return Ok(()); diff --git a/mullvad-daemon/src/migrations/v5.rs b/mullvad-daemon/src/migrations/v5.rs index 0fcaca4e08..0695ee8c7e 100644 --- a/mullvad-daemon/src/migrations/v5.rs +++ b/mullvad-daemon/src/migrations/v5.rs @@ -1,5 +1,10 @@ use super::{Error, Result}; -use mullvad_types::settings::SettingsVersion; +use crate::{device::DeviceService, DaemonEventSender, InternalDaemonEvent}; +use mullvad_types::{ + account::AccountToken, device::DeviceData, settings::SettingsVersion, wireguard::WireguardData, +}; +use talpid_core::mpsc::Sender; +use talpid_types::ErrorExt; // ====================================================== // Section for vendoring types and values that @@ -21,16 +26,48 @@ use mullvad_types::settings::SettingsVersion; /// * `use_mulithop` was not present in the settings /// * A multihop entry location had been previously specified. /// -/// This change is backwards compatible since older daemons will just ignore `use_multihop` if -/// present. -/// /// It is also no longer valid to have `entry_location` set to null. So remove the field if it /// is null in order to make it default back to the default location. -pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { - if !version_matches(settings) { - return Ok(()); +/// +/// This also removes the account token and WireGuard key from the settings, looks up the +/// corresponding device, and eventually stores them in `device.json` instead. This is done by +/// sending the `DeviceMigrationEvent` event to the daemon. Because this is fallible, it can +/// result in the account token and private key being lost. This should not be not critical since +/// the account token is also stored in the account history. +pub(crate) async fn migrate( + settings: &mut serde_json::Value, + rest_handle: mullvad_rpc::rest::MullvadRestHandle, + daemon_tx: DaemonEventSender, +) -> Result<()> { + let migration_data = migrate_inner(settings).await?; + + if let Some(migration_data) = migration_data { + let api_handle = rest_handle.availability.clone(); + let service = DeviceService::new(rest_handle, api_handle); + match (migration_data.token, migration_data.wg_data) { + (token, Some(wg_data)) => { + log::info!("Creating a new device cache from previous settings"); + tokio::spawn(cache_from_wireguard_key(daemon_tx, service, token, wg_data)); + } + (token, None) => { + log::info!("Generating a new device for the account"); + tokio::spawn(cache_from_account(daemon_tx, service, token)); + } + } } + Ok(()) +} + +struct MigrationData { + token: AccountToken, + wg_data: Option<WireguardData>, +} + +async fn migrate_inner(settings: &mut serde_json::Value) -> Result<Option<MigrationData>> { + if !version_matches(settings) { + return Ok(None); + } let wireguard_constraints = || -> Option<&serde_json::Value> { settings .get("relay_settings")? @@ -54,11 +91,35 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { } } + if let Some(token) = settings.get("account_token").filter(|t| !t.is_null()) { + let token: AccountToken = + serde_json::from_value(token.clone()).map_err(Error::ParseError)?; + let mig_data = if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) { + let wg_data: WireguardData = + serde_json::from_value(wg_data.clone()).map_err(Error::ParseError)?; + Ok(Some(MigrationData { + token, + wg_data: Some(wg_data), + })) + } else { + Ok(Some(MigrationData { + token, + wg_data: None, + })) + }; + + let settings_map = settings.as_object_mut().ok_or(Error::NoMatchingVersion)?; + settings_map.remove("account_token"); + settings_map.remove("wireguard"); + + return mig_data; + } + // Note: Not incrementing the version number yet, since this migration is still open // for future modification. // settings["settings_version"] = serde_json::json!(SettingsVersion::V6); - Ok(()) + Ok(None) } fn version_matches(settings: &mut serde_json::Value) -> bool { @@ -68,9 +129,56 @@ fn version_matches(settings: &mut serde_json::Value) -> bool { .unwrap_or(false) } +async fn cache_from_wireguard_key( + daemon_tx: DaemonEventSender, + service: DeviceService, + token: AccountToken, + wg_data: WireguardData, +) { + let devices = match service.list_devices_with_backoff(token.clone()).await { + Ok(devices) => devices, + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to enumerate devices for account") + ); + return; + } + }; + + for device in devices.into_iter() { + if device.pubkey == wg_data.private_key.public_key() { + let _ = daemon_tx.send(InternalDaemonEvent::DeviceMigrationEvent(DeviceData { + token, + device, + wg_data, + })); + return; + } + } + log::info!("The existing WireGuard key is not valid; generating a new device"); + cache_from_account(daemon_tx, service, token).await; +} + +async fn cache_from_account( + daemon_tx: DaemonEventSender, + service: DeviceService, + token: AccountToken, +) { + match service.generate_for_account_with_backoff(token).await { + Ok(device_data) => { + let _ = daemon_tx.send(InternalDaemonEvent::DeviceMigrationEvent(device_data)); + } + Err(error) => log::error!( + "{}", + error.display_chain_with_msg("Failed to generate new device for account") + ), + } +} + #[cfg(test)] mod test { - use super::{migrate, version_matches}; + use super::{migrate_inner, version_matches}; use serde_json; pub const V5_SETTINGS_V1: &str = r#" @@ -144,7 +252,6 @@ mod test { pub const V5_SETTINGS_V2: &str = r#" { - "account_token": "1234", "relay_settings": { "normal": { "location": { @@ -212,13 +319,12 @@ mod test { } "#; - #[test] - fn test_v5_v1_migration() { + #[tokio::test] + async fn test_v5_v1_migration() { let mut old_settings = serde_json::from_str(V5_SETTINGS_V1).unwrap(); assert!(version_matches(&mut old_settings)); - - migrate(&mut old_settings).unwrap(); + migrate_inner(&mut old_settings).await.unwrap(); let new_settings: serde_json::Value = serde_json::from_str(V5_SETTINGS_V2).unwrap(); assert_eq!(&old_settings, &new_settings); diff --git a/mullvad-daemon/src/relays/mod.rs b/mullvad-daemon/src/relays/mod.rs index 332ca5fea3..c4a136369c 100644 --- a/mullvad-daemon/src/relays/mod.rs +++ b/mullvad-daemon/src/relays/mod.rs @@ -276,7 +276,6 @@ impl RelaySelector { relay_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> Result<RelaySelectorResult, Error> { match relay_constraints.tunnel_protocol { Constraint::Only(TunnelType::OpenVpn) => self.get_openvpn_endpoint( @@ -293,12 +292,9 @@ impl RelaySelector { &relay_constraints.wireguard_constraints, retry_attempt, ), - Constraint::Any => self.get_any_tunnel_endpoint( - relay_constraints, - bridge_state, - retry_attempt, - wg_key_exists, - ), + Constraint::Any => { + self.get_any_tunnel_endpoint(relay_constraints, bridge_state, retry_attempt) + } } } @@ -479,14 +475,9 @@ impl RelaySelector { relay_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> Result<RelaySelectorResult, Error> { - let preferred_constraints = self.preferred_constraints( - &relay_constraints, - bridge_state, - retry_attempt, - wg_key_exists, - ); + let preferred_constraints = + self.preferred_constraints(&relay_constraints, bridge_state, retry_attempt); let original_matcher: RelayMatcher<_> = relay_constraints.clone().into(); let preferred_tunnel_protocol = preferred_constraints.tunnel_protocol; @@ -543,14 +534,12 @@ impl RelaySelector { original_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> RelayConstraints { let (preferred_port, preferred_protocol, preferred_tunnel) = self .preferred_tunnel_constraints( retry_attempt, &original_constraints.location, &original_constraints.providers, - wg_key_exists, ); let mut relay_constraints = original_constraints.clone(); @@ -731,7 +720,6 @@ impl RelaySelector { retry_attempt: u32, location_constraint: &Constraint<LocationConstraint>, providers_constraint: &Constraint<Providers>, - wg_key_exists: bool, ) -> (Constraint<u16>, TransportProtocol, TunnelType) { #[cfg(target_os = "windows")] { @@ -757,7 +745,7 @@ impl RelaySelector { }); // If location does not support WireGuard, defer to preferred OpenVPN tunnel // constraints - if !location_supports_wireguard || !wg_key_exists { + if !location_supports_wireguard { let (preferred_port, preferred_protocol) = Self::preferred_openvpn_constraints(retry_attempt); return (preferred_port, preferred_protocol, TunnelType::OpenVpn); @@ -1159,7 +1147,7 @@ mod test { }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) @@ -1167,7 +1155,7 @@ mod test { for attempt in 0..10 { assert!(relay_selector - .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .is_ok()); } @@ -1184,7 +1172,7 @@ mod test { }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1192,7 +1180,7 @@ mod test { for attempt in 0..10 { assert!(relay_selector - .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .is_ok()); } @@ -1205,7 +1193,6 @@ mod test { &relay_constraints, BridgeState::Off, attempt, - true, ); assert_eq!( preferred.tunnel_protocol, @@ -1215,7 +1202,6 @@ mod test { &relay_constraints, BridgeState::Off, attempt, - true, ) { Ok(result) if matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)) => (), _ => panic!("OpenVPN endpoint was not selected"), @@ -1250,14 +1236,14 @@ mod test { // The same host cannot be used for entry and exit assert!(relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .is_err()); relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location2); // If the entry and exit differ, this should succeed assert!(relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .is_ok()); } @@ -1286,7 +1272,7 @@ mod test { // The exit must not equal the entry let exit_relay = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .map_err(|error| error.to_string())? .exit_relay; @@ -1301,7 +1287,7 @@ mod test { endpoint, .. } = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .map_err(|error| error.to_string())?; assert_eq!(exit_relay.hostname, specific_hostname); @@ -1336,7 +1322,7 @@ mod test { }); let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1362,7 +1348,7 @@ mod test { ..RelayConstraints::default() }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) @@ -1381,14 +1367,14 @@ mod test { #[cfg(all(unix, not(target_os = "android")))] { let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) ); } let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1405,54 +1391,6 @@ mod test { } #[test] - fn test_wg_relay_with_no_key() { - let mut relay_constraints = RelayConstraints { - tunnel_protocol: Constraint::Only(TunnelType::Wireguard), - ..RelayConstraints::default() - }; - - let relay_selector = new_relay_selector(); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect("Failed to get WireGuard relay when WireGuard relay was specified as the only tunnel protocol"); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - - relay_constraints.tunnel_protocol = Constraint::Any; - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect("Failed to get OpenVPN relay with tunnel protocol constraint set to Any and without a WireGuard key"); - - assert!(matches!(result.endpoint, MullvadEndpoint::OpenVpn(_))); - - let wireguard_specific_location = LocationConstraint::Hostname( - "se".to_string(), - "got".to_string(), - "se9-wireguard".to_string(), - ); - relay_constraints.location = Constraint::Only(wireguard_specific_location); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect( - "Failed to get a valid WireGuard relay when tunnel constraints are set to any - tunnel protocol and with a wireguard specific location without a wireguard key", - ); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) - .expect( - "Failed to get a valid WireGuard relay when tunnel constraints are set to any - tunnel protocol and with a wireguard specific location with a wireguard key", - ); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - } - - #[test] fn test_selecting_any_relay_will_consider_multihop() { let relay_constraints = RelayConstraints { wireguard_constraints: WireguardConstraints { @@ -1467,7 +1405,7 @@ mod test { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection"); // Windows will ignore WireGuard until WireGuard is supported well enough // TODO: Remove this caveat once Windows defaults to using WireGuard @@ -1502,7 +1440,7 @@ mod test { fn test_selecting_wireguard_location_will_consider_multihop() { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, true) + let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0) .expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection"); @@ -1526,7 +1464,7 @@ mod test { let relay_selector = new_relay_selector(); let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get WireGuard TCP multihop relay"); assert!(result.entry_relay.is_some()); @@ -1555,7 +1493,7 @@ mod test { let relay_selector = new_relay_selector(); let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get WireGuard TCP relay"); let endpoint = result.endpoint.unwrap_wireguard(); assert!(matches!(endpoint.peer.protocol, TransportProtocol::Tcp)); @@ -1570,7 +1508,7 @@ mod test { const INVALID_UDP_PORTS: [u16; 2] = [80, 443]; for attempt in 0..1000 { let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .expect("Failed to get WireGuard TCP multihop relay"); assert!(!INVALID_UDP_PORTS.contains(&result.endpoint.to_endpoint().address.port())); assert_eq!( @@ -1587,7 +1525,7 @@ mod test { const VALID_TCP_PORTS: [u16; 3] = [80, 443, 5001]; for attempt in 0..1000 { let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .expect("Failed to get WireGuard TCP multihop relay"); assert!(VALID_TCP_PORTS.contains(&result.endpoint.to_endpoint().address.port())); assert_eq!( @@ -1609,7 +1547,7 @@ mod test { ..RelayConstraints::default() }; relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&constraints, BridgeState::Off, 0) .expect_err("Successfully selected a relay that should be filtered"); constraints.location = Constraint::Only(LocationConstraint::Hostname( @@ -1619,7 +1557,7 @@ mod test { )); relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&constraints, BridgeState::Off, 0) .expect_err("Successfully selected a relay that should be filtered"); } } diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index ec610f63d4..bf3fe710c8 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -3,7 +3,7 @@ use futures::TryFutureExt; use mullvad_types::{ relay_constraints::{BridgeSettings, BridgeState, RelaySettingsUpdate}, settings::{DnsOptions, Settings}, - wireguard::{RotationInterval, WireguardData}, + wireguard::RotationInterval, }; #[cfg(target_os = "windows")] use std::collections::HashSet; @@ -191,21 +191,6 @@ impl SettingsPersister { settings } - /// Changes account number to the one given. Also saves the new settings to disk. - /// The boolean in the Result indicates if the account token changed or not - pub async fn set_account_token( - &mut self, - account_token: Option<String>, - ) -> Result<bool, Error> { - let should_save = self.settings.set_account_token(account_token); - self.update(should_save).await - } - - pub async fn set_wireguard(&mut self, wireguard: Option<WireguardData>) -> Result<bool, Error> { - let should_save = self.settings.set_wireguard(wireguard); - self.update(should_save).await - } - pub async fn update_relay_settings( &mut self, update: RelaySettingsUpdate, diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs deleted file mode 100644 index eb198b858b..0000000000 --- a/mullvad-daemon/src/wireguard.rs +++ /dev/null @@ -1,499 +0,0 @@ -use crate::{DaemonEventSender, InternalDaemonEvent}; -use chrono::offset::Utc; -use mullvad_rpc::{ - availability::ApiAvailabilityHandle, - rest::{Error as RestError, MullvadRestHandle}, -}; -use mullvad_types::account::AccountToken; -pub use mullvad_types::wireguard::*; -use std::{future::Future, pin::Pin, time::Duration}; - -use futures::future::{abortable, AbortHandle}; -#[cfg(not(target_os = "android"))] -use talpid_core::future_retry::constant_interval; -use talpid_core::{ - future_retry::{retry_future, retry_future_n, ExponentialBackoff, Jittered}, - mpsc::Sender, -}; - -pub use talpid_types::net::wireguard::{ - ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters, -}; -use talpid_types::ErrorExt; - -/// How long to wait before starting key rotation -const ROTATION_START_DELAY: Duration = Duration::from_secs(60 * 3); - -/// How often to check whether the key has expired. -/// A short interval is used in case the computer is ever suspended. -const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60); - -const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4); -const RETRY_INTERVAL_FACTOR: u32 = 5; -const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); - -#[cfg(not(target_os = "android"))] -const SHORT_RETRY_INTERVAL: Duration = Duration::ZERO; - -const MAX_KEY_REMOVAL_RETRIES: usize = 2; - -#[derive(err_derive::Error, Debug)] -pub enum Error { - #[error(display = "Unexpected HTTP request error")] - RestError(#[error(source)] mullvad_rpc::rest::Error), - #[error(display = "API availability check was interrupted")] - ApiCheckError(#[error(source)] mullvad_rpc::availability::Error), - #[error(display = "Account already has maximum number of keys")] - TooManyKeys, -} - -pub type Result<T> = std::result::Result<T, Error>; - -pub struct KeyManager { - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - current_job: Option<AbortHandle>, - - abort_scheduler_tx: Option<AbortHandle>, - auto_rotation_interval: RotationInterval, -} - -impl KeyManager { - pub(crate) fn new( - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - ) -> Self { - Self { - daemon_tx, - availability_handle, - http_handle, - current_job: None, - abort_scheduler_tx: None, - auto_rotation_interval: RotationInterval::default(), - } - } - - /// Reset key rotation, cancelling the current one and starting a new one for the specified - /// account - pub async fn reset_rotation(&mut self, current_key: PublicKey, account_token: AccountToken) { - self.run_automatic_rotation(account_token, current_key) - .await - } - - /// Update automatic key rotation interval - /// Passing `None` for the interval will cause the default value to be used. - pub async fn set_rotation_interval( - &mut self, - current_key: PublicKey, - account_token: AccountToken, - auto_rotation_interval: Option<RotationInterval>, - ) { - self.auto_rotation_interval = auto_rotation_interval.unwrap_or_default(); - self.reset_rotation(current_key, account_token).await; - } - - /// Stop current key generation - pub fn reset(&mut self) { - if let Some(job) = self.current_job.take() { - job.abort() - } - } - - /// Generate a new private key - pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> { - self.reset(); - let private_key = PrivateKey::new_from_random(); - - self.push_future_generator(account, private_key, None)() - .await - .map_err(Self::map_rpc_error) - } - - /// Replace a key for an account synchronously - pub async fn replace_key( - &mut self, - account: AccountToken, - old_key: PublicKey, - ) -> Result<WireguardData> { - self.reset(); - - let new_key = PrivateKey::new_from_random(); - Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await - } - - /// Verifies whether a key is valid or not. - pub fn verify_wireguard_key( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<bool>> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - async move { - match rpc.get_wireguard_key(account, &key).await { - Ok(_) => Ok(true), - Err(mullvad_rpc::rest::Error::ApiError(status, _code)) - if status == mullvad_rpc::StatusCode::NOT_FOUND => - { - Ok(false) - } - Err(err) => Err(Self::map_rpc_error(err)), - } - } - } - - /// Removes a key from an account - #[cfg(not(target_os = "android"))] - pub fn remove_key( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<()>> { - self.remove_key_inner(account, key, constant_interval(SHORT_RETRY_INTERVAL), false) - } - - /// Removes a key from an account - pub fn remove_key_with_backoff( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<()>> { - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - self.remove_key_inner(account, key, retry_strategy, true) - } - - fn remove_key_inner<D: Iterator<Item = Duration> + 'static>( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - retry_strategy: D, - offline_check: bool, - ) -> impl Future<Output = Result<()>> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - let api_handle = self.availability_handle.clone(); - let api_handle_2 = api_handle.clone(); - let future = retry_future_n( - move || { - let remove_key = rpc.remove_wireguard_key(account.clone(), key.clone()); - let wait_future = api_handle.wait_online(); - async move { - if offline_check { - let _ = wait_future.await; - } - remove_key.await - } - }, - move |result| match result { - Ok(_) => false, - Err(error) => Self::should_retry_removal(error, &api_handle_2), - }, - retry_strategy, - MAX_KEY_REMOVAL_RETRIES, - ); - async move { future.await.map_err(Self::map_rpc_error) } - } - - fn should_retry_removal(error: &RestError, api_handle: &ApiAvailabilityHandle) -> bool { - error.is_network_error() && !api_handle.get_state().is_offline() - } - - fn should_retry(error: &RestError) -> bool { - if let RestError::ApiError(_status, code) = &error { - code != mullvad_rpc::INVALID_ACCOUNT && code != mullvad_rpc::KEY_LIMIT_REACHED - } else { - true - } - } - - /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel. - pub async fn spawn_key_generation_task( - &mut self, - account: AccountToken, - timeout: Option<Duration>, - ) { - self.reset(); - let private_key = PrivateKey::new_from_random(); - - let error_tx = self.daemon_tx.clone(); - let error_account = account.clone(); - - let mut inner_future_generator = - self.push_future_generator(account.clone(), private_key, timeout); - - let availability_handle = self.availability_handle.clone(); - - let future_generator = move || { - let wait_available = availability_handle.wait_background(); - let fut = inner_future_generator(); - let error_tx = error_tx.clone(); - let error_account = error_account.clone(); - async move { - let error_account_copy = error_account.clone(); - wait_available.await.map_err(|error| { - let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( - error_account_copy, - Err(Error::ApiCheckError(error)), - ))); - false - })?; - let response = fut.await; - match response { - Ok(addresses) => Ok(addresses), - Err(err) => { - let should_retry = Self::should_retry(&err); - let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( - error_account, - Err(Self::map_rpc_error(err)), - ))); - Err(should_retry) - } - } - } - }; - - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - - let should_retry = move |result: &std::result::Result<_, bool>| -> bool { - match result { - Ok(_) => false, - Err(should_retry) => *should_retry, - } - }; - - let upload_future = retry_future(future_generator, should_retry, retry_strategy); - - let (cancellable_upload, abort_handle) = abortable(Box::pin(upload_future)); - let daemon_tx = self.daemon_tx.clone(); - let future = async move { - match cancellable_upload.await { - Ok(Ok(wireguard_data)) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account, - Ok(wireguard_data), - ))); - } - Ok(Err(_)) => {} - Err(_) => { - log::error!("Key generation cancelled"); - } - } - }; - - tokio::spawn(Box::pin(future)); - self.current_job = Some(abort_handle); - } - - fn push_future_generator( - &self, - account: AccountToken, - private_key: PrivateKey, - timeout: Option<Duration>, - ) -> Box< - dyn FnMut() -> Pin< - Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>, - > + Send, - > { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - let public_key = private_key.public_key(); - - let push_future = - move || -> std::pin::Pin<Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send >> { - let key = private_key.clone(); - let address_future = rpc - .push_wg_key(account.clone(), public_key.clone(), timeout); - Box::pin(async move { - let addresses = address_future.await?; - Ok(WireguardData { - private_key: key, - addresses, - created: Utc::now(), - }) - }) - }; - Box::new(push_future) - } - - async fn replace_key_rpc( - http_handle: MullvadRestHandle, - account: AccountToken, - old_key: PublicKey, - new_key: PrivateKey, - ) -> Result<WireguardData> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(http_handle); - let new_public_key = new_key.public_key(); - let addresses = rpc - .replace_wg_key(account, old_key.key, new_public_key) - .await - .map_err(Self::map_rpc_error)?; - Ok(WireguardData { - private_key: new_key, - addresses, - created: Utc::now(), - }) - } - - fn map_rpc_error(err: mullvad_rpc::rest::Error) -> Error { - match &err { - // TODO: Consider handling the invalid account case too. - mullvad_rpc::rest::Error::ApiError(status, message) - if *status == mullvad_rpc::StatusCode::BAD_REQUEST - && message == mullvad_rpc::KEY_LIMIT_REACHED => - { - Error::TooManyKeys - } - _ => Error::RestError(err), - } - } - - async fn wait_for_key_expiry(key: &PublicKey, rotation_interval_secs: u64) { - let mut interval = tokio::time::interval(KEY_CHECK_INTERVAL); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - interval.tick().await; - if (Utc::now().signed_duration_since(key.created)).num_seconds() as u64 - >= rotation_interval_secs - { - return; - } - } - } - - async fn create_automatic_rotation( - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - mut public_key: PublicKey, - rotation_interval_secs: u64, - account_token: AccountToken, - ) { - tokio::time::sleep(ROTATION_START_DELAY).await; - - let rotate_key_for_account = - move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> { - let wait_available = availability_handle.wait_background(); - let rotate = Self::rotate_key( - daemon_tx.clone(), - http_handle.clone(), - account_token.clone(), - old_key.clone(), - ); - Box::pin(async move { - wait_available.await?; - rotate.await - }) - }; - - loop { - Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await; - - let rotate_key_for_account_copy = rotate_key_for_account.clone(); - match Self::rotate_key_with_retries(public_key.clone(), rotate_key_for_account_copy) - .await - { - Ok(new_key) => public_key = new_key, - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg( - "Stopping automatic key rotation due to an error" - ) - ); - return; - } - } - } - } - - fn rotate_key( - daemon_tx: DaemonEventSender, - http_handle: MullvadRestHandle, - account_token: AccountToken, - old_key: PublicKey, - ) -> impl Future<Output = Result<PublicKey>> { - let new_key = PrivateKey::new_from_random(); - let rpc_result = - Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key); - - async move { - match rpc_result.await { - Ok(data) => { - // Update account data - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token, - Ok(data.clone()), - ))); - Ok(data.get_public_key()) - } - Err(Error::TooManyKeys) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token, - Err(Error::TooManyKeys), - ))); - Err(Error::TooManyKeys) - } - Err(unknown) => Err(unknown), - } - } - } - - async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey> - where - F: FnMut(&PublicKey) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> - + Clone - + 'static, - { - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - let should_retry = move |result: &Result<PublicKey>| -> bool { - match result { - Ok(_) => false, - Err(error) => match error { - Error::RestError(error) => Self::should_retry(error), - _ => false, - }, - } - }; - - retry_future( - move || rotate_key.clone()(&old_key), - should_retry, - retry_strategy, - ) - .await - } - - async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) { - self.stop_automatic_rotation(); - - log::debug!("Starting automatic key rotation job"); - // Schedule cancellable series of repeating rotation tasks - let fut = Self::create_automatic_rotation( - self.daemon_tx.clone(), - self.availability_handle.clone(), - self.http_handle.clone(), - public_key, - self.auto_rotation_interval.as_duration().as_secs(), - account_token, - ); - let (request, abort_handle) = abortable(Box::pin(fut)); - - tokio::spawn(request); - self.abort_scheduler_tx = Some(abort_handle); - } - - fn stop_automatic_rotation(&mut self) { - if let Some(abort_handle) = self.abort_scheduler_tx.take() { - log::info!("Stopping automatic key rotation"); - abort_handle.abort(); - } - } -} diff --git a/mullvad-jni/src/daemon_interface.rs b/mullvad-jni/src/daemon_interface.rs index 6f47fe21e8..3550b20d32 100644 --- a/mullvad-jni/src/daemon_interface.rs +++ b/mullvad-jni/src/daemon_interface.rs @@ -1,14 +1,15 @@ use futures::{channel::oneshot, executor::block_on}; -use mullvad_daemon::{DaemonCommand, DaemonCommandSender}; +use mullvad_daemon::{device, DaemonCommand, DaemonCommandSender}; use mullvad_types::{ account::{AccountData, AccountToken, VoucherSubmission}, + device::{Device, DeviceConfig}, location::GeoIpLocation, relay_constraints::RelaySettingsUpdate, relay_list::RelayList, settings::{DnsOptions, Settings}, states::{TargetState, TunnelState}, version::AppVersionInfo, - wireguard::{self, KeygenEvent}, + wireguard, }; #[derive(Debug, err_derive::Error)] @@ -37,6 +38,12 @@ impl From<mullvad_daemon::Error> for Error { fn from(error: mullvad_daemon::Error) -> Error { match error { mullvad_daemon::Error::RestError(error) => Error::RpcError(error), + mullvad_daemon::Error::LoginError(device::Error::OtherRestError(error)) => { + Error::RpcError(error) + } + mullvad_daemon::Error::ListDevicesError(device::Error::OtherRestError(error)) => { + Error::RpcError(error) + } error => Error::OtherError(error), } } @@ -79,16 +86,6 @@ impl DaemonInterface { block_on(rx).map(|_| ()).map_err(|_| Error::NoResponse) } - pub fn generate_wireguard_key(&self) -> Result<KeygenEvent> { - let (tx, rx) = oneshot::channel(); - - self.send_command(DaemonCommand::GenerateWireguardKey(tx))?; - - block_on(rx) - .map_err(|_| Error::NoResponse)? - .map_err(Error::from) - } - pub fn get_account_data(&self, account_token: String) -> Result<AccountData> { let (tx, rx) = oneshot::channel(); @@ -195,23 +192,54 @@ impl DaemonInterface { .map_err(Error::from) } - pub fn verify_wireguard_key(&self) -> Result<bool> { + pub fn login_account(&self, account_token: String) -> Result<()> { + let (tx, rx) = oneshot::channel(); + + self.send_command(DaemonCommand::LoginAccount(tx, account_token))?; + + block_on(rx) + .map_err(|_| Error::NoResponse)? + .map_err(Error::from) + } + + pub fn logout_account(&self) -> Result<()> { let (tx, rx) = oneshot::channel(); - self.send_command(DaemonCommand::VerifyWireguardKey(tx))?; + self.send_command(DaemonCommand::LogoutAccount(tx))?; + block_on(rx) .map_err(|_| Error::NoResponse)? .map_err(Error::from) } - pub fn set_account(&self, account_token: Option<String>) -> Result<()> { + pub fn get_device(&self) -> Result<Option<DeviceConfig>> { let (tx, rx) = oneshot::channel(); - self.send_command(DaemonCommand::SetAccount(tx, account_token))?; + self.send_command(DaemonCommand::GetDevice(tx))?; block_on(rx) .map_err(|_| Error::NoResponse)? - .map_err(|_| Error::SettingsError) + .map_err(Error::from) + } + + pub fn list_devices(&self, account_token: String) -> Result<Vec<Device>> { + let (tx, rx) = oneshot::channel(); + + self.send_command(DaemonCommand::ListDevices(tx, account_token))?; + + block_on(rx) + .map_err(|_| Error::NoResponse)? + .map_err(Error::from) + } + + pub fn remove_device(&self, account_token: String, device_id: String) -> Result<()> { + let (tx, rx) = oneshot::channel(); + + self.send_command(DaemonCommand::RemoveDevice(tx, account_token, device_id))?; + + block_on(rx) + .map_err(|_| Error::NoResponse)? + .map_err(Error::from) } pub fn set_allow_lan(&self, allow_lan: bool) -> Result<()> { diff --git a/mullvad-jni/src/jni_event_listener.rs b/mullvad-jni/src/jni_event_listener.rs index 9fd5c3d2ea..553f3f48f3 100644 --- a/mullvad-jni/src/jni_event_listener.rs +++ b/mullvad-jni/src/jni_event_listener.rs @@ -7,8 +7,11 @@ use jnix::{ }; use mullvad_daemon::EventListener; use mullvad_types::{ - relay_list::RelayList, settings::Settings, states::TunnelState, version::AppVersionInfo, - wireguard::KeygenEvent, + device::{DeviceEvent, RemoveDeviceEvent}, + relay_list::RelayList, + settings::Settings, + states::TunnelState, + version::AppVersionInfo, }; use std::{sync::mpsc, thread}; use talpid_types::ErrorExt; @@ -27,11 +30,12 @@ pub enum Error { } enum Event { - KeygenEvent(KeygenEvent), RelayList(RelayList), Settings(Settings), Tunnel(TunnelState), AppVersionInfo(AppVersionInfo), + DeviceEvent(DeviceEvent), + RemoveDeviceEvent(RemoveDeviceEvent), } #[derive(Clone, Debug)] @@ -44,10 +48,6 @@ impl JniEventListener { } impl EventListener for JniEventListener { - fn notify_key_event(&self, key_event: KeygenEvent) { - let _ = self.0.send(Event::KeygenEvent(key_event)); - } - fn notify_new_state(&self, state: TunnelState) { let _ = self.0.send(Event::Tunnel(state)); } @@ -63,16 +63,25 @@ impl EventListener for JniEventListener { fn notify_app_version(&self, app_version_info: AppVersionInfo) { let _ = self.0.send(Event::AppVersionInfo(app_version_info)); } + + fn notify_device_event(&self, event: DeviceEvent) { + let _ = self.0.send(Event::DeviceEvent(event)); + } + + fn notify_remove_device_event(&self, event: RemoveDeviceEvent) { + let _ = self.0.send(Event::RemoveDeviceEvent(event)); + } } struct JniEventHandler<'env> { env: JnixEnv<'env>, mullvad_ipc_client: JObject<'env>, notify_app_version_info_event: JMethodID<'env>, - notify_keygen_event: JMethodID<'env>, notify_relay_list_event: JMethodID<'env>, notify_settings_event: JMethodID<'env>, notify_tunnel_event: JMethodID<'env>, + notify_device_event: JMethodID<'env>, + notify_remove_device_event: JMethodID<'env>, events: mpsc::Receiver<Event>, } @@ -123,12 +132,6 @@ impl<'env> JniEventHandler<'env> { "notifyAppVersionInfoEvent", "(Lnet/mullvad/mullvadvpn/model/AppVersionInfo;)V", )?; - let notify_keygen_event = Self::get_method_id( - &env, - &class, - "notifyKeygenEvent", - "(Lnet/mullvad/mullvadvpn/model/KeygenEvent;)V", - )?; let notify_relay_list_event = Self::get_method_id( &env, &class, @@ -147,15 +150,28 @@ impl<'env> JniEventHandler<'env> { "notifyTunnelStateEvent", "(Lnet/mullvad/mullvadvpn/model/TunnelState;)V", )?; + let notify_device_event = Self::get_method_id( + &env, + &class, + "notifyDeviceEvent", + "(Lnet/mullvad/mullvadvpn/model/DeviceEvent;)V", + )?; + let notify_remove_device_event = Self::get_method_id( + &env, + &class, + "notifyRemoveDeviceEvent", + "(Lnet/mullvad/mullvadvpn/model/RemoveDeviceEvent;)V", + )?; Ok(JniEventHandler { env, mullvad_ipc_client, notify_app_version_info_event, - notify_keygen_event, notify_relay_list_event, notify_settings_event, notify_tunnel_event, + notify_device_event, + notify_remove_device_event, events, }) } @@ -173,31 +189,53 @@ impl<'env> JniEventHandler<'env> { fn run(&mut self) { while let Ok(event) = self.events.recv() { match event { - Event::KeygenEvent(keygen_event) => self.handle_keygen_event(keygen_event), Event::RelayList(relay_list) => self.handle_relay_list_event(relay_list), Event::Settings(settings) => self.handle_settings(settings), Event::Tunnel(tunnel_event) => self.handle_tunnel_event(tunnel_event), Event::AppVersionInfo(app_version_info) => { self.handle_app_version_info_event(app_version_info) } + Event::DeviceEvent(device_event) => self.handle_device_event(device_event), + Event::RemoveDeviceEvent(device_event) => { + self.handle_remove_device_event(device_event) + } } } } - fn handle_keygen_event(&self, event: KeygenEvent) { - let java_keygen_event = event.into_java(&self.env); + fn handle_device_event(&self, device_event: DeviceEvent) { + let java_event = device_event.into_java(&self.env); + + let result = self.env.call_method_unchecked( + self.mullvad_ipc_client, + self.notify_device_event, + JavaType::Primitive(Primitive::Void), + &[JValue::Object(java_event.as_obj())], + ); + + if let Err(error) = result { + log::error!( + "{}", + error.display_chain_with_msg("Failed to call MullvadDaemon.notifyDeviceEvent") + ); + } + } + + fn handle_remove_device_event(&self, remove_event: RemoveDeviceEvent) { + let java_event = remove_event.into_java(&self.env); let result = self.env.call_method_unchecked( self.mullvad_ipc_client, - self.notify_keygen_event, + self.notify_remove_device_event, JavaType::Primitive(Primitive::Void), - &[JValue::Object(java_keygen_event.as_obj())], + &[JValue::Object(java_event.as_obj())], ); if let Err(error) = result { log::error!( "{}", - error.display_chain_with_msg("Failed to call MullvadDaemon.notifyKeygenEvent") + error + .display_chain_with_msg("Failed to call MullvadDaemon.notifyRemoveDeviceEvent") ); } } diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 646988e11b..c98c132e60 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -19,7 +19,8 @@ use jnix::{ FromJava, IntoJava, JnixEnv, }; use mullvad_daemon::{ - exception_logging, logging, runtime::new_runtime_builder, version, Daemon, DaemonCommandChannel, + device, exception_logging, logging, runtime::new_runtime_builder, version, Daemon, + DaemonCommandChannel, }; use mullvad_rpc::{rest::Error as RestError, StatusCode}; use mullvad_types::{ @@ -92,6 +93,65 @@ impl From<Result<AccountData, daemon_interface::Error>> for GetAccountDataResult #[derive(IntoJava)] #[jnix(package = "net.mullvad.mullvadvpn.model")] +pub enum LoginResult { + Ok, + InvalidAccount, + MaxDevicesReached, + RpcError, + OtherError, +} + +impl From<Result<(), daemon_interface::Error>> for LoginResult { + fn from(result: Result<(), daemon_interface::Error>) -> Self { + match result { + Ok(()) => LoginResult::Ok, + Err(error) => match error { + daemon_interface::Error::OtherError(mullvad_daemon::Error::LoginError(error)) => { + match error { + device::Error::InvalidAccount => LoginResult::InvalidAccount, + device::Error::MaxDevicesReached => LoginResult::MaxDevicesReached, + device::Error::OtherRestError(_) => LoginResult::RpcError, + _ => LoginResult::OtherError, + } + } + daemon_interface::Error::RpcError(_) => LoginResult::RpcError, + _ => LoginResult::OtherError, + }, + } + } +} + +#[derive(IntoJava)] +#[jnix(package = "net.mullvad.mullvadvpn.model")] +pub enum RemoveDeviceResult { + Ok, + NotFound, + RpcError, + OtherError, +} + +impl From<Result<(), daemon_interface::Error>> for RemoveDeviceResult { + fn from(result: Result<(), daemon_interface::Error>) -> Self { + match result { + Ok(()) => RemoveDeviceResult::Ok, + Err(error) => match error { + daemon_interface::Error::OtherError(mullvad_daemon::Error::LoginError(error)) => { + match error { + device::Error::InvalidAccount => RemoveDeviceResult::RpcError, + device::Error::InvalidDevice => RemoveDeviceResult::NotFound, + device::Error::OtherRestError(_) => RemoveDeviceResult::RpcError, + _ => RemoveDeviceResult::OtherError, + } + } + daemon_interface::Error::RpcError(_) => RemoveDeviceResult::RpcError, + _ => RemoveDeviceResult::OtherError, + }, + } + } +} + +#[derive(IntoJava)] +#[jnix(package = "net.mullvad.mullvadvpn.model")] pub enum VoucherSubmissionResult { Ok(VoucherSubmission), Error(VoucherSubmissionError), @@ -439,66 +499,6 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_disconn #[no_mangle] #[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_generateWireguardKey< - 'env, ->( - env: JNIEnv<'env>, - _: JObject<'_>, - daemon_interface_address: jlong, -) -> JObject<'env> { - let env = JnixEnv::from(env); - - if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { - match daemon_interface.generate_wireguard_key() { - Ok(keygen_event) => keygen_event.into_java(&env).forget(), - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to request to generate wireguard key") - ); - JObject::null() - } - } - } else { - JObject::null() - } -} - -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_verifyWireguardKey< - 'env, ->( - env: JNIEnv<'env>, - _: JObject<'_>, - daemon_interface_address: jlong, -) -> JObject<'env> { - let env = JnixEnv::from(env); - - if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { - match daemon_interface.verify_wireguard_key() { - Ok(key_is_valid) => env - .new_object( - &env.get_class("java/lang/Boolean"), - "(Z)V", - &[JValue::Bool(key_is_valid as jboolean)], - ) - .expect("Failed to create Boolean Java object"), - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to verify wireguard key") - ); - JObject::null() - } - } - } else { - JObject::null() - } -} - -#[no_mangle] -#[allow(non_snake_case)] pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_getAccountHistory<'env>( env: JNIEnv<'env>, _: JObject<'_>, @@ -768,25 +768,118 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_clearAc #[no_mangle] #[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_setAccount( - env: JNIEnv<'_>, +pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_loginAccount<'env>( + env: JNIEnv<'env>, _: JObject<'_>, daemon_interface_address: jlong, accountToken: JString<'_>, +) -> JObject<'env> { + let env = JnixEnv::from(env); + + if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { + let account = String::from_java(&env, accountToken); + let result = daemon_interface.login_account(account); + + if let Err(ref error) = &result { + log_request_error("login account", error); + } + + LoginResult::from(result).into_java(&env).forget() + } else { + LoginResult::OtherError.into_java(&env).forget() + } +} + +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_logoutAccount( + _: JNIEnv<'_>, + _: JObject<'_>, + daemon_interface_address: jlong, ) { + if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { + if let Err(error) = daemon_interface.logout_account() { + log::error!("{}", error.display_chain_with_msg("Failed to log out")); + } + } +} + +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_getDevice<'env>( + env: JNIEnv<'env>, + _: JObject<'_>, + daemon_interface_address: jlong, +) -> JObject<'env> { let env = JnixEnv::from(env); if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { - let account = Option::from_java(&env, accountToken); + match daemon_interface.get_device() { + Ok(key) => key.into_java(&env).forget(), + Err(error) => { + log::error!("{}", error.display_chain_with_msg("Failed to get device")); + JObject::null() + } + } + } else { + JObject::null() + } +} - if let Err(error) = daemon_interface.set_account(account) { - log::error!("{}", error.display_chain_with_msg("Failed to set account")); +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_listDevices<'env>( + env: JNIEnv<'env>, + _: JObject<'_>, + daemon_interface_address: jlong, + account_token: JString<'_>, +) -> JObject<'env> { + let env = JnixEnv::from(env); + + if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { + let token = String::from_java(&env, account_token); + match daemon_interface.list_devices(token) { + Ok(key) => key.into_java(&env).forget(), + Err(error) => { + log::error!("{}", error.display_chain_with_msg("Failed to list devices")); + JObject::null() + } } + } else { + JObject::null() } } #[no_mangle] #[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_removeDevice<'env>( + env: JNIEnv<'env>, + _: JObject<'_>, + daemon_interface_address: jlong, + account_token: JString<'_>, + device_id: JString<'_>, +) -> JObject<'env> { + let env = JnixEnv::from(env); + + let result = if let Some(daemon_interface) = get_daemon_interface(daemon_interface_address) { + let token = String::from_java(&env, account_token); + let device_id = String::from_java(&env, device_id); + let raw_result = daemon_interface.remove_device(token, device_id); + + if let Err(ref error) = &raw_result { + log_request_error("remove device", error); + } + + RemoveDeviceResult::from(raw_result) + } else { + RemoveDeviceResult::OtherError + }; + + result.into_java(&env).forget() +} + +#[no_mangle] +#[allow(non_snake_case)] pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_setAllowLan( env: JNIEnv<'_>, _: JObject<'_>, diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index e690557aae..21eb6ab512 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -44,19 +44,24 @@ service ManagementService { // Account management rpc CreateNewAccount(google.protobuf.Empty) returns (google.protobuf.StringValue) {} - rpc SetAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc LoginAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc LogoutAccount(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetAccountData(google.protobuf.StringValue) returns (AccountData) {} rpc GetAccountHistory(google.protobuf.Empty) returns (AccountHistory) {} rpc ClearAccountHistory(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetWwwAuthToken(google.protobuf.Empty) returns (google.protobuf.StringValue) {} rpc SubmitVoucher(google.protobuf.StringValue) returns (VoucherSubmission) {} + // Device management + rpc GetDevice(google.protobuf.Empty) returns (DeviceConfig) {} + rpc ListDevices(google.protobuf.StringValue) returns (DeviceList) {} + rpc RemoveDevice(DeviceRemoval) returns (google.protobuf.Empty) {} + // WireGuard key management rpc SetWireguardRotationInterval(google.protobuf.Duration) returns (google.protobuf.Empty) {} rpc ResetWireguardRotationInterval(google.protobuf.Empty) returns (google.protobuf.Empty) {} - rpc GenerateWireguardKey(google.protobuf.Empty) returns (KeygenEvent) {} + rpc RotateWireguardKey(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {} - rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {} // Split tunneling (Linux) rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {} @@ -265,16 +270,15 @@ message BridgeState { } message Settings { - string account_token = 1; - RelaySettings relay_settings = 2; - BridgeSettings bridge_settings = 3; - BridgeState bridge_state = 4; - bool allow_lan = 5; - bool block_when_disconnected = 6; - bool auto_connect = 7; - TunnelOptions tunnel_options = 8; - bool show_beta_releases = 9; - SplitTunnelSettings split_tunnel = 10; + RelaySettings relay_settings = 1; + BridgeSettings bridge_settings = 2; + BridgeState bridge_state = 3; + bool allow_lan = 4; + bool block_when_disconnected = 5; + bool auto_connect = 6; + TunnelOptions tunnel_options = 7; + bool show_beta_releases = 8; + SplitTunnelSettings split_tunnel = 9; } message SplitTunnelSettings { @@ -423,16 +427,6 @@ message PublicKey { google.protobuf.Timestamp created = 2; } -message KeygenEvent { - enum KeygenEvent { - NEW_KEY = 0; - TOO_MANY_KEYS = 1; - GENERATION_FAILURE = 2; - } - KeygenEvent event = 1; - PublicKey new_key = 2; -} - message AppVersionInfo { bool supported = 1; string latest_stable = 2; @@ -521,10 +515,47 @@ message DaemonEvent { Settings settings = 2; RelayList relay_list = 3; AppVersionInfo version_info = 4; - KeygenEvent key_event = 5; + DeviceEvent device = 5; + RemoveDeviceEvent remove_device = 6; } } message RelayList { repeated RelayListCountry countries = 1; } + +message DeviceConfig { + string account_token = 1; + Device device = 2; +} + +message Device { + string id = 1; + string name = 2; + bytes pubkey = 3; + repeated DevicePort ports = 4; +} + +message DevicePort { + string id = 1; +} + +message DeviceList { + repeated Device devices = 1; +} + +message DeviceRemoval { + string account_token = 1; + string device_id = 2; +} + +message DeviceEvent { + DeviceConfig device = 1; + bool remote = 2; +} + +message RemoveDeviceEvent { + string account_token = 1; + Device removed_device = 2; + repeated Device new_device_list = 3; +} diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 5398927569..c76ada98d1 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -2,7 +2,7 @@ pub use prost_types::{Duration, Timestamp}; use mullvad_types::relay_constraints::Constraint; use std::convert::TryFrom; -use talpid_types::ErrorExt; +use talpid_types::{net::wireguard, ErrorExt}; tonic::include_proto!("mullvad_daemon.management_interface"); @@ -197,22 +197,58 @@ impl From<mullvad_types::states::TunnelState> for TunnelState { } } -impl From<mullvad_types::wireguard::KeygenEvent> for KeygenEvent { - fn from(event: mullvad_types::wireguard::KeygenEvent) -> Self { - use keygen_event::KeygenEvent as Event; - use mullvad_types::wireguard::KeygenEvent as MullvadEvent; +impl From<mullvad_types::device::Device> for Device { + fn from(device: mullvad_types::device::Device) -> Self { + Device { + id: device.id, + name: device.name, + pubkey: device.pubkey.as_bytes().to_vec(), + ports: device.ports.into_iter().map(DevicePort::from).collect(), + } + } +} - KeygenEvent { - event: match event { - MullvadEvent::NewKey(_) => i32::from(Event::NewKey), - MullvadEvent::TooManyKeys => i32::from(Event::TooManyKeys), - MullvadEvent::GenerationFailure => i32::from(Event::GenerationFailure), - }, - new_key: if let MullvadEvent::NewKey(key) = event { - Some(PublicKey::from(key)) - } else { - None - }, +impl From<mullvad_types::device::DevicePort> for DevicePort { + fn from(port: mullvad_types::device::DevicePort) -> Self { + DevicePort { id: port.id } + } +} + +impl From<mullvad_types::device::DeviceEvent> for DeviceEvent { + fn from(event: mullvad_types::device::DeviceEvent) -> Self { + DeviceEvent { + device: event.device.map(|config| DeviceConfig { + account_token: config.token, + device: Some(Device::from(config.device)), + }), + remote: event.remote, + } + } +} + +impl From<mullvad_types::device::RemoveDeviceEvent> for RemoveDeviceEvent { + fn from(event: mullvad_types::device::RemoveDeviceEvent) -> Self { + RemoveDeviceEvent { + account_token: event.account_token, + removed_device: Some(Device::from(event.removed_device)), + new_device_list: event.new_devices.into_iter().map(Device::from).collect(), + } + } +} + +impl From<mullvad_types::device::DeviceConfig> for DeviceConfig { + fn from(device: mullvad_types::device::DeviceConfig) -> Self { + DeviceConfig { + account_token: device.token, + device: Some(Device::from(device.device)), + } + } +} + +impl From<Vec<mullvad_types::device::Device>> for DeviceList { + fn from(devices: Vec<mullvad_types::device::Device>) -> Self { + DeviceList { + devices: devices.into_iter().map(Device::from).collect(), } } } @@ -387,7 +423,6 @@ impl From<&mullvad_types::settings::Settings> for Settings { let split_tunnel = None; Self { - account_token: settings.get_account_token().unwrap_or_default(), relay_settings: Some(RelaySettings::from(settings.get_relay_settings())), bridge_settings: Some(BridgeSettings::from(settings.bridge_settings.clone())), bridge_state: Some(BridgeState::from(settings.get_bridge_state())), @@ -689,6 +724,29 @@ pub enum FromProtobufTypeError { InvalidArgument(&'static str), } +impl TryFrom<Device> for mullvad_types::device::Device { + type Error = FromProtobufTypeError; + + fn try_from(device: Device) -> Result<Self, Self::Error> { + Ok(mullvad_types::device::Device { + id: device.id, + name: device.name, + pubkey: bytes_to_pubkey(&device.pubkey)?, + ports: device + .ports + .into_iter() + .map(mullvad_types::device::DevicePort::from) + .collect(), + }) + } +} + +impl From<DevicePort> for mullvad_types::device::DevicePort { + fn from(port: DevicePort) -> Self { + mullvad_types::device::DevicePort { id: port.id } + } +} + impl TryFrom<&WireguardConstraints> for mullvad_types::relay_constraints::WireguardConstraints { type Error = FromProtobufTypeError; @@ -929,7 +987,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { type Error = FromProtobufTypeError; fn try_from(config: ConnectionConfig) -> Result<mullvad_types::ConnectionConfig, Self::Error> { - use talpid_types::net::{self, openvpn, wireguard}; + use talpid_types::net::{self, openvpn}; let config = config.config.ok_or(FromProtobufTypeError::InvalidArgument( "missing connection config", @@ -974,14 +1032,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { "missing peer config", ))?; - // Copy the public key to an array - if peer.public_key.len() != 32 { - return Err(FromProtobufTypeError::InvalidArgument("invalid public key")); - } - - let mut public_key = [0; 32]; - let buffer = &peer.public_key[..public_key.len()]; - public_key.copy_from_slice(buffer); + let public_key = bytes_to_pubkey(&peer.public_key)?; let ipv4_gateway = match config.ipv4_gateway.parse() { Ok(address) => address, @@ -1037,7 +1088,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { addresses: tunnel_addresses, }, peer: wireguard::PeerConfig { - public_key: wireguard::PublicKey::from(public_key), + public_key, allowed_ips, endpoint, protocol: try_transport_protocol_from_i32(peer.protocol)?, @@ -1052,6 +1103,15 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { } } +fn bytes_to_pubkey(bytes: &[u8]) -> Result<wireguard::PublicKey, FromProtobufTypeError> { + if bytes.len() != 32 { + return Err(FromProtobufTypeError::InvalidArgument("invalid public key")); + } + let mut public_key = [0; 32]; + public_key.copy_from_slice(&bytes[..32]); + Ok(wireguard::PublicKey::from(public_key)) +} + impl From<RelayLocation> for Constraint<mullvad_types::relay_constraints::LocationConstraint> { fn from(location: RelayLocation) -> Self { use mullvad_types::relay_constraints::LocationConstraint; diff --git a/mullvad-rpc/src/access.rs b/mullvad-rpc/src/access.rs new file mode 100644 index 0000000000..d95a5319c2 --- /dev/null +++ b/mullvad-rpc/src/access.rs @@ -0,0 +1,110 @@ +use crate::{ + rest, + rest::{RequestFactory, RequestServiceHandle}, +}; +use hyper::StatusCode; +use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; +use talpid_types::ErrorExt; + +pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1"; + +#[derive(Clone)] +pub struct AccessTokenProxy { + service: RequestServiceHandle, + factory: RequestFactory, + access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>, +} + +impl AccessTokenProxy { + pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { + Self { + service, + factory, + access_from_account: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Obtain access token for an account, requesting a new one from the API if necessary. + pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> { + let existing_token = { + self.access_from_account + .lock() + .unwrap() + .get(account.as_str()) + .cloned() + }; + if let Some(access_token) = existing_token { + if access_token.is_expired() { + log::debug!("Replacing expired access token"); + return self.request_new_token(account.clone()).await; + } + log::trace!("Using stored access token"); + return Ok(access_token.access_token.clone()); + } + self.request_new_token(account.clone()).await + } + + /// Remove an access token if the API response calls for it. + pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) { + if let Err(rest::Error::ApiError(_status, code)) = response { + if code == crate::INVALID_ACCESS_TOKEN { + log::debug!("Dropping invalid access token"); + self.remove_token(account); + } + } + } + + /// Removes a stored access token. + fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> { + self.access_from_account + .lock() + .unwrap() + .remove(account) + .map(|v| v.access_token) + } + + async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> { + log::debug!("Fetching access token for an account"); + let access_token = self + .fetch_access_token(account.clone()) + .await + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain access token") + ); + error + })?; + self.access_from_account + .lock() + .unwrap() + .insert(account, access_token.clone()); + Ok(access_token.access_token) + } + + async fn fetch_access_token( + &self, + account_token: AccountToken, + ) -> Result<AccessTokenData, rest::Error> { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_number: String, + } + let request = AccessTokenRequest { + account_number: account_token, + }; + + let service = self.service.clone(); + + let rest_request = self + .factory + .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await + } +} diff --git a/mullvad-rpc/src/availability.rs b/mullvad-rpc/src/availability.rs index da8d624e80..2cf40cf53b 100644 --- a/mullvad-rpc/src/availability.rs +++ b/mullvad-rpc/src/availability.rs @@ -122,10 +122,26 @@ impl ApiAvailabilityHandle { self.wait_for_state(|state| !state.is_suspended()) } + pub fn when_bg_resumes<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> { + let wait_task = self.wait_for_state(|state| !state.is_background_paused()); + async move { + let _ = wait_task.await; + task.await + } + } + pub fn wait_background(&self) -> impl Future<Output = Result<(), Error>> { self.wait_for_state(|state| !state.is_background_paused()) } + pub fn when_online<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> { + let wait_task = self.wait_for_state(|state| !state.is_offline()); + async move { + let _ = wait_task.await; + task.await + } + } + pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> { self.wait_for_state(|state| !state.is_offline()) } diff --git a/mullvad-rpc/src/device.rs b/mullvad-rpc/src/device.rs new file mode 100644 index 0000000000..de572aa20d --- /dev/null +++ b/mullvad-rpc/src/device.rs @@ -0,0 +1,196 @@ +use http::{Method, StatusCode}; +use mullvad_types::{ + account::AccountToken, + device::{Device, DeviceId, DeviceName, DevicePort}, +}; +use std::future::Future; +use talpid_types::net::wireguard; + +use crate::rest; + +use super::ACCOUNTS_URL_PREFIX; + +#[derive(Clone)] +pub struct DevicesProxy { + handle: rest::MullvadRestHandle, +} + +#[derive(serde::Deserialize)] +struct DeviceResponse { + id: DeviceId, + name: DeviceName, + pubkey: wireguard::PublicKey, + ipv4_address: ipnetwork::Ipv4Network, + ipv6_address: ipnetwork::Ipv6Network, + ports: Vec<DevicePort>, +} + +impl DevicesProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn create( + &self, + account: AccountToken, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>> + { + #[derive(serde::Serialize)] + struct DeviceSubmission { + pubkey: wireguard::PublicKey, + } + + let submission = DeviceSubmission { pubkey }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account)), + &[StatusCode::CREATED], + ) + .await; + + let response: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + id, + name, + pubkey, + ipv4_address, + ipv6_address, + ports, + .. + } = response; + + Ok(( + Device { + id, + name, + pubkey, + ports, + }, + mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }, + )) + } + } + + pub fn get( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<Device, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn list( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn remove( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<(), rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::DELETE, + Some((access_proxy, account)), + &[StatusCode::NO_CONTENT], + ) + .await; + + response?; + Ok(()) + } + } + + pub fn replace_wg_key( + &self, + account: AccountToken, + id: DeviceId, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + { + #[derive(serde::Serialize)] + struct RotateDevicePubkey { + pubkey: wireguard::PublicKey, + } + let req_body = RotateDevicePubkey { pubkey }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id), + Method::PUT, + &req_body, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + ipv4_address, + ipv6_address, + .. + } = updated_device; + Ok(mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }) + } + } +} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 614aa3bdb6..f93d27262a 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -16,7 +16,7 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, }; -use talpid_types::{net::wireguard, ErrorExt}; +use talpid_types::ErrorExt; pub mod availability; use availability::{ApiAvailability, ApiAvailabilityHandle}; @@ -29,9 +29,12 @@ mod tls_stream; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; +mod access; mod address_cache; +pub mod device; mod relay_list; pub use address_cache::AddressCache; +pub use device::DevicesProxy; pub use hyper::StatusCode; pub use relay_list::RelayListProxy; @@ -44,11 +47,17 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER"; /// Error code returned by the Mullvad API if the account token is invalid. pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT"; -/// Error code returned by the Mullvad API if the account token is missing or invalid. -pub const INVALID_AUTH: &str = "INVALID_AUTH"; +/// Error code returned by the Mullvad API if the access token is invalid. +pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN"; + +pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED"; +pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE"; pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt"; +const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1"; +const APP_URL_PREFIX: &str = "app/v1"; + lazy_static::lazy_static! { static ref API: ApiEndpoint = ApiEndpoint::get(); } @@ -257,7 +266,7 @@ impl MullvadRpcRuntime { self.socket_bypass_tx.clone(), ) .await; - let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned())); + let factory = rest::RequestFactory::new(API.host.clone(), None); rest::MullvadRestHandle::new( service, @@ -295,8 +304,8 @@ pub struct AccountsProxy { #[derive(serde::Deserialize)] struct AccountResponse { - token: AccountToken, - expires: DateTime<Utc>, + number: AccountToken, + expiry: DateTime<Utc>, } impl AccountsProxy { @@ -309,18 +318,21 @@ impl AccountsProxy { account: AccountToken, ) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> { let service = self.handle.service.clone(); - - let response = rest::send_request( - &self.handle.factory, - service, - "/v1/me", - Method::GET, - Some(account), - &[StatusCode::OK], - ); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); async move { - let account: AccountResponse = rest::deserialize_body(response.await?).await?; - Ok(account.expires) + let response = rest::send_request( + &factory, + service, + &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let account: AccountResponse = rest::deserialize_body(response?).await?; + Ok(account.expiry) } } @@ -329,7 +341,7 @@ impl AccountsProxy { let response = rest::send_request( &self.handle.factory, service, - "/v1/accounts", + &format!("{}/accounts", ACCOUNTS_URL_PREFIX), Method::POST, None, &[StatusCode::CREATED], @@ -337,7 +349,7 @@ impl AccountsProxy { async move { let account: AccountResponse = rest::deserialize_body(response.await?).await?; - Ok(account.token) + Ok(account.number) } } @@ -352,18 +364,23 @@ impl AccountsProxy { } let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); let submission = VoucherSubmission { voucher_code }; - let response = rest::post_request_with_json( - &self.handle.factory, - service, - "/v1/submit-voucher", - &submission, - Some(account_token), - &[StatusCode::OK], - ); - - async move { rest::deserialize_body(response.await?).await } + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/submit-voucher", APP_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account_token)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } } pub fn get_www_auth_token( @@ -376,17 +393,20 @@ impl AccountsProxy { } let service = self.handle.service.clone(); - let response = rest::send_request( - &self.handle.factory, - service, - "/v1/www-auth-token", - Method::POST, - Some(account), - &[StatusCode::OK], - ); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); async move { - let response: AuthTokenResponse = rest::deserialize_body(response.await?).await?; + let response = rest::send_request( + &factory, + service, + &format!("{}/www-auth-token", APP_URL_PREFIX), + Method::POST, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + let response: AuthTokenResponse = rest::deserialize_body(response?).await?; Ok(response.auth_token) } } @@ -425,10 +445,11 @@ impl ProblemReportProxy { let service = self.handle.service.clone(); - let request = rest::post_request_with_json( + let request = rest::send_json_request( &self.handle.factory, service, - "/v1/problem-report", + &format!("{}/problem-report", APP_URL_PREFIX), + Method::POST, &report, None, &[StatusCode::NO_CONTENT], @@ -467,7 +488,7 @@ impl AppVersionProxy { ) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> { let service = self.handle.service.clone(); - let path = format!("/v1/releases/{}/{}", platform, app_version); + let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version); let request = self.handle.factory.request(&path, Method::GET); async move { @@ -481,123 +502,6 @@ impl AppVersionProxy { } } -/// Error code for when an account has too many keys. Returned when trying to push a new key. -pub const KEY_LIMIT_REACHED: &str = "KEY_LIMIT_REACHED"; -#[derive(Clone)] -pub struct WireguardKeyProxy { - handle: rest::MullvadRestHandle, -} - -impl WireguardKeyProxy { - pub fn new(handle: rest::MullvadRestHandle) -> Self { - Self { handle } - } - - pub fn push_wg_key( - &mut self, - account_token: AccountToken, - public_key: wireguard::PublicKey, - timeout: Option<std::time::Duration>, - ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + 'static - { - #[derive(serde::Serialize)] - struct PublishRequest { - pubkey: wireguard::PublicKey, - } - - let service = self.handle.service.clone(); - let body = PublishRequest { pubkey: public_key }; - - let request = self.handle.factory.post_json(&"/v1/wireguard-keys", &body); - async move { - let mut request = request?; - if let Some(timeout) = timeout { - request.set_timeout(timeout); - } - request.set_auth(Some(account_token))?; - let response = service.request(request).await?; - rest::deserialize_body( - rest::parse_rest_response(response, &[StatusCode::CREATED]).await?, - ) - .await - } - } - - pub async fn replace_wg_key( - &mut self, - account_token: AccountToken, - old: wireguard::PublicKey, - new: wireguard::PublicKey, - ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> { - #[derive(serde::Serialize)] - struct ReplacementRequest { - old: wireguard::PublicKey, - new: wireguard::PublicKey, - } - - let service = self.handle.service.clone(); - let body = ReplacementRequest { old, new }; - - let response = rest::post_request_with_json( - &self.handle.factory, - service, - &"/v1/replace-wireguard-key", - &body, - Some(account_token), - [StatusCode::CREATED, StatusCode::OK].as_slice(), - ) - .await?; - - rest::deserialize_body(response).await - } - - pub async fn get_wireguard_key( - &mut self, - account_token: AccountToken, - key: &wireguard::PublicKey, - ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> { - let service = self.handle.service.clone(); - - let response = rest::send_request( - &self.handle.factory, - service, - &format!( - "/v1/wireguard-keys/{}", - urlencoding::encode(&key.to_base64()) - ), - Method::GET, - Some(account_token), - &[StatusCode::OK], - ) - .await?; - - rest::deserialize_body(response).await - } - - pub fn remove_wireguard_key( - &mut self, - account_token: AccountToken, - key: wireguard::PublicKey, - ) -> impl Future<Output = Result<(), rest::Error>> { - let service = self.handle.service.clone(); - let future = rest::send_request( - &self.handle.factory, - service, - &format!( - "/v1/wireguard-keys/{}", - urlencoding::encode(&key.to_base64()) - ), - Method::DELETE, - Some(account_token), - &[StatusCode::NO_CONTENT], - ); - async move { - let _ = future.await?; - Ok(()) - } - } -} - #[derive(Clone)] pub struct ApiProxy { handle: rest::MullvadRestHandle, @@ -614,7 +518,7 @@ impl ApiProxy { let response = rest::send_request( &self.handle.factory, service, - "/v1/api-addrs", + &format!("{}/api-addrs", APP_URL_PREFIX), Method::GET, None, &[StatusCode::OK], diff --git a/mullvad-rpc/src/relay_list.rs b/mullvad-rpc/src/relay_list.rs index f1ed2217fd..5a8a01836f 100644 --- a/mullvad-rpc/src/relay_list.rs +++ b/mullvad-rpc/src/relay_list.rs @@ -13,7 +13,7 @@ use std::{ time::Duration, }; -/// Fetches relay list from <https://api.mullvad.net/v1/relays> +/// Fetches relay list from https://api.mullvad.net/app/v1/relays #[derive(Clone)] pub struct RelayListProxy { handle: rest::MullvadRestHandle, @@ -33,7 +33,7 @@ impl RelayListProxy { etag: Option<String>, ) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> { let service = self.handle.service.clone(); - let request = self.handle.factory.request("/v1/relays", Method::GET); + let request = self.handle.factory.request("app/v1/relays", Method::GET); let future = async move { let mut request = request?; diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 17362cce05..6f36a2a096 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,6 +1,7 @@ #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ + access::AccessTokenProxy, address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, @@ -17,6 +18,7 @@ use hyper::{ header::{self, HeaderValue}, Method, Uri, }; +use mullvad_types::account::AccountToken; use std::{ future::Future, str::FromStr, @@ -29,6 +31,8 @@ pub use hyper::StatusCode; pub type Request = hyper::Request<hyper::Body>; pub type Response = hyper::Response<hyper::Body>; +const USER_AGENT: &str = "mullvad-app"; + const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60); const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60); const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); @@ -285,6 +289,7 @@ impl RestRequest { let mut builder = http::request::Builder::new() .method(Method::GET) + .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) .header(header::ACCEPT, HeaderValue::from_static("application/json")); if let Some(host) = uri.host() { builder = builder.header(header::HOST, HeaderValue::from_str(&host)?); @@ -302,11 +307,11 @@ impl RestRequest { }) } - /// Set the auth header with the following format: `Token $auth`. + /// Set the auth header with the following format: `Bearer $auth`. pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> { let header = match auth { Some(auth) => Some( - HeaderValue::from_str(&format!("Token {}", auth)) + HeaderValue::from_str(&format!("Bearer {}", auth)) .map_err(Error::InvalidHeaderError)?, ), None => None, @@ -399,7 +404,16 @@ impl RequestFactory { } pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> { - let mut request = self.hyper_request(path, Method::POST)?; + self.json_request(Method::POST, path, body) + } + + fn json_request<S: serde::Serialize>( + &self, + method: Method, + path: &str, + body: &S, + ) -> Result<RestRequest> { + let mut request = self.hyper_request(path, method)?; let json_body = serde_json::to_string(&body)?; let body_length = json_body.as_bytes().len() as u64; @@ -429,6 +443,7 @@ impl RequestFactory { let request = http::request::Builder::new() .method(method) .uri(uri) + .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) .header(header::ACCEPT, HeaderValue::from_static("application/json")) .header(header::HOST, self.hostname.clone()); @@ -468,44 +483,64 @@ pub fn send_request( service: RequestServiceHandle, uri: &str, method: Method, - auth: Option<String>, + auth: Option<(AccessTokenProxy, AccountToken)>, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future<Output = Result<Response>> { let request = factory.request(uri, method); async move { let mut request = request?; - request.set_auth(auth)?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result } } -pub fn post_request_with_json<B: serde::Serialize>( +pub fn send_json_request<B: serde::Serialize>( factory: &RequestFactory, service: RequestServiceHandle, uri: &str, + method: Method, body: &B, - auth: Option<String>, + auth: Option<(AccessTokenProxy, AccountToken)>, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future<Output = Result<Response>> { - let request = factory.post_json(uri, body); + let request = factory.json_request(method, uri, body); async move { let mut request = request?; - request.set_auth(auth)?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result } } -pub async fn deserialize_body<T: serde::de::DeserializeOwned>(mut response: Response) -> Result<T> { - let body_length: usize = response - .headers() - .get(header::CONTENT_LENGTH) - .and_then(|header_value| header_value.to_str().ok()) - .and_then(|length| length.parse::<usize>().ok()) - .unwrap_or(0); +pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> { + let body_length = get_body_length(&response); + deserialize_body_inner(response, body_length).await +} +async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( + mut response: Response, + body_length: usize, +) -> Result<T> { let mut body: Vec<u8> = Vec::with_capacity(body_length); while let Some(chunk) = response.body_mut().next().await { body.extend(&chunk?); @@ -514,6 +549,15 @@ pub async fn deserialize_body<T: serde::de::DeserializeOwned>(mut response: Resp serde_json::from_slice(&body).map_err(Error::DeserializeError) } +fn get_body_length(response: &Response) -> usize { + response + .headers() + .get(header::CONTENT_LENGTH) + .and_then(|header_value| header_value.to_str().ok()) + .and_then(|length| length.parse::<usize>().ok()) + .unwrap_or(0) +} + pub async fn parse_rest_response( response: Response, expected_statuses: &'static [hyper::StatusCode], @@ -537,23 +581,27 @@ pub async fn parse_rest_response( } pub async fn handle_error_response<T>(response: Response) -> Result<T> { - let error_message = match response.status() { + let status = response.status(); + let error_message = match status { hyper::StatusCode::NOT_FOUND => "Not found", hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed", - status => { - let err: ErrorResponse = deserialize_body(response).await?; - - return Err(Error::ApiError(status, err.code)); - } + status => match get_body_length(&response) { + 0 => status.canonical_reason().unwrap_or("Unexpected error"), + body_length => { + let err: ErrorResponse = deserialize_body_inner(response, body_length).await?; + return Err(Error::ApiError(status, err.code)); + } + }, }; - Err(Error::ApiError(response.status(), error_message.to_owned())) + Err(Error::ApiError(status, error_message.to_owned())) } #[derive(Clone)] pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, - availability: ApiAvailabilityHandle, + pub availability: ApiAvailabilityHandle, + pub token_store: AccessTokenProxy, } impl MullvadRestHandle { @@ -563,10 +611,13 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { + let token_store = AccessTokenProxy::new(service.clone(), factory.clone()); + let handle = Self { service, factory, availability, + token_store, }; if !super::API.disable_address_cache { handle.spawn_api_address_fetcher(address_cache); diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index e65b1278f8..e9289b115a 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -17,7 +17,7 @@ lazy_static::lazy_static! { } const KEY_RETRY_INTERVAL: Duration = Duration::ZERO; -const KEY_RETRY_MAX_RETRIES: usize = 2; +const KEY_RETRY_MAX_RETRIES: usize = 4; #[repr(i32)] enum ExitStatus { @@ -63,8 +63,8 @@ pub enum Error { #[error(display = "Failed to initialize mullvad RPC runtime")] RpcInitializationError(#[error(source)] mullvad_rpc::Error), - #[error(display = "Failed to remove WireGuard key for account")] - RemoveKeyError(#[error(source)] mullvad_rpc::rest::Error), + #[error(display = "Failed to remove device from account")] + RemoveDeviceError(#[error(source)] mullvad_rpc::rest::Error), #[error(display = "Failed to obtain settings directory path")] SettingsPathError(#[error(source)] SettingsPathErrorType), @@ -72,8 +72,11 @@ pub enum Error { #[error(display = "Failed to obtain cache directory path")] CachePathError(#[error(source)] mullvad_paths::Error), - #[error(display = "Failed to update the settings")] - SettingsError(#[error(source)] mullvad_daemon::settings::Error), + #[error(display = "Failed to read the device cache")] + ReadDeviceCacheError(#[error(source)] mullvad_daemon::device::Error), + + #[error(display = "Failed to write the device cache")] + WriteDeviceCacheError(#[error(source)] mullvad_daemon::device::Error), #[error(display = "Cannot parse the version string")] ParseVersionStringError, @@ -87,7 +90,7 @@ async fn main() { App::new("prepare-restart") .about("Move a running daemon into a blocking state and save its target state"), App::new("reset-firewall").about("Remove any firewall rules introduced by the daemon"), - App::new("remove-wireguard-key").about("Removes the WireGuard key from the active account"), + App::new("remove-device").about("Remove the current device from the active account"), App::new("is-older-version") .about("Checks whether the given version is older than the current version") .arg( @@ -110,7 +113,7 @@ async fn main() { let result = match matches.subcommand() { Some(("prepare-restart", _)) => prepare_restart().await, Some(("reset-firewall", _)) => reset_firewall().await, - Some(("remove-wireguard-key", _)) => remove_wireguard_key().await, + Some(("remove-device", _)) => remove_device().await, Some(("is-older-version", sub_matches)) => { let old_version = sub_matches.value_of("OLDVERSION").unwrap(); match is_older_version(old_version).await { @@ -159,43 +162,42 @@ async fn reset_firewall() -> Result<(), Error> { .map_err(Error::FirewallError) } -async fn remove_wireguard_key() -> Result<(), Error> { +async fn remove_device() -> Result<(), Error> { let (cache_path, settings_path) = get_paths()?; - let mut settings = mullvad_daemon::settings::SettingsPersister::load(&settings_path).await; + let (cacher, data) = mullvad_daemon::device::DeviceCacher::new(&settings_path) + .await + .map_err(Error::ReadDeviceCacheError)?; + if let Some(device) = data { + let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false) + .await + .map_err(Error::RpcInitializationError)?; - if let Some(token) = settings.get_account_token() { - if let Some(wg_data) = settings.get_wireguard() { - let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false) - .await - .map_err(Error::RpcInitializationError)?; - let mut key_proxy = mullvad_rpc::WireguardKeyProxy::new( - rpc_runtime - .mullvad_rest_handle( - ApiConnectionMode::try_from_cache(&cache_path) - .await - .into_repeat(), - |_| async { true }, - ) - .await, - ); - retry_future_n( - move || { - key_proxy.remove_wireguard_key(token.clone(), wg_data.private_key.public_key()) - }, - move |result| match result { - Err(error) => error.is_network_error(), - _ => false, - }, - constant_interval(KEY_RETRY_INTERVAL), - KEY_RETRY_MAX_RETRIES, - ) + let proxy = mullvad_rpc::DevicesProxy::new( + rpc_runtime + .mullvad_rest_handle( + ApiConnectionMode::try_from_cache(&cache_path) + .await + .into_repeat(), + |_| async { true }, + ) + .await, + ); + retry_future_n( + move || proxy.remove(device.token.clone(), device.device.id.clone()), + move |result| match result { + Err(error) => error.is_network_error(), + _ => false, + }, + constant_interval(KEY_RETRY_INTERVAL), + KEY_RETRY_MAX_RETRIES, + ) + .await + .map_err(Error::RemoveDeviceError)?; + + cacher + .remove() .await - .map_err(Error::RemoveKeyError)?; - settings - .set_wireguard(None) - .await - .map_err(Error::SettingsError)?; - } + .map_err(Error::WriteDeviceCacheError)?; } Ok(()) diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs index b5479640e6..16f6a963f2 100644 --- a/mullvad-types/src/account.rs +++ b/mullvad-types/src/account.rs @@ -3,9 +3,12 @@ use chrono::{offset::Utc, DateTime}; use jnix::IntoJava; use serde::{Deserialize, Serialize}; -/// Identifier used to authenticate or identify a Mullvad account. +/// Identifier used to identify a Mullvad account. pub type AccountToken = String; +/// Identifier used to authenticate a Mullvad account. +pub type AccessToken = String; + /// Account expiration info returned by the API via `/v1/me`. #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[cfg_attr(target_os = "android", derive(IntoJava))] @@ -18,7 +21,7 @@ pub struct AccountData { impl AccountData { /// Return true if the account has no time left. pub fn is_expired(&self) -> bool { - self.expiry >= Utc::now() + Utc::now() >= self.expiry } } @@ -35,3 +38,17 @@ pub struct VoucherSubmission { #[cfg_attr(target_os = "android", jnix(map = "|expiry| expiry.to_string()"))] pub new_expiry: DateTime<Utc>, } + +/// Token used for authentication in the API. +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct AccessTokenData { + pub access_token: AccessToken, + pub expiry: DateTime<Utc>, +} + +impl AccessTokenData { + /// Return true if the token is no longer valid. + pub fn is_expired(&self) -> bool { + Utc::now() >= self.expiry + } +} diff --git a/mullvad-types/src/device.rs b/mullvad-types/src/device.rs new file mode 100644 index 0000000000..4b0123e6a9 --- /dev/null +++ b/mullvad-types/src/device.rs @@ -0,0 +1,142 @@ +use crate::{account::AccountToken, wireguard}; +#[cfg(target_os = "android")] +use jnix::IntoJava; +use serde::{Deserialize, Serialize}; +use std::fmt; +use talpid_types::net::wireguard::PublicKey; + +/// UUID for a device. +pub type DeviceId = String; + +/// Human-readable device identifier. +pub type DeviceName = String; + +/// Contains data for a device returned by the API. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[cfg_attr(target_os = "android", derive(IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] +pub struct Device { + pub id: DeviceId, + pub name: DeviceName, + #[cfg_attr(target_os = "android", jnix(map = "|key| *key.as_bytes()"))] + pub pubkey: PublicKey, + pub ports: Vec<DevicePort>, +} + +impl Eq for Device {} + +impl Device { + /// Return name with each word capitalized: "Happy Seagull" instead of "happy seagull" + pub fn pretty_name(&self) -> String { + self.name + .split_whitespace() + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(c) => c.to_uppercase().chain(chars).collect(), + } + }) + .collect::<Vec<String>>() + .join(" ") + } + + pub fn eq_id(&self, other: &Device) -> bool { + self.id == other.id + } +} + +/// Ports associated with a device. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[cfg_attr(target_os = "android", derive(IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] +pub struct DevicePort { + /// Port identifier. + pub id: String, +} + +impl fmt::Display for DevicePort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.id) + } +} + +/// A complete device configuration. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct DeviceData { + pub token: AccountToken, + pub device: Device, + pub wg_data: wireguard::WireguardData, +} + +impl From<DeviceData> for Device { + fn from(data: DeviceData) -> Device { + data.device + } +} + +/// [`DeviceData`] excluding the private key. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[cfg_attr(target_os = "android", derive(IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] +pub struct DeviceConfig { + pub token: AccountToken, + pub device: Device, +} + +impl From<DeviceData> for DeviceConfig { + fn from(data: DeviceData) -> DeviceConfig { + DeviceConfig { + token: data.token, + device: data.device, + } + } +} + +/// Emitted when logging in or out of an account, or when the device changes. +#[derive(Clone, Debug)] +#[cfg_attr(target_os = "android", derive(IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] +pub struct DeviceEvent { + /// Device that was affected. + pub device: Option<DeviceConfig>, + /// Indicates whether the change was initiated remotely or by the daemon. + pub remote: bool, +} + +impl DeviceEvent { + pub fn new(data: Option<DeviceData>, remote: bool) -> DeviceEvent { + DeviceEvent { + device: data.map(DeviceConfig::from), + remote, + } + } + + pub fn from_device(data: DeviceData, remote: bool) -> DeviceEvent { + DeviceEvent { + device: Some(DeviceConfig { + token: data.token, + device: data.device, + }), + remote, + } + } + + pub fn revoke(remote: bool) -> Self { + Self { + device: None, + remote, + } + } +} + +/// Emitted when a device is removed using the `RemoveDevice` RPC. +/// This is not sent by a normal logout or when it is revoked remotely. +#[derive(Clone, Debug)] +#[cfg_attr(target_os = "android", derive(IntoJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] +pub struct RemoveDeviceEvent { + pub account_token: AccountToken, + pub removed_device: Device, + pub new_devices: Vec<Device>, +} diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs index e93ab2f606..6d636aceb5 100644 --- a/mullvad-types/src/lib.rs +++ b/mullvad-types/src/lib.rs @@ -2,6 +2,7 @@ pub mod account; pub mod auth_failed; +pub mod device; pub mod endpoint; pub mod location; pub mod relay_constraints; diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs index 26a24202a5..63ccb480a2 100644 --- a/mullvad-types/src/settings/mod.rs +++ b/mullvad-types/src/settings/mod.rs @@ -61,9 +61,6 @@ impl Serialize for SettingsVersion { #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] pub struct Settings { - account_token: Option<String>, - #[cfg_attr(target_os = "android", jnix(skip))] - wireguard: Option<wireguard::WireguardData>, relay_settings: RelaySettings, #[cfg_attr(target_os = "android", jnix(skip))] pub bridge_settings: BridgeSettings, @@ -102,8 +99,6 @@ pub struct SplitTunnelSettings { impl Default for Settings { fn default() -> Self { Settings { - account_token: None, - wireguard: None, relay_settings: RelaySettings::Normal(RelayConstraints { location: Constraint::Only(LocationConstraint::Country("se".to_owned())), ..Default::default() @@ -123,45 +118,6 @@ impl Default for Settings { } impl Settings { - pub fn get_account_token(&self) -> Option<String> { - self.account_token.clone() - } - - /// Changes account number to the one given. Also saves the new settings to disk. - /// The boolean in the Result indicates if the account token changed or not - pub fn set_account_token(&mut self, mut account_token: Option<String>) -> bool { - if account_token.as_ref().map(String::len) == Some(0) { - log::debug!("Setting empty account token is treated as unsetting it"); - account_token = None; - } - if account_token != self.account_token { - if account_token.is_none() { - log::info!("Unsetting account token"); - } else if self.account_token.is_none() { - log::info!("Setting account token"); - } else { - log::info!("Changing account token") - } - self.account_token = account_token; - true - } else { - false - } - } - - pub fn get_wireguard(&self) -> Option<wireguard::WireguardData> { - self.wireguard.clone() - } - - pub fn set_wireguard(&mut self, wireguard: Option<wireguard::WireguardData>) -> bool { - if wireguard != self.wireguard { - self.wireguard = wireguard; - true - } else { - false - } - } - pub fn get_relay_settings(&self) -> RelaySettings { self.relay_settings.clone() } diff --git a/mullvad-types/src/states.rs b/mullvad-types/src/states.rs index 9d3b188db4..86d9e816fe 100644 --- a/mullvad-types/src/states.rs +++ b/mullvad-types/src/states.rs @@ -55,4 +55,12 @@ impl TunnelState { _ => false, } } + + /// Returns true if the tunnel state is in the connected state. + pub fn is_connected(&self) -> bool { + match self { + TunnelState::Connected { .. } => true, + _ => false, + } + } } diff --git a/mullvad-types/src/wireguard.rs b/mullvad-types/src/wireguard.rs index 2991eb1a1d..4c05f1e552 100644 --- a/mullvad-types/src/wireguard.rs +++ b/mullvad-types/src/wireguard.rs @@ -145,24 +145,3 @@ pub struct AssociatedAddresses { pub ipv4_address: ipnetwork::Ipv4Network, pub ipv6_address: ipnetwork::Ipv6Network, } - -/// Event that is emitted when the daemon has finished generating a key. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -#[cfg_attr(target_os = "android", derive(IntoJava))] -#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] -pub enum KeygenEvent { - NewKey(PublicKey), - TooManyKeys, - GenerationFailure, -} - -impl fmt::Display for KeygenEvent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - KeygenEvent::NewKey(new_key) => write!(f, "New wireguard key {}", new_key.key), - KeygenEvent::TooManyKeys => write!(f, "Account has too many keys already"), - KeygenEvent::GenerationFailure => write!(f, "Failed to generate new wireguard key"), - } - } -} diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs index 050b90c81c..8c6424bc01 100644 --- a/talpid-core/src/mpsc.rs +++ b/talpid-core/src/mpsc.rs @@ -3,3 +3,9 @@ pub trait Sender<T> { /// Sends an item over the underlying channel, failing only if the channel is closed. fn send(&self, item: T) -> Result<(), ()>; } + +impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> { + fn send(&self, content: E) -> Result<(), ()> { + self.unbounded_send(content).map_err(|_| ()) + } +} |
