summaryrefslogtreecommitdiffhomepage
path: root/syncs
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2023-06-25 12:16:55 -0700
committerBrad Fitzpatrick <brad@danga.com>2023-06-25 12:51:19 -0700
commitba41d143209e9519b36536eafdc8292f0cc67047 (patch)
treefc6933f1b75160eb114eaf1da5a73744375e5c9d /syncs
parent1f57088cbd7ea9c03033c4d4cd0285927e55861f (diff)
downloadtailscale-ba41d143209e9519b36536eafdc8292f0cc67047.tar.xz
tailscale-ba41d143209e9519b36536eafdc8292f0cc67047.zip
syncs: add ShardedMap type
Updates tailscale/corp#7354 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Diffstat (limited to 'syncs')
-rw-r--r--syncs/shardedmap.go111
-rw-r--r--syncs/shardedmap_test.go44
2 files changed, 155 insertions, 0 deletions
diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go
new file mode 100644
index 000000000..00ce3aafa
--- /dev/null
+++ b/syncs/shardedmap.go
@@ -0,0 +1,111 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package syncs
+
+import (
+ "sync"
+
+ "golang.org/x/sys/cpu"
+)
+
+// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined
+// K-sharding function.
+//
+// The zero value is not safe for use; use NewShardedMap.
+type ShardedMap[K comparable, V any] struct {
+ shardFunc func(K) int
+ shards []mapShard[K, V]
+}
+
+type mapShard[K comparable, V any] struct {
+ mu sync.Mutex
+ m map[K]V
+ _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes
+}
+
+// NewShardedMap returns a new ShardedMap with the given number of shards and
+// sharding function.
+//
+// The shard func must return a integer in the range [0, shards) purely
+// deterministically based on the provided K.
+func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] {
+ m := &ShardedMap[K, V]{
+ shardFunc: shard,
+ shards: make([]mapShard[K, V], shards),
+ }
+ for i := range m.shards {
+ m.shards[i].m = make(map[K]V)
+ }
+ return m
+}
+
+func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] {
+ return &m.shards[m.shardFunc(key)]
+}
+
+// GetOk returns m[key] and whether it was present.
+func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) {
+ shard := m.shard(key)
+ shard.mu.Lock()
+ defer shard.mu.Unlock()
+ value, ok = shard.m[key]
+ return
+}
+
+// Get returns m[key] or the zero value of V if key is not present.
+func (m *ShardedMap[K, V]) Get(key K) (value V) {
+ value, _ = m.GetOk(key)
+ return
+}
+
+// Set sets m[key] = value.
+//
+// It reports whether the map grew in size (that is, whether key was not already
+// present in m).
+func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) {
+ shard := m.shard(key)
+ shard.mu.Lock()
+ defer shard.mu.Unlock()
+ s0 := len(shard.m)
+ shard.m[key] = value
+ return len(shard.m) > s0
+}
+
+// Delete removes key from m.
+//
+// It reports whether the map size shrunk (that is, whether key was present in
+// the map).
+func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) {
+ shard := m.shard(key)
+ shard.mu.Lock()
+ defer shard.mu.Unlock()
+ s0 := len(shard.m)
+ delete(shard.m, key)
+ return len(shard.m) < s0
+}
+
+// Contains reports whether m contains key.
+func (m *ShardedMap[K, V]) Contains(key K) bool {
+ shard := m.shard(key)
+ shard.mu.Lock()
+ defer shard.mu.Unlock()
+ _, ok := shard.m[key]
+ return ok
+}
+
+// Len returns the number of elements in m.
+//
+// It does so by locking shards one at a time, so it's not particularly cheap,
+// nor does it give a consistent snapshot of the map. It's mostly intended for
+// metrics or testing.
+func (m *ShardedMap[K, V]) Len() int {
+ n := 0
+ for i := range m.shards {
+ shard := &m.shards[i]
+ shard.mu.Lock()
+ n += len(shard.m)
+ shard.mu.Unlock()
+ }
+ return n
+}
diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go
new file mode 100644
index 000000000..b09a268d7
--- /dev/null
+++ b/syncs/shardedmap_test.go
@@ -0,0 +1,44 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package syncs
+
+import "testing"
+
+func TestShardedMap(t *testing.T) {
+ m := NewShardedMap[int, string](16, func(i int) int { return i % 16 })
+
+ if m.Contains(1) {
+ t.Errorf("got contains; want !contains")
+ }
+ if !m.Set(1, "one") {
+ t.Errorf("got !set; want set")
+ }
+ if m.Set(1, "one") {
+ t.Errorf("got set; want !set")
+ }
+ if !m.Contains(1) {
+ t.Errorf("got !contains; want contains")
+ }
+ if g, w := m.Get(1), "one"; g != w {
+ t.Errorf("got %q; want %q", g, w)
+ }
+ if _, ok := m.GetOk(1); !ok {
+ t.Errorf("got ok; want !ok")
+ }
+ if _, ok := m.GetOk(2); ok {
+ t.Errorf("got ok; want !ok")
+ }
+ if g, w := m.Len(), 1; g != w {
+ t.Errorf("got Len %v; want %v", g, w)
+ }
+ if m.Delete(2) {
+ t.Errorf("got deleted; want !deleted")
+ }
+ if !m.Delete(1) {
+ t.Errorf("got !deleted; want deleted")
+ }
+ if g, w := m.Len(), 0; g != w {
+ t.Errorf("got Len %v; want %v", g, w)
+ }
+}