summaryrefslogtreecommitdiffhomepage
path: root/cmd/tswrap/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/tswrap/main.go')
-rw-r--r--cmd/tswrap/main.go50
1 files changed, 19 insertions, 31 deletions
diff --git a/cmd/tswrap/main.go b/cmd/tswrap/main.go
index 559d56b03..e6607ac0b 100644
--- a/cmd/tswrap/main.go
+++ b/cmd/tswrap/main.go
@@ -2,15 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build linux
-
// The tswrap binary runs a child process and makes it accessible over
// Tailscale.
package main
import (
- "bufio"
- "bytes"
"context"
"errors"
"flag"
@@ -18,19 +14,18 @@ import (
"io"
"log"
"net"
- "net/netip"
"os"
"os/exec"
"os/signal"
"sort"
"strconv"
- "strings"
+ "syscall"
"time"
- "golang.org/x/sys/unix"
"tailscale.com/client/tailscale"
"tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/store/mem"
+ "tailscale.com/portlist"
"tailscale.com/syncs"
"tailscale.com/tsnet"
"tailscale.com/types/logger"
@@ -43,7 +38,7 @@ var (
func main() {
sigch := make(chan os.Signal, 1)
- signal.Notify(sigch, os.Interrupt, unix.SIGTERM)
+ signal.Notify(sigch, os.Interrupt, syscall.SIGTERM)
flag.Parse()
@@ -188,7 +183,7 @@ func (p *proxy) Stop() {
func (p *proxy) Wait() {
<-p.shutdownCtx.Done()
- p.cmd.Process.Signal(unix.SIGTERM)
+ p.cmd.Process.Signal(syscall.SIGTERM)
p.ln.Close()
if p.srv.Ephemeral {
p.client.Logout(context.Background())
@@ -290,37 +285,30 @@ func portsOfCmd(cmd *exec.Cmd) (ports []int, err error) {
return nil, errors.New("no process")
}
pid := cmd.Process.Pid
- wantSub := fmt.Sprintf(" %d/", pid)
- ns := exec.Command("netstat", "-p", "--inet", "-l", "-n")
- out, err := ns.Output()
+ poller, err := portlist.NewPoller()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("creating port poller: %w", err)
}
- bs := bufio.NewScanner(bytes.NewReader(out))
- for bs.Scan() {
- line := bs.Text()
- if !strings.HasPrefix(line, "tcp") ||
- !strings.Contains(line, "LISTEN") ||
- !strings.Contains(line, wantSub) {
- continue
+ defer poller.Close()
+ // TODO(raggi): timeout?
+ go poller.Run(context.Background())
+
+ c := poller.Updates()
+ for list := range c {
+ for _, p := range list {
+ if p.Pid == pid {
+ ports = append(ports, int(p.Port))
+ }
}
- f := strings.Fields(line)
- if len(f) < 4 {
- continue
+ if len(ports) > 0 {
+ break
}
- ipp, err := netip.ParseAddrPort(f[3])
- if err == nil {
- ports = append(ports, int(ipp.Port()))
- continue
- }
- }
- if err := bs.Err(); err != nil {
- return nil, err
}
if len(ports) == 0 {
return nil, errors.New("no listening ports found")
}
+
sort.Ints(ports)
return ports, nil
}