summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tsconsensus/authorization.go80
-rw-r--r--tsconsensus/http.go18
-rw-r--r--tsconsensus/tsconsensus.go136
3 files changed, 137 insertions, 97 deletions
diff --git a/tsconsensus/authorization.go b/tsconsensus/authorization.go
new file mode 100644
index 000000000..67e685def
--- /dev/null
+++ b/tsconsensus/authorization.go
@@ -0,0 +1,80 @@
+package tsconsensus
+
+import (
+ "context"
+ "net/netip"
+ "slices"
+
+ "tailscale.com/ipn/ipnstate"
+ "tailscale.com/tsnet"
+)
+
+type authorization struct {
+ ts *tsnet.Server
+ tag string
+ peers *peers
+}
+
+func (a *authorization) refresh(ctx context.Context) error {
+ lc, err := a.ts.LocalClient()
+ if err != nil {
+ return err
+ }
+ tStatus, err := lc.Status(ctx)
+ if err != nil {
+ return err
+ }
+ a.peers = newPeers(tStatus)
+ return nil
+}
+
+func (a *authorization) allowsHost(addr netip.Addr) bool {
+ return a.peers.peerExists(addr, a.tag)
+}
+
+func (a *authorization) selfAllowed() bool {
+ return a.peers.status.Self.Tags != nil && slices.Contains(a.peers.status.Self.Tags.AsSlice(), a.tag)
+}
+
+func (a *authorization) allowedPeers() []*ipnstate.PeerStatus {
+ if a.peers.allowedPeers == nil {
+ return []*ipnstate.PeerStatus{}
+ }
+ return a.peers.allowedPeers
+}
+
+type peers struct {
+ status *ipnstate.Status
+ peerByIPAddressAndTag map[netip.Addr]map[string]*ipnstate.PeerStatus
+ allowedPeers []*ipnstate.PeerStatus
+}
+
+func (ps *peers) peerExists(a netip.Addr, tag string) bool {
+ byTag, ok := ps.peerByIPAddressAndTag[a]
+ if !ok {
+ return false
+ }
+ _, ok = byTag[tag]
+ return ok
+}
+
+func newPeers(status *ipnstate.Status) *peers {
+ ps := &peers{
+ peerByIPAddressAndTag: map[netip.Addr]map[string]*ipnstate.PeerStatus{},
+ status: status,
+ }
+ for _, p := range status.Peer {
+ for _, addr := range p.TailscaleIPs {
+ if ps.peerByIPAddressAndTag[addr] == nil {
+ ps.peerByIPAddressAndTag[addr] = map[string]*ipnstate.PeerStatus{}
+ }
+ if p.Tags != nil {
+ for _, tag := range p.Tags.AsSlice() {
+ ps.peerByIPAddressAndTag[addr][tag] = p
+ ps.allowedPeers = append(ps.allowedPeers, p)
+ }
+ }
+ }
+ }
+ return ps
+}
diff --git a/tsconsensus/http.go b/tsconsensus/http.go
index 301127687..7570e2936 100644
--- a/tsconsensus/http.go
+++ b/tsconsensus/http.go
@@ -9,8 +9,6 @@ import (
"io"
"net/http"
"time"
-
- "tailscale.com/tsnet"
)
type joinRequest struct {
@@ -79,13 +77,19 @@ func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult,
return cr, nil
}
-func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
+func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
- allowed, err := allowedPeer(r.RemoteAddr, tag, ts)
+ err := auth.refresh(r.Context())
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ a, err := addrFromServerAddress(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ allowed := auth.allowsHost(a)
if !allowed {
http.Error(w, "peer not allowed", http.StatusBadRequest)
return
@@ -94,9 +98,9 @@ func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http
}
}
-func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux {
+func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
mux := http.NewServeMux()
- mux.HandleFunc("/join", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) {
+ mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
@@ -122,7 +126,7 @@ func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux
return
}
}))
- mux.HandleFunc("/executeCommand", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) {
+ mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go
index c6fb4f30b..a5aa75635 100644
--- a/tsconsensus/tsconsensus.go
+++ b/tsconsensus/tsconsensus.go
@@ -9,7 +9,6 @@ import (
"net"
"net/http"
"net/netip"
- "slices"
"time"
"github.com/hashicorp/raft"
@@ -79,42 +78,39 @@ func DefaultConfig() Config {
}
}
+func addrFromServerAddress(sa string) (netip.Addr, error) {
+ sAddr, _, err := net.SplitHostPort(sa)
+ if err != nil {
+ return netip.Addr{}, err
+ }
+ return netip.ParseAddr(sAddr)
+}
+
// StreamLayer implements an interface asked for by raft.NetworkTransport.
// It does the raft interprocess communication via tailscale.
type StreamLayer struct {
net.Listener
- s *tsnet.Server
- tag string
+ auth *authorization
+ s *tsnet.Server
}
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
- allowed, err := allowedPeer(string(address), sl.tag, sl.s)
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ err := sl.auth.refresh(ctx)
if err != nil {
return nil, err
}
- if !allowed {
- return nil, errors.New("peer is not allowed")
- }
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- return sl.s.Dial(ctx, "tcp", string(address))
-}
-func allowedPeer(remoteAddr string, tag string, s *tsnet.Server) (bool, error) {
- sAddr, _, err := net.SplitHostPort(remoteAddr)
+ addr, err := addrFromServerAddress(string(address))
if err != nil {
- return false, err
- }
- a, err := netip.ParseAddr(sAddr)
- if err != nil {
- return false, err
+ return nil, err
}
- ctx := context.Background() // TODO very much a sign I shouldn't be doing this here
- peers, err := taggedNodesFromStatus(ctx, tag, s)
- if err != nil {
- return false, err
+
+ if !sl.auth.allowsHost(addr) {
+ return nil, errors.New("peer is not allowed")
}
- return peers.has(a), nil
+ return sl.s.Dial(ctx, "tcp", string(address))
}
func (sl StreamLayer) Accept() (net.Conn, error) {
@@ -123,74 +119,26 @@ func (sl StreamLayer) Accept() (net.Conn, error) {
if err != nil || conn == nil {
return conn, err
}
- allowed, err := allowedPeer(conn.RemoteAddr().String(), sl.tag, sl.s)
+ ctx := context.Background() // TODO
+ err = sl.auth.refresh(ctx)
+ if err != nil {
+ // TODO should we stay alive here?
+ return nil, err
+ }
+
+ addr, err := addrFromServerAddress(conn.RemoteAddr().String())
if err != nil {
// TODO should we stay alive here?
return nil, err
}
- if !allowed {
+
+ if !sl.auth.allowsHost(addr) {
continue
}
return conn, err
}
}
-type allowedPeers struct {
- self *ipnstate.PeerStatus
- peers []*ipnstate.PeerStatus
- peerByIPAddress map[netip.Addr]*ipnstate.PeerStatus
- clusterTag string
-}
-
-func (ap *allowedPeers) allowed(n *ipnstate.PeerStatus) bool {
- return n.Tags != nil && slices.Contains(n.Tags.AsSlice(), ap.clusterTag)
-}
-
-func (ap *allowedPeers) addPeerIfAllowed(p *ipnstate.PeerStatus) {
- if !ap.allowed(p) {
- return
- }
- ap.peers = append(ap.peers, p)
- for _, addr := range p.TailscaleIPs {
- ap.peerByIPAddress[addr] = p
- }
-}
-
-func (ap *allowedPeers) addSelfIfAllowed(n *ipnstate.PeerStatus) {
- if ap.allowed(n) {
- ap.self = n
- }
-}
-
-func (ap *allowedPeers) has(a netip.Addr) bool {
- _, ok := ap.peerByIPAddress[a]
- return ok
-}
-
-func taggedNodesFromStatus(ctx context.Context, clusterTag string, ts *tsnet.Server) (*allowedPeers, error) {
- lc, err := ts.LocalClient()
- if err != nil {
- return nil, err
- }
- tStatus, err := lc.Status(ctx)
- if err != nil {
- return nil, err
- }
- ap := newAllowedPeers(clusterTag)
- for _, v := range tStatus.Peer {
- ap.addPeerIfAllowed(v)
- }
- ap.addSelfIfAllowed(tStatus.Self)
- return ap, nil
-}
-
-func newAllowedPeers(tag string) *allowedPeers {
- return &allowedPeers{
- peerByIPAddress: map[netip.Addr]*ipnstate.PeerStatus{},
- clusterTag: tag,
- }
-}
-
// Start returns a pointer to a running Consensus instance.
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) {
if clusterTag == "" {
@@ -211,22 +159,30 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
Config: cfg,
}
- tnfs, err := taggedNodesFromStatus(ctx, clusterTag, ts)
- if tnfs.self == nil {
+ auth := &authorization{
+ tag: clusterTag,
+ ts: ts,
+ }
+ err := auth.refresh(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ if !auth.selfAllowed() {
return nil, errors.New("this node is not tagged with the cluster tag")
}
- r, err := startRaft(ts, &fsm, c.Self, clusterTag, cfg)
+ r, err := startRaft(ts, &fsm, c.Self, auth, cfg)
if err != nil {
return nil, err
}
c.Raft = r
- srv, err := c.serveCmdHttp(ts, clusterTag)
+ srv, err := c.serveCmdHttp(ts, auth)
if err != nil {
return nil, err
}
c.cmdHttpServer = srv
- c.bootstrap(tnfs.peers)
+ c.bootstrap(auth.allowedPeers())
srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort))
if err != nil {
return nil, err
@@ -235,7 +191,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
return &c, nil
}
-func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag string, cfg Config) (*raft.Raft, error) {
+func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) {
config := cfg.Raft
config.LocalID = raft.ServerID(self.ID)
@@ -251,9 +207,9 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag st
}
transport := raft.NewNetworkTransport(StreamLayer{
- s: ts,
Listener: ln,
- tag: clusterTag,
+ auth: auth,
+ s: ts,
},
cfg.MaxConnPool,
cfg.ConnTimeout,
@@ -386,12 +342,12 @@ func (e lookElsewhereError) Error() string {
var ErrLeaderUnknown = errors.New("Leader Unknown")
-func (c *Consensus) serveCmdHttp(ts *tsnet.Server, tag string) (*http.Server, error) {
+func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.Server, error) {
ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host))
if err != nil {
return nil, err
}
- mux := c.makeCommandMux(ts, tag)
+ mux := c.makeCommandMux(auth)
srv := &http.Server{Handler: mux}
go func() {
err := srv.Serve(ln)