summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrew Dunham <andrew@du.nham.ca>2023-05-18 16:38:43 -0400
committerAndrew Dunham <andrew@du.nham.ca>2023-05-20 19:12:20 -0400
commit98b15d46d9767696f4823272b47d01edb51d3521 (patch)
treeb72af37a05bfbcc6d656d1a6a37e48169001abd6
parent0ca8bf1e26514c37999a9f111b192b8ce7dc4167 (diff)
downloadtailscale-andrew/slicesx-deduplicate.tar.xz
tailscale-andrew/slicesx-deduplicate.zip
util/slicesx: add Deduplicate/DeduplicateFuncandrew/slicesx-deduplicate
These functions allow deduplicating elements in a slice, either via direct comparison or via a function that returns a key to be used for comparison. Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Ie6a20acf0431247487ac5ead110d56580dacfee4
-rw-r--r--util/slicesx/slicesx.go60
-rw-r--r--util/slicesx/slicesx_test.go62
2 files changed, 122 insertions, 0 deletions
diff --git a/util/slicesx/slicesx.go b/util/slicesx/slicesx.go
index ce55594db..826bd173e 100644
--- a/util/slicesx/slicesx.go
+++ b/util/slicesx/slicesx.go
@@ -42,3 +42,63 @@ func Shuffle[S ~[]T, T any](s S) {
s[i], s[j] = s[j], s[i]
}
}
+
+// Deduplicate removes duplicate elements from the provided slice, compared as
+// if using the == operator. The slice is modified and returned, similar to the
+// append function.
+func Deduplicate[S ~[]T, T comparable](s S) S {
+ // Avoid allocs on empty slices
+ if s == nil {
+ return nil
+ }
+
+ var (
+ ret = s[:0]
+ seen = make(map[T]bool)
+ )
+ for _, elem := range s {
+ if seen[elem] {
+ continue
+ }
+ seen[elem] = true
+ ret = append(ret, elem)
+ }
+
+ // Zero out elements remaining at end of existing slice.
+ var zero T
+ for i := len(ret); i < len(s); i++ {
+ s[i] = zero
+ }
+
+ return ret
+}
+
+// DeduplicateFunc is the same as Deduplicate, but uses the provided function
+// to provide a key that is used for deduplication.
+func DeduplicateFunc[S ~[]T, T any, K comparable](s S, fn func(T) K) S {
+ // Avoid allocs on empty slices
+ if s == nil {
+ return nil
+ }
+
+ var (
+ ret = s[:0]
+ seen = make(map[K]bool)
+ )
+ for _, elem := range s {
+ key := fn(elem)
+ if seen[key] {
+ continue
+ }
+ seen[key] = true
+ ret = append(ret, elem)
+ }
+
+ // Zero out elements remaining at end of existing slice.
+ var zero T
+ for i := len(ret); i < len(s); i++ {
+ s[i] = zero
+ }
+
+ return ret
+}
diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go
index 1d6062d6a..8623207ce 100644
--- a/util/slicesx/slicesx_test.go
+++ b/util/slicesx/slicesx_test.go
@@ -44,6 +44,7 @@ func BenchmarkInterleave(b *testing.B) {
)
}
}
+
func TestShuffle(t *testing.T) {
var sl []int
for i := 0; i < 100; i++ {
@@ -64,3 +65,64 @@ func TestShuffle(t *testing.T) {
t.Errorf("expected shuffle after 10 tries")
}
}
+
+func TestDeduplicate(t *testing.T) {
+ testCases := []struct {
+ name string
+ ss []int
+ want []int
+ }{
+ {name: "no_dupes", ss: []int{1, 2, 3, 4}, want: []int{1, 2, 3, 4}},
+ {name: "ordered_dupes", ss: []int{1, 1, 2, 2, 3, 1}, want: []int{1, 2, 3}},
+ {name: "unordered_dupes", ss: []int{1, 2, 3, 1, 2, 3}, want: []int{1, 2, 3}},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := Deduplicate(tc.ss)
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("got %v; want %v", got, tc.want)
+ }
+ })
+ }
+}
+
+func TestDeduplicateFunc(t *testing.T) {
+ type uncomparable struct {
+ _ [0]map[string]int
+ key string
+ }
+
+ testCases := []struct {
+ name string
+ ss []uncomparable
+ want []uncomparable
+ }{
+ {
+ name: "no_dupes",
+ ss: []uncomparable{{key: "one"}, {key: "two"}},
+ want: []uncomparable{{key: "one"}, {key: "two"}},
+ },
+ {
+ name: "ordered_dupes",
+ ss: []uncomparable{{key: "one"}, {key: "one"}, {key: "two"}, {key: "two"}, {key: "two"}},
+ want: []uncomparable{{key: "one"}, {key: "two"}},
+ },
+ {
+ name: "unordered_dupes",
+ ss: []uncomparable{{key: "one"}, {key: "two"}, {key: "one"}, {key: "two"}, {key: "one"}},
+ want: []uncomparable{{key: "one"}, {key: "two"}},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := DeduplicateFunc(tc.ss, func(uu uncomparable) string {
+ return uu.key
+ })
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("got %v; want %v", got, tc.want)
+ }
+ })
+ }
+}