summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--cmd/sniproxy/snipproxy.go114
1 files changed, 84 insertions, 30 deletions
diff --git a/cmd/sniproxy/snipproxy.go b/cmd/sniproxy/snipproxy.go
index 34a2a7fbb..9672551cd 100644
--- a/cmd/sniproxy/snipproxy.go
+++ b/cmd/sniproxy/snipproxy.go
@@ -9,14 +9,13 @@ package main
import (
"context"
"flag"
- "fmt"
"log"
"net"
"net/netip"
"strings"
"time"
- "github.com/miekg/dns"
+ "golang.org/x/net/dns/dnsmessage"
"inet.af/tcpproxy"
"tailscale.com/client/tailscale"
"tailscale.com/net/netutil"
@@ -26,6 +25,7 @@ import (
var (
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
dnsserv = flag.Bool("dns", true, "run a small DNS server to reply to any query with its own address")
+ tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
)
func main() {
@@ -113,46 +113,100 @@ func (s *server) getAddresses() (ip4, ip6 netip.Addr) {
}
func (s *server) serveDns() {
- dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
- switch r.Opcode {
- case dns.OpcodeQuery:
- m := s.dnsResponse(r)
- m.SetReply(r)
- w.WriteMsg(m)
- }
- })
-
+ buf := make([]byte, 1024)
pconn, err := s.ts.ListenPacket("udp", ":53")
if err != nil {
- log.Printf("Failed to start DNS listener: %s\n ", err.Error())
- return
+ log.Fatal(err)
}
- dnsServer := &dns.Server{PacketConn: pconn}
- err = dnsServer.ActivateAndServe()
- if err != nil {
- log.Printf("Failed to start DNS server: %s\n ", err.Error())
+ for {
+ _, addr, err := pconn.ReadFrom(buf)
+ if err != nil {
+ log.Printf("pconn.ReadFrom failed: %v\n ", err)
+ continue
+ }
+
+ var msg dnsmessage.Message
+ err = msg.Unpack(buf)
+ if err != nil {
+ log.Printf("dnsmessage.Message unpack failed: %v\n ", err)
+ continue
+ }
+
+ buf, err := s.dnsResponse(&msg)
+ if err != nil {
+ log.Printf("s.dnsResponse failed: %v\n", err)
+ continue
+ }
+
+ _, err = pconn.WriteTo(buf, addr)
+ if err != nil {
+ log.Printf("pconn.WriteTo failed: %v\n", err)
+ continue
+ }
}
}
-func (s *server) dnsResponse(requestMsg *dns.Msg) *dns.Msg {
- responseMsg := new(dns.Msg)
- if len(requestMsg.Question) == 0 {
- return responseMsg
+func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
+ resp := dnsmessage.NewBuilder(buf,
+ dnsmessage.Header{
+ ID: req.Header.ID,
+ Response: true,
+ Authoritative: true,
+ })
+ resp.EnableCompression()
+
+ if len(req.Questions) == 0 {
+ buf, _ = resp.Finish()
+ return
}
- q := requestMsg.Question[0]
- var rr dns.RR
+ q := req.Questions[0]
+ err = resp.StartQuestions()
+ if err != nil {
+ return
+ }
+ resp.Question(q)
+
ip4, ip6 := s.getAddresses()
+ err = resp.StartAnswers()
+ if err != nil {
+ return
+ }
- switch q.Qtype {
- case dns.TypeAAAA:
- rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN AAAA %s", q.Name, ip6.String()))
+ switch q.Type {
+ case dnsmessage.TypeAAAA:
+ err = resp.AAAAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AAAAResource{AAAA: ip6.As16()},
+ )
- case dns.TypeA:
- rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN A %s", q.Name, ip4.String()))
+ case dnsmessage.TypeA:
+ err = resp.AResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AResource{A: ip4.As4()},
+ )
+ case dnsmessage.TypeSOA:
+ err = resp.SOAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
+ Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
+ )
+ case dnsmessage.TypeNS:
+ err = resp.NSResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.NSResource{NS: tsMBox},
+ )
}
- responseMsg.Answer = append(responseMsg.Answer, rr)
- return responseMsg
+ if err != nil {
+ return
+ }
+
+ buf, err = resp.Finish()
+ if err != nil {
+ return
+ }
+
+ return
}