Only recursively resolve CNAMES conditionally based on the question type

This commit is contained in:
Ben McHone 2024-01-28 16:34:38 -06:00
parent ba2c317fd5
commit 6262cdbcc1
2 changed files with 59 additions and 20 deletions

View File

@ -242,23 +242,26 @@ func (r *CustomDNSResolver) processCNAME(
cnames := resolvedCnames
cnames = append(cnames, targetWithoutDot)
// Resolve target recursively
targetResp, err := r.processRequest(ctx, aRequest, cnames)
if err != nil {
return nil, err
if question.Qtype == dns.TypeA {
// Resolve target recursively
targetResp, err := r.processRequest(ctx, aRequest, cnames)
if err != nil {
return nil, err
}
result = append(result, targetResp.Res.Answer...)
}
result = append(result, targetResp.Res.Answer...)
if question.Qtype == dns.TypeAAAA {
// Resolve ipv6 target recursively
targetResp, err := r.processRequest(ctx, aaaaRequest, cnames)
if err != nil {
return nil, err
}
// Resolve ipv6 target recursively
// Ignore the returned list of cnames, as the error would have been returned already
targetResp, err = r.processRequest(ctx, aaaaRequest, cnames)
if err != nil {
return nil, err
result = append(result, targetResp.Res.Answer...)
}
result = append(result, targetResp.Res.Answer...)
return result, nil
}

View File

@ -48,6 +48,7 @@ var _ = Describe("CustomDNSResolver", func() {
&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
},
"cname.domain": {&dns.CNAME{Target: "custom.domain"}},
"cname.ip6": {&dns.CNAME{Target: "ip6.domain"}},
"cname.example": {&dns.CNAME{Target: "example.com"}},
"cname.recursive": {&dns.CNAME{Target: "cname.recursive"}},
"mx.domain": {&dns.MX{Mx: "mx.domain"}},
@ -114,7 +115,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.On("Resolve", mock.Anything).Return(nil, err)
sut.Next(m)
_, err = sut.Resolve(ctx, newRequest("cname.example.", CNAME))
_, err = sut.Resolve(ctx, newRequest("cname.example.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("forward error"))
@ -123,12 +124,10 @@ var _ = Describe("CustomDNSResolver", func() {
err := fmt.Errorf("forward error")
m = &mockResolver{}
// The first call is for ipv4, the second for ipv6
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil).Once()
m.On("Resolve", mock.Anything).Return(nil, err).Once()
m.On("Resolve", mock.Anything).Return(nil, err)
sut.Next(m)
_, err = sut.Resolve(ctx, newRequest("cname.example.", CNAME))
_, err = sut.Resolve(ctx, newRequest("cname.example.", AAAA))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("forward error"))
@ -270,9 +269,28 @@ var _ = Describe("CustomDNSResolver", func() {
})
})
When("A CNAME record is defined for custom domain ", func() {
It("all CNAMES for the current type should be returned when relying on other Mappings", func() {
It("should not recurse if the request is strictly a CNAME request", func() {
By("CNAME query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.domain", CNAME))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(1),
ContainElements(
BeDNSRecord("cname.domain.", CNAME, "custom.domain.")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
It("all CNAMES for the current type should be recursively resolved when relying on other Mappings", func() {
By("A query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.domain", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
@ -289,10 +307,28 @@ var _ = Describe("CustomDNSResolver", func() {
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
By("AAAA query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.ip6", AAAA))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("cname.ip6.", CNAME, "ip6.domain."),
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
It("should return an error when the CNAME is recursive", func() {
By("CNAME query", func() {
_, err := sut.Resolve(ctx, newRequest("cname.recursive", CNAME))
_, err := sut.Resolve(ctx, newRequest("cname.recursive", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("CNAME loop detected:"))
// will not delegate to next resolver
@ -301,7 +337,7 @@ var _ = Describe("CustomDNSResolver", func() {
})
It("all CNAMES for the current type should be returned when relying on public DNS", func() {
By("CNAME query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.example", CNAME))).
Expect(sut.Resolve(ctx, newRequest("cname.example", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(