diff --git a/.gitignore b/.gitignore index 566e32e0..e4c451aa 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vendor/ coverage.html coverage.txt coverage/ +blocky diff --git a/config/config_test.go b/config/config_test.go index c6059c1b..f5b7ab52 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -235,6 +235,17 @@ blocking: Expect(err.Error()).Should(ContainSubstring("invalid IP address '192.168.178.WRONG'")) }) }) + When("CustomDNS hast wrong IPv6 defined", func() { + It("should return error", func() { + cfg := Config{} + data := `customDNS: + mapping: + someDomain: 2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334` + err := unmarshalConfig(logger, []byte(data), &cfg) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("invalid IP address '2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334'")) + }) + }) When("Conditional mapping hast wrong defined upstreams", func() { It("should return error", func() { cfg := Config{} @@ -866,12 +877,24 @@ func defaultTestFileConfig(config *Config) { Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8")) Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4")) Expect(config.Upstreams.Groups["default"][2].Host).Should(Equal("1.1.1.1")) - Expect(config.CustomDNS.Mapping.HostIPs).Should(HaveLen(2)) - Expect(config.CustomDNS.Mapping.HostIPs["my.duckdns.org"][0]).Should(Equal(net.ParseIP("192.168.178.3"))) - Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][0]).Should(Equal(net.ParseIP("192.168.178.3"))) - Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][1]).Should(Equal(net.ParseIP("192.168.178.4"))) - Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][2]).Should(Equal( - net.ParseIP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344"))) + Expect(config.CustomDNS.Mapping).Should(HaveLen(2)) + + duckDNSEntry := config.CustomDNS.Mapping["my.duckdns.org"][0] + duckDNSA := duckDNSEntry.(*dns.A) + Expect(duckDNSA.A).Should(Equal(net.ParseIP("192.168.178.3"))) + + multipleIpsEntry := config.CustomDNS.Mapping["multiple.ips"][0] + multipleIpsA := multipleIpsEntry.(*dns.A) + Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.3"))) + + multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][1] + multipleIpsA = multipleIpsEntry.(*dns.A) + Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.4"))) + + multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][2] + multipleIpsAAAA := multipleIpsEntry.(*dns.AAAA) + Expect(multipleIpsAAAA.AAAA).Should(Equal(net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7344"))) + Expect(config.Conditional.Mapping.Upstreams).Should(HaveLen(2)) Expect(config.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1)) Expect(config.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2)) diff --git a/config/custom_dns.go b/config/custom_dns.go index 08b7f1a7..23e0183e 100644 --- a/config/custom_dns.go +++ b/config/custom_dns.go @@ -5,6 +5,7 @@ import ( "net" "strings" + "github.com/miekg/dns" "github.com/sirupsen/logrus" ) @@ -16,14 +17,45 @@ type CustomDNS struct { FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` } -// CustomDNSMapping mapping for the custom DNS configuration -type CustomDNSMapping struct { - HostIPs map[string][]net.IP `yaml:"hostIPs"` +type ( + CustomDNSMapping map[string]CustomDNSEntries + CustomDNSEntries []dns.RR +) + +func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error { + var input string + if err := unmarshal(&input); err != nil { + return err + } + + parts := strings.Split(input, ",") + result := make(CustomDNSEntries, len(parts)) + containsCNAME := false + + for i, part := range parts { + rr, err := configToRR(part) + if err != nil { + return err + } + + _, isCNAME := rr.(*dns.CNAME) + containsCNAME = containsCNAME || isCNAME + + result[i] = rr + } + + if containsCNAME && len(result) > 1 { + return fmt.Errorf("when a CNAME record is present, it must be the only record in the mapping") + } + + *c = result + + return nil } // IsEnabled implements `config.Configurable`. func (c *CustomDNS) IsEnabled() bool { - return len(c.Mapping.HostIPs) != 0 + return len(c.Mapping) != 0 } // LogConfig implements `config.Configurable`. @@ -33,36 +65,52 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) { logger.Info("mapping:") - for key, val := range c.Mapping.HostIPs { + for key, val := range c.Mapping { logger.Infof(" %s = %s", key, val) } } -// UnmarshalYAML implements `yaml.Unmarshaler`. -func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input map[string]string - if err := unmarshal(&input); err != nil { - return err +func removePrefixSuffix(in, prefix string) string { + in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix)) + in = strings.TrimSuffix(in, ")") + + return strings.TrimSpace(in) +} + +func configToRR(part string) (dns.RR, error) { + if strings.HasPrefix(part, "CNAME(") { + domain := removePrefixSuffix(part, "CNAME") + domain = dns.Fqdn(domain) + cname := &dns.CNAME{Target: domain} + + return cname, nil } - result := make(map[string][]net.IP, len(input)) + // Fall back to A/AAAA records to maintain backwards compatibility in config.yml + // We will still remove the A() or AAAA() if it exists + if strings.Contains(part, ".") { // IPV4 address + ipStr := removePrefixSuffix(part, "A") + ip := net.ParseIP(ipStr) - for k, v := range input { - var ips []net.IP - - for _, part := range strings.Split(v, ",") { - ip := net.ParseIP(strings.TrimSpace(part)) - if ip == nil { - return fmt.Errorf("invalid IP address '%s'", part) - } - - ips = append(ips, ip) + if ip == nil { + return nil, fmt.Errorf("invalid IP address '%s'", part) } - result[k] = ips + a := new(dns.A) + a.A = ip + + return a, nil + } else { // IPV6 address + ipStr := removePrefixSuffix(part, "AAAA") + ip := net.ParseIP(ipStr) + + if ip == nil { + return nil, fmt.Errorf("invalid IP address '%s'", part) + } + + aaaa := new(dns.AAAA) + aaaa.AAAA = ip + + return aaaa, nil } - - c.HostIPs = result - - return nil } diff --git a/config/custom_dns_test.go b/config/custom_dns_test.go index 9d83b2bd..d92b2e2c 100644 --- a/config/custom_dns_test.go +++ b/config/custom_dns_test.go @@ -5,6 +5,7 @@ import ( "net" "github.com/creasty/defaults" + "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -17,15 +18,14 @@ var _ = Describe("CustomDNSConfig", func() { BeforeEach(func() { cfg = CustomDNS{ Mapping: CustomDNSMapping{ - HostIPs: map[string][]net.IP{ - "custom.domain": {net.ParseIP("192.168.143.123")}, - "ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, - "multiple.ips": { - net.ParseIP("192.168.143.123"), - net.ParseIP("192.168.143.125"), - net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), - }, + "custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}}, + "ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}}, + "multiple.ips": { + &dns.A{A: net.ParseIP("192.168.143.123")}, + &dns.A{A: net.ParseIP("192.168.143.125")}, + &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, }, + "cname.domain": {&dns.CNAME{Target: "custom.domain"}}, }, } }) @@ -60,27 +60,41 @@ var _ = Describe("CustomDNSConfig", func() { Expect(hook.Calls).ShouldNot(BeEmpty()) Expect(hook.Messages).Should(ContainElements( ContainSubstring("custom.domain = "), + ContainSubstring("ip6.domain = "), ContainSubstring("multiple.ips = "), + ContainSubstring("cname.domain = "), )) }) }) Describe("UnmarshalYAML", func() { It("Should parse config as map", func() { - c := &CustomDNSMapping{} + c := CustomDNSEntries{} err := c.UnmarshalYAML(func(i interface{}) error { - *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} + *i.(*string) = "1.2.3.4" return nil }) Expect(err).Should(Succeed()) - Expect(c.HostIPs).Should(HaveLen(1)) - Expect(c.HostIPs["key"]).Should(HaveLen(1)) - Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4"))) + Expect(c).Should(HaveLen(1)) + + aRecord := c[0].(*dns.A) + Expect(aRecord.A).Should(Equal(net.ParseIP("1.2.3.4"))) + }) + + It("Should return an error if a CNAME is accomanied by any other record", func() { + c := CustomDNSEntries{} + err := c.UnmarshalYAML(func(i interface{}) error { + *i.(*string) = "CNAME(example.com),A(1.2.3.4)" + + return nil + }) + Expect(err).Should(HaveOccurred()) + Expect(err).Should(MatchError("when a CNAME record is present, it must be the only record in the mapping")) }) It("should fail if wrong YAML format", func() { - c := &CustomDNSMapping{} + c := &CustomDNSEntries{} err := c.UnmarshalYAML(func(i interface{}) error { return errors.New("some err") }) diff --git a/docs/config.yml b/docs/config.yml index 8b583d48..4d90226f 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -47,6 +47,7 @@ customDNS: example.com: printer.lan mapping: printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 + second-printer-address.lan: CNAME(printer.lan) # optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma # Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name diff --git a/docs/configuration.md b/docs/configuration.md index 5f24b924..705652f3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -259,12 +259,12 @@ You can define your own domain name to IP mappings. For example, you can use a u or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one domain must be separated by a comma. -| Parameter | Type | Mandatory | Default value | -| ------------------- | --------------------------------------- | --------- | ------------- | -| customTTL | duration (no unit is minutes) | no | 1h | -| rewrite | string: string (domain: domain) | no | | -| mapping | string: string (hostname: address list) | no | | -| filterUnmappedTypes | boolean | no | true | +| Parameter | Type | Mandatory | Default value | +| ------------------- | ------------------------------------------- | --------- | ------------- | +| customTTL | duration (no unit is minutes) | no | 1h | +| rewrite | string: string (domain: domain) | no | | +| mapping | string: string (hostname: address or CNAME) | no | | +| filterUnmappedTypes | boolean | no | true | !!! example @@ -278,11 +278,14 @@ domain must be separated by a comma. mapping: printer.lan: 192.168.178.3 otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 + anothername.lan: CNAME(otherdevice.lan) ``` This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of `printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3. +CNAME records are supported by setting the value of the mapping to `CNAME(target)`. Note that the target will be recursively resolved and will return an error if a loop is detected. + With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the resolver lookup is performed. The query "printer.home" will be rewritten to "printer.lan" and return 192.168.178.3. diff --git a/e2e/containers.go b/e2e/containers.go index 78ae85fd..e57b7d2d 100644 --- a/e2e/containers.go +++ b/e2e/containers.go @@ -66,7 +66,7 @@ func createDNSMokkaContainer(ctx context.Context, alias string, rules ...string) // createHTTPServerContainer creates a static HTTP server container that serves one file with the given lines // and is attached to the test network under the given alias. // It is automatically terminated when the test is finished. -func createHTTPServerContainer(ctx context.Context, alias string, filename string, lines ...string, +func createHTTPServerContainer(ctx context.Context, alias, filename string, lines ...string, ) (testcontainers.Container, error) { file := createTempFile(lines...) diff --git a/helpertest/helper.go b/helpertest/helper.go index 97caa50e..29b0421a 100644 --- a/helpertest/helper.go +++ b/helpertest/helper.go @@ -212,6 +212,8 @@ func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error return v.A.String() == matcher.answer, nil case *dns.AAAA: return v.AAAA.String() == matcher.answer, nil + case *dns.CNAME: + return v.Target == matcher.answer, nil case *dns.PTR: return v.Ptr == matcher.answer, nil case *dns.MX: diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index 422ce45c..d96ceb64 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -2,7 +2,9 @@ package resolver import ( "context" + "fmt" "net" + "slices" "strings" "github.com/0xERR0R/blocky/config" @@ -14,27 +16,41 @@ import ( "github.com/sirupsen/logrus" ) +type createAnswerFunc func(question dns.Question, ip net.IP, ttl uint32) (dns.RR, error) + // CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map type CustomDNSResolver struct { configurable[*config.CustomDNS] NextResolver typed - mapping map[string][]net.IP - reverseAddresses map[string][]string + createAnswerFromQuestion createAnswerFunc + mapping config.CustomDNSMapping + reverseAddresses map[string][]string } // NewCustomDNSResolver creates new resolver instance func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver { - m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs)) - reverse := make(map[string][]string, len(cfg.Mapping.HostIPs)) + m := make(config.CustomDNSMapping, len(cfg.Mapping)) + reverse := make(map[string][]string, len(cfg.Mapping)) - for url, ips := range cfg.Mapping.HostIPs { - m[strings.ToLower(url)] = ips + for url, entries := range cfg.Mapping { + m[strings.ToLower(url)] = entries - for _, ip := range ips { - r, _ := dns.ReverseAddr(ip.String()) - reverse[r] = append(reverse[r], url) + for _, entry := range entries { + a, isA := entry.(*dns.A) + + if isA { + r, _ := dns.ReverseAddr(a.A.String()) + reverse[r] = append(reverse[r], url) + } + + aaaa, isAAAA := entry.(*dns.AAAA) + + if isAAAA { + r, _ := dns.ReverseAddr(aaaa.AAAA.String()) + reverse[r] = append(reverse[r], url) + } } } @@ -42,8 +58,9 @@ func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver { configurable: withConfig(&cfg), typed: withType("custom_dns"), - mapping: m, - reverseAddresses: reverse, + createAnswerFromQuestion: util.CreateAnswerFromQuestion, + mapping: m, + reverseAddresses: reverse, } } @@ -75,7 +92,26 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp return nil } -func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response { +func (r *CustomDNSResolver) forwardResponse( + logger *logrus.Entry, + ctx context.Context, + request *model.Request, +) (*model.Response, error) { + logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") + + forwardResponse, err := r.next.Resolve(ctx, request) + if err != nil { + return nil, err + } + + return forwardResponse, nil +} + +func (r *CustomDNSResolver) processRequest( + ctx context.Context, + request *model.Request, + resolvedCnames []string, +) (*model.Response, error) { logger := log.WithPrefix(request.Log, "custom_dns_resolver") response := new(dns.Msg) @@ -85,13 +121,20 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon domain := util.ExtractDomain(question) for len(domain) > 0 { - ips, found := r.mapping[domain] + if err := ctx.Err(); err != nil { + return nil, err + } + + entries, found := r.mapping[domain] + if found { - for _, ip := range ips { - if isSupportedType(ip, question) { - rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32()) - response.Answer = append(response.Answer, rr) + for _, entry := range entries { + result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry) + if err != nil { + return nil, err } + + response.Answer = append(response.Answer, result...) } if len(response.Answer) > 0 { @@ -100,7 +143,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon "domain": domain, }).Debugf("returning custom dns entry") - return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} + return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil } // Mapping exists for this domain, but for another type @@ -110,7 +153,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon } // return NOERROR with empty result - return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} + return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil } if i := strings.Index(domain, "."); i >= 0 { @@ -120,26 +163,99 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon } } - return nil + return r.forwardResponse(logger, ctx, request) +} + +func (r *CustomDNSResolver) processDNSEntry( + ctx context.Context, + request *model.Request, + resolvedCnames []string, + question dns.Question, + entry dns.RR, +) ([]dns.RR, error) { + switch v := entry.(type) { + case *dns.A: + return r.processIP(v.A, question) + case *dns.AAAA: + return r.processIP(v.AAAA, question) + case *dns.CNAME: + return r.processCNAME(ctx, request, *v, resolvedCnames, question) + } + + return nil, fmt.Errorf("unsupported customDNS RR type %T", entry) } // Resolve uses internal mapping to resolve the query func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, "custom_dns_resolver") - reverseResp := r.handleReverseDNS(request) if reverseResp != nil { return reverseResp, nil } - if len(r.mapping) > 0 { - resp := r.processRequest(request) - if resp != nil { - return resp, nil - } + resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping))) + if err != nil { + return nil, err } - logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") - - return r.next.Resolve(ctx, request) + return resp, nil +} + +func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question) (result []dns.RR, err error) { + result = make([]dns.RR, 0) + + if isSupportedType(ip, question) { + rr, err := r.createAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32()) + if err != nil { + return nil, err + } + + result = append(result, rr) + } + + return result, nil +} + +func (r *CustomDNSResolver) processCNAME( + ctx context.Context, + request *model.Request, + targetCname dns.CNAME, + resolvedCnames []string, + question dns.Question, +) (result []dns.RR, err error) { + cname := new(dns.CNAME) + ttl := r.cfg.CustomTTL.SecondsU32() + cname.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeCNAME, Name: question.Name} + cname.Target = dns.Fqdn(targetCname.Target) + result = append(result, cname) + + if question.Qtype == dns.TypeCNAME { + return result, nil + } + + targetWithoutDot := strings.TrimSuffix(targetCname.Target, ".") + + if slices.Contains(resolvedCnames, targetWithoutDot) { + return nil, fmt.Errorf("CNAME loop detected: %v", append(resolvedCnames, targetWithoutDot)) + } + + cnames := resolvedCnames + cnames = append(cnames, targetWithoutDot) + + clientIP := request.ClientIP.String() + clientID := request.RequestClientID + targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID) + + // resolve the target recursively + targetResp, err := r.processRequest(ctx, targetRequest, cnames) + if err != nil { + return nil, err + } + + result = append(result, targetResp.Res.Answer...) + + return result, nil +} + +func (r *CustomDNSResolver) CreateAnswerFromQuestion(newFunc createAnswerFunc) { + r.createAnswerFromQuestion = newFunc } diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index dc359bf7..6cb1a308 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "fmt" "net" "time" @@ -38,15 +39,20 @@ var _ = Describe("CustomDNSResolver", func() { DeferCleanup(cancelFn) cfg = config.CustomDNS{ - Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{ - "custom.domain": {net.ParseIP("192.168.143.123")}, - "ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + Mapping: config.CustomDNSMapping{ + "custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}}, + "ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}}, "multiple.ips": { - net.ParseIP("192.168.143.123"), - net.ParseIP("192.168.143.125"), - net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + &dns.A{A: net.ParseIP("192.168.143.123")}, + &dns.A{A: net.ParseIP("192.168.143.125")}, + &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, }, - }}, + "cname.domain": {&dns.CNAME{Target: "custom.domain"}}, + "cname.ip6": {&dns.CNAME{Target: "ip6.domain"}}, + "cname.example": {&dns.CNAME{Target: "example.com"}}, + "cname.recursive": {&dns.CNAME{Target: "cname.recursive"}}, + "mx.domain": {&dns.MX{Mx: "mx.domain"}}, + }, CustomTTL: config.Duration(time.Duration(TTL) * time.Second), FilterUnmappedTypes: true, } @@ -76,6 +82,57 @@ var _ = Describe("CustomDNSResolver", func() { }) Describe("Resolving custom name via CustomDNSResolver", func() { + When("The parent context has an error ", func() { + It("should return the error", func() { + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := sut.Resolve(cancelledCtx, newRequest("custom.domain.", A)) + + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("context canceled")) + }) + }) + When("Creating the IP response returns an error ", func() { + It("should return the error", func() { + createAnswerMock := func(_ dns.Question, _ net.IP, _ uint32) (dns.RR, error) { + return nil, fmt.Errorf("create answer error") + } + + sut.CreateAnswerFromQuestion(createAnswerMock) + + _, err := sut.Resolve(ctx, newRequest("custom.domain.", A)) + + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("create answer error")) + }) + }) + When("The forward request returns an error ", func() { + It("should return the error if the error occurs when checking ipv4 forward addresses", func() { + err := fmt.Errorf("forward error") + m = &mockResolver{} + + m.On("Resolve", mock.Anything).Return(nil, err) + + sut.Next(m) + _, err = sut.Resolve(ctx, newRequest("cname.example.", A)) + + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("forward error")) + }) + It("should return the error if the error occurs when checking ipv6 forward addresses", func() { + err := fmt.Errorf("forward error") + m = &mockResolver{} + + m.On("Resolve", mock.Anything).Return(nil, err) + + sut.Next(m) + _, err = sut.Resolve(ctx, newRequest("cname.example.", AAAA)) + + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("forward error")) + }) + }) When("Ip 4 mapping is defined for custom domain and", func() { Context("filterUnmappedTypes is true", func() { BeforeEach(func() { cfg.FilterUnmappedTypes = true }) @@ -211,6 +268,101 @@ var _ = Describe("CustomDNSResolver", func() { }) }) }) + When("A CNAME record is defined for custom domain ", func() { + It("should not recurse if the request is strictly a CNAME request", func() { + By("CNAME query", func() { + Expect(sut.Resolve(ctx, newRequest("cname.domain", CNAME))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + HaveLen(1), + ContainElements( + BeDNSRecord("cname.domain.", CNAME, "custom.domain.")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + + // will not delegate to next resolver + m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) + }) + }) + It("all CNAMES for the current type should be recursively resolved when relying on other Mappings", func() { + By("A query", func() { + Expect(sut.Resolve(ctx, newRequest("cname.domain", A))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + HaveLen(2), + ContainElements( + BeDNSRecord("cname.domain.", CNAME, "custom.domain."), + BeDNSRecord("custom.domain.", A, "192.168.143.123")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + + // will not delegate to next resolver + m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) + }) + By("AAAA query", func() { + Expect(sut.Resolve(ctx, newRequest("cname.ip6", AAAA))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + HaveLen(2), + ContainElements( + BeDNSRecord("cname.ip6.", CNAME, "ip6.domain."), + BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + + // will not delegate to next resolver + m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) + }) + }) + It("should return an error when the CNAME is recursive", func() { + By("CNAME query", func() { + _, err := sut.Resolve(ctx, newRequest("cname.recursive", A)) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("CNAME loop detected:")) + // will not delegate to next resolver + m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) + }) + }) + It("all CNAMES for the current type should be returned when relying on public DNS", func() { + By("CNAME query", func() { + Expect(sut.Resolve(ctx, newRequest("cname.example", A))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + ContainElements( + BeDNSRecord("cname.example.", CNAME, "example.com.")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + + // will delegate to next resolver + m.AssertCalled(GinkgoT(), "Resolve", mock.Anything) + }) + }) + }) + When("An unsupported DNS query type is queried from the resolver but found in the config mapping ", func() { + It("an error should be returned", func() { + By("MX query", func() { + _, err := sut.Resolve(ctx, newRequest("mx.domain", MX)) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("unsupported customDNS RR type *dns.MX")) + }) + }) + }) When("Reverse DNS request is received", func() { It("should resolve the defined domain name", func() { By("ipv4", func() { diff --git a/server/server.go b/server/server.go index a4f3af3c..3780822b 100644 --- a/server/server.go +++ b/server/server.go @@ -613,6 +613,13 @@ func (s *Server) OnRequest( func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) { var response *model.Response + contextUpstreamTimeoutMultiplier := 100 + timeoutDuration := time.Duration(contextUpstreamTimeoutMultiplier) * s.cfg.Upstreams.Timeout.ToDuration() + + ctx, cancel := context.WithTimeout(ctx, timeoutDuration) + + defer cancel() + switch { case len(request.Req.Question) == 0: m := new(dns.Msg) diff --git a/server/server_test.go b/server/server_test.go index 6a3f56e2..85ea63a5 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -104,10 +104,8 @@ var _ = BeforeSuite(func() { CustomDNS: config.CustomDNS{ CustomTTL: config.Duration(3600 * time.Second), Mapping: config.CustomDNSMapping{ - HostIPs: map[string][]net.IP{ - "custom.lan": {net.ParseIP("192.168.178.55")}, - "lan.home": {net.ParseIP("192.168.178.56")}, - }, + "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}}, + "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}}, }, }, Conditional: config.ConditionalUpstream{ @@ -596,10 +594,8 @@ var _ = Describe("Running DNS server", func() { }, CustomDNS: config.CustomDNS{ Mapping: config.CustomDNSMapping{ - HostIPs: map[string][]net.IP{ - "custom.lan": {net.ParseIP("192.168.178.55")}, - "lan.home": {net.ParseIP("192.168.178.56")}, - }, + "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}}, + "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}}, }, }, Blocking: config.Blocking{BlockType: "zeroIp"}, @@ -642,10 +638,8 @@ var _ = Describe("Running DNS server", func() { }, CustomDNS: config.CustomDNS{ Mapping: config.CustomDNSMapping{ - HostIPs: map[string][]net.IP{ - "custom.lan": {net.ParseIP("192.168.178.55")}, - "lan.home": {net.ParseIP("192.168.178.56")}, - }, + "custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}}, + "lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}}, }, }, Blocking: config.Blocking{BlockType: "zeroIp"},