summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTom DNetto <tom@tailscale.com>2022-07-11 11:28:18 -0700
committerTom <twitchyliquid64@users.noreply.github.com>2022-07-15 10:44:43 -0700
commit4f1374ec9eec11cf474a6c33ff22ee268d9ec136 (patch)
tree6e94a8f3d2284dcf767acdef49e79371cb766695
parentaf412e8874e94dc3ac57c37c3ec5e0606aa08fbb (diff)
downloadtailscale-4f1374ec9eec11cf474a6c33ff22ee268d9ec136.tar.xz
tailscale-4f1374ec9eec11cf474a6c33ff22ee268d9ec136.zip
tka: implement consensus & state computation internals
Signed-off-by: Tom DNetto <tom@tailscale.com>
-rw-r--r--tka/chaintest_test.go365
-rw-r--r--tka/tka.go332
-rw-r--r--tka/tka_test.go187
3 files changed, 884 insertions, 0 deletions
diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go
new file mode 100644
index 000000000..3b3ce4c4c
--- /dev/null
+++ b/tka/chaintest_test.go
@@ -0,0 +1,365 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tka
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "fmt"
+ "strconv"
+ "strings"
+ "testing"
+ "text/scanner"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+)
+
+// chaintest_test.go implements test helpers for concisely describing
+// chains of possibly signed AUMs, to assist in making tests shorter and
+// easier to read.
+
+// parsed representation of a named AUM in a test chain.
+type testchainNode struct {
+ Name string
+ Parent string
+ Uses []scanner.Position
+
+ HashSeed int
+ Template string
+ SignedWith string
+}
+
+// testChain represents a constructed web of AUMs for testing purposes.
+type testChain struct {
+ Nodes map[string]*testchainNode
+ AUMs map[string]AUM
+ AUMHashes map[string]AUMHash
+
+ // Configured by options to NewTestchain()
+ Template map[string]AUM
+ Key map[string]*Key
+ KeyPrivs map[string]ed25519.PrivateKey
+ SignAllKeys []string
+}
+
+// newTestchain constructs a web of AUMs based on the provided input and
+// options.
+//
+// Input is expected to be a graph & tweaks, looking like this:
+//
+// G1 -> A -> B
+// | -> C
+//
+// which defines AUMs G1, A, B, and C; with G1 having no parent, A having
+// G1 as a parent, and both B & C having A as a parent.
+//
+// Tweaks are specified like this:
+//
+// <AUM>.<tweak> = <value>
+//
+// for example: G1.hashSeed = 2
+//
+// There are 3 available tweaks:
+// - hashSeed: Set to an integer to tweak the AUM hash of that AUM.
+// - template: Set to the name of a template provided via optTemplate().
+// The template is copied and use as the content for that AUM.
+// - signedWith: Set to the name of a key provided via optKey(). This
+// key is used to sign that AUM.
+func newTestchain(t *testing.T, input string, options ...testchainOpt) *testChain {
+ t.Helper()
+
+ var (
+ s scanner.Scanner
+ out = testChain{
+ Nodes: map[string]*testchainNode{},
+ Template: map[string]AUM{},
+ Key: map[string]*Key{},
+ KeyPrivs: map[string]ed25519.PrivateKey{},
+ }
+ )
+
+ // Process any options
+ for _, o := range options {
+ if o.Template != nil {
+ out.Template[o.Name] = *o.Template
+ }
+ if o.Key != nil {
+ out.Key[o.Name] = o.Key
+ out.KeyPrivs[o.Name] = o.Private
+ }
+ if o.SignAllWith {
+ out.SignAllKeys = append(out.SignAllKeys, o.Name)
+ }
+ }
+
+ s.Init(strings.NewReader(input))
+ s.Mode = scanner.ScanIdents | scanner.SkipComments | scanner.ScanComments | scanner.ScanChars | scanner.ScanInts
+ s.Whitespace ^= 1 << '\t' // clear tabs
+ var (
+ lastIdent string
+ lastWasChain bool // if the last token was '->'
+ )
+ for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
+ switch tok {
+ case '\t':
+ t.Fatalf("tabs disallowed, use spaces (seen at %v)", s.Pos())
+
+ case '.': // tweaks, like <ident>.hashSeed = <val>
+ s.Scan()
+ tweak := s.TokenText()
+ if tok := s.Scan(); tok == '=' {
+ s.Scan()
+ switch tweak {
+ case "hashSeed":
+ out.Nodes[lastIdent].HashSeed, _ = strconv.Atoi(s.TokenText())
+ case "template":
+ out.Nodes[lastIdent].Template = s.TokenText()
+ case "signedWith":
+ out.Nodes[lastIdent].SignedWith = s.TokenText()
+ }
+ }
+
+ case scanner.Ident:
+ out.recordPos(s.TokenText(), s.Pos())
+ // If the last token was '->', that means
+ // that the next identifier has a child relationship
+ // with the identifier preceeding '->'.
+ if lastWasChain {
+ out.recordParent(t, s.TokenText(), lastIdent)
+ }
+ lastIdent = s.TokenText()
+
+ case '-': // handle '->'
+ switch s.Peek() {
+ case '>':
+ s.Scan()
+ lastWasChain = true
+ continue
+ }
+
+ case '|': // handle '|'
+ line, col := s.Pos().Line, s.Pos().Column
+ nodeLoop:
+ for _, n := range out.Nodes {
+ for _, p := range n.Uses {
+ // Find the identifier used right here on the line above.
+ if p.Line == line-1 && col <= p.Column && col > p.Column-len(n.Name) {
+ lastIdent = n.Name
+ out.recordPos(n.Name, s.Pos())
+ break nodeLoop
+ }
+ }
+ }
+ }
+ lastWasChain = false
+ // t.Logf("tok = %v, %q", tok, s.TokenText())
+ }
+
+ out.buildChain()
+ return &out
+}
+
+// called from the parser to record the location of an
+// identifier (a named AUM).
+func (c *testChain) recordPos(ident string, pos scanner.Position) {
+ n := c.Nodes[ident]
+ if n == nil {
+ n = &testchainNode{Name: ident}
+ }
+
+ n.Uses = append(n.Uses, pos)
+ c.Nodes[ident] = n
+}
+
+// called from the parser to record a parent relationship between
+// two AUMs.
+func (c *testChain) recordParent(t *testing.T, child, parent string) {
+ if p := c.Nodes[child].Parent; p != "" && p != parent {
+ t.Fatalf("differing parent specified for %s: %q != %q", child, p, parent)
+ }
+ c.Nodes[child].Parent = parent
+}
+
+// called after parsing to build the web of AUM structures.
+// This method populates c.AUMs and c.AUMHashes.
+func (c *testChain) buildChain() {
+ pending := make(map[string]*testchainNode, len(c.Nodes))
+ for k, v := range c.Nodes {
+ pending[k] = v
+ }
+
+ // AUMs with a parent need to know their hash, so we
+ // only compute AUMs who's parents have been computed
+ // each iteration. Since at least the genesis AUM
+ // had no parent, theres always a path to completion
+ // in O(n+1) where n is the number of AUMs.
+ c.AUMs = make(map[string]AUM, len(c.Nodes))
+ c.AUMHashes = make(map[string]AUMHash, len(c.Nodes))
+ for i := 0; i < len(c.Nodes)+1; i++ {
+ if len(pending) == 0 {
+ return
+ }
+
+ next := make([]*testchainNode, 0, 10)
+ for _, v := range pending {
+ if _, parentPending := pending[v.Parent]; !parentPending {
+ next = append(next, v)
+ }
+ }
+
+ for _, v := range next {
+ aum := c.makeAUM(v)
+ h := aum.Hash()
+
+ c.AUMHashes[v.Name] = h
+ c.AUMs[v.Name] = aum
+ delete(pending, v.Name)
+ }
+ }
+ panic("unexpected: incomplete despite len(Nodes)+1 iterations")
+}
+
+func (c *testChain) makeAUM(v *testchainNode) AUM {
+ // By default, the AUM used is just a no-op AUM
+ // with a parent hash set (if any).
+ //
+ // If <AUM>.template is set to the same name as in
+ // a provided optTemplate(), the AUM is built
+ // from a copy of that instead.
+ //
+ // If <AUM>.hashSeed = <int> is set, the KeyID is
+ // tweaked to effect tweaking the hash. This is useful
+ // if you want one AUM to have a lower hash than another.
+ aum := AUM{MessageKind: AUMNoOp}
+ if template := v.Template; template != "" {
+ aum = c.Template[template]
+ }
+ if v.Parent != "" {
+ parentHash := c.AUMHashes[v.Parent]
+ aum.PrevAUMHash = parentHash[:]
+ }
+ if seed := v.HashSeed; seed != 0 {
+ aum.KeyID = []byte{byte(seed)}
+ }
+ if err := aum.StaticValidate(); err != nil {
+ // Usually caused by a test writer specifying a template
+ // AUM which is ultimately invalid.
+ panic(fmt.Sprintf("aum %+v failed static validation: %v", aum, err))
+ }
+
+ sigHash := aum.SigHash()
+ for _, key := range c.SignAllKeys {
+ aum.Signatures = append(aum.Signatures, Signature{
+ KeyID: c.Key[key].ID(),
+ Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
+ })
+ }
+
+ // If the aum was specified as being signed by some key, then
+ // sign it using that key.
+ if key := v.SignedWith; key != "" {
+ aum.Signatures = append(aum.Signatures, Signature{
+ KeyID: c.Key[key].ID(),
+ Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
+ })
+ }
+
+ return aum
+}
+
+// Chonk returns a tailchonk containing all AUMs.
+func (c *testChain) Chonk() Chonk {
+ var out Mem
+ for _, update := range c.AUMs {
+ if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil {
+ panic(err)
+ }
+ }
+ return &out
+}
+
+// ChonkWith returns a tailchonk containing the named AUMs.
+func (c *testChain) ChonkWith(names ...string) Chonk {
+ var out Mem
+ for _, name := range names {
+ update := c.AUMs[name]
+ if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil {
+ panic(err)
+ }
+ }
+ return &out
+}
+
+type testchainOpt struct {
+ Name string
+ Template *AUM
+ Key *Key
+ Private ed25519.PrivateKey
+ SignAllWith bool
+}
+
+func optTemplate(name string, template AUM) testchainOpt {
+ return testchainOpt{
+ Name: name,
+ Template: &template,
+ }
+}
+
+func optKey(name string, key Key, priv ed25519.PrivateKey) testchainOpt {
+ return testchainOpt{
+ Name: name,
+ Key: &key,
+ Private: priv,
+ }
+}
+
+func optSignAllUsing(keyName string) testchainOpt {
+ return testchainOpt{
+ Name: keyName,
+ SignAllWith: true,
+ }
+}
+
+func TestNewTestchain(t *testing.T) {
+ c := newTestchain(t, `
+ genesis -> B -> C
+ | -> D
+ | -> E -> F
+
+ E.hashSeed = 12 // tweak E to have the lowest hash so its chosen
+ F.template = test
+ `, optTemplate("test", AUM{MessageKind: AUMNoOp, KeyID: []byte{10}}))
+
+ want := map[string]*testchainNode{
+ "genesis": &testchainNode{Name: "genesis", Uses: []scanner.Position{{Line: 2, Column: 16}}},
+ "B": &testchainNode{
+ Name: "B",
+ Parent: "genesis",
+ Uses: []scanner.Position{{Line: 2, Column: 21}, {Line: 3, Column: 21}, {Line: 4, Column: 21}},
+ },
+ "C": &testchainNode{Name: "C", Parent: "B", Uses: []scanner.Position{{Line: 2, Column: 26}}},
+ "D": &testchainNode{Name: "D", Parent: "B", Uses: []scanner.Position{{Line: 3, Column: 26}}},
+ "E": &testchainNode{Name: "E", Parent: "B", HashSeed: 12, Uses: []scanner.Position{{Line: 4, Column: 26}, {Line: 6, Column: 10}}},
+ "F": &testchainNode{Name: "F", Parent: "E", Template: "test", Uses: []scanner.Position{{Line: 4, Column: 31}, {Line: 7, Column: 10}}},
+ }
+
+ if diff := cmp.Diff(want, c.Nodes, cmpopts.IgnoreFields(scanner.Position{}, "Offset")); diff != "" {
+ t.Errorf("decoded state differs (-want, +got):\n%s", diff)
+ }
+ if !bytes.Equal(c.AUMs["F"].KeyID, []byte{10}) {
+ t.Errorf("AUM 'F' missing KeyID from template: %v", c.AUMs["F"])
+ }
+
+ // chonk := c.Chonk()
+ // authority, err := Open(chonk)
+ // if err != nil {
+ // t.Errorf("failed to initialize from chonk: %v", err)
+ // }
+
+ // if authority.Head() != c.AUMHashes["F"] {
+ // t.Errorf("head = %X, want %X", authority.Head(), c.AUMHashes["F"])
+ // }
+}
diff --git a/tka/tka.go b/tka/tka.go
index cec790d99..5974950e6 100644
--- a/tka/tka.go
+++ b/tka/tka.go
@@ -4,3 +4,335 @@
// Package tka (WIP) implements the Tailnet Key Authority.
package tka
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "os"
+ "sort"
+)
+
+// A chain describes a linear sequence of updates from Oldest to Head,
+// resulting in some State at Head.
+type chain struct {
+ Oldest AUM
+ Head AUM
+
+ state State
+
+ // Set to true if the AUM chain intersects with the active
+ // chain from a previous run.
+ chainsThroughActive bool
+}
+
+// computeChainCandidates returns all possible chains based on AUMs stored
+// in the given tailchonk. A chain is defined as a unique (oldest, newest)
+// AUM tuple. chain.state is not yet populated in returned chains.
+//
+// If lastKnownOldest is provided, any chain that includes the given AUM
+// has the chainsThroughActive field set to true. This bit is leveraged
+// in computeActiveAncestor() to filter out irrelevant chains when determining
+// the active ancestor from a list of distinct chains.
+func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int) ([]chain, error) {
+ heads, err := storage.Heads()
+ if err != nil {
+ return nil, fmt.Errorf("reading heads: %v", err)
+ }
+ candidates := make([]chain, len(heads))
+ for i := range heads {
+ // Oldest is iteratively computed below.
+ candidates[i] = chain{Oldest: heads[i], Head: heads[i]}
+ }
+ // Not strictly necessary, but simplifies checks in tests.
+ sort.Slice(candidates, func(i, j int) bool {
+ ih, jh := candidates[i].Oldest.Hash(), candidates[j].Oldest.Hash()
+ return bytes.Compare(ih[:], jh[:]) < 0
+ })
+
+ // candidates.Oldest needs to be computed by working backwards from
+ // head as far as we can.
+ iterAgain := true // if theres still work to be done.
+ for i := 0; iterAgain; i++ {
+ if i >= maxIter {
+ return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
+ }
+
+ iterAgain = false
+ for j := range candidates {
+ parent, hasParent := candidates[j].Oldest.Parent()
+ if hasParent {
+ parent, err := storage.AUM(parent)
+ if err != nil {
+ if err == os.ErrNotExist {
+ continue
+ }
+ return nil, fmt.Errorf("reading parent: %v", err)
+ }
+ candidates[j].Oldest = parent
+ if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() {
+ candidates[j].chainsThroughActive = true
+ }
+ iterAgain = true
+ }
+ }
+ }
+ return candidates, nil
+}
+
+// pickNextAUM returns the AUM which should be used as the next
+// AUM in the chain, possibly applying fork resolution logic.
+//
+// In other words: given an AUM with 3 children like this:
+// / - 1
+// P - 2
+// \ - 3
+//
+// pickNextAUM will determine and return the correct branch.
+//
+// This method takes ownership of the provided slice.
+func pickNextAUM(state State, candidates []AUM) AUM {
+ switch len(candidates) {
+ case 0:
+ panic("pickNextAUM called with empty candidate set")
+ case 1:
+ return candidates[0]
+ }
+
+ // Oooof, we have some forks in the chain. We need to pick which
+ // one to use by applying the Fork Resolution Algorithm ✨
+ //
+ // The rules are this:
+ // 1. The child with the highest signature weight is chosen.
+ // 2. If equal, the child which is a RemoveKey AUM is chosen.
+ // 3. If equal, the child with the lowest AUM hash is chosen.
+ sort.Slice(candidates, func(j, i int) bool {
+ // Rule 1.
+ iSigWeight, jSigWeight := candidates[i].Weight(state), candidates[j].Weight(state)
+ if iSigWeight != jSigWeight {
+ return iSigWeight < jSigWeight
+ }
+
+ // Rule 2.
+ if iKind, jKind := candidates[i].MessageKind, candidates[j].MessageKind; iKind != jKind &&
+ (iKind == AUMRemoveKey || jKind == AUMRemoveKey) {
+ return jKind == AUMRemoveKey
+ }
+
+ // Rule 3.
+ iHash, jHash := candidates[i].Hash(), candidates[j].Hash()
+ return bytes.Compare(iHash[:], jHash[:]) > 0
+ })
+
+ return candidates[0]
+}
+
+// advanceChain computes the next AUM to advance with based on all child
+// AUMs, returning the chosen AUM & the state obtained by applying that
+// AUM.
+//
+// The return value for next is nil if there are no children AUMs, hence
+// the provided state is at head (up to date).
+func advanceChain(state State, candidates []AUM) (next *AUM, out State, err error) {
+ if len(candidates) == 0 {
+ return nil, state, nil
+ }
+
+ aum := pickNextAUM(state, candidates)
+ if state, err = state.applyVerifiedAUM(aum); err != nil {
+ return nil, State{}, fmt.Errorf("advancing state: %v", err)
+ }
+ return &aum, state, nil
+}
+
+// fastForward iteratively advances the current state based on known AUMs until
+// the given termination function returns true or there is no more progress possible.
+//
+// The last-processed AUM, and the state computed after applying the last AUM,
+// are returned.
+func fastForward(storage Chonk, maxIter int, startState State, done func(curAUM AUM, curState State) bool) (AUM, State, error) {
+ if startState.LastAUMHash == nil {
+ return AUM{}, State{}, errors.New("invalid initial state")
+ }
+ nextAUM, err := storage.AUM(*startState.LastAUMHash)
+ if err != nil {
+ return AUM{}, State{}, fmt.Errorf("reading next: %v", err)
+ }
+
+ curs := nextAUM
+ state := startState
+ for i := 0; i < maxIter; i++ {
+ if done != nil && done(curs, state) {
+ return curs, state, nil
+ }
+
+ children, err := storage.ChildAUMs(curs.Hash())
+ if err != nil {
+ return AUM{}, State{}, fmt.Errorf("getting children of %X: %v", curs.Hash(), err)
+ }
+ next, nextState, err := advanceChain(state, children)
+ if err != nil {
+ return AUM{}, State{}, fmt.Errorf("advance %X: %v", curs.Hash(), err)
+ }
+ if next == nil {
+ // There were no more children, we are at 'head'.
+ return curs, state, nil
+ }
+ curs = *next
+ state = nextState
+ }
+
+ return AUM{}, State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
+}
+
+// computeStateAt returns the State at wantHash.
+func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) {
+ // TODO(tom): This is going to get expensive for really long
+ // chains. We should make nodes emit a checkpoint every
+ // X updates or something.
+
+ topAUM, err := storage.AUM(wantHash)
+ if err != nil {
+ return State{}, err
+ }
+
+ // Iterate backwards till we find a starting point to compute
+ // the state from.
+ //
+ // Valid starting points are either a checkpoint AUM, or a
+ // genesis AUM.
+ curs := topAUM
+ var state State
+ for i := 0; true; i++ {
+ if i > maxIter {
+ return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
+ }
+
+ // Checkpoints encapsulate the state at that point, dope.
+ if curs.MessageKind == AUMCheckpoint {
+ state = curs.State.cloneForUpdate(&curs)
+ break
+ }
+ parent, hasParent := curs.Parent()
+ if !hasParent {
+ // This is a 'genesis' update: there are none before it, so
+ // this AUM can be applied to the empty state to determine
+ // the state at this AUM.
+ //
+ // It is only valid for NoOp, AddKey, and Checkpoint AUMs
+ // to be a genesis update. Checkpoint was handled earlier.
+ if mk := curs.MessageKind; mk == AUMNoOp || mk == AUMAddKey {
+ var err error
+ if state, err = (State{}).applyVerifiedAUM(curs); err != nil {
+ return State{}, fmt.Errorf("applying genesis (%+v): %v", curs, err)
+ }
+ break
+ }
+ return State{}, fmt.Errorf("invalid genesis update: %+v", curs)
+ }
+
+ // If we got here, the current state is dependent on the previous.
+ // Keep iterating backwards till thats not the case.
+ if curs, err = storage.AUM(parent); err != nil {
+ return State{}, fmt.Errorf("reading parent: %v", err)
+ }
+ }
+
+ // We now know some starting point state. Iterate forward till we
+ // are at the AUM we want state for.
+ _, state, err = fastForward(storage, maxIter, state, func(curs AUM, _ State) bool {
+ return curs.Hash() == wantHash
+ })
+ // fastForward only terminates before the done condition if it
+ // doesnt have any later AUMs to process. This cant be the case
+ // as we've already iterated through them above so they must exist,
+ // but we check anyway to be super duper sure.
+ if err == nil && *state.LastAUMHash != wantHash {
+ panic("unexpected fastForward outcome")
+ }
+ return state, err
+}
+
+// computeActiveAncestor determines which ancestor AUM to use as the
+// ancestor of the valid chain.
+//
+// If all the chains end up having the same ancestor, then thats the
+// only possible ancestor, ezpz. However if there are multiple distinct
+// ancestors, that means there are distinct chains, and we need some
+// hint to choose what to use. For that, we rely on the chainsThroughActive
+// bit, which signals to us that that ancestor was part of the
+// chain in a previous run.
+func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) {
+ // Dedupe possible ancestors, tracking if they were part of
+ // the active chain on a previous run.
+ ancestors := make(map[AUMHash]bool, len(chains))
+ for _, c := range chains {
+ ancestors[c.Oldest.Hash()] = c.chainsThroughActive
+ }
+
+ if len(ancestors) == 1 {
+ // There's only one. DOPE.
+ for k, _ := range ancestors {
+ return k, nil
+ }
+ }
+
+ // Theres more than one, so we need to use the ancestor that was
+ // part of the active chain in a previous iteration.
+ // Note that there can only be one distinct ancestor that was
+ // formerly part of the active chain, because AUMs can only have
+ // one parent and would have converged to a common ancestor.
+ for k, chainsThroughActive := range ancestors {
+ if chainsThroughActive {
+ return k, nil
+ }
+ }
+
+ return AUMHash{}, errors.New("multiple distinct chains")
+}
+
+// computeActiveChain bootstraps the runtime state of the Authority when
+// starting entirely off stored state.
+//
+// TODO(tom): Don't look at head states, just iterate forward from
+// the ancestor.
+//
+// The algorithm is as follows:
+// 1. Determine all possible 'head' (like in git) states.
+// 2. Filter these possible chains based on whether the ancestor was
+// formerly (in a previous run) part of the chain.
+// 3. Compute the state of the state machine at this ancestor. This is
+// needed for fast-forward, as each update operates on the state of
+// the update preceeding it.
+// 4. Iteratively apply updates till we reach head ('fast forward').
+func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (chain, error) {
+ chains, err := computeChainCandidates(storage, lastKnownOldest, maxIter)
+ if err != nil {
+ return chain{}, fmt.Errorf("computing candidates: %v", err)
+ }
+
+ // Find the right ancestor.
+ oldestHash, err := computeActiveAncestor(storage, chains)
+ if err != nil {
+ return chain{}, fmt.Errorf("computing ancestor: %v", err)
+ }
+ ancestor, err := storage.AUM(oldestHash)
+ if err != nil {
+ return chain{}, err
+ }
+
+ // At this stage we know the ancestor AUM, so we have excluded distinct
+ // chains but we might still have forks (so we don't know the head AUM).
+ //
+ // We iterate forward from the ancestor AUM, handling any forks as we go
+ // till we arrive at a head.
+ out := chain{Oldest: ancestor, Head: ancestor}
+ if out.state, err = computeStateAt(storage, maxIter, oldestHash); err != nil {
+ return chain{}, fmt.Errorf("bootstrapping state: %v", err)
+ }
+ out.Head, out.state, err = fastForward(storage, maxIter, out.state, nil)
+ if err != nil {
+ return chain{}, fmt.Errorf("fast forward: %v", err)
+ }
+ return out, nil
+}
diff --git a/tka/tka_test.go b/tka/tka_test.go
new file mode 100644
index 000000000..72b6d3476
--- /dev/null
+++ b/tka/tka_test.go
@@ -0,0 +1,187 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tka
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestComputeChainCandidates(t *testing.T) {
+ c := newTestchain(t, `
+ G1 -> I1 -> I2 -> I3 -> L2
+ | -> L1 | -> L3
+
+ G2 -> L4
+
+ // We tweak these AUMs so they are different hashes.
+ G2.hashSeed = 2
+ L1.hashSeed = 2
+ L3.hashSeed = 2
+ L4.hashSeed = 3
+ `)
+ // Should result in 4 chains:
+ // G1->L1, G1->L2, G1->L3, G2->L4
+
+ i1H := c.AUMHashes["I1"]
+ got, err := computeChainCandidates(c.Chonk(), &i1H, 50)
+ if err != nil {
+ t.Fatalf("computeChainCandidates() failed: %v", err)
+ }
+
+ want := []chain{
+ {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true},
+ {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true},
+ {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true},
+ {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]},
+ }
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" {
+ t.Errorf("chains differ (-want, +got):\n%s", diff)
+ }
+}
+
+func TestForkResolutionHash(t *testing.T) {
+ c := newTestchain(t, `
+ G1 -> L1
+ | -> L2
+
+ // tweak hashes so L1 & L2 are not identical
+ L1.hashSeed = 2
+ L2.hashSeed = 3
+ `)
+
+ got, err := computeActiveChain(c.Chonk(), nil, 50)
+ if err != nil {
+ t.Fatalf("computeActiveChain() failed: %v", err)
+ }
+
+ // The fork with the lowest AUM hash should have been chosen.
+ l1H := c.AUMHashes["L1"]
+ l2H := c.AUMHashes["L2"]
+ want := l1H
+ if bytes.Compare(l2H[:], l1H[:]) < 0 {
+ want = l2H
+ }
+
+ if got := got.Head.Hash(); got != want {
+ t.Errorf("head was %x, want %x", got, want)
+ }
+}
+
+func TestForkResolutionSigWeight(t *testing.T) {
+ pub, priv := testingKey25519(t, 1)
+ key := Key{Kind: Key25519, Public: pub, Votes: 2}
+
+ c := newTestchain(t, `
+ G1 -> L1
+ | -> L2
+
+ G1.template = addKey
+ L1.hashSeed = 2
+ L2.signedWith = key
+ `,
+ optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}),
+ optKey("key", key, priv))
+
+ l1H := c.AUMHashes["L1"]
+ l2H := c.AUMHashes["L2"]
+ if bytes.Compare(l2H[:], l1H[:]) < 0 {
+ t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes")
+ }
+
+ got, err := computeActiveChain(c.Chonk(), nil, 50)
+ if err != nil {
+ t.Fatalf("computeActiveChain() failed: %v", err)
+ }
+
+ // Based on the hash, l1H should be chosen.
+ // But based on the signature weight (which has higher
+ // precedence), it should be l2H
+ want := l2H
+ if got := got.Head.Hash(); got != want {
+ t.Errorf("head was %x, want %x", got, want)
+ }
+}
+
+func TestForkResolutionMessageType(t *testing.T) {
+ pub, _ := testingKey25519(t, 1)
+ key := Key{Kind: Key25519, Public: pub, Votes: 2}
+
+ c := newTestchain(t, `
+ G1 -> L1
+ | -> L2
+ | -> L3
+
+ G1.template = addKey
+ L1.hashSeed = 11
+ L2.template = removeKey
+ L3.hashSeed = 18
+ `,
+ optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}),
+ optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.ID()}))
+
+ l1H := c.AUMHashes["L1"]
+ l2H := c.AUMHashes["L2"]
+ l3H := c.AUMHashes["L3"]
+ if bytes.Compare(l2H[:], l1H[:]) < 0 {
+ t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes")
+ }
+ if bytes.Compare(l2H[:], l3H[:]) < 0 {
+ t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes")
+ }
+
+ got, err := computeActiveChain(c.Chonk(), nil, 50)
+ if err != nil {
+ t.Fatalf("computeActiveChain() failed: %v", err)
+ }
+
+ // Based on the hash, L1 or L3 should be chosen.
+ // But based on the preference for AUMRemoveKey messages,
+ // it should be L2.
+ want := l2H
+ if got := got.Head.Hash(); got != want {
+ t.Errorf("head was %x, want %x", got, want)
+ }
+}
+
+func TestComputeStateAt(t *testing.T) {
+ pub, _ := testingKey25519(t, 1)
+ key := Key{Kind: Key25519, Public: pub, Votes: 2}
+
+ c := newTestchain(t, `
+ G1 -> I1 -> I2
+ I1.template = addKey
+ `,
+ optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}))
+
+ // G1 is before the key, so there shouldn't be a key there.
+ state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"])
+ if err != nil {
+ t.Fatalf("computeStateAt(G1) failed: %v", err)
+ }
+ if _, err := state.GetKey(key.ID()); err != ErrNoSuchKey {
+ t.Errorf("expected key to be missing: err = %v", err)
+ }
+ if *state.LastAUMHash != c.AUMHashes["G1"] {
+ t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"])
+ }
+
+ // I1 & I2 are after the key, so the computed state should contain
+ // the key.
+ for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} {
+ state, err = computeStateAt(c.Chonk(), 500, wantHash)
+ if err != nil {
+ t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err)
+ }
+ if *state.LastAUMHash != wantHash {
+ t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash)
+ }
+ if _, err := state.GetKey(key.ID()); err != nil {
+ t.Errorf("expected key to be present at state: err = %v", err)
+ }
+ }
+}