summaryrefslogtreecommitdiffhomepage
path: root/control/controlclient/controlclient_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'control/controlclient/controlclient_test.go')
-rw-r--r--control/controlclient/controlclient_test.go112
1 files changed, 112 insertions, 0 deletions
diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go
index 2205a0eb3..5c25af0f4 100644
--- a/control/controlclient/controlclient_test.go
+++ b/control/controlclient/controlclient_test.go
@@ -406,6 +406,118 @@ func testHTTPS(t *testing.T, withProxy bool) {
}
}
+// TestRegisterRateLimited verifies that the client correctly handles 429
+// responses to registration requests by parsing the Retry-After header
+// and returning a rateLimitError.
+func TestRegisterRateLimited(t *testing.T) {
+ bakedroots.ResetForTest(t, tlstest.TestRootCA())
+
+ bus := eventbustest.NewBus(t)
+
+ controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer controlLn.Close()
+
+ var registerAttempts atomic.Int64
+ tc := &testcontrol.Server{
+ Logf: tstest.WhileTestRunningLogger(t),
+ MaybeRateLimitRegister: func() (bool, string, string) {
+ if registerAttempts.Add(1) == 1 {
+ return true, "30", "try again later"
+ }
+ return false, "", ""
+ },
+ }
+ controlSrv := &http.Server{
+ Handler: tc,
+ ErrorLog: logger.StdLogger(t.Logf),
+ }
+ go controlSrv.Serve(controlLn)
+
+ const fakeControlIP = "1.2.3.4"
+
+ dialer := &tsdial.Dialer{}
+ dialer.SetNetMon(netmon.NewStatic())
+ dialer.SetBus(bus)
+ dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
+ }
+ var d net.Dialer
+ if host == fakeControlIP {
+ return d.DialContext(ctx, network, controlLn.Addr().String())
+ }
+ return nil, fmt.Errorf("unexpected dial to %q", addr)
+ })
+
+ opts := Options{
+ Persist: persist.Persist{},
+ GetMachinePrivateKey: func() (key.MachinePrivate, error) {
+ return key.NewMachine(), nil
+ },
+ ServerURL: "https://controlplane.tstest",
+ Clock: tstime.StdClock{},
+ Hostinfo: &tailcfg.Hostinfo{
+ BackendLogID: "test-backend-log-id",
+ },
+ DiscoPublicKey: key.NewDisco().Public(),
+ Logf: t.Logf,
+ HealthTracker: health.NewTracker(bus),
+ PopBrowserURL: func(url string) {
+ t.Logf("PopBrowserURL: %q", url)
+ },
+ Dialer: dialer,
+ Bus: bus,
+ }
+ d, err := NewDirect(opts)
+ if err != nil {
+ t.Fatalf("NewDirect: %v", err)
+ }
+
+ d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
+ if host == "controlplane.tstest" {
+ return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
+ }
+ t.Errorf("unexpected DNS query for %q", host)
+ return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // First attempt should get a 429 and return a rateLimitError.
+ _, err = d.TryLogin(ctx, LoginEphemeral)
+ if err == nil {
+ t.Fatal("expected rate limit error on first attempt, got nil")
+ }
+ var rle *rateLimitError
+ if !errors.As(err, &rle) {
+ t.Fatalf("expected *rateLimitError, got %T: %v", err, err)
+ }
+ if rle.retryAfter != 30*time.Second {
+ t.Errorf("retryAfter = %v, want 30s", rle.retryAfter)
+ }
+ if rle.msg != "try again later" {
+ t.Errorf("msg = %q, want %q", rle.msg, "try again later")
+ }
+
+ // Second attempt should succeed (server no longer rate-limiting).
+ url, err := d.TryLogin(ctx, LoginEphemeral)
+ if err != nil {
+ t.Fatalf("TryLogin after rate limit: %v", err)
+ }
+ if url != "" {
+ t.Errorf("got URL %q, want empty", url)
+ }
+
+ if got := registerAttempts.Load(); got != 2 {
+ t.Errorf("register attempts = %d, want 2", got)
+ }
+}
+
func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != target {