diff options
Diffstat (limited to 'control/controlbase/handshake.go')
| -rw-r--r-- | control/controlbase/handshake.go | 988 |
1 files changed, 494 insertions, 494 deletions
diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 765a4620b..937969a30 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ "go4.org/mem"
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+ "tailscale.com/types/key"
+)
+
+const (
+ // protocolName is the name of the specific instantiation of Noise
+ // that the control protocol uses. This string's value is fixed by
+ // the Noise spec, and shouldn't be changed unless we're updating
+ // the control protocol to use a different Noise instance.
+ protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+ // protocolVersion is the version of the control protocol that
+ // Client will use when initiating a handshake.
+ //protocolVersion uint16 = 1
+ // protocolVersionPrefix is the name portion of the protocol
+ // name+version string that gets mixed into the handshake as a
+ // prologue.
+ //
+ // This mixing verifies that both clients agree that they're
+ // executing the control protocol at a specific version that
+ // matches the advertised version in the cleartext packet header.
+ protocolVersionPrefix = "Tailscale Control Protocol v"
+ invalidNonce = ^uint64(0)
+)
+
+func protocolVersionPrologue(version uint16) []byte {
+ ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
+ ret = append(ret, protocolVersionPrefix...)
+ return strconv.AppendUint(ret, uint64(version), 10)
+}
+
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
+// is assumed to have already sent the client>server handshake
+// initiation message.
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
+
+// ClientDeferred initiates a control client handshake, returning the
+// initial message to send to the server and a continuation to
+// finalize the handshake.
+//
+// ClientDeferred is split in this way for RTT reduction: we run this
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
+// completely serialized the negotiation followed by the handshake,
+// we'd pay an extra RTT to transmit the handshake initiation after
+// protocol switching. By splitting the handshake into an initial
+// message and a continuation, we can embed the handshake initiation
+// into the HTTP protocol switching request and avoid a bit of delay.
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
+ var s symmetricState
+ s.Initialize()
+
+ // prologue
+ s.MixHash(protocolVersionPrologue(protocolVersion))
+
+ // <- s
+ // ...
+ s.MixHash(controlKey.UntypedBytes())
+
+ // -> e, es, s, ss
+ init := mkInitiationMessage(protocolVersion)
+ machineEphemeral := key.NewMachine()
+ machineEphemeralPub := machineEphemeral.Public()
+ copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(machineEphemeral, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing es: %w", err)
+ }
+ machineKeyPub := machineKey.Public()
+ s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
+ cipher, err = s.MixDH(machineKey, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
+ }
+ s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
+
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
+ }
+ return init[:], cont, nil
+}
+
+// Client wraps ClientDeferred and immediately invokes the returned
+// continuation with conn.
+//
+// This is a helper for when you don't need the fancy
+// continuation-style handshake, and just want to synchronously
+// upgrade a net.Conn to a secure transport.
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(init); err != nil {
+ return nil, err
+ }
+ return cont(ctx, conn)
+}
+
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ // No matter what, this function can only run once per s. Ensure
+ // attempted reuse causes a panic.
+ defer func() {
+ s.finished = true
+ }()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Read in the payload and look for errors/protocol violations from the server.
+ var resp responseMessage
+ if _, err := io.ReadFull(conn, resp.Header()); err != nil {
+ return nil, fmt.Errorf("reading response header: %w", err)
+ }
+ if resp.Type() != msgTypeResponse {
+ if resp.Type() != msgTypeError {
+ return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
+ }
+ msg := make([]byte, resp.Length())
+ if _, err := io.ReadFull(conn, msg); err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("server error: %q", msg)
+ }
+ if resp.Length() != len(resp.Payload()) {
+ return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
+ }
+ if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
+ return nil, err
+ }
+
+ // <- e, ee, se
+ controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting payload: %w", err)
+ }
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: protocolVersion,
+ peer: controlKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c1,
+ },
+ rx: rxState{
+ cipher: c2,
+ },
+ }
+ return c, nil
+}
+
+// Server initiates a control server handshake, returning the resulting
+// control connection.
+//
+// optionalInit can be the client's initial handshake message as
+// returned by ClientDeferred, or nil in which case the initial
+// message is read from conn.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Deliberately does not support formatting, so that we don't echo
+ // attacker-controlled input back to them.
+ sendErr := func(msg string) error {
+ if len(msg) >= 1<<16 {
+ msg = msg[:1<<16]
+ }
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeError
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ if _, err := io.WriteString(conn, msg); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ return fmt.Errorf("refused client handshake: %q", msg)
+ }
+
+ var s symmetricState
+ s.Initialize()
+
+ var init initiationMessage
+ if optionalInit != nil {
+ if len(optionalInit) != len(init) {
+ return nil, sendErr("wrong handshake initiation size")
+ }
+ copy(init[:], optionalInit)
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
+ return nil, err
+ }
+ // Just a rename to make it more obvious what the value is. In the
+ // current implementation we don't need to block any protocol
+ // versions at this layer, it's safe to let the handshake proceed
+ // and then let the caller make decisions based on the agreed-upon
+ // protocol version.
+ clientVersion := init.Version()
+ if init.Type() != msgTypeInitiation {
+ return nil, sendErr("unexpected handshake message type")
+ }
+ if init.Length() != len(init.Payload()) {
+ return nil, sendErr("wrong handshake initiation length")
+ }
+ // if optionalInit was provided, we have the payload already.
+ if optionalInit == nil {
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
+ return nil, err
+ }
+ }
+
+ // prologue. Can only do this once we at least think the client is
+ // handshaking using a supported version.
+ s.MixHash(protocolVersionPrologue(clientVersion))
+
+ // <- s
+ // ...
+ controlKeyPub := controlKey.Public()
+ s.MixHash(controlKeyPub.UntypedBytes())
+
+ // -> e, es, s, ss
+ machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(controlKey, machineEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing es: %w", err)
+ }
+ var machineKeyBytes [32]byte
+ if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
+ return nil, fmt.Errorf("decrypting machine key: %w", err)
+ }
+ machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
+ cipher, err = s.MixDH(controlKey, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing ss: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+ }
+
+ // <- e, ee, se
+ resp := mkResponseMessage()
+ controlEphemeral := key.NewMachine()
+ controlEphemeralPub := controlEphemeral.Public()
+ copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err = s.MixDH(controlEphemeral, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ if _, err := conn.Write(resp[:]); err != nil {
+ return nil, err
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: clientVersion,
+ peer: machineKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c2,
+ },
+ rx: rxState{
+ cipher: c1,
+ },
+ }
+ return c, nil
+}
+
+// symmetricState contains the state of an in-flight handshake.
+type symmetricState struct {
+ finished bool
+
+ h [blake2s.Size]byte // hash of currently-processed handshake state
+ ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
+}
+
+func (s *symmetricState) checkFinished() {
+ if s.finished {
+ panic("attempted to use symmetricState after Split was called")
+ }
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any handshake messages.
+func (s *symmetricState) Initialize() {
+ s.checkFinished()
+ s.h = blake2s.Sum256([]byte(protocolName))
+ s.ck = s.h
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+ s.checkFinished()
+ h := newBLAKE2s()
+ h.Write(s.h[:])
+ h.Write(data)
+ h.Sum(s.h[:0])
+}
+
+// MixDH updates s.ck with the result of X25519(priv, pub) and returns
+// a singleUseCHP that can be used to encrypt or decrypt handshake
+// data.
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
+ s.checkFinished()
+ keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
+ if err != nil {
+ return nil, fmt.Errorf("computing X25519: %w", err)
+ }
+
+ r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+ if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+ return nil, fmt.Errorf("extracting ck: %w", err)
+ }
+ var k [chp.KeySize]byte
+ if _, err := io.ReadFull(r, k[:]); err != nil {
+ return nil, fmt.Errorf("extracting k: %w", err)
+ }
+ return newSingleUseCHP(k), nil
+}
+
+// EncryptAndHash encrypts plaintext into ciphertext (which must be
+// the correct size to hold the encrypted plaintext) using cipher,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ panic("ciphertext is wrong size for given plaintext")
+ }
+ ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
+ s.MixHash(ret)
+}
+
+// DecryptAndHash decrypts the given ciphertext into plaintext (which
+// must be the correct size to hold the decrypted ciphertext) using
+// cipher. If decryption is successful, it mixes the ciphertext into
+// s.h.
+func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ return errors.New("plaintext is wrong size for given ciphertext")
+ }
+ if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
+ return err
+ }
+ s.MixHash(ciphertext)
+ return nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derived from
+// the current handshake state. Methods on s cannot be used again
+// after calling Split.
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+ s.finished = true
+
+ var k1, k2 [chp.KeySize]byte
+ r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+ if _, err := io.ReadFull(r, k1[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k1: %w", err)
+ }
+ if _, err := io.ReadFull(r, k2[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k2: %w", err)
+ }
+ c1, err = chp.New(k1[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+ }
+ c2, err = chp.New(k2[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+ }
+ return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ // Should never happen, errors only happen when using BLAKE2s
+ // in MAC mode with a key.
+ panic(err)
+ }
+ return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+ aead, err := chp.New(key[:])
+ if err != nil {
+ // Can only happen if we passed a key of the wrong length. The
+ // function signature prevents that.
+ panic(err)
+ }
+ return aead
+}
+
+// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
+// only once, either for encrypting or decrypting, but not both. The
+// chosen operation is always executed with an all-zeros
+// nonce. Subsequent calls to either Seal or Open panic.
+type singleUseCHP struct {
+ c cipher.AEAD
+}
+
+func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
+ return &singleUseCHP{newCHP(key)}
+}
+
+func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Seal(dst, nonce[:], plaintext, additionalData)
+}
+
+func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Open(dst, nonce[:], ciphertext, additionalData)
+}
|
