diff options
| author | Jordan Whited <jordan@tailscale.com> | 2024-07-11 15:34:32 -0700 |
|---|---|---|
| committer | Jordan Whited <jordan@tailscale.com> | 2024-07-18 14:38:39 -0700 |
| commit | b1ab0264587ed9ba16e62d33c0f1b219d73fd690 (patch) | |
| tree | 8ddc8d894841c0abb63174f7d1fef7f5d2ddcb42 | |
| parent | 2742153f84d00a665a8f9a14e5fd42d38aca9379 (diff) | |
| download | tailscale-jwhited/gVisor-gso-gro.tar.xz tailscale-jwhited/gVisor-gso-gro.zip | |
net/tstun,wgengine/netstack: GSO from gVisor experimentjwhited/gVisor-gso-gro
Signed-off-by: Jordan Whited <jordan@tailscale.com>
| -rw-r--r-- | go.mod | 2 | ||||
| -rw-r--r-- | go.sum | 4 | ||||
| -rw-r--r-- | net/tstun/wrap.go | 83 | ||||
| -rw-r--r-- | wgengine/netstack/link_endpoint.go | 287 | ||||
| -rw-r--r-- | wgengine/netstack/netstack.go | 25 |
5 files changed, 372 insertions, 29 deletions
@@ -104,7 +104,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 - gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 + gvisor.dev/gvisor v0.0.0-20240713103206-39d6c232e61d honnef.co/go/tools v0.4.6 k8s.io/api v0.30.1 k8s.io/apimachinery v0.30.1 @@ -1474,8 +1474,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= -gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= +gvisor.dev/gvisor v0.0.0-20240713103206-39d6c232e61d h1:dFTIljP/5ReqgM7nMR4DauApFatUaSP8r9btX0sd8a8= +gvisor.dev/gvisor v0.0.0-20240713103206-39d6c232e61d/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 8ea73b4b2..382d9b386 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -20,6 +20,7 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "go4.org/mem" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/disco" "tailscale.com/net/connstats" @@ -104,6 +105,7 @@ type Wrapper struct { // peerConfig stores the current NAT configuration. peerConfig atomic.Pointer[peerConfigTable] + buf []byte // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is // allocated in wrap() and the underlying arrays should never grow. vectorBuffer [][]byte @@ -159,7 +161,8 @@ type Wrapper struct { // and therefore sees the packets that may be later dropped by it. PreFilterPacketInboundFromWireGuard FilterFunc // PostFilterPacketInboundFromWireGuard is the inbound filter function that runs after the main filter. - PostFilterPacketInboundFromWireGuard FilterFunc + PostFilterPacketInboundFromWireGuard FilterFunc + PostFilterPacketInboundFromWireGuardFlush func() // PreFilterPacketOutboundToWireGuardNetstackIntercept is a filter function that runs before the main filter // for packets from the local system. This filter is populated by netstack to hook // packets that should be handled by netstack. If set, this filter runs before @@ -262,6 +265,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { startCh: make(chan struct{}), } + w.buf = make([]byte, 65535) w.vectorBuffer = make([][]byte, tdev.BatchSize()) for i := range w.vectorBuffer { w.vectorBuffer[i] = make([]byte, maxBufferSize) @@ -894,13 +898,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { return 0, res.err } if res.data == nil { - n, err := t.injectedRead(res.injected, buffs[0], offset) - sizes[0] = n - if err != nil && n == 0 { - return 0, err - } - - return 1, err + return t.injectedRead(t.buf, res.injected, buffs, sizes, offset) } metricPacketOut.Add(int64(len(res.data))) @@ -956,25 +954,26 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { } // injectedRead handles injected reads, which bypass filters. -func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int, error) { - metricPacketOut.Add(1) - - var n int - if !res.packet.IsNil() { - - n = copy(buf[offset:], res.packet.NetworkHeader().Slice()) - n += copy(buf[offset+n:], res.packet.TransportHeader().Slice()) - n += copy(buf[offset+n:], res.packet.Data().AsRange().ToSlice()) +func (t *Wrapper) injectedRead(buf []byte, res tunInjectedRead, outBuffs [][]byte, sizes []int, offset int) (n int, err error) { + var ( + buffN int + gso stack.GSO + ) + if res.packet != nil { + buffN = copy(buf, res.packet.NetworkHeader().Slice()) + buffN += copy(buf[buffN:], res.packet.TransportHeader().Slice()) + buffN += copy(buf[buffN:], res.packet.Data().AsRange().ToSlice()) + gso = res.packet.GSOOptions res.packet.DecRef() } else { - n = copy(buf[offset:], res.data) + buffN = copy(buf, res.data) } pc := t.peerConfig.Load() p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) - p.Decode(buf[offset : offset+n]) + p.Decode(buf[:buffN]) pc.snat(p) if m := t.destIPActivity.Load(); m != nil { @@ -984,10 +983,51 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int } if stats := t.stats.Load(); stats != nil { - stats.UpdateTxVirtual(buf[offset:][:n]) + stats.UpdateTxVirtual(buf[:buffN]) } + + // gVisor can pass us gso.Type=stack.GSOTCPv{4,6} and gso.NeedsCsum=true for + // a TCP segment that is too small to split. This varies from Linux virtio + // where we get the equivalent of stack.GSONone if it's too small to split. + // So, we have to check size before falling into GSO logic, otherwise + // tun.GSOSplit() will clear checksum(s) and return early, resulting in a + // packet being fed up to wireguard-go with invalid checksums. + // TODO(jwhited): bounds checks and consider res.data was non-nil + if gso.Type == stack.GSONone || buffN-int(gso.L3HdrLen) <= int(gso.MSS) { + if gso.NeedsCsum { + err = tun.GSONoneChecksum(buf[:buffN], gso.L3HdrLen, gso.CsumOffset) + } + n = 1 + sizes[0] = buffN + copy(outBuffs[0][offset:], buf[:buffN]) + } else if gso.Type == stack.GSOTCPv4 || gso.Type == stack.GSOTCPv6 { + tcphLen := uint16((buf[gso.L3HdrLen+12] >> 4) * 4) // TODO(jwhited): bounds checks + hdr := tun.VirtioNetHdr{ + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: gso.L3HdrLen + tcphLen, + GSOSize: gso.MSS, + CsumStart: gso.L3HdrLen, + CsumOffset: gso.CsumOffset, + } + if gso.Type == stack.GSOTCPv6 { + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + } + // TODO(jwhited): tun.GSOSplit() is an unmodified export of + // tun.gsoSplit(). This will need to be refactored into its own + // package. Its eventual API should not require virtio_net_hdr, but + // something more intermediary/generic. Its 'in' arg is assumed to be + // non-overlapping with 'outBuffs', but it would be more performant if + // we could just assign/copy into outBuffs[0] for 'in' for this use + // case, instead. + n, err = tun.GSOSplit(buf[:buffN], hdr, outBuffs, sizes, offset, gso.Type == stack.GSOTCPv6) + } else { + // TODO(jwhited): unexpected + panic("unexpected") + } + t.noteActivity() - return n, nil + metricPacketOut.Add(int64(n)) + return n, err } func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response { @@ -1112,6 +1152,7 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { } } } + t.PostFilterPacketInboundFromWireGuardFlush() if t.disableFilter { i = len(buffs) } diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go new file mode 100644 index 000000000..df67aecb1 --- /dev/null +++ b/wgengine/netstack/link_endpoint.go @@ -0,0 +1,287 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "context" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/stack/gro" +) + +type queue struct { + // c is the outbound packet channel. + c chan *stack.PacketBuffer + mu sync.RWMutex + // +checklocks:mu + closed bool +} + +func (q *queue) Close() { + q.mu.Lock() + defer q.mu.Unlock() + if !q.closed { + close(q.c) + } + q.closed = true +} + +func (q *queue) Read() *stack.PacketBuffer { + select { + case p := <-q.c: + return p + default: + return nil + } +} + +func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer { + select { + case pkt := <-q.c: + return pkt + case <-ctx.Done(): + return nil + } +} + +func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error { + // q holds the PacketBuffer. + q.mu.RLock() + if q.closed { + q.mu.RUnlock() + return &tcpip.ErrClosedForSend{} + } + + wrote := false + select { + case q.c <- pkt.IncRef(): + wrote = true + default: + pkt.DecRef() + } + q.mu.RUnlock() + + if wrote { + return nil + } + return &tcpip.ErrNoBufferSpace{} +} + +func (q *queue) Num() int { + return len(q.c) +} + +var _ stack.LinkEndpoint = (*linkEndpoint)(nil) +var _ stack.GSOEndpoint = (*linkEndpoint)(nil) + +// linkEndpoint is link layer endpoint that stores outbound packets in a channel +// and allows injection of inbound packets. +// +// +stateify savable +type linkEndpoint struct { + LinkEPCapabilities stack.LinkEndpointCapabilities + SupportedGSOKind stack.SupportedGSO + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + dispatcher stack.NetworkDispatcher + // +checklocks:mu + linkAddr tcpip.LinkAddress + // +checklocks:mu + mtu uint32 + + // Outbound packet queue. + q *queue + + gro *gro.GRO +} + +// newLinkEndpoint creates a new channel endpoint. +func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *linkEndpoint { + ep := &linkEndpoint{ + q: &queue{ + c: make(chan *stack.PacketBuffer, size), + }, + mtu: mtu, + linkAddr: linkAddr, + gro: &gro.GRO{}, + } + ep.gro.Init(true) + return ep +} + +// Close closes e. Further packet injections will return an error, and all pending +// packets are discarded. Close may be called concurrently with WritePackets. +func (e *linkEndpoint) Close() { + e.q.Close() + e.Drain() +} + +// Read does non-blocking read one packet from the outbound packet queue. +func (e *linkEndpoint) Read() *stack.PacketBuffer { + return e.q.Read() +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns nil. +func (e *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { + return e.q.ReadContext(ctx) +} + +// Drain removes all outbound packets from the channel and counts them. +func (e *linkEndpoint) Drain() int { + c := 0 + for pkt := e.Read(); pkt != nil; pkt = e.Read() { + pkt.DecRef() + c++ + } + return c +} + +// NumQueued returns the number of packet queued for outbound. +func (e *linkEndpoint) NumQueued() int { + return e.q.Num() +} + +// InjectInbound injects an inbound packet. If the endpoint is not attached, the +// packet is not delivered. +func (e *linkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(protocol, pkt) + } +} + +func (e *linkEndpoint) GROEnqueue(pkt *stack.PacketBuffer) { + e.mu.RLock() + defer e.mu.RUnlock() + if e.gro.Dispatcher == nil { + pkt.DecRef() + return + } + e.gro.Enqueue(pkt) +} + +func (e *linkEndpoint) GROFlush() { + e.mu.RLock() + defer e.mu.RUnlock() + if e.gro.Dispatcher == nil { + return + } + e.gro.Flush() +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + defer e.mu.Unlock() + e.dispatcher = dispatcher + e.gro.Dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *linkEndpoint) IsAttached() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *linkEndpoint) MTU() uint32 { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mtu +} + +// SetMTU implements stack.LinkEndpoint.SetMTU. +func (e *linkEndpoint) SetMTU(mtu uint32) { + e.mu.Lock() + defer e.mu.Unlock() + e.mtu = mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.LinkEPCapabilities +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (*linkEndpoint) GSOMaxSize() uint32 { + return 1<<16 - 1 +} + +// SupportedGSO implements stack.GSOEndpoint. +func (e *linkEndpoint) SupportedGSO() stack.SupportedGSO { + return e.SupportedGSOKind +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*linkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *linkEndpoint) LinkAddress() tcpip.LinkAddress { + e.mu.RLock() + defer e.mu.RUnlock() + return e.linkAddr +} + +// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress. +func (e *linkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { + e.mu.Lock() + defer e.mu.Unlock() + e.linkAddr = addr +} + +// WritePackets stores outbound packets into the channel. +// Multiple concurrent calls are permitted. +func (e *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + n := 0 + for _, pkt := range pkts.AsSlice() { + if err := e.q.Write(pkt); err != nil { + if _, ok := err.(*tcpip.ErrNoBufferSpace); !ok && n == 0 { + return 0, err + } + break + } + n++ + } + + return n, nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*linkEndpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*linkEndpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*linkEndpoint) AddHeader(*stack.PacketBuffer) {} + +// ParseHeader implements stack.LinkEndpoint.ParseHeader. +func (*linkEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } + +// SetOnCloseAction implements stack.LinkEndpoint. +func (*linkEndpoint) SetOnCloseAction(func()) {} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 4d08a20ed..086ea802b 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -176,7 +175,7 @@ type Impl struct { ProcessSubnets bool ipstack *stack.Stack - linkEP *channel.Endpoint + linkEP *linkEndpoint tundev *tstun.Wrapper e wgengine.Engine pm *proxymap.Mapper @@ -285,10 +284,19 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) } } - linkEP := channel.New(512, uint32(tstun.DefaultTUNMTU()), "") + linkEP := newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "") + linkEP.LinkEPCapabilities = stack.CapabilityRXChecksumOffload if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) } + go func() { + for { + <-time.After(time.Second * 2) + log.Printf("XXX IP Stats: %+v", ipstack.Stats().IP) + log.Printf("XXX TCP Stats: %+v", ipstack.Stats().TCP) + } + }() + linkEP.SupportedGSOKind = stack.HostGSOSupported // By default the netstack NIC will only accept packets for the IPs // registered to it. Since in some cases we dynamically register IPs // based on the packets that arrive, the NIC needs to accept all @@ -333,6 +341,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound + ns.tundev.PostFilterPacketInboundFromWireGuardFlush = ns.groFlush ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets stacksForMetrics.Store(ns, struct{}{}) return ns, nil @@ -791,7 +800,7 @@ func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet. func (ns *Impl) inject() { for { pkt := ns.linkEP.ReadContext(ns.ctx) - if pkt.IsNil() { + if pkt == nil { if ns.ctx.Err() != nil { // Return without logging. return @@ -1000,6 +1009,10 @@ func (ns *Impl) userPing(dstIP netip.Addr, pingResPkt []byte, direction userPing } } +func (ns *Impl) groFlush() { + ns.linkEP.GROFlush() +} + // injectInbound is installed as a packet hook on the 'inbound' (from a // WireGuard peer) path. Returning filter.Accept releases the packet to // continue normally (typically being delivered to the host networking stack), @@ -1048,7 +1061,9 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(bytes.Clone(p.Buffer())), }) - ns.linkEP.InjectInbound(pn, packetBuf) + packetBuf.NetworkProtocolNumber = pn + //packetBuf.RXChecksumValidated = true + ns.linkEP.GROEnqueue(packetBuf) packetBuf.DecRef() // We've now delivered this to netstack, so we're done. |
