summaryrefslogtreecommitdiffhomepage
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/tsdial/tsdial.go27
-rw-r--r--net/tsdial/tsdial_test.go97
2 files changed, 124 insertions, 0 deletions
diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go
index ebbafa52b..ca08810a3 100644
--- a/net/tsdial/tsdial.go
+++ b/net/tsdial/tsdial.go
@@ -515,6 +515,33 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn,
return stdDialer.DialContext(ctx, network, ipp.String())
}
+// UserDialPlan resolves addr and reports whether the dialer would
+// handle it via Tailscale. If viaTailscale is false, the resolved
+// address is not a Tailscale route and the caller may dial it directly.
+//
+// Warning: there is a TOCTOU race if addr contains a DNS name and the
+// caller subsequently passes the same DNS name to [Dialer.UserDial], as DNS
+// may resolve differently the second time. Callers who want to only
+// dial over Tailscale should call [Dialer.UserDial] with the returned
+// ipp.String() (an IP:port) rather than the original DNS name.
+func (d *Dialer) UserDialPlan(ctx context.Context, network, addr string) (ipp netip.AddrPort, viaTailscale bool, err error) {
+ ipp, err = d.userDialResolve(ctx, network, addr)
+ if err != nil {
+ return netip.AddrPort{}, false, err
+ }
+ if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) {
+ return ipp, true, nil
+ }
+ if routes := d.routes.Load(); routes != nil {
+ isTailscaleRoute, _ := routes.Lookup(ipp.Addr())
+ return ipp, isTailscaleRoute, nil
+ }
+ if version.IsMacGUIVariant() && tsaddr.IsTailscaleIP(ipp.Addr()) {
+ return ipp, true, nil
+ }
+ return ipp, false, nil
+}
+
// dialPeerAPI connects to a Tailscale peer's peerapi over TCP.
//
// network must a "tcp" type, and addr must be an ip:port. Name resolution
diff --git a/net/tsdial/tsdial_test.go b/net/tsdial/tsdial_test.go
new file mode 100644
index 000000000..92960acbe
--- /dev/null
+++ b/net/tsdial/tsdial_test.go
@@ -0,0 +1,97 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "context"
+ "net/netip"
+ "testing"
+
+ "github.com/gaissmai/bart"
+)
+
+func TestUserDialPlan(t *testing.T) {
+ tests := []struct {
+ name string
+ addr string
+ routes map[netip.Prefix]bool // nil means no routes configured
+ useNetstackFor func(netip.Addr) bool // nil means not set
+ wantVia bool
+ wantAddr netip.AddrPort
+ }{
+ {
+ name: "loopback_no_routes",
+ addr: "127.0.0.1:8080",
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("127.0.0.1:8080"),
+ },
+ {
+ name: "loopback_v6_no_routes",
+ addr: "[::1]:8080",
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("[::1]:8080"),
+ },
+ {
+ name: "tailscale_ip_in_routes",
+ addr: "100.64.1.1:22",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ },
+ wantVia: true,
+ wantAddr: netip.MustParseAddrPort("100.64.1.1:22"),
+ },
+ {
+ name: "non_tailscale_ip_in_local_routes",
+ addr: "10.0.0.5:80",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ netip.MustParsePrefix("10.0.0.0/8"): false, // local route
+ },
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("10.0.0.5:80"),
+ },
+ {
+ name: "loopback_with_routes_configured",
+ addr: "127.0.0.1:3000",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ },
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("127.0.0.1:3000"),
+ },
+ {
+ name: "netstack_for_ip",
+ addr: "100.100.100.100:53",
+ useNetstackFor: func(ip netip.Addr) bool {
+ return ip == netip.MustParseAddr("100.100.100.100")
+ },
+ wantVia: true,
+ wantAddr: netip.MustParseAddrPort("100.100.100.100:53"),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &Dialer{}
+ if tt.routes != nil {
+ rt := &bart.Table[bool]{}
+ for pfx, v := range tt.routes {
+ rt.Insert(pfx, v)
+ }
+ d.routes.Store(rt)
+ }
+ d.UseNetstackForIP = tt.useNetstackFor
+
+ ipp, viaTailscale, err := d.UserDialPlan(context.Background(), "tcp", tt.addr)
+ if err != nil {
+ t.Fatalf("UserDialPlan: %v", err)
+ }
+ if viaTailscale != tt.wantVia {
+ t.Errorf("viaTailscale = %v, want %v", viaTailscale, tt.wantVia)
+ }
+ if ipp != tt.wantAddr {
+ t.Errorf("addr = %v, want %v", ipp, tt.wantAddr)
+ }
+ })
+ }
+}