summaryrefslogtreecommitdiffhomepage
path: root/control/tsp/map_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'control/tsp/map_test.go')
-rw-r--r--control/tsp/map_test.go270
1 files changed, 270 insertions, 0 deletions
diff --git a/control/tsp/map_test.go b/control/tsp/map_test.go
new file mode 100644
index 000000000..15b32dd36
--- /dev/null
+++ b/control/tsp/map_test.go
@@ -0,0 +1,270 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsp
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/klauspost/compress/zstd"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tstest/integration/testcontrol"
+ "tailscale.com/types/key"
+)
+
+func TestMapAgainstTestControl(t *testing.T) {
+ ctrl := &testcontrol.Server{}
+ ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl)
+ ctrl.HTTPTestServer.Start()
+ t.Cleanup(ctrl.HTTPTestServer.Close)
+ baseURL := ctrl.HTTPTestServer.URL
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ serverKey, err := DiscoverServerKey(ctx, baseURL)
+ if err != nil {
+ t.Fatalf("DiscoverServerKey: %v", err)
+ }
+
+ register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) {
+ t.Helper()
+ nodeKey = key.NewNode()
+ machineKey = key.NewMachine()
+ c, err := NewClient(ClientOpts{
+ ServerURL: baseURL,
+ MachineKey: machineKey,
+ })
+ if err != nil {
+ t.Fatalf("NewClient %s: %v", hostname, err)
+ }
+ defer c.Close()
+ c.SetControlPublicKey(serverKey)
+ if _, err := c.Register(ctx, RegisterOpts{
+ NodeKey: nodeKey,
+ Hostinfo: &tailcfg.Hostinfo{Hostname: hostname},
+ }); err != nil {
+ t.Fatalf("Register %s: %v", hostname, err)
+ }
+ return nodeKey, machineKey
+ }
+
+ nodeKeyA, machineKeyA := register("a")
+ nodeKeyB, _ := register("b")
+
+ clientA, err := NewClient(ClientOpts{
+ ServerURL: baseURL,
+ MachineKey: machineKeyA,
+ })
+ if err != nil {
+ t.Fatalf("NewClient A: %v", err)
+ }
+ defer clientA.Close()
+ clientA.SetControlPublicKey(serverKey)
+
+ session, err := clientA.Map(ctx, MapOpts{
+ NodeKey: nodeKeyA,
+ Hostinfo: &tailcfg.Hostinfo{Hostname: "a"},
+ Stream: true,
+ })
+ if err != nil {
+ t.Fatalf("Map: %v", err)
+ }
+ defer session.Close()
+
+ // nextNonKeepalive returns the next non-keepalive MapResponse, to keep
+ // the test robust if a server-side keepalive arrives mid-test.
+ nextNonKeepalive := func() *tailcfg.MapResponse {
+ t.Helper()
+ for {
+ resp, err := session.Next()
+ if err != nil {
+ t.Fatalf("session.Next: %v", err)
+ }
+ if resp.KeepAlive {
+ continue
+ }
+ return resp
+ }
+ }
+
+ // First MapResponse: expect node A as self and node B in Peers.
+ first := nextNonKeepalive()
+ if first.Node == nil {
+ t.Fatal("first response has nil Node")
+ }
+ if got, want := first.Node.Key, nodeKeyA.Public(); got != want {
+ t.Errorf("first Node.Key = %v, want %v", got, want)
+ }
+ var foundB bool
+ for _, p := range first.Peers {
+ if p.Key == nodeKeyB.Public() {
+ foundB = true
+ break
+ }
+ }
+ if !foundB {
+ t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers))
+ }
+
+ // Inject raw MapResponses and verify they come out the reader, in order.
+ // msgToSend is single-slot, so we must consume each before injecting the next.
+ for i := range 3 {
+ want := fmt.Sprintf("injected-%d.example.com", i)
+ inject := &tailcfg.MapResponse{Domain: want}
+ if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) {
+ t.Fatalf("AddRawMapResponse %d: node not connected", i)
+ }
+ got := nextNonKeepalive()
+ if got.Domain != want {
+ t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want)
+ }
+ }
+}
+
+// newTestPipeline builds the same framedReader → zstd → boundedReader →
+// json.Decoder pipeline that [Client.Map] builds for a live session, but
+// feeds it from a raw byte slice. Returned jdec can be used with Decode to
+// pull out MapResponses.
+func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder {
+ t.Helper()
+ bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
+ fr := &framedReader{
+ r: bytes.NewReader(wire),
+ maxSize: maxMessageSize,
+ onNewFrame: bounded.reset,
+ }
+ zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
+ if err != nil {
+ t.Fatalf("zstd.NewReader: %v", err)
+ }
+ t.Cleanup(zdec.Close)
+ bounded.r = zdec
+ return json.NewDecoder(bounded)
+}
+
+// zstdFrame returns a zstd-compressed frame of b.
+func zstdFrame(t testing.TB, b []byte) []byte {
+ t.Helper()
+ enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1))
+ if err != nil {
+ t.Fatalf("zstd.NewWriter: %v", err)
+ }
+ defer enc.Close()
+ return enc.EncodeAll(b, nil)
+}
+
+// wireFrame writes a 4-byte little-endian length prefix plus payload to buf.
+func wireFrame(buf *bytes.Buffer, payload []byte) {
+ var hdr [4]byte
+ binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload)))
+ buf.Write(hdr[:])
+ buf.Write(payload)
+}
+
+// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming
+// a frame larger than the configured cap is rejected before any payload
+// bytes are read from the stream.
+func TestMapFrameSizeTooLarge(t *testing.T) {
+ const max = 4 << 20
+ var wire bytes.Buffer
+ var hdr [4]byte
+ binary.LittleEndian.PutUint32(hdr[:], (max + 1))
+ wire.Write(hdr[:])
+
+ jdec := newTestPipeline(t, wire.Bytes(), max)
+ var resp tailcfg.MapResponse
+ err := jdec.Decode(&resp)
+ if err == nil {
+ t.Fatal("Decode: got nil error, want frame-too-large")
+ }
+ if !strings.Contains(err.Error(), "exceeds max") {
+ t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
+ }
+}
+
+// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well
+// under the cap) which decompresses into a huge JSON payload is rejected.
+// This is the "zstd bomb" case: a tiny compressed frame that would
+// explode into a huge decoded payload for json.Decoder to consume.
+func TestMapDecodedSizeTooLarge(t *testing.T) {
+ const max = 4 << 20
+ big := strings.Repeat("a", 5<<20) // 5 MiB of 'a'
+ raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if int64(len(raw)) <= max {
+ t.Fatalf("raw JSON unexpectedly small: %d", len(raw))
+ }
+ compressed := zstdFrame(t, raw)
+ if int64(len(compressed)) >= max {
+ t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed))
+ }
+
+ var wire bytes.Buffer
+ wireFrame(&wire, compressed)
+
+ jdec := newTestPipeline(t, wire.Bytes(), max)
+ var resp tailcfg.MapResponse
+ err = jdec.Decode(&resp)
+ if err == nil {
+ t.Fatal("Decode: got nil error, want decoded-size-exceeded")
+ }
+ if !strings.Contains(err.Error(), "exceeds max") {
+ t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
+ }
+}
+
+// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded
+// budget is reset at each new frame boundary. Two consecutive 3-MiB frames
+// should both decode successfully under a 4-MiB per-frame cap. Without the
+// reset, the second frame would fail (remaining budget after frame 1 =
+// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more).
+func TestMapBudgetResetsBetweenFrames(t *testing.T) {
+ const max = 4 << 20
+ payload := strings.Repeat("a", 3<<20)
+ r1 := &tailcfg.MapResponse{Domain: payload + "-one"}
+ r2 := &tailcfg.MapResponse{Domain: payload + "-two"}
+
+ var wire bytes.Buffer
+ for _, r := range []*tailcfg.MapResponse{r1, r2} {
+ raw, err := json.Marshal(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if int64(len(raw)) >= max {
+ t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max)
+ }
+ compressed := zstdFrame(t, raw)
+ if int64(len(compressed)) >= max {
+ t.Fatalf("compressed size %d >= max %d", len(compressed), max)
+ }
+ wireFrame(&wire, compressed)
+ }
+
+ jdec := newTestPipeline(t, wire.Bytes(), max)
+
+ var got1, got2 tailcfg.MapResponse
+ if err := jdec.Decode(&got1); err != nil {
+ t.Fatalf("first Decode: %v", err)
+ }
+ if got1.Domain != r1.Domain {
+ t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain))
+ }
+ if err := jdec.Decode(&got2); err != nil {
+ t.Fatalf("second Decode: %v", err)
+ }
+ if got2.Domain != r2.Domain {
+ t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain))
+ }
+}