diff options
| -rw-r--r-- | cmd/tailscale/cli/cli.go | 1 | ||||
| -rw-r--r-- | cmd/tailscale/cli/rsh.go | 450 | ||||
| -rw-r--r-- | cmd/tailscale/cli/rsh_test.go | 142 | ||||
| -rw-r--r-- | feature/condregister/maybe_rsh.go | 8 | ||||
| -rw-r--r-- | feature/rsh/checkmode_test.go | 286 | ||||
| -rw-r--r-- | feature/rsh/localapi.go | 156 | ||||
| -rw-r--r-- | feature/rsh/policy.go | 152 | ||||
| -rw-r--r-- | feature/rsh/policy_test.go | 257 | ||||
| -rw-r--r-- | feature/rsh/protocol.go | 157 | ||||
| -rw-r--r-- | feature/rsh/protocol_test.go | 188 | ||||
| -rw-r--r-- | feature/rsh/rsh.go | 739 | ||||
| -rw-r--r-- | feature/rsh/rsh_netstack.go | 30 | ||||
| -rw-r--r-- | ipn/ipnlocal/local.go | 7 |
13 files changed, 2573 insertions, 0 deletions
diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index fda6b4546..a04807a47 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -259,6 +259,7 @@ change in the future. pingCmd, ncCmd, sshCmd, + rshCmd, nilOrCall(maybeFunnelCmd), nilOrCall(maybeServeCmd), versionCmd, diff --git a/cmd/tailscale/cli/rsh.go b/cmd/tailscale/cli/rsh.go new file mode 100644 index 000000000..7b352b52c --- /dev/null +++ b/cmd/tailscale/cli/rsh.go @@ -0,0 +1,450 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/tailscale/apitype" +) + +var rshArgs struct { + loginUser string // -l flag: SSH login user + sshOption string // -o flag: SSH option (ignored, for compatibility) +} + +var rshCmd = &ffcli.Command{ + Name: "rsh", + ShortUsage: "tailscale rsh [-l user] [user@]<host> [command...]", + ShortHelp: "Execute a remote command over Tailscale without SSH overhead", + LongHelp: strings.TrimSpace(` +The 'tailscale rsh' command executes a command on a remote Tailscale node +using a direct TCP connection over the Tailscale network. Unlike SSH, it +avoids double encryption (SSH + WireGuard) and SSH's suboptimal buffering. + +It is designed to be used as an rsync -e transport replacement: + + rsync -e 'tailscale rsh' -avz ./local/ user@host:/remote/ + +The remote node must have Tailscale SSH enabled, as rsh reuses the same +SSH access policy for authorization. + +SSH-compatible flags (-l user, -o option) are accepted and handled +appropriately so that rsync and similar tools can invoke rsh as a +drop-in remote shell replacement. + +When used without a command, it starts the user's default login shell. +`), + FlagSet: func() *flag.FlagSet { + fs := newFlagSet("rsh") + fs.StringVar(&rshArgs.loginUser, "l", "", "remote login user (SSH-compatible)") + fs.StringVar(&rshArgs.sshOption, "o", "", "SSH option (ignored, for compatibility)") + return fs + }(), + Exec: runRsh, +} + +// rshFraming constants matching feature/rsh/protocol.go. +const ( + rshChanStdin byte = 0x00 + rshChanStdout byte = 0x01 + rshChanStderr byte = 0x02 + rshChanExit byte = 0x03 + rshTokenLen = 32 + rshMaxFrame = 256 * 1024 + rshFrameHdrSize = 5 +) + +func runRsh(ctx context.Context, args []string) error { + if len(args) == 0 { + return errors.New("usage: tailscale rsh [user@]<host> [command...]") + } + + // Check tailscaled is running. + st, err := localClient.Status(ctx) + if err != nil { + return fixTailscaledConnectError(err) + } + description, ok := isRunningOrStarting(st) + if !ok { + printf("%s\n", description) + os.Exit(1) + } + + username, host, cmdArgs, err := parseRshArgs(args) + if err != nil { + return err + } + + // The -l flag parsed by ffcli takes priority over user@host. + // This handles cases like: tailscale rsh -l ubuntu james-ai + // where ffcli parses -l before runRsh sees the args. + if rshArgs.loginUser != "" { + username = rshArgs.loginUser + } + + // If no explicit user, default to the current OS user. + if username == "" { + u, err := currentUser() + if err != nil { + return fmt.Errorf("cannot determine current user: %w", err) + } + username = u + } + + // Resolve host to a peer. + ps, ok := peerStatusFromArg(st, host) + if !ok { + return fmt.Errorf("unknown host %q; not found in Tailscale network", host) + } + + // Build the command string (rsync passes it as separate args). + command := strings.Join(cmdArgs, " ") + + // Request an rsh session via the LocalAPI. + type localRshRequest struct { + PeerID string `json:"peer"` + User string `json:"user"` + Command string `json:"command,omitempty"` + } + reqBody := localRshRequest{ + PeerID: string(ps.ID), + User: username, + Command: command, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", + "http://"+apitype.LocalAPIHost+"/localapi/v0/rsh", + bytes.NewReader(bodyBytes)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := localClient.DoLocalRequest(req) + if err != nil { + return fmt.Errorf("rsh setup: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return fmt.Errorf("rsh setup failed: %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + + type rshResponse struct { + Addr string `json:"addr"` + Token string `json:"token"` + } + type rshStatusMessage struct { + Status string `json:"status"` + } + + var rshResp rshResponse + ct := resp.Header.Get("Content-Type") + if strings.HasPrefix(ct, "application/x-ndjson") { + // Streaming check mode: read newline-delimited JSON lines. + // Status messages go to stderr, the final rshResponse has addr+token. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 64*1024) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + // Try to decode as rshResponse (has "addr" field). + var candidate rshResponse + if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" { + rshResp = candidate + continue + } + // Otherwise, treat as a status message. + var msg rshStatusMessage + if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" { + fmt.Fprintf(os.Stderr, "rsh: %s\n", msg.Status) + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("rsh: reading streaming response: %w", err) + } + } else { + // Simple JSON response. + if err := json.NewDecoder(resp.Body).Decode(&rshResp); err != nil { + return fmt.Errorf("rsh: invalid response: %w", err) + } + } + resp.Body.Close() + + if rshResp.Addr == "" || rshResp.Token == "" { + return errors.New("rsh: server returned empty address or token") + } + + // Parse the address to get host and port for DialTCP. + addrHost, portStr, err := splitHostPort(rshResp.Addr) + if err != nil { + return fmt.Errorf("rsh: invalid address %q: %w", rshResp.Addr, err) + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return fmt.Errorf("rsh: invalid port %q: %w", portStr, err) + } + + // Decode the token. + token, err := hex.DecodeString(rshResp.Token) + if err != nil || len(token) != rshTokenLen { + return fmt.Errorf("rsh: invalid token") + } + + // Connect to the data channel via tailscaled. + conn, err := localClient.DialTCP(ctx, addrHost, uint16(port)) + if err != nil { + return fmt.Errorf("rsh: connect to %s: %w", rshResp.Addr, err) + } + defer conn.Close() + + // Send the authentication token. + if _, err := conn.Write(token); err != nil { + return fmt.Errorf("rsh: send token: %w", err) + } + + // Run the framing protocol. + return rshPumpIO(conn) +} + +// rshPumpIO handles the framing protocol between the local stdin/stdout/stderr +// and the remote process over the connection. +func rshPumpIO(conn io.ReadWriteCloser) error { + // Goroutine: read stdin and send as ChanStdin frames. + stdinDone := make(chan struct{}) + go func() { + defer close(stdinDone) + buf := make([]byte, 64*1024) + for { + n, err := os.Stdin.Read(buf) + if n > 0 { + if werr := writeFrame(conn, rshChanStdin, buf[:n]); werr != nil { + return + } + } + if err != nil { + // Send a zero-length stdin frame to signal EOF. + writeFrame(conn, rshChanStdin, nil) + return + } + } + }() + + // Main loop: read frames from the connection and dispatch. + var hdr [rshFrameHdrSize]byte + for { + if _, err := io.ReadFull(conn, hdr[:]); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + // Connection closed without exit code. + return fmt.Errorf("rsh: connection closed unexpectedly") + } + return fmt.Errorf("rsh: read frame: %w", err) + } + ch := hdr[0] + n := binary.BigEndian.Uint32(hdr[1:]) + if n > rshMaxFrame { + return fmt.Errorf("rsh: frame too large: %d", n) + } + + switch ch { + case rshChanStdout: + if _, err := io.CopyN(os.Stdout, conn, int64(n)); err != nil { + return fmt.Errorf("rsh: stdout: %w", err) + } + case rshChanStderr: + if _, err := io.CopyN(os.Stderr, conn, int64(n)); err != nil { + return fmt.Errorf("rsh: stderr: %w", err) + } + case rshChanExit: + if n != 4 { + return fmt.Errorf("rsh: invalid exit frame size: %d", n) + } + var exitBuf [4]byte + if _, err := io.ReadFull(conn, exitBuf[:]); err != nil { + return fmt.Errorf("rsh: read exit code: %w", err) + } + code := int(binary.BigEndian.Uint32(exitBuf[:])) + if code != 0 { + os.Exit(code) + } + return nil + default: + // Unknown channel, skip the payload. + if _, err := io.CopyN(io.Discard, conn, int64(n)); err != nil { + return fmt.Errorf("rsh: skip unknown frame: %w", err) + } + } + } +} + +// writeFrame writes a single rsh protocol frame to w. +func writeFrame(w io.Writer, ch byte, data []byte) error { + var hdr [rshFrameHdrSize]byte + hdr[0] = ch + binary.BigEndian.PutUint32(hdr[1:], uint32(len(data))) + if _, err := w.Write(hdr[:]); err != nil { + return err + } + if len(data) > 0 { + if _, err := w.Write(data); err != nil { + return err + } + } + return nil +} + +// parseRshArgs parses SSH-compatible arguments as passed by rsync and +// similar tools when using rsh as a remote shell transport. +// +// rsync invokes the remote shell as: +// +// tailscale rsh [user@host] [-l user] [-o option]... <host> <command...> +// +// The user@host may appear as the first positional arg (from the rsync +// URI), while -l overrides the username. The bare hostname after flags +// is the actual target. Everything after that is the remote command. +// +// Returns the resolved username (may be empty if none specified), host, +// and command args. +func parseRshArgs(args []string) (username, host string, cmdArgs []string, err error) { + if len(args) == 0 { + return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]") + } + + // First, check if args[0] is a user@host or bare host (not a flag). + // rsync passes the user@host from the rsync URI as the first arg, + // before any -l flag. + i := 0 + if !strings.HasPrefix(args[0], "-") { + u, h, hasAt := strings.Cut(args[0], "@") + if hasAt { + username = u + host = h + } else { + // Bare hostname (no @). Record it; it may be + // overridden if a second bare hostname appears + // after flags (the rsync pattern). + host = args[0] + } + i = 1 + } + + // Parse SSH-compatible flags. + flagUser := "" + hadFlags := false + for i < len(args) { + a := args[i] + if a == "--" { + i++ + break + } + if !strings.HasPrefix(a, "-") { + break // first non-flag is the host + } + hadFlags = true + switch { + case a == "-l": + // -l <user> + i++ + if i >= len(args) { + return "", "", nil, errors.New("rsh: -l requires an argument") + } + flagUser = args[i] + i++ + case strings.HasPrefix(a, "-l"): + // -l<user> (no space) + flagUser = a[2:] + i++ + case a == "-o": + // -o <option>: SSH option, ignore. + i++ + if i < len(args) { + i++ // skip the option value + } + case strings.HasPrefix(a, "-o"): + // -o<option>: SSH option, ignore. + i++ + default: + // Unknown flag (e.g. -4, -6, -p, etc.); skip it. + // SSH has many flags; we silently ignore ones we + // don't understand since we don't need them. + i++ + } + } + + // After flags, the next non-flag arg is the host. rsync passes + // the bare hostname after -l flags, so we expect it here when + // flags were present. When there were no flags and we already + // have a host from args[0], the remaining args are the command. + if hadFlags && i < len(args) && !strings.HasPrefix(args[i], "-") { + host = args[i] + i++ + } + + // Everything remaining is the command. + cmdArgs = args[i:] + + // -l flag overrides any user from user@host. + if flagUser != "" { + username = flagUser + } + + if host == "" { + return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]") + } + + return username, host, cmdArgs, nil +} + +// splitHostPort splits a host:port string. Unlike net.SplitHostPort, +// it handles bare IPv4 addresses with port (100.1.2.3:1234) as well +// as [IPv6]:port format. +func splitHostPort(addr string) (host, port string, err error) { + // Handle IPv6 [::]:port format. + if strings.HasPrefix(addr, "[") { + end := strings.Index(addr, "]:") + if end < 0 { + return "", "", fmt.Errorf("invalid address: %s", addr) + } + return addr[1:end], addr[end+2:], nil + } + // Handle IPv4 host:port. + i := strings.LastIndex(addr, ":") + if i < 0 { + return "", "", fmt.Errorf("no port in address: %s", addr) + } + return addr[:i], addr[i+1:], nil +} + +// currentUser returns the current OS username. +func currentUser() (string, error) { + // os/user.Current() can fail in some environments (static builds, etc). + // Try it first, fall back to env vars. + if u := os.Getenv("USER"); u != "" { + return u, nil + } + return "", errors.New("cannot determine current user") +} diff --git a/cmd/tailscale/cli/rsh_test.go b/cmd/tailscale/cli/rsh_test.go new file mode 100644 index 000000000..964a59e3b --- /dev/null +++ b/cmd/tailscale/cli/rsh_test.go @@ -0,0 +1,142 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "strings" + "testing" +) + +func TestParseRshArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantUser string + wantHost string + wantCmd string // joined with space + wantErr bool + }{ + { + name: "simple_user_at_host_with_command", + args: []string{"alice@myhost", "ls", "-la"}, + wantUser: "alice", + wantHost: "myhost", + wantCmd: "ls -la", + }, + { + name: "bare_host_with_command", + args: []string{"myhost", "ls"}, + wantUser: "", + wantHost: "myhost", + wantCmd: "ls", + }, + { + name: "bare_host_no_command", + args: []string{"myhost"}, + wantUser: "", + wantHost: "myhost", + wantCmd: "", + }, + { + // This is the exact pattern from rsync: + // tailscale rsh ubuntu@james-ai -l ubuntu james-ai rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/ + name: "rsync_pattern", + args: []string{"ubuntu@james-ai", "-l", "ubuntu", "james-ai", "rsync", "--server", "--sender", "-vlogDtpre.iLsfxCIvu", ".", "ai/"}, + wantUser: "ubuntu", + wantHost: "james-ai", + wantCmd: "rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/", + }, + { + name: "l_flag_overrides_user_at_host", + args: []string{"alice@myhost", "-l", "bob", "myhost", "echo", "hi"}, + wantUser: "bob", + wantHost: "myhost", + wantCmd: "echo hi", + }, + { + name: "l_flag_no_space", + args: []string{"myhost", "-lubuntu", "myhost", "ls"}, + wantUser: "ubuntu", + wantHost: "myhost", + wantCmd: "ls", + }, + { + name: "l_flag_without_user_at_host", + args: []string{"-l", "ubuntu", "myhost", "rsync", "--server"}, + wantUser: "ubuntu", + wantHost: "myhost", + wantCmd: "rsync --server", + }, + { + name: "o_flag_ignored", + args: []string{"alice@myhost", "-o", "StrictHostKeyChecking=no", "myhost", "ls"}, + wantUser: "alice", + wantHost: "myhost", + wantCmd: "ls", + }, + { + name: "o_flag_no_space_ignored", + args: []string{"alice@myhost", "-oStrictHostKeyChecking=no", "myhost", "ls"}, + wantUser: "alice", + wantHost: "myhost", + wantCmd: "ls", + }, + { + name: "multiple_flags", + args: []string{"alice@myhost", "-o", "BatchMode=yes", "-l", "root", "myhost", "uptime"}, + wantUser: "root", + wantHost: "myhost", + wantCmd: "uptime", + }, + { + name: "unknown_flags_skipped", + args: []string{"alice@myhost", "-4", "-p", "myhost", "ls"}, + wantUser: "alice", + wantHost: "myhost", + wantCmd: "ls", + }, + { + name: "double_dash_separator", + args: []string{"myhost", "--", "-l", "this-is-command"}, + wantUser: "", + wantHost: "myhost", + wantCmd: "-l this-is-command", + }, + { + name: "empty_args", + args: []string{}, + wantErr: true, + }, + { + name: "l_flag_missing_value", + args: []string{"myhost", "-l"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, host, cmdArgs, err := parseRshArgs(tt.args) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user != tt.wantUser { + t.Errorf("user = %q, want %q", user, tt.wantUser) + } + if host != tt.wantHost { + t.Errorf("host = %q, want %q", host, tt.wantHost) + } + gotCmd := strings.Join(cmdArgs, " ") + if gotCmd != tt.wantCmd { + t.Errorf("command = %q, want %q", gotCmd, tt.wantCmd) + } + }) + } +} diff --git a/feature/condregister/maybe_rsh.go b/feature/condregister/maybe_rsh.go new file mode 100644 index 000000000..b2745d068 --- /dev/null +++ b/feature/condregister/maybe_rsh.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd) && !ts_omit_rsh + +package condregister + +import _ "tailscale.com/feature/rsh" diff --git a/feature/rsh/checkmode_test.go b/feature/rsh/checkmode_test.go new file mode 100644 index 000000000..86e53c21b --- /dev/null +++ b/feature/rsh/checkmode_test.go @@ -0,0 +1,286 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "bufio" + "bytes" + "encoding/json" + "net/netip" + "os/user" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/netmap" +) + +func TestExpandDelegateURL(t *testing.T) { + nm := &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + ID: 42, + StableID: "self-stable", + Key: key.NewNode().Public(), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + }, + }).View(), + } + + peerNode := (&tailcfg.Node{ + ID: 99, + StableID: "peer-stable", + Key: key.NewNode().Public(), + }).View() + + peerAddr := netip.MustParseAddr("100.64.1.2") + lu := &user.User{Username: "localice"} + + tests := []struct { + name string + url string + want string + }{ + { + name: "all_variables", + url: "https://control.example.com/check?src=$SRC_NODE_IP&srcid=$SRC_NODE_ID&dst=$DST_NODE_IP&dstid=$DST_NODE_ID&sshuser=$SSH_USER&local=$LOCAL_USER", + want: "https://control.example.com/check?src=100.64.1.2&srcid=99&dst=100.64.0.1&dstid=42&sshuser=alice&local=localice", + }, + { + name: "no_variables", + url: "https://control.example.com/check?static=true", + want: "https://control.example.com/check?static=true", + }, + { + name: "url_encoding", + url: "https://control.example.com/check?user=$SSH_USER", + want: "https://control.example.com/check?user=alice%40example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sshUser := "alice" + if tt.name == "url_encoding" { + sshUser = "alice@example.com" + } + got := expandDelegateURL(tt.url, nm, peerNode, peerAddr, sshUser, lu) + if got != tt.want { + t.Errorf("expandDelegateURL() =\n %s\nwant:\n %s", got, tt.want) + } + }) + } +} + +func TestWriteNDJSON(t *testing.T) { + var buf bytes.Buffer + + // Write a status message. + writeNDJSON(&buf, nil, rshStatusMessage{Status: "waiting"}) + + // Write an rshResponse. + writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.1:1234", Token: "abcd"}) + + // Verify output is two newline-delimited JSON lines. + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + if len(lines) != 2 { + t.Fatalf("got %d lines, want 2:\n%s", len(lines), buf.String()) + } + + // Verify first line is a status message. + var msg rshStatusMessage + if err := json.Unmarshal([]byte(lines[0]), &msg); err != nil { + t.Fatalf("unmarshal line 0: %v", err) + } + if msg.Status != "waiting" { + t.Errorf("status = %q, want %q", msg.Status, "waiting") + } + + // Verify second line is a response. + var resp rshResponse + if err := json.Unmarshal([]byte(lines[1]), &resp); err != nil { + t.Fatalf("unmarshal line 1: %v", err) + } + if resp.Addr != "100.64.0.1:1234" { + t.Errorf("addr = %q, want %q", resp.Addr, "100.64.0.1:1234") + } + if resp.Token != "abcd" { + t.Errorf("token = %q, want %q", resp.Token, "abcd") + } +} + +func TestNDJSONStreamParsing(t *testing.T) { + // Simulate a streaming NDJSON response as the CLI would see it. + var buf bytes.Buffer + writeNDJSON(&buf, nil, rshStatusMessage{Status: "Checking with control plane..."}) + writeNDJSON(&buf, nil, rshStatusMessage{Status: "Waiting for approval..."}) + writeNDJSON(&buf, nil, rshStatusMessage{Status: "Access approved"}) + writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.5:4567", Token: "deadbeef"}) + + // Parse like the CLI does. + scanner := bufio.NewScanner(&buf) + var statusMessages []string + var finalResp rshResponse + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var candidate rshResponse + if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" { + finalResp = candidate + continue + } + var msg rshStatusMessage + if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" { + statusMessages = append(statusMessages, msg.Status) + } + } + + if len(statusMessages) != 3 { + t.Fatalf("got %d status messages, want 3", len(statusMessages)) + } + if statusMessages[0] != "Checking with control plane..." { + t.Errorf("status[0] = %q, want %q", statusMessages[0], "Checking with control plane...") + } + if statusMessages[2] != "Access approved" { + t.Errorf("status[2] = %q, want %q", statusMessages[2], "Access approved") + } + if finalResp.Addr != "100.64.0.5:4567" { + t.Errorf("addr = %q, want %q", finalResp.Addr, "100.64.0.5:4567") + } + if finalResp.Token != "deadbeef" { + t.Errorf("token = %q, want %q", finalResp.Token, "deadbeef") + } +} + +func TestEvalSSHPolicyHoldAndDelegate(t *testing.T) { + now := timeVal(2025, 1, 1) + + node := (&tailcfg.Node{ + ID: 1, + StableID: "stable1", + Key: key.NewNode().Public(), + }).View() + + uprof := tailcfg.UserProfile{ + LoginName: "alice@example.com", + } + + srcAddr := netip.MustParseAddr("100.64.1.2") + + tests := []struct { + name string + pol *tailcfg.SSHPolicy + sshUser string + wantResult evalResult + wantUser string + wantURL string + }{ + { + name: "hold_and_delegate_with_message", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{ + HoldAndDelegate: "https://control.example.com/approve?user=$SSH_USER", + Message: "Please approve in the admin panel", + }, + }, + }, + }, + sshUser: "alice", + wantResult: evalHoldDelegate, + wantUser: "alice", + wantURL: "https://control.example.com/approve?user=$SSH_USER", + }, + { + name: "hold_with_specific_user_mapping", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}}, + SSHUsers: map[string]string{"root": "admin"}, + Action: &tailcfg.SSHAction{ + HoldAndDelegate: "https://control.example.com/check", + }, + }, + }, + }, + sshUser: "root", + wantResult: evalHoldDelegate, + wantUser: "admin", + wantURL: "https://control.example.com/check", + }, + { + name: "hold_rejects_unmapped_user", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"root": "admin"}, + Action: &tailcfg.SSHAction{ + HoldAndDelegate: "https://control.example.com/check", + }, + }, + }, + }, + sshUser: "unknown", + wantResult: evalRejectedUser, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now) + if result != tt.wantResult { + t.Errorf("result = %v, want %v", result, tt.wantResult) + } + if tt.wantUser != "" && localUser != tt.wantUser { + t.Errorf("localUser = %q, want %q", localUser, tt.wantUser) + } + if tt.wantURL != "" && action != nil && action.HoldAndDelegate != tt.wantURL { + t.Errorf("HoldAndDelegate = %q, want %q", action.HoldAndDelegate, tt.wantURL) + } + if tt.wantResult == evalHoldDelegate && action == nil { + t.Error("expected non-nil action for evalHoldDelegate result") + } + }) + } +} + +func TestExpandDelegateURLNilFields(t *testing.T) { + // Test with minimal/nil fields to ensure no panics. + lu := &user.User{Username: "bob"} + peerAddr := netip.MustParseAddr("100.64.0.2") + + // Nil netmap, invalid peer node. + got := expandDelegateURL( + "https://control.example.com/check?dst=$DST_NODE_ID&src=$SRC_NODE_ID", + nil, + tailcfg.NodeView{}, // invalid + peerAddr, + "bob", + lu, + ) + // Should not panic; missing IDs should be empty strings. + if strings.Contains(got, "$DST_NODE_ID") { + t.Errorf("unexpanded variable in URL: %s", got) + } + if strings.Contains(got, "$SRC_NODE_ID") { + t.Errorf("unexpanded variable in URL: %s", got) + } +} + +func timeVal(year, month, day int) time.Time { + return time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) +} diff --git a/feature/rsh/localapi.go b/feature/rsh/localapi.go new file mode 100644 index 000000000..778ad545e --- /dev/null +++ b/feature/rsh/localapi.go @@ -0,0 +1,156 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "tailscale.com/ipn/localapi" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" +) + +var ( + metricLocalAPIRshCalls = clientmetric.NewCounter("localapi_rsh") +) + +func init() { + localapi.Register("rsh", serveRsh) +} + +// localRshRequest is the JSON body the CLI sends to POST /localapi/v0/rsh. +// It includes the target peer information. +type localRshRequest struct { + // PeerID is the StableNodeID of the target peer. + PeerID tailcfg.StableNodeID `json:"peer"` + + // User is the SSH user to connect as. + User string `json:"user"` + + // Command is the command to execute on the remote. + Command string `json:"command,omitempty"` +} + +// serveRsh proxies an rsh setup request to the target peer's PeerAPI. +// +// POST /localapi/v0/rsh +// +// Request body: JSON localRshRequest +// Response body: JSON rshResponse (addr + token from the remote) +func serveRsh(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + metricLocalAPIRshCalls.Add(1) + + if !h.PermitRead { + http.Error(w, "rsh access denied", http.StatusForbidden) + return + } + if r.Method != "POST" { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) + return + } + + var req localRshRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + if req.PeerID == "" { + http.Error(w, "peer is required", http.StatusBadRequest) + return + } + if req.User == "" { + http.Error(w, "user is required", http.StatusBadRequest) + return + } + + b := h.LocalBackend() + nm := b.NetMap() + if nm == nil { + http.Error(w, "no netmap available", http.StatusInternalServerError) + return + } + + // Find the peer and its PeerAPI base URL. + var peerAPIBaseURL string + for _, p := range nm.Peers { + if p.StableID() == req.PeerID { + peerAPIBaseURL = b.PeerAPIBase(p) + break + } + } + if peerAPIBaseURL == "" { + http.Error(w, "peer not found or no PeerAPI available", http.StatusNotFound) + return + } + + // Build the PeerAPI request. + peerReqBody := rshRequest{ + User: req.User, + Command: req.Command, + } + bodyBytes, err := json.Marshal(peerReqBody) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + peerURL := strings.TrimRight(peerAPIBaseURL, "/") + "/v0/rsh" + peerReq, err := http.NewRequestWithContext(r.Context(), "POST", peerURL, bytes.NewReader(bodyBytes)) + if err != nil { + http.Error(w, "internal error: "+err.Error(), http.StatusInternalServerError) + return + } + peerReq.Header.Set("Content-Type", "application/json") + + // Use the PeerAPI transport to dial the remote peer. + tr := b.Dialer().PeerAPITransport() + resp, err := tr.RoundTrip(peerReq) + if err != nil { + h.Logf("rsh: failed to reach peer %s: %v", req.PeerID, err) + http.Error(w, fmt.Sprintf("failed to reach peer: %v", err), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + h.Logf("rsh: peer returned status %d: %s", resp.StatusCode, string(body)) + http.Error(w, fmt.Sprintf("peer error: %s", string(body)), resp.StatusCode) + return + } + + // Pass through the response from the peer. If the peer is using + // streaming NDJSON (check mode / HoldAndDelegate), we forward + // each line as it arrives so the CLI can display status messages. + ct := resp.Header.Get("Content-Type") + w.Header().Set("Content-Type", ct) + if strings.HasPrefix(ct, "application/x-ndjson") { + // Streaming mode: flush each line as it arrives. + flusher, _ := w.(http.Flusher) + w.WriteHeader(http.StatusOK) + buf := make([]byte, 4096) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + w.Write(buf[:n]) + if flusher != nil { + flusher.Flush() + } + } + if err != nil { + return + } + } + } else { + // Simple JSON response: pass through directly. + io.Copy(w, resp.Body) + } +} diff --git a/feature/rsh/policy.go b/feature/rsh/policy.go new file mode 100644 index 000000000..261d21242 --- /dev/null +++ b/feature/rsh/policy.go @@ -0,0 +1,152 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "errors" + "net/netip" + "time" + + "tailscale.com/tailcfg" +) + +// evalResult is the result of SSH policy evaluation. +type evalResult int + +const ( + evalAccepted evalResult = iota // rule matched with Accept + evalRejected // no matching rule, or explicit Reject + evalRejectedUser // principal matched but user mapping failed + evalHoldDelegate // rule matched with HoldAndDelegate (check mode) +) + +// evalSSHPolicy evaluates the SSH policy for the given parameters. +// This replicates the core matching logic from ssh/tailssh without +// depending on the SSH connection type. +// +// It returns the matching action, the mapped local user, and the evaluation result. +func evalSSHPolicy( + pol *tailcfg.SSHPolicy, + node tailcfg.NodeView, + uprof tailcfg.UserProfile, + srcAddr netip.Addr, + sshUser string, + now time.Time, +) (action *tailcfg.SSHAction, localUser string, result evalResult) { + if pol == nil { + return nil, "", evalRejected + } + failedOnUser := false + for _, r := range pol.Rules { + if a, lu, err := matchRule(r, node, uprof, srcAddr, sshUser, now); err == nil { + if a.HoldAndDelegate != "" { + return a, lu, evalHoldDelegate + } + return a, lu, evalAccepted + } else if errors.Is(err, errUserMatch) { + failedOnUser = true + } + } + if failedOnUser { + return nil, "", evalRejectedUser + } + return nil, "", evalRejected +} + +var ( + errNilRule = errors.New("nil rule") + errNilAction = errors.New("nil action") + errRuleExpired = errors.New("rule expired") + errPrincipalMatch = errors.New("principal didn't match") + errUserMatch = errors.New("user didn't match") +) + +// matchRule checks whether a single SSHRule matches the given parameters. +func matchRule( + r *tailcfg.SSHRule, + node tailcfg.NodeView, + uprof tailcfg.UserProfile, + srcAddr netip.Addr, + sshUser string, + now time.Time, +) (action *tailcfg.SSHAction, localUser string, err error) { + if r == nil { + return nil, "", errNilRule + } + if r.Action == nil { + return nil, "", errNilAction + } + if r.RuleExpires != nil && r.RuleExpires.Before(now) { + return nil, "", errRuleExpired + } + if !anyPrincipalMatches(r.Principals, node, uprof, srcAddr) { + return nil, "", errPrincipalMatch + } + if !r.Action.Reject { + localUser = mapLocalUser(r.SSHUsers, sshUser) + if localUser == "" { + return nil, "", errUserMatch + } + } + return r.Action, localUser, nil +} + +// anyPrincipalMatches reports whether any of the given principals match +// the Tailscale identity of the connecting peer. +func anyPrincipalMatches( + ps []*tailcfg.SSHPrincipal, + node tailcfg.NodeView, + uprof tailcfg.UserProfile, + srcAddr netip.Addr, +) bool { + for _, p := range ps { + if p == nil { + continue + } + if principalMatchesTailscaleIdentity(p, node, uprof, srcAddr) { + return true + } + } + return false +} + +// principalMatchesTailscaleIdentity reports whether a principal matches +// the Tailscale identity of the connecting peer. +func principalMatchesTailscaleIdentity( + p *tailcfg.SSHPrincipal, + node tailcfg.NodeView, + uprof tailcfg.UserProfile, + srcAddr netip.Addr, +) bool { + if p.Any { + return true + } + if !p.Node.IsZero() && node.Valid() && p.Node == node.StableID() { + return true + } + if p.NodeIP != "" { + if ip, _ := netip.ParseAddr(p.NodeIP); ip == srcAddr { + return true + } + } + if p.UserLogin != "" && uprof.LoginName == p.UserLogin { + return true + } + return false +} + +// mapLocalUser maps an SSH user to a local user using the SSHUsers map +// from a policy rule. +func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) string { + v, ok := ruleSSHUsers[reqSSHUser] + if !ok { + v = ruleSSHUsers["*"] + } + if v == "=" { + return reqSSHUser + } + return v +} diff --git a/feature/rsh/policy_test.go b/feature/rsh/policy_test.go new file mode 100644 index 000000000..5fa88f8cb --- /dev/null +++ b/feature/rsh/policy_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "net/netip" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func TestEvalSSHPolicy(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + + node := (&tailcfg.Node{ + ID: 1, + StableID: "stable1", + Key: key.NewNode().Public(), + }).View() + + uprof := tailcfg.UserProfile{ + LoginName: "alice@example.com", + } + + srcAddr := netip.MustParseAddr("100.64.1.2") + + tests := []struct { + name string + pol *tailcfg.SSHPolicy + sshUser string + wantResult evalResult + wantUser string + }{ + { + name: "nil_policy", + pol: nil, + sshUser: "root", + wantResult: evalRejected, + }, + { + name: "empty_policy", + pol: &tailcfg.SSHPolicy{}, + sshUser: "root", + wantResult: evalRejected, + }, + { + name: "accept_any_wildcard_user", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalAccepted, + wantUser: "alice", + }, + { + name: "accept_specific_user_mapping", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"root": "admin"}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "root", + wantResult: evalAccepted, + wantUser: "admin", + }, + { + name: "reject_unmapped_user", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"root": "admin"}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "unknown", + wantResult: evalRejectedUser, + }, + { + name: "match_by_node_stable_id", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Node: "stable1"}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "bob", + wantResult: evalAccepted, + wantUser: "bob", + }, + { + name: "reject_wrong_node", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Node: "other-node"}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "bob", + wantResult: evalRejected, + }, + { + name: "match_by_node_ip", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalAccepted, + wantUser: "alice", + }, + { + name: "match_by_user_login", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalAccepted, + wantUser: "alice", + }, + { + name: "hold_and_delegate_returns_eval_hold", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{HoldAndDelegate: "https://example.com/approve"}, + }, + }, + }, + sshUser: "alice", + wantResult: evalHoldDelegate, + }, + { + name: "explicit_reject_action", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + Action: &tailcfg.SSHAction{Reject: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalAccepted, // matchRule succeeds for Reject rules (no SSHUsers check) + }, + { + name: "expired_rule_skipped", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + RuleExpires: timePtr(now.Add(-time.Hour)), + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalRejected, + }, + { + name: "first_matching_rule_wins", + pol: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}}, + SSHUsers: map[string]string{"alice": "localice"}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + { + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{"*": "="}, + Action: &tailcfg.SSHAction{Accept: true}, + }, + }, + }, + sshUser: "alice", + wantResult: evalAccepted, + wantUser: "localice", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now) + if result != tt.wantResult { + t.Errorf("result = %v, want %v", result, tt.wantResult) + } + if tt.wantUser != "" && localUser != tt.wantUser { + t.Errorf("localUser = %q, want %q", localUser, tt.wantUser) + } + _ = action // not checked in most tests + }) + } +} + +func TestMapLocalUser(t *testing.T) { + tests := []struct { + name string + sshUsers map[string]string + reqUser string + wantResult string + }{ + {"exact_match", map[string]string{"root": "admin"}, "root", "admin"}, + {"wildcard_match", map[string]string{"*": "defaultuser"}, "anyone", "defaultuser"}, + {"identity_match", map[string]string{"*": "="}, "alice", "alice"}, + {"no_match", map[string]string{"root": "admin"}, "unknown", ""}, + {"exact_over_wildcard", map[string]string{"root": "admin", "*": "default"}, "root", "admin"}, + {"nil_map", nil, "alice", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapLocalUser(tt.sshUsers, tt.reqUser) + if got != tt.wantResult { + t.Errorf("mapLocalUser(%v, %q) = %q, want %q", tt.sshUsers, tt.reqUser, got, tt.wantResult) + } + }) + } +} + +func timePtr(t time.Time) *time.Time { return &t } diff --git a/feature/rsh/protocol.go b/feature/rsh/protocol.go new file mode 100644 index 000000000..7389b9bae --- /dev/null +++ b/feature/rsh/protocol.go @@ -0,0 +1,157 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package rsh implements a fast remote shell transport over Tailscale, +// designed as an rsync -e compatible replacement for SSH. It uses a PeerAPI +// endpoint for session setup and a raw TCP data channel for I/O, +// avoiding SSH's double encryption and suboptimal buffering. +package rsh + +import ( + "encoding/binary" + "fmt" + "io" + "sync" +) + +// Channel type constants for the wire protocol. +// The protocol is length-prefixed framing: +// +// [1 byte: channel] [4 bytes: length (big-endian)] [N bytes: payload] +const ( + // ChanStdin is data from client to server (remote process stdin). + ChanStdin byte = 0x00 + + // ChanStdout is data from server to client (remote process stdout). + ChanStdout byte = 0x01 + + // ChanStderr is data from server to client (remote process stderr). + ChanStderr byte = 0x02 + + // ChanExit is the exit code from the remote process. + // Payload is a 4-byte big-endian signed integer exit code. + // Sent by server to client, then the server closes the connection. + ChanExit byte = 0x03 +) + +const ( + // tokenLen is the length of the one-time authentication token. + tokenLen = 32 + + // maxFrameSize is the maximum payload size for a single frame. + // 256KB is a good balance between throughput and memory usage, + // matching typical rsync block sizes. + maxFrameSize = 256 * 1024 + + // frameHeaderSize is the size of the frame header (channel + length). + frameHeaderSize = 5 +) + +// frameWriter writes length-prefixed frames to an underlying writer. +// It is safe for concurrent use. +type frameWriter struct { + mu sync.Mutex + w io.Writer +} + +// newFrameWriter creates a new frameWriter that writes to w. +func newFrameWriter(w io.Writer) *frameWriter { + return &frameWriter{w: w} +} + +// WriteFrame writes a single frame with the given channel and payload. +func (fw *frameWriter) WriteFrame(ch byte, data []byte) error { + if len(data) > maxFrameSize { + return fmt.Errorf("rsh: frame payload too large: %d > %d", len(data), maxFrameSize) + } + fw.mu.Lock() + defer fw.mu.Unlock() + + var hdr [frameHeaderSize]byte + hdr[0] = ch + binary.BigEndian.PutUint32(hdr[1:], uint32(len(data))) + + if _, err := fw.w.Write(hdr[:]); err != nil { + return err + } + if len(data) > 0 { + if _, err := fw.w.Write(data); err != nil { + return err + } + } + return nil +} + +// WriteExitCode writes an exit code frame and is a convenience wrapper. +func (fw *frameWriter) WriteExitCode(code int) error { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], uint32(code)) + return fw.WriteFrame(ChanExit, buf[:]) +} + +// frameReader reads length-prefixed frames from an underlying reader. +type frameReader struct { + r io.Reader + buf []byte // reusable buffer for payloads +} + +// newFrameReader creates a new frameReader that reads from r. +func newFrameReader(r io.Reader) *frameReader { + return &frameReader{ + r: r, + buf: make([]byte, 0, 32*1024), // start small, grow as needed + } +} + +// ReadFrame reads the next frame, returning the channel type and payload. +// The returned payload slice is valid until the next call to ReadFrame. +func (fr *frameReader) ReadFrame() (ch byte, data []byte, err error) { + var hdr [frameHeaderSize]byte + if _, err := io.ReadFull(fr.r, hdr[:]); err != nil { + return 0, nil, err + } + ch = hdr[0] + n := binary.BigEndian.Uint32(hdr[1:]) + if n > maxFrameSize { + return 0, nil, fmt.Errorf("rsh: frame too large: %d > %d", n, maxFrameSize) + } + if int(n) > cap(fr.buf) { + fr.buf = make([]byte, n) + } else { + fr.buf = fr.buf[:n] + } + if n > 0 { + if _, err := io.ReadFull(fr.r, fr.buf); err != nil { + return 0, nil, err + } + } + return ch, fr.buf, nil +} + +// channelWriter wraps a frameWriter to implement io.Writer for a specific channel. +type channelWriter struct { + fw *frameWriter + ch byte +} + +// newChannelWriter returns an io.Writer that writes all data as frames on +// the given channel. +func newChannelWriter(fw *frameWriter, ch byte) io.Writer { + return &channelWriter{fw: fw, ch: ch} +} + +func (cw *channelWriter) Write(p []byte) (int, error) { + written := 0 + for len(p) > 0 { + chunk := p + if len(chunk) > maxFrameSize { + chunk = chunk[:maxFrameSize] + } + if err := cw.fw.WriteFrame(cw.ch, chunk); err != nil { + return written, err + } + written += len(chunk) + p = p[len(chunk):] + } + return written, nil +} diff --git a/feature/rsh/protocol_test.go b/feature/rsh/protocol_test.go new file mode 100644 index 000000000..ca0ecd0d1 --- /dev/null +++ b/feature/rsh/protocol_test.go @@ -0,0 +1,188 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rsh + +import ( + "bytes" + "encoding/binary" + "io" + "testing" +) + +func TestFrameRoundtrip(t *testing.T) { + var buf bytes.Buffer + + fw := newFrameWriter(&buf) + fr := newFrameReader(&buf) + + // Write several frames. + if err := fw.WriteFrame(ChanStdout, []byte("hello")); err != nil { + t.Fatalf("WriteFrame stdout: %v", err) + } + if err := fw.WriteFrame(ChanStderr, []byte("world")); err != nil { + t.Fatalf("WriteFrame stderr: %v", err) + } + if err := fw.WriteFrame(ChanStdin, []byte("input")); err != nil { + t.Fatalf("WriteFrame stdin: %v", err) + } + if err := fw.WriteExitCode(42); err != nil { + t.Fatalf("WriteExitCode: %v", err) + } + + // Read them back. + ch, data, err := fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 1: %v", err) + } + if ch != ChanStdout || string(data) != "hello" { + t.Errorf("frame 1: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdout, "hello") + } + + ch, data, err = fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 2: %v", err) + } + if ch != ChanStderr || string(data) != "world" { + t.Errorf("frame 2: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStderr, "world") + } + + ch, data, err = fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 3: %v", err) + } + if ch != ChanStdin || string(data) != "input" { + t.Errorf("frame 3: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdin, "input") + } + + ch, data, err = fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 4: %v", err) + } + if ch != ChanExit { + t.Errorf("frame 4: got ch=%d, want ch=%d", ch, ChanExit) + } + if len(data) != 4 { + t.Fatalf("exit frame data len = %d, want 4", len(data)) + } + code := int(binary.BigEndian.Uint32(data)) + if code != 42 { + t.Errorf("exit code = %d, want 42", code) + } + + // Should get EOF now. + _, _, err = fr.ReadFrame() + if err != io.EOF && err != io.ErrUnexpectedEOF { + t.Errorf("expected EOF after all frames, got: %v", err) + } +} + +func TestFrameEmptyPayload(t *testing.T) { + var buf bytes.Buffer + fw := newFrameWriter(&buf) + fr := newFrameReader(&buf) + + if err := fw.WriteFrame(ChanStdin, nil); err != nil { + t.Fatalf("WriteFrame empty: %v", err) + } + + ch, data, err := fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if ch != ChanStdin { + t.Errorf("ch = %d, want %d", ch, ChanStdin) + } + if len(data) != 0 { + t.Errorf("data len = %d, want 0", len(data)) + } +} + +func TestFrameTooLarge(t *testing.T) { + fw := newFrameWriter(io.Discard) + data := make([]byte, maxFrameSize+1) + if err := fw.WriteFrame(ChanStdout, data); err == nil { + t.Error("expected error for oversized frame, got nil") + } +} + +func TestFrameReaderTooLarge(t *testing.T) { + // Construct a frame with length > maxFrameSize. + var buf bytes.Buffer + var hdr [frameHeaderSize]byte + hdr[0] = ChanStdout + binary.BigEndian.PutUint32(hdr[1:], maxFrameSize+1) + buf.Write(hdr[:]) + + fr := newFrameReader(&buf) + _, _, err := fr.ReadFrame() + if err == nil { + t.Error("expected error for oversized frame in reader, got nil") + } +} + +func TestChannelWriter(t *testing.T) { + var buf bytes.Buffer + fw := newFrameWriter(&buf) + cw := newChannelWriter(fw, ChanStdout) + + data := []byte("hello world from channel writer") + n, err := cw.Write(data) + if err != nil { + t.Fatalf("Write: %v", err) + } + if n != len(data) { + t.Errorf("n = %d, want %d", n, len(data)) + } + + fr := newFrameReader(&buf) + ch, got, err := fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if ch != ChanStdout { + t.Errorf("ch = %d, want %d", ch, ChanStdout) + } + if !bytes.Equal(got, data) { + t.Errorf("data mismatch: got %q, want %q", got, data) + } +} + +func TestChannelWriterChunking(t *testing.T) { + var buf bytes.Buffer + fw := newFrameWriter(&buf) + cw := newChannelWriter(fw, ChanStdout) + + // Write more than maxFrameSize to verify chunking. + data := make([]byte, maxFrameSize+100) + for i := range data { + data[i] = byte(i % 256) + } + + n, err := cw.Write(data) + if err != nil { + t.Fatalf("Write: %v", err) + } + if n != len(data) { + t.Errorf("n = %d, want %d", n, len(data)) + } + + // Should produce two frames. + fr := newFrameReader(&buf) + + ch, chunk1, err := fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 1: %v", err) + } + if ch != ChanStdout || len(chunk1) != maxFrameSize { + t.Errorf("frame 1: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk1), ChanStdout, maxFrameSize) + } + + ch, chunk2, err := fr.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame 2: %v", err) + } + if ch != ChanStdout || len(chunk2) != 100 { + t.Errorf("frame 2: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk2), ChanStdout, 100) + } +} diff --git a/feature/rsh/rsh.go b/feature/rsh/rsh.go new file mode 100644 index 000000000..87fd4a2e3 --- /dev/null +++ b/feature/rsh/rsh.go @@ -0,0 +1,739 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "os/exec" + "os/user" + "runtime" + "strings" + "time" + + "tailscale.com/envknob" + "tailscale.com/hostinfo" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" + "tailscale.com/util/backoff" + "tailscale.com/util/clientmetric" + "tailscale.com/util/osuser" +) + +var ( + metricRshCalls = clientmetric.NewCounter("peerapi_rsh") + metricRshAccepts = clientmetric.NewCounter("peerapi_rsh_accept") + metricRshRejects = clientmetric.NewCounter("peerapi_rsh_reject") +) + +func init() { + ipnlocal.RegisterPeerAPIHandler("/v0/rsh", handleRsh) +} + +// rshRequest is the JSON body sent to POST /v0/rsh. +type rshRequest struct { + // User is the requested SSH user (will be mapped via SSHUsers policy). + User string `json:"user"` + + // Command is the command to execute. If empty, the user's default + // login shell is started. + Command string `json:"command,omitempty"` +} + +// rshResponse is returned by a successful POST /v0/rsh. +// In streaming mode (check mode), this is the final JSON line in +// the newline-delimited JSON stream. +type rshResponse struct { + // Addr is the Tailscale IP:port to connect to for the data channel. + Addr string `json:"addr"` + + // Token is the hex-encoded one-time authentication token that must + // be sent as the first bytes on the data channel connection. + Token string `json:"token"` +} + +// rshStatusMessage is sent as a streaming JSON line during the +// HoldAndDelegate (check mode) flow before the final rshResponse. +// Each message is a newline-delimited JSON object. +type rshStatusMessage struct { + // Status is a human-readable status message to display to the user. + Status string `json:"status"` +} + +// netstackTCPListenerFunc is the type of a function that creates a TCP +// listener on the netstack (gVisor) network stack. It is set by the +// netstack package at init time. +// +// We use a function hook instead of a type assertion on NetstackImpl +// because netstack.Impl.ListenTCP returns *gonet.TCPListener (not +// net.Listener), and importing gonet would create an unwanted gVisor +// dependency. +var netstackListenTCP func(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error) + +const linux = "linux" + +func handleRsh(ph ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + metricRshCalls.Add(1) + logf := ph.Logf + + if r.Method != "POST" { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) + return + } + + b := ph.LocalBackend() + + // Check that SSH is enabled on this node. + if !b.ShouldRunSSH() { + logf("rsh: denied; SSH not enabled") + metricRshRejects.Add(1) + http.Error(w, "SSH not enabled on this node", http.StatusForbidden) + return + } + + // Parse request. + var req rshRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + logf("rsh: bad request body: %v", err) + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if req.User == "" { + http.Error(w, "user is required", http.StatusBadRequest) + return + } + + // Evaluate SSH policy. + peerNode := ph.Peer() + peerAddr := ph.RemoteAddr().Addr() + + nm := b.NetMap() + if nm == nil { + logf("rsh: no netmap") + http.Error(w, "no netmap available", http.StatusInternalServerError) + return + } + + sshPol := nm.SSHPolicy + if sshPol == nil { + logf("rsh: no SSH policy") + metricRshRejects.Add(1) + http.Error(w, "no SSH policy configured", http.StatusForbidden) + return + } + + // Look up the peer's user profile for policy matching. + _, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr()) + if !ok { + logf("rsh: unknown peer %v", ph.RemoteAddr()) + metricRshRejects.Add(1) + http.Error(w, "unknown peer", http.StatusForbidden) + return + } + + action, localUser, result := evalSSHPolicy(sshPol, peerNode, uprof, peerAddr, req.User, time.Now()) + + switch result { + case evalAccepted: + if action.Reject { + logf("rsh: policy explicitly rejects %v -> %s@%s", peerAddr, req.User, localUser) + metricRshRejects.Add(1) + http.Error(w, "access denied by policy", http.StatusForbidden) + return + } + // Good, accepted. action may still have a Message to send. + case evalHoldDelegate: + // Check mode: we need to poll the control plane for approval. + // The response uses streaming newline-delimited JSON so + // status messages can be sent while we wait. + case evalRejectedUser: + logf("rsh: user %q not mapped for peer %v", req.User, peerAddr) + metricRshRejects.Add(1) + http.Error(w, fmt.Sprintf("user %q not permitted", req.User), http.StatusForbidden) + return + case evalRejected: + logf("rsh: policy rejects %v -> %s", peerAddr, req.User) + metricRshRejects.Add(1) + http.Error(w, "access denied by policy", http.StatusForbidden) + return + } + + // Look up the local user. We need this for both the immediate accept + // path and the HoldAndDelegate path (to expand delegate URL variables). + lu, loginShell, err := osuser.LookupByUsernameWithShell(localUser) + if err != nil { + logf("rsh: user lookup failed for %q: %v", localUser, err) + http.Error(w, fmt.Sprintf("user %q not found", localUser), http.StatusInternalServerError) + return + } + groupIDs, err := osuser.GetGroupIds(lu) + if err != nil { + logf("rsh: group lookup failed for %q: %v", localUser, err) + http.Error(w, "failed to look up user groups", http.StatusInternalServerError) + return + } + + // If HoldAndDelegate, run the check mode loop to get a terminal action. + // We use streaming JSON so status messages can be sent to the client. + if result == evalHoldDelegate { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + + action, err = resolveCheckMode(r.Context(), b, action, nm, peerNode, peerAddr, req.User, lu, w, flusher, logf) + if err != nil { + // Connection is already streaming; send error as a status message. + logf("rsh: check mode failed: %v", err) + writeNDJSON(w, flusher, rshStatusMessage{Status: fmt.Sprintf("check mode error: %v", err)}) + return + } + if action.Reject { + logf("rsh: check mode rejected %v -> %s", peerAddr, req.User) + metricRshRejects.Add(1) + msg := "access denied" + if action.Message != "" { + msg = action.Message + } + writeNDJSON(w, flusher, rshStatusMessage{Status: msg}) + return + } + if !action.Accept { + logf("rsh: check mode returned non-terminal action for %v -> %s", peerAddr, req.User) + metricRshRejects.Add(1) + writeNDJSON(w, flusher, rshStatusMessage{Status: "unexpected response from control"}) + return + } + } + + // Find a local Tailscale IP to listen on. + listenAddr, err := pickListenAddr(nm, peerAddr) + if err != nil { + logf("rsh: no listen address: %v", err) + if result == evalHoldDelegate { + flusher, _ := w.(http.Flusher) + writeNDJSON(w, flusher, rshStatusMessage{Status: "no suitable listen address"}) + } else { + http.Error(w, "no suitable listen address", http.StatusInternalServerError) + } + return + } + + // Create the listener. + ln, err := listenTailscale(b, listenAddr) + if err != nil { + logf("rsh: listen failed: %v", err) + if result == evalHoldDelegate { + flusher, _ := w.(http.Flusher) + writeNDJSON(w, flusher, rshStatusMessage{Status: "failed to create listener"}) + } else { + http.Error(w, "failed to create listener", http.StatusInternalServerError) + } + return + } + + // Generate one-time token. + var tokenBytes [tokenLen]byte + if _, err := rand.Read(tokenBytes[:]); err != nil { + ln.Close() + logf("rsh: rand failed: %v", err) + if result == evalHoldDelegate { + flusher, _ := w.(http.Flusher) + writeNDJSON(w, flusher, rshStatusMessage{Status: "internal error"}) + } else { + http.Error(w, "internal error", http.StatusInternalServerError) + } + return + } + tokenHex := hex.EncodeToString(tokenBytes[:]) + + metricRshAccepts.Add(1) + + // Start the session handler in a goroutine. It will accept one + // connection, verify the token, and wire up the incubator process. + go handleRshSession(b, ln, tokenBytes[:], peerAddr, lu, loginShell, groupIDs, req, ph, logf) + + // Return the listen address and token to the client. + resp := rshResponse{ + Addr: ln.Addr().String(), + Token: tokenHex, + } + if result == evalHoldDelegate { + // Streaming mode: send a final accept message then the response. + flusher, _ := w.(http.Flusher) + if action.Message != "" { + writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message}) + } + writeNDJSON(w, flusher, resp) + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } +} + +// writeNDJSON writes v as a single newline-delimited JSON line to w +// and flushes. This is used for the streaming check mode response. +func writeNDJSON(w io.Writer, flusher http.Flusher, v any) { + json.NewEncoder(w).Encode(v) // Encode appends '\n' + if flusher != nil { + flusher.Flush() + } +} + +// resolveCheckMode runs the HoldAndDelegate loop, polling the control plane +// until a terminal action (Accept or Reject) is returned. It sends status +// messages to the client as streaming JSON lines while waiting. +// +// This is the rsh equivalent of SSH's clientAuth HoldAndDelegate loop. +func resolveCheckMode( + ctx context.Context, + b *ipnlocal.LocalBackend, + action *tailcfg.SSHAction, + nm *netmap.NetworkMap, + peerNode tailcfg.NodeView, + peerAddr netip.Addr, + sshUser string, + lu *user.User, + w io.Writer, + flusher http.Flusher, + logf func(string, ...any), +) (*tailcfg.SSHAction, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) + defer cancel() + + for { + if action.Message != "" { + writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message}) + } + + if action.Accept || action.Reject { + return action, nil + } + if action.HoldAndDelegate == "" { + return nil, fmt.Errorf("action has neither Accept, Reject, nor HoldAndDelegate") + } + + delegateURL := expandDelegateURL(action.HoldAndDelegate, nm, peerNode, peerAddr, sshUser, lu) + logf("rsh: check mode: polling %s", delegateURL) + writeNDJSON(w, flusher, rshStatusMessage{Status: "Waiting for approval..."}) + + var err error + action, err = fetchSSHAction(ctx, b, delegateURL, logf) + if err != nil { + return nil, fmt.Errorf("fetching SSH action: %w", err) + } + } +} + +// expandDelegateURL expands the variables in a HoldAndDelegate URL. +// The variables match those used by SSH: $SRC_NODE_IP, $SRC_NODE_ID, +// $DST_NODE_IP, $DST_NODE_ID, $SSH_USER, $LOCAL_USER. +func expandDelegateURL( + actionURL string, + nm *netmap.NetworkMap, + peerNode tailcfg.NodeView, + peerAddr netip.Addr, + sshUser string, + lu *user.User, +) string { + var dstNodeID string + if nm != nil { + dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID())) + } + var srcNodeID string + if peerNode.Valid() { + srcNodeID = fmt.Sprint(int64(peerNode.ID())) + } + var dstNodeIP string + if nm != nil { + addrs := nm.GetAddresses() + for _, pfx := range addrs.All() { + if pfx.IsSingleIP() { + dstNodeIP = pfx.Addr().String() + break + } + } + } + return strings.NewReplacer( + "$SRC_NODE_IP", url.QueryEscape(peerAddr.String()), + "$SRC_NODE_ID", srcNodeID, + "$DST_NODE_IP", url.QueryEscape(dstNodeIP), + "$DST_NODE_ID", dstNodeID, + "$SSH_USER", url.QueryEscape(sshUser), + "$LOCAL_USER", url.QueryEscape(lu.Username), + ).Replace(actionURL) +} + +// fetchSSHAction polls a control plane URL over the Noise transport +// and returns the SSHAction. It retries with exponential backoff on +// transient errors, matching the behavior of SSH's fetchSSHAction. +func fetchSSHAction(ctx context.Context, b *ipnlocal.LocalBackend, url string, logf func(string, ...any)) (*tailcfg.SSHAction, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) + defer cancel() + bo := backoff.NewBackoff("rsh-fetch-ssh-action", logf, 10*time.Second) + for { + if err := ctx.Err(); err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + res, err := b.DoNoiseRequest(req) + if err != nil { + bo.BackOff(ctx, err) + continue + } + if res.StatusCode != 200 { + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if len(body) > 1<<10 { + body = body[:1<<10] + } + logf("rsh: fetch of %v: %s, %s", url, res.Status, body) + bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status)) + continue + } + a := new(tailcfg.SSHAction) + err = json.NewDecoder(res.Body).Decode(a) + res.Body.Close() + if err != nil { + logf("rsh: invalid SSHAction JSON from %v: %v", url, err) + bo.BackOff(ctx, err) + continue + } + return a, nil + } +} + +// pickListenAddr selects a local Tailscale IP address that matches the +// address family of the peer. This ensures the data channel connection +// uses the same protocol version. +func pickListenAddr(nm *netmap.NetworkMap, peerAddr netip.Addr) (netip.Addr, error) { + addrs := nm.GetAddresses() + wantV4 := peerAddr.Is4() + + for _, pfx := range addrs.All() { + if !pfx.IsSingleIP() { + continue + } + a := pfx.Addr() + if wantV4 && a.Is4() { + return a, nil + } + if !wantV4 && a.Is6() { + return a, nil + } + } + // Fallback: return any address. + for _, pfx := range addrs.All() { + if pfx.IsSingleIP() { + return pfx.Addr(), nil + } + } + return netip.Addr{}, fmt.Errorf("no Tailscale addresses available") +} + +// listenTailscale creates a TCP listener on the given Tailscale IP. +// In netstack mode, it uses the gVisor stack via the netstackListenTCP hook. +// In kernel TUN mode, it uses the standard library. +func listenTailscale(b *ipnlocal.LocalBackend, addr netip.Addr) (net.Listener, error) { + network := "tcp4" + if addr.Is6() { + network = "tcp6" + } + listenAddr := netip.AddrPortFrom(addr, 0).String() + + if b.Sys().IsNetstack() { + // In full netstack mode, we need to use the gVisor stack to listen + // since all local IP traffic is handled by netstack. + if netstackListenTCP == nil { + return nil, fmt.Errorf("netstack listener not available (rsh_netstack not linked)") + } + return netstackListenTCP(b, network, listenAddr) + } + + // In kernel TUN mode, the Tailscale IP is assigned to the TUN device + // and the kernel handles routing. Standard net.Listen works. + return net.Listen(network, listenAddr) +} + +// handleRshSession is run in a goroutine. It accepts a single connection +// from the listener, verifies the token and source, then spawns the +// remote command via the incubator. +func handleRshSession( + b *ipnlocal.LocalBackend, + ln net.Listener, + token []byte, + expectedPeer netip.Addr, + lu *user.User, + loginShell string, + groupIDs []string, + req rshRequest, + ph ipnlocal.PeerAPIHandler, + logf func(string, ...any), +) { + defer ln.Close() + + // Set a deadline for the client to connect. + if dl, ok := ln.(interface{ SetDeadline(time.Time) error }); ok { + dl.SetDeadline(time.Now().Add(30 * time.Second)) + } + + conn, err := ln.Accept() + if err != nil { + logf("rsh: accept failed: %v", err) + return + } + ln.Close() // Only accept one connection. + + defer conn.Close() + + // Verify source IP. + tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + logf("rsh: unexpected remote addr type: %T", conn.RemoteAddr()) + return + } + remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP) + if !ok { + logf("rsh: invalid remote IP") + return + } + remoteIP = remoteIP.Unmap() + if remoteIP != expectedPeer { + logf("rsh: unexpected peer %v, expected %v", remoteIP, expectedPeer) + return + } + + // Read and verify token. + var gotToken [tokenLen]byte + if _, err := io.ReadFull(conn, gotToken[:]); err != nil { + logf("rsh: failed to read token: %v", err) + return + } + if subtle.ConstantTimeCompare(gotToken[:], token) != 1 { + logf("rsh: invalid token from %v", remoteIP) + return + } + + // Set TCP_NODELAY for low-latency rsync control messages. + if tc, ok := conn.(*net.TCPConn); ok { + tc.SetNoDelay(true) + } + + logf("rsh: session accepted from %v as %s, command=%q", remoteIP, lu.Username, req.Command) + + // Build and run the incubator command. + runIncubator(b, conn, lu, loginShell, groupIDs, req, ph, logf) +} + +// runIncubator spawns the remote command using the existing SSH incubator +// mechanism for privilege dropping and PAM integration. +func runIncubator( + b *ipnlocal.LocalBackend, + conn net.Conn, + lu *user.User, + loginShell string, + groupIDs []string, + req rshRequest, + ph ipnlocal.PeerAPIHandler, + logf func(string, ...any), +) { + tailscaledPath, err := os.Executable() + if err != nil { + logf("rsh: os.Executable: %v", err) + sendExitCode(conn, 1) + return + } + + peerNode := ph.Peer() + remoteUser := "unknown" + if peerNode.Valid() { + if peerNode.IsTagged() { + remoteUser = strings.Join(peerNode.Tags().AsSlice(), ",") + } else { + _, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr()) + if ok { + remoteUser = uprof.LoginName + } + } + } + + groups := strings.Join(groupIDs, ",") + isShell := req.Command == "" + + incubatorArgs := []string{ + "be-child", + "ssh", + "--login-shell=" + loginShell, + "--uid=" + lu.Uid, + "--gid=" + lu.Gid, + "--groups=" + groups, + "--local-user=" + lu.Username, + "--home-dir=" + lu.HomeDir, + "--remote-user=" + remoteUser, + "--remote-ip=" + ph.RemoteAddr().Addr().String(), + "--has-tty=false", + "--tty-name=", + } + + if runtime.GOOS == linux && hostinfo.IsSELinuxEnforcing() { + incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing") + } + + nm := b.NetMap() + if nm != nil && nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) { + incubatorArgs = append(incubatorArgs, "--force-v1-behavior") + } + + if isShell { + incubatorArgs = append(incubatorArgs, "--shell") + } else { + incubatorArgs = append(incubatorArgs, "--cmd="+req.Command) + } + + cmd := exec.Command(tailscaledPath, incubatorArgs...) + cmd.Dir = "/" + + // Set up the environment for the child. + cmd.Env = []string{ + "SHELL=" + loginShell, + "USER=" + lu.Username, + "HOME=" + lu.HomeDir, + "PATH=" + defaultPathForUser(lu), + } + + // Create stdin/stdout/stderr pipes. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + logf("rsh: stdin pipe: %v", err) + sendExitCode(conn, 1) + return + } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + logf("rsh: stdout pipe: %v", err) + sendExitCode(conn, 1) + return + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + logf("rsh: stderr pipe: %v", err) + sendExitCode(conn, 1) + return + } + + if err := cmd.Start(); err != nil { + logf("rsh: start incubator: %v", err) + sendExitCode(conn, 1) + return + } + + fw := newFrameWriter(conn) + fr := newFrameReader(conn) + + // Goroutine: read frames from client, write stdin to incubator. + stdinDone := make(chan struct{}) + go func() { + defer close(stdinDone) + defer stdinPipe.Close() + for { + ch, data, err := fr.ReadFrame() + if err != nil { + return + } + if ch == ChanStdin { + if _, err := stdinPipe.Write(data); err != nil { + return + } + } + } + }() + + // Goroutine: read stdout from incubator, write frames to client. + stdoutDone := make(chan struct{}) + go func() { + defer close(stdoutDone) + buf := make([]byte, 64*1024) + for { + n, err := stdoutPipe.Read(buf) + if n > 0 { + if werr := fw.WriteFrame(ChanStdout, buf[:n]); werr != nil { + return + } + } + if err != nil { + return + } + } + }() + + // Goroutine: read stderr from incubator, write frames to client. + stderrDone := make(chan struct{}) + go func() { + defer close(stderrDone) + buf := make([]byte, 64*1024) + for { + n, err := stderrPipe.Read(buf) + if n > 0 { + if werr := fw.WriteFrame(ChanStderr, buf[:n]); werr != nil { + return + } + } + if err != nil { + return + } + } + }() + + // Wait for the process to exit. + exitCode := 0 + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + logf("rsh: wait: %v", err) + exitCode = 1 + } + } + + // Wait for output goroutines to drain. + <-stdoutDone + <-stderrDone + + // Send exit code and close. + logf("rsh: session ended for %s, exit code %d", lu.Username, exitCode) + fw.WriteExitCode(exitCode) +} + +// sendExitCode is a helper used before the framing writer is set up. +func sendExitCode(conn net.Conn, code int) { + fw := newFrameWriter(conn) + fw.WriteExitCode(code) +} + +// defaultPathForUser returns an appropriate default PATH for the user. +// This is a simplified version of the logic in ssh/tailssh/user.go. +func defaultPathForUser(u *user.User) string { + if u.Uid == "0" { + return "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + } + return "/usr/local/bin:/usr/bin:/bin" +} + +// envknobs for debugging. +var rshVerbose = envknob.RegisterBool("TS_DEBUG_RSH_VLOG") diff --git a/feature/rsh/rsh_netstack.go b/feature/rsh/rsh_netstack.go new file mode 100644 index 000000000..2b279eb16 --- /dev/null +++ b/feature/rsh/rsh_netstack.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd + +package rsh + +import ( + "net" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/wgengine/netstack" +) + +func init() { + netstackListenTCP = netstackListenTCPImpl +} + +func netstackListenTCPImpl(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error) { + ns, ok := b.Sys().Netstack.GetOK() + if !ok { + return nil, net.ErrClosed + } + // Type-assert to *netstack.Impl which has the ListenTCP method. + impl, ok := ns.(*netstack.Impl) + if !ok { + return nil, net.ErrClosed + } + return impl.ListenTCP(network, address) +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 27858484a..9aa252892 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -653,6 +653,13 @@ func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) { func (b *LocalBackend) Clock() tstime.Clock { return b.clock } func (b *LocalBackend) Sys() *tsd.System { return b.sys } +// PeerAPIBase returns the "http://ip:port" URL base to reach a peer's PeerAPI. +// It returns the empty string if the peer doesn't support PeerAPI or there's +// no matching address family. +func (b *LocalBackend) PeerAPIBase(peer tailcfg.NodeView) string { + return peerAPIBase(b.NetMap(), peer) +} + // NodeBackend returns the current node's NodeBackend interface. func (b *LocalBackend) NodeBackend() ipnext.NodeBackend { return b.currentNode() |
