summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--disco/disco.go59
-rw-r--r--disco/disco_test.go15
-rw-r--r--types/key/node.go20
-rw-r--r--wgengine/magicsock/magicsock.go141
-rw-r--r--wgengine/magicsock/magicsock_test.go61
5 files changed, 296 insertions, 0 deletions
diff --git a/disco/disco.go b/disco/disco.go
index 0e7c3f7e5..7c92507d6 100644
--- a/disco/disco.go
+++ b/disco/disco.go
@@ -27,6 +27,7 @@ import (
"net/netip"
"go4.org/mem"
+ "golang.org/x/crypto/nacl/box"
"tailscale.com/types/key"
)
@@ -44,6 +45,8 @@ const (
TypePing = MessageType(0x01)
TypePong = MessageType(0x02)
TypeCallMeMaybe = MessageType(0x03)
+ TypeKnock = MessageType(0x04)
+ TypeKnockReply = MessageType(0x05)
)
const v0 = byte(0)
@@ -83,6 +86,10 @@ func Parse(p []byte) (Message, error) {
return parsePong(ver, p)
case TypeCallMeMaybe:
return parseCallMeMaybe(ver, p)
+ case TypeKnock:
+ return parseKnock(ver, p)
+ case TypeKnockReply:
+ return parseKnockReply(ver, p)
default:
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
}
@@ -240,6 +247,54 @@ func parsePong(ver uint8, p []byte) (m *Pong, err error) {
return m, nil
}
+type Knock struct {
+ // SealedNonce is the random client-generated per-knock nonce,
+ // which is NaCL-box sealed to the node key of the destination.
+ // The unencrypted nonce is 8 bytes.
+ SealedNonce [box.AnonymousOverhead + 8]byte
+}
+
+func (m *Knock) AppendMarshal(b []byte) []byte {
+ dataLen := box.AnonymousOverhead + 8
+ ret, d := appendMsgHeader(b, TypeKnock, v0, dataLen)
+ copy(d, m.SealedNonce[:])
+ return ret
+}
+
+func parseKnock(ver uint8, p []byte) (m *Knock, err error) {
+ if len(p) < (box.AnonymousOverhead + 8) {
+ return nil, errShort
+ }
+ m = new(Knock)
+ p = p[copy(m.SealedNonce[:], p):]
+ // Deliberately lax on longer-than-expected messages, for future
+ // compatibility.
+ return m, nil
+}
+
+type KnockReply struct {
+ // Nonce is the nonce value from the Knock request.
+ Nonce [8]byte
+}
+
+func (m *KnockReply) AppendMarshal(b []byte) []byte {
+ dataLen := 8
+ ret, d := appendMsgHeader(b, TypeKnockReply, v0, dataLen)
+ copy(d, m.Nonce[:])
+ return ret
+}
+
+func parseKnockReply(ver uint8, p []byte) (m *KnockReply, err error) {
+ if len(p) < 8 {
+ return nil, errShort
+ }
+ m = new(KnockReply)
+ p = p[copy(m.Nonce[:], p):]
+ // Deliberately lax on longer-than-expected messages, for future
+ // compatibility.
+ return m, nil
+}
+
// MessageSummary returns a short summary of m for logging purposes.
func MessageSummary(m Message) string {
switch m := m.(type) {
@@ -249,6 +304,10 @@ func MessageSummary(m Message) string {
return fmt.Sprintf("pong tx=%x", m.TxID[:6])
case *CallMeMaybe:
return "call-me-maybe"
+ case *Knock:
+ return fmt.Sprintf("knock")
+ case *KnockReply:
+ return fmt.Sprintf("knock reply nonce=%x", m.Nonce[:])
default:
return fmt.Sprintf("%#v", m)
}
diff --git a/disco/disco_test.go b/disco/disco_test.go
index 67bd1561a..6d8a2be13 100644
--- a/disco/disco_test.go
+++ b/disco/disco_test.go
@@ -4,6 +4,7 @@
package disco
import (
+ "bytes"
"fmt"
"net/netip"
"reflect"
@@ -66,6 +67,20 @@ func TestMarshalAndParse(t *testing.T) {
},
want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15",
},
+ {
+ name: "knock",
+ m: &Knock{
+ SealedNonce: [16 + 32 + 8]byte(bytes.Repeat([]byte{1, 2}, 28)),
+ },
+ want: "04 00 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02",
+ },
+ {
+ name: "knock_reply",
+ m: &KnockReply{
+ Nonce: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ want: "05 00 01 02 03 04 05 06 07 08",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff --git a/types/key/node.go b/types/key/node.go
index a84057231..8d7b99d59 100644
--- a/types/key/node.go
+++ b/types/key/node.go
@@ -142,6 +142,26 @@ func (k NodePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte
return box.Open(nil, ciphertext[len(nonce):], nonce, &p.k, &k.k)
}
+// SealAnonymous seals the cleartext to the node key k.
+func (k NodePublic) SealAnonymous(cleartext []byte) (ciphertext []byte, err error) {
+ if k.IsZero() {
+ panic("can't seal with zero keys")
+ }
+ return box.SealAnonymous(nil, cleartext, &k.k, nil)
+}
+
+// OpenAnonymous opens the anonymous NaCl box ciphertext, which must be a value
+// created by SealAnonymous, and returns the inner cleartext if ciphertext is
+// a valid box to k.
+func (k NodePrivate) OpenAnonymous(ciphertext []byte) (cleartext []byte, ok bool) {
+ if k.IsZero() {
+ panic("can't open with zero keys")
+ }
+
+ p := k.Public()
+ return box.OpenAnonymous(nil, ciphertext, &p.k, &k.k)
+}
+
func (k NodePrivate) UntypedHexString() string {
return hex.EncodeToString(k.k[:])
}
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go
index dea8b2d97..ee7f755a6 100644
--- a/wgengine/magicsock/magicsock.go
+++ b/wgengine/magicsock/magicsock.go
@@ -2306,6 +2306,25 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
ep.publicKey.ShortString(), derpStr(src.String()),
len(dm.MyNumber))
go ep.handleCallMeMaybe(dm)
+ case *disco.Knock:
+ metricRecvDiscoKnock.Add(1)
+ if isDERP {
+ metricRecvDiscoKnockBadDisco.Add(1)
+ c.logf("[unexpected] Knock packets should only come via LAN")
+ return
+ }
+ c.handleKnockLocked(dm, src, di)
+ case *disco.KnockReply:
+ metricRecvDiscoKnockReply.Add(1)
+ if isDERP {
+ metricRecvDiscoKnockReplyBadDisco.Add(1)
+ c.logf("[unexpected] Knock reply packets should only come via LAN")
+ return
+ }
+ c.logf("magicsock: disco: got knock reply %v from %v", dm, src)
+ c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) {
+ return !ep.handleKnockReplyLocked(dm, src, di)
+ })
}
return
}
@@ -2348,6 +2367,114 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic
return nk, false
}
+// handleKnockReplyLocked handles a DISCO Knock Reply message. If the nonce is
+// correct, the callback for the pending knock is invoked.
+//
+// True is returned if this endpoint handled the nonce.
+//
+// di is the discoInfo of the source of the knock packet.
+func (de *endpoint) handleKnockReplyLocked(dm *disco.KnockReply, src netip.AddrPort, di *discoInfo) bool {
+ de.mu.Lock()
+ defer de.mu.Unlock()
+
+ if de.pendingKnock == nil || !bytes.Equal(dm.Nonce[:], de.pendingKnock.nonce[:]) {
+ return false
+ }
+
+ // From this point on, nonce is correct
+ cb := de.pendingKnock.cb
+ de.pendingKnock = nil
+ go cb(nil)
+ return true
+}
+
+// handleKnockLocked handles a DISCO Knock message. If the recieved packet
+// is in order, a response is sent containing the unwrapped nonce.
+//
+// di is the discoInfo of the source of the knock packet.
+func (c *Conn) handleKnockLocked(dm *disco.Knock, src netip.AddrPort, di *discoInfo) {
+ // TODO(tom): Filter to LAN-only sources
+
+ nonceBytes, ok := c.privateKey.OpenAnonymous(dm.SealedNonce[:])
+ if !ok {
+ metricRecvDiscoKnockBadSeal.Add(1)
+ c.logf("magicsock: disco: dropping bad knock from %v", src)
+ return
+ }
+
+ var nonce [8]byte
+ copy(nonce[:], nonceBytes)
+
+ c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) {
+ go c.sendDiscoMessage(src, ep.publicKey, di.discoKey, &disco.KnockReply{
+ Nonce: nonce,
+ }, discoVerboseLog)
+ return true
+ })
+}
+
+// Knock handles a request to knock a specific peer.
+func (c *Conn) Knock(addr netip.AddrPort, peer *tailcfg.Node, cb func(error)) {
+ if runtime.GOOS == "js" {
+ cb(errors.New("no direct over tsconnect"))
+ return
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.privateKey.IsZero() {
+ cb(errNetworkDown)
+ return
+ }
+
+ ep, ok := c.peerMap.endpointForNodeKey(peer.Key)
+ if !ok {
+ cb(errors.New("unknown peer"))
+ return
+ }
+ ep.knock(addr, cb)
+}
+
+func (de *endpoint) knock(addr netip.AddrPort, cb func(error)) {
+ de.mu.Lock()
+ defer de.mu.Unlock()
+
+ if de.expired {
+ cb(errExpired)
+ return
+ }
+ epDisco := de.disco.Load()
+ if epDisco == nil {
+ cb(errors.New("no disco key"))
+ return
+ }
+
+ var nonce [8]byte
+ if _, err := crand.Read(nonce[:]); err != nil {
+ panic(err) // worth dying for
+ }
+ sealed, err := de.publicKey.SealAnonymous(nonce[:])
+ if err != nil {
+ cb(err)
+ return
+ }
+
+ if de.pendingKnock != nil {
+ de.pendingKnock.cb(errors.New("superceded"))
+ }
+ de.pendingKnock = &pendingKnock{addr, cb, nonce}
+
+ go func() {
+ knock := disco.Knock{}
+ copy(knock.SealedNonce[:], sealed)
+ sent, _ := de.c.sendDiscoMessage(addr, de.publicKey, epDisco.key, &knock, discoVerboseLog)
+ if !sent {
+ panic("not sent")
+ }
+ }()
+ de.noteActiveLocked()
+}
+
// di is the discoInfo of the source of the ping.
// derpNodeSrc is non-zero if the ping arrived via DERP.
func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInfo, derpNodeSrc key.NodePublic) {
@@ -4141,6 +4268,8 @@ type endpoint struct {
pendingCLIPings []pendingCLIPing // any outstanding "tailscale ping" commands running
+ pendingKnock *pendingKnock // any outstanding knock challenge, if any
+
// The following fields are related to the new "silent disco"
// implementation that's a WIP as of 2022-10-20.
// See #540 for background.
@@ -4156,6 +4285,12 @@ type pendingCLIPing struct {
cb func(*ipnstate.PingResult)
}
+type pendingKnock struct {
+ addr netip.AddrPort
+ cb func(error)
+ nonce [8]byte
+}
+
const (
// sessionActiveTimeout is how long since the last activity we
// try to keep an established endpoint peering alive.
@@ -5269,6 +5404,7 @@ func (de *endpoint) stopAndReset() {
de.heartBeatTimer = nil
}
de.pendingCLIPings = nil
+ de.pendingKnock = nil
}
// resetLocked clears all the endpoint's p2p state, reverting it to a
@@ -5468,6 +5604,11 @@ var (
metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe")
metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node")
metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco")
+ metricRecvDiscoKnock = clientmetric.NewCounter("magicsock_disco_recv_knock")
+ metricRecvDiscoKnockBadDisco = clientmetric.NewCounter("magicsock_disco_recv_knock_bad_disco")
+ metricRecvDiscoKnockBadSeal = clientmetric.NewCounter("magicsock_disco_recv_knock_bad_seal")
+ metricRecvDiscoKnockReply = clientmetric.NewCounter("magicsock_disco_recv_knock_reply")
+ metricRecvDiscoKnockReplyBadDisco = clientmetric.NewCounter("magicsock_disco_recv_knock_reply_bad_disco")
metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here")
metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown")
// metricDERPHomeChange is how many times our DERP home region DI has
diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go
index 78e5bb232..0423447c8 100644
--- a/wgengine/magicsock/magicsock_test.go
+++ b/wgengine/magicsock/magicsock_test.go
@@ -2908,3 +2908,64 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
}
}
}
+
+func TestDiscoKnock(t *testing.T) {
+ mstun := &natlab.Machine{Name: "stun"}
+ m1 := &natlab.Machine{Name: "m1"}
+ m2 := &natlab.Machine{Name: "m2"}
+ inet := natlab.NewInternet()
+ sif := mstun.Attach("eth0", inet)
+ m1if := m1.Attach("eth0", inet)
+ m2if := m2.Attach("eth0", inet)
+
+ d := &devices{
+ m1: m1,
+ m1IP: m1if.V4(),
+ m2: m2,
+ m2IP: m2if.V4(),
+ stun: mstun,
+ stunIP: sif.V4(),
+ }
+
+ logf, closeLogf := logger.LogfCloser(t.Logf)
+ defer closeLogf()
+
+ derpMap, cleanup := runDERPAndStun(t, logf, d.stun, d.stunIP)
+ defer cleanup()
+
+ ms1 := newMagicStack(t, logger.WithPrefix(logf, "conn1: "), d.m1, derpMap)
+ defer ms1.Close()
+ ms2 := newMagicStack(t, logger.WithPrefix(logf, "conn2: "), d.m2, derpMap)
+ defer ms2.Close()
+
+ cleanup = meshStacks(t.Logf, nil, ms1, ms2)
+ defer cleanup()
+
+ // Wait for both peers to know about each other.
+ for {
+ if s1 := ms1.Status(); len(s1.Peer) != 1 {
+ time.Sleep(10 * time.Millisecond)
+ continue
+ }
+ if s2 := ms2.Status(); len(s2.Peer) != 1 {
+ time.Sleep(10 * time.Millisecond)
+ continue
+ }
+ break
+ }
+
+ cbErr := make(chan error, 1)
+ ms1.conn.Knock(netip.AddrPortFrom(m2if.V4(), ms2.conn.pconn4.LocalAddr().AddrPort().Port()), &tailcfg.Node{Key: ms2.privateKey.Public()}, func(err error) {
+ cbErr <- err
+ })
+
+ select {
+ case err := <-cbErr:
+ if err != nil {
+ t.Errorf("Knock failed: %v", err)
+ }
+
+ case <-time.After(2 * time.Second):
+ t.Error("timeout waiting for knock callback")
+ }
+}