diff options
Diffstat (limited to 'control/controlbase')
| -rw-r--r-- | control/controlbase/conn.go | 816 | ||||
| -rw-r--r-- | control/controlbase/handshake.go | 988 | ||||
| -rw-r--r-- | control/controlbase/interop_test.go | 512 | ||||
| -rw-r--r-- | control/controlbase/messages.go | 174 |
4 files changed, 1245 insertions, 1245 deletions
diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index dc22212e8..b6fc53b3a 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -1,408 +1,408 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package controlbase implements the base transport of the Tailscale -// 2021 control protocol. -// -// The base transport implements Noise IK, instantiated with -// Curve25519, ChaCha20Poly1305 and BLAKE2s. -package controlbase - -import ( - "crypto/cipher" - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "tailscale.com/types/key" -) - -const ( - // maxMessageSize is the maximum size of a protocol frame on the - // wire, including header and payload. - maxMessageSize = 4096 - // maxCiphertextSize is the maximum amount of ciphertext bytes - // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - 3 - // maxPlaintextSize is the maximum amount of plaintext bytes that - // one protocol frame can carry, after encryption and framing. - maxPlaintextSize = maxCiphertextSize - chp.Overhead -) - -// A Conn is a secured Noise connection. It implements the net.Conn -// interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) causes all future writes to -// fail. -type Conn struct { - conn net.Conn - version uint16 - peer key.MachinePublic - handshakeHash [blake2s.Size]byte - rx rxState - tx txState -} - -// rxState is all the Conn state that Read uses. -type rxState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - buf *maxMsgBuffer // or nil when reads exhausted - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes - hdrBuf [headerLen]byte // small buffer used when buf is nil -} - -// txState is all the Conn state that Write uses. -type txState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - err error // records the first partial write error for all future calls -} - -// ProtocolVersion returns the protocol version that was used to -// establish this Conn. -func (c *Conn) ProtocolVersion() int { - return int(c.version) -} - -// HandshakeHash returns the Noise handshake hash for the connection, -// which can be used to bind other messages to this connection -// (i.e. to ensure that the message wasn't replayed from a different -// connection). -func (c *Conn) HandshakeHash() [blake2s.Size]byte { - return c.handshakeHash -} - -// Peer returns the peer's long-term public key. -func (c *Conn) Peer() key.MachinePublic { - return c.peer -} - -// readNLocked reads into c.rx.buf until buf contains at least total -// bytes. Returns a slice of the total bytes in rxBuf, or an -// error if fewer than total bytes are available. -// -// It may be called with a nil c.rx.buf only if total == headerLen. -// -// On success, c.rx.buf will be non-nil. -func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxMessageSize { - return nil, errReadTooBig{total} - } - for { - if total <= c.rx.n { - return c.rx.buf[:total], nil - } - var n int - var err error - if c.rx.buf == nil { - if c.rx.n != 0 || total != headerLen { - panic("unexpected") - } - // Optimization to reduce memory usage. - // Most connections are blocked forever waiting for - // a read, so we don't want c.rx.buf to be allocated until - // we know there's data to read. Instead, when we're - // waiting for data to arrive here, read into the - // 3 byte hdrBuf: - n, err = c.conn.Read(c.rx.hdrBuf[:]) - if n > 0 { - c.rx.buf = getMaxMsgBuffer() - copy(c.rx.buf[:], c.rx.hdrBuf[:n]) - } - } else { - n, err = c.conn.Read(c.rx.buf[c.rx.n:]) - } - c.rx.n += n - if err != nil { - return nil, err - } - } -} - -// decryptLocked decrypts msg (which is header+ciphertext) in-place -// and sets c.rx.plaintext to the decrypted bytes. -func (c *Conn) decryptLocked(msg []byte) (err error) { - if msgType := msg[0]; msgType != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) - } - // We don't check the length field here, because the caller - // already did in order to figure out how big the msg slice should - // be. - ciphertext := msg[headerLen:] - - if !c.rx.nonce.Valid() { - return errCipherExhausted{} - } - - c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) - c.rx.nonce.Increment() - - if err != nil { - // Once a decryption has failed, our Conn is no longer - // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. Future - // read attempts will return net.ErrClosed. - c.rx.cipher = nil - } - return err -} - -// encryptLocked encrypts plaintext into buf (including the -// packet header) and returns a slice of the ciphertext, or an error -// if the cipher is exhausted (i.e. can no longer be used safely). -func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { - if !c.tx.nonce.Valid() { - // Received 2^64-1 messages on this cipher state. Connection - // is no longer usable. - return nil, errCipherExhausted{} - } - - buf[0] = msgTypeRecord - binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) - ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) - c.tx.nonce.Increment() - - return ret, nil -} - -// wholeMessageLocked returns a slice of one whole Noise transport -// message from c.rx.buf, if one whole message is available, and -// advances the read state to the next Noise message in the -// buffer. Returns nil without advancing read state if there isn't one -// whole message in c.rx.buf. -func (c *Conn) wholeMessageLocked() []byte { - available := c.rx.n - c.rx.next - if available < headerLen { - return nil - } - bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - if len(bs) < totalSize { - return nil - } - c.rx.next += totalSize - return bs[:totalSize] -} - -// decryptOneLocked decrypts one Noise transport message, reading from -// c.conn as needed, and sets c.rx.plaintext to point to the decrypted -// bytes. c.rx.plaintext is only valid if err == nil. -func (c *Conn) decryptOneLocked() error { - c.rx.plaintext = nil - - // Fast path: do we have one whole ciphertext frame buffered - // already? - if bs := c.wholeMessageLocked(); bs != nil { - return c.decryptLocked(bs) - } - - if c.rx.next != 0 { - // To simplify the read logic, move the remainder of the - // buffered bytes back to the head of the buffer, so we can - // grow it without worrying about wraparound. - c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.next = 0 - } - - // Return our buffer to the pool if it's empty, lest we be - // blocked in a long Read call, reading the 3 byte header. We - // don't to keep that buffer unnecessarily alive. - if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { - bufPool.Put(c.rx.buf) - c.rx.buf = nil - } - - bs, err := c.readNLocked(headerLen) - if err != nil { - return err - } - // The rest of the header (besides the length field) gets verified - // in decryptLocked, not here. - messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - bs, err = c.readNLocked(messageLen) - if err != nil { - return err - } - - c.rx.next = len(bs) - - return c.decryptLocked(bs) -} - -// Read implements io.Reader. -func (c *Conn) Read(bs []byte) (int, error) { - c.rx.Lock() - defer c.rx.Unlock() - - if c.rx.cipher == nil { - return 0, net.ErrClosed - } - // If no plaintext is buffered, decrypt incoming frames until we - // have some plaintext. Zero-byte Noise frames are allowed in this - // protocol, which is why we have to loop here rather than decrypt - // a single additional frame. - for len(c.rx.plaintext) == 0 { - if err := c.decryptOneLocked(); err != nil { - return 0, err - } - } - n := copy(bs, c.rx.plaintext) - c.rx.plaintext = c.rx.plaintext[n:] - - // Lose slice's underlying array pointer to unneeded memory so - // GC can collect more. - if len(c.rx.plaintext) == 0 { - c.rx.plaintext = nil - } - return n, nil -} - -// Write implements io.Writer. -func (c *Conn) Write(bs []byte) (n int, err error) { - c.tx.Lock() - defer c.tx.Unlock() - - if c.tx.err != nil { - return 0, c.tx.err - } - defer func() { - if err != nil { - // All write errors are fatal for this conn, so clear the - // cipher state whenever an error happens. - c.tx.cipher = nil - } - if c.tx.err == nil { - // Only set c.tx.err if not nil so that we can return one - // error on the first failure, and a different one for - // subsequent calls. See the error handling around Write - // below for why. - c.tx.err = err - } - }() - - if c.tx.cipher == nil { - return 0, net.ErrClosed - } - - buf := getMaxMsgBuffer() - defer bufPool.Put(buf) - - var sent int - for len(bs) > 0 { - toSend := bs - if len(toSend) > maxPlaintextSize { - toSend = bs[:maxPlaintextSize] - } - bs = bs[len(toSend):] - - ciphertext, err := c.encryptLocked(toSend, buf) - if err != nil { - return sent, err - } - if _, err := c.conn.Write(ciphertext); err != nil { - // Return the raw error on the Write that actually - // failed. For future writes, return that error wrapped in - // a desync error. - c.tx.err = errPartialWrite{err} - return sent, err - } - sent += len(toSend) - } - return sent, nil -} - -// Close implements io.Closer. -func (c *Conn) Close() error { - closeErr := c.conn.Close() // unblocks any waiting reads or writes - - // Remove references to live cipher state. Strictly speaking this - // is unnecessary, but we want to try and hand the active cipher - // state to the garbage collector promptly, to preserve perfect - // forward secrecy as much as we can. - c.rx.Lock() - c.rx.cipher = nil - c.rx.Unlock() - c.tx.Lock() - c.tx.cipher = nil - c.tx.Unlock() - return closeErr -} - -func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } - -// errCipherExhausted is the error returned when we run out of nonces -// on a cipher. -type errCipherExhausted struct{} - -func (errCipherExhausted) Error() string { - return "cipher exhausted, no more nonces available for current key" -} -func (errCipherExhausted) Timeout() bool { return false } -func (errCipherExhausted) Temporary() bool { return false } - -// errPartialWrite is the error returned when the cipher state has -// become unusable due to a past partial write. -type errPartialWrite struct { - err error -} - -func (e errPartialWrite) Error() string { - return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) -} -func (e errPartialWrite) Unwrap() error { return e.err } -func (e errPartialWrite) Temporary() bool { return false } -func (e errPartialWrite) Timeout() bool { return false } - -// errReadTooBig is the error returned when the peer sent an -// unacceptably large Noise frame. -type errReadTooBig struct { - requested int -} - -func (e errReadTooBig) Error() string { - return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) -} -func (e errReadTooBig) Temporary() bool { - // permanent error because this error only occurs when our peer - // sends us a frame so large we're unwilling to ever decode it. - return false -} -func (e errReadTooBig) Timeout() bool { return false } - -type nonce [chp.NonceSize]byte - -func (n *nonce) Valid() bool { - return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce -} - -func (n *nonce) Increment() { - if !n.Valid() { - panic("increment of invalid nonce") - } - binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) -} - -type maxMsgBuffer [maxMessageSize]byte - -// bufPool holds the temporary buffers for Conn.Read & Write. -var bufPool = &sync.Pool{ - New: func() any { - return new(maxMsgBuffer) - }, -} - -func getMaxMsgBuffer() *maxMsgBuffer { - return bufPool.Get().(*maxMsgBuffer) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package controlbase implements the base transport of the Tailscale
+// 2021 control protocol.
+//
+// The base transport implements Noise IK, instantiated with
+// Curve25519, ChaCha20Poly1305 and BLAKE2s.
+package controlbase
+
+import (
+ "crypto/cipher"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "tailscale.com/types/key"
+)
+
+const (
+ // maxMessageSize is the maximum size of a protocol frame on the
+ // wire, including header and payload.
+ maxMessageSize = 4096
+ // maxCiphertextSize is the maximum amount of ciphertext bytes
+ // that one protocol frame can carry, after framing.
+ maxCiphertextSize = maxMessageSize - 3
+ // maxPlaintextSize is the maximum amount of plaintext bytes that
+ // one protocol frame can carry, after encryption and framing.
+ maxPlaintextSize = maxCiphertextSize - chp.Overhead
+)
+
+// A Conn is a secured Noise connection. It implements the net.Conn
+// interface, with the unusual trait that any write error (including a
+// SetWriteDeadline induced i/o timeout) causes all future writes to
+// fail.
+type Conn struct {
+ conn net.Conn
+ version uint16
+ peer key.MachinePublic
+ handshakeHash [blake2s.Size]byte
+ rx rxState
+ tx txState
+}
+
+// rxState is all the Conn state that Read uses.
+type rxState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ buf *maxMsgBuffer // or nil when reads exhausted
+ n int // number of valid bytes in buf
+ next int // offset of next undecrypted packet
+ plaintext []byte // slice into buf of decrypted bytes
+ hdrBuf [headerLen]byte // small buffer used when buf is nil
+}
+
+// txState is all the Conn state that Write uses.
+type txState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ err error // records the first partial write error for all future calls
+}
+
+// ProtocolVersion returns the protocol version that was used to
+// establish this Conn.
+func (c *Conn) ProtocolVersion() int {
+ return int(c.version)
+}
+
+// HandshakeHash returns the Noise handshake hash for the connection,
+// which can be used to bind other messages to this connection
+// (i.e. to ensure that the message wasn't replayed from a different
+// connection).
+func (c *Conn) HandshakeHash() [blake2s.Size]byte {
+ return c.handshakeHash
+}
+
+// Peer returns the peer's long-term public key.
+func (c *Conn) Peer() key.MachinePublic {
+ return c.peer
+}
+
+// readNLocked reads into c.rx.buf until buf contains at least total
+// bytes. Returns a slice of the total bytes in rxBuf, or an
+// error if fewer than total bytes are available.
+//
+// It may be called with a nil c.rx.buf only if total == headerLen.
+//
+// On success, c.rx.buf will be non-nil.
+func (c *Conn) readNLocked(total int) ([]byte, error) {
+ if total > maxMessageSize {
+ return nil, errReadTooBig{total}
+ }
+ for {
+ if total <= c.rx.n {
+ return c.rx.buf[:total], nil
+ }
+ var n int
+ var err error
+ if c.rx.buf == nil {
+ if c.rx.n != 0 || total != headerLen {
+ panic("unexpected")
+ }
+ // Optimization to reduce memory usage.
+ // Most connections are blocked forever waiting for
+ // a read, so we don't want c.rx.buf to be allocated until
+ // we know there's data to read. Instead, when we're
+ // waiting for data to arrive here, read into the
+ // 3 byte hdrBuf:
+ n, err = c.conn.Read(c.rx.hdrBuf[:])
+ if n > 0 {
+ c.rx.buf = getMaxMsgBuffer()
+ copy(c.rx.buf[:], c.rx.hdrBuf[:n])
+ }
+ } else {
+ n, err = c.conn.Read(c.rx.buf[c.rx.n:])
+ }
+ c.rx.n += n
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// decryptLocked decrypts msg (which is header+ciphertext) in-place
+// and sets c.rx.plaintext to the decrypted bytes.
+func (c *Conn) decryptLocked(msg []byte) (err error) {
+ if msgType := msg[0]; msgType != msgTypeRecord {
+ return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
+ }
+ // We don't check the length field here, because the caller
+ // already did in order to figure out how big the msg slice should
+ // be.
+ ciphertext := msg[headerLen:]
+
+ if !c.rx.nonce.Valid() {
+ return errCipherExhausted{}
+ }
+
+ c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
+ c.rx.nonce.Increment()
+
+ if err != nil {
+ // Once a decryption has failed, our Conn is no longer
+ // synchronized with our peer. Nuke the cipher state to be
+ // safe, so that no further decryptions are attempted. Future
+ // read attempts will return net.ErrClosed.
+ c.rx.cipher = nil
+ }
+ return err
+}
+
+// encryptLocked encrypts plaintext into buf (including the
+// packet header) and returns a slice of the ciphertext, or an error
+// if the cipher is exhausted (i.e. can no longer be used safely).
+func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
+ if !c.tx.nonce.Valid() {
+ // Received 2^64-1 messages on this cipher state. Connection
+ // is no longer usable.
+ return nil, errCipherExhausted{}
+ }
+
+ buf[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
+ ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
+ c.tx.nonce.Increment()
+
+ return ret, nil
+}
+
+// wholeMessageLocked returns a slice of one whole Noise transport
+// message from c.rx.buf, if one whole message is available, and
+// advances the read state to the next Noise message in the
+// buffer. Returns nil without advancing read state if there isn't one
+// whole message in c.rx.buf.
+func (c *Conn) wholeMessageLocked() []byte {
+ available := c.rx.n - c.rx.next
+ if available < headerLen {
+ return nil
+ }
+ bs := c.rx.buf[c.rx.next:c.rx.n]
+ totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ if len(bs) < totalSize {
+ return nil
+ }
+ c.rx.next += totalSize
+ return bs[:totalSize]
+}
+
+// decryptOneLocked decrypts one Noise transport message, reading from
+// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
+// bytes. c.rx.plaintext is only valid if err == nil.
+func (c *Conn) decryptOneLocked() error {
+ c.rx.plaintext = nil
+
+ // Fast path: do we have one whole ciphertext frame buffered
+ // already?
+ if bs := c.wholeMessageLocked(); bs != nil {
+ return c.decryptLocked(bs)
+ }
+
+ if c.rx.next != 0 {
+ // To simplify the read logic, move the remainder of the
+ // buffered bytes back to the head of the buffer, so we can
+ // grow it without worrying about wraparound.
+ c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
+ c.rx.next = 0
+ }
+
+ // Return our buffer to the pool if it's empty, lest we be
+ // blocked in a long Read call, reading the 3 byte header. We
+ // don't to keep that buffer unnecessarily alive.
+ if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
+ bufPool.Put(c.rx.buf)
+ c.rx.buf = nil
+ }
+
+ bs, err := c.readNLocked(headerLen)
+ if err != nil {
+ return err
+ }
+ // The rest of the header (besides the length field) gets verified
+ // in decryptLocked, not here.
+ messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ bs, err = c.readNLocked(messageLen)
+ if err != nil {
+ return err
+ }
+
+ c.rx.next = len(bs)
+
+ return c.decryptLocked(bs)
+}
+
+// Read implements io.Reader.
+func (c *Conn) Read(bs []byte) (int, error) {
+ c.rx.Lock()
+ defer c.rx.Unlock()
+
+ if c.rx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+ // If no plaintext is buffered, decrypt incoming frames until we
+ // have some plaintext. Zero-byte Noise frames are allowed in this
+ // protocol, which is why we have to loop here rather than decrypt
+ // a single additional frame.
+ for len(c.rx.plaintext) == 0 {
+ if err := c.decryptOneLocked(); err != nil {
+ return 0, err
+ }
+ }
+ n := copy(bs, c.rx.plaintext)
+ c.rx.plaintext = c.rx.plaintext[n:]
+
+ // Lose slice's underlying array pointer to unneeded memory so
+ // GC can collect more.
+ if len(c.rx.plaintext) == 0 {
+ c.rx.plaintext = nil
+ }
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (c *Conn) Write(bs []byte) (n int, err error) {
+ c.tx.Lock()
+ defer c.tx.Unlock()
+
+ if c.tx.err != nil {
+ return 0, c.tx.err
+ }
+ defer func() {
+ if err != nil {
+ // All write errors are fatal for this conn, so clear the
+ // cipher state whenever an error happens.
+ c.tx.cipher = nil
+ }
+ if c.tx.err == nil {
+ // Only set c.tx.err if not nil so that we can return one
+ // error on the first failure, and a different one for
+ // subsequent calls. See the error handling around Write
+ // below for why.
+ c.tx.err = err
+ }
+ }()
+
+ if c.tx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+
+ buf := getMaxMsgBuffer()
+ defer bufPool.Put(buf)
+
+ var sent int
+ for len(bs) > 0 {
+ toSend := bs
+ if len(toSend) > maxPlaintextSize {
+ toSend = bs[:maxPlaintextSize]
+ }
+ bs = bs[len(toSend):]
+
+ ciphertext, err := c.encryptLocked(toSend, buf)
+ if err != nil {
+ return sent, err
+ }
+ if _, err := c.conn.Write(ciphertext); err != nil {
+ // Return the raw error on the Write that actually
+ // failed. For future writes, return that error wrapped in
+ // a desync error.
+ c.tx.err = errPartialWrite{err}
+ return sent, err
+ }
+ sent += len(toSend)
+ }
+ return sent, nil
+}
+
+// Close implements io.Closer.
+func (c *Conn) Close() error {
+ closeErr := c.conn.Close() // unblocks any waiting reads or writes
+
+ // Remove references to live cipher state. Strictly speaking this
+ // is unnecessary, but we want to try and hand the active cipher
+ // state to the garbage collector promptly, to preserve perfect
+ // forward secrecy as much as we can.
+ c.rx.Lock()
+ c.rx.cipher = nil
+ c.rx.Unlock()
+ c.tx.Lock()
+ c.tx.cipher = nil
+ c.tx.Unlock()
+ return closeErr
+}
+
+func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
+func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
+func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
+func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
+func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
+
+// errCipherExhausted is the error returned when we run out of nonces
+// on a cipher.
+type errCipherExhausted struct{}
+
+func (errCipherExhausted) Error() string {
+ return "cipher exhausted, no more nonces available for current key"
+}
+func (errCipherExhausted) Timeout() bool { return false }
+func (errCipherExhausted) Temporary() bool { return false }
+
+// errPartialWrite is the error returned when the cipher state has
+// become unusable due to a past partial write.
+type errPartialWrite struct {
+ err error
+}
+
+func (e errPartialWrite) Error() string {
+ return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
+}
+func (e errPartialWrite) Unwrap() error { return e.err }
+func (e errPartialWrite) Temporary() bool { return false }
+func (e errPartialWrite) Timeout() bool { return false }
+
+// errReadTooBig is the error returned when the peer sent an
+// unacceptably large Noise frame.
+type errReadTooBig struct {
+ requested int
+}
+
+func (e errReadTooBig) Error() string {
+ return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
+}
+func (e errReadTooBig) Temporary() bool {
+ // permanent error because this error only occurs when our peer
+ // sends us a frame so large we're unwilling to ever decode it.
+ return false
+}
+func (e errReadTooBig) Timeout() bool { return false }
+
+type nonce [chp.NonceSize]byte
+
+func (n *nonce) Valid() bool {
+ return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
+}
+
+func (n *nonce) Increment() {
+ if !n.Valid() {
+ panic("increment of invalid nonce")
+ }
+ binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
+}
+
+type maxMsgBuffer [maxMessageSize]byte
+
+// bufPool holds the temporary buffers for Conn.Read & Write.
+var bufPool = &sync.Pool{
+ New: func() any {
+ return new(maxMsgBuffer)
+ },
+}
+
+func getMaxMsgBuffer() *maxMsgBuffer {
+ return bufPool.Get().(*maxMsgBuffer)
+}
diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 765a4620b..937969a30 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ "go4.org/mem"
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+ "tailscale.com/types/key"
+)
+
+const (
+ // protocolName is the name of the specific instantiation of Noise
+ // that the control protocol uses. This string's value is fixed by
+ // the Noise spec, and shouldn't be changed unless we're updating
+ // the control protocol to use a different Noise instance.
+ protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+ // protocolVersion is the version of the control protocol that
+ // Client will use when initiating a handshake.
+ //protocolVersion uint16 = 1
+ // protocolVersionPrefix is the name portion of the protocol
+ // name+version string that gets mixed into the handshake as a
+ // prologue.
+ //
+ // This mixing verifies that both clients agree that they're
+ // executing the control protocol at a specific version that
+ // matches the advertised version in the cleartext packet header.
+ protocolVersionPrefix = "Tailscale Control Protocol v"
+ invalidNonce = ^uint64(0)
+)
+
+func protocolVersionPrologue(version uint16) []byte {
+ ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
+ ret = append(ret, protocolVersionPrefix...)
+ return strconv.AppendUint(ret, uint64(version), 10)
+}
+
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
+// is assumed to have already sent the client>server handshake
+// initiation message.
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
+
+// ClientDeferred initiates a control client handshake, returning the
+// initial message to send to the server and a continuation to
+// finalize the handshake.
+//
+// ClientDeferred is split in this way for RTT reduction: we run this
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
+// completely serialized the negotiation followed by the handshake,
+// we'd pay an extra RTT to transmit the handshake initiation after
+// protocol switching. By splitting the handshake into an initial
+// message and a continuation, we can embed the handshake initiation
+// into the HTTP protocol switching request and avoid a bit of delay.
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
+ var s symmetricState
+ s.Initialize()
+
+ // prologue
+ s.MixHash(protocolVersionPrologue(protocolVersion))
+
+ // <- s
+ // ...
+ s.MixHash(controlKey.UntypedBytes())
+
+ // -> e, es, s, ss
+ init := mkInitiationMessage(protocolVersion)
+ machineEphemeral := key.NewMachine()
+ machineEphemeralPub := machineEphemeral.Public()
+ copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(machineEphemeral, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing es: %w", err)
+ }
+ machineKeyPub := machineKey.Public()
+ s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
+ cipher, err = s.MixDH(machineKey, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
+ }
+ s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
+
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
+ }
+ return init[:], cont, nil
+}
+
+// Client wraps ClientDeferred and immediately invokes the returned
+// continuation with conn.
+//
+// This is a helper for when you don't need the fancy
+// continuation-style handshake, and just want to synchronously
+// upgrade a net.Conn to a secure transport.
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(init); err != nil {
+ return nil, err
+ }
+ return cont(ctx, conn)
+}
+
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ // No matter what, this function can only run once per s. Ensure
+ // attempted reuse causes a panic.
+ defer func() {
+ s.finished = true
+ }()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Read in the payload and look for errors/protocol violations from the server.
+ var resp responseMessage
+ if _, err := io.ReadFull(conn, resp.Header()); err != nil {
+ return nil, fmt.Errorf("reading response header: %w", err)
+ }
+ if resp.Type() != msgTypeResponse {
+ if resp.Type() != msgTypeError {
+ return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
+ }
+ msg := make([]byte, resp.Length())
+ if _, err := io.ReadFull(conn, msg); err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("server error: %q", msg)
+ }
+ if resp.Length() != len(resp.Payload()) {
+ return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
+ }
+ if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
+ return nil, err
+ }
+
+ // <- e, ee, se
+ controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting payload: %w", err)
+ }
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: protocolVersion,
+ peer: controlKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c1,
+ },
+ rx: rxState{
+ cipher: c2,
+ },
+ }
+ return c, nil
+}
+
+// Server initiates a control server handshake, returning the resulting
+// control connection.
+//
+// optionalInit can be the client's initial handshake message as
+// returned by ClientDeferred, or nil in which case the initial
+// message is read from conn.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Deliberately does not support formatting, so that we don't echo
+ // attacker-controlled input back to them.
+ sendErr := func(msg string) error {
+ if len(msg) >= 1<<16 {
+ msg = msg[:1<<16]
+ }
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeError
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ if _, err := io.WriteString(conn, msg); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ return fmt.Errorf("refused client handshake: %q", msg)
+ }
+
+ var s symmetricState
+ s.Initialize()
+
+ var init initiationMessage
+ if optionalInit != nil {
+ if len(optionalInit) != len(init) {
+ return nil, sendErr("wrong handshake initiation size")
+ }
+ copy(init[:], optionalInit)
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
+ return nil, err
+ }
+ // Just a rename to make it more obvious what the value is. In the
+ // current implementation we don't need to block any protocol
+ // versions at this layer, it's safe to let the handshake proceed
+ // and then let the caller make decisions based on the agreed-upon
+ // protocol version.
+ clientVersion := init.Version()
+ if init.Type() != msgTypeInitiation {
+ return nil, sendErr("unexpected handshake message type")
+ }
+ if init.Length() != len(init.Payload()) {
+ return nil, sendErr("wrong handshake initiation length")
+ }
+ // if optionalInit was provided, we have the payload already.
+ if optionalInit == nil {
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
+ return nil, err
+ }
+ }
+
+ // prologue. Can only do this once we at least think the client is
+ // handshaking using a supported version.
+ s.MixHash(protocolVersionPrologue(clientVersion))
+
+ // <- s
+ // ...
+ controlKeyPub := controlKey.Public()
+ s.MixHash(controlKeyPub.UntypedBytes())
+
+ // -> e, es, s, ss
+ machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(controlKey, machineEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing es: %w", err)
+ }
+ var machineKeyBytes [32]byte
+ if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
+ return nil, fmt.Errorf("decrypting machine key: %w", err)
+ }
+ machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
+ cipher, err = s.MixDH(controlKey, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing ss: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+ }
+
+ // <- e, ee, se
+ resp := mkResponseMessage()
+ controlEphemeral := key.NewMachine()
+ controlEphemeralPub := controlEphemeral.Public()
+ copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err = s.MixDH(controlEphemeral, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ if _, err := conn.Write(resp[:]); err != nil {
+ return nil, err
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: clientVersion,
+ peer: machineKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c2,
+ },
+ rx: rxState{
+ cipher: c1,
+ },
+ }
+ return c, nil
+}
+
+// symmetricState contains the state of an in-flight handshake.
+type symmetricState struct {
+ finished bool
+
+ h [blake2s.Size]byte // hash of currently-processed handshake state
+ ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
+}
+
+func (s *symmetricState) checkFinished() {
+ if s.finished {
+ panic("attempted to use symmetricState after Split was called")
+ }
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any handshake messages.
+func (s *symmetricState) Initialize() {
+ s.checkFinished()
+ s.h = blake2s.Sum256([]byte(protocolName))
+ s.ck = s.h
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+ s.checkFinished()
+ h := newBLAKE2s()
+ h.Write(s.h[:])
+ h.Write(data)
+ h.Sum(s.h[:0])
+}
+
+// MixDH updates s.ck with the result of X25519(priv, pub) and returns
+// a singleUseCHP that can be used to encrypt or decrypt handshake
+// data.
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
+ s.checkFinished()
+ keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
+ if err != nil {
+ return nil, fmt.Errorf("computing X25519: %w", err)
+ }
+
+ r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+ if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+ return nil, fmt.Errorf("extracting ck: %w", err)
+ }
+ var k [chp.KeySize]byte
+ if _, err := io.ReadFull(r, k[:]); err != nil {
+ return nil, fmt.Errorf("extracting k: %w", err)
+ }
+ return newSingleUseCHP(k), nil
+}
+
+// EncryptAndHash encrypts plaintext into ciphertext (which must be
+// the correct size to hold the encrypted plaintext) using cipher,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ panic("ciphertext is wrong size for given plaintext")
+ }
+ ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
+ s.MixHash(ret)
+}
+
+// DecryptAndHash decrypts the given ciphertext into plaintext (which
+// must be the correct size to hold the decrypted ciphertext) using
+// cipher. If decryption is successful, it mixes the ciphertext into
+// s.h.
+func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ return errors.New("plaintext is wrong size for given ciphertext")
+ }
+ if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
+ return err
+ }
+ s.MixHash(ciphertext)
+ return nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derived from
+// the current handshake state. Methods on s cannot be used again
+// after calling Split.
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+ s.finished = true
+
+ var k1, k2 [chp.KeySize]byte
+ r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+ if _, err := io.ReadFull(r, k1[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k1: %w", err)
+ }
+ if _, err := io.ReadFull(r, k2[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k2: %w", err)
+ }
+ c1, err = chp.New(k1[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+ }
+ c2, err = chp.New(k2[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+ }
+ return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ // Should never happen, errors only happen when using BLAKE2s
+ // in MAC mode with a key.
+ panic(err)
+ }
+ return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+ aead, err := chp.New(key[:])
+ if err != nil {
+ // Can only happen if we passed a key of the wrong length. The
+ // function signature prevents that.
+ panic(err)
+ }
+ return aead
+}
+
+// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
+// only once, either for encrypting or decrypting, but not both. The
+// chosen operation is always executed with an all-zeros
+// nonce. Subsequent calls to either Seal or Open panic.
+type singleUseCHP struct {
+ c cipher.AEAD
+}
+
+func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
+ return &singleUseCHP{newCHP(key)}
+}
+
+func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Seal(dst, nonce[:], plaintext, additionalData)
+}
+
+func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Open(dst, nonce[:], ciphertext, additionalData)
+}
diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index c41fbf4dd..d11c04149 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -1,256 +1,256 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - - "tailscale.com/net/memnet" - "tailscale.com/types/key" -) - -// Can a reference Noise IK client talk to our server? -func TestInteropClient(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - serverErr = make(chan error, 2) - serverBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - server, err := Server(context.Background(), s2, controlKey, nil) - serverErr <- err - if err != nil { - return - } - var buf [1024]byte - _, err = io.ReadFull(server, buf[:len(c2s)]) - serverBytes <- buf[:len(c2s)] - if err != nil { - serverErr <- err - return - } - _, err = server.Write([]byte(s2c)) - serverErr <- err - }() - - gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) - if err != nil { - t.Fatalf("failed client interop: %v", err) - } - if string(gotS2C) != s2c { - t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) - } - - if err := <-serverErr; err != nil { - t.Fatalf("server handshake failed: %v", err) - } - if err := <-serverErr; err != nil { - t.Fatalf("server read/write failed: %v", err) - } - if got := string(<-serverBytes); got != c2s { - t.Fatalf("server received %q, want %q", got, c2s) - } -} - -// Can our client talk to a reference Noise IK server? -func TestInteropServer(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - clientErr = make(chan error, 2) - clientBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) - clientErr <- err - if err != nil { - return - } - _, err = client.Write([]byte(c2s)) - if err != nil { - clientErr <- err - return - } - var buf [1024]byte - _, err = io.ReadFull(client, buf[:len(s2c)]) - clientBytes <- buf[:len(s2c)] - clientErr <- err - }() - - gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) - if err != nil { - t.Fatalf("failed server interop: %v", err) - } - if string(gotC2S) != c2s { - t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) - } - - if err := <-clientErr; err != nil { - t.Fatalf("client handshake failed: %v", err) - } - if err := <-clientErr; err != nil { - t.Fatalf("client read/write failed: %v", err) - } - if got := string(<-clientBytes); got != s2c { - t.Fatalf("client received %q, want %q", got, s2c) - } -} - -// noiseExplorerClient uses the Noise Explorer implementation of Noise -// IK to handshake as a Noise client on conn, transmit payload, and -// read+return a payload from the peer. -func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], machineKey.UntypedBytes()) - copy(mk.public_key[:], machineKey.Public().UntypedBytes()) - var peerKey [32]byte - copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) - - _, msg1 := SendMessage(&session, nil) - var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) - hdr[2] = msgTypeInitiation - binary.BigEndian.PutUint16(hdr[3:5], 96) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ns); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ciphertext); err != nil { - return nil, err - } - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:51]); err != nil { - return nil, err - } - // ignore the header for this test, we're only checking the noise - // implementation. - msg2 := messagebuffer{ - ciphertext: buf[35:51], - } - copy(msg2.ne[:], buf[3:35]) - _, p, valid := RecvMessage(&session, &msg2) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg3 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(hdr[:3]); err != nil { - return nil, err - } - if _, err := conn.Write(msg3.ciphertext); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - // Ignore all of the header except the payload length - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg4 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg4) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - return p, nil -} - -func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], controlKey.UntypedBytes()) - copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:101]); err != nil { - return nil, err - } - // Ignore the header, we're just checking the noise implementation. - msg1 := messagebuffer{ - ns: buf[37:85], - ciphertext: buf[85:101], - } - copy(msg1.ne[:], buf[5:37]) - _, p, valid := RecvMessage(&session, &msg1) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg2 := SendMessage(&session, nil) - var hdr [headerLen]byte - hdr[0] = msgTypeResponse - binary.BigEndian.PutUint16(hdr[1:3], 48) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ciphertext[:]); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg3 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg3) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - _, msg4 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg4.ciphertext); err != nil { - return nil, err - } - - return p, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "testing"
+
+ "tailscale.com/net/memnet"
+ "tailscale.com/types/key"
+)
+
+// Can a reference Noise IK client talk to our server?
+func TestInteropClient(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ serverErr = make(chan error, 2)
+ serverBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ server, err := Server(context.Background(), s2, controlKey, nil)
+ serverErr <- err
+ if err != nil {
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(server, buf[:len(c2s)])
+ serverBytes <- buf[:len(c2s)]
+ if err != nil {
+ serverErr <- err
+ return
+ }
+ _, err = server.Write([]byte(s2c))
+ serverErr <- err
+ }()
+
+ gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
+ if err != nil {
+ t.Fatalf("failed client interop: %v", err)
+ }
+ if string(gotS2C) != s2c {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
+ }
+
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server handshake failed: %v", err)
+ }
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server read/write failed: %v", err)
+ }
+ if got := string(<-serverBytes); got != c2s {
+ t.Fatalf("server received %q, want %q", got, c2s)
+ }
+}
+
+// Can our client talk to a reference Noise IK server?
+func TestInteropServer(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ clientErr = make(chan error, 2)
+ clientBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
+ clientErr <- err
+ if err != nil {
+ return
+ }
+ _, err = client.Write([]byte(c2s))
+ if err != nil {
+ clientErr <- err
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(client, buf[:len(s2c)])
+ clientBytes <- buf[:len(s2c)]
+ clientErr <- err
+ }()
+
+ gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
+ if err != nil {
+ t.Fatalf("failed server interop: %v", err)
+ }
+ if string(gotC2S) != c2s {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
+ }
+
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client handshake failed: %v", err)
+ }
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client read/write failed: %v", err)
+ }
+ if got := string(<-clientBytes); got != s2c {
+ t.Fatalf("client received %q, want %q", got, s2c)
+ }
+}
+
+// noiseExplorerClient uses the Noise Explorer implementation of Noise
+// IK to handshake as a Noise client on conn, transmit payload, and
+// read+return a payload from the peer.
+func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], machineKey.UntypedBytes())
+ copy(mk.public_key[:], machineKey.Public().UntypedBytes())
+ var peerKey [32]byte
+ copy(peerKey[:], controlKey.UntypedBytes())
+ session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
+
+ _, msg1 := SendMessage(&session, nil)
+ var hdr [initiationHeaderLen]byte
+ binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
+ hdr[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(hdr[3:5], 96)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ns); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ciphertext); err != nil {
+ return nil, err
+ }
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:51]); err != nil {
+ return nil, err
+ }
+ // ignore the header for this test, we're only checking the noise
+ // implementation.
+ msg2 := messagebuffer{
+ ciphertext: buf[35:51],
+ }
+ copy(msg2.ne[:], buf[3:35])
+ _, p, valid := RecvMessage(&session, &msg2)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg3 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
+ if _, err := conn.Write(hdr[:3]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg3.ciphertext); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ // Ignore all of the header except the payload length
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg4 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg4)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ return p, nil
+}
+
+func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], controlKey.UntypedBytes())
+ copy(mk.public_key[:], controlKey.Public().UntypedBytes())
+ session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:101]); err != nil {
+ return nil, err
+ }
+ // Ignore the header, we're just checking the noise implementation.
+ msg1 := messagebuffer{
+ ns: buf[37:85],
+ ciphertext: buf[85:101],
+ }
+ copy(msg1.ne[:], buf[5:37])
+ _, p, valid := RecvMessage(&session, &msg1)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg2 := SendMessage(&session, nil)
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(hdr[1:3], 48)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg3 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg3)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ _, msg4 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg4.ciphertext); err != nil {
+ return nil, err
+ }
+
+ return p, nil
+}
diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 59073088f..899378681 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import "encoding/binary" - -const ( - // msgTypeInitiation frames carry a Noise IK handshake initiation message. - msgTypeInitiation = 1 - // msgTypeResponse frames carry a Noise IK handshake response message. - msgTypeResponse = 2 - // msgTypeError frames carry an unauthenticated human-readable - // error message. - // - // Errors reported in this message type must be treated as public - // hints only. They are not encrypted or authenticated, and so can - // be seen and tampered with on the wire. - msgTypeError = 3 - // msgTypeRecord frames carry session data bytes. - msgTypeRecord = 4 - - // headerLen is the size of the header on all messages except msgTypeInitiation. - headerLen = 3 - // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. - initiationHeaderLen = 5 -) - -// initiationMessage is the protocol message sent from a client -// machine to a control server. -// -// 2b: protocol version -// 1b: message type (0x01) -// 2b: payload length (96) -// 5b: header (see headerLen for fields) -// 32b: client ephemeral public key (cleartext) -// 48b: client machine public key (encrypted) -// 16b: message tag (authenticates the whole message) -type initiationMessage [101]byte - -func mkInitiationMessage(protocolVersion uint16) initiationMessage { - var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) - return ret -} - -func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } -func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } - -func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } -func (m *initiationMessage) Type() byte { return m[2] } -func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } - -func (m *initiationMessage) EphemeralPub() []byte { - return m[initiationHeaderLen : initiationHeaderLen+32] -} -func (m *initiationMessage) MachinePub() []byte { - return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] -} -func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } - -// responseMessage is the protocol message sent from a control server -// to a client machine. -// -// 1b: message type (0x02) -// 2b: payload length (48) -// 32b: control ephemeral public key (cleartext) -// 16b: message tag (authenticates the whole message) -type responseMessage [51]byte - -func mkResponseMessage() responseMessage { - var ret responseMessage - ret[0] = msgTypeResponse - binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) - return ret -} - -func (m *responseMessage) Header() []byte { return m[:headerLen] } -func (m *responseMessage) Payload() []byte { return m[headerLen:] } - -func (m *responseMessage) Type() byte { return m[0] } -func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } - -func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import "encoding/binary"
+
+const (
+ // msgTypeInitiation frames carry a Noise IK handshake initiation message.
+ msgTypeInitiation = 1
+ // msgTypeResponse frames carry a Noise IK handshake response message.
+ msgTypeResponse = 2
+ // msgTypeError frames carry an unauthenticated human-readable
+ // error message.
+ //
+ // Errors reported in this message type must be treated as public
+ // hints only. They are not encrypted or authenticated, and so can
+ // be seen and tampered with on the wire.
+ msgTypeError = 3
+ // msgTypeRecord frames carry session data bytes.
+ msgTypeRecord = 4
+
+ // headerLen is the size of the header on all messages except msgTypeInitiation.
+ headerLen = 3
+ // initiationHeaderLen is the size of the header on all msgTypeInitiation messages.
+ initiationHeaderLen = 5
+)
+
+// initiationMessage is the protocol message sent from a client
+// machine to a control server.
+//
+// 2b: protocol version
+// 1b: message type (0x01)
+// 2b: payload length (96)
+// 5b: header (see headerLen for fields)
+// 32b: client ephemeral public key (cleartext)
+// 48b: client machine public key (encrypted)
+// 16b: message tag (authenticates the whole message)
+type initiationMessage [101]byte
+
+func mkInitiationMessage(protocolVersion uint16) initiationMessage {
+ var ret initiationMessage
+ binary.BigEndian.PutUint16(ret[:2], protocolVersion)
+ ret[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] }
+func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] }
+
+func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) }
+func (m *initiationMessage) Type() byte { return m[2] }
+func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) }
+
+func (m *initiationMessage) EphemeralPub() []byte {
+ return m[initiationHeaderLen : initiationHeaderLen+32]
+}
+func (m *initiationMessage) MachinePub() []byte {
+ return m[initiationHeaderLen+32 : initiationHeaderLen+32+48]
+}
+func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] }
+
+// responseMessage is the protocol message sent from a control server
+// to a client machine.
+//
+// 1b: message type (0x02)
+// 2b: payload length (48)
+// 32b: control ephemeral public key (cleartext)
+// 16b: message tag (authenticates the whole message)
+type responseMessage [51]byte
+
+func mkResponseMessage() responseMessage {
+ var ret responseMessage
+ ret[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *responseMessage) Header() []byte { return m[:headerLen] }
+func (m *responseMessage) Payload() []byte { return m[headerLen:] }
+
+func (m *responseMessage) Type() byte { return m[0] }
+func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) }
+
+func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] }
+func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }
|
