summaryrefslogtreecommitdiffhomepage
path: root/control/tsp/map.go
diff options
context:
space:
mode:
Diffstat (limited to 'control/tsp/map.go')
-rw-r--r--control/tsp/map.go339
1 files changed, 339 insertions, 0 deletions
diff --git a/control/tsp/map.go b/control/tsp/map.go
new file mode 100644
index 000000000..96531255b
--- /dev/null
+++ b/control/tsp/map.go
@@ -0,0 +1,339 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsp
+
+import (
+ "bytes"
+ "cmp"
+ "context"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "github.com/klauspost/compress/zstd"
+ "tailscale.com/control/ts2021"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/key"
+)
+
+// errSessionClosed is returned by [MapSession.Next] and
+// [MapSession.NextInto] when called after [MapSession.Close].
+var errSessionClosed = errors.New("tsp: map session closed")
+
+// DefaultMaxMessageSize is the default cap, in bytes, on the size of a
+// single compressed map response frame. See [MapOpts.MaxMessageSize].
+const DefaultMaxMessageSize = 4 << 20
+
+// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to
+// amortize the cost of setting up zstd state. Decoders are returned via
+// [MapSession.Close]; entries are reclaimed by the runtime under memory
+// pressure via sync.Pool semantics.
+var zstdDecoderPool sync.Pool // of *zstd.Decoder
+
+// MapOpts contains options for sending a map request.
+type MapOpts struct {
+ // NodeKey is the node's private key. Required.
+ NodeKey key.NodePrivate
+
+ // Hostinfo is the host information to send. Optional;
+ // if nil, a minimal default is used.
+ Hostinfo *tailcfg.Hostinfo
+
+ // Stream is whether to receive multiple MapResponses over
+ // the same HTTP connection.
+ Stream bool
+
+ // OmitPeers is whether the client is okay with the Peers list
+ // being omitted in the response.
+ OmitPeers bool
+
+ // MaxMessageSize is the maximum size in bytes of any single
+ // compressed map response frame on the wire. If zero,
+ // [DefaultMaxMessageSize] is used.
+ MaxMessageSize int64
+}
+
+// framedReader is an io.Reader that consumes a stream of length-prefixed
+// frames (each a little-endian uint32 length followed by that many bytes)
+// from r and yields only the frame payloads back-to-back.
+//
+// This lets us feed the concatenated zstd frames from our wire protocol
+// into a single streaming zstd decoder. Zstd's file format permits
+// concatenation (RFC 8478 §2), and klauspost's decoder handles it
+// transparently.
+//
+// If onNewFrame is non-nil, it is called after each new 4-byte length
+// header is successfully read. Used to reset the per-message decoded-size
+// budget downstream.
+type framedReader struct {
+ r io.Reader
+ maxSize int64 // per-frame compressed-size cap
+ remain int // bytes remaining in the current frame
+ onNewFrame func()
+}
+
+func (f *framedReader) Read(p []byte) (int, error) {
+ if f.remain == 0 {
+ var hdr [4]byte
+ if _, err := io.ReadFull(f.r, hdr[:]); err != nil {
+ return 0, err
+ }
+ sz := int64(binary.LittleEndian.Uint32(hdr[:]))
+ if sz == 0 {
+ return 0, fmt.Errorf("map response: zero-length frame")
+ }
+ if sz > f.maxSize {
+ return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize)
+ }
+ f.remain = int(sz)
+ if f.onNewFrame != nil {
+ f.onNewFrame()
+ }
+ }
+ if len(p) > f.remain {
+ p = p[:f.remain]
+ }
+ n, err := f.r.Read(p)
+ f.remain -= n
+ return n, err
+}
+
+// boundedReader is an io.Reader that yields at most remain bytes from r
+// before returning an error. Call reset to raise the budget back to max,
+// typically at a new message boundary.
+//
+// Used to cap the decoded size of a single map response so a malicious
+// server can't send a small zstd frame that explodes into gigabytes of
+// junk for the json.Decoder to consume.
+type boundedReader struct {
+ r io.Reader
+ max int64
+ remain int64
+}
+
+func (b *boundedReader) Read(p []byte) (int, error) {
+ if b.remain <= 0 {
+ return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max)
+ }
+ if int64(len(p)) > b.remain {
+ p = p[:b.remain]
+ }
+ n, err := b.r.Read(p)
+ b.remain -= int64(n)
+ return n, err
+}
+
+func (b *boundedReader) reset() { b.remain = b.max }
+
+// MapSession wraps an in-progress map response stream. Call Next to read
+// each MapResponse. Call Close when done.
+type MapSession struct {
+ res *http.Response
+ stream bool
+ noiseDoer func(*http.Request) (*http.Response, error)
+
+ // inNext detects concurrent NextInto callers. It CAS-flips
+ // false→true on entry and back to false on exit; a failed CAS
+ // panics, akin to how the Go runtime detects concurrent map
+ // access. It does not serialize Close vs. NextInto; that's
+ // nextMu's job.
+ inNext atomic.Bool
+
+ // nextMu is held while [MapSession.NextInto] is running jdec.Decode,
+ // so that Close can wait for an in-flight Decode to unwind before it
+ // touches zdec (Reset, pool-Put) and avoids racing with the running
+ // Read chain that Decode drives.
+ nextMu sync.Mutex
+ read int // guarded by nextMu
+ closed bool // guarded by nextMu
+ zdec *zstd.Decoder // reads from a framedReader wrapping res.Body
+ jdec *json.Decoder // reads decompressed JSON from zdec
+
+ closeOnce sync.Once
+ closeErr error
+}
+
+// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session.
+func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) {
+ return s.noiseDoer(req)
+}
+
+// Next reads and returns the next MapResponse from the stream.
+// For non-streaming sessions, the first call returns the single response
+// and subsequent calls return io.EOF.
+// For streaming sessions, Next blocks until the next response arrives
+// or the server closes the connection.
+//
+// Each call allocates a fresh MapResponse. Callers that want to amortize
+// the allocation across calls can use [MapSession.NextInto].
+//
+// Next and NextInto are not safe to call concurrently from multiple
+// goroutines on the same [MapSession]; a concurrent call panics, akin
+// to the Go runtime's concurrent map access detection. [MapSession.Close]
+// may be called concurrently to abort an in-flight Next.
+func (s *MapSession) Next() (*tailcfg.MapResponse, error) {
+ var resp tailcfg.MapResponse
+ if err := s.NextInto(&resp); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+// NextInto is like [MapSession.Next] but decodes the next MapResponse into
+// the caller-supplied *resp rather than allocating a new one. The pointer's
+// pointee is zeroed before decoding so fields from a prior response do not
+// persist.
+//
+// For non-streaming sessions, the first call decodes the single response
+// and subsequent calls return io.EOF.
+// For streaming sessions, NextInto blocks until the next response arrives
+// or the server closes the connection.
+//
+// See [MapSession.Next] for concurrency rules; those apply to NextInto too.
+func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error {
+ if !s.inNext.CompareAndSwap(false, true) {
+ panic("tsp: invalid concurrent call to MapSession.Next/NextInto")
+ }
+ defer s.inNext.Store(false)
+
+ s.nextMu.Lock()
+ defer s.nextMu.Unlock()
+ if s.closed {
+ return errSessionClosed
+ }
+ if !s.stream && s.read > 0 {
+ return io.EOF
+ }
+ *resp = tailcfg.MapResponse{}
+ if err := s.jdec.Decode(resp); err != nil {
+ return err
+ }
+ s.read++
+ return nil
+}
+
+// Close returns the session's zstd decoder to the pool and closes the
+// underlying HTTP response body. It is safe to call Close multiple times
+// and from multiple goroutines, including while a [MapSession.Next] or
+// [MapSession.NextInto] call is in flight on another goroutine (which
+// will return an error once the body close propagates).
+func (s *MapSession) Close() error {
+ // Callers are likely to race a deferred Close with a time.AfterFunc
+ // timeout (or similar) Close that aborts a hung Next. Without the
+ // Once, both Closes would Put the same *zstd.Decoder into the pool,
+ // corrupting it, and the Reset/Put in one would race with the
+ // zdec.Read that the hung Next is driving.
+ //
+ // Ordering inside the Once: close the body first to unblock any
+ // in-flight NextInto (its Read chain ends at res.Body and will
+ // return an error once it's closed). That lets NextInto unwind and
+ // release nextMu. Only then do we take nextMu ourselves and touch
+ // zdec, which is safe because no goroutine is still reading from
+ // it. Acquiring nextMu before closing the body would deadlock
+ // against a hung NextInto.
+ s.closeOnce.Do(func() {
+ s.closeErr = s.res.Body.Close()
+ s.nextMu.Lock()
+ defer s.nextMu.Unlock()
+ s.closed = true
+ s.zdec.Reset(nil)
+ zstdDecoderPool.Put(s.zdec)
+ })
+ return s.closeErr
+}
+
+// Map sends a map request to the coordination server and returns a MapSession
+// for reading the framed, zstd-compressed response(s).
+func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) {
+ if opts.NodeKey.IsZero() {
+ return nil, fmt.Errorf("NodeKey is required")
+ }
+
+ hi := opts.Hostinfo
+ if hi == nil {
+ hi = defaultHostinfo()
+ }
+
+ mapReq := tailcfg.MapRequest{
+ Version: tailcfg.CurrentCapabilityVersion,
+ NodeKey: opts.NodeKey.Public(),
+ Hostinfo: hi,
+ Stream: opts.Stream,
+ Compress: "zstd",
+ OmitPeers: opts.OmitPeers,
+ // Streaming requires the server to track us as "connected",
+ // which in turn requires ReadOnly=false. Non-streaming polls
+ // stay ReadOnly to minimize side effects.
+ ReadOnly: !opts.Stream,
+ }
+
+ body, err := json.Marshal(mapReq)
+ if err != nil {
+ return nil, fmt.Errorf("encoding map request: %w", err)
+ }
+
+ nc, err := c.noiseClient(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("establishing noise connection: %w", err)
+ }
+
+ url := c.serverURL + "/machine/map"
+ url = strings.Replace(url, "http:", "https:", 1)
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("creating map request: %w", err)
+ }
+ ts2021.AddLBHeader(req, opts.NodeKey.Public())
+
+ res, err := nc.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("map request: %w", err)
+ }
+
+ if res.StatusCode != 200 {
+ msg, _ := io.ReadAll(res.Body)
+ res.Body.Close()
+ return nil, fmt.Errorf("map request: http %d: %.200s",
+ res.StatusCode, strings.TrimSpace(string(msg)))
+ }
+
+ maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize)
+ bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
+ fr := &framedReader{
+ r: res.Body,
+ maxSize: maxMessageSize,
+ onNewFrame: bounded.reset,
+ }
+
+ zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder)
+ if zdec != nil {
+ if err := zdec.Reset(fr); err != nil {
+ // Reset can fail if the previous stream is in a bad state; drop
+ // the decoder and create a fresh one.
+ zdec = nil
+ }
+ }
+ if zdec == nil {
+ zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
+ if err != nil {
+ res.Body.Close()
+ return nil, fmt.Errorf("creating zstd decoder: %w", err)
+ }
+ }
+ bounded.r = zdec
+
+ return &MapSession{
+ res: res,
+ stream: opts.Stream,
+ noiseDoer: nc.Do,
+ zdec: zdec,
+ jdec: json.NewDecoder(bounded),
+ }, nil
+}