diff options
| -rw-r--r-- | cmd/derper/depaware.txt | 1 | ||||
| -rw-r--r-- | cmd/derper/derper.go | 112 | ||||
| -rw-r--r-- | cmd/derper/stun.go | 170 | ||||
| -rw-r--r-- | cmd/derper/stun_test.go | 55 |
4 files changed, 265 insertions, 73 deletions
diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index e25e7b92c..e46ad22b7 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -262,6 +262,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa net/url from crypto/x509+ os from crypto/rand+ os/exec from golang.zx2c4.com/wireguard/windows/tunnel/winipcfg+ + os/signal from tailscale.com/cmd/derper W os/user from tailscale.com/util/winutil path from golang.org/x/crypto/acme/autocert+ path/filepath from crypto/x509+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index a757745ba..ac634ea89 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -17,11 +17,12 @@ import ( "math" "net" "net/http" - "net/netip" "os" + "os/signal" "path/filepath" "regexp" "strings" + "syscall" "time" "go4.org/mem" @@ -30,7 +31,6 @@ import ( "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/metrics" - "tailscale.com/net/stun" "tailscale.com/tsweb" "tailscale.com/types/key" "tailscale.com/util/cmpx" @@ -56,28 +56,17 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") + + stunOnly = flag.Bool("stun-only", false, "only start a stun server (used by the stun subprocess spawner") + stunSubprocess = flag.Bool("stun-subprocess", false, "spawn the stun server as a sub-process rather than in the host process") ) var ( - stats = new(metrics.Set) - stunDisposition = &metrics.LabelMap{Label: "disposition"} - stunAddrFamily = &metrics.LabelMap{Label: "family"} tlsRequestVersion = &metrics.LabelMap{Label: "version"} tlsActiveVersion = &metrics.LabelMap{Label: "version"} - - stunReadError = stunDisposition.Get("read_error") - stunNotSTUN = stunDisposition.Get("not_stun") - stunWriteError = stunDisposition.Get("write_error") - stunSuccess = stunDisposition.Get("success") - - stunIPv4 = stunAddrFamily.Get("ipv4") - stunIPv6 = stunAddrFamily.Get("ipv6") ) func init() { - stats.Set("counter_requests", stunDisposition) - stats.Set("counter_addrfamily", stunAddrFamily) - expvar.Publish("stun", stats) expvar.Publish("derper_tls_request_version", tlsRequestVersion) expvar.Publish("gauge_derper_tls_active_version", tlsActiveVersion) } @@ -132,9 +121,22 @@ func writeNewConfig() config { return cfg } +func cancelOnSignal(cancelf func()) { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + cancelf() + }() +} + func main() { flag.Parse() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cancelOnSignal(cancel) + if *dev { *addr = ":3340" // above the keys DERP log.Printf("Running in dev mode.") @@ -146,6 +148,21 @@ func main() { log.Fatalf("invalid server address: %v", err) } + if *runSTUN { + if *stunSubprocess { + if *stunOnly { + log.SetPrefix(fmt.Sprintf("stun(%d) ", os.Getpid())) + log.Printf("Starting in stun-only mode.") + go printSTUNStats(ctx, os.Stdout, time.Second*10) + serveSTUN(ctx, listenHost, *stunPort) + return + } + go startChildSTUN(ctx) + } else { + go serveSTUN(ctx, listenHost, *stunPort) + } + } + cfg := loadConfig() serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual" @@ -221,10 +238,6 @@ func main() { })) debug.Handle("traffic", "Traffic check", http.HandlerFunc(s.ServeDebugTraffic)) - if *runSTUN { - go serveSTUN(listenHost, *stunPort) - } - quietLogger := log.New(logFilter{}, "", 0) httpsrv := &http.Server{ Addr: *addr, @@ -241,6 +254,12 @@ func main() { ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } + go func() { + <-ctx.Done() + timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + httpsrv.Shutdown(timeoutCtx) + }() if serveTLS { log.Printf("derper: serving on %s with TLS", *addr) @@ -351,59 +370,6 @@ func probeHandler(w http.ResponseWriter, r *http.Request) { } } -func serveSTUN(host string, port int) { - pc, err := net.ListenPacket("udp", net.JoinHostPort(host, fmt.Sprint(port))) - if err != nil { - log.Fatalf("failed to open STUN listener: %v", err) - } - log.Printf("running STUN server on %v", pc.LocalAddr()) - serverSTUNListener(context.Background(), pc.(*net.UDPConn)) -} - -func serverSTUNListener(ctx context.Context, pc *net.UDPConn) { - var buf [64 << 10]byte - var ( - n int - ua *net.UDPAddr - err error - ) - for { - n, ua, err = pc.ReadFromUDP(buf[:]) - if err != nil { - if ctx.Err() != nil { - return - } - log.Printf("STUN ReadFrom: %v", err) - time.Sleep(time.Second) - stunReadError.Add(1) - continue - } - pkt := buf[:n] - if !stun.Is(pkt) { - stunNotSTUN.Add(1) - continue - } - txid, err := stun.ParseBindingRequest(pkt) - if err != nil { - stunNotSTUN.Add(1) - continue - } - if ua.IP.To4() != nil { - stunIPv4.Add(1) - } else { - stunIPv6.Add(1) - } - addr, _ := netip.AddrFromSlice(ua.IP) - res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) - _, err = pc.WriteTo(res, ua) - if err != nil { - stunWriteError.Add(1) - } else { - stunSuccess.Add(1) - } - } -} - var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`) func prodAutocertHostPolicy(_ context.Context, host string) error { diff --git a/cmd/derper/stun.go b/cmd/derper/stun.go new file mode 100644 index 000000000..34ad200b6 --- /dev/null +++ b/cmd/derper/stun.go @@ -0,0 +1,170 @@ +package main + +import ( + "context" + "encoding/json" + "expvar" + "fmt" + "io" + "log" + "net" + "net/netip" + "os" + "os/exec" + "time" + + "tailscale.com/metrics" + "tailscale.com/net/stun" +) + +var ( + stats = new(metrics.Set) + stunDisposition = &metrics.LabelMap{Label: "disposition"} + stunAddrFamily = &metrics.LabelMap{Label: "family"} + stunReadError = stunDisposition.Get("read_error") + stunNotSTUN = stunDisposition.Get("not_stun") + stunWriteError = stunDisposition.Get("write_error") + stunSuccess = stunDisposition.Get("success") + + stunIPv4 = stunAddrFamily.Get("ipv4") + stunIPv6 = stunAddrFamily.Get("ipv6") +) + +// statsEntry is the structure of the JSON output of the above stats. +type statsEntry struct { + CounterAddrfamily struct { + Ipv4 int64 `json:"ipv4"` + Ipv6 int64 `json:"ipv6"` + } `json:"counter_addrfamily"` + CounterRequests struct { + NotStun int64 `json:"not_stun"` + ReadError int64 `json:"read_error"` + Success int64 `json:"success"` + WriteError int64 `json:"write_error"` + } `json:"counter_requests"` +} + +func (e *statsEntry) Set() { + stunIPv4.Set(e.CounterAddrfamily.Ipv4) + stunIPv6.Set(e.CounterAddrfamily.Ipv6) + stunNotSTUN.Set(e.CounterRequests.NotStun) + stunReadError.Set(e.CounterRequests.ReadError) + stunSuccess.Set(e.CounterRequests.Success) + stunWriteError.Set(e.CounterRequests.WriteError) +} + +func init() { + stats.Set("counter_requests", stunDisposition) + stats.Set("counter_addrfamily", stunAddrFamily) + expvar.Publish("stun", stats) +} + +// printSTUNStats prints STUN stats to w every d until ctx is done. +func printSTUNStats(ctx context.Context, w io.Writer, d time.Duration) { + ticker := time.NewTicker(d) + for { + expvar.Do(func(kv expvar.KeyValue) { + if kv.Key == "stun" { + fmt.Fprintf(w, "%s\n", kv.Value) + } + }) + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + } +} + +// readSTUNStats reads lines from r containing STUN statistics and updates matching expvar values. +func readSTUNStats(ctx context.Context, r io.Reader) { + d := json.NewDecoder(r) + var entry statsEntry + for { + if err := d.Decode(&entry); err != nil { + return + } + entry.Set() + if ctx.Err() != nil { + return + } + } +} + +// serveChildSTUN starts a stun server in a child process. If the process exits before context is done, serveChildSTUN will with a log entry. +func startChildSTUN(ctx context.Context) { + cmd := exec.Command(os.Args[0], append(os.Args[1:], "-stun-only=true")...) + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Fatalf("stun: failed to create stdout pipe: %v", err) + } + cmd.Stderr = os.Stderr + err = cmd.Start() + if err != nil { + log.Fatalf("stun: failed to start subprocess: %v", err) + } + readSTUNStats(ctx, stdout) + cmd.Process.Kill() + cmd.Process.Wait() + if ctx.Err() == nil { + log.Fatalf("stun: subprocess exited unexpectedly: %v", cmd.ProcessState) + } +} + +func serveSTUN(ctx context.Context, host string, port int) { + pc, err := net.ListenPacket("udp", net.JoinHostPort(host, fmt.Sprint(port))) + if err != nil { + log.Fatalf("failed to open STUN listener: %v", err) + } + log.Printf("running STUN server on %v", pc.LocalAddr()) + // close the listener on shutdown in order to rbeak out of the read loop + go func() { + <-ctx.Done() + pc.Close() + }() + serverSTUNListener(ctx, pc.(*net.UDPConn)) +} + +func serverSTUNListener(ctx context.Context, pc *net.UDPConn) { + var buf [64 << 10]byte + var ( + n int + ua *net.UDPAddr + err error + ) + for { + n, ua, err = pc.ReadFromUDP(buf[:]) + if err != nil { + if ctx.Err() != nil { + return + } + log.Printf("STUN ReadFrom: %v", err) + time.Sleep(time.Second) + stunReadError.Add(1) + continue + } + pkt := buf[:n] + if !stun.Is(pkt) { + stunNotSTUN.Add(1) + continue + } + txid, err := stun.ParseBindingRequest(pkt) + if err != nil { + stunNotSTUN.Add(1) + continue + } + if ua.IP.To4() != nil { + stunIPv4.Add(1) + } else { + stunIPv6.Add(1) + } + addr, _ := netip.AddrFromSlice(ua.IP) + res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) + _, err = pc.WriteTo(res, ua) + if err != nil { + stunWriteError.Add(1) + } else { + stunSuccess.Add(1) + } + } +} diff --git a/cmd/derper/stun_test.go b/cmd/derper/stun_test.go new file mode 100644 index 000000000..8ea3baa46 --- /dev/null +++ b/cmd/derper/stun_test.go @@ -0,0 +1,55 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "expvar" + "strings" + "testing" + "time" +) + +var allStats = []*expvar.Int{ + stunIPv4, + stunIPv6, + stunNotSTUN, + stunReadError, + stunSuccess, + stunWriteError, +} + +func TestStunStats(t *testing.T) { + doneCtx, cancel := context.WithCancel(context.Background()) + cancel() + + var buf bytes.Buffer + + for _, s := range allStats { + s.Set(5) + } + + printSTUNStats(doneCtx, &buf, time.Millisecond) + + for _, s := range allStats { + s.Set(10) + } + + readSTUNStats(doneCtx, &buf) + + for _, s := range allStats { + if s.Value() != 5 { + t.Errorf("expected %d, got %d", 5, s.Value()) + } + } +} + +func TestStatsEntryContainsAllFields(t *testing.T) { + s := stats.String() + var e statsEntry + d := json.NewDecoder(strings.NewReader(s)) + d.DisallowUnknownFields() + if err := d.Decode(&e); err != nil { + t.Fatal(err) + } +} |
