summaryrefslogtreecommitdiffhomepage
path: root/cmd/natc
diff options
context:
space:
mode:
authorFran Bull <fran@tailscale.com>2024-09-30 10:27:39 -0700
committerFran Bull <fran@tailscale.com>2024-11-06 11:45:16 -0800
commit5f24261a1eee73b619e40b328ce20c709eff57df (patch)
tree8fc12810a09a90f1c9c7ed3812422e9b727bdb6b /cmd/natc
parent8dcbd988f7653aa17b33094d3f917125414aeab6 (diff)
downloadtailscale-fran/natc-raft.tar.xz
tailscale-fran/natc-raft.zip
Diffstat (limited to 'cmd/natc')
-rw-r--r--cmd/natc/consensus.go267
-rw-r--r--cmd/natc/http.go132
-rw-r--r--cmd/natc/ippool.go257
-rw-r--r--cmd/natc/ippool_test.go129
-rw-r--r--cmd/natc/natc.go154
5 files changed, 833 insertions, 106 deletions
diff --git a/cmd/natc/consensus.go b/cmd/natc/consensus.go
new file mode 100644
index 000000000..c326bab94
--- /dev/null
+++ b/cmd/natc/consensus.go
@@ -0,0 +1,267 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+ "time"
+
+ "github.com/hashicorp/raft"
+ "tailscale.com/ipn/ipnstate"
+ "tailscale.com/tsnet"
+)
+
+type consensus struct {
+ Raft *raft.Raft
+ CommandClient *commandClient
+ Self selfRaftNode
+}
+
+type selfRaftNode struct {
+ ID string
+ Addr netip.Addr
+}
+
+func (n *selfRaftNode) addrRaftPort() netip.AddrPort {
+ return netip.AddrPortFrom(n.Addr, 6311)
+}
+
+// StreamLayer implements an interface asked for by raft.NetworkTransport.
+// Do the raft interprocess comms via tailscale.
+type StreamLayer struct {
+ net.Listener
+ s *tsnet.Server
+}
+
+// Dial is used to create a new outgoing connection
+func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ return sl.s.Dial(ctx, "tcp", string(address))
+}
+
+type listeners struct {
+ raft *StreamLayer // for the raft goroutine
+ command net.Listener // for the command http goroutine
+}
+
+func NewConsensus(myAddr netip.Addr, httpClient *http.Client) *consensus {
+ cc := commandClient{
+ port: 6312,
+ httpClient: httpClient,
+ }
+ self := selfRaftNode{
+ ID: myAddr.String(),
+ Addr: myAddr,
+ }
+ return &consensus{
+ CommandClient: &cc,
+ Self: self,
+ }
+}
+
+func (c *consensus) Start(lns *listeners, sm *fsm) error {
+ config := raft.DefaultConfig()
+ config.LocalID = raft.ServerID(c.Self.ID)
+ config.HeartbeatTimeout = 1000 * time.Millisecond
+ config.ElectionTimeout = 1000 * time.Millisecond
+ logStore := raft.NewInmemStore()
+ stableStore := raft.NewInmemStore()
+ snapshots := raft.NewInmemSnapshotStore()
+ transport := raft.NewNetworkTransport(lns.raft, 5, 5*time.Second, nil)
+
+ ra, err := raft.NewRaft(config, sm, logStore, stableStore, snapshots, transport)
+ if err != nil {
+ return fmt.Errorf("new raft: %s", err)
+ }
+ c.Raft = ra
+
+ mux := c.makeCommandMux()
+ go func() {
+ defer lns.command.Close()
+ log.Fatal(http.Serve(lns.command, mux))
+ }()
+ return nil
+}
+
+func (c *consensus) handleJoin(jr joinRequest) error {
+ configFuture := c.Raft.GetConfiguration()
+ if err := configFuture.Error(); err != nil {
+ return err
+ }
+
+ for _, srv := range configFuture.Configuration().Servers {
+ // If a node already exists with either the joining node's ID or address,
+ // that node may need to be removed from the config first.
+ if srv.ID == raft.ServerID(jr.RemoteID) || srv.Address == raft.ServerAddress(jr.RemoteAddr) {
+ // However if *both* the ID and the address are the same, then nothing -- not even
+ // a join operation -- is needed.
+ if srv.Address == raft.ServerAddress(jr.RemoteAddr) && srv.ID == raft.ServerID(jr.RemoteID) {
+ log.Printf("node %s at %s already member of cluster, ignoring join request", jr.RemoteID, jr.RemoteAddr)
+ return nil
+ }
+
+ future := c.Raft.RemoveServer(srv.ID, 0, 0)
+ if err := future.Error(); err != nil {
+ return fmt.Errorf("error removing existing node %s at %s: %s", jr.RemoteID, jr.RemoteAddr, err)
+ }
+ }
+ }
+
+ f := c.Raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(jr.RemoteAddr), 0, 0)
+ if f.Error() != nil {
+ return f.Error()
+ }
+ return nil
+}
+
+// try to join a raft cluster, or start one
+func BootstrapConsensus(sm *fsm, myAddr netip.Addr, lns *listeners, targets []*ipnstate.PeerStatus, httpClient *http.Client) (*consensus, error) {
+ cns := NewConsensus(myAddr, httpClient)
+ err := cns.Start(lns, sm)
+ if err != nil {
+ return cns, err
+ }
+ joined := false
+ log.Printf("Trying to find cluster: num targets to try: %d", len(targets))
+ for _, p := range targets {
+ if !p.Online {
+ log.Printf("Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0])
+ } else {
+ log.Printf("Trying to find cluster: trying %s", p.TailscaleIPs[0])
+ err = cns.JoinCluster(p.TailscaleIPs[0])
+ if err != nil {
+ log.Printf("Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err)
+ } else {
+ log.Printf("Trying to find cluster: joined %s", p.TailscaleIPs[0])
+ joined = true
+ break
+ }
+ }
+ }
+
+ if !joined {
+ log.Printf("Trying to find cluster: unsuccessful, starting as leader: %s", myAddr)
+ err = cns.LeadCluster()
+ if err != nil {
+ return cns, err
+ }
+ }
+ return cns, nil
+}
+
+func (c *consensus) JoinCluster(a netip.Addr) error {
+ return c.CommandClient.Join(c.CommandClient.ServerAddressFromAddr(a), joinRequest{
+ RemoteAddr: c.Self.addrRaftPort().String(),
+ RemoteID: c.Self.ID,
+ })
+
+}
+
+func (c *consensus) LeadCluster() error {
+ configuration := raft.Configuration{
+ Servers: []raft.Server{
+ {
+ ID: raft.ServerID(c.Self.ID),
+ Address: raft.ServerAddress(fmt.Sprintf("%s:6311", c.Self.Addr)),
+ },
+ },
+ }
+ f := c.Raft.BootstrapCluster(configuration)
+ return f.Error()
+}
+
+// plumbing for executing a command either locally or via http transport
+// and telling peers we're not the leader and who we think the leader is
+type command struct {
+ Name string
+ Args []byte
+}
+
+type commandResult struct {
+ Err error
+ Result []byte
+}
+
+type lookElsewhereError struct {
+ where string
+}
+
+func (e lookElsewhereError) Error() string {
+ return fmt.Sprintf("not the leader, try: %s", e.where)
+}
+
+func (c *consensus) executeCommandLocally(cmd command) (commandResult, error) {
+ b, err := json.Marshal(cmd)
+ if err != nil {
+ return commandResult{}, err
+ }
+ f := c.Raft.Apply(b, 10*time.Second)
+ err = f.Error()
+ result := f.Response()
+ if errors.Is(err, raft.ErrNotLeader) {
+ raftLeaderAddr, _ := c.Raft.LeaderWithID()
+ leaderAddr := (string)(raftLeaderAddr)
+ if leaderAddr != "" {
+ leaderAddr = leaderAddr[:len(raftLeaderAddr)-1] + "2" // TODO
+ }
+ return commandResult{}, lookElsewhereError{where: leaderAddr}
+ }
+ return result.(commandResult), err
+}
+
+func (c *consensus) executeCommand(cmd command) (commandResult, error) {
+ b, err := json.Marshal(cmd)
+ if err != nil {
+ return commandResult{}, err
+ }
+ result, err := c.executeCommandLocally(cmd)
+ var leErr lookElsewhereError
+ for errors.As(err, &leErr) {
+ result, err = c.CommandClient.ExecuteCommand(leErr.where, b)
+ }
+ return result, err
+}
+
+// fulfil the raft lib functional state machine interface
+type fsm ipPool
+type fsmSnapshot struct{}
+
+func (f *fsm) Apply(l *raft.Log) interface{} {
+ var c command
+ if err := json.Unmarshal(l.Data, &c); err != nil {
+ panic(fmt.Sprintf("failed to unmarshal command: %s", err.Error()))
+ }
+ switch c.Name {
+ case "checkoutAddr":
+ return f.executeCheckoutAddr(c.Args)
+ case "markLastUsed":
+ return f.executeMarkLastUsed(c.Args)
+ default:
+ panic(fmt.Sprintf("unrecognized command: %s", c.Name))
+ }
+}
+
+func (f *fsm) Snapshot() (raft.FSMSnapshot, error) {
+ panic("Snapshot unexpectedly used")
+ return nil, nil
+}
+
+func (f *fsm) Restore(rc io.ReadCloser) error {
+ panic("Restore unexpectedly used")
+ return nil
+}
+
+func (f *fsmSnapshot) Persist(sink raft.SnapshotSink) error {
+ panic("Persist unexpectedly used")
+ return nil
+}
+
+func (f *fsmSnapshot) Release() {
+ panic("Release unexpectedly used")
+}
diff --git a/cmd/natc/http.go b/cmd/natc/http.go
new file mode 100644
index 000000000..06ff334d7
--- /dev/null
+++ b/cmd/natc/http.go
@@ -0,0 +1,132 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/netip"
+ "time"
+)
+
+type joinRequest struct {
+ RemoteAddr string `json:'remoteAddr'`
+ RemoteID string `json:'remoteID'`
+}
+
+type commandClient struct {
+ port int
+ httpClient *http.Client
+}
+
+func (rac *commandClient) ServerAddressFromAddr(addr netip.Addr) string {
+ return fmt.Sprintf("%s:%d", addr, rac.port)
+}
+
+func (rac *commandClient) Url(serverAddr string, path string) string {
+ return fmt.Sprintf("http://%s%s", serverAddr, path)
+}
+
+func (rac *commandClient) Join(serverAddr string, jr joinRequest) error {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ rBs, err := json.Marshal(jr)
+ if err != nil {
+ return err
+ }
+ url := rac.Url(serverAddr, "/join")
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rBs))
+ if err != nil {
+ return err
+ }
+ resp, err := rac.httpClient.Do(req)
+ if err != nil {
+ return err
+ }
+ respBs, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != 200 {
+ return errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs)))
+ }
+ return nil
+}
+
+func (rac *commandClient) ExecuteCommand(serverAddr string, bs []byte) (commandResult, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ url := rac.Url(serverAddr, "/executeCommand")
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs))
+ if err != nil {
+ return commandResult{}, err
+ }
+ resp, err := rac.httpClient.Do(req)
+ if err != nil {
+ return commandResult{}, err
+ }
+ respBs, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return commandResult{}, err
+ }
+ if resp.StatusCode != 200 {
+ return commandResult{}, errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs)))
+ }
+ var cr commandResult
+ if err = json.Unmarshal(respBs, &cr); err != nil {
+ return commandResult{}, err
+ }
+ return cr, nil
+}
+
+func (c *consensus) makeCommandMux() *http.ServeMux {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Bad Request", http.StatusBadRequest)
+ return
+ }
+ decoder := json.NewDecoder(r.Body)
+ var jr joinRequest
+ err := decoder.Decode(&jr)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ if jr.RemoteAddr == "" {
+ http.Error(w, "Required: remoteAddr", http.StatusBadRequest)
+ return
+ }
+ if jr.RemoteID == "" {
+ http.Error(w, "Required: remoteID", http.StatusBadRequest)
+ return
+ }
+ err = c.handleJoin(jr)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ })
+ mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Bad Request", http.StatusBadRequest)
+ return
+ }
+ decoder := json.NewDecoder(r.Body)
+ var cmd command
+ err := decoder.Decode(&cmd)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ result, err := c.executeCommandLocally(cmd)
+ if err := json.NewEncoder(w).Encode(result); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ })
+ return mux
+}
diff --git a/cmd/natc/ippool.go b/cmd/natc/ippool.go
new file mode 100644
index 000000000..091a93c87
--- /dev/null
+++ b/cmd/natc/ippool.go
@@ -0,0 +1,257 @@
+package main
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "net/netip"
+ "sync"
+ "time"
+
+ "github.com/gaissmai/bart"
+ "tailscale.com/ipn/ipnstate"
+ "tailscale.com/syncs"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tsnet"
+ "tailscale.com/util/mak"
+)
+
+type ipPool struct {
+ perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState]
+ v4Ranges []netip.Prefix
+ dnsAddr netip.Addr
+ consensus *consensus
+}
+
+func (ipp *ipPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, updatedAt time.Time) string {
+ // TODO lock
+ pm, ok := ipp.perPeerMap.Load(from)
+ if !ok {
+ log.Printf("DomainForIP: peer state absent for: %d", from)
+ return ""
+ }
+ ww, ok := pm.AddrToDomain.Lookup(addr)
+ if !ok {
+ log.Printf("DomainForIP: peer state doesn't recognize domain")
+ return ""
+ }
+ go func() {
+ err := ipp.markLastUsed(from, addr, ww.Domain, updatedAt)
+ if err != nil {
+ panic(err)
+ }
+ }()
+ return ww.Domain
+}
+
+type markLastUsedArgs struct {
+ NodeID tailcfg.NodeID
+ Addr netip.Addr
+ Domain string
+ UpdatedAt time.Time
+}
+
+// called by raft
+func (cd *fsm) executeMarkLastUsed(bs []byte) commandResult {
+ var args markLastUsedArgs
+ err := json.Unmarshal(bs, &args)
+ if err != nil {
+ return commandResult{Err: err}
+ }
+ err = cd.applyMarkLastUsed(args.NodeID, args.Addr, args.Domain, args.UpdatedAt)
+ if err != nil {
+ return commandResult{Err: err}
+ }
+ return commandResult{}
+}
+
+func (ipp *fsm) applyMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, domain string, updatedAt time.Time) error {
+ // TODO lock
+ ps, ok := ipp.perPeerMap.Load(from)
+ if !ok {
+ // unexpected in normal operation (but not an error?)
+ return nil
+ }
+ ww, ok := ps.AddrToDomain.Lookup(addr)
+ if !ok {
+ // unexpected in normal operation (but not an error?)
+ return nil
+ }
+ if ww.Domain != domain {
+ // then I guess we're too late to update lastUsed
+ return nil
+ }
+ if ww.LastUsed.After(updatedAt) {
+ // prefer the most recent
+ return nil
+ }
+ ww.LastUsed = updatedAt
+ ps.AddrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), ww)
+ return nil
+}
+
+func (ipp *ipPool) StartConsensus(peers []*ipnstate.PeerStatus, ts *tsnet.Server) {
+ v4, _ := ts.TailscaleIPs()
+ adminLn, err := ts.Listen("tcp", fmt.Sprintf("%s:6312", v4))
+ if err != nil {
+ log.Fatal(err)
+ }
+ raftLn, err := ts.Listen("tcp", fmt.Sprintf("%s:6311", v4))
+ if err != nil {
+ log.Fatal(err)
+ }
+ sl := StreamLayer{s: ts, Listener: raftLn}
+ lns := listeners{command: adminLn, raft: &sl}
+ cns, err := BootstrapConsensus((*fsm)(ipp), v4, &lns, peers, ts.HTTPClient())
+ if err != nil {
+ log.Fatalf("BootstrapConsensus failed: %v", err)
+ }
+ ipp.consensus = cns
+}
+
+type whereWhen struct {
+ Domain string
+ LastUsed time.Time
+}
+
+type perPeerState struct {
+ DomainToAddr map[string]netip.Addr
+ AddrToDomain *bart.Table[whereWhen]
+ mu sync.Mutex // not jsonified
+}
+
+func (ps *perPeerState) unusedIPV4(ranges []netip.Prefix, exclude netip.Addr, reuseDeadline time.Time) (netip.Addr, bool, string, error) {
+ // TODO here we iterate through each ip within the ranges until we find one that's unused
+ // could be done more efficiently either by:
+ // 1) storing an index into ranges and an ip we had last used from that range in perPeerState
+ // (how would this work with checking ips back into the pool though?)
+ // 2) using a random approach like the natc does now, except the raft state machine needs to
+ // be deterministic so it can replay logs, so I think we would do something like generate a
+ // random ip each time, and then have a call into the state machine that says "give me whatever
+ // ip you have, and if you don't have one use this one". I think that would work.
+ for _, r := range ranges {
+ ip := r.Addr()
+ for r.Contains(ip) {
+ if ip != exclude {
+ ww, ok := ps.AddrToDomain.Lookup(ip)
+ if !ok {
+ return ip, false, "", nil
+ }
+ if ww.LastUsed.Before(reuseDeadline) {
+ return ip, true, ww.Domain, nil
+ }
+ }
+ ip = ip.Next()
+ }
+ }
+ return netip.Addr{}, false, "", errors.New("ip pool exhausted")
+}
+
+func (cd *ipPool) IpForDomain(nid tailcfg.NodeID, domain string) (netip.Addr, error) {
+ now := time.Now()
+ args := checkoutAddrArgs{
+ NodeID: nid,
+ Domain: domain,
+ ReuseDeadline: now.Add(-10 * time.Second), // TODO what time period? 48 hours?
+ UpdatedAt: now,
+ }
+ bs, err := json.Marshal(args)
+ if err != nil {
+ return netip.Addr{}, err
+ }
+ c := command{
+ Name: "checkoutAddr",
+ Args: bs,
+ }
+ result, err := cd.consensus.executeCommand(c)
+ if err != nil {
+ log.Printf("IpForDomain: raft error executing command: %v", err)
+ return netip.Addr{}, err
+ }
+ if result.Err != nil {
+ log.Printf("IpForDomain: error returned from state machine: %v", err)
+ return netip.Addr{}, result.Err
+ }
+ var addr netip.Addr
+ err = json.Unmarshal(result.Result, &addr)
+ return addr, err
+}
+
+func (cd *ipPool) markLastUsed(nid tailcfg.NodeID, addr netip.Addr, domain string, lastUsed time.Time) error {
+ args := markLastUsedArgs{
+ NodeID: nid,
+ Addr: addr,
+ Domain: domain,
+ UpdatedAt: lastUsed,
+ }
+ bs, err := json.Marshal(args)
+ if err != nil {
+ return err
+ }
+ c := command{
+ Name: "markLastUsed",
+ Args: bs,
+ }
+ result, err := cd.consensus.executeCommand(c)
+ if err != nil {
+ log.Printf("markLastUsed: raft error executing command: %v", err)
+ return err
+ }
+ if result.Err != nil {
+ log.Printf("markLastUsed: error returned from state machine: %v", err)
+ return result.Err
+ }
+ return nil
+}
+
+type checkoutAddrArgs struct {
+ NodeID tailcfg.NodeID
+ Domain string
+ ReuseDeadline time.Time
+ UpdatedAt time.Time
+}
+
+// called by raft
+func (cd *fsm) executeCheckoutAddr(bs []byte) commandResult {
+ var args checkoutAddrArgs
+ err := json.Unmarshal(bs, &args)
+ if err != nil {
+ return commandResult{Err: err}
+ }
+ addr, err := cd.applyCheckoutAddr(args.NodeID, args.Domain, args.ReuseDeadline, args.UpdatedAt)
+ if err != nil {
+ return commandResult{Err: err}
+ }
+ resultBs, err := json.Marshal(addr)
+ if err != nil {
+ return commandResult{Err: err}
+ }
+ return commandResult{Result: resultBs}
+}
+
+func (cd *fsm) applyCheckoutAddr(nid tailcfg.NodeID, domain string, reuseDeadline, updatedAt time.Time) (netip.Addr, error) {
+ // TODO lock and unlock
+ pm, _ := cd.perPeerMap.LoadOrStore(nid, &perPeerState{
+ AddrToDomain: &bart.Table[whereWhen]{},
+ })
+ if existing, ok := pm.DomainToAddr[domain]; ok {
+ // TODO handle error case where this doesn't exist
+ ww, _ := pm.AddrToDomain.Lookup(existing)
+ ww.LastUsed = updatedAt
+ pm.AddrToDomain.Insert(netip.PrefixFrom(existing, existing.BitLen()), ww)
+ return existing, nil
+ }
+ addr, wasInUse, previousDomain, err := pm.unusedIPV4(cd.v4Ranges, cd.dnsAddr, reuseDeadline)
+ if err != nil {
+ return netip.Addr{}, err
+ }
+ mak.Set(&pm.DomainToAddr, domain, addr)
+ if wasInUse {
+ // remove it from domaintoaddr
+ delete(pm.DomainToAddr, previousDomain)
+ // don't need to remove it from addrtodomain, insert will do that
+ }
+ pm.AddrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), whereWhen{Domain: domain, LastUsed: updatedAt})
+ return addr, nil
+}
diff --git a/cmd/natc/ippool_test.go b/cmd/natc/ippool_test.go
new file mode 100644
index 000000000..fd5fc9c3e
--- /dev/null
+++ b/cmd/natc/ippool_test.go
@@ -0,0 +1,129 @@
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/netip"
+ "testing"
+
+ "tailscale.com/tailcfg"
+)
+
+func TestV6V4(t *testing.T) {
+ c := connector{
+ v6ULA: ula(uint16(1)),
+ }
+
+ tests := [][]string{
+ []string{"100.64.0.0", "fd7a:115c:a1e0:a99c:1:0:6440:0"},
+ []string{"0.0.0.0", "fd7a:115c:a1e0:a99c:1::"},
+ []string{"255.255.255.255", "fd7a:115c:a1e0:a99c:1:0:ffff:ffff"},
+ }
+
+ for i, test := range tests {
+ // to v6
+ v6 := c.v6ForV4(netip.MustParseAddr(test[0]))
+ want := netip.MustParseAddr(test[1])
+ if v6 != want {
+ t.Fatalf("test %d: want: %v, got: %v", i, want, v6)
+ }
+
+ // to v4
+ v4 := v4ForV6(netip.MustParseAddr(test[1]))
+ want = netip.MustParseAddr(test[0])
+ if v4 != want {
+ t.Fatalf("test %d: want: %v, got: %v", i, want, v4)
+ }
+ }
+}
+
+func TestIPForDomain(t *testing.T) {
+ pfx := netip.MustParsePrefix("100.64.0.0/16")
+ ipp := fsm{
+ v4Ranges: []netip.Prefix{pfx},
+ dnsAddr: netip.MustParseAddr("100.64.0.0"),
+ }
+ a, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "example.com")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !pfx.Contains(a) {
+ t.Fatalf("expected %v to be in the prefix %v", a, pfx)
+ }
+
+ b, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "a.example.com")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !pfx.Contains(b) {
+ t.Fatalf("expected %v to be in the prefix %v", b, pfx)
+ }
+ if b == a {
+ t.Fatalf("same address issued twice %v, %v", a, b)
+ }
+
+ c, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "example.com")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if c != a {
+ t.Fatalf("expected %v to be remembered as the addr for example.com, but got %v", a, c)
+ }
+}
+
+func TestDomainForIP(t *testing.T) {
+ pfx := netip.MustParsePrefix("100.64.0.0/16")
+ sm := fsm{
+ v4Ranges: []netip.Prefix{pfx},
+ dnsAddr: netip.MustParseAddr("100.64.0.0"),
+ }
+ ipp := (*ipPool)(&sm)
+ nid := tailcfg.NodeID(1)
+ domain := "example.com"
+ d := ipp.DomainForIP(nid, netip.MustParseAddr("100.64.0.1"))
+ if d != "" {
+ t.Fatalf("expected an empty string if the addr is not found but got %s", d)
+ }
+ a, err := sm.applyCheckoutAddr(nid, domain)
+ if err != nil {
+ t.Fatal(err)
+ }
+ d2 := ipp.DomainForIP(nid, a)
+ if d2 != domain {
+ t.Fatalf("expected %s but got %s", domain, d2)
+ }
+}
+
+func TestBlah(t *testing.T) {
+ type ecr interface {
+ getResult() interface{}
+ setResult(interface{})
+ toJSON() ([]byte, error)
+ fromJSON([]byte) err
+ }
+ type fran struct {
+ Result netip.Addr
+ }
+ func(f *fran) toJSON() string {
+ return json.Marshal(f)
+ }
+ func(f *fran) fromJSON(bs []byte) err {
+ return json.UnMarshal(bs, f)
+ }
+ thrujson := func(in ecr) ecr {
+ bs, err := json.Marshal(in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var out ecr
+ err = json.Unmarshal(bs, &out)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return out
+ }
+ a := netip.Addr{}
+ out := thrujson(ecr{Result: a}).Result
+ b := (out).(netip.Addr)
+ fmt.Println(b)
+}
diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go
index d94523c6e..433bdc6a6 100644
--- a/cmd/natc/natc.go
+++ b/cmd/natc/natc.go
@@ -8,18 +8,16 @@ package main
import (
"context"
- "encoding/binary"
"errors"
"flag"
"fmt"
"log"
- "math/rand/v2"
"net"
"net/http"
"net/netip"
"os"
+ "slices"
"strings"
- "sync"
"time"
"github.com/gaissmai/bart"
@@ -30,13 +28,11 @@ import (
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/ipn"
+ "tailscale.com/ipn/ipnstate"
"tailscale.com/net/netutil"
- "tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tsweb"
- "tailscale.com/util/dnsname"
- "tailscale.com/util/mak"
)
func main() {
@@ -56,6 +52,7 @@ func main() {
printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit")
ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore")
wgPort = fs.Uint("wg-port", 0, "udp port for wireguard and peer to peer traffic")
+ clusterTag = fs.String("cluster-tag", "", "TODO")
)
ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_NATC"))
@@ -105,6 +102,7 @@ func main() {
ts := &tsnet.Server{
Hostname: *hostname,
}
+ ts.ControlURL = "http://host.docker.internal:31544"
if *wgPort != 0 {
if *wgPort >= 1<<16 {
log.Fatalf("wg-port must be in the range [0, 65535]")
@@ -112,6 +110,7 @@ func main() {
ts.Port = uint16(*wgPort)
}
defer ts.Close()
+
if *verboseTSNet {
ts.Logf = log.Printf
}
@@ -136,6 +135,28 @@ func main() {
if _, err := ts.Up(ctx); err != nil {
log.Fatalf("ts.Up: %v", err)
}
+ woo, err := lc.Status(ctx)
+ if err != nil {
+ panic(err)
+ }
+ var peers []*ipnstate.PeerStatus
+ if *clusterTag != "" && woo.Self.Tags != nil && slices.Contains(woo.Self.Tags.AsSlice(), *clusterTag) {
+ for _, v := range woo.Peer {
+ if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), *clusterTag) {
+ peers = append(peers, v)
+ }
+ }
+ } else {
+ // we are not in clustering mode I guess?
+ panic("todo")
+ }
+
+ ipp := ipPool{
+ v4Ranges: v4Prefixes,
+ dnsAddr: dnsAddr,
+ }
+
+ ipp.StartConsensus(peers, ts)
c := &connector{
ts: ts,
@@ -144,6 +165,7 @@ func main() {
v4Ranges: v4Prefixes,
v6ULA: ula(uint16(*siteID)),
ignoreDsts: ignoreDstTable,
+ ipAddrs: &ipp,
}
c.run(ctx)
}
@@ -165,7 +187,7 @@ type connector struct {
// v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
v6ULA netip.Prefix
- perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState]
+ ipAddrs *ipPool
// ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil)
// It is never mutated, only used for lookups.
@@ -332,16 +354,15 @@ var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
// generateDNSResponse generates a DNS response for the given request. The from
// argument is the NodeID of the node that sent the request.
func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) {
- pm, _ := c.perPeerMap.LoadOrStore(from, &perPeerState{c: c})
var addrs []netip.Addr
if len(req.Questions) > 0 {
switch req.Questions[0].Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
- var err error
- addrs, err = pm.ipForDomain(req.Questions[0].Name.String())
+ v4, err := c.ipAddrs.IpForDomain(from, req.Questions[0].Name.String())
if err != nil {
return nil, err
}
+ addrs = []netip.Addr{v4, c.v6ForV4(v4)}
}
}
return dnsResponse(req, addrs)
@@ -429,14 +450,13 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con
}
from := who.Node.ID
- ps, ok := c.perPeerMap.Load(from)
- if !ok {
- log.Printf("handleTCPFlow: no perPeerState for %v", from)
- return nil, false
+ dstAddr := dst.Addr()
+ if dstAddr.Is6() {
+ dstAddr = v4ForV6(dstAddr)
}
- domain, ok := ps.domainForIP(dst.Addr())
- if !ok {
- log.Printf("handleTCPFlow: no domain for IP %v\n", dst.Addr())
+ domain := c.ipAddrs.DomainForIP(from, dstAddr, time.Now())
+ if domain == "" {
+ log.Print("handleTCPFlow: found no domain")
return nil, false
}
return func(conn net.Conn) {
@@ -480,96 +500,18 @@ func proxyTCPConn(c net.Conn, dest string) {
p.Start()
}
-// perPeerState holds the state for a single peer.
-type perPeerState struct {
- c *connector
-
- mu sync.Mutex
- domainToAddr map[string][]netip.Addr
- addrToDomain *bart.Table[string]
-}
-
-// domainForIP returns the domain name assigned to the given IP address and
-// whether it was found.
-func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) {
- ps.mu.Lock()
- defer ps.mu.Unlock()
- if ps.addrToDomain == nil {
- return "", false
- }
- return ps.addrToDomain.Lookup(ip)
-}
-
-// ipForDomain assigns a pair of unique IP addresses for the given domain and
-// returns them. The first address is an IPv4 address and the second is an IPv6
-// address. If the domain already has assigned addresses, it returns them.
-func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) {
- fqdn, err := dnsname.ToFQDN(domain)
- if err != nil {
- return nil, err
- }
- domain = fqdn.WithoutTrailingDot()
-
- ps.mu.Lock()
- defer ps.mu.Unlock()
- if addrs, ok := ps.domainToAddr[domain]; ok {
- return addrs, nil
- }
- addrs := ps.assignAddrsLocked(domain)
- return addrs, nil
-}
-
-// isIPUsedLocked reports whether the given IP address is already assigned to a
-// domain.
-// ps.mu must be held.
-func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool {
- _, ok := ps.addrToDomain.Lookup(ip)
- return ok
-}
-
-// unusedIPv4Locked returns an unused IPv4 address from the available ranges.
-func (ps *perPeerState) unusedIPv4Locked() netip.Addr {
- // TODO: skip ranges that have been exhausted
- for _, r := range ps.c.v4Ranges {
- ip := randV4(r)
- for r.Contains(ip) {
- if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr {
- return ip
- }
- ip = ip.Next()
- }
- }
- return netip.Addr{}
-}
-
-// randV4 returns a random IPv4 address within the given prefix.
-func randV4(maskedPfx netip.Prefix) netip.Addr {
- bits := 32 - maskedPfx.Bits()
- randBits := rand.Uint32N(1 << uint(bits))
-
- ip4 := maskedPfx.Addr().As4()
- pn := binary.BigEndian.Uint32(ip4[:])
- binary.BigEndian.PutUint32(ip4[:], randBits|pn)
- return netip.AddrFrom4(ip4)
-}
-
-// assignAddrsLocked assigns a pair of unique IP addresses for the given domain
-// and returns them. The first address is an IPv4 address and the second is an
-// IPv6 address. It does not check if the domain already has assigned addresses.
-// ps.mu must be held.
-func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr {
- if ps.addrToDomain == nil {
- ps.addrToDomain = &bart.Table[string]{}
- }
- v4 := ps.unusedIPv4Locked()
- as16 := ps.c.v6ULA.Addr().As16()
+func (c *connector) v6ForV4(v4 netip.Addr) netip.Addr {
+ as16 := c.v6ULA.Addr().As16()
as4 := v4.As4()
copy(as16[12:], as4[:])
v6 := netip.AddrFrom16(as16)
- addrs := []netip.Addr{v4, v6}
- mak.Set(&ps.domainToAddr, domain, addrs)
- for _, a := range addrs {
- ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain)
- }
- return addrs
+ return v6
+}
+
+func v4ForV6(v6 netip.Addr) netip.Addr {
+ as16 := v6.As16()
+ var as4 [4]byte
+ copy(as4[:], as16[12:])
+ v4 := netip.AddrFrom4(as4)
+ return v4
}