refactor: configuration rework (usage and printing) (#920)

* refactor: make `config.Duration` a struct with `time.Duration` embed

Allows directly calling `time.Duration` methods.

* refactor(HostsFileResolver): don't copy individual config items

The idea is to make adding configuration options easier, and searching
for references straight forward.

* refactor: move config printing to struct and use a logger

Using a logger allows using multiple levels so the whole configuration
can be printed in trace/verbose mode, but only important parts are
shown by default.

* squash: rename `Cast` to `ToDuration`

* squash: revert `Duration` to a simple wrapper ("new type" pattern)

* squash: `Duration.IsZero` tests

* squash: refactor resolvers to rely on their config directly if possible

* squash: implement `IsEnabled` and `LogValues` for all resolvers

* refactor: use go-enum `--values` to simplify getting all log fields

* refactor: simplify `QType` unmarshaling

* squash: rename `ValueLogger` to `Configurable`

* squash: rename `UpstreamConfig` to `ParallelBestConfig`

* squash: rename `RewriteConfig` to `RewriterConfig`

* squash: config tests

* squash: resolver tests

* squash: add `ForEach` test and improve `Chain` ones

* squash: simplify implementing `config.Configurable`

* squash: minor changes for better coverage

* squash: more `UnmarshalYAML` -> `UnmarshalText`

* refactor: move `config.Upstream` into own file

* refactor: add `Resolver.Type` method

* squash: add `log` method to `typed` to use `Resolover.Type` as prefix

* squash: tweak startup config logging

* squash: add `LogResolverConfig` tests

* squash: make sure all options of type `Duration` use `%s`
This commit is contained in:
ThinkChaos 2023-03-12 17:14:10 -04:00 committed by GitHub
parent bacb4437da
commit 5088c75a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 2918 additions and 1557 deletions

81
config/blocking.go Normal file
View File

@ -0,0 +1,81 @@
package config
import (
"strings"
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
)
// BlockingConfig configuration for query blocking
type BlockingConfig struct {
BlackLists map[string][]string `yaml:"blackLists"`
WhiteLists map[string][]string `yaml:"whiteLists"`
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
BlockType string `yaml:"blockType" default:"ZEROIP"`
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
// Deprecated
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
}
// IsEnabled implements `config.Configurable`.
func (c *BlockingConfig) IsEnabled() bool {
return len(c.ClientGroupsBlock) != 0
}
// IsEnabled implements `config.Configurable`.
func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
logger.Info("clientGroupsBlock:")
for key, val := range c.ClientGroupsBlock {
logger.Infof(" %s = %v", key, val)
}
logger.Infof("blockType = %s", c.BlockType)
if c.BlockType != "NXDOMAIN" {
logger.Infof("blockTTL = %s", c.BlockTTL)
}
logger.Infof("downloadTimeout = %s", c.DownloadTimeout)
logger.Infof("failStartOnListError = %t", c.FailStartOnListError)
if c.RefreshPeriod > 0 {
logger.Infof("refresh = every %s", c.RefreshPeriod)
} else {
logger.Debug("refresh = disabled")
}
logger.Info("blacklist:")
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
c.logListGroups(logger, c.BlackLists)
})
logger.Info("whitelist:")
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
c.logListGroups(logger, c.WhiteLists)
})
}
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) {
for group, links := range listGroups {
logger.Infof("%s:", group)
for _, link := range links {
if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char
link = link[:idx] // first line only
logger.Infof(" - %s [...]", link)
} else {
logger.Infof(" - %s", link)
}
}
}
}

86
config/blocking_test.go Normal file
View File

@ -0,0 +1,86 @@
package config
import (
"time"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
)
var _ = Describe("BlockingConfig", func() {
var (
cfg BlockingConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: Duration(time.Minute),
BlackLists: map[string][]string{
"gr1": {"/a/file/path"},
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
RefreshPeriod: Duration(time.Hour),
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := BlockingConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := BlockingConfig{
BlockTTL: Duration(-1),
}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:"))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
})
When("refresh is disabled", func() {
It("should reflect that", func() {
cfg.RefreshPeriod = Duration(-1)
logger.Logger.Level = logrus.InfoLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
logger.Logger.Level = logrus.TraceLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
})
})
})
})

52
config/caching.go Normal file
View File

@ -0,0 +1,52 @@
package config
import (
"time"
"github.com/sirupsen/logrus"
)
// CachingConfig configuration for domain caching
type CachingConfig struct {
MinCachingTime Duration `yaml:"minTime"`
MaxCachingTime Duration `yaml:"maxTime"`
CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"`
MaxItemsCount int `yaml:"maxItemsCount"`
Prefetching bool `yaml:"prefetching"`
PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"`
PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"`
PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"`
}
// IsEnabled implements `config.Configurable`.
func (c *CachingConfig) IsEnabled() bool {
return c.MaxCachingTime > 0
}
// LogConfig implements `config.Configurable`.
func (c *CachingConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("minTime = %s", c.MinCachingTime)
logger.Infof("maxTime = %s", c.MaxCachingTime)
logger.Infof("cacheTimeNegative = %s", c.CacheTimeNegative)
if c.Prefetching {
logger.Infof("prefetching:")
logger.Infof(" expires = %s", c.PrefetchExpires)
logger.Infof(" threshold = %d", c.PrefetchThreshold)
logger.Infof(" maxItems = %d", c.PrefetchMaxItemsCount)
} else {
logger.Debug("prefetching: disabled")
}
}
func (c *CachingConfig) EnablePrefetch() {
const day = Duration(24 * time.Hour)
if c.MaxCachingTime.IsZero() {
// make sure resolver gets enabled
c.MaxCachingTime = day
}
c.Prefetching = true
c.PrefetchThreshold = 0
}

65
config/caching_test.go Normal file
View File

@ -0,0 +1,65 @@
package config
import (
"time"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("CachingConfig", func() {
var (
cfg CachingConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = CachingConfig{
MaxCachingTime: Duration(time.Hour),
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := CachingConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("the config is enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("the config is disabled", func() {
It("should be false", func() {
cfg := CachingConfig{
MaxCachingTime: Duration(-1),
}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
When("prefetching is enabled", func() {
BeforeEach(func() {
cfg = CachingConfig{
Prefetching: true,
}
})
It("should return configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("prefetching:")))
})
})
})
})

36
config/client_lookup.go Normal file
View File

@ -0,0 +1,36 @@
package config
import (
"net"
"github.com/sirupsen/logrus"
)
// ClientLookupConfig configuration for the client lookup
type ClientLookupConfig struct {
ClientnameIPMapping map[string][]net.IP `yaml:"clients"`
Upstream Upstream `yaml:"upstream"`
SingleNameOrder []uint `yaml:"singleNameOrder"`
}
// IsEnabled implements `config.Configurable`.
func (c *ClientLookupConfig) IsEnabled() bool {
return !c.Upstream.IsDefault() || len(c.ClientnameIPMapping) != 0
}
// LogConfig implements `config.Configurable`.
func (c *ClientLookupConfig) LogConfig(logger *logrus.Entry) {
if !c.Upstream.IsDefault() {
logger.Infof("upstream = %s", c.Upstream)
}
logger.Infof("singleNameOrder = %v", c.SingleNameOrder)
if len(c.ClientnameIPMapping) > 0 {
logger.Infof("client IP mapping:")
for k, v := range c.ClientnameIPMapping {
logger.Infof(" %s = %s", k, v)
}
}
}

View File

@ -0,0 +1,68 @@
package config
import (
"net"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ClientLookupConfig", func() {
var (
cfg ClientLookupConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = ClientLookupConfig{
Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"},
SingleNameOrder: []uint{1, 2},
ClientnameIPMapping: map[string][]net.IP{
"client8": {net.ParseIP("1.2.3.5")},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg = ClientLookupConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
By("upstream", func() {
cfg := ClientLookupConfig{
Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"},
ClientnameIPMapping: nil,
}
Expect(cfg.IsEnabled()).Should(BeTrue())
})
By("mapping", func() {
cfg := ClientLookupConfig{
ClientnameIPMapping: map[string][]net.IP{
"client8": {net.ParseIP("1.2.3.5")},
},
}
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("client IP mapping:")))
})
})
})

View File

@ -0,0 +1,60 @@
package config
import (
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
// ConditionalUpstreamConfig conditional upstream configuration
type ConditionalUpstreamConfig struct {
RewriterConfig `yaml:",inline"`
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
}
// ConditionalUpstreamMapping mapping for conditional configuration
type ConditionalUpstreamMapping struct {
Upstreams map[string][]Upstream
}
// IsEnabled implements `config.Configurable`.
func (c *ConditionalUpstreamConfig) IsEnabled() bool {
return len(c.Mapping.Upstreams) != 0
}
// LogConfig implements `config.Configurable`.
func (c *ConditionalUpstreamConfig) LogConfig(logger *logrus.Entry) {
for key, val := range c.Mapping.Upstreams {
logger.Infof("%s = %v", key, val)
}
}
// UnmarshalYAML implements `yaml.Unmarshaler`.
func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input map[string]string
if err := unmarshal(&input); err != nil {
return err
}
result := make(map[string][]Upstream, len(input))
for k, v := range input {
var upstreams []Upstream
for _, part := range strings.Split(v, ",") {
upstream, err := ParseUpstream(strings.TrimSpace(part))
if err != nil {
return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err)
}
upstreams = append(upstreams, upstream)
}
result[k] = upstreams
}
c.Upstreams = result
return nil
}

View File

@ -0,0 +1,89 @@
package config
import (
"errors"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ConditionalUpstreamConfig", func() {
var (
cfg ConditionalUpstreamConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = ConditionalUpstreamConfig{
Mapping: ConditionalUpstreamMapping{
Upstreams: map[string][]Upstream{
"fritz.box": {Upstream{Net: NetProtocolTcpUdp, Host: "fbTest"}},
"other.box": {Upstream{Net: NetProtocolTcpUdp, Host: "otherTest"}},
".": {Upstream{Net: NetProtocolTcpUdp, Host: "dotTest"}},
},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := ConditionalUpstreamConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := ConditionalUpstreamConfig{
Mapping: ConditionalUpstreamMapping{Upstreams: map[string][]Upstream{}},
}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("fritz.box = ")))
})
})
Describe("UnmarshalYAML", func() {
It("Should parse config as map", func() {
c := &ConditionalUpstreamMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
Expect(c.Upstreams).Should(HaveLen(1))
Expect(c.Upstreams["key"]).Should(HaveLen(1))
Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{
Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53,
}))
})
It("should fail if wrong YAML format", func() {
c := &ConditionalUpstreamMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
})

View File

@ -1,4 +1,4 @@
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values
package config
import (
@ -7,16 +7,12 @@ import (
"net"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/hako/durafmt"
"github.com/sirupsen/logrus"
"github.com/0xERR0R/blocky/log"
"github.com/creasty/defaults"
@ -29,6 +25,16 @@ const (
httpsPort = 443
)
type Configurable interface {
// IsEnabled returns true when the receiver is configured.
IsEnabled() bool
// LogConfig logs the receiver's configuration.
//
// Calling this method when `IsEnabled` returns false is undefined.
LogConfig(*logrus.Entry)
}
// NetProtocol resolver protocol ENUM(
// tcp+udp // TCP and UDP protocols
// tcp-tls // TCP-TLS protocol
@ -90,44 +96,6 @@ type StartStrategyType uint16
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
type QueryLogField string
type QType dns.Type
func (c QType) String() string {
return dns.Type(c).String()
}
type QTypeSet map[QType]struct{}
func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
s := make(QTypeSet, len(qTypes))
for _, qType := range qTypes {
s.Insert(qType)
}
return s
}
func (s QTypeSet) Contains(qType dns.Type) bool {
_, found := s[QType(qType)]
return found
}
func (s *QTypeSet) Insert(qType dns.Type) {
if *s == nil {
*s = make(QTypeSet, 1)
}
(*s)[QType(qType)] = struct{}{}
}
type Duration time.Duration
func (c Duration) String() string {
return durafmt.Parse(time.Duration(c)).String()
}
//nolint:gochecknoglobals
var netDefaultPort = map[NetProtocol]uint16{
NetProtocolTcpUdp: udpPort,
@ -135,82 +103,12 @@ var netDefaultPort = map[NetProtocol]uint16{
NetProtocolHttps: httpsPort,
}
// Upstream is the definition of external DNS server
type Upstream struct {
Net NetProtocol
Host string
Port uint16
Path string
CommonName string // Common Name to use for certificate verification; optional. "" uses .Host
}
// IsDefault returns true if u is the default value
func (u *Upstream) IsDefault() bool {
return *u == Upstream{}
}
// String returns the string representation of u
func (u Upstream) String() string {
if u.IsDefault() {
return "no upstream"
}
var sb strings.Builder
sb.WriteString(u.Net.String())
sb.WriteRune(':')
if u.Net == NetProtocolHttps {
sb.WriteString("//")
}
isIPv6 := strings.ContainsRune(u.Host, ':')
if isIPv6 {
sb.WriteRune('[')
sb.WriteString(u.Host)
sb.WriteRune(']')
} else {
sb.WriteString(u.Host)
}
if u.Port != netDefaultPort[u.Net] {
sb.WriteRune(':')
sb.WriteString(fmt.Sprint(u.Port))
}
if u.Path != "" {
sb.WriteString(u.Path)
}
return sb.String()
}
// UnmarshalYAML creates Upstream from YAML
func (u *Upstream) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
upstream, err := ParseUpstream(s)
if err != nil {
return fmt.Errorf("can't convert upstream '%s': %w", s, err)
}
*u = upstream
return nil
}
// ListenConfig is a list of address(es) to listen on
type ListenConfig []string
// UnmarshalYAML creates ListenConfig from YAML
func (l *ListenConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
var addresses string
if err := unmarshal(&addresses); err != nil {
return err
}
// UnmarshalText implements `encoding.TextUnmarshaler`.
func (l *ListenConfig) UnmarshalText(data []byte) error {
addresses := string(data)
*l = strings.Split(addresses, ",")
@ -256,225 +154,11 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e
return nil
}
// UnmarshalYAML creates ConditionalUpstreamMapping from YAML
func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input map[string]string
if err := unmarshal(&input); err != nil {
return err
}
result := make(map[string][]Upstream, len(input))
for k, v := range input {
var upstreams []Upstream
for _, part := range strings.Split(v, ",") {
upstream, err := ParseUpstream(strings.TrimSpace(part))
if err != nil {
return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err)
}
upstreams = append(upstreams, upstream)
}
result[k] = upstreams
}
c.Upstreams = result
return nil
}
// UnmarshalYAML creates CustomDNSMapping from YAML
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input map[string]string
if err := unmarshal(&input); err != nil {
return err
}
result := make(map[string][]net.IP, len(input))
for k, v := range input {
var ips []net.IP
for _, part := range strings.Split(v, ",") {
ip := net.ParseIP(strings.TrimSpace(part))
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
ips = append(ips, ip)
}
result[k] = ips
}
c.HostIPs = result
return nil
}
// UnmarshalYAML creates Duration from YAML. If no unit is used, uses minutes
func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}
if minutes, err := strconv.Atoi(input); err == nil {
// duration is defined as number without unit
// use minutes to ensure back compatibility
*c = Duration(time.Duration(minutes) * time.Minute)
return nil
}
duration, err := time.ParseDuration(input)
if err == nil {
*c = Duration(duration)
return nil
}
return err
}
func (c *QType) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}
t, found := dns.StringToType[input]
if !found {
types := make([]string, 0, len(dns.StringToType))
for k := range dns.StringToType {
types = append(types, k)
}
sort.Strings(types)
return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'",
input, strings.Join(types, ", "))
}
*c = QType(t)
return nil
}
func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input []QType
if err := unmarshal(&input); err != nil {
return err
}
*s = make(QTypeSet, len(input))
for _, qType := range input {
(*s)[qType] = struct{}{}
}
return nil
}
var validDomain = regexp.MustCompile(
`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname]
func ParseUpstream(upstream string) (Upstream, error) {
var path string
var port uint16
commonName, upstream := extractCommonName(upstream)
n, upstream := extractNet(upstream)
path, upstream = extractPath(upstream)
host, portString, err := net.SplitHostPort(upstream)
// string contains host:port
if err == nil {
p, err := ConvertPort(portString)
if err != nil {
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
return Upstream{}, err
}
port = p
} else {
// only host, use default port
host = upstream
port = netDefaultPort[n]
// trim any IPv6 brackets
host = strings.TrimPrefix(host, "[")
host = strings.TrimSuffix(host, "]")
}
// validate hostname or ip
if ip := net.ParseIP(host); ip == nil {
// is not IP
if !validDomain.MatchString(host) {
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
}
}
return Upstream{
Net: n,
Host: host,
Port: port,
Path: path,
CommonName: commonName,
}, nil
}
func extractCommonName(in string) (string, string) {
upstream, cn, _ := strings.Cut(in, "#")
return cn, upstream
}
func extractPath(in string) (path, upstream string) {
slashIdx := strings.Index(in, "/")
if slashIdx >= 0 {
path = in[slashIdx:]
upstream = in[:slashIdx]
} else {
upstream = in
}
return
}
func extractNet(upstream string) (NetProtocol, string) {
tcpUDPPrefix := NetProtocolTcpUdp.String() + ":"
if strings.HasPrefix(upstream, tcpUDPPrefix) {
return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):]
}
tcpTLSPrefix := NetProtocolTcpTls.String() + ":"
if strings.HasPrefix(upstream, tcpTLSPrefix) {
return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):]
}
httpsPrefix := NetProtocolHttps.String() + ":"
if strings.HasPrefix(upstream, httpsPrefix) {
return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//")
}
return NetProtocolTcpUdp, upstream
}
// Config main configuration
//
//nolint:maligned
type Config struct {
Upstream UpstreamConfig `yaml:"upstream"`
Upstream ParallelBestConfig `yaml:"upstream"`
UpstreamTimeout Duration `yaml:"upstreamTimeout" default:"2s"`
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
CustomDNS CustomDNSConfig `yaml:"customDNS"`
@ -483,7 +167,7 @@ type Config struct {
ClientLookup ClientLookupConfig `yaml:"clientLookup"`
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus PrometheusConfig `yaml:"prometheus"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
@ -494,7 +178,7 @@ type Config struct {
KeyFile string `yaml:"keyFile"`
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
HostsFile HostsFileConfig `yaml:"hostsFile"`
FqdnOnly bool `yaml:"fqdnOnly" default:"false"`
FqdnOnly FqdnOnlyConfig `yaml:",inline"`
Filtering FilteringConfig `yaml:"filtering"`
Ede EdeConfig `yaml:"ede"`
// Deprecated
@ -508,7 +192,7 @@ type Config struct {
// Deprecated
LogTimestamp bool `yaml:"logTimestamp" default:"true"`
// Deprecated
DNSPorts ListenConfig `yaml:"port" default:"[\"53\"]"`
DNSPorts ListenConfig `yaml:"port" default:"\"53\""`
// Deprecated
HTTPPorts ListenConfig `yaml:"httpPort"`
// Deprecated
@ -518,12 +202,19 @@ type Config struct {
}
type PortsConfig struct {
DNS ListenConfig `yaml:"dns" default:"[\"53\"]"`
DNS ListenConfig `yaml:"dns" default:"\"53\""`
HTTP ListenConfig `yaml:"http"`
HTTPS ListenConfig `yaml:"https"`
TLS ListenConfig `yaml:"tls"`
}
func (c *PortsConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("DNS = %s", c.DNS)
logger.Infof("TLS = %s", c.TLS)
logger.Infof("HTTP = %s", c.HTTP)
logger.Infof("HTTPS = %s", c.HTTPS)
}
// split in two types to avoid infinite recursion. See `BootstrapDNSConfig.UnmarshalYAML`.
type (
BootstrapDNSConfig bootstrapDNSConfig
@ -539,105 +230,6 @@ type (
}
)
// PrometheusConfig contains the config values for prometheus
type PrometheusConfig struct {
Enable bool `yaml:"enable" default:"false"`
Path string `yaml:"path" default:"/metrics"`
}
// UpstreamConfig upstream server configuration
type UpstreamConfig struct {
ExternalResolvers map[string][]Upstream `yaml:",inline"`
}
// RewriteConfig custom DNS configuration
type RewriteConfig struct {
Rewrite map[string]string `yaml:"rewrite"`
FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"`
}
// CustomDNSConfig custom DNS configuration
type CustomDNSConfig struct {
RewriteConfig `yaml:",inline"`
CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"`
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
}
// CustomDNSMapping mapping for the custom DNS configuration
type CustomDNSMapping struct {
HostIPs map[string][]net.IP
}
// ConditionalUpstreamConfig conditional upstream configuration
type ConditionalUpstreamConfig struct {
RewriteConfig `yaml:",inline"`
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
}
// ConditionalUpstreamMapping mapping for conditional configuration
type ConditionalUpstreamMapping struct {
Upstreams map[string][]Upstream
}
// BlockingConfig configuration for query blocking
type BlockingConfig struct {
BlackLists map[string][]string `yaml:"blackLists"`
WhiteLists map[string][]string `yaml:"whiteLists"`
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
BlockType string `yaml:"blockType" default:"ZEROIP"`
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
// Deprecated
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
}
// ClientLookupConfig configuration for the client lookup
type ClientLookupConfig struct {
ClientnameIPMapping map[string][]net.IP `yaml:"clients"`
Upstream Upstream `yaml:"upstream"`
SingleNameOrder []uint `yaml:"singleNameOrder"`
}
// CachingConfig configuration for domain caching
type CachingConfig struct {
MinCachingTime Duration `yaml:"minTime"`
MaxCachingTime Duration `yaml:"maxTime"`
CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"`
MaxItemsCount int `yaml:"maxItemsCount"`
Prefetching bool `yaml:"prefetching"`
PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"`
PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"`
PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"`
}
func (c *CachingConfig) EnablePrefetch() {
const day = 24 * time.Hour
if c.MaxCachingTime == 0 {
// make sure resolver gets enabled
c.MaxCachingTime = Duration(day)
}
c.Prefetching = true
c.PrefetchThreshold = 0
}
// QueryLogConfig configuration for the query logging
type QueryLogConfig struct {
Target string `yaml:"target"`
Type QueryLogType `yaml:"type"`
LogRetentionDays uint64 `yaml:"logRetentionDays"`
CreationAttempts int `yaml:"creationAttempts" default:"3"`
CreationCooldown Duration `yaml:"creationCooldown" default:"2s"`
Fields []QueryLogField `yaml:"fields"`
}
// RedisConfig configuration for the redis connection
type RedisConfig struct {
Address string `yaml:"address"`
@ -652,21 +244,25 @@ type RedisConfig struct {
SentinelAddresses []string `yaml:"sentinelAddresses"`
}
type HostsFileConfig struct {
Filepath string `yaml:"filePath"`
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
FilterLoopback bool `yaml:"filterLoopback"`
}
type (
FqdnOnlyConfig = toEnable
EdeConfig = toEnable
)
type FilteringConfig struct {
QueryTypes QTypeSet `yaml:"queryTypes"`
}
type EdeConfig struct {
type toEnable struct {
Enable bool `yaml:"enable" default:"false"`
}
// IsEnabled implements `config.Configurable`.
func (c *toEnable) IsEnabled() bool {
return c.Enable
}
// LogConfig implements `config.Configurable`.
func (c *toEnable) LogConfig(logger *logrus.Entry) {
logger.Info("enabled")
}
//nolint:gochecknoglobals
var (
config = &Config{}

View File

@ -40,6 +40,15 @@ func IPVersionNames() []string {
return tmp
}
// IPVersionValues returns a list of the values for IPVersion
func IPVersionValues() []IPVersion {
return []IPVersion{
IPVersionDual,
IPVersionV4,
IPVersionV6,
}
}
var _IPVersionMap = map[IPVersion]string{
IPVersionDual: _IPVersionName[0:4],
IPVersionV4: _IPVersionName[4:6],
@ -113,6 +122,15 @@ func NetProtocolNames() []string {
return tmp
}
// NetProtocolValues returns a list of the values for NetProtocol
func NetProtocolValues() []NetProtocol {
return []NetProtocol{
NetProtocolTcpUdp,
NetProtocolTcpTls,
NetProtocolHttps,
}
}
var _NetProtocolMap = map[NetProtocol]string{
NetProtocolTcpUdp: _NetProtocolName[0:7],
NetProtocolTcpTls: _NetProtocolName[7:14],
@ -190,6 +208,18 @@ func QueryLogFieldNames() []string {
return tmp
}
// QueryLogFieldValues returns a list of the values for QueryLogField
func QueryLogFieldValues() []QueryLogField {
return []QueryLogField{
QueryLogFieldClientIP,
QueryLogFieldClientName,
QueryLogFieldResponseReason,
QueryLogFieldResponseAnswer,
QueryLogFieldQuestion,
QueryLogFieldDuration,
}
}
// String implements the Stringer interface.
func (x QueryLogField) String() string {
return string(x)
@ -274,6 +304,18 @@ func QueryLogTypeNames() []string {
return tmp
}
// QueryLogTypeValues returns a list of the values for QueryLogType
func QueryLogTypeValues() []QueryLogType {
return []QueryLogType{
QueryLogTypeConsole,
QueryLogTypeNone,
QueryLogTypeMysql,
QueryLogTypePostgresql,
QueryLogTypeCsv,
QueryLogTypeCsvClient,
}
}
var _QueryLogTypeMap = map[QueryLogType]string{
QueryLogTypeConsole: _QueryLogTypeName[0:7],
QueryLogTypeNone: _QueryLogTypeName[7:11],
@ -353,6 +395,15 @@ func StartStrategyTypeNames() []string {
return tmp
}
// StartStrategyTypeValues returns a list of the values for StartStrategyType
func StartStrategyTypeValues() []StartStrategyType {
return []StartStrategyType{
StartStrategyTypeBlocking,
StartStrategyTypeFailOnError,
StartStrategyTypeFast,
}
}
var _StartStrategyTypeMap = map[StartStrategyType]string{
StartStrategyTypeBlocking: _StartStrategyTypeName[0:8],
StartStrategyTypeFailOnError: _StartStrategyTypeName[8:19],

View File

@ -6,6 +6,12 @@ import (
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
)
var (
logger *logrus.Entry
hook *log.MockLoggerHook
)
func TestConfig(t *testing.T) {
@ -13,3 +19,9 @@ func TestConfig(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Config Suite")
}
func suiteBeforeEach() {
BeforeEach(func() {
logger, hook = log.NewMockEntry()
})
}

View File

@ -1,7 +1,6 @@
package config
import (
"errors"
"net"
"time"
@ -321,15 +320,11 @@ bootstrapDns:
})
})
Describe("YAML parsing", func() {
Describe("Parsing", func() {
Context("upstream", func() {
It("should create the upstream struct with data", func() {
u := &Upstream{}
err := u.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "tcp+udp:1.2.3.4"
return nil
})
err := u.UnmarshalText([]byte("tcp+udp:1.2.3.4"))
Expect(err).Should(Succeed())
Expect(u.Net).Should(Equal(NetProtocolTcpUdp))
Expect(u.Host).Should(Equal("1.2.3.4"))
@ -338,139 +333,18 @@ bootstrapDns:
It("should fail if the upstream is in wrong format", func() {
u := &Upstream{}
err := u.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
err := u.UnmarshalText([]byte("invalid!"))
Expect(err).Should(HaveOccurred())
})
})
Context("ListenConfig", func() {
It("should parse and split valid string config", func() {
l := &ListenConfig{}
err := l.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "55,:56"
return nil
})
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
})
It("should fail on error", func() {
l := &ListenConfig{}
err := l.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
})
})
Context("Duration", func() {
It("should parse duration with unit", func() {
d := Duration(0)
err := d.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "1m20s"
return nil
})
Expect(err).Should(Succeed())
Expect(d).Should(Equal(Duration(80 * time.Second)))
Expect(d.String()).Should(Equal("1 minute 20 seconds"))
})
It("should fail if duration is in wrong format", func() {
d := Duration(0)
err := d.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "wrong"
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("time: invalid duration \"wrong\""))
})
It("should fail if wrong YAML format", func() {
d := Duration(0)
err := d.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
Context("ConditionalUpstreamMapping", func() {
It("Should parse config as map", func() {
c := &ConditionalUpstreamMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
Expect(c.Upstreams).Should(HaveLen(1))
Expect(c.Upstreams["key"]).Should(HaveLen(1))
Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{
Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53,
}))
})
It("should fail if wrong YAML format", func() {
c := &ConditionalUpstreamMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
Context("CustomDNSMapping", func() {
It("Should parse config as map", func() {
c := &CustomDNSMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
Expect(c.HostIPs).Should(HaveLen(1))
Expect(c.HostIPs["key"]).Should(HaveLen(1))
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4")))
})
It("should fail if wrong YAML format", func() {
c := &CustomDNSMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
Context("QueryTyoe", func() {
It("Should parse existing DNS type as string", func() {
t := QType(0)
err := t.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "AAAA"
return nil
})
Expect(err).Should(Succeed())
Expect(t).Should(Equal(QType(dns.TypeAAAA)))
Expect(t.String()).Should(Equal("AAAA"))
})
It("should fail if DNS type does not exist", func() {
t := QType(0)
err := t.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "WRONGTYPE"
return nil
})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'"))
})
It("should fail if wrong YAML format", func() {
d := QType(0)
err := d.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
})
@ -651,29 +525,6 @@ bootstrapDns:
"tcp-tls:[fd00::6cd4:d7e0:d99d:2952]",
),
)
Describe("QTypeSet", func() {
It("new should insert given qTypes", func() {
set := NewQTypeSet(dns.Type(dns.TypeA))
Expect(set).Should(HaveKey(QType(dns.TypeA)))
Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue())
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
})
It("should insert given qTypes", func() {
set := NewQTypeSet()
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
set.Insert(dns.Type(dns.TypeAAAA))
Expect(set).Should(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue())
})
})
})
func defaultTestFileConfig() {
@ -700,8 +551,8 @@ func defaultTestFileConfig() {
Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
Expect(config.Filtering.QueryTypes).Should(HaveLen(2))
Expect(config.Caching.MaxCachingTime).Should(Equal(Duration(0)))
Expect(config.Caching.MinCachingTime).Should(Equal(Duration(0)))
Expect(config.Caching.MaxCachingTime.IsZero()).Should(BeTrue())
Expect(config.Caching.MinCachingTime.IsZero()).Should(BeTrue())
Expect(config.DoHUserAgent).Should(Equal("testBlocky"))
Expect(config.MinTLSServeVer).Should(Equal("1.3"))

68
config/custom_dns.go Normal file
View File

@ -0,0 +1,68 @@
package config
import (
"fmt"
"net"
"strings"
"github.com/sirupsen/logrus"
)
// CustomDNSConfig custom DNS configuration
type CustomDNSConfig struct {
RewriterConfig `yaml:",inline"`
CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"`
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
}
// CustomDNSMapping mapping for the custom DNS configuration
type CustomDNSMapping struct {
HostIPs map[string][]net.IP `yaml:"hostIPs"`
}
// IsEnabled implements `config.Configurable`.
func (c *CustomDNSConfig) IsEnabled() bool {
return len(c.Mapping.HostIPs) != 0
}
// LogConfig implements `config.Configurable`.
func (c *CustomDNSConfig) LogConfig(logger *logrus.Entry) {
logger.Debugf("TTL = %s", c.CustomTTL)
logger.Debugf("filterUnmappedTypes = %t", c.FilterUnmappedTypes)
logger.Info("mapping:")
for key, val := range c.Mapping.HostIPs {
logger.Infof(" %s = %s", key, val)
}
}
// UnmarshalYAML implements `yaml.Unmarshaler`.
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input map[string]string
if err := unmarshal(&input); err != nil {
return err
}
result := make(map[string][]net.IP, len(input))
for k, v := range input {
var ips []net.IP
for _, part := range strings.Split(v, ",") {
ip := net.ParseIP(strings.TrimSpace(part))
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
ips = append(ips, ip)
}
result[k] = ips
}
c.HostIPs = result
return nil
}

91
config/custom_dns_test.go Normal file
View File

@ -0,0 +1,91 @@
package config
import (
"errors"
"net"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("CustomDNSConfig", func() {
var (
cfg CustomDNSConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = CustomDNSConfig{
Mapping: CustomDNSMapping{
HostIPs: map[string][]net.IP{
"custom.domain": {net.ParseIP("192.168.143.123")},
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
"multiple.ips": {
net.ParseIP("192.168.143.123"),
net.ParseIP("192.168.143.125"),
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
},
},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := CustomDNSConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := CustomDNSConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("custom.domain = ")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("multiple.ips = ")))
})
})
Describe("UnmarshalYAML", func() {
It("Should parse config as map", func() {
c := &CustomDNSMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
Expect(c.HostIPs).Should(HaveLen(1))
Expect(c.HostIPs["key"]).Should(HaveLen(1))
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4")))
})
It("should fail if wrong YAML format", func() {
c := &CustomDNSMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
return errors.New("some err")
})
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("some err"))
})
})
})

54
config/duration.go Normal file
View File

@ -0,0 +1,54 @@
package config
import (
"strconv"
"time"
"github.com/0xERR0R/blocky/log"
"github.com/hako/durafmt"
)
type Duration time.Duration
func (c Duration) ToDuration() time.Duration {
return time.Duration(c)
}
func (c Duration) IsZero() bool {
return c.ToDuration() == 0
}
func (c Duration) Seconds() float64 {
return c.ToDuration().Seconds()
}
func (c Duration) SecondsU32() uint32 {
return uint32(c.Seconds())
}
func (c Duration) String() string {
return durafmt.Parse(c.ToDuration()).String()
}
// UnmarshalText implements `encoding.TextUnmarshaler`.
func (c *Duration) UnmarshalText(data []byte) error {
input := string(data)
if minutes, err := strconv.Atoi(input); err == nil {
// number without unit: use minutes to ensure back compatibility
*c = Duration(time.Duration(minutes) * time.Minute)
log.Log().Warnf("Setting a duration without a unit is deprecated. Please use '%s min' instead.", input)
return nil
}
duration, err := time.ParseDuration(input)
if err == nil {
*c = Duration(duration)
return nil
}
return err
}

44
config/duration_test.go Normal file
View File

@ -0,0 +1,44 @@
package config
import (
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Duration", func() {
var d Duration
BeforeEach(func() {
var zero Duration
d = zero
})
Describe("UnmarshalText", func() {
It("should parse duration with unit", func() {
err := d.UnmarshalText([]byte("1m20s"))
Expect(err).Should(Succeed())
Expect(d).Should(Equal(Duration(80 * time.Second)))
Expect(d.String()).Should(Equal("1 minute 20 seconds"))
})
It("should fail if duration is in wrong format", func() {
err := d.UnmarshalText([]byte("wrong"))
Expect(err).Should(HaveOccurred())
Expect(err).Should(MatchError("time: invalid duration \"wrong\""))
})
})
Describe("IsZero", func() {
It("should be true for zero", func() {
Expect(d.IsZero()).Should(BeTrue())
Expect(Duration(0).IsZero()).Should(BeTrue())
})
It("should be false for non-zero", func() {
Expect(Duration(time.Second).IsZero()).Should(BeFalse())
})
})
})

23
config/filtering.go Normal file
View File

@ -0,0 +1,23 @@
package config
import (
"github.com/sirupsen/logrus"
)
type FilteringConfig struct {
QueryTypes QTypeSet `yaml:"queryTypes"`
}
// IsEnabled implements `config.Configurable`.
func (c *FilteringConfig) IsEnabled() bool {
return len(c.QueryTypes) != 0
}
// LogConfig implements `config.Configurable`.
func (c *FilteringConfig) LogConfig(logger *logrus.Entry) {
logger.Info("query types:")
for qType := range c.QueryTypes {
logger.Infof(" - %s", qType)
}
}

57
config/filtering_test.go Normal file
View File

@ -0,0 +1,57 @@
package config
import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("FilteringConfig", func() {
var (
cfg FilteringConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = FilteringConfig{
QueryTypes: NewQTypeSet(AAAA, MX),
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := FilteringConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := FilteringConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).Should(HaveLen(3))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("query types:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - AAAA")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - MX")))
})
})
})

25
config/hosts_file.go Normal file
View File

@ -0,0 +1,25 @@
package config
import (
"github.com/sirupsen/logrus"
)
type HostsFileConfig struct {
Filepath string `yaml:"filePath"`
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
FilterLoopback bool `yaml:"filterLoopback"`
}
// IsEnabled implements `config.Configurable`.
func (c *HostsFileConfig) IsEnabled() bool {
return len(c.Filepath) != 0
}
// LogConfig implements `config.Configurable`.
func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("file path: %s", c.Filepath)
logger.Infof("TTL: %s", c.HostsTTL)
logger.Infof("refresh period: %s", c.RefreshPeriod)
logger.Infof("filter loopback addresses: %t", c.FilterLoopback)
}

58
config/hosts_file_test.go Normal file
View File

@ -0,0 +1,58 @@
package config
import (
"time"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("HostsFileConfig", func() {
var (
cfg HostsFileConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = HostsFileConfig{
Filepath: "/dev/null",
HostsTTL: Duration(29 * time.Minute),
RefreshPeriod: Duration(30 * time.Minute),
FilterLoopback: true,
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := HostsFileConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := HostsFileConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null")))
})
})
})

19
config/metrics.go Normal file
View File

@ -0,0 +1,19 @@
package config
import "github.com/sirupsen/logrus"
// MetricsConfig contains the config values for prometheus
type MetricsConfig struct {
Enable bool `yaml:"enable" default:"false"`
Path string `yaml:"path" default:"/metrics"`
}
// IsEnabled implements `config.Configurable`.
func (c *MetricsConfig) IsEnabled() bool {
return c.Enable
}
// LogConfig implements `config.Configurable`.
func (c *MetricsConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("url path: %s", c.Path)
}

54
config/metrics_test.go Normal file
View File

@ -0,0 +1,54 @@
package config
import (
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("MetricsConfig", func() {
var (
cfg MetricsConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = MetricsConfig{
Enable: true,
Path: "/custom/path",
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := MetricsConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := MetricsConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).Should(HaveLen(1))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("url path: /custom/path")))
})
})
})

32
config/parallel_best.go Normal file
View File

@ -0,0 +1,32 @@
package config
import (
"github.com/sirupsen/logrus"
)
const UpstreamDefaultCfgName = "default"
// ParallelBestConfig upstream server configuration
type ParallelBestConfig struct {
ExternalResolvers ParallelBestMapping `yaml:",inline"`
}
type ParallelBestMapping map[string][]Upstream
// IsEnabled implements `config.Configurable`.
func (c *ParallelBestConfig) IsEnabled() bool {
return len(c.ExternalResolvers) != 0
}
// LogConfig implements `config.Configurable`.
func (c *ParallelBestConfig) LogConfig(logger *logrus.Entry) {
logger.Info("upstream resolvers:")
for name, upstreams := range c.ExternalResolvers {
logger.Infof(" %s:", name)
for _, upstream := range upstreams {
logger.Infof(" - %s", upstream)
}
}
}

View File

@ -0,0 +1,59 @@
package config
import (
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ParallelBestConfig", func() {
var (
cfg ParallelBestConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = ParallelBestConfig{
ExternalResolvers: ParallelBestMapping{
UpstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
},
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := ParallelBestConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := ParallelBestConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstream resolvers:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
})
})
})

76
config/qtype_set.go Normal file
View File

@ -0,0 +1,76 @@
package config
import (
"fmt"
"sort"
"strings"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
)
type QTypeSet map[QType]struct{}
func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
s := make(QTypeSet, len(qTypes))
for _, qType := range qTypes {
s.Insert(qType)
}
return s
}
func (s QTypeSet) Contains(qType dns.Type) bool {
_, found := s[QType(qType)]
return found
}
func (s *QTypeSet) Insert(qType dns.Type) {
if *s == nil {
*s = make(QTypeSet, 1)
}
(*s)[QType(qType)] = struct{}{}
}
func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input []QType
if err := unmarshal(&input); err != nil {
return err
}
*s = make(QTypeSet, len(input))
for _, qType := range input {
(*s)[qType] = struct{}{}
}
return nil
}
type QType dns.Type
func (c QType) String() string {
return dns.Type(c).String()
}
// UnmarshalText implements `encoding.TextUnmarshaler`.
func (c *QType) UnmarshalText(data []byte) error {
input := string(data)
t, found := dns.StringToType[input]
if !found {
types := maps.Keys(dns.StringToType)
sort.Strings(types)
return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'",
input, strings.Join(types, ", "))
}
*c = QType(t)
return nil
}

60
config/qtype_set_test.go Normal file
View File

@ -0,0 +1,60 @@
package config
import (
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("QTypeSet", func() {
Describe("NewQTypeSet", func() {
It("should insert given qTypes", func() {
set := NewQTypeSet(dns.Type(dns.TypeA))
Expect(set).Should(HaveKey(QType(dns.TypeA)))
Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue())
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
})
})
Describe("Insert", func() {
It("should insert given qTypes", func() {
set := NewQTypeSet()
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
set.Insert(dns.Type(dns.TypeAAAA))
Expect(set).Should(HaveKey(QType(dns.TypeAAAA)))
Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue())
})
})
})
var _ = Describe("QType", func() {
Describe("UnmarshalText", func() {
It("Should parse existing DNS type as string", func() {
t := QType(0)
err := t.UnmarshalText([]byte("AAAA"))
Expect(err).Should(Succeed())
Expect(t).Should(Equal(QType(dns.TypeAAAA)))
Expect(t.String()).Should(Equal("AAAA"))
})
It("should fail if DNS type does not exist", func() {
t := QType(0)
err := t.UnmarshalText([]byte("WRONGTYPE"))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'"))
})
It("should fail if wrong YAML format", func() {
d := QType(0)
err := d.UnmarshalText([]byte("some err"))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'some err'"))
})
})
})

41
config/query_log.go Normal file
View File

@ -0,0 +1,41 @@
package config
import (
"github.com/sirupsen/logrus"
)
// QueryLogConfig configuration for the query logging
type QueryLogConfig struct {
Target string `yaml:"target"`
Type QueryLogType `yaml:"type"`
LogRetentionDays uint64 `yaml:"logRetentionDays"`
CreationAttempts int `yaml:"creationAttempts" default:"3"`
CreationCooldown Duration `yaml:"creationCooldown" default:"2s"`
Fields []QueryLogField `yaml:"fields"`
}
// SetDefaults implements `defaults.Setter`.
func (c *QueryLogConfig) SetDefaults() {
// Since the default depends on the enum values, set it dynamically
// to avoid having to repeat the values in the annotation.
c.Fields = QueryLogFieldValues()
}
// IsEnabled implements `config.Configurable`.
func (c *QueryLogConfig) IsEnabled() bool {
return c.Type != QueryLogTypeNone
}
// LogConfig implements `config.Configurable`.
func (c *QueryLogConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("type: %s", c.Type)
if c.Target != "" {
logger.Infof("target: %s", c.Target)
}
logger.Infof("logRetentionDays: %d", c.LogRetentionDays)
logger.Debugf("creationAttempts: %d", c.CreationAttempts)
logger.Debugf("creationCooldown: %s", c.CreationCooldown)
logger.Infof("fields: %s", c.Fields)
}

72
config/query_log_test.go Normal file
View File

@ -0,0 +1,72 @@
package config
import (
"time"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("QueryLogConfig", func() {
var (
cfg QueryLogConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = QueryLogConfig{
Target: "/dev/null",
Type: QueryLogTypeCsvClient,
LogRetentionDays: 0,
CreationAttempts: 1,
CreationCooldown: Duration(time.Millisecond),
}
})
Describe("IsEnabled", func() {
It("should be true by default", func() {
cfg := QueryLogConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeTrue())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := QueryLogConfig{
Type: QueryLogTypeNone,
}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("logRetentionDays:")))
})
})
Describe("SetDefaults", func() {
It("should log configuration", func() {
cfg := QueryLogConfig{}
Expect(cfg.Fields).Should(BeEmpty())
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.Fields).ShouldNot(BeEmpty())
})
})
})

27
config/rewriter.go Normal file
View File

@ -0,0 +1,27 @@
package config
import (
"github.com/sirupsen/logrus"
)
// RewriterConfig custom DNS configuration
type RewriterConfig struct {
Rewrite map[string]string `yaml:"rewrite"`
FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"`
}
// IsEnabled implements `config.Configurable`.
func (c *RewriterConfig) IsEnabled() bool {
return len(c.Rewrite) != 0
}
// LogConfig implements `config.Configurable`.
func (c *RewriterConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("fallbackUpstream = %t", c.FallbackUpstream)
logger.Info("rules:")
for key, val := range c.Rewrite {
logger.Infof(" %s = %s", key, val)
}
}

57
config/rewriter_test.go Normal file
View File

@ -0,0 +1,57 @@
package config
import (
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("RewriterConfig", func() {
var (
cfg RewriterConfig
)
suiteBeforeEach()
BeforeEach(func() {
cfg = RewriterConfig{
Rewrite: map[string]string{
"original1": "rewritten1",
"original2": "rewritten2",
},
}
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := RewriterConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
})
When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})
When("disabled", func() {
It("should be false", func() {
cfg := RewriterConfig{}
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("rules:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("original2 =")))
})
})
})

164
config/upstream.go Normal file
View File

@ -0,0 +1,164 @@
package config
import (
"fmt"
"net"
"regexp"
"strings"
)
var validDomain = regexp.MustCompile(
`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
// Upstream is the definition of external DNS server
type Upstream struct {
Net NetProtocol
Host string
Port uint16
Path string
CommonName string // Common Name to use for certificate verification; optional. "" uses .Host
}
// IsDefault returns true if u is the default value
func (u *Upstream) IsDefault() bool {
return *u == Upstream{}
}
// String returns the string representation of u
func (u Upstream) String() string {
if u.IsDefault() {
return "no upstream"
}
var sb strings.Builder
sb.WriteString(u.Net.String())
sb.WriteRune(':')
if u.Net == NetProtocolHttps {
sb.WriteString("//")
}
isIPv6 := strings.ContainsRune(u.Host, ':')
if isIPv6 {
sb.WriteRune('[')
sb.WriteString(u.Host)
sb.WriteRune(']')
} else {
sb.WriteString(u.Host)
}
if u.Port != netDefaultPort[u.Net] {
sb.WriteRune(':')
sb.WriteString(fmt.Sprint(u.Port))
}
if u.Path != "" {
sb.WriteString(u.Path)
}
return sb.String()
}
// UnmarshalText implements `encoding.TextUnmarshaler`.
func (u *Upstream) UnmarshalText(data []byte) error {
s := string(data)
upstream, err := ParseUpstream(s)
if err != nil {
return fmt.Errorf("can't convert upstream '%s': %w", s, err)
}
*u = upstream
return nil
}
// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname]
func ParseUpstream(upstream string) (Upstream, error) {
var path string
var port uint16
commonName, upstream := extractCommonName(upstream)
n, upstream := extractNet(upstream)
path, upstream = extractPath(upstream)
host, portString, err := net.SplitHostPort(upstream)
// string contains host:port
if err == nil {
p, err := ConvertPort(portString)
if err != nil {
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
return Upstream{}, err
}
port = p
} else {
// only host, use default port
host = upstream
port = netDefaultPort[n]
// trim any IPv6 brackets
host = strings.TrimPrefix(host, "[")
host = strings.TrimSuffix(host, "]")
}
// validate hostname or ip
if ip := net.ParseIP(host); ip == nil {
// is not IP
if !validDomain.MatchString(host) {
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
}
}
return Upstream{
Net: n,
Host: host,
Port: port,
Path: path,
CommonName: commonName,
}, nil
}
func extractCommonName(in string) (string, string) {
upstream, cn, _ := strings.Cut(in, "#")
return cn, upstream
}
func extractPath(in string) (path, upstream string) {
slashIdx := strings.Index(in, "/")
if slashIdx >= 0 {
path = in[slashIdx:]
upstream = in[:slashIdx]
} else {
upstream = in
}
return
}
func extractNet(upstream string) (NetProtocol, string) {
tcpUDPPrefix := NetProtocolTcpUdp.String() + ":"
if strings.HasPrefix(upstream, tcpUDPPrefix) {
return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):]
}
tcpTLSPrefix := NetProtocolTcpTls.String() + ":"
if strings.HasPrefix(upstream, tcpTLSPrefix) {
return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):]
}
httpsPrefix := NetProtocolHttps.String() + ":"
if strings.HasPrefix(upstream, httpsPrefix) {
return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//")
}
return NetProtocolTcpUdp, upstream
}

1
go.mod
View File

@ -123,6 +123,7 @@ require (
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/exp v0.0.0-20230307190834-24139beb5833
golang.org/x/mod v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/term v0.6.0 // indirect

2
go.sum
View File

@ -475,6 +475,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

View File

@ -12,7 +12,6 @@ import (
"sync"
"time"
"github.com/hako/durafmt"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
@ -38,9 +37,6 @@ type ListCacheType int
type Matcher interface {
// Match matches passed domain name against cached list entries
Match(domain string, groupsToCheck []string) (found bool, group string)
// Configuration returns current configuration and stats
Configuration() []string
}
// ListCache generic cache of strings divided in groups
@ -55,42 +51,19 @@ type ListCache struct {
processingConcurrency uint
}
// Configuration returns current configuration and stats
func (b *ListCache) Configuration() (result []string) {
if b.refreshPeriod > 0 {
result = append(result, fmt.Sprintf("refresh period: %s", durafmt.Parse(b.refreshPeriod)))
} else {
result = append(result, "refresh: disabled")
}
result = append(result, "group links:")
for group, links := range b.groupToLinks {
result = append(result, fmt.Sprintf(" %s:", group))
for _, link := range links {
if strings.Contains(link, "\n") {
link = "[INLINE DEFINITION]"
}
result = append(result, fmt.Sprintf(" - %s", link))
}
}
result = append(result, "group caches:")
// LogConfig implements `config.Configurable`.
func (b *ListCache) LogConfig(logger *logrus.Entry) {
var total int
b.lock.RLock()
defer b.lock.RUnlock()
for group, cache := range b.groupCaches {
result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount()))
logger.Infof("%s: %d entries", group, cache.ElementCount())
total += cache.ElementCount()
}
result = append(result, fmt.Sprintf(" TOTAL: %d entries", total))
return result
logger.Infof("TOTAL: %d entries", total)
}
// NewListCache creates new list instance

View File

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http/httptest"
"os"
@ -14,7 +13,9 @@ import (
. "github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/onsi/ginkgo/v2"
@ -414,36 +415,31 @@ var _ = Describe("ListCache", func() {
})
})
})
Describe("Configuration", func() {
When("refresh is enabled", func() {
It("should print list configuration", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {inlineList("inline", "definition")},
}
Describe("LogConfig", func() {
var (
logger *logrus.Entry
hook *log.MockLoggerHook
)
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
c := sut.Configuration()
Expect(c).Should(ContainElement("refresh period: 1 hour"))
Expect(len(c)).Should(BeNumerically(">", 1))
})
BeforeEach(func() {
logger, hook = log.NewMockEntry()
})
When("refresh is disabled", func() {
It("should print 'refresh disabled'", func() {
lists := map[string][]string{
"gr1": {emptyFile.Path},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(),
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
It("should print list configuration", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {inlineList("inline", "definition")},
}
c := sut.Configuration()
Expect(c).Should(ContainElement("refresh: disabled"))
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr1:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr2:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("TOTAL:")))
})
})
@ -486,7 +482,7 @@ func (m *MockDownloader) ListSource() string {
func createTestListFile(dir string, totalLines int) (string, int) {
file, err := os.CreateTemp(dir, "blocky")
if err != nil {
log.Fatal(err)
log.Log().Fatal(err)
}
w := bufio.NewWriter(file)

View File

@ -6,9 +6,11 @@ import (
"fmt"
"io"
"strings"
"sync"
"github.com/sirupsen/logrus"
prefixed "github.com/x-cray/logrus-prefixed-formatter"
"golang.org/x/exp/maps"
)
const prefixField = "prefix"
@ -118,3 +120,50 @@ func ConfigureLogger(cfg *Config) {
func Silence() {
logger.Out = io.Discard
}
func WithIndent(log *logrus.Entry, prefix string, callback func(*logrus.Entry)) {
undo := indentMessages(prefix, log.Logger)
defer undo()
callback(log)
}
// indentMessages modifies a logger and adds `prefix` to all messages.
//
// The returned function must be called to remove the prefix.
func indentMessages(prefix string, logger *logrus.Logger) func() {
if _, ok := logger.Formatter.(*prefixed.TextFormatter); !ok {
// log is not plaintext, do nothing
return func() {}
}
oldHooks := maps.Clone(logger.Hooks)
logger.AddHook(prefixMsgHook{
prefix: prefix,
})
var once sync.Once
return func() {
once.Do(func() {
logger.ReplaceHooks(oldHooks)
})
}
}
type prefixMsgHook struct {
prefix string
}
// Levels implements `logrus.Hook`.
func (h prefixMsgHook) Levels() []logrus.Level {
return logrus.AllLevels
}
// Fire implements `logrus.Hook`.
func (h prefixMsgHook) Fire(entry *logrus.Entry) error {
entry.Message = h.prefix + entry.Message
return nil
}

42
log/mock_entry.go Normal file
View File

@ -0,0 +1,42 @@
package log
import (
"io"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
)
func NewMockEntry() (*logrus.Entry, *MockLoggerHook) {
logger := logrus.New()
logger.Out = io.Discard
entry := logrus.Entry{Logger: logger}
hook := MockLoggerHook{}
entry.Logger.AddHook(&hook)
hook.On("Fire", mock.Anything).Return(nil)
return &entry, &hook
}
type MockLoggerHook struct {
mock.Mock
Messages []string
}
// Levels implements `logrus.Hook`.
func (h *MockLoggerHook) Levels() []logrus.Level {
return logrus.AllLevels
}
// Fire implements `logrus.Hook`.
func (h *MockLoggerHook) Fire(entry *logrus.Entry) error {
_ = h.Called()
h.Messages = append(h.Messages, entry.Message)
return nil
}

View File

@ -18,7 +18,7 @@ func RegisterMetric(c prometheus.Collector) {
}
// Start starts prometheus endpoint
func Start(router *chi.Mux, cfg config.PrometheusConfig) {
func Start(router *chi.Mux, cfg config.MetricsConfig) {
if cfg.Enable {
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
_ = reg.Register(collectors.NewGoCollector())

View File

@ -22,6 +22,31 @@ import (
// )
type ResponseType int
func (t ResponseType) ToExtendedErrorCode() uint16 {
switch t {
case ResponseTypeRESOLVED:
return dns.ExtendedErrorCodeOther
case ResponseTypeCACHED:
return dns.ExtendedErrorCodeCachedError
case ResponseTypeCONDITIONAL:
return dns.ExtendedErrorCodeForgedAnswer
case ResponseTypeCUSTOMDNS:
return dns.ExtendedErrorCodeForgedAnswer
case ResponseTypeHOSTSFILE:
return dns.ExtendedErrorCodeForgedAnswer
case ResponseTypeNOTFQDN:
return dns.ExtendedErrorCodeBlocked
case ResponseTypeBLOCKED:
return dns.ExtendedErrorCodeBlocked
case ResponseTypeFILTERED:
return dns.ExtendedErrorCodeFiltered
case ResponseTypeSPECIAL:
return dns.ExtendedErrorCodeFiltered
default:
return dns.ExtendedErrorCodeOther
}
}
// Response represents the response of a DNS query
type Response struct {
Res *dns.Msg

View File

@ -83,7 +83,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown),
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
} else {
rdb = redis.NewClient(&redis.Options{
@ -92,7 +92,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown),
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
}

View File

@ -35,7 +35,7 @@ func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
return nxDomainBlockHandler{}, nil
}
blockTime := uint32(time.Duration(cfg.BlockTTL).Seconds())
blockTime := cfg.BlockTTL.SecondsU32()
if strings.EqualFold(cfgBlockType, "ZEROIP") {
return zeroIPBlockHandler{
@ -78,16 +78,17 @@ type status struct {
// BlockingResolver checks request's question (domain name) against black and white lists
type BlockingResolver struct {
configurable[*config.BlockingConfig]
NextResolver
typed
blacklistMatcher *lists.ListCache
whitelistMatcher *lists.ListCache
cfg config.BlockingConfig
blockHandler blockHandler
whitelistOnlyGroups map[string]bool
status *status
clientGroupsBlock map[string][]string
redisClient *redis.Client
redisEnabled bool
fqdnIPCache expirationcache.ExpiringCache
}
@ -100,7 +101,7 @@ func NewBlockingResolver(
return nil, err
}
refreshPeriod := time.Duration(cfg.RefreshPeriod)
refreshPeriod := cfg.RefreshPeriod.ToDuration()
downloader := createDownloader(cfg, bootstrap)
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists,
refreshPeriod, downloader, cfg.ProcessingConcurrency,
@ -129,8 +130,10 @@ func NewBlockingResolver(
}
res := &BlockingResolver{
configurable: withConfig(&cfg),
typed: withType("blocking"),
blockHandler: blockHandler,
cfg: cfg,
blacklistMatcher: blacklistMatcher,
whitelistMatcher: whitelistMatcher,
whitelistOnlyGroups: whitelistOnlyGroups,
@ -140,10 +143,9 @@ func NewBlockingResolver(
},
clientGroupsBlock: cgb,
redisClient: redis,
redisEnabled: (redis != nil),
}
if res.redisEnabled {
if res.redisClient != nil {
setupRedisEnabledSubscriber(res)
}
@ -156,27 +158,25 @@ func NewBlockingResolver(
func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader {
return lists.NewDownloader(
lists.WithTimeout(time.Duration(cfg.DownloadTimeout)),
lists.WithTimeout(cfg.DownloadTimeout.ToDuration()),
lists.WithAttempts(cfg.DownloadAttempts),
lists.WithCooldown(time.Duration(cfg.DownloadCooldown)),
lists.WithCooldown(cfg.DownloadCooldown.ToDuration()),
lists.WithTransport(bootstrap.NewHTTPTransport()),
)
}
func setupRedisEnabledSubscriber(c *BlockingResolver) {
logger := log.PrefixedLog("blocking_resolver")
go func() {
for em := range c.redisClient.EnabledChannel {
if em != nil {
logger.Debug("Received state from redis: ", em)
c.log().Debug("Received state from redis: ", em)
if em.State {
c.internalEnableBlocking()
} else {
err := c.internalDisableBlocking(em.Duration, em.Groups)
if err != nil {
logger.Warn("Blocking couldn't be disabled:", err)
c.log().Warn("Blocking couldn't be disabled:", err)
}
}
}
@ -213,7 +213,7 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
func (r *BlockingResolver) EnableBlocking() {
r.internalEnableBlocking()
if r.redisEnabled {
if r.redisClient != nil {
r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true})
}
}
@ -232,7 +232,7 @@ func (r *BlockingResolver) internalEnableBlocking() {
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error {
err := r.internalDisableBlocking(duration, disableGroups)
if err == nil && r.redisEnabled {
if err == nil && r.redisClient != nil {
r.redisClient.PublishEnabled(&redis.EnabledMessage{
State: false,
Duration: duration,
@ -329,39 +329,15 @@ func (r *BlockingResolver) handleBlocked(logger *logrus.Entry,
return &model.Response{Res: response, RType: model.ResponseTypeBLOCKED, Reason: reason}, nil
}
// Configuration returns the current resolver configuration
func (r *BlockingResolver) Configuration() (result []string) {
if len(r.cfg.ClientGroupsBlock) == 0 {
return configDisabled
}
// LogConfig implements `config.Configurable`.
func (r *BlockingResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger)
result = append(result, "clientGroupsBlock")
for key, val := range r.cfg.ClientGroupsBlock {
result = append(result, fmt.Sprintf(" %s = \"%s\"", key, strings.Join(val, ";")))
}
logger.Info("blacklist cache entries:")
log.WithIndent(logger, " ", r.blacklistMatcher.LogConfig)
blockType := r.cfg.BlockType
result = append(result, fmt.Sprintf("blockType = \"%s\"", blockType))
if blockType != "NXDOMAIN" {
result = append(result, fmt.Sprintf("blockTTL = %s", r.cfg.BlockTTL.String()))
}
result = append(result, fmt.Sprintf("downloadTimeout = %s", r.cfg.DownloadTimeout.String()))
result = append(result, fmt.Sprintf("FailStartOnListError = %t", r.cfg.FailStartOnListError))
result = append(result, "blacklist:")
for _, c := range r.blacklistMatcher.Configuration() {
result = append(result, fmt.Sprintf(" %s", c))
}
result = append(result, "whitelist:")
for _, c := range r.whitelistMatcher.Configuration() {
result = append(result, fmt.Sprintf(" %s", c))
}
return result
logger.Info("whitelist cache entries:")
log.WithIndent(logger, " ", r.whitelistMatcher.LogConfig)
}
func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool {
@ -594,7 +570,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
}
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (result []net.IP, ttl time.Duration) {
prefixedLog := log.PrefixedLog("FQDNClientIdentifierCache")
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(&model.Request{

View File

@ -7,6 +7,7 @@ import (
. "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/lists"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
@ -51,7 +52,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer *dns.Msg
)
systemResolverBootstrap := &Bootstrap{}
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sutConfig = config.BlockingConfig{
@ -71,6 +76,22 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Events", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
@ -1086,35 +1107,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{"gr1": {group1File.Path}},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{}
})
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
Describe("Create resolver with wrong parameter", func() {
When("Wrong blockType is used", func() {
It("should return error", func() {

View File

@ -62,7 +62,11 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
return b, nil
}
parallelResolver, err := newParallelBestResolver(bootstraped.ResolverGroups())
// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
// where `ParallelBestResolver` uses its config, we can just use an empty one.
pbCfg := config.ParallelBestConfig{}
parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
if err != nil {
return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err)
}
@ -74,20 +78,16 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
cachingCfg := cfg.Caching
cachingCfg.EnablePrefetch()
if cachingCfg.MinCachingTime == 0 {
if cachingCfg.MinCachingTime.IsZero() {
// Set a min time in case the user didn't to avoid prefetching too often
cachingCfg.MinCachingTime = config.Duration(time.Hour)
}
b.bootstraped = bootstraped
cachingResolver := NewCachingResolver(cachingCfg, nil)
// don't emit any metrics
cachingResolver.emitMetricEvents = false
b.resolver = Chain(
NewFilteringResolver(cfg.Filtering),
cachingResolver,
newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones
parallelResolver,
)
@ -116,10 +116,10 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
ctx := context.Background()
timeout := cfg.UpstreamTimeout
if timeout != 0 {
if timeout.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration())
defer cancel()
}

View File

@ -282,7 +282,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
Expect(ips).Should(HaveLen(0))
Expect(ips).Should(BeEmpty())
})
})
@ -296,7 +296,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host"))
Expect(ips).Should(HaveLen(0))
Expect(ips).Should(BeEmpty())
})
})

View File

@ -5,8 +5,6 @@ import (
"sync/atomic"
"time"
"github.com/hako/durafmt"
"github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
@ -24,16 +22,15 @@ const defaultCachingCleanUpInterval = 5 * time.Second
// CachingResolver caches answers from dns queries with their TTL time,
// to avoid external resolver calls for recurrent queries
type CachingResolver struct {
configurable[*config.CachingConfig]
NextResolver
minCacheTimeSec, maxCacheTimeSec int
cacheTimeNegative time.Duration
resultCache expirationcache.ExpiringCache
prefetchExpires time.Duration
prefetchThreshold int
prefetchingNameCache expirationcache.ExpiringCache
redisClient *redis.Client
redisEnabled bool
emitMetricEvents bool
typed
emitMetricEvents bool // disabled by Bootstrap
resultCache expirationcache.ExpiringCache
prefetchingNameCache expirationcache.ExpiringCache
redisClient *redis.Client
}
// cacheValue includes query answer and prefetch flag
@ -44,18 +41,21 @@ type cacheValue struct {
// NewCachingResolver creates a new resolver instance
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
return newCachingResolver(cfg, redis, true)
}
func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver {
c := &CachingResolver{
minCacheTimeSec: int(time.Duration(cfg.MinCachingTime).Seconds()),
maxCacheTimeSec: int(time.Duration(cfg.MaxCachingTime).Seconds()),
cacheTimeNegative: time.Duration(cfg.CacheTimeNegative),
redisClient: redis,
redisEnabled: (redis != nil),
emitMetricEvents: true,
configurable: withConfig(&cfg),
typed: withType("caching"),
redisClient: redis,
emitMetricEvents: emitMetricEvents,
}
configureCaches(c, &cfg)
if c.redisEnabled {
if c.redisClient != nil {
setupRedisCacheSubscriber(c)
c.redisClient.GetRedisCache()
}
@ -68,10 +68,6 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
if cfg.Prefetching {
c.prefetchExpires = time.Duration(cfg.PrefetchExpires)
c.prefetchThreshold = cfg.PrefetchThreshold
c.prefetchingNameCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval(time.Minute),
expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount)),
@ -88,12 +84,10 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
}
func setupRedisCacheSubscriber(c *CachingResolver) {
logger := log.PrefixedLog("caching_resolver")
go func() {
for rc := range c.redisClient.CacheChannel {
if rc != nil {
logger.Debug("Received key from redis: ", rc.Key)
c.log().Debug("Received key from redis: ", rc.Key)
c.putInCache(rc.Key, rc.Response, false, false)
}
}
@ -102,22 +96,22 @@ func setupRedisCacheSubscriber(c *CachingResolver) {
// check if domain was queried > threshold in the time window
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
if r.prefetchThreshold == 0 {
if r.cfg.PrefetchThreshold == 0 {
return true
}
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
return cnt != nil && cnt.(int) > r.prefetchThreshold
return cnt != nil && cnt.(int) > r.cfg.PrefetchThreshold
}
func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := log.PrefixedLog("caching_resolver")
if r.shouldPrefetch(cacheKey) {
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType.String())
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
@ -136,29 +130,11 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
return nil, 0
}
// Configuration returns a current resolver configuration
func (r *CachingResolver) Configuration() (result []string) {
if r.maxCacheTimeSec < 0 {
return configDisabled
}
// LogConfig implements `config.Configurable`.
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger)
result = append(result, fmt.Sprintf("minCacheTimeInSec = %d", r.minCacheTimeSec))
result = append(result, fmt.Sprintf("maxCacheTimeSec = %d", r.maxCacheTimeSec))
result = append(result, fmt.Sprintf("cacheTimeNegative = %s", durafmt.Parse(r.cacheTimeNegative)))
result = append(result, fmt.Sprintf("prefetching = %t", r.prefetchingNameCache != nil))
if r.prefetchingNameCache != nil {
result = append(result, fmt.Sprintf("prefetchExpires = %s", durafmt.Parse(r.prefetchExpires)))
result = append(result, fmt.Sprintf("prefetchThreshold = %d", r.prefetchThreshold))
}
result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.TotalCount()))
return
logger.Infof("cache entries = %d", r.resultCache.TotalCount())
}
// Resolve checks if the current query result is already in the cache and returns it
@ -166,7 +142,7 @@ func (r *CachingResolver) Configuration() (result []string) {
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")
if r.maxCacheTimeSec < 0 {
if r.cfg.MaxCachingTime < 0 {
logger.Debug("skip cache")
return r.next.Resolve(request)
@ -214,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
response, err = r.next.Resolve(request)
if err == nil {
r.putInCache(cacheKey, response, false, r.redisEnabled)
r.putInCache(cacheKey, response, false, true)
}
}
@ -228,7 +204,7 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log
domainCount = x.(int)
}
domainCount++
r.prefetchingNameCache.Put(cacheKey, domainCount, r.prefetchExpires)
r.prefetchingNameCache.Put(cacheKey, domainCount, r.cfg.PrefetchExpires.ToDuration())
totalCount := r.prefetchingNameCache.TotalCount()
logger.Debugf("domain '%s' was requested %d times, "+
@ -242,9 +218,9 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
// put value into cache
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer))
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cacheTimeNegative > 0 {
if r.cfg.CacheTimeNegative > 0 {
// put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cacheTimeNegative)
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
}
}
@ -264,20 +240,20 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
var max uint32
if len(answer) == 0 {
return r.cacheTimeNegative
return r.cfg.CacheTimeNegative.ToDuration()
}
for _, a := range answer {
// if TTL < mitTTL -> adjust the value, set minTTL
if r.minCacheTimeSec > 0 {
if atomic.LoadUint32(&a.Header().Ttl) < uint32(r.minCacheTimeSec) {
atomic.StoreUint32(&a.Header().Ttl, uint32(r.minCacheTimeSec))
if r.cfg.MinCachingTime > 0 {
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
}
}
if r.maxCacheTimeSec > 0 {
if atomic.LoadUint32(&a.Header().Ttl) > uint32(r.maxCacheTimeSec) {
atomic.StoreUint32(&a.Header().Ttl, uint32(r.maxCacheTimeSec))
if r.cfg.MaxCachingTime > 0 {
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
@ -27,6 +28,12 @@ var _ = Describe("CachingResolver", func() {
mockAnswer *dns.Msg
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sutConfig = config.CachingConfig{}
if err := defaults.Set(&sutConfig); err != nil {
@ -42,6 +49,22 @@ var _ = Describe("CachingResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Caching responses", func() {
When("prefetching is enabled", func() {
BeforeEach(func() {
@ -103,6 +126,15 @@ var _ = Describe("CachingResolver", func() {
HaveTTL(BeNumerically("<=", 2))))
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
})
When("threshold is 0", func() {
BeforeEach(func() {
sutConfig.PrefetchThreshold = 0
})
It("should always prefetch", func() {
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
})
})
})
When("min caching time is defined", func() {
BeforeEach(func() {
@ -364,6 +396,10 @@ var _ = Describe("CachingResolver", func() {
})
It("response should be cached", func() {
By("default config should enable negative caching", func() {
Expect(sutConfig.CacheTimeNegative).Should(BeNumerically(">", 0))
})
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Should(SatisfyAll(
@ -495,43 +531,6 @@ var _ = Describe("CachingResolver", func() {
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
BeforeEach(func() {
sutConfig = config.CachingConfig{}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {
BeforeEach(func() {
sutConfig = config.CachingConfig{
MaxCachingTime: config.Duration(time.Minute * -1),
}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
When("prefetching is enabled", func() {
BeforeEach(func() {
sutConfig = config.CachingConfig{
Prefetching: true,
}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
Expect(c).Should(ContainElement(ContainSubstring("prefetchThreshold")))
})
})
})
Describe("Redis is configured", func() {
var (
redisServer *miniredis.Miniredis

View File

@ -1,7 +1,6 @@
package resolver
import (
"fmt"
"net"
"strings"
"time"
@ -18,11 +17,12 @@ import (
// ClientNamesResolver tries to determine client name by asking responsible DNS server via rDNS (reverse lookup)
type ClientNamesResolver struct {
configurable[*config.ClientLookupConfig]
NextResolver
typed
cache expirationcache.ExpiringCache
externalResolver Resolver
singleNameOrder []uint
clientIPMapping map[string][]net.IP
NextResolver
}
// NewClientNamesResolver creates new resolver instance
@ -38,38 +38,21 @@ func NewClientNamesResolver(
}
cr = &ClientNamesResolver{
configurable: withConfig(&cfg),
typed: withType("client_names"),
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval(time.Hour)),
externalResolver: r,
singleNameOrder: cfg.SingleNameOrder,
clientIPMapping: cfg.ClientnameIPMapping,
}
return
}
// Configuration returns current resolver configuration
func (r *ClientNamesResolver) Configuration() (result []string) {
if r.externalResolver == nil && len(r.clientIPMapping) == 0 {
return append(configDisabled, "use only IP address")
}
// LogConfig implements `config.Configurable`.
func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger)
result = append(result, fmt.Sprintf("singleNameOrder = \"%v\"", r.singleNameOrder))
if r.externalResolver != nil {
result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver))
}
result = append(result, fmt.Sprintf("cache item count = %d", r.cache.TotalCount()))
if len(r.clientIPMapping) > 0 {
result = append(result, "client IP mapping:")
for k, v := range r.clientIPMapping {
result = append(result, fmt.Sprintf("%s -> %s", k, v))
}
}
return
logger.Infof("cache entries = %d", r.cache.TotalCount())
}
// Resolve tries to resolve the client name from the ip address
@ -89,13 +72,11 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
}
ip := request.ClientIP
if ip == nil {
return []string{}
}
c, _ := r.cache.Get(ip.String())
if c != nil {
if t, ok := c.([]string); ok {
return t
@ -103,6 +84,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
}
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
r.cache.Put(ip.String(), names, time.Hour)
return names
@ -127,7 +109,6 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
// try client mapping first
result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 {
return
}
@ -151,8 +132,8 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip)
// optional: if singleNameOrder is set, use only one name in the defined order
if len(r.singleNameOrder) > 0 {
for _, i := range r.singleNameOrder {
if len(r.cfg.SingleNameOrder) > 0 {
for _, i := range r.cfg.SingleNameOrder {
if i > 0 && int(i) <= len(clientNames) {
result = []string{clientNames[i-1]}
@ -169,7 +150,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
}
func (r *ClientNamesResolver) getNameFromIPMapping(ip net.IP, result []string) []string {
for name, ips := range r.clientIPMapping {
for name, ips := range r.cfg.ClientnameIPMapping {
for _, i := range ips {
if ip.String() == i.String() {
result = append(result, name)

View File

@ -5,6 +5,7 @@ import (
"net"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
@ -22,6 +23,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
m *mockResolver
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
JustBeforeEach(func() {
res, err := NewClientNamesResolver(sutConfig, nil, false)
Expect(err).Should(Succeed())
@ -31,6 +38,22 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Resolve client name from request clientID", func() {
BeforeEach(func() {
sutConfig = config.ClientLookupConfig{}
@ -281,7 +304,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
})
When("Upstream produces error", func() {
BeforeEach(func() {
JustBeforeEach(func() {
sutConfig = config.ClientLookupConfig{}
clientMockResolver := &mockResolver{}
clientMockResolver.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
@ -348,32 +371,4 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
BeforeEach(func() {
sutConfig = config.ClientLookupConfig{
Upstream: config.Upstream{Net: config.NetProtocolTcpUdp, Host: "host"},
SingleNameOrder: []uint{1, 2},
ClientnameIPMapping: map[string][]net.IP{
"client8": {net.ParseIP("1.2.3.5")},
},
}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {
BeforeEach(func() {
sutConfig = config.ClientLookupConfig{}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
})
})

View File

@ -1,7 +1,6 @@
package resolver
import (
"fmt"
"strings"
"github.com/0xERR0R/blocky/config"
@ -15,7 +14,10 @@ import (
// ConditionalUpstreamResolver delegates DNS question to other DNS resolver dependent on domain name in question
type ConditionalUpstreamResolver struct {
configurable[*config.ConditionalUpstreamConfig]
NextResolver
typed
mapping map[string]Resolver
}
@ -26,10 +28,13 @@ func NewConditionalUpstreamResolver(
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstream := range cfg.Mapping.Upstreams {
upstreams := make(map[string][]config.Upstream)
upstreams[upstreamDefaultCfgName] = upstream
pbCfg := config.ParallelBestConfig{
ExternalResolvers: config.ParallelBestMapping{
upstreamDefaultCfgName: upstream,
},
}
r, err := NewParallelBestResolver(upstreams, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
@ -37,20 +42,14 @@ func NewConditionalUpstreamResolver(
m[strings.ToLower(domain)] = r
}
return &ConditionalUpstreamResolver{mapping: m}, nil
}
r := ConditionalUpstreamResolver{
configurable: withConfig(&cfg),
typed: withType("conditional_upstream"),
// Configuration returns current configuration
func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
if len(r.mapping) == 0 {
return configDisabled
mapping: m,
}
for key, val := range r.mapping {
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
}
return
return &r, nil
}
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {

View File

@ -3,6 +3,7 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -18,6 +19,12 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
m *mockResolver
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
@ -54,6 +61,22 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Resolve conditional DNS queries via defined DNS server", func() {
When("Query is exact equal defined condition in mapping", func() {
Context("first mapping entry", func() {
@ -149,22 +172,4 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
Expect(r).Should(BeNil())
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {
BeforeEach(func() {
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, nil, false)
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
})
})

View File

@ -1,10 +1,8 @@
package resolver
import (
"fmt"
"net"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
@ -17,11 +15,12 @@ import (
// CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map
type CustomDNSResolver struct {
configurable[*config.CustomDNSConfig]
NextResolver
mapping map[string][]net.IP
reverseAddresses map[string][]string
ttl uint32
filterUnmappedTypes bool
typed
mapping map[string][]net.IP
reverseAddresses map[string][]string
}
// NewCustomDNSResolver creates new resolver instance
@ -38,27 +37,13 @@ func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver {
}
}
ttl := uint32(time.Duration(cfg.CustomTTL).Seconds())
return &CustomDNSResolver{
mapping: m,
reverseAddresses: reverse,
ttl: ttl,
filterUnmappedTypes: cfg.FilterUnmappedTypes,
}
}
configurable: withConfig(&cfg),
typed: withType("custom_dns"),
// Configuration returns current resolver configuration
func (r *CustomDNSResolver) Configuration() (result []string) {
if len(r.mapping) == 0 {
return configDisabled
mapping: m,
reverseAddresses: reverse,
}
for key, val := range r.mapping {
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
}
return
}
func isSupportedType(ip net.IP, question dns.Question) bool {
@ -75,7 +60,7 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
response.SetReply(request.Req)
for _, url := range urls {
h := util.CreateHeader(question, r.ttl)
h := util.CreateHeader(question, r.cfg.CustomTTL.SecondsU32())
ptr := new(dns.PTR)
ptr.Ptr = dns.Fqdn(url)
ptr.Hdr = h
@ -103,7 +88,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
if found {
for _, ip := range ips {
if isSupportedType(ip, question) {
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32())
response.Answer = append(response.Answer, rr)
}
}
@ -118,7 +103,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
}
// Mapping exists for this domain, but for another type
if !r.filterUnmappedTypes {
if !r.cfg.FilterUnmappedTypes {
// go to next resolver
break
}

View File

@ -6,6 +6,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
@ -15,12 +16,18 @@ import (
var _ = Describe("CustomDNSResolver", func() {
var (
TTL = uint32(time.Now().Second())
sut ChainedResolver
m *mockResolver
cfg config.CustomDNSConfig
)
TTL := uint32(time.Now().Second())
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
cfg = config.CustomDNSConfig{
@ -45,6 +52,22 @@ var _ = Describe("CustomDNSResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Resolving custom name via CustomDNSResolver", func() {
When("Ip 4 mapping is defined for custom domain and", func() {
Context("filterUnmappedTypes is true", func() {
@ -257,23 +280,4 @@ var _ = Describe("CustomDNSResolver", func() {
})
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("resolver is disabled", func() {
BeforeEach(func() {
cfg = config.CustomDNSConfig{}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
})
})

View File

@ -7,77 +7,47 @@ import (
)
type EdeResolver struct {
configurable[*config.EdeConfig]
NextResolver
config config.EdeConfig
typed
}
func NewEdeResolver(cfg config.EdeConfig) ChainedResolver {
return &EdeResolver{
config: cfg,
configurable: withConfig(&cfg),
typed: withType("extended_error_code"),
}
}
func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) {
if !r.cfg.Enable {
return r.next.Resolve(request)
}
resp, err := r.next.Resolve(request)
if err != nil {
return nil, err
}
if r.config.Enable {
addExtraReasoning(resp)
}
r.addExtraReasoning(resp)
return resp, nil
}
func (r *EdeResolver) Configuration() (result []string) {
if !r.config.Enable {
return configDisabled
func (r *EdeResolver) addExtraReasoning(res *model.Response) {
infocode := res.RType.ToExtendedErrorCode()
if infocode == dns.ExtendedErrorCodeOther {
// dns.ExtendedErrorCodeOther seams broken in some clients
return
}
return configEnabled
}
func addExtraReasoning(res *model.Response) {
// dns.ExtendedErrorCodeOther seams broken in some clients
infocode := convertToExtendedErrorCode(res.RType)
if infocode > 0 {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.Option = append(opt.Option, convertExtendedError(res, infocode))
res.Res.Extra = append(res.Res.Extra, opt)
}
}
func convertExtendedError(input *model.Response, infocode uint16) *dns.EDNS0_EDE {
return &dns.EDNS0_EDE{
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.Option = append(opt.Option, &dns.EDNS0_EDE{
InfoCode: infocode,
ExtraText: input.Reason,
}
}
func convertToExtendedErrorCode(input model.ResponseType) uint16 {
switch input {
case model.ResponseTypeRESOLVED:
return dns.ExtendedErrorCodeOther
case model.ResponseTypeCACHED:
return dns.ExtendedErrorCodeCachedError
case model.ResponseTypeCONDITIONAL:
return dns.ExtendedErrorCodeForgedAnswer
case model.ResponseTypeCUSTOMDNS:
return dns.ExtendedErrorCodeForgedAnswer
case model.ResponseTypeHOSTSFILE:
return dns.ExtendedErrorCodeForgedAnswer
case model.ResponseTypeNOTFQDN:
return dns.ExtendedErrorCodeBlocked
case model.ResponseTypeBLOCKED:
return dns.ExtendedErrorCodeBlocked
case model.ResponseTypeFILTERED:
return dns.ExtendedErrorCodeFiltered
case model.ResponseTypeSPECIAL:
return dns.ExtendedErrorCodeFiltered
default:
return dns.ExtendedErrorCodeOther
}
ExtraText: res.Reason,
})
res.Res.Extra = append(res.Res.Extra, opt)
}

View File

@ -5,6 +5,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
@ -22,6 +23,12 @@ var _ = Describe("EdeResolver", func() {
mockAnswer *dns.Msg
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mockAnswer = new(dns.Msg)
})
@ -59,6 +66,12 @@ var _ = Describe("EdeResolver", func() {
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
})
When("ede is enabled", func() {
@ -106,26 +119,14 @@ var _ = Describe("EdeResolver", func() {
Expect(err).To(Equal(resolveErr))
})
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
BeforeEach(func() {
sutConfig = config.EdeConfig{Enable: true}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(c).Should(Equal(configEnabled))
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
When("resolver is disabled", func() {
BeforeEach(func() {
sutConfig = config.EdeConfig{Enable: false}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
})

View File

@ -1,10 +1,6 @@
package resolver
import (
"fmt"
"sort"
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
@ -13,13 +9,21 @@ import (
// FilteringResolver filters DNS queries (for example can drop all AAAA query)
// returns empty ANSWER with NOERROR
type FilteringResolver struct {
configurable[*config.FilteringConfig]
NextResolver
queryTypes config.QTypeSet
typed
}
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
return &FilteringResolver{
configurable: withConfig(&cfg),
typed: withType("filtering"),
}
}
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) {
qType := request.Req.Question[0].Qtype
if r.queryTypes.Contains(dns.Type(qType)) {
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
response := new(dns.Msg)
response.SetRcode(request.Req, dns.RcodeSuccess)
@ -28,27 +32,3 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
return r.next.Resolve(request)
}
func (r *FilteringResolver) Configuration() (result []string) {
if len(r.queryTypes) == 0 {
return configDisabled
}
qTypes := make([]string, 0, len(r.queryTypes))
for qType := range r.queryTypes {
qTypes = append(qTypes, qType.String())
}
sort.Strings(qTypes)
result = append(result, fmt.Sprintf("filtering query Types: '%v'", strings.Join(qTypes, ", ")))
return
}
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
return &FilteringResolver{
queryTypes: cfg.QueryTypes,
}
}

View File

@ -3,7 +3,9 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -18,6 +20,12 @@ var _ = Describe("FilteringResolver", func() {
mockAnswer *dns.Msg
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mockAnswer = new(dns.Msg)
})
@ -29,6 +37,22 @@ var _ = Describe("FilteringResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
When("Filtering query types are defined", func() {
BeforeEach(func() {
sutConfig = config.FilteringConfig{
@ -59,10 +83,6 @@ var _ = Describe("FilteringResolver", func() {
// no call of next resolver
Expect(m.Calls).Should(BeZero())
})
It("Configure should output all query types", func() {
c := sut.Configuration()
Expect(c).Should(Equal([]string{"filtering query Types: 'AAAA, MX'"}))
})
})
When("No filtering query types are defined", func() {
@ -81,9 +101,5 @@ var _ = Describe("FilteringResolver", func() {
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Configure should output 'empty list'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
})

View File

@ -10,18 +10,20 @@ import (
)
type FqdnOnlyResolver struct {
configurable[*config.FqdnOnlyConfig]
NextResolver
enabled bool
typed
}
func NewFqdnOnlyResolver(cfg config.Config) *FqdnOnlyResolver {
func NewFqdnOnlyResolver(cfg config.FqdnOnlyConfig) *FqdnOnlyResolver {
return &FqdnOnlyResolver{
enabled: cfg.FqdnOnly,
configurable: withConfig(&cfg),
typed: withType("fqdn_only"),
}
}
func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, error) {
if r.enabled {
if r.IsEnabled() {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
if !strings.Contains(domainFromQuestion, ".") {
response := new(dns.Msg)
@ -33,11 +35,3 @@ func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, err
return r.next.Resolve(request)
}
func (r *FqdnOnlyResolver) Configuration() (result []string) {
if !r.enabled {
return configDisabled
}
return configEnabled
}

View File

@ -3,6 +3,7 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
@ -13,11 +14,17 @@ import (
var _ = Describe("FqdnOnlyResolver", func() {
var (
sut *FqdnOnlyResolver
sutConfig config.Config
sutConfig config.FqdnOnlyConfig
m *mockResolver
mockAnswer *dns.Msg
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mockAnswer = new(dns.Msg)
})
@ -29,11 +36,25 @@ var _ = Describe("FqdnOnlyResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
When("Fqdn only is enabled", func() {
BeforeEach(func() {
sutConfig = config.Config{
FqdnOnly: true,
}
sutConfig = config.FqdnOnlyConfig{Enable: true}
})
It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
@ -59,17 +80,27 @@ var _ = Describe("FqdnOnlyResolver", func() {
// no call of next resolver
Expect(m.Calls).Should(BeZero())
})
It("Configure should output enabled", func() {
c := sut.Configuration()
Expect(c).Should(Equal(configEnabled))
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
})
When("Fqdn only is disabled", func() {
BeforeEach(func() {
sutConfig = config.Config{
FqdnOnly: false,
}
sutConfig = config.FqdnOnlyConfig{Enable: false}
})
It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
@ -95,9 +126,11 @@ var _ = Describe("FqdnOnlyResolver", func() {
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
It("Configure should output disabled", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
})
})

View File

@ -9,7 +9,6 @@ import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -17,23 +16,44 @@ import (
)
const (
hostsFileResolverLogger = "hosts_file_resolver"
// reduce initial capacity so we don't waste memory if there are less entries than before
memReleaseFactor = 2
)
type HostsFileResolver struct {
configurable[*config.HostsFileConfig]
NextResolver
HostsFilePath string
hosts splitHostsFileData
ttl uint32
refreshPeriod time.Duration
filterLoopback bool
typed
hosts splitHostsFileData
}
type HostsFileEntry = parsers.HostsFileEntry
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
r := HostsFileResolver{
configurable: withConfig(&cfg),
typed: withType("hosts_file"),
}
if err := r.parseHostsFile(context.Background()); err != nil {
r.log().Errorf("disabling hosts file resolving due to error: %s", err)
r.cfg.Filepath = "" // don't try parsing the file again
} else {
go r.periodicUpdate()
}
return &r
}
// LogConfig implements `config.Configurable`.
func (r *HostsFileResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger)
logger.Infof("cache entries = %d", r.hosts.len())
}
func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Response {
question := request.Req.Question[0]
if question.Qtype != dns.TypePTR {
@ -46,7 +66,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil
}
if r.filterLoopback && questionIP.IsLoopback() {
if r.cfg.FilterLoopback && questionIP.IsLoopback() {
// skip the search: we won't find anything
return nil
}
@ -64,13 +84,13 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
ptr := new(dns.PTR)
ptr.Ptr = dns.Fqdn(host)
ptr.Hdr = util.CreateHeader(question, r.ttl)
ptr.Hdr = util.CreateHeader(question, r.cfg.HostsTTL.SecondsU32())
response.Answer = append(response.Answer, ptr)
for _, alias := range hostData.Aliases {
ptrAlias := new(dns.PTR)
ptrAlias.Ptr = dns.Fqdn(alias)
ptrAlias.Hdr = util.CreateHeader(question, r.ttl)
ptrAlias.Hdr = ptr.Hdr
response.Answer = append(response.Answer, ptrAlias)
}
@ -82,9 +102,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
}
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, hostsFileResolverLogger)
if r.HostsFilePath == "" {
if r.cfg.Filepath == "" {
return r.next.Resolve(request)
}
@ -98,7 +116,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
response := r.resolve(request.Req, question, domain)
if response != nil {
logger.WithFields(logrus.Fields{
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": domain,
}).Debugf("returning hosts file entry")
@ -106,7 +124,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
}
logger.WithField("resolver", Name(r.next)).Trace("go to next resolver")
r.log().WithField("resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
}
@ -117,7 +135,7 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
return nil
}
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.HostsTTL.SecondsU32())
response := new(dns.Msg)
response.SetReply(req)
@ -126,47 +144,14 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
return response
}
func (r *HostsFileResolver) Configuration() (result []string) {
if r.HostsFilePath == "" || r.hosts.isEmpty() {
return configDisabled
}
result = append(result, fmt.Sprintf("file path: %s", r.HostsFilePath))
result = append(result, fmt.Sprintf("TTL: %d", r.ttl))
result = append(result, fmt.Sprintf("refresh period: %s", r.refreshPeriod.String()))
result = append(result, fmt.Sprintf("filter loopback addresses: %t", r.filterLoopback))
return
}
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
r := HostsFileResolver{
HostsFilePath: cfg.Filepath,
ttl: uint32(time.Duration(cfg.HostsTTL).Seconds()),
refreshPeriod: time.Duration(cfg.RefreshPeriod),
filterLoopback: cfg.FilterLoopback,
}
if err := r.parseHostsFile(context.Background()); err != nil {
logger := log.PrefixedLog(hostsFileResolverLogger)
logger.Warnf("hosts file resolving is disabled: %s", err)
r.HostsFilePath = "" // don't try parsing the file again
} else {
go r.periodicUpdate()
}
return &r
}
func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
const maxErrorsPerFile = 5
if r.HostsFilePath == "" {
if r.cfg.Filepath == "" {
return nil
}
f, err := os.Open(r.HostsFilePath)
f, err := os.Open(r.cfg.Filepath)
if err != nil {
return err
}
@ -176,7 +161,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile)
p.OnErr(func(err error) {
log.PrefixedLog(hostsFileResolverLogger).Warnf("error parsing %s: %s, trying to continue", r.HostsFilePath, err)
r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err)
})
err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
@ -187,7 +172,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
}
// Ignore loopback, if so configured
if r.filterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") {
if r.cfg.FilterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") {
return nil
}
@ -196,7 +181,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
return nil
})
if err != nil {
return fmt.Errorf("error parsing %s: %w", r.HostsFilePath, err) // err is parsers.ErrTooManyErrors
return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors
}
r.hosts = newHosts
@ -205,15 +190,14 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
}
func (r *HostsFileResolver) periodicUpdate() {
if r.refreshPeriod > 0 {
ticker := time.NewTicker(r.refreshPeriod)
if r.cfg.RefreshPeriod.ToDuration() > 0 {
ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration())
defer ticker.Stop()
for {
<-ticker.C
logger := log.PrefixedLog(hostsFileResolverLogger)
logger.WithField("file", r.HostsFilePath).Debug("refreshing hosts file")
r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file")
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background()))
}
@ -238,7 +222,11 @@ func newSplitHostsDataWithSameCapacity(other splitHostsFileData) splitHostsFileD
}
func (d splitHostsFileData) isEmpty() bool {
return d.v4.isEmpty() && d.v6.isEmpty()
return d.len() == 0
}
func (d splitHostsFileData) len() int {
return d.v4.len() + d.v6.len()
}
func (d splitHostsFileData) getIP(qType dns.Type, domain string) net.IP {
@ -277,8 +265,8 @@ func newHostsDataWithSameCapacity(other hostsFileData) hostsFileData {
}
}
func (d hostsFileData) isEmpty() bool {
return len(d.hosts) == 0 && len(d.aliases) == 0
func (d hostsFileData) len() int {
return len(d.hosts) + len(d.aliases)
}
func (d hostsFileData) getIP(hostname string) net.IP {

View File

@ -2,12 +2,11 @@ package resolver
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
@ -17,6 +16,8 @@ import (
var _ = Describe("HostsFileResolver", func() {
var (
TTL = uint32(time.Now().Second())
sut *HostsFileResolver
sutConfig config.HostsFileConfig
m *mockResolver
@ -24,7 +25,11 @@ var _ = Describe("HostsFileResolver", func() {
tmpFile *TmpFile
)
TTL := uint32(time.Now().Second())
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
tmpDir = NewTmpFolder("HostsFileResolver")
@ -49,16 +54,32 @@ var _ = Describe("HostsFileResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Using hosts file", func() {
When("Hosts file cannot be located", func() {
BeforeEach(func() {
sutConfig = config.HostsFileConfig{
Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()),
Filepath: "/this/file/does/not/exist",
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
}
})
It("should not parse any hosts", func() {
Expect(sut.HostsFilePath).Should(BeEmpty())
Expect(sut.cfg.Filepath).Should(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
@ -140,11 +161,11 @@ var _ = Describe("HostsFileResolver", func() {
It("should not be used", func() {
Expect(sut).ShouldNot(BeNil())
Expect(sut.HostsFilePath).Should(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(HaveLen(0))
Expect(sut.hosts.v6.hosts).Should(HaveLen(0))
Expect(sut.hosts.v4.aliases).Should(HaveLen(0))
Expect(sut.hosts.v6.aliases).Should(HaveLen(0))
Expect(sut.cfg.Filepath).Should(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
Expect(sut.hosts.v6.aliases).Should(BeEmpty())
})
})
@ -307,25 +328,6 @@ var _ = Describe("HostsFileResolver", func() {
})
})
Describe("Configuration output", func() {
When("hosts file is provided", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
When("hosts file is not provided", func() {
BeforeEach(func() {
sutConfig = config.HostsFileConfig{}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
Expect(c).Should(ContainElement(configStatusDisabled))
})
})
})
Describe("Delegating to next resolver", func() {
When("no hosts file is provided", func() {
It("should delegate to next resolver", func() {

View File

@ -1,7 +1,6 @@
package resolver
import (
"fmt"
"strings"
"time"
@ -15,8 +14,10 @@ import (
// MetricsResolver resolver that records metrics about requests/response
type MetricsResolver struct {
configurable[*config.MetricsConfig]
NextResolver
cfg config.PrometheusConfig
typed
totalQueries *prometheus.CounterVec
totalResponse *prometheus.CounterVec
totalErrors prometheus.Counter
@ -24,11 +25,11 @@ type MetricsResolver struct {
}
// Resolve resolves the passed request
func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
response, err := m.next.Resolve(request)
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
response, err := r.next.Resolve(request)
if m.cfg.Enable {
m.totalQueries.With(prometheus.Labels{
if r.cfg.Enable {
r.totalQueries.With(prometheus.Labels{
"client": strings.Join(request.ClientNames, ","),
"type": dns.TypeToString[request.Req.Question[0].Qtype],
}).Inc()
@ -40,12 +41,12 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
responseType = response.RType.String()
}
m.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs)
r.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs)
if err != nil {
m.totalErrors.Inc()
r.totalErrors.Inc()
} else {
m.totalResponse.With(prometheus.Labels{
r.totalResponse.With(prometheus.Labels{
"reason": response.Reason,
"response_code": dns.RcodeToString[response.Res.Rcode],
"response_type": response.RType.String(),
@ -56,38 +57,28 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
return response, err
}
// Configuration gets the config of this resolver in a string slice
func (m *MetricsResolver) Configuration() (result []string) {
if !m.cfg.Enable {
return configDisabled
// NewMetricsResolver creates a new intance of the MetricsResolver type
func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver {
m := MetricsResolver{
configurable: withConfig(&cfg),
typed: withType("metrics"),
durationHistogram: durationHistogram(),
totalQueries: totalQueriesMetric(),
totalResponse: totalResponseMetric(),
totalErrors: totalErrorMetric(),
}
result = append(result, "metrics:")
result = append(result, fmt.Sprintf(" Enable = %t", m.cfg.Enable))
result = append(result, fmt.Sprintf(" Path = %s", m.cfg.Path))
m.registerMetrics()
return
return &m
}
// NewMetricsResolver creates a new intance of the MetricsResolver type
func NewMetricsResolver(cfg config.PrometheusConfig) ChainedResolver {
durationHistogram := durationHistogram()
totalQueries := totalQueriesMetric()
totalResponse := totalResponseMetric()
totalErrors := totalErrorMetric()
metrics.RegisterMetric(durationHistogram)
metrics.RegisterMetric(totalQueries)
metrics.RegisterMetric(totalResponse)
metrics.RegisterMetric(totalErrors)
return &MetricsResolver{
cfg: cfg,
durationHistogram: durationHistogram,
totalQueries: totalQueries,
totalResponse: totalResponse,
totalErrors: totalErrors,
}
func (r *MetricsResolver) registerMetrics() {
metrics.RegisterMetric(r.durationHistogram)
metrics.RegisterMetric(r.totalQueries)
metrics.RegisterMetric(r.totalResponse)
metrics.RegisterMetric(r.totalErrors)
}
func totalQueriesMetric() *prometheus.CounterVec {

View File

@ -4,6 +4,7 @@ import (
"errors"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
@ -22,13 +23,35 @@ var _ = Describe("MetricResolver", func() {
m *mockResolver
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sut = NewMetricsResolver(config.PrometheusConfig{Enable: true}).(*MetricsResolver)
sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Recording prometheus metrics", func() {
Context("Recording request metrics", func() {
When("Request will be performed", func() {
@ -63,13 +86,4 @@ var _ = Describe("MetricResolver", func() {
})
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
})
})

View File

@ -11,6 +11,7 @@ import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
"github.com/0xERR0R/blocky/model"
@ -27,10 +28,21 @@ type mockResolver struct {
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
}
func (r *mockResolver) Configuration() []string {
// Type implements `Resolver`.
func (r *mockResolver) Type() string {
return "mock"
}
// IsEnabled implements `config.Configurable`.
func (r *mockResolver) IsEnabled() bool {
args := r.Called()
return args.Get(0).([]string)
return args.Get(0).(bool)
}
// LogConfig implements `config.Configurable`.
func (r *mockResolver) LogConfig(logger *logrus.Entry) {
r.Called()
}
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {

View File

@ -2,6 +2,7 @@ package resolver
import (
"github.com/0xERR0R/blocky/model"
"github.com/sirupsen/logrus"
)
var NoResponse = &model.Response{} //nolint:gochecknoglobals
@ -13,10 +14,20 @@ func NewNoOpResolver() Resolver {
return NoOpResolver{}
}
func (r NoOpResolver) Configuration() (result []string) {
return nil
// Type implements `Resolver`.
func (NoOpResolver) Type() string {
return "noop"
}
func (r NoOpResolver) Resolve(request *model.Request) (*model.Response, error) {
// IsEnabled implements `config.Configurable`.
func (NoOpResolver) IsEnabled() bool {
return true
}
// LogConfig implements `config.Configurable`.
func (NoOpResolver) LogConfig(*logrus.Entry) {
}
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) {
return NoResponse, nil
}

View File

@ -2,6 +2,7 @@ package resolver
import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -9,6 +10,12 @@ import (
var _ = Describe("NoOpResolver", func() {
var sut NoOpResolver
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sut = NewNoOpResolver().(NoOpResolver)
})
@ -21,10 +28,19 @@ var _ = Describe("NoOpResolver", func() {
})
})
Describe("Configuration output", func() {
It("returns nothing", func() {
c := sut.Configuration()
Expect(c).Should(BeNil())
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should not log anything", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).Should(BeEmpty())
})
})
})

View File

@ -20,13 +20,16 @@ import (
)
const (
upstreamDefaultCfgName = "default"
parallelResolverLogger = "parallel_best_resolver"
upstreamDefaultCfgName = config.UpstreamDefaultCfgName
parallelResolverType = "parallel_best"
resolverCount = 2
)
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
type ParallelBestResolver struct {
configurable[*config.ParallelBestConfig]
typed
resolversPerClient map[string][]*upstreamResolverStatus
}
@ -77,10 +80,11 @@ func testResolver(r *UpstreamResolver) error {
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(
upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (Resolver, error) {
logger := log.PrefixedLog(parallelResolverLogger)
cfg config.ParallelBestConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType)
upstreamResolvers := cfg.ExternalResolvers
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
for name, upstreamCfgs := range upstreamResolvers {
@ -114,10 +118,12 @@ func NewParallelBestResolver(
resolverGroups[name] = group
}
return newParallelBestResolver(resolverGroups)
return newParallelBestResolver(cfg, resolverGroups)
}
func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, error) {
func newParallelBestResolver(
cfg config.ParallelBestConfig, resolverGroups map[string][]Resolver,
) (*ParallelBestResolver, error) {
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
for groupName, resolvers := range resolverGroups {
@ -136,27 +142,21 @@ func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, er
}
r := ParallelBestResolver{
configurable: withConfig(&cfg),
typed: withType(parallelResolverType),
resolversPerClient: resolversPerClient,
}
return &r, nil
}
// Configuration returns current resolver configuration
func (r *ParallelBestResolver) Configuration() (result []string) {
result = append(result, "upstream resolvers:")
for name, res := range r.resolversPerClient {
result = append(result, fmt.Sprintf("- %s", name))
for _, r := range res {
result = append(result, fmt.Sprintf(" - %s", r.resolver))
}
}
return
func (r *ParallelBestResolver) Name() string {
return r.String()
}
func (r ParallelBestResolver) String() string {
result := make([]string, 0)
func (r *ParallelBestResolver) String() string {
result := make([]string, 0, len(r.resolversPerClient))
for name, res := range r.resolversPerClient {
tmp := make([]string, len(res))
@ -206,7 +206,7 @@ func (r *ParallelBestResolver) resolversForClient(request *model.Request) (resul
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverLogger)
logger := log.WithPrefix(request.Log, parallelResolverType)
resolvers := r.resolversForClient(request)

View File

@ -6,6 +6,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -19,13 +20,74 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
noVerifyUpstreams = false
)
systemResolverBootstrap := &Bootstrap{}
var (
sut *ParallelBestResolver
sutMapping config.ParallelBestMapping
sutVerify bool
err error
bootstrap *Bootstrap
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
},
config.Upstream{
Host: "127.0.0.2",
},
},
}
sutVerify = noVerifyUpstreams
bootstrap = systemResolverBootstrap
})
JustBeforeEach(func() {
sutConfig := config.ParallelBestConfig{ExternalResolvers: sutMapping}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
})
config.GetConfig().UpstreamTimeout = config.Duration(1000 * time.Millisecond)
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Name", func() {
It("should not be empty", func() {
Expect(sut.Name()).ShouldNot(BeEmpty())
})
})
When("default upstream resolvers are not defined", func() {
It("should fail on startup", func() {
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, nil, noVerifyUpstreams)
_, err := NewParallelBestResolver(config.ParallelBestConfig{
ExternalResolvers: config.ParallelBestMapping{},
}, nil, noVerifyUpstreams)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
})
@ -40,7 +102,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
defer mockUpstream.Close()
upstream := map[string][]config.Upstream{
upstream := config.ParallelBestMapping{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
@ -49,21 +111,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
},
}
_, err := NewParallelBestResolver(upstream, systemResolverBootstrap, verifyUpstreams)
_, err := NewParallelBestResolver(config.ParallelBestConfig{
ExternalResolvers: upstream,
}, systemResolverBootstrap, verifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
})
})
When("no upstream resolvers can be reached", func() {
var (
upstream map[string][]config.Upstream
b *Bootstrap
)
BeforeEach(func() {
b = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
upstream = map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {
config.Upstream{
Host: "wrong",
@ -75,22 +134,26 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}
})
It("should fail to start if strict checking is enabled", func() {
_, err := NewParallelBestResolver(upstream, b, verifyUpstreams)
Expect(err).Should(HaveOccurred())
When("strict checking is enabled", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
It("should start if strict checking is disabled", func() {
_, err := NewParallelBestResolver(upstream, b, noVerifyUpstreams)
Expect(err).Should(Not(HaveOccurred()))
When("strict checking is disabled", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
})
It("should start", func() {
Expect(err).Should(Not(HaveOccurred()))
})
})
})
Describe("Resolving result from fastest upstream resolver", func() {
var (
sut Resolver
err error
)
When("2 Upstream resolvers are defined", func() {
When("one resolver is fast and another is slow", func() {
BeforeEach(func() {
@ -107,10 +170,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
DeferCleanup(slowTestUpstream.Close)
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()},
}, nil, noVerifyUpstreams)
Expect(err).Should(Succeed())
}
})
It("Should use result from fastest one", func() {
request := newRequest("example.com.", A)
@ -136,9 +198,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
return response
})
DeferCleanup(slowTestUpstream.Close)
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {withErrorUpstream, slowTestUpstream.Start()},
}, systemResolverBootstrap, noVerifyUpstreams)
}
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
@ -158,9 +220,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
withError1 := config.Upstream{Host: "wrong"}
withError2 := config.Upstream{Host: "wrong"}
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {withError1, withError2},
}, systemResolverBootstrap, noVerifyUpstreams)
}
Expect(err).Should(Succeed())
})
It("Should return error", func() {
@ -194,14 +256,14 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
WithAnswerRR("example.com 123 IN A 123.124.122.126")
DeferCleanup(clientSpecificCIRDMockUpstream.Close)
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {defaultMockUpstream.Start()},
"laptop": {clientSpecificExactMockUpstream.Start()},
"client-*-m": {clientSpecificWildcardMockUpstream.Start()},
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
}, nil, noVerifyUpstreams)
}
})
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
@ -294,11 +356,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{
sutMapping = config.ParallelBestMapping{
upstreamDefaultCfgName: {
mockUpstream.Start(),
},
}, nil, noVerifyUpstreams)
}
})
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
@ -327,10 +389,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
tmp, _ := NewParallelBestResolver(map[string][]config.Upstream{
sut, _ := NewParallelBestResolver(config.ParallelBestConfig{ExternalResolvers: config.ParallelBestMapping{
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}, systemResolverBootstrap, noVerifyUpstreams)
sut := tmp.(*ParallelBestResolver)
}}, systemResolverBootstrap, noVerifyUpstreams)
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
@ -391,26 +452,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b, verifyUpstreams)
r, err := NewParallelBestResolver(config.ParallelBestConfig{
ExternalResolvers: config.ParallelBestMapping{"test": {{Host: "example.com"}}},
}, b, verifyUpstreams)
Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
})
})
Describe("Configuration output", func() {
var sut Resolver
BeforeEach(func() {
config.GetConfig().StartVerifyUpstream = false
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{upstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
}}, nil, noVerifyUpstreams)
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
})

View File

@ -1,7 +1,6 @@
package resolver
import (
"fmt"
"time"
"github.com/0xERR0R/blocky/config"
@ -14,35 +13,32 @@ import (
)
const (
cleanUpRunPeriod = 12 * time.Hour
queryLoggingResolverPrefix = "query_logging_resolver"
logChanCap = 1000
defaultFlushPeriod = 30 * time.Second
cleanUpRunPeriod = 12 * time.Hour
queryLoggingResolverType = "query_logging"
logChanCap = 1000
defaultFlushPeriod = 30 * time.Second
)
// QueryLoggingResolver writes query information (question, answer, duration, ...)
type QueryLoggingResolver struct {
configurable[*config.QueryLogConfig]
NextResolver
target string
logRetentionDays uint64
logChan chan *querylog.LogEntry
writer querylog.Writer
logType config.QueryLogType
fields []config.QueryLogField
typed
logChan chan *querylog.LogEntry
writer querylog.Writer
}
// NewQueryLoggingResolver returns a new resolver instance
func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
logger := log.PrefixedLog(queryLoggingResolverPrefix)
logger := log.PrefixedLog(queryLoggingResolverType)
var writer querylog.Writer
logType := cfg.Type
err := retry.Do(
func() error {
var err error
switch logType {
switch cfg.Type {
case config.QueryLogTypeCsv:
writer, err = querylog.NewCSVWriter(cfg.Target, false, cfg.LogRetentionDays)
case config.QueryLogTypeCsvClient:
@ -61,7 +57,7 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
},
retry.Attempts(uint(cfg.CreationAttempts)),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Duration(cfg.CreationCooldown)),
retry.Delay(cfg.CreationCooldown.ToDuration()),
retry.OnRetry(func(n uint, err error) {
logger.Warnf(
"Error occurred on query writer creation, retry attempt %d/%d: %v", n+1, cfg.CreationAttempts, err,
@ -71,18 +67,17 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
logger.Error("can't create query log writer, using console as fallback: ", err)
writer = querylog.NewLoggerWriter()
logType = config.QueryLogTypeConsole
cfg.Type = config.QueryLogTypeConsole
}
logChan := make(chan *querylog.LogEntry, logChanCap)
resolver := QueryLoggingResolver{
target: cfg.Target,
logRetentionDays: cfg.LogRetentionDays,
logChan: logChan,
writer: writer,
logType: logType,
fields: resolveQueryLogFields(cfg),
configurable: withConfig(&cfg),
typed: withType(queryLoggingResolverType),
logChan: logChan,
writer: writer,
}
go resolver.writeLog()
@ -94,24 +89,6 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
return &resolver
}
func resolveQueryLogFields(cfg config.QueryLogConfig) []config.QueryLogField {
var fields []config.QueryLogField
if len(cfg.Fields) == 0 {
// no fields defined, use all fields as fallback
for _, v := range config.QueryLogFieldNames() {
qlt, err := config.ParseQueryLogField(v)
util.LogOnError("ignoring unknown query log field", err)
fields = append(fields, qlt)
}
} else {
fields = cfg.Fields
}
return fields
}
// triggers periodically cleanup of old log files
func (r *QueryLoggingResolver) periodicCleanUp() {
ticker := time.NewTicker(cleanUpRunPeriod)
@ -129,7 +106,7 @@ func (r *QueryLoggingResolver) doCleanUp() {
// Resolve logs the query, duration and the result
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, queryLoggingResolverPrefix)
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
start := time.Now()
@ -157,7 +134,7 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
ClientNames: []string{"none"},
}
for _, f := range r.fields {
for _, f := range r.cfg.Fields {
switch f {
case config.QueryLogFieldClientIP:
entry.ClientIP = request.ClientIP.String()
@ -196,18 +173,8 @@ func (r *QueryLoggingResolver) writeLog() {
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap {
log.PrefixedLog(queryLoggingResolverPrefix).WithField("channel_len",
r.log().WithField("channel_len",
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
}
}
}
// Configuration returns the current resolver configuration
func (r *QueryLoggingResolver) Configuration() (result []string) {
result = append(result, fmt.Sprintf("type: \"%s\"", r.logType))
result = append(result, fmt.Sprintf("target: \"%s\"", r.target))
result = append(result, fmt.Sprintf("logRetentionDays: %d", r.logRetentionDays))
result = append(result, fmt.Sprintf("fields: %s", r.fields))
return
}

View File

@ -10,6 +10,7 @@ import (
"time"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/querylog"
"github.com/0xERR0R/blocky/config"
@ -44,6 +45,12 @@ var _ = Describe("QueryLoggingResolver", func() {
mockAnswer *dns.Msg
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mockAnswer = new(dns.Msg)
tmpDir = NewTmpFolder("queryLoggingResolver")
@ -52,6 +59,10 @@ var _ = Describe("QueryLoggingResolver", func() {
})
JustBeforeEach(func() {
if len(sutConfig.Fields) == 0 {
sutConfig.SetDefaults() // not called when using a struct literal
}
sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver)
DeferCleanup(func() { close(sut.logChan) })
m = &mockResolver{}
@ -59,6 +70,22 @@ var _ = Describe("QueryLoggingResolver", func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Process request", func() {
When("Resolver has no configuration", func() {
BeforeEach(func() {
@ -276,24 +303,6 @@ var _ = Describe("QueryLoggingResolver", func() {
})
})
Describe("Configuration output", func() {
When("resolver is enabled", func() {
BeforeEach(func() {
sutConfig = config.QueryLogConfig{
Target: tmpDir.Path,
Type: config.QueryLogTypeCsvClient,
LogRetentionDays: 0,
CreationAttempts: 1,
CreationCooldown: config.Duration(time.Millisecond),
}
})
It("should return configuration", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", 1))
})
})
})
Describe("Clean up of query log directory", func() {
When("fallback logger is enabled, log retention is enabled", func() {
BeforeEach(func() {
@ -355,7 +364,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}
})
It("should use fallback", func() {
Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole))
Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole))
})
})
@ -369,7 +378,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}
})
It("should use fallback", func() {
Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole))
Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole))
})
})
})

View File

@ -1,11 +1,10 @@
package resolver
import (
"fmt"
"net"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -14,20 +13,6 @@ import (
"github.com/sirupsen/logrus"
)
// Resolver is not configured.
const (
configStatusEnabled string = "enabled"
configStatusDisabled string = "disabled"
)
var (
// note: this is not used by all resolvers: only those that don't print any other configuration
configEnabled = []string{configStatusEnabled} //nolint:gochecknoglobals
configDisabled = []string{configStatusDisabled} //nolint:gochecknoglobals
)
func newRequest(question string, rType dns.Type, logger ...*logrus.Entry) *model.Request {
var loggerEntry *logrus.Entry
if len(logger) == 1 {
@ -84,11 +69,15 @@ func newRequestWithClientID(question string, rType dns.Type, ip, requestClientID
// Resolver generic interface for all resolvers
type Resolver interface {
config.Configurable
// Type returns a short, user-friendly, name for the resolver.
//
// It should be the same for all instances of a specific Resolver type.
Type() string
// Resolve performs resolution of a DNS request
Resolve(req *model.Request) (*model.Response, error)
// Configuration returns current resolver configuration
Configuration() []string
}
// ChainedResolver represents a resolver, which can delegate result to the next one
@ -142,10 +131,73 @@ func Name(resolver Resolver) string {
return named.Name()
}
return defaultName(resolver)
return resolver.Type()
}
// defaultName returns a short user-friendly name of a resolver
func defaultName(resolver Resolver) string {
return strings.Split(fmt.Sprintf("%T", resolver), ".")[1]
// ForEach iterates over all resolvers in the chain.
//
// If resolver is not a chain, or is unlinked,
// the callback is called exactly once.
func ForEach(resolver Resolver, callback func(Resolver)) {
for resolver != nil {
callback(resolver)
if chained, ok := resolver.(ChainedResolver); ok {
resolver = chained.GetNext()
} else {
break
}
}
}
// LogResolverConfig logs the resolver's type and config.
func LogResolverConfig(res Resolver, logger *logrus.Entry) {
// Use the type, not the full typeName, to avoid redundant information with the config
typeName := res.Type()
if !res.IsEnabled() {
logger.Debugf("-> %s: disabled", typeName)
return
}
logger.Infof("-> %s:", typeName)
log.WithIndent(logger, " ", res.LogConfig)
}
// Should be embedded in a Resolver to auto-implement `Resolver.Type`.
type typed struct {
typeName string
}
func withType(t string) typed {
return typed{typeName: t}
}
// Type implements `Resolver`.
func (t *typed) Type() string {
return t.typeName
}
func (t *typed) log() *logrus.Entry {
return log.PrefixedLog(t.Type())
}
// Should be embedded in a Resolver to auto-implement `config.Configurable`.
type configurable[T config.Configurable] struct {
cfg T
}
func withConfig[T config.Configurable](cfg T) configurable[T] {
return configurable[T]{cfg: cfg}
}
// IsEnabled implements `config.Configurable`.
func (c *configurable[T]) IsEnabled() bool {
return c.cfg.IsEnabled()
}
// LogConfig implements `config.Configurable`.
func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
c.cfg.LogConfig(logger)
}

View File

@ -1,45 +1,127 @@
package resolver
import (
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var systemResolverBootstrap = &Bootstrap{}
var _ = Describe("Resolver", func() {
systemResolverBootstrap := &Bootstrap{}
Describe("Chains", func() {
var (
r1 ChainedResolver
r2 ChainedResolver
r3 ChainedResolver
r4 Resolver
)
Describe("Creating resolver chain", func() {
When("A chain of resolvers will be created", func() {
It("should be iterable by calling 'GetNext'", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, nil, false)
ch := Chain(br, cr)
c, ok := ch.(ChainedResolver)
Expect(ok).Should(BeTrue())
BeforeEach(func() {
r1 = &mockResolver{}
r2 = &mockResolver{}
r3 = &mockResolver{}
r4 = &NoOpResolver{}
})
next := c.GetNext()
Expect(next).ShouldNot(BeNil())
Describe("Chain", func() {
It("should create a chain iterable using `GetNext`", func() {
ch := Chain(r1, r2, r3, r4)
Expect(ch).ShouldNot(BeNil())
Expect(ch).Should(Equal(r1))
Expect(r1.GetNext()).Should(Equal(r2))
Expect(r2.GetNext()).Should(Equal(r3))
Expect(r3.GetNext()).Should(Equal(r4))
})
It("should not link a final ChainedResolver", func() {
ch := Chain(r1, r2)
Expect(ch).ShouldNot(BeNil())
Expect(r1.GetNext()).Should(Equal(r2))
Expect(r2.GetNext()).Should(BeNil())
})
})
Describe("ForEach", func() {
It("should iterate on all resolvers in the chain", func() {
ch := Chain(r1, r2, r3, r4)
Expect(ch).ShouldNot(BeNil())
var itResult []Resolver
ForEach(ch, func(r Resolver) {
itResult = append(itResult, r)
})
Expect(itResult).ShouldNot(BeEmpty())
Expect(itResult).Should(Equal([]Resolver{r1, r2, r3, r4}))
})
})
Describe("LogResolverConfig", func() {
It("should call the resolver's `LogConfig`", func() {
logger := logrus.NewEntry(log.Log())
m := &mockResolver{}
m.On("IsEnabled").Return(true)
m.On("LogConfig")
LogResolverConfig(m, logger)
m.AssertExpectations(GinkgoT())
})
When("the resolver is disabled", func() {
It("should not call the resolver's `LogConfig`", func() {
logger := logrus.NewEntry(log.Log())
m := &mockResolver{}
m.On("IsEnabled").Return(false)
LogResolverConfig(m, logger)
m.AssertExpectations(GinkgoT())
})
})
})
})
Describe("Name", func() {
When("'Name' is called", func() {
It("should return resolver name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
name := Name(br)
Expect(name).Should(Equal("BlockingResolver"))
Expect(name).Should(Equal("blocking"))
})
})
When("'Name' is called on a NamedResolver", func() {
It("should return it's custom name", func() {
It("should return its custom name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}}
cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}}
r := NewRewriterResolver(cfg, br)
name := Name(r)
Expect(name).Should(Equal("BlockingResolver w/ RewriterResolver"))
Expect(name).Should(Equal("blocking w/ rewrite"))
})
})
})
})
func expectValidResolverType(sut Resolver) {
By("it must not contain spaces", func() {
Expect(sut.Type()).ShouldNot(ContainSubstring(" "))
})
By("it must be lower case", func() {
Expect(sut.Type()).Should(Equal(strings.ToLower(sut.Type())))
})
By("it must not contain 'resolver'", func() {
Expect(sut.Type()).ShouldNot(ContainSubstring("resolver"))
})
}

View File

@ -18,13 +18,14 @@ import (
// The branch is where the rewrite is active. If the branch doesn't
// yield a result, the normal resolving is continued.
type RewriterResolver struct {
configurable[*config.RewriterConfig]
NextResolver
rewrite map[string]string
inner Resolver
fallbackUpstream bool
typed
inner Resolver
}
func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) ChainedResolver {
func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) ChainedResolver {
if len(cfg.Rewrite) == 0 {
return inner
}
@ -36,27 +37,22 @@ func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) Chaine
inner.Next(NewNoOpResolver())
return &RewriterResolver{
rewrite: cfg.Rewrite,
inner: inner,
fallbackUpstream: cfg.FallbackUpstream,
configurable: withConfig(&cfg),
typed: withType("rewrite"),
inner: inner,
}
}
func (r *RewriterResolver) Name() string {
return fmt.Sprintf("%s w/ %s", Name(r.inner), defaultName(r))
return fmt.Sprintf("%s w/ %s", Name(r.inner), r.Type())
}
// Configuration returns current resolver configuration
func (r *RewriterResolver) Configuration() (result []string) {
result = append(result, "rewrite:")
for key, val := range r.rewrite {
result = append(result, fmt.Sprintf(" %s = \"%s\"", key, val))
}
// LogConfig implements `config.Configurable`.
func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
LogResolverConfig(r.inner, logger)
innerCfg := r.inner.Configuration()
result = append(result, innerCfg...)
return result
r.cfg.LogConfig(logger)
}
// Resolve uses the inner resolver to resolve the rewritten query
@ -79,7 +75,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
request.Req = original
fallbackCondition := err != nil || (response != NoResponse && response.Res.Answer == nil)
if r.fallbackUpstream && fallbackCondition {
if r.cfg.FallbackUpstream && fallbackCondition {
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
@ -130,7 +126,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg
logger.WithFields(logrus.Fields{
"domain": domainOriginal,
"rewrite": rewriteKey + ":" + r.rewrite[rewriteKey],
"rewrite": rewriteKey + ":" + r.cfg.Rewrite[rewriteKey],
}).Debugf("rewriting %q to %q", domainOriginal, domainRewritten)
}
}
@ -139,7 +135,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg
}
func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
for k, v := range r.rewrite {
for k, v := range r.cfg.Rewrite {
if strings.HasSuffix(domain, "."+k) {
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v

View File

@ -2,8 +2,10 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
@ -19,7 +21,7 @@ const (
var _ = Describe("RewriterResolver", func() {
var (
sut ChainedResolver
sutConfig config.RewriteConfig
sutConfig config.RewriterConfig
mInner *mockResolver
mNext *mockResolver
@ -29,11 +31,17 @@ var _ = Describe("RewriterResolver", func() {
mNextResponse *model.Response
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mInner = &mockResolver{}
mNext = &mockResolver{}
sutConfig = config.RewriteConfig{Rewrite: map[string]string{"original": "rewritten"}}
sutConfig = config.RewriterConfig{Rewrite: map[string]string{"original": "rewritten"}}
})
JustBeforeEach(func() {
@ -48,7 +56,7 @@ var _ = Describe("RewriterResolver", func() {
When("has no configuration", func() {
BeforeEach(func() {
sutConfig = config.RewriteConfig{}
sutConfig = config.RewriterConfig{}
})
It("should return the inner resolver", func() {
@ -194,11 +202,10 @@ var _ = Describe("RewriterResolver", func() {
Describe("Configuration output", func() {
When("resolver is enabled", func() {
It("should return configuration", func() {
innerOutput := []string{"inner:", "config-output"}
mInner.On("Configuration").Return(innerOutput)
mInner.On("LogConfig")
mInner.On("IsEnabled").Return(true)
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">", len(innerOutput)))
sut.LogConfig(logrus.NewEntry(log.Log()))
})
})
})

View File

@ -7,6 +7,7 @@ import (
"github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
const (
@ -46,11 +47,15 @@ type defaultIPs struct {
type SpecialUseDomainNamesResolver struct {
NextResolver
typed
defaults *defaultIPs
}
func NewSpecialUseDomainNamesResolver() ChainedResolver {
return &SpecialUseDomainNamesResolver{
typed: withType("special_use_domains"),
defaults: &defaultIPs{
loopbackV4: net.ParseIP("127.0.0.1"),
loopbackV6: net.IPv6loopback,
@ -58,6 +63,17 @@ func NewSpecialUseDomainNamesResolver() ChainedResolver {
}
}
// IsEnabled implements `config.Configurable`.
func (r *SpecialUseDomainNamesResolver) IsEnabled() bool {
// RFC 6761 & 6762 are always active
return true
}
// LogConfig implements `config.Configurable`.
func (r *SpecialUseDomainNamesResolver) LogConfig(logger *logrus.Entry) {
logger.Info("enabled")
}
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
// RFC 6761 - negative
if r.isSpecial(request, sudnArpaSlice()...) ||
@ -78,11 +94,6 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
return r.next.Resolve(request)
}
// RFC 6761 & 6762 are always active
func (r *SpecialUseDomainNamesResolver) Configuration() []string {
return configEnabled
}
func (r *SpecialUseDomainNamesResolver) isSpecial(request *model.Request, names ...string) bool {
domainFromQuestion := request.Req.Question[0].Name
for _, n := range names {

View File

@ -2,6 +2,7 @@ package resolver
import (
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -16,6 +17,12 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
m *mockResolver
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
mockAnswer, err := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
Expect(err).Should(Succeed())
@ -27,6 +34,22 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
sut.Next(m)
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should not log anything", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Blocking special names", func() {
It("should block arpa", func() {
for _, arpa := range sudnArpaSlice() {
@ -133,12 +156,4 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
))
})
})
Describe("Configuration pseudo test", func() {
It("should always be empty", func() {
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">=", 1))
})
})
})

View File

@ -30,6 +30,8 @@ const (
// UpstreamResolver sends request to external DNS server
type UpstreamResolver struct {
typed
upstream config.Upstream
upstreamClient upstreamClient
bootstrap *Bootstrap
@ -51,7 +53,7 @@ type httpUpstreamClient struct {
}
func createUpstreamClient(cfg config.Upstream) upstreamClient {
timeout := time.Duration(config.GetConfig().UpstreamTimeout)
timeout := config.GetConfig().UpstreamTimeout.ToDuration()
tlsConfig := tls.Config{
ServerName: cfg.Host,
@ -209,25 +211,30 @@ func newUpstreamResolverUnchecked(upstream config.Upstream, bootstrap *Bootstrap
upstreamClient := createUpstreamClient(upstream)
return &UpstreamResolver{
typed: withType("upstream"),
upstream: upstream,
upstreamClient: upstreamClient,
bootstrap: bootstrap,
}
}
// Configuration return current resolver configuration
func (r *UpstreamResolver) Configuration() (result []string) {
return []string{r.String()}
// IsEnabled implements `config.Configurable`.
func (r *UpstreamResolver) IsEnabled() bool {
return true
}
// LogConfig implements `config.Configurable`.
func (r *UpstreamResolver) LogConfig(logger *logrus.Entry) {
logger.Info(r.upstream)
}
func (r UpstreamResolver) String() string {
return fmt.Sprintf("upstream '%s'", r.upstream.String())
return fmt.Sprintf("%s '%s'", r.Type(), r.upstream)
}
// Resolve calls external resolver
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "upstream_resolver")
ips, err := r.bootstrap.UpstreamIPs(r)
if err != nil {
return nil, err
@ -247,7 +254,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
var err error
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol)
if err == nil {
logger.WithFields(logrus.Fields{
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
@ -272,7 +279,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
return errors.As(err, &netErr) && netErr.Timeout()
}),
retry.OnRetry(func(n uint, err error) {
logger.WithFields(logrus.Fields{
r.log().WithFields(logrus.Fields{
"upstream": r.upstream.String(),
"upstream_ip": ip.String(),
"question": util.QuestionToString(request.Req.Question),

View File

@ -9,6 +9,7 @@ import (
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -17,7 +18,40 @@ import (
)
var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
systemResolverBootstrap := &Bootstrap{}
var (
sut *UpstreamResolver
sutConfig config.Upstream
)
BeforeEach(func() {
sutConfig = config.Upstream{Host: "localhost"}
})
JustBeforeEach(func() {
sut = newUpstreamResolverUnchecked(sutConfig, systemResolverBootstrap)
})
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
Describe("LogConfig", func() {
It("should log something", func() {
logger, hook := log.NewMockEntry()
sut.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
})
})
Describe("Using DNS upstream", func() {
When("Configured DNS resolver can resolve query", func() {
@ -35,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
)
})
})
@ -53,7 +87,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
)
})
})
@ -216,15 +250,4 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
})
})
Describe("Configuration", func() {
When("Configuration is called", func() {
It("should return configuration", func() {
sut := newUpstreamResolverUnchecked(config.Upstream{}, nil)
c := sut.Configuration()
Expect(len(c)).Should(BeNumerically(">=", 1))
})
})
})
})

View File

@ -395,7 +395,7 @@ func createQueryResolver(
redisClient *redis.Client,
) (r resolver.Resolver, err error) {
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap, cfg.StartVerifyUpstream)
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream)
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
@ -411,16 +411,16 @@ func createQueryResolver(
r = resolver.Chain(
resolver.NewFilteringResolver(cfg.Filtering),
resolver.NewFqdnOnlyResolver(*cfg),
resolver.NewFqdnOnlyResolver(cfg.FqdnOnly),
clientNames,
resolver.NewEdeResolver(cfg.Ede),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
resolver.NewHostsFileResolver(cfg.HostsFile),
blocking,
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, condUpstream),
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
resolver.NewSpecialUseDomainNamesResolver(),
parallel,
)
@ -439,25 +439,12 @@ func (s *Server) registerDNSHandlers() {
func (s *Server) printConfiguration() {
logger().Info("current configuration:")
res := s.queryResolver
for res != nil {
logger().Infof("-> resolver: '%s'", resolver.Name(res))
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
resolver.LogResolverConfig(res, logger())
})
for _, c := range res.Configuration() {
logger().Infof(" %s", c)
}
if c, ok := res.(resolver.ChainedResolver); ok {
res = c.GetNext()
} else {
break
}
}
logger().Infof("- DNS listening on addrs/ports: %v", s.cfg.Ports.DNS)
logger().Infof("- TLS listening on addrs/ports: %v", s.cfg.Ports.TLS)
logger().Infof("- HTTP listening on addrs/ports: %v", s.cfg.Ports.HTTP)
logger().Infof("- HTTPS listening on addrs/ports: %v", s.cfg.Ports.HTTPS)
logger().Info("listeners:")
log.WithIndent(logger(), " ", s.cfg.Ports.LogConfig)
logger().Info("runtime information:")
@ -465,17 +452,19 @@ func (s *Server) printConfiguration() {
runtime.GC()
debug.FreeOSMemory()
logger().Infof(" numCPU = %d", runtime.NumCPU())
logger().Infof(" numGoroutine = %d", runtime.NumGoroutine())
// gather memory stats
var m runtime.MemStats
runtime.ReadMemStats(&m)
logger().Infof("MEM Alloc = %10v MB", toMB(m.Alloc))
logger().Infof("MEM HeapAlloc = %10v MB", toMB(m.HeapAlloc))
logger().Infof("MEM Sys = %10v MB", toMB(m.Sys))
logger().Infof("MEM NumGC = %10v", m.NumGC)
logger().Infof("RUN NumCPU = %10d", runtime.NumCPU())
logger().Infof("RUN NumGoroutine = %10d", runtime.NumGoroutine())
logger().Infof(" memory:")
logger().Infof(" alloc = %10v MB", toMB(m.Alloc))
logger().Infof(" heapAlloc = %10v MB", toMB(m.HeapAlloc))
logger().Infof(" sys = %10v MB", toMB(m.Sys))
logger().Infof(" numGC = %10v", m.NumGC)
}
func toMB(b uint64) uint64 {

View File

@ -105,7 +105,7 @@ var _ = BeforeSuite(func() {
// create server
sut, err = NewServer(&config.Config{
CustomDNS: config.CustomDNSConfig{
CustomTTL: config.Duration(time.Duration(3600) * time.Second),
CustomTTL: config.Duration(3600 * time.Second),
Mapping: config.CustomDNSMapping{
HostIPs: map[string][]net.IP{
"custom.lan": {net.ParseIP("192.168.178.55")},
@ -143,7 +143,7 @@ var _ = BeforeSuite(func() {
BlockType: "zeroIp",
BlockTTL: config.Duration(6 * time.Hour),
},
Upstream: config.UpstreamConfig{
Upstream: config.ParallelBestConfig{
ExternalResolvers: map[string][]config.Upstream{"default": {upstreamGoogle}},
},
ClientLookup: config.ClientLookupConfig{
@ -158,7 +158,7 @@ var _ = BeforeSuite(func() {
},
CertFile: certPem.Path,
KeyFile: keyPem.Path,
Prometheus: config.PrometheusConfig{
Prometheus: config.MetricsConfig{
Enable: true,
Path: "/metrics",
},
@ -643,7 +643,7 @@ var _ = Describe("Running DNS server", func() {
It("start was called 2 times, start should fail", func() {
// create server
server, err := NewServer(&config.Config{
Upstream: config.UpstreamConfig{
Upstream: config.ParallelBestConfig{
ExternalResolvers: map[string][]config.Upstream{
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
},
@ -685,7 +685,7 @@ var _ = Describe("Running DNS server", func() {
It("stop was called 2 times, start should fail", func() {
// create server
server, err := NewServer(&config.Config{
Upstream: config.UpstreamConfig{
Upstream: config.ParallelBestConfig{
ExternalResolvers: map[string][]config.Upstream{
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
},

View File

@ -109,7 +109,7 @@ func NewMsgWithQuestion(question string, qType dns.Type) *dns.Msg {
// NewMsgWithAnswer creates new DNS message with answer
func NewMsgWithAnswer(domain string, ttl uint, dnsType dns.Type, address string) (*dns.Msg, error) {
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType.String(), address))
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType, address))
if err != nil {
return nil, err
}