diff --git a/.golangci.yml b/.golangci.yml index 53153135..3572589d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -80,6 +80,7 @@ issues: # Exclude some linters from running on tests files. - path: _test\.go linters: - - gochecknoglobals - dupl + - funlen + - gochecknoglobals - gosec diff --git a/config/blocking.go b/config/blocking.go index f4f4ac12..364d8dfc 100644 --- a/config/blocking.go +++ b/config/blocking.go @@ -1,8 +1,6 @@ package config import ( - "strings" - . "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck "github.com/0xERR0R/blocky/log" "github.com/sirupsen/logrus" @@ -10,32 +8,40 @@ import ( // BlockingConfig configuration for query blocking type BlockingConfig struct { - BlackLists map[string][]string `yaml:"blackLists"` - WhiteLists map[string][]string `yaml:"whiteLists"` - ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"` - BlockType string `yaml:"blockType" default:"ZEROIP"` - BlockTTL Duration `yaml:"blockTTL" default:"6h"` - DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"` - DownloadAttempts uint `yaml:"downloadAttempts" default:"3"` - DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"` - RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` - ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"` - StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"` - MaxErrorsPerFile int `yaml:"maxErrorsPerFile" default:"5"` + BlackLists map[string][]BytesSource `yaml:"blackLists"` + WhiteLists map[string][]BytesSource `yaml:"whiteLists"` + ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"` + BlockType string `yaml:"blockType" default:"ZEROIP"` + BlockTTL Duration `yaml:"blockTTL" default:"6h"` + Loading SourceLoadingConfig `yaml:"loading"` // Deprecated options Deprecated struct { - FailStartOnListError *bool `yaml:"failStartOnListError"` + DownloadTimeout *Duration `yaml:"downloadTimeout"` + DownloadAttempts *uint `yaml:"downloadAttempts"` + DownloadCooldown *Duration `yaml:"downloadCooldown"` + RefreshPeriod *Duration `yaml:"refreshPeriod"` + FailStartOnListError *bool `yaml:"failStartOnListError"` + ProcessingConcurrency *uint `yaml:"processingConcurrency"` + StartStrategy *StartStrategyType `yaml:"startStrategy"` + MaxErrorsPerFile *int `yaml:"maxErrorsPerFile"` } `yaml:",inline"` } func (c *BlockingConfig) migrate(logger *logrus.Entry) bool { return Migrate(logger, "blocking", c.Deprecated, map[string]Migrator{ - "failStartOnListError": Apply(To("startStrategy", c), func(oldValue bool) { - if oldValue && c.StartStrategy != StartStrategyTypeFast { - c.StartStrategy = StartStrategyTypeFailOnError + "downloadTimeout": Move(To("loading.downloads.timeout", &c.Loading.Downloads)), + "downloadAttempts": Move(To("loading.downloads.attempts", &c.Loading.Downloads)), + "downloadCooldown": Move(To("loading.downloads.cooldown", &c.Loading.Downloads)), + "refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)), + "failStartOnListError": Apply(To("loading.strategy", &c.Loading), func(oldValue bool) { + if oldValue { + c.Loading.Strategy = StartStrategyTypeFailOnError } }), + "processingConcurrency": Move(To("loading.concurrency", &c.Loading)), + "startStrategy": Move(To("loading.strategy", &c.Loading)), + "maxErrorsPerFile": Move(To("loading.maxErrorsPerSource", &c.Loading)), }) } @@ -44,7 +50,7 @@ func (c *BlockingConfig) IsEnabled() bool { return len(c.ClientGroupsBlock) != 0 } -// IsEnabled implements `config.Configurable`. +// LogConfig implements `config.Configurable`. func (c *BlockingConfig) LogConfig(logger *logrus.Entry) { logger.Info("clientGroupsBlock:") @@ -58,17 +64,8 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) { logger.Infof("blockTTL = %s", c.BlockTTL) } - logger.Infof("downloadTimeout = %s", c.DownloadTimeout) - - logger.Infof("startStrategy = %s", c.StartStrategy) - - logger.Infof("maxErrorsPerFile = %d", c.MaxErrorsPerFile) - - if c.RefreshPeriod > 0 { - logger.Infof("refresh = every %s", c.RefreshPeriod) - } else { - logger.Debug("refresh = disabled") - } + logger.Info("loading:") + log.WithIndent(logger, " ", c.Loading.LogConfig) logger.Info("blacklist:") log.WithIndent(logger, " ", func(logger *logrus.Entry) { @@ -81,18 +78,12 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) { }) } -func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) { - for group, links := range listGroups { +func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]BytesSource) { + for group, sources := range listGroups { logger.Infof("%s:", group) - for _, link := range links { - if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char - link = link[:idx] // first line only - - logger.Infof(" - %s [...]", link) - } else { - logger.Infof(" - %s", link) - } + for _, source := range sources { + logger.Infof(" - %s", source) } } } diff --git a/config/blocking_test.go b/config/blocking_test.go index 59258525..17b1a6ff 100644 --- a/config/blocking_test.go +++ b/config/blocking_test.go @@ -6,7 +6,6 @@ import ( "github.com/creasty/defaults" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/sirupsen/logrus" ) var _ = Describe("BlockingConfig", func() { @@ -18,13 +17,12 @@ var _ = Describe("BlockingConfig", func() { cfg = BlockingConfig{ BlockType: "ZEROIP", BlockTTL: Duration(time.Minute), - BlackLists: map[string][]string{ - "gr1": {"/a/file/path"}, + BlackLists: map[string][]BytesSource{ + "gr1": NewBytesSources("/a/file/path"), }, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, - RefreshPeriod: Duration(time.Hour), } }) @@ -59,26 +57,7 @@ var _ = Describe("BlockingConfig", func() { Expect(hook.Calls).ShouldNot(BeEmpty()) Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:")) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour"))) - }) - When("refresh is disabled", func() { - It("should reflect that", func() { - cfg.RefreshPeriod = Duration(-1) - - logger.Logger.Level = logrus.InfoLevel - - cfg.LogConfig(logger) - - Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled"))) - - logger.Logger.Level = logrus.TraceLevel - - cfg.LogConfig(logger) - - Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled"))) - }) + Expect(hook.Messages).Should(ContainElement(Equal("blockType = ZEROIP"))) }) }) }) diff --git a/config/bytes_source.go b/config/bytes_source.go new file mode 100644 index 00000000..c78277ae --- /dev/null +++ b/config/bytes_source.go @@ -0,0 +1,111 @@ +//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values +package config + +import ( + "fmt" + "strings" +) + +const maxTextSourceDisplayLen = 12 + +// var BytesSourceNone = BytesSource{} + +// BytesSourceType supported BytesSource types. ENUM( +// text=1 // Inline YAML block. +// http // HTTP(S). +// file // Local file. +// ) +type BytesSourceType uint16 + +type BytesSource struct { + Type BytesSourceType + From string +} + +func (s BytesSource) String() string { + switch s.Type { + case BytesSourceTypeText: + break + + case BytesSourceTypeHttp: + return s.From + + case BytesSourceTypeFile: + return fmt.Sprintf("file://%s", s.From) + + default: + return fmt.Sprintf("unknown source (%s: %s)", s.Type, s.From) + } + + text := s.From + truncated := false + + if idx := strings.IndexRune(text, '\n'); idx != -1 { + text = text[:idx] // first line only + truncated = idx < len(text) // don't count removing last char + } + + if len(text) > maxTextSourceDisplayLen { // truncate + text = text[:maxTextSourceDisplayLen] + truncated = true + } + + if truncated { + return fmt.Sprintf("%s...", text[:maxTextSourceDisplayLen]) + } + + return text +} + +// UnmarshalText implements `encoding.TextUnmarshaler`. +func (s *BytesSource) UnmarshalText(data []byte) error { + source := string(data) + + switch { + // Inline definition in YAML (with literal style Block Scalar) + case strings.ContainsAny(source, "\n"): + *s = BytesSource{Type: BytesSourceTypeText, From: source} + + // HTTP(S) + case strings.HasPrefix(source, "http"): + *s = BytesSource{Type: BytesSourceTypeHttp, From: source} + + // Probably path to a local file + default: + *s = BytesSource{Type: BytesSourceTypeFile, From: strings.TrimPrefix(source, "file://")} + } + + return nil +} + +func newBytesSource(source string) BytesSource { + var res BytesSource + + // UnmarshalText never returns an error + _ = res.UnmarshalText([]byte(source)) + + return res +} + +func NewBytesSources(sources ...string) []BytesSource { + res := make([]BytesSource, 0, len(sources)) + + for _, source := range sources { + res = append(res, newBytesSource(source)) + } + + return res +} + +func TextBytesSource(lines ...string) BytesSource { + return BytesSource{Type: BytesSourceTypeText, From: inlineList(lines...)} +} + +func inlineList(lines ...string) string { + res := strings.Join(lines, "\n") + + // ensure at least one line ending so it's parsed as an inline block + res += "\n" + + return res +} diff --git a/config/bytes_source_enum.go b/config/bytes_source_enum.go new file mode 100644 index 00000000..7c4dd859 --- /dev/null +++ b/config/bytes_source_enum.go @@ -0,0 +1,101 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: +// Revision: +// Build Date: +// Built By: + +package config + +import ( + "fmt" + "strings" +) + +const ( + // BytesSourceTypeText is a BytesSourceType of type Text. + // Inline YAML block. + BytesSourceTypeText BytesSourceType = iota + 1 + // BytesSourceTypeHttp is a BytesSourceType of type Http. + // HTTP(S). + BytesSourceTypeHttp + // BytesSourceTypeFile is a BytesSourceType of type File. + // Local file. + BytesSourceTypeFile +) + +var ErrInvalidBytesSourceType = fmt.Errorf("not a valid BytesSourceType, try [%s]", strings.Join(_BytesSourceTypeNames, ", ")) + +const _BytesSourceTypeName = "texthttpfile" + +var _BytesSourceTypeNames = []string{ + _BytesSourceTypeName[0:4], + _BytesSourceTypeName[4:8], + _BytesSourceTypeName[8:12], +} + +// BytesSourceTypeNames returns a list of possible string values of BytesSourceType. +func BytesSourceTypeNames() []string { + tmp := make([]string, len(_BytesSourceTypeNames)) + copy(tmp, _BytesSourceTypeNames) + return tmp +} + +// BytesSourceTypeValues returns a list of the values for BytesSourceType +func BytesSourceTypeValues() []BytesSourceType { + return []BytesSourceType{ + BytesSourceTypeText, + BytesSourceTypeHttp, + BytesSourceTypeFile, + } +} + +var _BytesSourceTypeMap = map[BytesSourceType]string{ + BytesSourceTypeText: _BytesSourceTypeName[0:4], + BytesSourceTypeHttp: _BytesSourceTypeName[4:8], + BytesSourceTypeFile: _BytesSourceTypeName[8:12], +} + +// String implements the Stringer interface. +func (x BytesSourceType) String() string { + if str, ok := _BytesSourceTypeMap[x]; ok { + return str + } + return fmt.Sprintf("BytesSourceType(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x BytesSourceType) IsValid() bool { + _, ok := _BytesSourceTypeMap[x] + return ok +} + +var _BytesSourceTypeValue = map[string]BytesSourceType{ + _BytesSourceTypeName[0:4]: BytesSourceTypeText, + _BytesSourceTypeName[4:8]: BytesSourceTypeHttp, + _BytesSourceTypeName[8:12]: BytesSourceTypeFile, +} + +// ParseBytesSourceType attempts to convert a string to a BytesSourceType. +func ParseBytesSourceType(name string) (BytesSourceType, error) { + if x, ok := _BytesSourceTypeValue[name]; ok { + return x, nil + } + return BytesSourceType(0), fmt.Errorf("%s is %w", name, ErrInvalidBytesSourceType) +} + +// MarshalText implements the text marshaller method. +func (x BytesSourceType) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *BytesSourceType) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseBytesSourceType(name) + if err != nil { + return err + } + *x = tmp + return nil +} diff --git a/config/caching_test.go b/config/caching_test.go index 1b80c7ac..1cc2c635 100644 --- a/config/caching_test.go +++ b/config/caching_test.go @@ -60,4 +60,20 @@ var _ = Describe("CachingConfig", func() { }) }) }) + + Describe("EnablePrefetch", func() { + When("prefetching is enabled", func() { + BeforeEach(func() { + cfg = CachingConfig{} + }) + + It("should return configuration", func() { + cfg.EnablePrefetch() + + Expect(cfg.Prefetching).Should(BeTrue()) + Expect(cfg.PrefetchThreshold).Should(Equal(0)) + Expect(cfg.MaxCachingTime).ShouldNot(BeZero()) + }) + }) + }) }) diff --git a/config/config.go b/config/config.go index 1fa2e135..1580352f 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( + "context" "errors" "fmt" "net" @@ -10,6 +11,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/sirupsen/logrus" @@ -32,7 +34,7 @@ type Configurable interface { // LogConfig logs the receiver's configuration. // - // Calling this method when `IsEnabled` returns false is undefined. + // The behavior of this method is undefined when `IsEnabled` returns false. LogConfig(*logrus.Entry) } @@ -93,6 +95,30 @@ type QueryLogType int16 // ) type StartStrategyType uint16 +func (s *StartStrategyType) do(setup func() error, logErr func(error)) error { + if *s == StartStrategyTypeFast { + go func() { + err := setup() + if err != nil { + logErr(err) + } + }() + + return nil + } + + err := setup() + if err != nil { + logErr(err) + + if *s == StartStrategyTypeFailOnError { + return err + } + } + + return nil +} + // QueryLogField data field to be logged // ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration) type QueryLogField string @@ -259,6 +285,86 @@ func (c *toEnable) LogConfig(logger *logrus.Entry) { logger.Info("enabled") } +type SourceLoadingConfig struct { + Concurrency uint `yaml:"concurrency" default:"4"` + MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"` + RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"` + Strategy StartStrategyType `yaml:"strategy" default:"blocking"` + Downloads DownloaderConfig `yaml:"downloads"` +} + +func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("concurrency = %d", c.Concurrency) + logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource) + logger.Debugf("strategy = %s", c.Strategy) + + if c.RefreshPeriod > 0 { + logger.Infof("refresh = every %s", c.RefreshPeriod) + } else { + logger.Debug("refresh = disabled") + } + + logger.Info("downloads:") + log.WithIndent(logger, " ", c.Downloads.LogConfig) +} + +func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error { + refreshAndRecover := func(ctx context.Context) (rerr error) { + defer func() { + if val := recover(); val != nil { + rerr = fmt.Errorf("refresh function panicked: %v", val) + } + }() + + return refresh(ctx) + } + + err := c.Strategy.do(func() error { return refreshAndRecover(context.Background()) }, logErr) + if err != nil { + return err + } + + if c.RefreshPeriod > 0 { + go c.periodically(refreshAndRecover, logErr) + } + + return nil +} + +func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) { + ticker := time.NewTicker(c.RefreshPeriod.ToDuration()) + defer ticker.Stop() + + for range ticker.C { + err := refresh(context.Background()) + if err != nil { + logErr(err) + } + } +} + +type DownloaderConfig struct { + Timeout Duration `yaml:"timeout" default:"5s"` + Attempts uint `yaml:"attempts" default:"3"` + Cooldown Duration `yaml:"cooldown" default:"500ms"` +} + +func (c *DownloaderConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("timeout = %s", c.Timeout) + logger.Infof("attempts = %d", c.Attempts) + logger.Debugf("cooldown = %s", c.Cooldown) +} + +func WithDefaults[T any]() (T, error) { + var cfg T + + if err := defaults.Set(&cfg); err != nil { + return cfg, fmt.Errorf("can't apply %T defaults: %w", cfg, err) + } + + return cfg, nil +} + //nolint:gochecknoglobals var ( config = &Config{} @@ -270,9 +376,9 @@ func LoadConfig(path string, mandatory bool) (*Config, error) { cfgLock.Lock() defer cfgLock.Unlock() - cfg := Config{} - if err := defaults.Set(&cfg); err != nil { - return nil, fmt.Errorf("can't apply default values: %w", err) + cfg, err := WithDefaults[Config]() + if err != nil { + return nil, err } fs, err := os.Stat(path) @@ -398,6 +504,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool { }) usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts + usesDepredOpts = cfg.HostsFile.migrate(logger) || usesDepredOpts return usesDepredOpts } diff --git a/config/config_enum.go b/config/config_enum.go index 337b258b..e445ac5d 100644 --- a/config/config_enum.go +++ b/config/config_enum.go @@ -63,6 +63,13 @@ func (x IPVersion) String() string { return fmt.Sprintf("IPVersion(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x IPVersion) IsValid() bool { + _, ok := _IPVersionMap[x] + return ok +} + var _IPVersionValue = map[string]IPVersion{ _IPVersionName[0:4]: IPVersionDual, _IPVersionName[4:6]: IPVersionV4, @@ -145,6 +152,13 @@ func (x NetProtocol) String() string { return fmt.Sprintf("NetProtocol(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x NetProtocol) IsValid() bool { + _, ok := _NetProtocolMap[x] + return ok +} + var _NetProtocolValue = map[string]NetProtocol{ _NetProtocolName[0:7]: NetProtocolTcpUdp, _NetProtocolName[7:14]: NetProtocolTcpTls, @@ -225,7 +239,8 @@ func (x QueryLogField) String() string { return string(x) } -// String implements the Stringer interface. +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values func (x QueryLogField) IsValid() bool { _, err := ParseQueryLogField(string(x)) return err == nil @@ -333,6 +348,13 @@ func (x QueryLogType) String() string { return fmt.Sprintf("QueryLogType(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x QueryLogType) IsValid() bool { + _, ok := _QueryLogTypeMap[x] + return ok +} + var _QueryLogTypeValue = map[string]QueryLogType{ _QueryLogTypeName[0:7]: QueryLogTypeConsole, _QueryLogTypeName[7:11]: QueryLogTypeNone, @@ -418,6 +440,13 @@ func (x StartStrategyType) String() string { return fmt.Sprintf("StartStrategyType(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x StartStrategyType) IsValid() bool { + _, ok := _StartStrategyTypeMap[x] + return ok +} + var _StartStrategyTypeValue = map[string]StartStrategyType{ _StartStrategyTypeName[0:8]: StartStrategyTypeBlocking, _StartStrategyTypeName[8:19]: StartStrategyTypeFailOnError, diff --git a/config/config_test.go b/config/config_test.go index ccef7a31..5e2039c7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,11 +1,15 @@ package config import ( + "context" + "errors" "net" + "sync/atomic" "time" "github.com/creasty/defaults" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -48,15 +52,17 @@ var _ = Describe("Config", func() { BeforeEach(func() { c.Blocking.Deprecated.FailStartOnListError = ptrOf(true) }) - It("should change StartStrategy blocking to failOnError", func() { - c.Blocking.StartStrategy = StartStrategyTypeBlocking + It("should change loading.strategy blocking to failOnError", func() { + c.Blocking.Loading.Strategy = StartStrategyTypeBlocking c.migrate(logger) - Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFailOnError)) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy"))) + Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFailOnError)) }) - It("shouldn't change StartStrategy if set to fast", func() { - c.Blocking.StartStrategy = StartStrategyTypeFast + It("shouldn't change loading.strategy if set to fast", func() { + c.Blocking.Loading.Strategy = StartStrategyTypeFast c.migrate(logger) - Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFast)) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy"))) + Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFast)) }) }) @@ -206,8 +212,10 @@ var _ = Describe("Config", func() { When("duration is in wrong format", func() { It("should return error", func() { cfg := Config{} - data := `blocking: - refreshPeriod: wrongduration` + data := ` +blocking: + loading: + refreshPeriod: wrongduration` err := unmarshalConfig([]byte(data), &cfg) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("invalid duration \"wrongduration\"")) @@ -534,6 +542,222 @@ bootstrapDns: "tcp-tls:[fd00::6cd4:d7e0:d99d:2952]", ), ) + + Describe("SourceLoadingConfig", func() { + var cfg SourceLoadingConfig + + BeforeEach(func() { + cfg = SourceLoadingConfig{ + Concurrency: 12, + RefreshPeriod: Duration(time.Hour), + } + }) + + Describe("LogConfig", func() { + It("should log configuration", func() { + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages[0]).Should(Equal("concurrency = 12")) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour"))) + }) + When("refresh is disabled", func() { + BeforeEach(func() { + cfg.RefreshPeriod = Duration(-1) + }) + + It("should reflect that", func() { + logger.Logger.Level = logrus.InfoLevel + + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled"))) + + logger.Logger.Level = logrus.TraceLevel + + cfg.LogConfig(logger) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled"))) + }) + }) + }) + }) + + Describe("StartStrategyType", func() { + Describe("StartStrategyTypeBlocking", func() { + It("runs in the current goroutine", func() { + sut := StartStrategyTypeBlocking + panicVal := new(int) + + defer func() { + // recover will catch the panic if it happened in the same goroutine + Expect(recover()).Should(BeIdenticalTo(panicVal)) + }() + + _ = sut.do(func() error { + panic(panicVal) + }, nil) + + Fail("unreachable") + }) + + It("logs errors and doesn't return them", func() { + sut := StartStrategyTypeBlocking + expectedErr := errors.New("test") + + err := sut.do(func() error { + return expectedErr + }, func(err error) { + Expect(err).Should(MatchError(expectedErr)) + }) + + Expect(err).Should(Succeed()) + }) + }) + + Describe("StartStrategyTypeFailOnError", func() { + It("runs in the current goroutine", func() { + sut := StartStrategyTypeBlocking + panicVal := new(int) + + defer func() { + // recover will catch the panic if it happened in the same goroutine + Expect(recover()).Should(BeIdenticalTo(panicVal)) + }() + + _ = sut.do(func() error { + panic(panicVal) + }, nil) + + Fail("unreachable") + }) + + It("logs errors and returns them", func() { + sut := StartStrategyTypeFailOnError + expectedErr := errors.New("test") + + err := sut.do(func() error { + return expectedErr + }, func(err error) { + Expect(err).Should(MatchError(expectedErr)) + }) + + Expect(err).Should(MatchError(expectedErr)) + }) + }) + + Describe("StartStrategyTypeFast", func() { + It("runs in a new goroutine", func() { + sut := StartStrategyTypeFast + events := make(chan string) + wait := make(chan struct{}) + + err := sut.do(func() error { + events <- "start" + <-wait + events <- "done" + + return nil + }, nil) + + Eventually(events, "50ms").Should(Receive(Equal("start"))) + Expect(err).Should(Succeed()) + Consistently(events).ShouldNot(Receive()) + close(wait) + Eventually(events, "50ms").Should(Receive(Equal("done"))) + }) + + It("logs errors", func() { + sut := StartStrategyTypeFast + expectedErr := errors.New("test") + wait := make(chan struct{}) + + err := sut.do(func() error { + return expectedErr + }, func(err error) { + Expect(err).Should(MatchError(expectedErr)) + close(wait) + }) + + Expect(err).Should(Succeed()) + Eventually(wait, "50ms").Should(BeClosed()) + }) + }) + }) + + Describe("SourceLoadingConfig", func() { + It("handles panics", func() { + sut := SourceLoadingConfig{ + Strategy: StartStrategyTypeFailOnError, + } + + panicMsg := "panic value" + + err := sut.StartPeriodicRefresh(func(context.Context) error { + panic(panicMsg) + }, func(err error) { + Expect(err).Should(MatchError(ContainSubstring(panicMsg))) + }) + + Expect(err).Should(MatchError(ContainSubstring(panicMsg))) + }) + + It("periodically calls refresh", func() { + sut := SourceLoadingConfig{ + Strategy: StartStrategyTypeFast, + RefreshPeriod: Duration(5 * time.Millisecond), + } + + panicMsg := "panic value" + calls := make(chan int32) + + var call atomic.Int32 + + err := sut.StartPeriodicRefresh(func(context.Context) error { + call := call.Add(1) + calls <- call + + if call == 3 { + panic(panicMsg) + } + + return nil + }, func(err error) { + defer GinkgoRecover() + + Expect(err).Should(MatchError(ContainSubstring(panicMsg))) + Expect(call.Load()).Should(Equal(int32(3))) + }) + + Expect(err).Should(Succeed()) + Eventually(calls, "50ms").Should(Receive(Equal(int32(1)))) + Eventually(calls, "50ms").Should(Receive(Equal(int32(2)))) + Eventually(calls, "50ms").Should(Receive(Equal(int32(3)))) + }) + }) + + Describe("WithDefaults", func() { + It("use valid defaults", func() { + type T struct { + X int `default:"1"` + } + + t, err := WithDefaults[T]() + Expect(err).Should(Succeed()) + Expect(t.X).Should(Equal(1)) + }) + + It("return an error if the tag is invalid", func() { + type T struct { + X struct{} `default:"fail"` + } + + _, err := WithDefaults[T]() + Expect(err).ShouldNot(Succeed()) + }) + }) }) func defaultTestFileConfig() { @@ -557,7 +781,7 @@ func defaultTestFileConfig() { Expect(config.Blocking.WhiteLists).Should(HaveLen(1)) Expect(config.Blocking.ClientGroupsBlock).Should(HaveLen(2)) Expect(config.Blocking.BlockTTL).Should(Equal(Duration(time.Minute))) - Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour))) + Expect(config.Blocking.Loading.RefreshPeriod).Should(Equal(Duration(2 * time.Hour))) Expect(config.Filtering.QueryTypes).Should(HaveLen(2)) Expect(config.FqdnOnly.Enable).Should(BeTrue()) @@ -613,7 +837,8 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { " Laptop-D.fritz.box:", " - ads", " blockTTL: 1m", - " refreshPeriod: 120", + " loading:", + " refreshPeriod: 120", "clientLookup:", " upstream: 192.168.178.1", " singleNameOrder:", @@ -629,7 +854,6 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { "startVerifyUpstream: false") } -//nolint:funlen func writeConfigDir(tmpDir *helpertest.TmpFolder) error { f1 := tmpDir.CreateStringFile("config1.yaml", "upstream:", @@ -675,7 +899,8 @@ func writeConfigDir(tmpDir *helpertest.TmpFolder) error { " Laptop-D.fritz.box:", " - ads", " blockTTL: 1m", - " refreshPeriod: 120", + " loading:", + " refreshPeriod: 120", "clientLookup:", " upstream: 192.168.178.1", " singleNameOrder:", diff --git a/config/hosts_file.go b/config/hosts_file.go index 967b735d..b3402941 100644 --- a/config/hosts_file.go +++ b/config/hosts_file.go @@ -1,25 +1,49 @@ package config import ( + . "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck + "github.com/0xERR0R/blocky/log" "github.com/sirupsen/logrus" ) type HostsFileConfig struct { - Filepath string `yaml:"filePath"` - HostsTTL Duration `yaml:"hostsTTL" default:"1h"` - RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"` - FilterLoopback bool `yaml:"filterLoopback"` + Sources []BytesSource `yaml:"sources"` + HostsTTL Duration `yaml:"hostsTTL" default:"1h"` + FilterLoopback bool `yaml:"filterLoopback"` + Loading SourceLoadingConfig `yaml:"loading"` + + // Deprecated options + Deprecated struct { + RefreshPeriod *Duration `yaml:"refreshPeriod"` + Filepath *BytesSource `yaml:"filePath"` + } `yaml:",inline"` +} + +func (c *HostsFileConfig) migrate(logger *logrus.Entry) bool { + return Migrate(logger, "hostsFile", c.Deprecated, map[string]Migrator{ + "refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)), + "filePath": Apply(To("sources", c), func(value BytesSource) { + c.Sources = append(c.Sources, value) + }), + }) } // IsEnabled implements `config.Configurable`. func (c *HostsFileConfig) IsEnabled() bool { - return len(c.Filepath) != 0 + return len(c.Sources) != 0 } // LogConfig implements `config.Configurable`. func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) { - logger.Infof("file path: %s", c.Filepath) logger.Infof("TTL: %s", c.HostsTTL) - logger.Infof("refresh period: %s", c.RefreshPeriod) logger.Infof("filter loopback addresses: %t", c.FilterLoopback) + + logger.Info("loading:") + log.WithIndent(logger, " ", c.Loading.LogConfig) + + logger.Info("sources:") + + for _, source := range c.Sources { + logger.Infof(" - %s", source) + } } diff --git a/config/hosts_file_test.go b/config/hosts_file_test.go index 5f67feb8..4ff5d9c1 100644 --- a/config/hosts_file_test.go +++ b/config/hosts_file_test.go @@ -15,9 +15,12 @@ var _ = Describe("HostsFileConfig", func() { BeforeEach(func() { cfg = HostsFileConfig{ - Filepath: "/dev/null", + Sources: append( + NewBytesSources("/a/file/path"), + TextBytesSource("127.0.0.1 localhost"), + ), HostsTTL: Duration(29 * time.Minute), - RefreshPeriod: Duration(30 * time.Minute), + Loading: SourceLoadingConfig{RefreshPeriod: Duration(30 * time.Minute)}, FilterLoopback: true, } }) @@ -50,7 +53,28 @@ var _ = Describe("HostsFileConfig", func() { cfg.LogConfig(logger) Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("- file:///a/file/path"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("- 127.0.0.1 lo..."))) + }) + }) + + Describe("migrate", func() { + It("should", func() { + cfg, err := WithDefaults[HostsFileConfig]() + Expect(err).Should(Succeed()) + + cfg.Deprecated.Filepath = ptrOf(newBytesSource("/a/file/path")) + cfg.Deprecated.RefreshPeriod = ptrOf(Duration(time.Hour)) + + migrated := cfg.migrate(logger) + Expect(migrated).Should(BeTrue()) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.loading.refreshPeriod"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.sources"))) + + Expect(cfg.Sources).Should(Equal([]BytesSource{*cfg.Deprecated.Filepath})) + Expect(cfg.Loading.RefreshPeriod).Should(Equal(*cfg.Deprecated.RefreshPeriod)) }) }) }) diff --git a/docs/config.yml b/docs/config.yml index 89bb6751..4d290d5a 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -1,3 +1,5 @@ +# REVIEW: manual changelog entry + upstream: # these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query # format for resolver: [net:]host:[port][/path]. net could be empty (default, shortcut for tcp+udp), tcp+udp, tcp, udp, tcp-tls or https (DoH). If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls, 443 for https (Doh)) @@ -99,22 +101,33 @@ blocking: # optional: TTL for answers to blocked domains # default: 6h blockTTL: 1m - # optional: automatically list refresh period (in duration format). Default: 4h. - # Negative value -> deactivate automatically refresh. - # 0 value -> use default - refreshPeriod: 4h - # optional: timeout for list download (each url). Default: 60s. Use large values for big lists or slow internet connections - downloadTimeout: 4m - # optional: Download attempt timeout. Default: 60s - downloadAttempts: 5 - # optional: Time between the download attempts. Default: 1s - downloadCooldown: 10s - # optional: if failOnError, application startup will fail if at least one list can't be downloaded / opened. Default: blocking - startStrategy: failOnError - # Number of errors allowed in a list before it is considered invalid. - # A value of -1 disables the limit. - # Default: 5 - maxErrorsPerFile: 5 + # optional: Configure how lists, AKA sources, are loaded + loading: + # optional: list refresh period in duration format. + # Set to a value <= 0 to disable. + # default: 4h + refreshPeriod: 24h + # optional: Applies only to lists that are downloaded (HTTP URLs). + downloads: + # optional: timeout for list download (each url). Use large values for big lists or slow internet connections + # default: 5s + timeout: 60s + # optional: Maximum download attempts + # default: 3 + attempts: 5 + # optional: Time between the download attempts + # default: 500ms + cooldown: 10s + # optional: Maximum number of lists to process in parallel. + # default: 4 + concurrency: 16 + # optional: if failOnError, application startup will fail if at least one list can't be downloaded/opened + # default: blocking + strategy: failOnError + # Number of errors allowed in a list before it is considered invalid. + # A value of -1 disables the limit. + # default: 5 + maxErrorsPerSource: 5 # optional: configuration for caching of DNS responses caching: @@ -161,6 +174,7 @@ clientLookup: clients: laptop: - 192.168.178.29 + # optional: configuration for prometheus metrics endpoint prometheus: # enabled if true @@ -214,6 +228,7 @@ redis: # optional: Mininal TLS version that the DoH and DoT server will use minTlsServeVersion: 1.3 + # if https port > 0: path to cert and key file for SSL encryption. if not set, self-signed certificate will be generated #certFile: server.crt #keyFile: server.key @@ -238,14 +253,45 @@ fqdnOnly: # optional: if path defined, use this file for query resolution (A, AAAA and rDNS). Default: empty hostsFile: - # optional: Path to hosts file (e.g. /etc/hosts on Linux) - filePath: /etc/hosts + # optional: Hosts files to parse + sources: + - /etc/hosts + - https://example.com/hosts + - | + # inline hosts + 127.0.0.1 example.com # optional: TTL, default: 1h - hostsTTL: 60m - # optional: Time between hosts file refresh, default: 1h - refreshPeriod: 30m - # optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not, default: false + hostsTTL: 30m + # optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not + # default: false filterLoopback: true + # optional: Configure how sources are loaded + loading: + # optional: file refresh period in duration format. + # Set to a value <= 0 to disable. + # default: 4h + refreshPeriod: 24h + # optional: Applies only to files that are downloaded (HTTP URLs). + downloads: + # optional: timeout for file download (each url). Use large values for big files or slow internet connections + # default: 5s + timeout: 60s + # optional: Maximum download attempts + # default: 3 + attempts: 5 + # optional: Time between the download attempts + # default: 500ms + cooldown: 10s + # optional: Maximum number of files to process in parallel. + # default: 4 + concurrency: 16 + # optional: if failOnError, application startup will fail if at least one file can't be downloaded/opened + # default: blocking + strategy: failOnError + # Number of errors allowed in a file before it is considered invalid. + # A value of -1 disables the limit. + # default: 5 + maxErrorsPerSource: 5 # optional: ports configuration ports: @@ -272,4 +318,4 @@ log: # optional: add EDE error codes to dns response ede: # enabled if true, Default: false - enable: true \ No newline at end of file + enable: true diff --git a/docs/configuration.md b/docs/configuration.md index 37ab1682..fc33bc52 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -330,20 +330,24 @@ contains a map of client name and multiple IP addresses. ## Blocking and whitelisting -Blocky can download and use external lists with domains or IP addresses to block DNS query (e.g. advertisement, malware, +Blocky can use lists of domains and IPs to block (e.g. advertisement, malware, trackers, adult sites). You can group several list sources together and define the blocking behavior per client. -External blacklists must be either in the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file)) or just -a plain domain list (one domain per line). Blocky also supports regex as more powerful tool to define patterns to block. +Blocking uses the [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach. For each DNS query, the domain name from +the request, IP address from the response, and any CNAME records will be checked to determine whether to block the query or not. -Blocky uses [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach to block a DNS query. Domain name from -the request, IP address from the response, and the CNAME record will be checked against configured blacklists. - -To avoid over-blocking, you can define or use already existing whitelists. +To avoid over-blocking, you can use whitelists. ### Definition black and whitelists -Each black or whitelist can be either a path to the local file, a URL to download or inline list definition of a domains -in hosts format (YAML literal block scalar style). All Urls must be grouped to a group name. +Lists are defined in groups. This allows using different sets of lists for different clients. + +Each list in a group is a "source" and can be downloaded, read from a file, or inlined in the config. See [Sources](#sources) for details and configuring how those are loaded and reloaded/refreshed. + +The supported list formats are: + +1. the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file)) +2. one domain per line (plain domain list) +3. one regex per line !!! example @@ -354,35 +358,38 @@ in hosts format (YAML literal block scalar style). All Urls must be grouped to a - https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts - | - # inline definition with YAML literal block scalar style + # inline definition using YAML literal block scalar style + # content is in plain domain list format someadsdomain.com anotheradsdomain.com - # this is a regex + - | + # inline definition with a regex /^banners?[_.-]/ special: - https://raw.githubusercontent.com/StevenBlack/hosts/master/alternates/fakenews/hosts whiteLists: ads: - whitelist.txt + - /path/to/file.txt - | # inline definition with YAML literal block scalar style whitelistdomain.com ``` - In this example you can see 2 groups: **ads** with 2 lists and **special** with one list. One local whitelist was defined for the **ads** group. + In this example you can see 2 groups: **ads** and **special** with one list. The **ads** group includes 2 inline lists. !!! warning If the same group has black and whitelists, whitelists will be used to disable particular blacklist entries. If a group has **only** whitelist entries -> this means only domains from this list are allowed, all other domains will - be blocked + be blocked. -!!! note - Please define also client group mapping, otherwise you black and whitelist definition will have no effect +!!! warning + You must also define client group mapping, otherwise you black and whitelist definition will have no effect. #### Regex support -You can use regex to define patterns to block. A regex entry must start and end with the slash character (/). Some +You can use regex to define patterns to block. A regex entry must start and end with the slash character (`/`). Some Examples: - `/baddomain/` will block `www.baddomain.com`, `baddomain.com`, but also `mybaddomain-sometext.com` @@ -395,7 +402,7 @@ In this configuration section, you can define, which blocking group(s) should be Example: All clients should use the **ads** group, which blocks advertisement and kids devices should use the **adult** group, which blocky adult sites. -Clients without a group assignment will use automatically the **default** group. +Clients without an explicit group assignment will use the **default** group. You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address, client's full-qualified domain name or a client subnet as CIDR notation. @@ -460,82 +467,9 @@ after receiving the custom value. blockTTL: 10s ``` -### List refresh period +### Lists Loading -To keep the list cache up-to-date, blocky will periodically download and reload all external lists. Default period is ** -4 hours**. You can configure this by setting the `blocking.refreshPeriod` parameter to a value in **duration format**. -Negative value will deactivate automatically refresh. - -!!! example - - ```yaml - blocking: - refreshPeriod: 60m - ``` - -Refresh every hour. - -### Download - -You can configure the list download attempts according to your internet connection: - -| Parameter | Type | Mandatory | Default value | Description | -|------------------|-----------------|-----------|---------------|------------------------------------------------| -| downloadTimeout | duration format | no | 60s | Download attempt timeout | -| downloadAttempts | int | no | 3 | How many download attempts should be performed | -| downloadCooldown | duration format | no | 1s | Time between the download attempts | - -!!! example - - ```yaml - blocking: - downloadTimeout: 4m - downloadAttempts: 5 - downloadCooldown: 10s - ``` - -### Start strategy - -You can configure the blocking behavior during application start of blocky. -If no strategy is selected blocking will be used. - -| startStrategy | Description | -|---------------|-------------------------------------------------------------------------------------------------------| -| blocking | all blocking lists will be loaded before DNS resolution starts | -| failOnError | like blocking but blocky will shut down if any download fails | -| fast | DNS resolution starts immediately without blocking which will be enabled after list load is completed | - -!!! example - - ```yaml - blocking: - startStrategy: failOnError - ``` - -### Max Errors per file - -Number of errors allowed in a list before it is considered invalid and parsing stops. -A value of -1 disables the limit. - -!!! example - - ```yaml - blocking: - maxErrorsPerFile: 10 - ``` - -### Concurrency - -Blocky downloads and processes links in a single group concurrently. With parameter `processingConcurrency` you can adjust -how many links can be processed in the same time. Higher value can reduce the overall list refresh time, but more parallel - download and processing jobs need more RAM. Please consider to reduce this value on systems with limited memory. Default value is 4. - -!!! example - - ```yaml - blocking: - processingConcurrency: 10 - ``` +See [Sources Loading](#sources-loading). ## Caching @@ -716,7 +650,7 @@ Configuration parameters: ```yaml hostsFile: filePath: /etc/hosts - hostsTTL: 60m + hostsTTL: 1h refreshPeriod: 30m ``` @@ -745,3 +679,127 @@ for detailed information, how to create and configure SSL certificates. DoH url: `https://host:port/dns-query` --8<-- "docs/includes/abbreviations.md" + +## Sources + +Sources are a concept shared by the blocking and hosts file resolvers. They represent where to load the files for each resolver. + +The supported source types are: + +- HTTP(S) URL (any source starting with `http`) +- inline configuration (any source containing a newline) +- local file path (any source not matching the above rules) + +!!! note + + The format/content of the sources depends on the context: lists and hosts files have different, but overlapping, supported formats. + +!!! example + + ```yaml + - https://example.com/a/source # blocky will download and parse the file + - /a/file/path # blocky will read the local file + - | # blocky will parse the content of this multi-line string + # inline configuration + ``` + +### Sources Loading + +This sections covers `loading` configuration that applies to both the blocking and hosts file resolvers. +These settings apply only to the resolver under which they are nested. + +!!! example + + ```yaml + blocking: + loading: + # only applies to white/blacklists + + hostsFile: + loading: + # only applies to hostsFile sources + ``` + +#### Refresh / Reload + +To keep source contents up-to-date, blocky can periodically refresh and reparse them. Default period is ** +4 hours**. You can configure this by setting the `refreshPeriod` parameter to a value in **duration format**. +A value of zero or less will disable this feature. + +!!! example + + ```yaml + loading: + refreshPeriod: 1h + ``` + + Refresh every hour. + +### Downloads + +Configures how HTTP(S) sources are downloaded: + +| Parameter | Type | Mandatory | Default value | Description | +|-----------|----------|-----------|---------------|------------------------------------------------| +| timeout | duration | no | 5s | Download attempt timeout | +| attempts | int | no | 3 | How many download attempts should be performed | +| cooldown | duration | no | 500ms | Time between the download attempts | + +!!! example + + ```yaml + loading: + downloads: + timeout: 4m + attempts: 5 + cooldown: 10s + ``` + +### Strategy + +This configures how Blocky startup works. +The default strategy is blocking. + +| strategy | Description | +|-------------|------------------------------------------------------------------------------------------------------------------------------------------| +| blocking | all sources are loaded before DNS resolution starts | +| failOnError | like blocking but blocky will shut down if any source fails to load | +| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after | + +!!! example + + ```yaml + loading: + strategy: failOnError + ``` + +### Max Errors per Source + +Number of errors allowed when parsing a source before it is considered invalid and parsing stops. +A value of -1 disables the limit. + +!!! example + + ```yaml + loading: + maxErrorsPerSource: 10 + ``` + +### Concurrency + +Blocky downloads and processes sources concurrently. This allows limiting how many can be processed in the same time. +Larger values can reduce the overall list refresh time at the cost of using more RAM. Please consider reducing this value on systems with limited memory. +Default value is 4. + +!!! example + + ```yaml + loading: + concurrency: 10 + ``` + +!!! note + + As with other settings under `loading`, the limit applies to the blocking and hosts file resolvers separately. + The total number of concurrent sources concurrently processed can reach the sum of both values. + For example if blocking has a limit set to 8 and hosts file's is 4, there could be up to 12 concurrent jobs. diff --git a/e2e/blocking_test.go b/e2e/blocking_test.go index 9b3a86b1..087aa515 100644 --- a/e2e/blocking_test.go +++ b/e2e/blocking_test.go @@ -19,7 +19,7 @@ var _ = Describe("External lists and query blocking", func() { }) Describe("List download on startup", func() { When("external blacklist ist not available", func() { - Context("startStrategy = blocking", func() { + Context("loading.strategy = blocking", func() { BeforeEach(func() { blocky, err = createBlockyContainer(tmpDir, "log:", @@ -28,7 +28,8 @@ var _ = Describe("External lists and query blocking", func() { " default:", " - moka", "blocking:", - " startStrategy: blocking", + " loading:", + " strategy: blocking", " blackLists:", " ads:", " - http://wrong.domain.url/list.txt", @@ -54,7 +55,7 @@ var _ = Describe("External lists and query blocking", func() { Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: "))) }) }) - Context("startStrategy = failOnError", func() { + Context("loading.strategy = failOnError", func() { BeforeEach(func() { blocky, err = createBlockyContainer(tmpDir, "log:", @@ -63,7 +64,8 @@ var _ = Describe("External lists and query blocking", func() { " default:", " - moka", "blocking:", - " startStrategy: failOnError", + " loading:", + " strategy: failOnError", " blackLists:", " ads:", " - http://wrong.domain.url/list.txt", diff --git a/evt/events.go b/evt/events.go index 7d67f453..9fcc9136 100644 --- a/evt/events.go +++ b/evt/events.go @@ -29,7 +29,7 @@ const ( // CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count CachingDomainsToPrefetchCountChanged = "caching:domainsToPrefetchCountChanged" - // CachingFailedDownloadChanged fires, if a download of a blocking list fails + // CachingFailedDownloadChanged fires, if a download of a blocking list or hosts file fails CachingFailedDownloadChanged = "caching:failedDownload" // ApplicationStarted fires on start of the application. Parameter: version number, build time diff --git a/go.mod b/go.mod index 7db51cb0..aec27928 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( require ( github.com/DATA-DOG/go-sqlmock v1.5.0 + github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 github.com/docker/go-connections v0.4.0 github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198 github.com/testcontainers/testcontainers-go v0.21.0 @@ -56,7 +57,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/google/pprof v0.0.0-20230309165930-d61513b1440d // indirect github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/klauspost/compress v1.11.13 // indirect github.com/magiconair/properties v1.8.7 // indirect diff --git a/go.sum b/go.sum index f83b8bbd..80e94c2c 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/hcsshim v0.9.7 h1:mKNHW/Xvv1aFH87Jb6ERDzXTJTLPlmzfZ28VBFD/bfg= +github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI= +github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo= github.com/abice/go-enum v0.5.6 h1:Ury51IQXUppbIl56MqRU/++A8SSeLG4plePphPjxW1s= github.com/abice/go-enum v0.5.6/go.mod h1:X2GpCT8VkCXLkVm48hebWx3cVgFJ8zM5nY5iUrJZO1Q= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= @@ -108,8 +110,8 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20230309165930-d61513b1440d h1:um9/pc7tKMINFfP1eE7Wv6PRGXlcCSJkVajF7KJw3uQ= +github.com/google/pprof v0.0.0-20230309165930-d61513b1440d/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -124,7 +126,6 @@ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+l github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM= github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= @@ -319,7 +320,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/lists/downloader.go b/lists/downloader.go index c86bf325..1536f79a 100644 --- a/lists/downloader.go +++ b/lists/downloader.go @@ -6,18 +6,12 @@ import ( "io" "net" "net/http" - "time" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/evt" "github.com/avast/retry-go/v4" ) -const ( - defaultDownloadTimeout = time.Second - defaultDownloadAttempts = uint(1) - defaultDownloadCooldown = 500 * time.Millisecond -) - // TransientError represents a temporary error like timeout, network errors... type TransientError struct { inner error @@ -36,74 +30,35 @@ type FileDownloader interface { DownloadFile(link string) (io.ReadCloser, error) } -// HTTPDownloader downloads files via HTTP protocol -type HTTPDownloader struct { - downloadTimeout time.Duration - downloadAttempts uint - downloadCooldown time.Duration - httpTransport *http.Transport +// httpDownloader downloads files via HTTP protocol +type httpDownloader struct { + cfg config.DownloaderConfig + + client http.Client } -type DownloaderOption func(c *HTTPDownloader) - -func NewDownloader(options ...DownloaderOption) *HTTPDownloader { - d := &HTTPDownloader{ - downloadTimeout: defaultDownloadTimeout, - downloadAttempts: defaultDownloadAttempts, - downloadCooldown: defaultDownloadCooldown, - httpTransport: &http.Transport{}, - } - - for _, opt := range options { - opt(d) - } - - return d +func NewDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) FileDownloader { + return newDownloader(cfg, transport) } -// WithTimeout sets the download timeout -func WithTimeout(timeout time.Duration) DownloaderOption { - return func(d *HTTPDownloader) { - d.downloadTimeout = timeout +func newDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) *httpDownloader { + return &httpDownloader{ + cfg: cfg, + + client: http.Client{ + Transport: transport, + Timeout: cfg.Timeout.ToDuration(), + }, } } -// WithTimeout sets the pause between 2 download attempts -func WithCooldown(cooldown time.Duration) DownloaderOption { - return func(d *HTTPDownloader) { - d.downloadCooldown = cooldown - } -} - -// WithTimeout sets the attempt number for retry -func WithAttempts(downloadAttempts uint) DownloaderOption { - return func(d *HTTPDownloader) { - d.downloadAttempts = downloadAttempts - } -} - -// WithTimeout sets the HTTP transport -func WithTransport(httpTransport *http.Transport) DownloaderOption { - return func(d *HTTPDownloader) { - d.httpTransport = httpTransport - } -} - -func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) { - client := http.Client{ - Timeout: d.downloadTimeout, - Transport: d.httpTransport, - } - - logger().WithField("link", link).Info("starting download") - +func (d *httpDownloader) DownloadFile(link string) (io.ReadCloser, error) { var body io.ReadCloser err := retry.Do( func() error { - var resp *http.Response - var httpErr error - if resp, httpErr = client.Get(link); httpErr == nil { + resp, httpErr := d.client.Get(link) + if httpErr == nil { if resp.StatusCode == http.StatusOK { body = resp.Body @@ -121,17 +76,18 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) { return httpErr }, - retry.Attempts(d.downloadAttempts), + retry.Attempts(d.cfg.Attempts), retry.DelayType(retry.FixedDelay), - retry.Delay(d.downloadCooldown), + retry.Delay(d.cfg.Cooldown.ToDuration()), retry.LastErrorOnly(true), retry.OnRetry(func(n uint, err error) { var transientErr *TransientError var dnsErr *net.DNSError - logger := logger().WithField("link", link).WithField("attempt", - fmt.Sprintf("%d/%d", n+1, d.downloadAttempts)) + logger := logger(). + WithField("link", link). + WithField("attempt", fmt.Sprintf("%d/%d", n+1, d.cfg.Attempts)) switch { case errors.As(err, &transientErr): diff --git a/lists/downloader_test.go b/lists/downloader_test.go index 27106a29..5387c863 100644 --- a/lists/downloader_test.go +++ b/lists/downloader_test.go @@ -10,6 +10,7 @@ import ( "sync/atomic" "time" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -20,11 +21,17 @@ import ( var _ = Describe("Downloader", func() { var ( - sut *HTTPDownloader + sutConfig config.DownloaderConfig + sut *httpDownloader failedDownloadCountEvtChannel chan string loggerHook *test.Hook ) BeforeEach(func() { + var err error + + sutConfig, err = config.WithDefaults[config.DownloaderConfig]() + Expect(err).Should(Succeed()) + failedDownloadCountEvtChannel = make(chan string, 5) // collect received events in the channel fn := func(url string) { @@ -40,33 +47,27 @@ var _ = Describe("Downloader", func() { DeferCleanup(loggerHook.Reset) }) - Describe("Construct downloader", func() { - When("No options are provided", func() { - BeforeEach(func() { - sut = NewDownloader() - }) - It("Should provide default valus", func() { - Expect(sut.downloadAttempts).Should(BeNumerically("==", defaultDownloadAttempts)) - Expect(sut.downloadTimeout).Should(BeNumerically("==", defaultDownloadTimeout)) - Expect(sut.downloadCooldown).Should(BeNumerically("==", defaultDownloadCooldown)) - }) - }) - When("Options are provided", func() { + JustBeforeEach(func() { + sut = newDownloader(sutConfig, nil) + }) + + Describe("NewDownloader", func() { + It("Should use provided parameters", func() { transport := &http.Transport{} - BeforeEach(func() { - sut = NewDownloader( - WithAttempts(5), - WithCooldown(2*time.Second), - WithTimeout(5*time.Second), - WithTransport(transport), - ) - }) - It("Should use provided parameters", func() { - Expect(sut.downloadAttempts).Should(BeNumerically("==", 5)) - Expect(sut.downloadTimeout).Should(BeNumerically("==", 5*time.Second)) - Expect(sut.downloadCooldown).Should(BeNumerically("==", 2*time.Second)) - Expect(sut.httpTransport).Should(BeIdenticalTo(transport)) - }) + + sut = NewDownloader( + config.DownloaderConfig{ + Attempts: 5, + Cooldown: config.Duration(2 * time.Second), + Timeout: config.Duration(5 * time.Second), + }, + transport, + ).(*httpDownloader) + + Expect(sut.cfg.Attempts).Should(BeNumerically("==", 5)) + Expect(sut.cfg.Timeout).Should(BeNumerically("==", 5*time.Second)) + Expect(sut.cfg.Cooldown).Should(BeNumerically("==", 2*time.Second)) + Expect(sut.client.Transport).Should(BeIdenticalTo(transport)) }) }) @@ -77,7 +78,7 @@ var _ = Describe("Downloader", func() { server = TestServer("line.one\nline.two") DeferCleanup(server.Close) - sut = NewDownloader() + sut = newDownloader(sutConfig, nil) }) It("Should return all lines from the file", func() { reader, err := sut.DownloadFile(server.URL) @@ -98,7 +99,7 @@ var _ = Describe("Downloader", func() { })) DeferCleanup(server.Close) - sut = NewDownloader(WithAttempts(3)) + sutConfig.Attempts = 3 }) It("Should return error", func() { reader, err := sut.DownloadFile(server.URL) @@ -112,7 +113,7 @@ var _ = Describe("Downloader", func() { }) When("Wrong URL is defined", func() { BeforeEach(func() { - sut = NewDownloader() + sutConfig.Attempts = 1 }) It("Should return error", func() { _, err := sut.DownloadFile("somewrongurl") @@ -129,10 +130,11 @@ var _ = Describe("Downloader", func() { var attempt uint64 = 1 BeforeEach(func() { - sut = NewDownloader( - WithTimeout(20*time.Millisecond), - WithAttempts(3), - WithCooldown(time.Millisecond)) + sutConfig = config.DownloaderConfig{ + Timeout: config.Duration(20 * time.Millisecond), + Attempts: 3, + Cooldown: config.Duration(time.Millisecond), + } // should produce a timeout on first attempt server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -166,24 +168,23 @@ var _ = Describe("Downloader", func() { }) When("If timeout occurs on all request", func() { BeforeEach(func() { - sut = NewDownloader( - WithTimeout(100*time.Millisecond), - WithAttempts(3), - WithCooldown(time.Millisecond)) + sutConfig = config.DownloaderConfig{ + Timeout: config.Duration(10 * time.Millisecond), + Attempts: 3, + Cooldown: config.Duration(time.Millisecond), + } // should always produce a timeout server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - time.Sleep(200 * time.Millisecond) + time.Sleep(20 * time.Millisecond) })) DeferCleanup(server.Close) }) It("Should perform a retry until max retry attempt count is reached and return TransientError", func() { reader, err := sut.DownloadFile(server.URL) Expect(err).Should(HaveOccurred()) - - err2 := unwrapTransientErr(err) - - Expect(err2.Error()).Should(ContainSubstring("Timeout")) + Expect(errors.As(err, new(*TransientError))).Should(BeTrue()) + Expect(err.Error()).Should(ContainSubstring("Timeout")) Expect(reader).Should(BeNil()) // failed download event was emitted 3 times @@ -193,19 +194,18 @@ var _ = Describe("Downloader", func() { }) When("DNS resolution of passed URL fails", func() { BeforeEach(func() { - sut = NewDownloader( - WithTimeout(500*time.Millisecond), - WithAttempts(3), - WithCooldown(200*time.Millisecond)) + sutConfig = config.DownloaderConfig{ + Timeout: config.Duration(500 * time.Millisecond), + Attempts: 3, + Cooldown: 200 * config.Duration(time.Millisecond), + } }) It("Should perform a retry until max retry attempt count is reached and return DNSError", func() { reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist") Expect(err).Should(HaveOccurred()) - err2 := unwrapTransientErr(err) - var dnsError *net.DNSError - Expect(errors.As(err2, &dnsError)).To(BeTrue(), "received error %w", err) + Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err) Expect(reader).Should(BeNil()) // failed download event was emitted 3 times @@ -216,12 +216,3 @@ var _ = Describe("Downloader", func() { }) }) }) - -func unwrapTransientErr(origErr error) error { - var transientErr *TransientError - if errors.As(origErr, &transientErr) { - return transientErr.Unwrap() - } - - return origErr -} diff --git a/lists/list_cache.go b/lists/list_cache.go index 8e17ae76..ba17c4ca 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -5,25 +5,20 @@ import ( "context" "errors" "fmt" - "io" "net" - "os" - "strings" - "time" - "github.com/hashicorp/go-multierror" "github.com/sirupsen/logrus" "github.com/0xERR0R/blocky/cache/stringcache" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/evt" "github.com/0xERR0R/blocky/lists/parsers" "github.com/0xERR0R/blocky/log" + "github.com/ThinkChaos/parcour" + "github.com/ThinkChaos/parcour/jobgroup" ) -const ( - defaultProcessingConcurrency = 4 - chanCap = 1000 -) +const groupProducersBufferCap = 1000 // ListCacheType represents the type of cached list ENUM( // blacklist // is a list with blocked domains @@ -41,19 +36,17 @@ type Matcher interface { type ListCache struct { groupedCache stringcache.GroupedStringCache - groupToLinks map[string][]string - refreshPeriod time.Duration - downloader FileDownloader - listType ListCacheType - processingConcurrency uint - maxErrorsPerFile int + cfg config.SourceLoadingConfig + listType ListCacheType + groupSources map[string][]config.BytesSource + downloader FileDownloader } // LogConfig implements `config.Configurable`. func (b *ListCache) LogConfig(logger *logrus.Entry) { var total int - for group := range b.groupToLinks { + for group := range b.groupSources { count := b.groupedCache.ElementCount(group) logger.Infof("%s: %d entries", group, count) total += count @@ -63,132 +56,36 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) { } // NewListCache creates new list instance -func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration, - downloader FileDownloader, processingConcurrency uint, async bool, maxErrorsPerFile int, +func NewListCache( + t ListCacheType, cfg config.SourceLoadingConfig, + groupSources map[string][]config.BytesSource, downloader FileDownloader, ) (*ListCache, error) { - if processingConcurrency == 0 { - processingConcurrency = defaultProcessingConcurrency - } - - b := &ListCache{ + c := &ListCache{ groupedCache: stringcache.NewChainedGroupedCache( stringcache.NewInMemoryGroupedStringCache(), stringcache.NewInMemoryGroupedRegexCache(), ), - groupToLinks: groupToLinks, - refreshPeriod: refreshPeriod, - downloader: downloader, - listType: t, - processingConcurrency: processingConcurrency, - maxErrorsPerFile: maxErrorsPerFile, + + cfg: cfg, + listType: t, + groupSources: groupSources, + downloader: downloader, } - var initError error - if async { - initError = nil - - // start list refresh in the background - go b.Refresh() - } else { - initError = b.refresh(true) + err := cfg.StartPeriodicRefresh(c.refresh, func(err error) { + logger().WithError(err).Errorf("could not init %s", t) + }) + if err != nil { + return nil, err } - if initError == nil { - go periodicUpdate(b) - } - - return b, initError -} - -// periodicUpdate triggers periodical refresh (and download) of list entries -func periodicUpdate(cache *ListCache) { - if cache.refreshPeriod > 0 { - ticker := time.NewTicker(cache.refreshPeriod) - defer ticker.Stop() - - for { - <-ticker.C - cache.Refresh() - } - } + return c, nil } func logger() *logrus.Entry { return log.PrefixedLog("list_cache") } -// downloads and reads files with domain names and creates cache for them -// -//nolint:funlen // will refactor in a later commit -func (b *ListCache) createCacheForGroup(group string, links []string) (created bool, err error) { - groupFactory := b.groupedCache.Refresh(group) - - fileLinesChan := make(chan string, chanCap) - errChan := make(chan error, chanCap) - - workerDoneChan := make(chan bool, len(links)) - - // guard channel is used to limit the number of concurrent executions of the function - guard := make(chan struct{}, b.processingConcurrency) - - processingLinkJobs := len(links) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // loop over links (http/local) or inline definitions - // start a new goroutine for each link, but limit to max. number (see processingConcurrency) - for idx, link := range links { - go func(idx int, link string) { - // try to write in this channel -> this will block if max amount of goroutines are being executed - guard <- struct{}{} - - defer func() { - // remove from guard channel to allow other blocked goroutines to continue - <-guard - workerDoneChan <- true - }() - - name := linkName(idx, link) - - err := b.parseFile(ctx, name, link, fileLinesChan) - if err != nil { - errChan <- err - } - }(idx, link) - } - -Loop: - for { - select { - case line := <-fileLinesChan: - groupFactory.AddEntry(line) - case e := <-errChan: - var transientErr *TransientError - - if errors.As(e, &transientErr) { - return false, e - } - err = multierror.Append(err, e) - case <-workerDoneChan: - processingLinkJobs-- - - default: - if processingLinkJobs == 0 { - break Loop - } - } - } - - if groupFactory.Count() == 0 && err != nil { - return false, err - } - - groupFactory.Finish() - - return true, err -} - // Match matches passed domain name against cached list entries func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []string) { return b.groupedCache.Contains(domain, groupsToCheck) @@ -196,65 +93,123 @@ func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []strin // Refresh triggers the refresh of a list func (b *ListCache) Refresh() { - _ = b.refresh(false) + _ = b.refresh(context.Background()) } -func (b *ListCache) refresh(isInit bool) error { - var err error +func (b *ListCache) refresh(ctx context.Context) error { + unlimitedGrp, _ := jobgroup.WithContext(ctx) + defer unlimitedGrp.Close() - for group, links := range b.groupToLinks { - created, e := b.createCacheForGroup(group, links) - if e != nil { - err = multierror.Append(err, multierror.Prefix(e, fmt.Sprintf("can't create cache group '%s':", group))) - } + producersGrp := jobgroup.WithMaxConcurrency(unlimitedGrp, b.cfg.Concurrency) + defer producersGrp.Close() - count := b.groupedCache.ElementCount(group) + for group, sources := range b.groupSources { + group, sources := group, sources - if !created { - logger := logger().WithFields(logrus.Fields{ - "group": group, - "total_count": count, - }) + unlimitedGrp.Go(func(ctx context.Context) error { + err := b.createCacheForGroup(producersGrp, unlimitedGrp, group, sources) + if err != nil { + count := b.groupedCache.ElementCount(group) - if count == 0 || isInit { - logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds") - } else { - logger.Warn("Populating of group cache failed, using existing cache, if any") + logger := logger().WithFields(logrus.Fields{ + "group": group, + "total_count": count, + }) + + if count == 0 { + logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds") + } else { + logger.Warn("Populating of group cache failed, using existing cache, if any") + } + + return err } - continue - } + count := b.groupedCache.ElementCount(group) - evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count) + evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count) - logger().WithFields(logrus.Fields{ - "group": group, - "total_count": count, - }).Info("group import finished") + logger().WithFields(logrus.Fields{ + "group": group, + "total_count": count, + }).Info("group import finished") + + return nil + }) } - return err + return unlimitedGrp.Wait() } -func readFile(file string) (io.ReadCloser, error) { - logger().WithField("file", file).Info("starting processing of file") - file = strings.TrimPrefix(file, "file://") +func (b *ListCache) createCacheForGroup( + producersGrp, consumersGrp jobgroup.JobGroup, group string, sources []config.BytesSource, +) error { + groupFactory := b.groupedCache.Refresh(group) - return os.Open(file) + producers := parcour.NewProducersWithBuffer[string](producersGrp, consumersGrp, groupProducersBufferCap) + defer producers.Close() + + for i, source := range sources { + i, source := i, source + + producers.GoProduce(func(ctx context.Context, hostsChan chan<- string) error { + locInfo := fmt.Sprintf("item #%d of group %s", i, group) + + opener, err := NewSourceOpener(locInfo, source, b.downloader) + if err != nil { + return err + } + + return b.parseFile(ctx, opener, hostsChan) + }) + } + + hasEntries := false + + producers.GoConsume(func(ctx context.Context, ch <-chan string) error { + for host := range ch { + hasEntries = true + + groupFactory.AddEntry(host) + } + + return nil + }) + + err := producers.Wait() + if err != nil { + if !hasEntries { + // Always fail the group if no entries were parsed + return err + } + + var transientErr *TransientError + + if errors.As(err, &transientErr) { + // Temporary error: fail the whole group to retry later + return err + } + } + + groupFactory.Finish() + + return nil } // downloads file (or reads local file) and writes each line in the file to the result channel -func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh chan<- string) error { +func (b *ListCache) parseFile(ctx context.Context, opener SourceOpener, resultCh chan<- string) error { count := 0 logger := func() *logrus.Entry { return logger().WithFields(logrus.Fields{ - "source": name, + "source": opener.String(), "count": count, }) } - r, err := b.newLinkReader(link) + logger().Debug("starting processing of source") + + r, err := opener.Open() if err != nil { logger().Error("cannot open source: ", err) @@ -262,7 +217,7 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c } defer r.Close() - p := parsers.AllowErrors(parsers.Hosts(r), b.maxErrorsPerFile) + p := parsers.AllowErrors(parsers.Hosts(r), b.cfg.MaxErrorsPerSource) p.OnErr(func(err error) { logger().Warnf("parse error: %s, trying to continue", err) }) @@ -303,27 +258,3 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c return nil } - -func linkName(linkIdx int, link string) string { - if strings.ContainsAny(link, "\n") { - return fmt.Sprintf("inline block (item #%d in group)", linkIdx) - } - - return link -} - -func (b *ListCache) newLinkReader(link string) (r io.ReadCloser, err error) { - switch { - // link contains a line break -> this is inline list definition in YAML (with literal style Block Scalar) - case strings.ContainsAny(link, "\n"): - r = io.NopCloser(strings.NewReader(link)) - // link is http(s) -> download it - case strings.HasPrefix(link, "http"): - r, err = b.downloader.DownloadFile(link) - // probably path to a local file - default: - r, err = readFile(link) - } - - return -} diff --git a/lists/list_cache_benchmark_test.go b/lists/list_cache_benchmark_test.go index 565da927..fedf2fd8 100644 --- a/lists/list_cache_benchmark_test.go +++ b/lists/list_cache_benchmark_test.go @@ -2,17 +2,24 @@ package lists import ( "testing" + + "github.com/0xERR0R/blocky/config" ) func BenchmarkRefresh(b *testing.B) { file1, _ := createTestListFile(b.TempDir(), 100000) file2, _ := createTestListFile(b.TempDir(), 150000) file3, _ := createTestListFile(b.TempDir(), 130000) - lists := map[string][]string{ - "gr1": {file1, file2, file3}, + lists := map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(file1, file2, file3), } - cache, _ := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(), 5, false, 5) + cfg := config.SourceLoadingConfig{ + Concurrency: 5, + RefreshPeriod: config.Duration(-1), + } + downloader := NewDownloader(config.DownloaderConfig{}, nil) + cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader) b.ReportAllocs() diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go index 73ef3953..2d08cfb1 100644 --- a/lists/list_cache_test.go +++ b/lists/list_cache_test.go @@ -2,6 +2,7 @@ package lists import ( "bufio" + "context" "errors" "fmt" "io" @@ -9,8 +10,8 @@ import ( "net/http/httptest" "os" "strings" - "time" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/evt" "github.com/0xERR0R/blocky/lists/parsers" "github.com/0xERR0R/blocky/log" @@ -27,13 +28,28 @@ var _ = Describe("ListCache", func() { tmpDir *TmpFolder emptyFile, file1, file2, file3 *TmpFile server1, server2, server3 *httptest.Server - maxErrorsPerFile int + + sut *ListCache + sutConfig config.SourceLoadingConfig + + listCacheType ListCacheType + lists map[string][]config.BytesSource + downloader FileDownloader + mockDownloader *MockDownloader ) + BeforeEach(func() { - maxErrorsPerFile = 5 - tmpDir = NewTmpFolder("ListCache") - Expect(tmpDir.Error).Should(Succeed()) - DeferCleanup(tmpDir.Clean) + var err error + + listCacheType = ListCacheTypeBlacklist + + sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]() + Expect(err).Should(Succeed()) + + sutConfig.RefreshPeriod = -1 + + downloader = NewDownloader(config.DownloaderConfig{}, nil) + mockDownloader = nil server1 = TestServer("blocked1.com\nblocked1a.com\n192.168.178.55") DeferCleanup(server1.Close) @@ -42,6 +58,13 @@ var _ = Describe("ListCache", func() { server3 = TestServer("blocked3.com\nblocked1a.com") DeferCleanup(server3.Close) + tmpDir = NewTmpFolder("ListCache") + Expect(tmpDir.Error).Should(Succeed()) + DeferCleanup(tmpDir.Clean) + + emptyFile = tmpDir.CreateStringFile("empty", "#empty file") + Expect(emptyFile.Error).Should(Succeed()) + emptyFile = tmpDir.CreateStringFile("empty", "#empty file") Expect(emptyFile.Error).Should(Succeed()) file1 = tmpDir.CreateStringFile("file1", "blocked1.com", "blocked1a.com") @@ -52,61 +75,56 @@ var _ = Describe("ListCache", func() { Expect(file3.Error).Should(Succeed()) }) + JustBeforeEach(func() { + var err error + + Expect(lists).ShouldNot(BeNil(), "bad test: forgot to set `lists`") + + if mockDownloader != nil { + downloader = mockDownloader + } + + sut, err = NewListCache(listCacheType, sutConfig, lists, downloader) + Expect(err).Should(Succeed()) + }) + Describe("List cache and matching", func() { - When("Query with empty", func() { - It("should not panic", func() { - lists := map[string][]string{ - "gr0": {emptyFile.Path}, - } - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) - - group := sut.Match("", []string{"gr0"}) - Expect(group).Should(BeEmpty()) - }) - }) - When("List is empty", func() { - It("should not match anything", func() { - lists := map[string][]string{ - "gr1": {emptyFile.Path}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr0": config.NewBytesSources(emptyFile.Path), } - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) + }) + When("Query with empty", func() { + It("should not panic", func() { + group := sut.Match("", []string{"gr0"}) + Expect(group).Should(BeEmpty()) + }) + }) + + It("should not match anything", func() { group := sut.Match("google.com", []string{"gr1"}) Expect(group).Should(BeEmpty()) }) }) When("List becomes empty on refresh", func() { - It("should delete existing elements from group cache", func() { - mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) { + BeforeEach(func() { + mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) { res <- "blocked1.com" res <- "# nothing" }) - lists := map[string][]string{ + lists = map[string][]config.BytesSource{ "gr1": {mockDownloader.ListSource()}, } + }) - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, - 4*time.Hour, - mockDownloader, - defaultProcessingConcurrency, - false, - maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) - + It("should delete existing elements from group cache", func(ctx context.Context) { group := sut.Match("blocked1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) - err = sut.refresh(false) + err := sut.refresh(ctx) Expect(err).Should(Succeed()) group = sut.Match("blocked1.com", []string{"gr1"}) @@ -114,21 +132,19 @@ var _ = Describe("ListCache", func() { }) }) When("List has invalid lines", func() { - It("should still other domains", func() { - lists := map[string][]string{ + BeforeEach(func() { + lists = map[string][]config.BytesSource{ "gr1": { - inlineList( + config.TextBytesSource( "inlinedomain1.com", "invaliddomain!", "inlinedomain2.com", ), }, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should still other domains", func() { group := sut.Match("inlinedomain1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) @@ -137,28 +153,20 @@ var _ = Describe("ListCache", func() { }) }) When("a temporary/transient err occurs on download", func() { - It("should not delete existing elements from group cache", func() { + BeforeEach(func() { // should produce a transient error on second and third attempt - mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) { - res <- "blocked1.com" + mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) { + res <- "blocked1.com\nblocked2.com\n" err <- &TransientError{inner: errors.New("boom")} err <- &TransientError{inner: errors.New("boom")} }) - lists := map[string][]string{ + lists = map[string][]config.BytesSource{ "gr1": {mockDownloader.ListSource()}, } + }) - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, - 4*time.Hour, - mockDownloader, - defaultProcessingConcurrency, - false, - maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) - + It("should not delete existing elements from group cache", func(ctx context.Context) { By("Lists loaded without timeout", func() { Eventually(func(g Gomega) { group := sut.Match("blocked1.com", []string{"gr1"}) @@ -166,7 +174,7 @@ var _ = Describe("ListCache", func() { }, "1s").Should(Succeed()) }) - Expect(sut.refresh(false)).Should(HaveOccurred()) + Expect(sut.refresh(ctx)).Should(HaveOccurred()) By("List couldn't be loaded due to timeout", func() { group := sut.Match("blocked1.com", []string{"gr1"}) @@ -182,27 +190,25 @@ var _ = Describe("ListCache", func() { }) }) When("non transient err occurs on download", func() { - It("should keep existing elements from group cache", func() { + BeforeEach(func() { // should produce a non transient error on second attempt - mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) { + mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) { res <- "blocked1.com" err <- errors.New("boom") }) - lists := map[string][]string{ + lists = map[string][]config.BytesSource{ "gr1": {mockDownloader.ListSource()}, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader, - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should keep existing elements from group cache", func(ctx context.Context) { By("Lists loaded without err", func() { group := sut.Match("blocked1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) }) - Expect(sut.refresh(false)).Should(HaveOccurred()) + Expect(sut.refresh(ctx)).Should(HaveOccurred()) By("Lists from first load is kept", func() { group := sut.Match("blocked1.com", []string{"gr1"}) @@ -211,16 +217,14 @@ var _ = Describe("ListCache", func() { }) }) When("Configuration has 3 external working urls", func() { - It("should download the list and match against", func() { - lists := map[string][]string{ - "gr1": {server1.URL, server2.URL}, - "gr2": {server3.URL}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(server1.URL, server2.URL), + "gr2": config.NewBytesSources(server3.URL), } + }) - sut, _ := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - + It("should download the list and match against", func() { group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) Expect(group).Should(ContainElement("gr1")) @@ -232,16 +236,14 @@ var _ = Describe("ListCache", func() { }) }) When("Configuration has some faulty urls", func() { - It("should download the list and match against", func() { - lists := map[string][]string{ - "gr1": {server1.URL, server2.URL, "doesnotexist"}, - "gr2": {server3.URL, "someotherfile"}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(server1.URL, server2.URL, "doesnotexist"), + "gr2": config.NewBytesSources(server3.URL, "someotherfile"), } + }) - sut, _ := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - + It("should download the list and match against", func() { group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) Expect(group).Should(ContainElement("gr1")) @@ -253,39 +255,33 @@ var _ = Describe("ListCache", func() { }) }) When("List will be updated", func() { - It("event should be fired and contain count of elements in downloaded lists", func() { - lists := map[string][]string{ - "gr1": {server1.URL}, - } + resultCnt := 0 - resultCnt := 0 + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(server1.URL), + } _ = Bus().SubscribeOnce(BlockingCacheGroupChanged, func(listType ListCacheType, group string, cnt int) { resultCnt = cnt }) + }) - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) - + It("event should be fired and contain count of elements in downloaded lists", func() { group := sut.Match("blocked1.com", []string{}) Expect(group).Should(BeEmpty()) Expect(resultCnt).Should(Equal(3)) }) }) When("multiple groups are passed", func() { - It("should match", func() { - lists := map[string][]string{ - "gr1": {file1.Path, file2.Path}, - "gr2": {"file://" + file3.Path}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(file1.Path, file2.Path), + "gr2": config.NewBytesSources("file://" + file3.Path), } + }) - sut, err := NewListCache( - ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile, - ) - Expect(err).Should(Succeed()) - + It("should match", func() { Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(3)) Expect(sut.groupedCache.ElementCount("gr2")).Should(Equal(2)) @@ -304,31 +300,28 @@ var _ = Describe("ListCache", func() { file1, lines1 := createTestListFile(GinkgoT().TempDir(), 10000) file2, lines2 := createTestListFile(GinkgoT().TempDir(), 15000) file3, lines3 := createTestListFile(GinkgoT().TempDir(), 13000) - lists := map[string][]string{ - "gr1": {file1, file2, file3}, + lists := map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(file1, file2, file3), } - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) + sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3)) }) }) When("inline list content is defined", func() { - It("should match", func() { - lists := map[string][]string{ - "gr1": {inlineList( + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": {config.TextBytesSource( "inlinedomain1.com", "#some comment", "inlinedomain2.com", )}, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should match", func() { Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(2)) group := sut.Match("inlinedomain1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) @@ -338,65 +331,59 @@ var _ = Describe("ListCache", func() { }) }) When("Text file can't be parsed", func() { - It("should still match already imported strings", func() { - lists := map[string][]string{ + BeforeEach(func() { + lists = map[string][]config.BytesSource{ "gr1": { - inlineList( + config.TextBytesSource( "inlinedomain1.com", "lineTooLong"+strings.Repeat("x", bufio.MaxScanTokenSize), // too long ), }, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should still match already imported strings", func() { group := sut.Match("inlinedomain1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) }) }) When("Text file has too many errors", func() { BeforeEach(func() { - maxErrorsPerFile = 0 + sutConfig.MaxErrorsPerSource = 0 + sutConfig.Strategy = config.StartStrategyTypeFailOnError }) It("should fail parsing", func() { - lists := map[string][]string{ + lists := map[string][]config.BytesSource{ "gr1": { - inlineList("invaliddomain!"), // too many errors since `maxErrorsPerFile` is 0 + config.TextBytesSource("invaliddomain!"), // too many errors since `maxErrorsPerSource` is 0 }, } - _, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) + _, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).ShouldNot(Succeed()) Expect(err).Should(MatchError(parsers.ErrTooManyErrors)) }) }) When("file has end of line comment", func() { - It("should still parse the domain", func() { - lists := map[string][]string{ - "gr1": {inlineList("inlinedomain1.com#a comment")}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": {config.TextBytesSource("inlinedomain1.com#a comment")}, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should still parse the domain", func() { group := sut.Match("inlinedomain1.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) }) }) When("inline regex content is defined", func() { - It("should match", func() { - lists := map[string][]string{ - "gr1": {inlineList("/^apple\\.(de|com)$/")}, + BeforeEach(func() { + lists = map[string][]config.BytesSource{ + "gr1": {config.TextBytesSource("/^apple\\.(de|com)$/")}, } + }) - sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) - Expect(err).Should(Succeed()) - + It("should match", func() { group := sut.Match("apple.com", []string{"gr1"}) Expect(group).Should(ContainElement("gr1")) @@ -416,13 +403,12 @@ var _ = Describe("ListCache", func() { }) It("should print list configuration", func() { - lists := map[string][]string{ - "gr1": {server1.URL, server2.URL}, - "gr2": {inlineList("inline", "definition")}, + lists := map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(server1.URL, server2.URL), + "gr2": {config.TextBytesSource("inline", "definition")}, } - sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(), - defaultProcessingConcurrency, false, maxErrorsPerFile) + sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) sut.LogConfig(logger) @@ -435,13 +421,16 @@ var _ = Describe("ListCache", func() { Describe("StartStrategy", func() { When("async load is enabled", func() { + BeforeEach(func() { + sutConfig.Strategy = config.StartStrategyTypeFast + }) + It("should never return an error", func() { - lists := map[string][]string{ - "gr1": {"doesnotexist"}, + lists := map[string][]config.BytesSource{ + "gr1": config.NewBytesSources("doesnotexist"), } - _, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(), - defaultProcessingConcurrency, true, maxErrorsPerFile) + _, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) }) }) @@ -465,8 +454,11 @@ func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(str)), nil } -func (m *MockDownloader) ListSource() string { - return "http://mock" +func (m *MockDownloader) ListSource() config.BytesSource { + return config.BytesSource{ + Type: config.BytesSourceTypeHttp, + From: "http://mock-downloader", + } } func createTestListFile(dir string, totalLines int) (string, int) { @@ -502,12 +494,3 @@ func RandStringBytes(n int) string { return string(b) } - -func inlineList(lines ...string) string { - res := strings.Join(lines, "\n") - - // ensure at least one line ending so it's parsed as an inline block - res += "\n" - - return res -} diff --git a/lists/sourcereader.go b/lists/sourcereader.go new file mode 100644 index 00000000..8e6541c8 --- /dev/null +++ b/lists/sourcereader.go @@ -0,0 +1,69 @@ +package lists + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/0xERR0R/blocky/config" +) + +type SourceOpener interface { + fmt.Stringer + + Open() (io.ReadCloser, error) +} + +func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) { + switch source.Type { + case config.BytesSourceTypeText: + return &textOpener{source: source, locInfo: txtLocInfo}, nil + + case config.BytesSourceTypeHttp: + return &httpOpener{source: source, downloader: downloader}, nil + + case config.BytesSourceTypeFile: + return &fileOpener{source: source}, nil + } + + return nil, fmt.Errorf("cannot open %s", source) +} + +type textOpener struct { + source config.BytesSource + locInfo string +} + +func (o *textOpener) Open() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(o.source.From)), nil +} + +func (o *textOpener) String() string { + return fmt.Sprintf("%s: %s", o.locInfo, o.source) +} + +type httpOpener struct { + source config.BytesSource + downloader FileDownloader +} + +func (o *httpOpener) Open() (io.ReadCloser, error) { + return o.downloader.DownloadFile(o.source.From) +} + +func (o *httpOpener) String() string { + return o.source.String() +} + +type fileOpener struct { + source config.BytesSource +} + +func (o *fileOpener) Open() (io.ReadCloser, error) { + return os.Open(o.source.From) +} + +func (o *fileOpener) String() string { + return o.source.String() +} diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 24285a7e..a501e0c1 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -101,18 +101,14 @@ func NewBlockingResolver( return nil, err } - refreshPeriod := cfg.RefreshPeriod.ToDuration() - downloader := createDownloader(cfg, bootstrap) - blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, - refreshPeriod, downloader, cfg.ProcessingConcurrency, - (cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile) - whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists, - refreshPeriod, downloader, cfg.ProcessingConcurrency, - (cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile) + downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()) + + blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader) + whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader) whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg) err = multierror.Append(err, blErr, wlErr).ErrorOrNil() - if err != nil && cfg.StartStrategy == config.StartStrategyTypeFailOnError { + if err != nil { return nil, err } @@ -156,15 +152,6 @@ func NewBlockingResolver( return res, nil } -func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader { - return lists.NewDownloader( - lists.WithTimeout(cfg.DownloadTimeout.ToDuration()), - lists.WithAttempts(cfg.DownloadAttempts), - lists.WithCooldown(cfg.DownloadCooldown.ToDuration()), - lists.WithTransport(bootstrap.NewHTTPTransport()), - ) -} - func setupRedisEnabledSubscriber(c *BlockingResolver) { go func() { for em := range c.redisClient.EnabledChannel { diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index b6203b09..c80883ff 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -97,9 +97,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{ - "gr1": {group1File.Path}, - "gr2": {group2File.Path}, + BlackLists: map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(group1File.Path), + "gr2": config.NewBytesSources(group2File.Path), }, } }) @@ -125,9 +125,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{ - "gr1": {group1File.Path}, - "gr2": {group2File.Path}, + BlackLists: map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(group1File.Path), + "gr2": config.NewBytesSources(group2File.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, @@ -164,13 +164,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{ - "gr1": {"\n/regex/"}, + BlackLists: map[string][]config.BytesSource{ + "gr1": {config.TextBytesSource("/regex/")}, }, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, - StartStrategy: config.StartStrategyTypeFast, + Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFast}, } }) @@ -193,10 +193,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ BlockTTL: config.Duration(6 * time.Hour), - BlackLists: map[string][]string{ - "gr1": {group1File.Path}, - "gr2": {group2File.Path}, - "defaultGroup": {defaultGroupFile.Path}, + BlackLists: map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(group1File.Path), + "gr2": config.NewBytesSources(group2File.Path), + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), }, ClientGroupsBlock: map[string][]string{ "Client1": {"gr1"}, @@ -399,8 +399,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{ - "defaultGroup": {defaultGroupFile.Path}, + BlackLists: map[string][]config.BytesSource{ + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"defaultGroup"}, @@ -425,8 +425,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", - BlackLists: map[string][]string{ - "defaultGroup": {defaultGroupFile.Path}, + BlackLists: map[string][]config.BytesSource{ + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"defaultGroup"}, @@ -470,8 +470,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ BlockTTL: config.Duration(6 * time.Hour), - BlackLists: map[string][]string{ - "defaultGroup": {defaultGroupFile.Path}, + BlackLists: map[string][]config.BytesSource{ + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"defaultGroup"}, @@ -508,8 +508,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("BlockType is custom IP only for ipv4", func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ - BlackLists: map[string][]string{ - "defaultGroup": {defaultGroupFile.Path}, + BlackLists: map[string][]config.BytesSource{ + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"defaultGroup"}, @@ -601,8 +601,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{"gr1": {group1File.Path}}, - WhiteLists: map[string][]string{"gr1": {group1File.Path}}, + BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)}, + WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)}, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, @@ -627,9 +627,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "zeroIP", BlockTTL: config.Duration(60 * time.Second), - WhiteLists: map[string][]string{ - "gr1": {group1File.Path}, - "gr2": {group2File.Path}, + WhiteLists: map[string][]config.BytesSource{ + "gr1": config.NewBytesSources(group1File.Path), + "gr2": config.NewBytesSources(group2File.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, @@ -728,8 +728,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{"gr1": {group1File.Path}}, - WhiteLists: map[string][]string{"gr1": {defaultGroupFile.Path}}, + BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)}, + WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(defaultGroupFile.Path)}, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, @@ -755,7 +755,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), - BlackLists: map[string][]string{"gr1": {group1File.Path}}, + BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)}, ClientGroupsBlock: map[string][]string{ "default": {"gr1"}, }, @@ -798,9 +798,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Describe("Control status via API", func() { BeforeEach(func() { sutConfig = config.BlockingConfig{ - BlackLists: map[string][]string{ - "defaultGroup": {defaultGroupFile.Path}, - "group1": {group1File.Path}, + BlackLists: map[string][]config.BytesSource{ + "defaultGroup": config.NewBytesSources(defaultGroupFile.Path), + "group1": config.NewBytesSources(group1File.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"defaultGroup", "group1"}, @@ -1118,13 +1118,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { MatchError("unknown blockType 'wrong', please use one of: ZeroIP, NxDomain or specify destination IP address(es)")) }) }) - When("startStrategy is failOnError", func() { + When("strategy is failOnError", func() { It("should fail if lists can't be downloaded", func() { _, err := NewBlockingResolver(config.BlockingConfig{ - BlackLists: map[string][]string{"gr1": {"wrongPath"}}, - WhiteLists: map[string][]string{"whitelist": {"wrongPath"}}, - StartStrategy: config.StartStrategyTypeFailOnError, - BlockType: "zeroIp", + BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")}, + WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")}, + Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError}, + BlockType: "zeroIp", }, nil, systemResolverBootstrap) Expect(err).Should(HaveOccurred()) }) diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index f3518e41..3fc29151 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "net" - "os" - "time" "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/lists" "github.com/0xERR0R/blocky/lists/parsers" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" + "github.com/ThinkChaos/parcour" + "github.com/ThinkChaos/parcour/jobgroup" "github.com/miekg/dns" "github.com/sirupsen/logrus" ) @@ -18,33 +19,37 @@ import ( const ( // reduce initial capacity so we don't waste memory if there are less entries than before memReleaseFactor = 2 + + producersBuffCap = 1000 ) +type HostsFileEntry = parsers.HostsFileEntry + type HostsFileResolver struct { configurable[*config.HostsFileConfig] NextResolver typed - hosts splitHostsFileData + hosts splitHostsFileData + downloader lists.FileDownloader } -type HostsFileEntry = parsers.HostsFileEntry - -func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver { +func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) { r := HostsFileResolver{ configurable: withConfig(&cfg), typed: withType("hosts_file"), + + downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()), } - if err := r.parseHostsFile(context.Background()); err != nil { - r.log().Errorf("disabling hosts file resolving due to error: %s", err) - - r.cfg.Filepath = "" // don't try parsing the file again - } else { - go r.periodicUpdate() + err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) { + r.log().WithError(err).Errorf("could not load hosts files") + }) + if err != nil { + return nil, err } - return &r + return &r, nil } // LogConfig implements `config.Configurable`. @@ -102,7 +107,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp } func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) { - if r.cfg.Filepath == "" { + if !r.IsEnabled() { return r.next.Resolve(request) } @@ -144,27 +149,78 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain return response } -func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { - const maxErrorsPerFile = 5 - - if r.cfg.Filepath == "" { +func (r *HostsFileResolver) loadSources(ctx context.Context) error { + if !r.IsEnabled() { return nil } - f, err := os.Open(r.cfg.Filepath) - if err != nil { - return err + r.log().Debug("loading hosts files") + + //nolint:ineffassign,staticcheck // keep `ctx :=` so if we use ctx in the future, we use the correct one + consumersGrp, ctx := jobgroup.WithContext(ctx) + defer consumersGrp.Close() + + producersGrp := jobgroup.WithMaxConcurrency(consumersGrp, r.cfg.Loading.Concurrency) + defer producersGrp.Close() + + producers := parcour.NewProducersWithBuffer[*HostsFileEntry](producersGrp, consumersGrp, producersBuffCap) + defer producers.Close() + + for i, source := range r.cfg.Sources { + i, source := i, source + + producers.GoProduce(func(ctx context.Context, hostsChan chan<- *HostsFileEntry) error { + locInfo := fmt.Sprintf("item #%d", i) + + opener, err := lists.NewSourceOpener(locInfo, source, r.downloader) + if err != nil { + return err + } + + err = r.parseFile(ctx, opener, hostsChan) + if err != nil { + return fmt.Errorf("error parsing %s: %w", opener, err) // err is parsers.ErrTooManyErrors + } + + return nil + }) } - defer f.Close() newHosts := newSplitHostsDataWithSameCapacity(r.hosts) - p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile) - p.OnErr(func(err error) { - r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err) + producers.GoConsume(func(ctx context.Context, ch <-chan *HostsFileEntry) error { + for entry := range ch { + newHosts.add(entry) + } + + return nil }) - err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error { + err := producers.Wait() + if err != nil { + return err + } + + r.hosts = newHosts + + return nil +} + +func (r *HostsFileResolver) parseFile( + ctx context.Context, opener lists.SourceOpener, hostsChan chan<- *HostsFileEntry, +) error { + reader, err := opener.Open() + if err != nil { + return err + } + defer reader.Close() + + p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource) + p.OnErr(func(err error) { + r.log().Warnf("error parsing %s: %s, trying to continue", opener, err) + }) + + return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error { if len(entry.Interface) != 0 { // Ignore entries with a specific interface: we don't restrict what clients/interfaces we serve entries to, // so this avoids returning entries that can't be accessed by the client. @@ -176,32 +232,10 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error { return nil } - newHosts.add(entry) + hostsChan <- entry return nil }) - if err != nil { - return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors - } - - r.hosts = newHosts - - return nil -} - -func (r *HostsFileResolver) periodicUpdate() { - if r.cfg.RefreshPeriod.ToDuration() > 0 { - ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration()) - defer ticker.Stop() - - for { - <-ticker.C - - r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file") - - util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background())) - } - } } // stores hosts file data for IP versions separately diff --git a/resolver/hosts_file_resolver_test.go b/resolver/hosts_file_resolver_test.go index 2851ccdd..eef4591d 100644 --- a/resolver/hosts_file_resolver_test.go +++ b/resolver/hosts_file_resolver_test.go @@ -40,15 +40,22 @@ var _ = Describe("HostsFileResolver", func() { Expect(tmpFile.Error).Should(Succeed()) sutConfig = config.HostsFileConfig{ - Filepath: tmpFile.Path, + Sources: config.NewBytesSources(tmpFile.Path), HostsTTL: config.Duration(time.Duration(TTL) * time.Second), - RefreshPeriod: config.Duration(30 * time.Minute), FilterLoopback: true, + Loading: config.SourceLoadingConfig{ + RefreshPeriod: -1, + MaxErrorsPerSource: 5, + }, } }) JustBeforeEach(func() { - sut = NewHostsFileResolver(sutConfig) + var err error + + sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap) + Expect(err).Should(Succeed()) + m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) @@ -74,12 +81,12 @@ var _ = Describe("HostsFileResolver", func() { When("Hosts file cannot be located", func() { BeforeEach(func() { sutConfig = config.HostsFileConfig{ - Filepath: "/this/file/does/not/exist", + Sources: config.NewBytesSources("/this/file/does/not/exist"), HostsTTL: config.Duration(time.Duration(TTL) * time.Second), } }) It("should not parse any hosts", func() { - Expect(sut.cfg.Filepath).Should(BeEmpty()) + Expect(sut.cfg.Sources).ShouldNot(BeEmpty()) Expect(sut.hosts.v4.hosts).Should(BeEmpty()) Expect(sut.hosts.v6.hosts).Should(BeEmpty()) Expect(sut.hosts.v4.aliases).Should(BeEmpty()) @@ -99,13 +106,15 @@ var _ = Describe("HostsFileResolver", func() { When("Hosts file is not set", func() { BeforeEach(func() { - sut = NewHostsFileResolver(config.HostsFileConfig{}) + sutConfig.Deprecated.Filepath = new(config.BytesSource) + sutConfig.Sources = nil + m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) }) It("should not return an error", func() { - err := sut.parseHostsFile(context.Background()) + err := sut.loadSources(context.Background()) Expect(err).Should(Succeed()) }) It("should go to next resolver on query", func() { @@ -156,12 +165,12 @@ var _ = Describe("HostsFileResolver", func() { ) Expect(tmpFile.Error).Should(Succeed()) - sutConfig.Filepath = tmpFile.Path + sutConfig.Sources = config.NewBytesSources(tmpFile.Path) }) It("should not be used", func() { Expect(sut).ShouldNot(BeNil()) - Expect(sut.cfg.Filepath).Should(BeEmpty()) + Expect(sut.cfg.Sources).ShouldNot(BeEmpty()) Expect(sut.hosts.v4.hosts).Should(BeEmpty()) Expect(sut.hosts.v6.hosts).Should(BeEmpty()) Expect(sut.hosts.v4.aliases).Should(BeEmpty()) diff --git a/server/server.go b/server/server.go index 2d36101b..8f949b7f 100644 --- a/server/server.go +++ b/server/server.go @@ -400,15 +400,17 @@ func createQueryResolver( parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream) clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream) + hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap) - mErr := multierror.Append( + err = multierror.Append( multierror.Prefix(blErr, "blocking resolver: "), multierror.Prefix(pErr, "parallel resolver: "), multierror.Prefix(cnErr, "client names resolver: "), multierror.Prefix(cuErr, "conditional upstream resolver: "), - ) - if mErr.ErrorOrNil() != nil { - return nil, mErr + multierror.Prefix(hfErr, "hosts file resolver: "), + ).ErrorOrNil() + if err != nil { + return nil, err } r = resolver.Chain( @@ -419,7 +421,7 @@ func createQueryResolver( resolver.NewQueryLoggingResolver(cfg.QueryLog), resolver.NewMetricsResolver(cfg.Prometheus), resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)), - resolver.NewHostsFileResolver(cfg.HostsFile), + hostsFile, blocking, resolver.NewCachingResolver(cfg.Caching, redisClient), resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream), diff --git a/server/server_test.go b/server/server_test.go index 72bc8418..5d3a4ab3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -122,17 +122,17 @@ var _ = BeforeSuite(func() { }, }, Blocking: config.BlockingConfig{ - BlackLists: map[string][]string{ - "ads": { + BlackLists: map[string][]config.BytesSource{ + "ads": config.NewBytesSources( doubleclickFile.Path, bildFile.Path, heiseFile.Path, - }, - "youtube": {youtubeFile.Path}, + ), + "youtube": config.NewBytesSources(youtubeFile.Path), }, - WhiteLists: map[string][]string{ - "ads": {heiseFile.Path}, - "whitelist": {heiseFile.Path}, + WhiteLists: map[string][]config.BytesSource{ + "ads": config.NewBytesSources(heiseFile.Path), + "whitelist": config.NewBytesSources(heiseFile.Path), }, ClientGroupsBlock: map[string][]string{ "default": {"ads"},