diff options
| author | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
|---|---|---|
| committer | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
| commit | 0267fe83b200f1702a2fa0a395442c02a053fadb (patch) | |
| tree | 63654c55225eeb834de59a5a0bc8d19033c6145b /control | |
| parent | 87546a5edf6b6503a87eeb2d666baba57398a066 (diff) | |
| download | tailscale-1.78.0.tar.xz tailscale-1.78.0.zip | |
VERSION.txt: this is v1.78.0v1.78.0
Signed-off-by: Nick Khyl <nickk@tailscale.com>
Diffstat (limited to 'control')
| -rw-r--r-- | control/controlbase/conn.go | 816 | ||||
| -rw-r--r-- | control/controlbase/handshake.go | 988 | ||||
| -rw-r--r-- | control/controlbase/interop_test.go | 512 | ||||
| -rw-r--r-- | control/controlbase/messages.go | 174 | ||||
| -rw-r--r-- | control/controlclient/sign.go | 84 | ||||
| -rw-r--r-- | control/controlclient/sign_supported_test.go | 472 | ||||
| -rw-r--r-- | control/controlclient/sign_unsupported.go | 32 | ||||
| -rw-r--r-- | control/controlclient/status.go | 250 | ||||
| -rw-r--r-- | control/controlhttp/client_common.go | 34 |
9 files changed, 1681 insertions, 1681 deletions
diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index dc22212e8..b6fc53b3a 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -1,408 +1,408 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package controlbase implements the base transport of the Tailscale -// 2021 control protocol. -// -// The base transport implements Noise IK, instantiated with -// Curve25519, ChaCha20Poly1305 and BLAKE2s. -package controlbase - -import ( - "crypto/cipher" - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "tailscale.com/types/key" -) - -const ( - // maxMessageSize is the maximum size of a protocol frame on the - // wire, including header and payload. - maxMessageSize = 4096 - // maxCiphertextSize is the maximum amount of ciphertext bytes - // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - 3 - // maxPlaintextSize is the maximum amount of plaintext bytes that - // one protocol frame can carry, after encryption and framing. - maxPlaintextSize = maxCiphertextSize - chp.Overhead -) - -// A Conn is a secured Noise connection. It implements the net.Conn -// interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) causes all future writes to -// fail. -type Conn struct { - conn net.Conn - version uint16 - peer key.MachinePublic - handshakeHash [blake2s.Size]byte - rx rxState - tx txState -} - -// rxState is all the Conn state that Read uses. -type rxState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - buf *maxMsgBuffer // or nil when reads exhausted - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes - hdrBuf [headerLen]byte // small buffer used when buf is nil -} - -// txState is all the Conn state that Write uses. -type txState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - err error // records the first partial write error for all future calls -} - -// ProtocolVersion returns the protocol version that was used to -// establish this Conn. -func (c *Conn) ProtocolVersion() int { - return int(c.version) -} - -// HandshakeHash returns the Noise handshake hash for the connection, -// which can be used to bind other messages to this connection -// (i.e. to ensure that the message wasn't replayed from a different -// connection). -func (c *Conn) HandshakeHash() [blake2s.Size]byte { - return c.handshakeHash -} - -// Peer returns the peer's long-term public key. -func (c *Conn) Peer() key.MachinePublic { - return c.peer -} - -// readNLocked reads into c.rx.buf until buf contains at least total -// bytes. Returns a slice of the total bytes in rxBuf, or an -// error if fewer than total bytes are available. -// -// It may be called with a nil c.rx.buf only if total == headerLen. -// -// On success, c.rx.buf will be non-nil. -func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxMessageSize { - return nil, errReadTooBig{total} - } - for { - if total <= c.rx.n { - return c.rx.buf[:total], nil - } - var n int - var err error - if c.rx.buf == nil { - if c.rx.n != 0 || total != headerLen { - panic("unexpected") - } - // Optimization to reduce memory usage. - // Most connections are blocked forever waiting for - // a read, so we don't want c.rx.buf to be allocated until - // we know there's data to read. Instead, when we're - // waiting for data to arrive here, read into the - // 3 byte hdrBuf: - n, err = c.conn.Read(c.rx.hdrBuf[:]) - if n > 0 { - c.rx.buf = getMaxMsgBuffer() - copy(c.rx.buf[:], c.rx.hdrBuf[:n]) - } - } else { - n, err = c.conn.Read(c.rx.buf[c.rx.n:]) - } - c.rx.n += n - if err != nil { - return nil, err - } - } -} - -// decryptLocked decrypts msg (which is header+ciphertext) in-place -// and sets c.rx.plaintext to the decrypted bytes. -func (c *Conn) decryptLocked(msg []byte) (err error) { - if msgType := msg[0]; msgType != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) - } - // We don't check the length field here, because the caller - // already did in order to figure out how big the msg slice should - // be. - ciphertext := msg[headerLen:] - - if !c.rx.nonce.Valid() { - return errCipherExhausted{} - } - - c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) - c.rx.nonce.Increment() - - if err != nil { - // Once a decryption has failed, our Conn is no longer - // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. Future - // read attempts will return net.ErrClosed. - c.rx.cipher = nil - } - return err -} - -// encryptLocked encrypts plaintext into buf (including the -// packet header) and returns a slice of the ciphertext, or an error -// if the cipher is exhausted (i.e. can no longer be used safely). -func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { - if !c.tx.nonce.Valid() { - // Received 2^64-1 messages on this cipher state. Connection - // is no longer usable. - return nil, errCipherExhausted{} - } - - buf[0] = msgTypeRecord - binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) - ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) - c.tx.nonce.Increment() - - return ret, nil -} - -// wholeMessageLocked returns a slice of one whole Noise transport -// message from c.rx.buf, if one whole message is available, and -// advances the read state to the next Noise message in the -// buffer. Returns nil without advancing read state if there isn't one -// whole message in c.rx.buf. -func (c *Conn) wholeMessageLocked() []byte { - available := c.rx.n - c.rx.next - if available < headerLen { - return nil - } - bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - if len(bs) < totalSize { - return nil - } - c.rx.next += totalSize - return bs[:totalSize] -} - -// decryptOneLocked decrypts one Noise transport message, reading from -// c.conn as needed, and sets c.rx.plaintext to point to the decrypted -// bytes. c.rx.plaintext is only valid if err == nil. -func (c *Conn) decryptOneLocked() error { - c.rx.plaintext = nil - - // Fast path: do we have one whole ciphertext frame buffered - // already? - if bs := c.wholeMessageLocked(); bs != nil { - return c.decryptLocked(bs) - } - - if c.rx.next != 0 { - // To simplify the read logic, move the remainder of the - // buffered bytes back to the head of the buffer, so we can - // grow it without worrying about wraparound. - c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.next = 0 - } - - // Return our buffer to the pool if it's empty, lest we be - // blocked in a long Read call, reading the 3 byte header. We - // don't to keep that buffer unnecessarily alive. - if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { - bufPool.Put(c.rx.buf) - c.rx.buf = nil - } - - bs, err := c.readNLocked(headerLen) - if err != nil { - return err - } - // The rest of the header (besides the length field) gets verified - // in decryptLocked, not here. - messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - bs, err = c.readNLocked(messageLen) - if err != nil { - return err - } - - c.rx.next = len(bs) - - return c.decryptLocked(bs) -} - -// Read implements io.Reader. -func (c *Conn) Read(bs []byte) (int, error) { - c.rx.Lock() - defer c.rx.Unlock() - - if c.rx.cipher == nil { - return 0, net.ErrClosed - } - // If no plaintext is buffered, decrypt incoming frames until we - // have some plaintext. Zero-byte Noise frames are allowed in this - // protocol, which is why we have to loop here rather than decrypt - // a single additional frame. - for len(c.rx.plaintext) == 0 { - if err := c.decryptOneLocked(); err != nil { - return 0, err - } - } - n := copy(bs, c.rx.plaintext) - c.rx.plaintext = c.rx.plaintext[n:] - - // Lose slice's underlying array pointer to unneeded memory so - // GC can collect more. - if len(c.rx.plaintext) == 0 { - c.rx.plaintext = nil - } - return n, nil -} - -// Write implements io.Writer. -func (c *Conn) Write(bs []byte) (n int, err error) { - c.tx.Lock() - defer c.tx.Unlock() - - if c.tx.err != nil { - return 0, c.tx.err - } - defer func() { - if err != nil { - // All write errors are fatal for this conn, so clear the - // cipher state whenever an error happens. - c.tx.cipher = nil - } - if c.tx.err == nil { - // Only set c.tx.err if not nil so that we can return one - // error on the first failure, and a different one for - // subsequent calls. See the error handling around Write - // below for why. - c.tx.err = err - } - }() - - if c.tx.cipher == nil { - return 0, net.ErrClosed - } - - buf := getMaxMsgBuffer() - defer bufPool.Put(buf) - - var sent int - for len(bs) > 0 { - toSend := bs - if len(toSend) > maxPlaintextSize { - toSend = bs[:maxPlaintextSize] - } - bs = bs[len(toSend):] - - ciphertext, err := c.encryptLocked(toSend, buf) - if err != nil { - return sent, err - } - if _, err := c.conn.Write(ciphertext); err != nil { - // Return the raw error on the Write that actually - // failed. For future writes, return that error wrapped in - // a desync error. - c.tx.err = errPartialWrite{err} - return sent, err - } - sent += len(toSend) - } - return sent, nil -} - -// Close implements io.Closer. -func (c *Conn) Close() error { - closeErr := c.conn.Close() // unblocks any waiting reads or writes - - // Remove references to live cipher state. Strictly speaking this - // is unnecessary, but we want to try and hand the active cipher - // state to the garbage collector promptly, to preserve perfect - // forward secrecy as much as we can. - c.rx.Lock() - c.rx.cipher = nil - c.rx.Unlock() - c.tx.Lock() - c.tx.cipher = nil - c.tx.Unlock() - return closeErr -} - -func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } - -// errCipherExhausted is the error returned when we run out of nonces -// on a cipher. -type errCipherExhausted struct{} - -func (errCipherExhausted) Error() string { - return "cipher exhausted, no more nonces available for current key" -} -func (errCipherExhausted) Timeout() bool { return false } -func (errCipherExhausted) Temporary() bool { return false } - -// errPartialWrite is the error returned when the cipher state has -// become unusable due to a past partial write. -type errPartialWrite struct { - err error -} - -func (e errPartialWrite) Error() string { - return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) -} -func (e errPartialWrite) Unwrap() error { return e.err } -func (e errPartialWrite) Temporary() bool { return false } -func (e errPartialWrite) Timeout() bool { return false } - -// errReadTooBig is the error returned when the peer sent an -// unacceptably large Noise frame. -type errReadTooBig struct { - requested int -} - -func (e errReadTooBig) Error() string { - return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) -} -func (e errReadTooBig) Temporary() bool { - // permanent error because this error only occurs when our peer - // sends us a frame so large we're unwilling to ever decode it. - return false -} -func (e errReadTooBig) Timeout() bool { return false } - -type nonce [chp.NonceSize]byte - -func (n *nonce) Valid() bool { - return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce -} - -func (n *nonce) Increment() { - if !n.Valid() { - panic("increment of invalid nonce") - } - binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) -} - -type maxMsgBuffer [maxMessageSize]byte - -// bufPool holds the temporary buffers for Conn.Read & Write. -var bufPool = &sync.Pool{ - New: func() any { - return new(maxMsgBuffer) - }, -} - -func getMaxMsgBuffer() *maxMsgBuffer { - return bufPool.Get().(*maxMsgBuffer) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package controlbase implements the base transport of the Tailscale
+// 2021 control protocol.
+//
+// The base transport implements Noise IK, instantiated with
+// Curve25519, ChaCha20Poly1305 and BLAKE2s.
+package controlbase
+
+import (
+ "crypto/cipher"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "tailscale.com/types/key"
+)
+
+const (
+ // maxMessageSize is the maximum size of a protocol frame on the
+ // wire, including header and payload.
+ maxMessageSize = 4096
+ // maxCiphertextSize is the maximum amount of ciphertext bytes
+ // that one protocol frame can carry, after framing.
+ maxCiphertextSize = maxMessageSize - 3
+ // maxPlaintextSize is the maximum amount of plaintext bytes that
+ // one protocol frame can carry, after encryption and framing.
+ maxPlaintextSize = maxCiphertextSize - chp.Overhead
+)
+
+// A Conn is a secured Noise connection. It implements the net.Conn
+// interface, with the unusual trait that any write error (including a
+// SetWriteDeadline induced i/o timeout) causes all future writes to
+// fail.
+type Conn struct {
+ conn net.Conn
+ version uint16
+ peer key.MachinePublic
+ handshakeHash [blake2s.Size]byte
+ rx rxState
+ tx txState
+}
+
+// rxState is all the Conn state that Read uses.
+type rxState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ buf *maxMsgBuffer // or nil when reads exhausted
+ n int // number of valid bytes in buf
+ next int // offset of next undecrypted packet
+ plaintext []byte // slice into buf of decrypted bytes
+ hdrBuf [headerLen]byte // small buffer used when buf is nil
+}
+
+// txState is all the Conn state that Write uses.
+type txState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ err error // records the first partial write error for all future calls
+}
+
+// ProtocolVersion returns the protocol version that was used to
+// establish this Conn.
+func (c *Conn) ProtocolVersion() int {
+ return int(c.version)
+}
+
+// HandshakeHash returns the Noise handshake hash for the connection,
+// which can be used to bind other messages to this connection
+// (i.e. to ensure that the message wasn't replayed from a different
+// connection).
+func (c *Conn) HandshakeHash() [blake2s.Size]byte {
+ return c.handshakeHash
+}
+
+// Peer returns the peer's long-term public key.
+func (c *Conn) Peer() key.MachinePublic {
+ return c.peer
+}
+
+// readNLocked reads into c.rx.buf until buf contains at least total
+// bytes. Returns a slice of the total bytes in rxBuf, or an
+// error if fewer than total bytes are available.
+//
+// It may be called with a nil c.rx.buf only if total == headerLen.
+//
+// On success, c.rx.buf will be non-nil.
+func (c *Conn) readNLocked(total int) ([]byte, error) {
+ if total > maxMessageSize {
+ return nil, errReadTooBig{total}
+ }
+ for {
+ if total <= c.rx.n {
+ return c.rx.buf[:total], nil
+ }
+ var n int
+ var err error
+ if c.rx.buf == nil {
+ if c.rx.n != 0 || total != headerLen {
+ panic("unexpected")
+ }
+ // Optimization to reduce memory usage.
+ // Most connections are blocked forever waiting for
+ // a read, so we don't want c.rx.buf to be allocated until
+ // we know there's data to read. Instead, when we're
+ // waiting for data to arrive here, read into the
+ // 3 byte hdrBuf:
+ n, err = c.conn.Read(c.rx.hdrBuf[:])
+ if n > 0 {
+ c.rx.buf = getMaxMsgBuffer()
+ copy(c.rx.buf[:], c.rx.hdrBuf[:n])
+ }
+ } else {
+ n, err = c.conn.Read(c.rx.buf[c.rx.n:])
+ }
+ c.rx.n += n
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// decryptLocked decrypts msg (which is header+ciphertext) in-place
+// and sets c.rx.plaintext to the decrypted bytes.
+func (c *Conn) decryptLocked(msg []byte) (err error) {
+ if msgType := msg[0]; msgType != msgTypeRecord {
+ return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
+ }
+ // We don't check the length field here, because the caller
+ // already did in order to figure out how big the msg slice should
+ // be.
+ ciphertext := msg[headerLen:]
+
+ if !c.rx.nonce.Valid() {
+ return errCipherExhausted{}
+ }
+
+ c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
+ c.rx.nonce.Increment()
+
+ if err != nil {
+ // Once a decryption has failed, our Conn is no longer
+ // synchronized with our peer. Nuke the cipher state to be
+ // safe, so that no further decryptions are attempted. Future
+ // read attempts will return net.ErrClosed.
+ c.rx.cipher = nil
+ }
+ return err
+}
+
+// encryptLocked encrypts plaintext into buf (including the
+// packet header) and returns a slice of the ciphertext, or an error
+// if the cipher is exhausted (i.e. can no longer be used safely).
+func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
+ if !c.tx.nonce.Valid() {
+ // Received 2^64-1 messages on this cipher state. Connection
+ // is no longer usable.
+ return nil, errCipherExhausted{}
+ }
+
+ buf[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
+ ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
+ c.tx.nonce.Increment()
+
+ return ret, nil
+}
+
+// wholeMessageLocked returns a slice of one whole Noise transport
+// message from c.rx.buf, if one whole message is available, and
+// advances the read state to the next Noise message in the
+// buffer. Returns nil without advancing read state if there isn't one
+// whole message in c.rx.buf.
+func (c *Conn) wholeMessageLocked() []byte {
+ available := c.rx.n - c.rx.next
+ if available < headerLen {
+ return nil
+ }
+ bs := c.rx.buf[c.rx.next:c.rx.n]
+ totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ if len(bs) < totalSize {
+ return nil
+ }
+ c.rx.next += totalSize
+ return bs[:totalSize]
+}
+
+// decryptOneLocked decrypts one Noise transport message, reading from
+// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
+// bytes. c.rx.plaintext is only valid if err == nil.
+func (c *Conn) decryptOneLocked() error {
+ c.rx.plaintext = nil
+
+ // Fast path: do we have one whole ciphertext frame buffered
+ // already?
+ if bs := c.wholeMessageLocked(); bs != nil {
+ return c.decryptLocked(bs)
+ }
+
+ if c.rx.next != 0 {
+ // To simplify the read logic, move the remainder of the
+ // buffered bytes back to the head of the buffer, so we can
+ // grow it without worrying about wraparound.
+ c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
+ c.rx.next = 0
+ }
+
+ // Return our buffer to the pool if it's empty, lest we be
+ // blocked in a long Read call, reading the 3 byte header. We
+ // don't to keep that buffer unnecessarily alive.
+ if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
+ bufPool.Put(c.rx.buf)
+ c.rx.buf = nil
+ }
+
+ bs, err := c.readNLocked(headerLen)
+ if err != nil {
+ return err
+ }
+ // The rest of the header (besides the length field) gets verified
+ // in decryptLocked, not here.
+ messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ bs, err = c.readNLocked(messageLen)
+ if err != nil {
+ return err
+ }
+
+ c.rx.next = len(bs)
+
+ return c.decryptLocked(bs)
+}
+
+// Read implements io.Reader.
+func (c *Conn) Read(bs []byte) (int, error) {
+ c.rx.Lock()
+ defer c.rx.Unlock()
+
+ if c.rx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+ // If no plaintext is buffered, decrypt incoming frames until we
+ // have some plaintext. Zero-byte Noise frames are allowed in this
+ // protocol, which is why we have to loop here rather than decrypt
+ // a single additional frame.
+ for len(c.rx.plaintext) == 0 {
+ if err := c.decryptOneLocked(); err != nil {
+ return 0, err
+ }
+ }
+ n := copy(bs, c.rx.plaintext)
+ c.rx.plaintext = c.rx.plaintext[n:]
+
+ // Lose slice's underlying array pointer to unneeded memory so
+ // GC can collect more.
+ if len(c.rx.plaintext) == 0 {
+ c.rx.plaintext = nil
+ }
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (c *Conn) Write(bs []byte) (n int, err error) {
+ c.tx.Lock()
+ defer c.tx.Unlock()
+
+ if c.tx.err != nil {
+ return 0, c.tx.err
+ }
+ defer func() {
+ if err != nil {
+ // All write errors are fatal for this conn, so clear the
+ // cipher state whenever an error happens.
+ c.tx.cipher = nil
+ }
+ if c.tx.err == nil {
+ // Only set c.tx.err if not nil so that we can return one
+ // error on the first failure, and a different one for
+ // subsequent calls. See the error handling around Write
+ // below for why.
+ c.tx.err = err
+ }
+ }()
+
+ if c.tx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+
+ buf := getMaxMsgBuffer()
+ defer bufPool.Put(buf)
+
+ var sent int
+ for len(bs) > 0 {
+ toSend := bs
+ if len(toSend) > maxPlaintextSize {
+ toSend = bs[:maxPlaintextSize]
+ }
+ bs = bs[len(toSend):]
+
+ ciphertext, err := c.encryptLocked(toSend, buf)
+ if err != nil {
+ return sent, err
+ }
+ if _, err := c.conn.Write(ciphertext); err != nil {
+ // Return the raw error on the Write that actually
+ // failed. For future writes, return that error wrapped in
+ // a desync error.
+ c.tx.err = errPartialWrite{err}
+ return sent, err
+ }
+ sent += len(toSend)
+ }
+ return sent, nil
+}
+
+// Close implements io.Closer.
+func (c *Conn) Close() error {
+ closeErr := c.conn.Close() // unblocks any waiting reads or writes
+
+ // Remove references to live cipher state. Strictly speaking this
+ // is unnecessary, but we want to try and hand the active cipher
+ // state to the garbage collector promptly, to preserve perfect
+ // forward secrecy as much as we can.
+ c.rx.Lock()
+ c.rx.cipher = nil
+ c.rx.Unlock()
+ c.tx.Lock()
+ c.tx.cipher = nil
+ c.tx.Unlock()
+ return closeErr
+}
+
+func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
+func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
+func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
+func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
+func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
+
+// errCipherExhausted is the error returned when we run out of nonces
+// on a cipher.
+type errCipherExhausted struct{}
+
+func (errCipherExhausted) Error() string {
+ return "cipher exhausted, no more nonces available for current key"
+}
+func (errCipherExhausted) Timeout() bool { return false }
+func (errCipherExhausted) Temporary() bool { return false }
+
+// errPartialWrite is the error returned when the cipher state has
+// become unusable due to a past partial write.
+type errPartialWrite struct {
+ err error
+}
+
+func (e errPartialWrite) Error() string {
+ return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
+}
+func (e errPartialWrite) Unwrap() error { return e.err }
+func (e errPartialWrite) Temporary() bool { return false }
+func (e errPartialWrite) Timeout() bool { return false }
+
+// errReadTooBig is the error returned when the peer sent an
+// unacceptably large Noise frame.
+type errReadTooBig struct {
+ requested int
+}
+
+func (e errReadTooBig) Error() string {
+ return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
+}
+func (e errReadTooBig) Temporary() bool {
+ // permanent error because this error only occurs when our peer
+ // sends us a frame so large we're unwilling to ever decode it.
+ return false
+}
+func (e errReadTooBig) Timeout() bool { return false }
+
+type nonce [chp.NonceSize]byte
+
+func (n *nonce) Valid() bool {
+ return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
+}
+
+func (n *nonce) Increment() {
+ if !n.Valid() {
+ panic("increment of invalid nonce")
+ }
+ binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
+}
+
+type maxMsgBuffer [maxMessageSize]byte
+
+// bufPool holds the temporary buffers for Conn.Read & Write.
+var bufPool = &sync.Pool{
+ New: func() any {
+ return new(maxMsgBuffer)
+ },
+}
+
+func getMaxMsgBuffer() *maxMsgBuffer {
+ return bufPool.Get().(*maxMsgBuffer)
+}
diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 765a4620b..937969a30 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ "go4.org/mem"
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+ "tailscale.com/types/key"
+)
+
+const (
+ // protocolName is the name of the specific instantiation of Noise
+ // that the control protocol uses. This string's value is fixed by
+ // the Noise spec, and shouldn't be changed unless we're updating
+ // the control protocol to use a different Noise instance.
+ protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+ // protocolVersion is the version of the control protocol that
+ // Client will use when initiating a handshake.
+ //protocolVersion uint16 = 1
+ // protocolVersionPrefix is the name portion of the protocol
+ // name+version string that gets mixed into the handshake as a
+ // prologue.
+ //
+ // This mixing verifies that both clients agree that they're
+ // executing the control protocol at a specific version that
+ // matches the advertised version in the cleartext packet header.
+ protocolVersionPrefix = "Tailscale Control Protocol v"
+ invalidNonce = ^uint64(0)
+)
+
+func protocolVersionPrologue(version uint16) []byte {
+ ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
+ ret = append(ret, protocolVersionPrefix...)
+ return strconv.AppendUint(ret, uint64(version), 10)
+}
+
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
+// is assumed to have already sent the client>server handshake
+// initiation message.
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
+
+// ClientDeferred initiates a control client handshake, returning the
+// initial message to send to the server and a continuation to
+// finalize the handshake.
+//
+// ClientDeferred is split in this way for RTT reduction: we run this
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
+// completely serialized the negotiation followed by the handshake,
+// we'd pay an extra RTT to transmit the handshake initiation after
+// protocol switching. By splitting the handshake into an initial
+// message and a continuation, we can embed the handshake initiation
+// into the HTTP protocol switching request and avoid a bit of delay.
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
+ var s symmetricState
+ s.Initialize()
+
+ // prologue
+ s.MixHash(protocolVersionPrologue(protocolVersion))
+
+ // <- s
+ // ...
+ s.MixHash(controlKey.UntypedBytes())
+
+ // -> e, es, s, ss
+ init := mkInitiationMessage(protocolVersion)
+ machineEphemeral := key.NewMachine()
+ machineEphemeralPub := machineEphemeral.Public()
+ copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(machineEphemeral, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing es: %w", err)
+ }
+ machineKeyPub := machineKey.Public()
+ s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
+ cipher, err = s.MixDH(machineKey, controlKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
+ }
+ s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
+
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
+ }
+ return init[:], cont, nil
+}
+
+// Client wraps ClientDeferred and immediately invokes the returned
+// continuation with conn.
+//
+// This is a helper for when you don't need the fancy
+// continuation-style handshake, and just want to synchronously
+// upgrade a net.Conn to a secure transport.
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(init); err != nil {
+ return nil, err
+ }
+ return cont(ctx, conn)
+}
+
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
+ // No matter what, this function can only run once per s. Ensure
+ // attempted reuse causes a panic.
+ defer func() {
+ s.finished = true
+ }()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Read in the payload and look for errors/protocol violations from the server.
+ var resp responseMessage
+ if _, err := io.ReadFull(conn, resp.Header()); err != nil {
+ return nil, fmt.Errorf("reading response header: %w", err)
+ }
+ if resp.Type() != msgTypeResponse {
+ if resp.Type() != msgTypeError {
+ return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
+ }
+ msg := make([]byte, resp.Length())
+ if _, err := io.ReadFull(conn, msg); err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("server error: %q", msg)
+ }
+ if resp.Length() != len(resp.Payload()) {
+ return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
+ }
+ if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
+ return nil, err
+ }
+
+ // <- e, ee, se
+ controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting payload: %w", err)
+ }
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: protocolVersion,
+ peer: controlKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c1,
+ },
+ rx: rxState{
+ cipher: c2,
+ },
+ }
+ return c, nil
+}
+
+// Server initiates a control server handshake, returning the resulting
+// control connection.
+//
+// optionalInit can be the client's initial handshake message as
+// returned by ClientDeferred, or nil in which case the initial
+// message is read from conn.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
+ }
+ defer func() {
+ conn.SetDeadline(time.Time{})
+ }()
+ }
+
+ // Deliberately does not support formatting, so that we don't echo
+ // attacker-controlled input back to them.
+ sendErr := func(msg string) error {
+ if len(msg) >= 1<<16 {
+ msg = msg[:1<<16]
+ }
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeError
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ if _, err := io.WriteString(conn, msg); err != nil {
+ return fmt.Errorf("sending %q error to client: %w", msg, err)
+ }
+ return fmt.Errorf("refused client handshake: %q", msg)
+ }
+
+ var s symmetricState
+ s.Initialize()
+
+ var init initiationMessage
+ if optionalInit != nil {
+ if len(optionalInit) != len(init) {
+ return nil, sendErr("wrong handshake initiation size")
+ }
+ copy(init[:], optionalInit)
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
+ return nil, err
+ }
+ // Just a rename to make it more obvious what the value is. In the
+ // current implementation we don't need to block any protocol
+ // versions at this layer, it's safe to let the handshake proceed
+ // and then let the caller make decisions based on the agreed-upon
+ // protocol version.
+ clientVersion := init.Version()
+ if init.Type() != msgTypeInitiation {
+ return nil, sendErr("unexpected handshake message type")
+ }
+ if init.Length() != len(init.Payload()) {
+ return nil, sendErr("wrong handshake initiation length")
+ }
+ // if optionalInit was provided, we have the payload already.
+ if optionalInit == nil {
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
+ return nil, err
+ }
+ }
+
+ // prologue. Can only do this once we at least think the client is
+ // handshaking using a supported version.
+ s.MixHash(protocolVersionPrologue(clientVersion))
+
+ // <- s
+ // ...
+ controlKeyPub := controlKey.Public()
+ s.MixHash(controlKeyPub.UntypedBytes())
+
+ // -> e, es, s, ss
+ machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
+ s.MixHash(machineEphemeralPub.UntypedBytes())
+ cipher, err := s.MixDH(controlKey, machineEphemeralPub)
+ if err != nil {
+ return nil, fmt.Errorf("computing es: %w", err)
+ }
+ var machineKeyBytes [32]byte
+ if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
+ return nil, fmt.Errorf("decrypting machine key: %w", err)
+ }
+ machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
+ cipher, err = s.MixDH(controlKey, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing ss: %w", err)
+ }
+ if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
+ return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+ }
+
+ // <- e, ee, se
+ resp := mkResponseMessage()
+ controlEphemeral := key.NewMachine()
+ controlEphemeralPub := controlEphemeral.Public()
+ copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
+ s.MixHash(controlEphemeralPub.UntypedBytes())
+ if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+ return nil, fmt.Errorf("computing ee: %w", err)
+ }
+ cipher, err = s.MixDH(controlEphemeral, machineKey)
+ if err != nil {
+ return nil, fmt.Errorf("computing se: %w", err)
+ }
+ s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
+
+ c1, c2, err := s.Split()
+ if err != nil {
+ return nil, fmt.Errorf("finalizing handshake: %w", err)
+ }
+
+ if _, err := conn.Write(resp[:]); err != nil {
+ return nil, err
+ }
+
+ c := &Conn{
+ conn: conn,
+ version: clientVersion,
+ peer: machineKey,
+ handshakeHash: s.h,
+ tx: txState{
+ cipher: c2,
+ },
+ rx: rxState{
+ cipher: c1,
+ },
+ }
+ return c, nil
+}
+
+// symmetricState contains the state of an in-flight handshake.
+type symmetricState struct {
+ finished bool
+
+ h [blake2s.Size]byte // hash of currently-processed handshake state
+ ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
+}
+
+func (s *symmetricState) checkFinished() {
+ if s.finished {
+ panic("attempted to use symmetricState after Split was called")
+ }
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any handshake messages.
+func (s *symmetricState) Initialize() {
+ s.checkFinished()
+ s.h = blake2s.Sum256([]byte(protocolName))
+ s.ck = s.h
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+ s.checkFinished()
+ h := newBLAKE2s()
+ h.Write(s.h[:])
+ h.Write(data)
+ h.Sum(s.h[:0])
+}
+
+// MixDH updates s.ck with the result of X25519(priv, pub) and returns
+// a singleUseCHP that can be used to encrypt or decrypt handshake
+// data.
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
+ s.checkFinished()
+ keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
+ if err != nil {
+ return nil, fmt.Errorf("computing X25519: %w", err)
+ }
+
+ r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+ if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+ return nil, fmt.Errorf("extracting ck: %w", err)
+ }
+ var k [chp.KeySize]byte
+ if _, err := io.ReadFull(r, k[:]); err != nil {
+ return nil, fmt.Errorf("extracting k: %w", err)
+ }
+ return newSingleUseCHP(k), nil
+}
+
+// EncryptAndHash encrypts plaintext into ciphertext (which must be
+// the correct size to hold the encrypted plaintext) using cipher,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ panic("ciphertext is wrong size for given plaintext")
+ }
+ ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
+ s.MixHash(ret)
+}
+
+// DecryptAndHash decrypts the given ciphertext into plaintext (which
+// must be the correct size to hold the decrypted ciphertext) using
+// cipher. If decryption is successful, it mixes the ciphertext into
+// s.h.
+func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
+ s.checkFinished()
+ if len(ciphertext) != len(plaintext)+chp.Overhead {
+ return errors.New("plaintext is wrong size for given ciphertext")
+ }
+ if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
+ return err
+ }
+ s.MixHash(ciphertext)
+ return nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derived from
+// the current handshake state. Methods on s cannot be used again
+// after calling Split.
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+ s.finished = true
+
+ var k1, k2 [chp.KeySize]byte
+ r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+ if _, err := io.ReadFull(r, k1[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k1: %w", err)
+ }
+ if _, err := io.ReadFull(r, k2[:]); err != nil {
+ return nil, nil, fmt.Errorf("extracting k2: %w", err)
+ }
+ c1, err = chp.New(k1[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+ }
+ c2, err = chp.New(k2[:])
+ if err != nil {
+ return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+ }
+ return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ // Should never happen, errors only happen when using BLAKE2s
+ // in MAC mode with a key.
+ panic(err)
+ }
+ return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+ aead, err := chp.New(key[:])
+ if err != nil {
+ // Can only happen if we passed a key of the wrong length. The
+ // function signature prevents that.
+ panic(err)
+ }
+ return aead
+}
+
+// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
+// only once, either for encrypting or decrypting, but not both. The
+// chosen operation is always executed with an all-zeros
+// nonce. Subsequent calls to either Seal or Open panic.
+type singleUseCHP struct {
+ c cipher.AEAD
+}
+
+func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
+ return &singleUseCHP{newCHP(key)}
+}
+
+func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Seal(dst, nonce[:], plaintext, additionalData)
+}
+
+func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
+ if c.c == nil {
+ panic("Attempted reuse of singleUseAEAD")
+ }
+ cipher := c.c
+ c.c = nil
+ var nonce [chp.NonceSize]byte
+ return cipher.Open(dst, nonce[:], ciphertext, additionalData)
+}
diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index c41fbf4dd..d11c04149 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -1,256 +1,256 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - - "tailscale.com/net/memnet" - "tailscale.com/types/key" -) - -// Can a reference Noise IK client talk to our server? -func TestInteropClient(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - serverErr = make(chan error, 2) - serverBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - server, err := Server(context.Background(), s2, controlKey, nil) - serverErr <- err - if err != nil { - return - } - var buf [1024]byte - _, err = io.ReadFull(server, buf[:len(c2s)]) - serverBytes <- buf[:len(c2s)] - if err != nil { - serverErr <- err - return - } - _, err = server.Write([]byte(s2c)) - serverErr <- err - }() - - gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) - if err != nil { - t.Fatalf("failed client interop: %v", err) - } - if string(gotS2C) != s2c { - t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) - } - - if err := <-serverErr; err != nil { - t.Fatalf("server handshake failed: %v", err) - } - if err := <-serverErr; err != nil { - t.Fatalf("server read/write failed: %v", err) - } - if got := string(<-serverBytes); got != c2s { - t.Fatalf("server received %q, want %q", got, c2s) - } -} - -// Can our client talk to a reference Noise IK server? -func TestInteropServer(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - clientErr = make(chan error, 2) - clientBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) - clientErr <- err - if err != nil { - return - } - _, err = client.Write([]byte(c2s)) - if err != nil { - clientErr <- err - return - } - var buf [1024]byte - _, err = io.ReadFull(client, buf[:len(s2c)]) - clientBytes <- buf[:len(s2c)] - clientErr <- err - }() - - gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) - if err != nil { - t.Fatalf("failed server interop: %v", err) - } - if string(gotC2S) != c2s { - t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) - } - - if err := <-clientErr; err != nil { - t.Fatalf("client handshake failed: %v", err) - } - if err := <-clientErr; err != nil { - t.Fatalf("client read/write failed: %v", err) - } - if got := string(<-clientBytes); got != s2c { - t.Fatalf("client received %q, want %q", got, s2c) - } -} - -// noiseExplorerClient uses the Noise Explorer implementation of Noise -// IK to handshake as a Noise client on conn, transmit payload, and -// read+return a payload from the peer. -func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], machineKey.UntypedBytes()) - copy(mk.public_key[:], machineKey.Public().UntypedBytes()) - var peerKey [32]byte - copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) - - _, msg1 := SendMessage(&session, nil) - var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) - hdr[2] = msgTypeInitiation - binary.BigEndian.PutUint16(hdr[3:5], 96) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ns); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ciphertext); err != nil { - return nil, err - } - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:51]); err != nil { - return nil, err - } - // ignore the header for this test, we're only checking the noise - // implementation. - msg2 := messagebuffer{ - ciphertext: buf[35:51], - } - copy(msg2.ne[:], buf[3:35]) - _, p, valid := RecvMessage(&session, &msg2) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg3 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(hdr[:3]); err != nil { - return nil, err - } - if _, err := conn.Write(msg3.ciphertext); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - // Ignore all of the header except the payload length - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg4 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg4) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - return p, nil -} - -func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], controlKey.UntypedBytes()) - copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:101]); err != nil { - return nil, err - } - // Ignore the header, we're just checking the noise implementation. - msg1 := messagebuffer{ - ns: buf[37:85], - ciphertext: buf[85:101], - } - copy(msg1.ne[:], buf[5:37]) - _, p, valid := RecvMessage(&session, &msg1) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg2 := SendMessage(&session, nil) - var hdr [headerLen]byte - hdr[0] = msgTypeResponse - binary.BigEndian.PutUint16(hdr[1:3], 48) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ciphertext[:]); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg3 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg3) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - _, msg4 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg4.ciphertext); err != nil { - return nil, err - } - - return p, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "testing"
+
+ "tailscale.com/net/memnet"
+ "tailscale.com/types/key"
+)
+
+// Can a reference Noise IK client talk to our server?
+func TestInteropClient(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ serverErr = make(chan error, 2)
+ serverBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ server, err := Server(context.Background(), s2, controlKey, nil)
+ serverErr <- err
+ if err != nil {
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(server, buf[:len(c2s)])
+ serverBytes <- buf[:len(c2s)]
+ if err != nil {
+ serverErr <- err
+ return
+ }
+ _, err = server.Write([]byte(s2c))
+ serverErr <- err
+ }()
+
+ gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
+ if err != nil {
+ t.Fatalf("failed client interop: %v", err)
+ }
+ if string(gotS2C) != s2c {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
+ }
+
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server handshake failed: %v", err)
+ }
+ if err := <-serverErr; err != nil {
+ t.Fatalf("server read/write failed: %v", err)
+ }
+ if got := string(<-serverBytes); got != c2s {
+ t.Fatalf("server received %q, want %q", got, c2s)
+ }
+}
+
+// Can our client talk to a reference Noise IK server?
+func TestInteropServer(t *testing.T) {
+ var (
+ s1, s2 = memnet.NewConn("noise", 128000)
+ controlKey = key.NewMachine()
+ machineKey = key.NewMachine()
+ clientErr = make(chan error, 2)
+ clientBytes = make(chan []byte, 1)
+ c2s = "client>server"
+ s2c = "server>client"
+ )
+
+ go func() {
+ client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
+ clientErr <- err
+ if err != nil {
+ return
+ }
+ _, err = client.Write([]byte(c2s))
+ if err != nil {
+ clientErr <- err
+ return
+ }
+ var buf [1024]byte
+ _, err = io.ReadFull(client, buf[:len(s2c)])
+ clientBytes <- buf[:len(s2c)]
+ clientErr <- err
+ }()
+
+ gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
+ if err != nil {
+ t.Fatalf("failed server interop: %v", err)
+ }
+ if string(gotC2S) != c2s {
+ t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
+ }
+
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client handshake failed: %v", err)
+ }
+ if err := <-clientErr; err != nil {
+ t.Fatalf("client read/write failed: %v", err)
+ }
+ if got := string(<-clientBytes); got != s2c {
+ t.Fatalf("client received %q, want %q", got, s2c)
+ }
+}
+
+// noiseExplorerClient uses the Noise Explorer implementation of Noise
+// IK to handshake as a Noise client on conn, transmit payload, and
+// read+return a payload from the peer.
+func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], machineKey.UntypedBytes())
+ copy(mk.public_key[:], machineKey.Public().UntypedBytes())
+ var peerKey [32]byte
+ copy(peerKey[:], controlKey.UntypedBytes())
+ session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
+
+ _, msg1 := SendMessage(&session, nil)
+ var hdr [initiationHeaderLen]byte
+ binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
+ hdr[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(hdr[3:5], 96)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ns); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg1.ciphertext); err != nil {
+ return nil, err
+ }
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:51]); err != nil {
+ return nil, err
+ }
+ // ignore the header for this test, we're only checking the noise
+ // implementation.
+ msg2 := messagebuffer{
+ ciphertext: buf[35:51],
+ }
+ copy(msg2.ne[:], buf[3:35])
+ _, p, valid := RecvMessage(&session, &msg2)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg3 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
+ if _, err := conn.Write(hdr[:3]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg3.ciphertext); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ // Ignore all of the header except the payload length
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg4 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg4)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ return p, nil
+}
+
+func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
+ var mk keypair
+ copy(mk.private_key[:], controlKey.UntypedBytes())
+ copy(mk.public_key[:], controlKey.Public().UntypedBytes())
+ session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
+
+ var buf [1024]byte
+ if _, err := io.ReadFull(conn, buf[:101]); err != nil {
+ return nil, err
+ }
+ // Ignore the header, we're just checking the noise implementation.
+ msg1 := messagebuffer{
+ ns: buf[37:85],
+ ciphertext: buf[85:101],
+ }
+ copy(msg1.ne[:], buf[5:37])
+ _, p, valid := RecvMessage(&session, &msg1)
+ if !valid {
+ return nil, errors.New("handshake failed")
+ }
+ if len(p) != 0 {
+ return nil, errors.New("non-empty payload")
+ }
+
+ _, msg2 := SendMessage(&session, nil)
+ var hdr [headerLen]byte
+ hdr[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(hdr[1:3], 48)
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ne[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.ReadFull(conn, buf[:3]); err != nil {
+ return nil, err
+ }
+ plen := int(binary.BigEndian.Uint16(buf[1:3]))
+ if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+ return nil, err
+ }
+
+ msg3 := messagebuffer{
+ ciphertext: buf[:plen],
+ }
+ _, p, valid = RecvMessage(&session, &msg3)
+ if !valid {
+ return nil, errors.New("transport message decryption failed")
+ }
+
+ _, msg4 := SendMessage(&session, payload)
+ hdr[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
+ if _, err := conn.Write(hdr[:]); err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(msg4.ciphertext); err != nil {
+ return nil, err
+ }
+
+ return p, nil
+}
diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 59073088f..899378681 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import "encoding/binary" - -const ( - // msgTypeInitiation frames carry a Noise IK handshake initiation message. - msgTypeInitiation = 1 - // msgTypeResponse frames carry a Noise IK handshake response message. - msgTypeResponse = 2 - // msgTypeError frames carry an unauthenticated human-readable - // error message. - // - // Errors reported in this message type must be treated as public - // hints only. They are not encrypted or authenticated, and so can - // be seen and tampered with on the wire. - msgTypeError = 3 - // msgTypeRecord frames carry session data bytes. - msgTypeRecord = 4 - - // headerLen is the size of the header on all messages except msgTypeInitiation. - headerLen = 3 - // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. - initiationHeaderLen = 5 -) - -// initiationMessage is the protocol message sent from a client -// machine to a control server. -// -// 2b: protocol version -// 1b: message type (0x01) -// 2b: payload length (96) -// 5b: header (see headerLen for fields) -// 32b: client ephemeral public key (cleartext) -// 48b: client machine public key (encrypted) -// 16b: message tag (authenticates the whole message) -type initiationMessage [101]byte - -func mkInitiationMessage(protocolVersion uint16) initiationMessage { - var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) - return ret -} - -func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } -func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } - -func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } -func (m *initiationMessage) Type() byte { return m[2] } -func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } - -func (m *initiationMessage) EphemeralPub() []byte { - return m[initiationHeaderLen : initiationHeaderLen+32] -} -func (m *initiationMessage) MachinePub() []byte { - return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] -} -func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } - -// responseMessage is the protocol message sent from a control server -// to a client machine. -// -// 1b: message type (0x02) -// 2b: payload length (48) -// 32b: control ephemeral public key (cleartext) -// 16b: message tag (authenticates the whole message) -type responseMessage [51]byte - -func mkResponseMessage() responseMessage { - var ret responseMessage - ret[0] = msgTypeResponse - binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) - return ret -} - -func (m *responseMessage) Header() []byte { return m[:headerLen] } -func (m *responseMessage) Payload() []byte { return m[headerLen:] } - -func (m *responseMessage) Type() byte { return m[0] } -func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } - -func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlbase
+
+import "encoding/binary"
+
+const (
+ // msgTypeInitiation frames carry a Noise IK handshake initiation message.
+ msgTypeInitiation = 1
+ // msgTypeResponse frames carry a Noise IK handshake response message.
+ msgTypeResponse = 2
+ // msgTypeError frames carry an unauthenticated human-readable
+ // error message.
+ //
+ // Errors reported in this message type must be treated as public
+ // hints only. They are not encrypted or authenticated, and so can
+ // be seen and tampered with on the wire.
+ msgTypeError = 3
+ // msgTypeRecord frames carry session data bytes.
+ msgTypeRecord = 4
+
+ // headerLen is the size of the header on all messages except msgTypeInitiation.
+ headerLen = 3
+ // initiationHeaderLen is the size of the header on all msgTypeInitiation messages.
+ initiationHeaderLen = 5
+)
+
+// initiationMessage is the protocol message sent from a client
+// machine to a control server.
+//
+// 2b: protocol version
+// 1b: message type (0x01)
+// 2b: payload length (96)
+// 5b: header (see headerLen for fields)
+// 32b: client ephemeral public key (cleartext)
+// 48b: client machine public key (encrypted)
+// 16b: message tag (authenticates the whole message)
+type initiationMessage [101]byte
+
+func mkInitiationMessage(protocolVersion uint16) initiationMessage {
+ var ret initiationMessage
+ binary.BigEndian.PutUint16(ret[:2], protocolVersion)
+ ret[2] = msgTypeInitiation
+ binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] }
+func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] }
+
+func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) }
+func (m *initiationMessage) Type() byte { return m[2] }
+func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) }
+
+func (m *initiationMessage) EphemeralPub() []byte {
+ return m[initiationHeaderLen : initiationHeaderLen+32]
+}
+func (m *initiationMessage) MachinePub() []byte {
+ return m[initiationHeaderLen+32 : initiationHeaderLen+32+48]
+}
+func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] }
+
+// responseMessage is the protocol message sent from a control server
+// to a client machine.
+//
+// 1b: message type (0x02)
+// 2b: payload length (48)
+// 32b: control ephemeral public key (cleartext)
+// 16b: message tag (authenticates the whole message)
+type responseMessage [51]byte
+
+func mkResponseMessage() responseMessage {
+ var ret responseMessage
+ ret[0] = msgTypeResponse
+ binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload())))
+ return ret
+}
+
+func (m *responseMessage) Header() []byte { return m[:headerLen] }
+func (m *responseMessage) Payload() []byte { return m[headerLen:] }
+
+func (m *responseMessage) Type() byte { return m[0] }
+func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) }
+
+func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] }
+func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }
diff --git a/control/controlclient/sign.go b/control/controlclient/sign.go index e3a479c28..5e72f1cf4 100644 --- a/control/controlclient/sign.go +++ b/control/controlclient/sign.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "crypto" - "errors" - "fmt" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -var ( - errNoCertStore = errors.New("no certificate store") - errCertificateNotConfigured = errors.New("no certificate subject configured") - errUnsupportedSignatureVersion = errors.New("unsupported signature version") -) - -// HashRegisterRequest generates the hash required sign or verify a -// tailcfg.RegisterRequest. -func HashRegisterRequest( - version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, - serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { - h := crypto.SHA256.New() - - // hash.Hash.Write never returns an error, so we don't check for one here. - switch version { - case tailcfg.SignatureV1: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) - case tailcfg.SignatureV2: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) - default: - return nil, errUnsupportedSignatureVersion - } - - return h.Sum(nil), nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlclient
+
+import (
+ "crypto"
+ "errors"
+ "fmt"
+ "time"
+
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/key"
+)
+
+var (
+ errNoCertStore = errors.New("no certificate store")
+ errCertificateNotConfigured = errors.New("no certificate subject configured")
+ errUnsupportedSignatureVersion = errors.New("unsupported signature version")
+)
+
+// HashRegisterRequest generates the hash required sign or verify a
+// tailcfg.RegisterRequest.
+func HashRegisterRequest(
+ version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte,
+ serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) {
+ h := crypto.SHA256.New()
+
+ // hash.Hash.Write never returns an error, so we don't check for one here.
+ switch version {
+ case tailcfg.SignatureV1:
+ fmt.Fprintf(h, "%s%s%s%s%s",
+ ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString())
+ case tailcfg.SignatureV2:
+ fmt.Fprintf(h, "%s%s%s%s%s",
+ ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey)
+ default:
+ return nil, errUnsupportedSignatureVersion
+ }
+
+ return h.Sum(nil), nil
+}
diff --git a/control/controlclient/sign_supported_test.go b/control/controlclient/sign_supported_test.go index e20349a4e..ca41794d1 100644 --- a/control/controlclient/sign_supported_test.go +++ b/control/controlclient/sign_supported_test.go @@ -1,236 +1,236 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows && cgo - -package controlclient - -import ( - "crypto" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "reflect" - "testing" - "time" - - "github.com/tailscale/certstore" -) - -const ( - testRootCommonName = "testroot" - testRootSubject = "CN=testroot" -) - -type testIdentity struct { - chain []*x509.Certificate -} - -func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { - return []*x509.Certificate{ - { - NotBefore: notBefore, - NotAfter: notAfter, - PublicKeyAlgorithm: x509.RSA, - }, - { - Subject: pkix.Name{ - CommonName: rootCommonName, - }, - PublicKeyAlgorithm: x509.RSA, - }, - } -} - -func (t *testIdentity) Certificate() (*x509.Certificate, error) { - return t.chain[0], nil -} - -func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { - return t.chain, nil -} - -func (t *testIdentity) Signer() (crypto.Signer, error) { - return nil, errors.New("not implemented") -} - -func (t *testIdentity) Delete() error { - return errors.New("not implemented") -} - -func (t *testIdentity) Close() {} - -func TestSelectIdentityFromSlice(t *testing.T) { - var times []time.Time - for _, ts := range []string{ - "2000-01-01T00:00:00Z", - "2001-01-01T00:00:00Z", - "2002-01-01T00:00:00Z", - "2003-01-01T00:00:00Z", - } { - tm, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatal(err) - } - times = append(times, tm) - } - - tests := []struct { - name string - subject string - ids []certstore.Identity - now time.Time - // wantIndex is an index into ids, or -1 for nil. - wantIndex int - }{ - { - name: "single unexpired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 0, - }, - { - name: "single expired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[2]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 1, - }, - { - name: "expired with unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[3]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two certs both unexpired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two unexpired one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) - - if gotId == nil && gotChain != nil { - t.Error("id is nil: got non-nil chain, want nil chain") - return - } - if gotId != nil && gotChain == nil { - t.Error("id is not nil: got nil chain, want non-nil chain") - return - } - if tt.wantIndex == -1 { - if gotId != nil { - t.Error("got non-nil id, want nil id") - } - return - } - if gotId == nil { - t.Error("got nil id, want non-nil id") - return - } - if gotId != tt.ids[tt.wantIndex] { - found := -1 - for i := range tt.ids { - if tt.ids[i] == gotId { - found = i - break - } - } - if found == -1 { - t.Errorf("got unknown id, want id at index %v", tt.wantIndex) - } else { - t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) - } - } - - tid, ok := tt.ids[tt.wantIndex].(*testIdentity) - if !ok { - t.Error("got non-testIdentity, want testIdentity") - return - } - - if !reflect.DeepEqual(tid.chain, gotChain) { - t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows && cgo
+
+package controlclient
+
+import (
+ "crypto"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "errors"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/tailscale/certstore"
+)
+
+const (
+ testRootCommonName = "testroot"
+ testRootSubject = "CN=testroot"
+)
+
+type testIdentity struct {
+ chain []*x509.Certificate
+}
+
+func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
+ return []*x509.Certificate{
+ {
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ PublicKeyAlgorithm: x509.RSA,
+ },
+ {
+ Subject: pkix.Name{
+ CommonName: rootCommonName,
+ },
+ PublicKeyAlgorithm: x509.RSA,
+ },
+ }
+}
+
+func (t *testIdentity) Certificate() (*x509.Certificate, error) {
+ return t.chain[0], nil
+}
+
+func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
+ return t.chain, nil
+}
+
+func (t *testIdentity) Signer() (crypto.Signer, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (t *testIdentity) Delete() error {
+ return errors.New("not implemented")
+}
+
+func (t *testIdentity) Close() {}
+
+func TestSelectIdentityFromSlice(t *testing.T) {
+ var times []time.Time
+ for _, ts := range []string{
+ "2000-01-01T00:00:00Z",
+ "2001-01-01T00:00:00Z",
+ "2002-01-01T00:00:00Z",
+ "2003-01-01T00:00:00Z",
+ } {
+ tm, err := time.Parse(time.RFC3339, ts)
+ if err != nil {
+ t.Fatal(err)
+ }
+ times = append(times, tm)
+ }
+
+ tests := []struct {
+ name string
+ subject string
+ ids []certstore.Identity
+ now time.Time
+ // wantIndex is an index into ids, or -1 for nil.
+ wantIndex int
+ }{
+ {
+ name: "single unexpired identity",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[2]),
+ },
+ },
+ now: times[1],
+ wantIndex: 0,
+ },
+ {
+ name: "single expired identity",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[1]),
+ },
+ },
+ now: times[2],
+ wantIndex: -1,
+ },
+ {
+ name: "unrelated ids",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain("something", times[0], times[2]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[2]),
+ },
+ &testIdentity{
+ chain: makeChain("else", times[0], times[2]),
+ },
+ },
+ now: times[1],
+ wantIndex: 1,
+ },
+ {
+ name: "expired with unrelated ids",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain("something", times[0], times[3]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[1]),
+ },
+ &testIdentity{
+ chain: makeChain("else", times[0], times[3]),
+ },
+ },
+ now: times[2],
+ wantIndex: -1,
+ },
+ {
+ name: "one expired",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[1]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[1], times[3]),
+ },
+ },
+ now: times[2],
+ wantIndex: 1,
+ },
+ {
+ name: "two certs both unexpired",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[3]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[1], times[3]),
+ },
+ },
+ now: times[2],
+ wantIndex: 1,
+ },
+ {
+ name: "two unexpired one expired",
+ subject: testRootSubject,
+ ids: []certstore.Identity{
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[3]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[1], times[3]),
+ },
+ &testIdentity{
+ chain: makeChain(testRootCommonName, times[0], times[1]),
+ },
+ },
+ now: times[2],
+ wantIndex: 1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
+
+ if gotId == nil && gotChain != nil {
+ t.Error("id is nil: got non-nil chain, want nil chain")
+ return
+ }
+ if gotId != nil && gotChain == nil {
+ t.Error("id is not nil: got nil chain, want non-nil chain")
+ return
+ }
+ if tt.wantIndex == -1 {
+ if gotId != nil {
+ t.Error("got non-nil id, want nil id")
+ }
+ return
+ }
+ if gotId == nil {
+ t.Error("got nil id, want non-nil id")
+ return
+ }
+ if gotId != tt.ids[tt.wantIndex] {
+ found := -1
+ for i := range tt.ids {
+ if tt.ids[i] == gotId {
+ found = i
+ break
+ }
+ }
+ if found == -1 {
+ t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
+ } else {
+ t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
+ }
+ }
+
+ tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
+ if !ok {
+ t.Error("got non-testIdentity, want testIdentity")
+ return
+ }
+
+ if !reflect.DeepEqual(tid.chain, gotChain) {
+ t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
+ }
+ })
+ }
+}
diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index 5e161dcbc..4ec40d502 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package controlclient - -import ( - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// signRegisterRequest on non-supported platforms always returns errNoCertStore. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { - return errNoCertStore -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package controlclient
+
+import (
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/key"
+)
+
+// signRegisterRequest on non-supported platforms always returns errNoCertStore.
+func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error {
+ return errNoCertStore
+}
diff --git a/control/controlclient/status.go b/control/controlclient/status.go index d0fdf80d7..7dba14d3f 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "encoding/json" - "fmt" - "reflect" - - "tailscale.com/types/netmap" - "tailscale.com/types/persist" - "tailscale.com/types/structs" -) - -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - -type Status struct { - _ structs.Incomparable - - // Err, if non-nil, is an error that occurred while logging in. - // - // If it's of type UserVisibleError then it's meant to be shown to users in - // their Tailscale client. Otherwise it's just logged to tailscaled's logs. - Err error - - // URL, if non-empty, is the interactive URL to visit to finish logging in. - URL string - - // NetMap is the latest server-pushed state of the tailnet network. - NetMap *netmap.NetworkMap - - // Persist, when Valid, is the locally persisted configuration. - // - // TODO(bradfitz,maisem): clarify this. - Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State -} - -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - -// Equal reports whether s and s2 are equal. -func (s *Status) Equal(s2 *Status) bool { - if s == nil && s2 == nil { - return true - } - return s != nil && s2 != nil && - s.Err == s2.Err && - s.URL == s2.URL && - s.state == s2.state && - reflect.DeepEqual(s.Persist, s2.Persist) && - reflect.DeepEqual(s.NetMap, s2.NetMap) -} - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlclient
+
+import (
+ "encoding/json"
+ "fmt"
+ "reflect"
+
+ "tailscale.com/types/netmap"
+ "tailscale.com/types/persist"
+ "tailscale.com/types/structs"
+)
+
+// State is the high-level state of the client. It is used only in
+// unit tests for proper sequencing, don't depend on it anywhere else.
+//
+// TODO(apenwarr): eliminate the state, as it's now obsolete.
+//
+// apenwarr: Historical note: controlclient.Auto was originally
+// intended to be the state machine for the whole tailscale client, but that
+// turned out to not be the right abstraction layer, and it moved to
+// ipn.Backend. Since ipn.Backend now has a state machine, it would be
+// much better if controlclient could be a simple stateless API. But the
+// current server-side API (two interlocking polling https calls) makes that
+// very hard to implement. A server side API change could untangle this and
+// remove all the statefulness.
+type State int
+
+const (
+ StateNew = State(iota)
+ StateNotAuthenticated
+ StateAuthenticating
+ StateURLVisitRequired
+ StateAuthenticated
+ StateSynchronized // connected and received map update
+)
+
+func (s State) AppendText(b []byte) ([]byte, error) {
+ return append(b, s.String()...), nil
+}
+
+func (s State) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+func (s State) String() string {
+ switch s {
+ case StateNew:
+ return "state:new"
+ case StateNotAuthenticated:
+ return "state:not-authenticated"
+ case StateAuthenticating:
+ return "state:authenticating"
+ case StateURLVisitRequired:
+ return "state:url-visit-required"
+ case StateAuthenticated:
+ return "state:authenticated"
+ case StateSynchronized:
+ return "state:synchronized"
+ default:
+ return fmt.Sprintf("state:unknown:%d", int(s))
+ }
+}
+
+type Status struct {
+ _ structs.Incomparable
+
+ // Err, if non-nil, is an error that occurred while logging in.
+ //
+ // If it's of type UserVisibleError then it's meant to be shown to users in
+ // their Tailscale client. Otherwise it's just logged to tailscaled's logs.
+ Err error
+
+ // URL, if non-empty, is the interactive URL to visit to finish logging in.
+ URL string
+
+ // NetMap is the latest server-pushed state of the tailnet network.
+ NetMap *netmap.NetworkMap
+
+ // Persist, when Valid, is the locally persisted configuration.
+ //
+ // TODO(bradfitz,maisem): clarify this.
+ Persist persist.PersistView
+
+ // state is the internal state. It should not be exposed outside this
+ // package, but we have some automated tests elsewhere that need to
+ // use it via the StateForTest accessor.
+ // TODO(apenwarr): Unexport or remove these.
+ state State
+}
+
+// LoginFinished reports whether the controlclient is in its "StateAuthenticated"
+// state where it's in a happy register state but not yet in a map poll.
+//
+// TODO(bradfitz): delete this and everything around Status.state.
+func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated }
+
+// StateForTest returns the internal state of s for tests only.
+func (s *Status) StateForTest() State { return s.state }
+
+// SetStateForTest sets the internal state of s for tests only.
+func (s *Status) SetStateForTest(state State) { s.state = state }
+
+// Equal reports whether s and s2 are equal.
+func (s *Status) Equal(s2 *Status) bool {
+ if s == nil && s2 == nil {
+ return true
+ }
+ return s != nil && s2 != nil &&
+ s.Err == s2.Err &&
+ s.URL == s2.URL &&
+ s.state == s2.state &&
+ reflect.DeepEqual(s.Persist, s2.Persist) &&
+ reflect.DeepEqual(s.NetMap, s2.NetMap)
+}
+
+func (s Status) String() string {
+ b, err := json.MarshalIndent(s, "", "\t")
+ if err != nil {
+ panic(err)
+ }
+ return s.state.String() + " " + string(b)
+}
diff --git a/control/controlhttp/client_common.go b/control/controlhttp/client_common.go index dd94e93cd..72a89e3cd 100644 --- a/control/controlhttp/client_common.go +++ b/control/controlhttp/client_common.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlhttp - -import ( - "tailscale.com/control/controlbase" -) - -// ClientConn is a Tailscale control client as returned by the Dialer. -// -// It's effectively just a *controlbase.Conn (which it embeds) with -// optional metadata. -type ClientConn struct { - // Conn is the noise connection. - *controlbase.Conn -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package controlhttp
+
+import (
+ "tailscale.com/control/controlbase"
+)
+
+// ClientConn is a Tailscale control client as returned by the Dialer.
+//
+// It's effectively just a *controlbase.Conn (which it embeds) with
+// optional metadata.
+type ClientConn struct {
+ // Conn is the noise connection.
+ *controlbase.Conn
+}
|
