feat: support CNAME records in customDNS mappings (#1352)

Co-authored-by: Ben McHone <ben@mchone.dev>
This commit is contained in:
Ben 2024-01-29 10:22:03 -06:00 committed by GitHub
parent 3817d98e74
commit b8b4dc323a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 463 additions and 102 deletions

1
.gitignore vendored
View File

@ -15,3 +15,4 @@ vendor/
coverage.html coverage.html
coverage.txt coverage.txt
coverage/ coverage/
blocky

View File

@ -235,6 +235,17 @@ blocking:
Expect(err.Error()).Should(ContainSubstring("invalid IP address '192.168.178.WRONG'")) Expect(err.Error()).Should(ContainSubstring("invalid IP address '192.168.178.WRONG'"))
}) })
}) })
When("CustomDNS hast wrong IPv6 defined", func() {
It("should return error", func() {
cfg := Config{}
data := `customDNS:
mapping:
someDomain: 2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334`
err := unmarshalConfig(logger, []byte(data), &cfg)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("invalid IP address '2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334'"))
})
})
When("Conditional mapping hast wrong defined upstreams", func() { When("Conditional mapping hast wrong defined upstreams", func() {
It("should return error", func() { It("should return error", func() {
cfg := Config{} cfg := Config{}
@ -866,12 +877,24 @@ func defaultTestFileConfig(config *Config) {
Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8")) Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8"))
Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4")) Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4"))
Expect(config.Upstreams.Groups["default"][2].Host).Should(Equal("1.1.1.1")) Expect(config.Upstreams.Groups["default"][2].Host).Should(Equal("1.1.1.1"))
Expect(config.CustomDNS.Mapping.HostIPs).Should(HaveLen(2)) Expect(config.CustomDNS.Mapping).Should(HaveLen(2))
Expect(config.CustomDNS.Mapping.HostIPs["my.duckdns.org"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][0]).Should(Equal(net.ParseIP("192.168.178.3"))) duckDNSEntry := config.CustomDNS.Mapping["my.duckdns.org"][0]
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][1]).Should(Equal(net.ParseIP("192.168.178.4"))) duckDNSA := duckDNSEntry.(*dns.A)
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][2]).Should(Equal( Expect(duckDNSA.A).Should(Equal(net.ParseIP("192.168.178.3")))
net.ParseIP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344")))
multipleIpsEntry := config.CustomDNS.Mapping["multiple.ips"][0]
multipleIpsA := multipleIpsEntry.(*dns.A)
Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.3")))
multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][1]
multipleIpsA = multipleIpsEntry.(*dns.A)
Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.4")))
multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][2]
multipleIpsAAAA := multipleIpsEntry.(*dns.AAAA)
Expect(multipleIpsAAAA.AAAA).Should(Equal(net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7344")))
Expect(config.Conditional.Mapping.Upstreams).Should(HaveLen(2)) Expect(config.Conditional.Mapping.Upstreams).Should(HaveLen(2))
Expect(config.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1)) Expect(config.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1))
Expect(config.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2)) Expect(config.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2))

View File

@ -5,6 +5,7 @@ import (
"net" "net"
"strings" "strings"
"github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -16,14 +17,45 @@ type CustomDNS struct {
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
} }
// CustomDNSMapping mapping for the custom DNS configuration type (
type CustomDNSMapping struct { CustomDNSMapping map[string]CustomDNSEntries
HostIPs map[string][]net.IP `yaml:"hostIPs"` CustomDNSEntries []dns.RR
)
func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}
parts := strings.Split(input, ",")
result := make(CustomDNSEntries, len(parts))
containsCNAME := false
for i, part := range parts {
rr, err := configToRR(part)
if err != nil {
return err
}
_, isCNAME := rr.(*dns.CNAME)
containsCNAME = containsCNAME || isCNAME
result[i] = rr
}
if containsCNAME && len(result) > 1 {
return fmt.Errorf("when a CNAME record is present, it must be the only record in the mapping")
}
*c = result
return nil
} }
// IsEnabled implements `config.Configurable`. // IsEnabled implements `config.Configurable`.
func (c *CustomDNS) IsEnabled() bool { func (c *CustomDNS) IsEnabled() bool {
return len(c.Mapping.HostIPs) != 0 return len(c.Mapping) != 0
} }
// LogConfig implements `config.Configurable`. // LogConfig implements `config.Configurable`.
@ -33,36 +65,52 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
logger.Info("mapping:") logger.Info("mapping:")
for key, val := range c.Mapping.HostIPs { for key, val := range c.Mapping {
logger.Infof(" %s = %s", key, val) logger.Infof(" %s = %s", key, val)
} }
} }
// UnmarshalYAML implements `yaml.Unmarshaler`. func removePrefixSuffix(in, prefix string) string {
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix))
var input map[string]string in = strings.TrimSuffix(in, ")")
if err := unmarshal(&input); err != nil {
return err return strings.TrimSpace(in)
}
func configToRR(part string) (dns.RR, error) {
if strings.HasPrefix(part, "CNAME(") {
domain := removePrefixSuffix(part, "CNAME")
domain = dns.Fqdn(domain)
cname := &dns.CNAME{Target: domain}
return cname, nil
} }
result := make(map[string][]net.IP, len(input)) // Fall back to A/AAAA records to maintain backwards compatibility in config.yml
// We will still remove the A() or AAAA() if it exists
if strings.Contains(part, ".") { // IPV4 address
ipStr := removePrefixSuffix(part, "A")
ip := net.ParseIP(ipStr)
for k, v := range input { if ip == nil {
var ips []net.IP return nil, fmt.Errorf("invalid IP address '%s'", part)
for _, part := range strings.Split(v, ",") {
ip := net.ParseIP(strings.TrimSpace(part))
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
ips = append(ips, ip)
} }
result[k] = ips a := new(dns.A)
a.A = ip
return a, nil
} else { // IPV6 address
ipStr := removePrefixSuffix(part, "AAAA")
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", part)
}
aaaa := new(dns.AAAA)
aaaa.AAAA = ip
return aaaa, nil
} }
c.HostIPs = result
return nil
} }

View File

@ -5,6 +5,7 @@ import (
"net" "net"
"github.com/creasty/defaults" "github.com/creasty/defaults"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -17,15 +18,14 @@ var _ = Describe("CustomDNSConfig", func() {
BeforeEach(func() { BeforeEach(func() {
cfg = CustomDNS{ cfg = CustomDNS{
Mapping: CustomDNSMapping{ Mapping: CustomDNSMapping{
HostIPs: map[string][]net.IP{ "custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}},
"custom.domain": {net.ParseIP("192.168.143.123")}, "ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}},
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, "multiple.ips": {
"multiple.ips": { &dns.A{A: net.ParseIP("192.168.143.123")},
net.ParseIP("192.168.143.123"), &dns.A{A: net.ParseIP("192.168.143.125")},
net.ParseIP("192.168.143.125"), &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
},
}, },
"cname.domain": {&dns.CNAME{Target: "custom.domain"}},
}, },
} }
}) })
@ -60,27 +60,41 @@ var _ = Describe("CustomDNSConfig", func() {
Expect(hook.Calls).ShouldNot(BeEmpty()) Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElements( Expect(hook.Messages).Should(ContainElements(
ContainSubstring("custom.domain = "), ContainSubstring("custom.domain = "),
ContainSubstring("ip6.domain = "),
ContainSubstring("multiple.ips = "), ContainSubstring("multiple.ips = "),
ContainSubstring("cname.domain = "),
)) ))
}) })
}) })
Describe("UnmarshalYAML", func() { Describe("UnmarshalYAML", func() {
It("Should parse config as map", func() { It("Should parse config as map", func() {
c := &CustomDNSMapping{} c := CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error { err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} *i.(*string) = "1.2.3.4"
return nil return nil
}) })
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(c.HostIPs).Should(HaveLen(1)) Expect(c).Should(HaveLen(1))
Expect(c.HostIPs["key"]).Should(HaveLen(1))
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4"))) aRecord := c[0].(*dns.A)
Expect(aRecord.A).Should(Equal(net.ParseIP("1.2.3.4")))
})
It("Should return an error if a CNAME is accomanied by any other record", func() {
c := CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "CNAME(example.com),A(1.2.3.4)"
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("when a CNAME record is present, it must be the only record in the mapping"))
}) })
It("should fail if wrong YAML format", func() { It("should fail if wrong YAML format", func() {
c := &CustomDNSMapping{} c := &CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error { err := c.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err") return errors.New("some err")
}) })

View File

@ -47,6 +47,7 @@ customDNS:
example.com: printer.lan example.com: printer.lan
mapping: mapping:
printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
second-printer-address.lan: CNAME(printer.lan)
# optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma # optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma
# Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name # Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name

View File

@ -259,12 +259,12 @@ You can define your own domain name to IP mappings. For example, you can use a u
or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one
domain must be separated by a comma. domain must be separated by a comma.
| Parameter | Type | Mandatory | Default value | | Parameter | Type | Mandatory | Default value |
| ------------------- | --------------------------------------- | --------- | ------------- | | ------------------- | ------------------------------------------- | --------- | ------------- |
| customTTL | duration (no unit is minutes) | no | 1h | | customTTL | duration (no unit is minutes) | no | 1h |
| rewrite | string: string (domain: domain) | no | | | rewrite | string: string (domain: domain) | no | |
| mapping | string: string (hostname: address list) | no | | | mapping | string: string (hostname: address or CNAME) | no | |
| filterUnmappedTypes | boolean | no | true | | filterUnmappedTypes | boolean | no | true |
!!! example !!! example
@ -278,11 +278,14 @@ domain must be separated by a comma.
mapping: mapping:
printer.lan: 192.168.178.3 printer.lan: 192.168.178.3
otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
anothername.lan: CNAME(otherdevice.lan)
``` ```
This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of
`printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3. `printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3.
CNAME records are supported by setting the value of the mapping to `CNAME(target)`. Note that the target will be recursively resolved and will return an error if a loop is detected.
With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the
resolver lookup is performed. resolver lookup is performed.
The query "printer.home" will be rewritten to "printer.lan" and return 192.168.178.3. The query "printer.home" will be rewritten to "printer.lan" and return 192.168.178.3.

View File

@ -66,7 +66,7 @@ func createDNSMokkaContainer(ctx context.Context, alias string, rules ...string)
// createHTTPServerContainer creates a static HTTP server container that serves one file with the given lines // createHTTPServerContainer creates a static HTTP server container that serves one file with the given lines
// and is attached to the test network under the given alias. // and is attached to the test network under the given alias.
// It is automatically terminated when the test is finished. // It is automatically terminated when the test is finished.
func createHTTPServerContainer(ctx context.Context, alias string, filename string, lines ...string, func createHTTPServerContainer(ctx context.Context, alias, filename string, lines ...string,
) (testcontainers.Container, error) { ) (testcontainers.Container, error) {
file := createTempFile(lines...) file := createTempFile(lines...)

View File

@ -212,6 +212,8 @@ func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error
return v.A.String() == matcher.answer, nil return v.A.String() == matcher.answer, nil
case *dns.AAAA: case *dns.AAAA:
return v.AAAA.String() == matcher.answer, nil return v.AAAA.String() == matcher.answer, nil
case *dns.CNAME:
return v.Target == matcher.answer, nil
case *dns.PTR: case *dns.PTR:
return v.Ptr == matcher.answer, nil return v.Ptr == matcher.answer, nil
case *dns.MX: case *dns.MX:

View File

@ -2,7 +2,9 @@ package resolver
import ( import (
"context" "context"
"fmt"
"net" "net"
"slices"
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -14,27 +16,41 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type createAnswerFunc func(question dns.Question, ip net.IP, ttl uint32) (dns.RR, error)
// CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map // CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map
type CustomDNSResolver struct { type CustomDNSResolver struct {
configurable[*config.CustomDNS] configurable[*config.CustomDNS]
NextResolver NextResolver
typed typed
mapping map[string][]net.IP createAnswerFromQuestion createAnswerFunc
reverseAddresses map[string][]string mapping config.CustomDNSMapping
reverseAddresses map[string][]string
} }
// NewCustomDNSResolver creates new resolver instance // NewCustomDNSResolver creates new resolver instance
func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver { func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs)) m := make(config.CustomDNSMapping, len(cfg.Mapping))
reverse := make(map[string][]string, len(cfg.Mapping.HostIPs)) reverse := make(map[string][]string, len(cfg.Mapping))
for url, ips := range cfg.Mapping.HostIPs { for url, entries := range cfg.Mapping {
m[strings.ToLower(url)] = ips m[strings.ToLower(url)] = entries
for _, ip := range ips { for _, entry := range entries {
r, _ := dns.ReverseAddr(ip.String()) a, isA := entry.(*dns.A)
reverse[r] = append(reverse[r], url)
if isA {
r, _ := dns.ReverseAddr(a.A.String())
reverse[r] = append(reverse[r], url)
}
aaaa, isAAAA := entry.(*dns.AAAA)
if isAAAA {
r, _ := dns.ReverseAddr(aaaa.AAAA.String())
reverse[r] = append(reverse[r], url)
}
} }
} }
@ -42,8 +58,9 @@ func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
configurable: withConfig(&cfg), configurable: withConfig(&cfg),
typed: withType("custom_dns"), typed: withType("custom_dns"),
mapping: m, createAnswerFromQuestion: util.CreateAnswerFromQuestion,
reverseAddresses: reverse, mapping: m,
reverseAddresses: reverse,
} }
} }
@ -75,7 +92,26 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil return nil
} }
func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response { func (r *CustomDNSResolver) forwardResponse(
logger *logrus.Entry,
ctx context.Context,
request *model.Request,
) (*model.Response, error) {
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
forwardResponse, err := r.next.Resolve(ctx, request)
if err != nil {
return nil, err
}
return forwardResponse, nil
}
func (r *CustomDNSResolver) processRequest(
ctx context.Context,
request *model.Request,
resolvedCnames []string,
) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver") logger := log.WithPrefix(request.Log, "custom_dns_resolver")
response := new(dns.Msg) response := new(dns.Msg)
@ -85,13 +121,20 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
domain := util.ExtractDomain(question) domain := util.ExtractDomain(question)
for len(domain) > 0 { for len(domain) > 0 {
ips, found := r.mapping[domain] if err := ctx.Err(); err != nil {
return nil, err
}
entries, found := r.mapping[domain]
if found { if found {
for _, ip := range ips { for _, entry := range entries {
if isSupportedType(ip, question) { result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry)
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32()) if err != nil {
response.Answer = append(response.Answer, rr) return nil, err
} }
response.Answer = append(response.Answer, result...)
} }
if len(response.Answer) > 0 { if len(response.Answer) > 0 {
@ -100,7 +143,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
"domain": domain, "domain": domain,
}).Debugf("returning custom dns entry") }).Debugf("returning custom dns entry")
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
} }
// Mapping exists for this domain, but for another type // Mapping exists for this domain, but for another type
@ -110,7 +153,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
} }
// return NOERROR with empty result // return NOERROR with empty result
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
} }
if i := strings.Index(domain, "."); i >= 0 { if i := strings.Index(domain, "."); i >= 0 {
@ -120,26 +163,99 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
} }
} }
return nil return r.forwardResponse(logger, ctx, request)
}
func (r *CustomDNSResolver) processDNSEntry(
ctx context.Context,
request *model.Request,
resolvedCnames []string,
question dns.Question,
entry dns.RR,
) ([]dns.RR, error) {
switch v := entry.(type) {
case *dns.A:
return r.processIP(v.A, question)
case *dns.AAAA:
return r.processIP(v.AAAA, question)
case *dns.CNAME:
return r.processCNAME(ctx, request, *v, resolvedCnames, question)
}
return nil, fmt.Errorf("unsupported customDNS RR type %T", entry)
} }
// Resolve uses internal mapping to resolve the query // Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
reverseResp := r.handleReverseDNS(request) reverseResp := r.handleReverseDNS(request)
if reverseResp != nil { if reverseResp != nil {
return reverseResp, nil return reverseResp, nil
} }
if len(r.mapping) > 0 { resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping)))
resp := r.processRequest(request) if err != nil {
if resp != nil { return nil, err
return resp, nil
}
} }
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") return resp, nil
}
return r.next.Resolve(ctx, request)
func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question) (result []dns.RR, err error) {
result = make([]dns.RR, 0)
if isSupportedType(ip, question) {
rr, err := r.createAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32())
if err != nil {
return nil, err
}
result = append(result, rr)
}
return result, nil
}
func (r *CustomDNSResolver) processCNAME(
ctx context.Context,
request *model.Request,
targetCname dns.CNAME,
resolvedCnames []string,
question dns.Question,
) (result []dns.RR, err error) {
cname := new(dns.CNAME)
ttl := r.cfg.CustomTTL.SecondsU32()
cname.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeCNAME, Name: question.Name}
cname.Target = dns.Fqdn(targetCname.Target)
result = append(result, cname)
if question.Qtype == dns.TypeCNAME {
return result, nil
}
targetWithoutDot := strings.TrimSuffix(targetCname.Target, ".")
if slices.Contains(resolvedCnames, targetWithoutDot) {
return nil, fmt.Errorf("CNAME loop detected: %v", append(resolvedCnames, targetWithoutDot))
}
cnames := resolvedCnames
cnames = append(cnames, targetWithoutDot)
clientIP := request.ClientIP.String()
clientID := request.RequestClientID
targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID)
// resolve the target recursively
targetResp, err := r.processRequest(ctx, targetRequest, cnames)
if err != nil {
return nil, err
}
result = append(result, targetResp.Res.Answer...)
return result, nil
}
func (r *CustomDNSResolver) CreateAnswerFromQuestion(newFunc createAnswerFunc) {
r.createAnswerFromQuestion = newFunc
} }

View File

@ -2,6 +2,7 @@ package resolver
import ( import (
"context" "context"
"fmt"
"net" "net"
"time" "time"
@ -38,15 +39,20 @@ var _ = Describe("CustomDNSResolver", func() {
DeferCleanup(cancelFn) DeferCleanup(cancelFn)
cfg = config.CustomDNS{ cfg = config.CustomDNS{
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{ Mapping: config.CustomDNSMapping{
"custom.domain": {net.ParseIP("192.168.143.123")}, "custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}},
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, "ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}},
"multiple.ips": { "multiple.ips": {
net.ParseIP("192.168.143.123"), &dns.A{A: net.ParseIP("192.168.143.123")},
net.ParseIP("192.168.143.125"), &dns.A{A: net.ParseIP("192.168.143.125")},
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
}, },
}}, "cname.domain": {&dns.CNAME{Target: "custom.domain"}},
"cname.ip6": {&dns.CNAME{Target: "ip6.domain"}},
"cname.example": {&dns.CNAME{Target: "example.com"}},
"cname.recursive": {&dns.CNAME{Target: "cname.recursive"}},
"mx.domain": {&dns.MX{Mx: "mx.domain"}},
},
CustomTTL: config.Duration(time.Duration(TTL) * time.Second), CustomTTL: config.Duration(time.Duration(TTL) * time.Second),
FilterUnmappedTypes: true, FilterUnmappedTypes: true,
} }
@ -76,6 +82,57 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
Describe("Resolving custom name via CustomDNSResolver", func() { Describe("Resolving custom name via CustomDNSResolver", func() {
When("The parent context has an error ", func() {
It("should return the error", func() {
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err := sut.Resolve(cancelledCtx, newRequest("custom.domain.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("context canceled"))
})
})
When("Creating the IP response returns an error ", func() {
It("should return the error", func() {
createAnswerMock := func(_ dns.Question, _ net.IP, _ uint32) (dns.RR, error) {
return nil, fmt.Errorf("create answer error")
}
sut.CreateAnswerFromQuestion(createAnswerMock)
_, err := sut.Resolve(ctx, newRequest("custom.domain.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("create answer error"))
})
})
When("The forward request returns an error ", func() {
It("should return the error if the error occurs when checking ipv4 forward addresses", func() {
err := fmt.Errorf("forward error")
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(nil, err)
sut.Next(m)
_, err = sut.Resolve(ctx, newRequest("cname.example.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("forward error"))
})
It("should return the error if the error occurs when checking ipv6 forward addresses", func() {
err := fmt.Errorf("forward error")
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(nil, err)
sut.Next(m)
_, err = sut.Resolve(ctx, newRequest("cname.example.", AAAA))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("forward error"))
})
})
When("Ip 4 mapping is defined for custom domain and", func() { When("Ip 4 mapping is defined for custom domain and", func() {
Context("filterUnmappedTypes is true", func() { Context("filterUnmappedTypes is true", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = true }) BeforeEach(func() { cfg.FilterUnmappedTypes = true })
@ -211,6 +268,101 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
}) })
}) })
When("A CNAME record is defined for custom domain ", func() {
It("should not recurse if the request is strictly a CNAME request", func() {
By("CNAME query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.domain", CNAME))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(1),
ContainElements(
BeDNSRecord("cname.domain.", CNAME, "custom.domain.")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
It("all CNAMES for the current type should be recursively resolved when relying on other Mappings", func() {
By("A query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.domain", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("cname.domain.", CNAME, "custom.domain."),
BeDNSRecord("custom.domain.", A, "192.168.143.123")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
By("AAAA query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.ip6", AAAA))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("cname.ip6.", CNAME, "ip6.domain."),
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
It("should return an error when the CNAME is recursive", func() {
By("CNAME query", func() {
_, err := sut.Resolve(ctx, newRequest("cname.recursive", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("CNAME loop detected:"))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
It("all CNAMES for the current type should be returned when relying on public DNS", func() {
By("CNAME query", func() {
Expect(sut.Resolve(ctx, newRequest("cname.example", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
ContainElements(
BeDNSRecord("cname.example.", CNAME, "example.com.")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will delegate to next resolver
m.AssertCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
})
When("An unsupported DNS query type is queried from the resolver but found in the config mapping ", func() {
It("an error should be returned", func() {
By("MX query", func() {
_, err := sut.Resolve(ctx, newRequest("mx.domain", MX))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("unsupported customDNS RR type *dns.MX"))
})
})
})
When("Reverse DNS request is received", func() { When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() { It("should resolve the defined domain name", func() {
By("ipv4", func() { By("ipv4", func() {

View File

@ -613,6 +613,13 @@ func (s *Server) OnRequest(
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
var response *model.Response var response *model.Response
contextUpstreamTimeoutMultiplier := 100
timeoutDuration := time.Duration(contextUpstreamTimeoutMultiplier) * s.cfg.Upstreams.Timeout.ToDuration()
ctx, cancel := context.WithTimeout(ctx, timeoutDuration)
defer cancel()
switch { switch {
case len(request.Req.Question) == 0: case len(request.Req.Question) == 0:
m := new(dns.Msg) m := new(dns.Msg)

View File

@ -104,10 +104,8 @@ var _ = BeforeSuite(func() {
CustomDNS: config.CustomDNS{ CustomDNS: config.CustomDNS{
CustomTTL: config.Duration(3600 * time.Second), CustomTTL: config.Duration(3600 * time.Second),
Mapping: config.CustomDNSMapping{ Mapping: config.CustomDNSMapping{
HostIPs: map[string][]net.IP{ "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
"custom.lan": {net.ParseIP("192.168.178.55")}, "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
"lan.home": {net.ParseIP("192.168.178.56")},
},
}, },
}, },
Conditional: config.ConditionalUpstream{ Conditional: config.ConditionalUpstream{
@ -596,10 +594,8 @@ var _ = Describe("Running DNS server", func() {
}, },
CustomDNS: config.CustomDNS{ CustomDNS: config.CustomDNS{
Mapping: config.CustomDNSMapping{ Mapping: config.CustomDNSMapping{
HostIPs: map[string][]net.IP{ "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
"custom.lan": {net.ParseIP("192.168.178.55")}, "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
"lan.home": {net.ParseIP("192.168.178.56")},
},
}, },
}, },
Blocking: config.Blocking{BlockType: "zeroIp"}, Blocking: config.Blocking{BlockType: "zeroIp"},
@ -642,10 +638,8 @@ var _ = Describe("Running DNS server", func() {
}, },
CustomDNS: config.CustomDNS{ CustomDNS: config.CustomDNS{
Mapping: config.CustomDNSMapping{ Mapping: config.CustomDNSMapping{
HostIPs: map[string][]net.IP{ "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
"custom.lan": {net.ParseIP("192.168.178.55")}, "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
"lan.home": {net.ParseIP("192.168.178.56")},
},
}, },
}, },
Blocking: config.Blocking{BlockType: "zeroIp"}, Blocking: config.Blocking{BlockType: "zeroIp"},