diff options
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/tswrap/main.go | 50 |
1 files changed, 19 insertions, 31 deletions
diff --git a/cmd/tswrap/main.go b/cmd/tswrap/main.go index 559d56b03..e6607ac0b 100644 --- a/cmd/tswrap/main.go +++ b/cmd/tswrap/main.go @@ -2,15 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build linux - // The tswrap binary runs a child process and makes it accessible over // Tailscale. package main import ( - "bufio" - "bytes" "context" "errors" "flag" @@ -18,19 +14,18 @@ import ( "io" "log" "net" - "net/netip" "os" "os/exec" "os/signal" "sort" "strconv" - "strings" + "syscall" "time" - "golang.org/x/sys/unix" "tailscale.com/client/tailscale" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" + "tailscale.com/portlist" "tailscale.com/syncs" "tailscale.com/tsnet" "tailscale.com/types/logger" @@ -43,7 +38,7 @@ var ( func main() { sigch := make(chan os.Signal, 1) - signal.Notify(sigch, os.Interrupt, unix.SIGTERM) + signal.Notify(sigch, os.Interrupt, syscall.SIGTERM) flag.Parse() @@ -188,7 +183,7 @@ func (p *proxy) Stop() { func (p *proxy) Wait() { <-p.shutdownCtx.Done() - p.cmd.Process.Signal(unix.SIGTERM) + p.cmd.Process.Signal(syscall.SIGTERM) p.ln.Close() if p.srv.Ephemeral { p.client.Logout(context.Background()) @@ -290,37 +285,30 @@ func portsOfCmd(cmd *exec.Cmd) (ports []int, err error) { return nil, errors.New("no process") } pid := cmd.Process.Pid - wantSub := fmt.Sprintf(" %d/", pid) - ns := exec.Command("netstat", "-p", "--inet", "-l", "-n") - out, err := ns.Output() + poller, err := portlist.NewPoller() if err != nil { - return nil, err + return nil, fmt.Errorf("creating port poller: %w", err) } - bs := bufio.NewScanner(bytes.NewReader(out)) - for bs.Scan() { - line := bs.Text() - if !strings.HasPrefix(line, "tcp") || - !strings.Contains(line, "LISTEN") || - !strings.Contains(line, wantSub) { - continue + defer poller.Close() + // TODO(raggi): timeout? + go poller.Run(context.Background()) + + c := poller.Updates() + for list := range c { + for _, p := range list { + if p.Pid == pid { + ports = append(ports, int(p.Port)) + } } - f := strings.Fields(line) - if len(f) < 4 { - continue + if len(ports) > 0 { + break } - ipp, err := netip.ParseAddrPort(f[3]) - if err == nil { - ports = append(ports, int(ipp.Port())) - continue - } - } - if err := bs.Err(); err != nil { - return nil, err } if len(ports) == 0 { return nil, errors.New("no listening ports found") } + sort.Ints(ports) return ports, nil } |
