mirror of https://github.com/0xERR0R/blocky.git
feat: support CNAME records in customDNS mappings (#1352)
Co-authored-by: Ben McHone <ben@mchone.dev>
This commit is contained in:
parent
3817d98e74
commit
b8b4dc323a
|
@ -15,3 +15,4 @@ vendor/
|
|||
coverage.html
|
||||
coverage.txt
|
||||
coverage/
|
||||
blocky
|
||||
|
|
|
@ -235,6 +235,17 @@ blocking:
|
|||
Expect(err.Error()).Should(ContainSubstring("invalid IP address '192.168.178.WRONG'"))
|
||||
})
|
||||
})
|
||||
When("CustomDNS hast wrong IPv6 defined", func() {
|
||||
It("should return error", func() {
|
||||
cfg := Config{}
|
||||
data := `customDNS:
|
||||
mapping:
|
||||
someDomain: 2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334`
|
||||
err := unmarshalConfig(logger, []byte(data), &cfg)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("invalid IP address '2001:MALFORMED:IP:ADDRESS:0000:8a2e:0370:7334'"))
|
||||
})
|
||||
})
|
||||
When("Conditional mapping hast wrong defined upstreams", func() {
|
||||
It("should return error", func() {
|
||||
cfg := Config{}
|
||||
|
@ -866,12 +877,24 @@ func defaultTestFileConfig(config *Config) {
|
|||
Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8"))
|
||||
Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4"))
|
||||
Expect(config.Upstreams.Groups["default"][2].Host).Should(Equal("1.1.1.1"))
|
||||
Expect(config.CustomDNS.Mapping.HostIPs).Should(HaveLen(2))
|
||||
Expect(config.CustomDNS.Mapping.HostIPs["my.duckdns.org"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
|
||||
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
|
||||
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][1]).Should(Equal(net.ParseIP("192.168.178.4")))
|
||||
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][2]).Should(Equal(
|
||||
net.ParseIP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344")))
|
||||
Expect(config.CustomDNS.Mapping).Should(HaveLen(2))
|
||||
|
||||
duckDNSEntry := config.CustomDNS.Mapping["my.duckdns.org"][0]
|
||||
duckDNSA := duckDNSEntry.(*dns.A)
|
||||
Expect(duckDNSA.A).Should(Equal(net.ParseIP("192.168.178.3")))
|
||||
|
||||
multipleIpsEntry := config.CustomDNS.Mapping["multiple.ips"][0]
|
||||
multipleIpsA := multipleIpsEntry.(*dns.A)
|
||||
Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.3")))
|
||||
|
||||
multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][1]
|
||||
multipleIpsA = multipleIpsEntry.(*dns.A)
|
||||
Expect(multipleIpsA.A).Should(Equal(net.ParseIP("192.168.178.4")))
|
||||
|
||||
multipleIpsEntry = config.CustomDNS.Mapping["multiple.ips"][2]
|
||||
multipleIpsAAAA := multipleIpsEntry.(*dns.AAAA)
|
||||
Expect(multipleIpsAAAA.AAAA).Should(Equal(net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7344")))
|
||||
|
||||
Expect(config.Conditional.Mapping.Upstreams).Should(HaveLen(2))
|
||||
Expect(config.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1))
|
||||
Expect(config.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2))
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -16,14 +17,45 @@ type CustomDNS struct {
|
|||
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
|
||||
}
|
||||
|
||||
// CustomDNSMapping mapping for the custom DNS configuration
|
||||
type CustomDNSMapping struct {
|
||||
HostIPs map[string][]net.IP `yaml:"hostIPs"`
|
||||
type (
|
||||
CustomDNSMapping map[string]CustomDNSEntries
|
||||
CustomDNSEntries []dns.RR
|
||||
)
|
||||
|
||||
func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parts := strings.Split(input, ",")
|
||||
result := make(CustomDNSEntries, len(parts))
|
||||
containsCNAME := false
|
||||
|
||||
for i, part := range parts {
|
||||
rr, err := configToRR(part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, isCNAME := rr.(*dns.CNAME)
|
||||
containsCNAME = containsCNAME || isCNAME
|
||||
|
||||
result[i] = rr
|
||||
}
|
||||
|
||||
if containsCNAME && len(result) > 1 {
|
||||
return fmt.Errorf("when a CNAME record is present, it must be the only record in the mapping")
|
||||
}
|
||||
|
||||
*c = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *CustomDNS) IsEnabled() bool {
|
||||
return len(c.Mapping.HostIPs) != 0
|
||||
return len(c.Mapping) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
|
@ -33,36 +65,52 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
|
|||
|
||||
logger.Info("mapping:")
|
||||
|
||||
for key, val := range c.Mapping.HostIPs {
|
||||
for key, val := range c.Mapping {
|
||||
logger.Infof(" %s = %s", key, val)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements `yaml.Unmarshaler`.
|
||||
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input map[string]string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
func removePrefixSuffix(in, prefix string) string {
|
||||
in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix))
|
||||
in = strings.TrimSuffix(in, ")")
|
||||
|
||||
return strings.TrimSpace(in)
|
||||
}
|
||||
|
||||
func configToRR(part string) (dns.RR, error) {
|
||||
if strings.HasPrefix(part, "CNAME(") {
|
||||
domain := removePrefixSuffix(part, "CNAME")
|
||||
domain = dns.Fqdn(domain)
|
||||
cname := &dns.CNAME{Target: domain}
|
||||
|
||||
return cname, nil
|
||||
}
|
||||
|
||||
result := make(map[string][]net.IP, len(input))
|
||||
// Fall back to A/AAAA records to maintain backwards compatibility in config.yml
|
||||
// We will still remove the A() or AAAA() if it exists
|
||||
if strings.Contains(part, ".") { // IPV4 address
|
||||
ipStr := removePrefixSuffix(part, "A")
|
||||
ip := net.ParseIP(ipStr)
|
||||
|
||||
for k, v := range input {
|
||||
var ips []net.IP
|
||||
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
ip := net.ParseIP(strings.TrimSpace(part))
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address '%s'", part)
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address '%s'", part)
|
||||
}
|
||||
|
||||
result[k] = ips
|
||||
a := new(dns.A)
|
||||
a.A = ip
|
||||
|
||||
return a, nil
|
||||
} else { // IPV6 address
|
||||
ipStr := removePrefixSuffix(part, "AAAA")
|
||||
ip := net.ParseIP(ipStr)
|
||||
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address '%s'", part)
|
||||
}
|
||||
|
||||
aaaa := new(dns.AAAA)
|
||||
aaaa.AAAA = ip
|
||||
|
||||
return aaaa, nil
|
||||
}
|
||||
|
||||
c.HostIPs = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -17,15 +18,14 @@ var _ = Describe("CustomDNSConfig", func() {
|
|||
BeforeEach(func() {
|
||||
cfg = CustomDNS{
|
||||
Mapping: CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.domain": {net.ParseIP("192.168.143.123")},
|
||||
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
"multiple.ips": {
|
||||
net.ParseIP("192.168.143.123"),
|
||||
net.ParseIP("192.168.143.125"),
|
||||
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
},
|
||||
"custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}},
|
||||
"ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}},
|
||||
"multiple.ips": {
|
||||
&dns.A{A: net.ParseIP("192.168.143.123")},
|
||||
&dns.A{A: net.ParseIP("192.168.143.125")},
|
||||
&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
},
|
||||
"cname.domain": {&dns.CNAME{Target: "custom.domain"}},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
@ -60,27 +60,41 @@ var _ = Describe("CustomDNSConfig", func() {
|
|||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElements(
|
||||
ContainSubstring("custom.domain = "),
|
||||
ContainSubstring("ip6.domain = "),
|
||||
ContainSubstring("multiple.ips = "),
|
||||
ContainSubstring("cname.domain = "),
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UnmarshalYAML", func() {
|
||||
It("Should parse config as map", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
c := CustomDNSEntries{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
*i.(*string) = "1.2.3.4"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(c.HostIPs).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"]).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4")))
|
||||
Expect(c).Should(HaveLen(1))
|
||||
|
||||
aRecord := c[0].(*dns.A)
|
||||
Expect(aRecord.A).Should(Equal(net.ParseIP("1.2.3.4")))
|
||||
})
|
||||
|
||||
It("Should return an error if a CNAME is accomanied by any other record", func() {
|
||||
c := CustomDNSEntries{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "CNAME(example.com),A(1.2.3.4)"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("when a CNAME record is present, it must be the only record in the mapping"))
|
||||
})
|
||||
|
||||
It("should fail if wrong YAML format", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
c := &CustomDNSEntries{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
|
|
|
@ -47,6 +47,7 @@ customDNS:
|
|||
example.com: printer.lan
|
||||
mapping:
|
||||
printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
|
||||
second-printer-address.lan: CNAME(printer.lan)
|
||||
|
||||
# optional: definition, which DNS resolver(s) should be used for queries to the domain (with all sub-domains). Multiple resolvers must be separated by a comma
|
||||
# Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name
|
||||
|
|
|
@ -259,12 +259,12 @@ You can define your own domain name to IP mappings. For example, you can use a u
|
|||
or define a domain name for your local device on order to use the HTTPS certificate. Multiple IP addresses for one
|
||||
domain must be separated by a comma.
|
||||
|
||||
| Parameter | Type | Mandatory | Default value |
|
||||
| ------------------- | --------------------------------------- | --------- | ------------- |
|
||||
| customTTL | duration (no unit is minutes) | no | 1h |
|
||||
| rewrite | string: string (domain: domain) | no | |
|
||||
| mapping | string: string (hostname: address list) | no | |
|
||||
| filterUnmappedTypes | boolean | no | true |
|
||||
| Parameter | Type | Mandatory | Default value |
|
||||
| ------------------- | ------------------------------------------- | --------- | ------------- |
|
||||
| customTTL | duration (no unit is minutes) | no | 1h |
|
||||
| rewrite | string: string (domain: domain) | no | |
|
||||
| mapping | string: string (hostname: address or CNAME) | no | |
|
||||
| filterUnmappedTypes | boolean | no | true |
|
||||
|
||||
!!! example
|
||||
|
||||
|
@ -278,11 +278,14 @@ domain must be separated by a comma.
|
|||
mapping:
|
||||
printer.lan: 192.168.178.3
|
||||
otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
|
||||
anothername.lan: CNAME(otherdevice.lan)
|
||||
```
|
||||
|
||||
This configuration will also resolve any subdomain of the defined domain, recursively. For example querying any of
|
||||
`printer.lan`, `my.printer.lan` or `i.love.my.printer.lan` will return 192.168.178.3.
|
||||
|
||||
CNAME records are supported by setting the value of the mapping to `CNAME(target)`. Note that the target will be recursively resolved and will return an error if a loop is detected.
|
||||
|
||||
With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the
|
||||
resolver lookup is performed.
|
||||
The query "printer.home" will be rewritten to "printer.lan" and return 192.168.178.3.
|
||||
|
|
|
@ -66,7 +66,7 @@ func createDNSMokkaContainer(ctx context.Context, alias string, rules ...string)
|
|||
// createHTTPServerContainer creates a static HTTP server container that serves one file with the given lines
|
||||
// and is attached to the test network under the given alias.
|
||||
// It is automatically terminated when the test is finished.
|
||||
func createHTTPServerContainer(ctx context.Context, alias string, filename string, lines ...string,
|
||||
func createHTTPServerContainer(ctx context.Context, alias, filename string, lines ...string,
|
||||
) (testcontainers.Container, error) {
|
||||
file := createTempFile(lines...)
|
||||
|
||||
|
|
|
@ -212,6 +212,8 @@ func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error
|
|||
return v.A.String() == matcher.answer, nil
|
||||
case *dns.AAAA:
|
||||
return v.AAAA.String() == matcher.answer, nil
|
||||
case *dns.CNAME:
|
||||
return v.Target == matcher.answer, nil
|
||||
case *dns.PTR:
|
||||
return v.Ptr == matcher.answer, nil
|
||||
case *dns.MX:
|
||||
|
|
|
@ -2,7 +2,9 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -14,27 +16,41 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type createAnswerFunc func(question dns.Question, ip net.IP, ttl uint32) (dns.RR, error)
|
||||
|
||||
// CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map
|
||||
type CustomDNSResolver struct {
|
||||
configurable[*config.CustomDNS]
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
mapping map[string][]net.IP
|
||||
reverseAddresses map[string][]string
|
||||
createAnswerFromQuestion createAnswerFunc
|
||||
mapping config.CustomDNSMapping
|
||||
reverseAddresses map[string][]string
|
||||
}
|
||||
|
||||
// NewCustomDNSResolver creates new resolver instance
|
||||
func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
|
||||
m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs))
|
||||
reverse := make(map[string][]string, len(cfg.Mapping.HostIPs))
|
||||
m := make(config.CustomDNSMapping, len(cfg.Mapping))
|
||||
reverse := make(map[string][]string, len(cfg.Mapping))
|
||||
|
||||
for url, ips := range cfg.Mapping.HostIPs {
|
||||
m[strings.ToLower(url)] = ips
|
||||
for url, entries := range cfg.Mapping {
|
||||
m[strings.ToLower(url)] = entries
|
||||
|
||||
for _, ip := range ips {
|
||||
r, _ := dns.ReverseAddr(ip.String())
|
||||
reverse[r] = append(reverse[r], url)
|
||||
for _, entry := range entries {
|
||||
a, isA := entry.(*dns.A)
|
||||
|
||||
if isA {
|
||||
r, _ := dns.ReverseAddr(a.A.String())
|
||||
reverse[r] = append(reverse[r], url)
|
||||
}
|
||||
|
||||
aaaa, isAAAA := entry.(*dns.AAAA)
|
||||
|
||||
if isAAAA {
|
||||
r, _ := dns.ReverseAddr(aaaa.AAAA.String())
|
||||
reverse[r] = append(reverse[r], url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,8 +58,9 @@ func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
|
|||
configurable: withConfig(&cfg),
|
||||
typed: withType("custom_dns"),
|
||||
|
||||
mapping: m,
|
||||
reverseAddresses: reverse,
|
||||
createAnswerFromQuestion: util.CreateAnswerFromQuestion,
|
||||
mapping: m,
|
||||
reverseAddresses: reverse,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,7 +92,26 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response {
|
||||
func (r *CustomDNSResolver) forwardResponse(
|
||||
logger *logrus.Entry,
|
||||
ctx context.Context,
|
||||
request *model.Request,
|
||||
) (*model.Response, error) {
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
forwardResponse, err := r.next.Resolve(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return forwardResponse, nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processRequest(
|
||||
ctx context.Context,
|
||||
request *model.Request,
|
||||
resolvedCnames []string,
|
||||
) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
|
||||
|
||||
response := new(dns.Msg)
|
||||
|
@ -85,13 +121,20 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
domain := util.ExtractDomain(question)
|
||||
|
||||
for len(domain) > 0 {
|
||||
ips, found := r.mapping[domain]
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, found := r.mapping[domain]
|
||||
|
||||
if found {
|
||||
for _, ip := range ips {
|
||||
if isSupportedType(ip, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32())
|
||||
response.Answer = append(response.Answer, rr)
|
||||
for _, entry := range entries {
|
||||
result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response.Answer = append(response.Answer, result...)
|
||||
}
|
||||
|
||||
if len(response.Answer) > 0 {
|
||||
|
@ -100,7 +143,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
"domain": domain,
|
||||
}).Debugf("returning custom dns entry")
|
||||
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
||||
}
|
||||
|
||||
// Mapping exists for this domain, but for another type
|
||||
|
@ -110,7 +153,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
}
|
||||
|
||||
// return NOERROR with empty result
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(domain, "."); i >= 0 {
|
||||
|
@ -120,26 +163,99 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return r.forwardResponse(logger, ctx, request)
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processDNSEntry(
|
||||
ctx context.Context,
|
||||
request *model.Request,
|
||||
resolvedCnames []string,
|
||||
question dns.Question,
|
||||
entry dns.RR,
|
||||
) ([]dns.RR, error) {
|
||||
switch v := entry.(type) {
|
||||
case *dns.A:
|
||||
return r.processIP(v.A, question)
|
||||
case *dns.AAAA:
|
||||
return r.processIP(v.AAAA, question)
|
||||
case *dns.CNAME:
|
||||
return r.processCNAME(ctx, request, *v, resolvedCnames, question)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported customDNS RR type %T", entry)
|
||||
}
|
||||
|
||||
// Resolve uses internal mapping to resolve the query
|
||||
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
|
||||
|
||||
reverseResp := r.handleReverseDNS(request)
|
||||
if reverseResp != nil {
|
||||
return reverseResp, nil
|
||||
}
|
||||
|
||||
if len(r.mapping) > 0 {
|
||||
resp := r.processRequest(request)
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(ctx, request)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question) (result []dns.RR, err error) {
|
||||
result = make([]dns.RR, 0)
|
||||
|
||||
if isSupportedType(ip, question) {
|
||||
rr, err := r.createAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, rr)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processCNAME(
|
||||
ctx context.Context,
|
||||
request *model.Request,
|
||||
targetCname dns.CNAME,
|
||||
resolvedCnames []string,
|
||||
question dns.Question,
|
||||
) (result []dns.RR, err error) {
|
||||
cname := new(dns.CNAME)
|
||||
ttl := r.cfg.CustomTTL.SecondsU32()
|
||||
cname.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeCNAME, Name: question.Name}
|
||||
cname.Target = dns.Fqdn(targetCname.Target)
|
||||
result = append(result, cname)
|
||||
|
||||
if question.Qtype == dns.TypeCNAME {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
targetWithoutDot := strings.TrimSuffix(targetCname.Target, ".")
|
||||
|
||||
if slices.Contains(resolvedCnames, targetWithoutDot) {
|
||||
return nil, fmt.Errorf("CNAME loop detected: %v", append(resolvedCnames, targetWithoutDot))
|
||||
}
|
||||
|
||||
cnames := resolvedCnames
|
||||
cnames = append(cnames, targetWithoutDot)
|
||||
|
||||
clientIP := request.ClientIP.String()
|
||||
clientID := request.RequestClientID
|
||||
targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID)
|
||||
|
||||
// resolve the target recursively
|
||||
targetResp, err := r.processRequest(ctx, targetRequest, cnames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, targetResp.Res.Answer...)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) CreateAnswerFromQuestion(newFunc createAnswerFunc) {
|
||||
r.createAnswerFromQuestion = newFunc
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -38,15 +39,20 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
DeferCleanup(cancelFn)
|
||||
|
||||
cfg = config.CustomDNS{
|
||||
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
|
||||
"custom.domain": {net.ParseIP("192.168.143.123")},
|
||||
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
Mapping: config.CustomDNSMapping{
|
||||
"custom.domain": {&dns.A{A: net.ParseIP("192.168.143.123")}},
|
||||
"ip6.domain": {&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}},
|
||||
"multiple.ips": {
|
||||
net.ParseIP("192.168.143.123"),
|
||||
net.ParseIP("192.168.143.125"),
|
||||
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
&dns.A{A: net.ParseIP("192.168.143.123")},
|
||||
&dns.A{A: net.ParseIP("192.168.143.125")},
|
||||
&dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
},
|
||||
}},
|
||||
"cname.domain": {&dns.CNAME{Target: "custom.domain"}},
|
||||
"cname.ip6": {&dns.CNAME{Target: "ip6.domain"}},
|
||||
"cname.example": {&dns.CNAME{Target: "example.com"}},
|
||||
"cname.recursive": {&dns.CNAME{Target: "cname.recursive"}},
|
||||
"mx.domain": {&dns.MX{Mx: "mx.domain"}},
|
||||
},
|
||||
CustomTTL: config.Duration(time.Duration(TTL) * time.Second),
|
||||
FilterUnmappedTypes: true,
|
||||
}
|
||||
|
@ -76,6 +82,57 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
|
||||
Describe("Resolving custom name via CustomDNSResolver", func() {
|
||||
When("The parent context has an error ", func() {
|
||||
It("should return the error", func() {
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := sut.Resolve(cancelledCtx, newRequest("custom.domain.", A))
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("context canceled"))
|
||||
})
|
||||
})
|
||||
When("Creating the IP response returns an error ", func() {
|
||||
It("should return the error", func() {
|
||||
createAnswerMock := func(_ dns.Question, _ net.IP, _ uint32) (dns.RR, error) {
|
||||
return nil, fmt.Errorf("create answer error")
|
||||
}
|
||||
|
||||
sut.CreateAnswerFromQuestion(createAnswerMock)
|
||||
|
||||
_, err := sut.Resolve(ctx, newRequest("custom.domain.", A))
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("create answer error"))
|
||||
})
|
||||
})
|
||||
When("The forward request returns an error ", func() {
|
||||
It("should return the error if the error occurs when checking ipv4 forward addresses", func() {
|
||||
err := fmt.Errorf("forward error")
|
||||
m = &mockResolver{}
|
||||
|
||||
m.On("Resolve", mock.Anything).Return(nil, err)
|
||||
|
||||
sut.Next(m)
|
||||
_, err = sut.Resolve(ctx, newRequest("cname.example.", A))
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("forward error"))
|
||||
})
|
||||
It("should return the error if the error occurs when checking ipv6 forward addresses", func() {
|
||||
err := fmt.Errorf("forward error")
|
||||
m = &mockResolver{}
|
||||
|
||||
m.On("Resolve", mock.Anything).Return(nil, err)
|
||||
|
||||
sut.Next(m)
|
||||
_, err = sut.Resolve(ctx, newRequest("cname.example.", AAAA))
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("forward error"))
|
||||
})
|
||||
})
|
||||
When("Ip 4 mapping is defined for custom domain and", func() {
|
||||
Context("filterUnmappedTypes is true", func() {
|
||||
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
|
||||
|
@ -211,6 +268,101 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
When("A CNAME record is defined for custom domain ", func() {
|
||||
It("should not recurse if the request is strictly a CNAME request", func() {
|
||||
By("CNAME query", func() {
|
||||
Expect(sut.Resolve(ctx, newRequest("cname.domain", CNAME))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
HaveLen(1),
|
||||
ContainElements(
|
||||
BeDNSRecord("cname.domain.", CNAME, "custom.domain.")),
|
||||
)),
|
||||
HaveResponseType(ResponseTypeCUSTOMDNS),
|
||||
HaveReason("CUSTOM DNS"),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// will not delegate to next resolver
|
||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
})
|
||||
It("all CNAMES for the current type should be recursively resolved when relying on other Mappings", func() {
|
||||
By("A query", func() {
|
||||
Expect(sut.Resolve(ctx, newRequest("cname.domain", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
HaveLen(2),
|
||||
ContainElements(
|
||||
BeDNSRecord("cname.domain.", CNAME, "custom.domain."),
|
||||
BeDNSRecord("custom.domain.", A, "192.168.143.123")),
|
||||
)),
|
||||
HaveResponseType(ResponseTypeCUSTOMDNS),
|
||||
HaveReason("CUSTOM DNS"),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// will not delegate to next resolver
|
||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
By("AAAA query", func() {
|
||||
Expect(sut.Resolve(ctx, newRequest("cname.ip6", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
HaveLen(2),
|
||||
ContainElements(
|
||||
BeDNSRecord("cname.ip6.", CNAME, "ip6.domain."),
|
||||
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334")),
|
||||
)),
|
||||
HaveResponseType(ResponseTypeCUSTOMDNS),
|
||||
HaveReason("CUSTOM DNS"),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// will not delegate to next resolver
|
||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
})
|
||||
It("should return an error when the CNAME is recursive", func() {
|
||||
By("CNAME query", func() {
|
||||
_, err := sut.Resolve(ctx, newRequest("cname.recursive", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("CNAME loop detected:"))
|
||||
// will not delegate to next resolver
|
||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
})
|
||||
It("all CNAMES for the current type should be returned when relying on public DNS", func() {
|
||||
By("CNAME query", func() {
|
||||
Expect(sut.Resolve(ctx, newRequest("cname.example", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
ContainElements(
|
||||
BeDNSRecord("cname.example.", CNAME, "example.com.")),
|
||||
)),
|
||||
HaveResponseType(ResponseTypeCUSTOMDNS),
|
||||
HaveReason("CUSTOM DNS"),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// will delegate to next resolver
|
||||
m.AssertCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
})
|
||||
})
|
||||
When("An unsupported DNS query type is queried from the resolver but found in the config mapping ", func() {
|
||||
It("an error should be returned", func() {
|
||||
By("MX query", func() {
|
||||
_, err := sut.Resolve(ctx, newRequest("mx.domain", MX))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("unsupported customDNS RR type *dns.MX"))
|
||||
})
|
||||
})
|
||||
})
|
||||
When("Reverse DNS request is received", func() {
|
||||
It("should resolve the defined domain name", func() {
|
||||
By("ipv4", func() {
|
||||
|
|
|
@ -613,6 +613,13 @@ func (s *Server) OnRequest(
|
|||
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
var response *model.Response
|
||||
|
||||
contextUpstreamTimeoutMultiplier := 100
|
||||
timeoutDuration := time.Duration(contextUpstreamTimeoutMultiplier) * s.cfg.Upstreams.Timeout.ToDuration()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeoutDuration)
|
||||
|
||||
defer cancel()
|
||||
|
||||
switch {
|
||||
case len(request.Req.Question) == 0:
|
||||
m := new(dns.Msg)
|
||||
|
|
|
@ -104,10 +104,8 @@ var _ = BeforeSuite(func() {
|
|||
CustomDNS: config.CustomDNS{
|
||||
CustomTTL: config.Duration(3600 * time.Second),
|
||||
Mapping: config.CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.lan": {net.ParseIP("192.168.178.55")},
|
||||
"lan.home": {net.ParseIP("192.168.178.56")},
|
||||
},
|
||||
"custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
|
||||
"lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
|
||||
},
|
||||
},
|
||||
Conditional: config.ConditionalUpstream{
|
||||
|
@ -596,10 +594,8 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
CustomDNS: config.CustomDNS{
|
||||
Mapping: config.CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.lan": {net.ParseIP("192.168.178.55")},
|
||||
"lan.home": {net.ParseIP("192.168.178.56")},
|
||||
},
|
||||
"custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
|
||||
"lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
|
||||
},
|
||||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
|
@ -642,10 +638,8 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
CustomDNS: config.CustomDNS{
|
||||
Mapping: config.CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.lan": {net.ParseIP("192.168.178.55")},
|
||||
"lan.home": {net.ParseIP("192.168.178.56")},
|
||||
},
|
||||
"custom.lan": {&dns.A{A: net.ParseIP("192.168.178.55")}},
|
||||
"lan.home": {&dns.A{A: net.ParseIP("192.168.178.56")}},
|
||||
},
|
||||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
|
|
Loading…
Reference in New Issue