summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-04-06 23:37:11 +0200
committerDavid Lönnhager <david.l@mullvad.net>2023-05-03 10:23:33 +0200
commit5a8ce732572e83e8abbdef7afa38394730ed7ee7 (patch)
treeac124a74d731c0cabbab90aa7c738aa779c15fdb
parent4834225b593b7ac273ff44ef4105b87c21f1ba4e (diff)
downloadmullvadvpn-5a8ce732572e83e8abbdef7afa38394730ed7ee7.tar.xz
mullvadvpn-5a8ce732572e83e8abbdef7afa38394730ed7ee7.zip
Clean up conversion from base64 to key types
-rw-r--r--mullvad-cli/src/cmds/relay.rs12
-rw-r--r--talpid-types/src/net/wireguard.rs57
2 files changed, 30 insertions, 39 deletions
diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs
index 6b216c82c1..437b0977c2 100644
--- a/mullvad-cli/src/cmds/relay.rs
+++ b/mullvad-cli/src/cmds/relay.rs
@@ -145,10 +145,10 @@ pub enum SetCustomCommands {
/// Remote port
port: u16,
/// Base64 encoded public key of remote peer
- // TODO: parse
- peer_pubkey: String,
+ #[arg(value_parser = wireguard::PublicKey::from_base64)]
+ peer_pubkey: wireguard::PublicKey,
/// IP addresses of local tunnel interface
- // TODO: at least one
+ #[arg(required = true, num_args = 1..)]
tunnel_ip: Vec<IpAddr>,
/// IPv4 gateway address
#[arg(long)]
@@ -351,7 +351,7 @@ impl Relay {
async fn read_custom_wireguard_relay(
host: String,
port: u16,
- peer_pubkey: String,
+ peer_pubkey: wireguard::PublicKey,
tunnel_ip: Vec<IpAddr>,
ipv4_gateway: Ipv4Addr,
ipv6_gateway: Option<Ipv6Addr>,
@@ -369,8 +369,6 @@ impl Relay {
.await
.unwrap();
- let peer_public_key = wireguard::PublicKey::from_base64(&peer_pubkey)
- .map_err(|_| Error::InvalidCommand("invalid public key"))?;
let private_key = wireguard::PrivateKey::from_base64(&private_key_str)
.map_err(|_| Error::InvalidCommand("invalid private key"))?;
@@ -382,7 +380,7 @@ impl Relay {
addresses: tunnel_ip,
},
peer: wireguard::PeerConfig {
- public_key: peer_public_key,
+ public_key: peer_pubkey,
allowed_ips: all_of_the_internet(),
endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port),
psk: None,
diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs
index dd7da28b1a..efa35e9eef 100644
--- a/talpid-types/src/net/wireguard.rs
+++ b/talpid-types/src/net/wireguard.rs
@@ -105,14 +105,8 @@ impl PrivateKey {
base64::encode(self.0.to_bytes())
}
- pub fn from_base64(key: &str) -> Result<Self, InvalidKeyError> {
- let bytes = base64::decode(key).map_err(|_| InvalidKeyError(()))?;
- if bytes.len() != 32 {
- return Err(InvalidKeyError(()));
- }
- let mut key = [0u8; 32];
- key.copy_from_slice(&bytes);
- Ok(From::from(key))
+ pub fn from_base64(key: &str) -> Result<Self, InvalidKey> {
+ key_from_base64(key)
}
}
@@ -165,8 +159,13 @@ impl<'de> Deserialize<'de> for PrivateKey {
pub struct PublicKey(x25519_dalek::PublicKey);
/// Error returned if an input represents an invalid key
-#[derive(Debug)]
-pub struct InvalidKeyError(());
+#[derive(Debug, err_derive::Error)]
+pub enum InvalidKey {
+ #[error(display = "Invalid key: {}", _0)]
+ Format(#[error(source)] base64::DecodeError),
+ #[error(display = "Invalid key length: {}", _0)]
+ Length(usize),
+}
impl PublicKey {
/// Get the public key as bytes
@@ -178,14 +177,8 @@ impl PublicKey {
base64::encode(self.as_bytes())
}
- pub fn from_base64(key: &str) -> Result<Self, InvalidKeyError> {
- let bytes = base64::decode(key).map_err(|_| InvalidKeyError(()))?;
- if bytes.len() != 32 {
- return Err(InvalidKeyError(()));
- }
- let mut key = [0u8; 32];
- key.copy_from_slice(&bytes);
- Ok(From::from(key))
+ pub fn from_base64(key: &str) -> Result<Self, InvalidKey> {
+ key_from_base64(key)
}
}
@@ -202,10 +195,11 @@ impl From<[u8; 32]> for PublicKey {
}
impl TryFrom<&[u8]> for PublicKey {
- type Error = InvalidKeyError;
+ type Error = InvalidKey;
fn try_from(public_key: &[u8]) -> Result<PublicKey, Self::Error> {
- let key: [u8; 32] = <[u8; 32]>::try_from(public_key).map_err(|_| InvalidKeyError(()))?;
+ let key: [u8; 32] =
+ <[u8; 32]>::try_from(public_key).map_err(|_| InvalidKey::Length(public_key.len()))?;
Ok(PublicKey(x25519_dalek::PublicKey::from(key)))
}
}
@@ -294,16 +288,15 @@ where
use serde::de::Error;
String::deserialize(deserializer)
- .and_then(|string| base64::decode(string).map_err(|err| Error::custom(err.to_string())))
- .and_then(|buffer| {
- let mut key = [0u8; 32];
- if buffer.len() != 32 {
- return Err(Error::custom(format!(
- "Key has unexpected length: {}",
- buffer.len()
- )));
- }
- key.copy_from_slice(&buffer);
- Ok(From::from(key))
- })
+ .and_then(|string| key_from_base64(&string).map_err(|err| Error::custom(err.to_string())))
+}
+
+fn key_from_base64<K: From<[u8; 32]>>(key: &str) -> Result<K, InvalidKey> {
+ let bytes = base64::decode(key).map_err(InvalidKey::Format)?;
+ if bytes.len() != 32 {
+ return Err(InvalidKey::Length(bytes.len()));
+ }
+ let mut key = [0u8; 32];
+ key.copy_from_slice(&bytes);
+ Ok(From::from(key))
}