summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorVimT <me@vimt.me>2024-06-25 23:40:42 +0800
committerBrad Fitzpatrick <brad@danga.com>2024-08-05 09:25:24 -0700
commite3f047618bd09f352e755bebcbbeb7db06734679 (patch)
treee7a84f82d7bc456ed1fc4ac7485c2110be345bea
parent91d2e1772ddc359a1c015d8d63cfe19f813b34ef (diff)
downloadtailscale-e3f047618bd09f352e755bebcbbeb7db06734679.tar.xz
tailscale-e3f047618bd09f352e755bebcbbeb7db06734679.zip
net/socks5: support UDP
Updates #7581 Signed-off-by: VimT <me@vimt.me>
-rw-r--r--net/socks5/socks5.go452
-rw-r--r--net/socks5/socks5_test.go113
2 files changed, 484 insertions, 81 deletions
diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go
index b774ebe24..0d651537f 100644
--- a/net/socks5/socks5.go
+++ b/net/socks5/socks5.go
@@ -13,8 +13,10 @@
package socks5
import (
+ "bytes"
"context"
"encoding/binary"
+ "errors"
"fmt"
"io"
"log"
@@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
}
go func() {
defer c.Close()
- conn := &Conn{clientConn: c, srv: s}
+ conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
err := conn.Run()
if err != nil {
s.logf("client connection failed: %v", err)
@@ -136,9 +138,12 @@ type Conn struct {
// The struct is filled by each of the internal
// methods in turn as the transaction progresses.
+ logf logger.Logf
srv *Server
clientConn net.Conn
request *request
+
+ udpClientAddr net.Addr
}
// Run starts the new connection.
@@ -172,58 +177,59 @@ func (c *Conn) Run() error {
func (c *Conn) handleRequest() error {
req, err := parseClientRequest(c.clientConn)
if err != nil {
- res := &response{reply: generalFailure}
+ res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
- if req.command != connect {
- res := &response{reply: commandNotSupported}
+
+ c.request = req
+ switch req.command {
+ case connect:
+ return c.handleTCP()
+ case udpAssociate:
+ return c.handleUDP()
+ default:
+ res := errorResponse(commandNotSupported)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return fmt.Errorf("unsupported command %v", req.command)
}
- c.request = req
+}
+func (c *Conn) handleTCP() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
srv, err := c.srv.dial(
ctx,
"tcp",
- net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
+ c.request.destination.hostPort(),
)
if err != nil {
- res := &response{reply: generalFailure}
+ res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer srv.Close()
- serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
+
+ localAddr := srv.LocalAddr().String()
+ serverAddr, serverPort, err := splitHostPort(localAddr)
if err != nil {
return err
}
- serverPort, _ := strconv.Atoi(serverPortStr)
- var bindAddrType addrType
- if ip := net.ParseIP(serverAddr); ip != nil {
- if ip.To4() != nil {
- bindAddrType = ipv4
- } else {
- bindAddrType = ipv6
- }
- } else {
- bindAddrType = domainName
- }
res := &response{
- reply: success,
- bindAddrType: bindAddrType,
- bindAddr: serverAddr,
- bindPort: uint16(serverPort),
+ reply: success,
+ bindAddr: socksAddr{
+ addrType: getAddrType(serverAddr),
+ addr: serverAddr,
+ port: serverPort,
+ },
}
buf, err := res.marshal()
if err != nil {
- res = &response{reply: generalFailure}
+ res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
c.clientConn.Write(buf)
@@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
return <-errc
}
+func (c *Conn) handleUDP() error {
+ // The DST.ADDR and DST.PORT fields contain the address and port that
+ // the client expects to use to send UDP datagrams on for the
+ // association. The server MAY use this information to limit access
+ // to the association.
+ // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
+ //
+ // We do NOT limit the access from the client currently in this implementation.
+ _ = c.request.destination
+
+ addr := c.clientConn.LocalAddr()
+ host, _, err := net.SplitHostPort(addr.String())
+ if err != nil {
+ return err
+ }
+ clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
+ if err != nil {
+ res := errorResponse(generalFailure)
+ buf, _ := res.marshal()
+ c.clientConn.Write(buf)
+ return err
+ }
+ defer clientUDPConn.Close()
+
+ serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
+ if err != nil {
+ res := errorResponse(generalFailure)
+ buf, _ := res.marshal()
+ c.clientConn.Write(buf)
+ return err
+ }
+ defer serverUDPConn.Close()
+
+ bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
+ if err != nil {
+ return err
+ }
+
+ res := &response{
+ reply: success,
+ bindAddr: socksAddr{
+ addrType: getAddrType(bindAddr),
+ addr: bindAddr,
+ port: bindPort,
+ },
+ }
+ buf, err := res.marshal()
+ if err != nil {
+ res = errorResponse(generalFailure)
+ buf, _ = res.marshal()
+ }
+ c.clientConn.Write(buf)
+
+ return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
+}
+
+func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ const bufferSize = 8 * 1024
+ const readTimeout = 5 * time.Second
+
+ // client -> target
+ go func() {
+ defer cancel()
+ buf := make([]byte, bufferSize)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
+ if err != nil {
+ if isTimeout(err) {
+ continue
+ }
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ c.logf("udp transfer: handle udp request fail: %v", err)
+ }
+ }
+ }
+ }()
+
+ // target -> client
+ go func() {
+ defer cancel()
+ buf := make([]byte, bufferSize)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
+ if err != nil {
+ if isTimeout(err) {
+ continue
+ }
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ c.logf("udp transfer: handle udp response fail: %v", err)
+ }
+ }
+ }
+ }()
+
+ // A UDP association terminates when the TCP connection that the UDP
+ // ASSOCIATE request arrived on terminates. RFC1928
+ _, err := io.Copy(io.Discard, associatedTCP)
+ if err != nil {
+ err = fmt.Errorf("udp associated tcp conn: %w", err)
+ }
+ return err
+}
+
+func (c *Conn) handleUDPRequest(
+ clientConn net.PacketConn,
+ targetConn net.PacketConn,
+ buf []byte,
+ readTimeout time.Duration,
+) error {
+ // add a deadline for the read to avoid blocking forever
+ _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
+ n, addr, err := clientConn.ReadFrom(buf)
+ if err != nil {
+ return fmt.Errorf("read from client: %w", err)
+ }
+ c.udpClientAddr = addr
+ req, data, err := parseUDPRequest(buf[:n])
+ if err != nil {
+ return fmt.Errorf("parse udp request: %w", err)
+ }
+ targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
+ if err != nil {
+ c.logf("resolve target addr fail: %v", err)
+ }
+
+ nn, err := targetConn.WriteTo(data, targetAddr)
+ if err != nil {
+ return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
+ }
+ if nn != len(data) {
+ return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
+ }
+ return nil
+}
+
+func (c *Conn) handleUDPResponse(
+ targetConn net.PacketConn,
+ clientConn net.PacketConn,
+ buf []byte,
+ readTimeout time.Duration,
+) error {
+ // add a deadline for the read to avoid blocking forever
+ _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
+ n, addr, err := targetConn.ReadFrom(buf)
+ if err != nil {
+ return fmt.Errorf("read from target: %w", err)
+ }
+ host, port, err := splitHostPort(addr.String())
+ if err != nil {
+ return fmt.Errorf("split host port: %w", err)
+ }
+ hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
+ pkt, err := hdr.marshal()
+ if err != nil {
+ return fmt.Errorf("marshal udp request: %w", err)
+ }
+ data := append(pkt, buf[:n]...)
+ // use addr from client to send back
+ nn, err := clientConn.WriteTo(data, c.udpClientAddr)
+ if err != nil {
+ return fmt.Errorf("write to client: %w", err)
+ }
+ if nn != len(data) {
+ return fmt.Errorf("write to client: %w", io.ErrShortWrite)
+ }
+ return nil
+}
+
+func isTimeout(err error) bool {
+ terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
+ return ok && terr.Timeout()
+}
+
+func splitHostPort(hostport string) (host string, port uint16, err error) {
+ host, portStr, err := net.SplitHostPort(hostport)
+ if err != nil {
+ return "", 0, err
+ }
+ portInt, err := strconv.Atoi(portStr)
+ if err != nil {
+ return "", 0, err
+ }
+ if portInt < 0 || portInt > 65535 {
+ return "", 0, fmt.Errorf("invalid port number %d", portInt)
+ }
+ return host, uint16(portInt), nil
+}
+
// parseClientGreeting parses a request initiation packet.
func parseClientGreeting(r io.Reader, authMethod byte) error {
var hdr [2]byte
@@ -295,123 +503,205 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
return string(usrBytes), string(pwdBytes), nil
}
+func getAddrType(addr string) addrType {
+ if ip := net.ParseIP(addr); ip != nil {
+ if ip.To4() != nil {
+ return ipv4
+ }
+ return ipv6
+ }
+ return domainName
+}
+
// request represents data contained within a SOCKS5
// connection request packet.
type request struct {
- command commandType
- destination string
- port uint16
- destAddrType addrType
+ command commandType
+ destination socksAddr
}
// parseClientRequest converts raw packet bytes into a
// SOCKS5Request struct.
func parseClientRequest(r io.Reader) (*request, error) {
- var hdr [4]byte
+ var hdr [3]byte
_, err := io.ReadFull(r, hdr[:])
if err != nil {
return nil, fmt.Errorf("could not read packet header")
}
cmd := hdr[1]
- destAddrType := addrType(hdr[3])
- var destination string
- var port uint16
+ destination, err := parseSocksAddr(r)
+ return &request{
+ command: commandType(cmd),
+ destination: destination,
+ }, err
+}
+
+type socksAddr struct {
+ addrType addrType
+ addr string
+ port uint16
+}
+
+var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
+
+func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
+ var addrTypeData [1]byte
+ _, err = io.ReadFull(r, addrTypeData[:])
+ if err != nil {
+ return socksAddr{}, fmt.Errorf("could not read address type")
+ }
- if destAddrType == ipv4 {
+ dstAddrType := addrType(addrTypeData[0])
+ var destination string
+ switch dstAddrType {
+ case ipv4:
var ip [4]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
- return nil, fmt.Errorf("could not read IPv4 address")
+ return socksAddr{}, fmt.Errorf("could not read IPv4 address")
}
destination = net.IP(ip[:]).String()
- } else if destAddrType == domainName {
+ case domainName:
var dstSizeByte [1]byte
_, err = io.ReadFull(r, dstSizeByte[:])
if err != nil {
- return nil, fmt.Errorf("could not read domain name size")
+ return socksAddr{}, fmt.Errorf("could not read domain name size")
}
dstSize := int(dstSizeByte[0])
domainName := make([]byte, dstSize)
_, err = io.ReadFull(r, domainName)
if err != nil {
- return nil, fmt.Errorf("could not read domain name")
+ return socksAddr{}, fmt.Errorf("could not read domain name")
}
destination = string(domainName)
- } else if destAddrType == ipv6 {
+ case ipv6:
var ip [16]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
- return nil, fmt.Errorf("could not read IPv6 address")
+ return socksAddr{}, fmt.Errorf("could not read IPv6 address")
}
destination = net.IP(ip[:]).String()
- } else {
- return nil, fmt.Errorf("unsupported address type")
+ default:
+ return socksAddr{}, fmt.Errorf("unsupported address type")
}
var portBytes [2]byte
_, err = io.ReadFull(r, portBytes[:])
if err != nil {
- return nil, fmt.Errorf("could not read port")
+ return socksAddr{}, fmt.Errorf("could not read port")
}
- port = binary.BigEndian.Uint16(portBytes[:])
-
- return &request{
- command: commandType(cmd),
- destination: destination,
- port: port,
- destAddrType: destAddrType,
+ port := binary.BigEndian.Uint16(portBytes[:])
+ return socksAddr{
+ addrType: dstAddrType,
+ addr: destination,
+ port: port,
}, nil
}
+func (s socksAddr) marshal() ([]byte, error) {
+ var addr []byte
+ switch s.addrType {
+ case ipv4:
+ addr = net.ParseIP(s.addr).To4()
+ if addr == nil {
+ return nil, fmt.Errorf("invalid IPv4 address for binding")
+ }
+ case domainName:
+ if len(s.addr) > 255 {
+ return nil, fmt.Errorf("invalid domain name for binding")
+ }
+ addr = make([]byte, 0, len(s.addr)+1)
+ addr = append(addr, byte(len(s.addr)))
+ addr = append(addr, []byte(s.addr)...)
+ case ipv6:
+ addr = net.ParseIP(s.addr).To16()
+ if addr == nil {
+ return nil, fmt.Errorf("invalid IPv6 address for binding")
+ }
+ default:
+ return nil, fmt.Errorf("unsupported address type")
+ }
+
+ pkt := []byte{byte(s.addrType)}
+ pkt = append(pkt, addr...)
+ pkt = binary.BigEndian.AppendUint16(pkt, s.port)
+ return pkt, nil
+}
+func (s socksAddr) hostPort() string {
+ return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
+}
+
// response contains the contents of
// a response packet sent from the proxy
// to the client.
type response struct {
- reply replyCode
- bindAddrType addrType
- bindAddr string
- bindPort uint16
+ reply replyCode
+ bindAddr socksAddr
+}
+
+func errorResponse(code replyCode) *response {
+ return &response{reply: code, bindAddr: zeroSocksAddr}
}
// marshal converts a SOCKS5Response struct into
// a packet. If res.reply == Success, it may throw an error on
// receiving an invalid bind address. Otherwise, it will not throw.
func (res *response) marshal() ([]byte, error) {
- pkt := make([]byte, 4)
+ pkt := make([]byte, 3)
pkt[0] = socks5Version
pkt[1] = byte(res.reply)
pkt[2] = 0 // null reserved byte
- pkt[3] = byte(res.bindAddrType)
- if res.reply != success {
- return pkt, nil
+ addrPkt, err := res.bindAddr.marshal()
+ if err != nil {
+ return nil, err
}
- var addr []byte
- switch res.bindAddrType {
- case ipv4:
- addr = net.ParseIP(res.bindAddr).To4()
- if addr == nil {
- return nil, fmt.Errorf("invalid IPv4 address for binding")
- }
- case domainName:
- if len(res.bindAddr) > 255 {
- return nil, fmt.Errorf("invalid domain name for binding")
- }
- addr = make([]byte, 0, len(res.bindAddr)+1)
- addr = append(addr, byte(len(res.bindAddr)))
- addr = append(addr, []byte(res.bindAddr)...)
- case ipv6:
- addr = net.ParseIP(res.bindAddr).To16()
- if addr == nil {
- return nil, fmt.Errorf("invalid IPv6 address for binding")
- }
- default:
- return nil, fmt.Errorf("unsupported address type")
+ return append(pkt, addrPkt...), nil
+}
+
+type udpRequest struct {
+ frag byte
+ addr socksAddr
+}
+
+// +----+------+------+----------+----------+----------+
+// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
+// +----+------+------+----------+----------+----------+
+// | 2 | 1 | 1 | Variable | 2 | Variable |
+// +----+------+------+----------+----------+----------+
+func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
+ if len(data) < 4 {
+ return nil, nil, fmt.Errorf("invalid packet length")
}
- pkt = append(pkt, addr...)
- pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
+ // reserved bytes
+ if !(data[0] == 0 && data[1] == 0) {
+ return nil, nil, fmt.Errorf("invalid udp request header")
+ }
- return pkt, nil
+ frag := data[2]
+
+ reader := bytes.NewReader(data[3:])
+ addr, err := parseSocksAddr(reader)
+ bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
+ body := data[len(data)-bodyLen:]
+ return &udpRequest{
+ frag: frag,
+ addr: addr,
+ }, body, err
+}
+
+func (u *udpRequest) marshal() ([]byte, error) {
+ pkt := make([]byte, 3)
+ pkt[0] = 0
+ pkt[1] = 0
+ pkt[2] = u.frag
+
+ addrPkt, err := u.addr.marshal()
+ if err != nil {
+ return nil, err
+ }
+
+ return append(pkt, addrPkt...), nil
}
diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go
index 201a66575..11ea59d4b 100644
--- a/net/socks5/socks5_test.go
+++ b/net/socks5/socks5_test.go
@@ -4,6 +4,7 @@
package socks5
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -32,6 +33,19 @@ func backendServer(listener net.Listener) {
listener.Close()
}
+func udpEchoServer(conn net.PacketConn) {
+ var buf [1024]byte
+ n, addr, err := conn.ReadFrom(buf[:])
+ if err != nil {
+ panic(err)
+ }
+ _, err = conn.WriteTo(buf[:n], addr)
+ if err != nil {
+ panic(err)
+ }
+ conn.Close()
+}
+
func TestRead(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
listener, err := net.Listen("tcp", ":0")
@@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestUDP(t *testing.T) {
+ // backend UDP server which we'll use SOCKS5 to connect to
+ listener, err := net.ListenPacket("udp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
+ go udpEchoServer(listener)
+
+ // SOCKS5 server
+ socks5, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ socks5Port := socks5.Addr().(*net.TCPAddr).Port
+ go socks5Server(socks5)
+
+ // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
+ conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, 1024)
+ n, err := conn.Read(buf) // server hello
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
+ t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
+ }
+
+ targetAddr := socksAddr{
+ addrType: domainName,
+ addr: "localhost",
+ port: uint16(backendServerPort),
+ }
+ targetAddrPkt, err := targetAddr.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ n, err = conn.Read(buf) // server response
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
+ t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+ }
+ udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
+ if err != nil {
+ t.Fatal(err)
+ }
+ udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ udpPayload = append(udpPayload, []byte("Test")...)
+ _, err = udpConn.Write(udpPayload) // send udp package
+ if err != nil {
+ t.Fatal(err)
+ }
+ n, _, err = udpConn.ReadFrom(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(responseBody) != "Test" {
+ t.Fatalf("got: %q want: Test", responseBody)
+ }
+ err = udpConn.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = conn.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+}