diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-04-06 23:37:11 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-05-03 10:23:33 +0200 |
| commit | 5a8ce732572e83e8abbdef7afa38394730ed7ee7 (patch) | |
| tree | ac124a74d731c0cabbab90aa7c738aa779c15fdb | |
| parent | 4834225b593b7ac273ff44ef4105b87c21f1ba4e (diff) | |
| download | mullvadvpn-5a8ce732572e83e8abbdef7afa38394730ed7ee7.tar.xz mullvadvpn-5a8ce732572e83e8abbdef7afa38394730ed7ee7.zip | |
Clean up conversion from base64 to key types
| -rw-r--r-- | mullvad-cli/src/cmds/relay.rs | 12 | ||||
| -rw-r--r-- | talpid-types/src/net/wireguard.rs | 57 |
2 files changed, 30 insertions, 39 deletions
diff --git a/mullvad-cli/src/cmds/relay.rs b/mullvad-cli/src/cmds/relay.rs index 6b216c82c1..437b0977c2 100644 --- a/mullvad-cli/src/cmds/relay.rs +++ b/mullvad-cli/src/cmds/relay.rs @@ -145,10 +145,10 @@ pub enum SetCustomCommands { /// Remote port port: u16, /// Base64 encoded public key of remote peer - // TODO: parse - peer_pubkey: String, + #[arg(value_parser = wireguard::PublicKey::from_base64)] + peer_pubkey: wireguard::PublicKey, /// IP addresses of local tunnel interface - // TODO: at least one + #[arg(required = true, num_args = 1..)] tunnel_ip: Vec<IpAddr>, /// IPv4 gateway address #[arg(long)] @@ -351,7 +351,7 @@ impl Relay { async fn read_custom_wireguard_relay( host: String, port: u16, - peer_pubkey: String, + peer_pubkey: wireguard::PublicKey, tunnel_ip: Vec<IpAddr>, ipv4_gateway: Ipv4Addr, ipv6_gateway: Option<Ipv6Addr>, @@ -369,8 +369,6 @@ impl Relay { .await .unwrap(); - let peer_public_key = wireguard::PublicKey::from_base64(&peer_pubkey) - .map_err(|_| Error::InvalidCommand("invalid public key"))?; let private_key = wireguard::PrivateKey::from_base64(&private_key_str) .map_err(|_| Error::InvalidCommand("invalid private key"))?; @@ -382,7 +380,7 @@ impl Relay { addresses: tunnel_ip, }, peer: wireguard::PeerConfig { - public_key: peer_public_key, + public_key: peer_pubkey, allowed_ips: all_of_the_internet(), endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port), psk: None, diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index dd7da28b1a..efa35e9eef 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -105,14 +105,8 @@ impl PrivateKey { base64::encode(self.0.to_bytes()) } - pub fn from_base64(key: &str) -> Result<Self, InvalidKeyError> { - let bytes = base64::decode(key).map_err(|_| InvalidKeyError(()))?; - if bytes.len() != 32 { - return Err(InvalidKeyError(())); - } - let mut key = [0u8; 32]; - key.copy_from_slice(&bytes); - Ok(From::from(key)) + pub fn from_base64(key: &str) -> Result<Self, InvalidKey> { + key_from_base64(key) } } @@ -165,8 +159,13 @@ impl<'de> Deserialize<'de> for PrivateKey { pub struct PublicKey(x25519_dalek::PublicKey); /// Error returned if an input represents an invalid key -#[derive(Debug)] -pub struct InvalidKeyError(()); +#[derive(Debug, err_derive::Error)] +pub enum InvalidKey { + #[error(display = "Invalid key: {}", _0)] + Format(#[error(source)] base64::DecodeError), + #[error(display = "Invalid key length: {}", _0)] + Length(usize), +} impl PublicKey { /// Get the public key as bytes @@ -178,14 +177,8 @@ impl PublicKey { base64::encode(self.as_bytes()) } - pub fn from_base64(key: &str) -> Result<Self, InvalidKeyError> { - let bytes = base64::decode(key).map_err(|_| InvalidKeyError(()))?; - if bytes.len() != 32 { - return Err(InvalidKeyError(())); - } - let mut key = [0u8; 32]; - key.copy_from_slice(&bytes); - Ok(From::from(key)) + pub fn from_base64(key: &str) -> Result<Self, InvalidKey> { + key_from_base64(key) } } @@ -202,10 +195,11 @@ impl From<[u8; 32]> for PublicKey { } impl TryFrom<&[u8]> for PublicKey { - type Error = InvalidKeyError; + type Error = InvalidKey; fn try_from(public_key: &[u8]) -> Result<PublicKey, Self::Error> { - let key: [u8; 32] = <[u8; 32]>::try_from(public_key).map_err(|_| InvalidKeyError(()))?; + let key: [u8; 32] = + <[u8; 32]>::try_from(public_key).map_err(|_| InvalidKey::Length(public_key.len()))?; Ok(PublicKey(x25519_dalek::PublicKey::from(key))) } } @@ -294,16 +288,15 @@ where use serde::de::Error; String::deserialize(deserializer) - .and_then(|string| base64::decode(string).map_err(|err| Error::custom(err.to_string()))) - .and_then(|buffer| { - let mut key = [0u8; 32]; - if buffer.len() != 32 { - return Err(Error::custom(format!( - "Key has unexpected length: {}", - buffer.len() - ))); - } - key.copy_from_slice(&buffer); - Ok(From::from(key)) - }) + .and_then(|string| key_from_base64(&string).map_err(|err| Error::custom(err.to_string()))) +} + +fn key_from_base64<K: From<[u8; 32]>>(key: &str) -> Result<K, InvalidKey> { + let bytes = base64::decode(key).map_err(InvalidKey::Format)?; + if bytes.len() != 32 { + return Err(InvalidKey::Length(bytes.len())); + } + let mut key = [0u8; 32]; + key.copy_from_slice(&bytes); + Ok(From::from(key)) } |
