diff options
Diffstat (limited to 'tstest/integration/testcontrol/testcontrol.go')
| -rw-r--r-- | tstest/integration/testcontrol/testcontrol.go | 77 |
1 files changed, 43 insertions, 34 deletions
diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 52b96fe4d..71205f897 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -55,6 +55,10 @@ type Server struct { MagicDNSDomain string HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + // AllNodesSameUser, if true, makes all created nodes + // belong to the same user. + AllNodesSameUser bool + // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL @@ -96,9 +100,9 @@ type Server struct { logins map[key.NodePublic]*tailcfg.Login updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath - nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + nodeKeyAuthed set.Set[key.NodePublic] + msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. } // BaseURL returns the server's base URL, without trailing slash. @@ -522,6 +526,10 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) return u, s.logins[nodeKey] } id := tailcfg.UserID(len(s.users) + 1) + if s.AllNodesSameUser { + id = 123 + } + s.logf("Created user %v for node %s", id, nodeKey) loginName := fmt.Sprintf("user-%d@%s", id, domain) displayName := fmt.Sprintf("User %d", id) login := &tailcfg.Login{ @@ -582,10 +590,8 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool { if ap.nodeKey.IsZero() { panic("zero AuthPath.NodeKey") } - if s.nodeKeyAuthed == nil { - s.nodeKeyAuthed = map[key.NodePublic]bool{} - } - s.nodeKeyAuthed[ap.nodeKey] = true + s.nodeKeyAuthed.Make() + s.nodeKeyAuthed.Add(ap.nodeKey) ap.CompleteSuccessfully() return true } @@ -645,36 +651,40 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. if s.nodes == nil { s.nodes = map[key.NodePublic]*tailcfg.Node{} } - + _, ok := s.nodes[nk] machineAuthorized := true // TODO: add Server.RequireMachineAuth + if !ok { - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) - v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - - allowedIPs := []netip.Prefix{ - v4Prefix, - v6Prefix, - } + nodeID := len(s.nodes) + 1 + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(nodeID>>8), uint8(nodeID)), 32) + v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - s.nodes[nk] = &tailcfg.Node{ - ID: tailcfg.NodeID(user.ID), - StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), - User: user.ID, - Machine: mkey, - Key: req.NodeKey, - MachineAuthorized: machineAuthorized, - Addresses: allowedIPs, - AllowedIPs: allowedIPs, - Hostinfo: req.Hostinfo.View(), - Name: req.Hostinfo.Hostname, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityHTTPS, - tailcfg.NodeAttrFunnel, - tailcfg.CapabilityFunnelPorts + "?ports=8080,443", - }, + allowedIPs := []netip.Prefix{ + v4Prefix, + v6Prefix, + } + node := &tailcfg.Node{ + ID: tailcfg.NodeID(nodeID), + StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(nodeID))), + User: user.ID, + Machine: mkey, + Key: req.NodeKey, + MachineAuthorized: machineAuthorized, + Addresses: allowedIPs, + AllowedIPs: allowedIPs, + Hostinfo: req.Hostinfo.View(), + Name: req.Hostinfo.Hostname, + Capabilities: []tailcfg.NodeCapability{ + tailcfg.CapabilityHTTPS, + tailcfg.NodeAttrFunnel, + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityFunnelPorts + "?ports=8080,443", + }, + } + s.nodes[nk] = node } requireAuth := s.RequireAuth - if requireAuth && s.nodeKeyAuthed[nk] { + if requireAuth && s.nodeKeyAuthed.Contains(nk) { requireAuth = false } allExpired := s.allExpired @@ -951,7 +961,6 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, node.CapMap = nodeCapMap node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) - user, _ := s.getUser(nk) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) dns := s.DNSConfig if dns != nil && s.MagicDNSDomain != "" { @@ -1013,7 +1022,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, }) res.UserProfiles = s.allUserProfiles() - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(node.ID>>8), uint8(node.ID)), 32) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) res.Node.Addresses = []netip.Prefix{ |
