diff options
| -rw-r--r-- | feature/conn25/conn25.go | 48 | ||||
| -rw-r--r-- | feature/conn25/conn25_test.go | 2 |
2 files changed, 25 insertions, 25 deletions
diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index e5db9619b..2afc06052 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -627,6 +627,7 @@ type client struct { v6MagicIPPool *ippool v6TransitIPPool *ippool assignments addrAssignments + byConnKey map[key.NodePublic]set.Set[netip.Prefix] config config } @@ -775,7 +776,7 @@ func (c *client) addTransitIPForConnector(tip netip.Addr, conn tailcfg.NodeView) c.mu.Lock() defer c.mu.Unlock() - return c.assignments.insertTransitConnMapping(tip, conn.Key()) + return c.insertTransitConnMapping(tip, conn.Key()) } func (e *extension) sendLoop(ctx context.Context) { @@ -820,7 +821,7 @@ func (c *client) enqueueAddressAssignment(addrs addrs) error { func (c *client) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { c.mu.Lock() defer c.mu.Unlock() - tips, ok := c.assignments.lookupTransitIPsByConnKey(k) + tips, ok := c.lookupTransitIPsByConnKey(k) if !ok { return views.Slice[netip.Prefix]{} } @@ -1186,7 +1187,6 @@ type addrAssignments struct { byMagicIP map[netip.Addr]addrs byTransitIP map[netip.Addr]addrs byDomainDst map[domainDst]addrs - byConnKey map[key.NodePublic]set.Set[netip.Prefix] } func (a *addrAssignments) insert(as addrs) error { @@ -1209,15 +1209,30 @@ func (a *addrAssignments) insert(as addrs) error { return nil } +func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { + v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] + return v, ok +} + +func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (addrs, bool) { + v, ok := a.byMagicIP[mip] + return v, ok +} + +func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { + v, ok := a.byTransitIP[tip] + return v, ok +} + // insertTransitConnMapping adds an entry to the byConnKey map // for the provided transitIP (as a prefix). // The provided transitIP must already be present in the byTransitIP map. -func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { - if _, ok := a.lookupByTransitIP(tip); !ok { +func (c *client) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { + if _, ok := c.assignments.lookupByTransitIP(tip); !ok { return errors.New("transit IP is not already known") } - ctips, ok := a.byConnKey[connKey] + ctips, ok := c.byConnKey[connKey] tipp := netip.PrefixFrom(tip, tip.BitLen()) if ok { if ctips.Contains(tipp) { @@ -1225,32 +1240,17 @@ func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.N } } else { ctips.Make() - mak.Set(&a.byConnKey, connKey, ctips) + mak.Set(&c.byConnKey, connKey, ctips) } ctips.Add(tipp) return nil } -func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { - v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] - return v, ok -} - -func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (addrs, bool) { - v, ok := a.byMagicIP[mip] - return v, ok -} - -func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { - v, ok := a.byTransitIP[tip] - return v, ok -} - // lookupTransitIPsByConnKey returns a slice containing the transit IPs (as netipPrefix) // associated with the given connector (identified by node key), or (nil, false) if there is no entry // for the given key. -func (a *addrAssignments) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { - s, ok := a.byConnKey[k] +func (c *client) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { + s, ok := c.byConnKey[k] if !ok { return nil, false } diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 8e829724b..1784ccb68 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -1665,7 +1665,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { // Check that each of the lookups behaves as expected for i, lu := range tt.lookups { - got, ok := ext.conn25.client.assignments.lookupTransitIPsByConnKey(lu.connKey) + got, ok := ext.conn25.client.lookupTransitIPsByConnKey(lu.connKey) if ok != lu.expectedOk { t.Fatalf("unexpected ok result at index %d wanted %v, got %v", i, lu.expectedOk, ok) } |
