Add rewrite support to custom DNS (#449)

This commit extracts rewriting logic from `ConditionalUpstreamResolver`
into the new `RewriterResolver`, and uses that to enable rewriting for
the `CustomDNSResolver`.
`RewriterResolver` wraps a resolver and applies the rewrite to the
request that is forwarded to the inner resolver.

It also introduces a new optional interface: `NamedResolver`.
This allows a `Resolver` to choose what its user friendly name is,
instead of always being its type name.
This commit is contained in:
ThinkChaos 2022-03-17 17:30:21 -04:00 committed by GitHub
parent e6a5af33f2
commit f8b6e59ef4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 478 additions and 139 deletions

View File

@ -53,7 +53,7 @@ Blocky is a DNS proxy and ad-blocker for the local network written in Go with fo
* Supports modern DNS extensions: DNSSEC, eDNS, ...
* Free configurable blocking lists - no hidden filtering etc.
* Provides DoH Endpoint
* Uses random upstream resolvers from the configuration - increases you privacy though the distribution of your DNS
* Uses random upstream resolvers from the configuration - increases your privacy through the distribution of your DNS
traffic over multiple provider
* Blocky does **NOT** collect any user data, telemetry, statistics etc.

View File

@ -103,7 +103,7 @@ func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) e
return err
}
result := make(map[string][]Upstream)
result := make(map[string][]Upstream, len(input))
for k, v := range input {
var upstreams []Upstream
@ -132,7 +132,7 @@ func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) erro
return err
}
result := make(map[string][]net.IP)
result := make(map[string][]net.IP, len(input))
for k, v := range input {
var ips []net.IP
@ -310,10 +310,16 @@ type UpstreamConfig struct {
ExternalResolvers map[string][]Upstream `yaml:",inline"`
}
// RewriteConfig custom DNS configuration
type RewriteConfig struct {
Rewrite map[string]string `yaml:"rewrite"`
}
// CustomDNSConfig custom DNS configuration
type CustomDNSConfig struct {
CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"`
RewriteConfig `yaml:",inline"`
CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"`
}
// CustomDNSMapping mapping for the custom DNS configuration
@ -323,8 +329,8 @@ type CustomDNSMapping struct {
// ConditionalUpstreamConfig conditional upstream configuration
type ConditionalUpstreamConfig struct {
Rewrite map[string]string `yaml:"rewrite"`
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
RewriteConfig `yaml:",inline"`
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
}
// ConditionalUpstreamMapping mapping for conditional configuration

View File

@ -23,6 +23,9 @@ upstreamTimeout: 2s
# example: query "printer.lan" or "my.printer.lan" will return 192.168.178.3
customDNS:
customTTL: 1h
# optional: replace domain in the query with other domain before resolver lookup in the mapping
rewrite:
example.com: printer.lan
mapping:
printer.lan: 192.168.178.3,2001:0db8:85a3:08d3:1319:8a2e:0370:7344

View File

@ -118,6 +118,7 @@ domain must be separated by a comma.
| Parameter | Type | Mandatory | Default value |
|-----------|-----------------------------------------|-----------|---------------|
| customTTL | duration (no unit is minutes) | no | 1h |
| rewrite | string: string (domain: domain) | no | |
| mapping | string: string (hostname: address list) | no | |
!!! example
@ -125,13 +126,20 @@ domain must be separated by a comma.
```yaml
customDNS:
customTTL: 1h
mapping:
printer.lan: 192.168.178.3
otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
rewrite:
home: lan
replace-me.com: with-this.com
mapping:
printer.lan: 192.168.178.3
otherdevice.lan: 192.168.178.15,2001:0db8:85a3:08d3:1319:8a2e:0370:7344
```
This configuration will also resolve any subdomain of the defined domain. For example a query "printer.lan" or "
my.printer.lan" will return 192.168.178.3 as IP address.
my.printer.lan" will return 192.168.178.3 as IP address.
With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the
resolver lookup is performed.
The query "printer.home" will be rewritten to "printer.lan" and return 192.168.178.3.
## Conditional DNS resolution
@ -139,8 +147,7 @@ You can define, which DNS resolver(s) should be used for queries for the particu
is for example useful, if you want to reach devices in your local network by the name. Since only your router know which
hostname belongs to which IP address, all DNS queries for the local network should be redirected to the router.
With the optional parameter `rewrite` you can replace domain part of the query with the defined part **before** the
resolver lookup is performed.
The optional parameter `rewrite` behaves the same as with custom DNS.
!!! example
@ -163,7 +170,7 @@ resolver lookup is performed.
You can use `.` as wildcard for all non full qualified domains (domains without dot)
In this example, a DNS query "client.fritz.box" will be redirected to the router's DNS server at 192.168.178.1 and client.lan.net to 192.170.1.2 and 192.170.1.3.
The query client.example.com will be rewritten to "client.fritz.box" and also redirected to the resolver at 192.168.178.1. All unqualified hostnames (e.g. 'test')
The query "client.example.com" will be rewritten to "client.fritz.box" and also redirected to the resolver at 192.168.178.1. All unqualified hostnames (e.g. "test")
will be redirected to the DNS server at 168.168.0.1

View File

@ -113,7 +113,7 @@ func NewBlockingResolver(cfg config.BlockingConfig, redis *redis.Client) (Chaine
return nil, multierror.Prefix(err, "blocking resolver: ")
}
cgb := make(map[string][]string)
cgb := make(map[string][]string, len(cfg.ClientGroupsBlock))
for identifier, cfgGroups := range cfg.ClientGroupsBlock {
for _, ipart := range strings.Split(identifier, ",") {
@ -181,7 +181,7 @@ func (r *BlockingResolver) RefreshLists() {
// nolint:prealloc
func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
groups := make(map[string]bool)
groups := make(map[string]bool, len(r.cfg.BlackLists))
for group := range r.cfg.BlackLists {
groups[group] = true
@ -282,7 +282,7 @@ func (r *BlockingResolver) BlockingStatus() api.BlockingStatus {
// returns groups, which have only whitelist entries
func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result map[string]bool) {
result = make(map[string]bool)
result = make(map[string]bool, len(cfg.WhiteLists))
for g, links := range cfg.WhiteLists {
if len(links) > 0 {

View File

@ -43,7 +43,7 @@ var _ = Describe("BlockingResolver", func() {
var (
sut *BlockingResolver
sutConfig config.BlockingConfig
m *resolverMock
m *MockResolver
mockAnswer *dns.Msg
err error
@ -64,7 +64,7 @@ var _ = Describe("BlockingResolver", func() {
})
JustBeforeEach(func() {
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
tmp, _ := NewBlockingResolver(sutConfig, nil)
sut = tmp.(*BlockingResolver)
@ -126,16 +126,13 @@ var _ = Describe("BlockingResolver", func() {
When("Full-qualified group name is used", func() {
It("bla", func() {
tmp, _ := NewBlockingResolver(sutConfig, nil)
sut = tmp.(*BlockingResolver)
sut.Next(&MockResolver{AnswerFn: func(t uint16, qName string) *dns.Msg {
m.AnswerFn = func(t uint16, qName string) *dns.Msg {
if t == dns.TypeA && qName == "full.qualified.com." {
a, _ := util.NewMsgWithAnswer(qName, 60*60, dns.TypeA, "192.168.178.39")
return a
}
return nil
}})
sut.RefreshLists()
}
Bus().Publish(ApplicationStarted, "")
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "192.168.178.39", "client1"))
@ -829,7 +826,7 @@ var _ = Describe("BlockingResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})

View File

@ -23,7 +23,7 @@ var _ = Describe("CachingResolver", func() {
var (
sut ChainedResolver
sutConfig config.CachingConfig
m *resolverMock
m *MockResolver
mockAnswer *dns.Msg
err error
@ -45,7 +45,7 @@ var _ = Describe("CachingResolver", func() {
JustBeforeEach(func() {
sut = NewCachingResolver(sutConfig, nil)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)
})
@ -444,7 +444,7 @@ var _ = Describe("CachingResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
@ -469,7 +469,7 @@ var _ = Describe("CachingResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
Expect(c).Should(ContainElement(ContainSubstring("prefetchThreshold")))
})
})
@ -509,7 +509,7 @@ var _ = Describe("CachingResolver", func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, dns.TypeA, "1.1.1.1")
sut = NewCachingResolver(sutConfig, redisClient)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)
})

View File

@ -20,7 +20,7 @@ var _ = Describe("ClientResolver", func() {
var (
sut *ClientNamesResolver
sutConfig config.ClientLookupConfig
m *resolverMock
m *MockResolver
mockReverseUpstream config.Upstream
mockReverseUpstreamCallCount int
mockReverseUpstreamAnswer *dns.Msg
@ -47,7 +47,7 @@ var _ = Describe("ClientResolver", func() {
JustBeforeEach(func() {
sut = NewClientNamesResolver(sutConfig).(*ClientNamesResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -250,9 +250,9 @@ var _ = Describe("ClientResolver", func() {
})
When("Upstream produces error", func() {
JustBeforeEach(func() {
clientResolverMock := &resolverMock{}
clientResolverMock.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
sut.externalResolver = clientResolverMock
clientMockResolver := &MockResolver{}
clientMockResolver.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
sut.externalResolver = clientMockResolver
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.TypeA, "192.168.178.25")
@ -303,7 +303,7 @@ var _ = Describe("ClientResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})

View File

@ -16,13 +16,11 @@ import (
type ConditionalUpstreamResolver struct {
NextResolver
mapping map[string]Resolver
rewrite map[string]string
}
// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(cfg config.ConditionalUpstreamConfig) ChainedResolver {
m := make(map[string]Resolver)
rewrite := make(map[string]string)
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstream := range cfg.Mapping.Upstreams {
upstreams := make(map[string][]config.Upstream)
@ -30,11 +28,7 @@ func NewConditionalUpstreamResolver(cfg config.ConditionalUpstreamConfig) Chaine
m[strings.ToLower(domain)] = NewParallelBestResolver(upstreams)
}
for k, v := range cfg.Rewrite {
rewrite[strings.ToLower(k)] = strings.ToLower(v)
}
return &ConditionalUpstreamResolver{mapping: m, rewrite: rewrite}
return &ConditionalUpstreamResolver{mapping: m}
}
// Configuration returns current configuration
@ -43,13 +37,6 @@ func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
for key, val := range r.mapping {
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
}
if len(r.rewrite) > 0 {
result = append(result, "rewrite:")
for key, val := range r.rewrite {
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
}
}
} else {
result = []string{"deactivated"}
}
@ -57,22 +44,12 @@ func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
return
}
func (r *ConditionalUpstreamResolver) applyRewrite(domain string) string {
for k, v := range r.rewrite {
if strings.HasSuffix(domain, "."+k) {
return strings.TrimSuffix(domain, "."+k) + "." + v
}
}
return domain
}
// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := withPrefix(request.Log, "conditional_resolver")
if len(r.mapping) > 0 {
domainFromQuestion := r.applyRewrite(util.ExtractDomain(request.Req.Question[0]))
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion
if !strings.Contains(domainFromQuestion, ".") {

View File

@ -17,7 +17,7 @@ import (
var _ = Describe("ConditionalUpstreamResolver", func() {
var (
sut ChainedResolver
m *resolverMock
m *MockResolver
err error
resp *Response
)
@ -28,7 +28,6 @@ var _ = Describe("ConditionalUpstreamResolver", func() {
BeforeEach(func() {
sut = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
Rewrite: map[string]string{"example.com": "fritz.box"},
Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{
"fritz.box": {TestUDPUpstream(func(request *dns.Msg) (response *dns.Msg) {
@ -48,7 +47,7 @@ var _ = Describe("ConditionalUpstreamResolver", func() {
})},
}},
})
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
@ -96,22 +95,6 @@ var _ = Describe("ConditionalUpstreamResolver", func() {
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
})
When("rewrite mapping is defined", func() {
It("Should resolve the IP via defined resolver after applying the rewrite", func() {
resp, err = sut.Resolve(newRequest("test.example.com.", dns.TypeA))
Expect(resp.Res.Answer).Should(BeDNSRecord("test.fritz.box.", dns.TypeA, 123, "123.124.122.122"))
// no call to next resolver
Expect(m.Calls).Should(BeEmpty())
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
It("Should delegate to next resolver if there is no subdomain after rewrite", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.TypeA))
m.AssertExpectations(GinkgoT())
})
})
})
Describe("Delegation to next resolver", func() {
When("Query doesn't match defined mapping", func() {
@ -127,7 +110,7 @@ var _ = Describe("ConditionalUpstreamResolver", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {

View File

@ -24,8 +24,8 @@ type CustomDNSResolver struct {
// NewCustomDNSResolver creates new resolver instance
func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver {
m := make(map[string][]net.IP)
reverse := make(map[string][]string)
m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs))
reverse := make(map[string][]string, len(cfg.Mapping.HostIPs))
for url, ips := range cfg.Mapping.HostIPs {
m[strings.ToLower(url)] = ips

View File

@ -16,7 +16,7 @@ import (
var _ = Describe("CustomDNSResolver", func() {
var (
sut ChainedResolver
m *resolverMock
m *MockResolver
err error
resp *Response
)
@ -35,7 +35,7 @@ var _ = Describe("CustomDNSResolver", func() {
}},
CustomTTL: config.Duration(time.Duration(TTL) * time.Second),
})
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})

View File

@ -17,7 +17,7 @@ import (
var _ = Describe("HostsFileResolver", func() {
var (
sut *HostsFileResolver
m *resolverMock
m *MockResolver
err error
resp *Response
)
@ -31,7 +31,7 @@ var _ = Describe("HostsFileResolver", func() {
RefreshPeriod: config.Duration(30 * time.Minute),
}
sut = NewHostsFileResolver(cfg).(*HostsFileResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
@ -44,7 +44,7 @@ var _ = Describe("HostsFileResolver", func() {
Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()),
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
}).(*HostsFileResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
@ -62,7 +62,7 @@ var _ = Describe("HostsFileResolver", func() {
When("Hosts file is not set", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{}).(*HostsFileResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})

View File

@ -14,7 +14,7 @@ import (
var _ = Describe("IPv6DisablingResolver", func() {
var (
sut *IPv6DisablingResolver
m *resolverMock
m *MockResolver
mockAnswer *dns.Msg
disableIPv6 *bool
query = newRequest("example.com", dns.TypeAAAA)
@ -23,7 +23,7 @@ var _ = Describe("IPv6DisablingResolver", func() {
JustBeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1230, dns.TypeAAAA, "2001:0db8:85a3:08d3:1319:8a2e:0370:7344")
sut = NewIPv6Checker(*disableIPv6).(*IPv6DisablingResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
sut.Next(m)
})

View File

@ -18,14 +18,14 @@ import (
var _ = Describe("MetricResolver", func() {
var (
sut *MetricsResolver
m *resolverMock
m *MockResolver
err error
resp *Response
)
BeforeEach(func() {
sut = NewMetricsResolver(config.PrometheusConfig{Enable: true}).(*MetricsResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
@ -47,7 +47,7 @@ var _ = Describe("MetricResolver", func() {
})
When("Error occurs while request processing", func() {
BeforeEach(func() {
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
sut.Next(m)
})
@ -65,7 +65,7 @@ var _ = Describe("MetricResolver", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
})

View File

@ -18,50 +18,60 @@ import (
)
type MockResolver struct {
AnswerFn func(t uint16, qName string) *dns.Msg
}
func (m *MockResolver) Resolve(req *model.Request) (*model.Response, error) {
for _, question := range req.Req.Question {
answer := m.AnswerFn(question.Qtype, question.Name)
if answer != nil {
return &model.Response{
Res: answer,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
}
response := new(dns.Msg)
response.SetRcode(req.Req, dns.RcodeBadName)
return &model.Response{
Res: response,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
func (m *MockResolver) Configuration() []string {
return []string{}
}
type resolverMock struct {
mock.Mock
NextResolver
ResolveFn func(req *model.Request) (*model.Response, error)
ResponseFn func(req *dns.Msg) *dns.Msg
AnswerFn func(t uint16, qName string) *dns.Msg
}
func (r *resolverMock) Configuration() (result []string) {
return
func (r *MockResolver) Configuration() []string {
args := r.Called()
return args.Get(0).([]string)
}
func (r *resolverMock) Resolve(req *model.Request) (*model.Response, error) {
func (r *MockResolver) Resolve(req *model.Request) (*model.Response, error) {
args := r.Called(req)
if r.ResolveFn != nil {
return r.ResolveFn(req)
}
if r.ResponseFn != nil {
return &model.Response{
Res: r.ResponseFn(req.Req),
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
if r.AnswerFn != nil {
for _, question := range req.Req.Question {
answer := r.AnswerFn(question.Qtype, question.Name)
if answer != nil {
return &model.Response{
Res: answer,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
}
response := new(dns.Msg)
response.SetRcode(req.Req, dns.RcodeBadName)
return &model.Response{
Res: response,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
resp, ok := args.Get(0).(*model.Response)
if ok {
return resp, args.Error((1))
return resp, args.Error(1)
}
return nil, args.Error(1)

22
resolver/noop_resolver.go Normal file
View File

@ -0,0 +1,22 @@
package resolver
import (
"github.com/0xERR0R/blocky/model"
)
var NoResponse = &model.Response{} // nolint:gochecknoglobals
// NoOpResolver is used to finish a resolver branch as created in RewriterResolver
type NoOpResolver struct{}
func NewNoOpResolver() Resolver {
return NoOpResolver{}
}
func (r NoOpResolver) Configuration() (result []string) {
return nil
}
func (r NoOpResolver) Resolve(request *model.Request) (*model.Response, error) {
return NoResponse, nil
}

View File

@ -0,0 +1,30 @@
package resolver
import (
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("NoOpResolver", func() {
var sut NoOpResolver
BeforeEach(func() {
sut = NewNoOpResolver().(NoOpResolver)
})
Describe("Resolving", func() {
It("returns no response", func() {
resp, err := sut.Resolve(newRequest("test.tld", dns.TypeA))
Expect(err).Should(Succeed())
Expect(resp).Should(Equal(NoResponse))
})
})
Describe("Configuration output", func() {
It("returns nothing", func() {
c := sut.Configuration()
Expect(c).Should(BeNil())
})
})
})

View File

@ -37,7 +37,7 @@ type requestResponse struct {
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(upstreamResolvers map[string][]config.Upstream) Resolver {
s := make(map[string][]*upstreamResolverStatus)
s := make(map[string][]*upstreamResolverStatus, len(upstreamResolvers))
logger := logger(parallelResolverLogger)
for name, res := range upstreamResolvers {

View File

@ -300,7 +300,7 @@ var _ = Describe("ParallelBestResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})

View File

@ -42,7 +42,7 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig config.QueryLogConfig
err error
resp *Response
m *resolverMock
m *MockResolver
tmpDir string
mockAnswer *dns.Msg
)
@ -55,7 +55,7 @@ var _ = Describe("QueryLoggingResolver", func() {
JustBeforeEach(func() {
sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver)
m = &resolverMock{}
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
sut.Next(m)
})
@ -220,7 +220,7 @@ var _ = Describe("QueryLoggingResolver", func() {
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c) > 1).Should(BeTrue())
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
})

View File

@ -56,7 +56,7 @@ type Resolver interface {
// Resolve performs resolution of a DNS request
Resolve(req *model.Request) (*model.Response, error)
// Configuration prints current resolver configuration
// Configuration returns current resolver configuration
Configuration() []string
}
@ -86,6 +86,13 @@ func (r *NextResolver) GetNext() Resolver {
return r.next
}
// NamedResolver is a resolver with a special name
type NamedResolver interface {
// Name returns the full name of the resolver
Name() string
}
func logger(prefix string) *logrus.Entry {
return log.PrefixedLog(prefix)
}
@ -109,5 +116,14 @@ func Chain(resolvers ...Resolver) Resolver {
// Name returns a user-friendly name of a resolver
func Name(resolver Resolver) string {
if named, ok := resolver.(NamedResolver); ok {
return named.Name()
}
return defaultName(resolver)
}
// defaultName returns a short user-friendly name of a resolver
func defaultName(resolver Resolver) string {
return strings.Split(fmt.Sprintf("%T", resolver), ".")[1]
}

View File

@ -20,12 +20,23 @@ var _ = Describe("Resolver", func() {
Expect(next).ShouldNot(BeNil())
})
})
When("'Name' will be called", func() {
When("'Name' is called", func() {
It("should return resolver name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil)
name := Name(br)
Expect(name).Should(Equal("BlockingResolver"))
})
})
When("'Name' is called on a NamedResolver", func() {
It("should return it's custom name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil)
cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}}
r := NewRewriterResolver(cfg, br)
name := Name(r)
Expect(name).Should(Equal("BlockingResolver w/ RewriterResolver"))
})
})
})
})

View File

@ -0,0 +1,136 @@
package resolver
import (
"fmt"
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// RewriterResolver is different from other resolvers, in the sense that
// it creates a branch in the resolver chain.
// The branch is where the rewrite is active. If the branch doesn't
// yield a result, the normal resolving is continued.
type RewriterResolver struct {
NextResolver
rewrite map[string]string
inner Resolver
}
func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) ChainedResolver {
if len(cfg.Rewrite) == 0 {
return inner
}
for k, v := range cfg.Rewrite {
cfg.Rewrite[strings.ToLower(k)] = strings.ToLower(v)
}
inner.Next(NewNoOpResolver())
return &RewriterResolver{
rewrite: cfg.Rewrite,
inner: inner,
}
}
func (r *RewriterResolver) Name() string {
return fmt.Sprintf("%s w/ %s", Name(r.inner), defaultName(r))
}
// Configuration returns current resolver configuration
func (r *RewriterResolver) Configuration() (result []string) {
result = append(result, "rewrite:")
for key, val := range r.rewrite {
result = append(result, fmt.Sprintf(" %s = \"%s\"", key, val))
}
innerCfg := r.inner.Configuration()
result = append(result, innerCfg...)
return result
}
// Resolve uses the inner resolver to resolve the rewritten query
func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := withPrefix(request.Log, "rewriter_resolver")
original := request.Req
rewritten, originalNames := r.rewriteRequest(logger, original)
if rewritten != nil {
request.Req = rewritten
}
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
response, err := r.inner.Resolve(request)
if err != nil {
return response, err
}
// Revert the request: must be done before calling r.next
request.Req = original
if response == NoResponse {
// Inner resolver had no response, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
}
// Revert the rewrite in r.inner's response
if rewritten != nil {
for i := range originalNames {
response.Res.Question[i].Name = originalNames[i]
if i < len(response.Res.Answer) {
response.Res.Answer[i].Header().Name = originalNames[i]
}
}
}
return response, nil
}
func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg) (rewritten *dns.Msg, originalNames []string) { // nolint: lll
originalNames = make([]string, len(request.Question))
for i := range request.Question {
nameOriginal := request.Question[i].Name
originalNames[i] = nameOriginal
domainOriginal := util.ExtractDomainOnly(nameOriginal)
domainRewritten, rewriteKey := r.rewriteDomain(domainOriginal)
if domainRewritten != domainOriginal {
if rewritten == nil {
rewritten = request.Copy()
}
rewritten.Question[i].Name = dns.Fqdn(domainRewritten)
logger.WithFields(logrus.Fields{
"domain": domainOriginal,
"rewrite": rewriteKey + ":" + r.rewrite[rewriteKey],
}).Debugf("rewriting %q to %q", domainOriginal, domainRewritten)
}
}
return rewritten, originalNames
}
func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
for k, v := range r.rewrite {
if strings.HasSuffix(domain, "."+k) {
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
return newDomain, k
}
}
return domain, ""
}

View File

@ -0,0 +1,141 @@
package resolver
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/mock"
)
var _ = Describe("RewriterResolver", func() {
var (
sut ChainedResolver
sutConfig config.RewriteConfig
mInner *MockResolver
mNext *MockResolver
fqdnOriginal string
fqdnRewritten string
mNextResponse *model.Response
)
BeforeEach(func() {
mInner = &MockResolver{}
mNext = &MockResolver{}
sutConfig = config.RewriteConfig{Rewrite: map[string]string{"original": "rewritten"}}
})
JustBeforeEach(func() {
sut = NewRewriterResolver(sutConfig, mInner)
sut.Next(mNext)
})
AfterEach(func() {
mInner.AssertExpectations(GinkgoT())
mNext.AssertExpectations(GinkgoT())
})
When("has no configuration", func() {
BeforeEach(func() {
sutConfig = config.RewriteConfig{}
})
It("should return the inner resolver", func() {
Expect(sut).Should(BeIdenticalTo(mInner))
})
})
When("has rewrite", func() {
var request *model.Request
AfterEach(func() {
request = newRequest(fqdnOriginal, dns.TypeA)
mInner.On("Resolve", mock.Anything)
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
Expect(req).Should(Equal(request.Req))
// Inner should see fqdnRewritten
q := req.Question[0]
Expect(q.Name).Should(Equal(fqdnRewritten))
res := new(dns.Msg)
res.SetReply(req)
ptr := new(dns.PTR)
ptr.Ptr = fqdnRewritten
ptr.Hdr = util.CreateHeader(q, 1)
res.Answer = append(res.Answer, ptr)
return res
}
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
if resp != mNextResponse {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
Expect(resp.Res.Answer[0].Header().Name).Should(Equal(fqdnOriginal))
}
})
It("should modify names", func() {
fqdnOriginal = "test.original."
fqdnRewritten = "test.rewritten."
})
It("should modify subdomains", func() {
fqdnOriginal = "sub.test.original."
fqdnRewritten = "sub.test.rewritten."
})
It("should not modify unknown names", func() {
fqdnOriginal = "test.untouched."
fqdnRewritten = fqdnOriginal
})
It("should not modify name if subdomain", func() {
fqdnOriginal = "test.original.untouched."
fqdnRewritten = fqdnOriginal
})
It("should call next resolver", func() {
fqdnOriginal = "test.original."
fqdnRewritten = "test.rewritten."
// Make inner call the NoOpResolver
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
return mInner.next.Resolve(req)
}
// Resolver after RewriterResolver should see `fqdnOriginal`
mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil
}
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
innerOutput := []string{"inner:", "config-output"}
mInner.On("Configuration").Return(innerOutput)
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", len(innerOutput)))
})
})
})
})

View File

@ -198,11 +198,11 @@ func createQueryResolver(cfg *config.Config, redisClient *redis.Client) (resolve
resolver.NewClientNamesResolver(cfg.ClientLookup),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewCustomDNSResolver(cfg.CustomDNS),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
resolver.NewHostsFileResolver(cfg.HostsFile),
br,
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewConditionalUpstreamResolver(cfg.Conditional),
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, resolver.NewConditionalUpstreamResolver(cfg.Conditional)),
resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers),
), brErr
}