summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2026-03-04 03:31:13 +0000
committerBrad Fitzpatrick <brad@danga.com>2026-03-03 20:56:20 -0800
commitd42b3743b72c8fd7df77945f99bf6aaec617f33d (patch)
treebd859c0d3c7913dbd1e6de9f91d42223d56e73d5
parent120f27f383d5501d1483c5238b591e66db500fe4 (diff)
downloadtailscale-d42b3743b72c8fd7df77945f99bf6aaec617f33d.tar.xz
tailscale-d42b3743b72c8fd7df77945f99bf6aaec617f33d.zip
net/porttrack: add net.Listen wrapper to help tests allocate ports race-free
Updates tailscale/corp#27805 Updates tailscale/corp#27806 Updates tailscale/corp#37964 Change-Id: I7bb5ed7f258e840a8208e5d725c7b2f126d7ef96 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
-rw-r--r--net/porttrack/porttrack.go176
-rw-r--r--net/porttrack/porttrack_test.go95
2 files changed, 271 insertions, 0 deletions
diff --git a/net/porttrack/porttrack.go b/net/porttrack/porttrack.go
new file mode 100644
index 000000000..822e7200e
--- /dev/null
+++ b/net/porttrack/porttrack.go
@@ -0,0 +1,176 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package porttrack provides race-free ephemeral port assignment for
+// subprocess tests. The parent test process creates a [Collector] that
+// listens on a TCP port; the child process uses [Listen] which, when
+// given a magic address, binds to localhost:0 and reports the actual
+// port back to the collector.
+//
+// The magic address format is:
+//
+// testport-report:HOST:PORT/LABEL
+//
+// where HOST:PORT is the collector's TCP address and LABEL identifies
+// which listener this is (e.g. "main", "plaintext").
+//
+// When [Listen] is called with a non-magic address, it falls through to
+// [net.Listen] with zero overhead beyond a single [strings.HasPrefix]
+// check.
+package porttrack
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+ "sync"
+
+ "tailscale.com/util/testenv"
+)
+
+const magicPrefix = "testport-report:"
+
+// Collector is the parent/test side of the porttrack protocol. It
+// listens for port reports from child processes that used [Listen]
+// with a magic address obtained from [Collector.Addr].
+type Collector struct {
+ ln net.Listener
+ mu sync.Mutex
+ cond *sync.Cond
+ ports map[string]int
+ err error // non-nil if a context passed to Port was cancelled
+}
+
+// NewCollector creates a new Collector. The collector's TCP listener is
+// closed when t finishes.
+func NewCollector(t testenv.TB) *Collector {
+ t.Helper()
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("porttrack.NewCollector: %v", err)
+ }
+ c := &Collector{
+ ln: ln,
+ ports: make(map[string]int),
+ }
+ c.cond = sync.NewCond(&c.mu)
+ go c.accept(t)
+ t.Cleanup(func() { ln.Close() })
+ return c
+}
+
+// accept runs in a goroutine, accepting connections and parsing port
+// reports until the listener is closed.
+func (c *Collector) accept(t testenv.TB) {
+ for {
+ conn, err := c.ln.Accept()
+ if err != nil {
+ return // listener closed
+ }
+ go c.handleConn(t, conn)
+ }
+}
+
+func (c *Collector) handleConn(t testenv.TB, conn net.Conn) {
+ defer conn.Close()
+ scanner := bufio.NewScanner(conn)
+ for scanner.Scan() {
+ line := scanner.Text()
+ label, portStr, ok := strings.Cut(line, "\t")
+ if !ok {
+ t.Errorf("porttrack: malformed report line: %q", line)
+ return
+ }
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ t.Errorf("porttrack: bad port in report %q: %v", line, err)
+ return
+ }
+ c.mu.Lock()
+ c.ports[label] = port
+ c.cond.Broadcast()
+ c.mu.Unlock()
+ }
+}
+
+// Addr returns a magic address string that, when passed to [Listen],
+// causes the child to bind to localhost:0 and report its actual port
+// back to this collector under the given label.
+func (c *Collector) Addr(label string) string {
+ return magicPrefix + c.ln.Addr().String() + "/" + label
+}
+
+// Port blocks until the child process has reported the port for the
+// given label, then returns it. If ctx is cancelled before a port is
+// reported, Port returns the context's cause as an error.
+func (c *Collector) Port(ctx context.Context, label string) (int, error) {
+ stop := context.AfterFunc(ctx, func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.err == nil {
+ c.err = context.Cause(ctx)
+ }
+ c.cond.Broadcast()
+ })
+ defer stop()
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for {
+ if p, ok := c.ports[label]; ok {
+ return p, nil
+ }
+ if c.err != nil {
+ return 0, c.err
+ }
+ c.cond.Wait()
+ }
+}
+
+// Listen is the child/production side of the porttrack protocol.
+//
+// If address has the magic prefix (as returned by [Collector.Addr]),
+// Listen binds to localhost:0 on the given network, then TCP-connects
+// to the collector and writes "LABEL\tPORT\n" to report the actual
+// port. The collector connection is closed before returning.
+//
+// If address does not have the magic prefix, Listen is simply
+// [net.Listen](network, address).
+func Listen(network, address string) (net.Listener, error) {
+ rest, ok := strings.CutPrefix(address, magicPrefix)
+ if !ok {
+ return net.Listen(network, address)
+ }
+
+ // rest is "HOST:PORT/LABEL"
+ slashIdx := strings.LastIndex(rest, "/")
+ if slashIdx < 0 {
+ return nil, fmt.Errorf("porttrack: malformed magic address %q: missing /LABEL", address)
+ }
+ collectorAddr := rest[:slashIdx]
+ label := rest[slashIdx+1:]
+
+ ln, err := net.Listen(network, "localhost:0")
+ if err != nil {
+ return nil, err
+ }
+
+ port := ln.Addr().(*net.TCPAddr).Port
+
+ conn, err := net.Dial("tcp", collectorAddr)
+ if err != nil {
+ ln.Close()
+ return nil, fmt.Errorf("porttrack: failed to connect to collector at %s: %v", collectorAddr, err)
+ }
+ _, err = fmt.Fprintf(conn, "%s\t%d\n", label, port)
+ conn.Close()
+ if err != nil {
+ ln.Close()
+ return nil, fmt.Errorf("porttrack: failed to report port to collector: %v", err)
+ }
+
+ return ln, nil
+}
diff --git a/net/porttrack/porttrack_test.go b/net/porttrack/porttrack_test.go
new file mode 100644
index 000000000..06412d875
--- /dev/null
+++ b/net/porttrack/porttrack_test.go
@@ -0,0 +1,95 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package porttrack
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "testing"
+)
+
+func TestCollectorAndListen(t *testing.T) {
+ c := NewCollector(t)
+
+ labels := []string{"main", "plaintext", "debug"}
+ ports := make([]int, len(labels))
+
+ for i, label := range labels {
+ ln, err := Listen("tcp", c.Addr(label))
+ if err != nil {
+ t.Fatalf("Listen(%q): %v", label, err)
+ }
+ defer ln.Close()
+ p, err := c.Port(t.Context(), label)
+ if err != nil {
+ t.Fatalf("Port(%q): %v", label, err)
+ }
+ ports[i] = p
+ }
+
+ // All ports should be distinct non-zero values.
+ seen := map[int]string{}
+ for i, label := range labels {
+ if ports[i] == 0 {
+ t.Errorf("Port(%q) = 0", label)
+ }
+ if prev, ok := seen[ports[i]]; ok {
+ t.Errorf("Port(%q) = Port(%q) = %d", label, prev, ports[i])
+ }
+ seen[ports[i]] = label
+ }
+}
+
+func TestListenPassthrough(t *testing.T) {
+ ln, err := Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Listen passthrough: %v", err)
+ }
+ defer ln.Close()
+ if ln.Addr().(*net.TCPAddr).Port == 0 {
+ t.Fatal("expected non-zero port")
+ }
+}
+
+func TestRoundTrip(t *testing.T) {
+ c := NewCollector(t)
+
+ ln, err := Listen("tcp", c.Addr("http"))
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ // Start a server on the listener.
+ go http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNoContent)
+ }))
+
+ port, err := c.Port(t.Context(), "http")
+ if err != nil {
+ t.Fatalf("Port: %v", err)
+ }
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port))
+ if err != nil {
+ t.Fatalf("http.Get: %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusNoContent {
+ t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
+ }
+}
+
+func TestPortContextCancelled(t *testing.T) {
+ c := NewCollector(t)
+ // Nobody will ever report "never", so Port should block until ctx is done.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+ _, err := c.Port(ctx, "never")
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("Port with cancelled context: got %v, want %v", err, context.Canceled)
+ }
+}