mirror of https://github.com/0xERR0R/blocky.git
refactor: configuration rework (usage and printing) (#920)
* refactor: make `config.Duration` a struct with `time.Duration` embed Allows directly calling `time.Duration` methods. * refactor(HostsFileResolver): don't copy individual config items The idea is to make adding configuration options easier, and searching for references straight forward. * refactor: move config printing to struct and use a logger Using a logger allows using multiple levels so the whole configuration can be printed in trace/verbose mode, but only important parts are shown by default. * squash: rename `Cast` to `ToDuration` * squash: revert `Duration` to a simple wrapper ("new type" pattern) * squash: `Duration.IsZero` tests * squash: refactor resolvers to rely on their config directly if possible * squash: implement `IsEnabled` and `LogValues` for all resolvers * refactor: use go-enum `--values` to simplify getting all log fields * refactor: simplify `QType` unmarshaling * squash: rename `ValueLogger` to `Configurable` * squash: rename `UpstreamConfig` to `ParallelBestConfig` * squash: rename `RewriteConfig` to `RewriterConfig` * squash: config tests * squash: resolver tests * squash: add `ForEach` test and improve `Chain` ones * squash: simplify implementing `config.Configurable` * squash: minor changes for better coverage * squash: more `UnmarshalYAML` -> `UnmarshalText` * refactor: move `config.Upstream` into own file * refactor: add `Resolver.Type` method * squash: add `log` method to `typed` to use `Resolover.Type` as prefix * squash: tweak startup config logging * squash: add `LogResolverConfig` tests * squash: make sure all options of type `Duration` use `%s`
This commit is contained in:
parent
bacb4437da
commit
5088c75a78
|
@ -0,0 +1,81 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// BlockingConfig configuration for query blocking
|
||||
type BlockingConfig struct {
|
||||
BlackLists map[string][]string `yaml:"blackLists"`
|
||||
WhiteLists map[string][]string `yaml:"whiteLists"`
|
||||
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
|
||||
BlockType string `yaml:"blockType" default:"ZEROIP"`
|
||||
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
|
||||
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
|
||||
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
|
||||
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
|
||||
// Deprecated
|
||||
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`
|
||||
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
|
||||
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *BlockingConfig) IsEnabled() bool {
|
||||
return len(c.ClientGroupsBlock) != 0
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("clientGroupsBlock:")
|
||||
|
||||
for key, val := range c.ClientGroupsBlock {
|
||||
logger.Infof(" %s = %v", key, val)
|
||||
}
|
||||
|
||||
logger.Infof("blockType = %s", c.BlockType)
|
||||
|
||||
if c.BlockType != "NXDOMAIN" {
|
||||
logger.Infof("blockTTL = %s", c.BlockTTL)
|
||||
}
|
||||
|
||||
logger.Infof("downloadTimeout = %s", c.DownloadTimeout)
|
||||
|
||||
logger.Infof("failStartOnListError = %t", c.FailStartOnListError)
|
||||
|
||||
if c.RefreshPeriod > 0 {
|
||||
logger.Infof("refresh = every %s", c.RefreshPeriod)
|
||||
} else {
|
||||
logger.Debug("refresh = disabled")
|
||||
}
|
||||
|
||||
logger.Info("blacklist:")
|
||||
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
|
||||
c.logListGroups(logger, c.BlackLists)
|
||||
})
|
||||
|
||||
logger.Info("whitelist:")
|
||||
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
|
||||
c.logListGroups(logger, c.WhiteLists)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) {
|
||||
for group, links := range listGroups {
|
||||
logger.Infof("%s:", group)
|
||||
|
||||
for _, link := range links {
|
||||
if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char
|
||||
link = link[:idx] // first line only
|
||||
|
||||
logger.Infof(" - %s [...]", link)
|
||||
} else {
|
||||
logger.Infof(" - %s", link)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var _ = Describe("BlockingConfig", func() {
|
||||
var (
|
||||
cfg BlockingConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {"/a/file/path"},
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
RefreshPeriod: Duration(time.Hour),
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := BlockingConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := BlockingConfig{
|
||||
BlockTTL: Duration(-1),
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:"))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
|
||||
})
|
||||
When("refresh is disabled", func() {
|
||||
It("should reflect that", func() {
|
||||
cfg.RefreshPeriod = Duration(-1)
|
||||
|
||||
logger.Logger.Level = logrus.InfoLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
|
||||
logger.Logger.Level = logrus.TraceLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,52 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CachingConfig configuration for domain caching
|
||||
type CachingConfig struct {
|
||||
MinCachingTime Duration `yaml:"minTime"`
|
||||
MaxCachingTime Duration `yaml:"maxTime"`
|
||||
CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"`
|
||||
MaxItemsCount int `yaml:"maxItemsCount"`
|
||||
Prefetching bool `yaml:"prefetching"`
|
||||
PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"`
|
||||
PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"`
|
||||
PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *CachingConfig) IsEnabled() bool {
|
||||
return c.MaxCachingTime > 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *CachingConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("minTime = %s", c.MinCachingTime)
|
||||
logger.Infof("maxTime = %s", c.MaxCachingTime)
|
||||
logger.Infof("cacheTimeNegative = %s", c.CacheTimeNegative)
|
||||
|
||||
if c.Prefetching {
|
||||
logger.Infof("prefetching:")
|
||||
logger.Infof(" expires = %s", c.PrefetchExpires)
|
||||
logger.Infof(" threshold = %d", c.PrefetchThreshold)
|
||||
logger.Infof(" maxItems = %d", c.PrefetchMaxItemsCount)
|
||||
} else {
|
||||
logger.Debug("prefetching: disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CachingConfig) EnablePrefetch() {
|
||||
const day = Duration(24 * time.Hour)
|
||||
|
||||
if c.MaxCachingTime.IsZero() {
|
||||
// make sure resolver gets enabled
|
||||
c.MaxCachingTime = day
|
||||
}
|
||||
|
||||
c.Prefetching = true
|
||||
c.PrefetchThreshold = 0
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CachingConfig", func() {
|
||||
var (
|
||||
cfg CachingConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = CachingConfig{
|
||||
MaxCachingTime: Duration(time.Hour),
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := CachingConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("the config is enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("the config is disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := CachingConfig{
|
||||
MaxCachingTime: Duration(-1),
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
When("prefetching is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
cfg = CachingConfig{
|
||||
Prefetching: true,
|
||||
}
|
||||
})
|
||||
|
||||
It("should return configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("prefetching:")))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,36 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ClientLookupConfig configuration for the client lookup
|
||||
type ClientLookupConfig struct {
|
||||
ClientnameIPMapping map[string][]net.IP `yaml:"clients"`
|
||||
Upstream Upstream `yaml:"upstream"`
|
||||
SingleNameOrder []uint `yaml:"singleNameOrder"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *ClientLookupConfig) IsEnabled() bool {
|
||||
return !c.Upstream.IsDefault() || len(c.ClientnameIPMapping) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *ClientLookupConfig) LogConfig(logger *logrus.Entry) {
|
||||
if !c.Upstream.IsDefault() {
|
||||
logger.Infof("upstream = %s", c.Upstream)
|
||||
}
|
||||
|
||||
logger.Infof("singleNameOrder = %v", c.SingleNameOrder)
|
||||
|
||||
if len(c.ClientnameIPMapping) > 0 {
|
||||
logger.Infof("client IP mapping:")
|
||||
|
||||
for k, v := range c.ClientnameIPMapping {
|
||||
logger.Infof(" %s = %s", k, v)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ClientLookupConfig", func() {
|
||||
var (
|
||||
cfg ClientLookupConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = ClientLookupConfig{
|
||||
Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"},
|
||||
SingleNameOrder: []uint{1, 2},
|
||||
ClientnameIPMapping: map[string][]net.IP{
|
||||
"client8": {net.ParseIP("1.2.3.5")},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg = ClientLookupConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
By("upstream", func() {
|
||||
cfg := ClientLookupConfig{
|
||||
Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"},
|
||||
ClientnameIPMapping: nil,
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
|
||||
By("mapping", func() {
|
||||
cfg := ClientLookupConfig{
|
||||
ClientnameIPMapping: map[string][]net.IP{
|
||||
"client8": {net.ParseIP("1.2.3.5")},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("client IP mapping:")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,60 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConditionalUpstreamConfig conditional upstream configuration
|
||||
type ConditionalUpstreamConfig struct {
|
||||
RewriterConfig `yaml:",inline"`
|
||||
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
|
||||
}
|
||||
|
||||
// ConditionalUpstreamMapping mapping for conditional configuration
|
||||
type ConditionalUpstreamMapping struct {
|
||||
Upstreams map[string][]Upstream
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *ConditionalUpstreamConfig) IsEnabled() bool {
|
||||
return len(c.Mapping.Upstreams) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *ConditionalUpstreamConfig) LogConfig(logger *logrus.Entry) {
|
||||
for key, val := range c.Mapping.Upstreams {
|
||||
logger.Infof("%s = %v", key, val)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements `yaml.Unmarshaler`.
|
||||
func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input map[string]string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make(map[string][]Upstream, len(input))
|
||||
|
||||
for k, v := range input {
|
||||
var upstreams []Upstream
|
||||
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
upstream, err := ParseUpstream(strings.TrimSpace(part))
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err)
|
||||
}
|
||||
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
|
||||
result[k] = upstreams
|
||||
}
|
||||
|
||||
c.Upstreams = result
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ConditionalUpstreamConfig", func() {
|
||||
var (
|
||||
cfg ConditionalUpstreamConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = ConditionalUpstreamConfig{
|
||||
Mapping: ConditionalUpstreamMapping{
|
||||
Upstreams: map[string][]Upstream{
|
||||
"fritz.box": {Upstream{Net: NetProtocolTcpUdp, Host: "fbTest"}},
|
||||
"other.box": {Upstream{Net: NetProtocolTcpUdp, Host: "otherTest"}},
|
||||
".": {Upstream{Net: NetProtocolTcpUdp, Host: "dotTest"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := ConditionalUpstreamConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := ConditionalUpstreamConfig{
|
||||
Mapping: ConditionalUpstreamMapping{Upstreams: map[string][]Upstream{}},
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("fritz.box = ")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UnmarshalYAML", func() {
|
||||
It("Should parse config as map", func() {
|
||||
c := &ConditionalUpstreamMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(c.Upstreams).Should(HaveLen(1))
|
||||
Expect(c.Upstreams["key"]).Should(HaveLen(1))
|
||||
Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{
|
||||
Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53,
|
||||
}))
|
||||
})
|
||||
|
||||
It("should fail if wrong YAML format", func() {
|
||||
c := &ConditionalUpstreamMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
})
|
488
config/config.go
488
config/config.go
|
@ -1,4 +1,4 @@
|
|||
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names
|
||||
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values
|
||||
package config
|
||||
|
||||
import (
|
||||
|
@ -7,16 +7,12 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/hako/durafmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/creasty/defaults"
|
||||
|
@ -29,6 +25,16 @@ const (
|
|||
httpsPort = 443
|
||||
)
|
||||
|
||||
type Configurable interface {
|
||||
// IsEnabled returns true when the receiver is configured.
|
||||
IsEnabled() bool
|
||||
|
||||
// LogConfig logs the receiver's configuration.
|
||||
//
|
||||
// Calling this method when `IsEnabled` returns false is undefined.
|
||||
LogConfig(*logrus.Entry)
|
||||
}
|
||||
|
||||
// NetProtocol resolver protocol ENUM(
|
||||
// tcp+udp // TCP and UDP protocols
|
||||
// tcp-tls // TCP-TLS protocol
|
||||
|
@ -90,44 +96,6 @@ type StartStrategyType uint16
|
|||
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
|
||||
type QueryLogField string
|
||||
|
||||
type QType dns.Type
|
||||
|
||||
func (c QType) String() string {
|
||||
return dns.Type(c).String()
|
||||
}
|
||||
|
||||
type QTypeSet map[QType]struct{}
|
||||
|
||||
func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
|
||||
s := make(QTypeSet, len(qTypes))
|
||||
|
||||
for _, qType := range qTypes {
|
||||
s.Insert(qType)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s QTypeSet) Contains(qType dns.Type) bool {
|
||||
_, found := s[QType(qType)]
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func (s *QTypeSet) Insert(qType dns.Type) {
|
||||
if *s == nil {
|
||||
*s = make(QTypeSet, 1)
|
||||
}
|
||||
|
||||
(*s)[QType(qType)] = struct{}{}
|
||||
}
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
func (c Duration) String() string {
|
||||
return durafmt.Parse(time.Duration(c)).String()
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var netDefaultPort = map[NetProtocol]uint16{
|
||||
NetProtocolTcpUdp: udpPort,
|
||||
|
@ -135,82 +103,12 @@ var netDefaultPort = map[NetProtocol]uint16{
|
|||
NetProtocolHttps: httpsPort,
|
||||
}
|
||||
|
||||
// Upstream is the definition of external DNS server
|
||||
type Upstream struct {
|
||||
Net NetProtocol
|
||||
Host string
|
||||
Port uint16
|
||||
Path string
|
||||
CommonName string // Common Name to use for certificate verification; optional. "" uses .Host
|
||||
}
|
||||
|
||||
// IsDefault returns true if u is the default value
|
||||
func (u *Upstream) IsDefault() bool {
|
||||
return *u == Upstream{}
|
||||
}
|
||||
|
||||
// String returns the string representation of u
|
||||
func (u Upstream) String() string {
|
||||
if u.IsDefault() {
|
||||
return "no upstream"
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(u.Net.String())
|
||||
sb.WriteRune(':')
|
||||
|
||||
if u.Net == NetProtocolHttps {
|
||||
sb.WriteString("//")
|
||||
}
|
||||
|
||||
isIPv6 := strings.ContainsRune(u.Host, ':')
|
||||
if isIPv6 {
|
||||
sb.WriteRune('[')
|
||||
sb.WriteString(u.Host)
|
||||
sb.WriteRune(']')
|
||||
} else {
|
||||
sb.WriteString(u.Host)
|
||||
}
|
||||
|
||||
if u.Port != netDefaultPort[u.Net] {
|
||||
sb.WriteRune(':')
|
||||
sb.WriteString(fmt.Sprint(u.Port))
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
sb.WriteString(u.Path)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// UnmarshalYAML creates Upstream from YAML
|
||||
func (u *Upstream) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
upstream, err := ParseUpstream(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't convert upstream '%s': %w", s, err)
|
||||
}
|
||||
|
||||
*u = upstream
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenConfig is a list of address(es) to listen on
|
||||
type ListenConfig []string
|
||||
|
||||
// UnmarshalYAML creates ListenConfig from YAML
|
||||
func (l *ListenConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var addresses string
|
||||
if err := unmarshal(&addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
// UnmarshalText implements `encoding.TextUnmarshaler`.
|
||||
func (l *ListenConfig) UnmarshalText(data []byte) error {
|
||||
addresses := string(data)
|
||||
|
||||
*l = strings.Split(addresses, ",")
|
||||
|
||||
|
@ -256,225 +154,11 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e
|
|||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML creates ConditionalUpstreamMapping from YAML
|
||||
func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input map[string]string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make(map[string][]Upstream, len(input))
|
||||
|
||||
for k, v := range input {
|
||||
var upstreams []Upstream
|
||||
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
upstream, err := ParseUpstream(strings.TrimSpace(part))
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err)
|
||||
}
|
||||
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
|
||||
result[k] = upstreams
|
||||
}
|
||||
|
||||
c.Upstreams = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML creates CustomDNSMapping from YAML
|
||||
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input map[string]string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make(map[string][]net.IP, len(input))
|
||||
|
||||
for k, v := range input {
|
||||
var ips []net.IP
|
||||
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
ip := net.ParseIP(strings.TrimSpace(part))
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address '%s'", part)
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
result[k] = ips
|
||||
}
|
||||
|
||||
c.HostIPs = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML creates Duration from YAML. If no unit is used, uses minutes
|
||||
func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if minutes, err := strconv.Atoi(input); err == nil {
|
||||
// duration is defined as number without unit
|
||||
// use minutes to ensure back compatibility
|
||||
*c = Duration(time.Duration(minutes) * time.Minute)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(input)
|
||||
if err == nil {
|
||||
*c = Duration(duration)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *QType) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t, found := dns.StringToType[input]
|
||||
if !found {
|
||||
types := make([]string, 0, len(dns.StringToType))
|
||||
for k := range dns.StringToType {
|
||||
types = append(types, k)
|
||||
}
|
||||
|
||||
sort.Strings(types)
|
||||
|
||||
return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'",
|
||||
input, strings.Join(types, ", "))
|
||||
}
|
||||
|
||||
*c = QType(t)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input []QType
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*s = make(QTypeSet, len(input))
|
||||
|
||||
for _, qType := range input {
|
||||
(*s)[qType] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var validDomain = regexp.MustCompile(
|
||||
`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
|
||||
// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname]
|
||||
func ParseUpstream(upstream string) (Upstream, error) {
|
||||
var path string
|
||||
|
||||
var port uint16
|
||||
|
||||
commonName, upstream := extractCommonName(upstream)
|
||||
|
||||
n, upstream := extractNet(upstream)
|
||||
|
||||
path, upstream = extractPath(upstream)
|
||||
|
||||
host, portString, err := net.SplitHostPort(upstream)
|
||||
|
||||
// string contains host:port
|
||||
if err == nil {
|
||||
p, err := ConvertPort(portString)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
|
||||
|
||||
return Upstream{}, err
|
||||
}
|
||||
|
||||
port = p
|
||||
} else {
|
||||
// only host, use default port
|
||||
host = upstream
|
||||
port = netDefaultPort[n]
|
||||
|
||||
// trim any IPv6 brackets
|
||||
host = strings.TrimPrefix(host, "[")
|
||||
host = strings.TrimSuffix(host, "]")
|
||||
}
|
||||
|
||||
// validate hostname or ip
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
// is not IP
|
||||
if !validDomain.MatchString(host) {
|
||||
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
|
||||
}
|
||||
}
|
||||
|
||||
return Upstream{
|
||||
Net: n,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Path: path,
|
||||
CommonName: commonName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractCommonName(in string) (string, string) {
|
||||
upstream, cn, _ := strings.Cut(in, "#")
|
||||
|
||||
return cn, upstream
|
||||
}
|
||||
|
||||
func extractPath(in string) (path, upstream string) {
|
||||
slashIdx := strings.Index(in, "/")
|
||||
|
||||
if slashIdx >= 0 {
|
||||
path = in[slashIdx:]
|
||||
upstream = in[:slashIdx]
|
||||
} else {
|
||||
upstream = in
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func extractNet(upstream string) (NetProtocol, string) {
|
||||
tcpUDPPrefix := NetProtocolTcpUdp.String() + ":"
|
||||
if strings.HasPrefix(upstream, tcpUDPPrefix) {
|
||||
return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):]
|
||||
}
|
||||
|
||||
tcpTLSPrefix := NetProtocolTcpTls.String() + ":"
|
||||
if strings.HasPrefix(upstream, tcpTLSPrefix) {
|
||||
return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):]
|
||||
}
|
||||
|
||||
httpsPrefix := NetProtocolHttps.String() + ":"
|
||||
if strings.HasPrefix(upstream, httpsPrefix) {
|
||||
return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//")
|
||||
}
|
||||
|
||||
return NetProtocolTcpUdp, upstream
|
||||
}
|
||||
|
||||
// Config main configuration
|
||||
//
|
||||
//nolint:maligned
|
||||
type Config struct {
|
||||
Upstream UpstreamConfig `yaml:"upstream"`
|
||||
Upstream ParallelBestConfig `yaml:"upstream"`
|
||||
UpstreamTimeout Duration `yaml:"upstreamTimeout" default:"2s"`
|
||||
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
|
||||
CustomDNS CustomDNSConfig `yaml:"customDNS"`
|
||||
|
@ -483,7 +167,7 @@ type Config struct {
|
|||
ClientLookup ClientLookupConfig `yaml:"clientLookup"`
|
||||
Caching CachingConfig `yaml:"caching"`
|
||||
QueryLog QueryLogConfig `yaml:"queryLog"`
|
||||
Prometheus PrometheusConfig `yaml:"prometheus"`
|
||||
Prometheus MetricsConfig `yaml:"prometheus"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Log log.Config `yaml:"log"`
|
||||
Ports PortsConfig `yaml:"ports"`
|
||||
|
@ -494,7 +178,7 @@ type Config struct {
|
|||
KeyFile string `yaml:"keyFile"`
|
||||
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
|
||||
HostsFile HostsFileConfig `yaml:"hostsFile"`
|
||||
FqdnOnly bool `yaml:"fqdnOnly" default:"false"`
|
||||
FqdnOnly FqdnOnlyConfig `yaml:",inline"`
|
||||
Filtering FilteringConfig `yaml:"filtering"`
|
||||
Ede EdeConfig `yaml:"ede"`
|
||||
// Deprecated
|
||||
|
@ -508,7 +192,7 @@ type Config struct {
|
|||
// Deprecated
|
||||
LogTimestamp bool `yaml:"logTimestamp" default:"true"`
|
||||
// Deprecated
|
||||
DNSPorts ListenConfig `yaml:"port" default:"[\"53\"]"`
|
||||
DNSPorts ListenConfig `yaml:"port" default:"\"53\""`
|
||||
// Deprecated
|
||||
HTTPPorts ListenConfig `yaml:"httpPort"`
|
||||
// Deprecated
|
||||
|
@ -518,12 +202,19 @@ type Config struct {
|
|||
}
|
||||
|
||||
type PortsConfig struct {
|
||||
DNS ListenConfig `yaml:"dns" default:"[\"53\"]"`
|
||||
DNS ListenConfig `yaml:"dns" default:"\"53\""`
|
||||
HTTP ListenConfig `yaml:"http"`
|
||||
HTTPS ListenConfig `yaml:"https"`
|
||||
TLS ListenConfig `yaml:"tls"`
|
||||
}
|
||||
|
||||
func (c *PortsConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("DNS = %s", c.DNS)
|
||||
logger.Infof("TLS = %s", c.TLS)
|
||||
logger.Infof("HTTP = %s", c.HTTP)
|
||||
logger.Infof("HTTPS = %s", c.HTTPS)
|
||||
}
|
||||
|
||||
// split in two types to avoid infinite recursion. See `BootstrapDNSConfig.UnmarshalYAML`.
|
||||
type (
|
||||
BootstrapDNSConfig bootstrapDNSConfig
|
||||
|
@ -539,105 +230,6 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
// PrometheusConfig contains the config values for prometheus
|
||||
type PrometheusConfig struct {
|
||||
Enable bool `yaml:"enable" default:"false"`
|
||||
Path string `yaml:"path" default:"/metrics"`
|
||||
}
|
||||
|
||||
// UpstreamConfig upstream server configuration
|
||||
type UpstreamConfig struct {
|
||||
ExternalResolvers map[string][]Upstream `yaml:",inline"`
|
||||
}
|
||||
|
||||
// RewriteConfig custom DNS configuration
|
||||
type RewriteConfig struct {
|
||||
Rewrite map[string]string `yaml:"rewrite"`
|
||||
FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"`
|
||||
}
|
||||
|
||||
// CustomDNSConfig custom DNS configuration
|
||||
type CustomDNSConfig struct {
|
||||
RewriteConfig `yaml:",inline"`
|
||||
CustomTTL Duration `yaml:"customTTL" default:"1h"`
|
||||
Mapping CustomDNSMapping `yaml:"mapping"`
|
||||
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
|
||||
}
|
||||
|
||||
// CustomDNSMapping mapping for the custom DNS configuration
|
||||
type CustomDNSMapping struct {
|
||||
HostIPs map[string][]net.IP
|
||||
}
|
||||
|
||||
// ConditionalUpstreamConfig conditional upstream configuration
|
||||
type ConditionalUpstreamConfig struct {
|
||||
RewriteConfig `yaml:",inline"`
|
||||
Mapping ConditionalUpstreamMapping `yaml:"mapping"`
|
||||
}
|
||||
|
||||
// ConditionalUpstreamMapping mapping for conditional configuration
|
||||
type ConditionalUpstreamMapping struct {
|
||||
Upstreams map[string][]Upstream
|
||||
}
|
||||
|
||||
// BlockingConfig configuration for query blocking
|
||||
type BlockingConfig struct {
|
||||
BlackLists map[string][]string `yaml:"blackLists"`
|
||||
WhiteLists map[string][]string `yaml:"whiteLists"`
|
||||
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
|
||||
BlockType string `yaml:"blockType" default:"ZEROIP"`
|
||||
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
|
||||
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
|
||||
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
|
||||
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
|
||||
// Deprecated
|
||||
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`
|
||||
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
|
||||
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
|
||||
}
|
||||
|
||||
// ClientLookupConfig configuration for the client lookup
|
||||
type ClientLookupConfig struct {
|
||||
ClientnameIPMapping map[string][]net.IP `yaml:"clients"`
|
||||
Upstream Upstream `yaml:"upstream"`
|
||||
SingleNameOrder []uint `yaml:"singleNameOrder"`
|
||||
}
|
||||
|
||||
// CachingConfig configuration for domain caching
|
||||
type CachingConfig struct {
|
||||
MinCachingTime Duration `yaml:"minTime"`
|
||||
MaxCachingTime Duration `yaml:"maxTime"`
|
||||
CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"`
|
||||
MaxItemsCount int `yaml:"maxItemsCount"`
|
||||
Prefetching bool `yaml:"prefetching"`
|
||||
PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"`
|
||||
PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"`
|
||||
PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"`
|
||||
}
|
||||
|
||||
func (c *CachingConfig) EnablePrefetch() {
|
||||
const day = 24 * time.Hour
|
||||
|
||||
if c.MaxCachingTime == 0 {
|
||||
// make sure resolver gets enabled
|
||||
c.MaxCachingTime = Duration(day)
|
||||
}
|
||||
|
||||
c.Prefetching = true
|
||||
c.PrefetchThreshold = 0
|
||||
}
|
||||
|
||||
// QueryLogConfig configuration for the query logging
|
||||
type QueryLogConfig struct {
|
||||
Target string `yaml:"target"`
|
||||
Type QueryLogType `yaml:"type"`
|
||||
LogRetentionDays uint64 `yaml:"logRetentionDays"`
|
||||
CreationAttempts int `yaml:"creationAttempts" default:"3"`
|
||||
CreationCooldown Duration `yaml:"creationCooldown" default:"2s"`
|
||||
Fields []QueryLogField `yaml:"fields"`
|
||||
}
|
||||
|
||||
// RedisConfig configuration for the redis connection
|
||||
type RedisConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
|
@ -652,21 +244,25 @@ type RedisConfig struct {
|
|||
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
||||
}
|
||||
|
||||
type HostsFileConfig struct {
|
||||
Filepath string `yaml:"filePath"`
|
||||
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
|
||||
FilterLoopback bool `yaml:"filterLoopback"`
|
||||
}
|
||||
type (
|
||||
FqdnOnlyConfig = toEnable
|
||||
EdeConfig = toEnable
|
||||
)
|
||||
|
||||
type FilteringConfig struct {
|
||||
QueryTypes QTypeSet `yaml:"queryTypes"`
|
||||
}
|
||||
|
||||
type EdeConfig struct {
|
||||
type toEnable struct {
|
||||
Enable bool `yaml:"enable" default:"false"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *toEnable) IsEnabled() bool {
|
||||
return c.Enable
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *toEnable) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("enabled")
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
config = &Config{}
|
||||
|
|
|
@ -40,6 +40,15 @@ func IPVersionNames() []string {
|
|||
return tmp
|
||||
}
|
||||
|
||||
// IPVersionValues returns a list of the values for IPVersion
|
||||
func IPVersionValues() []IPVersion {
|
||||
return []IPVersion{
|
||||
IPVersionDual,
|
||||
IPVersionV4,
|
||||
IPVersionV6,
|
||||
}
|
||||
}
|
||||
|
||||
var _IPVersionMap = map[IPVersion]string{
|
||||
IPVersionDual: _IPVersionName[0:4],
|
||||
IPVersionV4: _IPVersionName[4:6],
|
||||
|
@ -113,6 +122,15 @@ func NetProtocolNames() []string {
|
|||
return tmp
|
||||
}
|
||||
|
||||
// NetProtocolValues returns a list of the values for NetProtocol
|
||||
func NetProtocolValues() []NetProtocol {
|
||||
return []NetProtocol{
|
||||
NetProtocolTcpUdp,
|
||||
NetProtocolTcpTls,
|
||||
NetProtocolHttps,
|
||||
}
|
||||
}
|
||||
|
||||
var _NetProtocolMap = map[NetProtocol]string{
|
||||
NetProtocolTcpUdp: _NetProtocolName[0:7],
|
||||
NetProtocolTcpTls: _NetProtocolName[7:14],
|
||||
|
@ -190,6 +208,18 @@ func QueryLogFieldNames() []string {
|
|||
return tmp
|
||||
}
|
||||
|
||||
// QueryLogFieldValues returns a list of the values for QueryLogField
|
||||
func QueryLogFieldValues() []QueryLogField {
|
||||
return []QueryLogField{
|
||||
QueryLogFieldClientIP,
|
||||
QueryLogFieldClientName,
|
||||
QueryLogFieldResponseReason,
|
||||
QueryLogFieldResponseAnswer,
|
||||
QueryLogFieldQuestion,
|
||||
QueryLogFieldDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x QueryLogField) String() string {
|
||||
return string(x)
|
||||
|
@ -274,6 +304,18 @@ func QueryLogTypeNames() []string {
|
|||
return tmp
|
||||
}
|
||||
|
||||
// QueryLogTypeValues returns a list of the values for QueryLogType
|
||||
func QueryLogTypeValues() []QueryLogType {
|
||||
return []QueryLogType{
|
||||
QueryLogTypeConsole,
|
||||
QueryLogTypeNone,
|
||||
QueryLogTypeMysql,
|
||||
QueryLogTypePostgresql,
|
||||
QueryLogTypeCsv,
|
||||
QueryLogTypeCsvClient,
|
||||
}
|
||||
}
|
||||
|
||||
var _QueryLogTypeMap = map[QueryLogType]string{
|
||||
QueryLogTypeConsole: _QueryLogTypeName[0:7],
|
||||
QueryLogTypeNone: _QueryLogTypeName[7:11],
|
||||
|
@ -353,6 +395,15 @@ func StartStrategyTypeNames() []string {
|
|||
return tmp
|
||||
}
|
||||
|
||||
// StartStrategyTypeValues returns a list of the values for StartStrategyType
|
||||
func StartStrategyTypeValues() []StartStrategyType {
|
||||
return []StartStrategyType{
|
||||
StartStrategyTypeBlocking,
|
||||
StartStrategyTypeFailOnError,
|
||||
StartStrategyTypeFast,
|
||||
}
|
||||
}
|
||||
|
||||
var _StartStrategyTypeMap = map[StartStrategyType]string{
|
||||
StartStrategyTypeBlocking: _StartStrategyTypeName[0:8],
|
||||
StartStrategyTypeFailOnError: _StartStrategyTypeName[8:19],
|
||||
|
|
|
@ -6,6 +6,12 @@ import (
|
|||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
logger *logrus.Entry
|
||||
hook *log.MockLoggerHook
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
|
@ -13,3 +19,9 @@ func TestConfig(t *testing.T) {
|
|||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Config Suite")
|
||||
}
|
||||
|
||||
func suiteBeforeEach() {
|
||||
BeforeEach(func() {
|
||||
logger, hook = log.NewMockEntry()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -321,15 +320,11 @@ bootstrapDns:
|
|||
})
|
||||
})
|
||||
|
||||
Describe("YAML parsing", func() {
|
||||
Describe("Parsing", func() {
|
||||
Context("upstream", func() {
|
||||
It("should create the upstream struct with data", func() {
|
||||
u := &Upstream{}
|
||||
err := u.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "tcp+udp:1.2.3.4"
|
||||
|
||||
return nil
|
||||
})
|
||||
err := u.UnmarshalText([]byte("tcp+udp:1.2.3.4"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(u.Net).Should(Equal(NetProtocolTcpUdp))
|
||||
Expect(u.Host).Should(Equal("1.2.3.4"))
|
||||
|
@ -338,139 +333,18 @@ bootstrapDns:
|
|||
|
||||
It("should fail if the upstream is in wrong format", func() {
|
||||
u := &Upstream{}
|
||||
err := u.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
err := u.UnmarshalText([]byte("invalid!"))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
Context("ListenConfig", func() {
|
||||
It("should parse and split valid string config", func() {
|
||||
l := &ListenConfig{}
|
||||
err := l.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "55,:56"
|
||||
|
||||
return nil
|
||||
})
|
||||
err := l.UnmarshalText([]byte("55,:56"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(*l).Should(HaveLen(2))
|
||||
Expect(*l).Should(ContainElements("55", ":56"))
|
||||
})
|
||||
It("should fail on error", func() {
|
||||
l := &ListenConfig{}
|
||||
err := l.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
Context("Duration", func() {
|
||||
It("should parse duration with unit", func() {
|
||||
d := Duration(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "1m20s"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(d).Should(Equal(Duration(80 * time.Second)))
|
||||
Expect(d.String()).Should(Equal("1 minute 20 seconds"))
|
||||
})
|
||||
It("should fail if duration is in wrong format", func() {
|
||||
d := Duration(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "wrong"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("time: invalid duration \"wrong\""))
|
||||
})
|
||||
It("should fail if wrong YAML format", func() {
|
||||
d := Duration(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
Context("ConditionalUpstreamMapping", func() {
|
||||
It("Should parse config as map", func() {
|
||||
c := &ConditionalUpstreamMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(c.Upstreams).Should(HaveLen(1))
|
||||
Expect(c.Upstreams["key"]).Should(HaveLen(1))
|
||||
Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{
|
||||
Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53,
|
||||
}))
|
||||
})
|
||||
It("should fail if wrong YAML format", func() {
|
||||
c := &ConditionalUpstreamMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
Context("CustomDNSMapping", func() {
|
||||
It("Should parse config as map", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(c.HostIPs).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"]).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4")))
|
||||
})
|
||||
It("should fail if wrong YAML format", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
Context("QueryTyoe", func() {
|
||||
It("Should parse existing DNS type as string", func() {
|
||||
t := QType(0)
|
||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "AAAA"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(t).Should(Equal(QType(dns.TypeAAAA)))
|
||||
Expect(t.String()).Should(Equal("AAAA"))
|
||||
})
|
||||
It("should fail if DNS type does not exist", func() {
|
||||
t := QType(0)
|
||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "WRONGTYPE"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'"))
|
||||
})
|
||||
It("should fail if wrong YAML format", func() {
|
||||
d := QType(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -651,29 +525,6 @@ bootstrapDns:
|
|||
"tcp-tls:[fd00::6cd4:d7e0:d99d:2952]",
|
||||
),
|
||||
)
|
||||
|
||||
Describe("QTypeSet", func() {
|
||||
It("new should insert given qTypes", func() {
|
||||
set := NewQTypeSet(dns.Type(dns.TypeA))
|
||||
Expect(set).Should(HaveKey(QType(dns.TypeA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue())
|
||||
|
||||
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
|
||||
})
|
||||
|
||||
It("should insert given qTypes", func() {
|
||||
set := NewQTypeSet()
|
||||
|
||||
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
|
||||
|
||||
set.Insert(dns.Type(dns.TypeAAAA))
|
||||
|
||||
Expect(set).Should(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func defaultTestFileConfig() {
|
||||
|
@ -700,8 +551,8 @@ func defaultTestFileConfig() {
|
|||
Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
|
||||
Expect(config.Filtering.QueryTypes).Should(HaveLen(2))
|
||||
|
||||
Expect(config.Caching.MaxCachingTime).Should(Equal(Duration(0)))
|
||||
Expect(config.Caching.MinCachingTime).Should(Equal(Duration(0)))
|
||||
Expect(config.Caching.MaxCachingTime.IsZero()).Should(BeTrue())
|
||||
Expect(config.Caching.MinCachingTime.IsZero()).Should(BeTrue())
|
||||
|
||||
Expect(config.DoHUserAgent).Should(Equal("testBlocky"))
|
||||
Expect(config.MinTLSServeVer).Should(Equal("1.3"))
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CustomDNSConfig custom DNS configuration
|
||||
type CustomDNSConfig struct {
|
||||
RewriterConfig `yaml:",inline"`
|
||||
CustomTTL Duration `yaml:"customTTL" default:"1h"`
|
||||
Mapping CustomDNSMapping `yaml:"mapping"`
|
||||
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
|
||||
}
|
||||
|
||||
// CustomDNSMapping mapping for the custom DNS configuration
|
||||
type CustomDNSMapping struct {
|
||||
HostIPs map[string][]net.IP `yaml:"hostIPs"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *CustomDNSConfig) IsEnabled() bool {
|
||||
return len(c.Mapping.HostIPs) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *CustomDNSConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Debugf("TTL = %s", c.CustomTTL)
|
||||
logger.Debugf("filterUnmappedTypes = %t", c.FilterUnmappedTypes)
|
||||
|
||||
logger.Info("mapping:")
|
||||
|
||||
for key, val := range c.Mapping.HostIPs {
|
||||
logger.Infof(" %s = %s", key, val)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements `yaml.Unmarshaler`.
|
||||
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input map[string]string
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make(map[string][]net.IP, len(input))
|
||||
|
||||
for k, v := range input {
|
||||
var ips []net.IP
|
||||
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
ip := net.ParseIP(strings.TrimSpace(part))
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address '%s'", part)
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
result[k] = ips
|
||||
}
|
||||
|
||||
c.HostIPs = result
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CustomDNSConfig", func() {
|
||||
var (
|
||||
cfg CustomDNSConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = CustomDNSConfig{
|
||||
Mapping: CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.domain": {net.ParseIP("192.168.143.123")},
|
||||
"ip6.domain": {net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
"multiple.ips": {
|
||||
net.ParseIP("192.168.143.123"),
|
||||
net.ParseIP("192.168.143.125"),
|
||||
net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := CustomDNSConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := CustomDNSConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("custom.domain = ")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("multiple.ips = ")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UnmarshalYAML", func() {
|
||||
It("Should parse config as map", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(c.HostIPs).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"]).Should(HaveLen(1))
|
||||
Expect(c.HostIPs["key"][0]).Should(Equal(net.ParseIP("1.2.3.4")))
|
||||
})
|
||||
|
||||
It("should fail if wrong YAML format", func() {
|
||||
c := &CustomDNSMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
return errors.New("some err")
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("some err"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,54 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/hako/durafmt"
|
||||
)
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
func (c Duration) ToDuration() time.Duration {
|
||||
return time.Duration(c)
|
||||
}
|
||||
|
||||
func (c Duration) IsZero() bool {
|
||||
return c.ToDuration() == 0
|
||||
}
|
||||
|
||||
func (c Duration) Seconds() float64 {
|
||||
return c.ToDuration().Seconds()
|
||||
}
|
||||
|
||||
func (c Duration) SecondsU32() uint32 {
|
||||
return uint32(c.Seconds())
|
||||
}
|
||||
|
||||
func (c Duration) String() string {
|
||||
return durafmt.Parse(c.ToDuration()).String()
|
||||
}
|
||||
|
||||
// UnmarshalText implements `encoding.TextUnmarshaler`.
|
||||
func (c *Duration) UnmarshalText(data []byte) error {
|
||||
input := string(data)
|
||||
|
||||
if minutes, err := strconv.Atoi(input); err == nil {
|
||||
// number without unit: use minutes to ensure back compatibility
|
||||
*c = Duration(time.Duration(minutes) * time.Minute)
|
||||
|
||||
log.Log().Warnf("Setting a duration without a unit is deprecated. Please use '%s min' instead.", input)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(input)
|
||||
if err == nil {
|
||||
*c = Duration(duration)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Duration", func() {
|
||||
var d Duration
|
||||
|
||||
BeforeEach(func() {
|
||||
var zero Duration
|
||||
|
||||
d = zero
|
||||
})
|
||||
|
||||
Describe("UnmarshalText", func() {
|
||||
It("should parse duration with unit", func() {
|
||||
err := d.UnmarshalText([]byte("1m20s"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(d).Should(Equal(Duration(80 * time.Second)))
|
||||
Expect(d.String()).Should(Equal("1 minute 20 seconds"))
|
||||
})
|
||||
|
||||
It("should fail if duration is in wrong format", func() {
|
||||
err := d.UnmarshalText([]byte("wrong"))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err).Should(MatchError("time: invalid duration \"wrong\""))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("IsZero", func() {
|
||||
It("should be true for zero", func() {
|
||||
Expect(d.IsZero()).Should(BeTrue())
|
||||
Expect(Duration(0).IsZero()).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should be false for non-zero", func() {
|
||||
Expect(Duration(time.Second).IsZero()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,23 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type FilteringConfig struct {
|
||||
QueryTypes QTypeSet `yaml:"queryTypes"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *FilteringConfig) IsEnabled() bool {
|
||||
return len(c.QueryTypes) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *FilteringConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("query types:")
|
||||
|
||||
for qType := range c.QueryTypes {
|
||||
logger.Infof(" - %s", qType)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("FilteringConfig", func() {
|
||||
var (
|
||||
cfg FilteringConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = FilteringConfig{
|
||||
QueryTypes: NewQTypeSet(AAAA, MX),
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := FilteringConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := FilteringConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).Should(HaveLen(3))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("query types:")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - AAAA")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring(" - MX")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,25 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type HostsFileConfig struct {
|
||||
Filepath string `yaml:"filePath"`
|
||||
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
|
||||
FilterLoopback bool `yaml:"filterLoopback"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *HostsFileConfig) IsEnabled() bool {
|
||||
return len(c.Filepath) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("file path: %s", c.Filepath)
|
||||
logger.Infof("TTL: %s", c.HostsTTL)
|
||||
logger.Infof("refresh period: %s", c.RefreshPeriod)
|
||||
logger.Infof("filter loopback addresses: %t", c.FilterLoopback)
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("HostsFileConfig", func() {
|
||||
var (
|
||||
cfg HostsFileConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = HostsFileConfig{
|
||||
Filepath: "/dev/null",
|
||||
HostsTTL: Duration(29 * time.Minute),
|
||||
RefreshPeriod: Duration(30 * time.Minute),
|
||||
FilterLoopback: true,
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := HostsFileConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := HostsFileConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,19 @@
|
|||
package config
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
|
||||
// MetricsConfig contains the config values for prometheus
|
||||
type MetricsConfig struct {
|
||||
Enable bool `yaml:"enable" default:"false"`
|
||||
Path string `yaml:"path" default:"/metrics"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *MetricsConfig) IsEnabled() bool {
|
||||
return c.Enable
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *MetricsConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("url path: %s", c.Path)
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MetricsConfig", func() {
|
||||
var (
|
||||
cfg MetricsConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = MetricsConfig{
|
||||
Enable: true,
|
||||
Path: "/custom/path",
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := MetricsConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := MetricsConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).Should(HaveLen(1))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("url path: /custom/path")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,32 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const UpstreamDefaultCfgName = "default"
|
||||
|
||||
// ParallelBestConfig upstream server configuration
|
||||
type ParallelBestConfig struct {
|
||||
ExternalResolvers ParallelBestMapping `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ParallelBestMapping map[string][]Upstream
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *ParallelBestConfig) IsEnabled() bool {
|
||||
return len(c.ExternalResolvers) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *ParallelBestConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("upstream resolvers:")
|
||||
|
||||
for name, upstreams := range c.ExternalResolvers {
|
||||
logger.Infof(" %s:", name)
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
logger.Infof(" - %s", upstream)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ParallelBestConfig", func() {
|
||||
var (
|
||||
cfg ParallelBestConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = ParallelBestConfig{
|
||||
ExternalResolvers: ParallelBestMapping{
|
||||
UpstreamDefaultCfgName: {
|
||||
{Host: "host1"},
|
||||
{Host: "host2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := ParallelBestConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := ParallelBestConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstream resolvers:")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,76 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type QTypeSet map[QType]struct{}
|
||||
|
||||
func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
|
||||
s := make(QTypeSet, len(qTypes))
|
||||
|
||||
for _, qType := range qTypes {
|
||||
s.Insert(qType)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s QTypeSet) Contains(qType dns.Type) bool {
|
||||
_, found := s[QType(qType)]
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func (s *QTypeSet) Insert(qType dns.Type) {
|
||||
if *s == nil {
|
||||
*s = make(QTypeSet, 1)
|
||||
}
|
||||
|
||||
(*s)[QType(qType)] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var input []QType
|
||||
if err := unmarshal(&input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*s = make(QTypeSet, len(input))
|
||||
|
||||
for _, qType := range input {
|
||||
(*s)[qType] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type QType dns.Type
|
||||
|
||||
func (c QType) String() string {
|
||||
return dns.Type(c).String()
|
||||
}
|
||||
|
||||
// UnmarshalText implements `encoding.TextUnmarshaler`.
|
||||
func (c *QType) UnmarshalText(data []byte) error {
|
||||
input := string(data)
|
||||
|
||||
t, found := dns.StringToType[input]
|
||||
if !found {
|
||||
types := maps.Keys(dns.StringToType)
|
||||
|
||||
sort.Strings(types)
|
||||
|
||||
return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'",
|
||||
input, strings.Join(types, ", "))
|
||||
}
|
||||
|
||||
*c = QType(t)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("QTypeSet", func() {
|
||||
Describe("NewQTypeSet", func() {
|
||||
It("should insert given qTypes", func() {
|
||||
set := NewQTypeSet(dns.Type(dns.TypeA))
|
||||
Expect(set).Should(HaveKey(QType(dns.TypeA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue())
|
||||
|
||||
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Insert", func() {
|
||||
It("should insert given qTypes", func() {
|
||||
set := NewQTypeSet()
|
||||
|
||||
Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
|
||||
|
||||
set.Insert(dns.Type(dns.TypeAAAA))
|
||||
|
||||
Expect(set).Should(HaveKey(QType(dns.TypeAAAA)))
|
||||
Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("QType", func() {
|
||||
Describe("UnmarshalText", func() {
|
||||
It("Should parse existing DNS type as string", func() {
|
||||
t := QType(0)
|
||||
err := t.UnmarshalText([]byte("AAAA"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(t).Should(Equal(QType(dns.TypeAAAA)))
|
||||
Expect(t.String()).Should(Equal("AAAA"))
|
||||
})
|
||||
|
||||
It("should fail if DNS type does not exist", func() {
|
||||
t := QType(0)
|
||||
err := t.UnmarshalText([]byte("WRONGTYPE"))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'"))
|
||||
})
|
||||
|
||||
It("should fail if wrong YAML format", func() {
|
||||
d := QType(0)
|
||||
err := d.UnmarshalText([]byte("some err"))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'some err'"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,41 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// QueryLogConfig configuration for the query logging
|
||||
type QueryLogConfig struct {
|
||||
Target string `yaml:"target"`
|
||||
Type QueryLogType `yaml:"type"`
|
||||
LogRetentionDays uint64 `yaml:"logRetentionDays"`
|
||||
CreationAttempts int `yaml:"creationAttempts" default:"3"`
|
||||
CreationCooldown Duration `yaml:"creationCooldown" default:"2s"`
|
||||
Fields []QueryLogField `yaml:"fields"`
|
||||
}
|
||||
|
||||
// SetDefaults implements `defaults.Setter`.
|
||||
func (c *QueryLogConfig) SetDefaults() {
|
||||
// Since the default depends on the enum values, set it dynamically
|
||||
// to avoid having to repeat the values in the annotation.
|
||||
c.Fields = QueryLogFieldValues()
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *QueryLogConfig) IsEnabled() bool {
|
||||
return c.Type != QueryLogTypeNone
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *QueryLogConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("type: %s", c.Type)
|
||||
|
||||
if c.Target != "" {
|
||||
logger.Infof("target: %s", c.Target)
|
||||
}
|
||||
|
||||
logger.Infof("logRetentionDays: %d", c.LogRetentionDays)
|
||||
logger.Debugf("creationAttempts: %d", c.CreationAttempts)
|
||||
logger.Debugf("creationCooldown: %s", c.CreationCooldown)
|
||||
logger.Infof("fields: %s", c.Fields)
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("QueryLogConfig", func() {
|
||||
var (
|
||||
cfg QueryLogConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = QueryLogConfig{
|
||||
Target: "/dev/null",
|
||||
Type: QueryLogTypeCsvClient,
|
||||
LogRetentionDays: 0,
|
||||
CreationAttempts: 1,
|
||||
CreationCooldown: Duration(time.Millisecond),
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be true by default", func() {
|
||||
cfg := QueryLogConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := QueryLogConfig{
|
||||
Type: QueryLogTypeNone,
|
||||
}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("logRetentionDays:")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetDefaults", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg := QueryLogConfig{}
|
||||
Expect(cfg.Fields).Should(BeEmpty())
|
||||
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.Fields).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,27 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RewriterConfig custom DNS configuration
|
||||
type RewriterConfig struct {
|
||||
Rewrite map[string]string `yaml:"rewrite"`
|
||||
FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *RewriterConfig) IsEnabled() bool {
|
||||
return len(c.Rewrite) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *RewriterConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("fallbackUpstream = %t", c.FallbackUpstream)
|
||||
|
||||
logger.Info("rules:")
|
||||
|
||||
for key, val := range c.Rewrite {
|
||||
logger.Infof(" %s = %s", key, val)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RewriterConfig", func() {
|
||||
var (
|
||||
cfg RewriterConfig
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = RewriterConfig{
|
||||
Rewrite: map[string]string{
|
||||
"original1": "rewritten1",
|
||||
"original2": "rewritten2",
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("should be false by default", func() {
|
||||
cfg := RewriterConfig{}
|
||||
Expect(defaults.Set(&cfg)).Should(Succeed())
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
|
||||
When("enabled", func() {
|
||||
It("should be true", func() {
|
||||
Expect(cfg.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
When("disabled", func() {
|
||||
It("should be false", func() {
|
||||
cfg := RewriterConfig{}
|
||||
|
||||
Expect(cfg.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("rules:")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("original2 =")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,164 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var validDomain = regexp.MustCompile(
|
||||
`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
|
||||
// Upstream is the definition of external DNS server
|
||||
type Upstream struct {
|
||||
Net NetProtocol
|
||||
Host string
|
||||
Port uint16
|
||||
Path string
|
||||
CommonName string // Common Name to use for certificate verification; optional. "" uses .Host
|
||||
}
|
||||
|
||||
// IsDefault returns true if u is the default value
|
||||
func (u *Upstream) IsDefault() bool {
|
||||
return *u == Upstream{}
|
||||
}
|
||||
|
||||
// String returns the string representation of u
|
||||
func (u Upstream) String() string {
|
||||
if u.IsDefault() {
|
||||
return "no upstream"
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(u.Net.String())
|
||||
sb.WriteRune(':')
|
||||
|
||||
if u.Net == NetProtocolHttps {
|
||||
sb.WriteString("//")
|
||||
}
|
||||
|
||||
isIPv6 := strings.ContainsRune(u.Host, ':')
|
||||
if isIPv6 {
|
||||
sb.WriteRune('[')
|
||||
sb.WriteString(u.Host)
|
||||
sb.WriteRune(']')
|
||||
} else {
|
||||
sb.WriteString(u.Host)
|
||||
}
|
||||
|
||||
if u.Port != netDefaultPort[u.Net] {
|
||||
sb.WriteRune(':')
|
||||
sb.WriteString(fmt.Sprint(u.Port))
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
sb.WriteString(u.Path)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// UnmarshalText implements `encoding.TextUnmarshaler`.
|
||||
func (u *Upstream) UnmarshalText(data []byte) error {
|
||||
s := string(data)
|
||||
|
||||
upstream, err := ParseUpstream(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't convert upstream '%s': %w", s, err)
|
||||
}
|
||||
|
||||
*u = upstream
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseUpstream creates new Upstream from passed string in format [net]:host[:port][/path][#commonname]
|
||||
func ParseUpstream(upstream string) (Upstream, error) {
|
||||
var path string
|
||||
|
||||
var port uint16
|
||||
|
||||
commonName, upstream := extractCommonName(upstream)
|
||||
|
||||
n, upstream := extractNet(upstream)
|
||||
|
||||
path, upstream = extractPath(upstream)
|
||||
|
||||
host, portString, err := net.SplitHostPort(upstream)
|
||||
|
||||
// string contains host:port
|
||||
if err == nil {
|
||||
p, err := ConvertPort(portString)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
|
||||
|
||||
return Upstream{}, err
|
||||
}
|
||||
|
||||
port = p
|
||||
} else {
|
||||
// only host, use default port
|
||||
host = upstream
|
||||
port = netDefaultPort[n]
|
||||
|
||||
// trim any IPv6 brackets
|
||||
host = strings.TrimPrefix(host, "[")
|
||||
host = strings.TrimSuffix(host, "]")
|
||||
}
|
||||
|
||||
// validate hostname or ip
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
// is not IP
|
||||
if !validDomain.MatchString(host) {
|
||||
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
|
||||
}
|
||||
}
|
||||
|
||||
return Upstream{
|
||||
Net: n,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Path: path,
|
||||
CommonName: commonName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractCommonName(in string) (string, string) {
|
||||
upstream, cn, _ := strings.Cut(in, "#")
|
||||
|
||||
return cn, upstream
|
||||
}
|
||||
|
||||
func extractPath(in string) (path, upstream string) {
|
||||
slashIdx := strings.Index(in, "/")
|
||||
|
||||
if slashIdx >= 0 {
|
||||
path = in[slashIdx:]
|
||||
upstream = in[:slashIdx]
|
||||
} else {
|
||||
upstream = in
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func extractNet(upstream string) (NetProtocol, string) {
|
||||
tcpUDPPrefix := NetProtocolTcpUdp.String() + ":"
|
||||
if strings.HasPrefix(upstream, tcpUDPPrefix) {
|
||||
return NetProtocolTcpUdp, upstream[len(tcpUDPPrefix):]
|
||||
}
|
||||
|
||||
tcpTLSPrefix := NetProtocolTcpTls.String() + ":"
|
||||
if strings.HasPrefix(upstream, tcpTLSPrefix) {
|
||||
return NetProtocolTcpTls, upstream[len(tcpTLSPrefix):]
|
||||
}
|
||||
|
||||
httpsPrefix := NetProtocolHttps.String() + ":"
|
||||
if strings.HasPrefix(upstream, httpsPrefix) {
|
||||
return NetProtocolHttps, strings.TrimPrefix(upstream[len(httpsPrefix):], "//")
|
||||
}
|
||||
|
||||
return NetProtocolTcpUdp, upstream
|
||||
}
|
1
go.mod
1
go.mod
|
@ -123,6 +123,7 @@ require (
|
|||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.0 // indirect
|
||||
golang.org/x/crypto v0.6.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230307190834-24139beb5833
|
||||
golang.org/x/mod v0.8.0 // indirect
|
||||
golang.org/x/sys v0.6.0 // indirect
|
||||
golang.org/x/term v0.6.0 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -475,6 +475,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
|||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
|
||||
golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hako/durafmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -38,9 +37,6 @@ type ListCacheType int
|
|||
type Matcher interface {
|
||||
// Match matches passed domain name against cached list entries
|
||||
Match(domain string, groupsToCheck []string) (found bool, group string)
|
||||
|
||||
// Configuration returns current configuration and stats
|
||||
Configuration() []string
|
||||
}
|
||||
|
||||
// ListCache generic cache of strings divided in groups
|
||||
|
@ -55,42 +51,19 @@ type ListCache struct {
|
|||
processingConcurrency uint
|
||||
}
|
||||
|
||||
// Configuration returns current configuration and stats
|
||||
func (b *ListCache) Configuration() (result []string) {
|
||||
if b.refreshPeriod > 0 {
|
||||
result = append(result, fmt.Sprintf("refresh period: %s", durafmt.Parse(b.refreshPeriod)))
|
||||
} else {
|
||||
result = append(result, "refresh: disabled")
|
||||
}
|
||||
|
||||
result = append(result, "group links:")
|
||||
for group, links := range b.groupToLinks {
|
||||
result = append(result, fmt.Sprintf(" %s:", group))
|
||||
|
||||
for _, link := range links {
|
||||
if strings.Contains(link, "\n") {
|
||||
link = "[INLINE DEFINITION]"
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf(" - %s", link))
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, "group caches:")
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (b *ListCache) LogConfig(logger *logrus.Entry) {
|
||||
var total int
|
||||
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
|
||||
for group, cache := range b.groupCaches {
|
||||
result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount()))
|
||||
logger.Infof("%s: %d entries", group, cache.ElementCount())
|
||||
total += cache.ElementCount()
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf(" TOTAL: %d entries", total))
|
||||
|
||||
return result
|
||||
logger.Infof("TOTAL: %d entries", total)
|
||||
}
|
||||
|
||||
// NewListCache creates new list instance
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
|
@ -14,7 +13,9 @@ import (
|
|||
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -414,36 +415,31 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
Describe("Configuration", func() {
|
||||
When("refresh is enabled", func() {
|
||||
It("should print list configuration", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL},
|
||||
"gr2": {inlineList("inline", "definition")},
|
||||
}
|
||||
Describe("LogConfig", func() {
|
||||
var (
|
||||
logger *logrus.Entry
|
||||
hook *log.MockLoggerHook
|
||||
)
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
|
||||
defaultProcessingConcurrency, false)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement("refresh period: 1 hour"))
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
BeforeEach(func() {
|
||||
logger, hook = log.NewMockEntry()
|
||||
})
|
||||
When("refresh is disabled", func() {
|
||||
It("should print 'refresh disabled'", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {emptyFile.Path},
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(),
|
||||
defaultProcessingConcurrency, false)
|
||||
Expect(err).Should(Succeed())
|
||||
It("should print list configuration", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL},
|
||||
"gr2": {inlineList("inline", "definition")},
|
||||
}
|
||||
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement("refresh: disabled"))
|
||||
})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
|
||||
defaultProcessingConcurrency, false)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
sut.LogConfig(logger)
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr1:")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("gr2:")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("TOTAL:")))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -486,7 +482,7 @@ func (m *MockDownloader) ListSource() string {
|
|||
func createTestListFile(dir string, totalLines int) (string, int) {
|
||||
file, err := os.CreateTemp(dir, "blocky")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
log.Log().Fatal(err)
|
||||
}
|
||||
|
||||
w := bufio.NewWriter(file)
|
||||
|
|
|
@ -6,9 +6,11 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
prefixed "github.com/x-cray/logrus-prefixed-formatter"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const prefixField = "prefix"
|
||||
|
@ -118,3 +120,50 @@ func ConfigureLogger(cfg *Config) {
|
|||
func Silence() {
|
||||
logger.Out = io.Discard
|
||||
}
|
||||
|
||||
func WithIndent(log *logrus.Entry, prefix string, callback func(*logrus.Entry)) {
|
||||
undo := indentMessages(prefix, log.Logger)
|
||||
defer undo()
|
||||
|
||||
callback(log)
|
||||
}
|
||||
|
||||
// indentMessages modifies a logger and adds `prefix` to all messages.
|
||||
//
|
||||
// The returned function must be called to remove the prefix.
|
||||
func indentMessages(prefix string, logger *logrus.Logger) func() {
|
||||
if _, ok := logger.Formatter.(*prefixed.TextFormatter); !ok {
|
||||
// log is not plaintext, do nothing
|
||||
return func() {}
|
||||
}
|
||||
|
||||
oldHooks := maps.Clone(logger.Hooks)
|
||||
|
||||
logger.AddHook(prefixMsgHook{
|
||||
prefix: prefix,
|
||||
})
|
||||
|
||||
var once sync.Once
|
||||
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
logger.ReplaceHooks(oldHooks)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type prefixMsgHook struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
// Levels implements `logrus.Hook`.
|
||||
func (h prefixMsgHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
||||
|
||||
// Fire implements `logrus.Hook`.
|
||||
func (h prefixMsgHook) Fire(entry *logrus.Entry) error {
|
||||
entry.Message = h.prefix + entry.Message
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func NewMockEntry() (*logrus.Entry, *MockLoggerHook) {
|
||||
logger := logrus.New()
|
||||
logger.Out = io.Discard
|
||||
|
||||
entry := logrus.Entry{Logger: logger}
|
||||
hook := MockLoggerHook{}
|
||||
|
||||
entry.Logger.AddHook(&hook)
|
||||
|
||||
hook.On("Fire", mock.Anything).Return(nil)
|
||||
|
||||
return &entry, &hook
|
||||
}
|
||||
|
||||
type MockLoggerHook struct {
|
||||
mock.Mock
|
||||
|
||||
Messages []string
|
||||
}
|
||||
|
||||
// Levels implements `logrus.Hook`.
|
||||
func (h *MockLoggerHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
||||
|
||||
// Fire implements `logrus.Hook`.
|
||||
func (h *MockLoggerHook) Fire(entry *logrus.Entry) error {
|
||||
_ = h.Called()
|
||||
|
||||
h.Messages = append(h.Messages, entry.Message)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -18,7 +18,7 @@ func RegisterMetric(c prometheus.Collector) {
|
|||
}
|
||||
|
||||
// Start starts prometheus endpoint
|
||||
func Start(router *chi.Mux, cfg config.PrometheusConfig) {
|
||||
func Start(router *chi.Mux, cfg config.MetricsConfig) {
|
||||
if cfg.Enable {
|
||||
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||
_ = reg.Register(collectors.NewGoCollector())
|
||||
|
|
|
@ -22,6 +22,31 @@ import (
|
|||
// )
|
||||
type ResponseType int
|
||||
|
||||
func (t ResponseType) ToExtendedErrorCode() uint16 {
|
||||
switch t {
|
||||
case ResponseTypeRESOLVED:
|
||||
return dns.ExtendedErrorCodeOther
|
||||
case ResponseTypeCACHED:
|
||||
return dns.ExtendedErrorCodeCachedError
|
||||
case ResponseTypeCONDITIONAL:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case ResponseTypeCUSTOMDNS:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case ResponseTypeHOSTSFILE:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case ResponseTypeNOTFQDN:
|
||||
return dns.ExtendedErrorCodeBlocked
|
||||
case ResponseTypeBLOCKED:
|
||||
return dns.ExtendedErrorCodeBlocked
|
||||
case ResponseTypeFILTERED:
|
||||
return dns.ExtendedErrorCodeFiltered
|
||||
case ResponseTypeSPECIAL:
|
||||
return dns.ExtendedErrorCodeFiltered
|
||||
default:
|
||||
return dns.ExtendedErrorCodeOther
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents the response of a DNS query
|
||||
type Response struct {
|
||||
Res *dns.Msg
|
||||
|
|
|
@ -83,7 +83,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
Password: cfg.Password,
|
||||
DB: cfg.Database,
|
||||
MaxRetries: cfg.ConnectionAttempts,
|
||||
MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown),
|
||||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||
})
|
||||
} else {
|
||||
rdb = redis.NewClient(&redis.Options{
|
||||
|
@ -92,7 +92,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
Password: cfg.Password,
|
||||
DB: cfg.Database,
|
||||
MaxRetries: cfg.ConnectionAttempts,
|
||||
MaxRetryBackoff: time.Duration(cfg.ConnectionCooldown),
|
||||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
|
|||
return nxDomainBlockHandler{}, nil
|
||||
}
|
||||
|
||||
blockTime := uint32(time.Duration(cfg.BlockTTL).Seconds())
|
||||
blockTime := cfg.BlockTTL.SecondsU32()
|
||||
|
||||
if strings.EqualFold(cfgBlockType, "ZEROIP") {
|
||||
return zeroIPBlockHandler{
|
||||
|
@ -78,16 +78,17 @@ type status struct {
|
|||
|
||||
// BlockingResolver checks request's question (domain name) against black and white lists
|
||||
type BlockingResolver struct {
|
||||
configurable[*config.BlockingConfig]
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
blacklistMatcher *lists.ListCache
|
||||
whitelistMatcher *lists.ListCache
|
||||
cfg config.BlockingConfig
|
||||
blockHandler blockHandler
|
||||
whitelistOnlyGroups map[string]bool
|
||||
status *status
|
||||
clientGroupsBlock map[string][]string
|
||||
redisClient *redis.Client
|
||||
redisEnabled bool
|
||||
fqdnIPCache expirationcache.ExpiringCache
|
||||
}
|
||||
|
||||
|
@ -100,7 +101,7 @@ func NewBlockingResolver(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
refreshPeriod := time.Duration(cfg.RefreshPeriod)
|
||||
refreshPeriod := cfg.RefreshPeriod.ToDuration()
|
||||
downloader := createDownloader(cfg, bootstrap)
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists,
|
||||
refreshPeriod, downloader, cfg.ProcessingConcurrency,
|
||||
|
@ -129,8 +130,10 @@ func NewBlockingResolver(
|
|||
}
|
||||
|
||||
res := &BlockingResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("blocking"),
|
||||
|
||||
blockHandler: blockHandler,
|
||||
cfg: cfg,
|
||||
blacklistMatcher: blacklistMatcher,
|
||||
whitelistMatcher: whitelistMatcher,
|
||||
whitelistOnlyGroups: whitelistOnlyGroups,
|
||||
|
@ -140,10 +143,9 @@ func NewBlockingResolver(
|
|||
},
|
||||
clientGroupsBlock: cgb,
|
||||
redisClient: redis,
|
||||
redisEnabled: (redis != nil),
|
||||
}
|
||||
|
||||
if res.redisEnabled {
|
||||
if res.redisClient != nil {
|
||||
setupRedisEnabledSubscriber(res)
|
||||
}
|
||||
|
||||
|
@ -156,27 +158,25 @@ func NewBlockingResolver(
|
|||
|
||||
func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader {
|
||||
return lists.NewDownloader(
|
||||
lists.WithTimeout(time.Duration(cfg.DownloadTimeout)),
|
||||
lists.WithTimeout(cfg.DownloadTimeout.ToDuration()),
|
||||
lists.WithAttempts(cfg.DownloadAttempts),
|
||||
lists.WithCooldown(time.Duration(cfg.DownloadCooldown)),
|
||||
lists.WithCooldown(cfg.DownloadCooldown.ToDuration()),
|
||||
lists.WithTransport(bootstrap.NewHTTPTransport()),
|
||||
)
|
||||
}
|
||||
|
||||
func setupRedisEnabledSubscriber(c *BlockingResolver) {
|
||||
logger := log.PrefixedLog("blocking_resolver")
|
||||
|
||||
go func() {
|
||||
for em := range c.redisClient.EnabledChannel {
|
||||
if em != nil {
|
||||
logger.Debug("Received state from redis: ", em)
|
||||
c.log().Debug("Received state from redis: ", em)
|
||||
|
||||
if em.State {
|
||||
c.internalEnableBlocking()
|
||||
} else {
|
||||
err := c.internalDisableBlocking(em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
logger.Warn("Blocking couldn't be disabled:", err)
|
||||
c.log().Warn("Blocking couldn't be disabled:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
|
|||
func (r *BlockingResolver) EnableBlocking() {
|
||||
r.internalEnableBlocking()
|
||||
|
||||
if r.redisEnabled {
|
||||
if r.redisClient != nil {
|
||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true})
|
||||
}
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ func (r *BlockingResolver) internalEnableBlocking() {
|
|||
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
|
||||
func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error {
|
||||
err := r.internalDisableBlocking(duration, disableGroups)
|
||||
if err == nil && r.redisEnabled {
|
||||
if err == nil && r.redisClient != nil {
|
||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{
|
||||
State: false,
|
||||
Duration: duration,
|
||||
|
@ -329,39 +329,15 @@ func (r *BlockingResolver) handleBlocked(logger *logrus.Entry,
|
|||
return &model.Response{Res: response, RType: model.ResponseTypeBLOCKED, Reason: reason}, nil
|
||||
}
|
||||
|
||||
// Configuration returns the current resolver configuration
|
||||
func (r *BlockingResolver) Configuration() (result []string) {
|
||||
if len(r.cfg.ClientGroupsBlock) == 0 {
|
||||
return configDisabled
|
||||
}
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *BlockingResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.cfg.LogConfig(logger)
|
||||
|
||||
result = append(result, "clientGroupsBlock")
|
||||
for key, val := range r.cfg.ClientGroupsBlock {
|
||||
result = append(result, fmt.Sprintf(" %s = \"%s\"", key, strings.Join(val, ";")))
|
||||
}
|
||||
logger.Info("blacklist cache entries:")
|
||||
log.WithIndent(logger, " ", r.blacklistMatcher.LogConfig)
|
||||
|
||||
blockType := r.cfg.BlockType
|
||||
result = append(result, fmt.Sprintf("blockType = \"%s\"", blockType))
|
||||
|
||||
if blockType != "NXDOMAIN" {
|
||||
result = append(result, fmt.Sprintf("blockTTL = %s", r.cfg.BlockTTL.String()))
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf("downloadTimeout = %s", r.cfg.DownloadTimeout.String()))
|
||||
|
||||
result = append(result, fmt.Sprintf("FailStartOnListError = %t", r.cfg.FailStartOnListError))
|
||||
|
||||
result = append(result, "blacklist:")
|
||||
for _, c := range r.blacklistMatcher.Configuration() {
|
||||
result = append(result, fmt.Sprintf(" %s", c))
|
||||
}
|
||||
|
||||
result = append(result, "whitelist:")
|
||||
for _, c := range r.whitelistMatcher.Configuration() {
|
||||
result = append(result, fmt.Sprintf(" %s", c))
|
||||
}
|
||||
|
||||
return result
|
||||
logger.Info("whitelist cache entries:")
|
||||
log.WithIndent(logger, " ", r.whitelistMatcher.LogConfig)
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool {
|
||||
|
@ -594,7 +570,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
|
|||
}
|
||||
|
||||
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (result []net.IP, ttl time.Duration) {
|
||||
prefixedLog := log.PrefixedLog("FQDNClientIdentifierCache")
|
||||
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
|
||||
|
||||
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
resp, err := r.next.Resolve(&model.Request{
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
. "github.com/0xERR0R/blocky/evt"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/lists"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/redis"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -51,7 +52,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
|
@ -71,6 +76,22 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Events", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
|
@ -1086,35 +1107,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{"gr1": {group1File.Path}},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{}
|
||||
})
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Create resolver with wrong parameter", func() {
|
||||
When("Wrong blockType is used", func() {
|
||||
It("should return error", func() {
|
||||
|
|
|
@ -62,7 +62,11 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
return b, nil
|
||||
}
|
||||
|
||||
parallelResolver, err := newParallelBestResolver(bootstraped.ResolverGroups())
|
||||
// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
|
||||
// where `ParallelBestResolver` uses its config, we can just use an empty one.
|
||||
pbCfg := config.ParallelBestConfig{}
|
||||
|
||||
parallelResolver, err := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create bootstrap ParallelBestResolver: %w", err)
|
||||
}
|
||||
|
@ -74,20 +78,16 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
cachingCfg := cfg.Caching
|
||||
cachingCfg.EnablePrefetch()
|
||||
|
||||
if cachingCfg.MinCachingTime == 0 {
|
||||
if cachingCfg.MinCachingTime.IsZero() {
|
||||
// Set a min time in case the user didn't to avoid prefetching too often
|
||||
cachingCfg.MinCachingTime = config.Duration(time.Hour)
|
||||
}
|
||||
|
||||
b.bootstraped = bootstraped
|
||||
|
||||
cachingResolver := NewCachingResolver(cachingCfg, nil)
|
||||
// don't emit any metrics
|
||||
cachingResolver.emitMetricEvents = false
|
||||
|
||||
b.resolver = Chain(
|
||||
NewFilteringResolver(cfg.Filtering),
|
||||
cachingResolver,
|
||||
newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones
|
||||
parallelResolver,
|
||||
)
|
||||
|
||||
|
@ -116,10 +116,10 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
|||
ctx := context.Background()
|
||||
|
||||
timeout := cfg.UpstreamTimeout
|
||||
if timeout != 0 {
|
||||
if timeout.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration())
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
|
|
|
@ -282,7 +282,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
||||
Expect(ips).Should(HaveLen(0))
|
||||
Expect(ips).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -296,7 +296,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
||||
Expect(ips).Should(HaveLen(0))
|
||||
Expect(ips).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hako/durafmt"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/expirationcache"
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
|
@ -24,16 +22,15 @@ const defaultCachingCleanUpInterval = 5 * time.Second
|
|||
// CachingResolver caches answers from dns queries with their TTL time,
|
||||
// to avoid external resolver calls for recurrent queries
|
||||
type CachingResolver struct {
|
||||
configurable[*config.CachingConfig]
|
||||
NextResolver
|
||||
minCacheTimeSec, maxCacheTimeSec int
|
||||
cacheTimeNegative time.Duration
|
||||
resultCache expirationcache.ExpiringCache
|
||||
prefetchExpires time.Duration
|
||||
prefetchThreshold int
|
||||
prefetchingNameCache expirationcache.ExpiringCache
|
||||
redisClient *redis.Client
|
||||
redisEnabled bool
|
||||
emitMetricEvents bool
|
||||
typed
|
||||
|
||||
emitMetricEvents bool // disabled by Bootstrap
|
||||
|
||||
resultCache expirationcache.ExpiringCache
|
||||
prefetchingNameCache expirationcache.ExpiringCache
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
// cacheValue includes query answer and prefetch flag
|
||||
|
@ -44,18 +41,21 @@ type cacheValue struct {
|
|||
|
||||
// NewCachingResolver creates a new resolver instance
|
||||
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
|
||||
return newCachingResolver(cfg, redis, true)
|
||||
}
|
||||
|
||||
func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver {
|
||||
c := &CachingResolver{
|
||||
minCacheTimeSec: int(time.Duration(cfg.MinCachingTime).Seconds()),
|
||||
maxCacheTimeSec: int(time.Duration(cfg.MaxCachingTime).Seconds()),
|
||||
cacheTimeNegative: time.Duration(cfg.CacheTimeNegative),
|
||||
redisClient: redis,
|
||||
redisEnabled: (redis != nil),
|
||||
emitMetricEvents: true,
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("caching"),
|
||||
|
||||
redisClient: redis,
|
||||
emitMetricEvents: emitMetricEvents,
|
||||
}
|
||||
|
||||
configureCaches(c, &cfg)
|
||||
|
||||
if c.redisEnabled {
|
||||
if c.redisClient != nil {
|
||||
setupRedisCacheSubscriber(c)
|
||||
c.redisClient.GetRedisCache()
|
||||
}
|
||||
|
@ -68,10 +68,6 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
|||
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
|
||||
|
||||
if cfg.Prefetching {
|
||||
c.prefetchExpires = time.Duration(cfg.PrefetchExpires)
|
||||
|
||||
c.prefetchThreshold = cfg.PrefetchThreshold
|
||||
|
||||
c.prefetchingNameCache = expirationcache.NewCache(
|
||||
expirationcache.WithCleanUpInterval(time.Minute),
|
||||
expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount)),
|
||||
|
@ -88,12 +84,10 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
|||
}
|
||||
|
||||
func setupRedisCacheSubscriber(c *CachingResolver) {
|
||||
logger := log.PrefixedLog("caching_resolver")
|
||||
|
||||
go func() {
|
||||
for rc := range c.redisClient.CacheChannel {
|
||||
if rc != nil {
|
||||
logger.Debug("Received key from redis: ", rc.Key)
|
||||
c.log().Debug("Received key from redis: ", rc.Key)
|
||||
c.putInCache(rc.Key, rc.Response, false, false)
|
||||
}
|
||||
}
|
||||
|
@ -102,22 +96,22 @@ func setupRedisCacheSubscriber(c *CachingResolver) {
|
|||
|
||||
// check if domain was queried > threshold in the time window
|
||||
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
|
||||
if r.prefetchThreshold == 0 {
|
||||
if r.cfg.PrefetchThreshold == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
||||
|
||||
return cnt != nil && cnt.(int) > r.prefetchThreshold
|
||||
return cnt != nil && cnt.(int) > r.cfg.PrefetchThreshold
|
||||
}
|
||||
|
||||
func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
|
||||
logger := log.PrefixedLog("caching_resolver")
|
||||
|
||||
if r.shouldPrefetch(cacheKey) {
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType.String())
|
||||
logger := r.log()
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
|
||||
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
||||
response, err := r.next.Resolve(req)
|
||||
|
@ -136,29 +130,11 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
|
|||
return nil, 0
|
||||
}
|
||||
|
||||
// Configuration returns a current resolver configuration
|
||||
func (r *CachingResolver) Configuration() (result []string) {
|
||||
if r.maxCacheTimeSec < 0 {
|
||||
return configDisabled
|
||||
}
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.cfg.LogConfig(logger)
|
||||
|
||||
result = append(result, fmt.Sprintf("minCacheTimeInSec = %d", r.minCacheTimeSec))
|
||||
|
||||
result = append(result, fmt.Sprintf("maxCacheTimeSec = %d", r.maxCacheTimeSec))
|
||||
|
||||
result = append(result, fmt.Sprintf("cacheTimeNegative = %s", durafmt.Parse(r.cacheTimeNegative)))
|
||||
|
||||
result = append(result, fmt.Sprintf("prefetching = %t", r.prefetchingNameCache != nil))
|
||||
|
||||
if r.prefetchingNameCache != nil {
|
||||
result = append(result, fmt.Sprintf("prefetchExpires = %s", durafmt.Parse(r.prefetchExpires)))
|
||||
|
||||
result = append(result, fmt.Sprintf("prefetchThreshold = %d", r.prefetchThreshold))
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.TotalCount()))
|
||||
|
||||
return
|
||||
logger.Infof("cache entries = %d", r.resultCache.TotalCount())
|
||||
}
|
||||
|
||||
// Resolve checks if the current query result is already in the cache and returns it
|
||||
|
@ -166,7 +142,7 @@ func (r *CachingResolver) Configuration() (result []string) {
|
|||
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||
logger := log.WithPrefix(request.Log, "caching_resolver")
|
||||
|
||||
if r.maxCacheTimeSec < 0 {
|
||||
if r.cfg.MaxCachingTime < 0 {
|
||||
logger.Debug("skip cache")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
|
@ -214,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
|||
response, err = r.next.Resolve(request)
|
||||
|
||||
if err == nil {
|
||||
r.putInCache(cacheKey, response, false, r.redisEnabled)
|
||||
r.putInCache(cacheKey, response, false, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -228,7 +204,7 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log
|
|||
domainCount = x.(int)
|
||||
}
|
||||
domainCount++
|
||||
r.prefetchingNameCache.Put(cacheKey, domainCount, r.prefetchExpires)
|
||||
r.prefetchingNameCache.Put(cacheKey, domainCount, r.cfg.PrefetchExpires.ToDuration())
|
||||
totalCount := r.prefetchingNameCache.TotalCount()
|
||||
|
||||
logger.Debugf("domain '%s' was requested %d times, "+
|
||||
|
@ -242,9 +218,9 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
|
|||
// put value into cache
|
||||
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer))
|
||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||
if r.cacheTimeNegative > 0 {
|
||||
if r.cfg.CacheTimeNegative > 0 {
|
||||
// put negative cache if result code is NXDOMAIN
|
||||
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cacheTimeNegative)
|
||||
r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,20 +240,20 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
|
|||
var max uint32
|
||||
|
||||
if len(answer) == 0 {
|
||||
return r.cacheTimeNegative
|
||||
return r.cfg.CacheTimeNegative.ToDuration()
|
||||
}
|
||||
|
||||
for _, a := range answer {
|
||||
// if TTL < mitTTL -> adjust the value, set minTTL
|
||||
if r.minCacheTimeSec > 0 {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) < uint32(r.minCacheTimeSec) {
|
||||
atomic.StoreUint32(&a.Header().Ttl, uint32(r.minCacheTimeSec))
|
||||
if r.cfg.MinCachingTime > 0 {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
if r.maxCacheTimeSec > 0 {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) > uint32(r.maxCacheTimeSec) {
|
||||
atomic.StoreUint32(&a.Header().Ttl, uint32(r.maxCacheTimeSec))
|
||||
if r.cfg.MaxCachingTime > 0 {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/redis"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -27,6 +28,12 @@ var _ = Describe("CachingResolver", func() {
|
|||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.CachingConfig{}
|
||||
if err := defaults.Set(&sutConfig); err != nil {
|
||||
|
@ -42,6 +49,22 @@ var _ = Describe("CachingResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Caching responses", func() {
|
||||
When("prefetching is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
|
@ -103,6 +126,15 @@ var _ = Describe("CachingResolver", func() {
|
|||
HaveTTL(BeNumerically("<=", 2))))
|
||||
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
|
||||
})
|
||||
When("threshold is 0", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.PrefetchThreshold = 0
|
||||
})
|
||||
|
||||
It("should always prefetch", func() {
|
||||
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
When("min caching time is defined", func() {
|
||||
BeforeEach(func() {
|
||||
|
@ -364,6 +396,10 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
It("response should be cached", func() {
|
||||
By("default config should enable negative caching", func() {
|
||||
Expect(sutConfig.CacheTimeNegative).Should(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Should(SatisfyAll(
|
||||
|
@ -495,43 +531,6 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.CachingConfig{}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.CachingConfig{
|
||||
MaxCachingTime: config.Duration(time.Minute * -1),
|
||||
}
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
|
||||
When("prefetching is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.CachingConfig{
|
||||
Prefetching: true,
|
||||
}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
Expect(c).Should(ContainElement(ContainSubstring("prefetchThreshold")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Redis is configured", func() {
|
||||
var (
|
||||
redisServer *miniredis.Miniredis
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -18,11 +17,12 @@ import (
|
|||
|
||||
// ClientNamesResolver tries to determine client name by asking responsible DNS server via rDNS (reverse lookup)
|
||||
type ClientNamesResolver struct {
|
||||
configurable[*config.ClientLookupConfig]
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
cache expirationcache.ExpiringCache
|
||||
externalResolver Resolver
|
||||
singleNameOrder []uint
|
||||
clientIPMapping map[string][]net.IP
|
||||
NextResolver
|
||||
}
|
||||
|
||||
// NewClientNamesResolver creates new resolver instance
|
||||
|
@ -38,38 +38,21 @@ func NewClientNamesResolver(
|
|||
}
|
||||
|
||||
cr = &ClientNamesResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("client_names"),
|
||||
|
||||
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval(time.Hour)),
|
||||
externalResolver: r,
|
||||
singleNameOrder: cfg.SingleNameOrder,
|
||||
clientIPMapping: cfg.ClientnameIPMapping,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Configuration returns current resolver configuration
|
||||
func (r *ClientNamesResolver) Configuration() (result []string) {
|
||||
if r.externalResolver == nil && len(r.clientIPMapping) == 0 {
|
||||
return append(configDisabled, "use only IP address")
|
||||
}
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.cfg.LogConfig(logger)
|
||||
|
||||
result = append(result, fmt.Sprintf("singleNameOrder = \"%v\"", r.singleNameOrder))
|
||||
|
||||
if r.externalResolver != nil {
|
||||
result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver))
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf("cache item count = %d", r.cache.TotalCount()))
|
||||
|
||||
if len(r.clientIPMapping) > 0 {
|
||||
result = append(result, "client IP mapping:")
|
||||
|
||||
for k, v := range r.clientIPMapping {
|
||||
result = append(result, fmt.Sprintf("%s -> %s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
logger.Infof("cache entries = %d", r.cache.TotalCount())
|
||||
}
|
||||
|
||||
// Resolve tries to resolve the client name from the ip address
|
||||
|
@ -89,13 +72,11 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
|||
}
|
||||
|
||||
ip := request.ClientIP
|
||||
|
||||
if ip == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
c, _ := r.cache.Get(ip.String())
|
||||
|
||||
if c != nil {
|
||||
if t, ok := c.([]string); ok {
|
||||
return t
|
||||
|
@ -103,6 +84,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
|||
}
|
||||
|
||||
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
|
||||
|
||||
r.cache.Put(ip.String(), names, time.Hour)
|
||||
|
||||
return names
|
||||
|
@ -127,7 +109,6 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
|
|||
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
||||
// try client mapping first
|
||||
result = r.getNameFromIPMapping(ip, result)
|
||||
|
||||
if len(result) > 0 {
|
||||
return
|
||||
}
|
||||
|
@ -151,8 +132,8 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
|||
clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip)
|
||||
|
||||
// optional: if singleNameOrder is set, use only one name in the defined order
|
||||
if len(r.singleNameOrder) > 0 {
|
||||
for _, i := range r.singleNameOrder {
|
||||
if len(r.cfg.SingleNameOrder) > 0 {
|
||||
for _, i := range r.cfg.SingleNameOrder {
|
||||
if i > 0 && int(i) <= len(clientNames) {
|
||||
result = []string{clientNames[i-1]}
|
||||
|
||||
|
@ -169,7 +150,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
|||
}
|
||||
|
||||
func (r *ClientNamesResolver) getNameFromIPMapping(ip net.IP, result []string) []string {
|
||||
for name, ips := range r.clientIPMapping {
|
||||
for name, ips := range r.cfg.ClientnameIPMapping {
|
||||
for _, i := range ips {
|
||||
if ip.String() == i.String() {
|
||||
result = append(result, name)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
|
@ -22,6 +23,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
m *mockResolver
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
res, err := NewClientNamesResolver(sutConfig, nil, false)
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -31,6 +38,22 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolve client name from request clientID", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.ClientLookupConfig{}
|
||||
|
@ -281,7 +304,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
})
|
||||
When("Upstream produces error", func() {
|
||||
BeforeEach(func() {
|
||||
JustBeforeEach(func() {
|
||||
sutConfig = config.ClientLookupConfig{}
|
||||
clientMockResolver := &mockResolver{}
|
||||
clientMockResolver.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
|
||||
|
@ -348,32 +371,4 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.ClientLookupConfig{
|
||||
Upstream: config.Upstream{Net: config.NetProtocolTcpUdp, Host: "host"},
|
||||
SingleNameOrder: []uint{1, 2},
|
||||
ClientnameIPMapping: map[string][]net.IP{
|
||||
"client8": {net.ParseIP("1.2.3.5")},
|
||||
},
|
||||
}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.ClientLookupConfig{}
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -15,7 +14,10 @@ import (
|
|||
|
||||
// ConditionalUpstreamResolver delegates DNS question to other DNS resolver dependent on domain name in question
|
||||
type ConditionalUpstreamResolver struct {
|
||||
configurable[*config.ConditionalUpstreamConfig]
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
mapping map[string]Resolver
|
||||
}
|
||||
|
||||
|
@ -26,10 +28,13 @@ func NewConditionalUpstreamResolver(
|
|||
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
||||
|
||||
for domain, upstream := range cfg.Mapping.Upstreams {
|
||||
upstreams := make(map[string][]config.Upstream)
|
||||
upstreams[upstreamDefaultCfgName] = upstream
|
||||
pbCfg := config.ParallelBestConfig{
|
||||
ExternalResolvers: config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: upstream,
|
||||
},
|
||||
}
|
||||
|
||||
r, err := NewParallelBestResolver(upstreams, bootstrap, shouldVerifyUpstreams)
|
||||
r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,20 +42,14 @@ func NewConditionalUpstreamResolver(
|
|||
m[strings.ToLower(domain)] = r
|
||||
}
|
||||
|
||||
return &ConditionalUpstreamResolver{mapping: m}, nil
|
||||
}
|
||||
r := ConditionalUpstreamResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("conditional_upstream"),
|
||||
|
||||
// Configuration returns current configuration
|
||||
func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
|
||||
if len(r.mapping) == 0 {
|
||||
return configDisabled
|
||||
mapping: m,
|
||||
}
|
||||
|
||||
for key, val := range r.mapping {
|
||||
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
|
||||
}
|
||||
|
||||
return
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package resolver
|
|||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
||||
|
@ -18,6 +19,12 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
m *mockResolver
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
|
||||
|
@ -54,6 +61,22 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolve conditional DNS queries via defined DNS server", func() {
|
||||
When("Query is exact equal defined condition in mapping", func() {
|
||||
Context("first mapping entry", func() {
|
||||
|
@ -149,22 +172,4 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
Expect(r).Should(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{}, nil, false)
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -17,11 +15,12 @@ import (
|
|||
|
||||
// CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map
|
||||
type CustomDNSResolver struct {
|
||||
configurable[*config.CustomDNSConfig]
|
||||
NextResolver
|
||||
mapping map[string][]net.IP
|
||||
reverseAddresses map[string][]string
|
||||
ttl uint32
|
||||
filterUnmappedTypes bool
|
||||
typed
|
||||
|
||||
mapping map[string][]net.IP
|
||||
reverseAddresses map[string][]string
|
||||
}
|
||||
|
||||
// NewCustomDNSResolver creates new resolver instance
|
||||
|
@ -38,27 +37,13 @@ func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver {
|
|||
}
|
||||
}
|
||||
|
||||
ttl := uint32(time.Duration(cfg.CustomTTL).Seconds())
|
||||
|
||||
return &CustomDNSResolver{
|
||||
mapping: m,
|
||||
reverseAddresses: reverse,
|
||||
ttl: ttl,
|
||||
filterUnmappedTypes: cfg.FilterUnmappedTypes,
|
||||
}
|
||||
}
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("custom_dns"),
|
||||
|
||||
// Configuration returns current resolver configuration
|
||||
func (r *CustomDNSResolver) Configuration() (result []string) {
|
||||
if len(r.mapping) == 0 {
|
||||
return configDisabled
|
||||
mapping: m,
|
||||
reverseAddresses: reverse,
|
||||
}
|
||||
|
||||
for key, val := range r.mapping {
|
||||
result = append(result, fmt.Sprintf("%s = \"%s\"", key, val))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func isSupportedType(ip net.IP, question dns.Question) bool {
|
||||
|
@ -75,7 +60,7 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
response.SetReply(request.Req)
|
||||
|
||||
for _, url := range urls {
|
||||
h := util.CreateHeader(question, r.ttl)
|
||||
h := util.CreateHeader(question, r.cfg.CustomTTL.SecondsU32())
|
||||
ptr := new(dns.PTR)
|
||||
ptr.Ptr = dns.Fqdn(url)
|
||||
ptr.Hdr = h
|
||||
|
@ -103,7 +88,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
if found {
|
||||
for _, ip := range ips {
|
||||
if isSupportedType(ip, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.CustomTTL.SecondsU32())
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
@ -118,7 +103,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
}
|
||||
|
||||
// Mapping exists for this domain, but for another type
|
||||
if !r.filterUnmappedTypes {
|
||||
if !r.cfg.FilterUnmappedTypes {
|
||||
// go to next resolver
|
||||
break
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -15,12 +16,18 @@ import (
|
|||
|
||||
var _ = Describe("CustomDNSResolver", func() {
|
||||
var (
|
||||
TTL = uint32(time.Now().Second())
|
||||
|
||||
sut ChainedResolver
|
||||
m *mockResolver
|
||||
cfg config.CustomDNSConfig
|
||||
)
|
||||
|
||||
TTL := uint32(time.Now().Second())
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = config.CustomDNSConfig{
|
||||
|
@ -45,6 +52,22 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolving custom name via CustomDNSResolver", func() {
|
||||
When("Ip 4 mapping is defined for custom domain and", func() {
|
||||
Context("filterUnmappedTypes is true", func() {
|
||||
|
@ -257,23 +280,4 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
cfg = config.CustomDNSConfig{}
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -7,77 +7,47 @@ import (
|
|||
)
|
||||
|
||||
type EdeResolver struct {
|
||||
configurable[*config.EdeConfig]
|
||||
NextResolver
|
||||
config config.EdeConfig
|
||||
typed
|
||||
}
|
||||
|
||||
func NewEdeResolver(cfg config.EdeConfig) ChainedResolver {
|
||||
return &EdeResolver{
|
||||
config: cfg,
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("extended_error_code"),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
if !r.cfg.Enable {
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
resp, err := r.next.Resolve(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.config.Enable {
|
||||
addExtraReasoning(resp)
|
||||
}
|
||||
r.addExtraReasoning(resp)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (r *EdeResolver) Configuration() (result []string) {
|
||||
if !r.config.Enable {
|
||||
return configDisabled
|
||||
func (r *EdeResolver) addExtraReasoning(res *model.Response) {
|
||||
infocode := res.RType.ToExtendedErrorCode()
|
||||
|
||||
if infocode == dns.ExtendedErrorCodeOther {
|
||||
// dns.ExtendedErrorCodeOther seams broken in some clients
|
||||
return
|
||||
}
|
||||
|
||||
return configEnabled
|
||||
}
|
||||
|
||||
func addExtraReasoning(res *model.Response) {
|
||||
// dns.ExtendedErrorCodeOther seams broken in some clients
|
||||
infocode := convertToExtendedErrorCode(res.RType)
|
||||
if infocode > 0 {
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
opt.Option = append(opt.Option, convertExtendedError(res, infocode))
|
||||
res.Res.Extra = append(res.Res.Extra, opt)
|
||||
}
|
||||
}
|
||||
|
||||
func convertExtendedError(input *model.Response, infocode uint16) *dns.EDNS0_EDE {
|
||||
return &dns.EDNS0_EDE{
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{
|
||||
InfoCode: infocode,
|
||||
ExtraText: input.Reason,
|
||||
}
|
||||
}
|
||||
|
||||
func convertToExtendedErrorCode(input model.ResponseType) uint16 {
|
||||
switch input {
|
||||
case model.ResponseTypeRESOLVED:
|
||||
return dns.ExtendedErrorCodeOther
|
||||
case model.ResponseTypeCACHED:
|
||||
return dns.ExtendedErrorCodeCachedError
|
||||
case model.ResponseTypeCONDITIONAL:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case model.ResponseTypeCUSTOMDNS:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case model.ResponseTypeHOSTSFILE:
|
||||
return dns.ExtendedErrorCodeForgedAnswer
|
||||
case model.ResponseTypeNOTFQDN:
|
||||
return dns.ExtendedErrorCodeBlocked
|
||||
case model.ResponseTypeBLOCKED:
|
||||
return dns.ExtendedErrorCodeBlocked
|
||||
case model.ResponseTypeFILTERED:
|
||||
return dns.ExtendedErrorCodeFiltered
|
||||
case model.ResponseTypeSPECIAL:
|
||||
return dns.ExtendedErrorCodeFiltered
|
||||
default:
|
||||
return dns.ExtendedErrorCodeOther
|
||||
}
|
||||
ExtraText: res.Reason,
|
||||
})
|
||||
res.Res.Extra = append(res.Res.Extra, opt)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
|
||||
|
@ -22,6 +23,12 @@ var _ = Describe("EdeResolver", func() {
|
|||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
@ -59,6 +66,12 @@ var _ = Describe("EdeResolver", func() {
|
|||
// delegated to next resolver
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("ede is enabled", func() {
|
||||
|
@ -106,26 +119,14 @@ var _ = Describe("EdeResolver", func() {
|
|||
Expect(err).To(Equal(resolveErr))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.EdeConfig{Enable: true}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(Equal(configEnabled))
|
||||
})
|
||||
})
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
When("resolver is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.EdeConfig{Enable: false}
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -13,13 +9,21 @@ import (
|
|||
// FilteringResolver filters DNS queries (for example can drop all AAAA query)
|
||||
// returns empty ANSWER with NOERROR
|
||||
type FilteringResolver struct {
|
||||
configurable[*config.FilteringConfig]
|
||||
NextResolver
|
||||
queryTypes config.QTypeSet
|
||||
typed
|
||||
}
|
||||
|
||||
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
|
||||
return &FilteringResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("filtering"),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
qType := request.Req.Question[0].Qtype
|
||||
if r.queryTypes.Contains(dns.Type(qType)) {
|
||||
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
|
||||
response := new(dns.Msg)
|
||||
response.SetRcode(request.Req, dns.RcodeSuccess)
|
||||
|
||||
|
@ -28,27 +32,3 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
func (r *FilteringResolver) Configuration() (result []string) {
|
||||
if len(r.queryTypes) == 0 {
|
||||
return configDisabled
|
||||
}
|
||||
|
||||
qTypes := make([]string, 0, len(r.queryTypes))
|
||||
|
||||
for qType := range r.queryTypes {
|
||||
qTypes = append(qTypes, qType.String())
|
||||
}
|
||||
|
||||
sort.Strings(qTypes)
|
||||
|
||||
result = append(result, fmt.Sprintf("filtering query Types: '%v'", strings.Join(qTypes, ", ")))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func NewFilteringResolver(cfg config.FilteringConfig) ChainedResolver {
|
||||
return &FilteringResolver{
|
||||
queryTypes: cfg.QueryTypes,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@ package resolver
|
|||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -18,6 +20,12 @@ var _ = Describe("FilteringResolver", func() {
|
|||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
@ -29,6 +37,22 @@ var _ = Describe("FilteringResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("Filtering query types are defined", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.FilteringConfig{
|
||||
|
@ -59,10 +83,6 @@ var _ = Describe("FilteringResolver", func() {
|
|||
// no call of next resolver
|
||||
Expect(m.Calls).Should(BeZero())
|
||||
})
|
||||
It("Configure should output all query types", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(Equal([]string{"filtering query Types: 'AAAA, MX'"}))
|
||||
})
|
||||
})
|
||||
|
||||
When("No filtering query types are defined", func() {
|
||||
|
@ -81,9 +101,5 @@ var _ = Describe("FilteringResolver", func() {
|
|||
// delegated to next resolver
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
It("Configure should output 'empty list'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -10,18 +10,20 @@ import (
|
|||
)
|
||||
|
||||
type FqdnOnlyResolver struct {
|
||||
configurable[*config.FqdnOnlyConfig]
|
||||
NextResolver
|
||||
enabled bool
|
||||
typed
|
||||
}
|
||||
|
||||
func NewFqdnOnlyResolver(cfg config.Config) *FqdnOnlyResolver {
|
||||
func NewFqdnOnlyResolver(cfg config.FqdnOnlyConfig) *FqdnOnlyResolver {
|
||||
return &FqdnOnlyResolver{
|
||||
enabled: cfg.FqdnOnly,
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("fqdn_only"),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
if r.enabled {
|
||||
if r.IsEnabled() {
|
||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||
if !strings.Contains(domainFromQuestion, ".") {
|
||||
response := new(dns.Msg)
|
||||
|
@ -33,11 +35,3 @@ func (r *FqdnOnlyResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
func (r *FqdnOnlyResolver) Configuration() (result []string) {
|
||||
if !r.enabled {
|
||||
return configDisabled
|
||||
}
|
||||
|
||||
return configEnabled
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package resolver
|
|||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -13,11 +14,17 @@ import (
|
|||
var _ = Describe("FqdnOnlyResolver", func() {
|
||||
var (
|
||||
sut *FqdnOnlyResolver
|
||||
sutConfig config.Config
|
||||
sutConfig config.FqdnOnlyConfig
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
@ -29,11 +36,25 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("Fqdn only is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.Config{
|
||||
FqdnOnly: true,
|
||||
}
|
||||
sutConfig = config.FqdnOnlyConfig{Enable: true}
|
||||
})
|
||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
|
@ -59,17 +80,27 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
// no call of next resolver
|
||||
Expect(m.Calls).Should(BeZero())
|
||||
})
|
||||
It("Configure should output enabled", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(Equal(configEnabled))
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("Fqdn only is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.Config{
|
||||
FqdnOnly: false,
|
||||
}
|
||||
sutConfig = config.FqdnOnlyConfig{Enable: false}
|
||||
})
|
||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
|
@ -95,9 +126,11 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
// delegated to next resolver
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
It("Configure should output disabled", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -17,23 +16,44 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
hostsFileResolverLogger = "hosts_file_resolver"
|
||||
|
||||
// reduce initial capacity so we don't waste memory if there are less entries than before
|
||||
memReleaseFactor = 2
|
||||
)
|
||||
|
||||
type HostsFileResolver struct {
|
||||
configurable[*config.HostsFileConfig]
|
||||
NextResolver
|
||||
HostsFilePath string
|
||||
hosts splitHostsFileData
|
||||
ttl uint32
|
||||
refreshPeriod time.Duration
|
||||
filterLoopback bool
|
||||
typed
|
||||
|
||||
hosts splitHostsFileData
|
||||
}
|
||||
|
||||
type HostsFileEntry = parsers.HostsFileEntry
|
||||
|
||||
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
|
||||
r := HostsFileResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("hosts_file"),
|
||||
}
|
||||
|
||||
if err := r.parseHostsFile(context.Background()); err != nil {
|
||||
r.log().Errorf("disabling hosts file resolving due to error: %s", err)
|
||||
|
||||
r.cfg.Filepath = "" // don't try parsing the file again
|
||||
} else {
|
||||
go r.periodicUpdate()
|
||||
}
|
||||
|
||||
return &r
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *HostsFileResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.cfg.LogConfig(logger)
|
||||
|
||||
logger.Infof("cache entries = %d", r.hosts.len())
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Response {
|
||||
question := request.Req.Question[0]
|
||||
if question.Qtype != dns.TypePTR {
|
||||
|
@ -46,7 +66,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
return nil
|
||||
}
|
||||
|
||||
if r.filterLoopback && questionIP.IsLoopback() {
|
||||
if r.cfg.FilterLoopback && questionIP.IsLoopback() {
|
||||
// skip the search: we won't find anything
|
||||
return nil
|
||||
}
|
||||
|
@ -64,13 +84,13 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
|
||||
ptr := new(dns.PTR)
|
||||
ptr.Ptr = dns.Fqdn(host)
|
||||
ptr.Hdr = util.CreateHeader(question, r.ttl)
|
||||
ptr.Hdr = util.CreateHeader(question, r.cfg.HostsTTL.SecondsU32())
|
||||
response.Answer = append(response.Answer, ptr)
|
||||
|
||||
for _, alias := range hostData.Aliases {
|
||||
ptrAlias := new(dns.PTR)
|
||||
ptrAlias.Ptr = dns.Fqdn(alias)
|
||||
ptrAlias.Hdr = util.CreateHeader(question, r.ttl)
|
||||
ptrAlias.Hdr = ptr.Hdr
|
||||
response.Answer = append(response.Answer, ptrAlias)
|
||||
}
|
||||
|
||||
|
@ -82,9 +102,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
}
|
||||
|
||||
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, hostsFileResolverLogger)
|
||||
|
||||
if r.HostsFilePath == "" {
|
||||
if r.cfg.Filepath == "" {
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
|
@ -98,7 +116,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
|
||||
response := r.resolve(request.Req, question, domain)
|
||||
if response != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(response.Answer),
|
||||
"domain": domain,
|
||||
}).Debugf("returning hosts file entry")
|
||||
|
@ -106,7 +124,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
|
||||
}
|
||||
|
||||
logger.WithField("resolver", Name(r.next)).Trace("go to next resolver")
|
||||
r.log().WithField("resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
@ -117,7 +135,7 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
|
|||
return nil
|
||||
}
|
||||
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.cfg.HostsTTL.SecondsU32())
|
||||
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(req)
|
||||
|
@ -126,47 +144,14 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
|
|||
return response
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) Configuration() (result []string) {
|
||||
if r.HostsFilePath == "" || r.hosts.isEmpty() {
|
||||
return configDisabled
|
||||
}
|
||||
|
||||
result = append(result, fmt.Sprintf("file path: %s", r.HostsFilePath))
|
||||
result = append(result, fmt.Sprintf("TTL: %d", r.ttl))
|
||||
result = append(result, fmt.Sprintf("refresh period: %s", r.refreshPeriod.String()))
|
||||
result = append(result, fmt.Sprintf("filter loopback addresses: %t", r.filterLoopback))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
|
||||
r := HostsFileResolver{
|
||||
HostsFilePath: cfg.Filepath,
|
||||
ttl: uint32(time.Duration(cfg.HostsTTL).Seconds()),
|
||||
refreshPeriod: time.Duration(cfg.RefreshPeriod),
|
||||
filterLoopback: cfg.FilterLoopback,
|
||||
}
|
||||
|
||||
if err := r.parseHostsFile(context.Background()); err != nil {
|
||||
logger := log.PrefixedLog(hostsFileResolverLogger)
|
||||
logger.Warnf("hosts file resolving is disabled: %s", err)
|
||||
|
||||
r.HostsFilePath = "" // don't try parsing the file again
|
||||
} else {
|
||||
go r.periodicUpdate()
|
||||
}
|
||||
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
||||
const maxErrorsPerFile = 5
|
||||
|
||||
if r.HostsFilePath == "" {
|
||||
if r.cfg.Filepath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(r.HostsFilePath)
|
||||
f, err := os.Open(r.cfg.Filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -176,7 +161,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
|||
|
||||
p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile)
|
||||
p.OnErr(func(err error) {
|
||||
log.PrefixedLog(hostsFileResolverLogger).Warnf("error parsing %s: %s, trying to continue", r.HostsFilePath, err)
|
||||
r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err)
|
||||
})
|
||||
|
||||
err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
|
||||
|
@ -187,7 +172,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Ignore loopback, if so configured
|
||||
if r.filterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") {
|
||||
if r.cfg.FilterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -196,7 +181,7 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
|||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing %s: %w", r.HostsFilePath, err) // err is parsers.ErrTooManyErrors
|
||||
return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors
|
||||
}
|
||||
|
||||
r.hosts = newHosts
|
||||
|
@ -205,15 +190,14 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (r *HostsFileResolver) periodicUpdate() {
|
||||
if r.refreshPeriod > 0 {
|
||||
ticker := time.NewTicker(r.refreshPeriod)
|
||||
if r.cfg.RefreshPeriod.ToDuration() > 0 {
|
||||
ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
|
||||
logger := log.PrefixedLog(hostsFileResolverLogger)
|
||||
logger.WithField("file", r.HostsFilePath).Debug("refreshing hosts file")
|
||||
r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file")
|
||||
|
||||
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background()))
|
||||
}
|
||||
|
@ -238,7 +222,11 @@ func newSplitHostsDataWithSameCapacity(other splitHostsFileData) splitHostsFileD
|
|||
}
|
||||
|
||||
func (d splitHostsFileData) isEmpty() bool {
|
||||
return d.v4.isEmpty() && d.v6.isEmpty()
|
||||
return d.len() == 0
|
||||
}
|
||||
|
||||
func (d splitHostsFileData) len() int {
|
||||
return d.v4.len() + d.v6.len()
|
||||
}
|
||||
|
||||
func (d splitHostsFileData) getIP(qType dns.Type, domain string) net.IP {
|
||||
|
@ -277,8 +265,8 @@ func newHostsDataWithSameCapacity(other hostsFileData) hostsFileData {
|
|||
}
|
||||
}
|
||||
|
||||
func (d hostsFileData) isEmpty() bool {
|
||||
return len(d.hosts) == 0 && len(d.aliases) == 0
|
||||
func (d hostsFileData) len() int {
|
||||
return len(d.hosts) + len(d.aliases)
|
||||
}
|
||||
|
||||
func (d hostsFileData) getIP(hostname string) net.IP {
|
||||
|
|
|
@ -2,12 +2,11 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -17,6 +16,8 @@ import (
|
|||
|
||||
var _ = Describe("HostsFileResolver", func() {
|
||||
var (
|
||||
TTL = uint32(time.Now().Second())
|
||||
|
||||
sut *HostsFileResolver
|
||||
sutConfig config.HostsFileConfig
|
||||
m *mockResolver
|
||||
|
@ -24,7 +25,11 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
tmpFile *TmpFile
|
||||
)
|
||||
|
||||
TTL := uint32(time.Now().Second())
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
tmpDir = NewTmpFolder("HostsFileResolver")
|
||||
|
@ -49,16 +54,32 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Using hosts file", func() {
|
||||
When("Hosts file cannot be located", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.HostsFileConfig{
|
||||
Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()),
|
||||
Filepath: "/this/file/does/not/exist",
|
||||
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
|
||||
}
|
||||
})
|
||||
It("should not parse any hosts", func() {
|
||||
Expect(sut.HostsFilePath).Should(BeEmpty())
|
||||
Expect(sut.cfg.Filepath).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
|
||||
|
@ -140,11 +161,11 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
It("should not be used", func() {
|
||||
Expect(sut).ShouldNot(BeNil())
|
||||
Expect(sut.HostsFilePath).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.hosts).Should(HaveLen(0))
|
||||
Expect(sut.hosts.v6.hosts).Should(HaveLen(0))
|
||||
Expect(sut.hosts.v4.aliases).Should(HaveLen(0))
|
||||
Expect(sut.hosts.v6.aliases).Should(HaveLen(0))
|
||||
Expect(sut.cfg.Filepath).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
|
||||
Expect(sut.hosts.v6.aliases).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -307,25 +328,6 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("hosts file is provided", func() {
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
|
||||
When("hosts file is not provided", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.HostsFileConfig{}
|
||||
})
|
||||
It("should return 'disabled'", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement(configStatusDisabled))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Delegating to next resolver", func() {
|
||||
When("no hosts file is provided", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -15,8 +14,10 @@ import (
|
|||
|
||||
// MetricsResolver resolver that records metrics about requests/response
|
||||
type MetricsResolver struct {
|
||||
configurable[*config.MetricsConfig]
|
||||
NextResolver
|
||||
cfg config.PrometheusConfig
|
||||
typed
|
||||
|
||||
totalQueries *prometheus.CounterVec
|
||||
totalResponse *prometheus.CounterVec
|
||||
totalErrors prometheus.Counter
|
||||
|
@ -24,11 +25,11 @@ type MetricsResolver struct {
|
|||
}
|
||||
|
||||
// Resolve resolves the passed request
|
||||
func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
response, err := m.next.Resolve(request)
|
||||
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
response, err := r.next.Resolve(request)
|
||||
|
||||
if m.cfg.Enable {
|
||||
m.totalQueries.With(prometheus.Labels{
|
||||
if r.cfg.Enable {
|
||||
r.totalQueries.With(prometheus.Labels{
|
||||
"client": strings.Join(request.ClientNames, ","),
|
||||
"type": dns.TypeToString[request.Req.Question[0].Qtype],
|
||||
}).Inc()
|
||||
|
@ -40,12 +41,12 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
|
|||
responseType = response.RType.String()
|
||||
}
|
||||
|
||||
m.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs)
|
||||
r.durationHistogram.WithLabelValues(responseType).Observe(reqDurationMs)
|
||||
|
||||
if err != nil {
|
||||
m.totalErrors.Inc()
|
||||
r.totalErrors.Inc()
|
||||
} else {
|
||||
m.totalResponse.With(prometheus.Labels{
|
||||
r.totalResponse.With(prometheus.Labels{
|
||||
"reason": response.Reason,
|
||||
"response_code": dns.RcodeToString[response.Res.Rcode],
|
||||
"response_type": response.RType.String(),
|
||||
|
@ -56,38 +57,28 @@ func (m *MetricsResolver) Resolve(request *model.Request) (*model.Response, erro
|
|||
return response, err
|
||||
}
|
||||
|
||||
// Configuration gets the config of this resolver in a string slice
|
||||
func (m *MetricsResolver) Configuration() (result []string) {
|
||||
if !m.cfg.Enable {
|
||||
return configDisabled
|
||||
// NewMetricsResolver creates a new intance of the MetricsResolver type
|
||||
func NewMetricsResolver(cfg config.MetricsConfig) ChainedResolver {
|
||||
m := MetricsResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("metrics"),
|
||||
|
||||
durationHistogram: durationHistogram(),
|
||||
totalQueries: totalQueriesMetric(),
|
||||
totalResponse: totalResponseMetric(),
|
||||
totalErrors: totalErrorMetric(),
|
||||
}
|
||||
|
||||
result = append(result, "metrics:")
|
||||
result = append(result, fmt.Sprintf(" Enable = %t", m.cfg.Enable))
|
||||
result = append(result, fmt.Sprintf(" Path = %s", m.cfg.Path))
|
||||
m.registerMetrics()
|
||||
|
||||
return
|
||||
return &m
|
||||
}
|
||||
|
||||
// NewMetricsResolver creates a new intance of the MetricsResolver type
|
||||
func NewMetricsResolver(cfg config.PrometheusConfig) ChainedResolver {
|
||||
durationHistogram := durationHistogram()
|
||||
totalQueries := totalQueriesMetric()
|
||||
totalResponse := totalResponseMetric()
|
||||
totalErrors := totalErrorMetric()
|
||||
|
||||
metrics.RegisterMetric(durationHistogram)
|
||||
metrics.RegisterMetric(totalQueries)
|
||||
metrics.RegisterMetric(totalResponse)
|
||||
metrics.RegisterMetric(totalErrors)
|
||||
|
||||
return &MetricsResolver{
|
||||
cfg: cfg,
|
||||
durationHistogram: durationHistogram,
|
||||
totalQueries: totalQueries,
|
||||
totalResponse: totalResponse,
|
||||
totalErrors: totalErrors,
|
||||
}
|
||||
func (r *MetricsResolver) registerMetrics() {
|
||||
metrics.RegisterMetric(r.durationHistogram)
|
||||
metrics.RegisterMetric(r.totalQueries)
|
||||
metrics.RegisterMetric(r.totalResponse)
|
||||
metrics.RegisterMetric(r.totalErrors)
|
||||
}
|
||||
|
||||
func totalQueriesMetric() *prometheus.CounterVec {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
|
@ -22,13 +23,35 @@ var _ = Describe("MetricResolver", func() {
|
|||
m *mockResolver
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewMetricsResolver(config.PrometheusConfig{Enable: true}).(*MetricsResolver)
|
||||
sut = NewMetricsResolver(config.MetricsConfig{Enable: true}).(*MetricsResolver)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Recording prometheus metrics", func() {
|
||||
Context("Recording request metrics", func() {
|
||||
When("Request will be performed", func() {
|
||||
|
@ -63,13 +86,4 @@ var _ = Describe("MetricResolver", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
|
||||
|
@ -27,10 +28,21 @@ type mockResolver struct {
|
|||
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
func (r *mockResolver) Configuration() []string {
|
||||
// Type implements `Resolver`.
|
||||
func (r *mockResolver) Type() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (r *mockResolver) IsEnabled() bool {
|
||||
args := r.Called()
|
||||
|
||||
return args.Get(0).([]string)
|
||||
return args.Get(0).(bool)
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *mockResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.Called()
|
||||
}
|
||||
|
||||
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var NoResponse = &model.Response{} //nolint:gochecknoglobals
|
||||
|
@ -13,10 +14,20 @@ func NewNoOpResolver() Resolver {
|
|||
return NoOpResolver{}
|
||||
}
|
||||
|
||||
func (r NoOpResolver) Configuration() (result []string) {
|
||||
return nil
|
||||
// Type implements `Resolver`.
|
||||
func (NoOpResolver) Type() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (r NoOpResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (NoOpResolver) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (NoOpResolver) LogConfig(*logrus.Entry) {
|
||||
}
|
||||
|
||||
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) {
|
||||
return NoResponse, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -9,6 +10,12 @@ import (
|
|||
var _ = Describe("NoOpResolver", func() {
|
||||
var sut NoOpResolver
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewNoOpResolver().(NoOpResolver)
|
||||
})
|
||||
|
@ -21,10 +28,19 @@ var _ = Describe("NoOpResolver", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
It("returns nothing", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(BeNil())
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should not log anything", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -20,13 +20,16 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
upstreamDefaultCfgName = "default"
|
||||
parallelResolverLogger = "parallel_best_resolver"
|
||||
upstreamDefaultCfgName = config.UpstreamDefaultCfgName
|
||||
parallelResolverType = "parallel_best"
|
||||
resolverCount = 2
|
||||
)
|
||||
|
||||
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
|
||||
type ParallelBestResolver struct {
|
||||
configurable[*config.ParallelBestConfig]
|
||||
typed
|
||||
|
||||
resolversPerClient map[string][]*upstreamResolverStatus
|
||||
}
|
||||
|
||||
|
@ -77,10 +80,11 @@ func testResolver(r *UpstreamResolver) error {
|
|||
|
||||
// NewParallelBestResolver creates new resolver instance
|
||||
func NewParallelBestResolver(
|
||||
upstreamResolvers map[string][]config.Upstream, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (Resolver, error) {
|
||||
logger := log.PrefixedLog(parallelResolverLogger)
|
||||
cfg config.ParallelBestConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (*ParallelBestResolver, error) {
|
||||
logger := log.PrefixedLog(parallelResolverType)
|
||||
|
||||
upstreamResolvers := cfg.ExternalResolvers
|
||||
resolverGroups := make(map[string][]Resolver, len(upstreamResolvers))
|
||||
|
||||
for name, upstreamCfgs := range upstreamResolvers {
|
||||
|
@ -114,10 +118,12 @@ func NewParallelBestResolver(
|
|||
resolverGroups[name] = group
|
||||
}
|
||||
|
||||
return newParallelBestResolver(resolverGroups)
|
||||
return newParallelBestResolver(cfg, resolverGroups)
|
||||
}
|
||||
|
||||
func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, error) {
|
||||
func newParallelBestResolver(
|
||||
cfg config.ParallelBestConfig, resolverGroups map[string][]Resolver,
|
||||
) (*ParallelBestResolver, error) {
|
||||
resolversPerClient := make(map[string][]*upstreamResolverStatus, len(resolverGroups))
|
||||
|
||||
for groupName, resolvers := range resolverGroups {
|
||||
|
@ -136,27 +142,21 @@ func newParallelBestResolver(resolverGroups map[string][]Resolver) (Resolver, er
|
|||
}
|
||||
|
||||
r := ParallelBestResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType(parallelResolverType),
|
||||
|
||||
resolversPerClient: resolversPerClient,
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// Configuration returns current resolver configuration
|
||||
func (r *ParallelBestResolver) Configuration() (result []string) {
|
||||
result = append(result, "upstream resolvers:")
|
||||
for name, res := range r.resolversPerClient {
|
||||
result = append(result, fmt.Sprintf("- %s", name))
|
||||
for _, r := range res {
|
||||
result = append(result, fmt.Sprintf(" - %s", r.resolver))
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
func (r *ParallelBestResolver) Name() string {
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (r ParallelBestResolver) String() string {
|
||||
result := make([]string, 0)
|
||||
func (r *ParallelBestResolver) String() string {
|
||||
result := make([]string, 0, len(r.resolversPerClient))
|
||||
|
||||
for name, res := range r.resolversPerClient {
|
||||
tmp := make([]string, len(res))
|
||||
|
@ -206,7 +206,7 @@ func (r *ParallelBestResolver) resolversForClient(request *model.Request) (resul
|
|||
|
||||
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
||||
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, parallelResolverLogger)
|
||||
logger := log.WithPrefix(request.Log, parallelResolverType)
|
||||
|
||||
resolvers := r.resolversForClient(request)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -19,13 +20,74 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
noVerifyUpstreams = false
|
||||
)
|
||||
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
var (
|
||||
sut *ParallelBestResolver
|
||||
sutMapping config.ParallelBestMapping
|
||||
sutVerify bool
|
||||
|
||||
err error
|
||||
|
||||
bootstrap *Bootstrap
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
},
|
||||
config.Upstream{
|
||||
Host: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sutVerify = noVerifyUpstreams
|
||||
|
||||
bootstrap = systemResolverBootstrap
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sutConfig := config.ParallelBestConfig{ExternalResolvers: sutMapping}
|
||||
|
||||
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
|
||||
})
|
||||
|
||||
config.GetConfig().UpstreamTimeout = config.Duration(1000 * time.Millisecond)
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
It("should not be empty", func() {
|
||||
Expect(sut.Name()).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("default upstream resolvers are not defined", func() {
|
||||
It("should fail on startup", func() {
|
||||
_, err := NewParallelBestResolver(map[string][]config.Upstream{}, nil, noVerifyUpstreams)
|
||||
_, err := NewParallelBestResolver(config.ParallelBestConfig{
|
||||
ExternalResolvers: config.ParallelBestMapping{},
|
||||
}, nil, noVerifyUpstreams)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("no external DNS resolvers configured"))
|
||||
})
|
||||
|
@ -40,7 +102,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
defer mockUpstream.Close()
|
||||
|
||||
upstream := map[string][]config.Upstream{
|
||||
upstream := config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
|
@ -49,21 +111,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewParallelBestResolver(upstream, systemResolverBootstrap, verifyUpstreams)
|
||||
_, err := NewParallelBestResolver(config.ParallelBestConfig{
|
||||
ExternalResolvers: upstream,
|
||||
}, systemResolverBootstrap, verifyUpstreams)
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
|
||||
When("no upstream resolvers can be reached", func() {
|
||||
var (
|
||||
upstream map[string][]config.Upstream
|
||||
b *Bootstrap
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
b = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
upstream = map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
Host: "wrong",
|
||||
|
@ -75,22 +134,26 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
}
|
||||
})
|
||||
|
||||
It("should fail to start if strict checking is enabled", func() {
|
||||
_, err := NewParallelBestResolver(upstream, b, verifyUpstreams)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
When("strict checking is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = verifyUpstreams
|
||||
})
|
||||
It("should fail to start", func() {
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
It("should start if strict checking is disabled", func() {
|
||||
_, err := NewParallelBestResolver(upstream, b, noVerifyUpstreams)
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
When("strict checking is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = noVerifyUpstreams
|
||||
})
|
||||
It("should start", func() {
|
||||
Expect(err).Should(Not(HaveOccurred()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolving result from fastest upstream resolver", func() {
|
||||
var (
|
||||
sut Resolver
|
||||
err error
|
||||
)
|
||||
When("2 Upstream resolvers are defined", func() {
|
||||
When("one resolver is fast and another is slow", func() {
|
||||
BeforeEach(func() {
|
||||
|
@ -107,10 +170,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {fastTestUpstream.Start(), slowTestUpstream.Start()},
|
||||
}, nil, noVerifyUpstreams)
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
})
|
||||
It("Should use result from fastest one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
@ -136,9 +198,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
return response
|
||||
})
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {withErrorUpstream, slowTestUpstream.Start()},
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
}
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should use result from successful resolver", func() {
|
||||
|
@ -158,9 +220,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
withError1 := config.Upstream{Host: "wrong"}
|
||||
withError2 := config.Upstream{Host: "wrong"}
|
||||
|
||||
sut, err = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {withError1, withError2},
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
}
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should return error", func() {
|
||||
|
@ -194,14 +256,14 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
WithAnswerRR("example.com 123 IN A 123.124.122.126")
|
||||
DeferCleanup(clientSpecificCIRDMockUpstream.Close)
|
||||
|
||||
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {defaultMockUpstream.Start()},
|
||||
"laptop": {clientSpecificExactMockUpstream.Start()},
|
||||
"client-*-m": {clientSpecificWildcardMockUpstream.Start()},
|
||||
"client[0-9]": {clientSpecificWildcardMockUpstream.Start()},
|
||||
"192.168.178.33": {clientSpecificIPMockUpstream.Start()},
|
||||
"10.43.8.67/28": {clientSpecificCIRDMockUpstream.Start()},
|
||||
}, nil, noVerifyUpstreams)
|
||||
}
|
||||
})
|
||||
It("Should use default if client name or IP don't match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
||||
|
@ -294,11 +356,11 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream.Close)
|
||||
|
||||
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sutMapping = config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {
|
||||
mockUpstream.Start(),
|
||||
},
|
||||
}, nil, noVerifyUpstreams)
|
||||
}
|
||||
})
|
||||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
@ -327,10 +389,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream2.Close)
|
||||
|
||||
tmp, _ := NewParallelBestResolver(map[string][]config.Upstream{
|
||||
sut, _ := NewParallelBestResolver(config.ParallelBestConfig{ExternalResolvers: config.ParallelBestMapping{
|
||||
upstreamDefaultCfgName: {withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||
}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
sut := tmp.(*ParallelBestResolver)
|
||||
}}, systemResolverBootstrap, noVerifyUpstreams)
|
||||
|
||||
By("all resolvers have same weight for random -> equal distribution", func() {
|
||||
resolverCount := make(map[Resolver]int)
|
||||
|
@ -391,26 +452,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("errors during construction", func() {
|
||||
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewParallelBestResolver(map[string][]config.Upstream{"test": {{Host: "example.com"}}}, b, verifyUpstreams)
|
||||
r, err := NewParallelBestResolver(config.ParallelBestConfig{
|
||||
ExternalResolvers: config.ParallelBestMapping{"test": {{Host: "example.com"}}},
|
||||
}, b, verifyUpstreams)
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(r).Should(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
var sut Resolver
|
||||
BeforeEach(func() {
|
||||
config.GetConfig().StartVerifyUpstream = false
|
||||
|
||||
sut, _ = NewParallelBestResolver(map[string][]config.Upstream{upstreamDefaultCfgName: {
|
||||
{Host: "host1"},
|
||||
{Host: "host2"},
|
||||
}}, nil, noVerifyUpstreams)
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -14,35 +13,32 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
cleanUpRunPeriod = 12 * time.Hour
|
||||
queryLoggingResolverPrefix = "query_logging_resolver"
|
||||
logChanCap = 1000
|
||||
defaultFlushPeriod = 30 * time.Second
|
||||
cleanUpRunPeriod = 12 * time.Hour
|
||||
queryLoggingResolverType = "query_logging"
|
||||
logChanCap = 1000
|
||||
defaultFlushPeriod = 30 * time.Second
|
||||
)
|
||||
|
||||
// QueryLoggingResolver writes query information (question, answer, duration, ...)
|
||||
type QueryLoggingResolver struct {
|
||||
configurable[*config.QueryLogConfig]
|
||||
NextResolver
|
||||
target string
|
||||
logRetentionDays uint64
|
||||
logChan chan *querylog.LogEntry
|
||||
writer querylog.Writer
|
||||
logType config.QueryLogType
|
||||
fields []config.QueryLogField
|
||||
typed
|
||||
|
||||
logChan chan *querylog.LogEntry
|
||||
writer querylog.Writer
|
||||
}
|
||||
|
||||
// NewQueryLoggingResolver returns a new resolver instance
|
||||
func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
||||
logger := log.PrefixedLog(queryLoggingResolverPrefix)
|
||||
logger := log.PrefixedLog(queryLoggingResolverType)
|
||||
|
||||
var writer querylog.Writer
|
||||
|
||||
logType := cfg.Type
|
||||
|
||||
err := retry.Do(
|
||||
func() error {
|
||||
var err error
|
||||
switch logType {
|
||||
switch cfg.Type {
|
||||
case config.QueryLogTypeCsv:
|
||||
writer, err = querylog.NewCSVWriter(cfg.Target, false, cfg.LogRetentionDays)
|
||||
case config.QueryLogTypeCsvClient:
|
||||
|
@ -61,7 +57,7 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
|||
},
|
||||
retry.Attempts(uint(cfg.CreationAttempts)),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(time.Duration(cfg.CreationCooldown)),
|
||||
retry.Delay(cfg.CreationCooldown.ToDuration()),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
logger.Warnf(
|
||||
"Error occurred on query writer creation, retry attempt %d/%d: %v", n+1, cfg.CreationAttempts, err,
|
||||
|
@ -71,18 +67,17 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
|||
logger.Error("can't create query log writer, using console as fallback: ", err)
|
||||
|
||||
writer = querylog.NewLoggerWriter()
|
||||
logType = config.QueryLogTypeConsole
|
||||
cfg.Type = config.QueryLogTypeConsole
|
||||
}
|
||||
|
||||
logChan := make(chan *querylog.LogEntry, logChanCap)
|
||||
|
||||
resolver := QueryLoggingResolver{
|
||||
target: cfg.Target,
|
||||
logRetentionDays: cfg.LogRetentionDays,
|
||||
logChan: logChan,
|
||||
writer: writer,
|
||||
logType: logType,
|
||||
fields: resolveQueryLogFields(cfg),
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType(queryLoggingResolverType),
|
||||
|
||||
logChan: logChan,
|
||||
writer: writer,
|
||||
}
|
||||
|
||||
go resolver.writeLog()
|
||||
|
@ -94,24 +89,6 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
|||
return &resolver
|
||||
}
|
||||
|
||||
func resolveQueryLogFields(cfg config.QueryLogConfig) []config.QueryLogField {
|
||||
var fields []config.QueryLogField
|
||||
|
||||
if len(cfg.Fields) == 0 {
|
||||
// no fields defined, use all fields as fallback
|
||||
for _, v := range config.QueryLogFieldNames() {
|
||||
qlt, err := config.ParseQueryLogField(v)
|
||||
util.LogOnError("ignoring unknown query log field", err)
|
||||
|
||||
fields = append(fields, qlt)
|
||||
}
|
||||
} else {
|
||||
fields = cfg.Fields
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// triggers periodically cleanup of old log files
|
||||
func (r *QueryLoggingResolver) periodicCleanUp() {
|
||||
ticker := time.NewTicker(cleanUpRunPeriod)
|
||||
|
@ -129,7 +106,7 @@ func (r *QueryLoggingResolver) doCleanUp() {
|
|||
|
||||
// Resolve logs the query, duration and the result
|
||||
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, queryLoggingResolverPrefix)
|
||||
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
|
@ -157,7 +134,7 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
|
|||
ClientNames: []string{"none"},
|
||||
}
|
||||
|
||||
for _, f := range r.fields {
|
||||
for _, f := range r.cfg.Fields {
|
||||
switch f {
|
||||
case config.QueryLogFieldClientIP:
|
||||
entry.ClientIP = request.ClientIP.String()
|
||||
|
@ -196,18 +173,8 @@ func (r *QueryLoggingResolver) writeLog() {
|
|||
|
||||
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
||||
if len(r.logChan) > halfCap {
|
||||
log.PrefixedLog(queryLoggingResolverPrefix).WithField("channel_len",
|
||||
r.log().WithField("channel_len",
|
||||
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration returns the current resolver configuration
|
||||
func (r *QueryLoggingResolver) Configuration() (result []string) {
|
||||
result = append(result, fmt.Sprintf("type: \"%s\"", r.logType))
|
||||
result = append(result, fmt.Sprintf("target: \"%s\"", r.target))
|
||||
result = append(result, fmt.Sprintf("logRetentionDays: %d", r.logRetentionDays))
|
||||
result = append(result, fmt.Sprintf("fields: %s", r.fields))
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/querylog"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -44,6 +45,12 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
mockAnswer *dns.Msg
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mockAnswer = new(dns.Msg)
|
||||
tmpDir = NewTmpFolder("queryLoggingResolver")
|
||||
|
@ -52,6 +59,10 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
if len(sutConfig.Fields) == 0 {
|
||||
sutConfig.SetDefaults() // not called when using a struct literal
|
||||
}
|
||||
|
||||
sut = NewQueryLoggingResolver(sutConfig).(*QueryLoggingResolver)
|
||||
DeferCleanup(func() { close(sut.logChan) })
|
||||
m = &mockResolver{}
|
||||
|
@ -59,6 +70,22 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Process request", func() {
|
||||
When("Resolver has no configuration", func() {
|
||||
BeforeEach(func() {
|
||||
|
@ -276,24 +303,6 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.QueryLogConfig{
|
||||
Target: tmpDir.Path,
|
||||
Type: config.QueryLogTypeCsvClient,
|
||||
LogRetentionDays: 0,
|
||||
CreationAttempts: 1,
|
||||
CreationCooldown: config.Duration(time.Millisecond),
|
||||
}
|
||||
})
|
||||
It("should return configuration", func() {
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", 1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Clean up of query log directory", func() {
|
||||
When("fallback logger is enabled, log retention is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
|
@ -355,7 +364,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
}
|
||||
})
|
||||
It("should use fallback", func() {
|
||||
Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole))
|
||||
Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -369,7 +378,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
}
|
||||
})
|
||||
It("should use fallback", func() {
|
||||
Expect(sut.logType).Should(Equal(config.QueryLogTypeConsole))
|
||||
Expect(sut.cfg.Type).Should(Equal(config.QueryLogTypeConsole))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -14,20 +13,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Resolver is not configured.
|
||||
const (
|
||||
configStatusEnabled string = "enabled"
|
||||
|
||||
configStatusDisabled string = "disabled"
|
||||
)
|
||||
|
||||
var (
|
||||
// note: this is not used by all resolvers: only those that don't print any other configuration
|
||||
configEnabled = []string{configStatusEnabled} //nolint:gochecknoglobals
|
||||
|
||||
configDisabled = []string{configStatusDisabled} //nolint:gochecknoglobals
|
||||
)
|
||||
|
||||
func newRequest(question string, rType dns.Type, logger ...*logrus.Entry) *model.Request {
|
||||
var loggerEntry *logrus.Entry
|
||||
if len(logger) == 1 {
|
||||
|
@ -84,11 +69,15 @@ func newRequestWithClientID(question string, rType dns.Type, ip, requestClientID
|
|||
|
||||
// Resolver generic interface for all resolvers
|
||||
type Resolver interface {
|
||||
config.Configurable
|
||||
|
||||
// Type returns a short, user-friendly, name for the resolver.
|
||||
//
|
||||
// It should be the same for all instances of a specific Resolver type.
|
||||
Type() string
|
||||
|
||||
// Resolve performs resolution of a DNS request
|
||||
Resolve(req *model.Request) (*model.Response, error)
|
||||
|
||||
// Configuration returns current resolver configuration
|
||||
Configuration() []string
|
||||
}
|
||||
|
||||
// ChainedResolver represents a resolver, which can delegate result to the next one
|
||||
|
@ -142,10 +131,73 @@ func Name(resolver Resolver) string {
|
|||
return named.Name()
|
||||
}
|
||||
|
||||
return defaultName(resolver)
|
||||
return resolver.Type()
|
||||
}
|
||||
|
||||
// defaultName returns a short user-friendly name of a resolver
|
||||
func defaultName(resolver Resolver) string {
|
||||
return strings.Split(fmt.Sprintf("%T", resolver), ".")[1]
|
||||
// ForEach iterates over all resolvers in the chain.
|
||||
//
|
||||
// If resolver is not a chain, or is unlinked,
|
||||
// the callback is called exactly once.
|
||||
func ForEach(resolver Resolver, callback func(Resolver)) {
|
||||
for resolver != nil {
|
||||
callback(resolver)
|
||||
|
||||
if chained, ok := resolver.(ChainedResolver); ok {
|
||||
resolver = chained.GetNext()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogResolverConfig logs the resolver's type and config.
|
||||
func LogResolverConfig(res Resolver, logger *logrus.Entry) {
|
||||
// Use the type, not the full typeName, to avoid redundant information with the config
|
||||
typeName := res.Type()
|
||||
|
||||
if !res.IsEnabled() {
|
||||
logger.Debugf("-> %s: disabled", typeName)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("-> %s:", typeName)
|
||||
log.WithIndent(logger, " ", res.LogConfig)
|
||||
}
|
||||
|
||||
// Should be embedded in a Resolver to auto-implement `Resolver.Type`.
|
||||
type typed struct {
|
||||
typeName string
|
||||
}
|
||||
|
||||
func withType(t string) typed {
|
||||
return typed{typeName: t}
|
||||
}
|
||||
|
||||
// Type implements `Resolver`.
|
||||
func (t *typed) Type() string {
|
||||
return t.typeName
|
||||
}
|
||||
|
||||
func (t *typed) log() *logrus.Entry {
|
||||
return log.PrefixedLog(t.Type())
|
||||
}
|
||||
|
||||
// Should be embedded in a Resolver to auto-implement `config.Configurable`.
|
||||
type configurable[T config.Configurable] struct {
|
||||
cfg T
|
||||
}
|
||||
|
||||
func withConfig[T config.Configurable](cfg T) configurable[T] {
|
||||
return configurable[T]{cfg: cfg}
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *configurable[T]) IsEnabled() bool {
|
||||
return c.cfg.IsEnabled()
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
|
||||
c.cfg.LogConfig(logger)
|
||||
}
|
||||
|
|
|
@ -1,45 +1,127 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var systemResolverBootstrap = &Bootstrap{}
|
||||
|
||||
var _ = Describe("Resolver", func() {
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
Describe("Chains", func() {
|
||||
var (
|
||||
r1 ChainedResolver
|
||||
r2 ChainedResolver
|
||||
r3 ChainedResolver
|
||||
r4 Resolver
|
||||
)
|
||||
|
||||
Describe("Creating resolver chain", func() {
|
||||
When("A chain of resolvers will be created", func() {
|
||||
It("should be iterable by calling 'GetNext'", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
cr, _ := NewClientNamesResolver(config.ClientLookupConfig{}, nil, false)
|
||||
ch := Chain(br, cr)
|
||||
c, ok := ch.(ChainedResolver)
|
||||
Expect(ok).Should(BeTrue())
|
||||
BeforeEach(func() {
|
||||
r1 = &mockResolver{}
|
||||
r2 = &mockResolver{}
|
||||
r3 = &mockResolver{}
|
||||
r4 = &NoOpResolver{}
|
||||
})
|
||||
|
||||
next := c.GetNext()
|
||||
Expect(next).ShouldNot(BeNil())
|
||||
Describe("Chain", func() {
|
||||
It("should create a chain iterable using `GetNext`", func() {
|
||||
ch := Chain(r1, r2, r3, r4)
|
||||
Expect(ch).ShouldNot(BeNil())
|
||||
Expect(ch).Should(Equal(r1))
|
||||
Expect(r1.GetNext()).Should(Equal(r2))
|
||||
Expect(r2.GetNext()).Should(Equal(r3))
|
||||
Expect(r3.GetNext()).Should(Equal(r4))
|
||||
})
|
||||
|
||||
It("should not link a final ChainedResolver", func() {
|
||||
ch := Chain(r1, r2)
|
||||
Expect(ch).ShouldNot(BeNil())
|
||||
|
||||
Expect(r1.GetNext()).Should(Equal(r2))
|
||||
Expect(r2.GetNext()).Should(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ForEach", func() {
|
||||
It("should iterate on all resolvers in the chain", func() {
|
||||
ch := Chain(r1, r2, r3, r4)
|
||||
Expect(ch).ShouldNot(BeNil())
|
||||
|
||||
var itResult []Resolver
|
||||
|
||||
ForEach(ch, func(r Resolver) {
|
||||
itResult = append(itResult, r)
|
||||
})
|
||||
|
||||
Expect(itResult).ShouldNot(BeEmpty())
|
||||
Expect(itResult).Should(Equal([]Resolver{r1, r2, r3, r4}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogResolverConfig", func() {
|
||||
It("should call the resolver's `LogConfig`", func() {
|
||||
logger := logrus.NewEntry(log.Log())
|
||||
|
||||
m := &mockResolver{}
|
||||
m.On("IsEnabled").Return(true)
|
||||
m.On("LogConfig")
|
||||
|
||||
LogResolverConfig(m, logger)
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
When("the resolver is disabled", func() {
|
||||
It("should not call the resolver's `LogConfig`", func() {
|
||||
logger := logrus.NewEntry(log.Log())
|
||||
|
||||
m := &mockResolver{}
|
||||
m.On("IsEnabled").Return(false)
|
||||
|
||||
LogResolverConfig(m, logger)
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
When("'Name' is called", func() {
|
||||
It("should return resolver name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
name := Name(br)
|
||||
Expect(name).Should(Equal("BlockingResolver"))
|
||||
Expect(name).Should(Equal("blocking"))
|
||||
})
|
||||
})
|
||||
When("'Name' is called on a NamedResolver", func() {
|
||||
It("should return it's custom name", func() {
|
||||
It("should return its custom name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
|
||||
cfg := config.RewriteConfig{Rewrite: map[string]string{"not": "empty"}}
|
||||
cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}}
|
||||
r := NewRewriterResolver(cfg, br)
|
||||
|
||||
name := Name(r)
|
||||
Expect(name).Should(Equal("BlockingResolver w/ RewriterResolver"))
|
||||
Expect(name).Should(Equal("blocking w/ rewrite"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func expectValidResolverType(sut Resolver) {
|
||||
By("it must not contain spaces", func() {
|
||||
Expect(sut.Type()).ShouldNot(ContainSubstring(" "))
|
||||
})
|
||||
By("it must be lower case", func() {
|
||||
Expect(sut.Type()).Should(Equal(strings.ToLower(sut.Type())))
|
||||
})
|
||||
By("it must not contain 'resolver'", func() {
|
||||
Expect(sut.Type()).ShouldNot(ContainSubstring("resolver"))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -18,13 +18,14 @@ import (
|
|||
// The branch is where the rewrite is active. If the branch doesn't
|
||||
// yield a result, the normal resolving is continued.
|
||||
type RewriterResolver struct {
|
||||
configurable[*config.RewriterConfig]
|
||||
NextResolver
|
||||
rewrite map[string]string
|
||||
inner Resolver
|
||||
fallbackUpstream bool
|
||||
typed
|
||||
|
||||
inner Resolver
|
||||
}
|
||||
|
||||
func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) ChainedResolver {
|
||||
func NewRewriterResolver(cfg config.RewriterConfig, inner ChainedResolver) ChainedResolver {
|
||||
if len(cfg.Rewrite) == 0 {
|
||||
return inner
|
||||
}
|
||||
|
@ -36,27 +37,22 @@ func NewRewriterResolver(cfg config.RewriteConfig, inner ChainedResolver) Chaine
|
|||
inner.Next(NewNoOpResolver())
|
||||
|
||||
return &RewriterResolver{
|
||||
rewrite: cfg.Rewrite,
|
||||
inner: inner,
|
||||
fallbackUpstream: cfg.FallbackUpstream,
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("rewrite"),
|
||||
|
||||
inner: inner,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RewriterResolver) Name() string {
|
||||
return fmt.Sprintf("%s w/ %s", Name(r.inner), defaultName(r))
|
||||
return fmt.Sprintf("%s w/ %s", Name(r.inner), r.Type())
|
||||
}
|
||||
|
||||
// Configuration returns current resolver configuration
|
||||
func (r *RewriterResolver) Configuration() (result []string) {
|
||||
result = append(result, "rewrite:")
|
||||
for key, val := range r.rewrite {
|
||||
result = append(result, fmt.Sprintf(" %s = \"%s\"", key, val))
|
||||
}
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
|
||||
LogResolverConfig(r.inner, logger)
|
||||
|
||||
innerCfg := r.inner.Configuration()
|
||||
result = append(result, innerCfg...)
|
||||
|
||||
return result
|
||||
r.cfg.LogConfig(logger)
|
||||
}
|
||||
|
||||
// Resolve uses the inner resolver to resolve the rewritten query
|
||||
|
@ -79,7 +75,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
request.Req = original
|
||||
|
||||
fallbackCondition := err != nil || (response != NoResponse && response.Res.Answer == nil)
|
||||
if r.fallbackUpstream && fallbackCondition {
|
||||
if r.cfg.FallbackUpstream && fallbackCondition {
|
||||
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
|
||||
|
||||
|
@ -130,7 +126,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg
|
|||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"domain": domainOriginal,
|
||||
"rewrite": rewriteKey + ":" + r.rewrite[rewriteKey],
|
||||
"rewrite": rewriteKey + ":" + r.cfg.Rewrite[rewriteKey],
|
||||
}).Debugf("rewriting %q to %q", domainOriginal, domainRewritten)
|
||||
}
|
||||
}
|
||||
|
@ -139,7 +135,7 @@ func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg
|
|||
}
|
||||
|
||||
func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
|
||||
for k, v := range r.rewrite {
|
||||
for k, v := range r.cfg.Rewrite {
|
||||
if strings.HasSuffix(domain, "."+k) {
|
||||
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
|
||||
|
||||
|
|
|
@ -2,8 +2,10 @@ package resolver
|
|||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -19,7 +21,7 @@ const (
|
|||
var _ = Describe("RewriterResolver", func() {
|
||||
var (
|
||||
sut ChainedResolver
|
||||
sutConfig config.RewriteConfig
|
||||
sutConfig config.RewriterConfig
|
||||
mInner *mockResolver
|
||||
mNext *mockResolver
|
||||
|
||||
|
@ -29,11 +31,17 @@ var _ = Describe("RewriterResolver", func() {
|
|||
mNextResponse *model.Response
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mInner = &mockResolver{}
|
||||
mNext = &mockResolver{}
|
||||
|
||||
sutConfig = config.RewriteConfig{Rewrite: map[string]string{"original": "rewritten"}}
|
||||
sutConfig = config.RewriterConfig{Rewrite: map[string]string{"original": "rewritten"}}
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
|
@ -48,7 +56,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
|
||||
When("has no configuration", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.RewriteConfig{}
|
||||
sutConfig = config.RewriterConfig{}
|
||||
})
|
||||
|
||||
It("should return the inner resolver", func() {
|
||||
|
@ -194,11 +202,10 @@ var _ = Describe("RewriterResolver", func() {
|
|||
Describe("Configuration output", func() {
|
||||
When("resolver is enabled", func() {
|
||||
It("should return configuration", func() {
|
||||
innerOutput := []string{"inner:", "config-output"}
|
||||
mInner.On("Configuration").Return(innerOutput)
|
||||
mInner.On("LogConfig")
|
||||
mInner.On("IsEnabled").Return(true)
|
||||
|
||||
c := sut.Configuration()
|
||||
Expect(len(c)).Should(BeNumerically(">", len(innerOutput)))
|
||||
sut.LogConfig(logrus.NewEntry(log.Log()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -46,11 +47,15 @@ type defaultIPs struct {
|
|||
|
||||
type SpecialUseDomainNamesResolver struct {
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
defaults *defaultIPs
|
||||
}
|
||||
|
||||
func NewSpecialUseDomainNamesResolver() ChainedResolver {
|
||||
return &SpecialUseDomainNamesResolver{
|
||||
typed: withType("special_use_domains"),
|
||||
|
||||
defaults: &defaultIPs{
|
||||
loopbackV4: net.ParseIP("127.0.0.1"),
|
||||
loopbackV6: net.IPv6loopback,
|
||||
|
@ -58,6 +63,17 @@ func NewSpecialUseDomainNamesResolver() ChainedResolver {
|
|||
}
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (r *SpecialUseDomainNamesResolver) IsEnabled() bool {
|
||||
// RFC 6761 & 6762 are always active
|
||||
return true
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *SpecialUseDomainNamesResolver) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("enabled")
|
||||
}
|
||||
|
||||
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
// RFC 6761 - negative
|
||||
if r.isSpecial(request, sudnArpaSlice()...) ||
|
||||
|
@ -78,11 +94,6 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
|
|||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
// RFC 6761 & 6762 are always active
|
||||
func (r *SpecialUseDomainNamesResolver) Configuration() []string {
|
||||
return configEnabled
|
||||
}
|
||||
|
||||
func (r *SpecialUseDomainNamesResolver) isSpecial(request *model.Request, names ...string) bool {
|
||||
domainFromQuestion := request.Req.Question[0].Name
|
||||
for _, n := range names {
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -16,6 +17,12 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
m *mockResolver
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
mockAnswer, err := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -27,6 +34,22 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should not log anything", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Blocking special names", func() {
|
||||
It("should block arpa", func() {
|
||||
for _, arpa := range sudnArpaSlice() {
|
||||
|
@ -133,12 +156,4 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Configuration pseudo test", func() {
|
||||
It("should always be empty", func() {
|
||||
c := sut.Configuration()
|
||||
|
||||
Expect(len(c)).Should(BeNumerically(">=", 1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -30,6 +30,8 @@ const (
|
|||
|
||||
// UpstreamResolver sends request to external DNS server
|
||||
type UpstreamResolver struct {
|
||||
typed
|
||||
|
||||
upstream config.Upstream
|
||||
upstreamClient upstreamClient
|
||||
bootstrap *Bootstrap
|
||||
|
@ -51,7 +53,7 @@ type httpUpstreamClient struct {
|
|||
}
|
||||
|
||||
func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||
timeout := time.Duration(config.GetConfig().UpstreamTimeout)
|
||||
timeout := config.GetConfig().UpstreamTimeout.ToDuration()
|
||||
|
||||
tlsConfig := tls.Config{
|
||||
ServerName: cfg.Host,
|
||||
|
@ -209,25 +211,30 @@ func newUpstreamResolverUnchecked(upstream config.Upstream, bootstrap *Bootstrap
|
|||
upstreamClient := createUpstreamClient(upstream)
|
||||
|
||||
return &UpstreamResolver{
|
||||
typed: withType("upstream"),
|
||||
|
||||
upstream: upstream,
|
||||
upstreamClient: upstreamClient,
|
||||
bootstrap: bootstrap,
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration return current resolver configuration
|
||||
func (r *UpstreamResolver) Configuration() (result []string) {
|
||||
return []string{r.String()}
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (r *UpstreamResolver) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *UpstreamResolver) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info(r.upstream)
|
||||
}
|
||||
|
||||
func (r UpstreamResolver) String() string {
|
||||
return fmt.Sprintf("upstream '%s'", r.upstream.String())
|
||||
return fmt.Sprintf("%s '%s'", r.Type(), r.upstream)
|
||||
}
|
||||
|
||||
// Resolve calls external resolver
|
||||
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||
logger := log.WithPrefix(request.Log, "upstream_resolver")
|
||||
|
||||
ips, err := r.bootstrap.UpstreamIPs(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -247,7 +254,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
|||
var err error
|
||||
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol)
|
||||
if err == nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(resp.Answer),
|
||||
"return_code": dns.RcodeToString[resp.Rcode],
|
||||
"upstream": r.upstream.String(),
|
||||
|
@ -272,7 +279,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
|||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
logger.WithFields(logrus.Fields{
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"upstream": r.upstream.String(),
|
||||
"upstream_ip": ip.String(),
|
||||
"question": util.QuestionToString(request.Req.Question),
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -17,7 +18,40 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||
systemResolverBootstrap := &Bootstrap{}
|
||||
var (
|
||||
sut *UpstreamResolver
|
||||
sutConfig config.Upstream
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.Upstream{Host: "localhost"}
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut = newUpstreamResolverUnchecked(sutConfig, systemResolverBootstrap)
|
||||
})
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
It("is true", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Using DNS upstream", func() {
|
||||
When("Configured DNS resolver can resolve query", func() {
|
||||
|
@ -35,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
|
||||
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
@ -53,7 +87,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
HaveNoAnswer(),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeNameError),
|
||||
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream.String()))),
|
||||
HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
@ -216,15 +250,4 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
Describe("Configuration", func() {
|
||||
When("Configuration is called", func() {
|
||||
It("should return configuration", func() {
|
||||
sut := newUpstreamResolverUnchecked(config.Upstream{}, nil)
|
||||
|
||||
c := sut.Configuration()
|
||||
|
||||
Expect(len(c)).Should(BeNumerically(">=", 1))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -395,7 +395,7 @@ func createQueryResolver(
|
|||
redisClient *redis.Client,
|
||||
) (r resolver.Resolver, err error) {
|
||||
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
|
||||
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers, bootstrap, cfg.StartVerifyUpstream)
|
||||
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
|
||||
|
@ -411,16 +411,16 @@ func createQueryResolver(
|
|||
|
||||
r = resolver.Chain(
|
||||
resolver.NewFilteringResolver(cfg.Filtering),
|
||||
resolver.NewFqdnOnlyResolver(*cfg),
|
||||
resolver.NewFqdnOnlyResolver(cfg.FqdnOnly),
|
||||
clientNames,
|
||||
resolver.NewEdeResolver(cfg.Ede),
|
||||
resolver.NewQueryLoggingResolver(cfg.QueryLog),
|
||||
resolver.NewMetricsResolver(cfg.Prometheus),
|
||||
resolver.NewRewriterResolver(cfg.CustomDNS.RewriteConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
|
||||
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
|
||||
resolver.NewHostsFileResolver(cfg.HostsFile),
|
||||
blocking,
|
||||
resolver.NewCachingResolver(cfg.Caching, redisClient),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriteConfig, condUpstream),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
|
||||
resolver.NewSpecialUseDomainNamesResolver(),
|
||||
parallel,
|
||||
)
|
||||
|
@ -439,25 +439,12 @@ func (s *Server) registerDNSHandlers() {
|
|||
func (s *Server) printConfiguration() {
|
||||
logger().Info("current configuration:")
|
||||
|
||||
res := s.queryResolver
|
||||
for res != nil {
|
||||
logger().Infof("-> resolver: '%s'", resolver.Name(res))
|
||||
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
|
||||
resolver.LogResolverConfig(res, logger())
|
||||
})
|
||||
|
||||
for _, c := range res.Configuration() {
|
||||
logger().Infof(" %s", c)
|
||||
}
|
||||
|
||||
if c, ok := res.(resolver.ChainedResolver); ok {
|
||||
res = c.GetNext()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
logger().Infof("- DNS listening on addrs/ports: %v", s.cfg.Ports.DNS)
|
||||
logger().Infof("- TLS listening on addrs/ports: %v", s.cfg.Ports.TLS)
|
||||
logger().Infof("- HTTP listening on addrs/ports: %v", s.cfg.Ports.HTTP)
|
||||
logger().Infof("- HTTPS listening on addrs/ports: %v", s.cfg.Ports.HTTPS)
|
||||
logger().Info("listeners:")
|
||||
log.WithIndent(logger(), " ", s.cfg.Ports.LogConfig)
|
||||
|
||||
logger().Info("runtime information:")
|
||||
|
||||
|
@ -465,17 +452,19 @@ func (s *Server) printConfiguration() {
|
|||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
logger().Infof(" numCPU = %d", runtime.NumCPU())
|
||||
logger().Infof(" numGoroutine = %d", runtime.NumGoroutine())
|
||||
|
||||
// gather memory stats
|
||||
var m runtime.MemStats
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
logger().Infof("MEM Alloc = %10v MB", toMB(m.Alloc))
|
||||
logger().Infof("MEM HeapAlloc = %10v MB", toMB(m.HeapAlloc))
|
||||
logger().Infof("MEM Sys = %10v MB", toMB(m.Sys))
|
||||
logger().Infof("MEM NumGC = %10v", m.NumGC)
|
||||
logger().Infof("RUN NumCPU = %10d", runtime.NumCPU())
|
||||
logger().Infof("RUN NumGoroutine = %10d", runtime.NumGoroutine())
|
||||
logger().Infof(" memory:")
|
||||
logger().Infof(" alloc = %10v MB", toMB(m.Alloc))
|
||||
logger().Infof(" heapAlloc = %10v MB", toMB(m.HeapAlloc))
|
||||
logger().Infof(" sys = %10v MB", toMB(m.Sys))
|
||||
logger().Infof(" numGC = %10v", m.NumGC)
|
||||
}
|
||||
|
||||
func toMB(b uint64) uint64 {
|
||||
|
|
|
@ -105,7 +105,7 @@ var _ = BeforeSuite(func() {
|
|||
// create server
|
||||
sut, err = NewServer(&config.Config{
|
||||
CustomDNS: config.CustomDNSConfig{
|
||||
CustomTTL: config.Duration(time.Duration(3600) * time.Second),
|
||||
CustomTTL: config.Duration(3600 * time.Second),
|
||||
Mapping: config.CustomDNSMapping{
|
||||
HostIPs: map[string][]net.IP{
|
||||
"custom.lan": {net.ParseIP("192.168.178.55")},
|
||||
|
@ -143,7 +143,7 @@ var _ = BeforeSuite(func() {
|
|||
BlockType: "zeroIp",
|
||||
BlockTTL: config.Duration(6 * time.Hour),
|
||||
},
|
||||
Upstream: config.UpstreamConfig{
|
||||
Upstream: config.ParallelBestConfig{
|
||||
ExternalResolvers: map[string][]config.Upstream{"default": {upstreamGoogle}},
|
||||
},
|
||||
ClientLookup: config.ClientLookupConfig{
|
||||
|
@ -158,7 +158,7 @@ var _ = BeforeSuite(func() {
|
|||
},
|
||||
CertFile: certPem.Path,
|
||||
KeyFile: keyPem.Path,
|
||||
Prometheus: config.PrometheusConfig{
|
||||
Prometheus: config.MetricsConfig{
|
||||
Enable: true,
|
||||
Path: "/metrics",
|
||||
},
|
||||
|
@ -643,7 +643,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
It("start was called 2 times, start should fail", func() {
|
||||
// create server
|
||||
server, err := NewServer(&config.Config{
|
||||
Upstream: config.UpstreamConfig{
|
||||
Upstream: config.ParallelBestConfig{
|
||||
ExternalResolvers: map[string][]config.Upstream{
|
||||
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
|
||||
},
|
||||
|
@ -685,7 +685,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
It("stop was called 2 times, start should fail", func() {
|
||||
// create server
|
||||
server, err := NewServer(&config.Config{
|
||||
Upstream: config.UpstreamConfig{
|
||||
Upstream: config.ParallelBestConfig{
|
||||
ExternalResolvers: map[string][]config.Upstream{
|
||||
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
|
||||
},
|
||||
|
|
|
@ -109,7 +109,7 @@ func NewMsgWithQuestion(question string, qType dns.Type) *dns.Msg {
|
|||
|
||||
// NewMsgWithAnswer creates new DNS message with answer
|
||||
func NewMsgWithAnswer(domain string, ttl uint, dnsType dns.Type, address string) (*dns.Msg, error) {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType.String(), address))
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType, address))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue