diff --git a/config/blocking.go b/config/blocking.go new file mode 100644 index 00000000..be13e435 --- /dev/null +++ b/config/blocking.go @@ -0,0 +1,81 @@ +package config + +import ( + "strings" + + "github.com/0xERR0R/blocky/log" + "github.com/sirupsen/logrus" +) + +// BlockingConfig configuration for query blocking +type BlockingConfig struct { + BlackLists map[string][]string `yaml:"blackLists"` + WhiteLists map[string][]string `yaml:"whiteLists"` + ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"` + BlockType string `yaml:"blockType" default:"ZEROIP"` + BlockTTL Duration `yaml:"blockTTL" default:"6h"` + DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"` + DownloadAttempts uint `yaml:"downloadAttempts" default:"3"` + DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"` + RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` + // Deprecated + FailStartOnListError bool `yaml:"failStartOnListError" default:"false"` + ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"` + StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"` +} + +// IsEnabled implements `config.Configurable`. +func (c *BlockingConfig) IsEnabled() bool { + return len(c.ClientGroupsBlock) != 0 +} + +// IsEnabled implements `config.Configurable`. +func (c *BlockingConfig) LogConfig(logger *logrus.Entry) { + logger.Info("clientGroupsBlock:") + + for key, val := range c.ClientGroupsBlock { + logger.Infof(" %s = %v", key, val) + } + + logger.Infof("blockType = %s", c.BlockType) + + if c.BlockType != "NXDOMAIN" { + logger.Infof("blockTTL = %s", c.BlockTTL) + } + + logger.Infof("downloadTimeout = %s", c.DownloadTimeout) + + logger.Infof("failStartOnListError = %t", c.FailStartOnListError) + + if c.RefreshPeriod > 0 { + logger.Infof("refresh = every %s", c.RefreshPeriod) + } else { + logger.Debug("refresh = disabled") + } + + logger.Info("blacklist:") + log.WithIndent(logger, " ", func(logger *logrus.Entry) { + c.logListGroups(logger, c.BlackLists) + }) + + logger.Info("whitelist:") + log.WithIndent(logger, " ", func(logger *logrus.Entry) { + c.logListGroups(logger, c.WhiteLists) + }) +} + +func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) { + for group, links := range listGroups { + logger.Infof("%s:", group) + + for _, link := range links { + if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char + link = link[:idx] // first line only + + logger.Infof(" - %s [...]", link) + } else { + logger.Infof(" - %s", link) + } + } + } +} diff --git a/config/blocking_test.go b/config/blocking_test.go new file mode 100644 index 00000000..b806b028 --- /dev/null +++ b/config/blocking_test.go @@ -0,0 +1,86 @@ +package config + +import ( + "time" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/sirupsen/logrus" +) + +var _ = Describe("BlockingConfig", func() { + var ( + cfg BlockingConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = BlockingConfig{ + BlockType: "ZEROIP", + BlockTTL: Duration(time.Minute), + BlackLists: map[string][]string{ + "gr1": {"/a/file/path"}, + }, + ClientGroupsBlock: map[string][]string{ + "default": {"gr1"}, + }, + RefreshPeriod: Duration(time.Hour), + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := BlockingConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := BlockingConfig{ + BlockTTL: Duration(-1), + } + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:")) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour"))) + }) + When("refresh is disabled", func() { + It("should reflect that", func() { + cfg.RefreshPeriod = Duration(-1) + + logger.Logger.Level = logrus.InfoLevel + + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled"))) + + logger.Logger.Level = logrus.TraceLevel + + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled"))) + }) + }) + }) +}) diff --git a/config/caching.go b/config/caching.go new file mode 100644 index 00000000..0393d34f --- /dev/null +++ b/config/caching.go @@ -0,0 +1,52 @@ +package config + +import ( + "time" + + "github.com/sirupsen/logrus" +) + +// CachingConfig configuration for domain caching +type CachingConfig struct { + MinCachingTime Duration `yaml:"minTime"` + MaxCachingTime Duration `yaml:"maxTime"` + CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"` + MaxItemsCount int `yaml:"maxItemsCount"` + Prefetching bool `yaml:"prefetching"` + PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"` + PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"` + PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"` +} + +// IsEnabled implements `config.Configurable`. +func (c *CachingConfig) IsEnabled() bool { + return c.MaxCachingTime > 0 +} + +// LogConfig implements `config.Configurable`. +func (c *CachingConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("minTime = %s", c.MinCachingTime) + logger.Infof("maxTime = %s", c.MaxCachingTime) + logger.Infof("cacheTimeNegative = %s", c.CacheTimeNegative) + + if c.Prefetching { + logger.Infof("prefetching:") + logger.Infof(" expires = %s", c.PrefetchExpires) + logger.Infof(" threshold = %d", c.PrefetchThreshold) + logger.Infof(" maxItems = %d", c.PrefetchMaxItemsCount) + } else { + logger.Debug("prefetching: disabled") + } +} + +func (c *CachingConfig) EnablePrefetch() { + const day = Duration(24 * time.Hour) + + if c.MaxCachingTime.IsZero() { + // make sure resolver gets enabled + c.MaxCachingTime = day + } + + c.Prefetching = true + c.PrefetchThreshold = 0 +} diff --git a/config/caching_test.go b/config/caching_test.go new file mode 100644 index 00000000..a295d0c9 --- /dev/null +++ b/config/caching_test.go @@ -0,0 +1,65 @@ +package config + +import ( + "time" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("CachingConfig", func() { + var ( + cfg CachingConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = CachingConfig{ + MaxCachingTime: Duration(time.Hour), + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := CachingConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("the config is enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("the config is disabled", func() { + It("should be false", func() { + cfg := CachingConfig{ + MaxCachingTime: Duration(-1), + } + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + When("prefetching is enabled", func() { + BeforeEach(func() { + cfg = CachingConfig{ + Prefetching: true, + } + }) + + It("should return configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("prefetching:"))) + }) + }) + }) +}) diff --git a/config/client_lookup.go b/config/client_lookup.go new file mode 100644 index 00000000..00da5046 --- /dev/null +++ b/config/client_lookup.go @@ -0,0 +1,36 @@ +package config + +import ( + "net" + + "github.com/sirupsen/logrus" +) + +// ClientLookupConfig configuration for the client lookup +type ClientLookupConfig struct { + ClientnameIPMapping map[string][]net.IP `yaml:"clients"` + Upstream Upstream `yaml:"upstream"` + SingleNameOrder []uint `yaml:"singleNameOrder"` +} + +// IsEnabled implements `config.Configurable`. +func (c *ClientLookupConfig) IsEnabled() bool { + return !c.Upstream.IsDefault() || len(c.ClientnameIPMapping) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *ClientLookupConfig) LogConfig(logger *logrus.Entry) { + if !c.Upstream.IsDefault() { + logger.Infof("upstream = %s", c.Upstream) + } + + logger.Infof("singleNameOrder = %v", c.SingleNameOrder) + + if len(c.ClientnameIPMapping) > 0 { + logger.Infof("client IP mapping:") + + for k, v := range c.ClientnameIPMapping { + logger.Infof(" %s = %s", k, v) + } + } +} diff --git a/config/client_lookup_test.go b/config/client_lookup_test.go new file mode 100644 index 00000000..bdf169b7 --- /dev/null +++ b/config/client_lookup_test.go @@ -0,0 +1,68 @@ +package config + +import ( + "net" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ClientLookupConfig", func() { + var ( + cfg ClientLookupConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = ClientLookupConfig{ + Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"}, + SingleNameOrder: []uint{1, 2}, + ClientnameIPMapping: map[string][]net.IP{ + "client8": {net.ParseIP("1.2.3.5")}, + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg = ClientLookupConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + By("upstream", func() { + cfg := ClientLookupConfig{ + Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"}, + ClientnameIPMapping: nil, + } + + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + + By("mapping", func() { + cfg := ClientLookupConfig{ + ClientnameIPMapping: map[string][]net.IP{ + "client8": {net.ParseIP("1.2.3.5")}, + }, + } + + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("client IP mapping:"))) + }) + }) +}) diff --git a/config/conditional_upstream.go b/config/conditional_upstream.go new file mode 100644 index 00000000..fe92e719 --- /dev/null +++ b/config/conditional_upstream.go @@ -0,0 +1,60 @@ +package config + +import ( + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +// ConditionalUpstreamConfig conditional upstream configuration +type ConditionalUpstreamConfig struct { + RewriterConfig `yaml:",inline"` + Mapping ConditionalUpstreamMapping `yaml:"mapping"` +} + +// ConditionalUpstreamMapping mapping for conditional configuration +type ConditionalUpstreamMapping struct { + Upstreams map[string][]Upstream +} + +// IsEnabled implements `config.Configurable`. +func (c *ConditionalUpstreamConfig) IsEnabled() bool { + return len(c.Mapping.Upstreams) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *ConditionalUpstreamConfig) LogConfig(logger *logrus.Entry) { + for key, val := range c.Mapping.Upstreams { + logger.Infof("%s = %v", key, val) + } +} + +// UnmarshalYAML implements `yaml.Unmarshaler`. +func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { + var input map[string]string + if err := unmarshal(&input); err != nil { + return err + } + + result := make(map[string][]Upstream, len(input)) + + for k, v := range input { + var upstreams []Upstream + + for _, part := range strings.Split(v, ",") { + upstream, err := ParseUpstream(strings.TrimSpace(part)) + if err != nil { + return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err) + } + + upstreams = append(upstreams, upstream) + } + + result[k] = upstreams + } + + c.Upstreams = result + + return nil +} diff --git a/config/conditional_upstream_test.go b/config/conditional_upstream_test.go new file mode 100644 index 00000000..7016a14c --- /dev/null +++ b/config/conditional_upstream_test.go @@ -0,0 +1,89 @@ +package config + +import ( + "errors" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ConditionalUpstreamConfig", func() { + var ( + cfg ConditionalUpstreamConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = ConditionalUpstreamConfig{ + Mapping: ConditionalUpstreamMapping{ + Upstreams: map[string][]Upstream{ + "fritz.box": {Upstream{Net: NetProtocolTcpUdp, Host: "fbTest"}}, + "other.box": {Upstream{Net: NetProtocolTcpUdp, Host: "otherTest"}}, + ".": {Upstream{Net: NetProtocolTcpUdp, Host: "dotTest"}}, + }, + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := ConditionalUpstreamConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := ConditionalUpstreamConfig{ + Mapping: ConditionalUpstreamMapping{Upstreams: map[string][]Upstream{}}, + } + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("fritz.box = "))) + }) + }) + + Describe("UnmarshalYAML", func() { + It("Should parse config as map", func() { + c := &ConditionalUpstreamMapping{} + err := c.UnmarshalYAML(func(i interface{}) error { + *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} + + return nil + }) + Expect(err).Should(Succeed()) + Expect(c.Upstreams).Should(HaveLen(1)) + Expect(c.Upstreams["key"]).Should(HaveLen(1)) + Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{ + Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53, + })) + }) + + It("should fail if wrong YAML format", func() { + c := &ConditionalUpstreamMapping{} + err := c.UnmarshalYAML(func(i interface{}) error { + return errors.New("some err") + }) + Expect(err).Should(HaveOccurred()) + Expect(err).Should(MatchError("some err")) + }) + }) +}) diff --git a/config/config.go b/config/config.go index 082aa501..6299b567 100644 --- a/config/config.go +++ b/config/config.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names +//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values package config import ( @@ -7,16 +7,12 @@ import ( "net" "os" "path/filepath" - "regexp" - "sort" "strconv" "strings" "sync" - "time" "github.com/miekg/dns" - - "github.com/hako/durafmt" + "github.com/sirupsen/logrus" "github.com/0xERR0R/blocky/log" "github.com/creasty/defaults" @@ -29,6 +25,16 @@ const ( httpsPort = 443 ) +type Configurable interface { + // IsEnabled returns true when the receiver is configured. + IsEnabled() bool + + // LogConfig logs the receiver's configuration. + // + // Calling this method when `IsEnabled` returns false is undefined. + LogConfig(*logrus.Entry) +} + // NetProtocol resolver protocol ENUM( // tcp+udp // TCP and UDP protocols // tcp-tls // TCP-TLS protocol @@ -90,44 +96,6 @@ type StartStrategyType uint16 // ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration) type QueryLogField string -type QType dns.Type - -func (c QType) String() string { - return dns.Type(c).String() -} - -type QTypeSet map[QType]struct{} - -func NewQTypeSet(qTypes ...dns.Type) QTypeSet { - s := make(QTypeSet, len(qTypes)) - - for _, qType := range qTypes { - s.Insert(qType) - } - - return s -} - -func (s QTypeSet) Contains(qType dns.Type) bool { - _, found := s[QType(qType)] - - return found -} - -func (s *QTypeSet) Insert(qType dns.Type) { - if *s == nil { - *s = make(QTypeSet, 1) - } - - (*s)[QType(qType)] = struct{}{} -} - -type Duration time.Duration - -func (c Duration) String() string { - return durafmt.Parse(time.Duration(c)).String() -} - //nolint:gochecknoglobals var netDefaultPort = map[NetProtocol]uint16{ NetProtocolTcpUdp: udpPort, @@ -135,82 +103,12 @@ var netDefaultPort = map[NetProtocol]uint16{ NetProtocolHttps: httpsPort, } -// Upstream is the definition of external DNS server -type Upstream struct { - Net NetProtocol - Host string - Port uint16 - Path string - CommonName string // Common Name to use for certificate verification; optional. "" uses .Host -} - -// IsDefault returns true if u is the default value -func (u *Upstream) IsDefault() bool { - return *u == Upstream{} -} - -// String returns the string representation of u -func (u Upstream) String() string { - if u.IsDefault() { - return "no upstream" - } - - var sb strings.Builder - - sb.WriteString(u.Net.String()) - sb.WriteRune(':') - - if u.Net == NetProtocolHttps { - sb.WriteString("//") - } - - isIPv6 := strings.ContainsRune(u.Host, ':') - if isIPv6 { - sb.WriteRune('[') - sb.WriteString(u.Host) - sb.WriteRune(']') - } else { - sb.WriteString(u.Host) - } - - if u.Port != netDefaultPort[u.Net] { - sb.WriteRune(':') - sb.WriteString(fmt.Sprint(u.Port)) - } - - if u.Path != "" { - sb.WriteString(u.Path) - } - - return sb.String() -} - -// UnmarshalYAML creates Upstream from YAML -func (u *Upstream) UnmarshalYAML(unmarshal func(interface{}) error) error { - var s string - if err := unmarshal(&s); err != nil { - return err - } - - upstream, err := ParseUpstream(s) - if err != nil { - return fmt.Errorf("can't convert upstream '%s': %w", s, err) - } - - *u = upstream - - return nil -} - // ListenConfig is a list of address(es) to listen on type ListenConfig []string -// UnmarshalYAML creates ListenConfig from YAML -func (l *ListenConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - var addresses string - if err := unmarshal(&addresses); err != nil { - return err - } +// UnmarshalText implements `encoding.TextUnmarshaler`. +func (l *ListenConfig) UnmarshalText(data []byte) error { + addresses := string(data) *l = strings.Split(addresses, ",") @@ -256,225 +154,11 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e return nil } -// UnmarshalYAML creates ConditionalUpstreamMapping from YAML -func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input map[string]string - if err := unmarshal(&input); err != nil { - return err - } - - result := make(map[string][]Upstream, len(input)) - - for k, v := range input { - var upstreams []Upstream - - for _, part := range strings.Split(v, ",") { - upstream, err := ParseUpstream(strings.TrimSpace(part)) - if err != nil { - return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err) - } - - upstreams = append(upstreams, upstream) - } - - result[k] = upstreams - } - - c.Upstreams = result - - return nil -} - -// UnmarshalYAML creates CustomDNSMapping from YAML -func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input map[string]string - if err := unmarshal(&input); err != nil { - return err - } - - result := make(map[string][]net.IP, len(input)) - - for k, v := range input { - var ips []net.IP - - for _, part := range strings.Split(v, ",") { - ip := net.ParseIP(strings.TrimSpace(part)) - if ip == nil { - return fmt.Errorf("invalid IP address '%s'", part) - } - - ips = append(ips, ip) - } - - result[k] = ips - } - - c.HostIPs = result - - return nil -} - -// UnmarshalYAML creates Duration from YAML. If no unit is used, uses minutes -func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input string - if err := unmarshal(&input); err != nil { - return err - } - - if minutes, err := strconv.Atoi(input); err == nil { - // duration is defined as number without unit - // use minutes to ensure back compatibility - *c = Duration(time.Duration(minutes) * time.Minute) - - return nil - } - - duration, err := time.ParseDuration(input) - if err == nil { - *c = Duration(duration) - - return nil - } - - return err -} - -func (c *QType) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input string - if err := unmarshal(&input); err != nil { - return err - } - - t, found := dns.StringToType[input] - if !found { - types := make([]string, 0, len(dns.StringToType)) - for k := range dns.StringToType { - types = append(types, k) - } - - sort.Strings(types) - - return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'", - input, strings.Join(types, ", ")) - } - - *c = QType(t) - - return nil -} - -func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error { - var input []QType - if err := unmarshal(&input); err != nil { - return err - } - - *s = make(QTypeSet, len(input)) - - for _, qType := range input { - (*s)[qType] = struct{}{} - } - - return nil -} - -var validDomain = regexp.MustCompile( - `^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) - -// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname] -func ParseUpstream(upstream string) (Upstream, error) { - var path string - - var port uint16 - - commonName, upstream := extractCommonName(upstream) - - n, upstream := extractNet(upstream) - - path, upstream = extractPath(upstream) - - host, portString, err := net.SplitHostPort(upstream) - - // string contains host:port - if err == nil { - p, err := ConvertPort(portString) - if err != nil { - err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err) - - return Upstream{}, err - } - - port = p - } else { - // only host, use default port - host = upstream - port = netDefaultPort[n] - - // trim any IPv6 brackets - host = strings.TrimPrefix(host, "[") - host = strings.TrimSuffix(host, "]") - } - - // validate hostname or ip - if ip := net.ParseIP(host); ip == nil { - // is not IP - if !validDomain.MatchString(host) { - return Upstream{}, fmt.Errorf("wrong host name '%s'", host) - } - } - - return Upstream{ - Net: n, - Host: host, - Port: port, - Path: path, - CommonName: commonName, - }, nil -} - -func extractCommonName(in string) (string, string) { - upstream, cn, _ := strings.Cut(in, "#") - - return cn, upstream -} - -func extractPath(in string) (path, upstream string) { - slashIdx := strings.Index(in, "/") - - if slashIdx >= 0 { - path = in[slashIdx:] - upstream = in[:slashIdx] - } else { - upstream = in - } - - return -} - -func extractNet(upstream string) (NetProtocol, string) { - tcpUDPPrefix := NetProtocolTcpUdp.String() + ":" - if strings.HasPrefix(upstream, tcpUDPPrefix) { - return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):] - } - - tcpTLSPrefix := NetProtocolTcpTls.String() + ":" - if strings.HasPrefix(upstream, tcpTLSPrefix) { - return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):] - } - - httpsPrefix := NetProtocolHttps.String() + ":" - if strings.HasPrefix(upstream, httpsPrefix) { - return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//") - } - - return NetProtocolTcpUdp, upstream -} - // Config main configuration // //nolint:maligned type Config struct { - Upstream UpstreamConfig `yaml:"upstream"` + Upstream ParallelBestConfig `yaml:"upstream"` UpstreamTimeout Duration `yaml:"upstreamTimeout" default:"2s"` ConnectIPVersion IPVersion `yaml:"connectIPVersion"` CustomDNS CustomDNSConfig `yaml:"customDNS"` @@ -483,7 +167,7 @@ type Config struct { ClientLookup ClientLookupConfig `yaml:"clientLookup"` Caching CachingConfig `yaml:"caching"` QueryLog QueryLogConfig `yaml:"queryLog"` - Prometheus PrometheusConfig `yaml:"prometheus"` + Prometheus MetricsConfig `yaml:"prometheus"` Redis RedisConfig `yaml:"redis"` Log log.Config `yaml:"log"` Ports PortsConfig `yaml:"ports"` @@ -494,7 +178,7 @@ type Config struct { KeyFile string `yaml:"keyFile"` BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"` HostsFile HostsFileConfig `yaml:"hostsFile"` - FqdnOnly bool `yaml:"fqdnOnly" default:"false"` + FqdnOnly FqdnOnlyConfig `yaml:",inline"` Filtering FilteringConfig `yaml:"filtering"` Ede EdeConfig `yaml:"ede"` // Deprecated @@ -508,7 +192,7 @@ type Config struct { // Deprecated LogTimestamp bool `yaml:"logTimestamp" default:"true"` // Deprecated - DNSPorts ListenConfig `yaml:"port" default:"[\"53\"]"` + DNSPorts ListenConfig `yaml:"port" default:"\"53\""` // Deprecated HTTPPorts ListenConfig `yaml:"httpPort"` // Deprecated @@ -518,12 +202,19 @@ type Config struct { } type PortsConfig struct { - DNS ListenConfig `yaml:"dns" default:"[\"53\"]"` + DNS ListenConfig `yaml:"dns" default:"\"53\""` HTTP ListenConfig `yaml:"http"` HTTPS ListenConfig `yaml:"https"` TLS ListenConfig `yaml:"tls"` } +func (c *PortsConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("DNS = %s", c.DNS) + logger.Infof("TLS = %s", c.TLS) + logger.Infof("HTTP = %s", c.HTTP) + logger.Infof("HTTPS = %s", c.HTTPS) +} + // split in two types to avoid infinite recursion. See `BootstrapDNSConfig.UnmarshalYAML`. type ( BootstrapDNSConfig bootstrapDNSConfig @@ -539,105 +230,6 @@ type ( } ) -// PrometheusConfig contains the config values for prometheus -type PrometheusConfig struct { - Enable bool `yaml:"enable" default:"false"` - Path string `yaml:"path" default:"/metrics"` -} - -// UpstreamConfig upstream server configuration -type UpstreamConfig struct { - ExternalResolvers map[string][]Upstream `yaml:",inline"` -} - -// RewriteConfig custom DNS configuration -type RewriteConfig struct { - Rewrite map[string]string `yaml:"rewrite"` - FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"` -} - -// CustomDNSConfig custom DNS configuration -type CustomDNSConfig struct { - RewriteConfig `yaml:",inline"` - CustomTTL Duration `yaml:"customTTL" default:"1h"` - Mapping CustomDNSMapping `yaml:"mapping"` - FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` -} - -// CustomDNSMapping mapping for the custom DNS configuration -type CustomDNSMapping struct { - HostIPs map[string][]net.IP -} - -// ConditionalUpstreamConfig conditional upstream configuration -type ConditionalUpstreamConfig struct { - RewriteConfig `yaml:",inline"` - Mapping ConditionalUpstreamMapping `yaml:"mapping"` -} - -// ConditionalUpstreamMapping mapping for conditional configuration -type ConditionalUpstreamMapping struct { - Upstreams map[string][]Upstream -} - -// BlockingConfig configuration for query blocking -type BlockingConfig struct { - BlackLists map[string][]string `yaml:"blackLists"` - WhiteLists map[string][]string `yaml:"whiteLists"` - ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"` - BlockType string `yaml:"blockType" default:"ZEROIP"` - BlockTTL Duration `yaml:"blockTTL" default:"6h"` - DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"` - DownloadAttempts uint `yaml:"downloadAttempts" default:"3"` - DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"` - RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` - // Deprecated - FailStartOnListError bool `yaml:"failStartOnListError" default:"false"` - ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"` - StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"` -} - -// ClientLookupConfig configuration for the client lookup -type ClientLookupConfig struct { - ClientnameIPMapping map[string][]net.IP `yaml:"clients"` - Upstream Upstream `yaml:"upstream"` - SingleNameOrder []uint `yaml:"singleNameOrder"` -} - -// CachingConfig configuration for domain caching -type CachingConfig struct { - MinCachingTime Duration `yaml:"minTime"` - MaxCachingTime Duration `yaml:"maxTime"` - CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"` - MaxItemsCount int `yaml:"maxItemsCount"` - Prefetching bool `yaml:"prefetching"` - PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"` - PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"` - PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"` -} - -func (c *CachingConfig) EnablePrefetch() { - const day = 24 * time.Hour - - if c.MaxCachingTime == 0 { - // make sure resolver gets enabled - c.MaxCachingTime = Duration(day) - } - - c.Prefetching = true - c.PrefetchThreshold = 0 -} - -// QueryLogConfig configuration for the query logging -type QueryLogConfig struct { - Target string `yaml:"target"` - Type QueryLogType `yaml:"type"` - LogRetentionDays uint64 `yaml:"logRetentionDays"` - CreationAttempts int `yaml:"creationAttempts" default:"3"` - CreationCooldown Duration `yaml:"creationCooldown" default:"2s"` - Fields []QueryLogField `yaml:"fields"` -} - // RedisConfig configuration for the redis connection type RedisConfig struct { Address string `yaml:"address"` @@ -652,21 +244,25 @@ type RedisConfig struct { SentinelAddresses []string `yaml:"sentinelAddresses"` } -type HostsFileConfig struct { - Filepath string `yaml:"filePath"` - HostsTTL Duration `yaml:"hostsTTL" default:"1h"` - RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"` - FilterLoopback bool `yaml:"filterLoopback"` -} +type ( + FqdnOnlyConfig = toEnable + EdeConfig = toEnable +) -type FilteringConfig struct { - QueryTypes QTypeSet `yaml:"queryTypes"` -} - -type EdeConfig struct { +type toEnable struct { Enable bool `yaml:"enable" default:"false"` } +// IsEnabled implements `config.Configurable`. +func (c *toEnable) IsEnabled() bool { + return c.Enable +} + +// LogConfig implements `config.Configurable`. +func (c *toEnable) LogConfig(logger *logrus.Entry) { + logger.Info("enabled") +} + //nolint:gochecknoglobals var ( config = &Config{} diff --git a/config/config_enum.go b/config/config_enum.go index 4afa8fcd..337b258b 100644 --- a/config/config_enum.go +++ b/config/config_enum.go @@ -40,6 +40,15 @@ func IPVersionNames() []string { return tmp } +// IPVersionValues returns a list of the values for IPVersion +func IPVersionValues() []IPVersion { + return []IPVersion{ + IPVersionDual, + IPVersionV4, + IPVersionV6, + } +} + var _IPVersionMap = map[IPVersion]string{ IPVersionDual: _IPVersionName[0:4], IPVersionV4: _IPVersionName[4:6], @@ -113,6 +122,15 @@ func NetProtocolNames() []string { return tmp } +// NetProtocolValues returns a list of the values for NetProtocol +func NetProtocolValues() []NetProtocol { + return []NetProtocol{ + NetProtocolTcpUdp, + NetProtocolTcpTls, + NetProtocolHttps, + } +} + var _NetProtocolMap = map[NetProtocol]string{ NetProtocolTcpUdp: _NetProtocolName[0:7], NetProtocolTcpTls: _NetProtocolName[7:14], @@ -190,6 +208,18 @@ func QueryLogFieldNames() []string { return tmp } +// QueryLogFieldValues returns a list of the values for QueryLogField +func QueryLogFieldValues() []QueryLogField { + return []QueryLogField{ + QueryLogFieldClientIP, + QueryLogFieldClientName, + QueryLogFieldResponseReason, + QueryLogFieldResponseAnswer, + QueryLogFieldQuestion, + QueryLogFieldDuration, + } +} + // String implements the Stringer interface. func (x QueryLogField) String() string { return string(x) @@ -274,6 +304,18 @@ func QueryLogTypeNames() []string { return tmp } +// QueryLogTypeValues returns a list of the values for QueryLogType +func QueryLogTypeValues() []QueryLogType { + return []QueryLogType{ + QueryLogTypeConsole, + QueryLogTypeNone, + QueryLogTypeMysql, + QueryLogTypePostgresql, + QueryLogTypeCsv, + QueryLogTypeCsvClient, + } +} + var _QueryLogTypeMap = map[QueryLogType]string{ QueryLogTypeConsole: _QueryLogTypeName[0:7], QueryLogTypeNone: _QueryLogTypeName[7:11], @@ -353,6 +395,15 @@ func StartStrategyTypeNames() []string { return tmp } +// StartStrategyTypeValues returns a list of the values for StartStrategyType +func StartStrategyTypeValues() []StartStrategyType { + return []StartStrategyType{ + StartStrategyTypeBlocking, + StartStrategyTypeFailOnError, + StartStrategyTypeFast, + } +} + var _StartStrategyTypeMap = map[StartStrategyType]string{ StartStrategyTypeBlocking: _StartStrategyTypeName[0:8], StartStrategyTypeFailOnError: _StartStrategyTypeName[8:19], diff --git a/config/config_suite_test.go b/config/config_suite_test.go index 1ad724a1..4b200bab 100644 --- a/config/config_suite_test.go +++ b/config/config_suite_test.go @@ -6,6 +6,12 @@ import ( "github.com/0xERR0R/blocky/log" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/sirupsen/logrus" +) + +var ( + logger *logrus.Entry + hook *log.MockLoggerHook ) func TestConfig(t *testing.T) { @@ -13,3 +19,9 @@ func TestConfig(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Config Suite") } + +func suiteBeforeEach() { + BeforeEach(func() { + logger, hook = log.NewMockEntry() + }) +} diff --git a/config/config_test.go b/config/config_test.go index ba571ea3..7eea1e66 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,6 @@ package config import ( - "errors" "net" "time" @@ -321,15 +320,11 @@ bootstrapDns: }) }) - Describe("YAML parsing", func() { + Describe("Parsing", func() { Context("upstream", func() { It("should create the upstream struct with data", func() { u := &Upstream{} - err := u.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "tcp+udp:1.2.3.4" - - return nil - }) + err := u.UnmarshalText([]byte("tcp+udp:1.2.3.4")) Expect(err).Should(Succeed()) Expect(u.Net).Should(Equal(NetProtocolTcpUdp)) Expect(u.Host).Should(Equal("1.2.3.4")) @@ -338,139 +333,18 @@ bootstrapDns: It("should fail if the upstream is in wrong format", func() { u := &Upstream{} - err := u.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) + err := u.UnmarshalText([]byte("invalid!")) Expect(err).Should(HaveOccurred()) }) }) Context("ListenConfig", func() { It("should parse and split valid string config", func() { l := &ListenConfig{} - err := l.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "55,:56" - - return nil - }) + err := l.UnmarshalText([]byte("55,:56")) Expect(err).Should(Succeed()) Expect(*l).Should(HaveLen(2)) Expect(*l).Should(ContainElements("55", ":56")) }) - It("should fail on error", func() { - l := &ListenConfig{} - err := l.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) - Expect(err).Should(HaveOccurred()) - }) - }) - Context("Duration", func() { - It("should parse duration with unit", func() { - d := Duration(0) - err := d.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "1m20s" - - return nil - }) - Expect(err).Should(Succeed()) - Expect(d).Should(Equal(Duration(80 * time.Second))) - Expect(d.String()).Should(Equal("1 minute 20 seconds")) - }) - It("should fail if duration is in wrong format", func() { - d := Duration(0) - err := d.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "wrong" - - return nil - }) - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError("time: invalid duration \"wrong\"")) - }) - It("should fail if wrong YAML format", func() { - d := Duration(0) - err := d.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError("some err")) - }) - }) - Context("ConditionalUpstreamMapping", func() { - It("Should parse config as map", func() { - c := &ConditionalUpstreamMapping{} - err := c.UnmarshalYAML(func(i interface{}) error { - *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} - - return nil - }) - Expect(err).Should(Succeed()) - Expect(c.Upstreams).Should(HaveLen(1)) - Expect(c.Upstreams["key"]).Should(HaveLen(1)) - Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{ - Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53, - })) - }) - It("should fail if wrong YAML format", func() { - c := &ConditionalUpstreamMapping{} - err := c.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError("some err")) - }) - }) - Context("CustomDNSMapping", func() { - It("Should parse config as map", func() { - c := &CustomDNSMapping{} - err := c.UnmarshalYAML(func(i interface{}) error { - *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} - - return nil - }) - Expect(err).Should(Succeed()) - Expect(c.HostIPs).Should(HaveLen(1)) - Expect(c.HostIPs["key"]).Should(HaveLen(1)) - Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4"))) - }) - It("should fail if wrong YAML format", func() { - c := &CustomDNSMapping{} - err := c.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError("some err")) - }) - }) - Context("QueryTyoe", func() { - It("Should parse existing DNS type as string", func() { - t := QType(0) - err := t.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "AAAA" - - return nil - }) - Expect(err).Should(Succeed()) - Expect(t).Should(Equal(QType(dns.TypeAAAA))) - Expect(t.String()).Should(Equal("AAAA")) - }) - It("should fail if DNS type does not exist", func() { - t := QType(0) - err := t.UnmarshalYAML(func(i interface{}) error { - *i.(*string) = "WRONGTYPE" - - return nil - }) - Expect(err).Should(HaveOccurred()) - Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'")) - }) - It("should fail if wrong YAML format", func() { - d := QType(0) - err := d.UnmarshalYAML(func(i interface{}) error { - return errors.New("some err") - }) - Expect(err).Should(HaveOccurred()) - Expect(err).Should(MatchError("some err")) - }) }) }) @@ -651,29 +525,6 @@ bootstrapDns: "tcp-tls:[fd00::6cd4:d7e0:d99d:2952]", ), ) - - Describe("QTypeSet", func() { - It("new should insert given qTypes", func() { - set := NewQTypeSet(dns.Type(dns.TypeA)) - Expect(set).Should(HaveKey(QType(dns.TypeA))) - Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue()) - - Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) - Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) - }) - - It("should insert given qTypes", func() { - set := NewQTypeSet() - - Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) - Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) - - set.Insert(dns.Type(dns.TypeAAAA)) - - Expect(set).Should(HaveKey(QType(dns.TypeAAAA))) - Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue()) - }) - }) }) func defaultTestFileConfig() { @@ -700,8 +551,8 @@ func defaultTestFileConfig() { Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour))) Expect(config.Filtering.QueryTypes).Should(HaveLen(2)) - Expect(config.Caching.MaxCachingTime).Should(Equal(Duration(0))) - Expect(config.Caching.MinCachingTime).Should(Equal(Duration(0))) + Expect(config.Caching.MaxCachingTime.IsZero()).Should(BeTrue()) + Expect(config.Caching.MinCachingTime.IsZero()).Should(BeTrue()) Expect(config.DoHUserAgent).Should(Equal("testBlocky")) Expect(config.MinTLSServeVer).Should(Equal("1.3")) diff --git a/config/custom_dns.go b/config/custom_dns.go new file mode 100644 index 00000000..a9de22c2 --- /dev/null +++ b/config/custom_dns.go @@ -0,0 +1,68 @@ +package config + +import ( + "fmt" + "net" + "strings" + + "github.com/sirupsen/logrus" +) + +// CustomDNSConfig custom DNS configuration +type CustomDNSConfig struct { + RewriterConfig `yaml:",inline"` + CustomTTL Duration `yaml:"customTTL" default:"1h"` + Mapping CustomDNSMapping `yaml:"mapping"` + FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` +} + +// CustomDNSMapping mapping for the custom DNS configuration +type CustomDNSMapping struct { + HostIPs map[string][]net.IP `yaml:"hostIPs"` +} + +// IsEnabled implements `config.Configurable`. +func (c *CustomDNSConfig) IsEnabled() bool { + return len(c.Mapping.HostIPs) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *CustomDNSConfig) LogConfig(logger *logrus.Entry) { + logger.Debugf("TTL = %s", c.CustomTTL) + logger.Debugf("filterUnmappedTypes = %t", c.FilterUnmappedTypes) + + logger.Info("mapping:") + + for key, val := range c.Mapping.HostIPs { + logger.Infof(" %s = %s", key, val) + } +} + +// UnmarshalYAML implements `yaml.Unmarshaler`. +func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { + var input map[string]string + if err := unmarshal(&input); err != nil { + return err + } + + result := make(map[string][]net.IP, len(input)) + + for k, v := range input { + var ips []net.IP + + for _, part := range strings.Split(v, ",") { + ip := net.ParseIP(strings.TrimSpace(part)) + if ip == nil { + return fmt.Errorf("invalid IP address '%s'", part) + } + + ips = append(ips, ip) + } + + result[k] = ips + } + + c.HostIPs = result + + return nil +} diff --git a/config/custom_dns_test.go b/config/custom_dns_test.go new file mode 100644 index 00000000..0361bf20 --- /dev/null +++ b/config/custom_dns_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "errors" + "net" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("CustomDNSConfig", func() { + var ( + cfg CustomDNSConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = CustomDNSConfig{ + Mapping: CustomDNSMapping{ + HostIPs: map[string][]net.IP{ + "custom.domain": {net.ParseIP("192.168.143.123")}, + "ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + "multiple.ips": { + net.ParseIP("192.168.143.123"), + net.ParseIP("192.168.143.125"), + net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + }, + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := CustomDNSConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := CustomDNSConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("custom.domain = "))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("multiple.ips = "))) + }) + }) + + Describe("UnmarshalYAML", func() { + It("Should parse config as map", func() { + c := &CustomDNSMapping{} + err := c.UnmarshalYAML(func(i interface{}) error { + *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} + + return nil + }) + Expect(err).Should(Succeed()) + Expect(c.HostIPs).Should(HaveLen(1)) + Expect(c.HostIPs["key"]).Should(HaveLen(1)) + Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4"))) + }) + + It("should fail if wrong YAML format", func() { + c := &CustomDNSMapping{} + err := c.UnmarshalYAML(func(i interface{}) error { + return errors.New("some err") + }) + Expect(err).Should(HaveOccurred()) + Expect(err).Should(MatchError("some err")) + }) + }) +}) diff --git a/config/duration.go b/config/duration.go new file mode 100644 index 00000000..5ca316ac --- /dev/null +++ b/config/duration.go @@ -0,0 +1,54 @@ +package config + +import ( + "strconv" + "time" + + "github.com/0xERR0R/blocky/log" + "github.com/hako/durafmt" +) + +type Duration time.Duration + +func (c Duration) ToDuration() time.Duration { + return time.Duration(c) +} + +func (c Duration) IsZero() bool { + return c.ToDuration() == 0 +} + +func (c Duration) Seconds() float64 { + return c.ToDuration().Seconds() +} + +func (c Duration) SecondsU32() uint32 { + return uint32(c.Seconds()) +} + +func (c Duration) String() string { + return durafmt.Parse(c.ToDuration()).String() +} + +// UnmarshalText implements `encoding.TextUnmarshaler`. +func (c *Duration) UnmarshalText(data []byte) error { + input := string(data) + + if minutes, err := strconv.Atoi(input); err == nil { + // number without unit: use minutes to ensure back compatibility + *c = Duration(time.Duration(minutes) * time.Minute) + + log.Log().Warnf("Setting a duration without a unit is deprecated. Please use '%s min' instead.", input) + + return nil + } + + duration, err := time.ParseDuration(input) + if err == nil { + *c = Duration(duration) + + return nil + } + + return err +} diff --git a/config/duration_test.go b/config/duration_test.go new file mode 100644 index 00000000..882f038f --- /dev/null +++ b/config/duration_test.go @@ -0,0 +1,44 @@ +package config + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Duration", func() { + var d Duration + + BeforeEach(func() { + var zero Duration + + d = zero + }) + + Describe("UnmarshalText", func() { + It("should parse duration with unit", func() { + err := d.UnmarshalText([]byte("1m20s")) + Expect(err).Should(Succeed()) + Expect(d).Should(Equal(Duration(80 * time.Second))) + Expect(d.String()).Should(Equal("1 minute 20 seconds")) + }) + + It("should fail if duration is in wrong format", func() { + err := d.UnmarshalText([]byte("wrong")) + Expect(err).Should(HaveOccurred()) + Expect(err).Should(MatchError("time: invalid duration \"wrong\"")) + }) + }) + + Describe("IsZero", func() { + It("should be true for zero", func() { + Expect(d.IsZero()).Should(BeTrue()) + Expect(Duration(0).IsZero()).Should(BeTrue()) + }) + + It("should be false for non-zero", func() { + Expect(Duration(time.Second).IsZero()).Should(BeFalse()) + }) + }) +}) diff --git a/config/filtering.go b/config/filtering.go new file mode 100644 index 00000000..dcc37ac7 --- /dev/null +++ b/config/filtering.go @@ -0,0 +1,23 @@ +package config + +import ( + "github.com/sirupsen/logrus" +) + +type FilteringConfig struct { + QueryTypes QTypeSet `yaml:"queryTypes"` +} + +// IsEnabled implements `config.Configurable`. +func (c *FilteringConfig) IsEnabled() bool { + return len(c.QueryTypes) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *FilteringConfig) LogConfig(logger *logrus.Entry) { + logger.Info("query types:") + + for qType := range c.QueryTypes { + logger.Infof(" - %s", qType) + } +} diff --git a/config/filtering_test.go b/config/filtering_test.go new file mode 100644 index 00000000..648cd186 --- /dev/null +++ b/config/filtering_test.go @@ -0,0 +1,57 @@ +package config + +import ( + . "github.com/0xERR0R/blocky/helpertest" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("FilteringConfig", func() { + var ( + cfg FilteringConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = FilteringConfig{ + QueryTypes: NewQTypeSet(AAAA, MX), + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := FilteringConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := FilteringConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).Should(HaveLen(3)) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("query types:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - AAAA"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - MX"))) + }) + }) +}) diff --git a/config/hosts_file.go b/config/hosts_file.go new file mode 100644 index 00000000..967b735d --- /dev/null +++ b/config/hosts_file.go @@ -0,0 +1,25 @@ +package config + +import ( + "github.com/sirupsen/logrus" +) + +type HostsFileConfig struct { + Filepath string `yaml:"filePath"` + HostsTTL Duration `yaml:"hostsTTL" default:"1h"` + RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"` + FilterLoopback bool `yaml:"filterLoopback"` +} + +// IsEnabled implements `config.Configurable`. +func (c *HostsFileConfig) IsEnabled() bool { + return len(c.Filepath) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("file path: %s", c.Filepath) + logger.Infof("TTL: %s", c.HostsTTL) + logger.Infof("refresh period: %s", c.RefreshPeriod) + logger.Infof("filter loopback addresses: %t", c.FilterLoopback) +} diff --git a/config/hosts_file_test.go b/config/hosts_file_test.go new file mode 100644 index 00000000..aa112080 --- /dev/null +++ b/config/hosts_file_test.go @@ -0,0 +1,58 @@ +package config + +import ( + "time" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HostsFileConfig", func() { + var ( + cfg HostsFileConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = HostsFileConfig{ + Filepath: "/dev/null", + HostsTTL: Duration(29 * time.Minute), + RefreshPeriod: Duration(30 * time.Minute), + FilterLoopback: true, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := HostsFileConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := HostsFileConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null"))) + }) + }) +}) diff --git a/config/metrics.go b/config/metrics.go new file mode 100644 index 00000000..69d9e4f7 --- /dev/null +++ b/config/metrics.go @@ -0,0 +1,19 @@ +package config + +import "github.com/sirupsen/logrus" + +// MetricsConfig contains the config values for prometheus +type MetricsConfig struct { + Enable bool `yaml:"enable" default:"false"` + Path string `yaml:"path" default:"/metrics"` +} + +// IsEnabled implements `config.Configurable`. +func (c *MetricsConfig) IsEnabled() bool { + return c.Enable +} + +// LogConfig implements `config.Configurable`. +func (c *MetricsConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("url path: %s", c.Path) +} diff --git a/config/metrics_test.go b/config/metrics_test.go new file mode 100644 index 00000000..afedbe77 --- /dev/null +++ b/config/metrics_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("MetricsConfig", func() { + var ( + cfg MetricsConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = MetricsConfig{ + Enable: true, + Path: "/custom/path", + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := MetricsConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := MetricsConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).Should(HaveLen(1)) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("url path: /custom/path"))) + }) + }) +}) diff --git a/config/parallel_best.go b/config/parallel_best.go new file mode 100644 index 00000000..6ce73543 --- /dev/null +++ b/config/parallel_best.go @@ -0,0 +1,32 @@ +package config + +import ( + "github.com/sirupsen/logrus" +) + +const UpstreamDefaultCfgName = "default" + +// ParallelBestConfig upstream server configuration +type ParallelBestConfig struct { + ExternalResolvers ParallelBestMapping `yaml:",inline"` +} + +type ParallelBestMapping map[string][]Upstream + +// IsEnabled implements `config.Configurable`. +func (c *ParallelBestConfig) IsEnabled() bool { + return len(c.ExternalResolvers) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *ParallelBestConfig) LogConfig(logger *logrus.Entry) { + logger.Info("upstream resolvers:") + + for name, upstreams := range c.ExternalResolvers { + logger.Infof(" %s:", name) + + for _, upstream := range upstreams { + logger.Infof(" - %s", upstream) + } + } +} diff --git a/config/parallel_best_test.go b/config/parallel_best_test.go new file mode 100644 index 00000000..1f8f6e95 --- /dev/null +++ b/config/parallel_best_test.go @@ -0,0 +1,59 @@ +package config + +import ( + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ParallelBestConfig", func() { + var ( + cfg ParallelBestConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = ParallelBestConfig{ + ExternalResolvers: ParallelBestMapping{ + UpstreamDefaultCfgName: { + {Host: "host1"}, + {Host: "host2"}, + }, + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := ParallelBestConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := ParallelBestConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstream resolvers:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:"))) + }) + }) +}) diff --git a/config/qtype_set.go b/config/qtype_set.go new file mode 100644 index 00000000..4508b82f --- /dev/null +++ b/config/qtype_set.go @@ -0,0 +1,76 @@ +package config + +import ( + "fmt" + "sort" + "strings" + + "github.com/miekg/dns" + "golang.org/x/exp/maps" +) + +type QTypeSet map[QType]struct{} + +func NewQTypeSet(qTypes ...dns.Type) QTypeSet { + s := make(QTypeSet, len(qTypes)) + + for _, qType := range qTypes { + s.Insert(qType) + } + + return s +} + +func (s QTypeSet) Contains(qType dns.Type) bool { + _, found := s[QType(qType)] + + return found +} + +func (s *QTypeSet) Insert(qType dns.Type) { + if *s == nil { + *s = make(QTypeSet, 1) + } + + (*s)[QType(qType)] = struct{}{} +} + +func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error { + var input []QType + if err := unmarshal(&input); err != nil { + return err + } + + *s = make(QTypeSet, len(input)) + + for _, qType := range input { + (*s)[qType] = struct{}{} + } + + return nil +} + +type QType dns.Type + +func (c QType) String() string { + return dns.Type(c).String() +} + +// UnmarshalText implements `encoding.TextUnmarshaler`. +func (c *QType) UnmarshalText(data []byte) error { + input := string(data) + + t, found := dns.StringToType[input] + if !found { + types := maps.Keys(dns.StringToType) + + sort.Strings(types) + + return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'", + input, strings.Join(types, ", ")) + } + + *c = QType(t) + + return nil +} diff --git a/config/qtype_set_test.go b/config/qtype_set_test.go new file mode 100644 index 00000000..0515d8df --- /dev/null +++ b/config/qtype_set_test.go @@ -0,0 +1,60 @@ +package config + +import ( + "github.com/miekg/dns" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("QTypeSet", func() { + Describe("NewQTypeSet", func() { + It("should insert given qTypes", func() { + set := NewQTypeSet(dns.Type(dns.TypeA)) + Expect(set).Should(HaveKey(QType(dns.TypeA))) + Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue()) + + Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) + Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) + }) + }) + + Describe("Insert", func() { + It("should insert given qTypes", func() { + set := NewQTypeSet() + + Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) + Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) + + set.Insert(dns.Type(dns.TypeAAAA)) + + Expect(set).Should(HaveKey(QType(dns.TypeAAAA))) + Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue()) + }) + }) +}) + +var _ = Describe("QType", func() { + Describe("UnmarshalText", func() { + It("Should parse existing DNS type as string", func() { + t := QType(0) + err := t.UnmarshalText([]byte("AAAA")) + Expect(err).Should(Succeed()) + Expect(t).Should(Equal(QType(dns.TypeAAAA))) + Expect(t.String()).Should(Equal("AAAA")) + }) + + It("should fail if DNS type does not exist", func() { + t := QType(0) + err := t.UnmarshalText([]byte("WRONGTYPE")) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'")) + }) + + It("should fail if wrong YAML format", func() { + d := QType(0) + err := d.UnmarshalText([]byte("some err")) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'some err'")) + }) + }) +}) diff --git a/config/query_log.go b/config/query_log.go new file mode 100644 index 00000000..bb4636b5 --- /dev/null +++ b/config/query_log.go @@ -0,0 +1,41 @@ +package config + +import ( + "github.com/sirupsen/logrus" +) + +// QueryLogConfig configuration for the query logging +type QueryLogConfig struct { + Target string `yaml:"target"` + Type QueryLogType `yaml:"type"` + LogRetentionDays uint64 `yaml:"logRetentionDays"` + CreationAttempts int `yaml:"creationAttempts" default:"3"` + CreationCooldown Duration `yaml:"creationCooldown" default:"2s"` + Fields []QueryLogField `yaml:"fields"` +} + +// SetDefaults implements `defaults.Setter`. +func (c *QueryLogConfig) SetDefaults() { + // Since the default depends on the enum values, set it dynamically + // to avoid having to repeat the values in the annotation. + c.Fields = QueryLogFieldValues() +} + +// IsEnabled implements `config.Configurable`. +func (c *QueryLogConfig) IsEnabled() bool { + return c.Type != QueryLogTypeNone +} + +// LogConfig implements `config.Configurable`. +func (c *QueryLogConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("type: %s", c.Type) + + if c.Target != "" { + logger.Infof("target: %s", c.Target) + } + + logger.Infof("logRetentionDays: %d", c.LogRetentionDays) + logger.Debugf("creationAttempts: %d", c.CreationAttempts) + logger.Debugf("creationCooldown: %s", c.CreationCooldown) + logger.Infof("fields: %s", c.Fields) +} diff --git a/config/query_log_test.go b/config/query_log_test.go new file mode 100644 index 00000000..208f755c --- /dev/null +++ b/config/query_log_test.go @@ -0,0 +1,72 @@ +package config + +import ( + "time" + + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("QueryLogConfig", func() { + var ( + cfg QueryLogConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = QueryLogConfig{ + Target: "/dev/null", + Type: QueryLogTypeCsvClient, + LogRetentionDays: 0, + CreationAttempts: 1, + CreationCooldown: Duration(time.Millisecond), + } + }) + + Describe("IsEnabled", func() { + It("should be true by default", func() { + cfg := QueryLogConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := QueryLogConfig{ + Type: QueryLogTypeNone, + } + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("logRetentionDays:"))) + }) + }) + + Describe("SetDefaults", func() { + It("should log configuration", func() { + cfg := QueryLogConfig{} + Expect(cfg.Fields).Should(BeEmpty()) + + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.Fields).ShouldNot(BeEmpty()) + }) + }) +}) diff --git a/config/rewriter.go b/config/rewriter.go new file mode 100644 index 00000000..b23f1769 --- /dev/null +++ b/config/rewriter.go @@ -0,0 +1,27 @@ +package config + +import ( + "github.com/sirupsen/logrus" +) + +// RewriterConfig custom DNS configuration +type RewriterConfig struct { + Rewrite map[string]string `yaml:"rewrite"` + FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"` +} + +// IsEnabled implements `config.Configurable`. +func (c *RewriterConfig) IsEnabled() bool { + return len(c.Rewrite) != 0 +} + +// LogConfig implements `config.Configurable`. +func (c *RewriterConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("fallbackUpstream = %t", c.FallbackUpstream) + + logger.Info("rules:") + + for key, val := range c.Rewrite { + logger.Infof(" %s = %s", key, val) + } +} diff --git a/config/rewriter_test.go b/config/rewriter_test.go new file mode 100644 index 00000000..5b92083e --- /dev/null +++ b/config/rewriter_test.go @@ -0,0 +1,57 @@ +package config + +import ( + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("RewriterConfig", func() { + var ( + cfg RewriterConfig + ) + + suiteBeforeEach() + + BeforeEach(func() { + cfg = RewriterConfig{ + Rewrite: map[string]string{ + "original1": "rewritten1", + "original2": "rewritten2", + }, + } + }) + + Describe("IsEnabled", func() { + It("should be false by default", func() { + cfg := RewriterConfig{} + Expect(defaults.Set(&cfg)).Should(Succeed()) + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + + When("enabled", func() { + It("should be true", func() { + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + }) + + When("disabled", func() { + It("should be false", func() { + cfg := RewriterConfig{} + + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("rules:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("original2 ="))) + }) + }) +}) diff --git a/config/upstream.go b/config/upstream.go new file mode 100644 index 00000000..ce4f2cd3 --- /dev/null +++ b/config/upstream.go @@ -0,0 +1,164 @@ +package config + +import ( + "fmt" + "net" + "regexp" + "strings" +) + +var validDomain = regexp.MustCompile( + `^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) + +// Upstream is the definition of external DNS server +type Upstream struct { + Net NetProtocol + Host string + Port uint16 + Path string + CommonName string // Common Name to use for certificate verification; optional. "" uses .Host +} + +// IsDefault returns true if u is the default value +func (u *Upstream) IsDefault() bool { + return *u == Upstream{} +} + +// String returns the string representation of u +func (u Upstream) String() string { + if u.IsDefault() { + return "no upstream" + } + + var sb strings.Builder + + sb.WriteString(u.Net.String()) + sb.WriteRune(':') + + if u.Net == NetProtocolHttps { + sb.WriteString("//") + } + + isIPv6 := strings.ContainsRune(u.Host, ':') + if isIPv6 { + sb.WriteRune('[') + sb.WriteString(u.Host) + sb.WriteRune(']') + } else { + sb.WriteString(u.Host) + } + + if u.Port != netDefaultPort[u.Net] { + sb.WriteRune(':') + sb.WriteString(fmt.Sprint(u.Port)) + } + + if u.Path != "" { + sb.WriteString(u.Path) + } + + return sb.String() +} + +// UnmarshalText implements `encoding.TextUnmarshaler`. +func (u *Upstream) UnmarshalText(data []byte) error { + s := string(data) + + upstream, err := ParseUpstream(s) + if err != nil { + return fmt.Errorf("can't convert upstream '%s': %w", s, err) + } + + *u = upstream + + return nil +} + +// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname] +func ParseUpstream(upstream string) (Upstream, error) { + var path string + + var port uint16 + + commonName, upstream := extractCommonName(upstream) + + n, upstream := extractNet(upstream) + + path, upstream = extractPath(upstream) + + host, portString, err := net.SplitHostPort(upstream) + + // string contains host:port + if err == nil { + p, err := ConvertPort(portString) + if err != nil { + err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err) + + return Upstream{}, err + } + + port = p + } else { + // only host, use default port + host = upstream + port = netDefaultPort[n] + + // trim any IPv6 brackets + host = strings.TrimPrefix(host, "[") + host = strings.TrimSuffix(host, "]") + } + + // validate hostname or ip + if ip := net.ParseIP(host); ip == nil { + // is not IP + if !validDomain.MatchString(host) { + return Upstream{}, fmt.Errorf("wrong host name '%s'", host) + } + } + + return Upstream{ + Net: n, + Host: host, + Port: port, + Path: path, + CommonName: commonName, + }, nil +} + +func extractCommonName(in string) (string, string) { + upstream, cn, _ := strings.Cut(in, "#") + + return cn, upstream +} + +func extractPath(in string) (path, upstream string) { + slashIdx := strings.Index(in, "/") + + if slashIdx >= 0 { + path = in[slashIdx:] + upstream = in[:slashIdx] + } else { + upstream = in + } + + return +} + +func extractNet(upstream string) (NetProtocol, string) { + tcpUDPPrefix := NetProtocolTcpUdp.String() + ":" + if strings.HasPrefix(upstream, tcpUDPPrefix) { + return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):] + } + + tcpTLSPrefix := NetProtocolTcpTls.String() + ":" + if strings.HasPrefix(upstream, tcpTLSPrefix) { + return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):] + } + + httpsPrefix := NetProtocolHttps.String() + ":" + if strings.HasPrefix(upstream, httpsPrefix) { + return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//") + } + + return NetProtocolTcpUdp, upstream +} diff --git a/go.mod b/go.mod index 1d376d22..0e0c21f8 100644 --- a/go.mod +++ b/go.mod @@ -123,6 +123,7 @@ require ( github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect golang.org/x/crypto v0.6.0 // indirect + golang.org/x/exp v0.0.0-20230307190834-24139beb5833 golang.org/x/mod v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/term v0.6.0 // indirect diff --git a/go.sum b/go.sum index 62968888..6b4a7556 100644 --- a/go.sum +++ b/go.sum @@ -475,6 +475,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s= +golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/lists/list_cache.go b/lists/list_cache.go index 2b3e0761..982f19d3 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/hako/durafmt" "github.com/hashicorp/go-multierror" "github.com/sirupsen/logrus" @@ -38,9 +37,6 @@ type ListCacheType int type Matcher interface { // Match matches passed domain name against cached list entries Match(domain string, groupsToCheck []string) (found bool, group string) - - // Configuration returns current configuration and stats - Configuration() []string } // ListCache generic cache of strings divided in groups @@ -55,42 +51,19 @@ type ListCache struct { processingConcurrency uint } -// Configuration returns current configuration and stats -func (b *ListCache) Configuration() (result []string) { - if b.refreshPeriod > 0 { - result = append(result, fmt.Sprintf("refresh period: %s", durafmt.Parse(b.refreshPeriod))) - } else { - result = append(result, "refresh: disabled") - } - - result = append(result, "group links:") - for group, links := range b.groupToLinks { - result = append(result, fmt.Sprintf(" %s:", group)) - - for _, link := range links { - if strings.Contains(link, "\n") { - link = "[INLINE DEFINITION]" - } - - result = append(result, fmt.Sprintf(" - %s", link)) - } - } - - result = append(result, "group caches:") - +// LogConfig implements `config.Configurable`. +func (b *ListCache) LogConfig(logger *logrus.Entry) { var total int b.lock.RLock() defer b.lock.RUnlock() for group, cache := range b.groupCaches { - result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount())) + logger.Infof("%s: %d entries", group, cache.ElementCount()) total += cache.ElementCount() } - result = append(result, fmt.Sprintf(" TOTAL: %d entries", total)) - - return result + logger.Infof("TOTAL: %d entries", total) } // NewListCache creates new list instance diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go index 39f2c8e6..1cab91c2 100644 --- a/lists/list_cache_test.go +++ b/lists/list_cache_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "log" "math/rand" "net/http/httptest" "os" @@ -14,7 +13,9 @@ import ( . "github.com/0xERR0R/blocky/evt" "github.com/0xERR0R/blocky/lists/parsers" + "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/util" + "github.com/sirupsen/logrus" . "github.com/0xERR0R/blocky/helpertest" . "github.com/onsi/ginkgo/v2" @@ -414,36 +415,31 @@ var _ = Describe("ListCache", func() { }) }) }) - Describe("Configuration", func() { - When("refresh is enabled", func() { - It("should print list configuration", func() { - lists := map[string][]string{ - "gr1": {server1.URL, server2.URL}, - "gr2": {inlineList("inline", "definition")}, - } + Describe("LogConfig", func() { + var ( + logger *logrus.Entry + hook *log.MockLoggerHook + ) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(), - defaultProcessingConcurrency, false) - Expect(err).Should(Succeed()) - - c := sut.Configuration() - Expect(c).Should(ContainElement("refresh period: 1 hour")) - Expect(len(c)).Should(BeNumerically(">", 1)) - }) + BeforeEach(func() { + logger, hook = log.NewMockEntry() }) - When("refresh is disabled", func() { - It("should print 'refresh disabled'", func() { - lists := map[string][]string{ - "gr1": {emptyFile.Path}, - } - sut, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(), - defaultProcessingConcurrency, false) - Expect(err).Should(Succeed()) + It("should print list configuration", func() { + lists := map[string][]string{ + "gr1": {server1.URL, server2.URL}, + "gr2": {inlineList("inline", "definition")}, + } - c := sut.Configuration() - Expect(c).Should(ContainElement("refresh: disabled")) - }) + sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(), + defaultProcessingConcurrency, false) + Expect(err).Should(Succeed()) + + sut.LogConfig(logger) + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr1:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr2:"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("TOTAL:"))) }) }) @@ -486,7 +482,7 @@ func (m *MockDownloader) ListSource() string { func createTestListFile(dir string, totalLines int) (string, int) { file, err := os.CreateTemp(dir, "blocky") if err != nil { - log.Fatal(err) + log.Log().Fatal(err) } w := bufio.NewWriter(file) diff --git a/log/logger.go b/log/logger.go index 53d27839..ad98db3f 100644 --- a/log/logger.go +++ b/log/logger.go @@ -6,9 +6,11 @@ import ( "fmt" "io" "strings" + "sync" "github.com/sirupsen/logrus" prefixed "github.com/x-cray/logrus-prefixed-formatter" + "golang.org/x/exp/maps" ) const prefixField = "prefix" @@ -118,3 +120,50 @@ func ConfigureLogger(cfg *Config) { func Silence() { logger.Out = io.Discard } + +func WithIndent(log *logrus.Entry, prefix string, callback func(*logrus.Entry)) { + undo := indentMessages(prefix, log.Logger) + defer undo() + + callback(log) +} + +// indentMessages modifies a logger and adds `prefix` to all messages. +// +// The returned function must be called to remove the prefix. +func indentMessages(prefix string, logger *logrus.Logger) func() { + if _, ok := logger.Formatter.(*prefixed.TextFormatter); !ok { + // log is not plaintext, do nothing + return func() {} + } + + oldHooks := maps.Clone(logger.Hooks) + + logger.AddHook(prefixMsgHook{ + prefix: prefix, + }) + + var once sync.Once + + return func() { + once.Do(func() { + logger.ReplaceHooks(oldHooks) + }) + } +} + +type prefixMsgHook struct { + prefix string +} + +// Levels implements `logrus.Hook`. +func (h prefixMsgHook) Levels() []logrus.Level { + return logrus.AllLevels +} + +// Fire implements `logrus.Hook`. +func (h prefixMsgHook) Fire(entry *logrus.Entry) error { + entry.Message = h.prefix + entry.Message + + return nil +} diff --git a/log/mock_entry.go b/log/mock_entry.go new file mode 100644 index 00000000..0415a5ef --- /dev/null +++ b/log/mock_entry.go @@ -0,0 +1,42 @@ +package log + +import ( + "io" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" +) + +func NewMockEntry() (*logrus.Entry, *MockLoggerHook) { + logger := logrus.New() + logger.Out = io.Discard + + entry := logrus.Entry{Logger: logger} + hook := MockLoggerHook{} + + entry.Logger.AddHook(&hook) + + hook.On("Fire", mock.Anything).Return(nil) + + return &entry, &hook +} + +type MockLoggerHook struct { + mock.Mock + + Messages []string +} + +// Levels implements `logrus.Hook`. +func (h *MockLoggerHook) Levels() []logrus.Level { + return logrus.AllLevels +} + +// Fire implements `logrus.Hook`. +func (h *MockLoggerHook) Fire(entry *logrus.Entry) error { + _ = h.Called() + + h.Messages = append(h.Messages, entry.Message) + + return nil +} diff --git a/metrics/metrics.go b/metrics/metrics.go index b0057fc7..a35c7526 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -18,7 +18,7 @@ func RegisterMetric(c prometheus.Collector) { } // Start starts prometheus endpoint -func Start(router *chi.Mux, cfg config.PrometheusConfig) { +func Start(router *chi.Mux, cfg config.MetricsConfig) { if cfg.Enable { _ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) _ = reg.Register(collectors.NewGoCollector()) diff --git a/model/models.go b/model/models.go index 99aa2ebe..43b0292b 100644 --- a/model/models.go +++ b/model/models.go @@ -22,6 +22,31 @@ import ( // ) type ResponseType int +func (t ResponseType) ToExtendedErrorCode() uint16 { + switch t { + case ResponseTypeRESOLVED: + return dns.ExtendedErrorCodeOther + case ResponseTypeCACHED: + return dns.ExtendedErrorCodeCachedError + case ResponseTypeCONDITIONAL: + return dns.ExtendedErrorCodeForgedAnswer + case ResponseTypeCUSTOMDNS: + return dns.ExtendedErrorCodeForgedAnswer + case ResponseTypeHOSTSFILE: + return dns.ExtendedErrorCodeForgedAnswer + case ResponseTypeNOTFQDN: + return dns.ExtendedErrorCodeBlocked + case ResponseTypeBLOCKED: + return dns.ExtendedErrorCodeBlocked + case ResponseTypeFILTERED: + return dns.ExtendedErrorCodeFiltered + case ResponseTypeSPECIAL: + return dns.ExtendedErrorCodeFiltered + default: + return dns.ExtendedErrorCodeOther + } +} + // Response represents the response of a DNS query type Response struct { Res *dns.Msg diff --git a/redis/redis.go b/redis/redis.go index a4055738..adf22f1e 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -83,7 +83,7 @@ func New(cfg *config.RedisConfig) (*Client, error) { Password: cfg.Password, DB: cfg.Database, MaxRetries: cfg.ConnectionAttempts, - MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown), + MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(), }) } else { rdb = redis.NewClient(&redis.Options{ @@ -92,7 +92,7 @@ func New(cfg *config.RedisConfig) (*Client, error) { Password: cfg.Password, DB: cfg.Database, MaxRetries: cfg.ConnectionAttempts, - MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown), + MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(), }) } diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 1115b563..9c581161 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -35,7 +35,7 @@ func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) { return nxDomainBlockHandler{}, nil } - blockTime := uint32(time.Duration(cfg.BlockTTL).Seconds()) + blockTime := cfg.BlockTTL.SecondsU32() if strings.EqualFold(cfgBlockType, "ZEROIP") { return zeroIPBlockHandler{ @@ -78,16 +78,17 @@ type status struct { // BlockingResolver checks request's question (domain name) against black and white lists type BlockingResolver struct { + configurable[*config.BlockingConfig] NextResolver + typed + blacklistMatcher *lists.ListCache whitelistMatcher *lists.ListCache - cfg config.BlockingConfig blockHandler blockHandler whitelistOnlyGroups map[string]bool status *status clientGroupsBlock map[string][]string redisClient *redis.Client - redisEnabled bool fqdnIPCache expirationcache.ExpiringCache } @@ -100,7 +101,7 @@ func NewBlockingResolver( return nil, err } - refreshPeriod := time.Duration(cfg.RefreshPeriod) + refreshPeriod := cfg.RefreshPeriod.ToDuration() downloader := createDownloader(cfg, bootstrap) blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, refreshPeriod, downloader, cfg.ProcessingConcurrency, @@ -129,8 +130,10 @@ func NewBlockingResolver( } res := &BlockingResolver{ + configurable: withConfig(&cfg), + typed: withType("blocking"), + blockHandler: blockHandler, - cfg: cfg, blacklistMatcher: blacklistMatcher, whitelistMatcher: whitelistMatcher, whitelistOnlyGroups: whitelistOnlyGroups, @@ -140,10 +143,9 @@ func NewBlockingResolver( }, clientGroupsBlock: cgb, redisClient: redis, - redisEnabled: (redis != nil), } - if res.redisEnabled { + if res.redisClient != nil { setupRedisEnabledSubscriber(res) } @@ -156,27 +158,25 @@ func NewBlockingResolver( func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader { return lists.NewDownloader( - lists.WithTimeout(time.Duration(cfg.DownloadTimeout)), + lists.WithTimeout(cfg.DownloadTimeout.ToDuration()), lists.WithAttempts(cfg.DownloadAttempts), - lists.WithCooldown(time.Duration(cfg.DownloadCooldown)), + lists.WithCooldown(cfg.DownloadCooldown.ToDuration()), lists.WithTransport(bootstrap.NewHTTPTransport()), ) } func setupRedisEnabledSubscriber(c *BlockingResolver) { - logger := log.PrefixedLog("blocking_resolver") - go func() { for em := range c.redisClient.EnabledChannel { if em != nil { - logger.Debug("Received state from redis: ", em) + c.log().Debug("Received state from redis: ", em) if em.State { c.internalEnableBlocking() } else { err := c.internalDisableBlocking(em.Duration, em.Groups) if err != nil { - logger.Warn("Blocking couldn't be disabled:", err) + c.log().Warn("Blocking couldn't be disabled:", err) } } } @@ -213,7 +213,7 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string { func (r *BlockingResolver) EnableBlocking() { r.internalEnableBlocking() - if r.redisEnabled { + if r.redisClient != nil { r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true}) } } @@ -232,7 +232,7 @@ func (r *BlockingResolver) internalEnableBlocking() { // DisableBlocking deactivates the blocking for a particular duration (or forever if 0). func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error { err := r.internalDisableBlocking(duration, disableGroups) - if err == nil && r.redisEnabled { + if err == nil && r.redisClient != nil { r.redisClient.PublishEnabled(&redis.EnabledMessage{ State: false, Duration: duration, @@ -329,39 +329,15 @@ func (r *BlockingResolver) handleBlocked(logger *logrus.Entry, return &model.Response{Res: response, RType: model.ResponseTypeBLOCKED, Reason: reason}, nil } -// Configuration returns the current resolver configuration -func (r *BlockingResolver) Configuration() (result []string) { - if len(r.cfg.ClientGroupsBlock) == 0 { - return configDisabled - } +// LogConfig implements `config.Configurable`. +func (r *BlockingResolver) LogConfig(logger *logrus.Entry) { + r.cfg.LogConfig(logger) - result = append(result, "clientGroupsBlock") - for key, val := range r.cfg.ClientGroupsBlock { - result = append(result, fmt.Sprintf(" %s = \"%s\"", key, strings.Join(val, ";"))) - } + logger.Info("blacklist cache entries:") + log.WithIndent(logger, " ", r.blacklistMatcher.LogConfig) - blockType := r.cfg.BlockType - result = append(result, fmt.Sprintf("blockType = \"%s\"", blockType)) - - if blockType != "NXDOMAIN" { - result = append(result, fmt.Sprintf("blockTTL = %s", r.cfg.BlockTTL.String())) - } - - result = append(result, fmt.Sprintf("downloadTimeout = %s", r.cfg.DownloadTimeout.String())) - - result = append(result, fmt.Sprintf("FailStartOnListError = %t", r.cfg.FailStartOnListError)) - - result = append(result, "blacklist:") - for _, c := range r.blacklistMatcher.Configuration() { - result = append(result, fmt.Sprintf(" %s", c)) - } - - result = append(result, "whitelist:") - for _, c := range r.whitelistMatcher.Configuration() { - result = append(result, fmt.Sprintf(" %s", c)) - } - - return result + logger.Info("whitelist cache entries:") + log.WithIndent(logger, " ", r.whitelistMatcher.LogConfig) } func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool { @@ -594,7 +570,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) { } func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (result []net.IP, ttl time.Duration) { - prefixedLog := log.PrefixedLog("FQDNClientIdentifierCache") + prefixedLog := log.WithPrefix(r.log(), "client_id_cache") for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} { resp, err := r.next.Resolve(&model.Request{ diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index 5c2e68c6..b6203b09 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -7,6 +7,7 @@ import ( . "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/lists" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/util" @@ -51,7 +52,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer *dns.Msg ) - systemResolverBootstrap := &Bootstrap{} + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) BeforeEach(func() { sutConfig = config.BlockingConfig{ @@ -71,6 +76,22 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Events", func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ @@ -1086,35 +1107,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) }) - Describe("Configuration output", func() { - When("resolver is enabled", func() { - BeforeEach(func() { - sutConfig = config.BlockingConfig{ - BlockType: "ZEROIP", - BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{"gr1": {group1File.Path}}, - ClientGroupsBlock: map[string][]string{ - "default": {"gr1"}, - }, - } - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - - When("resolver is disabled", func() { - BeforeEach(func() { - sutConfig = config.BlockingConfig{} - }) - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - Describe("Create resolver with wrong parameter", func() { When("Wrong blockType is used", func() { It("should return error", func() { diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index f4160922..d8d95692 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -62,7 +62,11 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { return b, nil } - parallelResolver, err := newParallelBestResolver(bootstraped.ResolverGroups()) + // Bootstrap doesn't have a `LogConfig` method, and since that's the only place + // where `ParallelBestResolver` uses its config, we can just use an empty one. + pbCfg := config.ParallelBestConfig{} + + parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups()) if err != nil { return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err) } @@ -74,20 +78,16 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { cachingCfg := cfg.Caching cachingCfg.EnablePrefetch() - if cachingCfg.MinCachingTime == 0 { + if cachingCfg.MinCachingTime.IsZero() { // Set a min time in case the user didn't to avoid prefetching too often cachingCfg.MinCachingTime = config.Duration(time.Hour) } b.bootstraped = bootstraped - cachingResolver := NewCachingResolver(cachingCfg, nil) - // don't emit any metrics - cachingResolver.emitMetricEvents = false - b.resolver = Chain( NewFilteringResolver(cfg.Filtering), - cachingResolver, + newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones parallelResolver, ) @@ -116,10 +116,10 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) { ctx := context.Background() timeout := cfg.UpstreamTimeout - if timeout != 0 { + if timeout.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)) + ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration()) defer cancel() } diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index 9f9026a3..97907e3c 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -282,7 +282,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) - Expect(ips).Should(HaveLen(0)) + Expect(ips).Should(BeEmpty()) }) }) @@ -296,7 +296,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no such host")) - Expect(ips).Should(HaveLen(0)) + Expect(ips).Should(BeEmpty()) }) }) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index fac3c1c0..53ebb17d 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -5,8 +5,6 @@ import ( "sync/atomic" "time" - "github.com/hako/durafmt" - "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/evt" @@ -24,16 +22,15 @@ const defaultCachingCleanUpInterval = 5 * time.Second // CachingResolver caches answers from dns queries with their TTL time, // to avoid external resolver calls for recurrent queries type CachingResolver struct { + configurable[*config.CachingConfig] NextResolver - minCacheTimeSec, maxCacheTimeSec int - cacheTimeNegative time.Duration - resultCache expirationcache.ExpiringCache - prefetchExpires time.Duration - prefetchThreshold int - prefetchingNameCache expirationcache.ExpiringCache - redisClient *redis.Client - redisEnabled bool - emitMetricEvents bool + typed + + emitMetricEvents bool // disabled by Bootstrap + + resultCache expirationcache.ExpiringCache + prefetchingNameCache expirationcache.ExpiringCache + redisClient *redis.Client } // cacheValue includes query answer and prefetch flag @@ -44,18 +41,21 @@ type cacheValue struct { // NewCachingResolver creates a new resolver instance func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver { + return newCachingResolver(cfg, redis, true) +} + +func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver { c := &CachingResolver{ - minCacheTimeSec: int(time.Duration(cfg.MinCachingTime).Seconds()), - maxCacheTimeSec: int(time.Duration(cfg.MaxCachingTime).Seconds()), - cacheTimeNegative: time.Duration(cfg.CacheTimeNegative), - redisClient: redis, - redisEnabled: (redis != nil), - emitMetricEvents: true, + configurable: withConfig(&cfg), + typed: withType("caching"), + + redisClient: redis, + emitMetricEvents: emitMetricEvents, } configureCaches(c, &cfg) - if c.redisEnabled { + if c.redisClient != nil { setupRedisCacheSubscriber(c) c.redisClient.GetRedisCache() } @@ -68,10 +68,6 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount)) if cfg.Prefetching { - c.prefetchExpires = time.Duration(cfg.PrefetchExpires) - - c.prefetchThreshold = cfg.PrefetchThreshold - c.prefetchingNameCache = expirationcache.NewCache( expirationcache.WithCleanUpInterval(time.Minute), expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount)), @@ -88,12 +84,10 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { } func setupRedisCacheSubscriber(c *CachingResolver) { - logger := log.PrefixedLog("caching_resolver") - go func() { for rc := range c.redisClient.CacheChannel { if rc != nil { - logger.Debug("Received key from redis: ", rc.Key) + c.log().Debug("Received key from redis: ", rc.Key) c.putInCache(rc.Key, rc.Response, false, false) } } @@ -102,22 +96,22 @@ func setupRedisCacheSubscriber(c *CachingResolver) { // check if domain was queried > threshold in the time window func (r *CachingResolver) shouldPrefetch(cacheKey string) bool { - if r.prefetchThreshold == 0 { + if r.cfg.PrefetchThreshold == 0 { return true } cnt, _ := r.prefetchingNameCache.Get(cacheKey) - return cnt != nil && cnt.(int) > r.prefetchThreshold + return cnt != nil && cnt.(int) > r.cfg.PrefetchThreshold } func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.Duration) { qType, domainName := util.ExtractCacheKey(cacheKey) - logger := log.PrefixedLog("caching_resolver") - if r.shouldPrefetch(cacheKey) { - logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType.String()) + logger := r.log() + + logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger) response, err := r.next.Resolve(req) @@ -136,29 +130,11 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time. return nil, 0 } -// Configuration returns a current resolver configuration -func (r *CachingResolver) Configuration() (result []string) { - if r.maxCacheTimeSec < 0 { - return configDisabled - } +// LogConfig implements `config.Configurable`. +func (r *CachingResolver) LogConfig(logger *logrus.Entry) { + r.cfg.LogConfig(logger) - result = append(result, fmt.Sprintf("minCacheTimeInSec = %d", r.minCacheTimeSec)) - - result = append(result, fmt.Sprintf("maxCacheTimeSec = %d", r.maxCacheTimeSec)) - - result = append(result, fmt.Sprintf("cacheTimeNegative = %s", durafmt.Parse(r.cacheTimeNegative))) - - result = append(result, fmt.Sprintf("prefetching = %t", r.prefetchingNameCache != nil)) - - if r.prefetchingNameCache != nil { - result = append(result, fmt.Sprintf("prefetchExpires = %s", durafmt.Parse(r.prefetchExpires))) - - result = append(result, fmt.Sprintf("prefetchThreshold = %d", r.prefetchThreshold)) - } - - result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.TotalCount())) - - return + logger.Infof("cache entries = %d", r.resultCache.TotalCount()) } // Resolve checks if the current query result is already in the cache and returns it @@ -166,7 +142,7 @@ func (r *CachingResolver) Configuration() (result []string) { func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) { logger := log.WithPrefix(request.Log, "caching_resolver") - if r.maxCacheTimeSec < 0 { + if r.cfg.MaxCachingTime < 0 { logger.Debug("skip cache") return r.next.Resolve(request) @@ -214,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo response, err = r.next.Resolve(request) if err == nil { - r.putInCache(cacheKey, response, false, r.redisEnabled) + r.putInCache(cacheKey, response, false, true) } } @@ -228,7 +204,7 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log domainCount = x.(int) } domainCount++ - r.prefetchingNameCache.Put(cacheKey, domainCount, r.prefetchExpires) + r.prefetchingNameCache.Put(cacheKey, domainCount, r.cfg.PrefetchExpires.ToDuration()) totalCount := r.prefetchingNameCache.TotalCount() logger.Debugf("domain '%s' was requested %d times, "+ @@ -242,9 +218,9 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, // put value into cache r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer)) } else if response.Res.Rcode == dns.RcodeNameError { - if r.cacheTimeNegative > 0 { + if r.cfg.CacheTimeNegative > 0 { // put negative cache if result code is NXDOMAIN - r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cacheTimeNegative) + r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration()) } } @@ -264,20 +240,20 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) { var max uint32 if len(answer) == 0 { - return r.cacheTimeNegative + return r.cfg.CacheTimeNegative.ToDuration() } for _, a := range answer { // if TTL < mitTTL -> adjust the value, set minTTL - if r.minCacheTimeSec > 0 { - if atomic.LoadUint32(&a.Header().Ttl) < uint32(r.minCacheTimeSec) { - atomic.StoreUint32(&a.Header().Ttl, uint32(r.minCacheTimeSec)) + if r.cfg.MinCachingTime > 0 { + if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() { + atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32()) } } - if r.maxCacheTimeSec > 0 { - if atomic.LoadUint32(&a.Header().Ttl) > uint32(r.maxCacheTimeSec) { - atomic.StoreUint32(&a.Header().Ttl, uint32(r.maxCacheTimeSec)) + if r.cfg.MaxCachingTime > 0 { + if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() { + atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32()) } } diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index fdef8a3f..de735450 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -7,6 +7,7 @@ import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/util" @@ -27,6 +28,12 @@ var _ = Describe("CachingResolver", func() { mockAnswer *dns.Msg ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { sutConfig = config.CachingConfig{} if err := defaults.Set(&sutConfig); err != nil { @@ -42,6 +49,22 @@ var _ = Describe("CachingResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Caching responses", func() { When("prefetching is enabled", func() { BeforeEach(func() { @@ -103,6 +126,15 @@ var _ = Describe("CachingResolver", func() { HaveTTL(BeNumerically("<=", 2)))) Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com"))) }) + When("threshold is 0", func() { + BeforeEach(func() { + sutConfig.PrefetchThreshold = 0 + }) + + It("should always prefetch", func() { + Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue()) + }) + }) }) When("min caching time is defined", func() { BeforeEach(func() { @@ -364,6 +396,10 @@ var _ = Describe("CachingResolver", func() { }) It("response should be cached", func() { + By("default config should enable negative caching", func() { + Expect(sutConfig.CacheTimeNegative).Should(BeNumerically(">", 0)) + }) + By("first request", func() { Expect(sut.Resolve(newRequest("example.com.", AAAA))). Should(SatisfyAll( @@ -495,43 +531,6 @@ var _ = Describe("CachingResolver", func() { }) }) - Describe("Configuration output", func() { - When("resolver is enabled", func() { - BeforeEach(func() { - sutConfig = config.CachingConfig{} - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - - When("resolver is disabled", func() { - BeforeEach(func() { - sutConfig = config.CachingConfig{ - MaxCachingTime: config.Duration(time.Minute * -1), - } - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - - When("prefetching is enabled", func() { - BeforeEach(func() { - sutConfig = config.CachingConfig{ - Prefetching: true, - } - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - Expect(c).Should(ContainElement(ContainSubstring("prefetchThreshold"))) - }) - }) - }) - Describe("Redis is configured", func() { var ( redisServer *miniredis.Miniredis diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index b5bcd2ae..177c2f81 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -1,7 +1,6 @@ package resolver import ( - "fmt" "net" "strings" "time" @@ -18,11 +17,12 @@ import ( // ClientNamesResolver tries to determine client name by asking responsible DNS server via rDNS (reverse lookup) type ClientNamesResolver struct { + configurable[*config.ClientLookupConfig] + NextResolver + typed + cache expirationcache.ExpiringCache externalResolver Resolver - singleNameOrder []uint - clientIPMapping map[string][]net.IP - NextResolver } // NewClientNamesResolver creates new resolver instance @@ -38,38 +38,21 @@ func NewClientNamesResolver( } cr = &ClientNamesResolver{ + configurable: withConfig(&cfg), + typed: withType("client_names"), + cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval(time.Hour)), externalResolver: r, - singleNameOrder: cfg.SingleNameOrder, - clientIPMapping: cfg.ClientnameIPMapping, } return } -// Configuration returns current resolver configuration -func (r *ClientNamesResolver) Configuration() (result []string) { - if r.externalResolver == nil && len(r.clientIPMapping) == 0 { - return append(configDisabled, "use only IP address") - } +// LogConfig implements `config.Configurable`. +func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) { + r.cfg.LogConfig(logger) - result = append(result, fmt.Sprintf("singleNameOrder = \"%v\"", r.singleNameOrder)) - - if r.externalResolver != nil { - result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver)) - } - - result = append(result, fmt.Sprintf("cache item count = %d", r.cache.TotalCount())) - - if len(r.clientIPMapping) > 0 { - result = append(result, "client IP mapping:") - - for k, v := range r.clientIPMapping { - result = append(result, fmt.Sprintf("%s -> %s", k, v)) - } - } - - return + logger.Infof("cache entries = %d", r.cache.TotalCount()) } // Resolve tries to resolve the client name from the ip address @@ -89,13 +72,11 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string { } ip := request.ClientIP - if ip == nil { return []string{} } c, _ := r.cache.Get(ip.String()) - if c != nil { if t, ok := c.([]string); ok { return t @@ -103,6 +84,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string { } names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver")) + r.cache.Put(ip.String(), names, time.Hour) return names @@ -127,7 +109,6 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) { // try client mapping first result = r.getNameFromIPMapping(ip, result) - if len(result) > 0 { return } @@ -151,8 +132,8 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip) // optional: if singleNameOrder is set, use only one name in the defined order - if len(r.singleNameOrder) > 0 { - for _, i := range r.singleNameOrder { + if len(r.cfg.SingleNameOrder) > 0 { + for _, i := range r.cfg.SingleNameOrder { if i > 0 && int(i) <= len(clientNames) { result = []string{clientNames[i-1]} @@ -169,7 +150,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry } func (r *ClientNamesResolver) getNameFromIPMapping(ip net.IP, result []string) []string { - for name, ips := range r.clientIPMapping { + for name, ips := range r.cfg.ClientnameIPMapping { for _, i := range ips { if ip.String() == i.String() { result = append(result, name) diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index 24ed63d0..eff18f44 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -5,6 +5,7 @@ import ( "net" "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/model" @@ -22,6 +23,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { m *mockResolver ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + JustBeforeEach(func() { res, err := NewClientNamesResolver(sutConfig, nil, false) Expect(err).Should(Succeed()) @@ -31,6 +38,22 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Resolve client name from request clientID", func() { BeforeEach(func() { sutConfig = config.ClientLookupConfig{} @@ -281,7 +304,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) }) When("Upstream produces error", func() { - BeforeEach(func() { + JustBeforeEach(func() { sutConfig = config.ClientLookupConfig{} clientMockResolver := &mockResolver{} clientMockResolver.On("Resolve", mock.Anything).Return(nil, errors.New("error")) @@ -348,32 +371,4 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) }) }) - - Describe("Configuration output", func() { - When("resolver is enabled", func() { - BeforeEach(func() { - sutConfig = config.ClientLookupConfig{ - Upstream: config.Upstream{Net: config.NetProtocolTcpUdp, Host: "host"}, - SingleNameOrder: []uint{1, 2}, - ClientnameIPMapping: map[string][]net.IP{ - "client8": {net.ParseIP("1.2.3.5")}, - }, - } - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - - When("resolver is disabled", func() { - BeforeEach(func() { - sutConfig = config.ClientLookupConfig{} - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - }) }) diff --git a/resolver/conditional_upstream_resolver.go b/resolver/conditional_upstream_resolver.go index bf1492aa..3173fc47 100644 --- a/resolver/conditional_upstream_resolver.go +++ b/resolver/conditional_upstream_resolver.go @@ -1,7 +1,6 @@ package resolver import ( - "fmt" "strings" "github.com/0xERR0R/blocky/config" @@ -15,7 +14,10 @@ import ( // ConditionalUpstreamResolver delegates DNS question to other DNS resolver dependent on domain name in question type ConditionalUpstreamResolver struct { + configurable[*config.ConditionalUpstreamConfig] NextResolver + typed + mapping map[string]Resolver } @@ -26,10 +28,13 @@ func NewConditionalUpstreamResolver( m := make(map[string]Resolver, len(cfg.Mapping.Upstreams)) for domain, upstream := range cfg.Mapping.Upstreams { - upstreams := make(map[string][]config.Upstream) - upstreams[upstreamDefaultCfgName] = upstream + pbCfg := config.ParallelBestConfig{ + ExternalResolvers: config.ParallelBestMapping{ + upstreamDefaultCfgName: upstream, + }, + } - r, err := NewParallelBestResolver(upstreams, bootstrap, shouldVerifyUpstreams) + r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams) if err != nil { return nil, err } @@ -37,20 +42,14 @@ func NewConditionalUpstreamResolver( m[strings.ToLower(domain)] = r } - return &ConditionalUpstreamResolver{mapping: m}, nil -} + r := ConditionalUpstreamResolver{ + configurable: withConfig(&cfg), + typed: withType("conditional_upstream"), -// Configuration returns current configuration -func (r *ConditionalUpstreamResolver) Configuration() (result []string) { - if len(r.mapping) == 0 { - return configDisabled + mapping: m, } - for key, val := range r.mapping { - result = append(result, fmt.Sprintf("%s = \"%s\"", key, val)) - } - - return + return &r, nil } func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) { diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 59ef5a88..d134d127 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -3,6 +3,7 @@ package resolver import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -18,6 +19,12 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu m *mockResolver ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") @@ -54,6 +61,22 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Resolve conditional DNS queries via defined DNS server", func() { When("Query is exact equal defined condition in mapping", func() { Context("first mapping entry", func() { @@ -149,22 +172,4 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu Expect(r).Should(BeNil()) }) }) - - Describe("Configuration output", func() { - When("resolver is enabled", func() { - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - When("resolver is disabled", func() { - BeforeEach(func() { - sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, nil, false) - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - }) }) diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index 333a137d..56ff2b50 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -1,10 +1,8 @@ package resolver import ( - "fmt" "net" "strings" - "time" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" @@ -17,11 +15,12 @@ import ( // CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map type CustomDNSResolver struct { + configurable[*config.CustomDNSConfig] NextResolver - mapping map[string][]net.IP - reverseAddresses map[string][]string - ttl uint32 - filterUnmappedTypes bool + typed + + mapping map[string][]net.IP + reverseAddresses map[string][]string } // NewCustomDNSResolver creates new resolver instance @@ -38,27 +37,13 @@ func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver { } } - ttl := uint32(time.Duration(cfg.CustomTTL).Seconds()) - return &CustomDNSResolver{ - mapping: m, - reverseAddresses: reverse, - ttl: ttl, - filterUnmappedTypes: cfg.FilterUnmappedTypes, - } -} + configurable: withConfig(&cfg), + typed: withType("custom_dns"), -// Configuration returns current resolver configuration -func (r *CustomDNSResolver) Configuration() (result []string) { - if len(r.mapping) == 0 { - return configDisabled + mapping: m, + reverseAddresses: reverse, } - - for key, val := range r.mapping { - result = append(result, fmt.Sprintf("%s = \"%s\"", key, val)) - } - - return } func isSupportedType(ip net.IP, question dns.Question) bool { @@ -75,7 +60,7 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp response.SetReply(request.Req) for _, url := range urls { - h := util.CreateHeader(question, r.ttl) + h := util.CreateHeader(question, r.cfg.CustomTTL.SecondsU32()) ptr := new(dns.PTR) ptr.Ptr = dns.Fqdn(url) ptr.Hdr = h @@ -103,7 +88,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon if found { for _, ip := range ips { if isSupportedType(ip, question) { - rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl) + rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32()) response.Answer = append(response.Answer, rr) } } @@ -118,7 +103,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon } // Mapping exists for this domain, but for another type - if !r.filterUnmappedTypes { + if !r.cfg.FilterUnmappedTypes { // go to next resolver break } diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index 3996ba40..854f342f 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -6,6 +6,7 @@ import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" @@ -15,12 +16,18 @@ import ( var _ = Describe("CustomDNSResolver", func() { var ( + TTL = uint32(time.Now().Second()) + sut ChainedResolver m *mockResolver cfg config.CustomDNSConfig ) - TTL := uint32(time.Now().Second()) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) BeforeEach(func() { cfg = config.CustomDNSConfig{ @@ -45,6 +52,22 @@ var _ = Describe("CustomDNSResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Resolving custom name via CustomDNSResolver", func() { When("Ip 4 mapping is defined for custom domain and", func() { Context("filterUnmappedTypes is true", func() { @@ -257,23 +280,4 @@ var _ = Describe("CustomDNSResolver", func() { }) }) }) - - Describe("Configuration output", func() { - When("resolver is enabled", func() { - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - - When("resolver is disabled", func() { - BeforeEach(func() { - cfg = config.CustomDNSConfig{} - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - }) }) diff --git a/resolver/ede_resolver.go b/resolver/ede_resolver.go index a5ee4797..e924a892 100644 --- a/resolver/ede_resolver.go +++ b/resolver/ede_resolver.go @@ -7,77 +7,47 @@ import ( ) type EdeResolver struct { + configurable[*config.EdeConfig] NextResolver - config config.EdeConfig + typed } func NewEdeResolver(cfg config.EdeConfig) ChainedResolver { return &EdeResolver{ - config: cfg, + configurable: withConfig(&cfg), + typed: withType("extended_error_code"), } } func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) { + if !r.cfg.Enable { + return r.next.Resolve(request) + } + resp, err := r.next.Resolve(request) if err != nil { return nil, err } - if r.config.Enable { - addExtraReasoning(resp) - } + r.addExtraReasoning(resp) return resp, nil } -func (r *EdeResolver) Configuration() (result []string) { - if !r.config.Enable { - return configDisabled +func (r *EdeResolver) addExtraReasoning(res *model.Response) { + infocode := res.RType.ToExtendedErrorCode() + + if infocode == dns.ExtendedErrorCodeOther { + // dns.ExtendedErrorCodeOther seams broken in some clients + return } - return configEnabled -} - -func addExtraReasoning(res *model.Response) { - // dns.ExtendedErrorCodeOther seams broken in some clients - infocode := convertToExtendedErrorCode(res.RType) - if infocode > 0 { - opt := new(dns.OPT) - opt.Hdr.Name = "." - opt.Hdr.Rrtype = dns.TypeOPT - opt.Option = append(opt.Option, convertExtendedError(res, infocode)) - res.Res.Extra = append(res.Res.Extra, opt) - } -} - -func convertExtendedError(input *model.Response, infocode uint16) *dns.EDNS0_EDE { - return &dns.EDNS0_EDE{ + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + opt.Option = append(opt.Option, &dns.EDNS0_EDE{ InfoCode: infocode, - ExtraText: input.Reason, - } -} - -func convertToExtendedErrorCode(input model.ResponseType) uint16 { - switch input { - case model.ResponseTypeRESOLVED: - return dns.ExtendedErrorCodeOther - case model.ResponseTypeCACHED: - return dns.ExtendedErrorCodeCachedError - case model.ResponseTypeCONDITIONAL: - return dns.ExtendedErrorCodeForgedAnswer - case model.ResponseTypeCUSTOMDNS: - return dns.ExtendedErrorCodeForgedAnswer - case model.ResponseTypeHOSTSFILE: - return dns.ExtendedErrorCodeForgedAnswer - case model.ResponseTypeNOTFQDN: - return dns.ExtendedErrorCodeBlocked - case model.ResponseTypeBLOCKED: - return dns.ExtendedErrorCodeBlocked - case model.ResponseTypeFILTERED: - return dns.ExtendedErrorCodeFiltered - case model.ResponseTypeSPECIAL: - return dns.ExtendedErrorCodeFiltered - default: - return dns.ExtendedErrorCodeOther - } + ExtraText: res.Reason, + }) + res.Res.Extra = append(res.Res.Extra, opt) } diff --git a/resolver/ede_resolver_test.go b/resolver/ede_resolver_test.go index 35209d45..5038db79 100644 --- a/resolver/ede_resolver_test.go +++ b/resolver/ede_resolver_test.go @@ -5,6 +5,7 @@ import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" @@ -22,6 +23,12 @@ var _ = Describe("EdeResolver", func() { mockAnswer *dns.Msg ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mockAnswer = new(dns.Msg) }) @@ -59,6 +66,12 @@ var _ = Describe("EdeResolver", func() { // delegated to next resolver Expect(m.Calls).Should(HaveLen(1)) }) + + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) }) When("ede is enabled", func() { @@ -106,26 +119,14 @@ var _ = Describe("EdeResolver", func() { Expect(err).To(Equal(resolveErr)) }) }) - }) - Describe("Configuration output", func() { - When("resolver is enabled", func() { - BeforeEach(func() { - sutConfig = config.EdeConfig{Enable: true} - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(c).Should(Equal(configEnabled)) - }) - }) + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() - When("resolver is disabled", func() { - BeforeEach(func() { - sutConfig = config.EdeConfig{Enable: false} - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) }) }) }) diff --git a/resolver/filtering_resolver.go b/resolver/filtering_resolver.go index 57ebf4d2..738c970f 100644 --- a/resolver/filtering_resolver.go +++ b/resolver/filtering_resolver.go @@ -1,10 +1,6 @@ package resolver import ( - "fmt" - "sort" - "strings" - "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/model" "github.com/miekg/dns" @@ -13,13 +9,21 @@ import ( // FilteringResolver filters DNS queries (for example can drop all AAAA query) // returns empty ANSWER with NOERROR type FilteringResolver struct { + configurable[*config.FilteringConfig] NextResolver - queryTypes config.QTypeSet + typed +} + +func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver { + return &FilteringResolver{ + configurable: withConfig(&cfg), + typed: withType("filtering"), + } } func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) { qType := request.Req.Question[0].Qtype - if r.queryTypes.Contains(dns.Type(qType)) { + if r.cfg.QueryTypes.Contains(dns.Type(qType)) { response := new(dns.Msg) response.SetRcode(request.Req, dns.RcodeSuccess) @@ -28,27 +32,3 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er return r.next.Resolve(request) } - -func (r *FilteringResolver) Configuration() (result []string) { - if len(r.queryTypes) == 0 { - return configDisabled - } - - qTypes := make([]string, 0, len(r.queryTypes)) - - for qType := range r.queryTypes { - qTypes = append(qTypes, qType.String()) - } - - sort.Strings(qTypes) - - result = append(result, fmt.Sprintf("filtering query Types: '%v'", strings.Join(qTypes, ", "))) - - return -} - -func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver { - return &FilteringResolver{ - queryTypes: cfg.QueryTypes, - } -} diff --git a/resolver/filtering_resolver_test.go b/resolver/filtering_resolver_test.go index 26f5701e..2c0a02b1 100644 --- a/resolver/filtering_resolver_test.go +++ b/resolver/filtering_resolver_test.go @@ -3,7 +3,9 @@ package resolver import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" + "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -18,6 +20,12 @@ var _ = Describe("FilteringResolver", func() { mockAnswer *dns.Msg ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mockAnswer = new(dns.Msg) }) @@ -29,6 +37,22 @@ var _ = Describe("FilteringResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + When("Filtering query types are defined", func() { BeforeEach(func() { sutConfig = config.FilteringConfig{ @@ -59,10 +83,6 @@ var _ = Describe("FilteringResolver", func() { // no call of next resolver Expect(m.Calls).Should(BeZero()) }) - It("Configure should output all query types", func() { - c := sut.Configuration() - Expect(c).Should(Equal([]string{"filtering query Types: 'AAAA, MX'"})) - }) }) When("No filtering query types are defined", func() { @@ -81,9 +101,5 @@ var _ = Describe("FilteringResolver", func() { // delegated to next resolver Expect(m.Calls).Should(HaveLen(1)) }) - It("Configure should output 'empty list'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) }) }) diff --git a/resolver/fqdn_only_resolver.go b/resolver/fqdn_only_resolver.go index ce0c8dfe..25093f4b 100644 --- a/resolver/fqdn_only_resolver.go +++ b/resolver/fqdn_only_resolver.go @@ -10,18 +10,20 @@ import ( ) type FqdnOnlyResolver struct { + configurable[*config.FqdnOnlyConfig] NextResolver - enabled bool + typed } -func NewFqdnOnlyResolver(cfg config.Config) *FqdnOnlyResolver { +func NewFqdnOnlyResolver(cfg config.FqdnOnlyConfig) *FqdnOnlyResolver { return &FqdnOnlyResolver{ - enabled: cfg.FqdnOnly, + configurable: withConfig(&cfg), + typed: withType("fqdn_only"), } } func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, error) { - if r.enabled { + if r.IsEnabled() { domainFromQuestion := util.ExtractDomain(request.Req.Question[0]) if !strings.Contains(domainFromQuestion, ".") { response := new(dns.Msg) @@ -33,11 +35,3 @@ func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, err return r.next.Resolve(request) } - -func (r *FqdnOnlyResolver) Configuration() (result []string) { - if !r.enabled { - return configDisabled - } - - return configEnabled -} diff --git a/resolver/fqdn_only_resolver_test.go b/resolver/fqdn_only_resolver_test.go index 77aee0fb..7838cc34 100644 --- a/resolver/fqdn_only_resolver_test.go +++ b/resolver/fqdn_only_resolver_test.go @@ -3,6 +3,7 @@ package resolver import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" @@ -13,11 +14,17 @@ import ( var _ = Describe("FqdnOnlyResolver", func() { var ( sut *FqdnOnlyResolver - sutConfig config.Config + sutConfig config.FqdnOnlyConfig m *mockResolver mockAnswer *dns.Msg ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mockAnswer = new(dns.Msg) }) @@ -29,11 +36,25 @@ var _ = Describe("FqdnOnlyResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + When("Fqdn only is enabled", func() { BeforeEach(func() { - sutConfig = config.Config{ - FqdnOnly: true, - } + sutConfig = config.FqdnOnlyConfig{Enable: true} }) It("Should delegate to next resolver if request query is fqdn", func() { Expect(sut.Resolve(newRequest("example.com", A))). @@ -59,17 +80,27 @@ var _ = Describe("FqdnOnlyResolver", func() { // no call of next resolver Expect(m.Calls).Should(BeZero()) }) - It("Configure should output enabled", func() { - c := sut.Configuration() - Expect(c).Should(Equal(configEnabled)) + + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) }) }) When("Fqdn only is disabled", func() { BeforeEach(func() { - sutConfig = config.Config{ - FqdnOnly: false, - } + sutConfig = config.FqdnOnlyConfig{Enable: false} }) It("Should delegate to next resolver if request query is fqdn", func() { Expect(sut.Resolve(newRequest("example.com", A))). @@ -95,9 +126,11 @@ var _ = Describe("FqdnOnlyResolver", func() { // delegated to next resolver Expect(m.Calls).Should(HaveLen(1)) }) - It("Configure should output disabled", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) + + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) }) }) }) diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index 49994bbc..f3518e41 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -9,7 +9,6 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/lists/parsers" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -17,23 +16,44 @@ import ( ) const ( - hostsFileResolverLogger = "hosts_file_resolver" - // reduce initial capacity so we don't waste memory if there are less entries than before memReleaseFactor = 2 ) type HostsFileResolver struct { + configurable[*config.HostsFileConfig] NextResolver - HostsFilePath string - hosts splitHostsFileData - ttl uint32 - refreshPeriod time.Duration - filterLoopback bool + typed + + hosts splitHostsFileData } type HostsFileEntry = parsers.HostsFileEntry +func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver { + r := HostsFileResolver{ + configurable: withConfig(&cfg), + typed: withType("hosts_file"), + } + + if err := r.parseHostsFile(context.Background()); err != nil { + r.log().Errorf("disabling hosts file resolving due to error: %s", err) + + r.cfg.Filepath = "" // don't try parsing the file again + } else { + go r.periodicUpdate() + } + + return &r +} + +// LogConfig implements `config.Configurable`. +func (r *HostsFileResolver) LogConfig(logger *logrus.Entry) { + r.cfg.LogConfig(logger) + + logger.Infof("cache entries = %d", r.hosts.len()) +} + func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Response { question := request.Req.Question[0] if question.Qtype != dns.TypePTR { @@ -46,7 +66,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp return nil } - if r.filterLoopback && questionIP.IsLoopback() { + if r.cfg.FilterLoopback && questionIP.IsLoopback() { // skip the search: we won't find anything return nil } @@ -64,13 +84,13 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp ptr := new(dns.PTR) ptr.Ptr = dns.Fqdn(host) - ptr.Hdr = util.CreateHeader(question, r.ttl) + ptr.Hdr = util.CreateHeader(question, r.cfg.HostsTTL.SecondsU32()) response.Answer = append(response.Answer, ptr) for _, alias := range hostData.Aliases { ptrAlias := new(dns.PTR) ptrAlias.Ptr = dns.Fqdn(alias) - ptrAlias.Hdr = util.CreateHeader(question, r.ttl) + ptrAlias.Hdr = ptr.Hdr response.Answer = append(response.Answer, ptrAlias) } @@ -82,9 +102,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp } func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, hostsFileResolverLogger) - - if r.HostsFilePath == "" { + if r.cfg.Filepath == "" { return r.next.Resolve(request) } @@ -98,7 +116,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er response := r.resolve(request.Req, question, domain) if response != nil { - logger.WithFields(logrus.Fields{ + r.log().WithFields(logrus.Fields{ "answer": util.AnswerToString(response.Answer), "domain": domain, }).Debugf("returning hosts file entry") @@ -106,7 +124,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil } - logger.WithField("resolver", Name(r.next)).Trace("go to next resolver") + r.log().WithField("resolver", Name(r.next)).Trace("go to next resolver") return r.next.Resolve(request) } @@ -117,7 +135,7 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain return nil } - rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl) + rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.HostsTTL.SecondsU32()) response := new(dns.Msg) response.SetReply(req) @@ -126,47 +144,14 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain return response } -func (r *HostsFileResolver) Configuration() (result []string) { - if r.HostsFilePath == "" || r.hosts.isEmpty() { - return configDisabled - } - - result = append(result, fmt.Sprintf("file path: %s", r.HostsFilePath)) - result = append(result, fmt.Sprintf("TTL: %d", r.ttl)) - result = append(result, fmt.Sprintf("refresh period: %s", r.refreshPeriod.String())) - result = append(result, fmt.Sprintf("filter loopback addresses: %t", r.filterLoopback)) - - return -} - -func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver { - r := HostsFileResolver{ - HostsFilePath: cfg.Filepath, - ttl: uint32(time.Duration(cfg.HostsTTL).Seconds()), - refreshPeriod: time.Duration(cfg.RefreshPeriod), - filterLoopback: cfg.FilterLoopback, - } - - if err := r.parseHostsFile(context.Background()); err != nil { - logger := log.PrefixedLog(hostsFileResolverLogger) - logger.Warnf("hosts file resolving is disabled: %s", err) - - r.HostsFilePath = "" // don't try parsing the file again - } else { - go r.periodicUpdate() - } - - return &r -} - func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { const maxErrorsPerFile = 5 - if r.HostsFilePath == "" { + if r.cfg.Filepath == "" { return nil } - f, err := os.Open(r.HostsFilePath) + f, err := os.Open(r.cfg.Filepath) if err != nil { return err } @@ -176,7 +161,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile) p.OnErr(func(err error) { - log.PrefixedLog(hostsFileResolverLogger).Warnf("error parsing %s: %s, trying to continue", r.HostsFilePath, err) + r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err) }) err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error { @@ -187,7 +172,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { } // Ignore loopback, if so configured - if r.filterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") { + if r.cfg.FilterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") { return nil } @@ -196,7 +181,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { return nil }) if err != nil { - return fmt.Errorf("error parsing %s: %w", r.HostsFilePath, err) // err is parsers.ErrTooManyErrors + return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors } r.hosts = newHosts @@ -205,15 +190,14 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { } func (r *HostsFileResolver) periodicUpdate() { - if r.refreshPeriod > 0 { - ticker := time.NewTicker(r.refreshPeriod) + if r.cfg.RefreshPeriod.ToDuration() > 0 { + ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration()) defer ticker.Stop() for { <-ticker.C - logger := log.PrefixedLog(hostsFileResolverLogger) - logger.WithField("file", r.HostsFilePath).Debug("refreshing hosts file") + r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file") util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background())) } @@ -238,7 +222,11 @@ func newSplitHostsDataWithSameCapacity(other splitHostsFileData) splitHostsFileD } func (d splitHostsFileData) isEmpty() bool { - return d.v4.isEmpty() && d.v6.isEmpty() + return d.len() == 0 +} + +func (d splitHostsFileData) len() int { + return d.v4.len() + d.v6.len() } func (d splitHostsFileData) getIP(qType dns.Type, domain string) net.IP { @@ -277,8 +265,8 @@ func newHostsDataWithSameCapacity(other hostsFileData) hostsFileData { } } -func (d hostsFileData) isEmpty() bool { - return len(d.hosts) == 0 && len(d.aliases) == 0 +func (d hostsFileData) len() int { + return len(d.hosts) + len(d.aliases) } func (d hostsFileData) getIP(hostname string) net.IP { diff --git a/resolver/hosts_file_resolver_test.go b/resolver/hosts_file_resolver_test.go index 1d108fc6..2851ccdd 100644 --- a/resolver/hosts_file_resolver_test.go +++ b/resolver/hosts_file_resolver_test.go @@ -2,12 +2,11 @@ package resolver import ( "context" - "fmt" - "math/rand" "time" "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" @@ -17,6 +16,8 @@ import ( var _ = Describe("HostsFileResolver", func() { var ( + TTL = uint32(time.Now().Second()) + sut *HostsFileResolver sutConfig config.HostsFileConfig m *mockResolver @@ -24,7 +25,11 @@ var _ = Describe("HostsFileResolver", func() { tmpFile *TmpFile ) - TTL := uint32(time.Now().Second()) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) BeforeEach(func() { tmpDir = NewTmpFolder("HostsFileResolver") @@ -49,16 +54,32 @@ var _ = Describe("HostsFileResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Using hosts file", func() { When("Hosts file cannot be located", func() { BeforeEach(func() { sutConfig = config.HostsFileConfig{ - Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()), + Filepath: "/this/file/does/not/exist", HostsTTL: config.Duration(time.Duration(TTL) * time.Second), } }) It("should not parse any hosts", func() { - Expect(sut.HostsFilePath).Should(BeEmpty()) + Expect(sut.cfg.Filepath).Should(BeEmpty()) Expect(sut.hosts.v4.hosts).Should(BeEmpty()) Expect(sut.hosts.v6.hosts).Should(BeEmpty()) Expect(sut.hosts.v4.aliases).Should(BeEmpty()) @@ -140,11 +161,11 @@ var _ = Describe("HostsFileResolver", func() { It("should not be used", func() { Expect(sut).ShouldNot(BeNil()) - Expect(sut.HostsFilePath).Should(BeEmpty()) - Expect(sut.hosts.v4.hosts).Should(HaveLen(0)) - Expect(sut.hosts.v6.hosts).Should(HaveLen(0)) - Expect(sut.hosts.v4.aliases).Should(HaveLen(0)) - Expect(sut.hosts.v6.aliases).Should(HaveLen(0)) + Expect(sut.cfg.Filepath).Should(BeEmpty()) + Expect(sut.hosts.v4.hosts).Should(BeEmpty()) + Expect(sut.hosts.v6.hosts).Should(BeEmpty()) + Expect(sut.hosts.v4.aliases).Should(BeEmpty()) + Expect(sut.hosts.v6.aliases).Should(BeEmpty()) }) }) @@ -307,25 +328,6 @@ var _ = Describe("HostsFileResolver", func() { }) }) - Describe("Configuration output", func() { - When("hosts file is provided", func() { - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - - When("hosts file is not provided", func() { - BeforeEach(func() { - sutConfig = config.HostsFileConfig{} - }) - It("should return 'disabled'", func() { - c := sut.Configuration() - Expect(c).Should(ContainElement(configStatusDisabled)) - }) - }) - }) - Describe("Delegating to next resolver", func() { When("no hosts file is provided", func() { It("should delegate to next resolver", func() { diff --git a/resolver/metrics_resolver.go b/resolver/metrics_resolver.go index af916854..0c15c00f 100644 --- a/resolver/metrics_resolver.go +++ b/resolver/metrics_resolver.go @@ -1,7 +1,6 @@ package resolver import ( - "fmt" "strings" "time" @@ -15,8 +14,10 @@ import ( // MetricsResolver resolver that records metrics about requests/response type MetricsResolver struct { + configurable[*config.MetricsConfig] NextResolver - cfg config.PrometheusConfig + typed + totalQueries *prometheus.CounterVec totalResponse *prometheus.CounterVec totalErrors prometheus.Counter @@ -24,11 +25,11 @@ type MetricsResolver struct { } // Resolve resolves the passed request -func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) { - response, err := m.next.Resolve(request) +func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) { + response, err := r.next.Resolve(request) - if m.cfg.Enable { - m.totalQueries.With(prometheus.Labels{ + if r.cfg.Enable { + r.totalQueries.With(prometheus.Labels{ "client": strings.Join(request.ClientNames, ","), "type": dns.TypeToString[request.Req.Question[0].Qtype], }).Inc() @@ -40,12 +41,12 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro responseType = response.RType.String() } - m.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs) + r.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs) if err != nil { - m.totalErrors.Inc() + r.totalErrors.Inc() } else { - m.totalResponse.With(prometheus.Labels{ + r.totalResponse.With(prometheus.Labels{ "reason": response.Reason, "response_code": dns.RcodeToString[response.Res.Rcode], "response_type": response.RType.String(), @@ -56,38 +57,28 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro return response, err } -// Configuration gets the config of this resolver in a string slice -func (m *MetricsResolver) Configuration() (result []string) { - if !m.cfg.Enable { - return configDisabled +// NewMetricsResolver creates a new intance of the MetricsResolver type +func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver { + m := MetricsResolver{ + configurable: withConfig(&cfg), + typed: withType("metrics"), + + durationHistogram: durationHistogram(), + totalQueries: totalQueriesMetric(), + totalResponse: totalResponseMetric(), + totalErrors: totalErrorMetric(), } - result = append(result, "metrics:") - result = append(result, fmt.Sprintf(" Enable = %t", m.cfg.Enable)) - result = append(result, fmt.Sprintf(" Path = %s", m.cfg.Path)) + m.registerMetrics() - return + return &m } -// NewMetricsResolver creates a new intance of the MetricsResolver type -func NewMetricsResolver(cfg config.PrometheusConfig) ChainedResolver { - durationHistogram := durationHistogram() - totalQueries := totalQueriesMetric() - totalResponse := totalResponseMetric() - totalErrors := totalErrorMetric() - - metrics.RegisterMetric(durationHistogram) - metrics.RegisterMetric(totalQueries) - metrics.RegisterMetric(totalResponse) - metrics.RegisterMetric(totalErrors) - - return &MetricsResolver{ - cfg: cfg, - durationHistogram: durationHistogram, - totalQueries: totalQueries, - totalResponse: totalResponse, - totalErrors: totalErrors, - } +func (r *MetricsResolver) registerMetrics() { + metrics.RegisterMetric(r.durationHistogram) + metrics.RegisterMetric(r.totalQueries) + metrics.RegisterMetric(r.totalResponse) + metrics.RegisterMetric(r.totalErrors) } func totalQueriesMetric() *prometheus.CounterVec { diff --git a/resolver/metrics_resolver_test.go b/resolver/metrics_resolver_test.go index a5a184dc..eca8ab69 100644 --- a/resolver/metrics_resolver_test.go +++ b/resolver/metrics_resolver_test.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/model" @@ -22,13 +23,35 @@ var _ = Describe("MetricResolver", func() { m *mockResolver ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { - sut = NewMetricsResolver(config.PrometheusConfig{Enable: true}).(*MetricsResolver) + sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Recording prometheus metrics", func() { Context("Recording request metrics", func() { When("Request will be performed", func() { @@ -63,13 +86,4 @@ var _ = Describe("MetricResolver", func() { }) }) }) - - Describe("Configuration output", func() { - When("resolver is enabled", func() { - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - }) }) diff --git a/resolver/mocks_test.go b/resolver/mocks_test.go index 09557a01..567e0600 100644 --- a/resolver/mocks_test.go +++ b/resolver/mocks_test.go @@ -11,6 +11,7 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/util" + "github.com/sirupsen/logrus" "github.com/0xERR0R/blocky/model" @@ -27,10 +28,21 @@ type mockResolver struct { AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error) } -func (r *mockResolver) Configuration() []string { +// Type implements `Resolver`. +func (r *mockResolver) Type() string { + return "mock" +} + +// IsEnabled implements `config.Configurable`. +func (r *mockResolver) IsEnabled() bool { args := r.Called() - return args.Get(0).([]string) + return args.Get(0).(bool) +} + +// LogConfig implements `config.Configurable`. +func (r *mockResolver) LogConfig(logger *logrus.Entry) { + r.Called() } func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) { diff --git a/resolver/noop_resolver.go b/resolver/noop_resolver.go index 89335ae3..9b3f6885 100644 --- a/resolver/noop_resolver.go +++ b/resolver/noop_resolver.go @@ -2,6 +2,7 @@ package resolver import ( "github.com/0xERR0R/blocky/model" + "github.com/sirupsen/logrus" ) var NoResponse = &model.Response{} //nolint:gochecknoglobals @@ -13,10 +14,20 @@ func NewNoOpResolver() Resolver { return NoOpResolver{} } -func (r NoOpResolver) Configuration() (result []string) { - return nil +// Type implements `Resolver`. +func (NoOpResolver) Type() string { + return "noop" } -func (r NoOpResolver) Resolve(request *model.Request) (*model.Response, error) { +// IsEnabled implements `config.Configurable`. +func (NoOpResolver) IsEnabled() bool { + return true +} + +// LogConfig implements `config.Configurable`. +func (NoOpResolver) LogConfig(*logrus.Entry) { +} + +func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) { return NoResponse, nil } diff --git a/resolver/noop_resolver_test.go b/resolver/noop_resolver_test.go index 31cd9cbd..6493fef1 100644 --- a/resolver/noop_resolver_test.go +++ b/resolver/noop_resolver_test.go @@ -2,6 +2,7 @@ package resolver import ( . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -9,6 +10,12 @@ import ( var _ = Describe("NoOpResolver", func() { var sut NoOpResolver + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { sut = NewNoOpResolver().(NoOpResolver) }) @@ -21,10 +28,19 @@ var _ = Describe("NoOpResolver", func() { }) }) - Describe("Configuration output", func() { - It("returns nothing", func() { - c := sut.Configuration() - Expect(c).Should(BeNil()) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should not log anything", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).Should(BeEmpty()) }) }) }) diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index 60fc0df9..4cb8bcd8 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -20,13 +20,16 @@ import ( ) const ( - upstreamDefaultCfgName = "default" - parallelResolverLogger = "parallel_best_resolver" + upstreamDefaultCfgName = config.UpstreamDefaultCfgName + parallelResolverType = "parallel_best" resolverCount = 2 ) // ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer type ParallelBestResolver struct { + configurable[*config.ParallelBestConfig] + typed + resolversPerClient map[string][]*upstreamResolverStatus } @@ -77,10 +80,11 @@ func testResolver(r *UpstreamResolver) error { // NewParallelBestResolver creates new resolver instance func NewParallelBestResolver( - upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool, -) (Resolver, error) { - logger := log.PrefixedLog(parallelResolverLogger) + cfg config.ParallelBestConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, +) (*ParallelBestResolver, error) { + logger := log.PrefixedLog(parallelResolverType) + upstreamResolvers := cfg.ExternalResolvers resolverGroups := make(map[string][]Resolver, len(upstreamResolvers)) for name, upstreamCfgs := range upstreamResolvers { @@ -114,10 +118,12 @@ func NewParallelBestResolver( resolverGroups[name] = group } - return newParallelBestResolver(resolverGroups) + return newParallelBestResolver(cfg, resolverGroups) } -func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, error) { +func newParallelBestResolver( + cfg config.ParallelBestConfig, resolverGroups map[string][]Resolver, +) (*ParallelBestResolver, error) { resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups)) for groupName, resolvers := range resolverGroups { @@ -136,27 +142,21 @@ func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, er } r := ParallelBestResolver{ + configurable: withConfig(&cfg), + typed: withType(parallelResolverType), + resolversPerClient: resolversPerClient, } return &r, nil } -// Configuration returns current resolver configuration -func (r *ParallelBestResolver) Configuration() (result []string) { - result = append(result, "upstream resolvers:") - for name, res := range r.resolversPerClient { - result = append(result, fmt.Sprintf("- %s", name)) - for _, r := range res { - result = append(result, fmt.Sprintf(" - %s", r.resolver)) - } - } - - return +func (r *ParallelBestResolver) Name() string { + return r.String() } -func (r ParallelBestResolver) String() string { - result := make([]string, 0) +func (r *ParallelBestResolver) String() string { + result := make([]string, 0, len(r.resolversPerClient)) for name, res := range r.resolversPerClient { tmp := make([]string, len(res)) @@ -206,7 +206,7 @@ func (r *ParallelBestResolver) resolversForClient(request *model.Request) (resul // Resolve sends the query request to multiple upstream resolvers and returns the fastest result func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, parallelResolverLogger) + logger := log.WithPrefix(request.Log, parallelResolverType) resolvers := r.resolversForClient(request) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 47494699..1a80368f 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -6,6 +6,7 @@ import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -19,13 +20,74 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { noVerifyUpstreams = false ) - systemResolverBootstrap := &Bootstrap{} + var ( + sut *ParallelBestResolver + sutMapping config.ParallelBestMapping + sutVerify bool + + err error + + bootstrap *Bootstrap + ) + + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + + BeforeEach(func() { + sutMapping = config.ParallelBestMapping{ + upstreamDefaultCfgName: { + config.Upstream{ + Host: "wrong", + }, + config.Upstream{ + Host: "127.0.0.2", + }, + }, + } + + sutVerify = noVerifyUpstreams + + bootstrap = systemResolverBootstrap + }) + + JustBeforeEach(func() { + sutConfig := config.ParallelBestConfig{ExternalResolvers: sutMapping} + + sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) + }) config.GetConfig().UpstreamTimeout = config.Duration(1000 * time.Millisecond) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + + Describe("Name", func() { + It("should not be empty", func() { + Expect(sut.Name()).ShouldNot(BeEmpty()) + }) + }) + When("default upstream resolvers are not defined", func() { It("should fail on startup", func() { - _, err := NewParallelBestResolver(map[string][]config.Upstream{}, nil, noVerifyUpstreams) + _, err := NewParallelBestResolver(config.ParallelBestConfig{ + ExternalResolvers: config.ParallelBestMapping{}, + }, nil, noVerifyUpstreams) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured")) }) @@ -40,7 +102,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) defer mockUpstream.Close() - upstream := map[string][]config.Upstream{ + upstream := config.ParallelBestMapping{ upstreamDefaultCfgName: { config.Upstream{ Host: "wrong", @@ -49,21 +111,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }, } - _, err := NewParallelBestResolver(upstream, systemResolverBootstrap, verifyUpstreams) + _, err := NewParallelBestResolver(config.ParallelBestConfig{ + ExternalResolvers: upstream, + }, systemResolverBootstrap, verifyUpstreams) Expect(err).Should(Not(HaveOccurred())) }) }) When("no upstream resolvers can be reached", func() { - var ( - upstream map[string][]config.Upstream - b *Bootstrap - ) - BeforeEach(func() { - b = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - upstream = map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: { config.Upstream{ Host: "wrong", @@ -75,22 +134,26 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { } }) - It("should fail to start if strict checking is enabled", func() { - _, err := NewParallelBestResolver(upstream, b, verifyUpstreams) - Expect(err).Should(HaveOccurred()) + When("strict checking is enabled", func() { + BeforeEach(func() { + sutVerify = verifyUpstreams + }) + It("should fail to start", func() { + Expect(err).Should(HaveOccurred()) + }) }) - It("should start if strict checking is disabled", func() { - _, err := NewParallelBestResolver(upstream, b, noVerifyUpstreams) - Expect(err).Should(Not(HaveOccurred())) + When("strict checking is disabled", func() { + BeforeEach(func() { + sutVerify = noVerifyUpstreams + }) + It("should start", func() { + Expect(err).Should(Not(HaveOccurred())) + }) }) }) Describe("Resolving result from fastest upstream resolver", func() { - var ( - sut Resolver - err error - ) When("2 Upstream resolvers are defined", func() { When("one resolver is fast and another is slow", func() { BeforeEach(func() { @@ -107,10 +170,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) DeferCleanup(slowTestUpstream.Close) - sut, err = NewParallelBestResolver(map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()}, - }, nil, noVerifyUpstreams) - Expect(err).Should(Succeed()) + } }) It("Should use result from fastest one", func() { request := newRequest("example.com.", A) @@ -136,9 +198,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { return response }) DeferCleanup(slowTestUpstream.Close) - sut, err = NewParallelBestResolver(map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: {withErrorUpstream, slowTestUpstream.Start()}, - }, systemResolverBootstrap, noVerifyUpstreams) + } Expect(err).Should(Succeed()) }) It("Should use result from successful resolver", func() { @@ -158,9 +220,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { withError1 := config.Upstream{Host: "wrong"} withError2 := config.Upstream{Host: "wrong"} - sut, err = NewParallelBestResolver(map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: {withError1, withError2}, - }, systemResolverBootstrap, noVerifyUpstreams) + } Expect(err).Should(Succeed()) }) It("Should return error", func() { @@ -194,14 +256,14 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { WithAnswerRR("example.com 123 IN A 123.124.122.126") DeferCleanup(clientSpecificCIRDMockUpstream.Close) - sut, _ = NewParallelBestResolver(map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: {defaultMockUpstream.Start()}, "laptop": {clientSpecificExactMockUpstream.Start()}, "client-*-m": {clientSpecificWildcardMockUpstream.Start()}, "client[0-9]": {clientSpecificWildcardMockUpstream.Start()}, "192.168.178.33": {clientSpecificIPMockUpstream.Start()}, "10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()}, - }, nil, noVerifyUpstreams) + } }) It("Should use default if client name or IP don't match", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") @@ -294,11 +356,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream.Close) - sut, _ = NewParallelBestResolver(map[string][]config.Upstream{ + sutMapping = config.ParallelBestMapping{ upstreamDefaultCfgName: { mockUpstream.Start(), }, - }, nil, noVerifyUpstreams) + } }) It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) @@ -327,10 +389,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - tmp, _ := NewParallelBestResolver(map[string][]config.Upstream{ + sut, _ := NewParallelBestResolver(config.ParallelBestConfig{ExternalResolvers: config.ParallelBestMapping{ upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, - }, systemResolverBootstrap, noVerifyUpstreams) - sut := tmp.(*ParallelBestResolver) + }}, systemResolverBootstrap, noVerifyUpstreams) By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) @@ -391,26 +452,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b, verifyUpstreams) + r, err := NewParallelBestResolver(config.ParallelBestConfig{ + ExternalResolvers: config.ParallelBestMapping{"test": {{Host: "example.com"}}}, + }, b, verifyUpstreams) Expect(err).ShouldNot(Succeed()) Expect(r).Should(BeNil()) }) }) - - Describe("Configuration output", func() { - var sut Resolver - BeforeEach(func() { - config.GetConfig().StartVerifyUpstream = false - - sut, _ = NewParallelBestResolver(map[string][]config.Upstream{upstreamDefaultCfgName: { - {Host: "host1"}, - {Host: "host2"}, - }}, nil, noVerifyUpstreams) - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) }) diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go index a06d5694..1b80a6f8 100644 --- a/resolver/query_logging_resolver.go +++ b/resolver/query_logging_resolver.go @@ -1,7 +1,6 @@ package resolver import ( - "fmt" "time" "github.com/0xERR0R/blocky/config" @@ -14,35 +13,32 @@ import ( ) const ( - cleanUpRunPeriod = 12 * time.Hour - queryLoggingResolverPrefix = "query_logging_resolver" - logChanCap = 1000 - defaultFlushPeriod = 30 * time.Second + cleanUpRunPeriod = 12 * time.Hour + queryLoggingResolverType = "query_logging" + logChanCap = 1000 + defaultFlushPeriod = 30 * time.Second ) // QueryLoggingResolver writes query information (question, answer, duration, ...) type QueryLoggingResolver struct { + configurable[*config.QueryLogConfig] NextResolver - target string - logRetentionDays uint64 - logChan chan *querylog.LogEntry - writer querylog.Writer - logType config.QueryLogType - fields []config.QueryLogField + typed + + logChan chan *querylog.LogEntry + writer querylog.Writer } // NewQueryLoggingResolver returns a new resolver instance func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { - logger := log.PrefixedLog(queryLoggingResolverPrefix) + logger := log.PrefixedLog(queryLoggingResolverType) var writer querylog.Writer - logType := cfg.Type - err := retry.Do( func() error { var err error - switch logType { + switch cfg.Type { case config.QueryLogTypeCsv: writer, err = querylog.NewCSVWriter(cfg.Target, false, cfg.LogRetentionDays) case config.QueryLogTypeCsvClient: @@ -61,7 +57,7 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { }, retry.Attempts(uint(cfg.CreationAttempts)), retry.DelayType(retry.FixedDelay), - retry.Delay(time.Duration(cfg.CreationCooldown)), + retry.Delay(cfg.CreationCooldown.ToDuration()), retry.OnRetry(func(n uint, err error) { logger.Warnf( "Error occurred on query writer creation, retry attempt %d/%d: %v", n+1, cfg.CreationAttempts, err, @@ -71,18 +67,17 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { logger.Error("can't create query log writer, using console as fallback: ", err) writer = querylog.NewLoggerWriter() - logType = config.QueryLogTypeConsole + cfg.Type = config.QueryLogTypeConsole } logChan := make(chan *querylog.LogEntry, logChanCap) resolver := QueryLoggingResolver{ - target: cfg.Target, - logRetentionDays: cfg.LogRetentionDays, - logChan: logChan, - writer: writer, - logType: logType, - fields: resolveQueryLogFields(cfg), + configurable: withConfig(&cfg), + typed: withType(queryLoggingResolverType), + + logChan: logChan, + writer: writer, } go resolver.writeLog() @@ -94,24 +89,6 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { return &resolver } -func resolveQueryLogFields(cfg config.QueryLogConfig) []config.QueryLogField { - var fields []config.QueryLogField - - if len(cfg.Fields) == 0 { - // no fields defined, use all fields as fallback - for _, v := range config.QueryLogFieldNames() { - qlt, err := config.ParseQueryLogField(v) - util.LogOnError("ignoring unknown query log field", err) - - fields = append(fields, qlt) - } - } else { - fields = cfg.Fields - } - - return fields -} - // triggers periodically cleanup of old log files func (r *QueryLoggingResolver) periodicCleanUp() { ticker := time.NewTicker(cleanUpRunPeriod) @@ -129,7 +106,7 @@ func (r *QueryLoggingResolver) doCleanUp() { // Resolve logs the query, duration and the result func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, queryLoggingResolverPrefix) + logger := log.WithPrefix(request.Log, queryLoggingResolverType) start := time.Now() @@ -157,7 +134,7 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response * ClientNames: []string{"none"}, } - for _, f := range r.fields { + for _, f := range r.cfg.Fields { switch f { case config.QueryLogFieldClientIP: entry.ClientIP = request.ClientIP.String() @@ -196,18 +173,8 @@ func (r *QueryLoggingResolver) writeLog() { // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) if len(r.logChan) > halfCap { - log.PrefixedLog(queryLoggingResolverPrefix).WithField("channel_len", + r.log().WithField("channel_len", len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) } } } - -// Configuration returns the current resolver configuration -func (r *QueryLoggingResolver) Configuration() (result []string) { - result = append(result, fmt.Sprintf("type: \"%s\"", r.logType)) - result = append(result, fmt.Sprintf("target: \"%s\"", r.target)) - result = append(result, fmt.Sprintf("logRetentionDays: %d", r.logRetentionDays)) - result = append(result, fmt.Sprintf("fields: %s", r.fields)) - - return -} diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index 4d3cd0d2..1a4d02ff 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -10,6 +10,7 @@ import ( "time" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/querylog" "github.com/0xERR0R/blocky/config" @@ -44,6 +45,12 @@ var _ = Describe("QueryLoggingResolver", func() { mockAnswer *dns.Msg ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mockAnswer = new(dns.Msg) tmpDir = NewTmpFolder("queryLoggingResolver") @@ -52,6 +59,10 @@ var _ = Describe("QueryLoggingResolver", func() { }) JustBeforeEach(func() { + if len(sutConfig.Fields) == 0 { + sutConfig.SetDefaults() // not called when using a struct literal + } + sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver) DeferCleanup(func() { close(sut.logChan) }) m = &mockResolver{} @@ -59,6 +70,22 @@ var _ = Describe("QueryLoggingResolver", func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Process request", func() { When("Resolver has no configuration", func() { BeforeEach(func() { @@ -276,24 +303,6 @@ var _ = Describe("QueryLoggingResolver", func() { }) }) - Describe("Configuration output", func() { - When("resolver is enabled", func() { - BeforeEach(func() { - sutConfig = config.QueryLogConfig{ - Target: tmpDir.Path, - Type: config.QueryLogTypeCsvClient, - LogRetentionDays: 0, - CreationAttempts: 1, - CreationCooldown: config.Duration(time.Millisecond), - } - }) - It("should return configuration", func() { - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", 1)) - }) - }) - }) - Describe("Clean up of query log directory", func() { When("fallback logger is enabled, log retention is enabled", func() { BeforeEach(func() { @@ -355,7 +364,7 @@ var _ = Describe("QueryLoggingResolver", func() { } }) It("should use fallback", func() { - Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole)) + Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole)) }) }) @@ -369,7 +378,7 @@ var _ = Describe("QueryLoggingResolver", func() { } }) It("should use fallback", func() { - Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole)) + Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole)) }) }) }) diff --git a/resolver/resolver.go b/resolver/resolver.go index 786e472f..e1e9c83e 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -1,11 +1,10 @@ package resolver import ( - "fmt" "net" - "strings" "time" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -14,20 +13,6 @@ import ( "github.com/sirupsen/logrus" ) -// Resolver is not configured. -const ( - configStatusEnabled string = "enabled" - - configStatusDisabled string = "disabled" -) - -var ( - // note: this is not used by all resolvers: only those that don't print any other configuration - configEnabled = []string{configStatusEnabled} //nolint:gochecknoglobals - - configDisabled = []string{configStatusDisabled} //nolint:gochecknoglobals -) - func newRequest(question string, rType dns.Type, logger ...*logrus.Entry) *model.Request { var loggerEntry *logrus.Entry if len(logger) == 1 { @@ -84,11 +69,15 @@ func newRequestWithClientID(question string, rType dns.Type, ip, requestClientID // Resolver generic interface for all resolvers type Resolver interface { + config.Configurable + + // Type returns a short, user-friendly, name for the resolver. + // + // It should be the same for all instances of a specific Resolver type. + Type() string + // Resolve performs resolution of a DNS request Resolve(req *model.Request) (*model.Response, error) - - // Configuration returns current resolver configuration - Configuration() []string } // ChainedResolver represents a resolver, which can delegate result to the next one @@ -142,10 +131,73 @@ func Name(resolver Resolver) string { return named.Name() } - return defaultName(resolver) + return resolver.Type() } -// defaultName returns a short user-friendly name of a resolver -func defaultName(resolver Resolver) string { - return strings.Split(fmt.Sprintf("%T", resolver), ".")[1] +// ForEach iterates over all resolvers in the chain. +// +// If resolver is not a chain, or is unlinked, +// the callback is called exactly once. +func ForEach(resolver Resolver, callback func(Resolver)) { + for resolver != nil { + callback(resolver) + + if chained, ok := resolver.(ChainedResolver); ok { + resolver = chained.GetNext() + } else { + break + } + } +} + +// LogResolverConfig logs the resolver's type and config. +func LogResolverConfig(res Resolver, logger *logrus.Entry) { + // Use the type, not the full typeName, to avoid redundant information with the config + typeName := res.Type() + + if !res.IsEnabled() { + logger.Debugf("-> %s: disabled", typeName) + + return + } + + logger.Infof("-> %s:", typeName) + log.WithIndent(logger, " ", res.LogConfig) +} + +// Should be embedded in a Resolver to auto-implement `Resolver.Type`. +type typed struct { + typeName string +} + +func withType(t string) typed { + return typed{typeName: t} +} + +// Type implements `Resolver`. +func (t *typed) Type() string { + return t.typeName +} + +func (t *typed) log() *logrus.Entry { + return log.PrefixedLog(t.Type()) +} + +// Should be embedded in a Resolver to auto-implement `config.Configurable`. +type configurable[T config.Configurable] struct { + cfg T +} + +func withConfig[T config.Configurable](cfg T) configurable[T] { + return configurable[T]{cfg: cfg} +} + +// IsEnabled implements `config.Configurable`. +func (c *configurable[T]) IsEnabled() bool { + return c.cfg.IsEnabled() +} + +// LogConfig implements `config.Configurable`. +func (c *configurable[T]) LogConfig(logger *logrus.Entry) { + c.cfg.LogConfig(logger) } diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 1a0f85f8..003e18da 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -1,45 +1,127 @@ package resolver import ( + "strings" + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" + "github.com/sirupsen/logrus" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var systemResolverBootstrap = &Bootstrap{} + var _ = Describe("Resolver", func() { - systemResolverBootstrap := &Bootstrap{} + Describe("Chains", func() { + var ( + r1 ChainedResolver + r2 ChainedResolver + r3 ChainedResolver + r4 Resolver + ) - Describe("Creating resolver chain", func() { - When("A chain of resolvers will be created", func() { - It("should be iterable by calling 'GetNext'", func() { - br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) - cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, nil, false) - ch := Chain(br, cr) - c, ok := ch.(ChainedResolver) - Expect(ok).Should(BeTrue()) + BeforeEach(func() { + r1 = &mockResolver{} + r2 = &mockResolver{} + r3 = &mockResolver{} + r4 = &NoOpResolver{} + }) - next := c.GetNext() - Expect(next).ShouldNot(BeNil()) + Describe("Chain", func() { + It("should create a chain iterable using `GetNext`", func() { + ch := Chain(r1, r2, r3, r4) + Expect(ch).ShouldNot(BeNil()) + Expect(ch).Should(Equal(r1)) + Expect(r1.GetNext()).Should(Equal(r2)) + Expect(r2.GetNext()).Should(Equal(r3)) + Expect(r3.GetNext()).Should(Equal(r4)) + }) + + It("should not link a final ChainedResolver", func() { + ch := Chain(r1, r2) + Expect(ch).ShouldNot(BeNil()) + + Expect(r1.GetNext()).Should(Equal(r2)) + Expect(r2.GetNext()).Should(BeNil()) }) }) + + Describe("ForEach", func() { + It("should iterate on all resolvers in the chain", func() { + ch := Chain(r1, r2, r3, r4) + Expect(ch).ShouldNot(BeNil()) + + var itResult []Resolver + + ForEach(ch, func(r Resolver) { + itResult = append(itResult, r) + }) + + Expect(itResult).ShouldNot(BeEmpty()) + Expect(itResult).Should(Equal([]Resolver{r1, r2, r3, r4})) + }) + }) + + Describe("LogResolverConfig", func() { + It("should call the resolver's `LogConfig`", func() { + logger := logrus.NewEntry(log.Log()) + + m := &mockResolver{} + m.On("IsEnabled").Return(true) + m.On("LogConfig") + + LogResolverConfig(m, logger) + + m.AssertExpectations(GinkgoT()) + }) + + When("the resolver is disabled", func() { + It("should not call the resolver's `LogConfig`", func() { + logger := logrus.NewEntry(log.Log()) + + m := &mockResolver{} + m.On("IsEnabled").Return(false) + + LogResolverConfig(m, logger) + + m.AssertExpectations(GinkgoT()) + }) + }) + }) + }) + + Describe("Name", func() { When("'Name' is called", func() { It("should return resolver name", func() { br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) name := Name(br) - Expect(name).Should(Equal("BlockingResolver")) + Expect(name).Should(Equal("blocking")) }) }) When("'Name' is called on a NamedResolver", func() { - It("should return it's custom name", func() { + It("should return its custom name", func() { br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) - cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}} + cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}} r := NewRewriterResolver(cfg, br) name := Name(r) - Expect(name).Should(Equal("BlockingResolver w/ RewriterResolver")) + Expect(name).Should(Equal("blocking w/ rewrite")) }) }) }) }) + +func expectValidResolverType(sut Resolver) { + By("it must not contain spaces", func() { + Expect(sut.Type()).ShouldNot(ContainSubstring(" ")) + }) + By("it must be lower case", func() { + Expect(sut.Type()).Should(Equal(strings.ToLower(sut.Type()))) + }) + By("it must not contain 'resolver'", func() { + Expect(sut.Type()).ShouldNot(ContainSubstring("resolver")) + }) +} diff --git a/resolver/rewriter_resolver.go b/resolver/rewriter_resolver.go index d8cabfd2..8024b546 100644 --- a/resolver/rewriter_resolver.go +++ b/resolver/rewriter_resolver.go @@ -18,13 +18,14 @@ import ( // The branch is where the rewrite is active. If the branch doesn't // yield a result, the normal resolving is continued. type RewriterResolver struct { + configurable[*config.RewriterConfig] NextResolver - rewrite map[string]string - inner Resolver - fallbackUpstream bool + typed + + inner Resolver } -func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) ChainedResolver { +func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) ChainedResolver { if len(cfg.Rewrite) == 0 { return inner } @@ -36,27 +37,22 @@ func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) Chaine inner.Next(NewNoOpResolver()) return &RewriterResolver{ - rewrite: cfg.Rewrite, - inner: inner, - fallbackUpstream: cfg.FallbackUpstream, + configurable: withConfig(&cfg), + typed: withType("rewrite"), + + inner: inner, } } func (r *RewriterResolver) Name() string { - return fmt.Sprintf("%s w/ %s", Name(r.inner), defaultName(r)) + return fmt.Sprintf("%s w/ %s", Name(r.inner), r.Type()) } -// Configuration returns current resolver configuration -func (r *RewriterResolver) Configuration() (result []string) { - result = append(result, "rewrite:") - for key, val := range r.rewrite { - result = append(result, fmt.Sprintf(" %s = \"%s\"", key, val)) - } +// LogConfig implements `config.Configurable`. +func (r *RewriterResolver) LogConfig(logger *logrus.Entry) { + LogResolverConfig(r.inner, logger) - innerCfg := r.inner.Configuration() - result = append(result, innerCfg...) - - return result + r.cfg.LogConfig(logger) } // Resolve uses the inner resolver to resolve the rewritten query @@ -79,7 +75,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err request.Req = original fallbackCondition := err != nil || (response != NoResponse && response.Res.Answer == nil) - if r.fallbackUpstream && fallbackCondition { + if r.cfg.FallbackUpstream && fallbackCondition { // Inner resolver had no answer, configuration requests fallback, continue with the normal chain logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver") @@ -130,7 +126,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg logger.WithFields(logrus.Fields{ "domain": domainOriginal, - "rewrite": rewriteKey + ":" + r.rewrite[rewriteKey], + "rewrite": rewriteKey + ":" + r.cfg.Rewrite[rewriteKey], }).Debugf("rewriting %q to %q", domainOriginal, domainRewritten) } } @@ -139,7 +135,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg } func (r *RewriterResolver) rewriteDomain(domain string) (string, string) { - for k, v := range r.rewrite { + for k, v := range r.cfg.Rewrite { if strings.HasSuffix(domain, "."+k) { newDomain := strings.TrimSuffix(domain, "."+k) + "." + v diff --git a/resolver/rewriter_resolver_test.go b/resolver/rewriter_resolver_test.go index 488aab1e..a74d0f31 100644 --- a/resolver/rewriter_resolver_test.go +++ b/resolver/rewriter_resolver_test.go @@ -2,8 +2,10 @@ package resolver import ( "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" + "github.com/sirupsen/logrus" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" @@ -19,7 +21,7 @@ const ( var _ = Describe("RewriterResolver", func() { var ( sut ChainedResolver - sutConfig config.RewriteConfig + sutConfig config.RewriterConfig mInner *mockResolver mNext *mockResolver @@ -29,11 +31,17 @@ var _ = Describe("RewriterResolver", func() { mNextResponse *model.Response ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mInner = &mockResolver{} mNext = &mockResolver{} - sutConfig = config.RewriteConfig{Rewrite: map[string]string{"original": "rewritten"}} + sutConfig = config.RewriterConfig{Rewrite: map[string]string{"original": "rewritten"}} }) JustBeforeEach(func() { @@ -48,7 +56,7 @@ var _ = Describe("RewriterResolver", func() { When("has no configuration", func() { BeforeEach(func() { - sutConfig = config.RewriteConfig{} + sutConfig = config.RewriterConfig{} }) It("should return the inner resolver", func() { @@ -194,11 +202,10 @@ var _ = Describe("RewriterResolver", func() { Describe("Configuration output", func() { When("resolver is enabled", func() { It("should return configuration", func() { - innerOutput := []string{"inner:", "config-output"} - mInner.On("Configuration").Return(innerOutput) + mInner.On("LogConfig") + mInner.On("IsEnabled").Return(true) - c := sut.Configuration() - Expect(len(c)).Should(BeNumerically(">", len(innerOutput))) + sut.LogConfig(logrus.NewEntry(log.Log())) }) }) }) diff --git a/resolver/sudn_resolver.go b/resolver/sudn_resolver.go index 928c40e4..772f81a1 100644 --- a/resolver/sudn_resolver.go +++ b/resolver/sudn_resolver.go @@ -7,6 +7,7 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) const ( @@ -46,11 +47,15 @@ type defaultIPs struct { type SpecialUseDomainNamesResolver struct { NextResolver + typed + defaults *defaultIPs } func NewSpecialUseDomainNamesResolver() ChainedResolver { return &SpecialUseDomainNamesResolver{ + typed: withType("special_use_domains"), + defaults: &defaultIPs{ loopbackV4: net.ParseIP("127.0.0.1"), loopbackV6: net.IPv6loopback, @@ -58,6 +63,17 @@ func NewSpecialUseDomainNamesResolver() ChainedResolver { } } +// IsEnabled implements `config.Configurable`. +func (r *SpecialUseDomainNamesResolver) IsEnabled() bool { + // RFC 6761 & 6762 are always active + return true +} + +// LogConfig implements `config.Configurable`. +func (r *SpecialUseDomainNamesResolver) LogConfig(logger *logrus.Entry) { + logger.Info("enabled") +} + func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) { // RFC 6761 - negative if r.isSpecial(request, sudnArpaSlice()...) || @@ -78,11 +94,6 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model. return r.next.Resolve(request) } -// RFC 6761 & 6762 are always active -func (r *SpecialUseDomainNamesResolver) Configuration() []string { - return configEnabled -} - func (r *SpecialUseDomainNamesResolver) isSpecial(request *model.Request, names ...string) bool { domainFromQuestion := request.Req.Question[0].Name for _, n := range names { diff --git a/resolver/sudn_resolver_test.go b/resolver/sudn_resolver_test.go index f8e20163..ff5af9bf 100644 --- a/resolver/sudn_resolver_test.go +++ b/resolver/sudn_resolver_test.go @@ -2,6 +2,7 @@ package resolver import ( . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -16,6 +17,12 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { m *mockResolver ) + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + BeforeEach(func() { mockAnswer, err := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") Expect(err).Should(Succeed()) @@ -27,6 +34,22 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { sut.Next(m) }) + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should not log anything", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) + Describe("Blocking special names", func() { It("should block arpa", func() { for _, arpa := range sudnArpaSlice() { @@ -133,12 +156,4 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { )) }) }) - - Describe("Configuration pseudo test", func() { - It("should always be empty", func() { - c := sut.Configuration() - - Expect(len(c)).Should(BeNumerically(">=", 1)) - }) - }) }) diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index a6f2af38..e42926d4 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -30,6 +30,8 @@ const ( // UpstreamResolver sends request to external DNS server type UpstreamResolver struct { + typed + upstream config.Upstream upstreamClient upstreamClient bootstrap *Bootstrap @@ -51,7 +53,7 @@ type httpUpstreamClient struct { } func createUpstreamClient(cfg config.Upstream) upstreamClient { - timeout := time.Duration(config.GetConfig().UpstreamTimeout) + timeout := config.GetConfig().UpstreamTimeout.ToDuration() tlsConfig := tls.Config{ ServerName: cfg.Host, @@ -209,25 +211,30 @@ func newUpstreamResolverUnchecked(upstream config.Upstream, bootstrap *Bootstrap upstreamClient := createUpstreamClient(upstream) return &UpstreamResolver{ + typed: withType("upstream"), + upstream: upstream, upstreamClient: upstreamClient, bootstrap: bootstrap, } } -// Configuration return current resolver configuration -func (r *UpstreamResolver) Configuration() (result []string) { - return []string{r.String()} +// IsEnabled implements `config.Configurable`. +func (r *UpstreamResolver) IsEnabled() bool { + return true +} + +// LogConfig implements `config.Configurable`. +func (r *UpstreamResolver) LogConfig(logger *logrus.Entry) { + logger.Info(r.upstream) } func (r UpstreamResolver) String() string { - return fmt.Sprintf("upstream '%s'", r.upstream.String()) + return fmt.Sprintf("%s '%s'", r.Type(), r.upstream) } // Resolve calls external resolver func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) { - logger := log.WithPrefix(request.Log, "upstream_resolver") - ips, err := r.bootstrap.UpstreamIPs(r) if err != nil { return nil, err @@ -247,7 +254,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp var err error resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol) if err == nil { - logger.WithFields(logrus.Fields{ + r.log().WithFields(logrus.Fields{ "answer": util.AnswerToString(resp.Answer), "return_code": dns.RcodeToString[resp.Rcode], "upstream": r.upstream.String(), @@ -272,7 +279,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp return errors.As(err, &netErr) && netErr.Timeout() }), retry.OnRetry(func(n uint, err error) { - logger.WithFields(logrus.Fields{ + r.log().WithFields(logrus.Fields{ "upstream": r.upstream.String(), "upstream_ip": ip.String(), "question": util.QuestionToString(request.Req.Question), diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go index 748e4ccc..8f5a2606 100644 --- a/resolver/upstream_resolver_test.go +++ b/resolver/upstream_resolver_test.go @@ -9,6 +9,7 @@ import ( "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" + "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -17,7 +18,40 @@ import ( ) var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { - systemResolverBootstrap := &Bootstrap{} + var ( + sut *UpstreamResolver + sutConfig config.Upstream + ) + + BeforeEach(func() { + sutConfig = config.Upstream{Host: "localhost"} + }) + + JustBeforeEach(func() { + sut = newUpstreamResolverUnchecked(sutConfig, systemResolverBootstrap) + }) + + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + Describe("LogConfig", func() { + It("should log something", func() { + logger, hook := log.NewMockEntry() + + sut.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + }) + }) Describe("Using DNS upstream", func() { When("Configured DNS resolver can resolve query", func() { @@ -35,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), HaveTTL(BeNumerically("==", 123)), - HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))), + HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))), ) }) }) @@ -53,7 +87,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { HaveNoAnswer(), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), - HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))), + HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))), ) }) }) @@ -216,15 +250,4 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) }) }) - Describe("Configuration", func() { - When("Configuration is called", func() { - It("should return configuration", func() { - sut := newUpstreamResolverUnchecked(config.Upstream{}, nil) - - c := sut.Configuration() - - Expect(len(c)).Should(BeNumerically(">=", 1)) - }) - }) - }) }) diff --git a/server/server.go b/server/server.go index 59e70870..8586bb6a 100644 --- a/server/server.go +++ b/server/server.go @@ -395,7 +395,7 @@ func createQueryResolver( redisClient *redis.Client, ) (r resolver.Resolver, err error) { blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap) - parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap, cfg.StartVerifyUpstream) + parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream) clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream) @@ -411,16 +411,16 @@ func createQueryResolver( r = resolver.Chain( resolver.NewFilteringResolver(cfg.Filtering), - resolver.NewFqdnOnlyResolver(*cfg), + resolver.NewFqdnOnlyResolver(cfg.FqdnOnly), clientNames, resolver.NewEdeResolver(cfg.Ede), resolver.NewQueryLoggingResolver(cfg.QueryLog), resolver.NewMetricsResolver(cfg.Prometheus), - resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)), + resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)), resolver.NewHostsFileResolver(cfg.HostsFile), blocking, resolver.NewCachingResolver(cfg.Caching, redisClient), - resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, condUpstream), + resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream), resolver.NewSpecialUseDomainNamesResolver(), parallel, ) @@ -439,25 +439,12 @@ func (s *Server) registerDNSHandlers() { func (s *Server) printConfiguration() { logger().Info("current configuration:") - res := s.queryResolver - for res != nil { - logger().Infof("-> resolver: '%s'", resolver.Name(res)) + resolver.ForEach(s.queryResolver, func(res resolver.Resolver) { + resolver.LogResolverConfig(res, logger()) + }) - for _, c := range res.Configuration() { - logger().Infof(" %s", c) - } - - if c, ok := res.(resolver.ChainedResolver); ok { - res = c.GetNext() - } else { - break - } - } - - logger().Infof("- DNS listening on addrs/ports: %v", s.cfg.Ports.DNS) - logger().Infof("- TLS listening on addrs/ports: %v", s.cfg.Ports.TLS) - logger().Infof("- HTTP listening on addrs/ports: %v", s.cfg.Ports.HTTP) - logger().Infof("- HTTPS listening on addrs/ports: %v", s.cfg.Ports.HTTPS) + logger().Info("listeners:") + log.WithIndent(logger(), " ", s.cfg.Ports.LogConfig) logger().Info("runtime information:") @@ -465,17 +452,19 @@ func (s *Server) printConfiguration() { runtime.GC() debug.FreeOSMemory() + logger().Infof(" numCPU = %d", runtime.NumCPU()) + logger().Infof(" numGoroutine = %d", runtime.NumGoroutine()) + // gather memory stats var m runtime.MemStats runtime.ReadMemStats(&m) - logger().Infof("MEM Alloc = %10v MB", toMB(m.Alloc)) - logger().Infof("MEM HeapAlloc = %10v MB", toMB(m.HeapAlloc)) - logger().Infof("MEM Sys = %10v MB", toMB(m.Sys)) - logger().Infof("MEM NumGC = %10v", m.NumGC) - logger().Infof("RUN NumCPU = %10d", runtime.NumCPU()) - logger().Infof("RUN NumGoroutine = %10d", runtime.NumGoroutine()) + logger().Infof(" memory:") + logger().Infof(" alloc = %10v MB", toMB(m.Alloc)) + logger().Infof(" heapAlloc = %10v MB", toMB(m.HeapAlloc)) + logger().Infof(" sys = %10v MB", toMB(m.Sys)) + logger().Infof(" numGC = %10v", m.NumGC) } func toMB(b uint64) uint64 { diff --git a/server/server_test.go b/server/server_test.go index 4b8f9e96..72bc8418 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -105,7 +105,7 @@ var _ = BeforeSuite(func() { // create server sut, err = NewServer(&config.Config{ CustomDNS: config.CustomDNSConfig{ - CustomTTL: config.Duration(time.Duration(3600) * time.Second), + CustomTTL: config.Duration(3600 * time.Second), Mapping: config.CustomDNSMapping{ HostIPs: map[string][]net.IP{ "custom.lan": {net.ParseIP("192.168.178.55")}, @@ -143,7 +143,7 @@ var _ = BeforeSuite(func() { BlockType: "zeroIp", BlockTTL: config.Duration(6 * time.Hour), }, - Upstream: config.UpstreamConfig{ + Upstream: config.ParallelBestConfig{ ExternalResolvers: map[string][]config.Upstream{"default": {upstreamGoogle}}, }, ClientLookup: config.ClientLookupConfig{ @@ -158,7 +158,7 @@ var _ = BeforeSuite(func() { }, CertFile: certPem.Path, KeyFile: keyPem.Path, - Prometheus: config.PrometheusConfig{ + Prometheus: config.MetricsConfig{ Enable: true, Path: "/metrics", }, @@ -643,7 +643,7 @@ var _ = Describe("Running DNS server", func() { It("start was called 2 times, start should fail", func() { // create server server, err := NewServer(&config.Config{ - Upstream: config.UpstreamConfig{ + Upstream: config.ParallelBestConfig{ ExternalResolvers: map[string][]config.Upstream{ "default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}}, }, @@ -685,7 +685,7 @@ var _ = Describe("Running DNS server", func() { It("stop was called 2 times, start should fail", func() { // create server server, err := NewServer(&config.Config{ - Upstream: config.UpstreamConfig{ + Upstream: config.ParallelBestConfig{ ExternalResolvers: map[string][]config.Upstream{ "default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}}, }, diff --git a/util/common.go b/util/common.go index b794bd9a..b95ae05b 100644 --- a/util/common.go +++ b/util/common.go @@ -109,7 +109,7 @@ func NewMsgWithQuestion(question string, qType dns.Type) *dns.Msg { // NewMsgWithAnswer creates new DNS message with answer func NewMsgWithAnswer(domain string, ttl uint, dnsType dns.Type, address string) (*dns.Msg, error) { - rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType.String(), address)) + rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType, address)) if err != nil { return nil, err }