mirror of https://github.com/0xERR0R/blocky.git
feat: add upstream strategy `strict` (#1093)
This commit is contained in:
parent
39208d860e
commit
c112e86740
|
@ -123,6 +123,10 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
|
|||
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
|
||||
type QueryLogField string
|
||||
|
||||
// UpstreamStrategy data field to be logged
|
||||
// ENUM(parallel_best,strict)
|
||||
type UpstreamStrategy uint8
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var netDefaultPort = map[NetProtocol]uint16{
|
||||
NetProtocolTcpUdp: udpPort,
|
||||
|
|
|
@ -476,3 +476,83 @@ func (x *StartStrategyType) UnmarshalText(text []byte) error {
|
|||
*x = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// UpstreamStrategyParallelBest is a UpstreamStrategy of type Parallel_best.
|
||||
UpstreamStrategyParallelBest UpstreamStrategy = iota
|
||||
// UpstreamStrategyStrict is a UpstreamStrategy of type Strict.
|
||||
UpstreamStrategyStrict
|
||||
)
|
||||
|
||||
var ErrInvalidUpstreamStrategy = fmt.Errorf("not a valid UpstreamStrategy, try [%s]", strings.Join(_UpstreamStrategyNames, ", "))
|
||||
|
||||
const _UpstreamStrategyName = "parallel_beststrict"
|
||||
|
||||
var _UpstreamStrategyNames = []string{
|
||||
_UpstreamStrategyName[0:13],
|
||||
_UpstreamStrategyName[13:19],
|
||||
}
|
||||
|
||||
// UpstreamStrategyNames returns a list of possible string values of UpstreamStrategy.
|
||||
func UpstreamStrategyNames() []string {
|
||||
tmp := make([]string, len(_UpstreamStrategyNames))
|
||||
copy(tmp, _UpstreamStrategyNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// UpstreamStrategyValues returns a list of the values for UpstreamStrategy
|
||||
func UpstreamStrategyValues() []UpstreamStrategy {
|
||||
return []UpstreamStrategy{
|
||||
UpstreamStrategyParallelBest,
|
||||
UpstreamStrategyStrict,
|
||||
}
|
||||
}
|
||||
|
||||
var _UpstreamStrategyMap = map[UpstreamStrategy]string{
|
||||
UpstreamStrategyParallelBest: _UpstreamStrategyName[0:13],
|
||||
UpstreamStrategyStrict: _UpstreamStrategyName[13:19],
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x UpstreamStrategy) String() string {
|
||||
if str, ok := _UpstreamStrategyMap[x]; ok {
|
||||
return str
|
||||
}
|
||||
return fmt.Sprintf("UpstreamStrategy(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x UpstreamStrategy) IsValid() bool {
|
||||
_, ok := _UpstreamStrategyMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _UpstreamStrategyValue = map[string]UpstreamStrategy{
|
||||
_UpstreamStrategyName[0:13]: UpstreamStrategyParallelBest,
|
||||
_UpstreamStrategyName[13:19]: UpstreamStrategyStrict,
|
||||
}
|
||||
|
||||
// ParseUpstreamStrategy attempts to convert a string to a UpstreamStrategy.
|
||||
func ParseUpstreamStrategy(name string) (UpstreamStrategy, error) {
|
||||
if x, ok := _UpstreamStrategyValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return UpstreamStrategy(0), fmt.Errorf("%s is %w", name, ErrInvalidUpstreamStrategy)
|
||||
}
|
||||
|
||||
// MarshalText implements the text marshaller method.
|
||||
func (x UpstreamStrategy) MarshalText() ([]byte, error) {
|
||||
return []byte(x.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the text unmarshaller method.
|
||||
func (x *UpstreamStrategy) UnmarshalText(text []byte) error {
|
||||
name := string(text)
|
||||
tmp, err := ParseUpstreamStrategy(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = tmp
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,8 +8,9 @@ const UpstreamDefaultCfgName = "default"
|
|||
|
||||
// UpstreamsConfig upstream servers configuration
|
||||
type UpstreamsConfig struct {
|
||||
Timeout Duration `yaml:"timeout" default:"2s"`
|
||||
Groups UpstreamGroups `yaml:"groups"`
|
||||
Timeout Duration `yaml:"timeout" default:"2s"`
|
||||
Groups UpstreamGroups `yaml:"groups"`
|
||||
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
|
||||
}
|
||||
|
||||
type UpstreamGroups map[string][]Upstream
|
||||
|
@ -22,7 +23,7 @@ func (c *UpstreamsConfig) IsEnabled() bool {
|
|||
// LogConfig implements `config.Configurable`.
|
||||
func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("timeout: ", c.Timeout)
|
||||
|
||||
logger.Info("strategy: ", c.Strategy)
|
||||
logger.Info("groups:")
|
||||
|
||||
for name, upstreams := range c.Groups {
|
||||
|
|
|
@ -18,6 +18,10 @@ upstreams:
|
|||
# or single ip address / client subnet as CIDR notation
|
||||
laptop*:
|
||||
- 123.123.123.123
|
||||
# optional: Determines what strategy blocky uses to choose the upstream servers.
|
||||
# accepted: parallel_best, strict
|
||||
# default: parallel_best
|
||||
strategy: parallel_best
|
||||
# optional: timeout to query the upstream resolver. Default: 2s
|
||||
timeout: 2s
|
||||
|
||||
|
|
|
@ -79,9 +79,9 @@ following network protocols (net part of the resolver URL):
|
|||
|
||||
!!! hint
|
||||
|
||||
You can (and should!) configure multiple DNS resolvers. Blocky picks 2 random resolvers from the list for each query and
|
||||
returns the answer from the fastest one. This improves your network speed and increases your privacy - your DNS traffic
|
||||
will be distributed over multiple providers.
|
||||
You can (and should!) configure multiple DNS resolvers.
|
||||
Per default blocky uses the `parallel_best` upstream strategy where blocky picks 2 random resolvers from the list for each query and
|
||||
returns the answer from the fastest one.
|
||||
|
||||
Each resolver must be defined as a string in following format: `[net:]host:[port][/path][#commonName]`.
|
||||
|
||||
|
@ -92,13 +92,15 @@ Each resolver must be defined as a string in following format: `[net:]host:[port
|
|||
| port | int (1 - 65535) | no | 53 for udp/tcp, 853 for tcp-tls and 443 for https |
|
||||
| commonName | string | no | the host value |
|
||||
|
||||
The commonName parameter overrides the expected certificate common name value used for verification.
|
||||
The `commonName` parameter overrides the expected certificate common name value used for verification.
|
||||
|
||||
Blocky needs at least the configuration of the **default** group. This group will be used as a fallback, if no client
|
||||
specific resolver configuration is available.
|
||||
!!! note
|
||||
Blocky needs at least the configuration of the **default** group with at least one upstream DNS server. This group will be used as a fallback, if no client
|
||||
specific resolver configuration is available.
|
||||
|
||||
You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address or a client subnet as
|
||||
CIDR notation.
|
||||
See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which public free DNS server you could use.
|
||||
|
||||
You can specify multiple upstream groups (additional to the `default` group) to use different upstream servers for different clients, based on client name (see [Client name lookup](#client-name-lookup)), client IP address or client subnet (as CIDR).
|
||||
|
||||
!!! tip
|
||||
|
||||
|
@ -121,15 +123,38 @@ CIDR notation.
|
|||
- 9.9.9.9
|
||||
```
|
||||
|
||||
Use `123.123.123.123` as single upstream DNS resolver for client laptop-home,
|
||||
`1.1.1.1` and `9.9.9.9` for all clients in the sub-net `10.43.8.67/28` and 4 resolvers (default) for all others clients.
|
||||
The above example results in:
|
||||
|
||||
!!! note
|
||||
- `123.123.123.123` as the only upstream DNS resolver for clients with a name starting with "laptop"
|
||||
- `1.1.1.1` and `9.9.9.9` for all clients in the subnet `10.43.8.67/28`
|
||||
- 4 resolvers (default) for all others clients.
|
||||
|
||||
** Blocky needs at least one upstream DNS server **
|
||||
The logic determining what group a client belongs to follows a strict order: IP, client name, CIDR
|
||||
|
||||
See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which
|
||||
public free DNS server you could use.
|
||||
If a client matches multiple client name or CIDR groups, a warning is logged and the first found group is used.
|
||||
|
||||
### Upstream strategy
|
||||
|
||||
Blocky supports different upstream strategies (default `parallel_best`) that determine how and to which upstream DNS servers requests are forwarded.
|
||||
|
||||
Currently available strategies:
|
||||
|
||||
- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
|
||||
If an upstream failed to answer within the last hour, it is less likely to be chosen for the race.
|
||||
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers
|
||||
(When using 10 upstream servers, each upstream will get on average 20% of the DNS requests)
|
||||
- `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
upstreams:
|
||||
strategy: strict
|
||||
groups:
|
||||
default:
|
||||
- 1.2.3.4
|
||||
- 9.8.7.6
|
||||
```
|
||||
|
||||
### Upstream lookup timeout
|
||||
|
||||
|
|
|
@ -70,7 +70,9 @@ var _ = Describe("Upstream resolver configuration tests", func() {
|
|||
It("should not start", func() {
|
||||
Expect(blocky.IsRunning()).Should(BeFalse())
|
||||
Expect(getContainerLogs(blocky)).
|
||||
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
|
||||
Should(ContainElements(
|
||||
ContainSubstring("creation of upstream branches failed: "),
|
||||
ContainSubstring("no valid upstream for group default")))
|
||||
})
|
||||
})
|
||||
When("'startVerifyUpstream' is true and upstream server as host name is not reachable", func() {
|
||||
|
@ -89,7 +91,9 @@ var _ = Describe("Upstream resolver configuration tests", func() {
|
|||
It("should not start", func() {
|
||||
Expect(blocky.IsRunning()).Should(BeFalse())
|
||||
Expect(getContainerLogs(blocky)).
|
||||
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
|
||||
Should(ContainElements(
|
||||
ContainSubstring("creation of upstream branches failed: "),
|
||||
ContainSubstring("no valid upstream for group default")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
2
go.mod
2
go.mod
|
@ -37,7 +37,7 @@ require (
|
|||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
|
||||
github.com/testcontainers/testcontainers-go v0.22.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -17,8 +17,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa
|
|||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/Microsoft/hcsshim v0.10.0-rc.8 h1:YSZVvlIIDD1UxQpJp0h+dnpLUw+TrY0cx8obKsp3bek=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef h1:lg6zRor4+PZN1Pxqtieo/NMhd61ZdV1Z/+bFURWIVfU=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
|
||||
github.com/abice/go-enum v0.5.7 h1:vOrobjpce5D/x5hYNqrWRkFUXFk7A6BlsJyVy4BS1jM=
|
||||
github.com/abice/go-enum v0.5.7/go.mod h1:FBDp+2Ygv9ZZzgcd+Gx3XbyClH7xxFfw8ghMrOpwu+A=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x ListCacheType) String() string {
|
|||
return fmt.Sprintf("ListCacheType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x ListCacheType) IsValid() bool {
|
||||
_, ok := _ListCacheTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _ListCacheTypeValue = map[string]ListCacheType{
|
||||
_ListCacheTypeName[0:9]: ListCacheTypeBlacklist,
|
||||
_ListCacheTypeName[9:18]: ListCacheTypeWhitelist,
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x FormatType) String() string {
|
|||
return fmt.Sprintf("FormatType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x FormatType) IsValid() bool {
|
||||
_, ok := _FormatTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _FormatTypeValue = map[string]FormatType{
|
||||
_FormatTypeName[0:4]: FormatTypeText,
|
||||
_FormatTypeName[4:8]: FormatTypeJson,
|
||||
|
@ -130,6 +137,13 @@ func (x Level) String() string {
|
|||
return fmt.Sprintf("Level(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x Level) IsValid() bool {
|
||||
_, ok := _LevelMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _LevelValue = map[string]Level{
|
||||
_LevelName[0:4]: LevelInfo,
|
||||
_LevelName[4:9]: LevelTrace,
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x RequestProtocol) String() string {
|
|||
return fmt.Sprintf("RequestProtocol(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x RequestProtocol) IsValid() bool {
|
||||
_, ok := _RequestProtocolMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _RequestProtocolValue = map[string]RequestProtocol{
|
||||
_RequestProtocolName[0:3]: RequestProtocolTCP,
|
||||
_RequestProtocolName[3:6]: RequestProtocolUDP,
|
||||
|
@ -151,6 +158,13 @@ func (x ResponseType) String() string {
|
|||
return fmt.Sprintf("ResponseType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x ResponseType) IsValid() bool {
|
||||
_, ok := _ResponseTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _ResponseTypeValue = map[string]ResponseType{
|
||||
_ResponseTypeName[0:8]: ResponseTypeRESOLVED,
|
||||
_ResponseTypeName[8:14]: ResponseTypeCACHED,
|
||||
|
|
|
@ -69,6 +69,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
JustBeforeEach(func() {
|
||||
var err error
|
||||
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
|
|
|
@ -66,10 +66,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
// where `ParallelBestResolver` uses its config, we can just use an empty one.
|
||||
var pbCfg config.UpstreamsConfig
|
||||
|
||||
parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err)
|
||||
}
|
||||
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
|
||||
|
||||
// Always enable prefetching to avoid stalling user requests
|
||||
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:
|
||||
|
|
|
@ -30,9 +30,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
res, err := NewClientNamesResolver(sutConfig, nil, false)
|
||||
var err error
|
||||
|
||||
sut, err = NewClientNamesResolver(sutConfig, nil, false)
|
||||
Expect(err).Should(Succeed())
|
||||
sut = res
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
|
|
|
@ -24,7 +24,7 @@ type ConditionalUpstreamResolver struct {
|
|||
// NewConditionalUpstreamResolver returns new resolver instance
|
||||
func NewConditionalUpstreamResolver(
|
||||
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (ChainedResolver, error) {
|
||||
) (*ConditionalUpstreamResolver, error) {
|
||||
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
||||
|
||||
for domain, upstream := range cfg.Mapping.Upstreams {
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() {
|
||||
var (
|
||||
sut ChainedResolver
|
||||
sut *ConditionalUpstreamResolver
|
||||
m *mockResolver
|
||||
)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ type CustomDNSResolver struct {
|
|||
}
|
||||
|
||||
// NewCustomDNSResolver creates new resolver instance
|
||||
func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver {
|
||||
func NewCustomDNSResolver(cfg config.CustomDNSConfig) *CustomDNSResolver {
|
||||
m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs))
|
||||
reverse := make(map[string][]string, len(cfg.Mapping.HostIPs))
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
var (
|
||||
TTL = uint32(time.Now().Second())
|
||||
|
||||
sut ChainedResolver
|
||||
sut *CustomDNSResolver
|
||||
m *mockResolver
|
||||
cfg config.CustomDNSConfig
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@ type EdeResolver struct {
|
|||
typed
|
||||
}
|
||||
|
||||
func NewEdeResolver(cfg config.EdeConfig) ChainedResolver {
|
||||
func NewEdeResolver(cfg config.EdeConfig) *EdeResolver {
|
||||
return &EdeResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("extended_error_code"),
|
||||
|
|
|
@ -43,7 +43,7 @@ var _ = Describe("EdeResolver", func() {
|
|||
}, nil)
|
||||
}
|
||||
|
||||
sut = NewEdeResolver(sutConfig).(*EdeResolver)
|
||||
sut = NewEdeResolver(sutConfig)
|
||||
sut.Next(m)
|
||||
})
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ type FilteringResolver struct {
|
|||
typed
|
||||
}
|
||||
|
||||
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
|
||||
func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
|
||||
return &FilteringResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("filtering"),
|
||||
|
|
|
@ -31,7 +31,7 @@ var _ = Describe("FilteringResolver", func() {
|
|||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut = NewFilteringResolver(sutConfig).(*FilteringResolver)
|
||||
sut = NewFilteringResolver(sutConfig)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut.Next(m)
|
||||
|
|
|
@ -58,7 +58,7 @@ func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
|
|||
}
|
||||
|
||||
// NewMetricsResolver creates a new intance of the MetricsResolver type
|
||||
func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver {
|
||||
func NewMetricsResolver(cfg config.MetricsConfig) *MetricsResolver {
|
||||
m := MetricsResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("metrics"),
|
||||
|
|
|
@ -30,7 +30,7 @@ var _ = Describe("MetricResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver)
|
||||
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
|
|
|
@ -10,8 +10,8 @@ var NoResponse = &model.Response{} //nolint:gochecknoglobals
|
|||
// NoOpResolver is used to finish a resolver branch as created in RewriterResolver
|
||||
type NoOpResolver struct{}
|
||||
|
||||
func NewNoOpResolver() Resolver {
|
||||
return NoOpResolver{}
|
||||
func NewNoOpResolver() *NoOpResolver {
|
||||
return &NoOpResolver{}
|
||||
}
|
||||
|
||||
// Type implements `Resolver`.
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("NoOpResolver", func() {
|
||||
var sut NoOpResolver
|
||||
var sut *NoOpResolver
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
|
@ -17,7 +17,7 @@ var _ = Describe("NoOpResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewNoOpResolver().(NoOpResolver)
|
||||
sut = NewNoOpResolver()
|
||||
})
|
||||
|
||||
Describe("Resolving", func() {
|
||||
|
|
|
@ -120,12 +120,12 @@ func NewParallelBestResolver(
|
|||
resolverGroups[name] = group
|
||||
}
|
||||
|
||||
return newParallelBestResolver(cfg, resolverGroups)
|
||||
return newParallelBestResolver(cfg, resolverGroups), nil
|
||||
}
|
||||
|
||||
func newParallelBestResolver(
|
||||
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
|
||||
) (*ParallelBestResolver, error) {
|
||||
) *ParallelBestResolver {
|
||||
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
|
||||
|
||||
for groupName, resolvers := range resolverGroups {
|
||||
|
@ -138,11 +138,6 @@ func newParallelBestResolver(
|
|||
resolversPerClient[groupName] = resolverStatuses
|
||||
}
|
||||
|
||||
if len(resolversPerClient[upstreamDefaultCfgName]) == 0 {
|
||||
return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+
|
||||
"Please configure at least one under '%s' configuration name", upstreamDefaultCfgName)
|
||||
}
|
||||
|
||||
r := ParallelBestResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType(parallelResolverType),
|
||||
|
@ -150,7 +145,7 @@ func newParallelBestResolver(
|
|||
resolversPerClient: resolversPerClient,
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *ParallelBestResolver) Name() string {
|
||||
|
@ -172,45 +167,16 @@ func (r *ParallelBestResolver) String() string {
|
|||
return fmt.Sprintf("parallel upstreams '%s'", strings.Join(result, "; "))
|
||||
}
|
||||
|
||||
func (r *ParallelBestResolver) resolversForClient(request *model.Request) (result []*upstreamResolverStatus) {
|
||||
clientIP := request.ClientIP.String()
|
||||
|
||||
// try client names
|
||||
for _, cName := range request.ClientNames {
|
||||
for clientDefinition, upstreams := range r.resolversPerClient {
|
||||
if cName != clientIP && util.ClientNameMatchesGroupName(clientDefinition, cName) {
|
||||
result = append(result, upstreams...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try IP
|
||||
upstreams, found := r.resolversPerClient[clientIP]
|
||||
|
||||
if found {
|
||||
result = append(result, upstreams...)
|
||||
}
|
||||
|
||||
// try CIDR
|
||||
for cidr, upstreams := range r.resolversPerClient {
|
||||
if util.CidrContainsIP(cidr, request.ClientIP) {
|
||||
result = append(result, upstreams...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
// return default
|
||||
result = r.resolversPerClient[upstreamDefaultCfgName]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
||||
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, parallelResolverType)
|
||||
|
||||
resolvers := r.resolversForClient(request)
|
||||
var resolvers []*upstreamResolverStatus
|
||||
for _, r := range r.resolversPerClient {
|
||||
resolvers = r
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if len(resolvers) == 1 {
|
||||
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
|
||||
|
@ -233,21 +199,19 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
|
||||
go r2.resolve(request, ch)
|
||||
|
||||
//nolint: gosimple
|
||||
for len(collectedErrors) < resolverCount {
|
||||
select {
|
||||
case result := <-ch:
|
||||
if result.err != nil {
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
collectedErrors = append(collectedErrors, result.err)
|
||||
} else {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
result := <-ch
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
if result.err != nil {
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
collectedErrors = append(collectedErrors, result.err)
|
||||
} else {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -266,9 +230,13 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst
|
|||
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
|
||||
const errorWindowInSec = 60
|
||||
|
||||
var choices []weightedrand.Choice[*upstreamResolverStatus, uint]
|
||||
choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
|
||||
|
||||
for _, res := range in {
|
||||
if exclude == res.resolver {
|
||||
continue
|
||||
}
|
||||
|
||||
var weight float64 = errorWindowInSec
|
||||
|
||||
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
|
||||
|
@ -277,9 +245,7 @@ func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamRes
|
|||
weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes()))
|
||||
}
|
||||
|
||||
if exclude != res.resolver {
|
||||
choices = append(choices, weightedrand.NewChoice(res, uint(weight)))
|
||||
}
|
||||
choices = append(choices, weightedrand.NewChoice(res, uint(weight)))
|
||||
}
|
||||
|
||||
c, err := weightedrand.NewChooser(choices...)
|
||||
|
|
|
@ -84,16 +84,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
})
|
||||
|
||||
When("default upstream resolvers are not defined", func() {
|
||||
It("should fail on startup", func() {
|
||||
_, err := NewParallelBestResolver(config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{},
|
||||
}, nil, noVerifyUpstreams)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
|
||||
})
|
||||
})
|
||||
|
||||
When("some default upstream resolvers cannot be reached", func() {
|
||||
It("should start normally", func() {
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
|
@ -234,124 +224,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
When("client specific resolvers are defined", func() {
|
||||
When("client name matches", func() {
|
||||
BeforeEach(func() {
|
||||
defaultMockUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(defaultMockUpstream.Close)
|
||||
|
||||
clientSpecificExactMockUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
DeferCleanup(clientSpecificExactMockUpstream.Close)
|
||||
|
||||
clientSpecificWildcardMockUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.124")
|
||||
DeferCleanup(clientSpecificWildcardMockUpstream.Close)
|
||||
|
||||
clientSpecificIPMockUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.125")
|
||||
DeferCleanup(clientSpecificIPMockUpstream.Close)
|
||||
|
||||
clientSpecificCIRDMockUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.126")
|
||||
DeferCleanup(clientSpecificCIRDMockUpstream.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {defaultMockUpstream.Start()},
|
||||
"laptop": {clientSpecificExactMockUpstream.Start()},
|
||||
"client-*-m": {clientSpecificWildcardMockUpstream.Start()},
|
||||
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
|
||||
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
|
||||
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
|
||||
}
|
||||
})
|
||||
It("Should use default if client name or IP don't match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches exact", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches with wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.124"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches with range wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.124"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client IP matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "cl")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.125"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client IP/name matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "192.168.178.33")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.125"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "10.43.8.64", "cl")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.126"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
})
|
||||
})
|
||||
When("only 1 upstream resolvers is defined", func() {
|
||||
BeforeEach(func() {
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
|
@ -390,17 +262,20 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream2.Close)
|
||||
|
||||
sut, _ := NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
||||
sut, _ = NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||
}}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
}},
|
||||
systemResolverBootstrap, noVerifyUpstreams)
|
||||
|
||||
By("all resolvers have same weight for random -> equal distribution", func() {
|
||||
resolverCount := make(map[Resolver]int)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
|
||||
"example.com", A, "123.123.100.100",
|
||||
)))
|
||||
var resolvers []*upstreamResolverStatus
|
||||
for _, r := range sut.resolversPerClient {
|
||||
resolvers = r
|
||||
}
|
||||
r1, r2 := pickRandom(resolvers)
|
||||
res1 := r1.resolver
|
||||
res2 := r2.resolver
|
||||
Expect(res1).ShouldNot(Equal(res2))
|
||||
|
@ -425,9 +300,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
resolverCount := make(map[*UpstreamResolver]int)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient(
|
||||
"example.com", A, "123.123.100.100",
|
||||
)))
|
||||
var resolvers []*upstreamResolverStatus
|
||||
for _, r := range sut.resolversPerClient {
|
||||
resolvers = r
|
||||
}
|
||||
r1, r2 := pickRandom(resolvers)
|
||||
res1 := r1.resolver.(*UpstreamResolver)
|
||||
res2 := r2.resolver.(*UpstreamResolver)
|
||||
Expect(res1).ShouldNot(Equal(res2))
|
||||
|
|
|
@ -30,7 +30,7 @@ type QueryLoggingResolver struct {
|
|||
}
|
||||
|
||||
// NewQueryLoggingResolver returns a new resolver instance
|
||||
func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
||||
func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
|
||||
logger := log.PrefixedLog(queryLoggingResolverType)
|
||||
|
||||
var writer querylog.Writer
|
||||
|
|
|
@ -63,7 +63,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
sutConfig.SetDefaults() // not called when using a struct literal
|
||||
}
|
||||
|
||||
sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver)
|
||||
sut = NewQueryLoggingResolver(sutConfig)
|
||||
DeferCleanup(func() { close(sut.logChan) })
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
|
||||
|
|
|
@ -30,6 +30,7 @@ func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) Chain
|
|||
return inner
|
||||
}
|
||||
|
||||
// ensures that the rewrites map contains all rewrites in lower case
|
||||
for k, v := range cfg.Rewrite {
|
||||
cfg.Rewrite[strings.ToLower(k)] = strings.ToLower(v)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
strictResolverType = "strict"
|
||||
)
|
||||
|
||||
// StrictResolver delegates the DNS message strictly to the first configured upstream resolver
|
||||
// if it can't provide the answer in time the next resolver is used
|
||||
type StrictResolver struct {
|
||||
configurable[*config.UpstreamsConfig]
|
||||
typed
|
||||
|
||||
resolversPerClient map[string][]*upstreamResolverStatus
|
||||
}
|
||||
|
||||
// NewStrictResolver creates new resolver instance
|
||||
func NewStrictResolver(
|
||||
cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (*StrictResolver, error) {
|
||||
logger := log.PrefixedLog(strictResolverType)
|
||||
|
||||
upstreamResolvers := cfg.Groups
|
||||
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
|
||||
|
||||
for name, upstreamCfgs := range upstreamResolvers {
|
||||
group := make([]Resolver, 0, len(upstreamCfgs))
|
||||
hasValidResolver := false
|
||||
|
||||
for _, u := range upstreamCfgs {
|
||||
resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
logger.Warnf("upstream group %s: %v", name, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldVerifyUpstreams {
|
||||
err = testResolver(resolver)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
hasValidResolver = true
|
||||
}
|
||||
}
|
||||
|
||||
group = append(group, resolver)
|
||||
}
|
||||
|
||||
if shouldVerifyUpstreams && !hasValidResolver {
|
||||
return nil, fmt.Errorf("no valid upstream for group %s", name)
|
||||
}
|
||||
|
||||
resolverGroups[name] = group
|
||||
}
|
||||
|
||||
return newStrictResolver(cfg, resolverGroups), nil
|
||||
}
|
||||
|
||||
func newStrictResolver(
|
||||
cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver,
|
||||
) *StrictResolver {
|
||||
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
|
||||
|
||||
for groupName, resolvers := range resolverGroups {
|
||||
resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers))
|
||||
|
||||
for _, r := range resolvers {
|
||||
resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r))
|
||||
}
|
||||
|
||||
resolversPerClient[groupName] = resolverStatuses
|
||||
}
|
||||
|
||||
r := StrictResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType(strictResolverType),
|
||||
|
||||
resolversPerClient: resolversPerClient,
|
||||
}
|
||||
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *StrictResolver) Name() string {
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (r *StrictResolver) String() string {
|
||||
result := make([]string, 0, len(r.resolversPerClient))
|
||||
|
||||
for name, res := range r.resolversPerClient {
|
||||
tmp := make([]string, len(res))
|
||||
for i, s := range res {
|
||||
tmp[i] = fmt.Sprintf("%s", s.resolver)
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ",")))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s upstreams %q", strictResolverType, strings.Join(result, "; "))
|
||||
}
|
||||
|
||||
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
||||
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, strictResolverType)
|
||||
|
||||
var resolvers []*upstreamResolverStatus
|
||||
for _, r := range r.resolversPerClient {
|
||||
resolvers = r
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// start with first resolver
|
||||
for i := range resolvers {
|
||||
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
// start in new go routine and cancel if
|
||||
|
||||
resolver := resolvers[i]
|
||||
ch := make(chan requestResponse, resolverCount)
|
||||
|
||||
go resolver.resolve(request, ch)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// log debug/info that timeout exceeded, call `continue` to try next upstream
|
||||
logger.WithField("resolver", resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
|
||||
|
||||
continue
|
||||
case result := <-ch:
|
||||
if result.err != nil {
|
||||
// log error & call `continue` to try next upstream
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
|
||||
}
|
|
@ -0,0 +1,306 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||
const (
|
||||
verifyUpstreams = true
|
||||
noVerifyUpstreams = false
|
||||
)
|
||||
|
||||
var (
|
||||
sut *StrictResolver
|
||||
sutMapping config.UpstreamGroups
|
||||
sutVerify bool
|
||||
|
||||
err error
|
||||
|
||||
bootstrap *Bootstrap
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
},
|
||||
config.Upstream{
|
||||
Host: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sutVerify = noVerifyUpstreams
|
||||
|
||||
bootstrap = systemResolverBootstrap
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sutConfig := config.UpstreamsConfig{Groups: sutMapping}
|
||||
|
||||
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
|
||||
})
|
||||
|
||||
config.GetConfig().Upstreams.Timeout = config.Duration(1000 * time.Millisecond)
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Type", func() {
|
||||
It("should be correct", func() {
|
||||
Expect(sut.Type()).ShouldNot(BeEmpty())
|
||||
Expect(sut.Type()).Should(Equal(strictResolverType))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
It("should contain correct resolver", func() {
|
||||
Expect(sut.Name()).ShouldNot(BeEmpty())
|
||||
Expect(sut.Name()).Should(ContainSubstring(strictResolverType))
|
||||
})
|
||||
})
|
||||
|
||||
When("some default upstream resolvers cannot be reached", func() {
|
||||
It("should start normally", func() {
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
|
||||
|
||||
return
|
||||
})
|
||||
defer mockUpstream.Close()
|
||||
|
||||
upstream := config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
},
|
||||
mockUpstream.Start(),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewStrictResolver(config.UpstreamsConfig{
|
||||
Groups: upstream,
|
||||
}, systemResolverBootstrap, verifyUpstreams)
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
|
||||
When("no upstream resolvers can be reached", func() {
|
||||
BeforeEach(func() {
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
},
|
||||
config.Upstream{
|
||||
Host: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
When("strict checking is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = verifyUpstreams
|
||||
})
|
||||
It("should fail to start", func() {
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
When("strict checking is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = noVerifyUpstreams
|
||||
})
|
||||
It("should start", func() {
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolving request in strict order", func() {
|
||||
When("2 Upstream resolvers are defined", func() {
|
||||
When("Both are responding", func() {
|
||||
When("they respond in time", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
|
||||
}
|
||||
})
|
||||
It("Should use result from first one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
})
|
||||
When("first upstream exceeds upstreamTimeout", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(time.Duration(config.GetConfig().Upstreams.Timeout) + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
|
||||
}
|
||||
})
|
||||
It("should return response from next upstream", func() {
|
||||
request := newRequest("example.com", A)
|
||||
Expect(sut.Resolve(request)).Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
})
|
||||
When("all upstreams exceed upsteamTimeout", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
||||
time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()},
|
||||
}
|
||||
})
|
||||
It("should return error", func() {
|
||||
request := newRequest("example.com", A)
|
||||
_, err := sut.Resolve(request)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
When("Only second is working", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := config.Upstream{Host: "wrong"}
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {testUpstream1, testUpstream2.Start()},
|
||||
}
|
||||
})
|
||||
It("Should use result from second one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
})
|
||||
When("None are working", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := config.Upstream{Host: "wrong"}
|
||||
testUpstream2 := config.Upstream{Host: "wrong"}
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {testUpstream1, testUpstream2},
|
||||
}
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should return error", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
_, err := sut.Resolve(request)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
When("only 1 upstream resolvers is defined", func() {
|
||||
BeforeEach(func() {
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream.Close)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
mockUpstream.Start(),
|
||||
},
|
||||
}
|
||||
})
|
||||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -92,7 +92,7 @@ type SpecialUseDomainNamesResolver struct {
|
|||
configurable[*config.SUDNConfig]
|
||||
}
|
||||
|
||||
func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) ChainedResolver {
|
||||
func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) *SpecialUseDomainNamesResolver {
|
||||
return &SpecialUseDomainNamesResolver{
|
||||
typed: withType("special_use_domains"),
|
||||
configurable: withConfig(&cfg),
|
||||
|
|
|
@ -41,7 +41,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
|
||||
sut = NewSpecialUseDomainNamesResolver(sutConfig).(*SpecialUseDomainNamesResolver)
|
||||
sut = NewSpecialUseDomainNamesResolver(sutConfig)
|
||||
sut.Next(m)
|
||||
})
|
||||
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
upstreamTreeResolverType = "upstream_tree"
|
||||
)
|
||||
|
||||
type UpstreamTreeResolver struct {
|
||||
configurable[*config.UpstreamsConfig]
|
||||
typed
|
||||
|
||||
branches map[string]Resolver
|
||||
}
|
||||
|
||||
func NewUpstreamTreeResolver(cfg config.UpstreamsConfig, branches map[string]Resolver) (Resolver, error) {
|
||||
if len(cfg.Groups[upstreamDefaultCfgName]) == 0 {
|
||||
return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+
|
||||
"Please configure at least one under '%s' configuration name", upstreamDefaultCfgName)
|
||||
}
|
||||
|
||||
if len(branches) != len(cfg.Groups) {
|
||||
return nil, fmt.Errorf("amount of passed in branches (%d) does not match amount of configured upstream groups (%d)",
|
||||
len(branches), len(cfg.Groups))
|
||||
}
|
||||
|
||||
if len(branches) == 1 {
|
||||
for _, r := range branches {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
// return resolver that forwards request to specific resolver branch depending on the client
|
||||
r := UpstreamTreeResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType(upstreamTreeResolverType),
|
||||
|
||||
branches: branches,
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) Name() string {
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) String() string {
|
||||
result := make([]string, 0, len(r.branches))
|
||||
|
||||
for group, res := range r.branches {
|
||||
result = append(result, fmt.Sprintf("%s (%s)", group, res.Type()))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
|
||||
|
||||
group := r.upstreamGroupByClient(request)
|
||||
|
||||
// delegate request to group resolver
|
||||
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
|
||||
|
||||
return r.branches[group].Resolve(request)
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
|
||||
groups := []string{}
|
||||
clientIP := request.ClientIP.String()
|
||||
|
||||
// try IP
|
||||
if _, exists := r.branches[clientIP]; exists {
|
||||
return clientIP
|
||||
}
|
||||
|
||||
// try client names
|
||||
for _, name := range request.ClientNames {
|
||||
for group := range r.branches {
|
||||
if util.ClientNameMatchesGroupName(group, name) {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try CIDR (only if no client name matched)
|
||||
if len(groups) == 0 {
|
||||
for cidr := range r.branches {
|
||||
if util.CidrContainsIP(cidr, request.ClientIP) {
|
||||
groups = append(groups, cidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) > 0 {
|
||||
if len(groups) > 1 {
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"clientNames": request.ClientNames,
|
||||
"clientIP": clientIP,
|
||||
"groups": groups,
|
||||
}).Warn("client matches multiple groups")
|
||||
}
|
||||
|
||||
return groups[0]
|
||||
}
|
||||
|
||||
return upstreamDefaultCfgName
|
||||
}
|
|
@ -0,0 +1,328 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var mockRes *mockResolver
|
||||
|
||||
var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||
var (
|
||||
sut Resolver
|
||||
sutConfig config.UpstreamsConfig
|
||||
branches map[string]Resolver
|
||||
|
||||
loggerHook *test.Hook
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mockRes = &mockResolver{}
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut, err = NewUpstreamTreeResolver(sutConfig, branches)
|
||||
})
|
||||
|
||||
When("has no configuration", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.UpstreamsConfig{}
|
||||
})
|
||||
|
||||
It("should return error", func() {
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(MatchError(ContainSubstring("no external DNS resolvers configured")))
|
||||
Expect(sut).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
When("amount of passed in resolvers doesn't match amount of groups", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
{Host: "wrong"},
|
||||
{Host: "127.0.0.1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
branches = map[string]Resolver{}
|
||||
})
|
||||
|
||||
It("should return error", func() {
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(MatchError(
|
||||
"amount of passed in branches (0) does not match amount of configured upstream groups (1)"))
|
||||
Expect(sut).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
When("has only default group", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
{Host: "wrong"},
|
||||
{Host: "127.0.0.1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
branches = createBranchesMock(sutConfig)
|
||||
})
|
||||
Describe("Type", func() {
|
||||
It("does not return error", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
It("returns mock", func() {
|
||||
Expect(sut.Type()).To(Equal("mock"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("has multiple groups", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
{Host: "wrong"},
|
||||
{Host: "127.0.0.1"},
|
||||
},
|
||||
"test": {
|
||||
{Host: "some-resolver"},
|
||||
},
|
||||
},
|
||||
}
|
||||
branches = createBranchesMock(sutConfig)
|
||||
})
|
||||
Describe("Type", func() {
|
||||
It("does not return error", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
It("returns upstream_tree", func() {
|
||||
Expect(sut.Type()).To(Equal(upstreamTreeResolverType))
|
||||
})
|
||||
})
|
||||
Describe("Configuration output", func() {
|
||||
It("should return configuration", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
|
||||
logger, hook := log.NewMockEntry()
|
||||
sut.LogConfig(logger)
|
||||
Expect(hook.Calls).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
var utrSut *UpstreamTreeResolver
|
||||
JustBeforeEach(func() {
|
||||
utrSut = sut.(*UpstreamTreeResolver)
|
||||
})
|
||||
|
||||
It("should contain correct resolver", func() {
|
||||
name := utrSut.Name()
|
||||
Expect(name).ShouldNot(BeEmpty())
|
||||
Expect(name).Should(ContainSubstring(upstreamTreeResolverType))
|
||||
})
|
||||
})
|
||||
|
||||
When("client specific resolvers are defined", func() {
|
||||
BeforeEach(func() {
|
||||
loggerHook = test.NewGlobal()
|
||||
log.Log().AddHook(loggerHook)
|
||||
|
||||
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {config.Upstream{}},
|
||||
"laptop": {config.Upstream{}},
|
||||
"client-*-m": {config.Upstream{}},
|
||||
"client[0-9]": {config.Upstream{}},
|
||||
"192.168.178.33": {config.Upstream{}},
|
||||
"10.43.8.67/28": {config.Upstream{}},
|
||||
"name-matches1": {config.Upstream{}},
|
||||
"name-matches*": {config.Upstream{}},
|
||||
}}
|
||||
|
||||
createMockResolver := func(group string) *mockResolver {
|
||||
resolver := &mockResolver{}
|
||||
|
||||
resolver.On("Resolve", mock.Anything)
|
||||
resolver.ResponseFn = func(req *dns.Msg) *dns.Msg {
|
||||
res := new(dns.Msg)
|
||||
res.SetReply(req)
|
||||
|
||||
ptr := new(dns.PTR)
|
||||
ptr.Ptr = group
|
||||
ptr.Hdr = util.CreateHeader(req.Question[0], 1)
|
||||
res.Answer = append(res.Answer, ptr)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
return resolver
|
||||
}
|
||||
|
||||
branches = map[string]Resolver{
|
||||
upstreamDefaultCfgName: nil,
|
||||
"laptop": nil,
|
||||
"client-*-m": nil,
|
||||
"client[0-9]": nil,
|
||||
"192.168.178.33": nil,
|
||||
"10.43.8.67/28": nil,
|
||||
"name-matches1": nil,
|
||||
"name-matches*": nil,
|
||||
}
|
||||
|
||||
for group := range branches {
|
||||
branches[group] = createMockResolver(group)
|
||||
}
|
||||
|
||||
Expect(branches).To(HaveLen(8))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
loggerHook.Reset()
|
||||
})
|
||||
|
||||
It("Should use default if client name or IP don't match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "default"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches exact", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "laptop"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches with wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "client-*-m"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name matches with range wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "client[0-9]"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client IP matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client name (containing IP) matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use exact IP match before client name match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use client name match before CIDR match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "laptop"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
})
|
||||
It("Should use one of the matching resolvers & log warning", func() {
|
||||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
SatisfyAny(
|
||||
BeDNSRecord("example.com.", A, "name-matches1"),
|
||||
BeDNSRecord("example.com.", A, "name-matches*"),
|
||||
),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("client matches multiple groups"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func createBranchesMock(cfg config.UpstreamsConfig) map[string]Resolver {
|
||||
branches := make(map[string]Resolver, len(cfg.Groups))
|
||||
|
||||
for name := range cfg.Groups {
|
||||
branches[name] = mockRes
|
||||
}
|
||||
|
||||
return branches
|
||||
}
|
|
@ -396,15 +396,21 @@ func createQueryResolver(
|
|||
bootstrap *resolver.Bootstrap,
|
||||
redisClient *redis.Client,
|
||||
) (r resolver.Resolver, err error) {
|
||||
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap)
|
||||
if uErr != nil {
|
||||
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
|
||||
}
|
||||
|
||||
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches)
|
||||
|
||||
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
|
||||
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstreams, bootstrap, cfg.StartVerifyUpstream)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
|
||||
|
||||
err = multierror.Append(
|
||||
multierror.Prefix(utErr, "upstream tree resolver: "),
|
||||
multierror.Prefix(blErr, "blocking resolver: "),
|
||||
multierror.Prefix(pErr, "parallel resolver: "),
|
||||
multierror.Prefix(cnErr, "client names resolver: "),
|
||||
multierror.Prefix(cuErr, "conditional upstream resolver: "),
|
||||
multierror.Prefix(hfErr, "hosts file resolver: "),
|
||||
|
@ -426,12 +432,42 @@ func createQueryResolver(
|
|||
resolver.NewCachingResolver(cfg.Caching, redisClient),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
|
||||
resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN),
|
||||
parallel,
|
||||
upstreamTree,
|
||||
)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func createUpstreamBranches(
|
||||
cfg *config.Config,
|
||||
bootstrap *resolver.Bootstrap,
|
||||
) (map[string]resolver.Resolver, error) {
|
||||
upstreamBranches := make(map[string]resolver.Resolver, len(cfg.Upstreams.Groups))
|
||||
|
||||
var uErr error
|
||||
|
||||
for group, upstreams := range cfg.Upstreams.Groups {
|
||||
var (
|
||||
upstream resolver.Resolver
|
||||
err error
|
||||
)
|
||||
|
||||
resolverCfg := config.UpstreamsConfig{Groups: config.UpstreamGroups{group: upstreams}}
|
||||
|
||||
switch cfg.Upstreams.Strategy {
|
||||
case config.UpstreamStrategyStrict:
|
||||
upstream, err = resolver.NewStrictResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
|
||||
case config.UpstreamStrategyParallelBest:
|
||||
upstream, err = resolver.NewParallelBestResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream)
|
||||
}
|
||||
|
||||
upstreamBranches[group] = upstream
|
||||
uErr = multierror.Append(multierror.Prefix(err, fmt.Sprintf("group %s: ", group))).ErrorOrNil()
|
||||
}
|
||||
|
||||
return upstreamBranches, uErr
|
||||
}
|
||||
|
||||
func (s *Server) registerDNSHandlers() {
|
||||
for _, server := range s.dnsServers {
|
||||
handler := server.Handler.(*dns.ServeMux)
|
||||
|
|
|
@ -728,6 +728,45 @@ var _ = Describe("Running DNS server", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("NewServer with strict upstream strategy", func() {
|
||||
It("successfully returns upstream branches", func() {
|
||||
branches, err := createUpstreamBranches(&config.Config{
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Strategy: config.UpstreamStrategyStrict,
|
||||
Groups: config.UpstreamGroups{
|
||||
"default": {{Host: "0.0.0.0"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(branches).ToNot(BeNil())
|
||||
Expect(branches).To(HaveLen(1))
|
||||
_ = branches["default"].(*resolver.StrictResolver)
|
||||
})
|
||||
})
|
||||
|
||||
Describe("create query resolver", func() {
|
||||
When("some upstream returns error", func() {
|
||||
It("create query resolver should return error", func() {
|
||||
r, err := createQueryResolver(&config.Config{
|
||||
StartVerifyUpstream: true,
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{
|
||||
"default": {{Host: "0.0.0.0"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil, nil)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(MatchError(ContainSubstring("creation of upstream branches failed: ")))
|
||||
Expect(r).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("resolve client IP", func() {
|
||||
Context("UDP address", func() {
|
||||
It("should correct resolve client IP", func() {
|
||||
|
|
Loading…
Reference in New Issue