Test refactoring (#798)

* test: refactor tests

* chore: fix possible race condition in cache
This commit is contained in:
Dimitri Herzog 2022-12-29 14:58:25 +01:00 committed by GitHub
parent 30086dc957
commit 53a7d4fccc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1790 additions and 1032 deletions

View File

@ -8,7 +8,6 @@ import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/testcontainers/testcontainers-go"
@ -37,9 +36,14 @@ var _ = Describe("Basic functional tests", func() {
DeferCleanup(blocky.Terminate)
})
It("Should start and answer DNS queries", func() {
msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("google.de.", A)
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 123, "1.2.3.4"))
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
})
It("should return 'healthy' container status (healthcheck)", func() {
Eventually(func(g Gomega) string {

View File

@ -3,7 +3,6 @@ package e2e
import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/testcontainers/testcontainers-go"
@ -43,9 +42,14 @@ var _ = Describe("External lists and query blocking", func() {
})
It("should start with warning in log work without errors", func() {
msg := util.NewMsgWithQuestion("google.com.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("google.com.", A)
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("google.com.", dns.TypeA, 123, "1.2.3.4"))
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("error during file processing")))
})
@ -108,9 +112,14 @@ var _ = Describe("External lists and query blocking", func() {
DeferCleanup(blocky.Terminate)
})
It("should download external list on startup and block queries", func() {
msg := util.NewMsgWithQuestion("blockeddomain.com.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("blockeddomain.com.", A)
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("blockeddomain.com.", dns.TypeA, 6*60*60, "0.0.0.0"))
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("blockeddomain.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 6*60*60)),
))
Expect(getContainerLogs(blocky)).Should(BeEmpty())
})

View File

@ -9,7 +9,6 @@ import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/testcontainers/testcontainers-go"
@ -74,28 +73,50 @@ var _ = Describe("Metrics functional tests", func() {
When("Some query results are cached", func() {
BeforeEach(func() {
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_entry_count 0"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_hit_count 0"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_miss_count 0"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 0"),
ContainElement("blocky_cache_hit_count 0"),
ContainElement("blocky_cache_miss_count 0"),
))
})
It("Should increment cache counts", func() {
msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("google.de.", A)
By("first query, should increment the cache miss count and the total count", func() {
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 123, "1.2.3.4"))
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_entry_count 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_hit_count 0"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_miss_count 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 1"),
ContainElement("blocky_cache_hit_count 0"),
ContainElement("blocky_cache_miss_count 1"),
))
})
By("Same query again, should increment the cache hit count", func() {
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 0, "1.2.3.4"))
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("<=", 123)),
))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_entry_count 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_hit_count 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).Should(ContainElement("blocky_cache_miss_count 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 1"),
ContainElement("blocky_cache_hit_count 1"),
ContainElement("blocky_cache_miss_count 1"),
))
})
})
})
@ -103,9 +124,11 @@ var _ = Describe("Metrics functional tests", func() {
When("Lists are loaded", func() {
It("Should expose list cache sizes per group as metrics", func() {
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Should(ContainElement("blocky_blacklist_cache{group=\"group1\"} 1"))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Should(ContainElement("blocky_blacklist_cache{group=\"group2\"} 3"))
Should(
SatisfyAll(
ContainElement("blocky_blacklist_cache{group=\"group1\"} 1"),
ContainElement("blocky_blacklist_cache{group=\"group2\"} 3"),
))
})
})
})

View File

@ -78,21 +78,25 @@ var _ = Describe("Query logs functional tests", func() {
Expect(entries).Should(HaveLen(2))
Expect(entries[0]).Should(SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
Expect(entries[0]).
Should(
SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
Expect(entries[1]).Should(SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "unknown.domain"),
HaveField("Answer", ""),
HaveField("ResponseCode", "NXDOMAIN"),
))
Expect(entries[1]).
Should(
SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "unknown.domain"),
HaveField("Answer", ""),
HaveField("ResponseCode", "NXDOMAIN"),
))
})
})
})
@ -149,21 +153,25 @@ var _ = Describe("Query logs functional tests", func() {
Expect(entries).Should(HaveLen(2))
Expect(entries[0]).Should(SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
Expect(entries[0]).
Should(
SatisfyAll(
HaveField("ResponseType", "RESOLVED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
Expect(entries[1]).Should(SatisfyAll(
HaveField("ResponseType", "CACHED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
Expect(entries[1]).
Should(
SatisfyAll(
HaveField("ResponseType", "CACHED"),
HaveField("QuestionType", "A"),
HaveField("QuestionName", "google.de"),
HaveField("Answer", "A (1.2.3.4)"),
HaveField("ResponseCode", "NOERROR"),
))
})
})
})

View File

@ -7,7 +7,6 @@ import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/util"
"github.com/go-redis/redis/v8"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/testcontainers/testcontainers-go"
@ -71,9 +70,14 @@ var _ = Describe("Redis configuration tests", func() {
DeferCleanup(blocky2.Terminate)
})
It("2nd instance of blocky should use cache from redis", func() {
msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("google.de.", A)
By("Query first blocky instance, should store cache in redis", func() {
Expect(doDNSRequest(blocky1, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 123, "1.2.3.4"))
Expect(doDNSRequest(blocky1, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
})
By("Check redis, must contain one cache entry", func() {
@ -85,7 +89,12 @@ var _ = Describe("Redis configuration tests", func() {
})
By("Query second blocky instance, should use cache from redis", func() {
Expect(doDNSRequest(blocky2, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 0, "1.2.3.4"))
Expect(doDNSRequest(blocky2, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("<=", 123)),
))
})
By("No warnings/errors in log", func() {
@ -113,9 +122,14 @@ var _ = Describe("Redis configuration tests", func() {
DeferCleanup(blocky1.Terminate)
})
It("should load cache from redis after start", func() {
msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("google.de.", A)
By("Query first blocky instance, should store cache in redis\"", func() {
Expect(doDNSRequest(blocky1, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 123, "1.2.3.4"))
Expect(doDNSRequest(blocky1, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
})
By("Check redis, must contain one cache entry", func() {
@ -142,7 +156,12 @@ var _ = Describe("Redis configuration tests", func() {
})
By("Query second blocky instance", func() {
Expect(doDNSRequest(blocky2, msg)).Should(BeDNSRecord("google.de.", dns.TypeA, 0, "1.2.3.4"))
Expect(doDNSRequest(blocky2, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("<=", 123)),
))
})
By("No warnings/errors in log", func() {

View File

@ -111,13 +111,17 @@ var _ = Describe("Upstream resolver configuration tests", func() {
})
It("should consider the timeout parameter", func() {
By("query without timeout", func() {
msg := util.NewMsgWithQuestion("example.com.", dns.Type(dns.TypeA))
Expect(doDNSRequest(blocky, msg)).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "1.2.3.4"))
msg := util.NewMsgWithQuestion("example.com.", A)
Expect(doDNSRequest(blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
})
By("query with timeout", func() {
msg := util.NewMsgWithQuestion("delay.com/.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("delay.com/.", A)
resp, err := doDNSRequest(blocky, msg)
Expect(err).Should(Succeed())

1
go.mod
View File

@ -37,6 +37,7 @@ require (
require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751
github.com/docker/go-connections v0.4.0
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
github.com/golangci/golangci-lint v1.50.1

1
go.sum
View File

@ -113,6 +113,7 @@ github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:H
github.com/abice/go-enum v0.5.4 h1:d6/smT5ZlVO/u9rSkYDP6XSp9NDMmQ16sUUR27O1gBA=
github.com/abice/go-enum v0.5.4/go.mod h1:TfHm+vl7PLBrd0oaWNUPRylZvC5XzScwIEGS2+jN+j4=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=

View File

@ -8,11 +8,22 @@ import (
"os"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
"github.com/onsi/gomega"
"github.com/onsi/gomega/types"
)
const (
A = dns.Type(dns.TypeA)
AAAA = dns.Type(dns.TypeAAAA)
HTTPS = dns.Type(dns.TypeHTTPS)
MX = dns.Type(dns.TypeMX)
PTR = dns.Type(dns.TypePTR)
TXT = dns.Type(dns.TypeTXT)
)
// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")
@ -52,27 +63,89 @@ func DoGetRequest(url string,
return rr, rr.Body
}
func ToAnswer(m *model.Response) []dns.RR {
return m.Res.Answer
}
func ToExtra(m *model.Response) []dns.RR {
return m.Res.Extra
}
func HaveNoAnswer() types.GomegaMatcher {
return gomega.WithTransform(ToAnswer, gomega.BeEmpty())
}
func HaveReason(reason string) types.GomegaMatcher {
return gomega.WithTransform(func(m *model.Response) string {
return m.Reason
}, gomega.Equal(reason))
}
func HaveResponseType(c model.ResponseType) types.GomegaMatcher {
return gomega.WithTransform(func(m *model.Response) model.ResponseType {
return m.RType
}, gomega.Equal(c))
}
func HaveReturnCode(code int) types.GomegaMatcher {
return gomega.WithTransform(func(m *model.Response) int {
return m.Res.Rcode
}, gomega.Equal(code))
}
func toFirstRR(actual interface{}) (dns.RR, error) {
switch i := actual.(type) {
case *model.Response:
return toFirstRR(i.Res)
case *dns.Msg:
return toFirstRR(i.Answer)
case []dns.RR:
if len(i) == 0 {
return nil, fmt.Errorf("answer must not be empty")
}
if len(i) == 1 {
return toFirstRR(i[0])
}
return nil, fmt.Errorf("supports only single RR in answer")
case dns.RR:
return i, nil
default:
return nil, fmt.Errorf("not supported type")
}
}
func HaveTTL(matcher types.GomegaMatcher) types.GomegaMatcher {
return gomega.WithTransform(func(actual interface{}) (uint32, error) {
rr, err := toFirstRR(actual)
if err != nil {
return 0, err
}
return rr.Header().Ttl, nil
}, matcher)
}
// BeDNSRecord returns new dns matcher
func BeDNSRecord(domain string, dnsType uint16, ttl uint32, answer string) types.GomegaMatcher {
func BeDNSRecord(domain string, dnsType dns.Type, answer string) types.GomegaMatcher {
return &dnsRecordMatcher{
domain: domain,
dnsType: dnsType,
TTL: ttl,
answer: answer,
}
}
type dnsRecordMatcher struct {
domain string
dnsType uint16
TTL uint32
dnsType dns.Type
answer string
}
func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error) {
if (rr.Header().Name != matcher.domain) ||
(rr.Header().Rrtype != matcher.dnsType) ||
(matcher.TTL > 0 && (rr.Header().Ttl != matcher.TTL)) {
(dns.Type(rr.Header().Rrtype) != matcher.dnsType) {
return false, nil
}
@ -92,30 +165,22 @@ func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error
// Match checks the DNS record
func (matcher *dnsRecordMatcher) Match(actual interface{}) (success bool, err error) {
switch i := actual.(type) {
case *dns.Msg:
return matcher.Match(i.Answer)
case dns.RR:
return matcher.matchSingle(i)
case []dns.RR:
if len(i) == 1 {
return matcher.matchSingle(i[0])
}
return false, fmt.Errorf("DNSRecord matcher expects []dns.RR with len == 1")
default:
return false, fmt.Errorf("DNSRecord matcher expects an dns.RR or []dns.RR")
rr, err := toFirstRR(actual)
if err != nil {
return false, err
}
return matcher.matchSingle(rr)
}
// FailureMessage generates a failure messge
// FailureMessage generates a failure message
func (matcher *dnsRecordMatcher) FailureMessage(actual interface{}) (message string) {
return fmt.Sprintf("Expected\n\t%s\n to contain\n\t domain '%s', ttl '%d', type '%s', answer '%s'",
actual, matcher.domain, matcher.TTL, dns.TypeToString[matcher.dnsType], matcher.answer)
return fmt.Sprintf("Expected\n\t%s\n to contain\n\t domain '%s', type '%s', answer '%s'",
actual, matcher.domain, dns.TypeToString[uint16(matcher.dnsType)], matcher.answer)
}
// NegatedFailureMessage creates negated message
func (matcher *dnsRecordMatcher) NegatedFailureMessage(actual interface{}) (message string) {
return fmt.Sprintf("Expected\n\t%s\n not to contain\n\t domain '%s', ttl '%d', type '%s', answer '%s'",
actual, matcher.domain, matcher.TTL, dns.TypeToString[matcher.dnsType], matcher.answer)
return fmt.Sprintf("Expected\n\t%s\n not to contain\n\t domain '%s', type '%s', answer '%s'",
actual, matcher.domain, dns.TypeToString[uint16(matcher.dnsType)], matcher.answer)
}

View File

@ -6,6 +6,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/0xERR0R/blocky/cache/expirationcache"
@ -93,7 +94,7 @@ type BlockingResolver struct {
// NewBlockingResolver returns a new configured instance of the resolver
func NewBlockingResolver(
cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap,
) (r ChainedResolver, err error) {
) (r *BlockingResolver, err error) {
blockHandler, err := createBlockHandler(cfg)
if err != nil {
return nil, err
@ -603,7 +604,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (result []
if err == nil && resp.Res.Rcode == dns.RcodeSuccess {
for _, rr := range resp.Res.Answer {
ttl = time.Duration(rr.Header().Ttl) * time.Second
ttl = time.Duration(atomic.LoadUint32(&rr.Header().Ttl)) * time.Second
switch v := rr.(type) {
case *dns.A:

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -178,13 +179,13 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("upstream returns an IPv6", func() {
It("it is used", func() {
bootstrapResponse, err := util.NewMsgWithAnswer(
"localhost.", 123, dns.Type(dns.TypeAAAA), net.IPv6loopback.String(),
"localhost.", 123, AAAA, net.IPv6loopback.String(),
)
Expect(err).Should(Succeed())
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("localhost", []dns.Type{dns.Type(dns.TypeAAAA)})
ips, err := sut.resolve("localhost", []dns.Type{AAAA})
Expect(err).Should(Succeed())
Expect(ips).Should(HaveLen(1))
@ -198,7 +199,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
ips, err := sut.resolve("localhost", []dns.Type{dns.Type(dns.TypeA)})
ips, err := sut.resolve("localhost", []dns.Type{A})
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
@ -212,7 +213,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{dns.Type(dns.TypeA)})
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A})
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host"))
@ -223,7 +224,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("called from another UpstreamResolver", func() {
It("uses the bootstrap upstream", func() {
mainReq := &model.Request{
Req: util.NewMsgWithQuestion("example.com.", dns.Type(dns.TypeA)),
Req: util.NewMsgWithQuestion("example.com.", A),
Log: logrus.NewEntry(log.Log()),
}
@ -234,7 +235,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
upstreamIP := upstream.Host
bootstrapResponse, err := util.NewMsgWithAnswer(
"localhost.", 123, dns.Type(dns.TypeA), upstreamIP,
"localhost.", 123, A, upstreamIP,
)
Expect(err).Should(Succeed())
@ -266,7 +267,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
Expect(err).Should(Succeed())
bootstrapResponse, err := util.NewMsgWithAnswer(
"localhost.", 123, dns.Type(dns.TypeA), host,
"localhost.", 123, A, host,
)
Expect(err).Should(Succeed())

View File

@ -2,6 +2,7 @@ package resolver
import (
"fmt"
"sync/atomic"
"time"
"github.com/hako/durafmt"
@ -40,7 +41,7 @@ type cacheValue struct {
}
// NewCachingResolver creates a new resolver instance
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedResolver {
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
c := &CachingResolver{
minCacheTimeSec: int(time.Duration(cfg.MinCachingTime).Seconds()),
maxCacheTimeSec: int(time.Duration(cfg.MaxCachingTime).Seconds()),
@ -183,9 +184,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
}
// Answer from successful request
resp.Answer = v.answer
for _, rr := range resp.Answer {
rr.Header().Ttl = uint32(ttl.Seconds())
for _, rr := range v.answer {
// make copy here since entries in cache can be modified by other goroutines (e.g. redis cache)
cp := dns.Copy(rr)
cp.Header().Ttl = uint32(ttl.Seconds())
resp.Answer = append(resp.Answer, cp)
}
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
@ -260,19 +264,20 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
for _, a := range answer {
// if TTL < mitTTL -> adjust the value, set minTTL
if r.minCacheTimeSec > 0 {
if a.Header().Ttl < uint32(r.minCacheTimeSec) {
a.Header().Ttl = uint32(r.minCacheTimeSec)
if atomic.LoadUint32(&a.Header().Ttl) < uint32(r.minCacheTimeSec) {
atomic.StoreUint32(&a.Header().Ttl, uint32(r.minCacheTimeSec))
}
}
if r.maxCacheTimeSec > 0 {
if a.Header().Ttl > uint32(r.maxCacheTimeSec) {
a.Header().Ttl = uint32(r.maxCacheTimeSec)
if atomic.LoadUint32(&a.Header().Ttl) > uint32(r.maxCacheTimeSec) {
atomic.StoreUint32(&a.Header().Ttl, uint32(r.maxCacheTimeSec))
}
}
if max < a.Header().Ttl {
max = a.Header().Ttl
headerTTL := atomic.LoadUint32(&a.Header().Ttl)
if max < headerTTL {
max = headerTTL
}
}

View File

@ -21,13 +21,10 @@ import (
var _ = Describe("CachingResolver", func() {
var (
sut ChainedResolver
sut *CachingResolver
sutConfig config.CachingConfig
m *mockResolver
mockAnswer *dns.Msg
err error
resp *Response
)
BeforeEach(func() {
@ -38,10 +35,6 @@ var _ = Describe("CachingResolver", func() {
mockAnswer = new(dns.Msg)
})
AfterEach(func() {
Expect(err).Should(Succeed())
})
JustBeforeEach(func() {
sut = NewCachingResolver(sutConfig, nil)
m = &mockResolver{}
@ -57,16 +50,16 @@ var _ = Describe("CachingResolver", func() {
PrefetchExpires: config.Duration(time.Minute * 120),
PrefetchThreshold: 5,
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 2, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 2, A, "123.122.121.120")
})
It("should prefetch domain if query count > threshold", func() {
// prepare resolver, set smaller caching times for testing
prefetchThreshold := 5
configureCaches(sut.(*CachingResolver), &sutConfig)
sut.(*CachingResolver).resultCache = expirationcache.NewCache(
configureCaches(sut, &sutConfig)
sut.resultCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval(100*time.Millisecond),
expirationcache.WithOnExpiredFn(sut.(*CachingResolver).onExpired))
expirationcache.WithOnExpiredFn(sut.onExpired))
domainPrefetched := make(chan string, 1)
prefetchHitDomain := make(chan string, 1)
@ -83,7 +76,7 @@ var _ = Describe("CachingResolver", func() {
})).Should(Succeed())
// first request
_, _ = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, _ = sut.Resolve(newRequest("example.com.", A))
// Domain is not prefetched
Expect(domainPrefetched).ShouldNot(Receive())
@ -93,7 +86,7 @@ var _ = Describe("CachingResolver", func() {
// now query again > threshold
for i := 0; i < prefetchThreshold+1; i++ {
_, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(Succeed())
}
@ -101,10 +94,13 @@ var _ = Describe("CachingResolver", func() {
Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com")))
// and it should hit from prefetch cache
res, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(res.RType).Should(Equal(ResponseTypeCACHED))
Expect(res.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
HaveTTL(BeNumerically("<=", 2))))
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
})
})
@ -116,7 +112,7 @@ var _ = Describe("CachingResolver", func() {
})
Context("response TTL is bigger than defined min caching time", func() {
BeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 600, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 600, A, "123.122.121.120")
})
It("should cache response and use response's TTL", func() {
@ -130,13 +126,15 @@ var _ = Describe("CachingResolver", func() {
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
totalCacheCount <- d
})
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
HaveTTL(BeNumerically("==", 600))))
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 600, "123.122.121.120"))
Expect(domain).Should(Receive(Equal("example.com")))
Expect(totalCacheCount).Should(Receive(Equal(1)))
@ -149,14 +147,17 @@ var _ = Describe("CachingResolver", func() {
domain <- d
})
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
// ttl is smaller
HaveTTL(BeNumerically("<=", 599))))
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// ttl is smaller
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 599, "123.122.121.120"))
g.Expect(domain).Should(Receive(Equal("example.com")))
}, "1s").Should(Succeed())
@ -166,30 +167,34 @@ var _ = Describe("CachingResolver", func() {
Context("response TTL is smaller than defined min caching time", func() {
Context("A query", func() {
BeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 123, A, "123.122.121.120")
})
It("should cache response and use min caching time as TTL", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
HaveTTL(BeNumerically("==", 300))))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 300, "123.122.121.120"))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
// ttl is smaller
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 299, "123.122.121.120"))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
// ttl is smaller
HaveTTL(BeNumerically("<=", 299))))
// still one call to upstream
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -197,32 +202,33 @@ var _ = Describe("CachingResolver", func() {
Context("AAAA query", func() {
BeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 123,
dns.Type(dns.TypeAAAA), "2001:0db8:85a3:08d3:1319:8a2e:0370:7344")
AAAA, "2001:0db8:85a3:08d3:1319:8a2e:0370:7344")
})
It("should cache response and use min caching time as TTL", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
HaveTTL(BeNumerically("==", 300))))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 300, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
// ttl is smaller
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 299, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
// ttl is smaller
HaveTTL(BeNumerically("<=", 299))))
// still one call to upstream
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -233,7 +239,7 @@ var _ = Describe("CachingResolver", func() {
mockAnswer, _ = util.NewMsgWithAnswer(
"example.com.",
1230,
dns.Type(dns.TypeAAAA),
AAAA,
"2001:0db8:85a3:08d3:1319:8a2e:0370:7344",
)
})
@ -246,26 +252,28 @@ var _ = Describe("CachingResolver", func() {
It("Shouldn't cache any responses", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
HaveTTL(BeNumerically("==", 1230))))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 1230, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// one more call to upstream
g.Expect(m.Calls).Should(HaveLen(2))
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 1230, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
// ttl is smaller
HaveTTL(BeNumerically("==", 1230))))
// one more call to upstream
Expect(m.Calls).Should(HaveLen(2))
})
})
})
@ -278,34 +286,36 @@ var _ = Describe("CachingResolver", func() {
})
It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 240, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.",
AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
HaveTTL(BeNumerically("==", 240))))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
// ttl is smaller
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeAAAA, 239, "2001:db8:85a3:8d3:1319:8a2e:370:7344"))
}, "1s").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.",
AAAA, "2001:db8:85a3:8d3:1319:8a2e:370:7344"),
// ttl is smaller
HaveTTL(BeNumerically("<=", 239))))
// still one call to upstream
Expect(m.Calls).Should(HaveLen(1))
})
})
})
})
When("Entry expires in cache", func() {
BeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1, dns.Type(dns.TypeA), "1.1.1.1")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1, A, "1.1.1.1")
})
Context("max caching time is defined", func() {
BeforeEach(func() {
@ -315,27 +325,31 @@ var _ = Describe("CachingResolver", func() {
})
It("should cache response and return 0 TTL if entry is expired", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.",
A, "1.1.1.1"),
HaveTTL(BeNumerically("==", 1))))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeA, 1, "1.1.1.1"))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
// ttl is 0
g.Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.",
dns.TypeA, 0, "1.1.1.1"))
}, "1100ms").Should(Succeed())
Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.",
A, "1.1.1.1"),
// ttl is 0
HaveTTL(BeNumerically("==", 0))))
// still one call to upstream
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -351,23 +365,27 @@ var _ = Describe("CachingResolver", func() {
It("response should be cached", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
HaveNoAnswer(),
))
Expect(m.Calls).Should(HaveLen(1))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Reason).Should(Equal("CACHED NEGATIVE"))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
// still one call to resolver
g.Expect(m.Calls).Should(HaveLen(1))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED NEGATIVE"),
HaveReturnCode(dns.RcodeNameError),
HaveNoAnswer(),
))
// still one call to resolver
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -381,21 +399,27 @@ var _ = Describe("CachingResolver", func() {
It("response shouldn't be cached", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
HaveNoAnswer(),
))
Expect(m.Calls).Should(HaveLen(1))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
g.Expect(m.Calls).Should(HaveLen(2))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReason(""),
HaveReturnCode(dns.RcodeNameError),
HaveNoAnswer(),
))
// one more call to upstream
Expect(m.Calls).Should(HaveLen(2))
})
})
})
@ -409,23 +433,27 @@ var _ = Describe("CachingResolver", func() {
It("response should be cached", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveNoAnswer(),
))
Expect(m.Calls).Should(HaveLen(1))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeAAAA)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Reason).Should(Equal("CACHED"))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
// still one call to resolver
g.Expect(m.Calls).Should(HaveLen(1))
}, "500ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"),
HaveReturnCode(dns.RcodeSuccess),
HaveNoAnswer(),
))
// still one call to resolver
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -435,27 +463,33 @@ var _ = Describe("CachingResolver", func() {
Describe("Not A / AAAA queries should also be cached", func() {
When("MX query will be performed", func() {
BeforeEach(func() {
mockAnswer, _ = util.NewMsgWithAnswer("google.de.", 180, dns.Type(dns.TypeMX), "10 alt1.aspmx.l.google.com.")
mockAnswer, _ = util.NewMsgWithAnswer("google.de.", 180, MX, "10 alt1.aspmx.l.google.com.")
})
It("Should be cached", func() {
By("first request", func() {
resp, err = sut.Resolve(newRequest("google.de.", dns.Type(dns.TypeMX)))
Expect(err).Should(Succeed())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(newRequest("google.de.", MX))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("google.de.", MX, "alt1.aspmx.l.google.com."),
HaveTTL(BeNumerically("==", 180)),
))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("google.de.", dns.TypeMX, 180, "alt1.aspmx.l.google.com."))
})
By("second request", func() {
Eventually(func(g Gomega) {
resp, err = sut.Resolve(newRequest("google.de.", dns.Type(dns.TypeMX)))
g.Expect(err).Should(Succeed())
g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED))
g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
g.Expect(m.Calls).Should(HaveLen(1))
g.Expect(resp.Res.Answer).Should(BeDNSRecord("google.de.", dns.TypeMX, 179, "alt1.aspmx.l.google.com."))
}, "1s").Should(Succeed())
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("google.de.", MX, "alt1.aspmx.l.google.com."),
HaveTTL(BeNumerically("<=", 179)),
))
// still one call to resolver
Expect(m.Calls).Should(HaveLen(1))
})
})
})
@ -503,6 +537,7 @@ var _ = Describe("CachingResolver", func() {
redisServer *miniredis.Miniredis
redisClient *redis.Client
redisConfig *config.RedisConfig
err error
)
BeforeEach(func() {
redisServer, err = miniredis.Run()
@ -529,7 +564,7 @@ var _ = Describe("CachingResolver", func() {
sutConfig = config.CachingConfig{
MaxCachingTime: config.Duration(time.Second * 10),
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, dns.Type(dns.TypeA), "1.1.1.1")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, A, "1.1.1.1")
sut = NewCachingResolver(sutConfig, redisClient)
m = &mockResolver{}
@ -538,18 +573,18 @@ var _ = Describe("CachingResolver", func() {
})
It("put in redis", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(HaveResponseType(ResponseTypeRESOLVED))
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}, "50ms").Should(HaveLen(1))
}).Should(HaveLen(1))
})
It("load", func() {
request := newRequest("example2.com.", dns.Type(dns.TypeA))
request := newRequest("example2.com.", A)
domain := util.ExtractDomain(request.Req.Question[0])
cacheKey := util.GenerateCacheKey(dns.Type(dns.TypeA), domain)
cacheKey := util.GenerateCacheKey(A, domain)
redisMockMsg := &redis.CacheMessage{
Key: cacheKey,
Response: &Response{
@ -559,13 +594,13 @@ var _ = Describe("CachingResolver", func() {
},
}
redisClient.CacheChannel <- redisMockMsg
time.Sleep(time.Second)
Eventually(func() error {
resp, err = sut.Resolve(request)
return err
}, "50ms").Should(Succeed())
Eventually(sut.Resolve).WithArguments(request).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveTTL(BeNumerically("<=", 10)),
))
})
})
})

View File

@ -6,6 +6,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
@ -41,21 +42,25 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use clientID if set", func() {
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("client123"))
Expect(request.ClientNames).Should(ConsistOf("client123"))
})
It("should use IP as fallback if clientID not set", func() {
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("1.2.3.4"))
Expect(request.ClientNames).Should(ConsistOf("1.2.3.4"))
})
})
Describe("Resolve client name with custom name mapping", Label("XXX"), func() {
@ -78,31 +83,37 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve defined name with ipv4 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("client7"))
Expect(request.ClientNames).Should(ConsistOf("client7"))
})
It("should resolve defined name with ipv6 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("client7"))
Expect(request.ClientNames).Should(ConsistOf("client7"))
})
It("should resolve multiple names defined names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(2))
Expect(request.ClientNames).Should(ContainElements("client7", "client8"))
Expect(request.ClientNames).Should(ConsistOf("client7", "client8"))
})
})
@ -128,21 +139,26 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() {
By("first request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames[0]).Should(Equal("host1"))
Expect(testUpstream.GetCallCount()).Should(Equal(1))
Expect(request.ClientNames).Should(ConsistOf("host1"))
})
By("second request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames[0]).Should(Equal("host1"))
Expect(request.ClientNames).Should(ConsistOf("host1"))
// use cache -> call count 1
Expect(testUpstream.GetCallCount()).Should(Equal(1))
})
@ -153,13 +169,15 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
By("third request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// no cache -> call count 2
Expect(testUpstream.GetCallCount()).Should(Equal(2))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames[0]).Should(Equal("host1"))
Expect(request.ClientNames).Should(ConsistOf("host1"))
})
})
})
@ -176,13 +194,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve all client names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(2))
Expect(request.ClientNames[0]).Should(Equal("myhost1"))
Expect(request.ClientNames[1]).Should(Equal("myhost2"))
Expect(request.ClientNames).Should(ConsistOf("myhost1", "myhost2"))
Expect(testUpstream.GetCallCount()).Should(Equal(1))
})
})
@ -203,11 +222,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames[0]).Should(Equal("host1"))
Expect(request.ClientNames).Should(ConsistOf("host1"))
Expect(testUpstream.GetCallCount()).Should(Equal(1))
})
})
@ -221,12 +243,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve the client name depending to defined order", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("myhost2"))
Expect(request.ClientNames).Should(ConsistOf("myhost2"))
Expect(testUpstream.GetCallCount()).Should(Equal(1))
})
})
@ -245,12 +269,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("192.168.178.25"))
Expect(request.ClientNames).Should(ConsistOf("192.168.178.25"))
Expect(testUpstream.GetCallCount()).Should(Equal(1))
})
})
@ -263,12 +289,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("192.168.178.25"))
Expect(request.ClientNames).Should(ConsistOf("192.168.178.25"))
})
})
@ -278,10 +306,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should resolve no names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(request.ClientNames).Should(BeEmpty())
})
})
@ -292,12 +322,14 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
resp, err := sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("192.168.178.25"))
Expect(request.ClientNames).Should(ConsistOf("192.168.178.25"))
})
})
})

View File

@ -3,46 +3,38 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
)
var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() {
var (
sut ChainedResolver
m *mockResolver
err error
resp *Response
sut ChainedResolver
m *mockResolver
)
AfterEach(func() {
Expect(err).Should(Succeed())
})
BeforeEach(func() {
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, dns.Type(dns.TypeA), "123.124.122.122")
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
return response
})
DeferCleanup(fbTestUpstream.Close)
otherTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 250, dns.Type(dns.TypeA), "192.192.192.192")
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 250, A, "192.192.192.192")
return response
})
DeferCleanup(otherTestUpstream.Close)
dotTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 223, dns.Type(dns.TypeA), "168.168.168.168")
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 223, A, "168.168.168.168")
return response
})
@ -66,51 +58,76 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("Query is exact equal defined condition in mapping", func() {
Context("first mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() {
resp, err = sut.Resolve(newRequest("fritz.box.", dns.Type(dns.TypeA), logrus.NewEntry(log.Log())))
Expect(sut.Resolve(newRequest("fritz.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("fritz.box.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeCONDITIONAL),
HaveReason("CONDITIONAL"),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Answer).Should(BeDNSRecord("fritz.box.", dns.TypeA, 123, "123.124.122.122"))
// no call to next resolver
Expect(m.Calls).Should(BeEmpty())
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
})
Context("last mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() {
resp, err = sut.Resolve(newRequest("other.box.", dns.Type(dns.TypeA)))
Expect(resp.Res.Answer).Should(BeDNSRecord("other.box.", dns.TypeA, 250, "192.192.192.192"))
Expect(sut.Resolve(newRequest("other.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("other.box.", A, "192.192.192.192"),
HaveTTL(BeNumerically("==", 250)),
HaveResponseType(ResponseTypeCONDITIONAL),
HaveReason("CONDITIONAL"),
HaveReturnCode(dns.RcodeSuccess),
))
// no call to next resolver
Expect(m.Calls).Should(BeEmpty())
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
})
})
When("Query is a subdomain of defined condition in mapping", func() {
It("Should resolve the IP of subdomain", func() {
resp, err = sut.Resolve(newRequest("test.fritz.box.", dns.Type(dns.TypeA)))
Expect(resp.Res.Answer).Should(BeDNSRecord("test.fritz.box.", dns.TypeA, 123, "123.124.122.122"))
Expect(sut.Resolve(newRequest("test.fritz.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeCONDITIONAL),
HaveReason("CONDITIONAL"),
HaveReturnCode(dns.RcodeSuccess),
))
// no call to next resolver
Expect(m.Calls).Should(BeEmpty())
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
})
When("Query is not fqdn and . condition is defined in mapping", func() {
It("Should resolve the IP of .", func() {
resp, err = sut.Resolve(newRequest("test.", dns.Type(dns.TypeA)))
Expect(resp.Res.Answer).Should(BeDNSRecord("test.", dns.TypeA, 223, "168.168.168.168"))
Expect(sut.Resolve(newRequest("test.", A))).
Should(
SatisfyAll(
BeDNSRecord("test.", A, "168.168.168.168"),
HaveTTL(BeNumerically("==", 223)),
HaveResponseType(ResponseTypeCONDITIONAL),
HaveReason("CONDITIONAL"),
HaveReturnCode(dns.RcodeSuccess),
))
// no call to next resolver
Expect(m.Calls).Should(BeEmpty())
Expect(resp.RType).Should(Equal(ResponseTypeCONDITIONAL))
})
})
})
Describe("Delegation to next resolver", func() {
When("Query doesn't match defined mapping", func() {
It("should delegate to next resolver", func() {
resp, err = sut.Resolve(newRequest("google.com.", dns.Type(dns.TypeA)))
Expect(sut.Resolve(newRequest("google.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
})
})

View File

@ -15,11 +15,9 @@ import (
var _ = Describe("CustomDNSResolver", func() {
var (
sut ChainedResolver
m *mockResolver
err error
resp *Response
cfg config.CustomDNSConfig
sut ChainedResolver
m *mockResolver
cfg config.CustomDNSConfig
)
TTL := uint32(time.Now().Second())
@ -52,26 +50,39 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is true", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
It("defined ip4 query should be resolved", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeA)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("custom.domain.", dns.TypeA, TTL, "192.168.143.123"))
Expect(sut.Resolve(newRequest("custom.domain.", A))).
Should(
SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
HaveTTL(BeNumerically("==", TTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("TXT query for defined mapping should return NOERROR and empty result", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeTXT)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(0))
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("ip6 query should return NOERROR and empty result", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeAAAA)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(0))
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
@ -80,21 +91,38 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is false", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = false })
It("defined ip4 query should be resolved", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeA)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("custom.domain.", dns.TypeA, TTL, "192.168.143.123"))
Expect(sut.Resolve(newRequest("custom.domain.", A))).
Should(
SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
HaveTTL(BeNumerically("==", TTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("TXT query for defined mapping should be delegated to next resolver", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeTXT)))
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegate was executed
m.AssertExpectations(GinkgoT())
})
It("ip6 query should return NOERROR and empty result", func() {
resp, err = sut.Resolve(newRequest("custom.domain.", dns.Type(dns.TypeAAAA)))
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegate was executed
m.AssertExpectations(GinkgoT())
@ -103,10 +131,15 @@ var _ = Describe("CustomDNSResolver", func() {
})
When("Ip 6 mapping is defined for custom domain ", func() {
It("ip6 query should be resolved", func() {
resp, err = sut.Resolve(newRequest("ip6.domain.", dns.Type(dns.TypeAAAA)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("ip6.domain.", dns.TypeAAAA, TTL, "2001:db8:85a3::8a2e:370:7334"))
Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))).
Should(
SatisfyAll(
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
HaveTTL(BeNumerically("==", TTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
@ -114,22 +147,35 @@ var _ = Describe("CustomDNSResolver", func() {
When("Multiple IPs are defined for custom domain ", func() {
It("all IPs for the current type should be returned", func() {
By("IPv6 query", func() {
resp, err = sut.Resolve(newRequest("multiple.ips.", dns.Type(dns.TypeAAAA)))
Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))).
Should(
SatisfyAll(
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
HaveTTL(BeNumerically("==", TTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("multiple.ips.", dns.TypeAAAA, TTL, "2001:db8:85a3::8a2e:370:7334"))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
By("IPv4 query", func() {
resp, err = sut.Resolve(newRequest("multiple.ips.", dns.Type(dns.TypeA)))
Expect(sut.Resolve(newRequest("multiple.ips.", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("multiple.ips.", A, "192.168.143.123"),
BeDNSRecord("multiple.ips.", A, "192.168.143.125")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(2))
Expect(resp.Res.Answer).Should(ContainElements(
BeDNSRecord("multiple.ips.", dns.TypeA, TTL, "192.168.143.123"),
BeDNSRecord("multiple.ips.", dns.TypeA, TTL, "192.168.143.125")))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
@ -138,27 +184,42 @@ var _ = Describe("CustomDNSResolver", func() {
When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() {
By("ipv4", func() {
resp, err = sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", dns.Type(dns.TypePTR)))
Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("123.143.168.192.in-addr.arpa.", PTR, "custom.domain."),
BeDNSRecord("123.143.168.192.in-addr.arpa.", PTR, "multiple.ips.")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(2))
Expect(resp.Res.Answer).Should(ContainElements(
BeDNSRecord("123.143.168.192.in-addr.arpa.", dns.TypePTR, TTL, "custom.domain."),
BeDNSRecord("123.143.168.192.in-addr.arpa.", dns.TypePTR, TTL, "multiple.ips.")))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
By("ipv6", func() {
resp, err = sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
dns.Type(dns.TypePTR)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(2))
Expect(resp.Res.Answer).Should(ContainElements(
BeDNSRecord("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
dns.TypePTR, TTL, "ip6.domain."),
BeDNSRecord("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
dns.TypePTR, TTL, "multiple.ips.")))
Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
PTR))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
ContainElements(
BeDNSRecord("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
PTR, "ip6.domain."),
BeDNSRecord("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
PTR, "multiple.ips.")),
)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
@ -166,25 +227,31 @@ var _ = Describe("CustomDNSResolver", func() {
})
When("Domain mapping is defined", func() {
It("subdomain must also match", func() {
resp, err = sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", dns.Type(dns.TypeA)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("ABC.CUSTOM.DOMAIN.", dns.TypeA, TTL, "192.168.143.123"))
Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))).
Should(
SatisfyAll(
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
HaveTTL(BeNumerically("==", TTL)),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReason("CUSTOM DNS"),
HaveReturnCode(dns.RcodeSuccess),
))
// will not delegate to next resolver
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
})
AfterEach(func() {
Expect(err).Should(Succeed())
})
})
Describe("Delegating to next resolver", func() {
When("no mapping for domain exist", func() {
It("should delegate to next resolver", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(err).Should(Succeed())
// delegate was executed
m.AssertExpectations(GinkgoT())
})

View File

@ -4,7 +4,9 @@ import (
"errors"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
@ -27,9 +29,9 @@ var _ = Describe("EdeResolver", func() {
JustBeforeEach(func() {
if m == nil {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&model.Response{
m.On("Resolve", mock.Anything).Return(&Response{
Res: mockAnswer,
RType: model.ResponseTypeCUSTOMDNS,
RType: ResponseTypeCUSTOMDNS,
Reason: "Test",
}, nil)
}
@ -45,12 +47,14 @@ var _ = Describe("EdeResolver", func() {
}
})
It("shouldn't add EDE information", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(model.ResponseTypeCUSTOMDNS))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(resp.Res.Extra).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToExtra, BeEmpty()),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
@ -63,20 +67,29 @@ var _ = Describe("EdeResolver", func() {
Enable: true,
}
})
extractFirstOptRecord := func(e []dns.RR) []dns.EDNS0 {
return e[0].(*dns.OPT).Option
}
It("should add EDE information", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(model.ResponseTypeCUSTOMDNS))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(resp.Res.Extra).Should(HaveLen(1))
opt, ok := resp.Res.Extra[0].(*dns.OPT)
Expect(ok).Should(BeTrue())
Expect(opt).ShouldNot(BeNil())
ede, ok := opt.Option[0].(*dns.EDNS0_EDE)
Expect(ok).Should(BeTrue())
Expect(ede.InfoCode).Should(Equal(dns.ExtendedErrorCodeForgedAnswer))
Expect(ede.ExtraText).Should(Equal("Test"))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReturnCode(dns.RcodeSuccess),
// extra should contain one OPT record
WithTransform(ToExtra,
SatisfyAll(
HaveLen(1),
WithTransform(extractFirstOptRecord,
SatisfyAll(
ContainElement(HaveField("InfoCode", Equal(dns.ExtendedErrorCodeForgedAnswer))),
ContainElement(HaveField("ExtraText", Equal("Test"))),
)),
)),
))
})
When("resolver returns an error", func() {
@ -88,7 +101,7 @@ var _ = Describe("EdeResolver", func() {
})
It("should return it", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
resp, err := sut.Resolve(newRequest("example.com", A))
Expect(resp).To(BeNil())
Expect(err).To(Equal(resolveErr))
})

View File

@ -2,8 +2,8 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -32,24 +32,29 @@ var _ = Describe("FilteringResolver", func() {
When("Filtering query types are defined", func() {
BeforeEach(func() {
sutConfig = config.FilteringConfig{
QueryTypes: config.NewQTypeSet(dns.Type(dns.TypeAAAA), dns.Type(dns.TypeMX)),
QueryTypes: config.NewQTypeSet(AAAA, MX),
}
})
It("Should delegate to next resolver if request query has other type", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Should return empty answer for defined query type", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeFILTERED))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeFILTERED),
HaveReturnCode(dns.RcodeSuccess),
))
// no call of next resolver
Expect(m.Calls).Should(BeZero())
@ -65,11 +70,16 @@ var _ = Describe("FilteringResolver", func() {
sutConfig = config.FilteringConfig{}
})
It("Should return empty answer without error", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(HaveLen(0))
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Configure should output 'empty list'", func() {
c := sut.Configuration()

View File

@ -14,7 +14,7 @@ type FqdnOnlyResolver struct {
enabled bool
}
func NewFqdnOnlyResolver(cfg config.Config) ChainedResolver {
func NewFqdnOnlyResolver(cfg config.Config) *FqdnOnlyResolver {
return &FqdnOnlyResolver{
enabled: cfg.FqdnOnly,
}

View File

@ -2,8 +2,8 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -23,7 +23,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
})
JustBeforeEach(func() {
sut = NewFqdnOnlyResolver(sutConfig).(*FqdnOnlyResolver)
sut = NewFqdnOnlyResolver(sutConfig)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)
@ -36,21 +36,25 @@ var _ = Describe("FqdnOnlyResolver", func() {
}
})
It("Should delegate to next resolver if request query is fqdn", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example.com", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Should return NXDOMAIN if request query is not fqdn", func() {
resp, err := sut.Resolve(newRequest("example", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(resp.RType).Should(Equal(ResponseTypeNOTFQDN))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeNOTFQDN),
HaveReturnCode(dns.RcodeNameError),
))
// no call of next resolver
Expect(m.Calls).Should(BeZero())
@ -68,21 +72,25 @@ var _ = Describe("FqdnOnlyResolver", func() {
}
})
It("Should delegate to next resolver if request query is fqdn", func() {
resp, err := sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example.com", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Should delegate to next resolver if request query is not fqdn", func() {
resp, err := sut.Resolve(newRequest("example", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeEmpty())
Expect(sut.Resolve(newRequest("example", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))

View File

@ -134,7 +134,7 @@ func (r *HostsFileResolver) Configuration() (result []string) {
return
}
func NewHostsFileResolver(cfg config.HostsFileConfig) ChainedResolver {
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
r := HostsFileResolver{
HostsFilePath: cfg.Filepath,
ttl: uint32(time.Duration(cfg.HostsTTL).Seconds()),

View File

@ -18,8 +18,6 @@ var _ = Describe("HostsFileResolver", func() {
var (
sut *HostsFileResolver
m *mockResolver
err error
resp *Response
tmpDir *TmpFolder
tmpFile *TmpFile
)
@ -40,7 +38,7 @@ var _ = Describe("HostsFileResolver", func() {
RefreshPeriod: config.Duration(30 * time.Minute),
FilterLoopback: true,
}
sut = NewHostsFileResolver(cfg).(*HostsFileResolver)
sut = NewHostsFileResolver(cfg)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -52,7 +50,7 @@ var _ = Describe("HostsFileResolver", func() {
sut = NewHostsFileResolver(config.HostsFileConfig{
Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()),
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
}).(*HostsFileResolver)
})
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -62,26 +60,34 @@ var _ = Describe("HostsFileResolver", func() {
Expect(sut.hosts).Should(HaveLen(0))
})
It("should go to next resolver on query", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
})
})
When("Hosts file is not set", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{}).(*HostsFileResolver)
sut = NewHostsFileResolver(config.HostsFileConfig{})
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
It("should not return an error", func() {
err = sut.parseHostsFile()
err := sut.parseHostsFile()
Expect(err).Should(Succeed())
})
It("should go to next resolver on query", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
})
})
@ -95,76 +101,96 @@ var _ = Describe("HostsFileResolver", func() {
When("IPv4 mapping is defined for a host", func() {
It("defined ipv4 query should be resolved", func() {
resp, err = sut.Resolve(newRequest("ipv4host.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(BeDNSRecord("ipv4host.", dns.TypeA, TTL, "192.168.2.1"))
Expect(sut.Resolve(newRequest("ipv4host.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("ipv4host.", A, "192.168.2.1"),
HaveTTL(BeNumerically("==", TTL)),
))
})
It("defined ipv4 query for alias should be resolved", func() {
resp, err = sut.Resolve(newRequest("router2.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(BeDNSRecord("router2.", dns.TypeA, TTL, "10.0.0.1"))
Expect(sut.Resolve(newRequest("router2.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("router2.", A, "10.0.0.1"),
HaveTTL(BeNumerically("==", TTL)),
))
})
It("ipv4 query should return NOERROR and empty result", func() {
resp, err = sut.Resolve(newRequest("does.not.existdns.Type(.", dns.Type(dns.TypeA)))
Expect(err).Should(BeNil())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(0))
Expect(sut.Resolve(newRequest("does.not.exist.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveReturnCode(dns.RcodeSuccess),
HaveResponseType(ResponseTypeRESOLVED),
))
})
})
When("IPv6 mapping is defined for a host", func() {
It("defined ipv6 query should be resolved", func() {
resp, err = sut.Resolve(newRequest("ipv6host.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(BeDNSRecord("ipv6host.", dns.TypeAAAA, TTL, "faaf:faaf:faaf:faaf::1"))
Expect(sut.Resolve(newRequest("ipv6host.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("ipv6host.", AAAA, "faaf:faaf:faaf:faaf::1"),
HaveTTL(BeNumerically("==", TTL)),
))
})
It("ipv6 query should return NOERROR and empty result", func() {
resp, err = sut.Resolve(newRequest("does.not.existdns.Type(.", dns.Type(dns.TypeAAAA)))
Expect(err).Should(BeNil())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(HaveLen(0))
Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveReturnCode(dns.RcodeSuccess),
HaveResponseType(ResponseTypeRESOLVED),
))
})
})
When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() {
By("ipv4 with one hostname", func() {
resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", dns.Type(dns.TypePTR)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(HaveLen(1))
Expect(resp.Res.Answer).Should(BeDNSRecord("2.0.0.10.in-addr.arpa.", dns.TypePTR, TTL, "router3."))
Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("2.0.0.10.in-addr.arpa.", PTR, "router3."),
HaveTTL(BeNumerically("==", TTL)),
))
})
By("ipv4 with aliases", func() {
resp, err = sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", dns.Type(dns.TypePTR)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(HaveLen(3))
Expect(resp.Res.Answer[0]).Should(BeDNSRecord("1.0.0.10.in-addr.arpa.", dns.TypePTR, TTL, "router0."))
Expect(resp.Res.Answer[1]).Should(BeDNSRecord("1.0.0.10.in-addr.arpa.", dns.TypePTR, TTL, "router1."))
Expect(resp.Res.Answer[2]).Should(BeDNSRecord("1.0.0.10.in-addr.arpa.", dns.TypePTR, TTL, "router2."))
Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToAnswer, ContainElements(
BeDNSRecord("1.0.0.10.in-addr.arpa.", PTR, "router0."),
BeDNSRecord("1.0.0.10.in-addr.arpa.", PTR, "router1."),
BeDNSRecord("1.0.0.10.in-addr.arpa.", PTR, "router2."),
)),
))
})
By("ipv6", func() {
resp, err = sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.",
dns.Type(dns.TypePTR)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeHOSTSFILE))
Expect(resp.Res.Answer).Should(HaveLen(2))
Expect(resp.Res.Answer[0]).Should(
BeDNSRecord("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.",
dns.TypePTR, TTL, "ipv6host."))
Expect(resp.Res.Answer[1]).Should(
BeDNSRecord("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.",
dns.TypePTR, TTL, "ipv6host.local.lan."))
Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToAnswer, ContainElements(
BeDNSRecord("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.",
PTR, "ipv6host."),
BeDNSRecord("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.",
PTR, "ipv6host.local.lan."),
)),
))
})
})
})
@ -180,7 +206,7 @@ var _ = Describe("HostsFileResolver", func() {
When("hosts file is not provided", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{}).(*HostsFileResolver)
sut = NewHostsFileResolver(config.HostsFileConfig{})
})
It("should return 'disabled'", func() {
c := sut.Configuration()
@ -192,7 +218,7 @@ var _ = Describe("HostsFileResolver", func() {
Describe("Delegating to next resolver", func() {
When("no hosts file is provided", func() {
It("should delegate to next resolver", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(Succeed())
// delegate was executed
m.AssertExpectations(GinkgoT())

View File

@ -5,6 +5,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
@ -17,10 +18,8 @@ import (
var _ = Describe("MetricResolver", func() {
var (
sut *MetricsResolver
m *mockResolver
err error
resp *Response
sut *MetricsResolver
m *mockResolver
)
BeforeEach(func() {
@ -34,14 +33,17 @@ var _ = Describe("MetricResolver", func() {
Context("Recording request metrics", func() {
When("Request will be performed", func() {
It("Should record metrics", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "", "client"))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
cnt, err := sut.totalQueries.GetMetricWith(prometheus.Labels{"client": "client", "type": "A"})
Expect(err).Should(Succeed())
Expect(testutil.ToFloat64(cnt)).Should(Equal(float64(1)))
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(testutil.ToFloat64(cnt)).Should(BeNumerically("==", 1))
m.AssertExpectations(GinkgoT())
})
})
@ -52,10 +54,11 @@ var _ = Describe("MetricResolver", func() {
sut.Next(m)
})
It("Error should be recorded", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "", "client"))
_, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))
Expect(err).Should(HaveOccurred())
Expect(testutil.ToFloat64(sut.totalErrors)).Should(Equal(float64(1)))
Expect(testutil.ToFloat64(sut.totalErrors)).Should(BeNumerically("==", 1))
})
})
})

View File

@ -1,7 +1,7 @@
package resolver
import (
"github.com/miekg/dns"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -15,7 +15,7 @@ var _ = Describe("NoOpResolver", func() {
Describe("Resolving", func() {
It("returns no response", func() {
resp, err := sut.Resolve(newRequest("test.tld", dns.Type(dns.TypeA)))
resp, err := sut.Resolve(newRequest("test.tld", A))
Expect(err).Should(Succeed())
Expect(resp).Should(Equal(NoResponse))
})

View File

@ -34,7 +34,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, dns.Type(dns.TypeA), "123.124.122.122")
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
return
})
@ -88,9 +88,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Describe("Resolving result from fastest upstream resolver", func() {
var (
sut Resolver
err error
resp *Response
sut Resolver
err error
)
When("2 Upstream resolvers are defined", func() {
When("one resolver is fast and another is slow", func() {
@ -99,7 +98,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
DeferCleanup(fastTestUpstream.Close)
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
@ -114,21 +113,22 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Expect(err).Should(Succeed())
})
It("Should use result from fastest one", func() {
request := newRequest("example.com.", dns.Type(dns.TypeA))
resp, err = sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("one resolver is slow, but another returns an error", func() {
BeforeEach(func() {
withErrorUpstream := config.Upstream{Host: "wrong"}
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
@ -142,14 +142,15 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
request := newRequest("example.com.", dns.Type(dns.TypeA))
resp, err = sut.Resolve(request)
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.123"))
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("all resolvers return errors", func() {
@ -163,8 +164,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Expect(err).Should(Succeed())
})
It("Should return error", func() {
request := newRequest("example.com.", dns.Type(dns.TypeA))
resp, err = sut.Resolve(request)
request := newRequest("example.com.", A)
_, err = sut.Resolve(request)
Expect(err).Should(HaveOccurred())
})
@ -203,67 +204,88 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}, nil, noVerifyUpstreams)
})
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "test")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches exact", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "laptop")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.123"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with wildcard", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "client-test-m")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.124"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.124"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with range wildcard", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "client7")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.124"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.124"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client IP matches", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.33", "cl")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.33", "cl")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.125"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.125"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client IP/name matches", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.33", "192.168.178.33")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "192.168.178.33", "192.168.178.33")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.125"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.125"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "10.43.8.64", "cl")
resp, err = sut.Resolve(request)
request := newRequestWithClient("example.com.", A, "10.43.8.64", "cl")
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.126"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.126"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
})
@ -279,12 +301,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}, nil, noVerifyUpstreams)
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", dns.Type(dns.TypeA))
resp, err = sut.Resolve(request)
request := newRequest("example.com.", A)
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
})
@ -311,7 +337,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
for i := 0; i < 100; i++ {
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
"example.com", dns.Type(dns.TypeA), "123.123.100.100",
"example.com", A, "123.123.100.100",
)))
res1 := r1.resolver
res2 := r2.resolver
@ -328,7 +354,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
By("perform 10 request, error upstream's weight will be reduced", func() {
// perform 10 requests
for i := 0; i < 100; i++ {
request := newRequest("example.com.", dns.Type(dns.TypeA))
request := newRequest("example.com.", A)
_, _ = sut.Resolve(request)
}
})
@ -338,7 +364,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
for i := 0; i < 100; i++ {
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
"example.com", dns.Type(dns.TypeA), "123.123.100.100",
"example.com", A, "123.123.100.100",
)))
res1 := r1.resolver.(*UpstreamResolver)
res2 := r2.resolver.(*UpstreamResolver)

View File

@ -9,7 +9,7 @@ import (
"os"
"time"
"github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/querylog"
"github.com/0xERR0R/blocky/config"
@ -39,16 +39,14 @@ var _ = Describe("QueryLoggingResolver", func() {
var (
sut *QueryLoggingResolver
sutConfig config.QueryLogConfig
err error
resp *Response
m *mockResolver
tmpDir *helpertest.TmpFolder
tmpDir *TmpFolder
mockAnswer *dns.Msg
)
BeforeEach(func() {
mockAnswer = new(dns.Msg)
tmpDir = helpertest.NewTmpFolder("queryLoggingResolver")
tmpDir = NewTmpFolder("queryLoggingResolver")
Expect(tmpDir.Error).Should(Succeed())
DeferCleanup(tmpDir.Clean)
})
@ -70,10 +68,14 @@ var _ = Describe("QueryLoggingResolver", func() {
}
})
It("should process request without query logging", func() {
resp, err = sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(sut.Resolve(newRequest("example.com", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
})
})
When("Configuration with logging per client", func() {
@ -84,18 +86,25 @@ var _ = Describe("QueryLoggingResolver", func() {
CreationAttempts: 1,
CreationCooldown: config.Duration(time.Millisecond),
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.122.121.120")
})
It("should create a log file per client", func() {
By("request from client 1", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.25", "client1"))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
By("request from client 2, has name with special chars, should be escaped", func() {
resp, err = sut.Resolve(newRequestWithClient(
"example.com.", dns.Type(dns.TypeA), "192.168.178.26", "cl/ient2\\$%&test",
))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient(
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
m.AssertExpectations(GinkgoT())
@ -145,16 +154,24 @@ var _ = Describe("QueryLoggingResolver", func() {
CreationAttempts: 1,
CreationCooldown: config.Duration(time.Millisecond),
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.122.121.120")
})
It("should create one log file for all clients", func() {
By("request from client 1", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.25", "client1"))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
By("request from client 2, has name with special chars, should be escaped", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.26", "client2"))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
m.AssertExpectations(GinkgoT())
@ -198,12 +215,16 @@ var _ = Describe("QueryLoggingResolver", func() {
CreationCooldown: config.Duration(time.Millisecond),
Fields: []config.QueryLogField{config.QueryLogFieldClientIP},
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, dns.Type(dns.TypeA), "123.122.121.120")
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.122.121.120")
})
It("should create one log file", func() {
By("request from client 1", func() {
resp, err = sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.25", "client1"))
Expect(err).Should(Succeed())
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
m.AssertExpectations(GinkgoT())
@ -246,7 +267,7 @@ var _ = Describe("QueryLoggingResolver", func() {
sut.writer = mockWriter
Eventually(func() int {
_, ierr := sut.Resolve(newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.25", "client1"))
_, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
Expect(ierr).Should(Succeed())
return len(sut.logChan)

View File

@ -1,8 +1,7 @@
package resolver
import (
"net"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -13,16 +12,12 @@ import (
var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
var (
sut *SpecialUseDomainNamesResolver
m *mockResolver
mockAnswer *dns.Msg
err error
resp *Response
sut *SpecialUseDomainNamesResolver
m *mockResolver
)
BeforeEach(func() {
mockAnswer, err = util.NewMsgWithAnswer("example.com.", 300, dns.Type(dns.TypeA), "123.145.123.145")
mockAnswer, err := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
Expect(err).Should(Succeed())
m = &mockResolver{}
@ -35,65 +30,107 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
Describe("Blocking special names", func() {
It("should block arpa", func() {
for _, arpa := range sudnArpaSlice() {
resp, err = sut.Resolve(newRequest(arpa, dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(arpa, A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
}
})
It("should block test", func() {
resp, err = sut.Resolve(newRequest(sudnTest, dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(sudnTest, A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
})
It("should block invalid", func() {
resp, err = sut.Resolve(newRequest(sudnInvalid, dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(sudnInvalid, A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
})
It("should block localhost none A", func() {
resp, err = sut.Resolve(newRequest(sudnLocalhost, dns.Type(dns.TypeHTTPS)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(sudnLocalhost, HTTPS))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
})
It("should block local", func() {
resp, err = sut.Resolve(newRequest(mdnsLocal, dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(mdnsLocal, A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
})
It("should block localhost none A", func() {
resp, err = sut.Resolve(newRequest(mdnsLocal, dns.Type(dns.TypeHTTPS)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(sut.Resolve(newRequest(mdnsLocal, HTTPS))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeNameError),
HaveReason("Special-Use Domain Name"),
))
})
})
Describe("Resolve localhost", func() {
It("should resolve IPv4 loopback", func() {
resp, err = sut.Resolve(newRequest(sudnLocalhost, dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer[0].(*dns.A).A).Should(Equal(sut.defaults.loopbackV4))
Expect(sut.Resolve(newRequest(sudnLocalhost, A))).
Should(
SatisfyAll(
BeDNSRecord(sudnLocalhost, A, sut.defaults.loopbackV4.String()),
HaveTTL(BeNumerically("==", 0)),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("should resolve IPv6 loopback", func() {
resp, err = sut.Resolve(newRequest(sudnLocalhost, dns.Type(dns.TypeAAAA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer[0].(*dns.AAAA).AAAA).Should(Equal(sut.defaults.loopbackV6))
Expect(sut.Resolve(newRequest(sudnLocalhost, AAAA))).
Should(
SatisfyAll(
BeDNSRecord(sudnLocalhost, AAAA, sut.defaults.loopbackV6.String()),
HaveTTL(BeNumerically("==", 0)),
HaveResponseType(ResponseTypeSPECIAL),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
Describe("Forward other", func() {
It("should forward example.com", func() {
resp, err = sut.Resolve(newRequest("example.com", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer[0].(*dns.A).A).Should(Equal(net.ParseIP("123.145.123.145")))
Expect(sut.Resolve(newRequest("example.com", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.145.123.145"),
HaveTTL(BeNumerically("==", 300)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})

View File

@ -28,12 +28,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
Expect(resp.Reason).Should(Equal(fmt.Sprintf("RESOLVED (%s)", upstream.String())))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
)
})
})
When("Configured DNS resolver can't resolve query", func() {
@ -44,11 +47,14 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeNameError))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Reason).Should(Equal(fmt.Sprintf("RESOLVED (%s)", upstream.String())))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
)
})
})
When("Configured DNS resolver fails", func() {
@ -60,7 +66,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
})
})
@ -75,7 +81,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
if atomic.LoadInt32(&counter) <= atomic.LoadInt32(&attemptsWithTimeout) {
time.Sleep(110 * time.Millisecond)
}
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.122")
Expect(err).Should(Succeed())
return response
@ -93,17 +99,20 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 2)
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
))
})
By("3 attempts with timeout -> should return error", func() {
atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 3)
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
})
@ -121,7 +130,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
BeforeEach(func() {
respFn = func(_ *dns.Msg) *dns.Msg {
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.122")
Expect(err).Should(Succeed())
@ -142,12 +151,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
When("Configured DOH resolver can resolve query", func() {
It("should return answer from DNS upstream", func() {
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(resp.RType).Should(Equal(ResponseTypeRESOLVED))
Expect(resp.Res.Answer).Should(BeDNSRecord("example.com.", dns.TypeA, 123, "123.124.122.122"))
Expect(resp.Reason).Should(Equal(fmt.Sprintf("RESOLVED (https://%s:%d)", upstream.Host, upstream.Port)))
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
HaveReason(fmt.Sprintf("RESOLVED (https://%s:%d)", upstream.Host, upstream.Port)),
))
})
})
When("Configured DOH resolver returns wrong http status code", func() {
@ -157,7 +169,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
})
@ -169,7 +181,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
@ -182,7 +194,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
})
@ -195,9 +207,12 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}, systemResolverBootstrap)
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
_, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(Or(ContainSubstring("no such host"), ContainSubstring("i/o timeout")))
Expect(err.Error()).Should(Or(
ContainSubstring("no such host"),
ContainSubstring("i/o timeout"),
ContainSubstring("Temporary failure in name resolution")))
})
})
})

View File

@ -19,7 +19,6 @@ import (
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/util"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -30,7 +29,6 @@ var (
mockClientName atomic.Value
sut *Server
err error
resp *dns.Msg
)
var _ = BeforeSuite(func() {
@ -40,7 +38,7 @@ var _ = BeforeSuite(func() {
return nil
}
response, err := util.NewMsgWithAnswer(
util.ExtractDomain(request.Question[0]), 123, dns.Type(dns.TypeA), "123.124.122.122",
util.ExtractDomain(request.Question[0]), 123, A, "123.124.122.122",
)
Expect(err).Should(Succeed())
@ -51,7 +49,7 @@ var _ = BeforeSuite(func() {
fritzboxMockUpstream := resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer(
util.ExtractDomain(request.Question[0]), 3600, dns.Type(dns.TypeA), "192.168.178.2",
util.ExtractDomain(request.Question[0]), 3600, A, "192.168.178.2",
)
Expect(err).Should(Succeed())
@ -199,119 +197,164 @@ var _ = Describe("Running DNS server", func() {
}
})
AfterEach(func() {
Expect(resp.Rcode).Should(Equal(dns.RcodeSuccess))
})
Context("DNS query is resolvable via external DNS", func() {
It("should return valid answer", func() {
resp = requestServer(util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("google.de.", dns.TypeA, 123, "123.124.122.122"))
Expect(requestServer(util.NewMsgWithQuestion("google.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
))
})
})
Context("Custom DNS entry with exact match", func() {
It("should return valid answer", func() {
resp = requestServer(util.NewMsgWithQuestion("custom.lan.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("custom.lan.", dns.TypeA, 3600, "192.168.178.55"))
Expect(requestServer(util.NewMsgWithQuestion("custom.lan.", A))).
Should(
SatisfyAll(
BeDNSRecord("custom.lan.", A, "192.168.178.55"),
HaveTTL(BeNumerically("==", 3600)),
))
})
})
Context("Custom DNS entry with sub domain", func() {
It("should return valid answer", func() {
resp = requestServer(util.NewMsgWithQuestion("host.lan.home.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("host.lan.home.", dns.TypeA, 3600, "192.168.178.56"))
Expect(requestServer(util.NewMsgWithQuestion("host.lan.home.", A))).
Should(
SatisfyAll(
BeDNSRecord("host.lan.home.", A, "192.168.178.56"),
HaveTTL(BeNumerically("==", 3600)),
))
})
})
Context("Conditional upstream", func() {
It("should resolve query via conditional upstream resolver", func() {
resp = requestServer(util.NewMsgWithQuestion("host.fritz.box.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("host.fritz.box.", dns.TypeA, 3600, "192.168.178.2"))
Expect(requestServer(util.NewMsgWithQuestion("host.fritz.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("host.fritz.box.", A, "192.168.178.2"),
HaveTTL(BeNumerically("==", 3600)),
))
})
})
Context("Conditional upstream blocking", func() {
It("Query should be blocked, domain is in default group", func() {
resp = requestServer(util.NewMsgWithQuestion("doubleclick.net.cn.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("doubleclick.net.cn.", dns.TypeA, 21600, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("doubleclick.net.cn.", A))).
Should(
SatisfyAll(
BeDNSRecord("doubleclick.net.cn.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("Blocking default group", func() {
It("Query should be blocked, domain is in default group", func() {
resp = requestServer(util.NewMsgWithQuestion("doubleclick.net.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("doubleclick.net.", dns.TypeA, 21600, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("doubleclick.net.", A))).
Should(
SatisfyAll(
BeDNSRecord("doubleclick.net.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("Blocking default group with sub domain", func() {
It("Query with subdomain should be blocked, domain is in default group", func() {
resp = requestServer(util.NewMsgWithQuestion("www.bild.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("www.bild.de.", dns.TypeA, 21600, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("www.bild.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("www.bild.de.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("no blocking default group with sub domain", func() {
It("Query with should not be blocked, sub domain is not in blacklist", func() {
resp = requestServer(util.NewMsgWithQuestion("bild.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("bild.de.", dns.TypeA, 0, "123.124.122.122"))
Expect(requestServer(util.NewMsgWithQuestion("bild.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("bild.de.", A, "123.124.122.122"),
HaveTTL(BeNumerically("<=", 123)),
))
})
})
Context("domain is on white and blacklist default group", func() {
It("Query with should not be blocked, domain is on white and blacklist", func() {
resp = requestServer(util.NewMsgWithQuestion("heise.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("heise.de.", dns.TypeA, 0, "123.124.122.122"))
Expect(requestServer(util.NewMsgWithQuestion("heise.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("heise.de.", A, "123.124.122.122"),
HaveTTL(BeNumerically("<=", 123)),
))
})
})
Context("domain is on client specific white list", func() {
It("Query with should not be blocked, domain is on client's white list", func() {
mockClientName.Store("clWhitelistOnly")
resp = requestServer(util.NewMsgWithQuestion("heise.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("heise.de.", dns.TypeA, 0, "123.124.122.122"))
Expect(requestServer(util.NewMsgWithQuestion("heise.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("heise.de.", A, "123.124.122.122"),
HaveTTL(BeNumerically("<=", 123)),
))
})
})
Context("block client whitelist only", func() {
It("Query with should be blocked, client has only whitelist, domain is not on client's white list", func() {
mockClientName.Store("clWhitelistOnly")
resp = requestServer(util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("google.de.", dns.TypeA, 0, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("google.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("block client with 2 groups", func() {
It("Query with should be blocked, domain is on black list", func() {
mockClientName.Store("clAdsAndYoutube")
resp = requestServer(util.NewMsgWithQuestion("www.bild.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("www.bild.de.", dns.TypeA, 0, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("www.bild.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("www.bild.de.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
resp = requestServer(util.NewMsgWithQuestion("youtube.com.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("youtube.com.", dns.TypeA, 0, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("youtube.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("youtube.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("client with 1 group: no block if domain in other group", func() {
It("Query with should not be blocked, domain is on black list in another group", func() {
mockClientName.Store("clYoutubeOnly")
resp = requestServer(util.NewMsgWithQuestion("www.bild.de.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("www.bild.de.", dns.TypeA, 0, "123.124.122.122"))
Expect(requestServer(util.NewMsgWithQuestion("www.bild.de.", A))).
Should(
SatisfyAll(
BeDNSRecord("www.bild.de.", A, "123.124.122.122"),
HaveTTL(BeNumerically("<=", 123)),
))
})
})
Context("block client with 1 group", func() {
It("Query with should not blocked, domain is on black list in client's group", func() {
mockClientName.Store("clYoutubeOnly")
resp = requestServer(util.NewMsgWithQuestion("youtube.com.", dns.Type(dns.TypeA)))
Expect(resp.Answer).Should(BeDNSRecord("youtube.com.", dns.TypeA, 0, "0.0.0.0"))
Expect(requestServer(util.NewMsgWithQuestion("youtube.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("youtube.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 21600)),
))
})
})
Context("health check", func() {
It("Should always return dummy response", func() {
resp = requestServer(util.NewMsgWithQuestion("healthcheck.blocky.", dns.Type(dns.TypeA)))
resp := requestServer(util.NewMsgWithQuestion("healthcheck.blocky.", A))
Expect(resp.Answer).Should(BeEmpty())
})
@ -332,8 +375,11 @@ var _ = Describe("Running DNS server", func() {
It("should return root page", func() {
resp, err := http.Get("http://localhost:4000/")
Expect(err).Should(Succeed())
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "text/html; charset=UTF-8"))
Expect(resp).Should(
SatisfyAll(
HaveHTTPStatus(http.StatusOK),
HaveHTTPHeaderWithValue("Content-type", "text/html; charset=UTF-8"),
))
})
})
})
@ -353,8 +399,11 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
Expect(resp).Should(
SatisfyAll(
HaveHTTPStatus(http.StatusOK),
HaveHTTPHeaderWithValue("Content-type", "application/json"),
))
var result api.QueryResult
err = json.NewDecoder(resp.Body).Decode(&result)
@ -374,7 +423,7 @@ var _ = Describe("Running DNS server", func() {
resp, err := http.Post("http://localhost:4000/api/query", "application/json", bytes.NewBuffer(jsonValue))
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp.StatusCode).Should(Equal(http.StatusInternalServerError))
})
@ -390,8 +439,8 @@ var _ = Describe("Running DNS server", func() {
resp, err := http.Post("http://localhost:4000/api/query", "application/json", bytes.NewBuffer(jsonValue))
Expect(err).Should(Succeed())
DeferCleanup(resp.Body.Close)
Expect(resp.StatusCode).Should(Equal(http.StatusInternalServerError))
_ = resp.Body.Close()
})
})
When("Request is malformed", func() {
@ -401,7 +450,7 @@ var _ = Describe("Running DNS server", func() {
resp, err := http.Post("http://localhost:4000/api/query", "application/json", bytes.NewBuffer(jsonValue))
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp.StatusCode).Should(Equal(http.StatusInternalServerError))
})
@ -414,7 +463,7 @@ var _ = Describe("Running DNS server", func() {
It("should get a valid response", func() {
resp, err := http.Get("http://localhost:4000/dns-query?dns=AAABAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB")
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
@ -426,14 +475,14 @@ var _ = Describe("Running DNS server", func() {
err = msg.Unpack(rawMsg)
Expect(err).Should(Succeed())
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", dns.TypeA, 0, "123.124.122.122"))
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", A, "123.124.122.122"))
})
})
When("Request does not contain a valid DNS message", func() {
It("should return 'Bad Request'", func() {
resp, err := http.Get("http://localhost:4000/dns-query?dns=xxxx")
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusBadRequest))
})
@ -442,7 +491,7 @@ var _ = Describe("Running DNS server", func() {
It("should return 'Bad Request'", func() {
resp, err := http.Get("http://localhost:4000/dns-query?dns=äöä")
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusBadRequest))
})
@ -451,7 +500,7 @@ var _ = Describe("Running DNS server", func() {
It("should return 'Bad Request'", func() {
resp, err := http.Get("http://localhost:4000/dns-query?test")
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusBadRequest))
})
@ -462,7 +511,7 @@ var _ = Describe("Running DNS server", func() {
resp, err := http.Get("http://localhost:4000/dns-query?dns=" + longBase64msg)
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusRequestURITooLong))
})
@ -471,16 +520,21 @@ var _ = Describe("Running DNS server", func() {
Context("DOH over POST (RFC 8484)", func() {
When("DOH post request with 'example.com' is performed", func() {
It("should get a valid response", func() {
msg := util.NewMsgWithQuestion("www.example.com.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("www.example.com.", A)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())
resp, err := http.Post("http://localhost:4000/dns-query",
"application/dns-message", bytes.NewReader(rawDNSMessage))
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
DeferCleanup(resp.Body.Close)
Expect(resp).Should(
SatisfyAll(
HaveHTTPStatus(http.StatusOK),
HaveHTTPHeaderWithValue("Content-type", "application/dns-message"),
))
rawMsg, err := io.ReadAll(resp.Body)
Expect(err).Should(Succeed())
@ -488,19 +542,23 @@ var _ = Describe("Running DNS server", func() {
err = msg.Unpack(rawMsg)
Expect(err).Should(Succeed())
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", dns.TypeA, 0, "123.124.122.122"))
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", A, "123.124.122.122"))
})
It("should get a valid response, clientId is passed", func() {
msg := util.NewMsgWithQuestion("www.example.com.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("www.example.com.", A)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())
resp, err := http.Post("http://localhost:4000/dns-query/client123",
"application/dns-message", bytes.NewReader(rawDNSMessage))
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
DeferCleanup(resp.Body.Close)
Expect(resp).Should(
SatisfyAll(
HaveHTTPStatus(http.StatusOK),
HaveHTTPHeaderWithValue("Content-type", "application/dns-message"),
))
rawMsg, err := io.ReadAll(resp.Body)
Expect(err).Should(Succeed())
@ -508,7 +566,7 @@ var _ = Describe("Running DNS server", func() {
err = msg.Unpack(rawMsg)
Expect(err).Should(Succeed())
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", dns.TypeA, 0, "123.124.122.122"))
Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", A, "123.124.122.122"))
})
})
When("POST payload exceeds 512 bytes", func() {
@ -517,7 +575,7 @@ var _ = Describe("Running DNS server", func() {
resp, err := http.Post("http://localhost:4000/dns-query", "application/dns-message", bytes.NewReader(largeMessage))
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusRequestEntityTooLarge))
})
@ -526,21 +584,22 @@ var _ = Describe("Running DNS server", func() {
It("should return 'Unsupported Media Type'", func() {
resp, err := http.Post("http://localhost:4000/dns-query", "application/text", bytes.NewReader([]byte("a")))
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusUnsupportedMediaType))
})
})
When("Internal error occurs", func() {
It("should return 'Internal server error'", func() {
msg := util.NewMsgWithQuestion("error.", dns.Type(dns.TypeA))
msg := util.NewMsgWithQuestion("error.", A)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())
resp, err := http.Post("http://localhost:4000/dns-query",
"application/dns-message", bytes.NewReader(rawDNSMessage))
Expect(err).Should(Succeed())
defer resp.Body.Close()
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusInternalServerError))
})
})