feat: add upstream strategy `strict` (#1093)

This commit is contained in:
DerRockWolf 2023-08-21 09:50:23 +02:00 committed by GitHub
parent 39208d860e
commit c112e86740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1233 additions and 245 deletions

View File

@ -123,6 +123,10 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
type QueryLogField string
// UpstreamStrategy data field to be logged
// ENUM(parallel_best,strict)
type UpstreamStrategy uint8
//nolint:gochecknoglobals
var netDefaultPort = map[NetProtocol]uint16{
NetProtocolTcpUdp: udpPort,

View File

@ -476,3 +476,83 @@ func (x *StartStrategyType) UnmarshalText(text []byte) error {
*x = tmp
return nil
}
const (
// UpstreamStrategyParallelBest is a UpstreamStrategy of type Parallel_best.
UpstreamStrategyParallelBest UpstreamStrategy = iota
// UpstreamStrategyStrict is a UpstreamStrategy of type Strict.
UpstreamStrategyStrict
)
var ErrInvalidUpstreamStrategy = fmt.Errorf("not a valid UpstreamStrategy, try [%s]", strings.Join(_UpstreamStrategyNames, ", "))
const _UpstreamStrategyName = "parallel_beststrict"
var _UpstreamStrategyNames = []string{
_UpstreamStrategyName[0:13],
_UpstreamStrategyName[13:19],
}
// UpstreamStrategyNames returns a list of possible string values of UpstreamStrategy.
func UpstreamStrategyNames() []string {
tmp := make([]string, len(_UpstreamStrategyNames))
copy(tmp, _UpstreamStrategyNames)
return tmp
}
// UpstreamStrategyValues returns a list of the values for UpstreamStrategy
func UpstreamStrategyValues() []UpstreamStrategy {
return []UpstreamStrategy{
UpstreamStrategyParallelBest,
UpstreamStrategyStrict,
}
}
var _UpstreamStrategyMap = map[UpstreamStrategy]string{
UpstreamStrategyParallelBest: _UpstreamStrategyName[0:13],
UpstreamStrategyStrict: _UpstreamStrategyName[13:19],
}
// String implements the Stringer interface.
func (x UpstreamStrategy) String() string {
if str, ok := _UpstreamStrategyMap[x]; ok {
return str
}
return fmt.Sprintf("UpstreamStrategy(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x UpstreamStrategy) IsValid() bool {
_, ok := _UpstreamStrategyMap[x]
return ok
}
var _UpstreamStrategyValue = map[string]UpstreamStrategy{
_UpstreamStrategyName[0:13]: UpstreamStrategyParallelBest,
_UpstreamStrategyName[13:19]: UpstreamStrategyStrict,
}
// ParseUpstreamStrategy attempts to convert a string to a UpstreamStrategy.
func ParseUpstreamStrategy(name string) (UpstreamStrategy, error) {
if x, ok := _UpstreamStrategyValue[name]; ok {
return x, nil
}
return UpstreamStrategy(0), fmt.Errorf("%s is %w", name, ErrInvalidUpstreamStrategy)
}
// MarshalText implements the text marshaller method.
func (x UpstreamStrategy) MarshalText() ([]byte, error) {
return []byte(x.String()), nil
}
// UnmarshalText implements the text unmarshaller method.
func (x *UpstreamStrategy) UnmarshalText(text []byte) error {
name := string(text)
tmp, err := ParseUpstreamStrategy(name)
if err != nil {
return err
}
*x = tmp
return nil
}

View File

@ -8,8 +8,9 @@ const UpstreamDefaultCfgName = "default"
// UpstreamsConfig upstream servers configuration
type UpstreamsConfig struct {
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
}
type UpstreamGroups map[string][]Upstream
@ -22,7 +23,7 @@ func (c *UpstreamsConfig) IsEnabled() bool {
// LogConfig implements `config.Configurable`.
func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) {
logger.Info("timeout: ", c.Timeout)
logger.Info("strategy: ", c.Strategy)
logger.Info("groups:")
for name, upstreams := range c.Groups {

View File

@ -18,6 +18,10 @@ upstreams:
# or single ip address / client subnet as CIDR notation
laptop*:
- 123.123.123.123
# optional: Determines what strategy blocky uses to choose the upstream servers.
# accepted: parallel_best, strict
# default: parallel_best
strategy: parallel_best
# optional: timeout to query the upstream resolver. Default: 2s
timeout: 2s

View File

@ -79,9 +79,9 @@ following network protocols (net part of the resolver URL):
!!! hint
You can (and should!) configure multiple DNS resolvers. Blocky picks 2 random resolvers from the list for each query and
returns the answer from the fastest one. This improves your network speed and increases your privacy - your DNS traffic
will be distributed over multiple providers.
You can (and should!) configure multiple DNS resolvers.
Per default blocky uses the `parallel_best` upstream strategy where blocky picks 2 random resolvers from the list for each query and
returns the answer from the fastest one.
Each resolver must be defined as a string in following format: `[net:]host:[port][/path][#commonName]`.
@ -92,13 +92,15 @@ Each resolver must be defined as a string in following format: `[net:]host:[port
| port | int (1 - 65535) | no | 53 for udp/tcp, 853 for tcp-tls and 443 for https |
| commonName | string | no | the host value |
The commonName parameter overrides the expected certificate common name value used for verification.
The `commonName` parameter overrides the expected certificate common name value used for verification.
Blocky needs at least the configuration of the **default** group. This group will be used as a fallback, if no client
specific resolver configuration is available.
!!! note
Blocky needs at least the configuration of the **default** group with at least one upstream DNS server. This group will be used as a fallback, if no client
specific resolver configuration is available.
You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address or a client subnet as
CIDR notation.
See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which public free DNS server you could use.
You can specify multiple upstream groups (additional to the `default` group) to use different upstream servers for different clients, based on client name (see [Client name lookup](#client-name-lookup)), client IP address or client subnet (as CIDR).
!!! tip
@ -121,15 +123,38 @@ CIDR notation.
- 9.9.9.9
```
Use `123.123.123.123` as single upstream DNS resolver for client laptop-home,
`1.1.1.1` and `9.9.9.9` for all clients in the sub-net `10.43.8.67/28` and 4 resolvers (default) for all others clients.
The above example results in:
!!! note
- `123.123.123.123` as the only upstream DNS resolver for clients with a name starting with "laptop"
- `1.1.1.1` and `9.9.9.9` for all clients in the subnet `10.43.8.67/28`
- 4 resolvers (default) for all others clients.
** Blocky needs at least one upstream DNS server **
The logic determining what group a client belongs to follows a strict order: IP, client name, CIDR
See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which
public free DNS server you could use.
If a client matches multiple client name or CIDR groups, a warning is logged and the first found group is used.
### Upstream strategy
Blocky supports different upstream strategies (default `parallel_best`) that determine how and to which upstream DNS servers requests are forwarded.
Currently available strategies:
- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
If an upstream failed to answer within the last hour, it is less likely to be chosen for the race.
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers
(When using 10 upstream servers, each upstream will get on average 20% of the DNS requests)
- `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on.
!!! example
```yaml
upstreams:
strategy: strict
groups:
default:
- 1.2.3.4
- 9.8.7.6
```
### Upstream lookup timeout

View File

@ -70,7 +70,9 @@ var _ = Describe("Upstream resolver configuration tests", func() {
It("should not start", func() {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
Should(ContainElements(
ContainSubstring("creation of upstream branches failed: "),
ContainSubstring("no valid upstream for group default")))
})
})
When("'startVerifyUpstream' is true and upstream server as host name is not reachable", func() {
@ -89,7 +91,9 @@ var _ = Describe("Upstream resolver configuration tests", func() {
It("should not start", func() {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
Should(ContainElements(
ContainSubstring("creation of upstream branches failed: "),
ContainSubstring("no valid upstream for group default")))
})
})
})

2
go.mod
View File

@ -37,7 +37,7 @@ require (
require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
github.com/docker/go-connections v0.4.0
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
github.com/testcontainers/testcontainers-go v0.22.0

4
go.sum
View File

@ -17,8 +17,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/Microsoft/hcsshim v0.10.0-rc.8 h1:YSZVvlIIDD1UxQpJp0h+dnpLUw+TrY0cx8obKsp3bek=
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI=
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef h1:lg6zRor4+PZN1Pxqtieo/NMhd61ZdV1Z/+bFURWIVfU=
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
github.com/abice/go-enum v0.5.7 h1:vOrobjpce5D/x5hYNqrWRkFUXFk7A6BlsJyVy4BS1jM=
github.com/abice/go-enum v0.5.7/go.mod h1:FBDp+2Ygv9ZZzgcd+Gx3XbyClH7xxFfw8ghMrOpwu+A=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=

View File

@ -49,6 +49,13 @@ func (x ListCacheType) String() string {
return fmt.Sprintf("ListCacheType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x ListCacheType) IsValid() bool {
_, ok := _ListCacheTypeMap[x]
return ok
}
var _ListCacheTypeValue = map[string]ListCacheType{
_ListCacheTypeName[0:9]: ListCacheTypeBlacklist,
_ListCacheTypeName[9:18]: ListCacheTypeWhitelist,

View File

@ -49,6 +49,13 @@ func (x FormatType) String() string {
return fmt.Sprintf("FormatType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x FormatType) IsValid() bool {
_, ok := _FormatTypeMap[x]
return ok
}
var _FormatTypeValue = map[string]FormatType{
_FormatTypeName[0:4]: FormatTypeText,
_FormatTypeName[4:8]: FormatTypeJson,
@ -130,6 +137,13 @@ func (x Level) String() string {
return fmt.Sprintf("Level(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x Level) IsValid() bool {
_, ok := _LevelMap[x]
return ok
}
var _LevelValue = map[string]Level{
_LevelName[0:4]: LevelInfo,
_LevelName[4:9]: LevelTrace,

View File

@ -49,6 +49,13 @@ func (x RequestProtocol) String() string {
return fmt.Sprintf("RequestProtocol(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x RequestProtocol) IsValid() bool {
_, ok := _RequestProtocolMap[x]
return ok
}
var _RequestProtocolValue = map[string]RequestProtocol{
_RequestProtocolName[0:3]: RequestProtocolTCP,
_RequestProtocolName[3:6]: RequestProtocolUDP,
@ -151,6 +158,13 @@ func (x ResponseType) String() string {
return fmt.Sprintf("ResponseType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x ResponseType) IsValid() bool {
_, ok := _ResponseTypeMap[x]
return ok
}
var _ResponseTypeValue = map[string]ResponseType{
_ResponseTypeName[0:8]: ResponseTypeRESOLVED,
_ResponseTypeName[8:14]: ResponseTypeCACHED,

View File

@ -69,6 +69,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
JustBeforeEach(func() {
var err error
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)

View File

@ -66,10 +66,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
// where `ParallelBestResolver` uses its config, we can just use an empty one.
var pbCfg config.UpstreamsConfig
parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
if err != nil {
return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err)
}
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
// Always enable prefetching to avoid stalling user requests
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:

View File

@ -30,9 +30,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
JustBeforeEach(func() {
res, err := NewClientNamesResolver(sutConfig, nil, false)
var err error
sut, err = NewClientNamesResolver(sutConfig, nil, false)
Expect(err).Should(Succeed())
sut = res
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)

View File

@ -24,7 +24,7 @@ type ConditionalUpstreamResolver struct {
// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (ChainedResolver, error) {
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstream := range cfg.Mapping.Upstreams {

View File

@ -15,7 +15,7 @@ import (
var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() {
var (
sut ChainedResolver
sut *ConditionalUpstreamResolver
m *mockResolver
)

View File

@ -24,7 +24,7 @@ type CustomDNSResolver struct {
}
// NewCustomDNSResolver creates new resolver instance
func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver {
func NewCustomDNSResolver(cfg config.CustomDNSConfig) *CustomDNSResolver {
m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs))
reverse := make(map[string][]string, len(cfg.Mapping.HostIPs))

View File

@ -18,7 +18,7 @@ var _ = Describe("CustomDNSResolver", func() {
var (
TTL = uint32(time.Now().Second())
sut ChainedResolver
sut *CustomDNSResolver
m *mockResolver
cfg config.CustomDNSConfig
)

View File

@ -12,7 +12,7 @@ type EdeResolver struct {
typed
}
func NewEdeResolver(cfg config.EdeConfig) ChainedResolver {
func NewEdeResolver(cfg config.EdeConfig) *EdeResolver {
return &EdeResolver{
configurable: withConfig(&cfg),
typed: withType("extended_error_code"),

View File

@ -43,7 +43,7 @@ var _ = Describe("EdeResolver", func() {
}, nil)
}
sut = NewEdeResolver(sutConfig).(*EdeResolver)
sut = NewEdeResolver(sutConfig)
sut.Next(m)
})

View File

@ -14,7 +14,7 @@ type FilteringResolver struct {
typed
}
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
return &FilteringResolver{
configurable: withConfig(&cfg),
typed: withType("filtering"),

View File

@ -31,7 +31,7 @@ var _ = Describe("FilteringResolver", func() {
})
JustBeforeEach(func() {
sut = NewFilteringResolver(sutConfig).(*FilteringResolver)
sut = NewFilteringResolver(sutConfig)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)

View File

@ -58,7 +58,7 @@ func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
}
// NewMetricsResolver creates a new intance of the MetricsResolver type
func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver {
func NewMetricsResolver(cfg config.MetricsConfig) *MetricsResolver {
m := MetricsResolver{
configurable: withConfig(&cfg),
typed: withType("metrics"),

View File

@ -30,7 +30,7 @@ var _ = Describe("MetricResolver", func() {
})
BeforeEach(func() {
sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver)
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)

View File

@ -10,8 +10,8 @@ var NoResponse = &model.Response{} //nolint:gochecknoglobals
// NoOpResolver is used to finish a resolver branch as created in RewriterResolver
type NoOpResolver struct{}
func NewNoOpResolver() Resolver {
return NoOpResolver{}
func NewNoOpResolver() *NoOpResolver {
return &NoOpResolver{}
}
// Type implements `Resolver`.

View File

@ -8,7 +8,7 @@ import (
)
var _ = Describe("NoOpResolver", func() {
var sut NoOpResolver
var sut *NoOpResolver
Describe("Type", func() {
It("follows conventions", func() {
@ -17,7 +17,7 @@ var _ = Describe("NoOpResolver", func() {
})
BeforeEach(func() {
sut = NewNoOpResolver().(NoOpResolver)
sut = NewNoOpResolver()
})
Describe("Resolving", func() {

View File

@ -120,12 +120,12 @@ func NewParallelBestResolver(
resolverGroups[name] = group
}
return newParallelBestResolver(cfg, resolverGroups)
return newParallelBestResolver(cfg, resolverGroups), nil
}
func newParallelBestResolver(
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
) (*ParallelBestResolver, error) {
) *ParallelBestResolver {
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
for groupName, resolvers := range resolverGroups {
@ -138,11 +138,6 @@ func newParallelBestResolver(
resolversPerClient[groupName] = resolverStatuses
}
if len(resolversPerClient[upstreamDefaultCfgName]) == 0 {
return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+
"Please configure at least one under '%s' configuration name", upstreamDefaultCfgName)
}
r := ParallelBestResolver{
configurable: withConfig(&cfg),
typed: withType(parallelResolverType),
@ -150,7 +145,7 @@ func newParallelBestResolver(
resolversPerClient: resolversPerClient,
}
return &r, nil
return &r
}
func (r *ParallelBestResolver) Name() string {
@ -172,45 +167,16 @@ func (r *ParallelBestResolver) String() string {
return fmt.Sprintf("parallel upstreams '%s'", strings.Join(result, "; "))
}
func (r *ParallelBestResolver) resolversForClient(request *model.Request) (result []*upstreamResolverStatus) {
clientIP := request.ClientIP.String()
// try client names
for _, cName := range request.ClientNames {
for clientDefinition, upstreams := range r.resolversPerClient {
if cName != clientIP && util.ClientNameMatchesGroupName(clientDefinition, cName) {
result = append(result, upstreams...)
}
}
}
// try IP
upstreams, found := r.resolversPerClient[clientIP]
if found {
result = append(result, upstreams...)
}
// try CIDR
for cidr, upstreams := range r.resolversPerClient {
if util.CidrContainsIP(cidr, request.ClientIP) {
result = append(result, upstreams...)
}
}
if len(result) == 0 {
// return default
result = r.resolversPerClient[upstreamDefaultCfgName]
}
return result
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType)
resolvers := r.resolversForClient(request)
var resolvers []*upstreamResolverStatus
for _, r := range r.resolversPerClient {
resolvers = r
break
}
if len(resolvers) == 1 {
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
@ -233,21 +199,19 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
go r2.resolve(request, ch)
//nolint: gosimple
for len(collectedErrors) < resolverCount {
select {
case result := <-ch:
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, result.err)
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
result := <-ch
return result.response, nil
}
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, result.err)
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
}
}
@ -266,9 +230,13 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
const errorWindowInSec = 60
var choices []weightedrand.Choice[*upstreamResolverStatus, uint]
choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
for _, res := range in {
if exclude == res.resolver {
continue
}
var weight float64 = errorWindowInSec
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
@ -277,9 +245,7 @@ func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamRes
weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes()))
}
if exclude != res.resolver {
choices = append(choices, weightedrand.NewChoice(res, uint(weight)))
}
choices = append(choices, weightedrand.NewChoice(res, uint(weight)))
}
c, err := weightedrand.NewChooser(choices...)

View File

@ -84,16 +84,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
})
When("default upstream resolvers are not defined", func() {
It("should fail on startup", func() {
_, err := NewParallelBestResolver(config.UpstreamsConfig{
Groups: config.UpstreamGroups{},
}, nil, noVerifyUpstreams)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
})
})
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
@ -234,124 +224,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
})
})
When("client specific resolvers are defined", func() {
When("client name matches", func() {
BeforeEach(func() {
defaultMockUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(defaultMockUpstream.Close)
clientSpecificExactMockUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(clientSpecificExactMockUpstream.Close)
clientSpecificWildcardMockUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.124")
DeferCleanup(clientSpecificWildcardMockUpstream.Close)
clientSpecificIPMockUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.125")
DeferCleanup(clientSpecificIPMockUpstream.Close)
clientSpecificCIRDMockUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.126")
DeferCleanup(clientSpecificCIRDMockUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {defaultMockUpstream.Start()},
"laptop": {clientSpecificExactMockUpstream.Start()},
"client-*-m": {clientSpecificWildcardMockUpstream.Start()},
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
}
})
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches exact", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.124"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with range wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.124"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client IP matches", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "cl")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.125"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client IP/name matches", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "192.168.178.33")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.125"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.64", "cl")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.126"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
})
When("only 1 upstream resolvers is defined", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
@ -390,17 +262,20 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ := NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{
sut, _ = NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}}, systemResolverBootstrap, noVerifyUpstreams)
}},
systemResolverBootstrap, noVerifyUpstreams)
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 1000; i++ {
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
"example.com", A, "123.123.100.100",
)))
var resolvers []*upstreamResolverStatus
for _, r := range sut.resolversPerClient {
resolvers = r
}
r1, r2 := pickRandom(resolvers)
res1 := r1.resolver
res2 := r2.resolver
Expect(res1).ShouldNot(Equal(res2))
@ -425,9 +300,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 100; i++ {
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
"example.com", A, "123.123.100.100",
)))
var resolvers []*upstreamResolverStatus
for _, r := range sut.resolversPerClient {
resolvers = r
}
r1, r2 := pickRandom(resolvers)
res1 := r1.resolver.(*UpstreamResolver)
res2 := r2.resolver.(*UpstreamResolver)
Expect(res1).ShouldNot(Equal(res2))

View File

@ -30,7 +30,7 @@ type QueryLoggingResolver struct {
}
// NewQueryLoggingResolver returns a new resolver instance
func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
logger := log.PrefixedLog(queryLoggingResolverType)
var writer querylog.Writer

View File

@ -63,7 +63,7 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig.SetDefaults() // not called when using a struct literal
}
sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver)
sut = NewQueryLoggingResolver(sutConfig)
DeferCleanup(func() { close(sut.logChan) })
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)

View File

@ -30,6 +30,7 @@ func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) Chain
return inner
}
// ensures that the rewrites map contains all rewrites in lower case
for k, v := range cfg.Rewrite {
cfg.Rewrite[strings.ToLower(k)] = strings.ToLower(v)
}

165
resolver/strict_resolver.go Normal file
View File

@ -0,0 +1,165 @@
package resolver
import (
"context"
"errors"
"fmt"
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
)
const (
strictResolverType = "strict"
)
// StrictResolver delegates the DNS message strictly to the first configured upstream resolver
// if it can't provide the answer in time the next resolver is used
type StrictResolver struct {
configurable[*config.UpstreamsConfig]
typed
resolversPerClient map[string][]*upstreamResolverStatus
}
// NewStrictResolver creates new resolver instance
func NewStrictResolver(
cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType)
upstreamResolvers := cfg.Groups
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
for name, upstreamCfgs := range upstreamResolvers {
group := make([]Resolver, 0, len(upstreamCfgs))
hasValidResolver := false
for _, u := range upstreamCfgs {
resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", name, err)
continue
}
if shouldVerifyUpstreams {
err = testResolver(resolver)
if err != nil {
logger.Warn(err)
} else {
hasValidResolver = true
}
}
group = append(group, resolver)
}
if shouldVerifyUpstreams && !hasValidResolver {
return nil, fmt.Errorf("no valid upstream for group %s", name)
}
resolverGroups[name] = group
}
return newStrictResolver(cfg, resolverGroups), nil
}
func newStrictResolver(
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
) *StrictResolver {
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
for groupName, resolvers := range resolverGroups {
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
for _, r := range resolvers {
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
}
resolversPerClient[groupName] = resolverStatuses
}
r := StrictResolver{
configurable: withConfig(&cfg),
typed: withType(strictResolverType),
resolversPerClient: resolversPerClient,
}
return &r
}
func (r *StrictResolver) Name() string {
return r.String()
}
func (r *StrictResolver) String() string {
result := make([]string, 0, len(r.resolversPerClient))
for name, res := range r.resolversPerClient {
tmp := make([]string, len(res))
for i, s := range res {
tmp[i] = fmt.Sprintf("%s", s.resolver)
}
result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ",")))
}
return fmt.Sprintf("%s upstreams %q", strictResolverType, strings.Join(result, "; "))
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType)
var resolvers []*upstreamResolverStatus
for _, r := range r.resolversPerClient {
resolvers = r
break
}
// start with first resolver
for i := range resolvers {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// start in new go routine and cancel if
resolver := resolvers[i]
ch := make(chan requestResponse, resolverCount)
go resolver.resolve(request, ch)
select {
case <-ctx.Done():
// log debug/info that timeout exceeded, call `continue` to try next upstream
logger.WithField("resolver", resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
continue
case result := <-ch:
if result.err != nil {
// log error & call `continue` to try next upstream
logger.Debug("resolution failed from resolver, cause: ", result.err)
continue
}
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
}
}
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
}

View File

@ -0,0 +1,306 @@
package resolver
import (
"time"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("StrictResolver", Label("strictResolver"), func() {
const (
verifyUpstreams = true
noVerifyUpstreams = false
)
var (
sut *StrictResolver
sutMapping config.UpstreamGroups
sutVerify bool
err error
bootstrap *Bootstrap
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
}
sutVerify = noVerifyUpstreams
bootstrap = systemResolverBootstrap
})
JustBeforeEach(func() {
sutConfig := config.UpstreamsConfig{Groups: sutMapping}
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
})
config.GetConfig().Upstreams.Timeout = config.Duration(1000 * time.Millisecond)
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Type", func() {
It("should be correct", func() {
Expect(sut.Type()).ShouldNot(BeEmpty())
Expect(sut.Type()).Should(Equal(strictResolverType))
})
})
Describe("Name", func() {
It("should contain correct resolver", func() {
Expect(sut.Name()).ShouldNot(BeEmpty())
Expect(sut.Name()).Should(ContainSubstring(strictResolverType))
})
})
When("some default upstream resolvers cannot be reached", func() {
It("should start normally", func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
return
})
defer mockUpstream.Close()
upstream := config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
mockUpstream.Start(),
},
}
_, err := NewStrictResolver(config.UpstreamsConfig{
Groups: upstream,
}, systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
When("no upstream resolvers can be reached", func() {
BeforeEach(func() {
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
}
})
When("strict checking is enabled", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("strict checking is disabled", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
})
It("should start", func() {
Expect(err).Should(Not(HaveOccurred()))
})
})
})
Describe("Resolving request in strict order", func() {
When("2 Upstream resolvers are defined", func() {
When("Both are responding", func() {
When("they respond in time", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
})
It("Should use result from first one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("first upstream exceeds upstreamTimeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(time.Duration(config.GetConfig().Upstreams.Timeout) + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
})
It("should return response from next upstream", func() {
request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("all upstreams exceed upsteamTimeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
}
})
It("should return error", func() {
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
Expect(err).To(HaveOccurred())
})
})
})
When("Only second is working", func() {
BeforeEach(func() {
testUpstream1 := config.Upstream{Host: "wrong"}
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1, testUpstream2.Start()},
}
})
It("Should use result from second one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
When("None are working", func() {
BeforeEach(func() {
testUpstream1 := config.Upstream{Host: "wrong"}
testUpstream2 := config.Upstream{Host: "wrong"}
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {testUpstream1, testUpstream2},
}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
request := newRequest("example.com.", A)
_, err := sut.Resolve(request)
Expect(err).Should(HaveOccurred())
})
})
})
When("only 1 upstream resolvers is defined", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
mockUpstream.Start(),
},
}
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
})
})
})

View File

@ -92,7 +92,7 @@ type SpecialUseDomainNamesResolver struct {
configurable[*config.SUDNConfig]
}
func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) ChainedResolver {
func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) *SpecialUseDomainNamesResolver {
return &SpecialUseDomainNamesResolver{
typed: withType("special_use_domains"),
configurable: withConfig(&cfg),

View File

@ -41,7 +41,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut = NewSpecialUseDomainNamesResolver(sutConfig).(*SpecialUseDomainNamesResolver)
sut = NewSpecialUseDomainNamesResolver(sutConfig)
sut.Next(m)
})

View File

@ -0,0 +1,118 @@
package resolver
import (
"fmt"
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
)
const (
upstreamTreeResolverType = "upstream_tree"
)
type UpstreamTreeResolver struct {
configurable[*config.UpstreamsConfig]
typed
branches map[string]Resolver
}
func NewUpstreamTreeResolver(cfg config.UpstreamsConfig, branches map[string]Resolver) (Resolver, error) {
if len(cfg.Groups[upstreamDefaultCfgName]) == 0 {
return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+
"Please configure at least one under '%s' configuration name", upstreamDefaultCfgName)
}
if len(branches) != len(cfg.Groups) {
return nil, fmt.Errorf("amount of passed in branches (%d) does not match amount of configured upstream groups (%d)",
len(branches), len(cfg.Groups))
}
if len(branches) == 1 {
for _, r := range branches {
return r, nil
}
}
// return resolver that forwards request to specific resolver branch depending on the client
r := UpstreamTreeResolver{
configurable: withConfig(&cfg),
typed: withType(upstreamTreeResolverType),
branches: branches,
}
return &r, nil
}
func (r *UpstreamTreeResolver) Name() string {
return r.String()
}
func (r *UpstreamTreeResolver) String() string {
result := make([]string, 0, len(r.branches))
for group, res := range r.branches {
result = append(result, fmt.Sprintf("%s (%s)", group, res.Type()))
}
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
}
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
group := r.upstreamGroupByClient(request)
// delegate request to group resolver
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
return r.branches[group].Resolve(request)
}
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
groups := []string{}
clientIP := request.ClientIP.String()
// try IP
if _, exists := r.branches[clientIP]; exists {
return clientIP
}
// try client names
for _, name := range request.ClientNames {
for group := range r.branches {
if util.ClientNameMatchesGroupName(group, name) {
groups = append(groups, group)
}
}
}
// try CIDR (only if no client name matched)
if len(groups) == 0 {
for cidr := range r.branches {
if util.CidrContainsIP(cidr, request.ClientIP) {
groups = append(groups, cidr)
}
}
}
if len(groups) > 0 {
if len(groups) > 1 {
r.log().WithFields(logrus.Fields{
"clientNames": request.ClientNames,
"clientIP": clientIP,
"groups": groups,
}).Warn("client matches multiple groups")
}
return groups[0]
}
return upstreamDefaultCfgName
}

View File

@ -0,0 +1,328 @@
package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/mock"
)
var mockRes *mockResolver
var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
var (
sut Resolver
sutConfig config.UpstreamsConfig
branches map[string]Resolver
loggerHook *test.Hook
err error
)
BeforeEach(func() {
mockRes = &mockResolver{}
})
JustBeforeEach(func() {
sut, err = NewUpstreamTreeResolver(sutConfig, branches)
})
When("has no configuration", func() {
BeforeEach(func() {
sutConfig = config.UpstreamsConfig{}
})
It("should return error", func() {
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(ContainSubstring("no external DNS resolvers configured")))
Expect(sut).To(BeNil())
})
})
When("amount of passed in resolvers doesn't match amount of groups", func() {
BeforeEach(func() {
sutConfig = config.UpstreamsConfig{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
},
}
branches = map[string]Resolver{}
})
It("should return error", func() {
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(
"amount of passed in branches (0) does not match amount of configured upstream groups (1)"))
Expect(sut).To(BeNil())
})
})
When("has only default group", func() {
BeforeEach(func() {
sutConfig = config.UpstreamsConfig{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
},
}
branches = createBranchesMock(sutConfig)
})
Describe("Type", func() {
It("does not return error", func() {
Expect(err).ToNot(HaveOccurred())
})
It("follows conventions", func() {
expectValidResolverType(sut)
})
It("returns mock", func() {
Expect(sut.Type()).To(Equal("mock"))
})
})
})
When("has multiple groups", func() {
BeforeEach(func() {
sutConfig = config.UpstreamsConfig{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {
{Host: "wrong"},
{Host: "127.0.0.1"},
},
"test": {
{Host: "some-resolver"},
},
},
}
branches = createBranchesMock(sutConfig)
})
Describe("Type", func() {
It("does not return error", func() {
Expect(err).ToNot(HaveOccurred())
})
It("follows conventions", func() {
expectValidResolverType(sut)
})
It("returns upstream_tree", func() {
Expect(sut.Type()).To(Equal(upstreamTreeResolverType))
})
})
Describe("Configuration output", func() {
It("should return configuration", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ToNot(BeEmpty())
})
})
Describe("Name", func() {
var utrSut *UpstreamTreeResolver
JustBeforeEach(func() {
utrSut = sut.(*UpstreamTreeResolver)
})
It("should contain correct resolver", func() {
name := utrSut.Name()
Expect(name).ShouldNot(BeEmpty())
Expect(name).Should(ContainSubstring(upstreamTreeResolverType))
})
})
When("client specific resolvers are defined", func() {
BeforeEach(func() {
loggerHook = test.NewGlobal()
log.Log().AddHook(loggerHook)
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{}},
"laptop": {config.Upstream{}},
"client-*-m": {config.Upstream{}},
"client[0-9]": {config.Upstream{}},
"192.168.178.33": {config.Upstream{}},
"10.43.8.67/28": {config.Upstream{}},
"name-matches1": {config.Upstream{}},
"name-matches*": {config.Upstream{}},
}}
createMockResolver := func(group string) *mockResolver {
resolver := &mockResolver{}
resolver.On("Resolve", mock.Anything)
resolver.ResponseFn = func(req *dns.Msg) *dns.Msg {
res := new(dns.Msg)
res.SetReply(req)
ptr := new(dns.PTR)
ptr.Ptr = group
ptr.Hdr = util.CreateHeader(req.Question[0], 1)
res.Answer = append(res.Answer, ptr)
return res
}
return resolver
}
branches = map[string]Resolver{
upstreamDefaultCfgName: nil,
"laptop": nil,
"client-*-m": nil,
"client[0-9]": nil,
"192.168.178.33": nil,
"10.43.8.67/28": nil,
"name-matches1": nil,
"name-matches*": nil,
}
for group := range branches {
branches[group] = createMockResolver(group)
}
Expect(branches).To(HaveLen(8))
})
AfterEach(func() {
loggerHook.Reset()
})
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "default"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches exact", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client-*-m"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name matches with range wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client[0-9]"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client IP matches", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client name (containing IP) matches", func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use exact IP match before client name match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use client name match before CIDR match", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
})
It("Should use one of the matching resolvers & log warning", func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
SatisfyAny(
BeDNSRecord("example.com.", A, "name-matches1"),
BeDNSRecord("example.com.", A, "name-matches*"),
),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("client matches multiple groups"))
})
})
})
})
func createBranchesMock(cfg config.UpstreamsConfig) map[string]Resolver {
branches := make(map[string]Resolver, len(cfg.Groups))
for name := range cfg.Groups {
branches[name] = mockRes
}
return branches
}

View File

@ -396,15 +396,21 @@ func createQueryResolver(
bootstrap *resolver.Bootstrap,
redisClient *redis.Client,
) (r resolver.Resolver, err error) {
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap)
if uErr != nil {
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
}
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches)
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstreams, bootstrap, cfg.StartVerifyUpstream)
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
err = multierror.Append(
multierror.Prefix(utErr, "upstream tree resolver: "),
multierror.Prefix(blErr, "blocking resolver: "),
multierror.Prefix(pErr, "parallel resolver: "),
multierror.Prefix(cnErr, "client names resolver: "),
multierror.Prefix(cuErr, "conditional upstream resolver: "),
multierror.Prefix(hfErr, "hosts file resolver: "),
@ -426,12 +432,42 @@ func createQueryResolver(
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN),
parallel,
upstreamTree,
)
return r, nil
}
func createUpstreamBranches(
cfg *config.Config,
bootstrap *resolver.Bootstrap,
) (map[string]resolver.Resolver, error) {
upstreamBranches := make(map[string]resolver.Resolver, len(cfg.Upstreams.Groups))
var uErr error
for group, upstreams := range cfg.Upstreams.Groups {
var (
upstream resolver.Resolver
err error
)
resolverCfg := config.UpstreamsConfig{Groups: config.UpstreamGroups{group: upstreams}}
switch cfg.Upstreams.Strategy {
case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyParallelBest:
upstream, err = resolver.NewParallelBestResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
}
upstreamBranches[group] = upstream
uErr = multierror.Append(multierror.Prefix(err, fmt.Sprintf("group %s: ", group))).ErrorOrNil()
}
return upstreamBranches, uErr
}
func (s *Server) registerDNSHandlers() {
for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)

View File

@ -728,6 +728,45 @@ var _ = Describe("Running DNS server", func() {
})
})
Describe("NewServer with strict upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{
Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyStrict,
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
},
nil)
Expect(err).ToNot(HaveOccurred())
Expect(branches).ToNot(BeNil())
Expect(branches).To(HaveLen(1))
_ = branches["default"].(*resolver.StrictResolver)
})
})
Describe("create query resolver", func() {
When("some upstream returns error", func() {
It("create query resolver should return error", func() {
r, err := createQueryResolver(&config.Config{
StartVerifyUpstream: true,
Upstreams: config.UpstreamsConfig{
Groups: config.UpstreamGroups{
"default": {{Host: "0.0.0.0"}},
},
},
},
nil, nil)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(ContainSubstring("creation of upstream branches failed: ")))
Expect(r).To(BeNil())
})
})
})
Describe("resolve client IP", func() {
Context("UDP address", func() {
It("should correct resolve client IP", func() {