mirror of https://github.com/0xERR0R/blocky.git
fix: `startVerifyUpstream` not disabling all start checks
This commit is contained in:
parent
add591c5a4
commit
c06c017a1a
|
@ -45,7 +45,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
|
|||
It("should not start", func() {
|
||||
Expect(blocky.IsRunning()).Should(BeFalse())
|
||||
Expect(getContainerLogs(blocky)).
|
||||
Should(ContainElement(ContainSubstring("unable to reach any DNS resolvers configured for resolver group default")))
|
||||
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -91,8 +91,9 @@ type BlockingResolver struct {
|
|||
}
|
||||
|
||||
// NewBlockingResolver returns a new configured instance of the resolver
|
||||
func NewBlockingResolver(cfg config.BlockingConfig,
|
||||
redis *redis.Client, bootstrap *Bootstrap) (r ChainedResolver, err error) {
|
||||
func NewBlockingResolver(
|
||||
cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap,
|
||||
) (r ChainedResolver, err error) {
|
||||
blockHandler, err := createBlockHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -54,6 +54,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
expectedReturnCode int
|
||||
)
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
|
||||
BeforeEach(func() {
|
||||
expectedReturnCode = dns.RcodeSuccess
|
||||
|
||||
|
@ -68,7 +70,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
JustBeforeEach(func() {
|
||||
m = &MockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
tmp, err := NewBlockingResolver(sutConfig, nil, skipUpstreamCheck)
|
||||
tmp, err := NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
sut = tmp.(*BlockingResolver)
|
||||
sut.Next(m)
|
||||
|
@ -102,7 +104,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
Expect(err).Should(Succeed())
|
||||
|
||||
// recreate to trigger a reload
|
||||
tmp, err := NewBlockingResolver(sutConfig, nil, skipUpstreamCheck)
|
||||
tmp, err := NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
sut = tmp.(*BlockingResolver)
|
||||
|
||||
|
@ -852,7 +854,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
It("should return error", func() {
|
||||
_, err := NewBlockingResolver(config.BlockingConfig{
|
||||
BlockType: "wrong",
|
||||
}, nil, skipUpstreamCheck)
|
||||
}, nil, systemResolverBootstrap)
|
||||
|
||||
Expect(err).Should(
|
||||
MatchError("unknown blockType 'wrong', please use one of: ZeroIP, NxDomain or specify destination IP address(es)"))
|
||||
|
@ -865,7 +867,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
WhiteLists: map[string][]string{"whitelist": {"wrongPath"}},
|
||||
StartStrategy: config.StartStrategyTypeFailOnError,
|
||||
BlockType: "zeroIp",
|
||||
}, nil, skipUpstreamCheck)
|
||||
}, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -893,7 +895,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
|
||||
tmp, err2 := NewBlockingResolver(sutConfig, redisClient, skipUpstreamCheck)
|
||||
tmp, err2 := NewBlockingResolver(sutConfig, redisClient, systemResolverBootstrap)
|
||||
Expect(err2).Should(Succeed())
|
||||
sut = tmp.(*BlockingResolver)
|
||||
})
|
||||
|
|
|
@ -26,8 +26,7 @@ var (
|
|||
|
||||
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
|
||||
type Bootstrap struct {
|
||||
log *logrus.Entry
|
||||
startVerifyUpstream bool
|
||||
log *logrus.Entry
|
||||
|
||||
resolver Resolver
|
||||
upstream Resolver // the upstream that's part of the above resolver
|
||||
|
@ -65,10 +64,9 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
// This also prevents the GC to clean up these two structs, but is not currently an
|
||||
// issue since they stay allocated until the process terminates
|
||||
b = &Bootstrap{
|
||||
log: log,
|
||||
upstreamIPs: ips,
|
||||
systemResolver: net.DefaultResolver, // allow replacing it during tests
|
||||
startVerifyUpstream: cfg.StartVerifyUpstream,
|
||||
log: log,
|
||||
upstreamIPs: ips,
|
||||
systemResolver: net.DefaultResolver, // allow replacing it during tests
|
||||
}
|
||||
|
||||
if upstream.IsDefault() {
|
||||
|
@ -87,12 +85,18 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
}
|
||||
|
||||
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
||||
ips, err := b.resolveUpstream(r, r.upstream.Host)
|
||||
hostname := r.upstream.Host
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
|
||||
return newIPSet([]net.IP{ip}), nil
|
||||
}
|
||||
|
||||
ips, err := b.resolveUpstream(r, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &IPSet{values: ips}, nil
|
||||
return newIPSet(ips), nil
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
||||
|
@ -233,6 +237,10 @@ type IPSet struct {
|
|||
index uint32
|
||||
}
|
||||
|
||||
func newIPSet(ips []net.IP) *IPSet {
|
||||
return &IPSet{values: ips}
|
||||
}
|
||||
|
||||
func (ips *IPSet) Current() net.IP {
|
||||
idx := atomic.LoadUint32(&ips.index)
|
||||
|
||||
|
|
|
@ -25,10 +25,12 @@ type ClientNamesResolver struct {
|
|||
}
|
||||
|
||||
// NewClientNamesResolver creates new resolver instance
|
||||
func NewClientNamesResolver(cfg config.ClientLookupConfig, bootstrap *Bootstrap) (cr *ClientNamesResolver, err error) {
|
||||
func NewClientNamesResolver(
|
||||
cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (cr *ClientNamesResolver, err error) {
|
||||
var r Resolver
|
||||
if !cfg.Upstream.IsDefault() {
|
||||
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap)
|
||||
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
)
|
||||
|
||||
JustBeforeEach(func() {
|
||||
res, err := NewClientNamesResolver(sutConfig, skipUpstreamCheck)
|
||||
res, err := NewClientNamesResolver(sutConfig, nil, false)
|
||||
Expect(err).Should(Succeed())
|
||||
sut = res
|
||||
m = &MockResolver{}
|
||||
|
@ -315,7 +315,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
r, err := NewClientNamesResolver(config.ClientLookupConfig{
|
||||
Upstream: config.Upstream{Host: "example.com"},
|
||||
}, b)
|
||||
}, b, true)
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(r).Should(BeNil())
|
||||
|
|
|
@ -19,15 +19,16 @@ type ConditionalUpstreamResolver struct {
|
|||
}
|
||||
|
||||
// NewConditionalUpstreamResolver returns new resolver instance
|
||||
func NewConditionalUpstreamResolver(cfg config.ConditionalUpstreamConfig,
|
||||
bootstrap *Bootstrap) (ChainedResolver, error) {
|
||||
func NewConditionalUpstreamResolver(
|
||||
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (ChainedResolver, error) {
|
||||
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
||||
|
||||
for domain, upstream := range cfg.Mapping.Upstreams {
|
||||
upstreams := make(map[string][]config.Upstream)
|
||||
upstreams[upstreamDefaultCfgName] = upstream
|
||||
|
||||
r, err := NewParallelBestResolver(upstreams, bootstrap)
|
||||
r, err := NewParallelBestResolver(upstreams, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
"other.box": {otherTestUpstream.Start()},
|
||||
".": {dotTestUpstream.Start()},
|
||||
}},
|
||||
}, skipUpstreamCheck)
|
||||
}, nil, false)
|
||||
m = &MockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
|
@ -125,7 +125,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
".": {config.Upstream{Host: "example.com"}},
|
||||
},
|
||||
},
|
||||
}, b)
|
||||
}, b, true)
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(r).Should(BeNil())
|
||||
|
@ -141,7 +141,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, skipUpstreamCheck)
|
||||
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, nil, false)
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
@ -32,6 +34,29 @@ type upstreamResolverStatus struct {
|
|||
lastErrorTime atomic.Value
|
||||
}
|
||||
|
||||
func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
|
||||
status := &upstreamResolverStatus{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
status.lastErrorTime.Store(time.Unix(0, 0))
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
|
||||
resp, err := r.resolver.Resolve(req)
|
||||
if err != nil && !errors.Is(err, context.Canceled) { // ignore `Canceled`: resolver lost the race, not an error
|
||||
// update the last error time
|
||||
r.lastErrorTime.Store(time.Now())
|
||||
}
|
||||
|
||||
ch <- requestResponse{
|
||||
response: resp,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type requestResponse struct {
|
||||
response *model.Response
|
||||
err error
|
||||
|
@ -50,46 +75,42 @@ func testResolver(r *UpstreamResolver) error {
|
|||
}
|
||||
|
||||
// NewParallelBestResolver creates new resolver instance
|
||||
func NewParallelBestResolver(upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap) (Resolver, error) {
|
||||
logger := logger("parallel resolver")
|
||||
s := make(map[string][]*upstreamResolverStatus)
|
||||
func NewParallelBestResolver(
|
||||
upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (Resolver, error) {
|
||||
logger := logger(parallelResolverLogger)
|
||||
|
||||
for name, res := range upstreamResolvers {
|
||||
var resolvers []*upstreamResolverStatus
|
||||
s := make(map[string][]*upstreamResolverStatus, len(upstreamResolvers))
|
||||
|
||||
var errResolvers int
|
||||
for name, upstreamCfgs := range upstreamResolvers {
|
||||
group := make([]*upstreamResolverStatus, 0, len(upstreamCfgs))
|
||||
hasValidResolver := false
|
||||
|
||||
for _, u := range res {
|
||||
r, err := NewUpstreamResolver(u, bootstrap)
|
||||
for _, u := range upstreamCfgs {
|
||||
r, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
logger.Warnf("upstream group %s: %v", name, err)
|
||||
errResolvers++
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if bootstrap != skipUpstreamCheck {
|
||||
if shouldVerifyUpstreams {
|
||||
err = testResolver(r)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
errResolvers++
|
||||
} else {
|
||||
hasValidResolver = true
|
||||
}
|
||||
}
|
||||
|
||||
resolver := &upstreamResolverStatus{
|
||||
resolver: r,
|
||||
}
|
||||
resolver.lastErrorTime.Store(time.Unix(0, 0))
|
||||
resolvers = append(resolvers, resolver)
|
||||
group = append(group, newUpstreamResolverStatus(r))
|
||||
}
|
||||
|
||||
if bootstrap != skipUpstreamCheck {
|
||||
if bootstrap.startVerifyUpstream && errResolvers == len(res) {
|
||||
return nil, fmt.Errorf("unable to reach any DNS resolvers configured for resolver group %s", name)
|
||||
}
|
||||
if shouldVerifyUpstreams && !hasValidResolver {
|
||||
return nil, fmt.Errorf("no valid upstream for group %s", name)
|
||||
}
|
||||
|
||||
s[name] = resolvers
|
||||
s[name] = group
|
||||
}
|
||||
|
||||
if len(s[upstreamDefaultCfgName]) == 0 {
|
||||
|
@ -181,11 +202,11 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
|
||||
logger.WithField("resolver", r1.resolver).Debug("delegating to resolver")
|
||||
|
||||
go resolve(request, r1, ch)
|
||||
go r1.resolve(request, ch)
|
||||
|
||||
logger.WithField("resolver", r2.resolver).Debug("delegating to resolver")
|
||||
|
||||
go resolve(request, r2, ch)
|
||||
go r2.resolve(request, ch)
|
||||
|
||||
//nolint: gosimple
|
||||
for len(collectedErrors) < resolverCount {
|
||||
|
@ -243,16 +264,3 @@ func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamRes
|
|||
|
||||
return c.Pick().(*upstreamResolverStatus)
|
||||
}
|
||||
|
||||
func resolve(req *model.Request, resolver *upstreamResolverStatus, ch chan<- requestResponse) {
|
||||
resp, err := resolver.resolver.Resolve(req)
|
||||
|
||||
// update the last error time
|
||||
if err != nil {
|
||||
resolver.lastErrorTime.Store(time.Now())
|
||||
}
|
||||
ch <- requestResponse{
|
||||
response: resp,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,21 +15,25 @@ import (
|
|||
|
||||
var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||
|
||||
const (
|
||||
verifyUpstreams = true
|
||||
noVerifyUpstreams = false
|
||||
)
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
|
||||
config.GetConfig().UpstreamTimeout = config.Duration(1000 * time.Millisecond)
|
||||
|
||||
Describe("Default upstream resolvers are not defined", func() {
|
||||
When("default upstream resolvers are not defined", func() {
|
||||
It("should fail on startup", func() {
|
||||
|
||||
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, skipUpstreamCheck)
|
||||
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, nil, noVerifyUpstreams)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Some default upstream resolvers cannot be reached", func() {
|
||||
When("some default upstream resolvers cannot be reached", func() {
|
||||
It("should start normally", func() {
|
||||
skipUpstreamCheck.startVerifyUpstream = true
|
||||
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, dns.Type(dns.TypeA), "123.124.122.122")
|
||||
|
||||
|
@ -46,12 +50,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewParallelBestResolver(upstream, skipUpstreamCheck)
|
||||
_, err := NewParallelBestResolver(upstream, systemResolverBootstrap, verifyUpstreams)
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("All default upstream resolvers cannot be reached", func() {
|
||||
When("no upstream resolvers can be reached", func() {
|
||||
var (
|
||||
upstream map[string][]config.Upstream
|
||||
b *Bootstrap
|
||||
|
@ -73,16 +77,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
|
||||
It("should fail to start if strict checking is enabled", func() {
|
||||
b.startVerifyUpstream = true
|
||||
|
||||
_, err := NewParallelBestResolver(upstream, b)
|
||||
_, err := NewParallelBestResolver(upstream, b, verifyUpstreams)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should start if strict checking is disabled", func() {
|
||||
b.startVerifyUpstream = false
|
||||
|
||||
_, err := NewParallelBestResolver(upstream, b)
|
||||
_, err := NewParallelBestResolver(upstream, b, noVerifyUpstreams)
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
|
@ -111,7 +111,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()},
|
||||
}, skipUpstreamCheck)
|
||||
}, nil, noVerifyUpstreams)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should use result from fastest one", func() {
|
||||
|
@ -139,7 +139,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
DeferCleanup(slowTestUpstream.Close)
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
upstreamDefaultCfgName: {withErrorUpstream, slowTestUpstream.Start()},
|
||||
}, skipUpstreamCheck)
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should use result from successful resolver", func() {
|
||||
|
@ -160,7 +160,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
upstreamDefaultCfgName: {withError1, withError2},
|
||||
}, skipUpstreamCheck)
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should return error", func() {
|
||||
|
@ -203,7 +203,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
|
||||
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
|
||||
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
|
||||
}, skipUpstreamCheck)
|
||||
}, nil, noVerifyUpstreams)
|
||||
})
|
||||
It("Should use default if client name or IP don't match", func() {
|
||||
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "test")
|
||||
|
@ -270,7 +270,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
upstreamDefaultCfgName: {
|
||||
mockUpstream.Start(),
|
||||
},
|
||||
}, skipUpstreamCheck)
|
||||
}, nil, noVerifyUpstreams)
|
||||
})
|
||||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", dns.Type(dns.TypeA))
|
||||
|
@ -297,7 +297,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
tmp, _ := NewParallelBestResolver(map[string][]config.Upstream{
|
||||
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||
}, skipUpstreamCheck)
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
sut := tmp.(*ParallelBestResolver)
|
||||
|
||||
By("all resolvers have same weight for random -> equal distribution", func() {
|
||||
|
@ -359,7 +359,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("errors during construction", func() {
|
||||
b := TestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b)
|
||||
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b, verifyUpstreams)
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(r).Should(BeNil())
|
||||
|
@ -376,7 +376,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{upstreamDefaultCfgName: {
|
||||
{Host: "host1"},
|
||||
{Host: "host2"},
|
||||
}}, skipUpstreamCheck)
|
||||
}}, nil, noVerifyUpstreams)
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
|
|
|
@ -8,11 +8,14 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Resolver", func() {
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
|
||||
Describe("Creating resolver chain", func() {
|
||||
When("A chain of resolvers will be created", func() {
|
||||
It("should be iterable by calling 'GetNext'", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
|
||||
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, skipUpstreamCheck)
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, nil, false)
|
||||
ch := Chain(br, cr)
|
||||
c, ok := ch.(ChainedResolver)
|
||||
Expect(ok).Should(BeTrue())
|
||||
|
@ -23,14 +26,14 @@ var _ = Describe("Resolver", func() {
|
|||
})
|
||||
When("'Name' is called", func() {
|
||||
It("should return resolver name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
name := Name(br)
|
||||
Expect(name).Should(Equal("BlockingResolver"))
|
||||
})
|
||||
})
|
||||
When("'Name' is called on a NamedResolver", func() {
|
||||
It("should return it's custom name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
|
||||
cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}}
|
||||
r := NewRewriterResolver(cfg, br)
|
||||
|
|
|
@ -28,12 +28,6 @@ const (
|
|||
retryAttempts = 3
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
// This is only set during tests (see upstream_resolver_test.go)
|
||||
skipUpstreamCheck *Bootstrap
|
||||
)
|
||||
|
||||
// UpstreamResolver sends request to external DNS server
|
||||
type UpstreamResolver struct {
|
||||
upstream config.Upstream
|
||||
|
@ -198,10 +192,10 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
|
|||
}
|
||||
|
||||
// NewUpstreamResolver creates new resolver instance
|
||||
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap) (*UpstreamResolver, error) {
|
||||
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) {
|
||||
r := newUpstreamResolverUnchecked(upstream, bootstrap)
|
||||
|
||||
if skipUpstreamCheck == nil || r.bootstrap != skipUpstreamCheck { // skip check during tests
|
||||
if verify {
|
||||
_, err := r.bootstrap.UpstreamIPs(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -16,14 +16,10 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
//nolint:gochecknoinits
|
||||
func init() {
|
||||
// Skips the constructor's check
|
||||
// Resolves hostnames using system resolver
|
||||
skipUpstreamCheck = &Bootstrap{}
|
||||
}
|
||||
|
||||
var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
|
||||
Describe("Using DNS upstream", func() {
|
||||
|
||||
When("Configured DNS resolver can resolve query", func() {
|
||||
|
@ -32,7 +28,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
DeferCleanup(mockUpstream.Close)
|
||||
|
||||
upstream := mockUpstream.Start()
|
||||
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -48,7 +44,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
DeferCleanup(mockUpstream.Close)
|
||||
|
||||
upstream := mockUpstream.Start()
|
||||
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -64,7 +60,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
})
|
||||
DeferCleanup(mockUpstream.Close)
|
||||
upstream := mockUpstream.Start()
|
||||
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
@ -91,7 +87,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
|
||||
upstream := mockUpstream.Start()
|
||||
|
||||
sut, _ = NewUpstreamResolver(upstream, skipUpstreamCheck)
|
||||
sut = newUpstreamResolverUnchecked(upstream, nil)
|
||||
sut.upstreamClient.(*dnsUpstreamClient).udpClient.Timeout = 100 * time.Millisecond
|
||||
})
|
||||
It("should perform a retry with 3 attempts", func() {
|
||||
|
@ -138,7 +134,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
|
||||
JustBeforeEach(func() {
|
||||
upstream = TestDOHUpstream(respFn, modifyHTTPRespFn)
|
||||
sut, _ = NewUpstreamResolver(upstream, skipUpstreamCheck)
|
||||
sut = newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
// use insecure certificates for test doh upstream
|
||||
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
|
||||
|
@ -196,10 +192,10 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
})
|
||||
When("Configured DOH resolver does not respond", func() {
|
||||
JustBeforeEach(func() {
|
||||
sut, _ = NewUpstreamResolver(config.Upstream{
|
||||
sut = newUpstreamResolverUnchecked(config.Upstream{
|
||||
Net: config.NetProtocolHttps,
|
||||
Host: "wronghost.example.com",
|
||||
}, skipUpstreamCheck)
|
||||
}, systemResolverBootstrap)
|
||||
})
|
||||
It("should return error", func() {
|
||||
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
|
||||
|
@ -211,7 +207,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
Describe("Configuration", func() {
|
||||
When("Configuration is called", func() {
|
||||
It("should return nil, because upstream resolver is printed out by other resolvers", func() {
|
||||
sut, _ := NewUpstreamResolver(config.Upstream{}, skipUpstreamCheck)
|
||||
sut := newUpstreamResolverUnchecked(config.Upstream{}, nil)
|
||||
|
||||
c := sut.Configuration()
|
||||
|
||||
|
|
|
@ -394,10 +394,10 @@ func createQueryResolver(
|
|||
bootstrap *resolver.Bootstrap,
|
||||
redisClient *redis.Client,
|
||||
) (r resolver.Resolver, err error) {
|
||||
blockingResolver, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
|
||||
parallelResolver, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap)
|
||||
clientNamesResolver, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap)
|
||||
conditionalUpstreamResolver, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap)
|
||||
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
|
||||
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap, cfg.StartVerifyUpstream)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
|
||||
mErr := multierror.Append(
|
||||
multierror.Prefix(blErr, "blocking resolver: "),
|
||||
|
@ -412,17 +412,17 @@ func createQueryResolver(
|
|||
r = resolver.Chain(
|
||||
resolver.NewFilteringResolver(cfg.Filtering),
|
||||
resolver.NewFqdnOnlyResolver(*cfg),
|
||||
clientNamesResolver,
|
||||
clientNames,
|
||||
resolver.NewEdeResolver(cfg.Ede),
|
||||
resolver.NewQueryLoggingResolver(cfg.QueryLog),
|
||||
resolver.NewMetricsResolver(cfg.Prometheus),
|
||||
resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
|
||||
resolver.NewHostsFileResolver(cfg.HostsFile),
|
||||
blockingResolver,
|
||||
blocking,
|
||||
resolver.NewCachingResolver(cfg.Caching, redisClient),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, conditionalUpstreamResolver),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, condUpstream),
|
||||
resolver.NewSpecialUseDomainNamesResolver(),
|
||||
parallelResolver,
|
||||
parallel,
|
||||
)
|
||||
|
||||
return r, nil
|
||||
|
|
Loading…
Reference in New Issue