refactor: embed `Upstreams` in `UpstreamGroup` to make values accessible

Move `startVerifyUpstream` to `upstreams.startVerify` so it's accessible
via `UpstreamGroup` and we don't need to pass `startVerify` to all
resolver constructors that call `NewUpstreamResolver`.

Also has the nice benefit of greatly reducing the usage of `GetConfig`.
This commit is contained in:
ThinkChaos 2023-11-21 22:18:08 -05:00
parent 08a3df6e64
commit b386e22ebe
25 changed files with 483 additions and 537 deletions

View File

@ -189,44 +189,44 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e
//
//nolint:maligned
type Config struct {
Upstreams Upstreams `yaml:"upstreams"`
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
CustomDNS CustomDNS `yaml:"customDNS"`
Conditional ConditionalUpstream `yaml:"conditional"`
Blocking Blocking `yaml:"blocking"`
ClientLookup ClientLookup `yaml:"clientLookup"`
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
DoHUserAgent string `yaml:"dohUserAgent"`
MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"`
StartVerifyUpstream bool `yaml:"startVerifyUpstream" default:"false"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
HostsFile HostsFileConfig `yaml:"hostsFile"`
FQDNOnly FQDNOnly `yaml:"fqdnOnly"`
Filtering FilteringConfig `yaml:"filtering"`
EDE EDE `yaml:"ede"`
ECS ECS `yaml:"ecs"`
SUDN SUDN `yaml:"specialUseDomains"`
Upstreams Upstreams `yaml:"upstreams"`
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
CustomDNS CustomDNS `yaml:"customDNS"`
Conditional ConditionalUpstream `yaml:"conditional"`
Blocking Blocking `yaml:"blocking"`
ClientLookup ClientLookup `yaml:"clientLookup"`
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
DoHUserAgent string `yaml:"dohUserAgent"`
MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
HostsFile HostsFileConfig `yaml:"hostsFile"`
FQDNOnly FQDNOnly `yaml:"fqdnOnly"`
Filtering FilteringConfig `yaml:"filtering"`
EDE EDE `yaml:"ede"`
ECS ECS `yaml:"ecs"`
SUDN SUDN `yaml:"specialUseDomains"`
// Deprecated options
Deprecated struct {
Upstream *UpstreamGroups `yaml:"upstream"`
UpstreamTimeout *Duration `yaml:"upstreamTimeout"`
DisableIPv6 *bool `yaml:"disableIPv6"`
LogLevel *log.Level `yaml:"logLevel"`
LogFormat *log.FormatType `yaml:"logFormat"`
LogPrivacy *bool `yaml:"logPrivacy"`
LogTimestamp *bool `yaml:"logTimestamp"`
DNSPorts *ListenConfig `yaml:"port"`
HTTPPorts *ListenConfig `yaml:"httpPort"`
HTTPSPorts *ListenConfig `yaml:"httpsPort"`
TLSPorts *ListenConfig `yaml:"tlsPort"`
Upstream *UpstreamGroups `yaml:"upstream"`
UpstreamTimeout *Duration `yaml:"upstreamTimeout"`
DisableIPv6 *bool `yaml:"disableIPv6"`
LogLevel *log.Level `yaml:"logLevel"`
LogFormat *log.FormatType `yaml:"logFormat"`
LogPrivacy *bool `yaml:"logPrivacy"`
LogTimestamp *bool `yaml:"logTimestamp"`
DNSPorts *ListenConfig `yaml:"port"`
HTTPPorts *ListenConfig `yaml:"httpPort"`
HTTPSPorts *ListenConfig `yaml:"httpsPort"`
TLSPorts *ListenConfig `yaml:"tlsPort"`
StartVerifyUpstream *bool `yaml:"startVerifyUpstream"`
} `yaml:",inline"`
}
@ -514,14 +514,15 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
cfg.Filtering.QueryTypes.Insert(dns.Type(dns.TypeAAAA))
}
}),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"startVerifyUpstream": Move(To("upstreams.startVerify", &cfg.Upstreams)),
})
usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts

View File

@ -770,6 +770,7 @@ bootstrapDns:
func defaultTestFileConfig() {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.StartVerify).Should(BeFalse())
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8"))
Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4"))
@ -798,7 +799,6 @@ func defaultTestFileConfig() {
Expect(config.DoHUserAgent).Should(Equal("testBlocky"))
Expect(config.MinTLSServeVer).Should(Equal("1.3"))
Expect(config.StartVerifyUpstream).Should(BeFalse())
Expect(GetConfig()).Should(Not(BeNil()))
}
@ -806,6 +806,7 @@ func defaultTestFileConfig() {
func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
return tmpDir.CreateStringFile("config.yml",
"upstreams:",
" startVerify: false",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
@ -860,12 +861,13 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
"logLevel: debug",
"dohUserAgent: testBlocky",
"minTlsServeVersion: 1.3",
"startVerifyUpstream: false")
)
}
func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
f1 := tmpDir.CreateStringFile("config1.yaml",
"upstreams:",
" startVerify: false",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
@ -925,7 +927,7 @@ func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
"logLevel: debug",
"dohUserAgent: testBlocky",
"minTlsServeVersion: 1.3",
"startVerifyUpstream: false")
)
return f2.Error
}

View File

@ -8,9 +8,10 @@ const UpstreamDefaultCfgName = "default"
// Upstreams upstream servers configuration
type Upstreams struct {
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
StartVerify bool `yaml:"startVerify" default:"false"`
}
type UpstreamGroups map[string][]Upstream
@ -37,13 +38,32 @@ func (c *Upstreams) LogConfig(logger *logrus.Entry) {
// UpstreamGroup represents the config for one group (upstream branch)
type UpstreamGroup struct {
Name string
Upstreams []Upstream
Upstreams
Name string // group name
}
// NewUpstreamGroup creates an UpstreamGroup with the given name and upstreams.
//
// The upstreams from `cfg.Groups` are ignored.
func NewUpstreamGroup(name string, cfg Upstreams, upstreams []Upstream) UpstreamGroup {
group := UpstreamGroup{
Name: name,
Upstreams: cfg,
}
group.Groups = UpstreamGroups{name: upstreams}
return group
}
func (c *UpstreamGroup) GroupUpstreams() []Upstream {
return c.Groups[c.Name]
}
// IsEnabled implements `config.Configurable`.
func (c *UpstreamGroup) IsEnabled() bool {
return len(c.Upstreams) != 0
return len(c.GroupUpstreams()) != 0
}
// LogConfig implements `config.Configurable`.
@ -51,7 +71,7 @@ func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) {
logger.Info("group: ", c.Name)
logger.Info("upstreams:")
for _, upstream := range c.Upstreams {
for _, upstream := range c.GroupUpstreams() {
logger.Infof(" - %s", upstream)
}
}

View File

@ -11,7 +11,7 @@ import (
var _ = Describe("ParallelBestConfig", func() {
suiteBeforeEach()
Context("UpstreamsConfig", func() {
Context("Upstreams", func() {
var cfg Upstreams
BeforeEach(func() {
@ -65,13 +65,13 @@ var _ = Describe("ParallelBestConfig", func() {
var cfg UpstreamGroup
BeforeEach(func() {
cfg = UpstreamGroup{
Name: UpstreamDefaultCfgName,
Upstreams: []Upstream{
{Host: "host1"},
{Host: "host2"},
},
}
upstreamsCfg, err := WithDefaults[Upstreams]()
Expect(err).Should(Succeed())
cfg = NewUpstreamGroup("test", upstreamsCfg, []Upstream{
{Host: "host1"},
{Host: "host2"},
})
})
Describe("IsEnabled", func() {
@ -102,7 +102,7 @@ var _ = Describe("ParallelBestConfig", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: test")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))

View File

@ -24,9 +24,8 @@ upstreams:
strategy: parallel_best
# optional: timeout to query the upstream resolver. Default: 2s
timeout: 2s
# optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false
startVerifyUpstream: true
# optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false
startVerify: false
# optional: Determines how blocky will create outgoing connections. This impacts both upstreams, and lists.
# accepted: dual, v4, v6

View File

@ -17,7 +17,6 @@ configuration properties as [JSON](config.yml).
| keyFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated |
| dohUserAgent | string | no | | HTTP User Agent for DoH upstreams |
| minTlsServeVersion | string | no | 1.2 | Minimum TLS version that the DoT and DoH server use to serve those encrypted DNS requests |
| startVerifyUpstream | bool | no | false | If true, blocky will fail to start unless at least one upstream server per group is reachable. |
| connectIPVersion | enum (dual, v4, v6) | no | dual | IP version to use for outgoing connections (dual, v4, v6) |
!!! example
@ -70,6 +69,16 @@ All logging options are optional.
## Upstreams configuration
| Parameter | Type | Mandatory | Default value | Description |
| --------------------- | ------------------------------------ | --------- | ------------- | ----------------------------------------------------------------------------------------------- |
| usptreams.groups | map of name to upstream | yes | | Upstream DNS servers to use, in groups. |
| usptreams.startVerify | bool | no | false | If true, blocky will fail to start unless at least one upstream server per group is functional. |
| usptreams.strategy | enum (parallel_best, random, strict) | no | parallel_best | Upstream server usage strategy. |
| usptreams.timeout | duration | no | 2s | Upstream connection timeout. |
### Upstream Groups
To resolve a DNS query, blocky needs external public or private DNS resolvers. Blocky supports DNS resolvers with
following network protocols (net part of the resolver URL):
@ -133,6 +142,22 @@ The logic determining what group a client belongs to follows a strict order: IP,
If a client matches multiple client name or CIDR groups, a warning is logged and the first found group is used.
### Upstream connection timeout
Blocky will wait 2 seconds (default value) for the response from the external upstream DNS server. You can change this
value by setting the `timeout` configuration parameter (in **duration format**).
!!! example
```yaml
upstreams:
timeout: 5s
groups:
default:
- 46.182.19.48
- 80.241.218.68
```
### Upstream strategy
Blocky supports different upstream strategies (default `parallel_best`) that determine how and to which upstream DNS servers requests are forwarded.
@ -160,21 +185,6 @@ Currently available strategies:
- 9.8.7.6
```
### Upstream lookup timeout
Blocky will wait 2 seconds (default value) for the response from the external upstream DNS server. You can change this
value by setting the `timeout` configuration parameter (in **duration format**).
!!! example
```yaml
upstreams:
timeout: 5s
groups:
default:
- 46.182.19.48
- 80.241.218.68
```
## Bootstrap DNS configuration

View File

@ -13,8 +13,8 @@ var _ = Describe("Upstream resolver configuration tests", func() {
var blocky testcontainers.Container
var err error
Describe("'startVerifyUpstream' parameter handling", func() {
When("'startVerifyUpstream' is false and upstream server as IP is not reachable", func() {
Describe("'upstreams.startVerify' parameter handling", func() {
When("'upstreams.startVerify' is false and upstream server as IP is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"log:",
@ -23,7 +23,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
" groups:",
" default:",
" - 192.192.192.192",
"startVerifyUpstream: false",
" startVerify: false",
)
Expect(err).Should(Succeed())
@ -34,7 +34,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(getContainerLogs(blocky)).Should(BeEmpty())
})
})
When("'startVerifyUpstream' is false and upstream server as host name is not reachable", func() {
When("'upstreams.startVerify' is false and upstream server as host name is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"log:",
@ -43,7 +43,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
" groups:",
" default:",
" - some.wrong.host",
"startVerifyUpstream: false",
" startVerify: false",
)
Expect(err).Should(Succeed())
@ -54,14 +54,14 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(getContainerLogs(blocky)).Should(BeEmpty())
})
})
When("'startVerifyUpstream' is true and upstream as IP address server is not reachable", func() {
When("'upstreams.startVerify' is true and upstream as IP address server is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"upstreams:",
" groups:",
" default:",
" - 192.192.192.192",
"startVerifyUpstream: true",
" startVerify: true",
)
Expect(err).Should(HaveOccurred())
@ -70,19 +70,17 @@ var _ = Describe("Upstream resolver configuration tests", func() {
It("should not start", func() {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Should(ContainElements(
ContainSubstring("creation of upstream branches failed: "),
ContainSubstring("no valid upstream for group default")))
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
When("'startVerifyUpstream' is true and upstream server as host name is not reachable", func() {
When("'upstreams.startVerify' is true and upstream server as host name is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"upstreams:",
" groups:",
" default:",
" - some.wrong.host",
"startVerifyUpstream: true",
" startVerify: true",
)
Expect(err).Should(HaveOccurred())
@ -91,9 +89,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
It("should not start", func() {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Should(ContainElements(
ContainSubstring("creation of upstream branches failed: "),
ContainSubstring("no valid upstream for group default")))
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
})

View File

@ -65,7 +65,7 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
},
}
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
if err != nil {
return nil, err
}
@ -76,11 +76,8 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
return b, nil
}
// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
// where `ParallelBestResolver` uses its config, we can just use an empty one.
pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName}
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers())
pbCfg := config.NewUpstreamGroup("<bootstrap>", cfg.Upstreams, nil)
pbCfg.Upstreams.Groups = nil // To be on the safe side it doesn't try to use anything besides the bootstrap
// Always enable prefetching to avoid stalling user requests
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:
@ -100,14 +97,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
NewFilteringResolver(cfg.Filtering),
// false: no metrics, to not overwrite the main blocking resolver ones
newCachingResolver(ctx, cachingCfg, nil, false),
parallelResolver,
newParallelBestResolver(pbCfg, bootstraped.Resolvers()),
)
return b, nil
}
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.upstream.Host
hostname := r.cfg.Host
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil
@ -249,7 +246,9 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.
// map of bootstraped resolvers their hardcoded IPs
type bootstrapedResolvers map[Resolver][]net.IP
func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (bootstrapedResolvers, error) {
func newBootstrapedResolvers(
b *Bootstrap, cfg config.BootstrapDNSConfig, upstreamsCfg config.Upstreams,
) (bootstrapedResolvers, error) {
upstreamIPs := make(bootstrapedResolvers, len(cfg))
var multiErr *multierror.Error
@ -289,7 +288,7 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots
continue
}
resolver := newUpstreamResolverUnchecked(upstream, b)
resolver := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, upstreamsCfg), b)
upstreamIPs[resolver] = ips
}

View File

@ -32,7 +32,6 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
)
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
sutConfig = &config.Config{
BootstrapDNS: []config.BootstrappedUpstreamConfig{
{
@ -43,6 +42,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
IPs: []net.IP{net.IPv4zero},
},
},
Upstreams: defaultUpstreamsConfig,
}
ctx, cancelFn = context.WithCancel(context.Background())
@ -327,7 +327,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
upstream.Host = "localhost" // force bootstrap to do resolve, and not just return the IP as is
r := newUpstreamResolverUnchecked(upstream, sut)
r := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, sutConfig.Upstreams), sut)
rsp, err := r.Resolve(ctx, mainReq)
Expect(err).Should(Succeed())

View File

@ -28,11 +28,11 @@ type ClientNamesResolver struct {
// NewClientNamesResolver creates new resolver instance
func NewClientNamesResolver(ctx context.Context,
cfg config.ClientLookup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
cfg config.ClientLookup, upstreamsCfg config.Upstreams, bootstrap *Bootstrap,
) (cr *ClientNamesResolver, err error) {
var r Resolver
if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
r, err = NewUpstreamResolver(ctx, newUpstreamConfig(cfg.Upstream, upstreamsCfg), bootstrap)
if err != nil {
return nil, err
}

View File

@ -39,7 +39,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false)
sut, err = NewClientNamesResolver(ctx, sutConfig, defaultUpstreamsConfig, nil)
Expect(err).Should(Succeed())
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
@ -370,9 +370,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.StartVerify = true
r, err := NewClientNamesResolver(ctx, config.ClientLookup{
Upstream: config.Upstream{Host: "example.com"},
}, b, true)
}, upstreamsCfg, b)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())

View File

@ -2,6 +2,7 @@ package resolver
import (
"context"
"fmt"
"strings"
"github.com/0xERR0R/blocky/config"
@ -24,14 +25,15 @@ type ConditionalUpstreamResolver struct {
// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(
ctx context.Context, cfg config.ConditionalUpstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.ConditionalUpstream, upstreamsCfg config.Upstreams, bootstrap *Bootstrap,
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
name := fmt.Sprintf("<conditional in %s>", domain)
cfg := config.NewUpstreamGroup(name, upstreamsCfg, upstreams)
r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(ctx, cfg, bootstrap)
if err != nil {
return nil, err
}

View File

@ -17,8 +17,10 @@ import (
var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() {
var (
sut *ConditionalUpstreamResolver
m *mockResolver
sut *ConditionalUpstreamResolver
sutConfig config.ConditionalUpstream
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
@ -65,7 +67,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
DeferCleanup(refuseTestUpstream.Close)
sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstream{
sutConfig = config.ConditionalUpstream{
Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{
"fritz.box": {fbTestUpstream.Start()},
@ -74,7 +76,11 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
".": {dotTestUpstream.Start()},
},
},
}, nil, false)
}
})
JustBeforeEach(func() {
sut, _ = NewConditionalUpstreamResolver(ctx, sutConfig, defaultUpstreamsConfig, systemResolverBootstrap)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -194,14 +200,19 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstream{
upstreamsCfg := config.Upstreams{
StartVerify: true,
}
sutConfig := config.ConditionalUpstream{
Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{
".": {config.Upstream{Host: "example.com"}},
},
},
}, b, true)
}
r, err := NewConditionalUpstreamResolver(ctx, sutConfig, upstreamsCfg, b)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
})

View File

@ -13,7 +13,6 @@ import (
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
"github.com/mroth/weightedrand/v2"
"github.com/sirupsen/logrus"
@ -31,7 +30,6 @@ type ParallelBestResolver struct {
configurable[*config.UpstreamGroup]
typed
groupName string
resolvers []*upstreamResolverStatus
resolverCount int
@ -53,6 +51,16 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
return status
}
func newUpstreamResolverStatuses(resolvers []Resolver) []*upstreamResolverStatus {
statuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
statuses = append(statuses, newUpstreamResolverStatus(r))
}
return statuses
}
func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
resp, err := r.resolver.Resolve(ctx, req)
if err != nil {
@ -83,25 +91,13 @@ type requestResponse struct {
err error
}
// testResolver sends a test query to verify the resolver is reachable and working
func testResolver(ctx context.Context, r *UpstreamResolver) error {
request := newRequest("github.com.", dns.Type(dns.TypeA))
resp, err := r.Resolve(ctx, request)
if err != nil || resp.RType != model.ResponseTypeRESOLVED {
return fmt.Errorf("test resolve of upstream server failed: %w", err)
}
return nil
}
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap)
if err != nil {
return nil, err
}
@ -109,20 +105,12 @@ func NewParallelBestResolver(
return newParallelBestResolver(cfg, resolvers), nil
}
func newParallelBestResolver(
cfg config.UpstreamGroup, resolvers []Resolver,
) *ParallelBestResolver {
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *ParallelBestResolver {
typeName := "parallel_best"
resolverCount := parallelBestResolverCount
retryWithDifferentResolver := false
if config.GetConfig().Upstreams.Strategy == config.UpstreamStrategyRandom {
if cfg.Strategy == config.UpstreamStrategyRandom {
typeName = "random"
resolverCount = 1
retryWithDifferentResolver = true
@ -132,11 +120,9 @@ func newParallelBestResolver(
configurable: withConfig(&cfg),
typed: withType(typeName),
groupName: cfg.Name,
resolvers: resolverStatuses,
resolverCount: resolverCount,
retryWithDifferentResolver: retryWithDifferentResolver,
resolvers: newUpstreamResolverStatuses(resolvers),
}
return &r
@ -147,12 +133,12 @@ func (r *ParallelBestResolver) Name() string {
}
func (r *ParallelBestResolver) String() string {
result := make([]string, len(r.resolvers))
resolvers := make([]string, len(r.resolvers))
for i, s := range r.resolvers {
result[i] = fmt.Sprintf("%s", s.resolver)
resolvers[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.groupName, strings.Join(result, ","))
return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(resolvers, ","))
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result

View File

@ -9,7 +9,6 @@ import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -17,18 +16,19 @@ import (
var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
const (
timeout = 50 * time.Millisecond
verifyUpstreams = true
noVerifyUpstreams = false
timeout = 50 * time.Millisecond
)
var (
sut *ParallelBestResolver
upstreams []config.Upstream
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
sut *ParallelBestResolver
sutStrategy config.UpstreamStrategy
upstreams []config.Upstream
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
err error
@ -42,29 +42,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
BeforeEach(func() {
old := config.GetConfig().Upstreams.Timeout
DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old })
config.GetConfig().Upstreams.Timeout = config.Duration(timeout)
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}}
sutStrategy = config.UpstreamStrategyParallelBest
sutVerify = noVerifyUpstreams
bootstrap = systemResolverBootstrap
})
JustBeforeEach(func() {
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
upstreamsCfg := config.Upstreams{
StartVerify: sutVerify,
Strategy: sutStrategy,
Timeout: config.Duration(timeout),
}
sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify)
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap)
})
Describe("IsEnabled", func() {
@ -90,26 +88,14 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
})
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
When("default upstream resolvers are not defined", func() {
BeforeEach(func() {
upstreams = []config.Upstream{}
})
return
})
defer mockUpstream.Close()
upstreams := []config.Upstream{
{Host: "wrong"},
mockUpstream.Start(),
}
_, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
It("should fail on startup", func() {
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError(ContainSubstring("no external DNS resolvers configured")))
})
})
@ -230,13 +216,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
withError2 := config.Upstream{Host: "wrong"}
upstreams = []config.Upstream{withError1, withError2}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
Expect(err).Should(Succeed())
request := newRequest("example.com.", A)
_, err = sut.Resolve(ctx, request)
_, err = sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
})
})
@ -265,7 +250,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Describe("Weighted random on resolver selection", func() {
When("5 upstream resolvers are defined", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
BeforeEach(func() {
withError1 := config.Upstream{Host: "wrong1"}
withError2 := config.Upstream{Host: "wrong2"}
@ -275,12 +260,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
systemResolverBootstrap, noVerifyUpstreams)
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
@ -335,11 +318,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: "test",
Upstreams: []config.Upstream{{Host: "example.com"}},
}, b, verifyUpstreams)
upstreamsCfg := sut.cfg.Upstreams
upstreamsCfg.StartVerify = true
group := config.NewUpstreamGroup("test", upstreamsCfg, []config.Upstream{{Host: "example.com"}})
r, err := NewParallelBestResolver(ctx, group, b)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
})
@ -347,7 +331,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Describe("random resolver strategy", func() {
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
sutStrategy = config.UpstreamStrategyRandom
})
Describe("Name", func() {
@ -468,7 +452,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Describe("Weighted random on resolver selection", func() {
When("4 upstream resolvers are defined", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
BeforeEach(func() {
withError1 := config.Upstream{Host: "wrong1"}
withError2 := config.Upstream{Host: "wrong2"}
@ -478,12 +462,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
systemResolverBootstrap, noVerifyUpstreams)
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)

View File

@ -218,21 +218,25 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
func createResolvers(
ctx context.Context, logger *logrus.Entry,
cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
cfg config.UpstreamGroup, bootstrap *Bootstrap,
) ([]Resolver, error) {
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
if len(cfg.GroupUpstreams()) == 0 {
return nil, fmt.Errorf("no external DNS resolvers configured for group %s", cfg.Name)
}
resolvers := make([]Resolver, 0, len(cfg.GroupUpstreams()))
hasValidResolvers := false
for _, u := range cfg.Upstreams {
resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams)
for _, u := range cfg.GroupUpstreams() {
resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(u, cfg.Upstreams), bootstrap)
if err != nil {
logger.Warnf("upstream group %s: %v", cfg.Name, err)
continue
}
if shoudVerifyUpstreams {
err = testResolver(ctx, resolver)
if cfg.StartVerify {
err = resolver.testResolve(ctx)
if err != nil {
logger.Warn(err)
} else {
@ -243,7 +247,7 @@ func createResolvers(
resolvers = append(resolvers, resolver)
}
if shoudVerifyUpstreams && !hasValidResolvers {
if cfg.StartVerify && !hasValidResolvers {
return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name)
}

View File

@ -1,20 +1,33 @@
package resolver_test
package resolver
import (
"context"
"testing"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var defaultUpstreamsConfig config.Upstreams
func init() {
log.Silence()
redis.SetLogger(NoLogs{})
var err error
defaultUpstreamsConfig, err = config.WithDefaults[config.Upstreams]()
if err != nil {
panic(err)
}
// Shorter timeout for tests
defaultUpstreamsConfig.Timeout = config.Duration(50 * time.Millisecond)
}
func TestResolver(t *testing.T) {

View File

@ -24,17 +24,16 @@ type StrictResolver struct {
configurable[*config.UpstreamGroup]
typed
groupName string
resolvers []*upstreamResolverStatus
}
// NewStrictResolver creates a new strict resolver instance
func NewStrictResolver(
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap)
if err != nil {
return nil, err
}
@ -45,18 +44,11 @@ func NewStrictResolver(
func newStrictResolver(
cfg config.UpstreamGroup, resolvers []Resolver,
) *StrictResolver {
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
r := StrictResolver{
configurable: withConfig(&cfg),
typed: withType(strictResolverType),
groupName: cfg.Name,
resolvers: resolverStatuses,
resolvers: newUpstreamResolverStatuses(resolvers),
}
return &r
@ -72,7 +64,7 @@ func (r *StrictResolver) String() string {
result[i] = fmt.Sprintf("%s", s.resolver)
}
return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.groupName, strings.Join(result, ","))
return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(result, ","))
}
// Resolve sends the query request in a strict order to the upstream resolvers

View File

@ -54,11 +54,11 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
JustBeforeEach(func() {
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify)
upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.StartVerify = sutVerify
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap)
})
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
@ -94,24 +94,21 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
return
})
defer mockUpstream.Close()
DeferCleanup(mockUpstream.Close)
upstreams := []config.Upstream{
upstreams = []config.Upstream{
{Host: "wrong"},
mockUpstream.Start(),
}
})
_, err := NewStrictResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
systemResolverBootstrap, verifyUpstreams)
It("should start normally", func() {
Expect(err).Should(Not(HaveOccurred()))
})
})

View File

@ -29,11 +29,34 @@ const (
retryAttempts = 3
)
type upstreamConfig struct {
config.Upstreams
config.Upstream
}
func newUpstreamConfig(upstream config.Upstream, cfg config.Upstreams) upstreamConfig {
return upstreamConfig{cfg, upstream}
}
func (c upstreamConfig) String() string {
return c.Upstream.String()
}
// IsEnabled implements `config.Configurable`.
func (c upstreamConfig) IsEnabled() bool {
return true
}
// LogConfig implements `config.Configurable`.
func (c upstreamConfig) LogConfig(logger *logrus.Entry) {
logger.Info(c.Upstream)
}
// UpstreamResolver sends request to external DNS server
type UpstreamResolver struct {
typed
configurable[upstreamConfig]
upstream config.Upstream
upstreamClient upstreamClient
bootstrap *Bootstrap
}
@ -54,7 +77,7 @@ type httpUpstreamClient struct {
host string
}
func createUpstreamClient(cfg config.Upstream) upstreamClient {
func createUpstreamClient(cfg upstreamConfig) upstreamClient {
tlsConfig := tls.Config{
ServerName: cfg.Host,
MinVersion: tls.VersionTLS12,
@ -187,11 +210,11 @@ func (r *dnsUpstreamClient) callExternal(
// NewUpstreamResolver creates new resolver instance
func NewUpstreamResolver(
ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool,
ctx context.Context, cfg upstreamConfig, bootstrap *Bootstrap,
) (*UpstreamResolver, error) {
r := newUpstreamResolverUnchecked(upstream, bootstrap)
r := newUpstreamResolverUnchecked(cfg, bootstrap)
if verify {
if cfg.StartVerify {
_, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil {
return nil, err
@ -202,30 +225,34 @@ func NewUpstreamResolver(
}
// newUpstreamResolverUnchecked creates new resolver instance without validating the upstream
func newUpstreamResolverUnchecked(upstream config.Upstream, bootstrap *Bootstrap) *UpstreamResolver {
upstreamClient := createUpstreamClient(upstream)
func newUpstreamResolverUnchecked(cfg upstreamConfig, bootstrap *Bootstrap) *UpstreamResolver {
upstreamClient := createUpstreamClient(cfg)
return &UpstreamResolver{
typed: withType("upstream"),
typed: withType("upstream"),
configurable: withConfig(cfg),
upstream: upstream,
upstreamClient: upstreamClient,
bootstrap: bootstrap,
}
}
// IsEnabled implements `config.Configurable`.
func (r *UpstreamResolver) IsEnabled() bool {
return true
}
// LogConfig implements `config.Configurable`.
func (r *UpstreamResolver) LogConfig(logger *logrus.Entry) {
logger.Info(r.upstream)
}
func (r UpstreamResolver) String() string {
return fmt.Sprintf("%s '%s'", r.Type(), r.upstream)
return fmt.Sprintf("%s '%s'", r.Type(), r.cfg)
}
func (r *UpstreamResolver) log() *logrus.Entry {
return r.typed.log().WithField("upstream", r.cfg.String())
}
// testResolve sends a test query to verify the upstream is reachable and working
func (r *UpstreamResolver) testResolve(ctx context.Context) error {
// example.com MUST always resolve. See SUDN resolver
request := newRequest("example.com.", dns.Type(dns.TypeA))
_, err := r.Resolve(ctx, request)
return err
}
// Resolve calls external resolver
@ -242,15 +269,21 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
err = retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration())
defer cancel()
ip = ips.Current()
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path)
ctx := ctx // make sure we don't overwrite the outer function's context
if r.cfg.Timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
defer cancel()
}
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
if err != nil {
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.cfg, upstreamURL, err)
}
resp = response
@ -266,7 +299,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
retry.RetryIf(isTimeout),
retry.OnRetry(func(n uint, err error) {
r.log().WithFields(logrus.Fields{
"upstream": r.upstream.String(),
"upstream": r.cfg.String(),
"upstream_ip": ip.String(),
"question": util.QuestionToString(request.Req.Question),
"attempt": fmt.Sprintf("%d/%d", n+1, retryAttempts),
@ -275,25 +308,20 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
ips.Next()
}))
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
// Make the error more user friendly than just "context deadline exceeded"
err = fmt.Errorf("timeout (%w)", err)
}
return nil, err
}
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil
}
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
"upstream": r.cfg.String(),
"upstream_ip": ip.String(),
"protocol": request.Protocol,
"net": r.upstream.Net,
"net": r.cfg.Net,
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
}

View File

@ -21,7 +21,7 @@ import (
var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
var (
sut *UpstreamResolver
sutConfig config.Upstream
sutConfig upstreamConfig
ctx context.Context
cancelFn context.CancelFunc
@ -31,7 +31,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.Upstream{Host: "localhost"}
sutConfig = newUpstreamConfig(config.Upstream{Host: "localhost"}, defaultUpstreamsConfig)
})
JustBeforeEach(func() {
@ -66,8 +66,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
sutConfig.Upstream = mockUpstream.Start()
sut := newUpstreamResolverUnchecked(sutConfig, nil)
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
@ -76,7 +76,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream))),
)
})
})
@ -85,8 +85,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerError(dns.RcodeNameError)
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
sutConfig.Upstream = mockUpstream.Start()
sut := newUpstreamResolverUnchecked(sutConfig, nil)
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
@ -94,7 +94,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream))),
)
})
})
@ -104,8 +104,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
return nil
})
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
sutConfig.Upstream = mockUpstream.Start()
sut := newUpstreamResolverUnchecked(sutConfig, nil)
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
@ -114,27 +114,27 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
When("Timeout occurs", func() {
var counter int32
var attemptsWithTimeout int32
var sut *UpstreamResolver
BeforeEach(func() {
resolveFn := func(request *dns.Msg) (response *dns.Msg) {
atomic.AddInt32(&counter, 1)
timeout := sutConfig.Timeout.ToDuration() // avoid data race
resolveFn := func(request *dns.Msg) *dns.Msg {
// timeout on first x attempts
if atomic.LoadInt32(&counter) <= atomic.LoadInt32(&attemptsWithTimeout) {
time.Sleep(110 * time.Millisecond)
if atomic.AddInt32(&counter, 1) <= atomic.LoadInt32(&attemptsWithTimeout) {
time.Sleep(2 * timeout)
}
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.122")
Expect(err).Should(Succeed())
return response
}
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(resolveFn)
DeferCleanup(mockUpstream.Close)
upstream := mockUpstream.Start()
sut = newUpstreamResolverUnchecked(upstream, nil)
sut.upstreamClient.(*dnsUpstreamClient).udpClient.Timeout = 100 * time.Millisecond
sutConfig.Upstream = mockUpstream.Start()
})
It("should perform a retry with 3 attempts", func() {
By("2 attempts with timeout -> should resolve with third attempt", func() {
atomic.StoreInt32(&counter, 0)
@ -166,7 +166,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutConfig = mockUpstream.Start()
sutConfig.Upstream = mockUpstream.Start()
})
It("should retry with UDP", func() {
@ -188,8 +188,6 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Describe("Using Dns over HTTP (DOH) upstream", func() {
var (
sut *UpstreamResolver
upstream config.Upstream
respFn func(request *dns.Msg) (response *dns.Msg)
modifyHTTPRespFn func(w http.ResponseWriter)
)
@ -205,8 +203,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
JustBeforeEach(func() {
upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn)
sut = newUpstreamResolverUnchecked(upstream, nil)
sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn)
sut = newUpstreamResolverUnchecked(sutConfig, nil)
// use insecure certificates for test doh upstream
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
@ -224,7 +222,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
HaveReason(fmt.Sprintf("RESOLVED (https://%s:%d)", upstream.Host, upstream.Port)),
HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream)),
))
})
})
@ -267,10 +265,12 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
When("Configured DOH resolver does not respond", func() {
JustBeforeEach(func() {
sut = newUpstreamResolverUnchecked(config.Upstream{
sutConfig.Upstream = config.Upstream{
Net: config.NetProtocolHttps,
Host: "wronghost.example.com",
}, systemResolverBootstrap)
}
sut = newUpstreamResolverUnchecked(sutConfig, systemResolverBootstrap)
})
It("should return error", func() {
_, err := sut.Resolve(ctx, newRequest("example.com.", A))

View File

@ -2,6 +2,7 @@ package resolver
import (
"context"
"errors"
"fmt"
"strings"
@ -23,15 +24,15 @@ type UpstreamTreeResolver struct {
branches map[string]Resolver
}
func NewUpstreamTreeResolver(cfg config.Upstreams, branches map[string]Resolver) (Resolver, error) {
func NewUpstreamTreeResolver(ctx context.Context, cfg config.Upstreams, bootstrap *Bootstrap) (Resolver, error) {
if len(cfg.Groups[upstreamDefaultCfgName]) == 0 {
return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+
"Please configure at least one under '%s' configuration name", upstreamDefaultCfgName)
}
if len(branches) != len(cfg.Groups) {
return nil, fmt.Errorf("amount of passed in branches (%d) does not match amount of configured upstream groups (%d)",
len(branches), len(cfg.Groups))
branches, err := createUpstreamBranches(ctx, cfg, bootstrap)
if err != nil {
return nil, err
}
if len(branches) == 1 {
@ -51,6 +52,45 @@ func NewUpstreamTreeResolver(cfg config.Upstreams, branches map[string]Resolver)
return &r, nil
}
func createUpstreamBranches(
ctx context.Context, cfg config.Upstreams, bootstrap *Bootstrap,
) (map[string]Resolver, error) {
branches := make(map[string]Resolver, len(cfg.Groups))
errs := make([]error, 0, len(cfg.Groups))
for group, upstreams := range cfg.Groups {
var (
upstream Resolver
err error
)
groupConfig := config.NewUpstreamGroup(group, cfg, upstreams)
switch cfg.Strategy {
case config.UpstreamStrategyParallelBest:
fallthrough
case config.UpstreamStrategyRandom:
upstream, err = NewParallelBestResolver(ctx, groupConfig, bootstrap)
case config.UpstreamStrategyStrict:
upstream, err = NewStrictResolver(ctx, groupConfig, bootstrap)
}
if err != nil {
errs = append(errs, fmt.Errorf("group %s: %w", group, err))
continue
}
branches[group] = upstream
}
if len(errs) != 0 {
return nil, errors.Join(errs...)
}
return branches, nil
}
func (r *UpstreamTreeResolver) Name() string {
return r.String()
}
@ -77,7 +117,7 @@ func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Reque
}
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
groups := []string{}
groups := make([]string, 0, len(r.branches))
clientIP := request.ClientIP.String()
// try IP

View File

@ -2,38 +2,40 @@ package resolver
import (
"context"
"fmt"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/mock"
)
var mockRes *mockResolver
var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
var (
sut Resolver
sutConfig config.Upstreams
branches map[string]Resolver
err error
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
mockRes = &mockResolver{}
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = defaultUpstreamsConfig
})
JustBeforeEach(func() {
sut, err = NewUpstreamTreeResolver(sutConfig, branches)
sut, err = NewUpstreamTreeResolver(ctx, sutConfig, systemResolverBootstrap)
})
When("has no configuration", func() {
When("it has no configuration", func() {
BeforeEach(func() {
sutConfig = config.Upstreams{}
})
@ -45,67 +47,56 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
})
})
When("amount of passed in resolvers doesn't match amount of groups", func() {
When("it has only default group", func() {
BeforeEach(func() {
sutConfig = config.Upstreams{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
sutConfig.Groups = config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
}
branches = map[string]Resolver{}
})
It("should return error", func() {
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(
"amount of passed in branches (0) does not match amount of configured upstream groups (1)"))
Expect(sut).To(BeNil())
})
})
When("strategy is parallel", func() {
BeforeEach(func() {
sutConfig.Strategy = config.UpstreamStrategyParallelBest
})
When("has only default group", func() {
BeforeEach(func() {
sutConfig = config.Upstreams{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
},
}
branches = createBranchesMock(sutConfig)
})
Describe("Type", func() {
It("does not return error", func() {
It("returns the resolver directly", func() {
Expect(err).ToNot(HaveOccurred())
_, ok := sut.(*ParallelBestResolver)
Expect(ok).Should(BeTrue())
})
It("follows conventions", func() {
expectValidResolverType(sut)
})
When("strategy is strict", func() {
BeforeEach(func() {
sutConfig.Strategy = config.UpstreamStrategyStrict
})
It("returns mock", func() {
Expect(sut.Type()).To(Equal("mock"))
It("returns the resolver directly", func() {
Expect(err).ToNot(HaveOccurred())
_, ok := sut.(*StrictResolver)
Expect(ok).Should(BeTrue())
})
})
})
When("has multiple groups", func() {
When("it has multiple groups", func() {
BeforeEach(func() {
sutConfig = config.Upstreams{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
"test": {
{Host: "some-resolver"},
},
sutConfig.Groups = config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
"test": {
{Host: "some-resolver"},
},
}
branches = createBranchesMock(sutConfig)
})
Describe("Type", func() {
It("does not return error", func() {
Expect(err).ToNot(HaveOccurred())
@ -117,6 +108,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Type()).To(Equal(upstreamTreeResolverType))
})
})
Describe("Configuration output", func() {
It("should return configuration", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
@ -140,62 +132,41 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
})
})
When("start verify is enabled", func() {
BeforeEach(func() {
sutConfig.StartVerify = true
})
It("should fail", func() {
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(ContainSubstring("no valid upstream")))
Expect(sut).To(BeNil())
})
})
When("client specific resolvers are defined", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
groups := map[string]string{
upstreamDefaultCfgName: "127.0.0.1",
"laptop": "127.0.0.2",
"client-*-m": "127.0.0.3",
"client[0-9]": "127.0.0.4",
"192.168.178.33": "127.0.0.5",
"10.43.8.67/28": "127.0.0.6",
"name-matches1": "127.0.0.7",
"name-matches*": "127.0.0.8",
}
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig.Groups = make(config.UpstreamGroups, len(groups))
sutConfig = config.Upstreams{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{}},
"laptop": {config.Upstream{}},
"client-*-m": {config.Upstream{}},
"client[0-9]": {config.Upstream{}},
"192.168.178.33": {config.Upstream{}},
"10.43.8.67/28": {config.Upstream{}},
"name-matches1": {config.Upstream{}},
"name-matches*": {config.Upstream{}},
}}
for group, ip := range groups {
Expect(ip).ShouldNot(BeNil())
createMockResolver := func(group string) *mockResolver {
resolver := &mockResolver{}
server := NewMockUDPUpstreamServer().WithAnswerRR(fmt.Sprintf("example.com 123 IN A %s", ip))
sutConfig.Groups[group] = []config.Upstream{server.Start()}
resolver.On("Resolve", mock.Anything)
resolver.ResponseFn = func(req *dns.Msg) *dns.Msg {
res := new(dns.Msg)
res.SetReply(req)
ptr := new(dns.PTR)
ptr.Ptr = group
ptr.Hdr = util.CreateHeader(req.Question[0], 1)
res.Answer = append(res.Answer, ptr)
return res
}
return resolver
DeferCleanup(server.Close)
}
branches = map[string]Resolver{
upstreamDefaultCfgName: nil,
"laptop": nil,
"client-*-m": nil,
"client[0-9]": nil,
"192.168.178.33": nil,
"10.43.8.67/28": nil,
"name-matches1": nil,
"name-matches*": nil,
}
for group := range branches {
branches[group] = createMockResolver(group)
}
Expect(branches).To(HaveLen(8))
})
It("Should use default if client name or IP don't match", func() {
@ -204,7 +175,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "default"),
BeDNSRecord("example.com.", A, groups["default"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -215,7 +186,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
BeDNSRecord("example.com.", A, groups["laptop"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -226,7 +197,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client-*-m"),
BeDNSRecord("example.com.", A, groups["client-*-m"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -237,7 +208,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client[0-9]"),
BeDNSRecord("example.com.", A, groups["client[0-9]"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -248,7 +219,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
BeDNSRecord("example.com.", A, groups["192.168.178.33"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -259,7 +230,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
BeDNSRecord("example.com.", A, groups["192.168.178.33"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -270,7 +241,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
BeDNSRecord("example.com.", A, groups["10.43.8.67/28"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -281,7 +252,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
BeDNSRecord("example.com.", A, groups["192.168.178.33"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -292,7 +263,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
BeDNSRecord("example.com.", A, groups["laptop"]),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
@ -307,8 +278,8 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
Should(
SatisfyAll(
SatisfyAny(
BeDNSRecord("example.com.", A, "name-matches1"),
BeDNSRecord("example.com.", A, "name-matches*"),
BeDNSRecord("example.com.", A, groups["name-matches1"]),
BeDNSRecord("example.com.", A, groups["name-matches*"]),
),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -319,13 +290,3 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
})
})
})
func createBranchesMock(cfg config.Upstreams) map[string]Resolver {
branches := make(map[string]Resolver, len(cfg.Groups))
for name := range cfg.Groups {
branches[name] = mockRes
}
return branches
}

View File

@ -388,22 +388,14 @@ func createQueryResolver(
cfg *config.Config,
bootstrap *resolver.Bootstrap,
redisClient *redis.Client,
) (r resolver.ChainedResolver, err error) {
upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap)
if uErr != nil {
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
}
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches)
) (resolver.ChainedResolver, error) {
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(ctx, cfg.Upstreams, bootstrap)
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(
ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream,
)
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, cfg.Upstreams, bootstrap)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(ctx, cfg.Conditional, cfg.Upstreams, bootstrap)
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
err = multierror.Append(
err := multierror.Append(
multierror.Prefix(utErr, "upstream tree resolver: "),
multierror.Prefix(blErr, "blocking resolver: "),
multierror.Prefix(cnErr, "client names resolver: "),
@ -414,7 +406,7 @@ func createQueryResolver(
return nil, err
}
r = resolver.Chain(
r := resolver.Chain(
resolver.NewFilteringResolver(cfg.Filtering),
resolver.NewFQDNOnlyResolver(cfg.FQDNOnly),
resolver.NewECSResolver(cfg.ECS),
@ -434,42 +426,6 @@ func createQueryResolver(
return r, nil
}
func createUpstreamBranches(
ctx context.Context,
cfg *config.Config,
bootstrap *resolver.Bootstrap,
) (map[string]resolver.Resolver, error) {
upstreamBranches := make(map[string]resolver.Resolver, len(cfg.Upstreams.Groups))
var uErr error
for group, upstreams := range cfg.Upstreams.Groups {
var (
upstream resolver.Resolver
err error
)
groupConfig := config.UpstreamGroup{
Name: group,
Upstreams: upstreams,
}
switch cfg.Upstreams.Strategy {
case config.UpstreamStrategyParallelBest:
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyRandom:
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
}
upstreamBranches[group] = upstream
uErr = multierror.Append(multierror.Prefix(err, fmt.Sprintf("group %s: ", group))).ErrorOrNil()
}
return upstreamBranches, uErr
}
func (s *Server) registerDNSHandlers() {
for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)

View File

@ -697,62 +697,6 @@ var _ = Describe("Running DNS server", func() {
})
})
Describe("NewServer with strict upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.Upstreams{
Strategy: config.UpstreamStrategyStrict,
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(branches).ToNot(BeNil())
Expect(branches).To(HaveLen(1))
_ = branches["default"].(*resolver.StrictResolver)
})
})
Describe("NewServer with random upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.Upstreams{
Strategy: config.UpstreamStrategyRandom,
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(branches).ToNot(BeNil())
Expect(branches).To(HaveLen(1))
_ = branches["default"].(*resolver.ParallelBestResolver)
})
})
Describe("create query resolver", func() {
When("some upstream returns error", func() {
It("create query resolver should return error", func() {
r, err := createQueryResolver(ctx, &config.Config{
StartVerifyUpstream: true,
Upstreams: config.Upstreams{
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
},
nil, nil)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(ContainSubstring("creation of upstream branches failed: ")))
Expect(r).To(BeNil())
})
})
})
Describe("resolve client IP", func() {
Context("UDP address", func() {
It("should correct resolve client IP", func() {