summaryrefslogtreecommitdiffhomepage
path: root/control/controlbase/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'control/controlbase/conn.go')
-rw-r--r--control/controlbase/conn.go816
1 files changed, 408 insertions, 408 deletions
diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go
index dc22212e8..b6fc53b3a 100644
--- a/control/controlbase/conn.go
+++ b/control/controlbase/conn.go
@@ -1,408 +1,408 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package controlbase implements the base transport of the Tailscale
-// 2021 control protocol.
-//
-// The base transport implements Noise IK, instantiated with
-// Curve25519, ChaCha20Poly1305 and BLAKE2s.
-package controlbase
-
-import (
- "crypto/cipher"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "time"
-
- "golang.org/x/crypto/blake2s"
- chp "golang.org/x/crypto/chacha20poly1305"
- "tailscale.com/types/key"
-)
-
-const (
- // maxMessageSize is the maximum size of a protocol frame on the
- // wire, including header and payload.
- maxMessageSize = 4096
- // maxCiphertextSize is the maximum amount of ciphertext bytes
- // that one protocol frame can carry, after framing.
- maxCiphertextSize = maxMessageSize - 3
- // maxPlaintextSize is the maximum amount of plaintext bytes that
- // one protocol frame can carry, after encryption and framing.
- maxPlaintextSize = maxCiphertextSize - chp.Overhead
-)
-
-// A Conn is a secured Noise connection. It implements the net.Conn
-// interface, with the unusual trait that any write error (including a
-// SetWriteDeadline induced i/o timeout) causes all future writes to
-// fail.
-type Conn struct {
- conn net.Conn
- version uint16
- peer key.MachinePublic
- handshakeHash [blake2s.Size]byte
- rx rxState
- tx txState
-}
-
-// rxState is all the Conn state that Read uses.
-type rxState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- buf *maxMsgBuffer // or nil when reads exhausted
- n int // number of valid bytes in buf
- next int // offset of next undecrypted packet
- plaintext []byte // slice into buf of decrypted bytes
- hdrBuf [headerLen]byte // small buffer used when buf is nil
-}
-
-// txState is all the Conn state that Write uses.
-type txState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- err error // records the first partial write error for all future calls
-}
-
-// ProtocolVersion returns the protocol version that was used to
-// establish this Conn.
-func (c *Conn) ProtocolVersion() int {
- return int(c.version)
-}
-
-// HandshakeHash returns the Noise handshake hash for the connection,
-// which can be used to bind other messages to this connection
-// (i.e. to ensure that the message wasn't replayed from a different
-// connection).
-func (c *Conn) HandshakeHash() [blake2s.Size]byte {
- return c.handshakeHash
-}
-
-// Peer returns the peer's long-term public key.
-func (c *Conn) Peer() key.MachinePublic {
- return c.peer
-}
-
-// readNLocked reads into c.rx.buf until buf contains at least total
-// bytes. Returns a slice of the total bytes in rxBuf, or an
-// error if fewer than total bytes are available.
-//
-// It may be called with a nil c.rx.buf only if total == headerLen.
-//
-// On success, c.rx.buf will be non-nil.
-func (c *Conn) readNLocked(total int) ([]byte, error) {
- if total > maxMessageSize {
- return nil, errReadTooBig{total}
- }
- for {
- if total <= c.rx.n {
- return c.rx.buf[:total], nil
- }
- var n int
- var err error
- if c.rx.buf == nil {
- if c.rx.n != 0 || total != headerLen {
- panic("unexpected")
- }
- // Optimization to reduce memory usage.
- // Most connections are blocked forever waiting for
- // a read, so we don't want c.rx.buf to be allocated until
- // we know there's data to read. Instead, when we're
- // waiting for data to arrive here, read into the
- // 3 byte hdrBuf:
- n, err = c.conn.Read(c.rx.hdrBuf[:])
- if n > 0 {
- c.rx.buf = getMaxMsgBuffer()
- copy(c.rx.buf[:], c.rx.hdrBuf[:n])
- }
- } else {
- n, err = c.conn.Read(c.rx.buf[c.rx.n:])
- }
- c.rx.n += n
- if err != nil {
- return nil, err
- }
- }
-}
-
-// decryptLocked decrypts msg (which is header+ciphertext) in-place
-// and sets c.rx.plaintext to the decrypted bytes.
-func (c *Conn) decryptLocked(msg []byte) (err error) {
- if msgType := msg[0]; msgType != msgTypeRecord {
- return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
- }
- // We don't check the length field here, because the caller
- // already did in order to figure out how big the msg slice should
- // be.
- ciphertext := msg[headerLen:]
-
- if !c.rx.nonce.Valid() {
- return errCipherExhausted{}
- }
-
- c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
- c.rx.nonce.Increment()
-
- if err != nil {
- // Once a decryption has failed, our Conn is no longer
- // synchronized with our peer. Nuke the cipher state to be
- // safe, so that no further decryptions are attempted. Future
- // read attempts will return net.ErrClosed.
- c.rx.cipher = nil
- }
- return err
-}
-
-// encryptLocked encrypts plaintext into buf (including the
-// packet header) and returns a slice of the ciphertext, or an error
-// if the cipher is exhausted (i.e. can no longer be used safely).
-func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
- if !c.tx.nonce.Valid() {
- // Received 2^64-1 messages on this cipher state. Connection
- // is no longer usable.
- return nil, errCipherExhausted{}
- }
-
- buf[0] = msgTypeRecord
- binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
- ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
- c.tx.nonce.Increment()
-
- return ret, nil
-}
-
-// wholeMessageLocked returns a slice of one whole Noise transport
-// message from c.rx.buf, if one whole message is available, and
-// advances the read state to the next Noise message in the
-// buffer. Returns nil without advancing read state if there isn't one
-// whole message in c.rx.buf.
-func (c *Conn) wholeMessageLocked() []byte {
- available := c.rx.n - c.rx.next
- if available < headerLen {
- return nil
- }
- bs := c.rx.buf[c.rx.next:c.rx.n]
- totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- if len(bs) < totalSize {
- return nil
- }
- c.rx.next += totalSize
- return bs[:totalSize]
-}
-
-// decryptOneLocked decrypts one Noise transport message, reading from
-// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
-// bytes. c.rx.plaintext is only valid if err == nil.
-func (c *Conn) decryptOneLocked() error {
- c.rx.plaintext = nil
-
- // Fast path: do we have one whole ciphertext frame buffered
- // already?
- if bs := c.wholeMessageLocked(); bs != nil {
- return c.decryptLocked(bs)
- }
-
- if c.rx.next != 0 {
- // To simplify the read logic, move the remainder of the
- // buffered bytes back to the head of the buffer, so we can
- // grow it without worrying about wraparound.
- c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
- c.rx.next = 0
- }
-
- // Return our buffer to the pool if it's empty, lest we be
- // blocked in a long Read call, reading the 3 byte header. We
- // don't to keep that buffer unnecessarily alive.
- if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
- bufPool.Put(c.rx.buf)
- c.rx.buf = nil
- }
-
- bs, err := c.readNLocked(headerLen)
- if err != nil {
- return err
- }
- // The rest of the header (besides the length field) gets verified
- // in decryptLocked, not here.
- messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- bs, err = c.readNLocked(messageLen)
- if err != nil {
- return err
- }
-
- c.rx.next = len(bs)
-
- return c.decryptLocked(bs)
-}
-
-// Read implements io.Reader.
-func (c *Conn) Read(bs []byte) (int, error) {
- c.rx.Lock()
- defer c.rx.Unlock()
-
- if c.rx.cipher == nil {
- return 0, net.ErrClosed
- }
- // If no plaintext is buffered, decrypt incoming frames until we
- // have some plaintext. Zero-byte Noise frames are allowed in this
- // protocol, which is why we have to loop here rather than decrypt
- // a single additional frame.
- for len(c.rx.plaintext) == 0 {
- if err := c.decryptOneLocked(); err != nil {
- return 0, err
- }
- }
- n := copy(bs, c.rx.plaintext)
- c.rx.plaintext = c.rx.plaintext[n:]
-
- // Lose slice's underlying array pointer to unneeded memory so
- // GC can collect more.
- if len(c.rx.plaintext) == 0 {
- c.rx.plaintext = nil
- }
- return n, nil
-}
-
-// Write implements io.Writer.
-func (c *Conn) Write(bs []byte) (n int, err error) {
- c.tx.Lock()
- defer c.tx.Unlock()
-
- if c.tx.err != nil {
- return 0, c.tx.err
- }
- defer func() {
- if err != nil {
- // All write errors are fatal for this conn, so clear the
- // cipher state whenever an error happens.
- c.tx.cipher = nil
- }
- if c.tx.err == nil {
- // Only set c.tx.err if not nil so that we can return one
- // error on the first failure, and a different one for
- // subsequent calls. See the error handling around Write
- // below for why.
- c.tx.err = err
- }
- }()
-
- if c.tx.cipher == nil {
- return 0, net.ErrClosed
- }
-
- buf := getMaxMsgBuffer()
- defer bufPool.Put(buf)
-
- var sent int
- for len(bs) > 0 {
- toSend := bs
- if len(toSend) > maxPlaintextSize {
- toSend = bs[:maxPlaintextSize]
- }
- bs = bs[len(toSend):]
-
- ciphertext, err := c.encryptLocked(toSend, buf)
- if err != nil {
- return sent, err
- }
- if _, err := c.conn.Write(ciphertext); err != nil {
- // Return the raw error on the Write that actually
- // failed. For future writes, return that error wrapped in
- // a desync error.
- c.tx.err = errPartialWrite{err}
- return sent, err
- }
- sent += len(toSend)
- }
- return sent, nil
-}
-
-// Close implements io.Closer.
-func (c *Conn) Close() error {
- closeErr := c.conn.Close() // unblocks any waiting reads or writes
-
- // Remove references to live cipher state. Strictly speaking this
- // is unnecessary, but we want to try and hand the active cipher
- // state to the garbage collector promptly, to preserve perfect
- // forward secrecy as much as we can.
- c.rx.Lock()
- c.rx.cipher = nil
- c.rx.Unlock()
- c.tx.Lock()
- c.tx.cipher = nil
- c.tx.Unlock()
- return closeErr
-}
-
-func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
-func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
-func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
-func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
-func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
-
-// errCipherExhausted is the error returned when we run out of nonces
-// on a cipher.
-type errCipherExhausted struct{}
-
-func (errCipherExhausted) Error() string {
- return "cipher exhausted, no more nonces available for current key"
-}
-func (errCipherExhausted) Timeout() bool { return false }
-func (errCipherExhausted) Temporary() bool { return false }
-
-// errPartialWrite is the error returned when the cipher state has
-// become unusable due to a past partial write.
-type errPartialWrite struct {
- err error
-}
-
-func (e errPartialWrite) Error() string {
- return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
-}
-func (e errPartialWrite) Unwrap() error { return e.err }
-func (e errPartialWrite) Temporary() bool { return false }
-func (e errPartialWrite) Timeout() bool { return false }
-
-// errReadTooBig is the error returned when the peer sent an
-// unacceptably large Noise frame.
-type errReadTooBig struct {
- requested int
-}
-
-func (e errReadTooBig) Error() string {
- return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
-}
-func (e errReadTooBig) Temporary() bool {
- // permanent error because this error only occurs when our peer
- // sends us a frame so large we're unwilling to ever decode it.
- return false
-}
-func (e errReadTooBig) Timeout() bool { return false }
-
-type nonce [chp.NonceSize]byte
-
-func (n *nonce) Valid() bool {
- return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
-}
-
-func (n *nonce) Increment() {
- if !n.Valid() {
- panic("increment of invalid nonce")
- }
- binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
-}
-
-type maxMsgBuffer [maxMessageSize]byte
-
-// bufPool holds the temporary buffers for Conn.Read & Write.
-var bufPool = &sync.Pool{
- New: func() any {
- return new(maxMsgBuffer)
- },
-}
-
-func getMaxMsgBuffer() *maxMsgBuffer {
- return bufPool.Get().(*maxMsgBuffer)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package controlbase implements the base transport of the Tailscale
+// 2021 control protocol.
+//
+// The base transport implements Noise IK, instantiated with
+// Curve25519, ChaCha20Poly1305 and BLAKE2s.
+package controlbase
+
+import (
+ "crypto/cipher"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/blake2s"
+ chp "golang.org/x/crypto/chacha20poly1305"
+ "tailscale.com/types/key"
+)
+
+const (
+ // maxMessageSize is the maximum size of a protocol frame on the
+ // wire, including header and payload.
+ maxMessageSize = 4096
+ // maxCiphertextSize is the maximum amount of ciphertext bytes
+ // that one protocol frame can carry, after framing.
+ maxCiphertextSize = maxMessageSize - 3
+ // maxPlaintextSize is the maximum amount of plaintext bytes that
+ // one protocol frame can carry, after encryption and framing.
+ maxPlaintextSize = maxCiphertextSize - chp.Overhead
+)
+
+// A Conn is a secured Noise connection. It implements the net.Conn
+// interface, with the unusual trait that any write error (including a
+// SetWriteDeadline induced i/o timeout) causes all future writes to
+// fail.
+type Conn struct {
+ conn net.Conn
+ version uint16
+ peer key.MachinePublic
+ handshakeHash [blake2s.Size]byte
+ rx rxState
+ tx txState
+}
+
+// rxState is all the Conn state that Read uses.
+type rxState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ buf *maxMsgBuffer // or nil when reads exhausted
+ n int // number of valid bytes in buf
+ next int // offset of next undecrypted packet
+ plaintext []byte // slice into buf of decrypted bytes
+ hdrBuf [headerLen]byte // small buffer used when buf is nil
+}
+
+// txState is all the Conn state that Write uses.
+type txState struct {
+ sync.Mutex
+ cipher cipher.AEAD
+ nonce nonce
+ err error // records the first partial write error for all future calls
+}
+
+// ProtocolVersion returns the protocol version that was used to
+// establish this Conn.
+func (c *Conn) ProtocolVersion() int {
+ return int(c.version)
+}
+
+// HandshakeHash returns the Noise handshake hash for the connection,
+// which can be used to bind other messages to this connection
+// (i.e. to ensure that the message wasn't replayed from a different
+// connection).
+func (c *Conn) HandshakeHash() [blake2s.Size]byte {
+ return c.handshakeHash
+}
+
+// Peer returns the peer's long-term public key.
+func (c *Conn) Peer() key.MachinePublic {
+ return c.peer
+}
+
+// readNLocked reads into c.rx.buf until buf contains at least total
+// bytes. Returns a slice of the total bytes in rxBuf, or an
+// error if fewer than total bytes are available.
+//
+// It may be called with a nil c.rx.buf only if total == headerLen.
+//
+// On success, c.rx.buf will be non-nil.
+func (c *Conn) readNLocked(total int) ([]byte, error) {
+ if total > maxMessageSize {
+ return nil, errReadTooBig{total}
+ }
+ for {
+ if total <= c.rx.n {
+ return c.rx.buf[:total], nil
+ }
+ var n int
+ var err error
+ if c.rx.buf == nil {
+ if c.rx.n != 0 || total != headerLen {
+ panic("unexpected")
+ }
+ // Optimization to reduce memory usage.
+ // Most connections are blocked forever waiting for
+ // a read, so we don't want c.rx.buf to be allocated until
+ // we know there's data to read. Instead, when we're
+ // waiting for data to arrive here, read into the
+ // 3 byte hdrBuf:
+ n, err = c.conn.Read(c.rx.hdrBuf[:])
+ if n > 0 {
+ c.rx.buf = getMaxMsgBuffer()
+ copy(c.rx.buf[:], c.rx.hdrBuf[:n])
+ }
+ } else {
+ n, err = c.conn.Read(c.rx.buf[c.rx.n:])
+ }
+ c.rx.n += n
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// decryptLocked decrypts msg (which is header+ciphertext) in-place
+// and sets c.rx.plaintext to the decrypted bytes.
+func (c *Conn) decryptLocked(msg []byte) (err error) {
+ if msgType := msg[0]; msgType != msgTypeRecord {
+ return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
+ }
+ // We don't check the length field here, because the caller
+ // already did in order to figure out how big the msg slice should
+ // be.
+ ciphertext := msg[headerLen:]
+
+ if !c.rx.nonce.Valid() {
+ return errCipherExhausted{}
+ }
+
+ c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
+ c.rx.nonce.Increment()
+
+ if err != nil {
+ // Once a decryption has failed, our Conn is no longer
+ // synchronized with our peer. Nuke the cipher state to be
+ // safe, so that no further decryptions are attempted. Future
+ // read attempts will return net.ErrClosed.
+ c.rx.cipher = nil
+ }
+ return err
+}
+
+// encryptLocked encrypts plaintext into buf (including the
+// packet header) and returns a slice of the ciphertext, or an error
+// if the cipher is exhausted (i.e. can no longer be used safely).
+func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
+ if !c.tx.nonce.Valid() {
+ // Received 2^64-1 messages on this cipher state. Connection
+ // is no longer usable.
+ return nil, errCipherExhausted{}
+ }
+
+ buf[0] = msgTypeRecord
+ binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
+ ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
+ c.tx.nonce.Increment()
+
+ return ret, nil
+}
+
+// wholeMessageLocked returns a slice of one whole Noise transport
+// message from c.rx.buf, if one whole message is available, and
+// advances the read state to the next Noise message in the
+// buffer. Returns nil without advancing read state if there isn't one
+// whole message in c.rx.buf.
+func (c *Conn) wholeMessageLocked() []byte {
+ available := c.rx.n - c.rx.next
+ if available < headerLen {
+ return nil
+ }
+ bs := c.rx.buf[c.rx.next:c.rx.n]
+ totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ if len(bs) < totalSize {
+ return nil
+ }
+ c.rx.next += totalSize
+ return bs[:totalSize]
+}
+
+// decryptOneLocked decrypts one Noise transport message, reading from
+// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
+// bytes. c.rx.plaintext is only valid if err == nil.
+func (c *Conn) decryptOneLocked() error {
+ c.rx.plaintext = nil
+
+ // Fast path: do we have one whole ciphertext frame buffered
+ // already?
+ if bs := c.wholeMessageLocked(); bs != nil {
+ return c.decryptLocked(bs)
+ }
+
+ if c.rx.next != 0 {
+ // To simplify the read logic, move the remainder of the
+ // buffered bytes back to the head of the buffer, so we can
+ // grow it without worrying about wraparound.
+ c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
+ c.rx.next = 0
+ }
+
+ // Return our buffer to the pool if it's empty, lest we be
+ // blocked in a long Read call, reading the 3 byte header. We
+ // don't to keep that buffer unnecessarily alive.
+ if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
+ bufPool.Put(c.rx.buf)
+ c.rx.buf = nil
+ }
+
+ bs, err := c.readNLocked(headerLen)
+ if err != nil {
+ return err
+ }
+ // The rest of the header (besides the length field) gets verified
+ // in decryptLocked, not here.
+ messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
+ bs, err = c.readNLocked(messageLen)
+ if err != nil {
+ return err
+ }
+
+ c.rx.next = len(bs)
+
+ return c.decryptLocked(bs)
+}
+
+// Read implements io.Reader.
+func (c *Conn) Read(bs []byte) (int, error) {
+ c.rx.Lock()
+ defer c.rx.Unlock()
+
+ if c.rx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+ // If no plaintext is buffered, decrypt incoming frames until we
+ // have some plaintext. Zero-byte Noise frames are allowed in this
+ // protocol, which is why we have to loop here rather than decrypt
+ // a single additional frame.
+ for len(c.rx.plaintext) == 0 {
+ if err := c.decryptOneLocked(); err != nil {
+ return 0, err
+ }
+ }
+ n := copy(bs, c.rx.plaintext)
+ c.rx.plaintext = c.rx.plaintext[n:]
+
+ // Lose slice's underlying array pointer to unneeded memory so
+ // GC can collect more.
+ if len(c.rx.plaintext) == 0 {
+ c.rx.plaintext = nil
+ }
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (c *Conn) Write(bs []byte) (n int, err error) {
+ c.tx.Lock()
+ defer c.tx.Unlock()
+
+ if c.tx.err != nil {
+ return 0, c.tx.err
+ }
+ defer func() {
+ if err != nil {
+ // All write errors are fatal for this conn, so clear the
+ // cipher state whenever an error happens.
+ c.tx.cipher = nil
+ }
+ if c.tx.err == nil {
+ // Only set c.tx.err if not nil so that we can return one
+ // error on the first failure, and a different one for
+ // subsequent calls. See the error handling around Write
+ // below for why.
+ c.tx.err = err
+ }
+ }()
+
+ if c.tx.cipher == nil {
+ return 0, net.ErrClosed
+ }
+
+ buf := getMaxMsgBuffer()
+ defer bufPool.Put(buf)
+
+ var sent int
+ for len(bs) > 0 {
+ toSend := bs
+ if len(toSend) > maxPlaintextSize {
+ toSend = bs[:maxPlaintextSize]
+ }
+ bs = bs[len(toSend):]
+
+ ciphertext, err := c.encryptLocked(toSend, buf)
+ if err != nil {
+ return sent, err
+ }
+ if _, err := c.conn.Write(ciphertext); err != nil {
+ // Return the raw error on the Write that actually
+ // failed. For future writes, return that error wrapped in
+ // a desync error.
+ c.tx.err = errPartialWrite{err}
+ return sent, err
+ }
+ sent += len(toSend)
+ }
+ return sent, nil
+}
+
+// Close implements io.Closer.
+func (c *Conn) Close() error {
+ closeErr := c.conn.Close() // unblocks any waiting reads or writes
+
+ // Remove references to live cipher state. Strictly speaking this
+ // is unnecessary, but we want to try and hand the active cipher
+ // state to the garbage collector promptly, to preserve perfect
+ // forward secrecy as much as we can.
+ c.rx.Lock()
+ c.rx.cipher = nil
+ c.rx.Unlock()
+ c.tx.Lock()
+ c.tx.cipher = nil
+ c.tx.Unlock()
+ return closeErr
+}
+
+func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
+func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
+func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
+func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
+func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
+
+// errCipherExhausted is the error returned when we run out of nonces
+// on a cipher.
+type errCipherExhausted struct{}
+
+func (errCipherExhausted) Error() string {
+ return "cipher exhausted, no more nonces available for current key"
+}
+func (errCipherExhausted) Timeout() bool { return false }
+func (errCipherExhausted) Temporary() bool { return false }
+
+// errPartialWrite is the error returned when the cipher state has
+// become unusable due to a past partial write.
+type errPartialWrite struct {
+ err error
+}
+
+func (e errPartialWrite) Error() string {
+ return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
+}
+func (e errPartialWrite) Unwrap() error { return e.err }
+func (e errPartialWrite) Temporary() bool { return false }
+func (e errPartialWrite) Timeout() bool { return false }
+
+// errReadTooBig is the error returned when the peer sent an
+// unacceptably large Noise frame.
+type errReadTooBig struct {
+ requested int
+}
+
+func (e errReadTooBig) Error() string {
+ return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
+}
+func (e errReadTooBig) Temporary() bool {
+ // permanent error because this error only occurs when our peer
+ // sends us a frame so large we're unwilling to ever decode it.
+ return false
+}
+func (e errReadTooBig) Timeout() bool { return false }
+
+type nonce [chp.NonceSize]byte
+
+func (n *nonce) Valid() bool {
+ return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
+}
+
+func (n *nonce) Increment() {
+ if !n.Valid() {
+ panic("increment of invalid nonce")
+ }
+ binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
+}
+
+type maxMsgBuffer [maxMessageSize]byte
+
+// bufPool holds the temporary buffers for Conn.Read & Write.
+var bufPool = &sync.Pool{
+ New: func() any {
+ return new(maxMsgBuffer)
+ },
+}
+
+func getMaxMsgBuffer() *maxMsgBuffer {
+ return bufPool.Get().(*maxMsgBuffer)
+}