diff options
Diffstat (limited to 'wgengine/netstack/netstack.go')
| -rw-r--r-- | wgengine/netstack/netstack.go | 338 |
1 files changed, 224 insertions, 114 deletions
diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index b2b21fcba..5906f00bf 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -12,8 +12,8 @@ import ( "context" "errors" "fmt" + "io" "log" - "net" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -37,161 +38,270 @@ import ( "tailscale.com/wgengine/tstun" ) -func Impl(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) error { +// Impl contains the state for the netstack implementation, +// and implements wgengine.FakeImpl to act as a userspace network +// stack when Tailscale is running in fake mode. +type Impl struct { + ipstack *stack.Stack + linkEP *channel.Endpoint + tundev *tstun.TUN + e wgengine.Engine + mc *magicsock.Conn + logf logger.Logf +} + +const nicID = 1 + +// Create creates and populates a new Impl. +func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (wgengine.FakeImpl, error) { if mc == nil { - return errors.New("nil magicsock.Conn") + return nil, errors.New("nil magicsock.Conn") } if tundev == nil { - return errors.New("nil tundev") + return nil, errors.New("nil tundev") } if logf == nil { - return errors.New("nil logger") + return nil, errors.New("nil logger") } if e == nil { - return errors.New("nil Engine") + return nil, errors.New("nil Engine") } ipstack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, }) - const mtu = 1500 linkEP := channel.New(512, mtu, "") + if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { + return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) + } + // Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side + // are handled by the one fake NIC we use. + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) + ipv6Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 16)), tcpip.AddressMask(strings.Repeat("\x00", 16))) + ipstack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + { + Destination: ipv6Subnet, + NIC: nicID, + }, + }) + ns := &Impl{ + logf: logf, + ipstack: ipstack, + linkEP: linkEP, + tundev: tundev, + e: e, + mc: mc, + } + return ns, nil +} + +// Start sets up all the handlers so netstack can start working. Implements +// wgengine.FakeImpl. +func (ns *Impl) Start() error { + ns.e.AddNetworkMapCallback(ns.updateIPs) + // size = 0 means use default buffer size + const tcpReceiveBufferSize = 0 + const maxInFlightConnectionAttempts = 16 + tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP) + udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP) + ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) + go ns.injectOutbound() + ns.tundev.PostFilterIn = ns.injectInbound - const nicID = 1 - if err := ipstack.CreateNIC(nicID, linkEP); err != nil { - log.Fatal(err) + return nil +} + +func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { + oldIPs := make(map[tcpip.Address]bool) + for _, ip := range ns.ipstack.AllAddresses()[nicID] { + oldIPs[ip.AddressWithPrefix.Address] = true + } + newIPs := make(map[tcpip.Address]bool) + for _, ip := range nm.Addresses { + newIPs[tcpip.Address(ip.IP.IPAddr().IP)] = true } - e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) { - oldIPs := make(map[tcpip.Address]bool) - for _, ip := range ipstack.AllAddresses()[nicID] { - oldIPs[ip.AddressWithPrefix.Address] = true + ipsToBeAdded := make(map[tcpip.Address]bool) + for ip := range newIPs { + if !oldIPs[ip] { + ipsToBeAdded[ip] = true } - newIPs := make(map[tcpip.Address]bool) - for _, ip := range nm.Addresses { - newIPs[tcpip.Address(ip.IPNet().IP)] = true + } + ipsToBeRemoved := make(map[tcpip.Address]bool) + for ip := range oldIPs { + if !newIPs[ip] { + ipsToBeRemoved[ip] = true } + } - ipsToBeAdded := make(map[tcpip.Address]bool) - for ip := range newIPs { - if !oldIPs[ip] { - ipsToBeAdded[ip] = true - } + for ip := range ipsToBeRemoved { + err := ns.ipstack.RemoveAddress(nicID, ip) + if err != nil { + ns.logf("netstack: could not deregister IP %s: %v", ip, err) + } else { + ns.logf("[v2] netstack: deregistered IP %s", ip) + } + } + for ip := range ipsToBeAdded { + var err *tcpip.Error + if ip.To4() == "" { + err = ns.ipstack.AddAddress(nicID, ipv6.ProtocolNumber, ip) + } else { + err = ns.ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) } - ipsToBeRemoved := make(map[tcpip.Address]bool) - for ip := range oldIPs { - if !newIPs[ip] { - ipsToBeRemoved[ip] = true - } + if err != nil { + ns.logf("netstack: could not register IP %s: %v", ip, err) + } else { + ns.logf("[v2] netstack: registered IP %s", ip) } + } +} - for ip := range ipsToBeRemoved { - err := ipstack.RemoveAddress(nicID, ip) - if err != nil { - logf("netstack: could not deregister IP %s: %v", ip, err) - } else { - logf("netstack: deregistered IP %s", ip) - } - } - for ip := range ipsToBeAdded { - err := ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) - if err != nil { - logf("netstack: could not register IP %s: %v", ip, err) - } else { - logf("netstack: registered IP %s", ip) - } +func (ns *Impl) dialContextTCP(ctx context.Context, address string) (*gonet.TCPConn, error) { + remoteIPPort, err := netaddr.ParseIPPort(address) + if err != nil { + return nil, fmt.Errorf("could not parse IP:port: %w", err) + } + remoteAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(remoteIPPort.IP.IPAddr().IP), + Port: remoteIPPort.Port, + } + var ipType tcpip.NetworkProtocolNumber + if remoteIPPort.IP.Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + + return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType) +} + +func (ns *Impl) injectOutbound() { + for { + packetInfo, ok := ns.linkEP.ReadContext(context.Background()) + if !ok { + ns.logf("[v2] ReadContext-for-write = ok=false") + continue } - }) + pkt := packetInfo.Pkt + hdrNetwork := pkt.NetworkHeader() + hdrTransport := pkt.TransportHeader() - // Add 0.0.0.0/0 default route. - subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) - ipstack.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: nicID, - }, - }) + full := make([]byte, 0, pkt.Size()) + full = append(full, hdrNetwork.View()...) + full = append(full, hdrTransport.View()...) + full = append(full, pkt.Data.ToView()...) - // use Forwarder to accept any connection from stack - fwd := tcp.NewForwarder(ipstack, 0, 16, func(r *tcp.ForwarderRequest) { - logf("XXX ForwarderRequest: %v", r) - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - r.Complete(true) + ns.logf("[v2] packet Write out: % x", full) + if err := ns.tundev.InjectOutbound(full); err != nil { + log.Printf("netstack inject outbound: %v", err) return } - r.Complete(false) - c := gonet.NewTCPConn(&wq, ep) - // TCP echo - go echo(c, e, mc) - - }) - ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - go func() { - for { - packetInfo, ok := linkEP.ReadContext(context.Background()) - if !ok { - logf("XXX ReadContext-for-write = ok=false") - continue - } - pkt := packetInfo.Pkt - hdrNetwork := pkt.NetworkHeader() - hdrTransport := pkt.TransportHeader() + } +} - full := make([]byte, 0, pkt.Size()) - full = append(full, hdrNetwork.View()...) - full = append(full, hdrTransport.View()...) - full = append(full, pkt.Data.ToView()...) +func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.TUN) filter.Response { + var pn tcpip.NetworkProtocolNumber + switch p.IPVersion { + case 4: + pn = header.IPv4ProtocolNumber + case 6: + pn = header.IPv6ProtocolNumber + } + ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer()) + vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView() + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ns.linkEP.InjectInbound(pn, packetBuf) + return filter.Accept +} - logf("XXX packet Write out: % x", full) - if err := tundev.InjectOutbound(full); err != nil { - log.Printf("netstack inject outbound: %v", err) - return - } +func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { + ns.logf("[v2] ForwarderRequest: %v", r) + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + r.Complete(true) + return + } + localAddr, err := ep.GetLocalAddress() + ns.logf("[v2] forwarding port %v to 100.101.102.103:80", localAddr.Port) + if err != nil { + r.Complete(true) + return + } + r.Complete(false) + c := gonet.NewTCPConn(&wq, ep) + go ns.forwardTCP(c, &wq, "100.101.102.103:80") +} +func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address string) { + defer client.Close() + ns.logf("[v2] netstack: forwarding to address %s", address) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + done := make(chan bool) + // netstack doesn't close the notification channel automatically if there was no + // hup signal, so we close done after we're done to not leak the goroutine below. + defer close(done) + go func() { + select { + case <-notifyCh: + case <-done: } + cancel() }() - - tundev.PostFilterIn = func(p *packet.Parsed, t *tstun.TUN) filter.Response { - var pn tcpip.NetworkProtocolNumber - switch p.IPVersion { - case 4: - pn = header.IPv4ProtocolNumber - case 6: - pn = header.IPv6ProtocolNumber - } - logf("XXX packet in (from %v): % x", p.Src, p.Buffer()) - vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView() - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - linkEP.InjectInbound(pn, packetBuf) - return filter.Accept + server, err := ns.dialContextTCP(ctx, address) + if err != nil { + ns.logf("netstack: could not connect to server %s: %s", address, err) + return } - return nil + defer server.Close() + connClosed := make(chan bool, 2) + go func() { + io.Copy(server, client) + connClosed <- true + }() + go func() { + io.Copy(client, server) + connClosed <- true + }() + <-connClosed + ns.logf("[v2] netstack: forwarder connection to %s closed", address) } -func echo(c *gonet.TCPConn, e wgengine.Engine, mc *magicsock.Conn) { - defer c.Close() - src, _ := netaddr.FromStdIP(c.RemoteAddr().(*net.TCPAddr).IP) - who := "" - if n, u, ok := mc.WhoIs(src); ok { - who = fmt.Sprintf("%v from %v", u.DisplayName, n.Name) +func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { + ns.logf("[v2] UDP ForwarderRequest: %v", r) + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + ns.logf("Could not create endpoint, exiting") + return } - fmt.Fprintf(c, "Hello, %s! Thanks for connecting to me on port %v (Try other ports too!)\nEchoing...\n", - who, - c.LocalAddr().(*net.TCPAddr).Port) + c := gonet.NewUDPConn(ns.ipstack, &wq, ep) + go echoUDP(c) +} + +func echoUDP(c *gonet.UDPConn) { buf := make([]byte, 1500) for { n, err := c.Read(buf) if err != nil { - log.Printf("Err: %v", err) break } c.Write(buf[:n]) } - log.Print("Connection closed") + c.Close() } |
