summaryrefslogtreecommitdiffhomepage
path: root/net/dns/resolver/forwarder.go
diff options
context:
space:
mode:
Diffstat (limited to 'net/dns/resolver/forwarder.go')
-rw-r--r--net/dns/resolver/forwarder.go55
1 files changed, 39 insertions, 16 deletions
diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go
index 5d1904468..8f4e641f9 100644
--- a/net/dns/resolver/forwarder.go
+++ b/net/dns/resolver/forwarder.go
@@ -529,28 +529,56 @@ type forwardQuery struct {
// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error {
- domain, err := nameFromQuery(query.bs)
+ ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
+ defer cancel()
+
+ v, err := f.forwardQuery(ctx, query.bs)
if err != nil {
return err
}
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case f.responses <- packet{v, query.addr}:
+ return nil
+ }
+}
- clampEDNSSize(query.bs, maxResponseBytes)
+func (f *forwarder) Forward(ctx context.Context, bs []byte) ([]byte, error) {
+ ctx, cancel := context.WithTimeout(ctx, responseTimeout)
+ defer cancel()
+
+ go func() {
+ select {
+ case <-f.ctx.Done():
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+
+ return f.forwardQuery(ctx, bs)
+}
+
+func (f *forwarder) forwardQuery(ctx context.Context, bs []byte) ([]byte, error) {
+ domain, err := nameFromQuery(bs)
+ if err != nil {
+ return nil, err
+ }
+
+ clampEDNSSize(bs, maxResponseBytes)
resolvers := f.resolvers(domain)
if len(resolvers) == 0 {
- return errNoUpstreams
+ return nil, errNoUpstreams
}
fq := &forwardQuery{
- txid: getTxID(query.bs),
- packet: query.bs,
+ txid: getTxID(bs),
+ packet: bs,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
- ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
- defer cancel()
-
resc := make(chan []byte, 1)
var (
mu sync.Mutex
@@ -586,19 +614,14 @@ func (f *forwarder) forward(query packet) error {
select {
case v := <-resc:
- select {
- case <-ctx.Done():
- return ctx.Err()
- case f.responses <- packet{v, query.addr}:
- return nil
- }
+ return v, nil
case <-ctx.Done():
mu.Lock()
defer mu.Unlock()
if firstErr != nil {
- return firstErr
+ return nil, firstErr
}
- return ctx.Err()
+ return nil, ctx.Err()
}
}