summaryrefslogtreecommitdiffhomepage
path: root/util
diff options
context:
space:
mode:
authorMarwan Sulaiman <marwan@tailscale.com>2024-07-23 12:57:14 -0400
committerMarwan Sulaiman <marwan@tailscale.com>2024-07-23 12:57:14 -0400
commit2d15835bb30d358d1894f3cdc39bc899dd50ff6a (patch)
tree00a5d37bb6aab13ef905dca7b6e0400054b28451 /util
parent57856fc0d5cbffdb81d28b0dd94e5ab2110fd58c (diff)
downloadtailscale-marwan/offunc.tar.xz
tailscale-marwan/offunc.zip
util/set: add SetOfFuncmarwan/offunc
Fixes #12901 Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
Diffstat (limited to 'util')
-rw-r--r--util/set/set.go15
-rw-r--r--util/set/set_test.go16
2 files changed, 31 insertions, 0 deletions
diff --git a/util/set/set.go b/util/set/set.go
index eb0697536..acdb2e034 100644
--- a/util/set/set.go
+++ b/util/set/set.go
@@ -17,6 +17,21 @@ func SetOf[T comparable](slice []T) Set[T] {
return Of(slice...)
}
+// SetOfFunc returns a set based on the given slice and func. It is helpful
+// when you want to turn a slice of a non-comparable type to a comparable aspect of it.
+// For example, to turn a slice of "users" to a set by their user ID field you can do something like:
+/*
+ users := []*User{...}
+ userIDs := SetOfFunc(users, func(u *User) string { return u.ID }) // returns set.Set[string]
+*/
+func SetOfFunc[T any, K comparable](slice []T, f func(T) K) Set[K] {
+ s := make(Set[K])
+ for _, e := range slice {
+ s.Add(f(e))
+ }
+ return s
+}
+
// Of returns a new set constructed from the elements in slice.
func Of[T comparable](slice ...T) Set[T] {
s := make(Set[T])
diff --git a/util/set/set_test.go b/util/set/set_test.go
index 85913ad24..4fe056d1e 100644
--- a/util/set/set_test.go
+++ b/util/set/set_test.go
@@ -64,6 +64,22 @@ func TestSetOf(t *testing.T) {
}
}
+func TestSetOfFunc(t *testing.T) {
+ type T struct {
+ ID string
+ }
+ ts := []*T{{"one"}, {"two"}, {"three"}, {"four"}}
+ s := SetOfFunc(ts, func(t *T) string { return t.ID })
+ if s.Len() != 4 {
+ t.Errorf("wrong len %d; want 4", s.Len())
+ }
+ for _, e := range []string{"one", "two", "three", "four"} {
+ if !s.Contains(e) {
+ t.Errorf("should contain %s", e)
+ }
+ }
+}
+
func TestEqual(t *testing.T) {
type test struct {
name string