fix: `startVerifyUpstream` not disabling all start checks

This commit is contained in:
ThinkChaos 2022-11-29 12:55:12 -05:00
parent add591c5a4
commit c06c017a1a
14 changed files with 134 additions and 119 deletions

View File

@ -45,7 +45,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
It("should not start", func() {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Should(ContainElement(ContainSubstring("unable to reach any DNS resolvers configured for resolver group default")))
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
})

View File

@ -91,8 +91,9 @@ type BlockingResolver struct {
}
// NewBlockingResolver returns a new configured instance of the resolver
func NewBlockingResolver(cfg config.BlockingConfig,
redis *redis.Client, bootstrap *Bootstrap) (r ChainedResolver, err error) {
func NewBlockingResolver(
cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap,
) (r ChainedResolver, err error) {
blockHandler, err := createBlockHandler(cfg)
if err != nil {
return nil, err

View File

@ -54,6 +54,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
expectedReturnCode int
)
systemResolverBootstrap := &Bootstrap{}
BeforeEach(func() {
expectedReturnCode = dns.RcodeSuccess
@ -68,7 +70,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
JustBeforeEach(func() {
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
tmp, err := NewBlockingResolver(sutConfig, nil, skipUpstreamCheck)
tmp, err := NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
Expect(err).Should(Succeed())
sut = tmp.(*BlockingResolver)
sut.Next(m)
@ -102,7 +104,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
Expect(err).Should(Succeed())
// recreate to trigger a reload
tmp, err := NewBlockingResolver(sutConfig, nil, skipUpstreamCheck)
tmp, err := NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
Expect(err).Should(Succeed())
sut = tmp.(*BlockingResolver)
@ -852,7 +854,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
It("should return error", func() {
_, err := NewBlockingResolver(config.BlockingConfig{
BlockType: "wrong",
}, nil, skipUpstreamCheck)
}, nil, systemResolverBootstrap)
Expect(err).Should(
MatchError("unknown blockType 'wrong', please use one of: ZeroIP, NxDomain or specify destination IP address(es)"))
@ -865,7 +867,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
WhiteLists: map[string][]string{"whitelist": {"wrongPath"}},
StartStrategy: config.StartStrategyTypeFailOnError,
BlockType: "zeroIp",
}, nil, skipUpstreamCheck)
}, nil, systemResolverBootstrap)
Expect(err).Should(HaveOccurred())
})
})
@ -893,7 +895,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BlockTTL: config.Duration(time.Minute),
}
tmp, err2 := NewBlockingResolver(sutConfig, redisClient, skipUpstreamCheck)
tmp, err2 := NewBlockingResolver(sutConfig, redisClient, systemResolverBootstrap)
Expect(err2).Should(Succeed())
sut = tmp.(*BlockingResolver)
})

View File

@ -26,8 +26,7 @@ var (
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
type Bootstrap struct {
log *logrus.Entry
startVerifyUpstream bool
log *logrus.Entry
resolver Resolver
upstream Resolver // the upstream that's part of the above resolver
@ -65,10 +64,9 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
// This also prevents the GC to clean up these two structs, but is not currently an
// issue since they stay allocated until the process terminates
b = &Bootstrap{
log: log,
upstreamIPs: ips,
systemResolver: net.DefaultResolver, // allow replacing it during tests
startVerifyUpstream: cfg.StartVerifyUpstream,
log: log,
upstreamIPs: ips,
systemResolver: net.DefaultResolver, // allow replacing it during tests
}
if upstream.IsDefault() {
@ -87,12 +85,18 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
}
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
ips, err := b.resolveUpstream(r, r.upstream.Host)
hostname := r.upstream.Host
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil
}
ips, err := b.resolveUpstream(r, hostname)
if err != nil {
return nil, err
}
return &IPSet{values: ips}, nil
return newIPSet(ips), nil
}
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
@ -233,6 +237,10 @@ type IPSet struct {
index uint32
}
func newIPSet(ips []net.IP) *IPSet {
return &IPSet{values: ips}
}
func (ips *IPSet) Current() net.IP {
idx := atomic.LoadUint32(&ips.index)

View File

@ -25,10 +25,12 @@ type ClientNamesResolver struct {
}
// NewClientNamesResolver creates new resolver instance
func NewClientNamesResolver(cfg config.ClientLookupConfig, bootstrap *Bootstrap) (cr *ClientNamesResolver, err error) {
func NewClientNamesResolver(
cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (cr *ClientNamesResolver, err error) {
var r Resolver
if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap)
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}

View File

@ -22,7 +22,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
)
JustBeforeEach(func() {
res, err := NewClientNamesResolver(sutConfig, skipUpstreamCheck)
res, err := NewClientNamesResolver(sutConfig, nil, false)
Expect(err).Should(Succeed())
sut = res
m = &MockResolver{}
@ -315,7 +315,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
r, err := NewClientNamesResolver(config.ClientLookupConfig{
Upstream: config.Upstream{Host: "example.com"},
}, b)
}, b, true)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())

View File

@ -19,15 +19,16 @@ type ConditionalUpstreamResolver struct {
}
// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(cfg config.ConditionalUpstreamConfig,
bootstrap *Bootstrap) (ChainedResolver, error) {
func NewConditionalUpstreamResolver(
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (ChainedResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstream := range cfg.Mapping.Upstreams {
upstreams := make(map[string][]config.Upstream)
upstreams[upstreamDefaultCfgName] = upstream
r, err := NewParallelBestResolver(upstreams, bootstrap)
r, err := NewParallelBestResolver(upstreams, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}

View File

@ -55,7 +55,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
"other.box": {otherTestUpstream.Start()},
".": {dotTestUpstream.Start()},
}},
}, skipUpstreamCheck)
}, nil, false)
m = &MockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -125,7 +125,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
".": {config.Upstream{Host: "example.com"}},
},
},
}, b)
}, b, true)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
@ -141,7 +141,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
When("resolver is disabled", func() {
BeforeEach(func() {
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, skipUpstreamCheck)
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, nil, false)
})
It("should return 'disabled'", func() {
c := sut.Configuration()

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"errors"
"fmt"
"math"
"strings"
@ -32,6 +34,29 @@ type upstreamResolverStatus struct {
lastErrorTime atomic.Value
}
func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
status := &upstreamResolverStatus{
resolver: resolver,
}
status.lastErrorTime.Store(time.Unix(0, 0))
return status
}
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
resp, err := r.resolver.Resolve(req)
if err != nil && !errors.Is(err, context.Canceled) { // ignore `Canceled`: resolver lost the race, not an error
// update the last error time
r.lastErrorTime.Store(time.Now())
}
ch <- requestResponse{
response: resp,
err: err,
}
}
type requestResponse struct {
response *model.Response
err error
@ -50,46 +75,42 @@ func testResolver(r *UpstreamResolver) error {
}
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap) (Resolver, error) {
logger := logger("parallel resolver")
s := make(map[string][]*upstreamResolverStatus)
func NewParallelBestResolver(
upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (Resolver, error) {
logger := logger(parallelResolverLogger)
for name, res := range upstreamResolvers {
var resolvers []*upstreamResolverStatus
s := make(map[string][]*upstreamResolverStatus, len(upstreamResolvers))
var errResolvers int
for name, upstreamCfgs := range upstreamResolvers {
group := make([]*upstreamResolverStatus, 0, len(upstreamCfgs))
hasValidResolver := false
for _, u := range res {
r, err := NewUpstreamResolver(u, bootstrap)
for _, u := range upstreamCfgs {
r, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", name, err)
errResolvers++
continue
}
if bootstrap != skipUpstreamCheck {
if shouldVerifyUpstreams {
err = testResolver(r)
if err != nil {
logger.Warn(err)
errResolvers++
} else {
hasValidResolver = true
}
}
resolver := &upstreamResolverStatus{
resolver: r,
}
resolver.lastErrorTime.Store(time.Unix(0, 0))
resolvers = append(resolvers, resolver)
group = append(group, newUpstreamResolverStatus(r))
}
if bootstrap != skipUpstreamCheck {
if bootstrap.startVerifyUpstream && errResolvers == len(res) {
return nil, fmt.Errorf("unable to reach any DNS resolvers configured for resolver group %s", name)
}
if shouldVerifyUpstreams && !hasValidResolver {
return nil, fmt.Errorf("no valid upstream for group %s", name)
}
s[name] = resolvers
s[name] = group
}
if len(s[upstreamDefaultCfgName]) == 0 {
@ -181,11 +202,11 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
logger.WithField("resolver", r1.resolver).Debug("delegating to resolver")
go resolve(request, r1, ch)
go r1.resolve(request, ch)
logger.WithField("resolver", r2.resolver).Debug("delegating to resolver")
go resolve(request, r2, ch)
go r2.resolve(request, ch)
//nolint: gosimple
for len(collectedErrors) < resolverCount {
@ -243,16 +264,3 @@ func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamRes
return c.Pick().(*upstreamResolverStatus)
}
func resolve(req *model.Request, resolver *upstreamResolverStatus, ch chan<- requestResponse) {
resp, err := resolver.resolver.Resolve(req)
// update the last error time
if err != nil {
resolver.lastErrorTime.Store(time.Now())
}
ch <- requestResponse{
response: resp,
err: err,
}
}

View File

@ -15,21 +15,25 @@ import (
var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
const (
verifyUpstreams = true
noVerifyUpstreams = false
)
systemResolverBootstrap := &Bootstrap{}
config.GetConfig().UpstreamTimeout = config.Duration(1000 * time.Millisecond)
Describe("Default upstream resolvers are not defined", func() {
When("default upstream resolvers are not defined", func() {
It("should fail on startup", func() {
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, skipUpstreamCheck)
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, nil, noVerifyUpstreams)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
})
})
Describe("Some default upstream resolvers cannot be reached", func() {
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
skipUpstreamCheck.startVerifyUpstream = true
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, dns.Type(dns.TypeA), "123.124.122.122")
@ -46,12 +50,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
},
}
_, err := NewParallelBestResolver(upstream, skipUpstreamCheck)
_, err := NewParallelBestResolver(upstream, systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
Describe("All default upstream resolvers cannot be reached", func() {
When("no upstream resolvers can be reached", func() {
var (
upstream map[string][]config.Upstream
b *Bootstrap
@ -73,16 +77,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
It("should fail to start if strict checking is enabled", func() {
b.startVerifyUpstream = true
_, err := NewParallelBestResolver(upstream, b)
_, err := NewParallelBestResolver(upstream, b, verifyUpstreams)
Expect(err).Should(HaveOccurred())
})
It("should start if strict checking is disabled", func() {
b.startVerifyUpstream = false
_, err := NewParallelBestResolver(upstream, b)
_, err := NewParallelBestResolver(upstream, b, noVerifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
@ -111,7 +111,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()},
}, skipUpstreamCheck)
}, nil, noVerifyUpstreams)
Expect(err).Should(Succeed())
})
It("Should use result from fastest one", func() {
@ -139,7 +139,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
DeferCleanup(slowTestUpstream.Close)
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
upstreamDefaultCfgName: {withErrorUpstream, slowTestUpstream.Start()},
}, skipUpstreamCheck)
}, systemResolverBootstrap, noVerifyUpstreams)
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
@ -160,7 +160,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
upstreamDefaultCfgName: {withError1, withError2},
}, skipUpstreamCheck)
}, systemResolverBootstrap, noVerifyUpstreams)
Expect(err).Should(Succeed())
})
It("Should return error", func() {
@ -203,7 +203,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
}, skipUpstreamCheck)
}, nil, noVerifyUpstreams)
})
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", dns.Type(dns.TypeA), "192.168.178.55", "test")
@ -270,7 +270,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreamDefaultCfgName: {
mockUpstream.Start(),
},
}, skipUpstreamCheck)
}, nil, noVerifyUpstreams)
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", dns.Type(dns.TypeA))
@ -297,7 +297,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
tmp, _ := NewParallelBestResolver(map[string][]config.Upstream{
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}, skipUpstreamCheck)
}, systemResolverBootstrap, noVerifyUpstreams)
sut := tmp.(*ParallelBestResolver)
By("all resolvers have same weight for random -> equal distribution", func() {
@ -359,7 +359,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() {
b := TestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b)
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b, verifyUpstreams)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
@ -376,7 +376,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{upstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
}}, skipUpstreamCheck)
}}, nil, noVerifyUpstreams)
})
It("should return configuration", func() {
c := sut.Configuration()

View File

@ -8,11 +8,14 @@ import (
)
var _ = Describe("Resolver", func() {
systemResolverBootstrap := &Bootstrap{}
Describe("Creating resolver chain", func() {
When("A chain of resolvers will be created", func() {
It("should be iterable by calling 'GetNext'", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, skipUpstreamCheck)
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, nil, false)
ch := Chain(br, cr)
c, ok := ch.(ChainedResolver)
Expect(ok).Should(BeTrue())
@ -23,14 +26,14 @@ var _ = Describe("Resolver", func() {
})
When("'Name' is called", func() {
It("should return resolver name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
name := Name(br)
Expect(name).Should(Equal("BlockingResolver"))
})
})
When("'Name' is called on a NamedResolver", func() {
It("should return it's custom name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, skipUpstreamCheck)
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}}
r := NewRewriterResolver(cfg, br)

View File

@ -28,12 +28,6 @@ const (
retryAttempts = 3
)
//nolint:gochecknoglobals
var (
// This is only set during tests (see upstream_resolver_test.go)
skipUpstreamCheck *Bootstrap
)
// UpstreamResolver sends request to external DNS server
type UpstreamResolver struct {
upstream config.Upstream
@ -198,10 +192,10 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
}
// NewUpstreamResolver creates new resolver instance
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap) (*UpstreamResolver, error) {
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) {
r := newUpstreamResolverUnchecked(upstream, bootstrap)
if skipUpstreamCheck == nil || r.bootstrap != skipUpstreamCheck { // skip check during tests
if verify {
_, err := r.bootstrap.UpstreamIPs(r)
if err != nil {
return nil, err

View File

@ -16,14 +16,10 @@ import (
. "github.com/onsi/gomega"
)
//nolint:gochecknoinits
func init() {
// Skips the constructor's check
// Resolves hostnames using system resolver
skipUpstreamCheck = &Bootstrap{}
}
var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
systemResolverBootstrap := &Bootstrap{}
Describe("Using DNS upstream", func() {
When("Configured DNS resolver can resolve query", func() {
@ -32,7 +28,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
sut := newUpstreamResolverUnchecked(upstream, nil)
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
@ -48,7 +44,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
sut := newUpstreamResolverUnchecked(upstream, nil)
resp, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(Succeed())
@ -64,7 +60,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut, _ := NewUpstreamResolver(upstream, skipUpstreamCheck)
sut := newUpstreamResolverUnchecked(upstream, nil)
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
Expect(err).Should(HaveOccurred())
@ -91,7 +87,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut, _ = NewUpstreamResolver(upstream, skipUpstreamCheck)
sut = newUpstreamResolverUnchecked(upstream, nil)
sut.upstreamClient.(*dnsUpstreamClient).udpClient.Timeout = 100 * time.Millisecond
})
It("should perform a retry with 3 attempts", func() {
@ -138,7 +134,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
JustBeforeEach(func() {
upstream = TestDOHUpstream(respFn, modifyHTTPRespFn)
sut, _ = NewUpstreamResolver(upstream, skipUpstreamCheck)
sut = newUpstreamResolverUnchecked(upstream, nil)
// use insecure certificates for test doh upstream
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
@ -196,10 +192,10 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
When("Configured DOH resolver does not respond", func() {
JustBeforeEach(func() {
sut, _ = NewUpstreamResolver(config.Upstream{
sut = newUpstreamResolverUnchecked(config.Upstream{
Net: config.NetProtocolHttps,
Host: "wronghost.example.com",
}, skipUpstreamCheck)
}, systemResolverBootstrap)
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", dns.Type(dns.TypeA)))
@ -211,7 +207,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Describe("Configuration", func() {
When("Configuration is called", func() {
It("should return nil, because upstream resolver is printed out by other resolvers", func() {
sut, _ := NewUpstreamResolver(config.Upstream{}, skipUpstreamCheck)
sut := newUpstreamResolverUnchecked(config.Upstream{}, nil)
c := sut.Configuration()

View File

@ -394,10 +394,10 @@ func createQueryResolver(
bootstrap *resolver.Bootstrap,
redisClient *redis.Client,
) (r resolver.Resolver, err error) {
blockingResolver, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
parallelResolver, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap)
clientNamesResolver, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap)
conditionalUpstreamResolver, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap)
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap, cfg.StartVerifyUpstream)
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
mErr := multierror.Append(
multierror.Prefix(blErr, "blocking resolver: "),
@ -412,17 +412,17 @@ func createQueryResolver(
r = resolver.Chain(
resolver.NewFilteringResolver(cfg.Filtering),
resolver.NewFqdnOnlyResolver(*cfg),
clientNamesResolver,
clientNames,
resolver.NewEdeResolver(cfg.Ede),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
resolver.NewHostsFileResolver(cfg.HostsFile),
blockingResolver,
blocking,
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, conditionalUpstreamResolver),
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, condUpstream),
resolver.NewSpecialUseDomainNamesResolver(),
parallelResolver,
parallel,
)
return r, nil