From 94663eeaeb8365e5f1feccd60a6bf6226877e547 Mon Sep 17 00:00:00 2001 From: DerRockWolf <50499906+DerRockWolf@users.noreply.github.com> Date: Sat, 18 Nov 2023 20:42:14 +0000 Subject: [PATCH] feat: add upstream strategy `random` (#1221) Also simplify code by getting rid of `resolversPerClient` and all surrounding logic. --- api/api_client.gen.go | 2 +- api/api_server.gen.go | 2 +- api/api_types.gen.go | 2 +- config/config.go | 2 +- config/config_enum.go | 8 +- config/upstreams.go | 21 ++ config/upstreams_test.go | 119 +++++--- docs/config.yml | 2 +- docs/configuration.md | 8 +- resolver/blocking_resolver.go | 22 +- resolver/bootstrap.go | 17 +- resolver/bootstrap_test.go | 1 + resolver/caching_resolver.go | 2 +- resolver/conditional_upstream_resolver.go | 10 +- resolver/custom_dns_resolver.go | 2 +- resolver/hosts_file_resolver.go | 2 +- resolver/parallel_best_resolver.go | 231 ++++++++------- resolver/parallel_best_resolver_test.go | 342 +++++++++++++++++----- resolver/resolver.go | 33 +++ resolver/strict_resolver.go | 97 ++---- resolver/strict_resolver_test.go | 84 ++---- server/server.go | 13 +- server/server_test.go | 21 +- 23 files changed, 656 insertions(+), 387 deletions(-) diff --git a/api/api_client.gen.go b/api/api_client.gen.go index 6add407c..f90364e4 100644 --- a/api/api_client.gen.go +++ b/api/api_client.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT. +// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT. package api import ( diff --git a/api/api_server.gen.go b/api/api_server.gen.go index acaf12e6..fdcdd2e0 100644 --- a/api/api_server.gen.go +++ b/api/api_server.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT. +// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT. package api import ( diff --git a/api/api_types.gen.go b/api/api_types.gen.go index 46c3ba9d..cda20516 100644 --- a/api/api_types.gen.go +++ b/api/api_types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT. +// Code generated by github.com/deepmap/oapi-codegen version v1.16.2 DO NOT EDIT. package api // ApiBlockingStatus defines model for api.BlockingStatus. diff --git a/config/config.go b/config/config.go index 1f181c38..6d249d70 100644 --- a/config/config.go +++ b/config/config.go @@ -124,7 +124,7 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error { type QueryLogField string // UpstreamStrategy data field to be logged -// ENUM(parallel_best,strict) +// ENUM(parallel_best,strict,random) type UpstreamStrategy uint8 //nolint:gochecknoglobals diff --git a/config/config_enum.go b/config/config_enum.go index 9bc7574b..422afd70 100644 --- a/config/config_enum.go +++ b/config/config_enum.go @@ -482,15 +482,18 @@ const ( UpstreamStrategyParallelBest UpstreamStrategy = iota // UpstreamStrategyStrict is a UpstreamStrategy of type Strict. UpstreamStrategyStrict + // UpstreamStrategyRandom is a UpstreamStrategy of type Random. + UpstreamStrategyRandom ) var ErrInvalidUpstreamStrategy = fmt.Errorf("not a valid UpstreamStrategy, try [%s]", strings.Join(_UpstreamStrategyNames, ", ")) -const _UpstreamStrategyName = "parallel_beststrict" +const _UpstreamStrategyName = "parallel_beststrictrandom" var _UpstreamStrategyNames = []string{ _UpstreamStrategyName[0:13], _UpstreamStrategyName[13:19], + _UpstreamStrategyName[19:25], } // UpstreamStrategyNames returns a list of possible string values of UpstreamStrategy. @@ -505,12 +508,14 @@ func UpstreamStrategyValues() []UpstreamStrategy { return []UpstreamStrategy{ UpstreamStrategyParallelBest, UpstreamStrategyStrict, + UpstreamStrategyRandom, } } var _UpstreamStrategyMap = map[UpstreamStrategy]string{ UpstreamStrategyParallelBest: _UpstreamStrategyName[0:13], UpstreamStrategyStrict: _UpstreamStrategyName[13:19], + UpstreamStrategyRandom: _UpstreamStrategyName[19:25], } // String implements the Stringer interface. @@ -531,6 +536,7 @@ func (x UpstreamStrategy) IsValid() bool { var _UpstreamStrategyValue = map[string]UpstreamStrategy{ _UpstreamStrategyName[0:13]: UpstreamStrategyParallelBest, _UpstreamStrategyName[13:19]: UpstreamStrategyStrict, + _UpstreamStrategyName[19:25]: UpstreamStrategyRandom, } // ParseUpstreamStrategy attempts to convert a string to a UpstreamStrategy. diff --git a/config/upstreams.go b/config/upstreams.go index 536e8a87..5f1787d8 100644 --- a/config/upstreams.go +++ b/config/upstreams.go @@ -34,3 +34,24 @@ func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) { } } } + +// UpstreamGroup represents the config for one group (upstream branch) +type UpstreamGroup struct { + Name string + Upstreams []Upstream +} + +// IsEnabled implements `config.Configurable`. +func (c *UpstreamGroup) IsEnabled() bool { + return len(c.Upstreams) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) { + logger.Info("group: ", c.Name) + logger.Info("upstreams:") + + for _, upstream := range c.Upstreams { + logger.Infof(" - %s", upstream) + } +} diff --git a/config/upstreams_test.go b/config/upstreams_test.go index 07b2653e..3dc6bd4a 100644 --- a/config/upstreams_test.go +++ b/config/upstreams_test.go @@ -9,53 +9,104 @@ import ( ) var _ = Describe("ParallelBestConfig", func() { - var cfg UpstreamsConfig - suiteBeforeEach() - BeforeEach(func() { - cfg = UpstreamsConfig{ - Timeout: Duration(5 * time.Second), - Groups: UpstreamGroups{ - UpstreamDefaultCfgName: { - {Host: "host1"}, - {Host: "host2"}, + Context("UpstreamsConfig", func() { + var cfg UpstreamsConfig + + BeforeEach(func() { + cfg = UpstreamsConfig{ + Timeout: Duration(5 * time.Second), + Groups: UpstreamGroups{ + UpstreamDefaultCfgName: { + {Host: "host1"}, + {Host: "host2"}, + }, }, - }, - } - }) - - Describe("IsEnabled", func() { - It("should be false by default", func() { - cfg := UpstreamsConfig{} - Expect(defaults.Set(&cfg)).Should(Succeed()) - - Expect(cfg.IsEnabled()).Should(BeFalse()) + } }) - When("enabled", func() { - It("should be true", func() { - Expect(cfg.IsEnabled()).Should(BeTrue()) - }) - }) - - When("disabled", func() { - It("should be false", func() { + Describe("IsEnabled", func() { + It("should be false by default", func() { cfg := UpstreamsConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) Expect(cfg.IsEnabled()).Should(BeFalse()) }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := UpstreamsConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:"))) + }) }) }) - Describe("LogConfig", func() { - It("should log configuration", func() { - cfg.LogConfig(logger) + Context("UpstreamGroupConfig", func() { + var cfg UpstreamGroup - Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:"))) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:"))) - Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:"))) + BeforeEach(func() { + cfg = UpstreamGroup{ + Name: UpstreamDefaultCfgName, + Upstreams: []Upstream{ + {Host: "host1"}, + {Host: "host2"}, + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := UpstreamGroup{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := UpstreamGroup{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:"))) + }) }) }) }) diff --git a/docs/config.yml b/docs/config.yml index e1eecc81..04544706 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -19,7 +19,7 @@ upstreams: laptop*: - 123.123.123.123 # optional: Determines what strategy blocky uses to choose the upstream servers. - # accepted: parallel_best, strict + # accepted: parallel_best, strict, random # default: parallel_best strategy: parallel_best # optional: timeout to query the upstream resolver. Default: 2s diff --git a/docs/configuration.md b/docs/configuration.md index 5c70de1e..eca5bd0c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -139,10 +139,14 @@ Blocky supports different upstream strategies (default `parallel_best`) that det Currently available strategies: -- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one. +- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one. If an upstream failed to answer within the last hour, it is less likely to be chosen for the race. - This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers + This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers. (When using 10 upstream servers, each upstream will get on average 20% of the DNS requests) +- `random`: blocky picks one random (weighted) resolver from the upstream group for each query and if successful, returns its response. + If the selected resolver fails to respond, a second one is picked to which the query is sent. + The weighting is identical to the `parallel_best` strategy. + Although the `random` strategy might be slower than the `parallel_best` strategy, it offers more privacy since each request is sent to a single upstream. - `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on. !!! example diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 2a6a33fb..813dcadc 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "sort" "strings" "sync" @@ -11,6 +12,7 @@ import ( "time" "github.com/0xERR0R/blocky/cache/expirationcache" + "golang.org/x/exp/maps" "github.com/hashicorp/go-multierror" @@ -206,21 +208,11 @@ func (r *BlockingResolver) RefreshLists() error { return err.ErrorOrNil() } -//nolint:prealloc func (r *BlockingResolver) retrieveAllBlockingGroups() []string { - groups := make(map[string]bool, len(r.cfg.BlackLists)) - - for group := range r.cfg.BlackLists { - groups[group] = true - } - - var result []string - for k := range groups { - result = append(result, k) - } + result := maps.Keys(r.cfg.BlackLists) result = append(result, "default") - sort.Strings(result) + slices.Sort(result) return result } @@ -615,11 +607,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP } func (r *BlockingResolver) initFQDNIPCache() { - identifiers := make([]string, 0) - - for identifier := range r.clientGroupsBlock { - identifiers = append(identifiers, identifier) - } + identifiers := maps.Keys(r.clientGroupsBlock) for _, identifier := range identifiers { if isFQDN(identifier) { diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index 1d9a1c64..e548c0c7 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -77,9 +78,9 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er // Bootstrap doesn't have a `LogConfig` method, and since that's the only place // where `ParallelBestResolver` uses its config, we can just use an empty one. - var pbCfg config.UpstreamsConfig + pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName} - parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups()) + parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers()) // Always enable prefetching to avoid stalling user requests // Otherwise, a request to blocky could end up waiting for 2 DNS requests: @@ -300,16 +301,8 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots return upstreamIPs, nil } -func (br bootstrapedResolvers) ResolverGroups() map[string][]Resolver { - resolvers := make([]Resolver, 0, len(br)) - - for resolver := range br { - resolvers = append(resolvers, resolver) - } - - return map[string][]Resolver{ - upstreamDefaultCfgName: resolvers, - } +func (br bootstrapedResolvers) Resolvers() []Resolver { + return maps.Keys(br) } type IPSet struct { diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index b1e1477a..61b3722d 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -32,6 +32,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { ) BeforeEach(func() { + config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest sutConfig = &config.Config{ BootstrapDNS: []config.BootstrappedUpstreamConfig{ { diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 17869fd3..e818546e 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -190,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil } - logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver") + logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver") response, err = r.next.Resolve(request) if err == nil { diff --git a/resolver/conditional_upstream_resolver.go b/resolver/conditional_upstream_resolver.go index 13084c49..5f70133c 100644 --- a/resolver/conditional_upstream_resolver.go +++ b/resolver/conditional_upstream_resolver.go @@ -27,14 +27,10 @@ func NewConditionalUpstreamResolver( ) (*ConditionalUpstreamResolver, error) { m := make(map[string]Resolver, len(cfg.Mapping.Upstreams)) - for domain, upstream := range cfg.Mapping.Upstreams { - pbCfg := config.UpstreamsConfig{ - Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: upstream, - }, - } + for domain, upstreams := range cfg.Mapping.Upstreams { + cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams} - r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams) + r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams) if err != nil { return nil, err } diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index 38930066..f580be0c 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -138,7 +138,7 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er } } - logger.WithField("resolver", Name(r.next)).Trace("go to next resolver") + logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") return r.next.Resolve(request) } diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index 0a3c2d7f..03ed5464 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -132,7 +132,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil } - r.log().WithField("resolver", Name(r.next)).Trace("go to next resolver") + r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver") return r.next.Resolve(request) } diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index cd7b02cf..8b5b2b91 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -20,17 +20,22 @@ import ( ) const ( - upstreamDefaultCfgName = config.UpstreamDefaultCfgName - parallelResolverType = "parallel_best" - resolverCount = 2 + upstreamDefaultCfgName = config.UpstreamDefaultCfgName + parallelResolverType = "parallel_best" + randomResolverType = "random" + parallelBestResolverCount = 2 ) // ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer type ParallelBestResolver struct { - configurable[*config.UpstreamsConfig] + configurable[*config.UpstreamGroup] typed - resolversPerClient map[string][]*upstreamResolverStatus + groupName string + resolvers []*upstreamResolverStatus + + resolverCount int + retryWithDifferentResolver bool } type upstreamResolverStatus struct { @@ -50,9 +55,13 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus { func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) { resp, err := r.resolver.Resolve(req) - if err != nil && !errors.Is(err, context.Canceled) { // ignore `Canceled`: resolver lost the race, not an error - // update the last error time - r.lastErrorTime.Store(time.Now()) + if err != nil { + // Ignore `Canceled`: resolver lost the race, not an error + if !errors.Is(err, context.Canceled) { + r.lastErrorTime.Store(time.Now()) + } + + err = fmt.Errorf("%s: %w", r.resolver, err) } ch <- requestResponse{ @@ -82,67 +91,44 @@ func testResolver(r *UpstreamResolver) error { // NewParallelBestResolver creates new resolver instance func NewParallelBestResolver( - cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ) (*ParallelBestResolver, error) { logger := log.PrefixedLog(parallelResolverType) - upstreamResolvers := cfg.Groups - resolverGroups := make(map[string][]Resolver, len(upstreamResolvers)) - - for name, upstreamCfgs := range upstreamResolvers { - group := make([]Resolver, 0, len(upstreamCfgs)) - hasValidResolver := false - - for _, u := range upstreamCfgs { - resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams) - if err != nil { - logger.Warnf("upstream group %s: %v", name, err) - - continue - } - - if shouldVerifyUpstreams { - err = testResolver(resolver) - if err != nil { - logger.Warn(err) - } else { - hasValidResolver = true - } - } - - group = append(group, resolver) - } - - if shouldVerifyUpstreams && !hasValidResolver { - return nil, fmt.Errorf("no valid upstream for group %s", name) - } - - resolverGroups[name] = group + resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams) + if err != nil { + return nil, err } - return newParallelBestResolver(cfg, resolverGroups), nil + return newParallelBestResolver(cfg, resolvers), nil } func newParallelBestResolver( - cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver, + cfg config.UpstreamGroup, resolvers []Resolver, ) *ParallelBestResolver { - resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups)) + resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) - for groupName, resolvers := range resolverGroups { - resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) + for _, r := range resolvers { + resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) + } - for _, r := range resolvers { - resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) - } + resolverCount := parallelBestResolverCount + retryWithDifferentResolver := false - resolversPerClient[groupName] = resolverStatuses + if config.GetConfig().Upstreams.Strategy == config.UpstreamStrategyRandom { + resolverCount = 1 + retryWithDifferentResolver = true } r := ParallelBestResolver{ configurable: withConfig(&cfg), typed: withType(parallelResolverType), - resolversPerClient: resolversPerClient, + groupName: cfg.Name, + resolvers: resolverStatuses, + + resolverCount: resolverCount, + retryWithDifferentResolver: retryWithDifferentResolver, } return &r @@ -153,88 +139,137 @@ func (r *ParallelBestResolver) Name() string { } func (r *ParallelBestResolver) String() string { - result := make([]string, 0, len(r.resolversPerClient)) - - for name, res := range r.resolversPerClient { - tmp := make([]string, len(res)) - for i, s := range res { - tmp[i] = fmt.Sprintf("%s", s.resolver) - } - - result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ","))) + result := make([]string, len(r.resolvers)) + for i, s := range r.resolvers { + result[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("parallel upstreams '%s'", strings.Join(result, "; ")) + return fmt.Sprintf("%s (resolverCount: %d, retryWithDifferentResolver: %t) upstreams '%s (%s)'", + parallelResolverType, r.resolverCount, r.retryWithDifferentResolver, r.groupName, strings.Join(result, ",")) } // Resolve sends the query request to multiple upstream resolvers and returns the fastest result func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) { logger := log.WithPrefix(request.Log, parallelResolverType) - var resolvers []*upstreamResolverStatus - for _, r := range r.resolversPerClient { - resolvers = r + if len(r.resolvers) == 1 { + logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver") - break + return r.resolvers[0].resolver.Resolve(request) } - if len(resolvers) == 1 { - logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver") + ctx := context.Background() - return resolvers[0].resolver.Resolve(request) + // using context with timeout for random upstream strategy + if r.resolverCount == 1 { + var cancel context.CancelFunc + + logger = log.WithPrefix(logger, "random") + timeout := config.GetConfig().Upstreams.Timeout + + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)) + defer cancel() } - r1, r2 := pickRandom(resolvers) - logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver) + resolvers := pickRandom(r.resolvers, r.resolverCount) + ch := make(chan requestResponse, len(resolvers)) - ch := make(chan requestResponse, resolverCount) + for _, resolver := range resolvers { + logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver") - var collectedErrors []error + go resolver.resolve(request, ch) + } - logger.WithField("resolver", r1.resolver).Debug("delegating to resolver") + response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers) + if response != nil { + return response, nil + } - go r1.resolve(request, ch) + if !r.retryWithDifferentResolver { + return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...)) + } - logger.WithField("resolver", r2.resolver).Debug("delegating to resolver") + return r.retryWithDifferent(logger, request, resolvers) +} - go r2.resolve(request, ch) +func evaluateResponses( + ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus, +) (*model.Response, []error) { + collectedErrors := make([]error, 0, len(resolvers)) - for len(collectedErrors) < resolverCount { - result := <-ch + for len(collectedErrors) < len(resolvers) { + select { + case <-ctx.Done(): + // this context currently only has a deadline when resolverCount == 1 + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + logger.WithField("resolver", resolvers[0].resolver). + Debug("upstream exceeded timeout, trying other upstream") + resolvers[0].lastErrorTime.Store(time.Now()) + } + case result := <-ch: + if result.err != nil { + logger.Debug("resolution failed from resolver, cause: ", result.err) + collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err)) + } else { + logger.WithFields(logrus.Fields{ + "resolver": *result.resolver, + "answer": util.AnswerToString(result.response.Res.Answer), + }).Debug("using response from resolver") - if result.err != nil { - logger.Debug("resolution failed from resolver, cause: ", result.err) - collectedErrors = append(collectedErrors, result.err) - } else { - logger.WithFields(logrus.Fields{ - "resolver": *result.resolver, - "answer": util.AnswerToString(result.response.Res.Answer), - }).Debug("using response from resolver") - - return result.response, nil + return result.response, nil + } } } - return nil, fmt.Errorf("resolution was not successful, used resolvers: '%s' and '%s' errors: %v", - r1.resolver, r2.resolver, collectedErrors) + return nil, collectedErrors } -// pick 2 different random resolvers from the resolver pool -func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upstreamResolverStatus) { - resolver1 = weightedRandom(resolvers, nil) - resolver2 = weightedRandom(resolvers, resolver1.resolver) +func (r *ParallelBestResolver) retryWithDifferent( + logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus, +) (*model.Response, error) { + // second try (if retryWithDifferentResolver == true) + resolver := weightedRandom(r.resolvers, resolvers) + logger.Debugf("using %s as second resolver", resolver.resolver) - return + ch := make(chan requestResponse, 1) + + resolver.resolve(request, ch) + + result := <-ch + if result.err != nil { + return nil, fmt.Errorf("resolution retry failed: %w", result.err) + } + + logger.WithFields(logrus.Fields{ + "resolver": *result.resolver, + "answer": util.AnswerToString(result.response.Res.Answer), + }).Debug("using response from resolver") + + return result.response, nil } -func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus { +// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool +func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus { + chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount) + + for i := 0; i < resolverCount; i++ { + chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers)) + } + + return chosenResolvers +} + +func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus { const errorWindowInSec = 60 choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in)) +outer: for _, res := range in { - if exclude == res.resolver { - continue + for _, exclude := range excludedResolvers { + if exclude.resolver == res.resolver { + continue outer + } } var weight float64 = errorWindowInSec diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 3ebb06b3..35d8d6e4 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -22,11 +22,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { ) var ( - sut *ParallelBestResolver - sutMapping config.UpstreamGroups - sutVerify bool - ctx context.Context - cancelFn context.CancelFunc + sut *ParallelBestResolver + upstreams []config.Upstream + sutVerify bool + ctx context.Context + cancelFn context.CancelFunc err error @@ -40,19 +40,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) BeforeEach(func() { + config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest + ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - config.Upstream{ - Host: "127.0.0.2", - }, - }, - } + upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}} sutVerify = noVerifyUpstreams @@ -60,9 +53,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) JustBeforeEach(func() { - sutConfig := config.UpstreamsConfig{ - Timeout: config.Duration(1000 * time.Millisecond), - Groups: sutMapping, + sutConfig := config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: upstreams, } sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) @@ -85,8 +78,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) Describe("Name", func() { - It("should not be empty", func() { + It("should contain correct resolver", func() { Expect(sut.Name()).ShouldNot(BeEmpty()) + Expect(sut.Name()).Should(ContainSubstring(parallelResolverType)) }) }) @@ -99,18 +93,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) defer mockUpstream.Close() - upstream := config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - mockUpstream.Start(), - }, + upstreams := []config.Upstream{ + {Host: "wrong"}, + mockUpstream.Start(), } - _, err := NewParallelBestResolver(config.UpstreamsConfig{ - Groups: upstream, - }, systemResolverBootstrap, verifyUpstreams) + _, err := NewParallelBestResolver(config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: upstreams, + }, + systemResolverBootstrap, verifyUpstreams) Expect(err).Should(Not(HaveOccurred())) }) }) @@ -119,16 +111,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { BeforeEach(func() { bootstrap = newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - config.Upstream{ - Host: "127.0.0.2", - }, - }, - } + upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}} }) When("strict checking is enabled", func() { @@ -167,9 +150,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) DeferCleanup(slowTestUpstream.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()}, - } + upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()} }) It("Should use result from fastest one", func() { request := newRequest("example.com.", A) @@ -195,9 +176,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { return response }) DeferCleanup(slowTestUpstream.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {config.Upstream{Host: "wrong"}, slowTestUpstream.Start()}, - } + upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()} + Expect(err).Should(Succeed()) }) It("Should use result from successful resolver", func() { request := newRequest("example.com.", A) @@ -216,9 +196,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { withError1 := config.Upstream{Host: "wrong"} withError2 := config.Upstream{Host: "wrong"} - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {withError1, withError2}, - } + upstreams = []config.Upstream{withError1, withError2} + Expect(err).Should(Succeed()) }) It("Should return error", func() { Expect(err).Should(Succeed()) @@ -234,11 +213,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - mockUpstream.Start(), - }, - } + upstreams = []config.Upstream{mockUpstream.Start()} }) It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) @@ -267,34 +242,30 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, - }}, + sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, + }, systemResolverBootstrap, noVerifyUpstreams) By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) for i := 0; i < 1000; i++ { - var resolvers []*upstreamResolverStatus - for _, r := range sut.resolversPerClient { - resolvers = r - } - r1, r2 := pickRandom(resolvers) - res1 := r1.resolver - res2 := r2.resolver + resolvers := pickRandom(sut.resolvers, parallelBestResolverCount) + res1 := resolvers[0].resolver + res2 := resolvers[1].resolver Expect(res1).ShouldNot(Equal(res2)) resolverCount[res1]++ resolverCount[res2]++ } for _, v := range resolverCount { - // should be 500 ± 100 - Expect(v).Should(BeNumerically("~", 500, 100)) + // should be 500 ± 50 + Expect(v).Should(BeNumerically("~", 500, 75)) } }) - By("perform 10 request, error upstream's weight will be reduced", func() { - // perform 10 requests + By("perform 100 request, error upstream's weight will be reduced", func() { for i := 0; i < 100; i++ { request := newRequest("example.com.", A) _, _ = sut.Resolve(request) @@ -305,13 +276,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 100; i++ { - var resolvers []*upstreamResolverStatus - for _, r := range sut.resolversPerClient { - resolvers = r - } - r1, r2 := pickRandom(resolvers) - res1 := r1.resolver.(*UpstreamResolver) - res2 := r2.resolver.(*UpstreamResolver) + resolvers := pickRandom(sut.resolvers, parallelBestResolverCount) + res1 := resolvers[0].resolver.(*UpstreamResolver) + res2 := resolvers[1].resolver.(*UpstreamResolver) Expect(res1).ShouldNot(Equal(res2)) resolverCount[res1]++ @@ -335,12 +302,235 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewParallelBestResolver(config.UpstreamsConfig{ - Groups: config.UpstreamGroups{"test": {{Host: "example.com"}}}, + r, err := NewParallelBestResolver(config.UpstreamGroup{ + Name: "test", + Upstreams: []config.Upstream{{Host: "example.com"}}, }, b, verifyUpstreams) Expect(err).ShouldNot(Succeed()) Expect(r).Should(BeNil()) }) }) + + Describe("random resolver strategy", func() { + const timeout = config.Duration(time.Second) + + BeforeEach(func() { + config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom + config.GetConfig().Upstreams.Timeout = timeout + }) + + Describe("Name", func() { + It("should contain correct resolver", func() { + Expect(sut.Name()).ShouldNot(BeEmpty()) + Expect(sut.Name()).Should(ContainSubstring(parallelResolverType)) + }) + }) + + Describe("Resolving request in random order", func() { + When("Multiple upstream resolvers are defined", func() { + When("Both are responding", func() { + When("Both respond in time", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") + DeferCleanup(testUpstream2.Close) + + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} + }) + It("Should return result from either one", func() { + request := newRequest("example.com.", A) + Expect(sut.Resolve(request)). + Should(SatisfyAll( + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + Or( + BeDNSRecord("example.com.", A, "123.124.122.122"), + BeDNSRecord("example.com.", A, "123.124.122.123"), + ), + )) + }) + }) + When("one upstream exceeds timeout", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") + time.Sleep(time.Duration(timeout) + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") + DeferCleanup(testUpstream2.Close) + + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} + }) + It("should ask a other random upstream and return its response", func() { + request := newRequest("example.com", A) + Expect(sut.Resolve(request)).Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.2"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + When("two upstreams exceed timeout", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") + time.Sleep(timeout.ToDuration() + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2") + time.Sleep(timeout.ToDuration() + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream2.Close) + + testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3") + DeferCleanup(testUpstream3.Close) + + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()} + }) + // These two tests are flaky -_- (maybe recreate the RandomResolver ) + It("should not return error (due to random selection the request could to through)", func() { + Eventually(func() error { + request := newRequest("example.com", A) + _, err := sut.Resolve(request) + + return err + }).WithTimeout(30 * time.Second). + Should(Not(HaveOccurred())) + }) + It("should return error (because it can be possible that the two broken upstreams are chosen)", func() { + Eventually(func() error { + sutConfig := config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: upstreams, + } + sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) + + request := newRequest("example.com", A) + _, err := sut.Resolve(request) + + return err + }).WithTimeout(30 * time.Second). + Should(HaveOccurred()) + }) + }) + }) + When("None are working", func() { + BeforeEach(func() { + testUpstream1 := config.Upstream{Host: "wrong"} + testUpstream2 := config.Upstream{Host: "wrong"} + + upstreams = []config.Upstream{testUpstream1, testUpstream2} + Expect(err).Should(Succeed()) + }) + It("Should return error", func() { + request := newRequest("example.com.", A) + _, err := sut.Resolve(request) + Expect(err).Should(HaveOccurred()) + }) + }) + }) + When("only 1 upstream resolvers is defined", func() { + BeforeEach(func() { + mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(mockUpstream.Close) + + upstreams = []config.Upstream{mockUpstream.Start()} + }) + It("Should use result from defined resolver", func() { + request := newRequest("example.com.", A) + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.122"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + }) + + Describe("Weighted random on resolver selection", func() { + When("4 upstream resolvers are defined", func() { + It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { + withError1 := config.Upstream{Host: "wrong1"} + withError2 := config.Upstream{Host: "wrong2"} + + mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(mockUpstream1.Close) + + mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(mockUpstream2.Close) + + sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, + }, + systemResolverBootstrap, noVerifyUpstreams) + + By("all resolvers have same weight for random -> equal distribution", func() { + resolverCount := make(map[Resolver]int) + + for i := 0; i < 2000; i++ { + r := weightedRandom(sut.resolvers, nil) + resolverCount[r.resolver]++ + } + for _, v := range resolverCount { + // should be 500 ± 100 + Expect(v).Should(BeNumerically("~", 500, 100)) + } + }) + By("perform 100 request, error upstream's weight will be reduced", func() { + for i := 0; i < 100; i++ { + request := newRequest("example.com.", A) + _, _ = sut.Resolve(request) + } + }) + + By("Resolvers without errors should be selected often", func() { + resolverCount := make(map[*UpstreamResolver]int) + + for i := 0; i < 200; i++ { + r := weightedRandom(sut.resolvers, nil) + res := r.resolver.(*UpstreamResolver) + + resolverCount[res]++ + } + for k, v := range resolverCount { + if strings.Contains(k.String(), "wrong") { + // error resolvers: should be 0 - 10 + Expect(v).Should(BeNumerically("~", 0, 10)) + } else { + // should be 90 ± 10 + Expect(v).Should(BeNumerically("~", 95, 20)) + } + } + }) + }) + }) + }) + }) }) diff --git a/resolver/resolver.go b/resolver/resolver.go index 59382834..677b47f0 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -214,3 +214,36 @@ func (c *configurable[T]) IsEnabled() bool { func (c *configurable[T]) LogConfig(logger *logrus.Entry) { c.cfg.LogConfig(logger) } + +func createResolvers( + logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool, +) ([]Resolver, error) { + resolvers := make([]Resolver, 0, len(cfg.Upstreams)) + hasValidResolvers := false + + for _, u := range cfg.Upstreams { + resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams) + if err != nil { + logger.Warnf("upstream group %s: %v", cfg.Name, err) + + continue + } + + if shoudVerifyUpstreams { + err = testResolver(resolver) + if err != nil { + logger.Warn(err) + } else { + hasValidResolvers = true + } + } + + resolvers = append(resolvers, resolver) + } + + if shoudVerifyUpstreams && !hasValidResolvers { + return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name) + } + + return resolvers, nil +} diff --git a/resolver/strict_resolver.go b/resolver/strict_resolver.go index 2fd6c460..cbf3ab16 100644 --- a/resolver/strict_resolver.go +++ b/resolver/strict_resolver.go @@ -21,75 +21,42 @@ const ( // StrictResolver delegates the DNS message strictly to the first configured upstream resolver // if it can't provide the answer in time the next resolver is used type StrictResolver struct { - configurable[*config.UpstreamsConfig] + configurable[*config.UpstreamGroup] typed - resolversPerClient map[string][]*upstreamResolverStatus + groupName string + resolvers []*upstreamResolverStatus } -// NewStrictResolver creates new resolver instance +// NewStrictResolver creates a new strict resolver instance func NewStrictResolver( - cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ) (*StrictResolver, error) { logger := log.PrefixedLog(strictResolverType) - upstreamResolvers := cfg.Groups - resolverGroups := make(map[string][]Resolver, len(upstreamResolvers)) - - for name, upstreamCfgs := range upstreamResolvers { - group := make([]Resolver, 0, len(upstreamCfgs)) - hasValidResolver := false - - for _, u := range upstreamCfgs { - resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams) - if err != nil { - logger.Warnf("upstream group %s: %v", name, err) - - continue - } - - if shouldVerifyUpstreams { - err = testResolver(resolver) - if err != nil { - logger.Warn(err) - } else { - hasValidResolver = true - } - } - - group = append(group, resolver) - } - - if shouldVerifyUpstreams && !hasValidResolver { - return nil, fmt.Errorf("no valid upstream for group %s", name) - } - - resolverGroups[name] = group + resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams) + if err != nil { + return nil, err } - return newStrictResolver(cfg, resolverGroups), nil + return newStrictResolver(cfg, resolvers), nil } func newStrictResolver( - cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver, + cfg config.UpstreamGroup, resolvers []Resolver, ) *StrictResolver { - resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups)) + resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) - for groupName, resolvers := range resolverGroups { - resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) - - for _, r := range resolvers { - resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) - } - - resolversPerClient[groupName] = resolverStatuses + for _, r := range resolvers { + resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) } r := StrictResolver{ configurable: withConfig(&cfg), typed: withType(strictResolverType), - resolversPerClient: resolversPerClient, + groupName: cfg.Name, + resolvers: resolverStatuses, } return &r @@ -100,48 +67,36 @@ func (r *StrictResolver) Name() string { } func (r *StrictResolver) String() string { - result := make([]string, 0, len(r.resolversPerClient)) - - for name, res := range r.resolversPerClient { - tmp := make([]string, len(res)) - for i, s := range res { - tmp[i] = fmt.Sprintf("%s", s.resolver) - } - - result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ","))) + result := make([]string, len(r.resolvers)) + for i, s := range r.resolvers { + result[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("%s upstreams %q", strictResolverType, strings.Join(result, "; ")) + return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.groupName, strings.Join(result, ",")) } -// Resolve sends the query request to multiple upstream resolvers and returns the fastest result +// Resolve sends the query request in a strict order to the upstream resolvers func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) { logger := log.WithPrefix(request.Log, strictResolverType) - var resolvers []*upstreamResolverStatus - for _, r := range r.resolversPerClient { - resolvers = r - - break - } - // start with first resolver - for i := range resolvers { + for i := range r.resolvers { timeout := config.GetConfig().Upstreams.Timeout.ToDuration() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - // start in new go routine and cancel if - resolver := resolvers[i] - ch := make(chan requestResponse, resolverCount) + resolver := r.resolvers[i] + logger.Debugf("using %s as resolver", resolver.resolver) + + ch := make(chan requestResponse, 1) go resolver.resolve(request, ch) select { case <-ctx.Done(): // log debug/info that timeout exceeded, call `continue` to try next upstream - logger.WithField("resolver", resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream") + logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream") continue case result := <-ch: diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index c35a9d15..6f5e3dd7 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -20,9 +20,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { ) var ( - sut *StrictResolver - sutMapping config.UpstreamGroups - sutVerify bool + sut *StrictResolver + upstreams []config.Upstream + sutVerify bool err error @@ -36,15 +36,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) BeforeEach(func() { - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - config.Upstream{ - Host: "127.0.0.2", - }, - }, + upstreams = []config.Upstream{ + {Host: "wrong"}, + {Host: "127.0.0.2"}, } sutVerify = noVerifyUpstreams @@ -53,12 +47,14 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) JustBeforeEach(func() { - sutConfig := config.UpstreamsConfig{Groups: sutMapping} - + sutConfig := config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: upstreams, + } sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify) }) - config.GetConfig().Upstreams.Timeout = config.Duration(1000 * time.Millisecond) + config.GetConfig().Upstreams.Timeout = config.Duration(time.Second) Describe("IsEnabled", func() { It("is true", func() { @@ -99,33 +95,25 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) defer mockUpstream.Close() - upstream := config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - mockUpstream.Start(), - }, + upstreams := []config.Upstream{ + {Host: "wrong"}, + mockUpstream.Start(), } - _, err := NewStrictResolver(config.UpstreamsConfig{ - Groups: upstream, - }, systemResolverBootstrap, verifyUpstreams) + _, err := NewStrictResolver(config.UpstreamGroup{ + Name: upstreamDefaultCfgName, + Upstreams: upstreams, + }, + systemResolverBootstrap, verifyUpstreams) Expect(err).Should(Not(HaveOccurred())) }) }) When("no upstream resolvers can be reached", func() { BeforeEach(func() { - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - config.Upstream{ - Host: "wrong", - }, - config.Upstream{ - Host: "127.0.0.2", - }, - }, + upstreams = []config.Upstream{ + {Host: "wrong"}, + {Host: "127.0.0.2"}, } }) @@ -159,9 +147,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") DeferCleanup(testUpstream2.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, - } + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} }) It("Should use result from first one", func() { request := newRequest("example.com.", A) @@ -190,9 +176,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") DeferCleanup(testUpstream2.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, - } + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} }) It("should return response from next upstream", func() { request := newRequest("example.com", A) @@ -226,10 +210,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { return response }) DeferCleanup(testUpstream2.Close) - - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, - } + upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} }) It("should return error", func() { request := newRequest("example.com", A) @@ -245,9 +226,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") DeferCleanup(testUpstream2.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {testUpstream1, testUpstream2.Start()}, - } + upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()} }) It("Should use result from second one", func() { request := newRequest("example.com.", A) @@ -263,9 +242,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) When("None are working", func() { BeforeEach(func() { - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {config.Upstream{Host: "wrong"}, config.Upstream{Host: "wrong"}}, - } + upstreams = []config.Upstream{{Host: "wrong"}, {Host: "wrong"}} + Expect(err).Should(Succeed()) }) It("Should return error", func() { request := newRequest("example.com.", A) @@ -279,11 +257,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream.Close) - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: { - mockUpstream.Start(), - }, - } + upstreams = []config.Upstream{mockUpstream.Start()} }) It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) diff --git a/server/server.go b/server/server.go index 219cf4c7..bdbb4b17 100644 --- a/server/server.go +++ b/server/server.go @@ -447,13 +447,18 @@ func createUpstreamBranches( err error ) - resolverCfg := config.UpstreamsConfig{Groups: config.UpstreamGroups{group: upstreams}} + groupConfig := config.UpstreamGroup{ + Name: group, + Upstreams: upstreams, + } switch cfg.Upstreams.Strategy { - case config.UpstreamStrategyStrict: - upstream, err = resolver.NewStrictResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream) case config.UpstreamStrategyParallelBest: - upstream, err = resolver.NewParallelBestResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream) + upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) + case config.UpstreamStrategyStrict: + upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) + case config.UpstreamStrategyRandom: + upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) } upstreamBranches[group] = upstream diff --git a/server/server_test.go b/server/server_test.go index 12c27aad..5c902a72 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -704,8 +704,7 @@ var _ = Describe("Running DNS server", func() { "default": {{Host: "0.0.0.0"}}, }, }, - }, - nil) + }, nil) Expect(err).ToNot(HaveOccurred()) Expect(branches).ToNot(BeNil()) @@ -714,6 +713,24 @@ var _ = Describe("Running DNS server", func() { }) }) + Describe("NewServer with random upstream strategy", func() { + It("successfully returns upstream branches", func() { + branches, err := createUpstreamBranches(&config.Config{ + Upstreams: config.UpstreamsConfig{ + Strategy: config.UpstreamStrategyRandom, + Groups: config.UpstreamGroups{ + "default": {{Host: "0.0.0.0"}}, + }, + }, + }, nil) + + Expect(err).ToNot(HaveOccurred()) + Expect(branches).ToNot(BeNil()) + Expect(branches).To(HaveLen(1)) + _ = branches["default"].(*resolver.ParallelBestResolver) + }) + }) + Describe("create query resolver", func() { When("some upstream returns error", func() { It("create query resolver should return error", func() {