summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tsnet/tsnet.go63
-rw-r--r--tsnet/tsnet_test.go279
-rw-r--r--wgengine/netstack/netstack.go75
-rw-r--r--wgengine/netstack/netstack_test.go280
4 files changed, 683 insertions, 14 deletions
diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go
index ea165e932..1bce7737d 100644
--- a/tsnet/tsnet.go
+++ b/tsnet/tsnet.go
@@ -158,6 +158,23 @@ type Server struct {
// that the control server will allow the node to adopt that tag.
AdvertiseTags []string
+ // IPPacketHandler, if non-nil, specifies a handler that will be called for
+ // incoming TCP/UDP packets that netstack will not process.
+ //
+ // The packet slice contains a complete IP packet starting with the IP header.
+ // The handler should not retain the packet slice after returning.
+ //
+ // Return true if the packet was handled, false to pass it to the host network stack.
+ //
+ // The handler will NOT see packets processed by netstack, including packets to
+ // local Tailscale IPs, subnet IPs, PeerAPI, SSH, or service IPs (100.100.100.100).
+ //
+ // The handler is called from the packet receive path and should not block.
+ //
+ // This must be set before calling Start, Listen, Dial, or Up.
+ // To send packets, call [IPPacketWriter()] to get an [IPPacketWriter].
+ IPPacketHandler func(packet []byte) bool
+
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
initOnce sync.Once
@@ -200,9 +217,25 @@ type Server struct {
// over the TCP conn.
type FallbackTCPHandler func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool)
+// IPPacketWriter is a function that sends an IP packet into the Tailscale network.
+// The packet must be a complete IP packet starting with the IP header.
+// It will be routed through the Tailscale network to the appropriate peer
+// based on the destination IP address.
+type IPPacketWriter func(pkt []byte) error
+
+// ErrIPPacketHandlerSet is returned by Listen, ListenPacket, ListenTLS,
+// ListenFunnel, and Dial when IPPacketHandler is set. These APIs rely on
+// netstack to process packets, which is bypassed when IPPacketHandler is set.
+var ErrIPPacketHandlerSet = errors.New("tsnet: cannot use socket APIs when IPPacketHandler is set")
+
// Dial connects to the address on the tailnet.
// It will start the server if it has not been started yet.
+//
+// If IPPacketHandler is set, Dial returns ErrIPPacketHandlerSet.
func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) {
+ if s.IPPacketHandler != nil {
+ return nil, ErrIPPacketHandlerSet
+ }
if err := s.Start(); err != nil {
return nil, err
}
@@ -667,6 +700,7 @@ func (s *Server) start() (reterr error) {
ns.ProcessSubnets = true
ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow
+ ns.HandleIPPacket = s.IPPacketHandler
s.netstack = ns
s.dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := eng.PeerForIP(ip)
@@ -1018,7 +1052,12 @@ func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(net
// IPv6 address of this node) only. To listen for traffic on other addresses
// such as those routed inbound via subnet routes, explicitly specify
// the listening address or use RegisterFallbackTCPHandler.
+//
+// If IPPacketHandler is set, Listen returns ErrIPPacketHandlerSet.
func (s *Server) Listen(network, addr string) (net.Listener, error) {
+ if s.IPPacketHandler != nil {
+ return nil, ErrIPPacketHandlerSet
+ }
return s.listen(network, addr, listenOnTailnet)
}
@@ -1029,7 +1068,12 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
// corresponding to "udp4" or "udp6" respectively. IP must be specified.
//
// If s has not been started yet, it will be started.
+//
+// If IPPacketHandler is set, ListenPacket returns ErrIPPacketHandlerSet.
func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
+ if s.IPPacketHandler != nil {
+ return nil, ErrIPPacketHandlerSet
+ }
ap, err := resolveListenAddr(network, addr)
if err != nil {
return nil, err
@@ -1050,10 +1094,24 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
return s.netstack.ListenPacket(network, ap.String())
}
+// IPPacketWriter returns a function to send IP packets into the Tailscale network.
+func (s *Server) IPPacketWriter() (IPPacketWriter, error) {
+ if err := s.Start(); err != nil {
+ return nil, err
+ }
+
+ return s.netstack.InjectOutbound, nil
+}
+
// ListenTLS announces only on the Tailscale network.
// It returns a TLS listener wrapping the tsnet listener.
// It will start the server if it has not been started yet.
+//
+// If IPPacketHandler is set, ListenTLS returns ErrIPPacketHandlerSet.
func (s *Server) ListenTLS(network, addr string) (net.Listener, error) {
+ if s.IPPacketHandler != nil {
+ return nil, ErrIPPacketHandlerSet
+ }
if network != "tcp" {
return nil, fmt.Errorf("ListenTLS(%q, %q): only tcp is supported", network, addr)
}
@@ -1159,7 +1217,12 @@ func FunnelTLSConfig(conf *tls.Config) FunnelOption {
// and the only other supported addrs currently are ":8443" and ":10000".
//
// It will start the server if it has not been started yet.
+//
+// If IPPacketHandler is set, ListenFunnel returns ErrIPPacketHandlerSet.
func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.Listener, error) {
+ if s.IPPacketHandler != nil {
+ return nil, ErrIPPacketHandlerSet
+ }
if network != "tcp" {
return nil, fmt.Errorf("ListenFunnel(%q, %q): only tcp is supported", network, addr)
}
diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go
index 838d5f3f5..3943bda56 100644
--- a/tsnet/tsnet_test.go
+++ b/tsnet/tsnet_test.go
@@ -13,6 +13,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
+ "encoding/binary"
"errors"
"flag"
"fmt"
@@ -42,11 +43,13 @@ import (
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netns"
+ "tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/deptest"
"tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol"
+ "tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/must"
@@ -1592,3 +1595,279 @@ func TestResolveAuthKey(t *testing.T) {
})
}
}
+
+// TestIPPacketHandler tests the IPPacketHandler functionality end-to-end.
+// It sets up two tsnet instances:
+// - s1: a "regular" tsnet using the high-level socket API (ListenPacket)
+// - s2: a "raw packet" tsnet using IPPacketHandler/IPPacketWriter
+//
+// The test verifies:
+// 1. s2 can send a UDP packet to s1 using raw IP packets via IPPacketWriter
+// 2. s2 receives the UDP echo reply from s1 via IPPacketHandler
+// 3. PeerAPI still works on s2 (raw packet handling doesn't break internal services)
+func TestIPPacketHandler(t *testing.T) {
+ tstest.Shard(t)
+ tstest.ResourceCheck(t)
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ controlURL, _ := startControl(t)
+
+ // Start s1 - a regular tsnet server that will echo UDP packets
+ s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
+ pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:9999", s1ip)))
+ defer pc.Close()
+
+ echoServerDone := make(chan struct{})
+ t.Logf("s1: starting UDP echo server on %v:9999", s1ip)
+ go func() {
+ defer close(echoServerDone)
+ defer t.Log("s1: echo server goroutine exiting")
+ buf := make([]byte, 1500)
+ for {
+ t.Log("s1: echo server waiting for packet...")
+ n, addr, err := pc.ReadFrom(buf)
+ if err != nil {
+ t.Logf("s1: echo server ReadFrom error: %v", err)
+ return
+ }
+ t.Logf("s1: received UDP packet from %v: %q", addr, buf[:n])
+ // Echo back the payload
+ if _, err := pc.WriteTo(buf[:n], addr); err != nil {
+ t.Logf("s1: error writing response: %v", err)
+ }
+ }
+ }()
+ defer func() {
+ pc.Close()
+ <-echoServerDone
+ }()
+
+ // Set up s2 with IPPacketHandler - the raw packet server
+ receivedPackets := make(chan []byte, 10)
+ tmp := filepath.Join(t.TempDir(), "s2")
+ os.MkdirAll(tmp, 0755)
+ s2 := &Server{
+ Dir: tmp,
+ ControlURL: controlURL,
+ Hostname: "s2",
+ Store: new(mem.Store),
+ Ephemeral: true,
+ getCertForTesting: testCertRoot.getCert,
+ IPPacketHandler: func(pkt []byte) bool {
+ // Make a copy since we can't retain the slice
+ cpy := make([]byte, len(pkt))
+ copy(cpy, pkt)
+
+ var parsed packet.Parsed
+ parsed.Decode(cpy)
+ t.Logf("s2 IPPacketHandler: received %s packet %v -> %v", parsed.IPProto, parsed.Src, parsed.Dst)
+
+ select {
+ case receivedPackets <- cpy:
+ default:
+ t.Logf("s2: dropped packet, channel full")
+ }
+ return true // consume the packet
+ },
+ }
+ if *verboseNodes {
+ s2.Logf = log.Printf
+ }
+ t.Cleanup(func() { s2.Close() })
+
+ status, err := s2.Up(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s2ip := status.TailscaleIPs[0]
+ t.Logf("s2 IP: %v", s2ip)
+
+ writer, err := s2.IPPacketWriter()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ lc2 := must.Get(s2.LocalClient())
+ res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
+ if err != nil {
+ t.Fatalf("ping failed: %v", err)
+ }
+ t.Logf("ping success: latency=%v", res.LatencySeconds)
+
+ const srcPort = 12345
+ const dstPort = 9999
+ payload := []byte("hello from raw packet")
+
+ udpPkt := buildUDPPacket(t, s2ip, s1ip, srcPort, dstPort, payload)
+ t.Logf("s2: sending raw UDP packet (%d bytes) to %v:%d", len(udpPkt), s1ip, dstPort)
+
+ if err := writer(udpPkt); err != nil {
+ t.Fatalf("IPPacketWriter failed: %v", err)
+ }
+ t.Log("s2: packet injected successfully")
+
+ t.Log("s2: waiting for UDP echo reply")
+ select {
+ case <-ctx.Done():
+ t.Fatal("timeout waiting for UDP echo reply")
+ case pkt := <-receivedPackets:
+ t.Logf("s2: received raw packet (%d bytes)", len(pkt))
+
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+ if parsed.IPVersion != 4 {
+ t.Errorf("expected IPv4, got version %d", parsed.IPVersion)
+ }
+ if parsed.IPProto != ipproto.UDP {
+ t.Errorf("expected UDP, got protocol %d", parsed.IPProto)
+ }
+ if parsed.Src.Addr() != s1ip {
+ t.Errorf("expected src %v, got %v", s1ip, parsed.Src.Addr())
+ }
+ if parsed.Dst.Addr() != s2ip {
+ t.Errorf("expected dst %v, got %v", s2ip, parsed.Dst.Addr())
+ }
+ if parsed.Src.Port() != dstPort {
+ t.Errorf("expected src port %d, got %d", dstPort, parsed.Src.Port())
+ }
+ if parsed.Dst.Port() != srcPort {
+ t.Errorf("expected dst port %d, got %d", srcPort, parsed.Dst.Port())
+ }
+ if !bytes.Equal(parsed.Payload(), payload) {
+ t.Errorf("payload mismatch: got %q, want %q", parsed.Payload(), payload)
+ }
+ }
+
+ t.Log("Testing PeerAPI access to raw packet server...")
+ s2Status := must.Get(lc2.StatusWithoutPeers(ctx))
+ if len(s2Status.Self.PeerAPIURL) == 0 {
+ t.Fatal("s2 has no PeerAPI URLs")
+ }
+ peerAPIURL := s2Status.Self.PeerAPIURL[0]
+ t.Logf("s2 PeerAPI URL: %s", peerAPIURL)
+
+ resp, err := s1.HTTPClient().Get(peerAPIURL)
+ if err != nil {
+ t.Fatalf("failed to GET s2 PeerAPI from s1: %v", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("PeerAPI request returned status %d, want 200", resp.StatusCode)
+ }
+ t.Log("PeerAPI access to raw packet server succeeded")
+}
+
+// buildUDPPacket constructs a complete IPv4 UDP packet with valid checksums.
+func buildUDPPacket(t *testing.T, src, dst netip.Addr, srcPort, dstPort uint16, payload []byte) []byte {
+ t.Helper()
+
+ if !src.Is4() || !dst.Is4() {
+ t.Fatal("buildUDPPacket only supports IPv4")
+ }
+
+ const ipHeaderLen = 20
+ const udpHeaderLen = 8
+ totalLen := ipHeaderLen + udpHeaderLen + len(payload)
+ udpLen := udpHeaderLen + len(payload)
+
+ pkt := make([]byte, totalLen)
+
+ // IPv4 header
+ pkt[0] = 0x45 // Version (4) + IHL (5 = 20 bytes)
+ pkt[1] = 0 // DSCP + ECN
+ binary.BigEndian.PutUint16(pkt[2:4], uint16(totalLen)) // Total length
+ binary.BigEndian.PutUint16(pkt[4:6], 0) // Identification
+ binary.BigEndian.PutUint16(pkt[6:8], 0) // Flags + Fragment offset
+ pkt[8] = 64 // TTL
+ pkt[9] = uint8(ipproto.UDP) // Protocol
+ // pkt[10:12] = checksum (computed below)
+ copy(pkt[12:16], src.AsSlice())
+ copy(pkt[16:20], dst.AsSlice())
+
+ // Compute IP header checksum
+ var sum uint32
+ for i := 0; i < ipHeaderLen; i += 2 {
+ sum += uint32(binary.BigEndian.Uint16(pkt[i : i+2]))
+ }
+ for sum > 0xffff {
+ sum = (sum & 0xffff) + (sum >> 16)
+ }
+ binary.BigEndian.PutUint16(pkt[10:12], ^uint16(sum))
+
+ // UDP header
+ udp := pkt[ipHeaderLen:]
+ binary.BigEndian.PutUint16(udp[0:2], srcPort)
+ binary.BigEndian.PutUint16(udp[2:4], dstPort)
+ binary.BigEndian.PutUint16(udp[4:6], uint16(udpLen))
+ // udp[6:8] = checksum (computed below with pseudo-header)
+
+ // Payload
+ copy(udp[udpHeaderLen:], payload)
+
+ // Compute UDP checksum with pseudo-header
+ // Pseudo-header: src IP (4) + dst IP (4) + zero (1) + protocol (1) + UDP length (2) = 12 bytes
+ var udpSum uint32
+ // Add source IP
+ udpSum += uint32(binary.BigEndian.Uint16(src.AsSlice()[0:2]))
+ udpSum += uint32(binary.BigEndian.Uint16(src.AsSlice()[2:4]))
+ // Add dest IP
+ udpSum += uint32(binary.BigEndian.Uint16(dst.AsSlice()[0:2]))
+ udpSum += uint32(binary.BigEndian.Uint16(dst.AsSlice()[2:4]))
+ // Add protocol (UDP = 17)
+ udpSum += uint32(ipproto.UDP)
+ // Add UDP length
+ udpSum += uint32(udpLen)
+ // Add UDP header and payload
+ for i := 0; i < len(udp); i += 2 {
+ if i+1 < len(udp) {
+ udpSum += uint32(binary.BigEndian.Uint16(udp[i : i+2]))
+ } else {
+ udpSum += uint32(udp[i]) << 8 // odd byte
+ }
+ }
+ // Fold 32-bit sum to 16 bits
+ for udpSum > 0xffff {
+ udpSum = (udpSum & 0xffff) + (udpSum >> 16)
+ }
+ checksum := ^uint16(udpSum)
+ if checksum == 0 {
+ checksum = 0xffff // UDP checksum of 0 is transmitted as 0xffff
+ }
+ binary.BigEndian.PutUint16(udp[6:8], checksum)
+ return pkt
+}
+
+// TestIPPacketHandlerSocketAPIGuards verifies that Listen, ListenPacket,
+// ListenTLS, ListenFunnel, and Dial return ErrIPPacketHandlerSet when
+// IPPacketHandler is set.
+func TestIPPacketHandlerSocketAPIGuards(t *testing.T) {
+ s := &Server{
+ IPPacketHandler: func([]byte) bool { return true },
+ }
+
+ _, err := s.Listen("tcp", ":0")
+ if !errors.Is(err, ErrIPPacketHandlerSet) {
+ t.Errorf("Listen: got %v, want ErrIPPacketHandlerSet", err)
+ }
+
+ _, err = s.ListenPacket("udp", "127.0.0.1:0")
+ if !errors.Is(err, ErrIPPacketHandlerSet) {
+ t.Errorf("ListenPacket: got %v, want ErrIPPacketHandlerSet", err)
+ }
+
+ _, err = s.ListenTLS("tcp", ":0")
+ if !errors.Is(err, ErrIPPacketHandlerSet) {
+ t.Errorf("ListenTLS: got %v, want ErrIPPacketHandlerSet", err)
+ }
+
+ _, err = s.ListenFunnel("tcp", ":443")
+ if !errors.Is(err, ErrIPPacketHandlerSet) {
+ t.Errorf("ListenFunnel: got %v, want ErrIPPacketHandlerSet", err)
+ }
+
+ _, err = s.Dial(context.Background(), "tcp", "127.0.0.1:80")
+ if !errors.Is(err, ErrIPPacketHandlerSet) {
+ t.Errorf("Dial: got %v, want ErrIPPacketHandlerSet", err)
+ }
+}
diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go
index c2b5d8a32..17f9cfd7d 100644
--- a/wgengine/netstack/netstack.go
+++ b/wgengine/netstack/netstack.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/tailscale/wireguard-go/conn"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -119,8 +120,8 @@ func maxInFlightConnectionAttemptsPerClient() int {
var debugNetstack = envknob.RegisterBool("TS_DEBUG_NETSTACK")
var (
- serviceIP = tsaddr.TailscaleServiceIP()
- serviceIPv6 = tsaddr.TailscaleServiceIPv6()
+ tsServiceIP = tsaddr.TailscaleServiceIP()
+ tsServiceIPv6 = tsaddr.TailscaleServiceIPv6()
)
func init() {
@@ -176,6 +177,22 @@ type Impl struct {
// It can only be set before calling Start.
ProcessSubnets bool
+ // HandleIPPacket, if non-nil, is called for incoming TCP/UDP packets that
+ // netstack will not process.
+ //
+ // The packet slice contains a complete IP packet starting with the IP header.
+ // The packet slice is only valid for the duration of the call and must not
+ // be retained.
+ //
+ // If HandleIPPacket returns true, the packet is consumed. If false, the packet
+ // is passed to the host network stack (if available).
+ //
+ // The handler will NOT see packets processed by netstack (local IPs, subnet IPs,
+ // PeerAPI, SSH, service IPs, or flows handled by GetTCPHandlerForFlow/GetUDPHandlerForFlow).
+ //
+ // The handler is called from the packet receive path and must not block.
+ HandleIPPacket func(packet []byte) bool
+
ipstack *stack.Stack
linkEP *linkEndpoint
tundev *tstun.Wrapper
@@ -415,6 +432,16 @@ func (ns *Impl) Close() error {
return nil
}
+// InjectOutbound sends an IP packet through the Tailscale network.
+// The packet must be a complete IP packet starting with the IP header.
+func (ns *Impl) InjectOutbound(pkt []byte) error {
+ packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Payload: buffer.MakeWithData(pkt),
+ })
+ // InjectOutboundPacketBuffer decrements the buffer reference count.
+ return ns.tundev.InjectOutboundPacketBuffer(packetBuf)
+}
+
// SetTransportProtocolOption forwards to the underlying
// [stack.Stack.SetTransportProtocolOption]. Callers are responsible for
// ensuring that the options are valid, compatible and appropriate for their use
@@ -772,7 +799,7 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper, gro *gro.
// Determine if we care about this local packet.
dst := p.Dst.Addr()
switch {
- case dst == serviceIP || dst == serviceIPv6:
+ case dst == tsServiceIP || dst == tsServiceIPv6:
// We want to intercept some traffic to the "service IP" (e.g.
// 100.100.100.100 for IPv4). However, of traffic to the
// service IP, we only care about UDP 53, and TCP on port 53,
@@ -994,13 +1021,13 @@ func (ns *Impl) shouldSendToHost(pkt *stack.PacketBuffer) bool {
switch v := hdr.(type) {
case header.IPv4:
srcIP := netip.AddrFrom4(v.SourceAddress().As4())
- if serviceIP == srcIP {
+ if tsServiceIP == srcIP {
return true
}
case header.IPv6:
srcIP := netip.AddrFrom16(v.SourceAddress().As16())
- if srcIP == serviceIPv6 {
+ if srcIP == tsServiceIPv6 {
return true
}
@@ -1064,7 +1091,7 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
// Handle incoming peerapi connections in netstack.
dstIP := p.Dst.Addr()
isLocal := ns.isLocalIP(dstIP)
- isService := ns.isVIPServiceIP(dstIP)
+ isVIPService := ns.isVIPServiceIP(dstIP)
// Handle TCP connection to the Tailscale IP(s) in some cases:
if ns.lb != nil && p.IPProto == ipproto.TCP && isLocal {
@@ -1087,7 +1114,20 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
return true
}
}
- if buildfeatures.HasServe && isService {
+ hittingServiceIP := dstIP == tsServiceIP || dstIP == tsServiceIPv6
+ if hittingServiceIP {
+ if p.IsEchoRequest() {
+ return true
+ }
+ if p.Dst.Port() == 53 {
+ return true
+ }
+ if ns.isLoopbackPort(p.Dst.Port()) {
+ return true
+ }
+ return false
+ }
+ if buildfeatures.HasServe && isVIPService {
if p.IsEchoRequest() {
return true
}
@@ -1103,6 +1143,13 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
if p.IPVersion == 6 && !isLocal && viaRange.Contains(dstIP) {
return ns.lb != nil && ns.lb.ShouldHandleViaIP(dstIP)
}
+ // If HandleIPPacket is set, don't process normal traffic in netstack.
+ // This allows the handler to receive all non-special packets.
+ // Special traffic (PeerAPI, SSH, service IP, VIP services) is still
+ // handled by netstack via the checks above.
+ if ns.HandleIPPacket != nil {
+ return false
+ }
if ns.ProcessLocalIPs && isLocal {
return true
}
@@ -1189,7 +1236,11 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper, gro *gro.GRO)
}
if !ns.shouldProcessInbound(p, t) {
- // Let the host network stack (if any) deal with it.
+ if ns.HandleIPPacket != nil {
+ if ns.HandleIPPacket(p.Buffer()) {
+ return filter.DropSilently, gro
+ }
+ }
return filter.Accept, gro
}
@@ -1392,7 +1443,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
}
// Local Services (DNS and WebDAV)
- hittingServiceIP := dialIP == serviceIP || dialIP == serviceIPv6
+ hittingServiceIP := dialIP == tsServiceIP || dialIP == tsServiceIPv6
hittingDNS := hittingServiceIP && reqDetails.LocalPort == 53
if hittingDNS {
c := getConnOrReset()
@@ -1433,7 +1484,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
}
switch {
case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort):
- if dialIP == serviceIPv6 {
+ if dialIP == tsServiceIPv6 {
dialIP = ipv6Loopback
} else {
dialIP = ipv4Loopback
@@ -1635,14 +1686,14 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
}
// Handle magicDNS and loopback traffic (via UDP) here.
- if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 {
+ if dst := dstAddr.Addr(); dst == tsServiceIP || dst == tsServiceIPv6 {
switch {
case dstAddr.Port() == 53:
c := gonet.NewUDPConn(&wq, ep)
go ns.handleMagicDNSUDP(srcAddr, c)
return
case ns.isLoopbackPort(dstAddr.Port()):
- if dst == serviceIPv6 {
+ if dst == tsServiceIPv6 {
dstAddr = netip.AddrPortFrom(ipv6Loopback, dstAddr.Port())
} else {
dstAddr = netip.AddrPortFrom(ipv4Loopback, dstAddr.Port())
diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go
index 93022811c..3eeb4d56b 100644
--- a/wgengine/netstack/netstack_test.go
+++ b/wgengine/netstack/netstack_test.go
@@ -847,13 +847,13 @@ func TestShouldSendToHost(t *testing.T) {
// not over WireGuard.
{
name: "from_service_ip_to_localhost",
- src: netip.AddrPortFrom(serviceIP, 53),
+ src: netip.AddrPortFrom(tsServiceIP, 53),
dst: netip.MustParseAddrPort("127.0.0.1:9999"),
want: true,
},
{
name: "from_service_ip_to_localhost_v6",
- src: netip.AddrPortFrom(serviceIPv6, 53),
+ src: netip.AddrPortFrom(tsServiceIPv6, 53),
dst: netip.MustParseAddrPort("[::1]:9999"),
want: true,
},
@@ -1019,3 +1019,279 @@ func makeUDP6PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer {
return pkt
}
+
+func TestHandleIPPacket(t *testing.T) {
+ impl := makeNetstack(t, func(ns *Impl) {})
+
+ client := netip.MustParseAddr("100.64.1.2")
+ destAddr := netip.MustParseAddr("100.64.1.1")
+ pkt := tcp4syn(t, client, destAddr, 1234, 5678)
+
+ var handlerCalled bool
+ var receivedPacket []byte
+ impl.HandleIPPacket = func(p []byte) bool {
+ handlerCalled = true
+ receivedPacket = append([]byte(nil), p...)
+ return true
+ }
+
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+
+ resp, _ := impl.injectInbound(&parsed, impl.tundev, nil)
+
+ if !handlerCalled {
+ t.Error("HandleIPPacket was not called")
+ }
+ if resp != filter.DropSilently {
+ t.Errorf("Expected DropSilently response when handler returns true, got %v", resp)
+ }
+ if len(receivedPacket) == 0 {
+ t.Error("Handler received empty packet")
+ }
+ if len(receivedPacket) < 20 {
+ t.Errorf("Packet too short: %d bytes", len(receivedPacket))
+ }
+ ipVer := receivedPacket[0] >> 4
+ if ipVer != 4 {
+ t.Errorf("Expected IPv4 packet (version 4), got version %d", ipVer)
+ }
+
+ handlerCalled = false
+ receivedPacket = nil
+ impl.HandleIPPacket = func(p []byte) bool {
+ handlerCalled = true
+ receivedPacket = append([]byte(nil), p...)
+ return false
+ }
+
+ var parsed2 packet.Parsed
+ parsed2.Decode(pkt)
+ resp2, _ := impl.injectInbound(&parsed2, impl.tundev, nil)
+
+ if !handlerCalled {
+ t.Error("HandleIPPacket was not called on second test")
+ }
+ if resp2 != filter.Accept {
+ t.Errorf("Expected Accept when handler declines, got %v", resp2)
+ }
+ if len(receivedPacket) == 0 {
+ t.Error("Handler received empty packet on second test")
+ }
+
+ impl.HandleIPPacket = nil
+ var parsed3 packet.Parsed
+ parsed3.Decode(pkt)
+ resp3, _ := impl.injectInbound(&parsed3, impl.tundev, nil)
+
+ if resp3 != filter.Accept {
+ t.Errorf("Expected Accept with no handler, got %v", resp3)
+ }
+}
+
+func TestHandleIPPacketIPv6(t *testing.T) {
+ impl := makeNetstack(t, func(ns *Impl) {})
+
+ src := netip.MustParseAddr("fd7a:115c:a1e0::2")
+ dst := netip.MustParseAddr("fd7a:115c:a1e0::1")
+ const tcpLen = header.TCPMinimumSize
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize+tcpLen))
+ ip.Encode(&header.IPv6Fields{
+ TransportProtocol: header.TCPProtocolNumber,
+ PayloadLength: tcpLen,
+ HopLimit: 64,
+ SrcAddr: tcpip.AddrFrom16(src.As16()),
+ DstAddr: tcpip.AddrFrom16(dst.As16()),
+ })
+
+ tcp := header.TCP(ip[header.IPv6MinimumSize:])
+ tcp.Encode(&header.TCPFields{
+ SrcPort: 1234,
+ DstPort: 5678,
+ SeqNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 65535,
+ })
+
+ pkt := []byte(ip)
+
+ var receivedPacket []byte
+ impl.HandleIPPacket = func(p []byte) bool {
+ receivedPacket = append([]byte(nil), p...)
+ return true
+ }
+
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+
+ resp, _ := impl.injectInbound(&parsed, impl.tundev, nil)
+
+ if resp != filter.DropSilently {
+ t.Errorf("Expected DropSilently, got %v", resp)
+ }
+ if len(receivedPacket) < 40 {
+ t.Errorf("Packet too short for IPv6: %d bytes", len(receivedPacket))
+ }
+ ipVer := receivedPacket[0] >> 4
+ if ipVer != 6 {
+ t.Errorf("Expected IPv6 packet (version 6), got version %d", ipVer)
+ }
+}
+
+func TestHandleIPPacketNotProcessed(t *testing.T) {
+ impl := makeNetstack(t, func(ns *Impl) {
+ ns.ProcessLocalIPs = false
+ ns.ProcessSubnets = false
+ })
+
+ var handlerCalled bool
+ impl.HandleIPPacket = func(p []byte) bool {
+ handlerCalled = true
+ return true
+ }
+
+ client := netip.MustParseAddr("100.64.1.2")
+ destAddr := netip.MustParseAddr("100.64.1.1")
+ pkt := tcp4syn(t, client, destAddr, 1234, 5678)
+
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+
+ resp, _ := impl.injectInbound(&parsed, impl.tundev, nil)
+
+ if !handlerCalled {
+ t.Error("HandleIPPacket should be called when shouldProcessInbound returns false")
+ }
+ if resp != filter.DropSilently {
+ t.Errorf("Expected DropSilently when handler consumes packet, got %v", resp)
+ }
+}
+
+func TestHandleIPPacketRealPacket(t *testing.T) {
+ impl := makeNetstack(t, func(ns *Impl) {})
+
+ client := netip.MustParseAddr("100.64.1.2")
+ destAddr := netip.MustParseAddr("100.64.1.1")
+ pkt := tcp4syn(t, client, destAddr, 1234, 5678)
+
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+ if parsed.IPVersion != 4 {
+ t.Fatalf("Expected IPv4, got version %d", parsed.IPVersion)
+ }
+ if parsed.Src.Addr() != client {
+ t.Errorf("Expected src %v, got %v", client, parsed.Src.Addr())
+ }
+ if parsed.Dst.Addr() != destAddr {
+ t.Errorf("Expected dst %v, got %v", destAddr, parsed.Dst.Addr())
+ }
+
+ var receivedPacket []byte
+ impl.HandleIPPacket = func(p []byte) bool {
+ receivedPacket = append([]byte(nil), p...)
+ return true
+ }
+
+ resp, _ := impl.injectInbound(&parsed, impl.tundev, nil)
+
+ if resp != filter.DropSilently {
+ t.Errorf("Expected DropSilently, got %v", resp)
+ }
+
+ if len(receivedPacket) != len(pkt) {
+ t.Errorf("Packet length mismatch: got %d, want %d", len(receivedPacket), len(pkt))
+ }
+
+ var parsedFromHandler packet.Parsed
+ parsedFromHandler.Decode(receivedPacket)
+ if parsedFromHandler.IPVersion != 4 {
+ t.Errorf("Handler received non-IPv4 packet: version %d", parsedFromHandler.IPVersion)
+ }
+ if parsedFromHandler.Src.Addr() != client {
+ t.Errorf("Handler packet src mismatch: got %v, want %v", parsedFromHandler.Src.Addr(), client)
+ }
+}
+
+// TestHandleIPPacketWithServices verifies that PeerAPI and Service IP packets
+// don't reach the handler (they're handled by shouldProcessInbound).
+func TestHandleIPPacketWithServices(t *testing.T) {
+ var handlerCalledFor []string
+ var handlerFunc = func(p []byte) bool {
+ var parsed packet.Parsed
+ parsed.Decode(p)
+ handlerCalledFor = append(handlerCalledFor, fmt.Sprintf("%v->%v", parsed.Src, parsed.Dst))
+ return true
+ }
+
+ client := netip.MustParseAddr("100.64.1.2")
+ selfIP := netip.MustParseAddr("100.64.1.1")
+
+ tests := []struct {
+ name string
+ setupImpl func(*Impl)
+ dstAddr netip.Addr
+ dstPort uint16
+ shouldReach bool
+ expectedResp filter.Response
+ description string
+ }{
+ {
+ name: "peerapi_packet",
+ setupImpl: func(ns *Impl) {
+ ns.ProcessLocalIPs = true
+ ns.peerapiPort4Atomic.Store(5555)
+ },
+ dstAddr: selfIP,
+ dstPort: 5555,
+ shouldReach: false,
+ expectedResp: filter.DropSilently,
+ description: "PeerAPI packets should NOT reach handler (handled by netstack)",
+ },
+ {
+ name: "service_ip_dns",
+ setupImpl: func(ns *Impl) {},
+ dstAddr: netip.MustParseAddr("100.100.100.100"),
+ dstPort: 53,
+ shouldReach: false,
+ expectedResp: filter.DropSilently,
+ description: "Service IP (DNS) packets should NOT reach handler",
+ },
+ {
+ name: "normal_packet",
+ setupImpl: func(ns *Impl) {},
+ dstAddr: selfIP,
+ dstPort: 8080,
+ shouldReach: true,
+ expectedResp: filter.DropSilently,
+ description: "Normal packets (not processed by netstack) should reach handler",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ impl := makeNetstack(t, tt.setupImpl)
+ impl.HandleIPPacket = handlerFunc
+ handlerCalledFor = nil
+
+ pkt := tcp4syn(t, client, tt.dstAddr, 1234, tt.dstPort)
+ var parsed packet.Parsed
+ parsed.Decode(pkt)
+
+ resp, _ := impl.injectInbound(&parsed, impl.tundev, nil)
+
+ handlerCalled := len(handlerCalledFor) > 0
+
+ if tt.shouldReach && !handlerCalled {
+ t.Errorf("%s: handler was not called, but should have been", tt.description)
+ }
+ if !tt.shouldReach && handlerCalled {
+ t.Errorf("%s: handler was called for %v, but should NOT have been", tt.description, handlerCalledFor)
+ }
+
+ if resp != tt.expectedResp {
+ t.Errorf("Expected %v, got %v", tt.expectedResp, resp)
+ }
+ })
+ }
+}