diff options
Diffstat (limited to 'util/workgraph/workgraph_test.go')
| -rw-r--r-- | util/workgraph/workgraph_test.go | 353 |
1 files changed, 353 insertions, 0 deletions
diff --git a/util/workgraph/workgraph_test.go b/util/workgraph/workgraph_test.go new file mode 100644 index 000000000..c36731f6b --- /dev/null +++ b/util/workgraph/workgraph_test.go @@ -0,0 +1,353 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package workgraph + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync/atomic" + "testing" + + "tailscale.com/util/must" + "tailscale.com/util/set" +) + +func debugGraph(tb testing.TB, g *WorkGraph) { + before := g.Graphviz() + tb.Cleanup(func() { + if !tb.Failed() { + return + } + + after := g.Graphviz() + tb.Logf("graphviz at start of test:\n%s", before) + tb.Logf("graphviz at end of test:\n%s", after) + }) +} + +func makeTestGraph(tb testing.TB) *WorkGraph { + logFunc := func(s string) func(context.Context) error { + return func(_ context.Context) error { + tb.Log(s) + return nil + } + } + makeNode := func(s string) Node { + return NodeFunc(s, logFunc(s+" called")) + } + withDeps := func(ss ...string) *AddNodeOpts { + return &AddNodeOpts{Dependencies: ss} + } + + g := NewWorkGraph() + + // Ensure we have at least 2 concurrent goroutines + g.Concurrency = runtime.GOMAXPROCS(-1) + if g.Concurrency < 2 { + g.Concurrency = 2 + } + + n1 := makeNode("one") + n2 := makeNode("two") + n3 := makeNode("three") + n4 := makeNode("four") + n5 := makeNode("five") + n6 := makeNode("six") + + must.Do(g.AddNode(n1, nil)) // can execute first + must.Do(g.AddNode(n2, nil)) // can execute first + must.Do(g.AddNode(n3, withDeps("one"))) + must.Do(g.AddNode(n4, withDeps("one", "two"))) + must.Do(g.AddNode(n5, withDeps("one"))) + must.Do(g.AddNode(n6, withDeps("four", "five"))) + + return g +} + +func TestWorkGraph(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + if err := g.Run(context.Background()); err != nil { + t.Fatal(err) + } +} + +func TestWorkGroup_Error(t *testing.T) { + g := NewWorkGraph() + + terr := errors.New("test error") + + returnsErr := func(_ context.Context) error { return terr } + notCalled := func(_ context.Context) error { panic("unused") } + + n1 := NodeFunc("one", returnsErr) + n2 := NodeFunc("two", notCalled) + n3 := NodeFunc("three", notCalled) + + must.Do(g.AddNode(n1, nil)) + must.Do(g.AddNode(n2, &AddNodeOpts{Dependencies: []string{"one"}})) + must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}})) + + err := g.Run(context.Background()) + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, terr) { + t.Errorf("got %v, want %v", err, terr) + } +} + +func TestWorkGroup_HandlesPanic(t *testing.T) { + g := NewWorkGraph() + + terr := errors.New("test error") + n1 := NodeFunc("one", func(_ context.Context) error { panic(terr) }) + + must.Do(g.AddNode(n1, nil)) + err := g.Run(context.Background()) + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, terr) { + t.Errorf("got %v, want %v", err, terr) + } +} + +func TestWorkGroup_Cancellation(t *testing.T) { + g := NewWorkGraph() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var running atomic.Int64 + blocks := func(ctx context.Context) error { + running.Add(1) + <-ctx.Done() + return ctx.Err() + } + + n1 := NodeFunc("one", blocks) + n2 := NodeFunc("two", blocks) + n3 := NodeFunc("three", blocks) + + must.Do(g.AddNode(n1, nil)) + must.Do(g.AddNode(n2, nil)) + + // Ensure that we have a node with dependencies that's also waiting + // since we want to verify that the queue publisher also properly + // handles context cancellation. + must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}})) + + // call Run in a goroutine since it blocks + errCh := make(chan error, 1) + go func() { + errCh <- g.Run(ctx) + }() + + // after all goroutines are running, cancel the context to unblock + for running.Load() != 2 { + // wait + } + cancel() + err := <-errCh + + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("got %v, want %v", err, context.Canceled) + } +} + +func TestTopoSortDFS(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + sorted, err := g.topoSortDFS() + if err != nil { + t.Fatal(err) + } + t.Logf("DFS topological sort: %v", sorted) + + validateTopologicalSortDFS(t, g, sorted) +} + +func validateTopologicalSortDFS(tb testing.TB, g *WorkGraph, order []string) { + // A valid ordering is any one where a node ID later in the list does + // not depend on a node ID earlier in the list. + for i, node := range order { + for j := 0; j < i; j++ { + if g.edges.Exists(node, order[j]) { + tb.Errorf("invalid edge: %v [%d] -> %v [%d]", node, i, order[j], j) + } + } + } +} + +func TestTopoSortKahn(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + groups, err := g.topoSortKahn() + if err != nil { + t.Fatal(err) + } + t.Logf("grouped topological sort: %v", groups) + + validateTopologicalSortKahn(t, g, groups) +} + +func validateTopologicalSortKahn(tb testing.TB, g *WorkGraph, groups []set.Set[string]) { + // A valid ordering is any one where a node ID later in the list does + // not depend on a node ID earlier in the list. + prev := make(map[string]bool) + for i, group := range groups { + for node := range group { + for m := range prev { + if g.edges.Exists(node, m) { + tb.Errorf("group[%d]: invalid edge: %v -> %v", i, node, m) + } + } + prev[node] = true + } + } + + // Verify that our topologically sorted groups contain all nodes. + for nid := range g.nodes { + if !prev[nid] { + tb.Errorf("topological sort missing node %v", nid) + } + } +} + +func FuzzTopSortKahn(f *testing.F) { + // We can't pass a map[string][]string (or similar) into a fuzz + // function, so instead let's create test data by using a combination + // of 'n' nodes and an adjacency matrix of edges from node to node. + // + // We then need to filter this adjacency matrix in the Fuzz function, + // since the fuzzer doesn't distinguish between "invalid fuzz inputs + // due to logic bugs", and "invalid fuzz data that causes a real + // error". + f.Add( + 10, // number of nodes + []byte{ + 1, 0, // 1 depends on 0 + 6, 2, // 6 depends on 2 + 9, 8, // 9 depends on 8 + }, + ) + f.Fuzz(func(t *testing.T, numNodes int, edges []byte) { + g := createGraphFromFuzzInput(t, numNodes, edges) + if g == nil { + return + } + + // This should not error + groups, err := g.topoSortKahn() + if err != nil { + t.Fatal(err) + } + validateTopologicalSortKahn(t, g, groups) + }) +} + +func FuzzTopSortDFS(f *testing.F) { + // We can't pass a map[string][]string (or similar) into a fuzz + // function, so instead let's create test data by using a combination + // of 'n' nodes and an adjacency matrix of edges from node to node. + // + // We then need to filter this adjacency matrix in the Fuzz function, + // since the fuzzer doesn't distinguish between "invalid fuzz inputs + // due to logic bugs", and "invalid fuzz data that causes a real + // error". + f.Add( + 10, // number of nodes + []byte{ + 1, 0, // 1 depends on 0 + 6, 2, // 6 depends on 2 + 9, 8, // 9 depends on 8 + }, + ) + f.Fuzz(func(t *testing.T, numNodes int, edges []byte) { + g := createGraphFromFuzzInput(t, numNodes, edges) + if g == nil { + return + } + + // This should not error + sorted, err := g.topoSortDFS() + if err != nil { + t.Fatal(err) + } + validateTopologicalSortDFS(t, g, sorted) + }) +} + +func createGraphFromFuzzInput(tb testing.TB, numNodes int, edges []byte) *WorkGraph { + nodeName := func(i int) string { + return fmt.Sprintf("node-%d", i) + } + + filterAdjacencyMatrix := func(numNodes int, edges []byte) map[string][]string { + deps := make(map[string][]string) + for i := 0; i < len(edges); i += 2 { + node, dep := int(edges[i]), int(edges[i+1]) + if node >= numNodes || dep >= numNodes { + // invalid node + continue + } + if node == dep { + // can't depend on self + continue + } + + // We add nodes in incrementing order (0, 1, 2, etc.), + // so an edge can't point 'forward' or it'll fail to be + // added. + if dep > node { + continue + } + + nn := nodeName(node) + deps[nn] = append(deps[nn], nodeName(dep)) + } + return deps + } + + // Constrain the number of nodes + if numNodes <= 0 || numNodes > 1000 { + return nil + } + // Must have pairs of edges (from, to) + if len(edges)%2 != 0 { + return nil + } + + // Convert list of edges into list of dependencies + deps := filterAdjacencyMatrix(numNodes, edges) + if len(deps) == 0 { + return nil + } + + // Actually create graph. + g := NewWorkGraph() + doNothing := func(context.Context) error { return nil } + for i := 0; i < numNodes; i++ { + nn := nodeName(i) + node := NodeFunc(nn, doNothing) + if err := g.AddNode(node, &AddNodeOpts{ + Dependencies: deps[nn], + }); err != nil { + tb.Error(err) // shouldn't error after we filtered out bad edges above + } + } + if tb.Failed() { + return nil + } + return g +} |
