feat: add `upstreams.init.strategy`

Replaces `startVerifyUpstream` and behaves just like
`blocking.loading.strategy`.

We use the bootstrap resolver for any requests that arrive before the
upstreams are initialized.
This commit is contained in:
ThinkChaos 2023-12-03 17:46:59 -05:00
parent 659076dd7b
commit 7a3c054b43
19 changed files with 429 additions and 207 deletions

View File

@ -34,13 +34,13 @@ func (c *Blocking) migrate(logger *logrus.Entry) bool {
"downloadAttempts": Move(To("loading.downloads.attempts", &c.Loading.Downloads)),
"downloadCooldown": Move(To("loading.downloads.cooldown", &c.Loading.Downloads)),
"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
"failStartOnListError": Apply(To("loading.strategy", &c.Loading), func(oldValue bool) {
"failStartOnListError": Apply(To("loading.strategy", &c.Loading.Init), func(oldValue bool) {
if oldValue {
c.Loading.Strategy = StartStrategyTypeFailOnError
}
}),
"processingConcurrency": Move(To("loading.concurrency", &c.Loading)),
"startStrategy": Move(To("loading.strategy", &c.Loading)),
"startStrategy": Move(To("loading.strategy", &c.Loading.Init)),
"maxErrorsPerFile": Move(To("loading.maxErrorsPerSource", &c.Loading)),
})
}

View File

@ -116,10 +116,14 @@ type QueryLogType int16
// )
type StartStrategyType uint16
func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
if *s == StartStrategyTypeFast {
func (s StartStrategyType) Do(ctx context.Context, init func(context.Context) error, logErr func(error)) error {
init = recoverToError(init, func(panicVal any) error {
return fmt.Errorf("panic during initialization: %v", panicVal)
})
if s == StartStrategyTypeFast {
go func() {
err := setup()
err := init(ctx)
if err != nil {
logErr(err)
}
@ -128,11 +132,11 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
return nil
}
err := setup()
err := init(ctx)
if err != nil {
logErr(err)
if *s == StartStrategyTypeFailOnError {
if s == StartStrategyTypeFailOnError {
return err
}
}
@ -309,18 +313,27 @@ func (c *toEnable) LogConfig(logger *logrus.Entry) {
logger.Info("enabled")
}
type Init struct {
Strategy StartStrategyType `yaml:"strategy" default:"blocking"`
}
func (c *Init) LogConfig(logger *logrus.Entry) {
logger.Debugf("strategy = %s", c.Strategy)
}
type SourceLoadingConfig struct {
Concurrency uint `yaml:"concurrency" default:"4"`
MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
Strategy StartStrategyType `yaml:"strategy" default:"blocking"`
Downloads DownloaderConfig `yaml:"downloads"`
Init `yaml:",inline"`
Concurrency uint `yaml:"concurrency" default:"4"`
MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
Downloads DownloaderConfig `yaml:"downloads"`
}
func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
c.Init.LogConfig(logger)
logger.Infof("concurrency = %d", c.Concurrency)
logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource)
logger.Debugf("strategy = %s", c.Strategy)
if c.RefreshPeriod.IsAboveZero() {
logger.Infof("refresh = every %s", c.RefreshPeriod)
@ -332,36 +345,28 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
log.WithIndent(logger, " ", c.Downloads.LogConfig)
}
func (c *SourceLoadingConfig) StartPeriodicRefresh(ctx context.Context,
refresh func(context.Context) error,
logErr func(error),
func (c *SourceLoadingConfig) StartPeriodicRefresh(
ctx context.Context, refresh func(context.Context) error, logErr func(error),
) error {
refreshAndRecover := func(ctx context.Context) (rerr error) {
defer func() {
if val := recover(); val != nil {
rerr = fmt.Errorf("refresh function panicked: %v", val)
}
}()
return refresh(ctx)
}
err := c.Strategy.do(func() error { return refreshAndRecover(context.Background()) }, logErr)
err := c.Strategy.Do(ctx, refresh, logErr)
if err != nil {
return err
}
if c.RefreshPeriod > 0 {
go c.periodically(ctx, refreshAndRecover, logErr)
go c.periodically(ctx, refresh, logErr)
}
return nil
}
func (c *SourceLoadingConfig) periodically(ctx context.Context,
refresh func(context.Context) error,
logErr func(error),
func (c *SourceLoadingConfig) periodically(
ctx context.Context, refresh func(context.Context) error, logErr func(error),
) {
refresh = recoverToError(refresh, func(panicVal any) error {
return fmt.Errorf("panic during refresh: %v", panicVal)
})
ticker := time.NewTicker(c.RefreshPeriod.ToDuration())
defer ticker.Stop()
@ -379,6 +384,18 @@ func (c *SourceLoadingConfig) periodically(ctx context.Context,
}
}
func recoverToError(do func(context.Context) error, onPanic func(any) error) func(context.Context) error {
return func(ctx context.Context) (rerr error) {
defer func() {
if val := recover(); val != nil {
rerr = onPanic(val)
}
}()
return do(ctx)
}
}
type DownloaderConfig struct {
Timeout Duration `yaml:"timeout" default:"5s"`
Attempts uint `yaml:"attempts" default:"3"`
@ -535,16 +552,22 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
cfg.Filtering.QueryTypes.Insert(dns.Type(dns.TypeAAAA))
}
}),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"startVerifyUpstream": Move(To("upstreams.startVerify", &cfg.Upstreams)),
"dohUserAgent": Move(To("upstreams.userAgent", &cfg.Upstreams)),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"dohUserAgent": Move(To("upstreams.userAgent", &cfg.Upstreams)),
"startVerifyUpstream": Apply(To("upstreams.init.strategy", &cfg.Upstreams.Init), func(value bool) {
if value {
cfg.Upstreams.Init.Strategy = StartStrategyTypeFailOnError
} else {
cfg.Upstreams.Init.Strategy = StartStrategyTypeFast
}
}),
})
usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts

View File

@ -142,6 +142,15 @@ var _ = Describe("Config", func() {
Expect(c.Ports.TLS).Should(Equal(ports))
})
})
When("parameter 'startVerifyUpstream' is set", func() {
It("should convert to upstreams.init.strategy", func() {
c.Deprecated.StartVerifyUpstream = ptrOf(true)
c.migrate(logger)
Expect(hook.Messages).Should(ContainElement(ContainSubstring("startVerifyUpstream")))
Expect(c.Upstreams.Init.Strategy).Should(Equal(StartStrategyTypeFailOnError))
})
})
})
Describe("Creation of Config", func() {
@ -552,8 +561,10 @@ bootstrapDns:
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages[0]).Should(Equal("concurrency = 12"))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
Expect(hook.Messages).Should(ContainElements(
ContainSubstring("concurrency = 12"),
ContainSubstring("refresh = every 1 hour"),
))
})
When("refresh is disabled", func() {
BeforeEach(func() {
@ -590,9 +601,11 @@ bootstrapDns:
Expect(recover()).Should(BeIdenticalTo(panicVal))
}()
_ = sut.do(func() error {
_ = sut.Do(context.Background(), func(context.Context) error {
return errors.New("trigger `logErr`")
}, func(err error) {
panic(panicVal)
}, nil)
})
Fail("unreachable")
})
@ -601,7 +614,7 @@ bootstrapDns:
sut := StartStrategyTypeBlocking
expectedErr := errors.New("test")
err := sut.do(func() error {
err := sut.Do(context.Background(), func(context.Context) error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
@ -609,11 +622,26 @@ bootstrapDns:
Expect(err).Should(Succeed())
})
It("logs panics and doesn't convert them to errors", func() {
sut := StartStrategyTypeBlocking
logged := false
err := sut.Do(context.Background(), func(context.Context) error {
panic(struct{}{})
}, func(err error) {
logged = true
Expect(err).Should(MatchError(ContainSubstring("panic")))
})
Expect(err).Should(Succeed())
Expect(logged).Should(BeTrue())
})
})
Describe("StartStrategyTypeFailOnError", func() {
It("runs in the current goroutine", func() {
sut := StartStrategyTypeBlocking
sut := StartStrategyTypeFailOnError
panicVal := new(int)
defer func() {
@ -621,9 +649,11 @@ bootstrapDns:
Expect(recover()).Should(BeIdenticalTo(panicVal))
}()
_ = sut.do(func() error {
_ = sut.Do(context.Background(), func(context.Context) error {
return errors.New("trigger `logErr`")
}, func(err error) {
panic(panicVal)
}, nil)
})
Fail("unreachable")
})
@ -632,7 +662,7 @@ bootstrapDns:
sut := StartStrategyTypeFailOnError
expectedErr := errors.New("test")
err := sut.do(func() error {
err := sut.Do(context.Background(), func(context.Context) error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
@ -640,6 +670,21 @@ bootstrapDns:
Expect(err).Should(MatchError(expectedErr))
})
It("returns logs panics and converts them to errors", func() {
sut := StartStrategyTypeFailOnError
logged := false
err := sut.Do(context.Background(), func(context.Context) error {
panic(struct{}{})
}, func(err error) {
logged = true
Expect(err).Should(MatchError(ContainSubstring("panic")))
})
Expect(err).Should(HaveOccurred())
Expect(logged).Should(BeTrue())
})
})
Describe("StartStrategyTypeFast", func() {
@ -648,7 +693,7 @@ bootstrapDns:
events := make(chan string)
wait := make(chan struct{})
err := sut.do(func() error {
err := sut.Do(context.Background(), func(context.Context) error {
events <- "start"
<-wait
events <- "done"
@ -668,7 +713,23 @@ bootstrapDns:
expectedErr := errors.New("test")
wait := make(chan struct{})
err := sut.do(func() error {
err := sut.Do(context.Background(), func(context.Context) error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
close(wait)
})
Expect(err).Should(Succeed())
Eventually(wait, "50ms").Should(BeClosed())
})
It("logs panics", func() {
sut := StartStrategyTypeFast
expectedErr := errors.New("test")
wait := make(chan struct{})
err := sut.Do(context.Background(), func(context.Context) error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
@ -717,7 +778,7 @@ bootstrapDns:
})
It("handles panics", func() {
sut := SourceLoadingConfig{
Strategy: StartStrategyTypeFailOnError,
Init: Init{Strategy: StartStrategyTypeFailOnError},
}
panicMsg := "panic value"
@ -733,7 +794,7 @@ bootstrapDns:
It("periodically calls refresh", func() {
sut := SourceLoadingConfig{
Strategy: StartStrategyTypeFast,
Init: Init{Strategy: StartStrategyTypeFast},
RefreshPeriod: Duration(5 * time.Millisecond),
}
@ -789,7 +850,7 @@ bootstrapDns:
func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.StartVerify).Should(BeFalse())
Expect(config.Upstreams.Init.Strategy).Should(Equal(StartStrategyTypeFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8"))
@ -824,8 +885,9 @@ func defaultTestFileConfig(config *Config) {
func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
return tmpDir.CreateStringFile("config.yml",
"upstreams:",
" startVerify: false",
" userAgent: testBlocky",
" init:",
" strategy: failOnError",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
@ -885,8 +947,9 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
func writeConfigDir(tmpDir *helpertest.TmpFolder) {
tmpDir.CreateStringFile("config1.yaml",
"upstreams:",
" startVerify: false",
" userAgent: testBlocky",
" init:",
" strategy: failOnError",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",

View File

@ -1,6 +1,7 @@
package config
import (
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
)
@ -8,11 +9,11 @@ const UpstreamDefaultCfgName = "default"
// Upstreams upstream servers configuration
type Upstreams struct {
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
StartVerify bool `yaml:"startVerify" default:"false"`
UserAgent string `yaml:"userAgent"`
Init Init `yaml:"init"`
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
UserAgent string `yaml:"userAgent"`
}
type UpstreamGroups map[string][]Upstream
@ -24,6 +25,9 @@ func (c *Upstreams) IsEnabled() bool {
// LogConfig implements `config.Configurable`.
func (c *Upstreams) LogConfig(logger *logrus.Entry) {
logger.Info("init:")
log.WithIndent(logger, " ", c.Init.LogConfig)
logger.Info("timeout: ", c.Timeout)
logger.Info("strategy: ", c.Strategy)
logger.Info("groups:")

View File

@ -1,6 +1,11 @@
# REVIEW: manual changelog entry
upstreams:
init:
# Configure startup behavior.
# accepted: blocking, failOnError, fast
# default: blocking
strategy: fast
groups:
# these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query
# format for resolver: [net:]host:[port][/path]. net could be empty (default, shortcut for tcp+udp), tcp+udp, tcp, udp, tcp-tls or https (DoH). If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls, 443 for https (Doh))
@ -24,8 +29,6 @@ upstreams:
strategy: parallel_best
# optional: timeout to query the upstream resolver. Default: 2s
timeout: 2s
# optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false
startVerify: false
# optional: HTTP User Agent when connecting to upstreams. Default: none
userAgent: "custom UA"
@ -126,7 +129,8 @@ blocking:
# optional: Maximum number of lists to process in parallel.
# default: 4
concurrency: 16
# optional: if failOnError, application startup will fail if at least one list can't be downloaded/opened
# Configure startup behavior.
# accepted: blocking, failOnError, fast
# default: blocking
strategy: failOnError
# Number of errors allowed in a list before it is considered invalid.
@ -292,7 +296,8 @@ hostsFile:
# optional: Maximum number of files to process in parallel.
# default: 4
concurrency: 16
# optional: if failOnError, application startup will fail if at least one file can't be downloaded/opened
# Configure startup behavior.
# accepted: blocking, failOnError, fast
# default: blocking
strategy: failOnError
# Number of errors allowed in a file before it is considered invalid.

View File

@ -66,15 +66,28 @@ All logging options are optional.
privacy: true
```
## Init Strategy
A couple of features use an "init/loading strategy" which configures behavior at Blocky startup.
This applies to all of them. The default strategy is blocking.
| strategy | Description |
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| blocking | Initialization happens before DNS resolution starts. Any errors are logged, but Blocky continues running if possible. |
| failOnError | Like blocking but Blocky will exit with an error if initialization fails. |
| fast | Blocky starts serving DNS immediately and initialization happens in the background. The feature requiring initialization will enable later on (if it succeeds). |
## Upstreams configuration
| Parameter | Type | Mandatory | Default value | Description |
| --------------------- | ------------------------------------ | --------- | ------------- | ----------------------------------------------------------------------------------------------- |
| usptreams.groups | map of name to upstream | yes | | Upstream DNS servers to use, in groups. |
| usptreams.startVerify | bool | no | false | If true, blocky will fail to start unless at least one upstream server per group is functional. |
| usptreams.strategy | enum (parallel_best, random, strict) | no | parallel_best | Upstream server usage strategy. |
| usptreams.timeout | duration | no | 2s | Upstream connection timeout. |
| usptreams.userAgent | string | no | | HTTP User Agent when connecting to upstreams. |
| Parameter | Type | Mandatory | Default value | Description |
| ----------------------- | ------------------------------------ | --------- | ------------- | ---------------------------------------------- |
| usptreams.groups | map of name to upstream | yes | | Upstream DNS servers to use, in groups. |
| usptreams.init.strategy | enum (blocking, failOnError, fast) | no | blocking | See [Init Strategy](#init-strategy) and below. |
| usptreams.strategy | enum (parallel_best, random, strict) | no | parallel_best | Upstream server usage strategy. |
| usptreams.timeout | duration | no | 2s | Upstream connection timeout. |
| usptreams.userAgent | string | no | | HTTP User Agent when connecting to upstreams. |
For `init.strategy`, the "init" is testing the given resolvers for each group. The potentially fatal error, depending on the strategy, is if a group has no functional resolvers.
### Upstream Groups
@ -847,14 +860,8 @@ Configures how HTTP(S) sources are downloaded:
### Strategy
This configures how Blocky startup works.
The default strategy is blocking.
| strategy | Description |
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| blocking | all sources are loaded before DNS resolution starts |
| failOnError | like blocking but blocky will shut down if any source fails to load |
| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after |
See [Init Strategy](#init-strategy).
In this context, "init" is loading and parsing each source, and an error is a single source failing to load/parse.
!!! example

View File

@ -15,8 +15,8 @@ var _ = Describe("Upstream resolver configuration tests", func() {
var blocky testcontainers.Container
var err error
Describe("'upstreams.startVerify' parameter handling", func() {
When("'upstreams.startVerify' is false and upstream server as IP is not reachable", func() {
Describe("'upstreams.init.strategy' parameter handling", func() {
When("'upstreams.init.strategy' is fast and upstream server as IP is not reachable", func() {
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
@ -25,16 +25,19 @@ var _ = Describe("Upstream resolver configuration tests", func() {
" groups:",
" default:",
" - 192.192.192.192",
" startVerify: false",
" init:",
" strategy: fast",
)
Expect(err).Should(Succeed())
})
It("should start even if upstream server is not reachable", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeTrue())
Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty())
Eventually(ctx, func() ([]string, error) {
return getContainerLogs(ctx, blocky)
}).Should(ContainElement(ContainSubstring("initial resolver test failed")))
})
})
When("'upstreams.startVerify' is false and upstream server as host name is not reachable", func() {
When("'upstreams.init.strategy' is fast and upstream server as host name is not reachable", func() {
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
@ -43,23 +46,25 @@ var _ = Describe("Upstream resolver configuration tests", func() {
" groups:",
" default:",
" - some.wrong.host",
" startVerify: false",
" init:",
" strategy: fast",
)
Expect(err).Should(Succeed())
})
It("should start even if upstream server is not reachable", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeTrue())
Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky)).Should(ContainElement(ContainSubstring("initial resolver test failed")))
})
})
When("'upstreams.startVerify' is true and upstream as IP address server is not reachable", func() {
When("'upstreams.init.strategy' is failOnError and upstream as IP address server is not reachable", func() {
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
" - 192.192.192.192",
" startVerify: true",
" init:",
" strategy: failOnError",
)
Expect(err).Should(HaveOccurred())
})
@ -69,14 +74,15 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
When("'upstreams.startVerify' is true and upstream server as host name is not reachable", func() {
When("'upstreams.init.strategy' is failOnError and upstream server as host name is not reachable", func() {
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
" - some.wrong.host",
" startVerify: true",
" init:",
" strategy: failOnError",
)
Expect(err).Should(HaveOccurred())
})

View File

@ -170,7 +170,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFast},
Loading: config.SourceLoadingConfig{
Init: config.Init{Strategy: config.StartStrategyTypeFast},
},
}
})
@ -1124,8 +1126,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
_, err := NewBlockingResolver(ctx, config.Blocking{
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")},
WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")},
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError},
BlockType: "zeroIp",
Loading: config.SourceLoadingConfig{
Init: config.Init{Strategy: config.StartStrategyTypeFailOnError},
},
BlockType: "zeroIp",
}, nil, systemResolverBootstrap)
Expect(err).Should(HaveOccurred())
})

View File

@ -2,6 +2,7 @@ package resolver
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
@ -24,6 +25,10 @@ const (
defaultTimeout = 5 * time.Second
)
var errArbitrarySystemResolverRequest = errors.New(
"cannot resolve arbitrary requests using the system resolver",
)
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
type Bootstrap struct {
configurable[*config.BootstrapDNSConfig]
@ -108,7 +113,7 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model
if b.resolver == nil {
// We could implement most queries using the `b.systemResolver.Lookup*` functions,
// but that requires a lot of boilerplate to translate from `dns` to `net` and back.
return nil, errors.New("cannot resolve arbitrary requests using the system resolver")
return nil, errArbitrarySystemResolverRequest
}
// Add bootstrap prefix to all inner resolver logs

View File

@ -167,6 +167,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}
})
JustBeforeEach(func() {
// Don't count the resolver test
testUpstream.ResetCallCount()
})
It("should resolve client name", func() {
By("first request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
@ -209,6 +214,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
// no cache -> call count 2
Expect(request.ClientNames).Should(ConsistOf("host1"))
Expect(testUpstream.GetCallCount()).Should(Equal(2))
})
})
})
@ -223,6 +229,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}
})
JustBeforeEach(func() {
// Don't count the resolver test
testUpstream.ResetCallCount()
})
It("should resolve all client names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(ctx, request)).
@ -251,6 +262,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sutConfig.Upstream = testUpstream.Start()
})
JustBeforeEach(func() {
// Don't count the resolver test
testUpstream.ResetCallCount()
})
It("should resolve client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(ctx, request)).
@ -272,6 +288,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sutConfig.Upstream = testUpstream.Start()
})
JustBeforeEach(func() {
// Don't count the resolver test
testUpstream.ResetCallCount()
})
It("should resolve the client name depending to defined order", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(ctx, request)).
@ -298,6 +319,11 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}
})
JustBeforeEach(func() {
// Don't count the resolver test
testUpstream.ResetCallCount()
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(ctx, request)).
@ -371,7 +397,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.StartVerify = true
upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError
r, err := NewClientNamesResolver(ctx, config.ClientLookup{
Upstream: config.Upstream{Host: "example.com"},

View File

@ -196,9 +196,8 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
upstreamsCfg := config.Upstreams{
StartVerify: true,
}
upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError
sutConfig := config.ConditionalUpstream{
Mapping: config.ConditionalUpstreamMapping{

View File

@ -87,6 +87,10 @@ func (t *MockUDPUpstreamServer) GetCallCount() int {
return int(atomic.LoadInt32(&t.callCount))
}
func (t *MockUDPUpstreamServer) ResetCallCount() {
atomic.StoreInt32(&t.callCount, 0)
}
func (t *MockUDPUpstreamServer) Close() {
if t.ln != nil {
_ = t.ln.Close()

View File

@ -30,7 +30,7 @@ type ParallelBestResolver struct {
configurable[*config.UpstreamGroup]
typed
resolvers []*upstreamResolverStatus
resolvers atomic.Pointer[[]*upstreamResolverStatus]
resolverCount int
retryWithDifferentResolver bool
@ -95,14 +95,12 @@ type requestResponse struct {
func NewParallelBestResolver(
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType)
r := newParallelBestResolver(
cfg,
[]Resolver{bootstrap}, // if start strategy is fast, use bootstrap until init finishes
)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap)
if err != nil {
return nil, err
}
return newParallelBestResolver(cfg, resolvers), nil
return initGroupResolvers(ctx, r, cfg, bootstrap)
}
func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *ParallelBestResolver {
@ -122,31 +120,40 @@ func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *Pa
resolverCount: resolverCount,
retryWithDifferentResolver: retryWithDifferentResolver,
resolvers: newUpstreamResolverStatuses(resolvers),
}
r.setResolvers(newUpstreamResolverStatuses(resolvers))
return &r
}
func (r *ParallelBestResolver) setResolvers(resolvers []*upstreamResolverStatus) {
r.resolvers.Store(&resolvers)
}
func (r *ParallelBestResolver) Name() string {
return r.String()
}
func (r *ParallelBestResolver) String() string {
resolvers := make([]string, len(r.resolvers))
for i, s := range r.resolvers {
resolvers[i] = fmt.Sprintf("%s", s.resolver)
resolvers := *r.resolvers.Load()
upstreams := make([]string, len(resolvers))
for i, s := range resolvers {
upstreams[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(resolvers, ","))
return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(upstreams, ","))
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType)
if len(r.resolvers) == 1 {
resolver := r.resolvers[0]
allResolvers := *r.resolvers.Load()
if len(allResolvers) == 1 {
resolver := allResolvers[0]
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
return resolver.resolve(ctx, request)
@ -155,7 +162,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque
ctx, cancel := context.WithCancel(ctx)
defer cancel() // abort requests to resolvers that lost the race
resolvers := pickRandom(r.resolvers, r.resolverCount)
resolvers := pickRandom(allResolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers))
for _, resolver := range resolvers {
@ -204,7 +211,7 @@ func (r *ParallelBestResolver) retryWithDifferent(
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(r.resolvers, resolvers)
resolver := weightedRandom(*r.resolvers.Load(), resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver)
resp, err := resolver.resolve(ctx, request)

View File

@ -14,18 +14,14 @@ import (
)
var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
const (
verifyUpstreams = true
noVerifyUpstreams = false
)
var (
sut *ParallelBestResolver
sutStrategy config.UpstreamStrategy
upstreams []config.Upstream
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
sut *ParallelBestResolver
sutStrategy config.UpstreamStrategy
sutStartStrategy config.StartStrategyType
upstreams []config.Upstream
ctx context.Context
cancelFn context.CancelFunc
err error
@ -44,17 +40,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}}
sutStartStrategy = config.StartStrategyTypeBlocking
sutStrategy = config.UpstreamStrategyParallelBest
sutVerify = noVerifyUpstreams
bootstrap = systemResolverBootstrap
})
JustBeforeEach(func() {
upstreamsCfg := config.Upstreams{
StartVerify: sutVerify,
Strategy: sutStrategy,
Timeout: config.Duration(timeout),
Init: config.Init{
Strategy: sutStartStrategy,
},
Strategy: sutStrategy,
Timeout: config.Duration(timeout),
}
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
@ -90,9 +88,31 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{}
})
It("should fail on startup", func() {
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError(ContainSubstring("no external DNS resolvers configured")))
When("using InitStrategyFailOnError", func() {
BeforeEach(func() {
sutStartStrategy = config.StartStrategyTypeFailOnError
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("using InitStrategyBlocking", func() {
BeforeEach(func() {
sutStartStrategy = config.StartStrategyTypeBlocking
})
It("should start", func() {
Expect(err).Should(Succeed())
})
})
When("using InitStrategyFast", func() {
BeforeEach(func() {
sutInitStrategy = config.InitStrategyFast
})
It("should start", func() {
Expect(err).Should(Succeed())
})
})
})
@ -103,18 +123,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}}
})
When("strict checking is enabled", func() {
When("using InitStrategyFailOnError", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
sutStartStrategy = config.StartStrategyTypeFailOnError
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("strict checking is disabled", func() {
When("using InitStrategyBlocking", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
sutStartStrategy = config.StartStrategyTypeBlocking
})
It("should start", func() {
Expect(err).Should(Succeed())
})
})
When("using InitStrategyFast", func() {
BeforeEach(func() {
sutInitStrategy = config.InitStrategyFast
})
It("should start", func() {
Expect(err).Should(Succeed())
@ -131,18 +160,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{timeoutUpstream.Start()}
})
When("strict checking is enabled", func() {
When("using InitStrategyFailOnError", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
sutStartStrategy = config.StartStrategyTypeFailOnError
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("strict checking is disabled", func() {
When("using InitStrategyBlocking", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
sutStartStrategy = config.StartStrategyTypeBlocking
})
It("should start", func() {
Expect(err).Should(Succeed())
@ -156,6 +185,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Expect(isTimeout(err)).Should(BeTrue())
})
})
When("using InitStrategyFast", func() {
BeforeEach(func() {
sutInitStrategy = config.InitStrategyFast
})
It("should start", func() {
Expect(err).Should(Succeed())
})
It("should not resolve", func() {
Expect(err).Should(Succeed())
request := newRequest("example.com.", A)
_, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
Expect(err).Should(SatisfyAny(
// The actual error depends on if the init has completed or not
MatchError(isTimeout, "isTimeout"),
MatchError(errArbitrarySystemResolverRequest),
))
})
})
})
Describe("Resolving result from fastest upstream resolver", func() {
@ -256,7 +306,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 1000; i++ {
resolvers := pickRandom(sut.resolvers, parallelBestResolverCount)
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver
res2 := resolvers[1].resolver
Expect(res1).ShouldNot(Equal(res2))
@ -280,7 +330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 100; i++ {
resolvers := pickRandom(sut.resolvers, parallelBestResolverCount)
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver.(*UpstreamResolver)
res2 := resolvers[1].resolver.(*UpstreamResolver)
Expect(res1).ShouldNot(Equal(res2))
@ -307,7 +357,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
upstreamsCfg := sut.cfg.Upstreams
upstreamsCfg.StartVerify = true
upstreamsCfg.Init.Strategy = config.StartStrategyTypeFailOnError
group := config.NewUpstreamGroup("test", upstreamsCfg, []config.Upstream{{Host: "example.com"}})
@ -448,7 +498,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 2000; i++ {
r := weightedRandom(sut.resolvers, nil)
r := weightedRandom(*sut.resolvers.Load(), nil)
resolverCount[r.resolver]++
}
for _, v := range resolverCount {
@ -467,7 +517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 200; i++ {
r := weightedRandom(sut.resolvers, nil)
r := weightedRandom(*sut.resolvers.Load(), nil)
res := r.resolver.(*UpstreamResolver)
resolverCount[res]++

View File

@ -216,38 +216,55 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
c.cfg.LogConfig(logger)
}
func createResolvers(
ctx context.Context, logger *logrus.Entry,
cfg config.UpstreamGroup, bootstrap *Bootstrap,
) ([]Resolver, error) {
if len(cfg.GroupUpstreams()) == 0 {
return nil, fmt.Errorf("no external DNS resolvers configured for group %s", cfg.Name)
}
type initializable interface {
log() *logrus.Entry
setResolvers([]*upstreamResolverStatus)
}
resolvers := make([]Resolver, 0, len(cfg.GroupUpstreams()))
hasValidResolvers := false
for _, u := range cfg.GroupUpstreams() {
resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(u, cfg.Upstreams), bootstrap)
func initGroupResolvers[T initializable](
ctx context.Context, r T, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) (T, error) {
init := func(ctx context.Context) error {
resolvers, err := createGroupResolvers(ctx, cfg, bootstrap)
if err != nil {
logger.Warnf("upstream group %s: %v", cfg.Name, err)
continue
return err
}
if cfg.StartVerify {
err = resolver.testResolve(ctx)
if err != nil {
logger.Warn(err)
} else {
hasValidResolvers = true
}
}
r.setResolvers(resolvers)
resolvers = append(resolvers, resolver)
return nil
}
if cfg.StartVerify && !hasValidResolvers {
onErr := func(err error) {
r.log().WithError(err).Error("upstream verification error, will continue to use bootstrap DNS")
}
err := cfg.Init.Strategy.Do(ctx, init, onErr)
if err != nil {
var zero T
return zero, err
}
return r, nil
}
func createGroupResolvers(
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) ([]*upstreamResolverStatus, error) {
upstreams := cfg.GroupUpstreams()
resolvers := make([]*upstreamResolverStatus, 0, len(upstreams))
for _, upstream := range upstreams {
resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(upstream, cfg.Upstreams), bootstrap)
if err != nil {
continue // err was already logged
}
resolvers = append(resolvers, newUpstreamResolverStatus(resolver))
}
if len(resolvers) == 0 {
return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name)
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"strings"
"sync/atomic"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
@ -24,21 +25,19 @@ type StrictResolver struct {
configurable[*config.UpstreamGroup]
typed
resolvers []*upstreamResolverStatus
resolvers atomic.Pointer[[]*upstreamResolverStatus]
}
// NewStrictResolver creates a new strict resolver instance
func NewStrictResolver(
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType)
r := newStrictResolver(
cfg,
[]Resolver{bootstrap}, // if start strategy is fast, use bootstrap until init finishes
)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap)
if err != nil {
return nil, err
}
return newStrictResolver(cfg, resolvers), nil
return initGroupResolvers(ctx, r, cfg, bootstrap)
}
func newStrictResolver(
@ -47,24 +46,30 @@ func newStrictResolver(
r := StrictResolver{
configurable: withConfig(&cfg),
typed: withType(strictResolverType),
resolvers: newUpstreamResolverStatuses(resolvers),
}
r.setResolvers(newUpstreamResolverStatuses(resolvers))
return &r
}
func (r *StrictResolver) setResolvers(resolvers []*upstreamResolverStatus) {
r.resolvers.Store(&resolvers)
}
func (r *StrictResolver) Name() string {
return r.String()
}
func (r *StrictResolver) String() string {
result := make([]string, len(r.resolvers))
for i, s := range r.resolvers {
result[i] = fmt.Sprintf("%s", s.resolver)
resolvers := *r.resolvers.Load()
upstreams := make([]string, len(resolvers))
for i, s := range resolvers {
upstreams[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(result, ","))
return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(upstreams, ","))
}
// Resolve sends the query request in a strict order to the upstream resolvers
@ -72,7 +77,7 @@ func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*
logger := log.WithPrefix(request.Log, strictResolverType)
// start with first resolver
for _, resolver := range r.resolvers {
for _, resolver := range *r.resolvers.Load() {
logger.Debugf("using %s as resolver", resolver.resolver)
resp, err := resolver.resolve(ctx, request)

View File

@ -15,15 +15,10 @@ import (
)
var _ = Describe("StrictResolver", Label("strictResolver"), func() {
const (
verifyUpstreams = true
noVerifyUpstreams = false
)
var (
sut *StrictResolver
upstreams []config.Upstream
sutVerify bool
sut *StrictResolver
sutStartStrategy config.StartStrategyType
upstreams []config.Upstream
err error
@ -52,14 +47,14 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
{Host: "127.0.0.2"},
}
sutVerify = noVerifyUpstreams
sutStartStrategy = config.StartStrategyTypeBlocking
bootstrap = systemResolverBootstrap
})
JustBeforeEach(func() {
upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.StartVerify = sutVerify
upstreamsCfg.Init.Strategy = sutStartStrategy
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
sutConfig.Timeout = config.Duration(timeout)
@ -125,7 +120,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
When("strict checking is enabled", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
sutStartStrategy = config.StartStrategyTypeFailOnError
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
@ -134,7 +129,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
When("strict checking is disabled", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
sutStartStrategy = config.StartStrategyTypeBlocking
})
It("should start", func() {
Expect(err).Should(Succeed())

View File

@ -216,11 +216,13 @@ func NewUpstreamResolver(
) (*UpstreamResolver, error) {
r := newUpstreamResolverUnchecked(cfg, bootstrap)
if cfg.StartVerify {
_, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil {
return nil, err
}
onErr := func(err error) {
r.log().WithError(err).Warn("initial resolver test failed")
}
err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr)
if err != nil {
return nil, err
}
return r, nil

View File

@ -132,9 +132,9 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
})
})
When("start verify is enabled", func() {
When("start strategy is failOnError", func() {
BeforeEach(func() {
sutConfig.StartVerify = true
sutConfig.Init.Strategy = config.StartStrategyTypeFailOnError
})
It("should fail", func() {