diff options
| -rw-r--r-- | control/controlhttp/client.go | 2 | ||||
| -rw-r--r-- | util/multierr/multierr.go | 38 | ||||
| -rw-r--r-- | util/multierr/multierr_test.go | 55 |
3 files changed, 94 insertions, 1 deletions
diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 9b1d5a1a5..bf8d973e7 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -246,7 +246,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { results[i].conn = nil // so we don't close it in the defer return conn, nil } - merr := multierr.New(errs...) + merr := multierr.New(multierr.DeduplicateContextErrors(errs)...) // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error()) diff --git a/util/multierr/multierr.go b/util/multierr/multierr.go index 93ca068f5..42be001c6 100644 --- a/util/multierr/multierr.go +++ b/util/multierr/multierr.go @@ -6,6 +6,7 @@ package multierr import ( + "context" "errors" "slices" "strings" @@ -134,3 +135,40 @@ func Range(err error, fn func(error) bool) bool { } return true } + +// DeduplicateContextErrors returns a new slice of errors with at most one +// occurrence of each [context.Canceled] or [context.DeadlineExceeded], if one +// or more of them are present in the input slice. +// +// All other non-nil errors are returned as-is; nil errors are skipped. +func DeduplicateContextErrors(errs []error) []error { + // preserve nil/non-nil distinction + if errs == nil { + return nil + } else if len(errs) == 0 { + return []error{} + } + + var ( + ret []error + sawCanceled, sawDeadline bool + ) + for _, err := range errs { + if err == nil { + continue + } + if errors.Is(err, context.Canceled) { + if sawCanceled { + continue + } + sawCanceled = true + } else if errors.Is(err, context.DeadlineExceeded) { + if sawDeadline { + continue + } + sawDeadline = true + } + ret = append(ret, err) + } + return ret +} diff --git a/util/multierr/multierr_test.go b/util/multierr/multierr_test.go index de7721a66..ea8c56637 100644 --- a/util/multierr/multierr_test.go +++ b/util/multierr/multierr_test.go @@ -4,6 +4,7 @@ package multierr_test import ( + "context" "errors" "fmt" "io" @@ -107,6 +108,60 @@ func TestRange(t *testing.T) { })), want) } +func TestDeduplicateContextErrors(t *testing.T) { + testError := errors.New("test error") + + tests := []struct { + name string + input []error + want []error + }{ + {name: "nil", input: nil, want: nil}, + {name: "empty", input: []error{}, want: []error{}}, + {name: "single", input: []error{testError}, want: []error{testError}}, + { + name: "duplicate_non_context", + input: []error{testError, testError}, + want: []error{testError, testError}, + }, + { + name: "single_context", + input: []error{context.Canceled}, + want: []error{context.Canceled}, + }, + { + name: "duplicate_context", + input: []error{testError, context.Canceled, context.Canceled}, + want: []error{testError, context.Canceled}, + }, + { + name: "duplicate_context_mixed", + input: []error{ + testError, + context.Canceled, + context.Canceled, + testError, + context.DeadlineExceeded, + context.DeadlineExceeded, + }, + want: []error{ + testError, + context.Canceled, + testError, + context.DeadlineExceeded, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := multierr.DeduplicateContextErrors(tt.input) + if diff := cmp.Diff(tt.want, got, cmpopts.EquateErrors()); diff != "" { + t.Errorf("DeduplicateContextErrors() mismatch (-want +got):\n%s", diff) + } + }) + } +} + var sink error func BenchmarkEmpty(b *testing.B) { |
