diff options
| author | David Crawshaw <crawshaw@tailscale.com> | 2021-07-29 17:38:14 -0700 |
|---|---|---|
| committer | David Crawshaw <crawshaw@tailscale.com> | 2021-07-29 17:38:37 -0700 |
| commit | 8b9e9c0786021c1cd02d86fffd3ba56b523f28ef (patch) | |
| tree | d03ec0a7cbd83662dc094c92e17e1437b96c0a60 /net/dns/resolver/forwarder.go | |
| parent | d37451bac6f38cc09b853b08b1dc8359ba767fa1 (diff) | |
| download | tailscale-crawshaw/peerdoh.tar.xz tailscale-crawshaw/peerdoh.zip | |
ipnlocal, resolver, etc: add peer API DoHcrawshaw/peerdoh
Diffstat (limited to 'net/dns/resolver/forwarder.go')
| -rw-r--r-- | net/dns/resolver/forwarder.go | 55 |
1 files changed, 39 insertions, 16 deletions
diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 5d1904468..8f4e641f9 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -529,28 +529,56 @@ type forwardQuery struct { // forward forwards the query to all upstream nameservers and returns the first response. func (f *forwarder) forward(query packet) error { - domain, err := nameFromQuery(query.bs) + ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) + defer cancel() + + v, err := f.forwardQuery(ctx, query.bs) if err != nil { return err } + select { + case <-ctx.Done(): + return ctx.Err() + case f.responses <- packet{v, query.addr}: + return nil + } +} - clampEDNSSize(query.bs, maxResponseBytes) +func (f *forwarder) Forward(ctx context.Context, bs []byte) ([]byte, error) { + ctx, cancel := context.WithTimeout(ctx, responseTimeout) + defer cancel() + + go func() { + select { + case <-f.ctx.Done(): + cancel() + case <-ctx.Done(): + } + }() + + return f.forwardQuery(ctx, bs) +} + +func (f *forwarder) forwardQuery(ctx context.Context, bs []byte) ([]byte, error) { + domain, err := nameFromQuery(bs) + if err != nil { + return nil, err + } + + clampEDNSSize(bs, maxResponseBytes) resolvers := f.resolvers(domain) if len(resolvers) == 0 { - return errNoUpstreams + return nil, errNoUpstreams } fq := &forwardQuery{ - txid: getTxID(query.bs), - packet: query.bs, + txid: getTxID(bs), + packet: bs, closeOnCtxDone: new(closePool), } defer fq.closeOnCtxDone.Close() - ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) - defer cancel() - resc := make(chan []byte, 1) var ( mu sync.Mutex @@ -586,19 +614,14 @@ func (f *forwarder) forward(query packet) error { select { case v := <-resc: - select { - case <-ctx.Done(): - return ctx.Err() - case f.responses <- packet{v, query.addr}: - return nil - } + return v, nil case <-ctx.Done(): mu.Lock() defer mu.Unlock() if firstErr != nil { - return firstErr + return nil, firstErr } - return ctx.Err() + return nil, ctx.Err() } } |
