summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tsnet/tsnet.go22
-rw-r--r--tsnet/tsnet_test.go22
2 files changed, 38 insertions, 6 deletions
diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go
index a4a4bda04..f1ad0180e 100644
--- a/tsnet/tsnet.go
+++ b/tsnet/tsnet.go
@@ -349,12 +349,17 @@ func (s *Server) Close() error {
s.loopbackListener.Close()
}
+ var lns []*listener
+
s.mu.Lock()
- defer s.mu.Unlock()
for _, ln := range s.listeners {
+ lns = append(lns, ln)
+ }
+ s.mu.Unlock()
+
+ for _, ln := range lns {
ln.Close()
}
- s.listeners = nil
wg.Wait()
return nil
@@ -997,10 +1002,11 @@ type listenKey struct {
}
type listener struct {
- s *Server
- keys []listenKey
- addr string
- conn chan net.Conn
+ s *Server
+ keys []listenKey
+ addr string
+ conn chan net.Conn
+ closed bool // guarded by s.mu
}
func (ln *listener) Accept() (net.Conn, error) {
@@ -1015,12 +1021,16 @@ func (ln *listener) Addr() net.Addr { return addr{ln} }
func (ln *listener) Close() error {
ln.s.mu.Lock()
defer ln.s.mu.Unlock()
+ if ln.closed {
+ return fmt.Errorf("tsnet: %w", net.ErrClosed)
+ }
for _, key := range ln.keys {
if v, ok := ln.s.listeners[key]; ok && v == ln {
delete(ln.s.listeners, key)
}
}
close(ln.conn)
+ ln.closed = true
return nil
}
diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go
index ab55b7b60..0ab75406b 100644
--- a/tsnet/tsnet_test.go
+++ b/tsnet/tsnet_test.go
@@ -9,6 +9,7 @@ import (
"flag"
"fmt"
"io"
+ "net"
"net/http"
"net/http/httptest"
"net/netip"
@@ -344,3 +345,24 @@ func TestTailscaleIPs(t *testing.T) {
sIp4, upIp4, sIp6, upIp6)
}
}
+
+func TestListenerCleanup(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ controlURL := startControl(t)
+ s1, _ := startServer(t, ctx, controlURL, "s1")
+
+ ln, err := s1.Listen("tcp", ":8081")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := s1.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := ln.Close(); !errors.Is(err, net.ErrClosed) {
+ t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
+ }
+}