summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorNick Khyl <nickk@tailscale.com>2026-02-19 08:50:32 -0600
committerNick Khyl <nickk@tailscale.com>2026-02-19 08:50:32 -0600
commit07fe367032c7715fd896210cde2d396430d75308 (patch)
tree929777d3a86523dc630314f6700131bb3bf7f542
parent4ec71f64801f2a805d8b7c7a223de77b91c6760e (diff)
downloadtailscale-nickkhyl/rioconn.tar.xz
tailscale-nickkhyl/rioconn.zip
net/rioconn/rioperf: add CLI to measure UDPConn performancenickkhyl/rioconn
Updates tailscale/corp#8610 Signed-off-by: Nick Khyl <nickk@tailscale.com>
-rw-r--r--net/rioconn/rioperf/main.go329
1 files changed, 329 insertions, 0 deletions
diff --git a/net/rioconn/rioperf/main.go b/net/rioconn/rioperf/main.go
new file mode 100644
index 000000000..8c4017ade
--- /dev/null
+++ b/net/rioconn/rioperf/main.go
@@ -0,0 +1,329 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+// Command rioperf is a simple traffic generator and performance test tool for rioconn.UDPConn.
+//
+// Run with "server" to start the server, or "client <dest>" to send traffic to it.
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "math"
+ "net"
+ "net/netip"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/peterbourgon/ff/v3/ffcli"
+ "golang.org/x/net/ipv6"
+ "tailscale.com/net/packet"
+ "tailscale.com/net/rioconn"
+)
+
+func main() {
+ root := &ffcli.Command{
+ Exec: func(ctx context.Context, args []string) error {
+ return flag.ErrHelp
+ },
+ ShortUsage: "rioperf <client|server> [flags]",
+ ShortHelp: "rioconn performance test tool",
+ FlagSet: flag.NewFlagSet("rioperf", flag.ExitOnError),
+ Subcommands: []*ffcli.Command{
+ {
+ Name: "server",
+ ShortUsage: "server [flags]",
+ ShortHelp: "Start a server",
+ FlagSet: buildServerFlags(),
+ Exec: func(ctx context.Context, args []string) error {
+ runServer()
+ return nil
+ },
+ },
+ {
+ Name: "client",
+ ShortUsage: "client <dest> [flags]",
+ ShortHelp: "Start a client",
+ FlagSet: buildClientFlags(),
+ Exec: func(ctx context.Context, args []string) error {
+ if len(args) < 1 {
+ return flag.ErrHelp
+ }
+ destStr := args[0]
+ if !strings.Contains(destStr, ":") {
+ destStr += ":0"
+ }
+ destUDPAddr, err := net.ResolveUDPAddr("udp", destStr)
+ if err != nil {
+ return fmt.Errorf("invalid destination address: %w", err)
+ }
+ if destUDPAddr.Port == 0 && clientFlags.port != 0 {
+ destUDPAddr.Port = clientFlags.port
+ }
+ destAddrPort := destUDPAddr.AddrPort()
+ destAddrPort = netip.AddrPortFrom(
+ destAddrPort.Addr().Unmap(),
+ destAddrPort.Port(),
+ )
+ runClient(destAddrPort)
+ return nil
+ },
+ },
+ },
+ }
+
+ if err := root.ParseAndRun(context.Background(), os.Args[1:]); err != nil {
+ fmt.Fprintf(os.Stderr, "error: %v\n", err)
+ os.Exit(1)
+ }
+}
+
+var serverFlags struct {
+ port int
+ batch int
+ maxMsgLen int
+ uro bool
+ memLimit uintptr
+}
+
+func buildServerFlags() *flag.FlagSet {
+ fs := flag.NewFlagSet("server", flag.ExitOnError)
+ fs.IntVar(&serverFlags.port, "port", 9999, "UDP port to listen on")
+ fs.IntVar(&serverFlags.batch, "batch", 64, "number of packets to read per batch")
+ fs.IntVar(&serverFlags.maxMsgLen, "maxlen", math.MaxUint16, "maximum length of each UDP message to receive")
+ fs.BoolVar(&serverFlags.uro, "uro", false, "enable UDP RSC Offload (URO)")
+ fs.Func("mem", "memory limit for RIO buffers (e.g. 512k, 1m, 16m)", func(s string) error {
+ var err error
+ serverFlags.memLimit, err = parseSize(s)
+ return err
+ })
+ return fs
+}
+
+var clientFlags struct {
+ port int
+ batchSize int
+ msgLen int
+ uso bool
+ memLimit uintptr
+}
+
+func buildClientFlags() *flag.FlagSet {
+ fs := flag.NewFlagSet("client", flag.ExitOnError)
+ fs.IntVar(&clientFlags.port, "port", 9999, "UDP port to send to")
+ fs.IntVar(&clientFlags.batchSize, "batch", 64, "number of packets to send per batch")
+ fs.IntVar(&clientFlags.msgLen, "len", 1312, "length of each UDP message to send")
+ fs.BoolVar(&clientFlags.uso, "uso", true, "enable UDP Segmentation Offload (USO)")
+ fs.Func("mem", "memory limit for RIO buffers (e.g. 512k, 1m, 16m)", func(s string) error {
+ var err error
+ clientFlags.memLimit, err = parseSize(s)
+ return err
+ })
+ return fs
+}
+
+func runServer() {
+ opts := []rioconn.UDPOption{
+ rioconn.URO(serverFlags.uro),
+ rioconn.RxMemoryLimit(serverFlags.memLimit),
+ rioconn.TxMemoryLimit(serverFlags.memLimit),
+ rioconn.RxMaxPayloadLen(uintptr(serverFlags.maxMsgLen)),
+ }
+
+ conn, err := rioconn.ListenUDP("udp4", &net.UDPAddr{Port: serverFlags.port}, opts...)
+ if err != nil {
+ fmt.Printf("Failed to create connection: %v\n", err)
+ os.Exit(1)
+ }
+ defer conn.Close()
+
+ fmt.Printf("Server listening on %v\n", conn.LocalAddrPort())
+
+ msgs := make([]ipv6.Message, serverFlags.batch)
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].Buffers[0] = make([]byte, conn.Config().Rx().MaxPayloadLen())
+ }
+
+ var mu sync.Mutex
+ var sessionID uint64
+ var sessionPackets, sessionBytes uint64
+ var sessionStart time.Time
+ var lastReceived time.Time
+
+ go func() {
+ ticker := time.NewTicker(time.Second)
+ defer ticker.Stop()
+ for range ticker.C {
+ mu.Lock()
+ if !lastReceived.IsZero() && time.Since(lastReceived) >= 1*time.Second && sessionPackets > 0 {
+ dt := lastReceived.Sub(sessionStart)
+ if dt > 0 {
+ pps := float64(sessionPackets) / dt.Seconds()
+ bitrate, units := formatBitrate(float64(sessionBytes) * 8 / dt.Seconds())
+ fmt.Printf("[avg] %.2f Mpps %.2f %s over %.1fs\n", pps/1e6, bitrate, units, dt.Seconds())
+ }
+ sessionPackets = 0
+ sessionBytes = 0
+ lastReceived = time.Time{}
+ }
+ mu.Unlock()
+ }
+ }()
+
+ var totalPackets, totalBytes uint64
+ var lastPackets, lastBytes uint64
+ lastReport := time.Now()
+
+ for {
+ n, err := conn.ReadBatch(msgs, 0)
+ if err != nil {
+ fmt.Printf("ReadBatch error: %v\n", err)
+ os.Exit(1)
+ }
+
+ now := time.Now()
+ mu.Lock()
+ if lastReceived.IsZero() {
+ sessionStart = now
+ sessionID++
+ fmt.Printf("\nSession %d started by %v\n", sessionID, msgs[0].Addr)
+ }
+ for i := 0; i < n; i++ {
+ totalBytes += uint64(msgs[i].N)
+ sessionBytes += uint64(msgs[i].N)
+ }
+ totalPackets += uint64(n)
+ sessionPackets += uint64(n)
+ lastReceived = now
+ mu.Unlock()
+
+ if pktsDelta := totalPackets - lastPackets; pktsDelta > 1024 {
+ if now.Sub(lastReport) >= time.Second {
+ dt := now.Sub(lastReport)
+ bytesDelta := totalBytes - lastBytes
+ pps := float64(pktsDelta) / dt.Seconds()
+ bitrate, units := formatBitrate(float64(bytesDelta) * 8 / dt.Seconds())
+ fmt.Printf("%.2f Mpps %.2f %s\n", pps/1e6, bitrate, units)
+ lastPackets = totalPackets
+ lastBytes = totalBytes
+ lastReport = now
+ }
+ }
+ }
+}
+
+func runClient(destAddrPort netip.AddrPort) {
+ opts := []rioconn.UDPOption{
+ rioconn.USO(clientFlags.uso),
+ rioconn.RxMemoryLimit(clientFlags.memLimit),
+ rioconn.TxMemoryLimit(clientFlags.memLimit),
+ rioconn.TxMaxPayloadLen(uintptr(clientFlags.msgLen)),
+ }
+
+ conn, err := rioconn.ListenUDP("udp4", &net.UDPAddr{}, opts...)
+ if err != nil {
+ fmt.Printf("Failed to create connection: %v\n", err)
+ os.Exit(1)
+ }
+ defer conn.Close()
+
+ fmt.Printf("Client sending to %v from %v\n", destAddrPort, conn.LocalAddrPort())
+
+ buffs := make([][]byte, clientFlags.batchSize)
+ for i := range buffs {
+ buffs[i] = make([]byte, clientFlags.msgLen)
+ for j := range buffs[i] {
+ buffs[i][j] = byte(i + j)
+ }
+ }
+
+ var totalPackets, totalBytes uint64
+ var lastPackets, lastBytes uint64
+ start := time.Now()
+ lastReport := start
+ endTime := start.Add(30 * time.Second)
+
+ for {
+ if err := conn.WriteBatchTo(buffs, destAddrPort, packet.GeneveHeader{}, 0); err != nil {
+ fmt.Fprintf(os.Stderr, "WriteBatchTo error: %v\n", err)
+ break
+ }
+
+ totalPackets += uint64(len(buffs))
+ totalBytes += uint64(len(buffs) * clientFlags.msgLen)
+
+ if pktsDelta := totalPackets - lastPackets; pktsDelta > 1024 {
+ now := time.Now()
+ if now.After(endTime) {
+ break
+ }
+ if dt := now.Sub(lastReport); dt >= time.Second {
+ bytesDelta := totalBytes - lastBytes
+ pps := float64(pktsDelta) / dt.Seconds()
+ bitrate, units := formatBitrate(float64(bytesDelta) * 8 / dt.Seconds())
+ fmt.Printf("%.2f Mpps %.2f %s\n", pps/1e6, bitrate, units)
+ lastPackets = totalPackets
+ lastBytes = totalBytes
+ lastReport = now
+ }
+ }
+ }
+
+ elapsed := time.Since(start).Seconds()
+ pps := float64(totalPackets) / elapsed
+ bitrate, units := formatBitrate(float64(totalBytes) * 8 / elapsed)
+ fmt.Printf("\n[avg] %.2f Mpps %.2f %s over %.1fs\n", pps/1e6, bitrate, units, endTime.Sub(start).Seconds())
+}
+
+func parseSize(s string) (uintptr, error) {
+ if s = strings.TrimSpace(s); s == "" {
+ return 0, fmt.Errorf("empty size")
+ }
+ multiplier := uintptr(1)
+
+ last := s[len(s)-1]
+ switch last {
+ case 'k', 'K':
+ multiplier = 1 << 10
+ s = s[:len(s)-1]
+ case 'm', 'M':
+ multiplier = 1 << 20
+ s = s[:len(s)-1]
+ case 'g', 'G':
+ multiplier = 1 << 30
+ s = s[:len(s)-1]
+ }
+
+ n, err := strconv.ParseUint(s, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("invalid size %q: %w", s, err)
+ }
+
+ return uintptr(n) * multiplier, nil
+}
+
+func formatBitrate[T float64 | int64](bps T) (n float64, unit string) {
+ const (
+ Kbps = 1000
+ Mbps = 1000 * Kbps
+ Gbps = 1000 * Mbps
+ )
+
+ switch {
+ case float64(bps) >= 0.9*Gbps:
+ return float64(bps) / float64(Gbps), "Gbps"
+ case float64(bps) >= 0.9*Mbps:
+ return float64(bps) / float64(Mbps), "Mbps"
+ case float64(bps) >= 0.9*Kbps:
+ return float64(bps) / float64(Kbps), "Kbps"
+ default:
+ return float64(bps), "bps"
+ }
+}