From dece894bd6d62205f2ec69379850e2a526667c8d Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 19 Dec 2023 19:24:08 -0500 Subject: [PATCH] fix(rewrite): support the case where upstream doesn't echo the question Apparently Tailscale's magic DNS does this. --- resolver/rewriter_resolver.go | 37 ++++++++++++++++++++++-------- resolver/rewriter_resolver_test.go | 34 ++++++++++++++++++++------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/resolver/rewriter_resolver.go b/resolver/rewriter_resolver.go index 20742193..bf5304fe 100644 --- a/resolver/rewriter_resolver.go +++ b/resolver/rewriter_resolver.go @@ -95,13 +95,25 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) return r.next.Resolve(ctx, request) } - // Revert the rewrite in r.inner's response - if rewritten != nil { - for i := range originalNames { - response.Res.Question[i].Name = originalNames[i] + if rewritten == nil { + return response, nil + } - if i < len(response.Res.Answer) { - response.Res.Answer[i].Header().Name = originalNames[i] + // Revert the rewrite in r.inner's response + + n := max(len(response.Res.Question), len(response.Res.Answer)) + for i := 0; i < n; i++ { + if i < len(response.Res.Question) { + original, ok := originalNames[response.Res.Question[i].Name] + if ok { + response.Res.Question[i].Name = original + } + } + + if i < len(response.Res.Answer) { + original, ok := originalNames[response.Res.Answer[i].Header().Name] + if ok { + response.Res.Answer[i].Header().Name = original } } } @@ -109,22 +121,27 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) return response, nil } -func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg) (rewritten *dns.Msg, originalNames []string) { //nolint: lll - originalNames = make([]string, len(request.Question)) +func (r *RewriterResolver) rewriteRequest( + logger *logrus.Entry, request *dns.Msg, +) (rewritten *dns.Msg, originalNames map[string]string) { + originalNames = make(map[string]string, len(request.Question)) for i := range request.Question { nameOriginal := request.Question[i].Name - originalNames[i] = nameOriginal domainOriginal := util.ExtractDomainOnly(nameOriginal) domainRewritten, rewriteKey := r.rewriteDomain(domainOriginal) if domainRewritten != domainOriginal { + rewrittenFQDN := dns.Fqdn(domainRewritten) + + originalNames[rewrittenFQDN] = nameOriginal + if rewritten == nil { rewritten = request.Copy() } - rewritten.Question[i].Name = dns.Fqdn(domainRewritten) + rewritten.Question[i].Name = rewrittenFQDN logger.WithFields(logrus.Fields{ "rewrite": util.Obfuscate(rewriteKey) + ":" + util.Obfuscate(r.cfg.Rewrite[rewriteKey]), diff --git a/resolver/rewriter_resolver_test.go b/resolver/rewriter_resolver_test.go index 6027c745..d1d77961 100644 --- a/resolver/rewriter_resolver_test.go +++ b/resolver/rewriter_resolver_test.go @@ -67,17 +67,14 @@ var _ = Describe("RewriterResolver", func() { }) When("has rewrite", func() { - var request *model.Request - var expectNilAnswer bool + var ( + request *model.Request + expectNilAnswer bool + ) BeforeEach(func() { expectNilAnswer = false - }) - AfterEach(func() { - request = newRequest(fqdnOriginal, dns.Type(dns.TypeA)) - - mInner.On("Resolve", mock.Anything) mInner.ResponseFn = func(req *dns.Msg) *dns.Msg { Expect(req).Should(Equal(request.Req)) @@ -95,11 +92,19 @@ var _ = Describe("RewriterResolver", func() { return res } + }) + + AfterEach(func() { + request = newRequest(fqdnOriginal, dns.Type(dns.TypeA)) + + mInner.On("Resolve", mock.Anything) resp, err := sut.Resolve(context.Background(), request) Expect(err).Should(Succeed()) if resp != mNextResponse { - Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal)) + if len(resp.Res.Question) != 0 { + Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal)) + } if expectNilAnswer { Expect(resp.Res.Answer).Should(BeEmpty()) } else { @@ -128,6 +133,19 @@ var _ = Describe("RewriterResolver", func() { fqdnRewritten = fqdnOriginal }) + It("should support a reply without the question", func() { + fqdnOriginal = sampleOriginal + fqdnRewritten = sampleRewritten + + origResponseFn := mInner.ResponseFn + mInner.ResponseFn = func(req *dns.Msg) *dns.Msg { + res := origResponseFn(req) + res.Question = nil + + return res + } + }) + It("should call next resolver", func() { fqdnOriginal = sampleOriginal fqdnRewritten = sampleRewritten