diff options
| author | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
|---|---|---|
| committer | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
| commit | 0267fe83b200f1702a2fa0a395442c02a053fadb (patch) | |
| tree | 63654c55225eeb834de59a5a0bc8d19033c6145b /wgengine | |
| parent | 87546a5edf6b6503a87eeb2d666baba57398a066 (diff) | |
| download | tailscale-1.78.0.tar.xz tailscale-1.78.0.zip | |
VERSION.txt: this is v1.78.0v1.78.0
Signed-off-by: Nick Khyl <nickk@tailscale.com>
Diffstat (limited to 'wgengine')
| -rw-r--r-- | wgengine/bench/bench.go | 818 | ||||
| -rw-r--r-- | wgengine/bench/bench_test.go | 216 | ||||
| -rw-r--r-- | wgengine/bench/trafficgen.go | 518 | ||||
| -rw-r--r-- | wgengine/capture/capture.go | 476 | ||||
| -rw-r--r-- | wgengine/magicsock/blockforever_conn.go | 110 | ||||
| -rw-r--r-- | wgengine/magicsock/endpoint_default.go | 44 | ||||
| -rw-r--r-- | wgengine/magicsock/endpoint_stub.go | 26 | ||||
| -rw-r--r-- | wgengine/magicsock/endpoint_tracker.go | 496 | ||||
| -rw-r--r-- | wgengine/magicsock/magicsock_unix_test.go | 120 | ||||
| -rw-r--r-- | wgengine/magicsock/peermtu_darwin.go | 102 | ||||
| -rw-r--r-- | wgengine/magicsock/peermtu_linux.go | 98 | ||||
| -rw-r--r-- | wgengine/magicsock/peermtu_unix.go | 84 | ||||
| -rw-r--r-- | wgengine/mem_ios.go | 40 | ||||
| -rw-r--r-- | wgengine/netstack/netstack_linux.go | 38 | ||||
| -rw-r--r-- | wgengine/router/runner.go | 240 | ||||
| -rw-r--r-- | wgengine/watchdog_js.go | 34 | ||||
| -rw-r--r-- | wgengine/wgcfg/device.go | 136 | ||||
| -rw-r--r-- | wgengine/wgcfg/device_test.go | 522 | ||||
| -rw-r--r-- | wgengine/wgcfg/parser.go | 372 | ||||
| -rw-r--r-- | wgengine/winnet/winnet_windows.go | 52 |
20 files changed, 2271 insertions, 2271 deletions
diff --git a/wgengine/bench/bench.go b/wgengine/bench/bench.go index 8695f18d1..b94930ee5 100644 --- a/wgengine/bench/bench.go +++ b/wgengine/bench/bench.go @@ -1,409 +1,409 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "bufio" - "io" - "log" - "net" - "net/http" - "net/http/pprof" - "net/netip" - "os" - "strconv" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const PayloadSize = 1000 -const ICMPMinSize = 24 - -var Addr1 = netip.MustParsePrefix("100.64.1.1/32") -var Addr2 = netip.MustParsePrefix("100.64.1.2/32") - -func main() { - var logf logger.Logf = log.Printf - log.SetFlags(0) - - debugMux := newDebugMux() - go runDebugServer(debugMux, "0.0.0.0:8999") - - mode, err := strconv.Atoi(os.Args[1]) - if err != nil { - log.Fatalf("%q: %v", os.Args[1], err) - } - - traf := NewTrafficGen(nil) - - // Sample test results below are using GOMAXPROCS=2 (for some - // tests, including wireguard-go, higher GOMAXPROCS goes slower) - // on apenwarr's old Linux box: - // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz - // My 2019 Mac Mini is about 20% faster on most tests. - - switch mode { - // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) - case 1: - setupTrivialNoAllocTest(logf, traf) - - // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) - case 2: - setupTrivialTest(logf, traf) - - // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) - case 11: - setupBlockingChannelTest(logf, traf) - - // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) - // (much faster on macOS??) - case 12: - setupNonblockingChannelTest(logf, traf) - - // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) - // (much faster on macOS??) - case 13: - setupDoubleChannelTest(logf, traf) - - // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) - case 21: - setupUDPTest(logf, traf) - - // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) - case 31: - setupBatchTCPTest(logf, traf) - - // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) - case 101: - setupWGTest(nil, logf, traf, Addr1, Addr2) - - default: - log.Fatalf("provide a valid test number (0..n)") - } - - logf("initialized ok.") - traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) - - var cur, prev Snapshot - var pps int64 - i := 0 - for { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec == 0 { - logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) - } else { - logf("%v @%7d pkt/s", d, pps) - } - } - - pps = traf.Adjust() - } -} - -func newDebugMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - return mux -} - -func runDebugServer(mux *http.ServeMux, addr string) { - srv := &http.Server{ - Addr: addr, - Handler: mux, - } - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } -} - -// The absolute minimal test of the traffic generator: have it fill -// a packet buffer, then absorb it again. Zero packet loss. -func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { - go func() { - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Almost the same, but this time allocate a fresh buffer each time -// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. -func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { - go func() { - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Pass packets through a blocking channel between sender and receiver. -// Still zero packet loss since the sender stops when the channel is full. -// Max speed depends on channel length (I'm not sure why). -func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - ch <- b[0 : n+16] - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as setupBlockingChannelTest, but now we drop packets whenever the -// channel is full. Max speed is about the same as the above test, but -// now with nonzero packet loss. -func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as above, but at an intermediate blocking channel and goroutine -// to make things a little more like wireguard-go. Roughly 20% slower than -// the single-channel version. -func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - ch2 := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // intermediary - for b := range ch { - ch2 <- b - } - close(ch2) - }() - - go func() { - // receiver - for b := range ch2 { - traf.GotPacket(b, 16) - } - }() -} - -// Instead of a channel, pass packets through a UDP socket. -func setupUDPTest(logf logger.Logf, traf *TrafficGen) { - la, err := net.ResolveUDPAddr("udp", ":0") - if err != nil { - log.Fatalf("resolve: %v", err) - } - - s1, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen1: %v", err) - } - s2, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen2: %v", err) - } - - a2 := s2.LocalAddr() - - // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, - // which is what returns from .LocalAddr() above. We have to - // force it to localhost instead. - a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") - - s1.SetWriteBuffer(1024 * 1024) - s2.SetReadBuffer(1024 * 1024) - - go func() { - // transmitter - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - s1.WriteTo(b[16:n+16], a2) - } - }() - - go func() { - // receiver - b := make([]byte, 1600) - for traf.Running() { - // Use ReadFrom instead of Read, to be more like - // how wireguard-go does it, even though we're not - // going to actually look at the address. - n, _, err := s2.ReadFrom(b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} - -// Instead of a channel, pass packets through a TCP socket. -// TCP is a single stream, so we can amortize one syscall across -// multiple packets. 10x amortization seems to make it go ~10x faster, -// as expected, getting us close to the speed of the channel tests above. -// There's also zero packet loss. -func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { - sl, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("listen: %v", err) - } - - var slCloseOnce sync.Once - slClose := func() { - slCloseOnce.Do(func() { - sl.Close() - }) - } - - s1, err := net.Dial("tcp", sl.Addr().String()) - if err != nil { - log.Fatalf("dial: %v", err) - } - - s2, err := sl.Accept() - if err != nil { - log.Fatalf("accept: %v", err) - } - - s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) - s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) - - ch := make(chan int) - - go func() { - // transmitter - defer slClose() - defer s1.Close() - - bs1 := bufio.NewWriterSize(s1, 1024*1024) - - b := make([]byte, 1600) - i := 0 - for { - i += 1 - n := traf.Generate(b, 16) - if n == 0 { - break - } - if i == 1 { - ch <- n - } - bs1.Write(b[16 : n+16]) - - // TODO: this is a pretty half-baked batching - // function, which we'd never want to employ in - // a real-life program. - // - // In real life, we'd probably want to flush - // immediately when there are no more packets to - // generate, and queue up only if we fall behind. - // - // In our case however, we just want to see the - // technical benefits of batching 10 syscalls - // into 1, so a fixed ratio makes more sense. - if (i % 10) == 0 { - bs1.Flush() - } - } - }() - - go func() { - // receiver - defer slClose() - defer s2.Close() - - bs2 := bufio.NewReaderSize(s2, 1024*1024) - - // Find out the packet size (we happen to know they're - // all the same size) - packetSize := <-ch - - b := make([]byte, packetSize) - for traf.Running() { - // TODO: can't use ReadFrom() here, which is - // unfair compared to UDP. (ReadFrom for UDP - // apparently allocates memory per packet, which - // this test does not.) - n, err := io.ReadFull(bs2, b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Create two wgengine instances and pass data through them, measuring
+// throughput, latency, and packet loss.
+package main
+
+import (
+ "bufio"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/http/pprof"
+ "net/netip"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+const PayloadSize = 1000
+const ICMPMinSize = 24
+
+var Addr1 = netip.MustParsePrefix("100.64.1.1/32")
+var Addr2 = netip.MustParsePrefix("100.64.1.2/32")
+
+func main() {
+ var logf logger.Logf = log.Printf
+ log.SetFlags(0)
+
+ debugMux := newDebugMux()
+ go runDebugServer(debugMux, "0.0.0.0:8999")
+
+ mode, err := strconv.Atoi(os.Args[1])
+ if err != nil {
+ log.Fatalf("%q: %v", os.Args[1], err)
+ }
+
+ traf := NewTrafficGen(nil)
+
+ // Sample test results below are using GOMAXPROCS=2 (for some
+ // tests, including wireguard-go, higher GOMAXPROCS goes slower)
+ // on apenwarr's old Linux box:
+ // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz
+ // My 2019 Mac Mini is about 20% faster on most tests.
+
+ switch mode {
+ // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec)
+ case 1:
+ setupTrivialNoAllocTest(logf, traf)
+
+ // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec)
+ case 2:
+ setupTrivialTest(logf, traf)
+
+ // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec)
+ case 11:
+ setupBlockingChannelTest(logf, traf)
+
+ // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec)
+ // (much faster on macOS??)
+ case 12:
+ setupNonblockingChannelTest(logf, traf)
+
+ // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec)
+ // (much faster on macOS??)
+ case 13:
+ setupDoubleChannelTest(logf, traf)
+
+ // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec)
+ case 21:
+ setupUDPTest(logf, traf)
+
+ // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec)
+ case 31:
+ setupBatchTCPTest(logf, traf)
+
+ // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec)
+ case 101:
+ setupWGTest(nil, logf, traf, Addr1, Addr2)
+
+ default:
+ log.Fatalf("provide a valid test number (0..n)")
+ }
+
+ logf("initialized ok.")
+ traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0)
+
+ var cur, prev Snapshot
+ var pps int64
+ i := 0
+ for {
+ i += 1
+ time.Sleep(10 * time.Millisecond)
+
+ if (i % 100) == 0 {
+ prev = cur
+ cur = traf.Snap()
+ d := cur.Sub(prev)
+
+ if prev.WhenNsec == 0 {
+ logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets)
+ } else {
+ logf("%v @%7d pkt/s", d, pps)
+ }
+ }
+
+ pps = traf.Adjust()
+ }
+}
+
+func newDebugMux() *http.ServeMux {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/debug/pprof/", pprof.Index)
+ mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
+ mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
+ mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
+ mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
+ return mux
+}
+
+func runDebugServer(mux *http.ServeMux, addr string) {
+ srv := &http.Server{
+ Addr: addr,
+ Handler: mux,
+ }
+ if err := srv.ListenAndServe(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+// The absolute minimal test of the traffic generator: have it fill
+// a packet buffer, then absorb it again. Zero packet loss.
+func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) {
+ go func() {
+ b := make([]byte, 1600)
+ for {
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ break
+ }
+ traf.GotPacket(b[0:n+16], 16)
+ }
+ }()
+}
+
+// Almost the same, but this time allocate a fresh buffer each time
+// through the loop. Still zero packet loss. Runs about 2/3 as fast for me.
+func setupTrivialTest(logf logger.Logf, traf *TrafficGen) {
+ go func() {
+ for {
+ b := make([]byte, 1600)
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ break
+ }
+ traf.GotPacket(b[0:n+16], 16)
+ }
+ }()
+}
+
+// Pass packets through a blocking channel between sender and receiver.
+// Still zero packet loss since the sender stops when the channel is full.
+// Max speed depends on channel length (I'm not sure why).
+func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) {
+ ch := make(chan []byte, 1000)
+
+ go func() {
+ // transmitter
+ for {
+ b := make([]byte, 1600)
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ close(ch)
+ break
+ }
+ ch <- b[0 : n+16]
+ }
+ }()
+
+ go func() {
+ // receiver
+ for b := range ch {
+ traf.GotPacket(b, 16)
+ }
+ }()
+}
+
+// Same as setupBlockingChannelTest, but now we drop packets whenever the
+// channel is full. Max speed is about the same as the above test, but
+// now with nonzero packet loss.
+func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) {
+ ch := make(chan []byte, 1000)
+
+ go func() {
+ // transmitter
+ for {
+ b := make([]byte, 1600)
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ close(ch)
+ break
+ }
+ select {
+ case ch <- b[0 : n+16]:
+ default:
+ }
+ }
+ }()
+
+ go func() {
+ // receiver
+ for b := range ch {
+ traf.GotPacket(b, 16)
+ }
+ }()
+}
+
+// Same as above, but at an intermediate blocking channel and goroutine
+// to make things a little more like wireguard-go. Roughly 20% slower than
+// the single-channel version.
+func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) {
+ ch := make(chan []byte, 1000)
+ ch2 := make(chan []byte, 1000)
+
+ go func() {
+ // transmitter
+ for {
+ b := make([]byte, 1600)
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ close(ch)
+ break
+ }
+ select {
+ case ch <- b[0 : n+16]:
+ default:
+ }
+ }
+ }()
+
+ go func() {
+ // intermediary
+ for b := range ch {
+ ch2 <- b
+ }
+ close(ch2)
+ }()
+
+ go func() {
+ // receiver
+ for b := range ch2 {
+ traf.GotPacket(b, 16)
+ }
+ }()
+}
+
+// Instead of a channel, pass packets through a UDP socket.
+func setupUDPTest(logf logger.Logf, traf *TrafficGen) {
+ la, err := net.ResolveUDPAddr("udp", ":0")
+ if err != nil {
+ log.Fatalf("resolve: %v", err)
+ }
+
+ s1, err := net.ListenUDP("udp", la)
+ if err != nil {
+ log.Fatalf("listen1: %v", err)
+ }
+ s2, err := net.ListenUDP("udp", la)
+ if err != nil {
+ log.Fatalf("listen2: %v", err)
+ }
+
+ a2 := s2.LocalAddr()
+
+ // On macOS (but not Linux), you can't transmit to 0.0.0.0:port,
+ // which is what returns from .LocalAddr() above. We have to
+ // force it to localhost instead.
+ a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1")
+
+ s1.SetWriteBuffer(1024 * 1024)
+ s2.SetReadBuffer(1024 * 1024)
+
+ go func() {
+ // transmitter
+ b := make([]byte, 1600)
+ for {
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ break
+ }
+ s1.WriteTo(b[16:n+16], a2)
+ }
+ }()
+
+ go func() {
+ // receiver
+ b := make([]byte, 1600)
+ for traf.Running() {
+ // Use ReadFrom instead of Read, to be more like
+ // how wireguard-go does it, even though we're not
+ // going to actually look at the address.
+ n, _, err := s2.ReadFrom(b)
+ if err != nil {
+ log.Fatalf("s2.Read: %v", err)
+ }
+ traf.GotPacket(b[:n], 0)
+ }
+ }()
+}
+
+// Instead of a channel, pass packets through a TCP socket.
+// TCP is a single stream, so we can amortize one syscall across
+// multiple packets. 10x amortization seems to make it go ~10x faster,
+// as expected, getting us close to the speed of the channel tests above.
+// There's also zero packet loss.
+func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) {
+ sl, err := net.Listen("tcp", ":0")
+ if err != nil {
+ log.Fatalf("listen: %v", err)
+ }
+
+ var slCloseOnce sync.Once
+ slClose := func() {
+ slCloseOnce.Do(func() {
+ sl.Close()
+ })
+ }
+
+ s1, err := net.Dial("tcp", sl.Addr().String())
+ if err != nil {
+ log.Fatalf("dial: %v", err)
+ }
+
+ s2, err := sl.Accept()
+ if err != nil {
+ log.Fatalf("accept: %v", err)
+ }
+
+ s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024)
+ s2.(*net.TCPConn).SetReadBuffer(1024 * 1024)
+
+ ch := make(chan int)
+
+ go func() {
+ // transmitter
+ defer slClose()
+ defer s1.Close()
+
+ bs1 := bufio.NewWriterSize(s1, 1024*1024)
+
+ b := make([]byte, 1600)
+ i := 0
+ for {
+ i += 1
+ n := traf.Generate(b, 16)
+ if n == 0 {
+ break
+ }
+ if i == 1 {
+ ch <- n
+ }
+ bs1.Write(b[16 : n+16])
+
+ // TODO: this is a pretty half-baked batching
+ // function, which we'd never want to employ in
+ // a real-life program.
+ //
+ // In real life, we'd probably want to flush
+ // immediately when there are no more packets to
+ // generate, and queue up only if we fall behind.
+ //
+ // In our case however, we just want to see the
+ // technical benefits of batching 10 syscalls
+ // into 1, so a fixed ratio makes more sense.
+ if (i % 10) == 0 {
+ bs1.Flush()
+ }
+ }
+ }()
+
+ go func() {
+ // receiver
+ defer slClose()
+ defer s2.Close()
+
+ bs2 := bufio.NewReaderSize(s2, 1024*1024)
+
+ // Find out the packet size (we happen to know they're
+ // all the same size)
+ packetSize := <-ch
+
+ b := make([]byte, packetSize)
+ for traf.Running() {
+ // TODO: can't use ReadFrom() here, which is
+ // unfair compared to UDP. (ReadFrom for UDP
+ // apparently allocates memory per packet, which
+ // this test does not.)
+ n, err := io.ReadFull(bs2, b)
+ if err != nil {
+ log.Fatalf("s2.Read: %v", err)
+ }
+ traf.GotPacket(b[:n], 0)
+ }
+ }()
+}
diff --git a/wgengine/bench/bench_test.go b/wgengine/bench/bench_test.go index 4fae86c05..42571d055 100644 --- a/wgengine/bench/bench_test.go +++ b/wgengine/bench/bench_test.go @@ -1,108 +1,108 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "fmt" - "testing" - "time" - - "tailscale.com/types/logger" -) - -func BenchmarkTrivialNoAlloc(b *testing.B) { - run(b, setupTrivialNoAllocTest) -} -func BenchmarkTrivial(b *testing.B) { - run(b, setupTrivialTest) -} - -func BenchmarkBlockingChannel(b *testing.B) { - run(b, setupBlockingChannelTest) -} - -func BenchmarkNonblockingChannel(b *testing.B) { - run(b, setupNonblockingChannelTest) -} - -func BenchmarkDoubleChannel(b *testing.B) { - run(b, setupDoubleChannelTest) -} - -func BenchmarkUDP(b *testing.B) { - run(b, setupUDPTest) -} - -func BenchmarkBatchTCP(b *testing.B) { - run(b, setupBatchTCPTest) -} - -func BenchmarkWireGuardTest(b *testing.B) { - b.Skip("https://github.com/tailscale/tailscale/issues/2716") - run(b, func(logf logger.Logf, traf *TrafficGen) { - setupWGTest(b, logf, traf, Addr1, Addr2) - }) -} - -type SetupFunc func(logger.Logf, *TrafficGen) - -func run(b *testing.B, setup SetupFunc) { - sizes := []int{ - ICMPMinSize + 8, - ICMPMinSize + 100, - ICMPMinSize + 1000, - } - - for _, size := range sizes { - b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { - runOnce(b, setup, size) - }) - } -} - -func runOnce(b *testing.B, setup SetupFunc, payload int) { - b.StopTimer() - b.ReportAllocs() - - var logf logger.Logf = b.Logf - if !testing.Verbose() { - logf = logger.Discard - } - - traf := NewTrafficGen(b.StartTimer) - setup(logf, traf) - - logf("initialized. (n=%v)", b.N) - b.SetBytes(int64(payload)) - - traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) - - var cur, prev Snapshot - var pps int64 - i := 0 - for traf.Running() { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec != 0 { - logf("%v @%7d pkt/sec", d, pps) - } - } - - pps = traf.Adjust() - } - - cur = traf.Snap() - d := cur.Sub(prev) - loss := float64(d.LostPackets) / float64(d.RxPackets) - - b.ReportMetric(loss*100, "%lost") -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Create two wgengine instances and pass data through them, measuring
+// throughput, latency, and packet loss.
+package main
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+func BenchmarkTrivialNoAlloc(b *testing.B) {
+ run(b, setupTrivialNoAllocTest)
+}
+func BenchmarkTrivial(b *testing.B) {
+ run(b, setupTrivialTest)
+}
+
+func BenchmarkBlockingChannel(b *testing.B) {
+ run(b, setupBlockingChannelTest)
+}
+
+func BenchmarkNonblockingChannel(b *testing.B) {
+ run(b, setupNonblockingChannelTest)
+}
+
+func BenchmarkDoubleChannel(b *testing.B) {
+ run(b, setupDoubleChannelTest)
+}
+
+func BenchmarkUDP(b *testing.B) {
+ run(b, setupUDPTest)
+}
+
+func BenchmarkBatchTCP(b *testing.B) {
+ run(b, setupBatchTCPTest)
+}
+
+func BenchmarkWireGuardTest(b *testing.B) {
+ b.Skip("https://github.com/tailscale/tailscale/issues/2716")
+ run(b, func(logf logger.Logf, traf *TrafficGen) {
+ setupWGTest(b, logf, traf, Addr1, Addr2)
+ })
+}
+
+type SetupFunc func(logger.Logf, *TrafficGen)
+
+func run(b *testing.B, setup SetupFunc) {
+ sizes := []int{
+ ICMPMinSize + 8,
+ ICMPMinSize + 100,
+ ICMPMinSize + 1000,
+ }
+
+ for _, size := range sizes {
+ b.Run(fmt.Sprintf("%d", size), func(b *testing.B) {
+ runOnce(b, setup, size)
+ })
+ }
+}
+
+func runOnce(b *testing.B, setup SetupFunc, payload int) {
+ b.StopTimer()
+ b.ReportAllocs()
+
+ var logf logger.Logf = b.Logf
+ if !testing.Verbose() {
+ logf = logger.Discard
+ }
+
+ traf := NewTrafficGen(b.StartTimer)
+ setup(logf, traf)
+
+ logf("initialized. (n=%v)", b.N)
+ b.SetBytes(int64(payload))
+
+ traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N))
+
+ var cur, prev Snapshot
+ var pps int64
+ i := 0
+ for traf.Running() {
+ i += 1
+ time.Sleep(10 * time.Millisecond)
+
+ if (i % 100) == 0 {
+ prev = cur
+ cur = traf.Snap()
+ d := cur.Sub(prev)
+
+ if prev.WhenNsec != 0 {
+ logf("%v @%7d pkt/sec", d, pps)
+ }
+ }
+
+ pps = traf.Adjust()
+ }
+
+ cur = traf.Snap()
+ d := cur.Sub(prev)
+ loss := float64(d.LostPackets) / float64(d.RxPackets)
+
+ b.ReportMetric(loss*100, "%lost")
+}
diff --git a/wgengine/bench/trafficgen.go b/wgengine/bench/trafficgen.go index ce79c616f..9de3c2e6b 100644 --- a/wgengine/bench/trafficgen.go +++ b/wgengine/bench/trafficgen.go @@ -1,259 +1,259 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/binary" - "fmt" - "log" - "net/netip" - "sync" - "time" - - "tailscale.com/net/packet" - "tailscale.com/types/ipproto" -) - -type Snapshot struct { - WhenNsec int64 // current time - timeAcc int64 // accumulated time (+NSecPerTx per transmit) - - LastSeqTx int64 // last sequence number sent - LastSeqRx int64 // last sequence number received - TotalLost int64 // packets out-of-order or lost so far - TotalOOO int64 // packets out-of-order so far - TotalBytesRx int64 // total bytes received so far -} - -type Delta struct { - DurationNsec int64 - TxPackets int64 - RxPackets int64 - LostPackets int64 - OOOPackets int64 - Bytes int64 -} - -func (b Snapshot) Sub(a Snapshot) Delta { - return Delta{ - DurationNsec: b.WhenNsec - a.WhenNsec, - TxPackets: b.LastSeqTx - a.LastSeqTx, - RxPackets: (b.LastSeqRx - a.LastSeqRx) - - (b.TotalLost - a.TotalLost) + - (b.TotalOOO - a.TotalOOO), - LostPackets: b.TotalLost - a.TotalLost, - OOOPackets: b.TotalOOO - a.TotalOOO, - Bytes: b.TotalBytesRx - a.TotalBytesRx, - } -} - -func (d Delta) String() string { - return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", - d.TxPackets, d.RxPackets, d.LostPackets, - float64(d.LostPackets)*100/float64(d.TxPackets), - d.OOOPackets, - float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) -} - -type TrafficGen struct { - mu sync.Mutex - cur, prev Snapshot // snapshots used for rate control - buf []byte // pre-generated packet buffer - done bool // true if the test has completed - - onFirstPacket func() // function to call on first received packet - - // maxPackets is the max packets to receive (not send) before - // ending the test. If it's zero, the test runs forever. - maxPackets int64 - - // nsPerPacket is the target average nanoseconds between packets. - // It's initially zero, which means transmit as fast as the - // caller wants to go. - nsPerPacket int64 - - // ppsHistory is the observed packets-per-second from recent - // samples. - ppsHistory [5]int64 -} - -// NewTrafficGen creates a new, initially locked, TrafficGen. -// Until Start() is called, Generate() will block forever. -func NewTrafficGen(onFirstPacket func()) *TrafficGen { - t := TrafficGen{ - onFirstPacket: onFirstPacket, - } - - // initially locked, until first Start() - t.mu.Lock() - - return &t -} - -// Start starts the traffic generator. It assumes mu is already locked, -// and unlocks it. -func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { - h12 := packet.ICMP4Header{ - IP4Header: packet.IP4Header{ - IPProto: ipproto.ICMPv4, - IPID: 0, - Src: src, - Dst: dst, - }, - Type: packet.ICMP4EchoRequest, - Code: packet.ICMP4NoCode, - } - - // ensure there's room for ICMP header plus sequence number - if bytesPerPacket < ICMPMinSize+8 { - log.Fatalf("bytesPerPacket must be > 24+8") - } - - t.maxPackets = maxPackets - - payload := make([]byte, bytesPerPacket-ICMPMinSize) - t.buf = packet.Generate(h12, payload) - - t.mu.Unlock() -} - -func (t *TrafficGen) Snap() Snapshot { - t.mu.Lock() - defer t.mu.Unlock() - - t.cur.WhenNsec = time.Now().UnixNano() - return t.cur -} - -func (t *TrafficGen) Running() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return !t.done -} - -// Generate produces the next packet in the sequence. It sleeps if -// it's too soon for the next packet to be sent. -// -// The generated packet is placed into buf at offset ofs, for compatibility -// with the wireguard-go conventions. -// -// The return value is the number of bytes generated in the packet, or 0 -// if the test has finished running. -func (t *TrafficGen) Generate(b []byte, ofs int) int { - t.mu.Lock() - - now := time.Now().UnixNano() - if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { - t.cur.timeAcc = now - 1 - } - if t.cur.timeAcc >= now { - // too soon - t.mu.Unlock() - time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) - t.mu.Lock() - - now = t.cur.timeAcc - } - if t.done { - t.mu.Unlock() - return 0 - } - - t.cur.timeAcc += t.nsPerPacket - t.cur.LastSeqTx += 1 - t.cur.WhenNsec = now - seq := t.cur.LastSeqTx - - t.mu.Unlock() - - copy(b[ofs:], t.buf) - binary.BigEndian.PutUint64( - b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], - uint64(seq)) - - return len(t.buf) -} - -// GotPacket processes a packet that came back on the receive side. -func (t *TrafficGen) GotPacket(b []byte, ofs int) { - t.mu.Lock() - defer t.mu.Unlock() - - s := &t.cur - seq := int64(binary.BigEndian.Uint64( - b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) - if seq > s.LastSeqRx { - if s.LastSeqRx > 0 { - // only count lost packets after the very first - // successful one. - s.TotalLost += seq - s.LastSeqRx - 1 - } - s.LastSeqRx = seq - } else { - s.TotalOOO += 1 - } - - // +1 packet since we only start counting after the first one - if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { - t.done = true - } - s.TotalBytesRx += int64(len(b) - ofs) - - f := t.onFirstPacket - t.onFirstPacket = nil - if f != nil { - f() - } -} - -// Adjust tunes the transmit rate based on the received packets. -// The goal is to converge on the fastest transmit rate that still has -// minimal packet loss. Returns the new target rate in packets/sec. -// -// We need to play this guessing game in order to balance out tx and rx -// rates when there's a lossy network between them. Otherwise we can end -// up using 99% of the CPU to blast out transmitted packets and leaving only -// 1% to receive them, leading to a misleading throughput calculation. -// -// Call this function multiple times per second. -func (t *TrafficGen) Adjust() (pps int64) { - t.mu.Lock() - defer t.mu.Unlock() - - d := t.cur.Sub(t.prev) - - // don't adjust rate until the first full period *after* receiving - // the first packet. This skips any handshake time in the underlying - // transport. - if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { - t.prev = t.cur - return 0 // no estimate yet, continue at max speed - } - - pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) - - // We use a rate selection algorithm based loosely on TCP BBR. - // Basically, we set the transmit rate to be a bit higher than - // the best observed transmit rate in the last several time - // periods. This guarantees some packet loss, but should converge - // quickly on a rate near the sustainable maximum. - bestPPS := pps - for _, p := range t.ppsHistory { - if p > bestPPS { - bestPPS = p - } - } - if pps > 0 && t.prev.WhenNsec > 0 { - copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) - t.ppsHistory[0] = pps - } - if bestPPS > 0 { - pps = bestPPS * 103 / 100 - t.nsPerPacket = int64(1e9 / pps) - } - t.prev = t.cur - - return pps -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "encoding/binary"
+ "fmt"
+ "log"
+ "net/netip"
+ "sync"
+ "time"
+
+ "tailscale.com/net/packet"
+ "tailscale.com/types/ipproto"
+)
+
+type Snapshot struct {
+ WhenNsec int64 // current time
+ timeAcc int64 // accumulated time (+NSecPerTx per transmit)
+
+ LastSeqTx int64 // last sequence number sent
+ LastSeqRx int64 // last sequence number received
+ TotalLost int64 // packets out-of-order or lost so far
+ TotalOOO int64 // packets out-of-order so far
+ TotalBytesRx int64 // total bytes received so far
+}
+
+type Delta struct {
+ DurationNsec int64
+ TxPackets int64
+ RxPackets int64
+ LostPackets int64
+ OOOPackets int64
+ Bytes int64
+}
+
+func (b Snapshot) Sub(a Snapshot) Delta {
+ return Delta{
+ DurationNsec: b.WhenNsec - a.WhenNsec,
+ TxPackets: b.LastSeqTx - a.LastSeqTx,
+ RxPackets: (b.LastSeqRx - a.LastSeqRx) -
+ (b.TotalLost - a.TotalLost) +
+ (b.TotalOOO - a.TotalOOO),
+ LostPackets: b.TotalLost - a.TotalLost,
+ OOOPackets: b.TotalOOO - a.TotalOOO,
+ Bytes: b.TotalBytesRx - a.TotalBytesRx,
+ }
+}
+
+func (d Delta) String() string {
+ return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)",
+ d.TxPackets, d.RxPackets, d.LostPackets,
+ float64(d.LostPackets)*100/float64(d.TxPackets),
+ d.OOOPackets,
+ float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6)
+}
+
+type TrafficGen struct {
+ mu sync.Mutex
+ cur, prev Snapshot // snapshots used for rate control
+ buf []byte // pre-generated packet buffer
+ done bool // true if the test has completed
+
+ onFirstPacket func() // function to call on first received packet
+
+ // maxPackets is the max packets to receive (not send) before
+ // ending the test. If it's zero, the test runs forever.
+ maxPackets int64
+
+ // nsPerPacket is the target average nanoseconds between packets.
+ // It's initially zero, which means transmit as fast as the
+ // caller wants to go.
+ nsPerPacket int64
+
+ // ppsHistory is the observed packets-per-second from recent
+ // samples.
+ ppsHistory [5]int64
+}
+
+// NewTrafficGen creates a new, initially locked, TrafficGen.
+// Until Start() is called, Generate() will block forever.
+func NewTrafficGen(onFirstPacket func()) *TrafficGen {
+ t := TrafficGen{
+ onFirstPacket: onFirstPacket,
+ }
+
+ // initially locked, until first Start()
+ t.mu.Lock()
+
+ return &t
+}
+
+// Start starts the traffic generator. It assumes mu is already locked,
+// and unlocks it.
+func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) {
+ h12 := packet.ICMP4Header{
+ IP4Header: packet.IP4Header{
+ IPProto: ipproto.ICMPv4,
+ IPID: 0,
+ Src: src,
+ Dst: dst,
+ },
+ Type: packet.ICMP4EchoRequest,
+ Code: packet.ICMP4NoCode,
+ }
+
+ // ensure there's room for ICMP header plus sequence number
+ if bytesPerPacket < ICMPMinSize+8 {
+ log.Fatalf("bytesPerPacket must be > 24+8")
+ }
+
+ t.maxPackets = maxPackets
+
+ payload := make([]byte, bytesPerPacket-ICMPMinSize)
+ t.buf = packet.Generate(h12, payload)
+
+ t.mu.Unlock()
+}
+
+func (t *TrafficGen) Snap() Snapshot {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.cur.WhenNsec = time.Now().UnixNano()
+ return t.cur
+}
+
+func (t *TrafficGen) Running() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ return !t.done
+}
+
+// Generate produces the next packet in the sequence. It sleeps if
+// it's too soon for the next packet to be sent.
+//
+// The generated packet is placed into buf at offset ofs, for compatibility
+// with the wireguard-go conventions.
+//
+// The return value is the number of bytes generated in the packet, or 0
+// if the test has finished running.
+func (t *TrafficGen) Generate(b []byte, ofs int) int {
+ t.mu.Lock()
+
+ now := time.Now().UnixNano()
+ if t.nsPerPacket == 0 || t.cur.timeAcc == 0 {
+ t.cur.timeAcc = now - 1
+ }
+ if t.cur.timeAcc >= now {
+ // too soon
+ t.mu.Unlock()
+ time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond)
+ t.mu.Lock()
+
+ now = t.cur.timeAcc
+ }
+ if t.done {
+ t.mu.Unlock()
+ return 0
+ }
+
+ t.cur.timeAcc += t.nsPerPacket
+ t.cur.LastSeqTx += 1
+ t.cur.WhenNsec = now
+ seq := t.cur.LastSeqTx
+
+ t.mu.Unlock()
+
+ copy(b[ofs:], t.buf)
+ binary.BigEndian.PutUint64(
+ b[ofs+ICMPMinSize:ofs+ICMPMinSize+8],
+ uint64(seq))
+
+ return len(t.buf)
+}
+
+// GotPacket processes a packet that came back on the receive side.
+func (t *TrafficGen) GotPacket(b []byte, ofs int) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ s := &t.cur
+ seq := int64(binary.BigEndian.Uint64(
+ b[ofs+ICMPMinSize : ofs+ICMPMinSize+8]))
+ if seq > s.LastSeqRx {
+ if s.LastSeqRx > 0 {
+ // only count lost packets after the very first
+ // successful one.
+ s.TotalLost += seq - s.LastSeqRx - 1
+ }
+ s.LastSeqRx = seq
+ } else {
+ s.TotalOOO += 1
+ }
+
+ // +1 packet since we only start counting after the first one
+ if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 {
+ t.done = true
+ }
+ s.TotalBytesRx += int64(len(b) - ofs)
+
+ f := t.onFirstPacket
+ t.onFirstPacket = nil
+ if f != nil {
+ f()
+ }
+}
+
+// Adjust tunes the transmit rate based on the received packets.
+// The goal is to converge on the fastest transmit rate that still has
+// minimal packet loss. Returns the new target rate in packets/sec.
+//
+// We need to play this guessing game in order to balance out tx and rx
+// rates when there's a lossy network between them. Otherwise we can end
+// up using 99% of the CPU to blast out transmitted packets and leaving only
+// 1% to receive them, leading to a misleading throughput calculation.
+//
+// Call this function multiple times per second.
+func (t *TrafficGen) Adjust() (pps int64) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ d := t.cur.Sub(t.prev)
+
+ // don't adjust rate until the first full period *after* receiving
+ // the first packet. This skips any handshake time in the underlying
+ // transport.
+ if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 {
+ t.prev = t.cur
+ return 0 // no estimate yet, continue at max speed
+ }
+
+ pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec)
+
+ // We use a rate selection algorithm based loosely on TCP BBR.
+ // Basically, we set the transmit rate to be a bit higher than
+ // the best observed transmit rate in the last several time
+ // periods. This guarantees some packet loss, but should converge
+ // quickly on a rate near the sustainable maximum.
+ bestPPS := pps
+ for _, p := range t.ppsHistory {
+ if p > bestPPS {
+ bestPPS = p
+ }
+ }
+ if pps > 0 && t.prev.WhenNsec > 0 {
+ copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1])
+ t.ppsHistory[0] = pps
+ }
+ if bestPPS > 0 {
+ pps = bestPPS * 103 / 100
+ t.nsPerPacket = int64(1e9 / pps)
+ }
+ t.prev = t.cur
+
+ return pps
+}
diff --git a/wgengine/capture/capture.go b/wgengine/capture/capture.go index 6ea5a9549..01f79ea9f 100644 --- a/wgengine/capture/capture.go +++ b/wgengine/capture/capture.go @@ -1,238 +1,238 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package capture formats packet logging into a debug pcap stream. -package capture - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net/http" - "sync" - "time" - - _ "embed" - - "tailscale.com/net/packet" - "tailscale.com/util/set" -) - -//go:embed ts-dissector.lua -var DissectorLua string - -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) - -var bufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -const flushPeriod = 100 * time.Millisecond - -func writePcapHeader(w io.Writer) { - binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number - binary.Write(w, binary.LittleEndian, uint16(2)) // version major - binary.Write(w, binary.LittleEndian, uint16(4)) // version minor - binary.Write(w, binary.LittleEndian, uint32(0)) // this zone - binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures - binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len - binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 -} - -func writePktHeader(w *bytes.Buffer, when time.Time, length int) { - s := when.Unix() - us := when.UnixMicro() - (s * 1000000) - - binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds - binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds - binary.Write(w, binary.LittleEndian, uint32(length)) // length present - binary.Write(w, binary.LittleEndian, uint32(length)) // total length -} - -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { - ctx, c := context.WithCancel(context.Background()) - return &Sink{ - ctx: ctx, - ctxCancel: c, - } -} - -// Type Sink handles callbacks with packets to be logged, -// formatting them into a pcap stream which is mirrored to -// all registered outputs. -type Sink struct { - ctx context.Context - ctxCancel context.CancelFunc - - mu sync.Mutex - outputs set.HandleSet[io.Writer] - flushTimer *time.Timer // or nil if none running -} - -// RegisterOutput connects an output to this sink, which -// will be written to with a pcap stream as packets are logged. -// A function is returned which unregisters the output when -// called. -// -// If w implements io.Closer, it will be closed upon error -// or when the sink is closed. If w implements http.Flusher, -// it will be flushed periodically. -func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { - select { - case <-s.ctx.Done(): - return func() {} - default: - } - - writePcapHeader(w) - s.mu.Lock() - hnd := s.outputs.Add(w) - s.mu.Unlock() - - return func() { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.outputs, hnd) - } -} - -// NumOutputs returns the number of outputs registered with the sink. -func (s *Sink) NumOutputs() int { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.outputs) -} - -// Close shuts down the sink. Future calls to LogPacket -// are ignored, and any registered output that implements -// io.Closer is closed. -func (s *Sink) Close() error { - s.ctxCancel() - s.mu.Lock() - defer s.mu.Unlock() - if s.flushTimer != nil { - s.flushTimer.Stop() - s.flushTimer = nil - } - - for _, o := range s.outputs { - if o, ok := o.(io.Closer); ok { - o.Close() - } - } - s.outputs = nil - return nil -} - -// WaitCh returns a channel which blocks until -// the sink is closed. -func (s *Sink) WaitCh() <-chan struct{} { - return s.ctx.Done() -} - -func customDataLen(meta packet.CaptureMeta) int { - length := 4 - if meta.DidSNAT { - length += meta.OriginalSrc.Addr().BitLen() / 8 - } - if meta.DidDNAT { - length += meta.OriginalDst.Addr().BitLen() / 8 - } - return length -} - -// LogPacket is called to insert a packet into the capture. -// -// This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { - select { - case <-s.ctx.Done(): - return - default: - } - - extraLen := customDataLen(meta) - b := bufferPool.Get().(*bytes.Buffer) - b.Reset() - b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) - defer bufferPool.Put(b) - - writePktHeader(b, when, len(data)+extraLen) - - // Custom tailscale debugging data - binary.Write(b, binary.LittleEndian, uint16(path)) - if meta.DidSNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) - b.Write(meta.OriginalSrc.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 - } - if meta.DidDNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) - b.Write(meta.OriginalDst.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 - } - - b.Write(data) - - s.mu.Lock() - defer s.mu.Unlock() - - var hadError []set.Handle - for hnd, o := range s.outputs { - if _, err := o.Write(b.Bytes()); err != nil { - hadError = append(hadError, hnd) - continue - } - } - for _, hnd := range hadError { - if o, ok := s.outputs[hnd].(io.Closer); ok { - o.Close() - } - delete(s.outputs, hnd) - } - - if s.flushTimer == nil { - s.flushTimer = time.AfterFunc(flushPeriod, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, o := range s.outputs { - if f, ok := o.(http.Flusher); ok { - f.Flush() - } - } - s.flushTimer = nil - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package capture formats packet logging into a debug pcap stream.
+package capture
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+
+ _ "embed"
+
+ "tailscale.com/net/packet"
+ "tailscale.com/util/set"
+)
+
+//go:embed ts-dissector.lua
+var DissectorLua string
+
+// Callback describes a function which is called to
+// record packets when debugging packet-capture.
+// Such callbacks must not take ownership of the
+// provided data slice: it may only copy out of it
+// within the lifetime of the function.
+type Callback func(Path, time.Time, []byte, packet.CaptureMeta)
+
+var bufferPool = sync.Pool{
+ New: func() any {
+ return new(bytes.Buffer)
+ },
+}
+
+const flushPeriod = 100 * time.Millisecond
+
+func writePcapHeader(w io.Writer) {
+ binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number
+ binary.Write(w, binary.LittleEndian, uint16(2)) // version major
+ binary.Write(w, binary.LittleEndian, uint16(4)) // version minor
+ binary.Write(w, binary.LittleEndian, uint32(0)) // this zone
+ binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures
+ binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len
+ binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0
+}
+
+func writePktHeader(w *bytes.Buffer, when time.Time, length int) {
+ s := when.Unix()
+ us := when.UnixMicro() - (s * 1000000)
+
+ binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds
+ binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds
+ binary.Write(w, binary.LittleEndian, uint32(length)) // length present
+ binary.Write(w, binary.LittleEndian, uint32(length)) // total length
+}
+
+// Path describes where in the data path the packet was captured.
+type Path uint8
+
+// Valid Path values.
+const (
+ // FromLocal indicates the packet was logged as it traversed the FromLocal path:
+ // i.e.: A packet from the local system into the TUN.
+ FromLocal Path = 0
+ // FromPeer indicates the packet was logged upon reception from a remote peer.
+ FromPeer Path = 1
+ // SynthesizedToLocal indicates the packet was generated from within tailscaled,
+ // and is being routed to the local machine's network stack.
+ SynthesizedToLocal Path = 2
+ // SynthesizedToPeer indicates the packet was generated from within tailscaled,
+ // and is being routed to a remote Wireguard peer.
+ SynthesizedToPeer Path = 3
+
+ // PathDisco indicates the packet is information about a disco frame.
+ PathDisco Path = 254
+)
+
+// New creates a new capture sink.
+func New() *Sink {
+ ctx, c := context.WithCancel(context.Background())
+ return &Sink{
+ ctx: ctx,
+ ctxCancel: c,
+ }
+}
+
+// Type Sink handles callbacks with packets to be logged,
+// formatting them into a pcap stream which is mirrored to
+// all registered outputs.
+type Sink struct {
+ ctx context.Context
+ ctxCancel context.CancelFunc
+
+ mu sync.Mutex
+ outputs set.HandleSet[io.Writer]
+ flushTimer *time.Timer // or nil if none running
+}
+
+// RegisterOutput connects an output to this sink, which
+// will be written to with a pcap stream as packets are logged.
+// A function is returned which unregisters the output when
+// called.
+//
+// If w implements io.Closer, it will be closed upon error
+// or when the sink is closed. If w implements http.Flusher,
+// it will be flushed periodically.
+func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) {
+ select {
+ case <-s.ctx.Done():
+ return func() {}
+ default:
+ }
+
+ writePcapHeader(w)
+ s.mu.Lock()
+ hnd := s.outputs.Add(w)
+ s.mu.Unlock()
+
+ return func() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.outputs, hnd)
+ }
+}
+
+// NumOutputs returns the number of outputs registered with the sink.
+func (s *Sink) NumOutputs() int {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return len(s.outputs)
+}
+
+// Close shuts down the sink. Future calls to LogPacket
+// are ignored, and any registered output that implements
+// io.Closer is closed.
+func (s *Sink) Close() error {
+ s.ctxCancel()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.flushTimer != nil {
+ s.flushTimer.Stop()
+ s.flushTimer = nil
+ }
+
+ for _, o := range s.outputs {
+ if o, ok := o.(io.Closer); ok {
+ o.Close()
+ }
+ }
+ s.outputs = nil
+ return nil
+}
+
+// WaitCh returns a channel which blocks until
+// the sink is closed.
+func (s *Sink) WaitCh() <-chan struct{} {
+ return s.ctx.Done()
+}
+
+func customDataLen(meta packet.CaptureMeta) int {
+ length := 4
+ if meta.DidSNAT {
+ length += meta.OriginalSrc.Addr().BitLen() / 8
+ }
+ if meta.DidDNAT {
+ length += meta.OriginalDst.Addr().BitLen() / 8
+ }
+ return length
+}
+
+// LogPacket is called to insert a packet into the capture.
+//
+// This function does not take ownership of the provided data slice.
+func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) {
+ select {
+ case <-s.ctx.Done():
+ return
+ default:
+ }
+
+ extraLen := customDataLen(meta)
+ b := bufferPool.Get().(*bytes.Buffer)
+ b.Reset()
+ b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload)
+ defer bufferPool.Put(b)
+
+ writePktHeader(b, when, len(data)+extraLen)
+
+ // Custom tailscale debugging data
+ binary.Write(b, binary.LittleEndian, uint16(path))
+ if meta.DidSNAT {
+ binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8))
+ b.Write(meta.OriginalSrc.Addr().AsSlice())
+ } else {
+ binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0
+ }
+ if meta.DidDNAT {
+ binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8))
+ b.Write(meta.OriginalDst.Addr().AsSlice())
+ } else {
+ binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0
+ }
+
+ b.Write(data)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var hadError []set.Handle
+ for hnd, o := range s.outputs {
+ if _, err := o.Write(b.Bytes()); err != nil {
+ hadError = append(hadError, hnd)
+ continue
+ }
+ }
+ for _, hnd := range hadError {
+ if o, ok := s.outputs[hnd].(io.Closer); ok {
+ o.Close()
+ }
+ delete(s.outputs, hnd)
+ }
+
+ if s.flushTimer == nil {
+ s.flushTimer = time.AfterFunc(flushPeriod, func() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, o := range s.outputs {
+ if f, ok := o.(http.Flusher); ok {
+ f.Flush()
+ }
+ }
+ s.flushTimer = nil
+ })
+ }
+}
diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index f2e85dcd5..58359acdd 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "errors" - "net" - "net/netip" - "sync" - "syscall" - "time" -) - -// blockForeverConn is a net.PacketConn whose reads block until it is closed. -type blockForeverConn struct { - mu sync.Mutex - cond *sync.Cond - closed bool -} - -func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, netip.AddrPort{}, net.ErrClosed -} - -func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { - // Silently drop writes. - return len(p), nil -} - -func (c *blockForeverConn) LocalAddr() net.Addr { - // Return a *net.UDPAddr because lots of code assumes that it will. - return new(net.UDPAddr) -} - -func (c *blockForeverConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - c.closed = true - c.cond.Broadcast() - return nil -} - -func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "errors"
+ "net"
+ "net/netip"
+ "sync"
+ "syscall"
+ "time"
+)
+
+// blockForeverConn is a net.PacketConn whose reads block until it is closed.
+type blockForeverConn struct {
+ mu sync.Mutex
+ cond *sync.Cond
+ closed bool
+}
+
+func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
+ c.mu.Lock()
+ for !c.closed {
+ c.cond.Wait()
+ }
+ c.mu.Unlock()
+ return 0, netip.AddrPort{}, net.ErrClosed
+}
+
+func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) {
+ // Silently drop writes.
+ return len(p), nil
+}
+
+func (c *blockForeverConn) LocalAddr() net.Addr {
+ // Return a *net.UDPAddr because lots of code assumes that it will.
+ return new(net.UDPAddr)
+}
+
+func (c *blockForeverConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return net.ErrClosed
+ }
+ c.closed = true
+ c.cond.Broadcast()
+ return nil
+}
+
+func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") }
+func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") }
+func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") }
+func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType }
diff --git a/wgengine/magicsock/endpoint_default.go b/wgengine/magicsock/endpoint_default.go index 1ed6e5e0e..9ffeef5f8 100644 --- a/wgengine/magicsock/endpoint_default.go +++ b/wgengine/magicsock/endpoint_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !wasm && !plan9 - -package magicsock - -import ( - "errors" - "syscall" -) - -// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to -// errors.Is while avoiding an allocation per call. -var errHOSTUNREACH error = syscall.EHOSTUNREACH - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, and for unknown -// errors always reports false. -func isBadEndpointErr(err error) bool { - return errors.Is(err, errHOSTUNREACH) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !js && !wasm && !plan9
+
+package magicsock
+
+import (
+ "errors"
+ "syscall"
+)
+
+// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to
+// errors.Is while avoiding an allocation per call.
+var errHOSTUNREACH error = syscall.EHOSTUNREACH
+
+// isBadEndpointErr checks if err is one which is known to report that an
+// endpoint can no longer be sent to. It is not exhaustive, and for unknown
+// errors always reports false.
+func isBadEndpointErr(err error) bool {
+ return errors.Is(err, errHOSTUNREACH)
+}
diff --git a/wgengine/magicsock/endpoint_stub.go b/wgengine/magicsock/endpoint_stub.go index a209c352b..9a5c9d937 100644 --- a/wgengine/magicsock/endpoint_stub.go +++ b/wgengine/magicsock/endpoint_stub.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 - -package magicsock - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, but covers known -// cases. -func isBadEndpointErr(err error) bool { - return false -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build wasm || plan9
+
+package magicsock
+
+// isBadEndpointErr checks if err is one which is known to report that an
+// endpoint can no longer be sent to. It is not exhaustive, but covers known
+// cases.
+func isBadEndpointErr(err error) bool {
+ return false
+}
diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index 5caddd1a0..e2ac926b4 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -1,248 +1,248 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "net/netip" - "slices" - "sync" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tempfork/heap" - "tailscale.com/util/mak" - "tailscale.com/util/set" -) - -const ( - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second - - // endpointTrackerMaxPerAddr is how many cached addresses we track for - // a given netip.Addr. This allows e.g. restricting the number of STUN - // endpoints we cache (which usually have the same netip.Addr but - // different ports). - // - // The value of 6 is chosen because we can advertise up to 3 endpoints - // based on the STUN IP: - // 1. The STUN endpoint itself (EndpointSTUN) - // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) - // 3. The STUN IP with a portmapped port (EndpointPortmapped) - // - // Storing 6 endpoints in the cache means we can store up to 2 previous - // sets of endpoints. - endpointTrackerMaxPerAddr = 6 -) - -// endpointTrackerEntry is an entry in an endpointHeap that stores the state of -// a given cached endpoint. -type endpointTrackerEntry struct { - // endpoint is the cached endpoint. - endpoint tailcfg.Endpoint - // until is the time until which this endpoint is being cached. - until time.Time - // index is the index within the containing endpointHeap. - index int -} - -// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in -// ascending order by the 'until' expiry time (i.e. oldest first). -type endpointHeap []*endpointTrackerEntry - -var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) - -// Len implements heap.Interface. -func (eh endpointHeap) Len() int { return len(eh) } - -// Less implements heap.Interface. -func (eh endpointHeap) Less(i, j int) bool { - // We want to store items so that the lowest item in the heap is the - // oldest, so that heap.Pop()-ing from the endpointHeap will remove the - // oldest entry. - return eh[i].until.Before(eh[j].until) -} - -// Swap implements heap.Interface. -func (eh endpointHeap) Swap(i, j int) { - eh[i], eh[j] = eh[j], eh[i] - eh[i].index = i - eh[j].index = j -} - -// Push implements heap.Interface. -func (eh *endpointHeap) Push(item *endpointTrackerEntry) { - n := len(*eh) - item.index = n - *eh = append(*eh, item) -} - -// Pop implements heap.Interface. -func (eh *endpointHeap) Pop() *endpointTrackerEntry { - old := *eh - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.index = -1 // for safety - *eh = old[0 : n-1] - return item -} - -// Min returns a pointer to the minimum element in the heap, without removing -// it. Since this is a min-heap ordered by the 'until' field, this returns the -// chronologically "earliest" element in the heap. -// -// Len() must be non-zero. -func (eh endpointHeap) Min() *endpointTrackerEntry { - return eh[0] -} - -// endpointTracker caches endpoints that are advertised to peers. This allows -// peers to still reach this node if there's a temporary endpoint flap; rather -// than withdrawing an endpoint and then re-advertising it the next time we run -// a netcheck, we keep advertising the endpoint until it's not present for a -// defined timeout. -// -// See tailscale/tailscale#7877 for more information. -type endpointTracker struct { - mu sync.Mutex - endpoints map[netip.Addr]*endpointHeap -} - -// update takes as input the current sent of discovered endpoints and the -// current time, and returns the set of endpoints plus any previous-cached and -// non-expired endpoints that should be advertised to peers. -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Extend endpoints that already exist in the cache. We do this before - // we remove expired endpoints, below, so we don't remove something - // that would otherwise have survived by extending. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.extendLocked(ep, until) - } - - // Now that we've extended existing endpoints, remove everything that - // has expired. - et.removeExpiredLocked(now) - - // Add entries from the input set of endpoints into the cache; we do - // this after removing expired ones so that we can store as many as - // possible, with space freed by the entries removed after expiry. - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Finally, add entries to the return array that aren't already there. - epsPlusCached = eps - for _, heap := range et.endpoints { - for _, ep := range *heap { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(ep.endpoint.Addr) { - continue - } else if now.After(ep.until) { - // Defense-in-depth; should never happen since - // we removed expired entries above, but ignore - // it anyway. - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - } - - return epsPlusCached -} - -// extendLocked will update the expiry time of the provided endpoint in the -// cache, if it is present. If it is not present, nothing will be done. -// -// et.mu must be held. -func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - epHeap, found := et.endpoints[key] - if !found { - return - } - - // Find the entry for this exact address; this loop is quick since we - // bound the number of items in the heap. - // - // TODO(andrew): this means we iterate over the entire heap once per - // endpoint; even if the heap is small, if we have a lot of input - // endpoints this can be expensive? - for i, entry := range *epHeap { - if entry.endpoint == ep { - entry.until = until - heap.Fix(epHeap, i) - return - } - } -} - -// addLocked will store the provided endpoint(s) in the cache for a fixed -// period of time, ensuring that the size of the endpoint cache remains below -// the maximum. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - - // Create or get the heap for this endpoint's addr - epHeap := et.endpoints[key] - if epHeap == nil { - epHeap = new(endpointHeap) - mak.Set(&et.endpoints, key, epHeap) - } - - // Find the entry for this exact address; this loop is quick - // since we bound the number of items in the heap. - found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { - return v.endpoint == ep - }) - if !found { - // Add address to heap; either the endpoint is new, or the heap - // was newly-created and thus empty. - heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) - } - - // Now that we've added everything, pop from our heap until we're below - // the limit. This is a min-heap, so popping removes the lowest (and - // thus oldest) endpoint. - for epHeap.Len() > endpointTrackerMaxPerAddr { - heap.Pop(epHeap) - } -} - -// removeExpired will remove all expired entries from the cache. -// -// et.mu must be held. -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, epHeap := range et.endpoints { - // The minimum element is oldest/earliest endpoint; repeatedly - // pop from the heap while it's in the past. - for epHeap.Len() > 0 { - minElem := epHeap.Min() - if now.After(minElem.until) { - heap.Pop(epHeap) - } else { - break - } - } - - if epHeap.Len() == 0 { - // Free up space in the map by removing the empty heap. - delete(et.endpoints, k) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "net/netip"
+ "slices"
+ "sync"
+ "time"
+
+ "tailscale.com/tailcfg"
+ "tailscale.com/tempfork/heap"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/set"
+)
+
+const (
+ // endpointTrackerLifetime is how long we continue advertising an
+ // endpoint after we last see it. This is intentionally chosen to be
+ // slightly longer than a full netcheck period.
+ endpointTrackerLifetime = 5*time.Minute + 10*time.Second
+
+ // endpointTrackerMaxPerAddr is how many cached addresses we track for
+ // a given netip.Addr. This allows e.g. restricting the number of STUN
+ // endpoints we cache (which usually have the same netip.Addr but
+ // different ports).
+ //
+ // The value of 6 is chosen because we can advertise up to 3 endpoints
+ // based on the STUN IP:
+ // 1. The STUN endpoint itself (EndpointSTUN)
+ // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort)
+ // 3. The STUN IP with a portmapped port (EndpointPortmapped)
+ //
+ // Storing 6 endpoints in the cache means we can store up to 2 previous
+ // sets of endpoints.
+ endpointTrackerMaxPerAddr = 6
+)
+
+// endpointTrackerEntry is an entry in an endpointHeap that stores the state of
+// a given cached endpoint.
+type endpointTrackerEntry struct {
+ // endpoint is the cached endpoint.
+ endpoint tailcfg.Endpoint
+ // until is the time until which this endpoint is being cached.
+ until time.Time
+ // index is the index within the containing endpointHeap.
+ index int
+}
+
+// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in
+// ascending order by the 'until' expiry time (i.e. oldest first).
+type endpointHeap []*endpointTrackerEntry
+
+var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil)
+
+// Len implements heap.Interface.
+func (eh endpointHeap) Len() int { return len(eh) }
+
+// Less implements heap.Interface.
+func (eh endpointHeap) Less(i, j int) bool {
+ // We want to store items so that the lowest item in the heap is the
+ // oldest, so that heap.Pop()-ing from the endpointHeap will remove the
+ // oldest entry.
+ return eh[i].until.Before(eh[j].until)
+}
+
+// Swap implements heap.Interface.
+func (eh endpointHeap) Swap(i, j int) {
+ eh[i], eh[j] = eh[j], eh[i]
+ eh[i].index = i
+ eh[j].index = j
+}
+
+// Push implements heap.Interface.
+func (eh *endpointHeap) Push(item *endpointTrackerEntry) {
+ n := len(*eh)
+ item.index = n
+ *eh = append(*eh, item)
+}
+
+// Pop implements heap.Interface.
+func (eh *endpointHeap) Pop() *endpointTrackerEntry {
+ old := *eh
+ n := len(old)
+ item := old[n-1]
+ old[n-1] = nil // avoid memory leak
+ item.index = -1 // for safety
+ *eh = old[0 : n-1]
+ return item
+}
+
+// Min returns a pointer to the minimum element in the heap, without removing
+// it. Since this is a min-heap ordered by the 'until' field, this returns the
+// chronologically "earliest" element in the heap.
+//
+// Len() must be non-zero.
+func (eh endpointHeap) Min() *endpointTrackerEntry {
+ return eh[0]
+}
+
+// endpointTracker caches endpoints that are advertised to peers. This allows
+// peers to still reach this node if there's a temporary endpoint flap; rather
+// than withdrawing an endpoint and then re-advertising it the next time we run
+// a netcheck, we keep advertising the endpoint until it's not present for a
+// defined timeout.
+//
+// See tailscale/tailscale#7877 for more information.
+type endpointTracker struct {
+ mu sync.Mutex
+ endpoints map[netip.Addr]*endpointHeap
+}
+
+// update takes as input the current sent of discovered endpoints and the
+// current time, and returns the set of endpoints plus any previous-cached and
+// non-expired endpoints that should be advertised to peers.
+func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
+ var inputEps set.Slice[netip.AddrPort]
+ for _, ep := range eps {
+ inputEps.Add(ep.Addr)
+ }
+
+ et.mu.Lock()
+ defer et.mu.Unlock()
+
+ // Extend endpoints that already exist in the cache. We do this before
+ // we remove expired endpoints, below, so we don't remove something
+ // that would otherwise have survived by extending.
+ until := now.Add(endpointTrackerLifetime)
+ for _, ep := range eps {
+ et.extendLocked(ep, until)
+ }
+
+ // Now that we've extended existing endpoints, remove everything that
+ // has expired.
+ et.removeExpiredLocked(now)
+
+ // Add entries from the input set of endpoints into the cache; we do
+ // this after removing expired ones so that we can store as many as
+ // possible, with space freed by the entries removed after expiry.
+ for _, ep := range eps {
+ et.addLocked(now, ep, until)
+ }
+
+ // Finally, add entries to the return array that aren't already there.
+ epsPlusCached = eps
+ for _, heap := range et.endpoints {
+ for _, ep := range *heap {
+ // If the endpoint was in the input list, or has expired, skip it.
+ if inputEps.Contains(ep.endpoint.Addr) {
+ continue
+ } else if now.After(ep.until) {
+ // Defense-in-depth; should never happen since
+ // we removed expired entries above, but ignore
+ // it anyway.
+ continue
+ }
+
+ // We haven't seen this endpoint; add to the return array
+ epsPlusCached = append(epsPlusCached, ep.endpoint)
+ }
+ }
+
+ return epsPlusCached
+}
+
+// extendLocked will update the expiry time of the provided endpoint in the
+// cache, if it is present. If it is not present, nothing will be done.
+//
+// et.mu must be held.
+func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) {
+ key := ep.Addr.Addr()
+ epHeap, found := et.endpoints[key]
+ if !found {
+ return
+ }
+
+ // Find the entry for this exact address; this loop is quick since we
+ // bound the number of items in the heap.
+ //
+ // TODO(andrew): this means we iterate over the entire heap once per
+ // endpoint; even if the heap is small, if we have a lot of input
+ // endpoints this can be expensive?
+ for i, entry := range *epHeap {
+ if entry.endpoint == ep {
+ entry.until = until
+ heap.Fix(epHeap, i)
+ return
+ }
+ }
+}
+
+// addLocked will store the provided endpoint(s) in the cache for a fixed
+// period of time, ensuring that the size of the endpoint cache remains below
+// the maximum.
+//
+// et.mu must be held.
+func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
+ key := ep.Addr.Addr()
+
+ // Create or get the heap for this endpoint's addr
+ epHeap := et.endpoints[key]
+ if epHeap == nil {
+ epHeap = new(endpointHeap)
+ mak.Set(&et.endpoints, key, epHeap)
+ }
+
+ // Find the entry for this exact address; this loop is quick
+ // since we bound the number of items in the heap.
+ found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool {
+ return v.endpoint == ep
+ })
+ if !found {
+ // Add address to heap; either the endpoint is new, or the heap
+ // was newly-created and thus empty.
+ heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until})
+ }
+
+ // Now that we've added everything, pop from our heap until we're below
+ // the limit. This is a min-heap, so popping removes the lowest (and
+ // thus oldest) endpoint.
+ for epHeap.Len() > endpointTrackerMaxPerAddr {
+ heap.Pop(epHeap)
+ }
+}
+
+// removeExpired will remove all expired entries from the cache.
+//
+// et.mu must be held.
+func (et *endpointTracker) removeExpiredLocked(now time.Time) {
+ for k, epHeap := range et.endpoints {
+ // The minimum element is oldest/earliest endpoint; repeatedly
+ // pop from the heap while it's in the past.
+ for epHeap.Len() > 0 {
+ minElem := epHeap.Min()
+ if now.After(minElem.until) {
+ heap.Pop(epHeap)
+ } else {
+ break
+ }
+ }
+
+ if epHeap.Len() == 0 {
+ // Free up space in the map by removing the empty heap.
+ delete(et.endpoints, k)
+ }
+ }
+}
diff --git a/wgengine/magicsock/magicsock_unix_test.go b/wgengine/magicsock/magicsock_unix_test.go index b0700a8eb..9ad8cab93 100644 --- a/wgengine/magicsock/magicsock_unix_test.go +++ b/wgengine/magicsock/magicsock_unix_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package magicsock - -import ( - "net" - "syscall" - "testing" - - "tailscale.com/types/nettype" -) - -func TestTrySetSocketBuffer(t *testing.T) { - c, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) - } - defer c.Close() - - rc, err := c.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - getBufs := func() (int, int) { - var rcv, snd int - rc.Control(func(fd uintptr) { - rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - t.Errorf("getsockopt(SO_RCVBUF): %v", err) - } - snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - t.Errorf("getsockopt(SO_SNDBUF): %v", err) - } - }) - return rcv, snd - } - - curRcv, curSnd := getBufs() - - trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) - - newRcv, newSnd := getBufs() - - if curRcv > newRcv { - t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) - } - if curSnd > newSnd { - t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) - } - - // On many systems we may not increase the value, particularly running as a - // regular user, so log the information for manual verification. - t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) - t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build unix
+
+package magicsock
+
+import (
+ "net"
+ "syscall"
+ "testing"
+
+ "tailscale.com/types/nettype"
+)
+
+func TestTrySetSocketBuffer(t *testing.T) {
+ c, err := net.ListenPacket("udp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ rc, err := c.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ getBufs := func() (int, int) {
+ var rcv, snd int
+ rc.Control(func(fd uintptr) {
+ rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
+ if err != nil {
+ t.Errorf("getsockopt(SO_RCVBUF): %v", err)
+ }
+ snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ t.Errorf("getsockopt(SO_SNDBUF): %v", err)
+ }
+ })
+ return rcv, snd
+ }
+
+ curRcv, curSnd := getBufs()
+
+ trySetSocketBuffer(c.(nettype.PacketConn), t.Logf)
+
+ newRcv, newSnd := getBufs()
+
+ if curRcv > newRcv {
+ t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv)
+ }
+ if curSnd > newSnd {
+ t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd)
+ }
+
+ // On many systems we may not increase the value, particularly running as a
+ // regular user, so log the information for manual verification.
+ t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv)
+ t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv)
+}
diff --git a/wgengine/magicsock/peermtu_darwin.go b/wgengine/magicsock/peermtu_darwin.go index a0a1aacb5..b2a1ed217 100644 --- a/wgengine/magicsock/peermtu_darwin.go +++ b/wgengine/magicsock/peermtu_darwin.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package magicsock - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return unix.IP_DONTFRAG - } - return unix.IPV6_DONTFRAG -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := 1 - if enable == false { - optArg = 0 - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == 1 { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin && !ios
+
+package magicsock
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func getDontFragOpt(network string) int {
+ if network == "udp4" {
+ return unix.IP_DONTFRAG
+ }
+ return unix.IPV6_DONTFRAG
+}
+
+func (c *Conn) setDontFragment(network string, enable bool) error {
+ optArg := 1
+ if enable == false {
+ optArg = 0
+ }
+ var err error
+ rcErr := c.connControl(network, func(fd uintptr) {
+ err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg)
+ })
+
+ if rcErr != nil {
+ return rcErr
+ }
+ return err
+}
+
+func (c *Conn) getDontFragment(network string) (bool, error) {
+ var v int
+ var err error
+ rcErr := c.connControl(network, func(fd uintptr) {
+ v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network))
+ })
+
+ if rcErr != nil {
+ return false, rcErr
+ }
+ if v == 1 {
+ return true, err
+ }
+ return false, err
+}
diff --git a/wgengine/magicsock/peermtu_linux.go b/wgengine/magicsock/peermtu_linux.go index b76f30f08..d32ead099 100644 --- a/wgengine/magicsock/peermtu_linux.go +++ b/wgengine/magicsock/peermtu_linux.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !android - -package magicsock - -import ( - "syscall" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return syscall.IP_MTU_DISCOVER - } - return syscall.IPV6_MTU_DISCOVER -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := syscall.IP_PMTUDISC_DO - if enable == false { - optArg = syscall.IP_PMTUDISC_DONT - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == syscall.IP_PMTUDISC_DO { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux && !android
+
+package magicsock
+
+import (
+ "syscall"
+)
+
+func getDontFragOpt(network string) int {
+ if network == "udp4" {
+ return syscall.IP_MTU_DISCOVER
+ }
+ return syscall.IPV6_MTU_DISCOVER
+}
+
+func (c *Conn) setDontFragment(network string, enable bool) error {
+ optArg := syscall.IP_PMTUDISC_DO
+ if enable == false {
+ optArg = syscall.IP_PMTUDISC_DONT
+ }
+ var err error
+ rcErr := c.connControl(network, func(fd uintptr) {
+ err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg)
+ })
+
+ if rcErr != nil {
+ return rcErr
+ }
+ return err
+}
+
+func (c *Conn) getDontFragment(network string) (bool, error) {
+ var v int
+ var err error
+ rcErr := c.connControl(network, func(fd uintptr) {
+ v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network))
+ })
+
+ if rcErr != nil {
+ return false, rcErr
+ }
+ if v == syscall.IP_PMTUDISC_DO {
+ return true, err
+ }
+ return false, err
+}
diff --git a/wgengine/magicsock/peermtu_unix.go b/wgengine/magicsock/peermtu_unix.go index eec3d744f..59e808ee7 100644 --- a/wgengine/magicsock/peermtu_unix.go +++ b/wgengine/magicsock/peermtu_unix.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (darwin && !ios) || (linux && !android) - -package magicsock - -import ( - "syscall" -) - -// getIPProto returns the value of the get/setsockopt proto argument necessary -// to set an IP sockopt that corresponds with the string network, which must be -// "udp4" or "udp6". -func getIPProto(network string) int { - if network == "udp4" { - return syscall.IPPROTO_IP - } - return syscall.IPPROTO_IPV6 -} - -// connControl allows the caller to run a system call on the socket underlying -// Conn specified by the string network, which must be "udp4" or "udp6". If the -// pconn type implements the syscall method, this function returns the value of -// of the system call fn called with the fd of the socket as its arg (or the -// error from rc.Control() if that fails). Otherwise it returns the error -// errUnsupportedConnType. -func (c *Conn) connControl(network string, fn func(fd uintptr)) error { - pconn := c.pconn4.pconn - if network == "udp6" { - pconn = c.pconn6.pconn - } - sc, ok := pconn.(syscall.Conn) - if !ok { - return errUnsupportedConnType - } - rc, err := sc.SyscallConn() - if err != nil { - return err - } - return rc.Control(fn) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build (darwin && !ios) || (linux && !android)
+
+package magicsock
+
+import (
+ "syscall"
+)
+
+// getIPProto returns the value of the get/setsockopt proto argument necessary
+// to set an IP sockopt that corresponds with the string network, which must be
+// "udp4" or "udp6".
+func getIPProto(network string) int {
+ if network == "udp4" {
+ return syscall.IPPROTO_IP
+ }
+ return syscall.IPPROTO_IPV6
+}
+
+// connControl allows the caller to run a system call on the socket underlying
+// Conn specified by the string network, which must be "udp4" or "udp6". If the
+// pconn type implements the syscall method, this function returns the value of
+// of the system call fn called with the fd of the socket as its arg (or the
+// error from rc.Control() if that fails). Otherwise it returns the error
+// errUnsupportedConnType.
+func (c *Conn) connControl(network string, fn func(fd uintptr)) error {
+ pconn := c.pconn4.pconn
+ if network == "udp6" {
+ pconn = c.pconn6.pconn
+ }
+ sc, ok := pconn.(syscall.Conn)
+ if !ok {
+ return errUnsupportedConnType
+ }
+ rc, err := sc.SyscallConn()
+ if err != nil {
+ return err
+ }
+ return rc.Control(fn)
+}
diff --git a/wgengine/mem_ios.go b/wgengine/mem_ios.go index cc266ea3a..975dfca61 100644 --- a/wgengine/mem_ios.go +++ b/wgengine/mem_ios.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgengine - -import ( - "github.com/tailscale/wireguard-go/device" -) - -// iOS has a very restrictive memory limit on network extensions. -// Reduce the maximum amount of memory that wireguard-go can allocate -// to avoid getting killed. - -func init() { - device.QueueStagedSize = 64 - device.QueueOutboundSize = 64 - device.QueueInboundSize = 64 - device.QueueHandshakeSize = 64 - device.PreallocatedBuffersPerPool = 64 -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package wgengine
+
+import (
+ "github.com/tailscale/wireguard-go/device"
+)
+
+// iOS has a very restrictive memory limit on network extensions.
+// Reduce the maximum amount of memory that wireguard-go can allocate
+// to avoid getting killed.
+
+func init() {
+ device.QueueStagedSize = 64
+ device.QueueOutboundSize = 64
+ device.QueueInboundSize = 64
+ device.QueueHandshakeSize = 64
+ device.PreallocatedBuffersPerPool = 64
+}
diff --git a/wgengine/netstack/netstack_linux.go b/wgengine/netstack/netstack_linux.go index a0bfb4456..9e27b7819 100644 --- a/wgengine/netstack/netstack_linux.go +++ b/wgengine/netstack/netstack_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstack - -import ( - "os/exec" - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - setAmbientCapsRaw = func(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_RAW}, - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netstack
+
+import (
+ "os/exec"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ setAmbientCapsRaw = func(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ AmbientCaps: []uintptr{unix.CAP_NET_RAW},
+ }
+ }
+}
diff --git a/wgengine/router/runner.go b/wgengine/router/runner.go index 8fa068e33..7ba633344 100644 --- a/wgengine/router/runner.go +++ b/wgengine/router/runner.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package router - -import ( - "errors" - "fmt" - "os" - "os/exec" - "strconv" - "strings" - "syscall" - - "golang.org/x/sys/unix" -) - -// commandRunner abstracts helpers to run OS commands. It exists -// purely to swap out osCommandRunner (below) with a fake runner in -// tests. -type commandRunner interface { - run(...string) error - output(...string) ([]byte, error) -} - -type osCommandRunner struct { - // ambientCapNetAdmin determines whether commands are executed with - // CAP_NET_ADMIN. - // CAP_NET_ADMIN is required when running as non-root and executing cmds - // like `ip rule`. Even if our process has the capability, we need to - // explicitly grant it to the new process. - // We specifically need this for Synology DSM7 where tailscaled no longer - // runs as root. - ambientCapNetAdmin bool -} - -// errCode extracts and returns the process exit code from err, or -// zero if err is nil. -func errCode(err error) int { - if err == nil { - return 0 - } - var e *exec.ExitError - if ok := errors.As(err, &e); ok { - return e.ExitCode() - } - s := err.Error() - if strings.HasPrefix(s, "exitcode:") { - code, err := strconv.Atoi(s[9:]) - if err == nil { - return code - } - } - return -42 -} - -func (o osCommandRunner) run(args ...string) error { - _, err := o.output(args...) - return err -} - -func (o osCommandRunner) output(args ...string) ([]byte, error) { - if len(args) == 0 { - return nil, errors.New("cmd: no argv[0]") - } - - cmd := exec.Command(args[0], args[1:]...) - cmd.Env = append(os.Environ(), "LC_ALL=C") - if o.ambientCapNetAdmin { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, - } - } - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) - } - - return out, nil -} - -type runGroup struct { - OkCode []int // error codes that are acceptable, other than 0, if any - Runner commandRunner // the runner that actually runs our commands - ErrAcc error // first error encountered, if any -} - -func newRunGroup(okCode []int, runner commandRunner) *runGroup { - return &runGroup{ - OkCode: okCode, - Runner: runner, - } -} - -func (rg *runGroup) okCode(err error) bool { - got := errCode(err) - for _, want := range rg.OkCode { - if got == want { - return true - } - } - return false -} - -func (rg *runGroup) Output(args ...string) []byte { - b, err := rg.Runner.output(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } - return b -} - -func (rg *runGroup) Run(args ...string) { - err := rg.Runner.run(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package router
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+// commandRunner abstracts helpers to run OS commands. It exists
+// purely to swap out osCommandRunner (below) with a fake runner in
+// tests.
+type commandRunner interface {
+ run(...string) error
+ output(...string) ([]byte, error)
+}
+
+type osCommandRunner struct {
+ // ambientCapNetAdmin determines whether commands are executed with
+ // CAP_NET_ADMIN.
+ // CAP_NET_ADMIN is required when running as non-root and executing cmds
+ // like `ip rule`. Even if our process has the capability, we need to
+ // explicitly grant it to the new process.
+ // We specifically need this for Synology DSM7 where tailscaled no longer
+ // runs as root.
+ ambientCapNetAdmin bool
+}
+
+// errCode extracts and returns the process exit code from err, or
+// zero if err is nil.
+func errCode(err error) int {
+ if err == nil {
+ return 0
+ }
+ var e *exec.ExitError
+ if ok := errors.As(err, &e); ok {
+ return e.ExitCode()
+ }
+ s := err.Error()
+ if strings.HasPrefix(s, "exitcode:") {
+ code, err := strconv.Atoi(s[9:])
+ if err == nil {
+ return code
+ }
+ }
+ return -42
+}
+
+func (o osCommandRunner) run(args ...string) error {
+ _, err := o.output(args...)
+ return err
+}
+
+func (o osCommandRunner) output(args ...string) ([]byte, error) {
+ if len(args) == 0 {
+ return nil, errors.New("cmd: no argv[0]")
+ }
+
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Env = append(os.Environ(), "LC_ALL=C")
+ if o.ambientCapNetAdmin {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ AmbientCaps: []uintptr{unix.CAP_NET_ADMIN},
+ }
+ }
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out)
+ }
+
+ return out, nil
+}
+
+type runGroup struct {
+ OkCode []int // error codes that are acceptable, other than 0, if any
+ Runner commandRunner // the runner that actually runs our commands
+ ErrAcc error // first error encountered, if any
+}
+
+func newRunGroup(okCode []int, runner commandRunner) *runGroup {
+ return &runGroup{
+ OkCode: okCode,
+ Runner: runner,
+ }
+}
+
+func (rg *runGroup) okCode(err error) bool {
+ got := errCode(err)
+ for _, want := range rg.OkCode {
+ if got == want {
+ return true
+ }
+ }
+ return false
+}
+
+func (rg *runGroup) Output(args ...string) []byte {
+ b, err := rg.Runner.output(args...)
+ if rg.ErrAcc == nil && err != nil && !rg.okCode(err) {
+ rg.ErrAcc = err
+ }
+ return b
+}
+
+func (rg *runGroup) Run(args ...string) {
+ err := rg.Runner.run(args...)
+ if rg.ErrAcc == nil && err != nil && !rg.okCode(err) {
+ rg.ErrAcc = err
+ }
+}
diff --git a/wgengine/watchdog_js.go b/wgengine/watchdog_js.go index 872ce36d5..9dcb29c4e 100644 --- a/wgengine/watchdog_js.go +++ b/wgengine/watchdog_js.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js - -package wgengine - -import "tailscale.com/net/dns/resolver" - -type watchdogEngine struct { - Engine - wrap Engine -} - -func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { - return nil, false -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build js
+
+package wgengine
+
+import "tailscale.com/net/dns/resolver"
+
+type watchdogEngine struct {
+ Engine
+ wrap Engine
+}
+
+func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) {
+ return nil, false
+}
diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index 80fa159e3..9b83998cb 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "io" - "sort" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" - "tailscale.com/util/multierr" -) - -// NewDevice returns a wireguard-go Device configured for Tailscale use. -func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { - ret := device.NewDevice(tunDev, bind, logger) - ret.DisableSomeRoamingForBrokenMobileSemantics() - return ret -} - -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := multierr.New(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - -// ReconfigDevice replaces the existing device configuration with cfg. -func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { - defer func() { - if err != nil { - logf("wgcfg.Reconfig failed: %v", err) - } - }() - - prev, err := DeviceConfig(d) - if err != nil { - return err - } - - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() - - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return multierr.New(setErr, toErr) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package wgcfg
+
+import (
+ "io"
+ "sort"
+
+ "github.com/tailscale/wireguard-go/conn"
+ "github.com/tailscale/wireguard-go/device"
+ "github.com/tailscale/wireguard-go/tun"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/multierr"
+)
+
+// NewDevice returns a wireguard-go Device configured for Tailscale use.
+func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device {
+ ret := device.NewDevice(tunDev, bind, logger)
+ ret.DisableSomeRoamingForBrokenMobileSemantics()
+ return ret
+}
+
+func DeviceConfig(d *device.Device) (*Config, error) {
+ r, w := io.Pipe()
+ errc := make(chan error, 1)
+ go func() {
+ errc <- d.IpcGetOperation(w)
+ w.Close()
+ }()
+ cfg, fromErr := FromUAPI(r)
+ r.Close()
+ getErr := <-errc
+ err := multierr.New(getErr, fromErr)
+ if err != nil {
+ return nil, err
+ }
+ sort.Slice(cfg.Peers, func(i, j int) bool {
+ return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey)
+ })
+ return cfg, nil
+}
+
+// ReconfigDevice replaces the existing device configuration with cfg.
+func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) {
+ defer func() {
+ if err != nil {
+ logf("wgcfg.Reconfig failed: %v", err)
+ }
+ }()
+
+ prev, err := DeviceConfig(d)
+ if err != nil {
+ return err
+ }
+
+ r, w := io.Pipe()
+ errc := make(chan error, 1)
+ go func() {
+ errc <- d.IpcSetOperation(r)
+ r.Close()
+ }()
+
+ toErr := cfg.ToUAPI(logf, w, prev)
+ w.Close()
+ setErr := <-errc
+ return multierr.New(setErr, toErr)
+}
diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index d54282e4b..c54ad16d9 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -1,261 +1,261 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "os" - "sort" - "strings" - "sync" - "testing" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - k2, pk2 := newK() - ip2 := netip.MustParsePrefix("10.0.0.2/32") - - k3, _ := newK() - ip3 := netip.MustParsePrefix("10.0.0.3/32") - - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } - - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { - t.Fatal(err) - } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 add new peer", func(t *testing.T) { - cfg1.Peers = append(cfg1.Peers, Peer{ - PublicKey: k3, - AllowedIPs: []netip.Prefix{ip3}, - }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") - } - }) - - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) -} - -// TODO: replace with a loopback tunnel -type nilTun struct { - events chan tun.Event - closed chan struct{} -} - -func newNilTun() tun.Device { - return &nilTun{ - events: make(chan tun.Event), - closed: make(chan struct{}), - } -} - -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() <-chan tun.Event { return t.events } - -func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Write(data [][]byte, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Close() error { - close(t.events) - close(t.closed) - return nil -} - -func (t *nilTun) BatchSize() int { return 1 } - -// A noopBind is a conn.Bind that does no actual binding work. -type noopBind struct{} - -func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - return nil, 1, nil -} -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } -func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return dummyEndpoint(s), nil -} -func (noopBind) BatchSize() int { return 1 } - -// A dummyEndpoint is a string holding the endpoint destination. -type dummyEndpoint string - -func (e dummyEndpoint) ClearSrc() {} -func (e dummyEndpoint) SrcToString() string { return "" } -func (e dummyEndpoint) DstToString() string { return string(e) } -func (e dummyEndpoint) DstToBytes() []byte { return nil } -func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package wgcfg
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "net/netip"
+ "os"
+ "sort"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/tailscale/wireguard-go/conn"
+ "github.com/tailscale/wireguard-go/device"
+ "github.com/tailscale/wireguard-go/tun"
+ "go4.org/mem"
+ "tailscale.com/types/key"
+)
+
+func TestDeviceConfig(t *testing.T) {
+ newK := func() (key.NodePublic, key.NodePrivate) {
+ t.Helper()
+ k := key.NewNode()
+ return k.Public(), k
+ }
+ k1, pk1 := newK()
+ ip1 := netip.MustParsePrefix("10.0.0.1/32")
+
+ k2, pk2 := newK()
+ ip2 := netip.MustParsePrefix("10.0.0.2/32")
+
+ k3, _ := newK()
+ ip3 := netip.MustParsePrefix("10.0.0.3/32")
+
+ cfg1 := &Config{
+ PrivateKey: pk1,
+ Peers: []Peer{{
+ PublicKey: k2,
+ AllowedIPs: []netip.Prefix{ip2},
+ }},
+ }
+
+ cfg2 := &Config{
+ PrivateKey: pk2,
+ Peers: []Peer{{
+ PublicKey: k1,
+ AllowedIPs: []netip.Prefix{ip1},
+ PersistentKeepalive: 5,
+ }},
+ }
+
+ device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1"))
+ device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2"))
+ defer device1.Close()
+ defer device2.Close()
+
+ cmp := func(t *testing.T, d *device.Device, want *Config) {
+ t.Helper()
+ got, err := DeviceConfig(d)
+ if err != nil {
+ t.Fatal(err)
+ }
+ prev := new(Config)
+ gotbuf := new(strings.Builder)
+ err = got.ToUAPI(t.Logf, gotbuf, prev)
+ gotStr := gotbuf.String()
+ if err != nil {
+ t.Errorf("got.ToUAPI(): error: %v", err)
+ return
+ }
+ wantbuf := new(strings.Builder)
+ err = want.ToUAPI(t.Logf, wantbuf, prev)
+ wantStr := wantbuf.String()
+ if err != nil {
+ t.Errorf("want.ToUAPI(): error: %v", err)
+ return
+ }
+ if gotStr != wantStr {
+ buf := new(bytes.Buffer)
+ w := bufio.NewWriter(buf)
+ if err := d.IpcGetOperation(w); err != nil {
+ t.Errorf("on error, could not IpcGetOperation: %v", err)
+ }
+ w.Flush()
+ t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String())
+ }
+ }
+
+ t.Run("device1 config", func(t *testing.T) {
+ if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device1, cfg1)
+ })
+
+ t.Run("device2 config", func(t *testing.T) {
+ if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device2, cfg2)
+ })
+
+ // This is only to test that Config and Reconfig are properly synchronized.
+ t.Run("device2 config/reconfig", func(t *testing.T) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ ReconfigDevice(device2, cfg2, t.Logf)
+ wg.Done()
+ }()
+
+ go func() {
+ DeviceConfig(device2)
+ wg.Done()
+ }()
+
+ wg.Wait()
+ })
+
+ t.Run("device1 modify peer", func(t *testing.T) {
+ cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0}))
+ if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device1, cfg1)
+ })
+
+ t.Run("device1 replace endpoint", func(t *testing.T) {
+ cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0}))
+ if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device1, cfg1)
+ })
+
+ t.Run("device1 add new peer", func(t *testing.T) {
+ cfg1.Peers = append(cfg1.Peers, Peer{
+ PublicKey: k3,
+ AllowedIPs: []netip.Prefix{ip3},
+ })
+ sort.Slice(cfg1.Peers, func(i, j int) bool {
+ return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey)
+ })
+
+ origCfg, err := DeviceConfig(device1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device1, cfg1)
+
+ newCfg, err := DeviceConfig(device1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ peer0 := func(cfg *Config) Peer {
+ p, ok := cfg.PeerWithKey(k2)
+ if !ok {
+ t.Helper()
+ t.Fatal("failed to look up peer 2")
+ }
+ return p
+ }
+ peersEqual := func(p, q Peer) bool {
+ return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs)
+ }
+ if !peersEqual(peer0(origCfg), peer0(newCfg)) {
+ t.Error("reconfig modified old peer")
+ }
+ })
+
+ t.Run("device1 remove peer", func(t *testing.T) {
+ removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey
+ cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1]
+
+ if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
+ t.Fatal(err)
+ }
+ cmp(t, device1, cfg1)
+
+ newCfg, err := DeviceConfig(device1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, ok := newCfg.PeerWithKey(removeKey)
+ if ok {
+ t.Error("reconfig failed to remove peer")
+ }
+ })
+}
+
+// TODO: replace with a loopback tunnel
+type nilTun struct {
+ events chan tun.Event
+ closed chan struct{}
+}
+
+func newNilTun() tun.Device {
+ return &nilTun{
+ events: make(chan tun.Event),
+ closed: make(chan struct{}),
+ }
+}
+
+func (t *nilTun) File() *os.File { return nil }
+func (t *nilTun) Flush() error { return nil }
+func (t *nilTun) MTU() (int, error) { return 1420, nil }
+func (t *nilTun) Name() (string, error) { return "niltun", nil }
+func (t *nilTun) Events() <-chan tun.Event { return t.events }
+
+func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) {
+ <-t.closed
+ return 0, io.EOF
+}
+
+func (t *nilTun) Write(data [][]byte, offset int) (int, error) {
+ <-t.closed
+ return 0, io.EOF
+}
+
+func (t *nilTun) Close() error {
+ close(t.events)
+ close(t.closed)
+ return nil
+}
+
+func (t *nilTun) BatchSize() int { return 1 }
+
+// A noopBind is a conn.Bind that does no actual binding work.
+type noopBind struct{}
+
+func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 1, nil
+}
+func (noopBind) Close() error { return nil }
+func (noopBind) SetMark(mark uint32) error { return nil }
+func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil }
+func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ return dummyEndpoint(s), nil
+}
+func (noopBind) BatchSize() int { return 1 }
+
+// A dummyEndpoint is a string holding the endpoint destination.
+type dummyEndpoint string
+
+func (e dummyEndpoint) ClearSrc() {}
+func (e dummyEndpoint) SrcToString() string { return "" }
+func (e dummyEndpoint) DstToString() string { return string(e) }
+func (e dummyEndpoint) DstToBytes() []byte { return nil }
+func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} }
+func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index ec3d008f7..553aaecbb 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package wgcfg
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "strconv"
+ "strings"
+
+ "go4.org/mem"
+ "tailscale.com/types/key"
+)
+
+type ParseError struct {
+ why string
+ offender string
+}
+
+func (e *ParseError) Error() string {
+ return fmt.Sprintf("%s: %q", e.why, e.offender)
+}
+
+func parseEndpoint(s string) (host string, port uint16, err error) {
+ i := strings.LastIndexByte(s, ':')
+ if i < 0 {
+ return "", 0, &ParseError{"Missing port from endpoint", s}
+ }
+ host, portStr := s[:i], s[i+1:]
+ if len(host) < 1 {
+ return "", 0, &ParseError{"Invalid endpoint host", host}
+ }
+ uport, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ return "", 0, err
+ }
+ hostColon := strings.IndexByte(host, ':')
+ if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
+ err := &ParseError{"Brackets must contain an IPv6 address", host}
+ if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
+ maybeV6 := net.ParseIP(host[1 : len(host)-1])
+ if maybeV6 == nil || len(maybeV6) != net.IPv6len {
+ return "", 0, err
+ }
+ } else {
+ return "", 0, err
+ }
+ host = host[1 : len(host)-1]
+ }
+ return host, uint16(uport), nil
+}
+
+// memROCut separates a mem.RO at the separator if it exists, otherwise
+// it returns two empty ROs and reports that it was not found.
+func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) {
+ if i := mem.IndexByte(s, sep); i >= 0 {
+ return s.SliceTo(i), s.SliceFrom(i + 1), true
+ }
+ found = false
+ return
+}
+
+// FromUAPI generates a Config from r.
+// r should be generated by calling device.IpcGetOperation;
+// it is not compatible with other uapi streams.
+func FromUAPI(r io.Reader) (*Config, error) {
+ cfg := new(Config)
+ var peer *Peer // current peer being operated on
+ deviceConfig := true
+
+ scanner := bufio.NewScanner(r)
+ for scanner.Scan() {
+ line := mem.B(scanner.Bytes())
+ if line.Len() == 0 {
+ continue
+ }
+ key, value, ok := memROCut(line, '=')
+ if !ok {
+ return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy())
+ }
+ valueBytes := scanner.Bytes()[key.Len()+1:]
+
+ if key.EqualString("public_key") {
+ if deviceConfig {
+ deviceConfig = false
+ }
+ // Load/create the peer we are now configuring.
+ var err error
+ peer, err = cfg.handlePublicKeyLine(valueBytes)
+ if err != nil {
+ return nil, err
+ }
+ continue
+ }
+
+ var err error
+ if deviceConfig {
+ err = cfg.handleDeviceLine(key, value, valueBytes)
+ } else {
+ err = cfg.handlePeerLine(peer, key, value, valueBytes)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+
+ return cfg, nil
+}
+
+func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error {
+ switch {
+ case k.EqualString("private_key"):
+ // wireguard-go guarantees not to send zero value; private keys are already clamped.
+ var err error
+ cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value)
+ if err != nil {
+ return err
+ }
+ case k.EqualString("listen_port") || k.EqualString("fwmark"):
+ // ignore
+ default:
+ return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())
+ }
+ return nil
+}
+
+func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) {
+ p := Peer{}
+ var err error
+ p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes))
+ if err != nil {
+ return nil, err
+ }
+ cfg.Peers = append(cfg.Peers, p)
+ return &cfg.Peers[len(cfg.Peers)-1], nil
+}
+
+func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error {
+ switch {
+ case k.EqualString("endpoint"):
+ nk, err := key.ParseNodePublicUntyped(value)
+ if err != nil {
+ return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString())
+ }
+ // nk ought to equal peer.PublicKey.
+ // Under some rare circumstances, it might not. See corp issue #3016.
+ // Even if that happens, don't stop early, so that we can recover from it.
+ // Instead, note the value of nk so we can fix as needed.
+ peer.WGEndpoint = nk
+ case k.EqualString("persistent_keepalive_interval"):
+ n, err := mem.ParseUint(value, 10, 16)
+ if err != nil {
+ return err
+ }
+ peer.PersistentKeepalive = uint16(n)
+ case k.EqualString("allowed_ip"):
+ ipp := netip.Prefix{}
+ err := ipp.UnmarshalText(valueBytes)
+ if err != nil {
+ return err
+ }
+ peer.AllowedIPs = append(peer.AllowedIPs, ipp)
+ case k.EqualString("protocol_version"):
+ if !value.EqualString("1") {
+ return fmt.Errorf("invalid protocol version: %q", value.StringCopy())
+ }
+ case k.EqualString("replace_allowed_ips") ||
+ k.EqualString("preshared_key") ||
+ k.EqualString("last_handshake_time_sec") ||
+ k.EqualString("last_handshake_time_nsec") ||
+ k.EqualString("tx_bytes") ||
+ k.EqualString("rx_bytes"):
+ // ignore
+ default:
+ return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())
+ }
+ return nil
+}
diff --git a/wgengine/winnet/winnet_windows.go b/wgengine/winnet/winnet_windows.go index 283ce5ad1..01e38517d 100644 --- a/wgengine/winnet/winnet_windows.go +++ b/wgengine/winnet/winnet_windows.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package winnet - -import ( - "fmt" - "syscall" - "unsafe" - - "github.com/go-ole/go-ole" -) - -func (v *INetworkConnection) GetAdapterId() (string, error) { - buf := ole.GUID{} - hr, _, _ := syscall.Syscall( - v.VTable().GetAdapterId, - 2, - uintptr(unsafe.Pointer(v)), - uintptr(unsafe.Pointer(&buf)), - 0) - if hr != 0 { - return "", fmt.Errorf("GetAdapterId failed: %08x", hr) - } - return buf.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winnet
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+
+ "github.com/go-ole/go-ole"
+)
+
+func (v *INetworkConnection) GetAdapterId() (string, error) {
+ buf := ole.GUID{}
+ hr, _, _ := syscall.Syscall(
+ v.VTable().GetAdapterId,
+ 2,
+ uintptr(unsafe.Pointer(v)),
+ uintptr(unsafe.Pointer(&buf)),
+ 0)
+ if hr != 0 {
+ return "", fmt.Errorf("GetAdapterId failed: %08x", hr)
+ }
+ return buf.String(), nil
+}
|
