diff options
| -rw-r--r-- | tsnet/tsnet.go | 63 | ||||
| -rw-r--r-- | tsnet/tsnet_test.go | 279 | ||||
| -rw-r--r-- | wgengine/netstack/netstack.go | 75 | ||||
| -rw-r--r-- | wgengine/netstack/netstack_test.go | 280 |
4 files changed, 683 insertions, 14 deletions
diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ea165e932..1bce7737d 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -158,6 +158,23 @@ type Server struct { // that the control server will allow the node to adopt that tag. AdvertiseTags []string + // IPPacketHandler, if non-nil, specifies a handler that will be called for + // incoming TCP/UDP packets that netstack will not process. + // + // The packet slice contains a complete IP packet starting with the IP header. + // The handler should not retain the packet slice after returning. + // + // Return true if the packet was handled, false to pass it to the host network stack. + // + // The handler will NOT see packets processed by netstack, including packets to + // local Tailscale IPs, subnet IPs, PeerAPI, SSH, or service IPs (100.100.100.100). + // + // The handler is called from the packet receive path and should not block. + // + // This must be set before calling Start, Listen, Dial, or Up. + // To send packets, call [IPPacketWriter()] to get an [IPPacketWriter]. + IPPacketHandler func(packet []byte) bool + getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error) initOnce sync.Once @@ -200,9 +217,25 @@ type Server struct { // over the TCP conn. type FallbackTCPHandler func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) +// IPPacketWriter is a function that sends an IP packet into the Tailscale network. +// The packet must be a complete IP packet starting with the IP header. +// It will be routed through the Tailscale network to the appropriate peer +// based on the destination IP address. +type IPPacketWriter func(pkt []byte) error + +// ErrIPPacketHandlerSet is returned by Listen, ListenPacket, ListenTLS, +// ListenFunnel, and Dial when IPPacketHandler is set. These APIs rely on +// netstack to process packets, which is bypassed when IPPacketHandler is set. +var ErrIPPacketHandlerSet = errors.New("tsnet: cannot use socket APIs when IPPacketHandler is set") + // Dial connects to the address on the tailnet. // It will start the server if it has not been started yet. +// +// If IPPacketHandler is set, Dial returns ErrIPPacketHandlerSet. func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if s.IPPacketHandler != nil { + return nil, ErrIPPacketHandlerSet + } if err := s.Start(); err != nil { return nil, err } @@ -667,6 +700,7 @@ func (s *Server) start() (reterr error) { ns.ProcessSubnets = true ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow + ns.HandleIPPacket = s.IPPacketHandler s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := eng.PeerForIP(ip) @@ -1018,7 +1052,12 @@ func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(net // IPv6 address of this node) only. To listen for traffic on other addresses // such as those routed inbound via subnet routes, explicitly specify // the listening address or use RegisterFallbackTCPHandler. +// +// If IPPacketHandler is set, Listen returns ErrIPPacketHandlerSet. func (s *Server) Listen(network, addr string) (net.Listener, error) { + if s.IPPacketHandler != nil { + return nil, ErrIPPacketHandlerSet + } return s.listen(network, addr, listenOnTailnet) } @@ -1029,7 +1068,12 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { // corresponding to "udp4" or "udp6" respectively. IP must be specified. // // If s has not been started yet, it will be started. +// +// If IPPacketHandler is set, ListenPacket returns ErrIPPacketHandlerSet. func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { + if s.IPPacketHandler != nil { + return nil, ErrIPPacketHandlerSet + } ap, err := resolveListenAddr(network, addr) if err != nil { return nil, err @@ -1050,10 +1094,24 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { return s.netstack.ListenPacket(network, ap.String()) } +// IPPacketWriter returns a function to send IP packets into the Tailscale network. +func (s *Server) IPPacketWriter() (IPPacketWriter, error) { + if err := s.Start(); err != nil { + return nil, err + } + + return s.netstack.InjectOutbound, nil +} + // ListenTLS announces only on the Tailscale network. // It returns a TLS listener wrapping the tsnet listener. // It will start the server if it has not been started yet. +// +// If IPPacketHandler is set, ListenTLS returns ErrIPPacketHandlerSet. func (s *Server) ListenTLS(network, addr string) (net.Listener, error) { + if s.IPPacketHandler != nil { + return nil, ErrIPPacketHandlerSet + } if network != "tcp" { return nil, fmt.Errorf("ListenTLS(%q, %q): only tcp is supported", network, addr) } @@ -1159,7 +1217,12 @@ func FunnelTLSConfig(conf *tls.Config) FunnelOption { // and the only other supported addrs currently are ":8443" and ":10000". // // It will start the server if it has not been started yet. +// +// If IPPacketHandler is set, ListenFunnel returns ErrIPPacketHandlerSet. func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.Listener, error) { + if s.IPPacketHandler != nil { + return nil, ErrIPPacketHandlerSet + } if network != "tcp" { return nil, fmt.Errorf("ListenFunnel(%q, %q): only tcp is supported", network, addr) } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 838d5f3f5..3943bda56 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/binary" "errors" "flag" "fmt" @@ -42,11 +43,13 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/deptest" "tailscale.com/tstest/integration" "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/must" @@ -1592,3 +1595,279 @@ func TestResolveAuthKey(t *testing.T) { }) } } + +// TestIPPacketHandler tests the IPPacketHandler functionality end-to-end. +// It sets up two tsnet instances: +// - s1: a "regular" tsnet using the high-level socket API (ListenPacket) +// - s2: a "raw packet" tsnet using IPPacketHandler/IPPacketWriter +// +// The test verifies: +// 1. s2 can send a UDP packet to s1 using raw IP packets via IPPacketWriter +// 2. s2 receives the UDP echo reply from s1 via IPPacketHandler +// 3. PeerAPI still works on s2 (raw packet handling doesn't break internal services) +func TestIPPacketHandler(t *testing.T) { + tstest.Shard(t) + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + + // Start s1 - a regular tsnet server that will echo UDP packets + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:9999", s1ip))) + defer pc.Close() + + echoServerDone := make(chan struct{}) + t.Logf("s1: starting UDP echo server on %v:9999", s1ip) + go func() { + defer close(echoServerDone) + defer t.Log("s1: echo server goroutine exiting") + buf := make([]byte, 1500) + for { + t.Log("s1: echo server waiting for packet...") + n, addr, err := pc.ReadFrom(buf) + if err != nil { + t.Logf("s1: echo server ReadFrom error: %v", err) + return + } + t.Logf("s1: received UDP packet from %v: %q", addr, buf[:n]) + // Echo back the payload + if _, err := pc.WriteTo(buf[:n], addr); err != nil { + t.Logf("s1: error writing response: %v", err) + } + } + }() + defer func() { + pc.Close() + <-echoServerDone + }() + + // Set up s2 with IPPacketHandler - the raw packet server + receivedPackets := make(chan []byte, 10) + tmp := filepath.Join(t.TempDir(), "s2") + os.MkdirAll(tmp, 0755) + s2 := &Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + getCertForTesting: testCertRoot.getCert, + IPPacketHandler: func(pkt []byte) bool { + // Make a copy since we can't retain the slice + cpy := make([]byte, len(pkt)) + copy(cpy, pkt) + + var parsed packet.Parsed + parsed.Decode(cpy) + t.Logf("s2 IPPacketHandler: received %s packet %v -> %v", parsed.IPProto, parsed.Src, parsed.Dst) + + select { + case receivedPackets <- cpy: + default: + t.Logf("s2: dropped packet, channel full") + } + return true // consume the packet + }, + } + if *verboseNodes { + s2.Logf = log.Printf + } + t.Cleanup(func() { s2.Close() }) + + status, err := s2.Up(ctx) + if err != nil { + t.Fatal(err) + } + s2ip := status.TailscaleIPs[0] + t.Logf("s2 IP: %v", s2ip) + + writer, err := s2.IPPacketWriter() + if err != nil { + t.Fatal(err) + } + + lc2 := must.Get(s2.LocalClient()) + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatalf("ping failed: %v", err) + } + t.Logf("ping success: latency=%v", res.LatencySeconds) + + const srcPort = 12345 + const dstPort = 9999 + payload := []byte("hello from raw packet") + + udpPkt := buildUDPPacket(t, s2ip, s1ip, srcPort, dstPort, payload) + t.Logf("s2: sending raw UDP packet (%d bytes) to %v:%d", len(udpPkt), s1ip, dstPort) + + if err := writer(udpPkt); err != nil { + t.Fatalf("IPPacketWriter failed: %v", err) + } + t.Log("s2: packet injected successfully") + + t.Log("s2: waiting for UDP echo reply") + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for UDP echo reply") + case pkt := <-receivedPackets: + t.Logf("s2: received raw packet (%d bytes)", len(pkt)) + + var parsed packet.Parsed + parsed.Decode(pkt) + if parsed.IPVersion != 4 { + t.Errorf("expected IPv4, got version %d", parsed.IPVersion) + } + if parsed.IPProto != ipproto.UDP { + t.Errorf("expected UDP, got protocol %d", parsed.IPProto) + } + if parsed.Src.Addr() != s1ip { + t.Errorf("expected src %v, got %v", s1ip, parsed.Src.Addr()) + } + if parsed.Dst.Addr() != s2ip { + t.Errorf("expected dst %v, got %v", s2ip, parsed.Dst.Addr()) + } + if parsed.Src.Port() != dstPort { + t.Errorf("expected src port %d, got %d", dstPort, parsed.Src.Port()) + } + if parsed.Dst.Port() != srcPort { + t.Errorf("expected dst port %d, got %d", srcPort, parsed.Dst.Port()) + } + if !bytes.Equal(parsed.Payload(), payload) { + t.Errorf("payload mismatch: got %q, want %q", parsed.Payload(), payload) + } + } + + t.Log("Testing PeerAPI access to raw packet server...") + s2Status := must.Get(lc2.StatusWithoutPeers(ctx)) + if len(s2Status.Self.PeerAPIURL) == 0 { + t.Fatal("s2 has no PeerAPI URLs") + } + peerAPIURL := s2Status.Self.PeerAPIURL[0] + t.Logf("s2 PeerAPI URL: %s", peerAPIURL) + + resp, err := s1.HTTPClient().Get(peerAPIURL) + if err != nil { + t.Fatalf("failed to GET s2 PeerAPI from s1: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("PeerAPI request returned status %d, want 200", resp.StatusCode) + } + t.Log("PeerAPI access to raw packet server succeeded") +} + +// buildUDPPacket constructs a complete IPv4 UDP packet with valid checksums. +func buildUDPPacket(t *testing.T, src, dst netip.Addr, srcPort, dstPort uint16, payload []byte) []byte { + t.Helper() + + if !src.Is4() || !dst.Is4() { + t.Fatal("buildUDPPacket only supports IPv4") + } + + const ipHeaderLen = 20 + const udpHeaderLen = 8 + totalLen := ipHeaderLen + udpHeaderLen + len(payload) + udpLen := udpHeaderLen + len(payload) + + pkt := make([]byte, totalLen) + + // IPv4 header + pkt[0] = 0x45 // Version (4) + IHL (5 = 20 bytes) + pkt[1] = 0 // DSCP + ECN + binary.BigEndian.PutUint16(pkt[2:4], uint16(totalLen)) // Total length + binary.BigEndian.PutUint16(pkt[4:6], 0) // Identification + binary.BigEndian.PutUint16(pkt[6:8], 0) // Flags + Fragment offset + pkt[8] = 64 // TTL + pkt[9] = uint8(ipproto.UDP) // Protocol + // pkt[10:12] = checksum (computed below) + copy(pkt[12:16], src.AsSlice()) + copy(pkt[16:20], dst.AsSlice()) + + // Compute IP header checksum + var sum uint32 + for i := 0; i < ipHeaderLen; i += 2 { + sum += uint32(binary.BigEndian.Uint16(pkt[i : i+2])) + } + for sum > 0xffff { + sum = (sum & 0xffff) + (sum >> 16) + } + binary.BigEndian.PutUint16(pkt[10:12], ^uint16(sum)) + + // UDP header + udp := pkt[ipHeaderLen:] + binary.BigEndian.PutUint16(udp[0:2], srcPort) + binary.BigEndian.PutUint16(udp[2:4], dstPort) + binary.BigEndian.PutUint16(udp[4:6], uint16(udpLen)) + // udp[6:8] = checksum (computed below with pseudo-header) + + // Payload + copy(udp[udpHeaderLen:], payload) + + // Compute UDP checksum with pseudo-header + // Pseudo-header: src IP (4) + dst IP (4) + zero (1) + protocol (1) + UDP length (2) = 12 bytes + var udpSum uint32 + // Add source IP + udpSum += uint32(binary.BigEndian.Uint16(src.AsSlice()[0:2])) + udpSum += uint32(binary.BigEndian.Uint16(src.AsSlice()[2:4])) + // Add dest IP + udpSum += uint32(binary.BigEndian.Uint16(dst.AsSlice()[0:2])) + udpSum += uint32(binary.BigEndian.Uint16(dst.AsSlice()[2:4])) + // Add protocol (UDP = 17) + udpSum += uint32(ipproto.UDP) + // Add UDP length + udpSum += uint32(udpLen) + // Add UDP header and payload + for i := 0; i < len(udp); i += 2 { + if i+1 < len(udp) { + udpSum += uint32(binary.BigEndian.Uint16(udp[i : i+2])) + } else { + udpSum += uint32(udp[i]) << 8 // odd byte + } + } + // Fold 32-bit sum to 16 bits + for udpSum > 0xffff { + udpSum = (udpSum & 0xffff) + (udpSum >> 16) + } + checksum := ^uint16(udpSum) + if checksum == 0 { + checksum = 0xffff // UDP checksum of 0 is transmitted as 0xffff + } + binary.BigEndian.PutUint16(udp[6:8], checksum) + return pkt +} + +// TestIPPacketHandlerSocketAPIGuards verifies that Listen, ListenPacket, +// ListenTLS, ListenFunnel, and Dial return ErrIPPacketHandlerSet when +// IPPacketHandler is set. +func TestIPPacketHandlerSocketAPIGuards(t *testing.T) { + s := &Server{ + IPPacketHandler: func([]byte) bool { return true }, + } + + _, err := s.Listen("tcp", ":0") + if !errors.Is(err, ErrIPPacketHandlerSet) { + t.Errorf("Listen: got %v, want ErrIPPacketHandlerSet", err) + } + + _, err = s.ListenPacket("udp", "127.0.0.1:0") + if !errors.Is(err, ErrIPPacketHandlerSet) { + t.Errorf("ListenPacket: got %v, want ErrIPPacketHandlerSet", err) + } + + _, err = s.ListenTLS("tcp", ":0") + if !errors.Is(err, ErrIPPacketHandlerSet) { + t.Errorf("ListenTLS: got %v, want ErrIPPacketHandlerSet", err) + } + + _, err = s.ListenFunnel("tcp", ":443") + if !errors.Is(err, ErrIPPacketHandlerSet) { + t.Errorf("ListenFunnel: got %v, want ErrIPPacketHandlerSet", err) + } + + _, err = s.Dial(context.Background(), "tcp", "127.0.0.1:80") + if !errors.Is(err, ErrIPPacketHandlerSet) { + t.Errorf("Dial: got %v, want ErrIPPacketHandlerSet", err) + } +} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index c2b5d8a32..17f9cfd7d 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -21,6 +21,7 @@ import ( "time" "github.com/tailscale/wireguard-go/conn" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -119,8 +120,8 @@ func maxInFlightConnectionAttemptsPerClient() int { var debugNetstack = envknob.RegisterBool("TS_DEBUG_NETSTACK") var ( - serviceIP = tsaddr.TailscaleServiceIP() - serviceIPv6 = tsaddr.TailscaleServiceIPv6() + tsServiceIP = tsaddr.TailscaleServiceIP() + tsServiceIPv6 = tsaddr.TailscaleServiceIPv6() ) func init() { @@ -176,6 +177,22 @@ type Impl struct { // It can only be set before calling Start. ProcessSubnets bool + // HandleIPPacket, if non-nil, is called for incoming TCP/UDP packets that + // netstack will not process. + // + // The packet slice contains a complete IP packet starting with the IP header. + // The packet slice is only valid for the duration of the call and must not + // be retained. + // + // If HandleIPPacket returns true, the packet is consumed. If false, the packet + // is passed to the host network stack (if available). + // + // The handler will NOT see packets processed by netstack (local IPs, subnet IPs, + // PeerAPI, SSH, service IPs, or flows handled by GetTCPHandlerForFlow/GetUDPHandlerForFlow). + // + // The handler is called from the packet receive path and must not block. + HandleIPPacket func(packet []byte) bool + ipstack *stack.Stack linkEP *linkEndpoint tundev *tstun.Wrapper @@ -415,6 +432,16 @@ func (ns *Impl) Close() error { return nil } +// InjectOutbound sends an IP packet through the Tailscale network. +// The packet must be a complete IP packet starting with the IP header. +func (ns *Impl) InjectOutbound(pkt []byte) error { + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(pkt), + }) + // InjectOutboundPacketBuffer decrements the buffer reference count. + return ns.tundev.InjectOutboundPacketBuffer(packetBuf) +} + // SetTransportProtocolOption forwards to the underlying // [stack.Stack.SetTransportProtocolOption]. Callers are responsible for // ensuring that the options are valid, compatible and appropriate for their use @@ -772,7 +799,7 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper, gro *gro. // Determine if we care about this local packet. dst := p.Dst.Addr() switch { - case dst == serviceIP || dst == serviceIPv6: + case dst == tsServiceIP || dst == tsServiceIPv6: // We want to intercept some traffic to the "service IP" (e.g. // 100.100.100.100 for IPv4). However, of traffic to the // service IP, we only care about UDP 53, and TCP on port 53, @@ -994,13 +1021,13 @@ func (ns *Impl) shouldSendToHost(pkt *stack.PacketBuffer) bool { switch v := hdr.(type) { case header.IPv4: srcIP := netip.AddrFrom4(v.SourceAddress().As4()) - if serviceIP == srcIP { + if tsServiceIP == srcIP { return true } case header.IPv6: srcIP := netip.AddrFrom16(v.SourceAddress().As16()) - if srcIP == serviceIPv6 { + if srcIP == tsServiceIPv6 { return true } @@ -1064,7 +1091,7 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { // Handle incoming peerapi connections in netstack. dstIP := p.Dst.Addr() isLocal := ns.isLocalIP(dstIP) - isService := ns.isVIPServiceIP(dstIP) + isVIPService := ns.isVIPServiceIP(dstIP) // Handle TCP connection to the Tailscale IP(s) in some cases: if ns.lb != nil && p.IPProto == ipproto.TCP && isLocal { @@ -1087,7 +1114,20 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { return true } } - if buildfeatures.HasServe && isService { + hittingServiceIP := dstIP == tsServiceIP || dstIP == tsServiceIPv6 + if hittingServiceIP { + if p.IsEchoRequest() { + return true + } + if p.Dst.Port() == 53 { + return true + } + if ns.isLoopbackPort(p.Dst.Port()) { + return true + } + return false + } + if buildfeatures.HasServe && isVIPService { if p.IsEchoRequest() { return true } @@ -1103,6 +1143,13 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { if p.IPVersion == 6 && !isLocal && viaRange.Contains(dstIP) { return ns.lb != nil && ns.lb.ShouldHandleViaIP(dstIP) } + // If HandleIPPacket is set, don't process normal traffic in netstack. + // This allows the handler to receive all non-special packets. + // Special traffic (PeerAPI, SSH, service IP, VIP services) is still + // handled by netstack via the checks above. + if ns.HandleIPPacket != nil { + return false + } if ns.ProcessLocalIPs && isLocal { return true } @@ -1189,7 +1236,11 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper, gro *gro.GRO) } if !ns.shouldProcessInbound(p, t) { - // Let the host network stack (if any) deal with it. + if ns.HandleIPPacket != nil { + if ns.HandleIPPacket(p.Buffer()) { + return filter.DropSilently, gro + } + } return filter.Accept, gro } @@ -1392,7 +1443,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } // Local Services (DNS and WebDAV) - hittingServiceIP := dialIP == serviceIP || dialIP == serviceIPv6 + hittingServiceIP := dialIP == tsServiceIP || dialIP == tsServiceIPv6 hittingDNS := hittingServiceIP && reqDetails.LocalPort == 53 if hittingDNS { c := getConnOrReset() @@ -1433,7 +1484,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } switch { case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort): - if dialIP == serviceIPv6 { + if dialIP == tsServiceIPv6 { dialIP = ipv6Loopback } else { dialIP = ipv4Loopback @@ -1635,14 +1686,14 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { } // Handle magicDNS and loopback traffic (via UDP) here. - if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 { + if dst := dstAddr.Addr(); dst == tsServiceIP || dst == tsServiceIPv6 { switch { case dstAddr.Port() == 53: c := gonet.NewUDPConn(&wq, ep) go ns.handleMagicDNSUDP(srcAddr, c) return case ns.isLoopbackPort(dstAddr.Port()): - if dst == serviceIPv6 { + if dst == tsServiceIPv6 { dstAddr = netip.AddrPortFrom(ipv6Loopback, dstAddr.Port()) } else { dstAddr = netip.AddrPortFrom(ipv4Loopback, dstAddr.Port()) diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 93022811c..3eeb4d56b 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -847,13 +847,13 @@ func TestShouldSendToHost(t *testing.T) { // not over WireGuard. { name: "from_service_ip_to_localhost", - src: netip.AddrPortFrom(serviceIP, 53), + src: netip.AddrPortFrom(tsServiceIP, 53), dst: netip.MustParseAddrPort("127.0.0.1:9999"), want: true, }, { name: "from_service_ip_to_localhost_v6", - src: netip.AddrPortFrom(serviceIPv6, 53), + src: netip.AddrPortFrom(tsServiceIPv6, 53), dst: netip.MustParseAddrPort("[::1]:9999"), want: true, }, @@ -1019,3 +1019,279 @@ func makeUDP6PacketBuffer(src, dst netip.AddrPort) *stack.PacketBuffer { return pkt } + +func TestHandleIPPacket(t *testing.T) { + impl := makeNetstack(t, func(ns *Impl) {}) + + client := netip.MustParseAddr("100.64.1.2") + destAddr := netip.MustParseAddr("100.64.1.1") + pkt := tcp4syn(t, client, destAddr, 1234, 5678) + + var handlerCalled bool + var receivedPacket []byte + impl.HandleIPPacket = func(p []byte) bool { + handlerCalled = true + receivedPacket = append([]byte(nil), p...) + return true + } + + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.injectInbound(&parsed, impl.tundev, nil) + + if !handlerCalled { + t.Error("HandleIPPacket was not called") + } + if resp != filter.DropSilently { + t.Errorf("Expected DropSilently response when handler returns true, got %v", resp) + } + if len(receivedPacket) == 0 { + t.Error("Handler received empty packet") + } + if len(receivedPacket) < 20 { + t.Errorf("Packet too short: %d bytes", len(receivedPacket)) + } + ipVer := receivedPacket[0] >> 4 + if ipVer != 4 { + t.Errorf("Expected IPv4 packet (version 4), got version %d", ipVer) + } + + handlerCalled = false + receivedPacket = nil + impl.HandleIPPacket = func(p []byte) bool { + handlerCalled = true + receivedPacket = append([]byte(nil), p...) + return false + } + + var parsed2 packet.Parsed + parsed2.Decode(pkt) + resp2, _ := impl.injectInbound(&parsed2, impl.tundev, nil) + + if !handlerCalled { + t.Error("HandleIPPacket was not called on second test") + } + if resp2 != filter.Accept { + t.Errorf("Expected Accept when handler declines, got %v", resp2) + } + if len(receivedPacket) == 0 { + t.Error("Handler received empty packet on second test") + } + + impl.HandleIPPacket = nil + var parsed3 packet.Parsed + parsed3.Decode(pkt) + resp3, _ := impl.injectInbound(&parsed3, impl.tundev, nil) + + if resp3 != filter.Accept { + t.Errorf("Expected Accept with no handler, got %v", resp3) + } +} + +func TestHandleIPPacketIPv6(t *testing.T) { + impl := makeNetstack(t, func(ns *Impl) {}) + + src := netip.MustParseAddr("fd7a:115c:a1e0::2") + dst := netip.MustParseAddr("fd7a:115c:a1e0::1") + const tcpLen = header.TCPMinimumSize + ip := header.IPv6(make([]byte, header.IPv6MinimumSize+tcpLen)) + ip.Encode(&header.IPv6Fields{ + TransportProtocol: header.TCPProtocolNumber, + PayloadLength: tcpLen, + HopLimit: 64, + SrcAddr: tcpip.AddrFrom16(src.As16()), + DstAddr: tcpip.AddrFrom16(dst.As16()), + }) + + tcp := header.TCP(ip[header.IPv6MinimumSize:]) + tcp.Encode(&header.TCPFields{ + SrcPort: 1234, + DstPort: 5678, + SeqNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 65535, + }) + + pkt := []byte(ip) + + var receivedPacket []byte + impl.HandleIPPacket = func(p []byte) bool { + receivedPacket = append([]byte(nil), p...) + return true + } + + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.injectInbound(&parsed, impl.tundev, nil) + + if resp != filter.DropSilently { + t.Errorf("Expected DropSilently, got %v", resp) + } + if len(receivedPacket) < 40 { + t.Errorf("Packet too short for IPv6: %d bytes", len(receivedPacket)) + } + ipVer := receivedPacket[0] >> 4 + if ipVer != 6 { + t.Errorf("Expected IPv6 packet (version 6), got version %d", ipVer) + } +} + +func TestHandleIPPacketNotProcessed(t *testing.T) { + impl := makeNetstack(t, func(ns *Impl) { + ns.ProcessLocalIPs = false + ns.ProcessSubnets = false + }) + + var handlerCalled bool + impl.HandleIPPacket = func(p []byte) bool { + handlerCalled = true + return true + } + + client := netip.MustParseAddr("100.64.1.2") + destAddr := netip.MustParseAddr("100.64.1.1") + pkt := tcp4syn(t, client, destAddr, 1234, 5678) + + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.injectInbound(&parsed, impl.tundev, nil) + + if !handlerCalled { + t.Error("HandleIPPacket should be called when shouldProcessInbound returns false") + } + if resp != filter.DropSilently { + t.Errorf("Expected DropSilently when handler consumes packet, got %v", resp) + } +} + +func TestHandleIPPacketRealPacket(t *testing.T) { + impl := makeNetstack(t, func(ns *Impl) {}) + + client := netip.MustParseAddr("100.64.1.2") + destAddr := netip.MustParseAddr("100.64.1.1") + pkt := tcp4syn(t, client, destAddr, 1234, 5678) + + var parsed packet.Parsed + parsed.Decode(pkt) + if parsed.IPVersion != 4 { + t.Fatalf("Expected IPv4, got version %d", parsed.IPVersion) + } + if parsed.Src.Addr() != client { + t.Errorf("Expected src %v, got %v", client, parsed.Src.Addr()) + } + if parsed.Dst.Addr() != destAddr { + t.Errorf("Expected dst %v, got %v", destAddr, parsed.Dst.Addr()) + } + + var receivedPacket []byte + impl.HandleIPPacket = func(p []byte) bool { + receivedPacket = append([]byte(nil), p...) + return true + } + + resp, _ := impl.injectInbound(&parsed, impl.tundev, nil) + + if resp != filter.DropSilently { + t.Errorf("Expected DropSilently, got %v", resp) + } + + if len(receivedPacket) != len(pkt) { + t.Errorf("Packet length mismatch: got %d, want %d", len(receivedPacket), len(pkt)) + } + + var parsedFromHandler packet.Parsed + parsedFromHandler.Decode(receivedPacket) + if parsedFromHandler.IPVersion != 4 { + t.Errorf("Handler received non-IPv4 packet: version %d", parsedFromHandler.IPVersion) + } + if parsedFromHandler.Src.Addr() != client { + t.Errorf("Handler packet src mismatch: got %v, want %v", parsedFromHandler.Src.Addr(), client) + } +} + +// TestHandleIPPacketWithServices verifies that PeerAPI and Service IP packets +// don't reach the handler (they're handled by shouldProcessInbound). +func TestHandleIPPacketWithServices(t *testing.T) { + var handlerCalledFor []string + var handlerFunc = func(p []byte) bool { + var parsed packet.Parsed + parsed.Decode(p) + handlerCalledFor = append(handlerCalledFor, fmt.Sprintf("%v->%v", parsed.Src, parsed.Dst)) + return true + } + + client := netip.MustParseAddr("100.64.1.2") + selfIP := netip.MustParseAddr("100.64.1.1") + + tests := []struct { + name string + setupImpl func(*Impl) + dstAddr netip.Addr + dstPort uint16 + shouldReach bool + expectedResp filter.Response + description string + }{ + { + name: "peerapi_packet", + setupImpl: func(ns *Impl) { + ns.ProcessLocalIPs = true + ns.peerapiPort4Atomic.Store(5555) + }, + dstAddr: selfIP, + dstPort: 5555, + shouldReach: false, + expectedResp: filter.DropSilently, + description: "PeerAPI packets should NOT reach handler (handled by netstack)", + }, + { + name: "service_ip_dns", + setupImpl: func(ns *Impl) {}, + dstAddr: netip.MustParseAddr("100.100.100.100"), + dstPort: 53, + shouldReach: false, + expectedResp: filter.DropSilently, + description: "Service IP (DNS) packets should NOT reach handler", + }, + { + name: "normal_packet", + setupImpl: func(ns *Impl) {}, + dstAddr: selfIP, + dstPort: 8080, + shouldReach: true, + expectedResp: filter.DropSilently, + description: "Normal packets (not processed by netstack) should reach handler", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + impl := makeNetstack(t, tt.setupImpl) + impl.HandleIPPacket = handlerFunc + handlerCalledFor = nil + + pkt := tcp4syn(t, client, tt.dstAddr, 1234, tt.dstPort) + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.injectInbound(&parsed, impl.tundev, nil) + + handlerCalled := len(handlerCalledFor) > 0 + + if tt.shouldReach && !handlerCalled { + t.Errorf("%s: handler was not called, but should have been", tt.description) + } + if !tt.shouldReach && handlerCalled { + t.Errorf("%s: handler was called for %v, but should NOT have been", tt.description, handlerCalledFor) + } + + if resp != tt.expectedResp { + t.Errorf("Expected %v, got %v", tt.expectedResp, resp) + } + }) + } +} |
