summaryrefslogtreecommitdiffhomepage
path: root/control/controlbase/handshake.go
diff options
context:
space:
mode:
Diffstat (limited to 'control/controlbase/handshake.go')
-rw-r--r--control/controlbase/handshake.go988
1 files changed, 494 insertions, 494 deletions
diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go
index 765a4620b..937969a30 100644
--- a/control/controlbase/handshake.go
+++ b/control/controlbase/handshake.go
@@ -1,494 +1,494 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package controlbase
-
-import (
- "context"
- "crypto/cipher"
- "encoding/binary"
- "errors"
- "fmt"
- "hash"
- "io"
- "net"
- "strconv"
- "time"
-
- "go4.org/mem"
- "golang.org/x/crypto/blake2s"
- chp "golang.org/x/crypto/chacha20poly1305"
- "golang.org/x/crypto/curve25519"
- "golang.org/x/crypto/hkdf"
- "tailscale.com/types/key"
-)
-
-const (
- // protocolName is the name of the specific instantiation of Noise
- // that the control protocol uses. This string's value is fixed by
- // the Noise spec, and shouldn't be changed unless we're updating
- // the control protocol to use a different Noise instance.
- protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
- // protocolVersion is the version of the control protocol that
- // Client will use when initiating a handshake.
- //protocolVersion uint16 = 1
- // protocolVersionPrefix is the name portion of the protocol
- // name+version string that gets mixed into the handshake as a
- // prologue.
- //
- // This mixing verifies that both clients agree that they're
- // executing the control protocol at a specific version that
- // matches the advertised version in the cleartext packet header.
- protocolVersionPrefix = "Tailscale Control Protocol v"
- invalidNonce = ^uint64(0)
-)
-
-func protocolVersionPrologue(version uint16) []byte {
- ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
- ret = append(ret, protocolVersionPrefix...)
- return strconv.AppendUint(ret, uint64(version), 10)
-}
-
-// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
-// is assumed to have already sent the client>server handshake
-// initiation message.
-type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
-
-// ClientDeferred initiates a control client handshake, returning the
-// initial message to send to the server and a continuation to
-// finalize the handshake.
-//
-// ClientDeferred is split in this way for RTT reduction: we run this
-// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
-// completely serialized the negotiation followed by the handshake,
-// we'd pay an extra RTT to transmit the handshake initiation after
-// protocol switching. By splitting the handshake into an initial
-// message and a continuation, we can embed the handshake initiation
-// into the HTTP protocol switching request and avoid a bit of delay.
-func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
- var s symmetricState
- s.Initialize()
-
- // prologue
- s.MixHash(protocolVersionPrologue(protocolVersion))
-
- // <- s
- // ...
- s.MixHash(controlKey.UntypedBytes())
-
- // -> e, es, s, ss
- init := mkInitiationMessage(protocolVersion)
- machineEphemeral := key.NewMachine()
- machineEphemeralPub := machineEphemeral.Public()
- copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
- s.MixHash(machineEphemeralPub.UntypedBytes())
- cipher, err := s.MixDH(machineEphemeral, controlKey)
- if err != nil {
- return nil, nil, fmt.Errorf("computing es: %w", err)
- }
- machineKeyPub := machineKey.Public()
- s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
- cipher, err = s.MixDH(machineKey, controlKey)
- if err != nil {
- return nil, nil, fmt.Errorf("computing ss: %w", err)
- }
- s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
-
- cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
- return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
- }
- return init[:], cont, nil
-}
-
-// Client wraps ClientDeferred and immediately invokes the returned
-// continuation with conn.
-//
-// This is a helper for when you don't need the fancy
-// continuation-style handshake, and just want to synchronously
-// upgrade a net.Conn to a secure transport.
-func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
- init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
- if err != nil {
- return nil, err
- }
- if _, err := conn.Write(init); err != nil {
- return nil, err
- }
- return cont(ctx, conn)
-}
-
-func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
- // No matter what, this function can only run once per s. Ensure
- // attempted reuse causes a panic.
- defer func() {
- s.finished = true
- }()
-
- if deadline, ok := ctx.Deadline(); ok {
- if err := conn.SetDeadline(deadline); err != nil {
- return nil, fmt.Errorf("setting conn deadline: %w", err)
- }
- defer func() {
- conn.SetDeadline(time.Time{})
- }()
- }
-
- // Read in the payload and look for errors/protocol violations from the server.
- var resp responseMessage
- if _, err := io.ReadFull(conn, resp.Header()); err != nil {
- return nil, fmt.Errorf("reading response header: %w", err)
- }
- if resp.Type() != msgTypeResponse {
- if resp.Type() != msgTypeError {
- return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
- }
- msg := make([]byte, resp.Length())
- if _, err := io.ReadFull(conn, msg); err != nil {
- return nil, err
- }
- return nil, fmt.Errorf("server error: %q", msg)
- }
- if resp.Length() != len(resp.Payload()) {
- return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
- }
- if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
- return nil, err
- }
-
- // <- e, ee, se
- controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
- s.MixHash(controlEphemeralPub.UntypedBytes())
- if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
- return nil, fmt.Errorf("computing ee: %w", err)
- }
- cipher, err := s.MixDH(machineKey, controlEphemeralPub)
- if err != nil {
- return nil, fmt.Errorf("computing se: %w", err)
- }
- if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
- return nil, fmt.Errorf("decrypting payload: %w", err)
- }
-
- c1, c2, err := s.Split()
- if err != nil {
- return nil, fmt.Errorf("finalizing handshake: %w", err)
- }
-
- c := &Conn{
- conn: conn,
- version: protocolVersion,
- peer: controlKey,
- handshakeHash: s.h,
- tx: txState{
- cipher: c1,
- },
- rx: rxState{
- cipher: c2,
- },
- }
- return c, nil
-}
-
-// Server initiates a control server handshake, returning the resulting
-// control connection.
-//
-// optionalInit can be the client's initial handshake message as
-// returned by ClientDeferred, or nil in which case the initial
-// message is read from conn.
-//
-// The context deadline, if any, covers the entire handshaking
-// process.
-func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
- if deadline, ok := ctx.Deadline(); ok {
- if err := conn.SetDeadline(deadline); err != nil {
- return nil, fmt.Errorf("setting conn deadline: %w", err)
- }
- defer func() {
- conn.SetDeadline(time.Time{})
- }()
- }
-
- // Deliberately does not support formatting, so that we don't echo
- // attacker-controlled input back to them.
- sendErr := func(msg string) error {
- if len(msg) >= 1<<16 {
- msg = msg[:1<<16]
- }
- var hdr [headerLen]byte
- hdr[0] = msgTypeError
- binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
- if _, err := conn.Write(hdr[:]); err != nil {
- return fmt.Errorf("sending %q error to client: %w", msg, err)
- }
- if _, err := io.WriteString(conn, msg); err != nil {
- return fmt.Errorf("sending %q error to client: %w", msg, err)
- }
- return fmt.Errorf("refused client handshake: %q", msg)
- }
-
- var s symmetricState
- s.Initialize()
-
- var init initiationMessage
- if optionalInit != nil {
- if len(optionalInit) != len(init) {
- return nil, sendErr("wrong handshake initiation size")
- }
- copy(init[:], optionalInit)
- } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
- return nil, err
- }
- // Just a rename to make it more obvious what the value is. In the
- // current implementation we don't need to block any protocol
- // versions at this layer, it's safe to let the handshake proceed
- // and then let the caller make decisions based on the agreed-upon
- // protocol version.
- clientVersion := init.Version()
- if init.Type() != msgTypeInitiation {
- return nil, sendErr("unexpected handshake message type")
- }
- if init.Length() != len(init.Payload()) {
- return nil, sendErr("wrong handshake initiation length")
- }
- // if optionalInit was provided, we have the payload already.
- if optionalInit == nil {
- if _, err := io.ReadFull(conn, init.Payload()); err != nil {
- return nil, err
- }
- }
-
- // prologue. Can only do this once we at least think the client is
- // handshaking using a supported version.
- s.MixHash(protocolVersionPrologue(clientVersion))
-
- // <- s
- // ...
- controlKeyPub := controlKey.Public()
- s.MixHash(controlKeyPub.UntypedBytes())
-
- // -> e, es, s, ss
- machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
- s.MixHash(machineEphemeralPub.UntypedBytes())
- cipher, err := s.MixDH(controlKey, machineEphemeralPub)
- if err != nil {
- return nil, fmt.Errorf("computing es: %w", err)
- }
- var machineKeyBytes [32]byte
- if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
- return nil, fmt.Errorf("decrypting machine key: %w", err)
- }
- machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
- cipher, err = s.MixDH(controlKey, machineKey)
- if err != nil {
- return nil, fmt.Errorf("computing ss: %w", err)
- }
- if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
- return nil, fmt.Errorf("decrypting initiation tag: %w", err)
- }
-
- // <- e, ee, se
- resp := mkResponseMessage()
- controlEphemeral := key.NewMachine()
- controlEphemeralPub := controlEphemeral.Public()
- copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
- s.MixHash(controlEphemeralPub.UntypedBytes())
- if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
- return nil, fmt.Errorf("computing ee: %w", err)
- }
- cipher, err = s.MixDH(controlEphemeral, machineKey)
- if err != nil {
- return nil, fmt.Errorf("computing se: %w", err)
- }
- s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
-
- c1, c2, err := s.Split()
- if err != nil {
- return nil, fmt.Errorf("finalizing handshake: %w", err)
- }
-
- if _, err := conn.Write(resp[:]); err != nil {
- return nil, err
- }
-
- c := &Conn{
- conn: conn,
- version: clientVersion,
- peer: machineKey,
- handshakeHash: s.h,
- tx: txState{
- cipher: c2,
- },
- rx: rxState{
- cipher: c1,
- },
- }
- return c, nil
-}
-
-// symmetricState contains the state of an in-flight handshake.
-type symmetricState struct {
- finished bool
-
- h [blake2s.Size]byte // hash of currently-processed handshake state
- ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
-}
-
-func (s *symmetricState) checkFinished() {
- if s.finished {
- panic("attempted to use symmetricState after Split was called")
- }
-}
-
-// Initialize sets s to the initial handshake state, prior to
-// processing any handshake messages.
-func (s *symmetricState) Initialize() {
- s.checkFinished()
- s.h = blake2s.Sum256([]byte(protocolName))
- s.ck = s.h
-}
-
-// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
-// concatenation.
-func (s *symmetricState) MixHash(data []byte) {
- s.checkFinished()
- h := newBLAKE2s()
- h.Write(s.h[:])
- h.Write(data)
- h.Sum(s.h[:0])
-}
-
-// MixDH updates s.ck with the result of X25519(priv, pub) and returns
-// a singleUseCHP that can be used to encrypt or decrypt handshake
-// data.
-//
-// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
-// it as a single function allows for strongly-typed arguments that
-// reduce the risk of error in the caller (e.g. invoking X25519 with
-// two private keys, or two public keys), and thus producing the wrong
-// calculation.
-func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
- s.checkFinished()
- keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
- if err != nil {
- return nil, fmt.Errorf("computing X25519: %w", err)
- }
-
- r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
- if _, err := io.ReadFull(r, s.ck[:]); err != nil {
- return nil, fmt.Errorf("extracting ck: %w", err)
- }
- var k [chp.KeySize]byte
- if _, err := io.ReadFull(r, k[:]); err != nil {
- return nil, fmt.Errorf("extracting k: %w", err)
- }
- return newSingleUseCHP(k), nil
-}
-
-// EncryptAndHash encrypts plaintext into ciphertext (which must be
-// the correct size to hold the encrypted plaintext) using cipher,
-// mixes the ciphertext into s.h, and returns the ciphertext.
-func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
- s.checkFinished()
- if len(ciphertext) != len(plaintext)+chp.Overhead {
- panic("ciphertext is wrong size for given plaintext")
- }
- ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
- s.MixHash(ret)
-}
-
-// DecryptAndHash decrypts the given ciphertext into plaintext (which
-// must be the correct size to hold the decrypted ciphertext) using
-// cipher. If decryption is successful, it mixes the ciphertext into
-// s.h.
-func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
- s.checkFinished()
- if len(ciphertext) != len(plaintext)+chp.Overhead {
- return errors.New("plaintext is wrong size for given ciphertext")
- }
- if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
- return err
- }
- s.MixHash(ciphertext)
- return nil
-}
-
-// Split returns two ChaCha20Poly1305 ciphers with keys derived from
-// the current handshake state. Methods on s cannot be used again
-// after calling Split.
-func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
- s.finished = true
-
- var k1, k2 [chp.KeySize]byte
- r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
- if _, err := io.ReadFull(r, k1[:]); err != nil {
- return nil, nil, fmt.Errorf("extracting k1: %w", err)
- }
- if _, err := io.ReadFull(r, k2[:]); err != nil {
- return nil, nil, fmt.Errorf("extracting k2: %w", err)
- }
- c1, err = chp.New(k1[:])
- if err != nil {
- return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
- }
- c2, err = chp.New(k2[:])
- if err != nil {
- return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
- }
- return c1, c2, nil
-}
-
-// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
-// error.
-func newBLAKE2s() hash.Hash {
- h, err := blake2s.New256(nil)
- if err != nil {
- // Should never happen, errors only happen when using BLAKE2s
- // in MAC mode with a key.
- panic(err)
- }
- return h
-}
-
-// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
-// panics on error.
-func newCHP(key [chp.KeySize]byte) cipher.AEAD {
- aead, err := chp.New(key[:])
- if err != nil {
- // Can only happen if we passed a key of the wrong length. The
- // function signature prevents that.
- panic(err)
- }
- return aead
-}
-
-// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
-// only once, either for encrypting or decrypting, but not both. The
-// chosen operation is always executed with an all-zeros
-// nonce. Subsequent calls to either Seal or Open panic.
-type singleUseCHP struct {
- c cipher.AEAD
-}
-
-func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
- return &singleUseCHP{newCHP(key)}
-}
-
-func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
- if c.c == nil {
- panic("Attempted reuse of singleUseAEAD")
- }
- cipher := c.c
- c.c = nil
- var nonce [chp.NonceSize]byte
- return cipher.Seal(dst, nonce[:], plaintext, additionalData)
-}
-
-func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
- if c.c == nil {
- panic("Attempted reuse of singleUseAEAD")
- }
- cipher := c.c
- c.c = nil
- var nonce [chp.NonceSize]byte
- return cipher.Open(dst, nonce[:], ciphertext, additionalData)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ "go4.org/mem"
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+ "tailscale.com/types/key"
+)
+
+const (
+ // protocolName is the name of the specific instantiation of Noise
+ // that the control protocol uses. This string's value is fixed by
+ // the Noise spec, and shouldn't be changed unless we're updating
+ // the control protocol to use a different Noise instance.
+ protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+ // protocolVersion is the version of the control protocol that
+ // Client will use when initiating a handshake.
+ //protocolVersion uint16 = 1
+ // protocolVersionPrefix is the name portion of the protocol
+ // name+version string that gets mixed into the handshake as a
+ // prologue.
+ //
+ // This mixing verifies that both clients agree that they're
+ // executing the control protocol at a specific version that
+ // matches the advertised version in the cleartext packet header.
+ protocolVersionPrefix = "Tailscale Control Protocol v"
+ invalidNonce = ^uint64(0)
+)
+
+func protocolVersionPrologue(version uint16) []byte {
+ ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
+ ret = append(ret, protocolVersionPrefix...)
+ return strconv.AppendUint(ret, uint64(version), 10)
+}
+
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
+// is assumed to have already sent the client>server handshake
+// initiation message.
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
+
+// ClientDeferred initiates a control client handshake, returning the
+// initial message to send to the server and a continuation to
+// finalize the handshake.
+//
+// ClientDeferred is split in this way for RTT reduction: we run this
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
+// completely serialized the negotiation followed by the handshake,
+// we'd pay an extra RTT to transmit the handshake initiation after
+// protocol switching. By splitting the handshake into an initial
+// message and a continuation, we can embed the handshake initiation
+// into the HTTP protocol switching request and avoid a bit of delay.
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
+ var s symmetricState
+ s.Initialize()
+
+ // prologue
+ s.MixHash(protocolVersionPrologue(protocolVersion))
+
+ // <- s
+ // ...
+ s.MixHash(controlKey.UntypedBytes())
+
+ // -> e, es, s, ss
+ init := mkInitiationMessage(protocolVersion)
+ machineEphemeral := key.NewMachine()
+ machineEphemeralPub := machineEphemeral.Public()
+ copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(machineEphemeral, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing es: %w", err)
+ }
+ machineKeyPub := machineKey.Public()
+ s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
+ cipher, err = s.MixDH(machineKey, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
+ }
+ s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
+
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
+ }
+ return init[:], cont, nil
+}
+
+// Client wraps ClientDeferred and immediately invokes the returned
+// continuation with conn.
+//
+// This is a helper for when you don't need the fancy
+// continuation-style handshake, and just want to synchronously
+// upgrade a net.Conn to a secure transport.
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(init); err != nil {
+ return nil, err
+ }
+ return cont(ctx, conn)
+}
+
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ // No matter what, this function can only run once per s. Ensure
+ // attempted reuse causes a panic.
+ defer func() {
+ s.finished = true
+ }()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Read in the payload and look for errors/protocol violations from the server.
+ var resp responseMessage
+ if _, err := io.ReadFull(conn, resp.Header()); err != nil {
+ return nil, fmt.Errorf("reading response header: %w", err)
+ }
+ if resp.Type() != msgTypeResponse {
+ if resp.Type() != msgTypeError {
+ return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
+ }
+ msg := make([]byte, resp.Length())
+ if _, err := io.ReadFull(conn, msg); err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("server error: %q", msg)
+ }
+ if resp.Length() != len(resp.Payload()) {
+ return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
+ }
+ if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
+ return nil, err
+ }
+
+ // <- e, ee, se
+ controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting payload: %w", err)
+ }
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: protocolVersion,
+ peer: controlKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c1,
+ },
+ rx: rxState{
+ cipher: c2,
+ },
+ }
+ return c, nil
+}
+
+// Server initiates a control server handshake, returning the resulting
+// control connection.
+//
+// optionalInit can be the client's initial handshake message as
+// returned by ClientDeferred, or nil in which case the initial
+// message is read from conn.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Deliberately does not support formatting, so that we don't echo
+ // attacker-controlled input back to them.
+ sendErr := func(msg string) error {
+ if len(msg) >= 1<<16 {
+ msg = msg[:1<<16]
+ }
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeError
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ if _, err := io.WriteString(conn, msg); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ return fmt.Errorf("refused client handshake: %q", msg)
+ }
+
+ var s symmetricState
+ s.Initialize()
+
+ var init initiationMessage
+ if optionalInit != nil {
+ if len(optionalInit) != len(init) {
+ return nil, sendErr("wrong handshake initiation size")
+ }
+ copy(init[:], optionalInit)
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
+ return nil, err
+ }
+ // Just a rename to make it more obvious what the value is. In the
+ // current implementation we don't need to block any protocol
+ // versions at this layer, it's safe to let the handshake proceed
+ // and then let the caller make decisions based on the agreed-upon
+ // protocol version.
+ clientVersion := init.Version()
+ if init.Type() != msgTypeInitiation {
+ return nil, sendErr("unexpected handshake message type")
+ }
+ if init.Length() != len(init.Payload()) {
+ return nil, sendErr("wrong handshake initiation length")
+ }
+ // if optionalInit was provided, we have the payload already.
+ if optionalInit == nil {
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
+ return nil, err
+ }
+ }
+
+ // prologue. Can only do this once we at least think the client is
+ // handshaking using a supported version.
+ s.MixHash(protocolVersionPrologue(clientVersion))
+
+ // <- s
+ // ...
+ controlKeyPub := controlKey.Public()
+ s.MixHash(controlKeyPub.UntypedBytes())
+
+ // -> e, es, s, ss
+ machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(controlKey, machineEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing es: %w", err)
+ }
+ var machineKeyBytes [32]byte
+ if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
+ return nil, fmt.Errorf("decrypting machine key: %w", err)
+ }
+ machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
+ cipher, err = s.MixDH(controlKey, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing ss: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+ }
+
+ // <- e, ee, se
+ resp := mkResponseMessage()
+ controlEphemeral := key.NewMachine()
+ controlEphemeralPub := controlEphemeral.Public()
+ copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err = s.MixDH(controlEphemeral, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ if _, err := conn.Write(resp[:]); err != nil {
+ return nil, err
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: clientVersion,
+ peer: machineKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c2,
+ },
+ rx: rxState{
+ cipher: c1,
+ },
+ }
+ return c, nil
+}
+
+// symmetricState contains the state of an in-flight handshake.
+type symmetricState struct {
+ finished bool
+
+ h [blake2s.Size]byte // hash of currently-processed handshake state
+ ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
+}
+
+func (s *symmetricState) checkFinished() {
+ if s.finished {
+ panic("attempted to use symmetricState after Split was called")
+ }
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any handshake messages.
+func (s *symmetricState) Initialize() {
+ s.checkFinished()
+ s.h = blake2s.Sum256([]byte(protocolName))
+ s.ck = s.h
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+ s.checkFinished()
+ h := newBLAKE2s()
+ h.Write(s.h[:])
+ h.Write(data)
+ h.Sum(s.h[:0])
+}
+
+// MixDH updates s.ck with the result of X25519(priv, pub) and returns
+// a singleUseCHP that can be used to encrypt or decrypt handshake
+// data.
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
+ s.checkFinished()
+ keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
+ if err != nil {
+ return nil, fmt.Errorf("computing X25519: %w", err)
+ }
+
+ r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+ if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+ return nil, fmt.Errorf("extracting ck: %w", err)
+ }
+ var k [chp.KeySize]byte
+ if _, err := io.ReadFull(r, k[:]); err != nil {
+ return nil, fmt.Errorf("extracting k: %w", err)
+ }
+ return newSingleUseCHP(k), nil
+}
+
+// EncryptAndHash encrypts plaintext into ciphertext (which must be
+// the correct size to hold the encrypted plaintext) using cipher,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ panic("ciphertext is wrong size for given plaintext")
+ }
+ ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
+ s.MixHash(ret)
+}
+
+// DecryptAndHash decrypts the given ciphertext into plaintext (which
+// must be the correct size to hold the decrypted ciphertext) using
+// cipher. If decryption is successful, it mixes the ciphertext into
+// s.h.
+func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ return errors.New("plaintext is wrong size for given ciphertext")
+ }
+ if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
+ return err
+ }
+ s.MixHash(ciphertext)
+ return nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derived from
+// the current handshake state. Methods on s cannot be used again
+// after calling Split.
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+ s.finished = true
+
+ var k1, k2 [chp.KeySize]byte
+ r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+ if _, err := io.ReadFull(r, k1[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k1: %w", err)
+ }
+ if _, err := io.ReadFull(r, k2[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k2: %w", err)
+ }
+ c1, err = chp.New(k1[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+ }
+ c2, err = chp.New(k2[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+ }
+ return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ // Should never happen, errors only happen when using BLAKE2s
+ // in MAC mode with a key.
+ panic(err)
+ }
+ return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+ aead, err := chp.New(key[:])
+ if err != nil {
+ // Can only happen if we passed a key of the wrong length. The
+ // function signature prevents that.
+ panic(err)
+ }
+ return aead
+}
+
+// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
+// only once, either for encrypting or decrypting, but not both. The
+// chosen operation is always executed with an all-zeros
+// nonce. Subsequent calls to either Seal or Open panic.
+type singleUseCHP struct {
+ c cipher.AEAD
+}
+
+func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
+ return &singleUseCHP{newCHP(key)}
+}
+
+func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Seal(dst, nonce[:], plaintext, additionalData)
+}
+
+func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Open(dst, nonce[:], ciphertext, additionalData)
+}