summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJames Tucker <james@tailscale.com>2025-11-03 14:53:11 -0800
committerJames Tucker <james@tailscale.com>2025-11-19 12:35:26 -0800
commitadf7bbf902e6410139d7517d23ace8c10a91bd00 (patch)
tree9e2a73f8ec284bfc2e4ead6b2cde67b677c7f36c
parent3b865d7c33b1e945e9122dbe6f4eeff696a84e0a (diff)
downloadtailscale-raggi/disco-key-tsmp.tar.xz
tailscale-raggi/disco-key-tsmp.zip
net,wgengine: add support for disco key exchnage via TSMPraggi/disco-key-tsmp
Updates tailscale/corp#34037 Signed-off-by: James Tucker <james@tailscale.com>
-rw-r--r--net/packet/tsmp.go68
-rw-r--r--net/packet/tsmp_test.go168
-rw-r--r--net/tstun/wrap.go47
-rw-r--r--wgengine/magicsock/derp.go10
-rw-r--r--wgengine/magicsock/magicsock.go142
-rw-r--r--wgengine/magicsock/magicsock_test.go56
-rw-r--r--wgengine/magicsock/tsmp_disco_test.go291
-rw-r--r--wgengine/userspace.go48
8 files changed, 818 insertions, 12 deletions
diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go
index 0ea321e84..b9b805b1a 100644
--- a/net/packet/tsmp.go
+++ b/net/packet/tsmp.go
@@ -18,7 +18,7 @@ import (
"tailscale.com/types/ipproto"
)
-const minTSMPSize = 7 // the rejected body is 7 bytes
+const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
// TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a
@@ -72,6 +72,12 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
+
+ // TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
+ TSMPTypeDiscoKeyRequest TSMPType = 'd'
+
+ // TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
+ TSMPTypeDiscoKeyUpdate TSMPType = 'D'
)
type TailscaleRejectReason byte
@@ -259,3 +265,63 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
return nil
}
+
+// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
+//
+// On the wire, after the IP header, it's currently 1 byte:
+// - 'd' (TSMPTypeDiscoKeyRequest)
+type TSMPDiscoKeyRequest struct{}
+
+func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
+ if pp.IPProto != ipproto.TSMP {
+ return
+ }
+ p := pp.Payload()
+ if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
+ return
+ }
+ return h, true
+}
+
+// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
+// It may be sent in response to a request, or unsolicited when a node
+// believes its peer may have stale disco key information.
+//
+// On the wire, after the IP header, it's currently 33 bytes:
+// - 'D' (TSMPTypeDiscoKeyUpdate)
+// - 32 bytes disco public key
+type TSMPDiscoKeyUpdate struct {
+ IPHeader Header
+ DiscoKey [32]byte // raw disco public key bytes
+}
+
+// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
+// The update.IPHeader field is not populated.
+func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
+ if pp.IPProto != ipproto.TSMP {
+ return
+ }
+ p := pp.Payload()
+ if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
+ return
+ }
+ copy(update.DiscoKey[:], p[1:33])
+ return update, true
+}
+
+func (h TSMPDiscoKeyUpdate) Len() int {
+ return h.IPHeader.Len() + 33
+}
+
+func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if err := h.IPHeader.Marshal(buf); err != nil {
+ return err
+ }
+ buf = buf[h.IPHeader.Len():]
+ buf[0] = byte(TSMPTypeDiscoKeyUpdate)
+ copy(buf[1:33], h.DiscoKey[:])
+ return nil
+}
diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go
index e261e6a41..a8cd3cad5 100644
--- a/net/packet/tsmp_test.go
+++ b/net/packet/tsmp_test.go
@@ -71,3 +71,171 @@ func TestTailscaleRejectedHeader(t *testing.T) {
}
}
}
+
+func TestTSMPDiscoKeyRequest(t *testing.T) {
+ t.Run("Manual", func(t *testing.T) {
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
+
+ var p Parsed
+ p.IPProto = TSMP
+ p.dataofs = 40 // simulate after IP header
+ buf := make([]byte, 40+1)
+ copy(buf[40:], payload[:])
+ p.b = buf
+ p.length = len(buf)
+
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request")
+ }
+ })
+
+ t.Run("RoundTripIPv4", func(t *testing.T) {
+ src := netip.MustParseAddr("100.64.0.1")
+ dst := netip.MustParseAddr("100.64.0.2")
+
+ iph := IP4Header{
+ IPProto: TSMP,
+ Src: src,
+ Dst: dst,
+ }
+
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
+
+ pkt := Generate(iph, payload[:])
+ t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
+
+ // Manually check what decode4 would see
+ if len(pkt) >= 4 {
+ declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
+ t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
+ t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+ t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
+ p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
+
+ if p.IPVersion != 4 {
+ t.Errorf("IPVersion = %d, want 4", p.IPVersion)
+ }
+ if p.IPProto != TSMP {
+ t.Errorf("IPProto = %v, want TSMP", p.IPProto)
+ }
+ if p.Src.Addr() != src {
+ t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
+ }
+ if p.Dst.Addr() != dst {
+ t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
+ }
+
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request from generated packet")
+ }
+ })
+
+ t.Run("RoundTripIPv6", func(t *testing.T) {
+ src := netip.MustParseAddr("2001:db8::1")
+ dst := netip.MustParseAddr("2001:db8::2")
+
+ iph := IP6Header{
+ IPProto: TSMP,
+ Src: src,
+ Dst: dst,
+ }
+
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
+
+ pkt := Generate(iph, payload[:])
+ t.Logf("Generated packet: %d bytes", len(pkt))
+
+ var p Parsed
+ p.Decode(pkt)
+
+ if p.IPVersion != 6 {
+ t.Errorf("IPVersion = %d, want 6", p.IPVersion)
+ }
+ if p.IPProto != TSMP {
+ t.Errorf("IPProto = %v, want TSMP", p.IPProto)
+ }
+ if p.Src.Addr() != src {
+ t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
+ }
+ if p.Dst.Addr() != dst {
+ t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
+ }
+
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request from generated packet")
+ }
+ })
+}
+
+func TestTSMPDiscoKeyUpdate(t *testing.T) {
+ var discoKey [32]byte
+ for i := range discoKey {
+ discoKey[i] = byte(i + 10)
+ }
+
+ // Test IPv4
+ t.Run("IPv4", func(t *testing.T) {
+ update := TSMPDiscoKeyUpdate{
+ IPHeader: IP4Header{
+ IPProto: TSMP,
+ Src: netip.MustParseAddr("1.2.3.4"),
+ Dst: netip.MustParseAddr("5.6.7.8"),
+ },
+ DiscoKey: discoKey,
+ }
+
+ pkt := make([]byte, update.Len())
+ if err := update.Marshal(pkt); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+
+ parsed, ok := p.AsTSMPDiscoKeyUpdate()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key update")
+ }
+ if parsed.DiscoKey != discoKey {
+ t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
+ }
+ })
+
+ // Test IPv6
+ t.Run("IPv6", func(t *testing.T) {
+ update := TSMPDiscoKeyUpdate{
+ IPHeader: IP6Header{
+ IPProto: TSMP,
+ Src: netip.MustParseAddr("2001:db8::1"),
+ Dst: netip.MustParseAddr("2001:db8::2"),
+ },
+ DiscoKey: discoKey,
+ }
+
+ pkt := make([]byte, update.Len())
+ if err := update.Marshal(pkt); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+
+ parsed, ok := p.AsTSMPDiscoKeyUpdate()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key update")
+ }
+ if parsed.DiscoKey != discoKey {
+ t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
+ }
+ })
+}
diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go
index db4f689bf..234dc4941 100644
--- a/net/tstun/wrap.go
+++ b/net/tstun/wrap.go
@@ -188,11 +188,19 @@ type Wrapper struct {
// OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives.
OnTSMPPongReceived func(packet.TSMPPongReply)
+ // OnTSMPDiscoKeyReceived, if non-nil, is called whenever a TSMP disco key update arrives.
+ // The srcIP parameter identifies the peer that sent the update.
+ OnTSMPDiscoKeyReceived func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate)
+
// OnICMPEchoResponseReceived, if non-nil, is called whenever a ICMP echo response
// arrives. If the packet is to be handled internally this returns true,
// false otherwise.
OnICMPEchoResponseReceived func(*packet.Parsed) bool
+ // GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
+ // This is called when responding to TSMP disco key requests.
+ GetDiscoPublicKey func() key.DiscoPublic
+
// PeerAPIPort, if non-nil, returns the peerapi port that's
// running for the given IP address.
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
@@ -1132,6 +1140,15 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
if f := t.OnTSMPPongReceived; f != nil {
f(data)
}
+ } else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
+ t.noteActivity()
+ t.injectOutboundDiscoKeyUpdate(p)
+ return filter.DropSilently, gro
+ } else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
+ if f := t.OnTSMPDiscoKeyReceived; f != nil {
+ f(p.Src.Addr(), discoKeyUpdate)
+ }
+ return filter.DropSilently, gro
}
}
@@ -1440,6 +1457,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
t.InjectOutbound(packet.Generate(pong, nil))
}
+func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
+ if t.GetDiscoPublicKey == nil {
+ return
+ }
+
+ discoKey := t.GetDiscoPublicKey()
+ if discoKey.IsZero() {
+ return
+ }
+
+ update := packet.TSMPDiscoKeyUpdate{
+ DiscoKey: discoKey.Raw32(),
+ }
+
+ switch pp.IPVersion {
+ case 4:
+ h4 := pp.IP4Header()
+ h4.ToResponse()
+ update.IPHeader = h4
+ case 6:
+ h6 := pp.IP6Header()
+ h6.ToResponse()
+ update.IPHeader = h6
+ default:
+ return
+ }
+
+ t.InjectOutbound(packet.Generate(update, nil))
+}
+
// InjectOutbound makes the Wrapper device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.
diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go
index 37a4f1a64..8896d4009 100644
--- a/wgengine/magicsock/derp.go
+++ b/wgengine/magicsock/derp.go
@@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
}
+ // Request disco key from peer via TSMP if we receive a WireGuard handshake
+ // over DERP without recent disco success. This handles the "WireGuard-first"
+ // case where WireGuard establishes a tunnel via DERP before disco succeeds
+ // (e.g., control plane unreachable or stale disco keys).
+ // We only trigger on data packets (not handshake packets) because the tunnel
+ // must be fully established before we can send TSMP requests through it.
+ if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
+ go c.requestDiscoKeyViaTSMP(dm.src, ep)
+ }
+
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go
index 064838a2d..ae5ea3daa 100644
--- a/wgengine/magicsock/magicsock.go
+++ b/wgengine/magicsock/magicsock.go
@@ -155,17 +155,18 @@ type Conn struct {
// This block mirrors the contents and field order of the Options
// struct. Initialized once at construction, then constant.
- eventBus *eventbus.Bus
- eventClient *eventbus.Client
- logf logger.Logf
- epFunc func([]tailcfg.Endpoint)
- derpActiveFunc func()
- idleFunc func() time.Duration // nil means unknown
- testOnlyPacketListener nettype.PacketListener
- noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
- netMon *netmon.Monitor // must be non-nil
- health *health.Tracker // or nil
- controlKnobs *controlknobs.Knobs // or nil
+ eventBus *eventbus.Bus
+ eventClient *eventbus.Client
+ logf logger.Logf
+ epFunc func([]tailcfg.Endpoint)
+ derpActiveFunc func()
+ idleFunc func() time.Duration // nil means unknown
+ testOnlyPacketListener nettype.PacketListener
+ noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
+ sendTSMPDiscoKeyRequest func(netip.Addr) error // or nil, sends TSMP disco key request to peer
+ netMon *netmon.Monitor // must be non-nil
+ health *health.Tracker // or nil
+ controlKnobs *controlknobs.Knobs // or nil
// ================================================================
// No locking required to access these fields, either because
@@ -1800,6 +1801,15 @@ func looksLikeInitiationMsg(b []byte) bool {
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
}
+func looksLikeWireGuardHandshake(b []byte) bool {
+ if len(b) < 4 {
+ return false
+ }
+ msgType := binary.LittleEndian.Uint32(b)
+ return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
+ (len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
+}
+
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
//
// size is the length of 'b' to report up to wireguard-go (only relevant if
@@ -2857,6 +2867,14 @@ func (c *Conn) SetSilentDisco(v bool) {
})
}
+// SetSendTSMPDiscoKeyRequest sets the callback function to send TSMP disco key requests.
+// This is provided by the engine/tundev to inject TSMP packets.
+func (c *Conn) SetSendTSMPDiscoKeyRequest(fn func(netip.Addr) error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.sendTSMPDiscoKeyRequest = fn
+}
+
// SilentDisco returns true if silent disco is enabled, otherwise false.
func (c *Conn) SilentDisco() bool {
c.mu.Lock()
@@ -4104,6 +4122,13 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
+
+ // TSMP disco key exchange
+ metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
+ metricTSMPDiscoKeyRequestError = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_error")
+ metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
+ metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
+ metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
)
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@@ -4242,6 +4267,101 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
// See http://go/corp/29422 & http://go/corp/30042
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
+
+ // Request disco key from peer via TSMP if we establish a tunnel
+ // without a recent disco ping. This handles cases where WireGuard
+ // establishes a tunnel before disco succeeds (e.g., control plane
+ // unreachable or stale disco keys).
+ go le.c.requestDiscoKeyViaTSMP(pubKey, ep)
+}
+
+// requestDiscoKeyViaTSMP sends a TSMP disco key request to a peer if there
+// hasn't been a recent disco ping.
+func (c *Conn) requestDiscoKeyViaTSMP(nodeKey key.NodePublic, ep *endpoint) {
+ if c.sendTSMPDiscoKeyRequest == nil {
+ return
+ }
+ if !ep.nodeAddr.IsValid() {
+ return
+ }
+
+ epDisco := ep.disco.Load()
+ if epDisco != nil {
+ c.mu.Lock()
+ di := c.discoInfo[epDisco.key]
+ recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
+ c.mu.Unlock()
+
+ if recentDiscoPing {
+ return
+ }
+ }
+ // YUCK. once again goroutines fight back - we need to deterministically
+ // schedule this _after_ the wireguard handshake response or else we trigger
+ // the wireguard handshake race problem. Maybe it's ok though, as we should
+ // really be singleflighting this, and perhaps we just use a singleflight
+ // with a short cork.
+ time.Sleep(time.Millisecond)
+
+ c.logf("magicsock: sending TSMP disco key request to %v (%v)", nodeKey.ShortString(), ep.nodeAddr)
+ if err := c.sendTSMPDiscoKeyRequest(ep.nodeAddr); err != nil {
+ c.logf("magicsock: failed to send TSMP disco key request: %v", err)
+ metricTSMPDiscoKeyRequestError.Add(1)
+ return
+ }
+ metricTSMPDiscoKeyRequestSent.Add(1)
+}
+
+// HandleDiscoKeyUpdate processes a TSMP disco key update.
+// The update may be solicited (in response to a request) or unsolicited.
+// srcIP is the Tailscale IP of the peer that sent the update.
+func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
+ discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
+ c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
+ metricTSMPDiscoKeyUpdateReceived.Add(1)
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ var nodeKey key.NodePublic
+ var found bool
+ for _, peer := range c.peers.All() {
+ for _, addr := range peer.Addresses().All() {
+ if addr.Addr() == srcIP {
+ nodeKey = peer.Key()
+ found = true
+ break
+ }
+ }
+ if found {
+ break
+ }
+ }
+
+ if !found {
+ c.logf("magicsock: disco key update from unknown peer %v", srcIP)
+ metricTSMPDiscoKeyUpdateUnknown.Add(1)
+ return
+ }
+
+ ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
+ if !ok {
+ c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
+ return
+ }
+
+ oldDiscoKey := key.DiscoPublic{}
+ if epDisco := ep.disco.Load(); epDisco != nil {
+ oldDiscoKey = epDisco.key
+ }
+ c.discoInfoForKnownPeerLocked(discoKey)
+ ep.disco.Store(&endpointDisco{
+ key: discoKey,
+ short: discoKey.ShortString(),
+ })
+ c.peerMap.upsertEndpoint(ep, oldDiscoKey)
+ c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
+ metricTSMPDiscoKeyUpdateApplied.Add(1)
}
// PeerRelays returns the current set of candidate peer relays.
diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go
index 7ae422906..59c8cfb9f 100644
--- a/wgengine/magicsock/magicsock_test.go
+++ b/wgengine/magicsock/magicsock_test.go
@@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
+ "tailscale.com/types/views"
"tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
@@ -4305,3 +4306,58 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
+
+func TestSendTSMPDiscoKeyRequest(t *testing.T) {
+ ep := &endpoint{
+ nodeID: 1,
+ publicKey: key.NewNode().Public(),
+ nodeAddr: netip.MustParseAddr("100.64.0.1"),
+ }
+ discoKey := key.NewDisco().Public()
+ ep.disco.Store(&endpointDisco{
+ key: discoKey,
+ short: discoKey.ShortString(),
+ })
+ conn := newConn(t.Logf)
+ ep.c = conn
+
+ tsmpRequestCalled := make(chan struct{}, 1)
+ var capturedIP netip.Addr
+ conn.sendTSMPDiscoKeyRequest = func(ip netip.Addr) error {
+ capturedIP = ip
+ tsmpRequestCalled <- struct{}{}
+ return nil
+ }
+
+ conn.mu.Lock()
+ conn.peers = views.SliceOf([]tailcfg.NodeView{
+ (&tailcfg.Node{
+ Key: ep.publicKey,
+ Addresses: []netip.Prefix{
+ netip.MustParsePrefix("100.64.0.1/32"),
+ },
+ }).View(),
+ })
+ conn.mu.Unlock()
+
+ var pubKey [32]byte
+ copy(pubKey[:], ep.publicKey.AppendTo(nil))
+ conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
+
+ le := &lazyEndpoint{
+ c: conn,
+ src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
+ }
+
+ le.FromPeer(pubKey)
+
+ select {
+ case <-tsmpRequestCalled:
+ if !capturedIP.IsValid() {
+ t.Error("TSMP request sent with invalid IP")
+ }
+ t.Logf("TSMP disco key request sent to %v", capturedIP)
+ case <-time.After(time.Second):
+ t.Error("TSMP disco key request was not sent")
+ }
+}
diff --git a/wgengine/magicsock/tsmp_disco_test.go b/wgengine/magicsock/tsmp_disco_test.go
new file mode 100644
index 000000000..7a89e27ec
--- /dev/null
+++ b/wgengine/magicsock/tsmp_disco_test.go
@@ -0,0 +1,291 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun/tuntest"
+ "tailscale.com/net/netaddr"
+ "tailscale.com/net/packet"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tstest"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/key"
+ "tailscale.com/types/netmap"
+ "tailscale.com/util/set"
+ "tailscale.com/wgengine/wgcfg/nmcfg"
+)
+
+func TestTSMPDiscoKeyExchange(t *testing.T) {
+ tstest.ResourceCheck(t)
+
+ // Set up DERP and STUN servers
+ derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
+ defer cleanup()
+
+ // Create two magicsock peers
+ m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
+ defer m1.Close()
+ m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
+ defer m2.Close()
+
+ // Wire up TSMP hooks to enable disco key exchange
+ // This mimics what userspaceEngine does in wgengine/userspace.go
+
+ // Hook 0: GetDiscoPublicKey - allows TSMP replies to include current disco key
+ m1.tsTun.GetDiscoPublicKey = m1.conn.DiscoPublicKey
+ m2.tsTun.GetDiscoPublicKey = m2.conn.DiscoPublicKey
+
+ // Hook 1: OnTSMPDiscoKeyReceived - handle incoming TSMP disco key updates
+ m1.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
+ t.Logf("m1: received TSMP disco key update from %v", srcIP)
+ m1.conn.HandleDiscoKeyUpdate(srcIP, update)
+ }
+ m2.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
+ t.Logf("m2: received TSMP disco key update from %v", srcIP)
+ m2.conn.HandleDiscoKeyUpdate(srcIP, update)
+ }
+
+ sendTSMPDiscoKeyRequest := func(dstIP netip.Addr) error {
+ var srcIP netip.Addr
+ var stack *magicStack
+
+ switch dstIP {
+ case m1.IP():
+ srcIP = m2.IP()
+ stack = m2
+ t.Logf("m2: sending disco key request to m1")
+ case m2.IP():
+ srcIP = m1.IP()
+ stack = m1
+ t.Logf("m1: sending disco key request to m2")
+ }
+
+ // equivalent to the implementation in userspace.Engine
+ iph := packet.IP4Header{
+ IPProto: ipproto.TSMP,
+ Src: srcIP,
+ Dst: dstIP,
+ }
+
+ var tsmpPayload [1]byte
+ tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
+
+ tsmpRequest := packet.Generate(iph, tsmpPayload[:])
+ return stack.tsTun.InjectOutbound(tsmpRequest)
+ }
+
+ // Hook 2: SetSendTSMPDiscoKeyRequest - send TSMP disco key requests
+ m1.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
+ m2.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
+
+ // Get initial disco keys
+ disco1Original := m1.conn.DiscoPublicKey()
+ disco2 := m2.conn.DiscoPublicKey()
+
+ t.Logf("m1: node=%v disco=%v", m1.Public().ShortString(), disco1Original.ShortString())
+ t.Logf("m2: node=%v disco=%v", m2.Public().ShortString(), disco2.ShortString())
+
+ // Wait for initial endpoints
+ var eps1, eps2 []tailcfg.Endpoint
+ select {
+ case eps1 = <-m1.epCh:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for m1 endpoints")
+ }
+ select {
+ case eps2 = <-m2.epCh:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for m2 endpoints")
+ }
+
+ // Build initial network maps and establish connection
+ nm1 := &netmap.NetworkMap{
+ NodeKey: m1.Public(),
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
+ }).View(),
+ Peers: []tailcfg.NodeView{
+ (&tailcfg.Node{
+ ID: 2,
+ Key: m2.Public(),
+ DiscoKey: disco2,
+ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
+ AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
+ Endpoints: epFromTyped(eps2),
+ HomeDERP: 1,
+ }).View(),
+ },
+ }
+
+ nm2 := &netmap.NetworkMap{
+ NodeKey: m2.Public(),
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
+ }).View(),
+ Peers: []tailcfg.NodeView{
+ (&tailcfg.Node{
+ ID: 1,
+ Key: m1.Public(),
+ DiscoKey: disco1Original,
+ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
+ AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
+ Endpoints: epFromTyped(eps1),
+ HomeDERP: 1,
+ }).View(),
+ },
+ }
+
+ cfg1, err := nmcfg.WGCfg(m1.privateKey, nm1, t.Logf, 0, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ cfg2, err := nmcfg.WGCfg(m2.privateKey, nm2, t.Logf, 0, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ nv1 := NodeViewsUpdate{
+ SelfNode: nm1.SelfNode,
+ Peers: nm1.Peers,
+ }
+ m1.conn.onNodeViewsUpdate(nv1)
+
+ peerSet1 := set.Set[key.NodePublic]{}
+ peerSet1.Add(m2.Public())
+ m1.conn.UpdatePeers(peerSet1)
+
+ nv2 := NodeViewsUpdate{
+ SelfNode: nm2.SelfNode,
+ Peers: nm2.Peers,
+ }
+ m2.conn.onNodeViewsUpdate(nv2)
+
+ peerSet2 := set.Set[key.NodePublic]{}
+ peerSet2.Add(m1.Public())
+ m2.conn.UpdatePeers(peerSet2)
+
+ if err := m1.Reconfig(cfg1); err != nil {
+ t.Fatal(err)
+ }
+ if err := m2.Reconfig(cfg2); err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("=== INITIAL CONFIGURATION COMPLETE ===")
+
+ // Start goroutines to drain TUN inbound channels so TSMP packets can be received
+ drainTun := func(name string, stack *magicStack) {
+ go func() {
+ for {
+ select {
+ case <-t.Context().Done():
+ return
+ case pkt := <-stack.tun.Inbound:
+ var p packet.Parsed
+ p.Decode(pkt)
+ if p.IPProto == ipproto.TSMP {
+ t.Logf("%s: received TSMP packet on TUN inbound: %d bytes", name, len(pkt))
+ } else if p.IPProto == ipproto.ICMPv4 {
+ t.Logf("%s: received ICMPv4 packet on TUN inbound: %d bytes", name, len(pkt))
+ } else {
+ t.Logf("%s: received packet on TUN inbound: %d bytes, proto=%v", name, len(pkt), p.IPProto)
+ }
+ }
+ }
+ }()
+ }
+ drainTun("m1", m1)
+ drainTun("m2", m2)
+
+ initialRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
+ initialUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
+ initialUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
+
+ t.Logf("Initial metrics: requests_sent=%d updates_received=%d updates_applied=%d",
+ initialRequestsSent, initialUpdatesReceived, initialUpdatesApplied)
+
+ t.Logf("=== ROTATING m1's DISCO KEY ===")
+ m1.conn.RotateDiscoKey()
+ disco1New := m1.conn.DiscoPublicKey()
+
+ if disco1Original.Compare(disco1New) == 0 {
+ t.Fatal("disco key failed to rotate")
+ }
+ t.Logf("Rotated: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
+
+ t.Logf("=== SENDING PACKETS TO TRIGGER TSMP EXCHANGE ===")
+
+ ping1to2 := tuntest.Ping(netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.1"))
+
+ // Send packets from m2 to m1 only - this will trigger m1's handshake initiation
+ // and when m2 receives the encrypted packet, it should trigger FromPeer -> TSMP
+ select {
+ case m1.tun.Outbound <- ping1to2:
+ default:
+ }
+
+ for {
+ time.Sleep(time.Millisecond)
+ // Check if m2 has learned m1's new disco key
+ st := m2.Status()
+ if ps, ok := st.Peer[m1.Public()]; ok && ps.CurAddr != "" {
+ t.Logf("Connection established after disco key rotation")
+ t.Logf("m2 -> m1 via %v", ps.CurAddr)
+ t.Logf("Disco key rotation: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
+
+ // Verify TSMP metrics incremented
+ finalRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
+ finalUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
+ finalUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
+
+ t.Logf("Final metrics: requests_sent=%d updates_received=%d updates_applied=%d",
+ finalRequestsSent, finalUpdatesReceived, finalUpdatesApplied)
+
+ // Check that at least one TSMP request was sent
+ if finalRequestsSent <= initialRequestsSent {
+ t.Errorf("Expected TSMP disco key request to be sent, but metric did not increment: %d -> %d",
+ initialRequestsSent, finalRequestsSent)
+ } else {
+ t.Logf("✓ TSMP disco key request sent (metric: %d -> %d)",
+ initialRequestsSent, finalRequestsSent)
+ }
+
+ // Check that at least one TSMP update was received
+ if finalUpdatesReceived <= initialUpdatesReceived {
+ t.Errorf("Expected TSMP disco key update to be received, but metric did not increment: %d -> %d",
+ initialUpdatesReceived, finalUpdatesReceived)
+ } else {
+ t.Logf("✓ TSMP disco key update received (metric: %d -> %d)",
+ initialUpdatesReceived, finalUpdatesReceived)
+ }
+
+ // Check that at least one TSMP update was applied
+ if finalUpdatesApplied <= initialUpdatesApplied {
+ t.Errorf("Expected TSMP disco key update to be applied, but metric did not increment: %d -> %d",
+ initialUpdatesApplied, finalUpdatesApplied)
+ } else {
+ t.Logf("✓ TSMP disco key update applied (metric: %d -> %d)",
+ initialUpdatesApplied, finalUpdatesApplied)
+ }
+
+ // Verify error counter didn't increment
+ requestErrors := metricTSMPDiscoKeyRequestError.Value()
+ if requestErrors > 0 {
+ t.Logf("Warning: TSMP disco key request errors: %d", requestErrors)
+ }
+
+ unknownPeers := metricTSMPDiscoKeyUpdateUnknown.Value()
+ if unknownPeers > 0 {
+ t.Logf("Warning: TSMP disco key updates from unknown peers: %d", unknownPeers)
+ }
+
+ t.Logf("TSMP disco key exchange infrastructure is functional")
+ return
+ }
+ }
+}
diff --git a/wgengine/userspace.go b/wgengine/userspace.go
index 8ad771fc5..53d3bb0dc 100644
--- a/wgengine/userspace.go
+++ b/wgengine/userspace.go
@@ -466,6 +466,25 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return true
}
+ e.tundev.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
+ e.logf("wgengine: got TSMP disco key update from %v", srcIP)
+ if e.magicConn != nil {
+ e.magicConn.HandleDiscoKeyUpdate(srcIP, update)
+ }
+ }
+
+ e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
+ if e.magicConn == nil {
+ return key.DiscoPublic{}
+ }
+ return e.magicConn.DiscoPublicKey()
+ }
+
+ // Wire up TSMP disco key request sending to magicsock
+ if e.magicConn != nil {
+ e.magicConn.SetSendTSMPDiscoKeyRequest(e.sendTSMPDiscoKeyRequest)
+ }
+
// wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating WireGuard device...")
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
@@ -1563,6 +1582,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
}
}
+// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
+func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
+ srcIP, err := e.mySelfIPMatchingFamily(ip)
+ if err != nil {
+ return err
+ }
+
+ var iph packet.Header
+ if srcIP.Is4() {
+ iph = packet.IP4Header{
+ IPProto: ipproto.TSMP,
+ Src: srcIP,
+ Dst: ip,
+ }
+ } else {
+ iph = packet.IP6Header{
+ IPProto: ipproto.TSMP,
+ Src: srcIP,
+ Dst: ip,
+ }
+ }
+
+ var tsmpPayload [1]byte
+ tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
+
+ tsmpRequest := packet.Generate(iph, tsmpPayload[:])
+ return e.tundev.InjectOutbound(tsmpRequest)
+}
+
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
e.mu.Lock()
defer e.mu.Unlock()