summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Crawshaw <crawshaw@tailscale.com>2023-02-26 17:32:36 -0500
committerDavid Crawshaw <crawshaw@tailscale.com>2023-02-26 17:47:52 -0500
commit5f256f114f178da7bf59da97537ab77052b379dc (patch)
treec9f55593ff60c44a769020191b1f9df38cb46aab
parente484e1c0fc0481194971b5eaa0f3123ef76f3c36 (diff)
downloadtailscale-crawshaw/pidlisten.tar.xz
tailscale-crawshaw/pidlisten.zip
net/pidlisten: new package that restricts dials to the current processcrawshaw/pidlisten
To be used in the C library wrapping tsnet to provide LocalAPI access. This commit contains a linux implementation. More operating systems to follow. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
-rw-r--r--net/pidlisten/pidlisten.go42
-rw-r--r--net/pidlisten/pidlisten_linux.go63
-rw-r--r--net/pidlisten/pidlisten_noimpl.go13
-rw-r--r--net/pidlisten/pidlisten_test.go122
4 files changed, 240 insertions, 0 deletions
diff --git a/net/pidlisten/pidlisten.go b/net/pidlisten/pidlisten.go
new file mode 100644
index 000000000..520ef7765
--- /dev/null
+++ b/net/pidlisten/pidlisten.go
@@ -0,0 +1,42 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package pidlisten implements a TCP listener that only
+// accepts connections from the current process.
+package pidlisten
+
+import (
+ "fmt"
+ "net"
+)
+
+type listener struct {
+ ln net.Listener
+}
+
+func (pln *listener) Accept() (net.Conn, error) {
+ for {
+ conn, err := pln.ln.Accept()
+ if err != nil {
+ return nil, err
+ }
+ ok, err := checkPIDLocal(conn)
+ if err != nil {
+ conn.Close()
+ return nil, fmt.Errorf("pidlisten: %w", err)
+ }
+ if !ok {
+ conn.Close()
+ continue
+ }
+ return conn, nil
+ }
+}
+
+func (pln *listener) Close() error {
+ return pln.ln.Close()
+}
+
+func (pln *listener) Addr() net.Addr {
+ return pln.ln.Addr()
+}
diff --git a/net/pidlisten/pidlisten_linux.go b/net/pidlisten/pidlisten_linux.go
new file mode 100644
index 000000000..6e4ffbdf2
--- /dev/null
+++ b/net/pidlisten/pidlisten_linux.go
@@ -0,0 +1,63 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package pidlisten
+
+import (
+ "errors"
+ "fmt"
+ "go4.org/mem"
+ "io/fs"
+ "net"
+ "os"
+ "path/filepath"
+ "tailscale.com/util/dirwalk"
+
+ "github.com/vishvananda/netlink"
+)
+
+// NewPIDListener wraps a net.Listener so that it only accepts connections from the current process.
+func NewPIDListener(ln net.Listener) net.Listener {
+ return &listener{ln: ln}
+}
+
+var errFoundSocket = errors.New("found socket")
+
+func checkPIDLocal(conn net.Conn) (bool, error) {
+ remoteAddr := conn.RemoteAddr()
+ var remoteIP net.IP
+ switch remoteAddr.Network() {
+ case "tcp":
+ remoteIP = remoteAddr.(*net.TCPAddr).IP
+ case "udp":
+ remoteIP = remoteAddr.(*net.UDPAddr).IP
+ default:
+ return false, nil
+ }
+ if !remoteIP.IsLoopback() {
+ return false, nil
+ }
+
+ // You can look up a net.Conn in both directions.
+ // There are different inodes for remote->local and local->remote.
+ // We want to look up the starting side of the net.Conn and check
+ // that its inode belongs to the current PID.
+ s, err := netlink.SocketGet(conn.RemoteAddr(), conn.LocalAddr())
+ if err != nil {
+ return false, err
+ }
+
+ want := fmt.Sprintf("socket:[%d]", s.INode)
+ dir := fmt.Sprintf("/proc/%d/fd", os.Getpid())
+ err = dirwalk.WalkShallow(mem.S(dir), func(name mem.RO, de fs.DirEntry) error {
+ n, err := os.Readlink(filepath.Join(dir, name.StringCopy()))
+ if err == nil && want == n {
+ return errFoundSocket
+ }
+ return nil
+ })
+ if err == errFoundSocket {
+ return true, nil
+ }
+ return false, err
+}
diff --git a/net/pidlisten/pidlisten_noimpl.go b/net/pidlisten/pidlisten_noimpl.go
new file mode 100644
index 000000000..395425731
--- /dev/null
+++ b/net/pidlisten/pidlisten_noimpl.go
@@ -0,0 +1,13 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+// +build !linux
+
+package pidlisten
+
+import "net"
+
+func checkPIDLocal(conn net.Conn) (bool, error) {
+ panic("not implemented")
+}
diff --git a/net/pidlisten/pidlisten_test.go b/net/pidlisten/pidlisten_test.go
new file mode 100644
index 000000000..cce02c2b8
--- /dev/null
+++ b/net/pidlisten/pidlisten_test.go
@@ -0,0 +1,122 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+// +build linux
+
+package pidlisten
+
+import (
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "os/exec"
+ "testing"
+ "time"
+)
+
+var flagDial = flag.String("dial", "", "if set, dials the given addr and reads until close")
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+ if *flagDial != "" {
+ conn, err := net.DialTimeout("tcp", *flagDial, 5*time.Second)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+ conn.SetDeadline(time.Now().Add(5 * time.Second))
+ b, err := io.ReadAll(conn)
+ fmt.Fprintf(os.Stderr, "%s", b)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+ os.Exit(0)
+ }
+ os.Exit(m.Run())
+}
+
+func TestPIDLocal(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ clientConn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientConn.Close()
+
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ok, err := checkPIDLocal(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !ok {
+ t.Errorf("checkPIDLocal=false, want true")
+ }
+}
+
+func testExternalProcess(t *testing.T, ln net.Listener) string {
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ panic(err)
+ }
+ fmt.Fprintf(c, "hello\n")
+ c.Close()
+ }
+ }()
+
+ exe, err := os.Executable()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ out, err := exec.Command(exe, "-dial="+ln.Addr().String()).CombinedOutput()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(out)
+}
+
+func TestExternalDialWorks(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ out := testExternalProcess(t, ln)
+ if out != "hello\n" {
+ t.Errorf("out=%q, want hello", out)
+ }
+}
+
+func TestPIDExternal(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ ln = NewPIDListener(ln)
+ out := testExternalProcess(t, ln)
+
+ if len(out) != 0 {
+ t.Errorf("unexpected socket output: %q", out)
+ }
+}