feat: Support defining records by dns zone format (#1360)

* feat: Support zonefile configuration for custom dns mapping

* docs: Update configuration.md

* Rename var to ok

* Linter fixes

* Remove hashes in test describe description

* Implement PR comments; zoneFileMapping -> zone, initialize with proper sizes

* Remove custom CNAME parsing

* Utilize TTL defined in zone file

* Link to wikipedia's example file

* Test to confirm that a relative zone entry without an $ORIGIN returns an error

* Write a test covering the $INCLUDE directive

* Write a test confirming that a dns zone can result in more than 1 RR

* Linting

* fix: Use proper matchers in CustomDNS Zone tests; Update configuration.md description

* Pull in config directory to support relative $INCLUDE

* Added tests to ensure the ability to use both bare filenames as well as relative filenames when using the $INCLUDE directive

* Shorten test description (Linting error)

* Move Assignment of z.RRs to the end of the UnmarshallYAML function

* Moved tests for relative $INCLUDE zones to config_test. Added test case when config param passed to blocky is a directory

* Corrected test case to _actually_ test againt bare file names
This commit is contained in:
Ben 2024-02-09 10:28:58 -06:00 committed by GitHub
parent 178dbb740e
commit 9f633f18d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 362 additions and 84 deletions

View File

@ -455,21 +455,30 @@ func loadConfig(logger *logrus.Entry, path string, mandatory bool) (rCfg *Config
return nil, fmt.Errorf("can't read config file(s): %w", err) return nil, fmt.Errorf("can't read config file(s): %w", err)
} }
var data []byte var (
data []byte
prettyPath string
)
if fs.IsDir() { if fs.IsDir() {
prettyPath = filepath.Join(path, "*")
data, err = readFromDir(path, data) data, err = readFromDir(path, data)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't read config files: %w", err) return nil, fmt.Errorf("can't read config files: %w", err)
} }
} else { } else {
prettyPath = path
data, err = os.ReadFile(path) data, err = os.ReadFile(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't read config file: %w", err) return nil, fmt.Errorf("can't read config file: %w", err)
} }
} }
cfg.CustomDNS.Zone.configPath = prettyPath
err = unmarshalConfig(logger, data, &cfg) err = unmarshalConfig(logger, data, &cfg)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -164,6 +164,94 @@ var _ = Describe("Config", func() {
defaultTestFileConfig(c) defaultTestFileConfig(c)
}) })
}) })
When("Test config file contains a zone file with $INCLUDE", func() {
When("The config path is set to the config file", func() {
It("Should support the $INCLUDE directive with a bare filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
cfgFile := writeConfigYmlWithLocalZoneFile(folder, "other.zone")
c, err = LoadConfig(cfgFile.Path, true)
Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))
Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
It("Should support the $INCLUDE directive with a relative filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
cfgFile := writeConfigYmlWithLocalZoneFile(folder, "./other.zone")
c, err = LoadConfig(cfgFile.Path, true)
Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))
Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
})
When("The config path is set to a directory", func() {
It("Should support the $INCLUDE directive with a bare filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
writeConfigYmlWithLocalZoneFile(folder, "other.zone")
c, err = LoadConfig(folder.Path, true)
Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))
Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
It("Should support the $INCLUDE directive with a relative filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
writeConfigYmlWithLocalZoneFile(folder, "./other.zone")
c, err = LoadConfig(folder.Path, true)
Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))
Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
})
})
When("Test file does not exist", func() { When("Test file does not exist", func() {
It("should fail", func() { It("should fail", func() {
_, err := LoadConfig(tmpDir.JoinPath("config-does-not-exist.yaml"), true) _, err := LoadConfig(tmpDir.JoinPath("config-does-not-exist.yaml"), true)
@ -977,6 +1065,33 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
) )
} }
func writeConfigYmlWithLocalZoneFile(tmpDir *helpertest.TmpFolder, includeStr string) *helpertest.TmpFile {
return tmpDir.CreateStringFile("config.yml",
"upstreams:",
" userAgent: testBlocky",
" init:",
" strategy: failOnError",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
" - tcp+udp:8.8.4.4",
" - 1.1.1.1",
"customDNS:",
" zone: |",
" $ORIGIN example.com.",
" $INCLUDE "+includeStr,
"filtering:",
" queryTypes:",
" - AAAA",
" - A",
"fqdnOnly:",
" enable: true",
"port: 55553,:55554,[::1]:55555",
"logLevel: debug",
"minTlsServeVersion: 1.3",
)
}
func writeConfigDir(tmpDir *helpertest.TmpFolder) { func writeConfigDir(tmpDir *helpertest.TmpFolder) {
tmpDir.CreateStringFile("config1.yaml", tmpDir.CreateStringFile("config1.yaml",
"upstreams:", "upstreams:",

View File

@ -14,14 +14,57 @@ type CustomDNS struct {
RewriterConfig `yaml:",inline"` RewriterConfig `yaml:",inline"`
CustomTTL Duration `yaml:"customTTL" default:"1h"` CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"` Mapping CustomDNSMapping `yaml:"mapping"`
Zone ZoneFileDNS `yaml:"zone" default:""`
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
} }
type ( type (
CustomDNSMapping map[string]CustomDNSEntries CustomDNSMapping map[string]CustomDNSEntries
CustomDNSEntries []dns.RR CustomDNSEntries []dns.RR
ZoneFileDNS struct {
RRs CustomDNSMapping
configPath string
}
) )
func (z *ZoneFileDNS) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}
result := make(CustomDNSMapping)
zoneParser := dns.NewZoneParser(strings.NewReader(input), "", z.configPath)
zoneParser.SetIncludeAllowed(true)
for {
zoneRR, ok := zoneParser.Next()
if !ok {
if zoneParser.Err() != nil {
return zoneParser.Err()
}
// Done
break
}
domain := zoneRR.Header().Name
if _, ok := result[domain]; !ok {
result[domain] = make(CustomDNSEntries, 0, 1)
}
result[domain] = append(result[domain], zoneRR)
}
z.RRs = result
return nil
}
func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error { func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string var input string
if err := unmarshal(&input); err != nil { if err := unmarshal(&input); err != nil {
@ -30,7 +73,6 @@ func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) erro
parts := strings.Split(input, ",") parts := strings.Split(input, ",")
result := make(CustomDNSEntries, len(parts)) result := make(CustomDNSEntries, len(parts))
containsCNAME := false
for i, part := range parts { for i, part := range parts {
rr, err := configToRR(part) rr, err := configToRR(part)
@ -38,16 +80,9 @@ func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) erro
return err return err
} }
_, isCNAME := rr.(*dns.CNAME)
containsCNAME = containsCNAME || isCNAME
result[i] = rr result[i] = rr
} }
if containsCNAME && len(result) > 1 {
return fmt.Errorf("when a CNAME record is present, it must be the only record in the mapping")
}
*c = result *c = result
return nil return nil
@ -70,47 +105,21 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
} }
} }
func removePrefixSuffix(in, prefix string) string { func configToRR(ipStr string) (dns.RR, error) {
in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix)) ip := net.ParseIP(ipStr)
in = strings.TrimSuffix(in, ")") if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", ipStr)
return strings.TrimSpace(in)
}
func configToRR(part string) (dns.RR, error) {
if strings.HasPrefix(part, "CNAME(") {
domain := removePrefixSuffix(part, "CNAME")
domain = dns.Fqdn(domain)
cname := &dns.CNAME{Target: domain}
return cname, nil
} }
// Fall back to A/AAAA records to maintain backwards compatibility in config.yml if ip.To4() != nil {
// We will still remove the A() or AAAA() if it exists
if strings.Contains(part, ".") { // IPV4 address
ipStr := removePrefixSuffix(part, "A")
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", part)
}
a := new(dns.A) a := new(dns.A)
a.A = ip a.A = ip
return a, nil return a, nil
} else { // IPV6 address
ipStr := removePrefixSuffix(part, "AAAA")
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", part)
}
aaaa := new(dns.AAAA)
aaaa.AAAA = ip
return aaaa, nil
} }
aaaa := new(dns.AAAA)
aaaa.AAAA = ip
return aaaa, nil
} }

View File

@ -2,8 +2,11 @@ package config
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"strings"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/creasty/defaults" "github.com/creasty/defaults"
"github.com/miekg/dns" "github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -25,7 +28,6 @@ var _ = Describe("CustomDNSConfig", func() {
&dns.A{A: net.ParseIP("192.168.143.125")}, &dns.A{A: net.ParseIP("192.168.143.125")},
&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
}, },
"cname.domain": {&dns.CNAME{Target: "custom.domain"}},
}, },
} }
}) })
@ -62,12 +64,11 @@ var _ = Describe("CustomDNSConfig", func() {
ContainSubstring("custom.domain = "), ContainSubstring("custom.domain = "),
ContainSubstring("ip6.domain = "), ContainSubstring("ip6.domain = "),
ContainSubstring("multiple.ips = "), ContainSubstring("multiple.ips = "),
ContainSubstring("cname.domain = "),
)) ))
}) })
}) })
Describe("UnmarshalYAML", func() { Describe("CustomDNSEntries UnmarshalYAML", func() {
It("Should parse config as map", func() { It("Should parse config as map", func() {
c := CustomDNSEntries{} c := CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error { err := c.UnmarshalYAML(func(i interface{}) error {
@ -82,17 +83,6 @@ var _ = Describe("CustomDNSConfig", func() {
Expect(aRecord.A).Should(Equal(net.ParseIP("1.2.3.4"))) Expect(aRecord.A).Should(Equal(net.ParseIP("1.2.3.4")))
}) })
It("Should return an error if a CNAME is accomanied by any other record", func() {
c := CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "CNAME(example.com),A(1.2.3.4)"
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("when a CNAME record is present, it must be the only record in the mapping"))
})
It("should fail if wrong YAML format", func() { It("should fail if wrong YAML format", func() {
c := &CustomDNSEntries{} c := &CustomDNSEntries{}
err := c.UnmarshalYAML(func(i interface{}) error { err := c.UnmarshalYAML(func(i interface{}) error {
@ -102,4 +92,116 @@ var _ = Describe("CustomDNSConfig", func() {
Expect(err).Should(MatchError("some err")) Expect(err).Should(MatchError("some err"))
}) })
}) })
Describe("ZoneFileDNS UnmarshalYAML", func() {
It("Should parse config as map", func() {
z := ZoneFileDNS{}
err := z.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = strings.TrimSpace(`
$ORIGIN example.com.
www 3600 A 1.2.3.4
www 3600 AAAA 2001:0db8:85a3:0000:0000:8a2e:0370:7334
www6 3600 AAAA 2001:0db8:85a3:0000:0000:8a2e:0370:7334
cname 3600 CNAME www
`)
return nil
})
Expect(err).Should(Succeed())
Expect(z.RRs).Should(HaveLen(3))
Expect(z.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(2),
ContainElements(
SatisfyAll(
BeDNSRecord("www.example.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 3600)),
),
SatisfyAll(
BeDNSRecord("www.example.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
HaveTTL(BeNumerically("==", 3600)),
))))
Expect(z.RRs["www6.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
BeDNSRecord("www6.example.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
HaveTTL(BeNumerically("==", 3600)),
))))
Expect(z.RRs["cname.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
BeDNSRecord("cname.example.com.", CNAME, "www.example.com."),
HaveTTL(BeNumerically("==", 3600)),
))))
})
It("Should support the $INCLUDE directive with an absolute path", func() {
folder := NewTmpFolder("zones")
file := folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
z := ZoneFileDNS{}
err := z.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = strings.TrimSpace(`
$ORIGIN example.com.
$INCLUDE ` + file.Path)
return nil
})
Expect(err).Should(Succeed())
Expect(z.RRs).Should(HaveLen(1))
Expect(z.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
BeDNSRecord("www.example.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 3600)),
)),
))
})
It("Should return an error if the zone file is malformed", func() {
z := ZoneFileDNS{}
err := z.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = strings.TrimSpace(`
$ORIGIN example.com.
www A 1.2.3.4
`)
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("dns: missing TTL with no previous value"))
})
It("Should return an error if a relative record is provided without an origin", func() {
z := ZoneFileDNS{}
err := z.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = strings.TrimSpace(`
$TTL 3600
www A 1.2.3.4
`)
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("dns: bad owner name: \"www\""))
})
It("Should return an error if the unmarshall function returns an error", func() {
z := ZoneFileDNS{}
err := z.UnmarshalYAML(func(i interface{}) error {
return fmt.Errorf("Failed to unmarshal")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("Failed to unmarshal"))
})
})
}) })

View File

@ -47,7 +47,6 @@ customDNS:
example.com: printer.lan example.com: printer.lan
mapping: mapping:
printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
second-printer-address.lan: CNAME(printer.lan)
# optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma # optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma
# Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name # Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name

View File

@ -259,12 +259,13 @@ You can define your own domain name to IP mappings. For example, you can use a u
or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one
domain must be separated by a comma. domain must be separated by a comma.
| Parameter | Type | Mandatory | Default value | | Parameter | Type | Mandatory | Default value |
| ------------------- | ------------------------------------------- | --------- | ------------- | | ------------------- | ------------------------------------------------------ | --------- | ------------- |
| customTTL | duration (no unit is minutes) | no | 1h | | customTTL | duration used for simple mappings (no unit is minutes) | no | 1h |
| rewrite | string: string (domain: domain) | no | | | rewrite | string: string (domain: domain) | no | |
| mapping | string: string (hostname: address or CNAME) | no | | | mapping | string: string (hostname: address or CNAME) | no | |
| filterUnmappedTypes | boolean | no | true | | zone | string containing a DNS Zone | no | |
| filterUnmappedTypes | boolean | no | true |
!!! example !!! example
@ -278,13 +279,22 @@ domain must be separated by a comma.
mapping: mapping:
printer.lan: 192.168.178.3 printer.lan: 192.168.178.3
otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344 otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
anothername.lan: CNAME(otherdevice.lan) zone: |
$ORIGIN example.com.
www 3600 A 1.2.3.4
@ 3600 CNAME www
``` ```
This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of
`printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3. `printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3.
CNAME records are supported by setting the value of the mapping to `CNAME(target)`. Note that the target will be recursively resolved and will return an error if a loop is detected. CNAME records are supported by utilizing the `zone` parameter. The zone file is a multiline string containing a [DNS Zone File](https://en.wikipedia.org/wiki/Zone_file#Example_file).
For records defined using the `zone` parameter, the `customTTL` parameter is unused. Instead, the TTL is defined in the zone directly.
The following directives are supported in the zone file:
* `$ORIGIN` - sets the origin for relative domain names
* `$TTL` - sets the default TTL for records in the zone
* `$INCLUDE` - includes another zone file relative to the blocky executable
* `$GENERATE` - generates a range of records
With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the
resolver lookup is performed. resolver lookup is performed.

View File

@ -31,12 +31,25 @@ type CustomDNSResolver struct {
// NewCustomDNSResolver creates new resolver instance // NewCustomDNSResolver creates new resolver instance
func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver { func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
m := make(config.CustomDNSMapping, len(cfg.Mapping)) dnsRecords := make(config.CustomDNSMapping, len(cfg.Mapping)+len(cfg.Zone.RRs))
reverse := make(map[string][]string, len(cfg.Mapping))
for url, entries := range cfg.Mapping { for url, entries := range cfg.Mapping {
m[strings.ToLower(url)] = entries url = util.ExtractDomainOnly(url)
dnsRecords[url] = entries
for _, entry := range entries {
entry.Header().Ttl = cfg.CustomTTL.SecondsU32()
}
}
for url, entries := range cfg.Zone.RRs {
url = util.ExtractDomainOnly(url)
dnsRecords[url] = entries
}
reverse := make(map[string][]string, len(dnsRecords))
for url, entries := range dnsRecords {
for _, entry := range entries { for _, entry := range entries {
a, isA := entry.(*dns.A) a, isA := entry.(*dns.A)
@ -59,7 +72,7 @@ func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
typed: withType("custom_dns"), typed: withType("custom_dns"),
createAnswerFromQuestion: util.CreateAnswerFromQuestion, createAnswerFromQuestion: util.CreateAnswerFromQuestion,
mapping: m, mapping: dnsRecords,
reverseAddresses: reverse, reverseAddresses: reverse,
} }
} }
@ -175,11 +188,11 @@ func (r *CustomDNSResolver) processDNSEntry(
) ([]dns.RR, error) { ) ([]dns.RR, error) {
switch v := entry.(type) { switch v := entry.(type) {
case *dns.A: case *dns.A:
return r.processIP(v.A, question) return r.processIP(v.A, question, v.Header().Ttl)
case *dns.AAAA: case *dns.AAAA:
return r.processIP(v.AAAA, question) return r.processIP(v.AAAA, question, v.Header().Ttl)
case *dns.CNAME: case *dns.CNAME:
return r.processCNAME(ctx, request, *v, resolvedCnames, question) return r.processCNAME(ctx, request, *v, resolvedCnames, question, v.Header().Ttl)
} }
return nil, fmt.Errorf("unsupported customDNS RR type %T", entry) return nil, fmt.Errorf("unsupported customDNS RR type %T", entry)
@ -200,11 +213,11 @@ func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request)
return resp, nil return resp, nil
} }
func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question) (result []dns.RR, err error) { func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint32) (result []dns.RR, err error) {
result = make([]dns.RR, 0) result = make([]dns.RR, 0)
if isSupportedType(ip, question) { if isSupportedType(ip, question) {
rr, err := r.createAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32()) rr, err := r.createAnswerFromQuestion(question, ip, ttl)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -221,9 +234,9 @@ func (r *CustomDNSResolver) processCNAME(
targetCname dns.CNAME, targetCname dns.CNAME,
resolvedCnames []string, resolvedCnames []string,
question dns.Question, question dns.Question,
ttl uint32,
) (result []dns.RR, err error) { ) (result []dns.RR, err error) {
cname := new(dns.CNAME) cname := new(dns.CNAME)
ttl := r.cfg.CustomTTL.SecondsU32()
cname.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeCNAME, Name: question.Name} cname.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeCNAME, Name: question.Name}
cname.Target = dns.Fqdn(targetCname.Target) cname.Target = dns.Fqdn(targetCname.Target)
result = append(result, cname) result = append(result, cname)

View File

@ -18,7 +18,8 @@ import (
var _ = Describe("CustomDNSResolver", func() { var _ = Describe("CustomDNSResolver", func() {
var ( var (
TTL = uint32(time.Now().Second()) TTL = uint32(time.Now().Second())
zoneTTL = uint32(time.Now().Second() * 2)
sut *CustomDNSResolver sut *CustomDNSResolver
m *mockResolver m *mockResolver
@ -38,6 +39,8 @@ var _ = Describe("CustomDNSResolver", func() {
ctx, cancelFn = context.WithCancel(context.Background()) ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn) DeferCleanup(cancelFn)
zoneHdr := dns.RR_Header{Ttl: zoneTTL}
cfg = config.CustomDNS{ cfg = config.CustomDNS{
Mapping: config.CustomDNSMapping{ Mapping: config.CustomDNSMapping{
"custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}}, "custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}},
@ -47,11 +50,16 @@ var _ = Describe("CustomDNSResolver", func() {
&dns.A{A: net.ParseIP("192.168.143.125")}, &dns.A{A: net.ParseIP("192.168.143.125")},
&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
}, },
"cname.domain": {&dns.CNAME{Target: "custom.domain"}}, },
"cname.ip6": {&dns.CNAME{Target: "ip6.domain"}}, Zone: config.ZoneFileDNS{
"cname.example": {&dns.CNAME{Target: "example.com"}}, RRs: config.CustomDNSMapping{
"cname.recursive": {&dns.CNAME{Target: "cname.recursive"}}, "example.zone.": {&dns.A{A: net.ParseIP("1.2.3.4"), Hdr: zoneHdr}},
"mx.domain": {&dns.MX{Mx: "mx.domain"}}, "cname.domain.": {&dns.CNAME{Target: "custom.domain", Hdr: zoneHdr}},
"cname.ip6.": {&dns.CNAME{Target: "ip6.domain", Hdr: zoneHdr}},
"cname.example.": {&dns.CNAME{Target: "example.com", Hdr: zoneHdr}},
"cname.recursive.": {&dns.CNAME{Target: "cname.recursive", Hdr: zoneHdr}},
"mx.domain.": {&dns.MX{Mx: "mx.domain", Hdr: zoneHdr}},
},
}, },
CustomTTL: config.Duration(time.Duration(TTL) * time.Second), CustomTTL: config.Duration(time.Duration(TTL) * time.Second),
FilterUnmappedTypes: true, FilterUnmappedTypes: true,
@ -136,6 +144,19 @@ var _ = Describe("CustomDNSResolver", func() {
When("Ip 4 mapping is defined for custom domain and", func() { When("Ip 4 mapping is defined for custom domain and", func() {
Context("filterUnmappedTypes is true", func() { Context("filterUnmappedTypes is true", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = true }) BeforeEach(func() { cfg.FilterUnmappedTypes = true })
It("defined ip4 query should be resolved from zone mappings and should use the TTL defined in the zone", func() {
Expect(sut.Resolve(ctx, newRequest("example.zone.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.zone.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", zoneTTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("defined ip4 query should be resolved", func() { It("defined ip4 query should be resolved", func() {
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
Should( Should(