summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJames Tucker <james@tailscale.com>2025-11-03 14:53:11 -0800
committerJames Tucker <james@tailscale.com>2025-11-26 14:37:17 -0800
commit5bfa8e97f6b419e6a1d2923c47a8fab7258b04db (patch)
treeb2407cecbe5da32ee9aa2959fbb216c0bffc1a32
parent9eff8a45034bc36a17004dce1fe6e7732af631a4 (diff)
downloadtailscale-raggi/disco-key-tsmp2.tar.xz
tailscale-raggi/disco-key-tsmp2.zip
net,wgengine: add support for disco key exchnage via TSMPraggi/disco-key-tsmp2
Updates tailscale/corp#34037 Signed-off-by: James Tucker <james@tailscale.com>
-rw-r--r--net/packet/tsmp.go89
-rw-r--r--net/packet/tsmp_test.go211
-rw-r--r--net/tstun/wrap.go68
-rw-r--r--net/tstun/wrap_test.go55
-rw-r--r--wgengine/magicsock/derp.go10
-rw-r--r--wgengine/magicsock/magicsock.go106
-rw-r--r--wgengine/magicsock/magicsock_test.go64
-rw-r--r--wgengine/userspace.go94
-rw-r--r--wgengine/userspace_test.go13
9 files changed, 552 insertions, 158 deletions
diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go
index 8fad1d503..b9b805b1a 100644
--- a/net/packet/tsmp.go
+++ b/net/packet/tsmp.go
@@ -15,12 +15,10 @@ import (
"fmt"
"net/netip"
- "go4.org/mem"
"tailscale.com/types/ipproto"
- "tailscale.com/types/key"
)
-const minTSMPSize = 7 // the rejected body is 7 bytes
+const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
// TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a
@@ -75,8 +73,11 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
- // TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
- TSMPTypeDiscoAdvertisement TSMPType = 'a'
+ // TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
+ TSMPTypeDiscoKeyRequest TSMPType = 'd'
+
+ // TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
+ TSMPTypeDiscoKeyUpdate TSMPType = 'D'
)
type TailscaleRejectReason byte
@@ -265,52 +266,62 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
return nil
}
-// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
+// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
//
-// On the wire, after the IP header, it's currently 33 bytes:
-// - 'a' (TSMPTypeDiscoAdvertisement)
-// - 32 disco key bytes
-type TSMPDiscoKeyAdvertisement struct {
- Src, Dst netip.Addr
- Key key.DiscoPublic
-}
+// On the wire, after the IP header, it's currently 1 byte:
+// - 'd' (TSMPTypeDiscoKeyRequest)
+type TSMPDiscoKeyRequest struct{}
-func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
- var iph Header
- if ka.Src.Is4() {
- iph = IP4Header{
- IPProto: ipproto.TSMP,
- Src: ka.Src,
- Dst: ka.Dst,
- }
- } else {
- iph = IP6Header{
- IPProto: ipproto.TSMP,
- Src: ka.Src,
- Dst: ka.Dst,
- }
+func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
+ if pp.IPProto != ipproto.TSMP {
+ return
}
- payload := make([]byte, 0, 33)
- payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
- payload = ka.Key.AppendTo(payload)
- if len(payload) != 33 {
- // Mostly to safeguard against ourselves changing this in the future.
- return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
+ p := pp.Payload()
+ if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
+ return
}
+ return h, true
+}
- return Generate(iph, payload), nil
+// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
+// It may be sent in response to a request, or unsolicited when a node
+// believes its peer may have stale disco key information.
+//
+// On the wire, after the IP header, it's currently 33 bytes:
+// - 'D' (TSMPTypeDiscoKeyUpdate)
+// - 32 bytes disco public key
+type TSMPDiscoKeyUpdate struct {
+ IPHeader Header
+ DiscoKey [32]byte // raw disco public key bytes
}
-func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
+// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
+// The update.IPHeader field is not populated.
+func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
- if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
+ if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
return
}
- tka.Src = pp.Src.Addr()
- tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
+ copy(update.DiscoKey[:], p[1:33])
+ return update, true
+}
- return tka, true
+func (h TSMPDiscoKeyUpdate) Len() int {
+ return h.IPHeader.Len() + 33
+}
+
+func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if err := h.IPHeader.Marshal(buf); err != nil {
+ return err
+ }
+ buf = buf[h.IPHeader.Len():]
+ buf[0] = byte(TSMPTypeDiscoKeyUpdate)
+ copy(buf[1:33], h.DiscoKey[:])
+ return nil
}
diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go
index d8f1d38d5..52e829d70 100644
--- a/net/packet/tsmp_test.go
+++ b/net/packet/tsmp_test.go
@@ -4,14 +4,8 @@
package packet
import (
- "bytes"
- "encoding/hex"
"net/netip"
- "slices"
"testing"
-
- "go4.org/mem"
- "tailscale.com/types/key"
)
func TestTailscaleRejectedHeader(t *testing.T) {
@@ -78,61 +72,168 @@ func TestTailscaleRejectedHeader(t *testing.T) {
}
}
-func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) {
- var (
- // IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum
- headerV4, _ = hex.DecodeString("45000035000000004063705d")
- // IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
- headerV6, _ = hex.DecodeString("6000000000216340")
+func TestTSMPDiscoKeyRequest(t *testing.T) {
+ t.Run("Manual", func(t *testing.T) {
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
+
+ var p Parsed
+ p.IPProto = TSMP
+ p.dataofs = 40 // simulate after IP header
+ buf := make([]byte, 40+1)
+ copy(buf[40:], payload[:])
+ p.b = buf
+ p.length = len(buf)
+
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request")
+ }
+ })
+
+ t.Run("RoundTripIPv4", func(t *testing.T) {
+ src := netip.MustParseAddr("100.64.0.1")
+ dst := netip.MustParseAddr("100.64.0.2")
+
+ iph := IP4Header{
+ IPProto: TSMP,
+ Src: src,
+ Dst: dst,
+ }
+
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
+
+ pkt := Generate(iph, payload[:])
+ t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
+
+ // Manually check what decode4 would see
+ if len(pkt) >= 4 {
+ declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
+ t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
+ t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+ t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
+ p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
+
+ if p.IPVersion != 4 {
+ t.Errorf("IPVersion = %d, want 4", p.IPVersion)
+ }
+ if p.IPProto != TSMP {
+ t.Errorf("IPProto = %v, want TSMP", p.IPProto)
+ }
+ if p.Src.Addr() != src {
+ t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
+ }
+ if p.Dst.Addr() != dst {
+ t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
+ }
- packetType = []byte{'a'}
- testKey = bytes.Repeat([]byte{'a'}, 32)
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request from generated packet")
+ }
+ })
+
+ t.Run("RoundTripIPv6", func(t *testing.T) {
+ src := netip.MustParseAddr("2001:db8::1")
+ dst := netip.MustParseAddr("2001:db8::2")
+
+ iph := IP6Header{
+ IPProto: TSMP,
+ Src: src,
+ Dst: dst,
+ }
- // IPs
- srcV4 = netip.MustParseAddr("1.2.3.4")
- dstV4 = netip.MustParseAddr("4.3.2.1")
- srcV6 = netip.MustParseAddr("2001:db8::1")
- dstV6 = netip.MustParseAddr("2001:db8::2")
- )
+ var payload [1]byte
+ payload[0] = byte(TSMPTypeDiscoKeyRequest)
- join := func(parts ...[]byte) []byte {
- return bytes.Join(parts, nil)
+ pkt := Generate(iph, payload[:])
+ t.Logf("Generated packet: %d bytes", len(pkt))
+
+ var p Parsed
+ p.Decode(pkt)
+
+ if p.IPVersion != 6 {
+ t.Errorf("IPVersion = %d, want 6", p.IPVersion)
+ }
+ if p.IPProto != TSMP {
+ t.Errorf("IPProto = %v, want TSMP", p.IPProto)
+ }
+ if p.Src.Addr() != src {
+ t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
+ }
+ if p.Dst.Addr() != dst {
+ t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
+ }
+
+ _, ok := p.AsTSMPDiscoKeyRequest()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key request from generated packet")
+ }
+ })
+}
+
+func TestTSMPDiscoKeyUpdate(t *testing.T) {
+ var discoKey [32]byte
+ for i := range discoKey {
+ discoKey[i] = byte(i + 10)
}
- tests := []struct {
- name string
- tka TSMPDiscoKeyAdvertisement
- want []byte
- }{
- {
- name: "v4Header",
- tka: TSMPDiscoKeyAdvertisement{
- Src: srcV4,
- Dst: dstV4,
- Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
+ t.Run("IPv4", func(t *testing.T) {
+ update := TSMPDiscoKeyUpdate{
+ IPHeader: IP4Header{
+ IPProto: TSMP,
+ Src: netip.MustParseAddr("1.2.3.4"),
+ Dst: netip.MustParseAddr("5.6.7.8"),
},
- want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey),
- },
- {
- name: "v6Header",
- tka: TSMPDiscoKeyAdvertisement{
- Src: srcV6,
- Dst: dstV6,
- Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
+ DiscoKey: discoKey,
+ }
+
+ pkt := make([]byte, update.Len())
+ if err := update.Marshal(pkt); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+
+ parsed, ok := p.AsTSMPDiscoKeyUpdate()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key update")
+ }
+ if parsed.DiscoKey != discoKey {
+ t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ update := TSMPDiscoKeyUpdate{
+ IPHeader: IP6Header{
+ IPProto: TSMP,
+ Src: netip.MustParseAddr("2001:db8::1"),
+ Dst: netip.MustParseAddr("2001:db8::2"),
},
- want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey),
- },
- }
+ DiscoKey: discoKey,
+ }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := tt.tka.Marshal()
- if err != nil {
- t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
- }
- if !slices.Equal(got, tt.want) {
- t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
- }
- })
- }
+ pkt := make([]byte, update.Len())
+ if err := update.Marshal(pkt); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(pkt)
+
+ parsed, ok := p.AsTSMPDiscoKeyUpdate()
+ if !ok {
+ t.Fatal("failed to parse TSMP disco key update")
+ }
+ if parsed.DiscoKey != discoKey {
+ t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
+ }
+ })
}
diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go
index 6e07c7a3d..0619276ae 100644
--- a/net/tstun/wrap.go
+++ b/net/tstun/wrap.go
@@ -194,6 +194,10 @@ type Wrapper struct {
// false otherwise.
OnICMPEchoResponseReceived func(*packet.Parsed) bool
+ // GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
+ // This is called when responding to TSMP disco key requests.
+ GetDiscoPublicKey func() key.DiscoPublic
+
// PeerAPIPort, if non-nil, returns the peerapi port that's
// running for the given IP address.
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
@@ -211,8 +215,8 @@ type Wrapper struct {
metrics *metrics
- eventClient *eventbus.Client
- discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement]
+ eventClient *eventbus.Client
+ discoKeyUpdatePub *eventbus.Publisher[DiscoKeyUpdate]
}
type metrics struct {
@@ -227,6 +231,12 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
}
}
+// DiscoKeyUpdate is published on the event bus when a TSMP disco key update is received.
+type DiscoKeyUpdate struct {
+ SrcIP netip.Addr
+ Key [32]byte
+}
+
// tunInjectedRead is an injected packet pretending to be a tun.Read().
type tunInjectedRead struct {
// Only one of packet or data should be set, and are read in that order of
@@ -288,7 +298,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry,
}
w.eventClient = bus.Client("net.tstun")
- w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient)
+ w.discoKeyUpdatePub = eventbus.Publish[DiscoKeyUpdate](w.eventClient)
w.vectorBuffer = make([][]byte, tdev.BatchSize())
for i := range w.vectorBuffer {
@@ -1126,11 +1136,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err
}
-type DiscoKeyAdvertisement struct {
- Src netip.Addr
- Key key.DiscoPublic
-}
-
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
if captHook != nil {
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
@@ -1141,16 +1146,21 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
t.noteActivity()
t.injectOutboundPong(p, pingReq)
return filter.DropSilently, gro
- } else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
- t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
- Src: discoKeyAdvert.Src,
- Key: discoKeyAdvert.Key,
- })
- return filter.DropSilently, gro
} else if data, ok := p.AsTSMPPong(); ok {
if f := t.OnTSMPPongReceived; f != nil {
f(data)
}
+ } else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
+ t.noteActivity()
+ t.injectOutboundDiscoKeyUpdate(p)
+ return filter.DropSilently, gro
+ } else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
+ // Publish to eventbus for subscribers
+ t.discoKeyUpdatePub.Publish(DiscoKeyUpdate{
+ SrcIP: p.Src.Addr(),
+ Key: discoKeyUpdate.DiscoKey,
+ })
+ return filter.DropSilently, gro
}
}
@@ -1459,6 +1469,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
t.InjectOutbound(packet.Generate(pong, nil))
}
+func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
+ if t.GetDiscoPublicKey == nil {
+ return
+ }
+
+ discoKey := t.GetDiscoPublicKey()
+ if discoKey.IsZero() {
+ return
+ }
+
+ update := packet.TSMPDiscoKeyUpdate{
+ DiscoKey: discoKey.Raw32(),
+ }
+
+ switch pp.IPVersion {
+ case 4:
+ h4 := pp.IP4Header()
+ h4.ToResponse()
+ update.IPHeader = h4
+ case 6:
+ h6 := pp.IP6Header()
+ h6.ToResponse()
+ update.IPHeader = h6
+ default:
+ return
+ }
+
+ t.InjectOutbound(packet.Generate(update, nil))
+}
+
// InjectOutbound makes the Wrapper device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.
diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go
index c7d0708df..2d33228b8 100644
--- a/net/tstun/wrap_test.go
+++ b/net/tstun/wrap_test.go
@@ -966,28 +966,57 @@ func TestCaptureHook(t *testing.T) {
}
func TestTSMPDisco(t *testing.T) {
- t.Run("IPv6DiscoAdvert", func(t *testing.T) {
+ t.Run("DiscoKeyRequest", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
- discoKey := key.NewDisco()
- buf, _ := (&packet.TSMPDiscoKeyAdvertisement{
- Src: src,
- Dst: dst,
- Key: discoKey.Public(),
- }).Marshal()
+
+ iph := packet.IP6Header{
+ IPProto: ipproto.TSMP,
+ Src: src,
+ Dst: dst,
+ }
+
+ var payload [1]byte
+ payload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
+ buf := packet.Generate(iph, payload[:])
var p packet.Parsed
p.Decode(buf)
- tda, ok := p.AsTSMPDiscoAdvertisement()
+ _, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
- t.Error("Unable to parse message as TSMPDiscoAdversitement")
+ t.Error("Unable to parse message as TSMPDiscoKeyRequest")
}
- if tda.Src != src {
- t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
+ })
+
+ t.Run("DiscoKeyUpdate", func(t *testing.T) {
+ src := netip.MustParseAddr("2001:db8::1")
+ dst := netip.MustParseAddr("2001:db8::2")
+ discoKey := key.NewDisco()
+
+ update := packet.TSMPDiscoKeyUpdate{
+ IPHeader: packet.IP6Header{
+ IPProto: ipproto.TSMP,
+ Src: src,
+ Dst: dst,
+ },
+ DiscoKey: discoKey.Public().Raw32(),
+ }
+
+ buf := make([]byte, update.Len())
+ if err := update.Marshal(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ var p packet.Parsed
+ p.Decode(buf)
+
+ parsed, ok := p.AsTSMPDiscoKeyUpdate()
+ if !ok {
+ t.Error("Unable to parse message as TSMPDiscoKeyUpdate")
}
- if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
- t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
+ if parsed.DiscoKey != discoKey.Public().Raw32() {
+ t.Errorf("Key did not match, expected %v, got %v", discoKey.Public().Raw32(), parsed.DiscoKey)
}
})
}
diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go
index 37a4f1a64..e20acadc4 100644
--- a/wgengine/magicsock/derp.go
+++ b/wgengine/magicsock/derp.go
@@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
}
+ // Request disco key from peer via TSMP if we receive a WireGuard handshake
+ // over DERP without recent disco success. This handles the "WireGuard-first"
+ // case where WireGuard establishes a tunnel via DERP before disco succeeds
+ // (e.g., control plane unreachable or stale disco keys).
+ if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
+ c.mu.Lock()
+ c.requestDiscoKeyViaTSMPLocked(dm.src, ep)
+ c.mu.Unlock()
+ }
+
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go
index 064838a2d..da4f5c47d 100644
--- a/wgengine/magicsock/magicsock.go
+++ b/wgengine/magicsock/magicsock.go
@@ -178,9 +178,10 @@ type Conn struct {
// A publisher for synchronization points to ensure correct ordering of
// config changes between magicsock and wireguard.
- syncPub *eventbus.Publisher[syncPoint]
- allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
- portUpdatePub *eventbus.Publisher[router.PortUpdate]
+ syncPub *eventbus.Publisher[syncPoint]
+ allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
+ portUpdatePub *eventbus.Publisher[router.PortUpdate]
+ tsmpDiscoKeyRequestPub *eventbus.Publisher[TSMPDiscoKeyRequest]
// pconn4 and pconn6 are the underlying UDP sockets used to
// send/receive packets for wireguard and other magicsock
@@ -572,6 +573,14 @@ type UDPRelayAllocReq struct {
Message *disco.AllocateUDPRelayEndpointRequest
}
+// TSMPDiscoKeyRequest is published on the event bus when magicsock needs to
+// send a TSMP disco key request to a peer. Subscribers should inject the
+// TSMP packet into the tunnel device.
+type TSMPDiscoKeyRequest struct {
+ DstIP netip.Addr
+ MetricSent *clientmetric.Metric
+}
+
// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse]
// that is yet to be transmitted over DERP (or delivered locally if
// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from
@@ -691,6 +700,7 @@ func NewConn(opts Options) (*Conn, error) {
c.syncPub = eventbus.Publish[syncPoint](ec)
c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec)
c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec)
+ c.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](ec)
eventbus.SubscribeFunc(ec, c.onPortMapChanged)
eventbus.SubscribeFunc(ec, c.onFilterUpdate)
eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate)
@@ -1800,6 +1810,15 @@ func looksLikeInitiationMsg(b []byte) bool {
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
}
+func looksLikeWireGuardHandshake(b []byte) bool {
+ if len(b) < 4 {
+ return false
+ }
+ msgType := binary.LittleEndian.Uint32(b)
+ return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
+ (len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
+}
+
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
//
// size is the length of 'b' to report up to wireguard-go (only relevant if
@@ -4104,6 +4123,12 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
+
+ // TSMP disco key exchange
+ metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
+ metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
+ metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
+ metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
)
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@@ -4242,6 +4267,81 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
// See http://go/corp/29422 & http://go/corp/30042
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
+
+ le.c.requestDiscoKeyViaTSMPLocked(pubKey, ep)
+}
+
+// requestDiscoKeyViaTSMPLocked sends a TSMP disco key request to a peer if there
+// hasn't been a recent disco ping.
+// c.mu must be held.
+func (c *Conn) requestDiscoKeyViaTSMPLocked(nodeKey key.NodePublic, ep *endpoint) {
+ if !ep.nodeAddr.IsValid() {
+ return
+ }
+
+ epDisco := ep.disco.Load()
+ if epDisco != nil {
+ di := c.discoInfo[epDisco.key]
+ recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
+
+ if recentDiscoPing {
+ return
+ }
+ }
+
+ go c.tsmpDiscoKeyRequestPub.Publish(TSMPDiscoKeyRequest{DstIP: ep.nodeAddr, MetricSent: metricTSMPDiscoKeyRequestSent})
+}
+
+// HandleDiscoKeyUpdate processes a TSMP disco key update.
+// The update may be solicited (in response to a request) or unsolicited.
+// srcIP is the Tailscale IP of the peer that sent the update.
+func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
+ discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
+ c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
+ metricTSMPDiscoKeyUpdateReceived.Add(1)
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ var nodeKey key.NodePublic
+ var found bool
+ for _, peer := range c.peers.All() {
+ for _, addr := range peer.Addresses().All() {
+ if addr.Addr() == srcIP {
+ nodeKey = peer.Key()
+ found = true
+ break
+ }
+ }
+ if found {
+ break
+ }
+ }
+
+ if !found {
+ c.logf("magicsock: disco key update from unknown peer %v", srcIP)
+ metricTSMPDiscoKeyUpdateUnknown.Add(1)
+ return
+ }
+
+ ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
+ if !ok {
+ c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
+ return
+ }
+
+ oldDiscoKey := key.DiscoPublic{}
+ if epDisco := ep.disco.Load(); epDisco != nil {
+ oldDiscoKey = epDisco.key
+ }
+ c.discoInfoForKnownPeerLocked(discoKey)
+ ep.disco.Store(&endpointDisco{
+ key: discoKey,
+ short: discoKey.ShortString(),
+ })
+ c.peerMap.upsertEndpoint(ep, oldDiscoKey)
+ c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
+ metricTSMPDiscoKeyUpdateApplied.Add(1)
}
// PeerRelays returns the current set of candidate peer relays.
diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go
index 4e1024886..b7694ba56 100644
--- a/wgengine/magicsock/magicsock_test.go
+++ b/wgengine/magicsock/magicsock_test.go
@@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
+ "tailscale.com/types/views"
"tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
@@ -4302,3 +4303,66 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
+
+func TestSendTSMPDiscoKeyRequest(t *testing.T) {
+ ep := &endpoint{
+ nodeID: 1,
+ publicKey: key.NewNode().Public(),
+ nodeAddr: netip.MustParseAddr("100.64.0.1"),
+ }
+ discoKey := key.NewDisco().Public()
+ ep.disco.Store(&endpointDisco{
+ key: discoKey,
+ short: discoKey.ShortString(),
+ })
+ bus := eventbustest.NewBus(t)
+ conn := newConn(t.Logf)
+ conn.eventBus = bus
+ conn.eventClient = bus.Client("magicsock.Conn.test")
+ conn.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](conn.eventClient)
+ ep.c = conn
+
+ tsmpRequestCalled := make(chan struct{}, 1)
+ var capturedIP netip.Addr
+ ec := bus.Client("test")
+ defer ec.Close()
+ eventbus.SubscribeFunc(ec, func(req TSMPDiscoKeyRequest) {
+ capturedIP = req.DstIP
+ if req.MetricSent != nil {
+ req.MetricSent.Add(1)
+ }
+ tsmpRequestCalled <- struct{}{}
+ })
+
+ conn.mu.Lock()
+ conn.peers = views.SliceOf([]tailcfg.NodeView{
+ (&tailcfg.Node{
+ Key: ep.publicKey,
+ Addresses: []netip.Prefix{
+ netip.MustParsePrefix("100.64.0.1/32"),
+ },
+ }).View(),
+ })
+ conn.mu.Unlock()
+
+ var pubKey [32]byte
+ copy(pubKey[:], ep.publicKey.AppendTo(nil))
+ conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
+
+ le := &lazyEndpoint{
+ c: conn,
+ src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
+ }
+
+ le.FromPeer(pubKey)
+
+ select {
+ case <-tsmpRequestCalled:
+ if !capturedIP.IsValid() {
+ t.Error("TSMP request sent with invalid IP")
+ }
+ t.Logf("TSMP disco key request sent to %v", capturedIP)
+ case <-time.After(time.Second):
+ t.Error("TSMP disco key request was not sent")
+ }
+}
diff --git a/wgengine/userspace.go b/wgengine/userspace.go
index a369fa343..c0e79633a 100644
--- a/wgengine/userspace.go
+++ b/wgengine/userspace.go
@@ -54,6 +54,7 @@ import (
"tailscale.com/util/execqueue"
"tailscale.com/util/mak"
"tailscale.com/util/set"
+ "tailscale.com/util/singleflight"
"tailscale.com/util/testenv"
"tailscale.com/util/usermetric"
"tailscale.com/version"
@@ -469,6 +470,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return true
}
+ e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
+ if e.magicConn == nil {
+ return key.DiscoPublic{}
+ }
+ return e.magicConn.DiscoPublicKey()
+ }
+
// wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating WireGuard device...")
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
@@ -549,6 +557,36 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
}
e.linkChangeQueue.Add(func() { e.linkChange(&cd) })
})
+ eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyUpdate) {
+ e.logf("wgengine: got TSMP disco key update from %v via eventbus", update.SrcIP)
+ if e.magicConn != nil {
+ pkt := packet.TSMPDiscoKeyUpdate{
+ DiscoKey: update.Key,
+ }
+ e.magicConn.HandleDiscoKeyUpdate(update.SrcIP, pkt)
+ }
+ })
+ var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}]
+ eventbus.SubscribeFunc(ec, func(req magicsock.TSMPDiscoKeyRequest) {
+ go tsmpRequestGroup.Do(req.DstIP, func() (struct{}, error) {
+ // DiscoKeyRequests are triggered by an incoming WireGuard handshake
+ // initiation arriving before a disco ping, which is a likely
+ // indicator that disco pings failed due to a lack of key
+ // synchronization. If the requests are sent immediately, before the
+ // handshake state is accepted in the WireGuard client state
+ // machine, this starts a new session, and the two peer state
+ // machines conflict, causing loss and additional delays. Delaying
+ // the send avoids this, so coalesce duplicate sends, and delay them
+ // by a short time to avoid the state machine conflict.
+ time.Sleep(time.Millisecond)
+ if err := e.sendTSMPDiscoKeyRequest(req.DstIP); err != nil {
+ e.logf("wgengine: failed to send TSMP disco key request: %v", err)
+ }
+ e.logf("wgengine: sending TSMP disco key request to %v", req.DstIP)
+ req.MetricSent.Add(1)
+ return struct{}{}, nil
+ })
+ })
e.eventClient = ec
e.logf("Engine created.")
return e, nil
@@ -1436,7 +1474,6 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in
e.magicConn.Ping(peer, res, size, cb)
case "TSMP":
e.sendTSMPPing(ip, peer, res, cb)
- e.sendTSMPDiscoAdvertisement(ip)
case "ICMP":
e.sendICMPEchoRequest(ip, peer, res, cb)
}
@@ -1557,29 +1594,6 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res
e.tundev.InjectOutbound(tsmpPing)
}
-func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) {
- srcIP, err := e.mySelfIPMatchingFamily(ip)
- if err != nil {
- e.logf("getting matching node: %s", err)
- return
- }
- tdka := packet.TSMPDiscoKeyAdvertisement{
- Src: srcIP,
- Dst: ip,
- Key: e.magicConn.DiscoPublicKey(),
- }
- payload, err := tdka.Marshal()
- if err != nil {
- e.logf("error generating TSMP Advertisement: %s", err)
- metricTSMPDiscoKeyAdvertisementError.Add(1)
- } else if err := e.tundev.InjectOutbound(payload); err != nil {
- e.logf("error sending TSMP Advertisement: %s", err)
- metricTSMPDiscoKeyAdvertisementError.Add(1)
- } else {
- metricTSMPDiscoKeyAdvertisementSent.Add(1)
- }
-}
-
func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1593,6 +1607,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
}
}
+// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
+func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
+ srcIP, err := e.mySelfIPMatchingFamily(ip)
+ if err != nil {
+ return err
+ }
+
+ var iph packet.Header
+ if srcIP.Is4() {
+ iph = packet.IP4Header{
+ IPProto: ipproto.TSMP,
+ Src: srcIP,
+ Dst: ip,
+ }
+ } else {
+ iph = packet.IP6Header{
+ IPProto: ipproto.TSMP,
+ Src: srcIP,
+ Dst: ip,
+ }
+ }
+
+ var tsmpPayload [1]byte
+ tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
+
+ tsmpRequest := packet.Generate(iph, tsmpPayload[:])
+ return e.tundev.InjectOutbound(tsmpRequest)
+}
+
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1746,9 +1789,6 @@ var (
metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes")
metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes")
-
- metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent")
- metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error")
)
func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) {
diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go
index 0a1d2924d..abcf2f64f 100644
--- a/wgengine/userspace_test.go
+++ b/wgengine/userspace_test.go
@@ -325,7 +325,7 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) {
}
}
-func TestTSMPKeyAdvertisement(t *testing.T) {
+func TestTSMPDiscoKeyRequest(t *testing.T) {
var knobs controlknobs.Knobs
bus := eventbustest.NewBus(t)
@@ -369,13 +369,12 @@ func TestTSMPKeyAdvertisement(t *testing.T) {
t.Fatal(err)
}
- addr := netip.MustParseAddr("100.100.99.1")
- previousValue := metricTSMPDiscoKeyAdvertisementSent.Value()
- ue.sendTSMPDiscoAdvertisement(addr)
- if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue {
- errs := metricTSMPDiscoKeyAdvertisementError.Value()
- t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs)
+ peerAddr := netip.MustParseAddr("100.100.99.1")
+ err = ue.sendTSMPDiscoKeyRequest(peerAddr)
+ if err != nil {
+ t.Fatalf("sendTSMPDiscoKeyRequest failed: %v", err)
}
+
// Remove config to have the engine shut down more consistently
err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{})
if err != nil {