summaryrefslogtreecommitdiffhomepage
path: root/control/controlclient
diff options
context:
space:
mode:
Diffstat (limited to 'control/controlclient')
-rw-r--r--control/controlclient/direct.go32
-rw-r--r--control/controlclient/noise.go50
2 files changed, 64 insertions, 18 deletions
diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go
index f5d1f0410..8c501dc10 100644
--- a/control/controlclient/direct.go
+++ b/control/controlclient/direct.go
@@ -42,6 +42,7 @@ import (
"tailscale.com/net/tlsdial"
"tailscale.com/net/tsdial"
"tailscale.com/net/tshttpproxy"
+ "tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tka"
"tailscale.com/tstime"
@@ -82,6 +83,11 @@ type Direct struct {
dialPlan ControlDialPlanner // can be nil
+ // lastServerAddr is set to the most recent address that we
+ // successfully connected to. It is used to prioritize this address
+ // when reconnecting (e.g. when a control server restart happens).
+ lastServerAddr syncs.AtomicValue[netip.Addr]
+
mu sync.Mutex // mutex guards the following fields
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
serverNoiseKey key.MachinePublic
@@ -1428,6 +1434,8 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, d time.Duration, cl
}
}
+var useLastAddr = envknob.RegisterBool("TS_CONTROLCLIENT_USE_LAST_ADDR")
+
// getNoiseClient returns the noise client, creating one if one doesn't exist.
func (c *Direct) getNoiseClient() (*NoiseClient, error) {
c.mu.Lock()
@@ -1444,6 +1452,12 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) {
if c.dialPlan != nil {
dp = c.dialPlan.Load
}
+
+ var lastAddr *syncs.AtomicValue[netip.Addr]
+ if useLastAddr() {
+ lastAddr = &c.lastServerAddr
+ }
+
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) {
k, err := c.getMachinePrivKey()
if err != nil {
@@ -1451,18 +1465,20 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) {
}
c.logf("[v1] creating new noise client")
nc, err := NewNoiseClient(NoiseOpts{
- PrivKey: k,
- ServerPubKey: serverNoiseKey,
- ServerURL: c.serverURL,
- Dialer: c.dialer,
- DNSCache: c.dnsCache,
- Logf: c.logf,
- NetMon: c.netMon,
- DialPlan: dp,
+ PrivKey: k,
+ ServerPubKey: serverNoiseKey,
+ ServerURL: c.serverURL,
+ Dialer: c.dialer,
+ DNSCache: c.dnsCache,
+ Logf: c.logf,
+ NetMon: c.netMon,
+ DialPlan: dp,
+ LastServerAddr: lastAddr,
})
if err != nil {
return nil, err
}
+
c.mu.Lock()
defer c.mu.Unlock()
c.noiseClient = nc
diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go
index f3e5f1bde..3191312f5 100644
--- a/control/controlclient/noise.go
+++ b/control/controlclient/noise.go
@@ -12,6 +12,7 @@ import (
"io"
"math"
"net/http"
+ "net/netip"
"net/url"
"sync"
"time"
@@ -22,6 +23,7 @@ import (
"tailscale.com/net/dnscache"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
+ "tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key"
@@ -172,6 +174,8 @@ type NoiseClient struct {
// be nil.
dialPlan func() *tailcfg.ControlDialPlan
+ lastServerAddr *syncs.AtomicValue[netip.Addr] // can be nil
+
logf logger.Logf
netMon *netmon.Monitor
@@ -207,6 +211,12 @@ type NoiseOpts struct {
// DialPlan, if set, is a function that should return an explicit plan
// on how to connect to the server.
DialPlan func() *tailcfg.ControlDialPlan
+ // LastServerAddr, if non-nil, contains storage for the last address
+ // used to (successfully) connect to the control server. It will be
+ // prioritized when making a connection to the server.
+ //
+ // If nil, no last address will be stored or used.
+ LastServerAddr *syncs.AtomicValue[netip.Addr]
}
// NewNoiseClient returns a new noiseClient for the provided server and machine key.
@@ -237,16 +247,17 @@ func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) {
}
np := &NoiseClient{
- serverPubKey: opts.ServerPubKey,
- privKey: opts.PrivKey,
- host: u.Hostname(),
- httpPort: httpPort,
- httpsPort: httpsPort,
- dialer: opts.Dialer,
- dnsCache: opts.DNSCache,
- dialPlan: opts.DialPlan,
- logf: opts.Logf,
- netMon: opts.NetMon,
+ serverPubKey: opts.ServerPubKey,
+ privKey: opts.PrivKey,
+ host: u.Hostname(),
+ httpPort: httpPort,
+ httpsPort: httpsPort,
+ dialer: opts.Dialer,
+ dnsCache: opts.DNSCache,
+ dialPlan: opts.DialPlan,
+ lastServerAddr: opts.LastServerAddr,
+ logf: opts.Logf,
+ netMon: opts.NetMon,
}
// Create the HTTP/2 Transport using a net/http.Transport
@@ -334,6 +345,14 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
})
var ce contextErr
if err == nil || !errors.As(err, &ce) {
+ // Store this address as our last-successful address for future
+ // use if we need to reconnect.
+ if nc.lastServerAddr != nil {
+ if addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()); err == nil {
+ nc.lastServerAddr.Store(addr.Addr())
+ }
+ }
+
return conn, err
}
if ctx.Err() == nil {
@@ -429,6 +448,16 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
// handshake.
timeoutSec += 5
+ // If we have a last server address, then give ourselves a bit more
+ // time to try it first.
+ var lastAddr netip.Addr
+ if nc.lastServerAddr != nil {
+ lastAddr = nc.lastServerAddr.Load()
+ }
+ if lastAddr.IsValid() {
+ timeoutSec += 5
+ }
+
// Be extremely defensive and ensure that the timeout is in the range
// [5, 60] seconds (e.g. if we accidentally get a negative number).
if timeoutSec > 60 {
@@ -451,6 +480,7 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
Dialer: nc.dialer.SystemDial,
DNSCache: nc.dnsCache,
DialPlan: dialPlan,
+ LastServerAddr: lastAddr,
Logf: nc.logf,
NetMon: nc.netMon,
Clock: tstime.StdClock{},