fix(rewrite): support the case where upstream doesn't echo the question

Apparently Tailscale's magic DNS does this.
This commit is contained in:
ThinkChaos 2023-12-19 19:24:08 -05:00
parent c6304e9d7f
commit dece894bd6
2 changed files with 53 additions and 18 deletions

View File

@ -95,13 +95,25 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request)
return r.next.Resolve(ctx, request)
}
// Revert the rewrite in r.inner's response
if rewritten != nil {
for i := range originalNames {
response.Res.Question[i].Name = originalNames[i]
if rewritten == nil {
return response, nil
}
if i < len(response.Res.Answer) {
response.Res.Answer[i].Header().Name = originalNames[i]
// Revert the rewrite in r.inner's response
n := max(len(response.Res.Question), len(response.Res.Answer))
for i := 0; i < n; i++ {
if i < len(response.Res.Question) {
original, ok := originalNames[response.Res.Question[i].Name]
if ok {
response.Res.Question[i].Name = original
}
}
if i < len(response.Res.Answer) {
original, ok := originalNames[response.Res.Answer[i].Header().Name]
if ok {
response.Res.Answer[i].Header().Name = original
}
}
}
@ -109,22 +121,27 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request)
return response, nil
}
func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg) (rewritten *dns.Msg, originalNames []string) { //nolint: lll
originalNames = make([]string, len(request.Question))
func (r *RewriterResolver) rewriteRequest(
logger *logrus.Entry, request *dns.Msg,
) (rewritten *dns.Msg, originalNames map[string]string) {
originalNames = make(map[string]string, len(request.Question))
for i := range request.Question {
nameOriginal := request.Question[i].Name
originalNames[i] = nameOriginal
domainOriginal := util.ExtractDomainOnly(nameOriginal)
domainRewritten, rewriteKey := r.rewriteDomain(domainOriginal)
if domainRewritten != domainOriginal {
rewrittenFQDN := dns.Fqdn(domainRewritten)
originalNames[rewrittenFQDN] = nameOriginal
if rewritten == nil {
rewritten = request.Copy()
}
rewritten.Question[i].Name = dns.Fqdn(domainRewritten)
rewritten.Question[i].Name = rewrittenFQDN
logger.WithFields(logrus.Fields{
"rewrite": util.Obfuscate(rewriteKey) + ":" + util.Obfuscate(r.cfg.Rewrite[rewriteKey]),

View File

@ -67,17 +67,14 @@ var _ = Describe("RewriterResolver", func() {
})
When("has rewrite", func() {
var request *model.Request
var expectNilAnswer bool
var (
request *model.Request
expectNilAnswer bool
)
BeforeEach(func() {
expectNilAnswer = false
})
AfterEach(func() {
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))
mInner.On("Resolve", mock.Anything)
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
Expect(req).Should(Equal(request.Req))
@ -95,11 +92,19 @@ var _ = Describe("RewriterResolver", func() {
return res
}
})
AfterEach(func() {
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))
mInner.On("Resolve", mock.Anything)
resp, err := sut.Resolve(context.Background(), request)
Expect(err).Should(Succeed())
if resp != mNextResponse {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
if len(resp.Res.Question) != 0 {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
}
if expectNilAnswer {
Expect(resp.Res.Answer).Should(BeEmpty())
} else {
@ -128,6 +133,19 @@ var _ = Describe("RewriterResolver", func() {
fqdnRewritten = fqdnOriginal
})
It("should support a reply without the question", func() {
fqdnOriginal = sampleOriginal
fqdnRewritten = sampleRewritten
origResponseFn := mInner.ResponseFn
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
res := origResponseFn(req)
res.Question = nil
return res
}
})
It("should call next resolver", func() {
fqdnOriginal = sampleOriginal
fqdnRewritten = sampleRewritten