summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2025-09-20 16:48:18 -0700
committerBrad Fitzpatrick <bradfitz@tailscale.com>2025-09-21 01:05:54 +0000
commit17643b05eb5a1e24b84c6937061f781e31dce0b9 (patch)
tree906a6fe511f9a15e5eea5810337f04b71bdf6e90
parent1b6bc37f2859007dc4ed949b14f1f8531990b3cf (diff)
downloadtailscale-jamesbrad/controlhttp-race-dial.tar.xz
tailscale-jamesbrad/controlhttp-race-dial.zip
control/controlhttp: simplify, fix race dialing, remove priority conceptjamesbrad/controlhttp-race-dial
Fixes tailscale/corp#32534 Co-authored-by: James Tucker <james@tailscale.com> Change-Id: I4eb57f046d8b40403220e40eb67a31c41adb3a38 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
-rw-r--r--control/controlhttp/client.go182
-rw-r--r--control/controlhttp/http_test.go161
2 files changed, 121 insertions, 222 deletions
diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go
index 87061c310..5b1fed422 100644
--- a/control/controlhttp/client.go
+++ b/control/controlhttp/client.go
@@ -27,14 +27,13 @@ import (
"errors"
"fmt"
"io"
- "math"
+ "log"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"net/url"
"runtime"
- "sort"
"sync/atomic"
"time"
@@ -53,7 +52,6 @@ import (
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
- "tailscale.com/util/multierr"
)
var stdDialer net.Dialer
@@ -110,18 +108,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
}
candidates := a.DialPlan.Candidates
- // Otherwise, we try dialing per the plan. Store the highest priority
- // in the list, so that if we get a connection to one of those
- // candidates we can return quickly.
- var highestPriority int = math.MinInt
- for _, c := range candidates {
- if c.Priority > highestPriority {
- highestPriority = c.Priority
- }
- }
-
- // This context allows us to cancel in-flight connections if we get a
- // highest-priority connection before we're all done.
+ // Create a context to be canceled as we return, so once we get a good connection,
+ // we can drop all the other ones.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -129,142 +117,61 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
type dialResult struct {
conn *ClientConn
err error
- cand tailcfg.ControlIPCandidate
}
- resultsCh := make(chan dialResult, len(candidates))
-
- var pending atomic.Int32
- pending.Store(int32(len(candidates)))
- for _, c := range candidates {
- go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
- var (
- conn *ClientConn
- err error
- )
-
- // Always send results back to our channel.
- defer func() {
- resultsCh <- dialResult{conn, err, c}
- if pending.Add(-1) == 0 {
- close(resultsCh)
- }
- }()
+ resultsCh := make(chan dialResult) // unbuffered, never closed
- // If non-zero, wait the configured start timeout
- // before we do anything.
- if c.DialStartDelaySec > 0 {
- a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
- tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
- defer tmr.Stop()
- select {
- case <-ctx.Done():
- err = ctx.Err()
- return
- case <-tmrChannel:
- }
- }
+ dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) {
+ a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %v", cand.DialStartDelaySec, a.Hostname, cand.IP)
- // Now, create a sub-context with the given timeout and
- // try dialing the provided host.
- ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
- defer cancel()
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second)))
+ defer cancel()
- if c.IP.IsValid() {
- a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
- } else if c.ACEHost != "" {
- a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost)
- }
- // This will dial, and the defer above sends it back to our parent.
- conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost)
- }(ctx, c)
+ if cand.IP.IsValid() {
+ a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, cand.IP)
+ } else if cand.ACEHost != "" {
+ a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, cand.ACEHost)
+ }
+ // This will dial, and the defer above sends it back to our parent.
+ return a.dialHostOpt(ctx, cand.IP, cand.ACEHost)
}
- var results []dialResult
- for res := range resultsCh {
- // If we get a response that has the highest priority, we don't
- // need to wait for any of the other connections to finish; we
- // can just return this connection.
- //
- // TODO(andrew): we could make this better by keeping track of
- // the highest remaining priority dynamically, instead of just
- // checking for the highest total
- if res.cand.Priority == highestPriority && res.conn != nil {
- a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String()))
-
- // Drain the channel and any existing connections in
- // the background.
+ for _, cand := range candidates {
+ timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() {
go func() {
- for _, res := range results {
- if res.conn != nil {
- res.conn.Close()
+ conn, err := dialCand(cand)
+ select {
+ case resultsCh <- dialResult{conn, err}:
+ if err == nil {
+ a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String()))
}
- }
- for res := range resultsCh {
- if res.conn != nil {
- res.conn.Close()
+ case <-ctx.Done():
+ if conn != nil {
+ conn.Close()
}
}
- if a.drainFinished != nil {
- close(a.drainFinished)
- }
}()
- return res.conn, nil
- }
-
- // This isn't a highest-priority result, so just store it until
- // we're done.
- results = append(results, res)
+ })
+ defer timer.Stop()
}
- // After we finish this function, close any remaining open connections.
- defer func() {
- for _, result := range results {
- // Note: below, we nil out the returned connection (if
- // any) in the slice so we don't close it.
- if result.conn != nil {
- result.conn.Close()
+ var errs []error
+ for {
+ select {
+ case res := <-resultsCh:
+ if res.err == nil {
+ return res.conn, nil
}
+ errs = append(errs, res.err)
+ if len(errs) == len(candidates) {
+ // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
+ a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...))
+ return a.dialHost(ctx)
+ }
+ case <-ctx.Done():
+ a.logf("controlhttp: context aborted dialing")
+ return nil, ctx.Err()
}
-
- // We don't drain asynchronously after this point, so notify our
- // channel when we return.
- if a.drainFinished != nil {
- close(a.drainFinished)
- }
- }()
-
- // Sort by priority, then take the first non-error response.
- sort.Slice(results, func(i, j int) bool {
- // NOTE: intentionally inverted so that the highest priority
- // item comes first
- return results[i].cand.Priority > results[j].cand.Priority
- })
-
- var (
- conn *ClientConn
- errs []error
- )
- for i, result := range results {
- if result.err != nil {
- errs = append(errs, result.err)
- continue
- }
-
- a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String()))
- conn = result.conn
- results[i].conn = nil // so we don't close it in the defer
- return conn, nil
}
- if ctx.Err() != nil {
- a.logf("controlhttp: context aborted dialing")
- return nil, ctx.Err()
- }
-
- merr := multierr.New(errs...)
-
- // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
- a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
- return a.dialHost(ctx)
}
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
@@ -422,6 +329,11 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost
go try(u443)
} // else we lost the race and it started already which is what we want
case u443:
+ if u80 == nil {
+ log.Printf("XXXX no port 80 so returning error: %v", res.err)
+ // We never started a port 80 dial, so just return the port 443 error.
+ return nil, res.err
+ }
err443 = res.err
default:
panic("invalid")
diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go
index 0b4e117f9..ffae74a6f 100644
--- a/control/controlhttp/http_test.go
+++ b/control/controlhttp/http_test.go
@@ -15,17 +15,19 @@ import (
"net/http/httputil"
"net/netip"
"net/url"
- "runtime"
"slices"
"strconv"
+ "strings"
"sync"
"testing"
+ "testing/synctest"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpcommon"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/health"
+ "tailscale.com/net/memnet"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/net/socks5"
@@ -545,35 +547,13 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
}
func TestDialPlan(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skip("only works on Linux due to multiple localhost addresses")
- }
-
client, server := key.NewMachine(), key.NewMachine()
const (
testProtocolVersion = 1
)
- getRandomPort := func() string {
- ln, err := net.Listen("tcp", ":0")
- if err != nil {
- t.Fatalf("net.Listen: %v", err)
- }
- defer ln.Close()
- _, port, err := net.SplitHostPort(ln.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- return port
- }
-
- // We need consistent ports for each address; these are chosen
- // randomly and we hope that they won't conflict during this test.
- httpPort := getRandomPort()
- httpsPort := getRandomPort()
-
- makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
+ makeHandler := func(t *testing.T, memNet *memnet.Network, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
done := make(chan struct{})
t.Cleanup(func() {
close(done)
@@ -592,11 +572,11 @@ func TestDialPlan(t *testing.T) {
handler = wrap(handler)
}
- httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
+ httpLn, err := memNet.Listen("tcp", host.String()+":80")
if err != nil {
t.Fatalf("HTTP listen: %v", err)
}
- httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
+ httpsLn, err := memNet.Listen("tcp", host.String()+":443")
if err != nil {
t.Fatalf("HTTPS listen: %v", err)
}
@@ -616,7 +596,6 @@ func TestDialPlan(t *testing.T) {
t.Cleanup(func() {
httpsServer.Close()
})
- return
}
fallbackAddr := netip.MustParseAddr("127.0.0.1")
@@ -686,74 +665,82 @@ func TestDialPlan(t *testing.T) {
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
- // TODO(awly): replace this with tstest.NewClock and update the
- // test to advance the clock correctly.
- clock := tstime.StdClock{}
- makeHandler(t, "fallback", fallbackAddr, nil)
- makeHandler(t, "good", goodAddr, nil)
- makeHandler(t, "other", otherAddr, nil)
- makeHandler(t, "other2", other2Addr, nil)
- makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
- return brokenMITMHandler(clock)
- })
+ synctest.Test(t, func(t *testing.T) {
- dialer := closeTrackDialer{
- t: t,
- inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial,
- conns: make(map[*closeTrackConn]bool),
- }
- defer dialer.Done()
+ // Get the synctest clock way out to 2025 at least so the
+ // net/http/httptest TLS client certs are valid?
+ // TODO(bradfitz): this might not be necessary. Still debugging.
+ time.Sleep(26 * 365 * 24 * time.Hour)
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
+ var memNet memnet.Network
- // By default, we intentionally point to something that
- // we know won't connect, since we want a fallback to
- // DNS to be an error.
- host := "example.com"
- if tt.allowFallback {
- host = "localhost"
- }
+ clock := tstime.StdClock{}
+ makeHandler(t, &memNet, "fallback", fallbackAddr, nil)
+ makeHandler(t, &memNet, "good", goodAddr, nil)
+ makeHandler(t, &memNet, "other", otherAddr, nil)
+ makeHandler(t, &memNet, "other2", other2Addr, nil)
+ makeHandler(t, &memNet, "broken", brokenAddr, func(h http.Handler) http.Handler {
+ return brokenMITMHandler(clock)
+ })
- drained := make(chan struct{})
- a := &Dialer{
- Hostname: host,
- HTTPPort: httpPort,
- HTTPSPort: httpsPort,
- MachineKey: client,
- ControlKey: server.Public(),
- ProtocolVersion: testProtocolVersion,
- Dialer: dialer.Dial,
- Logf: t.Logf,
- DialPlan: tt.plan,
- proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
- drainFinished: drained,
- omitCertErrorLogging: true,
- testFallbackDelay: 50 * time.Millisecond,
- Clock: clock,
- HealthTracker: health.NewTracker(eventbustest.NewBus(t)),
- }
+ dialer := closeTrackDialer{
+ t: t,
+ inner: memNet.Dial,
+ conns: make(map[*closeTrackConn]bool),
+ }
+ defer dialer.Done()
- conn, err := a.dial(ctx)
- if err != nil {
- t.Fatalf("dialing controlhttp: %v", err)
- }
- defer conn.Close()
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // By default, we intentionally point to something that
+ // we know won't connect, since we want a fallback to
+ // DNS to be an error.
+ host := "example.com"
+ if tt.allowFallback {
+ host = "localhost"
+ }
- raddr := conn.RemoteAddr().(*net.TCPAddr)
+ a := &Dialer{
+ Hostname: host,
+ MachineKey: client,
+ ControlKey: server.Public(),
+ ProtocolVersion: testProtocolVersion,
+ Dialer: dialer.Dial,
+ Logf: t.Logf,
+ DialPlan: tt.plan,
+ proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
+ omitCertErrorLogging: true,
+ testFallbackDelay: 50 * time.Millisecond,
+ Clock: clock,
+ HealthTracker: health.NewTracker(eventbustest.NewBus(t)),
+ }
- got, ok := netip.AddrFromSlice(raddr.IP)
- if !ok {
- t.Errorf("invalid remote IP: %v", raddr.IP)
- } else if got != tt.want {
- t.Errorf("got connection from %q; want %q", got, tt.want)
- } else {
- t.Logf("successfully connected to %q", raddr.String())
- }
+ conn, err := a.dial(ctx)
+ if err != nil {
+ t.Fatalf("dialing controlhttp: %v", err)
+ }
+ defer conn.Close()
+
+ raddrStr := conn.RemoteAddr().String()
+
+ raddrStr = strings.TrimSuffix(raddrStr, "|1") // memnet noise
+ raddrPort, err := netip.ParseAddrPort(raddrStr)
+ if err != nil {
+ t.Fatalf("parsing remote addr %q: %v", raddrStr, err)
+ }
+
+ got := raddrPort.Addr()
+ if got != tt.want {
+ t.Errorf("got connection from %q; want %q", got, tt.want)
+ } else {
+ t.Logf("successfully connected to %q", got)
+ }
- // Wait until our dialer drains so we can verify that
- // all connections are closed.
- <-drained
+ // Wait until our dialer drains so we can verify that
+ // all connections are closed.
+ synctest.Wait()
+ })
})
}
}