summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--util/topk/topk.go261
-rw-r--r--util/topk/topk_test.go135
2 files changed, 0 insertions, 396 deletions
diff --git a/util/topk/topk.go b/util/topk/topk.go
deleted file mode 100644
index 95ebd895d..000000000
--- a/util/topk/topk.go
+++ /dev/null
@@ -1,261 +0,0 @@
-// Copyright (c) Tailscale Inc & contributors
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package topk defines a count-min sketch and a cheap probabilistic top-K data
-// structure that uses the count-min sketch to track the top K items in
-// constant memory and O(log(k)) time.
-package topk
-
-import (
- "container/heap"
- "hash/maphash"
- "math"
- "slices"
- "sync"
-)
-
-// TopK is a probabilistic counter of the top K items, using a count-min sketch
-// to keep track of item counts and a heap to track the top K of them.
-type TopK[T any] struct {
- heap minHeap[T]
- k int
- sf SerializeFunc[T]
- cms CountMinSketch
-}
-
-// HashFunc is responsible for providing a []byte serialization of a value,
-// appended to the provided byte slice. This is used for hashing the value when
-// adding to a CountMinSketch.
-type SerializeFunc[T any] func([]byte, T) []byte
-
-// New creates a new TopK that stores k values. Parameters for the underlying
-// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of
-// error.
-func New[T any](k int, sf SerializeFunc[T]) *TopK[T] {
- hashes, buckets := PickParams(0.001, 0.001)
- return NewWithParams(k, sf, hashes, buckets)
-}
-
-// NewWithParams creates a new TopK that stores k values, and additionally
-// allows customizing the parameters for the underlying count-min sketch.
-func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] {
- ret := &TopK[T]{
- heap: make(minHeap[T], 0, k),
- k: k,
- sf: sf,
- }
- ret.cms.init(numHashes, numCols)
- return ret
-}
-
-// Add calls AddN(val, 1).
-func (tk *TopK[T]) Add(val T) uint64 {
- return tk.AddN(val, 1)
-}
-
-var hashPool = &sync.Pool{
- New: func() any {
- buf := make([]byte, 0, 128)
- return &buf
- },
-}
-
-// AddN adds the given item to the set with the provided count, returning the
-// new estimated count.
-func (tk *TopK[T]) AddN(val T, count uint64) uint64 {
- buf := hashPool.Get().(*[]byte)
- defer hashPool.Put(buf)
- ser := tk.sf((*buf)[:0], val)
-
- vcount := tk.cms.AddN(ser, count)
-
- // If we don't have a full heap, just push it.
- if len(tk.heap) < tk.k {
- heap.Push(&tk.heap, mhValue[T]{
- count: vcount,
- val: val,
- })
- return vcount
- }
-
- // If this item's count surpasses the heap's minimum, update the heap.
- if vcount > tk.heap[0].count {
- tk.heap[0] = mhValue[T]{
- count: vcount,
- val: val,
- }
- heap.Fix(&tk.heap, 0)
- }
- return vcount
-}
-
-// Top returns the estimated top K items as stored by this TopK.
-func (tk *TopK[T]) Top() []T {
- ret := make([]T, 0, tk.k)
- for _, item := range tk.heap {
- ret = append(ret, item.val)
- }
- return ret
-}
-
-// AppendTop appends the estimated top K items as stored by this TopK to the
-// provided slice, allocating only if the slice does not have enough capacity
-// to store all items. The provided slice can be nil.
-func (tk *TopK[T]) AppendTop(sl []T) []T {
- sl = slices.Grow(sl, tk.k)
- for _, item := range tk.heap {
- sl = append(sl, item.val)
- }
- return sl
-}
-
-// CountMinSketch implements a count-min sketch, a probabilistic data structure
-// that tracks the frequency of events in a stream of data.
-//
-// See: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch
-type CountMinSketch struct {
- hashes []maphash.Seed
- nbuckets int
- matrix []uint64
-}
-
-// NewCountMinSketch creates a new CountMinSketch with the provided number of
-// hashes and buckets. Hashes and buckets are often called "depth" and "width",
-// or "d" and "w", respectively.
-func NewCountMinSketch(hashes, buckets int) *CountMinSketch {
- ret := &CountMinSketch{}
- ret.init(hashes, buckets)
- return ret
-}
-
-// PickParams provides good parameters for 'hashes' and 'buckets' when
-// constructing a CountMinSketch, given an estimated total number of counts
-// (i.e. the sum of all counts ever stored), the error factor ϵ as a float
-// (e.g. 1% is 0.001), and the probability factor δ.
-//
-// Parameters are chosen such that with a probability of 1−δ, the error is at
-// most ϵ∗totalCount. Or, in other words: if N is the true count of an event,
-// E is the estimate given by a sketch and T the total count of items in the
-// sketch, E ≤ N + T*ϵ with probability (1 - δ).
-func PickParams(err, probability float64) (hashes, buckets int) {
- d := math.Ceil(math.Log(1 / probability))
- w := math.Ceil(math.E / err)
-
- return int(d), int(w)
-}
-
-func (cms *CountMinSketch) init(hashes, buckets int) {
- for range hashes {
- cms.hashes = append(cms.hashes, maphash.MakeSeed())
- }
-
- // Need a matrix of hashes * buckets to store counts
- cms.nbuckets = buckets
- cms.matrix = make([]uint64, hashes*buckets)
-}
-
-// Add calls AddN(val, 1).
-func (cms *CountMinSketch) Add(val []byte) uint64 {
- return cms.AddN(val, 1)
-}
-
-// AddN increments the count for the given value by the provided count,
-// returning the new count.
-func (cms *CountMinSketch) AddN(val []byte, count uint64) uint64 {
- var (
- mh maphash.Hash
- ret uint64 = math.MaxUint64
- )
- for i, seed := range cms.hashes {
- mh.SetSeed(seed)
-
- // Generate a hash for this value using Lemire's alternative to modular reduction:
- // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
- mh.Write(val)
- hash := mh.Sum64()
- hash = multiplyHigh64(hash, uint64(cms.nbuckets))
-
- // The index in our matrix is (i * buckets) to move "down" i
- // rows in our matrix to the row for this hash, plus 'hash' to
- // move inside this row.
- idx := (i * cms.nbuckets) + int(hash)
-
- // Add to this row
- cms.matrix[idx] += count
- ret = min(ret, cms.matrix[idx])
- }
- return ret
-}
-
-// Get returns the count for the provided value.
-func (cms *CountMinSketch) Get(val []byte) uint64 {
- var (
- mh maphash.Hash
- ret uint64 = math.MaxUint64
- )
- for i, seed := range cms.hashes {
- mh.SetSeed(seed)
-
- // Generate a hash for this value using Lemire's alternative to modular reduction:
- // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
- mh.Write(val)
- hash := mh.Sum64()
- hash = multiplyHigh64(hash, uint64(cms.nbuckets))
-
- // The index in our matrix is (i * buckets) to move "down" i
- // rows in our matrix to the row for this hash, plus 'hash' to
- // move inside this row.
- idx := (i * cms.nbuckets) + int(hash)
-
- // Select the minimal value among all rows
- ret = min(ret, cms.matrix[idx])
- }
- return ret
-}
-
-// multiplyHigh64 implements (x * y) >> 64 "the long way" without access to a
-// 128-bit type. This function is adapted from something similar in Tensorflow:
-//
-// https://github.com/tensorflow/tensorflow/commit/a47a300185026fe7829990def9113bf3a5109fed
-//
-// TODO(andrew-d): this could be replaced with a single "MULX" instruction on
-// x86_64 platforms, which we can do if this ever turns out to be a performance
-// bottleneck.
-func multiplyHigh64(x, y uint64) uint64 {
- x_lo := x & 0xffffffff
- x_hi := x >> 32
- buckets_lo := y & 0xffffffff
- buckets_hi := y >> 32
- prod_hi := x_hi * buckets_hi
- prod_lo := x_lo * buckets_lo
- prod_mid1 := x_hi * buckets_lo
- prod_mid2 := x_lo * buckets_hi
- carry := ((prod_mid1 & 0xffffffff) + (prod_mid2 & 0xffffffff) + (prod_lo >> 32)) >> 32
- return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry
-}
-
-type mhValue[T any] struct {
- count uint64
- val T
-}
-
-// An minHeap is a min-heap of ints and associated values.
-type minHeap[T any] []mhValue[T]
-
-func (h minHeap[T]) Len() int { return len(h) }
-func (h minHeap[T]) Less(i, j int) bool { return h[i].count < h[j].count }
-func (h minHeap[T]) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
-
-func (h *minHeap[T]) Push(x any) {
- // Push and Pop use pointer receivers because they modify the slice's length,
- // not just its contents.
- *h = append(*h, x.(mhValue[T]))
-}
-
-func (h *minHeap[T]) Pop() any {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[0 : n-1]
- return x
-}
diff --git a/util/topk/topk_test.go b/util/topk/topk_test.go
deleted file mode 100644
index 7679f59a3..000000000
--- a/util/topk/topk_test.go
+++ /dev/null
@@ -1,135 +0,0 @@
-// Copyright (c) Tailscale Inc & contributors
-// SPDX-License-Identifier: BSD-3-Clause
-
-package topk
-
-import (
- "encoding/binary"
- "fmt"
- "slices"
- "testing"
-)
-
-func TestCountMinSketch(t *testing.T) {
- cms := NewCountMinSketch(4, 10)
- items := []string{"foo", "bar", "baz", "asdf", "quux"}
- for _, item := range items {
- cms.Add([]byte(item))
- }
- for _, item := range items {
- count := cms.Get([]byte(item))
- if count < 1 {
- t.Errorf("item %q should have count >= 1", item)
- } else if count > 1 {
- t.Logf("item %q has count > 1: %d", item, count)
- }
- }
-
- // Test that an item that's *not* in the set has a value lower than the
- // total number of items we inserted (in the case that all items
- // collided).
- noItemCount := cms.Get([]byte("doesn't exist"))
- if noItemCount > uint64(len(items)) {
- t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount)
- }
-}
-
-func TestTopK(t *testing.T) {
- // This is probabilistic, so we're going to try 10 times to get the
- // "right" value; the likelihood that we fail on all attempts is
- // vanishingly small since the number of hash buckets is drastically
- // larger than the number of items we're inserting.
- var (
- got []int
- want = []int{5, 6, 7, 8, 9}
- )
- for range 10 {
- topk := NewWithParams[int](5, func(in []byte, val int) []byte {
- return binary.LittleEndian.AppendUint64(in, uint64(val))
- }, 4, 1000)
-
- // Add the first 10 integers with counts equal to 2x their value
- for i := range 10 {
- topk.AddN(i, uint64(i*2))
- }
-
- got = topk.Top()
- t.Logf("top K items: %+v", got)
- slices.Sort(got)
-
- if slices.Equal(got, want) {
- // All good!
- return
- }
-
- // continue and retry or fail
- }
-
- t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want)
-}
-
-func TestPickParams(t *testing.T) {
- hashes, buckets := PickParams(
- 0.001, // 0.1% error rate
- 0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error
- )
- t.Logf("hashes = %d, buckets = %d", hashes, buckets)
-}
-
-func BenchmarkCountMinSketch(b *testing.B) {
- cms := NewCountMinSketch(PickParams(0.001, 0.001))
- b.ResetTimer()
- b.ReportAllocs()
-
- var enc [8]byte
- for i := range b.N {
- binary.LittleEndian.PutUint64(enc[:], uint64(i))
- cms.Add(enc[:])
- }
-}
-
-func BenchmarkTopK(b *testing.B) {
- for _, n := range []int{
- 10,
- 128,
- 256,
- 1024,
- 8192,
- } {
- b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) {
- out := make([]int, 0, n)
- topk := New[int](n, func(in []byte, val int) []byte {
- return binary.LittleEndian.AppendUint64(in, uint64(val))
- })
- b.ResetTimer()
- b.ReportAllocs()
-
- for i := range b.N {
- topk.Add(i)
- }
- out = topk.AppendTop(out[:0]) // should not allocate
- _ = out // appease linter
- })
- }
-}
-
-func TestMultiplyHigh64(t *testing.T) {
- testCases := []struct {
- x, y uint64
- want uint64
- }{
- {0, 0, 0},
- {0xffffffff, 0xffffffff, 0},
- {0x2, 0xf000000000000000, 1},
- {0x3, 0xf000000000000000, 2},
- {0x3, 0xf000000000000001, 2},
- {0x3, 0xffffffffffffffff, 2},
- {0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe},
- }
- for _, tc := range testCases {
- got := multiplyHigh64(tc.x, tc.y)
- if got != tc.want {
- t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want)
- }
- }
-}