summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrew Dunham <andrew@du.nham.ca>2024-07-10 16:46:31 -0500
committerAndrew Dunham <andrew@du.nham.ca>2024-08-02 16:05:14 -0400
commit9939374c48aff28ea9bee63a749869312d0954ef (patch)
tree6ff65b4aed28157af6363cf1af3e30265bd9736b
parent4055b63b9b5f4662254cd4a5d926265e9ff7734f (diff)
downloadtailscale-9939374c48aff28ea9bee63a749869312d0954ef.tar.xz
tailscale-9939374c48aff28ea9bee63a749869312d0954ef.zip
wgengine/magicsock: use cloud metadata to get public IPs
Updates #12774 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I1661b6a2da7966ab667b075894837afd96f4742f
-rw-r--r--wgengine/magicsock/cloudinfo.go182
-rw-r--r--wgengine/magicsock/cloudinfo_nocloud.go23
-rw-r--r--wgengine/magicsock/cloudinfo_test.go123
-rw-r--r--wgengine/magicsock/magicsock.go31
-rw-r--r--wgengine/magicsock/magicsock_test.go10
5 files changed, 360 insertions, 9 deletions
diff --git a/wgengine/magicsock/cloudinfo.go b/wgengine/magicsock/cloudinfo.go
new file mode 100644
index 000000000..1de369631
--- /dev/null
+++ b/wgengine/magicsock/cloudinfo.go
@@ -0,0 +1,182 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !(ios || android || js)
+
+package magicsock
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/netip"
+ "slices"
+ "strings"
+ "time"
+
+ "tailscale.com/types/logger"
+ "tailscale.com/util/cloudenv"
+)
+
+const maxCloudInfoWait = 2 * time.Second
+
+type cloudInfo struct {
+ client http.Client
+ logf logger.Logf
+
+ // The following parameters are fixed for the lifetime of the cloudInfo
+ // object, but are used for testing.
+ cloud cloudenv.Cloud
+ endpoint string
+}
+
+func newCloudInfo(logf logger.Logf) *cloudInfo {
+ tr := &http.Transport{
+ DisableKeepAlives: true,
+ Dial: (&net.Dialer{
+ Timeout: maxCloudInfoWait,
+ }).Dial,
+ }
+
+ return &cloudInfo{
+ client: http.Client{Transport: tr},
+ logf: logf,
+ cloud: cloudenv.Get(),
+ endpoint: "http://" + cloudenv.CommonNonRoutableMetadataIP,
+ }
+}
+
+// GetPublicIPs returns any public IPs attached to the current cloud instance,
+// if the tailscaled process is running in a known cloud and there are any such
+// IPs present.
+func (ci *cloudInfo) GetPublicIPs(ctx context.Context) ([]netip.Addr, error) {
+ switch ci.cloud {
+ case cloudenv.AWS:
+ ret, err := ci.getAWS(ctx)
+ ci.logf("[v1] cloudinfo.GetPublicIPs: AWS: %v, %v", ret, err)
+ return ret, err
+ }
+
+ return nil, nil
+}
+
+// getAWSMetadata makes a request to the AWS metadata service at the given
+// path, authenticating with the provided IMDSv2 token. The returned metadata
+// is split by newline and returned as a slice.
+func (ci *cloudInfo) getAWSMetadata(ctx context.Context, token, path string) ([]string, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", ci.endpoint+path, nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request to %q: %w", path, err)
+ }
+ req.Header.Set("X-aws-ec2-metadata-token", token)
+
+ resp, err := ci.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("making request to metadata service %q: %w", path, err)
+ }
+ defer resp.Body.Close()
+
+ switch resp.StatusCode {
+ case http.StatusOK:
+ // Good
+ case http.StatusNotFound:
+ // Nothing found, but this isn't an error; just return
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("reading response body for %q: %w", path, err)
+ }
+
+ return strings.Split(strings.TrimSpace(string(body)), "\n"), nil
+}
+
+// getAWS returns all public IPv4 and IPv6 addresses present in the AWS instance metadata.
+func (ci *cloudInfo) getAWS(ctx context.Context) ([]netip.Addr, error) {
+ ctx, cancel := context.WithTimeout(ctx, maxCloudInfoWait)
+ defer cancel()
+
+ // Get a token so we can query the metadata service.
+ req, err := http.NewRequestWithContext(ctx, "PUT", ci.endpoint+"/latest/api/token", nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating token request: %w", err)
+ }
+ req.Header.Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", "10")
+
+ resp, err := ci.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("making token request to metadata service: %w", err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ return nil, fmt.Errorf("reading token response body: %w", err)
+ }
+ token := string(body)
+
+ server := resp.Header.Get("Server")
+ if server != "EC2ws" {
+ return nil, fmt.Errorf("unexpected server header: %q", server)
+ }
+
+ // Iterate over all interfaces and get their public IP addresses, both IPv4 and IPv6.
+ macAddrs, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/")
+ if err != nil {
+ return nil, fmt.Errorf("getting interface MAC addresses: %w", err)
+ }
+
+ var (
+ addrs []netip.Addr
+ errs []error
+ )
+
+ addAddr := func(addr string) {
+ ip, err := netip.ParseAddr(addr)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("parsing IP address %q: %w", addr, err))
+ return
+ }
+ addrs = append(addrs, ip)
+ }
+ for _, mac := range macAddrs {
+ ips, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/public-ipv4s")
+ if err != nil {
+ errs = append(errs, fmt.Errorf("getting IPv4 addresses for %q: %w", mac, err))
+ continue
+ }
+
+ for _, ip := range ips {
+ addAddr(ip)
+ }
+
+ // Try querying for IPv6 addresses.
+ ips, err = ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/ipv6s")
+ if err != nil {
+ errs = append(errs, fmt.Errorf("getting IPv6 addresses for %q: %w", mac, err))
+ continue
+ }
+ for _, ip := range ips {
+ addAddr(ip)
+ }
+ }
+
+ // Sort the returned addresses for determinism.
+ slices.SortFunc(addrs, func(a, b netip.Addr) int {
+ return a.Compare(b)
+ })
+
+ // Preferentially return any addresses we found, even if there were errors.
+ if len(addrs) > 0 {
+ return addrs, nil
+ }
+ if len(errs) > 0 {
+ return nil, fmt.Errorf("getting IP addresses: %w", errors.Join(errs...))
+ }
+ return nil, nil
+}
diff --git a/wgengine/magicsock/cloudinfo_nocloud.go b/wgengine/magicsock/cloudinfo_nocloud.go
new file mode 100644
index 000000000..b4414d318
--- /dev/null
+++ b/wgengine/magicsock/cloudinfo_nocloud.go
@@ -0,0 +1,23 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ios || android || js
+
+package magicsock
+
+import (
+ "context"
+ "net/netip"
+
+ "tailscale.com/types/logger"
+)
+
+type cloudInfo struct{}
+
+func newCloudInfo(_ logger.Logf) *cloudInfo {
+ return &cloudInfo{}
+}
+
+func (ci *cloudInfo) GetPublicIPs(_ context.Context) ([]netip.Addr, error) {
+ return nil, nil
+}
diff --git a/wgengine/magicsock/cloudinfo_test.go b/wgengine/magicsock/cloudinfo_test.go
new file mode 100644
index 000000000..15191aeef
--- /dev/null
+++ b/wgengine/magicsock/cloudinfo_test.go
@@ -0,0 +1,123 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/netip"
+ "slices"
+ "testing"
+
+ "tailscale.com/util/cloudenv"
+)
+
+func TestCloudInfo_AWS(t *testing.T) {
+ const (
+ mac1 = "06:1d:00:00:00:00"
+ mac2 = "06:1d:00:00:00:01"
+ publicV4 = "1.2.3.4"
+ otherV4_1 = "5.6.7.8"
+ otherV4_2 = "11.12.13.14"
+ v6addr = "2001:db8::1"
+
+ macsPrefix = "/latest/meta-data/network/interfaces/macs/"
+ )
+ // Launch a fake AWS IMDS server
+ fake := &fakeIMDS{
+ tb: t,
+ paths: map[string]string{
+ macsPrefix: mac1 + "\n" + mac2,
+ // This is the "main" public IP address for the instance
+ macsPrefix + mac1 + "/public-ipv4s": publicV4,
+
+ // There's another interface with two public IPs
+ // attached to it and an IPv6 address, all of which we
+ // should discover.
+ macsPrefix + mac2 + "/public-ipv4s": otherV4_1 + "\n" + otherV4_2,
+ macsPrefix + mac2 + "/ipv6s": v6addr,
+ },
+ }
+
+ srv := httptest.NewServer(fake)
+ defer srv.Close()
+
+ ci := newCloudInfo(t.Logf)
+ ci.cloud = cloudenv.AWS
+ ci.endpoint = srv.URL
+
+ ips, err := ci.GetPublicIPs(context.Background())
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ wantIPs := []netip.Addr{
+ netip.MustParseAddr(publicV4),
+ netip.MustParseAddr(otherV4_1),
+ netip.MustParseAddr(otherV4_2),
+ netip.MustParseAddr(v6addr),
+ }
+ if !slices.Equal(ips, wantIPs) {
+ t.Fatalf("got %v, want %v", ips, wantIPs)
+ }
+}
+
+func TestCloudInfo_AWSNotPublic(t *testing.T) {
+ returns404 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "PUT" && r.URL.Path == "/latest/api/token" {
+ w.Header().Set("Server", "EC2ws")
+ w.Write([]byte("fake-imds-token"))
+ return
+ }
+ http.NotFound(w, r)
+ })
+ srv := httptest.NewServer(returns404)
+ defer srv.Close()
+
+ ci := newCloudInfo(t.Logf)
+ ci.cloud = cloudenv.AWS
+ ci.endpoint = srv.URL
+
+ // If the IMDS server doesn't return any public IPs, it's not an error
+ // and we should just get an empty list.
+ ips, err := ci.GetPublicIPs(context.Background())
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(ips) != 0 {
+ t.Fatalf("got %v, want none", ips)
+ }
+}
+
+type fakeIMDS struct {
+ tb testing.TB
+ paths map[string]string
+}
+
+func (f *fakeIMDS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ f.tb.Logf("%s %s", r.Method, r.URL.Path)
+ path := r.URL.Path
+
+ // Handle the /latest/api/token case
+ const token = "fake-imds-token"
+ if r.Method == "PUT" && path == "/latest/api/token" {
+ w.Header().Set("Server", "EC2ws")
+ w.Write([]byte(token))
+ return
+ }
+
+ // Otherwise, require the IMDSv2 token to be set
+ if r.Header.Get("X-aws-ec2-metadata-token") != token {
+ f.tb.Errorf("missing or invalid IMDSv2 token")
+ http.Error(w, "missing or invalid IMDSv2 token", http.StatusForbidden)
+ return
+ }
+
+ if v, ok := f.paths[path]; ok {
+ w.Write([]byte(v))
+ return
+ }
+ http.NotFound(w, r)
+}
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go
index cd7fb23da..5ac53c771 100644
--- a/wgengine/magicsock/magicsock.go
+++ b/wgengine/magicsock/magicsock.go
@@ -133,6 +133,9 @@ type Conn struct {
// bind is the wireguard-go conn.Bind for Conn.
bind *connBind
+ // cloudInfo is used to query cloud metadata services.
+ cloudInfo *cloudInfo
+
// ============================================================
// Fields that must be accessed via atomic load/stores.
@@ -425,9 +428,10 @@ func (o *Options) derpActiveFunc() func() {
// newConn is the error-free, network-listening-side-effect-free based
// of NewConn. Mostly for tests.
-func newConn() *Conn {
+func newConn(logf logger.Logf) *Conn {
discoPrivate := key.NewDisco()
c := &Conn{
+ logf: logf,
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
derpStarted: make(chan struct{}),
peerLastDerp: make(map[key.NodePublic]int),
@@ -435,6 +439,7 @@ func newConn() *Conn {
discoInfo: make(map[key.DiscoPublic]*discoInfo),
discoPrivate: discoPrivate,
discoPublic: discoPrivate.Public(),
+ cloudInfo: newCloudInfo(logf),
}
c.discoShort = c.discoPublic.ShortString()
c.bind = &connBind{Conn: c, closed: true}
@@ -462,10 +467,9 @@ func NewConn(opts Options) (*Conn, error) {
return nil, errors.New("magicsock.Options.NetMon must be non-nil")
}
- c := newConn()
+ c := newConn(opts.logf())
c.port.Store(uint32(opts.Port))
c.controlKnobs = opts.ControlKnobs
- c.logf = opts.logf()
c.epFunc = opts.endpointsFunc()
c.derpActiveFunc = opts.derpActiveFunc()
c.idleFunc = opts.IdleFunc
@@ -952,6 +956,27 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
addAddr(ap, tailcfg.EndpointExplicitConf)
}
+ // If we're on a cloud instance, we might have a public IPv4 or IPv6
+ // address that we can be reached at. Find those, if they exist, and
+ // add them.
+ if addrs, err := c.cloudInfo.GetPublicIPs(ctx); err == nil {
+ var port4, port6 uint16
+ if addr := c.pconn4.LocalAddr(); addr != nil {
+ port4 = uint16(addr.Port)
+ }
+ if addr := c.pconn6.LocalAddr(); addr != nil {
+ port6 = uint16(addr.Port)
+ }
+
+ for _, addr := range addrs {
+ if addr.Is4() && port4 > 0 {
+ addAddr(netip.AddrPortFrom(addr, port4), tailcfg.EndpointLocal)
+ } else if addr.Is6() && port6 > 0 {
+ addAddr(netip.AddrPortFrom(addr, port6), tailcfg.EndpointLocal)
+ }
+ }
+ }
+
// Update our set of endpoints by adding any endpoints that we
// previously found but haven't expired yet. This also updates the
// cache with the set of endpoints discovered in this function.
diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go
index cec05dffc..a721c24e4 100644
--- a/wgengine/magicsock/magicsock_test.go
+++ b/wgengine/magicsock/magicsock_test.go
@@ -452,7 +452,7 @@ func TestPickDERPFallback(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
- c := newConn()
+ c := newConn(t.Logf)
dm := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {},
@@ -483,7 +483,7 @@ func TestPickDERPFallback(t *testing.T) {
// distribution over nodes works.
got := map[int]int{}
for range 50 {
- c = newConn()
+ c = newConn(t.Logf)
c.derpMap = dm
got[c.pickDERPFallback()]++
}
@@ -1185,8 +1185,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
}
func TestDiscoMessage(t *testing.T) {
- c := newConn()
- c.logf = t.Logf
+ c := newConn(t.Logf)
c.privateKey = key.NewNode()
peer1Pub := c.DiscoPublicKey()
@@ -3161,8 +3160,7 @@ func TestMaybeSetNearestDERP(t *testing.T) {
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ht := new(health.Tracker)
- c := newConn()
- c.logf = t.Logf
+ c := newConn(t.Logf)
c.myDerp = tt.old
c.derpMap = derpMap
c.health = ht