summaryrefslogtreecommitdiffhomepage
path: root/control/controlbase
diff options
context:
space:
mode:
Diffstat (limited to 'control/controlbase')
-rw-r--r--control/controlbase/conn.go816
-rw-r--r--control/controlbase/handshake.go988
-rw-r--r--control/controlbase/interop_test.go512
-rw-r--r--control/controlbase/messages.go174
4 files changed, 1245 insertions, 1245 deletions
diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go
index dc22212e8..b6fc53b3a 100644
--- a/control/controlbase/conn.go
+++ b/control/controlbase/conn.go
@@ -1,408 +1,408 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package controlbase implements the base transport of the Tailscale
-// 2021 control protocol.
-//
-// The base transport implements Noise IK, instantiated with
-// Curve25519, ChaCha20Poly1305 and BLAKE2s.
-package controlbase
-
-import (
- "crypto/cipher"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "time"
-
- "golang.org/x/crypto/blake2s"
- chp "golang.org/x/crypto/chacha20poly1305"
- "tailscale.com/types/key"
-)
-
-const (
- // maxMessageSize is the maximum size of a protocol frame on the
- // wire, including header and payload.
- maxMessageSize = 4096
- // maxCiphertextSize is the maximum amount of ciphertext bytes
- // that one protocol frame can carry, after framing.
- maxCiphertextSize = maxMessageSize - 3
- // maxPlaintextSize is the maximum amount of plaintext bytes that
- // one protocol frame can carry, after encryption and framing.
- maxPlaintextSize = maxCiphertextSize - chp.Overhead
-)
-
-// A Conn is a secured Noise connection. It implements the net.Conn
-// interface, with the unusual trait that any write error (including a
-// SetWriteDeadline induced i/o timeout) causes all future writes to
-// fail.
-type Conn struct {
- conn net.Conn
- version uint16
- peer key.MachinePublic
- handshakeHash [blake2s.Size]byte
- rx rxState
- tx txState
-}
-
-// rxState is all the Conn state that Read uses.
-type rxState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- buf *maxMsgBuffer // or nil when reads exhausted
- n int // number of valid bytes in buf
- next int // offset of next undecrypted packet
- plaintext []byte // slice into buf of decrypted bytes
- hdrBuf [headerLen]byte // small buffer used when buf is nil
-}
-
-// txState is all the Conn state that Write uses.
-type txState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- err error // records the first partial write error for all future calls
-}
-
-// ProtocolVersion returns the protocol version that was used to
-// establish this Conn.
-func (c *Conn) ProtocolVersion() int {
- return int(c.version)
-}
-
-// HandshakeHash returns the Noise handshake hash for the connection,
-// which can be used to bind other messages to this connection
-// (i.e. to ensure that the message wasn't replayed from a different
-// connection).
-func (c *Conn) HandshakeHash() [blake2s.Size]byte {
- return c.handshakeHash
-}
-
-// Peer returns the peer's long-term public key.
-func (c *Conn) Peer() key.MachinePublic {
- return c.peer
-}
-
-// readNLocked reads into c.rx.buf until buf contains at least total
-// bytes. Returns a slice of the total bytes in rxBuf, or an
-// error if fewer than total bytes are available.
-//
-// It may be called with a nil c.rx.buf only if total == headerLen.
-//
-// On success, c.rx.buf will be non-nil.
-func (c *Conn) readNLocked(total int) ([]byte, error) {
- if total > maxMessageSize {
- return nil, errReadTooBig{total}
- }
- for {
- if total <= c.rx.n {
- return c.rx.buf[:total], nil
- }
- var n int
- var err error
- if c.rx.buf == nil {
- if c.rx.n != 0 || total != headerLen {
- panic("unexpected")
- }
- // Optimization to reduce memory usage.
- // Most connections are blocked forever waiting for
- // a read, so we don't want c.rx.buf to be allocated until
- // we know there's data to read. Instead, when we're
- // waiting for data to arrive here, read into the
- // 3 byte hdrBuf:
- n, err = c.conn.Read(c.rx.hdrBuf[:])
- if n > 0 {
- c.rx.buf = getMaxMsgBuffer()
- copy(c.rx.buf[:], c.rx.hdrBuf[:n])
- }
- } else {
- n, err = c.conn.Read(c.rx.buf[c.rx.n:])
- }
- c.rx.n += n
- if err != nil {
- return nil, err
- }
- }
-}
-
-// decryptLocked decrypts msg (which is header+ciphertext) in-place
-// and sets c.rx.plaintext to the decrypted bytes.
-func (c *Conn) decryptLocked(msg []byte) (err error) {
- if msgType := msg[0]; msgType != msgTypeRecord {
- return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
- }
- // We don't check the length field here, because the caller
- // already did in order to figure out how big the msg slice should
- // be.
- ciphertext := msg[headerLen:]
-
- if !c.rx.nonce.Valid() {
- return errCipherExhausted{}
- }
-
- c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
- c.rx.nonce.Increment()
-
- if err != nil {
- // Once a decryption has failed, our Conn is no longer
- // synchronized with our peer. Nuke the cipher state to be
- // safe, so that no further decryptions are attempted. Future
- // read attempts will return net.ErrClosed.
- c.rx.cipher = nil
- }
- return err
-}
-
-// encryptLocked encrypts plaintext into buf (including the
-// packet header) and returns a slice of the ciphertext, or an error
-// if the cipher is exhausted (i.e. can no longer be used safely).
-func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
- if !c.tx.nonce.Valid() {
- // Received 2^64-1 messages on this cipher state. Connection
- // is no longer usable.
- return nil, errCipherExhausted{}
- }
-
- buf[0] = msgTypeRecord
- binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
- ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
- c.tx.nonce.Increment()
-
- return ret, nil
-}
-
-// wholeMessageLocked returns a slice of one whole Noise transport
-// message from c.rx.buf, if one whole message is available, and
-// advances the read state to the next Noise message in the
-// buffer. Returns nil without advancing read state if there isn't one
-// whole message in c.rx.buf.
-func (c *Conn) wholeMessageLocked() []byte {
- available := c.rx.n - c.rx.next
- if available < headerLen {
- return nil
- }
- bs := c.rx.buf[c.rx.next:c.rx.n]
- totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- if len(bs) < totalSize {
- return nil
- }
- c.rx.next += totalSize
- return bs[:totalSize]
-}
-
-// decryptOneLocked decrypts one Noise transport message, reading from
-// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
-// bytes. c.rx.plaintext is only valid if err == nil.
-func (c *Conn) decryptOneLocked() error {
- c.rx.plaintext = nil
-
- // Fast path: do we have one whole ciphertext frame buffered
- // already?
- if bs := c.wholeMessageLocked(); bs != nil {
- return c.decryptLocked(bs)
- }
-
- if c.rx.next != 0 {
- // To simplify the read logic, move the remainder of the
- // buffered bytes back to the head of the buffer, so we can
- // grow it without worrying about wraparound.
- c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
- c.rx.next = 0
- }
-
- // Return our buffer to the pool if it's empty, lest we be
- // blocked in a long Read call, reading the 3 byte header. We
- // don't to keep that buffer unnecessarily alive.
- if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
- bufPool.Put(c.rx.buf)
- c.rx.buf = nil
- }
-
- bs, err := c.readNLocked(headerLen)
- if err != nil {
- return err
- }
- // The rest of the header (besides the length field) gets verified
- // in decryptLocked, not here.
- messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- bs, err = c.readNLocked(messageLen)
- if err != nil {
- return err
- }
-
- c.rx.next = len(bs)
-
- return c.decryptLocked(bs)
-}
-
-// Read implements io.Reader.
-func (c *Conn) Read(bs []byte) (int, error) {
- c.rx.Lock()
- defer c.rx.Unlock()
-
- if c.rx.cipher == nil {
- return 0, net.ErrClosed
- }
- // If no plaintext is buffered, decrypt incoming frames until we
- // have some plaintext. Zero-byte Noise frames are allowed in this
- // protocol, which is why we have to loop here rather than decrypt
- // a single additional frame.
- for len(c.rx.plaintext) == 0 {
- if err := c.decryptOneLocked(); err != nil {
- return 0, err
- }
- }
- n := copy(bs, c.rx.plaintext)
- c.rx.plaintext = c.rx.plaintext[n:]
-
- // Lose slice's underlying array pointer to unneeded memory so
- // GC can collect more.
- if len(c.rx.plaintext) == 0 {
- c.rx.plaintext = nil
- }
- return n, nil
-}
-
-// Write implements io.Writer.
-func (c *Conn) Write(bs []byte) (n int, err error) {
- c.tx.Lock()
- defer c.tx.Unlock()
-
- if c.tx.err != nil {
- return 0, c.tx.err
- }
- defer func() {
- if err != nil {
- // All write errors are fatal for this conn, so clear the
- // cipher state whenever an error happens.
- c.tx.cipher = nil
- }
- if c.tx.err == nil {
- // Only set c.tx.err if not nil so that we can return one
- // error on the first failure, and a different one for
- // subsequent calls. See the error handling around Write
- // below for why.
- c.tx.err = err
- }
- }()
-
- if c.tx.cipher == nil {
- return 0, net.ErrClosed
- }
-
- buf := getMaxMsgBuffer()
- defer bufPool.Put(buf)
-
- var sent int
- for len(bs) > 0 {
- toSend := bs
- if len(toSend) > maxPlaintextSize {
- toSend = bs[:maxPlaintextSize]
- }
- bs = bs[len(toSend):]
-
- ciphertext, err := c.encryptLocked(toSend, buf)
- if err != nil {
- return sent, err
- }
- if _, err := c.conn.Write(ciphertext); err != nil {
- // Return the raw error on the Write that actually
- // failed. For future writes, return that error wrapped in
- // a desync error.
- c.tx.err = errPartialWrite{err}
- return sent, err
- }
- sent += len(toSend)
- }
- return sent, nil
-}
-
-// Close implements io.Closer.
-func (c *Conn) Close() error {
- closeErr := c.conn.Close() // unblocks any waiting reads or writes
-
- // Remove references to live cipher state. Strictly speaking this
- // is unnecessary, but we want to try and hand the active cipher
- // state to the garbage collector promptly, to preserve perfect
- // forward secrecy as much as we can.
- c.rx.Lock()
- c.rx.cipher = nil
- c.rx.Unlock()
- c.tx.Lock()
- c.tx.cipher = nil
- c.tx.Unlock()
- return closeErr
-}
-
-func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
-func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
-func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
-func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
-func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
-
-// errCipherExhausted is the error returned when we run out of nonces
-// on a cipher.
-type errCipherExhausted struct{}
-
-func (errCipherExhausted) Error() string {
- return "cipher exhausted, no more nonces available for current key"
-}
-func (errCipherExhausted) Timeout() bool { return false }
-func (errCipherExhausted) Temporary() bool { return false }
-
-// errPartialWrite is the error returned when the cipher state has
-// become unusable due to a past partial write.
-type errPartialWrite struct {
- err error
-}
-
-func (e errPartialWrite) Error() string {
- return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
-}
-func (e errPartialWrite) Unwrap() error { return e.err }
-func (e errPartialWrite) Temporary() bool { return false }
-func (e errPartialWrite) Timeout() bool { return false }
-
-// errReadTooBig is the error returned when the peer sent an
-// unacceptably large Noise frame.
-type errReadTooBig struct {
- requested int
-}
-
-func (e errReadTooBig) Error() string {
- return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
-}
-func (e errReadTooBig) Temporary() bool {
- // permanent error because this error only occurs when our peer
- // sends us a frame so large we're unwilling to ever decode it.
- return false
-}
-func (e errReadTooBig) Timeout() bool { return false }
-
-type nonce [chp.NonceSize]byte
-
-func (n *nonce) Valid() bool {
- return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
-}
-
-func (n *nonce) Increment() {
- if !n.Valid() {
- panic("increment of invalid nonce")
- }
- binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
-}
-
-type maxMsgBuffer [maxMessageSize]byte
-
-// bufPool holds the temporary buffers for Conn.Read & Write.
-var bufPool = &sync.Pool{
- New: func() any {
- return new(maxMsgBuffer)
- },
-}
-
-func getMaxMsgBuffer() *maxMsgBuffer {
- return bufPool.Get().(*maxMsgBuffer)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package controlbase implements the base transport of the Tailscale
+// 2021 control protocol.
+//
+// The base transport implements Noise IK, instantiated with
+// Curve25519, ChaCha20Poly1305 and BLAKE2s.
+package controlbase
+
+import (
+ "crypto/cipher"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "tailscale.com/types/key"
+)
+
+const (
+ // maxMessageSize is the maximum size of a protocol frame on the
+ // wire, including header and payload.
+ maxMessageSize = 4096
+ // maxCiphertextSize is the maximum amount of ciphertext bytes
+ // that one protocol frame can carry, after framing.
+ maxCiphertextSize = maxMessageSize - 3
+ // maxPlaintextSize is the maximum amount of plaintext bytes that
+ // one protocol frame can carry, after encryption and framing.
+ maxPlaintextSize = maxCiphertextSize - chp.Overhead
+)
+
+// A Conn is a secured Noise connection. It implements the net.Conn
+// interface, with the unusual trait that any write error (including a
+// SetWriteDeadline induced i/o timeout) causes all future writes to
+// fail.
+type Conn struct {
+ conn net.Conn
+ version uint16
+ peer key.MachinePublic
+ handshakeHash [blake2s.Size]byte
+ rx rxState
+ tx txState
+}
+
+// rxState is all the Conn state that Read uses.
+type rxState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ buf *maxMsgBuffer // or nil when reads exhausted
+ n int // number of valid bytes in buf
+ next int // offset of next undecrypted packet
+ plaintext []byte // slice into buf of decrypted bytes
+ hdrBuf [headerLen]byte // small buffer used when buf is nil
+}
+
+// txState is all the Conn state that Write uses.
+type txState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ err error // records the first partial write error for all future calls
+}
+
+// ProtocolVersion returns the protocol version that was used to
+// establish this Conn.
+func (c *Conn) ProtocolVersion() int {
+ return int(c.version)
+}
+
+// HandshakeHash returns the Noise handshake hash for the connection,
+// which can be used to bind other messages to this connection
+// (i.e. to ensure that the message wasn't replayed from a different
+// connection).
+func (c *Conn) HandshakeHash() [blake2s.Size]byte {
+ return c.handshakeHash
+}
+
+// Peer returns the peer's long-term public key.
+func (c *Conn) Peer() key.MachinePublic {
+ return c.peer
+}
+
+// readNLocked reads into c.rx.buf until buf contains at least total
+// bytes. Returns a slice of the total bytes in rxBuf, or an
+// error if fewer than total bytes are available.
+//
+// It may be called with a nil c.rx.buf only if total == headerLen.
+//
+// On success, c.rx.buf will be non-nil.
+func (c *Conn) readNLocked(total int) ([]byte, error) {
+ if total > maxMessageSize {
+ return nil, errReadTooBig{total}
+ }
+ for {
+ if total <= c.rx.n {
+ return c.rx.buf[:total], nil
+ }
+ var n int
+ var err error
+ if c.rx.buf == nil {
+ if c.rx.n != 0 || total != headerLen {
+ panic("unexpected")
+ }
+ // Optimization to reduce memory usage.
+ // Most connections are blocked forever waiting for
+ // a read, so we don't want c.rx.buf to be allocated until
+ // we know there's data to read. Instead, when we're
+ // waiting for data to arrive here, read into the
+ // 3 byte hdrBuf:
+ n, err = c.conn.Read(c.rx.hdrBuf[:])
+ if n > 0 {
+ c.rx.buf = getMaxMsgBuffer()
+ copy(c.rx.buf[:], c.rx.hdrBuf[:n])
+ }
+ } else {
+ n, err = c.conn.Read(c.rx.buf[c.rx.n:])
+ }
+ c.rx.n += n
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// decryptLocked decrypts msg (which is header+ciphertext) in-place
+// and sets c.rx.plaintext to the decrypted bytes.
+func (c *Conn) decryptLocked(msg []byte) (err error) {
+ if msgType := msg[0]; msgType != msgTypeRecord {
+ return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
+ }
+ // We don't check the length field here, because the caller
+ // already did in order to figure out how big the msg slice should
+ // be.
+ ciphertext := msg[headerLen:]
+
+ if !c.rx.nonce.Valid() {
+ return errCipherExhausted{}
+ }
+
+ c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
+ c.rx.nonce.Increment()
+
+ if err != nil {
+ // Once a decryption has failed, our Conn is no longer
+ // synchronized with our peer. Nuke the cipher state to be
+ // safe, so that no further decryptions are attempted. Future
+ // read attempts will return net.ErrClosed.
+ c.rx.cipher = nil
+ }
+ return err
+}
+
+// encryptLocked encrypts plaintext into buf (including the
+// packet header) and returns a slice of the ciphertext, or an error
+// if the cipher is exhausted (i.e. can no longer be used safely).
+func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
+ if !c.tx.nonce.Valid() {
+ // Received 2^64-1 messages on this cipher state. Connection
+ // is no longer usable.
+ return nil, errCipherExhausted{}
+ }
+
+ buf[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
+ ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
+ c.tx.nonce.Increment()
+
+ return ret, nil
+}
+
+// wholeMessageLocked returns a slice of one whole Noise transport
+// message from c.rx.buf, if one whole message is available, and
+// advances the read state to the next Noise message in the
+// buffer. Returns nil without advancing read state if there isn't one
+// whole message in c.rx.buf.
+func (c *Conn) wholeMessageLocked() []byte {
+ available := c.rx.n - c.rx.next
+ if available < headerLen {
+ return nil
+ }
+ bs := c.rx.buf[c.rx.next:c.rx.n]
+ totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ if len(bs) < totalSize {
+ return nil
+ }
+ c.rx.next += totalSize
+ return bs[:totalSize]
+}
+
+// decryptOneLocked decrypts one Noise transport message, reading from
+// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
+// bytes. c.rx.plaintext is only valid if err == nil.
+func (c *Conn) decryptOneLocked() error {
+ c.rx.plaintext = nil
+
+ // Fast path: do we have one whole ciphertext frame buffered
+ // already?
+ if bs := c.wholeMessageLocked(); bs != nil {
+ return c.decryptLocked(bs)
+ }
+
+ if c.rx.next != 0 {
+ // To simplify the read logic, move the remainder of the
+ // buffered bytes back to the head of the buffer, so we can
+ // grow it without worrying about wraparound.
+ c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
+ c.rx.next = 0
+ }
+
+ // Return our buffer to the pool if it's empty, lest we be
+ // blocked in a long Read call, reading the 3 byte header. We
+ // don't to keep that buffer unnecessarily alive.
+ if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
+ bufPool.Put(c.rx.buf)
+ c.rx.buf = nil
+ }
+
+ bs, err := c.readNLocked(headerLen)
+ if err != nil {
+ return err
+ }
+ // The rest of the header (besides the length field) gets verified
+ // in decryptLocked, not here.
+ messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ bs, err = c.readNLocked(messageLen)
+ if err != nil {
+ return err
+ }
+
+ c.rx.next = len(bs)
+
+ return c.decryptLocked(bs)
+}
+
+// Read implements io.Reader.
+func (c *Conn) Read(bs []byte) (int, error) {
+ c.rx.Lock()
+ defer c.rx.Unlock()
+
+ if c.rx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+ // If no plaintext is buffered, decrypt incoming frames until we
+ // have some plaintext. Zero-byte Noise frames are allowed in this
+ // protocol, which is why we have to loop here rather than decrypt
+ // a single additional frame.
+ for len(c.rx.plaintext) == 0 {
+ if err := c.decryptOneLocked(); err != nil {
+ return 0, err
+ }
+ }
+ n := copy(bs, c.rx.plaintext)
+ c.rx.plaintext = c.rx.plaintext[n:]
+
+ // Lose slice's underlying array pointer to unneeded memory so
+ // GC can collect more.
+ if len(c.rx.plaintext) == 0 {
+ c.rx.plaintext = nil
+ }
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (c *Conn) Write(bs []byte) (n int, err error) {
+ c.tx.Lock()
+ defer c.tx.Unlock()
+
+ if c.tx.err != nil {
+ return 0, c.tx.err
+ }
+ defer func() {
+ if err != nil {
+ // All write errors are fatal for this conn, so clear the
+ // cipher state whenever an error happens.
+ c.tx.cipher = nil
+ }
+ if c.tx.err == nil {
+ // Only set c.tx.err if not nil so that we can return one
+ // error on the first failure, and a different one for
+ // subsequent calls. See the error handling around Write
+ // below for why.
+ c.tx.err = err
+ }
+ }()
+
+ if c.tx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+
+ buf := getMaxMsgBuffer()
+ defer bufPool.Put(buf)
+
+ var sent int
+ for len(bs) > 0 {
+ toSend := bs
+ if len(toSend) > maxPlaintextSize {
+ toSend = bs[:maxPlaintextSize]
+ }
+ bs = bs[len(toSend):]
+
+ ciphertext, err := c.encryptLocked(toSend, buf)
+ if err != nil {
+ return sent, err
+ }
+ if _, err := c.conn.Write(ciphertext); err != nil {
+ // Return the raw error on the Write that actually
+ // failed. For future writes, return that error wrapped in
+ // a desync error.
+ c.tx.err = errPartialWrite{err}
+ return sent, err
+ }
+ sent += len(toSend)
+ }
+ return sent, nil
+}
+
+// Close implements io.Closer.
+func (c *Conn) Close() error {
+ closeErr := c.conn.Close() // unblocks any waiting reads or writes
+
+ // Remove references to live cipher state. Strictly speaking this
+ // is unnecessary, but we want to try and hand the active cipher
+ // state to the garbage collector promptly, to preserve perfect
+ // forward secrecy as much as we can.
+ c.rx.Lock()
+ c.rx.cipher = nil
+ c.rx.Unlock()
+ c.tx.Lock()
+ c.tx.cipher = nil
+ c.tx.Unlock()
+ return closeErr
+}
+
+func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
+func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
+func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
+func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
+func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
+
+// errCipherExhausted is the error returned when we run out of nonces
+// on a cipher.
+type errCipherExhausted struct{}
+
+func (errCipherExhausted) Error() string {
+ return "cipher exhausted, no more nonces available for current key"
+}
+func (errCipherExhausted) Timeout() bool { return false }
+func (errCipherExhausted) Temporary() bool { return false }
+
+// errPartialWrite is the error returned when the cipher state has
+// become unusable due to a past partial write.
+type errPartialWrite struct {
+ err error
+}
+
+func (e errPartialWrite) Error() string {
+ return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
+}
+func (e errPartialWrite) Unwrap() error { return e.err }
+func (e errPartialWrite) Temporary() bool { return false }
+func (e errPartialWrite) Timeout() bool { return false }
+
+// errReadTooBig is the error returned when the peer sent an
+// unacceptably large Noise frame.
+type errReadTooBig struct {
+ requested int
+}
+
+func (e errReadTooBig) Error() string {
+ return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
+}
+func (e errReadTooBig) Temporary() bool {
+ // permanent error because this error only occurs when our peer
+ // sends us a frame so large we're unwilling to ever decode it.
+ return false
+}
+func (e errReadTooBig) Timeout() bool { return false }
+
+type nonce [chp.NonceSize]byte
+
+func (n *nonce) Valid() bool {
+ return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
+}
+
+func (n *nonce) Increment() {
+ if !n.Valid() {
+ panic("increment of invalid nonce")
+ }
+ binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
+}
+
+type maxMsgBuffer [maxMessageSize]byte
+
+// bufPool holds the temporary buffers for Conn.Read & Write.
+var bufPool = &sync.Pool{
+ New: func() any {
+ return new(maxMsgBuffer)
+ },
+}
+
+func getMaxMsgBuffer() *maxMsgBuffer {
+ return bufPool.Get().(*maxMsgBuffer)
+}
diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go
index 765a4620b..937969a30 100644
--- a/control/controlbase/handshake.go
+++ b/control/controlbase/handshake.go
@@ -1,494 +1,494 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package controlbase
-
-import (
- "context"
- "crypto/cipher"
- "encoding/binary"
- "errors"
- "fmt"
- "hash"
- "io"
- "net"
- "strconv"
- "time"
-
- "go4.org/mem"
- "golang.org/x/crypto/blake2s"
- chp "golang.org/x/crypto/chacha20poly1305"
- "golang.org/x/crypto/curve25519"
- "golang.org/x/crypto/hkdf"
- "tailscale.com/types/key"
-)
-
-const (
- // protocolName is the name of the specific instantiation of Noise
- // that the control protocol uses. This string's value is fixed by
- // the Noise spec, and shouldn't be changed unless we're updating
- // the control protocol to use a different Noise instance.
- protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
- // protocolVersion is the version of the control protocol that
- // Client will use when initiating a handshake.
- //protocolVersion uint16 = 1
- // protocolVersionPrefix is the name portion of the protocol
- // name+version string that gets mixed into the handshake as a
- // prologue.
- //
- // This mixing verifies that both clients agree that they're
- // executing the control protocol at a specific version that
- // matches the advertised version in the cleartext packet header.
- protocolVersionPrefix = "Tailscale Control Protocol v"
- invalidNonce = ^uint64(0)
-)
-
-func protocolVersionPrologue(version uint16) []byte {
- ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
- ret = append(ret, protocolVersionPrefix...)
- return strconv.AppendUint(ret, uint64(version), 10)
-}
-
-// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
-// is assumed to have already sent the client>server handshake
-// initiation message.
-type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
-
-// ClientDeferred initiates a control client handshake, returning the
-// initial message to send to the server and a continuation to
-// finalize the handshake.
-//
-// ClientDeferred is split in this way for RTT reduction: we run this
-// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
-// completely serialized the negotiation followed by the handshake,
-// we'd pay an extra RTT to transmit the handshake initiation after
-// protocol switching. By splitting the handshake into an initial
-// message and a continuation, we can embed the handshake initiation
-// into the HTTP protocol switching request and avoid a bit of delay.
-func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
- var s symmetricState
- s.Initialize()
-
- // prologue
- s.MixHash(protocolVersionPrologue(protocolVersion))
-
- // <- s
- // ...
- s.MixHash(controlKey.UntypedBytes())
-
- // -> e, es, s, ss
- init := mkInitiationMessage(protocolVersion)
- machineEphemeral := key.NewMachine()
- machineEphemeralPub := machineEphemeral.Public()
- copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
- s.MixHash(machineEphemeralPub.UntypedBytes())
- cipher, err := s.MixDH(machineEphemeral, controlKey)
- if err != nil {
- return nil, nil, fmt.Errorf("computing es: %w", err)
- }
- machineKeyPub := machineKey.Public()
- s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
- cipher, err = s.MixDH(machineKey, controlKey)
- if err != nil {
- return nil, nil, fmt.Errorf("computing ss: %w", err)
- }
- s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
-
- cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
- return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
- }
- return init[:], cont, nil
-}
-
-// Client wraps ClientDeferred and immediately invokes the returned
-// continuation with conn.
-//
-// This is a helper for when you don't need the fancy
-// continuation-style handshake, and just want to synchronously
-// upgrade a net.Conn to a secure transport.
-func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
- init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
- if err != nil {
- return nil, err
- }
- if _, err := conn.Write(init); err != nil {
- return nil, err
- }
- return cont(ctx, conn)
-}
-
-func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
- // No matter what, this function can only run once per s. Ensure
- // attempted reuse causes a panic.
- defer func() {
- s.finished = true
- }()
-
- if deadline, ok := ctx.Deadline(); ok {
- if err := conn.SetDeadline(deadline); err != nil {
- return nil, fmt.Errorf("setting conn deadline: %w", err)
- }
- defer func() {
- conn.SetDeadline(time.Time{})
- }()
- }
-
- // Read in the payload and look for errors/protocol violations from the server.
- var resp responseMessage
- if _, err := io.ReadFull(conn, resp.Header()); err != nil {
- return nil, fmt.Errorf("reading response header: %w", err)
- }
- if resp.Type() != msgTypeResponse {
- if resp.Type() != msgTypeError {
- return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
- }
- msg := make([]byte, resp.Length())
- if _, err := io.ReadFull(conn, msg); err != nil {
- return nil, err
- }
- return nil, fmt.Errorf("server error: %q", msg)
- }
- if resp.Length() != len(resp.Payload()) {
- return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
- }
- if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
- return nil, err
- }
-
- // <- e, ee, se
- controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
- s.MixHash(controlEphemeralPub.UntypedBytes())
- if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
- return nil, fmt.Errorf("computing ee: %w", err)
- }
- cipher, err := s.MixDH(machineKey, controlEphemeralPub)
- if err != nil {
- return nil, fmt.Errorf("computing se: %w", err)
- }
- if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
- return nil, fmt.Errorf("decrypting payload: %w", err)
- }
-
- c1, c2, err := s.Split()
- if err != nil {
- return nil, fmt.Errorf("finalizing handshake: %w", err)
- }
-
- c := &Conn{
- conn: conn,
- version: protocolVersion,
- peer: controlKey,
- handshakeHash: s.h,
- tx: txState{
- cipher: c1,
- },
- rx: rxState{
- cipher: c2,
- },
- }
- return c, nil
-}
-
-// Server initiates a control server handshake, returning the resulting
-// control connection.
-//
-// optionalInit can be the client's initial handshake message as
-// returned by ClientDeferred, or nil in which case the initial
-// message is read from conn.
-//
-// The context deadline, if any, covers the entire handshaking
-// process.
-func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
- if deadline, ok := ctx.Deadline(); ok {
- if err := conn.SetDeadline(deadline); err != nil {
- return nil, fmt.Errorf("setting conn deadline: %w", err)
- }
- defer func() {
- conn.SetDeadline(time.Time{})
- }()
- }
-
- // Deliberately does not support formatting, so that we don't echo
- // attacker-controlled input back to them.
- sendErr := func(msg string) error {
- if len(msg) >= 1<<16 {
- msg = msg[:1<<16]
- }
- var hdr [headerLen]byte
- hdr[0] = msgTypeError
- binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
- if _, err := conn.Write(hdr[:]); err != nil {
- return fmt.Errorf("sending %q error to client: %w", msg, err)
- }
- if _, err := io.WriteString(conn, msg); err != nil {
- return fmt.Errorf("sending %q error to client: %w", msg, err)
- }
- return fmt.Errorf("refused client handshake: %q", msg)
- }
-
- var s symmetricState
- s.Initialize()
-
- var init initiationMessage
- if optionalInit != nil {
- if len(optionalInit) != len(init) {
- return nil, sendErr("wrong handshake initiation size")
- }
- copy(init[:], optionalInit)
- } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
- return nil, err
- }
- // Just a rename to make it more obvious what the value is. In the
- // current implementation we don't need to block any protocol
- // versions at this layer, it's safe to let the handshake proceed
- // and then let the caller make decisions based on the agreed-upon
- // protocol version.
- clientVersion := init.Version()
- if init.Type() != msgTypeInitiation {
- return nil, sendErr("unexpected handshake message type")
- }
- if init.Length() != len(init.Payload()) {
- return nil, sendErr("wrong handshake initiation length")
- }
- // if optionalInit was provided, we have the payload already.
- if optionalInit == nil {
- if _, err := io.ReadFull(conn, init.Payload()); err != nil {
- return nil, err
- }
- }
-
- // prologue. Can only do this once we at least think the client is
- // handshaking using a supported version.
- s.MixHash(protocolVersionPrologue(clientVersion))
-
- // <- s
- // ...
- controlKeyPub := controlKey.Public()
- s.MixHash(controlKeyPub.UntypedBytes())
-
- // -> e, es, s, ss
- machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
- s.MixHash(machineEphemeralPub.UntypedBytes())
- cipher, err := s.MixDH(controlKey, machineEphemeralPub)
- if err != nil {
- return nil, fmt.Errorf("computing es: %w", err)
- }
- var machineKeyBytes [32]byte
- if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
- return nil, fmt.Errorf("decrypting machine key: %w", err)
- }
- machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
- cipher, err = s.MixDH(controlKey, machineKey)
- if err != nil {
- return nil, fmt.Errorf("computing ss: %w", err)
- }
- if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
- return nil, fmt.Errorf("decrypting initiation tag: %w", err)
- }
-
- // <- e, ee, se
- resp := mkResponseMessage()
- controlEphemeral := key.NewMachine()
- controlEphemeralPub := controlEphemeral.Public()
- copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
- s.MixHash(controlEphemeralPub.UntypedBytes())
- if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
- return nil, fmt.Errorf("computing ee: %w", err)
- }
- cipher, err = s.MixDH(controlEphemeral, machineKey)
- if err != nil {
- return nil, fmt.Errorf("computing se: %w", err)
- }
- s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
-
- c1, c2, err := s.Split()
- if err != nil {
- return nil, fmt.Errorf("finalizing handshake: %w", err)
- }
-
- if _, err := conn.Write(resp[:]); err != nil {
- return nil, err
- }
-
- c := &Conn{
- conn: conn,
- version: clientVersion,
- peer: machineKey,
- handshakeHash: s.h,
- tx: txState{
- cipher: c2,
- },
- rx: rxState{
- cipher: c1,
- },
- }
- return c, nil
-}
-
-// symmetricState contains the state of an in-flight handshake.
-type symmetricState struct {
- finished bool
-
- h [blake2s.Size]byte // hash of currently-processed handshake state
- ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
-}
-
-func (s *symmetricState) checkFinished() {
- if s.finished {
- panic("attempted to use symmetricState after Split was called")
- }
-}
-
-// Initialize sets s to the initial handshake state, prior to
-// processing any handshake messages.
-func (s *symmetricState) Initialize() {
- s.checkFinished()
- s.h = blake2s.Sum256([]byte(protocolName))
- s.ck = s.h
-}
-
-// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
-// concatenation.
-func (s *symmetricState) MixHash(data []byte) {
- s.checkFinished()
- h := newBLAKE2s()
- h.Write(s.h[:])
- h.Write(data)
- h.Sum(s.h[:0])
-}
-
-// MixDH updates s.ck with the result of X25519(priv, pub) and returns
-// a singleUseCHP that can be used to encrypt or decrypt handshake
-// data.
-//
-// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
-// it as a single function allows for strongly-typed arguments that
-// reduce the risk of error in the caller (e.g. invoking X25519 with
-// two private keys, or two public keys), and thus producing the wrong
-// calculation.
-func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
- s.checkFinished()
- keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
- if err != nil {
- return nil, fmt.Errorf("computing X25519: %w", err)
- }
-
- r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
- if _, err := io.ReadFull(r, s.ck[:]); err != nil {
- return nil, fmt.Errorf("extracting ck: %w", err)
- }
- var k [chp.KeySize]byte
- if _, err := io.ReadFull(r, k[:]); err != nil {
- return nil, fmt.Errorf("extracting k: %w", err)
- }
- return newSingleUseCHP(k), nil
-}
-
-// EncryptAndHash encrypts plaintext into ciphertext (which must be
-// the correct size to hold the encrypted plaintext) using cipher,
-// mixes the ciphertext into s.h, and returns the ciphertext.
-func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
- s.checkFinished()
- if len(ciphertext) != len(plaintext)+chp.Overhead {
- panic("ciphertext is wrong size for given plaintext")
- }
- ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
- s.MixHash(ret)
-}
-
-// DecryptAndHash decrypts the given ciphertext into plaintext (which
-// must be the correct size to hold the decrypted ciphertext) using
-// cipher. If decryption is successful, it mixes the ciphertext into
-// s.h.
-func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
- s.checkFinished()
- if len(ciphertext) != len(plaintext)+chp.Overhead {
- return errors.New("plaintext is wrong size for given ciphertext")
- }
- if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
- return err
- }
- s.MixHash(ciphertext)
- return nil
-}
-
-// Split returns two ChaCha20Poly1305 ciphers with keys derived from
-// the current handshake state. Methods on s cannot be used again
-// after calling Split.
-func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
- s.finished = true
-
- var k1, k2 [chp.KeySize]byte
- r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
- if _, err := io.ReadFull(r, k1[:]); err != nil {
- return nil, nil, fmt.Errorf("extracting k1: %w", err)
- }
- if _, err := io.ReadFull(r, k2[:]); err != nil {
- return nil, nil, fmt.Errorf("extracting k2: %w", err)
- }
- c1, err = chp.New(k1[:])
- if err != nil {
- return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
- }
- c2, err = chp.New(k2[:])
- if err != nil {
- return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
- }
- return c1, c2, nil
-}
-
-// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
-// error.
-func newBLAKE2s() hash.Hash {
- h, err := blake2s.New256(nil)
- if err != nil {
- // Should never happen, errors only happen when using BLAKE2s
- // in MAC mode with a key.
- panic(err)
- }
- return h
-}
-
-// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
-// panics on error.
-func newCHP(key [chp.KeySize]byte) cipher.AEAD {
- aead, err := chp.New(key[:])
- if err != nil {
- // Can only happen if we passed a key of the wrong length. The
- // function signature prevents that.
- panic(err)
- }
- return aead
-}
-
-// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
-// only once, either for encrypting or decrypting, but not both. The
-// chosen operation is always executed with an all-zeros
-// nonce. Subsequent calls to either Seal or Open panic.
-type singleUseCHP struct {
- c cipher.AEAD
-}
-
-func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
- return &singleUseCHP{newCHP(key)}
-}
-
-func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
- if c.c == nil {
- panic("Attempted reuse of singleUseAEAD")
- }
- cipher := c.c
- c.c = nil
- var nonce [chp.NonceSize]byte
- return cipher.Seal(dst, nonce[:], plaintext, additionalData)
-}
-
-func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
- if c.c == nil {
- panic("Attempted reuse of singleUseAEAD")
- }
- cipher := c.c
- c.c = nil
- var nonce [chp.NonceSize]byte
- return cipher.Open(dst, nonce[:], ciphertext, additionalData)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ "go4.org/mem"
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+ "tailscale.com/types/key"
+)
+
+const (
+ // protocolName is the name of the specific instantiation of Noise
+ // that the control protocol uses. This string's value is fixed by
+ // the Noise spec, and shouldn't be changed unless we're updating
+ // the control protocol to use a different Noise instance.
+ protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+ // protocolVersion is the version of the control protocol that
+ // Client will use when initiating a handshake.
+ //protocolVersion uint16 = 1
+ // protocolVersionPrefix is the name portion of the protocol
+ // name+version string that gets mixed into the handshake as a
+ // prologue.
+ //
+ // This mixing verifies that both clients agree that they're
+ // executing the control protocol at a specific version that
+ // matches the advertised version in the cleartext packet header.
+ protocolVersionPrefix = "Tailscale Control Protocol v"
+ invalidNonce = ^uint64(0)
+)
+
+func protocolVersionPrologue(version uint16) []byte {
+ ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
+ ret = append(ret, protocolVersionPrefix...)
+ return strconv.AppendUint(ret, uint64(version), 10)
+}
+
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
+// is assumed to have already sent the client>server handshake
+// initiation message.
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
+
+// ClientDeferred initiates a control client handshake, returning the
+// initial message to send to the server and a continuation to
+// finalize the handshake.
+//
+// ClientDeferred is split in this way for RTT reduction: we run this
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
+// completely serialized the negotiation followed by the handshake,
+// we'd pay an extra RTT to transmit the handshake initiation after
+// protocol switching. By splitting the handshake into an initial
+// message and a continuation, we can embed the handshake initiation
+// into the HTTP protocol switching request and avoid a bit of delay.
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
+ var s symmetricState
+ s.Initialize()
+
+ // prologue
+ s.MixHash(protocolVersionPrologue(protocolVersion))
+
+ // <- s
+ // ...
+ s.MixHash(controlKey.UntypedBytes())
+
+ // -> e, es, s, ss
+ init := mkInitiationMessage(protocolVersion)
+ machineEphemeral := key.NewMachine()
+ machineEphemeralPub := machineEphemeral.Public()
+ copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(machineEphemeral, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing es: %w", err)
+ }
+ machineKeyPub := machineKey.Public()
+ s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
+ cipher, err = s.MixDH(machineKey, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
+ }
+ s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
+
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
+ }
+ return init[:], cont, nil
+}
+
+// Client wraps ClientDeferred and immediately invokes the returned
+// continuation with conn.
+//
+// This is a helper for when you don't need the fancy
+// continuation-style handshake, and just want to synchronously
+// upgrade a net.Conn to a secure transport.
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(init); err != nil {
+ return nil, err
+ }
+ return cont(ctx, conn)
+}
+
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ // No matter what, this function can only run once per s. Ensure
+ // attempted reuse causes a panic.
+ defer func() {
+ s.finished = true
+ }()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Read in the payload and look for errors/protocol violations from the server.
+ var resp responseMessage
+ if _, err := io.ReadFull(conn, resp.Header()); err != nil {
+ return nil, fmt.Errorf("reading response header: %w", err)
+ }
+ if resp.Type() != msgTypeResponse {
+ if resp.Type() != msgTypeError {
+ return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
+ }
+ msg := make([]byte, resp.Length())
+ if _, err := io.ReadFull(conn, msg); err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("server error: %q", msg)
+ }
+ if resp.Length() != len(resp.Payload()) {
+ return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
+ }
+ if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
+ return nil, err
+ }
+
+ // <- e, ee, se
+ controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting payload: %w", err)
+ }
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: protocolVersion,
+ peer: controlKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c1,
+ },
+ rx: rxState{
+ cipher: c2,
+ },
+ }
+ return c, nil
+}
+
+// Server initiates a control server handshake, returning the resulting
+// control connection.
+//
+// optionalInit can be the client's initial handshake message as
+// returned by ClientDeferred, or nil in which case the initial
+// message is read from conn.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Deliberately does not support formatting, so that we don't echo
+ // attacker-controlled input back to them.
+ sendErr := func(msg string) error {
+ if len(msg) >= 1<<16 {
+ msg = msg[:1<<16]
+ }
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeError
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ if _, err := io.WriteString(conn, msg); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ return fmt.Errorf("refused client handshake: %q", msg)
+ }
+
+ var s symmetricState
+ s.Initialize()
+
+ var init initiationMessage
+ if optionalInit != nil {
+ if len(optionalInit) != len(init) {
+ return nil, sendErr("wrong handshake initiation size")
+ }
+ copy(init[:], optionalInit)
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
+ return nil, err
+ }
+ // Just a rename to make it more obvious what the value is. In the
+ // current implementation we don't need to block any protocol
+ // versions at this layer, it's safe to let the handshake proceed
+ // and then let the caller make decisions based on the agreed-upon
+ // protocol version.
+ clientVersion := init.Version()
+ if init.Type() != msgTypeInitiation {
+ return nil, sendErr("unexpected handshake message type")
+ }
+ if init.Length() != len(init.Payload()) {
+ return nil, sendErr("wrong handshake initiation length")
+ }
+ // if optionalInit was provided, we have the payload already.
+ if optionalInit == nil {
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
+ return nil, err
+ }
+ }
+
+ // prologue. Can only do this once we at least think the client is
+ // handshaking using a supported version.
+ s.MixHash(protocolVersionPrologue(clientVersion))
+
+ // <- s
+ // ...
+ controlKeyPub := controlKey.Public()
+ s.MixHash(controlKeyPub.UntypedBytes())
+
+ // -> e, es, s, ss
+ machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(controlKey, machineEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing es: %w", err)
+ }
+ var machineKeyBytes [32]byte
+ if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
+ return nil, fmt.Errorf("decrypting machine key: %w", err)
+ }
+ machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
+ cipher, err = s.MixDH(controlKey, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing ss: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+ }
+
+ // <- e, ee, se
+ resp := mkResponseMessage()
+ controlEphemeral := key.NewMachine()
+ controlEphemeralPub := controlEphemeral.Public()
+ copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err = s.MixDH(controlEphemeral, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ if _, err := conn.Write(resp[:]); err != nil {
+ return nil, err
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: clientVersion,
+ peer: machineKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c2,
+ },
+ rx: rxState{
+ cipher: c1,
+ },
+ }
+ return c, nil
+}
+
+// symmetricState contains the state of an in-flight handshake.
+type symmetricState struct {
+ finished bool
+
+ h [blake2s.Size]byte // hash of currently-processed handshake state
+ ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
+}
+
+func (s *symmetricState) checkFinished() {
+ if s.finished {
+ panic("attempted to use symmetricState after Split was called")
+ }
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any handshake messages.
+func (s *symmetricState) Initialize() {
+ s.checkFinished()
+ s.h = blake2s.Sum256([]byte(protocolName))
+ s.ck = s.h
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+ s.checkFinished()
+ h := newBLAKE2s()
+ h.Write(s.h[:])
+ h.Write(data)
+ h.Sum(s.h[:0])
+}
+
+// MixDH updates s.ck with the result of X25519(priv, pub) and returns
+// a singleUseCHP that can be used to encrypt or decrypt handshake
+// data.
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
+ s.checkFinished()
+ keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
+ if err != nil {
+ return nil, fmt.Errorf("computing X25519: %w", err)
+ }
+
+ r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+ if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+ return nil, fmt.Errorf("extracting ck: %w", err)
+ }
+ var k [chp.KeySize]byte
+ if _, err := io.ReadFull(r, k[:]); err != nil {
+ return nil, fmt.Errorf("extracting k: %w", err)
+ }
+ return newSingleUseCHP(k), nil
+}
+
+// EncryptAndHash encrypts plaintext into ciphertext (which must be
+// the correct size to hold the encrypted plaintext) using cipher,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ panic("ciphertext is wrong size for given plaintext")
+ }
+ ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
+ s.MixHash(ret)
+}
+
+// DecryptAndHash decrypts the given ciphertext into plaintext (which
+// must be the correct size to hold the decrypted ciphertext) using
+// cipher. If decryption is successful, it mixes the ciphertext into
+// s.h.
+func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ return errors.New("plaintext is wrong size for given ciphertext")
+ }
+ if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
+ return err
+ }
+ s.MixHash(ciphertext)
+ return nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derived from
+// the current handshake state. Methods on s cannot be used again
+// after calling Split.
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+ s.finished = true
+
+ var k1, k2 [chp.KeySize]byte
+ r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+ if _, err := io.ReadFull(r, k1[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k1: %w", err)
+ }
+ if _, err := io.ReadFull(r, k2[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k2: %w", err)
+ }
+ c1, err = chp.New(k1[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+ }
+ c2, err = chp.New(k2[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+ }
+ return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ // Should never happen, errors only happen when using BLAKE2s
+ // in MAC mode with a key.
+ panic(err)
+ }
+ return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+ aead, err := chp.New(key[:])
+ if err != nil {
+ // Can only happen if we passed a key of the wrong length. The
+ // function signature prevents that.
+ panic(err)
+ }
+ return aead
+}
+
+// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
+// only once, either for encrypting or decrypting, but not both. The
+// chosen operation is always executed with an all-zeros
+// nonce. Subsequent calls to either Seal or Open panic.
+type singleUseCHP struct {
+ c cipher.AEAD
+}
+
+func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
+ return &singleUseCHP{newCHP(key)}
+}
+
+func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Seal(dst, nonce[:], plaintext, additionalData)
+}
+
+func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Open(dst, nonce[:], ciphertext, additionalData)
+}
diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go
index c41fbf4dd..d11c04149 100644
--- a/control/controlbase/interop_test.go
+++ b/control/controlbase/interop_test.go
@@ -1,256 +1,256 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package controlbase
-
-import (
- "context"
- "encoding/binary"
- "errors"
- "io"
- "net"
- "testing"
-
- "tailscale.com/net/memnet"
- "tailscale.com/types/key"
-)
-
-// Can a reference Noise IK client talk to our server?
-func TestInteropClient(t *testing.T) {
- var (
- s1, s2 = memnet.NewConn("noise", 128000)
- controlKey = key.NewMachine()
- machineKey = key.NewMachine()
- serverErr = make(chan error, 2)
- serverBytes = make(chan []byte, 1)
- c2s = "client>server"
- s2c = "server>client"
- )
-
- go func() {
- server, err := Server(context.Background(), s2, controlKey, nil)
- serverErr <- err
- if err != nil {
- return
- }
- var buf [1024]byte
- _, err = io.ReadFull(server, buf[:len(c2s)])
- serverBytes <- buf[:len(c2s)]
- if err != nil {
- serverErr <- err
- return
- }
- _, err = server.Write([]byte(s2c))
- serverErr <- err
- }()
-
- gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
- if err != nil {
- t.Fatalf("failed client interop: %v", err)
- }
- if string(gotS2C) != s2c {
- t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
- }
-
- if err := <-serverErr; err != nil {
- t.Fatalf("server handshake failed: %v", err)
- }
- if err := <-serverErr; err != nil {
- t.Fatalf("server read/write failed: %v", err)
- }
- if got := string(<-serverBytes); got != c2s {
- t.Fatalf("server received %q, want %q", got, c2s)
- }
-}
-
-// Can our client talk to a reference Noise IK server?
-func TestInteropServer(t *testing.T) {
- var (
- s1, s2 = memnet.NewConn("noise", 128000)
- controlKey = key.NewMachine()
- machineKey = key.NewMachine()
- clientErr = make(chan error, 2)
- clientBytes = make(chan []byte, 1)
- c2s = "client>server"
- s2c = "server>client"
- )
-
- go func() {
- client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
- clientErr <- err
- if err != nil {
- return
- }
- _, err = client.Write([]byte(c2s))
- if err != nil {
- clientErr <- err
- return
- }
- var buf [1024]byte
- _, err = io.ReadFull(client, buf[:len(s2c)])
- clientBytes <- buf[:len(s2c)]
- clientErr <- err
- }()
-
- gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
- if err != nil {
- t.Fatalf("failed server interop: %v", err)
- }
- if string(gotC2S) != c2s {
- t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
- }
-
- if err := <-clientErr; err != nil {
- t.Fatalf("client handshake failed: %v", err)
- }
- if err := <-clientErr; err != nil {
- t.Fatalf("client read/write failed: %v", err)
- }
- if got := string(<-clientBytes); got != s2c {
- t.Fatalf("client received %q, want %q", got, s2c)
- }
-}
-
-// noiseExplorerClient uses the Noise Explorer implementation of Noise
-// IK to handshake as a Noise client on conn, transmit payload, and
-// read+return a payload from the peer.
-func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
- var mk keypair
- copy(mk.private_key[:], machineKey.UntypedBytes())
- copy(mk.public_key[:], machineKey.Public().UntypedBytes())
- var peerKey [32]byte
- copy(peerKey[:], controlKey.UntypedBytes())
- session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
-
- _, msg1 := SendMessage(&session, nil)
- var hdr [initiationHeaderLen]byte
- binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
- hdr[2] = msgTypeInitiation
- binary.BigEndian.PutUint16(hdr[3:5], 96)
- if _, err := conn.Write(hdr[:]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg1.ne[:]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg1.ns); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg1.ciphertext); err != nil {
- return nil, err
- }
-
- var buf [1024]byte
- if _, err := io.ReadFull(conn, buf[:51]); err != nil {
- return nil, err
- }
- // ignore the header for this test, we're only checking the noise
- // implementation.
- msg2 := messagebuffer{
- ciphertext: buf[35:51],
- }
- copy(msg2.ne[:], buf[3:35])
- _, p, valid := RecvMessage(&session, &msg2)
- if !valid {
- return nil, errors.New("handshake failed")
- }
- if len(p) != 0 {
- return nil, errors.New("non-empty payload")
- }
-
- _, msg3 := SendMessage(&session, payload)
- hdr[0] = msgTypeRecord
- binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
- if _, err := conn.Write(hdr[:3]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg3.ciphertext); err != nil {
- return nil, err
- }
-
- if _, err := io.ReadFull(conn, buf[:3]); err != nil {
- return nil, err
- }
- // Ignore all of the header except the payload length
- plen := int(binary.BigEndian.Uint16(buf[1:3]))
- if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
- return nil, err
- }
-
- msg4 := messagebuffer{
- ciphertext: buf[:plen],
- }
- _, p, valid = RecvMessage(&session, &msg4)
- if !valid {
- return nil, errors.New("transport message decryption failed")
- }
-
- return p, nil
-}
-
-func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
- var mk keypair
- copy(mk.private_key[:], controlKey.UntypedBytes())
- copy(mk.public_key[:], controlKey.Public().UntypedBytes())
- session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
-
- var buf [1024]byte
- if _, err := io.ReadFull(conn, buf[:101]); err != nil {
- return nil, err
- }
- // Ignore the header, we're just checking the noise implementation.
- msg1 := messagebuffer{
- ns: buf[37:85],
- ciphertext: buf[85:101],
- }
- copy(msg1.ne[:], buf[5:37])
- _, p, valid := RecvMessage(&session, &msg1)
- if !valid {
- return nil, errors.New("handshake failed")
- }
- if len(p) != 0 {
- return nil, errors.New("non-empty payload")
- }
-
- _, msg2 := SendMessage(&session, nil)
- var hdr [headerLen]byte
- hdr[0] = msgTypeResponse
- binary.BigEndian.PutUint16(hdr[1:3], 48)
- if _, err := conn.Write(hdr[:]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg2.ne[:]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
- return nil, err
- }
-
- if _, err := io.ReadFull(conn, buf[:3]); err != nil {
- return nil, err
- }
- plen := int(binary.BigEndian.Uint16(buf[1:3]))
- if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
- return nil, err
- }
-
- msg3 := messagebuffer{
- ciphertext: buf[:plen],
- }
- _, p, valid = RecvMessage(&session, &msg3)
- if !valid {
- return nil, errors.New("transport message decryption failed")
- }
-
- _, msg4 := SendMessage(&session, payload)
- hdr[0] = msgTypeRecord
- binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
- if _, err := conn.Write(hdr[:]); err != nil {
- return nil, err
- }
- if _, err := conn.Write(msg4.ciphertext); err != nil {
- return nil, err
- }
-
- return p, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "testing"
+
+ "tailscale.com/net/memnet"
+ "tailscale.com/types/key"
+)
+
+// Can a reference Noise IK client talk to our server?
+func TestInteropClient(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ serverErr = make(chan error, 2)
+ serverBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ server, err := Server(context.Background(), s2, controlKey, nil)
+ serverErr <- err
+ if err != nil {
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(server, buf[:len(c2s)])
+ serverBytes <- buf[:len(c2s)]
+ if err != nil {
+ serverErr <- err
+ return
+ }
+ _, err = server.Write([]byte(s2c))
+ serverErr <- err
+ }()
+
+ gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
+ if err != nil {
+ t.Fatalf("failed client interop: %v", err)
+ }
+ if string(gotS2C) != s2c {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
+ }
+
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server handshake failed: %v", err)
+ }
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server read/write failed: %v", err)
+ }
+ if got := string(<-serverBytes); got != c2s {
+ t.Fatalf("server received %q, want %q", got, c2s)
+ }
+}
+
+// Can our client talk to a reference Noise IK server?
+func TestInteropServer(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ clientErr = make(chan error, 2)
+ clientBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
+ clientErr <- err
+ if err != nil {
+ return
+ }
+ _, err = client.Write([]byte(c2s))
+ if err != nil {
+ clientErr <- err
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(client, buf[:len(s2c)])
+ clientBytes <- buf[:len(s2c)]
+ clientErr <- err
+ }()
+
+ gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
+ if err != nil {
+ t.Fatalf("failed server interop: %v", err)
+ }
+ if string(gotC2S) != c2s {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
+ }
+
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client handshake failed: %v", err)
+ }
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client read/write failed: %v", err)
+ }
+ if got := string(<-clientBytes); got != s2c {
+ t.Fatalf("client received %q, want %q", got, s2c)
+ }
+}
+
+// noiseExplorerClient uses the Noise Explorer implementation of Noise
+// IK to handshake as a Noise client on conn, transmit payload, and
+// read+return a payload from the peer.
+func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], machineKey.UntypedBytes())
+ copy(mk.public_key[:], machineKey.Public().UntypedBytes())
+ var peerKey [32]byte
+ copy(peerKey[:], controlKey.UntypedBytes())
+ session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
+
+ _, msg1 := SendMessage(&session, nil)
+ var hdr [initiationHeaderLen]byte
+ binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
+ hdr[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(hdr[3:5], 96)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ns); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ciphertext); err != nil {
+ return nil, err
+ }
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:51]); err != nil {
+ return nil, err
+ }
+ // ignore the header for this test, we're only checking the noise
+ // implementation.
+ msg2 := messagebuffer{
+ ciphertext: buf[35:51],
+ }
+ copy(msg2.ne[:], buf[3:35])
+ _, p, valid := RecvMessage(&session, &msg2)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg3 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
+ if _, err := conn.Write(hdr[:3]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg3.ciphertext); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ // Ignore all of the header except the payload length
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg4 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg4)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ return p, nil
+}
+
+func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], controlKey.UntypedBytes())
+ copy(mk.public_key[:], controlKey.Public().UntypedBytes())
+ session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:101]); err != nil {
+ return nil, err
+ }
+ // Ignore the header, we're just checking the noise implementation.
+ msg1 := messagebuffer{
+ ns: buf[37:85],
+ ciphertext: buf[85:101],
+ }
+ copy(msg1.ne[:], buf[5:37])
+ _, p, valid := RecvMessage(&session, &msg1)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg2 := SendMessage(&session, nil)
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(hdr[1:3], 48)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg3 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg3)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ _, msg4 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg4.ciphertext); err != nil {
+ return nil, err
+ }
+
+ return p, nil
+}
diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go
index 59073088f..899378681 100644
--- a/control/controlbase/messages.go
+++ b/control/controlbase/messages.go
@@ -1,87 +1,87 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package controlbase
-
-import "encoding/binary"
-
-const (
- // msgTypeInitiation frames carry a Noise IK handshake initiation message.
- msgTypeInitiation = 1
- // msgTypeResponse frames carry a Noise IK handshake response message.
- msgTypeResponse = 2
- // msgTypeError frames carry an unauthenticated human-readable
- // error message.
- //
- // Errors reported in this message type must be treated as public
- // hints only. They are not encrypted or authenticated, and so can
- // be seen and tampered with on the wire.
- msgTypeError = 3
- // msgTypeRecord frames carry session data bytes.
- msgTypeRecord = 4
-
- // headerLen is the size of the header on all messages except msgTypeInitiation.
- headerLen = 3
- // initiationHeaderLen is the size of the header on all msgTypeInitiation messages.
- initiationHeaderLen = 5
-)
-
-// initiationMessage is the protocol message sent from a client
-// machine to a control server.
-//
-// 2b: protocol version
-// 1b: message type (0x01)
-// 2b: payload length (96)
-// 5b: header (see headerLen for fields)
-// 32b: client ephemeral public key (cleartext)
-// 48b: client machine public key (encrypted)
-// 16b: message tag (authenticates the whole message)
-type initiationMessage [101]byte
-
-func mkInitiationMessage(protocolVersion uint16) initiationMessage {
- var ret initiationMessage
- binary.BigEndian.PutUint16(ret[:2], protocolVersion)
- ret[2] = msgTypeInitiation
- binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
- return ret
-}
-
-func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] }
-func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] }
-
-func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) }
-func (m *initiationMessage) Type() byte { return m[2] }
-func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) }
-
-func (m *initiationMessage) EphemeralPub() []byte {
- return m[initiationHeaderLen : initiationHeaderLen+32]
-}
-func (m *initiationMessage) MachinePub() []byte {
- return m[initiationHeaderLen+32 : initiationHeaderLen+32+48]
-}
-func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] }
-
-// responseMessage is the protocol message sent from a control server
-// to a client machine.
-//
-// 1b: message type (0x02)
-// 2b: payload length (48)
-// 32b: control ephemeral public key (cleartext)
-// 16b: message tag (authenticates the whole message)
-type responseMessage [51]byte
-
-func mkResponseMessage() responseMessage {
- var ret responseMessage
- ret[0] = msgTypeResponse
- binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload())))
- return ret
-}
-
-func (m *responseMessage) Header() []byte { return m[:headerLen] }
-func (m *responseMessage) Payload() []byte { return m[headerLen:] }
-
-func (m *responseMessage) Type() byte { return m[0] }
-func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) }
-
-func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] }
-func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import "encoding/binary"
+
+const (
+ // msgTypeInitiation frames carry a Noise IK handshake initiation message.
+ msgTypeInitiation = 1
+ // msgTypeResponse frames carry a Noise IK handshake response message.
+ msgTypeResponse = 2
+ // msgTypeError frames carry an unauthenticated human-readable
+ // error message.
+ //
+ // Errors reported in this message type must be treated as public
+ // hints only. They are not encrypted or authenticated, and so can
+ // be seen and tampered with on the wire.
+ msgTypeError = 3
+ // msgTypeRecord frames carry session data bytes.
+ msgTypeRecord = 4
+
+ // headerLen is the size of the header on all messages except msgTypeInitiation.
+ headerLen = 3
+ // initiationHeaderLen is the size of the header on all msgTypeInitiation messages.
+ initiationHeaderLen = 5
+)
+
+// initiationMessage is the protocol message sent from a client
+// machine to a control server.
+//
+// 2b: protocol version
+// 1b: message type (0x01)
+// 2b: payload length (96)
+// 5b: header (see headerLen for fields)
+// 32b: client ephemeral public key (cleartext)
+// 48b: client machine public key (encrypted)
+// 16b: message tag (authenticates the whole message)
+type initiationMessage [101]byte
+
+func mkInitiationMessage(protocolVersion uint16) initiationMessage {
+ var ret initiationMessage
+ binary.BigEndian.PutUint16(ret[:2], protocolVersion)
+ ret[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] }
+func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] }
+
+func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) }
+func (m *initiationMessage) Type() byte { return m[2] }
+func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) }
+
+func (m *initiationMessage) EphemeralPub() []byte {
+ return m[initiationHeaderLen : initiationHeaderLen+32]
+}
+func (m *initiationMessage) MachinePub() []byte {
+ return m[initiationHeaderLen+32 : initiationHeaderLen+32+48]
+}
+func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] }
+
+// responseMessage is the protocol message sent from a control server
+// to a client machine.
+//
+// 1b: message type (0x02)
+// 2b: payload length (48)
+// 32b: control ephemeral public key (cleartext)
+// 16b: message tag (authenticates the whole message)
+type responseMessage [51]byte
+
+func mkResponseMessage() responseMessage {
+ var ret responseMessage
+ ret[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *responseMessage) Header() []byte { return m[:headerLen] }
+func (m *responseMessage) Payload() []byte { return m[headerLen:] }
+
+func (m *responseMessage) Type() byte { return m[0] }
+func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) }
+
+func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] }
+func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }