diff --git a/config/config.go b/config/config.go index 54be72bd..44853b40 100644 --- a/config/config.go +++ b/config/config.go @@ -123,6 +123,10 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error { // ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration) type QueryLogField string +// UpstreamStrategy data field to be logged +// ENUM(parallel_best,strict) +type UpstreamStrategy uint8 + //nolint:gochecknoglobals var netDefaultPort = map[NetProtocol]uint16{ NetProtocolTcpUdp: udpPort, diff --git a/config/config_enum.go b/config/config_enum.go index e445ac5d..9bc7574b 100644 --- a/config/config_enum.go +++ b/config/config_enum.go @@ -476,3 +476,83 @@ func (x *StartStrategyType) UnmarshalText(text []byte) error { *x = tmp return nil } + +const ( + // UpstreamStrategyParallelBest is a UpstreamStrategy of type Parallel_best. + UpstreamStrategyParallelBest UpstreamStrategy = iota + // UpstreamStrategyStrict is a UpstreamStrategy of type Strict. + UpstreamStrategyStrict +) + +var ErrInvalidUpstreamStrategy = fmt.Errorf("not a valid UpstreamStrategy, try [%s]", strings.Join(_UpstreamStrategyNames, ", ")) + +const _UpstreamStrategyName = "parallel_beststrict" + +var _UpstreamStrategyNames = []string{ + _UpstreamStrategyName[0:13], + _UpstreamStrategyName[13:19], +} + +// UpstreamStrategyNames returns a list of possible string values of UpstreamStrategy. +func UpstreamStrategyNames() []string { + tmp := make([]string, len(_UpstreamStrategyNames)) + copy(tmp, _UpstreamStrategyNames) + return tmp +} + +// UpstreamStrategyValues returns a list of the values for UpstreamStrategy +func UpstreamStrategyValues() []UpstreamStrategy { + return []UpstreamStrategy{ + UpstreamStrategyParallelBest, + UpstreamStrategyStrict, + } +} + +var _UpstreamStrategyMap = map[UpstreamStrategy]string{ + UpstreamStrategyParallelBest: _UpstreamStrategyName[0:13], + UpstreamStrategyStrict: _UpstreamStrategyName[13:19], +} + +// String implements the Stringer interface. +func (x UpstreamStrategy) String() string { + if str, ok := _UpstreamStrategyMap[x]; ok { + return str + } + return fmt.Sprintf("UpstreamStrategy(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x UpstreamStrategy) IsValid() bool { + _, ok := _UpstreamStrategyMap[x] + return ok +} + +var _UpstreamStrategyValue = map[string]UpstreamStrategy{ + _UpstreamStrategyName[0:13]: UpstreamStrategyParallelBest, + _UpstreamStrategyName[13:19]: UpstreamStrategyStrict, +} + +// ParseUpstreamStrategy attempts to convert a string to a UpstreamStrategy. +func ParseUpstreamStrategy(name string) (UpstreamStrategy, error) { + if x, ok := _UpstreamStrategyValue[name]; ok { + return x, nil + } + return UpstreamStrategy(0), fmt.Errorf("%s is %w", name, ErrInvalidUpstreamStrategy) +} + +// MarshalText implements the text marshaller method. +func (x UpstreamStrategy) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *UpstreamStrategy) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseUpstreamStrategy(name) + if err != nil { + return err + } + *x = tmp + return nil +} diff --git a/config/upstreams.go b/config/upstreams.go index 7343f5c3..536e8a87 100644 --- a/config/upstreams.go +++ b/config/upstreams.go @@ -8,8 +8,9 @@ const UpstreamDefaultCfgName = "default" // UpstreamsConfig upstream servers configuration type UpstreamsConfig struct { - Timeout Duration `yaml:"timeout" default:"2s"` - Groups UpstreamGroups `yaml:"groups"` + Timeout Duration `yaml:"timeout" default:"2s"` + Groups UpstreamGroups `yaml:"groups"` + Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` } type UpstreamGroups map[string][]Upstream @@ -22,7 +23,7 @@ func (c *UpstreamsConfig) IsEnabled() bool { // LogConfig implements `config.Configurable`. func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) { logger.Info("timeout: ", c.Timeout) - + logger.Info("strategy: ", c.Strategy) logger.Info("groups:") for name, upstreams := range c.Groups { diff --git a/docs/config.yml b/docs/config.yml index 2acd36b4..de4a8d60 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -18,6 +18,10 @@ upstreams: # or single ip address / client subnet as CIDR notation laptop*: - 123.123.123.123 + # optional: Determines what strategy blocky uses to choose the upstream servers. + # accepted: parallel_best, strict + # default: parallel_best + strategy: parallel_best # optional: timeout to query the upstream resolver. Default: 2s timeout: 2s diff --git a/docs/configuration.md b/docs/configuration.md index 6cb85163..0be20351 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,9 +79,9 @@ following network protocols (net part of the resolver URL): !!! hint - You can (and should!) configure multiple DNS resolvers. Blocky picks 2 random resolvers from the list for each query and - returns the answer from the fastest one. This improves your network speed and increases your privacy - your DNS traffic - will be distributed over multiple providers. + You can (and should!) configure multiple DNS resolvers. + Per default blocky uses the `parallel_best` upstream strategy where blocky picks 2 random resolvers from the list for each query and + returns the answer from the fastest one. Each resolver must be defined as a string in following format: `[net:]host:[port][/path][#commonName]`. @@ -92,13 +92,15 @@ Each resolver must be defined as a string in following format: `[net:]host:[port | port | int (1 - 65535) | no | 53 for udp/tcp, 853 for tcp-tls and 443 for https | | commonName | string | no | the host value | -The commonName parameter overrides the expected certificate common name value used for verification. +The `commonName` parameter overrides the expected certificate common name value used for verification. -Blocky needs at least the configuration of the **default** group. This group will be used as a fallback, if no client -specific resolver configuration is available. +!!! note + Blocky needs at least the configuration of the **default** group with at least one upstream DNS server. This group will be used as a fallback, if no client + specific resolver configuration is available. -You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address or a client subnet as -CIDR notation. + See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which public free DNS server you could use. + +You can specify multiple upstream groups (additional to the `default` group) to use different upstream servers for different clients, based on client name (see [Client name lookup](#client-name-lookup)), client IP address or client subnet (as CIDR). !!! tip @@ -121,15 +123,38 @@ CIDR notation. - 9.9.9.9 ``` -Use `123.123.123.123` as single upstream DNS resolver for client laptop-home, -`1.1.1.1` and `9.9.9.9` for all clients in the sub-net `10.43.8.67/28` and 4 resolvers (default) for all others clients. +The above example results in: -!!! note +- `123.123.123.123` as the only upstream DNS resolver for clients with a name starting with "laptop" +- `1.1.1.1` and `9.9.9.9` for all clients in the subnet `10.43.8.67/28` +- 4 resolvers (default) for all others clients. - ** Blocky needs at least one upstream DNS server ** +The logic determining what group a client belongs to follows a strict order: IP, client name, CIDR -See [List of public DNS servers](additional_information.md#list-of-public-dns-servers) if you need some ideas, which -public free DNS server you could use. +If a client matches multiple client name or CIDR groups, a warning is logged and the first found group is used. + +### Upstream strategy + +Blocky supports different upstream strategies (default `parallel_best`) that determine how and to which upstream DNS servers requests are forwarded. + +Currently available strategies: + +- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one. + If an upstream failed to answer within the last hour, it is less likely to be chosen for the race. + This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers + (When using 10 upstream servers, each upstream will get on average 20% of the DNS requests) +- `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on. + +!!! example + + ```yaml + upstreams: + strategy: strict + groups: + default: + - 1.2.3.4 + - 9.8.7.6 + ``` ### Upstream lookup timeout diff --git a/e2e/upstream_test.go b/e2e/upstream_test.go index d2743661..b454abd5 100644 --- a/e2e/upstream_test.go +++ b/e2e/upstream_test.go @@ -70,7 +70,9 @@ var _ = Describe("Upstream resolver configuration tests", func() { It("should not start", func() { Expect(blocky.IsRunning()).Should(BeFalse()) Expect(getContainerLogs(blocky)). - Should(ContainElement(ContainSubstring("no valid upstream for group default"))) + Should(ContainElements( + ContainSubstring("creation of upstream branches failed: "), + ContainSubstring("no valid upstream for group default"))) }) }) When("'startVerifyUpstream' is true and upstream server as host name is not reachable", func() { @@ -89,7 +91,9 @@ var _ = Describe("Upstream resolver configuration tests", func() { It("should not start", func() { Expect(blocky.IsRunning()).Should(BeFalse()) Expect(getContainerLogs(blocky)). - Should(ContainElement(ContainSubstring("no valid upstream for group default"))) + Should(ContainElements( + ContainSubstring("creation of upstream branches failed: "), + ContainSubstring("no valid upstream for group default"))) }) }) }) diff --git a/go.mod b/go.mod index cb943815..20844c4c 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 + github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef github.com/docker/go-connections v0.4.0 github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198 github.com/testcontainers/testcontainers-go v0.22.0 diff --git a/go.sum b/go.sum index 76786cd6..c2030a4c 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Microsoft/hcsshim v0.10.0-rc.8 h1:YSZVvlIIDD1UxQpJp0h+dnpLUw+TrY0cx8obKsp3bek= -github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI= -github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo= +github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef h1:lg6zRor4+PZN1Pxqtieo/NMhd61ZdV1Z/+bFURWIVfU= +github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo= github.com/abice/go-enum v0.5.7 h1:vOrobjpce5D/x5hYNqrWRkFUXFk7A6BlsJyVy4BS1jM= github.com/abice/go-enum v0.5.7/go.mod h1:FBDp+2Ygv9ZZzgcd+Gx3XbyClH7xxFfw8ghMrOpwu+A= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= diff --git a/lists/list_cache_enum.go b/lists/list_cache_enum.go index c1fdf38e..39f41e52 100644 --- a/lists/list_cache_enum.go +++ b/lists/list_cache_enum.go @@ -49,6 +49,13 @@ func (x ListCacheType) String() string { return fmt.Sprintf("ListCacheType(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x ListCacheType) IsValid() bool { + _, ok := _ListCacheTypeMap[x] + return ok +} + var _ListCacheTypeValue = map[string]ListCacheType{ _ListCacheTypeName[0:9]: ListCacheTypeBlacklist, _ListCacheTypeName[9:18]: ListCacheTypeWhitelist, diff --git a/log/logger_enum.go b/log/logger_enum.go index 70da43db..d4bc69c2 100644 --- a/log/logger_enum.go +++ b/log/logger_enum.go @@ -49,6 +49,13 @@ func (x FormatType) String() string { return fmt.Sprintf("FormatType(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x FormatType) IsValid() bool { + _, ok := _FormatTypeMap[x] + return ok +} + var _FormatTypeValue = map[string]FormatType{ _FormatTypeName[0:4]: FormatTypeText, _FormatTypeName[4:8]: FormatTypeJson, @@ -130,6 +137,13 @@ func (x Level) String() string { return fmt.Sprintf("Level(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x Level) IsValid() bool { + _, ok := _LevelMap[x] + return ok +} + var _LevelValue = map[string]Level{ _LevelName[0:4]: LevelInfo, _LevelName[4:9]: LevelTrace, diff --git a/model/models_enum.go b/model/models_enum.go index ad6f939c..f6ac755d 100644 --- a/model/models_enum.go +++ b/model/models_enum.go @@ -49,6 +49,13 @@ func (x RequestProtocol) String() string { return fmt.Sprintf("RequestProtocol(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x RequestProtocol) IsValid() bool { + _, ok := _RequestProtocolMap[x] + return ok +} + var _RequestProtocolValue = map[string]RequestProtocol{ _RequestProtocolName[0:3]: RequestProtocolTCP, _RequestProtocolName[3:6]: RequestProtocolUDP, @@ -151,6 +158,13 @@ func (x ResponseType) String() string { return fmt.Sprintf("ResponseType(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x ResponseType) IsValid() bool { + _, ok := _ResponseTypeMap[x] + return ok +} + var _ResponseTypeValue = map[string]ResponseType{ _ResponseTypeName[0:8]: ResponseTypeRESOLVED, _ResponseTypeName[8:14]: ResponseTypeCACHED, diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index c80883ff..899fc972 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -69,6 +69,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { JustBeforeEach(func() { var err error + m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap) diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index bcf593a4..f100c4da 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -66,10 +66,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { // where `ParallelBestResolver` uses its config, we can just use an empty one. var pbCfg config.UpstreamsConfig - parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups()) - if err != nil { - return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err) - } + parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups()) // Always enable prefetching to avoid stalling user requests // Otherwise, a request to blocky could end up waiting for 2 DNS requests: diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index eff18f44..75374f59 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -30,9 +30,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) JustBeforeEach(func() { - res, err := NewClientNamesResolver(sutConfig, nil, false) + var err error + + sut, err = NewClientNamesResolver(sutConfig, nil, false) Expect(err).Should(Succeed()) - sut = res m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) diff --git a/resolver/conditional_upstream_resolver.go b/resolver/conditional_upstream_resolver.go index 1a0b2bb0..f3553d8c 100644 --- a/resolver/conditional_upstream_resolver.go +++ b/resolver/conditional_upstream_resolver.go @@ -24,7 +24,7 @@ type ConditionalUpstreamResolver struct { // NewConditionalUpstreamResolver returns new resolver instance func NewConditionalUpstreamResolver( cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, -) (ChainedResolver, error) { +) (*ConditionalUpstreamResolver, error) { m := make(map[string]Resolver, len(cfg.Mapping.Upstreams)) for domain, upstream := range cfg.Mapping.Upstreams { diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index d134d127..8bc6ccd9 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -15,7 +15,7 @@ import ( var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() { var ( - sut ChainedResolver + sut *ConditionalUpstreamResolver m *mockResolver ) diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index 56ff2b50..38930066 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -24,7 +24,7 @@ type CustomDNSResolver struct { } // NewCustomDNSResolver creates new resolver instance -func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver { +func NewCustomDNSResolver(cfg config.CustomDNSConfig) *CustomDNSResolver { m := make(map[string][]net.IP, len(cfg.Mapping.HostIPs)) reverse := make(map[string][]string, len(cfg.Mapping.HostIPs)) diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index 854f342f..9ad65f17 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -18,7 +18,7 @@ var _ = Describe("CustomDNSResolver", func() { var ( TTL = uint32(time.Now().Second()) - sut ChainedResolver + sut *CustomDNSResolver m *mockResolver cfg config.CustomDNSConfig ) diff --git a/resolver/ede_resolver.go b/resolver/ede_resolver.go index e924a892..5e458d91 100644 --- a/resolver/ede_resolver.go +++ b/resolver/ede_resolver.go @@ -12,7 +12,7 @@ type EdeResolver struct { typed } -func NewEdeResolver(cfg config.EdeConfig) ChainedResolver { +func NewEdeResolver(cfg config.EdeConfig) *EdeResolver { return &EdeResolver{ configurable: withConfig(&cfg), typed: withType("extended_error_code"), diff --git a/resolver/ede_resolver_test.go b/resolver/ede_resolver_test.go index 5038db79..fc7e7ad5 100644 --- a/resolver/ede_resolver_test.go +++ b/resolver/ede_resolver_test.go @@ -43,7 +43,7 @@ var _ = Describe("EdeResolver", func() { }, nil) } - sut = NewEdeResolver(sutConfig).(*EdeResolver) + sut = NewEdeResolver(sutConfig) sut.Next(m) }) diff --git a/resolver/filtering_resolver.go b/resolver/filtering_resolver.go index 738c970f..8ff68909 100644 --- a/resolver/filtering_resolver.go +++ b/resolver/filtering_resolver.go @@ -14,7 +14,7 @@ type FilteringResolver struct { typed } -func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver { +func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver { return &FilteringResolver{ configurable: withConfig(&cfg), typed: withType("filtering"), diff --git a/resolver/filtering_resolver_test.go b/resolver/filtering_resolver_test.go index 2c0a02b1..aeaf5600 100644 --- a/resolver/filtering_resolver_test.go +++ b/resolver/filtering_resolver_test.go @@ -31,7 +31,7 @@ var _ = Describe("FilteringResolver", func() { }) JustBeforeEach(func() { - sut = NewFilteringResolver(sutConfig).(*FilteringResolver) + sut = NewFilteringResolver(sutConfig) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) sut.Next(m) diff --git a/resolver/metrics_resolver.go b/resolver/metrics_resolver.go index 0c15c00f..96459d7a 100644 --- a/resolver/metrics_resolver.go +++ b/resolver/metrics_resolver.go @@ -58,7 +58,7 @@ func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro } // NewMetricsResolver creates a new intance of the MetricsResolver type -func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver { +func NewMetricsResolver(cfg config.MetricsConfig) *MetricsResolver { m := MetricsResolver{ configurable: withConfig(&cfg), typed: withType("metrics"), diff --git a/resolver/metrics_resolver_test.go b/resolver/metrics_resolver_test.go index eca8ab69..a686e537 100644 --- a/resolver/metrics_resolver_test.go +++ b/resolver/metrics_resolver_test.go @@ -30,7 +30,7 @@ var _ = Describe("MetricResolver", func() { }) BeforeEach(func() { - sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver) + sut = NewMetricsResolver(config.MetricsConfig{Enable: true}) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) diff --git a/resolver/noop_resolver.go b/resolver/noop_resolver.go index 9b3f6885..6cd12f6c 100644 --- a/resolver/noop_resolver.go +++ b/resolver/noop_resolver.go @@ -10,8 +10,8 @@ var NoResponse = &model.Response{} //nolint:gochecknoglobals // NoOpResolver is used to finish a resolver branch as created in RewriterResolver type NoOpResolver struct{} -func NewNoOpResolver() Resolver { - return NoOpResolver{} +func NewNoOpResolver() *NoOpResolver { + return &NoOpResolver{} } // Type implements `Resolver`. diff --git a/resolver/noop_resolver_test.go b/resolver/noop_resolver_test.go index 6493fef1..b9f47096 100644 --- a/resolver/noop_resolver_test.go +++ b/resolver/noop_resolver_test.go @@ -8,7 +8,7 @@ import ( ) var _ = Describe("NoOpResolver", func() { - var sut NoOpResolver + var sut *NoOpResolver Describe("Type", func() { It("follows conventions", func() { @@ -17,7 +17,7 @@ var _ = Describe("NoOpResolver", func() { }) BeforeEach(func() { - sut = NewNoOpResolver().(NoOpResolver) + sut = NewNoOpResolver() }) Describe("Resolving", func() { diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index d3d87f02..cd7b02cf 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -120,12 +120,12 @@ func NewParallelBestResolver( resolverGroups[name] = group } - return newParallelBestResolver(cfg, resolverGroups) + return newParallelBestResolver(cfg, resolverGroups), nil } func newParallelBestResolver( cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver, -) (*ParallelBestResolver, error) { +) *ParallelBestResolver { resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups)) for groupName, resolvers := range resolverGroups { @@ -138,11 +138,6 @@ func newParallelBestResolver( resolversPerClient[groupName] = resolverStatuses } - if len(resolversPerClient[upstreamDefaultCfgName]) == 0 { - return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+ - "Please configure at least one under '%s' configuration name", upstreamDefaultCfgName) - } - r := ParallelBestResolver{ configurable: withConfig(&cfg), typed: withType(parallelResolverType), @@ -150,7 +145,7 @@ func newParallelBestResolver( resolversPerClient: resolversPerClient, } - return &r, nil + return &r } func (r *ParallelBestResolver) Name() string { @@ -172,45 +167,16 @@ func (r *ParallelBestResolver) String() string { return fmt.Sprintf("parallel upstreams '%s'", strings.Join(result, "; ")) } -func (r *ParallelBestResolver) resolversForClient(request *model.Request) (result []*upstreamResolverStatus) { - clientIP := request.ClientIP.String() - - // try client names - for _, cName := range request.ClientNames { - for clientDefinition, upstreams := range r.resolversPerClient { - if cName != clientIP && util.ClientNameMatchesGroupName(clientDefinition, cName) { - result = append(result, upstreams...) - } - } - } - - // try IP - upstreams, found := r.resolversPerClient[clientIP] - - if found { - result = append(result, upstreams...) - } - - // try CIDR - for cidr, upstreams := range r.resolversPerClient { - if util.CidrContainsIP(cidr, request.ClientIP) { - result = append(result, upstreams...) - } - } - - if len(result) == 0 { - // return default - result = r.resolversPerClient[upstreamDefaultCfgName] - } - - return result -} - // Resolve sends the query request to multiple upstream resolvers and returns the fastest result func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) { logger := log.WithPrefix(request.Log, parallelResolverType) - resolvers := r.resolversForClient(request) + var resolvers []*upstreamResolverStatus + for _, r := range r.resolversPerClient { + resolvers = r + + break + } if len(resolvers) == 1 { logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver") @@ -233,21 +199,19 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, go r2.resolve(request, ch) - //nolint: gosimple for len(collectedErrors) < resolverCount { - select { - case result := <-ch: - if result.err != nil { - logger.Debug("resolution failed from resolver, cause: ", result.err) - collectedErrors = append(collectedErrors, result.err) - } else { - logger.WithFields(logrus.Fields{ - "resolver": *result.resolver, - "answer": util.AnswerToString(result.response.Res.Answer), - }).Debug("using response from resolver") + result := <-ch - return result.response, nil - } + if result.err != nil { + logger.Debug("resolution failed from resolver, cause: ", result.err) + collectedErrors = append(collectedErrors, result.err) + } else { + logger.WithFields(logrus.Fields{ + "resolver": *result.resolver, + "answer": util.AnswerToString(result.response.Res.Answer), + }).Debug("using response from resolver") + + return result.response, nil } } @@ -266,9 +230,13 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus { const errorWindowInSec = 60 - var choices []weightedrand.Choice[*upstreamResolverStatus, uint] + choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in)) for _, res := range in { + if exclude == res.resolver { + continue + } + var weight float64 = errorWindowInSec if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour { @@ -277,9 +245,7 @@ func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamRes weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes())) } - if exclude != res.resolver { - choices = append(choices, weightedrand.NewChoice(res, uint(weight))) - } + choices = append(choices, weightedrand.NewChoice(res, uint(weight))) } c, err := weightedrand.NewChooser(choices...) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 5544e6c5..0278be84 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -84,16 +84,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) }) - When("default upstream resolvers are not defined", func() { - It("should fail on startup", func() { - _, err := NewParallelBestResolver(config.UpstreamsConfig{ - Groups: config.UpstreamGroups{}, - }, nil, noVerifyUpstreams) - Expect(err).Should(HaveOccurred()) - Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured")) - }) - }) - When("some default upstream resolvers cannot be reached", func() { It("should start normally", func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { @@ -234,124 +224,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) }) }) - When("client specific resolvers are defined", func() { - When("client name matches", func() { - BeforeEach(func() { - defaultMockUpstream := NewMockUDPUpstreamServer(). - WithAnswerRR("example.com 123 IN A 123.124.122.122") - DeferCleanup(defaultMockUpstream.Close) - - clientSpecificExactMockUpstream := NewMockUDPUpstreamServer(). - WithAnswerRR("example.com 123 IN A 123.124.122.123") - DeferCleanup(clientSpecificExactMockUpstream.Close) - - clientSpecificWildcardMockUpstream := NewMockUDPUpstreamServer(). - WithAnswerRR("example.com 123 IN A 123.124.122.124") - DeferCleanup(clientSpecificWildcardMockUpstream.Close) - - clientSpecificIPMockUpstream := NewMockUDPUpstreamServer(). - WithAnswerRR("example.com 123 IN A 123.124.122.125") - DeferCleanup(clientSpecificIPMockUpstream.Close) - - clientSpecificCIRDMockUpstream := NewMockUDPUpstreamServer(). - WithAnswerRR("example.com 123 IN A 123.124.122.126") - DeferCleanup(clientSpecificCIRDMockUpstream.Close) - - sutMapping = config.UpstreamGroups{ - upstreamDefaultCfgName: {defaultMockUpstream.Start()}, - "laptop": {clientSpecificExactMockUpstream.Start()}, - "client-*-m": {clientSpecificWildcardMockUpstream.Start()}, - "client[0-9]": {clientSpecificWildcardMockUpstream.Start()}, - "192.168.178.33": {clientSpecificIPMockUpstream.Start()}, - "10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()}, - } - }) - It("Should use default if client name or IP don't match", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.122"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client name matches exact", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.123"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client name matches with wildcard", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.124"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client name matches with range wildcard", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.124"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client IP matches", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.33", "cl") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.125"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client IP/name matches", func() { - request := newRequestWithClient("example.com.", A, "192.168.178.33", "192.168.178.33") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.125"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() { - request := newRequestWithClient("example.com.", A, "10.43.8.64", "cl") - - Expect(sut.Resolve(request)). - Should( - SatisfyAll( - BeDNSRecord("example.com.", A, "123.124.122.126"), - HaveTTL(BeNumerically("==", 123)), - HaveResponseType(ResponseTypeRESOLVED), - HaveReturnCode(dns.RcodeSuccess), - )) - }) - }) - }) When("only 1 upstream resolvers is defined", func() { BeforeEach(func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") @@ -390,17 +262,20 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ := NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{ + sut, _ = NewParallelBestResolver(config.UpstreamsConfig{Groups: config.UpstreamGroups{ upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, - }}, systemResolverBootstrap, noVerifyUpstreams) + }}, + systemResolverBootstrap, noVerifyUpstreams) By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) for i := 0; i < 1000; i++ { - r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient( - "example.com", A, "123.123.100.100", - ))) + var resolvers []*upstreamResolverStatus + for _, r := range sut.resolversPerClient { + resolvers = r + } + r1, r2 := pickRandom(resolvers) res1 := r1.resolver res2 := r2.resolver Expect(res1).ShouldNot(Equal(res2)) @@ -425,9 +300,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 100; i++ { - r1, r2 := pickRandom(sut.resolversForClient(newRequestWithClient( - "example.com", A, "123.123.100.100", - ))) + var resolvers []*upstreamResolverStatus + for _, r := range sut.resolversPerClient { + resolvers = r + } + r1, r2 := pickRandom(resolvers) res1 := r1.resolver.(*UpstreamResolver) res2 := r2.resolver.(*UpstreamResolver) Expect(res1).ShouldNot(Equal(res2)) diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go index 1b80a6f8..b21bf304 100644 --- a/resolver/query_logging_resolver.go +++ b/resolver/query_logging_resolver.go @@ -30,7 +30,7 @@ type QueryLoggingResolver struct { } // NewQueryLoggingResolver returns a new resolver instance -func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { +func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver { logger := log.PrefixedLog(queryLoggingResolverType) var writer querylog.Writer diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index 1a4d02ff..7386654d 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -63,7 +63,7 @@ var _ = Describe("QueryLoggingResolver", func() { sutConfig.SetDefaults() // not called when using a struct literal } - sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver) + sut = NewQueryLoggingResolver(sutConfig) DeferCleanup(func() { close(sut.logChan) }) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil) diff --git a/resolver/rewriter_resolver.go b/resolver/rewriter_resolver.go index 8024b546..320c94d4 100644 --- a/resolver/rewriter_resolver.go +++ b/resolver/rewriter_resolver.go @@ -30,6 +30,7 @@ func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) Chain return inner } + // ensures that the rewrites map contains all rewrites in lower case for k, v := range cfg.Rewrite { cfg.Rewrite[strings.ToLower(k)] = strings.ToLower(v) } diff --git a/resolver/strict_resolver.go b/resolver/strict_resolver.go new file mode 100644 index 00000000..2fd6c460 --- /dev/null +++ b/resolver/strict_resolver.go @@ -0,0 +1,165 @@ +package resolver + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" + "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" + + "github.com/sirupsen/logrus" +) + +const ( + strictResolverType = "strict" +) + +// StrictResolver delegates the DNS message strictly to the first configured upstream resolver +// if it can't provide the answer in time the next resolver is used +type StrictResolver struct { + configurable[*config.UpstreamsConfig] + typed + + resolversPerClient map[string][]*upstreamResolverStatus +} + +// NewStrictResolver creates new resolver instance +func NewStrictResolver( + cfg config.UpstreamsConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, +) (*StrictResolver, error) { + logger := log.PrefixedLog(strictResolverType) + + upstreamResolvers := cfg.Groups + resolverGroups := make(map[string][]Resolver, len(upstreamResolvers)) + + for name, upstreamCfgs := range upstreamResolvers { + group := make([]Resolver, 0, len(upstreamCfgs)) + hasValidResolver := false + + for _, u := range upstreamCfgs { + resolver, err := NewUpstreamResolver(u, bootstrap, shouldVerifyUpstreams) + if err != nil { + logger.Warnf("upstream group %s: %v", name, err) + + continue + } + + if shouldVerifyUpstreams { + err = testResolver(resolver) + if err != nil { + logger.Warn(err) + } else { + hasValidResolver = true + } + } + + group = append(group, resolver) + } + + if shouldVerifyUpstreams && !hasValidResolver { + return nil, fmt.Errorf("no valid upstream for group %s", name) + } + + resolverGroups[name] = group + } + + return newStrictResolver(cfg, resolverGroups), nil +} + +func newStrictResolver( + cfg config.UpstreamsConfig, resolverGroups map[string][]Resolver, +) *StrictResolver { + resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups)) + + for groupName, resolvers := range resolverGroups { + resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) + + for _, r := range resolvers { + resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) + } + + resolversPerClient[groupName] = resolverStatuses + } + + r := StrictResolver{ + configurable: withConfig(&cfg), + typed: withType(strictResolverType), + + resolversPerClient: resolversPerClient, + } + + return &r +} + +func (r *StrictResolver) Name() string { + return r.String() +} + +func (r *StrictResolver) String() string { + result := make([]string, 0, len(r.resolversPerClient)) + + for name, res := range r.resolversPerClient { + tmp := make([]string, len(res)) + for i, s := range res { + tmp[i] = fmt.Sprintf("%s", s.resolver) + } + + result = append(result, fmt.Sprintf("%s (%s)", name, strings.Join(tmp, ","))) + } + + return fmt.Sprintf("%s upstreams %q", strictResolverType, strings.Join(result, "; ")) +} + +// Resolve sends the query request to multiple upstream resolvers and returns the fastest result +func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) { + logger := log.WithPrefix(request.Log, strictResolverType) + + var resolvers []*upstreamResolverStatus + for _, r := range r.resolversPerClient { + resolvers = r + + break + } + + // start with first resolver + for i := range resolvers { + timeout := config.GetConfig().Upstreams.Timeout.ToDuration() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // start in new go routine and cancel if + + resolver := resolvers[i] + ch := make(chan requestResponse, resolverCount) + + go resolver.resolve(request, ch) + + select { + case <-ctx.Done(): + // log debug/info that timeout exceeded, call `continue` to try next upstream + logger.WithField("resolver", resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream") + + continue + case result := <-ch: + if result.err != nil { + // log error & call `continue` to try next upstream + logger.Debug("resolution failed from resolver, cause: ", result.err) + + continue + } + + logger.WithFields(logrus.Fields{ + "resolver": *result.resolver, + "answer": util.AnswerToString(result.response.Res.Answer), + }).Debug("using response from resolver") + + return result.response, nil + } + } + + return nil, errors.New("resolution was not successful, no resolver returned an answer in time") +} diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go new file mode 100644 index 00000000..84719c80 --- /dev/null +++ b/resolver/strict_resolver_test.go @@ -0,0 +1,306 @@ +package resolver + +import ( + "time" + + "github.com/0xERR0R/blocky/config" + . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" + . "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" + "github.com/miekg/dns" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("StrictResolver", Label("strictResolver"), func() { + const ( + verifyUpstreams = true + noVerifyUpstreams = false + ) + + var ( + sut *StrictResolver + sutMapping config.UpstreamGroups + sutVerify bool + + err error + + bootstrap *Bootstrap + ) + + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + + BeforeEach(func() { + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: { + config.Upstream{ + Host: "wrong", + }, + config.Upstream{ + Host: "127.0.0.2", + }, + }, + } + + sutVerify = noVerifyUpstreams + + bootstrap = systemResolverBootstrap + }) + + JustBeforeEach(func() { + sutConfig := config.UpstreamsConfig{Groups: sutMapping} + + sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify) + }) + + config.GetConfig().Upstreams.Timeout = config.Duration(1000 * time.Millisecond) + + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + + Describe("Type", func() { + It("should be correct", func() { + Expect(sut.Type()).ShouldNot(BeEmpty()) + Expect(sut.Type()).Should(Equal(strictResolverType)) + }) + }) + + Describe("Name", func() { + It("should contain correct resolver", func() { + Expect(sut.Name()).ShouldNot(BeEmpty()) + Expect(sut.Name()).Should(ContainSubstring(strictResolverType)) + }) + }) + + When("some default upstream resolvers cannot be reached", func() { + It("should start normally", func() { + mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") + + return + }) + defer mockUpstream.Close() + + upstream := config.UpstreamGroups{ + upstreamDefaultCfgName: { + config.Upstream{ + Host: "wrong", + }, + mockUpstream.Start(), + }, + } + + _, err := NewStrictResolver(config.UpstreamsConfig{ + Groups: upstream, + }, systemResolverBootstrap, verifyUpstreams) + Expect(err).Should(Not(HaveOccurred())) + }) + }) + + When("no upstream resolvers can be reached", func() { + BeforeEach(func() { + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: { + config.Upstream{ + Host: "wrong", + }, + config.Upstream{ + Host: "127.0.0.2", + }, + }, + } + }) + + When("strict checking is enabled", func() { + BeforeEach(func() { + sutVerify = verifyUpstreams + }) + It("should fail to start", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + + When("strict checking is disabled", func() { + BeforeEach(func() { + sutVerify = noVerifyUpstreams + }) + It("should start", func() { + Expect(err).Should(Not(HaveOccurred())) + }) + }) + }) + + Describe("Resolving request in strict order", func() { + When("2 Upstream resolvers are defined", func() { + When("Both are responding", func() { + When("they respond in time", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") + DeferCleanup(testUpstream2.Close) + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, + } + }) + It("Should use result from first one", func() { + request := newRequest("example.com.", A) + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.122"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + When("first upstream exceeds upstreamTimeout", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") + time.Sleep(time.Duration(config.GetConfig().Upstreams.Timeout) + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") + DeferCleanup(testUpstream2.Close) + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, + } + }) + It("should return response from next upstream", func() { + request := newRequest("example.com", A) + Expect(sut.Resolve(request)).Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.2"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + When("all upstreams exceed upsteamTimeout", func() { + BeforeEach(func() { + testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") + time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream1.Close) + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2") + time.Sleep(config.GetConfig().Upstreams.Timeout.ToDuration() + 2*time.Second) + + Expect(err).To(Succeed()) + + return response + }) + DeferCleanup(testUpstream2.Close) + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: {testUpstream1.Start(), testUpstream2.Start()}, + } + }) + It("should return error", func() { + request := newRequest("example.com", A) + _, err := sut.Resolve(request) + Expect(err).To(HaveOccurred()) + }) + }) + }) + When("Only second is working", func() { + BeforeEach(func() { + testUpstream1 := config.Upstream{Host: "wrong"} + + testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") + DeferCleanup(testUpstream2.Close) + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: {testUpstream1, testUpstream2.Start()}, + } + }) + It("Should use result from second one", func() { + request := newRequest("example.com.", A) + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.123"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + When("None are working", func() { + BeforeEach(func() { + testUpstream1 := config.Upstream{Host: "wrong"} + testUpstream2 := config.Upstream{Host: "wrong"} + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: {testUpstream1, testUpstream2}, + } + Expect(err).Should(Succeed()) + }) + It("Should return error", func() { + request := newRequest("example.com.", A) + _, err := sut.Resolve(request) + Expect(err).Should(HaveOccurred()) + }) + }) + }) + When("only 1 upstream resolvers is defined", func() { + BeforeEach(func() { + mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + DeferCleanup(mockUpstream.Close) + + sutMapping = config.UpstreamGroups{ + upstreamDefaultCfgName: { + mockUpstream.Start(), + }, + } + }) + It("Should use result from defined resolver", func() { + request := newRequest("example.com.", A) + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "123.124.122.122"), + HaveTTL(BeNumerically("==", 123)), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) + }) +}) diff --git a/resolver/sudn_resolver.go b/resolver/sudn_resolver.go index 44508a2d..f70c51a5 100644 --- a/resolver/sudn_resolver.go +++ b/resolver/sudn_resolver.go @@ -92,7 +92,7 @@ type SpecialUseDomainNamesResolver struct { configurable[*config.SUDNConfig] } -func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) ChainedResolver { +func NewSpecialUseDomainNamesResolver(cfg config.SUDNConfig) *SpecialUseDomainNamesResolver { return &SpecialUseDomainNamesResolver{ typed: withType("special_use_domains"), configurable: withConfig(&cfg), diff --git a/resolver/sudn_resolver_test.go b/resolver/sudn_resolver_test.go index 5cf9ed00..1c004b20 100644 --- a/resolver/sudn_resolver_test.go +++ b/resolver/sudn_resolver_test.go @@ -41,7 +41,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) - sut = NewSpecialUseDomainNamesResolver(sutConfig).(*SpecialUseDomainNamesResolver) + sut = NewSpecialUseDomainNamesResolver(sutConfig) sut.Next(m) }) diff --git a/resolver/upstream_tree_resolver.go b/resolver/upstream_tree_resolver.go new file mode 100644 index 00000000..72a13795 --- /dev/null +++ b/resolver/upstream_tree_resolver.go @@ -0,0 +1,118 @@ +package resolver + +import ( + "fmt" + "strings" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" + "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" + "github.com/sirupsen/logrus" +) + +const ( + upstreamTreeResolverType = "upstream_tree" +) + +type UpstreamTreeResolver struct { + configurable[*config.UpstreamsConfig] + typed + + branches map[string]Resolver +} + +func NewUpstreamTreeResolver(cfg config.UpstreamsConfig, branches map[string]Resolver) (Resolver, error) { + if len(cfg.Groups[upstreamDefaultCfgName]) == 0 { + return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+ + "Please configure at least one under '%s' configuration name", upstreamDefaultCfgName) + } + + if len(branches) != len(cfg.Groups) { + return nil, fmt.Errorf("amount of passed in branches (%d) does not match amount of configured upstream groups (%d)", + len(branches), len(cfg.Groups)) + } + + if len(branches) == 1 { + for _, r := range branches { + return r, nil + } + } + + // return resolver that forwards request to specific resolver branch depending on the client + r := UpstreamTreeResolver{ + configurable: withConfig(&cfg), + typed: withType(upstreamTreeResolverType), + + branches: branches, + } + + return &r, nil +} + +func (r *UpstreamTreeResolver) Name() string { + return r.String() +} + +func (r *UpstreamTreeResolver) String() string { + result := make([]string, 0, len(r.branches)) + + for group, res := range r.branches { + result = append(result, fmt.Sprintf("%s (%s)", group, res.Type())) + } + + return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", ")) +} + +func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) { + logger := log.WithPrefix(request.Log, upstreamTreeResolverType) + + group := r.upstreamGroupByClient(request) + + // delegate request to group resolver + logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver") + + return r.branches[group].Resolve(request) +} + +func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string { + groups := []string{} + clientIP := request.ClientIP.String() + + // try IP + if _, exists := r.branches[clientIP]; exists { + return clientIP + } + + // try client names + for _, name := range request.ClientNames { + for group := range r.branches { + if util.ClientNameMatchesGroupName(group, name) { + groups = append(groups, group) + } + } + } + + // try CIDR (only if no client name matched) + if len(groups) == 0 { + for cidr := range r.branches { + if util.CidrContainsIP(cidr, request.ClientIP) { + groups = append(groups, cidr) + } + } + } + + if len(groups) > 0 { + if len(groups) > 1 { + r.log().WithFields(logrus.Fields{ + "clientNames": request.ClientNames, + "clientIP": clientIP, + "groups": groups, + }).Warn("client matches multiple groups") + } + + return groups[0] + } + + return upstreamDefaultCfgName +} diff --git a/resolver/upstream_tree_resolver_test.go b/resolver/upstream_tree_resolver_test.go new file mode 100644 index 00000000..11faab42 --- /dev/null +++ b/resolver/upstream_tree_resolver_test.go @@ -0,0 +1,328 @@ +package resolver + +import ( + "github.com/0xERR0R/blocky/config" + . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" + . "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" + "github.com/miekg/dns" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/mock" +) + +var mockRes *mockResolver + +var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { + var ( + sut Resolver + sutConfig config.UpstreamsConfig + branches map[string]Resolver + + loggerHook *test.Hook + + err error + ) + + BeforeEach(func() { + mockRes = &mockResolver{} + }) + + JustBeforeEach(func() { + sut, err = NewUpstreamTreeResolver(sutConfig, branches) + }) + + When("has no configuration", func() { + BeforeEach(func() { + sutConfig = config.UpstreamsConfig{} + }) + + It("should return error", func() { + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("no external DNS resolvers configured"))) + Expect(sut).To(BeNil()) + }) + }) + + When("amount of passed in resolvers doesn't match amount of groups", func() { + BeforeEach(func() { + sutConfig = config.UpstreamsConfig{ + Groups: config.UpstreamGroups{ + upstreamDefaultCfgName: { + {Host: "wrong"}, + {Host: "127.0.0.1"}, + }, + }, + } + branches = map[string]Resolver{} + }) + + It("should return error", func() { + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError( + "amount of passed in branches (0) does not match amount of configured upstream groups (1)")) + Expect(sut).To(BeNil()) + }) + }) + + When("has only default group", func() { + BeforeEach(func() { + sutConfig = config.UpstreamsConfig{ + Groups: config.UpstreamGroups{ + upstreamDefaultCfgName: { + {Host: "wrong"}, + {Host: "127.0.0.1"}, + }, + }, + } + branches = createBranchesMock(sutConfig) + }) + Describe("Type", func() { + It("does not return error", func() { + Expect(err).ToNot(HaveOccurred()) + }) + It("follows conventions", func() { + expectValidResolverType(sut) + }) + It("returns mock", func() { + Expect(sut.Type()).To(Equal("mock")) + }) + }) + }) + + When("has multiple groups", func() { + BeforeEach(func() { + sutConfig = config.UpstreamsConfig{ + Groups: config.UpstreamGroups{ + upstreamDefaultCfgName: { + {Host: "wrong"}, + {Host: "127.0.0.1"}, + }, + "test": { + {Host: "some-resolver"}, + }, + }, + } + branches = createBranchesMock(sutConfig) + }) + Describe("Type", func() { + It("does not return error", func() { + Expect(err).ToNot(HaveOccurred()) + }) + It("follows conventions", func() { + expectValidResolverType(sut) + }) + It("returns upstream_tree", func() { + Expect(sut.Type()).To(Equal(upstreamTreeResolverType)) + }) + }) + Describe("Configuration output", func() { + It("should return configuration", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + + logger, hook := log.NewMockEntry() + sut.LogConfig(logger) + Expect(hook.Calls).ToNot(BeEmpty()) + }) + }) + + Describe("Name", func() { + var utrSut *UpstreamTreeResolver + JustBeforeEach(func() { + utrSut = sut.(*UpstreamTreeResolver) + }) + + It("should contain correct resolver", func() { + name := utrSut.Name() + Expect(name).ShouldNot(BeEmpty()) + Expect(name).Should(ContainSubstring(upstreamTreeResolverType)) + }) + }) + + When("client specific resolvers are defined", func() { + BeforeEach(func() { + loggerHook = test.NewGlobal() + log.Log().AddHook(loggerHook) + + sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{ + upstreamDefaultCfgName: {config.Upstream{}}, + "laptop": {config.Upstream{}}, + "client-*-m": {config.Upstream{}}, + "client[0-9]": {config.Upstream{}}, + "192.168.178.33": {config.Upstream{}}, + "10.43.8.67/28": {config.Upstream{}}, + "name-matches1": {config.Upstream{}}, + "name-matches*": {config.Upstream{}}, + }} + + createMockResolver := func(group string) *mockResolver { + resolver := &mockResolver{} + + resolver.On("Resolve", mock.Anything) + resolver.ResponseFn = func(req *dns.Msg) *dns.Msg { + res := new(dns.Msg) + res.SetReply(req) + + ptr := new(dns.PTR) + ptr.Ptr = group + ptr.Hdr = util.CreateHeader(req.Question[0], 1) + res.Answer = append(res.Answer, ptr) + + return res + } + + return resolver + } + + branches = map[string]Resolver{ + upstreamDefaultCfgName: nil, + "laptop": nil, + "client-*-m": nil, + "client[0-9]": nil, + "192.168.178.33": nil, + "10.43.8.67/28": nil, + "name-matches1": nil, + "name-matches*": nil, + } + + for group := range branches { + branches[group] = createMockResolver(group) + } + + Expect(branches).To(HaveLen(8)) + }) + + AfterEach(func() { + loggerHook.Reset() + }) + + It("Should use default if client name or IP don't match", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "default"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client name matches exact", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "laptop"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client name matches with wildcard", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "client-*-m"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client name matches with range wildcard", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "client[0-9]"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client IP matches", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "192.168.178.33"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client name (containing IP) matches", func() { + request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "192.168.178.33"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() { + request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "10.43.8.67/28"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use exact IP match before client name match", func() { + request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "192.168.178.33"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use client name match before CIDR match", func() { + request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + BeDNSRecord("example.com.", A, "laptop"), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Should use one of the matching resolvers & log warning", func() { + request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1") + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + SatisfyAny( + BeDNSRecord("example.com.", A, "name-matches1"), + BeDNSRecord("example.com.", A, "name-matches*"), + ), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + )) + + Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("client matches multiple groups")) + }) + }) + }) +}) + +func createBranchesMock(cfg config.UpstreamsConfig) map[string]Resolver { + branches := make(map[string]Resolver, len(cfg.Groups)) + + for name := range cfg.Groups { + branches[name] = mockRes + } + + return branches +} diff --git a/server/server.go b/server/server.go index 415a7e36..1b553995 100644 --- a/server/server.go +++ b/server/server.go @@ -396,15 +396,21 @@ func createQueryResolver( bootstrap *resolver.Bootstrap, redisClient *redis.Client, ) (r resolver.Resolver, err error) { + upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap) + if uErr != nil { + return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr) + } + + upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches) + blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap) - parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstreams, bootstrap, cfg.StartVerifyUpstream) clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream) hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap) err = multierror.Append( + multierror.Prefix(utErr, "upstream tree resolver: "), multierror.Prefix(blErr, "blocking resolver: "), - multierror.Prefix(pErr, "parallel resolver: "), multierror.Prefix(cnErr, "client names resolver: "), multierror.Prefix(cuErr, "conditional upstream resolver: "), multierror.Prefix(hfErr, "hosts file resolver: "), @@ -426,12 +432,42 @@ func createQueryResolver( resolver.NewCachingResolver(cfg.Caching, redisClient), resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream), resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN), - parallel, + upstreamTree, ) return r, nil } +func createUpstreamBranches( + cfg *config.Config, + bootstrap *resolver.Bootstrap, +) (map[string]resolver.Resolver, error) { + upstreamBranches := make(map[string]resolver.Resolver, len(cfg.Upstreams.Groups)) + + var uErr error + + for group, upstreams := range cfg.Upstreams.Groups { + var ( + upstream resolver.Resolver + err error + ) + + resolverCfg := config.UpstreamsConfig{Groups: config.UpstreamGroups{group: upstreams}} + + switch cfg.Upstreams.Strategy { + case config.UpstreamStrategyStrict: + upstream, err = resolver.NewStrictResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream) + case config.UpstreamStrategyParallelBest: + upstream, err = resolver.NewParallelBestResolver(resolverCfg, bootstrap, cfg.StartVerifyUpstream) + } + + upstreamBranches[group] = upstream + uErr = multierror.Append(multierror.Prefix(err, fmt.Sprintf("group %s: ", group))).ErrorOrNil() + } + + return upstreamBranches, uErr +} + func (s *Server) registerDNSHandlers() { for _, server := range s.dnsServers { handler := server.Handler.(*dns.ServeMux) diff --git a/server/server_test.go b/server/server_test.go index 1e553334..65b30bc2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -728,6 +728,45 @@ var _ = Describe("Running DNS server", func() { }) }) + Describe("NewServer with strict upstream strategy", func() { + It("successfully returns upstream branches", func() { + branches, err := createUpstreamBranches(&config.Config{ + Upstreams: config.UpstreamsConfig{ + Strategy: config.UpstreamStrategyStrict, + Groups: config.UpstreamGroups{ + "default": {{Host: "0.0.0.0"}}, + }, + }, + }, + nil) + + Expect(err).ToNot(HaveOccurred()) + Expect(branches).ToNot(BeNil()) + Expect(branches).To(HaveLen(1)) + _ = branches["default"].(*resolver.StrictResolver) + }) + }) + + Describe("create query resolver", func() { + When("some upstream returns error", func() { + It("create query resolver should return error", func() { + r, err := createQueryResolver(&config.Config{ + StartVerifyUpstream: true, + Upstreams: config.UpstreamsConfig{ + Groups: config.UpstreamGroups{ + "default": {{Host: "0.0.0.0"}}, + }, + }, + }, + nil, nil) + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("creation of upstream branches failed: "))) + Expect(r).To(BeNil()) + }) + }) + }) + Describe("resolve client IP", func() { Context("UDP address", func() { It("should correct resolve client IP", func() {