summaryrefslogtreecommitdiffhomepage
path: root/syncs/shardvalue_test.go
blob: ab34527abd77f4922f4bcb71d75bce2462a07639 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

package syncs

import (
	"math"
	"runtime"
	"sync"
	"sync/atomic"
	"testing"

	"golang.org/x/sys/cpu"
)

func TestShardValue(t *testing.T) {
	type intVal struct {
		atomic.Int64
		_ cpu.CacheLinePad
	}

	t.Run("One", func(t *testing.T) {
		sv := NewShardValue[intVal]()
		sv.One(func(v *intVal) {
			v.Store(10)
		})

		var v int64
		for i := range sv.shards {
			v += sv.shards[i].Load()
		}
		if v != 10 {
			t.Errorf("got %v, want 10", v)
		}
	})

	t.Run("All", func(t *testing.T) {
		sv := NewShardValue[intVal]()
		for i := range sv.shards {
			sv.shards[i].Store(int64(i))
		}

		var total int64
		sv.All(func(v *intVal) bool {
			total += v.Load()
			return true
		})
		// triangle coefficient lower one order due to 0 index
		want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2)
		if total != want {
			t.Errorf("got %v, want %v", total, want)
		}
	})

	t.Run("Len", func(t *testing.T) {
		sv := NewShardValue[intVal]()
		if got, want := sv.Len(), runtime.NumCPU(); got != want {
			t.Errorf("got %v, want %v", got, want)
		}
	})

	t.Run("distribution", func(t *testing.T) {
		sv := NewShardValue[intVal]()

		goroutines := 1000
		iterations := 10000
		var wg sync.WaitGroup
		wg.Add(goroutines)
		for range goroutines {
			go func() {
				defer wg.Done()
				for range iterations {
					sv.One(func(v *intVal) {
						v.Add(1)
					})
				}
			}()
		}
		wg.Wait()

		var (
			total        int64
			distribution []int64
		)
		t.Logf("distribution:")
		sv.All(func(v *intVal) bool {
			total += v.Load()
			distribution = append(distribution, v.Load())
			t.Logf("%d", v.Load())
			return true
		})

		if got, want := total, int64(goroutines*iterations); got != want {
			t.Errorf("got %v, want %v", got, want)
		}
		if got, want := len(distribution), runtime.NumCPU(); got != want {
			t.Errorf("got %v, want %v", got, want)
		}

		mean := total / int64(len(distribution))
		for _, v := range distribution {
			if v < mean/10 || v > mean*10 {
				t.Logf("distribution is very unbalanced: %v", distribution)
			}
		}
		t.Logf("mean:  %d", mean)

		var standardDev int64
		for _, v := range distribution {
			standardDev += ((v - mean) * (v - mean))
		}
		standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution)))))
		t.Logf("stdev: %d", standardDev)

		if standardDev > mean/3 {
			t.Logf("standard deviation is too high: %v", standardDev)
		}
	})
}