summaryrefslogtreecommitdiffhomepage
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/local/local.go13
-rw-r--r--client/local/local_test.go51
2 files changed, 64 insertions, 0 deletions
diff --git a/client/local/local.go b/client/local/local.go
index e72589306..f9cf96654 100644
--- a/client/local/local.go
+++ b/client/local/local.go
@@ -972,6 +972,19 @@ func (lc *Client) UserDial(ctx context.Context, network, host string, port uint1
if res.StatusCode != http.StatusSwitchingProtocols {
body, _ := io.ReadAll(res.Body)
res.Body.Close()
+ if res.StatusCode == http.StatusOK && res.Header.Get("Dial-Self") == "true" {
+ // Server told us to dial the address ourselves rather than
+ // proxying through the daemon. This happens for non-Tailscale
+ // addresses where the daemon shouldn't dial as root on the
+ // client's behalf. The server provides the resolved address
+ // to avoid a TOCTOU race with DNS re-resolution.
+ addr := res.Header.Get("Dial-Addr")
+ if addr == "" {
+ return nil, errors.New("server returned Dial-Self without Dial-Addr")
+ }
+ var d net.Dialer
+ return d.DialContext(ctx, network, addr)
+ }
return nil, fmt.Errorf("unexpected HTTP response: %s, %s", res.Status, body)
}
// From here on, the underlying net.Conn is ours to use, but there
diff --git a/client/local/local_test.go b/client/local/local_test.go
index a5377fbd6..58a87b224 100644
--- a/client/local/local_test.go
+++ b/client/local/local_test.go
@@ -61,6 +61,57 @@ func TestWhoIsPeerNotFound(t *testing.T) {
}
}
+func TestUserDialSelf(t *testing.T) {
+ // Start a real TCP listener that the client should dial directly
+ // when the server tells it to dial-self.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Write([]byte("hello"))
+ c.Close()
+ }
+ }()
+ targetAddr := ln.Addr().(*net.TCPAddr)
+
+ // Mock LocalAPI server that returns Dial-Self response.
+ nw := nettest.GetNetwork(t)
+ ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Dial-Self", "true")
+ w.Header().Set("Dial-Addr", targetAddr.String())
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer ts.Close()
+
+ lc := &Client{
+ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nw.Dial(ctx, network, ts.Listener.Addr().String())
+ },
+ }
+
+ conn, err := lc.UserDial(context.Background(), "tcp", targetAddr.IP.String(), uint16(targetAddr.Port))
+ if err != nil {
+ t.Fatalf("UserDial: %v", err)
+ }
+ defer conn.Close()
+
+ buf := make([]byte, 5)
+ n, err := conn.Read(buf)
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if got := string(buf[:n]); got != "hello" {
+ t.Errorf("got %q, want %q", got, "hello")
+ }
+}
+
func TestDeps(t *testing.T) {
deptest.DepChecker{
BadDeps: map[string]string{