diff options
Diffstat (limited to 'wgengine/magicsock/magicsock.go')
| -rw-r--r-- | wgengine/magicsock/magicsock.go | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 78ffd0cd0..a868cdb75 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -163,10 +163,12 @@ type Conn struct { derpActiveFunc func() idleFunc func() time.Duration // nil means unknown testOnlyPacketListener nettype.PacketListener - noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity - netMon *netmon.Monitor // must be non-nil - health *health.Tracker // or nil - controlKnobs *controlknobs.Knobs // or nil + noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity + netMon *netmon.Monitor // must be non-nil + health *health.Tracker // or nil + controlKnobs *controlknobs.Knobs // or nil + discoMessageHook func(disco.Message, key.DiscoPublic, key.NodePublic) bool + acceptDiscoFromUnknownPeer func(key.DiscoPublic) bool // ================================================================ // No locking required to access these fields, either because @@ -495,6 +497,16 @@ type Options struct { // DisablePortMapper, if true, disables the portmapper. // This is primarily useful in tests. DisablePortMapper bool + + // DiscoMessageHook, if non-nil, is called when a disco message is + // received from a peer. If it returns true, the message is considered + // handled and no further processing occurs. + DiscoMessageHook func(dm disco.Message, sender key.DiscoPublic, derpNodeSrc key.NodePublic) (handled bool) + + // AcceptDiscoFromUnknownPeer, if non-nil, is called when a disco + // message arrives from an unknown peer. If it returns true, the + // message is accepted and a discoInfo is created for the sender. + AcceptDiscoFromUnknownPeer func(sender key.DiscoPublic) bool } func (o *Options) logf() logger.Logf { @@ -630,6 +642,8 @@ func NewConn(opts Options) (*Conn, error) { c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener c.noteRecvActivity = opts.NoteRecvActivity + c.discoMessageHook = opts.DiscoMessageHook + c.acceptDiscoFromUnknownPeer = opts.AcceptDiscoFromUnknownPeer // Set up publishers and subscribers. Subscribe calls must return before // NewConn otherwise published events can be missed. @@ -2151,6 +2165,8 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake } case c.peerMap.knownPeerDiscoKey(sender): di = c.discoInfoForKnownPeerLocked(sender) + case c.acceptDiscoFromUnknownPeer != nil && c.acceptDiscoFromUnknownPeer(sender): + di = c.discoInfoForKnownPeerLocked(sender) default: metricRecvDiscoBadPeer.Add(1) if debugDisco() { @@ -2233,6 +2249,10 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake return } + if c.discoMessageHook != nil && c.discoMessageHook(dm, sender, derpNodeSrc) { + return + } + switch dm := dm.(type) { case *disco.Ping: metricRecvDiscoPing.Add(1) @@ -2635,6 +2655,34 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { return di } +// SetDiscoKey sets the disco private key used by this Conn. +func (c *Conn) SetDiscoKey(priv key.DiscoPrivate) { + c.discoAtomic.Set(priv) +} + +// SendDiscoMessageOverDERP sends a disco message to a peer identified by +// its disco and node public keys via the specified DERP region. +func (c *Conn) SendDiscoMessageOverDERP(dstDisco key.DiscoPublic, dstNode key.NodePublic, derpRegion int, m disco.Message) (sent bool, err error) { + dstAddr := netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(derpRegion)) + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return false, errConnClosed + } + pkt := make([]byte, 0, 512) + pkt = append(pkt, disco.Magic...) + pkt = c.discoAtomic.Public().AppendTo(pkt) + di := c.discoInfoForKnownPeerLocked(dstDisco) + c.mu.Unlock() + + box := di.sharedKey.Seal(m.AppendMarshal(nil)) + pkt = append(pkt, box...) + const isDisco = true + const isGeneveEncap = false + return c.sendAddr(dstAddr, dstNode, pkt, isDisco, isGeneveEncap) +} + func (c *Conn) SetNetworkUp(up bool) { c.mu.Lock() defer c.mu.Unlock() |
