feat: support multiple hosts files

This commit is contained in:
ThinkChaos 2023-04-17 12:21:56 -04:00
parent 5e4c155793
commit cfc3699ab5
29 changed files with 1506 additions and 822 deletions

View File

@ -80,6 +80,7 @@ issues:
# Exclude some linters from running on tests files.
- path: _test\.go
linters:
- gochecknoglobals
- dupl
- funlen
- gochecknoglobals
- gosec

View File

@ -1,8 +1,6 @@
package config
import (
"strings"
. "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
@ -10,32 +8,40 @@ import (
// BlockingConfig configuration for query blocking
type BlockingConfig struct {
BlackLists map[string][]string `yaml:"blackLists"`
WhiteLists map[string][]string `yaml:"whiteLists"`
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
BlockType string `yaml:"blockType" default:"ZEROIP"`
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
MaxErrorsPerFile int `yaml:"maxErrorsPerFile" default:"5"`
BlackLists map[string][]BytesSource `yaml:"blackLists"`
WhiteLists map[string][]BytesSource `yaml:"whiteLists"`
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
BlockType string `yaml:"blockType" default:"ZEROIP"`
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
Loading SourceLoadingConfig `yaml:"loading"`
// Deprecated options
Deprecated struct {
FailStartOnListError *bool `yaml:"failStartOnListError"`
DownloadTimeout *Duration `yaml:"downloadTimeout"`
DownloadAttempts *uint `yaml:"downloadAttempts"`
DownloadCooldown *Duration `yaml:"downloadCooldown"`
RefreshPeriod *Duration `yaml:"refreshPeriod"`
FailStartOnListError *bool `yaml:"failStartOnListError"`
ProcessingConcurrency *uint `yaml:"processingConcurrency"`
StartStrategy *StartStrategyType `yaml:"startStrategy"`
MaxErrorsPerFile *int `yaml:"maxErrorsPerFile"`
} `yaml:",inline"`
}
func (c *BlockingConfig) migrate(logger *logrus.Entry) bool {
return Migrate(logger, "blocking", c.Deprecated, map[string]Migrator{
"failStartOnListError": Apply(To("startStrategy", c), func(oldValue bool) {
if oldValue && c.StartStrategy != StartStrategyTypeFast {
c.StartStrategy = StartStrategyTypeFailOnError
"downloadTimeout": Move(To("loading.downloads.timeout", &c.Loading.Downloads)),
"downloadAttempts": Move(To("loading.downloads.attempts", &c.Loading.Downloads)),
"downloadCooldown": Move(To("loading.downloads.cooldown", &c.Loading.Downloads)),
"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
"failStartOnListError": Apply(To("loading.strategy", &c.Loading), func(oldValue bool) {
if oldValue {
c.Loading.Strategy = StartStrategyTypeFailOnError
}
}),
"processingConcurrency": Move(To("loading.concurrency", &c.Loading)),
"startStrategy": Move(To("loading.strategy", &c.Loading)),
"maxErrorsPerFile": Move(To("loading.maxErrorsPerSource", &c.Loading)),
})
}
@ -44,7 +50,7 @@ func (c *BlockingConfig) IsEnabled() bool {
return len(c.ClientGroupsBlock) != 0
}
// IsEnabled implements `config.Configurable`.
// LogConfig implements `config.Configurable`.
func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
logger.Info("clientGroupsBlock:")
@ -58,17 +64,8 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("blockTTL = %s", c.BlockTTL)
}
logger.Infof("downloadTimeout = %s", c.DownloadTimeout)
logger.Infof("startStrategy = %s", c.StartStrategy)
logger.Infof("maxErrorsPerFile = %d", c.MaxErrorsPerFile)
if c.RefreshPeriod > 0 {
logger.Infof("refresh = every %s", c.RefreshPeriod)
} else {
logger.Debug("refresh = disabled")
}
logger.Info("loading:")
log.WithIndent(logger, " ", c.Loading.LogConfig)
logger.Info("blacklist:")
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
@ -81,18 +78,12 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
})
}
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) {
for group, links := range listGroups {
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]BytesSource) {
for group, sources := range listGroups {
logger.Infof("%s:", group)
for _, link := range links {
if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char
link = link[:idx] // first line only
logger.Infof(" - %s [...]", link)
} else {
logger.Infof(" - %s", link)
}
for _, source := range sources {
logger.Infof(" - %s", source)
}
}
}

View File

@ -6,7 +6,6 @@ import (
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
)
var _ = Describe("BlockingConfig", func() {
@ -18,13 +17,12 @@ var _ = Describe("BlockingConfig", func() {
cfg = BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: Duration(time.Minute),
BlackLists: map[string][]string{
"gr1": {"/a/file/path"},
BlackLists: map[string][]BytesSource{
"gr1": NewBytesSources("/a/file/path"),
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
RefreshPeriod: Duration(time.Hour),
}
})
@ -59,26 +57,7 @@ var _ = Describe("BlockingConfig", func() {
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:"))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
})
When("refresh is disabled", func() {
It("should reflect that", func() {
cfg.RefreshPeriod = Duration(-1)
logger.Logger.Level = logrus.InfoLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
logger.Logger.Level = logrus.TraceLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
})
Expect(hook.Messages).Should(ContainElement(Equal("blockType = ZEROIP")))
})
})
})

111
config/bytes_source.go Normal file
View File

@ -0,0 +1,111 @@
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values
package config
import (
"fmt"
"strings"
)
const maxTextSourceDisplayLen = 12
// var BytesSourceNone = BytesSource{}
// BytesSourceType supported BytesSource types. ENUM(
// text=1 // Inline YAML block.
// http // HTTP(S).
// file // Local file.
// )
type BytesSourceType uint16
type BytesSource struct {
Type BytesSourceType
From string
}
func (s BytesSource) String() string {
switch s.Type {
case BytesSourceTypeText:
break
case BytesSourceTypeHttp:
return s.From
case BytesSourceTypeFile:
return fmt.Sprintf("file://%s", s.From)
default:
return fmt.Sprintf("unknown source (%s: %s)", s.Type, s.From)
}
text := s.From
truncated := false
if idx := strings.IndexRune(text, '\n'); idx != -1 {
text = text[:idx] // first line only
truncated = idx < len(text) // don't count removing last char
}
if len(text) > maxTextSourceDisplayLen { // truncate
text = text[:maxTextSourceDisplayLen]
truncated = true
}
if truncated {
return fmt.Sprintf("%s...", text[:maxTextSourceDisplayLen])
}
return text
}
// UnmarshalText implements `encoding.TextUnmarshaler`.
func (s *BytesSource) UnmarshalText(data []byte) error {
source := string(data)
switch {
// Inline definition in YAML (with literal style Block Scalar)
case strings.ContainsAny(source, "\n"):
*s = BytesSource{Type: BytesSourceTypeText, From: source}
// HTTP(S)
case strings.HasPrefix(source, "http"):
*s = BytesSource{Type: BytesSourceTypeHttp, From: source}
// Probably path to a local file
default:
*s = BytesSource{Type: BytesSourceTypeFile, From: strings.TrimPrefix(source, "file://")}
}
return nil
}
func newBytesSource(source string) BytesSource {
var res BytesSource
// UnmarshalText never returns an error
_ = res.UnmarshalText([]byte(source))
return res
}
func NewBytesSources(sources ...string) []BytesSource {
res := make([]BytesSource, 0, len(sources))
for _, source := range sources {
res = append(res, newBytesSource(source))
}
return res
}
func TextBytesSource(lines ...string) BytesSource {
return BytesSource{Type: BytesSourceTypeText, From: inlineList(lines...)}
}
func inlineList(lines ...string) string {
res := strings.Join(lines, "\n")
// ensure at least one line ending so it's parsed as an inline block
res += "\n"
return res
}

101
config/bytes_source_enum.go Normal file
View File

@ -0,0 +1,101 @@
// Code generated by go-enum DO NOT EDIT.
// Version:
// Revision:
// Build Date:
// Built By:
package config
import (
"fmt"
"strings"
)
const (
// BytesSourceTypeText is a BytesSourceType of type Text.
// Inline YAML block.
BytesSourceTypeText BytesSourceType = iota + 1
// BytesSourceTypeHttp is a BytesSourceType of type Http.
// HTTP(S).
BytesSourceTypeHttp
// BytesSourceTypeFile is a BytesSourceType of type File.
// Local file.
BytesSourceTypeFile
)
var ErrInvalidBytesSourceType = fmt.Errorf("not a valid BytesSourceType, try [%s]", strings.Join(_BytesSourceTypeNames, ", "))
const _BytesSourceTypeName = "texthttpfile"
var _BytesSourceTypeNames = []string{
_BytesSourceTypeName[0:4],
_BytesSourceTypeName[4:8],
_BytesSourceTypeName[8:12],
}
// BytesSourceTypeNames returns a list of possible string values of BytesSourceType.
func BytesSourceTypeNames() []string {
tmp := make([]string, len(_BytesSourceTypeNames))
copy(tmp, _BytesSourceTypeNames)
return tmp
}
// BytesSourceTypeValues returns a list of the values for BytesSourceType
func BytesSourceTypeValues() []BytesSourceType {
return []BytesSourceType{
BytesSourceTypeText,
BytesSourceTypeHttp,
BytesSourceTypeFile,
}
}
var _BytesSourceTypeMap = map[BytesSourceType]string{
BytesSourceTypeText: _BytesSourceTypeName[0:4],
BytesSourceTypeHttp: _BytesSourceTypeName[4:8],
BytesSourceTypeFile: _BytesSourceTypeName[8:12],
}
// String implements the Stringer interface.
func (x BytesSourceType) String() string {
if str, ok := _BytesSourceTypeMap[x]; ok {
return str
}
return fmt.Sprintf("BytesSourceType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x BytesSourceType) IsValid() bool {
_, ok := _BytesSourceTypeMap[x]
return ok
}
var _BytesSourceTypeValue = map[string]BytesSourceType{
_BytesSourceTypeName[0:4]: BytesSourceTypeText,
_BytesSourceTypeName[4:8]: BytesSourceTypeHttp,
_BytesSourceTypeName[8:12]: BytesSourceTypeFile,
}
// ParseBytesSourceType attempts to convert a string to a BytesSourceType.
func ParseBytesSourceType(name string) (BytesSourceType, error) {
if x, ok := _BytesSourceTypeValue[name]; ok {
return x, nil
}
return BytesSourceType(0), fmt.Errorf("%s is %w", name, ErrInvalidBytesSourceType)
}
// MarshalText implements the text marshaller method.
func (x BytesSourceType) MarshalText() ([]byte, error) {
return []byte(x.String()), nil
}
// UnmarshalText implements the text unmarshaller method.
func (x *BytesSourceType) UnmarshalText(text []byte) error {
name := string(text)
tmp, err := ParseBytesSourceType(name)
if err != nil {
return err
}
*x = tmp
return nil
}

View File

@ -60,4 +60,20 @@ var _ = Describe("CachingConfig", func() {
})
})
})
Describe("EnablePrefetch", func() {
When("prefetching is enabled", func() {
BeforeEach(func() {
cfg = CachingConfig{}
})
It("should return configuration", func() {
cfg.EnablePrefetch()
Expect(cfg.Prefetching).Should(BeTrue())
Expect(cfg.PrefetchThreshold).Should(Equal(0))
Expect(cfg.MaxCachingTime).ShouldNot(BeZero())
})
})
})
})

View File

@ -2,6 +2,7 @@
package config
import (
"context"
"errors"
"fmt"
"net"
@ -10,6 +11,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
@ -32,7 +34,7 @@ type Configurable interface {
// LogConfig logs the receiver's configuration.
//
// Calling this method when `IsEnabled` returns false is undefined.
// The behavior of this method is undefined when `IsEnabled` returns false.
LogConfig(*logrus.Entry)
}
@ -93,6 +95,30 @@ type QueryLogType int16
// )
type StartStrategyType uint16
func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
if *s == StartStrategyTypeFast {
go func() {
err := setup()
if err != nil {
logErr(err)
}
}()
return nil
}
err := setup()
if err != nil {
logErr(err)
if *s == StartStrategyTypeFailOnError {
return err
}
}
return nil
}
// QueryLogField data field to be logged
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
type QueryLogField string
@ -259,6 +285,86 @@ func (c *toEnable) LogConfig(logger *logrus.Entry) {
logger.Info("enabled")
}
type SourceLoadingConfig struct {
Concurrency uint `yaml:"concurrency" default:"4"`
MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
Strategy StartStrategyType `yaml:"strategy" default:"blocking"`
Downloads DownloaderConfig `yaml:"downloads"`
}
func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("concurrency = %d", c.Concurrency)
logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource)
logger.Debugf("strategy = %s", c.Strategy)
if c.RefreshPeriod > 0 {
logger.Infof("refresh = every %s", c.RefreshPeriod)
} else {
logger.Debug("refresh = disabled")
}
logger.Info("downloads:")
log.WithIndent(logger, " ", c.Downloads.LogConfig)
}
func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error {
refreshAndRecover := func(ctx context.Context) (rerr error) {
defer func() {
if val := recover(); val != nil {
rerr = fmt.Errorf("refresh function panicked: %v", val)
}
}()
return refresh(ctx)
}
err := c.Strategy.do(func() error { return refreshAndRecover(context.Background()) }, logErr)
if err != nil {
return err
}
if c.RefreshPeriod > 0 {
go c.periodically(refreshAndRecover, logErr)
}
return nil
}
func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) {
ticker := time.NewTicker(c.RefreshPeriod.ToDuration())
defer ticker.Stop()
for range ticker.C {
err := refresh(context.Background())
if err != nil {
logErr(err)
}
}
}
type DownloaderConfig struct {
Timeout Duration `yaml:"timeout" default:"5s"`
Attempts uint `yaml:"attempts" default:"3"`
Cooldown Duration `yaml:"cooldown" default:"500ms"`
}
func (c *DownloaderConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("timeout = %s", c.Timeout)
logger.Infof("attempts = %d", c.Attempts)
logger.Debugf("cooldown = %s", c.Cooldown)
}
func WithDefaults[T any]() (T, error) {
var cfg T
if err := defaults.Set(&cfg); err != nil {
return cfg, fmt.Errorf("can't apply %T defaults: %w", cfg, err)
}
return cfg, nil
}
//nolint:gochecknoglobals
var (
config = &Config{}
@ -270,9 +376,9 @@ func LoadConfig(path string, mandatory bool) (*Config, error) {
cfgLock.Lock()
defer cfgLock.Unlock()
cfg := Config{}
if err := defaults.Set(&cfg); err != nil {
return nil, fmt.Errorf("can't apply default values: %w", err)
cfg, err := WithDefaults[Config]()
if err != nil {
return nil, err
}
fs, err := os.Stat(path)
@ -398,6 +504,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
})
usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts
usesDepredOpts = cfg.HostsFile.migrate(logger) || usesDepredOpts
return usesDepredOpts
}

View File

@ -63,6 +63,13 @@ func (x IPVersion) String() string {
return fmt.Sprintf("IPVersion(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x IPVersion) IsValid() bool {
_, ok := _IPVersionMap[x]
return ok
}
var _IPVersionValue = map[string]IPVersion{
_IPVersionName[0:4]: IPVersionDual,
_IPVersionName[4:6]: IPVersionV4,
@ -145,6 +152,13 @@ func (x NetProtocol) String() string {
return fmt.Sprintf("NetProtocol(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x NetProtocol) IsValid() bool {
_, ok := _NetProtocolMap[x]
return ok
}
var _NetProtocolValue = map[string]NetProtocol{
_NetProtocolName[0:7]: NetProtocolTcpUdp,
_NetProtocolName[7:14]: NetProtocolTcpTls,
@ -225,7 +239,8 @@ func (x QueryLogField) String() string {
return string(x)
}
// String implements the Stringer interface.
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x QueryLogField) IsValid() bool {
_, err := ParseQueryLogField(string(x))
return err == nil
@ -333,6 +348,13 @@ func (x QueryLogType) String() string {
return fmt.Sprintf("QueryLogType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x QueryLogType) IsValid() bool {
_, ok := _QueryLogTypeMap[x]
return ok
}
var _QueryLogTypeValue = map[string]QueryLogType{
_QueryLogTypeName[0:7]: QueryLogTypeConsole,
_QueryLogTypeName[7:11]: QueryLogTypeNone,
@ -418,6 +440,13 @@ func (x StartStrategyType) String() string {
return fmt.Sprintf("StartStrategyType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x StartStrategyType) IsValid() bool {
_, ok := _StartStrategyTypeMap[x]
return ok
}
var _StartStrategyTypeValue = map[string]StartStrategyType{
_StartStrategyTypeName[0:8]: StartStrategyTypeBlocking,
_StartStrategyTypeName[8:19]: StartStrategyTypeFailOnError,

View File

@ -1,11 +1,15 @@
package config
import (
"context"
"errors"
"net"
"sync/atomic"
"time"
"github.com/creasty/defaults"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -48,15 +52,17 @@ var _ = Describe("Config", func() {
BeforeEach(func() {
c.Blocking.Deprecated.FailStartOnListError = ptrOf(true)
})
It("should change StartStrategy blocking to failOnError", func() {
c.Blocking.StartStrategy = StartStrategyTypeBlocking
It("should change loading.strategy blocking to failOnError", func() {
c.Blocking.Loading.Strategy = StartStrategyTypeBlocking
c.migrate(logger)
Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFailOnError))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy")))
Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFailOnError))
})
It("shouldn't change StartStrategy if set to fast", func() {
c.Blocking.StartStrategy = StartStrategyTypeFast
It("shouldn't change loading.strategy if set to fast", func() {
c.Blocking.Loading.Strategy = StartStrategyTypeFast
c.migrate(logger)
Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFast))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy")))
Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFast))
})
})
@ -206,8 +212,10 @@ var _ = Describe("Config", func() {
When("duration is in wrong format", func() {
It("should return error", func() {
cfg := Config{}
data := `blocking:
refreshPeriod: wrongduration`
data := `
blocking:
loading:
refreshPeriod: wrongduration`
err := unmarshalConfig([]byte(data), &cfg)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("invalid duration \"wrongduration\""))
@ -534,6 +542,222 @@ bootstrapDns:
"tcp-tls:[fd00::6cd4:d7e0:d99d:2952]",
),
)
Describe("SourceLoadingConfig", func() {
var cfg SourceLoadingConfig
BeforeEach(func() {
cfg = SourceLoadingConfig{
Concurrency: 12,
RefreshPeriod: Duration(time.Hour),
}
})
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages[0]).Should(Equal("concurrency = 12"))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
})
When("refresh is disabled", func() {
BeforeEach(func() {
cfg.RefreshPeriod = Duration(-1)
})
It("should reflect that", func() {
logger.Logger.Level = logrus.InfoLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
logger.Logger.Level = logrus.TraceLevel
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
})
})
})
})
Describe("StartStrategyType", func() {
Describe("StartStrategyTypeBlocking", func() {
It("runs in the current goroutine", func() {
sut := StartStrategyTypeBlocking
panicVal := new(int)
defer func() {
// recover will catch the panic if it happened in the same goroutine
Expect(recover()).Should(BeIdenticalTo(panicVal))
}()
_ = sut.do(func() error {
panic(panicVal)
}, nil)
Fail("unreachable")
})
It("logs errors and doesn't return them", func() {
sut := StartStrategyTypeBlocking
expectedErr := errors.New("test")
err := sut.do(func() error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
})
Expect(err).Should(Succeed())
})
})
Describe("StartStrategyTypeFailOnError", func() {
It("runs in the current goroutine", func() {
sut := StartStrategyTypeBlocking
panicVal := new(int)
defer func() {
// recover will catch the panic if it happened in the same goroutine
Expect(recover()).Should(BeIdenticalTo(panicVal))
}()
_ = sut.do(func() error {
panic(panicVal)
}, nil)
Fail("unreachable")
})
It("logs errors and returns them", func() {
sut := StartStrategyTypeFailOnError
expectedErr := errors.New("test")
err := sut.do(func() error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
})
Expect(err).Should(MatchError(expectedErr))
})
})
Describe("StartStrategyTypeFast", func() {
It("runs in a new goroutine", func() {
sut := StartStrategyTypeFast
events := make(chan string)
wait := make(chan struct{})
err := sut.do(func() error {
events <- "start"
<-wait
events <- "done"
return nil
}, nil)
Eventually(events, "50ms").Should(Receive(Equal("start")))
Expect(err).Should(Succeed())
Consistently(events).ShouldNot(Receive())
close(wait)
Eventually(events, "50ms").Should(Receive(Equal("done")))
})
It("logs errors", func() {
sut := StartStrategyTypeFast
expectedErr := errors.New("test")
wait := make(chan struct{})
err := sut.do(func() error {
return expectedErr
}, func(err error) {
Expect(err).Should(MatchError(expectedErr))
close(wait)
})
Expect(err).Should(Succeed())
Eventually(wait, "50ms").Should(BeClosed())
})
})
})
Describe("SourceLoadingConfig", func() {
It("handles panics", func() {
sut := SourceLoadingConfig{
Strategy: StartStrategyTypeFailOnError,
}
panicMsg := "panic value"
err := sut.StartPeriodicRefresh(func(context.Context) error {
panic(panicMsg)
}, func(err error) {
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
})
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
})
It("periodically calls refresh", func() {
sut := SourceLoadingConfig{
Strategy: StartStrategyTypeFast,
RefreshPeriod: Duration(5 * time.Millisecond),
}
panicMsg := "panic value"
calls := make(chan int32)
var call atomic.Int32
err := sut.StartPeriodicRefresh(func(context.Context) error {
call := call.Add(1)
calls <- call
if call == 3 {
panic(panicMsg)
}
return nil
}, func(err error) {
defer GinkgoRecover()
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
Expect(call.Load()).Should(Equal(int32(3)))
})
Expect(err).Should(Succeed())
Eventually(calls, "50ms").Should(Receive(Equal(int32(1))))
Eventually(calls, "50ms").Should(Receive(Equal(int32(2))))
Eventually(calls, "50ms").Should(Receive(Equal(int32(3))))
})
})
Describe("WithDefaults", func() {
It("use valid defaults", func() {
type T struct {
X int `default:"1"`
}
t, err := WithDefaults[T]()
Expect(err).Should(Succeed())
Expect(t.X).Should(Equal(1))
})
It("return an error if the tag is invalid", func() {
type T struct {
X struct{} `default:"fail"`
}
_, err := WithDefaults[T]()
Expect(err).ShouldNot(Succeed())
})
})
})
func defaultTestFileConfig() {
@ -557,7 +781,7 @@ func defaultTestFileConfig() {
Expect(config.Blocking.WhiteLists).Should(HaveLen(1))
Expect(config.Blocking.ClientGroupsBlock).Should(HaveLen(2))
Expect(config.Blocking.BlockTTL).Should(Equal(Duration(time.Minute)))
Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
Expect(config.Blocking.Loading.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
Expect(config.Filtering.QueryTypes).Should(HaveLen(2))
Expect(config.FqdnOnly.Enable).Should(BeTrue())
@ -613,7 +837,8 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
" Laptop-D.fritz.box:",
" - ads",
" blockTTL: 1m",
" refreshPeriod: 120",
" loading:",
" refreshPeriod: 120",
"clientLookup:",
" upstream: 192.168.178.1",
" singleNameOrder:",
@ -629,7 +854,6 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
"startVerifyUpstream: false")
}
//nolint:funlen
func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
f1 := tmpDir.CreateStringFile("config1.yaml",
"upstream:",
@ -675,7 +899,8 @@ func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
" Laptop-D.fritz.box:",
" - ads",
" blockTTL: 1m",
" refreshPeriod: 120",
" loading:",
" refreshPeriod: 120",
"clientLookup:",
" upstream: 192.168.178.1",
" singleNameOrder:",

View File

@ -1,25 +1,49 @@
package config
import (
. "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
)
type HostsFileConfig struct {
Filepath string `yaml:"filePath"`
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
FilterLoopback bool `yaml:"filterLoopback"`
Sources []BytesSource `yaml:"sources"`
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
FilterLoopback bool `yaml:"filterLoopback"`
Loading SourceLoadingConfig `yaml:"loading"`
// Deprecated options
Deprecated struct {
RefreshPeriod *Duration `yaml:"refreshPeriod"`
Filepath *BytesSource `yaml:"filePath"`
} `yaml:",inline"`
}
func (c *HostsFileConfig) migrate(logger *logrus.Entry) bool {
return Migrate(logger, "hostsFile", c.Deprecated, map[string]Migrator{
"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
"filePath": Apply(To("sources", c), func(value BytesSource) {
c.Sources = append(c.Sources, value)
}),
})
}
// IsEnabled implements `config.Configurable`.
func (c *HostsFileConfig) IsEnabled() bool {
return len(c.Filepath) != 0
return len(c.Sources) != 0
}
// LogConfig implements `config.Configurable`.
func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("file path: %s", c.Filepath)
logger.Infof("TTL: %s", c.HostsTTL)
logger.Infof("refresh period: %s", c.RefreshPeriod)
logger.Infof("filter loopback addresses: %t", c.FilterLoopback)
logger.Info("loading:")
log.WithIndent(logger, " ", c.Loading.LogConfig)
logger.Info("sources:")
for _, source := range c.Sources {
logger.Infof(" - %s", source)
}
}

View File

@ -15,9 +15,12 @@ var _ = Describe("HostsFileConfig", func() {
BeforeEach(func() {
cfg = HostsFileConfig{
Filepath: "/dev/null",
Sources: append(
NewBytesSources("/a/file/path"),
TextBytesSource("127.0.0.1 localhost"),
),
HostsTTL: Duration(29 * time.Minute),
RefreshPeriod: Duration(30 * time.Minute),
Loading: SourceLoadingConfig{RefreshPeriod: Duration(30 * time.Minute)},
FilterLoopback: true,
}
})
@ -50,7 +53,28 @@ var _ = Describe("HostsFileConfig", func() {
cfg.LogConfig(logger)
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("- file:///a/file/path")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("- 127.0.0.1 lo...")))
})
})
Describe("migrate", func() {
It("should", func() {
cfg, err := WithDefaults[HostsFileConfig]()
Expect(err).Should(Succeed())
cfg.Deprecated.Filepath = ptrOf(newBytesSource("/a/file/path"))
cfg.Deprecated.RefreshPeriod = ptrOf(Duration(time.Hour))
migrated := cfg.migrate(logger)
Expect(migrated).Should(BeTrue())
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.loading.refreshPeriod")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.sources")))
Expect(cfg.Sources).Should(Equal([]BytesSource{*cfg.Deprecated.Filepath}))
Expect(cfg.Loading.RefreshPeriod).Should(Equal(*cfg.Deprecated.RefreshPeriod))
})
})
})

View File

@ -1,3 +1,5 @@
# REVIEW: manual changelog entry
upstream:
# these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query
# format for resolver: [net:]host:[port][/path]. net could be empty (default, shortcut for tcp+udp), tcp+udp, tcp, udp, tcp-tls or https (DoH). If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls, 443 for https (Doh))
@ -99,22 +101,33 @@ blocking:
# optional: TTL for answers to blocked domains
# default: 6h
blockTTL: 1m
# optional: automatically list refresh period (in duration format). Default: 4h.
# Negative value -> deactivate automatically refresh.
# 0 value -> use default
refreshPeriod: 4h
# optional: timeout for list download (each url). Default: 60s. Use large values for big lists or slow internet connections
downloadTimeout: 4m
# optional: Download attempt timeout. Default: 60s
downloadAttempts: 5
# optional: Time between the download attempts. Default: 1s
downloadCooldown: 10s
# optional: if failOnError, application startup will fail if at least one list can't be downloaded / opened. Default: blocking
startStrategy: failOnError
# Number of errors allowed in a list before it is considered invalid.
# A value of -1 disables the limit.
# Default: 5
maxErrorsPerFile: 5
# optional: Configure how lists, AKA sources, are loaded
loading:
# optional: list refresh period in duration format.
# Set to a value <= 0 to disable.
# default: 4h
refreshPeriod: 24h
# optional: Applies only to lists that are downloaded (HTTP URLs).
downloads:
# optional: timeout for list download (each url). Use large values for big lists or slow internet connections
# default: 5s
timeout: 60s
# optional: Maximum download attempts
# default: 3
attempts: 5
# optional: Time between the download attempts
# default: 500ms
cooldown: 10s
# optional: Maximum number of lists to process in parallel.
# default: 4
concurrency: 16
# optional: if failOnError, application startup will fail if at least one list can't be downloaded/opened
# default: blocking
strategy: failOnError
# Number of errors allowed in a list before it is considered invalid.
# A value of -1 disables the limit.
# default: 5
maxErrorsPerSource: 5
# optional: configuration for caching of DNS responses
caching:
@ -161,6 +174,7 @@ clientLookup:
clients:
laptop:
- 192.168.178.29
# optional: configuration for prometheus metrics endpoint
prometheus:
# enabled if true
@ -214,6 +228,7 @@ redis:
# optional: Mininal TLS version that the DoH and DoT server will use
minTlsServeVersion: 1.3
# if https port > 0: path to cert and key file for SSL encryption. if not set, self-signed certificate will be generated
#certFile: server.crt
#keyFile: server.key
@ -238,14 +253,45 @@ fqdnOnly:
# optional: if path defined, use this file for query resolution (A, AAAA and rDNS). Default: empty
hostsFile:
# optional: Path to hosts file (e.g. /etc/hosts on Linux)
filePath: /etc/hosts
# optional: Hosts files to parse
sources:
- /etc/hosts
- https://example.com/hosts
- |
# inline hosts
127.0.0.1 example.com
# optional: TTL, default: 1h
hostsTTL: 60m
# optional: Time between hosts file refresh, default: 1h
refreshPeriod: 30m
# optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not, default: false
hostsTTL: 30m
# optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not
# default: false
filterLoopback: true
# optional: Configure how sources are loaded
loading:
# optional: file refresh period in duration format.
# Set to a value <= 0 to disable.
# default: 4h
refreshPeriod: 24h
# optional: Applies only to files that are downloaded (HTTP URLs).
downloads:
# optional: timeout for file download (each url). Use large values for big files or slow internet connections
# default: 5s
timeout: 60s
# optional: Maximum download attempts
# default: 3
attempts: 5
# optional: Time between the download attempts
# default: 500ms
cooldown: 10s
# optional: Maximum number of files to process in parallel.
# default: 4
concurrency: 16
# optional: if failOnError, application startup will fail if at least one file can't be downloaded/opened
# default: blocking
strategy: failOnError
# Number of errors allowed in a file before it is considered invalid.
# A value of -1 disables the limit.
# default: 5
maxErrorsPerSource: 5
# optional: ports configuration
ports:
@ -272,4 +318,4 @@ log:
# optional: add EDE error codes to dns response
ede:
# enabled if true, Default: false
enable: true
enable: true

View File

@ -330,20 +330,24 @@ contains a map of client name and multiple IP addresses.
## Blocking and whitelisting
Blocky can download and use external lists with domains or IP addresses to block DNS query (e.g. advertisement, malware,
Blocky can use lists of domains and IPs to block (e.g. advertisement, malware,
trackers, adult sites). You can group several list sources together and define the blocking behavior per client.
External blacklists must be either in the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file)) or just
a plain domain list (one domain per line). Blocky also supports regex as more powerful tool to define patterns to block.
Blocking uses the [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach. For each DNS query, the domain name from
the request, IP address from the response, and any CNAME records will be checked to determine whether to block the query or not.
Blocky uses [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach to block a DNS query. Domain name from
the request, IP address from the response, and the CNAME record will be checked against configured blacklists.
To avoid over-blocking, you can define or use already existing whitelists.
To avoid over-blocking, you can use whitelists.
### Definition black and whitelists
Each black or whitelist can be either a path to the local file, a URL to download or inline list definition of a domains
in hosts format (YAML literal block scalar style). All Urls must be grouped to a group name.
Lists are defined in groups. This allows using different sets of lists for different clients.
Each list in a group is a "source" and can be downloaded, read from a file, or inlined in the config. See [Sources](#sources) for details and configuring how those are loaded and reloaded/refreshed.
The supported list formats are:
1. the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file))
2. one domain per line (plain domain list)
3. one regex per line
!!! example
@ -354,35 +358,38 @@ in hosts format (YAML literal block scalar style). All Urls must be grouped to a
- https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt
- https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts
- |
# inline definition with YAML literal block scalar style
# inline definition using YAML literal block scalar style
# content is in plain domain list format
someadsdomain.com
anotheradsdomain.com
# this is a regex
- |
# inline definition with a regex
/^banners?[_.-]/
special:
- https://raw.githubusercontent.com/StevenBlack/hosts/master/alternates/fakenews/hosts
whiteLists:
ads:
- whitelist.txt
- /path/to/file.txt
- |
# inline definition with YAML literal block scalar style
whitelistdomain.com
```
In this example you can see 2 groups: **ads** with 2 lists and **special** with one list. One local whitelist was defined for the **ads** group.
In this example you can see 2 groups: **ads** and **special** with one list. The **ads** group includes 2 inline lists.
!!! warning
If the same group has black and whitelists, whitelists will be used to disable particular blacklist entries.
If a group has **only** whitelist entries -> this means only domains from this list are allowed, all other domains will
be blocked
be blocked.
!!! note
Please define also client group mapping, otherwise you black and whitelist definition will have no effect
!!! warning
You must also define client group mapping, otherwise you black and whitelist definition will have no effect.
#### Regex support
You can use regex to define patterns to block. A regex entry must start and end with the slash character (/). Some
You can use regex to define patterns to block. A regex entry must start and end with the slash character (`/`). Some
Examples:
- `/baddomain/` will block `www.baddomain.com`, `baddomain.com`, but also `mybaddomain-sometext.com`
@ -395,7 +402,7 @@ In this configuration section, you can define, which blocking group(s) should be
Example: All clients should use the **ads** group, which blocks advertisement and kids devices should use the **adult**
group, which blocky adult sites.
Clients without a group assignment will use automatically the **default** group.
Clients without an explicit group assignment will use the **default** group.
You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address, client's full-qualified domain name
or a client subnet as CIDR notation.
@ -460,82 +467,9 @@ after receiving the custom value.
blockTTL: 10s
```
### List refresh period
### Lists Loading
To keep the list cache up-to-date, blocky will periodically download and reload all external lists. Default period is **
4 hours**. You can configure this by setting the `blocking.refreshPeriod` parameter to a value in **duration format**.
Negative value will deactivate automatically refresh.
!!! example
```yaml
blocking:
refreshPeriod: 60m
```
Refresh every hour.
### Download
You can configure the list download attempts according to your internet connection:
| Parameter | Type | Mandatory | Default value | Description |
|------------------|-----------------|-----------|---------------|------------------------------------------------|
| downloadTimeout | duration format | no | 60s | Download attempt timeout |
| downloadAttempts | int | no | 3 | How many download attempts should be performed |
| downloadCooldown | duration format | no | 1s | Time between the download attempts |
!!! example
```yaml
blocking:
downloadTimeout: 4m
downloadAttempts: 5
downloadCooldown: 10s
```
### Start strategy
You can configure the blocking behavior during application start of blocky.
If no strategy is selected blocking will be used.
| startStrategy | Description |
|---------------|-------------------------------------------------------------------------------------------------------|
| blocking | all blocking lists will be loaded before DNS resolution starts |
| failOnError | like blocking but blocky will shut down if any download fails |
| fast | DNS resolution starts immediately without blocking which will be enabled after list load is completed |
!!! example
```yaml
blocking:
startStrategy: failOnError
```
### Max Errors per file
Number of errors allowed in a list before it is considered invalid and parsing stops.
A value of -1 disables the limit.
!!! example
```yaml
blocking:
maxErrorsPerFile: 10
```
### Concurrency
Blocky downloads and processes links in a single group concurrently. With parameter `processingConcurrency` you can adjust
how many links can be processed in the same time. Higher value can reduce the overall list refresh time, but more parallel
download and processing jobs need more RAM. Please consider to reduce this value on systems with limited memory. Default value is 4.
!!! example
```yaml
blocking:
processingConcurrency: 10
```
See [Sources Loading](#sources-loading).
## Caching
@ -716,7 +650,7 @@ Configuration parameters:
```yaml
hostsFile:
filePath: /etc/hosts
hostsTTL: 60m
hostsTTL: 1h
refreshPeriod: 30m
```
@ -745,3 +679,127 @@ for detailed information, how to create and configure SSL certificates.
DoH url: `https://host:port/dns-query`
--8<-- "docs/includes/abbreviations.md"
## Sources
Sources are a concept shared by the blocking and hosts file resolvers. They represent where to load the files for each resolver.
The supported source types are:
- HTTP(S) URL (any source starting with `http`)
- inline configuration (any source containing a newline)
- local file path (any source not matching the above rules)
!!! note
The format/content of the sources depends on the context: lists and hosts files have different, but overlapping, supported formats.
!!! example
```yaml
- https://example.com/a/source # blocky will download and parse the file
- /a/file/path # blocky will read the local file
- | # blocky will parse the content of this multi-line string
# inline configuration
```
### Sources Loading
This sections covers `loading` configuration that applies to both the blocking and hosts file resolvers.
These settings apply only to the resolver under which they are nested.
!!! example
```yaml
blocking:
loading:
# only applies to white/blacklists
hostsFile:
loading:
# only applies to hostsFile sources
```
#### Refresh / Reload
To keep source contents up-to-date, blocky can periodically refresh and reparse them. Default period is **
4 hours**. You can configure this by setting the `refreshPeriod` parameter to a value in **duration format**.
A value of zero or less will disable this feature.
!!! example
```yaml
loading:
refreshPeriod: 1h
```
Refresh every hour.
### Downloads
Configures how HTTP(S) sources are downloaded:
| Parameter | Type | Mandatory | Default value | Description |
|-----------|----------|-----------|---------------|------------------------------------------------|
| timeout | duration | no | 5s | Download attempt timeout |
| attempts | int | no | 3 | How many download attempts should be performed |
| cooldown | duration | no | 500ms | Time between the download attempts |
!!! example
```yaml
loading:
downloads:
timeout: 4m
attempts: 5
cooldown: 10s
```
### Strategy
This configures how Blocky startup works.
The default strategy is blocking.
| strategy | Description |
|-------------|------------------------------------------------------------------------------------------------------------------------------------------|
| blocking | all sources are loaded before DNS resolution starts |
| failOnError | like blocking but blocky will shut down if any source fails to load |
| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after |
!!! example
```yaml
loading:
strategy: failOnError
```
### Max Errors per Source
Number of errors allowed when parsing a source before it is considered invalid and parsing stops.
A value of -1 disables the limit.
!!! example
```yaml
loading:
maxErrorsPerSource: 10
```
### Concurrency
Blocky downloads and processes sources concurrently. This allows limiting how many can be processed in the same time.
Larger values can reduce the overall list refresh time at the cost of using more RAM. Please consider reducing this value on systems with limited memory.
Default value is 4.
!!! example
```yaml
loading:
concurrency: 10
```
!!! note
As with other settings under `loading`, the limit applies to the blocking and hosts file resolvers separately.
The total number of concurrent sources concurrently processed can reach the sum of both values.
For example if blocking has a limit set to 8 and hosts file's is 4, there could be up to 12 concurrent jobs.

View File

@ -19,7 +19,7 @@ var _ = Describe("External lists and query blocking", func() {
})
Describe("List download on startup", func() {
When("external blacklist ist not available", func() {
Context("startStrategy = blocking", func() {
Context("loading.strategy = blocking", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"log:",
@ -28,7 +28,8 @@ var _ = Describe("External lists and query blocking", func() {
" default:",
" - moka",
"blocking:",
" startStrategy: blocking",
" loading:",
" strategy: blocking",
" blackLists:",
" ads:",
" - http://wrong.domain.url/list.txt",
@ -54,7 +55,7 @@ var _ = Describe("External lists and query blocking", func() {
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: ")))
})
})
Context("startStrategy = failOnError", func() {
Context("loading.strategy = failOnError", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
"log:",
@ -63,7 +64,8 @@ var _ = Describe("External lists and query blocking", func() {
" default:",
" - moka",
"blocking:",
" startStrategy: failOnError",
" loading:",
" strategy: failOnError",
" blackLists:",
" ads:",
" - http://wrong.domain.url/list.txt",

View File

@ -29,7 +29,7 @@ const (
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
CachingDomainsToPrefetchCountChanged = "caching:domainsToPrefetchCountChanged"
// CachingFailedDownloadChanged fires, if a download of a blocking list fails
// CachingFailedDownloadChanged fires, if a download of a blocking list or hosts file fails
CachingFailedDownloadChanged = "caching:failedDownload"
// ApplicationStarted fires on start of the application. Parameter: version number, build time

3
go.mod
View File

@ -37,6 +37,7 @@ require (
require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5
github.com/docker/go-connections v0.4.0
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
github.com/testcontainers/testcontainers-go v0.21.0
@ -56,7 +57,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/klauspost/compress v1.11.13 // indirect
github.com/magiconair/properties v1.8.7 // indirect

8
go.sum
View File

@ -14,6 +14,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa
github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/hcsshim v0.9.7 h1:mKNHW/Xvv1aFH87Jb6ERDzXTJTLPlmzfZ28VBFD/bfg=
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI=
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
github.com/abice/go-enum v0.5.6 h1:Ury51IQXUppbIl56MqRU/++A8SSeLG4plePphPjxW1s=
github.com/abice/go-enum v0.5.6/go.mod h1:X2GpCT8VkCXLkVm48hebWx3cVgFJ8zM5nY5iUrJZO1Q=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
@ -108,8 +110,8 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d h1:um9/pc7tKMINFfP1eE7Wv6PRGXlcCSJkVajF7KJw3uQ=
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@ -124,7 +126,6 @@ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+l
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
@ -319,7 +320,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -6,18 +6,12 @@ import (
"io"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/avast/retry-go/v4"
)
const (
defaultDownloadTimeout = time.Second
defaultDownloadAttempts = uint(1)
defaultDownloadCooldown = 500 * time.Millisecond
)
// TransientError represents a temporary error like timeout, network errors...
type TransientError struct {
inner error
@ -36,74 +30,35 @@ type FileDownloader interface {
DownloadFile(link string) (io.ReadCloser, error)
}
// HTTPDownloader downloads files via HTTP protocol
type HTTPDownloader struct {
downloadTimeout time.Duration
downloadAttempts uint
downloadCooldown time.Duration
httpTransport *http.Transport
// httpDownloader downloads files via HTTP protocol
type httpDownloader struct {
cfg config.DownloaderConfig
client http.Client
}
type DownloaderOption func(c *HTTPDownloader)
func NewDownloader(options ...DownloaderOption) *HTTPDownloader {
d := &HTTPDownloader{
downloadTimeout: defaultDownloadTimeout,
downloadAttempts: defaultDownloadAttempts,
downloadCooldown: defaultDownloadCooldown,
httpTransport: &http.Transport{},
}
for _, opt := range options {
opt(d)
}
return d
func NewDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) FileDownloader {
return newDownloader(cfg, transport)
}
// WithTimeout sets the download timeout
func WithTimeout(timeout time.Duration) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadTimeout = timeout
func newDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) *httpDownloader {
return &httpDownloader{
cfg: cfg,
client: http.Client{
Transport: transport,
Timeout: cfg.Timeout.ToDuration(),
},
}
}
// WithTimeout sets the pause between 2 download attempts
func WithCooldown(cooldown time.Duration) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadCooldown = cooldown
}
}
// WithTimeout sets the attempt number for retry
func WithAttempts(downloadAttempts uint) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadAttempts = downloadAttempts
}
}
// WithTimeout sets the HTTP transport
func WithTransport(httpTransport *http.Transport) DownloaderOption {
return func(d *HTTPDownloader) {
d.httpTransport = httpTransport
}
}
func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
client := http.Client{
Timeout: d.downloadTimeout,
Transport: d.httpTransport,
}
logger().WithField("link", link).Info("starting download")
func (d *httpDownloader) DownloadFile(link string) (io.ReadCloser, error) {
var body io.ReadCloser
err := retry.Do(
func() error {
var resp *http.Response
var httpErr error
if resp, httpErr = client.Get(link); httpErr == nil {
resp, httpErr := d.client.Get(link)
if httpErr == nil {
if resp.StatusCode == http.StatusOK {
body = resp.Body
@ -121,17 +76,18 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
return httpErr
},
retry.Attempts(d.downloadAttempts),
retry.Attempts(d.cfg.Attempts),
retry.DelayType(retry.FixedDelay),
retry.Delay(d.downloadCooldown),
retry.Delay(d.cfg.Cooldown.ToDuration()),
retry.LastErrorOnly(true),
retry.OnRetry(func(n uint, err error) {
var transientErr *TransientError
var dnsErr *net.DNSError
logger := logger().WithField("link", link).WithField("attempt",
fmt.Sprintf("%d/%d", n+1, d.downloadAttempts))
logger := logger().
WithField("link", link).
WithField("attempt", fmt.Sprintf("%d/%d", n+1, d.cfg.Attempts))
switch {
case errors.As(err, &transientErr):

View File

@ -10,6 +10,7 @@ import (
"sync/atomic"
"time"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -20,11 +21,17 @@ import (
var _ = Describe("Downloader", func() {
var (
sut *HTTPDownloader
sutConfig config.DownloaderConfig
sut *httpDownloader
failedDownloadCountEvtChannel chan string
loggerHook *test.Hook
)
BeforeEach(func() {
var err error
sutConfig, err = config.WithDefaults[config.DownloaderConfig]()
Expect(err).Should(Succeed())
failedDownloadCountEvtChannel = make(chan string, 5)
// collect received events in the channel
fn := func(url string) {
@ -40,33 +47,27 @@ var _ = Describe("Downloader", func() {
DeferCleanup(loggerHook.Reset)
})
Describe("Construct downloader", func() {
When("No options are provided", func() {
BeforeEach(func() {
sut = NewDownloader()
})
It("Should provide default valus", func() {
Expect(sut.downloadAttempts).Should(BeNumerically("==", defaultDownloadAttempts))
Expect(sut.downloadTimeout).Should(BeNumerically("==", defaultDownloadTimeout))
Expect(sut.downloadCooldown).Should(BeNumerically("==", defaultDownloadCooldown))
})
})
When("Options are provided", func() {
JustBeforeEach(func() {
sut = newDownloader(sutConfig, nil)
})
Describe("NewDownloader", func() {
It("Should use provided parameters", func() {
transport := &http.Transport{}
BeforeEach(func() {
sut = NewDownloader(
WithAttempts(5),
WithCooldown(2*time.Second),
WithTimeout(5*time.Second),
WithTransport(transport),
)
})
It("Should use provided parameters", func() {
Expect(sut.downloadAttempts).Should(BeNumerically("==", 5))
Expect(sut.downloadTimeout).Should(BeNumerically("==", 5*time.Second))
Expect(sut.downloadCooldown).Should(BeNumerically("==", 2*time.Second))
Expect(sut.httpTransport).Should(BeIdenticalTo(transport))
})
sut = NewDownloader(
config.DownloaderConfig{
Attempts: 5,
Cooldown: config.Duration(2 * time.Second),
Timeout: config.Duration(5 * time.Second),
},
transport,
).(*httpDownloader)
Expect(sut.cfg.Attempts).Should(BeNumerically("==", 5))
Expect(sut.cfg.Timeout).Should(BeNumerically("==", 5*time.Second))
Expect(sut.cfg.Cooldown).Should(BeNumerically("==", 2*time.Second))
Expect(sut.client.Transport).Should(BeIdenticalTo(transport))
})
})
@ -77,7 +78,7 @@ var _ = Describe("Downloader", func() {
server = TestServer("line.one\nline.two")
DeferCleanup(server.Close)
sut = NewDownloader()
sut = newDownloader(sutConfig, nil)
})
It("Should return all lines from the file", func() {
reader, err := sut.DownloadFile(server.URL)
@ -98,7 +99,7 @@ var _ = Describe("Downloader", func() {
}))
DeferCleanup(server.Close)
sut = NewDownloader(WithAttempts(3))
sutConfig.Attempts = 3
})
It("Should return error", func() {
reader, err := sut.DownloadFile(server.URL)
@ -112,7 +113,7 @@ var _ = Describe("Downloader", func() {
})
When("Wrong URL is defined", func() {
BeforeEach(func() {
sut = NewDownloader()
sutConfig.Attempts = 1
})
It("Should return error", func() {
_, err := sut.DownloadFile("somewrongurl")
@ -129,10 +130,11 @@ var _ = Describe("Downloader", func() {
var attempt uint64 = 1
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(20*time.Millisecond),
WithAttempts(3),
WithCooldown(time.Millisecond))
sutConfig = config.DownloaderConfig{
Timeout: config.Duration(20 * time.Millisecond),
Attempts: 3,
Cooldown: config.Duration(time.Millisecond),
}
// should produce a timeout on first attempt
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -166,24 +168,23 @@ var _ = Describe("Downloader", func() {
})
When("If timeout occurs on all request", func() {
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(100*time.Millisecond),
WithAttempts(3),
WithCooldown(time.Millisecond))
sutConfig = config.DownloaderConfig{
Timeout: config.Duration(10 * time.Millisecond),
Attempts: 3,
Cooldown: config.Duration(time.Millisecond),
}
// should always produce a timeout
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(200 * time.Millisecond)
time.Sleep(20 * time.Millisecond)
}))
DeferCleanup(server.Close)
})
It("Should perform a retry until max retry attempt count is reached and return TransientError", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(HaveOccurred())
err2 := unwrapTransientErr(err)
Expect(err2.Error()).Should(ContainSubstring("Timeout"))
Expect(errors.As(err, new(*TransientError))).Should(BeTrue())
Expect(err.Error()).Should(ContainSubstring("Timeout"))
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
@ -193,19 +194,18 @@ var _ = Describe("Downloader", func() {
})
When("DNS resolution of passed URL fails", func() {
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(500*time.Millisecond),
WithAttempts(3),
WithCooldown(200*time.Millisecond))
sutConfig = config.DownloaderConfig{
Timeout: config.Duration(500 * time.Millisecond),
Attempts: 3,
Cooldown: 200 * config.Duration(time.Millisecond),
}
})
It("Should perform a retry until max retry attempt count is reached and return DNSError", func() {
reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist")
Expect(err).Should(HaveOccurred())
err2 := unwrapTransientErr(err)
var dnsError *net.DNSError
Expect(errors.As(err2, &dnsError)).To(BeTrue(), "received error %w", err)
Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err)
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
@ -216,12 +216,3 @@ var _ = Describe("Downloader", func() {
})
})
})
func unwrapTransientErr(origErr error) error {
var transientErr *TransientError
if errors.As(origErr, &transientErr) {
return transientErr.Unwrap()
}
return origErr
}

View File

@ -5,25 +5,20 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"time"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
"github.com/0xERR0R/blocky/cache/stringcache"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
"github.com/ThinkChaos/parcour"
"github.com/ThinkChaos/parcour/jobgroup"
)
const (
defaultProcessingConcurrency = 4
chanCap = 1000
)
const groupProducersBufferCap = 1000
// ListCacheType represents the type of cached list ENUM(
// blacklist // is a list with blocked domains
@ -41,19 +36,17 @@ type Matcher interface {
type ListCache struct {
groupedCache stringcache.GroupedStringCache
groupToLinks map[string][]string
refreshPeriod time.Duration
downloader FileDownloader
listType ListCacheType
processingConcurrency uint
maxErrorsPerFile int
cfg config.SourceLoadingConfig
listType ListCacheType
groupSources map[string][]config.BytesSource
downloader FileDownloader
}
// LogConfig implements `config.Configurable`.
func (b *ListCache) LogConfig(logger *logrus.Entry) {
var total int
for group := range b.groupToLinks {
for group := range b.groupSources {
count := b.groupedCache.ElementCount(group)
logger.Infof("%s: %d entries", group, count)
total += count
@ -63,132 +56,36 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) {
}
// NewListCache creates new list instance
func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration,
downloader FileDownloader, processingConcurrency uint, async bool, maxErrorsPerFile int,
func NewListCache(
t ListCacheType, cfg config.SourceLoadingConfig,
groupSources map[string][]config.BytesSource, downloader FileDownloader,
) (*ListCache, error) {
if processingConcurrency == 0 {
processingConcurrency = defaultProcessingConcurrency
}
b := &ListCache{
c := &ListCache{
groupedCache: stringcache.NewChainedGroupedCache(
stringcache.NewInMemoryGroupedStringCache(),
stringcache.NewInMemoryGroupedRegexCache(),
),
groupToLinks: groupToLinks,
refreshPeriod: refreshPeriod,
downloader: downloader,
listType: t,
processingConcurrency: processingConcurrency,
maxErrorsPerFile: maxErrorsPerFile,
cfg: cfg,
listType: t,
groupSources: groupSources,
downloader: downloader,
}
var initError error
if async {
initError = nil
// start list refresh in the background
go b.Refresh()
} else {
initError = b.refresh(true)
err := cfg.StartPeriodicRefresh(c.refresh, func(err error) {
logger().WithError(err).Errorf("could not init %s", t)
})
if err != nil {
return nil, err
}
if initError == nil {
go periodicUpdate(b)
}
return b, initError
}
// periodicUpdate triggers periodical refresh (and download) of list entries
func periodicUpdate(cache *ListCache) {
if cache.refreshPeriod > 0 {
ticker := time.NewTicker(cache.refreshPeriod)
defer ticker.Stop()
for {
<-ticker.C
cache.Refresh()
}
}
return c, nil
}
func logger() *logrus.Entry {
return log.PrefixedLog("list_cache")
}
// downloads and reads files with domain names and creates cache for them
//
//nolint:funlen // will refactor in a later commit
func (b *ListCache) createCacheForGroup(group string, links []string) (created bool, err error) {
groupFactory := b.groupedCache.Refresh(group)
fileLinesChan := make(chan string, chanCap)
errChan := make(chan error, chanCap)
workerDoneChan := make(chan bool, len(links))
// guard channel is used to limit the number of concurrent executions of the function
guard := make(chan struct{}, b.processingConcurrency)
processingLinkJobs := len(links)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// loop over links (http/local) or inline definitions
// start a new goroutine for each link, but limit to max. number (see processingConcurrency)
for idx, link := range links {
go func(idx int, link string) {
// try to write in this channel -> this will block if max amount of goroutines are being executed
guard <- struct{}{}
defer func() {
// remove from guard channel to allow other blocked goroutines to continue
<-guard
workerDoneChan <- true
}()
name := linkName(idx, link)
err := b.parseFile(ctx, name, link, fileLinesChan)
if err != nil {
errChan <- err
}
}(idx, link)
}
Loop:
for {
select {
case line := <-fileLinesChan:
groupFactory.AddEntry(line)
case e := <-errChan:
var transientErr *TransientError
if errors.As(e, &transientErr) {
return false, e
}
err = multierror.Append(err, e)
case <-workerDoneChan:
processingLinkJobs--
default:
if processingLinkJobs == 0 {
break Loop
}
}
}
if groupFactory.Count() == 0 && err != nil {
return false, err
}
groupFactory.Finish()
return true, err
}
// Match matches passed domain name against cached list entries
func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []string) {
return b.groupedCache.Contains(domain, groupsToCheck)
@ -196,65 +93,123 @@ func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []strin
// Refresh triggers the refresh of a list
func (b *ListCache) Refresh() {
_ = b.refresh(false)
_ = b.refresh(context.Background())
}
func (b *ListCache) refresh(isInit bool) error {
var err error
func (b *ListCache) refresh(ctx context.Context) error {
unlimitedGrp, _ := jobgroup.WithContext(ctx)
defer unlimitedGrp.Close()
for group, links := range b.groupToLinks {
created, e := b.createCacheForGroup(group, links)
if e != nil {
err = multierror.Append(err, multierror.Prefix(e, fmt.Sprintf("can't create cache group '%s':", group)))
}
producersGrp := jobgroup.WithMaxConcurrency(unlimitedGrp, b.cfg.Concurrency)
defer producersGrp.Close()
count := b.groupedCache.ElementCount(group)
for group, sources := range b.groupSources {
group, sources := group, sources
if !created {
logger := logger().WithFields(logrus.Fields{
"group": group,
"total_count": count,
})
unlimitedGrp.Go(func(ctx context.Context) error {
err := b.createCacheForGroup(producersGrp, unlimitedGrp, group, sources)
if err != nil {
count := b.groupedCache.ElementCount(group)
if count == 0 || isInit {
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
} else {
logger.Warn("Populating of group cache failed, using existing cache, if any")
logger := logger().WithFields(logrus.Fields{
"group": group,
"total_count": count,
})
if count == 0 {
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
} else {
logger.Warn("Populating of group cache failed, using existing cache, if any")
}
return err
}
continue
}
count := b.groupedCache.ElementCount(group)
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count)
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count)
logger().WithFields(logrus.Fields{
"group": group,
"total_count": count,
}).Info("group import finished")
logger().WithFields(logrus.Fields{
"group": group,
"total_count": count,
}).Info("group import finished")
return nil
})
}
return err
return unlimitedGrp.Wait()
}
func readFile(file string) (io.ReadCloser, error) {
logger().WithField("file", file).Info("starting processing of file")
file = strings.TrimPrefix(file, "file://")
func (b *ListCache) createCacheForGroup(
producersGrp, consumersGrp jobgroup.JobGroup, group string, sources []config.BytesSource,
) error {
groupFactory := b.groupedCache.Refresh(group)
return os.Open(file)
producers := parcour.NewProducersWithBuffer[string](producersGrp, consumersGrp, groupProducersBufferCap)
defer producers.Close()
for i, source := range sources {
i, source := i, source
producers.GoProduce(func(ctx context.Context, hostsChan chan<- string) error {
locInfo := fmt.Sprintf("item #%d of group %s", i, group)
opener, err := NewSourceOpener(locInfo, source, b.downloader)
if err != nil {
return err
}
return b.parseFile(ctx, opener, hostsChan)
})
}
hasEntries := false
producers.GoConsume(func(ctx context.Context, ch <-chan string) error {
for host := range ch {
hasEntries = true
groupFactory.AddEntry(host)
}
return nil
})
err := producers.Wait()
if err != nil {
if !hasEntries {
// Always fail the group if no entries were parsed
return err
}
var transientErr *TransientError
if errors.As(err, &transientErr) {
// Temporary error: fail the whole group to retry later
return err
}
}
groupFactory.Finish()
return nil
}
// downloads file (or reads local file) and writes each line in the file to the result channel
func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh chan<- string) error {
func (b *ListCache) parseFile(ctx context.Context, opener SourceOpener, resultCh chan<- string) error {
count := 0
logger := func() *logrus.Entry {
return logger().WithFields(logrus.Fields{
"source": name,
"source": opener.String(),
"count": count,
})
}
r, err := b.newLinkReader(link)
logger().Debug("starting processing of source")
r, err := opener.Open()
if err != nil {
logger().Error("cannot open source: ", err)
@ -262,7 +217,7 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c
}
defer r.Close()
p := parsers.AllowErrors(parsers.Hosts(r), b.maxErrorsPerFile)
p := parsers.AllowErrors(parsers.Hosts(r), b.cfg.MaxErrorsPerSource)
p.OnErr(func(err error) {
logger().Warnf("parse error: %s, trying to continue", err)
})
@ -303,27 +258,3 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c
return nil
}
func linkName(linkIdx int, link string) string {
if strings.ContainsAny(link, "\n") {
return fmt.Sprintf("inline block (item #%d in group)", linkIdx)
}
return link
}
func (b *ListCache) newLinkReader(link string) (r io.ReadCloser, err error) {
switch {
// link contains a line break -> this is inline list definition in YAML (with literal style Block Scalar)
case strings.ContainsAny(link, "\n"):
r = io.NopCloser(strings.NewReader(link))
// link is http(s) -> download it
case strings.HasPrefix(link, "http"):
r, err = b.downloader.DownloadFile(link)
// probably path to a local file
default:
r, err = readFile(link)
}
return
}

View File

@ -2,17 +2,24 @@ package lists
import (
"testing"
"github.com/0xERR0R/blocky/config"
)
func BenchmarkRefresh(b *testing.B) {
file1, _ := createTestListFile(b.TempDir(), 100000)
file2, _ := createTestListFile(b.TempDir(), 150000)
file3, _ := createTestListFile(b.TempDir(), 130000)
lists := map[string][]string{
"gr1": {file1, file2, file3},
lists := map[string][]config.BytesSource{
"gr1": config.NewBytesSources(file1, file2, file3),
}
cache, _ := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(), 5, false, 5)
cfg := config.SourceLoadingConfig{
Concurrency: 5,
RefreshPeriod: config.Duration(-1),
}
downloader := NewDownloader(config.DownloaderConfig{}, nil)
cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader)
b.ReportAllocs()

View File

@ -2,6 +2,7 @@ package lists
import (
"bufio"
"context"
"errors"
"fmt"
"io"
@ -9,8 +10,8 @@ import (
"net/http/httptest"
"os"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
@ -27,13 +28,28 @@ var _ = Describe("ListCache", func() {
tmpDir *TmpFolder
emptyFile, file1, file2, file3 *TmpFile
server1, server2, server3 *httptest.Server
maxErrorsPerFile int
sut *ListCache
sutConfig config.SourceLoadingConfig
listCacheType ListCacheType
lists map[string][]config.BytesSource
downloader FileDownloader
mockDownloader *MockDownloader
)
BeforeEach(func() {
maxErrorsPerFile = 5
tmpDir = NewTmpFolder("ListCache")
Expect(tmpDir.Error).Should(Succeed())
DeferCleanup(tmpDir.Clean)
var err error
listCacheType = ListCacheTypeBlacklist
sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]()
Expect(err).Should(Succeed())
sutConfig.RefreshPeriod = -1
downloader = NewDownloader(config.DownloaderConfig{}, nil)
mockDownloader = nil
server1 = TestServer("blocked1.com\nblocked1a.com\n192.168.178.55")
DeferCleanup(server1.Close)
@ -42,6 +58,13 @@ var _ = Describe("ListCache", func() {
server3 = TestServer("blocked3.com\nblocked1a.com")
DeferCleanup(server3.Close)
tmpDir = NewTmpFolder("ListCache")
Expect(tmpDir.Error).Should(Succeed())
DeferCleanup(tmpDir.Clean)
emptyFile = tmpDir.CreateStringFile("empty", "#empty file")
Expect(emptyFile.Error).Should(Succeed())
emptyFile = tmpDir.CreateStringFile("empty", "#empty file")
Expect(emptyFile.Error).Should(Succeed())
file1 = tmpDir.CreateStringFile("file1", "blocked1.com", "blocked1a.com")
@ -52,61 +75,56 @@ var _ = Describe("ListCache", func() {
Expect(file3.Error).Should(Succeed())
})
JustBeforeEach(func() {
var err error
Expect(lists).ShouldNot(BeNil(), "bad test: forgot to set `lists`")
if mockDownloader != nil {
downloader = mockDownloader
}
sut, err = NewListCache(listCacheType, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
})
Describe("List cache and matching", func() {
When("Query with empty", func() {
It("should not panic", func() {
lists := map[string][]string{
"gr0": {emptyFile.Path},
}
sut, err := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
Expect(err).Should(Succeed())
group := sut.Match("", []string{"gr0"})
Expect(group).Should(BeEmpty())
})
})
When("List is empty", func() {
It("should not match anything", func() {
lists := map[string][]string{
"gr1": {emptyFile.Path},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr0": config.NewBytesSources(emptyFile.Path),
}
sut, err := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
Expect(err).Should(Succeed())
})
When("Query with empty", func() {
It("should not panic", func() {
group := sut.Match("", []string{"gr0"})
Expect(group).Should(BeEmpty())
})
})
It("should not match anything", func() {
group := sut.Match("google.com", []string{"gr1"})
Expect(group).Should(BeEmpty())
})
})
When("List becomes empty on refresh", func() {
It("should delete existing elements from group cache", func() {
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
BeforeEach(func() {
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
res <- "# nothing"
})
lists := map[string][]string{
lists = map[string][]config.BytesSource{
"gr1": {mockDownloader.ListSource()},
}
})
sut, err := NewListCache(
ListCacheTypeBlacklist, lists,
4*time.Hour,
mockDownloader,
defaultProcessingConcurrency,
false,
maxErrorsPerFile,
)
Expect(err).Should(Succeed())
It("should delete existing elements from group cache", func(ctx context.Context) {
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
err = sut.refresh(false)
err := sut.refresh(ctx)
Expect(err).Should(Succeed())
group = sut.Match("blocked1.com", []string{"gr1"})
@ -114,21 +132,19 @@ var _ = Describe("ListCache", func() {
})
})
When("List has invalid lines", func() {
It("should still other domains", func() {
lists := map[string][]string{
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": {
inlineList(
config.TextBytesSource(
"inlinedomain1.com",
"invaliddomain!",
"inlinedomain2.com",
),
},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should still other domains", func() {
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
@ -137,28 +153,20 @@ var _ = Describe("ListCache", func() {
})
})
When("a temporary/transient err occurs on download", func() {
It("should not delete existing elements from group cache", func() {
BeforeEach(func() {
// should produce a transient error on second and third attempt
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com\nblocked2.com\n"
err <- &TransientError{inner: errors.New("boom")}
err <- &TransientError{inner: errors.New("boom")}
})
lists := map[string][]string{
lists = map[string][]config.BytesSource{
"gr1": {mockDownloader.ListSource()},
}
})
sut, err := NewListCache(
ListCacheTypeBlacklist, lists,
4*time.Hour,
mockDownloader,
defaultProcessingConcurrency,
false,
maxErrorsPerFile,
)
Expect(err).Should(Succeed())
It("should not delete existing elements from group cache", func(ctx context.Context) {
By("Lists loaded without timeout", func() {
Eventually(func(g Gomega) {
group := sut.Match("blocked1.com", []string{"gr1"})
@ -166,7 +174,7 @@ var _ = Describe("ListCache", func() {
}, "1s").Should(Succeed())
})
Expect(sut.refresh(false)).Should(HaveOccurred())
Expect(sut.refresh(ctx)).Should(HaveOccurred())
By("List couldn't be loaded due to timeout", func() {
group := sut.Match("blocked1.com", []string{"gr1"})
@ -182,27 +190,25 @@ var _ = Describe("ListCache", func() {
})
})
When("non transient err occurs on download", func() {
It("should keep existing elements from group cache", func() {
BeforeEach(func() {
// should produce a non transient error on second attempt
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
err <- errors.New("boom")
})
lists := map[string][]string{
lists = map[string][]config.BytesSource{
"gr1": {mockDownloader.ListSource()},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader,
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should keep existing elements from group cache", func(ctx context.Context) {
By("Lists loaded without err", func() {
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
Expect(sut.refresh(false)).Should(HaveOccurred())
Expect(sut.refresh(ctx)).Should(HaveOccurred())
By("Lists from first load is kept", func() {
group := sut.Match("blocked1.com", []string{"gr1"})
@ -211,16 +217,14 @@ var _ = Describe("ListCache", func() {
})
})
When("Configuration has 3 external working urls", func() {
It("should download the list and match against", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {server3.URL},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": config.NewBytesSources(server1.URL, server2.URL),
"gr2": config.NewBytesSources(server3.URL),
}
})
sut, _ := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
It("should download the list and match against", func() {
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
@ -232,16 +236,14 @@ var _ = Describe("ListCache", func() {
})
})
When("Configuration has some faulty urls", func() {
It("should download the list and match against", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL, "doesnotexist"},
"gr2": {server3.URL, "someotherfile"},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": config.NewBytesSources(server1.URL, server2.URL, "doesnotexist"),
"gr2": config.NewBytesSources(server3.URL, "someotherfile"),
}
})
sut, _ := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
It("should download the list and match against", func() {
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
@ -253,39 +255,33 @@ var _ = Describe("ListCache", func() {
})
})
When("List will be updated", func() {
It("event should be fired and contain count of elements in downloaded lists", func() {
lists := map[string][]string{
"gr1": {server1.URL},
}
resultCnt := 0
resultCnt := 0
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": config.NewBytesSources(server1.URL),
}
_ = Bus().SubscribeOnce(BlockingCacheGroupChanged, func(listType ListCacheType, group string, cnt int) {
resultCnt = cnt
})
})
sut, err := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
Expect(err).Should(Succeed())
It("event should be fired and contain count of elements in downloaded lists", func() {
group := sut.Match("blocked1.com", []string{})
Expect(group).Should(BeEmpty())
Expect(resultCnt).Should(Equal(3))
})
})
When("multiple groups are passed", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {file1.Path, file2.Path},
"gr2": {"file://" + file3.Path},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": config.NewBytesSources(file1.Path, file2.Path),
"gr2": config.NewBytesSources("file://" + file3.Path),
}
})
sut, err := NewListCache(
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
)
Expect(err).Should(Succeed())
It("should match", func() {
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(3))
Expect(sut.groupedCache.ElementCount("gr2")).Should(Equal(2))
@ -304,31 +300,28 @@ var _ = Describe("ListCache", func() {
file1, lines1 := createTestListFile(GinkgoT().TempDir(), 10000)
file2, lines2 := createTestListFile(GinkgoT().TempDir(), 15000)
file3, lines3 := createTestListFile(GinkgoT().TempDir(), 13000)
lists := map[string][]string{
"gr1": {file1, file2, file3},
lists := map[string][]config.BytesSource{
"gr1": config.NewBytesSources(file1, file2, file3),
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3))
})
})
When("inline list content is defined", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {inlineList(
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": {config.TextBytesSource(
"inlinedomain1.com",
"#some comment",
"inlinedomain2.com",
)},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should match", func() {
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(2))
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
@ -338,65 +331,59 @@ var _ = Describe("ListCache", func() {
})
})
When("Text file can't be parsed", func() {
It("should still match already imported strings", func() {
lists := map[string][]string{
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": {
inlineList(
config.TextBytesSource(
"inlinedomain1.com",
"lineTooLong"+strings.Repeat("x", bufio.MaxScanTokenSize), // too long
),
},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should still match already imported strings", func() {
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("Text file has too many errors", func() {
BeforeEach(func() {
maxErrorsPerFile = 0
sutConfig.MaxErrorsPerSource = 0
sutConfig.Strategy = config.StartStrategyTypeFailOnError
})
It("should fail parsing", func() {
lists := map[string][]string{
lists := map[string][]config.BytesSource{
"gr1": {
inlineList("invaliddomain!"), // too many errors since `maxErrorsPerFile` is 0
config.TextBytesSource("invaliddomain!"), // too many errors since `maxErrorsPerSource` is 0
},
}
_, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(parsers.ErrTooManyErrors))
})
})
When("file has end of line comment", func() {
It("should still parse the domain", func() {
lists := map[string][]string{
"gr1": {inlineList("inlinedomain1.com#a comment")},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": {config.TextBytesSource("inlinedomain1.com#a comment")},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should still parse the domain", func() {
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("inline regex content is defined", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {inlineList("/^apple\\.(de|com)$/")},
BeforeEach(func() {
lists = map[string][]config.BytesSource{
"gr1": {config.TextBytesSource("/^apple\\.(de|com)$/")},
}
})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
Expect(err).Should(Succeed())
It("should match", func() {
group := sut.Match("apple.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
@ -416,13 +403,12 @@ var _ = Describe("ListCache", func() {
})
It("should print list configuration", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {inlineList("inline", "definition")},
lists := map[string][]config.BytesSource{
"gr1": config.NewBytesSources(server1.URL, server2.URL),
"gr2": {config.TextBytesSource("inline", "definition")},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
defaultProcessingConcurrency, false, maxErrorsPerFile)
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
sut.LogConfig(logger)
@ -435,13 +421,16 @@ var _ = Describe("ListCache", func() {
Describe("StartStrategy", func() {
When("async load is enabled", func() {
BeforeEach(func() {
sutConfig.Strategy = config.StartStrategyTypeFast
})
It("should never return an error", func() {
lists := map[string][]string{
"gr1": {"doesnotexist"},
lists := map[string][]config.BytesSource{
"gr1": config.NewBytesSources("doesnotexist"),
}
_, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(),
defaultProcessingConcurrency, true, maxErrorsPerFile)
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
})
})
@ -465,8 +454,11 @@ func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(str)), nil
}
func (m *MockDownloader) ListSource() string {
return "http://mock"
func (m *MockDownloader) ListSource() config.BytesSource {
return config.BytesSource{
Type: config.BytesSourceTypeHttp,
From: "http://mock-downloader",
}
}
func createTestListFile(dir string, totalLines int) (string, int) {
@ -502,12 +494,3 @@ func RandStringBytes(n int) string {
return string(b)
}
func inlineList(lines ...string) string {
res := strings.Join(lines, "\n")
// ensure at least one line ending so it's parsed as an inline block
res += "\n"
return res
}

69
lists/sourcereader.go Normal file
View File

@ -0,0 +1,69 @@
package lists
import (
"fmt"
"io"
"os"
"strings"
"github.com/0xERR0R/blocky/config"
)
type SourceOpener interface {
fmt.Stringer
Open() (io.ReadCloser, error)
}
func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) {
switch source.Type {
case config.BytesSourceTypeText:
return &textOpener{source: source, locInfo: txtLocInfo}, nil
case config.BytesSourceTypeHttp:
return &httpOpener{source: source, downloader: downloader}, nil
case config.BytesSourceTypeFile:
return &fileOpener{source: source}, nil
}
return nil, fmt.Errorf("cannot open %s", source)
}
type textOpener struct {
source config.BytesSource
locInfo string
}
func (o *textOpener) Open() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(o.source.From)), nil
}
func (o *textOpener) String() string {
return fmt.Sprintf("%s: %s", o.locInfo, o.source)
}
type httpOpener struct {
source config.BytesSource
downloader FileDownloader
}
func (o *httpOpener) Open() (io.ReadCloser, error) {
return o.downloader.DownloadFile(o.source.From)
}
func (o *httpOpener) String() string {
return o.source.String()
}
type fileOpener struct {
source config.BytesSource
}
func (o *fileOpener) Open() (io.ReadCloser, error) {
return os.Open(o.source.From)
}
func (o *fileOpener) String() string {
return o.source.String()
}

View File

@ -101,18 +101,14 @@ func NewBlockingResolver(
return nil, err
}
refreshPeriod := cfg.RefreshPeriod.ToDuration()
downloader := createDownloader(cfg, bootstrap)
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists,
refreshPeriod, downloader, cfg.ProcessingConcurrency,
(cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile)
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists,
refreshPeriod, downloader, cfg.ProcessingConcurrency,
(cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile)
downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport())
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader)
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader)
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()
if err != nil && cfg.StartStrategy == config.StartStrategyTypeFailOnError {
if err != nil {
return nil, err
}
@ -156,15 +152,6 @@ func NewBlockingResolver(
return res, nil
}
func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader {
return lists.NewDownloader(
lists.WithTimeout(cfg.DownloadTimeout.ToDuration()),
lists.WithAttempts(cfg.DownloadAttempts),
lists.WithCooldown(cfg.DownloadCooldown.ToDuration()),
lists.WithTransport(bootstrap.NewHTTPTransport()),
)
}
func setupRedisEnabledSubscriber(c *BlockingResolver) {
go func() {
for em := range c.redisClient.EnabledChannel {

View File

@ -97,9 +97,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{
"gr1": {group1File.Path},
"gr2": {group2File.Path},
BlackLists: map[string][]config.BytesSource{
"gr1": config.NewBytesSources(group1File.Path),
"gr2": config.NewBytesSources(group2File.Path),
},
}
})
@ -125,9 +125,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{
"gr1": {group1File.Path},
"gr2": {group2File.Path},
BlackLists: map[string][]config.BytesSource{
"gr1": config.NewBytesSources(group1File.Path),
"gr2": config.NewBytesSources(group2File.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
@ -164,13 +164,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{
"gr1": {"\n/regex/"},
BlackLists: map[string][]config.BytesSource{
"gr1": {config.TextBytesSource("/regex/")},
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
StartStrategy: config.StartStrategyTypeFast,
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFast},
}
})
@ -193,10 +193,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockTTL: config.Duration(6 * time.Hour),
BlackLists: map[string][]string{
"gr1": {group1File.Path},
"gr2": {group2File.Path},
"defaultGroup": {defaultGroupFile.Path},
BlackLists: map[string][]config.BytesSource{
"gr1": config.NewBytesSources(group1File.Path),
"gr2": config.NewBytesSources(group2File.Path),
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
},
ClientGroupsBlock: map[string][]string{
"Client1": {"gr1"},
@ -399,8 +399,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{
"defaultGroup": {defaultGroupFile.Path},
BlackLists: map[string][]config.BytesSource{
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"defaultGroup"},
@ -425,8 +425,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlackLists: map[string][]string{
"defaultGroup": {defaultGroupFile.Path},
BlackLists: map[string][]config.BytesSource{
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"defaultGroup"},
@ -470,8 +470,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockTTL: config.Duration(6 * time.Hour),
BlackLists: map[string][]string{
"defaultGroup": {defaultGroupFile.Path},
BlackLists: map[string][]config.BytesSource{
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"defaultGroup"},
@ -508,8 +508,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("BlockType is custom IP only for ipv4", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlackLists: map[string][]string{
"defaultGroup": {defaultGroupFile.Path},
BlackLists: map[string][]config.BytesSource{
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"defaultGroup"},
@ -601,8 +601,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{"gr1": {group1File.Path}},
WhiteLists: map[string][]string{"gr1": {group1File.Path}},
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
@ -627,9 +627,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "zeroIP",
BlockTTL: config.Duration(60 * time.Second),
WhiteLists: map[string][]string{
"gr1": {group1File.Path},
"gr2": {group2File.Path},
WhiteLists: map[string][]config.BytesSource{
"gr1": config.NewBytesSources(group1File.Path),
"gr2": config.NewBytesSources(group2File.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
@ -728,8 +728,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{"gr1": {group1File.Path}},
WhiteLists: map[string][]string{"gr1": {defaultGroupFile.Path}},
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(defaultGroupFile.Path)},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
@ -755,7 +755,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
BlackLists: map[string][]string{"gr1": {group1File.Path}},
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
},
@ -798,9 +798,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
Describe("Control status via API", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlackLists: map[string][]string{
"defaultGroup": {defaultGroupFile.Path},
"group1": {group1File.Path},
BlackLists: map[string][]config.BytesSource{
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
"group1": config.NewBytesSources(group1File.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"defaultGroup", "group1"},
@ -1118,13 +1118,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
MatchError("unknown blockType 'wrong', please use one of: ZeroIP, NxDomain or specify destination IP address(es)"))
})
})
When("startStrategy is failOnError", func() {
When("strategy is failOnError", func() {
It("should fail if lists can't be downloaded", func() {
_, err := NewBlockingResolver(config.BlockingConfig{
BlackLists: map[string][]string{"gr1": {"wrongPath"}},
WhiteLists: map[string][]string{"whitelist": {"wrongPath"}},
StartStrategy: config.StartStrategyTypeFailOnError,
BlockType: "zeroIp",
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")},
WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")},
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError},
BlockType: "zeroIp",
}, nil, systemResolverBootstrap)
Expect(err).Should(HaveOccurred())
})

View File

@ -4,13 +4,14 @@ import (
"context"
"fmt"
"net"
"os"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/lists"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/ThinkChaos/parcour"
"github.com/ThinkChaos/parcour/jobgroup"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
@ -18,33 +19,37 @@ import (
const (
// reduce initial capacity so we don't waste memory if there are less entries than before
memReleaseFactor = 2
producersBuffCap = 1000
)
type HostsFileEntry = parsers.HostsFileEntry
type HostsFileResolver struct {
configurable[*config.HostsFileConfig]
NextResolver
typed
hosts splitHostsFileData
hosts splitHostsFileData
downloader lists.FileDownloader
}
type HostsFileEntry = parsers.HostsFileEntry
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) {
r := HostsFileResolver{
configurable: withConfig(&cfg),
typed: withType("hosts_file"),
downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()),
}
if err := r.parseHostsFile(context.Background()); err != nil {
r.log().Errorf("disabling hosts file resolving due to error: %s", err)
r.cfg.Filepath = "" // don't try parsing the file again
} else {
go r.periodicUpdate()
err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) {
r.log().WithError(err).Errorf("could not load hosts files")
})
if err != nil {
return nil, err
}
return &r
return &r, nil
}
// LogConfig implements `config.Configurable`.
@ -102,7 +107,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
}
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
if r.cfg.Filepath == "" {
if !r.IsEnabled() {
return r.next.Resolve(request)
}
@ -144,27 +149,78 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
return response
}
func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
const maxErrorsPerFile = 5
if r.cfg.Filepath == "" {
func (r *HostsFileResolver) loadSources(ctx context.Context) error {
if !r.IsEnabled() {
return nil
}
f, err := os.Open(r.cfg.Filepath)
if err != nil {
return err
r.log().Debug("loading hosts files")
//nolint:ineffassign,staticcheck // keep `ctx :=` so if we use ctx in the future, we use the correct one
consumersGrp, ctx := jobgroup.WithContext(ctx)
defer consumersGrp.Close()
producersGrp := jobgroup.WithMaxConcurrency(consumersGrp, r.cfg.Loading.Concurrency)
defer producersGrp.Close()
producers := parcour.NewProducersWithBuffer[*HostsFileEntry](producersGrp, consumersGrp, producersBuffCap)
defer producers.Close()
for i, source := range r.cfg.Sources {
i, source := i, source
producers.GoProduce(func(ctx context.Context, hostsChan chan<- *HostsFileEntry) error {
locInfo := fmt.Sprintf("item #%d", i)
opener, err := lists.NewSourceOpener(locInfo, source, r.downloader)
if err != nil {
return err
}
err = r.parseFile(ctx, opener, hostsChan)
if err != nil {
return fmt.Errorf("error parsing %s: %w", opener, err) // err is parsers.ErrTooManyErrors
}
return nil
})
}
defer f.Close()
newHosts := newSplitHostsDataWithSameCapacity(r.hosts)
p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile)
p.OnErr(func(err error) {
r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err)
producers.GoConsume(func(ctx context.Context, ch <-chan *HostsFileEntry) error {
for entry := range ch {
newHosts.add(entry)
}
return nil
})
err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
err := producers.Wait()
if err != nil {
return err
}
r.hosts = newHosts
return nil
}
func (r *HostsFileResolver) parseFile(
ctx context.Context, opener lists.SourceOpener, hostsChan chan<- *HostsFileEntry,
) error {
reader, err := opener.Open()
if err != nil {
return err
}
defer reader.Close()
p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource)
p.OnErr(func(err error) {
r.log().Warnf("error parsing %s: %s, trying to continue", opener, err)
})
return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
if len(entry.Interface) != 0 {
// Ignore entries with a specific interface: we don't restrict what clients/interfaces we serve entries to,
// so this avoids returning entries that can't be accessed by the client.
@ -176,32 +232,10 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
return nil
}
newHosts.add(entry)
hostsChan <- entry
return nil
})
if err != nil {
return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors
}
r.hosts = newHosts
return nil
}
func (r *HostsFileResolver) periodicUpdate() {
if r.cfg.RefreshPeriod.ToDuration() > 0 {
ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration())
defer ticker.Stop()
for {
<-ticker.C
r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file")
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background()))
}
}
}
// stores hosts file data for IP versions separately

View File

@ -40,15 +40,22 @@ var _ = Describe("HostsFileResolver", func() {
Expect(tmpFile.Error).Should(Succeed())
sutConfig = config.HostsFileConfig{
Filepath: tmpFile.Path,
Sources: config.NewBytesSources(tmpFile.Path),
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
RefreshPeriod: config.Duration(30 * time.Minute),
FilterLoopback: true,
Loading: config.SourceLoadingConfig{
RefreshPeriod: -1,
MaxErrorsPerSource: 5,
},
}
})
JustBeforeEach(func() {
sut = NewHostsFileResolver(sutConfig)
var err error
sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap)
Expect(err).Should(Succeed())
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -74,12 +81,12 @@ var _ = Describe("HostsFileResolver", func() {
When("Hosts file cannot be located", func() {
BeforeEach(func() {
sutConfig = config.HostsFileConfig{
Filepath: "/this/file/does/not/exist",
Sources: config.NewBytesSources("/this/file/does/not/exist"),
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
}
})
It("should not parse any hosts", func() {
Expect(sut.cfg.Filepath).Should(BeEmpty())
Expect(sut.cfg.Sources).ShouldNot(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
@ -99,13 +106,15 @@ var _ = Describe("HostsFileResolver", func() {
When("Hosts file is not set", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{})
sutConfig.Deprecated.Filepath = new(config.BytesSource)
sutConfig.Sources = nil
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
})
It("should not return an error", func() {
err := sut.parseHostsFile(context.Background())
err := sut.loadSources(context.Background())
Expect(err).Should(Succeed())
})
It("should go to next resolver on query", func() {
@ -156,12 +165,12 @@ var _ = Describe("HostsFileResolver", func() {
)
Expect(tmpFile.Error).Should(Succeed())
sutConfig.Filepath = tmpFile.Path
sutConfig.Sources = config.NewBytesSources(tmpFile.Path)
})
It("should not be used", func() {
Expect(sut).ShouldNot(BeNil())
Expect(sut.cfg.Filepath).Should(BeEmpty())
Expect(sut.cfg.Sources).ShouldNot(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
Expect(sut.hosts.v4.aliases).Should(BeEmpty())

View File

@ -400,15 +400,17 @@ func createQueryResolver(
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream)
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
mErr := multierror.Append(
err = multierror.Append(
multierror.Prefix(blErr, "blocking resolver: "),
multierror.Prefix(pErr, "parallel resolver: "),
multierror.Prefix(cnErr, "client names resolver: "),
multierror.Prefix(cuErr, "conditional upstream resolver: "),
)
if mErr.ErrorOrNil() != nil {
return nil, mErr
multierror.Prefix(hfErr, "hosts file resolver: "),
).ErrorOrNil()
if err != nil {
return nil, err
}
r = resolver.Chain(
@ -419,7 +421,7 @@ func createQueryResolver(
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
resolver.NewHostsFileResolver(cfg.HostsFile),
hostsFile,
blocking,
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),

View File

@ -122,17 +122,17 @@ var _ = BeforeSuite(func() {
},
},
Blocking: config.BlockingConfig{
BlackLists: map[string][]string{
"ads": {
BlackLists: map[string][]config.BytesSource{
"ads": config.NewBytesSources(
doubleclickFile.Path,
bildFile.Path,
heiseFile.Path,
},
"youtube": {youtubeFile.Path},
),
"youtube": config.NewBytesSources(youtubeFile.Path),
},
WhiteLists: map[string][]string{
"ads": {heiseFile.Path},
"whitelist": {heiseFile.Path},
WhiteLists: map[string][]config.BytesSource{
"ads": config.NewBytesSources(heiseFile.Path),
"whitelist": config.NewBytesSources(heiseFile.Path),
},
ClientGroupsBlock: map[string][]string{
"default": {"ads"},