feat: add upstream strategy `random` (#1221)

Also simplify code by getting rid of `resolversPerClient` and all surrounding logic.
This commit is contained in:
DerRockWolf 2023-11-18 20:42:14 +00:00 committed by GitHub
parent 4a5a395655
commit 94663eeaeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 656 additions and 387 deletions

View File

@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT.
// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT.
package api
import (

View File

@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT.
// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT.
package api
import (

View File

@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT.
// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT.
package api
// ApiBlockingStatus defines model for api.BlockingStatus.

View File

@ -124,7 +124,7 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
type QueryLogField string
// UpstreamStrategy data field to be logged
// ENUM(parallel_best,strict)
// ENUM(parallel_best,strict,random)
type UpstreamStrategy uint8
//nolint:gochecknoglobals

View File

@ -482,15 +482,18 @@ const (
UpstreamStrategyParallelBest UpstreamStrategy = iota
// UpstreamStrategyStrict is a UpstreamStrategy of type Strict.
UpstreamStrategyStrict
// UpstreamStrategyRandom is a UpstreamStrategy of type Random.
UpstreamStrategyRandom
)
var ErrInvalidUpstreamStrategy = fmt.Errorf("not a valid UpstreamStrategy, try [%s]", strings.Join(_UpstreamStrategyNames, ", "))
const _UpstreamStrategyName = "parallel_beststrict"
const _UpstreamStrategyName = "parallel_beststrictrandom"
var _UpstreamStrategyNames = []string{
_UpstreamStrategyName[0:13],
_UpstreamStrategyName[13:19],
_UpstreamStrategyName[19:25],
}
// UpstreamStrategyNames returns a list of possible string values of UpstreamStrategy.
@ -505,12 +508,14 @@ func UpstreamStrategyValues() []UpstreamStrategy {
return []UpstreamStrategy{
UpstreamStrategyParallelBest,
UpstreamStrategyStrict,
UpstreamStrategyRandom,
}
}
var _UpstreamStrategyMap = map[UpstreamStrategy]string{
UpstreamStrategyParallelBest: _UpstreamStrategyName[0:13],
UpstreamStrategyStrict: _UpstreamStrategyName[13:19],
UpstreamStrategyRandom: _UpstreamStrategyName[19:25],
}
// String implements the Stringer interface.
@ -531,6 +536,7 @@ func (x UpstreamStrategy) IsValid() bool {
var _UpstreamStrategyValue = map[string]UpstreamStrategy{
_UpstreamStrategyName[0:13]: UpstreamStrategyParallelBest,
_UpstreamStrategyName[13:19]: UpstreamStrategyStrict,
_UpstreamStrategyName[19:25]: UpstreamStrategyRandom,
}
// ParseUpstreamStrategy attempts to convert a string to a UpstreamStrategy.

View File

@ -34,3 +34,24 @@ func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) {
}
}
}
// UpstreamGroup represents the config for one group (upstream branch)
type UpstreamGroup struct {
Name string
Upstreams []Upstream
}
// IsEnabled implements `config.Configurable`.
func (c *UpstreamGroup) IsEnabled() bool {
return len(c.Upstreams) != 0
}
// LogConfig implements `config.Configurable`.
func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) {
logger.Info("group: ", c.Name)
logger.Info("upstreams:")
for _, upstream := range c.Upstreams {
logger.Infof(" - %s", upstream)
}
}

View File

@ -9,53 +9,104 @@ import (
)
var _ = Describe("ParallelBestConfig", func() {
var cfg UpstreamsConfig
suiteBeforeEach()
BeforeEach(func() {
cfg = UpstreamsConfig{
Timeout: Duration(5 * time.Second),
Groups: UpstreamGroups{
UpstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
Context("UpstreamsConfig", func() {
var cfg UpstreamsConfig
BeforeEach(func() {
cfg = UpstreamsConfig{
Timeout: Duration(5 * time.Second),
Groups: UpstreamGroups{
UpstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
},
},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamsConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
}
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamsConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := UpstreamsConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Context("UpstreamGroupConfig", func() {
var cfg UpstreamGroup
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
BeforeEach(func() {
cfg = UpstreamGroup{
Name: UpstreamDefaultCfgName,
Upstreams: []Upstream{
{Host: "host1"},
{Host: "host2"},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamGroup{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := UpstreamGroup{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
})
})
})
})

View File

@ -19,7 +19,7 @@ upstreams:
laptop*:
- 123.123.123.123
# optional: Determines what strategy blocky uses to choose the upstream servers.
# accepted: parallel_best, strict
# accepted: parallel_best, strict, random
# default: parallel_best
strategy: parallel_best
# optional: timeout to query the upstream resolver. Default: 2s

View File

@ -139,10 +139,14 @@ Blocky supports different upstream strategies (default `parallel_best`) that det
Currently available strategies:
- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
If an upstream failed to answer within the last hour, it is less likely to be chosen for the race.
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers.
(When using 10 upstream servers, each upstream will get on average 20% of the DNS requests)
- `random`: blocky picks one random (weighted) resolver from the upstream group for each query and if successful, returns its response.
If the selected resolver fails to respond, a second one is picked to which the query is sent.
The weighting is identical to the `parallel_best` strategy.
Although the `random` strategy might be slower than the `parallel_best` strategy, it offers more privacy since each request is sent to a single upstream.
- `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on.
!!! example

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"slices"
"sort"
"strings"
"sync"
@ -11,6 +12,7 @@ import (
"time"
"github.com/0xERR0R/blocky/cache/expirationcache"
"golang.org/x/exp/maps"
"github.com/hashicorp/go-multierror"
@ -206,21 +208,11 @@ func (r *BlockingResolver) RefreshLists() error {
return err.ErrorOrNil()
}
//nolint:prealloc
func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
groups := make(map[string]bool, len(r.cfg.BlackLists))
for group := range r.cfg.BlackLists {
groups[group] = true
}
var result []string
for k := range groups {
result = append(result, k)
}
result := maps.Keys(r.cfg.BlackLists)
result = append(result, "default")
sort.Strings(result)
slices.Sort(result)
return result
}
@ -615,11 +607,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
}
func (r *BlockingResolver) initFQDNIPCache() {
identifiers := make([]string, 0)
for identifier := range r.clientGroupsBlock {
identifiers = append(identifiers, identifier)
}
identifiers := maps.Keys(r.clientGroupsBlock)
for _, identifier := range identifiers {
if isFQDN(identifier) {

View File

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
const (
@ -77,9 +78,9 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
// where `ParallelBestResolver` uses its config, we can just use an empty one.
var pbCfg config.UpstreamsConfig
pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName}
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers())
// Always enable prefetching to avoid stalling user requests
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:
@ -300,16 +301,8 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots
return upstreamIPs, nil
}
func (br bootstrapedResolvers) ResolverGroups() map[string][]Resolver {
resolvers := make([]Resolver, 0, len(br))
for resolver := range br {
resolvers = append(resolvers, resolver)
}
return map[string][]Resolver{
upstreamDefaultCfgName: resolvers,
}
func (br bootstrapedResolvers) Resolvers() []Resolver {
return maps.Keys(br)
}
type IPSet struct {

View File

@ -32,6 +32,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
)
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
sutConfig = &config.Config{
BootstrapDNS: []config.BootstrappedUpstreamConfig{
{

View File

@ -190,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
}
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(request)
if err == nil {

View File

@ -27,14 +27,10 @@ func NewConditionalUpstreamResolver(
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstream := range cfg.Mapping.Upstreams {
pbCfg := config.UpstreamsConfig{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: upstream,
},
}
for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}

View File

@ -138,7 +138,7 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
}
}
logger.WithField("resolver", Name(r.next)).Trace("go to next resolver")
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
}

View File

@ -132,7 +132,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
}
r.log().WithField("resolver", Name(r.next)).Trace("go to next resolver")
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
}

View File

@ -20,17 +20,22 @@ import (
)
const (
upstreamDefaultCfgName = config.UpstreamDefaultCfgName
parallelResolverType = "parallel_best"
resolverCount = 2
upstreamDefaultCfgName = config.UpstreamDefaultCfgName
parallelResolverType = "parallel_best"
randomResolverType = "random"
parallelBestResolverCount = 2
)
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
type ParallelBestResolver struct {
configurable[*config.UpstreamsConfig]
configurable[*config.UpstreamGroup]
typed
resolversPerClient map[string][]*upstreamResolverStatus
groupName string
resolvers []*upstreamResolverStatus
resolverCount int
retryWithDifferentResolver bool
}
type upstreamResolverStatus struct {
@ -50,9 +55,13 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
resp, err := r.resolver.Resolve(req)
if err != nil && !errors.Is(err, context.Canceled) { // ignore `Canceled`: resolver lost the race, not an error
// update the last error time
r.lastErrorTime.Store(time.Now())
if err != nil {
// Ignore `Canceled`: resolver lost the race, not an error
if !errors.Is(err, context.Canceled) {
r.lastErrorTime.Store(time.Now())
}
err = fmt.Errorf("%s: %w", r.resolver, err)
}
ch <- requestResponse{
@ -82,67 +91,44 @@ func testResolver(r *UpstreamResolver) error {
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(
cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType)
upstreamResolvers := cfg.Groups
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
for name, upstreamCfgs := range upstreamResolvers {
group := make([]Resolver, 0, len(upstreamCfgs))
hasValidResolver := false
for _, u := range upstreamCfgs {
resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", name, err)
continue
}
if shouldVerifyUpstreams {
err = testResolver(resolver)
if err != nil {
logger.Warn(err)
} else {
hasValidResolver = true
}
}
group = append(group, resolver)
}
if shouldVerifyUpstreams && !hasValidResolver {
return nil, fmt.Errorf("no valid upstream for group %s", name)
}
resolverGroups[name] = group
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
return newParallelBestResolver(cfg, resolverGroups), nil
return newParallelBestResolver(cfg, resolvers), nil
}
func newParallelBestResolver(
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
cfg config.UpstreamGroup, resolvers []Resolver,
) *ParallelBestResolver {
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for groupName, resolvers := range resolverGroups {
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
resolverCount := parallelBestResolverCount
retryWithDifferentResolver := false
resolversPerClient[groupName] = resolverStatuses
if config.GetConfig().Upstreams.Strategy == config.UpstreamStrategyRandom {
resolverCount = 1
retryWithDifferentResolver = true
}
r := ParallelBestResolver{
configurable: withConfig(&cfg),
typed: withType(parallelResolverType),
resolversPerClient: resolversPerClient,
groupName: cfg.Name,
resolvers: resolverStatuses,
resolverCount: resolverCount,
retryWithDifferentResolver: retryWithDifferentResolver,
}
return &r
@ -153,88 +139,137 @@ func (r *ParallelBestResolver) Name() string {
}
func (r *ParallelBestResolver) String() string {
result := make([]string, 0, len(r.resolversPerClient))
for name, res := range r.resolversPerClient {
tmp := make([]string, len(res))
for i, s := range res {
tmp[i] = fmt.Sprintf("%s", s.resolver)
}
result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ",")))
result := make([]string, len(r.resolvers))
for i, s := range r.resolvers {
result[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("parallel upstreams '%s'", strings.Join(result, "; "))
return fmt.Sprintf("%s (resolverCount: %d, retryWithDifferentResolver: %t) upstreams '%s (%s)'",
parallelResolverType, r.resolverCount, r.retryWithDifferentResolver, r.groupName, strings.Join(result, ","))
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType)
var resolvers []*upstreamResolverStatus
for _, r := range r.resolversPerClient {
resolvers = r
if len(r.resolvers) == 1 {
logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver")
break
return r.resolvers[0].resolver.Resolve(request)
}
if len(resolvers) == 1 {
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
ctx := context.Background()
return resolvers[0].resolver.Resolve(request)
// using context with timeout for random upstream strategy
if r.resolverCount == 1 {
var cancel context.CancelFunc
logger = log.WithPrefix(logger, "random")
timeout := config.GetConfig().Upstreams.Timeout
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
defer cancel()
}
r1, r2 := pickRandom(resolvers)
logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver)
resolvers := pickRandom(r.resolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers))
ch := make(chan requestResponse, resolverCount)
for _, resolver := range resolvers {
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
var collectedErrors []error
go resolver.resolve(request, ch)
}
logger.WithField("resolver", r1.resolver).Debug("delegating to resolver")
response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers)
if response != nil {
return response, nil
}
go r1.resolve(request, ch)
if !r.retryWithDifferentResolver {
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
}
logger.WithField("resolver", r2.resolver).Debug("delegating to resolver")
return r.retryWithDifferent(logger, request, resolvers)
}
go r2.resolve(request, ch)
func evaluateResponses(
ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
) (*model.Response, []error) {
collectedErrors := make([]error, 0, len(resolvers))
for len(collectedErrors) < resolverCount {
result := <-ch
for len(collectedErrors) < len(resolvers) {
select {
case <-ctx.Done():
// this context currently only has a deadline when resolverCount == 1
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
logger.WithField("resolver", resolvers[0].resolver).
Debug("upstream exceeded timeout, trying other upstream")
resolvers[0].lastErrorTime.Store(time.Now())
}
case result := <-ch:
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, result.err)
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
return result.response, nil
}
}
}
return nil, fmt.Errorf("resolution was not successful, used resolvers: '%s' and '%s' errors: %v",
r1.resolver, r2.resolver, collectedErrors)
return nil, collectedErrors
}
// pick 2 different random resolvers from the resolver pool
func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upstreamResolverStatus) {
resolver1 = weightedRandom(resolvers, nil)
resolver2 = weightedRandom(resolvers, resolver1.resolver)
func (r *ParallelBestResolver) retryWithDifferent(
logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(r.resolvers, resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver)
return
ch := make(chan requestResponse, 1)
resolver.resolve(request, ch)
result := <-ch
if result.err != nil {
return nil, fmt.Errorf("resolution retry failed: %w", result.err)
}
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
}
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount)
for i := 0; i < resolverCount; i++ {
chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers))
}
return chosenResolvers
}
func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
const errorWindowInSec = 60
choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
outer:
for _, res := range in {
if exclude == res.resolver {
continue
for _, exclude := range excludedResolvers {
if exclude.resolver == res.resolver {
continue outer
}
}
var weight float64 = errorWindowInSec

View File

@ -22,11 +22,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
)
var (
sut *ParallelBestResolver
sutMapping config.UpstreamGroups
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
sut *ParallelBestResolver
upstreams []config.Upstream
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
err error
@ -40,19 +40,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
}
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}}
sutVerify = noVerifyUpstreams
@ -60,9 +53,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
JustBeforeEach(func() {
sutConfig := config.UpstreamsConfig{
Timeout: config.Duration(1000 * time.Millisecond),
Groups: sutMapping,
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
@ -85,8 +78,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
Describe("Name", func() {
It("should not be empty", func() {
It("should contain correct resolver", func() {
Expect(sut.Name()).ShouldNot(BeEmpty())
Expect(sut.Name()).Should(ContainSubstring(parallelResolverType))
})
})
@ -99,18 +93,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
defer mockUpstream.Close()
upstream := config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
mockUpstream.Start(),
},
upstreams := []config.Upstream{
{Host: "wrong"},
mockUpstream.Start(),
}
_, err := NewParallelBestResolver(config.UpstreamsConfig{
Groups: upstream,
}, systemResolverBootstrap, verifyUpstreams)
_, err := NewParallelBestResolver(config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
@ -119,16 +111,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
BeforeEach(func() {
bootstrap = newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
}
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}}
})
When("strict checking is enabled", func() {
@ -167,9 +150,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
DeferCleanup(slowTestUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()},
}
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
})
It("Should use result from fastest one", func() {
request := newRequest("example.com.", A)
@ -195,9 +176,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
return response
})
DeferCleanup(slowTestUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{Host: "wrong"}, slowTestUpstream.Start()},
}
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
request := newRequest("example.com.", A)
@ -216,9 +196,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
withError1 := config.Upstream{Host: "wrong"}
withError2 := config.Upstream{Host: "wrong"}
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {withError1, withError2},
}
upstreams = []config.Upstream{withError1, withError2}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
Expect(err).Should(Succeed())
@ -234,11 +213,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
mockUpstream.Start(),
},
}
upstreams = []config.Upstream{mockUpstream.Start()}
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
@ -267,34 +242,30 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}},
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
systemResolverBootstrap, noVerifyUpstreams)
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 1000; i++ {
var resolvers []*upstreamResolverStatus
for _, r := range sut.resolversPerClient {
resolvers = r
}
r1, r2 := pickRandom(resolvers)
res1 := r1.resolver
res2 := r2.resolver
resolvers := pickRandom(sut.resolvers, parallelBestResolverCount)
res1 := resolvers[0].resolver
res2 := resolvers[1].resolver
Expect(res1).ShouldNot(Equal(res2))
resolverCount[res1]++
resolverCount[res2]++
}
for _, v := range resolverCount {
// should be 500 ± 100
Expect(v).Should(BeNumerically("~", 500, 100))
// should be 500 ± 50
Expect(v).Should(BeNumerically("~", 500, 75))
}
})
By("perform 10 request, error upstream's weight will be reduced", func() {
// perform 10 requests
By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ {
request := newRequest("example.com.", A)
_, _ = sut.Resolve(request)
@ -305,13 +276,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 100; i++ {
var resolvers []*upstreamResolverStatus
for _, r := range sut.resolversPerClient {
resolvers = r
}
r1, r2 := pickRandom(resolvers)
res1 := r1.resolver.(*UpstreamResolver)
res2 := r2.resolver.(*UpstreamResolver)
resolvers := pickRandom(sut.resolvers, parallelBestResolverCount)
res1 := resolvers[0].resolver.(*UpstreamResolver)
res2 := resolvers[1].resolver.(*UpstreamResolver)
Expect(res1).ShouldNot(Equal(res2))
resolverCount[res1]++
@ -335,12 +302,235 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(config.UpstreamsConfig{
Groups: config.UpstreamGroups{"test": {{Host: "example.com"}}},
r, err := NewParallelBestResolver(config.UpstreamGroup{
Name: "test",
Upstreams: []config.Upstream{{Host: "example.com"}},
}, b, verifyUpstreams)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
})
})
Describe("random resolver strategy", func() {
const timeout = config.Duration(time.Second)
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
config.GetConfig().Upstreams.Timeout = timeout
})
Describe("Name", func() {
It("should contain correct resolver", func() {
Expect(sut.Name()).ShouldNot(BeEmpty())
Expect(sut.Name()).Should(ContainSubstring(parallelResolverType))
})
})
Describe("Resolving request in random order", func() {
When("Multiple upstream resolvers are defined", func() {
When("Both are responding", func() {
When("Both respond in time", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
It("Should return result from either one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(SatisfyAll(
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
Or(
BeDNSRecord("example.com.", A, "123.124.122.122"),
BeDNSRecord("example.com.", A, "123.124.122.123"),
),
))
})
})
When("one upstream exceeds timeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(time.Duration(timeout) + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
It("should ask a other random upstream and return its response", func() {
request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("two upstreams exceed timeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
time.Sleep(timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream2.Close)
testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3")
DeferCleanup(testUpstream3.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()}
})
// These two tests are flaky -_- (maybe recreate the RandomResolver )
It("should not return error (due to random selection the request could to through)", func() {
Eventually(func() error {
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
return err
}).WithTimeout(30 * time.Second).
Should(Not(HaveOccurred()))
})
It("should return error (because it can be possible that the two broken upstreams are chosen)", func() {
Eventually(func() error {
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
return err
}).WithTimeout(30 * time.Second).
Should(HaveOccurred())
})
})
})
When("None are working", func() {
BeforeEach(func() {
testUpstream1 := config.Upstream{Host: "wrong"}
testUpstream2 := config.Upstream{Host: "wrong"}
upstreams = []config.Upstream{testUpstream1, testUpstream2}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
request := newRequest("example.com.", A)
_, err := sut.Resolve(request)
Expect(err).Should(HaveOccurred())
})
})
})
When("only 1 upstream resolvers is defined", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
upstreams = []config.Upstream{mockUpstream.Start()}
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
})
Describe("Weighted random on resolver selection", func() {
When("4 upstream resolvers are defined", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
withError1 := config.Upstream{Host: "wrong1"}
withError2 := config.Upstream{Host: "wrong2"}
mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream1.Close)
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
systemResolverBootstrap, noVerifyUpstreams)
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 2000; i++ {
r := weightedRandom(sut.resolvers, nil)
resolverCount[r.resolver]++
}
for _, v := range resolverCount {
// should be 500 ± 100
Expect(v).Should(BeNumerically("~", 500, 100))
}
})
By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ {
request := newRequest("example.com.", A)
_, _ = sut.Resolve(request)
}
})
By("Resolvers without errors should be selected often", func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 200; i++ {
r := weightedRandom(sut.resolvers, nil)
res := r.resolver.(*UpstreamResolver)
resolverCount[res]++
}
for k, v := range resolverCount {
if strings.Contains(k.String(), "wrong") {
// error resolvers: should be 0 - 10
Expect(v).Should(BeNumerically("~", 0, 10))
} else {
// should be 90 ± 10
Expect(v).Should(BeNumerically("~", 95, 20))
}
}
})
})
})
})
})
})

View File

@ -214,3 +214,36 @@ func (c *configurable[T]) IsEnabled() bool {
func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
c.cfg.LogConfig(logger)
}
func createResolvers(
logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
) ([]Resolver, error) {
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
hasValidResolvers := false
for _, u := range cfg.Upstreams {
resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", cfg.Name, err)
continue
}
if shoudVerifyUpstreams {
err = testResolver(resolver)
if err != nil {
logger.Warn(err)
} else {
hasValidResolvers = true
}
}
resolvers = append(resolvers, resolver)
}
if shoudVerifyUpstreams && !hasValidResolvers {
return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name)
}
return resolvers, nil
}

View File

@ -21,75 +21,42 @@ const (
// StrictResolver delegates the DNS message strictly to the first configured upstream resolver
// if it can't provide the answer in time the next resolver is used
type StrictResolver struct {
configurable[*config.UpstreamsConfig]
configurable[*config.UpstreamGroup]
typed
resolversPerClient map[string][]*upstreamResolverStatus
groupName string
resolvers []*upstreamResolverStatus
}
// NewStrictResolver creates new resolver instance
// NewStrictResolver creates a new strict resolver instance
func NewStrictResolver(
cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType)
upstreamResolvers := cfg.Groups
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
for name, upstreamCfgs := range upstreamResolvers {
group := make([]Resolver, 0, len(upstreamCfgs))
hasValidResolver := false
for _, u := range upstreamCfgs {
resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", name, err)
continue
}
if shouldVerifyUpstreams {
err = testResolver(resolver)
if err != nil {
logger.Warn(err)
} else {
hasValidResolver = true
}
}
group = append(group, resolver)
}
if shouldVerifyUpstreams && !hasValidResolver {
return nil, fmt.Errorf("no valid upstream for group %s", name)
}
resolverGroups[name] = group
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
return newStrictResolver(cfg, resolverGroups), nil
return newStrictResolver(cfg, resolvers), nil
}
func newStrictResolver(
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
cfg config.UpstreamGroup, resolvers []Resolver,
) *StrictResolver {
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for groupName, resolvers := range resolverGroups {
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
resolversPerClient[groupName] = resolverStatuses
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
r := StrictResolver{
configurable: withConfig(&cfg),
typed: withType(strictResolverType),
resolversPerClient: resolversPerClient,
groupName: cfg.Name,
resolvers: resolverStatuses,
}
return &r
@ -100,48 +67,36 @@ func (r *StrictResolver) Name() string {
}
func (r *StrictResolver) String() string {
result := make([]string, 0, len(r.resolversPerClient))
for name, res := range r.resolversPerClient {
tmp := make([]string, len(res))
for i, s := range res {
tmp[i] = fmt.Sprintf("%s", s.resolver)
}
result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ",")))
result := make([]string, len(r.resolvers))
for i, s := range r.resolvers {
result[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("%s upstreams %q", strictResolverType, strings.Join(result, "; "))
return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.groupName, strings.Join(result, ","))
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
// Resolve sends the query request in a strict order to the upstream resolvers
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType)
var resolvers []*upstreamResolverStatus
for _, r := range r.resolversPerClient {
resolvers = r
break
}
// start with first resolver
for i := range resolvers {
for i := range r.resolvers {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// start in new go routine and cancel if
resolver := resolvers[i]
ch := make(chan requestResponse, resolverCount)
resolver := r.resolvers[i]
logger.Debugf("using %s as resolver", resolver.resolver)
ch := make(chan requestResponse, 1)
go resolver.resolve(request, ch)
select {
case <-ctx.Done():
// log debug/info that timeout exceeded, call `continue` to try next upstream
logger.WithField("resolver", resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
continue
case result := <-ch:

View File

@ -20,9 +20,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
)
var (
sut *StrictResolver
sutMapping config.UpstreamGroups
sutVerify bool
sut *StrictResolver
upstreams []config.Upstream
sutVerify bool
err error
@ -36,15 +36,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
BeforeEach(func() {
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
upstreams = []config.Upstream{
{Host: "wrong"},
{Host: "127.0.0.2"},
}
sutVerify = noVerifyUpstreams
@ -53,12 +47,14 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
JustBeforeEach(func() {
sutConfig := config.UpstreamsConfig{Groups: sutMapping}
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
})
config.GetConfig().Upstreams.Timeout = config.Duration(1000 * time.Millisecond)
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
Describe("IsEnabled", func() {
It("is true", func() {
@ -99,33 +95,25 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
defer mockUpstream.Close()
upstream := config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
mockUpstream.Start(),
},
upstreams := []config.Upstream{
{Host: "wrong"},
mockUpstream.Start(),
}
_, err := NewStrictResolver(config.UpstreamsConfig{
Groups: upstream,
}, systemResolverBootstrap, verifyUpstreams)
_, err := NewStrictResolver(config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
When("no upstream resolvers can be reached", func() {
BeforeEach(func() {
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
upstreams = []config.Upstream{
{Host: "wrong"},
{Host: "127.0.0.2"},
}
})
@ -159,9 +147,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
It("Should use result from first one", func() {
request := newRequest("example.com.", A)
@ -190,9 +176,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
It("should return response from next upstream", func() {
request := newRequest("example.com", A)
@ -226,10 +210,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
return response
})
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
It("should return error", func() {
request := newRequest("example.com", A)
@ -245,9 +226,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1, testUpstream2.Start()},
}
upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()}
})
It("Should use result from second one", func() {
request := newRequest("example.com.", A)
@ -263,9 +242,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
When("None are working", func() {
BeforeEach(func() {
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{Host: "wrong"}, config.Upstream{Host: "wrong"}},
}
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "wrong"}}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
request := newRequest("example.com.", A)
@ -279,11 +257,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
mockUpstream.Start(),
},
}
upstreams = []config.Upstream{mockUpstream.Start()}
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)

View File

@ -447,13 +447,18 @@ func createUpstreamBranches(
err error
)
resolverCfg := config.UpstreamsConfig{Groups: config.UpstreamGroups{group: upstreams}}
groupConfig := config.UpstreamGroup{
Name: group,
Upstreams: upstreams,
}
switch cfg.Upstreams.Strategy {
case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyParallelBest:
upstream, err = resolver.NewParallelBestResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyRandom:
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
}
upstreamBranches[group] = upstream

View File

@ -704,8 +704,7 @@ var _ = Describe("Running DNS server", func() {
"default": {{Host: "0.0.0.0"}},
},
},
},
nil)
}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(branches).ToNot(BeNil())
@ -714,6 +713,24 @@ var _ = Describe("Running DNS server", func() {
})
})
Describe("NewServer with random upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{
Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyRandom,
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(branches).ToNot(BeNil())
Expect(branches).To(HaveLen(1))
_ = branches["default"].(*resolver.ParallelBestResolver)
})
})
Describe("create query resolver", func() {
When("some upstream returns error", func() {
It("create query resolver should return error", func() {