From 7a3c054b4367d81653b70d3f33e7dfb967fb04ed Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sun, 3 Dec 2023 17:46:59 -0500 Subject: [PATCH] feat: add `upstreams.init.strategy` Replaces `startVerifyUpstream` and behaves just like `blocking.loading.strategy`. We use the bootstrap resolver for any requests that arrive before the upstreams are initialized. --- config/blocking.go | 4 +- config/config.go | 101 ++++++++++------ config/config_test.go | 95 ++++++++++++--- config/upstreams.go | 14 ++- docs/config.yml | 13 +- docs/configuration.md | 37 +++--- e2e/upstream_test.go | 28 +++-- resolver/blocking_resolver_test.go | 10 +- resolver/bootstrap.go | 7 +- resolver/client_names_resolver_test.go | 28 ++++- .../conditional_upstream_resolver_test.go | 5 +- resolver/mock_udp_upstream_server.go | 4 + resolver/parallel_best_resolver.go | 41 ++++--- resolver/parallel_best_resolver_test.go | 112 +++++++++++++----- resolver/resolver.go | 67 +++++++---- resolver/strict_resolver.go | 35 +++--- resolver/strict_resolver_test.go | 19 ++- resolver/upstream_resolver.go | 12 +- resolver/upstream_tree_resolver_test.go | 4 +- 19 files changed, 429 insertions(+), 207 deletions(-) diff --git a/config/blocking.go b/config/blocking.go index 45768ab4..bd548598 100644 --- a/config/blocking.go +++ b/config/blocking.go @@ -34,13 +34,13 @@ func (c *Blocking) migrate(logger *logrus.Entry) bool { "downloadAttempts": Move(To("loading.downloads.attempts", &c.Loading.Downloads)), "downloadCooldown": Move(To("loading.downloads.cooldown", &c.Loading.Downloads)), "refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)), - "failStartOnListError": Apply(To("loading.strategy", &c.Loading), func(oldValue bool) { + "failStartOnListError": Apply(To("loading.strategy", &c.Loading.Init), func(oldValue bool) { if oldValue { c.Loading.Strategy = StartStrategyTypeFailOnError } }), "processingConcurrency": Move(To("loading.concurrency", &c.Loading)), - "startStrategy": Move(To("loading.strategy", &c.Loading)), + "startStrategy": Move(To("loading.strategy", &c.Loading.Init)), "maxErrorsPerFile": Move(To("loading.maxErrorsPerSource", &c.Loading)), }) } diff --git a/config/config.go b/config/config.go index 93e3eea8..33ca165f 100644 --- a/config/config.go +++ b/config/config.go @@ -116,10 +116,14 @@ type QueryLogType int16 // ) type StartStrategyType uint16 -func (s *StartStrategyType) do(setup func() error, logErr func(error)) error { - if *s == StartStrategyTypeFast { +func (s StartStrategyType) Do(ctx context.Context, init func(context.Context) error, logErr func(error)) error { + init = recoverToError(init, func(panicVal any) error { + return fmt.Errorf("panic during initialization: %v", panicVal) + }) + + if s == StartStrategyTypeFast { go func() { - err := setup() + err := init(ctx) if err != nil { logErr(err) } @@ -128,11 +132,11 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error { return nil } - err := setup() + err := init(ctx) if err != nil { logErr(err) - if *s == StartStrategyTypeFailOnError { + if s == StartStrategyTypeFailOnError { return err } } @@ -309,18 +313,27 @@ func (c *toEnable) LogConfig(logger *logrus.Entry) { logger.Info("enabled") } +type Init struct { + Strategy StartStrategyType `yaml:"strategy" default:"blocking"` +} + +func (c *Init) LogConfig(logger *logrus.Entry) { + logger.Debugf("strategy = %s", c.Strategy) +} + type SourceLoadingConfig struct { - Concurrency uint `yaml:"concurrency" default:"4"` - MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"` - RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` - Strategy StartStrategyType `yaml:"strategy" default:"blocking"` - Downloads DownloaderConfig `yaml:"downloads"` + Init `yaml:",inline"` + + Concurrency uint `yaml:"concurrency" default:"4"` + MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"` + RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` + Downloads DownloaderConfig `yaml:"downloads"` } func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) { + c.Init.LogConfig(logger) logger.Infof("concurrency = %d", c.Concurrency) logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource) - logger.Debugf("strategy = %s", c.Strategy) if c.RefreshPeriod.IsAboveZero() { logger.Infof("refresh = every %s", c.RefreshPeriod) @@ -332,36 +345,28 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) { log.WithIndent(logger, " ", c.Downloads.LogConfig) } -func (c *SourceLoadingConfig) StartPeriodicRefresh(ctx context.Context, - refresh func(context.Context) error, - logErr func(error), +func (c *SourceLoadingConfig) StartPeriodicRefresh( + ctx context.Context, refresh func(context.Context) error, logErr func(error), ) error { - refreshAndRecover := func(ctx context.Context) (rerr error) { - defer func() { - if val := recover(); val != nil { - rerr = fmt.Errorf("refresh function panicked: %v", val) - } - }() - - return refresh(ctx) - } - - err := c.Strategy.do(func() error { return refreshAndRecover(context.Background()) }, logErr) + err := c.Strategy.Do(ctx, refresh, logErr) if err != nil { return err } if c.RefreshPeriod > 0 { - go c.periodically(ctx, refreshAndRecover, logErr) + go c.periodically(ctx, refresh, logErr) } return nil } -func (c *SourceLoadingConfig) periodically(ctx context.Context, - refresh func(context.Context) error, - logErr func(error), +func (c *SourceLoadingConfig) periodically( + ctx context.Context, refresh func(context.Context) error, logErr func(error), ) { + refresh = recoverToError(refresh, func(panicVal any) error { + return fmt.Errorf("panic during refresh: %v", panicVal) + }) + ticker := time.NewTicker(c.RefreshPeriod.ToDuration()) defer ticker.Stop() @@ -379,6 +384,18 @@ func (c *SourceLoadingConfig) periodically(ctx context.Context, } } +func recoverToError(do func(context.Context) error, onPanic func(any) error) func(context.Context) error { + return func(ctx context.Context) (rerr error) { + defer func() { + if val := recover(); val != nil { + rerr = onPanic(val) + } + }() + + return do(ctx) + } +} + type DownloaderConfig struct { Timeout Duration `yaml:"timeout" default:"5s"` Attempts uint `yaml:"attempts" default:"3"` @@ -535,16 +552,22 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool { cfg.Filtering.QueryTypes.Insert(dns.Type(dns.TypeAAAA)) } }), - "port": Move(To("ports.dns", &cfg.Ports)), - "httpPort": Move(To("ports.http", &cfg.Ports)), - "httpsPort": Move(To("ports.https", &cfg.Ports)), - "tlsPort": Move(To("ports.tls", &cfg.Ports)), - "logLevel": Move(To("log.level", &cfg.Log)), - "logFormat": Move(To("log.format", &cfg.Log)), - "logPrivacy": Move(To("log.privacy", &cfg.Log)), - "logTimestamp": Move(To("log.timestamp", &cfg.Log)), - "startVerifyUpstream": Move(To("upstreams.startVerify", &cfg.Upstreams)), - "dohUserAgent": Move(To("upstreams.userAgent", &cfg.Upstreams)), + "port": Move(To("ports.dns", &cfg.Ports)), + "httpPort": Move(To("ports.http", &cfg.Ports)), + "httpsPort": Move(To("ports.https", &cfg.Ports)), + "tlsPort": Move(To("ports.tls", &cfg.Ports)), + "logLevel": Move(To("log.level", &cfg.Log)), + "logFormat": Move(To("log.format", &cfg.Log)), + "logPrivacy": Move(To("log.privacy", &cfg.Log)), + "logTimestamp": Move(To("log.timestamp", &cfg.Log)), + "dohUserAgent": Move(To("upstreams.userAgent", &cfg.Upstreams)), + "startVerifyUpstream": Apply(To("upstreams.init.strategy", &cfg.Upstreams.Init), func(value bool) { + if value { + cfg.Upstreams.Init.Strategy = StartStrategyTypeFailOnError + } else { + cfg.Upstreams.Init.Strategy = StartStrategyTypeFast + } + }), }) usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts diff --git a/config/config_test.go b/config/config_test.go index 182f20a4..4bf74de5 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -142,6 +142,15 @@ var _ = Describe("Config", func() { Expect(c.Ports.TLS).Should(Equal(ports)) }) }) + + When("parameter 'startVerifyUpstream' is set", func() { + It("should convert to upstreams.init.strategy", func() { + c.Deprecated.StartVerifyUpstream = ptrOf(true) + c.migrate(logger) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("startVerifyUpstream"))) + Expect(c.Upstreams.Init.Strategy).Should(Equal(StartStrategyTypeFailOnError)) + }) + }) }) Describe("Creation of Config", func() { @@ -552,8 +561,10 @@ bootstrapDns: cfg.LogConfig(logger) Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages[0]).Should(Equal("concurrency = 12")) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour"))) + Expect(hook.Messages).Should(ContainElements( + ContainSubstring("concurrency = 12"), + ContainSubstring("refresh = every 1 hour"), + )) }) When("refresh is disabled", func() { BeforeEach(func() { @@ -590,9 +601,11 @@ bootstrapDns: Expect(recover()).Should(BeIdenticalTo(panicVal)) }() - _ = sut.do(func() error { + _ = sut.Do(context.Background(), func(context.Context) error { + return errors.New("trigger `logErr`") + }, func(err error) { panic(panicVal) - }, nil) + }) Fail("unreachable") }) @@ -601,7 +614,7 @@ bootstrapDns: sut := StartStrategyTypeBlocking expectedErr := errors.New("test") - err := sut.do(func() error { + err := sut.Do(context.Background(), func(context.Context) error { return expectedErr }, func(err error) { Expect(err).Should(MatchError(expectedErr)) @@ -609,11 +622,26 @@ bootstrapDns: Expect(err).Should(Succeed()) }) + + It("logs panics and doesn't convert them to errors", func() { + sut := StartStrategyTypeBlocking + + logged := false + err := sut.Do(context.Background(), func(context.Context) error { + panic(struct{}{}) + }, func(err error) { + logged = true + Expect(err).Should(MatchError(ContainSubstring("panic"))) + }) + + Expect(err).Should(Succeed()) + Expect(logged).Should(BeTrue()) + }) }) Describe("StartStrategyTypeFailOnError", func() { It("runs in the current goroutine", func() { - sut := StartStrategyTypeBlocking + sut := StartStrategyTypeFailOnError panicVal := new(int) defer func() { @@ -621,9 +649,11 @@ bootstrapDns: Expect(recover()).Should(BeIdenticalTo(panicVal)) }() - _ = sut.do(func() error { + _ = sut.Do(context.Background(), func(context.Context) error { + return errors.New("trigger `logErr`") + }, func(err error) { panic(panicVal) - }, nil) + }) Fail("unreachable") }) @@ -632,7 +662,7 @@ bootstrapDns: sut := StartStrategyTypeFailOnError expectedErr := errors.New("test") - err := sut.do(func() error { + err := sut.Do(context.Background(), func(context.Context) error { return expectedErr }, func(err error) { Expect(err).Should(MatchError(expectedErr)) @@ -640,6 +670,21 @@ bootstrapDns: Expect(err).Should(MatchError(expectedErr)) }) + + It("returns logs panics and converts them to errors", func() { + sut := StartStrategyTypeFailOnError + + logged := false + err := sut.Do(context.Background(), func(context.Context) error { + panic(struct{}{}) + }, func(err error) { + logged = true + Expect(err).Should(MatchError(ContainSubstring("panic"))) + }) + + Expect(err).Should(HaveOccurred()) + Expect(logged).Should(BeTrue()) + }) }) Describe("StartStrategyTypeFast", func() { @@ -648,7 +693,7 @@ bootstrapDns: events := make(chan string) wait := make(chan struct{}) - err := sut.do(func() error { + err := sut.Do(context.Background(), func(context.Context) error { events <- "start" <-wait events <- "done" @@ -668,7 +713,23 @@ bootstrapDns: expectedErr := errors.New("test") wait := make(chan struct{}) - err := sut.do(func() error { + err := sut.Do(context.Background(), func(context.Context) error { + return expectedErr + }, func(err error) { + Expect(err).Should(MatchError(expectedErr)) + close(wait) + }) + + Expect(err).Should(Succeed()) + Eventually(wait, "50ms").Should(BeClosed()) + }) + + It("logs panics", func() { + sut := StartStrategyTypeFast + expectedErr := errors.New("test") + wait := make(chan struct{}) + + err := sut.Do(context.Background(), func(context.Context) error { return expectedErr }, func(err error) { Expect(err).Should(MatchError(expectedErr)) @@ -717,7 +778,7 @@ bootstrapDns: }) It("handles panics", func() { sut := SourceLoadingConfig{ - Strategy: StartStrategyTypeFailOnError, + Init: Init{Strategy: StartStrategyTypeFailOnError}, } panicMsg := "panic value" @@ -733,7 +794,7 @@ bootstrapDns: It("periodically calls refresh", func() { sut := SourceLoadingConfig{ - Strategy: StartStrategyTypeFast, + Init: Init{Strategy: StartStrategyTypeFast}, RefreshPeriod: Duration(5 * time.Millisecond), } @@ -789,7 +850,7 @@ bootstrapDns: func defaultTestFileConfig(config *Config) { Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"})) - Expect(config.Upstreams.StartVerify).Should(BeFalse()) + Expect(config.Upstreams.Init.Strategy).Should(Equal(StartStrategyTypeFailOnError)) Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky")) Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3)) Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8")) @@ -824,8 +885,9 @@ func defaultTestFileConfig(config *Config) { func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { return tmpDir.CreateStringFile("config.yml", "upstreams:", - " startVerify: false", " userAgent: testBlocky", + " init:", + " strategy: failOnError", " groups:", " default:", " - tcp+udp:8.8.8.8", @@ -885,8 +947,9 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { func writeConfigDir(tmpDir *helpertest.TmpFolder) { tmpDir.CreateStringFile("config1.yaml", "upstreams:", - " startVerify: false", " userAgent: testBlocky", + " init:", + " strategy: failOnError", " groups:", " default:", " - tcp+udp:8.8.8.8", diff --git a/config/upstreams.go b/config/upstreams.go index 45294b88..698301ce 100644 --- a/config/upstreams.go +++ b/config/upstreams.go @@ -1,6 +1,7 @@ package config import ( + "github.com/0xERR0R/blocky/log" "github.com/sirupsen/logrus" ) @@ -8,11 +9,11 @@ const UpstreamDefaultCfgName = "default" // Upstreams upstream servers configuration type Upstreams struct { - Timeout Duration `yaml:"timeout" default:"2s"` - Groups UpstreamGroups `yaml:"groups"` - Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` - StartVerify bool `yaml:"startVerify" default:"false"` - UserAgent string `yaml:"userAgent"` + Init Init `yaml:"init"` + Timeout Duration `yaml:"timeout" default:"2s"` + Groups UpstreamGroups `yaml:"groups"` + Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` + UserAgent string `yaml:"userAgent"` } type UpstreamGroups map[string][]Upstream @@ -24,6 +25,9 @@ func (c *Upstreams) IsEnabled() bool { // LogConfig implements `config.Configurable`. func (c *Upstreams) LogConfig(logger *logrus.Entry) { + logger.Info("init:") + log.WithIndent(logger, " ", c.Init.LogConfig) + logger.Info("timeout: ", c.Timeout) logger.Info("strategy: ", c.Strategy) logger.Info("groups:") diff --git a/docs/config.yml b/docs/config.yml index 1e21e024..4b675bc8 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -1,6 +1,11 @@ # REVIEW: manual changelog entry upstreams: + init: + # Configure startup behavior. + # accepted: blocking, failOnError, fast + # default: blocking + strategy: fast groups: # these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query # format for resolver: [net:]host:[port][/path]. net could be empty (default, shortcut for tcp+udp), tcp+udp, tcp, udp, tcp-tls or https (DoH). If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls, 443 for https (Doh)) @@ -24,8 +29,6 @@ upstreams: strategy: parallel_best # optional: timeout to query the upstream resolver. Default: 2s timeout: 2s - # optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false - startVerify: false # optional: HTTP User Agent when connecting to upstreams. Default: none userAgent: "custom UA" @@ -126,7 +129,8 @@ blocking: # optional: Maximum number of lists to process in parallel. # default: 4 concurrency: 16 - # optional: if failOnError, application startup will fail if at least one list can't be downloaded/opened + # Configure startup behavior. + # accepted: blocking, failOnError, fast # default: blocking strategy: failOnError # Number of errors allowed in a list before it is considered invalid. @@ -292,7 +296,8 @@ hostsFile: # optional: Maximum number of files to process in parallel. # default: 4 concurrency: 16 - # optional: if failOnError, application startup will fail if at least one file can't be downloaded/opened + # Configure startup behavior. + # accepted: blocking, failOnError, fast # default: blocking strategy: failOnError # Number of errors allowed in a file before it is considered invalid. diff --git a/docs/configuration.md b/docs/configuration.md index 2c1d804e..bbdec8d4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -66,15 +66,28 @@ All logging options are optional. privacy: true ``` +## Init Strategy + +A couple of features use an "init/loading strategy" which configures behavior at Blocky startup. +This applies to all of them. The default strategy is blocking. + +| strategy | Description | +| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| blocking | Initialization happens before DNS resolution starts. Any errors are logged, but Blocky continues running if possible. | +| failOnError | Like blocking but Blocky will exit with an error if initialization fails. | +| fast | Blocky starts serving DNS immediately and initialization happens in the background. The feature requiring initialization will enable later on (if it succeeds). | + ## Upstreams configuration -| Parameter | Type | Mandatory | Default value | Description | -| --------------------- | ------------------------------------ | --------- | ------------- | ----------------------------------------------------------------------------------------------- | -| usptreams.groups | map of name to upstream | yes | | Upstream DNS servers to use, in groups. | -| usptreams.startVerify | bool | no | false | If true, blocky will fail to start unless at least one upstream server per group is functional. | -| usptreams.strategy | enum (parallel_best, random, strict) | no | parallel_best | Upstream server usage strategy. | -| usptreams.timeout | duration | no | 2s | Upstream connection timeout. | -| usptreams.userAgent | string | no | | HTTP User Agent when connecting to upstreams. | +| Parameter | Type | Mandatory | Default value | Description | +| ----------------------- | ------------------------------------ | --------- | ------------- | ---------------------------------------------- | +| usptreams.groups | map of name to upstream | yes | | Upstream DNS servers to use, in groups. | +| usptreams.init.strategy | enum (blocking, failOnError, fast) | no | blocking | See [Init Strategy](#init-strategy) and below. | +| usptreams.strategy | enum (parallel_best, random, strict) | no | parallel_best | Upstream server usage strategy. | +| usptreams.timeout | duration | no | 2s | Upstream connection timeout. | +| usptreams.userAgent | string | no | | HTTP User Agent when connecting to upstreams. | + +For `init.strategy`, the "init" is testing the given resolvers for each group. The potentially fatal error, depending on the strategy, is if a group has no functional resolvers. ### Upstream Groups @@ -847,14 +860,8 @@ Configures how HTTP(S) sources are downloaded: ### Strategy -This configures how Blocky startup works. -The default strategy is blocking. - -| strategy | Description | -| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------- | -| blocking | all sources are loaded before DNS resolution starts | -| failOnError | like blocking but blocky will shut down if any source fails to load | -| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after | +See [Init Strategy](#init-strategy). +In this context, "init" is loading and parsing each source, and an error is a single source failing to load/parse. !!! example diff --git a/e2e/upstream_test.go b/e2e/upstream_test.go index 6f4e5887..b02dd1ff 100644 --- a/e2e/upstream_test.go +++ b/e2e/upstream_test.go @@ -15,8 +15,8 @@ var _ = Describe("Upstream resolver configuration tests", func() { var blocky testcontainers.Container var err error - Describe("'upstreams.startVerify' parameter handling", func() { - When("'upstreams.startVerify' is false and upstream server as IP is not reachable", func() { + Describe("'upstreams.init.strategy' parameter handling", func() { + When("'upstreams.init.strategy' is fast and upstream server as IP is not reachable", func() { BeforeEach(func(ctx context.Context) { blocky, err = createBlockyContainer(ctx, tmpDir, "log:", @@ -25,16 +25,19 @@ var _ = Describe("Upstream resolver configuration tests", func() { " groups:", " default:", " - 192.192.192.192", - " startVerify: false", + " init:", + " strategy: fast", ) Expect(err).Should(Succeed()) }) It("should start even if upstream server is not reachable", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeTrue()) - Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty()) + Eventually(ctx, func() ([]string, error) { + return getContainerLogs(ctx, blocky) + }).Should(ContainElement(ContainSubstring("initial resolver test failed"))) }) }) - When("'upstreams.startVerify' is false and upstream server as host name is not reachable", func() { + When("'upstreams.init.strategy' is fast and upstream server as host name is not reachable", func() { BeforeEach(func(ctx context.Context) { blocky, err = createBlockyContainer(ctx, tmpDir, "log:", @@ -43,23 +46,25 @@ var _ = Describe("Upstream resolver configuration tests", func() { " groups:", " default:", " - some.wrong.host", - " startVerify: false", + " init:", + " strategy: fast", ) Expect(err).Should(Succeed()) }) It("should start even if upstream server is not reachable", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeTrue()) - Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky)).Should(ContainElement(ContainSubstring("initial resolver test failed"))) }) }) - When("'upstreams.startVerify' is true and upstream as IP address server is not reachable", func() { + When("'upstreams.init.strategy' is failOnError and upstream as IP address server is not reachable", func() { BeforeEach(func(ctx context.Context) { blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", " - 192.192.192.192", - " startVerify: true", + " init:", + " strategy: failOnError", ) Expect(err).Should(HaveOccurred()) }) @@ -69,14 +74,15 @@ var _ = Describe("Upstream resolver configuration tests", func() { Should(ContainElement(ContainSubstring("no valid upstream for group default"))) }) }) - When("'upstreams.startVerify' is true and upstream server as host name is not reachable", func() { + When("'upstreams.init.strategy' is failOnError and upstream server as host name is not reachable", func() { BeforeEach(func(ctx context.Context) { blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", " - some.wrong.host", - " startVerify: true", + " init:", + " strategy: failOnError", ) Expect(err).Should(HaveOccurred()) }) diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index 9c87e80d..57d8b379 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -170,7 +170,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, - Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFast}, + Loading: config.SourceLoadingConfig{ + Init: config.Init{Strategy: config.StartStrategyTypeFast}, + }, } }) @@ -1124,8 +1126,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { _, err := NewBlockingResolver(ctx, config.Blocking{ BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")}, WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")}, - Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError}, - BlockType: "zeroIp", + Loading: config.SourceLoadingConfig{ + Init: config.Init{Strategy: config.StartStrategyTypeFailOnError}, + }, + BlockType: "zeroIp", }, nil, systemResolverBootstrap) Expect(err).Should(HaveOccurred()) }) diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index 21f0d83c..ef5537ab 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "errors" "fmt" "math/rand" "net" @@ -24,6 +25,10 @@ const ( defaultTimeout = 5 * time.Second ) +var errArbitrarySystemResolverRequest = errors.New( + "cannot resolve arbitrary requests using the system resolver", +) + // Bootstrap allows resolving hostnames using the configured bootstrap DNS. type Bootstrap struct { configurable[*config.BootstrapDNSConfig] @@ -108,7 +113,7 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model if b.resolver == nil { // We could implement most queries using the `b.systemResolver.Lookup*` functions, // but that requires a lot of boilerplate to translate from `dns` to `net` and back. - return nil, errors.New("cannot resolve arbitrary requests using the system resolver") + return nil, errArbitrarySystemResolverRequest } // Add bootstrap prefix to all inner resolver logs diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index 0c819758..ee977b62 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -167,6 +167,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { } }) + JustBeforeEach(func() { + // Don't count the resolver test + testUpstream.ResetCallCount() + }) + It("should resolve client name", func() { By("first request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") @@ -209,6 +214,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { // no cache -> call count 2 Expect(request.ClientNames).Should(ConsistOf("host1")) + Expect(testUpstream.GetCallCount()).Should(Equal(2)) }) }) }) @@ -223,6 +229,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { } }) + JustBeforeEach(func() { + // Don't count the resolver test + testUpstream.ResetCallCount() + }) + It("should resolve all client names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") Expect(sut.Resolve(ctx, request)). @@ -251,6 +262,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sutConfig.Upstream = testUpstream.Start() }) + JustBeforeEach(func() { + // Don't count the resolver test + testUpstream.ResetCallCount() + }) + It("should resolve client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") Expect(sut.Resolve(ctx, request)). @@ -272,6 +288,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sutConfig.Upstream = testUpstream.Start() }) + JustBeforeEach(func() { + // Don't count the resolver test + testUpstream.ResetCallCount() + }) + It("should resolve the client name depending to defined order", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") Expect(sut.Resolve(ctx, request)). @@ -298,6 +319,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { } }) + JustBeforeEach(func() { + // Don't count the resolver test + testUpstream.ResetCallCount() + }) + It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") Expect(sut.Resolve(ctx, request)). @@ -371,7 +397,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) upstreamsCfg := defaultUpstreamsConfig - upstreamsCfg.StartVerify = true + upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError r, err := NewClientNamesResolver(ctx, config.ClientLookup{ Upstream: config.Upstream{Host: "example.com"}, diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 451932c0..d59c94c0 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -196,9 +196,8 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - upstreamsCfg := config.Upstreams{ - StartVerify: true, - } + upstreamsCfg := defaultUpstreamsConfig + upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError sutConfig := config.ConditionalUpstream{ Mapping: config.ConditionalUpstreamMapping{ diff --git a/resolver/mock_udp_upstream_server.go b/resolver/mock_udp_upstream_server.go index ce8f1a63..9c0ea432 100644 --- a/resolver/mock_udp_upstream_server.go +++ b/resolver/mock_udp_upstream_server.go @@ -87,6 +87,10 @@ func (t *MockUDPUpstreamServer) GetCallCount() int { return int(atomic.LoadInt32(&t.callCount)) } +func (t *MockUDPUpstreamServer) ResetCallCount() { + atomic.StoreInt32(&t.callCount, 0) +} + func (t *MockUDPUpstreamServer) Close() { if t.ln != nil { _ = t.ln.Close() diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index de87f88b..a2f6670a 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -30,7 +30,7 @@ type ParallelBestResolver struct { configurable[*config.UpstreamGroup] typed - resolvers []*upstreamResolverStatus + resolvers atomic.Pointer[[]*upstreamResolverStatus] resolverCount int retryWithDifferentResolver bool @@ -95,14 +95,12 @@ type requestResponse struct { func NewParallelBestResolver( ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, ) (*ParallelBestResolver, error) { - logger := log.PrefixedLog(parallelResolverType) + r := newParallelBestResolver( + cfg, + []Resolver{bootstrap}, // if start strategy is fast, use bootstrap until init finishes + ) - resolvers, err := createResolvers(ctx, logger, cfg, bootstrap) - if err != nil { - return nil, err - } - - return newParallelBestResolver(cfg, resolvers), nil + return initGroupResolvers(ctx, r, cfg, bootstrap) } func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *ParallelBestResolver { @@ -122,31 +120,40 @@ func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *Pa resolverCount: resolverCount, retryWithDifferentResolver: retryWithDifferentResolver, - resolvers: newUpstreamResolverStatuses(resolvers), } + r.setResolvers(newUpstreamResolverStatuses(resolvers)) + return &r } +func (r *ParallelBestResolver) setResolvers(resolvers []*upstreamResolverStatus) { + r.resolvers.Store(&resolvers) +} + func (r *ParallelBestResolver) Name() string { return r.String() } func (r *ParallelBestResolver) String() string { - resolvers := make([]string, len(r.resolvers)) - for i, s := range r.resolvers { - resolvers[i] = fmt.Sprintf("%s", s.resolver) + resolvers := *r.resolvers.Load() + + upstreams := make([]string, len(resolvers)) + for i, s := range resolvers { + upstreams[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(resolvers, ",")) + return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(upstreams, ",")) } // Resolve sends the query request to multiple upstream resolvers and returns the fastest result func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { logger := log.WithPrefix(request.Log, parallelResolverType) - if len(r.resolvers) == 1 { - resolver := r.resolvers[0] + allResolvers := *r.resolvers.Load() + + if len(allResolvers) == 1 { + resolver := allResolvers[0] logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver") return resolver.resolve(ctx, request) @@ -155,7 +162,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque ctx, cancel := context.WithCancel(ctx) defer cancel() // abort requests to resolvers that lost the race - resolvers := pickRandom(r.resolvers, r.resolverCount) + resolvers := pickRandom(allResolvers, r.resolverCount) ch := make(chan requestResponse, len(resolvers)) for _, resolver := range resolvers { @@ -204,7 +211,7 @@ func (r *ParallelBestResolver) retryWithDifferent( ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus, ) (*model.Response, error) { // second try (if retryWithDifferentResolver == true) - resolver := weightedRandom(r.resolvers, resolvers) + resolver := weightedRandom(*r.resolvers.Load(), resolvers) logger.Debugf("using %s as second resolver", resolver.resolver) resp, err := resolver.resolve(ctx, request) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 4f8bd555..51f2a761 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -14,18 +14,14 @@ import ( ) var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { - const ( - verifyUpstreams = true - noVerifyUpstreams = false - ) - var ( - sut *ParallelBestResolver - sutStrategy config.UpstreamStrategy - upstreams []config.Upstream - sutVerify bool - ctx context.Context - cancelFn context.CancelFunc + sut *ParallelBestResolver + sutStrategy config.UpstreamStrategy + sutStartStrategy config.StartStrategyType + upstreams []config.Upstream + + ctx context.Context + cancelFn context.CancelFunc err error @@ -44,17 +40,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}} + sutStartStrategy = config.StartStrategyTypeBlocking sutStrategy = config.UpstreamStrategyParallelBest - sutVerify = noVerifyUpstreams bootstrap = systemResolverBootstrap }) JustBeforeEach(func() { upstreamsCfg := config.Upstreams{ - StartVerify: sutVerify, - Strategy: sutStrategy, - Timeout: config.Duration(timeout), + Init: config.Init{ + Strategy: sutStartStrategy, + }, + Strategy: sutStrategy, + Timeout: config.Duration(timeout), } sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams) @@ -90,9 +88,31 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{} }) - It("should fail on startup", func() { - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError(ContainSubstring("no external DNS resolvers configured"))) + When("using InitStrategyFailOnError", func() { + BeforeEach(func() { + sutStartStrategy = config.StartStrategyTypeFailOnError + }) + It("should fail to start", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + + When("using InitStrategyBlocking", func() { + BeforeEach(func() { + sutStartStrategy = config.StartStrategyTypeBlocking + }) + It("should start", func() { + Expect(err).Should(Succeed()) + }) + }) + + When("using InitStrategyFast", func() { + BeforeEach(func() { + sutInitStrategy = config.InitStrategyFast + }) + It("should start", func() { + Expect(err).Should(Succeed()) + }) }) }) @@ -103,18 +123,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}} }) - When("strict checking is enabled", func() { + When("using InitStrategyFailOnError", func() { BeforeEach(func() { - sutVerify = verifyUpstreams + sutStartStrategy = config.StartStrategyTypeFailOnError }) It("should fail to start", func() { Expect(err).Should(HaveOccurred()) }) }) - When("strict checking is disabled", func() { + When("using InitStrategyBlocking", func() { BeforeEach(func() { - sutVerify = noVerifyUpstreams + sutStartStrategy = config.StartStrategyTypeBlocking + }) + It("should start", func() { + Expect(err).Should(Succeed()) + }) + }) + + When("using InitStrategyFast", func() { + BeforeEach(func() { + sutInitStrategy = config.InitStrategyFast }) It("should start", func() { Expect(err).Should(Succeed()) @@ -131,18 +160,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{timeoutUpstream.Start()} }) - When("strict checking is enabled", func() { + When("using InitStrategyFailOnError", func() { BeforeEach(func() { - sutVerify = verifyUpstreams + sutStartStrategy = config.StartStrategyTypeFailOnError }) It("should fail to start", func() { Expect(err).Should(HaveOccurred()) }) }) - When("strict checking is disabled", func() { + When("using InitStrategyBlocking", func() { BeforeEach(func() { - sutVerify = noVerifyUpstreams + sutStartStrategy = config.StartStrategyTypeBlocking }) It("should start", func() { Expect(err).Should(Succeed()) @@ -156,6 +185,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Expect(isTimeout(err)).Should(BeTrue()) }) }) + + When("using InitStrategyFast", func() { + BeforeEach(func() { + sutInitStrategy = config.InitStrategyFast + }) + It("should start", func() { + Expect(err).Should(Succeed()) + }) + It("should not resolve", func() { + Expect(err).Should(Succeed()) + + request := newRequest("example.com.", A) + _, err := sut.Resolve(ctx, request) + Expect(err).Should(HaveOccurred()) + Expect(err).Should(SatisfyAny( + // The actual error depends on if the init has completed or not + MatchError(isTimeout, "isTimeout"), + MatchError(errArbitrarySystemResolverRequest), + )) + }) + }) }) Describe("Resolving result from fastest upstream resolver", func() { @@ -256,7 +306,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[Resolver]int) for i := 0; i < 1000; i++ { - resolvers := pickRandom(sut.resolvers, parallelBestResolverCount) + resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount) res1 := resolvers[0].resolver res2 := resolvers[1].resolver Expect(res1).ShouldNot(Equal(res2)) @@ -280,7 +330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 100; i++ { - resolvers := pickRandom(sut.resolvers, parallelBestResolverCount) + resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount) res1 := resolvers[0].resolver.(*UpstreamResolver) res2 := resolvers[1].resolver.(*UpstreamResolver) Expect(res1).ShouldNot(Equal(res2)) @@ -307,7 +357,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) upstreamsCfg := sut.cfg.Upstreams - upstreamsCfg.StartVerify = true + upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError group := config.NewUpstreamGroup("test", upstreamsCfg, []config.Upstream{{Host: "example.com"}}) @@ -448,7 +498,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[Resolver]int) for i := 0; i < 2000; i++ { - r := weightedRandom(sut.resolvers, nil) + r := weightedRandom(*sut.resolvers.Load(), nil) resolverCount[r.resolver]++ } for _, v := range resolverCount { @@ -467,7 +517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 200; i++ { - r := weightedRandom(sut.resolvers, nil) + r := weightedRandom(*sut.resolvers.Load(), nil) res := r.resolver.(*UpstreamResolver) resolverCount[res]++ diff --git a/resolver/resolver.go b/resolver/resolver.go index 9a130334..de8c898e 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -216,38 +216,55 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) { c.cfg.LogConfig(logger) } -func createResolvers( - ctx context.Context, logger *logrus.Entry, - cfg config.UpstreamGroup, bootstrap *Bootstrap, -) ([]Resolver, error) { - if len(cfg.GroupUpstreams()) == 0 { - return nil, fmt.Errorf("no external DNS resolvers configured for group %s", cfg.Name) - } +type initializable interface { + log() *logrus.Entry + setResolvers([]*upstreamResolverStatus) +} - resolvers := make([]Resolver, 0, len(cfg.GroupUpstreams())) - hasValidResolvers := false - - for _, u := range cfg.GroupUpstreams() { - resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(u, cfg.Upstreams), bootstrap) +func initGroupResolvers[T initializable]( + ctx context.Context, r T, cfg config.UpstreamGroup, bootstrap *Bootstrap, +) (T, error) { + init := func(ctx context.Context) error { + resolvers, err := createGroupResolvers(ctx, cfg, bootstrap) if err != nil { - logger.Warnf("upstream group %s: %v", cfg.Name, err) - - continue + return err } - if cfg.StartVerify { - err = resolver.testResolve(ctx) - if err != nil { - logger.Warn(err) - } else { - hasValidResolvers = true - } - } + r.setResolvers(resolvers) - resolvers = append(resolvers, resolver) + return nil } - if cfg.StartVerify && !hasValidResolvers { + onErr := func(err error) { + r.log().WithError(err).Error("upstream verification error, will continue to use bootstrap DNS") + } + + err := cfg.Init.Strategy.Do(ctx, init, onErr) + if err != nil { + var zero T + + return zero, err + } + + return r, nil +} + +func createGroupResolvers( + ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, +) ([]*upstreamResolverStatus, error) { + upstreams := cfg.GroupUpstreams() + resolvers := make([]*upstreamResolverStatus, 0, len(upstreams)) + + for _, upstream := range upstreams { + resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(upstream, cfg.Upstreams), bootstrap) + if err != nil { + continue // err was already logged + } + + resolvers = append(resolvers, newUpstreamResolverStatus(resolver)) + } + + if len(resolvers) == 0 { return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name) } diff --git a/resolver/strict_resolver.go b/resolver/strict_resolver.go index cde91e93..11cdc498 100644 --- a/resolver/strict_resolver.go +++ b/resolver/strict_resolver.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" @@ -24,21 +25,19 @@ type StrictResolver struct { configurable[*config.UpstreamGroup] typed - resolvers []*upstreamResolverStatus + resolvers atomic.Pointer[[]*upstreamResolverStatus] } // NewStrictResolver creates a new strict resolver instance func NewStrictResolver( ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, ) (*StrictResolver, error) { - logger := log.PrefixedLog(strictResolverType) + r := newStrictResolver( + cfg, + []Resolver{bootstrap}, // if start strategy is fast, use bootstrap until init finishes + ) - resolvers, err := createResolvers(ctx, logger, cfg, bootstrap) - if err != nil { - return nil, err - } - - return newStrictResolver(cfg, resolvers), nil + return initGroupResolvers(ctx, r, cfg, bootstrap) } func newStrictResolver( @@ -47,24 +46,30 @@ func newStrictResolver( r := StrictResolver{ configurable: withConfig(&cfg), typed: withType(strictResolverType), - - resolvers: newUpstreamResolverStatuses(resolvers), } + r.setResolvers(newUpstreamResolverStatuses(resolvers)) + return &r } +func (r *StrictResolver) setResolvers(resolvers []*upstreamResolverStatus) { + r.resolvers.Store(&resolvers) +} + func (r *StrictResolver) Name() string { return r.String() } func (r *StrictResolver) String() string { - result := make([]string, len(r.resolvers)) - for i, s := range r.resolvers { - result[i] = fmt.Sprintf("%s", s.resolver) + resolvers := *r.resolvers.Load() + + upstreams := make([]string, len(resolvers)) + for i, s := range resolvers { + upstreams[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(result, ",")) + return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(upstreams, ",")) } // Resolve sends the query request in a strict order to the upstream resolvers @@ -72,7 +77,7 @@ func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (* logger := log.WithPrefix(request.Log, strictResolverType) // start with first resolver - for _, resolver := range r.resolvers { + for _, resolver := range *r.resolvers.Load() { logger.Debugf("using %s as resolver", resolver.resolver) resp, err := resolver.resolve(ctx, request) diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index b21e36a7..2fb47fe5 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -15,15 +15,10 @@ import ( ) var _ = Describe("StrictResolver", Label("strictResolver"), func() { - const ( - verifyUpstreams = true - noVerifyUpstreams = false - ) - var ( - sut *StrictResolver - upstreams []config.Upstream - sutVerify bool + sut *StrictResolver + sutStartStrategy config.StartStrategyType + upstreams []config.Upstream err error @@ -52,14 +47,14 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { {Host: "127.0.0.2"}, } - sutVerify = noVerifyUpstreams + sutStartStrategy = config.StartStrategyTypeBlocking bootstrap = systemResolverBootstrap }) JustBeforeEach(func() { upstreamsCfg := defaultUpstreamsConfig - upstreamsCfg.StartVerify = sutVerify + upstreamsCfg.Init.Strategy = sutStartStrategy sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams) sutConfig.Timeout = config.Duration(timeout) @@ -125,7 +120,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { When("strict checking is enabled", func() { BeforeEach(func() { - sutVerify = verifyUpstreams + sutStartStrategy = config.StartStrategyTypeFailOnError }) It("should fail to start", func() { Expect(err).Should(HaveOccurred()) @@ -134,7 +129,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { When("strict checking is disabled", func() { BeforeEach(func() { - sutVerify = noVerifyUpstreams + sutStartStrategy = config.StartStrategyTypeBlocking }) It("should start", func() { Expect(err).Should(Succeed()) diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index ce7b7f78..b22f90b0 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -216,11 +216,13 @@ func NewUpstreamResolver( ) (*UpstreamResolver, error) { r := newUpstreamResolverUnchecked(cfg, bootstrap) - if cfg.StartVerify { - _, err := r.bootstrap.UpstreamIPs(ctx, r) - if err != nil { - return nil, err - } + onErr := func(err error) { + r.log().WithError(err).Warn("initial resolver test failed") + } + + err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr) + if err != nil { + return nil, err } return r, nil diff --git a/resolver/upstream_tree_resolver_test.go b/resolver/upstream_tree_resolver_test.go index 6424fc22..5338c870 100644 --- a/resolver/upstream_tree_resolver_test.go +++ b/resolver/upstream_tree_resolver_test.go @@ -132,9 +132,9 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) }) - When("start verify is enabled", func() { + When("start strategy is failOnError", func() { BeforeEach(func() { - sutConfig.StartVerify = true + sutConfig.Init.Strategy = config.StartStrategyTypeFailOnError }) It("should fail", func() {