summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrew Dunham <andrew@du.nham.ca>2024-08-15 13:16:40 -0400
committerAndrew Dunham <andrew@du.nham.ca>2024-08-15 16:37:23 -0400
commit7dde340194d36f506a734774a417ec3bb4f74060 (patch)
tree1ad26a1a07199dd0329377779d9499063084619e
parent35b91cb66a41e36fc8c32eaf52da5aa143bf14b7 (diff)
downloadtailscale-andrew/disco-af-packet-refactor.tar.xz
tailscale-andrew/disco-af-packet-refactor.zip
fixup! wgengine/magicsock: actually use AF_PACKET socket for raw discoandrew/disco-af-packet-refactor
Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I2c71d7598b9e30df717329db7dc17cb4ad3f05f6
-rw-r--r--wgengine/magicsock/magicsock_linux.go112
-rw-r--r--wgengine/magicsock/magicsock_linux_test.go182
2 files changed, 154 insertions, 140 deletions
diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go
index ebb0988d1..d8a8ad74f 100644
--- a/wgengine/magicsock/magicsock_linux.go
+++ b/wgengine/magicsock/magicsock_linux.go
@@ -6,6 +6,7 @@ package magicsock
import (
"bytes"
"context"
+ "encoding/binary"
"errors"
"fmt"
"io"
@@ -24,7 +25,6 @@ import (
"tailscale.com/disco"
"tailscale.com/envknob"
"tailscale.com/net/netns"
- "tailscale.com/net/packet"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
@@ -51,6 +51,14 @@ var (
// receives the entire IPv4 packet, but not the Ethernet
// header.
+ // Double-check that this is a UDP packet; we shouldn't be
+ // seeing anything else given how we create our AF_PACKET
+ // socket, but an extra check here is cheap, and matches the
+ // check that we do in the IPv6 path.
+ bpf.LoadAbsolute{Off: 9, Size: 1},
+ bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
+ bpf.RetConstant{Val: 0x0},
+
// Disco packets are so small they should never get
// fragmented, and we don't want to handle reassembly.
bpf.LoadAbsolute{Off: 6, Size: 2},
@@ -235,7 +243,6 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
var (
ctx = context.Background()
buf [1500]byte
- pkt packet.Parsed
)
for {
n, _, err := sock.Recvfrom(ctx, buf[:], 0)
@@ -244,10 +251,11 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
}
- if !decodeDiscoPacket(&pkt, c.discoLogf, buf[:n], family == "ip6") {
+ _ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
+ if payload == nil {
continue
}
- if payload := pkt.Payload(); !bytes.Equal(payload, testDiscoPacket) {
+ if !bytes.Equal(payload, testDiscoPacket) {
c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
continue
}
@@ -260,50 +268,60 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return sock, nil
}
-// decodeDiscoPacket decodes a disco packet from buf, using pkt as storage for
-// the parsed packet. It returns true if the packet is a valid disco packet,
-// and false otherwise.
+// parseUDPPacket is a basic parser for UDP packets that returns the source and
+// destination addresses, and the payload. The returned payload is a sub-slice
+// of the input buffer.
+//
+// It expects to be called with a buffer that contains the entire UDP packet,
+// including the IP header, and one that has been filtered with the BPF
+// programs above.
//
-// It will log the reason for the packet being invalid to logf; it is the
-// caller's responsibility to control log verbosity.
-func decodeDiscoPacket(pkt *packet.Parsed, logf logger.Logf, buf []byte, isIPv6 bool) bool {
- // Do a quick length check before we parse the packet, so we can drop
- // things that we know are too small.
- var minSize int
+// If an error occurs, it will return the zero values for all return values.
+func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
+ // First, parse the IPv4 or IPv6 header to get to the UDP header. Since
+ // we assume this was filtered with BPF, we know that there will be no
+ // IPv6 extension headers.
+ var (
+ srcIP, dstIP netip.Addr
+ udp []byte
+ )
if isIPv6 {
- minSize = ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize
- } else {
- minSize = ipv4.HeaderLen + udpHeaderSize + discoMinHeaderSize
- }
- if len(buf) < minSize {
- logf("decodeDiscoPacket: received packet too small to be a disco packet: %d bytes < %d", len(buf), minSize)
- return false
- }
-
- // Parse the packet.
- pkt.Decode(buf)
+ // Basic length check to ensure that we don't panic
+ if len(buf) < ipv6.HeaderLen+udpHeaderSize {
+ return
+ }
- // Verify that this is a UDP packet.
- if pkt.IPProto != ipproto.UDP {
- logf("decodeDiscoPacket: received non-UDP packet: %d", pkt.IPProto)
- return false
- }
+ // Extract the source and destination addresses from the IPv6
+ // header.
+ srcIP, _ = netip.AddrFromSlice(buf[8:24])
+ dstIP, _ = netip.AddrFromSlice(buf[24:40])
- // Ensure that it's the right version of IP; given how we configure our
- // listening sockets, we shouldn't ever get the wrong one, but it's
- // best to confirm.
- var wantVersion uint8
- if isIPv6 {
- wantVersion = 6
+ // We know that the UDP packet starts immediately after the IPv6
+ // packet.
+ udp = buf[ipv6.HeaderLen:]
} else {
- wantVersion = 4
- }
- if pkt.IPVersion != wantVersion {
- logf("decodeDiscoPacket: received mismatched IP version %d (want %d)", pkt.IPVersion, wantVersion)
- return false
+ // This is an IPv4 packet; read the length field from the header.
+ if len(buf) < ipv4.HeaderLen {
+ return
+ }
+ udpOffset := int((buf[0] & 0x0F) << 2)
+ if udpOffset+udpHeaderSize > len(buf) {
+ return
+ }
+
+ // Parse the source and destination IPs.
+ srcIP, _ = netip.AddrFromSlice(buf[12:16])
+ dstIP, _ = netip.AddrFromSlice(buf[16:20])
+ udp = buf[udpOffset:]
}
- return true
+ // Parse the ports
+ srcPort := binary.BigEndian.Uint16(udp[0:2])
+ dstPort := binary.BigEndian.Uint16(udp[2:4])
+
+ // The payload starts after the UDP header.
+ payload = udp[8:]
+ return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
}
// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
@@ -358,10 +376,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
)
- var (
- buf [1500]byte
- pkt packet.Parsed
- )
+ var buf [1500]byte
for {
n, src, err := pc.Recvfrom(ctx, buf[:], 0)
if debugRawDiscoReads() {
@@ -375,12 +390,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
return
}
- if !decodeDiscoPacket(&pkt, dlogf, buf[:n], isIPV6) {
+ srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
+ if payload == nil {
// callee logged
continue
}
- dstPort := pkt.Dst.Port()
+ dstPort := dstAddr.Port()
if dstPort == 0 {
logf("[unexpected] received packet for port 0")
}
@@ -417,7 +433,7 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
metricRecvDiscoPacketIPv4.Add(1)
}
- c.handleDiscoMessage(pkt.Payload(), pkt.Src, key.NodePublic{}, discoRXPathRawSocket)
+ c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
}
}
diff --git a/wgengine/magicsock/magicsock_linux_test.go b/wgengine/magicsock/magicsock_linux_test.go
index e9e7d73d8..6b86b04f2 100644
--- a/wgengine/magicsock/magicsock_linux_test.go
+++ b/wgengine/magicsock/magicsock_linux_test.go
@@ -4,114 +4,112 @@
package magicsock
import (
+ "bytes"
"encoding/binary"
- "net"
"net/netip"
"testing"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/cpu"
"golang.org/x/sys/unix"
"tailscale.com/disco"
- "tailscale.com/net/packet"
- "tailscale.com/types/ipproto"
)
-func TestDecodeDiscoPacket(t *testing.T) {
- mk4 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
- if !src.Is4() || !dst.Is4() {
- panic("not an IPv4 address")
- }
- iph := &ipv4.Header{
- Version: ipv4.Version,
- Len: ipv4.HeaderLen,
- TotalLen: ipv4.HeaderLen + len(data),
- TTL: 64,
- Protocol: int(proto),
- Src: net.IP(src.AsSlice()),
- Dst: net.IP(dst.AsSlice()),
- }
- hdr, err := iph.Marshal()
- if err != nil {
- panic(err)
- }
- return append(hdr, data...)
- }
- mk6 := func(proto ipproto.Proto, src, dst netip.Addr, data []byte) []byte {
- if !src.Is6() || !dst.Is6() {
- panic("not an IPv6 address")
- }
- // The ipv6 package doesn't have a Marshal method, so just do
- // the most basic thing and construct the header manually.
- buf := make([]byte, ipv6.HeaderLen, ipv6.HeaderLen+len(data))
- buf[0] = 6 << 4 // version
- binary.BigEndian.PutUint16(buf[4:6], uint16(len(data)))
- buf[6] = byte(proto)
- copy(buf[8:24], src.AsSlice())
- copy(buf[24:40], dst.AsSlice())
- return append(buf, data...)
- }
+func TestParseUDPPacket(t *testing.T) {
+ src4 := netip.MustParseAddrPort("127.0.0.1:12345")
+ dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
- mkUDP := func(srcPort, dstPort uint16, data []byte) []byte {
- udp := make([]byte, 8, 8+len(data))
- binary.BigEndian.PutUint16(udp[0:2], srcPort)
- binary.BigEndian.PutUint16(udp[2:4], dstPort)
- binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(data)))
- return append(udp, data...)
- }
- mkUDP4 := func(src, dst netip.AddrPort, data []byte) []byte {
- return mk4(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
- }
- mkUDP6 := func(src, dst netip.AddrPort, data []byte) []byte {
- return mk6(ipproto.UDP, src.Addr(), dst.Addr(), mkUDP(src.Port(), dst.Port(), data))
+ src6 := netip.MustParseAddrPort("[::1]:12345")
+ dst6 := netip.MustParseAddrPort("[::2]:54321")
+
+ udp4Packet := []byte{
+ // IPv4 header
+ 0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
+ 0x40, 0x11, 0x00, 0x00,
+ 0x7f, 0x00, 0x00, 0x01, // source ip
+ 0x7f, 0x00, 0x00, 0x02, // dest ip
+
+ // UDP header
+ 0x30, 0x39, // src port
+ 0xd4, 0x31, // dest port
+ 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
+ 0x00, 0x00, // checksum; unused
+
+ // Payload: disco magic plus 4 bytes
+ 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
+ udp6Packet := []byte{
+ // IPv6 header
+ 0x60, 0x00, 0x00, 0x00,
+ 0x00, 0x12, // payload length
+ 0x11, // next header: UDP
+ 0x00, // hop limit; unused
+
+ // Source IP
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
+ // Dest IP
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
- ip4 := netip.MustParseAddrPort("127.0.0.10:12345")
- ip4_2 := netip.MustParseAddrPort("127.0.0.99:54321")
- ip6 := netip.MustParseAddrPort("[::1]:12345")
+ // UDP header
+ 0x30, 0x39, // src port
+ 0xd4, 0x31, // dest port
+ 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
+ 0x00, 0x00, // checksum; unused
- testCases := []struct {
- name string
- in []byte
- is6 bool
- want bool
- }{
- {
- name: "too_short_4",
- in: mkUDP4(ip4, ip4_2, append([]byte(disco.Magic), 0x00, 0x00)),
- is6: false,
- want: false,
- },
- {
- name: "too_short_6",
- in: mkUDP6(ip6, ip6, append([]byte(disco.Magic), 0x00, 0x00)),
- is6: true,
- want: false,
- },
- {
- name: "valid_4",
- in: mkUDP4(ip4, ip4_2, testDiscoPacket),
- is6: false,
- want: true,
- },
- {
- name: "valid_6",
- in: mkUDP6(ip6, ip6, testDiscoPacket),
- is6: true,
- want: true,
- },
+ // Payload: disco magic plus 4 bytes
+ 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- var pkt packet.Parsed
- got := decodeDiscoPacket(&pkt, t.Logf, tc.in, tc.is6)
- if got != tc.want {
- t.Errorf("got %v; want %v", got, tc.want)
+ // Verify that parsing the UDP packet works correctly.
+ t.Run("IPv4", func(t *testing.T) {
+ src, dst, payload := parseUDPPacket(udp4Packet, false)
+ if src != src4 {
+ t.Errorf("src = %v; want %v", src, src4)
+ }
+ if dst != dst4 {
+ t.Errorf("dst = %v; want %v", dst, dst4)
+ }
+ if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
+ t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ src, dst, payload := parseUDPPacket(udp6Packet, true)
+ if src != src6 {
+ t.Errorf("src = %v; want %v", src, src6)
+ }
+ if dst != dst6 {
+ t.Errorf("dst = %v; want %v", dst, dst6)
+ }
+ if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
+ t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
+ }
+ })
+ t.Run("Truncated", func(t *testing.T) {
+ truncateBy := func(b []byte, n int) []byte {
+ if n >= len(b) {
+ return nil
}
- })
- }
+ return b[:len(b)-n]
+ }
+
+ src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
+ if payload != nil {
+ t.Errorf("payload = %x; want nil", payload)
+ }
+ if src.IsValid() || dst.IsValid() {
+ t.Errorf("src = %v, dst = %v; want invalid", src, dst)
+ }
+
+ src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
+ if payload != nil {
+ t.Errorf("payload = %x; want nil", payload)
+ }
+ if src.IsValid() || dst.IsValid() {
+ t.Errorf("src = %v, dst = %v; want invalid", src, dst)
+ }
+ })
}
func TestEthernetProto(t *testing.T) {