summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoe Tsai <joetsai@digital-static.net>2025-07-22 09:22:17 -1000
committerGitHub <noreply@github.com>2025-07-22 12:22:17 -0700
commit0de5e7b94f0bb89bcaed108f656d3ed50da85d02 (patch)
treef91e1073913ecd0d3bb107feaa2947f80c035b70
parent44947054967e3eda476c92206e0a14fd1ffc4ec0 (diff)
downloadtailscale-0de5e7b94f0bb89bcaed108f656d3ed50da85d02.tar.xz
tailscale-0de5e7b94f0bb89bcaed108f656d3ed50da85d02.zip
util/set: add IntSet (#16602)
IntSet is a set optimized for integers. Updates tailscale/corp#29809 Signed-off-by: Joe Tsai <joetsai@digital-static.net>
-rw-r--r--util/set/intset.go172
-rw-r--r--util/set/intset_test.go174
2 files changed, 346 insertions, 0 deletions
diff --git a/util/set/intset.go b/util/set/intset.go
new file mode 100644
index 000000000..b747d3bff
--- /dev/null
+++ b/util/set/intset.go
@@ -0,0 +1,172 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package set
+
+import (
+ "iter"
+ "maps"
+ "math/bits"
+ "math/rand/v2"
+
+ "golang.org/x/exp/constraints"
+ "tailscale.com/util/mak"
+)
+
+// IntSet is a set optimized for integer values close to zero
+// or set of integers that are close in value.
+type IntSet[T constraints.Integer] struct {
+ // bits is a [bitSet] for numbers less than [bits.UintSize].
+ bits bitSet
+
+ // extra is a mapping of [bitSet] for numbers not in bits,
+ // where the key is a number modulo [bits.UintSize].
+ extra map[uint64]bitSet
+
+ // extraLen is the count of numbers in extra since len(extra)
+ // does not reflect that each bitSet may have multiple numbers.
+ extraLen int
+}
+
+// Values returns an iterator over the elements of the set.
+// The iterator will yield the elements in no particular order.
+func (s IntSet[T]) Values() iter.Seq[T] {
+ return func(yield func(T) bool) {
+ if s.bits != 0 {
+ for i := range s.bits.values() {
+ if !yield(decodeZigZag[T](i)) {
+ return
+ }
+ }
+ }
+ if s.extra != nil {
+ for hi, bs := range s.extra {
+ for lo := range bs.values() {
+ if !yield(decodeZigZag[T](hi*bits.UintSize + lo)) {
+ return
+ }
+ }
+ }
+ }
+ }
+}
+
+// Contains reports whether e is in the set.
+func (s IntSet[T]) Contains(e T) bool {
+ if v := encodeZigZag(e); v < bits.UintSize {
+ return s.bits.contains(v)
+ } else {
+ hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize)
+ return s.extra[hi].contains(lo)
+ }
+}
+
+// Add adds e to the set.
+//
+// When storing a IntSet in a map as a value type,
+// it is important to re-assign the map entry after calling Add or Delete,
+// as the IntSet's representation may change.
+func (s *IntSet[T]) Add(e T) {
+ if v := encodeZigZag(e); v < bits.UintSize {
+ s.bits.add(v)
+ } else {
+ hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize)
+ if bs := s.extra[hi]; !bs.contains(lo) {
+ bs.add(lo)
+ mak.Set(&s.extra, hi, bs)
+ s.extra[hi] = bs
+ s.extraLen++
+ }
+ }
+}
+
+// AddSeq adds the values from seq to the set.
+func (s *IntSet[T]) AddSeq(seq iter.Seq[T]) {
+ for e := range seq {
+ s.Add(e)
+ }
+}
+
+// Len reports the number of elements in the set.
+func (s IntSet[T]) Len() int {
+ return s.bits.len() + s.extraLen
+}
+
+// Delete removes e from the set.
+//
+// When storing a IntSet in a map as a value type,
+// it is important to re-assign the map entry after calling Add or Delete,
+// as the IntSet's representation may change.
+func (s *IntSet[T]) Delete(e T) {
+ if v := encodeZigZag(e); v < bits.UintSize {
+ s.bits.delete(v)
+ } else {
+ hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize)
+ if bs := s.extra[hi]; bs.contains(lo) {
+ bs.delete(lo)
+ mak.Set(&s.extra, hi, bs)
+ s.extra[hi] = bs
+ s.extraLen--
+ }
+ }
+}
+
+// Clone returns a copy of s that doesn't alias the original.
+func (s IntSet[T]) Clone() IntSet[T] {
+ return IntSet[T]{
+ bits: s.bits,
+ extra: maps.Clone(s.extra),
+ extraLen: s.extraLen,
+ }
+}
+
+type bitSet uint
+
+func (s bitSet) values() iter.Seq[uint64] {
+ return func(yield func(uint64) bool) {
+ // Hyrum-proofing: randomly iterate in forwards or reverse.
+ if rand.Uint64()%2 == 0 {
+ for i := 0; i < bits.UintSize; i++ {
+ if s.contains(uint64(i)) && !yield(uint64(i)) {
+ return
+ }
+ }
+ } else {
+ for i := bits.UintSize; i >= 0; i-- {
+ if s.contains(uint64(i)) && !yield(uint64(i)) {
+ return
+ }
+ }
+ }
+ }
+}
+func (s bitSet) len() int { return bits.OnesCount(uint(s)) }
+func (s bitSet) contains(i uint64) bool { return s&(1<<i) > 0 }
+func (s *bitSet) add(i uint64) { *s |= 1 << i }
+func (s *bitSet) delete(i uint64) { *s &= ^(1 << i) }
+
+// encodeZigZag encodes an integer as an unsigned integer ensuring that
+// negative integers near zero still have a near zero positive value.
+// For unsigned integers, it returns the value verbatim.
+func encodeZigZag[T constraints.Integer](v T) uint64 {
+ var zero T
+ if ^zero >= 0 { // must be constraints.Unsigned
+ return uint64(v)
+ } else { // must be constraints.Signed
+ // See [google.golang.org/protobuf/encoding/protowire.EncodeZigZag]
+ return uint64(int64(v)<<1) ^ uint64(int64(v)>>63)
+ }
+}
+
+// decodeZigZag decodes an unsigned integer as an integer ensuring that
+// negative integers near zero still have a near zero positive value.
+// For unsigned integers, it returns the value verbatim.
+func decodeZigZag[T constraints.Integer](v uint64) T {
+ var zero T
+ if ^zero >= 0 { // must be constraints.Unsigned
+ return T(v)
+ } else { // must be constraints.Signed
+ // See [google.golang.org/protobuf/encoding/protowire.DecodeZigZag]
+ return T(int64(v>>1) ^ int64(v)<<63>>63)
+ }
+}
diff --git a/util/set/intset_test.go b/util/set/intset_test.go
new file mode 100644
index 000000000..9523fe88d
--- /dev/null
+++ b/util/set/intset_test.go
@@ -0,0 +1,174 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package set
+
+import (
+ "maps"
+ "math"
+ "slices"
+ "testing"
+
+ "golang.org/x/exp/constraints"
+)
+
+func TestIntSet(t *testing.T) {
+ t.Run("Int64", func(t *testing.T) {
+ ss := make(Set[int64])
+ var si IntSet[int64]
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, -5)
+ deleteInt(t, ss, &si, 2)
+ deleteInt(t, ss, &si, 75)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, 2)
+ addInt(t, ss, &si, 75)
+ addInt(t, ss, &si, 75)
+ addInt(t, ss, &si, -3)
+ addInt(t, ss, &si, -3)
+ addInt(t, ss, &si, -3)
+ addInt(t, ss, &si, math.MinInt64)
+ addInt(t, ss, &si, 8)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, 77)
+ addInt(t, ss, &si, 76)
+ addInt(t, ss, &si, 76)
+ addInt(t, ss, &si, 76)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, -5)
+ addInt(t, ss, &si, 7)
+ addInt(t, ss, &si, -83)
+ addInt(t, ss, &si, math.MaxInt64)
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, -5)
+ deleteInt(t, ss, &si, 2)
+ deleteInt(t, ss, &si, 75)
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, math.MinInt64)
+ deleteInt(t, ss, &si, math.MaxInt64)
+ intValues(t, ss, si)
+ })
+
+ t.Run("Uint64", func(t *testing.T) {
+ ss := make(Set[uint64])
+ var si IntSet[uint64]
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, 5)
+ deleteInt(t, ss, &si, 2)
+ deleteInt(t, ss, &si, 75)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, 2)
+ addInt(t, ss, &si, 75)
+ addInt(t, ss, &si, 75)
+ addInt(t, ss, &si, 3)
+ addInt(t, ss, &si, 3)
+ addInt(t, ss, &si, 8)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, 77)
+ addInt(t, ss, &si, 76)
+ addInt(t, ss, &si, 76)
+ addInt(t, ss, &si, 76)
+ intValues(t, ss, si)
+ addInt(t, ss, &si, 5)
+ addInt(t, ss, &si, 7)
+ addInt(t, ss, &si, 83)
+ addInt(t, ss, &si, math.MaxInt64)
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, 5)
+ deleteInt(t, ss, &si, 2)
+ deleteInt(t, ss, &si, 75)
+ intValues(t, ss, si)
+ deleteInt(t, ss, &si, math.MaxInt64)
+ intValues(t, ss, si)
+ })
+}
+
+func intValues[T constraints.Integer](t testing.TB, ss Set[T], si IntSet[T]) {
+ got := slices.Collect(maps.Keys(ss))
+ slices.Sort(got)
+ want := slices.Collect(si.Values())
+ slices.Sort(want)
+ if !slices.Equal(got, want) {
+ t.Fatalf("Values mismatch:\n\tgot %v\n\twant %v", got, want)
+ }
+ if got, want := si.Len(), ss.Len(); got != want {
+ t.Fatalf("Len() = %v, want %v", got, want)
+ }
+}
+
+func addInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) {
+ t.Helper()
+ if got, want := si.Contains(v), ss.Contains(v); got != want {
+ t.Fatalf("Contains(%v) = %v, want %v", v, got, want)
+ }
+ ss.Add(v)
+ si.Add(v)
+ if !si.Contains(v) {
+ t.Fatalf("Contains(%v) = false, want true", v)
+ }
+ if got, want := si.Len(), ss.Len(); got != want {
+ t.Fatalf("Len() = %v, want %v", got, want)
+ }
+}
+
+func deleteInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) {
+ t.Helper()
+ if got, want := si.Contains(v), ss.Contains(v); got != want {
+ t.Fatalf("Contains(%v) = %v, want %v", v, got, want)
+ }
+ ss.Delete(v)
+ si.Delete(v)
+ if si.Contains(v) {
+ t.Fatalf("Contains(%v) = true, want false", v)
+ }
+ if got, want := si.Len(), ss.Len(); got != want {
+ t.Fatalf("Len() = %v, want %v", got, want)
+ }
+}
+
+func TestZigZag(t *testing.T) {
+ t.Run("Int64", func(t *testing.T) {
+ for _, tt := range []struct {
+ decoded int64
+ encoded uint64
+ }{
+ {math.MinInt64, math.MaxUint64},
+ {-2, 3},
+ {-1, 1},
+ {0, 0},
+ {1, 2},
+ {2, 4},
+ {math.MaxInt64, math.MaxUint64 - 1},
+ } {
+ encoded := encodeZigZag(tt.decoded)
+ if encoded != tt.encoded {
+ t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded)
+ }
+ decoded := decodeZigZag[int64](tt.encoded)
+ if decoded != tt.decoded {
+ t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded)
+ }
+ }
+ })
+ t.Run("Uint64", func(t *testing.T) {
+ for _, tt := range []struct {
+ decoded uint64
+ encoded uint64
+ }{
+ {0, 0},
+ {1, 1},
+ {2, 2},
+ {math.MaxInt64, math.MaxInt64},
+ {math.MaxUint64, math.MaxUint64},
+ } {
+ encoded := encodeZigZag(tt.decoded)
+ if encoded != tt.encoded {
+ t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded)
+ }
+ decoded := decodeZigZag[uint64](tt.encoded)
+ if decoded != tt.decoded {
+ t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded)
+ }
+ }
+ })
+}